diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
commit | 26a029d407be480d791972afb5975cf62c9360a6 (patch) | |
tree | f435a8308119effd964b339f76abb83a57c29483 /third_party/rust/naga/src | |
parent | Initial commit. (diff) | |
download | firefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz firefox-26a029d407be480d791972afb5975cf62c9360a6.zip |
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
95 files changed, 76114 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/arena.rs b/third_party/rust/naga/src/arena.rs new file mode 100644 index 0000000000..c37538667f --- /dev/null +++ b/third_party/rust/naga/src/arena.rs @@ -0,0 +1,772 @@ +use std::{cmp::Ordering, fmt, hash, marker::PhantomData, num::NonZeroU32, ops}; + +/// An unique index in the arena array that a handle points to. +/// The "non-zero" part ensures that an `Option<Handle<T>>` has +/// the same size and representation as `Handle<T>`. +type Index = NonZeroU32; + +use crate::{FastIndexSet, Span}; + +#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)] +#[error("Handle {index} of {kind} is either not present, or inaccessible yet")] +pub struct BadHandle { + pub kind: &'static str, + pub index: usize, +} + +impl BadHandle { + fn new<T>(handle: Handle<T>) -> Self { + Self { + kind: std::any::type_name::<T>(), + index: handle.index(), + } + } +} + +/// A strongly typed reference to an arena item. +/// +/// A `Handle` value can be used as an index into an [`Arena`] or [`UniqueArena`]. +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +#[cfg_attr( + any(feature = "serialize", feature = "deserialize"), + serde(transparent) +)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct Handle<T> { + index: Index, + #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))] + marker: PhantomData<T>, +} + +impl<T> Clone for Handle<T> { + fn clone(&self) -> Self { + *self + } +} + +impl<T> Copy for Handle<T> {} + +impl<T> PartialEq for Handle<T> { + fn eq(&self, other: &Self) -> bool { + self.index == other.index + } +} + +impl<T> Eq for Handle<T> {} + +impl<T> PartialOrd for Handle<T> { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + Some(self.cmp(other)) + } +} + +impl<T> Ord for Handle<T> { + fn cmp(&self, other: &Self) -> Ordering { + self.index.cmp(&other.index) + } +} + +impl<T> fmt::Debug for Handle<T> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "[{}]", self.index) + } +} + +impl<T> hash::Hash for Handle<T> { + fn hash<H: hash::Hasher>(&self, hasher: &mut H) { + self.index.hash(hasher) + } +} + +impl<T> Handle<T> { + #[cfg(test)] + pub const DUMMY: Self = Handle { + index: unsafe { NonZeroU32::new_unchecked(u32::MAX) }, + marker: PhantomData, + }; + + pub(crate) const fn new(index: Index) -> Self { + Handle { + index, + marker: PhantomData, + } + } + + /// Returns the zero-based index of this handle. + pub const fn index(self) -> usize { + let index = self.index.get() - 1; + index as usize + } + + /// Convert a `usize` index into a `Handle<T>`. + fn from_usize(index: usize) -> Self { + let handle_index = u32::try_from(index + 1) + .ok() + .and_then(Index::new) + .expect("Failed to insert into arena. Handle overflows"); + Handle::new(handle_index) + } + + /// Convert a `usize` index into a `Handle<T>`, without range checks. + const unsafe fn from_usize_unchecked(index: usize) -> Self { + Handle::new(Index::new_unchecked((index + 1) as u32)) + } +} + +/// A strongly typed range of handles. +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +#[cfg_attr( + any(feature = "serialize", feature = "deserialize"), + serde(transparent) +)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct Range<T> { + inner: ops::Range<u32>, + #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))] + marker: PhantomData<T>, +} + +impl<T> Range<T> { + pub(crate) const fn erase_type(self) -> Range<()> { + let Self { inner, marker: _ } = self; + Range { + inner, + marker: PhantomData, + } + } +} + +// NOTE: Keep this diagnostic in sync with that of [`BadHandle`]. +#[derive(Clone, Debug, thiserror::Error)] +#[error("Handle range {range:?} of {kind} is either not present, or inaccessible yet")] +pub struct BadRangeError { + // This error is used for many `Handle` types, but there's no point in making this generic, so + // we just flatten them all to `Handle<()>` here. + kind: &'static str, + range: Range<()>, +} + +impl BadRangeError { + pub fn new<T>(range: Range<T>) -> Self { + Self { + kind: std::any::type_name::<T>(), + range: range.erase_type(), + } + } +} + +impl<T> Clone for Range<T> { + fn clone(&self) -> Self { + Range { + inner: self.inner.clone(), + marker: self.marker, + } + } +} + +impl<T> fmt::Debug for Range<T> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "[{}..{}]", self.inner.start + 1, self.inner.end) + } +} + +impl<T> Iterator for Range<T> { + type Item = Handle<T>; + fn next(&mut self) -> Option<Self::Item> { + if self.inner.start < self.inner.end { + self.inner.start += 1; + Some(Handle { + index: NonZeroU32::new(self.inner.start).unwrap(), + marker: self.marker, + }) + } else { + None + } + } +} + +impl<T> Range<T> { + /// Return a range enclosing handles `first` through `last`, inclusive. + pub fn new_from_bounds(first: Handle<T>, last: Handle<T>) -> Self { + Self { + inner: (first.index() as u32)..(last.index() as u32 + 1), + marker: Default::default(), + } + } + + /// return the first and last handles included in `self`. + /// + /// If `self` is an empty range, there are no handles included, so + /// return `None`. + pub fn first_and_last(&self) -> Option<(Handle<T>, Handle<T>)> { + if self.inner.start < self.inner.end { + Some(( + // `Range::new_from_bounds` expects a 1-based, start- and + // end-inclusive range, but `self.inner` is a zero-based, + // end-exclusive range. + Handle::new(Index::new(self.inner.start + 1).unwrap()), + Handle::new(Index::new(self.inner.end).unwrap()), + )) + } else { + None + } + } + + /// Return the zero-based index range covered by `self`. + pub fn zero_based_index_range(&self) -> ops::Range<u32> { + self.inner.clone() + } + + /// Construct a `Range` that covers the zero-based indices in `inner`. + pub fn from_zero_based_index_range(inner: ops::Range<u32>, arena: &Arena<T>) -> Self { + // Since `inner` is a `Range<u32>`, we only need to check that + // the start and end are well-ordered, and that the end fits + // within `arena`. + assert!(inner.start <= inner.end); + assert!(inner.end as usize <= arena.len()); + Self { + inner, + marker: Default::default(), + } + } +} + +/// An arena holding some kind of component (e.g., type, constant, +/// instruction, etc.) that can be referenced. +/// +/// Adding new items to the arena produces a strongly-typed [`Handle`]. +/// The arena can be indexed using the given handle to obtain +/// a reference to the stored item. +#[cfg_attr(feature = "clone", derive(Clone))] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "serialize", serde(transparent))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] +pub struct Arena<T> { + /// Values of this arena. + data: Vec<T>, + #[cfg_attr(feature = "serialize", serde(skip))] + span_info: Vec<Span>, +} + +impl<T> Default for Arena<T> { + fn default() -> Self { + Self::new() + } +} + +impl<T: fmt::Debug> fmt::Debug for Arena<T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_map().entries(self.iter()).finish() + } +} + +impl<T> Arena<T> { + /// Create a new arena with no initial capacity allocated. + pub const fn new() -> Self { + Arena { + data: Vec::new(), + span_info: Vec::new(), + } + } + + /// Extracts the inner vector. + #[allow(clippy::missing_const_for_fn)] // ignore due to requirement of #![feature(const_precise_live_drops)] + pub fn into_inner(self) -> Vec<T> { + self.data + } + + /// Returns the current number of items stored in this arena. + pub fn len(&self) -> usize { + self.data.len() + } + + /// Returns `true` if the arena contains no elements. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Returns an iterator over the items stored in this arena, returning both + /// the item's handle and a reference to it. + pub fn iter(&self) -> impl DoubleEndedIterator<Item = (Handle<T>, &T)> { + self.data + .iter() + .enumerate() + .map(|(i, v)| unsafe { (Handle::from_usize_unchecked(i), v) }) + } + + /// Returns a iterator over the items stored in this arena, + /// returning both the item's handle and a mutable reference to it. + pub fn iter_mut(&mut self) -> impl DoubleEndedIterator<Item = (Handle<T>, &mut T)> { + self.data + .iter_mut() + .enumerate() + .map(|(i, v)| unsafe { (Handle::from_usize_unchecked(i), v) }) + } + + /// Adds a new value to the arena, returning a typed handle. + pub fn append(&mut self, value: T, span: Span) -> Handle<T> { + let index = self.data.len(); + self.data.push(value); + self.span_info.push(span); + Handle::from_usize(index) + } + + /// Fetch a handle to an existing type. + pub fn fetch_if<F: Fn(&T) -> bool>(&self, fun: F) -> Option<Handle<T>> { + self.data + .iter() + .position(fun) + .map(|index| unsafe { Handle::from_usize_unchecked(index) }) + } + + /// Adds a value with a custom check for uniqueness: + /// returns a handle pointing to + /// an existing element if the check succeeds, or adds a new + /// element otherwise. + pub fn fetch_if_or_append<F: Fn(&T, &T) -> bool>( + &mut self, + value: T, + span: Span, + fun: F, + ) -> Handle<T> { + if let Some(index) = self.data.iter().position(|d| fun(d, &value)) { + unsafe { Handle::from_usize_unchecked(index) } + } else { + self.append(value, span) + } + } + + /// Adds a value with a check for uniqueness, where the check is plain comparison. + pub fn fetch_or_append(&mut self, value: T, span: Span) -> Handle<T> + where + T: PartialEq, + { + self.fetch_if_or_append(value, span, T::eq) + } + + pub fn try_get(&self, handle: Handle<T>) -> Result<&T, BadHandle> { + self.data + .get(handle.index()) + .ok_or_else(|| BadHandle::new(handle)) + } + + /// Get a mutable reference to an element in the arena. + pub fn get_mut(&mut self, handle: Handle<T>) -> &mut T { + self.data.get_mut(handle.index()).unwrap() + } + + /// Get the range of handles from a particular number of elements to the end. + pub fn range_from(&self, old_length: usize) -> Range<T> { + Range { + inner: old_length as u32..self.data.len() as u32, + marker: PhantomData, + } + } + + /// Clears the arena keeping all allocations + pub fn clear(&mut self) { + self.data.clear() + } + + pub fn get_span(&self, handle: Handle<T>) -> Span { + *self + .span_info + .get(handle.index()) + .unwrap_or(&Span::default()) + } + + /// Assert that `handle` is valid for this arena. + pub fn check_contains_handle(&self, handle: Handle<T>) -> Result<(), BadHandle> { + if handle.index() < self.data.len() { + Ok(()) + } else { + Err(BadHandle::new(handle)) + } + } + + /// Assert that `range` is valid for this arena. + pub fn check_contains_range(&self, range: &Range<T>) -> Result<(), BadRangeError> { + // Since `range.inner` is a `Range<u32>`, we only need to check that the + // start precedes the end, and that the end is in range. + if range.inner.start > range.inner.end { + return Err(BadRangeError::new(range.clone())); + } + + // Empty ranges are tolerated: they can be produced by compaction. + if range.inner.start == range.inner.end { + return Ok(()); + } + + // `range.inner` is zero-based, but end-exclusive, so `range.inner.end` + // is actually the right one-based index for the last handle within the + // range. + let last_handle = Handle::new(range.inner.end.try_into().unwrap()); + if self.check_contains_handle(last_handle).is_err() { + return Err(BadRangeError::new(range.clone())); + } + + Ok(()) + } + + #[cfg(feature = "compact")] + pub(crate) fn retain_mut<P>(&mut self, mut predicate: P) + where + P: FnMut(Handle<T>, &mut T) -> bool, + { + let mut index = 0; + let mut retained = 0; + self.data.retain_mut(|elt| { + let handle = Handle::new(Index::new(index as u32 + 1).unwrap()); + let keep = predicate(handle, elt); + + // Since `predicate` needs mutable access to each element, + // we can't feasibly call it twice, so we have to compact + // spans by hand in parallel as part of this iteration. + if keep { + self.span_info[retained] = self.span_info[index]; + retained += 1; + } + + index += 1; + keep + }); + + self.span_info.truncate(retained); + } +} + +#[cfg(feature = "deserialize")] +impl<'de, T> serde::Deserialize<'de> for Arena<T> +where + T: serde::Deserialize<'de>, +{ + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + let data = Vec::deserialize(deserializer)?; + let span_info = std::iter::repeat(Span::default()) + .take(data.len()) + .collect(); + + Ok(Self { data, span_info }) + } +} + +impl<T> ops::Index<Handle<T>> for Arena<T> { + type Output = T; + fn index(&self, handle: Handle<T>) -> &T { + &self.data[handle.index()] + } +} + +impl<T> ops::IndexMut<Handle<T>> for Arena<T> { + fn index_mut(&mut self, handle: Handle<T>) -> &mut T { + &mut self.data[handle.index()] + } +} + +impl<T> ops::Index<Range<T>> for Arena<T> { + type Output = [T]; + fn index(&self, range: Range<T>) -> &[T] { + &self.data[range.inner.start as usize..range.inner.end as usize] + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn append_non_unique() { + let mut arena: Arena<u8> = Arena::new(); + let t1 = arena.append(0, Default::default()); + let t2 = arena.append(0, Default::default()); + assert!(t1 != t2); + assert!(arena[t1] == arena[t2]); + } + + #[test] + fn append_unique() { + let mut arena: Arena<u8> = Arena::new(); + let t1 = arena.append(0, Default::default()); + let t2 = arena.append(1, Default::default()); + assert!(t1 != t2); + assert!(arena[t1] != arena[t2]); + } + + #[test] + fn fetch_or_append_non_unique() { + let mut arena: Arena<u8> = Arena::new(); + let t1 = arena.fetch_or_append(0, Default::default()); + let t2 = arena.fetch_or_append(0, Default::default()); + assert!(t1 == t2); + assert!(arena[t1] == arena[t2]) + } + + #[test] + fn fetch_or_append_unique() { + let mut arena: Arena<u8> = Arena::new(); + let t1 = arena.fetch_or_append(0, Default::default()); + let t2 = arena.fetch_or_append(1, Default::default()); + assert!(t1 != t2); + assert!(arena[t1] != arena[t2]); + } +} + +/// An arena whose elements are guaranteed to be unique. +/// +/// A `UniqueArena` holds a set of unique values of type `T`, each with an +/// associated [`Span`]. Inserting a value returns a `Handle<T>`, which can be +/// used to index the `UniqueArena` and obtain shared access to the `T` element. +/// Access via a `Handle` is an array lookup - no hash lookup is necessary. +/// +/// The element type must implement `Eq` and `Hash`. Insertions of equivalent +/// elements, according to `Eq`, all return the same `Handle`. +/// +/// Once inserted, elements may not be mutated. +/// +/// `UniqueArena` is similar to [`Arena`]: If `Arena` is vector-like, +/// `UniqueArena` is `HashSet`-like. +#[cfg_attr(feature = "clone", derive(Clone))] +pub struct UniqueArena<T> { + set: FastIndexSet<T>, + + /// Spans for the elements, indexed by handle. + /// + /// The length of this vector is always equal to `set.len()`. `FastIndexSet` + /// promises that its elements "are indexed in a compact range, without + /// holes in the range 0..set.len()", so we can always use the indices + /// returned by insertion as indices into this vector. + span_info: Vec<Span>, +} + +impl<T> UniqueArena<T> { + /// Create a new arena with no initial capacity allocated. + pub fn new() -> Self { + UniqueArena { + set: FastIndexSet::default(), + span_info: Vec::new(), + } + } + + /// Return the current number of items stored in this arena. + pub fn len(&self) -> usize { + self.set.len() + } + + /// Return `true` if the arena contains no elements. + pub fn is_empty(&self) -> bool { + self.set.is_empty() + } + + /// Clears the arena, keeping all allocations. + pub fn clear(&mut self) { + self.set.clear(); + self.span_info.clear(); + } + + /// Return the span associated with `handle`. + /// + /// If a value has been inserted multiple times, the span returned is the + /// one provided with the first insertion. + pub fn get_span(&self, handle: Handle<T>) -> Span { + *self + .span_info + .get(handle.index()) + .unwrap_or(&Span::default()) + } + + #[cfg(feature = "compact")] + pub(crate) fn drain_all(&mut self) -> UniqueArenaDrain<T> { + UniqueArenaDrain { + inner_elts: self.set.drain(..), + inner_spans: self.span_info.drain(..), + index: Index::new(1).unwrap(), + } + } +} + +#[cfg(feature = "compact")] +pub(crate) struct UniqueArenaDrain<'a, T> { + inner_elts: indexmap::set::Drain<'a, T>, + inner_spans: std::vec::Drain<'a, Span>, + index: Index, +} + +#[cfg(feature = "compact")] +impl<'a, T> Iterator for UniqueArenaDrain<'a, T> { + type Item = (Handle<T>, T, Span); + + fn next(&mut self) -> Option<Self::Item> { + match self.inner_elts.next() { + Some(elt) => { + let handle = Handle::new(self.index); + self.index = self.index.checked_add(1).unwrap(); + let span = self.inner_spans.next().unwrap(); + Some((handle, elt, span)) + } + None => None, + } + } +} + +impl<T: Eq + hash::Hash> UniqueArena<T> { + /// Returns an iterator over the items stored in this arena, returning both + /// the item's handle and a reference to it. + pub fn iter(&self) -> impl DoubleEndedIterator<Item = (Handle<T>, &T)> { + self.set.iter().enumerate().map(|(i, v)| { + let position = i + 1; + let index = unsafe { Index::new_unchecked(position as u32) }; + (Handle::new(index), v) + }) + } + + /// Insert a new value into the arena. + /// + /// Return a [`Handle<T>`], which can be used to index this arena to get a + /// shared reference to the element. + /// + /// If this arena already contains an element that is `Eq` to `value`, + /// return a `Handle` to the existing element, and drop `value`. + /// + /// If `value` is inserted into the arena, associate `span` with + /// it. An element's span can be retrieved with the [`get_span`] + /// method. + /// + /// [`Handle<T>`]: Handle + /// [`get_span`]: UniqueArena::get_span + pub fn insert(&mut self, value: T, span: Span) -> Handle<T> { + let (index, added) = self.set.insert_full(value); + + if added { + debug_assert!(index == self.span_info.len()); + self.span_info.push(span); + } + + debug_assert!(self.set.len() == self.span_info.len()); + + Handle::from_usize(index) + } + + /// Replace an old value with a new value. + /// + /// # Panics + /// + /// - if the old value is not in the arena + /// - if the new value already exists in the arena + pub fn replace(&mut self, old: Handle<T>, new: T) { + let (index, added) = self.set.insert_full(new); + assert!(added && index == self.set.len() - 1); + + self.set.swap_remove_index(old.index()).unwrap(); + } + + /// Return this arena's handle for `value`, if present. + /// + /// If this arena already contains an element equal to `value`, + /// return its handle. Otherwise, return `None`. + pub fn get(&self, value: &T) -> Option<Handle<T>> { + self.set + .get_index_of(value) + .map(|index| unsafe { Handle::from_usize_unchecked(index) }) + } + + /// Return this arena's value at `handle`, if that is a valid handle. + pub fn get_handle(&self, handle: Handle<T>) -> Result<&T, BadHandle> { + self.set + .get_index(handle.index()) + .ok_or_else(|| BadHandle::new(handle)) + } + + /// Assert that `handle` is valid for this arena. + pub fn check_contains_handle(&self, handle: Handle<T>) -> Result<(), BadHandle> { + if handle.index() < self.set.len() { + Ok(()) + } else { + Err(BadHandle::new(handle)) + } + } +} + +impl<T> Default for UniqueArena<T> { + fn default() -> Self { + Self::new() + } +} + +impl<T: fmt::Debug + Eq + hash::Hash> fmt::Debug for UniqueArena<T> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_map().entries(self.iter()).finish() + } +} + +impl<T> ops::Index<Handle<T>> for UniqueArena<T> { + type Output = T; + fn index(&self, handle: Handle<T>) -> &T { + &self.set[handle.index()] + } +} + +#[cfg(feature = "serialize")] +impl<T> serde::Serialize for UniqueArena<T> +where + T: Eq + hash::Hash + serde::Serialize, +{ + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: serde::Serializer, + { + self.set.serialize(serializer) + } +} + +#[cfg(feature = "deserialize")] +impl<'de, T> serde::Deserialize<'de> for UniqueArena<T> +where + T: Eq + hash::Hash + serde::Deserialize<'de>, +{ + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + let set = FastIndexSet::deserialize(deserializer)?; + let span_info = std::iter::repeat(Span::default()).take(set.len()).collect(); + + Ok(Self { set, span_info }) + } +} + +//Note: largely borrowed from `HashSet` implementation +#[cfg(feature = "arbitrary")] +impl<'a, T> arbitrary::Arbitrary<'a> for UniqueArena<T> +where + T: Eq + hash::Hash + arbitrary::Arbitrary<'a>, +{ + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> { + let mut arena = Self::default(); + for elem in u.arbitrary_iter()? { + arena.set.insert(elem?); + arena.span_info.push(Span::UNDEFINED); + } + Ok(arena) + } + + fn arbitrary_take_rest(u: arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> { + let mut arena = Self::default(); + for elem in u.arbitrary_take_rest_iter()? { + arena.set.insert(elem?); + arena.span_info.push(Span::UNDEFINED); + } + Ok(arena) + } + + #[inline] + fn size_hint(depth: usize) -> (usize, Option<usize>) { + let depth_hint = <usize as arbitrary::Arbitrary>::size_hint(depth); + arbitrary::size_hint::and(depth_hint, (0, None)) + } +} diff --git a/third_party/rust/naga/src/back/dot/mod.rs b/third_party/rust/naga/src/back/dot/mod.rs new file mode 100644 index 0000000000..1556371df1 --- /dev/null +++ b/third_party/rust/naga/src/back/dot/mod.rs @@ -0,0 +1,703 @@ +/*! +Backend for [DOT][dot] (Graphviz). + +This backend writes a graph in the DOT language, for the ease +of IR inspection and debugging. + +[dot]: https://graphviz.org/doc/info/lang.html +*/ + +use crate::{ + arena::Handle, + valid::{FunctionInfo, ModuleInfo}, +}; + +use std::{ + borrow::Cow, + fmt::{Error as FmtError, Write as _}, +}; + +/// Configuration options for the dot backend +#[derive(Clone, Default)] +pub struct Options { + /// Only emit function bodies + pub cfg_only: bool, +} + +/// Identifier used to address a graph node +type NodeId = usize; + +/// Stores the target nodes for control flow statements +#[derive(Default, Clone, Copy)] +struct Targets { + /// The node, if some, where continue operations will land + continue_target: Option<usize>, + /// The node, if some, where break operations will land + break_target: Option<usize>, +} + +/// Stores information about the graph of statements +#[derive(Default)] +struct StatementGraph { + /// List of node names + nodes: Vec<&'static str>, + /// List of edges of the control flow, the items are defined as + /// (from, to, label) + flow: Vec<(NodeId, NodeId, &'static str)>, + /// List of implicit edges of the control flow, used for jump + /// operations such as continue or break, the items are defined as + /// (from, to, label, color_id) + jumps: Vec<(NodeId, NodeId, &'static str, usize)>, + /// List of dependency relationships between a statement node and + /// expressions + dependencies: Vec<(NodeId, Handle<crate::Expression>, &'static str)>, + /// List of expression emitted by statement node + emits: Vec<(NodeId, Handle<crate::Expression>)>, + /// List of function call by statement node + calls: Vec<(NodeId, Handle<crate::Function>)>, +} + +impl StatementGraph { + /// Adds a new block to the statement graph, returning the first and last node, respectively + fn add(&mut self, block: &[crate::Statement], targets: Targets) -> (NodeId, NodeId) { + use crate::Statement as S; + + // The first node of the block isn't a statement but a virtual node + let root = self.nodes.len(); + self.nodes.push(if root == 0 { "Root" } else { "Node" }); + // Track the last placed node, this will be returned to the caller and + // will also be used to generate the control flow edges + let mut last_node = root; + for statement in block { + // Reserve a new node for the current statement and link it to the + // node of the previous statement + let id = self.nodes.len(); + self.flow.push((last_node, id, "")); + self.nodes.push(""); // reserve space + + // Track the node identifier for the merge node, the merge node is + // the last node of a statement, normally this is the node itself, + // but for control flow statements such as `if`s and `switch`s this + // is a virtual node where all branches merge back. + let mut merge_id = id; + + self.nodes[id] = match *statement { + S::Emit(ref range) => { + for handle in range.clone() { + self.emits.push((id, handle)); + } + "Emit" + } + S::Kill => "Kill", //TODO: link to the beginning + S::Break => { + // Try to link to the break target, otherwise produce + // a broken connection + if let Some(target) = targets.break_target { + self.jumps.push((id, target, "Break", 5)) + } else { + self.jumps.push((id, root, "Broken", 7)) + } + "Break" + } + S::Continue => { + // Try to link to the continue target, otherwise produce + // a broken connection + if let Some(target) = targets.continue_target { + self.jumps.push((id, target, "Continue", 5)) + } else { + self.jumps.push((id, root, "Broken", 7)) + } + "Continue" + } + S::Barrier(_flags) => "Barrier", + S::Block(ref b) => { + let (other, last) = self.add(b, targets); + self.flow.push((id, other, "")); + // All following nodes should connect to the end of the block + // statement so change the merge id to it. + merge_id = last; + "Block" + } + S::If { + condition, + ref accept, + ref reject, + } => { + self.dependencies.push((id, condition, "condition")); + let (accept_id, accept_last) = self.add(accept, targets); + self.flow.push((id, accept_id, "accept")); + let (reject_id, reject_last) = self.add(reject, targets); + self.flow.push((id, reject_id, "reject")); + + // Create a merge node, link the branches to it and set it + // as the merge node to make the next statement node link to it + merge_id = self.nodes.len(); + self.nodes.push("Merge"); + self.flow.push((accept_last, merge_id, "")); + self.flow.push((reject_last, merge_id, "")); + + "If" + } + S::Switch { + selector, + ref cases, + } => { + self.dependencies.push((id, selector, "selector")); + + // Create a merge node and set it as the merge node to make + // the next statement node link to it + merge_id = self.nodes.len(); + self.nodes.push("Merge"); + + // Create a new targets structure and set the break target + // to the merge node + let mut targets = targets; + targets.break_target = Some(merge_id); + + for case in cases { + let (case_id, case_last) = self.add(&case.body, targets); + let label = match case.value { + crate::SwitchValue::Default => "default", + _ => "case", + }; + self.flow.push((id, case_id, label)); + // Link the last node of the branch to the merge node + self.flow.push((case_last, merge_id, "")); + } + "Switch" + } + S::Loop { + ref body, + ref continuing, + break_if, + } => { + // Create a new targets structure and set the break target + // to the merge node, this must happen before generating the + // continuing block since it can break. + let mut targets = targets; + targets.break_target = Some(id); + + let (continuing_id, continuing_last) = self.add(continuing, targets); + + // Set the the continue target to the beginning + // of the newly generated continuing block + targets.continue_target = Some(continuing_id); + + let (body_id, body_last) = self.add(body, targets); + + self.flow.push((id, body_id, "body")); + + // Link the last node of the body to the continuing block + self.flow.push((body_last, continuing_id, "continuing")); + // Link the last node of the continuing block back to the + // beginning of the loop body + self.flow.push((continuing_last, body_id, "continuing")); + + if let Some(expr) = break_if { + self.dependencies.push((continuing_id, expr, "break if")); + } + + "Loop" + } + S::Return { value } => { + if let Some(expr) = value { + self.dependencies.push((id, expr, "value")); + } + "Return" + } + S::Store { pointer, value } => { + self.dependencies.push((id, value, "value")); + self.emits.push((id, pointer)); + "Store" + } + S::ImageStore { + image, + coordinate, + array_index, + value, + } => { + self.dependencies.push((id, image, "image")); + self.dependencies.push((id, coordinate, "coordinate")); + if let Some(expr) = array_index { + self.dependencies.push((id, expr, "array_index")); + } + self.dependencies.push((id, value, "value")); + "ImageStore" + } + S::Call { + function, + ref arguments, + result, + } => { + for &arg in arguments { + self.dependencies.push((id, arg, "arg")); + } + if let Some(expr) = result { + self.emits.push((id, expr)); + } + self.calls.push((id, function)); + "Call" + } + S::Atomic { + pointer, + ref fun, + value, + result, + } => { + self.emits.push((id, result)); + self.dependencies.push((id, pointer, "pointer")); + self.dependencies.push((id, value, "value")); + if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { + self.dependencies.push((id, cmp, "cmp")); + } + "Atomic" + } + S::WorkGroupUniformLoad { pointer, result } => { + self.emits.push((id, result)); + self.dependencies.push((id, pointer, "pointer")); + "WorkGroupUniformLoad" + } + S::RayQuery { query, ref fun } => { + self.dependencies.push((id, query, "query")); + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + self.dependencies.push(( + id, + acceleration_structure, + "acceleration_structure", + )); + self.dependencies.push((id, descriptor, "descriptor")); + "RayQueryInitialize" + } + crate::RayQueryFunction::Proceed { result } => { + self.emits.push((id, result)); + "RayQueryProceed" + } + crate::RayQueryFunction::Terminate => "RayQueryTerminate", + } + } + }; + // Set the last node to the merge node + last_node = merge_id; + } + (root, last_node) + } +} + +#[allow(clippy::manual_unwrap_or)] +fn name(option: &Option<String>) -> &str { + match *option { + Some(ref name) => name, + None => "", + } +} + +/// set39 color scheme from <https://graphviz.org/doc/info/colors.html> +const COLORS: &[&str] = &[ + "white", // pattern starts at 1 + "#8dd3c7", "#ffffb3", "#bebada", "#fb8072", "#80b1d3", "#fdb462", "#b3de69", "#fccde5", + "#d9d9d9", +]; + +fn write_fun( + output: &mut String, + prefix: String, + fun: &crate::Function, + info: Option<&FunctionInfo>, + options: &Options, +) -> Result<(), FmtError> { + writeln!(output, "\t\tnode [ style=filled ]")?; + + if !options.cfg_only { + for (handle, var) in fun.local_variables.iter() { + writeln!( + output, + "\t\t{}_l{} [ shape=hexagon label=\"{:?} '{}'\" ]", + prefix, + handle.index(), + handle, + name(&var.name), + )?; + } + + write_function_expressions(output, &prefix, fun, info)?; + } + + let mut sg = StatementGraph::default(); + sg.add(&fun.body, Targets::default()); + for (index, label) in sg.nodes.into_iter().enumerate() { + writeln!( + output, + "\t\t{prefix}_s{index} [ shape=square label=\"{label}\" ]", + )?; + } + for (from, to, label) in sg.flow { + writeln!( + output, + "\t\t{prefix}_s{from} -> {prefix}_s{to} [ arrowhead=tee label=\"{label}\" ]", + )?; + } + for (from, to, label, color_id) in sg.jumps { + writeln!( + output, + "\t\t{}_s{} -> {}_s{} [ arrowhead=tee style=dashed color=\"{}\" label=\"{}\" ]", + prefix, from, prefix, to, COLORS[color_id], label, + )?; + } + + if !options.cfg_only { + for (to, expr, label) in sg.dependencies { + writeln!( + output, + "\t\t{}_e{} -> {}_s{} [ label=\"{}\" ]", + prefix, + expr.index(), + prefix, + to, + label, + )?; + } + for (from, to) in sg.emits { + writeln!( + output, + "\t\t{}_s{} -> {}_e{} [ style=dotted ]", + prefix, + from, + prefix, + to.index(), + )?; + } + } + + for (from, function) in sg.calls { + writeln!( + output, + "\t\t{}_s{} -> f{}_s0", + prefix, + from, + function.index(), + )?; + } + + Ok(()) +} + +fn write_function_expressions( + output: &mut String, + prefix: &str, + fun: &crate::Function, + info: Option<&FunctionInfo>, +) -> Result<(), FmtError> { + enum Payload<'a> { + Arguments(&'a [Handle<crate::Expression>]), + Local(Handle<crate::LocalVariable>), + Global(Handle<crate::GlobalVariable>), + } + + let mut edges = crate::FastHashMap::<&str, _>::default(); + let mut payload = None; + for (handle, expression) in fun.expressions.iter() { + use crate::Expression as E; + let (label, color_id) = match *expression { + E::Literal(_) => ("Literal".into(), 2), + E::Constant(_) => ("Constant".into(), 2), + E::ZeroValue(_) => ("ZeroValue".into(), 2), + E::Compose { ref components, .. } => { + payload = Some(Payload::Arguments(components)); + ("Compose".into(), 3) + } + E::Access { base, index } => { + edges.insert("base", base); + edges.insert("index", index); + ("Access".into(), 1) + } + E::AccessIndex { base, index } => { + edges.insert("base", base); + (format!("AccessIndex[{index}]").into(), 1) + } + E::Splat { size, value } => { + edges.insert("value", value); + (format!("Splat{size:?}").into(), 3) + } + E::Swizzle { + size, + vector, + pattern, + } => { + edges.insert("vector", vector); + (format!("Swizzle{:?}", &pattern[..size as usize]).into(), 3) + } + E::FunctionArgument(index) => (format!("Argument[{index}]").into(), 1), + E::GlobalVariable(h) => { + payload = Some(Payload::Global(h)); + ("Global".into(), 2) + } + E::LocalVariable(h) => { + payload = Some(Payload::Local(h)); + ("Local".into(), 1) + } + E::Load { pointer } => { + edges.insert("pointer", pointer); + ("Load".into(), 4) + } + E::ImageSample { + image, + sampler, + gather, + coordinate, + array_index, + offset: _, + level, + depth_ref, + } => { + edges.insert("image", image); + edges.insert("sampler", sampler); + edges.insert("coordinate", coordinate); + if let Some(expr) = array_index { + edges.insert("array_index", expr); + } + match level { + crate::SampleLevel::Auto => {} + crate::SampleLevel::Zero => {} + crate::SampleLevel::Exact(expr) => { + edges.insert("level", expr); + } + crate::SampleLevel::Bias(expr) => { + edges.insert("bias", expr); + } + crate::SampleLevel::Gradient { x, y } => { + edges.insert("grad_x", x); + edges.insert("grad_y", y); + } + } + if let Some(expr) = depth_ref { + edges.insert("depth_ref", expr); + } + let string = match gather { + Some(component) => Cow::Owned(format!("ImageGather{component:?}")), + _ => Cow::Borrowed("ImageSample"), + }; + (string, 5) + } + E::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + edges.insert("image", image); + edges.insert("coordinate", coordinate); + if let Some(expr) = array_index { + edges.insert("array_index", expr); + } + if let Some(sample) = sample { + edges.insert("sample", sample); + } + if let Some(level) = level { + edges.insert("level", level); + } + ("ImageLoad".into(), 5) + } + E::ImageQuery { image, query } => { + edges.insert("image", image); + let args = match query { + crate::ImageQuery::Size { level } => { + if let Some(expr) = level { + edges.insert("level", expr); + } + Cow::from("ImageSize") + } + _ => Cow::Owned(format!("{query:?}")), + }; + (args, 7) + } + E::Unary { op, expr } => { + edges.insert("expr", expr); + (format!("{op:?}").into(), 6) + } + E::Binary { op, left, right } => { + edges.insert("left", left); + edges.insert("right", right); + (format!("{op:?}").into(), 6) + } + E::Select { + condition, + accept, + reject, + } => { + edges.insert("condition", condition); + edges.insert("accept", accept); + edges.insert("reject", reject); + ("Select".into(), 3) + } + E::Derivative { axis, ctrl, expr } => { + edges.insert("", expr); + (format!("d{axis:?}{ctrl:?}").into(), 8) + } + E::Relational { fun, argument } => { + edges.insert("arg", argument); + (format!("{fun:?}").into(), 6) + } + E::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + edges.insert("arg", arg); + if let Some(expr) = arg1 { + edges.insert("arg1", expr); + } + if let Some(expr) = arg2 { + edges.insert("arg2", expr); + } + if let Some(expr) = arg3 { + edges.insert("arg3", expr); + } + (format!("{fun:?}").into(), 7) + } + E::As { + kind, + expr, + convert, + } => { + edges.insert("", expr); + let string = match convert { + Some(width) => format!("Convert<{kind:?},{width}>"), + None => format!("Bitcast<{kind:?}>"), + }; + (string.into(), 3) + } + E::CallResult(_function) => ("CallResult".into(), 4), + E::AtomicResult { .. } => ("AtomicResult".into(), 4), + E::WorkGroupUniformLoadResult { .. } => ("WorkGroupUniformLoadResult".into(), 4), + E::ArrayLength(expr) => { + edges.insert("", expr); + ("ArrayLength".into(), 7) + } + E::RayQueryProceedResult => ("rayQueryProceedResult".into(), 4), + E::RayQueryGetIntersection { query, committed } => { + edges.insert("", query); + let ty = if committed { "Committed" } else { "Candidate" }; + (format!("rayQueryGet{}Intersection", ty).into(), 4) + } + }; + + // give uniform expressions an outline + let color_attr = match info { + Some(info) if info[handle].uniformity.non_uniform_result.is_none() => "fillcolor", + _ => "color", + }; + writeln!( + output, + "\t\t{}_e{} [ {}=\"{}\" label=\"{:?} {}\" ]", + prefix, + handle.index(), + color_attr, + COLORS[color_id], + handle, + label, + )?; + + for (key, edge) in edges.drain() { + writeln!( + output, + "\t\t{}_e{} -> {}_e{} [ label=\"{}\" ]", + prefix, + edge.index(), + prefix, + handle.index(), + key, + )?; + } + match payload.take() { + Some(Payload::Arguments(list)) => { + write!(output, "\t\t{{")?; + for &comp in list { + write!(output, " {}_e{}", prefix, comp.index())?; + } + writeln!(output, " }} -> {}_e{}", prefix, handle.index())?; + } + Some(Payload::Local(h)) => { + writeln!( + output, + "\t\t{}_l{} -> {}_e{}", + prefix, + h.index(), + prefix, + handle.index(), + )?; + } + Some(Payload::Global(h)) => { + writeln!( + output, + "\t\tg{} -> {}_e{} [fillcolor=gray]", + h.index(), + prefix, + handle.index(), + )?; + } + None => {} + } + } + + Ok(()) +} + +/// Write shader module to a [`String`]. +pub fn write( + module: &crate::Module, + mod_info: Option<&ModuleInfo>, + options: Options, +) -> Result<String, FmtError> { + use std::fmt::Write as _; + + let mut output = String::new(); + output += "digraph Module {\n"; + + if !options.cfg_only { + writeln!(output, "\tsubgraph cluster_globals {{")?; + writeln!(output, "\t\tlabel=\"Globals\"")?; + for (handle, var) in module.global_variables.iter() { + writeln!( + output, + "\t\tg{} [ shape=hexagon label=\"{:?} {:?}/'{}'\" ]", + handle.index(), + handle, + var.space, + name(&var.name), + )?; + } + writeln!(output, "\t}}")?; + } + + for (handle, fun) in module.functions.iter() { + let prefix = format!("f{}", handle.index()); + writeln!(output, "\tsubgraph cluster_{prefix} {{")?; + writeln!( + output, + "\t\tlabel=\"Function{:?}/'{}'\"", + handle, + name(&fun.name) + )?; + let info = mod_info.map(|a| &a[handle]); + write_fun(&mut output, prefix, fun, info, &options)?; + writeln!(output, "\t}}")?; + } + for (ep_index, ep) in module.entry_points.iter().enumerate() { + let prefix = format!("ep{ep_index}"); + writeln!(output, "\tsubgraph cluster_{prefix} {{")?; + writeln!(output, "\t\tlabel=\"{:?}/'{}'\"", ep.stage, ep.name)?; + let info = mod_info.map(|a| a.get_entry_point(ep_index)); + write_fun(&mut output, prefix, &ep.function, info, &options)?; + writeln!(output, "\t}}")?; + } + + output += "}\n"; + Ok(output) +} diff --git a/third_party/rust/naga/src/back/glsl/features.rs b/third_party/rust/naga/src/back/glsl/features.rs new file mode 100644 index 0000000000..e7de05f695 --- /dev/null +++ b/third_party/rust/naga/src/back/glsl/features.rs @@ -0,0 +1,536 @@ +use super::{BackendResult, Error, Version, Writer}; +use crate::{ + back::glsl::{Options, WriterFlags}, + AddressSpace, Binding, Expression, Handle, ImageClass, ImageDimension, Interpolation, Sampling, + Scalar, ScalarKind, ShaderStage, StorageFormat, Type, TypeInner, +}; +use std::fmt::Write; + +bitflags::bitflags! { + /// Structure used to encode additions to GLSL that aren't supported by all versions. + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct Features: u32 { + /// Buffer address space support. + const BUFFER_STORAGE = 1; + const ARRAY_OF_ARRAYS = 1 << 1; + /// 8 byte floats. + const DOUBLE_TYPE = 1 << 2; + /// More image formats. + const FULL_IMAGE_FORMATS = 1 << 3; + const MULTISAMPLED_TEXTURES = 1 << 4; + const MULTISAMPLED_TEXTURE_ARRAYS = 1 << 5; + const CUBE_TEXTURES_ARRAY = 1 << 6; + const COMPUTE_SHADER = 1 << 7; + /// Image load and early depth tests. + const IMAGE_LOAD_STORE = 1 << 8; + const CONSERVATIVE_DEPTH = 1 << 9; + /// Interpolation and auxiliary qualifiers. + /// + /// Perspective, Flat, and Centroid are available in all GLSL versions we support. + const NOPERSPECTIVE_QUALIFIER = 1 << 11; + const SAMPLE_QUALIFIER = 1 << 12; + const CLIP_DISTANCE = 1 << 13; + const CULL_DISTANCE = 1 << 14; + /// Sample ID. + const SAMPLE_VARIABLES = 1 << 15; + /// Arrays with a dynamic length. + const DYNAMIC_ARRAY_SIZE = 1 << 16; + const MULTI_VIEW = 1 << 17; + /// Texture samples query + const TEXTURE_SAMPLES = 1 << 18; + /// Texture levels query + const TEXTURE_LEVELS = 1 << 19; + /// Image size query + const IMAGE_SIZE = 1 << 20; + /// Dual source blending + const DUAL_SOURCE_BLENDING = 1 << 21; + /// Instance index + /// + /// We can always support this, either through the language or a polyfill + const INSTANCE_INDEX = 1 << 22; + } +} + +/// Helper structure used to store the required [`Features`] needed to output a +/// [`Module`](crate::Module) +/// +/// Provides helper methods to check for availability and writing required extensions +pub struct FeaturesManager(Features); + +impl FeaturesManager { + /// Creates a new [`FeaturesManager`] instance + pub const fn new() -> Self { + Self(Features::empty()) + } + + /// Adds to the list of required [`Features`] + pub fn request(&mut self, features: Features) { + self.0 |= features + } + + /// Checks if the list of features [`Features`] contains the specified [`Features`] + pub fn contains(&mut self, features: Features) -> bool { + self.0.contains(features) + } + + /// Checks that all required [`Features`] are available for the specified + /// [`Version`] otherwise returns an [`Error::MissingFeatures`]. + pub fn check_availability(&self, version: Version) -> BackendResult { + // Will store all the features that are unavailable + let mut missing = Features::empty(); + + // Helper macro to check for feature availability + macro_rules! check_feature { + // Used when only core glsl supports the feature + ($feature:ident, $core:literal) => { + if self.0.contains(Features::$feature) + && (version < Version::Desktop($core) || version.is_es()) + { + missing |= Features::$feature; + } + }; + // Used when both core and es support the feature + ($feature:ident, $core:literal, $es:literal) => { + if self.0.contains(Features::$feature) + && (version < Version::Desktop($core) || version < Version::new_gles($es)) + { + missing |= Features::$feature; + } + }; + } + + check_feature!(COMPUTE_SHADER, 420, 310); + check_feature!(BUFFER_STORAGE, 400, 310); + check_feature!(DOUBLE_TYPE, 150); + check_feature!(CUBE_TEXTURES_ARRAY, 130, 310); + check_feature!(MULTISAMPLED_TEXTURES, 150, 300); + check_feature!(MULTISAMPLED_TEXTURE_ARRAYS, 150, 310); + check_feature!(ARRAY_OF_ARRAYS, 120, 310); + check_feature!(IMAGE_LOAD_STORE, 130, 310); + check_feature!(CONSERVATIVE_DEPTH, 130, 300); + check_feature!(NOPERSPECTIVE_QUALIFIER, 130); + check_feature!(SAMPLE_QUALIFIER, 400, 320); + check_feature!(CLIP_DISTANCE, 130, 300 /* with extension */); + check_feature!(CULL_DISTANCE, 450, 300 /* with extension */); + check_feature!(SAMPLE_VARIABLES, 400, 300); + check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310); + check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */); + match version { + Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300), + _ => check_feature!(MULTI_VIEW, 140, 310), + }; + // Only available on glsl core, this means that opengl es can't query the number + // of samples nor levels in a image and neither do bound checks on the sample nor + // the level argument of texelFecth + check_feature!(TEXTURE_SAMPLES, 150); + check_feature!(TEXTURE_LEVELS, 130); + check_feature!(IMAGE_SIZE, 430, 310); + + // Return an error if there are missing features + if missing.is_empty() { + Ok(()) + } else { + Err(Error::MissingFeatures(missing)) + } + } + + /// Helper method used to write all needed extensions + /// + /// # Notes + /// This won't check for feature availability so it might output extensions that aren't even + /// supported.[`check_availability`](Self::check_availability) will check feature availability + pub fn write(&self, options: &Options, mut out: impl Write) -> BackendResult { + if self.0.contains(Features::COMPUTE_SHADER) && !options.version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_compute_shader.txt + writeln!(out, "#extension GL_ARB_compute_shader : require")?; + } + + if self.0.contains(Features::BUFFER_STORAGE) && !options.version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_storage_buffer_object.txt + writeln!( + out, + "#extension GL_ARB_shader_storage_buffer_object : require" + )?; + } + + if self.0.contains(Features::DOUBLE_TYPE) && options.version < Version::Desktop(400) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_gpu_shader_fp64.txt + writeln!(out, "#extension GL_ARB_gpu_shader_fp64 : require")?; + } + + if self.0.contains(Features::CUBE_TEXTURES_ARRAY) { + if options.version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_cube_map_array.txt + writeln!(out, "#extension GL_EXT_texture_cube_map_array : require")?; + } else if options.version < Version::Desktop(400) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_cube_map_array.txt + writeln!(out, "#extension GL_ARB_texture_cube_map_array : require")?; + } + } + + if self.0.contains(Features::MULTISAMPLED_TEXTURE_ARRAYS) && options.version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_texture_storage_multisample_2d_array.txt + writeln!( + out, + "#extension GL_OES_texture_storage_multisample_2d_array : require" + )?; + } + + if self.0.contains(Features::ARRAY_OF_ARRAYS) && options.version < Version::Desktop(430) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_arrays_of_arrays.txt + writeln!(out, "#extension ARB_arrays_of_arrays : require")?; + } + + if self.0.contains(Features::IMAGE_LOAD_STORE) { + if self.0.contains(Features::FULL_IMAGE_FORMATS) && options.version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/NV/NV_image_formats.txt + writeln!(out, "#extension GL_NV_image_formats : require")?; + } + + if options.version < Version::Desktop(420) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_image_load_store.txt + writeln!(out, "#extension GL_ARB_shader_image_load_store : require")?; + } + } + + if self.0.contains(Features::CONSERVATIVE_DEPTH) { + if options.version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_conservative_depth.txt + writeln!(out, "#extension GL_EXT_conservative_depth : require")?; + } + + if options.version < Version::Desktop(420) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt + writeln!(out, "#extension GL_ARB_conservative_depth : require")?; + } + } + + if (self.0.contains(Features::CLIP_DISTANCE) || self.0.contains(Features::CULL_DISTANCE)) + && options.version.is_es() + { + // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_clip_cull_distance.txt + writeln!(out, "#extension GL_EXT_clip_cull_distance : require")?; + } + + if self.0.contains(Features::SAMPLE_VARIABLES) && options.version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_sample_variables.txt + writeln!(out, "#extension GL_OES_sample_variables : require")?; + } + + if self.0.contains(Features::MULTI_VIEW) { + if let Version::Embedded { is_webgl: true, .. } = options.version { + // https://www.khronos.org/registry/OpenGL/extensions/OVR/OVR_multiview2.txt + writeln!(out, "#extension GL_OVR_multiview2 : require")?; + } else { + // https://github.com/KhronosGroup/GLSL/blob/master/extensions/ext/GL_EXT_multiview.txt + writeln!(out, "#extension GL_EXT_multiview : require")?; + } + } + + if self.0.contains(Features::TEXTURE_SAMPLES) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_texture_image_samples.txt + writeln!( + out, + "#extension GL_ARB_shader_texture_image_samples : require" + )?; + } + + if self.0.contains(Features::TEXTURE_LEVELS) && options.version < Version::Desktop(430) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_query_levels.txt + writeln!(out, "#extension GL_ARB_texture_query_levels : require")?; + } + if self.0.contains(Features::DUAL_SOURCE_BLENDING) && options.version.is_es() { + // https://registry.khronos.org/OpenGL/extensions/EXT/EXT_blend_func_extended.txt + writeln!(out, "#extension GL_EXT_blend_func_extended : require")?; + } + + if self.0.contains(Features::INSTANCE_INDEX) { + if options.writer_flags.contains(WriterFlags::DRAW_PARAMETERS) { + // https://registry.khronos.org/OpenGL/extensions/ARB/ARB_shader_draw_parameters.txt + writeln!(out, "#extension GL_ARB_shader_draw_parameters : require")?; + } + } + + Ok(()) + } +} + +impl<'a, W> Writer<'a, W> { + /// Helper method that searches the module for all the needed [`Features`] + /// + /// # Errors + /// If the version doesn't support any of the needed [`Features`] a + /// [`Error::MissingFeatures`] will be returned + pub(super) fn collect_required_features(&mut self) -> BackendResult { + let ep_info = self.info.get_entry_point(self.entry_point_idx as usize); + + if let Some(depth_test) = self.entry_point.early_depth_test { + // If IMAGE_LOAD_STORE is supported for this version of GLSL + if self.options.version.supports_early_depth_test() { + self.features.request(Features::IMAGE_LOAD_STORE); + } + + if depth_test.conservative.is_some() { + self.features.request(Features::CONSERVATIVE_DEPTH); + } + } + + for arg in self.entry_point.function.arguments.iter() { + self.varying_required_features(arg.binding.as_ref(), arg.ty); + } + if let Some(ref result) = self.entry_point.function.result { + self.varying_required_features(result.binding.as_ref(), result.ty); + } + + if let ShaderStage::Compute = self.entry_point.stage { + self.features.request(Features::COMPUTE_SHADER) + } + + if self.multiview.is_some() { + self.features.request(Features::MULTI_VIEW); + } + + for (ty_handle, ty) in self.module.types.iter() { + match ty.inner { + TypeInner::Scalar(scalar) + | TypeInner::Vector { scalar, .. } + | TypeInner::Matrix { scalar, .. } => self.scalar_required_features(scalar), + TypeInner::Array { base, size, .. } => { + if let TypeInner::Array { .. } = self.module.types[base].inner { + self.features.request(Features::ARRAY_OF_ARRAYS) + } + + // If the array is dynamically sized + if size == crate::ArraySize::Dynamic { + let mut is_used = false; + + // Check if this type is used in a global that is needed by the current entrypoint + for (global_handle, global) in self.module.global_variables.iter() { + // Skip unused globals + if ep_info[global_handle].is_empty() { + continue; + } + + // If this array is the type of a global, then this array is used + if global.ty == ty_handle { + is_used = true; + break; + } + + // If the type of this global is a struct + if let crate::TypeInner::Struct { ref members, .. } = + self.module.types[global.ty].inner + { + // Check the last element of the struct to see if it's type uses + // this array + if let Some(last) = members.last() { + if last.ty == ty_handle { + is_used = true; + break; + } + } + } + } + + // If this dynamically size array is used, we need dynamic array size support + if is_used { + self.features.request(Features::DYNAMIC_ARRAY_SIZE); + } + } + } + TypeInner::Image { + dim, + arrayed, + class, + } => { + if arrayed && dim == ImageDimension::Cube { + self.features.request(Features::CUBE_TEXTURES_ARRAY) + } + + match class { + ImageClass::Sampled { multi: true, .. } + | ImageClass::Depth { multi: true } => { + self.features.request(Features::MULTISAMPLED_TEXTURES); + if arrayed { + self.features.request(Features::MULTISAMPLED_TEXTURE_ARRAYS); + } + } + ImageClass::Storage { format, .. } => match format { + StorageFormat::R8Unorm + | StorageFormat::R8Snorm + | StorageFormat::R8Uint + | StorageFormat::R8Sint + | StorageFormat::R16Uint + | StorageFormat::R16Sint + | StorageFormat::R16Float + | StorageFormat::Rg8Unorm + | StorageFormat::Rg8Snorm + | StorageFormat::Rg8Uint + | StorageFormat::Rg8Sint + | StorageFormat::Rg16Uint + | StorageFormat::Rg16Sint + | StorageFormat::Rg16Float + | StorageFormat::Rgb10a2Uint + | StorageFormat::Rgb10a2Unorm + | StorageFormat::Rg11b10Float + | StorageFormat::Rg32Uint + | StorageFormat::Rg32Sint + | StorageFormat::Rg32Float => { + self.features.request(Features::FULL_IMAGE_FORMATS) + } + _ => {} + }, + ImageClass::Sampled { multi: false, .. } + | ImageClass::Depth { multi: false } => {} + } + } + _ => {} + } + } + + let mut push_constant_used = false; + + for (handle, global) in self.module.global_variables.iter() { + if ep_info[handle].is_empty() { + continue; + } + match global.space { + AddressSpace::WorkGroup => self.features.request(Features::COMPUTE_SHADER), + AddressSpace::Storage { .. } => self.features.request(Features::BUFFER_STORAGE), + AddressSpace::PushConstant => { + if push_constant_used { + return Err(Error::MultiplePushConstants); + } + push_constant_used = true; + } + _ => {} + } + } + + // We will need to pass some of the members to a closure, so we need + // to separate them otherwise the borrow checker will complain, this + // shouldn't be needed in rust 2021 + let &mut Self { + module, + info, + ref mut features, + entry_point, + entry_point_idx, + ref policies, + .. + } = self; + + // Loop trough all expressions in both functions and the entry point + // to check for needed features + for (expressions, info) in module + .functions + .iter() + .map(|(h, f)| (&f.expressions, &info[h])) + .chain(std::iter::once(( + &entry_point.function.expressions, + info.get_entry_point(entry_point_idx as usize), + ))) + { + for (_, expr) in expressions.iter() { + match *expr { + // Check for queries that need aditonal features + Expression::ImageQuery { + image, + query, + .. + } => match query { + // Storage images use `imageSize` which is only available + // in glsl > 420 + // + // layers queries are also implemented as size queries + crate::ImageQuery::Size { .. } | crate::ImageQuery::NumLayers => { + if let TypeInner::Image { + class: crate::ImageClass::Storage { .. }, .. + } = *info[image].ty.inner_with(&module.types) { + features.request(Features::IMAGE_SIZE) + } + }, + crate::ImageQuery::NumLevels => features.request(Features::TEXTURE_LEVELS), + crate::ImageQuery::NumSamples => features.request(Features::TEXTURE_SAMPLES), + } + , + // Check for image loads that needs bound checking on the sample + // or level argument since this requires a feature + Expression::ImageLoad { + sample, level, .. + } => { + if policies.image_load != crate::proc::BoundsCheckPolicy::Unchecked { + if sample.is_some() { + features.request(Features::TEXTURE_SAMPLES) + } + + if level.is_some() { + features.request(Features::TEXTURE_LEVELS) + } + } + } + _ => {} + } + } + } + + self.features.check_availability(self.options.version) + } + + /// Helper method that checks the [`Features`] needed by a scalar + fn scalar_required_features(&mut self, scalar: Scalar) { + if scalar.kind == ScalarKind::Float && scalar.width == 8 { + self.features.request(Features::DOUBLE_TYPE); + } + } + + fn varying_required_features(&mut self, binding: Option<&Binding>, ty: Handle<Type>) { + match self.module.types[ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + for member in members { + self.varying_required_features(member.binding.as_ref(), member.ty); + } + } + _ => { + if let Some(binding) = binding { + match *binding { + Binding::BuiltIn(built_in) => match built_in { + crate::BuiltIn::ClipDistance => { + self.features.request(Features::CLIP_DISTANCE) + } + crate::BuiltIn::CullDistance => { + self.features.request(Features::CULL_DISTANCE) + } + crate::BuiltIn::SampleIndex => { + self.features.request(Features::SAMPLE_VARIABLES) + } + crate::BuiltIn::ViewIndex => { + self.features.request(Features::MULTI_VIEW) + } + crate::BuiltIn::InstanceIndex => { + self.features.request(Features::INSTANCE_INDEX) + } + _ => {} + }, + Binding::Location { + location: _, + interpolation, + sampling, + second_blend_source, + } => { + if interpolation == Some(Interpolation::Linear) { + self.features.request(Features::NOPERSPECTIVE_QUALIFIER); + } + if sampling == Some(Sampling::Sample) { + self.features.request(Features::SAMPLE_QUALIFIER); + } + if second_blend_source { + self.features.request(Features::DUAL_SOURCE_BLENDING); + } + } + } + } + } + } + } +} diff --git a/third_party/rust/naga/src/back/glsl/keywords.rs b/third_party/rust/naga/src/back/glsl/keywords.rs new file mode 100644 index 0000000000..857c935e68 --- /dev/null +++ b/third_party/rust/naga/src/back/glsl/keywords.rs @@ -0,0 +1,484 @@ +pub const RESERVED_KEYWORDS: &[&str] = &[ + // + // GLSL 4.6 keywords, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L2004-L2322 + // GLSL ES 3.2 keywords, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/es/3.2/GLSL_ES_Specification_3.20.html#L2166-L2478 + // + // Note: The GLSL ES 3.2 keywords are the same as GLSL 4.6 keywords with some residing in the reserved section. + // The only exception are the missing Vulkan keywords which I think is an oversight (see https://github.com/KhronosGroup/OpenGL-Registry/issues/585). + // + "const", + "uniform", + "buffer", + "shared", + "attribute", + "varying", + "coherent", + "volatile", + "restrict", + "readonly", + "writeonly", + "atomic_uint", + "layout", + "centroid", + "flat", + "smooth", + "noperspective", + "patch", + "sample", + "invariant", + "precise", + "break", + "continue", + "do", + "for", + "while", + "switch", + "case", + "default", + "if", + "else", + "subroutine", + "in", + "out", + "inout", + "int", + "void", + "bool", + "true", + "false", + "float", + "double", + "discard", + "return", + "vec2", + "vec3", + "vec4", + "ivec2", + "ivec3", + "ivec4", + "bvec2", + "bvec3", + "bvec4", + "uint", + "uvec2", + "uvec3", + "uvec4", + "dvec2", + "dvec3", + "dvec4", + "mat2", + "mat3", + "mat4", + "mat2x2", + "mat2x3", + "mat2x4", + "mat3x2", + "mat3x3", + "mat3x4", + "mat4x2", + "mat4x3", + "mat4x4", + "dmat2", + "dmat3", + "dmat4", + "dmat2x2", + "dmat2x3", + "dmat2x4", + "dmat3x2", + "dmat3x3", + "dmat3x4", + "dmat4x2", + "dmat4x3", + "dmat4x4", + "lowp", + "mediump", + "highp", + "precision", + "sampler1D", + "sampler1DShadow", + "sampler1DArray", + "sampler1DArrayShadow", + "isampler1D", + "isampler1DArray", + "usampler1D", + "usampler1DArray", + "sampler2D", + "sampler2DShadow", + "sampler2DArray", + "sampler2DArrayShadow", + "isampler2D", + "isampler2DArray", + "usampler2D", + "usampler2DArray", + "sampler2DRect", + "sampler2DRectShadow", + "isampler2DRect", + "usampler2DRect", + "sampler2DMS", + "isampler2DMS", + "usampler2DMS", + "sampler2DMSArray", + "isampler2DMSArray", + "usampler2DMSArray", + "sampler3D", + "isampler3D", + "usampler3D", + "samplerCube", + "samplerCubeShadow", + "isamplerCube", + "usamplerCube", + "samplerCubeArray", + "samplerCubeArrayShadow", + "isamplerCubeArray", + "usamplerCubeArray", + "samplerBuffer", + "isamplerBuffer", + "usamplerBuffer", + "image1D", + "iimage1D", + "uimage1D", + "image1DArray", + "iimage1DArray", + "uimage1DArray", + "image2D", + "iimage2D", + "uimage2D", + "image2DArray", + "iimage2DArray", + "uimage2DArray", + "image2DRect", + "iimage2DRect", + "uimage2DRect", + "image2DMS", + "iimage2DMS", + "uimage2DMS", + "image2DMSArray", + "iimage2DMSArray", + "uimage2DMSArray", + "image3D", + "iimage3D", + "uimage3D", + "imageCube", + "iimageCube", + "uimageCube", + "imageCubeArray", + "iimageCubeArray", + "uimageCubeArray", + "imageBuffer", + "iimageBuffer", + "uimageBuffer", + "struct", + // Vulkan keywords + "texture1D", + "texture1DArray", + "itexture1D", + "itexture1DArray", + "utexture1D", + "utexture1DArray", + "texture2D", + "texture2DArray", + "itexture2D", + "itexture2DArray", + "utexture2D", + "utexture2DArray", + "texture2DRect", + "itexture2DRect", + "utexture2DRect", + "texture2DMS", + "itexture2DMS", + "utexture2DMS", + "texture2DMSArray", + "itexture2DMSArray", + "utexture2DMSArray", + "texture3D", + "itexture3D", + "utexture3D", + "textureCube", + "itextureCube", + "utextureCube", + "textureCubeArray", + "itextureCubeArray", + "utextureCubeArray", + "textureBuffer", + "itextureBuffer", + "utextureBuffer", + "sampler", + "samplerShadow", + "subpassInput", + "isubpassInput", + "usubpassInput", + "subpassInputMS", + "isubpassInputMS", + "usubpassInputMS", + // Reserved keywords + "common", + "partition", + "active", + "asm", + "class", + "union", + "enum", + "typedef", + "template", + "this", + "resource", + "goto", + "inline", + "noinline", + "public", + "static", + "extern", + "external", + "interface", + "long", + "short", + "half", + "fixed", + "unsigned", + "superp", + "input", + "output", + "hvec2", + "hvec3", + "hvec4", + "fvec2", + "fvec3", + "fvec4", + "filter", + "sizeof", + "cast", + "namespace", + "using", + "sampler3DRect", + // + // GLSL 4.6 Built-In Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13314 + // + // Angle and Trigonometry Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13469-L13561C5 + "radians", + "degrees", + "sin", + "cos", + "tan", + "asin", + "acos", + "atan", + "sinh", + "cosh", + "tanh", + "asinh", + "acosh", + "atanh", + // Exponential Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13569-L13620 + "pow", + "exp", + "log", + "exp2", + "log2", + "sqrt", + "inversesqrt", + // Common Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13628-L13908 + "abs", + "sign", + "floor", + "trunc", + "round", + "roundEven", + "ceil", + "fract", + "mod", + "modf", + "min", + "max", + "clamp", + "mix", + "step", + "smoothstep", + "isnan", + "isinf", + "floatBitsToInt", + "floatBitsToUint", + "intBitsToFloat", + "uintBitsToFloat", + "fma", + "frexp", + "ldexp", + // Floating-Point Pack and Unpack Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13916-L14007 + "packUnorm2x16", + "packSnorm2x16", + "packUnorm4x8", + "packSnorm4x8", + "unpackUnorm2x16", + "unpackSnorm2x16", + "unpackUnorm4x8", + "unpackSnorm4x8", + "packHalf2x16", + "unpackHalf2x16", + "packDouble2x32", + "unpackDouble2x32", + // Geometric Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14014-L14121 + "length", + "distance", + "dot", + "cross", + "normalize", + "ftransform", + "faceforward", + "reflect", + "refract", + // Matrix Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14151-L14215 + "matrixCompMult", + "outerProduct", + "transpose", + "determinant", + "inverse", + // Vector Relational Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14259-L14322 + "lessThan", + "lessThanEqual", + "greaterThan", + "greaterThanEqual", + "equal", + "notEqual", + "any", + "all", + "not", + // Integer Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14335-L14432 + "uaddCarry", + "usubBorrow", + "umulExtended", + "imulExtended", + "bitfieldExtract", + "bitfieldInsert", + "bitfieldReverse", + "bitCount", + "findLSB", + "findMSB", + // Texture Query Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14645-L14732 + "textureSize", + "textureQueryLod", + "textureQueryLevels", + "textureSamples", + // Texel Lookup Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14736-L14997 + "texture", + "textureProj", + "textureLod", + "textureOffset", + "texelFetch", + "texelFetchOffset", + "textureProjOffset", + "textureLodOffset", + "textureProjLod", + "textureProjLodOffset", + "textureGrad", + "textureGradOffset", + "textureProjGrad", + "textureProjGradOffset", + // Texture Gather Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15077-L15154 + "textureGather", + "textureGatherOffset", + "textureGatherOffsets", + // Compatibility Profile Texture Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15161-L15220 + "texture1D", + "texture1DProj", + "texture1DLod", + "texture1DProjLod", + "texture2D", + "texture2DProj", + "texture2DLod", + "texture2DProjLod", + "texture3D", + "texture3DProj", + "texture3DLod", + "texture3DProjLod", + "textureCube", + "textureCubeLod", + "shadow1D", + "shadow2D", + "shadow1DProj", + "shadow2DProj", + "shadow1DLod", + "shadow2DLod", + "shadow1DProjLod", + "shadow2DProjLod", + // Atomic Counter Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15241-L15531 + "atomicCounterIncrement", + "atomicCounterDecrement", + "atomicCounter", + "atomicCounterAdd", + "atomicCounterSubtract", + "atomicCounterMin", + "atomicCounterMax", + "atomicCounterAnd", + "atomicCounterOr", + "atomicCounterXor", + "atomicCounterExchange", + "atomicCounterCompSwap", + // Atomic Memory Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15563-L15624 + "atomicAdd", + "atomicMin", + "atomicMax", + "atomicAnd", + "atomicOr", + "atomicXor", + "atomicExchange", + "atomicCompSwap", + // Image Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15763-L15878 + "imageSize", + "imageSamples", + "imageLoad", + "imageStore", + "imageAtomicAdd", + "imageAtomicMin", + "imageAtomicMax", + "imageAtomicAnd", + "imageAtomicOr", + "imageAtomicXor", + "imageAtomicExchange", + "imageAtomicCompSwap", + // Geometry Shader Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15886-L15932 + "EmitStreamVertex", + "EndStreamPrimitive", + "EmitVertex", + "EndPrimitive", + // Fragment Processing Functions, Derivative Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16041-L16114 + "dFdx", + "dFdy", + "dFdxFine", + "dFdyFine", + "dFdxCoarse", + "dFdyCoarse", + "fwidth", + "fwidthFine", + "fwidthCoarse", + // Fragment Processing Functions, Interpolation Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16150-L16198 + "interpolateAtCentroid", + "interpolateAtSample", + "interpolateAtOffset", + // Noise Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16214-L16243 + "noise1", + "noise2", + "noise3", + "noise4", + // Shader Invocation Control Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16255-L16276 + "barrier", + // Shader Memory Control Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16336-L16382 + "memoryBarrier", + "memoryBarrierAtomicCounter", + "memoryBarrierBuffer", + "memoryBarrierShared", + "memoryBarrierImage", + "groupMemoryBarrier", + // Subpass-Input Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16451-L16470 + "subpassLoad", + // Shader Invocation Group Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16483-L16511 + "anyInvocation", + "allInvocations", + "allInvocationsEqual", + // + // entry point name (should not be shadowed) + // + "main", + // Naga utilities: + super::MODF_FUNCTION, + super::FREXP_FUNCTION, + super::FIRST_INSTANCE_BINDING, +]; diff --git a/third_party/rust/naga/src/back/glsl/mod.rs b/third_party/rust/naga/src/back/glsl/mod.rs new file mode 100644 index 0000000000..e346d43257 --- /dev/null +++ b/third_party/rust/naga/src/back/glsl/mod.rs @@ -0,0 +1,4532 @@ +/*! +Backend for [GLSL][glsl] (OpenGL Shading Language). + +The main structure is [`Writer`], it maintains internal state that is used +to output a [`Module`](crate::Module) into glsl + +# Supported versions +### Core +- 330 +- 400 +- 410 +- 420 +- 430 +- 450 + +### ES +- 300 +- 310 + +[glsl]: https://www.khronos.org/registry/OpenGL/index_gl.php +*/ + +// GLSL is mostly a superset of C but it also removes some parts of it this is a list of relevant +// aspects for this backend. +// +// The most notable change is the introduction of the version preprocessor directive that must +// always be the first line of a glsl file and is written as +// `#version number profile` +// `number` is the version itself (i.e. 300) and `profile` is the +// shader profile we only support "core" and "es", the former is used in desktop applications and +// the later is used in embedded contexts, mobile devices and browsers. Each one as it's own +// versions (at the time of writing this the latest version for "core" is 460 and for "es" is 320) +// +// Other important preprocessor addition is the extension directive which is written as +// `#extension name: behaviour` +// Extensions provide increased features in a plugin fashion but they aren't required to be +// supported hence why they are called extensions, that's why `behaviour` is used it specifies +// whether the extension is strictly required or if it should only be enabled if needed. In our case +// when we use extensions we set behaviour to `require` always. +// +// The only thing that glsl removes that makes a difference are pointers. +// +// Additions that are relevant for the backend are the discard keyword, the introduction of +// vector, matrices, samplers, image types and functions that provide common shader operations + +pub use features::Features; + +use crate::{ + back, + proc::{self, NameKey}, + valid, Handle, ShaderStage, TypeInner, +}; +use features::FeaturesManager; +use std::{ + cmp::Ordering, + fmt, + fmt::{Error as FmtError, Write}, + mem, +}; +use thiserror::Error; + +/// Contains the features related code and the features querying method +mod features; +/// Contains a constant with a slice of all the reserved keywords RESERVED_KEYWORDS +mod keywords; + +/// List of supported `core` GLSL versions. +pub const SUPPORTED_CORE_VERSIONS: &[u16] = &[140, 150, 330, 400, 410, 420, 430, 440, 450, 460]; +/// List of supported `es` GLSL versions. +pub const SUPPORTED_ES_VERSIONS: &[u16] = &[300, 310, 320]; + +/// The suffix of the variable that will hold the calculated clamped level +/// of detail for bounds checking in `ImageLoad` +const CLAMPED_LOD_SUFFIX: &str = "_clamped_lod"; + +pub(crate) const MODF_FUNCTION: &str = "naga_modf"; +pub(crate) const FREXP_FUNCTION: &str = "naga_frexp"; + +// Must match code in glsl_built_in +pub const FIRST_INSTANCE_BINDING: &str = "naga_vs_first_instance"; + +/// Mapping between resources and bindings. +pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, u8>; + +impl crate::AtomicFunction { + const fn to_glsl(self) -> &'static str { + match self { + Self::Add | Self::Subtract => "Add", + Self::And => "And", + Self::InclusiveOr => "Or", + Self::ExclusiveOr => "Xor", + Self::Min => "Min", + Self::Max => "Max", + Self::Exchange { compare: None } => "Exchange", + Self::Exchange { compare: Some(_) } => "", //TODO + } + } +} + +impl crate::AddressSpace { + const fn is_buffer(&self) -> bool { + match *self { + crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } => true, + _ => false, + } + } + + /// Whether a variable with this address space can be initialized + const fn initializable(&self) -> bool { + match *self { + crate::AddressSpace::Function | crate::AddressSpace::Private => true, + crate::AddressSpace::WorkGroup + | crate::AddressSpace::Uniform + | crate::AddressSpace::Storage { .. } + | crate::AddressSpace::Handle + | crate::AddressSpace::PushConstant => false, + } + } +} + +/// A GLSL version. +#[derive(Debug, Copy, Clone, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum Version { + /// `core` GLSL. + Desktop(u16), + /// `es` GLSL. + Embedded { version: u16, is_webgl: bool }, +} + +impl Version { + /// Create a new gles version + pub const fn new_gles(version: u16) -> Self { + Self::Embedded { + version, + is_webgl: false, + } + } + + /// Returns true if self is `Version::Embedded` (i.e. is a es version) + const fn is_es(&self) -> bool { + match *self { + Version::Desktop(_) => false, + Version::Embedded { .. } => true, + } + } + + /// Returns true if targeting WebGL + const fn is_webgl(&self) -> bool { + match *self { + Version::Desktop(_) => false, + Version::Embedded { is_webgl, .. } => is_webgl, + } + } + + /// Checks the list of currently supported versions and returns true if it contains the + /// specified version + /// + /// # Notes + /// As an invalid version number will never be added to the supported version list + /// so this also checks for version validity + fn is_supported(&self) -> bool { + match *self { + Version::Desktop(v) => SUPPORTED_CORE_VERSIONS.contains(&v), + Version::Embedded { version: v, .. } => SUPPORTED_ES_VERSIONS.contains(&v), + } + } + + fn supports_io_locations(&self) -> bool { + *self >= Version::Desktop(330) || *self >= Version::new_gles(300) + } + + /// Checks if the version supports all of the explicit layouts: + /// - `location=` qualifiers for bindings + /// - `binding=` qualifiers for resources + /// + /// Note: `location=` for vertex inputs and fragment outputs is supported + /// unconditionally for GLES 300. + fn supports_explicit_locations(&self) -> bool { + *self >= Version::Desktop(410) || *self >= Version::new_gles(310) + } + + fn supports_early_depth_test(&self) -> bool { + *self >= Version::Desktop(130) || *self >= Version::new_gles(310) + } + + fn supports_std430_layout(&self) -> bool { + *self >= Version::Desktop(430) || *self >= Version::new_gles(310) + } + + fn supports_fma_function(&self) -> bool { + *self >= Version::Desktop(400) || *self >= Version::new_gles(320) + } + + fn supports_integer_functions(&self) -> bool { + *self >= Version::Desktop(400) || *self >= Version::new_gles(310) + } + + fn supports_frexp_function(&self) -> bool { + *self >= Version::Desktop(400) || *self >= Version::new_gles(310) + } + + fn supports_derivative_control(&self) -> bool { + *self >= Version::Desktop(450) + } +} + +impl PartialOrd for Version { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + match (*self, *other) { + (Version::Desktop(x), Version::Desktop(y)) => Some(x.cmp(&y)), + (Version::Embedded { version: x, .. }, Version::Embedded { version: y, .. }) => { + Some(x.cmp(&y)) + } + _ => None, + } + } +} + +impl fmt::Display for Version { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Version::Desktop(v) => write!(f, "{v} core"), + Version::Embedded { version: v, .. } => write!(f, "{v} es"), + } + } +} + +bitflags::bitflags! { + /// Configuration flags for the [`Writer`]. + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct WriterFlags: u32 { + /// Flip output Y and extend Z from (0, 1) to (-1, 1). + const ADJUST_COORDINATE_SPACE = 0x1; + /// Supports GL_EXT_texture_shadow_lod on the host, which provides + /// additional functions on shadows and arrays of shadows. + const TEXTURE_SHADOW_LOD = 0x2; + /// Supports ARB_shader_draw_parameters on the host, which provides + /// support for `gl_BaseInstanceARB`, `gl_BaseVertexARB`, and `gl_DrawIDARB`. + const DRAW_PARAMETERS = 0x4; + /// Include unused global variables, constants and functions. By default the output will exclude + /// global variables that are not used in the specified entrypoint (including indirect use), + /// all constant declarations, and functions that use excluded global variables. + const INCLUDE_UNUSED_ITEMS = 0x10; + /// Emit `PointSize` output builtin to vertex shaders, which is + /// required for drawing with `PointList` topology. + /// + /// https://registry.khronos.org/OpenGL/specs/es/3.2/GLSL_ES_Specification_3.20.html#built-in-language-variables + /// The variable gl_PointSize is intended for a shader to write the size of the point to be rasterized. It is measured in pixels. + /// If gl_PointSize is not written to, its value is undefined in subsequent pipe stages. + const FORCE_POINT_SIZE = 0x20; + } +} + +/// Configuration used in the [`Writer`]. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct Options { + /// The GLSL version to be used. + pub version: Version, + /// Configuration flags for the [`Writer`]. + pub writer_flags: WriterFlags, + /// Map of resources association to binding locations. + pub binding_map: BindingMap, + /// Should workgroup variables be zero initialized (by polyfilling)? + pub zero_initialize_workgroup_memory: bool, +} + +impl Default for Options { + fn default() -> Self { + Options { + version: Version::new_gles(310), + writer_flags: WriterFlags::ADJUST_COORDINATE_SPACE, + binding_map: BindingMap::default(), + zero_initialize_workgroup_memory: true, + } + } +} + +/// A subset of options meant to be changed per pipeline. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct PipelineOptions { + /// The stage of the entry point. + pub shader_stage: ShaderStage, + /// The name of the entry point. + /// + /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown. + pub entry_point: String, + /// How many views to render to, if doing multiview rendering. + pub multiview: Option<std::num::NonZeroU32>, +} + +#[derive(Debug)] +pub struct VaryingLocation { + /// The location of the global. + /// This corresponds to `layout(location = ..)` in GLSL. + pub location: u32, + /// The index which can be used for dual source blending. + /// This corresponds to `layout(index = ..)` in GLSL. + pub index: u32, +} + +/// Reflection info for texture mappings and uniforms. +#[derive(Debug)] +pub struct ReflectionInfo { + /// Mapping between texture names and variables/samplers. + pub texture_mapping: crate::FastHashMap<String, TextureMapping>, + /// Mapping between uniform variables and names. + pub uniforms: crate::FastHashMap<Handle<crate::GlobalVariable>, String>, + /// Mapping between names and attribute locations. + pub varying: crate::FastHashMap<String, VaryingLocation>, + /// List of push constant items in the shader. + pub push_constant_items: Vec<PushConstantItem>, +} + +/// Mapping between a texture and its sampler, if it exists. +/// +/// GLSL pre-Vulkan has no concept of separate textures and samplers. Instead, everything is a +/// `gsamplerN` where `g` is the scalar type and `N` is the dimension. But naga uses separate textures +/// and samplers in the IR, so the backend produces a [`FastHashMap`](crate::FastHashMap) with the texture name +/// as a key and a [`TextureMapping`] as a value. This way, the user knows where to bind. +/// +/// [`Storage`](crate::ImageClass::Storage) images produce `gimageN` and don't have an associated sampler, +/// so the [`sampler`](Self::sampler) field will be [`None`]. +#[derive(Debug, Clone)] +pub struct TextureMapping { + /// Handle to the image global variable. + pub texture: Handle<crate::GlobalVariable>, + /// Handle to the associated sampler global variable, if it exists. + pub sampler: Option<Handle<crate::GlobalVariable>>, +} + +/// All information to bind a single uniform value to the shader. +/// +/// Push constants are emulated using traditional uniforms in OpenGL. +/// +/// These are composed of a set of primitives (scalar, vector, matrix) that +/// are given names. Because they are not backed by the concept of a buffer, +/// we must do the work of calculating the offset of each primitive in the +/// push constant block. +#[derive(Debug, Clone)] +pub struct PushConstantItem { + /// GL uniform name for the item. This name is the same as if you were + /// to access it directly from a GLSL shader. + /// + /// The with the following example, the following names will be generated, + /// one name per GLSL uniform. + /// + /// ```glsl + /// struct InnerStruct { + /// value: f32, + /// } + /// + /// struct PushConstant { + /// InnerStruct inner; + /// vec4 array[2]; + /// } + /// + /// uniform PushConstants _push_constant_binding_cs; + /// ``` + /// + /// ```text + /// - _push_constant_binding_cs.inner.value + /// - _push_constant_binding_cs.array[0] + /// - _push_constant_binding_cs.array[1] + /// ``` + /// + pub access_path: String, + /// Type of the uniform. This will only ever be a scalar, vector, or matrix. + pub ty: Handle<crate::Type>, + /// The offset in the push constant memory block this uniform maps to. + /// + /// The size of the uniform can be derived from the type. + pub offset: u32, +} + +/// Helper structure that generates a number +#[derive(Default)] +struct IdGenerator(u32); + +impl IdGenerator { + /// Generates a number that's guaranteed to be unique for this `IdGenerator` + fn generate(&mut self) -> u32 { + // It's just an increasing number but it does the job + let ret = self.0; + self.0 += 1; + ret + } +} + +/// Assorted options needed for generating varyings. +#[derive(Clone, Copy)] +struct VaryingOptions { + output: bool, + targeting_webgl: bool, + draw_parameters: bool, +} + +impl VaryingOptions { + const fn from_writer_options(options: &Options, output: bool) -> Self { + Self { + output, + targeting_webgl: options.version.is_webgl(), + draw_parameters: options.writer_flags.contains(WriterFlags::DRAW_PARAMETERS), + } + } +} + +/// Helper wrapper used to get a name for a varying +/// +/// Varying have different naming schemes depending on their binding: +/// - Varyings with builtin bindings get the from [`glsl_built_in`]. +/// - Varyings with location bindings are named `_S_location_X` where `S` is a +/// prefix identifying which pipeline stage the varying connects, and `X` is +/// the location. +struct VaryingName<'a> { + binding: &'a crate::Binding, + stage: ShaderStage, + options: VaryingOptions, +} +impl fmt::Display for VaryingName<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self.binding { + crate::Binding::Location { + second_blend_source: true, + .. + } => { + write!(f, "_fs2p_location1",) + } + crate::Binding::Location { location, .. } => { + let prefix = match (self.stage, self.options.output) { + (ShaderStage::Compute, _) => unreachable!(), + // pipeline to vertex + (ShaderStage::Vertex, false) => "p2vs", + // vertex to fragment + (ShaderStage::Vertex, true) | (ShaderStage::Fragment, false) => "vs2fs", + // fragment to pipeline + (ShaderStage::Fragment, true) => "fs2p", + }; + write!(f, "_{prefix}_location{location}",) + } + crate::Binding::BuiltIn(built_in) => { + write!(f, "{}", glsl_built_in(built_in, self.options)) + } + } + } +} + +impl ShaderStage { + const fn to_str(self) -> &'static str { + match self { + ShaderStage::Compute => "cs", + ShaderStage::Fragment => "fs", + ShaderStage::Vertex => "vs", + } + } +} + +/// Shorthand result used internally by the backend +type BackendResult<T = ()> = Result<T, Error>; + +/// A GLSL compilation error. +#[derive(Debug, Error)] +pub enum Error { + /// A error occurred while writing to the output. + #[error("Format error")] + FmtError(#[from] FmtError), + /// The specified [`Version`] doesn't have all required [`Features`]. + /// + /// Contains the missing [`Features`]. + #[error("The selected version doesn't support {0:?}")] + MissingFeatures(Features), + /// [`AddressSpace::PushConstant`](crate::AddressSpace::PushConstant) was used more than + /// once in the entry point, which isn't supported. + #[error("Multiple push constants aren't supported")] + MultiplePushConstants, + /// The specified [`Version`] isn't supported. + #[error("The specified version isn't supported")] + VersionNotSupported, + /// The entry point couldn't be found. + #[error("The requested entry point couldn't be found")] + EntryPointNotFound, + /// A call was made to an unsupported external. + #[error("A call was made to an unsupported external: {0}")] + UnsupportedExternal(String), + /// A scalar with an unsupported width was requested. + #[error("A scalar with an unsupported width was requested: {0:?}")] + UnsupportedScalar(crate::Scalar), + /// A image was used with multiple samplers, which isn't supported. + #[error("A image was used with multiple samplers")] + ImageMultipleSamplers, + #[error("{0}")] + Custom(String), +} + +/// Binary operation with a different logic on the GLSL side. +enum BinaryOperation { + /// Vector comparison should use the function like `greaterThan()`, etc. + VectorCompare, + /// Vector component wise operation; used to polyfill unsupported ops like `|` and `&` for `bvecN`'s + VectorComponentWise, + /// GLSL `%` is SPIR-V `OpUMod/OpSMod` and `mod()` is `OpFMod`, but [`BinaryOperator::Modulo`](crate::BinaryOperator::Modulo) is `OpFRem`. + Modulo, + /// Any plain operation. No additional logic required. + Other, +} + +/// Writer responsible for all code generation. +pub struct Writer<'a, W> { + // Inputs + /// The module being written. + module: &'a crate::Module, + /// The module analysis. + info: &'a valid::ModuleInfo, + /// The output writer. + out: W, + /// User defined configuration to be used. + options: &'a Options, + /// The bound checking policies to be used + policies: proc::BoundsCheckPolicies, + + // Internal State + /// Features manager used to store all the needed features and write them. + features: FeaturesManager, + namer: proc::Namer, + /// A map with all the names needed for writing the module + /// (generated by a [`Namer`](crate::proc::Namer)). + names: crate::FastHashMap<NameKey, String>, + /// A map with the names of global variables needed for reflections. + reflection_names_globals: crate::FastHashMap<Handle<crate::GlobalVariable>, String>, + /// The selected entry point. + entry_point: &'a crate::EntryPoint, + /// The index of the selected entry point. + entry_point_idx: proc::EntryPointIndex, + /// A generator for unique block numbers. + block_id: IdGenerator, + /// Set of expressions that have associated temporary variables. + named_expressions: crate::NamedExpressions, + /// Set of expressions that need to be baked to avoid unnecessary repetition in output + need_bake_expressions: back::NeedBakeExpressions, + /// How many views to render to, if doing multiview rendering. + multiview: Option<std::num::NonZeroU32>, + /// Mapping of varying variables to their location. Needed for reflections. + varying: crate::FastHashMap<String, VaryingLocation>, +} + +impl<'a, W: Write> Writer<'a, W> { + /// Creates a new [`Writer`] instance. + /// + /// # Errors + /// - If the version specified is invalid or supported. + /// - If the entry point couldn't be found in the module. + /// - If the version specified doesn't support some used features. + pub fn new( + out: W, + module: &'a crate::Module, + info: &'a valid::ModuleInfo, + options: &'a Options, + pipeline_options: &'a PipelineOptions, + policies: proc::BoundsCheckPolicies, + ) -> Result<Self, Error> { + // Check if the requested version is supported + if !options.version.is_supported() { + log::error!("Version {}", options.version); + return Err(Error::VersionNotSupported); + } + + // Try to find the entry point and corresponding index + let ep_idx = module + .entry_points + .iter() + .position(|ep| { + pipeline_options.shader_stage == ep.stage && pipeline_options.entry_point == ep.name + }) + .ok_or(Error::EntryPointNotFound)?; + + // Generate a map with names required to write the module + let mut names = crate::FastHashMap::default(); + let mut namer = proc::Namer::default(); + namer.reset( + module, + keywords::RESERVED_KEYWORDS, + &[], + &[], + &[ + "gl_", // all GL built-in variables + "_group", // all normal bindings + "_push_constant_binding_", // all push constant bindings + ], + &mut names, + ); + + // Build the instance + let mut this = Self { + module, + info, + out, + options, + policies, + + namer, + features: FeaturesManager::new(), + names, + reflection_names_globals: crate::FastHashMap::default(), + entry_point: &module.entry_points[ep_idx], + entry_point_idx: ep_idx as u16, + multiview: pipeline_options.multiview, + block_id: IdGenerator::default(), + named_expressions: Default::default(), + need_bake_expressions: Default::default(), + varying: Default::default(), + }; + + // Find all features required to print this module + this.collect_required_features()?; + + Ok(this) + } + + /// Writes the [`Module`](crate::Module) as glsl to the output + /// + /// # Notes + /// If an error occurs while writing, the output might have been written partially + /// + /// # Panics + /// Might panic if the module is invalid + pub fn write(&mut self) -> Result<ReflectionInfo, Error> { + // We use `writeln!(self.out)` throughout the write to add newlines + // to make the output more readable + + let es = self.options.version.is_es(); + + // Write the version (It must be the first thing or it isn't a valid glsl output) + writeln!(self.out, "#version {}", self.options.version)?; + // Write all the needed extensions + // + // This used to be the last thing being written as it allowed to search for features while + // writing the module saving some loops but some older versions (420 or less) required the + // extensions to appear before being used, even though extensions are part of the + // preprocessor not the processor ¯\_(ツ)_/¯ + self.features.write(self.options, &mut self.out)?; + + // Write the additional extensions + if self + .options + .writer_flags + .contains(WriterFlags::TEXTURE_SHADOW_LOD) + { + // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_shadow_lod.txt + writeln!(self.out, "#extension GL_EXT_texture_shadow_lod : require")?; + } + + // glsl es requires a precision to be specified for floats and ints + // TODO: Should this be user configurable? + if es { + writeln!(self.out)?; + writeln!(self.out, "precision highp float;")?; + writeln!(self.out, "precision highp int;")?; + writeln!(self.out)?; + } + + if self.entry_point.stage == ShaderStage::Compute { + let workgroup_size = self.entry_point.workgroup_size; + writeln!( + self.out, + "layout(local_size_x = {}, local_size_y = {}, local_size_z = {}) in;", + workgroup_size[0], workgroup_size[1], workgroup_size[2] + )?; + writeln!(self.out)?; + } + + if self.entry_point.stage == ShaderStage::Vertex + && !self + .options + .writer_flags + .contains(WriterFlags::DRAW_PARAMETERS) + && self.features.contains(Features::INSTANCE_INDEX) + { + writeln!(self.out, "uniform uint {FIRST_INSTANCE_BINDING};")?; + writeln!(self.out)?; + } + + // Enable early depth tests if needed + if let Some(depth_test) = self.entry_point.early_depth_test { + // If early depth test is supported for this version of GLSL + if self.options.version.supports_early_depth_test() { + writeln!(self.out, "layout(early_fragment_tests) in;")?; + + if let Some(conservative) = depth_test.conservative { + use crate::ConservativeDepth as Cd; + + let depth = match conservative { + Cd::GreaterEqual => "greater", + Cd::LessEqual => "less", + Cd::Unchanged => "unchanged", + }; + writeln!(self.out, "layout (depth_{depth}) out float gl_FragDepth;")?; + } + writeln!(self.out)?; + } else { + log::warn!( + "Early depth testing is not supported for this version of GLSL: {}", + self.options.version + ); + } + } + + if self.entry_point.stage == ShaderStage::Vertex && self.options.version.is_webgl() { + if let Some(multiview) = self.multiview.as_ref() { + writeln!(self.out, "layout(num_views = {multiview}) in;")?; + writeln!(self.out)?; + } + } + + // Write struct types. + // + // This are always ordered because the IR is structured in a way that + // you can't make a struct without adding all of its members first. + for (handle, ty) in self.module.types.iter() { + if let TypeInner::Struct { ref members, .. } = ty.inner { + // Structures ending with runtime-sized arrays can only be + // rendered as shader storage blocks in GLSL, not stand-alone + // struct types. + if !self.module.types[members.last().unwrap().ty] + .inner + .is_dynamically_sized(&self.module.types) + { + let name = &self.names[&NameKey::Type(handle)]; + write!(self.out, "struct {name} ")?; + self.write_struct_body(handle, members)?; + writeln!(self.out, ";")?; + } + } + } + + // Write functions to create special types. + for (type_key, struct_ty) in self.module.special_types.predeclared_types.iter() { + match type_key { + &crate::PredeclaredType::ModfResult { size, width } + | &crate::PredeclaredType::FrexpResult { size, width } => { + let arg_type_name_owner; + let arg_type_name = if let Some(size) = size { + arg_type_name_owner = + format!("{}vec{}", if width == 8 { "d" } else { "" }, size as u8); + &arg_type_name_owner + } else if width == 8 { + "double" + } else { + "float" + }; + + let other_type_name_owner; + let (defined_func_name, called_func_name, other_type_name) = + if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) { + (MODF_FUNCTION, "modf", arg_type_name) + } else { + let other_type_name = if let Some(size) = size { + other_type_name_owner = format!("ivec{}", size as u8); + &other_type_name_owner + } else { + "int" + }; + (FREXP_FUNCTION, "frexp", other_type_name) + }; + + let struct_name = &self.names[&NameKey::Type(*struct_ty)]; + + writeln!(self.out)?; + if !self.options.version.supports_frexp_function() + && matches!(type_key, &crate::PredeclaredType::FrexpResult { .. }) + { + writeln!( + self.out, + "{struct_name} {defined_func_name}({arg_type_name} arg) {{ + {other_type_name} other = arg == {arg_type_name}(0) ? {other_type_name}(0) : {other_type_name}({arg_type_name}(1) + log2(arg)); + {arg_type_name} fract = arg * exp2({arg_type_name}(-other)); + return {struct_name}(fract, other); +}}", + )?; + } else { + writeln!( + self.out, + "{struct_name} {defined_func_name}({arg_type_name} arg) {{ + {other_type_name} other; + {arg_type_name} fract = {called_func_name}(arg, other); + return {struct_name}(fract, other); +}}", + )?; + } + } + &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {} + } + } + + // Write all named constants + let mut constants = self + .module + .constants + .iter() + .filter(|&(_, c)| c.name.is_some()) + .peekable(); + while let Some((handle, _)) = constants.next() { + self.write_global_constant(handle)?; + // Add extra newline for readability on last iteration + if constants.peek().is_none() { + writeln!(self.out)?; + } + } + + let ep_info = self.info.get_entry_point(self.entry_point_idx as usize); + + // Write the globals + // + // Unless explicitly disabled with WriterFlags::INCLUDE_UNUSED_ITEMS, + // we filter all globals that aren't used by the selected entry point as they might be + // interfere with each other (i.e. two globals with the same location but different with + // different classes) + let include_unused = self + .options + .writer_flags + .contains(WriterFlags::INCLUDE_UNUSED_ITEMS); + for (handle, global) in self.module.global_variables.iter() { + let is_unused = ep_info[handle].is_empty(); + if !include_unused && is_unused { + continue; + } + + match self.module.types[global.ty].inner { + // We treat images separately because they might require + // writing the storage format + TypeInner::Image { + mut dim, + arrayed, + class, + } => { + // Gather the storage format if needed + let storage_format_access = match self.module.types[global.ty].inner { + TypeInner::Image { + class: crate::ImageClass::Storage { format, access }, + .. + } => Some((format, access)), + _ => None, + }; + + if dim == crate::ImageDimension::D1 && es { + dim = crate::ImageDimension::D2 + } + + // Gether the location if needed + let layout_binding = if self.options.version.supports_explicit_locations() { + let br = global.binding.as_ref().unwrap(); + self.options.binding_map.get(br).cloned() + } else { + None + }; + + // Write all the layout qualifiers + if layout_binding.is_some() || storage_format_access.is_some() { + write!(self.out, "layout(")?; + if let Some(binding) = layout_binding { + write!(self.out, "binding = {binding}")?; + } + if let Some((format, _)) = storage_format_access { + let format_str = glsl_storage_format(format)?; + let separator = match layout_binding { + Some(_) => ",", + None => "", + }; + write!(self.out, "{separator}{format_str}")?; + } + write!(self.out, ") ")?; + } + + if let Some((_, access)) = storage_format_access { + self.write_storage_access(access)?; + } + + // All images in glsl are `uniform` + // The trailing space is important + write!(self.out, "uniform ")?; + + // write the type + // + // This is way we need the leading space because `write_image_type` doesn't add + // any spaces at the beginning or end + self.write_image_type(dim, arrayed, class)?; + + // Finally write the name and end the global with a `;` + // The leading space is important + let global_name = self.get_global_name(handle, global); + writeln!(self.out, " {global_name};")?; + writeln!(self.out)?; + + self.reflection_names_globals.insert(handle, global_name); + } + // glsl has no concept of samplers so we just ignore it + TypeInner::Sampler { .. } => continue, + // All other globals are written by `write_global` + _ => { + self.write_global(handle, global)?; + // Add a newline (only for readability) + writeln!(self.out)?; + } + } + } + + for arg in self.entry_point.function.arguments.iter() { + self.write_varying(arg.binding.as_ref(), arg.ty, false)?; + } + if let Some(ref result) = self.entry_point.function.result { + self.write_varying(result.binding.as_ref(), result.ty, true)?; + } + writeln!(self.out)?; + + // Write all regular functions + for (handle, function) in self.module.functions.iter() { + // Check that the function doesn't use globals that aren't supported + // by the current entry point + if !include_unused && !ep_info.dominates_global_use(&self.info[handle]) { + continue; + } + + let fun_info = &self.info[handle]; + + // Skip functions that that are not compatible with this entry point's stage. + // + // When validation is enabled, it rejects modules whose entry points try to call + // incompatible functions, so if we got this far, then any functions incompatible + // with our selected entry point must not be used. + // + // When validation is disabled, `fun_info.available_stages` is always just + // `ShaderStages::all()`, so this will write all functions in the module, and + // the downstream GLSL compiler will catch any problems. + if !fun_info.available_stages.contains(ep_info.available_stages) { + continue; + } + + // Write the function + self.write_function(back::FunctionType::Function(handle), function, fun_info)?; + + writeln!(self.out)?; + } + + self.write_function( + back::FunctionType::EntryPoint(self.entry_point_idx), + &self.entry_point.function, + ep_info, + )?; + + // Add newline at the end of file + writeln!(self.out)?; + + // Collect all reflection info and return it to the user + self.collect_reflection_info() + } + + fn write_array_size( + &mut self, + base: Handle<crate::Type>, + size: crate::ArraySize, + ) -> BackendResult { + write!(self.out, "[")?; + + // Write the array size + // Writes nothing if `ArraySize::Dynamic` + match size { + crate::ArraySize::Constant(size) => { + write!(self.out, "{size}")?; + } + crate::ArraySize::Dynamic => (), + } + + write!(self.out, "]")?; + + if let TypeInner::Array { + base: next_base, + size: next_size, + .. + } = self.module.types[base].inner + { + self.write_array_size(next_base, next_size)?; + } + + Ok(()) + } + + /// Helper method used to write value types + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_value_type(&mut self, inner: &TypeInner) -> BackendResult { + match *inner { + // Scalars are simple we just get the full name from `glsl_scalar` + TypeInner::Scalar(scalar) + | TypeInner::Atomic(scalar) + | TypeInner::ValuePointer { + size: None, + scalar, + space: _, + } => write!(self.out, "{}", glsl_scalar(scalar)?.full)?, + // Vectors are just `gvecN` where `g` is the scalar prefix and `N` is the vector size + TypeInner::Vector { size, scalar } + | TypeInner::ValuePointer { + size: Some(size), + scalar, + space: _, + } => write!(self.out, "{}vec{}", glsl_scalar(scalar)?.prefix, size as u8)?, + // Matrices are written with `gmatMxN` where `g` is the scalar prefix (only floats and + // doubles are allowed), `M` is the columns count and `N` is the rows count + // + // glsl supports a matrix shorthand `gmatN` where `N` = `M` but it doesn't justify the + // extra branch to write matrices this way + TypeInner::Matrix { + columns, + rows, + scalar, + } => write!( + self.out, + "{}mat{}x{}", + glsl_scalar(scalar)?.prefix, + columns as u8, + rows as u8 + )?, + // GLSL arrays are written as `type name[size]` + // Here we only write the size of the array i.e. `[size]` + // Base `type` and `name` should be written outside + TypeInner::Array { base, size, .. } => self.write_array_size(base, size)?, + // Write all variants instead of `_` so that if new variants are added a + // no exhaustiveness error is thrown + TypeInner::Pointer { .. } + | TypeInner::Struct { .. } + | TypeInner::Image { .. } + | TypeInner::Sampler { .. } + | TypeInner::AccelerationStructure + | TypeInner::RayQuery + | TypeInner::BindingArray { .. } => { + return Err(Error::Custom(format!("Unable to write type {inner:?}"))) + } + } + + Ok(()) + } + + /// Helper method used to write non image/sampler types + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_type(&mut self, ty: Handle<crate::Type>) -> BackendResult { + match self.module.types[ty].inner { + // glsl has no pointer types so just write types as normal and loads are skipped + TypeInner::Pointer { base, .. } => self.write_type(base), + // glsl structs are written as just the struct name + TypeInner::Struct { .. } => { + // Get the struct name + let name = &self.names[&NameKey::Type(ty)]; + write!(self.out, "{name}")?; + Ok(()) + } + // glsl array has the size separated from the base type + TypeInner::Array { base, .. } => self.write_type(base), + ref other => self.write_value_type(other), + } + } + + /// Helper method to write a image type + /// + /// # Notes + /// Adds no leading or trailing whitespace + fn write_image_type( + &mut self, + dim: crate::ImageDimension, + arrayed: bool, + class: crate::ImageClass, + ) -> BackendResult { + // glsl images consist of four parts the scalar prefix, the image "type", the dimensions + // and modifiers + // + // There exists two image types + // - sampler - for sampled images + // - image - for storage images + // + // There are three possible modifiers that can be used together and must be written in + // this order to be valid + // - MS - used if it's a multisampled image + // - Array - used if it's an image array + // - Shadow - used if it's a depth image + use crate::ImageClass as Ic; + + let (base, kind, ms, comparison) = match class { + Ic::Sampled { kind, multi: true } => ("sampler", kind, "MS", ""), + Ic::Sampled { kind, multi: false } => ("sampler", kind, "", ""), + Ic::Depth { multi: true } => ("sampler", crate::ScalarKind::Float, "MS", ""), + Ic::Depth { multi: false } => ("sampler", crate::ScalarKind::Float, "", "Shadow"), + Ic::Storage { format, .. } => ("image", format.into(), "", ""), + }; + + let precision = if self.options.version.is_es() { + "highp " + } else { + "" + }; + + write!( + self.out, + "{}{}{}{}{}{}{}", + precision, + glsl_scalar(crate::Scalar { kind, width: 4 })?.prefix, + base, + glsl_dimension(dim), + ms, + if arrayed { "Array" } else { "" }, + comparison + )?; + + Ok(()) + } + + /// Helper method used to write non images/sampler globals + /// + /// # Notes + /// Adds a newline + /// + /// # Panics + /// If the global has type sampler + fn write_global( + &mut self, + handle: Handle<crate::GlobalVariable>, + global: &crate::GlobalVariable, + ) -> BackendResult { + if self.options.version.supports_explicit_locations() { + if let Some(ref br) = global.binding { + match self.options.binding_map.get(br) { + Some(binding) => { + let layout = match global.space { + crate::AddressSpace::Storage { .. } => { + if self.options.version.supports_std430_layout() { + "std430, " + } else { + "std140, " + } + } + crate::AddressSpace::Uniform => "std140, ", + _ => "", + }; + write!(self.out, "layout({layout}binding = {binding}) ")? + } + None => { + log::debug!("unassigned binding for {:?}", global.name); + if let crate::AddressSpace::Storage { .. } = global.space { + if self.options.version.supports_std430_layout() { + write!(self.out, "layout(std430) ")? + } + } + } + } + } + } + + if let crate::AddressSpace::Storage { access } = global.space { + self.write_storage_access(access)?; + } + + if let Some(storage_qualifier) = glsl_storage_qualifier(global.space) { + write!(self.out, "{storage_qualifier} ")?; + } + + match global.space { + crate::AddressSpace::Private => { + self.write_simple_global(handle, global)?; + } + crate::AddressSpace::WorkGroup => { + self.write_simple_global(handle, global)?; + } + crate::AddressSpace::PushConstant => { + self.write_simple_global(handle, global)?; + } + crate::AddressSpace::Uniform => { + self.write_interface_block(handle, global)?; + } + crate::AddressSpace::Storage { .. } => { + self.write_interface_block(handle, global)?; + } + // A global variable in the `Function` address space is a + // contradiction in terms. + crate::AddressSpace::Function => unreachable!(), + // Textures and samplers are handled directly in `Writer::write`. + crate::AddressSpace::Handle => unreachable!(), + } + + Ok(()) + } + + fn write_simple_global( + &mut self, + handle: Handle<crate::GlobalVariable>, + global: &crate::GlobalVariable, + ) -> BackendResult { + self.write_type(global.ty)?; + write!(self.out, " ")?; + self.write_global_name(handle, global)?; + + if let TypeInner::Array { base, size, .. } = self.module.types[global.ty].inner { + self.write_array_size(base, size)?; + } + + if global.space.initializable() && is_value_init_supported(self.module, global.ty) { + write!(self.out, " = ")?; + if let Some(init) = global.init { + self.write_const_expr(init)?; + } else { + self.write_zero_init_value(global.ty)?; + } + } + + writeln!(self.out, ";")?; + + if let crate::AddressSpace::PushConstant = global.space { + let global_name = self.get_global_name(handle, global); + self.reflection_names_globals.insert(handle, global_name); + } + + Ok(()) + } + + /// Write an interface block for a single Naga global. + /// + /// Write `block_name { members }`. Since `block_name` must be unique + /// between blocks and structs, we add `_block_ID` where `ID` is a + /// `IdGenerator` generated number. Write `members` in the same way we write + /// a struct's members. + fn write_interface_block( + &mut self, + handle: Handle<crate::GlobalVariable>, + global: &crate::GlobalVariable, + ) -> BackendResult { + // Write the block name, it's just the struct name appended with `_block_ID` + let ty_name = &self.names[&NameKey::Type(global.ty)]; + let block_name = format!( + "{}_block_{}{:?}", + // avoid double underscores as they are reserved in GLSL + ty_name.trim_end_matches('_'), + self.block_id.generate(), + self.entry_point.stage, + ); + write!(self.out, "{block_name} ")?; + self.reflection_names_globals.insert(handle, block_name); + + match self.module.types[global.ty].inner { + crate::TypeInner::Struct { ref members, .. } + if self.module.types[members.last().unwrap().ty] + .inner + .is_dynamically_sized(&self.module.types) => + { + // Structs with dynamically sized arrays must have their + // members lifted up as members of the interface block. GLSL + // can't write such struct types anyway. + self.write_struct_body(global.ty, members)?; + write!(self.out, " ")?; + self.write_global_name(handle, global)?; + } + _ => { + // A global of any other type is written as the sole member + // of the interface block. Since the interface block is + // anonymous, this becomes visible in the global scope. + write!(self.out, "{{ ")?; + self.write_type(global.ty)?; + write!(self.out, " ")?; + self.write_global_name(handle, global)?; + if let TypeInner::Array { base, size, .. } = self.module.types[global.ty].inner { + self.write_array_size(base, size)?; + } + write!(self.out, "; }}")?; + } + } + + writeln!(self.out, ";")?; + + Ok(()) + } + + /// Helper method used to find which expressions of a given function require baking + /// + /// # Notes + /// Clears `need_bake_expressions` set before adding to it + fn update_expressions_to_bake(&mut self, func: &crate::Function, info: &valid::FunctionInfo) { + use crate::Expression; + self.need_bake_expressions.clear(); + for (fun_handle, expr) in func.expressions.iter() { + let expr_info = &info[fun_handle]; + let min_ref_count = func.expressions[fun_handle].bake_ref_count(); + if min_ref_count <= expr_info.ref_count { + self.need_bake_expressions.insert(fun_handle); + } + + let inner = expr_info.ty.inner_with(&self.module.types); + + if let Expression::Math { fun, arg, arg1, .. } = *expr { + match fun { + crate::MathFunction::Dot => { + // if the expression is a Dot product with integer arguments, + // then the args needs baking as well + if let TypeInner::Scalar(crate::Scalar { kind, .. }) = *inner { + match kind { + crate::ScalarKind::Sint | crate::ScalarKind::Uint => { + self.need_bake_expressions.insert(arg); + self.need_bake_expressions.insert(arg1.unwrap()); + } + _ => {} + } + } + } + crate::MathFunction::CountLeadingZeros => { + if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { + self.need_bake_expressions.insert(arg); + } + } + _ => {} + } + } + } + } + + /// Helper method used to get a name for a global + /// + /// Globals have different naming schemes depending on their binding: + /// - Globals without bindings use the name from the [`Namer`](crate::proc::Namer) + /// - Globals with resource binding are named `_group_X_binding_Y` where `X` + /// is the group and `Y` is the binding + fn get_global_name( + &self, + handle: Handle<crate::GlobalVariable>, + global: &crate::GlobalVariable, + ) -> String { + match (&global.binding, global.space) { + (&Some(ref br), _) => { + format!( + "_group_{}_binding_{}_{}", + br.group, + br.binding, + self.entry_point.stage.to_str() + ) + } + (&None, crate::AddressSpace::PushConstant) => { + format!("_push_constant_binding_{}", self.entry_point.stage.to_str()) + } + (&None, _) => self.names[&NameKey::GlobalVariable(handle)].clone(), + } + } + + /// Helper method used to write a name for a global without additional heap allocation + fn write_global_name( + &mut self, + handle: Handle<crate::GlobalVariable>, + global: &crate::GlobalVariable, + ) -> BackendResult { + match (&global.binding, global.space) { + (&Some(ref br), _) => write!( + self.out, + "_group_{}_binding_{}_{}", + br.group, + br.binding, + self.entry_point.stage.to_str() + )?, + (&None, crate::AddressSpace::PushConstant) => write!( + self.out, + "_push_constant_binding_{}", + self.entry_point.stage.to_str() + )?, + (&None, _) => write!( + self.out, + "{}", + &self.names[&NameKey::GlobalVariable(handle)] + )?, + } + + Ok(()) + } + + /// Write a GLSL global that will carry a Naga entry point's argument or return value. + /// + /// A Naga entry point's arguments and return value are rendered in GLSL as + /// variables at global scope with the `in` and `out` storage qualifiers. + /// The code we generate for `main` loads from all the `in` globals into + /// appropriately named locals. Before it returns, `main` assigns the + /// components of its return value into all the `out` globals. + /// + /// This function writes a declaration for one such GLSL global, + /// representing a value passed into or returned from [`self.entry_point`] + /// that has a [`Location`] binding. The global's name is generated based on + /// the location index and the shader stages being connected; see + /// [`VaryingName`]. This means we don't need to know the names of + /// arguments, just their types and bindings. + /// + /// Emit nothing for entry point arguments or return values with [`BuiltIn`] + /// bindings; `main` will read from or assign to the appropriate GLSL + /// special variable; these are pre-declared. As an exception, we do declare + /// `gl_Position` or `gl_FragCoord` with the `invariant` qualifier if + /// needed. + /// + /// Use `output` together with [`self.entry_point.stage`] to determine which + /// shader stages are being connected, and choose the `in` or `out` storage + /// qualifier. + /// + /// [`self.entry_point`]: Writer::entry_point + /// [`self.entry_point.stage`]: crate::EntryPoint::stage + /// [`Location`]: crate::Binding::Location + /// [`BuiltIn`]: crate::Binding::BuiltIn + fn write_varying( + &mut self, + binding: Option<&crate::Binding>, + ty: Handle<crate::Type>, + output: bool, + ) -> Result<(), Error> { + // For a struct, emit a separate global for each member with a binding. + if let crate::TypeInner::Struct { ref members, .. } = self.module.types[ty].inner { + for member in members { + self.write_varying(member.binding.as_ref(), member.ty, output)?; + } + return Ok(()); + } + + let binding = match binding { + None => return Ok(()), + Some(binding) => binding, + }; + + let (location, interpolation, sampling, second_blend_source) = match *binding { + crate::Binding::Location { + location, + interpolation, + sampling, + second_blend_source, + } => (location, interpolation, sampling, second_blend_source), + crate::Binding::BuiltIn(built_in) => { + if let crate::BuiltIn::Position { invariant: true } = built_in { + match (self.options.version, self.entry_point.stage) { + ( + Version::Embedded { + version: 300, + is_webgl: true, + }, + ShaderStage::Fragment, + ) => { + // `invariant gl_FragCoord` is not allowed in WebGL2 and possibly + // OpenGL ES in general (waiting on confirmation). + // + // See https://github.com/KhronosGroup/WebGL/issues/3518 + } + _ => { + writeln!( + self.out, + "invariant {};", + glsl_built_in( + built_in, + VaryingOptions::from_writer_options(self.options, output) + ) + )?; + } + } + } + return Ok(()); + } + }; + + // Write the interpolation modifier if needed + // + // We ignore all interpolation and auxiliary modifiers that aren't used in fragment + // shaders' input globals or vertex shaders' output globals. + let emit_interpolation_and_auxiliary = match self.entry_point.stage { + ShaderStage::Vertex => output, + ShaderStage::Fragment => !output, + ShaderStage::Compute => false, + }; + + // Write the I/O locations, if allowed + let io_location = if self.options.version.supports_explicit_locations() + || !emit_interpolation_and_auxiliary + { + if self.options.version.supports_io_locations() { + if second_blend_source { + write!(self.out, "layout(location = {location}, index = 1) ")?; + } else { + write!(self.out, "layout(location = {location}) ")?; + } + None + } else { + Some(VaryingLocation { + location, + index: second_blend_source as u32, + }) + } + } else { + None + }; + + // Write the interpolation qualifier. + if let Some(interp) = interpolation { + if emit_interpolation_and_auxiliary { + write!(self.out, "{} ", glsl_interpolation(interp))?; + } + } + + // Write the sampling auxiliary qualifier. + // + // Before GLSL 4.2, the `centroid` and `sample` qualifiers were required to appear + // immediately before the `in` / `out` qualifier, so we'll just follow that rule + // here, regardless of the version. + if let Some(sampling) = sampling { + if emit_interpolation_and_auxiliary { + if let Some(qualifier) = glsl_sampling(sampling) { + write!(self.out, "{qualifier} ")?; + } + } + } + + // Write the input/output qualifier. + write!(self.out, "{} ", if output { "out" } else { "in" })?; + + // Write the type + // `write_type` adds no leading or trailing spaces + self.write_type(ty)?; + + // Finally write the global name and end the global with a `;` and a newline + // Leading space is important + let vname = VaryingName { + binding: &crate::Binding::Location { + location, + interpolation: None, + sampling: None, + second_blend_source, + }, + stage: self.entry_point.stage, + options: VaryingOptions::from_writer_options(self.options, output), + }; + writeln!(self.out, " {vname};")?; + + if let Some(location) = io_location { + self.varying.insert(vname.to_string(), location); + } + + Ok(()) + } + + /// Helper method used to write functions (both entry points and regular functions) + /// + /// # Notes + /// Adds a newline + fn write_function( + &mut self, + ty: back::FunctionType, + func: &crate::Function, + info: &valid::FunctionInfo, + ) -> BackendResult { + // Create a function context for the function being written + let ctx = back::FunctionCtx { + ty, + info, + expressions: &func.expressions, + named_expressions: &func.named_expressions, + }; + + self.named_expressions.clear(); + self.update_expressions_to_bake(func, info); + + // Write the function header + // + // glsl headers are the same as in c: + // `ret_type name(args)` + // `ret_type` is the return type + // `name` is the function name + // `args` is a comma separated list of `type name` + // | - `type` is the argument type + // | - `name` is the argument name + + // Start by writing the return type if any otherwise write void + // This is the only place where `void` is a valid type + // (though it's more a keyword than a type) + if let back::FunctionType::EntryPoint(_) = ctx.ty { + write!(self.out, "void")?; + } else if let Some(ref result) = func.result { + self.write_type(result.ty)?; + if let TypeInner::Array { base, size, .. } = self.module.types[result.ty].inner { + self.write_array_size(base, size)? + } + } else { + write!(self.out, "void")?; + } + + // Write the function name and open parentheses for the argument list + let function_name = match ctx.ty { + back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)], + back::FunctionType::EntryPoint(_) => "main", + }; + write!(self.out, " {function_name}(")?; + + // Write the comma separated argument list + // + // We need access to `Self` here so we use the reference passed to the closure as an + // argument instead of capturing as that would cause a borrow checker error + let arguments = match ctx.ty { + back::FunctionType::EntryPoint(_) => &[][..], + back::FunctionType::Function(_) => &func.arguments, + }; + let arguments: Vec<_> = arguments + .iter() + .enumerate() + .filter(|&(_, arg)| match self.module.types[arg.ty].inner { + TypeInner::Sampler { .. } => false, + _ => true, + }) + .collect(); + self.write_slice(&arguments, |this, _, &(i, arg)| { + // Write the argument type + match this.module.types[arg.ty].inner { + // We treat images separately because they might require + // writing the storage format + TypeInner::Image { + dim, + arrayed, + class, + } => { + // Write the storage format if needed + if let TypeInner::Image { + class: crate::ImageClass::Storage { format, .. }, + .. + } = this.module.types[arg.ty].inner + { + write!(this.out, "layout({}) ", glsl_storage_format(format)?)?; + } + + // write the type + // + // This is way we need the leading space because `write_image_type` doesn't add + // any spaces at the beginning or end + this.write_image_type(dim, arrayed, class)?; + } + TypeInner::Pointer { base, .. } => { + // write parameter qualifiers + write!(this.out, "inout ")?; + this.write_type(base)?; + } + // All other types are written by `write_type` + _ => { + this.write_type(arg.ty)?; + } + } + + // Write the argument name + // The leading space is important + write!(this.out, " {}", &this.names[&ctx.argument_key(i as u32)])?; + + // Write array size + match this.module.types[arg.ty].inner { + TypeInner::Array { base, size, .. } => { + this.write_array_size(base, size)?; + } + TypeInner::Pointer { base, .. } => { + if let TypeInner::Array { base, size, .. } = this.module.types[base].inner { + this.write_array_size(base, size)?; + } + } + _ => {} + } + + Ok(()) + })?; + + // Close the parentheses and open braces to start the function body + writeln!(self.out, ") {{")?; + + if self.options.zero_initialize_workgroup_memory + && ctx.ty.is_compute_entry_point(self.module) + { + self.write_workgroup_variables_initialization(&ctx)?; + } + + // Compose the function arguments from globals, in case of an entry point. + if let back::FunctionType::EntryPoint(ep_index) = ctx.ty { + let stage = self.module.entry_points[ep_index as usize].stage; + for (index, arg) in func.arguments.iter().enumerate() { + write!(self.out, "{}", back::INDENT)?; + self.write_type(arg.ty)?; + let name = &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)]; + write!(self.out, " {name}")?; + write!(self.out, " = ")?; + match self.module.types[arg.ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + self.write_type(arg.ty)?; + write!(self.out, "(")?; + for (index, member) in members.iter().enumerate() { + let varying_name = VaryingName { + binding: member.binding.as_ref().unwrap(), + stage, + options: VaryingOptions::from_writer_options(self.options, false), + }; + if index != 0 { + write!(self.out, ", ")?; + } + write!(self.out, "{varying_name}")?; + } + writeln!(self.out, ");")?; + } + _ => { + let varying_name = VaryingName { + binding: arg.binding.as_ref().unwrap(), + stage, + options: VaryingOptions::from_writer_options(self.options, false), + }; + writeln!(self.out, "{varying_name};")?; + } + } + } + } + + // Write all function locals + // Locals are `type name (= init)?;` where the init part (including the =) are optional + // + // Always adds a newline + for (handle, local) in func.local_variables.iter() { + // Write indentation (only for readability) and the type + // `write_type` adds no trailing space + write!(self.out, "{}", back::INDENT)?; + self.write_type(local.ty)?; + + // Write the local name + // The leading space is important + write!(self.out, " {}", self.names[&ctx.name_key(handle)])?; + // Write size for array type + if let TypeInner::Array { base, size, .. } = self.module.types[local.ty].inner { + self.write_array_size(base, size)?; + } + // Write the local initializer if needed + if let Some(init) = local.init { + // Put the equal signal only if there's a initializer + // The leading and trailing spaces aren't needed but help with readability + write!(self.out, " = ")?; + + // Write the constant + // `write_constant` adds no trailing or leading space/newline + self.write_expr(init, &ctx)?; + } else if is_value_init_supported(self.module, local.ty) { + write!(self.out, " = ")?; + self.write_zero_init_value(local.ty)?; + } + + // Finish the local with `;` and add a newline (only for readability) + writeln!(self.out, ";")? + } + + // Write the function body (statement list) + for sta in func.body.iter() { + // Write a statement, the indentation should always be 1 when writing the function body + // `write_stmt` adds a newline + self.write_stmt(sta, &ctx, back::Level(1))?; + } + + // Close braces and add a newline + writeln!(self.out, "}}")?; + + Ok(()) + } + + fn write_workgroup_variables_initialization( + &mut self, + ctx: &back::FunctionCtx, + ) -> BackendResult { + let mut vars = self + .module + .global_variables + .iter() + .filter(|&(handle, var)| { + !ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }) + .peekable(); + + if vars.peek().is_some() { + let level = back::Level(1); + + writeln!(self.out, "{level}if (gl_LocalInvocationID == uvec3(0u)) {{")?; + + for (handle, var) in vars { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, "{}{} = ", level.next(), name)?; + self.write_zero_init_value(var.ty)?; + writeln!(self.out, ";")?; + } + + writeln!(self.out, "{level}}}")?; + self.write_barrier(crate::Barrier::WORK_GROUP, level)?; + } + + Ok(()) + } + + /// Write a list of comma separated `T` values using a writer function `F`. + /// + /// The writer function `F` receives a mutable reference to `self` that if needed won't cause + /// borrow checker issues (using for example a closure with `self` will cause issues), the + /// second argument is the 0 based index of the element on the list, and the last element is + /// a reference to the element `T` being written + /// + /// # Notes + /// - Adds no newlines or leading/trailing whitespace + /// - The last element won't have a trailing `,` + fn write_slice<T, F: FnMut(&mut Self, u32, &T) -> BackendResult>( + &mut self, + data: &[T], + mut f: F, + ) -> BackendResult { + // Loop through `data` invoking `f` for each element + for (index, item) in data.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + f(self, index as u32, item)?; + } + + Ok(()) + } + + /// Helper method used to write global constants + fn write_global_constant(&mut self, handle: Handle<crate::Constant>) -> BackendResult { + write!(self.out, "const ")?; + let constant = &self.module.constants[handle]; + self.write_type(constant.ty)?; + let name = &self.names[&NameKey::Constant(handle)]; + write!(self.out, " {name}")?; + if let TypeInner::Array { base, size, .. } = self.module.types[constant.ty].inner { + self.write_array_size(base, size)?; + } + write!(self.out, " = ")?; + self.write_const_expr(constant.init)?; + writeln!(self.out, ";")?; + Ok(()) + } + + /// Helper method used to output a dot product as an arithmetic expression + /// + fn write_dot_product( + &mut self, + arg: Handle<crate::Expression>, + arg1: Handle<crate::Expression>, + size: usize, + ctx: &back::FunctionCtx, + ) -> BackendResult { + // Write parentheses around the dot product expression to prevent operators + // with different precedences from applying earlier. + write!(self.out, "(")?; + + // Cycle trough all the components of the vector + for index in 0..size { + let component = back::COMPONENTS[index]; + // Write the addition to the previous product + // This will print an extra '+' at the beginning but that is fine in glsl + write!(self.out, " + ")?; + // Write the first vector expression, this expression is marked to be + // cached so unless it can't be cached (for example, it's a Constant) + // it shouldn't produce large expressions. + self.write_expr(arg, ctx)?; + // Access the current component on the first vector + write!(self.out, ".{component} * ")?; + // Write the second vector expression, this expression is marked to be + // cached so unless it can't be cached (for example, it's a Constant) + // it shouldn't produce large expressions. + self.write_expr(arg1, ctx)?; + // Access the current component on the second vector + write!(self.out, ".{component}")?; + } + + write!(self.out, ")")?; + Ok(()) + } + + /// Helper method used to write structs + /// + /// # Notes + /// Ends in a newline + fn write_struct_body( + &mut self, + handle: Handle<crate::Type>, + members: &[crate::StructMember], + ) -> BackendResult { + // glsl structs are written as in C + // `struct name() { members };` + // | `struct` is a keyword + // | `name` is the struct name + // | `members` is a semicolon separated list of `type name` + // | `type` is the member type + // | `name` is the member name + writeln!(self.out, "{{")?; + + for (idx, member) in members.iter().enumerate() { + // The indentation is only for readability + write!(self.out, "{}", back::INDENT)?; + + match self.module.types[member.ty].inner { + TypeInner::Array { + base, + size, + stride: _, + } => { + self.write_type(base)?; + write!( + self.out, + " {}", + &self.names[&NameKey::StructMember(handle, idx as u32)] + )?; + // Write [size] + self.write_array_size(base, size)?; + // Newline is important + writeln!(self.out, ";")?; + } + _ => { + // Write the member type + // Adds no trailing space + self.write_type(member.ty)?; + + // Write the member name and put a semicolon + // The leading space is important + // All members must have a semicolon even the last one + writeln!( + self.out, + " {};", + &self.names[&NameKey::StructMember(handle, idx as u32)] + )?; + } + } + } + + write!(self.out, "}}")?; + Ok(()) + } + + /// Helper method used to write statements + /// + /// # Notes + /// Always adds a newline + fn write_stmt( + &mut self, + sta: &crate::Statement, + ctx: &back::FunctionCtx, + level: back::Level, + ) -> BackendResult { + use crate::Statement; + + match *sta { + // This is where we can generate intermediate constants for some expression types. + Statement::Emit(ref range) => { + for handle in range.clone() { + let ptr_class = ctx.resolve_type(handle, &self.module.types).pointer_space(); + let expr_name = if ptr_class.is_some() { + // GLSL can't save a pointer-valued expression in a variable, + // but we shouldn't ever need to: they should never be named expressions, + // and none of the expression types flagged by bake_ref_count can be pointer-valued. + None + } else if let Some(name) = ctx.named_expressions.get(&handle) { + // Front end provides names for all variables at the start of writing. + // But we write them to step by step. We need to recache them + // Otherwise, we could accidentally write variable name instead of full expression. + // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords. + Some(self.namer.call(name)) + } else if self.need_bake_expressions.contains(&handle) { + Some(format!("{}{}", back::BAKE_PREFIX, handle.index())) + } else { + None + }; + + // If we are going to write an `ImageLoad` next and the target image + // is sampled and we are using the `Restrict` policy for bounds + // checking images we need to write a local holding the clamped lod. + if let crate::Expression::ImageLoad { + image, + level: Some(level_expr), + .. + } = ctx.expressions[handle] + { + if let TypeInner::Image { + class: crate::ImageClass::Sampled { .. }, + .. + } = *ctx.resolve_type(image, &self.module.types) + { + if let proc::BoundsCheckPolicy::Restrict = self.policies.image_load { + write!(self.out, "{level}")?; + self.write_clamped_lod(ctx, handle, image, level_expr)? + } + } + } + + if let Some(name) = expr_name { + write!(self.out, "{level}")?; + self.write_named_expr(handle, name, handle, ctx)?; + } + } + } + // Blocks are simple we just need to write the block statements between braces + // We could also just print the statements but this is more readable and maps more + // closely to the IR + Statement::Block(ref block) => { + write!(self.out, "{level}")?; + writeln!(self.out, "{{")?; + for sta in block.iter() { + // Increase the indentation to help with readability + self.write_stmt(sta, ctx, level.next())? + } + writeln!(self.out, "{level}}}")? + } + // Ifs are written as in C: + // ``` + // if(condition) { + // accept + // } else { + // reject + // } + // ``` + Statement::If { + condition, + ref accept, + ref reject, + } => { + write!(self.out, "{level}")?; + write!(self.out, "if (")?; + self.write_expr(condition, ctx)?; + writeln!(self.out, ") {{")?; + + for sta in accept { + // Increase indentation to help with readability + self.write_stmt(sta, ctx, level.next())?; + } + + // If there are no statements in the reject block we skip writing it + // This is only for readability + if !reject.is_empty() { + writeln!(self.out, "{level}}} else {{")?; + + for sta in reject { + // Increase indentation to help with readability + self.write_stmt(sta, ctx, level.next())?; + } + } + + writeln!(self.out, "{level}}}")? + } + // Switch are written as in C: + // ``` + // switch (selector) { + // // Fallthrough + // case label: + // block + // // Non fallthrough + // case label: + // block + // break; + // default: + // block + // } + // ``` + // Where the `default` case happens isn't important but we put it last + // so that we don't need to print a `break` for it + Statement::Switch { + selector, + ref cases, + } => { + // Start the switch + write!(self.out, "{level}")?; + write!(self.out, "switch(")?; + self.write_expr(selector, ctx)?; + writeln!(self.out, ") {{")?; + + // Write all cases + let l2 = level.next(); + for case in cases { + match case.value { + crate::SwitchValue::I32(value) => write!(self.out, "{l2}case {value}:")?, + crate::SwitchValue::U32(value) => write!(self.out, "{l2}case {value}u:")?, + crate::SwitchValue::Default => write!(self.out, "{l2}default:")?, + } + + let write_block_braces = !(case.fall_through && case.body.is_empty()); + if write_block_braces { + writeln!(self.out, " {{")?; + } else { + writeln!(self.out)?; + } + + for sta in case.body.iter() { + self.write_stmt(sta, ctx, l2.next())?; + } + + if !case.fall_through && case.body.last().map_or(true, |s| !s.is_terminator()) { + writeln!(self.out, "{}break;", l2.next())?; + } + + if write_block_braces { + writeln!(self.out, "{l2}}}")?; + } + } + + writeln!(self.out, "{level}}}")? + } + // Loops in naga IR are based on wgsl loops, glsl can emulate the behaviour by using a + // while true loop and appending the continuing block to the body resulting on: + // ``` + // bool loop_init = true; + // while(true) { + // if (!loop_init) { <continuing> } + // loop_init = false; + // <body> + // } + // ``` + Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + if !continuing.is_empty() || break_if.is_some() { + let gate_name = self.namer.call("loop_init"); + writeln!(self.out, "{level}bool {gate_name} = true;")?; + writeln!(self.out, "{level}while(true) {{")?; + let l2 = level.next(); + let l3 = l2.next(); + writeln!(self.out, "{l2}if (!{gate_name}) {{")?; + for sta in continuing { + self.write_stmt(sta, ctx, l3)?; + } + if let Some(condition) = break_if { + write!(self.out, "{l3}if (")?; + self.write_expr(condition, ctx)?; + writeln!(self.out, ") {{")?; + writeln!(self.out, "{}break;", l3.next())?; + writeln!(self.out, "{l3}}}")?; + } + writeln!(self.out, "{l2}}}")?; + writeln!(self.out, "{}{} = false;", level.next(), gate_name)?; + } else { + writeln!(self.out, "{level}while(true) {{")?; + } + for sta in body { + self.write_stmt(sta, ctx, level.next())?; + } + writeln!(self.out, "{level}}}")? + } + // Break, continue and return as written as in C + // `break;` + Statement::Break => { + write!(self.out, "{level}")?; + writeln!(self.out, "break;")? + } + // `continue;` + Statement::Continue => { + write!(self.out, "{level}")?; + writeln!(self.out, "continue;")? + } + // `return expr;`, `expr` is optional + Statement::Return { value } => { + write!(self.out, "{level}")?; + match ctx.ty { + back::FunctionType::Function(_) => { + write!(self.out, "return")?; + // Write the expression to be returned if needed + if let Some(expr) = value { + write!(self.out, " ")?; + self.write_expr(expr, ctx)?; + } + writeln!(self.out, ";")?; + } + back::FunctionType::EntryPoint(ep_index) => { + let mut has_point_size = false; + let ep = &self.module.entry_points[ep_index as usize]; + if let Some(ref result) = ep.function.result { + let value = value.unwrap(); + match self.module.types[result.ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + let temp_struct_name = match ctx.expressions[value] { + crate::Expression::Compose { .. } => { + let return_struct = "_tmp_return"; + write!( + self.out, + "{} {} = ", + &self.names[&NameKey::Type(result.ty)], + return_struct + )?; + self.write_expr(value, ctx)?; + writeln!(self.out, ";")?; + write!(self.out, "{level}")?; + Some(return_struct) + } + _ => None, + }; + + for (index, member) in members.iter().enumerate() { + if let Some(crate::Binding::BuiltIn( + crate::BuiltIn::PointSize, + )) = member.binding + { + has_point_size = true; + } + + let varying_name = VaryingName { + binding: member.binding.as_ref().unwrap(), + stage: ep.stage, + options: VaryingOptions::from_writer_options( + self.options, + true, + ), + }; + write!(self.out, "{varying_name} = ")?; + + if let Some(struct_name) = temp_struct_name { + write!(self.out, "{struct_name}")?; + } else { + self.write_expr(value, ctx)?; + } + + // Write field name + writeln!( + self.out, + ".{};", + &self.names + [&NameKey::StructMember(result.ty, index as u32)] + )?; + write!(self.out, "{level}")?; + } + } + _ => { + let name = VaryingName { + binding: result.binding.as_ref().unwrap(), + stage: ep.stage, + options: VaryingOptions::from_writer_options( + self.options, + true, + ), + }; + write!(self.out, "{name} = ")?; + self.write_expr(value, ctx)?; + writeln!(self.out, ";")?; + write!(self.out, "{level}")?; + } + } + } + + let is_vertex_stage = self.module.entry_points[ep_index as usize].stage + == ShaderStage::Vertex; + if is_vertex_stage + && self + .options + .writer_flags + .contains(WriterFlags::ADJUST_COORDINATE_SPACE) + { + writeln!( + self.out, + "gl_Position.yz = vec2(-gl_Position.y, gl_Position.z * 2.0 - gl_Position.w);", + )?; + write!(self.out, "{level}")?; + } + + if is_vertex_stage + && self + .options + .writer_flags + .contains(WriterFlags::FORCE_POINT_SIZE) + && !has_point_size + { + writeln!(self.out, "gl_PointSize = 1.0;")?; + write!(self.out, "{level}")?; + } + writeln!(self.out, "return;")?; + } + } + } + // This is one of the places were glsl adds to the syntax of C in this case the discard + // keyword which ceases all further processing in a fragment shader, it's called OpKill + // in spir-v that's why it's called `Statement::Kill` + Statement::Kill => writeln!(self.out, "{level}discard;")?, + Statement::Barrier(flags) => { + self.write_barrier(flags, level)?; + } + // Stores in glsl are just variable assignments written as `pointer = value;` + Statement::Store { pointer, value } => { + write!(self.out, "{level}")?; + self.write_expr(pointer, ctx)?; + write!(self.out, " = ")?; + self.write_expr(value, ctx)?; + writeln!(self.out, ";")? + } + Statement::WorkGroupUniformLoad { pointer, result } => { + // GLSL doesn't have pointers, which means that this backend needs to ensure that + // the actual "loading" is happening between the two barriers. + // This is done in `Emit` by never emitting a variable name for pointer variables + self.write_barrier(crate::Barrier::WORK_GROUP, level)?; + + let result_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + write!(self.out, "{level}")?; + // Expressions cannot have side effects, so just writing the expression here is fine. + self.write_named_expr(pointer, result_name, result, ctx)?; + + self.write_barrier(crate::Barrier::WORK_GROUP, level)?; + } + // Stores a value into an image. + Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => { + write!(self.out, "{level}")?; + self.write_image_store(ctx, image, coordinate, array_index, value)? + } + // A `Call` is written `name(arguments)` where `arguments` is a comma separated expressions list + Statement::Call { + function, + ref arguments, + result, + } => { + write!(self.out, "{level}")?; + if let Some(expr) = result { + let name = format!("{}{}", back::BAKE_PREFIX, expr.index()); + let result = self.module.functions[function].result.as_ref().unwrap(); + self.write_type(result.ty)?; + write!(self.out, " {name}")?; + if let TypeInner::Array { base, size, .. } = self.module.types[result.ty].inner + { + self.write_array_size(base, size)? + } + write!(self.out, " = ")?; + self.named_expressions.insert(expr, name); + } + write!(self.out, "{}(", &self.names[&NameKey::Function(function)])?; + let arguments: Vec<_> = arguments + .iter() + .enumerate() + .filter_map(|(i, arg)| { + let arg_ty = self.module.functions[function].arguments[i].ty; + match self.module.types[arg_ty].inner { + TypeInner::Sampler { .. } => None, + _ => Some(*arg), + } + }) + .collect(); + self.write_slice(&arguments, |this, _, arg| this.write_expr(*arg, ctx))?; + writeln!(self.out, ");")? + } + Statement::Atomic { + pointer, + ref fun, + value, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.resolve_type(result, &self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + let fun_str = fun.to_glsl(); + write!(self.out, "atomic{fun_str}(")?; + self.write_expr(pointer, ctx)?; + write!(self.out, ", ")?; + // handle the special cases + match *fun { + crate::AtomicFunction::Subtract => { + // we just wrote `InterlockedAdd`, so negate the argument + write!(self.out, "-")?; + } + crate::AtomicFunction::Exchange { compare: Some(_) } => { + return Err(Error::Custom( + "atomic CompareExchange is not implemented".to_string(), + )); + } + _ => {} + } + self.write_expr(value, ctx)?; + writeln!(self.out, ");")?; + } + Statement::RayQuery { .. } => unreachable!(), + } + + Ok(()) + } + + /// Write a const expression. + /// + /// Write `expr`, a handle to an [`Expression`] in the current [`Module`]'s + /// constant expression arena, as GLSL expression. + /// + /// # Notes + /// Adds no newlines or leading/trailing whitespace + /// + /// [`Expression`]: crate::Expression + /// [`Module`]: crate::Module + fn write_const_expr(&mut self, expr: Handle<crate::Expression>) -> BackendResult { + self.write_possibly_const_expr( + expr, + &self.module.const_expressions, + |expr| &self.info[expr], + |writer, expr| writer.write_const_expr(expr), + ) + } + + /// Write [`Expression`] variants that can occur in both runtime and const expressions. + /// + /// Write `expr`, a handle to an [`Expression`] in the arena `expressions`, + /// as as GLSL expression. This must be one of the [`Expression`] variants + /// that is allowed to occur in constant expressions. + /// + /// Use `write_expression` to write subexpressions. + /// + /// This is the common code for `write_expr`, which handles arbitrary + /// runtime expressions, and `write_const_expr`, which only handles + /// const-expressions. Each of those callers passes itself (essentially) as + /// the `write_expression` callback, so that subexpressions are restricted + /// to the appropriate variants. + /// + /// # Notes + /// Adds no newlines or leading/trailing whitespace + /// + /// [`Expression`]: crate::Expression + fn write_possibly_const_expr<'w, I, E>( + &'w mut self, + expr: Handle<crate::Expression>, + expressions: &crate::Arena<crate::Expression>, + info: I, + write_expression: E, + ) -> BackendResult + where + I: Fn(Handle<crate::Expression>) -> &'w proc::TypeResolution, + E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult, + { + use crate::Expression; + + match expressions[expr] { + Expression::Literal(literal) => { + match literal { + // Floats are written using `Debug` instead of `Display` because it always appends the + // decimal part even it's zero which is needed for a valid glsl float constant + crate::Literal::F64(value) => write!(self.out, "{:?}LF", value)?, + crate::Literal::F32(value) => write!(self.out, "{:?}", value)?, + // Unsigned integers need a `u` at the end + // + // While `core` doesn't necessarily need it, it's allowed and since `es` needs it we + // always write it as the extra branch wouldn't have any benefit in readability + crate::Literal::U32(value) => write!(self.out, "{}u", value)?, + crate::Literal::I32(value) => write!(self.out, "{}", value)?, + crate::Literal::Bool(value) => write!(self.out, "{}", value)?, + crate::Literal::I64(_) => { + return Err(Error::Custom("GLSL has no 64-bit integer type".into())); + } + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".into(), + )); + } + } + } + Expression::Constant(handle) => { + let constant = &self.module.constants[handle]; + if constant.name.is_some() { + write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?; + } else { + self.write_const_expr(constant.init)?; + } + } + Expression::ZeroValue(ty) => { + self.write_zero_init_value(ty)?; + } + Expression::Compose { ty, ref components } => { + self.write_type(ty)?; + + if let TypeInner::Array { base, size, .. } = self.module.types[ty].inner { + self.write_array_size(base, size)?; + } + + write!(self.out, "(")?; + for (index, component) in components.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + write_expression(self, *component)?; + } + write!(self.out, ")")? + } + // `Splat` needs to actually write down a vector, it's not always inferred in GLSL. + Expression::Splat { size: _, value } => { + let resolved = info(expr).inner_with(&self.module.types); + self.write_value_type(resolved)?; + write!(self.out, "(")?; + write_expression(self, value)?; + write!(self.out, ")")? + } + _ => unreachable!(), + } + + Ok(()) + } + + /// Helper method to write expressions + /// + /// # Notes + /// Doesn't add any newlines or leading/trailing spaces + fn write_expr( + &mut self, + expr: Handle<crate::Expression>, + ctx: &back::FunctionCtx, + ) -> BackendResult { + use crate::Expression; + + if let Some(name) = self.named_expressions.get(&expr) { + write!(self.out, "{name}")?; + return Ok(()); + } + + match ctx.expressions[expr] { + Expression::Literal(_) + | Expression::Constant(_) + | Expression::ZeroValue(_) + | Expression::Compose { .. } + | Expression::Splat { .. } => { + self.write_possibly_const_expr( + expr, + ctx.expressions, + |expr| &ctx.info[expr].ty, + |writer, expr| writer.write_expr(expr, ctx), + )?; + } + // `Access` is applied to arrays, vectors and matrices and is written as indexing + Expression::Access { base, index } => { + self.write_expr(base, ctx)?; + write!(self.out, "[")?; + self.write_expr(index, ctx)?; + write!(self.out, "]")? + } + // `AccessIndex` is the same as `Access` except that the index is a constant and it can + // be applied to structs, in this case we need to find the name of the field at that + // index and write `base.field_name` + Expression::AccessIndex { base, index } => { + self.write_expr(base, ctx)?; + + let base_ty_res = &ctx.info[base].ty; + let mut resolved = base_ty_res.inner_with(&self.module.types); + let base_ty_handle = match *resolved { + TypeInner::Pointer { base, space: _ } => { + resolved = &self.module.types[base].inner; + Some(base) + } + _ => base_ty_res.handle(), + }; + + match *resolved { + TypeInner::Vector { .. } => { + // Write vector access as a swizzle + write!(self.out, ".{}", back::COMPONENTS[index as usize])? + } + TypeInner::Matrix { .. } + | TypeInner::Array { .. } + | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?, + TypeInner::Struct { .. } => { + // This will never panic in case the type is a `Struct`, this is not true + // for other types so we can only check while inside this match arm + let ty = base_ty_handle.unwrap(); + + write!( + self.out, + ".{}", + &self.names[&NameKey::StructMember(ty, index)] + )? + } + ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))), + } + } + // `Swizzle` adds a few letters behind the dot. + Expression::Swizzle { + size, + vector, + pattern, + } => { + self.write_expr(vector, ctx)?; + write!(self.out, ".")?; + for &sc in pattern[..size as usize].iter() { + self.out.write_char(back::COMPONENTS[sc as usize])?; + } + } + // Function arguments are written as the argument name + Expression::FunctionArgument(pos) => { + write!(self.out, "{}", &self.names[&ctx.argument_key(pos)])? + } + // Global variables need some special work for their name but + // `get_global_name` does the work for us + Expression::GlobalVariable(handle) => { + let global = &self.module.global_variables[handle]; + self.write_global_name(handle, global)? + } + // A local is written as it's name + Expression::LocalVariable(handle) => { + write!(self.out, "{}", self.names[&ctx.name_key(handle)])? + } + // glsl has no pointers so there's no load operation, just write the pointer expression + Expression::Load { pointer } => self.write_expr(pointer, ctx)?, + // `ImageSample` is a bit complicated compared to the rest of the IR. + // + // First there are three variations depending whether the sample level is explicitly set, + // if it's automatic or it it's bias: + // `texture(image, coordinate)` - Automatic sample level + // `texture(image, coordinate, bias)` - Bias sample level + // `textureLod(image, coordinate, level)` - Zero or Exact sample level + // + // Furthermore if `depth_ref` is some we need to append it to the coordinate vector + Expression::ImageSample { + image, + sampler: _, //TODO? + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + } => { + let dim = match *ctx.resolve_type(image, &self.module.types) { + TypeInner::Image { dim, .. } => dim, + _ => unreachable!(), + }; + + if dim == crate::ImageDimension::Cube + && array_index.is_some() + && depth_ref.is_some() + { + match level { + crate::SampleLevel::Zero + | crate::SampleLevel::Exact(_) + | crate::SampleLevel::Gradient { .. } + | crate::SampleLevel::Bias(_) => { + return Err(Error::Custom(String::from( + "gsamplerCubeArrayShadow isn't supported in textureGrad, \ + textureLod or texture with bias", + ))) + } + crate::SampleLevel::Auto => {} + } + } + + // textureLod on sampler2DArrayShadow and samplerCubeShadow does not exist in GLSL. + // To emulate this, we will have to use textureGrad with a constant gradient of 0. + let workaround_lod_array_shadow_as_grad = (array_index.is_some() + || dim == crate::ImageDimension::Cube) + && depth_ref.is_some() + && gather.is_none() + && !self + .options + .writer_flags + .contains(WriterFlags::TEXTURE_SHADOW_LOD); + + //Write the function to be used depending on the sample level + let fun_name = match level { + crate::SampleLevel::Zero if gather.is_some() => "textureGather", + crate::SampleLevel::Auto | crate::SampleLevel::Bias(_) => "texture", + crate::SampleLevel::Zero | crate::SampleLevel::Exact(_) => { + if workaround_lod_array_shadow_as_grad { + "textureGrad" + } else { + "textureLod" + } + } + crate::SampleLevel::Gradient { .. } => "textureGrad", + }; + let offset_name = match offset { + Some(_) => "Offset", + None => "", + }; + + write!(self.out, "{fun_name}{offset_name}(")?; + + // Write the image that will be used + self.write_expr(image, ctx)?; + // The space here isn't required but it helps with readability + write!(self.out, ", ")?; + + // We need to get the coordinates vector size to later build a vector that's `size + 1` + // if `depth_ref` is some, if it isn't a vector we panic as that's not a valid expression + let mut coord_dim = match *ctx.resolve_type(coordinate, &self.module.types) { + TypeInner::Vector { size, .. } => size as u8, + TypeInner::Scalar { .. } => 1, + _ => unreachable!(), + }; + + if array_index.is_some() { + coord_dim += 1; + } + let merge_depth_ref = depth_ref.is_some() && gather.is_none() && coord_dim < 4; + if merge_depth_ref { + coord_dim += 1; + } + + let tex_1d_hack = dim == crate::ImageDimension::D1 && self.options.version.is_es(); + let is_vec = tex_1d_hack || coord_dim != 1; + // Compose a new texture coordinates vector + if is_vec { + write!(self.out, "vec{}(", coord_dim + tex_1d_hack as u8)?; + } + self.write_expr(coordinate, ctx)?; + if tex_1d_hack { + write!(self.out, ", 0.0")?; + } + if let Some(expr) = array_index { + write!(self.out, ", ")?; + self.write_expr(expr, ctx)?; + } + if merge_depth_ref { + write!(self.out, ", ")?; + self.write_expr(depth_ref.unwrap(), ctx)?; + } + if is_vec { + write!(self.out, ")")?; + } + + if let (Some(expr), false) = (depth_ref, merge_depth_ref) { + write!(self.out, ", ")?; + self.write_expr(expr, ctx)?; + } + + match level { + // Auto needs no more arguments + crate::SampleLevel::Auto => (), + // Zero needs level set to 0 + crate::SampleLevel::Zero => { + if workaround_lod_array_shadow_as_grad { + let vec_dim = match dim { + crate::ImageDimension::Cube => 3, + _ => 2, + }; + write!(self.out, ", vec{vec_dim}(0.0), vec{vec_dim}(0.0)")?; + } else if gather.is_none() { + write!(self.out, ", 0.0")?; + } + } + // Exact and bias require another argument + crate::SampleLevel::Exact(expr) => { + if workaround_lod_array_shadow_as_grad { + log::warn!("Unable to `textureLod` a shadow array, ignoring the LOD"); + write!(self.out, ", vec2(0,0), vec2(0,0)")?; + } else { + write!(self.out, ", ")?; + self.write_expr(expr, ctx)?; + } + } + crate::SampleLevel::Bias(_) => { + // This needs to be done after the offset writing + } + crate::SampleLevel::Gradient { x, y } => { + // If we are using sampler2D to replace sampler1D, we also + // need to make sure to use vec2 gradients + if tex_1d_hack { + write!(self.out, ", vec2(")?; + self.write_expr(x, ctx)?; + write!(self.out, ", 0.0)")?; + write!(self.out, ", vec2(")?; + self.write_expr(y, ctx)?; + write!(self.out, ", 0.0)")?; + } else { + write!(self.out, ", ")?; + self.write_expr(x, ctx)?; + write!(self.out, ", ")?; + self.write_expr(y, ctx)?; + } + } + } + + if let Some(constant) = offset { + write!(self.out, ", ")?; + if tex_1d_hack { + write!(self.out, "ivec2(")?; + } + self.write_const_expr(constant)?; + if tex_1d_hack { + write!(self.out, ", 0)")?; + } + } + + // Bias is always the last argument + if let crate::SampleLevel::Bias(expr) = level { + write!(self.out, ", ")?; + self.write_expr(expr, ctx)?; + } + + if let (Some(component), None) = (gather, depth_ref) { + write!(self.out, ", {}", component as usize)?; + } + + // End the function + write!(self.out, ")")? + } + Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => self.write_image_load(expr, ctx, image, coordinate, array_index, sample, level)?, + // Query translates into one of the: + // - textureSize/imageSize + // - textureQueryLevels + // - textureSamples/imageSamples + Expression::ImageQuery { image, query } => { + use crate::ImageClass; + + // This will only panic if the module is invalid + let (dim, class) = match *ctx.resolve_type(image, &self.module.types) { + TypeInner::Image { + dim, + arrayed: _, + class, + } => (dim, class), + _ => unreachable!(), + }; + let components = match dim { + crate::ImageDimension::D1 => 1, + crate::ImageDimension::D2 => 2, + crate::ImageDimension::D3 => 3, + crate::ImageDimension::Cube => 2, + }; + + if let crate::ImageQuery::Size { .. } = query { + match components { + 1 => write!(self.out, "uint(")?, + _ => write!(self.out, "uvec{components}(")?, + } + } else { + write!(self.out, "uint(")?; + } + + match query { + crate::ImageQuery::Size { level } => { + match class { + ImageClass::Sampled { multi, .. } | ImageClass::Depth { multi } => { + write!(self.out, "textureSize(")?; + self.write_expr(image, ctx)?; + if let Some(expr) = level { + let cast_to_int = matches!( + *ctx.resolve_type(expr, &self.module.types), + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Uint, + .. + }) + ); + + write!(self.out, ", ")?; + + if cast_to_int { + write!(self.out, "int(")?; + } + + self.write_expr(expr, ctx)?; + + if cast_to_int { + write!(self.out, ")")?; + } + } else if !multi { + // All textureSize calls requires an lod argument + // except for multisampled samplers + write!(self.out, ", 0")?; + } + } + ImageClass::Storage { .. } => { + write!(self.out, "imageSize(")?; + self.write_expr(image, ctx)?; + } + } + write!(self.out, ")")?; + if components != 1 || self.options.version.is_es() { + write!(self.out, ".{}", &"xyz"[..components])?; + } + } + crate::ImageQuery::NumLevels => { + write!(self.out, "textureQueryLevels(",)?; + self.write_expr(image, ctx)?; + write!(self.out, ")",)?; + } + crate::ImageQuery::NumLayers => { + let fun_name = match class { + ImageClass::Sampled { .. } | ImageClass::Depth { .. } => "textureSize", + ImageClass::Storage { .. } => "imageSize", + }; + write!(self.out, "{fun_name}(")?; + self.write_expr(image, ctx)?; + // All textureSize calls requires an lod argument + // except for multisampled samplers + if class.is_multisampled() { + write!(self.out, ", 0")?; + } + write!(self.out, ")")?; + if components != 1 || self.options.version.is_es() { + write!(self.out, ".{}", back::COMPONENTS[components])?; + } + } + crate::ImageQuery::NumSamples => { + let fun_name = match class { + ImageClass::Sampled { .. } | ImageClass::Depth { .. } => { + "textureSamples" + } + ImageClass::Storage { .. } => "imageSamples", + }; + write!(self.out, "{fun_name}(")?; + self.write_expr(image, ctx)?; + write!(self.out, ")",)?; + } + } + + write!(self.out, ")")?; + } + Expression::Unary { op, expr } => { + let operator_or_fn = match op { + crate::UnaryOperator::Negate => "-", + crate::UnaryOperator::LogicalNot => { + match *ctx.resolve_type(expr, &self.module.types) { + TypeInner::Vector { .. } => "not", + _ => "!", + } + } + crate::UnaryOperator::BitwiseNot => "~", + }; + write!(self.out, "{operator_or_fn}(")?; + + self.write_expr(expr, ctx)?; + + write!(self.out, ")")? + } + // `Binary` we just write `left op right`, except when dealing with + // comparison operations on vectors as they are implemented with + // builtin functions. + // Once again we wrap everything in parentheses to avoid precedence issues + Expression::Binary { + mut op, + left, + right, + } => { + // Holds `Some(function_name)` if the binary operation is + // implemented as a function call + use crate::{BinaryOperator as Bo, ScalarKind as Sk, TypeInner as Ti}; + + let left_inner = ctx.resolve_type(left, &self.module.types); + let right_inner = ctx.resolve_type(right, &self.module.types); + + let function = match (left_inner, right_inner) { + (&Ti::Vector { scalar, .. }, &Ti::Vector { .. }) => match op { + Bo::Less + | Bo::LessEqual + | Bo::Greater + | Bo::GreaterEqual + | Bo::Equal + | Bo::NotEqual => BinaryOperation::VectorCompare, + Bo::Modulo if scalar.kind == Sk::Float => BinaryOperation::Modulo, + Bo::And if scalar.kind == Sk::Bool => { + op = crate::BinaryOperator::LogicalAnd; + BinaryOperation::VectorComponentWise + } + Bo::InclusiveOr if scalar.kind == Sk::Bool => { + op = crate::BinaryOperator::LogicalOr; + BinaryOperation::VectorComponentWise + } + _ => BinaryOperation::Other, + }, + _ => match (left_inner.scalar_kind(), right_inner.scalar_kind()) { + (Some(Sk::Float), _) | (_, Some(Sk::Float)) => match op { + Bo::Modulo => BinaryOperation::Modulo, + _ => BinaryOperation::Other, + }, + (Some(Sk::Bool), Some(Sk::Bool)) => match op { + Bo::InclusiveOr => { + op = crate::BinaryOperator::LogicalOr; + BinaryOperation::Other + } + Bo::And => { + op = crate::BinaryOperator::LogicalAnd; + BinaryOperation::Other + } + _ => BinaryOperation::Other, + }, + _ => BinaryOperation::Other, + }, + }; + + match function { + BinaryOperation::VectorCompare => { + let op_str = match op { + Bo::Less => "lessThan(", + Bo::LessEqual => "lessThanEqual(", + Bo::Greater => "greaterThan(", + Bo::GreaterEqual => "greaterThanEqual(", + Bo::Equal => "equal(", + Bo::NotEqual => "notEqual(", + _ => unreachable!(), + }; + write!(self.out, "{op_str}")?; + self.write_expr(left, ctx)?; + write!(self.out, ", ")?; + self.write_expr(right, ctx)?; + write!(self.out, ")")?; + } + BinaryOperation::VectorComponentWise => { + self.write_value_type(left_inner)?; + write!(self.out, "(")?; + + let size = match *left_inner { + Ti::Vector { size, .. } => size, + _ => unreachable!(), + }; + + for i in 0..size as usize { + if i != 0 { + write!(self.out, ", ")?; + } + + self.write_expr(left, ctx)?; + write!(self.out, ".{}", back::COMPONENTS[i])?; + + write!(self.out, " {} ", back::binary_operation_str(op))?; + + self.write_expr(right, ctx)?; + write!(self.out, ".{}", back::COMPONENTS[i])?; + } + + write!(self.out, ")")?; + } + // TODO: handle undefined behavior of BinaryOperator::Modulo + // + // sint: + // if right == 0 return 0 + // if left == min(type_of(left)) && right == -1 return 0 + // if sign(left) == -1 || sign(right) == -1 return result as defined by WGSL + // + // uint: + // if right == 0 return 0 + // + // float: + // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798 + BinaryOperation::Modulo => { + write!(self.out, "(")?; + + // write `e1 - e2 * trunc(e1 / e2)` + self.write_expr(left, ctx)?; + write!(self.out, " - ")?; + self.write_expr(right, ctx)?; + write!(self.out, " * ")?; + write!(self.out, "trunc(")?; + self.write_expr(left, ctx)?; + write!(self.out, " / ")?; + self.write_expr(right, ctx)?; + write!(self.out, ")")?; + + write!(self.out, ")")?; + } + BinaryOperation::Other => { + write!(self.out, "(")?; + + self.write_expr(left, ctx)?; + write!(self.out, " {} ", back::binary_operation_str(op))?; + self.write_expr(right, ctx)?; + + write!(self.out, ")")?; + } + } + } + // `Select` is written as `condition ? accept : reject` + // We wrap everything in parentheses to avoid precedence issues + Expression::Select { + condition, + accept, + reject, + } => { + let cond_ty = ctx.resolve_type(condition, &self.module.types); + let vec_select = if let TypeInner::Vector { .. } = *cond_ty { + true + } else { + false + }; + + // TODO: Boolean mix on desktop required GL_EXT_shader_integer_mix + if vec_select { + // Glsl defines that for mix when the condition is a boolean the first element + // is picked if condition is false and the second if condition is true + write!(self.out, "mix(")?; + self.write_expr(reject, ctx)?; + write!(self.out, ", ")?; + self.write_expr(accept, ctx)?; + write!(self.out, ", ")?; + self.write_expr(condition, ctx)?; + } else { + write!(self.out, "(")?; + self.write_expr(condition, ctx)?; + write!(self.out, " ? ")?; + self.write_expr(accept, ctx)?; + write!(self.out, " : ")?; + self.write_expr(reject, ctx)?; + } + + write!(self.out, ")")? + } + // `Derivative` is a function call to a glsl provided function + Expression::Derivative { axis, ctrl, expr } => { + use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; + let fun_name = if self.options.version.supports_derivative_control() { + match (axis, ctrl) { + (Axis::X, Ctrl::Coarse) => "dFdxCoarse", + (Axis::X, Ctrl::Fine) => "dFdxFine", + (Axis::X, Ctrl::None) => "dFdx", + (Axis::Y, Ctrl::Coarse) => "dFdyCoarse", + (Axis::Y, Ctrl::Fine) => "dFdyFine", + (Axis::Y, Ctrl::None) => "dFdy", + (Axis::Width, Ctrl::Coarse) => "fwidthCoarse", + (Axis::Width, Ctrl::Fine) => "fwidthFine", + (Axis::Width, Ctrl::None) => "fwidth", + } + } else { + match axis { + Axis::X => "dFdx", + Axis::Y => "dFdy", + Axis::Width => "fwidth", + } + }; + write!(self.out, "{fun_name}(")?; + self.write_expr(expr, ctx)?; + write!(self.out, ")")? + } + // `Relational` is a normal function call to some glsl provided functions + Expression::Relational { fun, argument } => { + use crate::RelationalFunction as Rf; + + let fun_name = match fun { + Rf::IsInf => "isinf", + Rf::IsNan => "isnan", + Rf::All => "all", + Rf::Any => "any", + }; + write!(self.out, "{fun_name}(")?; + + self.write_expr(argument, ctx)?; + + write!(self.out, ")")? + } + Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + use crate::MathFunction as Mf; + + let fun_name = match fun { + // comparison + Mf::Abs => "abs", + Mf::Min => "min", + Mf::Max => "max", + Mf::Clamp => "clamp", + Mf::Saturate => { + write!(self.out, "clamp(")?; + + self.write_expr(arg, ctx)?; + + match *ctx.resolve_type(arg, &self.module.types) { + crate::TypeInner::Vector { size, .. } => write!( + self.out, + ", vec{}(0.0), vec{0}(1.0)", + back::vector_size_str(size) + )?, + _ => write!(self.out, ", 0.0, 1.0")?, + } + + write!(self.out, ")")?; + + return Ok(()); + } + // trigonometry + Mf::Cos => "cos", + Mf::Cosh => "cosh", + Mf::Sin => "sin", + Mf::Sinh => "sinh", + Mf::Tan => "tan", + Mf::Tanh => "tanh", + Mf::Acos => "acos", + Mf::Asin => "asin", + Mf::Atan => "atan", + Mf::Asinh => "asinh", + Mf::Acosh => "acosh", + Mf::Atanh => "atanh", + Mf::Radians => "radians", + Mf::Degrees => "degrees", + // glsl doesn't have atan2 function + // use two-argument variation of the atan function + Mf::Atan2 => "atan", + // decomposition + Mf::Ceil => "ceil", + Mf::Floor => "floor", + Mf::Round => "roundEven", + Mf::Fract => "fract", + Mf::Trunc => "trunc", + Mf::Modf => MODF_FUNCTION, + Mf::Frexp => FREXP_FUNCTION, + Mf::Ldexp => "ldexp", + // exponent + Mf::Exp => "exp", + Mf::Exp2 => "exp2", + Mf::Log => "log", + Mf::Log2 => "log2", + Mf::Pow => "pow", + // geometry + Mf::Dot => match *ctx.resolve_type(arg, &self.module.types) { + crate::TypeInner::Vector { + scalar: + crate::Scalar { + kind: crate::ScalarKind::Float, + .. + }, + .. + } => "dot", + crate::TypeInner::Vector { size, .. } => { + return self.write_dot_product(arg, arg1.unwrap(), size as usize, ctx) + } + _ => unreachable!( + "Correct TypeInner for dot product should be already validated" + ), + }, + Mf::Outer => "outerProduct", + Mf::Cross => "cross", + Mf::Distance => "distance", + Mf::Length => "length", + Mf::Normalize => "normalize", + Mf::FaceForward => "faceforward", + Mf::Reflect => "reflect", + Mf::Refract => "refract", + // computational + Mf::Sign => "sign", + Mf::Fma => { + if self.options.version.supports_fma_function() { + // Use the fma function when available + "fma" + } else { + // No fma support. Transform the function call into an arithmetic expression + write!(self.out, "(")?; + + self.write_expr(arg, ctx)?; + write!(self.out, " * ")?; + + let arg1 = + arg1.ok_or_else(|| Error::Custom("Missing fma arg1".to_owned()))?; + self.write_expr(arg1, ctx)?; + write!(self.out, " + ")?; + + let arg2 = + arg2.ok_or_else(|| Error::Custom("Missing fma arg2".to_owned()))?; + self.write_expr(arg2, ctx)?; + write!(self.out, ")")?; + + return Ok(()); + } + } + Mf::Mix => "mix", + Mf::Step => "step", + Mf::SmoothStep => "smoothstep", + Mf::Sqrt => "sqrt", + Mf::InverseSqrt => "inversesqrt", + Mf::Inverse => "inverse", + Mf::Transpose => "transpose", + Mf::Determinant => "determinant", + // bits + Mf::CountTrailingZeros => { + match *ctx.resolve_type(arg, &self.module.types) { + crate::TypeInner::Vector { size, scalar, .. } => { + let s = back::vector_size_str(size); + if let crate::ScalarKind::Uint = scalar.kind { + write!(self.out, "min(uvec{s}(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")), uvec{s}(32u))")?; + } else { + write!(self.out, "ivec{s}(min(uvec{s}(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")), uvec{s}(32u)))")?; + } + } + crate::TypeInner::Scalar(scalar) => { + if let crate::ScalarKind::Uint = scalar.kind { + write!(self.out, "min(uint(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")), 32u)")?; + } else { + write!(self.out, "int(min(uint(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")), 32u))")?; + } + } + _ => unreachable!(), + }; + return Ok(()); + } + Mf::CountLeadingZeros => { + if self.options.version.supports_integer_functions() { + match *ctx.resolve_type(arg, &self.module.types) { + crate::TypeInner::Vector { size, scalar } => { + let s = back::vector_size_str(size); + + if let crate::ScalarKind::Uint = scalar.kind { + write!(self.out, "uvec{s}(ivec{s}(31) - findMSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "mix(ivec{s}(31) - findMSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, "), ivec{s}(0), lessThan(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ", ivec{s}(0)))")?; + } + } + crate::TypeInner::Scalar(scalar) => { + if let crate::ScalarKind::Uint = scalar.kind { + write!(self.out, "uint(31 - findMSB(")?; + } else { + write!(self.out, "(")?; + self.write_expr(arg, ctx)?; + write!(self.out, " < 0 ? 0 : 31 - findMSB(")?; + } + + self.write_expr(arg, ctx)?; + write!(self.out, "))")?; + } + _ => unreachable!(), + }; + } else { + match *ctx.resolve_type(arg, &self.module.types) { + crate::TypeInner::Vector { size, scalar } => { + let s = back::vector_size_str(size); + + if let crate::ScalarKind::Uint = scalar.kind { + write!(self.out, "uvec{s}(")?; + write!(self.out, "vec{s}(31.0) - floor(log2(vec{s}(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ") + 0.5)))")?; + } else { + write!(self.out, "ivec{s}(")?; + write!(self.out, "mix(vec{s}(31.0) - floor(log2(vec{s}(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ") + 0.5)), ")?; + write!(self.out, "vec{s}(0.0), lessThan(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ", ivec{s}(0u))))")?; + } + } + crate::TypeInner::Scalar(scalar) => { + if let crate::ScalarKind::Uint = scalar.kind { + write!(self.out, "uint(31.0 - floor(log2(float(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ") + 0.5)))")?; + } else { + write!(self.out, "(")?; + self.write_expr(arg, ctx)?; + write!(self.out, " < 0 ? 0 : int(")?; + write!(self.out, "31.0 - floor(log2(float(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ") + 0.5))))")?; + } + } + _ => unreachable!(), + }; + } + + return Ok(()); + } + Mf::CountOneBits => "bitCount", + Mf::ReverseBits => "bitfieldReverse", + Mf::ExtractBits => "bitfieldExtract", + Mf::InsertBits => "bitfieldInsert", + Mf::FindLsb => "findLSB", + Mf::FindMsb => "findMSB", + // data packing + Mf::Pack4x8snorm => "packSnorm4x8", + Mf::Pack4x8unorm => "packUnorm4x8", + Mf::Pack2x16snorm => "packSnorm2x16", + Mf::Pack2x16unorm => "packUnorm2x16", + Mf::Pack2x16float => "packHalf2x16", + // data unpacking + Mf::Unpack4x8snorm => "unpackSnorm4x8", + Mf::Unpack4x8unorm => "unpackUnorm4x8", + Mf::Unpack2x16snorm => "unpackSnorm2x16", + Mf::Unpack2x16unorm => "unpackUnorm2x16", + Mf::Unpack2x16float => "unpackHalf2x16", + }; + + let extract_bits = fun == Mf::ExtractBits; + let insert_bits = fun == Mf::InsertBits; + + // Some GLSL functions always return signed integers (like findMSB), + // so they need to be cast to uint if the argument is also an uint. + let ret_might_need_int_to_uint = + matches!(fun, Mf::FindLsb | Mf::FindMsb | Mf::CountOneBits | Mf::Abs); + + // Some GLSL functions only accept signed integers (like abs), + // so they need their argument cast from uint to int. + let arg_might_need_uint_to_int = matches!(fun, Mf::Abs); + + // Check if the argument is an unsigned integer and return the vector size + // in case it's a vector + let maybe_uint_size = match *ctx.resolve_type(arg, &self.module.types) { + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Uint, + .. + }) => Some(None), + crate::TypeInner::Vector { + scalar: + crate::Scalar { + kind: crate::ScalarKind::Uint, + .. + }, + size, + } => Some(Some(size)), + _ => None, + }; + + // Cast to uint if the function needs it + if ret_might_need_int_to_uint { + if let Some(maybe_size) = maybe_uint_size { + match maybe_size { + Some(size) => write!(self.out, "uvec{}(", size as u8)?, + None => write!(self.out, "uint(")?, + } + } + } + + write!(self.out, "{fun_name}(")?; + + // Cast to int if the function needs it + if arg_might_need_uint_to_int { + if let Some(maybe_size) = maybe_uint_size { + match maybe_size { + Some(size) => write!(self.out, "ivec{}(", size as u8)?, + None => write!(self.out, "int(")?, + } + } + } + + self.write_expr(arg, ctx)?; + + // Close the cast from uint to int + if arg_might_need_uint_to_int && maybe_uint_size.is_some() { + write!(self.out, ")")? + } + + if let Some(arg) = arg1 { + write!(self.out, ", ")?; + if extract_bits { + write!(self.out, "int(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr(arg, ctx)?; + } + } + if let Some(arg) = arg2 { + write!(self.out, ", ")?; + if extract_bits || insert_bits { + write!(self.out, "int(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr(arg, ctx)?; + } + } + if let Some(arg) = arg3 { + write!(self.out, ", ")?; + if insert_bits { + write!(self.out, "int(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr(arg, ctx)?; + } + } + write!(self.out, ")")?; + + // Close the cast from int to uint + if ret_might_need_int_to_uint && maybe_uint_size.is_some() { + write!(self.out, ")")? + } + } + // `As` is always a call. + // If `convert` is true the function name is the type + // Else the function name is one of the glsl provided bitcast functions + Expression::As { + expr, + kind: target_kind, + convert, + } => { + let inner = ctx.resolve_type(expr, &self.module.types); + match convert { + Some(width) => { + // this is similar to `write_type`, but with the target kind + let scalar = glsl_scalar(crate::Scalar { + kind: target_kind, + width, + })?; + match *inner { + TypeInner::Matrix { columns, rows, .. } => write!( + self.out, + "{}mat{}x{}", + scalar.prefix, columns as u8, rows as u8 + )?, + TypeInner::Vector { size, .. } => { + write!(self.out, "{}vec{}", scalar.prefix, size as u8)? + } + _ => write!(self.out, "{}", scalar.full)?, + } + + write!(self.out, "(")?; + self.write_expr(expr, ctx)?; + write!(self.out, ")")? + } + None => { + use crate::ScalarKind as Sk; + + let target_vector_type = match *inner { + TypeInner::Vector { size, scalar } => Some(TypeInner::Vector { + size, + scalar: crate::Scalar { + kind: target_kind, + width: scalar.width, + }, + }), + _ => None, + }; + + let source_kind = inner.scalar_kind().unwrap(); + + match (source_kind, target_kind, target_vector_type) { + // No conversion needed + (Sk::Sint, Sk::Sint, _) + | (Sk::Uint, Sk::Uint, _) + | (Sk::Float, Sk::Float, _) + | (Sk::Bool, Sk::Bool, _) => { + self.write_expr(expr, ctx)?; + return Ok(()); + } + + // Cast to/from floats + (Sk::Float, Sk::Sint, _) => write!(self.out, "floatBitsToInt")?, + (Sk::Float, Sk::Uint, _) => write!(self.out, "floatBitsToUint")?, + (Sk::Sint, Sk::Float, _) => write!(self.out, "intBitsToFloat")?, + (Sk::Uint, Sk::Float, _) => write!(self.out, "uintBitsToFloat")?, + + // Cast between vector types + (_, _, Some(vector)) => { + self.write_value_type(&vector)?; + } + + // There is no way to bitcast between Uint/Sint in glsl. Use constructor conversion + (Sk::Uint | Sk::Bool, Sk::Sint, None) => write!(self.out, "int")?, + (Sk::Sint | Sk::Bool, Sk::Uint, None) => write!(self.out, "uint")?, + (Sk::Bool, Sk::Float, None) => write!(self.out, "float")?, + (Sk::Sint | Sk::Uint | Sk::Float, Sk::Bool, None) => { + write!(self.out, "bool")? + } + + (Sk::AbstractInt | Sk::AbstractFloat, _, _) + | (_, Sk::AbstractInt | Sk::AbstractFloat, _) => unreachable!(), + }; + + write!(self.out, "(")?; + self.write_expr(expr, ctx)?; + write!(self.out, ")")?; + } + } + } + // These expressions never show up in `Emit`. + Expression::CallResult(_) + | Expression::AtomicResult { .. } + | Expression::RayQueryProceedResult + | Expression::WorkGroupUniformLoadResult { .. } => unreachable!(), + // `ArrayLength` is written as `expr.length()` and we convert it to a uint + Expression::ArrayLength(expr) => { + write!(self.out, "uint(")?; + self.write_expr(expr, ctx)?; + write!(self.out, ".length())")? + } + // not supported yet + Expression::RayQueryGetIntersection { .. } => unreachable!(), + } + + Ok(()) + } + + /// Helper function to write the local holding the clamped lod + fn write_clamped_lod( + &mut self, + ctx: &back::FunctionCtx, + expr: Handle<crate::Expression>, + image: Handle<crate::Expression>, + level_expr: Handle<crate::Expression>, + ) -> Result<(), Error> { + // Define our local and start a call to `clamp` + write!( + self.out, + "int {}{}{} = clamp(", + back::BAKE_PREFIX, + expr.index(), + CLAMPED_LOD_SUFFIX + )?; + // Write the lod that will be clamped + self.write_expr(level_expr, ctx)?; + // Set the min value to 0 and start a call to `textureQueryLevels` to get + // the maximum value + write!(self.out, ", 0, textureQueryLevels(")?; + // Write the target image as an argument to `textureQueryLevels` + self.write_expr(image, ctx)?; + // Close the call to `textureQueryLevels` subtract 1 from it since + // the lod argument is 0 based, close the `clamp` call and end the + // local declaration statement. + writeln!(self.out, ") - 1);")?; + + Ok(()) + } + + // Helper method used to retrieve how many elements a coordinate vector + // for the images operations need. + fn get_coordinate_vector_size(&self, dim: crate::ImageDimension, arrayed: bool) -> u8 { + // openGL es doesn't have 1D images so we need workaround it + let tex_1d_hack = dim == crate::ImageDimension::D1 && self.options.version.is_es(); + // Get how many components the coordinate vector needs for the dimensions only + let tex_coord_size = match dim { + crate::ImageDimension::D1 => 1, + crate::ImageDimension::D2 => 2, + crate::ImageDimension::D3 => 3, + crate::ImageDimension::Cube => 2, + }; + // Calculate the true size of the coordinate vector by adding 1 for arrayed images + // and another 1 if we need to workaround 1D images by making them 2D + tex_coord_size + tex_1d_hack as u8 + arrayed as u8 + } + + /// Helper method to write the coordinate vector for image operations + fn write_texture_coord( + &mut self, + ctx: &back::FunctionCtx, + vector_size: u8, + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + // Emulate 1D images as 2D for profiles that don't support it (glsl es) + tex_1d_hack: bool, + ) -> Result<(), Error> { + match array_index { + // If the image needs an array indice we need to add it to the end of our + // coordinate vector, to do so we will use the `ivec(ivec, scalar)` + // constructor notation (NOTE: the inner `ivec` can also be a scalar, this + // is important for 1D arrayed images). + Some(layer_expr) => { + write!(self.out, "ivec{vector_size}(")?; + self.write_expr(coordinate, ctx)?; + write!(self.out, ", ")?; + // If we are replacing sampler1D with sampler2D we also need + // to add another zero to the coordinates vector for the y component + if tex_1d_hack { + write!(self.out, "0, ")?; + } + self.write_expr(layer_expr, ctx)?; + write!(self.out, ")")?; + } + // Otherwise write just the expression (and the 1D hack if needed) + None => { + let uvec_size = match *ctx.resolve_type(coordinate, &self.module.types) { + TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Uint, + .. + }) => Some(None), + TypeInner::Vector { + size, + scalar: + crate::Scalar { + kind: crate::ScalarKind::Uint, + .. + }, + } => Some(Some(size as u32)), + _ => None, + }; + if tex_1d_hack { + write!(self.out, "ivec2(")?; + } else if uvec_size.is_some() { + match uvec_size { + Some(None) => write!(self.out, "int(")?, + Some(Some(size)) => write!(self.out, "ivec{size}(")?, + _ => {} + } + } + self.write_expr(coordinate, ctx)?; + if tex_1d_hack { + write!(self.out, ", 0)")?; + } else if uvec_size.is_some() { + write!(self.out, ")")?; + } + } + } + + Ok(()) + } + + /// Helper method to write the `ImageStore` statement + fn write_image_store( + &mut self, + ctx: &back::FunctionCtx, + image: Handle<crate::Expression>, + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + value: Handle<crate::Expression>, + ) -> Result<(), Error> { + use crate::ImageDimension as IDim; + + // NOTE: openGL requires that `imageStore`s have no effets when the texel is invalid + // so we don't need to generate bounds checks (OpenGL 4.2 Core §3.9.20) + + // This will only panic if the module is invalid + let dim = match *ctx.resolve_type(image, &self.module.types) { + TypeInner::Image { dim, .. } => dim, + _ => unreachable!(), + }; + + // Begin our call to `imageStore` + write!(self.out, "imageStore(")?; + self.write_expr(image, ctx)?; + // Separate the image argument from the coordinates + write!(self.out, ", ")?; + + // openGL es doesn't have 1D images so we need workaround it + let tex_1d_hack = dim == IDim::D1 && self.options.version.is_es(); + // Write the coordinate vector + self.write_texture_coord( + ctx, + // Get the size of the coordinate vector + self.get_coordinate_vector_size(dim, array_index.is_some()), + coordinate, + array_index, + tex_1d_hack, + )?; + + // Separate the coordinate from the value to write and write the expression + // of the value to write. + write!(self.out, ", ")?; + self.write_expr(value, ctx)?; + // End the call to `imageStore` and the statement. + writeln!(self.out, ");")?; + + Ok(()) + } + + /// Helper method for writing an `ImageLoad` expression. + #[allow(clippy::too_many_arguments)] + fn write_image_load( + &mut self, + handle: Handle<crate::Expression>, + ctx: &back::FunctionCtx, + image: Handle<crate::Expression>, + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + sample: Option<Handle<crate::Expression>>, + level: Option<Handle<crate::Expression>>, + ) -> Result<(), Error> { + use crate::ImageDimension as IDim; + + // `ImageLoad` is a bit complicated. + // There are two functions one for sampled + // images another for storage images, the former uses `texelFetch` and the + // latter uses `imageLoad`. + // + // Furthermore we have `level` which is always `Some` for sampled images + // and `None` for storage images, so we end up with two functions: + // - `texelFetch(image, coordinate, level)` for sampled images + // - `imageLoad(image, coordinate)` for storage images + // + // Finally we also have to consider bounds checking, for storage images + // this is easy since openGL requires that invalid texels always return + // 0, for sampled images we need to either verify that all arguments are + // in bounds (`ReadZeroSkipWrite`) or make them a valid texel (`Restrict`). + + // This will only panic if the module is invalid + let (dim, class) = match *ctx.resolve_type(image, &self.module.types) { + TypeInner::Image { + dim, + arrayed: _, + class, + } => (dim, class), + _ => unreachable!(), + }; + + // Get the name of the function to be used for the load operation + // and the policy to be used with it. + let (fun_name, policy) = match class { + // Sampled images inherit the policy from the user passed policies + crate::ImageClass::Sampled { .. } => ("texelFetch", self.policies.image_load), + crate::ImageClass::Storage { .. } => { + // OpenGL ES 3.1 mentions in Chapter "8.22 Texture Image Loads and Stores" that: + // "Invalid image loads will return a vector where the value of R, G, and B components + // is 0 and the value of the A component is undefined." + // + // OpenGL 4.2 Core mentions in Chapter "3.9.20 Texture Image Loads and Stores" that: + // "Invalid image loads will return zero." + // + // So, we only inject bounds checks for ES + let policy = if self.options.version.is_es() { + self.policies.image_load + } else { + proc::BoundsCheckPolicy::Unchecked + }; + ("imageLoad", policy) + } + // TODO: Is there even a function for this? + crate::ImageClass::Depth { multi: _ } => { + return Err(Error::Custom( + "WGSL `textureLoad` from depth textures is not supported in GLSL".to_string(), + )) + } + }; + + // openGL es doesn't have 1D images so we need workaround it + let tex_1d_hack = dim == IDim::D1 && self.options.version.is_es(); + // Get the size of the coordinate vector + let vector_size = self.get_coordinate_vector_size(dim, array_index.is_some()); + + if let proc::BoundsCheckPolicy::ReadZeroSkipWrite = policy { + // To write the bounds checks for `ReadZeroSkipWrite` we will use a + // ternary operator since we are in the middle of an expression and + // need to return a value. + // + // NOTE: glsl does short circuit when evaluating logical + // expressions so we can be sure that after we test a + // condition it will be true for the next ones + + // Write parentheses around the ternary operator to prevent problems with + // expressions emitted before or after it having more precedence + write!(self.out, "(",)?; + + // The lod check needs to precede the size check since we need + // to use the lod to get the size of the image at that level. + if let Some(level_expr) = level { + self.write_expr(level_expr, ctx)?; + write!(self.out, " < textureQueryLevels(",)?; + self.write_expr(image, ctx)?; + // Chain the next check + write!(self.out, ") && ")?; + } + + // Check that the sample arguments doesn't exceed the number of samples + if let Some(sample_expr) = sample { + self.write_expr(sample_expr, ctx)?; + write!(self.out, " < textureSamples(",)?; + self.write_expr(image, ctx)?; + // Chain the next check + write!(self.out, ") && ")?; + } + + // We now need to write the size checks for the coordinates and array index + // first we write the comparison function in case the image is 1D non arrayed + // (and no 1D to 2D hack was needed) we are comparing scalars so the less than + // operator will suffice, but otherwise we'll be comparing two vectors so we'll + // need to use the `lessThan` function but it returns a vector of booleans (one + // for each comparison) so we need to fold it all in one scalar boolean, since + // we want all comparisons to pass we use the `all` function which will only + // return `true` if all the elements of the boolean vector are also `true`. + // + // So we'll end with one of the following forms + // - `coord < textureSize(image, lod)` for 1D images + // - `all(lessThan(coord, textureSize(image, lod)))` for normal images + // - `all(lessThan(ivec(coord, array_index), textureSize(image, lod)))` + // for arrayed images + // - `all(lessThan(coord, textureSize(image)))` for multi sampled images + + if vector_size != 1 { + write!(self.out, "all(lessThan(")?; + } + + // Write the coordinate vector + self.write_texture_coord(ctx, vector_size, coordinate, array_index, tex_1d_hack)?; + + if vector_size != 1 { + // If we used the `lessThan` function we need to separate the + // coordinates from the image size. + write!(self.out, ", ")?; + } else { + // If we didn't use it (ie. 1D images) we perform the comparison + // using the less than operator. + write!(self.out, " < ")?; + } + + // Call `textureSize` to get our image size + write!(self.out, "textureSize(")?; + self.write_expr(image, ctx)?; + // `textureSize` uses the lod as a second argument for mipmapped images + if let Some(level_expr) = level { + // Separate the image from the lod + write!(self.out, ", ")?; + self.write_expr(level_expr, ctx)?; + } + // Close the `textureSize` call + write!(self.out, ")")?; + + if vector_size != 1 { + // Close the `all` and `lessThan` calls + write!(self.out, "))")?; + } + + // Finally end the condition part of the ternary operator + write!(self.out, " ? ")?; + } + + // Begin the call to the function used to load the texel + write!(self.out, "{fun_name}(")?; + self.write_expr(image, ctx)?; + write!(self.out, ", ")?; + + // If we are using `Restrict` bounds checking we need to pass valid texel + // coordinates, to do so we use the `clamp` function to get a value between + // 0 and the image size - 1 (indexing begins at 0) + if let proc::BoundsCheckPolicy::Restrict = policy { + write!(self.out, "clamp(")?; + } + + // Write the coordinate vector + self.write_texture_coord(ctx, vector_size, coordinate, array_index, tex_1d_hack)?; + + // If we are using `Restrict` bounds checking we need to write the rest of the + // clamp we initiated before writing the coordinates. + if let proc::BoundsCheckPolicy::Restrict = policy { + // Write the min value 0 + if vector_size == 1 { + write!(self.out, ", 0")?; + } else { + write!(self.out, ", ivec{vector_size}(0)")?; + } + // Start the `textureSize` call to use as the max value. + write!(self.out, ", textureSize(")?; + self.write_expr(image, ctx)?; + // If the image is mipmapped we need to add the lod argument to the + // `textureSize` call, but this needs to be the clamped lod, this should + // have been generated earlier and put in a local. + if class.is_mipmapped() { + write!( + self.out, + ", {}{}{}", + back::BAKE_PREFIX, + handle.index(), + CLAMPED_LOD_SUFFIX + )?; + } + // Close the `textureSize` call + write!(self.out, ")")?; + + // Subtract 1 from the `textureSize` call since the coordinates are zero based. + if vector_size == 1 { + write!(self.out, " - 1")?; + } else { + write!(self.out, " - ivec{vector_size}(1)")?; + } + + // Close the `clamp` call + write!(self.out, ")")?; + + // Add the clamped lod (if present) as the second argument to the + // image load function. + if level.is_some() { + write!( + self.out, + ", {}{}{}", + back::BAKE_PREFIX, + handle.index(), + CLAMPED_LOD_SUFFIX + )?; + } + + // If a sample argument is needed we need to clamp it between 0 and + // the number of samples the image has. + if let Some(sample_expr) = sample { + write!(self.out, ", clamp(")?; + self.write_expr(sample_expr, ctx)?; + // Set the min value to 0 and start the call to `textureSamples` + write!(self.out, ", 0, textureSamples(")?; + self.write_expr(image, ctx)?; + // Close the `textureSamples` call, subtract 1 from it since the sample + // argument is zero based, and close the `clamp` call + writeln!(self.out, ") - 1)")?; + } + } else if let Some(sample_or_level) = sample.or(level) { + // If no bounds checking is need just add the sample or level argument + // after the coordinates + write!(self.out, ", ")?; + self.write_expr(sample_or_level, ctx)?; + } + + // Close the image load function. + write!(self.out, ")")?; + + // If we were using the `ReadZeroSkipWrite` policy we need to end the first branch + // (which is taken if the condition is `true`) with a colon (`:`) and write the + // second branch which is just a 0 value. + if let proc::BoundsCheckPolicy::ReadZeroSkipWrite = policy { + // Get the kind of the output value. + let kind = match class { + // Only sampled images can reach here since storage images + // don't need bounds checks and depth images aren't implemented + crate::ImageClass::Sampled { kind, .. } => kind, + _ => unreachable!(), + }; + + // End the first branch + write!(self.out, " : ")?; + // Write the 0 value + write!( + self.out, + "{}vec4(", + glsl_scalar(crate::Scalar { kind, width: 4 })?.prefix, + )?; + self.write_zero_init_scalar(kind)?; + // Close the zero value constructor + write!(self.out, ")")?; + // Close the parentheses surrounding our ternary + write!(self.out, ")")?; + } + + Ok(()) + } + + fn write_named_expr( + &mut self, + handle: Handle<crate::Expression>, + name: String, + // The expression which is being named. + // Generally, this is the same as handle, except in WorkGroupUniformLoad + named: Handle<crate::Expression>, + ctx: &back::FunctionCtx, + ) -> BackendResult { + match ctx.info[named].ty { + proc::TypeResolution::Handle(ty_handle) => match self.module.types[ty_handle].inner { + TypeInner::Struct { .. } => { + let ty_name = &self.names[&NameKey::Type(ty_handle)]; + write!(self.out, "{ty_name}")?; + } + _ => { + self.write_type(ty_handle)?; + } + }, + proc::TypeResolution::Value(ref inner) => { + self.write_value_type(inner)?; + } + } + + let resolved = ctx.resolve_type(named, &self.module.types); + + write!(self.out, " {name}")?; + if let TypeInner::Array { base, size, .. } = *resolved { + self.write_array_size(base, size)?; + } + write!(self.out, " = ")?; + self.write_expr(handle, ctx)?; + writeln!(self.out, ";")?; + self.named_expressions.insert(named, name); + + Ok(()) + } + + /// Helper function that write string with default zero initialization for supported types + fn write_zero_init_value(&mut self, ty: Handle<crate::Type>) -> BackendResult { + let inner = &self.module.types[ty].inner; + match *inner { + TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => { + self.write_zero_init_scalar(scalar.kind)?; + } + TypeInner::Vector { scalar, .. } => { + self.write_value_type(inner)?; + write!(self.out, "(")?; + self.write_zero_init_scalar(scalar.kind)?; + write!(self.out, ")")?; + } + TypeInner::Matrix { .. } => { + self.write_value_type(inner)?; + write!(self.out, "(")?; + self.write_zero_init_scalar(crate::ScalarKind::Float)?; + write!(self.out, ")")?; + } + TypeInner::Array { base, size, .. } => { + let count = match size + .to_indexable_length(self.module) + .expect("Bad array size") + { + proc::IndexableLength::Known(count) => count, + proc::IndexableLength::Dynamic => return Ok(()), + }; + self.write_type(base)?; + self.write_array_size(base, size)?; + write!(self.out, "(")?; + for _ in 1..count { + self.write_zero_init_value(base)?; + write!(self.out, ", ")?; + } + // write last parameter without comma and space + self.write_zero_init_value(base)?; + write!(self.out, ")")?; + } + TypeInner::Struct { ref members, .. } => { + let name = &self.names[&NameKey::Type(ty)]; + write!(self.out, "{name}(")?; + for (index, member) in members.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + self.write_zero_init_value(member.ty)?; + } + write!(self.out, ")")?; + } + _ => unreachable!(), + } + + Ok(()) + } + + /// Helper function that write string with zero initialization for scalar + fn write_zero_init_scalar(&mut self, kind: crate::ScalarKind) -> BackendResult { + match kind { + crate::ScalarKind::Bool => write!(self.out, "false")?, + crate::ScalarKind::Uint => write!(self.out, "0u")?, + crate::ScalarKind::Float => write!(self.out, "0.0")?, + crate::ScalarKind::Sint => write!(self.out, "0")?, + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".to_string(), + )) + } + } + + Ok(()) + } + + /// Issue a memory barrier. Please note that to ensure visibility, + /// OpenGL always requires a call to the `barrier()` function after a `memoryBarrier*()` + fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult { + if flags.contains(crate::Barrier::STORAGE) { + writeln!(self.out, "{level}memoryBarrierBuffer();")?; + } + if flags.contains(crate::Barrier::WORK_GROUP) { + writeln!(self.out, "{level}memoryBarrierShared();")?; + } + writeln!(self.out, "{level}barrier();")?; + Ok(()) + } + + /// Helper function that return the glsl storage access string of [`StorageAccess`](crate::StorageAccess) + /// + /// glsl allows adding both `readonly` and `writeonly` but this means that + /// they can only be used to query information about the resource which isn't what + /// we want here so when storage access is both `LOAD` and `STORE` add no modifiers + fn write_storage_access(&mut self, storage_access: crate::StorageAccess) -> BackendResult { + if !storage_access.contains(crate::StorageAccess::STORE) { + write!(self.out, "readonly ")?; + } + if !storage_access.contains(crate::StorageAccess::LOAD) { + write!(self.out, "writeonly ")?; + } + Ok(()) + } + + /// Helper method used to produce the reflection info that's returned to the user + fn collect_reflection_info(&mut self) -> Result<ReflectionInfo, Error> { + use std::collections::hash_map::Entry; + let info = self.info.get_entry_point(self.entry_point_idx as usize); + let mut texture_mapping = crate::FastHashMap::default(); + let mut uniforms = crate::FastHashMap::default(); + + for sampling in info.sampling_set.iter() { + let tex_name = self.reflection_names_globals[&sampling.image].clone(); + + match texture_mapping.entry(tex_name) { + Entry::Vacant(v) => { + v.insert(TextureMapping { + texture: sampling.image, + sampler: Some(sampling.sampler), + }); + } + Entry::Occupied(e) => { + if e.get().sampler != Some(sampling.sampler) { + log::error!("Conflicting samplers for {}", e.key()); + return Err(Error::ImageMultipleSamplers); + } + } + } + } + + let mut push_constant_info = None; + for (handle, var) in self.module.global_variables.iter() { + if info[handle].is_empty() { + continue; + } + match self.module.types[var.ty].inner { + crate::TypeInner::Image { .. } => { + let tex_name = self.reflection_names_globals[&handle].clone(); + match texture_mapping.entry(tex_name) { + Entry::Vacant(v) => { + v.insert(TextureMapping { + texture: handle, + sampler: None, + }); + } + Entry::Occupied(_) => { + // already used with a sampler, do nothing + } + } + } + _ => match var.space { + crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } => { + let name = self.reflection_names_globals[&handle].clone(); + uniforms.insert(handle, name); + } + crate::AddressSpace::PushConstant => { + let name = self.reflection_names_globals[&handle].clone(); + push_constant_info = Some((name, var.ty)); + } + _ => (), + }, + } + } + + let mut push_constant_segments = Vec::new(); + let mut push_constant_items = vec![]; + + if let Some((name, ty)) = push_constant_info { + // We don't have a layouter available to us, so we need to create one. + // + // This is potentially a bit wasteful, but the set of types in the program + // shouldn't be too large. + let mut layouter = crate::proc::Layouter::default(); + layouter.update(self.module.to_ctx()).unwrap(); + + // We start with the name of the binding itself. + push_constant_segments.push(name); + + // We then recursively collect all the uniform fields of the push constant. + self.collect_push_constant_items( + ty, + &mut push_constant_segments, + &layouter, + &mut 0, + &mut push_constant_items, + ); + } + + Ok(ReflectionInfo { + texture_mapping, + uniforms, + varying: mem::take(&mut self.varying), + push_constant_items, + }) + } + + fn collect_push_constant_items( + &mut self, + ty: Handle<crate::Type>, + segments: &mut Vec<String>, + layouter: &crate::proc::Layouter, + offset: &mut u32, + items: &mut Vec<PushConstantItem>, + ) { + // At this point in the recursion, `segments` contains the path + // needed to access `ty` from the root. + + let layout = &layouter[ty]; + *offset = layout.alignment.round_up(*offset); + match self.module.types[ty].inner { + // All these types map directly to GL uniforms. + TypeInner::Scalar { .. } | TypeInner::Vector { .. } | TypeInner::Matrix { .. } => { + // Build the full name, by combining all current segments. + let name: String = segments.iter().map(String::as_str).collect(); + items.push(PushConstantItem { + access_path: name, + offset: *offset, + ty, + }); + *offset += layout.size; + } + // Arrays are recursed into. + TypeInner::Array { base, size, .. } => { + let crate::ArraySize::Constant(count) = size else { + unreachable!("Cannot have dynamic arrays in push constants"); + }; + + for i in 0..count.get() { + // Add the array accessor and recurse. + segments.push(format!("[{}]", i)); + self.collect_push_constant_items(base, segments, layouter, offset, items); + segments.pop(); + } + + // Ensure the stride is kept by rounding up to the alignment. + *offset = layout.alignment.round_up(*offset) + } + TypeInner::Struct { ref members, .. } => { + for (index, member) in members.iter().enumerate() { + // Add struct accessor and recurse. + segments.push(format!( + ".{}", + self.names[&NameKey::StructMember(ty, index as u32)] + )); + self.collect_push_constant_items(member.ty, segments, layouter, offset, items); + segments.pop(); + } + + // Ensure ending padding is kept by rounding up to the alignment. + *offset = layout.alignment.round_up(*offset) + } + _ => unreachable!(), + } + } +} + +/// Structure returned by [`glsl_scalar`] +/// +/// It contains both a prefix used in other types and the full type name +struct ScalarString<'a> { + /// The prefix used to compose other types + prefix: &'a str, + /// The name of the scalar type + full: &'a str, +} + +/// Helper function that returns scalar related strings +/// +/// Check [`ScalarString`] for the information provided +/// +/// # Errors +/// If a [`Float`](crate::ScalarKind::Float) with an width that isn't 4 or 8 +const fn glsl_scalar(scalar: crate::Scalar) -> Result<ScalarString<'static>, Error> { + use crate::ScalarKind as Sk; + + Ok(match scalar.kind { + Sk::Sint => ScalarString { + prefix: "i", + full: "int", + }, + Sk::Uint => ScalarString { + prefix: "u", + full: "uint", + }, + Sk::Float => match scalar.width { + 4 => ScalarString { + prefix: "", + full: "float", + }, + 8 => ScalarString { + prefix: "d", + full: "double", + }, + _ => return Err(Error::UnsupportedScalar(scalar)), + }, + Sk::Bool => ScalarString { + prefix: "b", + full: "bool", + }, + Sk::AbstractInt | Sk::AbstractFloat => { + return Err(Error::UnsupportedScalar(scalar)); + } + }) +} + +/// Helper function that returns the glsl variable name for a builtin +const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'static str { + use crate::BuiltIn as Bi; + + match built_in { + Bi::Position { .. } => { + if options.output { + "gl_Position" + } else { + "gl_FragCoord" + } + } + Bi::ViewIndex if options.targeting_webgl => "int(gl_ViewID_OVR)", + Bi::ViewIndex => "gl_ViewIndex", + // vertex + Bi::BaseInstance => "uint(gl_BaseInstance)", + Bi::BaseVertex => "uint(gl_BaseVertex)", + Bi::ClipDistance => "gl_ClipDistance", + Bi::CullDistance => "gl_CullDistance", + Bi::InstanceIndex => { + if options.draw_parameters { + "(uint(gl_InstanceID) + uint(gl_BaseInstanceARB))" + } else { + // Must match FIRST_INSTANCE_BINDING + "(uint(gl_InstanceID) + naga_vs_first_instance)" + } + } + Bi::PointSize => "gl_PointSize", + Bi::VertexIndex => "uint(gl_VertexID)", + // fragment + Bi::FragDepth => "gl_FragDepth", + Bi::PointCoord => "gl_PointCoord", + Bi::FrontFacing => "gl_FrontFacing", + Bi::PrimitiveIndex => "uint(gl_PrimitiveID)", + Bi::SampleIndex => "gl_SampleID", + Bi::SampleMask => { + if options.output { + "gl_SampleMask" + } else { + "gl_SampleMaskIn" + } + } + // compute + Bi::GlobalInvocationId => "gl_GlobalInvocationID", + Bi::LocalInvocationId => "gl_LocalInvocationID", + Bi::LocalInvocationIndex => "gl_LocalInvocationIndex", + Bi::WorkGroupId => "gl_WorkGroupID", + Bi::WorkGroupSize => "gl_WorkGroupSize", + Bi::NumWorkGroups => "gl_NumWorkGroups", + } +} + +/// Helper function that returns the string corresponding to the address space +const fn glsl_storage_qualifier(space: crate::AddressSpace) -> Option<&'static str> { + use crate::AddressSpace as As; + + match space { + As::Function => None, + As::Private => None, + As::Storage { .. } => Some("buffer"), + As::Uniform => Some("uniform"), + As::Handle => Some("uniform"), + As::WorkGroup => Some("shared"), + As::PushConstant => Some("uniform"), + } +} + +/// Helper function that returns the string corresponding to the glsl interpolation qualifier +const fn glsl_interpolation(interpolation: crate::Interpolation) -> &'static str { + use crate::Interpolation as I; + + match interpolation { + I::Perspective => "smooth", + I::Linear => "noperspective", + I::Flat => "flat", + } +} + +/// Return the GLSL auxiliary qualifier for the given sampling value. +const fn glsl_sampling(sampling: crate::Sampling) -> Option<&'static str> { + use crate::Sampling as S; + + match sampling { + S::Center => None, + S::Centroid => Some("centroid"), + S::Sample => Some("sample"), + } +} + +/// Helper function that returns the glsl dimension string of [`ImageDimension`](crate::ImageDimension) +const fn glsl_dimension(dim: crate::ImageDimension) -> &'static str { + use crate::ImageDimension as IDim; + + match dim { + IDim::D1 => "1D", + IDim::D2 => "2D", + IDim::D3 => "3D", + IDim::Cube => "Cube", + } +} + +/// Helper function that returns the glsl storage format string of [`StorageFormat`](crate::StorageFormat) +fn glsl_storage_format(format: crate::StorageFormat) -> Result<&'static str, Error> { + use crate::StorageFormat as Sf; + + Ok(match format { + Sf::R8Unorm => "r8", + Sf::R8Snorm => "r8_snorm", + Sf::R8Uint => "r8ui", + Sf::R8Sint => "r8i", + Sf::R16Uint => "r16ui", + Sf::R16Sint => "r16i", + Sf::R16Float => "r16f", + Sf::Rg8Unorm => "rg8", + Sf::Rg8Snorm => "rg8_snorm", + Sf::Rg8Uint => "rg8ui", + Sf::Rg8Sint => "rg8i", + Sf::R32Uint => "r32ui", + Sf::R32Sint => "r32i", + Sf::R32Float => "r32f", + Sf::Rg16Uint => "rg16ui", + Sf::Rg16Sint => "rg16i", + Sf::Rg16Float => "rg16f", + Sf::Rgba8Unorm => "rgba8", + Sf::Rgba8Snorm => "rgba8_snorm", + Sf::Rgba8Uint => "rgba8ui", + Sf::Rgba8Sint => "rgba8i", + Sf::Rgb10a2Uint => "rgb10_a2ui", + Sf::Rgb10a2Unorm => "rgb10_a2", + Sf::Rg11b10Float => "r11f_g11f_b10f", + Sf::Rg32Uint => "rg32ui", + Sf::Rg32Sint => "rg32i", + Sf::Rg32Float => "rg32f", + Sf::Rgba16Uint => "rgba16ui", + Sf::Rgba16Sint => "rgba16i", + Sf::Rgba16Float => "rgba16f", + Sf::Rgba32Uint => "rgba32ui", + Sf::Rgba32Sint => "rgba32i", + Sf::Rgba32Float => "rgba32f", + Sf::R16Unorm => "r16", + Sf::R16Snorm => "r16_snorm", + Sf::Rg16Unorm => "rg16", + Sf::Rg16Snorm => "rg16_snorm", + Sf::Rgba16Unorm => "rgba16", + Sf::Rgba16Snorm => "rgba16_snorm", + + Sf::Bgra8Unorm => { + return Err(Error::Custom( + "Support format BGRA8 is not implemented".into(), + )) + } + }) +} + +fn is_value_init_supported(module: &crate::Module, ty: Handle<crate::Type>) -> bool { + match module.types[ty].inner { + TypeInner::Scalar { .. } | TypeInner::Vector { .. } | TypeInner::Matrix { .. } => true, + TypeInner::Array { base, size, .. } => { + size != crate::ArraySize::Dynamic && is_value_init_supported(module, base) + } + TypeInner::Struct { ref members, .. } => members + .iter() + .all(|member| is_value_init_supported(module, member.ty)), + _ => false, + } +} diff --git a/third_party/rust/naga/src/back/hlsl/conv.rs b/third_party/rust/naga/src/back/hlsl/conv.rs new file mode 100644 index 0000000000..b6918ddc42 --- /dev/null +++ b/third_party/rust/naga/src/back/hlsl/conv.rs @@ -0,0 +1,222 @@ +use std::borrow::Cow; + +use crate::proc::Alignment; + +use super::Error; + +impl crate::ScalarKind { + pub(super) fn to_hlsl_cast(self) -> &'static str { + match self { + Self::Float => "asfloat", + Self::Sint => "asint", + Self::Uint => "asuint", + Self::Bool | Self::AbstractInt | Self::AbstractFloat => unreachable!(), + } + } +} + +impl crate::Scalar { + /// Helper function that returns scalar related strings + /// + /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-scalar> + pub(super) const fn to_hlsl_str(self) -> Result<&'static str, Error> { + match self.kind { + crate::ScalarKind::Sint => Ok("int"), + crate::ScalarKind::Uint => Ok("uint"), + crate::ScalarKind::Float => match self.width { + 2 => Ok("half"), + 4 => Ok("float"), + 8 => Ok("double"), + _ => Err(Error::UnsupportedScalar(self)), + }, + crate::ScalarKind::Bool => Ok("bool"), + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { + Err(Error::UnsupportedScalar(self)) + } + } + } +} + +impl crate::TypeInner { + pub(super) const fn is_matrix(&self) -> bool { + match *self { + Self::Matrix { .. } => true, + _ => false, + } + } + + pub(super) fn size_hlsl(&self, gctx: crate::proc::GlobalCtx) -> u32 { + match *self { + Self::Matrix { + columns, + rows, + scalar, + } => { + let stride = Alignment::from(rows) * scalar.width as u32; + let last_row_size = rows as u32 * scalar.width as u32; + ((columns as u32 - 1) * stride) + last_row_size + } + Self::Array { base, size, stride } => { + let count = match size { + crate::ArraySize::Constant(size) => size.get(), + // A dynamically-sized array has to have at least one element + crate::ArraySize::Dynamic => 1, + }; + let last_el_size = gctx.types[base].inner.size_hlsl(gctx); + ((count - 1) * stride) + last_el_size + } + _ => self.size(gctx), + } + } + + /// Used to generate the name of the wrapped type constructor + pub(super) fn hlsl_type_id<'a>( + base: crate::Handle<crate::Type>, + gctx: crate::proc::GlobalCtx, + names: &'a crate::FastHashMap<crate::proc::NameKey, String>, + ) -> Result<Cow<'a, str>, Error> { + Ok(match gctx.types[base].inner { + crate::TypeInner::Scalar(scalar) => Cow::Borrowed(scalar.to_hlsl_str()?), + crate::TypeInner::Vector { size, scalar } => Cow::Owned(format!( + "{}{}", + scalar.to_hlsl_str()?, + crate::back::vector_size_str(size) + )), + crate::TypeInner::Matrix { + columns, + rows, + scalar, + } => Cow::Owned(format!( + "{}{}x{}", + scalar.to_hlsl_str()?, + crate::back::vector_size_str(columns), + crate::back::vector_size_str(rows), + )), + crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + .. + } => Cow::Owned(format!( + "array{size}_{}_", + Self::hlsl_type_id(base, gctx, names)? + )), + crate::TypeInner::Struct { .. } => { + Cow::Borrowed(&names[&crate::proc::NameKey::Type(base)]) + } + _ => unreachable!(), + }) + } +} + +impl crate::StorageFormat { + pub(super) const fn to_hlsl_str(self) -> &'static str { + match self { + Self::R16Float => "float", + Self::R8Unorm | Self::R16Unorm => "unorm float", + Self::R8Snorm | Self::R16Snorm => "snorm float", + Self::R8Uint | Self::R16Uint => "uint", + Self::R8Sint | Self::R16Sint => "int", + + Self::Rg16Float => "float2", + Self::Rg8Unorm | Self::Rg16Unorm => "unorm float2", + Self::Rg8Snorm | Self::Rg16Snorm => "snorm float2", + + Self::Rg8Sint | Self::Rg16Sint => "int2", + Self::Rg8Uint | Self::Rg16Uint => "uint2", + + Self::Rg11b10Float => "float3", + + Self::Rgba16Float | Self::R32Float | Self::Rg32Float | Self::Rgba32Float => "float4", + Self::Rgba8Unorm | Self::Bgra8Unorm | Self::Rgba16Unorm | Self::Rgb10a2Unorm => { + "unorm float4" + } + Self::Rgba8Snorm | Self::Rgba16Snorm => "snorm float4", + + Self::Rgba8Uint + | Self::Rgba16Uint + | Self::R32Uint + | Self::Rg32Uint + | Self::Rgba32Uint + | Self::Rgb10a2Uint => "uint4", + Self::Rgba8Sint + | Self::Rgba16Sint + | Self::R32Sint + | Self::Rg32Sint + | Self::Rgba32Sint => "int4", + } + } +} + +impl crate::BuiltIn { + pub(super) fn to_hlsl_str(self) -> Result<&'static str, Error> { + Ok(match self { + Self::Position { .. } => "SV_Position", + // vertex + Self::ClipDistance => "SV_ClipDistance", + Self::CullDistance => "SV_CullDistance", + Self::InstanceIndex => "SV_InstanceID", + Self::VertexIndex => "SV_VertexID", + // fragment + Self::FragDepth => "SV_Depth", + Self::FrontFacing => "SV_IsFrontFace", + Self::PrimitiveIndex => "SV_PrimitiveID", + Self::SampleIndex => "SV_SampleIndex", + Self::SampleMask => "SV_Coverage", + // compute + Self::GlobalInvocationId => "SV_DispatchThreadID", + Self::LocalInvocationId => "SV_GroupThreadID", + Self::LocalInvocationIndex => "SV_GroupIndex", + Self::WorkGroupId => "SV_GroupID", + // The specific semantic we use here doesn't matter, because references + // to this field will get replaced with references to `SPECIAL_CBUF_VAR` + // in `Writer::write_expr`. + Self::NumWorkGroups => "SV_GroupID", + Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => { + return Err(Error::Unimplemented(format!("builtin {self:?}"))) + } + Self::PointSize | Self::ViewIndex | Self::PointCoord => { + return Err(Error::Custom(format!("Unsupported builtin {self:?}"))) + } + }) + } +} + +impl crate::Interpolation { + /// Return the string corresponding to the HLSL interpolation qualifier. + pub(super) const fn to_hlsl_str(self) -> Option<&'static str> { + match self { + // Would be "linear", but it's the default interpolation in SM4 and up + // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-struct#interpolation-modifiers-introduced-in-shader-model-4 + Self::Perspective => None, + Self::Linear => Some("noperspective"), + Self::Flat => Some("nointerpolation"), + } + } +} + +impl crate::Sampling { + /// Return the HLSL auxiliary qualifier for the given sampling value. + pub(super) const fn to_hlsl_str(self) -> Option<&'static str> { + match self { + Self::Center => None, + Self::Centroid => Some("centroid"), + Self::Sample => Some("sample"), + } + } +} + +impl crate::AtomicFunction { + /// Return the HLSL suffix for the `InterlockedXxx` method. + pub(super) const fn to_hlsl_suffix(self) -> &'static str { + match self { + Self::Add | Self::Subtract => "Add", + Self::And => "And", + Self::InclusiveOr => "Or", + Self::ExclusiveOr => "Xor", + Self::Min => "Min", + Self::Max => "Max", + Self::Exchange { compare: None } => "Exchange", + Self::Exchange { .. } => "", //TODO + } + } +} diff --git a/third_party/rust/naga/src/back/hlsl/help.rs b/third_party/rust/naga/src/back/hlsl/help.rs new file mode 100644 index 0000000000..fa6062a1ad --- /dev/null +++ b/third_party/rust/naga/src/back/hlsl/help.rs @@ -0,0 +1,1138 @@ +/*! +Helpers for the hlsl backend + +Important note about `Expression::ImageQuery`/`Expression::ArrayLength` and hlsl backend: + +Due to implementation of `GetDimensions` function in hlsl (<https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>) +backend can't work with it as an expression. +Instead, it generates a unique wrapped function per `Expression::ImageQuery`, based on texture info and query function. +See `WrappedImageQuery` struct that represents a unique function and will be generated before writing all statements and expressions. +This allowed to works with `Expression::ImageQuery` as expression and write wrapped function. + +For example: +```wgsl +let dim_1d = textureDimensions(image_1d); +``` + +```hlsl +int NagaDimensions1D(Texture1D<float4>) +{ + uint4 ret; + image_1d.GetDimensions(ret.x); + return ret.x; +} + +int dim_1d = NagaDimensions1D(image_1d); +``` +*/ + +use super::{super::FunctionCtx, BackendResult}; +use crate::{arena::Handle, proc::NameKey}; +use std::fmt::Write; + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) struct WrappedArrayLength { + pub(super) writable: bool, +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) struct WrappedImageQuery { + pub(super) dim: crate::ImageDimension, + pub(super) arrayed: bool, + pub(super) class: crate::ImageClass, + pub(super) query: ImageQuery, +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) struct WrappedConstructor { + pub(super) ty: Handle<crate::Type>, +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) struct WrappedStructMatrixAccess { + pub(super) ty: Handle<crate::Type>, + pub(super) index: u32, +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) struct WrappedMatCx2 { + pub(super) columns: crate::VectorSize, +} + +/// HLSL backend requires its own `ImageQuery` enum. +/// +/// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function. +/// IR version can't be unique per function, because it's store mipmap level as an expression. +/// +/// For example: +/// ```wgsl +/// let dim_cube_array_lod = textureDimensions(image_cube_array, 1); +/// let dim_cube_array_lod2 = textureDimensions(image_cube_array, 1); +/// ``` +/// +/// ```ir +/// ImageQuery { +/// image: [1], +/// query: Size { +/// level: Some( +/// [1], +/// ), +/// }, +/// }, +/// ImageQuery { +/// image: [1], +/// query: Size { +/// level: Some( +/// [2], +/// ), +/// }, +/// }, +/// ``` +/// +/// HLSL should generate only 1 function for this case. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) enum ImageQuery { + Size, + SizeLevel, + NumLevels, + NumLayers, + NumSamples, +} + +impl From<crate::ImageQuery> for ImageQuery { + fn from(q: crate::ImageQuery) -> Self { + use crate::ImageQuery as Iq; + match q { + Iq::Size { level: Some(_) } => ImageQuery::SizeLevel, + Iq::Size { level: None } => ImageQuery::Size, + Iq::NumLevels => ImageQuery::NumLevels, + Iq::NumLayers => ImageQuery::NumLayers, + Iq::NumSamples => ImageQuery::NumSamples, + } + } +} + +impl<'a, W: Write> super::Writer<'a, W> { + pub(super) fn write_image_type( + &mut self, + dim: crate::ImageDimension, + arrayed: bool, + class: crate::ImageClass, + ) -> BackendResult { + let access_str = match class { + crate::ImageClass::Storage { .. } => "RW", + _ => "", + }; + let dim_str = dim.to_hlsl_str(); + let arrayed_str = if arrayed { "Array" } else { "" }; + write!(self.out, "{access_str}Texture{dim_str}{arrayed_str}")?; + match class { + crate::ImageClass::Depth { multi } => { + let multi_str = if multi { "MS" } else { "" }; + write!(self.out, "{multi_str}<float>")? + } + crate::ImageClass::Sampled { kind, multi } => { + let multi_str = if multi { "MS" } else { "" }; + let scalar_kind_str = crate::Scalar { kind, width: 4 }.to_hlsl_str()?; + write!(self.out, "{multi_str}<{scalar_kind_str}4>")? + } + crate::ImageClass::Storage { format, .. } => { + let storage_format_str = format.to_hlsl_str(); + write!(self.out, "<{storage_format_str}>")? + } + } + Ok(()) + } + + pub(super) fn write_wrapped_array_length_function_name( + &mut self, + query: WrappedArrayLength, + ) -> BackendResult { + let access_str = if query.writable { "RW" } else { "" }; + write!(self.out, "NagaBufferLength{access_str}",)?; + + Ok(()) + } + + /// Helper function that write wrapped function for `Expression::ArrayLength` + /// + /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer-getdimensions> + pub(super) fn write_wrapped_array_length_function( + &mut self, + wal: WrappedArrayLength, + ) -> BackendResult { + use crate::back::INDENT; + + const ARGUMENT_VARIABLE_NAME: &str = "buffer"; + const RETURN_VARIABLE_NAME: &str = "ret"; + + // Write function return type and name + write!(self.out, "uint ")?; + self.write_wrapped_array_length_function_name(wal)?; + + // Write function parameters + write!(self.out, "(")?; + let access_str = if wal.writable { "RW" } else { "" }; + writeln!( + self.out, + "{access_str}ByteAddressBuffer {ARGUMENT_VARIABLE_NAME})" + )?; + // Write function body + writeln!(self.out, "{{")?; + + // Write `GetDimensions` function. + writeln!(self.out, "{INDENT}uint {RETURN_VARIABLE_NAME};")?; + writeln!( + self.out, + "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions({RETURN_VARIABLE_NAME});" + )?; + + // Write return value + writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?; + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } + + pub(super) fn write_wrapped_image_query_function_name( + &mut self, + query: WrappedImageQuery, + ) -> BackendResult { + let dim_str = query.dim.to_hlsl_str(); + let class_str = match query.class { + crate::ImageClass::Sampled { multi: true, .. } => "MS", + crate::ImageClass::Depth { multi: true } => "DepthMS", + crate::ImageClass::Depth { multi: false } => "Depth", + crate::ImageClass::Sampled { multi: false, .. } => "", + crate::ImageClass::Storage { .. } => "RW", + }; + let arrayed_str = if query.arrayed { "Array" } else { "" }; + let query_str = match query.query { + ImageQuery::Size => "Dimensions", + ImageQuery::SizeLevel => "MipDimensions", + ImageQuery::NumLevels => "NumLevels", + ImageQuery::NumLayers => "NumLayers", + ImageQuery::NumSamples => "NumSamples", + }; + + write!(self.out, "Naga{class_str}{query_str}{dim_str}{arrayed_str}")?; + + Ok(()) + } + + /// Helper function that write wrapped function for `Expression::ImageQuery` + /// + /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions> + pub(super) fn write_wrapped_image_query_function( + &mut self, + module: &crate::Module, + wiq: WrappedImageQuery, + expr_handle: Handle<crate::Expression>, + func_ctx: &FunctionCtx, + ) -> BackendResult { + use crate::{ + back::{COMPONENTS, INDENT}, + ImageDimension as IDim, + }; + + const ARGUMENT_VARIABLE_NAME: &str = "tex"; + const RETURN_VARIABLE_NAME: &str = "ret"; + const MIP_LEVEL_PARAM: &str = "mip_level"; + + // Write function return type and name + let ret_ty = func_ctx.resolve_type(expr_handle, &module.types); + self.write_value_type(module, ret_ty)?; + write!(self.out, " ")?; + self.write_wrapped_image_query_function_name(wiq)?; + + // Write function parameters + write!(self.out, "(")?; + // Texture always first parameter + self.write_image_type(wiq.dim, wiq.arrayed, wiq.class)?; + write!(self.out, " {ARGUMENT_VARIABLE_NAME}")?; + // Mipmap is a second parameter if exists + if let ImageQuery::SizeLevel = wiq.query { + write!(self.out, ", uint {MIP_LEVEL_PARAM}")?; + } + writeln!(self.out, ")")?; + + // Write function body + writeln!(self.out, "{{")?; + + let array_coords = usize::from(wiq.arrayed); + // extra parameter is the mip level count or the sample count + let extra_coords = match wiq.class { + crate::ImageClass::Storage { .. } => 0, + crate::ImageClass::Sampled { .. } | crate::ImageClass::Depth { .. } => 1, + }; + + // GetDimensions Overloaded Methods + // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions#overloaded-methods + let (ret_swizzle, number_of_params) = match wiq.query { + ImageQuery::Size | ImageQuery::SizeLevel => { + let ret = match wiq.dim { + IDim::D1 => "x", + IDim::D2 => "xy", + IDim::D3 => "xyz", + IDim::Cube => "xy", + }; + (ret, ret.len() + array_coords + extra_coords) + } + ImageQuery::NumLevels | ImageQuery::NumSamples | ImageQuery::NumLayers => { + if wiq.arrayed || wiq.dim == IDim::D3 { + ("w", 4) + } else { + ("z", 3) + } + } + }; + + // Write `GetDimensions` function. + writeln!(self.out, "{INDENT}uint4 {RETURN_VARIABLE_NAME};")?; + write!(self.out, "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions(")?; + match wiq.query { + ImageQuery::SizeLevel => { + write!(self.out, "{MIP_LEVEL_PARAM}, ")?; + } + _ => match wiq.class { + crate::ImageClass::Sampled { multi: true, .. } + | crate::ImageClass::Depth { multi: true } + | crate::ImageClass::Storage { .. } => {} + _ => { + // Write zero mipmap level for supported types + write!(self.out, "0, ")?; + } + }, + } + + for component in COMPONENTS[..number_of_params - 1].iter() { + write!(self.out, "{RETURN_VARIABLE_NAME}.{component}, ")?; + } + + // write last parameter without comma and space for last parameter + write!( + self.out, + "{}.{}", + RETURN_VARIABLE_NAME, + COMPONENTS[number_of_params - 1] + )?; + + writeln!(self.out, ");")?; + + // Write return value + writeln!( + self.out, + "{INDENT}return {RETURN_VARIABLE_NAME}.{ret_swizzle};" + )?; + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } + + pub(super) fn write_wrapped_constructor_function_name( + &mut self, + module: &crate::Module, + constructor: WrappedConstructor, + ) -> BackendResult { + let name = crate::TypeInner::hlsl_type_id(constructor.ty, module.to_ctx(), &self.names)?; + write!(self.out, "Construct{name}")?; + Ok(()) + } + + /// Helper function that write wrapped function for `Expression::Compose` for structures. + pub(super) fn write_wrapped_constructor_function( + &mut self, + module: &crate::Module, + constructor: WrappedConstructor, + ) -> BackendResult { + use crate::back::INDENT; + + const ARGUMENT_VARIABLE_NAME: &str = "arg"; + const RETURN_VARIABLE_NAME: &str = "ret"; + + // Write function return type and name + if let crate::TypeInner::Array { base, size, .. } = module.types[constructor.ty].inner { + write!(self.out, "typedef ")?; + self.write_type(module, constructor.ty)?; + write!(self.out, " ret_")?; + self.write_wrapped_constructor_function_name(module, constructor)?; + self.write_array_size(module, base, size)?; + writeln!(self.out, ";")?; + + write!(self.out, "ret_")?; + self.write_wrapped_constructor_function_name(module, constructor)?; + } else { + self.write_type(module, constructor.ty)?; + } + write!(self.out, " ")?; + self.write_wrapped_constructor_function_name(module, constructor)?; + + // Write function parameters + write!(self.out, "(")?; + + let mut write_arg = |i, ty| -> BackendResult { + if i != 0 { + write!(self.out, ", ")?; + } + self.write_type(module, ty)?; + write!(self.out, " {ARGUMENT_VARIABLE_NAME}{i}")?; + if let crate::TypeInner::Array { base, size, .. } = module.types[ty].inner { + self.write_array_size(module, base, size)?; + } + Ok(()) + }; + + match module.types[constructor.ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + for (i, member) in members.iter().enumerate() { + write_arg(i, member.ty)?; + } + } + crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + .. + } => { + for i in 0..size.get() as usize { + write_arg(i, base)?; + } + } + _ => unreachable!(), + }; + + write!(self.out, ")")?; + + // Write function body + writeln!(self.out, " {{")?; + + match module.types[constructor.ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + let struct_name = &self.names[&NameKey::Type(constructor.ty)]; + writeln!( + self.out, + "{INDENT}{struct_name} {RETURN_VARIABLE_NAME} = ({struct_name})0;" + )?; + for (i, member) in members.iter().enumerate() { + let field_name = &self.names[&NameKey::StructMember(constructor.ty, i as u32)]; + + match module.types[member.ty].inner { + crate::TypeInner::Matrix { + columns, + rows: crate::VectorSize::Bi, + .. + } if member.binding.is_none() => { + for j in 0..columns as u8 { + writeln!( + self.out, + "{INDENT}{RETURN_VARIABLE_NAME}.{field_name}_{j} = {ARGUMENT_VARIABLE_NAME}{i}[{j}];" + )?; + } + } + ref other => { + // We cast arrays of native HLSL `floatCx2`s to arrays of `matCx2`s + // (where the inner matrix is represented by a struct with C `float2` members). + // See the module-level block comment in mod.rs for details. + if let Some(super::writer::MatrixType { + columns, + rows: crate::VectorSize::Bi, + width: 4, + }) = super::writer::get_inner_matrix_data(module, member.ty) + { + write!( + self.out, + "{}{}.{} = (__mat{}x2", + INDENT, RETURN_VARIABLE_NAME, field_name, columns as u8 + )?; + if let crate::TypeInner::Array { base, size, .. } = *other { + self.write_array_size(module, base, size)?; + } + writeln!(self.out, "){ARGUMENT_VARIABLE_NAME}{i};",)?; + } else { + writeln!( + self.out, + "{INDENT}{RETURN_VARIABLE_NAME}.{field_name} = {ARGUMENT_VARIABLE_NAME}{i};", + )?; + } + } + } + } + } + crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + .. + } => { + write!(self.out, "{INDENT}")?; + self.write_type(module, base)?; + write!(self.out, " {RETURN_VARIABLE_NAME}")?; + self.write_array_size(module, base, crate::ArraySize::Constant(size))?; + write!(self.out, " = {{ ")?; + for i in 0..size.get() { + if i != 0 { + write!(self.out, ", ")?; + } + write!(self.out, "{ARGUMENT_VARIABLE_NAME}{i}")?; + } + writeln!(self.out, " }};",)?; + } + _ => unreachable!(), + } + + // Write return value + writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?; + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } + + pub(super) fn write_wrapped_struct_matrix_get_function_name( + &mut self, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + let name = &self.names[&NameKey::Type(access.ty)]; + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + write!(self.out, "GetMat{field_name}On{name}")?; + Ok(()) + } + + /// Writes a function used to get a matCx2 from within a structure. + pub(super) fn write_wrapped_struct_matrix_get_function( + &mut self, + module: &crate::Module, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + use crate::back::INDENT; + + const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj"; + + // Write function return type and name + let member = match module.types[access.ty].inner { + crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize], + _ => unreachable!(), + }; + let ret_ty = &module.types[member.ty].inner; + self.write_value_type(module, ret_ty)?; + write!(self.out, " ")?; + self.write_wrapped_struct_matrix_get_function_name(access)?; + + // Write function parameters + write!(self.out, "(")?; + let struct_name = &self.names[&NameKey::Type(access.ty)]; + write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}")?; + + // Write function body + writeln!(self.out, ") {{")?; + + // Write return value + write!(self.out, "{INDENT}return ")?; + self.write_value_type(module, ret_ty)?; + write!(self.out, "(")?; + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + match module.types[member.ty].inner { + crate::TypeInner::Matrix { columns, .. } => { + for i in 0..columns as u8 { + if i != 0 { + write!(self.out, ", ")?; + } + write!(self.out, "{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}")?; + } + } + _ => unreachable!(), + } + writeln!(self.out, ");")?; + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } + + pub(super) fn write_wrapped_struct_matrix_set_function_name( + &mut self, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + let name = &self.names[&NameKey::Type(access.ty)]; + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + write!(self.out, "SetMat{field_name}On{name}")?; + Ok(()) + } + + /// Writes a function used to set a matCx2 from within a structure. + pub(super) fn write_wrapped_struct_matrix_set_function( + &mut self, + module: &crate::Module, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + use crate::back::INDENT; + + const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj"; + const MATRIX_ARGUMENT_VARIABLE_NAME: &str = "mat"; + + // Write function return type and name + write!(self.out, "void ")?; + self.write_wrapped_struct_matrix_set_function_name(access)?; + + // Write function parameters + write!(self.out, "(")?; + let struct_name = &self.names[&NameKey::Type(access.ty)]; + write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?; + let member = match module.types[access.ty].inner { + crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize], + _ => unreachable!(), + }; + self.write_type(module, member.ty)?; + write!(self.out, " {MATRIX_ARGUMENT_VARIABLE_NAME}")?; + // Write function body + writeln!(self.out, ") {{")?; + + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + + match module.types[member.ty].inner { + crate::TypeInner::Matrix { columns, .. } => { + for i in 0..columns as u8 { + writeln!( + self.out, + "{INDENT}{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {MATRIX_ARGUMENT_VARIABLE_NAME}[{i}];" + )?; + } + } + _ => unreachable!(), + } + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } + + pub(super) fn write_wrapped_struct_matrix_set_vec_function_name( + &mut self, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + let name = &self.names[&NameKey::Type(access.ty)]; + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + write!(self.out, "SetMatVec{field_name}On{name}")?; + Ok(()) + } + + /// Writes a function used to set a vec2 on a matCx2 from within a structure. + pub(super) fn write_wrapped_struct_matrix_set_vec_function( + &mut self, + module: &crate::Module, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + use crate::back::INDENT; + + const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj"; + const VECTOR_ARGUMENT_VARIABLE_NAME: &str = "vec"; + const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx"; + + // Write function return type and name + write!(self.out, "void ")?; + self.write_wrapped_struct_matrix_set_vec_function_name(access)?; + + // Write function parameters + write!(self.out, "(")?; + let struct_name = &self.names[&NameKey::Type(access.ty)]; + write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?; + let member = match module.types[access.ty].inner { + crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize], + _ => unreachable!(), + }; + let vec_ty = match module.types[member.ty].inner { + crate::TypeInner::Matrix { rows, scalar, .. } => { + crate::TypeInner::Vector { size: rows, scalar } + } + _ => unreachable!(), + }; + self.write_value_type(module, &vec_ty)?; + write!( + self.out, + " {VECTOR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}" + )?; + + // Write function body + writeln!(self.out, ") {{")?; + + writeln!( + self.out, + "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{" + )?; + + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + + match module.types[member.ty].inner { + crate::TypeInner::Matrix { columns, .. } => { + for i in 0..columns as u8 { + writeln!( + self.out, + "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {VECTOR_ARGUMENT_VARIABLE_NAME}; break; }}" + )?; + } + } + _ => unreachable!(), + } + + writeln!(self.out, "{INDENT}}}")?; + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } + + pub(super) fn write_wrapped_struct_matrix_set_scalar_function_name( + &mut self, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + let name = &self.names[&NameKey::Type(access.ty)]; + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + write!(self.out, "SetMatScalar{field_name}On{name}")?; + Ok(()) + } + + /// Writes a function used to set a float on a matCx2 from within a structure. + pub(super) fn write_wrapped_struct_matrix_set_scalar_function( + &mut self, + module: &crate::Module, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + use crate::back::INDENT; + + const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj"; + const SCALAR_ARGUMENT_VARIABLE_NAME: &str = "scalar"; + const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx"; + const VECTOR_INDEX_ARGUMENT_VARIABLE_NAME: &str = "vec_idx"; + + // Write function return type and name + write!(self.out, "void ")?; + self.write_wrapped_struct_matrix_set_scalar_function_name(access)?; + + // Write function parameters + write!(self.out, "(")?; + let struct_name = &self.names[&NameKey::Type(access.ty)]; + write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?; + let member = match module.types[access.ty].inner { + crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize], + _ => unreachable!(), + }; + let scalar_ty = match module.types[member.ty].inner { + crate::TypeInner::Matrix { scalar, .. } => crate::TypeInner::Scalar(scalar), + _ => unreachable!(), + }; + self.write_value_type(module, &scalar_ty)?; + write!( + self.out, + " {SCALAR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}, uint {VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}" + )?; + + // Write function body + writeln!(self.out, ") {{")?; + + writeln!( + self.out, + "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{" + )?; + + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + + match module.types[member.ty].inner { + crate::TypeInner::Matrix { columns, .. } => { + for i in 0..columns as u8 { + writeln!( + self.out, + "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}[{VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}] = {SCALAR_ARGUMENT_VARIABLE_NAME}; break; }}" + )?; + } + } + _ => unreachable!(), + } + + writeln!(self.out, "{INDENT}}}")?; + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } + + /// Write functions to create special types. + pub(super) fn write_special_functions(&mut self, module: &crate::Module) -> BackendResult { + for (type_key, struct_ty) in module.special_types.predeclared_types.iter() { + match type_key { + &crate::PredeclaredType::ModfResult { size, width } + | &crate::PredeclaredType::FrexpResult { size, width } => { + let arg_type_name_owner; + let arg_type_name = if let Some(size) = size { + arg_type_name_owner = format!( + "{}{}", + if width == 8 { "double" } else { "float" }, + size as u8 + ); + &arg_type_name_owner + } else if width == 8 { + "double" + } else { + "float" + }; + + let (defined_func_name, called_func_name, second_field_name, sign_multiplier) = + if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) { + (super::writer::MODF_FUNCTION, "modf", "whole", "") + } else { + ( + super::writer::FREXP_FUNCTION, + "frexp", + "exp_", + "sign(arg) * ", + ) + }; + + let struct_name = &self.names[&NameKey::Type(*struct_ty)]; + + writeln!( + self.out, + "{struct_name} {defined_func_name}({arg_type_name} arg) {{ + {arg_type_name} other; + {struct_name} result; + result.fract = {sign_multiplier}{called_func_name}(arg, other); + result.{second_field_name} = other; + return result; +}}" + )?; + writeln!(self.out)?; + } + &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {} + } + } + + Ok(()) + } + + /// Helper function that writes compose wrapped functions + pub(super) fn write_wrapped_compose_functions( + &mut self, + module: &crate::Module, + expressions: &crate::Arena<crate::Expression>, + ) -> BackendResult { + for (handle, _) in expressions.iter() { + if let crate::Expression::Compose { ty, .. } = expressions[handle] { + match module.types[ty].inner { + crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => { + let constructor = WrappedConstructor { ty }; + if self.wrapped.constructors.insert(constructor) { + self.write_wrapped_constructor_function(module, constructor)?; + } + } + _ => {} + }; + } + } + Ok(()) + } + + /// Helper function that writes various wrapped functions + pub(super) fn write_wrapped_functions( + &mut self, + module: &crate::Module, + func_ctx: &FunctionCtx, + ) -> BackendResult { + self.write_wrapped_compose_functions(module, func_ctx.expressions)?; + + for (handle, _) in func_ctx.expressions.iter() { + match func_ctx.expressions[handle] { + crate::Expression::ArrayLength(expr) => { + let global_expr = match func_ctx.expressions[expr] { + crate::Expression::GlobalVariable(_) => expr, + crate::Expression::AccessIndex { base, index: _ } => base, + ref other => unreachable!("Array length of {:?}", other), + }; + let global_var = match func_ctx.expressions[global_expr] { + crate::Expression::GlobalVariable(var_handle) => { + &module.global_variables[var_handle] + } + ref other => unreachable!("Array length of base {:?}", other), + }; + let storage_access = match global_var.space { + crate::AddressSpace::Storage { access } => access, + _ => crate::StorageAccess::default(), + }; + let wal = WrappedArrayLength { + writable: storage_access.contains(crate::StorageAccess::STORE), + }; + + if self.wrapped.array_lengths.insert(wal) { + self.write_wrapped_array_length_function(wal)?; + } + } + crate::Expression::ImageQuery { image, query } => { + let wiq = match *func_ctx.resolve_type(image, &module.types) { + crate::TypeInner::Image { + dim, + arrayed, + class, + } => WrappedImageQuery { + dim, + arrayed, + class, + query: query.into(), + }, + _ => unreachable!("we only query images"), + }; + + if self.wrapped.image_queries.insert(wiq) { + self.write_wrapped_image_query_function(module, wiq, handle, func_ctx)?; + } + } + // Write `WrappedConstructor` for structs that are loaded from `AddressSpace::Storage` + // since they will later be used by the fn `write_storage_load` + crate::Expression::Load { pointer } => { + let pointer_space = func_ctx + .resolve_type(pointer, &module.types) + .pointer_space(); + + if let Some(crate::AddressSpace::Storage { .. }) = pointer_space { + if let Some(ty) = func_ctx.info[handle].ty.handle() { + write_wrapped_constructor(self, ty, module)?; + } + } + + fn write_wrapped_constructor<W: Write>( + writer: &mut super::Writer<'_, W>, + ty: Handle<crate::Type>, + module: &crate::Module, + ) -> BackendResult { + match module.types[ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + for member in members { + write_wrapped_constructor(writer, member.ty, module)?; + } + + let constructor = WrappedConstructor { ty }; + if writer.wrapped.constructors.insert(constructor) { + writer + .write_wrapped_constructor_function(module, constructor)?; + } + } + crate::TypeInner::Array { base, .. } => { + write_wrapped_constructor(writer, base, module)?; + + let constructor = WrappedConstructor { ty }; + if writer.wrapped.constructors.insert(constructor) { + writer + .write_wrapped_constructor_function(module, constructor)?; + } + } + _ => {} + }; + + Ok(()) + } + } + // We treat matrices of the form `matCx2` as a sequence of C `vec2`s + // (see top level module docs for details). + // + // The functions injected here are required to get the matrix accesses working. + crate::Expression::AccessIndex { base, index } => { + let base_ty_res = &func_ctx.info[base].ty; + let mut resolved = base_ty_res.inner_with(&module.types); + let base_ty_handle = match *resolved { + crate::TypeInner::Pointer { base, .. } => { + resolved = &module.types[base].inner; + Some(base) + } + _ => base_ty_res.handle(), + }; + if let crate::TypeInner::Struct { ref members, .. } = *resolved { + let member = &members[index as usize]; + + match module.types[member.ty].inner { + crate::TypeInner::Matrix { + rows: crate::VectorSize::Bi, + .. + } if member.binding.is_none() => { + let ty = base_ty_handle.unwrap(); + let access = WrappedStructMatrixAccess { ty, index }; + + if self.wrapped.struct_matrix_access.insert(access) { + self.write_wrapped_struct_matrix_get_function(module, access)?; + self.write_wrapped_struct_matrix_set_function(module, access)?; + self.write_wrapped_struct_matrix_set_vec_function( + module, access, + )?; + self.write_wrapped_struct_matrix_set_scalar_function( + module, access, + )?; + } + } + _ => {} + } + } + } + _ => {} + }; + } + + Ok(()) + } + + pub(super) fn write_texture_coordinates( + &mut self, + kind: &str, + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + mip_level: Option<Handle<crate::Expression>>, + module: &crate::Module, + func_ctx: &FunctionCtx, + ) -> BackendResult { + // HLSL expects the array index to be merged with the coordinate + let extra = array_index.is_some() as usize + (mip_level.is_some()) as usize; + if extra == 0 { + self.write_expr(module, coordinate, func_ctx)?; + } else { + let num_coords = match *func_ctx.resolve_type(coordinate, &module.types) { + crate::TypeInner::Scalar { .. } => 1, + crate::TypeInner::Vector { size, .. } => size as usize, + _ => unreachable!(), + }; + write!(self.out, "{}{}(", kind, num_coords + extra)?; + self.write_expr(module, coordinate, func_ctx)?; + if let Some(expr) = array_index { + write!(self.out, ", ")?; + self.write_expr(module, expr, func_ctx)?; + } + if let Some(expr) = mip_level { + write!(self.out, ", ")?; + self.write_expr(module, expr, func_ctx)?; + } + write!(self.out, ")")?; + } + Ok(()) + } + + pub(super) fn write_mat_cx2_typedef_and_functions( + &mut self, + WrappedMatCx2 { columns }: WrappedMatCx2, + ) -> BackendResult { + use crate::back::INDENT; + + // typedef + write!(self.out, "typedef struct {{ ")?; + for i in 0..columns as u8 { + write!(self.out, "float2 _{i}; ")?; + } + writeln!(self.out, "}} __mat{}x2;", columns as u8)?; + + // __get_col_of_mat + writeln!( + self.out, + "float2 __get_col_of_mat{}x2(__mat{}x2 mat, uint idx) {{", + columns as u8, columns as u8 + )?; + writeln!(self.out, "{INDENT}switch(idx) {{")?; + for i in 0..columns as u8 { + writeln!(self.out, "{INDENT}case {i}: {{ return mat._{i}; }}")?; + } + writeln!(self.out, "{INDENT}default: {{ return (float2)0; }}")?; + writeln!(self.out, "{INDENT}}}")?; + writeln!(self.out, "}}")?; + + // __set_col_of_mat + writeln!( + self.out, + "void __set_col_of_mat{}x2(__mat{}x2 mat, uint idx, float2 value) {{", + columns as u8, columns as u8 + )?; + writeln!(self.out, "{INDENT}switch(idx) {{")?; + for i in 0..columns as u8 { + writeln!(self.out, "{INDENT}case {i}: {{ mat._{i} = value; break; }}")?; + } + writeln!(self.out, "{INDENT}}}")?; + writeln!(self.out, "}}")?; + + // __set_el_of_mat + writeln!( + self.out, + "void __set_el_of_mat{}x2(__mat{}x2 mat, uint idx, uint vec_idx, float value) {{", + columns as u8, columns as u8 + )?; + writeln!(self.out, "{INDENT}switch(idx) {{")?; + for i in 0..columns as u8 { + writeln!( + self.out, + "{INDENT}case {i}: {{ mat._{i}[vec_idx] = value; break; }}" + )?; + } + writeln!(self.out, "{INDENT}}}")?; + writeln!(self.out, "}}")?; + + writeln!(self.out)?; + + Ok(()) + } + + pub(super) fn write_all_mat_cx2_typedefs_and_functions( + &mut self, + module: &crate::Module, + ) -> BackendResult { + for (handle, _) in module.global_variables.iter() { + let global = &module.global_variables[handle]; + + if global.space == crate::AddressSpace::Uniform { + if let Some(super::writer::MatrixType { + columns, + rows: crate::VectorSize::Bi, + width: 4, + }) = super::writer::get_inner_matrix_data(module, global.ty) + { + let entry = WrappedMatCx2 { columns }; + if self.wrapped.mat_cx2s.insert(entry) { + self.write_mat_cx2_typedef_and_functions(entry)?; + } + } + } + } + + for (_, ty) in module.types.iter() { + if let crate::TypeInner::Struct { ref members, .. } = ty.inner { + for member in members.iter() { + if let crate::TypeInner::Array { .. } = module.types[member.ty].inner { + if let Some(super::writer::MatrixType { + columns, + rows: crate::VectorSize::Bi, + width: 4, + }) = super::writer::get_inner_matrix_data(module, member.ty) + { + let entry = WrappedMatCx2 { columns }; + if self.wrapped.mat_cx2s.insert(entry) { + self.write_mat_cx2_typedef_and_functions(entry)?; + } + } + } + } + } + } + + Ok(()) + } +} diff --git a/third_party/rust/naga/src/back/hlsl/keywords.rs b/third_party/rust/naga/src/back/hlsl/keywords.rs new file mode 100644 index 0000000000..059e533ff7 --- /dev/null +++ b/third_party/rust/naga/src/back/hlsl/keywords.rs @@ -0,0 +1,904 @@ +// When compiling with FXC without strict mode, these keywords are actually case insensitive. +// If you compile with strict mode and specify a different casing like "Pass" instead in an identifier, FXC will give this error: +// "error X3086: alternate cases for 'pass' are deprecated in strict mode" +// This behavior is not documented anywhere, but as far as I can tell this is the full list. +pub const RESERVED_CASE_INSENSITIVE: &[&str] = &[ + "asm", + "decl", + "pass", + "technique", + "Texture1D", + "Texture2D", + "Texture3D", + "TextureCube", +]; + +pub const RESERVED: &[&str] = &[ + // FXC keywords, from https://github.com/MicrosoftDocs/win32/blob/c885cb0c63b0e9be80c6a0e6512473ac6f4e771e/desktop-src/direct3dhlsl/dx-graphics-hlsl-appendix-keywords.md?plain=1#L99-L118 + "AppendStructuredBuffer", + "asm", + "asm_fragment", + "BlendState", + "bool", + "break", + "Buffer", + "ByteAddressBuffer", + "case", + "cbuffer", + "centroid", + "class", + "column_major", + "compile", + "compile_fragment", + "CompileShader", + "const", + "continue", + "ComputeShader", + "ConsumeStructuredBuffer", + "default", + "DepthStencilState", + "DepthStencilView", + "discard", + "do", + "double", + "DomainShader", + "dword", + "else", + "export", + "extern", + "false", + "float", + "for", + "fxgroup", + "GeometryShader", + "groupshared", + "half", + "Hullshader", + "if", + "in", + "inline", + "inout", + "InputPatch", + "int", + "interface", + "line", + "lineadj", + "linear", + "LineStream", + "matrix", + "min16float", + "min10float", + "min16int", + "min12int", + "min16uint", + "namespace", + "nointerpolation", + "noperspective", + "NULL", + "out", + "OutputPatch", + "packoffset", + "pass", + "pixelfragment", + "PixelShader", + "point", + "PointStream", + "precise", + "RasterizerState", + "RenderTargetView", + "return", + "register", + "row_major", + "RWBuffer", + "RWByteAddressBuffer", + "RWStructuredBuffer", + "RWTexture1D", + "RWTexture1DArray", + "RWTexture2D", + "RWTexture2DArray", + "RWTexture3D", + "sample", + "sampler", + "SamplerState", + "SamplerComparisonState", + "shared", + "snorm", + "stateblock", + "stateblock_state", + "static", + "string", + "struct", + "switch", + "StructuredBuffer", + "tbuffer", + "technique", + "technique10", + "technique11", + "texture", + "Texture1D", + "Texture1DArray", + "Texture2D", + "Texture2DArray", + "Texture2DMS", + "Texture2DMSArray", + "Texture3D", + "TextureCube", + "TextureCubeArray", + "true", + "typedef", + "triangle", + "triangleadj", + "TriangleStream", + "uint", + "uniform", + "unorm", + "unsigned", + "vector", + "vertexfragment", + "VertexShader", + "void", + "volatile", + "while", + // FXC reserved keywords, from https://github.com/MicrosoftDocs/win32/blob/c885cb0c63b0e9be80c6a0e6512473ac6f4e771e/desktop-src/direct3dhlsl/dx-graphics-hlsl-appendix-reserved-words.md?plain=1#L19-L38 + "auto", + "case", + "catch", + "char", + "class", + "const_cast", + "default", + "delete", + "dynamic_cast", + "enum", + "explicit", + "friend", + "goto", + "long", + "mutable", + "new", + "operator", + "private", + "protected", + "public", + "reinterpret_cast", + "short", + "signed", + "sizeof", + "static_cast", + "template", + "this", + "throw", + "try", + "typename", + "union", + "unsigned", + "using", + "virtual", + // FXC intrinsics, from https://github.com/MicrosoftDocs/win32/blob/1682b99e203708f6f5eda972d966e30f3c1588de/desktop-src/direct3dhlsl/dx-graphics-hlsl-intrinsic-functions.md?plain=1#L26-L165 + "abort", + "abs", + "acos", + "all", + "AllMemoryBarrier", + "AllMemoryBarrierWithGroupSync", + "any", + "asdouble", + "asfloat", + "asin", + "asint", + "asuint", + "atan", + "atan2", + "ceil", + "CheckAccessFullyMapped", + "clamp", + "clip", + "cos", + "cosh", + "countbits", + "cross", + "D3DCOLORtoUBYTE4", + "ddx", + "ddx_coarse", + "ddx_fine", + "ddy", + "ddy_coarse", + "ddy_fine", + "degrees", + "determinant", + "DeviceMemoryBarrier", + "DeviceMemoryBarrierWithGroupSync", + "distance", + "dot", + "dst", + "errorf", + "EvaluateAttributeCentroid", + "EvaluateAttributeAtSample", + "EvaluateAttributeSnapped", + "exp", + "exp2", + "f16tof32", + "f32tof16", + "faceforward", + "firstbithigh", + "firstbitlow", + "floor", + "fma", + "fmod", + "frac", + "frexp", + "fwidth", + "GetRenderTargetSampleCount", + "GetRenderTargetSamplePosition", + "GroupMemoryBarrier", + "GroupMemoryBarrierWithGroupSync", + "InterlockedAdd", + "InterlockedAnd", + "InterlockedCompareExchange", + "InterlockedCompareStore", + "InterlockedExchange", + "InterlockedMax", + "InterlockedMin", + "InterlockedOr", + "InterlockedXor", + "isfinite", + "isinf", + "isnan", + "ldexp", + "length", + "lerp", + "lit", + "log", + "log10", + "log2", + "mad", + "max", + "min", + "modf", + "msad4", + "mul", + "noise", + "normalize", + "pow", + "printf", + "Process2DQuadTessFactorsAvg", + "Process2DQuadTessFactorsMax", + "Process2DQuadTessFactorsMin", + "ProcessIsolineTessFactors", + "ProcessQuadTessFactorsAvg", + "ProcessQuadTessFactorsMax", + "ProcessQuadTessFactorsMin", + "ProcessTriTessFactorsAvg", + "ProcessTriTessFactorsMax", + "ProcessTriTessFactorsMin", + "radians", + "rcp", + "reflect", + "refract", + "reversebits", + "round", + "rsqrt", + "saturate", + "sign", + "sin", + "sincos", + "sinh", + "smoothstep", + "sqrt", + "step", + "tan", + "tanh", + "tex1D", + "tex1Dbias", + "tex1Dgrad", + "tex1Dlod", + "tex1Dproj", + "tex2D", + "tex2Dbias", + "tex2Dgrad", + "tex2Dlod", + "tex2Dproj", + "tex3D", + "tex3Dbias", + "tex3Dgrad", + "tex3Dlod", + "tex3Dproj", + "texCUBE", + "texCUBEbias", + "texCUBEgrad", + "texCUBElod", + "texCUBEproj", + "transpose", + "trunc", + // DXC (reserved) keywords, from https://github.com/microsoft/DirectXShaderCompiler/blob/d5d478470d3020a438d3cb810b8d3fe0992e6709/tools/clang/include/clang/Basic/TokenKinds.def#L222-L648 + // with the KEYALL, KEYCXX, BOOLSUPPORT, WCHARSUPPORT, KEYHLSL options enabled (see https://github.com/microsoft/DirectXShaderCompiler/blob/d5d478470d3020a438d3cb810b8d3fe0992e6709/tools/clang/lib/Frontend/CompilerInvocation.cpp#L1199) + "auto", + "break", + "case", + "char", + "const", + "continue", + "default", + "do", + "double", + "else", + "enum", + "extern", + "float", + "for", + "goto", + "if", + "inline", + "int", + "long", + "register", + "return", + "short", + "signed", + "sizeof", + "static", + "struct", + "switch", + "typedef", + "union", + "unsigned", + "void", + "volatile", + "while", + "_Alignas", + "_Alignof", + "_Atomic", + "_Complex", + "_Generic", + "_Imaginary", + "_Noreturn", + "_Static_assert", + "_Thread_local", + "__func__", + "__objc_yes", + "__objc_no", + "asm", + "bool", + "catch", + "class", + "const_cast", + "delete", + "dynamic_cast", + "explicit", + "export", + "false", + "friend", + "mutable", + "namespace", + "new", + "operator", + "private", + "protected", + "public", + "reinterpret_cast", + "static_cast", + "template", + "this", + "throw", + "true", + "try", + "typename", + "typeid", + "using", + "virtual", + "wchar_t", + "_Decimal32", + "_Decimal64", + "_Decimal128", + "__null", + "__alignof", + "__attribute", + "__builtin_choose_expr", + "__builtin_offsetof", + "__builtin_va_arg", + "__extension__", + "__imag", + "__int128", + "__label__", + "__real", + "__thread", + "__FUNCTION__", + "__PRETTY_FUNCTION__", + "__is_nothrow_assignable", + "__is_constructible", + "__is_nothrow_constructible", + "__has_nothrow_assign", + "__has_nothrow_move_assign", + "__has_nothrow_copy", + "__has_nothrow_constructor", + "__has_trivial_assign", + "__has_trivial_move_assign", + "__has_trivial_copy", + "__has_trivial_constructor", + "__has_trivial_move_constructor", + "__has_trivial_destructor", + "__has_virtual_destructor", + "__is_abstract", + "__is_base_of", + "__is_class", + "__is_convertible_to", + "__is_empty", + "__is_enum", + "__is_final", + "__is_literal", + "__is_literal_type", + "__is_pod", + "__is_polymorphic", + "__is_trivial", + "__is_union", + "__is_trivially_constructible", + "__is_trivially_copyable", + "__is_trivially_assignable", + "__underlying_type", + "__is_lvalue_expr", + "__is_rvalue_expr", + "__is_arithmetic", + "__is_floating_point", + "__is_integral", + "__is_complete_type", + "__is_void", + "__is_array", + "__is_function", + "__is_reference", + "__is_lvalue_reference", + "__is_rvalue_reference", + "__is_fundamental", + "__is_object", + "__is_scalar", + "__is_compound", + "__is_pointer", + "__is_member_object_pointer", + "__is_member_function_pointer", + "__is_member_pointer", + "__is_const", + "__is_volatile", + "__is_standard_layout", + "__is_signed", + "__is_unsigned", + "__is_same", + "__is_convertible", + "__array_rank", + "__array_extent", + "__private_extern__", + "__module_private__", + "__declspec", + "__cdecl", + "__stdcall", + "__fastcall", + "__thiscall", + "__vectorcall", + "cbuffer", + "tbuffer", + "packoffset", + "linear", + "centroid", + "nointerpolation", + "noperspective", + "sample", + "column_major", + "row_major", + "in", + "out", + "inout", + "uniform", + "precise", + "center", + "shared", + "groupshared", + "discard", + "snorm", + "unorm", + "point", + "line", + "lineadj", + "triangle", + "triangleadj", + "globallycoherent", + "interface", + "sampler_state", + "technique", + "indices", + "vertices", + "primitives", + "payload", + "Technique", + "technique10", + "technique11", + "__builtin_omp_required_simd_align", + "__pascal", + "__fp16", + "__alignof__", + "__asm", + "__asm__", + "__attribute__", + "__complex", + "__complex__", + "__const", + "__const__", + "__decltype", + "__imag__", + "__inline", + "__inline__", + "__nullptr", + "__real__", + "__restrict", + "__restrict__", + "__signed", + "__signed__", + "__typeof", + "__typeof__", + "__volatile", + "__volatile__", + "_Nonnull", + "_Nullable", + "_Null_unspecified", + "__builtin_convertvector", + "__char16_t", + "__char32_t", + // DXC intrinsics, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/utils/hct/gen_intrin_main.txt#L86-L376 + "D3DCOLORtoUBYTE4", + "GetRenderTargetSampleCount", + "GetRenderTargetSamplePosition", + "abort", + "abs", + "acos", + "all", + "AllMemoryBarrier", + "AllMemoryBarrierWithGroupSync", + "any", + "asdouble", + "asfloat", + "asfloat16", + "asint16", + "asin", + "asint", + "asuint", + "asuint16", + "atan", + "atan2", + "ceil", + "clamp", + "clip", + "cos", + "cosh", + "countbits", + "cross", + "ddx", + "ddx_coarse", + "ddx_fine", + "ddy", + "ddy_coarse", + "ddy_fine", + "degrees", + "determinant", + "DeviceMemoryBarrier", + "DeviceMemoryBarrierWithGroupSync", + "distance", + "dot", + "dst", + "EvaluateAttributeAtSample", + "EvaluateAttributeCentroid", + "EvaluateAttributeSnapped", + "GetAttributeAtVertex", + "exp", + "exp2", + "f16tof32", + "f32tof16", + "faceforward", + "firstbithigh", + "firstbitlow", + "floor", + "fma", + "fmod", + "frac", + "frexp", + "fwidth", + "GroupMemoryBarrier", + "GroupMemoryBarrierWithGroupSync", + "InterlockedAdd", + "InterlockedMin", + "InterlockedMax", + "InterlockedAnd", + "InterlockedOr", + "InterlockedXor", + "InterlockedCompareStore", + "InterlockedExchange", + "InterlockedCompareExchange", + "InterlockedCompareStoreFloatBitwise", + "InterlockedCompareExchangeFloatBitwise", + "isfinite", + "isinf", + "isnan", + "ldexp", + "length", + "lerp", + "lit", + "log", + "log10", + "log2", + "mad", + "max", + "min", + "modf", + "msad4", + "mul", + "normalize", + "pow", + "printf", + "Process2DQuadTessFactorsAvg", + "Process2DQuadTessFactorsMax", + "Process2DQuadTessFactorsMin", + "ProcessIsolineTessFactors", + "ProcessQuadTessFactorsAvg", + "ProcessQuadTessFactorsMax", + "ProcessQuadTessFactorsMin", + "ProcessTriTessFactorsAvg", + "ProcessTriTessFactorsMax", + "ProcessTriTessFactorsMin", + "radians", + "rcp", + "reflect", + "refract", + "reversebits", + "round", + "rsqrt", + "saturate", + "sign", + "sin", + "sincos", + "sinh", + "smoothstep", + "source_mark", + "sqrt", + "step", + "tan", + "tanh", + "tex1D", + "tex1Dbias", + "tex1Dgrad", + "tex1Dlod", + "tex1Dproj", + "tex2D", + "tex2Dbias", + "tex2Dgrad", + "tex2Dlod", + "tex2Dproj", + "tex3D", + "tex3Dbias", + "tex3Dgrad", + "tex3Dlod", + "tex3Dproj", + "texCUBE", + "texCUBEbias", + "texCUBEgrad", + "texCUBElod", + "texCUBEproj", + "transpose", + "trunc", + "CheckAccessFullyMapped", + "AddUint64", + "NonUniformResourceIndex", + "WaveIsFirstLane", + "WaveGetLaneIndex", + "WaveGetLaneCount", + "WaveActiveAnyTrue", + "WaveActiveAllTrue", + "WaveActiveAllEqual", + "WaveActiveBallot", + "WaveReadLaneAt", + "WaveReadLaneFirst", + "WaveActiveCountBits", + "WaveActiveSum", + "WaveActiveProduct", + "WaveActiveBitAnd", + "WaveActiveBitOr", + "WaveActiveBitXor", + "WaveActiveMin", + "WaveActiveMax", + "WavePrefixCountBits", + "WavePrefixSum", + "WavePrefixProduct", + "WaveMatch", + "WaveMultiPrefixBitAnd", + "WaveMultiPrefixBitOr", + "WaveMultiPrefixBitXor", + "WaveMultiPrefixCountBits", + "WaveMultiPrefixProduct", + "WaveMultiPrefixSum", + "QuadReadLaneAt", + "QuadReadAcrossX", + "QuadReadAcrossY", + "QuadReadAcrossDiagonal", + "QuadAny", + "QuadAll", + "TraceRay", + "ReportHit", + "CallShader", + "IgnoreHit", + "AcceptHitAndEndSearch", + "DispatchRaysIndex", + "DispatchRaysDimensions", + "WorldRayOrigin", + "WorldRayDirection", + "ObjectRayOrigin", + "ObjectRayDirection", + "RayTMin", + "RayTCurrent", + "PrimitiveIndex", + "InstanceID", + "InstanceIndex", + "GeometryIndex", + "HitKind", + "RayFlags", + "ObjectToWorld", + "WorldToObject", + "ObjectToWorld3x4", + "WorldToObject3x4", + "ObjectToWorld4x3", + "WorldToObject4x3", + "dot4add_u8packed", + "dot4add_i8packed", + "dot2add", + "unpack_s8s16", + "unpack_u8u16", + "unpack_s8s32", + "unpack_u8u32", + "pack_s8", + "pack_u8", + "pack_clamp_s8", + "pack_clamp_u8", + "SetMeshOutputCounts", + "DispatchMesh", + "IsHelperLane", + "AllocateRayQuery", + "CreateResourceFromHeap", + "and", + "or", + "select", + // DXC resource and other types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/HlslTypes.cpp#L441-#L572 + "InputPatch", + "OutputPatch", + "PointStream", + "LineStream", + "TriangleStream", + "Texture1D", + "RWTexture1D", + "Texture2D", + "RWTexture2D", + "Texture2DMS", + "RWTexture2DMS", + "Texture3D", + "RWTexture3D", + "TextureCube", + "RWTextureCube", + "Texture1DArray", + "RWTexture1DArray", + "Texture2DArray", + "RWTexture2DArray", + "Texture2DMSArray", + "RWTexture2DMSArray", + "TextureCubeArray", + "RWTextureCubeArray", + "FeedbackTexture2D", + "FeedbackTexture2DArray", + "RasterizerOrderedTexture1D", + "RasterizerOrderedTexture2D", + "RasterizerOrderedTexture3D", + "RasterizerOrderedTexture1DArray", + "RasterizerOrderedTexture2DArray", + "RasterizerOrderedBuffer", + "RasterizerOrderedByteAddressBuffer", + "RasterizerOrderedStructuredBuffer", + "ByteAddressBuffer", + "RWByteAddressBuffer", + "StructuredBuffer", + "RWStructuredBuffer", + "AppendStructuredBuffer", + "ConsumeStructuredBuffer", + "Buffer", + "RWBuffer", + "SamplerState", + "SamplerComparisonState", + "ConstantBuffer", + "TextureBuffer", + "RaytracingAccelerationStructure", + // DXC templated types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp + // look for `BuiltinTypeDeclBuilder` + "matrix", + "vector", + "TextureBuffer", + "ConstantBuffer", + "RayQuery", + // Naga utilities + super::writer::MODF_FUNCTION, + super::writer::FREXP_FUNCTION, +]; + +// DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254 +// + vector and matrix shorthands +pub const TYPES: &[&str] = &{ + const L: usize = 23 * (1 + 4 + 4 * 4); + let mut res = [""; L]; + let mut c = 0; + + /// For each scalar type, it will additionally generate vector and matrix shorthands + macro_rules! generate { + ([$($roots:literal),*], $x:tt) => { + $( + generate!(@inner push $roots); + generate!(@inner $roots, $x); + )* + }; + + (@inner $root:literal, [$($x:literal),*]) => { + generate!(@inner vector $root, $($x)*); + generate!(@inner matrix $root, $($x)*); + }; + + (@inner vector $root:literal, $($x:literal)*) => { + $( + generate!(@inner push concat!($root, $x)); + )* + }; + + (@inner matrix $root:literal, $($x:literal)*) => { + // Duplicate the list + generate!(@inner matrix $root, $($x)*; $($x)*); + }; + + // The head/tail recursion: pick the first element of the first list and recursively do it for the tail. + (@inner matrix $root:literal, $head:literal $($tail:literal)*; $($x:literal)*) => { + $( + generate!(@inner push concat!($root, $head, "x", $x)); + )* + generate!(@inner matrix $root, $($tail)*; $($x)*); + + }; + + // The end of iteration: we exhausted the list + (@inner matrix $root:literal, ; $($x:literal)*) => {}; + + (@inner push $v:expr) => { + res[c] = $v; + c += 1; + }; + } + + generate!( + [ + "bool", + "int", + "uint", + "dword", + "half", + "float", + "double", + "min10float", + "min16float", + "min12int", + "min16int", + "min16uint", + "int16_t", + "int32_t", + "int64_t", + "uint16_t", + "uint32_t", + "uint64_t", + "float16_t", + "float32_t", + "float64_t", + "int8_t4_packed", + "uint8_t4_packed" + ], + ["1", "2", "3", "4"] + ); + + debug_assert!(c == L); + + res +}; diff --git a/third_party/rust/naga/src/back/hlsl/mod.rs b/third_party/rust/naga/src/back/hlsl/mod.rs new file mode 100644 index 0000000000..37ddbd3d67 --- /dev/null +++ b/third_party/rust/naga/src/back/hlsl/mod.rs @@ -0,0 +1,302 @@ +/*! +Backend for [HLSL][hlsl] (High-Level Shading Language). + +# Supported shader model versions: +- 5.0 +- 5.1 +- 6.0 + +# Layout of values in `uniform` buffers + +WGSL's ["Internal Layout of Values"][ilov] rules specify how each WGSL +type should be stored in `uniform` and `storage` buffers. The HLSL we +generate must access values in that form, even when it is not what +HLSL would use normally. + +The rules described here only apply to WGSL `uniform` variables. WGSL +`storage` buffers are translated as HLSL `ByteAddressBuffers`, for +which we generate `Load` and `Store` method calls with explicit byte +offsets. WGSL pipeline inputs must be scalars or vectors; they cannot +be matrices, which is where the interesting problems arise. + +## Row- and column-major ordering for matrices + +WGSL specifies that matrices in uniform buffers are stored in +column-major order. This matches HLSL's default, so one might expect +things to be straightforward. Unfortunately, WGSL and HLSL disagree on +what indexing a matrix means: in WGSL, `m[i]` retrieves the `i`'th +*column* of `m`, whereas in HLSL it retrieves the `i`'th *row*. We +want to avoid translating `m[i]` into some complicated reassembly of a +vector from individually fetched components, so this is a problem. + +However, with a bit of trickery, it is possible to use HLSL's `m[i]` +as the translation of WGSL's `m[i]`: + +- We declare all matrices in uniform buffers in HLSL with the + `row_major` qualifier, and transpose the row and column counts: a + WGSL `mat3x4<f32>`, say, becomes an HLSL `row_major float3x4`. (Note + that WGSL and HLSL type names put the row and column in reverse + order.) Since the HLSL type is the transpose of how WebGPU directs + the user to store the data, HLSL will load all matrices transposed. + +- Since matrices are transposed, an HLSL indexing expression retrieves + the "columns" of the intended WGSL value, as desired. + +- For vector-matrix multiplication, since `mul(transpose(m), v)` is + equivalent to `mul(v, m)` (note the reversal of the arguments), and + `mul(v, transpose(m))` is equivalent to `mul(m, v)`, we can + translate WGSL `m * v` and `v * m` to HLSL by simply reversing the + arguments to `mul`. + +## Padding in two-row matrices + +An HLSL `row_major floatKx2` matrix has padding between its rows that +the WGSL `matKx2<f32>` matrix it represents does not. HLSL stores all +matrix rows [aligned on 16-byte boundaries][16bb], whereas WGSL says +that the columns of a `matKx2<f32>` need only be [aligned as required +for `vec2<f32>`][ilov], which is [eight-byte alignment][8bb]. + +To compensate for this, any time a `matKx2<f32>` appears in a WGSL +`uniform` variable, whether directly as the variable's type or as part +of a struct/array, we actually emit `K` separate `float2` members, and +assemble/disassemble the matrix from its columns (in WGSL; rows in +HLSL) upon load and store. + +For example, the following WGSL struct type: + +```ignore +struct Baz { + m: mat3x2<f32>, +} +``` + +is rendered as the HLSL struct type: + +```ignore +struct Baz { + float2 m_0; float2 m_1; float2 m_2; +}; +``` + +The `wrapped_struct_matrix` functions in `help.rs` generate HLSL +helper functions to access such members, converting between the stored +form and the HLSL matrix types appropriately. For example, for reading +the member `m` of the `Baz` struct above, we emit: + +```ignore +float3x2 GetMatmOnBaz(Baz obj) { + return float3x2(obj.m_0, obj.m_1, obj.m_2); +} +``` + +We also emit an analogous `Set` function, as well as functions for +accessing individual columns by dynamic index. + +[hlsl]: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl +[ilov]: https://gpuweb.github.io/gpuweb/wgsl/#internal-value-layout +[16bb]: https://github.com/microsoft/DirectXShaderCompiler/wiki/Buffer-Packing#constant-buffer-packing +[8bb]: https://gpuweb.github.io/gpuweb/wgsl/#alignment-and-size +*/ + +mod conv; +mod help; +mod keywords; +mod storage; +mod writer; + +use std::fmt::Error as FmtError; +use thiserror::Error; + +use crate::{back, proc}; + +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct BindTarget { + pub space: u8, + pub register: u32, + /// If the binding is an unsized binding array, this overrides the size. + pub binding_array_size: Option<u32>, +} + +// Using `BTreeMap` instead of `HashMap` so that we can hash itself. +pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>; + +/// A HLSL shader model version. +#[allow(non_snake_case, non_camel_case_types)] +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum ShaderModel { + V5_0, + V5_1, + V6_0, +} + +impl ShaderModel { + pub const fn to_str(self) -> &'static str { + match self { + Self::V5_0 => "5_0", + Self::V5_1 => "5_1", + Self::V6_0 => "6_0", + } + } +} + +impl crate::ShaderStage { + pub const fn to_hlsl_str(self) -> &'static str { + match self { + Self::Vertex => "vs", + Self::Fragment => "ps", + Self::Compute => "cs", + } + } +} + +impl crate::ImageDimension { + const fn to_hlsl_str(self) -> &'static str { + match self { + Self::D1 => "1D", + Self::D2 => "2D", + Self::D3 => "3D", + Self::Cube => "Cube", + } + } +} + +/// Shorthand result used internally by the backend +type BackendResult = Result<(), Error>; + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum EntryPointError { + #[error("mapping of {0:?} is missing")] + MissingBinding(crate::ResourceBinding), +} + +/// Configuration used in the [`Writer`]. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct Options { + /// The hlsl shader model to be used + pub shader_model: ShaderModel, + /// Map of resources association to binding locations. + pub binding_map: BindingMap, + /// Don't panic on missing bindings, instead generate any HLSL. + pub fake_missing_bindings: bool, + /// Add special constants to `SV_VertexIndex` and `SV_InstanceIndex`, + /// to make them work like in Vulkan/Metal, with help of the host. + pub special_constants_binding: Option<BindTarget>, + /// Bind target of the push constant buffer + pub push_constants_target: Option<BindTarget>, + /// Should workgroup variables be zero initialized (by polyfilling)? + pub zero_initialize_workgroup_memory: bool, +} + +impl Default for Options { + fn default() -> Self { + Options { + shader_model: ShaderModel::V5_1, + binding_map: BindingMap::default(), + fake_missing_bindings: true, + special_constants_binding: None, + push_constants_target: None, + zero_initialize_workgroup_memory: true, + } + } +} + +impl Options { + fn resolve_resource_binding( + &self, + res_binding: &crate::ResourceBinding, + ) -> Result<BindTarget, EntryPointError> { + match self.binding_map.get(res_binding) { + Some(target) => Ok(target.clone()), + None if self.fake_missing_bindings => Ok(BindTarget { + space: res_binding.group as u8, + register: res_binding.binding, + binding_array_size: None, + }), + None => Err(EntryPointError::MissingBinding(res_binding.clone())), + } + } +} + +/// Reflection info for entry point names. +#[derive(Default)] +pub struct ReflectionInfo { + /// Mapping of the entry point names. + /// + /// Each item in the array corresponds to an entry point index. The real entry point name may be different if one of the + /// reserved words are used. + /// + /// Note: Some entry points may fail translation because of missing bindings. + pub entry_point_names: Vec<Result<String, EntryPointError>>, +} + +#[derive(Error, Debug)] +pub enum Error { + #[error(transparent)] + IoError(#[from] FmtError), + #[error("A scalar with an unsupported width was requested: {0:?}")] + UnsupportedScalar(crate::Scalar), + #[error("{0}")] + Unimplemented(String), // TODO: Error used only during development + #[error("{0}")] + Custom(String), +} + +#[derive(Default)] +struct Wrapped { + array_lengths: crate::FastHashSet<help::WrappedArrayLength>, + image_queries: crate::FastHashSet<help::WrappedImageQuery>, + constructors: crate::FastHashSet<help::WrappedConstructor>, + struct_matrix_access: crate::FastHashSet<help::WrappedStructMatrixAccess>, + mat_cx2s: crate::FastHashSet<help::WrappedMatCx2>, +} + +impl Wrapped { + fn clear(&mut self) { + self.array_lengths.clear(); + self.image_queries.clear(); + self.constructors.clear(); + self.struct_matrix_access.clear(); + self.mat_cx2s.clear(); + } +} + +pub struct Writer<'a, W> { + out: W, + names: crate::FastHashMap<proc::NameKey, String>, + namer: proc::Namer, + /// HLSL backend options + options: &'a Options, + /// Information about entry point arguments and result types. + entry_point_io: Vec<writer::EntryPointInterface>, + /// Set of expressions that have associated temporary variables + named_expressions: crate::NamedExpressions, + wrapped: Wrapped, + + /// A reference to some part of a global variable, lowered to a series of + /// byte offset calculations. + /// + /// See the [`storage`] module for background on why we need this. + /// + /// Each [`SubAccess`] in the vector is a lowering of some [`Access`] or + /// [`AccessIndex`] expression to the level of byte strides and offsets. See + /// [`SubAccess`] for details. + /// + /// This field is a member of [`Writer`] solely to allow re-use of + /// the `Vec`'s dynamic allocation. The value is no longer needed + /// once HLSL for the access has been generated. + /// + /// [`Storage`]: crate::AddressSpace::Storage + /// [`SubAccess`]: storage::SubAccess + /// [`Access`]: crate::Expression::Access + /// [`AccessIndex`]: crate::Expression::AccessIndex + temp_access_chain: Vec<storage::SubAccess>, + need_bake_expressions: back::NeedBakeExpressions, +} diff --git a/third_party/rust/naga/src/back/hlsl/storage.rs b/third_party/rust/naga/src/back/hlsl/storage.rs new file mode 100644 index 0000000000..1b8a6ec12d --- /dev/null +++ b/third_party/rust/naga/src/back/hlsl/storage.rs @@ -0,0 +1,494 @@ +/*! +Generating accesses to [`ByteAddressBuffer`] contents. + +Naga IR globals in the [`Storage`] address space are rendered as +[`ByteAddressBuffer`]s or [`RWByteAddressBuffer`]s in HLSL. These +buffers don't have HLSL types (structs, arrays, etc.); instead, they +are just raw blocks of bytes, with methods to load and store values of +specific types at particular byte offsets. This means that Naga must +translate chains of [`Access`] and [`AccessIndex`] expressions into +HLSL expressions that compute byte offsets into the buffer. + +To generate code for a [`Storage`] access: + +- Call [`Writer::fill_access_chain`] on the expression referring to + the value. This populates [`Writer::temp_access_chain`] with the + appropriate byte offset calculations, as a vector of [`SubAccess`] + values. + +- Call [`Writer::write_storage_address`] to emit an HLSL expression + for a given slice of [`SubAccess`] values. + +Naga IR expressions can operate on composite values of any type, but +[`ByteAddressBuffer`] and [`RWByteAddressBuffer`] have only a fixed +set of `Load` and `Store` methods, to access one through four +consecutive 32-bit values. To synthesize a Naga access, you can +initialize [`temp_access_chain`] to refer to the composite, and then +temporarily push and pop additional steps on +[`Writer::temp_access_chain`] to generate accesses to the individual +elements/members. + +The [`temp_access_chain`] field is a member of [`Writer`] solely to +allow re-use of the `Vec`'s dynamic allocation. Its value is no longer +needed once HLSL for the access has been generated. + +[`Storage`]: crate::AddressSpace::Storage +[`ByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-byteaddressbuffer +[`RWByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer +[`Access`]: crate::Expression::Access +[`AccessIndex`]: crate::Expression::AccessIndex +[`Writer::fill_access_chain`]: super::Writer::fill_access_chain +[`Writer::write_storage_address`]: super::Writer::write_storage_address +[`Writer::temp_access_chain`]: super::Writer::temp_access_chain +[`temp_access_chain`]: super::Writer::temp_access_chain +[`Writer`]: super::Writer +*/ + +use super::{super::FunctionCtx, BackendResult, Error}; +use crate::{ + proc::{Alignment, NameKey, TypeResolution}, + Handle, +}; + +use std::{fmt, mem}; + +const STORE_TEMP_NAME: &str = "_value"; + +/// One step in accessing a [`Storage`] global's component or element. +/// +/// [`Writer::temp_access_chain`] holds a series of these structures, +/// describing how to compute the byte offset of a particular element +/// or member of some global variable in the [`Storage`] address +/// space. +/// +/// [`Writer::temp_access_chain`]: super::Writer::temp_access_chain +/// [`Storage`]: crate::AddressSpace::Storage +#[derive(Debug)] +pub(super) enum SubAccess { + /// Add the given byte offset. This is used for struct members, or + /// known components of a vector or matrix. In all those cases, + /// the byte offset is a compile-time constant. + Offset(u32), + + /// Scale `value` by `stride`, and add that to the current byte + /// offset. This is used to compute the offset of an array element + /// whose index is computed at runtime. + Index { + value: Handle<crate::Expression>, + stride: u32, + }, +} + +pub(super) enum StoreValue { + Expression(Handle<crate::Expression>), + TempIndex { + depth: usize, + index: u32, + ty: TypeResolution, + }, + TempAccess { + depth: usize, + base: Handle<crate::Type>, + member_index: u32, + }, +} + +impl<W: fmt::Write> super::Writer<'_, W> { + pub(super) fn write_storage_address( + &mut self, + module: &crate::Module, + chain: &[SubAccess], + func_ctx: &FunctionCtx, + ) -> BackendResult { + if chain.is_empty() { + write!(self.out, "0")?; + } + for (i, access) in chain.iter().enumerate() { + if i != 0 { + write!(self.out, "+")?; + } + match *access { + SubAccess::Offset(offset) => { + write!(self.out, "{offset}")?; + } + SubAccess::Index { value, stride } => { + self.write_expr(module, value, func_ctx)?; + write!(self.out, "*{stride}")?; + } + } + } + Ok(()) + } + + fn write_storage_load_sequence<I: Iterator<Item = (TypeResolution, u32)>>( + &mut self, + module: &crate::Module, + var_handle: Handle<crate::GlobalVariable>, + sequence: I, + func_ctx: &FunctionCtx, + ) -> BackendResult { + for (i, (ty_resolution, offset)) in sequence.enumerate() { + // add the index temporarily + self.temp_access_chain.push(SubAccess::Offset(offset)); + if i != 0 { + write!(self.out, ", ")?; + }; + self.write_storage_load(module, var_handle, ty_resolution, func_ctx)?; + self.temp_access_chain.pop(); + } + Ok(()) + } + + /// Emit code to access a [`Storage`] global's component. + /// + /// Emit HLSL to access the component of `var_handle`, a global + /// variable in the [`Storage`] address space, whose type is + /// `result_ty` and whose location within the global is given by + /// [`self.temp_access_chain`]. See the [`storage`] module's + /// documentation for background. + /// + /// [`Storage`]: crate::AddressSpace::Storage + /// [`self.temp_access_chain`]: super::Writer::temp_access_chain + pub(super) fn write_storage_load( + &mut self, + module: &crate::Module, + var_handle: Handle<crate::GlobalVariable>, + result_ty: TypeResolution, + func_ctx: &FunctionCtx, + ) -> BackendResult { + match *result_ty.inner_with(&module.types) { + crate::TypeInner::Scalar(scalar) => { + // working around the borrow checker in `self.write_expr` + let chain = mem::take(&mut self.temp_access_chain); + let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; + let cast = scalar.kind.to_hlsl_cast(); + write!(self.out, "{cast}({var_name}.Load(")?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, "))")?; + self.temp_access_chain = chain; + } + crate::TypeInner::Vector { size, scalar } => { + // working around the borrow checker in `self.write_expr` + let chain = mem::take(&mut self.temp_access_chain); + let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; + let cast = scalar.kind.to_hlsl_cast(); + write!(self.out, "{}({}.Load{}(", cast, var_name, size as u8)?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, "))")?; + self.temp_access_chain = chain; + } + crate::TypeInner::Matrix { + columns, + rows, + scalar, + } => { + write!( + self.out, + "{}{}x{}(", + scalar.to_hlsl_str()?, + columns as u8, + rows as u8, + )?; + + // Note: Matrices containing vec3s, due to padding, act like they contain vec4s. + let row_stride = Alignment::from(rows) * scalar.width as u32; + let iter = (0..columns as u32).map(|i| { + let ty_inner = crate::TypeInner::Vector { size: rows, scalar }; + (TypeResolution::Value(ty_inner), i * row_stride) + }); + self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?; + write!(self.out, ")")?; + } + crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + stride, + } => { + let constructor = super::help::WrappedConstructor { + ty: result_ty.handle().unwrap(), + }; + self.write_wrapped_constructor_function_name(module, constructor)?; + write!(self.out, "(")?; + let iter = (0..size.get()).map(|i| (TypeResolution::Handle(base), stride * i)); + self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?; + write!(self.out, ")")?; + } + crate::TypeInner::Struct { ref members, .. } => { + let constructor = super::help::WrappedConstructor { + ty: result_ty.handle().unwrap(), + }; + self.write_wrapped_constructor_function_name(module, constructor)?; + write!(self.out, "(")?; + let iter = members + .iter() + .map(|m| (TypeResolution::Handle(m.ty), m.offset)); + self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?; + write!(self.out, ")")?; + } + _ => unreachable!(), + } + Ok(()) + } + + fn write_store_value( + &mut self, + module: &crate::Module, + value: &StoreValue, + func_ctx: &FunctionCtx, + ) -> BackendResult { + match *value { + StoreValue::Expression(expr) => self.write_expr(module, expr, func_ctx)?, + StoreValue::TempIndex { + depth, + index, + ty: _, + } => write!(self.out, "{STORE_TEMP_NAME}{depth}[{index}]")?, + StoreValue::TempAccess { + depth, + base, + member_index, + } => { + let name = &self.names[&NameKey::StructMember(base, member_index)]; + write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}")? + } + } + Ok(()) + } + + /// Helper function to write down the Store operation on a `ByteAddressBuffer`. + pub(super) fn write_storage_store( + &mut self, + module: &crate::Module, + var_handle: Handle<crate::GlobalVariable>, + value: StoreValue, + func_ctx: &FunctionCtx, + level: crate::back::Level, + ) -> BackendResult { + let temp_resolution; + let ty_resolution = match value { + StoreValue::Expression(expr) => &func_ctx.info[expr].ty, + StoreValue::TempIndex { + depth: _, + index: _, + ref ty, + } => ty, + StoreValue::TempAccess { + depth: _, + base, + member_index, + } => { + let ty_handle = match module.types[base].inner { + crate::TypeInner::Struct { ref members, .. } => { + members[member_index as usize].ty + } + _ => unreachable!(), + }; + temp_resolution = TypeResolution::Handle(ty_handle); + &temp_resolution + } + }; + match *ty_resolution.inner_with(&module.types) { + crate::TypeInner::Scalar(_) => { + // working around the borrow checker in `self.write_expr` + let chain = mem::take(&mut self.temp_access_chain); + let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; + write!(self.out, "{level}{var_name}.Store(")?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", asuint(")?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, "));")?; + self.temp_access_chain = chain; + } + crate::TypeInner::Vector { size, .. } => { + // working around the borrow checker in `self.write_expr` + let chain = mem::take(&mut self.temp_access_chain); + let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; + write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", asuint(")?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, "));")?; + self.temp_access_chain = chain; + } + crate::TypeInner::Matrix { + columns, + rows, + scalar, + } => { + // first, assign the value to a temporary + writeln!(self.out, "{level}{{")?; + let depth = level.0 + 1; + write!( + self.out, + "{}{}{}x{} {}{} = ", + level.next(), + scalar.to_hlsl_str()?, + columns as u8, + rows as u8, + STORE_TEMP_NAME, + depth, + )?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, ";")?; + + // Note: Matrices containing vec3s, due to padding, act like they contain vec4s. + let row_stride = Alignment::from(rows) * scalar.width as u32; + + // then iterate the stores + for i in 0..columns as u32 { + self.temp_access_chain + .push(SubAccess::Offset(i * row_stride)); + let ty_inner = crate::TypeInner::Vector { size: rows, scalar }; + let sv = StoreValue::TempIndex { + depth, + index: i, + ty: TypeResolution::Value(ty_inner), + }; + self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?; + self.temp_access_chain.pop(); + } + // done + writeln!(self.out, "{level}}}")?; + } + crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + stride, + } => { + // first, assign the value to a temporary + writeln!(self.out, "{level}{{")?; + write!(self.out, "{}", level.next())?; + self.write_value_type(module, &module.types[base].inner)?; + let depth = level.next().0; + write!(self.out, " {STORE_TEMP_NAME}{depth}")?; + self.write_array_size(module, base, crate::ArraySize::Constant(size))?; + write!(self.out, " = ")?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, ";")?; + // then iterate the stores + for i in 0..size.get() { + self.temp_access_chain.push(SubAccess::Offset(i * stride)); + let sv = StoreValue::TempIndex { + depth, + index: i, + ty: TypeResolution::Handle(base), + }; + self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?; + self.temp_access_chain.pop(); + } + // done + writeln!(self.out, "{level}}}")?; + } + crate::TypeInner::Struct { ref members, .. } => { + // first, assign the value to a temporary + writeln!(self.out, "{level}{{")?; + let depth = level.next().0; + let struct_ty = ty_resolution.handle().unwrap(); + let struct_name = &self.names[&NameKey::Type(struct_ty)]; + write!( + self.out, + "{}{} {}{} = ", + level.next(), + struct_name, + STORE_TEMP_NAME, + depth + )?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, ";")?; + // then iterate the stores + for (i, member) in members.iter().enumerate() { + self.temp_access_chain + .push(SubAccess::Offset(member.offset)); + let sv = StoreValue::TempAccess { + depth, + base: struct_ty, + member_index: i as u32, + }; + self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?; + self.temp_access_chain.pop(); + } + // done + writeln!(self.out, "{level}}}")?; + } + _ => unreachable!(), + } + Ok(()) + } + + /// Set [`temp_access_chain`] to compute the byte offset of `cur_expr`. + /// + /// The `cur_expr` expression must be a reference to a global + /// variable in the [`Storage`] address space, or a chain of + /// [`Access`] and [`AccessIndex`] expressions referring to some + /// component of such a global. + /// + /// [`temp_access_chain`]: super::Writer::temp_access_chain + /// [`Storage`]: crate::AddressSpace::Storage + /// [`Access`]: crate::Expression::Access + /// [`AccessIndex`]: crate::Expression::AccessIndex + pub(super) fn fill_access_chain( + &mut self, + module: &crate::Module, + mut cur_expr: Handle<crate::Expression>, + func_ctx: &FunctionCtx, + ) -> Result<Handle<crate::GlobalVariable>, Error> { + enum AccessIndex { + Expression(Handle<crate::Expression>), + Constant(u32), + } + enum Parent<'a> { + Array { stride: u32 }, + Struct(&'a [crate::StructMember]), + } + self.temp_access_chain.clear(); + + loop { + let (next_expr, access_index) = match func_ctx.expressions[cur_expr] { + crate::Expression::GlobalVariable(handle) => return Ok(handle), + crate::Expression::Access { base, index } => (base, AccessIndex::Expression(index)), + crate::Expression::AccessIndex { base, index } => { + (base, AccessIndex::Constant(index)) + } + ref other => { + return Err(Error::Unimplemented(format!("Pointer access of {other:?}"))) + } + }; + + let parent = match *func_ctx.resolve_type(next_expr, &module.types) { + crate::TypeInner::Pointer { base, .. } => match module.types[base].inner { + crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members), + crate::TypeInner::Array { stride, .. } => Parent::Array { stride }, + crate::TypeInner::Vector { scalar, .. } => Parent::Array { + stride: scalar.width as u32, + }, + crate::TypeInner::Matrix { rows, scalar, .. } => Parent::Array { + // The stride between matrices is the count of rows as this is how + // long each column is. + stride: Alignment::from(rows) * scalar.width as u32, + }, + _ => unreachable!(), + }, + crate::TypeInner::ValuePointer { scalar, .. } => Parent::Array { + stride: scalar.width as u32, + }, + _ => unreachable!(), + }; + + let sub = match (parent, access_index) { + (Parent::Array { stride }, AccessIndex::Expression(value)) => { + SubAccess::Index { value, stride } + } + (Parent::Array { stride }, AccessIndex::Constant(index)) => { + SubAccess::Offset(stride * index) + } + (Parent::Struct(members), AccessIndex::Constant(index)) => { + SubAccess::Offset(members[index as usize].offset) + } + (Parent::Struct(_), AccessIndex::Expression(_)) => unreachable!(), + }; + + self.temp_access_chain.push(sub); + cur_expr = next_expr; + } + } +} diff --git a/third_party/rust/naga/src/back/hlsl/writer.rs b/third_party/rust/naga/src/back/hlsl/writer.rs new file mode 100644 index 0000000000..43f7212837 --- /dev/null +++ b/third_party/rust/naga/src/back/hlsl/writer.rs @@ -0,0 +1,3366 @@ +use super::{ + help::{WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess}, + storage::StoreValue, + BackendResult, Error, Options, +}; +use crate::{ + back, + proc::{self, NameKey}, + valid, Handle, Module, ScalarKind, ShaderStage, TypeInner, +}; +use std::{fmt, mem}; + +const LOCATION_SEMANTIC: &str = "LOC"; +const SPECIAL_CBUF_TYPE: &str = "NagaConstants"; +const SPECIAL_CBUF_VAR: &str = "_NagaConstants"; +const SPECIAL_FIRST_VERTEX: &str = "first_vertex"; +const SPECIAL_FIRST_INSTANCE: &str = "first_instance"; +const SPECIAL_OTHER: &str = "other"; + +pub(crate) const MODF_FUNCTION: &str = "naga_modf"; +pub(crate) const FREXP_FUNCTION: &str = "naga_frexp"; + +struct EpStructMember { + name: String, + ty: Handle<crate::Type>, + // technically, this should always be `Some` + binding: Option<crate::Binding>, + index: u32, +} + +/// Structure contains information required for generating +/// wrapped structure of all entry points arguments +struct EntryPointBinding { + /// Name of the fake EP argument that contains the struct + /// with all the flattened input data. + arg_name: String, + /// Generated structure name + ty_name: String, + /// Members of generated structure + members: Vec<EpStructMember>, +} + +pub(super) struct EntryPointInterface { + /// If `Some`, the input of an entry point is gathered in a special + /// struct with members sorted by binding. + /// The `EntryPointBinding::members` array is sorted by index, + /// so that we can walk it in `write_ep_arguments_initialization`. + input: Option<EntryPointBinding>, + /// If `Some`, the output of an entry point is flattened. + /// The `EntryPointBinding::members` array is sorted by binding, + /// So that we can walk it in `Statement::Return` handler. + output: Option<EntryPointBinding>, +} + +#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)] +enum InterfaceKey { + Location(u32), + BuiltIn(crate::BuiltIn), + Other, +} + +impl InterfaceKey { + const fn new(binding: Option<&crate::Binding>) -> Self { + match binding { + Some(&crate::Binding::Location { location, .. }) => Self::Location(location), + Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in), + None => Self::Other, + } + } +} + +#[derive(Copy, Clone, PartialEq)] +enum Io { + Input, + Output, +} + +impl<'a, W: fmt::Write> super::Writer<'a, W> { + pub fn new(out: W, options: &'a Options) -> Self { + Self { + out, + names: crate::FastHashMap::default(), + namer: proc::Namer::default(), + options, + entry_point_io: Vec::new(), + named_expressions: crate::NamedExpressions::default(), + wrapped: super::Wrapped::default(), + temp_access_chain: Vec::new(), + need_bake_expressions: Default::default(), + } + } + + fn reset(&mut self, module: &Module) { + self.names.clear(); + self.namer.reset( + module, + super::keywords::RESERVED, + super::keywords::TYPES, + super::keywords::RESERVED_CASE_INSENSITIVE, + &[], + &mut self.names, + ); + self.entry_point_io.clear(); + self.named_expressions.clear(); + self.wrapped.clear(); + self.need_bake_expressions.clear(); + } + + /// Helper method used to find which expressions of a given function require baking + /// + /// # Notes + /// Clears `need_bake_expressions` set before adding to it + fn update_expressions_to_bake( + &mut self, + module: &Module, + func: &crate::Function, + info: &valid::FunctionInfo, + ) { + use crate::Expression; + self.need_bake_expressions.clear(); + for (fun_handle, expr) in func.expressions.iter() { + let expr_info = &info[fun_handle]; + let min_ref_count = func.expressions[fun_handle].bake_ref_count(); + if min_ref_count <= expr_info.ref_count { + self.need_bake_expressions.insert(fun_handle); + } + + if let Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } = *expr + { + match fun { + crate::MathFunction::Asinh + | crate::MathFunction::Acosh + | crate::MathFunction::Atanh + | crate::MathFunction::Unpack2x16float + | crate::MathFunction::Unpack2x16snorm + | crate::MathFunction::Unpack2x16unorm + | crate::MathFunction::Unpack4x8snorm + | crate::MathFunction::Unpack4x8unorm + | crate::MathFunction::Pack2x16float + | crate::MathFunction::Pack2x16snorm + | crate::MathFunction::Pack2x16unorm + | crate::MathFunction::Pack4x8snorm + | crate::MathFunction::Pack4x8unorm => { + self.need_bake_expressions.insert(arg); + } + crate::MathFunction::ExtractBits => { + self.need_bake_expressions.insert(arg); + self.need_bake_expressions.insert(arg1.unwrap()); + self.need_bake_expressions.insert(arg2.unwrap()); + } + crate::MathFunction::InsertBits => { + self.need_bake_expressions.insert(arg); + self.need_bake_expressions.insert(arg1.unwrap()); + self.need_bake_expressions.insert(arg2.unwrap()); + self.need_bake_expressions.insert(arg3.unwrap()); + } + crate::MathFunction::CountLeadingZeros => { + let inner = info[fun_handle].ty.inner_with(&module.types); + if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { + self.need_bake_expressions.insert(arg); + } + } + _ => {} + } + } + + if let Expression::Derivative { axis, ctrl, expr } = *expr { + use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; + if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) { + self.need_bake_expressions.insert(expr); + } + } + } + } + + pub fn write( + &mut self, + module: &Module, + module_info: &valid::ModuleInfo, + ) -> Result<super::ReflectionInfo, Error> { + self.reset(module); + + // Write special constants, if needed + if let Some(ref bt) = self.options.special_constants_binding { + writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?; + writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?; + writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?; + writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?; + writeln!(self.out, "}};")?; + write!( + self.out, + "ConstantBuffer<{}> {}: register(b{}", + SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register + )?; + if bt.space != 0 { + write!(self.out, ", space{}", bt.space)?; + } + writeln!(self.out, ");")?; + + // Extra newline for readability + writeln!(self.out)?; + } + + // Save all entry point output types + let ep_results = module + .entry_points + .iter() + .map(|ep| (ep.stage, ep.function.result.clone())) + .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>(); + + self.write_all_mat_cx2_typedefs_and_functions(module)?; + + // Write all structs + for (handle, ty) in module.types.iter() { + if let TypeInner::Struct { ref members, span } = ty.inner { + if module.types[members.last().unwrap().ty] + .inner + .is_dynamically_sized(&module.types) + { + // unsized arrays can only be in storage buffers, + // for which we use `ByteAddressBuffer` anyway. + continue; + } + + let ep_result = ep_results.iter().find(|e| { + if let Some(ref result) = e.1 { + result.ty == handle + } else { + false + } + }); + + self.write_struct( + module, + handle, + members, + span, + ep_result.map(|r| (r.0, Io::Output)), + )?; + writeln!(self.out)?; + } + } + + self.write_special_functions(module)?; + + self.write_wrapped_compose_functions(module, &module.const_expressions)?; + + // Write all named constants + let mut constants = module + .constants + .iter() + .filter(|&(_, c)| c.name.is_some()) + .peekable(); + while let Some((handle, _)) = constants.next() { + self.write_global_constant(module, handle)?; + // Add extra newline for readability on last iteration + if constants.peek().is_none() { + writeln!(self.out)?; + } + } + + // Write all globals + for (ty, _) in module.global_variables.iter() { + self.write_global(module, ty)?; + } + + if !module.global_variables.is_empty() { + // Add extra newline for readability + writeln!(self.out)?; + } + + // Write all entry points wrapped structs + for (index, ep) in module.entry_points.iter().enumerate() { + let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone(); + let ep_io = self.write_ep_interface(module, &ep.function, ep.stage, &ep_name)?; + self.entry_point_io.push(ep_io); + } + + // Write all regular functions + for (handle, function) in module.functions.iter() { + let info = &module_info[handle]; + + // Check if all of the globals are accessible + if !self.options.fake_missing_bindings { + if let Some((var_handle, _)) = + module + .global_variables + .iter() + .find(|&(var_handle, var)| match var.binding { + Some(ref binding) if !info[var_handle].is_empty() => { + self.options.resolve_resource_binding(binding).is_err() + } + _ => false, + }) + { + log::info!( + "Skipping function {:?} (name {:?}) because global {:?} is inaccessible", + handle, + function.name, + var_handle + ); + continue; + } + } + + let ctx = back::FunctionCtx { + ty: back::FunctionType::Function(handle), + info, + expressions: &function.expressions, + named_expressions: &function.named_expressions, + }; + let name = self.names[&NameKey::Function(handle)].clone(); + + self.write_wrapped_functions(module, &ctx)?; + + self.write_function(module, name.as_str(), function, &ctx, info)?; + + writeln!(self.out)?; + } + + let mut entry_point_names = Vec::with_capacity(module.entry_points.len()); + + // Write all entry points + for (index, ep) in module.entry_points.iter().enumerate() { + let info = module_info.get_entry_point(index); + + if !self.options.fake_missing_bindings { + let mut ep_error = None; + for (var_handle, var) in module.global_variables.iter() { + match var.binding { + Some(ref binding) if !info[var_handle].is_empty() => { + if let Err(err) = self.options.resolve_resource_binding(binding) { + ep_error = Some(err); + break; + } + } + _ => {} + } + } + if let Some(err) = ep_error { + entry_point_names.push(Err(err)); + continue; + } + } + + let ctx = back::FunctionCtx { + ty: back::FunctionType::EntryPoint(index as u16), + info, + expressions: &ep.function.expressions, + named_expressions: &ep.function.named_expressions, + }; + + self.write_wrapped_functions(module, &ctx)?; + + if ep.stage == ShaderStage::Compute { + // HLSL is calling workgroup size "num threads" + let num_threads = ep.workgroup_size; + writeln!( + self.out, + "[numthreads({}, {}, {})]", + num_threads[0], num_threads[1], num_threads[2] + )?; + } + + let name = self.names[&NameKey::EntryPoint(index as u16)].clone(); + self.write_function(module, &name, &ep.function, &ctx, info)?; + + if index < module.entry_points.len() - 1 { + writeln!(self.out)?; + } + + entry_point_names.push(Ok(name)); + } + + Ok(super::ReflectionInfo { entry_point_names }) + } + + fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult { + match *binding { + crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => { + write!(self.out, "precise ")?; + } + crate::Binding::Location { + interpolation, + sampling, + .. + } => { + if let Some(interpolation) = interpolation { + if let Some(string) = interpolation.to_hlsl_str() { + write!(self.out, "{string} ")? + } + } + + if let Some(sampling) = sampling { + if let Some(string) = sampling.to_hlsl_str() { + write!(self.out, "{string} ")? + } + } + } + crate::Binding::BuiltIn(_) => {} + } + + Ok(()) + } + + //TODO: we could force fragment outputs to always go through `entry_point_io.output` path + // if they are struct, so that the `stage` argument here could be omitted. + fn write_semantic( + &mut self, + binding: &crate::Binding, + stage: Option<(ShaderStage, Io)>, + ) -> BackendResult { + match *binding { + crate::Binding::BuiltIn(builtin) => { + let builtin_str = builtin.to_hlsl_str()?; + write!(self.out, " : {builtin_str}")?; + } + crate::Binding::Location { + second_blend_source: true, + .. + } => { + write!(self.out, " : SV_Target1")?; + } + crate::Binding::Location { + location, + second_blend_source: false, + .. + } => { + if stage == Some((crate::ShaderStage::Fragment, Io::Output)) { + write!(self.out, " : SV_Target{location}")?; + } else { + write!(self.out, " : {LOCATION_SEMANTIC}{location}")?; + } + } + } + + Ok(()) + } + + fn write_interface_struct( + &mut self, + module: &Module, + shader_stage: (ShaderStage, Io), + struct_name: String, + mut members: Vec<EpStructMember>, + ) -> Result<EntryPointBinding, Error> { + // Sort the members so that first come the user-defined varyings + // in ascending locations, and then built-ins. This allows VS and FS + // interfaces to match with regards to order. + members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref())); + + write!(self.out, "struct {struct_name}")?; + writeln!(self.out, " {{")?; + for m in members.iter() { + write!(self.out, "{}", back::INDENT)?; + if let Some(ref binding) = m.binding { + self.write_modifier(binding)?; + } + self.write_type(module, m.ty)?; + write!(self.out, " {}", &m.name)?; + if let Some(ref binding) = m.binding { + self.write_semantic(binding, Some(shader_stage))?; + } + writeln!(self.out, ";")?; + } + writeln!(self.out, "}};")?; + writeln!(self.out)?; + + match shader_stage.1 { + Io::Input => { + // bring back the original order + members.sort_by_key(|m| m.index); + } + Io::Output => { + // keep it sorted by binding + } + } + + Ok(EntryPointBinding { + arg_name: self.namer.call(struct_name.to_lowercase().as_str()), + ty_name: struct_name, + members, + }) + } + + /// Flatten all entry point arguments into a single struct. + /// This is needed since we need to re-order them: first placing user locations, + /// then built-ins. + fn write_ep_input_struct( + &mut self, + module: &Module, + func: &crate::Function, + stage: ShaderStage, + entry_point_name: &str, + ) -> Result<EntryPointBinding, Error> { + let struct_name = format!("{stage:?}Input_{entry_point_name}"); + + let mut fake_members = Vec::new(); + for arg in func.arguments.iter() { + match module.types[arg.ty].inner { + TypeInner::Struct { ref members, .. } => { + for member in members.iter() { + let name = self.namer.call_or(&member.name, "member"); + let index = fake_members.len() as u32; + fake_members.push(EpStructMember { + name, + ty: member.ty, + binding: member.binding.clone(), + index, + }); + } + } + _ => { + let member_name = self.namer.call_or(&arg.name, "member"); + let index = fake_members.len() as u32; + fake_members.push(EpStructMember { + name: member_name, + ty: arg.ty, + binding: arg.binding.clone(), + index, + }); + } + } + } + + self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members) + } + + /// Flatten all entry point results into a single struct. + /// This is needed since we need to re-order them: first placing user locations, + /// then built-ins. + fn write_ep_output_struct( + &mut self, + module: &Module, + result: &crate::FunctionResult, + stage: ShaderStage, + entry_point_name: &str, + ) -> Result<EntryPointBinding, Error> { + let struct_name = format!("{stage:?}Output_{entry_point_name}"); + + let mut fake_members = Vec::new(); + let empty = []; + let members = match module.types[result.ty].inner { + TypeInner::Struct { ref members, .. } => members, + ref other => { + log::error!("Unexpected {:?} output type without a binding", other); + &empty[..] + } + }; + + for member in members.iter() { + let member_name = self.namer.call_or(&member.name, "member"); + let index = fake_members.len() as u32; + fake_members.push(EpStructMember { + name: member_name, + ty: member.ty, + binding: member.binding.clone(), + index, + }); + } + + self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members) + } + + /// Writes special interface structures for an entry point. The special structures have + /// all the fields flattened into them and sorted by binding. They are only needed for + /// VS outputs and FS inputs, so that these interfaces match. + fn write_ep_interface( + &mut self, + module: &Module, + func: &crate::Function, + stage: ShaderStage, + ep_name: &str, + ) -> Result<EntryPointInterface, Error> { + Ok(EntryPointInterface { + input: if !func.arguments.is_empty() && stage == ShaderStage::Fragment { + Some(self.write_ep_input_struct(module, func, stage, ep_name)?) + } else { + None + }, + output: match func.result { + Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => { + Some(self.write_ep_output_struct(module, fr, stage, ep_name)?) + } + _ => None, + }, + }) + } + + /// Write an entry point preface that initializes the arguments as specified in IR. + fn write_ep_arguments_initialization( + &mut self, + module: &Module, + func: &crate::Function, + ep_index: u16, + ) -> BackendResult { + let ep_input = match self.entry_point_io[ep_index as usize].input.take() { + Some(ep_input) => ep_input, + None => return Ok(()), + }; + let mut fake_iter = ep_input.members.iter(); + for (arg_index, arg) in func.arguments.iter().enumerate() { + write!(self.out, "{}", back::INDENT)?; + self.write_type(module, arg.ty)?; + let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)]; + write!(self.out, " {arg_name}")?; + match module.types[arg.ty].inner { + TypeInner::Array { base, size, .. } => { + self.write_array_size(module, base, size)?; + let fake_member = fake_iter.next().unwrap(); + writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?; + } + TypeInner::Struct { ref members, .. } => { + write!(self.out, " = {{ ")?; + for index in 0..members.len() { + if index != 0 { + write!(self.out, ", ")?; + } + let fake_member = fake_iter.next().unwrap(); + write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?; + } + writeln!(self.out, " }};")?; + } + _ => { + let fake_member = fake_iter.next().unwrap(); + writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?; + } + } + } + assert!(fake_iter.next().is_none()); + Ok(()) + } + + /// Helper method used to write global variables + /// # Notes + /// Always adds a newline + fn write_global( + &mut self, + module: &Module, + handle: Handle<crate::GlobalVariable>, + ) -> BackendResult { + let global = &module.global_variables[handle]; + let inner = &module.types[global.ty].inner; + + if let Some(ref binding) = global.binding { + if let Err(err) = self.options.resolve_resource_binding(binding) { + log::info!( + "Skipping global {:?} (name {:?}) for being inaccessible: {}", + handle, + global.name, + err, + ); + return Ok(()); + } + } + + // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register + let register_ty = match global.space { + crate::AddressSpace::Function => unreachable!("Function address space"), + crate::AddressSpace::Private => { + write!(self.out, "static ")?; + self.write_type(module, global.ty)?; + "" + } + crate::AddressSpace::WorkGroup => { + write!(self.out, "groupshared ")?; + self.write_type(module, global.ty)?; + "" + } + crate::AddressSpace::Uniform => { + // constant buffer declarations are expected to be inlined, e.g. + // `cbuffer foo: register(b0) { field1: type1; }` + write!(self.out, "cbuffer")?; + "b" + } + crate::AddressSpace::Storage { access } => { + let (prefix, register) = if access.contains(crate::StorageAccess::STORE) { + ("RW", "u") + } else { + ("", "t") + }; + write!(self.out, "{prefix}ByteAddressBuffer")?; + register + } + crate::AddressSpace::Handle => { + let handle_ty = match *inner { + TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner, + _ => inner, + }; + + let register = match *handle_ty { + TypeInner::Sampler { .. } => "s", + // all storage textures are UAV, unconditionally + TypeInner::Image { + class: crate::ImageClass::Storage { .. }, + .. + } => "u", + _ => "t", + }; + self.write_type(module, global.ty)?; + register + } + crate::AddressSpace::PushConstant => { + // The type of the push constants will be wrapped in `ConstantBuffer` + write!(self.out, "ConstantBuffer<")?; + "b" + } + }; + + // If the global is a push constant write the type now because it will be a + // generic argument to `ConstantBuffer` + if global.space == crate::AddressSpace::PushConstant { + self.write_global_type(module, global.ty)?; + + // need to write the array size if the type was emitted with `write_type` + if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner { + self.write_array_size(module, base, size)?; + } + + // Close the angled brackets for the generic argument + write!(self.out, ">")?; + } + + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, " {name}")?; + + // Push constants need to be assigned a binding explicitly by the consumer + // since naga has no way to know the binding from the shader alone + if global.space == crate::AddressSpace::PushConstant { + let target = self + .options + .push_constants_target + .as_ref() + .expect("No bind target was defined for the push constants block"); + write!(self.out, ": register(b{}", target.register)?; + if target.space != 0 { + write!(self.out, ", space{}", target.space)?; + } + write!(self.out, ")")?; + } + + if let Some(ref binding) = global.binding { + // this was already resolved earlier when we started evaluating an entry point. + let bt = self.options.resolve_resource_binding(binding).unwrap(); + + // need to write the binding array size if the type was emitted with `write_type` + if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner { + if let Some(overridden_size) = bt.binding_array_size { + write!(self.out, "[{overridden_size}]")?; + } else { + self.write_array_size(module, base, size)?; + } + } + + write!(self.out, " : register({}{}", register_ty, bt.register)?; + if bt.space != 0 { + write!(self.out, ", space{}", bt.space)?; + } + write!(self.out, ")")?; + } else { + // need to write the array size if the type was emitted with `write_type` + if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner { + self.write_array_size(module, base, size)?; + } + if global.space == crate::AddressSpace::Private { + write!(self.out, " = ")?; + if let Some(init) = global.init { + self.write_const_expression(module, init)?; + } else { + self.write_default_init(module, global.ty)?; + } + } + } + + if global.space == crate::AddressSpace::Uniform { + write!(self.out, " {{ ")?; + + self.write_global_type(module, global.ty)?; + + write!( + self.out, + " {}", + &self.names[&NameKey::GlobalVariable(handle)] + )?; + + // need to write the array size if the type was emitted with `write_type` + if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner { + self.write_array_size(module, base, size)?; + } + + writeln!(self.out, "; }}")?; + } else { + writeln!(self.out, ";")?; + } + + Ok(()) + } + + /// Helper method used to write global constants + /// + /// # Notes + /// Ends in a newline + fn write_global_constant( + &mut self, + module: &Module, + handle: Handle<crate::Constant>, + ) -> BackendResult { + write!(self.out, "static const ")?; + let constant = &module.constants[handle]; + self.write_type(module, constant.ty)?; + let name = &self.names[&NameKey::Constant(handle)]; + write!(self.out, " {}", name)?; + // Write size for array type + if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner { + self.write_array_size(module, base, size)?; + } + write!(self.out, " = ")?; + self.write_const_expression(module, constant.init)?; + writeln!(self.out, ";")?; + Ok(()) + } + + pub(super) fn write_array_size( + &mut self, + module: &Module, + base: Handle<crate::Type>, + size: crate::ArraySize, + ) -> BackendResult { + write!(self.out, "[")?; + + match size { + crate::ArraySize::Constant(size) => { + write!(self.out, "{size}")?; + } + crate::ArraySize::Dynamic => unreachable!(), + } + + write!(self.out, "]")?; + + if let TypeInner::Array { + base: next_base, + size: next_size, + .. + } = module.types[base].inner + { + self.write_array_size(module, next_base, next_size)?; + } + + Ok(()) + } + + /// Helper method used to write structs + /// + /// # Notes + /// Ends in a newline + fn write_struct( + &mut self, + module: &Module, + handle: Handle<crate::Type>, + members: &[crate::StructMember], + span: u32, + shader_stage: Option<(ShaderStage, Io)>, + ) -> BackendResult { + // Write struct name + let struct_name = &self.names[&NameKey::Type(handle)]; + writeln!(self.out, "struct {struct_name} {{")?; + + let mut last_offset = 0; + for (index, member) in members.iter().enumerate() { + if member.binding.is_none() && member.offset > last_offset { + // using int as padding should work as long as the backend + // doesn't support a type that's less than 4 bytes in size + // (Error::UnsupportedScalar catches this) + let padding = (member.offset - last_offset) / 4; + for i in 0..padding { + writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?; + } + } + let ty_inner = &module.types[member.ty].inner; + last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx()); + + // The indentation is only for readability + write!(self.out, "{}", back::INDENT)?; + + match module.types[member.ty].inner { + TypeInner::Array { base, size, .. } => { + // HLSL arrays are written as `type name[size]` + + self.write_global_type(module, member.ty)?; + + // Write `name` + write!( + self.out, + " {}", + &self.names[&NameKey::StructMember(handle, index as u32)] + )?; + // Write [size] + self.write_array_size(module, base, size)?; + } + // We treat matrices of the form `matCx2` as a sequence of C `vec2`s. + // See the module-level block comment in mod.rs for details. + TypeInner::Matrix { + rows, + columns, + scalar, + } if member.binding.is_none() && rows == crate::VectorSize::Bi => { + let vec_ty = crate::TypeInner::Vector { size: rows, scalar }; + let field_name_key = NameKey::StructMember(handle, index as u32); + + for i in 0..columns as u8 { + if i != 0 { + write!(self.out, "; ")?; + } + self.write_value_type(module, &vec_ty)?; + write!(self.out, " {}_{}", &self.names[&field_name_key], i)?; + } + } + _ => { + // Write modifier before type + if let Some(ref binding) = member.binding { + self.write_modifier(binding)?; + } + + // Even though Naga IR matrices are column-major, we must describe + // matrices passed from the CPU as being in row-major order. + // See the module-level block comment in mod.rs for details. + if let TypeInner::Matrix { .. } = module.types[member.ty].inner { + write!(self.out, "row_major ")?; + } + + // Write the member type and name + self.write_type(module, member.ty)?; + write!( + self.out, + " {}", + &self.names[&NameKey::StructMember(handle, index as u32)] + )?; + } + } + + if let Some(ref binding) = member.binding { + self.write_semantic(binding, shader_stage)?; + }; + writeln!(self.out, ";")?; + } + + // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL + if members.last().unwrap().binding.is_none() && span > last_offset { + let padding = (span - last_offset) / 4; + for i in 0..padding { + writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?; + } + } + + writeln!(self.out, "}};")?; + Ok(()) + } + + /// Helper method used to write global/structs non image/sampler types + /// + /// # Notes + /// Adds no trailing or leading whitespace + pub(super) fn write_global_type( + &mut self, + module: &Module, + ty: Handle<crate::Type>, + ) -> BackendResult { + let matrix_data = get_inner_matrix_data(module, ty); + + // We treat matrices of the form `matCx2` as a sequence of C `vec2`s. + // See the module-level block comment in mod.rs for details. + if let Some(MatrixType { + columns, + rows: crate::VectorSize::Bi, + width: 4, + }) = matrix_data + { + write!(self.out, "__mat{}x2", columns as u8)?; + } else { + // Even though Naga IR matrices are column-major, we must describe + // matrices passed from the CPU as being in row-major order. + // See the module-level block comment in mod.rs for details. + if matrix_data.is_some() { + write!(self.out, "row_major ")?; + } + + self.write_type(module, ty)?; + } + + Ok(()) + } + + /// Helper method used to write non image/sampler types + /// + /// # Notes + /// Adds no trailing or leading whitespace + pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult { + let inner = &module.types[ty].inner; + match *inner { + TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?, + // hlsl array has the size separated from the base type + TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => { + self.write_type(module, base)? + } + ref other => self.write_value_type(module, other)?, + } + + Ok(()) + } + + /// Helper method used to write value types + /// + /// # Notes + /// Adds no trailing or leading whitespace + pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult { + match *inner { + TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => { + write!(self.out, "{}", scalar.to_hlsl_str()?)?; + } + TypeInner::Vector { size, scalar } => { + write!( + self.out, + "{}{}", + scalar.to_hlsl_str()?, + back::vector_size_str(size) + )?; + } + TypeInner::Matrix { + columns, + rows, + scalar, + } => { + // The IR supports only float matrix + // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix + + // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well. + write!( + self.out, + "{}{}x{}", + scalar.to_hlsl_str()?, + back::vector_size_str(columns), + back::vector_size_str(rows), + )?; + } + TypeInner::Image { + dim, + arrayed, + class, + } => { + self.write_image_type(dim, arrayed, class)?; + } + TypeInner::Sampler { comparison } => { + let sampler = if comparison { + "SamplerComparisonState" + } else { + "SamplerState" + }; + write!(self.out, "{sampler}")?; + } + // HLSL arrays are written as `type name[size]` + // Current code is written arrays only as `[size]` + // Base `type` and `name` should be written outside + TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => { + self.write_array_size(module, base, size)?; + } + _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))), + } + + Ok(()) + } + + /// Helper method used to write functions + /// # Notes + /// Ends in a newline + fn write_function( + &mut self, + module: &Module, + name: &str, + func: &crate::Function, + func_ctx: &back::FunctionCtx<'_>, + info: &valid::FunctionInfo, + ) -> BackendResult { + // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax + + self.update_expressions_to_bake(module, func, info); + + // Write modifier + if let Some(crate::FunctionResult { + binding: + Some( + ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { + invariant: true, + }), + ), + .. + }) = func.result + { + self.write_modifier(binding)?; + } + + // Write return type + if let Some(ref result) = func.result { + match func_ctx.ty { + back::FunctionType::Function(_) => { + self.write_type(module, result.ty)?; + } + back::FunctionType::EntryPoint(index) => { + if let Some(ref ep_output) = self.entry_point_io[index as usize].output { + write!(self.out, "{}", ep_output.ty_name)?; + } else { + self.write_type(module, result.ty)?; + } + } + } + } else { + write!(self.out, "void")?; + } + + // Write function name + write!(self.out, " {name}(")?; + + let need_workgroup_variables_initialization = + self.need_workgroup_variables_initialization(func_ctx, module); + + // Write function arguments for non entry point functions + match func_ctx.ty { + back::FunctionType::Function(handle) => { + for (index, arg) in func.arguments.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + // Write argument type + let arg_ty = match module.types[arg.ty].inner { + // pointers in function arguments are expected and resolve to `inout` + TypeInner::Pointer { base, .. } => { + //TODO: can we narrow this down to just `in` when possible? + write!(self.out, "inout ")?; + base + } + _ => arg.ty, + }; + self.write_type(module, arg_ty)?; + + let argument_name = + &self.names[&NameKey::FunctionArgument(handle, index as u32)]; + + // Write argument name. Space is important. + write!(self.out, " {argument_name}")?; + if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner { + self.write_array_size(module, base, size)?; + } + } + } + back::FunctionType::EntryPoint(ep_index) => { + if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input { + write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?; + } else { + let stage = module.entry_points[ep_index as usize].stage; + for (index, arg) in func.arguments.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + self.write_type(module, arg.ty)?; + + let argument_name = + &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)]; + + write!(self.out, " {argument_name}")?; + if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner { + self.write_array_size(module, base, size)?; + } + + if let Some(ref binding) = arg.binding { + self.write_semantic(binding, Some((stage, Io::Input)))?; + } + } + + if need_workgroup_variables_initialization { + if !func.arguments.is_empty() { + write!(self.out, ", ")?; + } + write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?; + } + } + } + } + // Ends of arguments + write!(self.out, ")")?; + + // Write semantic if it present + if let back::FunctionType::EntryPoint(index) = func_ctx.ty { + let stage = module.entry_points[index as usize].stage; + if let Some(crate::FunctionResult { + binding: Some(ref binding), + .. + }) = func.result + { + self.write_semantic(binding, Some((stage, Io::Output)))?; + } + } + + // Function body start + writeln!(self.out)?; + writeln!(self.out, "{{")?; + + if need_workgroup_variables_initialization { + self.write_workgroup_variables_initialization(func_ctx, module)?; + } + + if let back::FunctionType::EntryPoint(index) = func_ctx.ty { + self.write_ep_arguments_initialization(module, func, index)?; + } + + // Write function local variables + for (handle, local) in func.local_variables.iter() { + // Write indentation (only for readability) + write!(self.out, "{}", back::INDENT)?; + + // Write the local name + // The leading space is important + self.write_type(module, local.ty)?; + write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?; + // Write size for array type + if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner { + self.write_array_size(module, base, size)?; + } + + write!(self.out, " = ")?; + // Write the local initializer if needed + if let Some(init) = local.init { + self.write_expr(module, init, func_ctx)?; + } else { + // Zero initialize local variables + self.write_default_init(module, local.ty)?; + } + + // Finish the local with `;` and add a newline (only for readability) + writeln!(self.out, ";")? + } + + if !func.local_variables.is_empty() { + writeln!(self.out)?; + } + + // Write the function body (statement list) + for sta in func.body.iter() { + // The indentation should always be 1 when writing the function body + self.write_stmt(module, sta, func_ctx, back::Level(1))?; + } + + writeln!(self.out, "}}")?; + + self.named_expressions.clear(); + + Ok(()) + } + + fn need_workgroup_variables_initialization( + &mut self, + func_ctx: &back::FunctionCtx, + module: &Module, + ) -> bool { + self.options.zero_initialize_workgroup_memory + && func_ctx.ty.is_compute_entry_point(module) + && module.global_variables.iter().any(|(handle, var)| { + !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }) + } + + fn write_workgroup_variables_initialization( + &mut self, + func_ctx: &back::FunctionCtx, + module: &Module, + ) -> BackendResult { + let level = back::Level(1); + + writeln!( + self.out, + "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{" + )?; + + let vars = module.global_variables.iter().filter(|&(handle, var)| { + !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }); + + for (handle, var) in vars { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, "{}{} = ", level.next(), name)?; + self.write_default_init(module, var.ty)?; + writeln!(self.out, ";")?; + } + + writeln!(self.out, "{level}}}")?; + self.write_barrier(crate::Barrier::WORK_GROUP, level) + } + + /// Helper method used to write statements + /// + /// # Notes + /// Always adds a newline + fn write_stmt( + &mut self, + module: &Module, + stmt: &crate::Statement, + func_ctx: &back::FunctionCtx<'_>, + level: back::Level, + ) -> BackendResult { + use crate::Statement; + + match *stmt { + Statement::Emit(ref range) => { + for handle in range.clone() { + let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space(); + let expr_name = if ptr_class.is_some() { + // HLSL can't save a pointer-valued expression in a variable, + // but we shouldn't ever need to: they should never be named expressions, + // and none of the expression types flagged by bake_ref_count can be pointer-valued. + None + } else if let Some(name) = func_ctx.named_expressions.get(&handle) { + // Front end provides names for all variables at the start of writing. + // But we write them to step by step. We need to recache them + // Otherwise, we could accidentally write variable name instead of full expression. + // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords. + Some(self.namer.call(name)) + } else if self.need_bake_expressions.contains(&handle) { + Some(format!("_expr{}", handle.index())) + } else { + None + }; + + if let Some(name) = expr_name { + write!(self.out, "{level}")?; + self.write_named_expr(module, handle, name, handle, func_ctx)?; + } + } + } + // TODO: copy-paste from glsl-out + Statement::Block(ref block) => { + write!(self.out, "{level}")?; + writeln!(self.out, "{{")?; + for sta in block.iter() { + // Increase the indentation to help with readability + self.write_stmt(module, sta, func_ctx, level.next())? + } + writeln!(self.out, "{level}}}")? + } + // TODO: copy-paste from glsl-out + Statement::If { + condition, + ref accept, + ref reject, + } => { + write!(self.out, "{level}")?; + write!(self.out, "if (")?; + self.write_expr(module, condition, func_ctx)?; + writeln!(self.out, ") {{")?; + + let l2 = level.next(); + for sta in accept { + // Increase indentation to help with readability + self.write_stmt(module, sta, func_ctx, l2)?; + } + + // If there are no statements in the reject block we skip writing it + // This is only for readability + if !reject.is_empty() { + writeln!(self.out, "{level}}} else {{")?; + + for sta in reject { + // Increase indentation to help with readability + self.write_stmt(module, sta, func_ctx, l2)?; + } + } + + writeln!(self.out, "{level}}}")? + } + // TODO: copy-paste from glsl-out + Statement::Kill => writeln!(self.out, "{level}discard;")?, + Statement::Return { value: None } => { + writeln!(self.out, "{level}return;")?; + } + Statement::Return { value: Some(expr) } => { + let base_ty_res = &func_ctx.info[expr].ty; + let mut resolved = base_ty_res.inner_with(&module.types); + if let TypeInner::Pointer { base, space: _ } = *resolved { + resolved = &module.types[base].inner; + } + + if let TypeInner::Struct { .. } = *resolved { + // We can safely unwrap here, since we now we working with struct + let ty = base_ty_res.handle().unwrap(); + let struct_name = &self.names[&NameKey::Type(ty)]; + let variable_name = self.namer.call(&struct_name.to_lowercase()); + write!(self.out, "{level}const {struct_name} {variable_name} = ",)?; + self.write_expr(module, expr, func_ctx)?; + writeln!(self.out, ";")?; + + // for entry point returns, we may need to reshuffle the outputs into a different struct + let ep_output = match func_ctx.ty { + back::FunctionType::Function(_) => None, + back::FunctionType::EntryPoint(index) => { + self.entry_point_io[index as usize].output.as_ref() + } + }; + let final_name = match ep_output { + Some(ep_output) => { + let final_name = self.namer.call(&variable_name); + write!( + self.out, + "{}const {} {} = {{ ", + level, ep_output.ty_name, final_name, + )?; + for (index, m) in ep_output.members.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + let member_name = &self.names[&NameKey::StructMember(ty, m.index)]; + write!(self.out, "{variable_name}.{member_name}")?; + } + writeln!(self.out, " }};")?; + final_name + } + None => variable_name, + }; + writeln!(self.out, "{level}return {final_name};")?; + } else { + write!(self.out, "{level}return ")?; + self.write_expr(module, expr, func_ctx)?; + writeln!(self.out, ";")? + } + } + Statement::Store { pointer, value } => { + let ty_inner = func_ctx.resolve_type(pointer, &module.types); + if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() { + let var_handle = self.fill_access_chain(module, pointer, func_ctx)?; + self.write_storage_store( + module, + var_handle, + StoreValue::Expression(value), + func_ctx, + level, + )?; + } else { + // We treat matrices of the form `matCx2` as a sequence of C `vec2`s. + // See the module-level block comment in mod.rs for details. + // + // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars). + // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads). + struct MatrixAccess { + base: Handle<crate::Expression>, + index: u32, + } + enum Index { + Expression(Handle<crate::Expression>), + Static(u32), + } + + let get_members = |expr: Handle<crate::Expression>| { + let resolved = func_ctx.resolve_type(expr, &module.types); + match *resolved { + TypeInner::Pointer { base, .. } => match module.types[base].inner { + TypeInner::Struct { ref members, .. } => Some(members), + _ => None, + }, + _ => None, + } + }; + + let mut matrix = None; + let mut vector = None; + let mut scalar = None; + + let mut current_expr = pointer; + for _ in 0..3 { + let resolved = func_ctx.resolve_type(current_expr, &module.types); + + match (resolved, &func_ctx.expressions[current_expr]) { + ( + &TypeInner::Pointer { base: ty, .. }, + &crate::Expression::AccessIndex { base, index }, + ) if matches!( + module.types[ty].inner, + TypeInner::Matrix { + rows: crate::VectorSize::Bi, + .. + } + ) && get_members(base) + .map(|members| members[index as usize].binding.is_none()) + == Some(true) => + { + matrix = Some(MatrixAccess { base, index }); + break; + } + ( + &TypeInner::ValuePointer { + size: Some(crate::VectorSize::Bi), + .. + }, + &crate::Expression::Access { base, index }, + ) => { + vector = Some(Index::Expression(index)); + current_expr = base; + } + ( + &TypeInner::ValuePointer { + size: Some(crate::VectorSize::Bi), + .. + }, + &crate::Expression::AccessIndex { base, index }, + ) => { + vector = Some(Index::Static(index)); + current_expr = base; + } + ( + &TypeInner::ValuePointer { size: None, .. }, + &crate::Expression::Access { base, index }, + ) => { + scalar = Some(Index::Expression(index)); + current_expr = base; + } + ( + &TypeInner::ValuePointer { size: None, .. }, + &crate::Expression::AccessIndex { base, index }, + ) => { + scalar = Some(Index::Static(index)); + current_expr = base; + } + _ => break, + } + } + + write!(self.out, "{level}")?; + + if let Some(MatrixAccess { index, base }) = matrix { + let base_ty_res = &func_ctx.info[base].ty; + let resolved = base_ty_res.inner_with(&module.types); + let ty = match *resolved { + TypeInner::Pointer { base, .. } => base, + _ => base_ty_res.handle().unwrap(), + }; + + if let Some(Index::Static(vec_index)) = vector { + self.write_expr(module, base, func_ctx)?; + write!( + self.out, + ".{}_{}", + &self.names[&NameKey::StructMember(ty, index)], + vec_index + )?; + + if let Some(scalar_index) = scalar { + write!(self.out, "[")?; + match scalar_index { + Index::Static(index) => { + write!(self.out, "{index}")?; + } + Index::Expression(index) => { + self.write_expr(module, index, func_ctx)?; + } + } + write!(self.out, "]")?; + } + + write!(self.out, " = ")?; + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ";")?; + } else { + let access = WrappedStructMatrixAccess { ty, index }; + match (&vector, &scalar) { + (&Some(_), &Some(_)) => { + self.write_wrapped_struct_matrix_set_scalar_function_name( + access, + )?; + } + (&Some(_), &None) => { + self.write_wrapped_struct_matrix_set_vec_function_name(access)?; + } + (&None, _) => { + self.write_wrapped_struct_matrix_set_function_name(access)?; + } + } + + write!(self.out, "(")?; + self.write_expr(module, base, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + + if let Some(Index::Expression(vec_index)) = vector { + write!(self.out, ", ")?; + self.write_expr(module, vec_index, func_ctx)?; + + if let Some(scalar_index) = scalar { + write!(self.out, ", ")?; + match scalar_index { + Index::Static(index) => { + write!(self.out, "{index}")?; + } + Index::Expression(index) => { + self.write_expr(module, index, func_ctx)?; + } + } + } + } + writeln!(self.out, ");")?; + } + } else { + // We handle `Store`s to __matCx2 column vectors and scalar elements via + // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2. + struct MatrixData { + columns: crate::VectorSize, + base: Handle<crate::Expression>, + } + + enum Index { + Expression(Handle<crate::Expression>), + Static(u32), + } + + let mut matrix = None; + let mut vector = None; + let mut scalar = None; + + let mut current_expr = pointer; + for _ in 0..3 { + let resolved = func_ctx.resolve_type(current_expr, &module.types); + match (resolved, &func_ctx.expressions[current_expr]) { + ( + &TypeInner::ValuePointer { + size: Some(crate::VectorSize::Bi), + .. + }, + &crate::Expression::Access { base, index }, + ) => { + vector = Some(index); + current_expr = base; + } + ( + &TypeInner::ValuePointer { size: None, .. }, + &crate::Expression::Access { base, index }, + ) => { + scalar = Some(Index::Expression(index)); + current_expr = base; + } + ( + &TypeInner::ValuePointer { size: None, .. }, + &crate::Expression::AccessIndex { base, index }, + ) => { + scalar = Some(Index::Static(index)); + current_expr = base; + } + _ => { + if let Some(MatrixType { + columns, + rows: crate::VectorSize::Bi, + width: 4, + }) = get_inner_matrix_of_struct_array_member( + module, + current_expr, + func_ctx, + true, + ) { + matrix = Some(MatrixData { + columns, + base: current_expr, + }); + } + + break; + } + } + } + + if let (Some(MatrixData { columns, base }), Some(vec_index)) = + (matrix, vector) + { + if scalar.is_some() { + write!(self.out, "__set_el_of_mat{}x2", columns as u8)?; + } else { + write!(self.out, "__set_col_of_mat{}x2", columns as u8)?; + } + write!(self.out, "(")?; + self.write_expr(module, base, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, vec_index, func_ctx)?; + + if let Some(scalar_index) = scalar { + write!(self.out, ", ")?; + match scalar_index { + Index::Static(index) => { + write!(self.out, "{index}")?; + } + Index::Expression(index) => { + self.write_expr(module, index, func_ctx)?; + } + } + } + + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + + writeln!(self.out, ");")?; + } else { + self.write_expr(module, pointer, func_ctx)?; + write!(self.out, " = ")?; + + // We cast the RHS of this store in cases where the LHS + // is a struct member with type: + // - matCx2 or + // - a (possibly nested) array of matCx2's + if let Some(MatrixType { + columns, + rows: crate::VectorSize::Bi, + width: 4, + }) = get_inner_matrix_of_struct_array_member( + module, pointer, func_ctx, false, + ) { + let mut resolved = func_ctx.resolve_type(pointer, &module.types); + if let TypeInner::Pointer { base, .. } = *resolved { + resolved = &module.types[base].inner; + } + + write!(self.out, "(__mat{}x2", columns as u8)?; + if let TypeInner::Array { base, size, .. } = *resolved { + self.write_array_size(module, base, size)?; + } + write!(self.out, ")")?; + } + + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ";")? + } + } + } + } + Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + let l2 = level.next(); + if !continuing.is_empty() || break_if.is_some() { + let gate_name = self.namer.call("loop_init"); + writeln!(self.out, "{level}bool {gate_name} = true;")?; + writeln!(self.out, "{level}while(true) {{")?; + writeln!(self.out, "{l2}if (!{gate_name}) {{")?; + let l3 = l2.next(); + for sta in continuing.iter() { + self.write_stmt(module, sta, func_ctx, l3)?; + } + if let Some(condition) = break_if { + write!(self.out, "{l3}if (")?; + self.write_expr(module, condition, func_ctx)?; + writeln!(self.out, ") {{")?; + writeln!(self.out, "{}break;", l3.next())?; + writeln!(self.out, "{l3}}}")?; + } + writeln!(self.out, "{l2}}}")?; + writeln!(self.out, "{l2}{gate_name} = false;")?; + } else { + writeln!(self.out, "{level}while(true) {{")?; + } + + for sta in body.iter() { + self.write_stmt(module, sta, func_ctx, l2)?; + } + writeln!(self.out, "{level}}}")? + } + Statement::Break => writeln!(self.out, "{level}break;")?, + Statement::Continue => writeln!(self.out, "{level}continue;")?, + Statement::Barrier(barrier) => { + self.write_barrier(barrier, level)?; + } + Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => { + write!(self.out, "{level}")?; + self.write_expr(module, image, func_ctx)?; + + write!(self.out, "[")?; + if let Some(index) = array_index { + // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here + write!(self.out, "int3(")?; + self.write_expr(module, coordinate, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr(module, coordinate, func_ctx)?; + } + write!(self.out, "]")?; + + write!(self.out, " = ")?; + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ";")?; + } + Statement::Call { + function, + ref arguments, + result, + } => { + write!(self.out, "{level}")?; + if let Some(expr) = result { + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, expr.index()); + let expr_ty = &func_ctx.info[expr].ty; + match *expr_ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(expr, name); + } + let func_name = &self.names[&NameKey::Function(function)]; + write!(self.out, "{func_name}(")?; + for (index, argument) in arguments.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + self.write_expr(module, *argument, func_ctx)?; + } + writeln!(self.out, ");")? + } + Statement::Atomic { + pointer, + ref fun, + value, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + + // Validation ensures that `pointer` has a `Pointer` type. + let pointer_space = func_ctx + .resolve_type(pointer, &module.types) + .pointer_space() + .unwrap(); + + let fun_str = fun.to_hlsl_suffix(); + write!(self.out, " {res_name}; ")?; + match pointer_space { + crate::AddressSpace::WorkGroup => { + write!(self.out, "Interlocked{fun_str}(")?; + self.write_expr(module, pointer, func_ctx)?; + } + crate::AddressSpace::Storage { .. } => { + let var_handle = self.fill_access_chain(module, pointer, func_ctx)?; + // The call to `self.write_storage_address` wants + // mutable access to all of `self`, so temporarily take + // ownership of our reusable access chain buffer. + let chain = mem::take(&mut self.temp_access_chain); + let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; + write!(self.out, "{var_name}.Interlocked{fun_str}(")?; + self.write_storage_address(module, &chain, func_ctx)?; + self.temp_access_chain = chain; + } + ref other => { + return Err(Error::Custom(format!( + "invalid address space {other:?} for atomic statement" + ))) + } + } + write!(self.out, ", ")?; + // handle the special cases + match *fun { + crate::AtomicFunction::Subtract => { + // we just wrote `InterlockedAdd`, so negate the argument + write!(self.out, "-")?; + } + crate::AtomicFunction::Exchange { compare: Some(_) } => { + return Err(Error::Unimplemented("atomic CompareExchange".to_string())); + } + _ => {} + } + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ", {res_name});")?; + self.named_expressions.insert(result, res_name); + } + Statement::WorkGroupUniformLoad { pointer, result } => { + self.write_barrier(crate::Barrier::WORK_GROUP, level)?; + write!(self.out, "{level}")?; + let name = format!("_expr{}", result.index()); + self.write_named_expr(module, pointer, name, result, func_ctx)?; + + self.write_barrier(crate::Barrier::WORK_GROUP, level)?; + } + Statement::Switch { + selector, + ref cases, + } => { + // Start the switch + write!(self.out, "{level}")?; + write!(self.out, "switch(")?; + self.write_expr(module, selector, func_ctx)?; + writeln!(self.out, ") {{")?; + + // Write all cases + let indent_level_1 = level.next(); + let indent_level_2 = indent_level_1.next(); + + for (i, case) in cases.iter().enumerate() { + match case.value { + crate::SwitchValue::I32(value) => { + write!(self.out, "{indent_level_1}case {value}:")? + } + crate::SwitchValue::U32(value) => { + write!(self.out, "{indent_level_1}case {value}u:")? + } + crate::SwitchValue::Default => { + write!(self.out, "{indent_level_1}default:")? + } + } + + // The new block is not only stylistic, it plays a role here: + // We might end up having to write the same case body + // multiple times due to FXC not supporting fallthrough. + // Therefore, some `Expression`s written by `Statement::Emit` + // will end up having the same name (`_expr<handle_index>`). + // So we need to put each case in its own scope. + let write_block_braces = !(case.fall_through && case.body.is_empty()); + if write_block_braces { + writeln!(self.out, " {{")?; + } else { + writeln!(self.out)?; + } + + // Although FXC does support a series of case clauses before + // a block[^yes], it does not support fallthrough from a + // non-empty case block to the next[^no]. If this case has a + // non-empty body with a fallthrough, emulate that by + // duplicating the bodies of all the cases it would fall + // into as extensions of this case's own body. This makes + // the HLSL output potentially quadratic in the size of the + // Naga IR. + // + // [^yes]: ```hlsl + // case 1: + // case 2: do_stuff() + // ``` + // [^no]: ```hlsl + // case 1: do_this(); + // case 2: do_that(); + // ``` + if case.fall_through && !case.body.is_empty() { + let curr_len = i + 1; + let end_case_idx = curr_len + + cases + .iter() + .skip(curr_len) + .position(|case| !case.fall_through) + .unwrap(); + let indent_level_3 = indent_level_2.next(); + for case in &cases[i..=end_case_idx] { + writeln!(self.out, "{indent_level_2}{{")?; + let prev_len = self.named_expressions.len(); + for sta in case.body.iter() { + self.write_stmt(module, sta, func_ctx, indent_level_3)?; + } + // Clear all named expressions that were previously inserted by the statements in the block + self.named_expressions.truncate(prev_len); + writeln!(self.out, "{indent_level_2}}}")?; + } + + let last_case = &cases[end_case_idx]; + if last_case.body.last().map_or(true, |s| !s.is_terminator()) { + writeln!(self.out, "{indent_level_2}break;")?; + } + } else { + for sta in case.body.iter() { + self.write_stmt(module, sta, func_ctx, indent_level_2)?; + } + if !case.fall_through + && case.body.last().map_or(true, |s| !s.is_terminator()) + { + writeln!(self.out, "{indent_level_2}break;")?; + } + } + + if write_block_braces { + writeln!(self.out, "{indent_level_1}}}")?; + } + } + + writeln!(self.out, "{level}}}")? + } + Statement::RayQuery { .. } => unreachable!(), + } + + Ok(()) + } + + fn write_const_expression( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + ) -> BackendResult { + self.write_possibly_const_expression( + module, + expr, + &module.const_expressions, + |writer, expr| writer.write_const_expression(module, expr), + ) + } + + fn write_possibly_const_expression<E>( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + expressions: &crate::Arena<crate::Expression>, + write_expression: E, + ) -> BackendResult + where + E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult, + { + use crate::Expression; + + match expressions[expr] { + Expression::Literal(literal) => match literal { + // Floats are written using `Debug` instead of `Display` because it always appends the + // decimal part even it's zero + crate::Literal::F64(value) => write!(self.out, "{value:?}L")?, + crate::Literal::F32(value) => write!(self.out, "{value:?}")?, + crate::Literal::U32(value) => write!(self.out, "{}u", value)?, + crate::Literal::I32(value) => write!(self.out, "{}", value)?, + crate::Literal::I64(value) => write!(self.out, "{}L", value)?, + crate::Literal::Bool(value) => write!(self.out, "{}", value)?, + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".into(), + )); + } + }, + Expression::Constant(handle) => { + let constant = &module.constants[handle]; + if constant.name.is_some() { + write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?; + } else { + self.write_const_expression(module, constant.init)?; + } + } + Expression::ZeroValue(ty) => self.write_default_init(module, ty)?, + Expression::Compose { ty, ref components } => { + match module.types[ty].inner { + TypeInner::Struct { .. } | TypeInner::Array { .. } => { + self.write_wrapped_constructor_function_name( + module, + WrappedConstructor { ty }, + )?; + } + _ => { + self.write_type(module, ty)?; + } + }; + write!(self.out, "(")?; + for (index, component) in components.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + write_expression(self, *component)?; + } + write!(self.out, ")")?; + } + Expression::Splat { size, value } => { + // hlsl is not supported one value constructor + // if we write, for example, int4(0), dxc returns error: + // error: too few elements in vector initialization (expected 4 elements, have 1) + let number_of_components = match size { + crate::VectorSize::Bi => "xx", + crate::VectorSize::Tri => "xxx", + crate::VectorSize::Quad => "xxxx", + }; + write!(self.out, "(")?; + write_expression(self, value)?; + write!(self.out, ").{number_of_components}")? + } + _ => unreachable!(), + } + + Ok(()) + } + + /// Helper method to write expressions + /// + /// # Notes + /// Doesn't add any newlines or leading/trailing spaces + pub(super) fn write_expr( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx<'_>, + ) -> BackendResult { + use crate::Expression; + + // Handle the special semantics of vertex_index/instance_index + let ff_input = if self.options.special_constants_binding.is_some() { + func_ctx.is_fixed_function_input(expr, module) + } else { + None + }; + let closing_bracket = match ff_input { + Some(crate::BuiltIn::VertexIndex) => { + write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?; + ")" + } + Some(crate::BuiltIn::InstanceIndex) => { + write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?; + ")" + } + Some(crate::BuiltIn::NumWorkGroups) => { + // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`), + // in compute shaders the special constants contain the number + // of workgroups, which we are using here. + write!( + self.out, + "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})", + )?; + return Ok(()); + } + _ => "", + }; + + if let Some(name) = self.named_expressions.get(&expr) { + write!(self.out, "{name}{closing_bracket}")?; + return Ok(()); + } + + let expression = &func_ctx.expressions[expr]; + + match *expression { + Expression::Literal(_) + | Expression::Constant(_) + | Expression::ZeroValue(_) + | Expression::Compose { .. } + | Expression::Splat { .. } => { + self.write_possibly_const_expression( + module, + expr, + func_ctx.expressions, + |writer, expr| writer.write_expr(module, expr, func_ctx), + )?; + } + // All of the multiplication can be expressed as `mul`, + // except vector * vector, which needs to use the "*" operator. + Expression::Binary { + op: crate::BinaryOperator::Multiply, + left, + right, + } if func_ctx.resolve_type(left, &module.types).is_matrix() + || func_ctx.resolve_type(right, &module.types).is_matrix() => + { + // We intentionally flip the order of multiplication as our matrices are implicitly transposed. + write!(self.out, "mul(")?; + self.write_expr(module, right, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, left, func_ctx)?; + write!(self.out, ")")?; + } + + // TODO: handle undefined behavior of BinaryOperator::Modulo + // + // sint: + // if right == 0 return 0 + // if left == min(type_of(left)) && right == -1 return 0 + // if sign(left) != sign(right) return result as defined by WGSL + // + // uint: + // if right == 0 return 0 + // + // float: + // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798 + + // While HLSL supports float operands with the % operator it is only + // defined in cases where both sides are either positive or negative. + Expression::Binary { + op: crate::BinaryOperator::Modulo, + left, + right, + } if func_ctx.resolve_type(left, &module.types).scalar_kind() + == Some(crate::ScalarKind::Float) => + { + write!(self.out, "fmod(")?; + self.write_expr(module, left, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, right, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Binary { op, left, right } => { + write!(self.out, "(")?; + self.write_expr(module, left, func_ctx)?; + write!(self.out, " {} ", crate::back::binary_operation_str(op))?; + self.write_expr(module, right, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Access { base, index } => { + if let Some(crate::AddressSpace::Storage { .. }) = + func_ctx.resolve_type(expr, &module.types).pointer_space() + { + // do nothing, the chain is written on `Load`/`Store` + } else { + // We use the function __get_col_of_matCx2 here in cases + // where `base`s type resolves to a matCx2 and is part of a + // struct member with type of (possibly nested) array of matCx2's. + // + // Note that this only works for `Load`s and we handle + // `Store`s differently in `Statement::Store`. + if let Some(MatrixType { + columns, + rows: crate::VectorSize::Bi, + width: 4, + }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true) + { + write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?; + self.write_expr(module, base, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + write!(self.out, ")")?; + return Ok(()); + } + + let resolved = func_ctx.resolve_type(base, &module.types); + + let non_uniform_qualifier = match *resolved { + TypeInner::BindingArray { .. } => { + let uniformity = &func_ctx.info[index].uniformity; + + uniformity.non_uniform_result.is_some() + } + _ => false, + }; + + self.write_expr(module, base, func_ctx)?; + write!(self.out, "[")?; + if non_uniform_qualifier { + write!(self.out, "NonUniformResourceIndex(")?; + } + self.write_expr(module, index, func_ctx)?; + if non_uniform_qualifier { + write!(self.out, ")")?; + } + write!(self.out, "]")?; + } + } + Expression::AccessIndex { base, index } => { + if let Some(crate::AddressSpace::Storage { .. }) = + func_ctx.resolve_type(expr, &module.types).pointer_space() + { + // do nothing, the chain is written on `Load`/`Store` + } else { + fn write_access<W: fmt::Write>( + writer: &mut super::Writer<'_, W>, + resolved: &TypeInner, + base_ty_handle: Option<Handle<crate::Type>>, + index: u32, + ) -> BackendResult { + match *resolved { + // We specifically lift the ValuePointer to this case. While `[0]` is valid + // HLSL for any vector behind a value pointer, FXC completely miscompiles + // it and generates completely nonsensical DXBC. + // + // See https://github.com/gfx-rs/naga/issues/2095 for more details. + TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => { + // Write vector access as a swizzle + write!(writer.out, ".{}", back::COMPONENTS[index as usize])? + } + TypeInner::Matrix { .. } + | TypeInner::Array { .. } + | TypeInner::BindingArray { .. } => write!(writer.out, "[{index}]")?, + TypeInner::Struct { .. } => { + // This will never panic in case the type is a `Struct`, this is not true + // for other types so we can only check while inside this match arm + let ty = base_ty_handle.unwrap(); + + write!( + writer.out, + ".{}", + &writer.names[&NameKey::StructMember(ty, index)] + )? + } + ref other => { + return Err(Error::Custom(format!("Cannot index {other:?}"))) + } + } + Ok(()) + } + + // We write the matrix column access in a special way since + // the type of `base` is our special __matCx2 struct. + if let Some(MatrixType { + rows: crate::VectorSize::Bi, + width: 4, + .. + }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true) + { + self.write_expr(module, base, func_ctx)?; + write!(self.out, "._{index}")?; + return Ok(()); + } + + let base_ty_res = &func_ctx.info[base].ty; + let mut resolved = base_ty_res.inner_with(&module.types); + let base_ty_handle = match *resolved { + TypeInner::Pointer { base, .. } => { + resolved = &module.types[base].inner; + Some(base) + } + _ => base_ty_res.handle(), + }; + + // We treat matrices of the form `matCx2` as a sequence of C `vec2`s. + // See the module-level block comment in mod.rs for details. + // + // We handle matrix reconstruction here for Loads. + // Stores are handled directly by `Statement::Store`. + if let TypeInner::Struct { ref members, .. } = *resolved { + let member = &members[index as usize]; + + match module.types[member.ty].inner { + TypeInner::Matrix { + rows: crate::VectorSize::Bi, + .. + } if member.binding.is_none() => { + let ty = base_ty_handle.unwrap(); + self.write_wrapped_struct_matrix_get_function_name( + WrappedStructMatrixAccess { ty, index }, + )?; + write!(self.out, "(")?; + self.write_expr(module, base, func_ctx)?; + write!(self.out, ")")?; + return Ok(()); + } + _ => {} + } + } + + self.write_expr(module, base, func_ctx)?; + write_access(self, resolved, base_ty_handle, index)?; + } + } + Expression::FunctionArgument(pos) => { + let key = func_ctx.argument_key(pos); + let name = &self.names[&key]; + write!(self.out, "{name}")?; + } + Expression::ImageSample { + image, + sampler, + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + } => { + use crate::SampleLevel as Sl; + const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"]; + + let (base_str, component_str) = match gather { + Some(component) => ("Gather", COMPONENTS[component as usize]), + None => ("Sample", ""), + }; + let cmp_str = match depth_ref { + Some(_) => "Cmp", + None => "", + }; + let level_str = match level { + Sl::Zero if gather.is_none() => "LevelZero", + Sl::Auto | Sl::Zero => "", + Sl::Exact(_) => "Level", + Sl::Bias(_) => "Bias", + Sl::Gradient { .. } => "Grad", + }; + + self.write_expr(module, image, func_ctx)?; + write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?; + self.write_expr(module, sampler, func_ctx)?; + write!(self.out, ", ")?; + self.write_texture_coordinates( + "float", + coordinate, + array_index, + None, + module, + func_ctx, + )?; + + if let Some(depth_ref) = depth_ref { + write!(self.out, ", ")?; + self.write_expr(module, depth_ref, func_ctx)?; + } + + match level { + Sl::Auto | Sl::Zero => {} + Sl::Exact(expr) => { + write!(self.out, ", ")?; + self.write_expr(module, expr, func_ctx)?; + } + Sl::Bias(expr) => { + write!(self.out, ", ")?; + self.write_expr(module, expr, func_ctx)?; + } + Sl::Gradient { x, y } => { + write!(self.out, ", ")?; + self.write_expr(module, x, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, y, func_ctx)?; + } + } + + if let Some(offset) = offset { + write!(self.out, ", ")?; + write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807 + self.write_const_expression(module, offset)?; + write!(self.out, ")")?; + } + + write!(self.out, ")")?; + } + Expression::ImageQuery { image, query } => { + // use wrapped image query function + if let TypeInner::Image { + dim, + arrayed, + class, + } = *func_ctx.resolve_type(image, &module.types) + { + let wrapped_image_query = WrappedImageQuery { + dim, + arrayed, + class, + query: query.into(), + }; + + self.write_wrapped_image_query_function_name(wrapped_image_query)?; + write!(self.out, "(")?; + // Image always first param + self.write_expr(module, image, func_ctx)?; + if let crate::ImageQuery::Size { level: Some(level) } = query { + write!(self.out, ", ")?; + self.write_expr(module, level, func_ctx)?; + } + write!(self.out, ")")?; + } + } + Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load + self.write_expr(module, image, func_ctx)?; + write!(self.out, ".Load(")?; + + self.write_texture_coordinates( + "int", + coordinate, + array_index, + level, + module, + func_ctx, + )?; + + if let Some(sample) = sample { + write!(self.out, ", ")?; + self.write_expr(module, sample, func_ctx)?; + } + + // close bracket for Load function + write!(self.out, ")")?; + + // return x component if return type is scalar + if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) { + write!(self.out, ".x")?; + } + } + Expression::GlobalVariable(handle) => match module.global_variables[handle].space { + crate::AddressSpace::Storage { .. } => {} + _ => { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, "{name}")?; + } + }, + Expression::LocalVariable(handle) => { + write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])? + } + Expression::Load { pointer } => { + match func_ctx + .resolve_type(pointer, &module.types) + .pointer_space() + { + Some(crate::AddressSpace::Storage { .. }) => { + let var_handle = self.fill_access_chain(module, pointer, func_ctx)?; + let result_ty = func_ctx.info[expr].ty.clone(); + self.write_storage_load(module, var_handle, result_ty, func_ctx)?; + } + _ => { + let mut close_paren = false; + + // We cast the value loaded to a native HLSL floatCx2 + // in cases where it is of type: + // - __matCx2 or + // - a (possibly nested) array of __matCx2's + if let Some(MatrixType { + rows: crate::VectorSize::Bi, + width: 4, + .. + }) = get_inner_matrix_of_struct_array_member( + module, pointer, func_ctx, false, + ) + .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx)) + { + let mut resolved = func_ctx.resolve_type(pointer, &module.types); + if let TypeInner::Pointer { base, .. } = *resolved { + resolved = &module.types[base].inner; + } + + write!(self.out, "((")?; + if let TypeInner::Array { base, size, .. } = *resolved { + self.write_type(module, base)?; + self.write_array_size(module, base, size)?; + } else { + self.write_value_type(module, resolved)?; + } + write!(self.out, ")")?; + close_paren = true; + } + + self.write_expr(module, pointer, func_ctx)?; + + if close_paren { + write!(self.out, ")")?; + } + } + } + } + Expression::Unary { op, expr } => { + // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators + let op_str = match op { + crate::UnaryOperator::Negate => "-", + crate::UnaryOperator::LogicalNot => "!", + crate::UnaryOperator::BitwiseNot => "~", + }; + write!(self.out, "{op_str}(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")?; + } + Expression::As { + expr, + kind, + convert, + } => { + let inner = func_ctx.resolve_type(expr, &module.types); + match convert { + Some(dst_width) => { + let scalar = crate::Scalar { + kind, + width: dst_width, + }; + match *inner { + TypeInner::Vector { size, .. } => { + write!( + self.out, + "{}{}(", + scalar.to_hlsl_str()?, + back::vector_size_str(size) + )?; + } + TypeInner::Scalar(_) => { + write!(self.out, "{}(", scalar.to_hlsl_str()?,)?; + } + TypeInner::Matrix { columns, rows, .. } => { + write!( + self.out, + "{}{}x{}(", + scalar.to_hlsl_str()?, + back::vector_size_str(columns), + back::vector_size_str(rows) + )?; + } + _ => { + return Err(Error::Unimplemented(format!( + "write_expr expression::as {inner:?}" + ))); + } + }; + } + None => { + write!(self.out, "{}(", kind.to_hlsl_cast(),)?; + } + } + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + use crate::MathFunction as Mf; + + enum Function { + Asincosh { is_sin: bool }, + Atanh, + ExtractBits, + InsertBits, + Pack2x16float, + Pack2x16snorm, + Pack2x16unorm, + Pack4x8snorm, + Pack4x8unorm, + Unpack2x16float, + Unpack2x16snorm, + Unpack2x16unorm, + Unpack4x8snorm, + Unpack4x8unorm, + Regular(&'static str), + MissingIntOverload(&'static str), + MissingIntReturnType(&'static str), + CountTrailingZeros, + CountLeadingZeros, + } + + let fun = match fun { + // comparison + Mf::Abs => Function::Regular("abs"), + Mf::Min => Function::Regular("min"), + Mf::Max => Function::Regular("max"), + Mf::Clamp => Function::Regular("clamp"), + Mf::Saturate => Function::Regular("saturate"), + // trigonometry + Mf::Cos => Function::Regular("cos"), + Mf::Cosh => Function::Regular("cosh"), + Mf::Sin => Function::Regular("sin"), + Mf::Sinh => Function::Regular("sinh"), + Mf::Tan => Function::Regular("tan"), + Mf::Tanh => Function::Regular("tanh"), + Mf::Acos => Function::Regular("acos"), + Mf::Asin => Function::Regular("asin"), + Mf::Atan => Function::Regular("atan"), + Mf::Atan2 => Function::Regular("atan2"), + Mf::Asinh => Function::Asincosh { is_sin: true }, + Mf::Acosh => Function::Asincosh { is_sin: false }, + Mf::Atanh => Function::Atanh, + Mf::Radians => Function::Regular("radians"), + Mf::Degrees => Function::Regular("degrees"), + // decomposition + Mf::Ceil => Function::Regular("ceil"), + Mf::Floor => Function::Regular("floor"), + Mf::Round => Function::Regular("round"), + Mf::Fract => Function::Regular("frac"), + Mf::Trunc => Function::Regular("trunc"), + Mf::Modf => Function::Regular(MODF_FUNCTION), + Mf::Frexp => Function::Regular(FREXP_FUNCTION), + Mf::Ldexp => Function::Regular("ldexp"), + // exponent + Mf::Exp => Function::Regular("exp"), + Mf::Exp2 => Function::Regular("exp2"), + Mf::Log => Function::Regular("log"), + Mf::Log2 => Function::Regular("log2"), + Mf::Pow => Function::Regular("pow"), + // geometry + Mf::Dot => Function::Regular("dot"), + //Mf::Outer => , + Mf::Cross => Function::Regular("cross"), + Mf::Distance => Function::Regular("distance"), + Mf::Length => Function::Regular("length"), + Mf::Normalize => Function::Regular("normalize"), + Mf::FaceForward => Function::Regular("faceforward"), + Mf::Reflect => Function::Regular("reflect"), + Mf::Refract => Function::Regular("refract"), + // computational + Mf::Sign => Function::Regular("sign"), + Mf::Fma => Function::Regular("mad"), + Mf::Mix => Function::Regular("lerp"), + Mf::Step => Function::Regular("step"), + Mf::SmoothStep => Function::Regular("smoothstep"), + Mf::Sqrt => Function::Regular("sqrt"), + Mf::InverseSqrt => Function::Regular("rsqrt"), + //Mf::Inverse =>, + Mf::Transpose => Function::Regular("transpose"), + Mf::Determinant => Function::Regular("determinant"), + // bits + Mf::CountTrailingZeros => Function::CountTrailingZeros, + Mf::CountLeadingZeros => Function::CountLeadingZeros, + Mf::CountOneBits => Function::MissingIntOverload("countbits"), + Mf::ReverseBits => Function::MissingIntOverload("reversebits"), + Mf::FindLsb => Function::MissingIntReturnType("firstbitlow"), + Mf::FindMsb => Function::MissingIntReturnType("firstbithigh"), + Mf::ExtractBits => Function::ExtractBits, + Mf::InsertBits => Function::InsertBits, + // Data Packing + Mf::Pack2x16float => Function::Pack2x16float, + Mf::Pack2x16snorm => Function::Pack2x16snorm, + Mf::Pack2x16unorm => Function::Pack2x16unorm, + Mf::Pack4x8snorm => Function::Pack4x8snorm, + Mf::Pack4x8unorm => Function::Pack4x8unorm, + // Data Unpacking + Mf::Unpack2x16float => Function::Unpack2x16float, + Mf::Unpack2x16snorm => Function::Unpack2x16snorm, + Mf::Unpack2x16unorm => Function::Unpack2x16unorm, + Mf::Unpack4x8snorm => Function::Unpack4x8snorm, + Mf::Unpack4x8unorm => Function::Unpack4x8unorm, + _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))), + }; + + match fun { + Function::Asincosh { is_sin } => { + write!(self.out, "log(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " + sqrt(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " * ")?; + self.write_expr(module, arg, func_ctx)?; + match is_sin { + true => write!(self.out, " + 1.0))")?, + false => write!(self.out, " - 1.0))")?, + } + } + Function::Atanh => { + write!(self.out, "0.5 * log((1.0 + ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ") / (1.0 - ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } + Function::ExtractBits => { + // e: T, + // offset: u32, + // count: u32 + // T is u32 or i32 or vecN<u32> or vecN<i32> + if let (Some(offset), Some(count)) = (arg1, arg2) { + let scalar_width: u8 = 32; + // Works for signed and unsigned + // (count == 0 ? 0 : (e << (32 - count - offset)) >> (32 - count)) + write!(self.out, "(")?; + self.write_expr(module, count, func_ctx)?; + write!(self.out, " == 0 ? 0 : (")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " << ({scalar_width} - ")?; + self.write_expr(module, count, func_ctx)?; + write!(self.out, " - ")?; + self.write_expr(module, offset, func_ctx)?; + write!(self.out, ")) >> ({scalar_width} - ")?; + self.write_expr(module, count, func_ctx)?; + write!(self.out, "))")?; + } + } + Function::InsertBits => { + // e: T, + // newbits: T, + // offset: u32, + // count: u32 + // returns T + // T is i32, u32, vecN<i32>, or vecN<u32> + if let (Some(newbits), Some(offset), Some(count)) = (arg1, arg2, arg3) { + let scalar_width: u8 = 32; + let scalar_max: u32 = 0xFFFFFFFF; + // mask = ((0xFFFFFFFFu >> (32 - count)) << offset) + // (count == 0 ? e : ((e & ~mask) | ((newbits << offset) & mask))) + write!(self.out, "(")?; + self.write_expr(module, count, func_ctx)?; + write!(self.out, " == 0 ? ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " : ")?; + write!(self.out, "(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " & ~")?; + // mask + write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?; + self.write_expr(module, count, func_ctx)?; + write!(self.out, ")) << ")?; + self.write_expr(module, offset, func_ctx)?; + write!(self.out, ")")?; + // end mask + write!(self.out, ") | ((")?; + self.write_expr(module, newbits, func_ctx)?; + write!(self.out, " << ")?; + self.write_expr(module, offset, func_ctx)?; + write!(self.out, ") & ")?; + // // mask + write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?; + self.write_expr(module, count, func_ctx)?; + write!(self.out, ")) << ")?; + self.write_expr(module, offset, func_ctx)?; + write!(self.out, ")")?; + // // end mask + write!(self.out, "))")?; + } + } + Function::Pack2x16float => { + write!(self.out, "(f32tof16(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[0]) | f32tof16(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[1]) << 16)")?; + } + Function::Pack2x16snorm => { + let scale = 32767; + + write!(self.out, "uint((int(round(clamp(")?; + self.write_expr(module, arg, func_ctx)?; + write!( + self.out, + "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp(" + )?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?; + } + Function::Pack2x16unorm => { + let scale = 65535; + + write!(self.out, "(uint(round(clamp(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?; + } + Function::Pack4x8snorm => { + let scale = 127; + + write!(self.out, "uint((int(round(clamp(")?; + self.write_expr(module, arg, func_ctx)?; + write!( + self.out, + "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp(" + )?; + self.write_expr(module, arg, func_ctx)?; + write!( + self.out, + "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp(" + )?; + self.write_expr(module, arg, func_ctx)?; + write!( + self.out, + "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp(" + )?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?; + } + Function::Pack4x8unorm => { + let scale = 255; + + write!(self.out, "(uint(round(clamp(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?; + self.write_expr(module, arg, func_ctx)?; + write!( + self.out, + "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp(" + )?; + self.write_expr(module, arg, func_ctx)?; + write!( + self.out, + "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp(" + )?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?; + } + + Function::Unpack2x16float => { + write!(self.out, "float2(f16tof32(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "), f16tof32((")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ") >> 16))")?; + } + Function::Unpack2x16snorm => { + let scale = 32767; + + write!(self.out, "(float2(int2(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " << 16, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ") >> 16) / {scale}.0)")?; + } + Function::Unpack2x16unorm => { + let scale = 65535; + + write!(self.out, "(float2(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " & 0xFFFF, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " >> 16) / {scale}.0)")?; + } + Function::Unpack4x8snorm => { + let scale = 127; + + write!(self.out, "(float4(int4(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " << 24, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " << 16, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " << 8, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ") >> 24) / {scale}.0)")?; + } + Function::Unpack4x8unorm => { + let scale = 255; + + write!(self.out, "(float4(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " & 0xFF, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " >> 8 & 0xFF, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " >> 16 & 0xFF, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " >> 24) / {scale}.0)")?; + } + Function::Regular(fun_name) => { + write!(self.out, "{fun_name}(")?; + self.write_expr(module, arg, func_ctx)?; + if let Some(arg) = arg1 { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + if let Some(arg) = arg2 { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + if let Some(arg) = arg3 { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + write!(self.out, ")")? + } + Function::MissingIntOverload(fun_name) => { + let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind(); + if let Some(ScalarKind::Sint) = scalar_kind { + write!(self.out, "asint({fun_name}(asuint(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")))")?; + } else { + write!(self.out, "{fun_name}(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")")?; + } + } + Function::MissingIntReturnType(fun_name) => { + let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind(); + if let Some(ScalarKind::Sint) = scalar_kind { + write!(self.out, "asint({fun_name}(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "{fun_name}(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")")?; + } + } + Function::CountTrailingZeros => { + match *func_ctx.resolve_type(arg, &module.types) { + TypeInner::Vector { size, scalar } => { + let s = match size { + crate::VectorSize::Bi => ".xx", + crate::VectorSize::Tri => ".xxx", + crate::VectorSize::Quad => ".xxxx", + }; + + if let ScalarKind::Uint = scalar.kind { + write!(self.out, "min((32u){s}, firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "asint(min((32u){s}, firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")))")?; + } + } + TypeInner::Scalar(scalar) => { + if let ScalarKind::Uint = scalar.kind { + write!(self.out, "min(32u, firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "asint(min(32u, firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")))")?; + } + } + _ => unreachable!(), + } + + return Ok(()); + } + Function::CountLeadingZeros => { + match *func_ctx.resolve_type(arg, &module.types) { + TypeInner::Vector { size, scalar } => { + let s = match size { + crate::VectorSize::Bi => ".xx", + crate::VectorSize::Tri => ".xxx", + crate::VectorSize::Quad => ".xxxx", + }; + + if let ScalarKind::Uint = scalar.kind { + write!(self.out, "((31u){s} - firstbithigh(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "(")?; + self.write_expr(module, arg, func_ctx)?; + write!( + self.out, + " < (0){s} ? (0){s} : (31){s} - asint(firstbithigh(" + )?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")))")?; + } + } + TypeInner::Scalar(scalar) => { + if let ScalarKind::Uint = scalar.kind { + write!(self.out, "(31u - firstbithigh(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " < 0 ? 0 : 31 - asint(firstbithigh(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")))")?; + } + } + _ => unreachable!(), + } + + return Ok(()); + } + } + } + Expression::Swizzle { + size, + vector, + pattern, + } => { + self.write_expr(module, vector, func_ctx)?; + write!(self.out, ".")?; + for &sc in pattern[..size as usize].iter() { + self.out.write_char(back::COMPONENTS[sc as usize])?; + } + } + Expression::ArrayLength(expr) => { + let var_handle = match func_ctx.expressions[expr] { + Expression::AccessIndex { base, index: _ } => { + match func_ctx.expressions[base] { + Expression::GlobalVariable(handle) => handle, + _ => unreachable!(), + } + } + Expression::GlobalVariable(handle) => handle, + _ => unreachable!(), + }; + + let var = &module.global_variables[var_handle]; + let (offset, stride) = match module.types[var.ty].inner { + TypeInner::Array { stride, .. } => (0, stride), + TypeInner::Struct { ref members, .. } => { + let last = members.last().unwrap(); + let stride = match module.types[last.ty].inner { + TypeInner::Array { stride, .. } => stride, + _ => unreachable!(), + }; + (last.offset, stride) + } + _ => unreachable!(), + }; + + let storage_access = match var.space { + crate::AddressSpace::Storage { access } => access, + _ => crate::StorageAccess::default(), + }; + let wrapped_array_length = WrappedArrayLength { + writable: storage_access.contains(crate::StorageAccess::STORE), + }; + + write!(self.out, "((")?; + self.write_wrapped_array_length_function_name(wrapped_array_length)?; + let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; + write!(self.out, "({var_name}) - {offset}) / {stride})")? + } + Expression::Derivative { axis, ctrl, expr } => { + use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; + if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) { + let tail = match ctrl { + Ctrl::Coarse => "coarse", + Ctrl::Fine => "fine", + Ctrl::None => unreachable!(), + }; + write!(self.out, "abs(ddx_{tail}(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")) + abs(ddy_{tail}(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, "))")? + } else { + let fun_str = match (axis, ctrl) { + (Axis::X, Ctrl::Coarse) => "ddx_coarse", + (Axis::X, Ctrl::Fine) => "ddx_fine", + (Axis::X, Ctrl::None) => "ddx", + (Axis::Y, Ctrl::Coarse) => "ddy_coarse", + (Axis::Y, Ctrl::Fine) => "ddy_fine", + (Axis::Y, Ctrl::None) => "ddy", + (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(), + (Axis::Width, Ctrl::None) => "fwidth", + }; + write!(self.out, "{fun_str}(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")? + } + } + Expression::Relational { fun, argument } => { + use crate::RelationalFunction as Rf; + + let fun_str = match fun { + Rf::All => "all", + Rf::Any => "any", + Rf::IsNan => "isnan", + Rf::IsInf => "isinf", + }; + write!(self.out, "{fun_str}(")?; + self.write_expr(module, argument, func_ctx)?; + write!(self.out, ")")? + } + Expression::Select { + condition, + accept, + reject, + } => { + write!(self.out, "(")?; + self.write_expr(module, condition, func_ctx)?; + write!(self.out, " ? ")?; + self.write_expr(module, accept, func_ctx)?; + write!(self.out, " : ")?; + self.write_expr(module, reject, func_ctx)?; + write!(self.out, ")")? + } + // Not supported yet + Expression::RayQueryGetIntersection { .. } => unreachable!(), + // Nothing to do here, since call expression already cached + Expression::CallResult(_) + | Expression::AtomicResult { .. } + | Expression::WorkGroupUniformLoadResult { .. } + | Expression::RayQueryProceedResult => {} + } + + if !closing_bracket.is_empty() { + write!(self.out, "{closing_bracket}")?; + } + Ok(()) + } + + fn write_named_expr( + &mut self, + module: &Module, + handle: Handle<crate::Expression>, + name: String, + // The expression which is being named. + // Generally, this is the same as handle, except in WorkGroupUniformLoad + named: Handle<crate::Expression>, + ctx: &back::FunctionCtx, + ) -> BackendResult { + match ctx.info[named].ty { + proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner { + TypeInner::Struct { .. } => { + let ty_name = &self.names[&NameKey::Type(ty_handle)]; + write!(self.out, "{ty_name}")?; + } + _ => { + self.write_type(module, ty_handle)?; + } + }, + proc::TypeResolution::Value(ref inner) => { + self.write_value_type(module, inner)?; + } + } + + let resolved = ctx.resolve_type(named, &module.types); + + write!(self.out, " {name}")?; + // If rhs is a array type, we should write array size + if let TypeInner::Array { base, size, .. } = *resolved { + self.write_array_size(module, base, size)?; + } + write!(self.out, " = ")?; + self.write_expr(module, handle, ctx)?; + writeln!(self.out, ";")?; + self.named_expressions.insert(named, name); + + Ok(()) + } + + /// Helper function that write default zero initialization + fn write_default_init(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult { + write!(self.out, "(")?; + self.write_type(module, ty)?; + if let TypeInner::Array { base, size, .. } = module.types[ty].inner { + self.write_array_size(module, base, size)?; + } + write!(self.out, ")0")?; + Ok(()) + } + + fn write_barrier(&mut self, barrier: crate::Barrier, level: back::Level) -> BackendResult { + if barrier.contains(crate::Barrier::STORAGE) { + writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?; + } + if barrier.contains(crate::Barrier::WORK_GROUP) { + writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?; + } + Ok(()) + } +} + +pub(super) struct MatrixType { + pub(super) columns: crate::VectorSize, + pub(super) rows: crate::VectorSize, + pub(super) width: crate::Bytes, +} + +pub(super) fn get_inner_matrix_data( + module: &Module, + handle: Handle<crate::Type>, +) -> Option<MatrixType> { + match module.types[handle].inner { + TypeInner::Matrix { + columns, + rows, + scalar, + } => Some(MatrixType { + columns, + rows, + width: scalar.width, + }), + TypeInner::Array { base, .. } => get_inner_matrix_data(module, base), + _ => None, + } +} + +/// Returns the matrix data if the access chain starting at `base`: +/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true` +/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`] +/// - ends at an expression with resolved type of [`TypeInner::Struct`] +pub(super) fn get_inner_matrix_of_struct_array_member( + module: &Module, + base: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx<'_>, + direct: bool, +) -> Option<MatrixType> { + let mut mat_data = None; + let mut array_base = None; + + let mut current_base = base; + loop { + let mut resolved = func_ctx.resolve_type(current_base, &module.types); + if let TypeInner::Pointer { base, .. } = *resolved { + resolved = &module.types[base].inner; + }; + + match *resolved { + TypeInner::Matrix { + columns, + rows, + scalar, + } => { + mat_data = Some(MatrixType { + columns, + rows, + width: scalar.width, + }) + } + TypeInner::Array { base, .. } => { + array_base = Some(base); + } + TypeInner::Struct { .. } => { + if let Some(array_base) = array_base { + if direct { + return mat_data; + } else { + return get_inner_matrix_data(module, array_base); + } + } + + break; + } + _ => break, + } + + current_base = match func_ctx.expressions[current_base] { + crate::Expression::Access { base, .. } => base, + crate::Expression::AccessIndex { base, .. } => base, + _ => break, + }; + } + None +} + +/// Returns the matrix data if the access chain starting at `base`: +/// - starts with an expression with resolved type of [`TypeInner::Matrix`] +/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`] +/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform) +fn get_inner_matrix_of_global_uniform( + module: &Module, + base: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx<'_>, +) -> Option<MatrixType> { + let mut mat_data = None; + let mut array_base = None; + + let mut current_base = base; + loop { + let mut resolved = func_ctx.resolve_type(current_base, &module.types); + if let TypeInner::Pointer { base, .. } = *resolved { + resolved = &module.types[base].inner; + }; + + match *resolved { + TypeInner::Matrix { + columns, + rows, + scalar, + } => { + mat_data = Some(MatrixType { + columns, + rows, + width: scalar.width, + }) + } + TypeInner::Array { base, .. } => { + array_base = Some(base); + } + _ => break, + } + + current_base = match func_ctx.expressions[current_base] { + crate::Expression::Access { base, .. } => base, + crate::Expression::AccessIndex { base, .. } => base, + crate::Expression::GlobalVariable(handle) + if module.global_variables[handle].space == crate::AddressSpace::Uniform => + { + return mat_data.or_else(|| { + array_base.and_then(|array_base| get_inner_matrix_data(module, array_base)) + }) + } + _ => break, + }; + } + None +} diff --git a/third_party/rust/naga/src/back/mod.rs b/third_party/rust/naga/src/back/mod.rs new file mode 100644 index 0000000000..8100b930e9 --- /dev/null +++ b/third_party/rust/naga/src/back/mod.rs @@ -0,0 +1,273 @@ +/*! +Backend functions that export shader [`Module`](super::Module)s into binary and text formats. +*/ +#![allow(dead_code)] // can be dead if none of the enabled backends need it + +#[cfg(feature = "dot-out")] +pub mod dot; +#[cfg(feature = "glsl-out")] +pub mod glsl; +#[cfg(feature = "hlsl-out")] +pub mod hlsl; +#[cfg(feature = "msl-out")] +pub mod msl; +#[cfg(feature = "spv-out")] +pub mod spv; +#[cfg(feature = "wgsl-out")] +pub mod wgsl; + +const COMPONENTS: &[char] = &['x', 'y', 'z', 'w']; +const INDENT: &str = " "; +const BAKE_PREFIX: &str = "_e"; + +type NeedBakeExpressions = crate::FastHashSet<crate::Handle<crate::Expression>>; + +#[derive(Clone, Copy)] +struct Level(usize); + +impl Level { + const fn next(&self) -> Self { + Level(self.0 + 1) + } +} + +impl std::fmt::Display for Level { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + (0..self.0).try_for_each(|_| formatter.write_str(INDENT)) + } +} + +/// Whether we're generating an entry point or a regular function. +/// +/// Backend languages often require different code for a [`Function`] +/// depending on whether it represents an [`EntryPoint`] or not. +/// Backends can pass common code one of these values to select the +/// right behavior. +/// +/// These values also carry enough information to find the `Function` +/// in the [`Module`]: the `Handle` for a regular function, or the +/// index into [`Module::entry_points`] for an entry point. +/// +/// [`Function`]: crate::Function +/// [`EntryPoint`]: crate::EntryPoint +/// [`Module`]: crate::Module +/// [`Module::entry_points`]: crate::Module::entry_points +enum FunctionType { + /// A regular function. + Function(crate::Handle<crate::Function>), + /// An [`EntryPoint`], and its index in [`Module::entry_points`]. + /// + /// [`EntryPoint`]: crate::EntryPoint + /// [`Module::entry_points`]: crate::Module::entry_points + EntryPoint(crate::proc::EntryPointIndex), +} + +impl FunctionType { + fn is_compute_entry_point(&self, module: &crate::Module) -> bool { + match *self { + FunctionType::EntryPoint(index) => { + module.entry_points[index as usize].stage == crate::ShaderStage::Compute + } + FunctionType::Function(_) => false, + } + } +} + +/// Helper structure that stores data needed when writing the function +struct FunctionCtx<'a> { + /// The current function being written + ty: FunctionType, + /// Analysis about the function + info: &'a crate::valid::FunctionInfo, + /// The expression arena of the current function being written + expressions: &'a crate::Arena<crate::Expression>, + /// Map of expressions that have associated variable names + named_expressions: &'a crate::NamedExpressions, +} + +impl FunctionCtx<'_> { + fn resolve_type<'a>( + &'a self, + handle: crate::Handle<crate::Expression>, + types: &'a crate::UniqueArena<crate::Type>, + ) -> &'a crate::TypeInner { + self.info[handle].ty.inner_with(types) + } + + /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for a local in the current function + const fn name_key(&self, local: crate::Handle<crate::LocalVariable>) -> crate::proc::NameKey { + match self.ty { + FunctionType::Function(handle) => crate::proc::NameKey::FunctionLocal(handle, local), + FunctionType::EntryPoint(idx) => crate::proc::NameKey::EntryPointLocal(idx, local), + } + } + + /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for a function argument. + /// + /// # Panics + /// - If the function arguments are less or equal to `arg` + const fn argument_key(&self, arg: u32) -> crate::proc::NameKey { + match self.ty { + FunctionType::Function(handle) => crate::proc::NameKey::FunctionArgument(handle, arg), + FunctionType::EntryPoint(ep_index) => { + crate::proc::NameKey::EntryPointArgument(ep_index, arg) + } + } + } + + // Returns true if the given expression points to a fixed-function pipeline input. + fn is_fixed_function_input( + &self, + mut expression: crate::Handle<crate::Expression>, + module: &crate::Module, + ) -> Option<crate::BuiltIn> { + let ep_function = match self.ty { + FunctionType::Function(_) => return None, + FunctionType::EntryPoint(ep_index) => &module.entry_points[ep_index as usize].function, + }; + let mut built_in = None; + loop { + match self.expressions[expression] { + crate::Expression::FunctionArgument(arg_index) => { + return match ep_function.arguments[arg_index as usize].binding { + Some(crate::Binding::BuiltIn(bi)) => Some(bi), + _ => built_in, + }; + } + crate::Expression::AccessIndex { base, index } => { + match *self.resolve_type(base, &module.types) { + crate::TypeInner::Struct { ref members, .. } => { + if let Some(crate::Binding::BuiltIn(bi)) = + members[index as usize].binding + { + built_in = Some(bi); + } + } + _ => return None, + } + expression = base; + } + _ => return None, + } + } + } +} + +impl crate::Expression { + /// Returns the ref count, upon reaching which this expression + /// should be considered for baking. + /// + /// Note: we have to cache any expressions that depend on the control flow, + /// or otherwise they may be moved into a non-uniform control flow, accidentally. + /// See the [module-level documentation][emit] for details. + /// + /// [emit]: index.html#expression-evaluation-time + const fn bake_ref_count(&self) -> usize { + match *self { + // accesses are never cached, only loads are + crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => usize::MAX, + // sampling may use the control flow, and image ops look better by themselves + crate::Expression::ImageSample { .. } | crate::Expression::ImageLoad { .. } => 1, + // derivatives use the control flow + crate::Expression::Derivative { .. } => 1, + // TODO: We need a better fix for named `Load` expressions + // More info - https://github.com/gfx-rs/naga/pull/914 + // And https://github.com/gfx-rs/naga/issues/910 + crate::Expression::Load { .. } => 1, + // cache expressions that are referenced multiple times + _ => 2, + } + } +} + +/// Helper function that returns the string corresponding to the [`BinaryOperator`](crate::BinaryOperator) +/// # Notes +/// Used by `glsl-out`, `msl-out`, `wgsl-out`, `hlsl-out`. +const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str { + use crate::BinaryOperator as Bo; + match op { + Bo::Add => "+", + Bo::Subtract => "-", + Bo::Multiply => "*", + Bo::Divide => "/", + Bo::Modulo => "%", + Bo::Equal => "==", + Bo::NotEqual => "!=", + Bo::Less => "<", + Bo::LessEqual => "<=", + Bo::Greater => ">", + Bo::GreaterEqual => ">=", + Bo::And => "&", + Bo::ExclusiveOr => "^", + Bo::InclusiveOr => "|", + Bo::LogicalAnd => "&&", + Bo::LogicalOr => "||", + Bo::ShiftLeft => "<<", + Bo::ShiftRight => ">>", + } +} + +/// Helper function that returns the string corresponding to the [`VectorSize`](crate::VectorSize) +/// # Notes +/// Used by `msl-out`, `wgsl-out`, `hlsl-out`. +const fn vector_size_str(size: crate::VectorSize) -> &'static str { + match size { + crate::VectorSize::Bi => "2", + crate::VectorSize::Tri => "3", + crate::VectorSize::Quad => "4", + } +} + +impl crate::TypeInner { + const fn is_handle(&self) -> bool { + match *self { + crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => true, + _ => false, + } + } +} + +impl crate::Statement { + /// Returns true if the statement directly terminates the current block. + /// + /// Used to decide whether case blocks require a explicit `break`. + pub const fn is_terminator(&self) -> bool { + match *self { + crate::Statement::Break + | crate::Statement::Continue + | crate::Statement::Return { .. } + | crate::Statement::Kill => true, + _ => false, + } + } +} + +bitflags::bitflags! { + /// Ray flags, for a [`RayDesc`]'s `flags` field. + /// + /// Note that these exactly correspond to the SPIR-V "Ray Flags" mask, and + /// the SPIR-V backend passes them directly through to the + /// `OpRayQueryInitializeKHR` instruction. (We have to choose something, so + /// we might as well make one back end's life easier.) + /// + /// [`RayDesc`]: crate::Module::generate_ray_desc_type + #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] + pub struct RayFlag: u32 { + const OPAQUE = 0x01; + const NO_OPAQUE = 0x02; + const TERMINATE_ON_FIRST_HIT = 0x04; + const SKIP_CLOSEST_HIT_SHADER = 0x08; + const CULL_BACK_FACING = 0x10; + const CULL_FRONT_FACING = 0x20; + const CULL_OPAQUE = 0x40; + const CULL_NO_OPAQUE = 0x80; + const SKIP_TRIANGLES = 0x100; + const SKIP_AABBS = 0x200; + } +} + +#[repr(u32)] +enum RayIntersectionType { + Triangle = 1, + BoundingBox = 4, +} diff --git a/third_party/rust/naga/src/back/msl/keywords.rs b/third_party/rust/naga/src/back/msl/keywords.rs new file mode 100644 index 0000000000..f0025bf239 --- /dev/null +++ b/third_party/rust/naga/src/back/msl/keywords.rs @@ -0,0 +1,342 @@ +// MSLS - Metal Shading Language Specification: +// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf +// +// C++ - Standard for Programming Language C++ (N4431) +// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4431.pdf +pub const RESERVED: &[&str] = &[ + // Standard for Programming Language C++ (N4431): 2.5 Alternative tokens + "and", + "bitor", + "or", + "xor", + "compl", + "bitand", + "and_eq", + "or_eq", + "xor_eq", + "not", + "not_eq", + // Standard for Programming Language C++ (N4431): 2.11 Keywords + "alignas", + "alignof", + "asm", + "auto", + "bool", + "break", + "case", + "catch", + "char", + "char16_t", + "char32_t", + "class", + "const", + "constexpr", + "const_cast", + "continue", + "decltype", + "default", + "delete", + "do", + "double", + "dynamic_cast", + "else", + "enum", + "explicit", + "export", + "extern", + "false", + "float", + "for", + "friend", + "goto", + "if", + "inline", + "int", + "long", + "mutable", + "namespace", + "new", + "noexcept", + "nullptr", + "operator", + "private", + "protected", + "public", + "register", + "reinterpret_cast", + "return", + "short", + "signed", + "sizeof", + "static", + "static_assert", + "static_cast", + "struct", + "switch", + "template", + "this", + "thread_local", + "throw", + "true", + "try", + "typedef", + "typeid", + "typename", + "union", + "unsigned", + "using", + "virtual", + "void", + "volatile", + "wchar_t", + "while", + // Metal Shading Language Specification: 1.4.4 Restrictions + "main", + // Metal Shading Language Specification: 2.1 Scalar Data Types + "int8_t", + "uchar", + "uint8_t", + "int16_t", + "ushort", + "uint16_t", + "int32_t", + "uint", + "uint32_t", + "int64_t", + "uint64_t", + "half", + "bfloat", + "size_t", + "ptrdiff_t", + // Metal Shading Language Specification: 2.2 Vector Data Types + "bool2", + "bool3", + "bool4", + "char2", + "char3", + "char4", + "short2", + "short3", + "short4", + "int2", + "int3", + "int4", + "long2", + "long3", + "long4", + "uchar2", + "uchar3", + "uchar4", + "ushort2", + "ushort3", + "ushort4", + "uint2", + "uint3", + "uint4", + "ulong2", + "ulong3", + "ulong4", + "half2", + "half3", + "half4", + "bfloat2", + "bfloat3", + "bfloat4", + "float2", + "float3", + "float4", + "vec", + // Metal Shading Language Specification: 2.2.3 Packed Vector Types + "packed_bool2", + "packed_bool3", + "packed_bool4", + "packed_char2", + "packed_char3", + "packed_char4", + "packed_short2", + "packed_short3", + "packed_short4", + "packed_int2", + "packed_int3", + "packed_int4", + "packed_uchar2", + "packed_uchar3", + "packed_uchar4", + "packed_ushort2", + "packed_ushort3", + "packed_ushort4", + "packed_uint2", + "packed_uint3", + "packed_uint4", + "packed_half2", + "packed_half3", + "packed_half4", + "packed_bfloat2", + "packed_bfloat3", + "packed_bfloat4", + "packed_float2", + "packed_float3", + "packed_float4", + "packed_long2", + "packed_long3", + "packed_long4", + "packed_vec", + // Metal Shading Language Specification: 2.3 Matrix Data Types + "half2x2", + "half2x3", + "half2x4", + "half3x2", + "half3x3", + "half3x4", + "half4x2", + "half4x3", + "half4x4", + "float2x2", + "float2x3", + "float2x4", + "float3x2", + "float3x3", + "float3x4", + "float4x2", + "float4x3", + "float4x4", + "matrix", + // Metal Shading Language Specification: 2.6 Atomic Data Types + "atomic", + "atomic_int", + "atomic_uint", + "atomic_bool", + "atomic_ulong", + "atomic_float", + // Metal Shading Language Specification: 2.20 Type Conversions and Re-interpreting Data + "as_type", + // Metal Shading Language Specification: 4 Address Spaces + "device", + "constant", + "thread", + "threadgroup", + "threadgroup_imageblock", + "ray_data", + "object_data", + // Metal Shading Language Specification: 5.1 Functions + "vertex", + "fragment", + "kernel", + // Metal Shading Language Specification: 6.1 Namespace and Header Files + "metal", + // C99 / C++ extension: + "restrict", + // Metal reserved types in <metal_types>: + "llong", + "ullong", + "quad", + "complex", + "imaginary", + // Constants in <metal_types>: + "CHAR_BIT", + "SCHAR_MAX", + "SCHAR_MIN", + "UCHAR_MAX", + "CHAR_MAX", + "CHAR_MIN", + "USHRT_MAX", + "SHRT_MAX", + "SHRT_MIN", + "UINT_MAX", + "INT_MAX", + "INT_MIN", + "ULONG_MAX", + "LONG_MAX", + "LONG_MIN", + "ULLONG_MAX", + "LLONG_MAX", + "LLONG_MIN", + "FLT_DIG", + "FLT_MANT_DIG", + "FLT_MAX_10_EXP", + "FLT_MAX_EXP", + "FLT_MIN_10_EXP", + "FLT_MIN_EXP", + "FLT_RADIX", + "FLT_MAX", + "FLT_MIN", + "FLT_EPSILON", + "FLT_DECIMAL_DIG", + "FP_ILOGB0", + "FP_ILOGB0", + "FP_ILOGBNAN", + "FP_ILOGBNAN", + "MAXFLOAT", + "HUGE_VALF", + "INFINITY", + "NAN", + "M_E_F", + "M_LOG2E_F", + "M_LOG10E_F", + "M_LN2_F", + "M_LN10_F", + "M_PI_F", + "M_PI_2_F", + "M_PI_4_F", + "M_1_PI_F", + "M_2_PI_F", + "M_2_SQRTPI_F", + "M_SQRT2_F", + "M_SQRT1_2_F", + "HALF_DIG", + "HALF_MANT_DIG", + "HALF_MAX_10_EXP", + "HALF_MAX_EXP", + "HALF_MIN_10_EXP", + "HALF_MIN_EXP", + "HALF_RADIX", + "HALF_MAX", + "HALF_MIN", + "HALF_EPSILON", + "HALF_DECIMAL_DIG", + "MAXHALF", + "HUGE_VALH", + "M_E_H", + "M_LOG2E_H", + "M_LOG10E_H", + "M_LN2_H", + "M_LN10_H", + "M_PI_H", + "M_PI_2_H", + "M_PI_4_H", + "M_1_PI_H", + "M_2_PI_H", + "M_2_SQRTPI_H", + "M_SQRT2_H", + "M_SQRT1_2_H", + "DBL_DIG", + "DBL_MANT_DIG", + "DBL_MAX_10_EXP", + "DBL_MAX_EXP", + "DBL_MIN_10_EXP", + "DBL_MIN_EXP", + "DBL_RADIX", + "DBL_MAX", + "DBL_MIN", + "DBL_EPSILON", + "DBL_DECIMAL_DIG", + "MAXDOUBLE", + "HUGE_VAL", + "M_E", + "M_LOG2E", + "M_LOG10E", + "M_LN2", + "M_LN10", + "M_PI", + "M_PI_2", + "M_PI_4", + "M_1_PI", + "M_2_PI", + "M_2_SQRTPI", + "M_SQRT2", + "M_SQRT1_2", + // Naga utilities + "DefaultConstructible", + super::writer::FREXP_FUNCTION, + super::writer::MODF_FUNCTION, +]; diff --git a/third_party/rust/naga/src/back/msl/mod.rs b/third_party/rust/naga/src/back/msl/mod.rs new file mode 100644 index 0000000000..5ef18730c9 --- /dev/null +++ b/third_party/rust/naga/src/back/msl/mod.rs @@ -0,0 +1,541 @@ +/*! +Backend for [MSL][msl] (Metal Shading Language). + +## Binding model + +Metal's bindings are flat per resource. Since there isn't an obvious mapping +from SPIR-V's descriptor sets, we require a separate mapping provided in the options. +This mapping may have one or more resource end points for each descriptor set + index +pair. + +## Entry points + +Even though MSL and our IR appear to be similar in that the entry points in both can +accept arguments and return values, the restrictions are different. +MSL allows the varyings to be either in separate arguments, or inside a single +`[[stage_in]]` struct. We gather input varyings and form this artificial structure. +We also add all the (non-Private) globals into the arguments. + +At the beginning of the entry point, we assign the local constants and re-compose +the arguments as they are declared on IR side, so that the rest of the logic can +pretend that MSL doesn't have all the restrictions it has. + +For the result type, if it's a structure, we re-compose it with a temporary value +holding the result. + +[msl]: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf +*/ + +use crate::{arena::Handle, proc::index, valid::ModuleInfo}; +use std::fmt::{Error as FmtError, Write}; + +mod keywords; +pub mod sampler; +mod writer; + +pub use writer::Writer; + +pub type Slot = u8; +pub type InlineSamplerIndex = u8; + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum BindSamplerTarget { + Resource(Slot), + Inline(InlineSamplerIndex), +} + +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))] +pub struct BindTarget { + pub buffer: Option<Slot>, + pub texture: Option<Slot>, + pub sampler: Option<BindSamplerTarget>, + /// If the binding is an unsized binding array, this overrides the size. + pub binding_array_size: Option<u32>, + pub mutable: bool, +} + +// Using `BTreeMap` instead of `HashMap` so that we can hash itself. +pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>; + +#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))] +pub struct EntryPointResources { + pub resources: BindingMap, + + pub push_constant_buffer: Option<Slot>, + + /// The slot of a buffer that contains an array of `u32`, + /// one for the size of each bound buffer that contains a runtime array, + /// in order of [`crate::GlobalVariable`] declarations. + pub sizes_buffer: Option<Slot>, +} + +pub type EntryPointResourceMap = std::collections::BTreeMap<String, EntryPointResources>; + +enum ResolvedBinding { + BuiltIn(crate::BuiltIn), + Attribute(u32), + Color { + location: u32, + second_blend_source: bool, + }, + User { + prefix: &'static str, + index: u32, + interpolation: Option<ResolvedInterpolation>, + }, + Resource(BindTarget), +} + +#[derive(Copy, Clone)] +enum ResolvedInterpolation { + CenterPerspective, + CenterNoPerspective, + CentroidPerspective, + CentroidNoPerspective, + SamplePerspective, + SampleNoPerspective, + Flat, +} + +// Note: some of these should be removed in favor of proper IR validation. + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error(transparent)] + Format(#[from] FmtError), + #[error("bind target {0:?} is empty")] + UnimplementedBindTarget(BindTarget), + #[error("composing of {0:?} is not implemented yet")] + UnsupportedCompose(Handle<crate::Type>), + #[error("operation {0:?} is not implemented yet")] + UnsupportedBinaryOp(crate::BinaryOperator), + #[error("standard function '{0}' is not implemented yet")] + UnsupportedCall(String), + #[error("feature '{0}' is not implemented yet")] + FeatureNotImplemented(String), + #[error("module is not valid")] + Validation, + #[error("BuiltIn {0:?} is not supported")] + UnsupportedBuiltIn(crate::BuiltIn), + #[error("capability {0:?} is not supported")] + CapabilityNotSupported(crate::valid::Capabilities), + #[error("attribute '{0}' is not supported for target MSL version")] + UnsupportedAttribute(String), + #[error("function '{0}' is not supported for target MSL version")] + UnsupportedFunction(String), + #[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")] + UnsupportedWriteableStorageBuffer, + #[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")] + UnsupportedWriteableStorageTexture(crate::ShaderStage), + #[error("can not use read-write storage textures prior to MSL 1.2")] + UnsupportedRWStorageTexture, + #[error("array of '{0}' is not supported for target MSL version")] + UnsupportedArrayOf(String), + #[error("array of type '{0:?}' is not supported")] + UnsupportedArrayOfType(Handle<crate::Type>), + #[error("ray tracing is not supported prior to MSL 2.3")] + UnsupportedRayTracing, +} + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum EntryPointError { + #[error("global '{0}' doesn't have a binding")] + MissingBinding(String), + #[error("mapping of {0:?} is missing")] + MissingBindTarget(crate::ResourceBinding), + #[error("mapping for push constants is missing")] + MissingPushConstants, + #[error("mapping for sizes buffer is missing")] + MissingSizesBuffer, +} + +/// Points in the MSL code where we might emit a pipeline input or output. +/// +/// Note that, even though vertex shaders' outputs are always fragment +/// shaders' inputs, we still need to distinguish `VertexOutput` and +/// `FragmentInput`, since there are certain differences in the way +/// [`ResolvedBinding`s] are represented on either side. +/// +/// [`ResolvedBinding`s]: ResolvedBinding +#[derive(Clone, Copy, Debug)] +enum LocationMode { + /// Input to the vertex shader. + VertexInput, + + /// Output from the vertex shader. + VertexOutput, + + /// Input to the fragment shader. + FragmentInput, + + /// Output from the fragment shader. + FragmentOutput, + + /// Compute shader input or output. + Uniform, +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct Options { + /// (Major, Minor) target version of the Metal Shading Language. + pub lang_version: (u8, u8), + /// Map of entry-point resources, indexed by entry point function name, to slots. + pub per_entry_point_map: EntryPointResourceMap, + /// Samplers to be inlined into the code. + pub inline_samplers: Vec<sampler::InlineSampler>, + /// Make it possible to link different stages via SPIRV-Cross. + pub spirv_cross_compatibility: bool, + /// Don't panic on missing bindings, instead generate invalid MSL. + pub fake_missing_bindings: bool, + /// Bounds checking policies. + #[cfg_attr(feature = "deserialize", serde(default))] + pub bounds_check_policies: index::BoundsCheckPolicies, + /// Should workgroup variables be zero initialized (by polyfilling)? + pub zero_initialize_workgroup_memory: bool, +} + +impl Default for Options { + fn default() -> Self { + Options { + lang_version: (1, 0), + per_entry_point_map: EntryPointResourceMap::default(), + inline_samplers: Vec::new(), + spirv_cross_compatibility: false, + fake_missing_bindings: true, + bounds_check_policies: index::BoundsCheckPolicies::default(), + zero_initialize_workgroup_memory: true, + } + } +} + +/// A subset of options that are meant to be changed per pipeline. +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct PipelineOptions { + /// Allow `BuiltIn::PointSize` and inject it if doesn't exist. + /// + /// Metal doesn't like this for non-point primitive topologies and requires it for + /// point primitive topologies. + /// + /// Enable this for vertex shaders with point primitive topologies. + pub allow_and_force_point_size: bool, +} + +impl Options { + fn resolve_local_binding( + &self, + binding: &crate::Binding, + mode: LocationMode, + ) -> Result<ResolvedBinding, Error> { + match *binding { + crate::Binding::BuiltIn(mut built_in) => { + match built_in { + crate::BuiltIn::Position { ref mut invariant } => { + if *invariant && self.lang_version < (2, 1) { + return Err(Error::UnsupportedAttribute("invariant".to_string())); + } + + // The 'invariant' attribute may only appear on vertex + // shader outputs, not fragment shader inputs. + if !matches!(mode, LocationMode::VertexOutput) { + *invariant = false; + } + } + crate::BuiltIn::BaseInstance if self.lang_version < (1, 2) => { + return Err(Error::UnsupportedAttribute("base_instance".to_string())); + } + crate::BuiltIn::InstanceIndex if self.lang_version < (1, 2) => { + return Err(Error::UnsupportedAttribute("instance_id".to_string())); + } + // macOS: Since Metal 2.2 + // iOS: Since Metal 2.3 (check depends on https://github.com/gfx-rs/naga/issues/2164) + crate::BuiltIn::PrimitiveIndex if self.lang_version < (2, 2) => { + return Err(Error::UnsupportedAttribute("primitive_id".to_string())); + } + _ => {} + } + + Ok(ResolvedBinding::BuiltIn(built_in)) + } + crate::Binding::Location { + location, + interpolation, + sampling, + second_blend_source, + } => match mode { + LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(location)), + LocationMode::FragmentOutput => { + if second_blend_source && self.lang_version < (1, 2) { + return Err(Error::UnsupportedAttribute( + "second_blend_source".to_string(), + )); + } + Ok(ResolvedBinding::Color { + location, + second_blend_source, + }) + } + LocationMode::VertexOutput | LocationMode::FragmentInput => { + Ok(ResolvedBinding::User { + prefix: if self.spirv_cross_compatibility { + "locn" + } else { + "loc" + }, + index: location, + interpolation: { + // unwrap: The verifier ensures that vertex shader outputs and fragment + // shader inputs always have fully specified interpolation, and that + // sampling is `None` only for Flat interpolation. + let interpolation = interpolation.unwrap(); + let sampling = sampling.unwrap_or(crate::Sampling::Center); + Some(ResolvedInterpolation::from_binding(interpolation, sampling)) + }, + }) + } + LocationMode::Uniform => { + log::error!( + "Unexpected Binding::Location({}) for the Uniform mode", + location + ); + Err(Error::Validation) + } + }, + } + } + + fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> { + self.per_entry_point_map.get(&ep.name) + } + + fn get_resource_binding_target( + &self, + ep: &crate::EntryPoint, + res_binding: &crate::ResourceBinding, + ) -> Option<&BindTarget> { + self.get_entry_point_resources(ep) + .and_then(|res| res.resources.get(res_binding)) + } + + fn resolve_resource_binding( + &self, + ep: &crate::EntryPoint, + res_binding: &crate::ResourceBinding, + ) -> Result<ResolvedBinding, EntryPointError> { + let target = self.get_resource_binding_target(ep, res_binding); + match target { + Some(target) => Ok(ResolvedBinding::Resource(target.clone())), + None if self.fake_missing_bindings => Ok(ResolvedBinding::User { + prefix: "fake", + index: 0, + interpolation: None, + }), + None => Err(EntryPointError::MissingBindTarget(res_binding.clone())), + } + } + + fn resolve_push_constants( + &self, + ep: &crate::EntryPoint, + ) -> Result<ResolvedBinding, EntryPointError> { + let slot = self + .get_entry_point_resources(ep) + .and_then(|res| res.push_constant_buffer); + match slot { + Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { + buffer: Some(slot), + ..Default::default() + })), + None if self.fake_missing_bindings => Ok(ResolvedBinding::User { + prefix: "fake", + index: 0, + interpolation: None, + }), + None => Err(EntryPointError::MissingPushConstants), + } + } + + fn resolve_sizes_buffer( + &self, + ep: &crate::EntryPoint, + ) -> Result<ResolvedBinding, EntryPointError> { + let slot = self + .get_entry_point_resources(ep) + .and_then(|res| res.sizes_buffer); + match slot { + Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { + buffer: Some(slot), + ..Default::default() + })), + None if self.fake_missing_bindings => Ok(ResolvedBinding::User { + prefix: "fake", + index: 0, + interpolation: None, + }), + None => Err(EntryPointError::MissingSizesBuffer), + } + } +} + +impl ResolvedBinding { + fn as_inline_sampler<'a>(&self, options: &'a Options) -> Option<&'a sampler::InlineSampler> { + match *self { + Self::Resource(BindTarget { + sampler: Some(BindSamplerTarget::Inline(index)), + .. + }) => Some(&options.inline_samplers[index as usize]), + _ => None, + } + } + + const fn as_bind_target(&self) -> Option<&BindTarget> { + match *self { + Self::Resource(ref target) => Some(target), + _ => None, + } + } + + fn try_fmt<W: Write>(&self, out: &mut W) -> Result<(), Error> { + write!(out, " [[")?; + match *self { + Self::BuiltIn(built_in) => { + use crate::BuiltIn as Bi; + let name = match built_in { + Bi::Position { invariant: false } => "position", + Bi::Position { invariant: true } => "position, invariant", + // vertex + Bi::BaseInstance => "base_instance", + Bi::BaseVertex => "base_vertex", + Bi::ClipDistance => "clip_distance", + Bi::InstanceIndex => "instance_id", + Bi::PointSize => "point_size", + Bi::VertexIndex => "vertex_id", + // fragment + Bi::FragDepth => "depth(any)", + Bi::PointCoord => "point_coord", + Bi::FrontFacing => "front_facing", + Bi::PrimitiveIndex => "primitive_id", + Bi::SampleIndex => "sample_id", + Bi::SampleMask => "sample_mask", + // compute + Bi::GlobalInvocationId => "thread_position_in_grid", + Bi::LocalInvocationId => "thread_position_in_threadgroup", + Bi::LocalInvocationIndex => "thread_index_in_threadgroup", + Bi::WorkGroupId => "threadgroup_position_in_grid", + Bi::WorkGroupSize => "dispatch_threads_per_threadgroup", + Bi::NumWorkGroups => "threadgroups_per_grid", + Bi::CullDistance | Bi::ViewIndex => { + return Err(Error::UnsupportedBuiltIn(built_in)) + } + }; + write!(out, "{name}")?; + } + Self::Attribute(index) => write!(out, "attribute({index})")?, + Self::Color { + location, + second_blend_source, + } => { + if second_blend_source { + write!(out, "color({location}) index(1)")? + } else { + write!(out, "color({location})")? + } + } + Self::User { + prefix, + index, + interpolation, + } => { + write!(out, "user({prefix}{index})")?; + if let Some(interpolation) = interpolation { + write!(out, ", ")?; + interpolation.try_fmt(out)?; + } + } + Self::Resource(ref target) => { + if let Some(id) = target.buffer { + write!(out, "buffer({id})")?; + } else if let Some(id) = target.texture { + write!(out, "texture({id})")?; + } else if let Some(BindSamplerTarget::Resource(id)) = target.sampler { + write!(out, "sampler({id})")?; + } else { + return Err(Error::UnimplementedBindTarget(target.clone())); + } + } + } + write!(out, "]]")?; + Ok(()) + } +} + +impl ResolvedInterpolation { + const fn from_binding(interpolation: crate::Interpolation, sampling: crate::Sampling) -> Self { + use crate::Interpolation as I; + use crate::Sampling as S; + + match (interpolation, sampling) { + (I::Perspective, S::Center) => Self::CenterPerspective, + (I::Perspective, S::Centroid) => Self::CentroidPerspective, + (I::Perspective, S::Sample) => Self::SamplePerspective, + (I::Linear, S::Center) => Self::CenterNoPerspective, + (I::Linear, S::Centroid) => Self::CentroidNoPerspective, + (I::Linear, S::Sample) => Self::SampleNoPerspective, + (I::Flat, _) => Self::Flat, + } + } + + fn try_fmt<W: Write>(self, out: &mut W) -> Result<(), Error> { + let identifier = match self { + Self::CenterPerspective => "center_perspective", + Self::CenterNoPerspective => "center_no_perspective", + Self::CentroidPerspective => "centroid_perspective", + Self::CentroidNoPerspective => "centroid_no_perspective", + Self::SamplePerspective => "sample_perspective", + Self::SampleNoPerspective => "sample_no_perspective", + Self::Flat => "flat", + }; + out.write_str(identifier)?; + Ok(()) + } +} + +/// Information about a translated module that is required +/// for the use of the result. +pub struct TranslationInfo { + /// Mapping of the entry point names. Each item in the array + /// corresponds to an entry point index. + /// + ///Note: Some entry points may fail translation because of missing bindings. + pub entry_point_names: Vec<Result<String, EntryPointError>>, +} + +pub fn write_string( + module: &crate::Module, + info: &ModuleInfo, + options: &Options, + pipeline_options: &PipelineOptions, +) -> Result<(String, TranslationInfo), Error> { + let mut w = writer::Writer::new(String::new()); + let info = w.write(module, info, options, pipeline_options)?; + Ok((w.finish(), info)) +} + +#[test] +fn test_error_size() { + use std::mem::size_of; + assert_eq!(size_of::<Error>(), 32); +} diff --git a/third_party/rust/naga/src/back/msl/sampler.rs b/third_party/rust/naga/src/back/msl/sampler.rs new file mode 100644 index 0000000000..0bf987076d --- /dev/null +++ b/third_party/rust/naga/src/back/msl/sampler.rs @@ -0,0 +1,176 @@ +#[cfg(feature = "deserialize")] +use serde::Deserialize; +#[cfg(feature = "serialize")] +use serde::Serialize; +use std::{num::NonZeroU32, ops::Range}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum Coord { + Normalized, + Pixel, +} + +impl Default for Coord { + fn default() -> Self { + Self::Normalized + } +} + +impl Coord { + pub const fn as_str(&self) -> &'static str { + match *self { + Self::Normalized => "normalized", + Self::Pixel => "pixel", + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum Address { + Repeat, + MirroredRepeat, + ClampToEdge, + ClampToZero, + ClampToBorder, +} + +impl Default for Address { + fn default() -> Self { + Self::ClampToEdge + } +} + +impl Address { + pub const fn as_str(&self) -> &'static str { + match *self { + Self::Repeat => "repeat", + Self::MirroredRepeat => "mirrored_repeat", + Self::ClampToEdge => "clamp_to_edge", + Self::ClampToZero => "clamp_to_zero", + Self::ClampToBorder => "clamp_to_border", + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum BorderColor { + TransparentBlack, + OpaqueBlack, + OpaqueWhite, +} + +impl Default for BorderColor { + fn default() -> Self { + Self::TransparentBlack + } +} + +impl BorderColor { + pub const fn as_str(&self) -> &'static str { + match *self { + Self::TransparentBlack => "transparent_black", + Self::OpaqueBlack => "opaque_black", + Self::OpaqueWhite => "opaque_white", + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum Filter { + Nearest, + Linear, +} + +impl Filter { + pub const fn as_str(&self) -> &'static str { + match *self { + Self::Nearest => "nearest", + Self::Linear => "linear", + } + } +} + +impl Default for Filter { + fn default() -> Self { + Self::Nearest + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum CompareFunc { + Never, + Less, + LessEqual, + Greater, + GreaterEqual, + Equal, + NotEqual, + Always, +} + +impl Default for CompareFunc { + fn default() -> Self { + Self::Never + } +} + +impl CompareFunc { + pub const fn as_str(&self) -> &'static str { + match *self { + Self::Never => "never", + Self::Less => "less", + Self::LessEqual => "less_equal", + Self::Greater => "greater", + Self::GreaterEqual => "greater_equal", + Self::Equal => "equal", + Self::NotEqual => "not_equal", + Self::Always => "always", + } + } +} + +#[derive(Clone, Debug, Default, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub struct InlineSampler { + pub coord: Coord, + pub address: [Address; 3], + pub border_color: BorderColor, + pub mag_filter: Filter, + pub min_filter: Filter, + pub mip_filter: Option<Filter>, + pub lod_clamp: Option<Range<f32>>, + pub max_anisotropy: Option<NonZeroU32>, + pub compare_func: CompareFunc, +} + +impl Eq for InlineSampler {} + +#[allow(renamed_and_removed_lints)] +#[allow(clippy::derive_hash_xor_eq)] +impl std::hash::Hash for InlineSampler { + fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) { + self.coord.hash(hasher); + self.address.hash(hasher); + self.border_color.hash(hasher); + self.mag_filter.hash(hasher); + self.min_filter.hash(hasher); + self.mip_filter.hash(hasher); + self.lod_clamp + .as_ref() + .map(|range| (range.start.to_bits(), range.end.to_bits())) + .hash(hasher); + self.max_anisotropy.hash(hasher); + self.compare_func.hash(hasher); + } +} diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs new file mode 100644 index 0000000000..1e496b5f50 --- /dev/null +++ b/third_party/rust/naga/src/back/msl/writer.rs @@ -0,0 +1,4659 @@ +use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo}; +use crate::{ + arena::Handle, + back, + proc::index, + proc::{self, NameKey, TypeResolution}, + valid, FastHashMap, FastHashSet, +}; +use bit_set::BitSet; +use std::{ + fmt::{Display, Error as FmtError, Formatter, Write}, + iter, +}; + +/// Shorthand result used internally by the backend +type BackendResult = Result<(), Error>; + +const NAMESPACE: &str = "metal"; +// The name of the array member of the Metal struct types we generate to +// represent Naga `Array` types. See the comments in `Writer::write_type_defs` +// for details. +const WRAPPED_ARRAY_FIELD: &str = "inner"; +// This is a hack: we need to pass a pointer to an atomic, +// but generally the backend isn't putting "&" in front of every pointer. +// Some more general handling of pointers is needed to be implemented here. +const ATOMIC_REFERENCE: &str = "&"; + +const RT_NAMESPACE: &str = "metal::raytracing"; +const RAY_QUERY_TYPE: &str = "_RayQuery"; +const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector"; +const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection"; +const RAY_QUERY_FIELD_READY: &str = "ready"; +const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type"; + +pub(crate) const MODF_FUNCTION: &str = "naga_modf"; +pub(crate) const FREXP_FUNCTION: &str = "naga_frexp"; + +/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix. +/// +/// The `sizes` slice determines whether this function writes a +/// scalar, vector, or matrix type: +/// +/// - An empty slice produces a scalar type. +/// - A one-element slice produces a vector type. +/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size. +fn put_numeric_type( + out: &mut impl Write, + scalar: crate::Scalar, + sizes: &[crate::VectorSize], +) -> Result<(), FmtError> { + match (scalar, sizes) { + (scalar, &[]) => { + write!(out, "{}", scalar.to_msl_name()) + } + (scalar, &[rows]) => { + write!( + out, + "{}::{}{}", + NAMESPACE, + scalar.to_msl_name(), + back::vector_size_str(rows) + ) + } + (scalar, &[rows, columns]) => { + write!( + out, + "{}::{}{}x{}", + NAMESPACE, + scalar.to_msl_name(), + back::vector_size_str(columns), + back::vector_size_str(rows) + ) + } + (_, _) => Ok(()), // not meaningful + } +} + +/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions. +const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e"; + +struct TypeContext<'a> { + handle: Handle<crate::Type>, + gctx: proc::GlobalCtx<'a>, + names: &'a FastHashMap<NameKey, String>, + access: crate::StorageAccess, + binding: Option<&'a super::ResolvedBinding>, + first_time: bool, +} + +impl<'a> Display for TypeContext<'a> { + fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> { + let ty = &self.gctx.types[self.handle]; + if ty.needs_alias() && !self.first_time { + let name = &self.names[&NameKey::Type(self.handle)]; + return write!(out, "{name}"); + } + + match ty.inner { + crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]), + crate::TypeInner::Atomic(scalar) => { + write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name()) + } + crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]), + crate::TypeInner::Matrix { columns, rows, .. } => { + put_numeric_type(out, crate::Scalar::F32, &[rows, columns]) + } + crate::TypeInner::Pointer { base, space } => { + let sub = Self { + handle: base, + first_time: false, + ..*self + }; + let space_name = match space.to_msl_name() { + Some(name) => name, + None => return Ok(()), + }; + write!(out, "{space_name} {sub}&") + } + crate::TypeInner::ValuePointer { + size, + scalar, + space, + } => { + match space.to_msl_name() { + Some(name) => write!(out, "{name} ")?, + None => return Ok(()), + }; + match size { + Some(rows) => put_numeric_type(out, scalar, &[rows])?, + None => put_numeric_type(out, scalar, &[])?, + }; + + write!(out, "&") + } + crate::TypeInner::Array { base, .. } => { + let sub = Self { + handle: base, + first_time: false, + ..*self + }; + // Array lengths go at the end of the type definition, + // so just print the element type here. + write!(out, "{sub}") + } + crate::TypeInner::Struct { .. } => unreachable!(), + crate::TypeInner::Image { + dim, + arrayed, + class, + } => { + let dim_str = match dim { + crate::ImageDimension::D1 => "1d", + crate::ImageDimension::D2 => "2d", + crate::ImageDimension::D3 => "3d", + crate::ImageDimension::Cube => "cube", + }; + let (texture_str, msaa_str, kind, access) = match class { + crate::ImageClass::Sampled { kind, multi } => { + let (msaa_str, access) = if multi { + ("_ms", "read") + } else { + ("", "sample") + }; + ("texture", msaa_str, kind, access) + } + crate::ImageClass::Depth { multi } => { + let (msaa_str, access) = if multi { + ("_ms", "read") + } else { + ("", "sample") + }; + ("depth", msaa_str, crate::ScalarKind::Float, access) + } + crate::ImageClass::Storage { format, .. } => { + let access = if self + .access + .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE) + { + "read_write" + } else if self.access.contains(crate::StorageAccess::STORE) { + "write" + } else if self.access.contains(crate::StorageAccess::LOAD) { + "read" + } else { + log::warn!( + "Storage access for {:?} (name '{}'): {:?}", + self.handle, + ty.name.as_deref().unwrap_or_default(), + self.access + ); + unreachable!("module is not valid"); + }; + ("texture", "", format.into(), access) + } + }; + let base_name = crate::Scalar { kind, width: 4 }.to_msl_name(); + let array_str = if arrayed { "_array" } else { "" }; + write!( + out, + "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>", + ) + } + crate::TypeInner::Sampler { comparison: _ } => { + write!(out, "{NAMESPACE}::sampler") + } + crate::TypeInner::AccelerationStructure => { + write!(out, "{RT_NAMESPACE}::instance_acceleration_structure") + } + crate::TypeInner::RayQuery => { + write!(out, "{RAY_QUERY_TYPE}") + } + crate::TypeInner::BindingArray { base, size } => { + let base_tyname = Self { + handle: base, + first_time: false, + ..*self + }; + + if let Some(&super::ResolvedBinding::Resource(super::BindTarget { + binding_array_size: Some(override_size), + .. + })) = self.binding + { + write!(out, "{NAMESPACE}::array<{base_tyname}, {override_size}>") + } else if let crate::ArraySize::Constant(size) = size { + write!(out, "{NAMESPACE}::array<{base_tyname}, {size}>") + } else { + unreachable!("metal requires all arrays be constant sized"); + } + } + } + } +} + +struct TypedGlobalVariable<'a> { + module: &'a crate::Module, + names: &'a FastHashMap<NameKey, String>, + handle: Handle<crate::GlobalVariable>, + usage: valid::GlobalUse, + binding: Option<&'a super::ResolvedBinding>, + reference: bool, +} + +impl<'a> TypedGlobalVariable<'a> { + fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult { + let var = &self.module.global_variables[self.handle]; + let name = &self.names[&NameKey::GlobalVariable(self.handle)]; + + let storage_access = match var.space { + crate::AddressSpace::Storage { access } => access, + _ => match self.module.types[var.ty].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage { access, .. }, + .. + } => access, + crate::TypeInner::BindingArray { base, .. } => { + match self.module.types[base].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage { access, .. }, + .. + } => access, + _ => crate::StorageAccess::default(), + } + } + _ => crate::StorageAccess::default(), + }, + }; + let ty_name = TypeContext { + handle: var.ty, + gctx: self.module.to_ctx(), + names: self.names, + access: storage_access, + binding: self.binding, + first_time: false, + }; + + let (space, access, reference) = match var.space.to_msl_name() { + Some(space) if self.reference => { + let access = if var.space.needs_access_qualifier() + && !self.usage.contains(valid::GlobalUse::WRITE) + { + "const" + } else { + "" + }; + (space, access, "&") + } + _ => ("", "", ""), + }; + + Ok(write!( + out, + "{}{}{}{}{}{} {}", + space, + if space.is_empty() { "" } else { " " }, + ty_name, + if access.is_empty() { "" } else { " " }, + access, + reference, + name, + )?) + } +} + +pub struct Writer<W> { + out: W, + names: FastHashMap<NameKey, String>, + named_expressions: crate::NamedExpressions, + /// Set of expressions that need to be baked to avoid unnecessary repetition in output + need_bake_expressions: back::NeedBakeExpressions, + namer: proc::Namer, + #[cfg(test)] + put_expression_stack_pointers: FastHashSet<*const ()>, + #[cfg(test)] + put_block_stack_pointers: FastHashSet<*const ()>, + /// Set of (struct type, struct field index) denoting which fields require + /// padding inserted **before** them (i.e. between fields at index - 1 and index) + struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>, +} + +impl crate::Scalar { + const fn to_msl_name(self) -> &'static str { + use crate::ScalarKind as Sk; + match self { + Self { + kind: Sk::Float, + width: _, + } => "float", + Self { + kind: Sk::Sint, + width: _, + } => "int", + Self { + kind: Sk::Uint, + width: _, + } => "uint", + Self { + kind: Sk::Bool, + width: _, + } => "bool", + Self { + kind: Sk::AbstractInt | Sk::AbstractFloat, + width: _, + } => unreachable!(), + } + } +} + +const fn separate(need_separator: bool) -> &'static str { + if need_separator { + "," + } else { + "" + } +} + +fn should_pack_struct_member( + members: &[crate::StructMember], + span: u32, + index: usize, + module: &crate::Module, +) -> Option<crate::Scalar> { + let member = &members[index]; + + let ty_inner = &module.types[member.ty].inner; + let last_offset = member.offset + ty_inner.size(module.to_ctx()); + let next_offset = match members.get(index + 1) { + Some(next) => next.offset, + None => span, + }; + let is_tight = next_offset == last_offset; + + match *ty_inner { + crate::TypeInner::Vector { + size: crate::VectorSize::Tri, + scalar: scalar @ crate::Scalar { width: 4, .. }, + } if is_tight => Some(scalar), + _ => None, + } +} + +fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool { + match arena[ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + if let Some(member) = members.last() { + if let crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } = arena[member.ty].inner + { + return true; + } + } + false + } + crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } => true, + _ => false, + } +} + +impl crate::AddressSpace { + /// Returns true if global variables in this address space are + /// passed in function arguments. These arguments need to be + /// passed through any functions called from the entry point. + const fn needs_pass_through(&self) -> bool { + match *self { + Self::Uniform + | Self::Storage { .. } + | Self::Private + | Self::WorkGroup + | Self::PushConstant + | Self::Handle => true, + Self::Function => false, + } + } + + /// Returns true if the address space may need a "const" qualifier. + const fn needs_access_qualifier(&self) -> bool { + match *self { + //Note: we are ignoring the storage access here, and instead + // rely on the actual use of a global by functions. This means we + // may end up with "const" even if the binding is read-write, + // and that should be OK. + Self::Storage { .. } => true, + // These should always be read-write. + Self::Private | Self::WorkGroup => false, + // These translate to `constant` address space, no need for qualifiers. + Self::Uniform | Self::PushConstant => false, + // Not applicable. + Self::Handle | Self::Function => false, + } + } + + const fn to_msl_name(self) -> Option<&'static str> { + match self { + Self::Handle => None, + Self::Uniform | Self::PushConstant => Some("constant"), + Self::Storage { .. } => Some("device"), + Self::Private | Self::Function => Some("thread"), + Self::WorkGroup => Some("threadgroup"), + } + } +} + +impl crate::Type { + // Returns `true` if we need to emit an alias for this type. + const fn needs_alias(&self) -> bool { + use crate::TypeInner as Ti; + + match self.inner { + // value types are concise enough, we only alias them if they are named + Ti::Scalar(_) + | Ti::Vector { .. } + | Ti::Matrix { .. } + | Ti::Atomic(_) + | Ti::Pointer { .. } + | Ti::ValuePointer { .. } => self.name.is_some(), + // composite types are better to be aliased, regardless of the name + Ti::Struct { .. } | Ti::Array { .. } => true, + // handle types may be different, depending on the global var access, so we always inline them + Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => false, + } + } +} + +enum FunctionOrigin { + Handle(Handle<crate::Function>), + EntryPoint(proc::EntryPointIndex), +} + +/// A level of detail argument. +/// +/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we +/// save the clamped level of detail in a temporary variable whose name is based +/// on the handle of the `ImageLoad` expression. But for other policies, we just +/// use the expression directly. +/// +/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict +/// [`ImageLoad`]: crate::Expression::ImageLoad +#[derive(Clone, Copy)] +enum LevelOfDetail { + Direct(Handle<crate::Expression>), + Restricted(Handle<crate::Expression>), +} + +/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`]. +/// +/// When this is used in code paths unconcerned with the `Restrict` bounds check +/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level` +/// will always be either `None` or `Some(Direct(_))`. But this turns out not to +/// be too awkward. If that changes, we can revisit. +/// +/// [`ImageLoad`]: crate::Expression::ImageLoad +/// [`ImageStore`]: crate::Statement::ImageStore +struct TexelAddress { + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + sample: Option<Handle<crate::Expression>>, + level: Option<LevelOfDetail>, +} + +struct ExpressionContext<'a> { + function: &'a crate::Function, + origin: FunctionOrigin, + info: &'a valid::FunctionInfo, + module: &'a crate::Module, + mod_info: &'a valid::ModuleInfo, + pipeline_options: &'a PipelineOptions, + lang_version: (u8, u8), + policies: index::BoundsCheckPolicies, + + /// A bitset containing the `Expression` handle indexes of expressions used + /// as indices in `ReadZeroSkipWrite`-policy accesses. These may need to be + /// cached in temporary variables. See `index::find_checked_indexes` for + /// details. + guarded_indices: BitSet, +} + +impl<'a> ExpressionContext<'a> { + fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner { + self.info[handle].ty.inner_with(&self.module.types) + } + + /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail. + /// + /// Only mipmapped images need to specify a level of detail. Since 1D + /// textures cannot have mipmaps, MSL requires that the level argument to + /// texture1d queries and accesses must be a constexpr 0. It's easiest + /// just to omit the level entirely for 1D textures. + fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool { + let image_ty = self.resolve_type(image); + if let crate::TypeInner::Image { dim, class, .. } = *image_ty { + class.is_mipmapped() && dim != crate::ImageDimension::D1 + } else { + false + } + } + + fn choose_bounds_check_policy( + &self, + pointer: Handle<crate::Expression>, + ) -> index::BoundsCheckPolicy { + self.policies + .choose_policy(pointer, &self.module.types, self.info) + } + + fn access_needs_check( + &self, + base: Handle<crate::Expression>, + index: index::GuardedIndex, + ) -> Option<index::IndexableLength> { + index::access_needs_check(base, index, self.module, self.function, self.info) + } + + fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> { + match self.function.expressions[expr_handle] { + crate::Expression::AccessIndex { base, index } => { + let ty = match *self.resolve_type(base) { + crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner, + ref ty => ty, + }; + match *ty { + crate::TypeInner::Struct { + ref members, span, .. + } => should_pack_struct_member(members, span, index as usize, self.module), + _ => None, + } + } + _ => None, + } + } +} + +struct StatementContext<'a> { + expression: ExpressionContext<'a>, + result_struct: Option<&'a str>, +} + +impl<W: Write> Writer<W> { + /// Creates a new `Writer` instance. + pub fn new(out: W) -> Self { + Writer { + out, + names: FastHashMap::default(), + named_expressions: Default::default(), + need_bake_expressions: Default::default(), + namer: proc::Namer::default(), + #[cfg(test)] + put_expression_stack_pointers: Default::default(), + #[cfg(test)] + put_block_stack_pointers: Default::default(), + struct_member_pads: FastHashSet::default(), + } + } + + /// Finishes writing and returns the output. + // See https://github.com/rust-lang/rust-clippy/issues/4979. + #[allow(clippy::missing_const_for_fn)] + pub fn finish(self) -> W { + self.out + } + + fn put_call_parameters( + &mut self, + parameters: impl Iterator<Item = Handle<crate::Expression>>, + context: &ExpressionContext, + ) -> BackendResult { + self.put_call_parameters_impl(parameters, context, |writer, context, expr| { + writer.put_expression(expr, context, true) + }) + } + + fn put_call_parameters_impl<C, E>( + &mut self, + parameters: impl Iterator<Item = Handle<crate::Expression>>, + ctx: &C, + put_expression: E, + ) -> BackendResult + where + E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult, + { + write!(self.out, "(")?; + for (i, handle) in parameters.enumerate() { + if i != 0 { + write!(self.out, ", ")?; + } + put_expression(self, ctx, handle)?; + } + write!(self.out, ")")?; + Ok(()) + } + + fn put_level_of_detail( + &mut self, + level: LevelOfDetail, + context: &ExpressionContext, + ) -> BackendResult { + match level { + LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?, + LevelOfDetail::Restricted(load) => { + write!(self.out, "{}{}", CLAMPED_LOD_LOAD_PREFIX, load.index())? + } + } + Ok(()) + } + + fn put_image_query( + &mut self, + image: Handle<crate::Expression>, + query: &str, + level: Option<LevelOfDetail>, + context: &ExpressionContext, + ) -> BackendResult { + self.put_expression(image, context, false)?; + write!(self.out, ".get_{query}(")?; + if let Some(level) = level { + self.put_level_of_detail(level, context)?; + } + write!(self.out, ")")?; + Ok(()) + } + + fn put_image_size_query( + &mut self, + image: Handle<crate::Expression>, + level: Option<LevelOfDetail>, + kind: crate::ScalarKind, + context: &ExpressionContext, + ) -> BackendResult { + //Note: MSL only has separate width/height/depth queries, + // so compose the result of them. + let dim = match *context.resolve_type(image) { + crate::TypeInner::Image { dim, .. } => dim, + ref other => unreachable!("Unexpected type {:?}", other), + }; + let scalar = crate::Scalar { kind, width: 4 }; + let coordinate_type = scalar.to_msl_name(); + match dim { + crate::ImageDimension::D1 => { + // Since 1D textures never have mipmaps, MSL requires that the + // `level` argument be a constexpr 0. It's simplest for us just + // to pass `None` and omit the level entirely. + if kind == crate::ScalarKind::Uint { + // No need to construct a vector. No cast needed. + self.put_image_query(image, "width", None, context)?; + } else { + // There's no definition for `int` in the `metal` namespace. + write!(self.out, "int(")?; + self.put_image_query(image, "width", None, context)?; + write!(self.out, ")")?; + } + } + crate::ImageDimension::D2 => { + write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?; + self.put_image_query(image, "width", level, context)?; + write!(self.out, ", ")?; + self.put_image_query(image, "height", level, context)?; + write!(self.out, ")")?; + } + crate::ImageDimension::D3 => { + write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?; + self.put_image_query(image, "width", level, context)?; + write!(self.out, ", ")?; + self.put_image_query(image, "height", level, context)?; + write!(self.out, ", ")?; + self.put_image_query(image, "depth", level, context)?; + write!(self.out, ")")?; + } + crate::ImageDimension::Cube => { + write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?; + self.put_image_query(image, "width", level, context)?; + write!(self.out, ")")?; + } + } + Ok(()) + } + + fn put_cast_to_uint_scalar_or_vector( + &mut self, + expr: Handle<crate::Expression>, + context: &ExpressionContext, + ) -> BackendResult { + // coordinates in IR are int, but Metal expects uint + match *context.resolve_type(expr) { + crate::TypeInner::Scalar(_) => { + put_numeric_type(&mut self.out, crate::Scalar::U32, &[])? + } + crate::TypeInner::Vector { size, .. } => { + put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])? + } + _ => return Err(Error::Validation), + }; + + write!(self.out, "(")?; + self.put_expression(expr, context, true)?; + write!(self.out, ")")?; + Ok(()) + } + + fn put_image_sample_level( + &mut self, + image: Handle<crate::Expression>, + level: crate::SampleLevel, + context: &ExpressionContext, + ) -> BackendResult { + let has_levels = context.image_needs_lod(image); + match level { + crate::SampleLevel::Auto => {} + crate::SampleLevel::Zero => { + //TODO: do we support Zero on `Sampled` image classes? + } + _ if !has_levels => { + log::warn!("1D image can't be sampled with level {:?}", level); + } + crate::SampleLevel::Exact(h) => { + write!(self.out, ", {NAMESPACE}::level(")?; + self.put_expression(h, context, true)?; + write!(self.out, ")")?; + } + crate::SampleLevel::Bias(h) => { + write!(self.out, ", {NAMESPACE}::bias(")?; + self.put_expression(h, context, true)?; + write!(self.out, ")")?; + } + crate::SampleLevel::Gradient { x, y } => { + write!(self.out, ", {NAMESPACE}::gradient2d(")?; + self.put_expression(x, context, true)?; + write!(self.out, ", ")?; + self.put_expression(y, context, true)?; + write!(self.out, ")")?; + } + } + Ok(()) + } + + fn put_image_coordinate_limits( + &mut self, + image: Handle<crate::Expression>, + level: Option<LevelOfDetail>, + context: &ExpressionContext, + ) -> BackendResult { + self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?; + write!(self.out, " - 1")?; + Ok(()) + } + + /// General function for writing restricted image indexes. + /// + /// This is used to produce restricted mip levels, array indices, and sample + /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the + /// [`Restrict`] bounds check policy. + /// + /// This function writes an expression of the form: + /// + /// ```ignore + /// + /// metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1) + /// + /// ``` + /// + /// [`ImageLoad`]: crate::Expression::ImageLoad + /// [`ImageStore`]: crate::Statement::ImageStore + /// [`Restrict`]: index::BoundsCheckPolicy::Restrict + fn put_restricted_scalar_image_index( + &mut self, + image: Handle<crate::Expression>, + index: Handle<crate::Expression>, + limit_method: &str, + context: &ExpressionContext, + ) -> BackendResult { + write!(self.out, "{NAMESPACE}::min(uint(")?; + self.put_expression(index, context, true)?; + write!(self.out, "), ")?; + self.put_expression(image, context, false)?; + write!(self.out, ".{limit_method}() - 1)")?; + Ok(()) + } + + fn put_restricted_texel_address( + &mut self, + image: Handle<crate::Expression>, + address: &TexelAddress, + context: &ExpressionContext, + ) -> BackendResult { + // Write the coordinate. + write!(self.out, "{NAMESPACE}::min(")?; + self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?; + write!(self.out, ", ")?; + self.put_image_coordinate_limits(image, address.level, context)?; + write!(self.out, ")")?; + + // Write the array index, if present. + if let Some(array_index) = address.array_index { + write!(self.out, ", ")?; + self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?; + } + + // Write the sample index, if present. + if let Some(sample) = address.sample { + write!(self.out, ", ")?; + self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?; + } + + // The level of detail should be clamped and cached by + // `put_cache_restricted_level`, so we don't need to clamp it here. + if let Some(level) = address.level { + write!(self.out, ", ")?; + self.put_level_of_detail(level, context)?; + } + + Ok(()) + } + + /// Write an expression that is true if the given image access is in bounds. + fn put_image_access_bounds_check( + &mut self, + image: Handle<crate::Expression>, + address: &TexelAddress, + context: &ExpressionContext, + ) -> BackendResult { + let mut conjunction = ""; + + // First, check the level of detail. Only if that is in bounds can we + // use it to find the appropriate bounds for the coordinates. + let level = if let Some(level) = address.level { + write!(self.out, "uint(")?; + self.put_level_of_detail(level, context)?; + write!(self.out, ") < ")?; + self.put_expression(image, context, true)?; + write!(self.out, ".get_num_mip_levels()")?; + conjunction = " && "; + Some(level) + } else { + None + }; + + // Check sample index, if present. + if let Some(sample) = address.sample { + write!(self.out, "uint(")?; + self.put_expression(sample, context, true)?; + write!(self.out, ") < ")?; + self.put_expression(image, context, true)?; + write!(self.out, ".get_num_samples()")?; + conjunction = " && "; + } + + // Check array index, if present. + if let Some(array_index) = address.array_index { + write!(self.out, "{conjunction}uint(")?; + self.put_expression(array_index, context, true)?; + write!(self.out, ") < ")?; + self.put_expression(image, context, true)?; + write!(self.out, ".get_array_size()")?; + conjunction = " && "; + } + + // Finally, check if the coordinates are within bounds. + let coord_is_vector = match *context.resolve_type(address.coordinate) { + crate::TypeInner::Vector { .. } => true, + _ => false, + }; + write!(self.out, "{conjunction}")?; + if coord_is_vector { + write!(self.out, "{NAMESPACE}::all(")?; + } + self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?; + write!(self.out, " < ")?; + self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?; + if coord_is_vector { + write!(self.out, ")")?; + } + + Ok(()) + } + + fn put_image_load( + &mut self, + load: Handle<crate::Expression>, + image: Handle<crate::Expression>, + mut address: TexelAddress, + context: &ExpressionContext, + ) -> BackendResult { + match context.policies.image_load { + proc::BoundsCheckPolicy::Restrict => { + // Use the cached restricted level of detail, if any. Omit the + // level altogether for 1D textures. + if address.level.is_some() { + address.level = if context.image_needs_lod(image) { + Some(LevelOfDetail::Restricted(load)) + } else { + None + } + } + + self.put_expression(image, context, false)?; + write!(self.out, ".read(")?; + self.put_restricted_texel_address(image, &address, context)?; + write!(self.out, ")")?; + } + proc::BoundsCheckPolicy::ReadZeroSkipWrite => { + write!(self.out, "(")?; + self.put_image_access_bounds_check(image, &address, context)?; + write!(self.out, " ? ")?; + self.put_unchecked_image_load(image, &address, context)?; + write!(self.out, ": DefaultConstructible())")?; + } + proc::BoundsCheckPolicy::Unchecked => { + self.put_unchecked_image_load(image, &address, context)?; + } + } + + Ok(()) + } + + fn put_unchecked_image_load( + &mut self, + image: Handle<crate::Expression>, + address: &TexelAddress, + context: &ExpressionContext, + ) -> BackendResult { + self.put_expression(image, context, false)?; + write!(self.out, ".read(")?; + // coordinates in IR are int, but Metal expects uint + self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?; + if let Some(expr) = address.array_index { + write!(self.out, ", ")?; + self.put_expression(expr, context, true)?; + } + if let Some(sample) = address.sample { + write!(self.out, ", ")?; + self.put_expression(sample, context, true)?; + } + if let Some(level) = address.level { + if context.image_needs_lod(image) { + write!(self.out, ", ")?; + self.put_level_of_detail(level, context)?; + } + } + write!(self.out, ")")?; + + Ok(()) + } + + fn put_image_store( + &mut self, + level: back::Level, + image: Handle<crate::Expression>, + address: &TexelAddress, + value: Handle<crate::Expression>, + context: &StatementContext, + ) -> BackendResult { + match context.expression.policies.image_store { + proc::BoundsCheckPolicy::Restrict => { + // We don't have a restricted level value, because we don't + // support writes to mipmapped textures. + debug_assert!(address.level.is_none()); + + write!(self.out, "{level}")?; + self.put_expression(image, &context.expression, false)?; + write!(self.out, ".write(")?; + self.put_expression(value, &context.expression, true)?; + write!(self.out, ", ")?; + self.put_restricted_texel_address(image, address, &context.expression)?; + writeln!(self.out, ");")?; + } + proc::BoundsCheckPolicy::ReadZeroSkipWrite => { + write!(self.out, "{level}if (")?; + self.put_image_access_bounds_check(image, address, &context.expression)?; + writeln!(self.out, ") {{")?; + self.put_unchecked_image_store(level.next(), image, address, value, context)?; + writeln!(self.out, "{level}}}")?; + } + proc::BoundsCheckPolicy::Unchecked => { + self.put_unchecked_image_store(level, image, address, value, context)?; + } + } + + Ok(()) + } + + fn put_unchecked_image_store( + &mut self, + level: back::Level, + image: Handle<crate::Expression>, + address: &TexelAddress, + value: Handle<crate::Expression>, + context: &StatementContext, + ) -> BackendResult { + write!(self.out, "{level}")?; + self.put_expression(image, &context.expression, false)?; + write!(self.out, ".write(")?; + self.put_expression(value, &context.expression, true)?; + write!(self.out, ", ")?; + // coordinates in IR are int, but Metal expects uint + self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?; + if let Some(expr) = address.array_index { + write!(self.out, ", ")?; + self.put_expression(expr, &context.expression, true)?; + } + writeln!(self.out, ");")?; + + Ok(()) + } + + /// Write the maximum valid index of the dynamically sized array at the end of `handle`. + /// + /// The 'maximum valid index' is simply one less than the array's length. + /// + /// This emits an expression of the form `a / b`, so the caller must + /// parenthesize its output if it will be applying operators of higher + /// precedence. + /// + /// `handle` must be the handle of a global variable whose final member is a + /// dynamically sized array. + fn put_dynamic_array_max_index( + &mut self, + handle: Handle<crate::GlobalVariable>, + context: &ExpressionContext, + ) -> BackendResult { + let global = &context.module.global_variables[handle]; + let (offset, array_ty) = match context.module.types[global.ty].inner { + crate::TypeInner::Struct { ref members, .. } => match members.last() { + Some(&crate::StructMember { offset, ty, .. }) => (offset, ty), + None => return Err(Error::Validation), + }, + crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } => (0, global.ty), + _ => return Err(Error::Validation), + }; + + let (size, stride) = match context.module.types[array_ty].inner { + crate::TypeInner::Array { base, stride, .. } => ( + context.module.types[base] + .inner + .size(context.module.to_ctx()), + stride, + ), + _ => return Err(Error::Validation), + }; + + // When the stride length is larger than the size, the final element's stride of + // bytes would have padding following the value. But the buffer size in + // `buffer_sizes.sizeN` may not include this padding - it only needs to be large + // enough to hold the actual values' bytes. + // + // So subtract off the size to get a byte size that falls at the start or within + // the final element. Then divide by the stride size, to get one less than the + // length, and then add one. This works even if the buffer size does include the + // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail + // if there are zero elements in the array, but the WebGPU `validating shader binding` + // rules, together with draw-time validation when `minBindingSize` is zero, + // prevent that. + write!( + self.out, + "(_buffer_sizes.size{idx} - {offset} - {size}) / {stride}", + idx = handle.index(), + offset = offset, + size = size, + stride = stride, + )?; + Ok(()) + } + + fn put_atomic_fetch( + &mut self, + pointer: Handle<crate::Expression>, + key: &str, + value: Handle<crate::Expression>, + context: &ExpressionContext, + ) -> BackendResult { + self.put_atomic_operation(pointer, "fetch_", key, value, context) + } + + fn put_atomic_operation( + &mut self, + pointer: Handle<crate::Expression>, + key1: &str, + key2: &str, + value: Handle<crate::Expression>, + context: &ExpressionContext, + ) -> BackendResult { + // If the pointer we're passing to the atomic operation needs to be conditional + // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and + // the pointer operand should be unchecked. + let policy = context.choose_bounds_check_policy(pointer); + let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite + && self.put_bounds_checks(pointer, context, back::Level(0), "")?; + + // If requested and successfully put bounds checks, continue the ternary expression. + if checked { + write!(self.out, " ? ")?; + } + + write!( + self.out, + "{NAMESPACE}::atomic_{key1}{key2}_explicit({ATOMIC_REFERENCE}" + )?; + self.put_access_chain(pointer, policy, context)?; + write!(self.out, ", ")?; + self.put_expression(value, context, true)?; + write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?; + + // Finish the ternary expression. + if checked { + write!(self.out, " : DefaultConstructible()")?; + } + + Ok(()) + } + + /// Emit code for the arithmetic expression of the dot product. + /// + fn put_dot_product( + &mut self, + arg: Handle<crate::Expression>, + arg1: Handle<crate::Expression>, + size: usize, + context: &ExpressionContext, + ) -> BackendResult { + // Write parentheses around the dot product expression to prevent operators + // with different precedences from applying earlier. + write!(self.out, "(")?; + + // Cycle trough all the components of the vector + for index in 0..size { + let component = back::COMPONENTS[index]; + // Write the addition to the previous product + // This will print an extra '+' at the beginning but that is fine in msl + write!(self.out, " + ")?; + // Write the first vector expression, this expression is marked to be + // cached so unless it can't be cached (for example, it's a Constant) + // it shouldn't produce large expressions. + self.put_expression(arg, context, true)?; + // Access the current component on the first vector + write!(self.out, ".{component} * ")?; + // Write the second vector expression, this expression is marked to be + // cached so unless it can't be cached (for example, it's a Constant) + // it shouldn't produce large expressions. + self.put_expression(arg1, context, true)?; + // Access the current component on the second vector + write!(self.out, ".{component}")?; + } + + write!(self.out, ")")?; + Ok(()) + } + + /// Emit code for the sign(i32) expression. + /// + fn put_isign( + &mut self, + arg: Handle<crate::Expression>, + context: &ExpressionContext, + ) -> BackendResult { + write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?; + match context.resolve_type(arg) { + &crate::TypeInner::Vector { size, .. } => { + let size = back::vector_size_str(size); + write!(self.out, "int{size}(-1), int{size}(1)")?; + } + _ => { + write!(self.out, "-1, 1")?; + } + } + write!(self.out, ", (")?; + self.put_expression(arg, context, true)?; + write!(self.out, " > 0)), 0, (")?; + self.put_expression(arg, context, true)?; + write!(self.out, " == 0))")?; + Ok(()) + } + + fn put_const_expression( + &mut self, + expr_handle: Handle<crate::Expression>, + module: &crate::Module, + mod_info: &valid::ModuleInfo, + ) -> BackendResult { + self.put_possibly_const_expression( + expr_handle, + &module.const_expressions, + module, + mod_info, + &(module, mod_info), + |&(_, mod_info), expr| &mod_info[expr], + |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info), + ) + } + + #[allow(clippy::too_many_arguments)] + fn put_possibly_const_expression<C, I, E>( + &mut self, + expr_handle: Handle<crate::Expression>, + expressions: &crate::Arena<crate::Expression>, + module: &crate::Module, + mod_info: &valid::ModuleInfo, + ctx: &C, + get_expr_ty: I, + put_expression: E, + ) -> BackendResult + where + I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution, + E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult, + { + match expressions[expr_handle] { + crate::Expression::Literal(literal) => match literal { + crate::Literal::F64(_) => { + return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64)) + } + crate::Literal::F32(value) => { + if value.is_infinite() { + let sign = if value.is_sign_negative() { "-" } else { "" }; + write!(self.out, "{sign}INFINITY")?; + } else if value.is_nan() { + write!(self.out, "NAN")?; + } else { + let suffix = if value.fract() == 0.0 { ".0" } else { "" }; + write!(self.out, "{value}{suffix}")?; + } + } + crate::Literal::U32(value) => { + write!(self.out, "{value}u")?; + } + crate::Literal::I32(value) => { + write!(self.out, "{value}")?; + } + crate::Literal::I64(value) => { + write!(self.out, "{value}L")?; + } + crate::Literal::Bool(value) => { + write!(self.out, "{value}")?; + } + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Validation); + } + }, + crate::Expression::Constant(handle) => { + let constant = &module.constants[handle]; + if constant.name.is_some() { + write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?; + } else { + self.put_const_expression(constant.init, module, mod_info)?; + } + } + crate::Expression::ZeroValue(ty) => { + let ty_name = TypeContext { + handle: ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + write!(self.out, "{ty_name} {{}}")?; + } + crate::Expression::Compose { ty, ref components } => { + let ty_name = TypeContext { + handle: ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + write!(self.out, "{ty_name}")?; + match module.types[ty].inner { + crate::TypeInner::Scalar(_) + | crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } => { + self.put_call_parameters_impl( + components.iter().copied(), + ctx, + put_expression, + )?; + } + crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => { + write!(self.out, " {{")?; + for (index, &component) in components.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + // insert padding initialization, if needed + if self.struct_member_pads.contains(&(ty, index as u32)) { + write!(self.out, "{{}}, ")?; + } + put_expression(self, ctx, component)?; + } + write!(self.out, "}}")?; + } + _ => return Err(Error::UnsupportedCompose(ty)), + } + } + crate::Expression::Splat { size, value } => { + let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) { + crate::TypeInner::Scalar(scalar) => scalar, + _ => return Err(Error::Validation), + }; + put_numeric_type(&mut self.out, scalar, &[size])?; + write!(self.out, "(")?; + put_expression(self, ctx, value)?; + write!(self.out, ")")?; + } + _ => unreachable!(), + } + + Ok(()) + } + + /// Emit code for the expression `expr_handle`. + /// + /// The `is_scoped` argument is true if the surrounding operators have the + /// precedence of the comma operator, or lower. So, for example: + /// + /// - Pass `true` for `is_scoped` when writing function arguments, an + /// expression statement, an initializer expression, or anything already + /// wrapped in parenthesis. + /// + /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really + /// almost anything else. + fn put_expression( + &mut self, + expr_handle: Handle<crate::Expression>, + context: &ExpressionContext, + is_scoped: bool, + ) -> BackendResult { + // Add to the set in order to track the stack size. + #[cfg(test)] + #[allow(trivial_casts)] + self.put_expression_stack_pointers + .insert(&expr_handle as *const _ as *const ()); + + if let Some(name) = self.named_expressions.get(&expr_handle) { + write!(self.out, "{name}")?; + return Ok(()); + } + + let expression = &context.function.expressions[expr_handle]; + log::trace!("expression {:?} = {:?}", expr_handle, expression); + match *expression { + crate::Expression::Literal(_) + | crate::Expression::Constant(_) + | crate::Expression::ZeroValue(_) + | crate::Expression::Compose { .. } + | crate::Expression::Splat { .. } => { + self.put_possibly_const_expression( + expr_handle, + &context.function.expressions, + context.module, + context.mod_info, + context, + |context, expr: Handle<crate::Expression>| &context.info[expr].ty, + |writer, context, expr| writer.put_expression(expr, context, true), + )?; + } + crate::Expression::Access { base, .. } + | crate::Expression::AccessIndex { base, .. } => { + // This is an acceptable place to generate a `ReadZeroSkipWrite` check. + // Since `put_bounds_checks` and `put_access_chain` handle an entire + // access chain at a time, recursing back through `put_expression` only + // for index expressions and the base object, we will never see intermediate + // `Access` or `AccessIndex` expressions here. + let policy = context.choose_bounds_check_policy(base); + if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite + && self.put_bounds_checks( + expr_handle, + context, + back::Level(0), + if is_scoped { "" } else { "(" }, + )? + { + write!(self.out, " ? ")?; + self.put_access_chain(expr_handle, policy, context)?; + write!(self.out, " : DefaultConstructible()")?; + + if !is_scoped { + write!(self.out, ")")?; + } + } else { + self.put_access_chain(expr_handle, policy, context)?; + } + } + crate::Expression::Swizzle { + size, + vector, + pattern, + } => { + self.put_wrapped_expression_for_packed_vec3_access(vector, context, false)?; + write!(self.out, ".")?; + for &sc in pattern[..size as usize].iter() { + write!(self.out, "{}", back::COMPONENTS[sc as usize])?; + } + } + crate::Expression::FunctionArgument(index) => { + let name_key = match context.origin { + FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index), + FunctionOrigin::EntryPoint(ep_index) => { + NameKey::EntryPointArgument(ep_index, index) + } + }; + let name = &self.names[&name_key]; + write!(self.out, "{name}")?; + } + crate::Expression::GlobalVariable(handle) => { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, "{name}")?; + } + crate::Expression::LocalVariable(handle) => { + let name_key = match context.origin { + FunctionOrigin::Handle(fun_handle) => { + NameKey::FunctionLocal(fun_handle, handle) + } + FunctionOrigin::EntryPoint(ep_index) => { + NameKey::EntryPointLocal(ep_index, handle) + } + }; + let name = &self.names[&name_key]; + write!(self.out, "{name}")?; + } + crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?, + crate::Expression::ImageSample { + image, + sampler, + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + } => { + let main_op = match gather { + Some(_) => "gather", + None => "sample", + }; + let comparison_op = match depth_ref { + Some(_) => "_compare", + None => "", + }; + self.put_expression(image, context, false)?; + write!(self.out, ".{main_op}{comparison_op}(")?; + self.put_expression(sampler, context, true)?; + write!(self.out, ", ")?; + self.put_expression(coordinate, context, true)?; + if let Some(expr) = array_index { + write!(self.out, ", ")?; + self.put_expression(expr, context, true)?; + } + if let Some(dref) = depth_ref { + write!(self.out, ", ")?; + self.put_expression(dref, context, true)?; + } + + self.put_image_sample_level(image, level, context)?; + + if let Some(offset) = offset { + write!(self.out, ", ")?; + self.put_const_expression(offset, context.module, context.mod_info)?; + } + + match gather { + None | Some(crate::SwizzleComponent::X) => {} + Some(component) => { + let is_cube_map = match *context.resolve_type(image) { + crate::TypeInner::Image { + dim: crate::ImageDimension::Cube, + .. + } => true, + _ => false, + }; + // Offset always comes before the gather, except + // in cube maps where it's not applicable + if offset.is_none() && !is_cube_map { + write!(self.out, ", {NAMESPACE}::int2(0)")?; + } + let letter = back::COMPONENTS[component as usize]; + write!(self.out, ", {NAMESPACE}::component::{letter}")?; + } + } + write!(self.out, ")")?; + } + crate::Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + let address = TexelAddress { + coordinate, + array_index, + sample, + level: level.map(LevelOfDetail::Direct), + }; + self.put_image_load(expr_handle, image, address, context)?; + } + //Note: for all the queries, the signed integers are expected, + // so a conversion is needed. + crate::Expression::ImageQuery { image, query } => match query { + crate::ImageQuery::Size { level } => { + self.put_image_size_query( + image, + level.map(LevelOfDetail::Direct), + crate::ScalarKind::Uint, + context, + )?; + } + crate::ImageQuery::NumLevels => { + self.put_expression(image, context, false)?; + write!(self.out, ".get_num_mip_levels()")?; + } + crate::ImageQuery::NumLayers => { + self.put_expression(image, context, false)?; + write!(self.out, ".get_array_size()")?; + } + crate::ImageQuery::NumSamples => { + self.put_expression(image, context, false)?; + write!(self.out, ".get_num_samples()")?; + } + }, + crate::Expression::Unary { op, expr } => { + let op_str = match op { + crate::UnaryOperator::Negate => "-", + crate::UnaryOperator::LogicalNot => "!", + crate::UnaryOperator::BitwiseNot => "~", + }; + write!(self.out, "{op_str}(")?; + self.put_expression(expr, context, false)?; + write!(self.out, ")")?; + } + crate::Expression::Binary { op, left, right } => { + let op_str = crate::back::binary_operation_str(op); + let kind = context + .resolve_type(left) + .scalar_kind() + .ok_or(Error::UnsupportedBinaryOp(op))?; + + // TODO: handle undefined behavior of BinaryOperator::Modulo + // + // sint: + // if right == 0 return 0 + // if left == min(type_of(left)) && right == -1 return 0 + // if sign(left) == -1 || sign(right) == -1 return result as defined by WGSL + // + // uint: + // if right == 0 return 0 + // + // float: + // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798 + + if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float { + write!(self.out, "{NAMESPACE}::fmod(")?; + self.put_expression(left, context, true)?; + write!(self.out, ", ")?; + self.put_expression(right, context, true)?; + write!(self.out, ")")?; + } else { + if !is_scoped { + write!(self.out, "(")?; + } + + // Cast packed vector if necessary + // Packed vector - matrix multiplications are not supported in MSL + if op == crate::BinaryOperator::Multiply + && matches!( + context.resolve_type(right), + &crate::TypeInner::Matrix { .. } + ) + { + self.put_wrapped_expression_for_packed_vec3_access(left, context, false)?; + } else { + self.put_expression(left, context, false)?; + } + + write!(self.out, " {op_str} ")?; + + // See comment above + if op == crate::BinaryOperator::Multiply + && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. }) + { + self.put_wrapped_expression_for_packed_vec3_access(right, context, false)?; + } else { + self.put_expression(right, context, false)?; + } + + if !is_scoped { + write!(self.out, ")")?; + } + } + } + crate::Expression::Select { + condition, + accept, + reject, + } => match *context.resolve_type(condition) { + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Bool, + .. + }) => { + if !is_scoped { + write!(self.out, "(")?; + } + self.put_expression(condition, context, false)?; + write!(self.out, " ? ")?; + self.put_expression(accept, context, false)?; + write!(self.out, " : ")?; + self.put_expression(reject, context, false)?; + if !is_scoped { + write!(self.out, ")")?; + } + } + crate::TypeInner::Vector { + scalar: + crate::Scalar { + kind: crate::ScalarKind::Bool, + .. + }, + .. + } => { + write!(self.out, "{NAMESPACE}::select(")?; + self.put_expression(reject, context, true)?; + write!(self.out, ", ")?; + self.put_expression(accept, context, true)?; + write!(self.out, ", ")?; + self.put_expression(condition, context, true)?; + write!(self.out, ")")?; + } + _ => return Err(Error::Validation), + }, + crate::Expression::Derivative { axis, expr, .. } => { + use crate::DerivativeAxis as Axis; + let op = match axis { + Axis::X => "dfdx", + Axis::Y => "dfdy", + Axis::Width => "fwidth", + }; + write!(self.out, "{NAMESPACE}::{op}")?; + self.put_call_parameters(iter::once(expr), context)?; + } + crate::Expression::Relational { fun, argument } => { + let op = match fun { + crate::RelationalFunction::Any => "any", + crate::RelationalFunction::All => "all", + crate::RelationalFunction::IsNan => "isnan", + crate::RelationalFunction::IsInf => "isinf", + }; + write!(self.out, "{NAMESPACE}::{op}")?; + self.put_call_parameters(iter::once(argument), context)?; + } + crate::Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + use crate::MathFunction as Mf; + + let arg_type = context.resolve_type(arg); + let scalar_argument = match arg_type { + &crate::TypeInner::Scalar(_) => true, + _ => false, + }; + + let fun_name = match fun { + // comparison + Mf::Abs => "abs", + Mf::Min => "min", + Mf::Max => "max", + Mf::Clamp => "clamp", + Mf::Saturate => "saturate", + // trigonometry + Mf::Cos => "cos", + Mf::Cosh => "cosh", + Mf::Sin => "sin", + Mf::Sinh => "sinh", + Mf::Tan => "tan", + Mf::Tanh => "tanh", + Mf::Acos => "acos", + Mf::Asin => "asin", + Mf::Atan => "atan", + Mf::Atan2 => "atan2", + Mf::Asinh => "asinh", + Mf::Acosh => "acosh", + Mf::Atanh => "atanh", + Mf::Radians => "", + Mf::Degrees => "", + // decomposition + Mf::Ceil => "ceil", + Mf::Floor => "floor", + Mf::Round => "rint", + Mf::Fract => "fract", + Mf::Trunc => "trunc", + Mf::Modf => MODF_FUNCTION, + Mf::Frexp => FREXP_FUNCTION, + Mf::Ldexp => "ldexp", + // exponent + Mf::Exp => "exp", + Mf::Exp2 => "exp2", + Mf::Log => "log", + Mf::Log2 => "log2", + Mf::Pow => "pow", + // geometry + Mf::Dot => match *context.resolve_type(arg) { + crate::TypeInner::Vector { + scalar: + crate::Scalar { + kind: crate::ScalarKind::Float, + .. + }, + .. + } => "dot", + crate::TypeInner::Vector { size, .. } => { + return self.put_dot_product(arg, arg1.unwrap(), size as usize, context) + } + _ => unreachable!( + "Correct TypeInner for dot product should be already validated" + ), + }, + Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))), + Mf::Cross => "cross", + Mf::Distance => "distance", + Mf::Length if scalar_argument => "abs", + Mf::Length => "length", + Mf::Normalize => "normalize", + Mf::FaceForward => "faceforward", + Mf::Reflect => "reflect", + Mf::Refract => "refract", + // computational + Mf::Sign => match arg_type.scalar_kind() { + Some(crate::ScalarKind::Sint) => { + return self.put_isign(arg, context); + } + _ => "sign", + }, + Mf::Fma => "fma", + Mf::Mix => "mix", + Mf::Step => "step", + Mf::SmoothStep => "smoothstep", + Mf::Sqrt => "sqrt", + Mf::InverseSqrt => "rsqrt", + Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))), + Mf::Transpose => "transpose", + Mf::Determinant => "determinant", + // bits + Mf::CountTrailingZeros => "ctz", + Mf::CountLeadingZeros => "clz", + Mf::CountOneBits => "popcount", + Mf::ReverseBits => "reverse_bits", + Mf::ExtractBits => "extract_bits", + Mf::InsertBits => "insert_bits", + Mf::FindLsb => "", + Mf::FindMsb => "", + // data packing + Mf::Pack4x8snorm => "pack_float_to_snorm4x8", + Mf::Pack4x8unorm => "pack_float_to_unorm4x8", + Mf::Pack2x16snorm => "pack_float_to_snorm2x16", + Mf::Pack2x16unorm => "pack_float_to_unorm2x16", + Mf::Pack2x16float => "", + // data unpacking + Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float", + Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float", + Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float", + Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float", + Mf::Unpack2x16float => "", + }; + + match fun { + Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => { + // reverse_bits is listed as requiring MSL 2.1 but that + // is a copy/paste error. Looking at previous snapshots + // on web.archive.org it's present in MSL 1.2. + // + // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html + // also talks about MSL 1.2 adding "New integer + // functions to extract, insert, and reverse bits, as + // described in Integer Functions." + if context.lang_version < (1, 2) { + return Err(Error::UnsupportedFunction(fun_name.to_string())); + } + } + _ => {} + } + + if fun == Mf::Distance && scalar_argument { + write!(self.out, "{NAMESPACE}::abs(")?; + self.put_expression(arg, context, false)?; + write!(self.out, " - ")?; + self.put_expression(arg1.unwrap(), context, false)?; + write!(self.out, ")")?; + } else if fun == Mf::FindLsb { + write!(self.out, "((({NAMESPACE}::ctz(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ") + 1) % 33) - 1)")?; + } else if fun == Mf::FindMsb { + let inner = context.resolve_type(arg); + + write!(self.out, "{NAMESPACE}::select(31 - {NAMESPACE}::clz(")?; + + if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { + write!(self.out, "{NAMESPACE}::select(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", ~")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " < 0)")?; + } else { + self.put_expression(arg, context, true)?; + } + + write!(self.out, "), ")?; + + // or metal will complain that select is ambiguous + match *inner { + crate::TypeInner::Vector { size, scalar } => { + let size = back::vector_size_str(size); + if let crate::ScalarKind::Sint = scalar.kind { + write!(self.out, "int{size}")?; + } else { + write!(self.out, "uint{size}")?; + } + } + crate::TypeInner::Scalar(scalar) => { + if let crate::ScalarKind::Sint = scalar.kind { + write!(self.out, "int")?; + } else { + write!(self.out, "uint")?; + } + } + _ => (), + } + + write!(self.out, "(-1), ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " == 0 || ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " == -1)")?; + } else if fun == Mf::Unpack2x16float { + write!(self.out, "float2(as_type<half2>(")?; + self.put_expression(arg, context, false)?; + write!(self.out, "))")?; + } else if fun == Mf::Pack2x16float { + write!(self.out, "as_type<uint>(half2(")?; + self.put_expression(arg, context, false)?; + write!(self.out, "))")?; + } else if fun == Mf::Radians { + write!(self.out, "((")?; + self.put_expression(arg, context, false)?; + write!(self.out, ") * 0.017453292519943295474)")?; + } else if fun == Mf::Degrees { + write!(self.out, "((")?; + self.put_expression(arg, context, false)?; + write!(self.out, ") * 57.295779513082322865)")?; + } else if fun == Mf::Modf || fun == Mf::Frexp { + write!(self.out, "{fun_name}")?; + self.put_call_parameters(iter::once(arg), context)?; + } else { + write!(self.out, "{NAMESPACE}::{fun_name}")?; + self.put_call_parameters( + iter::once(arg).chain(arg1).chain(arg2).chain(arg3), + context, + )?; + } + } + crate::Expression::As { + expr, + kind, + convert, + } => match *context.resolve_type(expr) { + crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => { + let target_scalar = crate::Scalar { + kind, + width: convert.unwrap_or(src.width), + }; + let is_bool_cast = + kind == crate::ScalarKind::Bool || src.kind == crate::ScalarKind::Bool; + let op = match convert { + Some(w) if w == src.width || is_bool_cast => "static_cast", + Some(8) if kind == crate::ScalarKind::Float => { + return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64)) + } + Some(_) => return Err(Error::Validation), + None => "as_type", + }; + write!(self.out, "{op}<")?; + match *context.resolve_type(expr) { + crate::TypeInner::Vector { size, .. } => { + put_numeric_type(&mut self.out, target_scalar, &[size])? + } + _ => put_numeric_type(&mut self.out, target_scalar, &[])?, + }; + write!(self.out, ">(")?; + self.put_expression(expr, context, true)?; + write!(self.out, ")")?; + } + crate::TypeInner::Matrix { + columns, + rows, + scalar, + } => { + let target_scalar = crate::Scalar { + kind, + width: convert.unwrap_or(scalar.width), + }; + put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?; + write!(self.out, "(")?; + self.put_expression(expr, context, true)?; + write!(self.out, ")")?; + } + _ => return Err(Error::Validation), + }, + // has to be a named expression + crate::Expression::CallResult(_) + | crate::Expression::AtomicResult { .. } + | crate::Expression::WorkGroupUniformLoadResult { .. } + | crate::Expression::RayQueryProceedResult => { + unreachable!() + } + crate::Expression::ArrayLength(expr) => { + // Find the global to which the array belongs. + let global = match context.function.expressions[expr] { + crate::Expression::AccessIndex { base, .. } => { + match context.function.expressions[base] { + crate::Expression::GlobalVariable(handle) => handle, + _ => return Err(Error::Validation), + } + } + crate::Expression::GlobalVariable(handle) => handle, + _ => return Err(Error::Validation), + }; + + if !is_scoped { + write!(self.out, "(")?; + } + write!(self.out, "1 + ")?; + self.put_dynamic_array_max_index(global, context)?; + if !is_scoped { + write!(self.out, ")")?; + } + } + crate::Expression::RayQueryGetIntersection { query, committed } => { + if context.lang_version < (2, 4) { + return Err(Error::UnsupportedRayTracing); + } + + if !committed { + unimplemented!() + } + let ty = context.module.special_types.ray_intersection.unwrap(); + let type_name = &self.names[&NameKey::Type(ty)]; + write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?; + self.put_expression(query, context, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?; + let fields = [ + "distance", + "user_instance_id", // req Metal 2.4 + "instance_id", + "", // SBT offset + "geometry_id", + "primitive_id", + "triangle_barycentric_coord", + "triangle_front_facing", + "", // padding + "object_to_world_transform", // req Metal 2.4 + "world_to_object_transform", // req Metal 2.4 + ]; + for field in fields { + write!(self.out, ", ")?; + if field.is_empty() { + write!(self.out, "{{}}")?; + } else { + self.put_expression(query, context, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?; + } + } + write!(self.out, "}}")?; + } + } + Ok(()) + } + + /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3 + fn put_wrapped_expression_for_packed_vec3_access( + &mut self, + expr_handle: Handle<crate::Expression>, + context: &ExpressionContext, + is_scoped: bool, + ) -> BackendResult { + if let Some(scalar) = context.get_packed_vec_kind(expr_handle) { + write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?; + self.put_expression(expr_handle, context, is_scoped)?; + write!(self.out, ")")?; + } else { + self.put_expression(expr_handle, context, is_scoped)?; + } + Ok(()) + } + + /// Write a `GuardedIndex` as a Metal expression. + fn put_index( + &mut self, + index: index::GuardedIndex, + context: &ExpressionContext, + is_scoped: bool, + ) -> BackendResult { + match index { + index::GuardedIndex::Expression(expr) => { + self.put_expression(expr, context, is_scoped)? + } + index::GuardedIndex::Known(value) => write!(self.out, "{value}")?, + } + Ok(()) + } + + /// Emit an index bounds check condition for `chain`, if required. + /// + /// `chain` is a subtree of `Access` and `AccessIndex` expressions, + /// operating either on a pointer to a value, or on a value directly. If we cannot + /// statically determine that all indexing operations in `chain` are within + /// bounds, then write a conditional expression to check them dynamically, + /// and return true. All accesses in the chain are checked by the generated + /// expression. + /// + /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`]. + /// + /// The text written is of the form: + /// + /// ```ignore + /// {level}{prefix}uint(i) < 4 && uint(j) < 10 + /// ``` + /// + /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`] + /// statements, presumably these arguments start an indented `if` statement; for + /// [`Load`] expressions, the caller is probably building up a ternary `?:` + /// expression. In either case, what is written is not a complete syntactic structure + /// in its own right, and the caller will have to finish it off if we return `true`. + /// + /// If no expression is written, return false. + /// + /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy + /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite + /// [`Store`]: crate::Statement::Store + /// [`Load`]: crate::Expression::Load + #[allow(unused_variables)] + fn put_bounds_checks( + &mut self, + mut chain: Handle<crate::Expression>, + context: &ExpressionContext, + level: back::Level, + prefix: &'static str, + ) -> Result<bool, Error> { + let mut check_written = false; + + // Iterate over the access chain, handling each expression. + loop { + // Produce a `GuardedIndex`, so we can shared code between the + // `Access` and `AccessIndex` cases. + let (base, guarded_index) = match context.function.expressions[chain] { + crate::Expression::Access { base, index } => { + (base, Some(index::GuardedIndex::Expression(index))) + } + crate::Expression::AccessIndex { base, index } => { + // Don't try to check indices into structs. Validation already took + // care of them, and index::needs_guard doesn't handle that case. + let mut base_inner = context.resolve_type(base); + if let crate::TypeInner::Pointer { base, .. } = *base_inner { + base_inner = &context.module.types[base].inner; + } + match *base_inner { + crate::TypeInner::Struct { .. } => (base, None), + _ => (base, Some(index::GuardedIndex::Known(index))), + } + } + _ => break, + }; + + if let Some(index) = guarded_index { + if let Some(length) = context.access_needs_check(base, index) { + if check_written { + write!(self.out, " && ")?; + } else { + write!(self.out, "{level}{prefix}")?; + check_written = true; + } + + // Check that the index falls within bounds. Do this with a single + // comparison, by casting the index to `uint` first, so that negative + // indices become large positive values. + write!(self.out, "uint(")?; + self.put_index(index, context, true)?; + self.out.write_str(") < ")?; + match length { + index::IndexableLength::Known(value) => write!(self.out, "{value}")?, + index::IndexableLength::Dynamic => { + let global = context + .function + .originating_global(base) + .ok_or(Error::Validation)?; + write!(self.out, "1 + ")?; + self.put_dynamic_array_max_index(global, context)? + } + } + } + } + + chain = base + } + + Ok(check_written) + } + + /// Write the access chain `chain`. + /// + /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions, + /// operating either on a pointer to a value, or on a value directly. + /// + /// Generate bounds checks code only if `policy` is [`Restrict`]. The + /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so + /// that must be handled in the caller. + /// + /// Handle the entire chain, recursing back into `put_expression` only for index + /// expressions and the base expression that originates the pointer or composite value + /// being accessed. This allows `put_expression` to assume that any `Access` or + /// `AccessIndex` expressions it sees are the top of a chain, so it can emit + /// `ReadZeroSkipWrite` checks. + /// + /// [`Access`]: crate::Expression::Access + /// [`AccessIndex`]: crate::Expression::AccessIndex + /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict + /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite + fn put_access_chain( + &mut self, + chain: Handle<crate::Expression>, + policy: index::BoundsCheckPolicy, + context: &ExpressionContext, + ) -> BackendResult { + match context.function.expressions[chain] { + crate::Expression::Access { base, index } => { + let mut base_ty = context.resolve_type(base); + + // Look through any pointers to see what we're really indexing. + if let crate::TypeInner::Pointer { base, space: _ } = *base_ty { + base_ty = &context.module.types[base].inner; + } + + self.put_subscripted_access_chain( + base, + base_ty, + index::GuardedIndex::Expression(index), + policy, + context, + )?; + } + crate::Expression::AccessIndex { base, index } => { + let base_resolution = &context.info[base].ty; + let mut base_ty = base_resolution.inner_with(&context.module.types); + let mut base_ty_handle = base_resolution.handle(); + + // Look through any pointers to see what we're really indexing. + if let crate::TypeInner::Pointer { base, space: _ } = *base_ty { + base_ty = &context.module.types[base].inner; + base_ty_handle = Some(base); + } + + // Handle structs and anything else that can use `.x` syntax here, so + // `put_subscripted_access_chain` won't have to handle the absurd case of + // indexing a struct with an expression. + match *base_ty { + crate::TypeInner::Struct { .. } => { + let base_ty = base_ty_handle.unwrap(); + self.put_access_chain(base, policy, context)?; + let name = &self.names[&NameKey::StructMember(base_ty, index)]; + write!(self.out, ".{name}")?; + } + crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => { + self.put_access_chain(base, policy, context)?; + // Prior to Metal v2.1 component access for packed vectors wasn't available + // however array indexing is + if context.get_packed_vec_kind(base).is_some() { + write!(self.out, "[{index}]")?; + } else { + write!(self.out, ".{}", back::COMPONENTS[index as usize])?; + } + } + _ => { + self.put_subscripted_access_chain( + base, + base_ty, + index::GuardedIndex::Known(index), + policy, + context, + )?; + } + } + } + _ => self.put_expression(chain, context, false)?, + } + + Ok(()) + } + + /// Write a `[]`-style access of `base` by `index`. + /// + /// If `policy` is [`Restrict`], then generate code as needed to force all index + /// values within bounds. + /// + /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or + /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`] + /// removed. Our callers often already have this handy. + /// + /// This only emits `[]` expressions; it doesn't handle struct member accesses or + /// referencing vector components by name. + /// + /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict + /// [`Array`]: crate::TypeInner::Array + /// [`Vector`]: crate::TypeInner::Vector + /// [`Pointer`]: crate::TypeInner::Pointer + fn put_subscripted_access_chain( + &mut self, + base: Handle<crate::Expression>, + base_ty: &crate::TypeInner, + index: index::GuardedIndex, + policy: index::BoundsCheckPolicy, + context: &ExpressionContext, + ) -> BackendResult { + let accessing_wrapped_array = match *base_ty { + crate::TypeInner::Array { + size: crate::ArraySize::Constant(_), + .. + } => true, + _ => false, + }; + + self.put_access_chain(base, policy, context)?; + if accessing_wrapped_array { + write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?; + } + write!(self.out, "[")?; + + // Decide whether this index needs to be clamped to fall within range. + let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict { + context.access_needs_check(base, index) + } else { + None + }; + if let Some(limit) = restriction_needed { + write!(self.out, "{NAMESPACE}::min(unsigned(")?; + self.put_index(index, context, true)?; + write!(self.out, "), ")?; + match limit { + index::IndexableLength::Known(limit) => { + write!(self.out, "{}u", limit - 1)?; + } + index::IndexableLength::Dynamic => { + let global = context + .function + .originating_global(base) + .ok_or(Error::Validation)?; + self.put_dynamic_array_max_index(global, context)?; + } + } + write!(self.out, ")")?; + } else { + self.put_index(index, context, true)?; + } + + write!(self.out, "]")?; + + Ok(()) + } + + fn put_load( + &mut self, + pointer: Handle<crate::Expression>, + context: &ExpressionContext, + is_scoped: bool, + ) -> BackendResult { + // Since access chains never cross between address spaces, we can just + // check the index bounds check policy once at the top. + let policy = context.choose_bounds_check_policy(pointer); + if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite + && self.put_bounds_checks( + pointer, + context, + back::Level(0), + if is_scoped { "" } else { "(" }, + )? + { + write!(self.out, " ? ")?; + self.put_unchecked_load(pointer, policy, context)?; + write!(self.out, " : DefaultConstructible()")?; + + if !is_scoped { + write!(self.out, ")")?; + } + } else { + self.put_unchecked_load(pointer, policy, context)?; + } + + Ok(()) + } + + fn put_unchecked_load( + &mut self, + pointer: Handle<crate::Expression>, + policy: index::BoundsCheckPolicy, + context: &ExpressionContext, + ) -> BackendResult { + let is_atomic_pointer = context + .resolve_type(pointer) + .is_atomic_pointer(&context.module.types); + + if is_atomic_pointer { + write!( + self.out, + "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}" + )?; + self.put_access_chain(pointer, policy, context)?; + write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?; + } else { + // We don't do any dereferencing with `*` here as pointer arguments to functions + // are done by `&` references and not `*` pointers. These do not need to be + // dereferenced. + self.put_access_chain(pointer, policy, context)?; + } + + Ok(()) + } + + fn put_return_value( + &mut self, + level: back::Level, + expr_handle: Handle<crate::Expression>, + result_struct: Option<&str>, + context: &ExpressionContext, + ) -> BackendResult { + match result_struct { + Some(struct_name) => { + let mut has_point_size = false; + let result_ty = context.function.result.as_ref().unwrap().ty; + match context.module.types[result_ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + let tmp = "_tmp"; + write!(self.out, "{level}const auto {tmp} = ")?; + self.put_expression(expr_handle, context, true)?; + writeln!(self.out, ";")?; + write!(self.out, "{level}return {struct_name} {{")?; + + let mut is_first = true; + + for (index, member) in members.iter().enumerate() { + if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) = + member.binding + { + has_point_size = true; + if !context.pipeline_options.allow_and_force_point_size { + continue; + } + } + + let comma = if is_first { "" } else { "," }; + is_first = false; + let name = &self.names[&NameKey::StructMember(result_ty, index as u32)]; + // HACK: we are forcefully deduplicating the expression here + // to convert from a wrapped struct to a raw array, e.g. + // `float gl_ClipDistance1 [[clip_distance]] [1];`. + if let crate::TypeInner::Array { + size: crate::ArraySize::Constant(size), + .. + } = context.module.types[member.ty].inner + { + write!(self.out, "{comma} {{")?; + for j in 0..size.get() { + if j != 0 { + write!(self.out, ",")?; + } + write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?; + } + write!(self.out, "}}")?; + } else { + write!(self.out, "{comma} {tmp}.{name}")?; + } + } + } + _ => { + write!(self.out, "{level}return {struct_name} {{ ")?; + self.put_expression(expr_handle, context, true)?; + } + } + + if let FunctionOrigin::EntryPoint(ep_index) = context.origin { + let stage = context.module.entry_points[ep_index as usize].stage; + if context.pipeline_options.allow_and_force_point_size + && stage == crate::ShaderStage::Vertex + && !has_point_size + { + // point size was injected and comes last + write!(self.out, ", 1.0")?; + } + } + write!(self.out, " }}")?; + } + None => { + write!(self.out, "{level}return ")?; + self.put_expression(expr_handle, context, true)?; + } + } + writeln!(self.out, ";")?; + Ok(()) + } + + /// Helper method used to find which expressions of a given function require baking + /// + /// # Notes + /// This function overwrites the contents of `self.need_bake_expressions` + fn update_expressions_to_bake( + &mut self, + func: &crate::Function, + info: &valid::FunctionInfo, + context: &ExpressionContext, + ) { + use crate::Expression; + self.need_bake_expressions.clear(); + + for (expr_handle, expr) in func.expressions.iter() { + // Expressions whose reference count is above the + // threshold should always be stored in temporaries. + let expr_info = &info[expr_handle]; + let min_ref_count = func.expressions[expr_handle].bake_ref_count(); + if min_ref_count <= expr_info.ref_count { + self.need_bake_expressions.insert(expr_handle); + } else { + match expr_info.ty { + // force ray desc to be baked: it's used multiple times internally + TypeResolution::Handle(h) + if Some(h) == context.module.special_types.ray_desc => + { + self.need_bake_expressions.insert(expr_handle); + } + _ => {} + } + } + + if let Expression::Math { fun, arg, arg1, .. } = *expr { + match fun { + crate::MathFunction::Dot => { + // WGSL's `dot` function works on any `vecN` type, but Metal's only + // works on floating-point vectors, so we emit inline code for + // integer vector `dot` calls. But that code uses each argument `N` + // times, once for each component (see `put_dot_product`), so to + // avoid duplicated evaluation, we must bake integer operands. + + // check what kind of product this is depending + // on the resolve type of the Dot function itself + let inner = context.resolve_type(expr_handle); + if let crate::TypeInner::Scalar(scalar) = *inner { + match scalar.kind { + crate::ScalarKind::Sint | crate::ScalarKind::Uint => { + self.need_bake_expressions.insert(arg); + self.need_bake_expressions.insert(arg1.unwrap()); + } + _ => {} + } + } + } + crate::MathFunction::FindMsb => { + self.need_bake_expressions.insert(arg); + } + crate::MathFunction::Sign => { + // WGSL's `sign` function works also on signed ints, but Metal's only + // works on floating points, so we emit inline code for integer `sign` + // calls. But that code uses each argument 2 times (see `put_isign`), + // so to avoid duplicated evaluation, we must bake the argument. + let inner = context.resolve_type(expr_handle); + if inner.scalar_kind() == Some(crate::ScalarKind::Sint) { + self.need_bake_expressions.insert(arg); + } + } + _ => {} + } + } + } + } + + fn start_baking_expression( + &mut self, + handle: Handle<crate::Expression>, + context: &ExpressionContext, + name: &str, + ) -> BackendResult { + match context.info[handle].ty { + TypeResolution::Handle(ty_handle) => { + let ty_name = TypeContext { + handle: ty_handle, + gctx: context.module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + write!(self.out, "{ty_name}")?; + } + TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => { + put_numeric_type(&mut self.out, scalar, &[])?; + } + TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => { + put_numeric_type(&mut self.out, scalar, &[size])?; + } + TypeResolution::Value(crate::TypeInner::Matrix { + columns, + rows, + scalar, + }) => { + put_numeric_type(&mut self.out, scalar, &[rows, columns])?; + } + TypeResolution::Value(ref other) => { + log::warn!("Type {:?} isn't a known local", other); //TEMP! + return Err(Error::FeatureNotImplemented("weird local type".to_string())); + } + } + + //TODO: figure out the naming scheme that wouldn't collide with user names. + write!(self.out, " {name} = ")?; + + Ok(()) + } + + /// Cache a clamped level of detail value, if necessary. + /// + /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a + /// properly clamped level of detail value both in the access itself, and + /// for fetching the size of the requested MIP level, needed to clamp the + /// coordinates. To avoid recomputing this clamped level of detail, we cache + /// it in a temporary variable, as part of the [`Emit`] statement covering + /// the [`ImageLoad`] expression. + /// + /// [`ImageLoad`]: crate::Expression::ImageLoad + /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict + /// [`Emit`]: crate::Statement::Emit + fn put_cache_restricted_level( + &mut self, + load: Handle<crate::Expression>, + image: Handle<crate::Expression>, + mip_level: Option<Handle<crate::Expression>>, + indent: back::Level, + context: &StatementContext, + ) -> BackendResult { + // Does this image access actually require (or even permit) a + // level-of-detail, and does the policy require us to restrict it? + let level_of_detail = match mip_level { + Some(level) => level, + None => return Ok(()), + }; + + if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict + || !context.expression.image_needs_lod(image) + { + return Ok(()); + } + + write!( + self.out, + "{}uint {}{} = ", + indent, + CLAMPED_LOD_LOAD_PREFIX, + load.index(), + )?; + self.put_restricted_scalar_image_index( + image, + level_of_detail, + "get_num_mip_levels", + &context.expression, + )?; + writeln!(self.out, ";")?; + + Ok(()) + } + + fn put_block( + &mut self, + level: back::Level, + statements: &[crate::Statement], + context: &StatementContext, + ) -> BackendResult { + // Add to the set in order to track the stack size. + #[cfg(test)] + #[allow(trivial_casts)] + self.put_block_stack_pointers + .insert(&level as *const _ as *const ()); + + for statement in statements { + log::trace!("statement[{}] {:?}", level.0, statement); + match *statement { + crate::Statement::Emit(ref range) => { + for handle in range.clone() { + // `ImageLoad` expressions covered by the `Restrict` bounds check policy + // may need to cache a clamped version of their level-of-detail argument. + if let crate::Expression::ImageLoad { + image, + level: mip_level, + .. + } = context.expression.function.expressions[handle] + { + self.put_cache_restricted_level( + handle, image, mip_level, level, context, + )?; + } + + let ptr_class = context.expression.resolve_type(handle).pointer_space(); + let expr_name = if ptr_class.is_some() { + None // don't bake pointer expressions (just yet) + } else if let Some(name) = + context.expression.function.named_expressions.get(&handle) + { + // The `crate::Function::named_expressions` table holds + // expressions that should be saved in temporaries once they + // are `Emit`ted. We only add them to `self.named_expressions` + // when we reach the `Emit` that covers them, so that we don't + // try to use their names before we've actually initialized + // the temporary that holds them. + // + // Don't assume the names in `named_expressions` are unique, + // or even valid. Use the `Namer`. + Some(self.namer.call(name)) + } else { + // If this expression is an index that we're going to first compare + // against a limit, and then actually use as an index, then we may + // want to cache it in a temporary, to avoid evaluating it twice. + let bake = + if context.expression.guarded_indices.contains(handle.index()) { + true + } else { + self.need_bake_expressions.contains(&handle) + }; + + if bake { + Some(format!("{}{}", back::BAKE_PREFIX, handle.index())) + } else { + None + } + }; + + if let Some(name) = expr_name { + write!(self.out, "{level}")?; + self.start_baking_expression(handle, &context.expression, &name)?; + self.put_expression(handle, &context.expression, true)?; + self.named_expressions.insert(handle, name); + writeln!(self.out, ";")?; + } + } + } + crate::Statement::Block(ref block) => { + if !block.is_empty() { + writeln!(self.out, "{level}{{")?; + self.put_block(level.next(), block, context)?; + writeln!(self.out, "{level}}}")?; + } + } + crate::Statement::If { + condition, + ref accept, + ref reject, + } => { + write!(self.out, "{level}if (")?; + self.put_expression(condition, &context.expression, true)?; + writeln!(self.out, ") {{")?; + self.put_block(level.next(), accept, context)?; + if !reject.is_empty() { + writeln!(self.out, "{level}}} else {{")?; + self.put_block(level.next(), reject, context)?; + } + writeln!(self.out, "{level}}}")?; + } + crate::Statement::Switch { + selector, + ref cases, + } => { + write!(self.out, "{level}switch(")?; + self.put_expression(selector, &context.expression, true)?; + writeln!(self.out, ") {{")?; + let lcase = level.next(); + for case in cases.iter() { + match case.value { + crate::SwitchValue::I32(value) => { + write!(self.out, "{lcase}case {value}:")?; + } + crate::SwitchValue::U32(value) => { + write!(self.out, "{lcase}case {value}u:")?; + } + crate::SwitchValue::Default => { + write!(self.out, "{lcase}default:")?; + } + } + + let write_block_braces = !(case.fall_through && case.body.is_empty()); + if write_block_braces { + writeln!(self.out, " {{")?; + } else { + writeln!(self.out)?; + } + + self.put_block(lcase.next(), &case.body, context)?; + if !case.fall_through + && case.body.last().map_or(true, |s| !s.is_terminator()) + { + writeln!(self.out, "{}break;", lcase.next())?; + } + + if write_block_braces { + writeln!(self.out, "{lcase}}}")?; + } + } + writeln!(self.out, "{level}}}")?; + } + crate::Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + if !continuing.is_empty() || break_if.is_some() { + let gate_name = self.namer.call("loop_init"); + writeln!(self.out, "{level}bool {gate_name} = true;")?; + writeln!(self.out, "{level}while(true) {{")?; + let lif = level.next(); + let lcontinuing = lif.next(); + writeln!(self.out, "{lif}if (!{gate_name}) {{")?; + self.put_block(lcontinuing, continuing, context)?; + if let Some(condition) = break_if { + write!(self.out, "{lcontinuing}if (")?; + self.put_expression(condition, &context.expression, true)?; + writeln!(self.out, ") {{")?; + writeln!(self.out, "{}break;", lcontinuing.next())?; + writeln!(self.out, "{lcontinuing}}}")?; + } + writeln!(self.out, "{lif}}}")?; + writeln!(self.out, "{lif}{gate_name} = false;")?; + } else { + writeln!(self.out, "{level}while(true) {{")?; + } + self.put_block(level.next(), body, context)?; + writeln!(self.out, "{level}}}")?; + } + crate::Statement::Break => { + writeln!(self.out, "{level}break;")?; + } + crate::Statement::Continue => { + writeln!(self.out, "{level}continue;")?; + } + crate::Statement::Return { + value: Some(expr_handle), + } => { + self.put_return_value( + level, + expr_handle, + context.result_struct, + &context.expression, + )?; + } + crate::Statement::Return { value: None } => { + writeln!(self.out, "{level}return;")?; + } + crate::Statement::Kill => { + writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?; + } + crate::Statement::Barrier(flags) => { + self.write_barrier(flags, level)?; + } + crate::Statement::Store { pointer, value } => { + self.put_store(pointer, value, level, context)? + } + crate::Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => { + let address = TexelAddress { + coordinate, + array_index, + sample: None, + level: None, + }; + self.put_image_store(level, image, &address, value, context)? + } + crate::Statement::Call { + function, + ref arguments, + result, + } => { + write!(self.out, "{level}")?; + if let Some(expr) = result { + let name = format!("{}{}", back::BAKE_PREFIX, expr.index()); + self.start_baking_expression(expr, &context.expression, &name)?; + self.named_expressions.insert(expr, name); + } + let fun_name = &self.names[&NameKey::Function(function)]; + write!(self.out, "{fun_name}(")?; + // first, write down the actual arguments + for (i, &handle) in arguments.iter().enumerate() { + if i != 0 { + write!(self.out, ", ")?; + } + self.put_expression(handle, &context.expression, true)?; + } + // follow-up with any global resources used + let mut separate = !arguments.is_empty(); + let fun_info = &context.expression.mod_info[function]; + let mut supports_array_length = false; + for (handle, var) in context.expression.module.global_variables.iter() { + if fun_info[handle].is_empty() { + continue; + } + if var.space.needs_pass_through() { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + if separate { + write!(self.out, ", ")?; + } else { + separate = true; + } + write!(self.out, "{name}")?; + } + supports_array_length |= + needs_array_length(var.ty, &context.expression.module.types); + } + if supports_array_length { + if separate { + write!(self.out, ", ")?; + } + write!(self.out, "_buffer_sizes")?; + } + + // done + writeln!(self.out, ");")?; + } + crate::Statement::Atomic { + pointer, + ref fun, + value, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_baking_expression(result, &context.expression, &res_name)?; + self.named_expressions.insert(result, res_name); + match *fun { + crate::AtomicFunction::Add => { + self.put_atomic_fetch(pointer, "add", value, &context.expression)?; + } + crate::AtomicFunction::Subtract => { + self.put_atomic_fetch(pointer, "sub", value, &context.expression)?; + } + crate::AtomicFunction::And => { + self.put_atomic_fetch(pointer, "and", value, &context.expression)?; + } + crate::AtomicFunction::InclusiveOr => { + self.put_atomic_fetch(pointer, "or", value, &context.expression)?; + } + crate::AtomicFunction::ExclusiveOr => { + self.put_atomic_fetch(pointer, "xor", value, &context.expression)?; + } + crate::AtomicFunction::Min => { + self.put_atomic_fetch(pointer, "min", value, &context.expression)?; + } + crate::AtomicFunction::Max => { + self.put_atomic_fetch(pointer, "max", value, &context.expression)?; + } + crate::AtomicFunction::Exchange { compare: None } => { + self.put_atomic_operation( + pointer, + "exchange", + "", + value, + &context.expression, + )?; + } + crate::AtomicFunction::Exchange { .. } => { + return Err(Error::FeatureNotImplemented( + "atomic CompareExchange".to_string(), + )); + } + } + // done + writeln!(self.out, ";")?; + } + crate::Statement::WorkGroupUniformLoad { pointer, result } => { + self.write_barrier(crate::Barrier::WORK_GROUP, level)?; + + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.put_load(pointer, &context.expression, true)?; + self.named_expressions.insert(result, name); + + writeln!(self.out, ";")?; + self.write_barrier(crate::Barrier::WORK_GROUP, level)?; + } + crate::Statement::RayQuery { query, ref fun } => { + if context.expression.lang_version < (2, 4) { + return Err(Error::UnsupportedRayTracing); + } + + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + //TODO: how to deal with winding? + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?; + { + let f_opaque = back::RayFlag::CULL_OPAQUE.bits(); + let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits(); + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!( + self.out, + ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode((" + )?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?; + writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?; + } + { + let f_opaque = back::RayFlag::OPAQUE.bits(); + let f_no_opaque = back::RayFlag::NO_OPAQUE.bits(); + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?; + writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?; + } + { + let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits(); + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!( + self.out, + ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection((" + )?; + self.put_expression(descriptor, &context.expression, true)?; + writeln!(self.out, ".flags & {flag}) != 0);")?; + } + + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?; + self.put_expression(query, &context.expression, true)?; + write!( + self.out, + ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray(" + )?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".origin, ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".dir, ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".tmin, ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".tmax), ")?; + self.put_expression(acceleration_structure, &context.expression, true)?; + write!(self.out, ", ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".cull_mask);")?; + + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?; + } + crate::RayQueryFunction::Proceed { result } => { + write!(self.out, "{level}")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?; + //TODO: actually proceed? + + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?; + } + crate::RayQueryFunction::Terminate => { + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.abort();")?; + } + } + } + } + } + + // un-emit expressions + //TODO: take care of loop/continuing? + for statement in statements { + if let crate::Statement::Emit(ref range) = *statement { + for handle in range.clone() { + self.named_expressions.remove(&handle); + } + } + } + Ok(()) + } + + fn put_store( + &mut self, + pointer: Handle<crate::Expression>, + value: Handle<crate::Expression>, + level: back::Level, + context: &StatementContext, + ) -> BackendResult { + let policy = context.expression.choose_bounds_check_policy(pointer); + if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite + && self.put_bounds_checks(pointer, &context.expression, level, "if (")? + { + writeln!(self.out, ") {{")?; + self.put_unchecked_store(pointer, value, policy, level.next(), context)?; + writeln!(self.out, "{level}}}")?; + } else { + self.put_unchecked_store(pointer, value, policy, level, context)?; + } + + Ok(()) + } + + fn put_unchecked_store( + &mut self, + pointer: Handle<crate::Expression>, + value: Handle<crate::Expression>, + policy: index::BoundsCheckPolicy, + level: back::Level, + context: &StatementContext, + ) -> BackendResult { + let is_atomic_pointer = context + .expression + .resolve_type(pointer) + .is_atomic_pointer(&context.expression.module.types); + + if is_atomic_pointer { + write!( + self.out, + "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}" + )?; + self.put_access_chain(pointer, policy, &context.expression)?; + write!(self.out, ", ")?; + self.put_expression(value, &context.expression, true)?; + writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?; + } else { + write!(self.out, "{level}")?; + self.put_access_chain(pointer, policy, &context.expression)?; + write!(self.out, " = ")?; + self.put_expression(value, &context.expression, true)?; + writeln!(self.out, ";")?; + } + + Ok(()) + } + + pub fn write( + &mut self, + module: &crate::Module, + info: &valid::ModuleInfo, + options: &Options, + pipeline_options: &PipelineOptions, + ) -> Result<TranslationInfo, Error> { + self.names.clear(); + self.namer.reset( + module, + super::keywords::RESERVED, + &[], + &[], + &[CLAMPED_LOD_LOAD_PREFIX], + &mut self.names, + ); + self.struct_member_pads.clear(); + + writeln!( + self.out, + "// language: metal{}.{}", + options.lang_version.0, options.lang_version.1 + )?; + writeln!(self.out, "#include <metal_stdlib>")?; + writeln!(self.out, "#include <simd/simd.h>")?; + writeln!(self.out)?; + // Work around Metal bug where `uint` is not available by default + writeln!(self.out, "using {NAMESPACE}::uint;")?; + + let mut uses_ray_query = false; + for (_, ty) in module.types.iter() { + match ty.inner { + crate::TypeInner::AccelerationStructure => { + if options.lang_version < (2, 4) { + return Err(Error::UnsupportedRayTracing); + } + } + crate::TypeInner::RayQuery => { + if options.lang_version < (2, 4) { + return Err(Error::UnsupportedRayTracing); + } + uses_ray_query = true; + } + _ => (), + } + } + + if module.special_types.ray_desc.is_some() + || module.special_types.ray_intersection.is_some() + { + if options.lang_version < (2, 4) { + return Err(Error::UnsupportedRayTracing); + } + } + + if uses_ray_query { + self.put_ray_query_type()?; + } + + if options + .bounds_check_policies + .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite) + { + self.put_default_constructible()?; + } + writeln!(self.out)?; + + { + let mut indices = vec![]; + for (handle, var) in module.global_variables.iter() { + if needs_array_length(var.ty, &module.types) { + let idx = handle.index(); + indices.push(idx); + } + } + + if !indices.is_empty() { + writeln!(self.out, "struct _mslBufferSizes {{")?; + + for idx in indices { + writeln!(self.out, "{}uint size{};", back::INDENT, idx)?; + } + + writeln!(self.out, "}};")?; + writeln!(self.out)?; + } + }; + + self.write_type_defs(module)?; + self.write_global_constants(module, info)?; + self.write_functions(module, info, options, pipeline_options) + } + + /// Write the definition for the `DefaultConstructible` class. + /// + /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to + /// produce 'zero' values for any type, including structs, arrays, and so + /// on. We could do this by emitting default constructor applications, but + /// that would entail printing the name of the type, which is more trouble + /// than you'd think. Instead, we just construct this magic C++14 class that + /// can be converted to any type that can be default constructed, using + /// template parameter inference to detect which type is needed, so we don't + /// have to figure out the name. + /// + /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite + fn put_default_constructible(&mut self) -> BackendResult { + let tab = back::INDENT; + writeln!(self.out, "struct DefaultConstructible {{")?; + writeln!(self.out, "{tab}template<typename T>")?; + writeln!(self.out, "{tab}operator T() && {{")?; + writeln!(self.out, "{tab}{tab}return T {{}};")?; + writeln!(self.out, "{tab}}}")?; + writeln!(self.out, "}};")?; + Ok(()) + } + + fn put_ray_query_type(&mut self) -> BackendResult { + let tab = back::INDENT; + writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?; + let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>"); + writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?; + writeln!( + self.out, + "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};" + )?; + writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?; + writeln!(self.out, "}};")?; + writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?; + let v_triangle = back::RayIntersectionType::Triangle as u32; + let v_bbox = back::RayIntersectionType::BoundingBox as u32; + writeln!( + self.out, + "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : " + )?; + writeln!( + self.out, + "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;" + )?; + writeln!(self.out, "}}")?; + Ok(()) + } + + fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult { + for (handle, ty) in module.types.iter() { + if !ty.needs_alias() { + continue; + } + let name = &self.names[&NameKey::Type(handle)]; + match ty.inner { + // Naga IR can pass around arrays by value, but Metal, following + // C++, performs an array-to-pointer conversion (C++ [conv.array]) + // on expressions of array type, so assigning the array by value + // isn't possible. However, Metal *does* assign structs by + // value. So in our Metal output, we wrap all array types in + // synthetic struct types: + // + // struct type1 { + // float inner[10] + // }; + // + // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in + // any expression that actually wants access to the array. + crate::TypeInner::Array { + base, + size, + stride: _, + } => { + let base_name = TypeContext { + handle: base, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + + match size { + crate::ArraySize::Constant(size) => { + writeln!(self.out, "struct {name} {{")?; + writeln!( + self.out, + "{}{} {}[{}];", + back::INDENT, + base_name, + WRAPPED_ARRAY_FIELD, + size + )?; + writeln!(self.out, "}};")?; + } + crate::ArraySize::Dynamic => { + writeln!(self.out, "typedef {base_name} {name}[1];")?; + } + } + } + crate::TypeInner::Struct { + ref members, span, .. + } => { + writeln!(self.out, "struct {name} {{")?; + let mut last_offset = 0; + for (index, member) in members.iter().enumerate() { + if member.offset > last_offset { + self.struct_member_pads.insert((handle, index as u32)); + let pad = member.offset - last_offset; + writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?; + } + let ty_inner = &module.types[member.ty].inner; + last_offset = member.offset + ty_inner.size(module.to_ctx()); + + let member_name = &self.names[&NameKey::StructMember(handle, index as u32)]; + + // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector + match should_pack_struct_member(members, span, index, module) { + Some(scalar) => { + writeln!( + self.out, + "{}{}::packed_{}3 {};", + back::INDENT, + NAMESPACE, + scalar.to_msl_name(), + member_name + )?; + } + None => { + let base_name = TypeContext { + handle: member.ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + writeln!( + self.out, + "{}{} {};", + back::INDENT, + base_name, + member_name + )?; + + // for 3-component vectors, add one component + if let crate::TypeInner::Vector { + size: crate::VectorSize::Tri, + scalar, + } = *ty_inner + { + last_offset += scalar.width as u32; + } + } + } + } + writeln!(self.out, "}};")?; + } + _ => { + let ty_name = TypeContext { + handle, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: true, + }; + writeln!(self.out, "typedef {ty_name} {name};")?; + } + } + } + + // Write functions to create special types. + for (type_key, struct_ty) in module.special_types.predeclared_types.iter() { + match type_key { + &crate::PredeclaredType::ModfResult { size, width } + | &crate::PredeclaredType::FrexpResult { size, width } => { + let arg_type_name_owner; + let arg_type_name = if let Some(size) = size { + arg_type_name_owner = format!( + "{NAMESPACE}::{}{}", + if width == 8 { "double" } else { "float" }, + size as u8 + ); + &arg_type_name_owner + } else if width == 8 { + "double" + } else { + "float" + }; + + let other_type_name_owner; + let (defined_func_name, called_func_name, other_type_name) = + if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) { + (MODF_FUNCTION, "modf", arg_type_name) + } else { + let other_type_name = if let Some(size) = size { + other_type_name_owner = format!("int{}", size as u8); + &other_type_name_owner + } else { + "int" + }; + (FREXP_FUNCTION, "frexp", other_type_name) + }; + + let struct_name = &self.names[&NameKey::Type(*struct_ty)]; + + writeln!(self.out)?; + writeln!( + self.out, + "{} {defined_func_name}({arg_type_name} arg) {{ + {other_type_name} other; + {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other); + return {}{{ fract, other }}; +}}", + struct_name, struct_name + )?; + } + &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {} + } + } + + Ok(()) + } + + /// Writes all named constants + fn write_global_constants( + &mut self, + module: &crate::Module, + mod_info: &valid::ModuleInfo, + ) -> BackendResult { + let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some()); + + for (handle, constant) in constants { + let ty_name = TypeContext { + handle: constant.ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + let name = &self.names[&NameKey::Constant(handle)]; + write!(self.out, "constant {ty_name} {name} = ")?; + self.put_const_expression(constant.init, module, mod_info)?; + writeln!(self.out, ";")?; + } + + Ok(()) + } + + fn put_inline_sampler_properties( + &mut self, + level: back::Level, + sampler: &sm::InlineSampler, + ) -> BackendResult { + for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) { + writeln!( + self.out, + "{}{}::{}_address::{},", + level, + NAMESPACE, + letter, + address.as_str(), + )?; + } + writeln!( + self.out, + "{}{}::mag_filter::{},", + level, + NAMESPACE, + sampler.mag_filter.as_str(), + )?; + writeln!( + self.out, + "{}{}::min_filter::{},", + level, + NAMESPACE, + sampler.min_filter.as_str(), + )?; + if let Some(filter) = sampler.mip_filter { + writeln!( + self.out, + "{}{}::mip_filter::{},", + level, + NAMESPACE, + filter.as_str(), + )?; + } + // avoid setting it on platforms that don't support it + if sampler.border_color != sm::BorderColor::TransparentBlack { + writeln!( + self.out, + "{}{}::border_color::{},", + level, + NAMESPACE, + sampler.border_color.as_str(), + )?; + } + //TODO: I'm not able to feed this in a way that MSL likes: + //>error: use of undeclared identifier 'lod_clamp' + //>error: no member named 'max_anisotropy' in namespace 'metal' + if false { + if let Some(ref lod) = sampler.lod_clamp { + writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?; + } + if let Some(aniso) = sampler.max_anisotropy { + writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?; + } + } + if sampler.compare_func != sm::CompareFunc::Never { + writeln!( + self.out, + "{}{}::compare_func::{},", + level, + NAMESPACE, + sampler.compare_func.as_str(), + )?; + } + writeln!( + self.out, + "{}{}::coord::{}", + level, + NAMESPACE, + sampler.coord.as_str() + )?; + Ok(()) + } + + // Returns the array of mapped entry point names. + fn write_functions( + &mut self, + module: &crate::Module, + mod_info: &valid::ModuleInfo, + options: &Options, + pipeline_options: &PipelineOptions, + ) -> Result<TranslationInfo, Error> { + let mut pass_through_globals = Vec::new(); + for (fun_handle, fun) in module.functions.iter() { + log::trace!( + "function {:?}, handle {:?}", + fun.name.as_deref().unwrap_or("(anonymous)"), + fun_handle + ); + + let fun_info = &mod_info[fun_handle]; + pass_through_globals.clear(); + let mut supports_array_length = false; + for (handle, var) in module.global_variables.iter() { + if !fun_info[handle].is_empty() { + if var.space.needs_pass_through() { + pass_through_globals.push(handle); + } + supports_array_length |= needs_array_length(var.ty, &module.types); + } + } + + writeln!(self.out)?; + let fun_name = &self.names[&NameKey::Function(fun_handle)]; + match fun.result { + Some(ref result) => { + let ty_name = TypeContext { + handle: result.ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + write!(self.out, "{ty_name}")?; + } + None => { + write!(self.out, "void")?; + } + } + writeln!(self.out, " {fun_name}(")?; + + for (index, arg) in fun.arguments.iter().enumerate() { + let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)]; + let param_type_name = TypeContext { + handle: arg.ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + let separator = separate( + !pass_through_globals.is_empty() + || index + 1 != fun.arguments.len() + || supports_array_length, + ); + writeln!( + self.out, + "{}{} {}{}", + back::INDENT, + param_type_name, + name, + separator + )?; + } + for (index, &handle) in pass_through_globals.iter().enumerate() { + let tyvar = TypedGlobalVariable { + module, + names: &self.names, + handle, + usage: fun_info[handle], + binding: None, + reference: true, + }; + let separator = + separate(index + 1 != pass_through_globals.len() || supports_array_length); + write!(self.out, "{}", back::INDENT)?; + tyvar.try_fmt(&mut self.out)?; + writeln!(self.out, "{separator}")?; + } + + if supports_array_length { + writeln!( + self.out, + "{}constant _mslBufferSizes& _buffer_sizes", + back::INDENT + )?; + } + + writeln!(self.out, ") {{")?; + + let guarded_indices = + index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies); + + let context = StatementContext { + expression: ExpressionContext { + function: fun, + origin: FunctionOrigin::Handle(fun_handle), + info: fun_info, + lang_version: options.lang_version, + policies: options.bounds_check_policies, + guarded_indices, + module, + mod_info, + pipeline_options, + }, + result_struct: None, + }; + + for (local_handle, local) in fun.local_variables.iter() { + let ty_name = TypeContext { + handle: local.ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + let local_name = &self.names[&NameKey::FunctionLocal(fun_handle, local_handle)]; + write!(self.out, "{}{} {}", back::INDENT, ty_name, local_name)?; + match local.init { + Some(value) => { + write!(self.out, " = ")?; + self.put_expression(value, &context.expression, true)?; + } + None => { + write!(self.out, " = {{}}")?; + } + }; + writeln!(self.out, ";")?; + } + + self.update_expressions_to_bake(fun, fun_info, &context.expression); + self.put_block(back::Level(1), &fun.body, &context)?; + writeln!(self.out, "}}")?; + self.named_expressions.clear(); + } + + let mut info = TranslationInfo { + entry_point_names: Vec::with_capacity(module.entry_points.len()), + }; + for (ep_index, ep) in module.entry_points.iter().enumerate() { + let fun = &ep.function; + let fun_info = mod_info.get_entry_point(ep_index); + let mut ep_error = None; + + log::trace!( + "entry point {:?}, index {:?}", + fun.name.as_deref().unwrap_or("(anonymous)"), + ep_index + ); + + // Is any global variable used by this entry point dynamically sized? + let supports_array_length = module + .global_variables + .iter() + .filter(|&(handle, _)| !fun_info[handle].is_empty()) + .any(|(_, var)| needs_array_length(var.ty, &module.types)); + + // skip this entry point if any global bindings are missing, + // or their types are incompatible. + if !options.fake_missing_bindings { + for (var_handle, var) in module.global_variables.iter() { + if fun_info[var_handle].is_empty() { + continue; + } + match var.space { + crate::AddressSpace::Uniform + | crate::AddressSpace::Storage { .. } + | crate::AddressSpace::Handle => { + let br = match var.binding { + Some(ref br) => br, + None => { + let var_name = var.name.clone().unwrap_or_default(); + ep_error = + Some(super::EntryPointError::MissingBinding(var_name)); + break; + } + }; + let target = options.get_resource_binding_target(ep, br); + let good = match target { + Some(target) => { + let binding_ty = match module.types[var.ty].inner { + crate::TypeInner::BindingArray { base, .. } => { + &module.types[base].inner + } + ref ty => ty, + }; + match *binding_ty { + crate::TypeInner::Image { .. } => target.texture.is_some(), + crate::TypeInner::Sampler { .. } => { + target.sampler.is_some() + } + _ => target.buffer.is_some(), + } + } + None => false, + }; + if !good { + ep_error = + Some(super::EntryPointError::MissingBindTarget(br.clone())); + break; + } + } + crate::AddressSpace::PushConstant => { + if let Err(e) = options.resolve_push_constants(ep) { + ep_error = Some(e); + break; + } + } + crate::AddressSpace::Function + | crate::AddressSpace::Private + | crate::AddressSpace::WorkGroup => {} + } + } + if supports_array_length { + if let Err(err) = options.resolve_sizes_buffer(ep) { + ep_error = Some(err); + } + } + } + + if let Some(err) = ep_error { + info.entry_point_names.push(Err(err)); + continue; + } + let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)]; + info.entry_point_names.push(Ok(fun_name.clone())); + + writeln!(self.out)?; + + let (em_str, in_mode, out_mode) = match ep.stage { + crate::ShaderStage::Vertex => ( + "vertex", + LocationMode::VertexInput, + LocationMode::VertexOutput, + ), + crate::ShaderStage::Fragment { .. } => ( + "fragment", + LocationMode::FragmentInput, + LocationMode::FragmentOutput, + ), + crate::ShaderStage::Compute { .. } => { + ("kernel", LocationMode::Uniform, LocationMode::Uniform) + } + }; + + // Since `Namer.reset` wasn't expecting struct members to be + // suddenly injected into another namespace like this, + // `self.names` doesn't keep them distinct from other variables. + // Generate fresh names for these arguments, and remember the + // mapping. + let mut flattened_member_names = FastHashMap::default(); + // Varyings' members get their own namespace + let mut varyings_namer = crate::proc::Namer::default(); + + // List all the Naga `EntryPoint`'s `Function`'s arguments, + // flattening structs into their members. In Metal, we will pass + // each of these values to the entry point as a separate argument— + // except for the varyings, handled next. + let mut flattened_arguments = Vec::new(); + for (arg_index, arg) in fun.arguments.iter().enumerate() { + match module.types[arg.ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + for (member_index, member) in members.iter().enumerate() { + let member_index = member_index as u32; + flattened_arguments.push(( + NameKey::StructMember(arg.ty, member_index), + member.ty, + member.binding.as_ref(), + )); + let name_key = NameKey::StructMember(arg.ty, member_index); + let name = match member.binding { + Some(crate::Binding::Location { .. }) => { + varyings_namer.call(&self.names[&name_key]) + } + _ => self.namer.call(&self.names[&name_key]), + }; + flattened_member_names.insert(name_key, name); + } + } + _ => flattened_arguments.push(( + NameKey::EntryPointArgument(ep_index as _, arg_index as u32), + arg.ty, + arg.binding.as_ref(), + )), + } + } + + // Identify the varyings among the argument values, and emit a + // struct type named `<fun>Input` to hold them. + let stage_in_name = format!("{fun_name}Input"); + let varyings_member_name = self.namer.call("varyings"); + let mut has_varyings = false; + if !flattened_arguments.is_empty() { + writeln!(self.out, "struct {stage_in_name} {{")?; + for &(ref name_key, ty, binding) in flattened_arguments.iter() { + let binding = match binding { + Some(ref binding @ &crate::Binding::Location { .. }) => binding, + _ => continue, + }; + has_varyings = true; + let name = match *name_key { + NameKey::StructMember(..) => &flattened_member_names[name_key], + _ => &self.names[name_key], + }; + let ty_name = TypeContext { + handle: ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + let resolved = options.resolve_local_binding(binding, in_mode)?; + write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?; + resolved.try_fmt(&mut self.out)?; + writeln!(self.out, ";")?; + } + writeln!(self.out, "}};")?; + } + + // Define a struct type named for the return value, if any, named + // `<fun>Output`. + let stage_out_name = format!("{fun_name}Output"); + let result_member_name = self.namer.call("member"); + let result_type_name = match fun.result { + Some(ref result) => { + let mut result_members = Vec::new(); + if let crate::TypeInner::Struct { ref members, .. } = + module.types[result.ty].inner + { + for (member_index, member) in members.iter().enumerate() { + result_members.push(( + &self.names[&NameKey::StructMember(result.ty, member_index as u32)], + member.ty, + member.binding.as_ref(), + )); + } + } else { + result_members.push(( + &result_member_name, + result.ty, + result.binding.as_ref(), + )); + } + + writeln!(self.out, "struct {stage_out_name} {{")?; + let mut has_point_size = false; + for (name, ty, binding) in result_members { + let ty_name = TypeContext { + handle: ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: true, + }; + let binding = binding.ok_or(Error::Validation)?; + + if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding { + has_point_size = true; + if !pipeline_options.allow_and_force_point_size { + continue; + } + } + + let array_len = match module.types[ty].inner { + crate::TypeInner::Array { + size: crate::ArraySize::Constant(size), + .. + } => Some(size), + _ => None, + }; + let resolved = options.resolve_local_binding(binding, out_mode)?; + write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?; + if let Some(array_len) = array_len { + write!(self.out, " [{array_len}]")?; + } + resolved.try_fmt(&mut self.out)?; + writeln!(self.out, ";")?; + } + + if pipeline_options.allow_and_force_point_size + && ep.stage == crate::ShaderStage::Vertex + && !has_point_size + { + // inject the point size output last + writeln!( + self.out, + "{}float _point_size [[point_size]];", + back::INDENT + )?; + } + writeln!(self.out, "}};")?; + &stage_out_name + } + None => "void", + }; + + // Write the entry point function's name, and begin its argument list. + writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?; + let mut is_first_argument = true; + + // If we have produced a struct holding the `EntryPoint`'s + // `Function`'s arguments' varyings, pass that struct first. + if has_varyings { + writeln!( + self.out, + " {stage_in_name} {varyings_member_name} [[stage_in]]" + )?; + is_first_argument = false; + } + + let mut local_invocation_id = None; + + // Then pass the remaining arguments not included in the varyings + // struct. + for &(ref name_key, ty, binding) in flattened_arguments.iter() { + let binding = match binding { + Some(binding @ &crate::Binding::BuiltIn { .. }) => binding, + _ => continue, + }; + let name = match *name_key { + NameKey::StructMember(..) => &flattened_member_names[name_key], + _ => &self.names[name_key], + }; + + if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) { + local_invocation_id = Some(name_key); + } + + let ty_name = TypeContext { + handle: ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + let resolved = options.resolve_local_binding(binding, in_mode)?; + let separator = if is_first_argument { + is_first_argument = false; + ' ' + } else { + ',' + }; + write!(self.out, "{separator} {ty_name} {name}")?; + resolved.try_fmt(&mut self.out)?; + writeln!(self.out)?; + } + + let need_workgroup_variables_initialization = + self.need_workgroup_variables_initialization(options, ep, module, fun_info); + + if need_workgroup_variables_initialization && local_invocation_id.is_none() { + let separator = if is_first_argument { + is_first_argument = false; + ' ' + } else { + ',' + }; + writeln!( + self.out, + "{separator} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]" + )?; + } + + // Those global variables used by this entry point and its callees + // get passed as arguments. `Private` globals are an exception, they + // don't outlive this invocation, so we declare them below as locals + // within the entry point. + for (handle, var) in module.global_variables.iter() { + let usage = fun_info[handle]; + if usage.is_empty() || var.space == crate::AddressSpace::Private { + continue; + } + + if options.lang_version < (1, 2) { + match var.space { + // This restriction is not documented in the MSL spec + // but validation will fail if it is not upheld. + // + // We infer the required version from the "Function + // Buffer Read-Writes" section of [what's new], where + // the feature sets listed correspond with the ones + // supporting MSL 1.2. + // + // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html + crate::AddressSpace::Storage { access } + if access.contains(crate::StorageAccess::STORE) + && ep.stage == crate::ShaderStage::Fragment => + { + return Err(Error::UnsupportedWriteableStorageBuffer) + } + crate::AddressSpace::Handle => { + match module.types[var.ty].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage { access, .. }, + .. + } => { + // This restriction is not documented in the MSL spec + // but validation will fail if it is not upheld. + // + // We infer the required version from the "Function + // Texture Read-Writes" section of [what's new], where + // the feature sets listed correspond with the ones + // supporting MSL 1.2. + // + // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html + if access.contains(crate::StorageAccess::STORE) + && (ep.stage == crate::ShaderStage::Vertex + || ep.stage == crate::ShaderStage::Fragment) + { + return Err(Error::UnsupportedWriteableStorageTexture( + ep.stage, + )); + } + + if access.contains( + crate::StorageAccess::LOAD | crate::StorageAccess::STORE, + ) { + return Err(Error::UnsupportedRWStorageTexture); + } + } + _ => {} + } + } + _ => {} + } + } + + // Check min MSL version for binding arrays + match var.space { + crate::AddressSpace::Handle => match module.types[var.ty].inner { + crate::TypeInner::BindingArray { base, .. } => { + match module.types[base].inner { + crate::TypeInner::Sampler { .. } => { + if options.lang_version < (2, 0) { + return Err(Error::UnsupportedArrayOf( + "samplers".to_string(), + )); + } + } + crate::TypeInner::Image { class, .. } => match class { + crate::ImageClass::Sampled { .. } + | crate::ImageClass::Depth { .. } + | crate::ImageClass::Storage { + access: crate::StorageAccess::LOAD, + .. + } => { + // Array of textures since: + // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164) + // - macOS: Metal 2 + + if options.lang_version < (2, 0) { + return Err(Error::UnsupportedArrayOf( + "textures".to_string(), + )); + } + } + crate::ImageClass::Storage { + access: crate::StorageAccess::STORE, + .. + } => { + // Array of write-only textures since: + // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164) + // - macOS: Metal 2 + + if options.lang_version < (2, 0) { + return Err(Error::UnsupportedArrayOf( + "write-only textures".to_string(), + )); + } + } + crate::ImageClass::Storage { .. } => { + return Err(Error::UnsupportedArrayOf( + "read-write textures".to_string(), + )); + } + }, + _ => { + return Err(Error::UnsupportedArrayOfType(base)); + } + } + } + _ => {} + }, + _ => {} + } + + // the resolves have already been checked for `!fake_missing_bindings` case + let resolved = match var.space { + crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(), + crate::AddressSpace::WorkGroup => None, + _ => options + .resolve_resource_binding(ep, var.binding.as_ref().unwrap()) + .ok(), + }; + if let Some(ref resolved) = resolved { + // Inline samplers are be defined in the EP body + if resolved.as_inline_sampler(options).is_some() { + continue; + } + } + + let tyvar = TypedGlobalVariable { + module, + names: &self.names, + handle, + usage, + binding: resolved.as_ref(), + reference: true, + }; + let separator = if is_first_argument { + is_first_argument = false; + ' ' + } else { + ',' + }; + write!(self.out, "{separator} ")?; + tyvar.try_fmt(&mut self.out)?; + if let Some(resolved) = resolved { + resolved.try_fmt(&mut self.out)?; + } + if let Some(value) = var.init { + write!(self.out, " = ")?; + self.put_const_expression(value, module, mod_info)?; + } + writeln!(self.out)?; + } + + // If this entry uses any variable-length arrays, their sizes are + // passed as a final struct-typed argument. + if supports_array_length { + // this is checked earlier + let resolved = options.resolve_sizes_buffer(ep).unwrap(); + let separator = if module.global_variables.is_empty() { + ' ' + } else { + ',' + }; + write!( + self.out, + "{separator} constant _mslBufferSizes& _buffer_sizes", + )?; + resolved.try_fmt(&mut self.out)?; + writeln!(self.out)?; + } + + // end of the entry point argument list + writeln!(self.out, ") {{")?; + + if need_workgroup_variables_initialization { + self.write_workgroup_variables_initialization( + module, + mod_info, + fun_info, + local_invocation_id, + )?; + } + + // Metal doesn't support private mutable variables outside of functions, + // so we put them here, just like the locals. + for (handle, var) in module.global_variables.iter() { + let usage = fun_info[handle]; + if usage.is_empty() { + continue; + } + if var.space == crate::AddressSpace::Private { + let tyvar = TypedGlobalVariable { + module, + names: &self.names, + handle, + usage, + binding: None, + reference: false, + }; + write!(self.out, "{}", back::INDENT)?; + tyvar.try_fmt(&mut self.out)?; + match var.init { + Some(value) => { + write!(self.out, " = ")?; + self.put_const_expression(value, module, mod_info)?; + writeln!(self.out, ";")?; + } + None => { + writeln!(self.out, " = {{}};")?; + } + }; + } else if let Some(ref binding) = var.binding { + // write an inline sampler + let resolved = options.resolve_resource_binding(ep, binding).unwrap(); + if let Some(sampler) = resolved.as_inline_sampler(options) { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + writeln!( + self.out, + "{}constexpr {}::sampler {}(", + back::INDENT, + NAMESPACE, + name + )?; + self.put_inline_sampler_properties(back::Level(2), sampler)?; + writeln!(self.out, "{});", back::INDENT)?; + } + } + } + + // Now take the arguments that we gathered into structs, and the + // structs that we flattened into arguments, and emit local + // variables with initializers that put everything back the way the + // body code expects. + // + // If we had to generate fresh names for struct members passed as + // arguments, be sure to use those names when rebuilding the struct. + // + // "Each day, I change some zeros to ones, and some ones to zeros. + // The rest, I leave alone." + for (arg_index, arg) in fun.arguments.iter().enumerate() { + let arg_name = + &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)]; + match module.types[arg.ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + let struct_name = &self.names[&NameKey::Type(arg.ty)]; + write!( + self.out, + "{}const {} {} = {{ ", + back::INDENT, + struct_name, + arg_name + )?; + for (member_index, member) in members.iter().enumerate() { + let key = NameKey::StructMember(arg.ty, member_index as u32); + let name = &flattened_member_names[&key]; + if member_index != 0 { + write!(self.out, ", ")?; + } + // insert padding initialization, if needed + if self + .struct_member_pads + .contains(&(arg.ty, member_index as u32)) + { + write!(self.out, "{{}}, ")?; + } + if let Some(crate::Binding::Location { .. }) = member.binding { + write!(self.out, "{varyings_member_name}.")?; + } + write!(self.out, "{name}")?; + } + writeln!(self.out, " }};")?; + } + _ => { + if let Some(crate::Binding::Location { .. }) = arg.binding { + writeln!( + self.out, + "{}const auto {} = {}.{};", + back::INDENT, + arg_name, + varyings_member_name, + arg_name + )?; + } + } + } + } + + let guarded_indices = + index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies); + + let context = StatementContext { + expression: ExpressionContext { + function: fun, + origin: FunctionOrigin::EntryPoint(ep_index as _), + info: fun_info, + lang_version: options.lang_version, + policies: options.bounds_check_policies, + guarded_indices, + module, + mod_info, + pipeline_options, + }, + result_struct: Some(&stage_out_name), + }; + + // Finally, declare all the local variables that we need + //TODO: we can postpone this till the relevant expressions are emitted + for (local_handle, local) in fun.local_variables.iter() { + let name = &self.names[&NameKey::EntryPointLocal(ep_index as _, local_handle)]; + let ty_name = TypeContext { + handle: local.ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?; + match local.init { + Some(value) => { + write!(self.out, " = ")?; + self.put_expression(value, &context.expression, true)?; + } + None => { + write!(self.out, " = {{}}")?; + } + }; + writeln!(self.out, ";")?; + } + + self.update_expressions_to_bake(fun, fun_info, &context.expression); + self.put_block(back::Level(1), &fun.body, &context)?; + writeln!(self.out, "}}")?; + if ep_index + 1 != module.entry_points.len() { + writeln!(self.out)?; + } + self.named_expressions.clear(); + } + + Ok(info) + } + + fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult { + // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`, + // so we try to avoid it here. + if flags.is_empty() { + writeln!( + self.out, + "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);", + )?; + } + if flags.contains(crate::Barrier::STORAGE) { + writeln!( + self.out, + "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);", + )?; + } + if flags.contains(crate::Barrier::WORK_GROUP) { + writeln!( + self.out, + "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", + )?; + } + Ok(()) + } +} + +/// Initializing workgroup variables is more tricky for Metal because we have to deal +/// with atomics at the type-level (which don't have a copy constructor). +mod workgroup_mem_init { + use crate::EntryPoint; + + use super::*; + + enum Access { + GlobalVariable(Handle<crate::GlobalVariable>), + StructMember(Handle<crate::Type>, u32), + Array(usize), + } + + impl Access { + fn write<W: Write>( + &self, + writer: &mut W, + names: &FastHashMap<NameKey, String>, + ) -> Result<(), core::fmt::Error> { + match *self { + Access::GlobalVariable(handle) => { + write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)]) + } + Access::StructMember(handle, index) => { + write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)]) + } + Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"), + } + } + } + + struct AccessStack { + stack: Vec<Access>, + array_depth: usize, + } + + impl AccessStack { + const fn new() -> Self { + Self { + stack: Vec::new(), + array_depth: 0, + } + } + + fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R { + let array_depth = self.array_depth; + self.stack.push(Access::Array(array_depth)); + self.array_depth += 1; + let res = cb(self, array_depth); + self.stack.pop(); + self.array_depth -= 1; + res + } + + fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R { + self.stack.push(new); + let res = cb(self); + self.stack.pop(); + res + } + + fn write<W: Write>( + &self, + writer: &mut W, + names: &FastHashMap<NameKey, String>, + ) -> Result<(), core::fmt::Error> { + for next in self.stack.iter() { + next.write(writer, names)?; + } + Ok(()) + } + } + + impl<W: Write> Writer<W> { + pub(super) fn need_workgroup_variables_initialization( + &mut self, + options: &Options, + ep: &EntryPoint, + module: &crate::Module, + fun_info: &valid::FunctionInfo, + ) -> bool { + options.zero_initialize_workgroup_memory + && ep.stage == crate::ShaderStage::Compute + && module.global_variables.iter().any(|(handle, var)| { + !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }) + } + + pub(super) fn write_workgroup_variables_initialization( + &mut self, + module: &crate::Module, + module_info: &valid::ModuleInfo, + fun_info: &valid::FunctionInfo, + local_invocation_id: Option<&NameKey>, + ) -> BackendResult { + let level = back::Level(1); + + writeln!( + self.out, + "{}if ({}::all({} == {}::uint3(0u))) {{", + level, + NAMESPACE, + local_invocation_id + .map(|name_key| self.names[name_key].as_str()) + .unwrap_or("__local_invocation_id"), + NAMESPACE, + )?; + + let mut access_stack = AccessStack::new(); + + let vars = module.global_variables.iter().filter(|&(handle, var)| { + !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }); + + for (handle, var) in vars { + access_stack.enter(Access::GlobalVariable(handle), |access_stack| { + self.write_workgroup_variable_initialization( + module, + module_info, + var.ty, + access_stack, + level.next(), + ) + })?; + } + + writeln!(self.out, "{level}}}")?; + self.write_barrier(crate::Barrier::WORK_GROUP, level) + } + + fn write_workgroup_variable_initialization( + &mut self, + module: &crate::Module, + module_info: &valid::ModuleInfo, + ty: Handle<crate::Type>, + access_stack: &mut AccessStack, + level: back::Level, + ) -> BackendResult { + if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) { + write!(self.out, "{level}")?; + access_stack.write(&mut self.out, &self.names)?; + writeln!(self.out, " = {{}};")?; + } else { + match module.types[ty].inner { + crate::TypeInner::Atomic { .. } => { + write!( + self.out, + "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}" + )?; + access_stack.write(&mut self.out, &self.names)?; + writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?; + } + crate::TypeInner::Array { base, size, .. } => { + let count = match size.to_indexable_length(module).expect("Bad array size") + { + proc::IndexableLength::Known(count) => count, + proc::IndexableLength::Dynamic => unreachable!(), + }; + + access_stack.enter_array(|access_stack, array_depth| { + writeln!( + self.out, + "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{" + )?; + self.write_workgroup_variable_initialization( + module, + module_info, + base, + access_stack, + level.next(), + )?; + writeln!(self.out, "{level}}}")?; + BackendResult::Ok(()) + })?; + } + crate::TypeInner::Struct { ref members, .. } => { + for (index, member) in members.iter().enumerate() { + access_stack.enter( + Access::StructMember(ty, index as u32), + |access_stack| { + self.write_workgroup_variable_initialization( + module, + module_info, + member.ty, + access_stack, + level, + ) + }, + )?; + } + } + _ => unreachable!(), + } + } + + Ok(()) + } + } +} + +#[test] +fn test_stack_size() { + use crate::valid::{Capabilities, ValidationFlags}; + // create a module with at least one expression nested + let mut module = crate::Module::default(); + let mut fun = crate::Function::default(); + let const_expr = fun.expressions.append( + crate::Expression::Literal(crate::Literal::F32(1.0)), + Default::default(), + ); + let nested_expr = fun.expressions.append( + crate::Expression::Unary { + op: crate::UnaryOperator::Negate, + expr: const_expr, + }, + Default::default(), + ); + fun.body.push( + crate::Statement::Emit(fun.expressions.range_from(1)), + Default::default(), + ); + fun.body.push( + crate::Statement::If { + condition: nested_expr, + accept: crate::Block::new(), + reject: crate::Block::new(), + }, + Default::default(), + ); + let _ = module.functions.append(fun, Default::default()); + // analyse the module + let info = crate::valid::Validator::new(ValidationFlags::empty(), Capabilities::empty()) + .validate(&module) + .unwrap(); + // process the module + let mut writer = Writer::new(String::new()); + writer + .write(&module, &info, &Default::default(), &Default::default()) + .unwrap(); + + { + // check expression stack + let mut addresses_start = usize::MAX; + let mut addresses_end = 0usize; + for pointer in writer.put_expression_stack_pointers { + addresses_start = addresses_start.min(pointer as usize); + addresses_end = addresses_end.max(pointer as usize); + } + let stack_size = addresses_end - addresses_start; + // check the size (in debug only) + // last observed macOS value: 20528 (CI) + if !(11000..=25000).contains(&stack_size) { + panic!("`put_expression` stack size {stack_size} has changed!"); + } + } + + { + // check block stack + let mut addresses_start = usize::MAX; + let mut addresses_end = 0usize; + for pointer in writer.put_block_stack_pointers { + addresses_start = addresses_start.min(pointer as usize); + addresses_end = addresses_end.max(pointer as usize); + } + let stack_size = addresses_end - addresses_start; + // check the size (in debug only) + // last observed macOS value: 19152 (CI) + if !(9000..=20000).contains(&stack_size) { + panic!("`put_block` stack size {stack_size} has changed!"); + } + } +} diff --git a/third_party/rust/naga/src/back/spv/block.rs b/third_party/rust/naga/src/back/spv/block.rs new file mode 100644 index 0000000000..6c96fa09e3 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/block.rs @@ -0,0 +1,2368 @@ +/*! +Implementations for `BlockContext` methods. +*/ + +use super::{ + helpers, index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext, + Dimension, Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer, + WriterFlags, +}; +use crate::{arena::Handle, proc::TypeResolution, Statement}; +use spirv::Word; + +fn get_dimension(type_inner: &crate::TypeInner) -> Dimension { + match *type_inner { + crate::TypeInner::Scalar(_) => Dimension::Scalar, + crate::TypeInner::Vector { .. } => Dimension::Vector, + crate::TypeInner::Matrix { .. } => Dimension::Matrix, + _ => unreachable!(), + } +} + +/// The results of emitting code for a left-hand-side expression. +/// +/// On success, `write_expression_pointer` returns one of these. +enum ExpressionPointer { + /// The pointer to the expression's value is available, as the value of the + /// expression with the given id. + Ready { pointer_id: Word }, + + /// The access expression must be conditional on the value of `condition`, a boolean + /// expression that is true if all indices are in bounds. If `condition` is true, then + /// `access` is an `OpAccessChain` instruction that will compute a pointer to the + /// expression's value. If `condition` is false, then executing `access` would be + /// undefined behavior. + Conditional { + condition: Word, + access: Instruction, + }, +} + +/// The termination statement to be added to the end of the block +pub enum BlockExit { + /// Generates an OpReturn (void return) + Return, + /// Generates an OpBranch to the specified block + Branch { + /// The branch target block + target: Word, + }, + /// Translates a loop `break if` into an `OpBranchConditional` to the + /// merge block if true (the merge block is passed through [`LoopContext::break_id`] + /// or else to the loop header (passed through [`preamble_id`]) + /// + /// [`preamble_id`]: Self::BreakIf::preamble_id + BreakIf { + /// The condition of the `break if` + condition: Handle<crate::Expression>, + /// The loop header block id + preamble_id: Word, + }, +} + +#[derive(Debug)] +pub(crate) struct DebugInfoInner<'a> { + pub source_code: &'a str, + pub source_file_id: Word, +} + +impl Writer { + // Flip Y coordinate to adjust for coordinate space difference + // between SPIR-V and our IR. + // The `position_id` argument is a pointer to a `vecN<f32>`, + // whose `y` component we will negate. + fn write_epilogue_position_y_flip( + &mut self, + position_id: Word, + body: &mut Vec<Instruction>, + ) -> Result<(), Error> { + let float_ptr_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::F32, + pointer_space: Some(spirv::StorageClass::Output), + })); + let index_y_id = self.get_index_constant(1); + let access_id = self.id_gen.next(); + body.push(Instruction::access_chain( + float_ptr_type_id, + access_id, + position_id, + &[index_y_id], + )); + + let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::F32, + pointer_space: None, + })); + let load_id = self.id_gen.next(); + body.push(Instruction::load(float_type_id, load_id, access_id, None)); + + let neg_id = self.id_gen.next(); + body.push(Instruction::unary( + spirv::Op::FNegate, + float_type_id, + neg_id, + load_id, + )); + + body.push(Instruction::store(access_id, neg_id, None)); + Ok(()) + } + + // Clamp fragment depth between 0 and 1. + fn write_epilogue_frag_depth_clamp( + &mut self, + frag_depth_id: Word, + body: &mut Vec<Instruction>, + ) -> Result<(), Error> { + let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::F32, + pointer_space: None, + })); + let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0)); + let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0)); + + let original_id = self.id_gen.next(); + body.push(Instruction::load( + float_type_id, + original_id, + frag_depth_id, + None, + )); + + let clamp_id = self.id_gen.next(); + body.push(Instruction::ext_inst( + self.gl450_ext_inst_id, + spirv::GLOp::FClamp, + float_type_id, + clamp_id, + &[original_id, zero_scalar_id, one_scalar_id], + )); + + body.push(Instruction::store(frag_depth_id, clamp_id, None)); + Ok(()) + } + + fn write_entry_point_return( + &mut self, + value_id: Word, + ir_result: &crate::FunctionResult, + result_members: &[ResultMember], + body: &mut Vec<Instruction>, + ) -> Result<(), Error> { + for (index, res_member) in result_members.iter().enumerate() { + let member_value_id = match ir_result.binding { + Some(_) => value_id, + None => { + let member_value_id = self.id_gen.next(); + body.push(Instruction::composite_extract( + res_member.type_id, + member_value_id, + value_id, + &[index as u32], + )); + member_value_id + } + }; + + body.push(Instruction::store(res_member.id, member_value_id, None)); + + match res_member.built_in { + Some(crate::BuiltIn::Position { .. }) + if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) => + { + self.write_epilogue_position_y_flip(res_member.id, body)?; + } + Some(crate::BuiltIn::FragDepth) + if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) => + { + self.write_epilogue_frag_depth_clamp(res_member.id, body)?; + } + _ => {} + } + } + Ok(()) + } +} + +impl<'w> BlockContext<'w> { + /// Decide whether to put off emitting instructions for `expr_handle`. + /// + /// We would like to gather together chains of `Access` and `AccessIndex` + /// Naga expressions into a single `OpAccessChain` SPIR-V instruction. To do + /// this, we don't generate instructions for these exprs when we first + /// encounter them. Their ids in `self.writer.cached.ids` are left as zero. Then, + /// once we encounter a `Load` or `Store` expression that actually needs the + /// chain's value, we call `write_expression_pointer` to handle the whole + /// thing in one fell swoop. + fn is_intermediate(&self, expr_handle: Handle<crate::Expression>) -> bool { + match self.ir_function.expressions[expr_handle] { + crate::Expression::GlobalVariable(handle) => { + match self.ir_module.global_variables[handle].space { + crate::AddressSpace::Handle => false, + _ => true, + } + } + crate::Expression::LocalVariable(_) => true, + crate::Expression::FunctionArgument(index) => { + let arg = &self.ir_function.arguments[index as usize]; + self.ir_module.types[arg.ty].inner.pointer_space().is_some() + } + + // The chain rule: if this `Access...`'s `base` operand was + // previously omitted, then omit this one, too. + _ => self.cached.ids[expr_handle.index()] == 0, + } + } + + /// Cache an expression for a value. + pub(super) fn cache_expression_value( + &mut self, + expr_handle: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<(), Error> { + let is_named_expression = self + .ir_function + .named_expressions + .contains_key(&expr_handle); + + if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression { + return Ok(()); + } + + let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty); + let id = match self.ir_function.expressions[expr_handle] { + crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal), + crate::Expression::Constant(handle) => { + let init = self.ir_module.constants[handle].init; + self.writer.constant_ids[init.index()] + } + crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id), + crate::Expression::Compose { ty, ref components } => { + self.temp_list.clear(); + if self.expression_constness.is_const(expr_handle) { + self.temp_list.extend( + crate::proc::flatten_compose( + ty, + components, + &self.ir_function.expressions, + &self.ir_module.types, + ) + .map(|component| self.cached[component]), + ); + self.writer + .get_constant_composite(LookupType::Handle(ty), &self.temp_list) + } else { + self.temp_list + .extend(components.iter().map(|&component| self.cached[component])); + + let id = self.gen_id(); + block.body.push(Instruction::composite_construct( + result_type_id, + id, + &self.temp_list, + )); + id + } + } + crate::Expression::Splat { size, value } => { + let value_id = self.cached[value]; + let components = &[value_id; 4][..size as usize]; + + if self.expression_constness.is_const(expr_handle) { + let ty = self + .writer + .get_expression_lookup_type(&self.fun_info[expr_handle].ty); + self.writer.get_constant_composite(ty, components) + } else { + let id = self.gen_id(); + block.body.push(Instruction::composite_construct( + result_type_id, + id, + components, + )); + id + } + } + crate::Expression::Access { base, index: _ } if self.is_intermediate(base) => { + // See `is_intermediate`; we'll handle this later in + // `write_expression_pointer`. + 0 + } + crate::Expression::Access { base, index } => { + let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types); + match *base_ty_inner { + crate::TypeInner::Vector { .. } => { + self.write_vector_access(expr_handle, base, index, block)? + } + // Only binding arrays in the Handle address space will take this path (due to `is_intermediate`) + crate::TypeInner::BindingArray { + base: binding_type, .. + } => { + let space = match self.ir_function.expressions[base] { + crate::Expression::GlobalVariable(gvar) => { + self.ir_module.global_variables[gvar].space + } + _ => unreachable!(), + }; + let binding_array_false_pointer = LookupType::Local(LocalType::Pointer { + base: binding_type, + class: helpers::map_storage_class(space), + }); + + let result_id = match self.write_expression_pointer( + expr_handle, + block, + Some(binding_array_false_pointer), + )? { + ExpressionPointer::Ready { pointer_id } => pointer_id, + ExpressionPointer::Conditional { .. } => { + return Err(Error::FeatureNotImplemented( + "Texture array out-of-bounds handling", + )); + } + }; + + let binding_type_id = self.get_type_id(LookupType::Handle(binding_type)); + + let load_id = self.gen_id(); + block.body.push(Instruction::load( + binding_type_id, + load_id, + result_id, + None, + )); + + // Subsequent image operations require the image/sampler to be decorated as NonUniform + // if the image/sampler binding array was accessed with a non-uniform index + // see VUID-RuntimeSpirv-NonUniform-06274 + if self.fun_info[index].uniformity.non_uniform_result.is_some() { + self.writer + .decorate_non_uniform_binding_array_access(load_id)?; + } + + load_id + } + ref other => { + log::error!( + "Unable to access base {:?} of type {:?}", + self.ir_function.expressions[base], + other + ); + return Err(Error::Validation( + "only vectors may be dynamically indexed by value", + )); + } + } + } + crate::Expression::AccessIndex { base, index: _ } if self.is_intermediate(base) => { + // See `is_intermediate`; we'll handle this later in + // `write_expression_pointer`. + 0 + } + crate::Expression::AccessIndex { base, index } => { + match *self.fun_info[base].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } + | crate::TypeInner::Array { .. } + | crate::TypeInner::Struct { .. } => { + // We never need bounds checks here: dynamically sized arrays can + // only appear behind pointers, and are thus handled by the + // `is_intermediate` case above. Everything else's size is + // statically known and checked in validation. + let id = self.gen_id(); + let base_id = self.cached[base]; + block.body.push(Instruction::composite_extract( + result_type_id, + id, + base_id, + &[index], + )); + id + } + // Only binding arrays in the Handle address space will take this path (due to `is_intermediate`) + crate::TypeInner::BindingArray { + base: binding_type, .. + } => { + let space = match self.ir_function.expressions[base] { + crate::Expression::GlobalVariable(gvar) => { + self.ir_module.global_variables[gvar].space + } + _ => unreachable!(), + }; + let binding_array_false_pointer = LookupType::Local(LocalType::Pointer { + base: binding_type, + class: helpers::map_storage_class(space), + }); + + let result_id = match self.write_expression_pointer( + expr_handle, + block, + Some(binding_array_false_pointer), + )? { + ExpressionPointer::Ready { pointer_id } => pointer_id, + ExpressionPointer::Conditional { .. } => { + return Err(Error::FeatureNotImplemented( + "Texture array out-of-bounds handling", + )); + } + }; + + let binding_type_id = self.get_type_id(LookupType::Handle(binding_type)); + + let load_id = self.gen_id(); + block.body.push(Instruction::load( + binding_type_id, + load_id, + result_id, + None, + )); + + load_id + } + ref other => { + log::error!("Unable to access index of {:?}", other); + return Err(Error::FeatureNotImplemented("access index for type")); + } + } + } + crate::Expression::GlobalVariable(handle) => { + self.writer.global_variables[handle.index()].access_id + } + crate::Expression::Swizzle { + size, + vector, + pattern, + } => { + let vector_id = self.cached[vector]; + self.temp_list.clear(); + for &sc in pattern[..size as usize].iter() { + self.temp_list.push(sc as Word); + } + let id = self.gen_id(); + block.body.push(Instruction::vector_shuffle( + result_type_id, + id, + vector_id, + vector_id, + &self.temp_list, + )); + id + } + crate::Expression::Unary { op, expr } => { + let id = self.gen_id(); + let expr_id = self.cached[expr]; + let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types); + + let spirv_op = match op { + crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Float) => spirv::Op::FNegate, + Some(crate::ScalarKind::Sint) => spirv::Op::SNegate, + _ => return Err(Error::Validation("Unexpected kind for negation")), + }, + crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot, + crate::UnaryOperator::BitwiseNot => spirv::Op::Not, + }; + + block + .body + .push(Instruction::unary(spirv_op, result_type_id, id, expr_id)); + id + } + crate::Expression::Binary { op, left, right } => { + let id = self.gen_id(); + let left_id = self.cached[left]; + let right_id = self.cached[right]; + + let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types); + let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types); + + let left_dimension = get_dimension(left_ty_inner); + let right_dimension = get_dimension(right_ty_inner); + + let mut reverse_operands = false; + + let spirv_op = match op { + crate::BinaryOperator::Add => match *left_ty_inner { + crate::TypeInner::Scalar(scalar) + | crate::TypeInner::Vector { scalar, .. } => match scalar.kind { + crate::ScalarKind::Float => spirv::Op::FAdd, + _ => spirv::Op::IAdd, + }, + crate::TypeInner::Matrix { + columns, + rows, + scalar, + } => { + self.write_matrix_matrix_column_op( + block, + id, + result_type_id, + left_id, + right_id, + columns, + rows, + scalar.width, + spirv::Op::FAdd, + ); + + self.cached[expr_handle] = id; + return Ok(()); + } + _ => unimplemented!(), + }, + crate::BinaryOperator::Subtract => match *left_ty_inner { + crate::TypeInner::Scalar(scalar) + | crate::TypeInner::Vector { scalar, .. } => match scalar.kind { + crate::ScalarKind::Float => spirv::Op::FSub, + _ => spirv::Op::ISub, + }, + crate::TypeInner::Matrix { + columns, + rows, + scalar, + } => { + self.write_matrix_matrix_column_op( + block, + id, + result_type_id, + left_id, + right_id, + columns, + rows, + scalar.width, + spirv::Op::FSub, + ); + + self.cached[expr_handle] = id; + return Ok(()); + } + _ => unimplemented!(), + }, + crate::BinaryOperator::Multiply => match (left_dimension, right_dimension) { + (Dimension::Scalar, Dimension::Vector) => { + self.write_vector_scalar_mult( + block, + id, + result_type_id, + right_id, + left_id, + right_ty_inner, + ); + + self.cached[expr_handle] = id; + return Ok(()); + } + (Dimension::Vector, Dimension::Scalar) => { + self.write_vector_scalar_mult( + block, + id, + result_type_id, + left_id, + right_id, + left_ty_inner, + ); + + self.cached[expr_handle] = id; + return Ok(()); + } + (Dimension::Vector, Dimension::Matrix) => spirv::Op::VectorTimesMatrix, + (Dimension::Matrix, Dimension::Scalar) => spirv::Op::MatrixTimesScalar, + (Dimension::Scalar, Dimension::Matrix) => { + reverse_operands = true; + spirv::Op::MatrixTimesScalar + } + (Dimension::Matrix, Dimension::Vector) => spirv::Op::MatrixTimesVector, + (Dimension::Matrix, Dimension::Matrix) => spirv::Op::MatrixTimesMatrix, + (Dimension::Vector, Dimension::Vector) + | (Dimension::Scalar, Dimension::Scalar) + if left_ty_inner.scalar_kind() == Some(crate::ScalarKind::Float) => + { + spirv::Op::FMul + } + (Dimension::Vector, Dimension::Vector) + | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul, + }, + crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::SDiv, + Some(crate::ScalarKind::Uint) => spirv::Op::UDiv, + Some(crate::ScalarKind::Float) => spirv::Op::FDiv, + _ => unimplemented!(), + }, + crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() { + // TODO: handle undefined behavior + // if right == 0 return 0 + // if left == min(type_of(left)) && right == -1 return 0 + Some(crate::ScalarKind::Sint) => spirv::Op::SRem, + // TODO: handle undefined behavior + // if right == 0 return 0 + Some(crate::ScalarKind::Uint) => spirv::Op::UMod, + // TODO: handle undefined behavior + // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798 + Some(crate::ScalarKind::Float) => spirv::Op::FRem, + _ => unimplemented!(), + }, + crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => { + spirv::Op::IEqual + } + Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual, + Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual, + _ => unimplemented!(), + }, + crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => { + spirv::Op::INotEqual + } + Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual, + Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual, + _ => unimplemented!(), + }, + crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan, + Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan, + Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan, + _ => unimplemented!(), + }, + crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual, + Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual, + Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual, + _ => unimplemented!(), + }, + crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan, + Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan, + Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan, + _ => unimplemented!(), + }, + crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual, + Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual, + Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual, + _ => unimplemented!(), + }, + crate::BinaryOperator::And => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd, + _ => spirv::Op::BitwiseAnd, + }, + crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor, + crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr, + _ => spirv::Op::BitwiseOr, + }, + crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd, + crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr, + crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical, + crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic, + Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical, + _ => unimplemented!(), + }, + }; + + block.body.push(Instruction::binary( + spirv_op, + result_type_id, + id, + if reverse_operands { right_id } else { left_id }, + if reverse_operands { left_id } else { right_id }, + )); + id + } + crate::Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + use crate::MathFunction as Mf; + enum MathOp { + Ext(spirv::GLOp), + Custom(Instruction), + } + + let arg0_id = self.cached[arg]; + let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types); + let arg_scalar_kind = arg_ty.scalar_kind(); + let arg1_id = match arg1 { + Some(handle) => self.cached[handle], + None => 0, + }; + let arg2_id = match arg2 { + Some(handle) => self.cached[handle], + None => 0, + }; + let arg3_id = match arg3 { + Some(handle) => self.cached[handle], + None => 0, + }; + + let id = self.gen_id(); + let math_op = match fun { + // comparison + Mf::Abs => { + match arg_scalar_kind { + Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FAbs), + Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GLOp::SAbs), + Some(crate::ScalarKind::Uint) => { + MathOp::Custom(Instruction::unary( + spirv::Op::CopyObject, // do nothing + result_type_id, + id, + arg0_id, + )) + } + other => unimplemented!("Unexpected abs({:?})", other), + } + } + Mf::Min => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Float) => spirv::GLOp::FMin, + Some(crate::ScalarKind::Sint) => spirv::GLOp::SMin, + Some(crate::ScalarKind::Uint) => spirv::GLOp::UMin, + other => unimplemented!("Unexpected min({:?})", other), + }), + Mf::Max => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Float) => spirv::GLOp::FMax, + Some(crate::ScalarKind::Sint) => spirv::GLOp::SMax, + Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax, + other => unimplemented!("Unexpected max({:?})", other), + }), + Mf::Clamp => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Float) => spirv::GLOp::FClamp, + Some(crate::ScalarKind::Sint) => spirv::GLOp::SClamp, + Some(crate::ScalarKind::Uint) => spirv::GLOp::UClamp, + other => unimplemented!("Unexpected max({:?})", other), + }), + Mf::Saturate => { + let (maybe_size, scalar) = match *arg_ty { + crate::TypeInner::Vector { size, scalar } => (Some(size), scalar), + crate::TypeInner::Scalar(scalar) => (None, scalar), + ref other => unimplemented!("Unexpected saturate({:?})", other), + }; + let scalar = crate::Scalar::float(scalar.width); + let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?; + let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?; + + if let Some(size) = maybe_size { + let ty = LocalType::Value { + vector_size: Some(size), + scalar, + pointer_space: None, + } + .into(); + + self.temp_list.clear(); + self.temp_list.resize(size as _, arg1_id); + + arg1_id = self.writer.get_constant_composite(ty, &self.temp_list); + + self.temp_list.fill(arg2_id); + + arg2_id = self.writer.get_constant_composite(ty, &self.temp_list); + } + + MathOp::Custom(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::FClamp, + result_type_id, + id, + &[arg0_id, arg1_id, arg2_id], + )) + } + // trigonometry + Mf::Sin => MathOp::Ext(spirv::GLOp::Sin), + Mf::Sinh => MathOp::Ext(spirv::GLOp::Sinh), + Mf::Asin => MathOp::Ext(spirv::GLOp::Asin), + Mf::Cos => MathOp::Ext(spirv::GLOp::Cos), + Mf::Cosh => MathOp::Ext(spirv::GLOp::Cosh), + Mf::Acos => MathOp::Ext(spirv::GLOp::Acos), + Mf::Tan => MathOp::Ext(spirv::GLOp::Tan), + Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh), + Mf::Atan => MathOp::Ext(spirv::GLOp::Atan), + Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2), + Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh), + Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh), + Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh), + Mf::Radians => MathOp::Ext(spirv::GLOp::Radians), + Mf::Degrees => MathOp::Ext(spirv::GLOp::Degrees), + // decomposition + Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil), + Mf::Round => MathOp::Ext(spirv::GLOp::RoundEven), + Mf::Floor => MathOp::Ext(spirv::GLOp::Floor), + Mf::Fract => MathOp::Ext(spirv::GLOp::Fract), + Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc), + Mf::Modf => MathOp::Ext(spirv::GLOp::ModfStruct), + Mf::Frexp => MathOp::Ext(spirv::GLOp::FrexpStruct), + Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp), + // geometry + Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Vector { + scalar: + crate::Scalar { + kind: crate::ScalarKind::Float, + .. + }, + .. + } => MathOp::Custom(Instruction::binary( + spirv::Op::Dot, + result_type_id, + id, + arg0_id, + arg1_id, + )), + // TODO: consider using integer dot product if VK_KHR_shader_integer_dot_product is available + crate::TypeInner::Vector { size, .. } => { + self.write_dot_product( + id, + result_type_id, + arg0_id, + arg1_id, + size as u32, + block, + ); + self.cached[expr_handle] = id; + return Ok(()); + } + _ => unreachable!( + "Correct TypeInner for dot product should be already validated" + ), + }, + Mf::Outer => MathOp::Custom(Instruction::binary( + spirv::Op::OuterProduct, + result_type_id, + id, + arg0_id, + arg1_id, + )), + Mf::Cross => MathOp::Ext(spirv::GLOp::Cross), + Mf::Distance => MathOp::Ext(spirv::GLOp::Distance), + Mf::Length => MathOp::Ext(spirv::GLOp::Length), + Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize), + Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward), + Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect), + Mf::Refract => MathOp::Ext(spirv::GLOp::Refract), + // exponent + Mf::Exp => MathOp::Ext(spirv::GLOp::Exp), + Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2), + Mf::Log => MathOp::Ext(spirv::GLOp::Log), + Mf::Log2 => MathOp::Ext(spirv::GLOp::Log2), + Mf::Pow => MathOp::Ext(spirv::GLOp::Pow), + // computational + Mf::Sign => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Float) => spirv::GLOp::FSign, + Some(crate::ScalarKind::Sint) => spirv::GLOp::SSign, + other => unimplemented!("Unexpected sign({:?})", other), + }), + Mf::Fma => MathOp::Ext(spirv::GLOp::Fma), + Mf::Mix => { + let selector = arg2.unwrap(); + let selector_ty = + self.fun_info[selector].ty.inner_with(&self.ir_module.types); + match (arg_ty, selector_ty) { + // if the selector is a scalar, we need to splat it + ( + &crate::TypeInner::Vector { size, .. }, + &crate::TypeInner::Scalar(scalar), + ) => { + let selector_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(size), + scalar, + pointer_space: None, + })); + self.temp_list.clear(); + self.temp_list.resize(size as usize, arg2_id); + + let selector_id = self.gen_id(); + block.body.push(Instruction::composite_construct( + selector_type_id, + selector_id, + &self.temp_list, + )); + + MathOp::Custom(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::FMix, + result_type_id, + id, + &[arg0_id, arg1_id, selector_id], + )) + } + _ => MathOp::Ext(spirv::GLOp::FMix), + } + } + Mf::Step => MathOp::Ext(spirv::GLOp::Step), + Mf::SmoothStep => MathOp::Ext(spirv::GLOp::SmoothStep), + Mf::Sqrt => MathOp::Ext(spirv::GLOp::Sqrt), + Mf::InverseSqrt => MathOp::Ext(spirv::GLOp::InverseSqrt), + Mf::Inverse => MathOp::Ext(spirv::GLOp::MatrixInverse), + Mf::Transpose => MathOp::Custom(Instruction::unary( + spirv::Op::Transpose, + result_type_id, + id, + arg0_id, + )), + Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant), + Mf::ReverseBits => MathOp::Custom(Instruction::unary( + spirv::Op::BitReverse, + result_type_id, + id, + arg0_id, + )), + Mf::CountTrailingZeros => { + let uint_id = match *arg_ty { + crate::TypeInner::Vector { size, mut scalar } => { + scalar.kind = crate::ScalarKind::Uint; + let ty = LocalType::Value { + vector_size: Some(size), + scalar, + pointer_space: None, + } + .into(); + + self.temp_list.clear(); + self.temp_list.resize( + size as _, + self.writer.get_constant_scalar_with(32, scalar)?, + ); + + self.writer.get_constant_composite(ty, &self.temp_list) + } + crate::TypeInner::Scalar(mut scalar) => { + scalar.kind = crate::ScalarKind::Uint; + self.writer.get_constant_scalar_with(32, scalar)? + } + _ => unreachable!(), + }; + + let lsb_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::FindILsb, + result_type_id, + lsb_id, + &[arg0_id], + )); + + MathOp::Custom(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::UMin, + result_type_id, + id, + &[uint_id, lsb_id], + )) + } + Mf::CountLeadingZeros => { + let (int_type_id, int_id) = match *arg_ty { + crate::TypeInner::Vector { size, mut scalar } => { + scalar.kind = crate::ScalarKind::Sint; + let ty = LocalType::Value { + vector_size: Some(size), + scalar, + pointer_space: None, + } + .into(); + + self.temp_list.clear(); + self.temp_list.resize( + size as _, + self.writer.get_constant_scalar_with(31, scalar)?, + ); + + ( + self.get_type_id(ty), + self.writer.get_constant_composite(ty, &self.temp_list), + ) + } + crate::TypeInner::Scalar(mut scalar) => { + scalar.kind = crate::ScalarKind::Sint; + ( + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar, + pointer_space: None, + })), + self.writer.get_constant_scalar_with(31, scalar)?, + ) + } + _ => unreachable!(), + }; + + let msb_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::FindUMsb, + int_type_id, + msb_id, + &[arg0_id], + )); + + MathOp::Custom(Instruction::binary( + spirv::Op::ISub, + result_type_id, + id, + int_id, + msb_id, + )) + } + Mf::CountOneBits => MathOp::Custom(Instruction::unary( + spirv::Op::BitCount, + result_type_id, + id, + arg0_id, + )), + Mf::ExtractBits => { + let op = match arg_scalar_kind { + Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract, + Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract, + other => unimplemented!("Unexpected sign({:?})", other), + }; + MathOp::Custom(Instruction::ternary( + op, + result_type_id, + id, + arg0_id, + arg1_id, + arg2_id, + )) + } + Mf::InsertBits => MathOp::Custom(Instruction::quaternary( + spirv::Op::BitFieldInsert, + result_type_id, + id, + arg0_id, + arg1_id, + arg2_id, + arg3_id, + )), + Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb), + Mf::FindMsb => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb, + Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb, + other => unimplemented!("Unexpected findMSB({:?})", other), + }), + Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8), + Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8), + Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16), + Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16), + Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16), + Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8), + Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8), + Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16), + Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16), + Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16), + }; + + block.body.push(match math_op { + MathOp::Ext(op) => Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + op, + result_type_id, + id, + &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()], + ), + MathOp::Custom(inst) => inst, + }); + id + } + crate::Expression::LocalVariable(variable) => self.function.variables[&variable].id, + crate::Expression::Load { pointer } => { + match self.write_expression_pointer(pointer, block, None)? { + ExpressionPointer::Ready { pointer_id } => { + let id = self.gen_id(); + let atomic_space = + match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Pointer { base, space } => { + match self.ir_module.types[base].inner { + crate::TypeInner::Atomic { .. } => Some(space), + _ => None, + } + } + _ => None, + }; + let instruction = if let Some(space) = atomic_space { + let (semantics, scope) = space.to_spirv_semantics_and_scope(); + let scope_constant_id = self.get_scope_constant(scope as u32); + let semantics_id = self.get_index_constant(semantics.bits()); + Instruction::atomic_load( + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + ) + } else { + Instruction::load(result_type_id, id, pointer_id, None) + }; + block.body.push(instruction); + id + } + ExpressionPointer::Conditional { condition, access } => { + //TODO: support atomics? + self.write_conditional_indexed_load( + result_type_id, + condition, + block, + move |id_gen, block| { + // The in-bounds path. Perform the access and the load. + let pointer_id = access.result_id.unwrap(); + let value_id = id_gen.next(); + block.body.push(access); + block.body.push(Instruction::load( + result_type_id, + value_id, + pointer_id, + None, + )); + value_id + }, + ) + } + } + } + crate::Expression::FunctionArgument(index) => self.function.parameter_id(index), + crate::Expression::CallResult(_) + | crate::Expression::AtomicResult { .. } + | crate::Expression::WorkGroupUniformLoadResult { .. } + | crate::Expression::RayQueryProceedResult => self.cached[expr_handle], + crate::Expression::As { + expr, + kind, + convert, + } => { + use crate::ScalarKind as Sk; + + let expr_id = self.cached[expr]; + let (src_scalar, src_size, is_matrix) = + match *self.fun_info[expr].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Scalar(scalar) => (scalar, None, false), + crate::TypeInner::Vector { scalar, size } => (scalar, Some(size), false), + crate::TypeInner::Matrix { scalar, .. } => (scalar, None, true), + ref other => { + log::error!("As source {:?}", other); + return Err(Error::Validation("Unexpected Expression::As source")); + } + }; + + enum Cast { + Identity, + Unary(spirv::Op), + Binary(spirv::Op, Word), + Ternary(spirv::Op, Word, Word), + } + + let cast = if is_matrix { + // we only support identity casts for matrices + Cast::Unary(spirv::Op::CopyObject) + } else { + match (src_scalar.kind, kind, convert) { + // Filter out identity casts. Some Adreno drivers are + // confused by no-op OpBitCast instructions. + (src_kind, kind, convert) + if src_kind == kind + && convert.filter(|&width| width != src_scalar.width).is_none() => + { + Cast::Identity + } + (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject), + (_, _, None) => Cast::Unary(spirv::Op::Bitcast), + // casting to a bool - generate `OpXxxNotEqual` + (_, Sk::Bool, Some(_)) => { + let op = match src_scalar.kind { + Sk::Sint | Sk::Uint => spirv::Op::INotEqual, + Sk::Float => spirv::Op::FUnordNotEqual, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(), + }; + let zero_scalar_id = + self.writer.get_constant_scalar_with(0, src_scalar)?; + let zero_id = match src_size { + Some(size) => { + let ty = LocalType::Value { + vector_size: Some(size), + scalar: src_scalar, + pointer_space: None, + } + .into(); + + self.temp_list.clear(); + self.temp_list.resize(size as _, zero_scalar_id); + + self.writer.get_constant_composite(ty, &self.temp_list) + } + None => zero_scalar_id, + }; + + Cast::Binary(op, zero_id) + } + // casting from a bool - generate `OpSelect` + (Sk::Bool, _, Some(dst_width)) => { + let dst_scalar = crate::Scalar { + kind, + width: dst_width, + }; + let zero_scalar_id = + self.writer.get_constant_scalar_with(0, dst_scalar)?; + let one_scalar_id = + self.writer.get_constant_scalar_with(1, dst_scalar)?; + let (accept_id, reject_id) = match src_size { + Some(size) => { + let ty = LocalType::Value { + vector_size: Some(size), + scalar: dst_scalar, + pointer_space: None, + } + .into(); + + self.temp_list.clear(); + self.temp_list.resize(size as _, zero_scalar_id); + + let vec0_id = + self.writer.get_constant_composite(ty, &self.temp_list); + + self.temp_list.fill(one_scalar_id); + + let vec1_id = + self.writer.get_constant_composite(ty, &self.temp_list); + + (vec1_id, vec0_id) + } + None => (one_scalar_id, zero_scalar_id), + }; + + Cast::Ternary(spirv::Op::Select, accept_id, reject_id) + } + (Sk::Float, Sk::Uint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToU), + (Sk::Float, Sk::Sint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToS), + (Sk::Float, Sk::Float, Some(dst_width)) + if src_scalar.width != dst_width => + { + Cast::Unary(spirv::Op::FConvert) + } + (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF), + (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => { + Cast::Unary(spirv::Op::SConvert) + } + (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF), + (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => { + Cast::Unary(spirv::Op::UConvert) + } + // We assume it's either an identity cast, or int-uint. + _ => Cast::Unary(spirv::Op::Bitcast), + } + }; + + let id = self.gen_id(); + let instruction = match cast { + Cast::Identity => None, + Cast::Unary(op) => Some(Instruction::unary(op, result_type_id, id, expr_id)), + Cast::Binary(op, operand) => Some(Instruction::binary( + op, + result_type_id, + id, + expr_id, + operand, + )), + Cast::Ternary(op, op1, op2) => Some(Instruction::ternary( + op, + result_type_id, + id, + expr_id, + op1, + op2, + )), + }; + if let Some(instruction) = instruction { + block.body.push(instruction); + id + } else { + expr_id + } + } + crate::Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => self.write_image_load( + result_type_id, + image, + coordinate, + array_index, + level, + sample, + block, + )?, + crate::Expression::ImageSample { + image, + sampler, + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + } => self.write_image_sample( + result_type_id, + image, + sampler, + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + block, + )?, + crate::Expression::Select { + condition, + accept, + reject, + } => { + let id = self.gen_id(); + let mut condition_id = self.cached[condition]; + let accept_id = self.cached[accept]; + let reject_id = self.cached[reject]; + + let condition_ty = self.fun_info[condition] + .ty + .inner_with(&self.ir_module.types); + let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types); + + if let ( + &crate::TypeInner::Scalar( + condition_scalar @ crate::Scalar { + kind: crate::ScalarKind::Bool, + .. + }, + ), + &crate::TypeInner::Vector { size, .. }, + ) = (condition_ty, object_ty) + { + self.temp_list.clear(); + self.temp_list.resize(size as usize, condition_id); + + let bool_vector_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(size), + scalar: condition_scalar, + pointer_space: None, + })); + + let id = self.gen_id(); + block.body.push(Instruction::composite_construct( + bool_vector_type_id, + id, + &self.temp_list, + )); + condition_id = id + } + + let instruction = + Instruction::select(result_type_id, id, condition_id, accept_id, reject_id); + block.body.push(instruction); + id + } + crate::Expression::Derivative { axis, ctrl, expr } => { + use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; + match ctrl { + Ctrl::Coarse | Ctrl::Fine => { + self.writer.require_any( + "DerivativeControl", + &[spirv::Capability::DerivativeControl], + )?; + } + Ctrl::None => {} + } + let id = self.gen_id(); + let expr_id = self.cached[expr]; + let op = match (axis, ctrl) { + (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse, + (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine, + (Axis::X, Ctrl::None) => spirv::Op::DPdx, + (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse, + (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine, + (Axis::Y, Ctrl::None) => spirv::Op::DPdy, + (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse, + (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine, + (Axis::Width, Ctrl::None) => spirv::Op::Fwidth, + }; + block + .body + .push(Instruction::derivative(op, result_type_id, id, expr_id)); + id + } + crate::Expression::ImageQuery { image, query } => { + self.write_image_query(result_type_id, image, query, block)? + } + crate::Expression::Relational { fun, argument } => { + use crate::RelationalFunction as Rf; + let arg_id = self.cached[argument]; + let op = match fun { + Rf::All => spirv::Op::All, + Rf::Any => spirv::Op::Any, + Rf::IsNan => spirv::Op::IsNan, + Rf::IsInf => spirv::Op::IsInf, + }; + let id = self.gen_id(); + block + .body + .push(Instruction::relational(op, result_type_id, id, arg_id)); + id + } + crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?, + crate::Expression::RayQueryGetIntersection { query, committed } => { + if !committed { + return Err(Error::FeatureNotImplemented("candidate intersection")); + } + self.write_ray_query_get_intersection(query, block) + } + }; + + self.cached[expr_handle] = id; + Ok(()) + } + + /// Build an `OpAccessChain` instruction. + /// + /// Emit any needed bounds-checking expressions to `block`. + /// + /// Some cases we need to generate a different return type than what the IR gives us. + /// This is because pointers to binding arrays of handles (such as images or samplers) + /// don't exist in the IR, but we need to create them to create an access chain in SPIRV. + /// + /// On success, the return value is an [`ExpressionPointer`] value; see the + /// documentation for that type. + fn write_expression_pointer( + &mut self, + mut expr_handle: Handle<crate::Expression>, + block: &mut Block, + return_type_override: Option<LookupType>, + ) -> Result<ExpressionPointer, Error> { + let result_lookup_ty = match self.fun_info[expr_handle].ty { + TypeResolution::Handle(ty_handle) => match return_type_override { + // We use the return type override as a special case for handle binding arrays as the OpAccessChain + // needs to return a pointer, but indexing into a handle binding array just gives you the type of + // the binding in the IR. + Some(ty) => ty, + None => LookupType::Handle(ty_handle), + }, + TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()), + }; + let result_type_id = self.get_type_id(result_lookup_ty); + + // The id of the boolean `and` of all dynamic bounds checks up to this point. If + // `None`, then we haven't done any dynamic bounds checks yet. + // + // When we have a chain of bounds checks, we combine them with `OpLogicalAnd`, not + // a short-circuit branch. This means we might do comparisons we don't need to, + // but we expect these checks to almost always succeed, and keeping branches to a + // minimum is essential. + let mut accumulated_checks = None; + // Is true if we are accessing into a binding array with a non-uniform index. + let mut is_non_uniform_binding_array = false; + + self.temp_list.clear(); + let root_id = loop { + expr_handle = match self.ir_function.expressions[expr_handle] { + crate::Expression::Access { base, index } => { + if let crate::Expression::GlobalVariable(var_handle) = + self.ir_function.expressions[base] + { + // The access chain needs to be decorated as NonUniform + // see VUID-RuntimeSpirv-NonUniform-06274 + let gvar = &self.ir_module.global_variables[var_handle]; + if let crate::TypeInner::BindingArray { .. } = + self.ir_module.types[gvar.ty].inner + { + is_non_uniform_binding_array = + self.fun_info[index].uniformity.non_uniform_result.is_some(); + } + } + + let index_id = match self.write_bounds_check(base, index, block)? { + BoundsCheckResult::KnownInBounds(known_index) => { + // Even if the index is known, `OpAccessIndex` + // requires expression operands, not literals. + let scalar = crate::Literal::U32(known_index); + self.writer.get_constant_scalar(scalar) + } + BoundsCheckResult::Computed(computed_index_id) => computed_index_id, + BoundsCheckResult::Conditional(comparison_id) => { + match accumulated_checks { + Some(prior_checks) => { + let combined = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + self.writer.get_bool_type_id(), + combined, + prior_checks, + comparison_id, + )); + accumulated_checks = Some(combined); + } + None => { + // Start a fresh chain of checks. + accumulated_checks = Some(comparison_id); + } + } + + // Either way, the index to use is unchanged. + self.cached[index] + } + }; + self.temp_list.push(index_id); + base + } + crate::Expression::AccessIndex { base, index } => { + let const_id = self.get_index_constant(index); + self.temp_list.push(const_id); + base + } + crate::Expression::GlobalVariable(handle) => { + let gv = &self.writer.global_variables[handle.index()]; + break gv.access_id; + } + crate::Expression::LocalVariable(variable) => { + let local_var = &self.function.variables[&variable]; + break local_var.id; + } + crate::Expression::FunctionArgument(index) => { + break self.function.parameter_id(index); + } + ref other => unimplemented!("Unexpected pointer expression {:?}", other), + } + }; + + let (pointer_id, expr_pointer) = if self.temp_list.is_empty() { + ( + root_id, + ExpressionPointer::Ready { + pointer_id: root_id, + }, + ) + } else { + self.temp_list.reverse(); + let pointer_id = self.gen_id(); + let access = + Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list); + + // If we generated some bounds checks, we need to leave it to our + // caller to generate the branch, the access, the load or store, and + // the zero value (for loads). Otherwise, we can emit the access + // ourselves, and just hand them the id of the pointer. + let expr_pointer = match accumulated_checks { + Some(condition) => ExpressionPointer::Conditional { condition, access }, + None => { + block.body.push(access); + ExpressionPointer::Ready { pointer_id } + } + }; + (pointer_id, expr_pointer) + }; + // Subsequent load, store and atomic operations require the pointer to be decorated as NonUniform + // if the binding array was accessed with a non-uniform index + // see VUID-RuntimeSpirv-NonUniform-06274 + if is_non_uniform_binding_array { + self.writer + .decorate_non_uniform_binding_array_access(pointer_id)?; + } + + Ok(expr_pointer) + } + + /// Build the instructions for matrix - matrix column operations + #[allow(clippy::too_many_arguments)] + fn write_matrix_matrix_column_op( + &mut self, + block: &mut Block, + result_id: Word, + result_type_id: Word, + left_id: Word, + right_id: Word, + columns: crate::VectorSize, + rows: crate::VectorSize, + width: u8, + op: spirv::Op, + ) { + self.temp_list.clear(); + + let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(rows), + scalar: crate::Scalar::float(width), + pointer_space: None, + })); + + for index in 0..columns as u32 { + let column_id_left = self.gen_id(); + let column_id_right = self.gen_id(); + let column_id_res = self.gen_id(); + + block.body.push(Instruction::composite_extract( + vector_type_id, + column_id_left, + left_id, + &[index], + )); + block.body.push(Instruction::composite_extract( + vector_type_id, + column_id_right, + right_id, + &[index], + )); + block.body.push(Instruction::binary( + op, + vector_type_id, + column_id_res, + column_id_left, + column_id_right, + )); + + self.temp_list.push(column_id_res); + } + + block.body.push(Instruction::composite_construct( + result_type_id, + result_id, + &self.temp_list, + )); + } + + /// Build the instructions for vector - scalar multiplication + fn write_vector_scalar_mult( + &mut self, + block: &mut Block, + result_id: Word, + result_type_id: Word, + vector_id: Word, + scalar_id: Word, + vector: &crate::TypeInner, + ) { + let (size, kind) = match *vector { + crate::TypeInner::Vector { + size, + scalar: crate::Scalar { kind, .. }, + } => (size, kind), + _ => unreachable!(), + }; + + let (op, operand_id) = match kind { + crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id), + _ => { + let operand_id = self.gen_id(); + self.temp_list.clear(); + self.temp_list.resize(size as usize, scalar_id); + block.body.push(Instruction::composite_construct( + result_type_id, + operand_id, + &self.temp_list, + )); + (spirv::Op::IMul, operand_id) + } + }; + + block.body.push(Instruction::binary( + op, + result_type_id, + result_id, + vector_id, + operand_id, + )); + } + + /// Build the instructions for the arithmetic expression of a dot product + fn write_dot_product( + &mut self, + result_id: Word, + result_type_id: Word, + arg0_id: Word, + arg1_id: Word, + size: u32, + block: &mut Block, + ) { + let mut partial_sum = self.writer.get_constant_null(result_type_id); + let last_component = size - 1; + for index in 0..=last_component { + // compute the product of the current components + let a_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + a_id, + arg0_id, + &[index], + )); + let b_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + b_id, + arg1_id, + &[index], + )); + let prod_id = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::IMul, + result_type_id, + prod_id, + a_id, + b_id, + )); + + // choose the id for the next sum, depending on current index + let id = if index == last_component { + result_id + } else { + self.gen_id() + }; + + // sum the computed product with the partial sum + block.body.push(Instruction::binary( + spirv::Op::IAdd, + result_type_id, + id, + partial_sum, + prod_id, + )); + // set the id of the result as the previous partial sum + partial_sum = id; + } + } + + pub(super) fn write_block( + &mut self, + label_id: Word, + naga_block: &crate::Block, + exit: BlockExit, + loop_context: LoopContext, + debug_info: Option<&DebugInfoInner>, + ) -> Result<(), Error> { + let mut block = Block::new(label_id); + for (statement, span) in naga_block.span_iter() { + if let (Some(debug_info), false) = ( + debug_info, + matches!( + statement, + &(Statement::Block(..) + | Statement::Break + | Statement::Continue + | Statement::Kill + | Statement::Return { .. } + | Statement::Loop { .. }) + ), + ) { + let loc: crate::SourceLocation = span.location(debug_info.source_code); + block.body.push(Instruction::line( + debug_info.source_file_id, + loc.line_number, + loc.line_position, + )); + }; + match *statement { + crate::Statement::Emit(ref range) => { + for handle in range.clone() { + // omit const expressions as we've already cached those + if !self.expression_constness.is_const(handle) { + self.cache_expression_value(handle, &mut block)?; + } + } + } + crate::Statement::Block(ref block_statements) => { + let scope_id = self.gen_id(); + self.function.consume(block, Instruction::branch(scope_id)); + + let merge_id = self.gen_id(); + self.write_block( + scope_id, + block_statements, + BlockExit::Branch { target: merge_id }, + loop_context, + debug_info, + )?; + + block = Block::new(merge_id); + } + crate::Statement::If { + condition, + ref accept, + ref reject, + } => { + let condition_id = self.cached[condition]; + + let merge_id = self.gen_id(); + block.body.push(Instruction::selection_merge( + merge_id, + spirv::SelectionControl::NONE, + )); + + let accept_id = if accept.is_empty() { + None + } else { + Some(self.gen_id()) + }; + let reject_id = if reject.is_empty() { + None + } else { + Some(self.gen_id()) + }; + + self.function.consume( + block, + Instruction::branch_conditional( + condition_id, + accept_id.unwrap_or(merge_id), + reject_id.unwrap_or(merge_id), + ), + ); + + if let Some(block_id) = accept_id { + self.write_block( + block_id, + accept, + BlockExit::Branch { target: merge_id }, + loop_context, + debug_info, + )?; + } + if let Some(block_id) = reject_id { + self.write_block( + block_id, + reject, + BlockExit::Branch { target: merge_id }, + loop_context, + debug_info, + )?; + } + + block = Block::new(merge_id); + } + crate::Statement::Switch { + selector, + ref cases, + } => { + let selector_id = self.cached[selector]; + + let merge_id = self.gen_id(); + block.body.push(Instruction::selection_merge( + merge_id, + spirv::SelectionControl::NONE, + )); + + let mut default_id = None; + // id of previous empty fall-through case + let mut last_id = None; + + let mut raw_cases = Vec::with_capacity(cases.len()); + let mut case_ids = Vec::with_capacity(cases.len()); + for case in cases.iter() { + // take id of previous empty fall-through case or generate a new one + let label_id = last_id.take().unwrap_or_else(|| self.gen_id()); + + if case.fall_through && case.body.is_empty() { + last_id = Some(label_id); + } + + case_ids.push(label_id); + + match case.value { + crate::SwitchValue::I32(value) => { + raw_cases.push(super::instructions::Case { + value: value as Word, + label_id, + }); + } + crate::SwitchValue::U32(value) => { + raw_cases.push(super::instructions::Case { value, label_id }); + } + crate::SwitchValue::Default => { + default_id = Some(label_id); + } + } + } + + let default_id = default_id.unwrap(); + + self.function.consume( + block, + Instruction::switch(selector_id, default_id, &raw_cases), + ); + + let inner_context = LoopContext { + break_id: Some(merge_id), + ..loop_context + }; + + for (i, (case, label_id)) in cases + .iter() + .zip(case_ids.iter()) + .filter(|&(case, _)| !(case.fall_through && case.body.is_empty())) + .enumerate() + { + let case_finish_id = if case.fall_through { + case_ids[i + 1] + } else { + merge_id + }; + self.write_block( + *label_id, + &case.body, + BlockExit::Branch { + target: case_finish_id, + }, + inner_context, + debug_info, + )?; + } + + block = Block::new(merge_id); + } + crate::Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + let preamble_id = self.gen_id(); + self.function + .consume(block, Instruction::branch(preamble_id)); + + let merge_id = self.gen_id(); + let body_id = self.gen_id(); + let continuing_id = self.gen_id(); + + // SPIR-V requires the continuing to the `OpLoopMerge`, + // so we have to start a new block with it. + block = Block::new(preamble_id); + // HACK the loop statement is begin with branch instruction, + // so we need to put `OpLine` debug info before merge instruction + if let Some(debug_info) = debug_info { + let loc: crate::SourceLocation = span.location(debug_info.source_code); + block.body.push(Instruction::line( + debug_info.source_file_id, + loc.line_number, + loc.line_position, + )) + } + block.body.push(Instruction::loop_merge( + merge_id, + continuing_id, + spirv::SelectionControl::NONE, + )); + self.function.consume(block, Instruction::branch(body_id)); + + self.write_block( + body_id, + body, + BlockExit::Branch { + target: continuing_id, + }, + LoopContext { + continuing_id: Some(continuing_id), + break_id: Some(merge_id), + }, + debug_info, + )?; + + let exit = match break_if { + Some(condition) => BlockExit::BreakIf { + condition, + preamble_id, + }, + None => BlockExit::Branch { + target: preamble_id, + }, + }; + + self.write_block( + continuing_id, + continuing, + exit, + LoopContext { + continuing_id: None, + break_id: Some(merge_id), + }, + debug_info, + )?; + + block = Block::new(merge_id); + } + crate::Statement::Break => { + self.function + .consume(block, Instruction::branch(loop_context.break_id.unwrap())); + return Ok(()); + } + crate::Statement::Continue => { + self.function.consume( + block, + Instruction::branch(loop_context.continuing_id.unwrap()), + ); + return Ok(()); + } + crate::Statement::Return { value: Some(value) } => { + let value_id = self.cached[value]; + let instruction = match self.function.entry_point_context { + // If this is an entry point, and we need to return anything, + // let's instead store the output variables and return `void`. + Some(ref context) => { + self.writer.write_entry_point_return( + value_id, + self.ir_function.result.as_ref().unwrap(), + &context.results, + &mut block.body, + )?; + Instruction::return_void() + } + None => Instruction::return_value(value_id), + }; + self.function.consume(block, instruction); + return Ok(()); + } + crate::Statement::Return { value: None } => { + self.function.consume(block, Instruction::return_void()); + return Ok(()); + } + crate::Statement::Kill => { + self.function.consume(block, Instruction::kill()); + return Ok(()); + } + crate::Statement::Barrier(flags) => { + self.writer.write_barrier(flags, &mut block); + } + crate::Statement::Store { pointer, value } => { + let value_id = self.cached[value]; + match self.write_expression_pointer(pointer, &mut block, None)? { + ExpressionPointer::Ready { pointer_id } => { + let atomic_space = match *self.fun_info[pointer] + .ty + .inner_with(&self.ir_module.types) + { + crate::TypeInner::Pointer { base, space } => { + match self.ir_module.types[base].inner { + crate::TypeInner::Atomic { .. } => Some(space), + _ => None, + } + } + _ => None, + }; + let instruction = if let Some(space) = atomic_space { + let (semantics, scope) = space.to_spirv_semantics_and_scope(); + let scope_constant_id = self.get_scope_constant(scope as u32); + let semantics_id = self.get_index_constant(semantics.bits()); + Instruction::atomic_store( + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ) + } else { + Instruction::store(pointer_id, value_id, None) + }; + block.body.push(instruction); + } + ExpressionPointer::Conditional { condition, access } => { + let mut selection = Selection::start(&mut block, ()); + selection.if_true(self, condition, ()); + + // The in-bounds path. Perform the access and the store. + let pointer_id = access.result_id.unwrap(); + selection.block().body.push(access); + selection + .block() + .body + .push(Instruction::store(pointer_id, value_id, None)); + + // Finish the in-bounds block and start the merge block. This + // is the block we'll leave current on return. + selection.finish(self, ()); + } + }; + } + crate::Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => self.write_image_store(image, coordinate, array_index, value, &mut block)?, + crate::Statement::Call { + function: local_function, + ref arguments, + result, + } => { + let id = self.gen_id(); + self.temp_list.clear(); + for &argument in arguments { + self.temp_list.push(self.cached[argument]); + } + + let type_id = match result { + Some(expr) => { + self.cached[expr] = id; + self.get_expression_type_id(&self.fun_info[expr].ty) + } + None => self.writer.void_type, + }; + + block.body.push(Instruction::function_call( + type_id, + id, + self.writer.lookup_function[&local_function], + &self.temp_list, + )); + } + crate::Statement::Atomic { + pointer, + ref fun, + value, + result, + } => { + let id = self.gen_id(); + let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty); + + self.cached[result] = id; + + let pointer_id = + match self.write_expression_pointer(pointer, &mut block, None)? { + ExpressionPointer::Ready { pointer_id } => pointer_id, + ExpressionPointer::Conditional { .. } => { + return Err(Error::FeatureNotImplemented( + "Atomics out-of-bounds handling", + )); + } + }; + + let space = self.fun_info[pointer] + .ty + .inner_with(&self.ir_module.types) + .pointer_space() + .unwrap(); + let (semantics, scope) = space.to_spirv_semantics_and_scope(); + let scope_constant_id = self.get_scope_constant(scope as u32); + let semantics_id = self.get_index_constant(semantics.bits()); + let value_id = self.cached[value]; + let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types); + + let instruction = match *fun { + crate::AtomicFunction::Add => Instruction::atomic_binary( + spirv::Op::AtomicIAdd, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::Subtract => Instruction::atomic_binary( + spirv::Op::AtomicISub, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::And => Instruction::atomic_binary( + spirv::Op::AtomicAnd, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::InclusiveOr => Instruction::atomic_binary( + spirv::Op::AtomicOr, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::ExclusiveOr => Instruction::atomic_binary( + spirv::Op::AtomicXor, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::Min => { + let spirv_op = match *value_inner { + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + }) => spirv::Op::AtomicSMin, + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Uint, + width: _, + }) => spirv::Op::AtomicUMin, + _ => unimplemented!(), + }; + Instruction::atomic_binary( + spirv_op, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ) + } + crate::AtomicFunction::Max => { + let spirv_op = match *value_inner { + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + }) => spirv::Op::AtomicSMax, + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Uint, + width: _, + }) => spirv::Op::AtomicUMax, + _ => unimplemented!(), + }; + Instruction::atomic_binary( + spirv_op, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ) + } + crate::AtomicFunction::Exchange { compare: None } => { + Instruction::atomic_binary( + spirv::Op::AtomicExchange, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ) + } + crate::AtomicFunction::Exchange { compare: Some(cmp) } => { + let scalar_type_id = match *value_inner { + crate::TypeInner::Scalar(scalar) => { + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar, + pointer_space: None, + })) + } + _ => unimplemented!(), + }; + let bool_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::BOOL, + pointer_space: None, + })); + + let cas_result_id = self.gen_id(); + let equality_result_id = self.gen_id(); + let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange); + cas_instr.set_type(scalar_type_id); + cas_instr.set_result(cas_result_id); + cas_instr.add_operand(pointer_id); + cas_instr.add_operand(scope_constant_id); + cas_instr.add_operand(semantics_id); // semantics if equal + cas_instr.add_operand(semantics_id); // semantics if not equal + cas_instr.add_operand(value_id); + cas_instr.add_operand(self.cached[cmp]); + block.body.push(cas_instr); + block.body.push(Instruction::binary( + spirv::Op::IEqual, + bool_type_id, + equality_result_id, + cas_result_id, + self.cached[cmp], + )); + Instruction::composite_construct( + result_type_id, + id, + &[cas_result_id, equality_result_id], + ) + } + }; + + block.body.push(instruction); + } + crate::Statement::WorkGroupUniformLoad { pointer, result } => { + self.writer + .write_barrier(crate::Barrier::WORK_GROUP, &mut block); + let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty); + // Embed the body of + match self.write_expression_pointer(pointer, &mut block, None)? { + ExpressionPointer::Ready { pointer_id } => { + let id = self.gen_id(); + block.body.push(Instruction::load( + result_type_id, + id, + pointer_id, + None, + )); + self.cached[result] = id; + } + ExpressionPointer::Conditional { condition, access } => { + self.cached[result] = self.write_conditional_indexed_load( + result_type_id, + condition, + &mut block, + move |id_gen, block| { + // The in-bounds path. Perform the access and the load. + let pointer_id = access.result_id.unwrap(); + let value_id = id_gen.next(); + block.body.push(access); + block.body.push(Instruction::load( + result_type_id, + value_id, + pointer_id, + None, + )); + value_id + }, + ) + } + } + self.writer + .write_barrier(crate::Barrier::WORK_GROUP, &mut block); + } + crate::Statement::RayQuery { query, ref fun } => { + self.write_ray_query_function(query, fun, &mut block); + } + } + } + + let termination = match exit { + // We're generating code for the top-level Block of the function, so we + // need to end it with some kind of return instruction. + BlockExit::Return => match self.ir_function.result { + Some(ref result) if self.function.entry_point_context.is_none() => { + let type_id = self.get_type_id(LookupType::Handle(result.ty)); + let null_id = self.writer.get_constant_null(type_id); + Instruction::return_value(null_id) + } + _ => Instruction::return_void(), + }, + BlockExit::Branch { target } => Instruction::branch(target), + BlockExit::BreakIf { + condition, + preamble_id, + } => { + let condition_id = self.cached[condition]; + + Instruction::branch_conditional( + condition_id, + loop_context.break_id.unwrap(), + preamble_id, + ) + } + }; + + self.function.consume(block, termination); + Ok(()) + } +} diff --git a/third_party/rust/naga/src/back/spv/helpers.rs b/third_party/rust/naga/src/back/spv/helpers.rs new file mode 100644 index 0000000000..5b6226db85 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/helpers.rs @@ -0,0 +1,109 @@ +use crate::{Handle, UniqueArena}; +use spirv::Word; + +pub(super) fn bytes_to_words(bytes: &[u8]) -> Vec<Word> { + bytes + .chunks(4) + .map(|chars| chars.iter().rev().fold(0u32, |u, c| (u << 8) | *c as u32)) + .collect() +} + +pub(super) fn string_to_words(input: &str) -> Vec<Word> { + let bytes = input.as_bytes(); + let mut words = bytes_to_words(bytes); + + if bytes.len() % 4 == 0 { + // nul-termination + words.push(0x0u32); + } + + words +} + +pub(super) const fn map_storage_class(space: crate::AddressSpace) -> spirv::StorageClass { + match space { + crate::AddressSpace::Handle => spirv::StorageClass::UniformConstant, + crate::AddressSpace::Function => spirv::StorageClass::Function, + crate::AddressSpace::Private => spirv::StorageClass::Private, + crate::AddressSpace::Storage { .. } => spirv::StorageClass::StorageBuffer, + crate::AddressSpace::Uniform => spirv::StorageClass::Uniform, + crate::AddressSpace::WorkGroup => spirv::StorageClass::Workgroup, + crate::AddressSpace::PushConstant => spirv::StorageClass::PushConstant, + } +} + +pub(super) fn contains_builtin( + binding: Option<&crate::Binding>, + ty: Handle<crate::Type>, + arena: &UniqueArena<crate::Type>, + built_in: crate::BuiltIn, +) -> bool { + if let Some(&crate::Binding::BuiltIn(bi)) = binding { + bi == built_in + } else if let crate::TypeInner::Struct { ref members, .. } = arena[ty].inner { + members + .iter() + .any(|member| contains_builtin(member.binding.as_ref(), member.ty, arena, built_in)) + } else { + false // unreachable + } +} + +impl crate::AddressSpace { + pub(super) const fn to_spirv_semantics_and_scope( + self, + ) -> (spirv::MemorySemantics, spirv::Scope) { + match self { + Self::Storage { .. } => (spirv::MemorySemantics::UNIFORM_MEMORY, spirv::Scope::Device), + Self::WorkGroup => ( + spirv::MemorySemantics::WORKGROUP_MEMORY, + spirv::Scope::Workgroup, + ), + _ => (spirv::MemorySemantics::empty(), spirv::Scope::Invocation), + } + } +} + +/// Return true if the global requires a type decorated with `Block`. +/// +/// Vulkan spec v1.3 §15.6.2, "Descriptor Set Interface", says: +/// +/// > Variables identified with the `Uniform` storage class are used to +/// > access transparent buffer backed resources. Such variables must +/// > be: +/// > +/// > - typed as `OpTypeStruct`, or an array of this type, +/// > +/// > - identified with a `Block` or `BufferBlock` decoration, and +/// > +/// > - laid out explicitly using the `Offset`, `ArrayStride`, and +/// > `MatrixStride` decorations as specified in §15.6.4, "Offset +/// > and Stride Assignment." +// See `back::spv::GlobalVariable::access_id` for details. +pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariable) -> bool { + match var.space { + crate::AddressSpace::Uniform + | crate::AddressSpace::Storage { .. } + | crate::AddressSpace::PushConstant => {} + _ => return false, + }; + match ir_module.types[var.ty].inner { + crate::TypeInner::Struct { + ref members, + span: _, + } => match members.last() { + Some(member) => match ir_module.types[member.ty].inner { + // Structs with dynamically sized arrays can't be copied and can't be wrapped. + crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } => false, + _ => true, + }, + None => false, + }, + crate::TypeInner::BindingArray { .. } => false, + // if it's not a structure or a binding array, let's wrap it to be able to put "Block" + _ => true, + } +} diff --git a/third_party/rust/naga/src/back/spv/image.rs b/third_party/rust/naga/src/back/spv/image.rs new file mode 100644 index 0000000000..c0fc41cbb6 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/image.rs @@ -0,0 +1,1210 @@ +/*! +Generating SPIR-V for image operations. +*/ + +use super::{ + selection::{MergeTuple, Selection}, + Block, BlockContext, Error, IdGenerator, Instruction, LocalType, LookupType, +}; +use crate::arena::Handle; +use spirv::Word; + +/// Information about a vector of coordinates. +/// +/// The coordinate vectors expected by SPIR-V `OpImageRead` and `OpImageFetch` +/// supply the array index for arrayed images as an additional component at +/// the end, whereas Naga's `ImageLoad`, `ImageStore`, and `ImageSample` carry +/// the array index as a separate field. +/// +/// In the process of generating code to compute the combined vector, we also +/// produce SPIR-V types and vector lengths that are useful elsewhere. This +/// struct gathers that information into one place, with standard names. +struct ImageCoordinates { + /// The SPIR-V id of the combined coordinate/index vector value. + /// + /// Note: when indexing a non-arrayed 1D image, this will be a scalar. + value_id: Word, + + /// The SPIR-V id of the type of `value`. + type_id: Word, + + /// The number of components in `value`, if it is a vector, or `None` if it + /// is a scalar. + size: Option<crate::VectorSize>, +} + +/// A trait for image access (load or store) code generators. +/// +/// Types implementing this trait hold information about an `ImageStore` or +/// `ImageLoad` operation that is not affected by the bounds check policy. The +/// `generate` method emits code for the access, given the results of bounds +/// checking. +/// +/// The [`image`] bounds checks policy affects access coordinates, level of +/// detail, and sample index, but never the image id, result type (if any), or +/// the specific SPIR-V instruction used. Types that implement this trait gather +/// together the latter category, so we don't have to plumb them through the +/// bounds-checking code. +/// +/// [`image`]: crate::proc::BoundsCheckPolicies::index +trait Access { + /// The Rust type that represents SPIR-V values and types for this access. + /// + /// For operations like loads, this is `Word`. For operations like stores, + /// this is `()`. + /// + /// For `ReadZeroSkipWrite`, this will be the type of the selection + /// construct that performs the bounds checks, so it must implement + /// `MergeTuple`. + type Output: MergeTuple + Copy + Clone; + + /// Write an image access to `block`. + /// + /// Access the texel at `coordinates_id`. The optional `level_id` indicates + /// the level of detail, and `sample_id` is the index of the sample to + /// access in a multisampled texel. + /// + /// This method assumes that `coordinates_id` has already had the image array + /// index, if any, folded in, as done by `write_image_coordinates`. + /// + /// Return the value id produced by the instruction, if any. + /// + /// Use `id_gen` to generate SPIR-V ids as necessary. + fn generate( + &self, + id_gen: &mut IdGenerator, + coordinates_id: Word, + level_id: Option<Word>, + sample_id: Option<Word>, + block: &mut Block, + ) -> Self::Output; + + /// Return the SPIR-V type of the value produced by the code written by + /// `generate`. If the access does not produce a value, `Self::Output` + /// should be `()`. + fn result_type(&self) -> Self::Output; + + /// Construct the SPIR-V 'zero' value to be returned for an out-of-bounds + /// access under the `ReadZeroSkipWrite` policy. If the access does not + /// produce a value, `Self::Output` should be `()`. + fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Self::Output; +} + +/// Texel access information for an [`ImageLoad`] expression. +/// +/// [`ImageLoad`]: crate::Expression::ImageLoad +struct Load { + /// The specific opcode we'll use to perform the fetch. Storage images + /// require `OpImageRead`, while sampled images require `OpImageFetch`. + opcode: spirv::Op, + + /// The type id produced by the actual image access instruction. + type_id: Word, + + /// The id of the image being accessed. + image_id: Word, +} + +impl Load { + fn from_image_expr( + ctx: &mut BlockContext<'_>, + image_id: Word, + image_class: crate::ImageClass, + result_type_id: Word, + ) -> Result<Load, Error> { + let opcode = match image_class { + crate::ImageClass::Storage { .. } => spirv::Op::ImageRead, + crate::ImageClass::Depth { .. } | crate::ImageClass::Sampled { .. } => { + spirv::Op::ImageFetch + } + }; + + // `OpImageRead` and `OpImageFetch` instructions produce vec4<f32> + // values. Most of the time, we can just use `result_type_id` for + // this. The exception is that `Expression::ImageLoad` from a depth + // image produces a scalar `f32`, so in that case we need to find + // the right SPIR-V type for the access instruction here. + let type_id = match image_class { + crate::ImageClass::Depth { .. } => { + ctx.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Quad), + scalar: crate::Scalar::F32, + pointer_space: None, + })) + } + _ => result_type_id, + }; + + Ok(Load { + opcode, + type_id, + image_id, + }) + } +} + +impl Access for Load { + type Output = Word; + + /// Write an instruction to access a given texel of this image. + fn generate( + &self, + id_gen: &mut IdGenerator, + coordinates_id: Word, + level_id: Option<Word>, + sample_id: Option<Word>, + block: &mut Block, + ) -> Word { + let texel_id = id_gen.next(); + let mut instruction = Instruction::image_fetch_or_read( + self.opcode, + self.type_id, + texel_id, + self.image_id, + coordinates_id, + ); + + match (level_id, sample_id) { + (None, None) => {} + (Some(level_id), None) => { + instruction.add_operand(spirv::ImageOperands::LOD.bits()); + instruction.add_operand(level_id); + } + (None, Some(sample_id)) => { + instruction.add_operand(spirv::ImageOperands::SAMPLE.bits()); + instruction.add_operand(sample_id); + } + // There's no such thing as a multi-sampled mipmap. + (Some(_), Some(_)) => unreachable!(), + } + + block.body.push(instruction); + + texel_id + } + + fn result_type(&self) -> Word { + self.type_id + } + + fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Word { + ctx.writer.get_constant_null(self.type_id) + } +} + +/// Texel access information for a [`Store`] statement. +/// +/// [`Store`]: crate::Statement::Store +struct Store { + /// The id of the image being written to. + image_id: Word, + + /// The value we're going to write to the texel. + value_id: Word, +} + +impl Access for Store { + /// Stores don't generate any value. + type Output = (); + + fn generate( + &self, + _id_gen: &mut IdGenerator, + coordinates_id: Word, + _level_id: Option<Word>, + _sample_id: Option<Word>, + block: &mut Block, + ) { + block.body.push(Instruction::image_write( + self.image_id, + coordinates_id, + self.value_id, + )); + } + + /// Stores don't generate any value, so this just returns `()`. + fn result_type(&self) {} + + /// Stores don't generate any value, so this just returns `()`. + fn out_of_bounds_value(&self, _ctx: &mut BlockContext<'_>) {} +} + +impl<'w> BlockContext<'w> { + /// Extend image coordinates with an array index, if necessary. + /// + /// Whereas [`Expression::ImageLoad`] and [`ImageSample`] treat the array + /// index as a separate operand from the coordinates, SPIR-V image access + /// instructions include the array index in the `coordinates` operand. This + /// function builds a SPIR-V coordinate vector from a Naga coordinate vector + /// and array index, if one is supplied, and returns a `ImageCoordinates` + /// struct describing what it built. + /// + /// If `array_index` is `Some(expr)`, then this function constructs a new + /// vector that is `coordinates` with `array_index` concatenated onto the + /// end: a `vec2` becomes a `vec3`, a scalar becomes a `vec2`, and so on. + /// + /// If `array_index` is `None`, then the return value uses `coordinates` + /// unchanged. Note that, when indexing a non-arrayed 1D image, this will be + /// a scalar value. + /// + /// If needed, this function generates code to convert the array index, + /// always an integer scalar, to match the component type of `coordinates`. + /// Naga's `ImageLoad` and SPIR-V's `OpImageRead`, `OpImageFetch`, and + /// `OpImageWrite` all use integer coordinates, while Naga's `ImageSample` + /// and SPIR-V's `OpImageSample...` instructions all take floating-point + /// coordinate vectors. + /// + /// [`Expression::ImageLoad`]: crate::Expression::ImageLoad + /// [`ImageSample`]: crate::Expression::ImageSample + fn write_image_coordinates( + &mut self, + coordinates: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + block: &mut Block, + ) -> Result<ImageCoordinates, Error> { + use crate::TypeInner as Ti; + use crate::VectorSize as Vs; + + let coordinates_id = self.cached[coordinates]; + let ty = &self.fun_info[coordinates].ty; + let inner_ty = ty.inner_with(&self.ir_module.types); + + // If there's no array index, the image coordinates are exactly the + // `coordinate` field of the `Expression::ImageLoad`. No work is needed. + let array_index = match array_index { + None => { + let value_id = coordinates_id; + let type_id = self.get_expression_type_id(ty); + let size = match *inner_ty { + Ti::Scalar { .. } => None, + Ti::Vector { size, .. } => Some(size), + _ => return Err(Error::Validation("coordinate type")), + }; + return Ok(ImageCoordinates { + value_id, + type_id, + size, + }); + } + Some(ix) => ix, + }; + + // Find the component type of `coordinates`, and figure out the size the + // combined coordinate vector will have. + let (component_scalar, size) = match *inner_ty { + Ti::Scalar(scalar @ crate::Scalar { width: 4, .. }) => (scalar, Some(Vs::Bi)), + Ti::Vector { + scalar: scalar @ crate::Scalar { width: 4, .. }, + size: Vs::Bi, + } => (scalar, Some(Vs::Tri)), + Ti::Vector { + scalar: scalar @ crate::Scalar { width: 4, .. }, + size: Vs::Tri, + } => (scalar, Some(Vs::Quad)), + Ti::Vector { size: Vs::Quad, .. } => { + return Err(Error::Validation("extending vec4 coordinate")); + } + ref other => { + log::error!("wrong coordinate type {:?}", other); + return Err(Error::Validation("coordinate type")); + } + }; + + // Convert the index to the coordinate component type, if necessary. + let array_index_id = self.cached[array_index]; + let ty = &self.fun_info[array_index].ty; + let inner_ty = ty.inner_with(&self.ir_module.types); + let array_index_scalar = match *inner_ty { + Ti::Scalar( + scalar @ crate::Scalar { + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + width: 4, + }, + ) => scalar, + _ => unreachable!("we only allow i32 and u32"), + }; + let cast = match (component_scalar.kind, array_index_scalar.kind) { + (crate::ScalarKind::Sint, crate::ScalarKind::Sint) + | (crate::ScalarKind::Uint, crate::ScalarKind::Uint) => None, + (crate::ScalarKind::Sint, crate::ScalarKind::Uint) + | (crate::ScalarKind::Uint, crate::ScalarKind::Sint) => Some(spirv::Op::Bitcast), + (crate::ScalarKind::Float, crate::ScalarKind::Sint) => Some(spirv::Op::ConvertSToF), + (crate::ScalarKind::Float, crate::ScalarKind::Uint) => Some(spirv::Op::ConvertUToF), + (crate::ScalarKind::Bool, _) => unreachable!("we don't allow bool for component"), + (_, crate::ScalarKind::Bool | crate::ScalarKind::Float) => { + unreachable!("we don't allow bool or float for array index") + } + (crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat, _) + | (_, crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat) => { + unreachable!("abstract types should never reach backends") + } + }; + let reconciled_array_index_id = if let Some(cast) = cast { + let component_ty_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: component_scalar, + pointer_space: None, + })); + let reconciled_id = self.gen_id(); + block.body.push(Instruction::unary( + cast, + component_ty_id, + reconciled_id, + array_index_id, + )); + reconciled_id + } else { + array_index_id + }; + + // Find the SPIR-V type for the combined coordinates/index vector. + let type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: size, + scalar: component_scalar, + pointer_space: None, + })); + + // Schmear the coordinates and index together. + let value_id = self.gen_id(); + block.body.push(Instruction::composite_construct( + type_id, + value_id, + &[coordinates_id, reconciled_array_index_id], + )); + Ok(ImageCoordinates { + value_id, + type_id, + size, + }) + } + + pub(super) fn get_handle_id(&mut self, expr_handle: Handle<crate::Expression>) -> Word { + let id = match self.ir_function.expressions[expr_handle] { + crate::Expression::GlobalVariable(handle) => { + self.writer.global_variables[handle.index()].handle_id + } + crate::Expression::FunctionArgument(i) => { + self.function.parameters[i as usize].handle_id + } + crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => { + self.cached[expr_handle] + } + ref other => unreachable!("Unexpected image expression {:?}", other), + }; + + if id == 0 { + unreachable!( + "Image expression {:?} doesn't have a handle ID", + expr_handle + ); + } + + id + } + + /// Generate a vector or scalar 'one' for arithmetic on `coordinates`. + /// + /// If `coordinates` is a scalar, return a scalar one. Otherwise, return + /// a vector of ones. + fn write_coordinate_one(&mut self, coordinates: &ImageCoordinates) -> Result<Word, Error> { + let one = self.get_scope_constant(1); + match coordinates.size { + None => Ok(one), + Some(vector_size) => { + let ones = [one; 4]; + let id = self.gen_id(); + Instruction::constant_composite( + coordinates.type_id, + id, + &ones[..vector_size as usize], + ) + .to_words(&mut self.writer.logical_layout.declarations); + Ok(id) + } + } + } + + /// Generate code to restrict `input` to fall between zero and one less than + /// `size_id`. + /// + /// Both must be 32-bit scalar integer values, whose type is given by + /// `type_id`. The computed value is also of type `type_id`. + fn restrict_scalar( + &mut self, + type_id: Word, + input_id: Word, + size_id: Word, + block: &mut Block, + ) -> Result<Word, Error> { + let i32_one_id = self.get_scope_constant(1); + + // Subtract one from `size` to get the largest valid value. + let limit_id = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::ISub, + type_id, + limit_id, + size_id, + i32_one_id, + )); + + // Use an unsigned minimum, to handle both positive out-of-range values + // and negative values in a single instruction: negative values of + // `input_id` get treated as very large positive values. + let restricted_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::UMin, + type_id, + restricted_id, + &[input_id, limit_id], + )); + + Ok(restricted_id) + } + + /// Write instructions to query the size of an image. + /// + /// This takes care of selecting the right instruction depending on whether + /// a level of detail parameter is present. + fn write_coordinate_bounds( + &mut self, + type_id: Word, + image_id: Word, + level_id: Option<Word>, + block: &mut Block, + ) -> Word { + let coordinate_bounds_id = self.gen_id(); + match level_id { + Some(level_id) => { + // A level of detail was provided, so fetch the image size for + // that level. + let mut inst = Instruction::image_query( + spirv::Op::ImageQuerySizeLod, + type_id, + coordinate_bounds_id, + image_id, + ); + inst.add_operand(level_id); + block.body.push(inst); + } + _ => { + // No level of detail was given. + block.body.push(Instruction::image_query( + spirv::Op::ImageQuerySize, + type_id, + coordinate_bounds_id, + image_id, + )); + } + } + + coordinate_bounds_id + } + + /// Write code to restrict coordinates for an image reference. + /// + /// First, clamp the level of detail or sample index to fall within bounds. + /// Then, obtain the image size, possibly using the clamped level of detail. + /// Finally, use an unsigned minimum instruction to force all coordinates + /// into range. + /// + /// Return a triple `(COORDS, LEVEL, SAMPLE)`, where `COORDS` is a coordinate + /// vector (including the array index, if any), `LEVEL` is an optional level + /// of detail, and `SAMPLE` is an optional sample index, all guaranteed to + /// be in-bounds for `image_id`. + /// + /// The result is usually a vector, but it is a scalar when indexing + /// non-arrayed 1D images. + fn write_restricted_coordinates( + &mut self, + image_id: Word, + coordinates: ImageCoordinates, + level_id: Option<Word>, + sample_id: Option<Word>, + block: &mut Block, + ) -> Result<(Word, Option<Word>, Option<Word>), Error> { + self.writer.require_any( + "the `Restrict` image bounds check policy", + &[spirv::Capability::ImageQuery], + )?; + + let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::I32, + pointer_space: None, + })); + + // If `level` is `Some`, clamp it to fall within bounds. This must + // happen first, because we'll use it to query the image size for + // clamping the actual coordinates. + let level_id = level_id + .map(|level_id| { + // Find the number of mipmap levels in this image. + let num_levels_id = self.gen_id(); + block.body.push(Instruction::image_query( + spirv::Op::ImageQueryLevels, + i32_type_id, + num_levels_id, + image_id, + )); + + self.restrict_scalar(i32_type_id, level_id, num_levels_id, block) + }) + .transpose()?; + + // If `sample_id` is `Some`, clamp it to fall within bounds. + let sample_id = sample_id + .map(|sample_id| { + // Find the number of samples per texel. + let num_samples_id = self.gen_id(); + block.body.push(Instruction::image_query( + spirv::Op::ImageQuerySamples, + i32_type_id, + num_samples_id, + image_id, + )); + + self.restrict_scalar(i32_type_id, sample_id, num_samples_id, block) + }) + .transpose()?; + + // Obtain the image bounds, including the array element count. + let coordinate_bounds_id = + self.write_coordinate_bounds(coordinates.type_id, image_id, level_id, block); + + // Compute maximum valid values from the bounds. + let ones = self.write_coordinate_one(&coordinates)?; + let coordinate_limit_id = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::ISub, + coordinates.type_id, + coordinate_limit_id, + coordinate_bounds_id, + ones, + )); + + // Restrict the coordinates to fall within those bounds. + // + // Use an unsigned minimum, to handle both positive out-of-range values + // and negative values in a single instruction: negative values of + // `coordinates` get treated as very large positive values. + let restricted_coordinates_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::UMin, + coordinates.type_id, + restricted_coordinates_id, + &[coordinates.value_id, coordinate_limit_id], + )); + + Ok((restricted_coordinates_id, level_id, sample_id)) + } + + fn write_conditional_image_access<A: Access>( + &mut self, + image_id: Word, + coordinates: ImageCoordinates, + level_id: Option<Word>, + sample_id: Option<Word>, + block: &mut Block, + access: &A, + ) -> Result<A::Output, Error> { + self.writer.require_any( + "the `ReadZeroSkipWrite` image bounds check policy", + &[spirv::Capability::ImageQuery], + )?; + + let bool_type_id = self.writer.get_bool_type_id(); + let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::I32, + pointer_space: None, + })); + + let null_id = access.out_of_bounds_value(self); + + let mut selection = Selection::start(block, access.result_type()); + + // If `level_id` is `Some`, check whether it is within bounds. This must + // happen first, because we'll be supplying this as an argument when we + // query the image size. + if let Some(level_id) = level_id { + // Find the number of mipmap levels in this image. + let num_levels_id = self.gen_id(); + selection.block().body.push(Instruction::image_query( + spirv::Op::ImageQueryLevels, + i32_type_id, + num_levels_id, + image_id, + )); + + let lod_cond_id = self.gen_id(); + selection.block().body.push(Instruction::binary( + spirv::Op::ULessThan, + bool_type_id, + lod_cond_id, + level_id, + num_levels_id, + )); + + selection.if_true(self, lod_cond_id, null_id); + } + + // If `sample_id` is `Some`, check whether it is in bounds. + if let Some(sample_id) = sample_id { + // Find the number of samples per texel. + let num_samples_id = self.gen_id(); + selection.block().body.push(Instruction::image_query( + spirv::Op::ImageQuerySamples, + i32_type_id, + num_samples_id, + image_id, + )); + + let samples_cond_id = self.gen_id(); + selection.block().body.push(Instruction::binary( + spirv::Op::ULessThan, + bool_type_id, + samples_cond_id, + sample_id, + num_samples_id, + )); + + selection.if_true(self, samples_cond_id, null_id); + } + + // Obtain the image bounds, including any array element count. + let coordinate_bounds_id = self.write_coordinate_bounds( + coordinates.type_id, + image_id, + level_id, + selection.block(), + ); + + // Compare the coordinates against the bounds. + let coords_bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: coordinates.size, + scalar: crate::Scalar::BOOL, + pointer_space: None, + })); + let coords_conds_id = self.gen_id(); + selection.block().body.push(Instruction::binary( + spirv::Op::ULessThan, + coords_bool_type_id, + coords_conds_id, + coordinates.value_id, + coordinate_bounds_id, + )); + + // If the comparison above was a vector comparison, then we need to + // check that all components of the comparison are true. + let coords_cond_id = if coords_bool_type_id != bool_type_id { + let id = self.gen_id(); + selection.block().body.push(Instruction::relational( + spirv::Op::All, + bool_type_id, + id, + coords_conds_id, + )); + id + } else { + coords_conds_id + }; + + selection.if_true(self, coords_cond_id, null_id); + + // All conditions are met. We can carry out the access. + let texel_id = access.generate( + &mut self.writer.id_gen, + coordinates.value_id, + level_id, + sample_id, + selection.block(), + ); + + // This, then, is the value of the 'true' branch. + Ok(selection.finish(self, texel_id)) + } + + /// Generate code for an `ImageLoad` expression. + /// + /// The arguments are the components of an `Expression::ImageLoad` variant. + #[allow(clippy::too_many_arguments)] + pub(super) fn write_image_load( + &mut self, + result_type_id: Word, + image: Handle<crate::Expression>, + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + level: Option<Handle<crate::Expression>>, + sample: Option<Handle<crate::Expression>>, + block: &mut Block, + ) -> Result<Word, Error> { + let image_id = self.get_handle_id(image); + let image_type = self.fun_info[image].ty.inner_with(&self.ir_module.types); + let image_class = match *image_type { + crate::TypeInner::Image { class, .. } => class, + _ => return Err(Error::Validation("image type")), + }; + + let access = Load::from_image_expr(self, image_id, image_class, result_type_id)?; + let coordinates = self.write_image_coordinates(coordinate, array_index, block)?; + + let level_id = level.map(|expr| self.cached[expr]); + let sample_id = sample.map(|expr| self.cached[expr]); + + // Perform the access, according to the bounds check policy. + let access_id = match self.writer.bounds_check_policies.image_load { + crate::proc::BoundsCheckPolicy::Restrict => { + let (coords, level_id, sample_id) = self.write_restricted_coordinates( + image_id, + coordinates, + level_id, + sample_id, + block, + )?; + access.generate(&mut self.writer.id_gen, coords, level_id, sample_id, block) + } + crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => self + .write_conditional_image_access( + image_id, + coordinates, + level_id, + sample_id, + block, + &access, + )?, + crate::proc::BoundsCheckPolicy::Unchecked => access.generate( + &mut self.writer.id_gen, + coordinates.value_id, + level_id, + sample_id, + block, + ), + }; + + // For depth images, `ImageLoad` expressions produce a single f32, + // whereas the SPIR-V instructions always produce a vec4. So we may have + // to pull out the component we need. + let result_id = if result_type_id == access.result_type() { + // The instruction produced the type we expected. We can use + // its result as-is. + access_id + } else { + // For `ImageClass::Depth` images, SPIR-V gave us four components, + // but we only want the first one. + let component_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + component_id, + access_id, + &[0], + )); + component_id + }; + + Ok(result_id) + } + + /// Generate code for an `ImageSample` expression. + /// + /// The arguments are the components of an `Expression::ImageSample` variant. + #[allow(clippy::too_many_arguments)] + pub(super) fn write_image_sample( + &mut self, + result_type_id: Word, + image: Handle<crate::Expression>, + sampler: Handle<crate::Expression>, + gather: Option<crate::SwizzleComponent>, + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + offset: Option<Handle<crate::Expression>>, + level: crate::SampleLevel, + depth_ref: Option<Handle<crate::Expression>>, + block: &mut Block, + ) -> Result<Word, Error> { + use super::instructions::SampleLod; + // image + let image_id = self.get_handle_id(image); + let image_type = self.fun_info[image].ty.handle().unwrap(); + // SPIR-V doesn't know about our `Depth` class, and it returns + // `vec4<f32>`, so we need to grab the first component out of it. + let needs_sub_access = match self.ir_module.types[image_type].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Depth { .. }, + .. + } => depth_ref.is_none() && gather.is_none(), + _ => false, + }; + let sample_result_type_id = if needs_sub_access { + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Quad), + scalar: crate::Scalar::F32, + pointer_space: None, + })) + } else { + result_type_id + }; + + // OpTypeSampledImage + let image_type_id = self.get_type_id(LookupType::Handle(image_type)); + let sampled_image_type_id = + self.get_type_id(LookupType::Local(LocalType::SampledImage { image_type_id })); + + let sampler_id = self.get_handle_id(sampler); + let coordinates_id = self + .write_image_coordinates(coordinate, array_index, block)? + .value_id; + + let sampled_image_id = self.gen_id(); + block.body.push(Instruction::sampled_image( + sampled_image_type_id, + sampled_image_id, + image_id, + sampler_id, + )); + let id = self.gen_id(); + + let depth_id = depth_ref.map(|handle| self.cached[handle]); + let mut mask = spirv::ImageOperands::empty(); + mask.set(spirv::ImageOperands::CONST_OFFSET, offset.is_some()); + + let mut main_instruction = match (level, gather) { + (_, Some(component)) => { + let component_id = self.get_index_constant(component as u32); + let mut inst = Instruction::image_gather( + sample_result_type_id, + id, + sampled_image_id, + coordinates_id, + component_id, + depth_id, + ); + if !mask.is_empty() { + inst.add_operand(mask.bits()); + } + inst + } + (crate::SampleLevel::Zero, None) => { + let mut inst = Instruction::image_sample( + sample_result_type_id, + id, + SampleLod::Explicit, + sampled_image_id, + coordinates_id, + depth_id, + ); + + let zero_id = self.writer.get_constant_scalar(crate::Literal::F32(0.0)); + + mask |= spirv::ImageOperands::LOD; + inst.add_operand(mask.bits()); + inst.add_operand(zero_id); + + inst + } + (crate::SampleLevel::Auto, None) => { + let mut inst = Instruction::image_sample( + sample_result_type_id, + id, + SampleLod::Implicit, + sampled_image_id, + coordinates_id, + depth_id, + ); + if !mask.is_empty() { + inst.add_operand(mask.bits()); + } + inst + } + (crate::SampleLevel::Exact(lod_handle), None) => { + let mut inst = Instruction::image_sample( + sample_result_type_id, + id, + SampleLod::Explicit, + sampled_image_id, + coordinates_id, + depth_id, + ); + + let lod_id = self.cached[lod_handle]; + mask |= spirv::ImageOperands::LOD; + inst.add_operand(mask.bits()); + inst.add_operand(lod_id); + + inst + } + (crate::SampleLevel::Bias(bias_handle), None) => { + let mut inst = Instruction::image_sample( + sample_result_type_id, + id, + SampleLod::Implicit, + sampled_image_id, + coordinates_id, + depth_id, + ); + + let bias_id = self.cached[bias_handle]; + mask |= spirv::ImageOperands::BIAS; + inst.add_operand(mask.bits()); + inst.add_operand(bias_id); + + inst + } + (crate::SampleLevel::Gradient { x, y }, None) => { + let mut inst = Instruction::image_sample( + sample_result_type_id, + id, + SampleLod::Explicit, + sampled_image_id, + coordinates_id, + depth_id, + ); + + let x_id = self.cached[x]; + let y_id = self.cached[y]; + mask |= spirv::ImageOperands::GRAD; + inst.add_operand(mask.bits()); + inst.add_operand(x_id); + inst.add_operand(y_id); + + inst + } + }; + + if let Some(offset_const) = offset { + let offset_id = self.writer.constant_ids[offset_const.index()]; + main_instruction.add_operand(offset_id); + } + + block.body.push(main_instruction); + + let id = if needs_sub_access { + let sub_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + sub_id, + id, + &[0], + )); + sub_id + } else { + id + }; + + Ok(id) + } + + /// Generate code for an `ImageQuery` expression. + /// + /// The arguments are the components of an `Expression::ImageQuery` variant. + pub(super) fn write_image_query( + &mut self, + result_type_id: Word, + image: Handle<crate::Expression>, + query: crate::ImageQuery, + block: &mut Block, + ) -> Result<Word, Error> { + use crate::{ImageClass as Ic, ImageDimension as Id, ImageQuery as Iq}; + + let image_id = self.get_handle_id(image); + let image_type = self.fun_info[image].ty.handle().unwrap(); + let (dim, arrayed, class) = match self.ir_module.types[image_type].inner { + crate::TypeInner::Image { + dim, + arrayed, + class, + } => (dim, arrayed, class), + _ => { + return Err(Error::Validation("image type")); + } + }; + + self.writer + .require_any("image queries", &[spirv::Capability::ImageQuery])?; + + let id = match query { + Iq::Size { level } => { + let dim_coords = match dim { + Id::D1 => 1, + Id::D2 | Id::Cube => 2, + Id::D3 => 3, + }; + let array_coords = usize::from(arrayed); + let vector_size = match dim_coords + array_coords { + 2 => Some(crate::VectorSize::Bi), + 3 => Some(crate::VectorSize::Tri), + 4 => Some(crate::VectorSize::Quad), + _ => None, + }; + let extended_size_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size, + scalar: crate::Scalar::U32, + pointer_space: None, + })); + + let (query_op, level_id) = match class { + Ic::Sampled { multi: true, .. } + | Ic::Depth { multi: true } + | Ic::Storage { .. } => (spirv::Op::ImageQuerySize, None), + _ => { + let level_id = match level { + Some(expr) => self.cached[expr], + None => self.get_index_constant(0), + }; + (spirv::Op::ImageQuerySizeLod, Some(level_id)) + } + }; + + // The ID of the vector returned by SPIR-V, which contains the dimensions + // as well as the layer count. + let id_extended = self.gen_id(); + let mut inst = Instruction::image_query( + query_op, + extended_size_type_id, + id_extended, + image_id, + ); + if let Some(expr_id) = level_id { + inst.add_operand(expr_id); + } + block.body.push(inst); + + if result_type_id != extended_size_type_id { + let id = self.gen_id(); + let components = match dim { + // always pick the first component, and duplicate it for all 3 dimensions + Id::Cube => &[0u32, 0][..], + _ => &[0u32, 1, 2, 3][..dim_coords], + }; + block.body.push(Instruction::vector_shuffle( + result_type_id, + id, + id_extended, + id_extended, + components, + )); + + id + } else { + id_extended + } + } + Iq::NumLevels => { + let query_id = self.gen_id(); + block.body.push(Instruction::image_query( + spirv::Op::ImageQueryLevels, + result_type_id, + query_id, + image_id, + )); + + query_id + } + Iq::NumLayers => { + let vec_size = match dim { + Id::D1 => crate::VectorSize::Bi, + Id::D2 | Id::Cube => crate::VectorSize::Tri, + Id::D3 => crate::VectorSize::Quad, + }; + let extended_size_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(vec_size), + scalar: crate::Scalar::U32, + pointer_space: None, + })); + let id_extended = self.gen_id(); + let mut inst = Instruction::image_query( + spirv::Op::ImageQuerySizeLod, + extended_size_type_id, + id_extended, + image_id, + ); + inst.add_operand(self.get_index_constant(0)); + block.body.push(inst); + + let extract_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + extract_id, + id_extended, + &[vec_size as u32 - 1], + )); + + extract_id + } + Iq::NumSamples => { + let query_id = self.gen_id(); + block.body.push(Instruction::image_query( + spirv::Op::ImageQuerySamples, + result_type_id, + query_id, + image_id, + )); + + query_id + } + }; + + Ok(id) + } + + pub(super) fn write_image_store( + &mut self, + image: Handle<crate::Expression>, + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + value: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<(), Error> { + let image_id = self.get_handle_id(image); + let coordinates = self.write_image_coordinates(coordinate, array_index, block)?; + let value_id = self.cached[value]; + + let write = Store { image_id, value_id }; + + match *self.fun_info[image].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Image { + class: + crate::ImageClass::Storage { + format: crate::StorageFormat::Bgra8Unorm, + .. + }, + .. + } => self.writer.require_any( + "Bgra8Unorm storage write", + &[spirv::Capability::StorageImageWriteWithoutFormat], + )?, + _ => {} + } + + match self.writer.bounds_check_policies.image_store { + crate::proc::BoundsCheckPolicy::Restrict => { + let (coords, _, _) = + self.write_restricted_coordinates(image_id, coordinates, None, None, block)?; + write.generate(&mut self.writer.id_gen, coords, None, None, block); + } + crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => { + self.write_conditional_image_access( + image_id, + coordinates, + None, + None, + block, + &write, + )?; + } + crate::proc::BoundsCheckPolicy::Unchecked => { + write.generate( + &mut self.writer.id_gen, + coordinates.value_id, + None, + None, + block, + ); + } + } + + Ok(()) + } +} diff --git a/third_party/rust/naga/src/back/spv/index.rs b/third_party/rust/naga/src/back/spv/index.rs new file mode 100644 index 0000000000..92e0f88d9a --- /dev/null +++ b/third_party/rust/naga/src/back/spv/index.rs @@ -0,0 +1,421 @@ +/*! +Bounds-checking for SPIR-V output. +*/ + +use super::{ + helpers::global_needs_wrapper, selection::Selection, Block, BlockContext, Error, IdGenerator, + Instruction, Word, +}; +use crate::{arena::Handle, proc::BoundsCheckPolicy}; + +/// The results of performing a bounds check. +/// +/// On success, `write_bounds_check` returns a value of this type. +pub(super) enum BoundsCheckResult { + /// The index is statically known and in bounds, with the given value. + KnownInBounds(u32), + + /// The given instruction computes the index to be used. + Computed(Word), + + /// The given instruction computes a boolean condition which is true + /// if the index is in bounds. + Conditional(Word), +} + +/// A value that we either know at translation time, or need to compute at runtime. +pub(super) enum MaybeKnown<T> { + /// The value is known at shader translation time. + Known(T), + + /// The value is computed by the instruction with the given id. + Computed(Word), +} + +impl<'w> BlockContext<'w> { + /// Emit code to compute the length of a run-time array. + /// + /// Given `array`, an expression referring a runtime-sized array, return the + /// instruction id for the array's length. + pub(super) fn write_runtime_array_length( + &mut self, + array: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<Word, Error> { + // Naga IR permits runtime-sized arrays as global variables or as the + // final member of a struct that is a global variable. SPIR-V permits + // only the latter, so this back end wraps bare runtime-sized arrays + // in a made-up struct; see `helpers::global_needs_wrapper` and its uses. + // This code must handle both cases. + let (structure_id, last_member_index) = match self.ir_function.expressions[array] { + crate::Expression::AccessIndex { base, index } => { + match self.ir_function.expressions[base] { + crate::Expression::GlobalVariable(handle) => ( + self.writer.global_variables[handle.index()].access_id, + index, + ), + _ => return Err(Error::Validation("array length expression")), + } + } + crate::Expression::GlobalVariable(handle) => { + let global = &self.ir_module.global_variables[handle]; + if !global_needs_wrapper(self.ir_module, global) { + return Err(Error::Validation("array length expression")); + } + + (self.writer.global_variables[handle.index()].var_id, 0) + } + _ => return Err(Error::Validation("array length expression")), + }; + + let length_id = self.gen_id(); + block.body.push(Instruction::array_length( + self.writer.get_uint_type_id(), + length_id, + structure_id, + last_member_index, + )); + + Ok(length_id) + } + + /// Compute the length of a subscriptable value. + /// + /// Given `sequence`, an expression referring to some indexable type, return + /// its length. The result may either be computed by SPIR-V instructions, or + /// known at shader translation time. + /// + /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any + /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically + /// sized, or use a specializable constant as its length. + fn write_sequence_length( + &mut self, + sequence: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<MaybeKnown<u32>, Error> { + let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types); + match sequence_ty.indexable_length(self.ir_module) { + Ok(crate::proc::IndexableLength::Known(known_length)) => { + Ok(MaybeKnown::Known(known_length)) + } + Ok(crate::proc::IndexableLength::Dynamic) => { + let length_id = self.write_runtime_array_length(sequence, block)?; + Ok(MaybeKnown::Computed(length_id)) + } + Err(err) => { + log::error!("Sequence length for {:?} failed: {}", sequence, err); + Err(Error::Validation("indexable length")) + } + } + } + + /// Compute the maximum valid index of a subscriptable value. + /// + /// Given `sequence`, an expression referring to some indexable type, return + /// its maximum valid index - one less than its length. The result may + /// either be computed, or known at shader translation time. + /// + /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any + /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically + /// sized, or use a specializable constant as its length. + fn write_sequence_max_index( + &mut self, + sequence: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<MaybeKnown<u32>, Error> { + match self.write_sequence_length(sequence, block)? { + MaybeKnown::Known(known_length) => { + // We should have thrown out all attempts to subscript zero-length + // sequences during validation, so the following subtraction should never + // underflow. + assert!(known_length > 0); + // Compute the max index from the length now. + Ok(MaybeKnown::Known(known_length - 1)) + } + MaybeKnown::Computed(length_id) => { + // Emit code to compute the max index from the length. + let const_one_id = self.get_index_constant(1); + let max_index_id = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::ISub, + self.writer.get_uint_type_id(), + max_index_id, + length_id, + const_one_id, + )); + Ok(MaybeKnown::Computed(max_index_id)) + } + } + } + + /// Restrict an index to be in range for a vector, matrix, or array. + /// + /// This is used to implement `BoundsCheckPolicy::Restrict`. An in-bounds + /// index is left unchanged. An out-of-bounds index is replaced with some + /// arbitrary in-bounds index. Note,this is not necessarily clamping; for + /// example, negative indices might be changed to refer to the last element + /// of the sequence, not the first, as clamping would do. + /// + /// Either return the restricted index value, if known, or add instructions + /// to `block` to compute it, and return the id of the result. See the + /// documentation for `BoundsCheckResult` for details. + /// + /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a + /// `Pointer` to any of those, or a `ValuePointer`. An array may be + /// fixed-size, dynamically sized, or use a specializable constant as its + /// length. + pub(super) fn write_restricted_index( + &mut self, + sequence: Handle<crate::Expression>, + index: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<BoundsCheckResult, Error> { + let index_id = self.cached[index]; + + // Get the sequence's maximum valid index. Return early if we've already + // done the bounds check. + let max_index_id = match self.write_sequence_max_index(sequence, block)? { + MaybeKnown::Known(known_max_index) => { + if let Ok(known_index) = self + .ir_module + .to_ctx() + .eval_expr_to_u32_from(index, &self.ir_function.expressions) + { + // Both the index and length are known at compile time. + // + // In strict WGSL compliance mode, out-of-bounds indices cannot be + // reported at shader translation time, and must be replaced with + // in-bounds indices at run time. So we cannot assume that + // validation ensured the index was in bounds. Restrict now. + let restricted = std::cmp::min(known_index, known_max_index); + return Ok(BoundsCheckResult::KnownInBounds(restricted)); + } + + self.get_index_constant(known_max_index) + } + MaybeKnown::Computed(max_index_id) => max_index_id, + }; + + // One or the other of the index or length is dynamic, so emit code for + // BoundsCheckPolicy::Restrict. + let restricted_index_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::UMin, + self.writer.get_uint_type_id(), + restricted_index_id, + &[index_id, max_index_id], + )); + Ok(BoundsCheckResult::Computed(restricted_index_id)) + } + + /// Write an index bounds comparison to `block`, if needed. + /// + /// If we're able to determine statically that `index` is in bounds for + /// `sequence`, return `KnownInBounds(value)`, where `value` is the actual + /// value of the index. (In principle, one could know that the index is in + /// bounds without knowing its specific value, but in our simple-minded + /// situation, we always know it.) + /// + /// If instead we must generate code to perform the comparison at run time, + /// return `Conditional(comparison_id)`, where `comparison_id` is an + /// instruction producing a boolean value that is true if `index` is in + /// bounds for `sequence`. + /// + /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a + /// `Pointer` to any of those, or a `ValuePointer`. An array may be + /// fixed-size, dynamically sized, or use a specializable constant as its + /// length. + fn write_index_comparison( + &mut self, + sequence: Handle<crate::Expression>, + index: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<BoundsCheckResult, Error> { + let index_id = self.cached[index]; + + // Get the sequence's length. Return early if we've already done the + // bounds check. + let length_id = match self.write_sequence_length(sequence, block)? { + MaybeKnown::Known(known_length) => { + if let Ok(known_index) = self + .ir_module + .to_ctx() + .eval_expr_to_u32_from(index, &self.ir_function.expressions) + { + // Both the index and length are known at compile time. + // + // It would be nice to assume that, since we are using the + // `ReadZeroSkipWrite` policy, we are not in strict WGSL + // compliance mode, and thus we can count on the validator to have + // rejected any programs with known out-of-bounds indices, and + // thus just return `KnownInBounds` here without actually + // checking. + // + // But it's also reasonable to expect that bounds check policies + // and error reporting policies should be able to vary + // independently without introducing security holes. So, we should + // support the case where bad indices do not cause validation + // errors, and are handled via `ReadZeroSkipWrite`. + // + // In theory, when `known_index` is bad, we could return a new + // `KnownOutOfBounds` variant here. But it's simpler just to fall + // through and let the bounds check take place. The shader is + // broken anyway, so it doesn't make sense to invest in emitting + // the ideal code for it. + if known_index < known_length { + return Ok(BoundsCheckResult::KnownInBounds(known_index)); + } + } + + self.get_index_constant(known_length) + } + MaybeKnown::Computed(length_id) => length_id, + }; + + // Compare the index against the length. + let condition_id = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::ULessThan, + self.writer.get_bool_type_id(), + condition_id, + index_id, + length_id, + )); + + // Indicate that we did generate the check. + Ok(BoundsCheckResult::Conditional(condition_id)) + } + + /// Emit a conditional load for `BoundsCheckPolicy::ReadZeroSkipWrite`. + /// + /// Generate code to load a value of `result_type` if `condition` is true, + /// and generate a null value of that type if it is false. Call `emit_load` + /// to emit the instructions to perform the load. Return the id of the + /// merged value of the two branches. + pub(super) fn write_conditional_indexed_load<F>( + &mut self, + result_type: Word, + condition: Word, + block: &mut Block, + emit_load: F, + ) -> Word + where + F: FnOnce(&mut IdGenerator, &mut Block) -> Word, + { + // For the out-of-bounds case, we produce a zero value. + let null_id = self.writer.get_constant_null(result_type); + + let mut selection = Selection::start(block, result_type); + + // As it turns out, we don't actually need a full 'if-then-else' + // structure for this: SPIR-V constants are declared up front, so the + // 'else' block would have no instructions. Instead we emit something + // like this: + // + // result = zero; + // if in_bounds { + // result = do the load; + // } + // use result; + + // Continue only if the index was in bounds. Otherwise, branch to the + // merge block. + selection.if_true(self, condition, null_id); + + // The in-bounds path. Perform the access and the load. + let loaded_value = emit_load(&mut self.writer.id_gen, selection.block()); + + selection.finish(self, loaded_value) + } + + /// Emit code for bounds checks for an array, vector, or matrix access. + /// + /// This implements either `index_bounds_check_policy` or + /// `buffer_bounds_check_policy`, depending on the address space of the + /// pointer being accessed. + /// + /// Return a `BoundsCheckResult` indicating how the index should be + /// consumed. See that type's documentation for details. + pub(super) fn write_bounds_check( + &mut self, + base: Handle<crate::Expression>, + index: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<BoundsCheckResult, Error> { + let policy = self.writer.bounds_check_policies.choose_policy( + base, + &self.ir_module.types, + self.fun_info, + ); + + Ok(match policy { + BoundsCheckPolicy::Restrict => self.write_restricted_index(base, index, block)?, + BoundsCheckPolicy::ReadZeroSkipWrite => { + self.write_index_comparison(base, index, block)? + } + BoundsCheckPolicy::Unchecked => BoundsCheckResult::Computed(self.cached[index]), + }) + } + + /// Emit code to subscript a vector by value with a computed index. + /// + /// Return the id of the element value. + pub(super) fn write_vector_access( + &mut self, + expr_handle: Handle<crate::Expression>, + base: Handle<crate::Expression>, + index: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<Word, Error> { + let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty); + + let base_id = self.cached[base]; + let index_id = self.cached[index]; + + let result_id = match self.write_bounds_check(base, index, block)? { + BoundsCheckResult::KnownInBounds(known_index) => { + let result_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + result_id, + base_id, + &[known_index], + )); + result_id + } + BoundsCheckResult::Computed(computed_index_id) => { + let result_id = self.gen_id(); + block.body.push(Instruction::vector_extract_dynamic( + result_type_id, + result_id, + base_id, + computed_index_id, + )); + result_id + } + BoundsCheckResult::Conditional(comparison_id) => { + // Run-time bounds checks were required. Emit + // conditional load. + self.write_conditional_indexed_load( + result_type_id, + comparison_id, + block, + |id_gen, block| { + // The in-bounds path. Generate the access. + let element_id = id_gen.next(); + block.body.push(Instruction::vector_extract_dynamic( + result_type_id, + element_id, + base_id, + index_id, + )); + element_id + }, + ) + } + }; + + Ok(result_id) + } +} diff --git a/third_party/rust/naga/src/back/spv/instructions.rs b/third_party/rust/naga/src/back/spv/instructions.rs new file mode 100644 index 0000000000..b963793ad3 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/instructions.rs @@ -0,0 +1,1100 @@ +use super::{block::DebugInfoInner, helpers}; +use spirv::{Op, Word}; + +pub(super) enum Signedness { + Unsigned = 0, + Signed = 1, +} + +pub(super) enum SampleLod { + Explicit, + Implicit, +} + +pub(super) struct Case { + pub value: Word, + pub label_id: Word, +} + +impl super::Instruction { + // + // Debug Instructions + // + + pub(super) fn string(name: &str, id: Word) -> Self { + let mut instruction = Self::new(Op::String); + instruction.set_result(id); + instruction.add_operands(helpers::string_to_words(name)); + instruction + } + + pub(super) fn source( + source_language: spirv::SourceLanguage, + version: u32, + source: &Option<DebugInfoInner>, + ) -> Self { + let mut instruction = Self::new(Op::Source); + instruction.add_operand(source_language as u32); + instruction.add_operands(helpers::bytes_to_words(&version.to_le_bytes())); + if let Some(source) = source.as_ref() { + instruction.add_operand(source.source_file_id); + instruction.add_operands(helpers::string_to_words(source.source_code)); + } + instruction + } + + pub(super) fn name(target_id: Word, name: &str) -> Self { + let mut instruction = Self::new(Op::Name); + instruction.add_operand(target_id); + instruction.add_operands(helpers::string_to_words(name)); + instruction + } + + pub(super) fn member_name(target_id: Word, member: Word, name: &str) -> Self { + let mut instruction = Self::new(Op::MemberName); + instruction.add_operand(target_id); + instruction.add_operand(member); + instruction.add_operands(helpers::string_to_words(name)); + instruction + } + + pub(super) fn line(file: Word, line: Word, column: Word) -> Self { + let mut instruction = Self::new(Op::Line); + instruction.add_operand(file); + instruction.add_operand(line); + instruction.add_operand(column); + instruction + } + + pub(super) const fn no_line() -> Self { + Self::new(Op::NoLine) + } + + // + // Annotation Instructions + // + + pub(super) fn decorate( + target_id: Word, + decoration: spirv::Decoration, + operands: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::Decorate); + instruction.add_operand(target_id); + instruction.add_operand(decoration as u32); + for operand in operands { + instruction.add_operand(*operand) + } + instruction + } + + pub(super) fn member_decorate( + target_id: Word, + member_index: Word, + decoration: spirv::Decoration, + operands: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::MemberDecorate); + instruction.add_operand(target_id); + instruction.add_operand(member_index); + instruction.add_operand(decoration as u32); + for operand in operands { + instruction.add_operand(*operand) + } + instruction + } + + // + // Extension Instructions + // + + pub(super) fn extension(name: &str) -> Self { + let mut instruction = Self::new(Op::Extension); + instruction.add_operands(helpers::string_to_words(name)); + instruction + } + + pub(super) fn ext_inst_import(id: Word, name: &str) -> Self { + let mut instruction = Self::new(Op::ExtInstImport); + instruction.set_result(id); + instruction.add_operands(helpers::string_to_words(name)); + instruction + } + + pub(super) fn ext_inst( + set_id: Word, + op: spirv::GLOp, + result_type_id: Word, + id: Word, + operands: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::ExtInst); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(set_id); + instruction.add_operand(op as u32); + for operand in operands { + instruction.add_operand(*operand) + } + instruction + } + + // + // Mode-Setting Instructions + // + + pub(super) fn memory_model( + addressing_model: spirv::AddressingModel, + memory_model: spirv::MemoryModel, + ) -> Self { + let mut instruction = Self::new(Op::MemoryModel); + instruction.add_operand(addressing_model as u32); + instruction.add_operand(memory_model as u32); + instruction + } + + pub(super) fn entry_point( + execution_model: spirv::ExecutionModel, + entry_point_id: Word, + name: &str, + interface_ids: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::EntryPoint); + instruction.add_operand(execution_model as u32); + instruction.add_operand(entry_point_id); + instruction.add_operands(helpers::string_to_words(name)); + + for interface_id in interface_ids { + instruction.add_operand(*interface_id); + } + + instruction + } + + pub(super) fn execution_mode( + entry_point_id: Word, + execution_mode: spirv::ExecutionMode, + args: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::ExecutionMode); + instruction.add_operand(entry_point_id); + instruction.add_operand(execution_mode as u32); + for arg in args { + instruction.add_operand(*arg); + } + instruction + } + + pub(super) fn capability(capability: spirv::Capability) -> Self { + let mut instruction = Self::new(Op::Capability); + instruction.add_operand(capability as u32); + instruction + } + + // + // Type-Declaration Instructions + // + + pub(super) fn type_void(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeVoid); + instruction.set_result(id); + instruction + } + + pub(super) fn type_bool(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeBool); + instruction.set_result(id); + instruction + } + + pub(super) fn type_int(id: Word, width: Word, signedness: Signedness) -> Self { + let mut instruction = Self::new(Op::TypeInt); + instruction.set_result(id); + instruction.add_operand(width); + instruction.add_operand(signedness as u32); + instruction + } + + pub(super) fn type_float(id: Word, width: Word) -> Self { + let mut instruction = Self::new(Op::TypeFloat); + instruction.set_result(id); + instruction.add_operand(width); + instruction + } + + pub(super) fn type_vector( + id: Word, + component_type_id: Word, + component_count: crate::VectorSize, + ) -> Self { + let mut instruction = Self::new(Op::TypeVector); + instruction.set_result(id); + instruction.add_operand(component_type_id); + instruction.add_operand(component_count as u32); + instruction + } + + pub(super) fn type_matrix( + id: Word, + column_type_id: Word, + column_count: crate::VectorSize, + ) -> Self { + let mut instruction = Self::new(Op::TypeMatrix); + instruction.set_result(id); + instruction.add_operand(column_type_id); + instruction.add_operand(column_count as u32); + instruction + } + + #[allow(clippy::too_many_arguments)] + pub(super) fn type_image( + id: Word, + sampled_type_id: Word, + dim: spirv::Dim, + flags: super::ImageTypeFlags, + image_format: spirv::ImageFormat, + ) -> Self { + let mut instruction = Self::new(Op::TypeImage); + instruction.set_result(id); + instruction.add_operand(sampled_type_id); + instruction.add_operand(dim as u32); + instruction.add_operand(flags.contains(super::ImageTypeFlags::DEPTH) as u32); + instruction.add_operand(flags.contains(super::ImageTypeFlags::ARRAYED) as u32); + instruction.add_operand(flags.contains(super::ImageTypeFlags::MULTISAMPLED) as u32); + instruction.add_operand(if flags.contains(super::ImageTypeFlags::SAMPLED) { + 1 + } else { + 2 + }); + instruction.add_operand(image_format as u32); + instruction + } + + pub(super) fn type_sampler(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeSampler); + instruction.set_result(id); + instruction + } + + pub(super) fn type_acceleration_structure(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeAccelerationStructureKHR); + instruction.set_result(id); + instruction + } + + pub(super) fn type_ray_query(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeRayQueryKHR); + instruction.set_result(id); + instruction + } + + pub(super) fn type_sampled_image(id: Word, image_type_id: Word) -> Self { + let mut instruction = Self::new(Op::TypeSampledImage); + instruction.set_result(id); + instruction.add_operand(image_type_id); + instruction + } + + pub(super) fn type_array(id: Word, element_type_id: Word, length_id: Word) -> Self { + let mut instruction = Self::new(Op::TypeArray); + instruction.set_result(id); + instruction.add_operand(element_type_id); + instruction.add_operand(length_id); + instruction + } + + pub(super) fn type_runtime_array(id: Word, element_type_id: Word) -> Self { + let mut instruction = Self::new(Op::TypeRuntimeArray); + instruction.set_result(id); + instruction.add_operand(element_type_id); + instruction + } + + pub(super) fn type_struct(id: Word, member_ids: &[Word]) -> Self { + let mut instruction = Self::new(Op::TypeStruct); + instruction.set_result(id); + + for member_id in member_ids { + instruction.add_operand(*member_id) + } + + instruction + } + + pub(super) fn type_pointer( + id: Word, + storage_class: spirv::StorageClass, + type_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::TypePointer); + instruction.set_result(id); + instruction.add_operand(storage_class as u32); + instruction.add_operand(type_id); + instruction + } + + pub(super) fn type_function(id: Word, return_type_id: Word, parameter_ids: &[Word]) -> Self { + let mut instruction = Self::new(Op::TypeFunction); + instruction.set_result(id); + instruction.add_operand(return_type_id); + + for parameter_id in parameter_ids { + instruction.add_operand(*parameter_id); + } + + instruction + } + + // + // Constant-Creation Instructions + // + + pub(super) fn constant_null(result_type_id: Word, id: Word) -> Self { + let mut instruction = Self::new(Op::ConstantNull); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction + } + + pub(super) fn constant_true(result_type_id: Word, id: Word) -> Self { + let mut instruction = Self::new(Op::ConstantTrue); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction + } + + pub(super) fn constant_false(result_type_id: Word, id: Word) -> Self { + let mut instruction = Self::new(Op::ConstantFalse); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction + } + + pub(super) fn constant_32bit(result_type_id: Word, id: Word, value: Word) -> Self { + Self::constant(result_type_id, id, &[value]) + } + + pub(super) fn constant_64bit(result_type_id: Word, id: Word, low: Word, high: Word) -> Self { + Self::constant(result_type_id, id, &[low, high]) + } + + pub(super) fn constant(result_type_id: Word, id: Word, values: &[Word]) -> Self { + let mut instruction = Self::new(Op::Constant); + instruction.set_type(result_type_id); + instruction.set_result(id); + + for value in values { + instruction.add_operand(*value); + } + + instruction + } + + pub(super) fn constant_composite( + result_type_id: Word, + id: Word, + constituent_ids: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::ConstantComposite); + instruction.set_type(result_type_id); + instruction.set_result(id); + + for constituent_id in constituent_ids { + instruction.add_operand(*constituent_id); + } + + instruction + } + + // + // Memory Instructions + // + + pub(super) fn variable( + result_type_id: Word, + id: Word, + storage_class: spirv::StorageClass, + initializer_id: Option<Word>, + ) -> Self { + let mut instruction = Self::new(Op::Variable); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(storage_class as u32); + + if let Some(initializer_id) = initializer_id { + instruction.add_operand(initializer_id); + } + + instruction + } + + pub(super) fn load( + result_type_id: Word, + id: Word, + pointer_id: Word, + memory_access: Option<spirv::MemoryAccess>, + ) -> Self { + let mut instruction = Self::new(Op::Load); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(pointer_id); + + if let Some(memory_access) = memory_access { + instruction.add_operand(memory_access.bits()); + } + + instruction + } + + pub(super) fn atomic_load( + result_type_id: Word, + id: Word, + pointer_id: Word, + scope_id: Word, + semantics_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::AtomicLoad); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(pointer_id); + instruction.add_operand(scope_id); + instruction.add_operand(semantics_id); + instruction + } + + pub(super) fn store( + pointer_id: Word, + value_id: Word, + memory_access: Option<spirv::MemoryAccess>, + ) -> Self { + let mut instruction = Self::new(Op::Store); + instruction.add_operand(pointer_id); + instruction.add_operand(value_id); + + if let Some(memory_access) = memory_access { + instruction.add_operand(memory_access.bits()); + } + + instruction + } + + pub(super) fn atomic_store( + pointer_id: Word, + scope_id: Word, + semantics_id: Word, + value_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::AtomicStore); + instruction.add_operand(pointer_id); + instruction.add_operand(scope_id); + instruction.add_operand(semantics_id); + instruction.add_operand(value_id); + instruction + } + + pub(super) fn access_chain( + result_type_id: Word, + id: Word, + base_id: Word, + index_ids: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::AccessChain); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(base_id); + + for index_id in index_ids { + instruction.add_operand(*index_id); + } + + instruction + } + + pub(super) fn array_length( + result_type_id: Word, + id: Word, + structure_id: Word, + array_member: Word, + ) -> Self { + let mut instruction = Self::new(Op::ArrayLength); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(structure_id); + instruction.add_operand(array_member); + instruction + } + + // + // Function Instructions + // + + pub(super) fn function( + return_type_id: Word, + id: Word, + function_control: spirv::FunctionControl, + function_type_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::Function); + instruction.set_type(return_type_id); + instruction.set_result(id); + instruction.add_operand(function_control.bits()); + instruction.add_operand(function_type_id); + instruction + } + + pub(super) fn function_parameter(result_type_id: Word, id: Word) -> Self { + let mut instruction = Self::new(Op::FunctionParameter); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction + } + + pub(super) const fn function_end() -> Self { + Self::new(Op::FunctionEnd) + } + + pub(super) fn function_call( + result_type_id: Word, + id: Word, + function_id: Word, + argument_ids: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::FunctionCall); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(function_id); + + for argument_id in argument_ids { + instruction.add_operand(*argument_id); + } + + instruction + } + + // + // Image Instructions + // + + pub(super) fn sampled_image( + result_type_id: Word, + id: Word, + image: Word, + sampler: Word, + ) -> Self { + let mut instruction = Self::new(Op::SampledImage); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(image); + instruction.add_operand(sampler); + instruction + } + + pub(super) fn image_sample( + result_type_id: Word, + id: Word, + lod: SampleLod, + sampled_image: Word, + coordinates: Word, + depth_ref: Option<Word>, + ) -> Self { + let op = match (lod, depth_ref) { + (SampleLod::Explicit, None) => Op::ImageSampleExplicitLod, + (SampleLod::Implicit, None) => Op::ImageSampleImplicitLod, + (SampleLod::Explicit, Some(_)) => Op::ImageSampleDrefExplicitLod, + (SampleLod::Implicit, Some(_)) => Op::ImageSampleDrefImplicitLod, + }; + + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(sampled_image); + instruction.add_operand(coordinates); + if let Some(dref) = depth_ref { + instruction.add_operand(dref); + } + + instruction + } + + pub(super) fn image_gather( + result_type_id: Word, + id: Word, + sampled_image: Word, + coordinates: Word, + component_id: Word, + depth_ref: Option<Word>, + ) -> Self { + let op = match depth_ref { + None => Op::ImageGather, + Some(_) => Op::ImageDrefGather, + }; + + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(sampled_image); + instruction.add_operand(coordinates); + if let Some(dref) = depth_ref { + instruction.add_operand(dref); + } else { + instruction.add_operand(component_id); + } + + instruction + } + + pub(super) fn image_fetch_or_read( + op: Op, + result_type_id: Word, + id: Word, + image: Word, + coordinates: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(image); + instruction.add_operand(coordinates); + instruction + } + + pub(super) fn image_write(image: Word, coordinates: Word, value: Word) -> Self { + let mut instruction = Self::new(Op::ImageWrite); + instruction.add_operand(image); + instruction.add_operand(coordinates); + instruction.add_operand(value); + instruction + } + + pub(super) fn image_query(op: Op, result_type_id: Word, id: Word, image: Word) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(image); + instruction + } + + // + // Ray Query Instructions + // + #[allow(clippy::too_many_arguments)] + pub(super) fn ray_query_initialize( + query: Word, + acceleration_structure: Word, + ray_flags: Word, + cull_mask: Word, + ray_origin: Word, + ray_tmin: Word, + ray_dir: Word, + ray_tmax: Word, + ) -> Self { + let mut instruction = Self::new(Op::RayQueryInitializeKHR); + instruction.add_operand(query); + instruction.add_operand(acceleration_structure); + instruction.add_operand(ray_flags); + instruction.add_operand(cull_mask); + instruction.add_operand(ray_origin); + instruction.add_operand(ray_tmin); + instruction.add_operand(ray_dir); + instruction.add_operand(ray_tmax); + instruction + } + + pub(super) fn ray_query_proceed(result_type_id: Word, id: Word, query: Word) -> Self { + let mut instruction = Self::new(Op::RayQueryProceedKHR); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(query); + instruction + } + + pub(super) fn ray_query_get_intersection( + op: Op, + result_type_id: Word, + id: Word, + query: Word, + intersection: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(query); + instruction.add_operand(intersection); + instruction + } + + // + // Conversion Instructions + // + pub(super) fn unary(op: Op, result_type_id: Word, id: Word, value: Word) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(value); + instruction + } + + // + // Composite Instructions + // + + pub(super) fn composite_construct( + result_type_id: Word, + id: Word, + constituent_ids: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::CompositeConstruct); + instruction.set_type(result_type_id); + instruction.set_result(id); + + for constituent_id in constituent_ids { + instruction.add_operand(*constituent_id); + } + + instruction + } + + pub(super) fn composite_extract( + result_type_id: Word, + id: Word, + composite_id: Word, + indices: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::CompositeExtract); + instruction.set_type(result_type_id); + instruction.set_result(id); + + instruction.add_operand(composite_id); + for index in indices { + instruction.add_operand(*index); + } + + instruction + } + + pub(super) fn vector_extract_dynamic( + result_type_id: Word, + id: Word, + vector_id: Word, + index_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::VectorExtractDynamic); + instruction.set_type(result_type_id); + instruction.set_result(id); + + instruction.add_operand(vector_id); + instruction.add_operand(index_id); + + instruction + } + + pub(super) fn vector_shuffle( + result_type_id: Word, + id: Word, + v1_id: Word, + v2_id: Word, + components: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::VectorShuffle); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(v1_id); + instruction.add_operand(v2_id); + + for &component in components { + instruction.add_operand(component); + } + + instruction + } + + // + // Arithmetic Instructions + // + pub(super) fn binary( + op: Op, + result_type_id: Word, + id: Word, + operand_1: Word, + operand_2: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(operand_1); + instruction.add_operand(operand_2); + instruction + } + + pub(super) fn ternary( + op: Op, + result_type_id: Word, + id: Word, + operand_1: Word, + operand_2: Word, + operand_3: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(operand_1); + instruction.add_operand(operand_2); + instruction.add_operand(operand_3); + instruction + } + + pub(super) fn quaternary( + op: Op, + result_type_id: Word, + id: Word, + operand_1: Word, + operand_2: Word, + operand_3: Word, + operand_4: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(operand_1); + instruction.add_operand(operand_2); + instruction.add_operand(operand_3); + instruction.add_operand(operand_4); + instruction + } + + pub(super) fn relational(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(expr_id); + instruction + } + + pub(super) fn atomic_binary( + op: Op, + result_type_id: Word, + id: Word, + pointer: Word, + scope_id: Word, + semantics_id: Word, + value: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(pointer); + instruction.add_operand(scope_id); + instruction.add_operand(semantics_id); + instruction.add_operand(value); + instruction + } + + // + // Bit Instructions + // + + // + // Relational and Logical Instructions + // + + // + // Derivative Instructions + // + + pub(super) fn derivative(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(expr_id); + instruction + } + + // + // Control-Flow Instructions + // + + pub(super) fn phi( + result_type_id: Word, + result_id: Word, + var_parent_pairs: &[(Word, Word)], + ) -> Self { + let mut instruction = Self::new(Op::Phi); + instruction.add_operand(result_type_id); + instruction.add_operand(result_id); + for &(variable, parent) in var_parent_pairs { + instruction.add_operand(variable); + instruction.add_operand(parent); + } + instruction + } + + pub(super) fn selection_merge( + merge_id: Word, + selection_control: spirv::SelectionControl, + ) -> Self { + let mut instruction = Self::new(Op::SelectionMerge); + instruction.add_operand(merge_id); + instruction.add_operand(selection_control.bits()); + instruction + } + + pub(super) fn loop_merge( + merge_id: Word, + continuing_id: Word, + selection_control: spirv::SelectionControl, + ) -> Self { + let mut instruction = Self::new(Op::LoopMerge); + instruction.add_operand(merge_id); + instruction.add_operand(continuing_id); + instruction.add_operand(selection_control.bits()); + instruction + } + + pub(super) fn label(id: Word) -> Self { + let mut instruction = Self::new(Op::Label); + instruction.set_result(id); + instruction + } + + pub(super) fn branch(id: Word) -> Self { + let mut instruction = Self::new(Op::Branch); + instruction.add_operand(id); + instruction + } + + // TODO Branch Weights not implemented. + pub(super) fn branch_conditional( + condition_id: Word, + true_label: Word, + false_label: Word, + ) -> Self { + let mut instruction = Self::new(Op::BranchConditional); + instruction.add_operand(condition_id); + instruction.add_operand(true_label); + instruction.add_operand(false_label); + instruction + } + + pub(super) fn switch(selector_id: Word, default_id: Word, cases: &[Case]) -> Self { + let mut instruction = Self::new(Op::Switch); + instruction.add_operand(selector_id); + instruction.add_operand(default_id); + for case in cases { + instruction.add_operand(case.value); + instruction.add_operand(case.label_id); + } + instruction + } + + pub(super) fn select( + result_type_id: Word, + id: Word, + condition_id: Word, + accept_id: Word, + reject_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::Select); + instruction.add_operand(result_type_id); + instruction.add_operand(id); + instruction.add_operand(condition_id); + instruction.add_operand(accept_id); + instruction.add_operand(reject_id); + instruction + } + + pub(super) const fn kill() -> Self { + Self::new(Op::Kill) + } + + pub(super) const fn return_void() -> Self { + Self::new(Op::Return) + } + + pub(super) fn return_value(value_id: Word) -> Self { + let mut instruction = Self::new(Op::ReturnValue); + instruction.add_operand(value_id); + instruction + } + + // + // Atomic Instructions + // + + // + // Primitive Instructions + // + + // Barriers + + pub(super) fn control_barrier( + exec_scope_id: Word, + mem_scope_id: Word, + semantics_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::ControlBarrier); + instruction.add_operand(exec_scope_id); + instruction.add_operand(mem_scope_id); + instruction.add_operand(semantics_id); + instruction + } +} + +impl From<crate::StorageFormat> for spirv::ImageFormat { + fn from(format: crate::StorageFormat) -> Self { + use crate::StorageFormat as Sf; + match format { + Sf::R8Unorm => Self::R8, + Sf::R8Snorm => Self::R8Snorm, + Sf::R8Uint => Self::R8ui, + Sf::R8Sint => Self::R8i, + Sf::R16Uint => Self::R16ui, + Sf::R16Sint => Self::R16i, + Sf::R16Float => Self::R16f, + Sf::Rg8Unorm => Self::Rg8, + Sf::Rg8Snorm => Self::Rg8Snorm, + Sf::Rg8Uint => Self::Rg8ui, + Sf::Rg8Sint => Self::Rg8i, + Sf::R32Uint => Self::R32ui, + Sf::R32Sint => Self::R32i, + Sf::R32Float => Self::R32f, + Sf::Rg16Uint => Self::Rg16ui, + Sf::Rg16Sint => Self::Rg16i, + Sf::Rg16Float => Self::Rg16f, + Sf::Rgba8Unorm => Self::Rgba8, + Sf::Rgba8Snorm => Self::Rgba8Snorm, + Sf::Rgba8Uint => Self::Rgba8ui, + Sf::Rgba8Sint => Self::Rgba8i, + Sf::Bgra8Unorm => Self::Unknown, + Sf::Rgb10a2Uint => Self::Rgb10a2ui, + Sf::Rgb10a2Unorm => Self::Rgb10A2, + Sf::Rg11b10Float => Self::R11fG11fB10f, + Sf::Rg32Uint => Self::Rg32ui, + Sf::Rg32Sint => Self::Rg32i, + Sf::Rg32Float => Self::Rg32f, + Sf::Rgba16Uint => Self::Rgba16ui, + Sf::Rgba16Sint => Self::Rgba16i, + Sf::Rgba16Float => Self::Rgba16f, + Sf::Rgba32Uint => Self::Rgba32ui, + Sf::Rgba32Sint => Self::Rgba32i, + Sf::Rgba32Float => Self::Rgba32f, + Sf::R16Unorm => Self::R16, + Sf::R16Snorm => Self::R16Snorm, + Sf::Rg16Unorm => Self::Rg16, + Sf::Rg16Snorm => Self::Rg16Snorm, + Sf::Rgba16Unorm => Self::Rgba16, + Sf::Rgba16Snorm => Self::Rgba16Snorm, + } + } +} + +impl From<crate::ImageDimension> for spirv::Dim { + fn from(dim: crate::ImageDimension) -> Self { + use crate::ImageDimension as Id; + match dim { + Id::D1 => Self::Dim1D, + Id::D2 => Self::Dim2D, + Id::D3 => Self::Dim3D, + Id::Cube => Self::DimCube, + } + } +} diff --git a/third_party/rust/naga/src/back/spv/layout.rs b/third_party/rust/naga/src/back/spv/layout.rs new file mode 100644 index 0000000000..39117a3d2a --- /dev/null +++ b/third_party/rust/naga/src/back/spv/layout.rs @@ -0,0 +1,210 @@ +use super::{Instruction, LogicalLayout, PhysicalLayout}; +use spirv::{Op, Word, MAGIC_NUMBER}; +use std::iter; + +// https://github.com/KhronosGroup/SPIRV-Headers/pull/195 +const GENERATOR: Word = 28; + +impl PhysicalLayout { + pub(super) const fn new(version: Word) -> Self { + PhysicalLayout { + magic_number: MAGIC_NUMBER, + version, + generator: GENERATOR, + bound: 0, + instruction_schema: 0x0u32, + } + } + + pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) { + sink.extend(iter::once(self.magic_number)); + sink.extend(iter::once(self.version)); + sink.extend(iter::once(self.generator)); + sink.extend(iter::once(self.bound)); + sink.extend(iter::once(self.instruction_schema)); + } +} + +impl super::recyclable::Recyclable for PhysicalLayout { + fn recycle(self) -> Self { + PhysicalLayout { + magic_number: self.magic_number, + version: self.version, + generator: self.generator, + instruction_schema: self.instruction_schema, + bound: 0, + } + } +} + +impl LogicalLayout { + pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) { + sink.extend(self.capabilities.iter().cloned()); + sink.extend(self.extensions.iter().cloned()); + sink.extend(self.ext_inst_imports.iter().cloned()); + sink.extend(self.memory_model.iter().cloned()); + sink.extend(self.entry_points.iter().cloned()); + sink.extend(self.execution_modes.iter().cloned()); + sink.extend(self.debugs.iter().cloned()); + sink.extend(self.annotations.iter().cloned()); + sink.extend(self.declarations.iter().cloned()); + sink.extend(self.function_declarations.iter().cloned()); + sink.extend(self.function_definitions.iter().cloned()); + } +} + +impl super::recyclable::Recyclable for LogicalLayout { + fn recycle(self) -> Self { + Self { + capabilities: self.capabilities.recycle(), + extensions: self.extensions.recycle(), + ext_inst_imports: self.ext_inst_imports.recycle(), + memory_model: self.memory_model.recycle(), + entry_points: self.entry_points.recycle(), + execution_modes: self.execution_modes.recycle(), + debugs: self.debugs.recycle(), + annotations: self.annotations.recycle(), + declarations: self.declarations.recycle(), + function_declarations: self.function_declarations.recycle(), + function_definitions: self.function_definitions.recycle(), + } + } +} + +impl Instruction { + pub(super) const fn new(op: Op) -> Self { + Instruction { + op, + wc: 1, // Always start at 1 for the first word (OP + WC), + type_id: None, + result_id: None, + operands: vec![], + } + } + + #[allow(clippy::panic)] + pub(super) fn set_type(&mut self, id: Word) { + assert!(self.type_id.is_none(), "Type can only be set once"); + self.type_id = Some(id); + self.wc += 1; + } + + #[allow(clippy::panic)] + pub(super) fn set_result(&mut self, id: Word) { + assert!(self.result_id.is_none(), "Result can only be set once"); + self.result_id = Some(id); + self.wc += 1; + } + + pub(super) fn add_operand(&mut self, operand: Word) { + self.operands.push(operand); + self.wc += 1; + } + + pub(super) fn add_operands(&mut self, operands: Vec<Word>) { + for operand in operands.into_iter() { + self.add_operand(operand) + } + } + + pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) { + sink.extend(Some(self.wc << 16 | self.op as u32)); + sink.extend(self.type_id); + sink.extend(self.result_id); + sink.extend(self.operands.iter().cloned()); + } +} + +impl Instruction { + #[cfg(test)] + fn validate(&self, words: &[Word]) { + let mut inst_index = 0; + let (wc, op) = ((words[inst_index] >> 16) as u16, words[inst_index] as u16); + inst_index += 1; + + assert_eq!(wc, words.len() as u16); + assert_eq!(op, self.op as u16); + + if self.type_id.is_some() { + assert_eq!(words[inst_index], self.type_id.unwrap()); + inst_index += 1; + } + + if self.result_id.is_some() { + assert_eq!(words[inst_index], self.result_id.unwrap()); + inst_index += 1; + } + + for (op_index, i) in (inst_index..wc as usize).enumerate() { + assert_eq!(words[i], self.operands[op_index]); + } + } +} + +#[test] +fn test_physical_layout_in_words() { + let bound = 5; + let version = 0x10203; + + let mut output = vec![]; + let mut layout = PhysicalLayout::new(version); + layout.bound = bound; + + layout.in_words(&mut output); + + assert_eq!(&output, &[MAGIC_NUMBER, version, GENERATOR, bound, 0,]); +} + +#[test] +fn test_logical_layout_in_words() { + let mut output = vec![]; + let mut layout = LogicalLayout::default(); + let layout_vectors = 11; + let mut instructions = Vec::with_capacity(layout_vectors); + + let vector_names = &[ + "Capabilities", + "Extensions", + "External Instruction Imports", + "Memory Model", + "Entry Points", + "Execution Modes", + "Debugs", + "Annotations", + "Declarations", + "Function Declarations", + "Function Definitions", + ]; + + for (i, _) in vector_names.iter().enumerate().take(layout_vectors) { + let mut dummy_instruction = Instruction::new(Op::Constant); + dummy_instruction.set_type((i + 1) as u32); + dummy_instruction.set_result((i + 2) as u32); + dummy_instruction.add_operand((i + 3) as u32); + dummy_instruction.add_operands(super::helpers::string_to_words( + format!("This is the vector: {}", vector_names[i]).as_str(), + )); + instructions.push(dummy_instruction); + } + + instructions[0].to_words(&mut layout.capabilities); + instructions[1].to_words(&mut layout.extensions); + instructions[2].to_words(&mut layout.ext_inst_imports); + instructions[3].to_words(&mut layout.memory_model); + instructions[4].to_words(&mut layout.entry_points); + instructions[5].to_words(&mut layout.execution_modes); + instructions[6].to_words(&mut layout.debugs); + instructions[7].to_words(&mut layout.annotations); + instructions[8].to_words(&mut layout.declarations); + instructions[9].to_words(&mut layout.function_declarations); + instructions[10].to_words(&mut layout.function_definitions); + + layout.in_words(&mut output); + + let mut index: usize = 0; + for instruction in instructions { + let wc = instruction.wc as usize; + instruction.validate(&output[index..index + wc]); + index += wc; + } +} diff --git a/third_party/rust/naga/src/back/spv/mod.rs b/third_party/rust/naga/src/back/spv/mod.rs new file mode 100644 index 0000000000..b7d57be0d4 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/mod.rs @@ -0,0 +1,748 @@ +/*! +Backend for [SPIR-V][spv] (Standard Portable Intermediate Representation). + +[spv]: https://www.khronos.org/registry/SPIR-V/ +*/ + +mod block; +mod helpers; +mod image; +mod index; +mod instructions; +mod layout; +mod ray; +mod recyclable; +mod selection; +mod writer; + +pub use spirv::Capability; + +use crate::arena::Handle; +use crate::proc::{BoundsCheckPolicies, TypeResolution}; + +use spirv::Word; +use std::ops; +use thiserror::Error; + +#[derive(Clone)] +struct PhysicalLayout { + magic_number: Word, + version: Word, + generator: Word, + bound: Word, + instruction_schema: Word, +} + +#[derive(Default)] +struct LogicalLayout { + capabilities: Vec<Word>, + extensions: Vec<Word>, + ext_inst_imports: Vec<Word>, + memory_model: Vec<Word>, + entry_points: Vec<Word>, + execution_modes: Vec<Word>, + debugs: Vec<Word>, + annotations: Vec<Word>, + declarations: Vec<Word>, + function_declarations: Vec<Word>, + function_definitions: Vec<Word>, +} + +struct Instruction { + op: spirv::Op, + wc: u32, + type_id: Option<Word>, + result_id: Option<Word>, + operands: Vec<Word>, +} + +const BITS_PER_BYTE: crate::Bytes = 8; + +#[derive(Clone, Debug, Error)] +pub enum Error { + #[error("The requested entry point couldn't be found")] + EntryPointNotFound, + #[error("target SPIRV-{0}.{1} is not supported")] + UnsupportedVersion(u8, u8), + #[error("using {0} requires at least one of the capabilities {1:?}, but none are available")] + MissingCapabilities(&'static str, Vec<Capability>), + #[error("unimplemented {0}")] + FeatureNotImplemented(&'static str), + #[error("module is not validated properly: {0}")] + Validation(&'static str), +} + +#[derive(Default)] +struct IdGenerator(Word); + +impl IdGenerator { + fn next(&mut self) -> Word { + self.0 += 1; + self.0 + } +} + +#[derive(Debug, Clone)] +pub struct DebugInfo<'a> { + pub source_code: &'a str, + pub file_name: &'a std::path::Path, +} + +/// A SPIR-V block to which we are still adding instructions. +/// +/// A `Block` represents a SPIR-V block that does not yet have a termination +/// instruction like `OpBranch` or `OpReturn`. +/// +/// The `OpLabel` that starts the block is implicit. It will be emitted based on +/// `label_id` when we write the block to a `LogicalLayout`. +/// +/// To terminate a `Block`, pass the block and the termination instruction to +/// `Function::consume`. This takes ownership of the `Block` and transforms it +/// into a `TerminatedBlock`. +struct Block { + label_id: Word, + body: Vec<Instruction>, +} + +/// A SPIR-V block that ends with a termination instruction. +struct TerminatedBlock { + label_id: Word, + body: Vec<Instruction>, +} + +impl Block { + const fn new(label_id: Word) -> Self { + Block { + label_id, + body: Vec::new(), + } + } +} + +struct LocalVariable { + id: Word, + instruction: Instruction, +} + +struct ResultMember { + id: Word, + type_id: Word, + built_in: Option<crate::BuiltIn>, +} + +struct EntryPointContext { + argument_ids: Vec<Word>, + results: Vec<ResultMember>, +} + +#[derive(Default)] +struct Function { + signature: Option<Instruction>, + parameters: Vec<FunctionArgument>, + variables: crate::FastHashMap<Handle<crate::LocalVariable>, LocalVariable>, + blocks: Vec<TerminatedBlock>, + entry_point_context: Option<EntryPointContext>, +} + +impl Function { + fn consume(&mut self, mut block: Block, termination: Instruction) { + block.body.push(termination); + self.blocks.push(TerminatedBlock { + label_id: block.label_id, + body: block.body, + }) + } + + fn parameter_id(&self, index: u32) -> Word { + match self.entry_point_context { + Some(ref context) => context.argument_ids[index as usize], + None => self.parameters[index as usize] + .instruction + .result_id + .unwrap(), + } + } +} + +/// Characteristics of a SPIR-V `OpTypeImage` type. +/// +/// SPIR-V requires non-composite types to be unique, including images. Since we +/// use `LocalType` for this deduplication, it's essential that `LocalImageType` +/// be equal whenever the corresponding `OpTypeImage`s would be. To reduce the +/// likelihood of mistakes, we use fields that correspond exactly to the +/// operands of an `OpTypeImage` instruction, using the actual SPIR-V types +/// where practical. +#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] +struct LocalImageType { + sampled_type: crate::ScalarKind, + dim: spirv::Dim, + flags: ImageTypeFlags, + image_format: spirv::ImageFormat, +} + +bitflags::bitflags! { + /// Flags corresponding to the boolean(-ish) parameters to OpTypeImage. + #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] + pub struct ImageTypeFlags: u8 { + const DEPTH = 0x1; + const ARRAYED = 0x2; + const MULTISAMPLED = 0x4; + const SAMPLED = 0x8; + } +} + +impl LocalImageType { + /// Construct a `LocalImageType` from the fields of a `TypeInner::Image`. + fn from_inner(dim: crate::ImageDimension, arrayed: bool, class: crate::ImageClass) -> Self { + let make_flags = |multi: bool, other: ImageTypeFlags| -> ImageTypeFlags { + let mut flags = other; + flags.set(ImageTypeFlags::ARRAYED, arrayed); + flags.set(ImageTypeFlags::MULTISAMPLED, multi); + flags + }; + + let dim = spirv::Dim::from(dim); + + match class { + crate::ImageClass::Sampled { kind, multi } => LocalImageType { + sampled_type: kind, + dim, + flags: make_flags(multi, ImageTypeFlags::SAMPLED), + image_format: spirv::ImageFormat::Unknown, + }, + crate::ImageClass::Depth { multi } => LocalImageType { + sampled_type: crate::ScalarKind::Float, + dim, + flags: make_flags(multi, ImageTypeFlags::DEPTH | ImageTypeFlags::SAMPLED), + image_format: spirv::ImageFormat::Unknown, + }, + crate::ImageClass::Storage { format, access: _ } => LocalImageType { + sampled_type: crate::ScalarKind::from(format), + dim, + flags: make_flags(false, ImageTypeFlags::empty()), + image_format: format.into(), + }, + } + } +} + +/// A SPIR-V type constructed during code generation. +/// +/// This is the variant of [`LookupType`] used to represent types that might not +/// be available in the arena. Variants are present here for one of two reasons: +/// +/// - They represent types synthesized during code generation, as explained +/// in the documentation for [`LookupType`]. +/// +/// - They represent types for which SPIR-V forbids duplicate `OpType...` +/// instructions, requiring deduplication. +/// +/// This is not a complete copy of [`TypeInner`]: for example, SPIR-V generation +/// never synthesizes new struct types, so `LocalType` has nothing for that. +/// +/// Each `LocalType` variant should be handled identically to its analogous +/// `TypeInner` variant. You can use the [`make_local`] function to help with +/// this, by converting everything possible to a `LocalType` before inspecting +/// it. +/// +/// ## `Localtype` equality and SPIR-V `OpType` uniqueness +/// +/// The definition of `Eq` on `LocalType` is carefully chosen to help us follow +/// certain SPIR-V rules. SPIR-V §2.8 requires some classes of `OpType...` +/// instructions to be unique; for example, you can't have two `OpTypeInt 32 1` +/// instructions in the same module. All 32-bit signed integers must use the +/// same type id. +/// +/// All SPIR-V types that must be unique can be represented as a `LocalType`, +/// and two `LocalType`s are always `Eq` if SPIR-V would require them to use the +/// same `OpType...` instruction. This lets us avoid duplicates by recording the +/// ids of the type instructions we've already generated in a hash table, +/// [`Writer::lookup_type`], keyed by `LocalType`. +/// +/// As another example, [`LocalImageType`], stored in the `LocalType::Image` +/// variant, is designed to help us deduplicate `OpTypeImage` instructions. See +/// its documentation for details. +/// +/// `LocalType` also includes variants like `Pointer` that do not need to be +/// unique - but it is harmless to avoid the duplication. +/// +/// As it always must, the `Hash` implementation respects the `Eq` relation. +/// +/// [`TypeInner`]: crate::TypeInner +#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] +enum LocalType { + /// A scalar, vector, or pointer to one of those. + Value { + /// If `None`, this represents a scalar type. If `Some`, this represents + /// a vector type of the given size. + vector_size: Option<crate::VectorSize>, + scalar: crate::Scalar, + pointer_space: Option<spirv::StorageClass>, + }, + /// A matrix of floating-point values. + Matrix { + columns: crate::VectorSize, + rows: crate::VectorSize, + width: crate::Bytes, + }, + Pointer { + base: Handle<crate::Type>, + class: spirv::StorageClass, + }, + Image(LocalImageType), + SampledImage { + image_type_id: Word, + }, + Sampler, + /// Equivalent to a [`LocalType::Pointer`] whose `base` is a Naga IR [`BindingArray`]. SPIR-V + /// permits duplicated `OpTypePointer` ids, so it's fine to have two different [`LocalType`] + /// representations for pointer types. + /// + /// [`BindingArray`]: crate::TypeInner::BindingArray + PointerToBindingArray { + base: Handle<crate::Type>, + size: u32, + space: crate::AddressSpace, + }, + BindingArray { + base: Handle<crate::Type>, + size: u32, + }, + AccelerationStructure, + RayQuery, +} + +/// A type encountered during SPIR-V generation. +/// +/// In the process of writing SPIR-V, we need to synthesize various types for +/// intermediate results and such: pointer types, vector/matrix component types, +/// or even booleans, which usually appear in SPIR-V code even when they're not +/// used by the module source. +/// +/// However, we can't use `crate::Type` or `crate::TypeInner` for these, as the +/// type arena may not contain what we need (it only contains types used +/// directly by other parts of the IR), and the IR module is immutable, so we +/// can't add anything to it. +/// +/// So for local use in the SPIR-V writer, we use this type, which holds either +/// a handle into the arena, or a [`LocalType`] containing something synthesized +/// locally. +/// +/// This is very similar to the [`proc::TypeResolution`] enum, with `LocalType` +/// playing the role of `TypeInner`. However, `LocalType` also has other +/// properties needed for SPIR-V generation; see the description of +/// [`LocalType`] for details. +/// +/// [`proc::TypeResolution`]: crate::proc::TypeResolution +#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] +enum LookupType { + Handle(Handle<crate::Type>), + Local(LocalType), +} + +impl From<LocalType> for LookupType { + fn from(local: LocalType) -> Self { + Self::Local(local) + } +} + +#[derive(Debug, PartialEq, Clone, Hash, Eq)] +struct LookupFunctionType { + parameter_type_ids: Vec<Word>, + return_type_id: Word, +} + +fn make_local(inner: &crate::TypeInner) -> Option<LocalType> { + Some(match *inner { + crate::TypeInner::Scalar(scalar) | crate::TypeInner::Atomic(scalar) => LocalType::Value { + vector_size: None, + scalar, + pointer_space: None, + }, + crate::TypeInner::Vector { size, scalar } => LocalType::Value { + vector_size: Some(size), + scalar, + pointer_space: None, + }, + crate::TypeInner::Matrix { + columns, + rows, + scalar, + } => LocalType::Matrix { + columns, + rows, + width: scalar.width, + }, + crate::TypeInner::Pointer { base, space } => LocalType::Pointer { + base, + class: helpers::map_storage_class(space), + }, + crate::TypeInner::ValuePointer { + size, + scalar, + space, + } => LocalType::Value { + vector_size: size, + scalar, + pointer_space: Some(helpers::map_storage_class(space)), + }, + crate::TypeInner::Image { + dim, + arrayed, + class, + } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)), + crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler, + crate::TypeInner::AccelerationStructure => LocalType::AccelerationStructure, + crate::TypeInner::RayQuery => LocalType::RayQuery, + crate::TypeInner::Array { .. } + | crate::TypeInner::Struct { .. } + | crate::TypeInner::BindingArray { .. } => return None, + }) +} + +#[derive(Debug)] +enum Dimension { + Scalar, + Vector, + Matrix, +} + +/// A map from evaluated [`Expression`](crate::Expression)s to their SPIR-V ids. +/// +/// When we emit code to evaluate a given `Expression`, we record the +/// SPIR-V id of its value here, under its `Handle<Expression>` index. +/// +/// A `CachedExpressions` value can be indexed by a `Handle<Expression>` value. +/// +/// [emit]: index.html#expression-evaluation-time-and-scope +#[derive(Default)] +struct CachedExpressions { + ids: Vec<Word>, +} +impl CachedExpressions { + fn reset(&mut self, length: usize) { + self.ids.clear(); + self.ids.resize(length, 0); + } +} +impl ops::Index<Handle<crate::Expression>> for CachedExpressions { + type Output = Word; + fn index(&self, h: Handle<crate::Expression>) -> &Word { + let id = &self.ids[h.index()]; + if *id == 0 { + unreachable!("Expression {:?} is not cached!", h); + } + id + } +} +impl ops::IndexMut<Handle<crate::Expression>> for CachedExpressions { + fn index_mut(&mut self, h: Handle<crate::Expression>) -> &mut Word { + let id = &mut self.ids[h.index()]; + if *id != 0 { + unreachable!("Expression {:?} is already cached!", h); + } + id + } +} +impl recyclable::Recyclable for CachedExpressions { + fn recycle(self) -> Self { + CachedExpressions { + ids: self.ids.recycle(), + } + } +} + +#[derive(Eq, Hash, PartialEq)] +enum CachedConstant { + Literal(crate::Literal), + Composite { + ty: LookupType, + constituent_ids: Vec<Word>, + }, + ZeroValue(Word), +} + +#[derive(Clone)] +struct GlobalVariable { + /// ID of the OpVariable that declares the global. + /// + /// If you need the variable's value, use [`access_id`] instead of this + /// field. If we wrapped the Naga IR `GlobalVariable`'s type in a struct to + /// comply with Vulkan's requirements, then this points to the `OpVariable` + /// with the synthesized struct type, whereas `access_id` points to the + /// field of said struct that holds the variable's actual value. + /// + /// This is used to compute the `access_id` pointer in function prologues, + /// and used for `ArrayLength` expressions, which do need the struct. + /// + /// [`access_id`]: GlobalVariable::access_id + var_id: Word, + + /// For `AddressSpace::Handle` variables, this ID is recorded in the function + /// prelude block (and reset before every function) as `OpLoad` of the variable. + /// It is then used for all the global ops, such as `OpImageSample`. + handle_id: Word, + + /// Actual ID used to access this variable. + /// For wrapped buffer variables, this ID is `OpAccessChain` into the + /// wrapper. Otherwise, the same as `var_id`. + /// + /// Vulkan requires that globals in the `StorageBuffer` and `Uniform` storage + /// classes must be structs with the `Block` decoration, but WGSL and Naga IR + /// make no such requirement. So for such variables, we generate a wrapper struct + /// type with a single element of the type given by Naga, generate an + /// `OpAccessChain` for that member in the function prelude, and use that pointer + /// to refer to the global in the function body. This is the id of that access, + /// updated for each function in `write_function`. + access_id: Word, +} + +impl GlobalVariable { + const fn dummy() -> Self { + Self { + var_id: 0, + handle_id: 0, + access_id: 0, + } + } + + const fn new(id: Word) -> Self { + Self { + var_id: id, + handle_id: 0, + access_id: 0, + } + } + + /// Prepare `self` for use within a single function. + fn reset_for_function(&mut self) { + self.handle_id = 0; + self.access_id = 0; + } +} + +struct FunctionArgument { + /// Actual instruction of the argument. + instruction: Instruction, + handle_id: Word, +} + +/// General information needed to emit SPIR-V for Naga statements. +struct BlockContext<'w> { + /// The writer handling the module to which this code belongs. + writer: &'w mut Writer, + + /// The [`Module`](crate::Module) for which we're generating code. + ir_module: &'w crate::Module, + + /// The [`Function`](crate::Function) for which we're generating code. + ir_function: &'w crate::Function, + + /// Information module validation produced about + /// [`ir_function`](BlockContext::ir_function). + fun_info: &'w crate::valid::FunctionInfo, + + /// The [`spv::Function`](Function) to which we are contributing SPIR-V instructions. + function: &'w mut Function, + + /// SPIR-V ids for expressions we've evaluated. + cached: CachedExpressions, + + /// The `Writer`'s temporary vector, for convenience. + temp_list: Vec<Word>, + + /// Tracks the constness of `Expression`s residing in `self.ir_function.expressions` + expression_constness: crate::proc::ExpressionConstnessTracker, +} + +impl BlockContext<'_> { + fn gen_id(&mut self) -> Word { + self.writer.id_gen.next() + } + + fn get_type_id(&mut self, lookup_type: LookupType) -> Word { + self.writer.get_type_id(lookup_type) + } + + fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word { + self.writer.get_expression_type_id(tr) + } + + fn get_index_constant(&mut self, index: Word) -> Word { + self.writer.get_constant_scalar(crate::Literal::U32(index)) + } + + fn get_scope_constant(&mut self, scope: Word) -> Word { + self.writer + .get_constant_scalar(crate::Literal::I32(scope as _)) + } +} + +#[derive(Clone, Copy, Default)] +struct LoopContext { + continuing_id: Option<Word>, + break_id: Option<Word>, +} + +pub struct Writer { + physical_layout: PhysicalLayout, + logical_layout: LogicalLayout, + id_gen: IdGenerator, + + /// The set of capabilities modules are permitted to use. + /// + /// This is initialized from `Options::capabilities`. + capabilities_available: Option<crate::FastHashSet<Capability>>, + + /// The set of capabilities used by this module. + /// + /// If `capabilities_available` is `Some`, then this is always a subset of + /// that. + capabilities_used: crate::FastIndexSet<Capability>, + + /// The set of spirv extensions used. + extensions_used: crate::FastIndexSet<&'static str>, + + debugs: Vec<Instruction>, + annotations: Vec<Instruction>, + flags: WriterFlags, + bounds_check_policies: BoundsCheckPolicies, + zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode, + void_type: Word, + //TODO: convert most of these into vectors, addressable by handle indices + lookup_type: crate::FastHashMap<LookupType, Word>, + lookup_function: crate::FastHashMap<Handle<crate::Function>, Word>, + lookup_function_type: crate::FastHashMap<LookupFunctionType, Word>, + /// Indexed by const-expression handle indexes + constant_ids: Vec<Word>, + cached_constants: crate::FastHashMap<CachedConstant, Word>, + global_variables: Vec<GlobalVariable>, + binding_map: BindingMap, + + // Cached expressions are only meaningful within a BlockContext, but we + // retain the table here between functions to save heap allocations. + saved_cached: CachedExpressions, + + gl450_ext_inst_id: Word, + + // Just a temporary list of SPIR-V ids + temp_list: Vec<Word>, +} + +bitflags::bitflags! { + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct WriterFlags: u32 { + /// Include debug labels for everything. + const DEBUG = 0x1; + /// Flip Y coordinate of `BuiltIn::Position` output. + const ADJUST_COORDINATE_SPACE = 0x2; + /// Emit `OpName` for input/output locations. + /// Contrary to spec, some drivers treat it as semantic, not allowing + /// any conflicts. + const LABEL_VARYINGS = 0x4; + /// Emit `PointSize` output builtin to vertex shaders, which is + /// required for drawing with `PointList` topology. + const FORCE_POINT_SIZE = 0x8; + /// Clamp `BuiltIn::FragDepth` output between 0 and 1. + const CLAMP_FRAG_DEPTH = 0x10; + } +} + +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct BindingInfo { + /// If the binding is an unsized binding array, this overrides the size. + pub binding_array_size: Option<u32>, +} + +// Using `BTreeMap` instead of `HashMap` so that we can hash itself. +pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindingInfo>; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ZeroInitializeWorkgroupMemoryMode { + /// Via `VK_KHR_zero_initialize_workgroup_memory` or Vulkan 1.3 + Native, + /// Via assignments + barrier + Polyfill, + None, +} + +#[derive(Debug, Clone)] +pub struct Options<'a> { + /// (Major, Minor) target version of the SPIR-V. + pub lang_version: (u8, u8), + + /// Configuration flags for the writer. + pub flags: WriterFlags, + + /// Map of resources to information about the binding. + pub binding_map: BindingMap, + + /// If given, the set of capabilities modules are allowed to use. Code that + /// requires capabilities beyond these is rejected with an error. + /// + /// If this is `None`, all capabilities are permitted. + pub capabilities: Option<crate::FastHashSet<Capability>>, + + /// How should generate code handle array, vector, matrix, or image texel + /// indices that are out of range? + pub bounds_check_policies: BoundsCheckPolicies, + + /// Dictates the way workgroup variables should be zero initialized + pub zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode, + + pub debug_info: Option<DebugInfo<'a>>, +} + +impl<'a> Default for Options<'a> { + fn default() -> Self { + let mut flags = WriterFlags::ADJUST_COORDINATE_SPACE + | WriterFlags::LABEL_VARYINGS + | WriterFlags::CLAMP_FRAG_DEPTH; + if cfg!(debug_assertions) { + flags |= WriterFlags::DEBUG; + } + Options { + lang_version: (1, 0), + flags, + binding_map: BindingMap::default(), + capabilities: None, + bounds_check_policies: crate::proc::BoundsCheckPolicies::default(), + zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode::Polyfill, + debug_info: None, + } + } +} + +// A subset of options meant to be changed per pipeline. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct PipelineOptions { + /// The stage of the entry point. + pub shader_stage: crate::ShaderStage, + /// The name of the entry point. + /// + /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown. + pub entry_point: String, +} + +pub fn write_vec( + module: &crate::Module, + info: &crate::valid::ModuleInfo, + options: &Options, + pipeline_options: Option<&PipelineOptions>, +) -> Result<Vec<u32>, Error> { + let mut words: Vec<u32> = Vec::new(); + let mut w = Writer::new(options)?; + + w.write( + module, + info, + pipeline_options, + &options.debug_info, + &mut words, + )?; + Ok(words) +} diff --git a/third_party/rust/naga/src/back/spv/ray.rs b/third_party/rust/naga/src/back/spv/ray.rs new file mode 100644 index 0000000000..bc2c4ce3c6 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/ray.rs @@ -0,0 +1,261 @@ +/*! +Generating SPIR-V for ray query operations. +*/ + +use super::{Block, BlockContext, Instruction, LocalType, LookupType}; +use crate::arena::Handle; + +impl<'w> BlockContext<'w> { + pub(super) fn write_ray_query_function( + &mut self, + query: Handle<crate::Expression>, + function: &crate::RayQueryFunction, + block: &mut Block, + ) { + let query_id = self.cached[query]; + match *function { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + //Note: composite extract indices and types must match `generate_ray_desc_type` + let desc_id = self.cached[descriptor]; + let acc_struct_id = self.get_handle_id(acceleration_structure); + + let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::U32, + pointer_space: None, + })); + let ray_flags_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + flag_type_id, + ray_flags_id, + desc_id, + &[0], + )); + let cull_mask_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + flag_type_id, + cull_mask_id, + desc_id, + &[1], + )); + + let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::F32, + pointer_space: None, + })); + let tmin_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmin_id, + desc_id, + &[2], + )); + let tmax_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmax_id, + desc_id, + &[3], + )); + + let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + scalar: crate::Scalar::F32, + pointer_space: None, + })); + let ray_origin_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_origin_id, + desc_id, + &[4], + )); + let ray_dir_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_dir_id, + desc_id, + &[5], + )); + + block.body.push(Instruction::ray_query_initialize( + query_id, + acc_struct_id, + ray_flags_id, + cull_mask_id, + ray_origin_id, + tmin_id, + ray_dir_id, + tmax_id, + )); + } + crate::RayQueryFunction::Proceed { result } => { + let id = self.gen_id(); + self.cached[result] = id; + let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty); + + block + .body + .push(Instruction::ray_query_proceed(result_type_id, id, query_id)); + } + crate::RayQueryFunction::Terminate => {} + } + } + + pub(super) fn write_ray_query_get_intersection( + &mut self, + query: Handle<crate::Expression>, + block: &mut Block, + ) -> spirv::Word { + let query_id = self.cached[query]; + let intersection_id = self.writer.get_constant_scalar(crate::Literal::U32( + spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _, + )); + + let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::U32, + pointer_space: None, + })); + let kind_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTypeKHR, + flag_type_id, + kind_id, + query_id, + intersection_id, + )); + let instance_custom_index_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR, + flag_type_id, + instance_custom_index_id, + query_id, + intersection_id, + )); + let instance_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceIdKHR, + flag_type_id, + instance_id, + query_id, + intersection_id, + )); + let sbt_record_offset_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR, + flag_type_id, + sbt_record_offset_id, + query_id, + intersection_id, + )); + let geometry_index_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionGeometryIndexKHR, + flag_type_id, + geometry_index_id, + query_id, + intersection_id, + )); + let primitive_index_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR, + flag_type_id, + primitive_index_id, + query_id, + intersection_id, + )); + + let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::F32, + pointer_space: None, + })); + let t_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTKHR, + scalar_type_id, + t_id, + query_id, + intersection_id, + )); + + let barycentrics_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Bi), + scalar: crate::Scalar::F32, + pointer_space: None, + })); + let barycentrics_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionBarycentricsKHR, + barycentrics_type_id, + barycentrics_id, + query_id, + intersection_id, + )); + + let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::BOOL, + pointer_space: None, + })); + let front_face_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionFrontFaceKHR, + bool_type_id, + front_face_id, + query_id, + intersection_id, + )); + + let transform_type_id = self.get_type_id(LookupType::Local(LocalType::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + width: 4, + })); + let object_to_world_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionObjectToWorldKHR, + transform_type_id, + object_to_world_id, + query_id, + intersection_id, + )); + let world_to_object_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionWorldToObjectKHR, + transform_type_id, + world_to_object_id, + query_id, + intersection_id, + )); + + let id = self.gen_id(); + let intersection_type_id = self.get_type_id(LookupType::Handle( + self.ir_module.special_types.ray_intersection.unwrap(), + )); + //Note: the arguments must match `generate_ray_intersection_type` layout + block.body.push(Instruction::composite_construct( + intersection_type_id, + id, + &[ + kind_id, + t_id, + instance_custom_index_id, + instance_id, + sbt_record_offset_id, + geometry_index_id, + primitive_index_id, + barycentrics_id, + front_face_id, + object_to_world_id, + world_to_object_id, + ], + )); + id + } +} diff --git a/third_party/rust/naga/src/back/spv/recyclable.rs b/third_party/rust/naga/src/back/spv/recyclable.rs new file mode 100644 index 0000000000..cd1466e3c7 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/recyclable.rs @@ -0,0 +1,67 @@ +/*! +Reusing collections' previous allocations. +*/ + +/// A value that can be reset to its initial state, retaining its current allocations. +/// +/// Naga attempts to lower the cost of SPIR-V generation by allowing clients to +/// reuse the same `Writer` for multiple Module translations. Reusing a `Writer` +/// means that the `Vec`s, `HashMap`s, and other heap-allocated structures the +/// `Writer` uses internally begin the translation with heap-allocated buffers +/// ready to use. +/// +/// But this approach introduces the risk of `Writer` state leaking from one +/// module to the next. When a developer adds fields to `Writer` or its internal +/// types, they must remember to reset their contents between modules. +/// +/// One trick to ensure that every field has been accounted for is to use Rust's +/// struct literal syntax to construct a new, reset value. If a developer adds a +/// field, but neglects to update the reset code, the compiler will complain +/// that a field is missing from the literal. This trait's `recycle` method +/// takes `self` by value, and returns `Self` by value, encouraging the use of +/// struct literal expressions in its implementation. +pub trait Recyclable { + /// Clear `self`, retaining its current memory allocations. + /// + /// Shrink the buffer if it's currently much larger than was actually used. + /// This prevents a module with exceptionally large allocations from causing + /// the `Writer` to retain more memory than it needs indefinitely. + fn recycle(self) -> Self; +} + +// Stock values for various collections. + +impl<T> Recyclable for Vec<T> { + fn recycle(mut self) -> Self { + self.clear(); + self + } +} + +impl<K, V, S: Clone> Recyclable for std::collections::HashMap<K, V, S> { + fn recycle(mut self) -> Self { + self.clear(); + self + } +} + +impl<K, S: Clone> Recyclable for std::collections::HashSet<K, S> { + fn recycle(mut self) -> Self { + self.clear(); + self + } +} + +impl<K, S: Clone> Recyclable for indexmap::IndexSet<K, S> { + fn recycle(mut self) -> Self { + self.clear(); + self + } +} + +impl<K: Ord, V> Recyclable for std::collections::BTreeMap<K, V> { + fn recycle(mut self) -> Self { + self.clear(); + self + } +} diff --git a/third_party/rust/naga/src/back/spv/selection.rs b/third_party/rust/naga/src/back/spv/selection.rs new file mode 100644 index 0000000000..788b1f10ab --- /dev/null +++ b/third_party/rust/naga/src/back/spv/selection.rs @@ -0,0 +1,257 @@ +/*! +Generate SPIR-V conditional structures. + +Builders for `if` structures with `and`s. + +The types in this module track the information needed to emit SPIR-V code +for complex conditional structures, like those whose conditions involve +short-circuiting 'and' and 'or' structures. These track labels and can emit +`OpPhi` instructions to merge values produced along different paths. + +This currently only supports exactly the forms Naga uses, so it doesn't +support `or` or `else`, and only supports zero or one merged values. + +Naga needs to emit code roughly like this: + +```ignore + + value = DEFAULT; + if COND1 && COND2 { + value = THEN_VALUE; + } + // use value + +``` + +Assuming `ctx` and `block` are a mutable references to a [`BlockContext`] +and the current [`Block`], and `merge_type` is the SPIR-V type for the +merged value `value`, we can build SPIR-V for the code above like so: + +```ignore + + let cond = Selection::start(block, merge_type); + // ... compute `cond1` ... + cond.if_true(ctx, cond1, DEFAULT); + // ... compute `cond2` ... + cond.if_true(ctx, cond2, DEFAULT); + // ... compute THEN_VALUE + let merged_value = cond.finish(ctx, THEN_VALUE); + +``` + +After this, `merged_value` is either `DEFAULT` or `THEN_VALUE`, depending on +the path by which the merged block was reached. + +This takes care of writing all branch instructions, including an +`OpSelectionMerge` annotation in the header block; starting new blocks and +assigning them labels; and emitting the `OpPhi` that gathers together the +right sources for the merged values, for every path through the selection +construct. + +When there is no merged value to produce, you can pass `()` for `merge_type` +and the merge values. In this case no `OpPhi` instructions are produced, and +the `finish` method returns `()`. + +To enforce proper nesting, a `Selection` takes ownership of the `&mut Block` +pointer for the duration of its lifetime. To obtain the block for generating +code in the selection's body, call the `Selection::block` method. +*/ + +use super::{Block, BlockContext, Instruction}; +use spirv::Word; + +/// A private struct recording what we know about the selection construct so far. +pub(super) struct Selection<'b, M: MergeTuple> { + /// The block pointer we're emitting code into. + block: &'b mut Block, + + /// The label of the selection construct's merge block, or `None` if we + /// haven't yet written the `OpSelectionMerge` merge instruction. + merge_label: Option<Word>, + + /// A set of `(VALUES, PARENT)` pairs, used to build `OpPhi` instructions in + /// the merge block. Each `PARENT` is the label of a predecessor block of + /// the merge block. The corresponding `VALUES` holds the ids of the values + /// that `PARENT` contributes to the merged values. + /// + /// We emit all branches to the merge block, so we know all its + /// predecessors. And we refuse to emit a branch unless we're given the + /// values the branching block contributes to the merge, so we always have + /// everything we need to emit the correct phis, by construction. + values: Vec<(M, Word)>, + + /// The types of the values in each element of `values`. + merge_types: M, +} + +impl<'b, M: MergeTuple> Selection<'b, M> { + /// Start a new selection construct. + /// + /// The `block` argument indicates the selection's header block. + /// + /// The `merge_types` argument should be a `Word` or tuple of `Word`s, each + /// value being the SPIR-V result type id of an `OpPhi` instruction that + /// will be written to the selection's merge block when this selection's + /// [`finish`] method is called. This argument may also be `()`, for + /// selections that produce no values. + /// + /// (This function writes no code to `block` itself; it simply constructs a + /// fresh `Selection`.) + /// + /// [`finish`]: Selection::finish + pub(super) fn start(block: &'b mut Block, merge_types: M) -> Self { + Selection { + block, + merge_label: None, + values: vec![], + merge_types, + } + } + + pub(super) fn block(&mut self) -> &mut Block { + self.block + } + + /// Branch to a successor block if `cond` is true, otherwise merge. + /// + /// If `cond` is false, branch to the merge block, using `values` as the + /// merged values. Otherwise, proceed to a new block. + /// + /// The `values` argument must be the same shape as the `merge_types` + /// argument passed to `Selection::start`. + pub(super) fn if_true(&mut self, ctx: &mut BlockContext, cond: Word, values: M) { + self.values.push((values, self.block.label_id)); + + let merge_label = self.make_merge_label(ctx); + let next_label = ctx.gen_id(); + ctx.function.consume( + std::mem::replace(self.block, Block::new(next_label)), + Instruction::branch_conditional(cond, next_label, merge_label), + ); + } + + /// Emit an unconditional branch to the merge block, and compute merged + /// values. + /// + /// Use `final_values` as the merged values contributed by the current + /// block, and transition to the merge block, emitting `OpPhi` instructions + /// to produce the merged values. This must be the same shape as the + /// `merge_types` argument passed to [`Selection::start`]. + /// + /// Return the SPIR-V ids of the merged values. This value has the same + /// shape as the `merge_types` argument passed to `Selection::start`. + pub(super) fn finish(self, ctx: &mut BlockContext, final_values: M) -> M { + match self { + Selection { + merge_label: None, .. + } => { + // We didn't actually emit any branches, so `self.values` must + // be empty, and `final_values` are the only sources we have for + // the merged values. Easy peasy. + final_values + } + + Selection { + block, + merge_label: Some(merge_label), + mut values, + merge_types, + } => { + // Emit the final branch and transition to the merge block. + values.push((final_values, block.label_id)); + ctx.function.consume( + std::mem::replace(block, Block::new(merge_label)), + Instruction::branch(merge_label), + ); + + // Now that we're in the merge block, build the phi instructions. + merge_types.write_phis(ctx, block, &values) + } + } + } + + /// Return the id of the merge block, writing a merge instruction if needed. + fn make_merge_label(&mut self, ctx: &mut BlockContext) -> Word { + match self.merge_label { + None => { + let merge_label = ctx.gen_id(); + self.block.body.push(Instruction::selection_merge( + merge_label, + spirv::SelectionControl::NONE, + )); + self.merge_label = Some(merge_label); + merge_label + } + Some(merge_label) => merge_label, + } + } +} + +/// A trait to help `Selection` manage any number of merged values. +/// +/// Some selection constructs, like a `ReadZeroSkipWrite` bounds check on a +/// [`Load`] expression, produce a single merged value. Others produce no merged +/// value, like a bounds check on a [`Store`] statement. +/// +/// To let `Selection` work nicely with both cases, we let the merge type +/// argument passed to [`Selection::start`] be any type that implements this +/// `MergeTuple` trait. `MergeTuple` is then implemented for `()`, `Word`, +/// `(Word, Word)`, and so on. +/// +/// A `MergeTuple` type can represent either a bunch of SPIR-V types or values; +/// the `merge_types` argument to `Selection::start` are type ids, whereas the +/// `values` arguments to the [`if_true`] and [`finish`] methods are value ids. +/// The set of merged value returned by `finish` is a tuple of value ids. +/// +/// In fact, since Naga only uses zero- and single-valued selection constructs +/// at present, we only implement `MergeTuple` for `()` and `Word`. But if you +/// add more cases, feel free to add more implementations. Once const generics +/// are available, we could have a single implementation of `MergeTuple` for all +/// lengths of arrays, and be done with it. +/// +/// [`Load`]: crate::Expression::Load +/// [`Store`]: crate::Statement::Store +/// [`if_true`]: Selection::if_true +/// [`finish`]: Selection::finish +pub(super) trait MergeTuple: Sized { + /// Write OpPhi instructions for the given set of predecessors. + /// + /// The `predecessors` vector should be a vector of `(LABEL, VALUES)` pairs, + /// where each `VALUES` holds the values contributed by the branch from + /// `LABEL`, which should be one of the current block's predecessors. + fn write_phis( + self, + ctx: &mut BlockContext, + block: &mut Block, + predecessors: &[(Self, Word)], + ) -> Self; +} + +/// Selections that produce a single merged value. +/// +/// For example, `ImageLoad` with `BoundsCheckPolicy::ReadZeroSkipWrite` either +/// returns a texel value or zeros. +impl MergeTuple for Word { + fn write_phis( + self, + ctx: &mut BlockContext, + block: &mut Block, + predecessors: &[(Word, Word)], + ) -> Word { + let merged_value = ctx.gen_id(); + block + .body + .push(Instruction::phi(self, merged_value, predecessors)); + merged_value + } +} + +/// Selections that produce no merged values. +/// +/// For example, `ImageStore` under `BoundsCheckPolicy::ReadZeroSkipWrite` +/// either does the store or skips it, but in neither case does it produce a +/// value. +impl MergeTuple for () { + /// No phis need to be generated. + fn write_phis(self, _: &mut BlockContext, _: &mut Block, _: &[((), Word)]) {} +} diff --git a/third_party/rust/naga/src/back/spv/writer.rs b/third_party/rust/naga/src/back/spv/writer.rs new file mode 100644 index 0000000000..4db86c93a7 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/writer.rs @@ -0,0 +1,2063 @@ +use super::{ + block::DebugInfoInner, + helpers::{contains_builtin, global_needs_wrapper, map_storage_class}, + make_local, Block, BlockContext, CachedConstant, CachedExpressions, DebugInfo, + EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, + LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, LoopContext, Options, + PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE, +}; +use crate::{ + arena::{Handle, UniqueArena}, + back::spv::BindingInfo, + proc::{Alignment, TypeResolution}, + valid::{FunctionInfo, ModuleInfo}, +}; +use spirv::Word; +use std::collections::hash_map::Entry; + +struct FunctionInterface<'a> { + varying_ids: &'a mut Vec<Word>, + stage: crate::ShaderStage, +} + +impl Function { + fn to_words(&self, sink: &mut impl Extend<Word>) { + self.signature.as_ref().unwrap().to_words(sink); + for argument in self.parameters.iter() { + argument.instruction.to_words(sink); + } + for (index, block) in self.blocks.iter().enumerate() { + Instruction::label(block.label_id).to_words(sink); + if index == 0 { + for local_var in self.variables.values() { + local_var.instruction.to_words(sink); + } + } + for instruction in block.body.iter() { + instruction.to_words(sink); + } + } + } +} + +impl Writer { + pub fn new(options: &Options) -> Result<Self, Error> { + let (major, minor) = options.lang_version; + if major != 1 { + return Err(Error::UnsupportedVersion(major, minor)); + } + let raw_version = ((major as u32) << 16) | ((minor as u32) << 8); + + let mut capabilities_used = crate::FastIndexSet::default(); + capabilities_used.insert(spirv::Capability::Shader); + + let mut id_gen = IdGenerator::default(); + let gl450_ext_inst_id = id_gen.next(); + let void_type = id_gen.next(); + + Ok(Writer { + physical_layout: PhysicalLayout::new(raw_version), + logical_layout: LogicalLayout::default(), + id_gen, + capabilities_available: options.capabilities.clone(), + capabilities_used, + extensions_used: crate::FastIndexSet::default(), + debugs: vec![], + annotations: vec![], + flags: options.flags, + bounds_check_policies: options.bounds_check_policies, + zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory, + void_type, + lookup_type: crate::FastHashMap::default(), + lookup_function: crate::FastHashMap::default(), + lookup_function_type: crate::FastHashMap::default(), + constant_ids: Vec::new(), + cached_constants: crate::FastHashMap::default(), + global_variables: Vec::new(), + binding_map: options.binding_map.clone(), + saved_cached: CachedExpressions::default(), + gl450_ext_inst_id, + temp_list: Vec::new(), + }) + } + + /// Reset `Writer` to its initial state, retaining any allocations. + /// + /// Why not just implement `Recyclable` for `Writer`? By design, + /// `Recyclable::recycle` requires ownership of the value, not just + /// `&mut`; see the trait documentation. But we need to use this method + /// from functions like `Writer::write`, which only have `&mut Writer`. + /// Workarounds include unsafe code (`std::ptr::read`, then `write`, ugh) + /// or something like a `Default` impl that returns an oddly-initialized + /// `Writer`, which is worse. + fn reset(&mut self) { + use super::recyclable::Recyclable; + use std::mem::take; + + let mut id_gen = IdGenerator::default(); + let gl450_ext_inst_id = id_gen.next(); + let void_type = id_gen.next(); + + // Every field of the old writer that is not determined by the `Options` + // passed to `Writer::new` should be reset somehow. + let fresh = Writer { + // Copied from the old Writer: + flags: self.flags, + bounds_check_policies: self.bounds_check_policies, + zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory, + capabilities_available: take(&mut self.capabilities_available), + binding_map: take(&mut self.binding_map), + + // Initialized afresh: + id_gen, + void_type, + gl450_ext_inst_id, + + // Recycled: + capabilities_used: take(&mut self.capabilities_used).recycle(), + extensions_used: take(&mut self.extensions_used).recycle(), + physical_layout: self.physical_layout.clone().recycle(), + logical_layout: take(&mut self.logical_layout).recycle(), + debugs: take(&mut self.debugs).recycle(), + annotations: take(&mut self.annotations).recycle(), + lookup_type: take(&mut self.lookup_type).recycle(), + lookup_function: take(&mut self.lookup_function).recycle(), + lookup_function_type: take(&mut self.lookup_function_type).recycle(), + constant_ids: take(&mut self.constant_ids).recycle(), + cached_constants: take(&mut self.cached_constants).recycle(), + global_variables: take(&mut self.global_variables).recycle(), + saved_cached: take(&mut self.saved_cached).recycle(), + temp_list: take(&mut self.temp_list).recycle(), + }; + + *self = fresh; + + self.capabilities_used.insert(spirv::Capability::Shader); + } + + /// Indicate that the code requires any one of the listed capabilities. + /// + /// If nothing in `capabilities` appears in the available capabilities + /// specified in the [`Options`] from which this `Writer` was created, + /// return an error. The `what` string is used in the error message to + /// explain what provoked the requirement. (If no available capabilities were + /// given, assume everything is available.) + /// + /// The first acceptable capability will be added to this `Writer`'s + /// [`capabilities_used`] table, and an `OpCapability` emitted for it in the + /// result. For this reason, more specific capabilities should be listed + /// before more general. + /// + /// [`capabilities_used`]: Writer::capabilities_used + pub(super) fn require_any( + &mut self, + what: &'static str, + capabilities: &[spirv::Capability], + ) -> Result<(), Error> { + match *capabilities { + [] => Ok(()), + [first, ..] => { + // Find the first acceptable capability, or return an error if + // there is none. + let selected = match self.capabilities_available { + None => first, + Some(ref available) => { + match capabilities.iter().find(|cap| available.contains(cap)) { + Some(&cap) => cap, + None => { + return Err(Error::MissingCapabilities(what, capabilities.to_vec())) + } + } + } + }; + self.capabilities_used.insert(selected); + Ok(()) + } + } + } + + /// Indicate that the code uses the given extension. + pub(super) fn use_extension(&mut self, extension: &'static str) { + self.extensions_used.insert(extension); + } + + pub(super) fn get_type_id(&mut self, lookup_ty: LookupType) -> Word { + match self.lookup_type.entry(lookup_ty) { + Entry::Occupied(e) => *e.get(), + Entry::Vacant(e) => { + let local = match lookup_ty { + LookupType::Handle(_handle) => unreachable!("Handles are populated at start"), + LookupType::Local(local) => local, + }; + + let id = self.id_gen.next(); + e.insert(id); + self.write_type_declaration_local(id, local); + id + } + } + } + + pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType { + match *tr { + TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle), + TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()), + } + } + + pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word { + let lookup_ty = self.get_expression_lookup_type(tr); + self.get_type_id(lookup_ty) + } + + pub(super) fn get_pointer_id( + &mut self, + arena: &UniqueArena<crate::Type>, + handle: Handle<crate::Type>, + class: spirv::StorageClass, + ) -> Result<Word, Error> { + let ty_id = self.get_type_id(LookupType::Handle(handle)); + if let crate::TypeInner::Pointer { .. } = arena[handle].inner { + return Ok(ty_id); + } + let lookup_type = LookupType::Local(LocalType::Pointer { + base: handle, + class, + }); + Ok(if let Some(&id) = self.lookup_type.get(&lookup_type) { + id + } else { + let id = self.id_gen.next(); + let instruction = Instruction::type_pointer(id, class, ty_id); + instruction.to_words(&mut self.logical_layout.declarations); + self.lookup_type.insert(lookup_type, id); + id + }) + } + + pub(super) fn get_uint_type_id(&mut self) -> Word { + let local_type = LocalType::Value { + vector_size: None, + scalar: crate::Scalar::U32, + pointer_space: None, + }; + self.get_type_id(local_type.into()) + } + + pub(super) fn get_float_type_id(&mut self) -> Word { + let local_type = LocalType::Value { + vector_size: None, + scalar: crate::Scalar::F32, + pointer_space: None, + }; + self.get_type_id(local_type.into()) + } + + pub(super) fn get_uint3_type_id(&mut self) -> Word { + let local_type = LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + scalar: crate::Scalar::U32, + pointer_space: None, + }; + self.get_type_id(local_type.into()) + } + + pub(super) fn get_float_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word { + let lookup_type = LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::F32, + pointer_space: Some(class), + }); + if let Some(&id) = self.lookup_type.get(&lookup_type) { + id + } else { + let id = self.id_gen.next(); + let ty_id = self.get_float_type_id(); + let instruction = Instruction::type_pointer(id, class, ty_id); + instruction.to_words(&mut self.logical_layout.declarations); + self.lookup_type.insert(lookup_type, id); + id + } + } + + pub(super) fn get_uint3_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word { + let lookup_type = LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + scalar: crate::Scalar::U32, + pointer_space: Some(class), + }); + if let Some(&id) = self.lookup_type.get(&lookup_type) { + id + } else { + let id = self.id_gen.next(); + let ty_id = self.get_uint3_type_id(); + let instruction = Instruction::type_pointer(id, class, ty_id); + instruction.to_words(&mut self.logical_layout.declarations); + self.lookup_type.insert(lookup_type, id); + id + } + } + + pub(super) fn get_bool_type_id(&mut self) -> Word { + let local_type = LocalType::Value { + vector_size: None, + scalar: crate::Scalar::BOOL, + pointer_space: None, + }; + self.get_type_id(local_type.into()) + } + + pub(super) fn get_bool3_type_id(&mut self) -> Word { + let local_type = LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + scalar: crate::Scalar::BOOL, + pointer_space: None, + }; + self.get_type_id(local_type.into()) + } + + pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) { + self.annotations + .push(Instruction::decorate(id, decoration, operands)); + } + + fn write_function( + &mut self, + ir_function: &crate::Function, + info: &FunctionInfo, + ir_module: &crate::Module, + mut interface: Option<FunctionInterface>, + debug_info: &Option<DebugInfoInner>, + ) -> Result<Word, Error> { + let mut function = Function::default(); + + let prelude_id = self.id_gen.next(); + let mut prelude = Block::new(prelude_id); + let mut ep_context = EntryPointContext { + argument_ids: Vec::new(), + results: Vec::new(), + }; + + let mut local_invocation_id = None; + + let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len()); + for argument in ir_function.arguments.iter() { + let class = spirv::StorageClass::Input; + let handle_ty = ir_module.types[argument.ty].inner.is_handle(); + let argument_type_id = match handle_ty { + true => self.get_pointer_id( + &ir_module.types, + argument.ty, + spirv::StorageClass::UniformConstant, + )?, + false => self.get_type_id(LookupType::Handle(argument.ty)), + }; + + if let Some(ref mut iface) = interface { + let id = if let Some(ref binding) = argument.binding { + let name = argument.name.as_deref(); + + let varying_id = self.write_varying( + ir_module, + iface.stage, + class, + name, + argument.ty, + binding, + )?; + iface.varying_ids.push(varying_id); + let id = self.id_gen.next(); + prelude + .body + .push(Instruction::load(argument_type_id, id, varying_id, None)); + + if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) { + local_invocation_id = Some(id); + } + + id + } else if let crate::TypeInner::Struct { ref members, .. } = + ir_module.types[argument.ty].inner + { + let struct_id = self.id_gen.next(); + let mut constituent_ids = Vec::with_capacity(members.len()); + for member in members { + let type_id = self.get_type_id(LookupType::Handle(member.ty)); + let name = member.name.as_deref(); + let binding = member.binding.as_ref().unwrap(); + let varying_id = self.write_varying( + ir_module, + iface.stage, + class, + name, + member.ty, + binding, + )?; + iface.varying_ids.push(varying_id); + let id = self.id_gen.next(); + prelude + .body + .push(Instruction::load(type_id, id, varying_id, None)); + constituent_ids.push(id); + + if binding == &crate::Binding::BuiltIn(crate::BuiltIn::GlobalInvocationId) { + local_invocation_id = Some(id); + } + } + prelude.body.push(Instruction::composite_construct( + argument_type_id, + struct_id, + &constituent_ids, + )); + struct_id + } else { + unreachable!("Missing argument binding on an entry point"); + }; + ep_context.argument_ids.push(id); + } else { + let argument_id = self.id_gen.next(); + let instruction = Instruction::function_parameter(argument_type_id, argument_id); + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(ref name) = argument.name { + self.debugs.push(Instruction::name(argument_id, name)); + } + } + function.parameters.push(FunctionArgument { + instruction, + handle_id: if handle_ty { + let id = self.id_gen.next(); + prelude.body.push(Instruction::load( + self.get_type_id(LookupType::Handle(argument.ty)), + id, + argument_id, + None, + )); + id + } else { + 0 + }, + }); + parameter_type_ids.push(argument_type_id); + }; + } + + let return_type_id = match ir_function.result { + Some(ref result) => { + if let Some(ref mut iface) = interface { + let mut has_point_size = false; + let class = spirv::StorageClass::Output; + if let Some(ref binding) = result.binding { + has_point_size |= + *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize); + let type_id = self.get_type_id(LookupType::Handle(result.ty)); + let varying_id = self.write_varying( + ir_module, + iface.stage, + class, + None, + result.ty, + binding, + )?; + iface.varying_ids.push(varying_id); + ep_context.results.push(ResultMember { + id: varying_id, + type_id, + built_in: binding.to_built_in(), + }); + } else if let crate::TypeInner::Struct { ref members, .. } = + ir_module.types[result.ty].inner + { + for member in members { + let type_id = self.get_type_id(LookupType::Handle(member.ty)); + let name = member.name.as_deref(); + let binding = member.binding.as_ref().unwrap(); + has_point_size |= + *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize); + let varying_id = self.write_varying( + ir_module, + iface.stage, + class, + name, + member.ty, + binding, + )?; + iface.varying_ids.push(varying_id); + ep_context.results.push(ResultMember { + id: varying_id, + type_id, + built_in: binding.to_built_in(), + }); + } + } else { + unreachable!("Missing result binding on an entry point"); + } + + if self.flags.contains(WriterFlags::FORCE_POINT_SIZE) + && iface.stage == crate::ShaderStage::Vertex + && !has_point_size + { + // add point size artificially + let varying_id = self.id_gen.next(); + let pointer_type_id = self.get_float_pointer_type_id(class); + Instruction::variable(pointer_type_id, varying_id, class, None) + .to_words(&mut self.logical_layout.declarations); + self.decorate( + varying_id, + spirv::Decoration::BuiltIn, + &[spirv::BuiltIn::PointSize as u32], + ); + iface.varying_ids.push(varying_id); + + let default_value_id = self.get_constant_scalar(crate::Literal::F32(1.0)); + prelude + .body + .push(Instruction::store(varying_id, default_value_id, None)); + } + self.void_type + } else { + self.get_type_id(LookupType::Handle(result.ty)) + } + } + None => self.void_type, + }; + + let lookup_function_type = LookupFunctionType { + parameter_type_ids, + return_type_id, + }; + + let function_id = self.id_gen.next(); + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(ref name) = ir_function.name { + self.debugs.push(Instruction::name(function_id, name)); + } + } + + let function_type = self.get_function_type(lookup_function_type); + function.signature = Some(Instruction::function( + return_type_id, + function_id, + spirv::FunctionControl::empty(), + function_type, + )); + + if interface.is_some() { + function.entry_point_context = Some(ep_context); + } + + // fill up the `GlobalVariable::access_id` + for gv in self.global_variables.iter_mut() { + gv.reset_for_function(); + } + for (handle, var) in ir_module.global_variables.iter() { + if info[handle].is_empty() { + continue; + } + + let mut gv = self.global_variables[handle.index()].clone(); + if let Some(ref mut iface) = interface { + // Have to include global variables in the interface + if self.physical_layout.version >= 0x10400 { + iface.varying_ids.push(gv.var_id); + } + } + + // Handle globals are pre-emitted and should be loaded automatically. + // + // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing. + let is_binding_array = match ir_module.types[var.ty].inner { + crate::TypeInner::BindingArray { .. } => true, + _ => false, + }; + + if var.space == crate::AddressSpace::Handle && !is_binding_array { + let var_type_id = self.get_type_id(LookupType::Handle(var.ty)); + let id = self.id_gen.next(); + prelude + .body + .push(Instruction::load(var_type_id, id, gv.var_id, None)); + gv.access_id = gv.var_id; + gv.handle_id = id; + } else if global_needs_wrapper(ir_module, var) { + let class = map_storage_class(var.space); + let pointer_type_id = self.get_pointer_id(&ir_module.types, var.ty, class)?; + let index_id = self.get_index_constant(0); + + let id = self.id_gen.next(); + prelude.body.push(Instruction::access_chain( + pointer_type_id, + id, + gv.var_id, + &[index_id], + )); + gv.access_id = id; + } else { + // by default, the variable ID is accessed as is + gv.access_id = gv.var_id; + }; + + // work around borrow checking in the presence of `self.xxx()` calls + self.global_variables[handle.index()] = gv; + } + + // Create a `BlockContext` for generating SPIR-V for the function's + // body. + let mut context = BlockContext { + ir_module, + ir_function, + fun_info: info, + function: &mut function, + // Re-use the cached expression table from prior functions. + cached: std::mem::take(&mut self.saved_cached), + + // Steal the Writer's temp list for a bit. + temp_list: std::mem::take(&mut self.temp_list), + writer: self, + expression_constness: crate::proc::ExpressionConstnessTracker::from_arena( + &ir_function.expressions, + ), + }; + + // fill up the pre-emitted and const expressions + context.cached.reset(ir_function.expressions.len()); + for (handle, expr) in ir_function.expressions.iter() { + if (expr.needs_pre_emit() && !matches!(*expr, crate::Expression::LocalVariable(_))) + || context.expression_constness.is_const(handle) + { + context.cache_expression_value(handle, &mut prelude)?; + } + } + + for (handle, variable) in ir_function.local_variables.iter() { + let id = context.gen_id(); + + if context.writer.flags.contains(WriterFlags::DEBUG) { + if let Some(ref name) = variable.name { + context.writer.debugs.push(Instruction::name(id, name)); + } + } + + let init_word = variable.init.map(|constant| context.cached[constant]); + let pointer_type_id = context.writer.get_pointer_id( + &ir_module.types, + variable.ty, + spirv::StorageClass::Function, + )?; + let instruction = Instruction::variable( + pointer_type_id, + id, + spirv::StorageClass::Function, + init_word.or_else(|| match ir_module.types[variable.ty].inner { + crate::TypeInner::RayQuery => None, + _ => { + let type_id = context.get_type_id(LookupType::Handle(variable.ty)); + Some(context.writer.write_constant_null(type_id)) + } + }), + ); + context + .function + .variables + .insert(handle, LocalVariable { id, instruction }); + } + + // cache local variable expressions + for (handle, expr) in ir_function.expressions.iter() { + if matches!(*expr, crate::Expression::LocalVariable(_)) { + context.cache_expression_value(handle, &mut prelude)?; + } + } + + let next_id = context.gen_id(); + + context + .function + .consume(prelude, Instruction::branch(next_id)); + + let workgroup_vars_init_exit_block_id = + match (context.writer.zero_initialize_workgroup_memory, interface) { + ( + super::ZeroInitializeWorkgroupMemoryMode::Polyfill, + Some( + ref mut interface @ FunctionInterface { + stage: crate::ShaderStage::Compute, + .. + }, + ), + ) => context.writer.generate_workgroup_vars_init_block( + next_id, + ir_module, + info, + local_invocation_id, + interface, + context.function, + ), + _ => None, + }; + + let main_id = if let Some(exit_id) = workgroup_vars_init_exit_block_id { + exit_id + } else { + next_id + }; + + context.write_block( + main_id, + &ir_function.body, + super::block::BlockExit::Return, + LoopContext::default(), + debug_info.as_ref(), + )?; + + // Consume the `BlockContext`, ending its borrows and letting the + // `Writer` steal back its cached expression table and temp_list. + let BlockContext { + cached, temp_list, .. + } = context; + self.saved_cached = cached; + self.temp_list = temp_list; + + function.to_words(&mut self.logical_layout.function_definitions); + Instruction::function_end().to_words(&mut self.logical_layout.function_definitions); + + Ok(function_id) + } + + fn write_execution_mode( + &mut self, + function_id: Word, + mode: spirv::ExecutionMode, + ) -> Result<(), Error> { + //self.check(mode.required_capabilities())?; + Instruction::execution_mode(function_id, mode, &[]) + .to_words(&mut self.logical_layout.execution_modes); + Ok(()) + } + + // TODO Move to instructions module + fn write_entry_point( + &mut self, + entry_point: &crate::EntryPoint, + info: &FunctionInfo, + ir_module: &crate::Module, + debug_info: &Option<DebugInfoInner>, + ) -> Result<Instruction, Error> { + let mut interface_ids = Vec::new(); + let function_id = self.write_function( + &entry_point.function, + info, + ir_module, + Some(FunctionInterface { + varying_ids: &mut interface_ids, + stage: entry_point.stage, + }), + debug_info, + )?; + + let exec_model = match entry_point.stage { + crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex, + crate::ShaderStage::Fragment => { + self.write_execution_mode(function_id, spirv::ExecutionMode::OriginUpperLeft)?; + if let Some(ref result) = entry_point.function.result { + if contains_builtin( + result.binding.as_ref(), + result.ty, + &ir_module.types, + crate::BuiltIn::FragDepth, + ) { + self.write_execution_mode( + function_id, + spirv::ExecutionMode::DepthReplacing, + )?; + } + } + spirv::ExecutionModel::Fragment + } + crate::ShaderStage::Compute => { + let execution_mode = spirv::ExecutionMode::LocalSize; + //self.check(execution_mode.required_capabilities())?; + Instruction::execution_mode( + function_id, + execution_mode, + &entry_point.workgroup_size, + ) + .to_words(&mut self.logical_layout.execution_modes); + spirv::ExecutionModel::GLCompute + } + }; + //self.check(exec_model.required_capabilities())?; + + Ok(Instruction::entry_point( + exec_model, + function_id, + &entry_point.name, + interface_ids.as_slice(), + )) + } + + fn make_scalar(&mut self, id: Word, scalar: crate::Scalar) -> Instruction { + use crate::ScalarKind as Sk; + + let bits = (scalar.width * BITS_PER_BYTE) as u32; + match scalar.kind { + Sk::Sint | Sk::Uint => { + let signedness = if scalar.kind == Sk::Sint { + super::instructions::Signedness::Signed + } else { + super::instructions::Signedness::Unsigned + }; + let cap = match bits { + 8 => Some(spirv::Capability::Int8), + 16 => Some(spirv::Capability::Int16), + 64 => Some(spirv::Capability::Int64), + _ => None, + }; + if let Some(cap) = cap { + self.capabilities_used.insert(cap); + } + Instruction::type_int(id, bits, signedness) + } + Sk::Float => { + if bits == 64 { + self.capabilities_used.insert(spirv::Capability::Float64); + } + Instruction::type_float(id, bits) + } + Sk::Bool => Instruction::type_bool(id), + Sk::AbstractInt | Sk::AbstractFloat => { + unreachable!("abstract types should never reach the backend"); + } + } + } + + fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> { + match *inner { + crate::TypeInner::Image { + dim, + arrayed, + class, + } => { + let sampled = match class { + crate::ImageClass::Sampled { .. } => true, + crate::ImageClass::Depth { .. } => true, + crate::ImageClass::Storage { format, .. } => { + self.request_image_format_capabilities(format.into())?; + false + } + }; + + match dim { + crate::ImageDimension::D1 => { + if sampled { + self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?; + } else { + self.require_any("1D storage images", &[spirv::Capability::Image1D])?; + } + } + crate::ImageDimension::Cube if arrayed => { + if sampled { + self.require_any( + "sampled cube array images", + &[spirv::Capability::SampledCubeArray], + )?; + } else { + self.require_any( + "cube array storage images", + &[spirv::Capability::ImageCubeArray], + )?; + } + } + _ => {} + } + } + crate::TypeInner::AccelerationStructure => { + self.require_any("Acceleration Structure", &[spirv::Capability::RayQueryKHR])?; + } + crate::TypeInner::RayQuery => { + self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?; + } + _ => {} + } + Ok(()) + } + + fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) { + let instruction = match local_ty { + LocalType::Value { + vector_size: None, + scalar, + pointer_space: None, + } => self.make_scalar(id, scalar), + LocalType::Value { + vector_size: Some(size), + scalar, + pointer_space: None, + } => { + let scalar_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar, + pointer_space: None, + })); + Instruction::type_vector(id, scalar_id, size) + } + LocalType::Matrix { + columns, + rows, + width, + } => { + let vector_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(rows), + scalar: crate::Scalar::float(width), + pointer_space: None, + })); + Instruction::type_matrix(id, vector_id, columns) + } + LocalType::Pointer { base, class } => { + let type_id = self.get_type_id(LookupType::Handle(base)); + Instruction::type_pointer(id, class, type_id) + } + LocalType::Value { + vector_size, + scalar, + pointer_space: Some(class), + } => { + let type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size, + scalar, + pointer_space: None, + })); + Instruction::type_pointer(id, class, type_id) + } + LocalType::Image(image) => { + let local_type = LocalType::Value { + vector_size: None, + scalar: crate::Scalar { + kind: image.sampled_type, + width: 4, + }, + pointer_space: None, + }; + let type_id = self.get_type_id(LookupType::Local(local_type)); + Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format) + } + LocalType::Sampler => Instruction::type_sampler(id), + LocalType::SampledImage { image_type_id } => { + Instruction::type_sampled_image(id, image_type_id) + } + LocalType::BindingArray { base, size } => { + let inner_ty = self.get_type_id(LookupType::Handle(base)); + let scalar_id = self.get_constant_scalar(crate::Literal::U32(size)); + Instruction::type_array(id, inner_ty, scalar_id) + } + LocalType::PointerToBindingArray { base, size, space } => { + let inner_ty = + self.get_type_id(LookupType::Local(LocalType::BindingArray { base, size })); + let class = map_storage_class(space); + Instruction::type_pointer(id, class, inner_ty) + } + LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id), + LocalType::RayQuery => Instruction::type_ray_query(id), + }; + + instruction.to_words(&mut self.logical_layout.declarations); + } + + fn write_type_declaration_arena( + &mut self, + arena: &UniqueArena<crate::Type>, + handle: Handle<crate::Type>, + ) -> Result<Word, Error> { + let ty = &arena[handle]; + let id = if let Some(local) = make_local(&ty.inner) { + // This type can be represented as a `LocalType`, so check if we've + // already written an instruction for it. If not, do so now, with + // `write_type_declaration_local`. + match self.lookup_type.entry(LookupType::Local(local)) { + // We already have an id for this `LocalType`. + Entry::Occupied(e) => *e.get(), + + // It's a type we haven't seen before. + Entry::Vacant(e) => { + let id = self.id_gen.next(); + e.insert(id); + + self.write_type_declaration_local(id, local); + + // If it's a type that needs SPIR-V capabilities, request them now, + // so write_type_declaration_local can stay infallible. + self.request_type_capabilities(&ty.inner)?; + + id + } + } + } else { + use spirv::Decoration; + + let id = self.id_gen.next(); + let instruction = match ty.inner { + crate::TypeInner::Array { base, size, stride } => { + self.decorate(id, Decoration::ArrayStride, &[stride]); + + let type_id = self.get_type_id(LookupType::Handle(base)); + match size { + crate::ArraySize::Constant(length) => { + let length_id = self.get_index_constant(length.get()); + Instruction::type_array(id, type_id, length_id) + } + crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id), + } + } + crate::TypeInner::BindingArray { base, size } => { + let type_id = self.get_type_id(LookupType::Handle(base)); + match size { + crate::ArraySize::Constant(length) => { + let length_id = self.get_index_constant(length.get()); + Instruction::type_array(id, type_id, length_id) + } + crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id), + } + } + crate::TypeInner::Struct { + ref members, + span: _, + } => { + let mut has_runtime_array = false; + let mut member_ids = Vec::with_capacity(members.len()); + for (index, member) in members.iter().enumerate() { + let member_ty = &arena[member.ty]; + match member_ty.inner { + crate::TypeInner::Array { + base: _, + size: crate::ArraySize::Dynamic, + stride: _, + } => { + has_runtime_array = true; + } + _ => (), + } + self.decorate_struct_member(id, index, member, arena)?; + let member_id = self.get_type_id(LookupType::Handle(member.ty)); + member_ids.push(member_id); + } + if has_runtime_array { + self.decorate(id, Decoration::Block, &[]); + } + Instruction::type_struct(id, member_ids.as_slice()) + } + + // These all have TypeLocal representations, so they should have been + // handled by `write_type_declaration_local` above. + crate::TypeInner::Scalar(_) + | crate::TypeInner::Atomic(_) + | crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } + | crate::TypeInner::Pointer { .. } + | crate::TypeInner::ValuePointer { .. } + | crate::TypeInner::Image { .. } + | crate::TypeInner::Sampler { .. } + | crate::TypeInner::AccelerationStructure + | crate::TypeInner::RayQuery => unreachable!(), + }; + + instruction.to_words(&mut self.logical_layout.declarations); + id + }; + + // Add this handle as a new alias for that type. + self.lookup_type.insert(LookupType::Handle(handle), id); + + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(ref name) = ty.name { + self.debugs.push(Instruction::name(id, name)); + } + } + + Ok(id) + } + + fn request_image_format_capabilities( + &mut self, + format: spirv::ImageFormat, + ) -> Result<(), Error> { + use spirv::ImageFormat as If; + match format { + If::Rg32f + | If::Rg16f + | If::R11fG11fB10f + | If::R16f + | If::Rgba16 + | If::Rgb10A2 + | If::Rg16 + | If::Rg8 + | If::R16 + | If::R8 + | If::Rgba16Snorm + | If::Rg16Snorm + | If::Rg8Snorm + | If::R16Snorm + | If::R8Snorm + | If::Rg32i + | If::Rg16i + | If::Rg8i + | If::R16i + | If::R8i + | If::Rgb10a2ui + | If::Rg32ui + | If::Rg16ui + | If::Rg8ui + | If::R16ui + | If::R8ui => self.require_any( + "storage image format", + &[spirv::Capability::StorageImageExtendedFormats], + ), + If::R64ui | If::R64i => self.require_any( + "64-bit integer storage image format", + &[spirv::Capability::Int64ImageEXT], + ), + If::Unknown + | If::Rgba32f + | If::Rgba16f + | If::R32f + | If::Rgba8 + | If::Rgba8Snorm + | If::Rgba32i + | If::Rgba16i + | If::Rgba8i + | If::R32i + | If::Rgba32ui + | If::Rgba16ui + | If::Rgba8ui + | If::R32ui => Ok(()), + } + } + + pub(super) fn get_index_constant(&mut self, index: Word) -> Word { + self.get_constant_scalar(crate::Literal::U32(index)) + } + + pub(super) fn get_constant_scalar_with( + &mut self, + value: u8, + scalar: crate::Scalar, + ) -> Result<Word, Error> { + Ok( + self.get_constant_scalar(crate::Literal::new(value, scalar).ok_or( + Error::Validation("Unexpected kind and/or width for Literal"), + )?), + ) + } + + pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word { + let scalar = CachedConstant::Literal(value); + if let Some(&id) = self.cached_constants.get(&scalar) { + return id; + } + let id = self.id_gen.next(); + self.write_constant_scalar(id, &value, None); + self.cached_constants.insert(scalar, id); + id + } + + fn write_constant_scalar( + &mut self, + id: Word, + value: &crate::Literal, + debug_name: Option<&String>, + ) { + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(name) = debug_name { + self.debugs.push(Instruction::name(id, name)); + } + } + let type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: value.scalar(), + pointer_space: None, + })); + let instruction = match *value { + crate::Literal::F64(value) => { + let bits = value.to_bits(); + Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32) + } + crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()), + crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value), + crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32), + crate::Literal::I64(value) => { + Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32) + } + crate::Literal::Bool(true) => Instruction::constant_true(type_id, id), + crate::Literal::Bool(false) => Instruction::constant_false(type_id, id), + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + unreachable!("Abstract types should not appear in IR presented to backends"); + } + }; + + instruction.to_words(&mut self.logical_layout.declarations); + } + + pub(super) fn get_constant_composite( + &mut self, + ty: LookupType, + constituent_ids: &[Word], + ) -> Word { + let composite = CachedConstant::Composite { + ty, + constituent_ids: constituent_ids.to_vec(), + }; + if let Some(&id) = self.cached_constants.get(&composite) { + return id; + } + let id = self.id_gen.next(); + self.write_constant_composite(id, ty, constituent_ids, None); + self.cached_constants.insert(composite, id); + id + } + + fn write_constant_composite( + &mut self, + id: Word, + ty: LookupType, + constituent_ids: &[Word], + debug_name: Option<&String>, + ) { + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(name) = debug_name { + self.debugs.push(Instruction::name(id, name)); + } + } + let type_id = self.get_type_id(ty); + Instruction::constant_composite(type_id, id, constituent_ids) + .to_words(&mut self.logical_layout.declarations); + } + + pub(super) fn get_constant_null(&mut self, type_id: Word) -> Word { + let null = CachedConstant::ZeroValue(type_id); + if let Some(&id) = self.cached_constants.get(&null) { + return id; + } + let id = self.write_constant_null(type_id); + self.cached_constants.insert(null, id); + id + } + + pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word { + let null_id = self.id_gen.next(); + Instruction::constant_null(type_id, null_id) + .to_words(&mut self.logical_layout.declarations); + null_id + } + + fn write_constant_expr( + &mut self, + handle: Handle<crate::Expression>, + ir_module: &crate::Module, + mod_info: &ModuleInfo, + ) -> Result<Word, Error> { + let id = match ir_module.const_expressions[handle] { + crate::Expression::Literal(literal) => self.get_constant_scalar(literal), + crate::Expression::Constant(constant) => { + let constant = &ir_module.constants[constant]; + self.constant_ids[constant.init.index()] + } + crate::Expression::ZeroValue(ty) => { + let type_id = self.get_type_id(LookupType::Handle(ty)); + self.get_constant_null(type_id) + } + crate::Expression::Compose { ty, ref components } => { + let component_ids: Vec<_> = crate::proc::flatten_compose( + ty, + components, + &ir_module.const_expressions, + &ir_module.types, + ) + .map(|component| self.constant_ids[component.index()]) + .collect(); + self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice()) + } + crate::Expression::Splat { size, value } => { + let value_id = self.constant_ids[value.index()]; + let component_ids = &[value_id; 4][..size as usize]; + + let ty = self.get_expression_lookup_type(&mod_info[handle]); + + self.get_constant_composite(ty, component_ids) + } + _ => unreachable!(), + }; + + self.constant_ids[handle.index()] = id; + + Ok(id) + } + + pub(super) fn write_barrier(&mut self, flags: crate::Barrier, block: &mut Block) { + let memory_scope = if flags.contains(crate::Barrier::STORAGE) { + spirv::Scope::Device + } else { + spirv::Scope::Workgroup + }; + let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE; + semantics.set( + spirv::MemorySemantics::UNIFORM_MEMORY, + flags.contains(crate::Barrier::STORAGE), + ); + semantics.set( + spirv::MemorySemantics::WORKGROUP_MEMORY, + flags.contains(crate::Barrier::WORK_GROUP), + ); + let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32); + let mem_scope_id = self.get_index_constant(memory_scope as u32); + let semantics_id = self.get_index_constant(semantics.bits()); + block.body.push(Instruction::control_barrier( + exec_scope_id, + mem_scope_id, + semantics_id, + )); + } + + fn generate_workgroup_vars_init_block( + &mut self, + entry_id: Word, + ir_module: &crate::Module, + info: &FunctionInfo, + local_invocation_id: Option<Word>, + interface: &mut FunctionInterface, + function: &mut Function, + ) -> Option<Word> { + let body = ir_module + .global_variables + .iter() + .filter(|&(handle, var)| { + !info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }) + .map(|(handle, var)| { + // It's safe to use `var_id` here, not `access_id`, because only + // variables in the `Uniform` and `StorageBuffer` address spaces + // get wrapped, and we're initializing `WorkGroup` variables. + let var_id = self.global_variables[handle.index()].var_id; + let var_type_id = self.get_type_id(LookupType::Handle(var.ty)); + let init_word = self.get_constant_null(var_type_id); + Instruction::store(var_id, init_word, None) + }) + .collect::<Vec<_>>(); + + if body.is_empty() { + return None; + } + + let uint3_type_id = self.get_uint3_type_id(); + + let mut pre_if_block = Block::new(entry_id); + + let local_invocation_id = if let Some(local_invocation_id) = local_invocation_id { + local_invocation_id + } else { + let varying_id = self.id_gen.next(); + let class = spirv::StorageClass::Input; + let pointer_type_id = self.get_uint3_pointer_type_id(class); + + Instruction::variable(pointer_type_id, varying_id, class, None) + .to_words(&mut self.logical_layout.declarations); + + self.decorate( + varying_id, + spirv::Decoration::BuiltIn, + &[spirv::BuiltIn::LocalInvocationId as u32], + ); + + interface.varying_ids.push(varying_id); + let id = self.id_gen.next(); + pre_if_block + .body + .push(Instruction::load(uint3_type_id, id, varying_id, None)); + + id + }; + + let zero_id = self.get_constant_null(uint3_type_id); + let bool3_type_id = self.get_bool3_type_id(); + + let eq_id = self.id_gen.next(); + pre_if_block.body.push(Instruction::binary( + spirv::Op::IEqual, + bool3_type_id, + eq_id, + local_invocation_id, + zero_id, + )); + + let condition_id = self.id_gen.next(); + let bool_type_id = self.get_bool_type_id(); + pre_if_block.body.push(Instruction::relational( + spirv::Op::All, + bool_type_id, + condition_id, + eq_id, + )); + + let merge_id = self.id_gen.next(); + pre_if_block.body.push(Instruction::selection_merge( + merge_id, + spirv::SelectionControl::NONE, + )); + + let accept_id = self.id_gen.next(); + function.consume( + pre_if_block, + Instruction::branch_conditional(condition_id, accept_id, merge_id), + ); + + let accept_block = Block { + label_id: accept_id, + body, + }; + function.consume(accept_block, Instruction::branch(merge_id)); + + let mut post_if_block = Block::new(merge_id); + + self.write_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block); + + let next_id = self.id_gen.next(); + function.consume(post_if_block, Instruction::branch(next_id)); + Some(next_id) + } + + /// Generate an `OpVariable` for one value in an [`EntryPoint`]'s IO interface. + /// + /// The [`Binding`]s of the arguments and result of an [`EntryPoint`]'s + /// [`Function`] describe a SPIR-V shader interface. In SPIR-V, the + /// interface is represented by global variables in the `Input` and `Output` + /// storage classes, with decorations indicating which builtin or location + /// each variable corresponds to. + /// + /// This function emits a single global `OpVariable` for a single value from + /// the interface, and adds appropriate decorations to indicate which + /// builtin or location it represents, how it should be interpolated, and so + /// on. The `class` argument gives the variable's SPIR-V storage class, + /// which should be either [`Input`] or [`Output`]. + /// + /// [`Binding`]: crate::Binding + /// [`Function`]: crate::Function + /// [`EntryPoint`]: crate::EntryPoint + /// [`Input`]: spirv::StorageClass::Input + /// [`Output`]: spirv::StorageClass::Output + fn write_varying( + &mut self, + ir_module: &crate::Module, + stage: crate::ShaderStage, + class: spirv::StorageClass, + debug_name: Option<&str>, + ty: Handle<crate::Type>, + binding: &crate::Binding, + ) -> Result<Word, Error> { + let id = self.id_gen.next(); + let pointer_type_id = self.get_pointer_id(&ir_module.types, ty, class)?; + Instruction::variable(pointer_type_id, id, class, None) + .to_words(&mut self.logical_layout.declarations); + + if self + .flags + .contains(WriterFlags::DEBUG | WriterFlags::LABEL_VARYINGS) + { + if let Some(name) = debug_name { + self.debugs.push(Instruction::name(id, name)); + } + } + + use spirv::{BuiltIn, Decoration}; + + match *binding { + crate::Binding::Location { + location, + interpolation, + sampling, + second_blend_source, + } => { + self.decorate(id, Decoration::Location, &[location]); + + let no_decorations = + // VUID-StandaloneSpirv-Flat-06202 + // > The Flat, NoPerspective, Sample, and Centroid decorations + // > must not be used on variables with the Input storage class in a vertex shader + (class == spirv::StorageClass::Input && stage == crate::ShaderStage::Vertex) || + // VUID-StandaloneSpirv-Flat-06201 + // > The Flat, NoPerspective, Sample, and Centroid decorations + // > must not be used on variables with the Output storage class in a fragment shader + (class == spirv::StorageClass::Output && stage == crate::ShaderStage::Fragment); + + if !no_decorations { + match interpolation { + // Perspective-correct interpolation is the default in SPIR-V. + None | Some(crate::Interpolation::Perspective) => (), + Some(crate::Interpolation::Flat) => { + self.decorate(id, Decoration::Flat, &[]); + } + Some(crate::Interpolation::Linear) => { + self.decorate(id, Decoration::NoPerspective, &[]); + } + } + match sampling { + // Center sampling is the default in SPIR-V. + None | Some(crate::Sampling::Center) => (), + Some(crate::Sampling::Centroid) => { + self.decorate(id, Decoration::Centroid, &[]); + } + Some(crate::Sampling::Sample) => { + self.require_any( + "per-sample interpolation", + &[spirv::Capability::SampleRateShading], + )?; + self.decorate(id, Decoration::Sample, &[]); + } + } + } + if second_blend_source { + self.decorate(id, Decoration::Index, &[1]); + } + } + crate::Binding::BuiltIn(built_in) => { + use crate::BuiltIn as Bi; + let built_in = match built_in { + Bi::Position { invariant } => { + if invariant { + self.decorate(id, Decoration::Invariant, &[]); + } + + if class == spirv::StorageClass::Output { + BuiltIn::Position + } else { + BuiltIn::FragCoord + } + } + Bi::ViewIndex => { + self.require_any("`view_index` built-in", &[spirv::Capability::MultiView])?; + BuiltIn::ViewIndex + } + // vertex + Bi::BaseInstance => BuiltIn::BaseInstance, + Bi::BaseVertex => BuiltIn::BaseVertex, + Bi::ClipDistance => { + self.require_any( + "`clip_distance` built-in", + &[spirv::Capability::ClipDistance], + )?; + BuiltIn::ClipDistance + } + Bi::CullDistance => { + self.require_any( + "`cull_distance` built-in", + &[spirv::Capability::CullDistance], + )?; + BuiltIn::CullDistance + } + Bi::InstanceIndex => BuiltIn::InstanceIndex, + Bi::PointSize => BuiltIn::PointSize, + Bi::VertexIndex => BuiltIn::VertexIndex, + // fragment + Bi::FragDepth => BuiltIn::FragDepth, + Bi::PointCoord => BuiltIn::PointCoord, + Bi::FrontFacing => BuiltIn::FrontFacing, + Bi::PrimitiveIndex => { + self.require_any( + "`primitive_index` built-in", + &[spirv::Capability::Geometry], + )?; + BuiltIn::PrimitiveId + } + Bi::SampleIndex => { + self.require_any( + "`sample_index` built-in", + &[spirv::Capability::SampleRateShading], + )?; + + BuiltIn::SampleId + } + Bi::SampleMask => BuiltIn::SampleMask, + // compute + Bi::GlobalInvocationId => BuiltIn::GlobalInvocationId, + Bi::LocalInvocationId => BuiltIn::LocalInvocationId, + Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex, + Bi::WorkGroupId => BuiltIn::WorkgroupId, + Bi::WorkGroupSize => BuiltIn::WorkgroupSize, + Bi::NumWorkGroups => BuiltIn::NumWorkgroups, + }; + + self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); + + use crate::ScalarKind as Sk; + + // Per the Vulkan spec, `VUID-StandaloneSpirv-Flat-04744`: + // + // > Any variable with integer or double-precision floating- + // > point type and with Input storage class in a fragment + // > shader, must be decorated Flat + if class == spirv::StorageClass::Input && stage == crate::ShaderStage::Fragment { + let is_flat = match ir_module.types[ty].inner { + crate::TypeInner::Scalar(scalar) + | crate::TypeInner::Vector { scalar, .. } => match scalar.kind { + Sk::Uint | Sk::Sint | Sk::Bool => true, + Sk::Float => false, + Sk::AbstractInt | Sk::AbstractFloat => { + return Err(Error::Validation( + "Abstract types should not appear in IR presented to backends", + )) + } + }, + _ => false, + }; + + if is_flat { + self.decorate(id, Decoration::Flat, &[]); + } + } + } + } + + Ok(id) + } + + fn write_global_variable( + &mut self, + ir_module: &crate::Module, + global_variable: &crate::GlobalVariable, + ) -> Result<Word, Error> { + use spirv::Decoration; + + let id = self.id_gen.next(); + let class = map_storage_class(global_variable.space); + + //self.check(class.required_capabilities())?; + + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(ref name) = global_variable.name { + self.debugs.push(Instruction::name(id, name)); + } + } + + let storage_access = match global_variable.space { + crate::AddressSpace::Storage { access } => Some(access), + _ => match ir_module.types[global_variable.ty].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage { access, .. }, + .. + } => Some(access), + _ => None, + }, + }; + if let Some(storage_access) = storage_access { + if !storage_access.contains(crate::StorageAccess::LOAD) { + self.decorate(id, Decoration::NonReadable, &[]); + } + if !storage_access.contains(crate::StorageAccess::STORE) { + self.decorate(id, Decoration::NonWritable, &[]); + } + } + + // Note: we should be able to substitute `binding_array<Foo, 0>`, + // but there is still code that tries to register the pre-substituted type, + // and it is failing on 0. + let mut substitute_inner_type_lookup = None; + if let Some(ref res_binding) = global_variable.binding { + self.decorate(id, Decoration::DescriptorSet, &[res_binding.group]); + self.decorate(id, Decoration::Binding, &[res_binding.binding]); + + if let Some(&BindingInfo { + binding_array_size: Some(remapped_binding_array_size), + }) = self.binding_map.get(res_binding) + { + if let crate::TypeInner::BindingArray { base, .. } = + ir_module.types[global_variable.ty].inner + { + substitute_inner_type_lookup = + Some(LookupType::Local(LocalType::PointerToBindingArray { + base, + size: remapped_binding_array_size, + space: global_variable.space, + })) + } + } + }; + + let init_word = global_variable + .init + .map(|constant| self.constant_ids[constant.index()]); + let inner_type_id = self.get_type_id( + substitute_inner_type_lookup.unwrap_or(LookupType::Handle(global_variable.ty)), + ); + + // generate the wrapping structure if needed + let pointer_type_id = if global_needs_wrapper(ir_module, global_variable) { + let wrapper_type_id = self.id_gen.next(); + + self.decorate(wrapper_type_id, Decoration::Block, &[]); + let member = crate::StructMember { + name: None, + ty: global_variable.ty, + binding: None, + offset: 0, + }; + self.decorate_struct_member(wrapper_type_id, 0, &member, &ir_module.types)?; + + Instruction::type_struct(wrapper_type_id, &[inner_type_id]) + .to_words(&mut self.logical_layout.declarations); + + let pointer_type_id = self.id_gen.next(); + Instruction::type_pointer(pointer_type_id, class, wrapper_type_id) + .to_words(&mut self.logical_layout.declarations); + + pointer_type_id + } else { + // This is a global variable in the Storage address space. The only + // way it could have `global_needs_wrapper() == false` is if it has + // a runtime-sized or binding array. + // Runtime-sized arrays were decorated when iterating through struct content. + // Now binding arrays require Block decorating. + if let crate::AddressSpace::Storage { .. } = global_variable.space { + match ir_module.types[global_variable.ty].inner { + crate::TypeInner::BindingArray { base, .. } => { + let decorated_id = self.get_type_id(LookupType::Handle(base)); + self.decorate(decorated_id, Decoration::Block, &[]); + } + _ => (), + }; + } + if substitute_inner_type_lookup.is_some() { + inner_type_id + } else { + self.get_pointer_id(&ir_module.types, global_variable.ty, class)? + } + }; + + let init_word = match (global_variable.space, self.zero_initialize_workgroup_memory) { + (crate::AddressSpace::Private, _) + | (crate::AddressSpace::WorkGroup, super::ZeroInitializeWorkgroupMemoryMode::Native) => { + init_word.or_else(|| Some(self.get_constant_null(inner_type_id))) + } + _ => init_word, + }; + + Instruction::variable(pointer_type_id, id, class, init_word) + .to_words(&mut self.logical_layout.declarations); + Ok(id) + } + + /// Write the necessary decorations for a struct member. + /// + /// Emit decorations for the `index`'th member of the struct type + /// designated by `struct_id`, described by `member`. + fn decorate_struct_member( + &mut self, + struct_id: Word, + index: usize, + member: &crate::StructMember, + arena: &UniqueArena<crate::Type>, + ) -> Result<(), Error> { + use spirv::Decoration; + + self.annotations.push(Instruction::member_decorate( + struct_id, + index as u32, + Decoration::Offset, + &[member.offset], + )); + + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(ref name) = member.name { + self.debugs + .push(Instruction::member_name(struct_id, index as u32, name)); + } + } + + // Matrices and arrays of matrices both require decorations, + // so "see through" an array to determine if they're needed. + let member_array_subty_inner = match arena[member.ty].inner { + crate::TypeInner::Array { base, .. } => &arena[base].inner, + ref other => other, + }; + if let crate::TypeInner::Matrix { + columns: _, + rows, + scalar, + } = *member_array_subty_inner + { + let byte_stride = Alignment::from(rows) * scalar.width as u32; + self.annotations.push(Instruction::member_decorate( + struct_id, + index as u32, + Decoration::ColMajor, + &[], + )); + self.annotations.push(Instruction::member_decorate( + struct_id, + index as u32, + Decoration::MatrixStride, + &[byte_stride], + )); + } + + Ok(()) + } + + fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word { + match self + .lookup_function_type + .entry(lookup_function_type.clone()) + { + Entry::Occupied(e) => *e.get(), + Entry::Vacant(_) => { + let id = self.id_gen.next(); + let instruction = Instruction::type_function( + id, + lookup_function_type.return_type_id, + &lookup_function_type.parameter_type_ids, + ); + instruction.to_words(&mut self.logical_layout.declarations); + self.lookup_function_type.insert(lookup_function_type, id); + id + } + } + } + + fn write_physical_layout(&mut self) { + self.physical_layout.bound = self.id_gen.0 + 1; + } + + fn write_logical_layout( + &mut self, + ir_module: &crate::Module, + mod_info: &ModuleInfo, + ep_index: Option<usize>, + debug_info: &Option<DebugInfo>, + ) -> Result<(), Error> { + fn has_view_index_check( + ir_module: &crate::Module, + binding: Option<&crate::Binding>, + ty: Handle<crate::Type>, + ) -> bool { + match ir_module.types[ty].inner { + crate::TypeInner::Struct { ref members, .. } => members.iter().any(|member| { + has_view_index_check(ir_module, member.binding.as_ref(), member.ty) + }), + _ => binding == Some(&crate::Binding::BuiltIn(crate::BuiltIn::ViewIndex)), + } + } + + let has_storage_buffers = + ir_module + .global_variables + .iter() + .any(|(_, var)| match var.space { + crate::AddressSpace::Storage { .. } => true, + _ => false, + }); + let has_view_index = ir_module + .entry_points + .iter() + .flat_map(|entry| entry.function.arguments.iter()) + .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty)); + let has_ray_query = ir_module.special_types.ray_desc.is_some() + | ir_module.special_types.ray_intersection.is_some(); + + if self.physical_layout.version < 0x10300 && has_storage_buffers { + // enable the storage buffer class on < SPV-1.3 + Instruction::extension("SPV_KHR_storage_buffer_storage_class") + .to_words(&mut self.logical_layout.extensions); + } + if has_view_index { + Instruction::extension("SPV_KHR_multiview") + .to_words(&mut self.logical_layout.extensions) + } + if has_ray_query { + Instruction::extension("SPV_KHR_ray_query") + .to_words(&mut self.logical_layout.extensions) + } + Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations); + Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450") + .to_words(&mut self.logical_layout.ext_inst_imports); + + let mut debug_info_inner = None; + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(debug_info) = debug_info.as_ref() { + let source_file_id = self.id_gen.next(); + self.debugs.push(Instruction::string( + &debug_info.file_name.display().to_string(), + source_file_id, + )); + + debug_info_inner = Some(DebugInfoInner { + source_code: debug_info.source_code, + source_file_id, + }); + self.debugs.push(Instruction::source( + spirv::SourceLanguage::Unknown, + 0, + &debug_info_inner, + )); + } + } + + // write all types + for (handle, _) in ir_module.types.iter() { + self.write_type_declaration_arena(&ir_module.types, handle)?; + } + + // write all const-expressions as constants + self.constant_ids + .resize(ir_module.const_expressions.len(), 0); + for (handle, _) in ir_module.const_expressions.iter() { + self.write_constant_expr(handle, ir_module, mod_info)?; + } + debug_assert!(self.constant_ids.iter().all(|&id| id != 0)); + + // write the name of constants on their respective const-expression initializer + if self.flags.contains(WriterFlags::DEBUG) { + for (_, constant) in ir_module.constants.iter() { + if let Some(ref name) = constant.name { + let id = self.constant_ids[constant.init.index()]; + self.debugs.push(Instruction::name(id, name)); + } + } + } + + // write all global variables + for (handle, var) in ir_module.global_variables.iter() { + // If a single entry point was specified, only write `OpVariable` instructions + // for the globals it actually uses. Emit dummies for the others, + // to preserve the indices in `global_variables`. + let gvar = match ep_index { + Some(index) if mod_info.get_entry_point(index)[handle].is_empty() => { + GlobalVariable::dummy() + } + _ => { + let id = self.write_global_variable(ir_module, var)?; + GlobalVariable::new(id) + } + }; + self.global_variables.push(gvar); + } + + // write all functions + for (handle, ir_function) in ir_module.functions.iter() { + let info = &mod_info[handle]; + if let Some(index) = ep_index { + let ep_info = mod_info.get_entry_point(index); + // If this function uses globals that we omitted from the SPIR-V + // because the entry point and its callees didn't use them, + // then we must skip it. + if !ep_info.dominates_global_use(info) { + log::info!("Skip function {:?}", ir_function.name); + continue; + } + + // Skip functions that that are not compatible with this entry point's stage. + // + // When validation is enabled, it rejects modules whose entry points try to call + // incompatible functions, so if we got this far, then any functions incompatible + // with our selected entry point must not be used. + // + // When validation is disabled, `fun_info.available_stages` is always just + // `ShaderStages::all()`, so this will write all functions in the module, and + // the downstream GLSL compiler will catch any problems. + if !info.available_stages.contains(ep_info.available_stages) { + continue; + } + } + let id = self.write_function(ir_function, info, ir_module, None, &debug_info_inner)?; + self.lookup_function.insert(handle, id); + } + + // write all or one entry points + for (index, ir_ep) in ir_module.entry_points.iter().enumerate() { + if ep_index.is_some() && ep_index != Some(index) { + continue; + } + let info = mod_info.get_entry_point(index); + let ep_instruction = + self.write_entry_point(ir_ep, info, ir_module, &debug_info_inner)?; + ep_instruction.to_words(&mut self.logical_layout.entry_points); + } + + for capability in self.capabilities_used.iter() { + Instruction::capability(*capability).to_words(&mut self.logical_layout.capabilities); + } + for extension in self.extensions_used.iter() { + Instruction::extension(extension).to_words(&mut self.logical_layout.extensions); + } + if ir_module.entry_points.is_empty() { + // SPIR-V doesn't like modules without entry points + Instruction::capability(spirv::Capability::Linkage) + .to_words(&mut self.logical_layout.capabilities); + } + + let addressing_model = spirv::AddressingModel::Logical; + let memory_model = spirv::MemoryModel::GLSL450; + //self.check(addressing_model.required_capabilities())?; + //self.check(memory_model.required_capabilities())?; + + Instruction::memory_model(addressing_model, memory_model) + .to_words(&mut self.logical_layout.memory_model); + + if self.flags.contains(WriterFlags::DEBUG) { + for debug in self.debugs.iter() { + debug.to_words(&mut self.logical_layout.debugs); + } + } + + for annotation in self.annotations.iter() { + annotation.to_words(&mut self.logical_layout.annotations); + } + + Ok(()) + } + + pub fn write( + &mut self, + ir_module: &crate::Module, + info: &ModuleInfo, + pipeline_options: Option<&PipelineOptions>, + debug_info: &Option<DebugInfo>, + words: &mut Vec<Word>, + ) -> Result<(), Error> { + self.reset(); + + // Try to find the entry point and corresponding index + let ep_index = match pipeline_options { + Some(po) => { + let index = ir_module + .entry_points + .iter() + .position(|ep| po.shader_stage == ep.stage && po.entry_point == ep.name) + .ok_or(Error::EntryPointNotFound)?; + Some(index) + } + None => None, + }; + + self.write_logical_layout(ir_module, info, ep_index, debug_info)?; + self.write_physical_layout(); + + self.physical_layout.in_words(words); + self.logical_layout.in_words(words); + Ok(()) + } + + /// Return the set of capabilities the last module written used. + pub const fn get_capabilities_used(&self) -> &crate::FastIndexSet<spirv::Capability> { + &self.capabilities_used + } + + pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> { + self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?; + self.use_extension("SPV_EXT_descriptor_indexing"); + self.decorate(id, spirv::Decoration::NonUniform, &[]); + Ok(()) + } +} + +#[test] +fn test_write_physical_layout() { + let mut writer = Writer::new(&Options::default()).unwrap(); + assert_eq!(writer.physical_layout.bound, 0); + writer.write_physical_layout(); + assert_eq!(writer.physical_layout.bound, 3); +} diff --git a/third_party/rust/naga/src/back/wgsl/mod.rs b/third_party/rust/naga/src/back/wgsl/mod.rs new file mode 100644 index 0000000000..d731b1ca0c --- /dev/null +++ b/third_party/rust/naga/src/back/wgsl/mod.rs @@ -0,0 +1,52 @@ +/*! +Backend for [WGSL][wgsl] (WebGPU Shading Language). + +[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html +*/ + +mod writer; + +use thiserror::Error; + +pub use writer::{Writer, WriterFlags}; + +#[derive(Error, Debug)] +pub enum Error { + #[error(transparent)] + FmtError(#[from] std::fmt::Error), + #[error("{0}")] + Custom(String), + #[error("{0}")] + Unimplemented(String), // TODO: Error used only during development + #[error("Unsupported math function: {0:?}")] + UnsupportedMathFunction(crate::MathFunction), + #[error("Unsupported relational function: {0:?}")] + UnsupportedRelationalFunction(crate::RelationalFunction), +} + +pub fn write_string( + module: &crate::Module, + info: &crate::valid::ModuleInfo, + flags: WriterFlags, +) -> Result<String, Error> { + let mut w = Writer::new(String::new(), flags); + w.write(module, info)?; + let output = w.finish(); + Ok(output) +} + +impl crate::AtomicFunction { + const fn to_wgsl(self) -> &'static str { + match self { + Self::Add => "Add", + Self::Subtract => "Sub", + Self::And => "And", + Self::InclusiveOr => "Or", + Self::ExclusiveOr => "Xor", + Self::Min => "Min", + Self::Max => "Max", + Self::Exchange { compare: None } => "Exchange", + Self::Exchange { .. } => "CompareExchangeWeak", + } + } +} diff --git a/third_party/rust/naga/src/back/wgsl/writer.rs b/third_party/rust/naga/src/back/wgsl/writer.rs new file mode 100644 index 0000000000..c737934f5e --- /dev/null +++ b/third_party/rust/naga/src/back/wgsl/writer.rs @@ -0,0 +1,1961 @@ +use super::Error; +use crate::{ + back, + proc::{self, NameKey}, + valid, Handle, Module, ShaderStage, TypeInner, +}; +use std::fmt::Write; + +/// Shorthand result used internally by the backend +type BackendResult = Result<(), Error>; + +/// WGSL [attribute](https://gpuweb.github.io/gpuweb/wgsl/#attributes) +enum Attribute { + Binding(u32), + BuiltIn(crate::BuiltIn), + Group(u32), + Invariant, + Interpolate(Option<crate::Interpolation>, Option<crate::Sampling>), + Location(u32), + SecondBlendSource, + Stage(ShaderStage), + WorkGroupSize([u32; 3]), +} + +/// The WGSL form that `write_expr_with_indirection` should use to render a Naga +/// expression. +/// +/// Sometimes a Naga `Expression` alone doesn't provide enough information to +/// choose the right rendering for it in WGSL. For example, one natural WGSL +/// rendering of a Naga `LocalVariable(x)` expression might be `&x`, since +/// `LocalVariable` produces a pointer to the local variable's storage. But when +/// rendering a `Store` statement, the `pointer` operand must be the left hand +/// side of a WGSL assignment, so the proper rendering is `x`. +/// +/// The caller of `write_expr_with_indirection` must provide an `Expected` value +/// to indicate how ambiguous expressions should be rendered. +#[derive(Clone, Copy, Debug)] +enum Indirection { + /// Render pointer-construction expressions as WGSL `ptr`-typed expressions. + /// + /// This is the right choice for most cases. Whenever a Naga pointer + /// expression is not the `pointer` operand of a `Load` or `Store`, it + /// must be a WGSL pointer expression. + Ordinary, + + /// Render pointer-construction expressions as WGSL reference-typed + /// expressions. + /// + /// For example, this is the right choice for the `pointer` operand when + /// rendering a `Store` statement as a WGSL assignment. + Reference, +} + +bitflags::bitflags! { + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct WriterFlags: u32 { + /// Always annotate the type information instead of inferring. + const EXPLICIT_TYPES = 0x1; + } +} + +pub struct Writer<W> { + out: W, + flags: WriterFlags, + names: crate::FastHashMap<NameKey, String>, + namer: proc::Namer, + named_expressions: crate::NamedExpressions, + ep_results: Vec<(ShaderStage, Handle<crate::Type>)>, +} + +impl<W: Write> Writer<W> { + pub fn new(out: W, flags: WriterFlags) -> Self { + Writer { + out, + flags, + names: crate::FastHashMap::default(), + namer: proc::Namer::default(), + named_expressions: crate::NamedExpressions::default(), + ep_results: vec![], + } + } + + fn reset(&mut self, module: &Module) { + self.names.clear(); + self.namer.reset( + module, + crate::keywords::wgsl::RESERVED, + // an identifier must not start with two underscore + &[], + &[], + &["__"], + &mut self.names, + ); + self.named_expressions.clear(); + self.ep_results.clear(); + } + + fn is_builtin_wgsl_struct(&self, module: &Module, handle: Handle<crate::Type>) -> bool { + module + .special_types + .predeclared_types + .values() + .any(|t| *t == handle) + } + + pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult { + self.reset(module); + + // Save all ep result types + for (_, ep) in module.entry_points.iter().enumerate() { + if let Some(ref result) = ep.function.result { + self.ep_results.push((ep.stage, result.ty)); + } + } + + // Write all structs + for (handle, ty) in module.types.iter() { + if let TypeInner::Struct { ref members, .. } = ty.inner { + { + if !self.is_builtin_wgsl_struct(module, handle) { + self.write_struct(module, handle, members)?; + writeln!(self.out)?; + } + } + } + } + + // Write all named constants + let mut constants = module + .constants + .iter() + .filter(|&(_, c)| c.name.is_some()) + .peekable(); + while let Some((handle, _)) = constants.next() { + self.write_global_constant(module, handle)?; + // Add extra newline for readability on last iteration + if constants.peek().is_none() { + writeln!(self.out)?; + } + } + + // Write all globals + for (ty, global) in module.global_variables.iter() { + self.write_global(module, global, ty)?; + } + + if !module.global_variables.is_empty() { + // Add extra newline for readability + writeln!(self.out)?; + } + + // Write all regular functions + for (handle, function) in module.functions.iter() { + let fun_info = &info[handle]; + + let func_ctx = back::FunctionCtx { + ty: back::FunctionType::Function(handle), + info: fun_info, + expressions: &function.expressions, + named_expressions: &function.named_expressions, + }; + + // Write the function + self.write_function(module, function, &func_ctx)?; + + writeln!(self.out)?; + } + + // Write all entry points + for (index, ep) in module.entry_points.iter().enumerate() { + let attributes = match ep.stage { + ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)], + ShaderStage::Compute => vec![ + Attribute::Stage(ShaderStage::Compute), + Attribute::WorkGroupSize(ep.workgroup_size), + ], + }; + + self.write_attributes(&attributes)?; + // Add a newline after attribute + writeln!(self.out)?; + + let func_ctx = back::FunctionCtx { + ty: back::FunctionType::EntryPoint(index as u16), + info: info.get_entry_point(index), + expressions: &ep.function.expressions, + named_expressions: &ep.function.named_expressions, + }; + self.write_function(module, &ep.function, &func_ctx)?; + + if index < module.entry_points.len() - 1 { + writeln!(self.out)?; + } + } + + Ok(()) + } + + /// Helper method used to write struct name + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_struct_name(&mut self, module: &Module, handle: Handle<crate::Type>) -> BackendResult { + if module.types[handle].name.is_none() { + if let Some(&(stage, _)) = self.ep_results.iter().find(|&&(_, ty)| ty == handle) { + let name = match stage { + ShaderStage::Compute => "ComputeOutput", + ShaderStage::Fragment => "FragmentOutput", + ShaderStage::Vertex => "VertexOutput", + }; + + write!(self.out, "{name}")?; + return Ok(()); + } + } + + write!(self.out, "{}", self.names[&NameKey::Type(handle)])?; + + Ok(()) + } + + /// Helper method used to write + /// [functions](https://gpuweb.github.io/gpuweb/wgsl/#functions) + /// + /// # Notes + /// Ends in a newline + fn write_function( + &mut self, + module: &Module, + func: &crate::Function, + func_ctx: &back::FunctionCtx<'_>, + ) -> BackendResult { + let func_name = match func_ctx.ty { + back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)], + back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)], + }; + + // Write function name + write!(self.out, "fn {func_name}(")?; + + // Write function arguments + for (index, arg) in func.arguments.iter().enumerate() { + // Write argument attribute if a binding is present + if let Some(ref binding) = arg.binding { + self.write_attributes(&map_binding_to_attribute(binding))?; + } + // Write argument name + let argument_name = &self.names[&func_ctx.argument_key(index as u32)]; + + write!(self.out, "{argument_name}: ")?; + // Write argument type + self.write_type(module, arg.ty)?; + if index < func.arguments.len() - 1 { + // Add a separator between args + write!(self.out, ", ")?; + } + } + + write!(self.out, ")")?; + + // Write function return type + if let Some(ref result) = func.result { + write!(self.out, " -> ")?; + if let Some(ref binding) = result.binding { + self.write_attributes(&map_binding_to_attribute(binding))?; + } + self.write_type(module, result.ty)?; + } + + write!(self.out, " {{")?; + writeln!(self.out)?; + + // Write function local variables + for (handle, local) in func.local_variables.iter() { + // Write indentation (only for readability) + write!(self.out, "{}", back::INDENT)?; + + // Write the local name + // The leading space is important + write!(self.out, "var {}: ", self.names[&func_ctx.name_key(handle)])?; + + // Write the local type + self.write_type(module, local.ty)?; + + // Write the local initializer if needed + if let Some(init) = local.init { + // Put the equal signal only if there's a initializer + // The leading and trailing spaces aren't needed but help with readability + write!(self.out, " = ")?; + + // Write the constant + // `write_constant` adds no trailing or leading space/newline + self.write_expr(module, init, func_ctx)?; + } + + // Finish the local with `;` and add a newline (only for readability) + writeln!(self.out, ";")? + } + + if !func.local_variables.is_empty() { + writeln!(self.out)?; + } + + // Write the function body (statement list) + for sta in func.body.iter() { + // The indentation should always be 1 when writing the function body + self.write_stmt(module, sta, func_ctx, back::Level(1))?; + } + + writeln!(self.out, "}}")?; + + self.named_expressions.clear(); + + Ok(()) + } + + /// Helper method to write a attribute + fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult { + for attribute in attributes { + match *attribute { + Attribute::Location(id) => write!(self.out, "@location({id}) ")?, + Attribute::SecondBlendSource => write!(self.out, "@second_blend_source ")?, + Attribute::BuiltIn(builtin_attrib) => { + let builtin = builtin_str(builtin_attrib)?; + write!(self.out, "@builtin({builtin}) ")?; + } + Attribute::Stage(shader_stage) => { + let stage_str = match shader_stage { + ShaderStage::Vertex => "vertex", + ShaderStage::Fragment => "fragment", + ShaderStage::Compute => "compute", + }; + write!(self.out, "@{stage_str} ")?; + } + Attribute::WorkGroupSize(size) => { + write!( + self.out, + "@workgroup_size({}, {}, {}) ", + size[0], size[1], size[2] + )?; + } + Attribute::Binding(id) => write!(self.out, "@binding({id}) ")?, + Attribute::Group(id) => write!(self.out, "@group({id}) ")?, + Attribute::Invariant => write!(self.out, "@invariant ")?, + Attribute::Interpolate(interpolation, sampling) => { + if sampling.is_some() && sampling != Some(crate::Sampling::Center) { + write!( + self.out, + "@interpolate({}, {}) ", + interpolation_str( + interpolation.unwrap_or(crate::Interpolation::Perspective) + ), + sampling_str(sampling.unwrap_or(crate::Sampling::Center)) + )?; + } else if interpolation.is_some() + && interpolation != Some(crate::Interpolation::Perspective) + { + write!( + self.out, + "@interpolate({}) ", + interpolation_str( + interpolation.unwrap_or(crate::Interpolation::Perspective) + ) + )?; + } + } + }; + } + Ok(()) + } + + /// Helper method used to write structs + /// + /// # Notes + /// Ends in a newline + fn write_struct( + &mut self, + module: &Module, + handle: Handle<crate::Type>, + members: &[crate::StructMember], + ) -> BackendResult { + write!(self.out, "struct ")?; + self.write_struct_name(module, handle)?; + write!(self.out, " {{")?; + writeln!(self.out)?; + for (index, member) in members.iter().enumerate() { + // The indentation is only for readability + write!(self.out, "{}", back::INDENT)?; + if let Some(ref binding) = member.binding { + self.write_attributes(&map_binding_to_attribute(binding))?; + } + // Write struct member name and type + let member_name = &self.names[&NameKey::StructMember(handle, index as u32)]; + write!(self.out, "{member_name}: ")?; + self.write_type(module, member.ty)?; + write!(self.out, ",")?; + writeln!(self.out)?; + } + + write!(self.out, "}}")?; + + writeln!(self.out)?; + + Ok(()) + } + + /// Helper method used to write non image/sampler types + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult { + let inner = &module.types[ty].inner; + match *inner { + TypeInner::Struct { .. } => self.write_struct_name(module, ty)?, + ref other => self.write_value_type(module, other)?, + } + + Ok(()) + } + + /// Helper method used to write value types + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult { + match *inner { + TypeInner::Vector { size, scalar } => write!( + self.out, + "vec{}<{}>", + back::vector_size_str(size), + scalar_kind_str(scalar), + )?, + TypeInner::Sampler { comparison: false } => { + write!(self.out, "sampler")?; + } + TypeInner::Sampler { comparison: true } => { + write!(self.out, "sampler_comparison")?; + } + TypeInner::Image { + dim, + arrayed, + class, + } => { + // More about texture types: https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type + use crate::ImageClass as Ic; + + let dim_str = image_dimension_str(dim); + let arrayed_str = if arrayed { "_array" } else { "" }; + let (class_str, multisampled_str, format_str, storage_str) = match class { + Ic::Sampled { kind, multi } => ( + "", + if multi { "multisampled_" } else { "" }, + scalar_kind_str(crate::Scalar { kind, width: 4 }), + "", + ), + Ic::Depth { multi } => { + ("depth_", if multi { "multisampled_" } else { "" }, "", "") + } + Ic::Storage { format, access } => ( + "storage_", + "", + storage_format_str(format), + if access.contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE) + { + ",read_write" + } else if access.contains(crate::StorageAccess::LOAD) { + ",read" + } else { + ",write" + }, + ), + }; + write!( + self.out, + "texture_{class_str}{multisampled_str}{dim_str}{arrayed_str}" + )?; + + if !format_str.is_empty() { + write!(self.out, "<{format_str}{storage_str}>")?; + } + } + TypeInner::Scalar(scalar) => { + write!(self.out, "{}", scalar_kind_str(scalar))?; + } + TypeInner::Atomic(scalar) => { + write!(self.out, "atomic<{}>", scalar_kind_str(scalar))?; + } + TypeInner::Array { + base, + size, + stride: _, + } => { + // More info https://gpuweb.github.io/gpuweb/wgsl/#array-types + // array<A, 3> -- Constant array + // array<A> -- Dynamic array + write!(self.out, "array<")?; + match size { + crate::ArraySize::Constant(len) => { + self.write_type(module, base)?; + write!(self.out, ", {len}")?; + } + crate::ArraySize::Dynamic => { + self.write_type(module, base)?; + } + } + write!(self.out, ">")?; + } + TypeInner::BindingArray { base, size } => { + // More info https://github.com/gpuweb/gpuweb/issues/2105 + write!(self.out, "binding_array<")?; + match size { + crate::ArraySize::Constant(len) => { + self.write_type(module, base)?; + write!(self.out, ", {len}")?; + } + crate::ArraySize::Dynamic => { + self.write_type(module, base)?; + } + } + write!(self.out, ">")?; + } + TypeInner::Matrix { + columns, + rows, + scalar, + } => { + write!( + self.out, + "mat{}x{}<{}>", + back::vector_size_str(columns), + back::vector_size_str(rows), + scalar_kind_str(scalar) + )?; + } + TypeInner::Pointer { base, space } => { + let (address, maybe_access) = address_space_str(space); + // Everything but `AddressSpace::Handle` gives us a `address` name, but + // Naga IR never produces pointers to handles, so it doesn't matter much + // how we write such a type. Just write it as the base type alone. + if let Some(space) = address { + write!(self.out, "ptr<{space}, ")?; + } + self.write_type(module, base)?; + if address.is_some() { + if let Some(access) = maybe_access { + write!(self.out, ", {access}")?; + } + write!(self.out, ">")?; + } + } + TypeInner::ValuePointer { + size: None, + scalar, + space, + } => { + let (address, maybe_access) = address_space_str(space); + if let Some(space) = address { + write!(self.out, "ptr<{}, {}", space, scalar_kind_str(scalar))?; + if let Some(access) = maybe_access { + write!(self.out, ", {access}")?; + } + write!(self.out, ">")?; + } else { + return Err(Error::Unimplemented(format!( + "ValuePointer to AddressSpace::Handle {inner:?}" + ))); + } + } + TypeInner::ValuePointer { + size: Some(size), + scalar, + space, + } => { + let (address, maybe_access) = address_space_str(space); + if let Some(space) = address { + write!( + self.out, + "ptr<{}, vec{}<{}>", + space, + back::vector_size_str(size), + scalar_kind_str(scalar) + )?; + if let Some(access) = maybe_access { + write!(self.out, ", {access}")?; + } + write!(self.out, ">")?; + } else { + return Err(Error::Unimplemented(format!( + "ValuePointer to AddressSpace::Handle {inner:?}" + ))); + } + write!(self.out, ">")?; + } + _ => { + return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))); + } + } + + Ok(()) + } + /// Helper method used to write statements + /// + /// # Notes + /// Always adds a newline + fn write_stmt( + &mut self, + module: &Module, + stmt: &crate::Statement, + func_ctx: &back::FunctionCtx<'_>, + level: back::Level, + ) -> BackendResult { + use crate::{Expression, Statement}; + + match *stmt { + Statement::Emit(ref range) => { + for handle in range.clone() { + let info = &func_ctx.info[handle]; + let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) { + // Front end provides names for all variables at the start of writing. + // But we write them to step by step. We need to recache them + // Otherwise, we could accidentally write variable name instead of full expression. + // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords. + Some(self.namer.call(name)) + } else { + let expr = &func_ctx.expressions[handle]; + let min_ref_count = expr.bake_ref_count(); + // Forcefully creating baking expressions in some cases to help with readability + let required_baking_expr = match *expr { + Expression::ImageLoad { .. } + | Expression::ImageQuery { .. } + | Expression::ImageSample { .. } => true, + _ => false, + }; + if min_ref_count <= info.ref_count || required_baking_expr { + Some(format!("{}{}", back::BAKE_PREFIX, handle.index())) + } else { + None + } + }; + + if let Some(name) = expr_name { + write!(self.out, "{level}")?; + self.start_named_expr(module, handle, func_ctx, &name)?; + self.write_expr(module, handle, func_ctx)?; + self.named_expressions.insert(handle, name); + writeln!(self.out, ";")?; + } + } + } + // TODO: copy-paste from glsl-out + Statement::If { + condition, + ref accept, + ref reject, + } => { + write!(self.out, "{level}")?; + write!(self.out, "if ")?; + self.write_expr(module, condition, func_ctx)?; + writeln!(self.out, " {{")?; + + let l2 = level.next(); + for sta in accept { + // Increase indentation to help with readability + self.write_stmt(module, sta, func_ctx, l2)?; + } + + // If there are no statements in the reject block we skip writing it + // This is only for readability + if !reject.is_empty() { + writeln!(self.out, "{level}}} else {{")?; + + for sta in reject { + // Increase indentation to help with readability + self.write_stmt(module, sta, func_ctx, l2)?; + } + } + + writeln!(self.out, "{level}}}")? + } + Statement::Return { value } => { + write!(self.out, "{level}")?; + write!(self.out, "return")?; + if let Some(return_value) = value { + // The leading space is important + write!(self.out, " ")?; + self.write_expr(module, return_value, func_ctx)?; + } + writeln!(self.out, ";")?; + } + // TODO: copy-paste from glsl-out + Statement::Kill => { + write!(self.out, "{level}")?; + writeln!(self.out, "discard;")? + } + Statement::Store { pointer, value } => { + write!(self.out, "{level}")?; + + let is_atomic_pointer = func_ctx + .resolve_type(pointer, &module.types) + .is_atomic_pointer(&module.types); + + if is_atomic_pointer { + write!(self.out, "atomicStore(")?; + self.write_expr(module, pointer, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr_with_indirection( + module, + pointer, + func_ctx, + Indirection::Reference, + )?; + write!(self.out, " = ")?; + self.write_expr(module, value, func_ctx)?; + } + writeln!(self.out, ";")? + } + Statement::Call { + function, + ref arguments, + result, + } => { + write!(self.out, "{level}")?; + if let Some(expr) = result { + let name = format!("{}{}", back::BAKE_PREFIX, expr.index()); + self.start_named_expr(module, expr, func_ctx, &name)?; + self.named_expressions.insert(expr, name); + } + let func_name = &self.names[&NameKey::Function(function)]; + write!(self.out, "{func_name}(")?; + for (index, &argument) in arguments.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + self.write_expr(module, argument, func_ctx)?; + } + writeln!(self.out, ");")? + } + Statement::Atomic { + pointer, + ref fun, + value, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + let fun_str = fun.to_wgsl(); + write!(self.out, "atomic{fun_str}(")?; + self.write_expr(module, pointer, func_ctx)?; + if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { + write!(self.out, ", ")?; + self.write_expr(module, cmp, func_ctx)?; + } + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ");")? + } + Statement::WorkGroupUniformLoad { pointer, result } => { + write!(self.out, "{level}")?; + // TODO: Obey named expressions here. + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + write!(self.out, "workgroupUniformLoad(")?; + self.write_expr(module, pointer, func_ctx)?; + writeln!(self.out, ");")?; + } + Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => { + write!(self.out, "{level}")?; + write!(self.out, "textureStore(")?; + self.write_expr(module, image, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, coordinate, func_ctx)?; + if let Some(array_index_expr) = array_index { + write!(self.out, ", ")?; + self.write_expr(module, array_index_expr, func_ctx)?; + } + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ");")?; + } + // TODO: copy-paste from glsl-out + Statement::Block(ref block) => { + write!(self.out, "{level}")?; + writeln!(self.out, "{{")?; + for sta in block.iter() { + // Increase the indentation to help with readability + self.write_stmt(module, sta, func_ctx, level.next())? + } + writeln!(self.out, "{level}}}")? + } + Statement::Switch { + selector, + ref cases, + } => { + // Start the switch + write!(self.out, "{level}")?; + write!(self.out, "switch ")?; + self.write_expr(module, selector, func_ctx)?; + writeln!(self.out, " {{")?; + + let l2 = level.next(); + let mut new_case = true; + for case in cases { + if case.fall_through && !case.body.is_empty() { + // TODO: we could do the same workaround as we did for the HLSL backend + return Err(Error::Unimplemented( + "fall-through switch case block".into(), + )); + } + + match case.value { + crate::SwitchValue::I32(value) => { + if new_case { + write!(self.out, "{l2}case ")?; + } + write!(self.out, "{value}")?; + } + crate::SwitchValue::U32(value) => { + if new_case { + write!(self.out, "{l2}case ")?; + } + write!(self.out, "{value}u")?; + } + crate::SwitchValue::Default => { + if new_case { + if case.fall_through { + write!(self.out, "{l2}case ")?; + } else { + write!(self.out, "{l2}")?; + } + } + write!(self.out, "default")?; + } + } + + new_case = !case.fall_through; + + if case.fall_through { + write!(self.out, ", ")?; + } else { + writeln!(self.out, ": {{")?; + } + + for sta in case.body.iter() { + self.write_stmt(module, sta, func_ctx, l2.next())?; + } + + if !case.fall_through { + writeln!(self.out, "{l2}}}")?; + } + } + + writeln!(self.out, "{level}}}")? + } + Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + write!(self.out, "{level}")?; + writeln!(self.out, "loop {{")?; + + let l2 = level.next(); + for sta in body.iter() { + self.write_stmt(module, sta, func_ctx, l2)?; + } + + // The continuing is optional so we don't need to write it if + // it is empty, but the `break if` counts as a continuing statement + // so even if `continuing` is empty we must generate it if a + // `break if` exists + if !continuing.is_empty() || break_if.is_some() { + writeln!(self.out, "{l2}continuing {{")?; + for sta in continuing.iter() { + self.write_stmt(module, sta, func_ctx, l2.next())?; + } + + // The `break if` is always the last + // statement of the `continuing` block + if let Some(condition) = break_if { + // The trailing space is important + write!(self.out, "{}break if ", l2.next())?; + self.write_expr(module, condition, func_ctx)?; + // Close the `break if` statement + writeln!(self.out, ";")?; + } + + writeln!(self.out, "{l2}}}")?; + } + + writeln!(self.out, "{level}}}")? + } + Statement::Break => { + writeln!(self.out, "{level}break;")?; + } + Statement::Continue => { + writeln!(self.out, "{level}continue;")?; + } + Statement::Barrier(barrier) => { + if barrier.contains(crate::Barrier::STORAGE) { + writeln!(self.out, "{level}storageBarrier();")?; + } + + if barrier.contains(crate::Barrier::WORK_GROUP) { + writeln!(self.out, "{level}workgroupBarrier();")?; + } + } + Statement::RayQuery { .. } => unreachable!(), + } + + Ok(()) + } + + /// Return the sort of indirection that `expr`'s plain form evaluates to. + /// + /// An expression's 'plain form' is the most general rendition of that + /// expression into WGSL, lacking `&` or `*` operators: + /// + /// - The plain form of `LocalVariable(x)` is simply `x`, which is a reference + /// to the local variable's storage. + /// + /// - The plain form of `GlobalVariable(g)` is simply `g`, which is usually a + /// reference to the global variable's storage. However, globals in the + /// `Handle` address space are immutable, and `GlobalVariable` expressions for + /// those produce the value directly, not a pointer to it. Such + /// `GlobalVariable` expressions are `Ordinary`. + /// + /// - `Access` and `AccessIndex` are `Reference` when their `base` operand is a + /// pointer. If they are applied directly to a composite value, they are + /// `Ordinary`. + /// + /// Note that `FunctionArgument` expressions are never `Reference`, even when + /// the argument's type is `Pointer`. `FunctionArgument` always evaluates to the + /// argument's value directly, so any pointer it produces is merely the value + /// passed by the caller. + fn plain_form_indirection( + &self, + expr: Handle<crate::Expression>, + module: &Module, + func_ctx: &back::FunctionCtx<'_>, + ) -> Indirection { + use crate::Expression as Ex; + + // Named expressions are `let` expressions, which apply the Load Rule, + // so if their type is a Naga pointer, then that must be a WGSL pointer + // as well. + if self.named_expressions.contains_key(&expr) { + return Indirection::Ordinary; + } + + match func_ctx.expressions[expr] { + Ex::LocalVariable(_) => Indirection::Reference, + Ex::GlobalVariable(handle) => { + let global = &module.global_variables[handle]; + match global.space { + crate::AddressSpace::Handle => Indirection::Ordinary, + _ => Indirection::Reference, + } + } + Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => { + let base_ty = func_ctx.resolve_type(base, &module.types); + match *base_ty { + crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => { + Indirection::Reference + } + _ => Indirection::Ordinary, + } + } + _ => Indirection::Ordinary, + } + } + + fn start_named_expr( + &mut self, + module: &Module, + handle: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx, + name: &str, + ) -> BackendResult { + // Write variable name + write!(self.out, "let {name}")?; + if self.flags.contains(WriterFlags::EXPLICIT_TYPES) { + write!(self.out, ": ")?; + let ty = &func_ctx.info[handle].ty; + // Write variable type + match *ty { + proc::TypeResolution::Handle(handle) => { + self.write_type(module, handle)?; + } + proc::TypeResolution::Value(ref inner) => { + self.write_value_type(module, inner)?; + } + } + } + + write!(self.out, " = ")?; + Ok(()) + } + + /// Write the ordinary WGSL form of `expr`. + /// + /// See `write_expr_with_indirection` for details. + fn write_expr( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx<'_>, + ) -> BackendResult { + self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary) + } + + /// Write `expr` as a WGSL expression with the requested indirection. + /// + /// In terms of the WGSL grammar, the resulting expression is a + /// `singular_expression`. It may be parenthesized. This makes it suitable + /// for use as the operand of a unary or binary operator without worrying + /// about precedence. + /// + /// This does not produce newlines or indentation. + /// + /// The `requested` argument indicates (roughly) whether Naga + /// `Pointer`-valued expressions represent WGSL references or pointers. See + /// `Indirection` for details. + fn write_expr_with_indirection( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx<'_>, + requested: Indirection, + ) -> BackendResult { + // If the plain form of the expression is not what we need, emit the + // operator necessary to correct that. + let plain = self.plain_form_indirection(expr, module, func_ctx); + match (requested, plain) { + (Indirection::Ordinary, Indirection::Reference) => { + write!(self.out, "(&")?; + self.write_expr_plain_form(module, expr, func_ctx, plain)?; + write!(self.out, ")")?; + } + (Indirection::Reference, Indirection::Ordinary) => { + write!(self.out, "(*")?; + self.write_expr_plain_form(module, expr, func_ctx, plain)?; + write!(self.out, ")")?; + } + (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?, + } + + Ok(()) + } + + fn write_const_expression( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + ) -> BackendResult { + self.write_possibly_const_expression( + module, + expr, + &module.const_expressions, + |writer, expr| writer.write_const_expression(module, expr), + ) + } + + fn write_possibly_const_expression<E>( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + expressions: &crate::Arena<crate::Expression>, + write_expression: E, + ) -> BackendResult + where + E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult, + { + use crate::Expression; + + match expressions[expr] { + Expression::Literal(literal) => match literal { + crate::Literal::F32(value) => write!(self.out, "{}f", value)?, + crate::Literal::U32(value) => write!(self.out, "{}u", value)?, + crate::Literal::I32(value) => { + // `-2147483648i` is not valid WGSL. The most negative `i32` + // value can only be expressed in WGSL using AbstractInt and + // a unary negation operator. + if value == i32::MIN { + write!(self.out, "i32(-2147483648)")?; + } else { + write!(self.out, "{}i", value)?; + } + } + crate::Literal::Bool(value) => write!(self.out, "{}", value)?, + crate::Literal::F64(value) => write!(self.out, "{:?}lf", value)?, + crate::Literal::I64(_) => { + return Err(Error::Custom("unsupported i64 literal".to_string())); + } + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".into(), + )); + } + }, + Expression::Constant(handle) => { + let constant = &module.constants[handle]; + if constant.name.is_some() { + write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?; + } else { + self.write_const_expression(module, constant.init)?; + } + } + Expression::ZeroValue(ty) => { + self.write_type(module, ty)?; + write!(self.out, "()")?; + } + Expression::Compose { ty, ref components } => { + self.write_type(module, ty)?; + write!(self.out, "(")?; + for (index, component) in components.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + write_expression(self, *component)?; + } + write!(self.out, ")")? + } + Expression::Splat { size, value } => { + let size = back::vector_size_str(size); + write!(self.out, "vec{size}(")?; + write_expression(self, value)?; + write!(self.out, ")")?; + } + _ => unreachable!(), + } + + Ok(()) + } + + /// Write the 'plain form' of `expr`. + /// + /// An expression's 'plain form' is the most general rendition of that + /// expression into WGSL, lacking `&` or `*` operators. The plain forms of + /// `LocalVariable(x)` and `GlobalVariable(g)` are simply `x` and `g`. Such + /// Naga expressions represent both WGSL pointers and references; it's the + /// caller's responsibility to distinguish those cases appropriately. + fn write_expr_plain_form( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx<'_>, + indirection: Indirection, + ) -> BackendResult { + use crate::Expression; + + if let Some(name) = self.named_expressions.get(&expr) { + write!(self.out, "{name}")?; + return Ok(()); + } + + let expression = &func_ctx.expressions[expr]; + + // Write the plain WGSL form of a Naga expression. + // + // The plain form of `LocalVariable` and `GlobalVariable` expressions is + // simply the variable name; `*` and `&` operators are never emitted. + // + // The plain form of `Access` and `AccessIndex` expressions are WGSL + // `postfix_expression` forms for member/component access and + // subscripting. + match *expression { + Expression::Literal(_) + | Expression::Constant(_) + | Expression::ZeroValue(_) + | Expression::Compose { .. } + | Expression::Splat { .. } => { + self.write_possibly_const_expression( + module, + expr, + func_ctx.expressions, + |writer, expr| writer.write_expr(module, expr, func_ctx), + )?; + } + Expression::FunctionArgument(pos) => { + let name_key = func_ctx.argument_key(pos); + let name = &self.names[&name_key]; + write!(self.out, "{name}")?; + } + Expression::Binary { op, left, right } => { + write!(self.out, "(")?; + self.write_expr(module, left, func_ctx)?; + write!(self.out, " {} ", back::binary_operation_str(op))?; + self.write_expr(module, right, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Access { base, index } => { + self.write_expr_with_indirection(module, base, func_ctx, indirection)?; + write!(self.out, "[")?; + self.write_expr(module, index, func_ctx)?; + write!(self.out, "]")? + } + Expression::AccessIndex { base, index } => { + let base_ty_res = &func_ctx.info[base].ty; + let mut resolved = base_ty_res.inner_with(&module.types); + + self.write_expr_with_indirection(module, base, func_ctx, indirection)?; + + let base_ty_handle = match *resolved { + TypeInner::Pointer { base, space: _ } => { + resolved = &module.types[base].inner; + Some(base) + } + _ => base_ty_res.handle(), + }; + + match *resolved { + TypeInner::Vector { .. } => { + // Write vector access as a swizzle + write!(self.out, ".{}", back::COMPONENTS[index as usize])? + } + TypeInner::Matrix { .. } + | TypeInner::Array { .. } + | TypeInner::BindingArray { .. } + | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?, + TypeInner::Struct { .. } => { + // This will never panic in case the type is a `Struct`, this is not true + // for other types so we can only check while inside this match arm + let ty = base_ty_handle.unwrap(); + + write!( + self.out, + ".{}", + &self.names[&NameKey::StructMember(ty, index)] + )? + } + ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))), + } + } + Expression::ImageSample { + image, + sampler, + gather: None, + coordinate, + array_index, + offset, + level, + depth_ref, + } => { + use crate::SampleLevel as Sl; + + let suffix_cmp = match depth_ref { + Some(_) => "Compare", + None => "", + }; + let suffix_level = match level { + Sl::Auto => "", + Sl::Zero | Sl::Exact(_) => "Level", + Sl::Bias(_) => "Bias", + Sl::Gradient { .. } => "Grad", + }; + + write!(self.out, "textureSample{suffix_cmp}{suffix_level}(")?; + self.write_expr(module, image, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, sampler, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, coordinate, func_ctx)?; + + if let Some(array_index) = array_index { + write!(self.out, ", ")?; + self.write_expr(module, array_index, func_ctx)?; + } + + if let Some(depth_ref) = depth_ref { + write!(self.out, ", ")?; + self.write_expr(module, depth_ref, func_ctx)?; + } + + match level { + Sl::Auto => {} + Sl::Zero => { + // Level 0 is implied for depth comparison + if depth_ref.is_none() { + write!(self.out, ", 0.0")?; + } + } + Sl::Exact(expr) => { + write!(self.out, ", ")?; + self.write_expr(module, expr, func_ctx)?; + } + Sl::Bias(expr) => { + write!(self.out, ", ")?; + self.write_expr(module, expr, func_ctx)?; + } + Sl::Gradient { x, y } => { + write!(self.out, ", ")?; + self.write_expr(module, x, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, y, func_ctx)?; + } + } + + if let Some(offset) = offset { + write!(self.out, ", ")?; + self.write_const_expression(module, offset)?; + } + + write!(self.out, ")")?; + } + + Expression::ImageSample { + image, + sampler, + gather: Some(component), + coordinate, + array_index, + offset, + level: _, + depth_ref, + } => { + let suffix_cmp = match depth_ref { + Some(_) => "Compare", + None => "", + }; + + write!(self.out, "textureGather{suffix_cmp}(")?; + match *func_ctx.resolve_type(image, &module.types) { + TypeInner::Image { + class: crate::ImageClass::Depth { multi: _ }, + .. + } => {} + _ => { + write!(self.out, "{}, ", component as u8)?; + } + } + self.write_expr(module, image, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, sampler, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, coordinate, func_ctx)?; + + if let Some(array_index) = array_index { + write!(self.out, ", ")?; + self.write_expr(module, array_index, func_ctx)?; + } + + if let Some(depth_ref) = depth_ref { + write!(self.out, ", ")?; + self.write_expr(module, depth_ref, func_ctx)?; + } + + if let Some(offset) = offset { + write!(self.out, ", ")?; + self.write_const_expression(module, offset)?; + } + + write!(self.out, ")")?; + } + Expression::ImageQuery { image, query } => { + use crate::ImageQuery as Iq; + + let texture_function = match query { + Iq::Size { .. } => "textureDimensions", + Iq::NumLevels => "textureNumLevels", + Iq::NumLayers => "textureNumLayers", + Iq::NumSamples => "textureNumSamples", + }; + + write!(self.out, "{texture_function}(")?; + self.write_expr(module, image, func_ctx)?; + if let Iq::Size { level: Some(level) } = query { + write!(self.out, ", ")?; + self.write_expr(module, level, func_ctx)?; + }; + write!(self.out, ")")?; + } + + Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + write!(self.out, "textureLoad(")?; + self.write_expr(module, image, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, coordinate, func_ctx)?; + if let Some(array_index) = array_index { + write!(self.out, ", ")?; + self.write_expr(module, array_index, func_ctx)?; + } + if let Some(index) = sample.or(level) { + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + } + write!(self.out, ")")?; + } + Expression::GlobalVariable(handle) => { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, "{name}")?; + } + + Expression::As { + expr, + kind, + convert, + } => { + let inner = func_ctx.resolve_type(expr, &module.types); + match *inner { + TypeInner::Matrix { + columns, + rows, + scalar, + } => { + let scalar = crate::Scalar { + kind, + width: convert.unwrap_or(scalar.width), + }; + let scalar_kind_str = scalar_kind_str(scalar); + write!( + self.out, + "mat{}x{}<{}>", + back::vector_size_str(columns), + back::vector_size_str(rows), + scalar_kind_str + )?; + } + TypeInner::Vector { + size, + scalar: crate::Scalar { width, .. }, + } => { + let scalar = crate::Scalar { + kind, + width: convert.unwrap_or(width), + }; + let vector_size_str = back::vector_size_str(size); + let scalar_kind_str = scalar_kind_str(scalar); + if convert.is_some() { + write!(self.out, "vec{vector_size_str}<{scalar_kind_str}>")?; + } else { + write!(self.out, "bitcast<vec{vector_size_str}<{scalar_kind_str}>>")?; + } + } + TypeInner::Scalar(crate::Scalar { width, .. }) => { + let scalar = crate::Scalar { + kind, + width: convert.unwrap_or(width), + }; + let scalar_kind_str = scalar_kind_str(scalar); + if convert.is_some() { + write!(self.out, "{scalar_kind_str}")? + } else { + write!(self.out, "bitcast<{scalar_kind_str}>")? + } + } + _ => { + return Err(Error::Unimplemented(format!( + "write_expr expression::as {inner:?}" + ))); + } + }; + write!(self.out, "(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Load { pointer } => { + let is_atomic_pointer = func_ctx + .resolve_type(pointer, &module.types) + .is_atomic_pointer(&module.types); + + if is_atomic_pointer { + write!(self.out, "atomicLoad(")?; + self.write_expr(module, pointer, func_ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr_with_indirection( + module, + pointer, + func_ctx, + Indirection::Reference, + )?; + } + } + Expression::LocalVariable(handle) => { + write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])? + } + Expression::ArrayLength(expr) => { + write!(self.out, "arrayLength(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")?; + } + + Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + use crate::MathFunction as Mf; + + enum Function { + Regular(&'static str), + } + + let function = match fun { + Mf::Abs => Function::Regular("abs"), + Mf::Min => Function::Regular("min"), + Mf::Max => Function::Regular("max"), + Mf::Clamp => Function::Regular("clamp"), + Mf::Saturate => Function::Regular("saturate"), + // trigonometry + Mf::Cos => Function::Regular("cos"), + Mf::Cosh => Function::Regular("cosh"), + Mf::Sin => Function::Regular("sin"), + Mf::Sinh => Function::Regular("sinh"), + Mf::Tan => Function::Regular("tan"), + Mf::Tanh => Function::Regular("tanh"), + Mf::Acos => Function::Regular("acos"), + Mf::Asin => Function::Regular("asin"), + Mf::Atan => Function::Regular("atan"), + Mf::Atan2 => Function::Regular("atan2"), + Mf::Asinh => Function::Regular("asinh"), + Mf::Acosh => Function::Regular("acosh"), + Mf::Atanh => Function::Regular("atanh"), + Mf::Radians => Function::Regular("radians"), + Mf::Degrees => Function::Regular("degrees"), + // decomposition + Mf::Ceil => Function::Regular("ceil"), + Mf::Floor => Function::Regular("floor"), + Mf::Round => Function::Regular("round"), + Mf::Fract => Function::Regular("fract"), + Mf::Trunc => Function::Regular("trunc"), + Mf::Modf => Function::Regular("modf"), + Mf::Frexp => Function::Regular("frexp"), + Mf::Ldexp => Function::Regular("ldexp"), + // exponent + Mf::Exp => Function::Regular("exp"), + Mf::Exp2 => Function::Regular("exp2"), + Mf::Log => Function::Regular("log"), + Mf::Log2 => Function::Regular("log2"), + Mf::Pow => Function::Regular("pow"), + // geometry + Mf::Dot => Function::Regular("dot"), + Mf::Cross => Function::Regular("cross"), + Mf::Distance => Function::Regular("distance"), + Mf::Length => Function::Regular("length"), + Mf::Normalize => Function::Regular("normalize"), + Mf::FaceForward => Function::Regular("faceForward"), + Mf::Reflect => Function::Regular("reflect"), + Mf::Refract => Function::Regular("refract"), + // computational + Mf::Sign => Function::Regular("sign"), + Mf::Fma => Function::Regular("fma"), + Mf::Mix => Function::Regular("mix"), + Mf::Step => Function::Regular("step"), + Mf::SmoothStep => Function::Regular("smoothstep"), + Mf::Sqrt => Function::Regular("sqrt"), + Mf::InverseSqrt => Function::Regular("inverseSqrt"), + Mf::Transpose => Function::Regular("transpose"), + Mf::Determinant => Function::Regular("determinant"), + // bits + Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"), + Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"), + Mf::CountOneBits => Function::Regular("countOneBits"), + Mf::ReverseBits => Function::Regular("reverseBits"), + Mf::ExtractBits => Function::Regular("extractBits"), + Mf::InsertBits => Function::Regular("insertBits"), + Mf::FindLsb => Function::Regular("firstTrailingBit"), + Mf::FindMsb => Function::Regular("firstLeadingBit"), + // data packing + Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"), + Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"), + Mf::Pack2x16snorm => Function::Regular("pack2x16snorm"), + Mf::Pack2x16unorm => Function::Regular("pack2x16unorm"), + Mf::Pack2x16float => Function::Regular("pack2x16float"), + // data unpacking + Mf::Unpack4x8snorm => Function::Regular("unpack4x8snorm"), + Mf::Unpack4x8unorm => Function::Regular("unpack4x8unorm"), + Mf::Unpack2x16snorm => Function::Regular("unpack2x16snorm"), + Mf::Unpack2x16unorm => Function::Regular("unpack2x16unorm"), + Mf::Unpack2x16float => Function::Regular("unpack2x16float"), + Mf::Inverse | Mf::Outer => { + return Err(Error::UnsupportedMathFunction(fun)); + } + }; + + match function { + Function::Regular(fun_name) => { + write!(self.out, "{fun_name}(")?; + self.write_expr(module, arg, func_ctx)?; + for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + write!(self.out, ")")? + } + } + } + + Expression::Swizzle { + size, + vector, + pattern, + } => { + self.write_expr(module, vector, func_ctx)?; + write!(self.out, ".")?; + for &sc in pattern[..size as usize].iter() { + self.out.write_char(back::COMPONENTS[sc as usize])?; + } + } + Expression::Unary { op, expr } => { + let unary = match op { + crate::UnaryOperator::Negate => "-", + crate::UnaryOperator::LogicalNot => "!", + crate::UnaryOperator::BitwiseNot => "~", + }; + + write!(self.out, "{unary}(")?; + self.write_expr(module, expr, func_ctx)?; + + write!(self.out, ")")? + } + + Expression::Select { + condition, + accept, + reject, + } => { + write!(self.out, "select(")?; + self.write_expr(module, reject, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, accept, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, condition, func_ctx)?; + write!(self.out, ")")? + } + Expression::Derivative { axis, ctrl, expr } => { + use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; + let op = match (axis, ctrl) { + (Axis::X, Ctrl::Coarse) => "dpdxCoarse", + (Axis::X, Ctrl::Fine) => "dpdxFine", + (Axis::X, Ctrl::None) => "dpdx", + (Axis::Y, Ctrl::Coarse) => "dpdyCoarse", + (Axis::Y, Ctrl::Fine) => "dpdyFine", + (Axis::Y, Ctrl::None) => "dpdy", + (Axis::Width, Ctrl::Coarse) => "fwidthCoarse", + (Axis::Width, Ctrl::Fine) => "fwidthFine", + (Axis::Width, Ctrl::None) => "fwidth", + }; + write!(self.out, "{op}(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")? + } + Expression::Relational { fun, argument } => { + use crate::RelationalFunction as Rf; + + let fun_name = match fun { + Rf::All => "all", + Rf::Any => "any", + _ => return Err(Error::UnsupportedRelationalFunction(fun)), + }; + write!(self.out, "{fun_name}(")?; + + self.write_expr(module, argument, func_ctx)?; + + write!(self.out, ")")? + } + // Not supported yet + Expression::RayQueryGetIntersection { .. } => unreachable!(), + // Nothing to do here, since call expression already cached + Expression::CallResult(_) + | Expression::AtomicResult { .. } + | Expression::RayQueryProceedResult + | Expression::WorkGroupUniformLoadResult { .. } => {} + } + + Ok(()) + } + + /// Helper method used to write global variables + /// # Notes + /// Always adds a newline + fn write_global( + &mut self, + module: &Module, + global: &crate::GlobalVariable, + handle: Handle<crate::GlobalVariable>, + ) -> BackendResult { + // Write group and binding attributes if present + if let Some(ref binding) = global.binding { + self.write_attributes(&[ + Attribute::Group(binding.group), + Attribute::Binding(binding.binding), + ])?; + writeln!(self.out)?; + } + + // First write global name and address space if supported + write!(self.out, "var")?; + let (address, maybe_access) = address_space_str(global.space); + if let Some(space) = address { + write!(self.out, "<{space}")?; + if let Some(access) = maybe_access { + write!(self.out, ", {access}")?; + } + write!(self.out, ">")?; + } + write!( + self.out, + " {}: ", + &self.names[&NameKey::GlobalVariable(handle)] + )?; + + // Write global type + self.write_type(module, global.ty)?; + + // Write initializer + if let Some(init) = global.init { + write!(self.out, " = ")?; + self.write_const_expression(module, init)?; + } + + // End with semicolon + writeln!(self.out, ";")?; + + Ok(()) + } + + /// Helper method used to write global constants + /// + /// # Notes + /// Ends in a newline + fn write_global_constant( + &mut self, + module: &Module, + handle: Handle<crate::Constant>, + ) -> BackendResult { + let name = &self.names[&NameKey::Constant(handle)]; + // First write only constant name + write!(self.out, "const {name}: ")?; + self.write_type(module, module.constants[handle].ty)?; + write!(self.out, " = ")?; + let init = module.constants[handle].init; + self.write_const_expression(module, init)?; + writeln!(self.out, ";")?; + + Ok(()) + } + + // See https://github.com/rust-lang/rust-clippy/issues/4979. + #[allow(clippy::missing_const_for_fn)] + pub fn finish(self) -> W { + self.out + } +} + +fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { + use crate::BuiltIn as Bi; + + Ok(match built_in { + Bi::VertexIndex => "vertex_index", + Bi::InstanceIndex => "instance_index", + Bi::Position { .. } => "position", + Bi::FrontFacing => "front_facing", + Bi::FragDepth => "frag_depth", + Bi::LocalInvocationId => "local_invocation_id", + Bi::LocalInvocationIndex => "local_invocation_index", + Bi::GlobalInvocationId => "global_invocation_id", + Bi::WorkGroupId => "workgroup_id", + Bi::NumWorkGroups => "num_workgroups", + Bi::SampleIndex => "sample_index", + Bi::SampleMask => "sample_mask", + Bi::PrimitiveIndex => "primitive_index", + Bi::ViewIndex => "view_index", + Bi::BaseInstance + | Bi::BaseVertex + | Bi::ClipDistance + | Bi::CullDistance + | Bi::PointSize + | Bi::PointCoord + | Bi::WorkGroupSize => { + return Err(Error::Custom(format!("Unsupported builtin {built_in:?}"))) + } + }) +} + +const fn image_dimension_str(dim: crate::ImageDimension) -> &'static str { + use crate::ImageDimension as IDim; + + match dim { + IDim::D1 => "1d", + IDim::D2 => "2d", + IDim::D3 => "3d", + IDim::Cube => "cube", + } +} + +const fn scalar_kind_str(scalar: crate::Scalar) -> &'static str { + use crate::Scalar; + use crate::ScalarKind as Sk; + + match scalar { + Scalar { + kind: Sk::Float, + width: 8, + } => "f64", + Scalar { + kind: Sk::Float, + width: 4, + } => "f32", + Scalar { + kind: Sk::Sint, + width: 4, + } => "i32", + Scalar { + kind: Sk::Uint, + width: 4, + } => "u32", + Scalar { + kind: Sk::Bool, + width: 1, + } => "bool", + _ => unreachable!(), + } +} + +const fn storage_format_str(format: crate::StorageFormat) -> &'static str { + use crate::StorageFormat as Sf; + + match format { + Sf::R8Unorm => "r8unorm", + Sf::R8Snorm => "r8snorm", + Sf::R8Uint => "r8uint", + Sf::R8Sint => "r8sint", + Sf::R16Uint => "r16uint", + Sf::R16Sint => "r16sint", + Sf::R16Float => "r16float", + Sf::Rg8Unorm => "rg8unorm", + Sf::Rg8Snorm => "rg8snorm", + Sf::Rg8Uint => "rg8uint", + Sf::Rg8Sint => "rg8sint", + Sf::R32Uint => "r32uint", + Sf::R32Sint => "r32sint", + Sf::R32Float => "r32float", + Sf::Rg16Uint => "rg16uint", + Sf::Rg16Sint => "rg16sint", + Sf::Rg16Float => "rg16float", + Sf::Rgba8Unorm => "rgba8unorm", + Sf::Rgba8Snorm => "rgba8snorm", + Sf::Rgba8Uint => "rgba8uint", + Sf::Rgba8Sint => "rgba8sint", + Sf::Bgra8Unorm => "bgra8unorm", + Sf::Rgb10a2Uint => "rgb10a2uint", + Sf::Rgb10a2Unorm => "rgb10a2unorm", + Sf::Rg11b10Float => "rg11b10float", + Sf::Rg32Uint => "rg32uint", + Sf::Rg32Sint => "rg32sint", + Sf::Rg32Float => "rg32float", + Sf::Rgba16Uint => "rgba16uint", + Sf::Rgba16Sint => "rgba16sint", + Sf::Rgba16Float => "rgba16float", + Sf::Rgba32Uint => "rgba32uint", + Sf::Rgba32Sint => "rgba32sint", + Sf::Rgba32Float => "rgba32float", + Sf::R16Unorm => "r16unorm", + Sf::R16Snorm => "r16snorm", + Sf::Rg16Unorm => "rg16unorm", + Sf::Rg16Snorm => "rg16snorm", + Sf::Rgba16Unorm => "rgba16unorm", + Sf::Rgba16Snorm => "rgba16snorm", + } +} + +/// Helper function that returns the string corresponding to the WGSL interpolation qualifier +const fn interpolation_str(interpolation: crate::Interpolation) -> &'static str { + use crate::Interpolation as I; + + match interpolation { + I::Perspective => "perspective", + I::Linear => "linear", + I::Flat => "flat", + } +} + +/// Return the WGSL auxiliary qualifier for the given sampling value. +const fn sampling_str(sampling: crate::Sampling) -> &'static str { + use crate::Sampling as S; + + match sampling { + S::Center => "", + S::Centroid => "centroid", + S::Sample => "sample", + } +} + +const fn address_space_str( + space: crate::AddressSpace, +) -> (Option<&'static str>, Option<&'static str>) { + use crate::AddressSpace as As; + + ( + Some(match space { + As::Private => "private", + As::Uniform => "uniform", + As::Storage { access } => { + if access.contains(crate::StorageAccess::STORE) { + return (Some("storage"), Some("read_write")); + } else { + "storage" + } + } + As::PushConstant => "push_constant", + As::WorkGroup => "workgroup", + As::Handle => return (None, None), + As::Function => "function", + }), + None, + ) +} + +fn map_binding_to_attribute(binding: &crate::Binding) -> Vec<Attribute> { + match *binding { + crate::Binding::BuiltIn(built_in) => { + if let crate::BuiltIn::Position { invariant: true } = built_in { + vec![Attribute::BuiltIn(built_in), Attribute::Invariant] + } else { + vec![Attribute::BuiltIn(built_in)] + } + } + crate::Binding::Location { + location, + interpolation, + sampling, + second_blend_source: false, + } => vec![ + Attribute::Location(location), + Attribute::Interpolate(interpolation, sampling), + ], + crate::Binding::Location { + location, + interpolation, + sampling, + second_blend_source: true, + } => vec![ + Attribute::Location(location), + Attribute::SecondBlendSource, + Attribute::Interpolate(interpolation, sampling), + ], + } +} diff --git a/third_party/rust/naga/src/block.rs b/third_party/rust/naga/src/block.rs new file mode 100644 index 0000000000..0abda9da7c --- /dev/null +++ b/third_party/rust/naga/src/block.rs @@ -0,0 +1,123 @@ +use crate::{Span, Statement}; +use std::ops::{Deref, DerefMut, RangeBounds}; + +/// A code block is a vector of statements, with maybe a vector of spans. +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "serialize", serde(transparent))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct Block { + body: Vec<Statement>, + #[cfg_attr(feature = "serialize", serde(skip))] + span_info: Vec<Span>, +} + +impl Block { + pub const fn new() -> Self { + Self { + body: Vec::new(), + span_info: Vec::new(), + } + } + + pub fn from_vec(body: Vec<Statement>) -> Self { + let span_info = std::iter::repeat(Span::default()) + .take(body.len()) + .collect(); + Self { body, span_info } + } + + pub fn with_capacity(capacity: usize) -> Self { + Self { + body: Vec::with_capacity(capacity), + span_info: Vec::with_capacity(capacity), + } + } + + #[allow(unused_variables)] + pub fn push(&mut self, end: Statement, span: Span) { + self.body.push(end); + self.span_info.push(span); + } + + pub fn extend(&mut self, item: Option<(Statement, Span)>) { + if let Some((end, span)) = item { + self.push(end, span) + } + } + + pub fn extend_block(&mut self, other: Self) { + self.span_info.extend(other.span_info); + self.body.extend(other.body); + } + + pub fn append(&mut self, other: &mut Self) { + self.span_info.append(&mut other.span_info); + self.body.append(&mut other.body); + } + + pub fn cull<R: RangeBounds<usize> + Clone>(&mut self, range: R) { + self.span_info.drain(range.clone()); + self.body.drain(range); + } + + pub fn splice<R: RangeBounds<usize> + Clone>(&mut self, range: R, other: Self) { + self.span_info.splice(range.clone(), other.span_info); + self.body.splice(range, other.body); + } + pub fn span_iter(&self) -> impl Iterator<Item = (&Statement, &Span)> { + let span_iter = self.span_info.iter(); + self.body.iter().zip(span_iter) + } + + pub fn span_iter_mut(&mut self) -> impl Iterator<Item = (&mut Statement, Option<&mut Span>)> { + let span_iter = self.span_info.iter_mut().map(Some); + self.body.iter_mut().zip(span_iter) + } + + pub fn is_empty(&self) -> bool { + self.body.is_empty() + } + + pub fn len(&self) -> usize { + self.body.len() + } +} + +impl Deref for Block { + type Target = [Statement]; + fn deref(&self) -> &[Statement] { + &self.body + } +} + +impl DerefMut for Block { + fn deref_mut(&mut self) -> &mut [Statement] { + &mut self.body + } +} + +impl<'a> IntoIterator for &'a Block { + type Item = &'a Statement; + type IntoIter = std::slice::Iter<'a, Statement>; + + fn into_iter(self) -> std::slice::Iter<'a, Statement> { + self.iter() + } +} + +#[cfg(feature = "deserialize")] +impl<'de> serde::Deserialize<'de> for Block { + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> + where + D: serde::Deserializer<'de>, + { + Ok(Self::from_vec(Vec::deserialize(deserializer)?)) + } +} + +impl From<Vec<Statement>> for Block { + fn from(body: Vec<Statement>) -> Self { + Self::from_vec(body) + } +} diff --git a/third_party/rust/naga/src/compact/expressions.rs b/third_party/rust/naga/src/compact/expressions.rs new file mode 100644 index 0000000000..301bbe3240 --- /dev/null +++ b/third_party/rust/naga/src/compact/expressions.rs @@ -0,0 +1,389 @@ +use super::{HandleMap, HandleSet, ModuleMap}; +use crate::arena::{Arena, Handle}; + +pub struct ExpressionTracer<'tracer> { + pub constants: &'tracer Arena<crate::Constant>, + + /// The arena in which we are currently tracing expressions. + pub expressions: &'tracer Arena<crate::Expression>, + + /// The used map for `types`. + pub types_used: &'tracer mut HandleSet<crate::Type>, + + /// The used map for `constants`. + pub constants_used: &'tracer mut HandleSet<crate::Constant>, + + /// The used set for `arena`. + /// + /// This points to whatever arena holds the expressions we are + /// currently tracing: either a function's expression arena, or + /// the module's constant expression arena. + pub expressions_used: &'tracer mut HandleSet<crate::Expression>, + + /// The used set for the module's `const_expressions` arena. + /// + /// If `None`, we are already tracing the constant expressions, + /// and `expressions_used` already refers to their handle set. + pub const_expressions_used: Option<&'tracer mut HandleSet<crate::Expression>>, +} + +impl<'tracer> ExpressionTracer<'tracer> { + /// Propagate usage through `self.expressions`, starting with `self.expressions_used`. + /// + /// Treat `self.expressions_used` as the initial set of "known + /// live" expressions, and follow through to identify all + /// transitively used expressions. + /// + /// Mark types, constants, and constant expressions used directly + /// by `self.expressions` as used. Items used indirectly are not + /// marked. + /// + /// [fe]: crate::Function::expressions + /// [ce]: crate::Module::const_expressions + pub fn trace_expressions(&mut self) { + log::trace!( + "entering trace_expression of {}", + if self.const_expressions_used.is_some() { + "function expressions" + } else { + "const expressions" + } + ); + + // We don't need recursion or a work list. Because an + // expression may only refer to other expressions that precede + // it in the arena, it suffices to make a single pass over the + // arena from back to front, marking the referents of used + // expressions as used themselves. + for (handle, expr) in self.expressions.iter().rev() { + // If this expression isn't used, it doesn't matter what it uses. + if !self.expressions_used.contains(handle) { + continue; + } + + log::trace!("tracing new expression {:?}", expr); + + use crate::Expression as Ex; + match *expr { + // Expressions that do not contain handles that need to be traced. + Ex::Literal(_) + | Ex::FunctionArgument(_) + | Ex::GlobalVariable(_) + | Ex::LocalVariable(_) + | Ex::CallResult(_) + | Ex::RayQueryProceedResult => {} + + Ex::Constant(handle) => { + self.constants_used.insert(handle); + // Constants and expressions are mutually recursive, which + // complicates our nice one-pass algorithm. However, since + // constants don't refer to each other, we can get around + // this by looking *through* each constant and marking its + // initializer as used. Since `expr` refers to the constant, + // and the constant refers to the initializer, it must + // precede `expr` in the arena. + let init = self.constants[handle].init; + match self.const_expressions_used { + Some(ref mut used) => used.insert(init), + None => self.expressions_used.insert(init), + } + } + Ex::ZeroValue(ty) => self.types_used.insert(ty), + Ex::Compose { ty, ref components } => { + self.types_used.insert(ty); + self.expressions_used + .insert_iter(components.iter().cloned()); + } + Ex::Access { base, index } => self.expressions_used.insert_iter([base, index]), + Ex::AccessIndex { base, index: _ } => self.expressions_used.insert(base), + Ex::Splat { size: _, value } => self.expressions_used.insert(value), + Ex::Swizzle { + size: _, + vector, + pattern: _, + } => self.expressions_used.insert(vector), + Ex::Load { pointer } => self.expressions_used.insert(pointer), + Ex::ImageSample { + image, + sampler, + gather: _, + coordinate, + array_index, + offset, + ref level, + depth_ref, + } => { + self.expressions_used + .insert_iter([image, sampler, coordinate]); + self.expressions_used.insert_iter(array_index); + match self.const_expressions_used { + Some(ref mut used) => used.insert_iter(offset), + None => self.expressions_used.insert_iter(offset), + } + use crate::SampleLevel as Sl; + match *level { + Sl::Auto | Sl::Zero => {} + Sl::Exact(expr) | Sl::Bias(expr) => self.expressions_used.insert(expr), + Sl::Gradient { x, y } => self.expressions_used.insert_iter([x, y]), + } + self.expressions_used.insert_iter(depth_ref); + } + Ex::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + self.expressions_used.insert(image); + self.expressions_used.insert(coordinate); + self.expressions_used.insert_iter(array_index); + self.expressions_used.insert_iter(sample); + self.expressions_used.insert_iter(level); + } + Ex::ImageQuery { image, ref query } => { + self.expressions_used.insert(image); + use crate::ImageQuery as Iq; + match *query { + Iq::Size { level } => self.expressions_used.insert_iter(level), + Iq::NumLevels | Iq::NumLayers | Iq::NumSamples => {} + } + } + Ex::Unary { op: _, expr } => self.expressions_used.insert(expr), + Ex::Binary { op: _, left, right } => { + self.expressions_used.insert_iter([left, right]); + } + Ex::Select { + condition, + accept, + reject, + } => self + .expressions_used + .insert_iter([condition, accept, reject]), + Ex::Derivative { + axis: _, + ctrl: _, + expr, + } => self.expressions_used.insert(expr), + Ex::Relational { fun: _, argument } => self.expressions_used.insert(argument), + Ex::Math { + fun: _, + arg, + arg1, + arg2, + arg3, + } => { + self.expressions_used.insert(arg); + self.expressions_used.insert_iter(arg1); + self.expressions_used.insert_iter(arg2); + self.expressions_used.insert_iter(arg3); + } + Ex::As { + expr, + kind: _, + convert: _, + } => self.expressions_used.insert(expr), + Ex::AtomicResult { ty, comparison: _ } => self.types_used.insert(ty), + Ex::WorkGroupUniformLoadResult { ty } => self.types_used.insert(ty), + Ex::ArrayLength(expr) => self.expressions_used.insert(expr), + Ex::RayQueryGetIntersection { + query, + committed: _, + } => self.expressions_used.insert(query), + } + } + } +} + +impl ModuleMap { + /// Fix up all handles in `expr`. + /// + /// Use the expression handle remappings in `operand_map`, and all + /// other mappings from `self`. + pub fn adjust_expression( + &self, + expr: &mut crate::Expression, + operand_map: &HandleMap<crate::Expression>, + ) { + let adjust = |expr: &mut Handle<crate::Expression>| { + operand_map.adjust(expr); + }; + + use crate::Expression as Ex; + match *expr { + // Expressions that do not contain handles that need to be adjusted. + Ex::Literal(_) + | Ex::FunctionArgument(_) + | Ex::GlobalVariable(_) + | Ex::LocalVariable(_) + | Ex::CallResult(_) + | Ex::RayQueryProceedResult => {} + + // Expressions that contain handles that need to be adjusted. + Ex::Constant(ref mut constant) => self.constants.adjust(constant), + Ex::ZeroValue(ref mut ty) => self.types.adjust(ty), + Ex::Compose { + ref mut ty, + ref mut components, + } => { + self.types.adjust(ty); + for component in components { + adjust(component); + } + } + Ex::Access { + ref mut base, + ref mut index, + } => { + adjust(base); + adjust(index); + } + Ex::AccessIndex { + ref mut base, + index: _, + } => adjust(base), + Ex::Splat { + size: _, + ref mut value, + } => adjust(value), + Ex::Swizzle { + size: _, + ref mut vector, + pattern: _, + } => adjust(vector), + Ex::Load { ref mut pointer } => adjust(pointer), + Ex::ImageSample { + ref mut image, + ref mut sampler, + gather: _, + ref mut coordinate, + ref mut array_index, + ref mut offset, + ref mut level, + ref mut depth_ref, + } => { + adjust(image); + adjust(sampler); + adjust(coordinate); + operand_map.adjust_option(array_index); + if let Some(ref mut offset) = *offset { + self.const_expressions.adjust(offset); + } + self.adjust_sample_level(level, operand_map); + operand_map.adjust_option(depth_ref); + } + Ex::ImageLoad { + ref mut image, + ref mut coordinate, + ref mut array_index, + ref mut sample, + ref mut level, + } => { + adjust(image); + adjust(coordinate); + operand_map.adjust_option(array_index); + operand_map.adjust_option(sample); + operand_map.adjust_option(level); + } + Ex::ImageQuery { + ref mut image, + ref mut query, + } => { + adjust(image); + self.adjust_image_query(query, operand_map); + } + Ex::Unary { + op: _, + ref mut expr, + } => adjust(expr), + Ex::Binary { + op: _, + ref mut left, + ref mut right, + } => { + adjust(left); + adjust(right); + } + Ex::Select { + ref mut condition, + ref mut accept, + ref mut reject, + } => { + adjust(condition); + adjust(accept); + adjust(reject); + } + Ex::Derivative { + axis: _, + ctrl: _, + ref mut expr, + } => adjust(expr), + Ex::Relational { + fun: _, + ref mut argument, + } => adjust(argument), + Ex::Math { + fun: _, + ref mut arg, + ref mut arg1, + ref mut arg2, + ref mut arg3, + } => { + adjust(arg); + operand_map.adjust_option(arg1); + operand_map.adjust_option(arg2); + operand_map.adjust_option(arg3); + } + Ex::As { + ref mut expr, + kind: _, + convert: _, + } => adjust(expr), + Ex::AtomicResult { + ref mut ty, + comparison: _, + } => self.types.adjust(ty), + Ex::WorkGroupUniformLoadResult { ref mut ty } => self.types.adjust(ty), + Ex::ArrayLength(ref mut expr) => adjust(expr), + Ex::RayQueryGetIntersection { + ref mut query, + committed: _, + } => adjust(query), + } + } + + fn adjust_sample_level( + &self, + level: &mut crate::SampleLevel, + operand_map: &HandleMap<crate::Expression>, + ) { + let adjust = |expr: &mut Handle<crate::Expression>| operand_map.adjust(expr); + + use crate::SampleLevel as Sl; + match *level { + Sl::Auto | Sl::Zero => {} + Sl::Exact(ref mut expr) => adjust(expr), + Sl::Bias(ref mut expr) => adjust(expr), + Sl::Gradient { + ref mut x, + ref mut y, + } => { + adjust(x); + adjust(y); + } + } + } + + fn adjust_image_query( + &self, + query: &mut crate::ImageQuery, + operand_map: &HandleMap<crate::Expression>, + ) { + use crate::ImageQuery as Iq; + + match *query { + Iq::Size { ref mut level } => operand_map.adjust_option(level), + Iq::NumLevels | Iq::NumLayers | Iq::NumSamples => {} + } + } +} diff --git a/third_party/rust/naga/src/compact/functions.rs b/third_party/rust/naga/src/compact/functions.rs new file mode 100644 index 0000000000..b0d08c7e96 --- /dev/null +++ b/third_party/rust/naga/src/compact/functions.rs @@ -0,0 +1,106 @@ +use super::handle_set_map::HandleSet; +use super::{FunctionMap, ModuleMap}; + +pub struct FunctionTracer<'a> { + pub function: &'a crate::Function, + pub constants: &'a crate::Arena<crate::Constant>, + + pub types_used: &'a mut HandleSet<crate::Type>, + pub constants_used: &'a mut HandleSet<crate::Constant>, + pub const_expressions_used: &'a mut HandleSet<crate::Expression>, + + /// Function-local expressions used. + pub expressions_used: HandleSet<crate::Expression>, +} + +impl<'a> FunctionTracer<'a> { + pub fn trace(&mut self) { + for argument in self.function.arguments.iter() { + self.types_used.insert(argument.ty); + } + + if let Some(ref result) = self.function.result { + self.types_used.insert(result.ty); + } + + for (_, local) in self.function.local_variables.iter() { + self.types_used.insert(local.ty); + if let Some(init) = local.init { + self.expressions_used.insert(init); + } + } + + // Treat named expressions as alive, for the sake of our test suite, + // which uses `let blah = expr;` to exercise lots of things. + for (&value, _name) in &self.function.named_expressions { + self.expressions_used.insert(value); + } + + self.trace_block(&self.function.body); + + // Given that `trace_block` has marked the expressions used + // directly by statements, walk the arena to find all + // expressions used, directly or indirectly. + self.as_expression().trace_expressions(); + } + + fn as_expression(&mut self) -> super::expressions::ExpressionTracer { + super::expressions::ExpressionTracer { + constants: self.constants, + expressions: &self.function.expressions, + + types_used: self.types_used, + constants_used: self.constants_used, + expressions_used: &mut self.expressions_used, + const_expressions_used: Some(&mut self.const_expressions_used), + } + } +} + +impl FunctionMap { + pub fn compact( + &self, + function: &mut crate::Function, + module_map: &ModuleMap, + reuse: &mut crate::NamedExpressions, + ) { + assert!(reuse.is_empty()); + + for argument in function.arguments.iter_mut() { + module_map.types.adjust(&mut argument.ty); + } + + if let Some(ref mut result) = function.result { + module_map.types.adjust(&mut result.ty); + } + + for (_, local) in function.local_variables.iter_mut() { + log::trace!("adjusting local variable {:?}", local.name); + module_map.types.adjust(&mut local.ty); + if let Some(ref mut init) = local.init { + self.expressions.adjust(init); + } + } + + // Drop unused expressions, reusing existing storage. + function.expressions.retain_mut(|handle, expr| { + if self.expressions.used(handle) { + module_map.adjust_expression(expr, &self.expressions); + true + } else { + false + } + }); + + // Adjust named expressions. + for (mut handle, name) in function.named_expressions.drain(..) { + self.expressions.adjust(&mut handle); + reuse.insert(handle, name); + } + std::mem::swap(&mut function.named_expressions, reuse); + assert!(reuse.is_empty()); + + // Adjust statements. + self.adjust_body(function); + } +} diff --git a/third_party/rust/naga/src/compact/handle_set_map.rs b/third_party/rust/naga/src/compact/handle_set_map.rs new file mode 100644 index 0000000000..c716ca8294 --- /dev/null +++ b/third_party/rust/naga/src/compact/handle_set_map.rs @@ -0,0 +1,170 @@ +use crate::arena::{Arena, Handle, Range, UniqueArena}; + +type Index = std::num::NonZeroU32; + +/// A set of `Handle<T>` values. +pub struct HandleSet<T> { + /// Bound on zero-based indexes of handles stored in this set. + len: usize, + + /// `members[i]` is true if the handle with zero-based index `i` + /// is a member. + members: bit_set::BitSet, + + /// This type is indexed by values of type `T`. + as_keys: std::marker::PhantomData<T>, +} + +impl<T> HandleSet<T> { + pub fn for_arena(arena: &impl ArenaType<T>) -> Self { + let len = arena.len(); + Self { + len, + members: bit_set::BitSet::with_capacity(len), + as_keys: std::marker::PhantomData, + } + } + + /// Add `handle` to the set. + pub fn insert(&mut self, handle: Handle<T>) { + // Note that, oddly, `Handle::index` does not return a 1-based + // `Index`, but rather a zero-based `usize`. + self.members.insert(handle.index()); + } + + /// Add handles from `iter` to the set. + pub fn insert_iter(&mut self, iter: impl IntoIterator<Item = Handle<T>>) { + for handle in iter { + self.insert(handle); + } + } + + pub fn contains(&self, handle: Handle<T>) -> bool { + // Note that, oddly, `Handle::index` does not return a 1-based + // `Index`, but rather a zero-based `usize`. + self.members.contains(handle.index()) + } +} + +pub trait ArenaType<T> { + fn len(&self) -> usize; +} + +impl<T> ArenaType<T> for Arena<T> { + fn len(&self) -> usize { + self.len() + } +} + +impl<T: std::hash::Hash + Eq> ArenaType<T> for UniqueArena<T> { + fn len(&self) -> usize { + self.len() + } +} + +/// A map from old handle indices to new, compressed handle indices. +pub struct HandleMap<T> { + /// The indices assigned to handles in the compacted module. + /// + /// If `new_index[i]` is `Some(n)`, then `n` is the 1-based + /// `Index` of the compacted `Handle` corresponding to the + /// pre-compacted `Handle` whose zero-based index is `i`. ("Clear + /// as mud.") + new_index: Vec<Option<Index>>, + + /// This type is indexed by values of type `T`. + as_keys: std::marker::PhantomData<T>, +} + +impl<T: 'static> HandleMap<T> { + pub fn from_set(set: HandleSet<T>) -> Self { + let mut next_index = Index::new(1).unwrap(); + Self { + new_index: (0..set.len) + .map(|zero_based_index| { + if set.members.contains(zero_based_index) { + // This handle will be retained in the compacted version, + // so assign it a new index. + let this = next_index; + next_index = next_index.checked_add(1).unwrap(); + Some(this) + } else { + // This handle will be omitted in the compacted version. + None + } + }) + .collect(), + as_keys: std::marker::PhantomData, + } + } + + /// Return true if `old` is used in the compacted module. + pub fn used(&self, old: Handle<T>) -> bool { + self.new_index[old.index()].is_some() + } + + /// Return the counterpart to `old` in the compacted module. + /// + /// If we thought `old` wouldn't be used in the compacted module, return + /// `None`. + pub fn try_adjust(&self, old: Handle<T>) -> Option<Handle<T>> { + log::trace!( + "adjusting {} handle [{}] -> [{:?}]", + std::any::type_name::<T>(), + old.index() + 1, + self.new_index[old.index()] + ); + // Note that `Handle::index` returns a zero-based index, + // but `Handle::new` accepts a 1-based `Index`. + self.new_index[old.index()].map(Handle::new) + } + + /// Return the counterpart to `old` in the compacted module. + /// + /// If we thought `old` wouldn't be used in the compacted module, panic. + pub fn adjust(&self, handle: &mut Handle<T>) { + *handle = self.try_adjust(*handle).unwrap(); + } + + /// Like `adjust`, but for optional handles. + pub fn adjust_option(&self, handle: &mut Option<Handle<T>>) { + if let Some(ref mut handle) = *handle { + self.adjust(handle); + } + } + + /// Shrink `range` to include only used handles. + /// + /// Fortunately, compaction doesn't arbitrarily scramble the expressions + /// in the arena, but instead preserves the order of the elements while + /// squeezing out unused ones. That means that a contiguous range in the + /// pre-compacted arena always maps to a contiguous range in the + /// post-compacted arena. So we just need to adjust the endpoints. + /// + /// Compaction may have eliminated the endpoints themselves. + /// + /// Use `compacted_arena` to bounds-check the result. + pub fn adjust_range(&self, range: &mut Range<T>, compacted_arena: &Arena<T>) { + let mut index_range = range.zero_based_index_range(); + let compacted; + // Remember that the indices we retrieve from `new_index` are 1-based + // compacted indices, but the index range we're computing is zero-based + // compacted indices. + if let Some(first1) = index_range.find_map(|i| self.new_index[i as usize]) { + // The first call to `find_map` mutated `index_range` to hold the + // remainder of original range, which is exactly the range we need + // to search for the new last handle. + if let Some(last1) = index_range.rev().find_map(|i| self.new_index[i as usize]) { + // Build a zero-based end-exclusive range, given one-based handle indices. + compacted = first1.get() - 1..last1.get(); + } else { + // The range contains only a single live handle, which + // we identified with the first `find_map` call. + compacted = first1.get() - 1..first1.get(); + } + } else { + compacted = 0..0; + }; + *range = Range::from_zero_based_index_range(compacted, compacted_arena); + } +} diff --git a/third_party/rust/naga/src/compact/mod.rs b/third_party/rust/naga/src/compact/mod.rs new file mode 100644 index 0000000000..b4e57ed5c9 --- /dev/null +++ b/third_party/rust/naga/src/compact/mod.rs @@ -0,0 +1,307 @@ +mod expressions; +mod functions; +mod handle_set_map; +mod statements; +mod types; + +use crate::{arena, compact::functions::FunctionTracer}; +use handle_set_map::{HandleMap, HandleSet}; + +/// Remove unused types, expressions, and constants from `module`. +/// +/// Assuming that all globals, named constants, special types, +/// functions and entry points in `module` are used, determine which +/// types, constants, and expressions (both function-local and global +/// constant expressions) are actually used, and remove the rest, +/// adjusting all handles as necessary. The result should be a module +/// functionally identical to the original. +/// +/// This may be useful to apply to modules generated in the snapshot +/// tests. Our backends often generate temporary names based on handle +/// indices, which means that adding or removing unused arena entries +/// can affect the output even though they have no semantic effect. +/// Such meaningless changes add noise to snapshot diffs, making +/// accurate patch review difficult. Compacting the modules before +/// generating snapshots makes the output independent of unused arena +/// entries. +/// +/// # Panics +/// +/// If `module` has not passed validation, this may panic. +pub fn compact(module: &mut crate::Module) { + let mut module_tracer = ModuleTracer::new(module); + + // We treat all globals as used by definition. + log::trace!("tracing global variables"); + { + for (_, global) in module.global_variables.iter() { + log::trace!("tracing global {:?}", global.name); + module_tracer.types_used.insert(global.ty); + if let Some(init) = global.init { + module_tracer.const_expressions_used.insert(init); + } + } + } + + // We treat all special types as used by definition. + module_tracer.trace_special_types(&module.special_types); + + // We treat all named constants as used by definition. + for (handle, constant) in module.constants.iter() { + if constant.name.is_some() { + module_tracer.constants_used.insert(handle); + module_tracer.const_expressions_used.insert(constant.init); + } + } + + // We assume that all functions are used. + // + // Observe which types, constant expressions, constants, and + // expressions each function uses, and produce maps for each + // function from pre-compaction to post-compaction expression + // handles. + log::trace!("tracing functions"); + let function_maps: Vec<FunctionMap> = module + .functions + .iter() + .map(|(_, f)| { + log::trace!("tracing function {:?}", f.name); + let mut function_tracer = module_tracer.as_function(f); + function_tracer.trace(); + FunctionMap::from(function_tracer) + }) + .collect(); + + // Similarly, observe what each entry point actually uses. + log::trace!("tracing entry points"); + let entry_point_maps: Vec<FunctionMap> = module + .entry_points + .iter() + .map(|e| { + log::trace!("tracing entry point {:?}", e.function.name); + let mut used = module_tracer.as_function(&e.function); + used.trace(); + FunctionMap::from(used) + }) + .collect(); + + // Given that the above steps have marked all the constant + // expressions used directly by globals, constants, functions, and + // entry points, walk the constant expression arena to find all + // constant expressions used, directly or indirectly. + module_tracer.as_const_expression().trace_expressions(); + + // Constants' initializers are taken care of already, because + // expression tracing sees through constants. But we still need to + // note type usage. + for (handle, constant) in module.constants.iter() { + if module_tracer.constants_used.contains(handle) { + module_tracer.types_used.insert(constant.ty); + } + } + + // Treat all named types as used. + for (handle, ty) in module.types.iter() { + log::trace!("tracing type {:?}, name {:?}", handle, ty.name); + if ty.name.is_some() { + module_tracer.types_used.insert(handle); + } + } + + // Propagate usage through types. + module_tracer.as_type().trace_types(); + + // Now that we know what is used and what is never touched, + // produce maps from the `Handle`s that appear in `module` now to + // the corresponding `Handle`s that will refer to the same items + // in the compacted module. + let module_map = ModuleMap::from(module_tracer); + + // Drop unused types from the type arena. + // + // `FastIndexSet`s don't have an underlying Vec<T> that we can + // steal, compact in place, and then rebuild the `FastIndexSet` + // from. So we have to rebuild the type arena from scratch. + log::trace!("compacting types"); + let mut new_types = arena::UniqueArena::new(); + for (old_handle, mut ty, span) in module.types.drain_all() { + if let Some(expected_new_handle) = module_map.types.try_adjust(old_handle) { + module_map.adjust_type(&mut ty); + let actual_new_handle = new_types.insert(ty, span); + assert_eq!(actual_new_handle, expected_new_handle); + } + } + module.types = new_types; + log::trace!("adjusting special types"); + module_map.adjust_special_types(&mut module.special_types); + + // Drop unused constant expressions, reusing existing storage. + log::trace!("adjusting constant expressions"); + module.const_expressions.retain_mut(|handle, expr| { + if module_map.const_expressions.used(handle) { + module_map.adjust_expression(expr, &module_map.const_expressions); + true + } else { + false + } + }); + + // Drop unused constants in place, reusing existing storage. + log::trace!("adjusting constants"); + module.constants.retain_mut(|handle, constant| { + if module_map.constants.used(handle) { + module_map.types.adjust(&mut constant.ty); + module_map.const_expressions.adjust(&mut constant.init); + true + } else { + false + } + }); + + // Adjust global variables' types and initializers. + log::trace!("adjusting global variables"); + for (_, global) in module.global_variables.iter_mut() { + log::trace!("adjusting global {:?}", global.name); + module_map.types.adjust(&mut global.ty); + if let Some(ref mut init) = global.init { + module_map.const_expressions.adjust(init); + } + } + + // Temporary storage to help us reuse allocations of existing + // named expression tables. + let mut reused_named_expressions = crate::NamedExpressions::default(); + + // Compact each function. + for ((_, function), map) in module.functions.iter_mut().zip(function_maps.iter()) { + log::trace!("compacting function {:?}", function.name); + map.compact(function, &module_map, &mut reused_named_expressions); + } + + // Compact each entry point. + for (entry, map) in module.entry_points.iter_mut().zip(entry_point_maps.iter()) { + log::trace!("compacting entry point {:?}", entry.function.name); + map.compact( + &mut entry.function, + &module_map, + &mut reused_named_expressions, + ); + } +} + +struct ModuleTracer<'module> { + module: &'module crate::Module, + types_used: HandleSet<crate::Type>, + constants_used: HandleSet<crate::Constant>, + const_expressions_used: HandleSet<crate::Expression>, +} + +impl<'module> ModuleTracer<'module> { + fn new(module: &'module crate::Module) -> Self { + Self { + module, + types_used: HandleSet::for_arena(&module.types), + constants_used: HandleSet::for_arena(&module.constants), + const_expressions_used: HandleSet::for_arena(&module.const_expressions), + } + } + + fn trace_special_types(&mut self, special_types: &crate::SpecialTypes) { + let crate::SpecialTypes { + ref ray_desc, + ref ray_intersection, + ref predeclared_types, + } = *special_types; + + if let Some(ray_desc) = *ray_desc { + self.types_used.insert(ray_desc); + } + if let Some(ray_intersection) = *ray_intersection { + self.types_used.insert(ray_intersection); + } + for (_, &handle) in predeclared_types { + self.types_used.insert(handle); + } + } + + fn as_type(&mut self) -> types::TypeTracer { + types::TypeTracer { + types: &self.module.types, + types_used: &mut self.types_used, + } + } + + fn as_const_expression(&mut self) -> expressions::ExpressionTracer { + expressions::ExpressionTracer { + expressions: &self.module.const_expressions, + constants: &self.module.constants, + types_used: &mut self.types_used, + constants_used: &mut self.constants_used, + expressions_used: &mut self.const_expressions_used, + const_expressions_used: None, + } + } + + pub fn as_function<'tracer>( + &'tracer mut self, + function: &'tracer crate::Function, + ) -> FunctionTracer<'tracer> { + FunctionTracer { + function, + constants: &self.module.constants, + types_used: &mut self.types_used, + constants_used: &mut self.constants_used, + const_expressions_used: &mut self.const_expressions_used, + expressions_used: HandleSet::for_arena(&function.expressions), + } + } +} + +struct ModuleMap { + types: HandleMap<crate::Type>, + constants: HandleMap<crate::Constant>, + const_expressions: HandleMap<crate::Expression>, +} + +impl From<ModuleTracer<'_>> for ModuleMap { + fn from(used: ModuleTracer) -> Self { + ModuleMap { + types: HandleMap::from_set(used.types_used), + constants: HandleMap::from_set(used.constants_used), + const_expressions: HandleMap::from_set(used.const_expressions_used), + } + } +} + +impl ModuleMap { + fn adjust_special_types(&self, special: &mut crate::SpecialTypes) { + let crate::SpecialTypes { + ref mut ray_desc, + ref mut ray_intersection, + ref mut predeclared_types, + } = *special; + + if let Some(ref mut ray_desc) = *ray_desc { + self.types.adjust(ray_desc); + } + if let Some(ref mut ray_intersection) = *ray_intersection { + self.types.adjust(ray_intersection); + } + + for handle in predeclared_types.values_mut() { + self.types.adjust(handle); + } + } +} + +struct FunctionMap { + expressions: HandleMap<crate::Expression>, +} + +impl From<FunctionTracer<'_>> for FunctionMap { + fn from(used: FunctionTracer) -> Self { + FunctionMap { + expressions: HandleMap::from_set(used.expressions_used), + } + } +} diff --git a/third_party/rust/naga/src/compact/statements.rs b/third_party/rust/naga/src/compact/statements.rs new file mode 100644 index 0000000000..0698b57258 --- /dev/null +++ b/third_party/rust/naga/src/compact/statements.rs @@ -0,0 +1,300 @@ +use super::functions::FunctionTracer; +use super::FunctionMap; +use crate::arena::Handle; + +impl FunctionTracer<'_> { + pub fn trace_block(&mut self, block: &[crate::Statement]) { + let mut worklist: Vec<&[crate::Statement]> = vec![block]; + while let Some(last) = worklist.pop() { + for stmt in last { + use crate::Statement as St; + match *stmt { + St::Emit(ref _range) => { + // If we come across a statement that actually uses an + // expression in this range, it'll get traced from + // there. But since evaluating expressions has no + // effect, we don't need to assume that everything + // emitted is live. + } + St::Block(ref block) => worklist.push(block), + St::If { + condition, + ref accept, + ref reject, + } => { + self.expressions_used.insert(condition); + worklist.push(accept); + worklist.push(reject); + } + St::Switch { + selector, + ref cases, + } => { + self.expressions_used.insert(selector); + for case in cases { + worklist.push(&case.body); + } + } + St::Loop { + ref body, + ref continuing, + break_if, + } => { + if let Some(break_if) = break_if { + self.expressions_used.insert(break_if); + } + worklist.push(body); + worklist.push(continuing); + } + St::Return { value: Some(value) } => { + self.expressions_used.insert(value); + } + St::Store { pointer, value } => { + self.expressions_used.insert(pointer); + self.expressions_used.insert(value); + } + St::ImageStore { + image, + coordinate, + array_index, + value, + } => { + self.expressions_used.insert(image); + self.expressions_used.insert(coordinate); + if let Some(array_index) = array_index { + self.expressions_used.insert(array_index); + } + self.expressions_used.insert(value); + } + St::Atomic { + pointer, + ref fun, + value, + result, + } => { + self.expressions_used.insert(pointer); + self.trace_atomic_function(fun); + self.expressions_used.insert(value); + self.expressions_used.insert(result); + } + St::WorkGroupUniformLoad { pointer, result } => { + self.expressions_used.insert(pointer); + self.expressions_used.insert(result); + } + St::Call { + function: _, + ref arguments, + result, + } => { + for expr in arguments { + self.expressions_used.insert(*expr); + } + if let Some(result) = result { + self.expressions_used.insert(result); + } + } + St::RayQuery { query, ref fun } => { + self.expressions_used.insert(query); + self.trace_ray_query_function(fun); + } + + // Trivial statements. + St::Break + | St::Continue + | St::Kill + | St::Barrier(_) + | St::Return { value: None } => {} + } + } + } + } + + fn trace_atomic_function(&mut self, fun: &crate::AtomicFunction) { + use crate::AtomicFunction as Af; + match *fun { + Af::Exchange { + compare: Some(expr), + } => { + self.expressions_used.insert(expr); + } + Af::Exchange { compare: None } + | Af::Add + | Af::Subtract + | Af::And + | Af::ExclusiveOr + | Af::InclusiveOr + | Af::Min + | Af::Max => {} + } + } + + fn trace_ray_query_function(&mut self, fun: &crate::RayQueryFunction) { + use crate::RayQueryFunction as Qf; + match *fun { + Qf::Initialize { + acceleration_structure, + descriptor, + } => { + self.expressions_used.insert(acceleration_structure); + self.expressions_used.insert(descriptor); + } + Qf::Proceed { result } => { + self.expressions_used.insert(result); + } + Qf::Terminate => {} + } + } +} + +impl FunctionMap { + pub fn adjust_body(&self, function: &mut crate::Function) { + let block = &mut function.body; + let mut worklist: Vec<&mut [crate::Statement]> = vec![block]; + let adjust = |handle: &mut Handle<crate::Expression>| { + self.expressions.adjust(handle); + }; + while let Some(last) = worklist.pop() { + for stmt in last { + use crate::Statement as St; + match *stmt { + St::Emit(ref mut range) => { + self.expressions.adjust_range(range, &function.expressions); + } + St::Block(ref mut block) => worklist.push(block), + St::If { + ref mut condition, + ref mut accept, + ref mut reject, + } => { + adjust(condition); + worklist.push(accept); + worklist.push(reject); + } + St::Switch { + ref mut selector, + ref mut cases, + } => { + adjust(selector); + for case in cases { + worklist.push(&mut case.body); + } + } + St::Loop { + ref mut body, + ref mut continuing, + ref mut break_if, + } => { + if let Some(ref mut break_if) = *break_if { + adjust(break_if); + } + worklist.push(body); + worklist.push(continuing); + } + St::Return { + value: Some(ref mut value), + } => adjust(value), + St::Store { + ref mut pointer, + ref mut value, + } => { + adjust(pointer); + adjust(value); + } + St::ImageStore { + ref mut image, + ref mut coordinate, + ref mut array_index, + ref mut value, + } => { + adjust(image); + adjust(coordinate); + if let Some(ref mut array_index) = *array_index { + adjust(array_index); + } + adjust(value); + } + St::Atomic { + ref mut pointer, + ref mut fun, + ref mut value, + ref mut result, + } => { + adjust(pointer); + self.adjust_atomic_function(fun); + adjust(value); + adjust(result); + } + St::WorkGroupUniformLoad { + ref mut pointer, + ref mut result, + } => { + adjust(pointer); + adjust(result); + } + St::Call { + function: _, + ref mut arguments, + ref mut result, + } => { + for expr in arguments { + adjust(expr); + } + if let Some(ref mut result) = *result { + adjust(result); + } + } + St::RayQuery { + ref mut query, + ref mut fun, + } => { + adjust(query); + self.adjust_ray_query_function(fun); + } + + // Trivial statements. + St::Break + | St::Continue + | St::Kill + | St::Barrier(_) + | St::Return { value: None } => {} + } + } + } + } + + fn adjust_atomic_function(&self, fun: &mut crate::AtomicFunction) { + use crate::AtomicFunction as Af; + match *fun { + Af::Exchange { + compare: Some(ref mut expr), + } => { + self.expressions.adjust(expr); + } + Af::Exchange { compare: None } + | Af::Add + | Af::Subtract + | Af::And + | Af::ExclusiveOr + | Af::InclusiveOr + | Af::Min + | Af::Max => {} + } + } + + fn adjust_ray_query_function(&self, fun: &mut crate::RayQueryFunction) { + use crate::RayQueryFunction as Qf; + match *fun { + Qf::Initialize { + ref mut acceleration_structure, + ref mut descriptor, + } => { + self.expressions.adjust(acceleration_structure); + self.expressions.adjust(descriptor); + } + Qf::Proceed { ref mut result } => { + self.expressions.adjust(result); + } + Qf::Terminate => {} + } + } +} diff --git a/third_party/rust/naga/src/compact/types.rs b/third_party/rust/naga/src/compact/types.rs new file mode 100644 index 0000000000..b78619d9a8 --- /dev/null +++ b/third_party/rust/naga/src/compact/types.rs @@ -0,0 +1,102 @@ +use super::{HandleSet, ModuleMap}; +use crate::{Handle, UniqueArena}; + +pub struct TypeTracer<'a> { + pub types: &'a UniqueArena<crate::Type>, + pub types_used: &'a mut HandleSet<crate::Type>, +} + +impl<'a> TypeTracer<'a> { + /// Propagate usage through `self.types`, starting with `self.types_used`. + /// + /// Treat `self.types_used` as the initial set of "known + /// live" types, and follow through to identify all + /// transitively used types. + pub fn trace_types(&mut self) { + // We don't need recursion or a work list. Because an + // expression may only refer to other expressions that precede + // it in the arena, it suffices to make a single pass over the + // arena from back to front, marking the referents of used + // expressions as used themselves. + for (handle, ty) in self.types.iter().rev() { + // If this type isn't used, it doesn't matter what it uses. + if !self.types_used.contains(handle) { + continue; + } + + use crate::TypeInner as Ti; + match ty.inner { + // Types that do not contain handles. + Ti::Scalar { .. } + | Ti::Vector { .. } + | Ti::Matrix { .. } + | Ti::Atomic { .. } + | Ti::ValuePointer { .. } + | Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery => {} + + // Types that do contain handles. + Ti::Pointer { base, space: _ } + | Ti::Array { + base, + size: _, + stride: _, + } + | Ti::BindingArray { base, size: _ } => self.types_used.insert(base), + Ti::Struct { + ref members, + span: _, + } => { + self.types_used.insert_iter(members.iter().map(|m| m.ty)); + } + } + } + } +} + +impl ModuleMap { + pub fn adjust_type(&self, ty: &mut crate::Type) { + let adjust = |ty: &mut Handle<crate::Type>| self.types.adjust(ty); + + use crate::TypeInner as Ti; + match ty.inner { + // Types that do not contain handles. + Ti::Scalar(_) + | Ti::Vector { .. } + | Ti::Matrix { .. } + | Ti::Atomic(_) + | Ti::ValuePointer { .. } + | Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery => {} + + // Types that do contain handles. + Ti::Pointer { + ref mut base, + space: _, + } => adjust(base), + Ti::Array { + ref mut base, + size: _, + stride: _, + } => adjust(base), + Ti::Struct { + ref mut members, + span: _, + } => { + for member in members { + self.types.adjust(&mut member.ty); + } + } + Ti::BindingArray { + ref mut base, + size: _, + } => { + adjust(base); + } + }; + } +} diff --git a/third_party/rust/naga/src/front/glsl/ast.rs b/third_party/rust/naga/src/front/glsl/ast.rs new file mode 100644 index 0000000000..96b676dd6d --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/ast.rs @@ -0,0 +1,394 @@ +use std::{borrow::Cow, fmt}; + +use super::{builtins::MacroCall, context::ExprPos, Span}; +use crate::{ + AddressSpace, BinaryOperator, Binding, Constant, Expression, Function, GlobalVariable, Handle, + Interpolation, Literal, Sampling, StorageAccess, Type, UnaryOperator, +}; + +#[derive(Debug, Clone, Copy)] +pub enum GlobalLookupKind { + Variable(Handle<GlobalVariable>), + Constant(Handle<Constant>, Handle<Type>), + BlockSelect(Handle<GlobalVariable>, u32), +} + +#[derive(Debug, Clone, Copy)] +pub struct GlobalLookup { + pub kind: GlobalLookupKind, + pub entry_arg: Option<usize>, + pub mutable: bool, +} + +#[derive(Debug, Clone)] +pub struct ParameterInfo { + pub qualifier: ParameterQualifier, + /// Whether the parameter should be treated as a depth image instead of a + /// sampled image. + pub depth: bool, +} + +/// How the function is implemented +#[derive(Clone, Copy)] +pub enum FunctionKind { + /// The function is user defined + Call(Handle<Function>), + /// The function is a builtin + Macro(MacroCall), +} + +impl fmt::Debug for FunctionKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Self::Call(_) => write!(f, "Call"), + Self::Macro(_) => write!(f, "Macro"), + } + } +} + +#[derive(Debug)] +pub struct Overload { + /// Normalized function parameters, modifiers are not applied + pub parameters: Vec<Handle<Type>>, + pub parameters_info: Vec<ParameterInfo>, + /// How the function is implemented + pub kind: FunctionKind, + /// Whether this function was already defined or is just a prototype + pub defined: bool, + /// Whether this overload is the one provided by the language or has + /// been redeclared by the user (builtins only) + pub internal: bool, + /// Whether or not this function returns void (nothing) + pub void: bool, +} + +bitflags::bitflags! { + /// Tracks the variations of the builtin already generated, this is needed because some + /// builtins overloads can't be generated unless explicitly used, since they might cause + /// unneeded capabilities to be requested + #[derive(Default)] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct BuiltinVariations: u32 { + /// Request the standard overloads + const STANDARD = 1 << 0; + /// Request overloads that use the double type + const DOUBLE = 1 << 1; + /// Request overloads that use samplerCubeArray(Shadow) + const CUBE_TEXTURES_ARRAY = 1 << 2; + /// Request overloads that use sampler2DMSArray + const D2_MULTI_TEXTURES_ARRAY = 1 << 3; + } +} + +#[derive(Debug, Default)] +pub struct FunctionDeclaration { + pub overloads: Vec<Overload>, + /// Tracks the builtin overload variations that were already generated + pub variations: BuiltinVariations, +} + +#[derive(Debug)] +pub struct EntryArg { + pub name: Option<String>, + pub binding: Binding, + pub handle: Handle<GlobalVariable>, + pub storage: StorageQualifier, +} + +#[derive(Debug, Clone)] +pub struct VariableReference { + pub expr: Handle<Expression>, + /// Whether the variable is of a pointer type (and needs loading) or not + pub load: bool, + /// Whether the value of the variable can be changed or not + pub mutable: bool, + pub constant: Option<(Handle<Constant>, Handle<Type>)>, + pub entry_arg: Option<usize>, +} + +#[derive(Debug, Clone)] +pub struct HirExpr { + pub kind: HirExprKind, + pub meta: Span, +} + +#[derive(Debug, Clone)] +pub enum HirExprKind { + Access { + base: Handle<HirExpr>, + index: Handle<HirExpr>, + }, + Select { + base: Handle<HirExpr>, + field: String, + }, + Literal(Literal), + Binary { + left: Handle<HirExpr>, + op: BinaryOperator, + right: Handle<HirExpr>, + }, + Unary { + op: UnaryOperator, + expr: Handle<HirExpr>, + }, + Variable(VariableReference), + Call(FunctionCall), + /// Represents the ternary operator in glsl (`:?`) + Conditional { + /// The expression that will decide which branch to take, must evaluate to a boolean + condition: Handle<HirExpr>, + /// The expression that will be evaluated if [`condition`] returns `true` + /// + /// [`condition`]: Self::Conditional::condition + accept: Handle<HirExpr>, + /// The expression that will be evaluated if [`condition`] returns `false` + /// + /// [`condition`]: Self::Conditional::condition + reject: Handle<HirExpr>, + }, + Assign { + tgt: Handle<HirExpr>, + value: Handle<HirExpr>, + }, + /// A prefix/postfix operator like `++` + PrePostfix { + /// The operation to be performed + op: BinaryOperator, + /// Whether this is a postfix or a prefix + postfix: bool, + /// The target expression + expr: Handle<HirExpr>, + }, + /// A method call like `what.something(a, b, c)` + Method { + /// expression the method call applies to (`what` in the example) + expr: Handle<HirExpr>, + /// the method name (`something` in the example) + name: String, + /// the arguments to the method (`a`, `b`, and `c` in the example) + args: Vec<Handle<HirExpr>>, + }, +} + +#[derive(Debug, Hash, PartialEq, Eq)] +pub enum QualifierKey<'a> { + String(Cow<'a, str>), + /// Used for `std140` and `std430` layout qualifiers + Layout, + /// Used for image formats + Format, +} + +#[derive(Debug)] +pub enum QualifierValue { + None, + Uint(u32), + Layout(StructLayout), + Format(crate::StorageFormat), +} + +#[derive(Debug, Default)] +pub struct TypeQualifiers<'a> { + pub span: Span, + pub storage: (StorageQualifier, Span), + pub invariant: Option<Span>, + pub interpolation: Option<(Interpolation, Span)>, + pub precision: Option<(Precision, Span)>, + pub sampling: Option<(Sampling, Span)>, + /// Memory qualifiers used in the declaration to set the storage access to be used + /// in declarations that support it (storage images and buffers) + pub storage_access: Option<(StorageAccess, Span)>, + pub layout_qualifiers: crate::FastHashMap<QualifierKey<'a>, (QualifierValue, Span)>, +} + +impl<'a> TypeQualifiers<'a> { + /// Appends `errors` with errors for all unused qualifiers + pub fn unused_errors(&self, errors: &mut Vec<super::Error>) { + if let Some(meta) = self.invariant { + errors.push(super::Error { + kind: super::ErrorKind::SemanticError( + "Invariant qualifier can only be used in in/out variables".into(), + ), + meta, + }); + } + + if let Some((_, meta)) = self.interpolation { + errors.push(super::Error { + kind: super::ErrorKind::SemanticError( + "Interpolation qualifiers can only be used in in/out variables".into(), + ), + meta, + }); + } + + if let Some((_, meta)) = self.sampling { + errors.push(super::Error { + kind: super::ErrorKind::SemanticError( + "Sampling qualifiers can only be used in in/out variables".into(), + ), + meta, + }); + } + + if let Some((_, meta)) = self.storage_access { + errors.push(super::Error { + kind: super::ErrorKind::SemanticError( + "Memory qualifiers can only be used in storage variables".into(), + ), + meta, + }); + } + + for &(_, meta) in self.layout_qualifiers.values() { + errors.push(super::Error { + kind: super::ErrorKind::SemanticError("Unexpected qualifier".into()), + meta, + }); + } + } + + /// Removes the layout qualifier with `name`, if it exists and adds an error if it isn't + /// a [`QualifierValue::Uint`] + pub fn uint_layout_qualifier( + &mut self, + name: &'a str, + errors: &mut Vec<super::Error>, + ) -> Option<u32> { + match self + .layout_qualifiers + .remove(&QualifierKey::String(name.into())) + { + Some((QualifierValue::Uint(v), _)) => Some(v), + Some((_, meta)) => { + errors.push(super::Error { + kind: super::ErrorKind::SemanticError("Qualifier expects a uint value".into()), + meta, + }); + // Return a dummy value instead of `None` to differentiate from + // the qualifier not existing, since some parts might require the + // qualifier to exist and throwing another error that it doesn't + // exist would be unhelpful + Some(0) + } + _ => None, + } + } + + /// Removes the layout qualifier with `name`, if it exists and adds an error if it isn't + /// a [`QualifierValue::None`] + pub fn none_layout_qualifier(&mut self, name: &'a str, errors: &mut Vec<super::Error>) -> bool { + match self + .layout_qualifiers + .remove(&QualifierKey::String(name.into())) + { + Some((QualifierValue::None, _)) => true, + Some((_, meta)) => { + errors.push(super::Error { + kind: super::ErrorKind::SemanticError( + "Qualifier doesn't expect a value".into(), + ), + meta, + }); + // Return a `true` to since the qualifier is defined and adding + // another error for it not being defined would be unhelpful + true + } + _ => false, + } + } +} + +#[derive(Debug, Clone)] +pub enum FunctionCallKind { + TypeConstructor(Handle<Type>), + Function(String), +} + +#[derive(Debug, Clone)] +pub struct FunctionCall { + pub kind: FunctionCallKind, + pub args: Vec<Handle<HirExpr>>, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum StorageQualifier { + AddressSpace(AddressSpace), + Input, + Output, + Const, +} + +impl Default for StorageQualifier { + fn default() -> Self { + StorageQualifier::AddressSpace(AddressSpace::Function) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StructLayout { + Std140, + Std430, +} + +// TODO: Encode precision hints in the IR +/// A precision hint used in GLSL declarations. +/// +/// Precision hints can be used to either speed up shader execution or control +/// the precision of arithmetic operations. +/// +/// To use a precision hint simply add it before the type in the declaration. +/// ```glsl +/// mediump float a; +/// ``` +/// +/// The default when no precision is declared is `highp` which means that all +/// operations operate with the type defined width. +/// +/// For `mediump` and `lowp` operations follow the spir-v +/// [`RelaxedPrecision`][RelaxedPrecision] decoration semantics. +/// +/// [RelaxedPrecision]: https://www.khronos.org/registry/SPIR-V/specs/unified1/SPIRV.html#_a_id_relaxedprecisionsection_a_relaxed_precision +#[derive(Debug, Clone, PartialEq, Copy)] +pub enum Precision { + /// `lowp` precision + Low, + /// `mediump` precision + Medium, + /// `highp` precision + High, +} + +#[derive(Debug, Clone, PartialEq, Copy)] +pub enum ParameterQualifier { + In, + Out, + InOut, + Const, +} + +impl ParameterQualifier { + /// Returns true if the argument should be passed as a lhs expression + pub const fn is_lhs(&self) -> bool { + match *self { + ParameterQualifier::Out | ParameterQualifier::InOut => true, + _ => false, + } + } + + /// Converts from a parameter qualifier into a [`ExprPos`] + pub const fn as_pos(&self) -> ExprPos { + match *self { + ParameterQualifier::Out | ParameterQualifier::InOut => ExprPos::Lhs, + _ => ExprPos::Rhs, + } + } +} + +/// The GLSL profile used by a shader. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Profile { + /// The `core` profile, default when no profile is specified. + Core, +} diff --git a/third_party/rust/naga/src/front/glsl/builtins.rs b/third_party/rust/naga/src/front/glsl/builtins.rs new file mode 100644 index 0000000000..9e3a578c6b --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/builtins.rs @@ -0,0 +1,2314 @@ +use super::{ + ast::{ + BuiltinVariations, FunctionDeclaration, FunctionKind, Overload, ParameterInfo, + ParameterQualifier, + }, + context::Context, + Error, ErrorKind, Frontend, Result, +}; +use crate::{ + BinaryOperator, DerivativeAxis as Axis, DerivativeControl as Ctrl, Expression, Handle, + ImageClass, ImageDimension as Dim, ImageQuery, MathFunction, Module, RelationalFunction, + SampleLevel, Scalar, ScalarKind as Sk, Span, Type, TypeInner, UnaryOperator, VectorSize, +}; + +impl crate::ScalarKind { + const fn dummy_storage_format(&self) -> crate::StorageFormat { + match *self { + Sk::Sint => crate::StorageFormat::R16Sint, + Sk::Uint => crate::StorageFormat::R16Uint, + _ => crate::StorageFormat::R16Float, + } + } +} + +impl Module { + /// Helper function, to create a function prototype for a builtin + fn add_builtin(&mut self, args: Vec<TypeInner>, builtin: MacroCall) -> Overload { + let mut parameters = Vec::with_capacity(args.len()); + let mut parameters_info = Vec::with_capacity(args.len()); + + for arg in args { + parameters.push(self.types.insert( + Type { + name: None, + inner: arg, + }, + Span::default(), + )); + parameters_info.push(ParameterInfo { + qualifier: ParameterQualifier::In, + depth: false, + }); + } + + Overload { + parameters, + parameters_info, + kind: FunctionKind::Macro(builtin), + defined: false, + internal: true, + void: false, + } + } +} + +const fn make_coords_arg(number_of_components: usize, kind: Sk) -> TypeInner { + let scalar = Scalar { kind, width: 4 }; + + match number_of_components { + 1 => TypeInner::Scalar(scalar), + _ => TypeInner::Vector { + size: match number_of_components { + 2 => VectorSize::Bi, + 3 => VectorSize::Tri, + _ => VectorSize::Quad, + }, + scalar, + }, + } +} + +/// Inject builtins into the declaration +/// +/// This is done to not add a large startup cost and not increase memory +/// usage if it isn't needed. +pub fn inject_builtin( + declaration: &mut FunctionDeclaration, + module: &mut Module, + name: &str, + mut variations: BuiltinVariations, +) { + log::trace!( + "{} variations: {:?} {:?}", + name, + variations, + declaration.variations + ); + // Don't regeneate variations + variations.remove(declaration.variations); + declaration.variations |= variations; + + if variations.contains(BuiltinVariations::STANDARD) { + inject_standard_builtins(declaration, module, name) + } + + if variations.contains(BuiltinVariations::DOUBLE) { + inject_double_builtin(declaration, module, name) + } + + match name { + "texture" + | "textureGrad" + | "textureGradOffset" + | "textureLod" + | "textureLodOffset" + | "textureOffset" + | "textureProj" + | "textureProjGrad" + | "textureProjGradOffset" + | "textureProjLod" + | "textureProjLodOffset" + | "textureProjOffset" => { + let f = |kind, dim, arrayed, multi, shadow| { + for bits in 0..=0b11 { + let variant = bits & 0b1 != 0; + let bias = bits & 0b10 != 0; + + let (proj, offset, level_type) = match name { + // texture(gsampler, gvec P, [float bias]); + "texture" => (false, false, TextureLevelType::None), + // textureGrad(gsampler, gvec P, gvec dPdx, gvec dPdy); + "textureGrad" => (false, false, TextureLevelType::Grad), + // textureGradOffset(gsampler, gvec P, gvec dPdx, gvec dPdy, ivec offset); + "textureGradOffset" => (false, true, TextureLevelType::Grad), + // textureLod(gsampler, gvec P, float lod); + "textureLod" => (false, false, TextureLevelType::Lod), + // textureLodOffset(gsampler, gvec P, float lod, ivec offset); + "textureLodOffset" => (false, true, TextureLevelType::Lod), + // textureOffset(gsampler, gvec+1 P, ivec offset, [float bias]); + "textureOffset" => (false, true, TextureLevelType::None), + // textureProj(gsampler, gvec+1 P, [float bias]); + "textureProj" => (true, false, TextureLevelType::None), + // textureProjGrad(gsampler, gvec+1 P, gvec dPdx, gvec dPdy); + "textureProjGrad" => (true, false, TextureLevelType::Grad), + // textureProjGradOffset(gsampler, gvec+1 P, gvec dPdx, gvec dPdy, ivec offset); + "textureProjGradOffset" => (true, true, TextureLevelType::Grad), + // textureProjLod(gsampler, gvec+1 P, float lod); + "textureProjLod" => (true, false, TextureLevelType::Lod), + // textureProjLodOffset(gsampler, gvec+1 P, gvec dPdx, gvec dPdy, ivec offset); + "textureProjLodOffset" => (true, true, TextureLevelType::Lod), + // textureProjOffset(gsampler, gvec+1 P, ivec offset, [float bias]); + "textureProjOffset" => (true, true, TextureLevelType::None), + _ => unreachable!(), + }; + + let builtin = MacroCall::Texture { + proj, + offset, + shadow, + level_type, + }; + + // Parse out the variant settings. + let grad = level_type == TextureLevelType::Grad; + let lod = level_type == TextureLevelType::Lod; + + let supports_variant = proj && !shadow; + if variant && !supports_variant { + continue; + } + + if bias && !matches!(level_type, TextureLevelType::None) { + continue; + } + + // Proj doesn't work with arrayed or Cube + if proj && (arrayed || dim == Dim::Cube) { + continue; + } + + // texture operations with offset are not supported for cube maps + if dim == Dim::Cube && offset { + continue; + } + + // sampler2DArrayShadow can't be used in textureLod or in texture with bias + if (lod || bias) && arrayed && shadow && dim == Dim::D2 { + continue; + } + + // TODO: glsl supports using bias with depth samplers but naga doesn't + if bias && shadow { + continue; + } + + let class = match shadow { + true => ImageClass::Depth { multi }, + false => ImageClass::Sampled { kind, multi }, + }; + + let image = TypeInner::Image { + dim, + arrayed, + class, + }; + + let num_coords_from_dim = image_dims_to_coords_size(dim).min(3); + let mut num_coords = num_coords_from_dim; + + if shadow && proj { + num_coords = 4; + } else if dim == Dim::D1 && shadow { + num_coords = 3; + } else if shadow { + num_coords += 1; + } else if proj { + if variant && num_coords == 4 { + // Normal form already has 4 components, no need to have a variant form. + continue; + } else if variant { + num_coords = 4; + } else { + num_coords += 1; + } + } + + if !(dim == Dim::D1 && shadow) { + num_coords += arrayed as usize; + } + + // Special case: texture(gsamplerCubeArrayShadow) kicks the shadow compare ref to a separate argument, + // since it would otherwise take five arguments. It also can't take a bias, nor can it be proj/grad/lod/offset + // (presumably because nobody asked for it, and implementation complexity?) + if num_coords >= 5 { + if lod || grad || offset || proj || bias { + continue; + } + debug_assert!(dim == Dim::Cube && shadow && arrayed); + } + debug_assert!(num_coords <= 5); + + let vector = make_coords_arg(num_coords, Sk::Float); + let mut args = vec![image, vector]; + + if num_coords == 5 { + args.push(TypeInner::Scalar(Scalar::F32)); + } + + match level_type { + TextureLevelType::Lod => { + args.push(TypeInner::Scalar(Scalar::F32)); + } + TextureLevelType::Grad => { + args.push(make_coords_arg(num_coords_from_dim, Sk::Float)); + args.push(make_coords_arg(num_coords_from_dim, Sk::Float)); + } + TextureLevelType::None => {} + }; + + if offset { + args.push(make_coords_arg(num_coords_from_dim, Sk::Sint)); + } + + if bias { + args.push(TypeInner::Scalar(Scalar::F32)); + } + + declaration + .overloads + .push(module.add_builtin(args, builtin)); + } + }; + + texture_args_generator(TextureArgsOptions::SHADOW | variations.into(), f) + } + "textureSize" => { + let f = |kind, dim, arrayed, multi, shadow| { + let class = match shadow { + true => ImageClass::Depth { multi }, + false => ImageClass::Sampled { kind, multi }, + }; + + let image = TypeInner::Image { + dim, + arrayed, + class, + }; + + let mut args = vec![image]; + + if !multi { + args.push(TypeInner::Scalar(Scalar::I32)) + } + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::TextureSize { arrayed })) + }; + + texture_args_generator( + TextureArgsOptions::SHADOW | TextureArgsOptions::MULTI | variations.into(), + f, + ) + } + "texelFetch" | "texelFetchOffset" => { + let offset = "texelFetchOffset" == name; + let f = |kind, dim, arrayed, multi, _shadow| { + // Cube images aren't supported + if let Dim::Cube = dim { + return; + } + + let image = TypeInner::Image { + dim, + arrayed, + class: ImageClass::Sampled { kind, multi }, + }; + + let dim_value = image_dims_to_coords_size(dim); + let coordinates = make_coords_arg(dim_value + arrayed as usize, Sk::Sint); + + let mut args = vec![image, coordinates, TypeInner::Scalar(Scalar::I32)]; + + if offset { + args.push(make_coords_arg(dim_value, Sk::Sint)); + } + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::ImageLoad { multi })) + }; + + // Don't generate shadow images since they aren't supported + texture_args_generator(TextureArgsOptions::MULTI | variations.into(), f) + } + "imageSize" => { + let f = |kind: Sk, dim, arrayed, _, _| { + // Naga doesn't support cube images and it's usefulness + // is questionable, so they won't be supported for now + if dim == Dim::Cube { + return; + } + + let image = TypeInner::Image { + dim, + arrayed, + class: ImageClass::Storage { + format: kind.dummy_storage_format(), + access: crate::StorageAccess::empty(), + }, + }; + + declaration + .overloads + .push(module.add_builtin(vec![image], MacroCall::TextureSize { arrayed })) + }; + + texture_args_generator(variations.into(), f) + } + "imageLoad" => { + let f = |kind: Sk, dim, arrayed, _, _| { + // Naga doesn't support cube images and it's usefulness + // is questionable, so they won't be supported for now + if dim == Dim::Cube { + return; + } + + let image = TypeInner::Image { + dim, + arrayed, + class: ImageClass::Storage { + format: kind.dummy_storage_format(), + access: crate::StorageAccess::LOAD, + }, + }; + + let dim_value = image_dims_to_coords_size(dim); + let mut coord_size = dim_value + arrayed as usize; + // > Every OpenGL API call that operates on cubemap array + // > textures takes layer-faces, not array layers + // + // So this means that imageCubeArray only takes a three component + // vector coordinate and the third component is a layer index. + if Dim::Cube == dim && arrayed { + coord_size = 3 + } + let coordinates = make_coords_arg(coord_size, Sk::Sint); + + let args = vec![image, coordinates]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::ImageLoad { multi: false })) + }; + + // Don't generate shadow nor multisampled images since they aren't supported + texture_args_generator(variations.into(), f) + } + "imageStore" => { + let f = |kind: Sk, dim, arrayed, _, _| { + // Naga doesn't support cube images and it's usefulness + // is questionable, so they won't be supported for now + if dim == Dim::Cube { + return; + } + + let image = TypeInner::Image { + dim, + arrayed, + class: ImageClass::Storage { + format: kind.dummy_storage_format(), + access: crate::StorageAccess::STORE, + }, + }; + + let dim_value = image_dims_to_coords_size(dim); + let mut coord_size = dim_value + arrayed as usize; + // > Every OpenGL API call that operates on cubemap array + // > textures takes layer-faces, not array layers + // + // So this means that imageCubeArray only takes a three component + // vector coordinate and the third component is a layer index. + if Dim::Cube == dim && arrayed { + coord_size = 3 + } + let coordinates = make_coords_arg(coord_size, Sk::Sint); + + let args = vec![ + image, + coordinates, + TypeInner::Vector { + size: VectorSize::Quad, + scalar: Scalar { kind, width: 4 }, + }, + ]; + + let mut overload = module.add_builtin(args, MacroCall::ImageStore); + overload.void = true; + declaration.overloads.push(overload) + }; + + // Don't generate shadow nor multisampled images since they aren't supported + texture_args_generator(variations.into(), f) + } + _ => {} + } +} + +/// Injects the builtins into declaration that don't need any special variations +fn inject_standard_builtins( + declaration: &mut FunctionDeclaration, + module: &mut Module, + name: &str, +) { + match name { + "sampler1D" | "sampler1DArray" | "sampler2D" | "sampler2DArray" | "sampler2DMS" + | "sampler2DMSArray" | "sampler3D" | "samplerCube" | "samplerCubeArray" => { + declaration.overloads.push(module.add_builtin( + vec![ + TypeInner::Image { + dim: match name { + "sampler1D" | "sampler1DArray" => Dim::D1, + "sampler2D" | "sampler2DArray" | "sampler2DMS" | "sampler2DMSArray" => { + Dim::D2 + } + "sampler3D" => Dim::D3, + _ => Dim::Cube, + }, + arrayed: matches!( + name, + "sampler1DArray" + | "sampler2DArray" + | "sampler2DMSArray" + | "samplerCubeArray" + ), + class: ImageClass::Sampled { + kind: Sk::Float, + multi: matches!(name, "sampler2DMS" | "sampler2DMSArray"), + }, + }, + TypeInner::Sampler { comparison: false }, + ], + MacroCall::Sampler, + )) + } + "sampler1DShadow" + | "sampler1DArrayShadow" + | "sampler2DShadow" + | "sampler2DArrayShadow" + | "samplerCubeShadow" + | "samplerCubeArrayShadow" => { + let dim = match name { + "sampler1DShadow" | "sampler1DArrayShadow" => Dim::D1, + "sampler2DShadow" | "sampler2DArrayShadow" => Dim::D2, + _ => Dim::Cube, + }; + let arrayed = matches!( + name, + "sampler1DArrayShadow" | "sampler2DArrayShadow" | "samplerCubeArrayShadow" + ); + + for i in 0..2 { + let ty = TypeInner::Image { + dim, + arrayed, + class: match i { + 0 => ImageClass::Sampled { + kind: Sk::Float, + multi: false, + }, + _ => ImageClass::Depth { multi: false }, + }, + }; + + declaration.overloads.push(module.add_builtin( + vec![ty, TypeInner::Sampler { comparison: true }], + MacroCall::SamplerShadow, + )) + } + } + "sin" | "exp" | "exp2" | "sinh" | "cos" | "cosh" | "tan" | "tanh" | "acos" | "asin" + | "log" | "log2" | "radians" | "degrees" | "asinh" | "acosh" | "atanh" + | "floatBitsToInt" | "floatBitsToUint" | "dFdx" | "dFdxFine" | "dFdxCoarse" | "dFdy" + | "dFdyFine" | "dFdyCoarse" | "fwidth" | "fwidthFine" | "fwidthCoarse" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + let scalar = Scalar::F32; + + declaration.overloads.push(module.add_builtin( + vec![match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }], + match name { + "sin" => MacroCall::MathFunction(MathFunction::Sin), + "exp" => MacroCall::MathFunction(MathFunction::Exp), + "exp2" => MacroCall::MathFunction(MathFunction::Exp2), + "sinh" => MacroCall::MathFunction(MathFunction::Sinh), + "cos" => MacroCall::MathFunction(MathFunction::Cos), + "cosh" => MacroCall::MathFunction(MathFunction::Cosh), + "tan" => MacroCall::MathFunction(MathFunction::Tan), + "tanh" => MacroCall::MathFunction(MathFunction::Tanh), + "acos" => MacroCall::MathFunction(MathFunction::Acos), + "asin" => MacroCall::MathFunction(MathFunction::Asin), + "log" => MacroCall::MathFunction(MathFunction::Log), + "log2" => MacroCall::MathFunction(MathFunction::Log2), + "asinh" => MacroCall::MathFunction(MathFunction::Asinh), + "acosh" => MacroCall::MathFunction(MathFunction::Acosh), + "atanh" => MacroCall::MathFunction(MathFunction::Atanh), + "radians" => MacroCall::MathFunction(MathFunction::Radians), + "degrees" => MacroCall::MathFunction(MathFunction::Degrees), + "floatBitsToInt" => MacroCall::BitCast(Sk::Sint), + "floatBitsToUint" => MacroCall::BitCast(Sk::Uint), + "dFdxCoarse" => MacroCall::Derivate(Axis::X, Ctrl::Coarse), + "dFdyCoarse" => MacroCall::Derivate(Axis::Y, Ctrl::Coarse), + "fwidthCoarse" => MacroCall::Derivate(Axis::Width, Ctrl::Coarse), + "dFdxFine" => MacroCall::Derivate(Axis::X, Ctrl::Fine), + "dFdyFine" => MacroCall::Derivate(Axis::Y, Ctrl::Fine), + "fwidthFine" => MacroCall::Derivate(Axis::Width, Ctrl::Fine), + "dFdx" => MacroCall::Derivate(Axis::X, Ctrl::None), + "dFdy" => MacroCall::Derivate(Axis::Y, Ctrl::None), + "fwidth" => MacroCall::Derivate(Axis::Width, Ctrl::None), + _ => unreachable!(), + }, + )) + } + } + "intBitsToFloat" | "uintBitsToFloat" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + let scalar = match name { + "intBitsToFloat" => Scalar::I32, + _ => Scalar::U32, + }; + + declaration.overloads.push(module.add_builtin( + vec![match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }], + MacroCall::BitCast(Sk::Float), + )) + } + } + "pow" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + let scalar = Scalar::F32; + let ty = || match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + + declaration.overloads.push( + module + .add_builtin(vec![ty(), ty()], MacroCall::MathFunction(MathFunction::Pow)), + ) + } + } + "abs" | "sign" => { + // bits layout + // bit 0 through 1 - dims + // bit 2 - float/sint + for bits in 0..0b1000 { + let size = match bits & 0b11 { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + let scalar = match bits >> 2 { + 0b0 => Scalar::F32, + _ => Scalar::I32, + }; + + let args = vec![match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }]; + + declaration.overloads.push(module.add_builtin( + args, + MacroCall::MathFunction(match name { + "abs" => MathFunction::Abs, + "sign" => MathFunction::Sign, + _ => unreachable!(), + }), + )) + } + } + "bitCount" | "bitfieldReverse" | "bitfieldExtract" | "bitfieldInsert" | "findLSB" + | "findMSB" => { + let fun = match name { + "bitCount" => MathFunction::CountOneBits, + "bitfieldReverse" => MathFunction::ReverseBits, + "bitfieldExtract" => MathFunction::ExtractBits, + "bitfieldInsert" => MathFunction::InsertBits, + "findLSB" => MathFunction::FindLsb, + "findMSB" => MathFunction::FindMsb, + _ => unreachable!(), + }; + + let mc = match fun { + MathFunction::ExtractBits => MacroCall::BitfieldExtract, + MathFunction::InsertBits => MacroCall::BitfieldInsert, + _ => MacroCall::MathFunction(fun), + }; + + // bits layout + // bit 0 - int/uint + // bit 1 through 2 - dims + for bits in 0..0b1000 { + let scalar = match bits & 0b1 { + 0b0 => Scalar::I32, + _ => Scalar::U32, + }; + let size = match bits >> 1 { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + + let ty = || match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + + let mut args = vec![ty()]; + + match fun { + MathFunction::ExtractBits => { + args.push(TypeInner::Scalar(Scalar::I32)); + args.push(TypeInner::Scalar(Scalar::I32)); + } + MathFunction::InsertBits => { + args.push(ty()); + args.push(TypeInner::Scalar(Scalar::I32)); + args.push(TypeInner::Scalar(Scalar::I32)); + } + _ => {} + } + + // we need to cast the return type of findLsb / findMsb + let mc = if scalar.kind == Sk::Uint { + match mc { + MacroCall::MathFunction(MathFunction::FindLsb) => MacroCall::FindLsbUint, + MacroCall::MathFunction(MathFunction::FindMsb) => MacroCall::FindMsbUint, + mc => mc, + } + } else { + mc + }; + + declaration.overloads.push(module.add_builtin(args, mc)) + } + } + "packSnorm4x8" | "packUnorm4x8" | "packSnorm2x16" | "packUnorm2x16" | "packHalf2x16" => { + let fun = match name { + "packSnorm4x8" => MathFunction::Pack4x8snorm, + "packUnorm4x8" => MathFunction::Pack4x8unorm, + "packSnorm2x16" => MathFunction::Pack2x16unorm, + "packUnorm2x16" => MathFunction::Pack2x16snorm, + "packHalf2x16" => MathFunction::Pack2x16float, + _ => unreachable!(), + }; + + let ty = match fun { + MathFunction::Pack4x8snorm | MathFunction::Pack4x8unorm => TypeInner::Vector { + size: crate::VectorSize::Quad, + scalar: Scalar::F32, + }, + MathFunction::Pack2x16unorm + | MathFunction::Pack2x16snorm + | MathFunction::Pack2x16float => TypeInner::Vector { + size: crate::VectorSize::Bi, + scalar: Scalar::F32, + }, + _ => unreachable!(), + }; + + let args = vec![ty]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::MathFunction(fun))); + } + "unpackSnorm4x8" | "unpackUnorm4x8" | "unpackSnorm2x16" | "unpackUnorm2x16" + | "unpackHalf2x16" => { + let fun = match name { + "unpackSnorm4x8" => MathFunction::Unpack4x8snorm, + "unpackUnorm4x8" => MathFunction::Unpack4x8unorm, + "unpackSnorm2x16" => MathFunction::Unpack2x16snorm, + "unpackUnorm2x16" => MathFunction::Unpack2x16unorm, + "unpackHalf2x16" => MathFunction::Unpack2x16float, + _ => unreachable!(), + }; + + let args = vec![TypeInner::Scalar(Scalar::U32)]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::MathFunction(fun))); + } + "atan" => { + // bits layout + // bit 0 - atan/atan2 + // bit 1 through 2 - dims + for bits in 0..0b1000 { + let fun = match bits & 0b1 { + 0b0 => MathFunction::Atan, + _ => MathFunction::Atan2, + }; + let size = match bits >> 1 { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + let scalar = Scalar::F32; + let ty = || match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + + let mut args = vec![ty()]; + + if fun == MathFunction::Atan2 { + args.push(ty()) + } + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::MathFunction(fun))) + } + } + "all" | "any" | "not" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b11 { + let size = match bits { + 0b00 => VectorSize::Bi, + 0b01 => VectorSize::Tri, + _ => VectorSize::Quad, + }; + + let args = vec![TypeInner::Vector { + size, + scalar: Scalar::BOOL, + }]; + + let fun = match name { + "all" => MacroCall::Relational(RelationalFunction::All), + "any" => MacroCall::Relational(RelationalFunction::Any), + "not" => MacroCall::Unary(UnaryOperator::LogicalNot), + _ => unreachable!(), + }; + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + "lessThan" | "greaterThan" | "lessThanEqual" | "greaterThanEqual" => { + for bits in 0..0b1001 { + let (size, scalar) = match bits { + 0b0000 => (VectorSize::Bi, Scalar::F32), + 0b0001 => (VectorSize::Tri, Scalar::F32), + 0b0010 => (VectorSize::Quad, Scalar::F32), + 0b0011 => (VectorSize::Bi, Scalar::I32), + 0b0100 => (VectorSize::Tri, Scalar::I32), + 0b0101 => (VectorSize::Quad, Scalar::I32), + 0b0110 => (VectorSize::Bi, Scalar::U32), + 0b0111 => (VectorSize::Tri, Scalar::U32), + _ => (VectorSize::Quad, Scalar::U32), + }; + + let ty = || TypeInner::Vector { size, scalar }; + let args = vec![ty(), ty()]; + + let fun = MacroCall::Binary(match name { + "lessThan" => BinaryOperator::Less, + "greaterThan" => BinaryOperator::Greater, + "lessThanEqual" => BinaryOperator::LessEqual, + "greaterThanEqual" => BinaryOperator::GreaterEqual, + _ => unreachable!(), + }); + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + "equal" | "notEqual" => { + for bits in 0..0b1100 { + let (size, scalar) = match bits { + 0b0000 => (VectorSize::Bi, Scalar::F32), + 0b0001 => (VectorSize::Tri, Scalar::F32), + 0b0010 => (VectorSize::Quad, Scalar::F32), + 0b0011 => (VectorSize::Bi, Scalar::I32), + 0b0100 => (VectorSize::Tri, Scalar::I32), + 0b0101 => (VectorSize::Quad, Scalar::I32), + 0b0110 => (VectorSize::Bi, Scalar::U32), + 0b0111 => (VectorSize::Tri, Scalar::U32), + 0b1000 => (VectorSize::Quad, Scalar::U32), + 0b1001 => (VectorSize::Bi, Scalar::BOOL), + 0b1010 => (VectorSize::Tri, Scalar::BOOL), + _ => (VectorSize::Quad, Scalar::BOOL), + }; + + let ty = || TypeInner::Vector { size, scalar }; + let args = vec![ty(), ty()]; + + let fun = MacroCall::Binary(match name { + "equal" => BinaryOperator::Equal, + "notEqual" => BinaryOperator::NotEqual, + _ => unreachable!(), + }); + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + "min" | "max" => { + // bits layout + // bit 0 through 1 - scalar kind + // bit 2 through 4 - dims + for bits in 0..0b11100 { + let scalar = match bits & 0b11 { + 0b00 => Scalar::F32, + 0b01 => Scalar::I32, + 0b10 => Scalar::U32, + _ => continue, + }; + let (size, second_size) = match bits >> 2 { + 0b000 => (None, None), + 0b001 => (Some(VectorSize::Bi), None), + 0b010 => (Some(VectorSize::Tri), None), + 0b011 => (Some(VectorSize::Quad), None), + 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)), + 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)), + _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)), + }; + + let args = vec![ + match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }, + match second_size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }, + ]; + + let fun = match name { + "max" => MacroCall::Splatted(MathFunction::Max, size, 1), + "min" => MacroCall::Splatted(MathFunction::Min, size, 1), + _ => unreachable!(), + }; + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + "mix" => { + // bits layout + // bit 0 through 1 - dims + // bit 2 through 4 - types + // + // 0b10011 is the last element since splatted single elements + // were already added + for bits in 0..0b10011 { + let size = match bits & 0b11 { + 0b00 => Some(VectorSize::Bi), + 0b01 => Some(VectorSize::Tri), + 0b10 => Some(VectorSize::Quad), + _ => None, + }; + let (scalar, splatted, boolean) = match bits >> 2 { + 0b000 => (Scalar::I32, false, true), + 0b001 => (Scalar::U32, false, true), + 0b010 => (Scalar::F32, false, true), + 0b011 => (Scalar::F32, false, false), + _ => (Scalar::F32, true, false), + }; + + let ty = |scalar| match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + let args = vec![ + ty(scalar), + ty(scalar), + match (boolean, splatted) { + (true, _) => ty(Scalar::BOOL), + (_, false) => TypeInner::Scalar(scalar), + _ => ty(scalar), + }, + ]; + + declaration.overloads.push(module.add_builtin( + args, + match boolean { + true => MacroCall::MixBoolean, + false => MacroCall::Splatted(MathFunction::Mix, size, 2), + }, + )) + } + } + "clamp" => { + // bits layout + // bit 0 through 1 - float/int/uint + // bit 2 through 3 - dims + // bit 4 - splatted + // + // 0b11010 is the last element since splatted single elements + // were already added + for bits in 0..0b11011 { + let scalar = match bits & 0b11 { + 0b00 => Scalar::F32, + 0b01 => Scalar::I32, + 0b10 => Scalar::U32, + _ => continue, + }; + let size = match (bits >> 2) & 0b11 { + 0b00 => Some(VectorSize::Bi), + 0b01 => Some(VectorSize::Tri), + 0b10 => Some(VectorSize::Quad), + _ => None, + }; + let splatted = bits & 0b10000 == 0b10000; + + let base_ty = || match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + let limit_ty = || match splatted { + true => TypeInner::Scalar(scalar), + false => base_ty(), + }; + + let args = vec![base_ty(), limit_ty(), limit_ty()]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::Clamp(size))) + } + } + "barrier" => declaration + .overloads + .push(module.add_builtin(Vec::new(), MacroCall::Barrier)), + // Add common builtins with floats + _ => inject_common_builtin(declaration, module, name, 4), + } +} + +/// Injects the builtins into declaration that need doubles +fn inject_double_builtin(declaration: &mut FunctionDeclaration, module: &mut Module, name: &str) { + match name { + "abs" | "sign" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + let scalar = Scalar::F64; + + let args = vec![match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }]; + + declaration.overloads.push(module.add_builtin( + args, + MacroCall::MathFunction(match name { + "abs" => MathFunction::Abs, + "sign" => MathFunction::Sign, + _ => unreachable!(), + }), + )) + } + } + "min" | "max" => { + // bits layout + // bit 0 through 2 - dims + for bits in 0..0b111 { + let (size, second_size) = match bits { + 0b000 => (None, None), + 0b001 => (Some(VectorSize::Bi), None), + 0b010 => (Some(VectorSize::Tri), None), + 0b011 => (Some(VectorSize::Quad), None), + 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)), + 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)), + _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)), + }; + let scalar = Scalar::F64; + + let args = vec![ + match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }, + match second_size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }, + ]; + + let fun = match name { + "max" => MacroCall::Splatted(MathFunction::Max, size, 1), + "min" => MacroCall::Splatted(MathFunction::Min, size, 1), + _ => unreachable!(), + }; + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + "mix" => { + // bits layout + // bit 0 through 1 - dims + // bit 2 through 3 - splatted/boolean + // + // 0b1010 is the last element since splatted with single elements + // is equal to normal single elements + for bits in 0..0b1011 { + let size = match bits & 0b11 { + 0b00 => Some(VectorSize::Quad), + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => None, + }; + let scalar = Scalar::F64; + let (splatted, boolean) = match bits >> 2 { + 0b00 => (false, false), + 0b01 => (false, true), + _ => (true, false), + }; + + let ty = |scalar| match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + let args = vec![ + ty(scalar), + ty(scalar), + match (boolean, splatted) { + (true, _) => ty(Scalar::BOOL), + (_, false) => TypeInner::Scalar(scalar), + _ => ty(scalar), + }, + ]; + + declaration.overloads.push(module.add_builtin( + args, + match boolean { + true => MacroCall::MixBoolean, + false => MacroCall::Splatted(MathFunction::Mix, size, 2), + }, + )) + } + } + "clamp" => { + // bits layout + // bit 0 through 1 - dims + // bit 2 - splatted + // + // 0b110 is the last element since splatted with single elements + // is equal to normal single elements + for bits in 0..0b111 { + let scalar = Scalar::F64; + let size = match bits & 0b11 { + 0b00 => Some(VectorSize::Bi), + 0b01 => Some(VectorSize::Tri), + 0b10 => Some(VectorSize::Quad), + _ => None, + }; + let splatted = bits & 0b100 == 0b100; + + let base_ty = || match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + let limit_ty = || match splatted { + true => TypeInner::Scalar(scalar), + false => base_ty(), + }; + + let args = vec![base_ty(), limit_ty(), limit_ty()]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::Clamp(size))) + } + } + "lessThan" | "greaterThan" | "lessThanEqual" | "greaterThanEqual" | "equal" + | "notEqual" => { + let scalar = Scalar::F64; + for bits in 0..0b11 { + let size = match bits { + 0b00 => VectorSize::Bi, + 0b01 => VectorSize::Tri, + _ => VectorSize::Quad, + }; + + let ty = || TypeInner::Vector { size, scalar }; + let args = vec![ty(), ty()]; + + let fun = MacroCall::Binary(match name { + "lessThan" => BinaryOperator::Less, + "greaterThan" => BinaryOperator::Greater, + "lessThanEqual" => BinaryOperator::LessEqual, + "greaterThanEqual" => BinaryOperator::GreaterEqual, + "equal" => BinaryOperator::Equal, + "notEqual" => BinaryOperator::NotEqual, + _ => unreachable!(), + }); + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + // Add common builtins with doubles + _ => inject_common_builtin(declaration, module, name, 8), + } +} + +/// Injects the builtins into declaration that can used either float or doubles +fn inject_common_builtin( + declaration: &mut FunctionDeclaration, + module: &mut Module, + name: &str, + float_width: crate::Bytes, +) { + let float_scalar = Scalar { + kind: Sk::Float, + width: float_width, + }; + match name { + "ceil" | "round" | "roundEven" | "floor" | "fract" | "trunc" | "sqrt" | "inversesqrt" + | "normalize" | "length" | "isinf" | "isnan" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + + let args = vec![match size { + Some(size) => TypeInner::Vector { + size, + scalar: float_scalar, + }, + None => TypeInner::Scalar(float_scalar), + }]; + + let fun = match name { + "ceil" => MacroCall::MathFunction(MathFunction::Ceil), + "round" | "roundEven" => MacroCall::MathFunction(MathFunction::Round), + "floor" => MacroCall::MathFunction(MathFunction::Floor), + "fract" => MacroCall::MathFunction(MathFunction::Fract), + "trunc" => MacroCall::MathFunction(MathFunction::Trunc), + "sqrt" => MacroCall::MathFunction(MathFunction::Sqrt), + "inversesqrt" => MacroCall::MathFunction(MathFunction::InverseSqrt), + "normalize" => MacroCall::MathFunction(MathFunction::Normalize), + "length" => MacroCall::MathFunction(MathFunction::Length), + "isinf" => MacroCall::Relational(RelationalFunction::IsInf), + "isnan" => MacroCall::Relational(RelationalFunction::IsNan), + _ => unreachable!(), + }; + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + "dot" | "reflect" | "distance" | "ldexp" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + let ty = |scalar| match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + + let fun = match name { + "dot" => MacroCall::MathFunction(MathFunction::Dot), + "reflect" => MacroCall::MathFunction(MathFunction::Reflect), + "distance" => MacroCall::MathFunction(MathFunction::Distance), + "ldexp" => MacroCall::MathFunction(MathFunction::Ldexp), + _ => unreachable!(), + }; + + let second_scalar = match fun { + MacroCall::MathFunction(MathFunction::Ldexp) => Scalar::I32, + _ => float_scalar, + }; + + declaration + .overloads + .push(module.add_builtin(vec![ty(float_scalar), ty(second_scalar)], fun)) + } + } + "transpose" => { + // bits layout + // bit 0 through 3 - dims + for bits in 0..0b1001 { + let (rows, columns) = match bits { + 0b0000 => (VectorSize::Bi, VectorSize::Bi), + 0b0001 => (VectorSize::Bi, VectorSize::Tri), + 0b0010 => (VectorSize::Bi, VectorSize::Quad), + 0b0011 => (VectorSize::Tri, VectorSize::Bi), + 0b0100 => (VectorSize::Tri, VectorSize::Tri), + 0b0101 => (VectorSize::Tri, VectorSize::Quad), + 0b0110 => (VectorSize::Quad, VectorSize::Bi), + 0b0111 => (VectorSize::Quad, VectorSize::Tri), + _ => (VectorSize::Quad, VectorSize::Quad), + }; + + declaration.overloads.push(module.add_builtin( + vec![TypeInner::Matrix { + columns, + rows, + scalar: float_scalar, + }], + MacroCall::MathFunction(MathFunction::Transpose), + )) + } + } + "inverse" | "determinant" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b11 { + let (rows, columns) = match bits { + 0b00 => (VectorSize::Bi, VectorSize::Bi), + 0b01 => (VectorSize::Tri, VectorSize::Tri), + _ => (VectorSize::Quad, VectorSize::Quad), + }; + + let args = vec![TypeInner::Matrix { + columns, + rows, + scalar: float_scalar, + }]; + + declaration.overloads.push(module.add_builtin( + args, + MacroCall::MathFunction(match name { + "inverse" => MathFunction::Inverse, + "determinant" => MathFunction::Determinant, + _ => unreachable!(), + }), + )) + } + } + "mod" | "step" => { + // bits layout + // bit 0 through 2 - dims + for bits in 0..0b111 { + let (size, second_size) = match bits { + 0b000 => (None, None), + 0b001 => (Some(VectorSize::Bi), None), + 0b010 => (Some(VectorSize::Tri), None), + 0b011 => (Some(VectorSize::Quad), None), + 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)), + 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)), + _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)), + }; + + let mut args = Vec::with_capacity(2); + let step = name == "step"; + + for i in 0..2 { + let maybe_size = match i == step as u32 { + true => size, + false => second_size, + }; + + args.push(match maybe_size { + Some(size) => TypeInner::Vector { + size, + scalar: float_scalar, + }, + None => TypeInner::Scalar(float_scalar), + }) + } + + let fun = match name { + "mod" => MacroCall::Mod(size), + "step" => MacroCall::Splatted(MathFunction::Step, size, 0), + _ => unreachable!(), + }; + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + // TODO: https://github.com/gfx-rs/naga/issues/2526 + // "modf" | "frexp" => { ... } + "cross" => { + let args = vec![ + TypeInner::Vector { + size: VectorSize::Tri, + scalar: float_scalar, + }, + TypeInner::Vector { + size: VectorSize::Tri, + scalar: float_scalar, + }, + ]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Cross))) + } + "outerProduct" => { + // bits layout + // bit 0 through 3 - dims + for bits in 0..0b1001 { + let (size1, size2) = match bits { + 0b0000 => (VectorSize::Bi, VectorSize::Bi), + 0b0001 => (VectorSize::Bi, VectorSize::Tri), + 0b0010 => (VectorSize::Bi, VectorSize::Quad), + 0b0011 => (VectorSize::Tri, VectorSize::Bi), + 0b0100 => (VectorSize::Tri, VectorSize::Tri), + 0b0101 => (VectorSize::Tri, VectorSize::Quad), + 0b0110 => (VectorSize::Quad, VectorSize::Bi), + 0b0111 => (VectorSize::Quad, VectorSize::Tri), + _ => (VectorSize::Quad, VectorSize::Quad), + }; + + let args = vec![ + TypeInner::Vector { + size: size1, + scalar: float_scalar, + }, + TypeInner::Vector { + size: size2, + scalar: float_scalar, + }, + ]; + + declaration + .overloads + .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Outer))) + } + } + "faceforward" | "fma" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + + let ty = || match size { + Some(size) => TypeInner::Vector { + size, + scalar: float_scalar, + }, + None => TypeInner::Scalar(float_scalar), + }; + let args = vec![ty(), ty(), ty()]; + + let fun = match name { + "faceforward" => MacroCall::MathFunction(MathFunction::FaceForward), + "fma" => MacroCall::MathFunction(MathFunction::Fma), + _ => unreachable!(), + }; + + declaration.overloads.push(module.add_builtin(args, fun)) + } + } + "refract" => { + // bits layout + // bit 0 through 1 - dims + for bits in 0..0b100 { + let size = match bits { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + + let ty = || match size { + Some(size) => TypeInner::Vector { + size, + scalar: float_scalar, + }, + None => TypeInner::Scalar(float_scalar), + }; + let args = vec![ty(), ty(), TypeInner::Scalar(Scalar::F32)]; + declaration + .overloads + .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Refract))) + } + } + "smoothstep" => { + // bit 0 - splatted + // bit 1 through 2 - dims + for bits in 0..0b1000 { + let splatted = bits & 0b1 == 0b1; + let size = match bits >> 1 { + 0b00 => None, + 0b01 => Some(VectorSize::Bi), + 0b10 => Some(VectorSize::Tri), + _ => Some(VectorSize::Quad), + }; + + if splatted && size.is_none() { + continue; + } + + let base_ty = || match size { + Some(size) => TypeInner::Vector { + size, + scalar: float_scalar, + }, + None => TypeInner::Scalar(float_scalar), + }; + let ty = || match splatted { + true => TypeInner::Scalar(float_scalar), + false => base_ty(), + }; + declaration.overloads.push(module.add_builtin( + vec![ty(), ty(), base_ty()], + MacroCall::SmoothStep { splatted: size }, + )) + } + } + // The function isn't a builtin or we don't yet support it + _ => {} + } +} + +#[derive(Clone, Copy, PartialEq, Debug)] +pub enum TextureLevelType { + None, + Lod, + Grad, +} + +/// A compiler defined builtin function +#[derive(Clone, Copy, PartialEq, Debug)] +pub enum MacroCall { + Sampler, + SamplerShadow, + Texture { + proj: bool, + offset: bool, + shadow: bool, + level_type: TextureLevelType, + }, + TextureSize { + arrayed: bool, + }, + ImageLoad { + multi: bool, + }, + ImageStore, + MathFunction(MathFunction), + FindLsbUint, + FindMsbUint, + BitfieldExtract, + BitfieldInsert, + Relational(RelationalFunction), + Unary(UnaryOperator), + Binary(BinaryOperator), + Mod(Option<VectorSize>), + Splatted(MathFunction, Option<VectorSize>, usize), + MixBoolean, + Clamp(Option<VectorSize>), + BitCast(Sk), + Derivate(Axis, Ctrl), + Barrier, + /// SmoothStep needs a separate variant because it might need it's inputs + /// to be splatted depending on the overload + SmoothStep { + /// The size of the splat operation if some + splatted: Option<VectorSize>, + }, +} + +impl MacroCall { + /// Adds the necessary expressions and statements to the passed body and + /// finally returns the final expression with the correct result + pub fn call( + &self, + frontend: &mut Frontend, + ctx: &mut Context, + args: &mut [Handle<Expression>], + meta: Span, + ) -> Result<Option<Handle<Expression>>> { + Ok(Some(match *self { + MacroCall::Sampler => { + ctx.samplers.insert(args[0], args[1]); + args[0] + } + MacroCall::SamplerShadow => { + sampled_to_depth(ctx, args[0], meta, &mut frontend.errors); + ctx.invalidate_expression(args[0], meta)?; + ctx.samplers.insert(args[0], args[1]); + args[0] + } + MacroCall::Texture { + proj, + offset, + shadow, + level_type, + } => { + let mut coords = args[1]; + + if proj { + let size = match *ctx.resolve_type(coords, meta)? { + TypeInner::Vector { size, .. } => size, + _ => unreachable!(), + }; + let mut right = ctx.add_expression( + Expression::AccessIndex { + base: coords, + index: size as u32 - 1, + }, + Span::default(), + )?; + let left = if let VectorSize::Bi = size { + ctx.add_expression( + Expression::AccessIndex { + base: coords, + index: 0, + }, + Span::default(), + )? + } else { + let size = match size { + VectorSize::Tri => VectorSize::Bi, + _ => VectorSize::Tri, + }; + right = ctx.add_expression( + Expression::Splat { size, value: right }, + Span::default(), + )?; + ctx.vector_resize(size, coords, Span::default())? + }; + coords = ctx.add_expression( + Expression::Binary { + op: BinaryOperator::Divide, + left, + right, + }, + Span::default(), + )?; + } + + let extra = args.get(2).copied(); + let comps = frontend.coordinate_components(ctx, args[0], coords, extra, meta)?; + + let mut num_args = 2; + + if comps.used_extra { + num_args += 1; + }; + + // Parse out explicit texture level. + let mut level = match level_type { + TextureLevelType::None => SampleLevel::Auto, + + TextureLevelType::Lod => { + num_args += 1; + + if shadow { + log::warn!("Assuming LOD {:?} is zero", args[2],); + + SampleLevel::Zero + } else { + SampleLevel::Exact(args[2]) + } + } + + TextureLevelType::Grad => { + num_args += 2; + + if shadow { + log::warn!( + "Assuming gradients {:?} and {:?} are not greater than 1", + args[2], + args[3], + ); + SampleLevel::Zero + } else { + SampleLevel::Gradient { + x: args[2], + y: args[3], + } + } + } + }; + + let texture_offset = match offset { + true => { + let offset_arg = args[num_args]; + num_args += 1; + match ctx.lift_up_const_expression(offset_arg) { + Ok(v) => Some(v), + Err(e) => { + frontend.errors.push(e); + None + } + } + } + false => None, + }; + + // Now go back and look for optional bias arg (if available) + if let TextureLevelType::None = level_type { + level = args + .get(num_args) + .copied() + .map_or(SampleLevel::Auto, SampleLevel::Bias); + } + + texture_call(ctx, args[0], level, comps, texture_offset, meta)? + } + + MacroCall::TextureSize { arrayed } => { + let mut expr = ctx.add_expression( + Expression::ImageQuery { + image: args[0], + query: ImageQuery::Size { + level: args.get(1).copied(), + }, + }, + Span::default(), + )?; + + if arrayed { + let mut components = Vec::with_capacity(4); + + let size = match *ctx.resolve_type(expr, meta)? { + TypeInner::Vector { size: ori_size, .. } => { + for index in 0..(ori_size as u32) { + components.push(ctx.add_expression( + Expression::AccessIndex { base: expr, index }, + Span::default(), + )?) + } + + match ori_size { + VectorSize::Bi => VectorSize::Tri, + _ => VectorSize::Quad, + } + } + _ => { + components.push(expr); + VectorSize::Bi + } + }; + + components.push(ctx.add_expression( + Expression::ImageQuery { + image: args[0], + query: ImageQuery::NumLayers, + }, + Span::default(), + )?); + + let ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size, + scalar: Scalar::U32, + }, + }, + Span::default(), + ); + + expr = ctx.add_expression(Expression::Compose { components, ty }, meta)? + } + + ctx.add_expression( + Expression::As { + expr, + kind: Sk::Sint, + convert: Some(4), + }, + Span::default(), + )? + } + MacroCall::ImageLoad { multi } => { + let comps = frontend.coordinate_components(ctx, args[0], args[1], None, meta)?; + let (sample, level) = match (multi, args.get(2)) { + (_, None) => (None, None), + (true, Some(&arg)) => (Some(arg), None), + (false, Some(&arg)) => (None, Some(arg)), + }; + ctx.add_expression( + Expression::ImageLoad { + image: args[0], + coordinate: comps.coordinate, + array_index: comps.array_index, + sample, + level, + }, + Span::default(), + )? + } + MacroCall::ImageStore => { + let comps = frontend.coordinate_components(ctx, args[0], args[1], None, meta)?; + ctx.emit_restart(); + ctx.body.push( + crate::Statement::ImageStore { + image: args[0], + coordinate: comps.coordinate, + array_index: comps.array_index, + value: args[2], + }, + meta, + ); + return Ok(None); + } + MacroCall::MathFunction(fun) => ctx.add_expression( + Expression::Math { + fun, + arg: args[0], + arg1: args.get(1).copied(), + arg2: args.get(2).copied(), + arg3: args.get(3).copied(), + }, + Span::default(), + )?, + mc @ (MacroCall::FindLsbUint | MacroCall::FindMsbUint) => { + let fun = match mc { + MacroCall::FindLsbUint => MathFunction::FindLsb, + MacroCall::FindMsbUint => MathFunction::FindMsb, + _ => unreachable!(), + }; + let res = ctx.add_expression( + Expression::Math { + fun, + arg: args[0], + arg1: None, + arg2: None, + arg3: None, + }, + Span::default(), + )?; + ctx.add_expression( + Expression::As { + expr: res, + kind: Sk::Sint, + convert: Some(4), + }, + Span::default(), + )? + } + MacroCall::BitfieldInsert => { + let conv_arg_2 = ctx.add_expression( + Expression::As { + expr: args[2], + kind: Sk::Uint, + convert: Some(4), + }, + Span::default(), + )?; + let conv_arg_3 = ctx.add_expression( + Expression::As { + expr: args[3], + kind: Sk::Uint, + convert: Some(4), + }, + Span::default(), + )?; + ctx.add_expression( + Expression::Math { + fun: MathFunction::InsertBits, + arg: args[0], + arg1: Some(args[1]), + arg2: Some(conv_arg_2), + arg3: Some(conv_arg_3), + }, + Span::default(), + )? + } + MacroCall::BitfieldExtract => { + let conv_arg_1 = ctx.add_expression( + Expression::As { + expr: args[1], + kind: Sk::Uint, + convert: Some(4), + }, + Span::default(), + )?; + let conv_arg_2 = ctx.add_expression( + Expression::As { + expr: args[2], + kind: Sk::Uint, + convert: Some(4), + }, + Span::default(), + )?; + ctx.add_expression( + Expression::Math { + fun: MathFunction::ExtractBits, + arg: args[0], + arg1: Some(conv_arg_1), + arg2: Some(conv_arg_2), + arg3: None, + }, + Span::default(), + )? + } + MacroCall::Relational(fun) => ctx.add_expression( + Expression::Relational { + fun, + argument: args[0], + }, + Span::default(), + )?, + MacroCall::Unary(op) => { + ctx.add_expression(Expression::Unary { op, expr: args[0] }, Span::default())? + } + MacroCall::Binary(op) => ctx.add_expression( + Expression::Binary { + op, + left: args[0], + right: args[1], + }, + Span::default(), + )?, + MacroCall::Mod(size) => { + ctx.implicit_splat(&mut args[1], meta, size)?; + + // x - y * floor(x / y) + + let div = ctx.add_expression( + Expression::Binary { + op: BinaryOperator::Divide, + left: args[0], + right: args[1], + }, + Span::default(), + )?; + let floor = ctx.add_expression( + Expression::Math { + fun: MathFunction::Floor, + arg: div, + arg1: None, + arg2: None, + arg3: None, + }, + Span::default(), + )?; + let mult = ctx.add_expression( + Expression::Binary { + op: BinaryOperator::Multiply, + left: floor, + right: args[1], + }, + Span::default(), + )?; + ctx.add_expression( + Expression::Binary { + op: BinaryOperator::Subtract, + left: args[0], + right: mult, + }, + Span::default(), + )? + } + MacroCall::Splatted(fun, size, i) => { + ctx.implicit_splat(&mut args[i], meta, size)?; + + ctx.add_expression( + Expression::Math { + fun, + arg: args[0], + arg1: args.get(1).copied(), + arg2: args.get(2).copied(), + arg3: args.get(3).copied(), + }, + Span::default(), + )? + } + MacroCall::MixBoolean => ctx.add_expression( + Expression::Select { + condition: args[2], + accept: args[1], + reject: args[0], + }, + Span::default(), + )?, + MacroCall::Clamp(size) => { + ctx.implicit_splat(&mut args[1], meta, size)?; + ctx.implicit_splat(&mut args[2], meta, size)?; + + ctx.add_expression( + Expression::Math { + fun: MathFunction::Clamp, + arg: args[0], + arg1: args.get(1).copied(), + arg2: args.get(2).copied(), + arg3: args.get(3).copied(), + }, + Span::default(), + )? + } + MacroCall::BitCast(kind) => ctx.add_expression( + Expression::As { + expr: args[0], + kind, + convert: None, + }, + Span::default(), + )?, + MacroCall::Derivate(axis, ctrl) => ctx.add_expression( + Expression::Derivative { + axis, + ctrl, + expr: args[0], + }, + Span::default(), + )?, + MacroCall::Barrier => { + ctx.emit_restart(); + ctx.body + .push(crate::Statement::Barrier(crate::Barrier::all()), meta); + return Ok(None); + } + MacroCall::SmoothStep { splatted } => { + ctx.implicit_splat(&mut args[0], meta, splatted)?; + ctx.implicit_splat(&mut args[1], meta, splatted)?; + + ctx.add_expression( + Expression::Math { + fun: MathFunction::SmoothStep, + arg: args[0], + arg1: args.get(1).copied(), + arg2: args.get(2).copied(), + arg3: None, + }, + Span::default(), + )? + } + })) + } +} + +fn texture_call( + ctx: &mut Context, + image: Handle<Expression>, + level: SampleLevel, + comps: CoordComponents, + offset: Option<Handle<Expression>>, + meta: Span, +) -> Result<Handle<Expression>> { + if let Some(sampler) = ctx.samplers.get(&image).copied() { + let mut array_index = comps.array_index; + + if let Some(ref mut array_index_expr) = array_index { + ctx.conversion(array_index_expr, meta, Scalar::I32)?; + } + + Ok(ctx.add_expression( + Expression::ImageSample { + image, + sampler, + gather: None, //TODO + coordinate: comps.coordinate, + array_index, + offset, + level, + depth_ref: comps.depth_ref, + }, + meta, + )?) + } else { + Err(Error { + kind: ErrorKind::SemanticError("Bad call".into()), + meta, + }) + } +} + +/// Helper struct for texture calls with the separate components from the vector argument +/// +/// Obtained by calling [`coordinate_components`](Frontend::coordinate_components) +#[derive(Debug)] +struct CoordComponents { + coordinate: Handle<Expression>, + depth_ref: Option<Handle<Expression>>, + array_index: Option<Handle<Expression>>, + used_extra: bool, +} + +impl Frontend { + /// Helper function for texture calls, splits the vector argument into it's components + fn coordinate_components( + &mut self, + ctx: &mut Context, + image: Handle<Expression>, + coord: Handle<Expression>, + extra: Option<Handle<Expression>>, + meta: Span, + ) -> Result<CoordComponents> { + if let TypeInner::Image { + dim, + arrayed, + class, + } = *ctx.resolve_type(image, meta)? + { + let image_size = match dim { + Dim::D1 => None, + Dim::D2 => Some(VectorSize::Bi), + Dim::D3 => Some(VectorSize::Tri), + Dim::Cube => Some(VectorSize::Tri), + }; + let coord_size = match *ctx.resolve_type(coord, meta)? { + TypeInner::Vector { size, .. } => Some(size), + _ => None, + }; + let (shadow, storage) = match class { + ImageClass::Depth { .. } => (true, false), + ImageClass::Storage { .. } => (false, true), + ImageClass::Sampled { .. } => (false, false), + }; + + let coordinate = match (image_size, coord_size) { + (Some(size), Some(coord_s)) if size != coord_s => { + ctx.vector_resize(size, coord, Span::default())? + } + (None, Some(_)) => ctx.add_expression( + Expression::AccessIndex { + base: coord, + index: 0, + }, + Span::default(), + )?, + _ => coord, + }; + + let mut coord_index = image_size.map_or(1, |s| s as u32); + + let array_index = if arrayed && !(storage && dim == Dim::Cube) { + let index = coord_index; + coord_index += 1; + + Some(ctx.add_expression( + Expression::AccessIndex { base: coord, index }, + Span::default(), + )?) + } else { + None + }; + let mut used_extra = false; + let depth_ref = match shadow { + true => { + let index = coord_index; + + if index == 4 { + used_extra = true; + extra + } else { + Some(ctx.add_expression( + Expression::AccessIndex { base: coord, index }, + Span::default(), + )?) + } + } + false => None, + }; + + Ok(CoordComponents { + coordinate, + depth_ref, + array_index, + used_extra, + }) + } else { + self.errors.push(Error { + kind: ErrorKind::SemanticError("Type is not an image".into()), + meta, + }); + + Ok(CoordComponents { + coordinate: coord, + depth_ref: None, + array_index: None, + used_extra: false, + }) + } + } +} + +/// Helper function to cast a expression holding a sampled image to a +/// depth image. +pub fn sampled_to_depth( + ctx: &mut Context, + image: Handle<Expression>, + meta: Span, + errors: &mut Vec<Error>, +) { + // Get the a mutable type handle of the underlying image storage + let ty = match ctx[image] { + Expression::GlobalVariable(handle) => &mut ctx.module.global_variables.get_mut(handle).ty, + Expression::FunctionArgument(i) => { + // Mark the function argument as carrying a depth texture + ctx.parameters_info[i as usize].depth = true; + // NOTE: We need to later also change the parameter type + &mut ctx.arguments[i as usize].ty + } + _ => { + // Only globals and function arguments are allowed to carry an image + return errors.push(Error { + kind: ErrorKind::SemanticError("Not a valid texture expression".into()), + meta, + }); + } + }; + + match ctx.module.types[*ty].inner { + // Update the image class to depth in case it already isn't + TypeInner::Image { + class, + dim, + arrayed, + } => match class { + ImageClass::Sampled { multi, .. } => { + *ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Image { + dim, + arrayed, + class: ImageClass::Depth { multi }, + }, + }, + Span::default(), + ) + } + ImageClass::Depth { .. } => {} + // Other image classes aren't allowed to be transformed to depth + ImageClass::Storage { .. } => errors.push(Error { + kind: ErrorKind::SemanticError("Not a texture".into()), + meta, + }), + }, + _ => errors.push(Error { + kind: ErrorKind::SemanticError("Not a texture".into()), + meta, + }), + }; + + // Copy the handle to allow borrowing the `ctx` again + let ty = *ty; + + // If the image was passed through a function argument we also need to change + // the corresponding parameter + if let Expression::FunctionArgument(i) = ctx[image] { + ctx.parameters[i as usize] = ty; + } +} + +bitflags::bitflags! { + /// Influences the operation `texture_args_generator` + struct TextureArgsOptions: u32 { + /// Generates multisampled variants of images + const MULTI = 1 << 0; + /// Generates shadow variants of images + const SHADOW = 1 << 1; + /// Generates standard images + const STANDARD = 1 << 2; + /// Generates cube arrayed images + const CUBE_ARRAY = 1 << 3; + /// Generates cube arrayed images + const D2_MULTI_ARRAY = 1 << 4; + } +} + +impl From<BuiltinVariations> for TextureArgsOptions { + fn from(variations: BuiltinVariations) -> Self { + let mut options = TextureArgsOptions::empty(); + if variations.contains(BuiltinVariations::STANDARD) { + options |= TextureArgsOptions::STANDARD + } + if variations.contains(BuiltinVariations::CUBE_TEXTURES_ARRAY) { + options |= TextureArgsOptions::CUBE_ARRAY + } + if variations.contains(BuiltinVariations::D2_MULTI_TEXTURES_ARRAY) { + options |= TextureArgsOptions::D2_MULTI_ARRAY + } + options + } +} + +/// Helper function to generate the image components for texture/image builtins +/// +/// Calls the passed function `f` with: +/// ```text +/// f(ScalarKind, ImageDimension, arrayed, multi, shadow) +/// ``` +/// +/// `options` controls extra image variants generation like multisampling and depth, +/// see the struct documentation +fn texture_args_generator( + options: TextureArgsOptions, + mut f: impl FnMut(crate::ScalarKind, Dim, bool, bool, bool), +) { + for kind in [Sk::Float, Sk::Uint, Sk::Sint].iter().copied() { + for dim in [Dim::D1, Dim::D2, Dim::D3, Dim::Cube].iter().copied() { + for arrayed in [false, true].iter().copied() { + if dim == Dim::Cube && arrayed { + if !options.contains(TextureArgsOptions::CUBE_ARRAY) { + continue; + } + } else if Dim::D2 == dim + && options.contains(TextureArgsOptions::MULTI) + && arrayed + && options.contains(TextureArgsOptions::D2_MULTI_ARRAY) + { + // multisampling for sampler2DMSArray + f(kind, dim, arrayed, true, false); + } else if !options.contains(TextureArgsOptions::STANDARD) { + continue; + } + + f(kind, dim, arrayed, false, false); + + // 3D images can't be neither arrayed nor shadow + // so we break out early, this way arrayed will always + // be false and we won't hit the shadow branch + if let Dim::D3 = dim { + break; + } + + if Dim::D2 == dim && options.contains(TextureArgsOptions::MULTI) && !arrayed { + // multisampling + f(kind, dim, arrayed, true, false); + } + + if Sk::Float == kind && options.contains(TextureArgsOptions::SHADOW) { + // shadow + f(kind, dim, arrayed, false, true); + } + } + } + } +} + +/// Helper functions used to convert from a image dimension into a integer representing the +/// number of components needed for the coordinates vector (1 means scalar instead of vector) +const fn image_dims_to_coords_size(dim: Dim) -> usize { + match dim { + Dim::D1 => 1, + Dim::D2 => 2, + _ => 3, + } +} diff --git a/third_party/rust/naga/src/front/glsl/context.rs b/third_party/rust/naga/src/front/glsl/context.rs new file mode 100644 index 0000000000..f26c57965d --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/context.rs @@ -0,0 +1,1506 @@ +use super::{ + ast::{ + GlobalLookup, GlobalLookupKind, HirExpr, HirExprKind, ParameterInfo, ParameterQualifier, + VariableReference, + }, + error::{Error, ErrorKind}, + types::{scalar_components, type_power}, + Frontend, Result, +}; +use crate::{ + front::Typifier, proc::Emitter, AddressSpace, Arena, BinaryOperator, Block, Expression, + FastHashMap, FunctionArgument, Handle, Literal, LocalVariable, RelationalFunction, Scalar, + Span, Statement, Type, TypeInner, VectorSize, +}; +use std::ops::Index; + +/// The position at which an expression is, used while lowering +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum ExprPos { + /// The expression is in the left hand side of an assignment + Lhs, + /// The expression is in the right hand side of an assignment + Rhs, + /// The expression is an array being indexed, needed to allow constant + /// arrays to be dynamically indexed + AccessBase { + /// The index is a constant + constant_index: bool, + }, +} + +impl ExprPos { + /// Returns an lhs position if the current position is lhs otherwise AccessBase + const fn maybe_access_base(&self, constant_index: bool) -> Self { + match *self { + ExprPos::Lhs + | ExprPos::AccessBase { + constant_index: false, + } => *self, + _ => ExprPos::AccessBase { constant_index }, + } + } +} + +#[derive(Debug)] +pub struct Context<'a> { + pub expressions: Arena<Expression>, + pub locals: Arena<LocalVariable>, + + /// The [`FunctionArgument`]s for the final [`crate::Function`]. + /// + /// Parameters with the `out` and `inout` qualifiers have [`Pointer`] types + /// here. For example, an `inout vec2 a` argument would be a [`Pointer`] to + /// a [`Vector`]. + /// + /// [`Pointer`]: crate::TypeInner::Pointer + /// [`Vector`]: crate::TypeInner::Vector + pub arguments: Vec<FunctionArgument>, + + /// The parameter types given in the source code. + /// + /// The `out` and `inout` qualifiers don't affect the types that appear + /// here. For example, an `inout vec2 a` argument would simply be a + /// [`Vector`], not a pointer to one. + /// + /// [`Vector`]: crate::TypeInner::Vector + pub parameters: Vec<Handle<Type>>, + pub parameters_info: Vec<ParameterInfo>, + + pub symbol_table: crate::front::SymbolTable<String, VariableReference>, + pub samplers: FastHashMap<Handle<Expression>, Handle<Expression>>, + + pub const_typifier: Typifier, + pub typifier: Typifier, + emitter: Emitter, + stmt_ctx: Option<StmtContext>, + pub body: Block, + pub module: &'a mut crate::Module, + pub is_const: bool, + /// Tracks the constness of `Expression`s residing in `self.expressions` + pub expression_constness: crate::proc::ExpressionConstnessTracker, +} + +impl<'a> Context<'a> { + pub fn new(frontend: &Frontend, module: &'a mut crate::Module, is_const: bool) -> Result<Self> { + let mut this = Context { + expressions: Arena::new(), + locals: Arena::new(), + arguments: Vec::new(), + + parameters: Vec::new(), + parameters_info: Vec::new(), + + symbol_table: crate::front::SymbolTable::default(), + samplers: FastHashMap::default(), + + const_typifier: Typifier::new(), + typifier: Typifier::new(), + emitter: Emitter::default(), + stmt_ctx: Some(StmtContext::new()), + body: Block::new(), + module, + is_const: false, + expression_constness: crate::proc::ExpressionConstnessTracker::new(), + }; + + this.emit_start(); + + for &(ref name, lookup) in frontend.global_variables.iter() { + this.add_global(name, lookup)? + } + this.is_const = is_const; + + Ok(this) + } + + pub fn new_body<F>(&mut self, cb: F) -> Result<Block> + where + F: FnOnce(&mut Self) -> Result<()>, + { + self.new_body_with_ret(cb).map(|(b, _)| b) + } + + pub fn new_body_with_ret<F, R>(&mut self, cb: F) -> Result<(Block, R)> + where + F: FnOnce(&mut Self) -> Result<R>, + { + self.emit_restart(); + let old_body = std::mem::replace(&mut self.body, Block::new()); + let res = cb(self); + self.emit_restart(); + let new_body = std::mem::replace(&mut self.body, old_body); + res.map(|r| (new_body, r)) + } + + pub fn with_body<F>(&mut self, body: Block, cb: F) -> Result<Block> + where + F: FnOnce(&mut Self) -> Result<()>, + { + self.emit_restart(); + let old_body = std::mem::replace(&mut self.body, body); + let res = cb(self); + self.emit_restart(); + let body = std::mem::replace(&mut self.body, old_body); + res.map(|_| body) + } + + pub fn add_global( + &mut self, + name: &str, + GlobalLookup { + kind, + entry_arg, + mutable, + }: GlobalLookup, + ) -> Result<()> { + let (expr, load, constant) = match kind { + GlobalLookupKind::Variable(v) => { + let span = self.module.global_variables.get_span(v); + ( + self.add_expression(Expression::GlobalVariable(v), span)?, + self.module.global_variables[v].space != AddressSpace::Handle, + None, + ) + } + GlobalLookupKind::BlockSelect(handle, index) => { + let span = self.module.global_variables.get_span(handle); + let base = self.add_expression(Expression::GlobalVariable(handle), span)?; + let expr = self.add_expression(Expression::AccessIndex { base, index }, span)?; + + ( + expr, + { + let ty = self.module.global_variables[handle].ty; + + match self.module.types[ty].inner { + TypeInner::Struct { ref members, .. } => { + if let TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } = self.module.types[members[index as usize].ty].inner + { + false + } else { + true + } + } + _ => true, + } + }, + None, + ) + } + GlobalLookupKind::Constant(v, ty) => { + let span = self.module.constants.get_span(v); + ( + self.add_expression(Expression::Constant(v), span)?, + false, + Some((v, ty)), + ) + } + }; + + let var = VariableReference { + expr, + load, + mutable, + constant, + entry_arg, + }; + + self.symbol_table.add(name.into(), var); + + Ok(()) + } + + /// Starts the expression emitter + /// + /// # Panics + /// + /// - If called twice in a row without calling [`emit_end`][Self::emit_end]. + #[inline] + pub fn emit_start(&mut self) { + self.emitter.start(&self.expressions) + } + + /// Emits all the expressions captured by the emitter to the current body + /// + /// # Panics + /// + /// - If called before calling [`emit_start`]. + /// - If called twice in a row without calling [`emit_start`]. + /// + /// [`emit_start`]: Self::emit_start + pub fn emit_end(&mut self) { + self.body.extend(self.emitter.finish(&self.expressions)) + } + + /// Emits all the expressions captured by the emitter to the current body + /// and starts the emitter again + /// + /// # Panics + /// + /// - If called before calling [`emit_start`][Self::emit_start]. + pub fn emit_restart(&mut self) { + self.emit_end(); + self.emit_start() + } + + pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> { + let mut eval = if self.is_const { + crate::proc::ConstantEvaluator::for_glsl_module(self.module) + } else { + crate::proc::ConstantEvaluator::for_glsl_function( + self.module, + &mut self.expressions, + &mut self.expression_constness, + &mut self.emitter, + &mut self.body, + ) + }; + + let res = eval.try_eval_and_append(&expr, meta).map_err(|e| Error { + kind: e.into(), + meta, + }); + + match res { + Ok(expr) => Ok(expr), + Err(e) => { + if self.is_const { + Err(e) + } else { + let needs_pre_emit = expr.needs_pre_emit(); + if needs_pre_emit { + self.body.extend(self.emitter.finish(&self.expressions)); + } + let h = self.expressions.append(expr, meta); + if needs_pre_emit { + self.emitter.start(&self.expressions); + } + Ok(h) + } + } + } + } + + /// Add variable to current scope + /// + /// Returns a variable if a variable with the same name was already defined, + /// otherwise returns `None` + pub fn add_local_var( + &mut self, + name: String, + expr: Handle<Expression>, + mutable: bool, + ) -> Option<VariableReference> { + let var = VariableReference { + expr, + load: true, + mutable, + constant: None, + entry_arg: None, + }; + + self.symbol_table.add(name, var) + } + + /// Add function argument to current scope + pub fn add_function_arg( + &mut self, + name_meta: Option<(String, Span)>, + ty: Handle<Type>, + qualifier: ParameterQualifier, + ) -> Result<()> { + let index = self.arguments.len(); + let mut arg = FunctionArgument { + name: name_meta.as_ref().map(|&(ref name, _)| name.clone()), + ty, + binding: None, + }; + self.parameters.push(ty); + + let opaque = match self.module.types[ty].inner { + TypeInner::Image { .. } | TypeInner::Sampler { .. } => true, + _ => false, + }; + + if qualifier.is_lhs() { + let span = self.module.types.get_span(arg.ty); + arg.ty = self.module.types.insert( + Type { + name: None, + inner: TypeInner::Pointer { + base: arg.ty, + space: AddressSpace::Function, + }, + }, + span, + ) + } + + self.arguments.push(arg); + + self.parameters_info.push(ParameterInfo { + qualifier, + depth: false, + }); + + if let Some((name, meta)) = name_meta { + let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta)?; + let mutable = qualifier != ParameterQualifier::Const && !opaque; + let load = qualifier.is_lhs(); + + let var = if mutable && !load { + let handle = self.locals.append( + LocalVariable { + name: Some(name.clone()), + ty, + init: None, + }, + meta, + ); + let local_expr = self.add_expression(Expression::LocalVariable(handle), meta)?; + + self.emit_restart(); + + self.body.push( + Statement::Store { + pointer: local_expr, + value: expr, + }, + meta, + ); + + VariableReference { + expr: local_expr, + load: true, + mutable, + constant: None, + entry_arg: None, + } + } else { + VariableReference { + expr, + load, + mutable, + constant: None, + entry_arg: None, + } + }; + + self.symbol_table.add(name, var); + } + + Ok(()) + } + + /// Returns a [`StmtContext`] to be used in parsing and lowering + /// + /// # Panics + /// + /// - If more than one [`StmtContext`] are active at the same time or if the + /// previous call didn't use it in lowering. + #[must_use] + pub fn stmt_ctx(&mut self) -> StmtContext { + self.stmt_ctx.take().unwrap() + } + + /// Lowers a [`HirExpr`] which might produce a [`Expression`]. + /// + /// consumes a [`StmtContext`] returning it to the context so that it can be + /// used again later. + pub fn lower( + &mut self, + mut stmt: StmtContext, + frontend: &mut Frontend, + expr: Handle<HirExpr>, + pos: ExprPos, + ) -> Result<(Option<Handle<Expression>>, Span)> { + let res = self.lower_inner(&stmt, frontend, expr, pos); + + stmt.hir_exprs.clear(); + self.stmt_ctx = Some(stmt); + + res + } + + /// Similar to [`lower`](Self::lower) but returns an error if the expression + /// returns void (ie. doesn't produce a [`Expression`]). + /// + /// consumes a [`StmtContext`] returning it to the context so that it can be + /// used again later. + pub fn lower_expect( + &mut self, + mut stmt: StmtContext, + frontend: &mut Frontend, + expr: Handle<HirExpr>, + pos: ExprPos, + ) -> Result<(Handle<Expression>, Span)> { + let res = self.lower_expect_inner(&stmt, frontend, expr, pos); + + stmt.hir_exprs.clear(); + self.stmt_ctx = Some(stmt); + + res + } + + /// internal implementation of [`lower_expect`](Self::lower_expect) + /// + /// this method is only public because it's used in + /// [`function_call`](Frontend::function_call), unless you know what + /// you're doing use [`lower_expect`](Self::lower_expect) + pub fn lower_expect_inner( + &mut self, + stmt: &StmtContext, + frontend: &mut Frontend, + expr: Handle<HirExpr>, + pos: ExprPos, + ) -> Result<(Handle<Expression>, Span)> { + let (maybe_expr, meta) = self.lower_inner(stmt, frontend, expr, pos)?; + + let expr = match maybe_expr { + Some(e) => e, + None => { + return Err(Error { + kind: ErrorKind::SemanticError("Expression returns void".into()), + meta, + }) + } + }; + + Ok((expr, meta)) + } + + fn lower_store( + &mut self, + pointer: Handle<Expression>, + value: Handle<Expression>, + meta: Span, + ) -> Result<()> { + if let Expression::Swizzle { + size, + mut vector, + pattern, + } = self.expressions[pointer] + { + // Stores to swizzled values are not directly supported, + // lower them as series of per-component stores. + let size = match size { + VectorSize::Bi => 2, + VectorSize::Tri => 3, + VectorSize::Quad => 4, + }; + + if let Expression::Load { pointer } = self.expressions[vector] { + vector = pointer; + } + + #[allow(clippy::needless_range_loop)] + for index in 0..size { + let dst = self.add_expression( + Expression::AccessIndex { + base: vector, + index: pattern[index].index(), + }, + meta, + )?; + let src = self.add_expression( + Expression::AccessIndex { + base: value, + index: index as u32, + }, + meta, + )?; + + self.emit_restart(); + + self.body.push( + Statement::Store { + pointer: dst, + value: src, + }, + meta, + ); + } + } else { + self.emit_restart(); + + self.body.push(Statement::Store { pointer, value }, meta); + } + + Ok(()) + } + + /// Internal implementation of [`lower`](Self::lower) + fn lower_inner( + &mut self, + stmt: &StmtContext, + frontend: &mut Frontend, + expr: Handle<HirExpr>, + pos: ExprPos, + ) -> Result<(Option<Handle<Expression>>, Span)> { + let HirExpr { ref kind, meta } = stmt.hir_exprs[expr]; + + log::debug!("Lowering {:?} (kind {:?}, pos {:?})", expr, kind, pos); + + let handle = match *kind { + HirExprKind::Access { base, index } => { + let (index, _) = self.lower_expect_inner(stmt, frontend, index, ExprPos::Rhs)?; + let maybe_constant_index = match pos { + // Don't try to generate `AccessIndex` if in a LHS position, since it + // wouldn't produce a pointer. + ExprPos::Lhs => None, + _ => self + .module + .to_ctx() + .eval_expr_to_u32_from(index, &self.expressions) + .ok(), + }; + + let base = self + .lower_expect_inner( + stmt, + frontend, + base, + pos.maybe_access_base(maybe_constant_index.is_some()), + )? + .0; + + let pointer = maybe_constant_index + .map(|index| self.add_expression(Expression::AccessIndex { base, index }, meta)) + .unwrap_or_else(|| { + self.add_expression(Expression::Access { base, index }, meta) + })?; + + if ExprPos::Rhs == pos { + let resolved = self.resolve_type(pointer, meta)?; + if resolved.pointer_space().is_some() { + return Ok(( + Some(self.add_expression(Expression::Load { pointer }, meta)?), + meta, + )); + } + } + + pointer + } + HirExprKind::Select { base, ref field } => { + let base = self.lower_expect_inner(stmt, frontend, base, pos)?.0; + + frontend.field_selection(self, pos, base, field, meta)? + } + HirExprKind::Literal(literal) if pos != ExprPos::Lhs => { + self.add_expression(Expression::Literal(literal), meta)? + } + HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => { + let (mut left, left_meta) = + self.lower_expect_inner(stmt, frontend, left, ExprPos::Rhs)?; + let (mut right, right_meta) = + self.lower_expect_inner(stmt, frontend, right, ExprPos::Rhs)?; + + match op { + BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => { + self.implicit_conversion(&mut right, right_meta, Scalar::U32)? + } + _ => self + .binary_implicit_conversion(&mut left, left_meta, &mut right, right_meta)?, + } + + self.typifier_grow(left, left_meta)?; + self.typifier_grow(right, right_meta)?; + + let left_inner = self.get_type(left); + let right_inner = self.get_type(right); + + match (left_inner, right_inner) { + ( + &TypeInner::Matrix { + columns: left_columns, + rows: left_rows, + scalar: left_scalar, + }, + &TypeInner::Matrix { + columns: right_columns, + rows: right_rows, + scalar: right_scalar, + }, + ) => { + let dimensions_ok = if op == BinaryOperator::Multiply { + left_columns == right_rows + } else { + left_columns == right_columns && left_rows == right_rows + }; + + // Check that the two arguments have the same dimensions + if !dimensions_ok || left_scalar != right_scalar { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + format!( + "Cannot apply operation to {left_inner:?} and {right_inner:?}" + ) + .into(), + ), + meta, + }) + } + + match op { + BinaryOperator::Divide => { + // Naga IR doesn't support matrix division so we need to + // divide the columns individually and reassemble the matrix + let mut components = Vec::with_capacity(left_columns as usize); + + for index in 0..left_columns as u32 { + // Get the column vectors + let left_vector = self.add_expression( + Expression::AccessIndex { base: left, index }, + meta, + )?; + let right_vector = self.add_expression( + Expression::AccessIndex { base: right, index }, + meta, + )?; + + // Divide the vectors + let column = self.add_expression( + Expression::Binary { + op, + left: left_vector, + right: right_vector, + }, + meta, + )?; + + components.push(column) + } + + let ty = self.module.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns: left_columns, + rows: left_rows, + scalar: left_scalar, + }, + }, + Span::default(), + ); + + // Rebuild the matrix from the divided vectors + self.add_expression(Expression::Compose { ty, components }, meta)? + } + BinaryOperator::Equal | BinaryOperator::NotEqual => { + // Naga IR doesn't support matrix comparisons so we need to + // compare the columns individually and then fold them together + // + // The folding is done using a logical and for equality and + // a logical or for inequality + let equals = op == BinaryOperator::Equal; + + let (op, combine, fun) = match equals { + true => ( + BinaryOperator::Equal, + BinaryOperator::LogicalAnd, + RelationalFunction::All, + ), + false => ( + BinaryOperator::NotEqual, + BinaryOperator::LogicalOr, + RelationalFunction::Any, + ), + }; + + let mut root = None; + + for index in 0..left_columns as u32 { + // Get the column vectors + let left_vector = self.add_expression( + Expression::AccessIndex { base: left, index }, + meta, + )?; + let right_vector = self.add_expression( + Expression::AccessIndex { base: right, index }, + meta, + )?; + + let argument = self.add_expression( + Expression::Binary { + op, + left: left_vector, + right: right_vector, + }, + meta, + )?; + + // The result of comparing two vectors is a boolean vector + // so use a relational function like all to get a single + // boolean value + let compare = self.add_expression( + Expression::Relational { fun, argument }, + meta, + )?; + + // Fold the result + root = Some(match root { + Some(right) => self.add_expression( + Expression::Binary { + op: combine, + left: compare, + right, + }, + meta, + )?, + None => compare, + }); + } + + root.unwrap() + } + _ => { + self.add_expression(Expression::Binary { left, op, right }, meta)? + } + } + } + (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op { + BinaryOperator::Equal | BinaryOperator::NotEqual => { + let equals = op == BinaryOperator::Equal; + + let (op, fun) = match equals { + true => (BinaryOperator::Equal, RelationalFunction::All), + false => (BinaryOperator::NotEqual, RelationalFunction::Any), + }; + + let argument = + self.add_expression(Expression::Binary { op, left, right }, meta)?; + + self.add_expression(Expression::Relational { fun, argument }, meta)? + } + _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, + }, + (&TypeInner::Vector { size, .. }, &TypeInner::Scalar { .. }) => match op { + BinaryOperator::Add + | BinaryOperator::Subtract + | BinaryOperator::Divide + | BinaryOperator::And + | BinaryOperator::ExclusiveOr + | BinaryOperator::InclusiveOr + | BinaryOperator::ShiftLeft + | BinaryOperator::ShiftRight => { + let scalar_vector = self + .add_expression(Expression::Splat { size, value: right }, meta)?; + + self.add_expression( + Expression::Binary { + op, + left, + right: scalar_vector, + }, + meta, + )? + } + _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, + }, + (&TypeInner::Scalar { .. }, &TypeInner::Vector { size, .. }) => match op { + BinaryOperator::Add + | BinaryOperator::Subtract + | BinaryOperator::Divide + | BinaryOperator::And + | BinaryOperator::ExclusiveOr + | BinaryOperator::InclusiveOr => { + let scalar_vector = + self.add_expression(Expression::Splat { size, value: left }, meta)?; + + self.add_expression( + Expression::Binary { + op, + left: scalar_vector, + right, + }, + meta, + )? + } + _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, + }, + ( + &TypeInner::Scalar(left_scalar), + &TypeInner::Matrix { + rows, + columns, + scalar: right_scalar, + }, + ) => { + // Check that the two arguments have the same scalar type + if left_scalar != right_scalar { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + format!( + "Cannot apply operation to {left_inner:?} and {right_inner:?}" + ) + .into(), + ), + meta, + }) + } + + match op { + BinaryOperator::Divide + | BinaryOperator::Add + | BinaryOperator::Subtract => { + // Naga IR doesn't support all matrix by scalar operations so + // we need for some to turn the scalar into a vector by + // splatting it and then for each column vector apply the + // operation and finally reconstruct the matrix + let scalar_vector = self.add_expression( + Expression::Splat { + size: rows, + value: left, + }, + meta, + )?; + + let mut components = Vec::with_capacity(columns as usize); + + for index in 0..columns as u32 { + // Get the column vector + let matrix_column = self.add_expression( + Expression::AccessIndex { base: right, index }, + meta, + )?; + + // Apply the operation to the splatted vector and + // the column vector + let column = self.add_expression( + Expression::Binary { + op, + left: scalar_vector, + right: matrix_column, + }, + meta, + )?; + + components.push(column) + } + + let ty = self.module.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns, + rows, + scalar: left_scalar, + }, + }, + Span::default(), + ); + + // Rebuild the matrix from the operation result vectors + self.add_expression(Expression::Compose { ty, components }, meta)? + } + _ => { + self.add_expression(Expression::Binary { left, op, right }, meta)? + } + } + } + ( + &TypeInner::Matrix { + rows, + columns, + scalar: left_scalar, + }, + &TypeInner::Scalar(right_scalar), + ) => { + // Check that the two arguments have the same scalar type + if left_scalar != right_scalar { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + format!( + "Cannot apply operation to {left_inner:?} and {right_inner:?}" + ) + .into(), + ), + meta, + }) + } + + match op { + BinaryOperator::Divide + | BinaryOperator::Add + | BinaryOperator::Subtract => { + // Naga IR doesn't support all matrix by scalar operations so + // we need for some to turn the scalar into a vector by + // splatting it and then for each column vector apply the + // operation and finally reconstruct the matrix + + let scalar_vector = self.add_expression( + Expression::Splat { + size: rows, + value: right, + }, + meta, + )?; + + let mut components = Vec::with_capacity(columns as usize); + + for index in 0..columns as u32 { + // Get the column vector + let matrix_column = self.add_expression( + Expression::AccessIndex { base: left, index }, + meta, + )?; + + // Apply the operation to the splatted vector and + // the column vector + let column = self.add_expression( + Expression::Binary { + op, + left: matrix_column, + right: scalar_vector, + }, + meta, + )?; + + components.push(column) + } + + let ty = self.module.types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns, + rows, + scalar: left_scalar, + }, + }, + Span::default(), + ); + + // Rebuild the matrix from the operation result vectors + self.add_expression(Expression::Compose { ty, components }, meta)? + } + _ => { + self.add_expression(Expression::Binary { left, op, right }, meta)? + } + } + } + _ => self.add_expression(Expression::Binary { left, op, right }, meta)?, + } + } + HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => { + let expr = self + .lower_expect_inner(stmt, frontend, expr, ExprPos::Rhs)? + .0; + + self.add_expression(Expression::Unary { op, expr }, meta)? + } + HirExprKind::Variable(ref var) => match pos { + ExprPos::Lhs => { + if !var.mutable { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Variable cannot be used in LHS position".into(), + ), + meta, + }) + } + + var.expr + } + ExprPos::AccessBase { constant_index } => { + // If the index isn't constant all accesses backed by a constant base need + // to be done through a proxy local variable, since constants have a non + // pointer type which is required for dynamic indexing + if !constant_index { + if let Some((constant, ty)) = var.constant { + let init = self + .add_expression(Expression::Constant(constant), Span::default())?; + let local = self.locals.append( + LocalVariable { + name: None, + ty, + init: Some(init), + }, + Span::default(), + ); + + self.add_expression(Expression::LocalVariable(local), Span::default())? + } else { + var.expr + } + } else { + var.expr + } + } + _ if var.load => { + self.add_expression(Expression::Load { pointer: var.expr }, meta)? + } + ExprPos::Rhs => { + if let Some((constant, _)) = self.is_const.then_some(var.constant).flatten() { + self.add_expression(Expression::Constant(constant), meta)? + } else { + var.expr + } + } + }, + HirExprKind::Call(ref call) if pos != ExprPos::Lhs => { + let maybe_expr = frontend.function_or_constructor_call( + self, + stmt, + call.kind.clone(), + &call.args, + meta, + )?; + return Ok((maybe_expr, meta)); + } + // `HirExprKind::Conditional` represents the ternary operator in glsl (`:?`) + // + // The ternary operator is defined to only evaluate one of the two possible + // expressions which means that it's behavior is that of an `if` statement, + // and it's merely syntactic sugar for it. + HirExprKind::Conditional { + condition, + accept, + reject, + } if ExprPos::Lhs != pos => { + // Given an expression `a ? b : c`, we need to produce a Naga + // statement roughly like: + // + // var temp; + // if a { + // temp = convert(b); + // } else { + // temp = convert(c); + // } + // + // where `convert` stands for type conversions to bring `b` and `c` to + // the same type, and then use `temp` to represent the value of the whole + // conditional expression in subsequent code. + + // Lower the condition first to the current bodyy + let condition = self + .lower_expect_inner(stmt, frontend, condition, ExprPos::Rhs)? + .0; + + let (mut accept_body, (mut accept, accept_meta)) = + self.new_body_with_ret(|ctx| { + // Lower the `true` branch + ctx.lower_expect_inner(stmt, frontend, accept, pos) + })?; + + let (mut reject_body, (mut reject, reject_meta)) = + self.new_body_with_ret(|ctx| { + // Lower the `false` branch + ctx.lower_expect_inner(stmt, frontend, reject, pos) + })?; + + // We need to do some custom implicit conversions since the two target expressions + // are in different bodies + if let (Some((accept_power, accept_scalar)), Some((reject_power, reject_scalar))) = ( + // Get the components of both branches and calculate the type power + self.expr_scalar_components(accept, accept_meta)? + .and_then(|scalar| Some((type_power(scalar)?, scalar))), + self.expr_scalar_components(reject, reject_meta)? + .and_then(|scalar| Some((type_power(scalar)?, scalar))), + ) { + match accept_power.cmp(&reject_power) { + std::cmp::Ordering::Less => { + accept_body = self.with_body(accept_body, |ctx| { + ctx.conversion(&mut accept, accept_meta, reject_scalar)?; + Ok(()) + })?; + } + std::cmp::Ordering::Equal => {} + std::cmp::Ordering::Greater => { + reject_body = self.with_body(reject_body, |ctx| { + ctx.conversion(&mut reject, reject_meta, accept_scalar)?; + Ok(()) + })?; + } + } + } + + // We need to get the type of the resulting expression to create the local, + // this must be done after implicit conversions to ensure both branches have + // the same type. + let ty = self.resolve_type_handle(accept, accept_meta)?; + + // Add the local that will hold the result of our conditional + let local = self.locals.append( + LocalVariable { + name: None, + ty, + init: None, + }, + meta, + ); + + let local_expr = self.add_expression(Expression::LocalVariable(local), meta)?; + + // Add to each the store to the result variable + accept_body.push( + Statement::Store { + pointer: local_expr, + value: accept, + }, + accept_meta, + ); + reject_body.push( + Statement::Store { + pointer: local_expr, + value: reject, + }, + reject_meta, + ); + + // Finally add the `If` to the main body with the `condition` we lowered + // earlier and the branches we prepared. + self.body.push( + Statement::If { + condition, + accept: accept_body, + reject: reject_body, + }, + meta, + ); + + // Note: `Expression::Load` must be emitted before it's used so make + // sure the emitter is active here. + self.add_expression( + Expression::Load { + pointer: local_expr, + }, + meta, + )? + } + HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => { + let (pointer, ptr_meta) = + self.lower_expect_inner(stmt, frontend, tgt, ExprPos::Lhs)?; + let (mut value, value_meta) = + self.lower_expect_inner(stmt, frontend, value, ExprPos::Rhs)?; + + let ty = match *self.resolve_type(pointer, ptr_meta)? { + TypeInner::Pointer { base, .. } => &self.module.types[base].inner, + ref ty => ty, + }; + + if let Some(scalar) = scalar_components(ty) { + self.implicit_conversion(&mut value, value_meta, scalar)?; + } + + self.lower_store(pointer, value, meta)?; + + value + } + HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => { + let (pointer, _) = self.lower_expect_inner(stmt, frontend, expr, ExprPos::Lhs)?; + let left = if let Expression::Swizzle { .. } = self.expressions[pointer] { + pointer + } else { + self.add_expression(Expression::Load { pointer }, meta)? + }; + + let res = match *self.resolve_type(left, meta)? { + TypeInner::Scalar(scalar) => { + let ty = TypeInner::Scalar(scalar); + Literal::one(scalar).map(|i| (ty, i, None, None)) + } + TypeInner::Vector { size, scalar } => { + let ty = TypeInner::Vector { size, scalar }; + Literal::one(scalar).map(|i| (ty, i, Some(size), None)) + } + TypeInner::Matrix { + columns, + rows, + scalar, + } => { + let ty = TypeInner::Matrix { + columns, + rows, + scalar, + }; + Literal::one(scalar).map(|i| (ty, i, Some(rows), Some(columns))) + } + _ => None, + }; + let (ty_inner, literal, rows, columns) = match res { + Some(res) => res, + None => { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Increment/decrement only works on scalar/vector/matrix".into(), + ), + meta, + }); + return Ok((Some(left), meta)); + } + }; + + let mut right = self.add_expression(Expression::Literal(literal), meta)?; + + // Glsl allows pre/postfixes operations on vectors and matrices, so if the + // target is either of them change the right side of the addition to be splatted + // to the same size as the target, furthermore if the target is a matrix + // use a composed matrix using the splatted value. + if let Some(size) = rows { + right = self.add_expression(Expression::Splat { size, value: right }, meta)?; + + if let Some(cols) = columns { + let ty = self.module.types.insert( + Type { + name: None, + inner: ty_inner, + }, + meta, + ); + + right = self.add_expression( + Expression::Compose { + ty, + components: std::iter::repeat(right).take(cols as usize).collect(), + }, + meta, + )?; + } + } + + let value = self.add_expression(Expression::Binary { op, left, right }, meta)?; + + self.lower_store(pointer, value, meta)?; + + if postfix { + left + } else { + value + } + } + HirExprKind::Method { + expr: object, + ref name, + ref args, + } if ExprPos::Lhs != pos => { + let args = args + .iter() + .map(|e| self.lower_expect_inner(stmt, frontend, *e, ExprPos::Rhs)) + .collect::<Result<Vec<_>>>()?; + match name.as_ref() { + "length" => { + if !args.is_empty() { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + ".length() doesn't take any arguments".into(), + ), + meta, + }); + } + let lowered_array = self.lower_expect_inner(stmt, frontend, object, pos)?.0; + let array_type = self.resolve_type(lowered_array, meta)?; + + match *array_type { + TypeInner::Array { + size: crate::ArraySize::Constant(size), + .. + } => { + let mut array_length = self.add_expression( + Expression::Literal(Literal::U32(size.get())), + meta, + )?; + self.forced_conversion(&mut array_length, meta, Scalar::I32)?; + array_length + } + // let the error be handled in type checking if it's not a dynamic array + _ => { + let mut array_length = self + .add_expression(Expression::ArrayLength(lowered_array), meta)?; + self.conversion(&mut array_length, meta, Scalar::I32)?; + array_length + } + } + } + _ => { + return Err(Error { + kind: ErrorKind::SemanticError( + format!("unknown method '{name}'").into(), + ), + meta, + }); + } + } + } + _ => { + return Err(Error { + kind: ErrorKind::SemanticError( + format!("{:?} cannot be in the left hand side", stmt.hir_exprs[expr]) + .into(), + ), + meta, + }) + } + }; + + log::trace!( + "Lowered {:?}\n\tKind = {:?}\n\tPos = {:?}\n\tResult = {:?}", + expr, + kind, + pos, + handle + ); + + Ok((Some(handle), meta)) + } + + pub fn expr_scalar_components( + &mut self, + expr: Handle<Expression>, + meta: Span, + ) -> Result<Option<Scalar>> { + let ty = self.resolve_type(expr, meta)?; + Ok(scalar_components(ty)) + } + + pub fn expr_power(&mut self, expr: Handle<Expression>, meta: Span) -> Result<Option<u32>> { + Ok(self + .expr_scalar_components(expr, meta)? + .and_then(type_power)) + } + + pub fn conversion( + &mut self, + expr: &mut Handle<Expression>, + meta: Span, + scalar: Scalar, + ) -> Result<()> { + *expr = self.add_expression( + Expression::As { + expr: *expr, + kind: scalar.kind, + convert: Some(scalar.width), + }, + meta, + )?; + + Ok(()) + } + + pub fn implicit_conversion( + &mut self, + expr: &mut Handle<Expression>, + meta: Span, + scalar: Scalar, + ) -> Result<()> { + if let (Some(tgt_power), Some(expr_power)) = + (type_power(scalar), self.expr_power(*expr, meta)?) + { + if tgt_power > expr_power { + self.conversion(expr, meta, scalar)?; + } + } + + Ok(()) + } + + pub fn forced_conversion( + &mut self, + expr: &mut Handle<Expression>, + meta: Span, + scalar: Scalar, + ) -> Result<()> { + if let Some(expr_scalar) = self.expr_scalar_components(*expr, meta)? { + if expr_scalar != scalar { + self.conversion(expr, meta, scalar)?; + } + } + + Ok(()) + } + + pub fn binary_implicit_conversion( + &mut self, + left: &mut Handle<Expression>, + left_meta: Span, + right: &mut Handle<Expression>, + right_meta: Span, + ) -> Result<()> { + let left_components = self.expr_scalar_components(*left, left_meta)?; + let right_components = self.expr_scalar_components(*right, right_meta)?; + + if let (Some((left_power, left_scalar)), Some((right_power, right_scalar))) = ( + left_components.and_then(|scalar| Some((type_power(scalar)?, scalar))), + right_components.and_then(|scalar| Some((type_power(scalar)?, scalar))), + ) { + match left_power.cmp(&right_power) { + std::cmp::Ordering::Less => { + self.conversion(left, left_meta, right_scalar)?; + } + std::cmp::Ordering::Equal => {} + std::cmp::Ordering::Greater => { + self.conversion(right, right_meta, left_scalar)?; + } + } + } + + Ok(()) + } + + pub fn implicit_splat( + &mut self, + expr: &mut Handle<Expression>, + meta: Span, + vector_size: Option<VectorSize>, + ) -> Result<()> { + let expr_type = self.resolve_type(*expr, meta)?; + + if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) { + *expr = self.add_expression(Expression::Splat { size, value: *expr }, meta)? + } + + Ok(()) + } + + pub fn vector_resize( + &mut self, + size: VectorSize, + vector: Handle<Expression>, + meta: Span, + ) -> Result<Handle<Expression>> { + self.add_expression( + Expression::Swizzle { + size, + vector, + pattern: crate::SwizzleComponent::XYZW, + }, + meta, + ) + } +} + +impl Index<Handle<Expression>> for Context<'_> { + type Output = Expression; + + fn index(&self, index: Handle<Expression>) -> &Self::Output { + if self.is_const { + &self.module.const_expressions[index] + } else { + &self.expressions[index] + } + } +} + +/// Helper struct passed when parsing expressions +/// +/// This struct should only be obtained through [`stmt_ctx`](Context::stmt_ctx) +/// and only one of these may be active at any time per context. +#[derive(Debug)] +pub struct StmtContext { + /// A arena of high level expressions which can be lowered through a + /// [`Context`] to Naga's [`Expression`]s + pub hir_exprs: Arena<HirExpr>, +} + +impl StmtContext { + const fn new() -> Self { + StmtContext { + hir_exprs: Arena::new(), + } + } +} diff --git a/third_party/rust/naga/src/front/glsl/error.rs b/third_party/rust/naga/src/front/glsl/error.rs new file mode 100644 index 0000000000..bd16ee30bc --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/error.rs @@ -0,0 +1,191 @@ +use super::token::TokenValue; +use crate::{proc::ConstantEvaluatorError, Span}; +use codespan_reporting::diagnostic::{Diagnostic, Label}; +use codespan_reporting::files::SimpleFile; +use codespan_reporting::term; +use pp_rs::token::PreprocessorError; +use std::borrow::Cow; +use termcolor::{NoColor, WriteColor}; +use thiserror::Error; + +fn join_with_comma(list: &[ExpectedToken]) -> String { + let mut string = "".to_string(); + for (i, val) in list.iter().enumerate() { + string.push_str(&val.to_string()); + match i { + i if i == list.len() - 1 => {} + i if i == list.len() - 2 => string.push_str(" or "), + _ => string.push_str(", "), + } + } + string +} + +/// One of the expected tokens returned in [`InvalidToken`](ErrorKind::InvalidToken). +#[derive(Clone, Debug, PartialEq)] +pub enum ExpectedToken { + /// A specific token was expected. + Token(TokenValue), + /// A type was expected. + TypeName, + /// An identifier was expected. + Identifier, + /// An integer literal was expected. + IntLiteral, + /// A float literal was expected. + FloatLiteral, + /// A boolean literal was expected. + BoolLiteral, + /// The end of file was expected. + Eof, +} +impl From<TokenValue> for ExpectedToken { + fn from(token: TokenValue) -> Self { + ExpectedToken::Token(token) + } +} +impl std::fmt::Display for ExpectedToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match *self { + ExpectedToken::Token(ref token) => write!(f, "{token:?}"), + ExpectedToken::TypeName => write!(f, "a type"), + ExpectedToken::Identifier => write!(f, "identifier"), + ExpectedToken::IntLiteral => write!(f, "integer literal"), + ExpectedToken::FloatLiteral => write!(f, "float literal"), + ExpectedToken::BoolLiteral => write!(f, "bool literal"), + ExpectedToken::Eof => write!(f, "end of file"), + } + } +} + +/// Information about the cause of an error. +#[derive(Clone, Debug, Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum ErrorKind { + /// Whilst parsing as encountered an unexpected EOF. + #[error("Unexpected end of file")] + EndOfFile, + /// The shader specified an unsupported or invalid profile. + #[error("Invalid profile: {0}")] + InvalidProfile(String), + /// The shader requested an unsupported or invalid version. + #[error("Invalid version: {0}")] + InvalidVersion(u64), + /// Whilst parsing an unexpected token was encountered. + /// + /// A list of expected tokens is also returned. + #[error("Expected {}, found {0:?}", join_with_comma(.1))] + InvalidToken(TokenValue, Vec<ExpectedToken>), + /// A specific feature is not yet implemented. + /// + /// To help prioritize work please open an issue in the github issue tracker + /// if none exist already or react to the already existing one. + #[error("Not implemented: {0}")] + NotImplemented(&'static str), + /// A reference to a variable that wasn't declared was used. + #[error("Unknown variable: {0}")] + UnknownVariable(String), + /// A reference to a type that wasn't declared was used. + #[error("Unknown type: {0}")] + UnknownType(String), + /// A reference to a non existent member of a type was made. + #[error("Unknown field: {0}")] + UnknownField(String), + /// An unknown layout qualifier was used. + /// + /// If the qualifier does exist please open an issue in the github issue tracker + /// if none exist already or react to the already existing one to help + /// prioritize work. + #[error("Unknown layout qualifier: {0}")] + UnknownLayoutQualifier(String), + /// Unsupported matrix of the form matCx2 + /// + /// Our IR expects matrices of the form matCx2 to have a stride of 8 however + /// matrices in the std140 layout have a stride of at least 16 + #[error("unsupported matrix of the form matCx2 in std140 block layout")] + UnsupportedMatrixTypeInStd140, + /// A variable with the same name already exists in the current scope. + #[error("Variable already declared: {0}")] + VariableAlreadyDeclared(String), + /// A semantic error was detected in the shader. + #[error("{0}")] + SemanticError(Cow<'static, str>), + /// An error was returned by the preprocessor. + #[error("{0:?}")] + PreprocessorError(PreprocessorError), + /// The parser entered an illegal state and exited + /// + /// This obviously is a bug and as such should be reported in the github issue tracker + #[error("Internal error: {0}")] + InternalError(&'static str), +} + +impl From<ConstantEvaluatorError> for ErrorKind { + fn from(err: ConstantEvaluatorError) -> Self { + ErrorKind::SemanticError(err.to_string().into()) + } +} + +/// Error returned during shader parsing. +#[derive(Clone, Debug, Error)] +#[error("{kind}")] +#[cfg_attr(test, derive(PartialEq))] +pub struct Error { + /// Holds the information about the error itself. + pub kind: ErrorKind, + /// Holds information about the range of the source code where the error happened. + pub meta: Span, +} + +/// A collection of errors returned during shader parsing. +#[derive(Clone, Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub struct ParseError { + pub errors: Vec<Error>, +} + +impl ParseError { + pub fn emit_to_writer(&self, writer: &mut impl WriteColor, source: &str) { + self.emit_to_writer_with_path(writer, source, "glsl"); + } + + pub fn emit_to_writer_with_path(&self, writer: &mut impl WriteColor, source: &str, path: &str) { + let path = path.to_string(); + let files = SimpleFile::new(path, source); + let config = codespan_reporting::term::Config::default(); + + for err in &self.errors { + let mut diagnostic = Diagnostic::error().with_message(err.kind.to_string()); + + if let Some(range) = err.meta.to_range() { + diagnostic = diagnostic.with_labels(vec![Label::primary((), range)]); + } + + term::emit(writer, &config, &files, &diagnostic).expect("cannot write error"); + } + } + + pub fn emit_to_string(&self, source: &str) -> String { + let mut writer = NoColor::new(Vec::new()); + self.emit_to_writer(&mut writer, source); + String::from_utf8(writer.into_inner()).unwrap() + } +} + +impl std::fmt::Display for ParseError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.errors.iter().try_for_each(|e| write!(f, "{e:?}")) + } +} + +impl std::error::Error for ParseError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } +} + +impl From<Vec<Error>> for ParseError { + fn from(errors: Vec<Error>) -> Self { + Self { errors } + } +} diff --git a/third_party/rust/naga/src/front/glsl/functions.rs b/third_party/rust/naga/src/front/glsl/functions.rs new file mode 100644 index 0000000000..df8cc8a30e --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/functions.rs @@ -0,0 +1,1602 @@ +use super::{ + ast::*, + builtins::{inject_builtin, sampled_to_depth}, + context::{Context, ExprPos, StmtContext}, + error::{Error, ErrorKind}, + types::scalar_components, + Frontend, Result, +}; +use crate::{ + front::glsl::types::type_power, proc::ensure_block_returns, AddressSpace, Block, EntryPoint, + Expression, Function, FunctionArgument, FunctionResult, Handle, Literal, LocalVariable, Scalar, + ScalarKind, Span, Statement, StructMember, Type, TypeInner, +}; +use std::iter; + +/// Struct detailing a store operation that must happen after a function call +struct ProxyWrite { + /// The store target + target: Handle<Expression>, + /// A pointer to read the value of the store + value: Handle<Expression>, + /// An optional conversion to be applied + convert: Option<Scalar>, +} + +impl Frontend { + pub(crate) fn function_or_constructor_call( + &mut self, + ctx: &mut Context, + stmt: &StmtContext, + fc: FunctionCallKind, + raw_args: &[Handle<HirExpr>], + meta: Span, + ) -> Result<Option<Handle<Expression>>> { + let args: Vec<_> = raw_args + .iter() + .map(|e| ctx.lower_expect_inner(stmt, self, *e, ExprPos::Rhs)) + .collect::<Result<_>>()?; + + match fc { + FunctionCallKind::TypeConstructor(ty) => { + if args.len() == 1 { + self.constructor_single(ctx, ty, args[0], meta).map(Some) + } else { + self.constructor_many(ctx, ty, args, meta).map(Some) + } + } + FunctionCallKind::Function(name) => { + self.function_call(ctx, stmt, name, args, raw_args, meta) + } + } + } + + fn constructor_single( + &mut self, + ctx: &mut Context, + ty: Handle<Type>, + (mut value, expr_meta): (Handle<Expression>, Span), + meta: Span, + ) -> Result<Handle<Expression>> { + let expr_type = ctx.resolve_type(value, expr_meta)?; + + let vector_size = match *expr_type { + TypeInner::Vector { size, .. } => Some(size), + _ => None, + }; + + let expr_is_bool = expr_type.scalar_kind() == Some(ScalarKind::Bool); + + // Special case: if casting from a bool, we need to use Select and not As. + match ctx.module.types[ty].inner.scalar() { + Some(result_scalar) if expr_is_bool && result_scalar.kind != ScalarKind::Bool => { + let result_scalar = Scalar { + width: 4, + ..result_scalar + }; + let l0 = Literal::zero(result_scalar).unwrap(); + let l1 = Literal::one(result_scalar).unwrap(); + let mut reject = ctx.add_expression(Expression::Literal(l0), expr_meta)?; + let mut accept = ctx.add_expression(Expression::Literal(l1), expr_meta)?; + + ctx.implicit_splat(&mut reject, meta, vector_size)?; + ctx.implicit_splat(&mut accept, meta, vector_size)?; + + let h = ctx.add_expression( + Expression::Select { + accept, + reject, + condition: value, + }, + expr_meta, + )?; + + return Ok(h); + } + _ => {} + } + + Ok(match ctx.module.types[ty].inner { + TypeInner::Vector { size, scalar } if vector_size.is_none() => { + ctx.forced_conversion(&mut value, expr_meta, scalar)?; + + if let TypeInner::Scalar { .. } = *ctx.resolve_type(value, expr_meta)? { + ctx.add_expression(Expression::Splat { size, value }, meta)? + } else { + self.vector_constructor(ctx, ty, size, scalar, &[(value, expr_meta)], meta)? + } + } + TypeInner::Scalar(scalar) => { + let mut expr = value; + if let TypeInner::Vector { .. } | TypeInner::Matrix { .. } = + *ctx.resolve_type(value, expr_meta)? + { + expr = ctx.add_expression( + Expression::AccessIndex { + base: expr, + index: 0, + }, + meta, + )?; + } + + if let TypeInner::Matrix { .. } = *ctx.resolve_type(value, expr_meta)? { + expr = ctx.add_expression( + Expression::AccessIndex { + base: expr, + index: 0, + }, + meta, + )?; + } + + ctx.add_expression( + Expression::As { + kind: scalar.kind, + expr, + convert: Some(scalar.width), + }, + meta, + )? + } + TypeInner::Vector { size, scalar } => { + if vector_size.map_or(true, |s| s != size) { + value = ctx.vector_resize(size, value, expr_meta)?; + } + + ctx.add_expression( + Expression::As { + kind: scalar.kind, + expr: value, + convert: Some(scalar.width), + }, + meta, + )? + } + TypeInner::Matrix { + columns, + rows, + scalar, + } => self.matrix_one_arg(ctx, ty, columns, rows, scalar, (value, expr_meta), meta)?, + TypeInner::Struct { ref members, .. } => { + let scalar_components = members + .get(0) + .and_then(|member| scalar_components(&ctx.module.types[member.ty].inner)); + if let Some(scalar) = scalar_components { + ctx.implicit_conversion(&mut value, expr_meta, scalar)?; + } + + ctx.add_expression( + Expression::Compose { + ty, + components: vec![value], + }, + meta, + )? + } + + TypeInner::Array { base, .. } => { + let scalar_components = scalar_components(&ctx.module.types[base].inner); + if let Some(scalar) = scalar_components { + ctx.implicit_conversion(&mut value, expr_meta, scalar)?; + } + + ctx.add_expression( + Expression::Compose { + ty, + components: vec![value], + }, + meta, + )? + } + _ => { + self.errors.push(Error { + kind: ErrorKind::SemanticError("Bad type constructor".into()), + meta, + }); + + value + } + }) + } + + #[allow(clippy::too_many_arguments)] + fn matrix_one_arg( + &mut self, + ctx: &mut Context, + ty: Handle<Type>, + columns: crate::VectorSize, + rows: crate::VectorSize, + element_scalar: Scalar, + (mut value, expr_meta): (Handle<Expression>, Span), + meta: Span, + ) -> Result<Handle<Expression>> { + let mut components = Vec::with_capacity(columns as usize); + // TODO: casts + // `Expression::As` doesn't support matrix width + // casts so we need to do some extra work for casts + + ctx.forced_conversion(&mut value, expr_meta, element_scalar)?; + match *ctx.resolve_type(value, expr_meta)? { + TypeInner::Scalar(_) => { + // If a matrix is constructed with a single scalar value, then that + // value is used to initialize all the values along the diagonal of + // the matrix; the rest are given zeros. + let vector_ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size: rows, + scalar: element_scalar, + }, + }, + meta, + ); + + let zero_literal = Literal::zero(element_scalar).unwrap(); + let zero = ctx.add_expression(Expression::Literal(zero_literal), meta)?; + + for i in 0..columns as u32 { + components.push( + ctx.add_expression( + Expression::Compose { + ty: vector_ty, + components: (0..rows as u32) + .map(|r| match r == i { + true => value, + false => zero, + }) + .collect(), + }, + meta, + )?, + ) + } + } + TypeInner::Matrix { + rows: ori_rows, + columns: ori_cols, + .. + } => { + // If a matrix is constructed from a matrix, then each component + // (column i, row j) in the result that has a corresponding component + // (column i, row j) in the argument will be initialized from there. All + // other components will be initialized to the identity matrix. + + let zero_literal = Literal::zero(element_scalar).unwrap(); + let one_literal = Literal::one(element_scalar).unwrap(); + + let zero = ctx.add_expression(Expression::Literal(zero_literal), meta)?; + let one = ctx.add_expression(Expression::Literal(one_literal), meta)?; + + let vector_ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size: rows, + scalar: element_scalar, + }, + }, + meta, + ); + + for i in 0..columns as u32 { + if i < ori_cols as u32 { + use std::cmp::Ordering; + + let vector = ctx.add_expression( + Expression::AccessIndex { + base: value, + index: i, + }, + meta, + )?; + + components.push(match ori_rows.cmp(&rows) { + Ordering::Less => { + let components = (0..rows as u32) + .map(|r| { + if r < ori_rows as u32 { + ctx.add_expression( + Expression::AccessIndex { + base: vector, + index: r, + }, + meta, + ) + } else if r == i { + Ok(one) + } else { + Ok(zero) + } + }) + .collect::<Result<_>>()?; + + ctx.add_expression( + Expression::Compose { + ty: vector_ty, + components, + }, + meta, + )? + } + Ordering::Equal => vector, + Ordering::Greater => ctx.vector_resize(rows, vector, meta)?, + }) + } else { + let compose_expr = Expression::Compose { + ty: vector_ty, + components: (0..rows as u32) + .map(|r| match r == i { + true => one, + false => zero, + }) + .collect(), + }; + + let vec = ctx.add_expression(compose_expr, meta)?; + + components.push(vec) + } + } + } + _ => { + components = iter::repeat(value).take(columns as usize).collect(); + } + } + + ctx.add_expression(Expression::Compose { ty, components }, meta) + } + + #[allow(clippy::too_many_arguments)] + fn vector_constructor( + &mut self, + ctx: &mut Context, + ty: Handle<Type>, + size: crate::VectorSize, + scalar: Scalar, + args: &[(Handle<Expression>, Span)], + meta: Span, + ) -> Result<Handle<Expression>> { + let mut components = Vec::with_capacity(size as usize); + + for (mut arg, expr_meta) in args.iter().copied() { + ctx.forced_conversion(&mut arg, expr_meta, scalar)?; + + if components.len() >= size as usize { + break; + } + + match *ctx.resolve_type(arg, expr_meta)? { + TypeInner::Scalar { .. } => components.push(arg), + TypeInner::Matrix { rows, columns, .. } => { + components.reserve(rows as usize * columns as usize); + for c in 0..(columns as u32) { + let base = ctx.add_expression( + Expression::AccessIndex { + base: arg, + index: c, + }, + expr_meta, + )?; + for r in 0..(rows as u32) { + components.push(ctx.add_expression( + Expression::AccessIndex { base, index: r }, + expr_meta, + )?) + } + } + } + TypeInner::Vector { size: ori_size, .. } => { + components.reserve(ori_size as usize); + for index in 0..(ori_size as u32) { + components.push(ctx.add_expression( + Expression::AccessIndex { base: arg, index }, + expr_meta, + )?) + } + } + _ => components.push(arg), + } + } + + components.truncate(size as usize); + + ctx.add_expression(Expression::Compose { ty, components }, meta) + } + + fn constructor_many( + &mut self, + ctx: &mut Context, + ty: Handle<Type>, + args: Vec<(Handle<Expression>, Span)>, + meta: Span, + ) -> Result<Handle<Expression>> { + let mut components = Vec::with_capacity(args.len()); + + let struct_member_data = match ctx.module.types[ty].inner { + TypeInner::Matrix { + columns, + rows, + scalar: element_scalar, + } => { + let mut flattened = Vec::with_capacity(columns as usize * rows as usize); + + for (mut arg, meta) in args.iter().copied() { + ctx.forced_conversion(&mut arg, meta, element_scalar)?; + + match *ctx.resolve_type(arg, meta)? { + TypeInner::Vector { size, .. } => { + for i in 0..(size as u32) { + flattened.push(ctx.add_expression( + Expression::AccessIndex { + base: arg, + index: i, + }, + meta, + )?) + } + } + _ => flattened.push(arg), + } + } + + let ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size: rows, + scalar: element_scalar, + }, + }, + meta, + ); + + for chunk in flattened.chunks(rows as usize) { + components.push(ctx.add_expression( + Expression::Compose { + ty, + components: Vec::from(chunk), + }, + meta, + )?) + } + None + } + TypeInner::Vector { size, scalar } => { + return self.vector_constructor(ctx, ty, size, scalar, &args, meta) + } + TypeInner::Array { base, .. } => { + for (mut arg, meta) in args.iter().copied() { + let scalar_components = scalar_components(&ctx.module.types[base].inner); + if let Some(scalar) = scalar_components { + ctx.implicit_conversion(&mut arg, meta, scalar)?; + } + + components.push(arg) + } + None + } + TypeInner::Struct { ref members, .. } => Some( + members + .iter() + .map(|member| scalar_components(&ctx.module.types[member.ty].inner)) + .collect::<Vec<_>>(), + ), + _ => { + return Err(Error { + kind: ErrorKind::SemanticError("Constructor: Too many arguments".into()), + meta, + }) + } + }; + + if let Some(struct_member_data) = struct_member_data { + for ((mut arg, meta), scalar_components) in + args.iter().copied().zip(struct_member_data.iter().copied()) + { + if let Some(scalar) = scalar_components { + ctx.implicit_conversion(&mut arg, meta, scalar)?; + } + + components.push(arg) + } + } + + ctx.add_expression(Expression::Compose { ty, components }, meta) + } + + #[allow(clippy::too_many_arguments)] + fn function_call( + &mut self, + ctx: &mut Context, + stmt: &StmtContext, + name: String, + args: Vec<(Handle<Expression>, Span)>, + raw_args: &[Handle<HirExpr>], + meta: Span, + ) -> Result<Option<Handle<Expression>>> { + // Grow the typifier to be able to index it later without needing + // to hold the context mutably + for &(expr, span) in args.iter() { + ctx.typifier_grow(expr, span)?; + } + + // Check if the passed arguments require any special variations + let mut variations = + builtin_required_variations(args.iter().map(|&(expr, _)| ctx.get_type(expr))); + + // Initiate the declaration if it wasn't previously initialized and inject builtins + let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| { + variations |= BuiltinVariations::STANDARD; + Default::default() + }); + inject_builtin(declaration, ctx.module, &name, variations); + + // Borrow again but without mutability, at this point a declaration is guaranteed + let declaration = self.lookup_function.get(&name).unwrap(); + + // Possibly contains the overload to be used in the call + let mut maybe_overload = None; + // The conversions needed for the best analyzed overload, this is initialized all to + // `NONE` to make sure that conversions always pass the first time without ambiguity + let mut old_conversions = vec![Conversion::None; args.len()]; + // Tracks whether the comparison between overloads lead to an ambiguity + let mut ambiguous = false; + + // Iterate over all the available overloads to select either an exact match or a + // overload which has suitable implicit conversions + 'outer: for (overload_idx, overload) in declaration.overloads.iter().enumerate() { + // If the overload and the function call don't have the same number of arguments + // continue to the next overload + if args.len() != overload.parameters.len() { + continue; + } + + log::trace!("Testing overload {}", overload_idx); + + // Stores whether the current overload matches exactly the function call + let mut exact = true; + // State of the selection + // If None we still don't know what is the best overload + // If Some(true) the new overload is better + // If Some(false) the old overload is better + let mut superior = None; + // Store the conversions for the current overload so that later they can replace the + // conversions used for querying the best overload + let mut new_conversions = vec![Conversion::None; args.len()]; + + // Loop through the overload parameters and check if the current overload is better + // compared to the previous best overload. + for (i, overload_parameter) in overload.parameters.iter().enumerate() { + let call_argument = &args[i]; + let parameter_info = &overload.parameters_info[i]; + + // If the image is used in the overload as a depth texture convert it + // before comparing, otherwise exact matches wouldn't be reported + if parameter_info.depth { + sampled_to_depth(ctx, call_argument.0, call_argument.1, &mut self.errors); + ctx.invalidate_expression(call_argument.0, call_argument.1)? + } + + ctx.typifier_grow(call_argument.0, call_argument.1)?; + + let overload_param_ty = &ctx.module.types[*overload_parameter].inner; + let call_arg_ty = ctx.get_type(call_argument.0); + + log::trace!( + "Testing parameter {}\n\tOverload = {:?}\n\tCall = {:?}", + i, + overload_param_ty, + call_arg_ty + ); + + // Storage images cannot be directly compared since while the access is part of the + // type in naga's IR, in glsl they are a qualifier and don't enter in the match as + // long as the access needed is satisfied. + if let ( + &TypeInner::Image { + class: + crate::ImageClass::Storage { + format: overload_format, + access: overload_access, + }, + dim: overload_dim, + arrayed: overload_arrayed, + }, + &TypeInner::Image { + class: + crate::ImageClass::Storage { + format: call_format, + access: call_access, + }, + dim: call_dim, + arrayed: call_arrayed, + }, + ) = (overload_param_ty, call_arg_ty) + { + // Images size must match otherwise the overload isn't what we want + let good_size = call_dim == overload_dim && call_arrayed == overload_arrayed; + // Glsl requires the formats to strictly match unless you are builtin + // function overload and have not been replaced, in which case we only + // check that the format scalar kind matches + let good_format = overload_format == call_format + || (overload.internal + && ScalarKind::from(overload_format) == ScalarKind::from(call_format)); + if !(good_size && good_format) { + continue 'outer; + } + + // While storage access mismatch is an error it isn't one that causes + // the overload matching to fail so we defer the error and consider + // that the images match exactly + if !call_access.contains(overload_access) { + self.errors.push(Error { + kind: ErrorKind::SemanticError( + format!( + "'{name}': image needs {overload_access:?} access but only {call_access:?} was provided" + ) + .into(), + ), + meta, + }); + } + + // The images satisfy the conditions to be considered as an exact match + new_conversions[i] = Conversion::Exact; + continue; + } else if overload_param_ty == call_arg_ty { + // If the types match there's no need to check for conversions so continue + new_conversions[i] = Conversion::Exact; + continue; + } + + // Glsl defines that inout follows both the conversions for input parameters and + // output parameters, this means that the type must have a conversion from both the + // call argument to the function parameter and the function parameter to the call + // argument, the only way this is possible is for the conversion to be an identity + // (i.e. call argument = function parameter) + if let ParameterQualifier::InOut = parameter_info.qualifier { + continue 'outer; + } + + // The function call argument and the function definition + // parameter are not equal at this point, so we need to try + // implicit conversions. + // + // Now there are two cases, the argument is defined as a normal + // parameter (`in` or `const`), in this case an implicit + // conversion is made from the calling argument to the + // definition argument. If the parameter is `out` the + // opposite needs to be done, so the implicit conversion is made + // from the definition argument to the calling argument. + let maybe_conversion = if parameter_info.qualifier.is_lhs() { + conversion(call_arg_ty, overload_param_ty) + } else { + conversion(overload_param_ty, call_arg_ty) + }; + + let conversion = match maybe_conversion { + Some(info) => info, + None => continue 'outer, + }; + + // At this point a conversion will be needed so the overload no longer + // exactly matches the call arguments + exact = false; + + // Compare the conversions needed for this overload parameter to that of the + // last overload analyzed respective parameter, the value is: + // - `true` when the new overload argument has a better conversion + // - `false` when the old overload argument has a better conversion + let best_arg = match (conversion, old_conversions[i]) { + // An exact match is always better, we don't need to check this for the + // current overload since it was checked earlier + (_, Conversion::Exact) => false, + // No overload was yet analyzed so this one is the best yet + (_, Conversion::None) => true, + // A conversion from a float to a double is the best possible conversion + (Conversion::FloatToDouble, _) => true, + (_, Conversion::FloatToDouble) => false, + // A conversion from a float to an integer is preferred than one + // from double to an integer + (Conversion::IntToFloat, Conversion::IntToDouble) => true, + (Conversion::IntToDouble, Conversion::IntToFloat) => false, + // This case handles things like no conversion and exact which were already + // treated and other cases which no conversion is better than the other + _ => continue, + }; + + // Check if the best parameter corresponds to the current selected overload + // to pass to the next comparison, if this isn't true mark it as ambiguous + match best_arg { + true => match superior { + Some(false) => ambiguous = true, + _ => { + superior = Some(true); + new_conversions[i] = conversion + } + }, + false => match superior { + Some(true) => ambiguous = true, + _ => superior = Some(false), + }, + } + } + + // The overload matches exactly the function call so there's no ambiguity (since + // repeated overload aren't allowed) and the current overload is selected, no + // further querying is needed. + if exact { + maybe_overload = Some(overload); + ambiguous = false; + break; + } + + match superior { + // New overload is better keep it + Some(true) => { + maybe_overload = Some(overload); + // Replace the conversions + old_conversions = new_conversions; + } + // Old overload is better do nothing + Some(false) => {} + // No overload was better than the other this can be caused + // when all conversions are ambiguous in which the overloads themselves are + // ambiguous. + None => { + ambiguous = true; + // Assign the new overload, this helps ensures that in this case of + // ambiguity the parsing won't end immediately and allow for further + // collection of errors. + maybe_overload = Some(overload); + } + } + } + + if ambiguous { + self.errors.push(Error { + kind: ErrorKind::SemanticError( + format!("Ambiguous best function for '{name}'").into(), + ), + meta, + }) + } + + let overload = maybe_overload.ok_or_else(|| Error { + kind: ErrorKind::SemanticError(format!("Unknown function '{name}'").into()), + meta, + })?; + + let parameters_info = overload.parameters_info.clone(); + let parameters = overload.parameters.clone(); + let is_void = overload.void; + let kind = overload.kind; + + let mut arguments = Vec::with_capacity(args.len()); + let mut proxy_writes = Vec::new(); + + // Iterate through the function call arguments applying transformations as needed + for (((parameter_info, call_argument), expr), parameter) in parameters_info + .iter() + .zip(&args) + .zip(raw_args) + .zip(¶meters) + { + let (mut handle, meta) = + ctx.lower_expect_inner(stmt, self, *expr, parameter_info.qualifier.as_pos())?; + + if parameter_info.qualifier.is_lhs() { + self.process_lhs_argument( + ctx, + meta, + *parameter, + parameter_info, + handle, + call_argument, + &mut proxy_writes, + &mut arguments, + )?; + + continue; + } + + let scalar_comps = scalar_components(&ctx.module.types[*parameter].inner); + + // Apply implicit conversions as needed + if let Some(scalar) = scalar_comps { + ctx.implicit_conversion(&mut handle, meta, scalar)?; + } + + arguments.push(handle) + } + + match kind { + FunctionKind::Call(function) => { + ctx.emit_end(); + + let result = if !is_void { + Some(ctx.add_expression(Expression::CallResult(function), meta)?) + } else { + None + }; + + ctx.body.push( + crate::Statement::Call { + function, + arguments, + result, + }, + meta, + ); + + ctx.emit_start(); + + // Write back all the variables that were scheduled to their original place + for proxy_write in proxy_writes { + let mut value = ctx.add_expression( + Expression::Load { + pointer: proxy_write.value, + }, + meta, + )?; + + if let Some(scalar) = proxy_write.convert { + ctx.conversion(&mut value, meta, scalar)?; + } + + ctx.emit_restart(); + + ctx.body.push( + Statement::Store { + pointer: proxy_write.target, + value, + }, + meta, + ); + } + + Ok(result) + } + FunctionKind::Macro(builtin) => builtin.call(self, ctx, arguments.as_mut_slice(), meta), + } + } + + /// Processes a function call argument that appears in place of an output + /// parameter. + #[allow(clippy::too_many_arguments)] + fn process_lhs_argument( + &mut self, + ctx: &mut Context, + meta: Span, + parameter_ty: Handle<Type>, + parameter_info: &ParameterInfo, + original: Handle<Expression>, + call_argument: &(Handle<Expression>, Span), + proxy_writes: &mut Vec<ProxyWrite>, + arguments: &mut Vec<Handle<Expression>>, + ) -> Result<()> { + let original_ty = ctx.resolve_type(original, meta)?; + let original_pointer_space = original_ty.pointer_space(); + + // The type of a possible spill variable needed for a proxy write + let mut maybe_ty = match *original_ty { + // If the argument is to be passed as a pointer but the type of the + // expression returns a vector it must mean that it was for example + // swizzled and it must be spilled into a local before calling + TypeInner::Vector { size, scalar } => Some(ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Vector { size, scalar }, + }, + Span::default(), + )), + // If the argument is a pointer whose address space isn't `Function`, an + // indirection through a local variable is needed to align the address + // spaces of the call argument and the overload parameter. + TypeInner::Pointer { base, space } if space != AddressSpace::Function => Some(base), + TypeInner::ValuePointer { + size, + scalar, + space, + } if space != AddressSpace::Function => { + let inner = match size { + Some(size) => TypeInner::Vector { size, scalar }, + None => TypeInner::Scalar(scalar), + }; + + Some( + ctx.module + .types + .insert(Type { name: None, inner }, Span::default()), + ) + } + _ => None, + }; + + // Since the original expression might be a pointer and we want a value + // for the proxy writes, we might need to load the pointer. + let value = if original_pointer_space.is_some() { + ctx.add_expression(Expression::Load { pointer: original }, Span::default())? + } else { + original + }; + + ctx.typifier_grow(call_argument.0, call_argument.1)?; + + let overload_param_ty = &ctx.module.types[parameter_ty].inner; + let call_arg_ty = ctx.get_type(call_argument.0); + let needs_conversion = call_arg_ty != overload_param_ty; + + let arg_scalar_comps = scalar_components(call_arg_ty); + + // Since output parameters also allow implicit conversions from the + // parameter to the argument, we need to spill the conversion to a + // variable and create a proxy write for the original variable. + if needs_conversion { + maybe_ty = Some(parameter_ty); + } + + if let Some(ty) = maybe_ty { + // Create the spill variable + let spill_var = ctx.locals.append( + LocalVariable { + name: None, + ty, + init: None, + }, + Span::default(), + ); + let spill_expr = + ctx.add_expression(Expression::LocalVariable(spill_var), Span::default())?; + + // If the argument is also copied in we must store the value of the + // original variable to the spill variable. + if let ParameterQualifier::InOut = parameter_info.qualifier { + ctx.body.push( + Statement::Store { + pointer: spill_expr, + value, + }, + Span::default(), + ); + } + + // Add the spill variable as an argument to the function call + arguments.push(spill_expr); + + let convert = if needs_conversion { + arg_scalar_comps + } else { + None + }; + + // Register the temporary local to be written back to it's original + // place after the function call + if let Expression::Swizzle { + size, + mut vector, + pattern, + } = ctx.expressions[original] + { + if let Expression::Load { pointer } = ctx.expressions[vector] { + vector = pointer; + } + + for (i, component) in pattern.iter().take(size as usize).enumerate() { + let original = ctx.add_expression( + Expression::AccessIndex { + base: vector, + index: *component as u32, + }, + Span::default(), + )?; + + let spill_component = ctx.add_expression( + Expression::AccessIndex { + base: spill_expr, + index: i as u32, + }, + Span::default(), + )?; + + proxy_writes.push(ProxyWrite { + target: original, + value: spill_component, + convert, + }); + } + } else { + proxy_writes.push(ProxyWrite { + target: original, + value: spill_expr, + convert, + }); + } + } else { + arguments.push(original); + } + + Ok(()) + } + + pub(crate) fn add_function( + &mut self, + mut ctx: Context, + name: String, + result: Option<FunctionResult>, + meta: Span, + ) { + ensure_block_returns(&mut ctx.body); + + let void = result.is_none(); + + // Check if the passed arguments require any special variations + let mut variations = builtin_required_variations( + ctx.parameters + .iter() + .map(|&arg| &ctx.module.types[arg].inner), + ); + + // Initiate the declaration if it wasn't previously initialized and inject builtins + let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| { + variations |= BuiltinVariations::STANDARD; + Default::default() + }); + inject_builtin(declaration, ctx.module, &name, variations); + + let Context { + expressions, + locals, + arguments, + parameters, + parameters_info, + body, + module, + .. + } = ctx; + + let function = Function { + name: Some(name), + arguments, + result, + local_variables: locals, + expressions, + named_expressions: crate::NamedExpressions::default(), + body, + }; + + 'outer: for decl in declaration.overloads.iter_mut() { + if parameters.len() != decl.parameters.len() { + continue; + } + + for (new_parameter, old_parameter) in parameters.iter().zip(decl.parameters.iter()) { + let new_inner = &module.types[*new_parameter].inner; + let old_inner = &module.types[*old_parameter].inner; + + if new_inner != old_inner { + continue 'outer; + } + } + + if decl.defined { + return self.errors.push(Error { + kind: ErrorKind::SemanticError("Function already defined".into()), + meta, + }); + } + + decl.defined = true; + decl.parameters_info = parameters_info; + match decl.kind { + FunctionKind::Call(handle) => *module.functions.get_mut(handle) = function, + FunctionKind::Macro(_) => { + let handle = module.functions.append(function, meta); + decl.kind = FunctionKind::Call(handle) + } + } + return; + } + + let handle = module.functions.append(function, meta); + declaration.overloads.push(Overload { + parameters, + parameters_info, + kind: FunctionKind::Call(handle), + defined: true, + internal: false, + void, + }); + } + + pub(crate) fn add_prototype( + &mut self, + ctx: Context, + name: String, + result: Option<FunctionResult>, + meta: Span, + ) { + let void = result.is_none(); + + // Check if the passed arguments require any special variations + let mut variations = builtin_required_variations( + ctx.parameters + .iter() + .map(|&arg| &ctx.module.types[arg].inner), + ); + + // Initiate the declaration if it wasn't previously initialized and inject builtins + let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| { + variations |= BuiltinVariations::STANDARD; + Default::default() + }); + inject_builtin(declaration, ctx.module, &name, variations); + + let Context { + arguments, + parameters, + parameters_info, + module, + .. + } = ctx; + + let function = Function { + name: Some(name), + arguments, + result, + ..Default::default() + }; + + 'outer: for decl in declaration.overloads.iter() { + if parameters.len() != decl.parameters.len() { + continue; + } + + for (new_parameter, old_parameter) in parameters.iter().zip(decl.parameters.iter()) { + let new_inner = &module.types[*new_parameter].inner; + let old_inner = &module.types[*old_parameter].inner; + + if new_inner != old_inner { + continue 'outer; + } + } + + return self.errors.push(Error { + kind: ErrorKind::SemanticError("Prototype already defined".into()), + meta, + }); + } + + let handle = module.functions.append(function, meta); + declaration.overloads.push(Overload { + parameters, + parameters_info, + kind: FunctionKind::Call(handle), + defined: false, + internal: false, + void, + }); + } + + /// Create a Naga [`EntryPoint`] that calls the GLSL `main` function. + /// + /// We compile the GLSL `main` function as an ordinary Naga [`Function`]. + /// This function synthesizes a Naga [`EntryPoint`] to call that. + /// + /// Each GLSL input and output variable (including builtins) becomes a Naga + /// [`GlobalVariable`]s in the [`Private`] address space, which `main` can + /// access in the usual way. + /// + /// The `EntryPoint` we synthesize here has an argument for each GLSL input + /// variable, and returns a struct with a member for each GLSL output + /// variable. The entry point contains code to: + /// + /// - copy its arguments into the Naga globals representing the GLSL input + /// variables, + /// + /// - call the Naga `Function` representing the GLSL `main` function, and then + /// + /// - build its return value from whatever values the GLSL `main` left in + /// the Naga globals representing GLSL `output` variables. + /// + /// Upon entry, [`ctx.body`] should contain code, accumulated by prior calls + /// to [`ParsingContext::parse_external_declaration`][pxd], to initialize + /// private global variables as needed. This code gets spliced into the + /// entry point before the call to `main`. + /// + /// [`GlobalVariable`]: crate::GlobalVariable + /// [`Private`]: crate::AddressSpace::Private + /// [`ctx.body`]: Context::body + /// [pxd]: super::ParsingContext::parse_external_declaration + pub(crate) fn add_entry_point( + &mut self, + function: Handle<Function>, + mut ctx: Context, + ) -> Result<()> { + let mut arguments = Vec::new(); + + let body = Block::with_capacity( + // global init body + ctx.body.len() + + // prologue and epilogue + self.entry_args.len() * 2 + // Call, Emit for composing struct and return + + 3, + ); + + let global_init_body = std::mem::replace(&mut ctx.body, body); + + for arg in self.entry_args.iter() { + if arg.storage != StorageQualifier::Input { + continue; + } + + let pointer = ctx + .expressions + .append(Expression::GlobalVariable(arg.handle), Default::default()); + + let ty = ctx.module.global_variables[arg.handle].ty; + + ctx.arg_type_walker( + arg.name.clone(), + arg.binding.clone(), + pointer, + ty, + &mut |ctx, name, pointer, ty, binding| { + let idx = arguments.len() as u32; + + arguments.push(FunctionArgument { + name, + ty, + binding: Some(binding), + }); + + let value = ctx + .expressions + .append(Expression::FunctionArgument(idx), Default::default()); + ctx.body + .push(Statement::Store { pointer, value }, Default::default()); + }, + )? + } + + ctx.body.extend_block(global_init_body); + + ctx.body.push( + Statement::Call { + function, + arguments: Vec::new(), + result: None, + }, + Default::default(), + ); + + let mut span = 0; + let mut members = Vec::new(); + let mut components = Vec::new(); + + for arg in self.entry_args.iter() { + if arg.storage != StorageQualifier::Output { + continue; + } + + let pointer = ctx + .expressions + .append(Expression::GlobalVariable(arg.handle), Default::default()); + + let ty = ctx.module.global_variables[arg.handle].ty; + + ctx.arg_type_walker( + arg.name.clone(), + arg.binding.clone(), + pointer, + ty, + &mut |ctx, name, pointer, ty, binding| { + members.push(StructMember { + name, + ty, + binding: Some(binding), + offset: span, + }); + + span += ctx.module.types[ty].inner.size(ctx.module.to_ctx()); + + let len = ctx.expressions.len(); + let load = ctx + .expressions + .append(Expression::Load { pointer }, Default::default()); + ctx.body.push( + Statement::Emit(ctx.expressions.range_from(len)), + Default::default(), + ); + components.push(load) + }, + )? + } + + let (ty, value) = if !components.is_empty() { + let ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Struct { members, span }, + }, + Default::default(), + ); + + let len = ctx.expressions.len(); + let res = ctx + .expressions + .append(Expression::Compose { ty, components }, Default::default()); + ctx.body.push( + Statement::Emit(ctx.expressions.range_from(len)), + Default::default(), + ); + + (Some(ty), Some(res)) + } else { + (None, None) + }; + + ctx.body + .push(Statement::Return { value }, Default::default()); + + let Context { + body, expressions, .. + } = ctx; + + ctx.module.entry_points.push(EntryPoint { + name: "main".to_string(), + stage: self.meta.stage, + early_depth_test: Some(crate::EarlyDepthTest { conservative: None }) + .filter(|_| self.meta.early_fragment_tests), + workgroup_size: self.meta.workgroup_size, + function: Function { + arguments, + expressions, + body, + result: ty.map(|ty| FunctionResult { ty, binding: None }), + ..Default::default() + }, + }); + + Ok(()) + } +} + +impl Context<'_> { + /// Helper function for building the input/output interface of the entry point + /// + /// Calls `f` with the data of the entry point argument, flattening composite types + /// recursively + /// + /// The passed arguments to the callback are: + /// - The ctx + /// - The name + /// - The pointer expression to the global storage + /// - The handle to the type of the entry point argument + /// - The binding of the entry point argument + fn arg_type_walker( + &mut self, + name: Option<String>, + binding: crate::Binding, + pointer: Handle<Expression>, + ty: Handle<Type>, + f: &mut impl FnMut( + &mut Context, + Option<String>, + Handle<Expression>, + Handle<Type>, + crate::Binding, + ), + ) -> Result<()> { + match self.module.types[ty].inner { + // TODO: Better error reporting + // right now we just don't walk the array if the size isn't known at + // compile time and let validation catch it + TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + .. + } => { + let mut location = match binding { + crate::Binding::Location { location, .. } => location, + crate::Binding::BuiltIn(_) => return Ok(()), + }; + + let interpolation = + self.module.types[base] + .inner + .scalar_kind() + .map(|kind| match kind { + ScalarKind::Float => crate::Interpolation::Perspective, + _ => crate::Interpolation::Flat, + }); + + for index in 0..size.get() { + let member_pointer = self.add_expression( + Expression::AccessIndex { + base: pointer, + index, + }, + crate::Span::default(), + )?; + + let binding = crate::Binding::Location { + location, + interpolation, + sampling: None, + second_blend_source: false, + }; + location += 1; + + self.arg_type_walker(name.clone(), binding, member_pointer, base, f)? + } + } + TypeInner::Struct { ref members, .. } => { + let mut location = match binding { + crate::Binding::Location { location, .. } => location, + crate::Binding::BuiltIn(_) => return Ok(()), + }; + + for (i, member) in members.clone().into_iter().enumerate() { + let member_pointer = self.add_expression( + Expression::AccessIndex { + base: pointer, + index: i as u32, + }, + crate::Span::default(), + )?; + + let binding = match member.binding { + Some(binding) => binding, + None => { + let interpolation = self.module.types[member.ty] + .inner + .scalar_kind() + .map(|kind| match kind { + ScalarKind::Float => crate::Interpolation::Perspective, + _ => crate::Interpolation::Flat, + }); + let binding = crate::Binding::Location { + location, + interpolation, + sampling: None, + second_blend_source: false, + }; + location += 1; + binding + } + }; + + self.arg_type_walker(member.name, binding, member_pointer, member.ty, f)? + } + } + _ => f(self, name, pointer, ty, binding), + } + + Ok(()) + } +} + +/// Helper enum containing the type of conversion need for a call +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +enum Conversion { + /// No conversion needed + Exact, + /// Float to double conversion needed + FloatToDouble, + /// Int or uint to float conversion needed + IntToFloat, + /// Int or uint to double conversion needed + IntToDouble, + /// Other type of conversion needed + Other, + /// No conversion was yet registered + None, +} + +/// Helper function, returns the type of conversion from `source` to `target`, if a +/// conversion is not possible returns None. +fn conversion(target: &TypeInner, source: &TypeInner) -> Option<Conversion> { + use ScalarKind::*; + + // Gather the `ScalarKind` and scalar width from both the target and the source + let (target_scalar, source_scalar) = match (target, source) { + // Conversions between scalars are allowed + (&TypeInner::Scalar(tgt_scalar), &TypeInner::Scalar(src_scalar)) => { + (tgt_scalar, src_scalar) + } + // Conversions between vectors of the same size are allowed + ( + &TypeInner::Vector { + size: tgt_size, + scalar: tgt_scalar, + }, + &TypeInner::Vector { + size: src_size, + scalar: src_scalar, + }, + ) if tgt_size == src_size => (tgt_scalar, src_scalar), + // Conversions between matrices of the same size are allowed + ( + &TypeInner::Matrix { + rows: tgt_rows, + columns: tgt_cols, + scalar: tgt_scalar, + }, + &TypeInner::Matrix { + rows: src_rows, + columns: src_cols, + scalar: src_scalar, + }, + ) if tgt_cols == src_cols && tgt_rows == src_rows => (tgt_scalar, src_scalar), + _ => return None, + }; + + // Check if source can be converted into target, if this is the case then the type + // power of target must be higher than that of source + let target_power = type_power(target_scalar); + let source_power = type_power(source_scalar); + if target_power < source_power { + return None; + } + + Some(match (target_scalar, source_scalar) { + // A conversion from a float to a double is special + (Scalar::F64, Scalar::F32) => Conversion::FloatToDouble, + // A conversion from an integer to a float is special + ( + Scalar::F32, + Scalar { + kind: Sint | Uint, + width: _, + }, + ) => Conversion::IntToFloat, + // A conversion from an integer to a double is special + ( + Scalar::F64, + Scalar { + kind: Sint | Uint, + width: _, + }, + ) => Conversion::IntToDouble, + _ => Conversion::Other, + }) +} + +/// Helper method returning all the non standard builtin variations needed +/// to process the function call with the passed arguments +fn builtin_required_variations<'a>(args: impl Iterator<Item = &'a TypeInner>) -> BuiltinVariations { + let mut variations = BuiltinVariations::empty(); + + for ty in args { + match *ty { + TypeInner::ValuePointer { scalar, .. } + | TypeInner::Scalar(scalar) + | TypeInner::Vector { scalar, .. } + | TypeInner::Matrix { scalar, .. } => { + if scalar == Scalar::F64 { + variations |= BuiltinVariations::DOUBLE + } + } + TypeInner::Image { + dim, + arrayed, + class, + } => { + if dim == crate::ImageDimension::Cube && arrayed { + variations |= BuiltinVariations::CUBE_TEXTURES_ARRAY + } + + if dim == crate::ImageDimension::D2 && arrayed && class.is_multisampled() { + variations |= BuiltinVariations::D2_MULTI_TEXTURES_ARRAY + } + } + _ => {} + } + } + + variations +} diff --git a/third_party/rust/naga/src/front/glsl/lex.rs b/third_party/rust/naga/src/front/glsl/lex.rs new file mode 100644 index 0000000000..1b59a9bf3e --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/lex.rs @@ -0,0 +1,301 @@ +use super::{ + ast::Precision, + token::{Directive, DirectiveKind, Token, TokenValue}, + types::parse_type, +}; +use crate::{FastHashMap, Span, StorageAccess}; +use pp_rs::{ + pp::Preprocessor, + token::{PreprocessorError, Punct, TokenValue as PPTokenValue}, +}; + +#[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub struct LexerResult { + pub kind: LexerResultKind, + pub meta: Span, +} + +#[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub enum LexerResultKind { + Token(Token), + Directive(Directive), + Error(PreprocessorError), +} + +pub struct Lexer<'a> { + pp: Preprocessor<'a>, +} + +impl<'a> Lexer<'a> { + pub fn new(input: &'a str, defines: &'a FastHashMap<String, String>) -> Self { + let mut pp = Preprocessor::new(input); + for (define, value) in defines { + pp.add_define(define, value).unwrap(); //TODO: handle error + } + Lexer { pp } + } +} + +impl<'a> Iterator for Lexer<'a> { + type Item = LexerResult; + fn next(&mut self) -> Option<Self::Item> { + let pp_token = match self.pp.next()? { + Ok(t) => t, + Err((err, loc)) => { + return Some(LexerResult { + kind: LexerResultKind::Error(err), + meta: loc.into(), + }); + } + }; + + let meta = pp_token.location.into(); + let value = match pp_token.value { + PPTokenValue::Extension(extension) => { + return Some(LexerResult { + kind: LexerResultKind::Directive(Directive { + kind: DirectiveKind::Extension, + tokens: extension.tokens, + }), + meta, + }) + } + PPTokenValue::Float(float) => TokenValue::FloatConstant(float), + PPTokenValue::Ident(ident) => { + match ident.as_str() { + // Qualifiers + "layout" => TokenValue::Layout, + "in" => TokenValue::In, + "out" => TokenValue::Out, + "uniform" => TokenValue::Uniform, + "buffer" => TokenValue::Buffer, + "shared" => TokenValue::Shared, + "invariant" => TokenValue::Invariant, + "flat" => TokenValue::Interpolation(crate::Interpolation::Flat), + "noperspective" => TokenValue::Interpolation(crate::Interpolation::Linear), + "smooth" => TokenValue::Interpolation(crate::Interpolation::Perspective), + "centroid" => TokenValue::Sampling(crate::Sampling::Centroid), + "sample" => TokenValue::Sampling(crate::Sampling::Sample), + "const" => TokenValue::Const, + "inout" => TokenValue::InOut, + "precision" => TokenValue::Precision, + "highp" => TokenValue::PrecisionQualifier(Precision::High), + "mediump" => TokenValue::PrecisionQualifier(Precision::Medium), + "lowp" => TokenValue::PrecisionQualifier(Precision::Low), + "restrict" => TokenValue::Restrict, + "readonly" => TokenValue::MemoryQualifier(StorageAccess::LOAD), + "writeonly" => TokenValue::MemoryQualifier(StorageAccess::STORE), + // values + "true" => TokenValue::BoolConstant(true), + "false" => TokenValue::BoolConstant(false), + // jump statements + "continue" => TokenValue::Continue, + "break" => TokenValue::Break, + "return" => TokenValue::Return, + "discard" => TokenValue::Discard, + // selection statements + "if" => TokenValue::If, + "else" => TokenValue::Else, + "switch" => TokenValue::Switch, + "case" => TokenValue::Case, + "default" => TokenValue::Default, + // iteration statements + "while" => TokenValue::While, + "do" => TokenValue::Do, + "for" => TokenValue::For, + // types + "void" => TokenValue::Void, + "struct" => TokenValue::Struct, + word => match parse_type(word) { + Some(t) => TokenValue::TypeName(t), + None => TokenValue::Identifier(String::from(word)), + }, + } + } + PPTokenValue::Integer(integer) => TokenValue::IntConstant(integer), + PPTokenValue::Punct(punct) => match punct { + // Compound assignments + Punct::AddAssign => TokenValue::AddAssign, + Punct::SubAssign => TokenValue::SubAssign, + Punct::MulAssign => TokenValue::MulAssign, + Punct::DivAssign => TokenValue::DivAssign, + Punct::ModAssign => TokenValue::ModAssign, + Punct::LeftShiftAssign => TokenValue::LeftShiftAssign, + Punct::RightShiftAssign => TokenValue::RightShiftAssign, + Punct::AndAssign => TokenValue::AndAssign, + Punct::XorAssign => TokenValue::XorAssign, + Punct::OrAssign => TokenValue::OrAssign, + + // Two character punctuation + Punct::Increment => TokenValue::Increment, + Punct::Decrement => TokenValue::Decrement, + Punct::LogicalAnd => TokenValue::LogicalAnd, + Punct::LogicalOr => TokenValue::LogicalOr, + Punct::LogicalXor => TokenValue::LogicalXor, + Punct::LessEqual => TokenValue::LessEqual, + Punct::GreaterEqual => TokenValue::GreaterEqual, + Punct::EqualEqual => TokenValue::Equal, + Punct::NotEqual => TokenValue::NotEqual, + Punct::LeftShift => TokenValue::LeftShift, + Punct::RightShift => TokenValue::RightShift, + + // Parenthesis or similar + Punct::LeftBrace => TokenValue::LeftBrace, + Punct::RightBrace => TokenValue::RightBrace, + Punct::LeftParen => TokenValue::LeftParen, + Punct::RightParen => TokenValue::RightParen, + Punct::LeftBracket => TokenValue::LeftBracket, + Punct::RightBracket => TokenValue::RightBracket, + + // Other one character punctuation + Punct::LeftAngle => TokenValue::LeftAngle, + Punct::RightAngle => TokenValue::RightAngle, + Punct::Semicolon => TokenValue::Semicolon, + Punct::Comma => TokenValue::Comma, + Punct::Colon => TokenValue::Colon, + Punct::Dot => TokenValue::Dot, + Punct::Equal => TokenValue::Assign, + Punct::Bang => TokenValue::Bang, + Punct::Minus => TokenValue::Dash, + Punct::Tilde => TokenValue::Tilde, + Punct::Plus => TokenValue::Plus, + Punct::Star => TokenValue::Star, + Punct::Slash => TokenValue::Slash, + Punct::Percent => TokenValue::Percent, + Punct::Pipe => TokenValue::VerticalBar, + Punct::Caret => TokenValue::Caret, + Punct::Ampersand => TokenValue::Ampersand, + Punct::Question => TokenValue::Question, + }, + PPTokenValue::Pragma(pragma) => { + return Some(LexerResult { + kind: LexerResultKind::Directive(Directive { + kind: DirectiveKind::Pragma, + tokens: pragma.tokens, + }), + meta, + }) + } + PPTokenValue::Version(version) => { + return Some(LexerResult { + kind: LexerResultKind::Directive(Directive { + kind: DirectiveKind::Version { + is_first_directive: version.is_first_directive, + }, + tokens: version.tokens, + }), + meta, + }) + } + }; + + Some(LexerResult { + kind: LexerResultKind::Token(Token { value, meta }), + meta, + }) + } +} + +#[cfg(test)] +mod tests { + use pp_rs::token::{Integer, Location, Token as PPToken, TokenValue as PPTokenValue}; + + use super::{ + super::token::{Directive, DirectiveKind, Token, TokenValue}, + Lexer, LexerResult, LexerResultKind, + }; + use crate::Span; + + #[test] + fn lex_tokens() { + let defines = crate::FastHashMap::default(); + + // line comments + let mut lex = Lexer::new("#version 450\nvoid main () {}", &defines); + let mut location = Location::default(); + location.start = 9; + location.end = 12; + assert_eq!( + lex.next().unwrap(), + LexerResult { + kind: LexerResultKind::Directive(Directive { + kind: DirectiveKind::Version { + is_first_directive: true + }, + tokens: vec![PPToken { + value: PPTokenValue::Integer(Integer { + signed: true, + value: 450, + width: 32 + }), + location + }] + }), + meta: Span::new(1, 8) + } + ); + assert_eq!( + lex.next().unwrap(), + LexerResult { + kind: LexerResultKind::Token(Token { + value: TokenValue::Void, + meta: Span::new(13, 17) + }), + meta: Span::new(13, 17) + } + ); + assert_eq!( + lex.next().unwrap(), + LexerResult { + kind: LexerResultKind::Token(Token { + value: TokenValue::Identifier("main".into()), + meta: Span::new(18, 22) + }), + meta: Span::new(18, 22) + } + ); + assert_eq!( + lex.next().unwrap(), + LexerResult { + kind: LexerResultKind::Token(Token { + value: TokenValue::LeftParen, + meta: Span::new(23, 24) + }), + meta: Span::new(23, 24) + } + ); + assert_eq!( + lex.next().unwrap(), + LexerResult { + kind: LexerResultKind::Token(Token { + value: TokenValue::RightParen, + meta: Span::new(24, 25) + }), + meta: Span::new(24, 25) + } + ); + assert_eq!( + lex.next().unwrap(), + LexerResult { + kind: LexerResultKind::Token(Token { + value: TokenValue::LeftBrace, + meta: Span::new(26, 27) + }), + meta: Span::new(26, 27) + } + ); + assert_eq!( + lex.next().unwrap(), + LexerResult { + kind: LexerResultKind::Token(Token { + value: TokenValue::RightBrace, + meta: Span::new(27, 28) + }), + meta: Span::new(27, 28) + } + ); + assert_eq!(lex.next(), None); + } +} diff --git a/third_party/rust/naga/src/front/glsl/mod.rs b/third_party/rust/naga/src/front/glsl/mod.rs new file mode 100644 index 0000000000..75f3929db4 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/mod.rs @@ -0,0 +1,232 @@ +/*! +Frontend for [GLSL][glsl] (OpenGL Shading Language). + +To begin, take a look at the documentation for the [`Frontend`]. + +# Supported versions +## Vulkan +- 440 (partial) +- 450 +- 460 + +[glsl]: https://www.khronos.org/registry/OpenGL/index_gl.php +*/ + +pub use ast::{Precision, Profile}; +pub use error::{Error, ErrorKind, ExpectedToken, ParseError}; +pub use token::TokenValue; + +use crate::{proc::Layouter, FastHashMap, FastHashSet, Handle, Module, ShaderStage, Span, Type}; +use ast::{EntryArg, FunctionDeclaration, GlobalLookup}; +use parser::ParsingContext; + +mod ast; +mod builtins; +mod context; +mod error; +mod functions; +mod lex; +mod offset; +mod parser; +#[cfg(test)] +mod parser_tests; +mod token; +mod types; +mod variables; + +type Result<T> = std::result::Result<T, Error>; + +/// Per-shader options passed to [`parse`](Frontend::parse). +/// +/// The [`From`] trait is implemented for [`ShaderStage`] to provide a quick way +/// to create an `Options` instance. +/// +/// ```rust +/// # use naga::ShaderStage; +/// # use naga::front::glsl::Options; +/// Options::from(ShaderStage::Vertex); +/// ``` +#[derive(Debug)] +pub struct Options { + /// The shader stage in the pipeline. + pub stage: ShaderStage, + /// Preprocessor definitions to be used, akin to having + /// ```glsl + /// #define key value + /// ``` + /// for each key value pair in the map. + pub defines: FastHashMap<String, String>, +} + +impl From<ShaderStage> for Options { + fn from(stage: ShaderStage) -> Self { + Options { + stage, + defines: FastHashMap::default(), + } + } +} + +/// Additional information about the GLSL shader. +/// +/// Stores additional information about the GLSL shader which might not be +/// stored in the shader [`Module`]. +#[derive(Debug)] +pub struct ShaderMetadata { + /// The GLSL version specified in the shader through the use of the + /// `#version` preprocessor directive. + pub version: u16, + /// The GLSL profile specified in the shader through the use of the + /// `#version` preprocessor directive. + pub profile: Profile, + /// The shader stage in the pipeline, passed to the [`parse`](Frontend::parse) + /// method via the [`Options`] struct. + pub stage: ShaderStage, + + /// The workgroup size for compute shaders, defaults to `[1; 3]` for + /// compute shaders and `[0; 3]` for non compute shaders. + pub workgroup_size: [u32; 3], + /// Whether or not early fragment tests where requested by the shader. + /// Defaults to `false`. + pub early_fragment_tests: bool, + + /// The shader can request extensions via the + /// `#extension` preprocessor directive, in the directive a behavior + /// parameter is used to control whether the extension should be disabled, + /// warn on usage, enabled if possible or required. + /// + /// This field only stores extensions which were required or requested to + /// be enabled if possible and they are supported. + pub extensions: FastHashSet<String>, +} + +impl ShaderMetadata { + fn reset(&mut self, stage: ShaderStage) { + self.version = 0; + self.profile = Profile::Core; + self.stage = stage; + self.workgroup_size = [u32::from(stage == ShaderStage::Compute); 3]; + self.early_fragment_tests = false; + self.extensions.clear(); + } +} + +impl Default for ShaderMetadata { + fn default() -> Self { + ShaderMetadata { + version: 0, + profile: Profile::Core, + stage: ShaderStage::Vertex, + workgroup_size: [0; 3], + early_fragment_tests: false, + extensions: FastHashSet::default(), + } + } +} + +/// The `Frontend` is the central structure of the GLSL frontend. +/// +/// To instantiate a new `Frontend` the [`Default`] trait is used, so a +/// call to the associated function [`Frontend::default`](Frontend::default) will +/// return a new `Frontend` instance. +/// +/// To parse a shader simply call the [`parse`](Frontend::parse) method with a +/// [`Options`] struct and a [`&str`](str) holding the glsl code. +/// +/// The `Frontend` also provides the [`metadata`](Frontend::metadata) to get some +/// further information about the previously parsed shader, like version and +/// extensions used (see the documentation for +/// [`ShaderMetadata`] to see all the returned information) +/// +/// # Example usage +/// ```rust +/// use naga::ShaderStage; +/// use naga::front::glsl::{Frontend, Options}; +/// +/// let glsl = r#" +/// #version 450 core +/// +/// void main() {} +/// "#; +/// +/// let mut frontend = Frontend::default(); +/// let options = Options::from(ShaderStage::Vertex); +/// frontend.parse(&options, glsl); +/// ``` +/// +/// # Reusability +/// +/// If there's a need to parse more than one shader reusing the same `Frontend` +/// instance may be beneficial since internal allocations will be reused. +/// +/// Calling the [`parse`](Frontend::parse) method multiple times will reset the +/// `Frontend` so no extra care is needed when reusing. +#[derive(Debug, Default)] +pub struct Frontend { + meta: ShaderMetadata, + + lookup_function: FastHashMap<String, FunctionDeclaration>, + lookup_type: FastHashMap<String, Handle<Type>>, + + global_variables: Vec<(String, GlobalLookup)>, + + entry_args: Vec<EntryArg>, + + layouter: Layouter, + + errors: Vec<Error>, +} + +impl Frontend { + fn reset(&mut self, stage: ShaderStage) { + self.meta.reset(stage); + + self.lookup_function.clear(); + self.lookup_type.clear(); + self.global_variables.clear(); + self.entry_args.clear(); + self.layouter.clear(); + } + + /// Parses a shader either outputting a shader [`Module`] or a list of + /// [`Error`]s. + /// + /// Multiple calls using the same `Frontend` and different shaders are supported. + pub fn parse( + &mut self, + options: &Options, + source: &str, + ) -> std::result::Result<Module, ParseError> { + self.reset(options.stage); + + let lexer = lex::Lexer::new(source, &options.defines); + let mut ctx = ParsingContext::new(lexer); + + match ctx.parse(self) { + Ok(module) => { + if self.errors.is_empty() { + Ok(module) + } else { + Err(std::mem::take(&mut self.errors).into()) + } + } + Err(e) => { + self.errors.push(e); + Err(std::mem::take(&mut self.errors).into()) + } + } + } + + /// Returns additional information about the parsed shader which might not + /// be stored in the [`Module`], see the documentation for + /// [`ShaderMetadata`] for more information about the returned data. + /// + /// # Notes + /// + /// Following an unsuccessful parsing the state of the returned information + /// is undefined, it might contain only partial information about the + /// current shader, the previous shader or both. + pub const fn metadata(&self) -> &ShaderMetadata { + &self.meta + } +} diff --git a/third_party/rust/naga/src/front/glsl/offset.rs b/third_party/rust/naga/src/front/glsl/offset.rs new file mode 100644 index 0000000000..c88c46598d --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/offset.rs @@ -0,0 +1,173 @@ +/*! +Module responsible for calculating the offset and span for types. + +There exists two types of layouts std140 and std430 (there's technically +two more layouts, shared and packed. Shared is not supported by spirv. Packed is +implementation dependent and for now it's just implemented as an alias to +std140). + +The OpenGl spec (the layout rules are defined by the OpenGl spec in section +7.6.2.2 as opposed to the GLSL spec) uses the term basic machine units which are +equivalent to bytes. +*/ + +use super::{ + ast::StructLayout, + error::{Error, ErrorKind}, + Span, +}; +use crate::{proc::Alignment, Handle, Scalar, Type, TypeInner, UniqueArena}; + +/// Struct with information needed for defining a struct member. +/// +/// Returned by [`calculate_offset`]. +#[derive(Debug)] +pub struct TypeAlignSpan { + /// The handle to the type, this might be the same handle passed to + /// [`calculate_offset`] or a new such a new array type with a different + /// stride set. + pub ty: Handle<Type>, + /// The alignment required by the type. + pub align: Alignment, + /// The size of the type. + pub span: u32, +} + +/// Returns the type, alignment and span of a struct member according to a [`StructLayout`]. +/// +/// The functions returns a [`TypeAlignSpan`] which has a `ty` member this +/// should be used as the struct member type because for example arrays may have +/// to change the stride and as such need to have a different type. +pub fn calculate_offset( + mut ty: Handle<Type>, + meta: Span, + layout: StructLayout, + types: &mut UniqueArena<Type>, + errors: &mut Vec<Error>, +) -> TypeAlignSpan { + // When using the std430 storage layout, shader storage blocks will be laid out in buffer storage + // identically to uniform and shader storage blocks using the std140 layout, except + // that the base alignment and stride of arrays of scalars and vectors in rule 4 and of + // structures in rule 9 are not rounded up a multiple of the base alignment of a vec4. + + let (align, span) = match types[ty].inner { + // 1. If the member is a scalar consuming N basic machine units, + // the base alignment is N. + TypeInner::Scalar(Scalar { width, .. }) => (Alignment::from_width(width), width as u32), + // 2. If the member is a two- or four-component vector with components + // consuming N basic machine units, the base alignment is 2N or 4N, respectively. + // 3. If the member is a three-component vector with components consuming N + // basic machine units, the base alignment is 4N. + TypeInner::Vector { + size, + scalar: Scalar { width, .. }, + } => ( + Alignment::from(size) * Alignment::from_width(width), + size as u32 * width as u32, + ), + // 4. If the member is an array of scalars or vectors, the base alignment and array + // stride are set to match the base alignment of a single array element, according + // to rules (1), (2), and (3), and rounded up to the base alignment of a vec4. + // TODO: Matrices array + TypeInner::Array { base, size, .. } => { + let info = calculate_offset(base, meta, layout, types, errors); + + let name = types[ty].name.clone(); + + // See comment at the beginning of the function + let (align, stride) = if StructLayout::Std430 == layout { + (info.align, info.align.round_up(info.span)) + } else { + let align = info.align.max(Alignment::MIN_UNIFORM); + (align, align.round_up(info.span)) + }; + + let span = match size { + crate::ArraySize::Constant(size) => size.get() * stride, + crate::ArraySize::Dynamic => stride, + }; + + let ty_span = types.get_span(ty); + ty = types.insert( + Type { + name, + inner: TypeInner::Array { + base: info.ty, + size, + stride, + }, + }, + ty_span, + ); + + (align, span) + } + // 5. If the member is a column-major matrix with C columns and R rows, the + // matrix is stored identically to an array of C column vectors with R + // components each, according to rule (4) + // TODO: Row major matrices + TypeInner::Matrix { + columns, + rows, + scalar, + } => { + let mut align = Alignment::from(rows) * Alignment::from_width(scalar.width); + + // See comment at the beginning of the function + if StructLayout::Std430 != layout { + align = align.max(Alignment::MIN_UNIFORM); + } + + // See comment on the error kind + if StructLayout::Std140 == layout && rows == crate::VectorSize::Bi { + errors.push(Error { + kind: ErrorKind::UnsupportedMatrixTypeInStd140, + meta, + }); + } + + (align, align * columns as u32) + } + TypeInner::Struct { ref members, .. } => { + let mut span = 0; + let mut align = Alignment::ONE; + let mut members = members.clone(); + let name = types[ty].name.clone(); + + for member in members.iter_mut() { + let info = calculate_offset(member.ty, meta, layout, types, errors); + + let member_alignment = info.align; + span = member_alignment.round_up(span); + align = member_alignment.max(align); + + member.ty = info.ty; + member.offset = span; + + span += info.span; + } + + span = align.round_up(span); + + let ty_span = types.get_span(ty); + ty = types.insert( + Type { + name, + inner: TypeInner::Struct { members, span }, + }, + ty_span, + ); + + (align, span) + } + _ => { + errors.push(Error { + kind: ErrorKind::SemanticError("Invalid struct member type".into()), + meta, + }); + (Alignment::ONE, 0) + } + }; + + TypeAlignSpan { ty, align, span } +} diff --git a/third_party/rust/naga/src/front/glsl/parser.rs b/third_party/rust/naga/src/front/glsl/parser.rs new file mode 100644 index 0000000000..851d2e1d79 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/parser.rs @@ -0,0 +1,431 @@ +use super::{ + ast::{FunctionKind, Profile, TypeQualifiers}, + context::{Context, ExprPos}, + error::ExpectedToken, + error::{Error, ErrorKind}, + lex::{Lexer, LexerResultKind}, + token::{Directive, DirectiveKind}, + token::{Token, TokenValue}, + variables::{GlobalOrConstant, VarDeclaration}, + Frontend, Result, +}; +use crate::{arena::Handle, proc::U32EvalError, Expression, Module, Span, Type}; +use pp_rs::token::{PreprocessorError, Token as PPToken, TokenValue as PPTokenValue}; +use std::iter::Peekable; + +mod declarations; +mod expressions; +mod functions; +mod types; + +pub struct ParsingContext<'source> { + lexer: Peekable<Lexer<'source>>, + /// Used to store tokens already consumed by the parser but that need to be backtracked + backtracked_token: Option<Token>, + last_meta: Span, +} + +impl<'source> ParsingContext<'source> { + pub fn new(lexer: Lexer<'source>) -> Self { + ParsingContext { + lexer: lexer.peekable(), + backtracked_token: None, + last_meta: Span::default(), + } + } + + /// Helper method for backtracking from a consumed token + /// + /// This method should always be used instead of assigning to `backtracked_token` since + /// it validates that backtracking hasn't occurred more than one time in a row + /// + /// # Panics + /// - If the parser already backtracked without bumping in between + pub fn backtrack(&mut self, token: Token) -> Result<()> { + // This should never happen + if let Some(ref prev_token) = self.backtracked_token { + return Err(Error { + kind: ErrorKind::InternalError("The parser tried to backtrack twice in a row"), + meta: prev_token.meta, + }); + } + + self.backtracked_token = Some(token); + + Ok(()) + } + + pub fn expect_ident(&mut self, frontend: &mut Frontend) -> Result<(String, Span)> { + let token = self.bump(frontend)?; + + match token.value { + TokenValue::Identifier(name) => Ok((name, token.meta)), + _ => Err(Error { + kind: ErrorKind::InvalidToken(token.value, vec![ExpectedToken::Identifier]), + meta: token.meta, + }), + } + } + + pub fn expect(&mut self, frontend: &mut Frontend, value: TokenValue) -> Result<Token> { + let token = self.bump(frontend)?; + + if token.value != value { + Err(Error { + kind: ErrorKind::InvalidToken(token.value, vec![value.into()]), + meta: token.meta, + }) + } else { + Ok(token) + } + } + + pub fn next(&mut self, frontend: &mut Frontend) -> Option<Token> { + loop { + if let Some(token) = self.backtracked_token.take() { + self.last_meta = token.meta; + break Some(token); + } + + let res = self.lexer.next()?; + + match res.kind { + LexerResultKind::Token(token) => { + self.last_meta = token.meta; + break Some(token); + } + LexerResultKind::Directive(directive) => { + frontend.handle_directive(directive, res.meta) + } + LexerResultKind::Error(error) => frontend.errors.push(Error { + kind: ErrorKind::PreprocessorError(error), + meta: res.meta, + }), + } + } + } + + pub fn bump(&mut self, frontend: &mut Frontend) -> Result<Token> { + self.next(frontend).ok_or(Error { + kind: ErrorKind::EndOfFile, + meta: self.last_meta, + }) + } + + /// Returns None on the end of the file rather than an error like other methods + pub fn bump_if(&mut self, frontend: &mut Frontend, value: TokenValue) -> Option<Token> { + if self.peek(frontend).filter(|t| t.value == value).is_some() { + self.bump(frontend).ok() + } else { + None + } + } + + pub fn peek(&mut self, frontend: &mut Frontend) -> Option<&Token> { + loop { + if let Some(ref token) = self.backtracked_token { + break Some(token); + } + + match self.lexer.peek()?.kind { + LexerResultKind::Token(_) => { + let res = self.lexer.peek()?; + + match res.kind { + LexerResultKind::Token(ref token) => break Some(token), + _ => unreachable!(), + } + } + LexerResultKind::Error(_) | LexerResultKind::Directive(_) => { + let res = self.lexer.next()?; + + match res.kind { + LexerResultKind::Directive(directive) => { + frontend.handle_directive(directive, res.meta) + } + LexerResultKind::Error(error) => frontend.errors.push(Error { + kind: ErrorKind::PreprocessorError(error), + meta: res.meta, + }), + LexerResultKind::Token(_) => unreachable!(), + } + } + } + } + } + + pub fn expect_peek(&mut self, frontend: &mut Frontend) -> Result<&Token> { + let meta = self.last_meta; + self.peek(frontend).ok_or(Error { + kind: ErrorKind::EndOfFile, + meta, + }) + } + + pub fn parse(&mut self, frontend: &mut Frontend) -> Result<Module> { + let mut module = Module::default(); + + // Body and expression arena for global initialization + let mut ctx = Context::new(frontend, &mut module, false)?; + + while self.peek(frontend).is_some() { + self.parse_external_declaration(frontend, &mut ctx)?; + } + + // Add an `EntryPoint` to `parser.module` for `main`, if a + // suitable overload exists. Error out if we can't find one. + if let Some(declaration) = frontend.lookup_function.get("main") { + for decl in declaration.overloads.iter() { + if let FunctionKind::Call(handle) = decl.kind { + if decl.defined && decl.parameters.is_empty() { + frontend.add_entry_point(handle, ctx)?; + return Ok(module); + } + } + } + } + + Err(Error { + kind: ErrorKind::SemanticError("Missing entry point".into()), + meta: Span::default(), + }) + } + + fn parse_uint_constant( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + ) -> Result<(u32, Span)> { + let (const_expr, meta) = self.parse_constant_expression(frontend, ctx.module)?; + + let res = ctx.module.to_ctx().eval_expr_to_u32(const_expr); + + let int = match res { + Ok(value) => Ok(value), + Err(U32EvalError::Negative) => Err(Error { + kind: ErrorKind::SemanticError("int constant overflows".into()), + meta, + }), + Err(U32EvalError::NonConst) => Err(Error { + kind: ErrorKind::SemanticError("Expected a uint constant".into()), + meta, + }), + }?; + + Ok((int, meta)) + } + + fn parse_constant_expression( + &mut self, + frontend: &mut Frontend, + module: &mut Module, + ) -> Result<(Handle<Expression>, Span)> { + let mut ctx = Context::new(frontend, module, true)?; + + let mut stmt_ctx = ctx.stmt_ctx(); + let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, None)?; + let (root, meta) = ctx.lower_expect(stmt_ctx, frontend, expr, ExprPos::Rhs)?; + + Ok((root, meta)) + } +} + +impl Frontend { + fn handle_directive(&mut self, directive: Directive, meta: Span) { + let mut tokens = directive.tokens.into_iter(); + + match directive.kind { + DirectiveKind::Version { is_first_directive } => { + if !is_first_directive { + self.errors.push(Error { + kind: ErrorKind::SemanticError( + "#version must occur first in shader".into(), + ), + meta, + }) + } + + match tokens.next() { + Some(PPToken { + value: PPTokenValue::Integer(int), + location, + }) => match int.value { + 440 | 450 | 460 => self.meta.version = int.value as u16, + _ => self.errors.push(Error { + kind: ErrorKind::InvalidVersion(int.value), + meta: location.into(), + }), + }, + Some(PPToken { value, location }) => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + value, + )), + meta: location.into(), + }), + None => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine), + meta, + }), + }; + + match tokens.next() { + Some(PPToken { + value: PPTokenValue::Ident(name), + location, + }) => match name.as_str() { + "core" => self.meta.profile = Profile::Core, + _ => self.errors.push(Error { + kind: ErrorKind::InvalidProfile(name), + meta: location.into(), + }), + }, + Some(PPToken { value, location }) => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + value, + )), + meta: location.into(), + }), + None => {} + }; + + if let Some(PPToken { value, location }) = tokens.next() { + self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + value, + )), + meta: location.into(), + }) + } + } + DirectiveKind::Extension => { + // TODO: Proper extension handling + // - Checking for extension support in the compiler + // - Handle behaviors such as warn + // - Handle the all extension + let name = match tokens.next() { + Some(PPToken { + value: PPTokenValue::Ident(name), + .. + }) => Some(name), + Some(PPToken { value, location }) => { + self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + value, + )), + meta: location.into(), + }); + + None + } + None => { + self.errors.push(Error { + kind: ErrorKind::PreprocessorError( + PreprocessorError::UnexpectedNewLine, + ), + meta, + }); + + None + } + }; + + match tokens.next() { + Some(PPToken { + value: PPTokenValue::Punct(pp_rs::token::Punct::Colon), + .. + }) => {} + Some(PPToken { value, location }) => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + value, + )), + meta: location.into(), + }), + None => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine), + meta, + }), + }; + + match tokens.next() { + Some(PPToken { + value: PPTokenValue::Ident(behavior), + location, + }) => match behavior.as_str() { + "require" | "enable" | "warn" | "disable" => { + if let Some(name) = name { + self.meta.extensions.insert(name); + } + } + _ => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + PPTokenValue::Ident(behavior), + )), + meta: location.into(), + }), + }, + Some(PPToken { value, location }) => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + value, + )), + meta: location.into(), + }), + None => self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine), + meta, + }), + } + + if let Some(PPToken { value, location }) = tokens.next() { + self.errors.push(Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken( + value, + )), + meta: location.into(), + }) + } + } + DirectiveKind::Pragma => { + // TODO: handle some common pragmas? + } + } + } +} + +pub struct DeclarationContext<'ctx, 'qualifiers, 'a> { + qualifiers: TypeQualifiers<'qualifiers>, + /// Indicates a global declaration + external: bool, + is_inside_loop: bool, + ctx: &'ctx mut Context<'a>, +} + +impl<'ctx, 'qualifiers, 'a> DeclarationContext<'ctx, 'qualifiers, 'a> { + fn add_var( + &mut self, + frontend: &mut Frontend, + ty: Handle<Type>, + name: String, + init: Option<Handle<Expression>>, + meta: Span, + ) -> Result<Handle<Expression>> { + let decl = VarDeclaration { + qualifiers: &mut self.qualifiers, + ty, + name: Some(name), + init, + meta, + }; + + match self.external { + true => { + let global = frontend.add_global_var(self.ctx, decl)?; + let expr = match global { + GlobalOrConstant::Global(handle) => Expression::GlobalVariable(handle), + GlobalOrConstant::Constant(handle) => Expression::Constant(handle), + }; + Ok(self.ctx.add_expression(expr, meta)?) + } + false => frontend.add_local_var(self.ctx, decl), + } + } +} diff --git a/third_party/rust/naga/src/front/glsl/parser/declarations.rs b/third_party/rust/naga/src/front/glsl/parser/declarations.rs new file mode 100644 index 0000000000..f5e38fb016 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/parser/declarations.rs @@ -0,0 +1,677 @@ +use crate::{ + front::glsl::{ + ast::{ + GlobalLookup, GlobalLookupKind, Precision, QualifierKey, QualifierValue, + StorageQualifier, StructLayout, TypeQualifiers, + }, + context::{Context, ExprPos}, + error::ExpectedToken, + offset, + token::{Token, TokenValue}, + types::scalar_components, + variables::{GlobalOrConstant, VarDeclaration}, + Error, ErrorKind, Frontend, Span, + }, + proc::Alignment, + AddressSpace, Expression, FunctionResult, Handle, Scalar, ScalarKind, Statement, StructMember, + Type, TypeInner, +}; + +use super::{DeclarationContext, ParsingContext, Result}; + +/// Helper method used to retrieve the child type of `ty` at +/// index `i`. +/// +/// # Note +/// +/// Does not check if the index is valid and returns the same type +/// when indexing out-of-bounds a struct or indexing a non indexable +/// type. +fn element_or_member_type( + ty: Handle<Type>, + i: usize, + types: &mut crate::UniqueArena<Type>, +) -> Handle<Type> { + match types[ty].inner { + // The child type of a vector is a scalar of the same kind and width + TypeInner::Vector { scalar, .. } => types.insert( + Type { + name: None, + inner: TypeInner::Scalar(scalar), + }, + Default::default(), + ), + // The child type of a matrix is a vector of floats with the same + // width and the size of the matrix rows. + TypeInner::Matrix { rows, scalar, .. } => types.insert( + Type { + name: None, + inner: TypeInner::Vector { size: rows, scalar }, + }, + Default::default(), + ), + // The child type of an array is the base type of the array + TypeInner::Array { base, .. } => base, + // The child type of a struct at index `i` is the type of it's + // member at that same index. + // + // In case the index is out of bounds the same type is returned + TypeInner::Struct { ref members, .. } => { + members.get(i).map(|member| member.ty).unwrap_or(ty) + } + // The type isn't indexable, the same type is returned + _ => ty, + } +} + +impl<'source> ParsingContext<'source> { + pub fn parse_external_declaration( + &mut self, + frontend: &mut Frontend, + global_ctx: &mut Context, + ) -> Result<()> { + if self + .parse_declaration(frontend, global_ctx, true, false)? + .is_none() + { + let token = self.bump(frontend)?; + match token.value { + TokenValue::Semicolon if frontend.meta.version == 460 => Ok(()), + _ => { + let expected = match frontend.meta.version { + 460 => vec![TokenValue::Semicolon.into(), ExpectedToken::Eof], + _ => vec![ExpectedToken::Eof], + }; + Err(Error { + kind: ErrorKind::InvalidToken(token.value, expected), + meta: token.meta, + }) + } + } + } else { + Ok(()) + } + } + + pub fn parse_initializer( + &mut self, + frontend: &mut Frontend, + ty: Handle<Type>, + ctx: &mut Context, + ) -> Result<(Handle<Expression>, Span)> { + // initializer: + // assignment_expression + // LEFT_BRACE initializer_list RIGHT_BRACE + // LEFT_BRACE initializer_list COMMA RIGHT_BRACE + // + // initializer_list: + // initializer + // initializer_list COMMA initializer + if let Some(Token { mut meta, .. }) = self.bump_if(frontend, TokenValue::LeftBrace) { + // initializer_list + let mut components = Vec::new(); + loop { + // The type expected to be parsed inside the initializer list + let new_ty = element_or_member_type(ty, components.len(), &mut ctx.module.types); + + components.push(self.parse_initializer(frontend, new_ty, ctx)?.0); + + let token = self.bump(frontend)?; + match token.value { + TokenValue::Comma => { + if let Some(Token { meta: end_meta, .. }) = + self.bump_if(frontend, TokenValue::RightBrace) + { + meta.subsume(end_meta); + break; + } + } + TokenValue::RightBrace => { + meta.subsume(token.meta); + break; + } + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![TokenValue::Comma.into(), TokenValue::RightBrace.into()], + ), + meta: token.meta, + }) + } + } + } + + Ok(( + ctx.add_expression(Expression::Compose { ty, components }, meta)?, + meta, + )) + } else { + let mut stmt = ctx.stmt_ctx(); + let expr = self.parse_assignment(frontend, ctx, &mut stmt)?; + let (mut init, init_meta) = ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; + + let scalar_components = scalar_components(&ctx.module.types[ty].inner); + if let Some(scalar) = scalar_components { + ctx.implicit_conversion(&mut init, init_meta, scalar)?; + } + + Ok((init, init_meta)) + } + } + + // Note: caller preparsed the type and qualifiers + // Note: caller skips this if the fallthrough token is not expected to be consumed here so this + // produced Error::InvalidToken if it isn't consumed + pub fn parse_init_declarator_list( + &mut self, + frontend: &mut Frontend, + mut ty: Handle<Type>, + ctx: &mut DeclarationContext, + ) -> Result<()> { + // init_declarator_list: + // single_declaration + // init_declarator_list COMMA IDENTIFIER + // init_declarator_list COMMA IDENTIFIER array_specifier + // init_declarator_list COMMA IDENTIFIER array_specifier EQUAL initializer + // init_declarator_list COMMA IDENTIFIER EQUAL initializer + // + // single_declaration: + // fully_specified_type + // fully_specified_type IDENTIFIER + // fully_specified_type IDENTIFIER array_specifier + // fully_specified_type IDENTIFIER array_specifier EQUAL initializer + // fully_specified_type IDENTIFIER EQUAL initializer + + // Consume any leading comma, e.g. this is valid: `float, a=1;` + if self + .peek(frontend) + .map_or(false, |t| t.value == TokenValue::Comma) + { + self.next(frontend); + } + + loop { + let token = self.bump(frontend)?; + let name = match token.value { + TokenValue::Semicolon => break, + TokenValue::Identifier(name) => name, + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()], + ), + meta: token.meta, + }) + } + }; + let mut meta = token.meta; + + // array_specifier + // array_specifier EQUAL initializer + // EQUAL initializer + + // parse an array specifier if it exists + // NOTE: unlike other parse methods this one doesn't expect an array specifier and + // returns Ok(None) rather than an error if there is not one + self.parse_array_specifier(frontend, ctx.ctx, &mut meta, &mut ty)?; + + let is_global_const = + ctx.qualifiers.storage.0 == StorageQualifier::Const && ctx.external; + + let init = self + .bump_if(frontend, TokenValue::Assign) + .map::<Result<_>, _>(|_| { + let prev_const = ctx.ctx.is_const; + ctx.ctx.is_const = is_global_const; + + let (mut expr, init_meta) = self.parse_initializer(frontend, ty, ctx.ctx)?; + + let scalar_components = scalar_components(&ctx.ctx.module.types[ty].inner); + if let Some(scalar) = scalar_components { + ctx.ctx.implicit_conversion(&mut expr, init_meta, scalar)?; + } + + ctx.ctx.is_const = prev_const; + + meta.subsume(init_meta); + + Ok(expr) + }) + .transpose()?; + + let decl_initializer; + let late_initializer; + if is_global_const { + decl_initializer = init; + late_initializer = None; + } else if ctx.external { + decl_initializer = + init.and_then(|expr| ctx.ctx.lift_up_const_expression(expr).ok()); + late_initializer = None; + } else if let Some(init) = init { + if ctx.is_inside_loop || !ctx.ctx.expression_constness.is_const(init) { + decl_initializer = None; + late_initializer = Some(init); + } else { + decl_initializer = Some(init); + late_initializer = None; + } + } else { + decl_initializer = None; + late_initializer = None; + }; + + let pointer = ctx.add_var(frontend, ty, name, decl_initializer, meta)?; + + if let Some(value) = late_initializer { + ctx.ctx.emit_restart(); + ctx.ctx.body.push(Statement::Store { pointer, value }, meta); + } + + let token = self.bump(frontend)?; + match token.value { + TokenValue::Semicolon => break, + TokenValue::Comma => {} + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![TokenValue::Comma.into(), TokenValue::Semicolon.into()], + ), + meta: token.meta, + }) + } + } + } + + Ok(()) + } + + /// `external` whether or not we are in a global or local context + pub fn parse_declaration( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + external: bool, + is_inside_loop: bool, + ) -> Result<Option<Span>> { + //declaration: + // function_prototype SEMICOLON + // + // init_declarator_list SEMICOLON + // PRECISION precision_qualifier type_specifier SEMICOLON + // + // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE SEMICOLON + // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE IDENTIFIER SEMICOLON + // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE IDENTIFIER array_specifier SEMICOLON + // type_qualifier SEMICOLON type_qualifier IDENTIFIER SEMICOLON + // type_qualifier IDENTIFIER identifier_list SEMICOLON + + if self.peek_type_qualifier(frontend) || self.peek_type_name(frontend) { + let mut qualifiers = self.parse_type_qualifiers(frontend, ctx)?; + + if self.peek_type_name(frontend) { + // This branch handles variables and function prototypes and if + // external is true also function definitions + let (ty, mut meta) = self.parse_type(frontend, ctx)?; + + let token = self.bump(frontend)?; + let token_fallthrough = match token.value { + TokenValue::Identifier(name) => match self.expect_peek(frontend)?.value { + TokenValue::LeftParen => { + // This branch handles function definition and prototypes + self.bump(frontend)?; + + let result = ty.map(|ty| FunctionResult { ty, binding: None }); + + let mut context = Context::new(frontend, ctx.module, false)?; + + self.parse_function_args(frontend, &mut context)?; + + let end_meta = self.expect(frontend, TokenValue::RightParen)?.meta; + meta.subsume(end_meta); + + let token = self.bump(frontend)?; + return match token.value { + TokenValue::Semicolon => { + // This branch handles function prototypes + frontend.add_prototype(context, name, result, meta); + + Ok(Some(meta)) + } + TokenValue::LeftBrace if external => { + // This branch handles function definitions + // as you can see by the guard this branch + // only happens if external is also true + + // parse the body + self.parse_compound_statement( + token.meta, + frontend, + &mut context, + &mut None, + false, + )?; + + frontend.add_function(context, name, result, meta); + + Ok(Some(meta)) + } + _ if external => Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![ + TokenValue::LeftBrace.into(), + TokenValue::Semicolon.into(), + ], + ), + meta: token.meta, + }), + _ => Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![TokenValue::Semicolon.into()], + ), + meta: token.meta, + }), + }; + } + // Pass the token to the init_declarator_list parser + _ => Token { + value: TokenValue::Identifier(name), + meta: token.meta, + }, + }, + // Pass the token to the init_declarator_list parser + _ => token, + }; + + // If program execution has reached here then this will be a + // init_declarator_list + // token_fallthrough will have a token that was already bumped + if let Some(ty) = ty { + let mut ctx = DeclarationContext { + qualifiers, + external, + is_inside_loop, + ctx, + }; + + self.backtrack(token_fallthrough)?; + self.parse_init_declarator_list(frontend, ty, &mut ctx)?; + } else { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError("Declaration cannot have void type".into()), + meta, + }) + } + + Ok(Some(meta)) + } else { + // This branch handles struct definitions and modifiers like + // ```glsl + // layout(early_fragment_tests); + // ``` + let token = self.bump(frontend)?; + match token.value { + TokenValue::Identifier(ty_name) => { + if self.bump_if(frontend, TokenValue::LeftBrace).is_some() { + self.parse_block_declaration( + frontend, + ctx, + &mut qualifiers, + ty_name, + token.meta, + ) + .map(Some) + } else { + if qualifiers.invariant.take().is_some() { + frontend.make_variable_invariant(ctx, &ty_name, token.meta)?; + + qualifiers.unused_errors(&mut frontend.errors); + self.expect(frontend, TokenValue::Semicolon)?; + return Ok(Some(qualifiers.span)); + } + + //TODO: declaration + // type_qualifier IDENTIFIER SEMICOLON + // type_qualifier IDENTIFIER identifier_list SEMICOLON + Err(Error { + kind: ErrorKind::NotImplemented("variable qualifier"), + meta: token.meta, + }) + } + } + TokenValue::Semicolon => { + if let Some(value) = + qualifiers.uint_layout_qualifier("local_size_x", &mut frontend.errors) + { + frontend.meta.workgroup_size[0] = value; + } + if let Some(value) = + qualifiers.uint_layout_qualifier("local_size_y", &mut frontend.errors) + { + frontend.meta.workgroup_size[1] = value; + } + if let Some(value) = + qualifiers.uint_layout_qualifier("local_size_z", &mut frontend.errors) + { + frontend.meta.workgroup_size[2] = value; + } + + frontend.meta.early_fragment_tests |= qualifiers + .none_layout_qualifier("early_fragment_tests", &mut frontend.errors); + + qualifiers.unused_errors(&mut frontend.errors); + + Ok(Some(qualifiers.span)) + } + _ => Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()], + ), + meta: token.meta, + }), + } + } + } else { + match self.peek(frontend).map(|t| &t.value) { + Some(&TokenValue::Precision) => { + // PRECISION precision_qualifier type_specifier SEMICOLON + self.bump(frontend)?; + + let token = self.bump(frontend)?; + let _ = match token.value { + TokenValue::PrecisionQualifier(p) => p, + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![ + TokenValue::PrecisionQualifier(Precision::High).into(), + TokenValue::PrecisionQualifier(Precision::Medium).into(), + TokenValue::PrecisionQualifier(Precision::Low).into(), + ], + ), + meta: token.meta, + }) + } + }; + + let (ty, meta) = self.parse_type_non_void(frontend, ctx)?; + + match ctx.module.types[ty].inner { + TypeInner::Scalar(Scalar { + kind: ScalarKind::Float | ScalarKind::Sint, + .. + }) => {} + _ => frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Precision statement can only work on floats and ints".into(), + ), + meta, + }), + } + + self.expect(frontend, TokenValue::Semicolon)?; + + Ok(Some(meta)) + } + _ => Ok(None), + } + } + } + + pub fn parse_block_declaration( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + qualifiers: &mut TypeQualifiers, + ty_name: String, + mut meta: Span, + ) -> Result<Span> { + let layout = match qualifiers.layout_qualifiers.remove(&QualifierKey::Layout) { + Some((QualifierValue::Layout(l), _)) => l, + None => { + if let StorageQualifier::AddressSpace(AddressSpace::Storage { .. }) = + qualifiers.storage.0 + { + StructLayout::Std430 + } else { + StructLayout::Std140 + } + } + _ => unreachable!(), + }; + + let mut members = Vec::new(); + let span = self.parse_struct_declaration_list(frontend, ctx, &mut members, layout)?; + self.expect(frontend, TokenValue::RightBrace)?; + + let mut ty = ctx.module.types.insert( + Type { + name: Some(ty_name), + inner: TypeInner::Struct { + members: members.clone(), + span, + }, + }, + Default::default(), + ); + + let token = self.bump(frontend)?; + let name = match token.value { + TokenValue::Semicolon => None, + TokenValue::Identifier(name) => { + self.parse_array_specifier(frontend, ctx, &mut meta, &mut ty)?; + + self.expect(frontend, TokenValue::Semicolon)?; + + Some(name) + } + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()], + ), + meta: token.meta, + }) + } + }; + + let global = frontend.add_global_var( + ctx, + VarDeclaration { + qualifiers, + ty, + name, + init: None, + meta, + }, + )?; + + for (i, k, ty) in members.into_iter().enumerate().filter_map(|(i, m)| { + let ty = m.ty; + m.name.map(|s| (i as u32, s, ty)) + }) { + let lookup = GlobalLookup { + kind: match global { + GlobalOrConstant::Global(handle) => GlobalLookupKind::BlockSelect(handle, i), + GlobalOrConstant::Constant(handle) => GlobalLookupKind::Constant(handle, ty), + }, + entry_arg: None, + mutable: true, + }; + ctx.add_global(&k, lookup)?; + + frontend.global_variables.push((k, lookup)); + } + + Ok(meta) + } + + // TODO: Accept layout arguments + pub fn parse_struct_declaration_list( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + members: &mut Vec<StructMember>, + layout: StructLayout, + ) -> Result<u32> { + let mut span = 0; + let mut align = Alignment::ONE; + + loop { + // TODO: type_qualifier + + let (base_ty, mut meta) = self.parse_type_non_void(frontend, ctx)?; + + loop { + let (name, name_meta) = self.expect_ident(frontend)?; + let mut ty = base_ty; + self.parse_array_specifier(frontend, ctx, &mut meta, &mut ty)?; + + meta.subsume(name_meta); + + let info = offset::calculate_offset( + ty, + meta, + layout, + &mut ctx.module.types, + &mut frontend.errors, + ); + + let member_alignment = info.align; + span = member_alignment.round_up(span); + align = member_alignment.max(align); + + members.push(StructMember { + name: Some(name), + ty: info.ty, + binding: None, + offset: span, + }); + + span += info.span; + + if self.bump_if(frontend, TokenValue::Comma).is_none() { + break; + } + } + + self.expect(frontend, TokenValue::Semicolon)?; + + if let TokenValue::RightBrace = self.expect_peek(frontend)?.value { + break; + } + } + + span = align.round_up(span); + + Ok(span) + } +} diff --git a/third_party/rust/naga/src/front/glsl/parser/expressions.rs b/third_party/rust/naga/src/front/glsl/parser/expressions.rs new file mode 100644 index 0000000000..1b8febce90 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/parser/expressions.rs @@ -0,0 +1,542 @@ +use std::num::NonZeroU32; + +use crate::{ + front::glsl::{ + ast::{FunctionCall, FunctionCallKind, HirExpr, HirExprKind}, + context::{Context, StmtContext}, + error::{ErrorKind, ExpectedToken}, + parser::ParsingContext, + token::{Token, TokenValue}, + Error, Frontend, Result, Span, + }, + ArraySize, BinaryOperator, Handle, Literal, Type, TypeInner, UnaryOperator, +}; + +impl<'source> ParsingContext<'source> { + pub fn parse_primary( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + ) -> Result<Handle<HirExpr>> { + let mut token = self.bump(frontend)?; + + let literal = match token.value { + TokenValue::IntConstant(int) => { + if int.width != 32 { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError("Unsupported non-32bit integer".into()), + meta: token.meta, + }); + } + if int.signed { + Literal::I32(int.value as i32) + } else { + Literal::U32(int.value as u32) + } + } + TokenValue::FloatConstant(float) => { + if float.width != 32 { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError("Unsupported floating-point value (expected single-precision floating-point number)".into()), + meta: token.meta, + }); + } + Literal::F32(float.value) + } + TokenValue::BoolConstant(value) => Literal::Bool(value), + TokenValue::LeftParen => { + let expr = self.parse_expression(frontend, ctx, stmt)?; + let meta = self.expect(frontend, TokenValue::RightParen)?.meta; + + token.meta.subsume(meta); + + return Ok(expr); + } + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![ + TokenValue::LeftParen.into(), + ExpectedToken::IntLiteral, + ExpectedToken::FloatLiteral, + ExpectedToken::BoolLiteral, + ], + ), + meta: token.meta, + }); + } + }; + + Ok(stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Literal(literal), + meta: token.meta, + }, + Default::default(), + )) + } + + pub fn parse_function_call_args( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + meta: &mut Span, + ) -> Result<Vec<Handle<HirExpr>>> { + let mut args = Vec::new(); + if let Some(token) = self.bump_if(frontend, TokenValue::RightParen) { + meta.subsume(token.meta); + } else { + loop { + args.push(self.parse_assignment(frontend, ctx, stmt)?); + + let token = self.bump(frontend)?; + match token.value { + TokenValue::Comma => {} + TokenValue::RightParen => { + meta.subsume(token.meta); + break; + } + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![TokenValue::Comma.into(), TokenValue::RightParen.into()], + ), + meta: token.meta, + }); + } + } + } + } + + Ok(args) + } + + pub fn parse_postfix( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + ) -> Result<Handle<HirExpr>> { + let mut base = if self.peek_type_name(frontend) { + let (mut handle, mut meta) = self.parse_type_non_void(frontend, ctx)?; + + self.expect(frontend, TokenValue::LeftParen)?; + let args = self.parse_function_call_args(frontend, ctx, stmt, &mut meta)?; + + if let TypeInner::Array { + size: ArraySize::Dynamic, + stride, + base, + } = ctx.module.types[handle].inner + { + let span = ctx.module.types.get_span(handle); + + let size = u32::try_from(args.len()) + .ok() + .and_then(NonZeroU32::new) + .ok_or(Error { + kind: ErrorKind::SemanticError( + "There must be at least one argument".into(), + ), + meta, + })?; + + handle = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Array { + stride, + base, + size: ArraySize::Constant(size), + }, + }, + span, + ) + } + + stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Call(FunctionCall { + kind: FunctionCallKind::TypeConstructor(handle), + args, + }), + meta, + }, + Default::default(), + ) + } else if let TokenValue::Identifier(_) = self.expect_peek(frontend)?.value { + let (name, mut meta) = self.expect_ident(frontend)?; + + let expr = if self.bump_if(frontend, TokenValue::LeftParen).is_some() { + let args = self.parse_function_call_args(frontend, ctx, stmt, &mut meta)?; + + let kind = match frontend.lookup_type.get(&name) { + Some(ty) => FunctionCallKind::TypeConstructor(*ty), + None => FunctionCallKind::Function(name), + }; + + HirExpr { + kind: HirExprKind::Call(FunctionCall { kind, args }), + meta, + } + } else { + let var = match frontend.lookup_variable(ctx, &name, meta)? { + Some(var) => var, + None => { + return Err(Error { + kind: ErrorKind::UnknownVariable(name), + meta, + }) + } + }; + + HirExpr { + kind: HirExprKind::Variable(var), + meta, + } + }; + + stmt.hir_exprs.append(expr, Default::default()) + } else { + self.parse_primary(frontend, ctx, stmt)? + }; + + while let TokenValue::LeftBracket + | TokenValue::Dot + | TokenValue::Increment + | TokenValue::Decrement = self.expect_peek(frontend)?.value + { + let Token { value, mut meta } = self.bump(frontend)?; + + match value { + TokenValue::LeftBracket => { + let index = self.parse_expression(frontend, ctx, stmt)?; + let end_meta = self.expect(frontend, TokenValue::RightBracket)?.meta; + + meta.subsume(end_meta); + base = stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Access { base, index }, + meta, + }, + Default::default(), + ) + } + TokenValue::Dot => { + let (field, end_meta) = self.expect_ident(frontend)?; + + if self.bump_if(frontend, TokenValue::LeftParen).is_some() { + let args = self.parse_function_call_args(frontend, ctx, stmt, &mut meta)?; + + base = stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Method { + expr: base, + name: field, + args, + }, + meta, + }, + Default::default(), + ); + continue; + } + + meta.subsume(end_meta); + base = stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Select { base, field }, + meta, + }, + Default::default(), + ) + } + TokenValue::Increment | TokenValue::Decrement => { + base = stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::PrePostfix { + op: match value { + TokenValue::Increment => crate::BinaryOperator::Add, + _ => crate::BinaryOperator::Subtract, + }, + postfix: true, + expr: base, + }, + meta, + }, + Default::default(), + ) + } + _ => unreachable!(), + } + } + + Ok(base) + } + + pub fn parse_unary( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + ) -> Result<Handle<HirExpr>> { + Ok(match self.expect_peek(frontend)?.value { + TokenValue::Plus | TokenValue::Dash | TokenValue::Bang | TokenValue::Tilde => { + let Token { value, mut meta } = self.bump(frontend)?; + + let expr = self.parse_unary(frontend, ctx, stmt)?; + let end_meta = stmt.hir_exprs[expr].meta; + + let kind = match value { + TokenValue::Dash => HirExprKind::Unary { + op: UnaryOperator::Negate, + expr, + }, + TokenValue::Bang => HirExprKind::Unary { + op: UnaryOperator::LogicalNot, + expr, + }, + TokenValue::Tilde => HirExprKind::Unary { + op: UnaryOperator::BitwiseNot, + expr, + }, + _ => return Ok(expr), + }; + + meta.subsume(end_meta); + stmt.hir_exprs + .append(HirExpr { kind, meta }, Default::default()) + } + TokenValue::Increment | TokenValue::Decrement => { + let Token { value, meta } = self.bump(frontend)?; + + let expr = self.parse_unary(frontend, ctx, stmt)?; + + stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::PrePostfix { + op: match value { + TokenValue::Increment => crate::BinaryOperator::Add, + _ => crate::BinaryOperator::Subtract, + }, + postfix: false, + expr, + }, + meta, + }, + Default::default(), + ) + } + _ => self.parse_postfix(frontend, ctx, stmt)?, + }) + } + + pub fn parse_binary( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + passthrough: Option<Handle<HirExpr>>, + min_bp: u8, + ) -> Result<Handle<HirExpr>> { + let mut left = passthrough + .ok_or(ErrorKind::EndOfFile /* Dummy error */) + .or_else(|_| self.parse_unary(frontend, ctx, stmt))?; + let mut meta = stmt.hir_exprs[left].meta; + + while let Some((l_bp, r_bp)) = binding_power(&self.expect_peek(frontend)?.value) { + if l_bp < min_bp { + break; + } + + let Token { value, .. } = self.bump(frontend)?; + + let right = self.parse_binary(frontend, ctx, stmt, None, r_bp)?; + let end_meta = stmt.hir_exprs[right].meta; + + meta.subsume(end_meta); + left = stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Binary { + left, + op: match value { + TokenValue::LogicalOr => BinaryOperator::LogicalOr, + TokenValue::LogicalXor => BinaryOperator::NotEqual, + TokenValue::LogicalAnd => BinaryOperator::LogicalAnd, + TokenValue::VerticalBar => BinaryOperator::InclusiveOr, + TokenValue::Caret => BinaryOperator::ExclusiveOr, + TokenValue::Ampersand => BinaryOperator::And, + TokenValue::Equal => BinaryOperator::Equal, + TokenValue::NotEqual => BinaryOperator::NotEqual, + TokenValue::GreaterEqual => BinaryOperator::GreaterEqual, + TokenValue::LessEqual => BinaryOperator::LessEqual, + TokenValue::LeftAngle => BinaryOperator::Less, + TokenValue::RightAngle => BinaryOperator::Greater, + TokenValue::LeftShift => BinaryOperator::ShiftLeft, + TokenValue::RightShift => BinaryOperator::ShiftRight, + TokenValue::Plus => BinaryOperator::Add, + TokenValue::Dash => BinaryOperator::Subtract, + TokenValue::Star => BinaryOperator::Multiply, + TokenValue::Slash => BinaryOperator::Divide, + TokenValue::Percent => BinaryOperator::Modulo, + _ => unreachable!(), + }, + right, + }, + meta, + }, + Default::default(), + ) + } + + Ok(left) + } + + pub fn parse_conditional( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + passthrough: Option<Handle<HirExpr>>, + ) -> Result<Handle<HirExpr>> { + let mut condition = self.parse_binary(frontend, ctx, stmt, passthrough, 0)?; + let mut meta = stmt.hir_exprs[condition].meta; + + if self.bump_if(frontend, TokenValue::Question).is_some() { + let accept = self.parse_expression(frontend, ctx, stmt)?; + self.expect(frontend, TokenValue::Colon)?; + let reject = self.parse_assignment(frontend, ctx, stmt)?; + let end_meta = stmt.hir_exprs[reject].meta; + + meta.subsume(end_meta); + condition = stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Conditional { + condition, + accept, + reject, + }, + meta, + }, + Default::default(), + ) + } + + Ok(condition) + } + + pub fn parse_assignment( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + ) -> Result<Handle<HirExpr>> { + let tgt = self.parse_unary(frontend, ctx, stmt)?; + let mut meta = stmt.hir_exprs[tgt].meta; + + Ok(match self.expect_peek(frontend)?.value { + TokenValue::Assign => { + self.bump(frontend)?; + let value = self.parse_assignment(frontend, ctx, stmt)?; + let end_meta = stmt.hir_exprs[value].meta; + + meta.subsume(end_meta); + stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Assign { tgt, value }, + meta, + }, + Default::default(), + ) + } + TokenValue::OrAssign + | TokenValue::AndAssign + | TokenValue::AddAssign + | TokenValue::DivAssign + | TokenValue::ModAssign + | TokenValue::SubAssign + | TokenValue::MulAssign + | TokenValue::LeftShiftAssign + | TokenValue::RightShiftAssign + | TokenValue::XorAssign => { + let token = self.bump(frontend)?; + let right = self.parse_assignment(frontend, ctx, stmt)?; + let end_meta = stmt.hir_exprs[right].meta; + + meta.subsume(end_meta); + let value = stmt.hir_exprs.append( + HirExpr { + meta, + kind: HirExprKind::Binary { + left: tgt, + op: match token.value { + TokenValue::OrAssign => BinaryOperator::InclusiveOr, + TokenValue::AndAssign => BinaryOperator::And, + TokenValue::AddAssign => BinaryOperator::Add, + TokenValue::DivAssign => BinaryOperator::Divide, + TokenValue::ModAssign => BinaryOperator::Modulo, + TokenValue::SubAssign => BinaryOperator::Subtract, + TokenValue::MulAssign => BinaryOperator::Multiply, + TokenValue::LeftShiftAssign => BinaryOperator::ShiftLeft, + TokenValue::RightShiftAssign => BinaryOperator::ShiftRight, + TokenValue::XorAssign => BinaryOperator::ExclusiveOr, + _ => unreachable!(), + }, + right, + }, + }, + Default::default(), + ); + + stmt.hir_exprs.append( + HirExpr { + kind: HirExprKind::Assign { tgt, value }, + meta, + }, + Default::default(), + ) + } + _ => self.parse_conditional(frontend, ctx, stmt, Some(tgt))?, + }) + } + + pub fn parse_expression( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + stmt: &mut StmtContext, + ) -> Result<Handle<HirExpr>> { + let mut expr = self.parse_assignment(frontend, ctx, stmt)?; + + while let TokenValue::Comma = self.expect_peek(frontend)?.value { + self.bump(frontend)?; + expr = self.parse_assignment(frontend, ctx, stmt)?; + } + + Ok(expr) + } +} + +const fn binding_power(value: &TokenValue) -> Option<(u8, u8)> { + Some(match *value { + TokenValue::LogicalOr => (1, 2), + TokenValue::LogicalXor => (3, 4), + TokenValue::LogicalAnd => (5, 6), + TokenValue::VerticalBar => (7, 8), + TokenValue::Caret => (9, 10), + TokenValue::Ampersand => (11, 12), + TokenValue::Equal | TokenValue::NotEqual => (13, 14), + TokenValue::GreaterEqual + | TokenValue::LessEqual + | TokenValue::LeftAngle + | TokenValue::RightAngle => (15, 16), + TokenValue::LeftShift | TokenValue::RightShift => (17, 18), + TokenValue::Plus | TokenValue::Dash => (19, 20), + TokenValue::Star | TokenValue::Slash | TokenValue::Percent => (21, 22), + _ => return None, + }) +} diff --git a/third_party/rust/naga/src/front/glsl/parser/functions.rs b/third_party/rust/naga/src/front/glsl/parser/functions.rs new file mode 100644 index 0000000000..38184eedf7 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/parser/functions.rs @@ -0,0 +1,656 @@ +use crate::front::glsl::context::ExprPos; +use crate::front::glsl::Span; +use crate::Literal; +use crate::{ + front::glsl::{ + ast::ParameterQualifier, + context::Context, + parser::ParsingContext, + token::{Token, TokenValue}, + variables::VarDeclaration, + Error, ErrorKind, Frontend, Result, + }, + Block, Expression, Statement, SwitchCase, UnaryOperator, +}; + +impl<'source> ParsingContext<'source> { + pub fn peek_parameter_qualifier(&mut self, frontend: &mut Frontend) -> bool { + self.peek(frontend).map_or(false, |t| match t.value { + TokenValue::In | TokenValue::Out | TokenValue::InOut | TokenValue::Const => true, + _ => false, + }) + } + + /// Returns the parsed `ParameterQualifier` or `ParameterQualifier::In` + pub fn parse_parameter_qualifier(&mut self, frontend: &mut Frontend) -> ParameterQualifier { + if self.peek_parameter_qualifier(frontend) { + match self.bump(frontend).unwrap().value { + TokenValue::In => ParameterQualifier::In, + TokenValue::Out => ParameterQualifier::Out, + TokenValue::InOut => ParameterQualifier::InOut, + TokenValue::Const => ParameterQualifier::Const, + _ => unreachable!(), + } + } else { + ParameterQualifier::In + } + } + + pub fn parse_statement( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + terminator: &mut Option<usize>, + is_inside_loop: bool, + ) -> Result<Option<Span>> { + // Type qualifiers always identify a declaration statement + if self.peek_type_qualifier(frontend) { + return self.parse_declaration(frontend, ctx, false, is_inside_loop); + } + + // Type names can identify either declaration statements or type constructors + // depending on whether the token following the type name is a `(` (LeftParen) + if self.peek_type_name(frontend) { + // Start by consuming the type name so that we can peek the token after it + let token = self.bump(frontend)?; + // Peek the next token and check if it's a `(` (LeftParen) if so the statement + // is a constructor, otherwise it's a declaration. We need to do the check + // beforehand and not in the if since we will backtrack before the if + let declaration = TokenValue::LeftParen != self.expect_peek(frontend)?.value; + + self.backtrack(token)?; + + if declaration { + return self.parse_declaration(frontend, ctx, false, is_inside_loop); + } + } + + let new_break = || { + let mut block = Block::new(); + block.push(Statement::Break, crate::Span::default()); + block + }; + + let &Token { + ref value, + mut meta, + } = self.expect_peek(frontend)?; + + let meta_rest = match *value { + TokenValue::Continue => { + let meta = self.bump(frontend)?.meta; + ctx.body.push(Statement::Continue, meta); + terminator.get_or_insert(ctx.body.len()); + self.expect(frontend, TokenValue::Semicolon)?.meta + } + TokenValue::Break => { + let meta = self.bump(frontend)?.meta; + ctx.body.push(Statement::Break, meta); + terminator.get_or_insert(ctx.body.len()); + self.expect(frontend, TokenValue::Semicolon)?.meta + } + TokenValue::Return => { + self.bump(frontend)?; + let (value, meta) = match self.expect_peek(frontend)?.value { + TokenValue::Semicolon => (None, self.bump(frontend)?.meta), + _ => { + // TODO: Implicit conversions + let mut stmt = ctx.stmt_ctx(); + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; + self.expect(frontend, TokenValue::Semicolon)?; + let (handle, meta) = + ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; + (Some(handle), meta) + } + }; + + ctx.emit_restart(); + + ctx.body.push(Statement::Return { value }, meta); + terminator.get_or_insert(ctx.body.len()); + + meta + } + TokenValue::Discard => { + let meta = self.bump(frontend)?.meta; + ctx.body.push(Statement::Kill, meta); + terminator.get_or_insert(ctx.body.len()); + + self.expect(frontend, TokenValue::Semicolon)?.meta + } + TokenValue::If => { + let mut meta = self.bump(frontend)?.meta; + + self.expect(frontend, TokenValue::LeftParen)?; + let condition = { + let mut stmt = ctx.stmt_ctx(); + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; + let (handle, more_meta) = + ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; + meta.subsume(more_meta); + handle + }; + self.expect(frontend, TokenValue::RightParen)?; + + let accept = ctx.new_body(|ctx| { + if let Some(more_meta) = + self.parse_statement(frontend, ctx, &mut None, is_inside_loop)? + { + meta.subsume(more_meta); + } + Ok(()) + })?; + + let reject = ctx.new_body(|ctx| { + if self.bump_if(frontend, TokenValue::Else).is_some() { + if let Some(more_meta) = + self.parse_statement(frontend, ctx, &mut None, is_inside_loop)? + { + meta.subsume(more_meta); + } + } + Ok(()) + })?; + + ctx.body.push( + Statement::If { + condition, + accept, + reject, + }, + meta, + ); + + meta + } + TokenValue::Switch => { + let mut meta = self.bump(frontend)?.meta; + let end_meta; + + self.expect(frontend, TokenValue::LeftParen)?; + + let (selector, uint) = { + let mut stmt = ctx.stmt_ctx(); + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; + let (root, meta) = ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?; + let uint = ctx.resolve_type(root, meta)?.scalar_kind() + == Some(crate::ScalarKind::Uint); + (root, uint) + }; + + self.expect(frontend, TokenValue::RightParen)?; + + ctx.emit_restart(); + + let mut cases = Vec::new(); + // Track if any default case is present in the switch statement. + let mut default_present = false; + + self.expect(frontend, TokenValue::LeftBrace)?; + loop { + let value = match self.expect_peek(frontend)?.value { + TokenValue::Case => { + self.bump(frontend)?; + + let (const_expr, meta) = + self.parse_constant_expression(frontend, ctx.module)?; + + match ctx.module.const_expressions[const_expr] { + Expression::Literal(Literal::I32(value)) => match uint { + // This unchecked cast isn't good, but since + // we only reach this code when the selector + // is unsigned but the case label is signed, + // verification will reject the module + // anyway (which also matches GLSL's rules). + true => crate::SwitchValue::U32(value as u32), + false => crate::SwitchValue::I32(value), + }, + Expression::Literal(Literal::U32(value)) => { + crate::SwitchValue::U32(value) + } + _ => { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Case values can only be integers".into(), + ), + meta, + }); + + crate::SwitchValue::I32(0) + } + } + } + TokenValue::Default => { + self.bump(frontend)?; + default_present = true; + crate::SwitchValue::Default + } + TokenValue::RightBrace => { + end_meta = self.bump(frontend)?.meta; + break; + } + _ => { + let Token { value, meta } = self.bump(frontend)?; + return Err(Error { + kind: ErrorKind::InvalidToken( + value, + vec![ + TokenValue::Case.into(), + TokenValue::Default.into(), + TokenValue::RightBrace.into(), + ], + ), + meta, + }); + } + }; + + self.expect(frontend, TokenValue::Colon)?; + + let mut fall_through = true; + + let body = ctx.new_body(|ctx| { + let mut case_terminator = None; + loop { + match self.expect_peek(frontend)?.value { + TokenValue::Case | TokenValue::Default | TokenValue::RightBrace => { + break + } + _ => { + self.parse_statement( + frontend, + ctx, + &mut case_terminator, + is_inside_loop, + )?; + } + } + } + + if let Some(mut idx) = case_terminator { + if let Statement::Break = ctx.body[idx - 1] { + fall_through = false; + idx -= 1; + } + + ctx.body.cull(idx..) + } + + Ok(()) + })?; + + cases.push(SwitchCase { + value, + body, + fall_through, + }) + } + + meta.subsume(end_meta); + + // NOTE: do not unwrap here since a switch statement isn't required + // to have any cases. + if let Some(case) = cases.last_mut() { + // GLSL requires that the last case not be empty, so we check + // that here and produce an error otherwise (fall_through must + // also be checked because `break`s count as statements but + // they aren't added to the body) + if case.body.is_empty() && case.fall_through { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "last case/default label must be followed by statements".into(), + ), + meta, + }) + } + + // GLSL allows the last case to not have any `break` statement, + // this would mark it as fall through but naga's IR requires that + // the last case must not be fall through, so we mark need to mark + // the last case as not fall through always. + case.fall_through = false; + } + + // Add an empty default case in case non was present, this is needed because + // naga's IR requires that all switch statements must have a default case but + // GLSL doesn't require that, so we might need to add an empty default case. + if !default_present { + cases.push(SwitchCase { + value: crate::SwitchValue::Default, + body: Block::new(), + fall_through: false, + }) + } + + ctx.body.push(Statement::Switch { selector, cases }, meta); + + meta + } + TokenValue::While => { + let mut meta = self.bump(frontend)?.meta; + + let loop_body = ctx.new_body(|ctx| { + let mut stmt = ctx.stmt_ctx(); + self.expect(frontend, TokenValue::LeftParen)?; + let root = self.parse_expression(frontend, ctx, &mut stmt)?; + meta.subsume(self.expect(frontend, TokenValue::RightParen)?.meta); + + let (expr, expr_meta) = ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs)?; + let condition = ctx.add_expression( + Expression::Unary { + op: UnaryOperator::LogicalNot, + expr, + }, + expr_meta, + )?; + + ctx.emit_restart(); + + ctx.body.push( + Statement::If { + condition, + accept: new_break(), + reject: Block::new(), + }, + crate::Span::default(), + ); + + meta.subsume(expr_meta); + + if let Some(body_meta) = self.parse_statement(frontend, ctx, &mut None, true)? { + meta.subsume(body_meta); + } + Ok(()) + })?; + + ctx.body.push( + Statement::Loop { + body: loop_body, + continuing: Block::new(), + break_if: None, + }, + meta, + ); + + meta + } + TokenValue::Do => { + let mut meta = self.bump(frontend)?.meta; + + let loop_body = ctx.new_body(|ctx| { + let mut terminator = None; + self.parse_statement(frontend, ctx, &mut terminator, true)?; + + let mut stmt = ctx.stmt_ctx(); + + self.expect(frontend, TokenValue::While)?; + self.expect(frontend, TokenValue::LeftParen)?; + let root = self.parse_expression(frontend, ctx, &mut stmt)?; + let end_meta = self.expect(frontend, TokenValue::RightParen)?.meta; + + meta.subsume(end_meta); + + let (expr, expr_meta) = ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs)?; + let condition = ctx.add_expression( + Expression::Unary { + op: UnaryOperator::LogicalNot, + expr, + }, + expr_meta, + )?; + + ctx.emit_restart(); + + ctx.body.push( + Statement::If { + condition, + accept: new_break(), + reject: Block::new(), + }, + crate::Span::default(), + ); + + if let Some(idx) = terminator { + ctx.body.cull(idx..) + } + Ok(()) + })?; + + ctx.body.push( + Statement::Loop { + body: loop_body, + continuing: Block::new(), + break_if: None, + }, + meta, + ); + + meta + } + TokenValue::For => { + let mut meta = self.bump(frontend)?.meta; + + ctx.symbol_table.push_scope(); + self.expect(frontend, TokenValue::LeftParen)?; + + if self.bump_if(frontend, TokenValue::Semicolon).is_none() { + if self.peek_type_name(frontend) || self.peek_type_qualifier(frontend) { + self.parse_declaration(frontend, ctx, false, false)?; + } else { + let mut stmt = ctx.stmt_ctx(); + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; + ctx.lower(stmt, frontend, expr, ExprPos::Rhs)?; + self.expect(frontend, TokenValue::Semicolon)?; + } + } + + let loop_body = ctx.new_body(|ctx| { + if self.bump_if(frontend, TokenValue::Semicolon).is_none() { + let (expr, expr_meta) = if self.peek_type_name(frontend) + || self.peek_type_qualifier(frontend) + { + let mut qualifiers = self.parse_type_qualifiers(frontend, ctx)?; + let (ty, mut meta) = self.parse_type_non_void(frontend, ctx)?; + let name = self.expect_ident(frontend)?.0; + + self.expect(frontend, TokenValue::Assign)?; + + let (value, end_meta) = self.parse_initializer(frontend, ty, ctx)?; + meta.subsume(end_meta); + + let decl = VarDeclaration { + qualifiers: &mut qualifiers, + ty, + name: Some(name), + init: None, + meta, + }; + + let pointer = frontend.add_local_var(ctx, decl)?; + + ctx.emit_restart(); + + ctx.body.push(Statement::Store { pointer, value }, meta); + + (value, end_meta) + } else { + let mut stmt = ctx.stmt_ctx(); + let root = self.parse_expression(frontend, ctx, &mut stmt)?; + ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs)? + }; + + let condition = ctx.add_expression( + Expression::Unary { + op: UnaryOperator::LogicalNot, + expr, + }, + expr_meta, + )?; + + ctx.emit_restart(); + + ctx.body.push( + Statement::If { + condition, + accept: new_break(), + reject: Block::new(), + }, + crate::Span::default(), + ); + + self.expect(frontend, TokenValue::Semicolon)?; + } + Ok(()) + })?; + + let continuing = ctx.new_body(|ctx| { + match self.expect_peek(frontend)?.value { + TokenValue::RightParen => {} + _ => { + let mut stmt = ctx.stmt_ctx(); + let rest = self.parse_expression(frontend, ctx, &mut stmt)?; + ctx.lower(stmt, frontend, rest, ExprPos::Rhs)?; + } + } + Ok(()) + })?; + + meta.subsume(self.expect(frontend, TokenValue::RightParen)?.meta); + + let loop_body = ctx.with_body(loop_body, |ctx| { + if let Some(stmt_meta) = self.parse_statement(frontend, ctx, &mut None, true)? { + meta.subsume(stmt_meta); + } + Ok(()) + })?; + + ctx.body.push( + Statement::Loop { + body: loop_body, + continuing, + break_if: None, + }, + meta, + ); + + ctx.symbol_table.pop_scope(); + + meta + } + TokenValue::LeftBrace => { + let mut meta = self.bump(frontend)?.meta; + + let mut block_terminator = None; + + let block = ctx.new_body(|ctx| { + let block_meta = self.parse_compound_statement( + meta, + frontend, + ctx, + &mut block_terminator, + is_inside_loop, + )?; + meta.subsume(block_meta); + Ok(()) + })?; + + ctx.body.push(Statement::Block(block), meta); + if block_terminator.is_some() { + terminator.get_or_insert(ctx.body.len()); + } + + meta + } + TokenValue::Semicolon => self.bump(frontend)?.meta, + _ => { + // Attempt to force expression parsing for remainder of the + // tokens. Unknown or invalid tokens will be caught there and + // turned into an error. + let mut stmt = ctx.stmt_ctx(); + let expr = self.parse_expression(frontend, ctx, &mut stmt)?; + ctx.lower(stmt, frontend, expr, ExprPos::Rhs)?; + self.expect(frontend, TokenValue::Semicolon)?.meta + } + }; + + meta.subsume(meta_rest); + Ok(Some(meta)) + } + + pub fn parse_compound_statement( + &mut self, + mut meta: Span, + frontend: &mut Frontend, + ctx: &mut Context, + terminator: &mut Option<usize>, + is_inside_loop: bool, + ) -> Result<Span> { + ctx.symbol_table.push_scope(); + + loop { + if let Some(Token { + meta: brace_meta, .. + }) = self.bump_if(frontend, TokenValue::RightBrace) + { + meta.subsume(brace_meta); + break; + } + + let stmt = self.parse_statement(frontend, ctx, terminator, is_inside_loop)?; + + if let Some(stmt_meta) = stmt { + meta.subsume(stmt_meta); + } + } + + if let Some(idx) = *terminator { + ctx.body.cull(idx..) + } + + ctx.symbol_table.pop_scope(); + + Ok(meta) + } + + pub fn parse_function_args( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + ) -> Result<()> { + if self.bump_if(frontend, TokenValue::Void).is_some() { + return Ok(()); + } + + loop { + if self.peek_type_name(frontend) || self.peek_parameter_qualifier(frontend) { + let qualifier = self.parse_parameter_qualifier(frontend); + let mut ty = self.parse_type_non_void(frontend, ctx)?.0; + + match self.expect_peek(frontend)?.value { + TokenValue::Comma => { + self.bump(frontend)?; + ctx.add_function_arg(None, ty, qualifier)?; + continue; + } + TokenValue::Identifier(_) => { + let mut name = self.expect_ident(frontend)?; + self.parse_array_specifier(frontend, ctx, &mut name.1, &mut ty)?; + + ctx.add_function_arg(Some(name), ty, qualifier)?; + + if self.bump_if(frontend, TokenValue::Comma).is_some() { + continue; + } + + break; + } + _ => break, + } + } + + break; + } + + Ok(()) + } +} diff --git a/third_party/rust/naga/src/front/glsl/parser/types.rs b/third_party/rust/naga/src/front/glsl/parser/types.rs new file mode 100644 index 0000000000..1b612b298d --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/parser/types.rs @@ -0,0 +1,443 @@ +use std::num::NonZeroU32; + +use crate::{ + front::glsl::{ + ast::{QualifierKey, QualifierValue, StorageQualifier, StructLayout, TypeQualifiers}, + context::Context, + error::ExpectedToken, + parser::ParsingContext, + token::{Token, TokenValue}, + Error, ErrorKind, Frontend, Result, + }, + AddressSpace, ArraySize, Handle, Span, Type, TypeInner, +}; + +impl<'source> ParsingContext<'source> { + /// Parses an optional array_specifier returning whether or not it's present + /// and modifying the type handle if it exists + pub fn parse_array_specifier( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + span: &mut Span, + ty: &mut Handle<Type>, + ) -> Result<()> { + while self.parse_array_specifier_single(frontend, ctx, span, ty)? {} + Ok(()) + } + + /// Implementation of [`Self::parse_array_specifier`] for a single array_specifier + fn parse_array_specifier_single( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + span: &mut Span, + ty: &mut Handle<Type>, + ) -> Result<bool> { + if self.bump_if(frontend, TokenValue::LeftBracket).is_some() { + let size = if let Some(Token { meta, .. }) = + self.bump_if(frontend, TokenValue::RightBracket) + { + span.subsume(meta); + ArraySize::Dynamic + } else { + let (value, constant_span) = self.parse_uint_constant(frontend, ctx)?; + let size = NonZeroU32::new(value).ok_or(Error { + kind: ErrorKind::SemanticError("Array size must be greater than zero".into()), + meta: constant_span, + })?; + let end_span = self.expect(frontend, TokenValue::RightBracket)?.meta; + span.subsume(end_span); + ArraySize::Constant(size) + }; + + frontend.layouter.update(ctx.module.to_ctx()).unwrap(); + let stride = frontend.layouter[*ty].to_stride(); + *ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Array { + base: *ty, + size, + stride, + }, + }, + *span, + ); + + Ok(true) + } else { + Ok(false) + } + } + + pub fn parse_type( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + ) -> Result<(Option<Handle<Type>>, Span)> { + let token = self.bump(frontend)?; + let mut handle = match token.value { + TokenValue::Void => return Ok((None, token.meta)), + TokenValue::TypeName(ty) => ctx.module.types.insert(ty, token.meta), + TokenValue::Struct => { + let mut meta = token.meta; + let ty_name = self.expect_ident(frontend)?.0; + self.expect(frontend, TokenValue::LeftBrace)?; + let mut members = Vec::new(); + let span = self.parse_struct_declaration_list( + frontend, + ctx, + &mut members, + StructLayout::Std140, + )?; + let end_meta = self.expect(frontend, TokenValue::RightBrace)?.meta; + meta.subsume(end_meta); + let ty = ctx.module.types.insert( + Type { + name: Some(ty_name.clone()), + inner: TypeInner::Struct { members, span }, + }, + meta, + ); + frontend.lookup_type.insert(ty_name, ty); + ty + } + TokenValue::Identifier(ident) => match frontend.lookup_type.get(&ident) { + Some(ty) => *ty, + None => { + return Err(Error { + kind: ErrorKind::UnknownType(ident), + meta: token.meta, + }) + } + }, + _ => { + return Err(Error { + kind: ErrorKind::InvalidToken( + token.value, + vec![ + TokenValue::Void.into(), + TokenValue::Struct.into(), + ExpectedToken::TypeName, + ], + ), + meta: token.meta, + }); + } + }; + + let mut span = token.meta; + self.parse_array_specifier(frontend, ctx, &mut span, &mut handle)?; + Ok((Some(handle), span)) + } + + pub fn parse_type_non_void( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + ) -> Result<(Handle<Type>, Span)> { + let (maybe_ty, meta) = self.parse_type(frontend, ctx)?; + let ty = maybe_ty.ok_or_else(|| Error { + kind: ErrorKind::SemanticError("Type can't be void".into()), + meta, + })?; + + Ok((ty, meta)) + } + + pub fn peek_type_qualifier(&mut self, frontend: &mut Frontend) -> bool { + self.peek(frontend).map_or(false, |t| match t.value { + TokenValue::Invariant + | TokenValue::Interpolation(_) + | TokenValue::Sampling(_) + | TokenValue::PrecisionQualifier(_) + | TokenValue::Const + | TokenValue::In + | TokenValue::Out + | TokenValue::Uniform + | TokenValue::Shared + | TokenValue::Buffer + | TokenValue::Restrict + | TokenValue::MemoryQualifier(_) + | TokenValue::Layout => true, + _ => false, + }) + } + + pub fn parse_type_qualifiers<'a>( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + ) -> Result<TypeQualifiers<'a>> { + let mut qualifiers = TypeQualifiers::default(); + + while self.peek_type_qualifier(frontend) { + let token = self.bump(frontend)?; + + // Handle layout qualifiers outside the match since this can push multiple values + if token.value == TokenValue::Layout { + self.parse_layout_qualifier_id_list(frontend, ctx, &mut qualifiers)?; + continue; + } + + qualifiers.span.subsume(token.meta); + + match token.value { + TokenValue::Invariant => { + if qualifiers.invariant.is_some() { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Cannot use more than one invariant qualifier per declaration" + .into(), + ), + meta: token.meta, + }) + } + + qualifiers.invariant = Some(token.meta); + } + TokenValue::Interpolation(i) => { + if qualifiers.interpolation.is_some() { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Cannot use more than one interpolation qualifier per declaration" + .into(), + ), + meta: token.meta, + }) + } + + qualifiers.interpolation = Some((i, token.meta)); + } + TokenValue::Const + | TokenValue::In + | TokenValue::Out + | TokenValue::Uniform + | TokenValue::Shared + | TokenValue::Buffer => { + let storage = match token.value { + TokenValue::Const => StorageQualifier::Const, + TokenValue::In => StorageQualifier::Input, + TokenValue::Out => StorageQualifier::Output, + TokenValue::Uniform => { + StorageQualifier::AddressSpace(AddressSpace::Uniform) + } + TokenValue::Shared => { + StorageQualifier::AddressSpace(AddressSpace::WorkGroup) + } + TokenValue::Buffer => { + StorageQualifier::AddressSpace(AddressSpace::Storage { + access: crate::StorageAccess::all(), + }) + } + _ => unreachable!(), + }; + + if StorageQualifier::AddressSpace(AddressSpace::Function) + != qualifiers.storage.0 + { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Cannot use more than one storage qualifier per declaration".into(), + ), + meta: token.meta, + }); + } + + qualifiers.storage = (storage, token.meta); + } + TokenValue::Sampling(s) => { + if qualifiers.sampling.is_some() { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Cannot use more than one sampling qualifier per declaration" + .into(), + ), + meta: token.meta, + }) + } + + qualifiers.sampling = Some((s, token.meta)); + } + TokenValue::PrecisionQualifier(p) => { + if qualifiers.precision.is_some() { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "Cannot use more than one precision qualifier per declaration" + .into(), + ), + meta: token.meta, + }) + } + + qualifiers.precision = Some((p, token.meta)); + } + TokenValue::MemoryQualifier(access) => { + let storage_access = qualifiers + .storage_access + .get_or_insert((crate::StorageAccess::all(), Span::default())); + if !storage_access.0.contains(!access) { + frontend.errors.push(Error { + kind: ErrorKind::SemanticError( + "The same memory qualifier can only be used once".into(), + ), + meta: token.meta, + }) + } + + storage_access.0 &= access; + storage_access.1.subsume(token.meta); + } + TokenValue::Restrict => continue, + _ => unreachable!(), + }; + } + + Ok(qualifiers) + } + + pub fn parse_layout_qualifier_id_list( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + qualifiers: &mut TypeQualifiers, + ) -> Result<()> { + self.expect(frontend, TokenValue::LeftParen)?; + loop { + self.parse_layout_qualifier_id(frontend, ctx, &mut qualifiers.layout_qualifiers)?; + + if self.bump_if(frontend, TokenValue::Comma).is_some() { + continue; + } + + break; + } + let token = self.expect(frontend, TokenValue::RightParen)?; + qualifiers.span.subsume(token.meta); + + Ok(()) + } + + pub fn parse_layout_qualifier_id( + &mut self, + frontend: &mut Frontend, + ctx: &mut Context, + qualifiers: &mut crate::FastHashMap<QualifierKey, (QualifierValue, Span)>, + ) -> Result<()> { + // layout_qualifier_id: + // IDENTIFIER + // IDENTIFIER EQUAL constant_expression + // SHARED + let mut token = self.bump(frontend)?; + match token.value { + TokenValue::Identifier(name) => { + let (key, value) = match name.as_str() { + "std140" => ( + QualifierKey::Layout, + QualifierValue::Layout(StructLayout::Std140), + ), + "std430" => ( + QualifierKey::Layout, + QualifierValue::Layout(StructLayout::Std430), + ), + word => { + if let Some(format) = map_image_format(word) { + (QualifierKey::Format, QualifierValue::Format(format)) + } else { + let key = QualifierKey::String(name.into()); + let value = if self.bump_if(frontend, TokenValue::Assign).is_some() { + let (value, end_meta) = + match self.parse_uint_constant(frontend, ctx) { + Ok(v) => v, + Err(e) => { + frontend.errors.push(e); + (0, Span::default()) + } + }; + token.meta.subsume(end_meta); + + QualifierValue::Uint(value) + } else { + QualifierValue::None + }; + + (key, value) + } + } + }; + + qualifiers.insert(key, (value, token.meta)); + } + _ => frontend.errors.push(Error { + kind: ErrorKind::InvalidToken(token.value, vec![ExpectedToken::Identifier]), + meta: token.meta, + }), + } + + Ok(()) + } + + pub fn peek_type_name(&mut self, frontend: &mut Frontend) -> bool { + self.peek(frontend).map_or(false, |t| match t.value { + TokenValue::TypeName(_) | TokenValue::Void => true, + TokenValue::Struct => true, + TokenValue::Identifier(ref ident) => frontend.lookup_type.contains_key(ident), + _ => false, + }) + } +} + +fn map_image_format(word: &str) -> Option<crate::StorageFormat> { + use crate::StorageFormat as Sf; + + let format = match word { + // float-image-format-qualifier: + "rgba32f" => Sf::Rgba32Float, + "rgba16f" => Sf::Rgba16Float, + "rg32f" => Sf::Rg32Float, + "rg16f" => Sf::Rg16Float, + "r11f_g11f_b10f" => Sf::Rg11b10Float, + "r32f" => Sf::R32Float, + "r16f" => Sf::R16Float, + "rgba16" => Sf::Rgba16Unorm, + "rgb10_a2ui" => Sf::Rgb10a2Uint, + "rgb10_a2" => Sf::Rgb10a2Unorm, + "rgba8" => Sf::Rgba8Unorm, + "rg16" => Sf::Rg16Unorm, + "rg8" => Sf::Rg8Unorm, + "r16" => Sf::R16Unorm, + "r8" => Sf::R8Unorm, + "rgba16_snorm" => Sf::Rgba16Snorm, + "rgba8_snorm" => Sf::Rgba8Snorm, + "rg16_snorm" => Sf::Rg16Snorm, + "rg8_snorm" => Sf::Rg8Snorm, + "r16_snorm" => Sf::R16Snorm, + "r8_snorm" => Sf::R8Snorm, + // int-image-format-qualifier: + "rgba32i" => Sf::Rgba32Sint, + "rgba16i" => Sf::Rgba16Sint, + "rgba8i" => Sf::Rgba8Sint, + "rg32i" => Sf::Rg32Sint, + "rg16i" => Sf::Rg16Sint, + "rg8i" => Sf::Rg8Sint, + "r32i" => Sf::R32Sint, + "r16i" => Sf::R16Sint, + "r8i" => Sf::R8Sint, + // uint-image-format-qualifier: + "rgba32ui" => Sf::Rgba32Uint, + "rgba16ui" => Sf::Rgba16Uint, + "rgba8ui" => Sf::Rgba8Uint, + "rg32ui" => Sf::Rg32Uint, + "rg16ui" => Sf::Rg16Uint, + "rg8ui" => Sf::Rg8Uint, + "r32ui" => Sf::R32Uint, + "r16ui" => Sf::R16Uint, + "r8ui" => Sf::R8Uint, + // TODO: These next ones seem incorrect to me + // "rgb10_a2ui" => Sf::Rgb10a2Unorm, + _ => return None, + }; + + Some(format) +} diff --git a/third_party/rust/naga/src/front/glsl/parser_tests.rs b/third_party/rust/naga/src/front/glsl/parser_tests.rs new file mode 100644 index 0000000000..259052cd27 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/parser_tests.rs @@ -0,0 +1,858 @@ +use super::{ + ast::Profile, + error::ExpectedToken, + error::{Error, ErrorKind, ParseError}, + token::TokenValue, + Frontend, Options, Span, +}; +use crate::ShaderStage; +use pp_rs::token::PreprocessorError; + +#[test] +fn version() { + let mut frontend = Frontend::default(); + + // invalid versions + assert_eq!( + frontend + .parse( + &Options::from(ShaderStage::Vertex), + "#version 99000\n void main(){}", + ) + .err() + .unwrap(), + ParseError { + errors: vec![Error { + kind: ErrorKind::InvalidVersion(99000), + meta: Span::new(9, 14) + }], + }, + ); + + assert_eq!( + frontend + .parse( + &Options::from(ShaderStage::Vertex), + "#version 449\n void main(){}", + ) + .err() + .unwrap(), + ParseError { + errors: vec![Error { + kind: ErrorKind::InvalidVersion(449), + meta: Span::new(9, 12) + }] + }, + ); + + assert_eq!( + frontend + .parse( + &Options::from(ShaderStage::Vertex), + "#version 450 smart\n void main(){}", + ) + .err() + .unwrap(), + ParseError { + errors: vec![Error { + kind: ErrorKind::InvalidProfile("smart".into()), + meta: Span::new(13, 18), + }] + }, + ); + + assert_eq!( + frontend + .parse( + &Options::from(ShaderStage::Vertex), + "#version 450\nvoid main(){} #version 450", + ) + .err() + .unwrap(), + ParseError { + errors: vec![ + Error { + kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedHash,), + meta: Span::new(27, 28), + }, + Error { + kind: ErrorKind::InvalidToken( + TokenValue::Identifier("version".into()), + vec![ExpectedToken::Eof] + ), + meta: Span::new(28, 35) + } + ] + }, + ); + + // valid versions + frontend + .parse( + &Options::from(ShaderStage::Vertex), + " # version 450\nvoid main() {}", + ) + .unwrap(); + assert_eq!( + (frontend.metadata().version, frontend.metadata().profile), + (450, Profile::Core) + ); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + "#version 450\nvoid main() {}", + ) + .unwrap(); + assert_eq!( + (frontend.metadata().version, frontend.metadata().profile), + (450, Profile::Core) + ); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + "#version 450 core\nvoid main(void) {}", + ) + .unwrap(); + assert_eq!( + (frontend.metadata().version, frontend.metadata().profile), + (450, Profile::Core) + ); +} + +#[test] +fn control_flow() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + if (true) { + return 1; + } else { + return 2; + } + } + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + if (true) { + return 1; + } + } + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + int x; + int y = 3; + switch (5) { + case 2: + x = 2; + case 5: + x = 5; + y = 2; + break; + default: + x = 0; + } + } + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + int x = 0; + while(x < 5) { + x = x + 1; + } + do { + x = x - 1; + } while(x >= 4) + } + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + int x = 0; + for(int i = 0; i < 10;) { + x = x + 2; + } + for(;;); + return x; + } + "#, + ) + .unwrap(); +} + +#[test] +fn declarations() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + layout(location = 0) in vec2 v_uv; + layout(location = 0) out vec4 o_color; + layout(set = 1, binding = 1) uniform texture2D tex; + layout(set = 1, binding = 2) uniform sampler tex_sampler; + + layout(early_fragment_tests) in; + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + layout(std140, set = 2, binding = 0) + uniform u_locals { + vec3 model_offs; + float load_time; + ivec4 atlas_offs; + }; + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + layout(push_constant) + uniform u_locals { + vec3 model_offs; + float load_time; + ivec4 atlas_offs; + }; + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + layout(std430, set = 2, binding = 0) + uniform u_locals { + vec3 model_offs; + float load_time; + ivec4 atlas_offs; + }; + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + layout(std140, set = 2, binding = 0) + uniform u_locals { + vec3 model_offs; + float load_time; + } block_var; + + void main() { + load_time * model_offs; + block_var.load_time * block_var.model_offs; + } + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + float vector = vec4(1.0 / 17.0, 9.0 / 17.0, 3.0 / 17.0, 11.0 / 17.0); + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + precision highp float; + + void main() {} + "#, + ) + .unwrap(); +} + +#[test] +fn textures() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + #version 450 + layout(location = 0) in vec2 v_uv; + layout(location = 0) out vec4 o_color; + layout(set = 1, binding = 1) uniform texture2D tex; + layout(set = 1, binding = 2) uniform sampler tex_sampler; + void main() { + o_color = texture(sampler2D(tex, tex_sampler), v_uv); + o_color.a = texture(sampler2D(tex, tex_sampler), v_uv, 2.0).a; + } + "#, + ) + .unwrap(); +} + +#[test] +fn functions() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void test1(float); + void test1(float) {} + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void test2(float a) {} + void test3(float a, float b) {} + void test4(float, float) {} + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + float test(float a) { return a; } + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + float test(vec4 p) { + return p.x; + } + + void main() {} + "#, + ) + .unwrap(); + + // Function overloading + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + float test(vec2 p); + float test(vec3 p); + float test(vec4 p); + + float test(vec2 p) { + return p.x; + } + + float test(vec3 p) { + return p.x; + } + + float test(vec4 p) { + return p.x; + } + + void main() {} + "#, + ) + .unwrap(); + + assert_eq!( + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + int test(vec4 p) { + return p.x; + } + + float test(vec4 p) { + return p.x; + } + + void main() {} + "#, + ) + .err() + .unwrap(), + ParseError { + errors: vec![Error { + kind: ErrorKind::SemanticError("Function already defined".into()), + meta: Span::new(134, 152), + }] + }, + ); + + println!(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + float callee(uint q) { + return float(q); + } + + float caller() { + callee(1u); + } + + void main() {} + "#, + ) + .unwrap(); + + // Nested function call + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + layout(set = 0, binding = 1) uniform texture2D t_noise; + layout(set = 0, binding = 2) uniform sampler s_noise; + + void main() { + textureLod(sampler2D(t_noise, s_noise), vec2(1.0), 0); + } + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void fun(vec2 in_parameter, out float out_parameter) { + ivec2 _ = ivec2(in_parameter); + } + + void main() { + float a; + fun(vec2(1.0), a); + } + "#, + ) + .unwrap(); +} + +#[test] +fn constants() { + use crate::{Constant, Expression, Type, TypeInner}; + + let mut frontend = Frontend::default(); + + let module = frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + const float a = 1.0; + float global = a; + const float b = a; + + void main() {} + "#, + ) + .unwrap(); + + let mut types = module.types.iter(); + let mut constants = module.constants.iter(); + let mut const_expressions = module.const_expressions.iter(); + + let (ty_handle, ty) = types.next().unwrap(); + assert_eq!( + ty, + &Type { + name: None, + inner: TypeInner::Scalar(crate::Scalar::F32) + } + ); + + let (init_handle, init) = const_expressions.next().unwrap(); + assert_eq!(init, &Expression::Literal(crate::Literal::F32(1.0))); + + assert_eq!( + constants.next().unwrap().1, + &Constant { + name: Some("a".to_owned()), + r#override: crate::Override::None, + ty: ty_handle, + init: init_handle + } + ); + + assert_eq!( + constants.next().unwrap().1, + &Constant { + name: Some("b".to_owned()), + r#override: crate::Override::None, + ty: ty_handle, + init: init_handle + } + ); + + assert!(constants.next().is_none()); +} + +#[test] +fn function_overloading() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + + float saturate(float v) { return clamp(v, 0.0, 1.0); } + vec2 saturate(vec2 v) { return clamp(v, vec2(0.0), vec2(1.0)); } + vec3 saturate(vec3 v) { return clamp(v, vec3(0.0), vec3(1.0)); } + vec4 saturate(vec4 v) { return clamp(v, vec4(0.0), vec4(1.0)); } + + void main() { + float v1 = saturate(1.5); + vec2 v2 = saturate(vec2(0.5, 1.5)); + vec3 v3 = saturate(vec3(0.5, 1.5, 2.5)); + vec3 v4 = saturate(vec4(0.5, 1.5, 2.5, 3.5)); + } + "#, + ) + .unwrap(); +} + +#[test] +fn implicit_conversions() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + mat4 a = mat4(1); + float b = 1u; + float c = 1 + 2.0; + } + "#, + ) + .unwrap(); + + assert_eq!( + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void test(int a) {} + void test(uint a) {} + + void main() { + test(1.0); + } + "#, + ) + .err() + .unwrap(), + ParseError { + errors: vec![Error { + kind: ErrorKind::SemanticError("Unknown function \'test\'".into()), + meta: Span::new(156, 165), + }] + }, + ); + + assert_eq!( + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void test(float a) {} + void test(uint a) {} + + void main() { + test(1); + } + "#, + ) + .err() + .unwrap(), + ParseError { + errors: vec![Error { + kind: ErrorKind::SemanticError("Ambiguous best function for \'test\'".into()), + meta: Span::new(158, 165), + }] + } + ); +} + +#[test] +fn structs() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + Test { + vec4 pos; + } xx; + + void main() {} + "#, + ) + .unwrap_err(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + struct Test { + vec4 pos; + }; + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + const int NUM_VECS = 42; + struct Test { + vec4 vecs[NUM_VECS]; + }; + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + struct Hello { + vec4 test; + } test() { + return Hello( vec4(1.0) ); + } + + void main() {} + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + struct Test {}; + + void main() {} + "#, + ) + .unwrap_err(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + inout struct Test { + vec4 x; + }; + + void main() {} + "#, + ) + .unwrap_err(); +} + +#[test] +fn swizzles() { + let mut frontend = Frontend::default(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + vec4 v = vec4(1); + v.xyz = vec3(2); + v.x = 5.0; + v.xyz.zxy.yx.xy = vec2(5.0, 1.0); + } + "#, + ) + .unwrap(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + vec4 v = vec4(1); + v.xx = vec2(5.0); + } + "#, + ) + .unwrap_err(); + + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + vec3 v = vec3(1); + v.w = 2.0; + } + "#, + ) + .unwrap_err(); +} + +#[test] +fn expressions() { + let mut frontend = Frontend::default(); + + // Vector indexing + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + float test(int index) { + vec4 v = vec4(1.0, 2.0, 3.0, 4.0); + return v[index] + 1.0; + } + + void main() {} + "#, + ) + .unwrap(); + + // Prefix increment/decrement + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + uint index = 0; + + --index; + ++index; + } + "#, + ) + .unwrap(); + + // Dynamic indexing of array + frontend + .parse( + &Options::from(ShaderStage::Vertex), + r#" + # version 450 + void main() { + const vec4 positions[1] = { vec4(0) }; + + gl_Position = positions[gl_VertexIndex]; + } + "#, + ) + .unwrap(); +} diff --git a/third_party/rust/naga/src/front/glsl/token.rs b/third_party/rust/naga/src/front/glsl/token.rs new file mode 100644 index 0000000000..303723a27b --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/token.rs @@ -0,0 +1,137 @@ +pub use pp_rs::token::{Float, Integer, Location, Token as PPToken}; + +use super::ast::Precision; +use crate::{Interpolation, Sampling, Span, Type}; + +impl From<Location> for Span { + fn from(loc: Location) -> Self { + Span::new(loc.start, loc.end) + } +} + +#[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub struct Token { + pub value: TokenValue, + pub meta: Span, +} + +/// A token passed from the lexing used in the parsing. +/// +/// This type is exported since it's returned in the +/// [`InvalidToken`](super::ErrorKind::InvalidToken) error. +#[derive(Clone, Debug, PartialEq)] +pub enum TokenValue { + Identifier(String), + + FloatConstant(Float), + IntConstant(Integer), + BoolConstant(bool), + + Layout, + In, + Out, + InOut, + Uniform, + Buffer, + Const, + Shared, + + Restrict, + /// A `glsl` memory qualifier such as `writeonly` + /// + /// The associated [`crate::StorageAccess`] is the access being allowed + /// (for example `writeonly` has an associated value of [`crate::StorageAccess::STORE`]) + MemoryQualifier(crate::StorageAccess), + + Invariant, + Interpolation(Interpolation), + Sampling(Sampling), + Precision, + PrecisionQualifier(Precision), + + Continue, + Break, + Return, + Discard, + + If, + Else, + Switch, + Case, + Default, + While, + Do, + For, + + Void, + Struct, + TypeName(Type), + + Assign, + AddAssign, + SubAssign, + MulAssign, + DivAssign, + ModAssign, + LeftShiftAssign, + RightShiftAssign, + AndAssign, + XorAssign, + OrAssign, + + Increment, + Decrement, + + LogicalOr, + LogicalAnd, + LogicalXor, + + LessEqual, + GreaterEqual, + Equal, + NotEqual, + + LeftShift, + RightShift, + + LeftBrace, + RightBrace, + LeftParen, + RightParen, + LeftBracket, + RightBracket, + LeftAngle, + RightAngle, + + Comma, + Semicolon, + Colon, + Dot, + Bang, + Dash, + Tilde, + Plus, + Star, + Slash, + Percent, + VerticalBar, + Caret, + Ampersand, + Question, +} + +#[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub struct Directive { + pub kind: DirectiveKind, + pub tokens: Vec<PPToken>, +} + +#[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub enum DirectiveKind { + Version { is_first_directive: bool }, + Extension, + Pragma, +} diff --git a/third_party/rust/naga/src/front/glsl/types.rs b/third_party/rust/naga/src/front/glsl/types.rs new file mode 100644 index 0000000000..e87d76fffc --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/types.rs @@ -0,0 +1,360 @@ +use super::{context::Context, Error, ErrorKind, Result, Span}; +use crate::{ + proc::ResolveContext, Expression, Handle, ImageClass, ImageDimension, Scalar, ScalarKind, Type, + TypeInner, VectorSize, +}; + +pub fn parse_type(type_name: &str) -> Option<Type> { + match type_name { + "bool" => Some(Type { + name: None, + inner: TypeInner::Scalar(Scalar::BOOL), + }), + "float" => Some(Type { + name: None, + inner: TypeInner::Scalar(Scalar::F32), + }), + "double" => Some(Type { + name: None, + inner: TypeInner::Scalar(Scalar::F64), + }), + "int" => Some(Type { + name: None, + inner: TypeInner::Scalar(Scalar::I32), + }), + "uint" => Some(Type { + name: None, + inner: TypeInner::Scalar(Scalar::U32), + }), + "sampler" | "samplerShadow" => Some(Type { + name: None, + inner: TypeInner::Sampler { + comparison: type_name == "samplerShadow", + }, + }), + word => { + fn kind_width_parse(ty: &str) -> Option<Scalar> { + Some(match ty { + "" => Scalar::F32, + "b" => Scalar::BOOL, + "i" => Scalar::I32, + "u" => Scalar::U32, + "d" => Scalar::F64, + _ => return None, + }) + } + + fn size_parse(n: &str) -> Option<VectorSize> { + Some(match n { + "2" => VectorSize::Bi, + "3" => VectorSize::Tri, + "4" => VectorSize::Quad, + _ => return None, + }) + } + + let vec_parse = |word: &str| { + let mut iter = word.split("vec"); + + let kind = iter.next()?; + let size = iter.next()?; + let scalar = kind_width_parse(kind)?; + let size = size_parse(size)?; + + Some(Type { + name: None, + inner: TypeInner::Vector { size, scalar }, + }) + }; + + let mat_parse = |word: &str| { + let mut iter = word.split("mat"); + + let kind = iter.next()?; + let size = iter.next()?; + let scalar = kind_width_parse(kind)?; + + let (columns, rows) = if let Some(size) = size_parse(size) { + (size, size) + } else { + let mut iter = size.split('x'); + match (iter.next()?, iter.next()?, iter.next()) { + (col, row, None) => (size_parse(col)?, size_parse(row)?), + _ => return None, + } + }; + + Some(Type { + name: None, + inner: TypeInner::Matrix { + columns, + rows, + scalar, + }, + }) + }; + + let texture_parse = |word: &str| { + let mut iter = word.split("texture"); + + let texture_kind = |ty| { + Some(match ty { + "" => ScalarKind::Float, + "i" => ScalarKind::Sint, + "u" => ScalarKind::Uint, + _ => return None, + }) + }; + + let kind = iter.next()?; + let size = iter.next()?; + let kind = texture_kind(kind)?; + + let sampled = |multi| ImageClass::Sampled { kind, multi }; + + let (dim, arrayed, class) = match size { + "1D" => (ImageDimension::D1, false, sampled(false)), + "1DArray" => (ImageDimension::D1, true, sampled(false)), + "2D" => (ImageDimension::D2, false, sampled(false)), + "2DArray" => (ImageDimension::D2, true, sampled(false)), + "2DMS" => (ImageDimension::D2, false, sampled(true)), + "2DMSArray" => (ImageDimension::D2, true, sampled(true)), + "3D" => (ImageDimension::D3, false, sampled(false)), + "Cube" => (ImageDimension::Cube, false, sampled(false)), + "CubeArray" => (ImageDimension::Cube, true, sampled(false)), + _ => return None, + }; + + Some(Type { + name: None, + inner: TypeInner::Image { + dim, + arrayed, + class, + }, + }) + }; + + let image_parse = |word: &str| { + let mut iter = word.split("image"); + + let texture_kind = |ty| { + Some(match ty { + "" => ScalarKind::Float, + "i" => ScalarKind::Sint, + "u" => ScalarKind::Uint, + _ => return None, + }) + }; + + let kind = iter.next()?; + let size = iter.next()?; + // TODO: Check that the texture format and the kind match + let _ = texture_kind(kind)?; + + let class = ImageClass::Storage { + format: crate::StorageFormat::R8Uint, + access: crate::StorageAccess::all(), + }; + + // TODO: glsl support multisampled storage images, naga doesn't + let (dim, arrayed) = match size { + "1D" => (ImageDimension::D1, false), + "1DArray" => (ImageDimension::D1, true), + "2D" => (ImageDimension::D2, false), + "2DArray" => (ImageDimension::D2, true), + "3D" => (ImageDimension::D3, false), + // Naga doesn't support cube images and it's usefulness + // is questionable, so they won't be supported for now + // "Cube" => (ImageDimension::Cube, false), + // "CubeArray" => (ImageDimension::Cube, true), + _ => return None, + }; + + Some(Type { + name: None, + inner: TypeInner::Image { + dim, + arrayed, + class, + }, + }) + }; + + vec_parse(word) + .or_else(|| mat_parse(word)) + .or_else(|| texture_parse(word)) + .or_else(|| image_parse(word)) + } + } +} + +pub const fn scalar_components(ty: &TypeInner) -> Option<Scalar> { + match *ty { + TypeInner::Scalar(scalar) + | TypeInner::Vector { scalar, .. } + | TypeInner::ValuePointer { scalar, .. } + | TypeInner::Matrix { scalar, .. } => Some(scalar), + _ => None, + } +} + +pub const fn type_power(scalar: Scalar) -> Option<u32> { + Some(match scalar.kind { + ScalarKind::Sint => 0, + ScalarKind::Uint => 1, + ScalarKind::Float if scalar.width == 4 => 2, + ScalarKind::Float => 3, + ScalarKind::Bool | ScalarKind::AbstractInt | ScalarKind::AbstractFloat => return None, + }) +} + +impl Context<'_> { + /// Resolves the types of the expressions until `expr` (inclusive) + /// + /// This needs to be done before the [`typifier`] can be queried for + /// the types of the expressions in the range between the last grow and `expr`. + /// + /// # Note + /// + /// The `resolve_type*` methods (like [`resolve_type`]) automatically + /// grow the [`typifier`] so calling this method is not necessary when using + /// them. + /// + /// [`typifier`]: Context::typifier + /// [`resolve_type`]: Self::resolve_type + pub(crate) fn typifier_grow(&mut self, expr: Handle<Expression>, meta: Span) -> Result<()> { + let resolve_ctx = ResolveContext::with_locals(self.module, &self.locals, &self.arguments); + + let typifier = if self.is_const { + &mut self.const_typifier + } else { + &mut self.typifier + }; + + let expressions = if self.is_const { + &self.module.const_expressions + } else { + &self.expressions + }; + + typifier + .grow(expr, expressions, &resolve_ctx) + .map_err(|error| Error { + kind: ErrorKind::SemanticError(format!("Can't resolve type: {error:?}").into()), + meta, + }) + } + + pub(crate) fn get_type(&self, expr: Handle<Expression>) -> &TypeInner { + let typifier = if self.is_const { + &self.const_typifier + } else { + &self.typifier + }; + + typifier.get(expr, &self.module.types) + } + + /// Gets the type for the result of the `expr` expression + /// + /// Automatically grows the [`typifier`] to `expr` so calling + /// [`typifier_grow`] is not necessary + /// + /// [`typifier`]: Context::typifier + /// [`typifier_grow`]: Self::typifier_grow + pub(crate) fn resolve_type( + &mut self, + expr: Handle<Expression>, + meta: Span, + ) -> Result<&TypeInner> { + self.typifier_grow(expr, meta)?; + Ok(self.get_type(expr)) + } + + /// Gets the type handle for the result of the `expr` expression + /// + /// Automatically grows the [`typifier`] to `expr` so calling + /// [`typifier_grow`] is not necessary + /// + /// # Note + /// + /// Consider using [`resolve_type`] whenever possible + /// since it doesn't require adding each type to the [`types`] arena + /// and it doesn't need to mutably borrow the [`Parser`][Self] + /// + /// [`types`]: crate::Module::types + /// [`typifier`]: Context::typifier + /// [`typifier_grow`]: Self::typifier_grow + /// [`resolve_type`]: Self::resolve_type + pub(crate) fn resolve_type_handle( + &mut self, + expr: Handle<Expression>, + meta: Span, + ) -> Result<Handle<Type>> { + self.typifier_grow(expr, meta)?; + + let typifier = if self.is_const { + &mut self.const_typifier + } else { + &mut self.typifier + }; + + Ok(typifier.register_type(expr, &mut self.module.types)) + } + + /// Invalidates the cached type resolution for `expr` forcing a recomputation + pub(crate) fn invalidate_expression( + &mut self, + expr: Handle<Expression>, + meta: Span, + ) -> Result<()> { + let resolve_ctx = ResolveContext::with_locals(self.module, &self.locals, &self.arguments); + + let typifier = if self.is_const { + &mut self.const_typifier + } else { + &mut self.typifier + }; + + typifier + .invalidate(expr, &self.expressions, &resolve_ctx) + .map_err(|error| Error { + kind: ErrorKind::SemanticError(format!("Can't resolve type: {error:?}").into()), + meta, + }) + } + + pub(crate) fn lift_up_const_expression( + &mut self, + expr: Handle<Expression>, + ) -> Result<Handle<Expression>> { + let meta = self.expressions.get_span(expr); + Ok(match self.expressions[expr] { + ref expr @ (Expression::Literal(_) + | Expression::Constant(_) + | Expression::ZeroValue(_)) => self.module.const_expressions.append(expr.clone(), meta), + Expression::Compose { ty, ref components } => { + let mut components = components.clone(); + for component in &mut components { + *component = self.lift_up_const_expression(*component)?; + } + self.module + .const_expressions + .append(Expression::Compose { ty, components }, meta) + } + Expression::Splat { size, value } => { + let value = self.lift_up_const_expression(value)?; + self.module + .const_expressions + .append(Expression::Splat { size, value }, meta) + } + _ => { + return Err(Error { + kind: ErrorKind::SemanticError("Expression is not const-expression".into()), + meta, + }) + } + }) + } +} diff --git a/third_party/rust/naga/src/front/glsl/variables.rs b/third_party/rust/naga/src/front/glsl/variables.rs new file mode 100644 index 0000000000..5af2b228f0 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/variables.rs @@ -0,0 +1,646 @@ +use super::{ + ast::*, + context::{Context, ExprPos}, + error::{Error, ErrorKind}, + Frontend, Result, Span, +}; +use crate::{ + AddressSpace, Binding, BuiltIn, Constant, Expression, GlobalVariable, Handle, Interpolation, + LocalVariable, ResourceBinding, Scalar, ScalarKind, ShaderStage, SwizzleComponent, Type, + TypeInner, VectorSize, +}; + +pub struct VarDeclaration<'a, 'key> { + pub qualifiers: &'a mut TypeQualifiers<'key>, + pub ty: Handle<Type>, + pub name: Option<String>, + pub init: Option<Handle<Expression>>, + pub meta: Span, +} + +/// Information about a builtin used in [`add_builtin`](Frontend::add_builtin). +struct BuiltInData { + /// The type of the builtin. + inner: TypeInner, + /// The associated builtin class. + builtin: BuiltIn, + /// Whether the builtin can be written to or not. + mutable: bool, + /// The storage used for the builtin. + storage: StorageQualifier, +} + +pub enum GlobalOrConstant { + Global(Handle<GlobalVariable>), + Constant(Handle<Constant>), +} + +impl Frontend { + /// Adds a builtin and returns a variable reference to it + fn add_builtin( + &mut self, + ctx: &mut Context, + name: &str, + data: BuiltInData, + meta: Span, + ) -> Result<Option<VariableReference>> { + let ty = ctx.module.types.insert( + Type { + name: None, + inner: data.inner, + }, + meta, + ); + + let handle = ctx.module.global_variables.append( + GlobalVariable { + name: Some(name.into()), + space: AddressSpace::Private, + binding: None, + ty, + init: None, + }, + meta, + ); + + let idx = self.entry_args.len(); + self.entry_args.push(EntryArg { + name: None, + binding: Binding::BuiltIn(data.builtin), + handle, + storage: data.storage, + }); + + self.global_variables.push(( + name.into(), + GlobalLookup { + kind: GlobalLookupKind::Variable(handle), + entry_arg: Some(idx), + mutable: data.mutable, + }, + )); + + let expr = ctx.add_expression(Expression::GlobalVariable(handle), meta)?; + + let var = VariableReference { + expr, + load: true, + mutable: data.mutable, + constant: None, + entry_arg: Some(idx), + }; + + ctx.symbol_table.add_root(name.into(), var.clone()); + + Ok(Some(var)) + } + + pub(crate) fn lookup_variable( + &mut self, + ctx: &mut Context, + name: &str, + meta: Span, + ) -> Result<Option<VariableReference>> { + if let Some(var) = ctx.symbol_table.lookup(name).cloned() { + return Ok(Some(var)); + } + + let data = match name { + "gl_Position" => BuiltInData { + inner: TypeInner::Vector { + size: VectorSize::Quad, + scalar: Scalar::F32, + }, + builtin: BuiltIn::Position { invariant: false }, + mutable: true, + storage: StorageQualifier::Output, + }, + "gl_FragCoord" => BuiltInData { + inner: TypeInner::Vector { + size: VectorSize::Quad, + scalar: Scalar::F32, + }, + builtin: BuiltIn::Position { invariant: false }, + mutable: false, + storage: StorageQualifier::Input, + }, + "gl_PointCoord" => BuiltInData { + inner: TypeInner::Vector { + size: VectorSize::Bi, + scalar: Scalar::F32, + }, + builtin: BuiltIn::PointCoord, + mutable: false, + storage: StorageQualifier::Input, + }, + "gl_GlobalInvocationID" + | "gl_NumWorkGroups" + | "gl_WorkGroupSize" + | "gl_WorkGroupID" + | "gl_LocalInvocationID" => BuiltInData { + inner: TypeInner::Vector { + size: VectorSize::Tri, + scalar: Scalar::U32, + }, + builtin: match name { + "gl_GlobalInvocationID" => BuiltIn::GlobalInvocationId, + "gl_NumWorkGroups" => BuiltIn::NumWorkGroups, + "gl_WorkGroupSize" => BuiltIn::WorkGroupSize, + "gl_WorkGroupID" => BuiltIn::WorkGroupId, + "gl_LocalInvocationID" => BuiltIn::LocalInvocationId, + _ => unreachable!(), + }, + mutable: false, + storage: StorageQualifier::Input, + }, + "gl_FrontFacing" => BuiltInData { + inner: TypeInner::Scalar(Scalar::BOOL), + builtin: BuiltIn::FrontFacing, + mutable: false, + storage: StorageQualifier::Input, + }, + "gl_PointSize" | "gl_FragDepth" => BuiltInData { + inner: TypeInner::Scalar(Scalar::F32), + builtin: match name { + "gl_PointSize" => BuiltIn::PointSize, + "gl_FragDepth" => BuiltIn::FragDepth, + _ => unreachable!(), + }, + mutable: true, + storage: StorageQualifier::Output, + }, + "gl_ClipDistance" | "gl_CullDistance" => { + let base = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Scalar(Scalar::F32), + }, + meta, + ); + + BuiltInData { + inner: TypeInner::Array { + base, + size: crate::ArraySize::Dynamic, + stride: 4, + }, + builtin: match name { + "gl_ClipDistance" => BuiltIn::ClipDistance, + "gl_CullDistance" => BuiltIn::CullDistance, + _ => unreachable!(), + }, + mutable: self.meta.stage == ShaderStage::Vertex, + storage: StorageQualifier::Output, + } + } + _ => { + let builtin = match name { + "gl_BaseVertex" => BuiltIn::BaseVertex, + "gl_BaseInstance" => BuiltIn::BaseInstance, + "gl_PrimitiveID" => BuiltIn::PrimitiveIndex, + "gl_InstanceIndex" => BuiltIn::InstanceIndex, + "gl_VertexIndex" => BuiltIn::VertexIndex, + "gl_SampleID" => BuiltIn::SampleIndex, + "gl_LocalInvocationIndex" => BuiltIn::LocalInvocationIndex, + _ => return Ok(None), + }; + + BuiltInData { + inner: TypeInner::Scalar(Scalar::U32), + builtin, + mutable: false, + storage: StorageQualifier::Input, + } + } + }; + + self.add_builtin(ctx, name, data, meta) + } + + pub(crate) fn make_variable_invariant( + &mut self, + ctx: &mut Context, + name: &str, + meta: Span, + ) -> Result<()> { + if let Some(var) = self.lookup_variable(ctx, name, meta)? { + if let Some(index) = var.entry_arg { + if let Binding::BuiltIn(BuiltIn::Position { ref mut invariant }) = + self.entry_args[index].binding + { + *invariant = true; + } + } + } + Ok(()) + } + + pub(crate) fn field_selection( + &mut self, + ctx: &mut Context, + pos: ExprPos, + expression: Handle<Expression>, + name: &str, + meta: Span, + ) -> Result<Handle<Expression>> { + let (ty, is_pointer) = match *ctx.resolve_type(expression, meta)? { + TypeInner::Pointer { base, .. } => (&ctx.module.types[base].inner, true), + ref ty => (ty, false), + }; + match *ty { + TypeInner::Struct { ref members, .. } => { + let index = members + .iter() + .position(|m| m.name == Some(name.into())) + .ok_or_else(|| Error { + kind: ErrorKind::UnknownField(name.into()), + meta, + })?; + let pointer = ctx.add_expression( + Expression::AccessIndex { + base: expression, + index: index as u32, + }, + meta, + )?; + + Ok(match pos { + ExprPos::Rhs if is_pointer => { + ctx.add_expression(Expression::Load { pointer }, meta)? + } + _ => pointer, + }) + } + // swizzles (xyzw, rgba, stpq) + TypeInner::Vector { size, .. } => { + let check_swizzle_components = |comps: &str| { + name.chars() + .map(|c| { + comps + .find(c) + .filter(|i| *i < size as usize) + .map(|i| SwizzleComponent::from_index(i as u32)) + }) + .collect::<Option<Vec<SwizzleComponent>>>() + }; + + let components = check_swizzle_components("xyzw") + .or_else(|| check_swizzle_components("rgba")) + .or_else(|| check_swizzle_components("stpq")); + + if let Some(components) = components { + if let ExprPos::Lhs = pos { + let not_unique = (1..components.len()) + .any(|i| components[i..].contains(&components[i - 1])); + if not_unique { + self.errors.push(Error { + kind: + ErrorKind::SemanticError( + format!( + "swizzle cannot have duplicate components in left-hand-side expression for \"{name:?}\"" + ) + .into(), + ), + meta , + }) + } + } + + let mut pattern = [SwizzleComponent::X; 4]; + for (pat, component) in pattern.iter_mut().zip(&components) { + *pat = *component; + } + + // flatten nested swizzles (vec.zyx.xy.x => vec.z) + let mut expression = expression; + if let Expression::Swizzle { + size: _, + vector, + pattern: ref src_pattern, + } = ctx[expression] + { + expression = vector; + for pat in &mut pattern { + *pat = src_pattern[pat.index() as usize]; + } + } + + let size = match components.len() { + // Swizzles with just one component are accesses and not swizzles + 1 => { + match pos { + // If the position is in the right hand side and the base + // vector is a pointer, load it, otherwise the swizzle would + // produce a pointer + ExprPos::Rhs if is_pointer => { + expression = ctx.add_expression( + Expression::Load { + pointer: expression, + }, + meta, + )?; + } + _ => {} + }; + return ctx.add_expression( + Expression::AccessIndex { + base: expression, + index: pattern[0].index(), + }, + meta, + ); + } + 2 => VectorSize::Bi, + 3 => VectorSize::Tri, + 4 => VectorSize::Quad, + _ => { + self.errors.push(Error { + kind: ErrorKind::SemanticError( + format!("Bad swizzle size for \"{name:?}\"").into(), + ), + meta, + }); + + VectorSize::Quad + } + }; + + if is_pointer { + // NOTE: for lhs expression, this extra load ends up as an unused expr, because the + // assignment will extract the pointer and use it directly anyway. Unfortunately we + // need it for validation to pass, as swizzles cannot operate on pointer values. + expression = ctx.add_expression( + Expression::Load { + pointer: expression, + }, + meta, + )?; + } + + Ok(ctx.add_expression( + Expression::Swizzle { + size, + vector: expression, + pattern, + }, + meta, + )?) + } else { + Err(Error { + kind: ErrorKind::SemanticError( + format!("Invalid swizzle for vector \"{name}\"").into(), + ), + meta, + }) + } + } + _ => Err(Error { + kind: ErrorKind::SemanticError( + format!("Can't lookup field on this type \"{name}\"").into(), + ), + meta, + }), + } + } + + pub(crate) fn add_global_var( + &mut self, + ctx: &mut Context, + VarDeclaration { + qualifiers, + mut ty, + name, + init, + meta, + }: VarDeclaration, + ) -> Result<GlobalOrConstant> { + let storage = qualifiers.storage.0; + let (ret, lookup) = match storage { + StorageQualifier::Input | StorageQualifier::Output => { + let input = storage == StorageQualifier::Input; + // TODO: glslang seems to use a counter for variables without + // explicit location (even if that causes collisions) + let location = qualifiers + .uint_layout_qualifier("location", &mut self.errors) + .unwrap_or(0); + let interpolation = qualifiers.interpolation.take().map(|(i, _)| i).or_else(|| { + let kind = ctx.module.types[ty].inner.scalar_kind()?; + Some(match kind { + ScalarKind::Float => Interpolation::Perspective, + _ => Interpolation::Flat, + }) + }); + let sampling = qualifiers.sampling.take().map(|(s, _)| s); + + let handle = ctx.module.global_variables.append( + GlobalVariable { + name: name.clone(), + space: AddressSpace::Private, + binding: None, + ty, + init, + }, + meta, + ); + + let idx = self.entry_args.len(); + self.entry_args.push(EntryArg { + name: name.clone(), + binding: Binding::Location { + location, + interpolation, + sampling, + second_blend_source: false, + }, + handle, + storage, + }); + + let lookup = GlobalLookup { + kind: GlobalLookupKind::Variable(handle), + entry_arg: Some(idx), + mutable: !input, + }; + + (GlobalOrConstant::Global(handle), lookup) + } + StorageQualifier::Const => { + let init = init.ok_or_else(|| Error { + kind: ErrorKind::SemanticError("const values must have an initializer".into()), + meta, + })?; + + let constant = Constant { + name: name.clone(), + r#override: crate::Override::None, + ty, + init, + }; + let handle = ctx.module.constants.fetch_or_append(constant, meta); + + let lookup = GlobalLookup { + kind: GlobalLookupKind::Constant(handle, ty), + entry_arg: None, + mutable: false, + }; + + (GlobalOrConstant::Constant(handle), lookup) + } + StorageQualifier::AddressSpace(mut space) => { + match space { + AddressSpace::Storage { ref mut access } => { + if let Some((allowed_access, _)) = qualifiers.storage_access.take() { + *access = allowed_access; + } + } + AddressSpace::Uniform => match ctx.module.types[ty].inner { + TypeInner::Image { + class, + dim, + arrayed, + } => { + if let crate::ImageClass::Storage { + mut access, + mut format, + } = class + { + if let Some((allowed_access, _)) = qualifiers.storage_access.take() + { + access = allowed_access; + } + + match qualifiers.layout_qualifiers.remove(&QualifierKey::Format) { + Some((QualifierValue::Format(f), _)) => format = f, + // TODO: glsl supports images without format qualifier + // if they are `writeonly` + None => self.errors.push(Error { + kind: ErrorKind::SemanticError( + "image types require a format layout qualifier".into(), + ), + meta, + }), + _ => unreachable!(), + } + + ty = ctx.module.types.insert( + Type { + name: None, + inner: TypeInner::Image { + dim, + arrayed, + class: crate::ImageClass::Storage { format, access }, + }, + }, + meta, + ); + } + + space = AddressSpace::Handle + } + TypeInner::Sampler { .. } => space = AddressSpace::Handle, + _ => { + if qualifiers.none_layout_qualifier("push_constant", &mut self.errors) { + space = AddressSpace::PushConstant + } + } + }, + AddressSpace::Function => space = AddressSpace::Private, + _ => {} + }; + + let binding = match space { + AddressSpace::Uniform | AddressSpace::Storage { .. } | AddressSpace::Handle => { + let binding = qualifiers.uint_layout_qualifier("binding", &mut self.errors); + if binding.is_none() { + self.errors.push(Error { + kind: ErrorKind::SemanticError( + "uniform/buffer blocks require layout(binding=X)".into(), + ), + meta, + }); + } + let set = qualifiers.uint_layout_qualifier("set", &mut self.errors); + binding.map(|binding| ResourceBinding { + group: set.unwrap_or(0), + binding, + }) + } + _ => None, + }; + + let handle = ctx.module.global_variables.append( + GlobalVariable { + name: name.clone(), + space, + binding, + ty, + init, + }, + meta, + ); + + let lookup = GlobalLookup { + kind: GlobalLookupKind::Variable(handle), + entry_arg: None, + mutable: true, + }; + + (GlobalOrConstant::Global(handle), lookup) + } + }; + + if let Some(name) = name { + ctx.add_global(&name, lookup)?; + + self.global_variables.push((name, lookup)); + } + + qualifiers.unused_errors(&mut self.errors); + + Ok(ret) + } + + pub(crate) fn add_local_var( + &mut self, + ctx: &mut Context, + decl: VarDeclaration, + ) -> Result<Handle<Expression>> { + let storage = decl.qualifiers.storage; + let mutable = match storage.0 { + StorageQualifier::AddressSpace(AddressSpace::Function) => true, + StorageQualifier::Const => false, + _ => { + self.errors.push(Error { + kind: ErrorKind::SemanticError("Locals cannot have a storage qualifier".into()), + meta: storage.1, + }); + true + } + }; + + let handle = ctx.locals.append( + LocalVariable { + name: decl.name.clone(), + ty: decl.ty, + init: decl.init, + }, + decl.meta, + ); + let expr = ctx.add_expression(Expression::LocalVariable(handle), decl.meta)?; + + if let Some(name) = decl.name { + let maybe_var = ctx.add_local_var(name.clone(), expr, mutable); + + if maybe_var.is_some() { + self.errors.push(Error { + kind: ErrorKind::VariableAlreadyDeclared(name), + meta: decl.meta, + }) + } + } + + decl.qualifiers.unused_errors(&mut self.errors); + + Ok(expr) + } +} diff --git a/third_party/rust/naga/src/front/interpolator.rs b/third_party/rust/naga/src/front/interpolator.rs new file mode 100644 index 0000000000..0196a2254d --- /dev/null +++ b/third_party/rust/naga/src/front/interpolator.rs @@ -0,0 +1,62 @@ +/*! +Interpolation defaults. +*/ + +impl crate::Binding { + /// Apply the usual default interpolation for `ty` to `binding`. + /// + /// This function is a utility front ends may use to satisfy the Naga IR's + /// requirement, meant to ensure that input languages' policies have been + /// applied appropriately, that all I/O `Binding`s from the vertex shader to the + /// fragment shader must have non-`None` `interpolation` values. + /// + /// All the shader languages Naga supports have similar rules: + /// perspective-correct, center-sampled interpolation is the default for any + /// binding that can vary, and everything else either defaults to flat, or + /// requires an explicit flat qualifier/attribute/what-have-you. + /// + /// If `binding` is not a [`Location`] binding, or if its [`interpolation`] is + /// already set, then make no changes. Otherwise, set `binding`'s interpolation + /// and sampling to reasonable defaults depending on `ty`, the type of the value + /// being interpolated: + /// + /// - If `ty` is a floating-point scalar, vector, or matrix type, then + /// default to [`Perspective`] interpolation and [`Center`] sampling. + /// + /// - If `ty` is an integral scalar or vector, then default to [`Flat`] + /// interpolation, which has no associated sampling. + /// + /// - For any other types, make no change. Such types are not permitted as + /// user-defined IO values, and will probably be flagged by the verifier + /// + /// When structs appear in input or output types, each member ought to have its + /// own [`Binding`], so structs are simply covered by the third case. + /// + /// [`Binding`]: crate::Binding + /// [`Location`]: crate::Binding::Location + /// [`interpolation`]: crate::Binding::Location::interpolation + /// [`Perspective`]: crate::Interpolation::Perspective + /// [`Flat`]: crate::Interpolation::Flat + /// [`Center`]: crate::Sampling::Center + pub fn apply_default_interpolation(&mut self, ty: &crate::TypeInner) { + if let crate::Binding::Location { + location: _, + interpolation: ref mut interpolation @ None, + ref mut sampling, + second_blend_source: _, + } = *self + { + match ty.scalar_kind() { + Some(crate::ScalarKind::Float) => { + *interpolation = Some(crate::Interpolation::Perspective); + *sampling = Some(crate::Sampling::Center); + } + Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => { + *interpolation = Some(crate::Interpolation::Flat); + *sampling = None; + } + Some(_) | None => {} + } + } + } +} diff --git a/third_party/rust/naga/src/front/mod.rs b/third_party/rust/naga/src/front/mod.rs new file mode 100644 index 0000000000..e1f99452e1 --- /dev/null +++ b/third_party/rust/naga/src/front/mod.rs @@ -0,0 +1,328 @@ +/*! +Frontend parsers that consume binary and text shaders and load them into [`Module`](super::Module)s. +*/ + +mod interpolator; +mod type_gen; + +#[cfg(feature = "glsl-in")] +pub mod glsl; +#[cfg(feature = "spv-in")] +pub mod spv; +#[cfg(feature = "wgsl-in")] +pub mod wgsl; + +use crate::{ + arena::{Arena, Handle, UniqueArena}, + proc::{ResolveContext, ResolveError, TypeResolution}, + FastHashMap, +}; +use std::ops; + +/// A table of types for an `Arena<Expression>`. +/// +/// A front end can use a `Typifier` to get types for an arena's expressions +/// while it is still contributing expressions to it. At any point, you can call +/// [`typifier.grow(expr, arena, ctx)`], where `expr` is a `Handle<Expression>` +/// referring to something in `arena`, and the `Typifier` will resolve the types +/// of all the expressions up to and including `expr`. Then you can write +/// `typifier[handle]` to get the type of any handle at or before `expr`. +/// +/// Note that `Typifier` does *not* build an `Arena<Type>` as a part of its +/// usual operation. Ideally, a module's type arena should only contain types +/// actually needed by `Handle<Type>`s elsewhere in the module — functions, +/// variables, [`Compose`] expressions, other types, and so on — so we don't +/// want every little thing that occurs as the type of some intermediate +/// expression to show up there. +/// +/// Instead, `Typifier` accumulates a [`TypeResolution`] for each expression, +/// which refers to the `Arena<Type>` in the [`ResolveContext`] passed to `grow` +/// as needed. [`TypeResolution`] is a lightweight representation for +/// intermediate types like this; see its documentation for details. +/// +/// If you do need to register a `Typifier`'s conclusion in an `Arena<Type>` +/// (say, for a [`LocalVariable`] whose type you've inferred), you can use +/// [`register_type`] to do so. +/// +/// [`typifier.grow(expr, arena)`]: Typifier::grow +/// [`register_type`]: Typifier::register_type +/// [`Compose`]: crate::Expression::Compose +/// [`LocalVariable`]: crate::LocalVariable +#[derive(Debug, Default)] +pub struct Typifier { + resolutions: Vec<TypeResolution>, +} + +impl Typifier { + pub const fn new() -> Self { + Typifier { + resolutions: Vec::new(), + } + } + + pub fn reset(&mut self) { + self.resolutions.clear() + } + + pub fn get<'a>( + &'a self, + expr_handle: Handle<crate::Expression>, + types: &'a UniqueArena<crate::Type>, + ) -> &'a crate::TypeInner { + self.resolutions[expr_handle.index()].inner_with(types) + } + + /// Add an expression's type to an `Arena<Type>`. + /// + /// Add the type of `expr_handle` to `types`, and return a `Handle<Type>` + /// referring to it. + /// + /// # Note + /// + /// If you just need a [`TypeInner`] for `expr_handle`'s type, consider + /// using `typifier[expression].inner_with(types)` instead. Calling + /// [`TypeResolution::inner_with`] often lets us avoid adding anything to + /// the arena, which can significantly reduce the number of types that end + /// up in the final module. + /// + /// [`TypeInner`]: crate::TypeInner + pub fn register_type( + &self, + expr_handle: Handle<crate::Expression>, + types: &mut UniqueArena<crate::Type>, + ) -> Handle<crate::Type> { + match self[expr_handle].clone() { + TypeResolution::Handle(handle) => handle, + TypeResolution::Value(inner) => { + types.insert(crate::Type { name: None, inner }, crate::Span::UNDEFINED) + } + } + } + + /// Grow this typifier until it contains a type for `expr_handle`. + pub fn grow( + &mut self, + expr_handle: Handle<crate::Expression>, + expressions: &Arena<crate::Expression>, + ctx: &ResolveContext, + ) -> Result<(), ResolveError> { + if self.resolutions.len() <= expr_handle.index() { + for (eh, expr) in expressions.iter().skip(self.resolutions.len()) { + //Note: the closure can't `Err` by construction + let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h.index()]))?; + log::debug!("Resolving {:?} = {:?} : {:?}", eh, expr, resolution); + self.resolutions.push(resolution); + } + } + Ok(()) + } + + /// Recompute the type resolution for `expr_handle`. + /// + /// If the type of `expr_handle` hasn't yet been calculated, call + /// [`grow`](Self::grow) to ensure it is covered. + /// + /// In either case, when this returns, `self[expr_handle]` should be an + /// updated type resolution for `expr_handle`. + pub fn invalidate( + &mut self, + expr_handle: Handle<crate::Expression>, + expressions: &Arena<crate::Expression>, + ctx: &ResolveContext, + ) -> Result<(), ResolveError> { + if self.resolutions.len() <= expr_handle.index() { + self.grow(expr_handle, expressions, ctx) + } else { + let expr = &expressions[expr_handle]; + //Note: the closure can't `Err` by construction + let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h.index()]))?; + self.resolutions[expr_handle.index()] = resolution; + Ok(()) + } + } +} + +impl ops::Index<Handle<crate::Expression>> for Typifier { + type Output = TypeResolution; + fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output { + &self.resolutions[handle.index()] + } +} + +/// Type representing a lexical scope, associating a name to a single variable +/// +/// The scope is generic over the variable representation and name representation +/// in order to allow larger flexibility on the frontends on how they might +/// represent them. +type Scope<Name, Var> = FastHashMap<Name, Var>; + +/// Structure responsible for managing variable lookups and keeping track of +/// lexical scopes +/// +/// The symbol table is generic over the variable representation and its name +/// to allow larger flexibility on the frontends on how they might represent them. +/// +/// ``` +/// use naga::front::SymbolTable; +/// +/// // Create a new symbol table with `u32`s representing the variable +/// let mut symbol_table: SymbolTable<&str, u32> = SymbolTable::default(); +/// +/// // Add two variables named `var1` and `var2` with 0 and 2 respectively +/// symbol_table.add("var1", 0); +/// symbol_table.add("var2", 2); +/// +/// // Check that `var1` exists and is `0` +/// assert_eq!(symbol_table.lookup("var1"), Some(&0)); +/// +/// // Push a new scope and add a variable to it named `var1` shadowing the +/// // variable of our previous scope +/// symbol_table.push_scope(); +/// symbol_table.add("var1", 1); +/// +/// // Check that `var1` now points to the new value of `1` and `var2` still +/// // exists with its value of `2` +/// assert_eq!(symbol_table.lookup("var1"), Some(&1)); +/// assert_eq!(symbol_table.lookup("var2"), Some(&2)); +/// +/// // Pop the scope +/// symbol_table.pop_scope(); +/// +/// // Check that `var1` now refers to our initial variable with value `0` +/// assert_eq!(symbol_table.lookup("var1"), Some(&0)); +/// ``` +/// +/// Scopes are ordered as a LIFO stack so a variable defined in a later scope +/// with the same name as another variable defined in a earlier scope will take +/// precedence in the lookup. Scopes can be added with [`push_scope`] and +/// removed with [`pop_scope`]. +/// +/// A root scope is added when the symbol table is created and must always be +/// present. Trying to pop it will result in a panic. +/// +/// Variables can be added with [`add`] and looked up with [`lookup`]. Adding a +/// variable will do so in the currently active scope and as mentioned +/// previously a lookup will search from the current scope to the root scope. +/// +/// [`push_scope`]: Self::push_scope +/// [`pop_scope`]: Self::push_scope +/// [`add`]: Self::add +/// [`lookup`]: Self::lookup +pub struct SymbolTable<Name, Var> { + /// Stack of lexical scopes. Not all scopes are active; see [`cursor`]. + /// + /// [`cursor`]: Self::cursor + scopes: Vec<Scope<Name, Var>>, + /// Limit of the [`scopes`] stack (exclusive). By using a separate value for + /// the stack length instead of `Vec`'s own internal length, the scopes can + /// be reused to cache memory allocations. + /// + /// [`scopes`]: Self::scopes + cursor: usize, +} + +impl<Name, Var> SymbolTable<Name, Var> { + /// Adds a new lexical scope. + /// + /// All variables declared after this point will be added to this scope + /// until another scope is pushed or [`pop_scope`] is called, causing this + /// scope to be removed along with all variables added to it. + /// + /// [`pop_scope`]: Self::pop_scope + pub fn push_scope(&mut self) { + // If the cursor is equal to the scope's stack length then we need to + // push another empty scope. Otherwise we can reuse the already existing + // scope. + if self.scopes.len() == self.cursor { + self.scopes.push(FastHashMap::default()) + } else { + self.scopes[self.cursor].clear(); + } + + self.cursor += 1; + } + + /// Removes the current lexical scope and all its variables + /// + /// # PANICS + /// - If the current lexical scope is the root scope + pub fn pop_scope(&mut self) { + // Despite the method title, the variables are only deleted when the + // scope is reused. This is because while a clear is inevitable if the + // scope needs to be reused, there are cases where the scope might be + // popped and not reused, i.e. if another scope with the same nesting + // level is never pushed again. + assert!(self.cursor != 1, "Tried to pop the root scope"); + + self.cursor -= 1; + } +} + +impl<Name, Var> SymbolTable<Name, Var> +where + Name: std::hash::Hash + Eq, +{ + /// Perform a lookup for a variable named `name`. + /// + /// As stated in the struct level documentation the lookup will proceed from + /// the current scope to the root scope, returning `Some` when a variable is + /// found or `None` if there doesn't exist a variable with `name` in any + /// scope. + pub fn lookup<Q: ?Sized>(&self, name: &Q) -> Option<&Var> + where + Name: std::borrow::Borrow<Q>, + Q: std::hash::Hash + Eq, + { + // Iterate backwards trough the scopes and try to find the variable + for scope in self.scopes[..self.cursor].iter().rev() { + if let Some(var) = scope.get(name) { + return Some(var); + } + } + + None + } + + /// Adds a new variable to the current scope. + /// + /// Returns the previous variable with the same name in this scope if it + /// exists, so that the frontend might handle it in case variable shadowing + /// is disallowed. + pub fn add(&mut self, name: Name, var: Var) -> Option<Var> { + self.scopes[self.cursor - 1].insert(name, var) + } + + /// Adds a new variable to the root scope. + /// + /// This is used in GLSL for builtins which aren't known in advance and only + /// when used for the first time, so there must be a way to add those + /// declarations to the root unconditionally from the current scope. + /// + /// Returns the previous variable with the same name in the root scope if it + /// exists, so that the frontend might handle it in case variable shadowing + /// is disallowed. + pub fn add_root(&mut self, name: Name, var: Var) -> Option<Var> { + self.scopes[0].insert(name, var) + } +} + +impl<Name, Var> Default for SymbolTable<Name, Var> { + /// Constructs a new symbol table with a root scope + fn default() -> Self { + Self { + scopes: vec![FastHashMap::default()], + cursor: 1, + } + } +} + +use std::fmt; + +impl<Name: fmt::Debug, Var: fmt::Debug> fmt::Debug for SymbolTable<Name, Var> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("SymbolTable ")?; + f.debug_list() + .entries(self.scopes[..self.cursor].iter()) + .finish() + } +} diff --git a/third_party/rust/naga/src/front/spv/convert.rs b/third_party/rust/naga/src/front/spv/convert.rs new file mode 100644 index 0000000000..f0a714fbeb --- /dev/null +++ b/third_party/rust/naga/src/front/spv/convert.rs @@ -0,0 +1,179 @@ +use super::error::Error; +use std::convert::TryInto; + +pub(super) const fn map_binary_operator(word: spirv::Op) -> Result<crate::BinaryOperator, Error> { + use crate::BinaryOperator; + use spirv::Op; + + match word { + // Arithmetic Instructions +, -, *, /, % + Op::IAdd | Op::FAdd => Ok(BinaryOperator::Add), + Op::ISub | Op::FSub => Ok(BinaryOperator::Subtract), + Op::IMul | Op::FMul => Ok(BinaryOperator::Multiply), + Op::UDiv | Op::SDiv | Op::FDiv => Ok(BinaryOperator::Divide), + Op::SRem => Ok(BinaryOperator::Modulo), + // Relational and Logical Instructions + Op::IEqual | Op::FOrdEqual | Op::FUnordEqual | Op::LogicalEqual => { + Ok(BinaryOperator::Equal) + } + Op::INotEqual | Op::FOrdNotEqual | Op::FUnordNotEqual | Op::LogicalNotEqual => { + Ok(BinaryOperator::NotEqual) + } + Op::ULessThan | Op::SLessThan | Op::FOrdLessThan | Op::FUnordLessThan => { + Ok(BinaryOperator::Less) + } + Op::ULessThanEqual + | Op::SLessThanEqual + | Op::FOrdLessThanEqual + | Op::FUnordLessThanEqual => Ok(BinaryOperator::LessEqual), + Op::UGreaterThan | Op::SGreaterThan | Op::FOrdGreaterThan | Op::FUnordGreaterThan => { + Ok(BinaryOperator::Greater) + } + Op::UGreaterThanEqual + | Op::SGreaterThanEqual + | Op::FOrdGreaterThanEqual + | Op::FUnordGreaterThanEqual => Ok(BinaryOperator::GreaterEqual), + Op::BitwiseOr => Ok(BinaryOperator::InclusiveOr), + Op::BitwiseXor => Ok(BinaryOperator::ExclusiveOr), + Op::BitwiseAnd => Ok(BinaryOperator::And), + _ => Err(Error::UnknownBinaryOperator(word)), + } +} + +pub(super) const fn map_relational_fun( + word: spirv::Op, +) -> Result<crate::RelationalFunction, Error> { + use crate::RelationalFunction as Rf; + use spirv::Op; + + match word { + Op::All => Ok(Rf::All), + Op::Any => Ok(Rf::Any), + Op::IsNan => Ok(Rf::IsNan), + Op::IsInf => Ok(Rf::IsInf), + _ => Err(Error::UnknownRelationalFunction(word)), + } +} + +pub(super) const fn map_vector_size(word: spirv::Word) -> Result<crate::VectorSize, Error> { + match word { + 2 => Ok(crate::VectorSize::Bi), + 3 => Ok(crate::VectorSize::Tri), + 4 => Ok(crate::VectorSize::Quad), + _ => Err(Error::InvalidVectorSize(word)), + } +} + +pub(super) fn map_image_dim(word: spirv::Word) -> Result<crate::ImageDimension, Error> { + use spirv::Dim as D; + match D::from_u32(word) { + Some(D::Dim1D) => Ok(crate::ImageDimension::D1), + Some(D::Dim2D) => Ok(crate::ImageDimension::D2), + Some(D::Dim3D) => Ok(crate::ImageDimension::D3), + Some(D::DimCube) => Ok(crate::ImageDimension::Cube), + _ => Err(Error::UnsupportedImageDim(word)), + } +} + +pub(super) fn map_image_format(word: spirv::Word) -> Result<crate::StorageFormat, Error> { + match spirv::ImageFormat::from_u32(word) { + Some(spirv::ImageFormat::R8) => Ok(crate::StorageFormat::R8Unorm), + Some(spirv::ImageFormat::R8Snorm) => Ok(crate::StorageFormat::R8Snorm), + Some(spirv::ImageFormat::R8ui) => Ok(crate::StorageFormat::R8Uint), + Some(spirv::ImageFormat::R8i) => Ok(crate::StorageFormat::R8Sint), + Some(spirv::ImageFormat::R16) => Ok(crate::StorageFormat::R16Unorm), + Some(spirv::ImageFormat::R16Snorm) => Ok(crate::StorageFormat::R16Snorm), + Some(spirv::ImageFormat::R16ui) => Ok(crate::StorageFormat::R16Uint), + Some(spirv::ImageFormat::R16i) => Ok(crate::StorageFormat::R16Sint), + Some(spirv::ImageFormat::R16f) => Ok(crate::StorageFormat::R16Float), + Some(spirv::ImageFormat::Rg8) => Ok(crate::StorageFormat::Rg8Unorm), + Some(spirv::ImageFormat::Rg8Snorm) => Ok(crate::StorageFormat::Rg8Snorm), + Some(spirv::ImageFormat::Rg8ui) => Ok(crate::StorageFormat::Rg8Uint), + Some(spirv::ImageFormat::Rg8i) => Ok(crate::StorageFormat::Rg8Sint), + Some(spirv::ImageFormat::R32ui) => Ok(crate::StorageFormat::R32Uint), + Some(spirv::ImageFormat::R32i) => Ok(crate::StorageFormat::R32Sint), + Some(spirv::ImageFormat::R32f) => Ok(crate::StorageFormat::R32Float), + Some(spirv::ImageFormat::Rg16) => Ok(crate::StorageFormat::Rg16Unorm), + Some(spirv::ImageFormat::Rg16Snorm) => Ok(crate::StorageFormat::Rg16Snorm), + Some(spirv::ImageFormat::Rg16ui) => Ok(crate::StorageFormat::Rg16Uint), + Some(spirv::ImageFormat::Rg16i) => Ok(crate::StorageFormat::Rg16Sint), + Some(spirv::ImageFormat::Rg16f) => Ok(crate::StorageFormat::Rg16Float), + Some(spirv::ImageFormat::Rgba8) => Ok(crate::StorageFormat::Rgba8Unorm), + Some(spirv::ImageFormat::Rgba8Snorm) => Ok(crate::StorageFormat::Rgba8Snorm), + Some(spirv::ImageFormat::Rgba8ui) => Ok(crate::StorageFormat::Rgba8Uint), + Some(spirv::ImageFormat::Rgba8i) => Ok(crate::StorageFormat::Rgba8Sint), + Some(spirv::ImageFormat::Rgb10a2ui) => Ok(crate::StorageFormat::Rgb10a2Uint), + Some(spirv::ImageFormat::Rgb10A2) => Ok(crate::StorageFormat::Rgb10a2Unorm), + Some(spirv::ImageFormat::R11fG11fB10f) => Ok(crate::StorageFormat::Rg11b10Float), + Some(spirv::ImageFormat::Rg32ui) => Ok(crate::StorageFormat::Rg32Uint), + Some(spirv::ImageFormat::Rg32i) => Ok(crate::StorageFormat::Rg32Sint), + Some(spirv::ImageFormat::Rg32f) => Ok(crate::StorageFormat::Rg32Float), + Some(spirv::ImageFormat::Rgba16) => Ok(crate::StorageFormat::Rgba16Unorm), + Some(spirv::ImageFormat::Rgba16Snorm) => Ok(crate::StorageFormat::Rgba16Snorm), + Some(spirv::ImageFormat::Rgba16ui) => Ok(crate::StorageFormat::Rgba16Uint), + Some(spirv::ImageFormat::Rgba16i) => Ok(crate::StorageFormat::Rgba16Sint), + Some(spirv::ImageFormat::Rgba16f) => Ok(crate::StorageFormat::Rgba16Float), + Some(spirv::ImageFormat::Rgba32ui) => Ok(crate::StorageFormat::Rgba32Uint), + Some(spirv::ImageFormat::Rgba32i) => Ok(crate::StorageFormat::Rgba32Sint), + Some(spirv::ImageFormat::Rgba32f) => Ok(crate::StorageFormat::Rgba32Float), + _ => Err(Error::UnsupportedImageFormat(word)), + } +} + +pub(super) fn map_width(word: spirv::Word) -> Result<crate::Bytes, Error> { + (word >> 3) // bits to bytes + .try_into() + .map_err(|_| Error::InvalidTypeWidth(word)) +} + +pub(super) fn map_builtin(word: spirv::Word, invariant: bool) -> Result<crate::BuiltIn, Error> { + use spirv::BuiltIn as Bi; + Ok(match spirv::BuiltIn::from_u32(word) { + Some(Bi::Position | Bi::FragCoord) => crate::BuiltIn::Position { invariant }, + Some(Bi::ViewIndex) => crate::BuiltIn::ViewIndex, + // vertex + Some(Bi::BaseInstance) => crate::BuiltIn::BaseInstance, + Some(Bi::BaseVertex) => crate::BuiltIn::BaseVertex, + Some(Bi::ClipDistance) => crate::BuiltIn::ClipDistance, + Some(Bi::CullDistance) => crate::BuiltIn::CullDistance, + Some(Bi::InstanceIndex) => crate::BuiltIn::InstanceIndex, + Some(Bi::PointSize) => crate::BuiltIn::PointSize, + Some(Bi::VertexIndex) => crate::BuiltIn::VertexIndex, + // fragment + Some(Bi::FragDepth) => crate::BuiltIn::FragDepth, + Some(Bi::PointCoord) => crate::BuiltIn::PointCoord, + Some(Bi::FrontFacing) => crate::BuiltIn::FrontFacing, + Some(Bi::PrimitiveId) => crate::BuiltIn::PrimitiveIndex, + Some(Bi::SampleId) => crate::BuiltIn::SampleIndex, + Some(Bi::SampleMask) => crate::BuiltIn::SampleMask, + // compute + Some(Bi::GlobalInvocationId) => crate::BuiltIn::GlobalInvocationId, + Some(Bi::LocalInvocationId) => crate::BuiltIn::LocalInvocationId, + Some(Bi::LocalInvocationIndex) => crate::BuiltIn::LocalInvocationIndex, + Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId, + Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize, + Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups, + _ => return Err(Error::UnsupportedBuiltIn(word)), + }) +} + +pub(super) fn map_storage_class(word: spirv::Word) -> Result<super::ExtendedClass, Error> { + use super::ExtendedClass as Ec; + use spirv::StorageClass as Sc; + Ok(match Sc::from_u32(word) { + Some(Sc::Function) => Ec::Global(crate::AddressSpace::Function), + Some(Sc::Input) => Ec::Input, + Some(Sc::Output) => Ec::Output, + Some(Sc::Private) => Ec::Global(crate::AddressSpace::Private), + Some(Sc::UniformConstant) => Ec::Global(crate::AddressSpace::Handle), + Some(Sc::StorageBuffer) => Ec::Global(crate::AddressSpace::Storage { + //Note: this is restricted by decorations later + access: crate::StorageAccess::all(), + }), + // we expect the `Storage` case to be filtered out before calling this function. + Some(Sc::Uniform) => Ec::Global(crate::AddressSpace::Uniform), + Some(Sc::Workgroup) => Ec::Global(crate::AddressSpace::WorkGroup), + Some(Sc::PushConstant) => Ec::Global(crate::AddressSpace::PushConstant), + _ => return Err(Error::UnsupportedStorageClass(word)), + }) +} diff --git a/third_party/rust/naga/src/front/spv/error.rs b/third_party/rust/naga/src/front/spv/error.rs new file mode 100644 index 0000000000..af025636c0 --- /dev/null +++ b/third_party/rust/naga/src/front/spv/error.rs @@ -0,0 +1,154 @@ +use super::ModuleState; +use crate::arena::Handle; +use codespan_reporting::diagnostic::Diagnostic; +use codespan_reporting::files::SimpleFile; +use codespan_reporting::term; +use termcolor::{NoColor, WriteColor}; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("invalid header")] + InvalidHeader, + #[error("invalid word count")] + InvalidWordCount, + #[error("unknown instruction {0}")] + UnknownInstruction(u16), + #[error("unknown capability %{0}")] + UnknownCapability(spirv::Word), + #[error("unsupported instruction {1:?} at {0:?}")] + UnsupportedInstruction(ModuleState, spirv::Op), + #[error("unsupported capability {0:?}")] + UnsupportedCapability(spirv::Capability), + #[error("unsupported extension {0}")] + UnsupportedExtension(String), + #[error("unsupported extension set {0}")] + UnsupportedExtSet(String), + #[error("unsupported extension instantiation set %{0}")] + UnsupportedExtInstSet(spirv::Word), + #[error("unsupported extension instantiation %{0}")] + UnsupportedExtInst(spirv::Word), + #[error("unsupported type {0:?}")] + UnsupportedType(Handle<crate::Type>), + #[error("unsupported execution model %{0}")] + UnsupportedExecutionModel(spirv::Word), + #[error("unsupported execution mode %{0}")] + UnsupportedExecutionMode(spirv::Word), + #[error("unsupported storage class %{0}")] + UnsupportedStorageClass(spirv::Word), + #[error("unsupported image dimension %{0}")] + UnsupportedImageDim(spirv::Word), + #[error("unsupported image format %{0}")] + UnsupportedImageFormat(spirv::Word), + #[error("unsupported builtin %{0}")] + UnsupportedBuiltIn(spirv::Word), + #[error("unsupported control flow %{0}")] + UnsupportedControlFlow(spirv::Word), + #[error("unsupported binary operator %{0}")] + UnsupportedBinaryOperator(spirv::Word), + #[error("Naga supports OpTypeRuntimeArray in the StorageBuffer storage class only")] + UnsupportedRuntimeArrayStorageClass, + #[error("unsupported matrix stride {stride} for a {columns}x{rows} matrix with scalar width={width}")] + UnsupportedMatrixStride { + stride: u32, + columns: u8, + rows: u8, + width: u8, + }, + #[error("unknown binary operator {0:?}")] + UnknownBinaryOperator(spirv::Op), + #[error("unknown relational function {0:?}")] + UnknownRelationalFunction(spirv::Op), + #[error("invalid parameter {0:?}")] + InvalidParameter(spirv::Op), + #[error("invalid operand count {1} for {0:?}")] + InvalidOperandCount(spirv::Op, u16), + #[error("invalid operand")] + InvalidOperand, + #[error("invalid id %{0}")] + InvalidId(spirv::Word), + #[error("invalid decoration %{0}")] + InvalidDecoration(spirv::Word), + #[error("invalid type width %{0}")] + InvalidTypeWidth(spirv::Word), + #[error("invalid sign %{0}")] + InvalidSign(spirv::Word), + #[error("invalid inner type %{0}")] + InvalidInnerType(spirv::Word), + #[error("invalid vector size %{0}")] + InvalidVectorSize(spirv::Word), + #[error("invalid access type %{0}")] + InvalidAccessType(spirv::Word), + #[error("invalid access {0:?}")] + InvalidAccess(crate::Expression), + #[error("invalid access index %{0}")] + InvalidAccessIndex(spirv::Word), + #[error("invalid index type %{0}")] + InvalidIndexType(spirv::Word), + #[error("invalid binding %{0}")] + InvalidBinding(spirv::Word), + #[error("invalid global var {0:?}")] + InvalidGlobalVar(crate::Expression), + #[error("invalid image/sampler expression {0:?}")] + InvalidImageExpression(crate::Expression), + #[error("invalid image base type {0:?}")] + InvalidImageBaseType(Handle<crate::Type>), + #[error("invalid image {0:?}")] + InvalidImage(Handle<crate::Type>), + #[error("invalid as type {0:?}")] + InvalidAsType(Handle<crate::Type>), + #[error("invalid vector type {0:?}")] + InvalidVectorType(Handle<crate::Type>), + #[error("inconsistent comparison sampling {0:?}")] + InconsistentComparisonSampling(Handle<crate::GlobalVariable>), + #[error("wrong function result type %{0}")] + WrongFunctionResultType(spirv::Word), + #[error("wrong function argument type %{0}")] + WrongFunctionArgumentType(spirv::Word), + #[error("missing decoration {0:?}")] + MissingDecoration(spirv::Decoration), + #[error("bad string")] + BadString, + #[error("incomplete data")] + IncompleteData, + #[error("invalid terminator")] + InvalidTerminator, + #[error("invalid edge classification")] + InvalidEdgeClassification, + #[error("cycle detected in the CFG during traversal at {0}")] + ControlFlowGraphCycle(crate::front::spv::BlockId), + #[error("recursive function call %{0}")] + FunctionCallCycle(spirv::Word), + #[error("invalid array size {0:?}")] + InvalidArraySize(Handle<crate::Constant>), + #[error("invalid barrier scope %{0}")] + InvalidBarrierScope(spirv::Word), + #[error("invalid barrier memory semantics %{0}")] + InvalidBarrierMemorySemantics(spirv::Word), + #[error( + "arrays of images / samplers are supported only through bindings for \ + now (i.e. you can't create an array of images or samplers that doesn't \ + come from a binding)" + )] + NonBindingArrayOfImageOrSamplers, +} + +impl Error { + pub fn emit_to_writer(&self, writer: &mut impl WriteColor, source: &str) { + self.emit_to_writer_with_path(writer, source, "glsl"); + } + + pub fn emit_to_writer_with_path(&self, writer: &mut impl WriteColor, source: &str, path: &str) { + let path = path.to_string(); + let files = SimpleFile::new(path, source); + let config = codespan_reporting::term::Config::default(); + let diagnostic = Diagnostic::error().with_message(format!("{self:?}")); + + term::emit(writer, &config, &files, &diagnostic).expect("cannot write error"); + } + + pub fn emit_to_string(&self, source: &str) -> String { + let mut writer = NoColor::new(Vec::new()); + self.emit_to_writer(&mut writer, source); + String::from_utf8(writer.into_inner()).unwrap() + } +} diff --git a/third_party/rust/naga/src/front/spv/function.rs b/third_party/rust/naga/src/front/spv/function.rs new file mode 100644 index 0000000000..198d9c52dd --- /dev/null +++ b/third_party/rust/naga/src/front/spv/function.rs @@ -0,0 +1,674 @@ +use crate::{ + arena::{Arena, Handle}, + front::spv::{BlockContext, BodyIndex}, +}; + +use super::{Error, Instruction, LookupExpression, LookupHelper as _}; +use crate::proc::Emitter; + +pub type BlockId = u32; + +#[derive(Copy, Clone, Debug)] +pub struct MergeInstruction { + pub merge_block_id: BlockId, + pub continue_block_id: Option<BlockId>, +} + +impl<I: Iterator<Item = u32>> super::Frontend<I> { + // Registers a function call. It will generate a dummy handle to call, which + // gets resolved after all the functions are processed. + pub(super) fn add_call( + &mut self, + from: spirv::Word, + to: spirv::Word, + ) -> Handle<crate::Function> { + let dummy_handle = self + .dummy_functions + .append(crate::Function::default(), Default::default()); + self.deferred_function_calls.push(to); + self.function_call_graph.add_edge(from, to, ()); + dummy_handle + } + + pub(super) fn parse_function(&mut self, module: &mut crate::Module) -> Result<(), Error> { + let start = self.data_offset; + self.lookup_expression.clear(); + self.lookup_load_override.clear(); + self.lookup_sampled_image.clear(); + + let result_type_id = self.next()?; + let fun_id = self.next()?; + let _fun_control = self.next()?; + let fun_type_id = self.next()?; + + let mut fun = { + let ft = self.lookup_function_type.lookup(fun_type_id)?; + if ft.return_type_id != result_type_id { + return Err(Error::WrongFunctionResultType(result_type_id)); + } + crate::Function { + name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name), + arguments: Vec::with_capacity(ft.parameter_type_ids.len()), + result: if self.lookup_void_type == Some(result_type_id) { + None + } else { + let lookup_result_ty = self.lookup_type.lookup(result_type_id)?; + Some(crate::FunctionResult { + ty: lookup_result_ty.handle, + binding: None, + }) + }, + local_variables: Arena::new(), + expressions: self + .make_expression_storage(&module.global_variables, &module.constants), + named_expressions: crate::NamedExpressions::default(), + body: crate::Block::new(), + } + }; + + // read parameters + for i in 0..fun.arguments.capacity() { + let start = self.data_offset; + match self.next_inst()? { + Instruction { + op: spirv::Op::FunctionParameter, + wc: 3, + } => { + let type_id = self.next()?; + let id = self.next()?; + let handle = fun.expressions.append( + crate::Expression::FunctionArgument(i as u32), + self.span_from(start), + ); + self.lookup_expression.insert( + id, + LookupExpression { + handle, + type_id, + // Setting this to an invalid id will cause get_expr_handle + // to default to the main body making sure no load/stores + // are added. + block_id: 0, + }, + ); + //Note: we redo the lookup in order to work around `self` borrowing + + if type_id + != self + .lookup_function_type + .lookup(fun_type_id)? + .parameter_type_ids[i] + { + return Err(Error::WrongFunctionArgumentType(type_id)); + } + let ty = self.lookup_type.lookup(type_id)?.handle; + let decor = self.future_decor.remove(&id).unwrap_or_default(); + fun.arguments.push(crate::FunctionArgument { + name: decor.name, + ty, + binding: None, + }); + } + Instruction { op, .. } => return Err(Error::InvalidParameter(op)), + } + } + + // Read body + self.function_call_graph.add_node(fun_id); + let mut parameters_sampling = + vec![super::image::SamplingFlags::empty(); fun.arguments.len()]; + + let mut block_ctx = BlockContext { + phis: Default::default(), + blocks: Default::default(), + body_for_label: Default::default(), + mergers: Default::default(), + bodies: Default::default(), + function_id: fun_id, + expressions: &mut fun.expressions, + local_arena: &mut fun.local_variables, + const_arena: &mut module.constants, + const_expressions: &mut module.const_expressions, + type_arena: &module.types, + global_arena: &module.global_variables, + arguments: &fun.arguments, + parameter_sampling: &mut parameters_sampling, + }; + // Insert the main body whose parent is also himself + block_ctx.bodies.push(super::Body::with_parent(0)); + + // Scan the blocks and add them as nodes + loop { + let fun_inst = self.next_inst()?; + log::debug!("{:?}", fun_inst.op); + match fun_inst.op { + spirv::Op::Line => { + fun_inst.expect(4)?; + let _file_id = self.next()?; + let _row_id = self.next()?; + let _col_id = self.next()?; + } + spirv::Op::Label => { + // Read the label ID + fun_inst.expect(2)?; + let block_id = self.next()?; + + self.next_block(block_id, &mut block_ctx)?; + } + spirv::Op::FunctionEnd => { + fun_inst.expect(1)?; + break; + } + _ => { + return Err(Error::UnsupportedInstruction(self.state, fun_inst.op)); + } + } + } + + if let Some(ref prefix) = self.options.block_ctx_dump_prefix { + let dump_suffix = match self.lookup_entry_point.get(&fun_id) { + Some(ep) => format!("block_ctx.{:?}-{}.txt", ep.stage, ep.name), + None => format!("block_ctx.Fun-{}.txt", module.functions.len()), + }; + let dest = prefix.join(dump_suffix); + let dump = format!("{block_ctx:#?}"); + if let Err(e) = std::fs::write(&dest, dump) { + log::error!("Unable to dump the block context into {:?}: {}", dest, e); + } + } + + // Emit `Store` statements to properly initialize all the local variables we + // created for `phi` expressions. + // + // Note that get_expr_handle also contributes slightly odd entries to this table, + // to get the spill. + for phi in block_ctx.phis.iter() { + // Get a pointer to the local variable for the phi's value. + let phi_pointer = block_ctx.expressions.append( + crate::Expression::LocalVariable(phi.local), + crate::Span::default(), + ); + + // At the end of each of `phi`'s predecessor blocks, store the corresponding + // source value in the phi's local variable. + for &(source, predecessor) in phi.expressions.iter() { + let source_lexp = &self.lookup_expression[&source]; + let predecessor_body_idx = block_ctx.body_for_label[&predecessor]; + // If the expression is a global/argument it will have a 0 block + // id so we must use a default value instead of panicking + let source_body_idx = block_ctx + .body_for_label + .get(&source_lexp.block_id) + .copied() + .unwrap_or(0); + + // If the Naga `Expression` generated for `source` is in scope, then we + // can simply store that in the phi's local variable. + // + // Otherwise, spill the source value to a local variable in the block that + // defines it. (We know this store dominates the predecessor; otherwise, + // the phi wouldn't have been able to refer to that source expression in + // the first place.) Then, the predecessor block can count on finding the + // source's value in that local variable. + let value = if super::is_parent(predecessor_body_idx, source_body_idx, &block_ctx) { + source_lexp.handle + } else { + // The source SPIR-V expression is not defined in the phi's + // predecessor block, nor is it a globally available expression. So it + // must be defined off in some other block that merely dominates the + // predecessor. This means that the corresponding Naga `Expression` + // may not be in scope in the predecessor block. + // + // In the block that defines `source`, spill it to a fresh local + // variable, to ensure we can still use it at the end of the + // predecessor. + let ty = self.lookup_type[&source_lexp.type_id].handle; + let local = block_ctx.local_arena.append( + crate::LocalVariable { + name: None, + ty, + init: None, + }, + crate::Span::default(), + ); + + let pointer = block_ctx.expressions.append( + crate::Expression::LocalVariable(local), + crate::Span::default(), + ); + + // Get the spilled value of the source expression. + let start = block_ctx.expressions.len(); + let expr = block_ctx + .expressions + .append(crate::Expression::Load { pointer }, crate::Span::default()); + let range = block_ctx.expressions.range_from(start); + + block_ctx + .blocks + .get_mut(&predecessor) + .unwrap() + .push(crate::Statement::Emit(range), crate::Span::default()); + + // At the end of the block that defines it, spill the source + // expression's value. + block_ctx + .blocks + .get_mut(&source_lexp.block_id) + .unwrap() + .push( + crate::Statement::Store { + pointer, + value: source_lexp.handle, + }, + crate::Span::default(), + ); + + expr + }; + + // At the end of the phi predecessor block, store the source + // value in the phi's value. + block_ctx.blocks.get_mut(&predecessor).unwrap().push( + crate::Statement::Store { + pointer: phi_pointer, + value, + }, + crate::Span::default(), + ) + } + } + + fun.body = block_ctx.lower(); + + // done + let fun_handle = module.functions.append(fun, self.span_from_with_op(start)); + self.lookup_function.insert( + fun_id, + super::LookupFunction { + handle: fun_handle, + parameters_sampling, + }, + ); + + if let Some(ep) = self.lookup_entry_point.remove(&fun_id) { + // create a wrapping function + let mut function = crate::Function { + name: Some(format!("{}_wrap", ep.name)), + arguments: Vec::new(), + result: None, + local_variables: Arena::new(), + expressions: Arena::new(), + named_expressions: crate::NamedExpressions::default(), + body: crate::Block::new(), + }; + + // 1. copy the inputs from arguments to privates + for &v_id in ep.variable_ids.iter() { + let lvar = self.lookup_variable.lookup(v_id)?; + if let super::Variable::Input(ref arg) = lvar.inner { + let span = module.global_variables.get_span(lvar.handle); + let arg_expr = function.expressions.append( + crate::Expression::FunctionArgument(function.arguments.len() as u32), + span, + ); + let load_expr = if arg.ty == module.global_variables[lvar.handle].ty { + arg_expr + } else { + // The only case where the type is different is if we need to treat + // unsigned integer as signed. + let mut emitter = Emitter::default(); + emitter.start(&function.expressions); + let handle = function.expressions.append( + crate::Expression::As { + expr: arg_expr, + kind: crate::ScalarKind::Sint, + convert: Some(4), + }, + span, + ); + function.body.extend(emitter.finish(&function.expressions)); + handle + }; + function.body.push( + crate::Statement::Store { + pointer: function + .expressions + .append(crate::Expression::GlobalVariable(lvar.handle), span), + value: load_expr, + }, + span, + ); + + let mut arg = arg.clone(); + if ep.stage == crate::ShaderStage::Fragment { + if let Some(ref mut binding) = arg.binding { + binding.apply_default_interpolation(&module.types[arg.ty].inner); + } + } + function.arguments.push(arg); + } + } + // 2. call the wrapped function + let fake_id = !(module.entry_points.len() as u32); // doesn't matter, as long as it's not a collision + let dummy_handle = self.add_call(fake_id, fun_id); + function.body.push( + crate::Statement::Call { + function: dummy_handle, + arguments: Vec::new(), + result: None, + }, + crate::Span::default(), + ); + + // 3. copy the outputs from privates to the result + let mut members = Vec::new(); + let mut components = Vec::new(); + for &v_id in ep.variable_ids.iter() { + let lvar = self.lookup_variable.lookup(v_id)?; + if let super::Variable::Output(ref result) = lvar.inner { + let span = module.global_variables.get_span(lvar.handle); + let expr_handle = function + .expressions + .append(crate::Expression::GlobalVariable(lvar.handle), span); + + // Cull problematic builtins of gl_PerVertex. + // See the docs for `Frontend::gl_per_vertex_builtin_access`. + { + let ty = &module.types[result.ty]; + match ty.inner { + crate::TypeInner::Struct { + members: ref original_members, + span, + } if ty.name.as_deref() == Some("gl_PerVertex") => { + let mut new_members = original_members.clone(); + for member in &mut new_members { + if let Some(crate::Binding::BuiltIn(built_in)) = member.binding + { + if !self.gl_per_vertex_builtin_access.contains(&built_in) { + member.binding = None + } + } + } + if &new_members != original_members { + module.types.replace( + result.ty, + crate::Type { + name: ty.name.clone(), + inner: crate::TypeInner::Struct { + members: new_members, + span, + }, + }, + ); + } + } + _ => {} + } + } + + match module.types[result.ty].inner { + crate::TypeInner::Struct { + members: ref sub_members, + .. + } => { + for (index, sm) in sub_members.iter().enumerate() { + if sm.binding.is_none() { + continue; + } + let mut sm = sm.clone(); + + if let Some(ref mut binding) = sm.binding { + if ep.stage == crate::ShaderStage::Vertex { + binding.apply_default_interpolation( + &module.types[sm.ty].inner, + ); + } + } + + members.push(sm); + + components.push(function.expressions.append( + crate::Expression::AccessIndex { + base: expr_handle, + index: index as u32, + }, + span, + )); + } + } + ref inner => { + let mut binding = result.binding.clone(); + if let Some(ref mut binding) = binding { + if ep.stage == crate::ShaderStage::Vertex { + binding.apply_default_interpolation(inner); + } + } + + members.push(crate::StructMember { + name: None, + ty: result.ty, + binding, + offset: 0, + }); + // populate just the globals first, then do `Load` in a + // separate step, so that we can get a range. + components.push(expr_handle); + } + } + } + } + + for (member_index, member) in members.iter().enumerate() { + match member.binding { + Some(crate::Binding::BuiltIn(crate::BuiltIn::Position { .. })) + if self.options.adjust_coordinate_space => + { + let mut emitter = Emitter::default(); + emitter.start(&function.expressions); + let global_expr = components[member_index]; + let span = function.expressions.get_span(global_expr); + let access_expr = function.expressions.append( + crate::Expression::AccessIndex { + base: global_expr, + index: 1, + }, + span, + ); + let load_expr = function.expressions.append( + crate::Expression::Load { + pointer: access_expr, + }, + span, + ); + let neg_expr = function.expressions.append( + crate::Expression::Unary { + op: crate::UnaryOperator::Negate, + expr: load_expr, + }, + span, + ); + function.body.extend(emitter.finish(&function.expressions)); + function.body.push( + crate::Statement::Store { + pointer: access_expr, + value: neg_expr, + }, + span, + ); + } + _ => {} + } + } + + let mut emitter = Emitter::default(); + emitter.start(&function.expressions); + for component in components.iter_mut() { + let load_expr = crate::Expression::Load { + pointer: *component, + }; + let span = function.expressions.get_span(*component); + *component = function.expressions.append(load_expr, span); + } + + match members[..] { + [] => {} + [ref member] => { + function.body.extend(emitter.finish(&function.expressions)); + let span = function.expressions.get_span(components[0]); + function.body.push( + crate::Statement::Return { + value: components.first().cloned(), + }, + span, + ); + function.result = Some(crate::FunctionResult { + ty: member.ty, + binding: member.binding.clone(), + }); + } + _ => { + let span = crate::Span::total_span( + components.iter().map(|h| function.expressions.get_span(*h)), + ); + let ty = module.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Struct { + members, + span: 0xFFFF, // shouldn't matter + }, + }, + span, + ); + let result_expr = function + .expressions + .append(crate::Expression::Compose { ty, components }, span); + function.body.extend(emitter.finish(&function.expressions)); + function.body.push( + crate::Statement::Return { + value: Some(result_expr), + }, + span, + ); + function.result = Some(crate::FunctionResult { ty, binding: None }); + } + } + + module.entry_points.push(crate::EntryPoint { + name: ep.name, + stage: ep.stage, + early_depth_test: ep.early_depth_test, + workgroup_size: ep.workgroup_size, + function, + }); + } + + Ok(()) + } +} + +impl<'function> BlockContext<'function> { + pub(super) fn gctx(&self) -> crate::proc::GlobalCtx { + crate::proc::GlobalCtx { + types: self.type_arena, + constants: self.const_arena, + const_expressions: self.const_expressions, + } + } + + /// Consumes the `BlockContext` producing a Ir [`Block`](crate::Block) + fn lower(mut self) -> crate::Block { + fn lower_impl( + blocks: &mut crate::FastHashMap<spirv::Word, crate::Block>, + bodies: &[super::Body], + body_idx: BodyIndex, + ) -> crate::Block { + let mut block = crate::Block::new(); + + for item in bodies[body_idx].data.iter() { + match *item { + super::BodyFragment::BlockId(id) => block.append(blocks.get_mut(&id).unwrap()), + super::BodyFragment::If { + condition, + accept, + reject, + } => { + let accept = lower_impl(blocks, bodies, accept); + let reject = lower_impl(blocks, bodies, reject); + + block.push( + crate::Statement::If { + condition, + accept, + reject, + }, + crate::Span::default(), + ) + } + super::BodyFragment::Loop { + body, + continuing, + break_if, + } => { + let body = lower_impl(blocks, bodies, body); + let continuing = lower_impl(blocks, bodies, continuing); + + block.push( + crate::Statement::Loop { + body, + continuing, + break_if, + }, + crate::Span::default(), + ) + } + super::BodyFragment::Switch { + selector, + ref cases, + default, + } => { + let mut ir_cases: Vec<_> = cases + .iter() + .map(|&(value, body_idx)| { + let body = lower_impl(blocks, bodies, body_idx); + + // Handle simple cases that would make a fallthrough statement unreachable code + let fall_through = body.last().map_or(true, |s| !s.is_terminator()); + + crate::SwitchCase { + value: crate::SwitchValue::I32(value), + body, + fall_through, + } + }) + .collect(); + ir_cases.push(crate::SwitchCase { + value: crate::SwitchValue::Default, + body: lower_impl(blocks, bodies, default), + fall_through: false, + }); + + block.push( + crate::Statement::Switch { + selector, + cases: ir_cases, + }, + crate::Span::default(), + ) + } + super::BodyFragment::Break => { + block.push(crate::Statement::Break, crate::Span::default()) + } + super::BodyFragment::Continue => { + block.push(crate::Statement::Continue, crate::Span::default()) + } + } + } + + block + } + + lower_impl(&mut self.blocks, &self.bodies, 0) + } +} diff --git a/third_party/rust/naga/src/front/spv/image.rs b/third_party/rust/naga/src/front/spv/image.rs new file mode 100644 index 0000000000..0f25dd626b --- /dev/null +++ b/third_party/rust/naga/src/front/spv/image.rs @@ -0,0 +1,767 @@ +use crate::{ + arena::{Handle, UniqueArena}, + Scalar, +}; + +use super::{Error, LookupExpression, LookupHelper as _}; + +#[derive(Clone, Debug)] +pub(super) struct LookupSampledImage { + image: Handle<crate::Expression>, + sampler: Handle<crate::Expression>, +} + +bitflags::bitflags! { + /// Flags describing sampling method. + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct SamplingFlags: u32 { + /// Regular sampling. + const REGULAR = 0x1; + /// Comparison sampling. + const COMPARISON = 0x2; + } +} + +impl<'function> super::BlockContext<'function> { + fn get_image_expr_ty( + &self, + handle: Handle<crate::Expression>, + ) -> Result<Handle<crate::Type>, Error> { + match self.expressions[handle] { + crate::Expression::GlobalVariable(handle) => Ok(self.global_arena[handle].ty), + crate::Expression::FunctionArgument(i) => Ok(self.arguments[i as usize].ty), + ref other => Err(Error::InvalidImageExpression(other.clone())), + } + } +} + +/// Options of a sampling operation. +#[derive(Debug)] +pub struct SamplingOptions { + /// Projection sampling: the division by W is expected to happen + /// in the texture unit. + pub project: bool, + /// Depth comparison sampling with a reference value. + pub compare: bool, +} + +enum ExtraCoordinate { + ArrayLayer, + Projection, + Garbage, +} + +/// Return the texture coordinates separated from the array layer, +/// and/or divided by the projection term. +/// +/// The Proj sampling ops expect an extra coordinate for the W. +/// The arrayed (can't be Proj!) images expect an extra coordinate for the layer. +fn extract_image_coordinates( + image_dim: crate::ImageDimension, + extra_coordinate: ExtraCoordinate, + base: Handle<crate::Expression>, + coordinate_ty: Handle<crate::Type>, + ctx: &mut super::BlockContext, +) -> (Handle<crate::Expression>, Option<Handle<crate::Expression>>) { + let (given_size, kind) = match ctx.type_arena[coordinate_ty].inner { + crate::TypeInner::Scalar(Scalar { kind, .. }) => (None, kind), + crate::TypeInner::Vector { + size, + scalar: Scalar { kind, .. }, + } => (Some(size), kind), + ref other => unreachable!("Unexpected texture coordinate {:?}", other), + }; + + let required_size = image_dim.required_coordinate_size(); + let required_ty = required_size.map(|size| { + ctx.type_arena + .get(&crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size, + scalar: Scalar { kind, width: 4 }, + }, + }) + .expect("Required coordinate type should have been set up by `parse_type_image`!") + }); + let extra_expr = crate::Expression::AccessIndex { + base, + index: required_size.map_or(1, |size| size as u32), + }; + + let base_span = ctx.expressions.get_span(base); + + match extra_coordinate { + ExtraCoordinate::ArrayLayer => { + let extracted = match required_size { + None => ctx + .expressions + .append(crate::Expression::AccessIndex { base, index: 0 }, base_span), + Some(size) => { + let mut components = Vec::with_capacity(size as usize); + for index in 0..size as u32 { + let comp = ctx + .expressions + .append(crate::Expression::AccessIndex { base, index }, base_span); + components.push(comp); + } + ctx.expressions.append( + crate::Expression::Compose { + ty: required_ty.unwrap(), + components, + }, + base_span, + ) + } + }; + let array_index_f32 = ctx.expressions.append(extra_expr, base_span); + let array_index = ctx.expressions.append( + crate::Expression::As { + kind: crate::ScalarKind::Sint, + expr: array_index_f32, + convert: Some(4), + }, + base_span, + ); + (extracted, Some(array_index)) + } + ExtraCoordinate::Projection => { + let projection = ctx.expressions.append(extra_expr, base_span); + let divided = match required_size { + None => { + let temp = ctx + .expressions + .append(crate::Expression::AccessIndex { base, index: 0 }, base_span); + ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Divide, + left: temp, + right: projection, + }, + base_span, + ) + } + Some(size) => { + let mut components = Vec::with_capacity(size as usize); + for index in 0..size as u32 { + let temp = ctx + .expressions + .append(crate::Expression::AccessIndex { base, index }, base_span); + let comp = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Divide, + left: temp, + right: projection, + }, + base_span, + ); + components.push(comp); + } + ctx.expressions.append( + crate::Expression::Compose { + ty: required_ty.unwrap(), + components, + }, + base_span, + ) + } + }; + (divided, None) + } + ExtraCoordinate::Garbage if given_size == required_size => (base, None), + ExtraCoordinate::Garbage => { + use crate::SwizzleComponent as Sc; + let cut_expr = match required_size { + None => crate::Expression::AccessIndex { base, index: 0 }, + Some(size) => crate::Expression::Swizzle { + size, + vector: base, + pattern: [Sc::X, Sc::Y, Sc::Z, Sc::W], + }, + }; + (ctx.expressions.append(cut_expr, base_span), None) + } + } +} + +pub(super) fn patch_comparison_type( + flags: SamplingFlags, + var: &mut crate::GlobalVariable, + arena: &mut UniqueArena<crate::Type>, +) -> bool { + if !flags.contains(SamplingFlags::COMPARISON) { + return true; + } + if flags == SamplingFlags::all() { + return false; + } + + log::debug!("Flipping comparison for {:?}", var); + let original_ty = &arena[var.ty]; + let original_ty_span = arena.get_span(var.ty); + let ty_inner = match original_ty.inner { + crate::TypeInner::Image { + class: crate::ImageClass::Sampled { multi, .. }, + dim, + arrayed, + } => crate::TypeInner::Image { + class: crate::ImageClass::Depth { multi }, + dim, + arrayed, + }, + crate::TypeInner::Sampler { .. } => crate::TypeInner::Sampler { comparison: true }, + ref other => unreachable!("Unexpected type for comparison mutation: {:?}", other), + }; + + let name = original_ty.name.clone(); + var.ty = arena.insert( + crate::Type { + name, + inner: ty_inner, + }, + original_ty_span, + ); + true +} + +impl<I: Iterator<Item = u32>> super::Frontend<I> { + pub(super) fn parse_image_couple(&mut self) -> Result<(), Error> { + let _result_type_id = self.next()?; + let result_id = self.next()?; + let image_id = self.next()?; + let sampler_id = self.next()?; + let image_lexp = self.lookup_expression.lookup(image_id)?; + let sampler_lexp = self.lookup_expression.lookup(sampler_id)?; + self.lookup_sampled_image.insert( + result_id, + LookupSampledImage { + image: image_lexp.handle, + sampler: sampler_lexp.handle, + }, + ); + Ok(()) + } + + pub(super) fn parse_image_uncouple(&mut self, block_id: spirv::Word) -> Result<(), Error> { + let result_type_id = self.next()?; + let result_id = self.next()?; + let sampled_image_id = self.next()?; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: self.lookup_sampled_image.lookup(sampled_image_id)?.image, + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + pub(super) fn parse_image_write( + &mut self, + words_left: u16, + ctx: &mut super::BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + body_idx: usize, + ) -> Result<crate::Statement, Error> { + let image_id = self.next()?; + let coordinate_id = self.next()?; + let value_id = self.next()?; + + let image_ops = if words_left != 0 { self.next()? } else { 0 }; + + if image_ops != 0 { + let other = spirv::ImageOperands::from_bits_truncate(image_ops); + log::warn!("Unknown image write ops {:?}", other); + for _ in 1..words_left { + self.next()?; + } + } + + let image_lexp = self.lookup_expression.lookup(image_id)?; + let image_ty = ctx.get_image_expr_ty(image_lexp.handle)?; + + let coord_lexp = self.lookup_expression.lookup(coordinate_id)?; + let coord_handle = + self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx); + let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle; + let (coordinate, array_index) = match ctx.type_arena[image_ty].inner { + crate::TypeInner::Image { + dim, + arrayed, + class: _, + } => extract_image_coordinates( + dim, + if arrayed { + ExtraCoordinate::ArrayLayer + } else { + ExtraCoordinate::Garbage + }, + coord_handle, + coord_type_handle, + ctx, + ), + _ => return Err(Error::InvalidImage(image_ty)), + }; + + let value_lexp = self.lookup_expression.lookup(value_id)?; + let value = self.get_expr_handle(value_id, value_lexp, ctx, emitter, block, body_idx); + + Ok(crate::Statement::ImageStore { + image: image_lexp.handle, + coordinate, + array_index, + value, + }) + } + + pub(super) fn parse_image_load( + &mut self, + mut words_left: u16, + ctx: &mut super::BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let image_id = self.next()?; + let coordinate_id = self.next()?; + + let mut image_ops = if words_left != 0 { + words_left -= 1; + self.next()? + } else { + 0 + }; + + let mut sample = None; + let mut level = None; + while image_ops != 0 { + let bit = 1 << image_ops.trailing_zeros(); + match spirv::ImageOperands::from_bits_truncate(bit) { + spirv::ImageOperands::LOD => { + let lod_expr = self.next()?; + let lod_lexp = self.lookup_expression.lookup(lod_expr)?; + let lod_handle = + self.get_expr_handle(lod_expr, lod_lexp, ctx, emitter, block, body_idx); + level = Some(lod_handle); + words_left -= 1; + } + spirv::ImageOperands::SAMPLE => { + let sample_expr = self.next()?; + let sample_handle = self.lookup_expression.lookup(sample_expr)?.handle; + sample = Some(sample_handle); + words_left -= 1; + } + other => { + log::warn!("Unknown image load op {:?}", other); + for _ in 0..words_left { + self.next()?; + } + break; + } + } + image_ops ^= bit; + } + + // No need to call get_expr_handle here since only globals/arguments are + // allowed as images and they are always in the root scope + let image_lexp = self.lookup_expression.lookup(image_id)?; + let image_ty = ctx.get_image_expr_ty(image_lexp.handle)?; + + let coord_lexp = self.lookup_expression.lookup(coordinate_id)?; + let coord_handle = + self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx); + let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle; + let (coordinate, array_index) = match ctx.type_arena[image_ty].inner { + crate::TypeInner::Image { + dim, + arrayed, + class: _, + } => extract_image_coordinates( + dim, + if arrayed { + ExtraCoordinate::ArrayLayer + } else { + ExtraCoordinate::Garbage + }, + coord_handle, + coord_type_handle, + ctx, + ), + _ => return Err(Error::InvalidImage(image_ty)), + }; + + let expr = crate::Expression::ImageLoad { + image: image_lexp.handle, + coordinate, + array_index, + sample, + level, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, self.span_from_with_op(start)), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + #[allow(clippy::too_many_arguments)] + pub(super) fn parse_image_sample( + &mut self, + mut words_left: u16, + options: SamplingOptions, + ctx: &mut super::BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let sampled_image_id = self.next()?; + let coordinate_id = self.next()?; + let dref_id = if options.compare { + Some(self.next()?) + } else { + None + }; + + let mut image_ops = if words_left != 0 { + words_left -= 1; + self.next()? + } else { + 0 + }; + + let mut level = crate::SampleLevel::Auto; + let mut offset = None; + while image_ops != 0 { + let bit = 1 << image_ops.trailing_zeros(); + match spirv::ImageOperands::from_bits_truncate(bit) { + spirv::ImageOperands::BIAS => { + let bias_expr = self.next()?; + let bias_lexp = self.lookup_expression.lookup(bias_expr)?; + let bias_handle = + self.get_expr_handle(bias_expr, bias_lexp, ctx, emitter, block, body_idx); + level = crate::SampleLevel::Bias(bias_handle); + words_left -= 1; + } + spirv::ImageOperands::LOD => { + let lod_expr = self.next()?; + let lod_lexp = self.lookup_expression.lookup(lod_expr)?; + let lod_handle = + self.get_expr_handle(lod_expr, lod_lexp, ctx, emitter, block, body_idx); + level = if options.compare { + log::debug!("Assuming {:?} is zero", lod_handle); + crate::SampleLevel::Zero + } else { + crate::SampleLevel::Exact(lod_handle) + }; + words_left -= 1; + } + spirv::ImageOperands::GRAD => { + let grad_x_expr = self.next()?; + let grad_x_lexp = self.lookup_expression.lookup(grad_x_expr)?; + let grad_x_handle = self.get_expr_handle( + grad_x_expr, + grad_x_lexp, + ctx, + emitter, + block, + body_idx, + ); + let grad_y_expr = self.next()?; + let grad_y_lexp = self.lookup_expression.lookup(grad_y_expr)?; + let grad_y_handle = self.get_expr_handle( + grad_y_expr, + grad_y_lexp, + ctx, + emitter, + block, + body_idx, + ); + level = if options.compare { + log::debug!( + "Assuming gradients {:?} and {:?} are not greater than 1", + grad_x_handle, + grad_y_handle + ); + crate::SampleLevel::Zero + } else { + crate::SampleLevel::Gradient { + x: grad_x_handle, + y: grad_y_handle, + } + }; + words_left -= 2; + } + spirv::ImageOperands::CONST_OFFSET => { + let offset_constant = self.next()?; + let offset_handle = self.lookup_constant.lookup(offset_constant)?.handle; + let offset_handle = ctx.const_expressions.append( + crate::Expression::Constant(offset_handle), + Default::default(), + ); + offset = Some(offset_handle); + words_left -= 1; + } + other => { + log::warn!("Unknown image sample operand {:?}", other); + for _ in 0..words_left { + self.next()?; + } + break; + } + } + image_ops ^= bit; + } + + let si_lexp = self.lookup_sampled_image.lookup(sampled_image_id)?; + let coord_lexp = self.lookup_expression.lookup(coordinate_id)?; + let coord_handle = + self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx); + let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle; + + let sampling_bit = if options.compare { + SamplingFlags::COMPARISON + } else { + SamplingFlags::REGULAR + }; + + let image_ty = match ctx.expressions[si_lexp.image] { + crate::Expression::GlobalVariable(handle) => { + if let Some(flags) = self.handle_sampling.get_mut(&handle) { + *flags |= sampling_bit; + } + + ctx.global_arena[handle].ty + } + + crate::Expression::FunctionArgument(i) => { + ctx.parameter_sampling[i as usize] |= sampling_bit; + ctx.arguments[i as usize].ty + } + + crate::Expression::Access { base, .. } => match ctx.expressions[base] { + crate::Expression::GlobalVariable(handle) => { + if let Some(flags) = self.handle_sampling.get_mut(&handle) { + *flags |= sampling_bit; + } + + match ctx.type_arena[ctx.global_arena[handle].ty].inner { + crate::TypeInner::BindingArray { base, .. } => base, + _ => return Err(Error::InvalidGlobalVar(ctx.expressions[base].clone())), + } + } + + ref other => return Err(Error::InvalidGlobalVar(other.clone())), + }, + + ref other => return Err(Error::InvalidGlobalVar(other.clone())), + }; + + match ctx.expressions[si_lexp.sampler] { + crate::Expression::GlobalVariable(handle) => { + *self.handle_sampling.get_mut(&handle).unwrap() |= sampling_bit; + } + + crate::Expression::FunctionArgument(i) => { + ctx.parameter_sampling[i as usize] |= sampling_bit; + } + + crate::Expression::Access { base, .. } => match ctx.expressions[base] { + crate::Expression::GlobalVariable(handle) => { + *self.handle_sampling.get_mut(&handle).unwrap() |= sampling_bit; + } + + ref other => return Err(Error::InvalidGlobalVar(other.clone())), + }, + + ref other => return Err(Error::InvalidGlobalVar(other.clone())), + } + + let ((coordinate, array_index), depth_ref) = match ctx.type_arena[image_ty].inner { + crate::TypeInner::Image { + dim, + arrayed, + class: _, + } => ( + extract_image_coordinates( + dim, + if options.project { + ExtraCoordinate::Projection + } else if arrayed { + ExtraCoordinate::ArrayLayer + } else { + ExtraCoordinate::Garbage + }, + coord_handle, + coord_type_handle, + ctx, + ), + { + match dref_id { + Some(id) => { + let expr_lexp = self.lookup_expression.lookup(id)?; + let mut expr = + self.get_expr_handle(id, expr_lexp, ctx, emitter, block, body_idx); + + if options.project { + let required_size = dim.required_coordinate_size(); + let right = ctx.expressions.append( + crate::Expression::AccessIndex { + base: coord_handle, + index: required_size.map_or(1, |size| size as u32), + }, + crate::Span::default(), + ); + expr = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Divide, + left: expr, + right, + }, + crate::Span::default(), + ) + }; + Some(expr) + } + None => None, + } + }, + ), + _ => return Err(Error::InvalidImage(image_ty)), + }; + + let expr = crate::Expression::ImageSample { + image: si_lexp.image, + sampler: si_lexp.sampler, + gather: None, //TODO + coordinate, + array_index, + offset, + level, + depth_ref, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, self.span_from_with_op(start)), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + pub(super) fn parse_image_query_size( + &mut self, + at_level: bool, + ctx: &mut super::BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let image_id = self.next()?; + let level = if at_level { + let level_id = self.next()?; + let level_lexp = self.lookup_expression.lookup(level_id)?; + Some(self.get_expr_handle(level_id, level_lexp, ctx, emitter, block, body_idx)) + } else { + None + }; + + // No need to call get_expr_handle here since only globals/arguments are + // allowed as images and they are always in the root scope + //TODO: handle arrays and cubes + let image_lexp = self.lookup_expression.lookup(image_id)?; + + let expr = crate::Expression::ImageQuery { + image: image_lexp.handle, + query: crate::ImageQuery::Size { level }, + }; + + let result_type_handle = self.lookup_type.lookup(result_type_id)?.handle; + let maybe_scalar_kind = ctx.type_arena[result_type_handle].inner.scalar_kind(); + + let expr = if maybe_scalar_kind == Some(crate::ScalarKind::Sint) { + crate::Expression::As { + expr: ctx.expressions.append(expr, self.span_from_with_op(start)), + kind: crate::ScalarKind::Sint, + convert: Some(4), + } + } else { + expr + }; + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, self.span_from_with_op(start)), + type_id: result_type_id, + block_id, + }, + ); + + Ok(()) + } + + pub(super) fn parse_image_query_other( + &mut self, + query: crate::ImageQuery, + ctx: &mut super::BlockContext, + block_id: spirv::Word, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let image_id = self.next()?; + + // No need to call get_expr_handle here since only globals/arguments are + // allowed as images and they are always in the root scope + let image_lexp = self.lookup_expression.lookup(image_id)?.clone(); + + let expr = crate::Expression::ImageQuery { + image: image_lexp.handle, + query, + }; + + let result_type_handle = self.lookup_type.lookup(result_type_id)?.handle; + let maybe_scalar_kind = ctx.type_arena[result_type_handle].inner.scalar_kind(); + + let expr = if maybe_scalar_kind == Some(crate::ScalarKind::Sint) { + crate::Expression::As { + expr: ctx.expressions.append(expr, self.span_from_with_op(start)), + kind: crate::ScalarKind::Sint, + convert: Some(4), + } + } else { + expr + }; + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, self.span_from_with_op(start)), + type_id: result_type_id, + block_id, + }, + ); + + Ok(()) + } +} diff --git a/third_party/rust/naga/src/front/spv/mod.rs b/third_party/rust/naga/src/front/spv/mod.rs new file mode 100644 index 0000000000..8b1c854358 --- /dev/null +++ b/third_party/rust/naga/src/front/spv/mod.rs @@ -0,0 +1,5356 @@ +/*! +Frontend for [SPIR-V][spv] (Standard Portable Intermediate Representation). + +## ID lookups + +Our IR links to everything with `Handle`, while SPIR-V uses IDs. +In order to keep track of the associations, the parser has many lookup tables. +There map `spv::Word` into a specific IR handle, plus potentially a bit of +extra info, such as the related SPIR-V type ID. +TODO: would be nice to find ways that avoid looking up as much + +## Inputs/Outputs + +We create a private variable for each input/output. The relevant inputs are +populated at the start of an entry point. The outputs are saved at the end. + +The function associated with an entry point is wrapped in another function, +such that we can handle any `Return` statements without problems. + +## Row-major matrices + +We don't handle them natively, since the IR only expects column majority. +Instead, we detect when such matrix is accessed in the `OpAccessChain`, +and we generate a parallel expression that loads the value, but transposed. +This value then gets used instead of `OpLoad` result later on. + +[spv]: https://www.khronos.org/registry/SPIR-V/ +*/ + +mod convert; +mod error; +mod function; +mod image; +mod null; + +use convert::*; +pub use error::Error; +use function::*; + +use crate::{ + arena::{Arena, Handle, UniqueArena}, + proc::{Alignment, Layouter}, + FastHashMap, FastHashSet, FastIndexMap, +}; + +use petgraph::graphmap::GraphMap; +use std::{convert::TryInto, mem, num::NonZeroU32, path::PathBuf}; + +pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[ + spirv::Capability::Shader, + spirv::Capability::VulkanMemoryModel, + spirv::Capability::ClipDistance, + spirv::Capability::CullDistance, + spirv::Capability::SampleRateShading, + spirv::Capability::DerivativeControl, + spirv::Capability::Matrix, + spirv::Capability::ImageQuery, + spirv::Capability::Sampled1D, + spirv::Capability::Image1D, + spirv::Capability::SampledCubeArray, + spirv::Capability::ImageCubeArray, + spirv::Capability::StorageImageExtendedFormats, + spirv::Capability::Int8, + spirv::Capability::Int16, + spirv::Capability::Int64, + spirv::Capability::Float16, + spirv::Capability::Float64, + spirv::Capability::Geometry, + spirv::Capability::MultiView, + // tricky ones + spirv::Capability::UniformBufferArrayDynamicIndexing, + spirv::Capability::StorageBufferArrayDynamicIndexing, +]; +pub const SUPPORTED_EXTENSIONS: &[&str] = &[ + "SPV_KHR_storage_buffer_storage_class", + "SPV_KHR_vulkan_memory_model", + "SPV_KHR_multiview", +]; +pub const SUPPORTED_EXT_SETS: &[&str] = &["GLSL.std.450"]; + +#[derive(Copy, Clone)] +pub struct Instruction { + op: spirv::Op, + wc: u16, +} + +impl Instruction { + const fn expect(self, count: u16) -> Result<(), Error> { + if self.wc == count { + Ok(()) + } else { + Err(Error::InvalidOperandCount(self.op, self.wc)) + } + } + + fn expect_at_least(self, count: u16) -> Result<u16, Error> { + self.wc + .checked_sub(count) + .ok_or(Error::InvalidOperandCount(self.op, self.wc)) + } +} + +impl crate::TypeInner { + fn can_comparison_sample(&self, module: &crate::Module) -> bool { + match *self { + crate::TypeInner::Image { + class: + crate::ImageClass::Sampled { + kind: crate::ScalarKind::Float, + multi: false, + }, + .. + } => true, + crate::TypeInner::Sampler { .. } => true, + crate::TypeInner::BindingArray { base, .. } => { + module.types[base].inner.can_comparison_sample(module) + } + _ => false, + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] +pub enum ModuleState { + Empty, + Capability, + Extension, + ExtInstImport, + MemoryModel, + EntryPoint, + ExecutionMode, + Source, + Name, + ModuleProcessed, + Annotation, + Type, + Function, +} + +trait LookupHelper { + type Target; + fn lookup(&self, key: spirv::Word) -> Result<&Self::Target, Error>; +} + +impl<T> LookupHelper for FastHashMap<spirv::Word, T> { + type Target = T; + fn lookup(&self, key: spirv::Word) -> Result<&T, Error> { + self.get(&key).ok_or(Error::InvalidId(key)) + } +} + +impl crate::ImageDimension { + const fn required_coordinate_size(&self) -> Option<crate::VectorSize> { + match *self { + crate::ImageDimension::D1 => None, + crate::ImageDimension::D2 => Some(crate::VectorSize::Bi), + crate::ImageDimension::D3 => Some(crate::VectorSize::Tri), + crate::ImageDimension::Cube => Some(crate::VectorSize::Tri), + } + } +} + +type MemberIndex = u32; + +bitflags::bitflags! { + #[derive(Clone, Copy, Debug, Default)] + struct DecorationFlags: u32 { + const NON_READABLE = 0x1; + const NON_WRITABLE = 0x2; + } +} + +impl DecorationFlags { + fn to_storage_access(self) -> crate::StorageAccess { + let mut access = crate::StorageAccess::all(); + if self.contains(DecorationFlags::NON_READABLE) { + access &= !crate::StorageAccess::LOAD; + } + if self.contains(DecorationFlags::NON_WRITABLE) { + access &= !crate::StorageAccess::STORE; + } + access + } +} + +#[derive(Debug, PartialEq)] +enum Majority { + Column, + Row, +} + +#[derive(Debug, Default)] +struct Decoration { + name: Option<String>, + built_in: Option<spirv::Word>, + location: Option<spirv::Word>, + desc_set: Option<spirv::Word>, + desc_index: Option<spirv::Word>, + specialization: Option<spirv::Word>, + storage_buffer: bool, + offset: Option<spirv::Word>, + array_stride: Option<NonZeroU32>, + matrix_stride: Option<NonZeroU32>, + matrix_major: Option<Majority>, + invariant: bool, + interpolation: Option<crate::Interpolation>, + sampling: Option<crate::Sampling>, + flags: DecorationFlags, +} + +impl Decoration { + fn debug_name(&self) -> &str { + match self.name { + Some(ref name) => name.as_str(), + None => "?", + } + } + + fn specialization(&self) -> crate::Override { + self.specialization + .map_or(crate::Override::None, crate::Override::ByNameOrId) + } + + const fn resource_binding(&self) -> Option<crate::ResourceBinding> { + match *self { + Decoration { + desc_set: Some(group), + desc_index: Some(binding), + .. + } => Some(crate::ResourceBinding { group, binding }), + _ => None, + } + } + + fn io_binding(&self) -> Result<crate::Binding, Error> { + match *self { + Decoration { + built_in: Some(built_in), + location: None, + invariant, + .. + } => Ok(crate::Binding::BuiltIn(map_builtin(built_in, invariant)?)), + Decoration { + built_in: None, + location: Some(location), + interpolation, + sampling, + .. + } => Ok(crate::Binding::Location { + location, + interpolation, + sampling, + second_blend_source: false, + }), + _ => Err(Error::MissingDecoration(spirv::Decoration::Location)), + } + } +} + +#[derive(Debug)] +struct LookupFunctionType { + parameter_type_ids: Vec<spirv::Word>, + return_type_id: spirv::Word, +} + +struct LookupFunction { + handle: Handle<crate::Function>, + parameters_sampling: Vec<image::SamplingFlags>, +} + +#[derive(Debug)] +struct EntryPoint { + stage: crate::ShaderStage, + name: String, + early_depth_test: Option<crate::EarlyDepthTest>, + workgroup_size: [u32; 3], + variable_ids: Vec<spirv::Word>, +} + +#[derive(Clone, Debug)] +struct LookupType { + handle: Handle<crate::Type>, + base_id: Option<spirv::Word>, +} + +#[derive(Debug)] +struct LookupConstant { + handle: Handle<crate::Constant>, + type_id: spirv::Word, +} + +#[derive(Debug)] +enum Variable { + Global, + Input(crate::FunctionArgument), + Output(crate::FunctionResult), +} + +#[derive(Debug)] +struct LookupVariable { + inner: Variable, + handle: Handle<crate::GlobalVariable>, + type_id: spirv::Word, +} + +/// Information about SPIR-V result ids, stored in `Parser::lookup_expression`. +#[derive(Clone, Debug)] +struct LookupExpression { + /// The `Expression` constructed for this result. + /// + /// Note that, while a SPIR-V result id can be used in any block dominated + /// by its definition, a Naga `Expression` is only in scope for the rest of + /// its subtree. `Parser::get_expr_handle` takes care of spilling the result + /// to a `LocalVariable` which can then be used anywhere. + handle: Handle<crate::Expression>, + + /// The SPIR-V type of this result. + type_id: spirv::Word, + + /// The label id of the block that defines this expression. + /// + /// This is zero for globals, constants, and function parameters, since they + /// originate outside any function's block. + block_id: spirv::Word, +} + +#[derive(Debug)] +struct LookupMember { + type_id: spirv::Word, + // This is true for either matrices, or arrays of matrices (yikes). + row_major: bool, +} + +#[derive(Clone, Debug)] +enum LookupLoadOverride { + /// For arrays of matrices, we track them but not loading yet. + Pending, + /// For matrices, vectors, and scalars, we pre-load the data. + Loaded(Handle<crate::Expression>), +} + +#[derive(PartialEq)] +enum ExtendedClass { + Global(crate::AddressSpace), + Input, + Output, +} + +#[derive(Clone, Debug)] +pub struct Options { + /// The IR coordinate space matches all the APIs except SPIR-V, + /// so by default we flip the Y coordinate of the `BuiltIn::Position`. + /// This flag can be used to avoid this. + pub adjust_coordinate_space: bool, + /// Only allow shaders with the known set of capabilities. + pub strict_capabilities: bool, + pub block_ctx_dump_prefix: Option<PathBuf>, +} + +impl Default for Options { + fn default() -> Self { + Options { + adjust_coordinate_space: true, + strict_capabilities: false, + block_ctx_dump_prefix: None, + } + } +} + +/// An index into the `BlockContext::bodies` table. +type BodyIndex = usize; + +/// An intermediate representation of a Naga [`Statement`]. +/// +/// `Body` and `BodyFragment` values form a tree: the `BodyIndex` fields of the +/// variants are indices of the child `Body` values in [`BlockContext::bodies`]. +/// The `lower` function assembles the final `Statement` tree from this `Body` +/// tree. See [`BlockContext`] for details. +/// +/// [`Statement`]: crate::Statement +#[derive(Debug)] +enum BodyFragment { + BlockId(spirv::Word), + If { + condition: Handle<crate::Expression>, + accept: BodyIndex, + reject: BodyIndex, + }, + Loop { + /// The body of the loop. Its [`Body::parent`] is the block containing + /// this `Loop` fragment. + body: BodyIndex, + + /// The loop's continuing block. This is a grandchild: its + /// [`Body::parent`] is the loop body block, whose index is above. + continuing: BodyIndex, + + /// If the SPIR-V loop's back-edge branch is conditional, this is the + /// expression that must be `false` for the back-edge to be taken, with + /// `true` being for the "loop merge" (which breaks out of the loop). + break_if: Option<Handle<crate::Expression>>, + }, + Switch { + selector: Handle<crate::Expression>, + cases: Vec<(i32, BodyIndex)>, + default: BodyIndex, + }, + Break, + Continue, +} + +/// An intermediate representation of a Naga [`Block`]. +/// +/// This will be assembled into a `Block` once we've added spills for phi nodes +/// and out-of-scope expressions. See [`BlockContext`] for details. +/// +/// [`Block`]: crate::Block +#[derive(Debug)] +struct Body { + /// The index of the direct parent of this body + parent: usize, + data: Vec<BodyFragment>, +} + +impl Body { + /// Creates a new empty `Body` with the specified `parent` + pub const fn with_parent(parent: usize) -> Self { + Body { + parent, + data: Vec::new(), + } + } +} + +#[derive(Debug)] +struct PhiExpression { + /// The local variable used for the phi node + local: Handle<crate::LocalVariable>, + /// List of (expression, block) + expressions: Vec<(spirv::Word, spirv::Word)>, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum MergeBlockInformation { + LoopMerge, + LoopContinue, + SelectionMerge, + SwitchMerge, +} + +/// Fragments of Naga IR, to be assembled into `Statements` once data flow is +/// resolved. +/// +/// We can't build a Naga `Statement` tree directly from SPIR-V blocks for three +/// main reasons: +/// +/// - We parse a function's SPIR-V blocks in the order they appear in the file. +/// Within a function, SPIR-V requires that a block must precede any blocks it +/// structurally dominates, but doesn't say much else about the order in which +/// they must appear. So while we know we'll see control flow header blocks +/// before their child constructs and merge blocks, those children and the +/// merge blocks may appear in any order - perhaps even intermingled with +/// children of other constructs. +/// +/// - A SPIR-V expression can be used in any SPIR-V block dominated by its +/// definition, whereas Naga expressions are scoped to the rest of their +/// subtree. This means that discovering an expression use later in the +/// function retroactively requires us to have spilled that expression into a +/// local variable back before we left its scope. +/// +/// - We translate SPIR-V OpPhi expressions as Naga local variables in which we +/// store the appropriate value before jumping to the OpPhi's block. +/// +/// All these cases require us to go back and amend previously generated Naga IR +/// based on things we discover later. But modifying old blocks in arbitrary +/// spots in a `Statement` tree is awkward. +/// +/// Instead, as we iterate through the function's body, we accumulate +/// control-flow-free fragments of Naga IR in the [`blocks`] table, while +/// building a skeleton of the Naga `Statement` tree in [`bodies`]. We note any +/// spills and temporaries we must introduce in [`phis`]. +/// +/// Finally, once we've processed the entire function, we add temporaries and +/// spills to the fragmentary `Blocks` as directed by `phis`, and assemble them +/// into the final Naga `Statement` tree as directed by `bodies`. +/// +/// [`blocks`]: BlockContext::blocks +/// [`bodies`]: BlockContext::bodies +/// [`phis`]: BlockContext::phis +/// [`lower`]: function::lower +#[derive(Debug)] +struct BlockContext<'function> { + /// Phi nodes encountered when parsing the function, used to generate spills + /// to local variables. + phis: Vec<PhiExpression>, + + /// Fragments of control-flow-free Naga IR. + /// + /// These will be stitched together into a proper [`Statement`] tree according + /// to `bodies`, once parsing is complete. + /// + /// [`Statement`]: crate::Statement + blocks: FastHashMap<spirv::Word, crate::Block>, + + /// Map from each SPIR-V block's label id to the index of the [`Body`] in + /// [`bodies`] the block should append its contents to. + /// + /// Since each statement in a Naga [`Block`] dominates the next, we are sure + /// to encounter their SPIR-V blocks in order. Thus, by having this table + /// map a SPIR-V structured control flow construct's merge block to the same + /// body index as its header block, when we encounter the merge block, we + /// will simply pick up building the [`Body`] where the header left off. + /// + /// A function's first block is special: it is the only block we encounter + /// without having seen its label mentioned in advance. (It's simply the + /// first `OpLabel` after the `OpFunction`.) We thus assume that any block + /// missing an entry here must be the first block, which always has body + /// index zero. + /// + /// [`bodies`]: BlockContext::bodies + /// [`Block`]: crate::Block + body_for_label: FastHashMap<spirv::Word, BodyIndex>, + + /// SPIR-V metadata about merge/continue blocks. + mergers: FastHashMap<spirv::Word, MergeBlockInformation>, + + /// A table of `Body` values, each representing a block in the final IR. + /// + /// The first element is always the function's top-level block. + bodies: Vec<Body>, + + /// Id of the function currently being processed + function_id: spirv::Word, + /// Expression arena of the function currently being processed + expressions: &'function mut Arena<crate::Expression>, + /// Local variables arena of the function currently being processed + local_arena: &'function mut Arena<crate::LocalVariable>, + /// Constants arena of the module being processed + const_arena: &'function mut Arena<crate::Constant>, + const_expressions: &'function mut Arena<crate::Expression>, + /// Type arena of the module being processed + type_arena: &'function UniqueArena<crate::Type>, + /// Global arena of the module being processed + global_arena: &'function Arena<crate::GlobalVariable>, + /// Arguments of the function currently being processed + arguments: &'function [crate::FunctionArgument], + /// Metadata about the usage of function parameters as sampling objects + parameter_sampling: &'function mut [image::SamplingFlags], +} + +enum SignAnchor { + Result, + Operand, +} + +pub struct Frontend<I> { + data: I, + data_offset: usize, + state: ModuleState, + layouter: Layouter, + temp_bytes: Vec<u8>, + ext_glsl_id: Option<spirv::Word>, + future_decor: FastHashMap<spirv::Word, Decoration>, + future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>, + lookup_member: FastHashMap<(Handle<crate::Type>, MemberIndex), LookupMember>, + handle_sampling: FastHashMap<Handle<crate::GlobalVariable>, image::SamplingFlags>, + lookup_type: FastHashMap<spirv::Word, LookupType>, + lookup_void_type: Option<spirv::Word>, + lookup_storage_buffer_types: FastHashMap<Handle<crate::Type>, crate::StorageAccess>, + // Lookup for samplers and sampled images, storing flags on how they are used. + lookup_constant: FastHashMap<spirv::Word, LookupConstant>, + lookup_variable: FastHashMap<spirv::Word, LookupVariable>, + lookup_expression: FastHashMap<spirv::Word, LookupExpression>, + // Load overrides are used to work around row-major matrices + lookup_load_override: FastHashMap<spirv::Word, LookupLoadOverride>, + lookup_sampled_image: FastHashMap<spirv::Word, image::LookupSampledImage>, + lookup_function_type: FastHashMap<spirv::Word, LookupFunctionType>, + lookup_function: FastHashMap<spirv::Word, LookupFunction>, + lookup_entry_point: FastHashMap<spirv::Word, EntryPoint>, + //Note: each `OpFunctionCall` gets a single entry here, indexed by the + // dummy `Handle<crate::Function>` of the call site. + deferred_function_calls: Vec<spirv::Word>, + dummy_functions: Arena<crate::Function>, + // Graph of all function calls through the module. + // It's used to sort the functions (as nodes) topologically, + // so that in the IR any called function is already known. + function_call_graph: GraphMap<spirv::Word, (), petgraph::Directed>, + options: Options, + + /// Maps for a switch from a case target to the respective body and associated literals that + /// use that target block id. + /// + /// Used to preserve allocations between instruction parsing. + switch_cases: FastIndexMap<spirv::Word, (BodyIndex, Vec<i32>)>, + + /// Tracks access to gl_PerVertex's builtins, it is used to cull unused builtins since initializing those can + /// affect performance and the mere presence of some of these builtins might cause backends to error since they + /// might be unsupported. + /// + /// The problematic builtins are: PointSize, ClipDistance and CullDistance. + /// + /// glslang declares those by default even though they are never written to + /// (see <https://github.com/KhronosGroup/glslang/issues/1868>) + gl_per_vertex_builtin_access: FastHashSet<crate::BuiltIn>, +} + +impl<I: Iterator<Item = u32>> Frontend<I> { + pub fn new(data: I, options: &Options) -> Self { + Frontend { + data, + data_offset: 0, + state: ModuleState::Empty, + layouter: Layouter::default(), + temp_bytes: Vec::new(), + ext_glsl_id: None, + future_decor: FastHashMap::default(), + future_member_decor: FastHashMap::default(), + handle_sampling: FastHashMap::default(), + lookup_member: FastHashMap::default(), + lookup_type: FastHashMap::default(), + lookup_void_type: None, + lookup_storage_buffer_types: FastHashMap::default(), + lookup_constant: FastHashMap::default(), + lookup_variable: FastHashMap::default(), + lookup_expression: FastHashMap::default(), + lookup_load_override: FastHashMap::default(), + lookup_sampled_image: FastHashMap::default(), + lookup_function_type: FastHashMap::default(), + lookup_function: FastHashMap::default(), + lookup_entry_point: FastHashMap::default(), + deferred_function_calls: Vec::default(), + dummy_functions: Arena::new(), + function_call_graph: GraphMap::new(), + options: options.clone(), + switch_cases: FastIndexMap::default(), + gl_per_vertex_builtin_access: FastHashSet::default(), + } + } + + fn span_from(&self, from: usize) -> crate::Span { + crate::Span::from(from..self.data_offset) + } + + fn span_from_with_op(&self, from: usize) -> crate::Span { + crate::Span::from((from - 4)..self.data_offset) + } + + fn next(&mut self) -> Result<u32, Error> { + if let Some(res) = self.data.next() { + self.data_offset += 4; + Ok(res) + } else { + Err(Error::IncompleteData) + } + } + + fn next_inst(&mut self) -> Result<Instruction, Error> { + let word = self.next()?; + let (wc, opcode) = ((word >> 16) as u16, (word & 0xffff) as u16); + if wc == 0 { + return Err(Error::InvalidWordCount); + } + let op = spirv::Op::from_u32(opcode as u32).ok_or(Error::UnknownInstruction(opcode))?; + + Ok(Instruction { op, wc }) + } + + fn next_string(&mut self, mut count: u16) -> Result<(String, u16), Error> { + self.temp_bytes.clear(); + loop { + if count == 0 { + return Err(Error::BadString); + } + count -= 1; + let chars = self.next()?.to_le_bytes(); + let pos = chars.iter().position(|&c| c == 0).unwrap_or(4); + self.temp_bytes.extend_from_slice(&chars[..pos]); + if pos < 4 { + break; + } + } + std::str::from_utf8(&self.temp_bytes) + .map(|s| (s.to_owned(), count)) + .map_err(|_| Error::BadString) + } + + fn next_decoration( + &mut self, + inst: Instruction, + base_words: u16, + dec: &mut Decoration, + ) -> Result<(), Error> { + let raw = self.next()?; + let dec_typed = spirv::Decoration::from_u32(raw).ok_or(Error::InvalidDecoration(raw))?; + log::trace!("\t\t{}: {:?}", dec.debug_name(), dec_typed); + match dec_typed { + spirv::Decoration::BuiltIn => { + inst.expect(base_words + 2)?; + dec.built_in = Some(self.next()?); + } + spirv::Decoration::Location => { + inst.expect(base_words + 2)?; + dec.location = Some(self.next()?); + } + spirv::Decoration::DescriptorSet => { + inst.expect(base_words + 2)?; + dec.desc_set = Some(self.next()?); + } + spirv::Decoration::Binding => { + inst.expect(base_words + 2)?; + dec.desc_index = Some(self.next()?); + } + spirv::Decoration::BufferBlock => { + dec.storage_buffer = true; + } + spirv::Decoration::Offset => { + inst.expect(base_words + 2)?; + dec.offset = Some(self.next()?); + } + spirv::Decoration::ArrayStride => { + inst.expect(base_words + 2)?; + dec.array_stride = NonZeroU32::new(self.next()?); + } + spirv::Decoration::MatrixStride => { + inst.expect(base_words + 2)?; + dec.matrix_stride = NonZeroU32::new(self.next()?); + } + spirv::Decoration::Invariant => { + dec.invariant = true; + } + spirv::Decoration::NoPerspective => { + dec.interpolation = Some(crate::Interpolation::Linear); + } + spirv::Decoration::Flat => { + dec.interpolation = Some(crate::Interpolation::Flat); + } + spirv::Decoration::Centroid => { + dec.sampling = Some(crate::Sampling::Centroid); + } + spirv::Decoration::Sample => { + dec.sampling = Some(crate::Sampling::Sample); + } + spirv::Decoration::NonReadable => { + dec.flags |= DecorationFlags::NON_READABLE; + } + spirv::Decoration::NonWritable => { + dec.flags |= DecorationFlags::NON_WRITABLE; + } + spirv::Decoration::ColMajor => { + dec.matrix_major = Some(Majority::Column); + } + spirv::Decoration::RowMajor => { + dec.matrix_major = Some(Majority::Row); + } + spirv::Decoration::SpecId => { + dec.specialization = Some(self.next()?); + } + other => { + log::warn!("Unknown decoration {:?}", other); + for _ in base_words + 1..inst.wc { + let _var = self.next()?; + } + } + } + Ok(()) + } + + /// Return the Naga `Expression` for a given SPIR-V result `id`. + /// + /// `lookup` must be the `LookupExpression` for `id`. + /// + /// SPIR-V result ids can be used by any block dominated by the id's + /// definition, but Naga `Expressions` are only in scope for the remainder + /// of their `Statement` subtree. This means that the `Expression` generated + /// for `id` may no longer be in scope. In such cases, this function takes + /// care of spilling the value of `id` to a `LocalVariable` which can then + /// be used anywhere. The SPIR-V domination rule ensures that the + /// `LocalVariable` has been initialized before it is used. + /// + /// The `body_idx` argument should be the index of the `Body` that hopes to + /// use `id`'s `Expression`. + fn get_expr_handle( + &self, + id: spirv::Word, + lookup: &LookupExpression, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + body_idx: BodyIndex, + ) -> Handle<crate::Expression> { + // What `Body` was `id` defined in? + let expr_body_idx = ctx + .body_for_label + .get(&lookup.block_id) + .copied() + .unwrap_or(0); + + // Don't need to do a load/store if the expression is in the main body + // or if the expression is in the same body as where the query was + // requested. The body_idx might actually not be the final one if a loop + // or conditional occurs but in those cases we know that the new body + // will be a subscope of the body that was passed so we can still reuse + // the handle and not issue a load/store. + if is_parent(body_idx, expr_body_idx, ctx) { + lookup.handle + } else { + // Add a temporary variable of the same type which will be used to + // store the original expression and used in the current block + let ty = self.lookup_type[&lookup.type_id].handle; + let local = ctx.local_arena.append( + crate::LocalVariable { + name: None, + ty, + init: None, + }, + crate::Span::default(), + ); + + block.extend(emitter.finish(ctx.expressions)); + let pointer = ctx.expressions.append( + crate::Expression::LocalVariable(local), + crate::Span::default(), + ); + emitter.start(ctx.expressions); + let expr = ctx + .expressions + .append(crate::Expression::Load { pointer }, crate::Span::default()); + + // Add a slightly odd entry to the phi table, so that while `id`'s + // `Expression` is still in scope, the usual phi processing will + // spill its value to `local`, where we can find it later. + // + // This pretends that the block in which `id` is defined is the + // predecessor of some other block with a phi in it that cites id as + // one of its sources, and uses `local` as its variable. There is no + // such phi, but nobody needs to know that. + ctx.phis.push(PhiExpression { + local, + expressions: vec![(id, lookup.block_id)], + }); + + expr + } + } + + fn parse_expr_unary_op( + &mut self, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + op: crate::UnaryOperator, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let p_id = self.next()?; + + let p_lexp = self.lookup_expression.lookup(p_id)?; + let handle = self.get_expr_handle(p_id, p_lexp, ctx, emitter, block, body_idx); + + let expr = crate::Expression::Unary { op, expr: handle }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, self.span_from_with_op(start)), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + fn parse_expr_binary_op( + &mut self, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + op: crate::BinaryOperator, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let p1_id = self.next()?; + let p2_id = self.next()?; + + let p1_lexp = self.lookup_expression.lookup(p1_id)?; + let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx); + let p2_lexp = self.lookup_expression.lookup(p2_id)?; + let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx); + + let expr = crate::Expression::Binary { op, left, right }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, self.span_from_with_op(start)), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + /// A more complicated version of the unary op, + /// where we force the operand to have the same type as the result. + fn parse_expr_unary_op_sign_adjusted( + &mut self, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + op: crate::UnaryOperator, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let p1_id = self.next()?; + let span = self.span_from_with_op(start); + + let p1_lexp = self.lookup_expression.lookup(p1_id)?; + let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx); + + let result_lookup_ty = self.lookup_type.lookup(result_type_id)?; + let kind = ctx.type_arena[result_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + + let expr = crate::Expression::Unary { + op, + expr: if p1_lexp.type_id == result_type_id { + left + } else { + ctx.expressions.append( + crate::Expression::As { + expr: left, + kind, + convert: None, + }, + span, + ) + }, + }; + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + /// A more complicated version of the binary op, + /// where we force the operand to have the same type as the result. + /// This is mostly needed for "i++" and "i--" coming from GLSL. + #[allow(clippy::too_many_arguments)] + fn parse_expr_binary_op_sign_adjusted( + &mut self, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + op: crate::BinaryOperator, + // For arithmetic operations, we need the sign of operands to match the result. + // For boolean operations, however, the operands need to match the signs, but + // result is always different - a boolean. + anchor: SignAnchor, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let p1_id = self.next()?; + let p2_id = self.next()?; + let span = self.span_from_with_op(start); + + let p1_lexp = self.lookup_expression.lookup(p1_id)?; + let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx); + let p2_lexp = self.lookup_expression.lookup(p2_id)?; + let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx); + + let expected_type_id = match anchor { + SignAnchor::Result => result_type_id, + SignAnchor::Operand => p1_lexp.type_id, + }; + let expected_lookup_ty = self.lookup_type.lookup(expected_type_id)?; + let kind = ctx.type_arena[expected_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + + let expr = crate::Expression::Binary { + op, + left: if p1_lexp.type_id == expected_type_id { + left + } else { + ctx.expressions.append( + crate::Expression::As { + expr: left, + kind, + convert: None, + }, + span, + ) + }, + right: if p2_lexp.type_id == expected_type_id { + right + } else { + ctx.expressions.append( + crate::Expression::As { + expr: right, + kind, + convert: None, + }, + span, + ) + }, + }; + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + /// A version of the binary op where one or both of the arguments might need to be casted to a + /// specific integer kind (unsigned or signed), used for operations like OpINotEqual or + /// OpUGreaterThan. + #[allow(clippy::too_many_arguments)] + fn parse_expr_int_comparison( + &mut self, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + op: crate::BinaryOperator, + kind: crate::ScalarKind, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let p1_id = self.next()?; + let p2_id = self.next()?; + let span = self.span_from_with_op(start); + + let p1_lexp = self.lookup_expression.lookup(p1_id)?; + let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx); + let p1_lookup_ty = self.lookup_type.lookup(p1_lexp.type_id)?; + let p1_kind = ctx.type_arena[p1_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + let p2_lexp = self.lookup_expression.lookup(p2_id)?; + let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx); + let p2_lookup_ty = self.lookup_type.lookup(p2_lexp.type_id)?; + let p2_kind = ctx.type_arena[p2_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + + let expr = crate::Expression::Binary { + op, + left: if p1_kind == kind { + left + } else { + ctx.expressions.append( + crate::Expression::As { + expr: left, + kind, + convert: None, + }, + span, + ) + }, + right: if p2_kind == kind { + right + } else { + ctx.expressions.append( + crate::Expression::As { + expr: right, + kind, + convert: None, + }, + span, + ) + }, + }; + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + fn parse_expr_shift_op( + &mut self, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + op: crate::BinaryOperator, + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let p1_id = self.next()?; + let p2_id = self.next()?; + + let span = self.span_from_with_op(start); + + let p1_lexp = self.lookup_expression.lookup(p1_id)?; + let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx); + let p2_lexp = self.lookup_expression.lookup(p2_id)?; + let p2_handle = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx); + // convert the shift to Uint + let right = ctx.expressions.append( + crate::Expression::As { + expr: p2_handle, + kind: crate::ScalarKind::Uint, + convert: None, + }, + span, + ); + + let expr = crate::Expression::Binary { op, left, right }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + fn parse_expr_derivative( + &mut self, + ctx: &mut BlockContext, + emitter: &mut crate::proc::Emitter, + block: &mut crate::Block, + block_id: spirv::Word, + body_idx: usize, + (axis, ctrl): (crate::DerivativeAxis, crate::DerivativeControl), + ) -> Result<(), Error> { + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let arg_id = self.next()?; + + let arg_lexp = self.lookup_expression.lookup(arg_id)?; + let arg_handle = self.get_expr_handle(arg_id, arg_lexp, ctx, emitter, block, body_idx); + + let expr = crate::Expression::Derivative { + axis, + ctrl, + expr: arg_handle, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, self.span_from_with_op(start)), + type_id: result_type_id, + block_id, + }, + ); + Ok(()) + } + + #[allow(clippy::too_many_arguments)] + fn insert_composite( + &self, + root_expr: Handle<crate::Expression>, + root_type_id: spirv::Word, + object_expr: Handle<crate::Expression>, + selections: &[spirv::Word], + type_arena: &UniqueArena<crate::Type>, + expressions: &mut Arena<crate::Expression>, + span: crate::Span, + ) -> Result<Handle<crate::Expression>, Error> { + let selection = match selections.first() { + Some(&index) => index, + None => return Ok(object_expr), + }; + let root_span = expressions.get_span(root_expr); + let root_lookup = self.lookup_type.lookup(root_type_id)?; + + let (count, child_type_id) = match type_arena[root_lookup.handle].inner { + crate::TypeInner::Struct { ref members, .. } => { + let child_member = self + .lookup_member + .get(&(root_lookup.handle, selection)) + .ok_or(Error::InvalidAccessType(root_type_id))?; + (members.len(), child_member.type_id) + } + crate::TypeInner::Array { size, .. } => { + let size = match size { + crate::ArraySize::Constant(size) => size.get(), + // A runtime sized array is not a composite type + crate::ArraySize::Dynamic => { + return Err(Error::InvalidAccessType(root_type_id)) + } + }; + + let child_type_id = root_lookup + .base_id + .ok_or(Error::InvalidAccessType(root_type_id))?; + + (size as usize, child_type_id) + } + crate::TypeInner::Vector { size, .. } + | crate::TypeInner::Matrix { columns: size, .. } => { + let child_type_id = root_lookup + .base_id + .ok_or(Error::InvalidAccessType(root_type_id))?; + (size as usize, child_type_id) + } + _ => return Err(Error::InvalidAccessType(root_type_id)), + }; + + let mut components = Vec::with_capacity(count); + for index in 0..count as u32 { + let expr = expressions.append( + crate::Expression::AccessIndex { + base: root_expr, + index, + }, + if index == selection { span } else { root_span }, + ); + components.push(expr); + } + components[selection as usize] = self.insert_composite( + components[selection as usize], + child_type_id, + object_expr, + &selections[1..], + type_arena, + expressions, + span, + )?; + + Ok(expressions.append( + crate::Expression::Compose { + ty: root_lookup.handle, + components, + }, + span, + )) + } + + /// Add the next SPIR-V block's contents to `block_ctx`. + /// + /// Except for the function's entry block, `block_id` should be the label of + /// a block we've seen mentioned before, with an entry in + /// `block_ctx.body_for_label` to tell us which `Body` it contributes to. + fn next_block(&mut self, block_id: spirv::Word, ctx: &mut BlockContext) -> Result<(), Error> { + // Extend `body` with the correct form for a branch to `target`. + fn merger(body: &mut Body, target: &MergeBlockInformation) { + body.data.push(match *target { + MergeBlockInformation::LoopContinue => BodyFragment::Continue, + MergeBlockInformation::LoopMerge | MergeBlockInformation::SwitchMerge => { + BodyFragment::Break + } + + // Finishing a selection merge means just falling off the end of + // the `accept` or `reject` block of the `If` statement. + MergeBlockInformation::SelectionMerge => return, + }) + } + + let mut emitter = crate::proc::Emitter::default(); + emitter.start(ctx.expressions); + + // Find the `Body` to which this block contributes. + // + // If this is some SPIR-V structured control flow construct's merge + // block, then `body_idx` will refer to the same `Body` as the header, + // so that we simply pick up accumulating the `Body` where the header + // left off. Each of the statements in a block dominates the next, so + // we're sure to encounter their SPIR-V blocks in order, ensuring that + // the `Body` will be assembled in the proper order. + // + // Note that, unlike every other kind of SPIR-V block, we don't know the + // function's first block's label in advance. Thus, we assume that if + // this block has no entry in `ctx.body_for_label`, it must be the + // function's first block. This always has body index zero. + let mut body_idx = *ctx.body_for_label.entry(block_id).or_default(); + + // The Naga IR block this call builds. This will end up as + // `ctx.blocks[&block_id]`, and `ctx.bodies[body_idx]` will refer to it + // via a `BodyFragment::BlockId`. + let mut block = crate::Block::new(); + + // Stores the merge block as defined by a `OpSelectionMerge` otherwise is `None` + // + // This is used in `OpSwitch` to promote the `MergeBlockInformation` from + // `SelectionMerge` to `SwitchMerge` to allow `Break`s this isn't desirable for + // `LoopMerge`s because otherwise `Continue`s wouldn't be allowed + let mut selection_merge_block = None; + + macro_rules! get_expr_handle { + ($id:expr, $lexp:expr) => { + self.get_expr_handle($id, $lexp, ctx, &mut emitter, &mut block, body_idx) + }; + } + macro_rules! parse_expr_op { + ($op:expr, BINARY) => { + self.parse_expr_binary_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op) + }; + + ($op:expr, SHIFT) => { + self.parse_expr_shift_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op) + }; + ($op:expr, UNARY) => { + self.parse_expr_unary_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op) + }; + ($axis:expr, $ctrl:expr, DERIVATIVE) => { + self.parse_expr_derivative( + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + ($axis, $ctrl), + ) + }; + } + + let terminator = loop { + use spirv::Op; + let start = self.data_offset; + let inst = self.next_inst()?; + let span = crate::Span::from(start..(start + 4 * (inst.wc as usize))); + log::debug!("\t\t{:?} [{}]", inst.op, inst.wc); + + match inst.op { + Op::Line => { + inst.expect(4)?; + let _file_id = self.next()?; + let _row_id = self.next()?; + let _col_id = self.next()?; + } + Op::NoLine => inst.expect(1)?, + Op::Undef => { + inst.expect(3)?; + let type_id = self.next()?; + let id = self.next()?; + let type_lookup = self.lookup_type.lookup(type_id)?; + let ty = type_lookup.handle; + + self.lookup_expression.insert( + id, + LookupExpression { + handle: ctx + .expressions + .append(crate::Expression::ZeroValue(ty), span), + type_id, + block_id, + }, + ); + } + Op::Variable => { + inst.expect_at_least(4)?; + block.extend(emitter.finish(ctx.expressions)); + + let result_type_id = self.next()?; + let result_id = self.next()?; + let _storage_class = self.next()?; + let init = if inst.wc > 4 { + inst.expect(5)?; + let init_id = self.next()?; + let lconst = self.lookup_constant.lookup(init_id)?; + Some( + ctx.expressions + .append(crate::Expression::Constant(lconst.handle), span), + ) + } else { + None + }; + + let name = self + .future_decor + .remove(&result_id) + .and_then(|decor| decor.name); + if let Some(ref name) = name { + log::debug!("\t\t\tid={} name={}", result_id, name); + } + let lookup_ty = self.lookup_type.lookup(result_type_id)?; + let var_handle = ctx.local_arena.append( + crate::LocalVariable { + name, + ty: match ctx.type_arena[lookup_ty.handle].inner { + crate::TypeInner::Pointer { base, .. } => base, + _ => lookup_ty.handle, + }, + init, + }, + span, + ); + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx + .expressions + .append(crate::Expression::LocalVariable(var_handle), span), + type_id: result_type_id, + block_id, + }, + ); + emitter.start(ctx.expressions); + } + Op::Phi => { + inst.expect_at_least(3)?; + block.extend(emitter.finish(ctx.expressions)); + + let result_type_id = self.next()?; + let result_id = self.next()?; + + let name = format!("phi_{result_id}"); + let local = ctx.local_arena.append( + crate::LocalVariable { + name: Some(name), + ty: self.lookup_type.lookup(result_type_id)?.handle, + init: None, + }, + self.span_from(start), + ); + let pointer = ctx + .expressions + .append(crate::Expression::LocalVariable(local), span); + + let in_count = (inst.wc - 3) / 2; + let mut phi = PhiExpression { + local, + expressions: Vec::with_capacity(in_count as usize), + }; + for _ in 0..in_count { + let expr = self.next()?; + let block = self.next()?; + phi.expressions.push((expr, block)); + } + + ctx.phis.push(phi); + emitter.start(ctx.expressions); + + // Associate the lookup with an actual value, which is emitted + // into the current block. + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx + .expressions + .append(crate::Expression::Load { pointer }, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::AccessChain | Op::InBoundsAccessChain => { + struct AccessExpression { + base_handle: Handle<crate::Expression>, + type_id: spirv::Word, + load_override: Option<LookupLoadOverride>, + } + + inst.expect_at_least(4)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let base_id = self.next()?; + log::trace!("\t\t\tlooking up expr {:?}", base_id); + + let mut acex = { + let lexp = self.lookup_expression.lookup(base_id)?; + let lty = self.lookup_type.lookup(lexp.type_id)?; + + // HACK `OpAccessChain` and `OpInBoundsAccessChain` + // require for the result type to be a pointer, but if + // we're given a pointer to an image / sampler, it will + // be *already* dereferenced, since we do that early + // during `parse_type_pointer()`. + // + // This can happen only through `BindingArray`, since + // that's the only case where one can obtain a pointer + // to an image / sampler, and so let's match on that: + let dereference = match ctx.type_arena[lty.handle].inner { + crate::TypeInner::BindingArray { .. } => false, + _ => true, + }; + + let type_id = if dereference { + lty.base_id.ok_or(Error::InvalidAccessType(lexp.type_id))? + } else { + lexp.type_id + }; + + AccessExpression { + base_handle: get_expr_handle!(base_id, lexp), + type_id, + load_override: self.lookup_load_override.get(&base_id).cloned(), + } + }; + + for _ in 4..inst.wc { + let access_id = self.next()?; + log::trace!("\t\t\tlooking up index expr {:?}", access_id); + let index_expr = self.lookup_expression.lookup(access_id)?.clone(); + let index_expr_handle = get_expr_handle!(access_id, &index_expr); + let index_expr_data = &ctx.expressions[index_expr.handle]; + let index_maybe = match *index_expr_data { + crate::Expression::Constant(const_handle) => Some( + ctx.gctx() + .eval_expr_to_u32(ctx.const_arena[const_handle].init) + .map_err(|_| { + Error::InvalidAccess(crate::Expression::Constant( + const_handle, + )) + })?, + ), + _ => None, + }; + + log::trace!("\t\t\tlooking up type {:?}", acex.type_id); + let type_lookup = self.lookup_type.lookup(acex.type_id)?; + let ty = &ctx.type_arena[type_lookup.handle]; + acex = match ty.inner { + // can only index a struct with a constant + crate::TypeInner::Struct { ref members, .. } => { + let index = index_maybe + .ok_or_else(|| Error::InvalidAccess(index_expr_data.clone()))?; + + let lookup_member = self + .lookup_member + .get(&(type_lookup.handle, index)) + .ok_or(Error::InvalidAccessType(acex.type_id))?; + let base_handle = ctx.expressions.append( + crate::Expression::AccessIndex { + base: acex.base_handle, + index, + }, + span, + ); + + if ty.name.as_deref() == Some("gl_PerVertex") { + if let Some(crate::Binding::BuiltIn(built_in)) = + members[index as usize].binding + { + self.gl_per_vertex_builtin_access.insert(built_in); + } + } + + AccessExpression { + base_handle, + type_id: lookup_member.type_id, + load_override: if lookup_member.row_major { + debug_assert!(acex.load_override.is_none()); + let sub_type_lookup = + self.lookup_type.lookup(lookup_member.type_id)?; + Some(match ctx.type_arena[sub_type_lookup.handle].inner { + // load it transposed, to match column major expectations + crate::TypeInner::Matrix { .. } => { + let loaded = ctx.expressions.append( + crate::Expression::Load { + pointer: base_handle, + }, + span, + ); + let transposed = ctx.expressions.append( + crate::Expression::Math { + fun: crate::MathFunction::Transpose, + arg: loaded, + arg1: None, + arg2: None, + arg3: None, + }, + span, + ); + LookupLoadOverride::Loaded(transposed) + } + _ => LookupLoadOverride::Pending, + }) + } else { + None + }, + } + } + crate::TypeInner::Matrix { .. } => { + let load_override = match acex.load_override { + // We are indexing inside a row-major matrix + Some(LookupLoadOverride::Loaded(load_expr)) => { + let index = index_maybe.ok_or_else(|| { + Error::InvalidAccess(index_expr_data.clone()) + })?; + let sub_handle = ctx.expressions.append( + crate::Expression::AccessIndex { + base: load_expr, + index, + }, + span, + ); + Some(LookupLoadOverride::Loaded(sub_handle)) + } + _ => None, + }; + let sub_expr = match index_maybe { + Some(index) => crate::Expression::AccessIndex { + base: acex.base_handle, + index, + }, + None => crate::Expression::Access { + base: acex.base_handle, + index: index_expr_handle, + }, + }; + AccessExpression { + base_handle: ctx.expressions.append(sub_expr, span), + type_id: type_lookup + .base_id + .ok_or(Error::InvalidAccessType(acex.type_id))?, + load_override, + } + } + // This must be a vector or an array. + _ => { + let base_handle = ctx.expressions.append( + crate::Expression::Access { + base: acex.base_handle, + index: index_expr_handle, + }, + span, + ); + let load_override = match acex.load_override { + // If there is a load override in place, then we always end up + // with a side-loaded value here. + Some(lookup_load_override) => { + let sub_expr = match lookup_load_override { + // We must be indexing into the array of row-major matrices. + // Let's load the result of indexing and transpose it. + LookupLoadOverride::Pending => { + let loaded = ctx.expressions.append( + crate::Expression::Load { + pointer: base_handle, + }, + span, + ); + ctx.expressions.append( + crate::Expression::Math { + fun: crate::MathFunction::Transpose, + arg: loaded, + arg1: None, + arg2: None, + arg3: None, + }, + span, + ) + } + // We are indexing inside a row-major matrix. + LookupLoadOverride::Loaded(load_expr) => { + ctx.expressions.append( + crate::Expression::Access { + base: load_expr, + index: index_expr_handle, + }, + span, + ) + } + }; + Some(LookupLoadOverride::Loaded(sub_expr)) + } + None => None, + }; + AccessExpression { + base_handle, + type_id: type_lookup + .base_id + .ok_or(Error::InvalidAccessType(acex.type_id))?, + load_override, + } + } + }; + } + + if let Some(load_expr) = acex.load_override { + self.lookup_load_override.insert(result_id, load_expr); + } + let lookup_expression = LookupExpression { + handle: acex.base_handle, + type_id: result_type_id, + block_id, + }; + self.lookup_expression.insert(result_id, lookup_expression); + } + Op::VectorExtractDynamic => { + inst.expect(5)?; + + let result_type_id = self.next()?; + let id = self.next()?; + let composite_id = self.next()?; + let index_id = self.next()?; + + let root_lexp = self.lookup_expression.lookup(composite_id)?; + let root_handle = get_expr_handle!(composite_id, root_lexp); + let root_type_lookup = self.lookup_type.lookup(root_lexp.type_id)?; + let index_lexp = self.lookup_expression.lookup(index_id)?; + let index_handle = get_expr_handle!(index_id, index_lexp); + let index_type = self.lookup_type.lookup(index_lexp.type_id)?.handle; + + let num_components = match ctx.type_arena[root_type_lookup.handle].inner { + crate::TypeInner::Vector { size, .. } => size as u32, + _ => return Err(Error::InvalidVectorType(root_type_lookup.handle)), + }; + + let mut make_index = |ctx: &mut BlockContext, index: u32| { + make_index_literal( + ctx, + index, + &mut block, + &mut emitter, + index_type, + index_lexp.type_id, + span, + ) + }; + + let index_expr = make_index(ctx, 0)?; + let mut handle = ctx.expressions.append( + crate::Expression::Access { + base: root_handle, + index: index_expr, + }, + span, + ); + for index in 1..num_components { + let index_expr = make_index(ctx, index)?; + let access_expr = ctx.expressions.append( + crate::Expression::Access { + base: root_handle, + index: index_expr, + }, + span, + ); + let cond = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Equal, + left: index_expr, + right: index_handle, + }, + span, + ); + handle = ctx.expressions.append( + crate::Expression::Select { + condition: cond, + accept: access_expr, + reject: handle, + }, + span, + ); + } + + self.lookup_expression.insert( + id, + LookupExpression { + handle, + type_id: result_type_id, + block_id, + }, + ); + } + Op::VectorInsertDynamic => { + inst.expect(6)?; + + let result_type_id = self.next()?; + let id = self.next()?; + let composite_id = self.next()?; + let object_id = self.next()?; + let index_id = self.next()?; + + let object_lexp = self.lookup_expression.lookup(object_id)?; + let object_handle = get_expr_handle!(object_id, object_lexp); + let root_lexp = self.lookup_expression.lookup(composite_id)?; + let root_handle = get_expr_handle!(composite_id, root_lexp); + let root_type_lookup = self.lookup_type.lookup(root_lexp.type_id)?; + let index_lexp = self.lookup_expression.lookup(index_id)?; + let index_handle = get_expr_handle!(index_id, index_lexp); + let index_type = self.lookup_type.lookup(index_lexp.type_id)?.handle; + + let num_components = match ctx.type_arena[root_type_lookup.handle].inner { + crate::TypeInner::Vector { size, .. } => size as u32, + _ => return Err(Error::InvalidVectorType(root_type_lookup.handle)), + }; + + let mut components = Vec::with_capacity(num_components as usize); + for index in 0..num_components { + let index_expr = make_index_literal( + ctx, + index, + &mut block, + &mut emitter, + index_type, + index_lexp.type_id, + span, + )?; + let access_expr = ctx.expressions.append( + crate::Expression::Access { + base: root_handle, + index: index_expr, + }, + span, + ); + let cond = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Equal, + left: index_expr, + right: index_handle, + }, + span, + ); + let handle = ctx.expressions.append( + crate::Expression::Select { + condition: cond, + accept: object_handle, + reject: access_expr, + }, + span, + ); + components.push(handle); + } + let handle = ctx.expressions.append( + crate::Expression::Compose { + ty: root_type_lookup.handle, + components, + }, + span, + ); + + self.lookup_expression.insert( + id, + LookupExpression { + handle, + type_id: result_type_id, + block_id, + }, + ); + } + Op::CompositeExtract => { + inst.expect_at_least(4)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let base_id = self.next()?; + log::trace!("\t\t\tlooking up expr {:?}", base_id); + let mut lexp = self.lookup_expression.lookup(base_id)?.clone(); + lexp.handle = get_expr_handle!(base_id, &lexp); + for _ in 4..inst.wc { + let index = self.next()?; + log::trace!("\t\t\tlooking up type {:?}", lexp.type_id); + let type_lookup = self.lookup_type.lookup(lexp.type_id)?; + let type_id = match ctx.type_arena[type_lookup.handle].inner { + crate::TypeInner::Struct { .. } => { + self.lookup_member + .get(&(type_lookup.handle, index)) + .ok_or(Error::InvalidAccessType(lexp.type_id))? + .type_id + } + crate::TypeInner::Array { .. } + | crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } => type_lookup + .base_id + .ok_or(Error::InvalidAccessType(lexp.type_id))?, + ref other => { + log::warn!("composite type {:?}", other); + return Err(Error::UnsupportedType(type_lookup.handle)); + } + }; + lexp = LookupExpression { + handle: ctx.expressions.append( + crate::Expression::AccessIndex { + base: lexp.handle, + index, + }, + span, + ), + type_id, + block_id, + }; + } + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: lexp.handle, + type_id: result_type_id, + block_id, + }, + ); + } + Op::CompositeInsert => { + inst.expect_at_least(5)?; + + let result_type_id = self.next()?; + let id = self.next()?; + let object_id = self.next()?; + let composite_id = self.next()?; + let mut selections = Vec::with_capacity(inst.wc as usize - 5); + for _ in 5..inst.wc { + selections.push(self.next()?); + } + + let object_lexp = self.lookup_expression.lookup(object_id)?.clone(); + let object_handle = get_expr_handle!(object_id, &object_lexp); + let root_lexp = self.lookup_expression.lookup(composite_id)?.clone(); + let root_handle = get_expr_handle!(composite_id, &root_lexp); + let handle = self.insert_composite( + root_handle, + result_type_id, + object_handle, + &selections, + ctx.type_arena, + ctx.expressions, + span, + )?; + + self.lookup_expression.insert( + id, + LookupExpression { + handle, + type_id: result_type_id, + block_id, + }, + ); + } + Op::CompositeConstruct => { + inst.expect_at_least(3)?; + + let result_type_id = self.next()?; + let id = self.next()?; + let mut components = Vec::with_capacity(inst.wc as usize - 2); + for _ in 3..inst.wc { + let comp_id = self.next()?; + log::trace!("\t\t\tlooking up expr {:?}", comp_id); + let lexp = self.lookup_expression.lookup(comp_id)?; + let handle = get_expr_handle!(comp_id, lexp); + components.push(handle); + } + let ty = self.lookup_type.lookup(result_type_id)?.handle; + let first = components[0]; + let expr = match ctx.type_arena[ty].inner { + // this is an optimization to detect the splat + crate::TypeInner::Vector { size, .. } + if components.len() == size as usize + && components[1..].iter().all(|&c| c == first) => + { + crate::Expression::Splat { size, value: first } + } + _ => crate::Expression::Compose { ty, components }, + }; + self.lookup_expression.insert( + id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::Load => { + inst.expect_at_least(4)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let pointer_id = self.next()?; + if inst.wc != 4 { + inst.expect(5)?; + let _memory_access = self.next()?; + } + + let base_lexp = self.lookup_expression.lookup(pointer_id)?; + let base_handle = get_expr_handle!(pointer_id, base_lexp); + let type_lookup = self.lookup_type.lookup(base_lexp.type_id)?; + let handle = match ctx.type_arena[type_lookup.handle].inner { + crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => { + base_handle + } + _ => match self.lookup_load_override.get(&pointer_id) { + Some(&LookupLoadOverride::Loaded(handle)) => handle, + //Note: we aren't handling `LookupLoadOverride::Pending` properly here + _ => ctx.expressions.append( + crate::Expression::Load { + pointer: base_handle, + }, + span, + ), + }, + }; + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle, + type_id: result_type_id, + block_id, + }, + ); + } + Op::Store => { + inst.expect_at_least(3)?; + + let pointer_id = self.next()?; + let value_id = self.next()?; + if inst.wc != 3 { + inst.expect(4)?; + let _memory_access = self.next()?; + } + let base_expr = self.lookup_expression.lookup(pointer_id)?; + let base_handle = get_expr_handle!(pointer_id, base_expr); + let value_expr = self.lookup_expression.lookup(value_id)?; + let value_handle = get_expr_handle!(value_id, value_expr); + + block.extend(emitter.finish(ctx.expressions)); + block.push( + crate::Statement::Store { + pointer: base_handle, + value: value_handle, + }, + span, + ); + emitter.start(ctx.expressions); + } + // Arithmetic Instructions +, -, *, /, % + Op::SNegate | Op::FNegate => { + inst.expect(4)?; + self.parse_expr_unary_op_sign_adjusted( + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + crate::UnaryOperator::Negate, + )?; + } + Op::IAdd + | Op::ISub + | Op::IMul + | Op::BitwiseOr + | Op::BitwiseXor + | Op::BitwiseAnd + | Op::SDiv + | Op::SRem => { + inst.expect(5)?; + let operator = map_binary_operator(inst.op)?; + self.parse_expr_binary_op_sign_adjusted( + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + operator, + SignAnchor::Result, + )?; + } + Op::IEqual | Op::INotEqual => { + inst.expect(5)?; + let operator = map_binary_operator(inst.op)?; + self.parse_expr_binary_op_sign_adjusted( + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + operator, + SignAnchor::Operand, + )?; + } + Op::FAdd => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::Add, BINARY)?; + } + Op::FSub => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::Subtract, BINARY)?; + } + Op::FMul => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::Multiply, BINARY)?; + } + Op::UDiv | Op::FDiv => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::Divide, BINARY)?; + } + Op::UMod | Op::FRem => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::Modulo, BINARY)?; + } + Op::SMod => { + inst.expect(5)?; + + // x - y * int(floor(float(x) / float(y))) + + let start = self.data_offset; + let result_type_id = self.next()?; + let result_id = self.next()?; + let p1_id = self.next()?; + let p2_id = self.next()?; + let span = self.span_from_with_op(start); + + let p1_lexp = self.lookup_expression.lookup(p1_id)?; + let left = self.get_expr_handle( + p1_id, + p1_lexp, + ctx, + &mut emitter, + &mut block, + body_idx, + ); + let p2_lexp = self.lookup_expression.lookup(p2_id)?; + let right = self.get_expr_handle( + p2_id, + p2_lexp, + ctx, + &mut emitter, + &mut block, + body_idx, + ); + + let result_ty = self.lookup_type.lookup(result_type_id)?; + let inner = &ctx.type_arena[result_ty.handle].inner; + let kind = inner.scalar_kind().unwrap(); + let size = inner.size(ctx.gctx()) as u8; + + let left_cast = ctx.expressions.append( + crate::Expression::As { + expr: left, + kind: crate::ScalarKind::Float, + convert: Some(size), + }, + span, + ); + let right_cast = ctx.expressions.append( + crate::Expression::As { + expr: right, + kind: crate::ScalarKind::Float, + convert: Some(size), + }, + span, + ); + let div = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Divide, + left: left_cast, + right: right_cast, + }, + span, + ); + let floor = ctx.expressions.append( + crate::Expression::Math { + fun: crate::MathFunction::Floor, + arg: div, + arg1: None, + arg2: None, + arg3: None, + }, + span, + ); + let cast = ctx.expressions.append( + crate::Expression::As { + expr: floor, + kind, + convert: Some(size), + }, + span, + ); + let mult = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Multiply, + left: cast, + right, + }, + span, + ); + let sub = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Subtract, + left, + right: mult, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: sub, + type_id: result_type_id, + block_id, + }, + ); + } + Op::FMod => { + inst.expect(5)?; + + // x - y * floor(x / y) + + let start = self.data_offset; + let span = self.span_from_with_op(start); + + let result_type_id = self.next()?; + let result_id = self.next()?; + let p1_id = self.next()?; + let p2_id = self.next()?; + + let p1_lexp = self.lookup_expression.lookup(p1_id)?; + let left = self.get_expr_handle( + p1_id, + p1_lexp, + ctx, + &mut emitter, + &mut block, + body_idx, + ); + let p2_lexp = self.lookup_expression.lookup(p2_id)?; + let right = self.get_expr_handle( + p2_id, + p2_lexp, + ctx, + &mut emitter, + &mut block, + body_idx, + ); + + let div = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Divide, + left, + right, + }, + span, + ); + let floor = ctx.expressions.append( + crate::Expression::Math { + fun: crate::MathFunction::Floor, + arg: div, + arg1: None, + arg2: None, + arg3: None, + }, + span, + ); + let mult = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Multiply, + left: floor, + right, + }, + span, + ); + let sub = ctx.expressions.append( + crate::Expression::Binary { + op: crate::BinaryOperator::Subtract, + left, + right: mult, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: sub, + type_id: result_type_id, + block_id, + }, + ); + } + Op::VectorTimesScalar + | Op::VectorTimesMatrix + | Op::MatrixTimesScalar + | Op::MatrixTimesVector + | Op::MatrixTimesMatrix => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::Multiply, BINARY)?; + } + Op::Transpose => { + inst.expect(4)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let matrix_id = self.next()?; + let matrix_lexp = self.lookup_expression.lookup(matrix_id)?; + let matrix_handle = get_expr_handle!(matrix_id, matrix_lexp); + let expr = crate::Expression::Math { + fun: crate::MathFunction::Transpose, + arg: matrix_handle, + arg1: None, + arg2: None, + arg3: None, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::Dot => { + inst.expect(5)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let left_id = self.next()?; + let right_id = self.next()?; + let left_lexp = self.lookup_expression.lookup(left_id)?; + let left_handle = get_expr_handle!(left_id, left_lexp); + let right_lexp = self.lookup_expression.lookup(right_id)?; + let right_handle = get_expr_handle!(right_id, right_lexp); + let expr = crate::Expression::Math { + fun: crate::MathFunction::Dot, + arg: left_handle, + arg1: Some(right_handle), + arg2: None, + arg3: None, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::BitFieldInsert => { + inst.expect(7)?; + + let start = self.data_offset; + let span = self.span_from_with_op(start); + + let result_type_id = self.next()?; + let result_id = self.next()?; + let base_id = self.next()?; + let insert_id = self.next()?; + let offset_id = self.next()?; + let count_id = self.next()?; + let base_lexp = self.lookup_expression.lookup(base_id)?; + let base_handle = get_expr_handle!(base_id, base_lexp); + let insert_lexp = self.lookup_expression.lookup(insert_id)?; + let insert_handle = get_expr_handle!(insert_id, insert_lexp); + let offset_lexp = self.lookup_expression.lookup(offset_id)?; + let offset_handle = get_expr_handle!(offset_id, offset_lexp); + let offset_lookup_ty = self.lookup_type.lookup(offset_lexp.type_id)?; + let count_lexp = self.lookup_expression.lookup(count_id)?; + let count_handle = get_expr_handle!(count_id, count_lexp); + let count_lookup_ty = self.lookup_type.lookup(count_lexp.type_id)?; + + let offset_kind = ctx.type_arena[offset_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + let count_kind = ctx.type_arena[count_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + + let offset_cast_handle = if offset_kind != crate::ScalarKind::Uint { + ctx.expressions.append( + crate::Expression::As { + expr: offset_handle, + kind: crate::ScalarKind::Uint, + convert: None, + }, + span, + ) + } else { + offset_handle + }; + + let count_cast_handle = if count_kind != crate::ScalarKind::Uint { + ctx.expressions.append( + crate::Expression::As { + expr: count_handle, + kind: crate::ScalarKind::Uint, + convert: None, + }, + span, + ) + } else { + count_handle + }; + + let expr = crate::Expression::Math { + fun: crate::MathFunction::InsertBits, + arg: base_handle, + arg1: Some(insert_handle), + arg2: Some(offset_cast_handle), + arg3: Some(count_cast_handle), + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::BitFieldSExtract | Op::BitFieldUExtract => { + inst.expect(6)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let base_id = self.next()?; + let offset_id = self.next()?; + let count_id = self.next()?; + let base_lexp = self.lookup_expression.lookup(base_id)?; + let base_handle = get_expr_handle!(base_id, base_lexp); + let offset_lexp = self.lookup_expression.lookup(offset_id)?; + let offset_handle = get_expr_handle!(offset_id, offset_lexp); + let offset_lookup_ty = self.lookup_type.lookup(offset_lexp.type_id)?; + let count_lexp = self.lookup_expression.lookup(count_id)?; + let count_handle = get_expr_handle!(count_id, count_lexp); + let count_lookup_ty = self.lookup_type.lookup(count_lexp.type_id)?; + + let offset_kind = ctx.type_arena[offset_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + let count_kind = ctx.type_arena[count_lookup_ty.handle] + .inner + .scalar_kind() + .unwrap(); + + let offset_cast_handle = if offset_kind != crate::ScalarKind::Uint { + ctx.expressions.append( + crate::Expression::As { + expr: offset_handle, + kind: crate::ScalarKind::Uint, + convert: None, + }, + span, + ) + } else { + offset_handle + }; + + let count_cast_handle = if count_kind != crate::ScalarKind::Uint { + ctx.expressions.append( + crate::Expression::As { + expr: count_handle, + kind: crate::ScalarKind::Uint, + convert: None, + }, + span, + ) + } else { + count_handle + }; + + let expr = crate::Expression::Math { + fun: crate::MathFunction::ExtractBits, + arg: base_handle, + arg1: Some(offset_cast_handle), + arg2: Some(count_cast_handle), + arg3: None, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::BitReverse | Op::BitCount => { + inst.expect(4)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let base_id = self.next()?; + let base_lexp = self.lookup_expression.lookup(base_id)?; + let base_handle = get_expr_handle!(base_id, base_lexp); + let expr = crate::Expression::Math { + fun: match inst.op { + Op::BitReverse => crate::MathFunction::ReverseBits, + Op::BitCount => crate::MathFunction::CountOneBits, + _ => unreachable!(), + }, + arg: base_handle, + arg1: None, + arg2: None, + arg3: None, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::OuterProduct => { + inst.expect(5)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let left_id = self.next()?; + let right_id = self.next()?; + let left_lexp = self.lookup_expression.lookup(left_id)?; + let left_handle = get_expr_handle!(left_id, left_lexp); + let right_lexp = self.lookup_expression.lookup(right_id)?; + let right_handle = get_expr_handle!(right_id, right_lexp); + let expr = crate::Expression::Math { + fun: crate::MathFunction::Outer, + arg: left_handle, + arg1: Some(right_handle), + arg2: None, + arg3: None, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + // Bitwise instructions + Op::Not => { + inst.expect(4)?; + self.parse_expr_unary_op_sign_adjusted( + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + crate::UnaryOperator::BitwiseNot, + )?; + } + Op::ShiftRightLogical => { + inst.expect(5)?; + //TODO: convert input and result to unsigned + parse_expr_op!(crate::BinaryOperator::ShiftRight, SHIFT)?; + } + Op::ShiftRightArithmetic => { + inst.expect(5)?; + //TODO: convert input and result to signed + parse_expr_op!(crate::BinaryOperator::ShiftRight, SHIFT)?; + } + Op::ShiftLeftLogical => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::ShiftLeft, SHIFT)?; + } + // Sampling + Op::Image => { + inst.expect(4)?; + self.parse_image_uncouple(block_id)?; + } + Op::SampledImage => { + inst.expect(5)?; + self.parse_image_couple()?; + } + Op::ImageWrite => { + let extra = inst.expect_at_least(4)?; + let stmt = + self.parse_image_write(extra, ctx, &mut emitter, &mut block, body_idx)?; + block.extend(emitter.finish(ctx.expressions)); + block.push(stmt, span); + emitter.start(ctx.expressions); + } + Op::ImageFetch | Op::ImageRead => { + let extra = inst.expect_at_least(5)?; + self.parse_image_load( + extra, + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + )?; + } + Op::ImageSampleImplicitLod | Op::ImageSampleExplicitLod => { + let extra = inst.expect_at_least(5)?; + let options = image::SamplingOptions { + compare: false, + project: false, + }; + self.parse_image_sample( + extra, + options, + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + )?; + } + Op::ImageSampleProjImplicitLod | Op::ImageSampleProjExplicitLod => { + let extra = inst.expect_at_least(5)?; + let options = image::SamplingOptions { + compare: false, + project: true, + }; + self.parse_image_sample( + extra, + options, + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + )?; + } + Op::ImageSampleDrefImplicitLod | Op::ImageSampleDrefExplicitLod => { + let extra = inst.expect_at_least(6)?; + let options = image::SamplingOptions { + compare: true, + project: false, + }; + self.parse_image_sample( + extra, + options, + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + )?; + } + Op::ImageSampleProjDrefImplicitLod | Op::ImageSampleProjDrefExplicitLod => { + let extra = inst.expect_at_least(6)?; + let options = image::SamplingOptions { + compare: true, + project: true, + }; + self.parse_image_sample( + extra, + options, + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + )?; + } + Op::ImageQuerySize => { + inst.expect(4)?; + self.parse_image_query_size( + false, + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + )?; + } + Op::ImageQuerySizeLod => { + inst.expect(5)?; + self.parse_image_query_size( + true, + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + )?; + } + Op::ImageQueryLevels => { + inst.expect(4)?; + self.parse_image_query_other(crate::ImageQuery::NumLevels, ctx, block_id)?; + } + Op::ImageQuerySamples => { + inst.expect(4)?; + self.parse_image_query_other(crate::ImageQuery::NumSamples, ctx, block_id)?; + } + // other ops + Op::Select => { + inst.expect(6)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let condition = self.next()?; + let o1_id = self.next()?; + let o2_id = self.next()?; + + let cond_lexp = self.lookup_expression.lookup(condition)?; + let cond_handle = get_expr_handle!(condition, cond_lexp); + let o1_lexp = self.lookup_expression.lookup(o1_id)?; + let o1_handle = get_expr_handle!(o1_id, o1_lexp); + let o2_lexp = self.lookup_expression.lookup(o2_id)?; + let o2_handle = get_expr_handle!(o2_id, o2_lexp); + + let expr = crate::Expression::Select { + condition: cond_handle, + accept: o1_handle, + reject: o2_handle, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::VectorShuffle => { + inst.expect_at_least(5)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let v1_id = self.next()?; + let v2_id = self.next()?; + + let v1_lexp = self.lookup_expression.lookup(v1_id)?; + let v1_lty = self.lookup_type.lookup(v1_lexp.type_id)?; + let v1_handle = get_expr_handle!(v1_id, v1_lexp); + let n1 = match ctx.type_arena[v1_lty.handle].inner { + crate::TypeInner::Vector { size, .. } => size as u32, + _ => return Err(Error::InvalidInnerType(v1_lexp.type_id)), + }; + let v2_lexp = self.lookup_expression.lookup(v2_id)?; + let v2_lty = self.lookup_type.lookup(v2_lexp.type_id)?; + let v2_handle = get_expr_handle!(v2_id, v2_lexp); + let n2 = match ctx.type_arena[v2_lty.handle].inner { + crate::TypeInner::Vector { size, .. } => size as u32, + _ => return Err(Error::InvalidInnerType(v2_lexp.type_id)), + }; + + self.temp_bytes.clear(); + let mut max_component = 0; + for _ in 5..inst.wc as usize { + let mut index = self.next()?; + if index == u32::MAX { + // treat Undefined as X + index = 0; + } + max_component = max_component.max(index); + self.temp_bytes.push(index as u8); + } + + // Check for swizzle first. + let expr = if max_component < n1 { + use crate::SwizzleComponent as Sc; + let size = match self.temp_bytes.len() { + 2 => crate::VectorSize::Bi, + 3 => crate::VectorSize::Tri, + _ => crate::VectorSize::Quad, + }; + let mut pattern = [Sc::X; 4]; + for (pat, index) in pattern.iter_mut().zip(self.temp_bytes.drain(..)) { + *pat = match index { + 0 => Sc::X, + 1 => Sc::Y, + 2 => Sc::Z, + _ => Sc::W, + }; + } + crate::Expression::Swizzle { + size, + vector: v1_handle, + pattern, + } + } else { + // Fall back to access + compose + let mut components = Vec::with_capacity(self.temp_bytes.len()); + for index in self.temp_bytes.drain(..).map(|i| i as u32) { + let expr = if index < n1 { + crate::Expression::AccessIndex { + base: v1_handle, + index, + } + } else if index < n1 + n2 { + crate::Expression::AccessIndex { + base: v2_handle, + index: index - n1, + } + } else { + return Err(Error::InvalidAccessIndex(index)); + }; + components.push(ctx.expressions.append(expr, span)); + } + crate::Expression::Compose { + ty: self.lookup_type.lookup(result_type_id)?.handle, + components, + } + }; + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::Bitcast + | Op::ConvertSToF + | Op::ConvertUToF + | Op::ConvertFToU + | Op::ConvertFToS + | Op::FConvert + | Op::UConvert + | Op::SConvert => { + inst.expect(4)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let value_id = self.next()?; + + let value_lexp = self.lookup_expression.lookup(value_id)?; + let ty_lookup = self.lookup_type.lookup(result_type_id)?; + let scalar = match ctx.type_arena[ty_lookup.handle].inner { + crate::TypeInner::Scalar(scalar) + | crate::TypeInner::Vector { scalar, .. } + | crate::TypeInner::Matrix { scalar, .. } => scalar, + _ => return Err(Error::InvalidAsType(ty_lookup.handle)), + }; + + let expr = crate::Expression::As { + expr: get_expr_handle!(value_id, value_lexp), + kind: scalar.kind, + convert: if scalar.kind == crate::ScalarKind::Bool { + Some(crate::BOOL_WIDTH) + } else if inst.op == Op::Bitcast { + None + } else { + Some(scalar.width) + }, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::FunctionCall => { + inst.expect_at_least(4)?; + block.extend(emitter.finish(ctx.expressions)); + + let result_type_id = self.next()?; + let result_id = self.next()?; + let func_id = self.next()?; + + let mut arguments = Vec::with_capacity(inst.wc as usize - 4); + for _ in 0..arguments.capacity() { + let arg_id = self.next()?; + let lexp = self.lookup_expression.lookup(arg_id)?; + arguments.push(get_expr_handle!(arg_id, lexp)); + } + + // We just need an unique handle here, nothing more. + let function = self.add_call(ctx.function_id, func_id); + + let result = if self.lookup_void_type == Some(result_type_id) { + None + } else { + let expr_handle = ctx + .expressions + .append(crate::Expression::CallResult(function), span); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expr_handle, + type_id: result_type_id, + block_id, + }, + ); + Some(expr_handle) + }; + block.push( + crate::Statement::Call { + function, + arguments, + result, + }, + span, + ); + emitter.start(ctx.expressions); + } + Op::ExtInst => { + use crate::MathFunction as Mf; + use spirv::GLOp as Glo; + + let base_wc = 5; + inst.expect_at_least(base_wc)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + let set_id = self.next()?; + if Some(set_id) != self.ext_glsl_id { + return Err(Error::UnsupportedExtInstSet(set_id)); + } + let inst_id = self.next()?; + let gl_op = Glo::from_u32(inst_id).ok_or(Error::UnsupportedExtInst(inst_id))?; + + let fun = match gl_op { + Glo::Round => Mf::Round, + Glo::RoundEven => Mf::Round, + Glo::Trunc => Mf::Trunc, + Glo::FAbs | Glo::SAbs => Mf::Abs, + Glo::FSign | Glo::SSign => Mf::Sign, + Glo::Floor => Mf::Floor, + Glo::Ceil => Mf::Ceil, + Glo::Fract => Mf::Fract, + Glo::Sin => Mf::Sin, + Glo::Cos => Mf::Cos, + Glo::Tan => Mf::Tan, + Glo::Asin => Mf::Asin, + Glo::Acos => Mf::Acos, + Glo::Atan => Mf::Atan, + Glo::Sinh => Mf::Sinh, + Glo::Cosh => Mf::Cosh, + Glo::Tanh => Mf::Tanh, + Glo::Atan2 => Mf::Atan2, + Glo::Asinh => Mf::Asinh, + Glo::Acosh => Mf::Acosh, + Glo::Atanh => Mf::Atanh, + Glo::Radians => Mf::Radians, + Glo::Degrees => Mf::Degrees, + Glo::Pow => Mf::Pow, + Glo::Exp => Mf::Exp, + Glo::Log => Mf::Log, + Glo::Exp2 => Mf::Exp2, + Glo::Log2 => Mf::Log2, + Glo::Sqrt => Mf::Sqrt, + Glo::InverseSqrt => Mf::InverseSqrt, + Glo::MatrixInverse => Mf::Inverse, + Glo::Determinant => Mf::Determinant, + Glo::ModfStruct => Mf::Modf, + Glo::FMin | Glo::UMin | Glo::SMin | Glo::NMin => Mf::Min, + Glo::FMax | Glo::UMax | Glo::SMax | Glo::NMax => Mf::Max, + Glo::FClamp | Glo::UClamp | Glo::SClamp | Glo::NClamp => Mf::Clamp, + Glo::FMix => Mf::Mix, + Glo::Step => Mf::Step, + Glo::SmoothStep => Mf::SmoothStep, + Glo::Fma => Mf::Fma, + Glo::FrexpStruct => Mf::Frexp, + Glo::Ldexp => Mf::Ldexp, + Glo::Length => Mf::Length, + Glo::Distance => Mf::Distance, + Glo::Cross => Mf::Cross, + Glo::Normalize => Mf::Normalize, + Glo::FaceForward => Mf::FaceForward, + Glo::Reflect => Mf::Reflect, + Glo::Refract => Mf::Refract, + Glo::PackUnorm4x8 => Mf::Pack4x8unorm, + Glo::PackSnorm4x8 => Mf::Pack4x8snorm, + Glo::PackHalf2x16 => Mf::Pack2x16float, + Glo::PackUnorm2x16 => Mf::Pack2x16unorm, + Glo::PackSnorm2x16 => Mf::Pack2x16snorm, + Glo::UnpackUnorm4x8 => Mf::Unpack4x8unorm, + Glo::UnpackSnorm4x8 => Mf::Unpack4x8snorm, + Glo::UnpackHalf2x16 => Mf::Unpack2x16float, + Glo::UnpackUnorm2x16 => Mf::Unpack2x16unorm, + Glo::UnpackSnorm2x16 => Mf::Unpack2x16snorm, + Glo::FindILsb => Mf::FindLsb, + Glo::FindUMsb | Glo::FindSMsb => Mf::FindMsb, + // TODO: https://github.com/gfx-rs/naga/issues/2526 + Glo::Modf | Glo::Frexp => return Err(Error::UnsupportedExtInst(inst_id)), + Glo::IMix + | Glo::PackDouble2x32 + | Glo::UnpackDouble2x32 + | Glo::InterpolateAtCentroid + | Glo::InterpolateAtSample + | Glo::InterpolateAtOffset => { + return Err(Error::UnsupportedExtInst(inst_id)) + } + }; + + let arg_count = fun.argument_count(); + inst.expect(base_wc + arg_count as u16)?; + let arg = { + let arg_id = self.next()?; + let lexp = self.lookup_expression.lookup(arg_id)?; + get_expr_handle!(arg_id, lexp) + }; + let arg1 = if arg_count > 1 { + let arg_id = self.next()?; + let lexp = self.lookup_expression.lookup(arg_id)?; + Some(get_expr_handle!(arg_id, lexp)) + } else { + None + }; + let arg2 = if arg_count > 2 { + let arg_id = self.next()?; + let lexp = self.lookup_expression.lookup(arg_id)?; + Some(get_expr_handle!(arg_id, lexp)) + } else { + None + }; + let arg3 = if arg_count > 3 { + let arg_id = self.next()?; + let lexp = self.lookup_expression.lookup(arg_id)?; + Some(get_expr_handle!(arg_id, lexp)) + } else { + None + }; + + let expr = crate::Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + // Relational and Logical Instructions + Op::LogicalNot => { + inst.expect(4)?; + parse_expr_op!(crate::UnaryOperator::LogicalNot, UNARY)?; + } + Op::LogicalOr => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::LogicalOr, BINARY)?; + } + Op::LogicalAnd => { + inst.expect(5)?; + parse_expr_op!(crate::BinaryOperator::LogicalAnd, BINARY)?; + } + Op::SGreaterThan | Op::SGreaterThanEqual | Op::SLessThan | Op::SLessThanEqual => { + inst.expect(5)?; + self.parse_expr_int_comparison( + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + map_binary_operator(inst.op)?, + crate::ScalarKind::Sint, + )?; + } + Op::UGreaterThan | Op::UGreaterThanEqual | Op::ULessThan | Op::ULessThanEqual => { + inst.expect(5)?; + self.parse_expr_int_comparison( + ctx, + &mut emitter, + &mut block, + block_id, + body_idx, + map_binary_operator(inst.op)?, + crate::ScalarKind::Uint, + )?; + } + Op::FOrdEqual + | Op::FUnordEqual + | Op::FOrdNotEqual + | Op::FUnordNotEqual + | Op::FOrdLessThan + | Op::FUnordLessThan + | Op::FOrdGreaterThan + | Op::FUnordGreaterThan + | Op::FOrdLessThanEqual + | Op::FUnordLessThanEqual + | Op::FOrdGreaterThanEqual + | Op::FUnordGreaterThanEqual + | Op::LogicalEqual + | Op::LogicalNotEqual => { + inst.expect(5)?; + let operator = map_binary_operator(inst.op)?; + parse_expr_op!(operator, BINARY)?; + } + Op::Any | Op::All | Op::IsNan | Op::IsInf | Op::IsFinite | Op::IsNormal => { + inst.expect(4)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let arg_id = self.next()?; + + let arg_lexp = self.lookup_expression.lookup(arg_id)?; + let arg_handle = get_expr_handle!(arg_id, arg_lexp); + + let expr = crate::Expression::Relational { + fun: map_relational_fun(inst.op)?, + argument: arg_handle, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: ctx.expressions.append(expr, span), + type_id: result_type_id, + block_id, + }, + ); + } + Op::Kill => { + inst.expect(1)?; + break Some(crate::Statement::Kill); + } + Op::Unreachable => { + inst.expect(1)?; + break None; + } + Op::Return => { + inst.expect(1)?; + break Some(crate::Statement::Return { value: None }); + } + Op::ReturnValue => { + inst.expect(2)?; + let value_id = self.next()?; + let value_lexp = self.lookup_expression.lookup(value_id)?; + let value_handle = get_expr_handle!(value_id, value_lexp); + break Some(crate::Statement::Return { + value: Some(value_handle), + }); + } + Op::Branch => { + inst.expect(2)?; + let target_id = self.next()?; + + // If this is a branch to a merge or continue block, then + // that ends the current body. + // + // Why can we count on finding an entry here when it's + // needed? SPIR-V requires dominators to appear before + // blocks they dominate, so we will have visited a + // structured control construct's header block before + // anything that could exit it. + if let Some(info) = ctx.mergers.get(&target_id) { + block.extend(emitter.finish(ctx.expressions)); + ctx.blocks.insert(block_id, block); + let body = &mut ctx.bodies[body_idx]; + body.data.push(BodyFragment::BlockId(block_id)); + + merger(body, info); + + return Ok(()); + } + + // If `target_id` has no entry in `ctx.body_for_label`, then + // this must be the only branch to it: + // + // - We've already established that it's not anybody's merge + // block. + // + // - It can't be a switch case. Only switch header blocks + // and other switch cases can branch to a switch case. + // Switch header blocks must dominate all their cases, so + // they must appear in the file before them, and when we + // see `Op::Switch` we populate `ctx.body_for_label` for + // every switch case. + // + // Thus, `target_id` must be a simple extension of the + // current block, which we dominate, so we know we'll + // encounter it later in the file. + ctx.body_for_label.entry(target_id).or_insert(body_idx); + + break None; + } + Op::BranchConditional => { + inst.expect_at_least(4)?; + + let condition = { + let condition_id = self.next()?; + let lexp = self.lookup_expression.lookup(condition_id)?; + get_expr_handle!(condition_id, lexp) + }; + + // HACK(eddyb) Naga doesn't seem to have this helper, + // so it's declared on the fly here for convenience. + #[derive(Copy, Clone)] + struct BranchTarget { + label_id: spirv::Word, + merge_info: Option<MergeBlockInformation>, + } + let branch_target = |label_id| BranchTarget { + label_id, + merge_info: ctx.mergers.get(&label_id).copied(), + }; + + let true_target = branch_target(self.next()?); + let false_target = branch_target(self.next()?); + + // Consume branch weights + for _ in 4..inst.wc { + let _ = self.next()?; + } + + // Handle `OpBranchConditional`s used at the end of a loop + // body's "continuing" section as a "conditional backedge", + // i.e. a `do`-`while` condition, or `break if` in WGSL. + + // HACK(eddyb) this has to go to the parent *twice*, because + // `OpLoopMerge` left the "continuing" section nested in the + // loop body in terms of `parent`, but not `BodyFragment`. + let parent_body_idx = ctx.bodies[body_idx].parent; + let parent_parent_body_idx = ctx.bodies[parent_body_idx].parent; + match ctx.bodies[parent_parent_body_idx].data[..] { + // The `OpLoopMerge`'s `continuing` block and the loop's + // backedge block may not be the same, but they'll both + // belong to the same body. + [.., BodyFragment::Loop { + body: loop_body_idx, + continuing: loop_continuing_idx, + break_if: ref mut break_if_slot @ None, + }] if body_idx == loop_continuing_idx => { + // Try both orderings of break-vs-backedge, because + // SPIR-V is symmetrical here, unlike WGSL `break if`. + let break_if_cond = [true, false].into_iter().find_map(|true_breaks| { + let (break_candidate, backedge_candidate) = if true_breaks { + (true_target, false_target) + } else { + (false_target, true_target) + }; + + if break_candidate.merge_info + != Some(MergeBlockInformation::LoopMerge) + { + return None; + } + + // HACK(eddyb) since Naga doesn't explicitly track + // backedges, this is checking for the outcome of + // `OpLoopMerge` below (even if it looks weird). + let backedge_candidate_is_backedge = + backedge_candidate.merge_info.is_none() + && ctx.body_for_label.get(&backedge_candidate.label_id) + == Some(&loop_body_idx); + if !backedge_candidate_is_backedge { + return None; + } + + Some(if true_breaks { + condition + } else { + ctx.expressions.append( + crate::Expression::Unary { + op: crate::UnaryOperator::LogicalNot, + expr: condition, + }, + span, + ) + }) + }); + + if let Some(break_if_cond) = break_if_cond { + *break_if_slot = Some(break_if_cond); + + // This `OpBranchConditional` ends the "continuing" + // section of the loop body as normal, with the + // `break if` condition having been stashed above. + break None; + } + } + _ => {} + } + + block.extend(emitter.finish(ctx.expressions)); + ctx.blocks.insert(block_id, block); + let body = &mut ctx.bodies[body_idx]; + body.data.push(BodyFragment::BlockId(block_id)); + + let same_target = true_target.label_id == false_target.label_id; + + // Start a body block for the `accept` branch. + let accept = ctx.bodies.len(); + let mut accept_block = Body::with_parent(body_idx); + + // If the `OpBranchConditional` target is somebody else's + // merge or continue block, then put a `Break` or `Continue` + // statement in this new body block. + if let Some(info) = true_target.merge_info { + merger( + match same_target { + true => &mut ctx.bodies[body_idx], + false => &mut accept_block, + }, + &info, + ) + } else { + // Note the body index for the block we're branching to. + let prev = ctx.body_for_label.insert( + true_target.label_id, + match same_target { + true => body_idx, + false => accept, + }, + ); + debug_assert!(prev.is_none()); + } + + if same_target { + return Ok(()); + } + + ctx.bodies.push(accept_block); + + // Handle the `reject` branch just like the `accept` block. + let reject = ctx.bodies.len(); + let mut reject_block = Body::with_parent(body_idx); + + if let Some(info) = false_target.merge_info { + merger(&mut reject_block, &info) + } else { + let prev = ctx.body_for_label.insert(false_target.label_id, reject); + debug_assert!(prev.is_none()); + } + + ctx.bodies.push(reject_block); + + let body = &mut ctx.bodies[body_idx]; + body.data.push(BodyFragment::If { + condition, + accept, + reject, + }); + + return Ok(()); + } + Op::Switch => { + inst.expect_at_least(3)?; + let selector = self.next()?; + let default_id = self.next()?; + + // If the previous instruction was a `OpSelectionMerge` then we must + // promote the `MergeBlockInformation` to a `SwitchMerge` + if let Some(merge) = selection_merge_block { + ctx.mergers + .insert(merge, MergeBlockInformation::SwitchMerge); + } + + let default = ctx.bodies.len(); + ctx.bodies.push(Body::with_parent(body_idx)); + ctx.body_for_label.entry(default_id).or_insert(default); + + let selector_lexp = &self.lookup_expression[&selector]; + let selector_lty = self.lookup_type.lookup(selector_lexp.type_id)?; + let selector_handle = get_expr_handle!(selector, selector_lexp); + let selector = match ctx.type_arena[selector_lty.handle].inner { + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Uint, + width: _, + }) => { + // IR expects a signed integer, so do a bitcast + ctx.expressions.append( + crate::Expression::As { + kind: crate::ScalarKind::Sint, + expr: selector_handle, + convert: None, + }, + span, + ) + } + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + }) => selector_handle, + ref other => unimplemented!("Unexpected selector {:?}", other), + }; + + // Clear past switch cases to prevent them from entering this one + self.switch_cases.clear(); + + for _ in 0..(inst.wc - 3) / 2 { + let literal = self.next()?; + let target = self.next()?; + + let case_body_idx = ctx.bodies.len(); + + // Check if any previous case already used this target block id, if so + // group them together to reorder them later so that no weird + // fallthrough cases happen. + if let Some(&mut (_, ref mut literals)) = self.switch_cases.get_mut(&target) + { + literals.push(literal as i32); + continue; + } + + let mut body = Body::with_parent(body_idx); + + if let Some(info) = ctx.mergers.get(&target) { + merger(&mut body, info); + } + + ctx.bodies.push(body); + ctx.body_for_label.entry(target).or_insert(case_body_idx); + + // Register this target block id as already having been processed and + // the respective body index assigned and the first case value + self.switch_cases + .insert(target, (case_body_idx, vec![literal as i32])); + } + + // Loop trough the collected target blocks creating a new case for each + // literal pointing to it, only one case will have the true body and all the + // others will be empty fallthrough so that they all execute the same body + // without duplicating code. + // + // Since `switch_cases` is an indexmap the order of insertion is preserved + // this is needed because spir-v defines fallthrough order in the switch + // instruction. + let mut cases = Vec::with_capacity((inst.wc as usize - 3) / 2); + for &(case_body_idx, ref literals) in self.switch_cases.values() { + let value = literals[0]; + + for &literal in literals.iter().skip(1) { + let empty_body_idx = ctx.bodies.len(); + let body = Body::with_parent(body_idx); + + ctx.bodies.push(body); + + cases.push((literal, empty_body_idx)); + } + + cases.push((value, case_body_idx)); + } + + block.extend(emitter.finish(ctx.expressions)); + + let body = &mut ctx.bodies[body_idx]; + ctx.blocks.insert(block_id, block); + // Make sure the vector has space for at least two more allocations + body.data.reserve(2); + body.data.push(BodyFragment::BlockId(block_id)); + body.data.push(BodyFragment::Switch { + selector, + cases, + default, + }); + + return Ok(()); + } + Op::SelectionMerge => { + inst.expect(3)?; + let merge_block_id = self.next()?; + // TODO: Selection Control Mask + let _selection_control = self.next()?; + + // Indicate that the merge block is a continuation of the + // current `Body`. + ctx.body_for_label.entry(merge_block_id).or_insert(body_idx); + + // Let subsequent branches to the merge block know that + // they've reached the end of the selection construct. + ctx.mergers + .insert(merge_block_id, MergeBlockInformation::SelectionMerge); + + selection_merge_block = Some(merge_block_id); + } + Op::LoopMerge => { + inst.expect_at_least(4)?; + let merge_block_id = self.next()?; + let continuing = self.next()?; + + // TODO: Loop Control Parameters + for _ in 0..inst.wc - 3 { + self.next()?; + } + + // Indicate that the merge block is a continuation of the + // current `Body`. + ctx.body_for_label.entry(merge_block_id).or_insert(body_idx); + // Let subsequent branches to the merge block know that + // they're `Break` statements. + ctx.mergers + .insert(merge_block_id, MergeBlockInformation::LoopMerge); + + let loop_body_idx = ctx.bodies.len(); + ctx.bodies.push(Body::with_parent(body_idx)); + + let continue_idx = ctx.bodies.len(); + // The continue block inherits the scope of the loop body + ctx.bodies.push(Body::with_parent(loop_body_idx)); + ctx.body_for_label.entry(continuing).or_insert(continue_idx); + // Let subsequent branches to the continue block know that + // they're `Continue` statements. + ctx.mergers + .insert(continuing, MergeBlockInformation::LoopContinue); + + // The loop header always belongs to the loop body + ctx.body_for_label.insert(block_id, loop_body_idx); + + let parent_body = &mut ctx.bodies[body_idx]; + parent_body.data.push(BodyFragment::Loop { + body: loop_body_idx, + continuing: continue_idx, + break_if: None, + }); + body_idx = loop_body_idx; + } + Op::DPdxCoarse => { + parse_expr_op!( + crate::DerivativeAxis::X, + crate::DerivativeControl::Coarse, + DERIVATIVE + )?; + } + Op::DPdyCoarse => { + parse_expr_op!( + crate::DerivativeAxis::Y, + crate::DerivativeControl::Coarse, + DERIVATIVE + )?; + } + Op::FwidthCoarse => { + parse_expr_op!( + crate::DerivativeAxis::Width, + crate::DerivativeControl::Coarse, + DERIVATIVE + )?; + } + Op::DPdxFine => { + parse_expr_op!( + crate::DerivativeAxis::X, + crate::DerivativeControl::Fine, + DERIVATIVE + )?; + } + Op::DPdyFine => { + parse_expr_op!( + crate::DerivativeAxis::Y, + crate::DerivativeControl::Fine, + DERIVATIVE + )?; + } + Op::FwidthFine => { + parse_expr_op!( + crate::DerivativeAxis::Width, + crate::DerivativeControl::Fine, + DERIVATIVE + )?; + } + Op::DPdx => { + parse_expr_op!( + crate::DerivativeAxis::X, + crate::DerivativeControl::None, + DERIVATIVE + )?; + } + Op::DPdy => { + parse_expr_op!( + crate::DerivativeAxis::Y, + crate::DerivativeControl::None, + DERIVATIVE + )?; + } + Op::Fwidth => { + parse_expr_op!( + crate::DerivativeAxis::Width, + crate::DerivativeControl::None, + DERIVATIVE + )?; + } + Op::ArrayLength => { + inst.expect(5)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let structure_id = self.next()?; + let member_index = self.next()?; + + // We're assuming that the validation pass, if it's run, will catch if the + // wrong types or parameters are supplied here. + + let structure_ptr = self.lookup_expression.lookup(structure_id)?; + let structure_handle = get_expr_handle!(structure_id, structure_ptr); + + let member_ptr = ctx.expressions.append( + crate::Expression::AccessIndex { + base: structure_handle, + index: member_index, + }, + span, + ); + + let length = ctx + .expressions + .append(crate::Expression::ArrayLength(member_ptr), span); + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: length, + type_id: result_type_id, + block_id, + }, + ); + } + Op::CopyMemory => { + inst.expect_at_least(3)?; + let target_id = self.next()?; + let source_id = self.next()?; + let _memory_access = if inst.wc != 3 { + inst.expect(4)?; + spirv::MemoryAccess::from_bits(self.next()?) + .ok_or(Error::InvalidParameter(Op::CopyMemory))? + } else { + spirv::MemoryAccess::NONE + }; + + // TODO: check if the source and target types are the same? + let target = self.lookup_expression.lookup(target_id)?; + let target_handle = get_expr_handle!(target_id, target); + let source = self.lookup_expression.lookup(source_id)?; + let source_handle = get_expr_handle!(source_id, source); + + // This operation is practically the same as loading and then storing, I think. + let value_expr = ctx.expressions.append( + crate::Expression::Load { + pointer: source_handle, + }, + span, + ); + + block.extend(emitter.finish(ctx.expressions)); + block.push( + crate::Statement::Store { + pointer: target_handle, + value: value_expr, + }, + span, + ); + + emitter.start(ctx.expressions); + } + Op::ControlBarrier => { + inst.expect(4)?; + let exec_scope_id = self.next()?; + let _mem_scope_raw = self.next()?; + let semantics_id = self.next()?; + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let semantics_const = self.lookup_constant.lookup(semantics_id)?; + + let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + let semantics = resolve_constant(ctx.gctx(), semantics_const.handle) + .ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?; + + if exec_scope == spirv::Scope::Workgroup as u32 { + let mut flags = crate::Barrier::empty(); + flags.set( + crate::Barrier::STORAGE, + semantics & spirv::MemorySemantics::UNIFORM_MEMORY.bits() != 0, + ); + flags.set( + crate::Barrier::WORK_GROUP, + semantics + & (spirv::MemorySemantics::SUBGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY) + .bits() + != 0, + ); + block.push(crate::Statement::Barrier(flags), span); + } else { + log::warn!("Unsupported barrier execution scope: {}", exec_scope); + } + } + Op::CopyObject => { + inst.expect(4)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let operand_id = self.next()?; + + let lookup = self.lookup_expression.lookup(operand_id)?; + let handle = get_expr_handle!(operand_id, lookup); + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle, + type_id: result_type_id, + block_id, + }, + ); + } + _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), + } + }; + + block.extend(emitter.finish(ctx.expressions)); + if let Some(stmt) = terminator { + block.push(stmt, crate::Span::default()); + } + + // Save this block fragment in `block_ctx.blocks`, and mark it to be + // incorporated into the current body at `Statement` assembly time. + ctx.blocks.insert(block_id, block); + let body = &mut ctx.bodies[body_idx]; + body.data.push(BodyFragment::BlockId(block_id)); + Ok(()) + } + + fn make_expression_storage( + &mut self, + globals: &Arena<crate::GlobalVariable>, + constants: &Arena<crate::Constant>, + ) -> Arena<crate::Expression> { + let mut expressions = Arena::new(); + #[allow(clippy::panic)] + { + assert!(self.lookup_expression.is_empty()); + } + // register global variables + for (&id, var) in self.lookup_variable.iter() { + let span = globals.get_span(var.handle); + let handle = expressions.append(crate::Expression::GlobalVariable(var.handle), span); + self.lookup_expression.insert( + id, + LookupExpression { + type_id: var.type_id, + handle, + // Setting this to an invalid id will cause get_expr_handle + // to default to the main body making sure no load/stores + // are added. + block_id: 0, + }, + ); + } + // register constants + for (&id, con) in self.lookup_constant.iter() { + let span = constants.get_span(con.handle); + let handle = expressions.append(crate::Expression::Constant(con.handle), span); + self.lookup_expression.insert( + id, + LookupExpression { + type_id: con.type_id, + handle, + // Setting this to an invalid id will cause get_expr_handle + // to default to the main body making sure no load/stores + // are added. + block_id: 0, + }, + ); + } + // done + expressions + } + + fn switch(&mut self, state: ModuleState, op: spirv::Op) -> Result<(), Error> { + if state < self.state { + Err(Error::UnsupportedInstruction(self.state, op)) + } else { + self.state = state; + Ok(()) + } + } + + /// Walk the statement tree and patch it in the following cases: + /// 1. Function call targets are replaced by `deferred_function_calls` map + fn patch_statements( + &mut self, + statements: &mut crate::Block, + expressions: &mut Arena<crate::Expression>, + fun_parameter_sampling: &mut [image::SamplingFlags], + ) -> Result<(), Error> { + use crate::Statement as S; + let mut i = 0usize; + while i < statements.len() { + match statements[i] { + S::Emit(_) => {} + S::Block(ref mut block) => { + self.patch_statements(block, expressions, fun_parameter_sampling)?; + } + S::If { + condition: _, + ref mut accept, + ref mut reject, + } => { + self.patch_statements(reject, expressions, fun_parameter_sampling)?; + self.patch_statements(accept, expressions, fun_parameter_sampling)?; + } + S::Switch { + selector: _, + ref mut cases, + } => { + for case in cases.iter_mut() { + self.patch_statements(&mut case.body, expressions, fun_parameter_sampling)?; + } + } + S::Loop { + ref mut body, + ref mut continuing, + break_if: _, + } => { + self.patch_statements(body, expressions, fun_parameter_sampling)?; + self.patch_statements(continuing, expressions, fun_parameter_sampling)?; + } + S::Break + | S::Continue + | S::Return { .. } + | S::Kill + | S::Barrier(_) + | S::Store { .. } + | S::ImageStore { .. } + | S::Atomic { .. } + | S::RayQuery { .. } => {} + S::Call { + function: ref mut callee, + ref arguments, + .. + } => { + let fun_id = self.deferred_function_calls[callee.index()]; + let fun_lookup = self.lookup_function.lookup(fun_id)?; + *callee = fun_lookup.handle; + + // Patch sampling flags + for (arg_index, arg) in arguments.iter().enumerate() { + let flags = match fun_lookup.parameters_sampling.get(arg_index) { + Some(&flags) if !flags.is_empty() => flags, + _ => continue, + }; + + match expressions[*arg] { + crate::Expression::GlobalVariable(handle) => { + if let Some(sampling) = self.handle_sampling.get_mut(&handle) { + *sampling |= flags + } + } + crate::Expression::FunctionArgument(i) => { + fun_parameter_sampling[i as usize] |= flags; + } + ref other => return Err(Error::InvalidGlobalVar(other.clone())), + } + } + } + S::WorkGroupUniformLoad { .. } => unreachable!(), + } + i += 1; + } + Ok(()) + } + + fn patch_function( + &mut self, + handle: Option<Handle<crate::Function>>, + fun: &mut crate::Function, + ) -> Result<(), Error> { + // Note: this search is a bit unfortunate + let (fun_id, mut parameters_sampling) = match handle { + Some(h) => { + let (&fun_id, lookup) = self + .lookup_function + .iter_mut() + .find(|&(_, ref lookup)| lookup.handle == h) + .unwrap(); + (fun_id, mem::take(&mut lookup.parameters_sampling)) + } + None => (0, Vec::new()), + }; + + for (_, expr) in fun.expressions.iter_mut() { + if let crate::Expression::CallResult(ref mut function) = *expr { + let fun_id = self.deferred_function_calls[function.index()]; + *function = self.lookup_function.lookup(fun_id)?.handle; + } + } + + self.patch_statements( + &mut fun.body, + &mut fun.expressions, + &mut parameters_sampling, + )?; + + if let Some(lookup) = self.lookup_function.get_mut(&fun_id) { + lookup.parameters_sampling = parameters_sampling; + } + Ok(()) + } + + pub fn parse(mut self) -> Result<crate::Module, Error> { + let mut module = { + if self.next()? != spirv::MAGIC_NUMBER { + return Err(Error::InvalidHeader); + } + let version_raw = self.next()?; + let generator = self.next()?; + let _bound = self.next()?; + let _schema = self.next()?; + log::info!("Generated by {} version {:x}", generator, version_raw); + crate::Module::default() + }; + + self.layouter.clear(); + self.dummy_functions = Arena::new(); + self.lookup_function.clear(); + self.function_call_graph.clear(); + + loop { + use spirv::Op; + + let inst = match self.next_inst() { + Ok(inst) => inst, + Err(Error::IncompleteData) => break, + Err(other) => return Err(other), + }; + log::debug!("\t{:?} [{}]", inst.op, inst.wc); + + match inst.op { + Op::Capability => self.parse_capability(inst), + Op::Extension => self.parse_extension(inst), + Op::ExtInstImport => self.parse_ext_inst_import(inst), + Op::MemoryModel => self.parse_memory_model(inst), + Op::EntryPoint => self.parse_entry_point(inst), + Op::ExecutionMode => self.parse_execution_mode(inst), + Op::String => self.parse_string(inst), + Op::Source => self.parse_source(inst), + Op::SourceExtension => self.parse_source_extension(inst), + Op::Name => self.parse_name(inst), + Op::MemberName => self.parse_member_name(inst), + Op::ModuleProcessed => self.parse_module_processed(inst), + Op::Decorate => self.parse_decorate(inst), + Op::MemberDecorate => self.parse_member_decorate(inst), + Op::TypeVoid => self.parse_type_void(inst), + Op::TypeBool => self.parse_type_bool(inst, &mut module), + Op::TypeInt => self.parse_type_int(inst, &mut module), + Op::TypeFloat => self.parse_type_float(inst, &mut module), + Op::TypeVector => self.parse_type_vector(inst, &mut module), + Op::TypeMatrix => self.parse_type_matrix(inst, &mut module), + Op::TypeFunction => self.parse_type_function(inst), + Op::TypePointer => self.parse_type_pointer(inst, &mut module), + Op::TypeArray => self.parse_type_array(inst, &mut module), + Op::TypeRuntimeArray => self.parse_type_runtime_array(inst, &mut module), + Op::TypeStruct => self.parse_type_struct(inst, &mut module), + Op::TypeImage => self.parse_type_image(inst, &mut module), + Op::TypeSampledImage => self.parse_type_sampled_image(inst), + Op::TypeSampler => self.parse_type_sampler(inst, &mut module), + Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module), + Op::ConstantComposite => self.parse_composite_constant(inst, &mut module), + Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module), + Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module), + Op::ConstantFalse => self.parse_bool_constant(inst, false, &mut module), + Op::Variable => self.parse_global_variable(inst, &mut module), + Op::Function => { + self.switch(ModuleState::Function, inst.op)?; + inst.expect(5)?; + self.parse_function(&mut module) + } + _ => Err(Error::UnsupportedInstruction(self.state, inst.op)), //TODO + }?; + } + + log::info!("Patching..."); + { + let mut nodes = petgraph::algo::toposort(&self.function_call_graph, None) + .map_err(|cycle| Error::FunctionCallCycle(cycle.node_id()))?; + nodes.reverse(); // we need dominated first + let mut functions = mem::take(&mut module.functions); + for fun_id in nodes { + if fun_id > !(functions.len() as u32) { + // skip all the fake IDs registered for the entry points + continue; + } + let lookup = self.lookup_function.get_mut(&fun_id).unwrap(); + // take out the function from the old array + let fun = mem::take(&mut functions[lookup.handle]); + // add it to the newly formed arena, and adjust the lookup + lookup.handle = module + .functions + .append(fun, functions.get_span(lookup.handle)); + } + } + // patch all the functions + for (handle, fun) in module.functions.iter_mut() { + self.patch_function(Some(handle), fun)?; + } + for ep in module.entry_points.iter_mut() { + self.patch_function(None, &mut ep.function)?; + } + + // Check all the images and samplers to have consistent comparison property. + for (handle, flags) in self.handle_sampling.drain() { + if !image::patch_comparison_type( + flags, + module.global_variables.get_mut(handle), + &mut module.types, + ) { + return Err(Error::InconsistentComparisonSampling(handle)); + } + } + + if !self.future_decor.is_empty() { + log::warn!("Unused item decorations: {:?}", self.future_decor); + self.future_decor.clear(); + } + if !self.future_member_decor.is_empty() { + log::warn!("Unused member decorations: {:?}", self.future_member_decor); + self.future_member_decor.clear(); + } + + Ok(module) + } + + fn parse_capability(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Capability, inst.op)?; + inst.expect(2)?; + let capability = self.next()?; + let cap = + spirv::Capability::from_u32(capability).ok_or(Error::UnknownCapability(capability))?; + if !SUPPORTED_CAPABILITIES.contains(&cap) { + if self.options.strict_capabilities { + return Err(Error::UnsupportedCapability(cap)); + } else { + log::warn!("Unknown capability {:?}", cap); + } + } + Ok(()) + } + + fn parse_extension(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Extension, inst.op)?; + inst.expect_at_least(2)?; + let (name, left) = self.next_string(inst.wc - 1)?; + if left != 0 { + return Err(Error::InvalidOperand); + } + if !SUPPORTED_EXTENSIONS.contains(&name.as_str()) { + return Err(Error::UnsupportedExtension(name)); + } + Ok(()) + } + + fn parse_ext_inst_import(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Extension, inst.op)?; + inst.expect_at_least(3)?; + let result_id = self.next()?; + let (name, left) = self.next_string(inst.wc - 2)?; + if left != 0 { + return Err(Error::InvalidOperand); + } + if !SUPPORTED_EXT_SETS.contains(&name.as_str()) { + return Err(Error::UnsupportedExtSet(name)); + } + self.ext_glsl_id = Some(result_id); + Ok(()) + } + + fn parse_memory_model(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::MemoryModel, inst.op)?; + inst.expect(3)?; + let _addressing_model = self.next()?; + let _memory_model = self.next()?; + Ok(()) + } + + fn parse_entry_point(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::EntryPoint, inst.op)?; + inst.expect_at_least(4)?; + let exec_model = self.next()?; + let exec_model = spirv::ExecutionModel::from_u32(exec_model) + .ok_or(Error::UnsupportedExecutionModel(exec_model))?; + let function_id = self.next()?; + let (name, left) = self.next_string(inst.wc - 3)?; + let ep = EntryPoint { + stage: match exec_model { + spirv::ExecutionModel::Vertex => crate::ShaderStage::Vertex, + spirv::ExecutionModel::Fragment => crate::ShaderStage::Fragment, + spirv::ExecutionModel::GLCompute => crate::ShaderStage::Compute, + _ => return Err(Error::UnsupportedExecutionModel(exec_model as u32)), + }, + name, + early_depth_test: None, + workgroup_size: [0; 3], + variable_ids: self.data.by_ref().take(left as usize).collect(), + }; + self.lookup_entry_point.insert(function_id, ep); + Ok(()) + } + + fn parse_execution_mode(&mut self, inst: Instruction) -> Result<(), Error> { + use spirv::ExecutionMode; + + self.switch(ModuleState::ExecutionMode, inst.op)?; + inst.expect_at_least(3)?; + + let ep_id = self.next()?; + let mode_id = self.next()?; + let args: Vec<spirv::Word> = self.data.by_ref().take(inst.wc as usize - 3).collect(); + + let ep = self + .lookup_entry_point + .get_mut(&ep_id) + .ok_or(Error::InvalidId(ep_id))?; + let mode = spirv::ExecutionMode::from_u32(mode_id) + .ok_or(Error::UnsupportedExecutionMode(mode_id))?; + + match mode { + ExecutionMode::EarlyFragmentTests => { + if ep.early_depth_test.is_none() { + ep.early_depth_test = Some(crate::EarlyDepthTest { conservative: None }); + } + } + ExecutionMode::DepthUnchanged => { + ep.early_depth_test = Some(crate::EarlyDepthTest { + conservative: Some(crate::ConservativeDepth::Unchanged), + }); + } + ExecutionMode::DepthGreater => { + ep.early_depth_test = Some(crate::EarlyDepthTest { + conservative: Some(crate::ConservativeDepth::GreaterEqual), + }); + } + ExecutionMode::DepthLess => { + ep.early_depth_test = Some(crate::EarlyDepthTest { + conservative: Some(crate::ConservativeDepth::LessEqual), + }); + } + ExecutionMode::DepthReplacing => { + // Ignored because it can be deduced from the IR. + } + ExecutionMode::OriginUpperLeft => { + // Ignored because the other option (OriginLowerLeft) is not valid in Vulkan mode. + } + ExecutionMode::LocalSize => { + ep.workgroup_size = [args[0], args[1], args[2]]; + } + _ => { + return Err(Error::UnsupportedExecutionMode(mode_id)); + } + } + + Ok(()) + } + + fn parse_string(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Source, inst.op)?; + inst.expect_at_least(3)?; + let _id = self.next()?; + let (_name, _) = self.next_string(inst.wc - 2)?; + Ok(()) + } + + fn parse_source(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Source, inst.op)?; + for _ in 1..inst.wc { + let _ = self.next()?; + } + Ok(()) + } + + fn parse_source_extension(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Source, inst.op)?; + inst.expect_at_least(2)?; + let (_name, _) = self.next_string(inst.wc - 1)?; + Ok(()) + } + + fn parse_name(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Name, inst.op)?; + inst.expect_at_least(3)?; + let id = self.next()?; + let (name, left) = self.next_string(inst.wc - 2)?; + if left != 0 { + return Err(Error::InvalidOperand); + } + self.future_decor.entry(id).or_default().name = Some(name); + Ok(()) + } + + fn parse_member_name(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Name, inst.op)?; + inst.expect_at_least(4)?; + let id = self.next()?; + let member = self.next()?; + let (name, left) = self.next_string(inst.wc - 3)?; + if left != 0 { + return Err(Error::InvalidOperand); + } + + self.future_member_decor + .entry((id, member)) + .or_default() + .name = Some(name); + Ok(()) + } + + fn parse_module_processed(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Name, inst.op)?; + inst.expect_at_least(2)?; + let (_info, left) = self.next_string(inst.wc - 1)?; + //Note: string is ignored + if left != 0 { + return Err(Error::InvalidOperand); + } + Ok(()) + } + + fn parse_decorate(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Annotation, inst.op)?; + inst.expect_at_least(3)?; + let id = self.next()?; + let mut dec = self.future_decor.remove(&id).unwrap_or_default(); + self.next_decoration(inst, 2, &mut dec)?; + self.future_decor.insert(id, dec); + Ok(()) + } + + fn parse_member_decorate(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Annotation, inst.op)?; + inst.expect_at_least(4)?; + let id = self.next()?; + let member = self.next()?; + + let mut dec = self + .future_member_decor + .remove(&(id, member)) + .unwrap_or_default(); + self.next_decoration(inst, 3, &mut dec)?; + self.future_member_decor.insert((id, member), dec); + Ok(()) + } + + fn parse_type_void(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(2)?; + let id = self.next()?; + self.lookup_void_type = Some(id); + Ok(()) + } + + fn parse_type_bool( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(2)?; + let id = self.next()?; + let inner = crate::TypeInner::Scalar(crate::Scalar::BOOL); + self.lookup_type.insert( + id, + LookupType { + handle: module.types.insert( + crate::Type { + name: self.future_decor.remove(&id).and_then(|dec| dec.name), + inner, + }, + self.span_from_with_op(start), + ), + base_id: None, + }, + ); + Ok(()) + } + + fn parse_type_int( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(4)?; + let id = self.next()?; + let width = self.next()?; + let sign = self.next()?; + let inner = crate::TypeInner::Scalar(crate::Scalar { + kind: match sign { + 0 => crate::ScalarKind::Uint, + 1 => crate::ScalarKind::Sint, + _ => return Err(Error::InvalidSign(sign)), + }, + width: map_width(width)?, + }); + self.lookup_type.insert( + id, + LookupType { + handle: module.types.insert( + crate::Type { + name: self.future_decor.remove(&id).and_then(|dec| dec.name), + inner, + }, + self.span_from_with_op(start), + ), + base_id: None, + }, + ); + Ok(()) + } + + fn parse_type_float( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(3)?; + let id = self.next()?; + let width = self.next()?; + let inner = crate::TypeInner::Scalar(crate::Scalar::float(map_width(width)?)); + self.lookup_type.insert( + id, + LookupType { + handle: module.types.insert( + crate::Type { + name: self.future_decor.remove(&id).and_then(|dec| dec.name), + inner, + }, + self.span_from_with_op(start), + ), + base_id: None, + }, + ); + Ok(()) + } + + fn parse_type_vector( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(4)?; + let id = self.next()?; + let type_id = self.next()?; + let type_lookup = self.lookup_type.lookup(type_id)?; + let scalar = match module.types[type_lookup.handle].inner { + crate::TypeInner::Scalar(scalar) => scalar, + _ => return Err(Error::InvalidInnerType(type_id)), + }; + let component_count = self.next()?; + let inner = crate::TypeInner::Vector { + size: map_vector_size(component_count)?, + scalar, + }; + self.lookup_type.insert( + id, + LookupType { + handle: module.types.insert( + crate::Type { + name: self.future_decor.remove(&id).and_then(|dec| dec.name), + inner, + }, + self.span_from_with_op(start), + ), + base_id: Some(type_id), + }, + ); + Ok(()) + } + + fn parse_type_matrix( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(4)?; + let id = self.next()?; + let vector_type_id = self.next()?; + let num_columns = self.next()?; + let decor = self.future_decor.remove(&id); + + let vector_type_lookup = self.lookup_type.lookup(vector_type_id)?; + let inner = match module.types[vector_type_lookup.handle].inner { + crate::TypeInner::Vector { size, scalar } => crate::TypeInner::Matrix { + columns: map_vector_size(num_columns)?, + rows: size, + scalar, + }, + _ => return Err(Error::InvalidInnerType(vector_type_id)), + }; + + self.lookup_type.insert( + id, + LookupType { + handle: module.types.insert( + crate::Type { + name: decor.and_then(|dec| dec.name), + inner, + }, + self.span_from_with_op(start), + ), + base_id: Some(vector_type_id), + }, + ); + Ok(()) + } + + fn parse_type_function(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect_at_least(3)?; + let id = self.next()?; + let return_type_id = self.next()?; + let parameter_type_ids = self.data.by_ref().take(inst.wc as usize - 3).collect(); + self.lookup_function_type.insert( + id, + LookupFunctionType { + parameter_type_ids, + return_type_id, + }, + ); + Ok(()) + } + + fn parse_type_pointer( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(4)?; + let id = self.next()?; + let storage_class = self.next()?; + let type_id = self.next()?; + + let decor = self.future_decor.remove(&id); + let base_lookup_ty = self.lookup_type.lookup(type_id)?; + let base_inner = &module.types[base_lookup_ty.handle].inner; + + let space = if let Some(space) = base_inner.pointer_space() { + space + } else if self + .lookup_storage_buffer_types + .contains_key(&base_lookup_ty.handle) + { + crate::AddressSpace::Storage { + access: crate::StorageAccess::default(), + } + } else { + match map_storage_class(storage_class)? { + ExtendedClass::Global(space) => space, + ExtendedClass::Input | ExtendedClass::Output => crate::AddressSpace::Private, + } + }; + + // We don't support pointers to runtime-sized arrays in the `Uniform` + // storage class with the `BufferBlock` decoration. Runtime-sized arrays + // should be in the StorageBuffer class. + if let crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } = *base_inner + { + match space { + crate::AddressSpace::Storage { .. } => {} + _ => { + return Err(Error::UnsupportedRuntimeArrayStorageClass); + } + } + } + + // Don't bother with pointer stuff for `Handle` types. + let lookup_ty = if space == crate::AddressSpace::Handle { + base_lookup_ty.clone() + } else { + LookupType { + handle: module.types.insert( + crate::Type { + name: decor.and_then(|dec| dec.name), + inner: crate::TypeInner::Pointer { + base: base_lookup_ty.handle, + space, + }, + }, + self.span_from_with_op(start), + ), + base_id: Some(type_id), + } + }; + self.lookup_type.insert(id, lookup_ty); + Ok(()) + } + + fn parse_type_array( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(4)?; + let id = self.next()?; + let type_id = self.next()?; + let length_id = self.next()?; + let length_const = self.lookup_constant.lookup(length_id)?; + + let size = resolve_constant(module.to_ctx(), length_const.handle) + .and_then(NonZeroU32::new) + .ok_or(Error::InvalidArraySize(length_const.handle))?; + + let decor = self.future_decor.remove(&id).unwrap_or_default(); + let base = self.lookup_type.lookup(type_id)?.handle; + + self.layouter.update(module.to_ctx()).unwrap(); + + // HACK if the underlying type is an image or a sampler, let's assume + // that we're dealing with a binding-array + // + // Note that it's not a strictly correct assumption, but rather a trade + // off caused by an impedance mismatch between SPIR-V's and Naga's type + // systems - Naga distinguishes between arrays and binding-arrays via + // types (i.e. both kinds of arrays are just different types), while + // SPIR-V distinguishes between them through usage - e.g. given: + // + // ``` + // %image = OpTypeImage %float 2D 2 0 0 2 Rgba16f + // %uint_256 = OpConstant %uint 256 + // %image_array = OpTypeArray %image %uint_256 + // ``` + // + // ``` + // %image = OpTypeImage %float 2D 2 0 0 2 Rgba16f + // %uint_256 = OpConstant %uint 256 + // %image_array = OpTypeArray %image %uint_256 + // %image_array_ptr = OpTypePointer UniformConstant %image_array + // ``` + // + // ... in the first case, `%image_array` should technically correspond + // to `TypeInner::Array`, while in the second case it should say + // `TypeInner::BindingArray` (kinda, depending on whether `%image_array` + // is ever used as a freestanding type or rather always through the + // pointer-indirection). + // + // Anyway, at the moment we don't support other kinds of image / sampler + // arrays than those binding-based, so this assumption is pretty safe + // for now. + let inner = if let crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } = + module.types[base].inner + { + crate::TypeInner::BindingArray { + base, + size: crate::ArraySize::Constant(size), + } + } else { + crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + stride: match decor.array_stride { + Some(stride) => stride.get(), + None => self.layouter[base].to_stride(), + }, + } + }; + + self.lookup_type.insert( + id, + LookupType { + handle: module.types.insert( + crate::Type { + name: decor.name, + inner, + }, + self.span_from_with_op(start), + ), + base_id: Some(type_id), + }, + ); + Ok(()) + } + + fn parse_type_runtime_array( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(3)?; + let id = self.next()?; + let type_id = self.next()?; + + let decor = self.future_decor.remove(&id).unwrap_or_default(); + let base = self.lookup_type.lookup(type_id)?.handle; + + self.layouter.update(module.to_ctx()).unwrap(); + + // HACK same case as in `parse_type_array()` + let inner = if let crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } = + module.types[base].inner + { + crate::TypeInner::BindingArray { + base: self.lookup_type.lookup(type_id)?.handle, + size: crate::ArraySize::Dynamic, + } + } else { + crate::TypeInner::Array { + base: self.lookup_type.lookup(type_id)?.handle, + size: crate::ArraySize::Dynamic, + stride: match decor.array_stride { + Some(stride) => stride.get(), + None => self.layouter[base].to_stride(), + }, + } + }; + + self.lookup_type.insert( + id, + LookupType { + handle: module.types.insert( + crate::Type { + name: decor.name, + inner, + }, + self.span_from_with_op(start), + ), + base_id: Some(type_id), + }, + ); + Ok(()) + } + + fn parse_type_struct( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect_at_least(2)?; + let id = self.next()?; + let parent_decor = self.future_decor.remove(&id); + let is_storage_buffer = parent_decor + .as_ref() + .map_or(false, |decor| decor.storage_buffer); + + self.layouter.update(module.to_ctx()).unwrap(); + + let mut members = Vec::<crate::StructMember>::with_capacity(inst.wc as usize - 2); + let mut member_lookups = Vec::with_capacity(members.capacity()); + let mut storage_access = crate::StorageAccess::empty(); + let mut span = 0; + let mut alignment = Alignment::ONE; + for i in 0..u32::from(inst.wc) - 2 { + let type_id = self.next()?; + let ty = self.lookup_type.lookup(type_id)?.handle; + let decor = self + .future_member_decor + .remove(&(id, i)) + .unwrap_or_default(); + + storage_access |= decor.flags.to_storage_access(); + + member_lookups.push(LookupMember { + type_id, + row_major: decor.matrix_major == Some(Majority::Row), + }); + + let member_alignment = self.layouter[ty].alignment; + span = member_alignment.round_up(span); + alignment = member_alignment.max(alignment); + + let binding = decor.io_binding().ok(); + if let Some(offset) = decor.offset { + span = offset; + } + let offset = span; + + span += self.layouter[ty].size; + + let inner = &module.types[ty].inner; + if let crate::TypeInner::Matrix { + columns, + rows, + scalar, + } = *inner + { + if let Some(stride) = decor.matrix_stride { + let expected_stride = Alignment::from(rows) * scalar.width as u32; + if stride.get() != expected_stride { + return Err(Error::UnsupportedMatrixStride { + stride: stride.get(), + columns: columns as u8, + rows: rows as u8, + width: scalar.width, + }); + } + } + } + + members.push(crate::StructMember { + name: decor.name, + ty, + binding, + offset, + }); + } + + span = alignment.round_up(span); + + let inner = crate::TypeInner::Struct { span, members }; + + let ty_handle = module.types.insert( + crate::Type { + name: parent_decor.and_then(|dec| dec.name), + inner, + }, + self.span_from_with_op(start), + ); + + if is_storage_buffer { + self.lookup_storage_buffer_types + .insert(ty_handle, storage_access); + } + for (i, member_lookup) in member_lookups.into_iter().enumerate() { + self.lookup_member + .insert((ty_handle, i as u32), member_lookup); + } + self.lookup_type.insert( + id, + LookupType { + handle: ty_handle, + base_id: None, + }, + ); + Ok(()) + } + + fn parse_type_image( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(9)?; + + let id = self.next()?; + let sample_type_id = self.next()?; + let dim = self.next()?; + let is_depth = self.next()?; + let is_array = self.next()? != 0; + let is_msaa = self.next()? != 0; + let _is_sampled = self.next()?; + let format = self.next()?; + + let dim = map_image_dim(dim)?; + let decor = self.future_decor.remove(&id).unwrap_or_default(); + + // ensure there is a type for texture coordinate without extra components + module.types.insert( + crate::Type { + name: None, + inner: { + let scalar = crate::Scalar::F32; + match dim.required_coordinate_size() { + None => crate::TypeInner::Scalar(scalar), + Some(size) => crate::TypeInner::Vector { size, scalar }, + } + }, + }, + Default::default(), + ); + + let base_handle = self.lookup_type.lookup(sample_type_id)?.handle; + let kind = module.types[base_handle] + .inner + .scalar_kind() + .ok_or(Error::InvalidImageBaseType(base_handle))?; + + let inner = crate::TypeInner::Image { + class: if is_depth == 1 { + crate::ImageClass::Depth { multi: is_msaa } + } else if format != 0 { + crate::ImageClass::Storage { + format: map_image_format(format)?, + access: crate::StorageAccess::default(), + } + } else { + crate::ImageClass::Sampled { + kind, + multi: is_msaa, + } + }, + dim, + arrayed: is_array, + }; + + let handle = module.types.insert( + crate::Type { + name: decor.name, + inner, + }, + self.span_from_with_op(start), + ); + + self.lookup_type.insert( + id, + LookupType { + handle, + base_id: Some(sample_type_id), + }, + ); + Ok(()) + } + + fn parse_type_sampled_image(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(3)?; + let id = self.next()?; + let image_id = self.next()?; + self.lookup_type.insert( + id, + LookupType { + handle: self.lookup_type.lookup(image_id)?.handle, + base_id: Some(image_id), + }, + ); + Ok(()) + } + + fn parse_type_sampler( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(2)?; + let id = self.next()?; + let decor = self.future_decor.remove(&id).unwrap_or_default(); + let handle = module.types.insert( + crate::Type { + name: decor.name, + inner: crate::TypeInner::Sampler { comparison: false }, + }, + self.span_from_with_op(start), + ); + self.lookup_type.insert( + id, + LookupType { + handle, + base_id: None, + }, + ); + Ok(()) + } + + fn parse_constant( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect_at_least(4)?; + let type_id = self.next()?; + let id = self.next()?; + let type_lookup = self.lookup_type.lookup(type_id)?; + let ty = type_lookup.handle; + + let literal = match module.types[ty].inner { + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Uint, + width, + }) => { + let low = self.next()?; + match width { + 4 => crate::Literal::U32(low), + _ => return Err(Error::InvalidTypeWidth(width as u32)), + } + } + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Sint, + width, + }) => { + let low = self.next()?; + match width { + 4 => crate::Literal::I32(low as i32), + 8 => { + inst.expect(5)?; + let high = self.next()?; + crate::Literal::I64((u64::from(high) << 32 | u64::from(low)) as i64) + } + _ => return Err(Error::InvalidTypeWidth(width as u32)), + } + } + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Float, + width, + }) => { + let low = self.next()?; + match width { + 4 => crate::Literal::F32(f32::from_bits(low)), + 8 => { + inst.expect(5)?; + let high = self.next()?; + crate::Literal::F64(f64::from_bits( + (u64::from(high) << 32) | u64::from(low), + )) + } + _ => return Err(Error::InvalidTypeWidth(width as u32)), + } + } + _ => return Err(Error::UnsupportedType(type_lookup.handle)), + }; + + let decor = self.future_decor.remove(&id).unwrap_or_default(); + + let span = self.span_from_with_op(start); + + let init = module + .const_expressions + .append(crate::Expression::Literal(literal), span); + self.lookup_constant.insert( + id, + LookupConstant { + handle: module.constants.append( + crate::Constant { + r#override: decor.specialization(), + name: decor.name, + ty, + init, + }, + span, + ), + type_id, + }, + ); + Ok(()) + } + + fn parse_composite_constant( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect_at_least(3)?; + let type_id = self.next()?; + let id = self.next()?; + + let type_lookup = self.lookup_type.lookup(type_id)?; + let ty = type_lookup.handle; + + let mut components = Vec::with_capacity(inst.wc as usize - 3); + for _ in 0..components.capacity() { + let start = self.data_offset; + let component_id = self.next()?; + let span = self.span_from_with_op(start); + let constant = self.lookup_constant.lookup(component_id)?; + let expr = module + .const_expressions + .append(crate::Expression::Constant(constant.handle), span); + components.push(expr); + } + + let decor = self.future_decor.remove(&id).unwrap_or_default(); + + let span = self.span_from_with_op(start); + + let init = module + .const_expressions + .append(crate::Expression::Compose { ty, components }, span); + self.lookup_constant.insert( + id, + LookupConstant { + handle: module.constants.append( + crate::Constant { + r#override: decor.specialization(), + name: decor.name, + ty, + init, + }, + span, + ), + type_id, + }, + ); + Ok(()) + } + + fn parse_null_constant( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(3)?; + let type_id = self.next()?; + let id = self.next()?; + let span = self.span_from_with_op(start); + + let type_lookup = self.lookup_type.lookup(type_id)?; + let ty = type_lookup.handle; + + let decor = self.future_decor.remove(&id).unwrap_or_default(); + + let init = module + .const_expressions + .append(crate::Expression::ZeroValue(ty), span); + let handle = module.constants.append( + crate::Constant { + r#override: decor.specialization(), + name: decor.name, + ty, + init, + }, + span, + ); + self.lookup_constant + .insert(id, LookupConstant { handle, type_id }); + Ok(()) + } + + fn parse_bool_constant( + &mut self, + inst: Instruction, + value: bool, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect(3)?; + let type_id = self.next()?; + let id = self.next()?; + let span = self.span_from_with_op(start); + + let type_lookup = self.lookup_type.lookup(type_id)?; + let ty = type_lookup.handle; + + let decor = self.future_decor.remove(&id).unwrap_or_default(); + + let init = module.const_expressions.append( + crate::Expression::Literal(crate::Literal::Bool(value)), + span, + ); + self.lookup_constant.insert( + id, + LookupConstant { + handle: module.constants.append( + crate::Constant { + r#override: decor.specialization(), + name: decor.name, + ty, + init, + }, + span, + ), + type_id, + }, + ); + Ok(()) + } + + fn parse_global_variable( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + let start = self.data_offset; + self.switch(ModuleState::Type, inst.op)?; + inst.expect_at_least(4)?; + let type_id = self.next()?; + let id = self.next()?; + let storage_class = self.next()?; + let init = if inst.wc > 4 { + inst.expect(5)?; + let start = self.data_offset; + let init_id = self.next()?; + let span = self.span_from_with_op(start); + let lconst = self.lookup_constant.lookup(init_id)?; + let expr = module + .const_expressions + .append(crate::Expression::Constant(lconst.handle), span); + Some(expr) + } else { + None + }; + let span = self.span_from_with_op(start); + let mut dec = self.future_decor.remove(&id).unwrap_or_default(); + + let original_ty = self.lookup_type.lookup(type_id)?.handle; + let mut ty = original_ty; + + if let crate::TypeInner::Pointer { base, space: _ } = module.types[original_ty].inner { + ty = base; + } + + if let crate::TypeInner::BindingArray { .. } = module.types[original_ty].inner { + // Inside `parse_type_array()` we guess that an array of images or + // samplers must be a binding array, and here we validate that guess + if dec.desc_set.is_none() || dec.desc_index.is_none() { + return Err(Error::NonBindingArrayOfImageOrSamplers); + } + } + + if let crate::TypeInner::Image { + dim, + arrayed, + class: crate::ImageClass::Storage { format, access: _ }, + } = module.types[ty].inner + { + // Storage image types in IR have to contain the access, but not in the SPIR-V. + // The same image type in SPIR-V can be used (and has to be used) for multiple images. + // So we copy the type out and apply the variable access decorations. + let access = dec.flags.to_storage_access(); + + ty = module.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Image { + dim, + arrayed, + class: crate::ImageClass::Storage { format, access }, + }, + }, + Default::default(), + ); + } + + let ext_class = match self.lookup_storage_buffer_types.get(&ty) { + Some(&access) => ExtendedClass::Global(crate::AddressSpace::Storage { access }), + None => map_storage_class(storage_class)?, + }; + + // Fix empty name for gl_PerVertex struct generated by glslang + if let crate::TypeInner::Pointer { .. } = module.types[original_ty].inner { + if ext_class == ExtendedClass::Input || ext_class == ExtendedClass::Output { + if let Some(ref dec_name) = dec.name { + if dec_name.is_empty() { + dec.name = Some("perVertexStruct".to_string()) + } + } + } + } + + let (inner, var) = match ext_class { + ExtendedClass::Global(mut space) => { + if let crate::AddressSpace::Storage { ref mut access } = space { + *access &= dec.flags.to_storage_access(); + } + let var = crate::GlobalVariable { + binding: dec.resource_binding(), + name: dec.name, + space, + ty, + init, + }; + (Variable::Global, var) + } + ExtendedClass::Input => { + let binding = dec.io_binding()?; + let mut unsigned_ty = ty; + if let crate::Binding::BuiltIn(built_in) = binding { + let needs_inner_uint = match built_in { + crate::BuiltIn::BaseInstance + | crate::BuiltIn::BaseVertex + | crate::BuiltIn::InstanceIndex + | crate::BuiltIn::SampleIndex + | crate::BuiltIn::VertexIndex + | crate::BuiltIn::PrimitiveIndex + | crate::BuiltIn::LocalInvocationIndex => { + Some(crate::TypeInner::Scalar(crate::Scalar::U32)) + } + crate::BuiltIn::GlobalInvocationId + | crate::BuiltIn::LocalInvocationId + | crate::BuiltIn::WorkGroupId + | crate::BuiltIn::WorkGroupSize => Some(crate::TypeInner::Vector { + size: crate::VectorSize::Tri, + scalar: crate::Scalar::U32, + }), + _ => None, + }; + if let (Some(inner), Some(crate::ScalarKind::Sint)) = + (needs_inner_uint, module.types[ty].inner.scalar_kind()) + { + unsigned_ty = module + .types + .insert(crate::Type { name: None, inner }, Default::default()); + } + } + + let var = crate::GlobalVariable { + name: dec.name.clone(), + space: crate::AddressSpace::Private, + binding: None, + ty, + init: None, + }; + + let inner = Variable::Input(crate::FunctionArgument { + name: dec.name, + ty: unsigned_ty, + binding: Some(binding), + }); + (inner, var) + } + ExtendedClass::Output => { + // For output interface blocks, this would be a structure. + let binding = dec.io_binding().ok(); + let init = match binding { + Some(crate::Binding::BuiltIn(built_in)) => { + match null::generate_default_built_in( + Some(built_in), + ty, + &mut module.const_expressions, + span, + ) { + Ok(handle) => Some(handle), + Err(e) => { + log::warn!("Failed to initialize output built-in: {}", e); + None + } + } + } + Some(crate::Binding::Location { .. }) => None, + None => match module.types[ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + let mut components = Vec::with_capacity(members.len()); + for member in members.iter() { + let built_in = match member.binding { + Some(crate::Binding::BuiltIn(built_in)) => Some(built_in), + _ => None, + }; + let handle = null::generate_default_built_in( + built_in, + member.ty, + &mut module.const_expressions, + span, + )?; + components.push(handle); + } + Some( + module + .const_expressions + .append(crate::Expression::Compose { ty, components }, span), + ) + } + _ => None, + }, + }; + + let var = crate::GlobalVariable { + name: dec.name, + space: crate::AddressSpace::Private, + binding: None, + ty, + init, + }; + let inner = Variable::Output(crate::FunctionResult { ty, binding }); + (inner, var) + } + }; + + let handle = module.global_variables.append(var, span); + + if module.types[ty].inner.can_comparison_sample(module) { + log::debug!("\t\ttracking {:?} for sampling properties", handle); + + self.handle_sampling + .insert(handle, image::SamplingFlags::empty()); + } + + self.lookup_variable.insert( + id, + LookupVariable { + inner, + handle, + type_id, + }, + ); + Ok(()) + } +} + +fn make_index_literal( + ctx: &mut BlockContext, + index: u32, + block: &mut crate::Block, + emitter: &mut crate::proc::Emitter, + index_type: Handle<crate::Type>, + index_type_id: spirv::Word, + span: crate::Span, +) -> Result<Handle<crate::Expression>, Error> { + block.extend(emitter.finish(ctx.expressions)); + + let literal = match ctx.type_arena[index_type].inner.scalar_kind() { + Some(crate::ScalarKind::Uint) => crate::Literal::U32(index), + Some(crate::ScalarKind::Sint) => crate::Literal::I32(index as i32), + _ => return Err(Error::InvalidIndexType(index_type_id)), + }; + let expr = ctx + .expressions + .append(crate::Expression::Literal(literal), span); + + emitter.start(ctx.expressions); + Ok(expr) +} + +fn resolve_constant( + gctx: crate::proc::GlobalCtx, + constant: Handle<crate::Constant>, +) -> Option<u32> { + match gctx.const_expressions[gctx.constants[constant].init] { + crate::Expression::Literal(crate::Literal::U32(id)) => Some(id), + crate::Expression::Literal(crate::Literal::I32(id)) => Some(id as u32), + _ => None, + } +} + +pub fn parse_u8_slice(data: &[u8], options: &Options) -> Result<crate::Module, Error> { + if data.len() % 4 != 0 { + return Err(Error::IncompleteData); + } + + let words = data + .chunks(4) + .map(|c| u32::from_le_bytes(c.try_into().unwrap())); + Frontend::new(words, options).parse() +} + +#[cfg(test)] +mod test { + #[test] + fn parse() { + let bin = vec![ + // Magic number. Version number: 1.0. + 0x03, 0x02, 0x23, 0x07, 0x00, 0x00, 0x01, 0x00, + // Generator number: 0. Bound: 0. + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Reserved word: 0. + 0x00, 0x00, 0x00, 0x00, // OpMemoryModel. Logical. + 0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // GLSL450. + 0x01, 0x00, 0x00, 0x00, + ]; + let _ = super::parse_u8_slice(&bin, &Default::default()).unwrap(); + } +} + +/// Helper function to check if `child` is in the scope of `parent` +fn is_parent(mut child: usize, parent: usize, block_ctx: &BlockContext) -> bool { + loop { + if child == parent { + // The child is in the scope parent + break true; + } else if child == 0 { + // Searched finished at the root the child isn't in the parent's body + break false; + } + + child = block_ctx.bodies[child].parent; + } +} diff --git a/third_party/rust/naga/src/front/spv/null.rs b/third_party/rust/naga/src/front/spv/null.rs new file mode 100644 index 0000000000..42cccca80a --- /dev/null +++ b/third_party/rust/naga/src/front/spv/null.rs @@ -0,0 +1,31 @@ +use super::Error; +use crate::arena::{Arena, Handle}; + +/// Create a default value for an output built-in. +pub fn generate_default_built_in( + built_in: Option<crate::BuiltIn>, + ty: Handle<crate::Type>, + const_expressions: &mut Arena<crate::Expression>, + span: crate::Span, +) -> Result<Handle<crate::Expression>, Error> { + let expr = match built_in { + Some(crate::BuiltIn::Position { .. }) => { + let zero = const_expressions + .append(crate::Expression::Literal(crate::Literal::F32(0.0)), span); + let one = const_expressions + .append(crate::Expression::Literal(crate::Literal::F32(1.0)), span); + crate::Expression::Compose { + ty, + components: vec![zero, zero, zero, one], + } + } + Some(crate::BuiltIn::PointSize) => crate::Expression::Literal(crate::Literal::F32(1.0)), + Some(crate::BuiltIn::FragDepth) => crate::Expression::Literal(crate::Literal::F32(0.0)), + Some(crate::BuiltIn::SampleMask) => { + crate::Expression::Literal(crate::Literal::U32(u32::MAX)) + } + // Note: `crate::BuiltIn::ClipDistance` is intentionally left for the default path + _ => crate::Expression::ZeroValue(ty), + }; + Ok(const_expressions.append(expr, span)) +} diff --git a/third_party/rust/naga/src/front/type_gen.rs b/third_party/rust/naga/src/front/type_gen.rs new file mode 100644 index 0000000000..34730c1db5 --- /dev/null +++ b/third_party/rust/naga/src/front/type_gen.rs @@ -0,0 +1,437 @@ +/*! +Type generators. +*/ + +use crate::{arena::Handle, span::Span}; + +impl crate::Module { + /// Populate this module's [`SpecialTypes::ray_desc`] type. + /// + /// [`SpecialTypes::ray_desc`] is the type of the [`descriptor`] operand of + /// an [`Initialize`] [`RayQuery`] statement. In WGSL, it is a struct type + /// referred to as `RayDesc`. + /// + /// Backends consume values of this type to drive platform APIs, so if you + /// change any its fields, you must update the backends to match. Look for + /// backend code dealing with [`RayQueryFunction::Initialize`]. + /// + /// [`SpecialTypes::ray_desc`]: crate::SpecialTypes::ray_desc + /// [`descriptor`]: crate::RayQueryFunction::Initialize::descriptor + /// [`Initialize`]: crate::RayQueryFunction::Initialize + /// [`RayQuery`]: crate::Statement::RayQuery + /// [`RayQueryFunction::Initialize`]: crate::RayQueryFunction::Initialize + pub fn generate_ray_desc_type(&mut self) -> Handle<crate::Type> { + if let Some(handle) = self.special_types.ray_desc { + return handle; + } + + let ty_flag = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::U32), + }, + Span::UNDEFINED, + ); + let ty_scalar = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::F32), + }, + Span::UNDEFINED, + ); + let ty_vector = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size: crate::VectorSize::Tri, + scalar: crate::Scalar::F32, + }, + }, + Span::UNDEFINED, + ); + + let handle = self.types.insert( + crate::Type { + name: Some("RayDesc".to_string()), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("flags".to_string()), + ty: ty_flag, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("cull_mask".to_string()), + ty: ty_flag, + binding: None, + offset: 4, + }, + crate::StructMember { + name: Some("tmin".to_string()), + ty: ty_scalar, + binding: None, + offset: 8, + }, + crate::StructMember { + name: Some("tmax".to_string()), + ty: ty_scalar, + binding: None, + offset: 12, + }, + crate::StructMember { + name: Some("origin".to_string()), + ty: ty_vector, + binding: None, + offset: 16, + }, + crate::StructMember { + name: Some("dir".to_string()), + ty: ty_vector, + binding: None, + offset: 32, + }, + ], + span: 48, + }, + }, + Span::UNDEFINED, + ); + + self.special_types.ray_desc = Some(handle); + handle + } + + /// Populate this module's [`SpecialTypes::ray_intersection`] type. + /// + /// [`SpecialTypes::ray_intersection`] is the type of a + /// `RayQueryGetIntersection` expression. In WGSL, it is a struct type + /// referred to as `RayIntersection`. + /// + /// Backends construct values of this type based on platform APIs, so if you + /// change any its fields, you must update the backends to match. Look for + /// the backend's handling for [`Expression::RayQueryGetIntersection`]. + /// + /// [`SpecialTypes::ray_intersection`]: crate::SpecialTypes::ray_intersection + /// [`Expression::RayQueryGetIntersection`]: crate::Expression::RayQueryGetIntersection + pub fn generate_ray_intersection_type(&mut self) -> Handle<crate::Type> { + if let Some(handle) = self.special_types.ray_intersection { + return handle; + } + + let ty_flag = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::U32), + }, + Span::UNDEFINED, + ); + let ty_scalar = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::F32), + }, + Span::UNDEFINED, + ); + let ty_barycentrics = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size: crate::VectorSize::Bi, + scalar: crate::Scalar::F32, + }, + }, + Span::UNDEFINED, + ); + let ty_bool = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::BOOL), + }, + Span::UNDEFINED, + ); + let ty_transform = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + scalar: crate::Scalar::F32, + }, + }, + Span::UNDEFINED, + ); + + let handle = self.types.insert( + crate::Type { + name: Some("RayIntersection".to_string()), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("kind".to_string()), + ty: ty_flag, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("t".to_string()), + ty: ty_scalar, + binding: None, + offset: 4, + }, + crate::StructMember { + name: Some("instance_custom_index".to_string()), + ty: ty_flag, + binding: None, + offset: 8, + }, + crate::StructMember { + name: Some("instance_id".to_string()), + ty: ty_flag, + binding: None, + offset: 12, + }, + crate::StructMember { + name: Some("sbt_record_offset".to_string()), + ty: ty_flag, + binding: None, + offset: 16, + }, + crate::StructMember { + name: Some("geometry_index".to_string()), + ty: ty_flag, + binding: None, + offset: 20, + }, + crate::StructMember { + name: Some("primitive_index".to_string()), + ty: ty_flag, + binding: None, + offset: 24, + }, + crate::StructMember { + name: Some("barycentrics".to_string()), + ty: ty_barycentrics, + binding: None, + offset: 28, + }, + crate::StructMember { + name: Some("front_face".to_string()), + ty: ty_bool, + binding: None, + offset: 36, + }, + crate::StructMember { + name: Some("object_to_world".to_string()), + ty: ty_transform, + binding: None, + offset: 48, + }, + crate::StructMember { + name: Some("world_to_object".to_string()), + ty: ty_transform, + binding: None, + offset: 112, + }, + ], + span: 176, + }, + }, + Span::UNDEFINED, + ); + + self.special_types.ray_intersection = Some(handle); + handle + } + + /// Populate this module's [`SpecialTypes::predeclared_types`] type and return the handle. + /// + /// [`SpecialTypes::predeclared_types`]: crate::SpecialTypes::predeclared_types + pub fn generate_predeclared_type( + &mut self, + special_type: crate::PredeclaredType, + ) -> Handle<crate::Type> { + use std::fmt::Write; + + if let Some(value) = self.special_types.predeclared_types.get(&special_type) { + return *value; + } + + let ty = match special_type { + crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => { + let bool_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::BOOL), + }, + Span::UNDEFINED, + ); + let scalar_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(scalar), + }, + Span::UNDEFINED, + ); + + crate::Type { + name: Some(format!( + "__atomic_compare_exchange_result<{:?},{}>", + scalar.kind, scalar.width, + )), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("old_value".to_string()), + ty: scalar_ty, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("exchanged".to_string()), + ty: bool_ty, + binding: None, + offset: 4, + }, + ], + span: 8, + }, + } + } + crate::PredeclaredType::ModfResult { size, width } => { + let float_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::float(width)), + }, + Span::UNDEFINED, + ); + + let (member_ty, second_offset) = if let Some(size) = size { + let vec_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size, + scalar: crate::Scalar::float(width), + }, + }, + Span::UNDEFINED, + ); + (vec_ty, size as u32 * width as u32) + } else { + (float_ty, width as u32) + }; + + let mut type_name = "__modf_result_".to_string(); + if let Some(size) = size { + let _ = write!(type_name, "vec{}_", size as u8); + } + let _ = write!(type_name, "f{}", width * 8); + + crate::Type { + name: Some(type_name), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("fract".to_string()), + ty: member_ty, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("whole".to_string()), + ty: member_ty, + binding: None, + offset: second_offset, + }, + ], + span: second_offset * 2, + }, + } + } + crate::PredeclaredType::FrexpResult { size, width } => { + let float_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar::float(width)), + }, + Span::UNDEFINED, + ); + + let int_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Sint, + width, + }), + }, + Span::UNDEFINED, + ); + + let (fract_member_ty, exp_member_ty, second_offset) = if let Some(size) = size { + let vec_float_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size, + scalar: crate::Scalar::float(width), + }, + }, + Span::UNDEFINED, + ); + let vec_int_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size, + scalar: crate::Scalar { + kind: crate::ScalarKind::Sint, + width, + }, + }, + }, + Span::UNDEFINED, + ); + (vec_float_ty, vec_int_ty, size as u32 * width as u32) + } else { + (float_ty, int_ty, width as u32) + }; + + let mut type_name = "__frexp_result_".to_string(); + if let Some(size) = size { + let _ = write!(type_name, "vec{}_", size as u8); + } + let _ = write!(type_name, "f{}", width * 8); + + crate::Type { + name: Some(type_name), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("fract".to_string()), + ty: fract_member_ty, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("exp".to_string()), + ty: exp_member_ty, + binding: None, + offset: second_offset, + }, + ], + span: second_offset * 2, + }, + } + } + }; + + let handle = self.types.insert(ty, Span::UNDEFINED); + self.special_types + .predeclared_types + .insert(special_type, handle); + handle + } +} diff --git a/third_party/rust/naga/src/front/wgsl/error.rs b/third_party/rust/naga/src/front/wgsl/error.rs new file mode 100644 index 0000000000..07e68f8dd9 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/error.rs @@ -0,0 +1,775 @@ +use crate::front::wgsl::parse::lexer::Token; +use crate::front::wgsl::Scalar; +use crate::proc::{Alignment, ConstantEvaluatorError, ResolveError}; +use crate::{SourceLocation, Span}; +use codespan_reporting::diagnostic::{Diagnostic, Label}; +use codespan_reporting::files::SimpleFile; +use codespan_reporting::term; +use std::borrow::Cow; +use std::ops::Range; +use termcolor::{ColorChoice, NoColor, StandardStream}; +use thiserror::Error; + +#[derive(Clone, Debug)] +pub struct ParseError { + message: String, + labels: Vec<(Span, Cow<'static, str>)>, + notes: Vec<String>, +} + +impl ParseError { + pub fn labels(&self) -> impl ExactSizeIterator<Item = (Span, &str)> + '_ { + self.labels + .iter() + .map(|&(span, ref msg)| (span, msg.as_ref())) + } + + pub fn message(&self) -> &str { + &self.message + } + + fn diagnostic(&self) -> Diagnostic<()> { + let diagnostic = Diagnostic::error() + .with_message(self.message.to_string()) + .with_labels( + self.labels + .iter() + .filter_map(|label| label.0.to_range().map(|range| (label, range))) + .map(|(label, range)| { + Label::primary((), range).with_message(label.1.to_string()) + }) + .collect(), + ) + .with_notes( + self.notes + .iter() + .map(|note| format!("note: {note}")) + .collect(), + ); + diagnostic + } + + /// Emits a summary of the error to standard error stream. + pub fn emit_to_stderr(&self, source: &str) { + self.emit_to_stderr_with_path(source, "wgsl") + } + + /// Emits a summary of the error to standard error stream. + pub fn emit_to_stderr_with_path<P>(&self, source: &str, path: P) + where + P: AsRef<std::path::Path>, + { + let path = path.as_ref().display().to_string(); + let files = SimpleFile::new(path, source); + let config = codespan_reporting::term::Config::default(); + let writer = StandardStream::stderr(ColorChoice::Auto); + term::emit(&mut writer.lock(), &config, &files, &self.diagnostic()) + .expect("cannot write error"); + } + + /// Emits a summary of the error to a string. + pub fn emit_to_string(&self, source: &str) -> String { + self.emit_to_string_with_path(source, "wgsl") + } + + /// Emits a summary of the error to a string. + pub fn emit_to_string_with_path<P>(&self, source: &str, path: P) -> String + where + P: AsRef<std::path::Path>, + { + let path = path.as_ref().display().to_string(); + let files = SimpleFile::new(path, source); + let config = codespan_reporting::term::Config::default(); + let mut writer = NoColor::new(Vec::new()); + term::emit(&mut writer, &config, &files, &self.diagnostic()).expect("cannot write error"); + String::from_utf8(writer.into_inner()).unwrap() + } + + /// Returns a [`SourceLocation`] for the first label in the error message. + pub fn location(&self, source: &str) -> Option<SourceLocation> { + self.labels.get(0).map(|label| label.0.location(source)) + } +} + +impl std::fmt::Display for ParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +impl std::error::Error for ParseError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum ExpectedToken<'a> { + Token(Token<'a>), + Identifier, + /// Expected: constant, parenthesized expression, identifier + PrimaryExpression, + /// Expected: assignment, increment/decrement expression + Assignment, + /// Expected: 'case', 'default', '}' + SwitchItem, + /// Expected: ',', ')' + WorkgroupSizeSeparator, + /// Expected: 'struct', 'let', 'var', 'type', ';', 'fn', eof + GlobalItem, + /// Expected a type. + Type, + /// Access of `var`, `let`, `const`. + Variable, + /// Access of a function + Function, +} + +#[derive(Clone, Copy, Debug, Error, PartialEq)] +pub enum NumberError { + #[error("invalid numeric literal format")] + Invalid, + #[error("numeric literal not representable by target type")] + NotRepresentable, + #[error("unimplemented f16 type")] + UnimplementedF16, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum InvalidAssignmentType { + Other, + Swizzle, + ImmutableBinding(Span), +} + +#[derive(Clone, Debug)] +pub enum Error<'a> { + Unexpected(Span, ExpectedToken<'a>), + UnexpectedComponents(Span), + UnexpectedOperationInConstContext(Span), + BadNumber(Span, NumberError), + BadMatrixScalarKind(Span, Scalar), + BadAccessor(Span), + BadTexture(Span), + BadTypeCast { + span: Span, + from_type: String, + to_type: String, + }, + BadTextureSampleType { + span: Span, + scalar: Scalar, + }, + BadIncrDecrReferenceType(Span), + InvalidResolve(ResolveError), + InvalidForInitializer(Span), + /// A break if appeared outside of a continuing block + InvalidBreakIf(Span), + InvalidGatherComponent(Span), + InvalidConstructorComponentType(Span, i32), + InvalidIdentifierUnderscore(Span), + ReservedIdentifierPrefix(Span), + UnknownAddressSpace(Span), + RepeatedAttribute(Span), + UnknownAttribute(Span), + UnknownBuiltin(Span), + UnknownAccess(Span), + UnknownIdent(Span, &'a str), + UnknownScalarType(Span), + UnknownType(Span), + UnknownStorageFormat(Span), + UnknownConservativeDepth(Span), + SizeAttributeTooLow(Span, u32), + AlignAttributeTooLow(Span, Alignment), + NonPowerOfTwoAlignAttribute(Span), + InconsistentBinding(Span), + TypeNotConstructible(Span), + TypeNotInferable(Span), + InitializationTypeMismatch { + name: Span, + expected: String, + got: String, + }, + MissingType(Span), + MissingAttribute(&'static str, Span), + InvalidAtomicPointer(Span), + InvalidAtomicOperandType(Span), + InvalidRayQueryPointer(Span), + Pointer(&'static str, Span), + NotPointer(Span), + NotReference(&'static str, Span), + InvalidAssignment { + span: Span, + ty: InvalidAssignmentType, + }, + ReservedKeyword(Span), + /// Redefinition of an identifier (used for both module-scope and local redefinitions). + Redefinition { + /// Span of the identifier in the previous definition. + previous: Span, + + /// Span of the identifier in the new definition. + current: Span, + }, + /// A declaration refers to itself directly. + RecursiveDeclaration { + /// The location of the name of the declaration. + ident: Span, + + /// The point at which it is used. + usage: Span, + }, + /// A declaration refers to itself indirectly, through one or more other + /// definitions. + CyclicDeclaration { + /// The location of the name of some declaration in the cycle. + ident: Span, + + /// The edges of the cycle of references. + /// + /// Each `(decl, reference)` pair indicates that the declaration whose + /// name is `decl` has an identifier at `reference` whose definition is + /// the next declaration in the cycle. The last pair's `reference` is + /// the same identifier as `ident`, above. + path: Vec<(Span, Span)>, + }, + InvalidSwitchValue { + uint: bool, + span: Span, + }, + CalledEntryPoint(Span), + WrongArgumentCount { + span: Span, + expected: Range<u32>, + found: u32, + }, + FunctionReturnsVoid(Span), + InvalidWorkGroupUniformLoad(Span), + Internal(&'static str), + ExpectedConstExprConcreteIntegerScalar(Span), + ExpectedNonNegative(Span), + ExpectedPositiveArrayLength(Span), + MissingWorkgroupSize(Span), + ConstantEvaluatorError(ConstantEvaluatorError, Span), + AutoConversion { + dest_span: Span, + dest_type: String, + source_span: Span, + source_type: String, + }, + AutoConversionLeafScalar { + dest_span: Span, + dest_scalar: String, + source_span: Span, + source_type: String, + }, + ConcretizationFailed { + expr_span: Span, + expr_type: String, + scalar: String, + inner: ConstantEvaluatorError, + }, +} + +impl<'a> Error<'a> { + pub(crate) fn as_parse_error(&self, source: &'a str) -> ParseError { + match *self { + Error::Unexpected(unexpected_span, expected) => { + let expected_str = match expected { + ExpectedToken::Token(token) => { + match token { + Token::Separator(c) => format!("'{c}'"), + Token::Paren(c) => format!("'{c}'"), + Token::Attribute => "@".to_string(), + Token::Number(_) => "number".to_string(), + Token::Word(s) => s.to_string(), + Token::Operation(c) => format!("operation ('{c}')"), + Token::LogicalOperation(c) => format!("logical operation ('{c}')"), + Token::ShiftOperation(c) => format!("bitshift ('{c}{c}')"), + Token::AssignmentOperation(c) if c=='<' || c=='>' => format!("bitshift ('{c}{c}=')"), + Token::AssignmentOperation(c) => format!("operation ('{c}=')"), + Token::IncrementOperation => "increment operation".to_string(), + Token::DecrementOperation => "decrement operation".to_string(), + Token::Arrow => "->".to_string(), + Token::Unknown(c) => format!("unknown ('{c}')"), + Token::Trivia => "trivia".to_string(), + Token::End => "end".to_string(), + } + } + ExpectedToken::Identifier => "identifier".to_string(), + ExpectedToken::PrimaryExpression => "expression".to_string(), + ExpectedToken::Assignment => "assignment or increment/decrement".to_string(), + ExpectedToken::SwitchItem => "switch item ('case' or 'default') or a closing curly bracket to signify the end of the switch statement ('}')".to_string(), + ExpectedToken::WorkgroupSizeSeparator => "workgroup size separator (',') or a closing parenthesis".to_string(), + ExpectedToken::GlobalItem => "global item ('struct', 'const', 'var', 'alias', ';', 'fn') or the end of the file".to_string(), + ExpectedToken::Type => "type".to_string(), + ExpectedToken::Variable => "variable access".to_string(), + ExpectedToken::Function => "function name".to_string(), + }; + ParseError { + message: format!( + "expected {}, found '{}'", + expected_str, &source[unexpected_span], + ), + labels: vec![(unexpected_span, format!("expected {expected_str}").into())], + notes: vec![], + } + } + Error::UnexpectedComponents(bad_span) => ParseError { + message: "unexpected components".to_string(), + labels: vec![(bad_span, "unexpected components".into())], + notes: vec![], + }, + Error::UnexpectedOperationInConstContext(span) => ParseError { + message: "this operation is not supported in a const context".to_string(), + labels: vec![(span, "operation not supported here".into())], + notes: vec![], + }, + Error::BadNumber(bad_span, ref err) => ParseError { + message: format!("{}: `{}`", err, &source[bad_span],), + labels: vec![(bad_span, err.to_string().into())], + notes: vec![], + }, + Error::BadMatrixScalarKind(span, scalar) => ParseError { + message: format!( + "matrix scalar type must be floating-point, but found `{}`", + scalar.to_wgsl() + ), + labels: vec![(span, "must be floating-point (e.g. `f32`)".into())], + notes: vec![], + }, + Error::BadAccessor(accessor_span) => ParseError { + message: format!("invalid field accessor `{}`", &source[accessor_span],), + labels: vec![(accessor_span, "invalid accessor".into())], + notes: vec![], + }, + Error::UnknownIdent(ident_span, ident) => ParseError { + message: format!("no definition in scope for identifier: '{ident}'"), + labels: vec![(ident_span, "unknown identifier".into())], + notes: vec![], + }, + Error::UnknownScalarType(bad_span) => ParseError { + message: format!("unknown scalar type: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown scalar type".into())], + notes: vec!["Valid scalar types are f32, f64, i32, u32, bool".into()], + }, + Error::BadTextureSampleType { span, scalar } => ParseError { + message: format!( + "texture sample type must be one of f32, i32 or u32, but found {}", + scalar.to_wgsl() + ), + labels: vec![(span, "must be one of f32, i32 or u32".into())], + notes: vec![], + }, + Error::BadIncrDecrReferenceType(span) => ParseError { + message: + "increment/decrement operation requires reference type to be one of i32 or u32" + .to_string(), + labels: vec![(span, "must be a reference type of i32 or u32".into())], + notes: vec![], + }, + Error::BadTexture(bad_span) => ParseError { + message: format!( + "expected an image, but found '{}' which is not an image", + &source[bad_span] + ), + labels: vec![(bad_span, "not an image".into())], + notes: vec![], + }, + Error::BadTypeCast { + span, + ref from_type, + ref to_type, + } => { + let msg = format!("cannot cast a {from_type} to a {to_type}"); + ParseError { + message: msg.clone(), + labels: vec![(span, msg.into())], + notes: vec![], + } + } + Error::InvalidResolve(ref resolve_error) => ParseError { + message: resolve_error.to_string(), + labels: vec![], + notes: vec![], + }, + Error::InvalidForInitializer(bad_span) => ParseError { + message: format!( + "for(;;) initializer is not an assignment or a function call: '{}'", + &source[bad_span] + ), + labels: vec![(bad_span, "not an assignment or function call".into())], + notes: vec![], + }, + Error::InvalidBreakIf(bad_span) => ParseError { + message: "A break if is only allowed in a continuing block".to_string(), + labels: vec![(bad_span, "not in a continuing block".into())], + notes: vec![], + }, + Error::InvalidGatherComponent(bad_span) => ParseError { + message: format!( + "textureGather component '{}' doesn't exist, must be 0, 1, 2, or 3", + &source[bad_span] + ), + labels: vec![(bad_span, "invalid component".into())], + notes: vec![], + }, + Error::InvalidConstructorComponentType(bad_span, component) => ParseError { + message: format!("invalid type for constructor component at index [{component}]"), + labels: vec![(bad_span, "invalid component type".into())], + notes: vec![], + }, + Error::InvalidIdentifierUnderscore(bad_span) => ParseError { + message: "Identifier can't be '_'".to_string(), + labels: vec![(bad_span, "invalid identifier".into())], + notes: vec![ + "Use phony assignment instead ('_ =' notice the absence of 'let' or 'var')" + .to_string(), + ], + }, + Error::ReservedIdentifierPrefix(bad_span) => ParseError { + message: format!( + "Identifier starts with a reserved prefix: '{}'", + &source[bad_span] + ), + labels: vec![(bad_span, "invalid identifier".into())], + notes: vec![], + }, + Error::UnknownAddressSpace(bad_span) => ParseError { + message: format!("unknown address space: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown address space".into())], + notes: vec![], + }, + Error::RepeatedAttribute(bad_span) => ParseError { + message: format!("repeated attribute: '{}'", &source[bad_span]), + labels: vec![(bad_span, "repeated attribute".into())], + notes: vec![], + }, + Error::UnknownAttribute(bad_span) => ParseError { + message: format!("unknown attribute: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown attribute".into())], + notes: vec![], + }, + Error::UnknownBuiltin(bad_span) => ParseError { + message: format!("unknown builtin: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown builtin".into())], + notes: vec![], + }, + Error::UnknownAccess(bad_span) => ParseError { + message: format!("unknown access: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown access".into())], + notes: vec![], + }, + Error::UnknownStorageFormat(bad_span) => ParseError { + message: format!("unknown storage format: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown storage format".into())], + notes: vec![], + }, + Error::UnknownConservativeDepth(bad_span) => ParseError { + message: format!("unknown conservative depth: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown conservative depth".into())], + notes: vec![], + }, + Error::UnknownType(bad_span) => ParseError { + message: format!("unknown type: '{}'", &source[bad_span]), + labels: vec![(bad_span, "unknown type".into())], + notes: vec![], + }, + Error::SizeAttributeTooLow(bad_span, min_size) => ParseError { + message: format!("struct member size must be at least {min_size}"), + labels: vec![(bad_span, format!("must be at least {min_size}").into())], + notes: vec![], + }, + Error::AlignAttributeTooLow(bad_span, min_align) => ParseError { + message: format!("struct member alignment must be at least {min_align}"), + labels: vec![(bad_span, format!("must be at least {min_align}").into())], + notes: vec![], + }, + Error::NonPowerOfTwoAlignAttribute(bad_span) => ParseError { + message: "struct member alignment must be a power of 2".to_string(), + labels: vec![(bad_span, "must be a power of 2".into())], + notes: vec![], + }, + Error::InconsistentBinding(span) => ParseError { + message: "input/output binding is not consistent".to_string(), + labels: vec![(span, "input/output binding is not consistent".into())], + notes: vec![], + }, + Error::TypeNotConstructible(span) => ParseError { + message: format!("type `{}` is not constructible", &source[span]), + labels: vec![(span, "type is not constructible".into())], + notes: vec![], + }, + Error::TypeNotInferable(span) => ParseError { + message: "type can't be inferred".to_string(), + labels: vec![(span, "type can't be inferred".into())], + notes: vec![], + }, + Error::InitializationTypeMismatch { name, ref expected, ref got } => { + ParseError { + message: format!( + "the type of `{}` is expected to be `{}`, but got `{}`", + &source[name], expected, got, + ), + labels: vec![( + name, + format!("definition of `{}`", &source[name]).into(), + )], + notes: vec![], + } + } + Error::MissingType(name_span) => ParseError { + message: format!("variable `{}` needs a type", &source[name_span]), + labels: vec![( + name_span, + format!("definition of `{}`", &source[name_span]).into(), + )], + notes: vec![], + }, + Error::MissingAttribute(name, name_span) => ParseError { + message: format!( + "variable `{}` needs a '{}' attribute", + &source[name_span], name + ), + labels: vec![( + name_span, + format!("definition of `{}`", &source[name_span]).into(), + )], + notes: vec![], + }, + Error::InvalidAtomicPointer(span) => ParseError { + message: "atomic operation is done on a pointer to a non-atomic".to_string(), + labels: vec![(span, "atomic pointer is invalid".into())], + notes: vec![], + }, + Error::InvalidAtomicOperandType(span) => ParseError { + message: "atomic operand type is inconsistent with the operation".to_string(), + labels: vec![(span, "atomic operand type is invalid".into())], + notes: vec![], + }, + Error::InvalidRayQueryPointer(span) => ParseError { + message: "ray query operation is done on a pointer to a non-ray-query".to_string(), + labels: vec![(span, "ray query pointer is invalid".into())], + notes: vec![], + }, + Error::NotPointer(span) => ParseError { + message: "the operand of the `*` operator must be a pointer".to_string(), + labels: vec![(span, "expression is not a pointer".into())], + notes: vec![], + }, + Error::NotReference(what, span) => ParseError { + message: format!("{what} must be a reference"), + labels: vec![(span, "expression is not a reference".into())], + notes: vec![], + }, + Error::InvalidAssignment { span, ty } => { + let (extra_label, notes) = match ty { + InvalidAssignmentType::Swizzle => ( + None, + vec![ + "WGSL does not support assignments to swizzles".into(), + "consider assigning each component individually".into(), + ], + ), + InvalidAssignmentType::ImmutableBinding(binding_span) => ( + Some((binding_span, "this is an immutable binding".into())), + vec![format!( + "consider declaring '{}' with `var` instead of `let`", + &source[binding_span] + )], + ), + InvalidAssignmentType::Other => (None, vec![]), + }; + + ParseError { + message: "invalid left-hand side of assignment".into(), + labels: std::iter::once((span, "cannot assign to this expression".into())) + .chain(extra_label) + .collect(), + notes, + } + } + Error::Pointer(what, span) => ParseError { + message: format!("{what} must not be a pointer"), + labels: vec![(span, "expression is a pointer".into())], + notes: vec![], + }, + Error::ReservedKeyword(name_span) => ParseError { + message: format!("name `{}` is a reserved keyword", &source[name_span]), + labels: vec![( + name_span, + format!("definition of `{}`", &source[name_span]).into(), + )], + notes: vec![], + }, + Error::Redefinition { previous, current } => ParseError { + message: format!("redefinition of `{}`", &source[current]), + labels: vec![ + ( + current, + format!("redefinition of `{}`", &source[current]).into(), + ), + ( + previous, + format!("previous definition of `{}`", &source[previous]).into(), + ), + ], + notes: vec![], + }, + Error::RecursiveDeclaration { ident, usage } => ParseError { + message: format!("declaration of `{}` is recursive", &source[ident]), + labels: vec![(ident, "".into()), (usage, "uses itself here".into())], + notes: vec![], + }, + Error::CyclicDeclaration { ident, ref path } => ParseError { + message: format!("declaration of `{}` is cyclic", &source[ident]), + labels: path + .iter() + .enumerate() + .flat_map(|(i, &(ident, usage))| { + [ + (ident, "".into()), + ( + usage, + if i == path.len() - 1 { + "ending the cycle".into() + } else { + format!("uses `{}`", &source[ident]).into() + }, + ), + ] + }) + .collect(), + notes: vec![], + }, + Error::InvalidSwitchValue { uint, span } => ParseError { + message: "invalid switch value".to_string(), + labels: vec![( + span, + if uint { + "expected unsigned integer" + } else { + "expected signed integer" + } + .into(), + )], + notes: vec![if uint { + format!("suffix the integer with a `u`: '{}u'", &source[span]) + } else { + let span = span.to_range().unwrap(); + format!( + "remove the `u` suffix: '{}'", + &source[span.start..span.end - 1] + ) + }], + }, + Error::CalledEntryPoint(span) => ParseError { + message: "entry point cannot be called".to_string(), + labels: vec![(span, "entry point cannot be called".into())], + notes: vec![], + }, + Error::WrongArgumentCount { + span, + ref expected, + found, + } => ParseError { + message: format!( + "wrong number of arguments: expected {}, found {}", + if expected.len() < 2 { + format!("{}", expected.start) + } else { + format!("{}..{}", expected.start, expected.end) + }, + found + ), + labels: vec![(span, "wrong number of arguments".into())], + notes: vec![], + }, + Error::FunctionReturnsVoid(span) => ParseError { + message: "function does not return any value".to_string(), + labels: vec![(span, "".into())], + notes: vec![ + "perhaps you meant to call the function in a separate statement?".into(), + ], + }, + Error::InvalidWorkGroupUniformLoad(span) => ParseError { + message: "incorrect type passed to workgroupUniformLoad".into(), + labels: vec![(span, "".into())], + notes: vec!["passed type must be a workgroup pointer".into()], + }, + Error::Internal(message) => ParseError { + message: "internal WGSL front end error".to_string(), + labels: vec![], + notes: vec![message.into()], + }, + Error::ExpectedConstExprConcreteIntegerScalar(span) => ParseError { + message: "must be a const-expression that resolves to a concrete integer scalar (u32 or i32)".to_string(), + labels: vec![(span, "must resolve to u32 or i32".into())], + notes: vec![], + }, + Error::ExpectedNonNegative(span) => ParseError { + message: "must be non-negative (>= 0)".to_string(), + labels: vec![(span, "must be non-negative".into())], + notes: vec![], + }, + Error::ExpectedPositiveArrayLength(span) => ParseError { + message: "array element count must be positive (> 0)".to_string(), + labels: vec![(span, "must be positive".into())], + notes: vec![], + }, + Error::ConstantEvaluatorError(ref e, span) => ParseError { + message: e.to_string(), + labels: vec![(span, "see msg".into())], + notes: vec![], + }, + Error::MissingWorkgroupSize(span) => ParseError { + message: "workgroup size is missing on compute shader entry point".to_string(), + labels: vec![( + span, + "must be paired with a @workgroup_size attribute".into(), + )], + notes: vec![], + }, + Error::AutoConversion { dest_span, ref dest_type, source_span, ref source_type } => ParseError { + message: format!("automatic conversions cannot convert `{source_type}` to `{dest_type}`"), + labels: vec![ + ( + dest_span, + format!("a value of type {dest_type} is required here").into(), + ), + ( + source_span, + format!("this expression has type {source_type}").into(), + ) + ], + notes: vec![], + }, + Error::AutoConversionLeafScalar { dest_span, ref dest_scalar, source_span, ref source_type } => ParseError { + message: format!("automatic conversions cannot convert elements of `{source_type}` to `{dest_scalar}`"), + labels: vec![ + ( + dest_span, + format!("a value with elements of type {dest_scalar} is required here").into(), + ), + ( + source_span, + format!("this expression has type {source_type}").into(), + ) + ], + notes: vec![], + }, + Error::ConcretizationFailed { expr_span, ref expr_type, ref scalar, ref inner } => ParseError { + message: format!("failed to convert expression to a concrete type: {}", inner), + labels: vec![ + ( + expr_span, + format!("this expression has type {}", expr_type).into(), + ) + ], + notes: vec![ + format!("the expression should have been converted to have {} scalar type", scalar), + ] + }, + } + } +} diff --git a/third_party/rust/naga/src/front/wgsl/index.rs b/third_party/rust/naga/src/front/wgsl/index.rs new file mode 100644 index 0000000000..a5524fe8f1 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/index.rs @@ -0,0 +1,193 @@ +use super::Error; +use crate::front::wgsl::parse::ast; +use crate::{FastHashMap, Handle, Span}; + +/// A `GlobalDecl` list in which each definition occurs before all its uses. +pub struct Index<'a> { + dependency_order: Vec<Handle<ast::GlobalDecl<'a>>>, +} + +impl<'a> Index<'a> { + /// Generate an `Index` for the given translation unit. + /// + /// Perform a topological sort on `tu`'s global declarations, placing + /// referents before the definitions that refer to them. + /// + /// Return an error if the graph of references between declarations contains + /// any cycles. + pub fn generate(tu: &ast::TranslationUnit<'a>) -> Result<Self, Error<'a>> { + // Produce a map from global definitions' names to their `Handle<GlobalDecl>`s. + // While doing so, reject conflicting definitions. + let mut globals = FastHashMap::with_capacity_and_hasher(tu.decls.len(), Default::default()); + for (handle, decl) in tu.decls.iter() { + let ident = decl_ident(decl); + let name = ident.name; + if let Some(old) = globals.insert(name, handle) { + return Err(Error::Redefinition { + previous: decl_ident(&tu.decls[old]).span, + current: ident.span, + }); + } + } + + let len = tu.decls.len(); + let solver = DependencySolver { + globals: &globals, + module: tu, + visited: vec![false; len], + temp_visited: vec![false; len], + path: Vec::new(), + out: Vec::with_capacity(len), + }; + let dependency_order = solver.solve()?; + + Ok(Self { dependency_order }) + } + + /// Iterate over `GlobalDecl`s, visiting each definition before all its uses. + /// + /// Produce handles for all of the `GlobalDecl`s of the `TranslationUnit` + /// passed to `Index::generate`, ordered so that a given declaration is + /// produced before any other declaration that uses it. + pub fn visit_ordered(&self) -> impl Iterator<Item = Handle<ast::GlobalDecl<'a>>> + '_ { + self.dependency_order.iter().copied() + } +} + +/// An edge from a reference to its referent in the current depth-first +/// traversal. +/// +/// This is like `ast::Dependency`, except that we've determined which +/// `GlobalDecl` it refers to. +struct ResolvedDependency<'a> { + /// The referent of some identifier used in the current declaration. + decl: Handle<ast::GlobalDecl<'a>>, + + /// Where that use occurs within the current declaration. + usage: Span, +} + +/// Local state for ordering a `TranslationUnit`'s module-scope declarations. +/// +/// Values of this type are used temporarily by `Index::generate` +/// to perform a depth-first sort on the declarations. +/// Technically, what we want is a topological sort, but a depth-first sort +/// has one key benefit - it's much more efficient in storing +/// the path of each node for error generation. +struct DependencySolver<'source, 'temp> { + /// A map from module-scope definitions' names to their handles. + globals: &'temp FastHashMap<&'source str, Handle<ast::GlobalDecl<'source>>>, + + /// The translation unit whose declarations we're ordering. + module: &'temp ast::TranslationUnit<'source>, + + /// For each handle, whether we have pushed it onto `out` yet. + visited: Vec<bool>, + + /// For each handle, whether it is an predecessor in the current depth-first + /// traversal. This is used to detect cycles in the reference graph. + temp_visited: Vec<bool>, + + /// The current path in our depth-first traversal. Used for generating + /// error messages for non-trivial reference cycles. + path: Vec<ResolvedDependency<'source>>, + + /// The list of declaration handles, with declarations before uses. + out: Vec<Handle<ast::GlobalDecl<'source>>>, +} + +impl<'a> DependencySolver<'a, '_> { + /// Produce the sorted list of declaration handles, and check for cycles. + fn solve(mut self) -> Result<Vec<Handle<ast::GlobalDecl<'a>>>, Error<'a>> { + for (id, _) in self.module.decls.iter() { + if self.visited[id.index()] { + continue; + } + + self.dfs(id)?; + } + + Ok(self.out) + } + + /// Ensure that all declarations used by `id` have been added to the + /// ordering, and then append `id` itself. + fn dfs(&mut self, id: Handle<ast::GlobalDecl<'a>>) -> Result<(), Error<'a>> { + let decl = &self.module.decls[id]; + let id_usize = id.index(); + + self.temp_visited[id_usize] = true; + for dep in decl.dependencies.iter() { + if let Some(&dep_id) = self.globals.get(dep.ident) { + self.path.push(ResolvedDependency { + decl: dep_id, + usage: dep.usage, + }); + let dep_id_usize = dep_id.index(); + + if self.temp_visited[dep_id_usize] { + // Found a cycle. + return if dep_id == id { + // A declaration refers to itself directly. + Err(Error::RecursiveDeclaration { + ident: decl_ident(decl).span, + usage: dep.usage, + }) + } else { + // A declaration refers to itself indirectly, through + // one or more other definitions. Report the entire path + // of references. + let start_at = self + .path + .iter() + .rev() + .enumerate() + .find_map(|(i, dep)| (dep.decl == dep_id).then_some(i)) + .unwrap_or(0); + + Err(Error::CyclicDeclaration { + ident: decl_ident(&self.module.decls[dep_id]).span, + path: self.path[start_at..] + .iter() + .map(|curr_dep| { + let curr_id = curr_dep.decl; + let curr_decl = &self.module.decls[curr_id]; + + (decl_ident(curr_decl).span, curr_dep.usage) + }) + .collect(), + }) + }; + } else if !self.visited[dep_id_usize] { + self.dfs(dep_id)?; + } + + // Remove this edge from the current path. + self.path.pop(); + } + + // Ignore unresolved identifiers; they may be predeclared objects. + } + + // Remove this node from the current path. + self.temp_visited[id_usize] = false; + + // Now everything this declaration uses has been visited, and is already + // present in `out`. That means we we can append this one to the + // ordering, and mark it as visited. + self.out.push(id); + self.visited[id_usize] = true; + + Ok(()) + } +} + +const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> { + match decl.kind { + ast::GlobalDeclKind::Fn(ref f) => f.name, + ast::GlobalDeclKind::Var(ref v) => v.name, + ast::GlobalDeclKind::Const(ref c) => c.name, + ast::GlobalDeclKind::Struct(ref s) => s.name, + ast::GlobalDeclKind::Type(ref t) => t.name, + } +} diff --git a/third_party/rust/naga/src/front/wgsl/lower/construction.rs b/third_party/rust/naga/src/front/wgsl/lower/construction.rs new file mode 100644 index 0000000000..de0d11d227 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/lower/construction.rs @@ -0,0 +1,616 @@ +use std::num::NonZeroU32; + +use crate::front::wgsl::parse::ast; +use crate::{Handle, Span}; + +use crate::front::wgsl::error::Error; +use crate::front::wgsl::lower::{ExpressionContext, Lowerer}; + +/// A cooked form of `ast::ConstructorType` that uses Naga types whenever +/// possible. +enum Constructor<T> { + /// A vector construction whose component type is inferred from the + /// argument: `vec3(1.0)`. + PartialVector { size: crate::VectorSize }, + + /// A matrix construction whose component type is inferred from the + /// argument: `mat2x2(1,2,3,4)`. + PartialMatrix { + columns: crate::VectorSize, + rows: crate::VectorSize, + }, + + /// An array whose component type and size are inferred from the arguments: + /// `array(3,4,5)`. + PartialArray, + + /// A known Naga type. + /// + /// When we match on this type, we need to see the `TypeInner` here, but at + /// the point that we build this value we'll still need mutable access to + /// the module later. To avoid borrowing from the module, the type parameter + /// `T` is `Handle<Type>` initially. Then we use `borrow_inner` to produce a + /// version holding a tuple `(Handle<Type>, &TypeInner)`. + Type(T), +} + +impl Constructor<Handle<crate::Type>> { + /// Return an equivalent `Constructor` value that includes borrowed + /// `TypeInner` values alongside any type handles. + /// + /// The returned form is more convenient to match on, since the patterns + /// can actually see what the handle refers to. + fn borrow_inner( + self, + module: &crate::Module, + ) -> Constructor<(Handle<crate::Type>, &crate::TypeInner)> { + match self { + Constructor::PartialVector { size } => Constructor::PartialVector { size }, + Constructor::PartialMatrix { columns, rows } => { + Constructor::PartialMatrix { columns, rows } + } + Constructor::PartialArray => Constructor::PartialArray, + Constructor::Type(handle) => Constructor::Type((handle, &module.types[handle].inner)), + } + } +} + +impl Constructor<(Handle<crate::Type>, &crate::TypeInner)> { + fn to_error_string(&self, ctx: &ExpressionContext) -> String { + match *self { + Self::PartialVector { size } => { + format!("vec{}<?>", size as u32,) + } + Self::PartialMatrix { columns, rows } => { + format!("mat{}x{}<?>", columns as u32, rows as u32,) + } + Self::PartialArray => "array<?, ?>".to_string(), + Self::Type((handle, _inner)) => handle.to_wgsl(&ctx.module.to_ctx()), + } + } +} + +enum Components<'a> { + None, + One { + component: Handle<crate::Expression>, + span: Span, + ty_inner: &'a crate::TypeInner, + }, + Many { + components: Vec<Handle<crate::Expression>>, + spans: Vec<Span>, + }, +} + +impl Components<'_> { + fn into_components_vec(self) -> Vec<Handle<crate::Expression>> { + match self { + Self::None => vec![], + Self::One { component, .. } => vec![component], + Self::Many { components, .. } => components, + } + } +} + +impl<'source, 'temp> Lowerer<'source, 'temp> { + /// Generate Naga IR for a type constructor expression. + /// + /// The `constructor` value represents the head of the constructor + /// expression, which is at least a hint of which type is being built; if + /// it's one of the `Partial` variants, we need to consider the argument + /// types as well. + /// + /// This is used for [`Construct`] expressions, but also for [`Call`] + /// expressions, once we've determined that the "callable" (in WGSL spec + /// terms) is actually a type. + /// + /// [`Construct`]: ast::Expression::Construct + /// [`Call`]: ast::Expression::Call + pub fn construct( + &mut self, + span: Span, + constructor: &ast::ConstructorType<'source>, + ty_span: Span, + components: &[Handle<ast::Expression<'source>>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + use crate::proc::TypeResolution as Tr; + + let constructor_h = self.constructor(constructor, ctx)?; + + let components = match *components { + [] => Components::None, + [component] => { + let span = ctx.ast_expressions.get_span(component); + let component = self.expression_for_abstract(component, ctx)?; + let ty_inner = super::resolve_inner!(ctx, component); + + Components::One { + component, + span, + ty_inner, + } + } + ref ast_components @ [_, _, ..] => { + let components = ast_components + .iter() + .map(|&expr| self.expression_for_abstract(expr, ctx)) + .collect::<Result<_, _>>()?; + let spans = ast_components + .iter() + .map(|&expr| ctx.ast_expressions.get_span(expr)) + .collect(); + + for &component in &components { + ctx.grow_types(component)?; + } + + Components::Many { components, spans } + } + }; + + // Even though we computed `constructor` above, wait until now to borrow + // a reference to the `TypeInner`, so that the component-handling code + // above can have mutable access to the type arena. + let constructor = constructor_h.borrow_inner(ctx.module); + + let expr; + match (components, constructor) { + // Empty constructor + (Components::None, dst_ty) => match dst_ty { + Constructor::Type((result_ty, _)) => { + return ctx.append_expression(crate::Expression::ZeroValue(result_ty), span) + } + Constructor::PartialVector { .. } + | Constructor::PartialMatrix { .. } + | Constructor::PartialArray => { + // We have no arguments from which to infer the result type, so + // partial constructors aren't acceptable here. + return Err(Error::TypeNotInferable(ty_span)); + } + }, + + // Scalar constructor & conversion (scalar -> scalar) + ( + Components::One { + component, + ty_inner: &crate::TypeInner::Scalar { .. }, + .. + }, + Constructor::Type((_, &crate::TypeInner::Scalar(scalar))), + ) => { + expr = crate::Expression::As { + expr: component, + kind: scalar.kind, + convert: Some(scalar.width), + }; + } + + // Vector conversion (vector -> vector) + ( + Components::One { + component, + ty_inner: &crate::TypeInner::Vector { size: src_size, .. }, + .. + }, + Constructor::Type(( + _, + &crate::TypeInner::Vector { + size: dst_size, + scalar: dst_scalar, + }, + )), + ) if dst_size == src_size => { + expr = crate::Expression::As { + expr: component, + kind: dst_scalar.kind, + convert: Some(dst_scalar.width), + }; + } + + // Vector conversion (vector -> vector) - partial + ( + Components::One { + component, + ty_inner: &crate::TypeInner::Vector { size: src_size, .. }, + .. + }, + Constructor::PartialVector { size: dst_size }, + ) if dst_size == src_size => { + // This is a trivial conversion: the sizes match, and a Partial + // constructor doesn't specify a scalar type, so nothing can + // possibly happen. + return Ok(component); + } + + // Matrix conversion (matrix -> matrix) + ( + Components::One { + component, + ty_inner: + &crate::TypeInner::Matrix { + columns: src_columns, + rows: src_rows, + .. + }, + .. + }, + Constructor::Type(( + _, + &crate::TypeInner::Matrix { + columns: dst_columns, + rows: dst_rows, + scalar: dst_scalar, + }, + )), + ) if dst_columns == src_columns && dst_rows == src_rows => { + expr = crate::Expression::As { + expr: component, + kind: dst_scalar.kind, + convert: Some(dst_scalar.width), + }; + } + + // Matrix conversion (matrix -> matrix) - partial + ( + Components::One { + component, + ty_inner: + &crate::TypeInner::Matrix { + columns: src_columns, + rows: src_rows, + .. + }, + .. + }, + Constructor::PartialMatrix { + columns: dst_columns, + rows: dst_rows, + }, + ) if dst_columns == src_columns && dst_rows == src_rows => { + // This is a trivial conversion: the sizes match, and a Partial + // constructor doesn't specify a scalar type, so nothing can + // possibly happen. + return Ok(component); + } + + // Vector constructor (splat) - infer type + ( + Components::One { + component, + ty_inner: &crate::TypeInner::Scalar { .. }, + .. + }, + Constructor::PartialVector { size }, + ) => { + expr = crate::Expression::Splat { + size, + value: component, + }; + } + + // Vector constructor (splat) + ( + Components::One { + mut component, + ty_inner: &crate::TypeInner::Scalar(_), + .. + }, + Constructor::Type((_, &crate::TypeInner::Vector { size, scalar })), + ) => { + ctx.convert_slice_to_common_leaf_scalar( + std::slice::from_mut(&mut component), + scalar, + )?; + expr = crate::Expression::Splat { + size, + value: component, + }; + } + + // Vector constructor (by elements), partial + ( + Components::Many { + mut components, + spans, + }, + Constructor::PartialVector { size }, + ) => { + let consensus_scalar = + ctx.automatic_conversion_consensus(&components) + .map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?; + let inner = consensus_scalar.to_inner_vector(size); + let ty = ctx.ensure_type_exists(inner); + expr = crate::Expression::Compose { ty, components }; + } + + // Vector constructor (by elements), full type given + ( + Components::Many { mut components, .. }, + Constructor::Type((ty, &crate::TypeInner::Vector { scalar, .. })), + ) => { + ctx.try_automatic_conversions_for_vector(&mut components, scalar, ty_span)?; + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by elements), partial + ( + Components::Many { + mut components, + spans, + }, + Constructor::PartialMatrix { columns, rows }, + ) if components.len() == columns as usize * rows as usize => { + let consensus_scalar = + ctx.automatic_conversion_consensus(&components) + .map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + // We actually only accept floating-point elements. + let consensus_scalar = consensus_scalar + .automatic_conversion_combine(crate::Scalar::ABSTRACT_FLOAT) + .unwrap_or(consensus_scalar); + ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?; + let vec_ty = ctx.ensure_type_exists(consensus_scalar.to_inner_vector(rows)); + + let components = components + .chunks(rows as usize) + .map(|vec_components| { + ctx.append_expression( + crate::Expression::Compose { + ty: vec_ty, + components: Vec::from(vec_components), + }, + Default::default(), + ) + }) + .collect::<Result<Vec<_>, _>>()?; + + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar: consensus_scalar, + }); + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by elements), type given + ( + Components::Many { mut components, .. }, + Constructor::Type(( + _, + &crate::TypeInner::Matrix { + columns, + rows, + scalar, + }, + )), + ) if components.len() == columns as usize * rows as usize => { + let element = Tr::Value(crate::TypeInner::Scalar(scalar)); + ctx.try_automatic_conversions_slice(&mut components, &element, ty_span)?; + let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows)); + + let components = components + .chunks(rows as usize) + .map(|vec_components| { + ctx.append_expression( + crate::Expression::Compose { + ty: vec_ty, + components: Vec::from(vec_components), + }, + Default::default(), + ) + }) + .collect::<Result<Vec<_>, _>>()?; + + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar, + }); + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by columns), partial + ( + Components::Many { + mut components, + spans, + }, + Constructor::PartialMatrix { columns, rows }, + ) => { + let consensus_scalar = + ctx.automatic_conversion_consensus(&components) + .map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?; + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar: consensus_scalar, + }); + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by columns), type given + ( + Components::Many { mut components, .. }, + Constructor::Type(( + ty, + &crate::TypeInner::Matrix { + columns: _, + rows, + scalar, + }, + )), + ) => { + let component_ty = crate::TypeInner::Vector { size: rows, scalar }; + ctx.try_automatic_conversions_slice( + &mut components, + &Tr::Value(component_ty), + ty_span, + )?; + expr = crate::Expression::Compose { ty, components }; + } + + // Array constructor - infer type + (components, Constructor::PartialArray) => { + let mut components = components.into_components_vec(); + if let Ok(consensus_scalar) = ctx.automatic_conversion_consensus(&components) { + // Note that this will *not* necessarily convert all the + // components to the same type! The `automatic_conversion_consensus` + // method only considers the parameters' leaf scalar + // types; the parameters themselves could be any mix of + // vectors, matrices, and scalars. + // + // But *if* it is possible for this array construction + // expression to be well-typed at all, then all the + // parameters must have the same type constructors (vec, + // matrix, scalar) applied to their leaf scalars, so + // reconciling their scalars is always the right thing to + // do. And if this array construction is not well-typed, + // these conversions will not make it so, and we can let + // validation catch the error. + ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?; + } else { + // There's no consensus scalar. Emit the `Compose` + // expression anyway, and let validation catch the problem. + } + + let base = ctx.register_type(components[0])?; + + let inner = crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant( + NonZeroU32::new(u32::try_from(components.len()).unwrap()).unwrap(), + ), + stride: { + self.layouter.update(ctx.module.to_ctx()).unwrap(); + self.layouter[base].to_stride() + }, + }; + let ty = ctx.ensure_type_exists(inner); + + expr = crate::Expression::Compose { ty, components }; + } + + // Array constructor, explicit type + (components, Constructor::Type((ty, &crate::TypeInner::Array { base, .. }))) => { + let mut components = components.into_components_vec(); + ctx.try_automatic_conversions_slice(&mut components, &Tr::Handle(base), ty_span)?; + expr = crate::Expression::Compose { ty, components }; + } + + // Struct constructor + ( + components, + Constructor::Type((ty, &crate::TypeInner::Struct { ref members, .. })), + ) => { + let mut components = components.into_components_vec(); + let struct_ty_span = ctx.module.types.get_span(ty); + + // Make a vector of the members' type handles in advance, to + // avoid borrowing `members` from `ctx` while we generate + // new code. + let members: Vec<Handle<crate::Type>> = members.iter().map(|m| m.ty).collect(); + + for (component, &ty) in components.iter_mut().zip(&members) { + *component = + ctx.try_automatic_conversions(*component, &Tr::Handle(ty), struct_ty_span)?; + } + expr = crate::Expression::Compose { ty, components }; + } + + // ERRORS + + // Bad conversion (type cast) + (Components::One { span, ty_inner, .. }, constructor) => { + let from_type = ty_inner.to_wgsl(&ctx.module.to_ctx()); + return Err(Error::BadTypeCast { + span, + from_type, + to_type: constructor.to_error_string(ctx), + }); + } + + // Too many parameters for scalar constructor + ( + Components::Many { spans, .. }, + Constructor::Type((_, &crate::TypeInner::Scalar { .. })), + ) => { + let span = spans[1].until(spans.last().unwrap()); + return Err(Error::UnexpectedComponents(span)); + } + + // Other types can't be constructed + _ => return Err(Error::TypeNotConstructible(ty_span)), + } + + let expr = ctx.append_expression(expr, span)?; + Ok(expr) + } + + /// Build a [`Constructor`] for a WGSL construction expression. + /// + /// If `constructor` conveys enough information to determine which Naga [`Type`] + /// we're actually building (i.e., it's not a partial constructor), then + /// ensure the `Type` exists in [`ctx.module`], and return + /// [`Constructor::Type`]. + /// + /// Otherwise, return the [`Constructor`] partial variant corresponding to + /// `constructor`. + /// + /// [`Type`]: crate::Type + /// [`ctx.module`]: ExpressionContext::module + fn constructor<'out>( + &mut self, + constructor: &ast::ConstructorType<'source>, + ctx: &mut ExpressionContext<'source, '_, 'out>, + ) -> Result<Constructor<Handle<crate::Type>>, Error<'source>> { + let handle = match *constructor { + ast::ConstructorType::Scalar(scalar) => { + let ty = ctx.ensure_type_exists(scalar.to_inner_scalar()); + Constructor::Type(ty) + } + ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size }, + ast::ConstructorType::Vector { size, scalar } => { + let ty = ctx.ensure_type_exists(scalar.to_inner_vector(size)); + Constructor::Type(ty) + } + ast::ConstructorType::PartialMatrix { columns, rows } => { + Constructor::PartialMatrix { columns, rows } + } + ast::ConstructorType::Matrix { + rows, + columns, + width, + } => { + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar: crate::Scalar::float(width), + }); + Constructor::Type(ty) + } + ast::ConstructorType::PartialArray => Constructor::PartialArray, + ast::ConstructorType::Array { base, size } => { + let base = self.resolve_ast_type(base, &mut ctx.as_global())?; + let size = self.array_size(size, &mut ctx.as_global())?; + + self.layouter.update(ctx.module.to_ctx()).unwrap(); + let stride = self.layouter[base].to_stride(); + + let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, stride }); + Constructor::Type(ty) + } + ast::ConstructorType::Type(ty) => Constructor::Type(ty), + }; + + Ok(handle) + } +} diff --git a/third_party/rust/naga/src/front/wgsl/lower/conversion.rs b/third_party/rust/naga/src/front/wgsl/lower/conversion.rs new file mode 100644 index 0000000000..2a2690f096 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/lower/conversion.rs @@ -0,0 +1,503 @@ +//! WGSL's automatic conversions for abstract types. + +use crate::{Handle, Span}; + +impl<'source, 'temp, 'out> super::ExpressionContext<'source, 'temp, 'out> { + /// Try to use WGSL's automatic conversions to convert `expr` to `goal_ty`. + /// + /// If no conversions are necessary, return `expr` unchanged. + /// + /// If automatic conversions cannot convert `expr` to `goal_ty`, return an + /// [`AutoConversion`] error. + /// + /// Although the Load Rule is one of the automatic conversions, this + /// function assumes it has already been applied if appropriate, as + /// indicated by the fact that the Rust type of `expr` is not `Typed<_>`. + /// + /// [`AutoConversion`]: super::Error::AutoConversion + pub fn try_automatic_conversions( + &mut self, + expr: Handle<crate::Expression>, + goal_ty: &crate::proc::TypeResolution, + goal_span: Span, + ) -> Result<Handle<crate::Expression>, super::Error<'source>> { + let expr_span = self.get_expression_span(expr); + // Keep the TypeResolution so we can get type names for + // structs in error messages. + let expr_resolution = super::resolve!(self, expr); + let types = &self.module.types; + let expr_inner = expr_resolution.inner_with(types); + let goal_inner = goal_ty.inner_with(types); + + // If `expr` already has the requested type, we're done. + if expr_inner.equivalent(goal_inner, types) { + return Ok(expr); + } + + let (_expr_scalar, goal_scalar) = + match expr_inner.automatically_converts_to(goal_inner, types) { + Some(scalars) => scalars, + None => { + let gctx = &self.module.to_ctx(); + let source_type = expr_resolution.to_wgsl(gctx); + let dest_type = goal_ty.to_wgsl(gctx); + + return Err(super::Error::AutoConversion { + dest_span: goal_span, + dest_type, + source_span: expr_span, + source_type, + }); + } + }; + + self.convert_leaf_scalar(expr, expr_span, goal_scalar) + } + + /// Try to convert `expr`'s leaf scalar to `goal` using automatic conversions. + /// + /// If no conversions are necessary, return `expr` unchanged. + /// + /// If automatic conversions cannot convert `expr` to `goal_scalar`, return + /// an [`AutoConversionLeafScalar`] error. + /// + /// Although the Load Rule is one of the automatic conversions, this + /// function assumes it has already been applied if appropriate, as + /// indicated by the fact that the Rust type of `expr` is not `Typed<_>`. + /// + /// [`AutoConversionLeafScalar`]: super::Error::AutoConversionLeafScalar + pub fn try_automatic_conversion_for_leaf_scalar( + &mut self, + expr: Handle<crate::Expression>, + goal_scalar: crate::Scalar, + goal_span: Span, + ) -> Result<Handle<crate::Expression>, super::Error<'source>> { + let expr_span = self.get_expression_span(expr); + let expr_resolution = super::resolve!(self, expr); + let types = &self.module.types; + let expr_inner = expr_resolution.inner_with(types); + + let make_error = || { + let gctx = &self.module.to_ctx(); + let source_type = expr_resolution.to_wgsl(gctx); + super::Error::AutoConversionLeafScalar { + dest_span: goal_span, + dest_scalar: goal_scalar.to_wgsl(), + source_span: expr_span, + source_type, + } + }; + + let expr_scalar = match expr_inner.scalar() { + Some(scalar) => scalar, + None => return Err(make_error()), + }; + + if expr_scalar == goal_scalar { + return Ok(expr); + } + + if !expr_scalar.automatically_converts_to(goal_scalar) { + return Err(make_error()); + } + + assert!(expr_scalar.is_abstract()); + + self.convert_leaf_scalar(expr, expr_span, goal_scalar) + } + + fn convert_leaf_scalar( + &mut self, + expr: Handle<crate::Expression>, + expr_span: Span, + goal_scalar: crate::Scalar, + ) -> Result<Handle<crate::Expression>, super::Error<'source>> { + let expr_inner = super::resolve_inner!(self, expr); + if let crate::TypeInner::Array { .. } = *expr_inner { + self.as_const_evaluator() + .cast_array(expr, goal_scalar, expr_span) + .map_err(|err| super::Error::ConstantEvaluatorError(err, expr_span)) + } else { + let cast = crate::Expression::As { + expr, + kind: goal_scalar.kind, + convert: Some(goal_scalar.width), + }; + self.append_expression(cast, expr_span) + } + } + + /// Try to convert `exprs` to `goal_ty` using WGSL's automatic conversions. + pub fn try_automatic_conversions_slice( + &mut self, + exprs: &mut [Handle<crate::Expression>], + goal_ty: &crate::proc::TypeResolution, + goal_span: Span, + ) -> Result<(), super::Error<'source>> { + for expr in exprs.iter_mut() { + *expr = self.try_automatic_conversions(*expr, goal_ty, goal_span)?; + } + + Ok(()) + } + + /// Apply WGSL's automatic conversions to a vector constructor's arguments. + /// + /// When calling a vector constructor like `vec3<f32>(...)`, the parameters + /// can be a mix of scalars and vectors, with the latter being spread out to + /// contribute each of their components as a component of the new value. + /// When the element type is explicit, as with `<f32>` in the example above, + /// WGSL's automatic conversions should convert abstract scalar and vector + /// parameters to the constructor's required scalar type. + pub fn try_automatic_conversions_for_vector( + &mut self, + exprs: &mut [Handle<crate::Expression>], + goal_scalar: crate::Scalar, + goal_span: Span, + ) -> Result<(), super::Error<'source>> { + use crate::proc::TypeResolution as Tr; + use crate::TypeInner as Ti; + let goal_scalar_res = Tr::Value(Ti::Scalar(goal_scalar)); + + for (i, expr) in exprs.iter_mut().enumerate() { + // Keep the TypeResolution so we can get full type names + // in error messages. + let expr_resolution = super::resolve!(self, *expr); + let types = &self.module.types; + let expr_inner = expr_resolution.inner_with(types); + + match *expr_inner { + Ti::Scalar(_) => { + *expr = self.try_automatic_conversions(*expr, &goal_scalar_res, goal_span)?; + } + Ti::Vector { size, scalar: _ } => { + let goal_vector_res = Tr::Value(Ti::Vector { + size, + scalar: goal_scalar, + }); + *expr = self.try_automatic_conversions(*expr, &goal_vector_res, goal_span)?; + } + _ => { + let span = self.get_expression_span(*expr); + return Err(super::Error::InvalidConstructorComponentType( + span, i as i32, + )); + } + } + } + + Ok(()) + } + + /// Convert `expr` to the leaf scalar type `scalar`. + pub fn convert_to_leaf_scalar( + &mut self, + expr: &mut Handle<crate::Expression>, + goal: crate::Scalar, + ) -> Result<(), super::Error<'source>> { + let inner = super::resolve_inner!(self, *expr); + // Do nothing if `inner` doesn't even have leaf scalars; + // it's a type error that validation will catch. + if inner.scalar() != Some(goal) { + let cast = crate::Expression::As { + expr: *expr, + kind: goal.kind, + convert: Some(goal.width), + }; + let expr_span = self.get_expression_span(*expr); + *expr = self.append_expression(cast, expr_span)?; + } + + Ok(()) + } + + /// Convert all expressions in `exprs` to a common scalar type. + /// + /// Note that the caller is responsible for making sure these + /// conversions are actually justified. This function simply + /// generates `As` expressions, regardless of whether they are + /// permitted WGSL automatic conversions. Callers intending to + /// implement automatic conversions need to determine for + /// themselves whether the casts we we generate are justified, + /// perhaps by calling `TypeInner::automatically_converts_to` or + /// `Scalar::automatic_conversion_combine`. + pub fn convert_slice_to_common_leaf_scalar( + &mut self, + exprs: &mut [Handle<crate::Expression>], + goal: crate::Scalar, + ) -> Result<(), super::Error<'source>> { + for expr in exprs.iter_mut() { + self.convert_to_leaf_scalar(expr, goal)?; + } + + Ok(()) + } + + /// Return an expression for the concretized value of `expr`. + /// + /// If `expr` is already concrete, return it unchanged. + pub fn concretize( + &mut self, + mut expr: Handle<crate::Expression>, + ) -> Result<Handle<crate::Expression>, super::Error<'source>> { + let inner = super::resolve_inner!(self, expr); + if let Some(scalar) = inner.automatically_convertible_scalar(&self.module.types) { + let concretized = scalar.concretize(); + if concretized != scalar { + assert!(scalar.is_abstract()); + let expr_span = self.get_expression_span(expr); + expr = self + .as_const_evaluator() + .cast_array(expr, concretized, expr_span) + .map_err(|err| { + // A `TypeResolution` includes the type's full name, if + // it has one. Also, avoid holding the borrow of `inner` + // across the call to `cast_array`. + let expr_type = &self.typifier()[expr]; + super::Error::ConcretizationFailed { + expr_span, + expr_type: expr_type.to_wgsl(&self.module.to_ctx()), + scalar: concretized.to_wgsl(), + inner: err, + } + })?; + } + } + + Ok(expr) + } + + /// Find the consensus scalar of `components` under WGSL's automatic + /// conversions. + /// + /// If `components` can all be converted to any common scalar via + /// WGSL's automatic conversions, return the best such scalar. + /// + /// The `components` slice must not be empty. All elements' types must + /// have been resolved. + /// + /// If `components` are definitely not acceptable as arguments to such + /// constructors, return `Err(i)`, where `i` is the index in + /// `components` of some problematic argument. + /// + /// This function doesn't fully type-check the arguments - it only + /// considers their leaf scalar types. This means it may return `Ok` + /// even when the Naga validator will reject the resulting + /// construction expression later. + pub fn automatic_conversion_consensus<'handle, I>( + &self, + components: I, + ) -> Result<crate::Scalar, usize> + where + I: IntoIterator<Item = &'handle Handle<crate::Expression>>, + I::IntoIter: Clone, // for debugging + { + let types = &self.module.types; + let mut inners = components + .into_iter() + .map(|&c| self.typifier()[c].inner_with(types)); + log::debug!( + "wgsl automatic_conversion_consensus: {:?}", + inners + .clone() + .map(|inner| inner.to_wgsl(&self.module.to_ctx())) + .collect::<Vec<String>>() + ); + let mut best = inners.next().unwrap().scalar().ok_or(0_usize)?; + for (inner, i) in inners.zip(1..) { + let scalar = inner.scalar().ok_or(i)?; + match best.automatic_conversion_combine(scalar) { + Some(new_best) => { + best = new_best; + } + None => return Err(i), + } + } + + log::debug!(" consensus: {:?}", best.to_wgsl()); + Ok(best) + } +} + +impl crate::TypeInner { + /// Determine whether `self` automatically converts to `goal`. + /// + /// If WGSL's automatic conversions (excluding the Load Rule) will + /// convert `self` to `goal`, then return a pair `(from, to)`, + /// where `from` and `to` are the scalar types of the leaf values + /// of `self` and `goal`. + /// + /// This function assumes that `self` and `goal` are different + /// types. Callers should first check whether any conversion is + /// needed at all. + /// + /// If the automatic conversions cannot convert `self` to `goal`, + /// return `None`. + fn automatically_converts_to( + &self, + goal: &Self, + types: &crate::UniqueArena<crate::Type>, + ) -> Option<(crate::Scalar, crate::Scalar)> { + use crate::ScalarKind as Sk; + use crate::TypeInner as Ti; + + // Automatic conversions only change the scalar type of a value's leaves + // (e.g., `vec4<AbstractFloat>` to `vec4<f32>`), never the type + // constructors applied to those scalar types (e.g., never scalar to + // `vec4`, or `vec2` to `vec3`). So first we check that the type + // constructors match, extracting the leaf scalar types in the process. + let expr_scalar; + let goal_scalar; + match (self, goal) { + (&Ti::Scalar(expr), &Ti::Scalar(goal)) => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Vector { + size: expr_size, + scalar: expr, + }, + &Ti::Vector { + size: goal_size, + scalar: goal, + }, + ) if expr_size == goal_size => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Matrix { + rows: expr_rows, + columns: expr_columns, + scalar: expr, + }, + &Ti::Matrix { + rows: goal_rows, + columns: goal_columns, + scalar: goal, + }, + ) if expr_rows == goal_rows && expr_columns == goal_columns => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Array { + base: expr_base, + size: expr_size, + stride: _, + }, + &Ti::Array { + base: goal_base, + size: goal_size, + stride: _, + }, + ) if expr_size == goal_size => { + return types[expr_base] + .inner + .automatically_converts_to(&types[goal_base].inner, types); + } + _ => return None, + } + + match (expr_scalar.kind, goal_scalar.kind) { + (Sk::AbstractFloat, Sk::Float) => {} + (Sk::AbstractInt, Sk::Sint | Sk::Uint | Sk::AbstractFloat | Sk::Float) => {} + _ => return None, + } + + log::trace!(" okay: expr {expr_scalar:?}, goal {goal_scalar:?}"); + Some((expr_scalar, goal_scalar)) + } + + fn automatically_convertible_scalar( + &self, + types: &crate::UniqueArena<crate::Type>, + ) -> Option<crate::Scalar> { + use crate::TypeInner as Ti; + match *self { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => { + Some(scalar) + } + Ti::Array { base, .. } => types[base].inner.automatically_convertible_scalar(types), + Ti::Atomic(_) + | Ti::Pointer { .. } + | Ti::ValuePointer { .. } + | Ti::Struct { .. } + | Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => None, + } + } +} + +impl crate::Scalar { + /// Find the common type of `self` and `other` under WGSL's + /// automatic conversions. + /// + /// If there are any scalars to which WGSL's automatic conversions + /// will convert both `self` and `other`, return the best such + /// scalar. Otherwise, return `None`. + pub const fn automatic_conversion_combine(self, other: Self) -> Option<crate::Scalar> { + use crate::ScalarKind as Sk; + + match (self.kind, other.kind) { + // When the kinds match... + (Sk::AbstractFloat, Sk::AbstractFloat) + | (Sk::AbstractInt, Sk::AbstractInt) + | (Sk::Sint, Sk::Sint) + | (Sk::Uint, Sk::Uint) + | (Sk::Float, Sk::Float) + | (Sk::Bool, Sk::Bool) => { + if self.width == other.width { + // ... either no conversion is necessary ... + Some(self) + } else { + // ... or no conversion is possible. + // We never convert concrete to concrete, and + // abstract types should have only one size. + None + } + } + + // AbstractInt converts to AbstractFloat. + (Sk::AbstractFloat, Sk::AbstractInt) => Some(self), + (Sk::AbstractInt, Sk::AbstractFloat) => Some(other), + + // AbstractFloat converts to Float. + (Sk::AbstractFloat, Sk::Float) => Some(other), + (Sk::Float, Sk::AbstractFloat) => Some(self), + + // AbstractInt converts to concrete integer or float. + (Sk::AbstractInt, Sk::Uint | Sk::Sint | Sk::Float) => Some(other), + (Sk::Uint | Sk::Sint | Sk::Float, Sk::AbstractInt) => Some(self), + + // AbstractFloat can't be reconciled with concrete integer types. + (Sk::AbstractFloat, Sk::Uint | Sk::Sint) | (Sk::Uint | Sk::Sint, Sk::AbstractFloat) => { + None + } + + // Nothing can be reconciled with `bool`. + (Sk::Bool, _) | (_, Sk::Bool) => None, + + // Different concrete types cannot be reconciled. + (Sk::Sint | Sk::Uint | Sk::Float, Sk::Sint | Sk::Uint | Sk::Float) => None, + } + } + + /// Return `true` if automatic conversions will covert `self` to `goal`. + pub fn automatically_converts_to(self, goal: Self) -> bool { + self.automatic_conversion_combine(goal) == Some(goal) + } + + const fn concretize(self) -> Self { + use crate::ScalarKind as Sk; + match self.kind { + Sk::Sint | Sk::Uint | Sk::Float | Sk::Bool => self, + Sk::AbstractInt => Self::I32, + Sk::AbstractFloat => Self::F32, + } + } +} diff --git a/third_party/rust/naga/src/front/wgsl/lower/mod.rs b/third_party/rust/naga/src/front/wgsl/lower/mod.rs new file mode 100644 index 0000000000..ba9b49e135 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/lower/mod.rs @@ -0,0 +1,2760 @@ +use std::num::NonZeroU32; + +use crate::front::wgsl::error::{Error, ExpectedToken, InvalidAssignmentType}; +use crate::front::wgsl::index::Index; +use crate::front::wgsl::parse::number::Number; +use crate::front::wgsl::parse::{ast, conv}; +use crate::front::Typifier; +use crate::proc::{ + ensure_block_returns, Alignment, ConstantEvaluator, Emitter, Layouter, ResolveContext, +}; +use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span}; + +mod construction; +mod conversion; + +/// Resolves the inner type of a given expression. +/// +/// Expects a &mut [`ExpressionContext`] and a [`Handle<Expression>`]. +/// +/// Returns a &[`crate::TypeInner`]. +/// +/// Ideally, we would simply have a function that takes a `&mut ExpressionContext` +/// and returns a `&TypeResolution`. Unfortunately, this leads the borrow checker +/// to conclude that the mutable borrow lasts for as long as we are using the +/// `&TypeResolution`, so we can't use the `ExpressionContext` for anything else - +/// like, say, resolving another operand's type. Using a macro that expands to +/// two separate calls, only the first of which needs a `&mut`, +/// lets the borrow checker see that the mutable borrow is over. +macro_rules! resolve_inner { + ($ctx:ident, $expr:expr) => {{ + $ctx.grow_types($expr)?; + $ctx.typifier()[$expr].inner_with(&$ctx.module.types) + }}; +} +pub(super) use resolve_inner; + +/// Resolves the inner types of two given expressions. +/// +/// Expects a &mut [`ExpressionContext`] and two [`Handle<Expression>`]s. +/// +/// Returns a tuple containing two &[`crate::TypeInner`]. +/// +/// See the documentation of [`resolve_inner!`] for why this macro is necessary. +macro_rules! resolve_inner_binary { + ($ctx:ident, $left:expr, $right:expr) => {{ + $ctx.grow_types($left)?; + $ctx.grow_types($right)?; + ( + $ctx.typifier()[$left].inner_with(&$ctx.module.types), + $ctx.typifier()[$right].inner_with(&$ctx.module.types), + ) + }}; +} + +/// Resolves the type of a given expression. +/// +/// Expects a &mut [`ExpressionContext`] and a [`Handle<Expression>`]. +/// +/// Returns a &[`TypeResolution`]. +/// +/// See the documentation of [`resolve_inner!`] for why this macro is necessary. +/// +/// [`TypeResolution`]: crate::proc::TypeResolution +macro_rules! resolve { + ($ctx:ident, $expr:expr) => {{ + $ctx.grow_types($expr)?; + &$ctx.typifier()[$expr] + }}; +} +pub(super) use resolve; + +/// State for constructing a `crate::Module`. +pub struct GlobalContext<'source, 'temp, 'out> { + /// The `TranslationUnit`'s expressions arena. + ast_expressions: &'temp Arena<ast::Expression<'source>>, + + /// The `TranslationUnit`'s types arena. + types: &'temp Arena<ast::Type<'source>>, + + // Naga IR values. + /// The map from the names of module-scope declarations to the Naga IR + /// `Handle`s we have built for them, owned by `Lowerer::lower`. + globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>, + + /// The module we're constructing. + module: &'out mut crate::Module, + + const_typifier: &'temp mut Typifier, +} + +impl<'source> GlobalContext<'source, '_, '_> { + fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> { + ExpressionContext { + ast_expressions: self.ast_expressions, + globals: self.globals, + types: self.types, + module: self.module, + const_typifier: self.const_typifier, + expr_type: ExpressionContextType::Constant, + } + } + + fn ensure_type_exists( + &mut self, + name: Option<String>, + inner: crate::TypeInner, + ) -> Handle<crate::Type> { + self.module + .types + .insert(crate::Type { inner, name }, Span::UNDEFINED) + } +} + +/// State for lowering a statement within a function. +pub struct StatementContext<'source, 'temp, 'out> { + // WGSL AST values. + /// A reference to [`TranslationUnit::expressions`] for the translation unit + /// we're lowering. + /// + /// [`TranslationUnit::expressions`]: ast::TranslationUnit::expressions + ast_expressions: &'temp Arena<ast::Expression<'source>>, + + /// A reference to [`TranslationUnit::types`] for the translation unit + /// we're lowering. + /// + /// [`TranslationUnit::types`]: ast::TranslationUnit::types + types: &'temp Arena<ast::Type<'source>>, + + // Naga IR values. + /// The map from the names of module-scope declarations to the Naga IR + /// `Handle`s we have built for them, owned by `Lowerer::lower`. + globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>, + + /// A map from each `ast::Local` handle to the Naga expression + /// we've built for it: + /// + /// - WGSL function arguments become Naga [`FunctionArgument`] expressions. + /// + /// - WGSL `var` declarations become Naga [`LocalVariable`] expressions. + /// + /// - WGSL `let` declararations become arbitrary Naga expressions. + /// + /// This always borrows the `local_table` local variable in + /// [`Lowerer::function`]. + /// + /// [`LocalVariable`]: crate::Expression::LocalVariable + /// [`FunctionArgument`]: crate::Expression::FunctionArgument + local_table: &'temp mut FastHashMap<Handle<ast::Local>, Typed<Handle<crate::Expression>>>, + + const_typifier: &'temp mut Typifier, + typifier: &'temp mut Typifier, + function: &'out mut crate::Function, + /// Stores the names of expressions that are assigned in `let` statement + /// Also stores the spans of the names, for use in errors. + named_expressions: &'out mut FastIndexMap<Handle<crate::Expression>, (String, Span)>, + module: &'out mut crate::Module, + + /// Which `Expression`s in `self.naga_expressions` are const expressions, in + /// the WGSL sense. + /// + /// According to the WGSL spec, a const expression must not refer to any + /// `let` declarations, even if those declarations' initializers are + /// themselves const expressions. So this tracker is not simply concerned + /// with the form of the expressions; it is also tracking whether WGSL says + /// we should consider them to be const. See the use of `force_non_const` in + /// the code for lowering `let` bindings. + expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker, +} + +impl<'a, 'temp> StatementContext<'a, 'temp, '_> { + fn as_expression<'t>( + &'t mut self, + block: &'t mut crate::Block, + emitter: &'t mut Emitter, + ) -> ExpressionContext<'a, 't, '_> + where + 'temp: 't, + { + ExpressionContext { + globals: self.globals, + types: self.types, + ast_expressions: self.ast_expressions, + const_typifier: self.const_typifier, + module: self.module, + expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext { + local_table: self.local_table, + function: self.function, + block, + emitter, + typifier: self.typifier, + expression_constness: self.expression_constness, + }), + } + } + + fn as_global(&mut self) -> GlobalContext<'a, '_, '_> { + GlobalContext { + ast_expressions: self.ast_expressions, + globals: self.globals, + types: self.types, + module: self.module, + const_typifier: self.const_typifier, + } + } + + fn invalid_assignment_type(&self, expr: Handle<crate::Expression>) -> InvalidAssignmentType { + if let Some(&(_, span)) = self.named_expressions.get(&expr) { + InvalidAssignmentType::ImmutableBinding(span) + } else { + match self.function.expressions[expr] { + crate::Expression::Swizzle { .. } => InvalidAssignmentType::Swizzle, + crate::Expression::Access { base, .. } => self.invalid_assignment_type(base), + crate::Expression::AccessIndex { base, .. } => self.invalid_assignment_type(base), + _ => InvalidAssignmentType::Other, + } + } + } +} + +pub struct RuntimeExpressionContext<'temp, 'out> { + /// A map from [`ast::Local`] handles to the Naga expressions we've built for them. + /// + /// This is always [`StatementContext::local_table`] for the + /// enclosing statement; see that documentation for details. + local_table: &'temp FastHashMap<Handle<ast::Local>, Typed<Handle<crate::Expression>>>, + + function: &'out mut crate::Function, + block: &'temp mut crate::Block, + emitter: &'temp mut Emitter, + typifier: &'temp mut Typifier, + + /// Which `Expression`s in `self.naga_expressions` are const expressions, in + /// the WGSL sense. + /// + /// See [`StatementContext::expression_constness`] for details. + expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker, +} + +/// The type of Naga IR expression we are lowering an [`ast::Expression`] to. +pub enum ExpressionContextType<'temp, 'out> { + /// We are lowering to an arbitrary runtime expression, to be + /// included in a function's body. + /// + /// The given [`RuntimeExpressionContext`] holds information about local + /// variables, arguments, and other definitions available only to runtime + /// expressions, not constant or override expressions. + Runtime(RuntimeExpressionContext<'temp, 'out>), + + /// We are lowering to a constant expression, to be included in the module's + /// constant expression arena. + /// + /// Everything constant expressions are allowed to refer to is + /// available in the [`ExpressionContext`], so this variant + /// carries no further information. + Constant, +} + +/// State for lowering an [`ast::Expression`] to Naga IR. +/// +/// [`ExpressionContext`]s come in two kinds, distinguished by +/// the value of the [`expr_type`] field: +/// +/// - A [`Runtime`] context contributes [`naga::Expression`]s to a [`naga::Function`]'s +/// runtime expression arena. +/// +/// - A [`Constant`] context contributes [`naga::Expression`]s to a [`naga::Module`]'s +/// constant expression arena. +/// +/// [`ExpressionContext`]s are constructed in restricted ways: +/// +/// - To get a [`Runtime`] [`ExpressionContext`], call +/// [`StatementContext::as_expression`]. +/// +/// - To get a [`Constant`] [`ExpressionContext`], call +/// [`GlobalContext::as_const`]. +/// +/// - You can demote a [`Runtime`] context to a [`Constant`] context +/// by calling [`as_const`], but there's no way to go in the other +/// direction, producing a runtime context from a constant one. This +/// is because runtime expressions can refer to constant +/// expressions, via [`Expression::Constant`], but constant +/// expressions can't refer to a function's expressions. +/// +/// Not to be confused with `wgsl::parse::ExpressionContext`, which is +/// for parsing the `ast::Expression` in the first place. +/// +/// [`expr_type`]: ExpressionContext::expr_type +/// [`Runtime`]: ExpressionContextType::Runtime +/// [`naga::Expression`]: crate::Expression +/// [`naga::Function`]: crate::Function +/// [`Constant`]: ExpressionContextType::Constant +/// [`naga::Module`]: crate::Module +/// [`as_const`]: ExpressionContext::as_const +/// [`Expression::Constant`]: crate::Expression::Constant +pub struct ExpressionContext<'source, 'temp, 'out> { + // WGSL AST values. + ast_expressions: &'temp Arena<ast::Expression<'source>>, + types: &'temp Arena<ast::Type<'source>>, + + // Naga IR values. + /// The map from the names of module-scope declarations to the Naga IR + /// `Handle`s we have built for them, owned by `Lowerer::lower`. + globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>, + + /// The IR [`Module`] we're constructing. + /// + /// [`Module`]: crate::Module + module: &'out mut crate::Module, + + /// Type judgments for [`module::const_expressions`]. + /// + /// [`module::const_expressions`]: crate::Module::const_expressions + const_typifier: &'temp mut Typifier, + + /// Whether we are lowering a constant expression or a general + /// runtime expression, and the data needed in each case. + expr_type: ExpressionContextType<'temp, 'out>, +} + +impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { + fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> { + ExpressionContext { + globals: self.globals, + types: self.types, + ast_expressions: self.ast_expressions, + const_typifier: self.const_typifier, + module: self.module, + expr_type: ExpressionContextType::Constant, + } + } + + fn as_global(&mut self) -> GlobalContext<'source, '_, '_> { + GlobalContext { + ast_expressions: self.ast_expressions, + globals: self.globals, + types: self.types, + module: self.module, + const_typifier: self.const_typifier, + } + } + + fn as_const_evaluator(&mut self) -> ConstantEvaluator { + match self.expr_type { + ExpressionContextType::Runtime(ref mut rctx) => ConstantEvaluator::for_wgsl_function( + self.module, + &mut rctx.function.expressions, + rctx.expression_constness, + rctx.emitter, + rctx.block, + ), + ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(self.module), + } + } + + fn append_expression( + &mut self, + expr: crate::Expression, + span: Span, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let mut eval = self.as_const_evaluator(); + match eval.try_eval_and_append(&expr, span) { + Ok(expr) => Ok(expr), + + // `expr` is not a constant expression. This is fine as + // long as we're not building `Module::const_expressions`. + Err(err) => match self.expr_type { + ExpressionContextType::Runtime(ref mut rctx) => { + Ok(rctx.function.expressions.append(expr, span)) + } + ExpressionContextType::Constant => Err(Error::ConstantEvaluatorError(err, span)), + }, + } + } + + fn const_access(&self, handle: Handle<crate::Expression>) -> Option<u32> { + match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => { + if !ctx.expression_constness.is_const(handle) { + return None; + } + + self.module + .to_ctx() + .eval_expr_to_u32_from(handle, &ctx.function.expressions) + .ok() + } + ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(), + } + } + + fn get_expression_span(&self, handle: Handle<crate::Expression>) -> Span { + match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle), + ExpressionContextType::Constant => self.module.const_expressions.get_span(handle), + } + } + + fn typifier(&self) -> &Typifier { + match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => ctx.typifier, + ExpressionContextType::Constant => self.const_typifier, + } + } + + fn runtime_expression_ctx( + &mut self, + span: Span, + ) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> { + match self.expr_type { + ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx), + ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)), + } + } + + fn gather_component( + &mut self, + expr: Handle<crate::Expression>, + component_span: Span, + gather_span: Span, + ) -> Result<crate::SwizzleComponent, Error<'source>> { + match self.expr_type { + ExpressionContextType::Runtime(ref rctx) => { + if !rctx.expression_constness.is_const(expr) { + return Err(Error::ExpectedConstExprConcreteIntegerScalar( + component_span, + )); + } + + let index = self + .module + .to_ctx() + .eval_expr_to_u32_from(expr, &rctx.function.expressions) + .map_err(|err| match err { + crate::proc::U32EvalError::NonConst => { + Error::ExpectedConstExprConcreteIntegerScalar(component_span) + } + crate::proc::U32EvalError::Negative => { + Error::ExpectedNonNegative(component_span) + } + })?; + crate::SwizzleComponent::XYZW + .get(index as usize) + .copied() + .ok_or(Error::InvalidGatherComponent(component_span)) + } + // This means a `gather` operation appeared in a constant expression. + // This error refers to the `gather` itself, not its "component" argument. + ExpressionContextType::Constant => { + Err(Error::UnexpectedOperationInConstContext(gather_span)) + } + } + } + + /// Determine the type of `handle`, and add it to the module's arena. + /// + /// If you just need a `TypeInner` for `handle`'s type, use the + /// [`resolve_inner!`] macro instead. This function + /// should only be used when the type of `handle` needs to appear + /// in the module's final `Arena<Type>`, for example, if you're + /// creating a [`LocalVariable`] whose type is inferred from its + /// initializer. + /// + /// [`LocalVariable`]: crate::LocalVariable + fn register_type( + &mut self, + handle: Handle<crate::Expression>, + ) -> Result<Handle<crate::Type>, Error<'source>> { + self.grow_types(handle)?; + // This is equivalent to calling ExpressionContext::typifier(), + // except that this lets the borrow checker see that it's okay + // to also borrow self.module.types mutably below. + let typifier = match self.expr_type { + ExpressionContextType::Runtime(ref ctx) => ctx.typifier, + ExpressionContextType::Constant => &*self.const_typifier, + }; + Ok(typifier.register_type(handle, &mut self.module.types)) + } + + /// Resolve the types of all expressions up through `handle`. + /// + /// Ensure that [`self.typifier`] has a [`TypeResolution`] for + /// every expression in [`self.function.expressions`]. + /// + /// This does not add types to any arena. The [`Typifier`] + /// documentation explains the steps we take to avoid filling + /// arenas with intermediate types. + /// + /// This function takes `&mut self`, so it can't conveniently + /// return a shared reference to the resulting `TypeResolution`: + /// the shared reference would extend the mutable borrow, and you + /// wouldn't be able to use `self` for anything else. Instead, you + /// should use [`register_type`] or one of [`resolve!`], + /// [`resolve_inner!`] or [`resolve_inner_binary!`]. + /// + /// [`self.typifier`]: ExpressionContext::typifier + /// [`TypeResolution`]: crate::proc::TypeResolution + /// [`register_type`]: Self::register_type + /// [`Typifier`]: Typifier + fn grow_types( + &mut self, + handle: Handle<crate::Expression>, + ) -> Result<&mut Self, Error<'source>> { + let empty_arena = Arena::new(); + let resolve_ctx; + let typifier; + let expressions; + match self.expr_type { + ExpressionContextType::Runtime(ref mut ctx) => { + resolve_ctx = ResolveContext::with_locals( + self.module, + &ctx.function.local_variables, + &ctx.function.arguments, + ); + typifier = &mut *ctx.typifier; + expressions = &ctx.function.expressions; + } + ExpressionContextType::Constant => { + resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]); + typifier = self.const_typifier; + expressions = &self.module.const_expressions; + } + }; + typifier + .grow(handle, expressions, &resolve_ctx) + .map_err(Error::InvalidResolve)?; + + Ok(self) + } + + fn image_data( + &mut self, + image: Handle<crate::Expression>, + span: Span, + ) -> Result<(crate::ImageClass, bool), Error<'source>> { + match *resolve_inner!(self, image) { + crate::TypeInner::Image { class, arrayed, .. } => Ok((class, arrayed)), + _ => Err(Error::BadTexture(span)), + } + } + + fn prepare_args<'b>( + &mut self, + args: &'b [Handle<ast::Expression<'source>>], + min_args: u32, + span: Span, + ) -> ArgumentContext<'b, 'source> { + ArgumentContext { + args: args.iter(), + min_args, + args_used: 0, + total_args: args.len() as u32, + span, + } + } + + /// Insert splats, if needed by the non-'*' operations. + /// + /// See the "Binary arithmetic expressions with mixed scalar and vector operands" + /// table in the WebGPU Shading Language specification for relevant operators. + /// + /// Multiply is not handled here as backends are expected to handle vec*scalar + /// operations, so inserting splats into the IR increases size needlessly. + fn binary_op_splat( + &mut self, + op: crate::BinaryOperator, + left: &mut Handle<crate::Expression>, + right: &mut Handle<crate::Expression>, + ) -> Result<(), Error<'source>> { + if matches!( + op, + crate::BinaryOperator::Add + | crate::BinaryOperator::Subtract + | crate::BinaryOperator::Divide + | crate::BinaryOperator::Modulo + ) { + match resolve_inner_binary!(self, *left, *right) { + (&crate::TypeInner::Vector { size, .. }, &crate::TypeInner::Scalar { .. }) => { + *right = self.append_expression( + crate::Expression::Splat { + size, + value: *right, + }, + self.get_expression_span(*right), + )?; + } + (&crate::TypeInner::Scalar { .. }, &crate::TypeInner::Vector { size, .. }) => { + *left = self.append_expression( + crate::Expression::Splat { size, value: *left }, + self.get_expression_span(*left), + )?; + } + _ => {} + } + } + + Ok(()) + } + + /// Add a single expression to the expression table that is not covered by `self.emitter`. + /// + /// This is useful for `CallResult` and `AtomicResult` expressions, which should not be covered by + /// `Emit` statements. + fn interrupt_emitter( + &mut self, + expression: crate::Expression, + span: Span, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + match self.expr_type { + ExpressionContextType::Runtime(ref mut rctx) => { + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + } + ExpressionContextType::Constant => {} + } + let result = self.append_expression(expression, span); + match self.expr_type { + ExpressionContextType::Runtime(ref mut rctx) => { + rctx.emitter.start(&rctx.function.expressions); + } + ExpressionContextType::Constant => {} + } + result + } + + /// Apply the WGSL Load Rule to `expr`. + /// + /// If `expr` is has type `ref<SC, T, A>`, perform a load to produce a value of type + /// `T`. Otherwise, return `expr` unchanged. + fn apply_load_rule( + &mut self, + expr: Typed<Handle<crate::Expression>>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + match expr { + Typed::Reference(pointer) => { + let load = crate::Expression::Load { pointer }; + let span = self.get_expression_span(pointer); + self.append_expression(load, span) + } + Typed::Plain(handle) => Ok(handle), + } + } + + fn ensure_type_exists(&mut self, inner: crate::TypeInner) -> Handle<crate::Type> { + self.as_global().ensure_type_exists(None, inner) + } +} + +struct ArgumentContext<'ctx, 'source> { + args: std::slice::Iter<'ctx, Handle<ast::Expression<'source>>>, + min_args: u32, + args_used: u32, + total_args: u32, + span: Span, +} + +impl<'source> ArgumentContext<'_, 'source> { + pub fn finish(self) -> Result<(), Error<'source>> { + if self.args.len() == 0 { + Ok(()) + } else { + Err(Error::WrongArgumentCount { + found: self.total_args, + expected: self.min_args..self.args_used + 1, + span: self.span, + }) + } + } + + pub fn next(&mut self) -> Result<Handle<ast::Expression<'source>>, Error<'source>> { + match self.args.next().copied() { + Some(arg) => { + self.args_used += 1; + Ok(arg) + } + None => Err(Error::WrongArgumentCount { + found: self.total_args, + expected: self.min_args..self.args_used + 1, + span: self.span, + }), + } + } +} + +/// WGSL type annotations on expressions, types, values, etc. +/// +/// Naga and WGSL types are very close, but Naga lacks WGSL's `ref` types, which +/// we need to know to apply the Load Rule. This enum carries some WGSL or Naga +/// datum along with enough information to determine its corresponding WGSL +/// type. +/// +/// The `T` type parameter can be any expression-like thing: +/// +/// - `Typed<Handle<crate::Type>>` can represent a full WGSL type. For example, +/// given some Naga `Pointer` type `ptr`, a WGSL reference type is a +/// `Typed::Reference(ptr)` whereas a WGSL pointer type is a +/// `Typed::Plain(ptr)`. +/// +/// - `Typed<crate::Expression>` or `Typed<Handle<crate::Expression>>` can +/// represent references similarly. +/// +/// Use the `map` and `try_map` methods to convert from one expression +/// representation to another. +/// +/// [`Expression`]: crate::Expression +#[derive(Debug, Copy, Clone)] +enum Typed<T> { + /// A WGSL reference. + Reference(T), + + /// A WGSL plain type. + Plain(T), +} + +impl<T> Typed<T> { + fn map<U>(self, mut f: impl FnMut(T) -> U) -> Typed<U> { + match self { + Self::Reference(v) => Typed::Reference(f(v)), + Self::Plain(v) => Typed::Plain(f(v)), + } + } + + fn try_map<U, E>(self, mut f: impl FnMut(T) -> Result<U, E>) -> Result<Typed<U>, E> { + Ok(match self { + Self::Reference(expr) => Typed::Reference(f(expr)?), + Self::Plain(expr) => Typed::Plain(f(expr)?), + }) + } +} + +/// A single vector component or swizzle. +/// +/// This represents the things that can appear after the `.` in a vector access +/// expression: either a single component name, or a series of them, +/// representing a swizzle. +enum Components { + Single(u32), + Swizzle { + size: crate::VectorSize, + pattern: [crate::SwizzleComponent; 4], + }, +} + +impl Components { + const fn letter_component(letter: char) -> Option<crate::SwizzleComponent> { + use crate::SwizzleComponent as Sc; + match letter { + 'x' | 'r' => Some(Sc::X), + 'y' | 'g' => Some(Sc::Y), + 'z' | 'b' => Some(Sc::Z), + 'w' | 'a' => Some(Sc::W), + _ => None, + } + } + + fn single_component(name: &str, name_span: Span) -> Result<u32, Error> { + let ch = name.chars().next().ok_or(Error::BadAccessor(name_span))?; + match Self::letter_component(ch) { + Some(sc) => Ok(sc as u32), + None => Err(Error::BadAccessor(name_span)), + } + } + + /// Construct a `Components` value from a 'member' name, like `"wzy"` or `"x"`. + /// + /// Use `name_span` for reporting errors in parsing the component string. + fn new(name: &str, name_span: Span) -> Result<Self, Error> { + let size = match name.len() { + 1 => return Ok(Components::Single(Self::single_component(name, name_span)?)), + 2 => crate::VectorSize::Bi, + 3 => crate::VectorSize::Tri, + 4 => crate::VectorSize::Quad, + _ => return Err(Error::BadAccessor(name_span)), + }; + + let mut pattern = [crate::SwizzleComponent::X; 4]; + for (comp, ch) in pattern.iter_mut().zip(name.chars()) { + *comp = Self::letter_component(ch).ok_or(Error::BadAccessor(name_span))?; + } + + Ok(Components::Swizzle { size, pattern }) + } +} + +/// An `ast::GlobalDecl` for which we have built the Naga IR equivalent. +enum LoweredGlobalDecl { + Function(Handle<crate::Function>), + Var(Handle<crate::GlobalVariable>), + Const(Handle<crate::Constant>), + Type(Handle<crate::Type>), + EntryPoint, +} + +enum Texture { + Gather, + GatherCompare, + + Sample, + SampleBias, + SampleCompare, + SampleCompareLevel, + SampleGrad, + SampleLevel, + // SampleBaseClampToEdge, +} + +impl Texture { + pub fn map(word: &str) -> Option<Self> { + Some(match word { + "textureGather" => Self::Gather, + "textureGatherCompare" => Self::GatherCompare, + + "textureSample" => Self::Sample, + "textureSampleBias" => Self::SampleBias, + "textureSampleCompare" => Self::SampleCompare, + "textureSampleCompareLevel" => Self::SampleCompareLevel, + "textureSampleGrad" => Self::SampleGrad, + "textureSampleLevel" => Self::SampleLevel, + // "textureSampleBaseClampToEdge" => Some(Self::SampleBaseClampToEdge), + _ => return None, + }) + } + + pub const fn min_argument_count(&self) -> u32 { + match *self { + Self::Gather => 3, + Self::GatherCompare => 4, + + Self::Sample => 3, + Self::SampleBias => 5, + Self::SampleCompare => 5, + Self::SampleCompareLevel => 5, + Self::SampleGrad => 6, + Self::SampleLevel => 5, + // Self::SampleBaseClampToEdge => 3, + } + } +} + +pub struct Lowerer<'source, 'temp> { + index: &'temp Index<'source>, + layouter: Layouter, +} + +impl<'source, 'temp> Lowerer<'source, 'temp> { + pub fn new(index: &'temp Index<'source>) -> Self { + Self { + index, + layouter: Layouter::default(), + } + } + + pub fn lower( + &mut self, + tu: &'temp ast::TranslationUnit<'source>, + ) -> Result<crate::Module, Error<'source>> { + let mut module = crate::Module::default(); + + let mut ctx = GlobalContext { + ast_expressions: &tu.expressions, + globals: &mut FastHashMap::default(), + types: &tu.types, + module: &mut module, + const_typifier: &mut Typifier::new(), + }; + + for decl_handle in self.index.visit_ordered() { + let span = tu.decls.get_span(decl_handle); + let decl = &tu.decls[decl_handle]; + + match decl.kind { + ast::GlobalDeclKind::Fn(ref f) => { + let lowered_decl = self.function(f, span, &mut ctx)?; + ctx.globals.insert(f.name.name, lowered_decl); + } + ast::GlobalDeclKind::Var(ref v) => { + let ty = self.resolve_ast_type(v.ty, &mut ctx)?; + + let init; + if let Some(init_ast) = v.init { + let mut ectx = ctx.as_const(); + let lowered = self.expression_for_abstract(init_ast, &mut ectx)?; + let ty_res = crate::proc::TypeResolution::Handle(ty); + let converted = ectx + .try_automatic_conversions(lowered, &ty_res, v.name.span) + .map_err(|error| match error { + Error::AutoConversion { + dest_span: _, + dest_type, + source_span: _, + source_type, + } => Error::InitializationTypeMismatch { + name: v.name.span, + expected: dest_type, + got: source_type, + }, + other => other, + })?; + init = Some(converted); + } else { + init = None; + } + + let binding = if let Some(ref binding) = v.binding { + Some(crate::ResourceBinding { + group: self.const_u32(binding.group, &mut ctx.as_const())?.0, + binding: self.const_u32(binding.binding, &mut ctx.as_const())?.0, + }) + } else { + None + }; + + let handle = ctx.module.global_variables.append( + crate::GlobalVariable { + name: Some(v.name.name.to_string()), + space: v.space, + binding, + ty, + init, + }, + span, + ); + + ctx.globals + .insert(v.name.name, LoweredGlobalDecl::Var(handle)); + } + ast::GlobalDeclKind::Const(ref c) => { + let mut ectx = ctx.as_const(); + let mut init = self.expression_for_abstract(c.init, &mut ectx)?; + + let ty; + if let Some(explicit_ty) = c.ty { + let explicit_ty = + self.resolve_ast_type(explicit_ty, &mut ectx.as_global())?; + let explicit_ty_res = crate::proc::TypeResolution::Handle(explicit_ty); + init = ectx + .try_automatic_conversions(init, &explicit_ty_res, c.name.span) + .map_err(|error| match error { + Error::AutoConversion { + dest_span: _, + dest_type, + source_span: _, + source_type, + } => Error::InitializationTypeMismatch { + name: c.name.span, + expected: dest_type, + got: source_type, + }, + other => other, + })?; + ty = explicit_ty; + } else { + init = ectx.concretize(init)?; + ty = ectx.register_type(init)?; + } + + let handle = ctx.module.constants.append( + crate::Constant { + name: Some(c.name.name.to_string()), + r#override: crate::Override::None, + ty, + init, + }, + span, + ); + + ctx.globals + .insert(c.name.name, LoweredGlobalDecl::Const(handle)); + } + ast::GlobalDeclKind::Struct(ref s) => { + let handle = self.r#struct(s, span, &mut ctx)?; + ctx.globals + .insert(s.name.name, LoweredGlobalDecl::Type(handle)); + } + ast::GlobalDeclKind::Type(ref alias) => { + let ty = self.resolve_named_ast_type( + alias.ty, + Some(alias.name.name.to_string()), + &mut ctx, + )?; + ctx.globals + .insert(alias.name.name, LoweredGlobalDecl::Type(ty)); + } + } + } + + // Constant evaluation may leave abstract-typed literals and + // compositions in expression arenas, so we need to compact the module + // to remove unused expressions and types. + crate::compact::compact(&mut module); + + Ok(module) + } + + fn function( + &mut self, + f: &ast::Function<'source>, + span: Span, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result<LoweredGlobalDecl, Error<'source>> { + let mut local_table = FastHashMap::default(); + let mut expressions = Arena::new(); + let mut named_expressions = FastIndexMap::default(); + + let arguments = f + .arguments + .iter() + .enumerate() + .map(|(i, arg)| { + let ty = self.resolve_ast_type(arg.ty, ctx)?; + let expr = expressions + .append(crate::Expression::FunctionArgument(i as u32), arg.name.span); + local_table.insert(arg.handle, Typed::Plain(expr)); + named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span)); + + Ok(crate::FunctionArgument { + name: Some(arg.name.name.to_string()), + ty, + binding: self.binding(&arg.binding, ty, ctx)?, + }) + }) + .collect::<Result<Vec<_>, _>>()?; + + let result = f + .result + .as_ref() + .map(|res| { + let ty = self.resolve_ast_type(res.ty, ctx)?; + Ok(crate::FunctionResult { + ty, + binding: self.binding(&res.binding, ty, ctx)?, + }) + }) + .transpose()?; + + let mut function = crate::Function { + name: Some(f.name.name.to_string()), + arguments, + result, + local_variables: Arena::new(), + expressions, + named_expressions: crate::NamedExpressions::default(), + body: crate::Block::default(), + }; + + let mut typifier = Typifier::default(); + let mut stmt_ctx = StatementContext { + local_table: &mut local_table, + globals: ctx.globals, + ast_expressions: ctx.ast_expressions, + const_typifier: ctx.const_typifier, + typifier: &mut typifier, + function: &mut function, + named_expressions: &mut named_expressions, + types: ctx.types, + module: ctx.module, + expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(), + }; + let mut body = self.block(&f.body, false, &mut stmt_ctx)?; + ensure_block_returns(&mut body); + + function.body = body; + function.named_expressions = named_expressions + .into_iter() + .map(|(key, (name, _))| (key, name)) + .collect(); + + if let Some(ref entry) = f.entry_point { + let workgroup_size = if let Some(workgroup_size) = entry.workgroup_size { + // TODO: replace with try_map once stabilized + let mut workgroup_size_out = [1; 3]; + for (i, size) in workgroup_size.into_iter().enumerate() { + if let Some(size_expr) = size { + workgroup_size_out[i] = self.const_u32(size_expr, &mut ctx.as_const())?.0; + } + } + workgroup_size_out + } else { + [0; 3] + }; + + ctx.module.entry_points.push(crate::EntryPoint { + name: f.name.name.to_string(), + stage: entry.stage, + early_depth_test: entry.early_depth_test, + workgroup_size, + function, + }); + Ok(LoweredGlobalDecl::EntryPoint) + } else { + let handle = ctx.module.functions.append(function, span); + Ok(LoweredGlobalDecl::Function(handle)) + } + } + + fn block( + &mut self, + b: &ast::Block<'source>, + is_inside_loop: bool, + ctx: &mut StatementContext<'source, '_, '_>, + ) -> Result<crate::Block, Error<'source>> { + let mut block = crate::Block::default(); + + for stmt in b.stmts.iter() { + self.statement(stmt, &mut block, is_inside_loop, ctx)?; + } + + Ok(block) + } + + fn statement( + &mut self, + stmt: &ast::Statement<'source>, + block: &mut crate::Block, + is_inside_loop: bool, + ctx: &mut StatementContext<'source, '_, '_>, + ) -> Result<(), Error<'source>> { + let out = match stmt.kind { + ast::StatementKind::Block(ref block) => { + let block = self.block(block, is_inside_loop, ctx)?; + crate::Statement::Block(block) + } + ast::StatementKind::LocalDecl(ref decl) => match *decl { + ast::LocalDecl::Let(ref l) => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let value = + self.expression(l.init, &mut ctx.as_expression(block, &mut emitter))?; + + // The WGSL spec says that any expression that refers to a + // `let`-bound variable is not a const expression. This + // affects when errors must be reported, so we can't even + // treat suitable `let` bindings as constant as an + // optimization. + ctx.expression_constness.force_non_const(value); + + let explicit_ty = + l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global())) + .transpose()?; + + if let Some(ty) = explicit_ty { + let mut ctx = ctx.as_expression(block, &mut emitter); + let init_ty = ctx.register_type(value)?; + if !ctx.module.types[ty] + .inner + .equivalent(&ctx.module.types[init_ty].inner, &ctx.module.types) + { + let gctx = &ctx.module.to_ctx(); + return Err(Error::InitializationTypeMismatch { + name: l.name.span, + expected: ty.to_wgsl(gctx), + got: init_ty.to_wgsl(gctx), + }); + } + } + + block.extend(emitter.finish(&ctx.function.expressions)); + ctx.local_table.insert(l.handle, Typed::Plain(value)); + ctx.named_expressions + .insert(value, (l.name.name.to_string(), l.name.span)); + + return Ok(()); + } + ast::LocalDecl::Var(ref v) => { + let explicit_ty = + v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_global())) + .transpose()?; + + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + let mut ectx = ctx.as_expression(block, &mut emitter); + + let ty; + let initializer; + match (v.init, explicit_ty) { + (Some(init), Some(explicit_ty)) => { + let init = self.expression_for_abstract(init, &mut ectx)?; + let ty_res = crate::proc::TypeResolution::Handle(explicit_ty); + let init = ectx + .try_automatic_conversions(init, &ty_res, v.name.span) + .map_err(|error| match error { + Error::AutoConversion { + dest_span: _, + dest_type, + source_span: _, + source_type, + } => Error::InitializationTypeMismatch { + name: v.name.span, + expected: dest_type, + got: source_type, + }, + other => other, + })?; + ty = explicit_ty; + initializer = Some(init); + } + (Some(init), None) => { + let concretized = self.expression(init, &mut ectx)?; + ty = ectx.register_type(concretized)?; + initializer = Some(concretized); + } + (None, Some(explicit_ty)) => { + ty = explicit_ty; + initializer = None; + } + (None, None) => return Err(Error::MissingType(v.name.span)), + } + + let (const_initializer, initializer) = { + match initializer { + Some(init) => { + // It's not correct to hoist the initializer up + // to the top of the function if: + // - the initialization is inside a loop, and should + // take place on every iteration, or + // - the initialization is not a constant + // expression, so its value depends on the + // state at the point of initialization. + if is_inside_loop || !ctx.expression_constness.is_const(init) { + (None, Some(init)) + } else { + (Some(init), None) + } + } + None => (None, None), + } + }; + + let var = ctx.function.local_variables.append( + crate::LocalVariable { + name: Some(v.name.name.to_string()), + ty, + init: const_initializer, + }, + stmt.span, + ); + + let handle = ctx.as_expression(block, &mut emitter).interrupt_emitter( + crate::Expression::LocalVariable(var), + Span::UNDEFINED, + )?; + block.extend(emitter.finish(&ctx.function.expressions)); + ctx.local_table.insert(v.handle, Typed::Reference(handle)); + + match initializer { + Some(initializer) => crate::Statement::Store { + pointer: handle, + value: initializer, + }, + None => return Ok(()), + } + } + }, + ast::StatementKind::If { + condition, + ref accept, + ref reject, + } => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let condition = + self.expression(condition, &mut ctx.as_expression(block, &mut emitter))?; + block.extend(emitter.finish(&ctx.function.expressions)); + + let accept = self.block(accept, is_inside_loop, ctx)?; + let reject = self.block(reject, is_inside_loop, ctx)?; + + crate::Statement::If { + condition, + accept, + reject, + } + } + ast::StatementKind::Switch { + selector, + ref cases, + } => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let mut ectx = ctx.as_expression(block, &mut emitter); + let selector = self.expression(selector, &mut ectx)?; + + let uint = + resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint); + block.extend(emitter.finish(&ctx.function.expressions)); + + let cases = cases + .iter() + .map(|case| { + Ok(crate::SwitchCase { + value: match case.value { + ast::SwitchValue::Expr(expr) => { + let span = ctx.ast_expressions.get_span(expr); + let expr = + self.expression(expr, &mut ctx.as_global().as_const())?; + match ctx.module.to_ctx().eval_expr_to_literal(expr) { + Some(crate::Literal::I32(value)) if !uint => { + crate::SwitchValue::I32(value) + } + Some(crate::Literal::U32(value)) if uint => { + crate::SwitchValue::U32(value) + } + _ => { + return Err(Error::InvalidSwitchValue { uint, span }); + } + } + } + ast::SwitchValue::Default => crate::SwitchValue::Default, + }, + body: self.block(&case.body, is_inside_loop, ctx)?, + fall_through: case.fall_through, + }) + }) + .collect::<Result<_, _>>()?; + + crate::Statement::Switch { selector, cases } + } + ast::StatementKind::Loop { + ref body, + ref continuing, + break_if, + } => { + let body = self.block(body, true, ctx)?; + let mut continuing = self.block(continuing, true, ctx)?; + + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + let break_if = break_if + .map(|expr| { + self.expression(expr, &mut ctx.as_expression(&mut continuing, &mut emitter)) + }) + .transpose()?; + continuing.extend(emitter.finish(&ctx.function.expressions)); + + crate::Statement::Loop { + body, + continuing, + break_if, + } + } + ast::StatementKind::Break => crate::Statement::Break, + ast::StatementKind::Continue => crate::Statement::Continue, + ast::StatementKind::Return { value } => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let value = value + .map(|expr| self.expression(expr, &mut ctx.as_expression(block, &mut emitter))) + .transpose()?; + block.extend(emitter.finish(&ctx.function.expressions)); + + crate::Statement::Return { value } + } + ast::StatementKind::Kill => crate::Statement::Kill, + ast::StatementKind::Call { + ref function, + ref arguments, + } => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let _ = self.call( + stmt.span, + function, + arguments, + &mut ctx.as_expression(block, &mut emitter), + )?; + block.extend(emitter.finish(&ctx.function.expressions)); + return Ok(()); + } + ast::StatementKind::Assign { + target: ast_target, + op, + value, + } => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let target = self.expression_for_reference( + ast_target, + &mut ctx.as_expression(block, &mut emitter), + )?; + let mut value = + self.expression(value, &mut ctx.as_expression(block, &mut emitter))?; + + let target_handle = match target { + Typed::Reference(handle) => handle, + Typed::Plain(handle) => { + let ty = ctx.invalid_assignment_type(handle); + return Err(Error::InvalidAssignment { + span: ctx.ast_expressions.get_span(ast_target), + ty, + }); + } + }; + + let value = match op { + Some(op) => { + let mut ctx = ctx.as_expression(block, &mut emitter); + let mut left = ctx.apply_load_rule(target)?; + ctx.binary_op_splat(op, &mut left, &mut value)?; + ctx.append_expression( + crate::Expression::Binary { + op, + left, + right: value, + }, + stmt.span, + )? + } + None => value, + }; + block.extend(emitter.finish(&ctx.function.expressions)); + + crate::Statement::Store { + pointer: target_handle, + value, + } + } + ast::StatementKind::Increment(value) | ast::StatementKind::Decrement(value) => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let op = match stmt.kind { + ast::StatementKind::Increment(_) => crate::BinaryOperator::Add, + ast::StatementKind::Decrement(_) => crate::BinaryOperator::Subtract, + _ => unreachable!(), + }; + + let value_span = ctx.ast_expressions.get_span(value); + let target = self + .expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?; + let target_handle = match target { + Typed::Reference(handle) => handle, + Typed::Plain(_) => return Err(Error::BadIncrDecrReferenceType(value_span)), + }; + + let mut ectx = ctx.as_expression(block, &mut emitter); + let scalar = match *resolve_inner!(ectx, target_handle) { + crate::TypeInner::ValuePointer { + size: None, scalar, .. + } => scalar, + crate::TypeInner::Pointer { base, .. } => match ectx.module.types[base].inner { + crate::TypeInner::Scalar(scalar) => scalar, + _ => return Err(Error::BadIncrDecrReferenceType(value_span)), + }, + _ => return Err(Error::BadIncrDecrReferenceType(value_span)), + }; + let literal = match scalar.kind { + crate::ScalarKind::Sint | crate::ScalarKind::Uint => { + crate::Literal::one(scalar) + .ok_or(Error::BadIncrDecrReferenceType(value_span))? + } + _ => return Err(Error::BadIncrDecrReferenceType(value_span)), + }; + + let right = + ectx.interrupt_emitter(crate::Expression::Literal(literal), Span::UNDEFINED)?; + let rctx = ectx.runtime_expression_ctx(stmt.span)?; + let left = rctx.function.expressions.append( + crate::Expression::Load { + pointer: target_handle, + }, + value_span, + ); + let value = rctx + .function + .expressions + .append(crate::Expression::Binary { op, left, right }, stmt.span); + + block.extend(emitter.finish(&ctx.function.expressions)); + crate::Statement::Store { + pointer: target_handle, + value, + } + } + ast::StatementKind::Ignore(expr) => { + let mut emitter = Emitter::default(); + emitter.start(&ctx.function.expressions); + + let _ = self.expression(expr, &mut ctx.as_expression(block, &mut emitter))?; + block.extend(emitter.finish(&ctx.function.expressions)); + return Ok(()); + } + }; + + block.push(out, stmt.span); + + Ok(()) + } + + /// Lower `expr` and apply the Load Rule if possible. + /// + /// For the time being, this concretizes abstract values, to support + /// consumers that haven't been adapted to consume them yet. Consumers + /// prepared for abstract values can call [`expression_for_abstract`]. + /// + /// [`expression_for_abstract`]: Lowerer::expression_for_abstract + fn expression( + &mut self, + expr: Handle<ast::Expression<'source>>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let expr = self.expression_for_abstract(expr, ctx)?; + ctx.concretize(expr) + } + + fn expression_for_abstract( + &mut self, + expr: Handle<ast::Expression<'source>>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let expr = self.expression_for_reference(expr, ctx)?; + ctx.apply_load_rule(expr) + } + + fn expression_for_reference( + &mut self, + expr: Handle<ast::Expression<'source>>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Typed<Handle<crate::Expression>>, Error<'source>> { + let span = ctx.ast_expressions.get_span(expr); + let expr = &ctx.ast_expressions[expr]; + + let expr: Typed<crate::Expression> = match *expr { + ast::Expression::Literal(literal) => { + let literal = match literal { + ast::Literal::Number(Number::F32(f)) => crate::Literal::F32(f), + ast::Literal::Number(Number::I32(i)) => crate::Literal::I32(i), + ast::Literal::Number(Number::U32(u)) => crate::Literal::U32(u), + ast::Literal::Number(Number::F64(f)) => crate::Literal::F64(f), + ast::Literal::Number(Number::AbstractInt(i)) => crate::Literal::AbstractInt(i), + ast::Literal::Number(Number::AbstractFloat(f)) => { + crate::Literal::AbstractFloat(f) + } + ast::Literal::Bool(b) => crate::Literal::Bool(b), + }; + let handle = ctx.interrupt_emitter(crate::Expression::Literal(literal), span)?; + return Ok(Typed::Plain(handle)); + } + ast::Expression::Ident(ast::IdentExpr::Local(local)) => { + let rctx = ctx.runtime_expression_ctx(span)?; + return Ok(rctx.local_table[&local]); + } + ast::Expression::Ident(ast::IdentExpr::Unresolved(name)) => { + let global = ctx + .globals + .get(name) + .ok_or(Error::UnknownIdent(span, name))?; + let expr = match *global { + LoweredGlobalDecl::Var(handle) => { + let expr = crate::Expression::GlobalVariable(handle); + match ctx.module.global_variables[handle].space { + crate::AddressSpace::Handle => Typed::Plain(expr), + _ => Typed::Reference(expr), + } + } + LoweredGlobalDecl::Const(handle) => { + Typed::Plain(crate::Expression::Constant(handle)) + } + _ => { + return Err(Error::Unexpected(span, ExpectedToken::Variable)); + } + }; + + return expr.try_map(|handle| ctx.interrupt_emitter(handle, span)); + } + ast::Expression::Construct { + ref ty, + ty_span, + ref components, + } => { + let handle = self.construct(span, ty, ty_span, components, ctx)?; + return Ok(Typed::Plain(handle)); + } + ast::Expression::Unary { op, expr } => { + let expr = self.expression_for_abstract(expr, ctx)?; + Typed::Plain(crate::Expression::Unary { op, expr }) + } + ast::Expression::AddrOf(expr) => { + // The `&` operator simply converts a reference to a pointer. And since a + // reference is required, the Load Rule is not applied. + match self.expression_for_reference(expr, ctx)? { + Typed::Reference(handle) => { + // No code is generated. We just declare the reference a pointer now. + return Ok(Typed::Plain(handle)); + } + Typed::Plain(_) => { + return Err(Error::NotReference("the operand of the `&` operator", span)); + } + } + } + ast::Expression::Deref(expr) => { + // The pointer we dereference must be loaded. + let pointer = self.expression(expr, ctx)?; + + if resolve_inner!(ctx, pointer).pointer_space().is_none() { + return Err(Error::NotPointer(span)); + } + + // No code is generated. We just declare the pointer a reference now. + return Ok(Typed::Reference(pointer)); + } + ast::Expression::Binary { op, left, right } => { + self.binary(op, left, right, span, ctx)? + } + ast::Expression::Call { + ref function, + ref arguments, + } => { + let handle = self + .call(span, function, arguments, ctx)? + .ok_or(Error::FunctionReturnsVoid(function.span))?; + return Ok(Typed::Plain(handle)); + } + ast::Expression::Index { base, index } => { + let lowered_base = self.expression_for_reference(base, ctx)?; + let index = self.expression(index, ctx)?; + + if let Typed::Plain(handle) = lowered_base { + if resolve_inner!(ctx, handle).pointer_space().is_some() { + return Err(Error::Pointer( + "the value indexed by a `[]` subscripting expression", + ctx.ast_expressions.get_span(base), + )); + } + } + + lowered_base.map(|base| match ctx.const_access(index) { + Some(index) => crate::Expression::AccessIndex { base, index }, + None => crate::Expression::Access { base, index }, + }) + } + ast::Expression::Member { base, ref field } => { + let lowered_base = self.expression_for_reference(base, ctx)?; + + let temp_inner; + let composite_type: &crate::TypeInner = match lowered_base { + Typed::Reference(handle) => { + let inner = resolve_inner!(ctx, handle); + match *inner { + crate::TypeInner::Pointer { base, .. } => &ctx.module.types[base].inner, + crate::TypeInner::ValuePointer { + size: None, scalar, .. + } => { + temp_inner = crate::TypeInner::Scalar(scalar); + &temp_inner + } + crate::TypeInner::ValuePointer { + size: Some(size), + scalar, + .. + } => { + temp_inner = crate::TypeInner::Vector { size, scalar }; + &temp_inner + } + _ => unreachable!( + "In Typed::Reference(handle), handle must be a Naga pointer" + ), + } + } + + Typed::Plain(handle) => { + let inner = resolve_inner!(ctx, handle); + if let crate::TypeInner::Pointer { .. } + | crate::TypeInner::ValuePointer { .. } = *inner + { + return Err(Error::Pointer( + "the value accessed by a `.member` expression", + ctx.ast_expressions.get_span(base), + )); + } + inner + } + }; + + let access = match *composite_type { + crate::TypeInner::Struct { ref members, .. } => { + let index = members + .iter() + .position(|m| m.name.as_deref() == Some(field.name)) + .ok_or(Error::BadAccessor(field.span))? + as u32; + + lowered_base.map(|base| crate::Expression::AccessIndex { base, index }) + } + crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => { + match Components::new(field.name, field.span)? { + Components::Swizzle { size, pattern } => { + // Swizzles aren't allowed on matrices, but + // validation will catch that. + Typed::Plain(crate::Expression::Swizzle { + size, + vector: ctx.apply_load_rule(lowered_base)?, + pattern, + }) + } + Components::Single(index) => lowered_base + .map(|base| crate::Expression::AccessIndex { base, index }), + } + } + _ => return Err(Error::BadAccessor(field.span)), + }; + + access + } + ast::Expression::Bitcast { expr, to, ty_span } => { + let expr = self.expression(expr, ctx)?; + let to_resolved = self.resolve_ast_type(to, &mut ctx.as_global())?; + + let element_scalar = match ctx.module.types[to_resolved].inner { + crate::TypeInner::Scalar(scalar) => scalar, + crate::TypeInner::Vector { scalar, .. } => scalar, + _ => { + let ty = resolve!(ctx, expr); + let gctx = &ctx.module.to_ctx(); + return Err(Error::BadTypeCast { + from_type: ty.to_wgsl(gctx), + span: ty_span, + to_type: to_resolved.to_wgsl(gctx), + }); + } + }; + + Typed::Plain(crate::Expression::As { + expr, + kind: element_scalar.kind, + convert: None, + }) + } + }; + + expr.try_map(|handle| ctx.append_expression(handle, span)) + } + + fn binary( + &mut self, + op: crate::BinaryOperator, + left: Handle<ast::Expression<'source>>, + right: Handle<ast::Expression<'source>>, + span: Span, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Typed<crate::Expression>, Error<'source>> { + // Load both operands. + let mut left = self.expression_for_abstract(left, ctx)?; + let mut right = self.expression_for_abstract(right, ctx)?; + + // Convert `scalar op vector` to `vector op vector` by introducing + // `Splat` expressions. + ctx.binary_op_splat(op, &mut left, &mut right)?; + + // Apply automatic conversions. + match op { + // Shift operators require the right operand to be `u32` or + // `vecN<u32>`. We can let the validator sort out vector length + // issues, but the right operand must be, or convert to, a u32 leaf + // scalar. + crate::BinaryOperator::ShiftLeft | crate::BinaryOperator::ShiftRight => { + right = + ctx.try_automatic_conversion_for_leaf_scalar(right, crate::Scalar::U32, span)?; + } + + // All other operators follow the same pattern: reconcile the + // scalar leaf types. If there's no reconciliation possible, + // leave the expressions as they are: validation will report the + // problem. + _ => { + ctx.grow_types(left)?; + ctx.grow_types(right)?; + if let Ok(consensus_scalar) = + ctx.automatic_conversion_consensus([left, right].iter()) + { + ctx.convert_to_leaf_scalar(&mut left, consensus_scalar)?; + ctx.convert_to_leaf_scalar(&mut right, consensus_scalar)?; + } + } + } + + Ok(Typed::Plain(crate::Expression::Binary { op, left, right })) + } + + /// Generate Naga IR for call expressions and statements, and type + /// constructor expressions. + /// + /// The "function" being called is simply an `Ident` that we know refers to + /// some module-scope definition. + /// + /// - If it is the name of a type, then the expression is a type constructor + /// expression: either constructing a value from components, a conversion + /// expression, or a zero value expression. + /// + /// - If it is the name of a function, then we're generating a [`Call`] + /// statement. We may be in the midst of generating code for an + /// expression, in which case we must generate an `Emit` statement to + /// force evaluation of the IR expressions we've generated so far, add the + /// `Call` statement to the current block, and then resume generating + /// expressions. + /// + /// [`Call`]: crate::Statement::Call + fn call( + &mut self, + span: Span, + function: &ast::Ident<'source>, + arguments: &[Handle<ast::Expression<'source>>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Option<Handle<crate::Expression>>, Error<'source>> { + match ctx.globals.get(function.name) { + Some(&LoweredGlobalDecl::Type(ty)) => { + let handle = self.construct( + span, + &ast::ConstructorType::Type(ty), + function.span, + arguments, + ctx, + )?; + Ok(Some(handle)) + } + Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => { + Err(Error::Unexpected(function.span, ExpectedToken::Function)) + } + Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)), + Some(&LoweredGlobalDecl::Function(function)) => { + let arguments = arguments + .iter() + .map(|&arg| self.expression(arg, ctx)) + .collect::<Result<Vec<_>, _>>()?; + + let has_result = ctx.module.functions[function].result.is_some(); + let rctx = ctx.runtime_expression_ctx(span)?; + // we need to always do this before a fn call since all arguments need to be emitted before the fn call + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + let result = has_result.then(|| { + rctx.function + .expressions + .append(crate::Expression::CallResult(function), span) + }); + rctx.emitter.start(&rctx.function.expressions); + rctx.block.push( + crate::Statement::Call { + function, + arguments, + result, + }, + span, + ); + + Ok(result) + } + None => { + let span = function.span; + let expr = if let Some(fun) = conv::map_relational_fun(function.name) { + let mut args = ctx.prepare_args(arguments, 1, span); + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + // Check for no-op all(bool) and any(bool): + let argument_unmodified = matches!( + fun, + crate::RelationalFunction::All | crate::RelationalFunction::Any + ) && { + matches!( + resolve_inner!(ctx, argument), + &crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Bool, + .. + }) + ) + }; + + if argument_unmodified { + return Ok(Some(argument)); + } else { + crate::Expression::Relational { fun, argument } + } + } else if let Some((axis, ctrl)) = conv::map_derivative(function.name) { + let mut args = ctx.prepare_args(arguments, 1, span); + let expr = self.expression(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::Derivative { axis, ctrl, expr } + } else if let Some(fun) = conv::map_standard_fun(function.name) { + let expected = fun.argument_count() as _; + let mut args = ctx.prepare_args(arguments, expected, span); + + let arg = self.expression(args.next()?, ctx)?; + let arg1 = args + .next() + .map(|x| self.expression(x, ctx)) + .ok() + .transpose()?; + let arg2 = args + .next() + .map(|x| self.expression(x, ctx)) + .ok() + .transpose()?; + let arg3 = args + .next() + .map(|x| self.expression(x, ctx)) + .ok() + .transpose()?; + + args.finish()?; + + if fun == crate::MathFunction::Modf || fun == crate::MathFunction::Frexp { + if let Some((size, width)) = match *resolve_inner!(ctx, arg) { + crate::TypeInner::Scalar(crate::Scalar { width, .. }) => { + Some((None, width)) + } + crate::TypeInner::Vector { + size, + scalar: crate::Scalar { width, .. }, + .. + } => Some((Some(size), width)), + _ => None, + } { + ctx.module.generate_predeclared_type( + if fun == crate::MathFunction::Modf { + crate::PredeclaredType::ModfResult { size, width } + } else { + crate::PredeclaredType::FrexpResult { size, width } + }, + ); + } + } + + crate::Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } + } else if let Some(fun) = Texture::map(function.name) { + self.texture_sample_helper(fun, arguments, span, ctx)? + } else { + match function.name { + "select" => { + let mut args = ctx.prepare_args(arguments, 3, span); + + let reject = self.expression(args.next()?, ctx)?; + let accept = self.expression(args.next()?, ctx)?; + let condition = self.expression(args.next()?, ctx)?; + + args.finish()?; + + crate::Expression::Select { + reject, + accept, + condition, + } + } + "arrayLength" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let expr = self.expression(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::ArrayLength(expr) + } + "atomicLoad" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let pointer = self.atomic_pointer(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::Load { pointer } + } + "atomicStore" => { + let mut args = ctx.prepare_args(arguments, 2, span); + let pointer = self.atomic_pointer(args.next()?, ctx)?; + let value = self.expression(args.next()?, ctx)?; + args.finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); + rctx.block + .push(crate::Statement::Store { pointer, value }, span); + return Ok(None); + } + "atomicAdd" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::Add, + arguments, + ctx, + )?)) + } + "atomicSub" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::Subtract, + arguments, + ctx, + )?)) + } + "atomicAnd" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::And, + arguments, + ctx, + )?)) + } + "atomicOr" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::InclusiveOr, + arguments, + ctx, + )?)) + } + "atomicXor" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::ExclusiveOr, + arguments, + ctx, + )?)) + } + "atomicMin" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::Min, + arguments, + ctx, + )?)) + } + "atomicMax" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::Max, + arguments, + ctx, + )?)) + } + "atomicExchange" => { + return Ok(Some(self.atomic_helper( + span, + crate::AtomicFunction::Exchange { compare: None }, + arguments, + ctx, + )?)) + } + "atomicCompareExchangeWeak" => { + let mut args = ctx.prepare_args(arguments, 3, span); + + let pointer = self.atomic_pointer(args.next()?, ctx)?; + + let compare = self.expression(args.next()?, ctx)?; + + let value = args.next()?; + let value_span = ctx.ast_expressions.get_span(value); + let value = self.expression(value, ctx)?; + + args.finish()?; + + let expression = match *resolve_inner!(ctx, value) { + crate::TypeInner::Scalar(scalar) => { + crate::Expression::AtomicResult { + ty: ctx.module.generate_predeclared_type( + crate::PredeclaredType::AtomicCompareExchangeWeakResult( + scalar, + ), + ), + comparison: true, + } + } + _ => return Err(Error::InvalidAtomicOperandType(value_span)), + }; + + let result = ctx.interrupt_emitter(expression, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::Atomic { + pointer, + fun: crate::AtomicFunction::Exchange { + compare: Some(compare), + }, + value, + result, + }, + span, + ); + return Ok(Some(result)); + } + "storageBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::Barrier(crate::Barrier::STORAGE), span); + return Ok(None); + } + "workgroupBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span); + return Ok(None); + } + "workgroupUniformLoad" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let expr = args.next()?; + args.finish()?; + + let pointer = self.expression(expr, ctx)?; + let result_ty = match *resolve_inner!(ctx, pointer) { + crate::TypeInner::Pointer { + base, + space: crate::AddressSpace::WorkGroup, + } => base, + ref other => { + log::error!("Type {other:?} passed to workgroupUniformLoad"); + let span = ctx.ast_expressions.get_span(expr); + return Err(Error::InvalidWorkGroupUniformLoad(span)); + } + }; + let result = ctx.interrupt_emitter( + crate::Expression::WorkGroupUniformLoadResult { ty: result_ty }, + span, + )?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::WorkGroupUniformLoad { pointer, result }, + span, + ); + + return Ok(Some(result)); + } + "textureStore" => { + let mut args = ctx.prepare_args(arguments, 3, span); + + let image = args.next()?; + let image_span = ctx.ast_expressions.get_span(image); + let image = self.expression(image, ctx)?; + + let coordinate = self.expression(args.next()?, ctx)?; + + let (_, arrayed) = ctx.image_data(image, image_span)?; + let array_index = arrayed + .then(|| { + args.min_args += 1; + self.expression(args.next()?, ctx) + }) + .transpose()?; + + let value = self.expression(args.next()?, ctx)?; + + args.finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); + let stmt = crate::Statement::ImageStore { + image, + coordinate, + array_index, + value, + }; + rctx.block.push(stmt, span); + return Ok(None); + } + "textureLoad" => { + let mut args = ctx.prepare_args(arguments, 2, span); + + let image = args.next()?; + let image_span = ctx.ast_expressions.get_span(image); + let image = self.expression(image, ctx)?; + + let coordinate = self.expression(args.next()?, ctx)?; + + let (class, arrayed) = ctx.image_data(image, image_span)?; + let array_index = arrayed + .then(|| { + args.min_args += 1; + self.expression(args.next()?, ctx) + }) + .transpose()?; + + let level = class + .is_mipmapped() + .then(|| { + args.min_args += 1; + self.expression(args.next()?, ctx) + }) + .transpose()?; + + let sample = class + .is_multisampled() + .then(|| self.expression(args.next()?, ctx)) + .transpose()?; + + args.finish()?; + + crate::Expression::ImageLoad { + image, + coordinate, + array_index, + level, + sample, + } + } + "textureDimensions" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + let level = args + .next() + .map(|arg| self.expression(arg, ctx)) + .ok() + .transpose()?; + args.finish()?; + + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::Size { level }, + } + } + "textureNumLevels" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::NumLevels, + } + } + "textureNumLayers" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::NumLayers, + } + } + "textureNumSamples" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let image = self.expression(args.next()?, ctx)?; + args.finish()?; + + crate::Expression::ImageQuery { + image, + query: crate::ImageQuery::NumSamples, + } + } + "rayQueryInitialize" => { + let mut args = ctx.prepare_args(arguments, 3, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + let acceleration_structure = self.expression(args.next()?, ctx)?; + let descriptor = self.expression(args.next()?, ctx)?; + args.finish()?; + + let _ = ctx.module.generate_ray_desc_type(); + let fun = crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + }; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.emitter.start(&rctx.function.expressions); + rctx.block + .push(crate::Statement::RayQuery { query, fun }, span); + return Ok(None); + } + "rayQueryProceed" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + args.finish()?; + + let result = ctx.interrupt_emitter( + crate::Expression::RayQueryProceedResult, + span, + )?; + let fun = crate::RayQueryFunction::Proceed { result }; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::RayQuery { query, fun }, span); + return Ok(Some(result)); + } + "rayQueryGetCommittedIntersection" => { + let mut args = ctx.prepare_args(arguments, 1, span); + let query = self.ray_query_pointer(args.next()?, ctx)?; + args.finish()?; + + let _ = ctx.module.generate_ray_intersection_type(); + + crate::Expression::RayQueryGetIntersection { + query, + committed: true, + } + } + "RayDesc" => { + let ty = ctx.module.generate_ray_desc_type(); + let handle = self.construct( + span, + &ast::ConstructorType::Type(ty), + function.span, + arguments, + ctx, + )?; + return Ok(Some(handle)); + } + _ => return Err(Error::UnknownIdent(function.span, function.name)), + } + }; + + let expr = ctx.append_expression(expr, span)?; + Ok(Some(expr)) + } + } + } + + fn atomic_pointer( + &mut self, + expr: Handle<ast::Expression<'source>>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let span = ctx.ast_expressions.get_span(expr); + let pointer = self.expression(expr, ctx)?; + + match *resolve_inner!(ctx, pointer) { + crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner { + crate::TypeInner::Atomic { .. } => Ok(pointer), + ref other => { + log::error!("Pointer type to {:?} passed to atomic op", other); + Err(Error::InvalidAtomicPointer(span)) + } + }, + ref other => { + log::error!("Type {:?} passed to atomic op", other); + Err(Error::InvalidAtomicPointer(span)) + } + } + } + + fn atomic_helper( + &mut self, + span: Span, + fun: crate::AtomicFunction, + args: &[Handle<ast::Expression<'source>>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let mut args = ctx.prepare_args(args, 2, span); + + let pointer = self.atomic_pointer(args.next()?, ctx)?; + + let value = args.next()?; + let value = self.expression(value, ctx)?; + let ty = ctx.register_type(value)?; + + args.finish()?; + + let result = ctx.interrupt_emitter( + crate::Expression::AtomicResult { + ty, + comparison: false, + }, + span, + )?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::Atomic { + pointer, + fun, + value, + result, + }, + span, + ); + Ok(result) + } + + fn texture_sample_helper( + &mut self, + fun: Texture, + args: &[Handle<ast::Expression<'source>>], + span: Span, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<crate::Expression, Error<'source>> { + let mut args = ctx.prepare_args(args, fun.min_argument_count(), span); + + fn get_image_and_span<'source>( + lowerer: &mut Lowerer<'source, '_>, + args: &mut ArgumentContext<'_, 'source>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<(Handle<crate::Expression>, Span), Error<'source>> { + let image = args.next()?; + let image_span = ctx.ast_expressions.get_span(image); + let image = lowerer.expression(image, ctx)?; + Ok((image, image_span)) + } + + let (image, image_span, gather) = match fun { + Texture::Gather => { + let image_or_component = args.next()?; + let image_or_component_span = ctx.ast_expressions.get_span(image_or_component); + // Gathers from depth textures don't take an initial `component` argument. + let lowered_image_or_component = self.expression(image_or_component, ctx)?; + + match *resolve_inner!(ctx, lowered_image_or_component) { + crate::TypeInner::Image { + class: crate::ImageClass::Depth { .. }, + .. + } => ( + lowered_image_or_component, + image_or_component_span, + Some(crate::SwizzleComponent::X), + ), + _ => { + let (image, image_span) = get_image_and_span(self, &mut args, ctx)?; + ( + image, + image_span, + Some(ctx.gather_component( + lowered_image_or_component, + image_or_component_span, + span, + )?), + ) + } + } + } + Texture::GatherCompare => { + let (image, image_span) = get_image_and_span(self, &mut args, ctx)?; + (image, image_span, Some(crate::SwizzleComponent::X)) + } + + _ => { + let (image, image_span) = get_image_and_span(self, &mut args, ctx)?; + (image, image_span, None) + } + }; + + let sampler = self.expression(args.next()?, ctx)?; + + let coordinate = self.expression(args.next()?, ctx)?; + + let (_, arrayed) = ctx.image_data(image, image_span)?; + let array_index = arrayed + .then(|| self.expression(args.next()?, ctx)) + .transpose()?; + + let (level, depth_ref) = match fun { + Texture::Gather => (crate::SampleLevel::Zero, None), + Texture::GatherCompare => { + let reference = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Zero, Some(reference)) + } + + Texture::Sample => (crate::SampleLevel::Auto, None), + Texture::SampleBias => { + let bias = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Bias(bias), None) + } + Texture::SampleCompare => { + let reference = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Auto, Some(reference)) + } + Texture::SampleCompareLevel => { + let reference = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Zero, Some(reference)) + } + Texture::SampleGrad => { + let x = self.expression(args.next()?, ctx)?; + let y = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Gradient { x, y }, None) + } + Texture::SampleLevel => { + let level = self.expression(args.next()?, ctx)?; + (crate::SampleLevel::Exact(level), None) + } + }; + + let offset = args + .next() + .map(|arg| self.expression(arg, &mut ctx.as_const())) + .ok() + .transpose()?; + + args.finish()?; + + Ok(crate::Expression::ImageSample { + image, + sampler, + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + }) + } + + fn r#struct( + &mut self, + s: &ast::Struct<'source>, + span: Span, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result<Handle<crate::Type>, Error<'source>> { + let mut offset = 0; + let mut struct_alignment = Alignment::ONE; + let mut members = Vec::with_capacity(s.members.len()); + + for member in s.members.iter() { + let ty = self.resolve_ast_type(member.ty, ctx)?; + + self.layouter.update(ctx.module.to_ctx()).unwrap(); + + let member_min_size = self.layouter[ty].size; + let member_min_alignment = self.layouter[ty].alignment; + + let member_size = if let Some(size_expr) = member.size { + let (size, span) = self.const_u32(size_expr, &mut ctx.as_const())?; + if size < member_min_size { + return Err(Error::SizeAttributeTooLow(span, member_min_size)); + } else { + size + } + } else { + member_min_size + }; + + let member_alignment = if let Some(align_expr) = member.align { + let (align, span) = self.const_u32(align_expr, &mut ctx.as_const())?; + if let Some(alignment) = Alignment::new(align) { + if alignment < member_min_alignment { + return Err(Error::AlignAttributeTooLow(span, member_min_alignment)); + } else { + alignment + } + } else { + return Err(Error::NonPowerOfTwoAlignAttribute(span)); + } + } else { + member_min_alignment + }; + + let binding = self.binding(&member.binding, ty, ctx)?; + + offset = member_alignment.round_up(offset); + struct_alignment = struct_alignment.max(member_alignment); + + members.push(crate::StructMember { + name: Some(member.name.name.to_owned()), + ty, + binding, + offset, + }); + + offset += member_size; + } + + let size = struct_alignment.round_up(offset); + let inner = crate::TypeInner::Struct { + members, + span: size, + }; + + let handle = ctx.module.types.insert( + crate::Type { + name: Some(s.name.name.to_string()), + inner, + }, + span, + ); + Ok(handle) + } + + fn const_u32( + &mut self, + expr: Handle<ast::Expression<'source>>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<(u32, Span), Error<'source>> { + let span = ctx.ast_expressions.get_span(expr); + let expr = self.expression(expr, ctx)?; + let value = ctx + .module + .to_ctx() + .eval_expr_to_u32(expr) + .map_err(|err| match err { + crate::proc::U32EvalError::NonConst => { + Error::ExpectedConstExprConcreteIntegerScalar(span) + } + crate::proc::U32EvalError::Negative => Error::ExpectedNonNegative(span), + })?; + Ok((value, span)) + } + + fn array_size( + &mut self, + size: ast::ArraySize<'source>, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result<crate::ArraySize, Error<'source>> { + Ok(match size { + ast::ArraySize::Constant(expr) => { + let span = ctx.ast_expressions.get_span(expr); + let const_expr = self.expression(expr, &mut ctx.as_const())?; + let len = + ctx.module + .to_ctx() + .eval_expr_to_u32(const_expr) + .map_err(|err| match err { + crate::proc::U32EvalError::NonConst => { + Error::ExpectedConstExprConcreteIntegerScalar(span) + } + crate::proc::U32EvalError::Negative => { + Error::ExpectedPositiveArrayLength(span) + } + })?; + let size = NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?; + crate::ArraySize::Constant(size) + } + ast::ArraySize::Dynamic => crate::ArraySize::Dynamic, + }) + } + + /// Build the Naga equivalent of a named AST type. + /// + /// Return a Naga `Handle<Type>` representing the front-end type + /// `handle`, which should be named `name`, if given. + /// + /// If `handle` refers to a type cached in [`SpecialTypes`], + /// `name` may be ignored. + /// + /// [`SpecialTypes`]: crate::SpecialTypes + fn resolve_named_ast_type( + &mut self, + handle: Handle<ast::Type<'source>>, + name: Option<String>, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result<Handle<crate::Type>, Error<'source>> { + let inner = match ctx.types[handle] { + ast::Type::Scalar(scalar) => scalar.to_inner_scalar(), + ast::Type::Vector { size, scalar } => scalar.to_inner_vector(size), + ast::Type::Matrix { + rows, + columns, + width, + } => crate::TypeInner::Matrix { + columns, + rows, + scalar: crate::Scalar::float(width), + }, + ast::Type::Atomic(scalar) => scalar.to_inner_atomic(), + ast::Type::Pointer { base, space } => { + let base = self.resolve_ast_type(base, ctx)?; + crate::TypeInner::Pointer { base, space } + } + ast::Type::Array { base, size } => { + let base = self.resolve_ast_type(base, ctx)?; + let size = self.array_size(size, ctx)?; + + self.layouter.update(ctx.module.to_ctx()).unwrap(); + let stride = self.layouter[base].to_stride(); + + crate::TypeInner::Array { base, size, stride } + } + ast::Type::Image { + dim, + arrayed, + class, + } => crate::TypeInner::Image { + dim, + arrayed, + class, + }, + ast::Type::Sampler { comparison } => crate::TypeInner::Sampler { comparison }, + ast::Type::AccelerationStructure => crate::TypeInner::AccelerationStructure, + ast::Type::RayQuery => crate::TypeInner::RayQuery, + ast::Type::BindingArray { base, size } => { + let base = self.resolve_ast_type(base, ctx)?; + let size = self.array_size(size, ctx)?; + crate::TypeInner::BindingArray { base, size } + } + ast::Type::RayDesc => { + return Ok(ctx.module.generate_ray_desc_type()); + } + ast::Type::RayIntersection => { + return Ok(ctx.module.generate_ray_intersection_type()); + } + ast::Type::User(ref ident) => { + return match ctx.globals.get(ident.name) { + Some(&LoweredGlobalDecl::Type(handle)) => Ok(handle), + Some(_) => Err(Error::Unexpected(ident.span, ExpectedToken::Type)), + None => Err(Error::UnknownType(ident.span)), + } + } + }; + + Ok(ctx.ensure_type_exists(name, inner)) + } + + /// Return a Naga `Handle<Type>` representing the front-end type `handle`. + fn resolve_ast_type( + &mut self, + handle: Handle<ast::Type<'source>>, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result<Handle<crate::Type>, Error<'source>> { + self.resolve_named_ast_type(handle, None, ctx) + } + + fn binding( + &mut self, + binding: &Option<ast::Binding<'source>>, + ty: Handle<crate::Type>, + ctx: &mut GlobalContext<'source, '_, '_>, + ) -> Result<Option<crate::Binding>, Error<'source>> { + Ok(match *binding { + Some(ast::Binding::BuiltIn(b)) => Some(crate::Binding::BuiltIn(b)), + Some(ast::Binding::Location { + location, + second_blend_source, + interpolation, + sampling, + }) => { + let mut binding = crate::Binding::Location { + location: self.const_u32(location, &mut ctx.as_const())?.0, + second_blend_source, + interpolation, + sampling, + }; + binding.apply_default_interpolation(&ctx.module.types[ty].inner); + Some(binding) + } + None => None, + }) + } + + fn ray_query_pointer( + &mut self, + expr: Handle<ast::Expression<'source>>, + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let span = ctx.ast_expressions.get_span(expr); + let pointer = self.expression(expr, ctx)?; + + match *resolve_inner!(ctx, pointer) { + crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner { + crate::TypeInner::RayQuery => Ok(pointer), + ref other => { + log::error!("Pointer type to {:?} passed to ray query op", other); + Err(Error::InvalidRayQueryPointer(span)) + } + }, + ref other => { + log::error!("Type {:?} passed to ray query op", other); + Err(Error::InvalidRayQueryPointer(span)) + } + } + } +} diff --git a/third_party/rust/naga/src/front/wgsl/mod.rs b/third_party/rust/naga/src/front/wgsl/mod.rs new file mode 100644 index 0000000000..b6151fe1c0 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/mod.rs @@ -0,0 +1,49 @@ +/*! +Frontend for [WGSL][wgsl] (WebGPU Shading Language). + +[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html +*/ + +mod error; +mod index; +mod lower; +mod parse; +#[cfg(test)] +mod tests; +mod to_wgsl; + +use crate::front::wgsl::error::Error; +use crate::front::wgsl::parse::Parser; +use thiserror::Error; + +pub use crate::front::wgsl::error::ParseError; +use crate::front::wgsl::lower::Lowerer; +use crate::Scalar; + +pub struct Frontend { + parser: Parser, +} + +impl Frontend { + pub const fn new() -> Self { + Self { + parser: Parser::new(), + } + } + + pub fn parse(&mut self, source: &str) -> Result<crate::Module, ParseError> { + self.inner(source).map_err(|x| x.as_parse_error(source)) + } + + fn inner<'a>(&mut self, source: &'a str) -> Result<crate::Module, Error<'a>> { + let tu = self.parser.parse(source)?; + let index = index::Index::generate(&tu)?; + let module = Lowerer::new(&index).lower(&tu)?; + + Ok(module) + } +} + +pub fn parse_str(source: &str) -> Result<crate::Module, ParseError> { + Frontend::new().parse(source) +} diff --git a/third_party/rust/naga/src/front/wgsl/parse/ast.rs b/third_party/rust/naga/src/front/wgsl/parse/ast.rs new file mode 100644 index 0000000000..dbaac523cb --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/parse/ast.rs @@ -0,0 +1,491 @@ +use crate::front::wgsl::parse::number::Number; +use crate::front::wgsl::Scalar; +use crate::{Arena, FastIndexSet, Handle, Span}; +use std::hash::Hash; + +#[derive(Debug, Default)] +pub struct TranslationUnit<'a> { + pub decls: Arena<GlobalDecl<'a>>, + /// The common expressions arena for the entire translation unit. + /// + /// All functions, global initializers, array lengths, etc. store their + /// expressions here. We apportion these out to individual Naga + /// [`Function`]s' expression arenas at lowering time. Keeping them all in a + /// single arena simplifies handling of things like array lengths (which are + /// effectively global and thus don't clearly belong to any function) and + /// initializers (which can appear in both function-local and module-scope + /// contexts). + /// + /// [`Function`]: crate::Function + pub expressions: Arena<Expression<'a>>, + + /// Non-user-defined types, like `vec4<f32>` or `array<i32, 10>`. + /// + /// These are referred to by `Handle<ast::Type<'a>>` values. + /// User-defined types are referred to by name until lowering. + pub types: Arena<Type<'a>>, +} + +#[derive(Debug, Clone, Copy)] +pub struct Ident<'a> { + pub name: &'a str, + pub span: Span, +} + +#[derive(Debug)] +pub enum IdentExpr<'a> { + Unresolved(&'a str), + Local(Handle<Local>), +} + +/// A reference to a module-scope definition or predeclared object. +/// +/// Each [`GlobalDecl`] holds a set of these values, to be resolved to +/// specific definitions later. To support de-duplication, `Eq` and +/// `Hash` on a `Dependency` value consider only the name, not the +/// source location at which the reference occurs. +#[derive(Debug)] +pub struct Dependency<'a> { + /// The name referred to. + pub ident: &'a str, + + /// The location at which the reference to that name occurs. + pub usage: Span, +} + +impl Hash for Dependency<'_> { + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + self.ident.hash(state); + } +} + +impl PartialEq for Dependency<'_> { + fn eq(&self, other: &Self) -> bool { + self.ident == other.ident + } +} + +impl Eq for Dependency<'_> {} + +/// A module-scope declaration. +#[derive(Debug)] +pub struct GlobalDecl<'a> { + pub kind: GlobalDeclKind<'a>, + + /// Names of all module-scope or predeclared objects this + /// declaration uses. + pub dependencies: FastIndexSet<Dependency<'a>>, +} + +#[derive(Debug)] +pub enum GlobalDeclKind<'a> { + Fn(Function<'a>), + Var(GlobalVariable<'a>), + Const(Const<'a>), + Struct(Struct<'a>), + Type(TypeAlias<'a>), +} + +#[derive(Debug)] +pub struct FunctionArgument<'a> { + pub name: Ident<'a>, + pub ty: Handle<Type<'a>>, + pub binding: Option<Binding<'a>>, + pub handle: Handle<Local>, +} + +#[derive(Debug)] +pub struct FunctionResult<'a> { + pub ty: Handle<Type<'a>>, + pub binding: Option<Binding<'a>>, +} + +#[derive(Debug)] +pub struct EntryPoint<'a> { + pub stage: crate::ShaderStage, + pub early_depth_test: Option<crate::EarlyDepthTest>, + pub workgroup_size: Option<[Option<Handle<Expression<'a>>>; 3]>, +} + +#[cfg(doc)] +use crate::front::wgsl::lower::{RuntimeExpressionContext, StatementContext}; + +#[derive(Debug)] +pub struct Function<'a> { + pub entry_point: Option<EntryPoint<'a>>, + pub name: Ident<'a>, + pub arguments: Vec<FunctionArgument<'a>>, + pub result: Option<FunctionResult<'a>>, + + /// Local variable and function argument arena. + /// + /// Note that the `Local` here is actually a zero-sized type. The AST keeps + /// all the detailed information about locals - names, types, etc. - in + /// [`LocalDecl`] statements. For arguments, that information is kept in + /// [`arguments`]. This `Arena`'s only role is to assign a unique `Handle` + /// to each of them, and track their definitions' spans for use in + /// diagnostics. + /// + /// In the AST, when an [`Ident`] expression refers to a local variable or + /// argument, its [`IdentExpr`] holds the referent's `Handle<Local>` in this + /// arena. + /// + /// During lowering, [`LocalDecl`] statements add entries to a per-function + /// table that maps `Handle<Local>` values to their Naga representations, + /// accessed via [`StatementContext::local_table`] and + /// [`RuntimeExpressionContext::local_table`]. This table is then consulted when + /// lowering subsequent [`Ident`] expressions. + /// + /// [`LocalDecl`]: StatementKind::LocalDecl + /// [`arguments`]: Function::arguments + /// [`Ident`]: Expression::Ident + /// [`StatementContext::local_table`]: StatementContext::local_table + /// [`RuntimeExpressionContext::local_table`]: RuntimeExpressionContext::local_table + pub locals: Arena<Local>, + + pub body: Block<'a>, +} + +#[derive(Debug)] +pub enum Binding<'a> { + BuiltIn(crate::BuiltIn), + Location { + location: Handle<Expression<'a>>, + second_blend_source: bool, + interpolation: Option<crate::Interpolation>, + sampling: Option<crate::Sampling>, + }, +} + +#[derive(Debug)] +pub struct ResourceBinding<'a> { + pub group: Handle<Expression<'a>>, + pub binding: Handle<Expression<'a>>, +} + +#[derive(Debug)] +pub struct GlobalVariable<'a> { + pub name: Ident<'a>, + pub space: crate::AddressSpace, + pub binding: Option<ResourceBinding<'a>>, + pub ty: Handle<Type<'a>>, + pub init: Option<Handle<Expression<'a>>>, +} + +#[derive(Debug)] +pub struct StructMember<'a> { + pub name: Ident<'a>, + pub ty: Handle<Type<'a>>, + pub binding: Option<Binding<'a>>, + pub align: Option<Handle<Expression<'a>>>, + pub size: Option<Handle<Expression<'a>>>, +} + +#[derive(Debug)] +pub struct Struct<'a> { + pub name: Ident<'a>, + pub members: Vec<StructMember<'a>>, +} + +#[derive(Debug)] +pub struct TypeAlias<'a> { + pub name: Ident<'a>, + pub ty: Handle<Type<'a>>, +} + +#[derive(Debug)] +pub struct Const<'a> { + pub name: Ident<'a>, + pub ty: Option<Handle<Type<'a>>>, + pub init: Handle<Expression<'a>>, +} + +/// The size of an [`Array`] or [`BindingArray`]. +/// +/// [`Array`]: Type::Array +/// [`BindingArray`]: Type::BindingArray +#[derive(Debug, Copy, Clone)] +pub enum ArraySize<'a> { + /// The length as a constant expression. + Constant(Handle<Expression<'a>>), + Dynamic, +} + +#[derive(Debug)] +pub enum Type<'a> { + Scalar(Scalar), + Vector { + size: crate::VectorSize, + scalar: Scalar, + }, + Matrix { + columns: crate::VectorSize, + rows: crate::VectorSize, + width: crate::Bytes, + }, + Atomic(Scalar), + Pointer { + base: Handle<Type<'a>>, + space: crate::AddressSpace, + }, + Array { + base: Handle<Type<'a>>, + size: ArraySize<'a>, + }, + Image { + dim: crate::ImageDimension, + arrayed: bool, + class: crate::ImageClass, + }, + Sampler { + comparison: bool, + }, + AccelerationStructure, + RayQuery, + RayDesc, + RayIntersection, + BindingArray { + base: Handle<Type<'a>>, + size: ArraySize<'a>, + }, + + /// A user-defined type, like a struct or a type alias. + User(Ident<'a>), +} + +#[derive(Debug, Default)] +pub struct Block<'a> { + pub stmts: Vec<Statement<'a>>, +} + +#[derive(Debug)] +pub struct Statement<'a> { + pub kind: StatementKind<'a>, + pub span: Span, +} + +#[derive(Debug)] +pub enum StatementKind<'a> { + LocalDecl(LocalDecl<'a>), + Block(Block<'a>), + If { + condition: Handle<Expression<'a>>, + accept: Block<'a>, + reject: Block<'a>, + }, + Switch { + selector: Handle<Expression<'a>>, + cases: Vec<SwitchCase<'a>>, + }, + Loop { + body: Block<'a>, + continuing: Block<'a>, + break_if: Option<Handle<Expression<'a>>>, + }, + Break, + Continue, + Return { + value: Option<Handle<Expression<'a>>>, + }, + Kill, + Call { + function: Ident<'a>, + arguments: Vec<Handle<Expression<'a>>>, + }, + Assign { + target: Handle<Expression<'a>>, + op: Option<crate::BinaryOperator>, + value: Handle<Expression<'a>>, + }, + Increment(Handle<Expression<'a>>), + Decrement(Handle<Expression<'a>>), + Ignore(Handle<Expression<'a>>), +} + +#[derive(Debug)] +pub enum SwitchValue<'a> { + Expr(Handle<Expression<'a>>), + Default, +} + +#[derive(Debug)] +pub struct SwitchCase<'a> { + pub value: SwitchValue<'a>, + pub body: Block<'a>, + pub fall_through: bool, +} + +/// A type at the head of a [`Construct`] expression. +/// +/// WGSL has two types of [`type constructor expressions`]: +/// +/// - Those that fully specify the type being constructed, like +/// `vec3<f32>(x,y,z)`, which obviously constructs a `vec3<f32>`. +/// +/// - Those that leave the component type of the composite being constructed +/// implicit, to be inferred from the argument types, like `vec3(x,y,z)`, +/// which constructs a `vec3<T>` where `T` is the type of `x`, `y`, and `z`. +/// +/// This enum represents the head type of both cases. The `PartialFoo` variants +/// represent the second case, where the component type is implicit. +/// +/// This does not cover structs or types referred to by type aliases. See the +/// documentation for [`Construct`] and [`Call`] expressions for details. +/// +/// [`Construct`]: Expression::Construct +/// [`type constructor expressions`]: https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr +/// [`Call`]: Expression::Call +#[derive(Debug)] +pub enum ConstructorType<'a> { + /// A scalar type or conversion: `f32(1)`. + Scalar(Scalar), + + /// A vector construction whose component type is inferred from the + /// argument: `vec3(1.0)`. + PartialVector { size: crate::VectorSize }, + + /// A vector construction whose component type is written out: + /// `vec3<f32>(1.0)`. + Vector { + size: crate::VectorSize, + scalar: Scalar, + }, + + /// A matrix construction whose component type is inferred from the + /// argument: `mat2x2(1,2,3,4)`. + PartialMatrix { + columns: crate::VectorSize, + rows: crate::VectorSize, + }, + + /// A matrix construction whose component type is written out: + /// `mat2x2<f32>(1,2,3,4)`. + Matrix { + columns: crate::VectorSize, + rows: crate::VectorSize, + width: crate::Bytes, + }, + + /// An array whose component type and size are inferred from the arguments: + /// `array(3,4,5)`. + PartialArray, + + /// An array whose component type and size are written out: + /// `array<u32, 4>(3,4,5)`. + Array { + base: Handle<Type<'a>>, + size: ArraySize<'a>, + }, + + /// Constructing a value of a known Naga IR type. + /// + /// This variant is produced only during lowering, when we have Naga types + /// available, never during parsing. + Type(Handle<crate::Type>), +} + +#[derive(Debug, Copy, Clone)] +pub enum Literal { + Bool(bool), + Number(Number), +} + +#[cfg(doc)] +use crate::front::wgsl::lower::Lowerer; + +#[derive(Debug)] +pub enum Expression<'a> { + Literal(Literal), + Ident(IdentExpr<'a>), + + /// A type constructor expression. + /// + /// This is only used for expressions like `KEYWORD(EXPR...)` and + /// `KEYWORD<PARAM>(EXPR...)`, where `KEYWORD` is a [type-defining keyword] like + /// `vec3`. These keywords cannot be shadowed by user definitions, so we can + /// tell that such an expression is a construction immediately. + /// + /// For ordinary identifiers, we can't tell whether an expression like + /// `IDENTIFIER(EXPR, ...)` is a construction expression or a function call + /// until we know `IDENTIFIER`'s definition, so we represent those as + /// [`Call`] expressions. + /// + /// [type-defining keyword]: https://gpuweb.github.io/gpuweb/wgsl/#type-defining-keywords + /// [`Call`]: Expression::Call + Construct { + ty: ConstructorType<'a>, + ty_span: Span, + components: Vec<Handle<Expression<'a>>>, + }, + Unary { + op: crate::UnaryOperator, + expr: Handle<Expression<'a>>, + }, + AddrOf(Handle<Expression<'a>>), + Deref(Handle<Expression<'a>>), + Binary { + op: crate::BinaryOperator, + left: Handle<Expression<'a>>, + right: Handle<Expression<'a>>, + }, + + /// A function call or type constructor expression. + /// + /// We can't tell whether an expression like `IDENTIFIER(EXPR, ...)` is a + /// construction expression or a function call until we know `IDENTIFIER`'s + /// definition, so we represent everything of that form as one of these + /// expressions until lowering. At that point, [`Lowerer::call`] has + /// everything's definition in hand, and can decide whether to emit a Naga + /// [`Constant`], [`As`], [`Splat`], or [`Compose`] expression. + /// + /// [`Lowerer::call`]: Lowerer::call + /// [`Constant`]: crate::Expression::Constant + /// [`As`]: crate::Expression::As + /// [`Splat`]: crate::Expression::Splat + /// [`Compose`]: crate::Expression::Compose + Call { + function: Ident<'a>, + arguments: Vec<Handle<Expression<'a>>>, + }, + Index { + base: Handle<Expression<'a>>, + index: Handle<Expression<'a>>, + }, + Member { + base: Handle<Expression<'a>>, + field: Ident<'a>, + }, + Bitcast { + expr: Handle<Expression<'a>>, + to: Handle<Type<'a>>, + ty_span: Span, + }, +} + +#[derive(Debug)] +pub struct LocalVariable<'a> { + pub name: Ident<'a>, + pub ty: Option<Handle<Type<'a>>>, + pub init: Option<Handle<Expression<'a>>>, + pub handle: Handle<Local>, +} + +#[derive(Debug)] +pub struct Let<'a> { + pub name: Ident<'a>, + pub ty: Option<Handle<Type<'a>>>, + pub init: Handle<Expression<'a>>, + pub handle: Handle<Local>, +} + +#[derive(Debug)] +pub enum LocalDecl<'a> { + Var(LocalVariable<'a>), + Let(Let<'a>), +} + +#[derive(Debug)] +/// A placeholder for a local variable declaration. +/// +/// See [`Function::locals`] for more information. +pub struct Local; diff --git a/third_party/rust/naga/src/front/wgsl/parse/conv.rs b/third_party/rust/naga/src/front/wgsl/parse/conv.rs new file mode 100644 index 0000000000..08f1e39285 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/parse/conv.rs @@ -0,0 +1,254 @@ +use super::Error; +use crate::front::wgsl::Scalar; +use crate::Span; + +pub fn map_address_space(word: &str, span: Span) -> Result<crate::AddressSpace, Error<'_>> { + match word { + "private" => Ok(crate::AddressSpace::Private), + "workgroup" => Ok(crate::AddressSpace::WorkGroup), + "uniform" => Ok(crate::AddressSpace::Uniform), + "storage" => Ok(crate::AddressSpace::Storage { + access: crate::StorageAccess::default(), + }), + "push_constant" => Ok(crate::AddressSpace::PushConstant), + "function" => Ok(crate::AddressSpace::Function), + _ => Err(Error::UnknownAddressSpace(span)), + } +} + +pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>> { + Ok(match word { + "position" => crate::BuiltIn::Position { invariant: false }, + // vertex + "vertex_index" => crate::BuiltIn::VertexIndex, + "instance_index" => crate::BuiltIn::InstanceIndex, + "view_index" => crate::BuiltIn::ViewIndex, + // fragment + "front_facing" => crate::BuiltIn::FrontFacing, + "frag_depth" => crate::BuiltIn::FragDepth, + "primitive_index" => crate::BuiltIn::PrimitiveIndex, + "sample_index" => crate::BuiltIn::SampleIndex, + "sample_mask" => crate::BuiltIn::SampleMask, + // compute + "global_invocation_id" => crate::BuiltIn::GlobalInvocationId, + "local_invocation_id" => crate::BuiltIn::LocalInvocationId, + "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex, + "workgroup_id" => crate::BuiltIn::WorkGroupId, + "num_workgroups" => crate::BuiltIn::NumWorkGroups, + _ => return Err(Error::UnknownBuiltin(span)), + }) +} + +pub fn map_interpolation(word: &str, span: Span) -> Result<crate::Interpolation, Error<'_>> { + match word { + "linear" => Ok(crate::Interpolation::Linear), + "flat" => Ok(crate::Interpolation::Flat), + "perspective" => Ok(crate::Interpolation::Perspective), + _ => Err(Error::UnknownAttribute(span)), + } +} + +pub fn map_sampling(word: &str, span: Span) -> Result<crate::Sampling, Error<'_>> { + match word { + "center" => Ok(crate::Sampling::Center), + "centroid" => Ok(crate::Sampling::Centroid), + "sample" => Ok(crate::Sampling::Sample), + _ => Err(Error::UnknownAttribute(span)), + } +} + +pub fn map_storage_format(word: &str, span: Span) -> Result<crate::StorageFormat, Error<'_>> { + use crate::StorageFormat as Sf; + Ok(match word { + "r8unorm" => Sf::R8Unorm, + "r8snorm" => Sf::R8Snorm, + "r8uint" => Sf::R8Uint, + "r8sint" => Sf::R8Sint, + "r16unorm" => Sf::R16Unorm, + "r16snorm" => Sf::R16Snorm, + "r16uint" => Sf::R16Uint, + "r16sint" => Sf::R16Sint, + "r16float" => Sf::R16Float, + "rg8unorm" => Sf::Rg8Unorm, + "rg8snorm" => Sf::Rg8Snorm, + "rg8uint" => Sf::Rg8Uint, + "rg8sint" => Sf::Rg8Sint, + "r32uint" => Sf::R32Uint, + "r32sint" => Sf::R32Sint, + "r32float" => Sf::R32Float, + "rg16unorm" => Sf::Rg16Unorm, + "rg16snorm" => Sf::Rg16Snorm, + "rg16uint" => Sf::Rg16Uint, + "rg16sint" => Sf::Rg16Sint, + "rg16float" => Sf::Rg16Float, + "rgba8unorm" => Sf::Rgba8Unorm, + "rgba8snorm" => Sf::Rgba8Snorm, + "rgba8uint" => Sf::Rgba8Uint, + "rgba8sint" => Sf::Rgba8Sint, + "rgb10a2uint" => Sf::Rgb10a2Uint, + "rgb10a2unorm" => Sf::Rgb10a2Unorm, + "rg11b10float" => Sf::Rg11b10Float, + "rg32uint" => Sf::Rg32Uint, + "rg32sint" => Sf::Rg32Sint, + "rg32float" => Sf::Rg32Float, + "rgba16unorm" => Sf::Rgba16Unorm, + "rgba16snorm" => Sf::Rgba16Snorm, + "rgba16uint" => Sf::Rgba16Uint, + "rgba16sint" => Sf::Rgba16Sint, + "rgba16float" => Sf::Rgba16Float, + "rgba32uint" => Sf::Rgba32Uint, + "rgba32sint" => Sf::Rgba32Sint, + "rgba32float" => Sf::Rgba32Float, + "bgra8unorm" => Sf::Bgra8Unorm, + _ => return Err(Error::UnknownStorageFormat(span)), + }) +} + +pub fn get_scalar_type(word: &str) -> Option<Scalar> { + use crate::ScalarKind as Sk; + match word { + // "f16" => Some(Scalar { kind: Sk::Float, width: 2 }), + "f32" => Some(Scalar { + kind: Sk::Float, + width: 4, + }), + "f64" => Some(Scalar { + kind: Sk::Float, + width: 8, + }), + "i32" => Some(Scalar { + kind: Sk::Sint, + width: 4, + }), + "u32" => Some(Scalar { + kind: Sk::Uint, + width: 4, + }), + "bool" => Some(Scalar { + kind: Sk::Bool, + width: crate::BOOL_WIDTH, + }), + _ => None, + } +} + +pub fn map_derivative(word: &str) -> Option<(crate::DerivativeAxis, crate::DerivativeControl)> { + use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; + match word { + "dpdxCoarse" => Some((Axis::X, Ctrl::Coarse)), + "dpdyCoarse" => Some((Axis::Y, Ctrl::Coarse)), + "fwidthCoarse" => Some((Axis::Width, Ctrl::Coarse)), + "dpdxFine" => Some((Axis::X, Ctrl::Fine)), + "dpdyFine" => Some((Axis::Y, Ctrl::Fine)), + "fwidthFine" => Some((Axis::Width, Ctrl::Fine)), + "dpdx" => Some((Axis::X, Ctrl::None)), + "dpdy" => Some((Axis::Y, Ctrl::None)), + "fwidth" => Some((Axis::Width, Ctrl::None)), + _ => None, + } +} + +pub fn map_relational_fun(word: &str) -> Option<crate::RelationalFunction> { + match word { + "any" => Some(crate::RelationalFunction::Any), + "all" => Some(crate::RelationalFunction::All), + _ => None, + } +} + +pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> { + use crate::MathFunction as Mf; + Some(match word { + // comparison + "abs" => Mf::Abs, + "min" => Mf::Min, + "max" => Mf::Max, + "clamp" => Mf::Clamp, + "saturate" => Mf::Saturate, + // trigonometry + "cos" => Mf::Cos, + "cosh" => Mf::Cosh, + "sin" => Mf::Sin, + "sinh" => Mf::Sinh, + "tan" => Mf::Tan, + "tanh" => Mf::Tanh, + "acos" => Mf::Acos, + "acosh" => Mf::Acosh, + "asin" => Mf::Asin, + "asinh" => Mf::Asinh, + "atan" => Mf::Atan, + "atanh" => Mf::Atanh, + "atan2" => Mf::Atan2, + "radians" => Mf::Radians, + "degrees" => Mf::Degrees, + // decomposition + "ceil" => Mf::Ceil, + "floor" => Mf::Floor, + "round" => Mf::Round, + "fract" => Mf::Fract, + "trunc" => Mf::Trunc, + "modf" => Mf::Modf, + "frexp" => Mf::Frexp, + "ldexp" => Mf::Ldexp, + // exponent + "exp" => Mf::Exp, + "exp2" => Mf::Exp2, + "log" => Mf::Log, + "log2" => Mf::Log2, + "pow" => Mf::Pow, + // geometry + "dot" => Mf::Dot, + "cross" => Mf::Cross, + "distance" => Mf::Distance, + "length" => Mf::Length, + "normalize" => Mf::Normalize, + "faceForward" => Mf::FaceForward, + "reflect" => Mf::Reflect, + "refract" => Mf::Refract, + // computational + "sign" => Mf::Sign, + "fma" => Mf::Fma, + "mix" => Mf::Mix, + "step" => Mf::Step, + "smoothstep" => Mf::SmoothStep, + "sqrt" => Mf::Sqrt, + "inverseSqrt" => Mf::InverseSqrt, + "transpose" => Mf::Transpose, + "determinant" => Mf::Determinant, + // bits + "countTrailingZeros" => Mf::CountTrailingZeros, + "countLeadingZeros" => Mf::CountLeadingZeros, + "countOneBits" => Mf::CountOneBits, + "reverseBits" => Mf::ReverseBits, + "extractBits" => Mf::ExtractBits, + "insertBits" => Mf::InsertBits, + "firstTrailingBit" => Mf::FindLsb, + "firstLeadingBit" => Mf::FindMsb, + // data packing + "pack4x8snorm" => Mf::Pack4x8snorm, + "pack4x8unorm" => Mf::Pack4x8unorm, + "pack2x16snorm" => Mf::Pack2x16snorm, + "pack2x16unorm" => Mf::Pack2x16unorm, + "pack2x16float" => Mf::Pack2x16float, + // data unpacking + "unpack4x8snorm" => Mf::Unpack4x8snorm, + "unpack4x8unorm" => Mf::Unpack4x8unorm, + "unpack2x16snorm" => Mf::Unpack2x16snorm, + "unpack2x16unorm" => Mf::Unpack2x16unorm, + "unpack2x16float" => Mf::Unpack2x16float, + _ => return None, + }) +} + +pub fn map_conservative_depth( + word: &str, + span: Span, +) -> Result<crate::ConservativeDepth, Error<'_>> { + use crate::ConservativeDepth as Cd; + match word { + "greater_equal" => Ok(Cd::GreaterEqual), + "less_equal" => Ok(Cd::LessEqual), + "unchanged" => Ok(Cd::Unchanged), + _ => Err(Error::UnknownConservativeDepth(span)), + } +} diff --git a/third_party/rust/naga/src/front/wgsl/parse/lexer.rs b/third_party/rust/naga/src/front/wgsl/parse/lexer.rs new file mode 100644 index 0000000000..d03a448561 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/parse/lexer.rs @@ -0,0 +1,739 @@ +use super::{number::consume_number, Error, ExpectedToken}; +use crate::front::wgsl::error::NumberError; +use crate::front::wgsl::parse::{conv, Number}; +use crate::front::wgsl::Scalar; +use crate::Span; + +type TokenSpan<'a> = (Token<'a>, Span); + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum Token<'a> { + Separator(char), + Paren(char), + Attribute, + Number(Result<Number, NumberError>), + Word(&'a str), + Operation(char), + LogicalOperation(char), + ShiftOperation(char), + AssignmentOperation(char), + IncrementOperation, + DecrementOperation, + Arrow, + Unknown(char), + Trivia, + End, +} + +fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str) { + let pos = input.find(|c| !what(c)).unwrap_or(input.len()); + input.split_at(pos) +} + +/// Return the token at the start of `input`. +/// +/// If `generic` is `false`, then the bit shift operators `>>` or `<<` +/// are valid lookahead tokens for the current parser state (see [§3.1 +/// Parsing] in the WGSL specification). In other words: +/// +/// - If `generic` is `true`, then we are expecting an angle bracket +/// around a generic type parameter, like the `<` and `>` in +/// `vec3<f32>`, so interpret `<` and `>` as `Token::Paren` tokens, +/// even if they're part of `<<` or `>>` sequences. +/// +/// - Otherwise, interpret `<<` and `>>` as shift operators: +/// `Token::LogicalOperation` tokens. +/// +/// [§3.1 Parsing]: https://gpuweb.github.io/gpuweb/wgsl/#parsing +fn consume_token(input: &str, generic: bool) -> (Token<'_>, &str) { + let mut chars = input.chars(); + let cur = match chars.next() { + Some(c) => c, + None => return (Token::End, ""), + }; + match cur { + ':' | ';' | ',' => (Token::Separator(cur), chars.as_str()), + '.' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('0'..='9') => consume_number(input), + _ => (Token::Separator(cur), og_chars), + } + } + '@' => (Token::Attribute, chars.as_str()), + '(' | ')' | '{' | '}' | '[' | ']' => (Token::Paren(cur), chars.as_str()), + '<' | '>' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('=') if !generic => (Token::LogicalOperation(cur), chars.as_str()), + Some(c) if c == cur && !generic => { + let og_chars = chars.as_str(); + match chars.next() { + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::ShiftOperation(cur), og_chars), + } + } + _ => (Token::Paren(cur), og_chars), + } + } + '0'..='9' => consume_number(input), + '/' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('/') => { + let _ = chars.position(is_comment_end); + (Token::Trivia, chars.as_str()) + } + Some('*') => { + let mut depth = 1; + let mut prev = None; + + for c in &mut chars { + match (prev, c) { + (Some('*'), '/') => { + prev = None; + depth -= 1; + if depth == 0 { + return (Token::Trivia, chars.as_str()); + } + } + (Some('/'), '*') => { + prev = None; + depth += 1; + } + _ => { + prev = Some(c); + } + } + } + + (Token::End, "") + } + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + '-' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('>') => (Token::Arrow, chars.as_str()), + Some('-') => (Token::DecrementOperation, chars.as_str()), + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + '+' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('+') => (Token::IncrementOperation, chars.as_str()), + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + '*' | '%' | '^' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + '~' => (Token::Operation(cur), chars.as_str()), + '=' | '!' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('=') => (Token::LogicalOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + '&' | '|' => { + let og_chars = chars.as_str(); + match chars.next() { + Some(c) if c == cur => (Token::LogicalOperation(cur), chars.as_str()), + Some('=') => (Token::AssignmentOperation(cur), chars.as_str()), + _ => (Token::Operation(cur), og_chars), + } + } + _ if is_blankspace(cur) => { + let (_, rest) = consume_any(input, is_blankspace); + (Token::Trivia, rest) + } + _ if is_word_start(cur) => { + let (word, rest) = consume_any(input, is_word_part); + (Token::Word(word), rest) + } + _ => (Token::Unknown(cur), chars.as_str()), + } +} + +/// Returns whether or not a char is a comment end +/// (Unicode Pattern_White_Space excluding U+0020, U+0009, U+200E and U+200F) +const fn is_comment_end(c: char) -> bool { + match c { + '\u{000a}'..='\u{000d}' | '\u{0085}' | '\u{2028}' | '\u{2029}' => true, + _ => false, + } +} + +/// Returns whether or not a char is a blankspace (Unicode Pattern_White_Space) +const fn is_blankspace(c: char) -> bool { + match c { + '\u{0020}' + | '\u{0009}'..='\u{000d}' + | '\u{0085}' + | '\u{200e}' + | '\u{200f}' + | '\u{2028}' + | '\u{2029}' => true, + _ => false, + } +} + +/// Returns whether or not a char is a word start (Unicode XID_Start + '_') +fn is_word_start(c: char) -> bool { + c == '_' || unicode_xid::UnicodeXID::is_xid_start(c) +} + +/// Returns whether or not a char is a word part (Unicode XID_Continue) +fn is_word_part(c: char) -> bool { + unicode_xid::UnicodeXID::is_xid_continue(c) +} + +#[derive(Clone)] +pub(in crate::front::wgsl) struct Lexer<'a> { + input: &'a str, + pub(in crate::front::wgsl) source: &'a str, + // The byte offset of the end of the last non-trivia token. + last_end_offset: usize, +} + +impl<'a> Lexer<'a> { + pub(in crate::front::wgsl) const fn new(input: &'a str) -> Self { + Lexer { + input, + source: input, + last_end_offset: 0, + } + } + + /// Calls the function with a lexer and returns the result of the function as well as the span for everything the function parsed + /// + /// # Examples + /// ```ignore + /// let lexer = Lexer::new("5"); + /// let (value, span) = lexer.capture_span(Lexer::next_uint_literal); + /// assert_eq!(value, 5); + /// ``` + #[inline] + pub fn capture_span<T, E>( + &mut self, + inner: impl FnOnce(&mut Self) -> Result<T, E>, + ) -> Result<(T, Span), E> { + let start = self.current_byte_offset(); + let res = inner(self)?; + let end = self.current_byte_offset(); + Ok((res, Span::from(start..end))) + } + + pub(in crate::front::wgsl) fn start_byte_offset(&mut self) -> usize { + loop { + // Eat all trivia because `next` doesn't eat trailing trivia. + let (token, rest) = consume_token(self.input, false); + if let Token::Trivia = token { + self.input = rest; + } else { + return self.current_byte_offset(); + } + } + } + + fn peek_token_and_rest(&mut self) -> (TokenSpan<'a>, &'a str) { + let mut cloned = self.clone(); + let token = cloned.next(); + let rest = cloned.input; + (token, rest) + } + + const fn current_byte_offset(&self) -> usize { + self.source.len() - self.input.len() + } + + pub(in crate::front::wgsl) fn span_from(&self, offset: usize) -> Span { + Span::from(offset..self.last_end_offset) + } + + /// Return the next non-whitespace token from `self`. + /// + /// Assume we are a parse state where bit shift operators may + /// occur, but not angle brackets. + #[must_use] + pub(in crate::front::wgsl) fn next(&mut self) -> TokenSpan<'a> { + self.next_impl(false) + } + + /// Return the next non-whitespace token from `self`. + /// + /// Assume we are in a parse state where angle brackets may occur, + /// but not bit shift operators. + #[must_use] + pub(in crate::front::wgsl) fn next_generic(&mut self) -> TokenSpan<'a> { + self.next_impl(true) + } + + /// Return the next non-whitespace token from `self`, with a span. + /// + /// See [`consume_token`] for the meaning of `generic`. + fn next_impl(&mut self, generic: bool) -> TokenSpan<'a> { + let mut start_byte_offset = self.current_byte_offset(); + loop { + let (token, rest) = consume_token(self.input, generic); + self.input = rest; + match token { + Token::Trivia => start_byte_offset = self.current_byte_offset(), + _ => { + self.last_end_offset = self.current_byte_offset(); + return (token, self.span_from(start_byte_offset)); + } + } + } + } + + #[must_use] + pub(in crate::front::wgsl) fn peek(&mut self) -> TokenSpan<'a> { + let (token, _) = self.peek_token_and_rest(); + token + } + + pub(in crate::front::wgsl) fn expect_span( + &mut self, + expected: Token<'a>, + ) -> Result<Span, Error<'a>> { + let next = self.next(); + if next.0 == expected { + Ok(next.1) + } else { + Err(Error::Unexpected(next.1, ExpectedToken::Token(expected))) + } + } + + pub(in crate::front::wgsl) fn expect(&mut self, expected: Token<'a>) -> Result<(), Error<'a>> { + self.expect_span(expected)?; + Ok(()) + } + + pub(in crate::front::wgsl) fn expect_generic_paren( + &mut self, + expected: char, + ) -> Result<(), Error<'a>> { + let next = self.next_generic(); + if next.0 == Token::Paren(expected) { + Ok(()) + } else { + Err(Error::Unexpected( + next.1, + ExpectedToken::Token(Token::Paren(expected)), + )) + } + } + + /// If the next token matches it is skipped and true is returned + pub(in crate::front::wgsl) fn skip(&mut self, what: Token<'_>) -> bool { + let (peeked_token, rest) = self.peek_token_and_rest(); + if peeked_token.0 == what { + self.input = rest; + true + } else { + false + } + } + + pub(in crate::front::wgsl) fn next_ident_with_span( + &mut self, + ) -> Result<(&'a str, Span), Error<'a>> { + match self.next() { + (Token::Word("_"), span) => Err(Error::InvalidIdentifierUnderscore(span)), + (Token::Word(word), span) if word.starts_with("__") => { + Err(Error::ReservedIdentifierPrefix(span)) + } + (Token::Word(word), span) => Ok((word, span)), + other => Err(Error::Unexpected(other.1, ExpectedToken::Identifier)), + } + } + + pub(in crate::front::wgsl) fn next_ident( + &mut self, + ) -> Result<super::ast::Ident<'a>, Error<'a>> { + let ident = self + .next_ident_with_span() + .map(|(name, span)| super::ast::Ident { name, span })?; + + if crate::keywords::wgsl::RESERVED.contains(&ident.name) { + return Err(Error::ReservedKeyword(ident.span)); + } + + Ok(ident) + } + + /// Parses a generic scalar type, for example `<f32>`. + pub(in crate::front::wgsl) fn next_scalar_generic(&mut self) -> Result<Scalar, Error<'a>> { + self.expect_generic_paren('<')?; + let pair = match self.next() { + (Token::Word(word), span) => { + conv::get_scalar_type(word).ok_or(Error::UnknownScalarType(span)) + } + (_, span) => Err(Error::UnknownScalarType(span)), + }?; + self.expect_generic_paren('>')?; + Ok(pair) + } + + /// Parses a generic scalar type, for example `<f32>`. + /// + /// Returns the span covering the inner type, excluding the brackets. + pub(in crate::front::wgsl) fn next_scalar_generic_with_span( + &mut self, + ) -> Result<(Scalar, Span), Error<'a>> { + self.expect_generic_paren('<')?; + let pair = match self.next() { + (Token::Word(word), span) => conv::get_scalar_type(word) + .map(|scalar| (scalar, span)) + .ok_or(Error::UnknownScalarType(span)), + (_, span) => Err(Error::UnknownScalarType(span)), + }?; + self.expect_generic_paren('>')?; + Ok(pair) + } + + pub(in crate::front::wgsl) fn next_storage_access( + &mut self, + ) -> Result<crate::StorageAccess, Error<'a>> { + let (ident, span) = self.next_ident_with_span()?; + match ident { + "read" => Ok(crate::StorageAccess::LOAD), + "write" => Ok(crate::StorageAccess::STORE), + "read_write" => Ok(crate::StorageAccess::LOAD | crate::StorageAccess::STORE), + _ => Err(Error::UnknownAccess(span)), + } + } + + pub(in crate::front::wgsl) fn next_format_generic( + &mut self, + ) -> Result<(crate::StorageFormat, crate::StorageAccess), Error<'a>> { + self.expect(Token::Paren('<'))?; + let (ident, ident_span) = self.next_ident_with_span()?; + let format = conv::map_storage_format(ident, ident_span)?; + self.expect(Token::Separator(','))?; + let access = self.next_storage_access()?; + self.expect(Token::Paren('>'))?; + Ok((format, access)) + } + + pub(in crate::front::wgsl) fn open_arguments(&mut self) -> Result<(), Error<'a>> { + self.expect(Token::Paren('(')) + } + + pub(in crate::front::wgsl) fn close_arguments(&mut self) -> Result<(), Error<'a>> { + let _ = self.skip(Token::Separator(',')); + self.expect(Token::Paren(')')) + } + + pub(in crate::front::wgsl) fn next_argument(&mut self) -> Result<bool, Error<'a>> { + let paren = Token::Paren(')'); + if self.skip(Token::Separator(',')) { + Ok(!self.skip(paren)) + } else { + self.expect(paren).map(|()| false) + } + } +} + +#[cfg(test)] +#[track_caller] +fn sub_test(source: &str, expected_tokens: &[Token]) { + let mut lex = Lexer::new(source); + for &token in expected_tokens { + assert_eq!(lex.next().0, token); + } + assert_eq!(lex.next().0, Token::End); +} + +#[test] +fn test_numbers() { + // WGSL spec examples // + + // decimal integer + sub_test( + "0x123 0X123u 1u 123 0 0i 0x3f", + &[ + Token::Number(Ok(Number::AbstractInt(291))), + Token::Number(Ok(Number::U32(291))), + Token::Number(Ok(Number::U32(1))), + Token::Number(Ok(Number::AbstractInt(123))), + Token::Number(Ok(Number::AbstractInt(0))), + Token::Number(Ok(Number::I32(0))), + Token::Number(Ok(Number::AbstractInt(63))), + ], + ); + // decimal floating point + sub_test( + "0.e+4f 01. .01 12.34 .0f 0h 1e-3 0xa.fp+2 0x1P+4f 0X.3 0x3p+2h 0X1.fp-4 0x3.2p+2h", + &[ + Token::Number(Ok(Number::F32(0.))), + Token::Number(Ok(Number::AbstractFloat(1.))), + Token::Number(Ok(Number::AbstractFloat(0.01))), + Token::Number(Ok(Number::AbstractFloat(12.34))), + Token::Number(Ok(Number::F32(0.))), + Token::Number(Err(NumberError::UnimplementedF16)), + Token::Number(Ok(Number::AbstractFloat(0.001))), + Token::Number(Ok(Number::AbstractFloat(43.75))), + Token::Number(Ok(Number::F32(16.))), + Token::Number(Ok(Number::AbstractFloat(0.1875))), + Token::Number(Err(NumberError::UnimplementedF16)), + Token::Number(Ok(Number::AbstractFloat(0.12109375))), + Token::Number(Err(NumberError::UnimplementedF16)), + ], + ); + + // MIN / MAX // + + // min / max decimal integer + sub_test( + "0i 2147483647i 2147483648i", + &[ + Token::Number(Ok(Number::I32(0))), + Token::Number(Ok(Number::I32(i32::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + // min / max decimal unsigned integer + sub_test( + "0u 4294967295u 4294967296u", + &[ + Token::Number(Ok(Number::U32(u32::MIN))), + Token::Number(Ok(Number::U32(u32::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + + // min / max hexadecimal signed integer + sub_test( + "0x0i 0x7FFFFFFFi 0x80000000i", + &[ + Token::Number(Ok(Number::I32(0))), + Token::Number(Ok(Number::I32(i32::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + // min / max hexadecimal unsigned integer + sub_test( + "0x0u 0xFFFFFFFFu 0x100000000u", + &[ + Token::Number(Ok(Number::U32(u32::MIN))), + Token::Number(Ok(Number::U32(u32::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + + // min/max decimal abstract int + sub_test( + "0 9223372036854775807 9223372036854775808", + &[ + Token::Number(Ok(Number::AbstractInt(0))), + Token::Number(Ok(Number::AbstractInt(i64::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + + // min/max hexadecimal abstract int + sub_test( + "0 0x7fffffffffffffff 0x8000000000000000", + &[ + Token::Number(Ok(Number::AbstractInt(0))), + Token::Number(Ok(Number::AbstractInt(i64::MAX))), + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); + + /// ≈ 2^-126 * 2^−23 (= 2^−149) + const SMALLEST_POSITIVE_SUBNORMAL_F32: f32 = 1e-45; + /// ≈ 2^-126 * (1 − 2^−23) + const LARGEST_SUBNORMAL_F32: f32 = 1.1754942e-38; + /// ≈ 2^-126 + const SMALLEST_POSITIVE_NORMAL_F32: f32 = f32::MIN_POSITIVE; + /// ≈ 1 − 2^−24 + const LARGEST_F32_LESS_THAN_ONE: f32 = 0.99999994; + /// ≈ 1 + 2^−23 + const SMALLEST_F32_LARGER_THAN_ONE: f32 = 1.0000001; + /// ≈ 2^127 * (2 − 2^−23) + const LARGEST_NORMAL_F32: f32 = f32::MAX; + + // decimal floating point + sub_test( + "1e-45f 1.1754942e-38f 1.17549435e-38f 0.99999994f 1.0000001f 3.40282347e+38f", + &[ + Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_SUBNORMAL_F32))), + Token::Number(Ok(Number::F32(LARGEST_SUBNORMAL_F32))), + Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_NORMAL_F32))), + Token::Number(Ok(Number::F32(LARGEST_F32_LESS_THAN_ONE))), + Token::Number(Ok(Number::F32(SMALLEST_F32_LARGER_THAN_ONE))), + Token::Number(Ok(Number::F32(LARGEST_NORMAL_F32))), + ], + ); + sub_test( + "3.40282367e+38f", + &[ + Token::Number(Err(NumberError::NotRepresentable)), // ≈ 2^128 + ], + ); + + // hexadecimal floating point + sub_test( + "0x1p-149f 0x7FFFFFp-149f 0x1p-126f 0xFFFFFFp-24f 0x800001p-23f 0xFFFFFFp+104f", + &[ + Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_SUBNORMAL_F32))), + Token::Number(Ok(Number::F32(LARGEST_SUBNORMAL_F32))), + Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_NORMAL_F32))), + Token::Number(Ok(Number::F32(LARGEST_F32_LESS_THAN_ONE))), + Token::Number(Ok(Number::F32(SMALLEST_F32_LARGER_THAN_ONE))), + Token::Number(Ok(Number::F32(LARGEST_NORMAL_F32))), + ], + ); + sub_test( + "0x1p128f 0x1.000001p0f", + &[ + Token::Number(Err(NumberError::NotRepresentable)), // = 2^128 + Token::Number(Err(NumberError::NotRepresentable)), + ], + ); +} + +#[test] +fn double_floats() { + sub_test( + "0x1.2p4lf 0x1p8lf 0.0625lf 625e-4lf 10lf 10l", + &[ + Token::Number(Ok(Number::F64(18.0))), + Token::Number(Ok(Number::F64(256.0))), + Token::Number(Ok(Number::F64(0.0625))), + Token::Number(Ok(Number::F64(0.0625))), + Token::Number(Ok(Number::F64(10.0))), + Token::Number(Ok(Number::AbstractInt(10))), + Token::Word("l"), + ], + ) +} + +#[test] +fn test_tokens() { + sub_test("id123_OK", &[Token::Word("id123_OK")]); + sub_test( + "92No", + &[ + Token::Number(Ok(Number::AbstractInt(92))), + Token::Word("No"), + ], + ); + sub_test( + "2u3o", + &[ + Token::Number(Ok(Number::U32(2))), + Token::Number(Ok(Number::AbstractInt(3))), + Token::Word("o"), + ], + ); + sub_test( + "2.4f44po", + &[ + Token::Number(Ok(Number::F32(2.4))), + Token::Number(Ok(Number::AbstractInt(44))), + Token::Word("po"), + ], + ); + sub_test( + "Δέλτα réflexion Кызыл 𐰓𐰏𐰇 朝焼け سلام 검정 שָׁלוֹם गुलाबी փիրուզ", + &[ + Token::Word("Δέλτα"), + Token::Word("réflexion"), + Token::Word("Кызыл"), + Token::Word("𐰓𐰏𐰇"), + Token::Word("朝焼け"), + Token::Word("سلام"), + Token::Word("검정"), + Token::Word("שָׁלוֹם"), + Token::Word("गुलाबी"), + Token::Word("փիրուզ"), + ], + ); + sub_test("æNoø", &[Token::Word("æNoø")]); + sub_test("No¾", &[Token::Word("No"), Token::Unknown('¾')]); + sub_test("No好", &[Token::Word("No好")]); + sub_test("_No", &[Token::Word("_No")]); + sub_test( + "*/*/***/*//=/*****//", + &[ + Token::Operation('*'), + Token::AssignmentOperation('/'), + Token::Operation('/'), + ], + ); + + // Type suffixes are only allowed on hex float literals + // if you provided an exponent. + sub_test( + "0x1.2f 0x1.2f 0x1.2h 0x1.2H 0x1.2lf", + &[ + // The 'f' suffixes are taken as a hex digit: + // the fractional part is 0x2f / 256. + Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))), + Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))), + Token::Number(Ok(Number::AbstractFloat(1.125))), + Token::Word("h"), + Token::Number(Ok(Number::AbstractFloat(1.125))), + Token::Word("H"), + Token::Number(Ok(Number::AbstractFloat(1.125))), + Token::Word("lf"), + ], + ) +} + +#[test] +fn test_variable_decl() { + sub_test( + "@group(0 ) var< uniform> texture: texture_multisampled_2d <f32 >;", + &[ + Token::Attribute, + Token::Word("group"), + Token::Paren('('), + Token::Number(Ok(Number::AbstractInt(0))), + Token::Paren(')'), + Token::Word("var"), + Token::Paren('<'), + Token::Word("uniform"), + Token::Paren('>'), + Token::Word("texture"), + Token::Separator(':'), + Token::Word("texture_multisampled_2d"), + Token::Paren('<'), + Token::Word("f32"), + Token::Paren('>'), + Token::Separator(';'), + ], + ); + sub_test( + "var<storage,read_write> buffer: array<u32>;", + &[ + Token::Word("var"), + Token::Paren('<'), + Token::Word("storage"), + Token::Separator(','), + Token::Word("read_write"), + Token::Paren('>'), + Token::Word("buffer"), + Token::Separator(':'), + Token::Word("array"), + Token::Paren('<'), + Token::Word("u32"), + Token::Paren('>'), + Token::Separator(';'), + ], + ); +} diff --git a/third_party/rust/naga/src/front/wgsl/parse/mod.rs b/third_party/rust/naga/src/front/wgsl/parse/mod.rs new file mode 100644 index 0000000000..51fc2f013b --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/parse/mod.rs @@ -0,0 +1,2350 @@ +use crate::front::wgsl::error::{Error, ExpectedToken}; +use crate::front::wgsl::parse::lexer::{Lexer, Token}; +use crate::front::wgsl::parse::number::Number; +use crate::front::wgsl::Scalar; +use crate::front::SymbolTable; +use crate::{Arena, FastIndexSet, Handle, ShaderStage, Span}; + +pub mod ast; +pub mod conv; +pub mod lexer; +pub mod number; + +/// State for constructing an AST expression. +/// +/// Not to be confused with [`lower::ExpressionContext`], which is for producing +/// Naga IR from the AST we produce here. +/// +/// [`lower::ExpressionContext`]: super::lower::ExpressionContext +struct ExpressionContext<'input, 'temp, 'out> { + /// The [`TranslationUnit::expressions`] arena to which we should contribute + /// expressions. + /// + /// [`TranslationUnit::expressions`]: ast::TranslationUnit::expressions + expressions: &'out mut Arena<ast::Expression<'input>>, + + /// The [`TranslationUnit::types`] arena to which we should contribute new + /// types. + /// + /// [`TranslationUnit::types`]: ast::TranslationUnit::types + types: &'out mut Arena<ast::Type<'input>>, + + /// A map from identifiers in scope to the locals/arguments they represent. + /// + /// The handles refer to the [`Function::locals`] area; see that field's + /// documentation for details. + /// + /// [`Function::locals`]: ast::Function::locals + local_table: &'temp mut SymbolTable<&'input str, Handle<ast::Local>>, + + /// The [`Function::locals`] arena for the function we're building. + /// + /// [`Function::locals`]: ast::Function::locals + locals: &'out mut Arena<ast::Local>, + + /// Identifiers used by the current global declaration that have no local definition. + /// + /// This becomes the [`GlobalDecl`]'s [`dependencies`] set. + /// + /// Note that we don't know at parse time what kind of [`GlobalDecl`] the + /// name refers to. We can't look up names until we've seen the entire + /// translation unit. + /// + /// [`GlobalDecl`]: ast::GlobalDecl + /// [`dependencies`]: ast::GlobalDecl::dependencies + unresolved: &'out mut FastIndexSet<ast::Dependency<'input>>, +} + +impl<'a> ExpressionContext<'a, '_, '_> { + fn parse_binary_op( + &mut self, + lexer: &mut Lexer<'a>, + classifier: impl Fn(Token<'a>) -> Option<crate::BinaryOperator>, + mut parser: impl FnMut( + &mut Lexer<'a>, + &mut Self, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + let start = lexer.start_byte_offset(); + let mut accumulator = parser(lexer, self)?; + while let Some(op) = classifier(lexer.peek().0) { + let _ = lexer.next(); + let left = accumulator; + let right = parser(lexer, self)?; + accumulator = self.expressions.append( + ast::Expression::Binary { op, left, right }, + lexer.span_from(start), + ); + } + Ok(accumulator) + } + + fn declare_local(&mut self, name: ast::Ident<'a>) -> Result<Handle<ast::Local>, Error<'a>> { + let handle = self.locals.append(ast::Local, name.span); + if let Some(old) = self.local_table.add(name.name, handle) { + Err(Error::Redefinition { + previous: self.locals.get_span(old), + current: name.span, + }) + } else { + Ok(handle) + } + } +} + +/// Which grammar rule we are in the midst of parsing. +/// +/// This is used for error checking. `Parser` maintains a stack of +/// these and (occasionally) checks that it is being pushed and popped +/// as expected. +#[derive(Clone, Debug, PartialEq)] +enum Rule { + Attribute, + VariableDecl, + TypeDecl, + FunctionDecl, + Block, + Statement, + PrimaryExpr, + SingularExpr, + UnaryExpr, + GeneralExpr, +} + +struct ParsedAttribute<T> { + value: Option<T>, +} + +impl<T> Default for ParsedAttribute<T> { + fn default() -> Self { + Self { value: None } + } +} + +impl<T> ParsedAttribute<T> { + fn set(&mut self, value: T, name_span: Span) -> Result<(), Error<'static>> { + if self.value.is_some() { + return Err(Error::RepeatedAttribute(name_span)); + } + self.value = Some(value); + Ok(()) + } +} + +#[derive(Default)] +struct BindingParser<'a> { + location: ParsedAttribute<Handle<ast::Expression<'a>>>, + second_blend_source: ParsedAttribute<bool>, + built_in: ParsedAttribute<crate::BuiltIn>, + interpolation: ParsedAttribute<crate::Interpolation>, + sampling: ParsedAttribute<crate::Sampling>, + invariant: ParsedAttribute<bool>, +} + +impl<'a> BindingParser<'a> { + fn parse( + &mut self, + parser: &mut Parser, + lexer: &mut Lexer<'a>, + name: &'a str, + name_span: Span, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<(), Error<'a>> { + match name { + "location" => { + lexer.expect(Token::Paren('('))?; + self.location + .set(parser.general_expression(lexer, ctx)?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } + "builtin" => { + lexer.expect(Token::Paren('('))?; + let (raw, span) = lexer.next_ident_with_span()?; + self.built_in + .set(conv::map_built_in(raw, span)?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } + "interpolate" => { + lexer.expect(Token::Paren('('))?; + let (raw, span) = lexer.next_ident_with_span()?; + self.interpolation + .set(conv::map_interpolation(raw, span)?, name_span)?; + if lexer.skip(Token::Separator(',')) { + let (raw, span) = lexer.next_ident_with_span()?; + self.sampling + .set(conv::map_sampling(raw, span)?, name_span)?; + } + lexer.expect(Token::Paren(')'))?; + } + "second_blend_source" => { + self.second_blend_source.set(true, name_span)?; + } + "invariant" => { + self.invariant.set(true, name_span)?; + } + _ => return Err(Error::UnknownAttribute(name_span)), + } + Ok(()) + } + + fn finish(self, span: Span) -> Result<Option<ast::Binding<'a>>, Error<'a>> { + match ( + self.location.value, + self.built_in.value, + self.interpolation.value, + self.sampling.value, + self.invariant.value.unwrap_or_default(), + ) { + (None, None, None, None, false) => Ok(None), + (Some(location), None, interpolation, sampling, false) => { + // Before handing over the completed `Module`, we call + // `apply_default_interpolation` to ensure that the interpolation and + // sampling have been explicitly specified on all vertex shader output and fragment + // shader input user bindings, so leaving them potentially `None` here is fine. + Ok(Some(ast::Binding::Location { + location, + interpolation, + sampling, + second_blend_source: self.second_blend_source.value.unwrap_or(false), + })) + } + (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant) => { + Ok(Some(ast::Binding::BuiltIn(crate::BuiltIn::Position { + invariant, + }))) + } + (None, Some(built_in), None, None, false) => Ok(Some(ast::Binding::BuiltIn(built_in))), + (_, _, _, _, _) => Err(Error::InconsistentBinding(span)), + } + } +} + +pub struct Parser { + rules: Vec<(Rule, usize)>, +} + +impl Parser { + pub const fn new() -> Self { + Parser { rules: Vec::new() } + } + + fn reset(&mut self) { + self.rules.clear(); + } + + fn push_rule_span(&mut self, rule: Rule, lexer: &mut Lexer<'_>) { + self.rules.push((rule, lexer.start_byte_offset())); + } + + fn pop_rule_span(&mut self, lexer: &Lexer<'_>) -> Span { + let (_, initial) = self.rules.pop().unwrap(); + lexer.span_from(initial) + } + + fn peek_rule_span(&mut self, lexer: &Lexer<'_>) -> Span { + let &(_, initial) = self.rules.last().unwrap(); + lexer.span_from(initial) + } + + fn switch_value<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<ast::SwitchValue<'a>, Error<'a>> { + if let Token::Word("default") = lexer.peek().0 { + let _ = lexer.next(); + return Ok(ast::SwitchValue::Default); + } + + let expr = self.general_expression(lexer, ctx)?; + Ok(ast::SwitchValue::Expr(expr)) + } + + /// Decide if we're looking at a construction expression, and return its + /// type if so. + /// + /// If the identifier `word` is a [type-defining keyword], then return a + /// [`ConstructorType`] value describing the type to build. Return an error + /// if the type is not constructible (like `sampler`). + /// + /// If `word` isn't a type name, then return `None`. + /// + /// [type-defining keyword]: https://gpuweb.github.io/gpuweb/wgsl/#type-defining-keywords + /// [`ConstructorType`]: ast::ConstructorType + fn constructor_type<'a>( + &mut self, + lexer: &mut Lexer<'a>, + word: &'a str, + span: Span, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Option<ast::ConstructorType<'a>>, Error<'a>> { + if let Some(scalar) = conv::get_scalar_type(word) { + return Ok(Some(ast::ConstructorType::Scalar(scalar))); + } + + let partial = match word { + "vec2" => ast::ConstructorType::PartialVector { + size: crate::VectorSize::Bi, + }, + "vec2i" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Bi, + scalar: Scalar { + kind: crate::ScalarKind::Sint, + width: 4, + }, + })) + } + "vec2u" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Bi, + scalar: Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }, + })) + } + "vec2f" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Bi, + scalar: Scalar::F32, + })) + } + "vec3" => ast::ConstructorType::PartialVector { + size: crate::VectorSize::Tri, + }, + "vec3i" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Tri, + scalar: Scalar::I32, + })) + } + "vec3u" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Tri, + scalar: Scalar::U32, + })) + } + "vec3f" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Tri, + scalar: Scalar::F32, + })) + } + "vec4" => ast::ConstructorType::PartialVector { + size: crate::VectorSize::Quad, + }, + "vec4i" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Quad, + scalar: Scalar::I32, + })) + } + "vec4u" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Quad, + scalar: Scalar::U32, + })) + } + "vec4f" => { + return Ok(Some(ast::ConstructorType::Vector { + size: crate::VectorSize::Quad, + scalar: Scalar::F32, + })) + } + "mat2x2" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Bi, + }, + "mat2x2f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Bi, + width: 4, + })) + } + "mat2x3" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Tri, + }, + "mat2x3f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Tri, + width: 4, + })) + } + "mat2x4" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Quad, + }, + "mat2x4f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Quad, + width: 4, + })) + } + "mat3x2" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Bi, + }, + "mat3x2f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Bi, + width: 4, + })) + } + "mat3x3" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Tri, + }, + "mat3x3f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Tri, + width: 4, + })) + } + "mat3x4" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Quad, + }, + "mat3x4f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Quad, + width: 4, + })) + } + "mat4x2" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Bi, + }, + "mat4x2f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Bi, + width: 4, + })) + } + "mat4x3" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + }, + "mat4x3f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + width: 4, + })) + } + "mat4x4" => ast::ConstructorType::PartialMatrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Quad, + }, + "mat4x4f" => { + return Ok(Some(ast::ConstructorType::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Quad, + width: 4, + })) + } + "array" => ast::ConstructorType::PartialArray, + "atomic" + | "binding_array" + | "sampler" + | "sampler_comparison" + | "texture_1d" + | "texture_1d_array" + | "texture_2d" + | "texture_2d_array" + | "texture_3d" + | "texture_cube" + | "texture_cube_array" + | "texture_multisampled_2d" + | "texture_multisampled_2d_array" + | "texture_depth_2d" + | "texture_depth_2d_array" + | "texture_depth_cube" + | "texture_depth_cube_array" + | "texture_depth_multisampled_2d" + | "texture_storage_1d" + | "texture_storage_1d_array" + | "texture_storage_2d" + | "texture_storage_2d_array" + | "texture_storage_3d" => return Err(Error::TypeNotConstructible(span)), + _ => return Ok(None), + }; + + // parse component type if present + match (lexer.peek().0, partial) { + (Token::Paren('<'), ast::ConstructorType::PartialVector { size }) => { + let scalar = lexer.next_scalar_generic()?; + Ok(Some(ast::ConstructorType::Vector { size, scalar })) + } + (Token::Paren('<'), ast::ConstructorType::PartialMatrix { columns, rows }) => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + match scalar.kind { + crate::ScalarKind::Float => Ok(Some(ast::ConstructorType::Matrix { + columns, + rows, + width: scalar.width, + })), + _ => Err(Error::BadMatrixScalarKind(span, scalar)), + } + } + (Token::Paren('<'), ast::ConstructorType::PartialArray) => { + lexer.expect_generic_paren('<')?; + let base = self.type_decl(lexer, ctx)?; + let size = if lexer.skip(Token::Separator(',')) { + let expr = self.unary_expression(lexer, ctx)?; + ast::ArraySize::Constant(expr) + } else { + ast::ArraySize::Dynamic + }; + lexer.expect_generic_paren('>')?; + + Ok(Some(ast::ConstructorType::Array { base, size })) + } + (_, partial) => Ok(Some(partial)), + } + } + + /// Expects `name` to be consumed (not in lexer). + fn arguments<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Vec<Handle<ast::Expression<'a>>>, Error<'a>> { + lexer.open_arguments()?; + let mut arguments = Vec::new(); + loop { + if !arguments.is_empty() { + if !lexer.next_argument()? { + break; + } + } else if lexer.skip(Token::Paren(')')) { + break; + } + let arg = self.general_expression(lexer, ctx)?; + arguments.push(arg); + } + + Ok(arguments) + } + + /// Expects [`Rule::PrimaryExpr`] or [`Rule::SingularExpr`] on top; does not pop it. + /// Expects `name` to be consumed (not in lexer). + fn function_call<'a>( + &mut self, + lexer: &mut Lexer<'a>, + name: &'a str, + name_span: Span, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + assert!(self.rules.last().is_some()); + + let expr = match name { + // bitcast looks like a function call, but it's an operator and must be handled differently. + "bitcast" => { + lexer.expect_generic_paren('<')?; + let start = lexer.start_byte_offset(); + let to = self.type_decl(lexer, ctx)?; + let span = lexer.span_from(start); + lexer.expect_generic_paren('>')?; + + lexer.open_arguments()?; + let expr = self.general_expression(lexer, ctx)?; + lexer.close_arguments()?; + + ast::Expression::Bitcast { + expr, + to, + ty_span: span, + } + } + // everything else must be handled later, since they can be hidden by user-defined functions. + _ => { + let arguments = self.arguments(lexer, ctx)?; + ctx.unresolved.insert(ast::Dependency { + ident: name, + usage: name_span, + }); + ast::Expression::Call { + function: ast::Ident { + name, + span: name_span, + }, + arguments, + } + } + }; + + let span = self.peek_rule_span(lexer); + let expr = ctx.expressions.append(expr, span); + Ok(expr) + } + + fn ident_expr<'a>( + &mut self, + name: &'a str, + name_span: Span, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> ast::IdentExpr<'a> { + match ctx.local_table.lookup(name) { + Some(&local) => ast::IdentExpr::Local(local), + None => { + ctx.unresolved.insert(ast::Dependency { + ident: name, + usage: name_span, + }); + ast::IdentExpr::Unresolved(name) + } + } + } + + fn primary_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + self.push_rule_span(Rule::PrimaryExpr, lexer); + + let expr = match lexer.peek() { + (Token::Paren('('), _) => { + let _ = lexer.next(); + let expr = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Paren(')'))?; + self.pop_rule_span(lexer); + return Ok(expr); + } + (Token::Word("true"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Bool(true)) + } + (Token::Word("false"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Bool(false)) + } + (Token::Number(res), span) => { + let _ = lexer.next(); + let num = res.map_err(|err| Error::BadNumber(span, err))?; + ast::Expression::Literal(ast::Literal::Number(num)) + } + (Token::Word("RAY_FLAG_NONE"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Number(Number::U32(0))) + } + (Token::Word("RAY_FLAG_TERMINATE_ON_FIRST_HIT"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Number(Number::U32(4))) + } + (Token::Word("RAY_QUERY_INTERSECTION_NONE"), _) => { + let _ = lexer.next(); + ast::Expression::Literal(ast::Literal::Number(Number::U32(0))) + } + (Token::Word(word), span) => { + let start = lexer.start_byte_offset(); + let _ = lexer.next(); + + if let Some(ty) = self.constructor_type(lexer, word, span, ctx)? { + let ty_span = lexer.span_from(start); + let components = self.arguments(lexer, ctx)?; + ast::Expression::Construct { + ty, + ty_span, + components, + } + } else if let Token::Paren('(') = lexer.peek().0 { + self.pop_rule_span(lexer); + return self.function_call(lexer, word, span, ctx); + } else if word == "bitcast" { + self.pop_rule_span(lexer); + return self.function_call(lexer, word, span, ctx); + } else { + let ident = self.ident_expr(word, span, ctx); + ast::Expression::Ident(ident) + } + } + other => return Err(Error::Unexpected(other.1, ExpectedToken::PrimaryExpression)), + }; + + let span = self.pop_rule_span(lexer); + let expr = ctx.expressions.append(expr, span); + Ok(expr) + } + + fn postfix<'a>( + &mut self, + span_start: usize, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + expr: Handle<ast::Expression<'a>>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + let mut expr = expr; + + loop { + let expression = match lexer.peek().0 { + Token::Separator('.') => { + let _ = lexer.next(); + let field = lexer.next_ident()?; + + ast::Expression::Member { base: expr, field } + } + Token::Paren('[') => { + let _ = lexer.next(); + let index = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Paren(']'))?; + + ast::Expression::Index { base: expr, index } + } + _ => break, + }; + + let span = lexer.span_from(span_start); + expr = ctx.expressions.append(expression, span); + } + + Ok(expr) + } + + /// Parse a `unary_expression`. + fn unary_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + self.push_rule_span(Rule::UnaryExpr, lexer); + //TODO: refactor this to avoid backing up + let expr = match lexer.peek().0 { + Token::Operation('-') => { + let _ = lexer.next(); + let expr = self.unary_expression(lexer, ctx)?; + let expr = ast::Expression::Unary { + op: crate::UnaryOperator::Negate, + expr, + }; + let span = self.peek_rule_span(lexer); + ctx.expressions.append(expr, span) + } + Token::Operation('!') => { + let _ = lexer.next(); + let expr = self.unary_expression(lexer, ctx)?; + let expr = ast::Expression::Unary { + op: crate::UnaryOperator::LogicalNot, + expr, + }; + let span = self.peek_rule_span(lexer); + ctx.expressions.append(expr, span) + } + Token::Operation('~') => { + let _ = lexer.next(); + let expr = self.unary_expression(lexer, ctx)?; + let expr = ast::Expression::Unary { + op: crate::UnaryOperator::BitwiseNot, + expr, + }; + let span = self.peek_rule_span(lexer); + ctx.expressions.append(expr, span) + } + Token::Operation('*') => { + let _ = lexer.next(); + let expr = self.unary_expression(lexer, ctx)?; + let expr = ast::Expression::Deref(expr); + let span = self.peek_rule_span(lexer); + ctx.expressions.append(expr, span) + } + Token::Operation('&') => { + let _ = lexer.next(); + let expr = self.unary_expression(lexer, ctx)?; + let expr = ast::Expression::AddrOf(expr); + let span = self.peek_rule_span(lexer); + ctx.expressions.append(expr, span) + } + _ => self.singular_expression(lexer, ctx)?, + }; + + self.pop_rule_span(lexer); + Ok(expr) + } + + /// Parse a `singular_expression`. + fn singular_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + let start = lexer.start_byte_offset(); + self.push_rule_span(Rule::SingularExpr, lexer); + let primary_expr = self.primary_expression(lexer, ctx)?; + let singular_expr = self.postfix(start, lexer, ctx, primary_expr)?; + self.pop_rule_span(lexer); + + Ok(singular_expr) + } + + fn equality_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + context: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + // equality_expression + context.parse_binary_op( + lexer, + |token| match token { + Token::LogicalOperation('=') => Some(crate::BinaryOperator::Equal), + Token::LogicalOperation('!') => Some(crate::BinaryOperator::NotEqual), + _ => None, + }, + // relational_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Paren('<') => Some(crate::BinaryOperator::Less), + Token::Paren('>') => Some(crate::BinaryOperator::Greater), + Token::LogicalOperation('<') => Some(crate::BinaryOperator::LessEqual), + Token::LogicalOperation('>') => Some(crate::BinaryOperator::GreaterEqual), + _ => None, + }, + // shift_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::ShiftOperation('<') => { + Some(crate::BinaryOperator::ShiftLeft) + } + Token::ShiftOperation('>') => { + Some(crate::BinaryOperator::ShiftRight) + } + _ => None, + }, + // additive_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('+') => Some(crate::BinaryOperator::Add), + Token::Operation('-') => { + Some(crate::BinaryOperator::Subtract) + } + _ => None, + }, + // multiplicative_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('*') => { + Some(crate::BinaryOperator::Multiply) + } + Token::Operation('/') => { + Some(crate::BinaryOperator::Divide) + } + Token::Operation('%') => { + Some(crate::BinaryOperator::Modulo) + } + _ => None, + }, + |lexer, context| self.unary_expression(lexer, context), + ) + }, + ) + }, + ) + }, + ) + }, + ) + } + + fn general_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> { + self.general_expression_with_span(lexer, ctx) + .map(|(expr, _)| expr) + } + + fn general_expression_with_span<'a>( + &mut self, + lexer: &mut Lexer<'a>, + context: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<(Handle<ast::Expression<'a>>, Span), Error<'a>> { + self.push_rule_span(Rule::GeneralExpr, lexer); + // logical_or_expression + let handle = context.parse_binary_op( + lexer, + |token| match token { + Token::LogicalOperation('|') => Some(crate::BinaryOperator::LogicalOr), + _ => None, + }, + // logical_and_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::LogicalOperation('&') => Some(crate::BinaryOperator::LogicalAnd), + _ => None, + }, + // inclusive_or_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('|') => Some(crate::BinaryOperator::InclusiveOr), + _ => None, + }, + // exclusive_or_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('^') => { + Some(crate::BinaryOperator::ExclusiveOr) + } + _ => None, + }, + // and_expression + |lexer, context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('&') => { + Some(crate::BinaryOperator::And) + } + _ => None, + }, + |lexer, context| { + self.equality_expression(lexer, context) + }, + ) + }, + ) + }, + ) + }, + ) + }, + )?; + Ok((handle, self.pop_rule_span(lexer))) + } + + fn variable_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<ast::GlobalVariable<'a>, Error<'a>> { + self.push_rule_span(Rule::VariableDecl, lexer); + let mut space = crate::AddressSpace::Handle; + + if lexer.skip(Token::Paren('<')) { + let (class_str, span) = lexer.next_ident_with_span()?; + space = match class_str { + "storage" => { + let access = if lexer.skip(Token::Separator(',')) { + lexer.next_storage_access()? + } else { + // defaulting to `read` + crate::StorageAccess::LOAD + }; + crate::AddressSpace::Storage { access } + } + _ => conv::map_address_space(class_str, span)?, + }; + lexer.expect(Token::Paren('>'))?; + } + let name = lexer.next_ident()?; + lexer.expect(Token::Separator(':'))?; + let ty = self.type_decl(lexer, ctx)?; + + let init = if lexer.skip(Token::Operation('=')) { + let handle = self.general_expression(lexer, ctx)?; + Some(handle) + } else { + None + }; + lexer.expect(Token::Separator(';'))?; + self.pop_rule_span(lexer); + + Ok(ast::GlobalVariable { + name, + space, + binding: None, + ty, + init, + }) + } + + fn struct_body<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Vec<ast::StructMember<'a>>, Error<'a>> { + let mut members = Vec::new(); + + lexer.expect(Token::Paren('{'))?; + let mut ready = true; + while !lexer.skip(Token::Paren('}')) { + if !ready { + return Err(Error::Unexpected( + lexer.next().1, + ExpectedToken::Token(Token::Separator(',')), + )); + } + let (mut size, mut align) = (ParsedAttribute::default(), ParsedAttribute::default()); + self.push_rule_span(Rule::Attribute, lexer); + let mut bind_parser = BindingParser::default(); + while lexer.skip(Token::Attribute) { + match lexer.next_ident_with_span()? { + ("size", name_span) => { + lexer.expect(Token::Paren('('))?; + let expr = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Paren(')'))?; + size.set(expr, name_span)?; + } + ("align", name_span) => { + lexer.expect(Token::Paren('('))?; + let expr = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Paren(')'))?; + align.set(expr, name_span)?; + } + (word, word_span) => bind_parser.parse(self, lexer, word, word_span, ctx)?, + } + } + + let bind_span = self.pop_rule_span(lexer); + let binding = bind_parser.finish(bind_span)?; + + let name = lexer.next_ident()?; + lexer.expect(Token::Separator(':'))?; + let ty = self.type_decl(lexer, ctx)?; + ready = lexer.skip(Token::Separator(',')); + + members.push(ast::StructMember { + name, + ty, + binding, + size: size.value, + align: align.value, + }); + } + + Ok(members) + } + + fn matrix_scalar_type<'a>( + &mut self, + lexer: &mut Lexer<'a>, + columns: crate::VectorSize, + rows: crate::VectorSize, + ) -> Result<ast::Type<'a>, Error<'a>> { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + match scalar.kind { + crate::ScalarKind::Float => Ok(ast::Type::Matrix { + columns, + rows, + width: scalar.width, + }), + _ => Err(Error::BadMatrixScalarKind(span, scalar)), + } + } + + fn type_decl_impl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + word: &'a str, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Option<ast::Type<'a>>, Error<'a>> { + if let Some(scalar) = conv::get_scalar_type(word) { + return Ok(Some(ast::Type::Scalar(scalar))); + } + + Ok(Some(match word { + "vec2" => { + let scalar = lexer.next_scalar_generic()?; + ast::Type::Vector { + size: crate::VectorSize::Bi, + scalar, + } + } + "vec2i" => ast::Type::Vector { + size: crate::VectorSize::Bi, + scalar: Scalar { + kind: crate::ScalarKind::Sint, + width: 4, + }, + }, + "vec2u" => ast::Type::Vector { + size: crate::VectorSize::Bi, + scalar: Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }, + }, + "vec2f" => ast::Type::Vector { + size: crate::VectorSize::Bi, + scalar: Scalar::F32, + }, + "vec3" => { + let scalar = lexer.next_scalar_generic()?; + ast::Type::Vector { + size: crate::VectorSize::Tri, + scalar, + } + } + "vec3i" => ast::Type::Vector { + size: crate::VectorSize::Tri, + scalar: Scalar { + kind: crate::ScalarKind::Sint, + width: 4, + }, + }, + "vec3u" => ast::Type::Vector { + size: crate::VectorSize::Tri, + scalar: Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }, + }, + "vec3f" => ast::Type::Vector { + size: crate::VectorSize::Tri, + scalar: Scalar::F32, + }, + "vec4" => { + let scalar = lexer.next_scalar_generic()?; + ast::Type::Vector { + size: crate::VectorSize::Quad, + scalar, + } + } + "vec4i" => ast::Type::Vector { + size: crate::VectorSize::Quad, + scalar: Scalar { + kind: crate::ScalarKind::Sint, + width: 4, + }, + }, + "vec4u" => ast::Type::Vector { + size: crate::VectorSize::Quad, + scalar: Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }, + }, + "vec4f" => ast::Type::Vector { + size: crate::VectorSize::Quad, + scalar: Scalar::F32, + }, + "mat2x2" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Bi)? + } + "mat2x2f" => ast::Type::Matrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Bi, + width: 4, + }, + "mat2x3" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Tri)? + } + "mat2x3f" => ast::Type::Matrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Tri, + width: 4, + }, + "mat2x4" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Quad)? + } + "mat2x4f" => ast::Type::Matrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Quad, + width: 4, + }, + "mat3x2" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Bi)? + } + "mat3x2f" => ast::Type::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Bi, + width: 4, + }, + "mat3x3" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Tri)? + } + "mat3x3f" => ast::Type::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Tri, + width: 4, + }, + "mat3x4" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Quad)? + } + "mat3x4f" => ast::Type::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Quad, + width: 4, + }, + "mat4x2" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Bi)? + } + "mat4x2f" => ast::Type::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Bi, + width: 4, + }, + "mat4x3" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Tri)? + } + "mat4x3f" => ast::Type::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + width: 4, + }, + "mat4x4" => { + self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Quad)? + } + "mat4x4f" => ast::Type::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Quad, + width: 4, + }, + "atomic" => { + let scalar = lexer.next_scalar_generic()?; + ast::Type::Atomic(scalar) + } + "ptr" => { + lexer.expect_generic_paren('<')?; + let (ident, span) = lexer.next_ident_with_span()?; + let mut space = conv::map_address_space(ident, span)?; + lexer.expect(Token::Separator(','))?; + let base = self.type_decl(lexer, ctx)?; + if let crate::AddressSpace::Storage { ref mut access } = space { + *access = if lexer.skip(Token::Separator(',')) { + lexer.next_storage_access()? + } else { + crate::StorageAccess::LOAD + }; + } + lexer.expect_generic_paren('>')?; + ast::Type::Pointer { base, space } + } + "array" => { + lexer.expect_generic_paren('<')?; + let base = self.type_decl(lexer, ctx)?; + let size = if lexer.skip(Token::Separator(',')) { + let size = self.unary_expression(lexer, ctx)?; + ast::ArraySize::Constant(size) + } else { + ast::ArraySize::Dynamic + }; + lexer.expect_generic_paren('>')?; + + ast::Type::Array { base, size } + } + "binding_array" => { + lexer.expect_generic_paren('<')?; + let base = self.type_decl(lexer, ctx)?; + let size = if lexer.skip(Token::Separator(',')) { + let size = self.unary_expression(lexer, ctx)?; + ast::ArraySize::Constant(size) + } else { + ast::ArraySize::Dynamic + }; + lexer.expect_generic_paren('>')?; + + ast::Type::BindingArray { base, size } + } + "sampler" => ast::Type::Sampler { comparison: false }, + "sampler_comparison" => ast::Type::Sampler { comparison: true }, + "texture_1d" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::D1, + arrayed: false, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: false, + }, + } + } + "texture_1d_array" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::D1, + arrayed: true, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: false, + }, + } + } + "texture_2d" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: false, + }, + } + } + "texture_2d_array" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: false, + }, + } + } + "texture_3d" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::D3, + arrayed: false, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: false, + }, + } + } + "texture_cube" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::Cube, + arrayed: false, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: false, + }, + } + } + "texture_cube_array" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::Cube, + arrayed: true, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: false, + }, + } + } + "texture_multisampled_2d" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: true, + }, + } + } + "texture_multisampled_2d_array" => { + let (scalar, span) = lexer.next_scalar_generic_with_span()?; + Self::check_texture_sample_type(scalar, span)?; + ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Sampled { + kind: scalar.kind, + multi: true, + }, + } + } + "texture_depth_2d" => ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Depth { multi: false }, + }, + "texture_depth_2d_array" => ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Depth { multi: false }, + }, + "texture_depth_cube" => ast::Type::Image { + dim: crate::ImageDimension::Cube, + arrayed: false, + class: crate::ImageClass::Depth { multi: false }, + }, + "texture_depth_cube_array" => ast::Type::Image { + dim: crate::ImageDimension::Cube, + arrayed: true, + class: crate::ImageClass::Depth { multi: false }, + }, + "texture_depth_multisampled_2d" => ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Depth { multi: true }, + }, + "texture_storage_1d" => { + let (format, access) = lexer.next_format_generic()?; + ast::Type::Image { + dim: crate::ImageDimension::D1, + arrayed: false, + class: crate::ImageClass::Storage { format, access }, + } + } + "texture_storage_1d_array" => { + let (format, access) = lexer.next_format_generic()?; + ast::Type::Image { + dim: crate::ImageDimension::D1, + arrayed: true, + class: crate::ImageClass::Storage { format, access }, + } + } + "texture_storage_2d" => { + let (format, access) = lexer.next_format_generic()?; + ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Storage { format, access }, + } + } + "texture_storage_2d_array" => { + let (format, access) = lexer.next_format_generic()?; + ast::Type::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Storage { format, access }, + } + } + "texture_storage_3d" => { + let (format, access) = lexer.next_format_generic()?; + ast::Type::Image { + dim: crate::ImageDimension::D3, + arrayed: false, + class: crate::ImageClass::Storage { format, access }, + } + } + "acceleration_structure" => ast::Type::AccelerationStructure, + "ray_query" => ast::Type::RayQuery, + "RayDesc" => ast::Type::RayDesc, + "RayIntersection" => ast::Type::RayIntersection, + _ => return Ok(None), + })) + } + + const fn check_texture_sample_type(scalar: Scalar, span: Span) -> Result<(), Error<'static>> { + use crate::ScalarKind::*; + // Validate according to https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type + match scalar { + Scalar { + kind: Float | Sint | Uint, + width: 4, + } => Ok(()), + _ => Err(Error::BadTextureSampleType { span, scalar }), + } + } + + /// Parse type declaration of a given name. + fn type_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<ast::Type<'a>>, Error<'a>> { + self.push_rule_span(Rule::TypeDecl, lexer); + + let (name, span) = lexer.next_ident_with_span()?; + + let ty = match self.type_decl_impl(lexer, name, ctx)? { + Some(ty) => ty, + None => { + ctx.unresolved.insert(ast::Dependency { + ident: name, + usage: span, + }); + ast::Type::User(ast::Ident { name, span }) + } + }; + + self.pop_rule_span(lexer); + + let handle = ctx.types.append(ty, Span::UNDEFINED); + Ok(handle) + } + + fn assignment_op_and_rhs<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + block: &mut ast::Block<'a>, + target: Handle<ast::Expression<'a>>, + span_start: usize, + ) -> Result<(), Error<'a>> { + use crate::BinaryOperator as Bo; + + let op = lexer.next(); + let (op, value) = match op { + (Token::Operation('='), _) => { + let value = self.general_expression(lexer, ctx)?; + (None, value) + } + (Token::AssignmentOperation(c), _) => { + let op = match c { + '<' => Bo::ShiftLeft, + '>' => Bo::ShiftRight, + '+' => Bo::Add, + '-' => Bo::Subtract, + '*' => Bo::Multiply, + '/' => Bo::Divide, + '%' => Bo::Modulo, + '&' => Bo::And, + '|' => Bo::InclusiveOr, + '^' => Bo::ExclusiveOr, + // Note: `consume_token` shouldn't produce any other assignment ops + _ => unreachable!(), + }; + + let value = self.general_expression(lexer, ctx)?; + (Some(op), value) + } + token @ (Token::IncrementOperation | Token::DecrementOperation, _) => { + let op = match token.0 { + Token::IncrementOperation => ast::StatementKind::Increment, + Token::DecrementOperation => ast::StatementKind::Decrement, + _ => unreachable!(), + }; + + let span = lexer.span_from(span_start); + block.stmts.push(ast::Statement { + kind: op(target), + span, + }); + return Ok(()); + } + _ => return Err(Error::Unexpected(op.1, ExpectedToken::Assignment)), + }; + + let span = lexer.span_from(span_start); + block.stmts.push(ast::Statement { + kind: ast::StatementKind::Assign { target, op, value }, + span, + }); + Ok(()) + } + + /// Parse an assignment statement (will also parse increment and decrement statements) + fn assignment_statement<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + block: &mut ast::Block<'a>, + ) -> Result<(), Error<'a>> { + let span_start = lexer.start_byte_offset(); + let target = self.general_expression(lexer, ctx)?; + self.assignment_op_and_rhs(lexer, ctx, block, target, span_start) + } + + /// Parse a function call statement. + /// Expects `ident` to be consumed (not in the lexer). + fn function_statement<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ident: &'a str, + ident_span: Span, + span_start: usize, + context: &mut ExpressionContext<'a, '_, '_>, + block: &mut ast::Block<'a>, + ) -> Result<(), Error<'a>> { + self.push_rule_span(Rule::SingularExpr, lexer); + + context.unresolved.insert(ast::Dependency { + ident, + usage: ident_span, + }); + let arguments = self.arguments(lexer, context)?; + let span = lexer.span_from(span_start); + + block.stmts.push(ast::Statement { + kind: ast::StatementKind::Call { + function: ast::Ident { + name: ident, + span: ident_span, + }, + arguments, + }, + span, + }); + + self.pop_rule_span(lexer); + + Ok(()) + } + + fn function_call_or_assignment_statement<'a>( + &mut self, + lexer: &mut Lexer<'a>, + context: &mut ExpressionContext<'a, '_, '_>, + block: &mut ast::Block<'a>, + ) -> Result<(), Error<'a>> { + let span_start = lexer.start_byte_offset(); + match lexer.peek() { + (Token::Word(name), span) => { + // A little hack for 2 token lookahead. + let cloned = lexer.clone(); + let _ = lexer.next(); + match lexer.peek() { + (Token::Paren('('), _) => { + self.function_statement(lexer, name, span, span_start, context, block) + } + _ => { + *lexer = cloned; + self.assignment_statement(lexer, context, block) + } + } + } + _ => self.assignment_statement(lexer, context, block), + } + } + + fn statement<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + block: &mut ast::Block<'a>, + ) -> Result<(), Error<'a>> { + self.push_rule_span(Rule::Statement, lexer); + match lexer.peek() { + (Token::Separator(';'), _) => { + let _ = lexer.next(); + self.pop_rule_span(lexer); + return Ok(()); + } + (Token::Paren('{'), _) => { + let (inner, span) = self.block(lexer, ctx)?; + block.stmts.push(ast::Statement { + kind: ast::StatementKind::Block(inner), + span, + }); + self.pop_rule_span(lexer); + return Ok(()); + } + (Token::Word(word), _) => { + let kind = match word { + "_" => { + let _ = lexer.next(); + lexer.expect(Token::Operation('='))?; + let expr = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Separator(';'))?; + + ast::StatementKind::Ignore(expr) + } + "let" => { + let _ = lexer.next(); + let name = lexer.next_ident()?; + + let given_ty = if lexer.skip(Token::Separator(':')) { + let ty = self.type_decl(lexer, ctx)?; + Some(ty) + } else { + None + }; + lexer.expect(Token::Operation('='))?; + let expr_id = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Separator(';'))?; + + let handle = ctx.declare_local(name)?; + ast::StatementKind::LocalDecl(ast::LocalDecl::Let(ast::Let { + name, + ty: given_ty, + init: expr_id, + handle, + })) + } + "var" => { + let _ = lexer.next(); + + let name = lexer.next_ident()?; + let ty = if lexer.skip(Token::Separator(':')) { + let ty = self.type_decl(lexer, ctx)?; + Some(ty) + } else { + None + }; + + let init = if lexer.skip(Token::Operation('=')) { + let init = self.general_expression(lexer, ctx)?; + Some(init) + } else { + None + }; + + lexer.expect(Token::Separator(';'))?; + + let handle = ctx.declare_local(name)?; + ast::StatementKind::LocalDecl(ast::LocalDecl::Var(ast::LocalVariable { + name, + ty, + init, + handle, + })) + } + "return" => { + let _ = lexer.next(); + let value = if lexer.peek().0 != Token::Separator(';') { + let handle = self.general_expression(lexer, ctx)?; + Some(handle) + } else { + None + }; + lexer.expect(Token::Separator(';'))?; + ast::StatementKind::Return { value } + } + "if" => { + let _ = lexer.next(); + let condition = self.general_expression(lexer, ctx)?; + + let accept = self.block(lexer, ctx)?.0; + + let mut elsif_stack = Vec::new(); + let mut elseif_span_start = lexer.start_byte_offset(); + let mut reject = loop { + if !lexer.skip(Token::Word("else")) { + break ast::Block::default(); + } + + if !lexer.skip(Token::Word("if")) { + // ... else { ... } + break self.block(lexer, ctx)?.0; + } + + // ... else if (...) { ... } + let other_condition = self.general_expression(lexer, ctx)?; + let other_block = self.block(lexer, ctx)?; + elsif_stack.push((elseif_span_start, other_condition, other_block)); + elseif_span_start = lexer.start_byte_offset(); + }; + + // reverse-fold the else-if blocks + //Note: we may consider uplifting this to the IR + for (other_span_start, other_cond, other_block) in + elsif_stack.into_iter().rev() + { + let sub_stmt = ast::StatementKind::If { + condition: other_cond, + accept: other_block.0, + reject, + }; + reject = ast::Block::default(); + let span = lexer.span_from(other_span_start); + reject.stmts.push(ast::Statement { + kind: sub_stmt, + span, + }) + } + + ast::StatementKind::If { + condition, + accept, + reject, + } + } + "switch" => { + let _ = lexer.next(); + let selector = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Paren('{'))?; + let mut cases = Vec::new(); + + loop { + // cases + default + match lexer.next() { + (Token::Word("case"), _) => { + // parse a list of values + let value = loop { + let value = self.switch_value(lexer, ctx)?; + if lexer.skip(Token::Separator(',')) { + if lexer.skip(Token::Separator(':')) { + break value; + } + } else { + lexer.skip(Token::Separator(':')); + break value; + } + cases.push(ast::SwitchCase { + value, + body: ast::Block::default(), + fall_through: true, + }); + }; + + let body = self.block(lexer, ctx)?.0; + + cases.push(ast::SwitchCase { + value, + body, + fall_through: false, + }); + } + (Token::Word("default"), _) => { + lexer.skip(Token::Separator(':')); + let body = self.block(lexer, ctx)?.0; + cases.push(ast::SwitchCase { + value: ast::SwitchValue::Default, + body, + fall_through: false, + }); + } + (Token::Paren('}'), _) => break, + (_, span) => { + return Err(Error::Unexpected(span, ExpectedToken::SwitchItem)) + } + } + } + + ast::StatementKind::Switch { selector, cases } + } + "loop" => self.r#loop(lexer, ctx)?, + "while" => { + let _ = lexer.next(); + let mut body = ast::Block::default(); + + let (condition, span) = lexer.capture_span(|lexer| { + let condition = self.general_expression(lexer, ctx)?; + Ok(condition) + })?; + let mut reject = ast::Block::default(); + reject.stmts.push(ast::Statement { + kind: ast::StatementKind::Break, + span, + }); + + body.stmts.push(ast::Statement { + kind: ast::StatementKind::If { + condition, + accept: ast::Block::default(), + reject, + }, + span, + }); + + let (block, span) = self.block(lexer, ctx)?; + body.stmts.push(ast::Statement { + kind: ast::StatementKind::Block(block), + span, + }); + + ast::StatementKind::Loop { + body, + continuing: ast::Block::default(), + break_if: None, + } + } + "for" => { + let _ = lexer.next(); + lexer.expect(Token::Paren('('))?; + + ctx.local_table.push_scope(); + + if !lexer.skip(Token::Separator(';')) { + let num_statements = block.stmts.len(); + let (_, span) = { + let ctx = &mut *ctx; + let block = &mut *block; + lexer.capture_span(|lexer| self.statement(lexer, ctx, block))? + }; + + if block.stmts.len() != num_statements { + match block.stmts.last().unwrap().kind { + ast::StatementKind::Call { .. } + | ast::StatementKind::Assign { .. } + | ast::StatementKind::LocalDecl(_) => {} + _ => return Err(Error::InvalidForInitializer(span)), + } + } + }; + + let mut body = ast::Block::default(); + if !lexer.skip(Token::Separator(';')) { + let (condition, span) = lexer.capture_span(|lexer| { + let condition = self.general_expression(lexer, ctx)?; + lexer.expect(Token::Separator(';'))?; + Ok(condition) + })?; + let mut reject = ast::Block::default(); + reject.stmts.push(ast::Statement { + kind: ast::StatementKind::Break, + span, + }); + body.stmts.push(ast::Statement { + kind: ast::StatementKind::If { + condition, + accept: ast::Block::default(), + reject, + }, + span, + }); + }; + + let mut continuing = ast::Block::default(); + if !lexer.skip(Token::Paren(')')) { + self.function_call_or_assignment_statement( + lexer, + ctx, + &mut continuing, + )?; + lexer.expect(Token::Paren(')'))?; + } + + let (block, span) = self.block(lexer, ctx)?; + body.stmts.push(ast::Statement { + kind: ast::StatementKind::Block(block), + span, + }); + + ctx.local_table.pop_scope(); + + ast::StatementKind::Loop { + body, + continuing, + break_if: None, + } + } + "break" => { + let (_, span) = lexer.next(); + // Check if the next token is an `if`, this indicates + // that the user tried to type out a `break if` which + // is illegal in this position. + let (peeked_token, peeked_span) = lexer.peek(); + if let Token::Word("if") = peeked_token { + let span = span.until(&peeked_span); + return Err(Error::InvalidBreakIf(span)); + } + lexer.expect(Token::Separator(';'))?; + ast::StatementKind::Break + } + "continue" => { + let _ = lexer.next(); + lexer.expect(Token::Separator(';'))?; + ast::StatementKind::Continue + } + "discard" => { + let _ = lexer.next(); + lexer.expect(Token::Separator(';'))?; + ast::StatementKind::Kill + } + // assignment or a function call + _ => { + self.function_call_or_assignment_statement(lexer, ctx, block)?; + lexer.expect(Token::Separator(';'))?; + self.pop_rule_span(lexer); + return Ok(()); + } + }; + + let span = self.pop_rule_span(lexer); + block.stmts.push(ast::Statement { kind, span }); + } + _ => { + self.assignment_statement(lexer, ctx, block)?; + lexer.expect(Token::Separator(';'))?; + self.pop_rule_span(lexer); + } + } + Ok(()) + } + + fn r#loop<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<ast::StatementKind<'a>, Error<'a>> { + let _ = lexer.next(); + let mut body = ast::Block::default(); + let mut continuing = ast::Block::default(); + let mut break_if = None; + + lexer.expect(Token::Paren('{'))?; + + ctx.local_table.push_scope(); + + loop { + if lexer.skip(Token::Word("continuing")) { + // Branch for the `continuing` block, this must be + // the last thing in the loop body + + // Expect a opening brace to start the continuing block + lexer.expect(Token::Paren('{'))?; + loop { + if lexer.skip(Token::Word("break")) { + // Branch for the `break if` statement, this statement + // has the form `break if <expr>;` and must be the last + // statement in a continuing block + + // The break must be followed by an `if` to form + // the break if + lexer.expect(Token::Word("if"))?; + + let condition = self.general_expression(lexer, ctx)?; + // Set the condition of the break if to the newly parsed + // expression + break_if = Some(condition); + + // Expect a semicolon to close the statement + lexer.expect(Token::Separator(';'))?; + // Expect a closing brace to close the continuing block, + // since the break if must be the last statement + lexer.expect(Token::Paren('}'))?; + // Stop parsing the continuing block + break; + } else if lexer.skip(Token::Paren('}')) { + // If we encounter a closing brace it means we have reached + // the end of the continuing block and should stop processing + break; + } else { + // Otherwise try to parse a statement + self.statement(lexer, ctx, &mut continuing)?; + } + } + // Since the continuing block must be the last part of the loop body, + // we expect to see a closing brace to end the loop body + lexer.expect(Token::Paren('}'))?; + break; + } + if lexer.skip(Token::Paren('}')) { + // If we encounter a closing brace it means we have reached + // the end of the loop body and should stop processing + break; + } + // Otherwise try to parse a statement + self.statement(lexer, ctx, &mut body)?; + } + + ctx.local_table.pop_scope(); + + Ok(ast::StatementKind::Loop { + body, + continuing, + break_if, + }) + } + + /// compound_statement + fn block<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<(ast::Block<'a>, Span), Error<'a>> { + self.push_rule_span(Rule::Block, lexer); + + ctx.local_table.push_scope(); + + lexer.expect(Token::Paren('{'))?; + let mut block = ast::Block::default(); + while !lexer.skip(Token::Paren('}')) { + self.statement(lexer, ctx, &mut block)?; + } + + ctx.local_table.pop_scope(); + + let span = self.pop_rule_span(lexer); + Ok((block, span)) + } + + fn varying_binding<'a>( + &mut self, + lexer: &mut Lexer<'a>, + ctx: &mut ExpressionContext<'a, '_, '_>, + ) -> Result<Option<ast::Binding<'a>>, Error<'a>> { + let mut bind_parser = BindingParser::default(); + self.push_rule_span(Rule::Attribute, lexer); + + while lexer.skip(Token::Attribute) { + let (word, span) = lexer.next_ident_with_span()?; + bind_parser.parse(self, lexer, word, span, ctx)?; + } + + let span = self.pop_rule_span(lexer); + bind_parser.finish(span) + } + + fn function_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + out: &mut ast::TranslationUnit<'a>, + dependencies: &mut FastIndexSet<ast::Dependency<'a>>, + ) -> Result<ast::Function<'a>, Error<'a>> { + self.push_rule_span(Rule::FunctionDecl, lexer); + // read function name + let fun_name = lexer.next_ident()?; + + let mut locals = Arena::new(); + + let mut ctx = ExpressionContext { + expressions: &mut out.expressions, + local_table: &mut SymbolTable::default(), + locals: &mut locals, + types: &mut out.types, + unresolved: dependencies, + }; + + // start a scope that contains arguments as well as the function body + ctx.local_table.push_scope(); + + // read parameter list + let mut arguments = Vec::new(); + lexer.expect(Token::Paren('('))?; + let mut ready = true; + while !lexer.skip(Token::Paren(')')) { + if !ready { + return Err(Error::Unexpected( + lexer.next().1, + ExpectedToken::Token(Token::Separator(',')), + )); + } + let binding = self.varying_binding(lexer, &mut ctx)?; + + let param_name = lexer.next_ident()?; + + lexer.expect(Token::Separator(':'))?; + let param_type = self.type_decl(lexer, &mut ctx)?; + + let handle = ctx.declare_local(param_name)?; + arguments.push(ast::FunctionArgument { + name: param_name, + ty: param_type, + binding, + handle, + }); + ready = lexer.skip(Token::Separator(',')); + } + // read return type + let result = if lexer.skip(Token::Arrow) && !lexer.skip(Token::Word("void")) { + let binding = self.varying_binding(lexer, &mut ctx)?; + let ty = self.type_decl(lexer, &mut ctx)?; + Some(ast::FunctionResult { ty, binding }) + } else { + None + }; + + // do not use `self.block` here, since we must not push a new scope + lexer.expect(Token::Paren('{'))?; + let mut body = ast::Block::default(); + while !lexer.skip(Token::Paren('}')) { + self.statement(lexer, &mut ctx, &mut body)?; + } + + ctx.local_table.pop_scope(); + + let fun = ast::Function { + entry_point: None, + name: fun_name, + arguments, + result, + body, + locals, + }; + + // done + self.pop_rule_span(lexer); + + Ok(fun) + } + + fn global_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + out: &mut ast::TranslationUnit<'a>, + ) -> Result<(), Error<'a>> { + // read attributes + let mut binding = None; + let mut stage = ParsedAttribute::default(); + let mut compute_span = Span::new(0, 0); + let mut workgroup_size = ParsedAttribute::default(); + let mut early_depth_test = ParsedAttribute::default(); + let (mut bind_index, mut bind_group) = + (ParsedAttribute::default(), ParsedAttribute::default()); + + let mut dependencies = FastIndexSet::default(); + let mut ctx = ExpressionContext { + expressions: &mut out.expressions, + local_table: &mut SymbolTable::default(), + locals: &mut Arena::new(), + types: &mut out.types, + unresolved: &mut dependencies, + }; + + self.push_rule_span(Rule::Attribute, lexer); + while lexer.skip(Token::Attribute) { + match lexer.next_ident_with_span()? { + ("binding", name_span) => { + lexer.expect(Token::Paren('('))?; + bind_index.set(self.general_expression(lexer, &mut ctx)?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } + ("group", name_span) => { + lexer.expect(Token::Paren('('))?; + bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } + ("vertex", name_span) => { + stage.set(crate::ShaderStage::Vertex, name_span)?; + } + ("fragment", name_span) => { + stage.set(crate::ShaderStage::Fragment, name_span)?; + } + ("compute", name_span) => { + stage.set(crate::ShaderStage::Compute, name_span)?; + compute_span = name_span; + } + ("workgroup_size", name_span) => { + lexer.expect(Token::Paren('('))?; + let mut new_workgroup_size = [None; 3]; + for (i, size) in new_workgroup_size.iter_mut().enumerate() { + *size = Some(self.general_expression(lexer, &mut ctx)?); + match lexer.next() { + (Token::Paren(')'), _) => break, + (Token::Separator(','), _) if i != 2 => (), + other => { + return Err(Error::Unexpected( + other.1, + ExpectedToken::WorkgroupSizeSeparator, + )) + } + } + } + workgroup_size.set(new_workgroup_size, name_span)?; + } + ("early_depth_test", name_span) => { + let conservative = if lexer.skip(Token::Paren('(')) { + let (ident, ident_span) = lexer.next_ident_with_span()?; + let value = conv::map_conservative_depth(ident, ident_span)?; + lexer.expect(Token::Paren(')'))?; + Some(value) + } else { + None + }; + early_depth_test.set(crate::EarlyDepthTest { conservative }, name_span)?; + } + (_, word_span) => return Err(Error::UnknownAttribute(word_span)), + } + } + + let attrib_span = self.pop_rule_span(lexer); + match (bind_group.value, bind_index.value) { + (Some(group), Some(index)) => { + binding = Some(ast::ResourceBinding { + group, + binding: index, + }); + } + (Some(_), None) => return Err(Error::MissingAttribute("binding", attrib_span)), + (None, Some(_)) => return Err(Error::MissingAttribute("group", attrib_span)), + (None, None) => {} + } + + // read item + let start = lexer.start_byte_offset(); + let kind = match lexer.next() { + (Token::Separator(';'), _) => None, + (Token::Word("struct"), _) => { + let name = lexer.next_ident()?; + + let members = self.struct_body(lexer, &mut ctx)?; + Some(ast::GlobalDeclKind::Struct(ast::Struct { name, members })) + } + (Token::Word("alias"), _) => { + let name = lexer.next_ident()?; + + lexer.expect(Token::Operation('='))?; + let ty = self.type_decl(lexer, &mut ctx)?; + lexer.expect(Token::Separator(';'))?; + Some(ast::GlobalDeclKind::Type(ast::TypeAlias { name, ty })) + } + (Token::Word("const"), _) => { + let name = lexer.next_ident()?; + + let ty = if lexer.skip(Token::Separator(':')) { + let ty = self.type_decl(lexer, &mut ctx)?; + Some(ty) + } else { + None + }; + + lexer.expect(Token::Operation('='))?; + let init = self.general_expression(lexer, &mut ctx)?; + lexer.expect(Token::Separator(';'))?; + + Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init })) + } + (Token::Word("var"), _) => { + let mut var = self.variable_decl(lexer, &mut ctx)?; + var.binding = binding.take(); + Some(ast::GlobalDeclKind::Var(var)) + } + (Token::Word("fn"), _) => { + let function = self.function_decl(lexer, out, &mut dependencies)?; + Some(ast::GlobalDeclKind::Fn(ast::Function { + entry_point: if let Some(stage) = stage.value { + if stage == ShaderStage::Compute && workgroup_size.value.is_none() { + return Err(Error::MissingWorkgroupSize(compute_span)); + } + Some(ast::EntryPoint { + stage, + early_depth_test: early_depth_test.value, + workgroup_size: workgroup_size.value, + }) + } else { + None + }, + ..function + })) + } + (Token::End, _) => return Ok(()), + other => return Err(Error::Unexpected(other.1, ExpectedToken::GlobalItem)), + }; + + if let Some(kind) = kind { + out.decls.append( + ast::GlobalDecl { kind, dependencies }, + lexer.span_from(start), + ); + } + + if !self.rules.is_empty() { + log::error!("Reached the end of global decl, but rule stack is not empty"); + log::error!("Rules: {:?}", self.rules); + return Err(Error::Internal("rule stack is not empty")); + }; + + match binding { + None => Ok(()), + Some(_) => Err(Error::Internal("we had the attribute but no var?")), + } + } + + pub fn parse<'a>(&mut self, source: &'a str) -> Result<ast::TranslationUnit<'a>, Error<'a>> { + self.reset(); + + let mut lexer = Lexer::new(source); + let mut tu = ast::TranslationUnit::default(); + loop { + match self.global_decl(&mut lexer, &mut tu) { + Err(error) => return Err(error), + Ok(()) => { + if lexer.peek().0 == Token::End { + break; + } + } + } + } + + Ok(tu) + } +} diff --git a/third_party/rust/naga/src/front/wgsl/parse/number.rs b/third_party/rust/naga/src/front/wgsl/parse/number.rs new file mode 100644 index 0000000000..7b09ac59bb --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/parse/number.rs @@ -0,0 +1,420 @@ +use crate::front::wgsl::error::NumberError; +use crate::front::wgsl::parse::lexer::Token; + +/// When using this type assume no Abstract Int/Float for now +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum Number { + /// Abstract Int (-2^63 ≤ i < 2^63) + AbstractInt(i64), + /// Abstract Float (IEEE-754 binary64) + AbstractFloat(f64), + /// Concrete i32 + I32(i32), + /// Concrete u32 + U32(u32), + /// Concrete f32 + F32(f32), + /// Concrete f64 + F64(f64), +} + +pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) { + let (result, rest) = parse(input); + (Token::Number(result), rest) +} + +enum Kind { + Int(IntKind), + Float(FloatKind), +} + +enum IntKind { + I32, + U32, +} + +#[derive(Debug)] +enum FloatKind { + F16, + F32, + F64, +} + +// The following regexes (from the WGSL spec) will be matched: + +// int_literal: +// | / 0 [iu]? / +// | / [1-9][0-9]* [iu]? / +// | / 0[xX][0-9a-fA-F]+ [iu]? / + +// decimal_float_literal: +// | / 0 [fh] / +// | / [1-9][0-9]* [fh] / +// | / [0-9]* \.[0-9]+ ([eE][+-]?[0-9]+)? [fh]? / +// | / [0-9]+ \.[0-9]* ([eE][+-]?[0-9]+)? [fh]? / +// | / [0-9]+ [eE][+-]?[0-9]+ [fh]? / + +// hex_float_literal: +// | / 0[xX][0-9a-fA-F]* \.[0-9a-fA-F]+ ([pP][+-]?[0-9]+ [fh]?)? / +// | / 0[xX][0-9a-fA-F]+ \.[0-9a-fA-F]* ([pP][+-]?[0-9]+ [fh]?)? / +// | / 0[xX][0-9a-fA-F]+ [pP][+-]?[0-9]+ [fh]? / + +// You could visualize the regex below via https://debuggex.com to get a rough idea what `parse` is doing +// (?:0[xX](?:([0-9a-fA-F]+\.[0-9a-fA-F]*|[0-9a-fA-F]*\.[0-9a-fA-F]+)(?:([pP][+-]?[0-9]+)([fh]?))?|([0-9a-fA-F]+)([pP][+-]?[0-9]+)([fh]?)|([0-9a-fA-F]+)([iu]?))|((?:[0-9]+[eE][+-]?[0-9]+|(?:[0-9]+\.[0-9]*|[0-9]*\.[0-9]+)(?:[eE][+-]?[0-9]+)?))([fh]?)|((?:[0-9]|[1-9][0-9]+))([iufh]?)) + +// Leading signs are handled as unary operators. + +fn parse(input: &str) -> (Result<Number, NumberError>, &str) { + /// returns `true` and consumes `X` bytes from the given byte buffer + /// if the given `X` nr of patterns are found at the start of the buffer + macro_rules! consume { + ($bytes:ident, $($pattern:pat),*) => { + match $bytes { + &[$($pattern),*, ref rest @ ..] => { $bytes = rest; true }, + _ => false, + } + }; + } + + /// consumes one byte from the given byte buffer + /// if one of the given patterns are found at the start of the buffer + /// returning the corresponding expr for the matched pattern + macro_rules! consume_map { + ($bytes:ident, [$( $($pattern:pat_param),* => $to:expr),* $(,)?]) => { + match $bytes { + $( &[ $($pattern),*, ref rest @ ..] => { $bytes = rest; Some($to) }, )* + _ => None, + } + }; + } + + /// consumes all consecutive bytes matched by the `0-9` pattern from the given byte buffer + /// returning the number of consumed bytes + macro_rules! consume_dec_digits { + ($bytes:ident) => {{ + let start_len = $bytes.len(); + while let &[b'0'..=b'9', ref rest @ ..] = $bytes { + $bytes = rest; + } + start_len - $bytes.len() + }}; + } + + /// consumes all consecutive bytes matched by the `0-9 | a-f | A-F` pattern from the given byte buffer + /// returning the number of consumed bytes + macro_rules! consume_hex_digits { + ($bytes:ident) => {{ + let start_len = $bytes.len(); + while let &[b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F', ref rest @ ..] = $bytes { + $bytes = rest; + } + start_len - $bytes.len() + }}; + } + + macro_rules! consume_float_suffix { + ($bytes:ident) => { + consume_map!($bytes, [ + b'h' => FloatKind::F16, + b'f' => FloatKind::F32, + b'l', b'f' => FloatKind::F64, + ]) + }; + } + + /// maps the given `&[u8]` (tail of the initial `input: &str`) to a `&str` + macro_rules! rest_to_str { + ($bytes:ident) => { + &input[input.len() - $bytes.len()..] + }; + } + + struct ExtractSubStr<'a>(&'a str); + + impl<'a> ExtractSubStr<'a> { + /// given an `input` and a `start` (tail of the `input`) + /// creates a new [`ExtractSubStr`](`Self`) + fn start(input: &'a str, start: &'a [u8]) -> Self { + let start = input.len() - start.len(); + Self(&input[start..]) + } + /// given an `end` (tail of the initial `input`) + /// returns a substring of `input` + fn end(&self, end: &'a [u8]) -> &'a str { + let end = self.0.len() - end.len(); + &self.0[..end] + } + } + + let mut bytes = input.as_bytes(); + + let general_extract = ExtractSubStr::start(input, bytes); + + if consume!(bytes, b'0', b'x' | b'X') { + let digits_extract = ExtractSubStr::start(input, bytes); + + let consumed = consume_hex_digits!(bytes); + + if consume!(bytes, b'.') { + let consumed_after_period = consume_hex_digits!(bytes); + + if consumed + consumed_after_period == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let significand = general_extract.end(bytes); + + if consume!(bytes, b'p' | b'P') { + consume!(bytes, b'+' | b'-'); + let consumed = consume_dec_digits!(bytes); + + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let number = general_extract.end(bytes); + + let kind = consume_float_suffix!(bytes); + + (parse_hex_float(number, kind), rest_to_str!(bytes)) + } else { + ( + parse_hex_float_missing_exponent(significand, None), + rest_to_str!(bytes), + ) + } + } else { + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let significand = general_extract.end(bytes); + let digits = digits_extract.end(bytes); + + let exp_extract = ExtractSubStr::start(input, bytes); + + if consume!(bytes, b'p' | b'P') { + consume!(bytes, b'+' | b'-'); + let consumed = consume_dec_digits!(bytes); + + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let exponent = exp_extract.end(bytes); + + let kind = consume_float_suffix!(bytes); + + ( + parse_hex_float_missing_period(significand, exponent, kind), + rest_to_str!(bytes), + ) + } else { + let kind = consume_map!(bytes, [b'i' => IntKind::I32, b'u' => IntKind::U32]); + + (parse_hex_int(digits, kind), rest_to_str!(bytes)) + } + } + } else { + let is_first_zero = bytes.first() == Some(&b'0'); + + let consumed = consume_dec_digits!(bytes); + + if consume!(bytes, b'.') { + let consumed_after_period = consume_dec_digits!(bytes); + + if consumed + consumed_after_period == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + if consume!(bytes, b'e' | b'E') { + consume!(bytes, b'+' | b'-'); + let consumed = consume_dec_digits!(bytes); + + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + } + + let number = general_extract.end(bytes); + + let kind = consume_float_suffix!(bytes); + + (parse_dec_float(number, kind), rest_to_str!(bytes)) + } else { + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + if consume!(bytes, b'e' | b'E') { + consume!(bytes, b'+' | b'-'); + let consumed = consume_dec_digits!(bytes); + + if consumed == 0 { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let number = general_extract.end(bytes); + + let kind = consume_float_suffix!(bytes); + + (parse_dec_float(number, kind), rest_to_str!(bytes)) + } else { + // make sure the multi-digit numbers don't start with zero + if consumed > 1 && is_first_zero { + return (Err(NumberError::Invalid), rest_to_str!(bytes)); + } + + let digits = general_extract.end(bytes); + + let kind = consume_map!(bytes, [ + b'i' => Kind::Int(IntKind::I32), + b'u' => Kind::Int(IntKind::U32), + b'h' => Kind::Float(FloatKind::F16), + b'f' => Kind::Float(FloatKind::F32), + b'l', b'f' => Kind::Float(FloatKind::F64), + ]); + + (parse_dec(digits, kind), rest_to_str!(bytes)) + } + } + } +} + +fn parse_hex_float_missing_exponent( + // format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) + significand: &str, + kind: Option<FloatKind>, +) -> Result<Number, NumberError> { + let hexf_input = format!("{}{}", significand, "p0"); + parse_hex_float(&hexf_input, kind) +} + +fn parse_hex_float_missing_period( + // format: 0[xX] [0-9a-fA-F]+ + significand: &str, + // format: [pP][+-]?[0-9]+ + exponent: &str, + kind: Option<FloatKind>, +) -> Result<Number, NumberError> { + let hexf_input = format!("{significand}.{exponent}"); + parse_hex_float(&hexf_input, kind) +} + +fn parse_hex_int( + // format: [0-9a-fA-F]+ + digits: &str, + kind: Option<IntKind>, +) -> Result<Number, NumberError> { + parse_int(digits, kind, 16) +} + +fn parse_dec( + // format: ( [0-9] | [1-9][0-9]+ ) + digits: &str, + kind: Option<Kind>, +) -> Result<Number, NumberError> { + match kind { + None => parse_int(digits, None, 10), + Some(Kind::Int(kind)) => parse_int(digits, Some(kind), 10), + Some(Kind::Float(kind)) => parse_dec_float(digits, Some(kind)), + } +} + +// Float parsing notes + +// The following chapters of IEEE 754-2019 are relevant: +// +// 7.4 Overflow (largest finite number is exceeded by what would have been +// the rounded floating-point result were the exponent range unbounded) +// +// 7.5 Underflow (tiny non-zero result is detected; +// for decimal formats tininess is detected before rounding when a non-zero result +// computed as though both the exponent range and the precision were unbounded +// would lie strictly between 2^−126) +// +// 7.6 Inexact (rounded result differs from what would have been computed +// were both exponent range and precision unbounded) + +// The WGSL spec requires us to error: +// on overflow for decimal floating point literals +// on overflow and inexact for hexadecimal floating point literals +// (underflow is not mentioned) + +// hexf_parse errors on overflow, underflow, inexact +// rust std lib float from str handles overflow, underflow, inexact transparently (rounds and will not error) + +// Therefore we only check for overflow manually for decimal floating point literals + +// input format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) [pP][+-]?[0-9]+ +fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> { + match kind { + None => match hexf_parse::parse_hexf64(input, false) { + Ok(num) => Ok(Number::AbstractFloat(num)), + // can only be ParseHexfErrorKind::Inexact but we can't check since it's private + _ => Err(NumberError::NotRepresentable), + }, + Some(FloatKind::F16) => Err(NumberError::UnimplementedF16), + Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) { + Ok(num) => Ok(Number::F32(num)), + // can only be ParseHexfErrorKind::Inexact but we can't check since it's private + _ => Err(NumberError::NotRepresentable), + }, + Some(FloatKind::F64) => match hexf_parse::parse_hexf64(input, false) { + Ok(num) => Ok(Number::F64(num)), + // can only be ParseHexfErrorKind::Inexact but we can't check since it's private + _ => Err(NumberError::NotRepresentable), + }, + } +} + +// input format: ( [0-9]+\.[0-9]* | [0-9]*\.[0-9]+ ) ([eE][+-]?[0-9]+)? +// | [0-9]+ [eE][+-]?[0-9]+ +fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> { + match kind { + None => { + let num = input.parse::<f64>().unwrap(); // will never fail + num.is_finite() + .then_some(Number::AbstractFloat(num)) + .ok_or(NumberError::NotRepresentable) + } + Some(FloatKind::F32) => { + let num = input.parse::<f32>().unwrap(); // will never fail + num.is_finite() + .then_some(Number::F32(num)) + .ok_or(NumberError::NotRepresentable) + } + Some(FloatKind::F64) => { + let num = input.parse::<f64>().unwrap(); // will never fail + num.is_finite() + .then_some(Number::F64(num)) + .ok_or(NumberError::NotRepresentable) + } + Some(FloatKind::F16) => Err(NumberError::UnimplementedF16), + } +} + +fn parse_int(input: &str, kind: Option<IntKind>, radix: u32) -> Result<Number, NumberError> { + fn map_err(e: core::num::ParseIntError) -> NumberError { + match *e.kind() { + core::num::IntErrorKind::PosOverflow | core::num::IntErrorKind::NegOverflow => { + NumberError::NotRepresentable + } + _ => unreachable!(), + } + } + match kind { + None => match i64::from_str_radix(input, radix) { + Ok(num) => Ok(Number::AbstractInt(num)), + Err(e) => Err(map_err(e)), + }, + Some(IntKind::I32) => match i32::from_str_radix(input, radix) { + Ok(num) => Ok(Number::I32(num)), + Err(e) => Err(map_err(e)), + }, + Some(IntKind::U32) => match u32::from_str_radix(input, radix) { + Ok(num) => Ok(Number::U32(num)), + Err(e) => Err(map_err(e)), + }, + } +} diff --git a/third_party/rust/naga/src/front/wgsl/tests.rs b/third_party/rust/naga/src/front/wgsl/tests.rs new file mode 100644 index 0000000000..eb2f8a2eb3 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/tests.rs @@ -0,0 +1,637 @@ +use super::parse_str; + +#[test] +fn parse_comment() { + parse_str( + "// + //// + ///////////////////////////////////////////////////////// asda + //////////////////// dad ////////// / + ///////////////////////////////////////////////////////////////////////////////////////////////////// + // + ", + ) + .unwrap(); +} + +#[test] +fn parse_types() { + parse_str("const a : i32 = 2;").unwrap(); + assert!(parse_str("const a : x32 = 2;").is_err()); + parse_str("var t: texture_2d<f32>;").unwrap(); + parse_str("var t: texture_cube_array<i32>;").unwrap(); + parse_str("var t: texture_multisampled_2d<u32>;").unwrap(); + parse_str("var t: texture_storage_1d<rgba8uint,write>;").unwrap(); + parse_str("var t: texture_storage_3d<r32float,read>;").unwrap(); +} + +#[test] +fn parse_type_inference() { + parse_str( + " + fn foo() { + let a = 2u; + let b: u32 = a; + var x = 3.; + var y = vec2<f32>(1, 2); + }", + ) + .unwrap(); + assert!(parse_str( + " + fn foo() { let c : i32 = 2.0; }", + ) + .is_err()); +} + +#[test] +fn parse_type_cast() { + parse_str( + " + const a : i32 = 2; + fn main() { + var x: f32 = f32(a); + x = f32(i32(a + 1) / 2); + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + let x: vec2<f32> = vec2<f32>(1.0, 2.0); + let y: vec2<u32> = vec2<u32>(x); + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + let x: vec2<f32> = vec2<f32>(0.0); + } + ", + ) + .unwrap(); + assert!(parse_str( + " + fn main() { + let x: vec2<f32> = vec2<f32>(0i, 0i); + } + ", + ) + .is_err()); +} + +#[test] +fn parse_struct() { + parse_str( + " + struct Foo { x: i32 } + struct Bar { + @size(16) x: vec2<i32>, + @align(16) y: f32, + @size(32) @align(128) z: vec3<f32>, + }; + struct Empty {} + var<storage,read_write> s: Foo; + ", + ) + .unwrap(); +} + +#[test] +fn parse_standard_fun() { + parse_str( + " + fn main() { + var x: i32 = min(max(1, 2), 3); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_statement() { + parse_str( + " + fn main() { + ; + {} + {;} + } + ", + ) + .unwrap(); + + parse_str( + " + fn foo() {} + fn bar() { foo(); } + ", + ) + .unwrap(); +} + +#[test] +fn parse_if() { + parse_str( + " + fn main() { + if true { + discard; + } else {} + if 0 != 1 {} + if false { + return; + } else if true { + return; + } else {} + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_parentheses_if() { + parse_str( + " + fn main() { + if (true) { + discard; + } else {} + if (0 != 1) {} + if (false) { + return; + } else if (true) { + return; + } else {} + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_loop() { + parse_str( + " + fn main() { + var i: i32 = 0; + loop { + if i == 1 { break; } + continuing { i = 1; } + } + loop { + if i == 0 { continue; } + break; + } + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + var found: bool = false; + var i: i32 = 0; + while !found { + if i == 10 { + found = true; + } + + i = i + 1; + } + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + while true { + break; + } + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + var a: i32 = 0; + for(var i: i32 = 0; i < 4; i = i + 1) { + a = a + 2; + } + } + ", + ) + .unwrap(); + parse_str( + " + fn main() { + for(;;) { + break; + } + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_switch() { + parse_str( + " + fn main() { + var pos: f32; + switch (3) { + case 0, 1: { pos = 0.0; } + case 2: { pos = 1.0; } + default: { pos = 3.0; } + } + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_switch_optional_colon_in_case() { + parse_str( + " + fn main() { + var pos: f32; + switch (3) { + case 0, 1 { pos = 0.0; } + case 2 { pos = 1.0; } + default { pos = 3.0; } + } + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_switch_default_in_case() { + parse_str( + " + fn main() { + var pos: f32; + switch (3) { + case 0, 1: { pos = 0.0; } + case 2: {} + case default, 3: { pos = 3.0; } + } + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_parentheses_switch() { + parse_str( + " + fn main() { + var pos: f32; + switch pos > 1.0 { + default: { pos = 3.0; } + } + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_texture_load() { + parse_str( + " + var t: texture_3d<u32>; + fn foo() { + let r: vec4<u32> = textureLoad(t, vec3<u32>(0u, 1u, 2u), 1); + } + ", + ) + .unwrap(); + parse_str( + " + var t: texture_multisampled_2d_array<i32>; + fn foo() { + let r: vec4<i32> = textureLoad(t, vec2<i32>(10, 20), 2, 3); + } + ", + ) + .unwrap(); + parse_str( + " + var t: texture_storage_1d_array<r32float,read>; + fn foo() { + let r: vec4<f32> = textureLoad(t, 10, 2); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_texture_store() { + parse_str( + " + var t: texture_storage_2d<rgba8unorm,write>; + fn foo() { + textureStore(t, vec2<i32>(10, 20), vec4<f32>(0.0, 1.0, 2.0, 3.0)); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_texture_query() { + parse_str( + " + var t: texture_multisampled_2d_array<f32>; + fn foo() { + var dim: vec2<u32> = textureDimensions(t); + dim = textureDimensions(t, 0); + let layers: u32 = textureNumLayers(t); + let samples: u32 = textureNumSamples(t); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_postfix() { + parse_str( + "fn foo() { + let x: f32 = vec4<f32>(1.0, 2.0, 3.0, 4.0).xyz.rgbr.aaaa.wz.g; + let y: f32 = fract(vec2<f32>(0.5, x)).x; + }", + ) + .unwrap(); +} + +#[test] +fn parse_expressions() { + parse_str("fn foo() { + let x: f32 = select(0.0, 1.0, true); + let y: vec2<f32> = select(vec2<f32>(1.0, 1.0), vec2<f32>(x, x), vec2<bool>(x < 0.5, x > 0.5)); + let z: bool = !(0.0 == 1.0); + }").unwrap(); +} + +#[test] +fn binary_expression_mixed_scalar_and_vector_operands() { + for (operand, expect_splat) in [ + ('<', false), + ('>', false), + ('&', false), + ('|', false), + ('+', true), + ('-', true), + ('*', false), + ('/', true), + ('%', true), + ] { + let module = parse_str(&format!( + " + @fragment + fn main(@location(0) some_vec: vec3<f32>) -> @location(0) vec4<f32> {{ + if (all(1.0 {operand} some_vec)) {{ + return vec4(0.0); + }} + return vec4(1.0); + }} + " + )) + .unwrap(); + + let expressions = &&module.entry_points[0].function.expressions; + + let found_expressions = expressions + .iter() + .filter(|&(_, e)| { + if let crate::Expression::Binary { left, .. } = *e { + matches!( + (expect_splat, &expressions[left]), + (false, &crate::Expression::Literal(crate::Literal::F32(..))) + | (true, &crate::Expression::Splat { .. }) + ) + } else { + false + } + }) + .count(); + + assert_eq!( + found_expressions, + 1, + "expected `{operand}` expression {} splat", + if expect_splat { "with" } else { "without" } + ); + } + + let module = parse_str( + "@fragment + fn main(mat: mat3x3<f32>) { + let vec = vec3<f32>(1.0, 1.0, 1.0); + let result = mat / vec; + }", + ) + .unwrap(); + let expressions = &&module.entry_points[0].function.expressions; + let found_splat = expressions.iter().any(|(_, e)| { + if let crate::Expression::Binary { left, .. } = *e { + matches!(&expressions[left], &crate::Expression::Splat { .. }) + } else { + false + } + }); + assert!(!found_splat, "'mat / vec' should not be splatted"); +} + +#[test] +fn parse_pointers() { + parse_str( + "fn foo(a: ptr<private, f32>) -> f32 { return *a; } + fn bar() { + var x: f32 = 1.0; + let px = &x; + let py = foo(px); + }", + ) + .unwrap(); +} + +#[test] +fn parse_struct_instantiation() { + parse_str( + " + struct Foo { + a: f32, + b: vec3<f32>, + } + + @fragment + fn fs_main() { + var foo: Foo = Foo(0.0, vec3<f32>(0.0, 1.0, 42.0)); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_array_length() { + parse_str( + " + struct Foo { + data: array<u32> + } // this is used as both input and output for convenience + + @group(0) @binding(0) + var<storage> foo: Foo; + + @group(0) @binding(1) + var<storage> bar: array<u32>; + + fn baz() { + var x: u32 = arrayLength(foo.data); + var y: u32 = arrayLength(bar); + } + ", + ) + .unwrap(); +} + +#[test] +fn parse_storage_buffers() { + parse_str( + " + @group(0) @binding(0) + var<storage> foo: array<u32>; + ", + ) + .unwrap(); + parse_str( + " + @group(0) @binding(0) + var<storage,read> foo: array<u32>; + ", + ) + .unwrap(); + parse_str( + " + @group(0) @binding(0) + var<storage,write> foo: array<u32>; + ", + ) + .unwrap(); + parse_str( + " + @group(0) @binding(0) + var<storage,read_write> foo: array<u32>; + ", + ) + .unwrap(); +} + +#[test] +fn parse_alias() { + parse_str( + " + alias Vec4 = vec4<f32>; + ", + ) + .unwrap(); +} + +#[test] +fn parse_texture_load_store_expecting_four_args() { + for (func, texture) in [ + ( + "textureStore", + "texture_storage_2d_array<rg11b10float, write>", + ), + ("textureLoad", "texture_2d_array<i32>"), + ] { + let error = parse_str(&format!( + " + @group(0) @binding(0) var tex_los_res: {texture}; + @compute + @workgroup_size(1) + fn main(@builtin(global_invocation_id) id: vec3<u32>) {{ + var color = vec4(1, 1, 1, 1); + {func}(tex_los_res, id, color); + }} + " + )) + .unwrap_err(); + assert_eq!( + error.message(), + "wrong number of arguments: expected 4, found 3" + ); + } +} + +#[test] +fn parse_repeated_attributes() { + use crate::{ + front::wgsl::{error::Error, Frontend}, + Span, + }; + + let template_vs = "@vertex fn vs() -> __REPLACE__ vec4<f32> { return vec4<f32>(0.0); }"; + let template_struct = "struct A { __REPLACE__ data: vec3<f32> }"; + let template_resource = "__REPLACE__ var tex_los_res: texture_2d_array<i32>;"; + let template_stage = "__REPLACE__ fn vs() -> vec4<f32> { return vec4<f32>(0.0); }"; + for (attribute, template) in [ + ("align(16)", template_struct), + ("binding(0)", template_resource), + ("builtin(position)", template_vs), + ("compute", template_stage), + ("fragment", template_stage), + ("group(0)", template_resource), + ("interpolate(flat)", template_vs), + ("invariant", template_vs), + ("location(0)", template_vs), + ("size(16)", template_struct), + ("vertex", template_stage), + ("early_depth_test(less_equal)", template_resource), + ("workgroup_size(1)", template_stage), + ] { + let shader = template.replace("__REPLACE__", &format!("@{attribute} @{attribute}")); + let name_length = attribute.rfind('(').unwrap_or(attribute.len()) as u32; + let span_start = shader.rfind(attribute).unwrap() as u32; + let span_end = span_start + name_length; + let expected_span = Span::new(span_start, span_end); + + let result = Frontend::new().inner(&shader); + assert!(matches!( + result.unwrap_err(), + Error::RepeatedAttribute(span) if span == expected_span + )); + } +} + +#[test] +fn parse_missing_workgroup_size() { + use crate::{ + front::wgsl::{error::Error, Frontend}, + Span, + }; + + let shader = "@compute fn vs() -> vec4<f32> { return vec4<f32>(0.0); }"; + let result = Frontend::new().inner(shader); + assert!(matches!( + result.unwrap_err(), + Error::MissingWorkgroupSize(span) if span == Span::new(1, 8) + )); +} diff --git a/third_party/rust/naga/src/front/wgsl/to_wgsl.rs b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs new file mode 100644 index 0000000000..c8331ace09 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs @@ -0,0 +1,283 @@ +//! Producing the WGSL forms of types, for use in error messages. + +use crate::proc::GlobalCtx; +use crate::Handle; + +impl crate::proc::TypeResolution { + pub fn to_wgsl(&self, gctx: &GlobalCtx) -> String { + match *self { + crate::proc::TypeResolution::Handle(handle) => handle.to_wgsl(gctx), + crate::proc::TypeResolution::Value(ref inner) => inner.to_wgsl(gctx), + } + } +} + +impl Handle<crate::Type> { + /// Formats the type as it is written in wgsl. + /// + /// For example `vec3<f32>`. + pub fn to_wgsl(self, gctx: &GlobalCtx) -> String { + let ty = &gctx.types[self]; + match ty.name { + Some(ref name) => name.clone(), + None => ty.inner.to_wgsl(gctx), + } + } +} + +impl crate::TypeInner { + /// Formats the type as it is written in wgsl. + /// + /// For example `vec3<f32>`. + /// + /// Note: `TypeInner::Struct` doesn't include the name of the + /// struct type. Therefore this method will simply return "struct" + /// for them. + pub fn to_wgsl(&self, gctx: &GlobalCtx) -> String { + use crate::TypeInner as Ti; + + match *self { + Ti::Scalar(scalar) => scalar.to_wgsl(), + Ti::Vector { size, scalar } => { + format!("vec{}<{}>", size as u32, scalar.to_wgsl()) + } + Ti::Matrix { + columns, + rows, + scalar, + } => { + format!( + "mat{}x{}<{}>", + columns as u32, + rows as u32, + scalar.to_wgsl(), + ) + } + Ti::Atomic(scalar) => { + format!("atomic<{}>", scalar.to_wgsl()) + } + Ti::Pointer { base, .. } => { + let name = base.to_wgsl(gctx); + format!("ptr<{name}>") + } + Ti::ValuePointer { scalar, .. } => { + format!("ptr<{}>", scalar.to_wgsl()) + } + Ti::Array { base, size, .. } => { + let base = base.to_wgsl(gctx); + match size { + crate::ArraySize::Constant(size) => format!("array<{base}, {size}>"), + crate::ArraySize::Dynamic => format!("array<{base}>"), + } + } + Ti::Struct { .. } => { + // TODO: Actually output the struct? + "struct".to_string() + } + Ti::Image { + dim, + arrayed, + class, + } => { + let dim_suffix = match dim { + crate::ImageDimension::D1 => "_1d", + crate::ImageDimension::D2 => "_2d", + crate::ImageDimension::D3 => "_3d", + crate::ImageDimension::Cube => "_cube", + }; + let array_suffix = if arrayed { "_array" } else { "" }; + + let class_suffix = match class { + crate::ImageClass::Sampled { multi: true, .. } => "_multisampled", + crate::ImageClass::Depth { multi: false } => "_depth", + crate::ImageClass::Depth { multi: true } => "_depth_multisampled", + crate::ImageClass::Sampled { multi: false, .. } + | crate::ImageClass::Storage { .. } => "", + }; + + let type_in_brackets = match class { + crate::ImageClass::Sampled { kind, .. } => { + // Note: The only valid widths are 4 bytes wide. + // The lexer has already verified this, so we can safely assume it here. + // https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type + let element_type = crate::Scalar { kind, width: 4 }.to_wgsl(); + format!("<{element_type}>") + } + crate::ImageClass::Depth { multi: _ } => String::new(), + crate::ImageClass::Storage { format, access } => { + if access.contains(crate::StorageAccess::STORE) { + format!("<{},write>", format.to_wgsl()) + } else { + format!("<{}>", format.to_wgsl()) + } + } + }; + + format!("texture{class_suffix}{dim_suffix}{array_suffix}{type_in_brackets}") + } + Ti::Sampler { .. } => "sampler".to_string(), + Ti::AccelerationStructure => "acceleration_structure".to_string(), + Ti::RayQuery => "ray_query".to_string(), + Ti::BindingArray { base, size, .. } => { + let member_type = &gctx.types[base]; + let base = member_type.name.as_deref().unwrap_or("unknown"); + match size { + crate::ArraySize::Constant(size) => format!("binding_array<{base}, {size}>"), + crate::ArraySize::Dynamic => format!("binding_array<{base}>"), + } + } + } + } +} + +impl crate::Scalar { + /// Format a scalar kind+width as a type is written in wgsl. + /// + /// Examples: `f32`, `u64`, `bool`. + pub fn to_wgsl(self) -> String { + let prefix = match self.kind { + crate::ScalarKind::Sint => "i", + crate::ScalarKind::Uint => "u", + crate::ScalarKind::Float => "f", + crate::ScalarKind::Bool => return "bool".to_string(), + crate::ScalarKind::AbstractInt => return "{AbstractInt}".to_string(), + crate::ScalarKind::AbstractFloat => return "{AbstractFloat}".to_string(), + }; + format!("{}{}", prefix, self.width * 8) + } +} + +impl crate::StorageFormat { + pub const fn to_wgsl(self) -> &'static str { + use crate::StorageFormat as Sf; + match self { + Sf::R8Unorm => "r8unorm", + Sf::R8Snorm => "r8snorm", + Sf::R8Uint => "r8uint", + Sf::R8Sint => "r8sint", + Sf::R16Uint => "r16uint", + Sf::R16Sint => "r16sint", + Sf::R16Float => "r16float", + Sf::Rg8Unorm => "rg8unorm", + Sf::Rg8Snorm => "rg8snorm", + Sf::Rg8Uint => "rg8uint", + Sf::Rg8Sint => "rg8sint", + Sf::R32Uint => "r32uint", + Sf::R32Sint => "r32sint", + Sf::R32Float => "r32float", + Sf::Rg16Uint => "rg16uint", + Sf::Rg16Sint => "rg16sint", + Sf::Rg16Float => "rg16float", + Sf::Rgba8Unorm => "rgba8unorm", + Sf::Rgba8Snorm => "rgba8snorm", + Sf::Rgba8Uint => "rgba8uint", + Sf::Rgba8Sint => "rgba8sint", + Sf::Bgra8Unorm => "bgra8unorm", + Sf::Rgb10a2Uint => "rgb10a2uint", + Sf::Rgb10a2Unorm => "rgb10a2unorm", + Sf::Rg11b10Float => "rg11b10float", + Sf::Rg32Uint => "rg32uint", + Sf::Rg32Sint => "rg32sint", + Sf::Rg32Float => "rg32float", + Sf::Rgba16Uint => "rgba16uint", + Sf::Rgba16Sint => "rgba16sint", + Sf::Rgba16Float => "rgba16float", + Sf::Rgba32Uint => "rgba32uint", + Sf::Rgba32Sint => "rgba32sint", + Sf::Rgba32Float => "rgba32float", + Sf::R16Unorm => "r16unorm", + Sf::R16Snorm => "r16snorm", + Sf::Rg16Unorm => "rg16unorm", + Sf::Rg16Snorm => "rg16snorm", + Sf::Rgba16Unorm => "rgba16unorm", + Sf::Rgba16Snorm => "rgba16snorm", + } + } +} + +mod tests { + #[test] + fn to_wgsl() { + use std::num::NonZeroU32; + + let mut types = crate::UniqueArena::new(); + + let mytype1 = types.insert( + crate::Type { + name: Some("MyType1".to_string()), + inner: crate::TypeInner::Struct { + members: vec![], + span: 0, + }, + }, + Default::default(), + ); + let mytype2 = types.insert( + crate::Type { + name: Some("MyType2".to_string()), + inner: crate::TypeInner::Struct { + members: vec![], + span: 0, + }, + }, + Default::default(), + ); + + let gctx = crate::proc::GlobalCtx { + types: &types, + constants: &crate::Arena::new(), + const_expressions: &crate::Arena::new(), + }; + let array = crate::TypeInner::Array { + base: mytype1, + stride: 4, + size: crate::ArraySize::Constant(unsafe { NonZeroU32::new_unchecked(32) }), + }; + assert_eq!(array.to_wgsl(&gctx), "array<MyType1, 32>"); + + let mat = crate::TypeInner::Matrix { + rows: crate::VectorSize::Quad, + columns: crate::VectorSize::Bi, + scalar: crate::Scalar::F64, + }; + assert_eq!(mat.to_wgsl(&gctx), "mat2x4<f64>"); + + let ptr = crate::TypeInner::Pointer { + base: mytype2, + space: crate::AddressSpace::Storage { + access: crate::StorageAccess::default(), + }, + }; + assert_eq!(ptr.to_wgsl(&gctx), "ptr<MyType2>"); + + let img1 = crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Sampled { + kind: crate::ScalarKind::Float, + multi: true, + }, + }; + assert_eq!(img1.to_wgsl(&gctx), "texture_multisampled_2d<f32>"); + + let img2 = crate::TypeInner::Image { + dim: crate::ImageDimension::Cube, + arrayed: true, + class: crate::ImageClass::Depth { multi: false }, + }; + assert_eq!(img2.to_wgsl(&gctx), "texture_depth_cube_array"); + + let img3 = crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Depth { multi: true }, + }; + assert_eq!(img3.to_wgsl(&gctx), "texture_depth_multisampled_2d"); + + let array = crate::TypeInner::BindingArray { + base: mytype1, + size: crate::ArraySize::Constant(unsafe { NonZeroU32::new_unchecked(32) }), + }; + assert_eq!(array.to_wgsl(&gctx), "binding_array<MyType1, 32>"); + } +} diff --git a/third_party/rust/naga/src/keywords/mod.rs b/third_party/rust/naga/src/keywords/mod.rs new file mode 100644 index 0000000000..d54a1704f7 --- /dev/null +++ b/third_party/rust/naga/src/keywords/mod.rs @@ -0,0 +1,6 @@ +/*! +Lists of reserved keywords for each shading language with a [frontend][crate::front] or [backend][crate::back]. +*/ + +#[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))] +pub mod wgsl; diff --git a/third_party/rust/naga/src/keywords/wgsl.rs b/third_party/rust/naga/src/keywords/wgsl.rs new file mode 100644 index 0000000000..7b47a13128 --- /dev/null +++ b/third_party/rust/naga/src/keywords/wgsl.rs @@ -0,0 +1,229 @@ +/*! +Keywords for [WGSL][wgsl] (WebGPU Shading Language). + +[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html +*/ + +// https://gpuweb.github.io/gpuweb/wgsl/#keyword-summary +// last sync: https://github.com/gpuweb/gpuweb/blob/39f2321f547c8f0b7f473cf1d47fba30b1691303/wgsl/index.bs +pub const RESERVED: &[&str] = &[ + // Type-defining Keywords + "array", + "atomic", + "bool", + "f32", + "f16", + "i32", + "mat2x2", + "mat2x3", + "mat2x4", + "mat3x2", + "mat3x3", + "mat3x4", + "mat4x2", + "mat4x3", + "mat4x4", + "ptr", + "sampler", + "sampler_comparison", + "texture_1d", + "texture_2d", + "texture_2d_array", + "texture_3d", + "texture_cube", + "texture_cube_array", + "texture_multisampled_2d", + "texture_storage_1d", + "texture_storage_2d", + "texture_storage_2d_array", + "texture_storage_3d", + "texture_depth_2d", + "texture_depth_2d_array", + "texture_depth_cube", + "texture_depth_cube_array", + "texture_depth_multisampled_2d", + "u32", + "vec2", + "vec3", + "vec4", + // Other Keywords + "alias", + "bitcast", + "break", + "case", + "const", + "continue", + "continuing", + "default", + "discard", + "else", + "enable", + "false", + "fn", + "for", + "if", + "let", + "loop", + "override", + "return", + "static_assert", + "struct", + "switch", + "true", + "type", + "var", + "while", + // Reserved Words + "CompileShader", + "ComputeShader", + "DomainShader", + "GeometryShader", + "Hullshader", + "NULL", + "Self", + "abstract", + "active", + "alignas", + "alignof", + "as", + "asm", + "asm_fragment", + "async", + "attribute", + "auto", + "await", + "become", + "binding_array", + "cast", + "catch", + "class", + "co_await", + "co_return", + "co_yield", + "coherent", + "column_major", + "common", + "compile", + "compile_fragment", + "concept", + "const_cast", + "consteval", + "constexpr", + "constinit", + "crate", + "debugger", + "decltype", + "delete", + "demote", + "demote_to_helper", + "do", + "dynamic_cast", + "enum", + "explicit", + "export", + "extends", + "extern", + "external", + "fallthrough", + "filter", + "final", + "finally", + "friend", + "from", + "fxgroup", + "get", + "goto", + "groupshared", + "handle", + "highp", + "impl", + "implements", + "import", + "inline", + "inout", + "instanceof", + "interface", + "layout", + "lowp", + "macro", + "macro_rules", + "match", + "mediump", + "meta", + "mod", + "module", + "move", + "mut", + "mutable", + "namespace", + "new", + "nil", + "noexcept", + "noinline", + "nointerpolation", + "noperspective", + "null", + "nullptr", + "of", + "operator", + "package", + "packoffset", + "partition", + "pass", + "patch", + "pixelfragment", + "precise", + "precision", + "premerge", + "priv", + "protected", + "pub", + "public", + "readonly", + "ref", + "regardless", + "register", + "reinterpret_cast", + "requires", + "resource", + "restrict", + "self", + "set", + "shared", + "signed", + "sizeof", + "smooth", + "snorm", + "static", + "static_assert", + "static_cast", + "std", + "subroutine", + "super", + "target", + "template", + "this", + "thread_local", + "throw", + "trait", + "try", + "typedef", + "typeid", + "typename", + "typeof", + "union", + "unless", + "unorm", + "unsafe", + "unsized", + "use", + "using", + "varying", + "virtual", + "volatile", + "wgsl", + "where", + "with", + "writeonly", + "yield", +]; diff --git a/third_party/rust/naga/src/lib.rs b/third_party/rust/naga/src/lib.rs new file mode 100644 index 0000000000..d6b9c6a7f4 --- /dev/null +++ b/third_party/rust/naga/src/lib.rs @@ -0,0 +1,2072 @@ +/*! Universal shader translator. + +The central structure of the crate is [`Module`]. A `Module` contains: + +- [`Function`]s, which have arguments, a return type, local variables, and a body, + +- [`EntryPoint`]s, which are specialized functions that can serve as the entry + point for pipeline stages like vertex shading or fragment shading, + +- [`Constant`]s and [`GlobalVariable`]s used by `EntryPoint`s and `Function`s, and + +- [`Type`]s used by the above. + +The body of an `EntryPoint` or `Function` is represented using two types: + +- An [`Expression`] produces a value, but has no side effects or control flow. + `Expressions` include variable references, unary and binary operators, and so + on. + +- A [`Statement`] can have side effects and structured control flow. + `Statement`s do not produce a value, other than by storing one in some + designated place. `Statements` include blocks, conditionals, and loops, but also + operations that have side effects, like stores and function calls. + +`Statement`s form a tree, with pointers into the DAG of `Expression`s. + +Restricting side effects to statements simplifies analysis and code generation. +A Naga backend can generate code to evaluate an `Expression` however and +whenever it pleases, as long as it is certain to observe the side effects of all +previously executed `Statement`s. + +Many `Statement` variants use the [`Block`] type, which is `Vec<Statement>`, +with optional span info, representing a series of statements executed in order. The body of an +`EntryPoint`s or `Function` is a `Block`, and `Statement` has a +[`Block`][Statement::Block] variant. + +If the `clone` feature is enabled, [`Arena`], [`UniqueArena`], [`Type`], [`TypeInner`], +[`Constant`], [`Function`], [`EntryPoint`] and [`Module`] can be cloned. + +## Arenas + +To improve translator performance and reduce memory usage, most structures are +stored in an [`Arena`]. An `Arena<T>` stores a series of `T` values, indexed by +[`Handle<T>`](Handle) values, which are just wrappers around integer indexes. +For example, a `Function`'s expressions are stored in an `Arena<Expression>`, +and compound expressions refer to their sub-expressions via `Handle<Expression>` +values. (When examining the serialized form of a `Module`, note that the first +element of an `Arena` has an index of 1, not 0.) + +A [`UniqueArena`] is just like an `Arena`, except that it stores only a single +instance of each value. The value type must implement `Eq` and `Hash`. Like an +`Arena`, inserting a value into a `UniqueArena` returns a `Handle` which can be +used to efficiently access the value, without a hash lookup. Inserting a value +multiple times returns the same `Handle`. + +If the `span` feature is enabled, both `Arena` and `UniqueArena` can associate a +source code span with each element. + +## Function Calls + +Naga's representation of function calls is unusual. Most languages treat +function calls as expressions, but because calls may have side effects, Naga +represents them as a kind of statement, [`Statement::Call`]. If the function +returns a value, a call statement designates a particular [`Expression::CallResult`] +expression to represent its return value, for use by subsequent statements and +expressions. + +## `Expression` evaluation time + +It is essential to know when an [`Expression`] should be evaluated, because its +value may depend on previous [`Statement`]s' effects. But whereas the order of +execution for a tree of `Statement`s is apparent from its structure, it is not +so clear for `Expressions`, since an expression may be referred to by any number +of `Statement`s and other `Expression`s. + +Naga's rules for when `Expression`s are evaluated are as follows: + +- [`Literal`], [`Constant`], and [`ZeroValue`] expressions are + considered to be implicitly evaluated before execution begins. + +- [`FunctionArgument`] and [`LocalVariable`] expressions are considered + implicitly evaluated upon entry to the function to which they belong. + Function arguments cannot be assigned to, and `LocalVariable` expressions + produce a *pointer to* the variable's value (for use with [`Load`] and + [`Store`]). Neither varies while the function executes, so it suffices to + consider these expressions evaluated once on entry. + +- Similarly, [`GlobalVariable`] expressions are considered implicitly + evaluated before execution begins, since their value does not change while + code executes, for one of two reasons: + + - Most `GlobalVariable` expressions produce a pointer to the variable's + value, for use with [`Load`] and [`Store`], as `LocalVariable` + expressions do. Although the variable's value may change, its address + does not. + + - A `GlobalVariable` expression referring to a global in the + [`AddressSpace::Handle`] address space produces the value directly, not + a pointer. Such global variables hold opaque types like shaders or + images, and cannot be assigned to. + +- A [`CallResult`] expression that is the `result` of a [`Statement::Call`], + representing the call's return value, is evaluated when the `Call` statement + is executed. + +- Similarly, an [`AtomicResult`] expression that is the `result` of an + [`Atomic`] statement, representing the result of the atomic operation, is + evaluated when the `Atomic` statement is executed. + +- A [`RayQueryProceedResult`] expression, which is a boolean + indicating if the ray query is finished, is evaluated when the + [`RayQuery`] statement whose [`Proceed::result`] points to it is + executed. + +- All other expressions are evaluated when the (unique) [`Statement::Emit`] + statement that covers them is executed. + +Now, strictly speaking, not all `Expression` variants actually care when they're +evaluated. For example, you can evaluate a [`BinaryOperator::Add`] expression +any time you like, as long as you give it the right operands. It's really only a +very small set of expressions that are affected by timing: + +- [`Load`], [`ImageSample`], and [`ImageLoad`] expressions are influenced by + stores to the variables or images they access, and must execute at the + proper time relative to them. + +- [`Derivative`] expressions are sensitive to control flow uniformity: they + must not be moved out of an area of uniform control flow into a non-uniform + area. + +- More generally, any expression that's used by more than one other expression + or statement should probably be evaluated only once, and then stored in a + variable to be cited at each point of use. + +Naga tries to help back ends handle all these cases correctly in a somewhat +circuitous way. The [`ModuleInfo`] structure returned by [`Validator::validate`] +provides a reference count for each expression in each function in the module. +Naturally, any expression with a reference count of two or more deserves to be +evaluated and stored in a temporary variable at the point that the `Emit` +statement covering it is executed. But if we selectively lower the reference +count threshold to _one_ for the sensitive expression types listed above, so +that we _always_ generate a temporary variable and save their value, then the +same code that manages multiply referenced expressions will take care of +introducing temporaries for time-sensitive expressions as well. The +`Expression::bake_ref_count` method (private to the back ends) is meant to help +with this. + +## `Expression` scope + +Each `Expression` has a *scope*, which is the region of the function within +which it can be used by `Statement`s and other `Expression`s. It is a validation +error to use an `Expression` outside its scope. + +An expression's scope is defined as follows: + +- The scope of a [`Constant`], [`GlobalVariable`], [`FunctionArgument`] or + [`LocalVariable`] expression covers the entire `Function` in which it + occurs. + +- The scope of an expression evaluated by an [`Emit`] statement covers the + subsequent expressions in that `Emit`, the subsequent statements in the `Block` + to which that `Emit` belongs (if any) and their sub-statements (if any). + +- The `result` expression of a [`Call`] or [`Atomic`] statement has a scope + covering the subsequent statements in the `Block` in which the statement + occurs (if any) and their sub-statements (if any). + +For example, this implies that an expression evaluated by some statement in a +nested `Block` is not available in the `Block`'s parents. Such a value would +need to be stored in a local variable to be carried upwards in the statement +tree. + +## Constant expressions + +A Naga *constant expression* is one of the following [`Expression`] +variants, whose operands (if any) are also constant expressions: +- [`Literal`] +- [`Constant`], for [`Constant`s][const_type] whose [`override`] is [`None`] +- [`ZeroValue`], for fixed-size types +- [`Compose`] +- [`Access`] +- [`AccessIndex`] +- [`Splat`] +- [`Swizzle`] +- [`Unary`] +- [`Binary`] +- [`Select`] +- [`Relational`] +- [`Math`] +- [`As`] + +A constant expression can be evaluated at module translation time. + +## Override expressions + +A Naga *override expression* is the same as a [constant expression], +except that it is also allowed to refer to [`Constant`s][const_type] +whose [`override`] is something other than [`None`]. + +An override expression can be evaluated at pipeline creation time. + +[`AtomicResult`]: Expression::AtomicResult +[`RayQueryProceedResult`]: Expression::RayQueryProceedResult +[`CallResult`]: Expression::CallResult +[`Constant`]: Expression::Constant +[`ZeroValue`]: Expression::ZeroValue +[`Literal`]: Expression::Literal +[`Derivative`]: Expression::Derivative +[`FunctionArgument`]: Expression::FunctionArgument +[`GlobalVariable`]: Expression::GlobalVariable +[`ImageLoad`]: Expression::ImageLoad +[`ImageSample`]: Expression::ImageSample +[`Load`]: Expression::Load +[`LocalVariable`]: Expression::LocalVariable + +[`Atomic`]: Statement::Atomic +[`Call`]: Statement::Call +[`Emit`]: Statement::Emit +[`Store`]: Statement::Store +[`RayQuery`]: Statement::RayQuery + +[`Proceed::result`]: RayQueryFunction::Proceed::result + +[`Validator::validate`]: valid::Validator::validate +[`ModuleInfo`]: valid::ModuleInfo + +[`Literal`]: Expression::Literal +[`ZeroValue`]: Expression::ZeroValue +[`Compose`]: Expression::Compose +[`Access`]: Expression::Access +[`AccessIndex`]: Expression::AccessIndex +[`Splat`]: Expression::Splat +[`Swizzle`]: Expression::Swizzle +[`Unary`]: Expression::Unary +[`Binary`]: Expression::Binary +[`Select`]: Expression::Select +[`Relational`]: Expression::Relational +[`Math`]: Expression::Math +[`As`]: Expression::As + +[const_type]: Constant +[`override`]: Constant::override +[`None`]: Override::None + +[constant expression]: index.html#constant-expressions +*/ + +#![allow( + clippy::new_without_default, + clippy::unneeded_field_pattern, + clippy::match_like_matches_macro, + clippy::collapsible_if, + clippy::derive_partial_eq_without_eq, + clippy::needless_borrowed_reference, + clippy::single_match +)] +#![warn( + trivial_casts, + trivial_numeric_casts, + unused_extern_crates, + unused_qualifications, + clippy::pattern_type_mismatch, + clippy::missing_const_for_fn, + clippy::rest_pat_in_fully_bound_structs, + clippy::match_wildcard_for_single_variants +)] +#![deny(clippy::exit)] +#![cfg_attr( + not(test), + warn( + clippy::dbg_macro, + clippy::panic, + clippy::print_stderr, + clippy::print_stdout, + clippy::todo + ) +)] + +mod arena; +pub mod back; +mod block; +#[cfg(feature = "compact")] +pub mod compact; +pub mod front; +pub mod keywords; +pub mod proc; +mod span; +pub mod valid; + +pub use crate::arena::{Arena, Handle, Range, UniqueArena}; + +pub use crate::span::{SourceLocation, Span, SpanContext, WithSpan}; +#[cfg(feature = "arbitrary")] +use arbitrary::Arbitrary; +#[cfg(feature = "deserialize")] +use serde::Deserialize; +#[cfg(feature = "serialize")] +use serde::Serialize; + +/// Width of a boolean type, in bytes. +pub const BOOL_WIDTH: Bytes = 1; + +/// Width of abstract types, in bytes. +pub const ABSTRACT_WIDTH: Bytes = 8; + +/// Hash map that is faster but not resilient to DoS attacks. +pub type FastHashMap<K, T> = rustc_hash::FxHashMap<K, T>; +/// Hash set that is faster but not resilient to DoS attacks. +pub type FastHashSet<K> = rustc_hash::FxHashSet<K>; + +/// Insertion-order-preserving hash set (`IndexSet<K>`), but with the same +/// hasher as `FastHashSet<K>` (faster but not resilient to DoS attacks). +pub type FastIndexSet<K> = + indexmap::IndexSet<K, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>; + +/// Insertion-order-preserving hash map (`IndexMap<K, V>`), but with the same +/// hasher as `FastHashMap<K, V>` (faster but not resilient to DoS attacks). +pub type FastIndexMap<K, V> = + indexmap::IndexMap<K, V, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>; + +/// Map of expressions that have associated variable names +pub(crate) type NamedExpressions = FastIndexMap<Handle<Expression>, String>; + +/// Early fragment tests. +/// +/// In a standard situation, if a driver determines that it is possible to switch on early depth test, it will. +/// +/// Typical situations when early depth test is switched off: +/// - Calling `discard` in a shader. +/// - Writing to the depth buffer, unless ConservativeDepth is enabled. +/// +/// To use in a shader: +/// - GLSL: `layout(early_fragment_tests) in;` +/// - HLSL: `Attribute earlydepthstencil` +/// - SPIR-V: `ExecutionMode EarlyFragmentTests` +/// - WGSL: `@early_depth_test` +/// +/// For more, see: +/// - <https://www.khronos.org/opengl/wiki/Early_Fragment_Test#Explicit_specification> +/// - <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-attributes-earlydepthstencil> +/// - <https://www.khronos.org/registry/SPIR-V/specs/unified1/SPIRV.html#Execution_Mode> +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct EarlyDepthTest { + pub conservative: Option<ConservativeDepth>, +} +/// Enables adjusting depth without disabling early Z. +/// +/// To use in a shader: +/// - GLSL: `layout (depth_<greater/less/unchanged/any>) out float gl_FragDepth;` +/// - `depth_any` option behaves as if the layout qualifier was not present. +/// - HLSL: `SV_DepthGreaterEqual`/`SV_DepthLessEqual`/`SV_Depth` +/// - SPIR-V: `ExecutionMode Depth<Greater/Less/Unchanged>` +/// - WGSL: `@early_depth_test(greater_equal/less_equal/unchanged)` +/// +/// For more, see: +/// - <https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt> +/// - <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-semantics#system-value-semantics> +/// - <https://www.khronos.org/registry/SPIR-V/specs/unified1/SPIRV.html#Execution_Mode> +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum ConservativeDepth { + /// Shader may rewrite depth only with a value greater than calculated. + GreaterEqual, + + /// Shader may rewrite depth smaller than one that would have been written without the modification. + LessEqual, + + /// Shader may not rewrite depth value. + Unchanged, +} + +/// Stage of the programmable pipeline. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[allow(missing_docs)] // The names are self evident +pub enum ShaderStage { + Vertex, + Fragment, + Compute, +} + +/// Addressing space of variables. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum AddressSpace { + /// Function locals. + Function, + /// Private data, per invocation, mutable. + Private, + /// Workgroup shared data, mutable. + WorkGroup, + /// Uniform buffer data. + Uniform, + /// Storage buffer data, potentially mutable. + Storage { access: StorageAccess }, + /// Opaque handles, such as samplers and images. + Handle, + /// Push constants. + PushConstant, +} + +/// Built-in inputs and outputs. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum BuiltIn { + Position { invariant: bool }, + ViewIndex, + // vertex + BaseInstance, + BaseVertex, + ClipDistance, + CullDistance, + InstanceIndex, + PointSize, + VertexIndex, + // fragment + FragDepth, + PointCoord, + FrontFacing, + PrimitiveIndex, + SampleIndex, + SampleMask, + // compute + GlobalInvocationId, + LocalInvocationId, + LocalInvocationIndex, + WorkGroupId, + WorkGroupSize, + NumWorkGroups, +} + +/// Number of bytes per scalar. +pub type Bytes = u8; + +/// Number of components in a vector. +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum VectorSize { + /// 2D vector + Bi = 2, + /// 3D vector + Tri = 3, + /// 4D vector + Quad = 4, +} + +impl VectorSize { + const MAX: usize = Self::Quad as u8 as usize; +} + +/// Primitive type for a scalar. +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum ScalarKind { + /// Signed integer type. + Sint, + /// Unsigned integer type. + Uint, + /// Floating point type. + Float, + /// Boolean type. + Bool, + + /// WGSL abstract integer type. + /// + /// These are forbidden by validation, and should never reach backends. + AbstractInt, + + /// Abstract floating-point type. + /// + /// These are forbidden by validation, and should never reach backends. + AbstractFloat, +} + +/// Characteristics of a scalar type. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct Scalar { + /// How the value's bits are to be interpreted. + pub kind: ScalarKind, + + /// This size of the value in bytes. + pub width: Bytes, +} + +/// Size of an array. +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum ArraySize { + /// The array size is constant. + Constant(std::num::NonZeroU32), + /// The array size can change at runtime. + Dynamic, +} + +/// The interpolation qualifier of a binding or struct field. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum Interpolation { + /// The value will be interpolated in a perspective-correct fashion. + /// Also known as "smooth" in glsl. + Perspective, + /// Indicates that linear, non-perspective, correct + /// interpolation must be used. + /// Also known as "no_perspective" in glsl. + Linear, + /// Indicates that no interpolation will be performed. + Flat, +} + +/// The sampling qualifiers of a binding or struct field. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum Sampling { + /// Interpolate the value at the center of the pixel. + Center, + + /// Interpolate the value at a point that lies within all samples covered by + /// the fragment within the current primitive. In multisampling, use a + /// single value for all samples in the primitive. + Centroid, + + /// Interpolate the value at each sample location. In multisampling, invoke + /// the fragment shader once per sample. + Sample, +} + +/// Member of a user-defined structure. +// Clone is used only for error reporting and is not intended for end users +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct StructMember { + pub name: Option<String>, + /// Type of the field. + pub ty: Handle<Type>, + /// For I/O structs, defines the binding. + pub binding: Option<Binding>, + /// Offset from the beginning from the struct. + pub offset: u32, +} + +/// The number of dimensions an image has. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum ImageDimension { + /// 1D image + D1, + /// 2D image + D2, + /// 3D image + D3, + /// Cube map + Cube, +} + +bitflags::bitflags! { + /// Flags describing an image. + #[cfg_attr(feature = "serialize", derive(Serialize))] + #[cfg_attr(feature = "deserialize", derive(Deserialize))] + #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] + #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] + pub struct StorageAccess: u32 { + /// Storage can be used as a source for load ops. + const LOAD = 0x1; + /// Storage can be used as a target for store ops. + const STORE = 0x2; + } +} + +/// Image storage format. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum StorageFormat { + // 8-bit formats + R8Unorm, + R8Snorm, + R8Uint, + R8Sint, + + // 16-bit formats + R16Uint, + R16Sint, + R16Float, + Rg8Unorm, + Rg8Snorm, + Rg8Uint, + Rg8Sint, + + // 32-bit formats + R32Uint, + R32Sint, + R32Float, + Rg16Uint, + Rg16Sint, + Rg16Float, + Rgba8Unorm, + Rgba8Snorm, + Rgba8Uint, + Rgba8Sint, + Bgra8Unorm, + + // Packed 32-bit formats + Rgb10a2Uint, + Rgb10a2Unorm, + Rg11b10Float, + + // 64-bit formats + Rg32Uint, + Rg32Sint, + Rg32Float, + Rgba16Uint, + Rgba16Sint, + Rgba16Float, + + // 128-bit formats + Rgba32Uint, + Rgba32Sint, + Rgba32Float, + + // Normalized 16-bit per channel formats + R16Unorm, + R16Snorm, + Rg16Unorm, + Rg16Snorm, + Rgba16Unorm, + Rgba16Snorm, +} + +/// Sub-class of the image type. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum ImageClass { + /// Regular sampled image. + Sampled { + /// Kind of values to sample. + kind: ScalarKind, + /// Multi-sampled image. + /// + /// A multi-sampled image holds several samples per texel. Multi-sampled + /// images cannot have mipmaps. + multi: bool, + }, + /// Depth comparison image. + Depth { + /// Multi-sampled depth image. + multi: bool, + }, + /// Storage image. + Storage { + format: StorageFormat, + access: StorageAccess, + }, +} + +/// A data type declared in the module. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct Type { + /// The name of the type, if any. + pub name: Option<String>, + /// Inner structure that depends on the kind of the type. + pub inner: TypeInner, +} + +/// Enum with additional information, depending on the kind of type. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum TypeInner { + /// Number of integral or floating-point kind. + Scalar(Scalar), + /// Vector of numbers. + Vector { size: VectorSize, scalar: Scalar }, + /// Matrix of numbers. + Matrix { + columns: VectorSize, + rows: VectorSize, + scalar: Scalar, + }, + /// Atomic scalar. + Atomic(Scalar), + /// Pointer to another type. + /// + /// Pointers to scalars and vectors should be treated as equivalent to + /// [`ValuePointer`] types. Use the [`TypeInner::equivalent`] method to + /// compare types in a way that treats pointers correctly. + /// + /// ## Pointers to non-`SIZED` types + /// + /// The `base` type of a pointer may be a non-[`SIZED`] type like a + /// dynamically-sized [`Array`], or a [`Struct`] whose last member is a + /// dynamically sized array. Such pointers occur as the types of + /// [`GlobalVariable`] or [`AccessIndex`] expressions referring to + /// dynamically-sized arrays. + /// + /// However, among pointers to non-`SIZED` types, only pointers to `Struct`s + /// are [`DATA`]. Pointers to dynamically sized `Array`s cannot be passed as + /// arguments, stored in variables, or held in arrays or structures. Their + /// only use is as the types of `AccessIndex` expressions. + /// + /// [`SIZED`]: valid::TypeFlags::SIZED + /// [`DATA`]: valid::TypeFlags::DATA + /// [`Array`]: TypeInner::Array + /// [`Struct`]: TypeInner::Struct + /// [`ValuePointer`]: TypeInner::ValuePointer + /// [`GlobalVariable`]: Expression::GlobalVariable + /// [`AccessIndex`]: Expression::AccessIndex + Pointer { + base: Handle<Type>, + space: AddressSpace, + }, + + /// Pointer to a scalar or vector. + /// + /// A `ValuePointer` type is equivalent to a `Pointer` whose `base` is a + /// `Scalar` or `Vector` type. This is for use in [`TypeResolution::Value`] + /// variants; see the documentation for [`TypeResolution`] for details. + /// + /// Use the [`TypeInner::equivalent`] method to compare types that could be + /// pointers, to ensure that `Pointer` and `ValuePointer` types are + /// recognized as equivalent. + /// + /// [`TypeResolution`]: proc::TypeResolution + /// [`TypeResolution::Value`]: proc::TypeResolution::Value + ValuePointer { + size: Option<VectorSize>, + scalar: Scalar, + space: AddressSpace, + }, + + /// Homogeneous list of elements. + /// + /// The `base` type must be a [`SIZED`], [`DATA`] type. + /// + /// ## Dynamically sized arrays + /// + /// An `Array` is [`SIZED`] unless its `size` is [`Dynamic`]. + /// Dynamically-sized arrays may only appear in a few situations: + /// + /// - They may appear as the type of a [`GlobalVariable`], or as the last + /// member of a [`Struct`]. + /// + /// - They may appear as the base type of a [`Pointer`]. An + /// [`AccessIndex`] expression referring to a struct's final + /// unsized array member would have such a pointer type. However, such + /// pointer types may only appear as the types of such intermediate + /// expressions. They are not [`DATA`], and cannot be stored in + /// variables, held in arrays or structs, or passed as parameters. + /// + /// [`SIZED`]: crate::valid::TypeFlags::SIZED + /// [`DATA`]: crate::valid::TypeFlags::DATA + /// [`Dynamic`]: ArraySize::Dynamic + /// [`Struct`]: TypeInner::Struct + /// [`Pointer`]: TypeInner::Pointer + /// [`AccessIndex`]: Expression::AccessIndex + Array { + base: Handle<Type>, + size: ArraySize, + stride: u32, + }, + + /// User-defined structure. + /// + /// There must always be at least one member. + /// + /// A `Struct` type is [`DATA`], and the types of its members must be + /// `DATA` as well. + /// + /// Member types must be [`SIZED`], except for the final member of a + /// struct, which may be a dynamically sized [`Array`]. The + /// `Struct` type itself is `SIZED` when all its members are `SIZED`. + /// + /// [`DATA`]: crate::valid::TypeFlags::DATA + /// [`SIZED`]: crate::valid::TypeFlags::SIZED + /// [`Array`]: TypeInner::Array + Struct { + members: Vec<StructMember>, + //TODO: should this be unaligned? + span: u32, + }, + /// Possibly multidimensional array of texels. + Image { + dim: ImageDimension, + arrayed: bool, + //TODO: consider moving `multisampled: bool` out + class: ImageClass, + }, + /// Can be used to sample values from images. + Sampler { comparison: bool }, + + /// Opaque object representing an acceleration structure of geometry. + AccelerationStructure, + + /// Locally used handle for ray queries. + RayQuery, + + /// Array of bindings. + /// + /// A `BindingArray` represents an array where each element draws its value + /// from a separate bound resource. The array's element type `base` may be + /// [`Image`], [`Sampler`], or any type that would be permitted for a global + /// in the [`Uniform`] or [`Storage`] address spaces. Only global variables + /// may be binding arrays; on the host side, their values are provided by + /// [`TextureViewArray`], [`SamplerArray`], or [`BufferArray`] + /// bindings. + /// + /// Since each element comes from a distinct resource, a binding array of + /// images could have images of varying sizes (but not varying dimensions; + /// they must all have the same `Image` type). Or, a binding array of + /// buffers could have elements that are dynamically sized arrays, each with + /// a different length. + /// + /// Binding arrays are in the same address spaces as their underlying type. + /// As such, referring to an array of images produces an [`Image`] value + /// directly (as opposed to a pointer). The only operation permitted on + /// `BindingArray` values is indexing, which works transparently: indexing + /// a binding array of samplers yields a [`Sampler`], indexing a pointer to the + /// binding array of storage buffers produces a pointer to the storage struct. + /// + /// Unlike textures and samplers, binding arrays are not [`ARGUMENT`], so + /// they cannot be passed as arguments to functions. + /// + /// Naga's WGSL front end supports binding arrays with the type syntax + /// `binding_array<T, N>`. + /// + /// [`Image`]: TypeInner::Image + /// [`Sampler`]: TypeInner::Sampler + /// [`Uniform`]: AddressSpace::Uniform + /// [`Storage`]: AddressSpace::Storage + /// [`TextureViewArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.TextureViewArray + /// [`SamplerArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.SamplerArray + /// [`BufferArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.BufferArray + /// [`DATA`]: crate::valid::TypeFlags::DATA + /// [`ARGUMENT`]: crate::valid::TypeFlags::ARGUMENT + /// [naga#1864]: https://github.com/gfx-rs/naga/issues/1864 + BindingArray { base: Handle<Type>, size: ArraySize }, +} + +#[derive(Debug, Clone, Copy, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum Literal { + /// May not be NaN or infinity. + F64(f64), + /// May not be NaN or infinity. + F32(f32), + U32(u32), + I32(i32), + I64(i64), + Bool(bool), + AbstractInt(i64), + AbstractFloat(f64), +} + +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "clone", derive(Clone))] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum Override { + None, + ByName, + ByNameOrId(u32), +} + +/// Constant value. +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "clone", derive(Clone))] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct Constant { + pub name: Option<String>, + pub r#override: Override, + pub ty: Handle<Type>, + + /// The value of the constant. + /// + /// This [`Handle`] refers to [`Module::const_expressions`], not + /// any [`Function::expressions`] arena. + /// + /// If [`override`] is [`None`], then this must be a Naga + /// [constant expression]. Otherwise, this may be a Naga + /// [override expression] or [constant expression]. + /// + /// [`override`]: Constant::override + /// [`None`]: Override::None + /// [constant expression]: index.html#constant-expressions + /// [override expression]: index.html#override-expressions + pub init: Handle<Expression>, +} + +/// Describes how an input/output variable is to be bound. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum Binding { + /// Built-in shader variable. + BuiltIn(BuiltIn), + + /// Indexed location. + /// + /// Values passed from the [`Vertex`] stage to the [`Fragment`] stage must + /// have their `interpolation` defaulted (i.e. not `None`) by the front end + /// as appropriate for that language. + /// + /// For other stages, we permit interpolations even though they're ignored. + /// When a front end is parsing a struct type, it usually doesn't know what + /// stages will be using it for IO, so it's easiest if it can apply the + /// defaults to anything with a `Location` binding, just in case. + /// + /// For anything other than floating-point scalars and vectors, the + /// interpolation must be `Flat`. + /// + /// [`Vertex`]: crate::ShaderStage::Vertex + /// [`Fragment`]: crate::ShaderStage::Fragment + Location { + location: u32, + /// Indicates the 2nd input to the blender when dual-source blending. + second_blend_source: bool, + interpolation: Option<Interpolation>, + sampling: Option<Sampling>, + }, +} + +/// Pipeline binding information for global resources. +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct ResourceBinding { + /// The bind group index. + pub group: u32, + /// Binding number within the group. + pub binding: u32, +} + +/// Variable defined at module level. +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct GlobalVariable { + /// Name of the variable, if any. + pub name: Option<String>, + /// How this variable is to be stored. + pub space: AddressSpace, + /// For resources, defines the binding point. + pub binding: Option<ResourceBinding>, + /// The type of this variable. + pub ty: Handle<Type>, + /// Initial value for this variable. + /// + /// Expression handle lives in const_expressions + pub init: Option<Handle<Expression>>, +} + +/// Variable defined at function level. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct LocalVariable { + /// Name of the variable, if any. + pub name: Option<String>, + /// The type of this variable. + pub ty: Handle<Type>, + /// Initial value for this variable. + /// + /// This handle refers to this `LocalVariable`'s function's + /// [`expressions`] arena, but it is required to be an evaluated + /// constant expression. + /// + /// [`expressions`]: Function::expressions + pub init: Option<Handle<Expression>>, +} + +/// Operation that can be applied on a single value. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum UnaryOperator { + Negate, + LogicalNot, + BitwiseNot, +} + +/// Operation that can be applied on two values. +/// +/// ## Arithmetic type rules +/// +/// The arithmetic operations `Add`, `Subtract`, `Multiply`, `Divide`, and +/// `Modulo` can all be applied to [`Scalar`] types other than [`Bool`], or +/// [`Vector`]s thereof. Both operands must have the same type. +/// +/// `Add` and `Subtract` can also be applied to [`Matrix`] values. Both operands +/// must have the same type. +/// +/// `Multiply` supports additional cases: +/// +/// - A [`Matrix`] or [`Vector`] can be multiplied by a scalar [`Float`], +/// either on the left or the right. +/// +/// - A [`Matrix`] on the left can be multiplied by a [`Vector`] on the right +/// if the matrix has as many columns as the vector has components (`matCxR +/// * VecC`). +/// +/// - A [`Vector`] on the left can be multiplied by a [`Matrix`] on the right +/// if the matrix has as many rows as the vector has components (`VecR * +/// matCxR`). +/// +/// - Two matrices can be multiplied if the left operand has as many columns +/// as the right operand has rows (`matNxR * matCxN`). +/// +/// In all the above `Multiply` cases, the byte widths of the underlying scalar +/// types of both operands must be the same. +/// +/// Note that `Multiply` supports mixed vector and scalar operations directly, +/// whereas the other arithmetic operations require an explicit [`Splat`] for +/// mixed-type use. +/// +/// [`Scalar`]: TypeInner::Scalar +/// [`Vector`]: TypeInner::Vector +/// [`Matrix`]: TypeInner::Matrix +/// [`Float`]: ScalarKind::Float +/// [`Bool`]: ScalarKind::Bool +/// [`Splat`]: Expression::Splat +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum BinaryOperator { + Add, + Subtract, + Multiply, + Divide, + /// Equivalent of the WGSL's `%` operator or SPIR-V's `OpFRem` + Modulo, + Equal, + NotEqual, + Less, + LessEqual, + Greater, + GreaterEqual, + And, + ExclusiveOr, + InclusiveOr, + LogicalAnd, + LogicalOr, + ShiftLeft, + /// Right shift carries the sign of signed integers only. + ShiftRight, +} + +/// Function on an atomic value. +/// +/// Note: these do not include load/store, which use the existing +/// [`Expression::Load`] and [`Statement::Store`]. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum AtomicFunction { + Add, + Subtract, + And, + ExclusiveOr, + InclusiveOr, + Min, + Max, + Exchange { compare: Option<Handle<Expression>> }, +} + +/// Hint at which precision to compute a derivative. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum DerivativeControl { + Coarse, + Fine, + None, +} + +/// Axis on which to compute a derivative. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum DerivativeAxis { + X, + Y, + Width, +} + +/// Built-in shader function for testing relation between values. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum RelationalFunction { + All, + Any, + IsNan, + IsInf, +} + +/// Built-in shader function for math. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum MathFunction { + // comparison + Abs, + Min, + Max, + Clamp, + Saturate, + // trigonometry + Cos, + Cosh, + Sin, + Sinh, + Tan, + Tanh, + Acos, + Asin, + Atan, + Atan2, + Asinh, + Acosh, + Atanh, + Radians, + Degrees, + // decomposition + Ceil, + Floor, + Round, + Fract, + Trunc, + Modf, + Frexp, + Ldexp, + // exponent + Exp, + Exp2, + Log, + Log2, + Pow, + // geometry + Dot, + Outer, + Cross, + Distance, + Length, + Normalize, + FaceForward, + Reflect, + Refract, + // computational + Sign, + Fma, + Mix, + Step, + SmoothStep, + Sqrt, + InverseSqrt, + Inverse, + Transpose, + Determinant, + // bits + CountTrailingZeros, + CountLeadingZeros, + CountOneBits, + ReverseBits, + ExtractBits, + InsertBits, + FindLsb, + FindMsb, + // data packing + Pack4x8snorm, + Pack4x8unorm, + Pack2x16snorm, + Pack2x16unorm, + Pack2x16float, + // data unpacking + Unpack4x8snorm, + Unpack4x8unorm, + Unpack2x16snorm, + Unpack2x16unorm, + Unpack2x16float, +} + +/// Sampling modifier to control the level of detail. +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum SampleLevel { + Auto, + Zero, + Exact(Handle<Expression>), + Bias(Handle<Expression>), + Gradient { + x: Handle<Expression>, + y: Handle<Expression>, + }, +} + +/// Type of an image query. +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum ImageQuery { + /// Get the size at the specified level. + Size { + /// If `None`, the base level is considered. + level: Option<Handle<Expression>>, + }, + /// Get the number of mipmap levels. + NumLevels, + /// Get the number of array layers. + NumLayers, + /// Get the number of samples. + NumSamples, +} + +/// Component selection for a vector swizzle. +#[repr(u8)] +#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum SwizzleComponent { + /// + X = 0, + /// + Y = 1, + /// + Z = 2, + /// + W = 3, +} + +bitflags::bitflags! { + /// Memory barrier flags. + #[cfg_attr(feature = "serialize", derive(Serialize))] + #[cfg_attr(feature = "deserialize", derive(Deserialize))] + #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] + #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] + pub struct Barrier: u32 { + /// Barrier affects all `AddressSpace::Storage` accesses. + const STORAGE = 0x1; + /// Barrier affects all `AddressSpace::WorkGroup` accesses. + const WORK_GROUP = 0x2; + } +} + +/// An expression that can be evaluated to obtain a value. +/// +/// This is a Single Static Assignment (SSA) scheme similar to SPIR-V. +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum Expression { + /// Literal. + Literal(Literal), + /// Constant value. + Constant(Handle<Constant>), + /// Zero value of a type. + ZeroValue(Handle<Type>), + /// Composite expression. + Compose { + ty: Handle<Type>, + components: Vec<Handle<Expression>>, + }, + + /// Array access with a computed index. + /// + /// ## Typing rules + /// + /// The `base` operand must be some composite type: [`Vector`], [`Matrix`], + /// [`Array`], a [`Pointer`] to one of those, or a [`ValuePointer`] with a + /// `size`. + /// + /// The `index` operand must be an integer, signed or unsigned. + /// + /// Indexing a [`Vector`] or [`Array`] produces a value of its element type. + /// Indexing a [`Matrix`] produces a [`Vector`]. + /// + /// Indexing a [`Pointer`] to any of the above produces a pointer to the + /// element/component type, in the same [`space`]. In the case of [`Array`], + /// the result is an actual [`Pointer`], but for vectors and matrices, there + /// may not be any type in the arena representing the component's type, so + /// those produce [`ValuePointer`] types equivalent to the appropriate + /// [`Pointer`]. + /// + /// ## Dynamic indexing restrictions + /// + /// To accommodate restrictions in some of the shader languages that Naga + /// targets, it is not permitted to subscript a matrix or array with a + /// dynamically computed index unless that matrix or array appears behind a + /// pointer. In other words, if the inner type of `base` is [`Array`] or + /// [`Matrix`], then `index` must be a constant. But if the type of `base` + /// is a [`Pointer`] to an array or matrix or a [`ValuePointer`] with a + /// `size`, then the index may be any expression of integer type. + /// + /// You can use the [`Expression::is_dynamic_index`] method to determine + /// whether a given index expression requires matrix or array base operands + /// to be behind a pointer. + /// + /// (It would be simpler to always require the use of `AccessIndex` when + /// subscripting arrays and matrices that are not behind pointers, but to + /// accommodate existing front ends, Naga also permits `Access`, with a + /// restricted `index`.) + /// + /// [`Vector`]: TypeInner::Vector + /// [`Matrix`]: TypeInner::Matrix + /// [`Array`]: TypeInner::Array + /// [`Pointer`]: TypeInner::Pointer + /// [`space`]: TypeInner::Pointer::space + /// [`ValuePointer`]: TypeInner::ValuePointer + /// [`Float`]: ScalarKind::Float + Access { + base: Handle<Expression>, + index: Handle<Expression>, + }, + /// Access the same types as [`Access`], plus [`Struct`] with a known index. + /// + /// [`Access`]: Expression::Access + /// [`Struct`]: TypeInner::Struct + AccessIndex { + base: Handle<Expression>, + index: u32, + }, + /// Splat scalar into a vector. + Splat { + size: VectorSize, + value: Handle<Expression>, + }, + /// Vector swizzle. + Swizzle { + size: VectorSize, + vector: Handle<Expression>, + pattern: [SwizzleComponent; 4], + }, + + /// Reference a function parameter, by its index. + /// + /// A `FunctionArgument` expression evaluates to a pointer to the argument's + /// value. You must use a [`Load`] expression to retrieve its value, or a + /// [`Store`] statement to assign it a new value. + /// + /// [`Load`]: Expression::Load + /// [`Store`]: Statement::Store + FunctionArgument(u32), + + /// Reference a global variable. + /// + /// If the given `GlobalVariable`'s [`space`] is [`AddressSpace::Handle`], + /// then the variable stores some opaque type like a sampler or an image, + /// and a `GlobalVariable` expression referring to it produces the + /// variable's value directly. + /// + /// For any other address space, a `GlobalVariable` expression produces a + /// pointer to the variable's value. You must use a [`Load`] expression to + /// retrieve its value, or a [`Store`] statement to assign it a new value. + /// + /// [`space`]: GlobalVariable::space + /// [`Load`]: Expression::Load + /// [`Store`]: Statement::Store + GlobalVariable(Handle<GlobalVariable>), + + /// Reference a local variable. + /// + /// A `LocalVariable` expression evaluates to a pointer to the variable's value. + /// You must use a [`Load`](Expression::Load) expression to retrieve its value, + /// or a [`Store`](Statement::Store) statement to assign it a new value. + LocalVariable(Handle<LocalVariable>), + + /// Load a value indirectly. + /// + /// For [`TypeInner::Atomic`] the result is a corresponding scalar. + /// For other types behind the `pointer<T>`, the result is `T`. + Load { pointer: Handle<Expression> }, + /// Sample a point from a sampled or a depth image. + ImageSample { + image: Handle<Expression>, + sampler: Handle<Expression>, + /// If Some(), this operation is a gather operation + /// on the selected component. + gather: Option<SwizzleComponent>, + coordinate: Handle<Expression>, + array_index: Option<Handle<Expression>>, + /// Expression handle lives in const_expressions + offset: Option<Handle<Expression>>, + level: SampleLevel, + depth_ref: Option<Handle<Expression>>, + }, + + /// Load a texel from an image. + /// + /// For most images, this returns a four-element vector of the same + /// [`ScalarKind`] as the image. If the format of the image does not have + /// four components, default values are provided: the first three components + /// (typically R, G, and B) default to zero, and the final component + /// (typically alpha) defaults to one. + /// + /// However, if the image's [`class`] is [`Depth`], then this returns a + /// [`Float`] scalar value. + /// + /// [`ScalarKind`]: ScalarKind + /// [`class`]: TypeInner::Image::class + /// [`Depth`]: ImageClass::Depth + /// [`Float`]: ScalarKind::Float + ImageLoad { + /// The image to load a texel from. This must have type [`Image`]. (This + /// will necessarily be a [`GlobalVariable`] or [`FunctionArgument`] + /// expression, since no other expressions are allowed to have that + /// type.) + /// + /// [`Image`]: TypeInner::Image + /// [`GlobalVariable`]: Expression::GlobalVariable + /// [`FunctionArgument`]: Expression::FunctionArgument + image: Handle<Expression>, + + /// The coordinate of the texel we wish to load. This must be a scalar + /// for [`D1`] images, a [`Bi`] vector for [`D2`] images, and a [`Tri`] + /// vector for [`D3`] images. (Array indices, sample indices, and + /// explicit level-of-detail values are supplied separately.) Its + /// component type must be [`Sint`]. + /// + /// [`D1`]: ImageDimension::D1 + /// [`D2`]: ImageDimension::D2 + /// [`D3`]: ImageDimension::D3 + /// [`Bi`]: VectorSize::Bi + /// [`Tri`]: VectorSize::Tri + /// [`Sint`]: ScalarKind::Sint + coordinate: Handle<Expression>, + + /// The index into an arrayed image. If the [`arrayed`] flag in + /// `image`'s type is `true`, then this must be `Some(expr)`, where + /// `expr` is a [`Sint`] scalar. Otherwise, it must be `None`. + /// + /// [`arrayed`]: TypeInner::Image::arrayed + /// [`Sint`]: ScalarKind::Sint + array_index: Option<Handle<Expression>>, + + /// A sample index, for multisampled [`Sampled`] and [`Depth`] images. + /// + /// [`Sampled`]: ImageClass::Sampled + /// [`Depth`]: ImageClass::Depth + sample: Option<Handle<Expression>>, + + /// A level of detail, for mipmapped images. + /// + /// This must be present when accessing non-multisampled + /// [`Sampled`] and [`Depth`] images, even if only the + /// full-resolution level is present (in which case the only + /// valid level is zero). + /// + /// [`Sampled`]: ImageClass::Sampled + /// [`Depth`]: ImageClass::Depth + level: Option<Handle<Expression>>, + }, + + /// Query information from an image. + ImageQuery { + image: Handle<Expression>, + query: ImageQuery, + }, + /// Apply an unary operator. + Unary { + op: UnaryOperator, + expr: Handle<Expression>, + }, + /// Apply a binary operator. + Binary { + op: BinaryOperator, + left: Handle<Expression>, + right: Handle<Expression>, + }, + /// Select between two values based on a condition. + /// + /// Note that, because expressions have no side effects, it is unobservable + /// whether the non-selected branch is evaluated. + Select { + /// Boolean expression + condition: Handle<Expression>, + accept: Handle<Expression>, + reject: Handle<Expression>, + }, + /// Compute the derivative on an axis. + Derivative { + axis: DerivativeAxis, + ctrl: DerivativeControl, + expr: Handle<Expression>, + }, + /// Call a relational function. + Relational { + fun: RelationalFunction, + argument: Handle<Expression>, + }, + /// Call a math function + Math { + fun: MathFunction, + arg: Handle<Expression>, + arg1: Option<Handle<Expression>>, + arg2: Option<Handle<Expression>>, + arg3: Option<Handle<Expression>>, + }, + /// Cast a simple type to another kind. + As { + /// Source expression, which can only be a scalar or a vector. + expr: Handle<Expression>, + /// Target scalar kind. + kind: ScalarKind, + /// If provided, converts to the specified byte width. + /// Otherwise, bitcast. + convert: Option<Bytes>, + }, + /// Result of calling another function. + CallResult(Handle<Function>), + /// Result of an atomic operation. + AtomicResult { ty: Handle<Type>, comparison: bool }, + /// Result of a [`WorkGroupUniformLoad`] statement. + /// + /// [`WorkGroupUniformLoad`]: Statement::WorkGroupUniformLoad + WorkGroupUniformLoadResult { + /// The type of the result + ty: Handle<Type>, + }, + /// Get the length of an array. + /// The expression must resolve to a pointer to an array with a dynamic size. + /// + /// This doesn't match the semantics of spirv's `OpArrayLength`, which must be passed + /// a pointer to a structure containing a runtime array in its' last field. + ArrayLength(Handle<Expression>), + + /// Result of a [`Proceed`] [`RayQuery`] statement. + /// + /// [`Proceed`]: RayQueryFunction::Proceed + /// [`RayQuery`]: Statement::RayQuery + RayQueryProceedResult, + + /// Return an intersection found by `query`. + /// + /// If `committed` is true, return the committed result available when + RayQueryGetIntersection { + query: Handle<Expression>, + committed: bool, + }, +} + +pub use block::Block; + +/// The value of the switch case. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum SwitchValue { + I32(i32), + U32(u32), + Default, +} + +/// A case for a switch statement. +// Clone is used only for error reporting and is not intended for end users +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct SwitchCase { + /// Value, upon which the case is considered true. + pub value: SwitchValue, + /// Body of the case. + pub body: Block, + /// If true, the control flow continues to the next case in the list, + /// or default. + pub fall_through: bool, +} + +/// An operation that a [`RayQuery` statement] applies to its [`query`] operand. +/// +/// [`RayQuery` statement]: Statement::RayQuery +/// [`query`]: Statement::RayQuery::query +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum RayQueryFunction { + /// Initialize the `RayQuery` object. + Initialize { + /// The acceleration structure within which this query should search for hits. + /// + /// The expression must be an [`AccelerationStructure`]. + /// + /// [`AccelerationStructure`]: TypeInner::AccelerationStructure + acceleration_structure: Handle<Expression>, + + #[allow(rustdoc::private_intra_doc_links)] + /// A struct of detailed parameters for the ray query. + /// + /// This expression should have the struct type given in + /// [`SpecialTypes::ray_desc`]. This is available in the WGSL + /// front end as the `RayDesc` type. + descriptor: Handle<Expression>, + }, + + /// Start or continue the query given by the statement's [`query`] operand. + /// + /// After executing this statement, the `result` expression is a + /// [`Bool`] scalar indicating whether there are more intersection + /// candidates to consider. + /// + /// [`query`]: Statement::RayQuery::query + /// [`Bool`]: ScalarKind::Bool + Proceed { + result: Handle<Expression>, + }, + + Terminate, +} + +//TODO: consider removing `Clone`. It's not valid to clone `Statement::Emit` anyway. +/// Instructions which make up an executable block. +// Clone is used only for error reporting and is not intended for end users +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum Statement { + /// Emit a range of expressions, visible to all statements that follow in this block. + /// + /// See the [module-level documentation][emit] for details. + /// + /// [emit]: index.html#expression-evaluation-time + Emit(Range<Expression>), + /// A block containing more statements, to be executed sequentially. + Block(Block), + /// Conditionally executes one of two blocks, based on the value of the condition. + If { + condition: Handle<Expression>, //bool + accept: Block, + reject: Block, + }, + /// Conditionally executes one of multiple blocks, based on the value of the selector. + /// + /// Each case must have a distinct [`value`], exactly one of which must be + /// [`Default`]. The `Default` may appear at any position, and covers all + /// values not explicitly appearing in other cases. A `Default` appearing in + /// the midst of the list of cases does not shadow the cases that follow. + /// + /// Some backend languages don't support fallthrough (HLSL due to FXC, + /// WGSL), and may translate fallthrough cases in the IR by duplicating + /// code. However, all backend languages do support cases selected by + /// multiple values, like `case 1: case 2: case 3: { ... }`. This is + /// represented in the IR as a series of fallthrough cases with empty + /// bodies, except for the last. + /// + /// [`value`]: SwitchCase::value + /// [`body`]: SwitchCase::body + /// [`Default`]: SwitchValue::Default + Switch { + selector: Handle<Expression>, + cases: Vec<SwitchCase>, + }, + + /// Executes a block repeatedly. + /// + /// Each iteration of the loop executes the `body` block, followed by the + /// `continuing` block. + /// + /// Executing a [`Break`], [`Return`] or [`Kill`] statement exits the loop. + /// + /// A [`Continue`] statement in `body` jumps to the `continuing` block. The + /// `continuing` block is meant to be used to represent structures like the + /// third expression of a C-style `for` loop head, to which `continue` + /// statements in the loop's body jump. + /// + /// The `continuing` block and its substatements must not contain `Return` + /// or `Kill` statements, or any `Break` or `Continue` statements targeting + /// this loop. (It may have `Break` and `Continue` statements targeting + /// loops or switches nested within the `continuing` block.) Expressions + /// emitted in `body` are in scope in `continuing`. + /// + /// If present, `break_if` is an expression which is evaluated after the + /// continuing block. Expressions emitted in `body` or `continuing` are + /// considered to be in scope. If the expression's value is true, control + /// continues after the `Loop` statement, rather than branching back to the + /// top of body as usual. The `break_if` expression corresponds to a "break + /// if" statement in WGSL, or a loop whose back edge is an + /// `OpBranchConditional` instruction in SPIR-V. + /// + /// [`Break`]: Statement::Break + /// [`Continue`]: Statement::Continue + /// [`Kill`]: Statement::Kill + /// [`Return`]: Statement::Return + /// [`break if`]: Self::Loop::break_if + Loop { + body: Block, + continuing: Block, + break_if: Option<Handle<Expression>>, + }, + + /// Exits the innermost enclosing [`Loop`] or [`Switch`]. + /// + /// A `Break` statement may only appear within a [`Loop`] or [`Switch`] + /// statement. It may not break out of a [`Loop`] from within the loop's + /// `continuing` block. + /// + /// [`Loop`]: Statement::Loop + /// [`Switch`]: Statement::Switch + Break, + + /// Skips to the `continuing` block of the innermost enclosing [`Loop`]. + /// + /// A `Continue` statement may only appear within the `body` block of the + /// innermost enclosing [`Loop`] statement. It must not appear within that + /// loop's `continuing` block. + /// + /// [`Loop`]: Statement::Loop + Continue, + + /// Returns from the function (possibly with a value). + /// + /// `Return` statements are forbidden within the `continuing` block of a + /// [`Loop`] statement. + /// + /// [`Loop`]: Statement::Loop + Return { value: Option<Handle<Expression>> }, + + /// Aborts the current shader execution. + /// + /// `Kill` statements are forbidden within the `continuing` block of a + /// [`Loop`] statement. + /// + /// [`Loop`]: Statement::Loop + Kill, + + /// Synchronize invocations within the work group. + /// The `Barrier` flags control which memory accesses should be synchronized. + /// If empty, this becomes purely an execution barrier. + Barrier(Barrier), + /// Stores a value at an address. + /// + /// For [`TypeInner::Atomic`] type behind the pointer, the value + /// has to be a corresponding scalar. + /// For other types behind the `pointer<T>`, the value is `T`. + /// + /// This statement is a barrier for any operations on the + /// `Expression::LocalVariable` or `Expression::GlobalVariable` + /// that is the destination of an access chain, started + /// from the `pointer`. + Store { + pointer: Handle<Expression>, + value: Handle<Expression>, + }, + /// Stores a texel value to an image. + /// + /// The `image`, `coordinate`, and `array_index` fields have the same + /// meanings as the corresponding operands of an [`ImageLoad`] expression; + /// see that documentation for details. Storing into multisampled images or + /// images with mipmaps is not supported, so there are no `level` or + /// `sample` operands. + /// + /// This statement is a barrier for any operations on the corresponding + /// [`Expression::GlobalVariable`] for this image. + /// + /// [`ImageLoad`]: Expression::ImageLoad + ImageStore { + image: Handle<Expression>, + coordinate: Handle<Expression>, + array_index: Option<Handle<Expression>>, + value: Handle<Expression>, + }, + /// Atomic function. + Atomic { + /// Pointer to an atomic value. + pointer: Handle<Expression>, + /// Function to run on the atomic. + fun: AtomicFunction, + /// Value to use in the function. + value: Handle<Expression>, + /// [`AtomicResult`] expression representing this function's result. + /// + /// [`AtomicResult`]: crate::Expression::AtomicResult + result: Handle<Expression>, + }, + /// Load uniformly from a uniform pointer in the workgroup address space. + /// + /// Corresponds to the [`workgroupUniformLoad`](https://www.w3.org/TR/WGSL/#workgroupUniformLoad-builtin) + /// built-in function of wgsl, and has the same barrier semantics + WorkGroupUniformLoad { + /// This must be of type [`Pointer`] in the [`WorkGroup`] address space + /// + /// [`Pointer`]: TypeInner::Pointer + /// [`WorkGroup`]: AddressSpace::WorkGroup + pointer: Handle<Expression>, + /// The [`WorkGroupUniformLoadResult`] expression representing this load's result. + /// + /// [`WorkGroupUniformLoadResult`]: Expression::WorkGroupUniformLoadResult + result: Handle<Expression>, + }, + /// Calls a function. + /// + /// If the `result` is `Some`, the corresponding expression has to be + /// `Expression::CallResult`, and this statement serves as a barrier for any + /// operations on that expression. + Call { + function: Handle<Function>, + arguments: Vec<Handle<Expression>>, + result: Option<Handle<Expression>>, + }, + RayQuery { + /// The [`RayQuery`] object this statement operates on. + /// + /// [`RayQuery`]: TypeInner::RayQuery + query: Handle<Expression>, + + /// The specific operation we're performing on `query`. + fun: RayQueryFunction, + }, +} + +/// A function argument. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct FunctionArgument { + /// Name of the argument, if any. + pub name: Option<String>, + /// Type of the argument. + pub ty: Handle<Type>, + /// For entry points, an argument has to have a binding + /// unless it's a structure. + pub binding: Option<Binding>, +} + +/// A function result. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct FunctionResult { + /// Type of the result. + pub ty: Handle<Type>, + /// For entry points, the result has to have a binding + /// unless it's a structure. + pub binding: Option<Binding>, +} + +/// A function defined in the module. +#[derive(Debug, Default)] +#[cfg_attr(feature = "clone", derive(Clone))] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct Function { + /// Name of the function, if any. + pub name: Option<String>, + /// Information about function argument. + pub arguments: Vec<FunctionArgument>, + /// The result of this function, if any. + pub result: Option<FunctionResult>, + /// Local variables defined and used in the function. + pub local_variables: Arena<LocalVariable>, + /// Expressions used inside this function. + /// + /// An `Expression` must occur before all other `Expression`s that use its + /// value. + pub expressions: Arena<Expression>, + /// Map of expressions that have associated variable names + pub named_expressions: NamedExpressions, + /// Block of instructions comprising the body of the function. + pub body: Block, +} + +/// The main function for a pipeline stage. +/// +/// An [`EntryPoint`] is a [`Function`] that serves as the main function for a +/// graphics or compute pipeline stage. For example, an `EntryPoint` whose +/// [`stage`] is [`ShaderStage::Vertex`] can serve as a graphics pipeline's +/// vertex shader. +/// +/// Since an entry point is called directly by the graphics or compute pipeline, +/// not by other WGSL functions, you must specify what the pipeline should pass +/// as the entry point's arguments, and what values it will return. For example, +/// a vertex shader needs a vertex's attributes as its arguments, but if it's +/// used for instanced draw calls, it will also want to know the instance id. +/// The vertex shader's return value will usually include an output vertex +/// position, and possibly other attributes to be interpolated and passed along +/// to a fragment shader. +/// +/// To specify this, the arguments and result of an `EntryPoint`'s [`function`] +/// must each have a [`Binding`], or be structs whose members all have +/// `Binding`s. This associates every value passed to or returned from the entry +/// point with either a [`BuiltIn`] or a [`Location`]: +/// +/// - A [`BuiltIn`] has special semantics, usually specific to its pipeline +/// stage. For example, the result of a vertex shader can include a +/// [`BuiltIn::Position`] value, which determines the position of a vertex +/// of a rendered primitive. Or, a compute shader might take an argument +/// whose binding is [`BuiltIn::WorkGroupSize`], through which the compute +/// pipeline would pass the number of invocations in your workgroup. +/// +/// - A [`Location`] indicates user-defined IO to be passed from one pipeline +/// stage to the next. For example, a vertex shader might also produce a +/// `uv` texture location as a user-defined IO value. +/// +/// In other words, the pipeline stage's input and output interface are +/// determined by the bindings of the arguments and result of the `EntryPoint`'s +/// [`function`]. +/// +/// [`Function`]: crate::Function +/// [`Location`]: Binding::Location +/// [`function`]: EntryPoint::function +/// [`stage`]: EntryPoint::stage +#[derive(Debug)] +#[cfg_attr(feature = "clone", derive(Clone))] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct EntryPoint { + /// Name of this entry point, visible externally. + /// + /// Entry point names for a given `stage` must be distinct within a module. + pub name: String, + /// Shader stage. + pub stage: ShaderStage, + /// Early depth test for fragment stages. + pub early_depth_test: Option<EarlyDepthTest>, + /// Workgroup size for compute stages + pub workgroup_size: [u32; 3], + /// The entrance function. + pub function: Function, +} + +/// Return types predeclared for the frexp, modf, and atomicCompareExchangeWeak built-in functions. +/// +/// These cannot be spelled in WGSL source. +/// +/// Stored in [`SpecialTypes::predeclared_types`] and created by [`Module::generate_predeclared_type`]. +#[derive(Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "clone", derive(Clone))] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum PredeclaredType { + AtomicCompareExchangeWeakResult(Scalar), + ModfResult { + size: Option<VectorSize>, + width: Bytes, + }, + FrexpResult { + size: Option<VectorSize>, + width: Bytes, + }, +} + +/// Set of special types that can be optionally generated by the frontends. +#[derive(Debug, Default)] +#[cfg_attr(feature = "clone", derive(Clone))] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct SpecialTypes { + /// Type for `RayDesc`. + /// + /// Call [`Module::generate_ray_desc_type`] to populate this if + /// needed and return the handle. + pub ray_desc: Option<Handle<Type>>, + + /// Type for `RayIntersection`. + /// + /// Call [`Module::generate_ray_intersection_type`] to populate + /// this if needed and return the handle. + pub ray_intersection: Option<Handle<Type>>, + + /// Types for predeclared wgsl types instantiated on demand. + /// + /// Call [`Module::generate_predeclared_type`] to populate this if + /// needed and return the handle. + pub predeclared_types: FastIndexMap<PredeclaredType, Handle<Type>>, +} + +/// Shader module. +/// +/// A module is a set of constants, global variables and functions, as well as +/// the types required to define them. +/// +/// Some functions are marked as entry points, to be used in a certain shader stage. +/// +/// To create a new module, use the `Default` implementation. +/// Alternatively, you can load an existing shader using one of the [available front ends][front]. +/// +/// When finished, you can export modules using one of the [available backends][back]. +#[derive(Debug, Default)] +#[cfg_attr(feature = "clone", derive(Clone))] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct Module { + /// Arena for the types defined in this module. + pub types: UniqueArena<Type>, + /// Dictionary of special type handles. + pub special_types: SpecialTypes, + /// Arena for the constants defined in this module. + pub constants: Arena<Constant>, + /// Arena for the global variables defined in this module. + pub global_variables: Arena<GlobalVariable>, + /// [Constant expressions] and [override expressions] used by this module. + /// + /// Each `Expression` must occur in the arena before any + /// `Expression` that uses its value. + /// + /// [Constant expressions]: index.html#constant-expressions + /// [override expressions]: index.html#override-expressions + pub const_expressions: Arena<Expression>, + /// Arena for the functions defined in this module. + /// + /// Each function must appear in this arena strictly before all its callers. + /// Recursion is not supported. + pub functions: Arena<Function>, + /// Entry points. + pub entry_points: Vec<EntryPoint>, +} diff --git a/third_party/rust/naga/src/proc/constant_evaluator.rs b/third_party/rust/naga/src/proc/constant_evaluator.rs new file mode 100644 index 0000000000..b3884b04b1 --- /dev/null +++ b/third_party/rust/naga/src/proc/constant_evaluator.rs @@ -0,0 +1,2475 @@ +use std::iter; + +use arrayvec::ArrayVec; + +use crate::{ + arena::{Arena, Handle, UniqueArena}, + ArraySize, BinaryOperator, Constant, Expression, Literal, ScalarKind, Span, Type, TypeInner, + UnaryOperator, +}; + +/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating +/// `macro_rules!` items that, in turn, emit their own `macro_rules!` items. +/// +/// Technique stolen directly from +/// <https://github.com/rust-lang/rust/issues/35853#issuecomment-415993963>. +macro_rules! with_dollar_sign { + ($($body:tt)*) => { + macro_rules! __with_dollar_sign { $($body)* } + __with_dollar_sign!($); + } +} + +macro_rules! gen_component_wise_extractor { + ( + $ident:ident -> $target:ident, + literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?], + scalar_kinds: [$( $scalar_kind:ident ),* $(,)?], + ) => { + /// A subset of [`Literal`]s intended to be used for implementing numeric built-ins. + enum $target<const N: usize> { + $( + #[doc = concat!( + "Maps to [`Literal::", + stringify!($mapping), + "`]", + )] + $mapping([$ty; N]), + )+ + } + + impl From<$target<1>> for Expression { + fn from(value: $target<1>) -> Self { + match value { + $( + $target::$mapping([value]) => { + Expression::Literal(Literal::$literal(value)) + } + )+ + } + } + } + + #[doc = concat!( + "Attempts to evaluate multiple `exprs` as a combined [`", + stringify!($target), + "`] to pass to `handler`. ", + )] + /// If `exprs` are vectors of the same length, `handler` is called for each corresponding + /// component of each vector. + /// + /// `handler`'s output is registered as a new expression. If `exprs` are vectors of the + /// same length, a new vector expression is registered, composed of each component emitted + /// by `handler`. + fn $ident<const N: usize, const M: usize, F>( + eval: &mut ConstantEvaluator<'_>, + span: Span, + exprs: [Handle<Expression>; N], + mut handler: F, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> + where + $target<M>: Into<Expression>, + F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone, + { + assert!(N > 0); + let err = ConstantEvaluatorError::InvalidMathArg; + let mut exprs = exprs.into_iter(); + + macro_rules! sanitize { + ($expr:expr) => { + eval.eval_zero_value_and_splat($expr, span) + .map(|expr| &eval.expressions[expr]) + }; + } + + let new_expr = match sanitize!(exprs.next().unwrap())? { + $( + &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x)) + .chain(exprs.map(|expr| { + sanitize!(expr).and_then(|expr| match expr { + &Expression::Literal(Literal::$literal(x)) => Ok(x), + _ => Err(err.clone()), + }) + })) + .collect::<Result<ArrayVec<_, N>, _>>() + .map(|a| a.into_inner().unwrap()) + .map($target::$mapping) + .and_then(|comps| Ok(handler(comps)?.into())), + )+ + &Expression::Compose { ty, ref components } => match &eval.types[ty].inner { + &TypeInner::Vector { size, scalar } => match scalar.kind { + $(ScalarKind::$scalar_kind)|* => { + let first_ty = ty; + let mut component_groups = + ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new(); + component_groups.push(crate::proc::flatten_compose( + first_ty, + components, + eval.expressions, + eval.types, + ).collect()); + component_groups.extend( + exprs + .map(|expr| { + sanitize!(expr).and_then(|expr| match expr { + &Expression::Compose { ty, ref components } + if &eval.types[ty].inner + == &eval.types[first_ty].inner => + { + Ok(crate::proc::flatten_compose( + ty, + components, + eval.expressions, + eval.types, + ).collect()) + } + _ => Err(err.clone()), + }) + }) + .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>( + )?, + ); + let component_groups = component_groups.into_inner().unwrap(); + let mut new_components = + ArrayVec::<_, { crate::VectorSize::MAX }>::new(); + for idx in 0..(size as u8).into() { + let group = component_groups + .iter() + .map(|cs| cs[idx]) + .collect::<ArrayVec<_, N>>() + .into_inner() + .unwrap(); + new_components.push($ident( + eval, + span, + group, + handler.clone(), + )?); + } + Ok(Expression::Compose { + ty: first_ty, + components: new_components.into_iter().collect(), + }) + } + _ => return Err(err), + }, + _ => return Err(err), + }, + _ => return Err(err), + }?; + eval.register_evaluated_expr(new_expr, span) + } + + with_dollar_sign! { + ($d:tt) => { + #[allow(unused)] + #[doc = concat!( + "A convenience macro for using the same RHS for each [`", + stringify!($target), + "`] variant in a call to [`", + stringify!($ident), + "`].", + )] + macro_rules! $ident { + ( + $eval:expr, + $span:expr, + [$d ($d expr:expr),+ $d (,)?], + |$d ($d arg:ident),+| $d tt:tt + ) => { + $ident($eval, $span, [$d ($d expr),+], |args| match args { + $( + $target::$mapping([$d ($d arg),+]) => { + let res = $d tt; + Result::map(res, $target::$mapping) + }, + )+ + }) + }; + } + }; + } + }; +} + +gen_component_wise_extractor! { + component_wise_scalar -> Scalar, + literals: [ + AbstractFloat => AbstractFloat: f64, + F32 => F32: f32, + AbstractInt => AbstractInt: i64, + U32 => U32: u32, + I32 => I32: i32, + ], + scalar_kinds: [ + Float, + AbstractFloat, + Sint, + Uint, + AbstractInt, + ], +} + +gen_component_wise_extractor! { + component_wise_float -> Float, + literals: [ + AbstractFloat => Abstract: f64, + F32 => F32: f32, + ], + scalar_kinds: [ + Float, + AbstractFloat, + ], +} + +gen_component_wise_extractor! { + component_wise_concrete_int -> ConcreteInt, + literals: [ + U32 => U32: u32, + I32 => I32: i32, + ], + scalar_kinds: [ + Sint, + Uint, + ], +} + +gen_component_wise_extractor! { + component_wise_signed -> Signed, + literals: [ + AbstractFloat => AbstractFloat: f64, + AbstractInt => AbstractInt: i64, + F32 => F32: f32, + I32 => I32: i32, + ], + scalar_kinds: [ + Sint, + AbstractInt, + Float, + AbstractFloat, + ], +} + +#[derive(Debug)] +enum Behavior { + Wgsl, + Glsl, +} + +/// A context for evaluating constant expressions. +/// +/// A `ConstantEvaluator` points at an expression arena to which it can append +/// newly evaluated expressions: you pass [`try_eval_and_append`] whatever kind +/// of Naga [`Expression`] you like, and if its value can be computed at compile +/// time, `try_eval_and_append` appends an expression representing the computed +/// value - a tree of [`Literal`], [`Compose`], [`ZeroValue`], and [`Swizzle`] +/// expressions - to the arena. See the [`try_eval_and_append`] method for details. +/// +/// A `ConstantEvaluator` also holds whatever information we need to carry out +/// that evaluation: types, other constants, and so on. +/// +/// [`try_eval_and_append`]: ConstantEvaluator::try_eval_and_append +/// [`Compose`]: Expression::Compose +/// [`ZeroValue`]: Expression::ZeroValue +/// [`Literal`]: Expression::Literal +/// [`Swizzle`]: Expression::Swizzle +#[derive(Debug)] +pub struct ConstantEvaluator<'a> { + /// Which language's evaluation rules we should follow. + behavior: Behavior, + + /// The module's type arena. + /// + /// Because expressions like [`Splat`] contain type handles, we need to be + /// able to add new types to produce those expressions. + /// + /// [`Splat`]: Expression::Splat + types: &'a mut UniqueArena<Type>, + + /// The module's constant arena. + constants: &'a Arena<Constant>, + + /// The arena to which we are contributing expressions. + expressions: &'a mut Arena<Expression>, + + /// When `self.expressions` refers to a function's local expression + /// arena, this needs to be populated + function_local_data: Option<FunctionLocalData<'a>>, +} + +#[derive(Debug)] +struct FunctionLocalData<'a> { + /// Global constant expressions + const_expressions: &'a Arena<Expression>, + /// Tracks the constness of expressions residing in `ConstantEvaluator.expressions` + expression_constness: &'a mut ExpressionConstnessTracker, + emitter: &'a mut super::Emitter, + block: &'a mut crate::Block, +} + +#[derive(Debug)] +pub struct ExpressionConstnessTracker { + inner: bit_set::BitSet, +} + +impl ExpressionConstnessTracker { + pub fn new() -> Self { + Self { + inner: bit_set::BitSet::new(), + } + } + + /// Forces the the expression to not be const + pub fn force_non_const(&mut self, value: Handle<Expression>) { + self.inner.remove(value.index()); + } + + fn insert(&mut self, value: Handle<Expression>) { + self.inner.insert(value.index()); + } + + pub fn is_const(&self, value: Handle<Expression>) -> bool { + self.inner.contains(value.index()) + } + + pub fn from_arena(arena: &Arena<Expression>) -> Self { + let mut tracker = Self::new(); + for (handle, expr) in arena.iter() { + let insert = match *expr { + crate::Expression::Literal(_) + | crate::Expression::ZeroValue(_) + | crate::Expression::Constant(_) => true, + crate::Expression::Compose { ref components, .. } => { + components.iter().all(|h| tracker.is_const(*h)) + } + crate::Expression::Splat { value, .. } => tracker.is_const(value), + _ => false, + }; + if insert { + tracker.insert(handle); + } + } + tracker + } +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum ConstantEvaluatorError { + #[error("Constants cannot access function arguments")] + FunctionArg, + #[error("Constants cannot access global variables")] + GlobalVariable, + #[error("Constants cannot access local variables")] + LocalVariable, + #[error("Cannot get the array length of a non array type")] + InvalidArrayLengthArg, + #[error("Constants cannot get the array length of a dynamically sized array")] + ArrayLengthDynamic, + #[error("Constants cannot call functions")] + Call, + #[error("Constants don't support workGroupUniformLoad")] + WorkGroupUniformLoadResult, + #[error("Constants don't support atomic functions")] + Atomic, + #[error("Constants don't support derivative functions")] + Derivative, + #[error("Constants don't support load expressions")] + Load, + #[error("Constants don't support image expressions")] + ImageExpression, + #[error("Constants don't support ray query expressions")] + RayQueryExpression, + #[error("Cannot access the type")] + InvalidAccessBase, + #[error("Cannot access at the index")] + InvalidAccessIndex, + #[error("Cannot access with index of type")] + InvalidAccessIndexTy, + #[error("Constants don't support array length expressions")] + ArrayLength, + #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")] + InvalidCastArg { from: String, to: String }, + #[error("Cannot apply the unary op to the argument")] + InvalidUnaryOpArg, + #[error("Cannot apply the binary op to the arguments")] + InvalidBinaryOpArgs, + #[error("Cannot apply math function to type")] + InvalidMathArg, + #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")] + InvalidMathArgCount(crate::MathFunction, usize, usize), + #[error("value of `low` is greater than `high` for clamp built-in function")] + InvalidClamp, + #[error("Splat is defined only on scalar values")] + SplatScalarOnly, + #[error("Can only swizzle vector constants")] + SwizzleVectorOnly, + #[error("swizzle component not present in source expression")] + SwizzleOutOfBounds, + #[error("Type is not constructible")] + TypeNotConstructible, + #[error("Subexpression(s) are not constant")] + SubexpressionsAreNotConstant, + #[error("Not implemented as constant expression: {0}")] + NotImplemented(String), + #[error("{0} operation overflowed")] + Overflow(String), + #[error( + "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately" + )] + AutomaticConversionLossy { + value: String, + to_type: &'static str, + }, + #[error("abstract floating-point values cannot be automatically converted to integers")] + AutomaticConversionFloatToInt { to_type: &'static str }, + #[error("Division by zero")] + DivisionByZero, + #[error("Remainder by zero")] + RemainderByZero, + #[error("RHS of shift operation is greater than or equal to 32")] + ShiftedMoreThan32Bits, + #[error(transparent)] + Literal(#[from] crate::valid::LiteralError), +} + +impl<'a> ConstantEvaluator<'a> { + /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s + /// constant expression arena. + /// + /// Report errors according to WGSL's rules for constant evaluation. + pub fn for_wgsl_module(module: &'a mut crate::Module) -> Self { + Self::for_module(Behavior::Wgsl, module) + } + + /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s + /// constant expression arena. + /// + /// Report errors according to GLSL's rules for constant evaluation. + pub fn for_glsl_module(module: &'a mut crate::Module) -> Self { + Self::for_module(Behavior::Glsl, module) + } + + fn for_module(behavior: Behavior, module: &'a mut crate::Module) -> Self { + Self { + behavior, + types: &mut module.types, + constants: &module.constants, + expressions: &mut module.const_expressions, + function_local_data: None, + } + } + + /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s + /// expression arena. + /// + /// Report errors according to WGSL's rules for constant evaluation. + pub fn for_wgsl_function( + module: &'a mut crate::Module, + expressions: &'a mut Arena<Expression>, + expression_constness: &'a mut ExpressionConstnessTracker, + emitter: &'a mut super::Emitter, + block: &'a mut crate::Block, + ) -> Self { + Self::for_function( + Behavior::Wgsl, + module, + expressions, + expression_constness, + emitter, + block, + ) + } + + /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s + /// expression arena. + /// + /// Report errors according to GLSL's rules for constant evaluation. + pub fn for_glsl_function( + module: &'a mut crate::Module, + expressions: &'a mut Arena<Expression>, + expression_constness: &'a mut ExpressionConstnessTracker, + emitter: &'a mut super::Emitter, + block: &'a mut crate::Block, + ) -> Self { + Self::for_function( + Behavior::Glsl, + module, + expressions, + expression_constness, + emitter, + block, + ) + } + + fn for_function( + behavior: Behavior, + module: &'a mut crate::Module, + expressions: &'a mut Arena<Expression>, + expression_constness: &'a mut ExpressionConstnessTracker, + emitter: &'a mut super::Emitter, + block: &'a mut crate::Block, + ) -> Self { + Self { + behavior, + types: &mut module.types, + constants: &module.constants, + expressions, + function_local_data: Some(FunctionLocalData { + const_expressions: &module.const_expressions, + expression_constness, + emitter, + block, + }), + } + } + + pub fn to_ctx(&self) -> crate::proc::GlobalCtx { + crate::proc::GlobalCtx { + types: self.types, + constants: self.constants, + const_expressions: match self.function_local_data { + Some(ref data) => data.const_expressions, + None => self.expressions, + }, + } + } + + fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> { + if let Some(ref function_local_data) = self.function_local_data { + if !function_local_data.expression_constness.is_const(expr) { + log::debug!("check: SubexpressionsAreNotConstant"); + return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant); + } + } + Ok(()) + } + + fn check_and_get( + &mut self, + expr: Handle<Expression>, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + match self.expressions[expr] { + Expression::Constant(c) => { + // Are we working in a function's expression arena, or the + // module's constant expression arena? + if let Some(ref function_local_data) = self.function_local_data { + // Deep-copy the constant's value into our arena. + self.copy_from( + self.constants[c].init, + function_local_data.const_expressions, + ) + } else { + // "See through" the constant and use its initializer. + Ok(self.constants[c].init) + } + } + _ => { + self.check(expr)?; + Ok(expr) + } + } + } + + /// Try to evaluate `expr` at compile time. + /// + /// The `expr` argument can be any sort of Naga [`Expression`] you like. If + /// we can determine its value at compile time, we append an expression + /// representing its value - a tree of [`Literal`], [`Compose`], + /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena + /// `self` contributes to. + /// + /// If `expr`'s value cannot be determined at compile time, return a an + /// error. If it's acceptable to evaluate `expr` at runtime, this error can + /// be ignored, and the caller can append `expr` to the arena itself. + /// + /// We only consider `expr` itself, without recursing into its operands. Its + /// operands must all have been produced by prior calls to + /// `try_eval_and_append`, to ensure that they have already been reduced to + /// an evaluated form if possible. + /// + /// [`Literal`]: Expression::Literal + /// [`Compose`]: Expression::Compose + /// [`ZeroValue`]: Expression::ZeroValue + /// [`Swizzle`]: Expression::Swizzle + pub fn try_eval_and_append( + &mut self, + expr: &Expression, + span: Span, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + log::trace!("try_eval_and_append: {:?}", expr); + match *expr { + Expression::Constant(c) if self.function_local_data.is_none() => { + // "See through" the constant and use its initializer. + // This is mainly done to avoid having constants pointing to other constants. + Ok(self.constants[c].init) + } + Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { + self.register_evaluated_expr(expr.clone(), span) + } + Expression::Compose { ty, ref components } => { + let components = components + .iter() + .map(|component| self.check_and_get(*component)) + .collect::<Result<Vec<_>, _>>()?; + self.register_evaluated_expr(Expression::Compose { ty, components }, span) + } + Expression::Splat { size, value } => { + let value = self.check_and_get(value)?; + self.register_evaluated_expr(Expression::Splat { size, value }, span) + } + Expression::AccessIndex { base, index } => { + let base = self.check_and_get(base)?; + + self.access(base, index as usize, span) + } + Expression::Access { base, index } => { + let base = self.check_and_get(base)?; + let index = self.check_and_get(index)?; + + self.access(base, self.constant_index(index)?, span) + } + Expression::Swizzle { + size, + vector, + pattern, + } => { + let vector = self.check_and_get(vector)?; + + self.swizzle(size, span, vector, pattern) + } + Expression::Unary { expr, op } => { + let expr = self.check_and_get(expr)?; + + self.unary_op(op, expr, span) + } + Expression::Binary { left, right, op } => { + let left = self.check_and_get(left)?; + let right = self.check_and_get(right)?; + + self.binary_op(op, left, right, span) + } + Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + let arg = self.check_and_get(arg)?; + let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?; + let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?; + let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?; + + self.math(arg, arg1, arg2, arg3, fun, span) + } + Expression::As { + convert, + expr, + kind, + } => { + let expr = self.check_and_get(expr)?; + + match convert { + Some(width) => self.cast(expr, crate::Scalar { kind, width }, span), + None => Err(ConstantEvaluatorError::NotImplemented( + "bitcast built-in function".into(), + )), + } + } + Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented( + "select built-in function".into(), + )), + Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented( + format!("{fun:?} built-in function"), + )), + Expression::ArrayLength(expr) => match self.behavior { + Behavior::Wgsl => Err(ConstantEvaluatorError::ArrayLength), + Behavior::Glsl => { + let expr = self.check_and_get(expr)?; + self.array_length(expr, span) + } + }, + Expression::Load { .. } => Err(ConstantEvaluatorError::Load), + Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable), + Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative), + Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call), + Expression::WorkGroupUniformLoadResult { .. } => { + Err(ConstantEvaluatorError::WorkGroupUniformLoadResult) + } + Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic), + Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg), + Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable), + Expression::ImageSample { .. } + | Expression::ImageLoad { .. } + | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression), + Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => { + Err(ConstantEvaluatorError::RayQueryExpression) + } + } + } + + /// Splat `value` to `size`, without using [`Splat`] expressions. + /// + /// This constructs [`Compose`] or [`ZeroValue`] expressions to + /// build a vector with the given `size` whose components are all + /// `value`. + /// + /// Use `span` as the span of the inserted expressions and + /// resulting types. + /// + /// [`Splat`]: Expression::Splat + /// [`Compose`]: Expression::Compose + /// [`ZeroValue`]: Expression::ZeroValue + fn splat( + &mut self, + value: Handle<Expression>, + size: crate::VectorSize, + span: Span, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + match self.expressions[value] { + Expression::Literal(literal) => { + let scalar = literal.scalar(); + let ty = self.types.insert( + Type { + name: None, + inner: TypeInner::Vector { size, scalar }, + }, + span, + ); + let expr = Expression::Compose { + ty, + components: vec![value; size as usize], + }; + self.register_evaluated_expr(expr, span) + } + Expression::ZeroValue(ty) => { + let inner = match self.types[ty].inner { + TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar }, + _ => return Err(ConstantEvaluatorError::SplatScalarOnly), + }; + let res_ty = self.types.insert(Type { name: None, inner }, span); + let expr = Expression::ZeroValue(res_ty); + self.register_evaluated_expr(expr, span) + } + _ => Err(ConstantEvaluatorError::SplatScalarOnly), + } + } + + fn swizzle( + &mut self, + size: crate::VectorSize, + span: Span, + src_constant: Handle<Expression>, + pattern: [crate::SwizzleComponent; 4], + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + let mut get_dst_ty = |ty| match self.types[ty].inner { + crate::TypeInner::Vector { size: _, scalar } => Ok(self.types.insert( + Type { + name: None, + inner: crate::TypeInner::Vector { size, scalar }, + }, + span, + )), + _ => Err(ConstantEvaluatorError::SwizzleVectorOnly), + }; + + match self.expressions[src_constant] { + Expression::ZeroValue(ty) => { + let dst_ty = get_dst_ty(ty)?; + let expr = Expression::ZeroValue(dst_ty); + self.register_evaluated_expr(expr, span) + } + Expression::Splat { value, .. } => { + let expr = Expression::Splat { size, value }; + self.register_evaluated_expr(expr, span) + } + Expression::Compose { ty, ref components } => { + let dst_ty = get_dst_ty(ty)?; + + let mut flattened = [src_constant; 4]; // dummy value + let len = + crate::proc::flatten_compose(ty, components, self.expressions, self.types) + .zip(flattened.iter_mut()) + .map(|(component, elt)| *elt = component) + .count(); + let flattened = &flattened[..len]; + + let swizzled_components = pattern[..size as usize] + .iter() + .map(|&sc| { + let sc = sc as usize; + if let Some(elt) = flattened.get(sc) { + Ok(*elt) + } else { + Err(ConstantEvaluatorError::SwizzleOutOfBounds) + } + }) + .collect::<Result<Vec<Handle<Expression>>, _>>()?; + let expr = Expression::Compose { + ty: dst_ty, + components: swizzled_components, + }; + self.register_evaluated_expr(expr, span) + } + _ => Err(ConstantEvaluatorError::SwizzleVectorOnly), + } + } + + fn math( + &mut self, + arg: Handle<Expression>, + arg1: Option<Handle<Expression>>, + arg2: Option<Handle<Expression>>, + arg3: Option<Handle<Expression>>, + fun: crate::MathFunction, + span: Span, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + let expected = fun.argument_count(); + let given = Some(arg) + .into_iter() + .chain(arg1) + .chain(arg2) + .chain(arg3) + .count(); + if expected != given { + return Err(ConstantEvaluatorError::InvalidMathArgCount( + fun, expected, given, + )); + } + + // NOTE: We try to match the declaration order of `MathFunction` here. + match fun { + // comparison + crate::MathFunction::Abs => { + component_wise_scalar(self, span, [arg], |args| match args { + Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])), + Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])), + Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.abs()])), + Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])), + Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz + }) + } + crate::MathFunction::Min => { + component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| { + Ok([e1.min(e2)]) + }) + } + crate::MathFunction::Max => { + component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| { + Ok([e1.max(e2)]) + }) + } + crate::MathFunction::Clamp => { + component_wise_scalar!( + self, + span, + [arg, arg1.unwrap(), arg2.unwrap()], + |e, low, high| { + if low > high { + Err(ConstantEvaluatorError::InvalidClamp) + } else { + Ok([e.clamp(low, high)]) + } + } + ) + } + crate::MathFunction::Saturate => { + component_wise_float!(self, span, [arg], |e| { Ok([e.clamp(0., 1.)]) }) + } + + // trigonometry + crate::MathFunction::Cos => { + component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) }) + } + crate::MathFunction::Cosh => { + component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) }) + } + crate::MathFunction::Sin => { + component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) }) + } + crate::MathFunction::Sinh => { + component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) }) + } + crate::MathFunction::Tan => { + component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) }) + } + crate::MathFunction::Tanh => { + component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) }) + } + crate::MathFunction::Acos => { + component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) }) + } + crate::MathFunction::Asin => { + component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) }) + } + crate::MathFunction::Atan => { + component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) }) + } + crate::MathFunction::Asinh => { + component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) }) + } + crate::MathFunction::Acosh => { + component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) }) + } + crate::MathFunction::Atanh => { + component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) }) + } + crate::MathFunction::Radians => { + component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) }) + } + crate::MathFunction::Degrees => { + component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) }) + } + + // decomposition + crate::MathFunction::Ceil => { + component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) }) + } + crate::MathFunction::Floor => { + component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) }) + } + crate::MathFunction::Round => { + // TODO: Use `f{32,64}.round_ties_even()` when available on stable. This polyfill + // is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source], + // which has licensing compatible with ours. See also + // <https://github.com/rust-lang/rust/issues/96710>. + // + // [polyfill source]: https://github.com/imeka/ndarray-ndimage/blob/8b14b4d6ecfbc96a8a052f802e342a7049c68d8f/src/lib.rs#L98 + fn round_ties_even(x: f64) -> f64 { + let i = x as i64; + let f = (x - i as f64).abs(); + if f == 0.5 { + if i & 1 == 1 { + // -1.5, 1.5, 3.5, ... + (x.abs() + 0.5).copysign(x) + } else { + (x.abs() - 0.5).copysign(x) + } + } else { + x.round() + } + } + component_wise_float(self, span, [arg], |e| match e { + Float::Abstract([e]) => Ok(Float::Abstract([round_ties_even(e)])), + Float::F32([e]) => Ok(Float::F32([(round_ties_even(e as f64) as f32)])), + }) + } + crate::MathFunction::Fract => { + component_wise_float!(self, span, [arg], |e| { + // N.B., Rust's definition of `fract` is `e - e.trunc()`, so we can't use that + // here. + Ok([e - e.floor()]) + }) + } + crate::MathFunction::Trunc => { + component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) }) + } + + // exponent + crate::MathFunction::Exp => { + component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) }) + } + crate::MathFunction::Exp2 => { + component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) }) + } + crate::MathFunction::Log => { + component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) }) + } + crate::MathFunction::Log2 => { + component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) }) + } + crate::MathFunction::Pow => { + component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| { + Ok([e1.powf(e2)]) + }) + } + + // computational + crate::MathFunction::Sign => { + component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) }) + } + crate::MathFunction::Fma => { + component_wise_float!( + self, + span, + [arg, arg1.unwrap(), arg2.unwrap()], + |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) } + ) + } + crate::MathFunction::Step => { + component_wise_float!(self, span, [arg, arg1.unwrap()], |edge, x| { + Ok([if edge <= x { 1.0 } else { 0.0 }]) + }) + } + crate::MathFunction::Sqrt => { + component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) }) + } + crate::MathFunction::InverseSqrt => { + component_wise_float!(self, span, [arg], |e| { Ok([1. / e.sqrt()]) }) + } + + // bits + crate::MathFunction::CountTrailingZeros => { + component_wise_concrete_int!(self, span, [arg], |e| { + #[allow(clippy::useless_conversion)] + Ok([e + .trailing_zeros() + .try_into() + .expect("bit count overflowed 32 bits, somehow!?")]) + }) + } + crate::MathFunction::CountLeadingZeros => { + component_wise_concrete_int!(self, span, [arg], |e| { + #[allow(clippy::useless_conversion)] + Ok([e + .leading_zeros() + .try_into() + .expect("bit count overflowed 32 bits, somehow!?")]) + }) + } + crate::MathFunction::CountOneBits => { + component_wise_concrete_int!(self, span, [arg], |e| { + #[allow(clippy::useless_conversion)] + Ok([e + .count_ones() + .try_into() + .expect("bit count overflowed 32 bits, somehow!?")]) + }) + } + crate::MathFunction::ReverseBits => { + component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) }) + } + + fun => Err(ConstantEvaluatorError::NotImplemented(format!( + "{fun:?} built-in function" + ))), + } + } + + fn array_length( + &mut self, + array: Handle<Expression>, + span: Span, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + match self.expressions[array] { + Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => { + match self.types[ty].inner { + TypeInner::Array { size, .. } => match size { + crate::ArraySize::Constant(len) => { + let expr = Expression::Literal(Literal::U32(len.get())); + self.register_evaluated_expr(expr, span) + } + crate::ArraySize::Dynamic => { + Err(ConstantEvaluatorError::ArrayLengthDynamic) + } + }, + _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg), + } + } + _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg), + } + } + + fn access( + &mut self, + base: Handle<Expression>, + index: usize, + span: Span, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + match self.expressions[base] { + Expression::ZeroValue(ty) => { + let ty_inner = &self.types[ty].inner; + let components = ty_inner + .components() + .ok_or(ConstantEvaluatorError::InvalidAccessBase)?; + + if index >= components as usize { + Err(ConstantEvaluatorError::InvalidAccessBase) + } else { + let ty_res = ty_inner + .component_type(index) + .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?; + let ty = match ty_res { + crate::proc::TypeResolution::Handle(ty) => ty, + crate::proc::TypeResolution::Value(inner) => { + self.types.insert(Type { name: None, inner }, span) + } + }; + self.register_evaluated_expr(Expression::ZeroValue(ty), span) + } + } + Expression::Splat { size, value } => { + if index >= size as usize { + Err(ConstantEvaluatorError::InvalidAccessBase) + } else { + Ok(value) + } + } + Expression::Compose { ty, ref components } => { + let _ = self.types[ty] + .inner + .components() + .ok_or(ConstantEvaluatorError::InvalidAccessBase)?; + + crate::proc::flatten_compose(ty, components, self.expressions, self.types) + .nth(index) + .ok_or(ConstantEvaluatorError::InvalidAccessIndex) + } + _ => Err(ConstantEvaluatorError::InvalidAccessBase), + } + } + + fn constant_index(&self, expr: Handle<Expression>) -> Result<usize, ConstantEvaluatorError> { + match self.expressions[expr] { + Expression::ZeroValue(ty) + if matches!( + self.types[ty].inner, + crate::TypeInner::Scalar(crate::Scalar { + kind: ScalarKind::Uint, + .. + }) + ) => + { + Ok(0) + } + Expression::Literal(Literal::U32(index)) => Ok(index as usize), + _ => Err(ConstantEvaluatorError::InvalidAccessIndexTy), + } + } + + /// Lower [`ZeroValue`] and [`Splat`] expressions to [`Literal`] and [`Compose`] expressions. + /// + /// [`ZeroValue`]: Expression::ZeroValue + /// [`Splat`]: Expression::Splat + /// [`Literal`]: Expression::Literal + /// [`Compose`]: Expression::Compose + fn eval_zero_value_and_splat( + &mut self, + expr: Handle<Expression>, + span: Span, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + match self.expressions[expr] { + Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span), + Expression::Splat { size, value } => self.splat(value, size, span), + _ => Ok(expr), + } + } + + /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions. + /// + /// [`ZeroValue`]: Expression::ZeroValue + /// [`Literal`]: Expression::Literal + /// [`Compose`]: Expression::Compose + fn eval_zero_value( + &mut self, + expr: Handle<Expression>, + span: Span, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + match self.expressions[expr] { + Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span), + _ => Ok(expr), + } + } + + /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions. + /// + /// [`ZeroValue`]: Expression::ZeroValue + /// [`Literal`]: Expression::Literal + /// [`Compose`]: Expression::Compose + fn eval_zero_value_impl( + &mut self, + ty: Handle<Type>, + span: Span, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + match self.types[ty].inner { + TypeInner::Scalar(scalar) => { + let expr = Expression::Literal( + Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?, + ); + self.register_evaluated_expr(expr, span) + } + TypeInner::Vector { size, scalar } => { + let scalar_ty = self.types.insert( + Type { + name: None, + inner: TypeInner::Scalar(scalar), + }, + span, + ); + let el = self.eval_zero_value_impl(scalar_ty, span)?; + let expr = Expression::Compose { + ty, + components: vec![el; size as usize], + }; + self.register_evaluated_expr(expr, span) + } + TypeInner::Matrix { + columns, + rows, + scalar, + } => { + let vec_ty = self.types.insert( + Type { + name: None, + inner: TypeInner::Vector { size: rows, scalar }, + }, + span, + ); + let el = self.eval_zero_value_impl(vec_ty, span)?; + let expr = Expression::Compose { + ty, + components: vec![el; columns as usize], + }; + self.register_evaluated_expr(expr, span) + } + TypeInner::Array { + base, + size: ArraySize::Constant(size), + .. + } => { + let el = self.eval_zero_value_impl(base, span)?; + let expr = Expression::Compose { + ty, + components: vec![el; size.get() as usize], + }; + self.register_evaluated_expr(expr, span) + } + TypeInner::Struct { ref members, .. } => { + let types: Vec<_> = members.iter().map(|m| m.ty).collect(); + let mut components = Vec::with_capacity(members.len()); + for ty in types { + components.push(self.eval_zero_value_impl(ty, span)?); + } + let expr = Expression::Compose { ty, components }; + self.register_evaluated_expr(expr, span) + } + _ => Err(ConstantEvaluatorError::TypeNotConstructible), + } + } + + /// Convert the scalar components of `expr` to `target`. + /// + /// Treat `span` as the location of the resulting expression. + pub fn cast( + &mut self, + expr: Handle<Expression>, + target: crate::Scalar, + span: Span, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + use crate::Scalar as Sc; + + let expr = self.eval_zero_value(expr, span)?; + + let make_error = || -> Result<_, ConstantEvaluatorError> { + let from = format!("{:?} {:?}", expr, self.expressions[expr]); + + #[cfg(feature = "wgsl-in")] + let to = target.to_wgsl(); + + #[cfg(not(feature = "wgsl-in"))] + let to = format!("{target:?}"); + + Err(ConstantEvaluatorError::InvalidCastArg { from, to }) + }; + + let expr = match self.expressions[expr] { + Expression::Literal(literal) => { + let literal = match target { + Sc::I32 => Literal::I32(match literal { + Literal::I32(v) => v, + Literal::U32(v) => v as i32, + Literal::F32(v) => v as i32, + Literal::Bool(v) => v as i32, + Literal::F64(_) | Literal::I64(_) => { + return make_error(); + } + Literal::AbstractInt(v) => i32::try_from_abstract(v)?, + Literal::AbstractFloat(v) => i32::try_from_abstract(v)?, + }), + Sc::U32 => Literal::U32(match literal { + Literal::I32(v) => v as u32, + Literal::U32(v) => v, + Literal::F32(v) => v as u32, + Literal::Bool(v) => v as u32, + Literal::F64(_) | Literal::I64(_) => { + return make_error(); + } + Literal::AbstractInt(v) => u32::try_from_abstract(v)?, + Literal::AbstractFloat(v) => u32::try_from_abstract(v)?, + }), + Sc::F32 => Literal::F32(match literal { + Literal::I32(v) => v as f32, + Literal::U32(v) => v as f32, + Literal::F32(v) => v, + Literal::Bool(v) => v as u32 as f32, + Literal::F64(_) | Literal::I64(_) => { + return make_error(); + } + Literal::AbstractInt(v) => f32::try_from_abstract(v)?, + Literal::AbstractFloat(v) => f32::try_from_abstract(v)?, + }), + Sc::F64 => Literal::F64(match literal { + Literal::I32(v) => v as f64, + Literal::U32(v) => v as f64, + Literal::F32(v) => v as f64, + Literal::F64(v) => v, + Literal::Bool(v) => v as u32 as f64, + Literal::I64(_) => return make_error(), + Literal::AbstractInt(v) => f64::try_from_abstract(v)?, + Literal::AbstractFloat(v) => f64::try_from_abstract(v)?, + }), + Sc::BOOL => Literal::Bool(match literal { + Literal::I32(v) => v != 0, + Literal::U32(v) => v != 0, + Literal::F32(v) => v != 0.0, + Literal::Bool(v) => v, + Literal::F64(_) + | Literal::I64(_) + | Literal::AbstractInt(_) + | Literal::AbstractFloat(_) => { + return make_error(); + } + }), + Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal { + Literal::AbstractInt(v) => { + // Overflow is forbidden, but inexact conversions + // are fine. The range of f64 is far larger than + // that of i64, so we don't have to check anything + // here. + v as f64 + } + Literal::AbstractFloat(v) => v, + _ => return make_error(), + }), + _ => { + log::debug!("Constant evaluator refused to convert value to {target:?}"); + return make_error(); + } + }; + Expression::Literal(literal) + } + Expression::Compose { + ty, + components: ref src_components, + } => { + let ty_inner = match self.types[ty].inner { + TypeInner::Vector { size, .. } => TypeInner::Vector { + size, + scalar: target, + }, + TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix { + columns, + rows, + scalar: target, + }, + _ => return make_error(), + }; + + let mut components = src_components.clone(); + for component in &mut components { + *component = self.cast(*component, target, span)?; + } + + let ty = self.types.insert( + Type { + name: None, + inner: ty_inner, + }, + span, + ); + + Expression::Compose { ty, components } + } + Expression::Splat { size, value } => { + let value_span = self.expressions.get_span(value); + let cast_value = self.cast(value, target, value_span)?; + Expression::Splat { + size, + value: cast_value, + } + } + _ => return make_error(), + }; + + self.register_evaluated_expr(expr, span) + } + + /// Convert the scalar leaves of `expr` to `target`, handling arrays. + /// + /// `expr` must be a `Compose` expression whose type is a scalar, vector, + /// matrix, or nested arrays of such. + /// + /// This is basically the same as the [`cast`] method, except that that + /// should only handle Naga [`As`] expressions, which cannot convert arrays. + /// + /// Treat `span` as the location of the resulting expression. + /// + /// [`cast`]: ConstantEvaluator::cast + /// [`As`]: crate::Expression::As + pub fn cast_array( + &mut self, + expr: Handle<Expression>, + target: crate::Scalar, + span: Span, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + let Expression::Compose { ty, ref components } = self.expressions[expr] else { + return self.cast(expr, target, span); + }; + + let crate::TypeInner::Array { + base: _, + size, + stride: _, + } = self.types[ty].inner + else { + return self.cast(expr, target, span); + }; + + let mut components = components.clone(); + for component in &mut components { + *component = self.cast_array(*component, target, span)?; + } + + let first = components.first().unwrap(); + let new_base = match self.resolve_type(*first)? { + crate::proc::TypeResolution::Handle(ty) => ty, + crate::proc::TypeResolution::Value(inner) => { + self.types.insert(Type { name: None, inner }, span) + } + }; + let new_base_stride = self.types[new_base].inner.size(self.to_ctx()); + let new_array_ty = self.types.insert( + Type { + name: None, + inner: TypeInner::Array { + base: new_base, + size, + stride: new_base_stride, + }, + }, + span, + ); + + let compose = Expression::Compose { + ty: new_array_ty, + components, + }; + self.register_evaluated_expr(compose, span) + } + + fn unary_op( + &mut self, + op: UnaryOperator, + expr: Handle<Expression>, + span: Span, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + let expr = self.eval_zero_value_and_splat(expr, span)?; + + let expr = match self.expressions[expr] { + Expression::Literal(value) => Expression::Literal(match op { + UnaryOperator::Negate => match value { + Literal::I32(v) => Literal::I32(v.wrapping_neg()), + Literal::F32(v) => Literal::F32(-v), + Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()), + Literal::AbstractFloat(v) => Literal::AbstractFloat(-v), + _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg), + }, + UnaryOperator::LogicalNot => match value { + Literal::Bool(v) => Literal::Bool(!v), + _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg), + }, + UnaryOperator::BitwiseNot => match value { + Literal::I32(v) => Literal::I32(!v), + Literal::U32(v) => Literal::U32(!v), + Literal::AbstractInt(v) => Literal::AbstractInt(!v), + _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg), + }, + }), + Expression::Compose { + ty, + components: ref src_components, + } => { + match self.types[ty].inner { + TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (), + _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg), + } + + let mut components = src_components.clone(); + for component in &mut components { + *component = self.unary_op(op, *component, span)?; + } + + Expression::Compose { ty, components } + } + _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg), + }; + + self.register_evaluated_expr(expr, span) + } + + fn binary_op( + &mut self, + op: BinaryOperator, + left: Handle<Expression>, + right: Handle<Expression>, + span: Span, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + let left = self.eval_zero_value_and_splat(left, span)?; + let right = self.eval_zero_value_and_splat(right, span)?; + + let expr = match (&self.expressions[left], &self.expressions[right]) { + (&Expression::Literal(left_value), &Expression::Literal(right_value)) => { + let literal = match op { + BinaryOperator::Equal => Literal::Bool(left_value == right_value), + BinaryOperator::NotEqual => Literal::Bool(left_value != right_value), + BinaryOperator::Less => Literal::Bool(left_value < right_value), + BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value), + BinaryOperator::Greater => Literal::Bool(left_value > right_value), + BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value), + + _ => match (left_value, right_value) { + (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op { + BinaryOperator::Add => a.checked_add(b).ok_or_else(|| { + ConstantEvaluatorError::Overflow("addition".into()) + })?, + BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| { + ConstantEvaluatorError::Overflow("subtraction".into()) + })?, + BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| { + ConstantEvaluatorError::Overflow("multiplication".into()) + })?, + BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| { + if b == 0 { + ConstantEvaluatorError::DivisionByZero + } else { + ConstantEvaluatorError::Overflow("division".into()) + } + })?, + BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| { + if b == 0 { + ConstantEvaluatorError::RemainderByZero + } else { + ConstantEvaluatorError::Overflow("remainder".into()) + } + })?, + BinaryOperator::And => a & b, + BinaryOperator::ExclusiveOr => a ^ b, + BinaryOperator::InclusiveOr => a | b, + _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), + }), + (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op { + BinaryOperator::ShiftLeft => a + .checked_shl(b) + .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?, + BinaryOperator::ShiftRight => a + .checked_shr(b) + .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?, + _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), + }), + (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op { + BinaryOperator::Add => a.checked_add(b).ok_or_else(|| { + ConstantEvaluatorError::Overflow("addition".into()) + })?, + BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| { + ConstantEvaluatorError::Overflow("subtraction".into()) + })?, + BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| { + ConstantEvaluatorError::Overflow("multiplication".into()) + })?, + BinaryOperator::Divide => a + .checked_div(b) + .ok_or(ConstantEvaluatorError::DivisionByZero)?, + BinaryOperator::Modulo => a + .checked_rem(b) + .ok_or(ConstantEvaluatorError::RemainderByZero)?, + BinaryOperator::And => a & b, + BinaryOperator::ExclusiveOr => a ^ b, + BinaryOperator::InclusiveOr => a | b, + BinaryOperator::ShiftLeft => a + .checked_shl(b) + .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?, + BinaryOperator::ShiftRight => a + .checked_shr(b) + .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?, + _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), + }), + (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op { + BinaryOperator::Add => a + b, + BinaryOperator::Subtract => a - b, + BinaryOperator::Multiply => a * b, + BinaryOperator::Divide => a / b, + BinaryOperator::Modulo => a % b, + _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), + }), + (Literal::AbstractInt(a), Literal::AbstractInt(b)) => { + Literal::AbstractInt(match op { + BinaryOperator::Add => a.checked_add(b).ok_or_else(|| { + ConstantEvaluatorError::Overflow("addition".into()) + })?, + BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| { + ConstantEvaluatorError::Overflow("subtraction".into()) + })?, + BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| { + ConstantEvaluatorError::Overflow("multiplication".into()) + })?, + BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| { + if b == 0 { + ConstantEvaluatorError::DivisionByZero + } else { + ConstantEvaluatorError::Overflow("division".into()) + } + })?, + BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| { + if b == 0 { + ConstantEvaluatorError::RemainderByZero + } else { + ConstantEvaluatorError::Overflow("remainder".into()) + } + })?, + BinaryOperator::And => a & b, + BinaryOperator::ExclusiveOr => a ^ b, + BinaryOperator::InclusiveOr => a | b, + _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), + }) + } + (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => { + Literal::AbstractFloat(match op { + BinaryOperator::Add => a + b, + BinaryOperator::Subtract => a - b, + BinaryOperator::Multiply => a * b, + BinaryOperator::Divide => a / b, + BinaryOperator::Modulo => a % b, + _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), + }) + } + (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op { + BinaryOperator::LogicalAnd => a && b, + BinaryOperator::LogicalOr => a || b, + _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), + }), + _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), + }, + }; + Expression::Literal(literal) + } + ( + &Expression::Compose { + components: ref src_components, + ty, + }, + &Expression::Literal(_), + ) => { + let mut components = src_components.clone(); + for component in &mut components { + *component = self.binary_op(op, *component, right, span)?; + } + Expression::Compose { ty, components } + } + ( + &Expression::Literal(_), + &Expression::Compose { + components: ref src_components, + ty, + }, + ) => { + let mut components = src_components.clone(); + for component in &mut components { + *component = self.binary_op(op, left, *component, span)?; + } + Expression::Compose { ty, components } + } + ( + &Expression::Compose { + components: ref left_components, + ty: left_ty, + }, + &Expression::Compose { + components: ref right_components, + ty: right_ty, + }, + ) => { + // We have to make a copy of the component lists, because the + // call to `binary_op_vector` needs `&mut self`, but `self` owns + // the component lists. + let left_flattened = crate::proc::flatten_compose( + left_ty, + left_components, + self.expressions, + self.types, + ); + let right_flattened = crate::proc::flatten_compose( + right_ty, + right_components, + self.expressions, + self.types, + ); + + // `flatten_compose` doesn't return an `ExactSizeIterator`, so + // make a reasonable guess of the capacity we'll need. + let mut flattened = Vec::with_capacity(left_components.len()); + flattened.extend(left_flattened.zip(right_flattened)); + + match (&self.types[left_ty].inner, &self.types[right_ty].inner) { + ( + &TypeInner::Vector { + size: left_size, .. + }, + &TypeInner::Vector { + size: right_size, .. + }, + ) if left_size == right_size => { + self.binary_op_vector(op, left_size, &flattened, left_ty, span)? + } + _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), + } + } + _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs), + }; + + self.register_evaluated_expr(expr, span) + } + + fn binary_op_vector( + &mut self, + op: BinaryOperator, + size: crate::VectorSize, + components: &[(Handle<Expression>, Handle<Expression>)], + left_ty: Handle<Type>, + span: Span, + ) -> Result<Expression, ConstantEvaluatorError> { + let ty = match op { + // Relational operators produce vectors of booleans. + BinaryOperator::Equal + | BinaryOperator::NotEqual + | BinaryOperator::Less + | BinaryOperator::LessEqual + | BinaryOperator::Greater + | BinaryOperator::GreaterEqual => self.types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size, + scalar: crate::Scalar::BOOL, + }, + }, + span, + ), + + // Other operators produce the same type as their left + // operand. + BinaryOperator::Add + | BinaryOperator::Subtract + | BinaryOperator::Multiply + | BinaryOperator::Divide + | BinaryOperator::Modulo + | BinaryOperator::And + | BinaryOperator::ExclusiveOr + | BinaryOperator::InclusiveOr + | BinaryOperator::LogicalAnd + | BinaryOperator::LogicalOr + | BinaryOperator::ShiftLeft + | BinaryOperator::ShiftRight => left_ty, + }; + + let components = components + .iter() + .map(|&(left, right)| self.binary_op(op, left, right, span)) + .collect::<Result<Vec<_>, _>>()?; + + Ok(Expression::Compose { ty, components }) + } + + /// Deep copy `expr` from `expressions` into `self.expressions`. + /// + /// Return the root of the new copy. + /// + /// This is used when we're evaluating expressions in a function's + /// expression arena that refer to a constant: we need to copy the + /// constant's value into the function's arena so we can operate on it. + fn copy_from( + &mut self, + expr: Handle<Expression>, + expressions: &Arena<Expression>, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + let span = expressions.get_span(expr); + match expressions[expr] { + ref expr @ (Expression::Literal(_) + | Expression::Constant(_) + | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span), + Expression::Compose { ty, ref components } => { + let mut components = components.clone(); + for component in &mut components { + *component = self.copy_from(*component, expressions)?; + } + self.register_evaluated_expr(Expression::Compose { ty, components }, span) + } + Expression::Splat { size, value } => { + let value = self.copy_from(value, expressions)?; + self.register_evaluated_expr(Expression::Splat { size, value }, span) + } + _ => { + log::debug!("copy_from: SubexpressionsAreNotConstant"); + Err(ConstantEvaluatorError::SubexpressionsAreNotConstant) + } + } + } + + fn register_evaluated_expr( + &mut self, + expr: Expression, + span: Span, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + // It suffices to only check literals, since we only register one + // expression at a time, `Compose` expressions can only refer to other + // expressions, and `ZeroValue` expressions are always okay. + if let Expression::Literal(literal) = expr { + crate::valid::check_literal_value(literal)?; + } + + if let Some(FunctionLocalData { + ref mut emitter, + ref mut block, + ref mut expression_constness, + .. + }) = self.function_local_data + { + let is_running = emitter.is_running(); + let needs_pre_emit = expr.needs_pre_emit(); + if is_running && needs_pre_emit { + block.extend(emitter.finish(self.expressions)); + let h = self.expressions.append(expr, span); + emitter.start(self.expressions); + expression_constness.insert(h); + Ok(h) + } else { + let h = self.expressions.append(expr, span); + expression_constness.insert(h); + Ok(h) + } + } else { + Ok(self.expressions.append(expr, span)) + } + } + + fn resolve_type( + &self, + expr: Handle<Expression>, + ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> { + use crate::proc::TypeResolution as Tr; + use crate::Expression as Ex; + let resolution = match self.expressions[expr] { + Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()), + Ex::Constant(c) => Tr::Handle(self.constants[c].ty), + Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty), + Ex::Splat { size, value } => { + let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else { + return Err(ConstantEvaluatorError::SplatScalarOnly); + }; + Tr::Value(TypeInner::Vector { scalar, size }) + } + _ => { + log::debug!("resolve_type: SubexpressionsAreNotConstant"); + return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant); + } + }; + + Ok(resolution) + } +} + +#[cfg(test)] +mod tests { + use std::vec; + + use crate::{ + Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator, + UniqueArena, VectorSize, + }; + + use super::{Behavior, ConstantEvaluator}; + + #[test] + fn unary_op() { + let mut types = UniqueArena::new(); + let mut constants = Arena::new(); + let mut const_expressions = Arena::new(); + + let scalar_ty = types.insert( + Type { + name: None, + inner: TypeInner::Scalar(crate::Scalar::I32), + }, + Default::default(), + ); + + let vec_ty = types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size: VectorSize::Bi, + scalar: crate::Scalar::I32, + }, + }, + Default::default(), + ); + + let h = constants.append( + Constant { + name: None, + r#override: crate::Override::None, + ty: scalar_ty, + init: const_expressions + .append(Expression::Literal(Literal::I32(4)), Default::default()), + }, + Default::default(), + ); + + let h1 = constants.append( + Constant { + name: None, + r#override: crate::Override::None, + ty: scalar_ty, + init: const_expressions + .append(Expression::Literal(Literal::I32(8)), Default::default()), + }, + Default::default(), + ); + + let vec_h = constants.append( + Constant { + name: None, + r#override: crate::Override::None, + ty: vec_ty, + init: const_expressions.append( + Expression::Compose { + ty: vec_ty, + components: vec![constants[h].init, constants[h1].init], + }, + Default::default(), + ), + }, + Default::default(), + ); + + let expr = const_expressions.append(Expression::Constant(h), Default::default()); + let expr1 = const_expressions.append(Expression::Constant(vec_h), Default::default()); + + let expr2 = Expression::Unary { + op: UnaryOperator::Negate, + expr, + }; + + let expr3 = Expression::Unary { + op: UnaryOperator::BitwiseNot, + expr, + }; + + let expr4 = Expression::Unary { + op: UnaryOperator::BitwiseNot, + expr: expr1, + }; + + let mut solver = ConstantEvaluator { + behavior: Behavior::Wgsl, + types: &mut types, + constants: &constants, + expressions: &mut const_expressions, + function_local_data: None, + }; + + let res1 = solver + .try_eval_and_append(&expr2, Default::default()) + .unwrap(); + let res2 = solver + .try_eval_and_append(&expr3, Default::default()) + .unwrap(); + let res3 = solver + .try_eval_and_append(&expr4, Default::default()) + .unwrap(); + + assert_eq!( + const_expressions[res1], + Expression::Literal(Literal::I32(-4)) + ); + + assert_eq!( + const_expressions[res2], + Expression::Literal(Literal::I32(!4)) + ); + + let res3_inner = &const_expressions[res3]; + + match *res3_inner { + Expression::Compose { + ref ty, + ref components, + } => { + assert_eq!(*ty, vec_ty); + let mut components_iter = components.iter().copied(); + assert_eq!( + const_expressions[components_iter.next().unwrap()], + Expression::Literal(Literal::I32(!4)) + ); + assert_eq!( + const_expressions[components_iter.next().unwrap()], + Expression::Literal(Literal::I32(!8)) + ); + assert!(components_iter.next().is_none()); + } + _ => panic!("Expected vector"), + } + } + + #[test] + fn cast() { + let mut types = UniqueArena::new(); + let mut constants = Arena::new(); + let mut const_expressions = Arena::new(); + + let scalar_ty = types.insert( + Type { + name: None, + inner: TypeInner::Scalar(crate::Scalar::I32), + }, + Default::default(), + ); + + let h = constants.append( + Constant { + name: None, + r#override: crate::Override::None, + ty: scalar_ty, + init: const_expressions + .append(Expression::Literal(Literal::I32(4)), Default::default()), + }, + Default::default(), + ); + + let expr = const_expressions.append(Expression::Constant(h), Default::default()); + + let root = Expression::As { + expr, + kind: ScalarKind::Bool, + convert: Some(crate::BOOL_WIDTH), + }; + + let mut solver = ConstantEvaluator { + behavior: Behavior::Wgsl, + types: &mut types, + constants: &constants, + expressions: &mut const_expressions, + function_local_data: None, + }; + + let res = solver + .try_eval_and_append(&root, Default::default()) + .unwrap(); + + assert_eq!( + const_expressions[res], + Expression::Literal(Literal::Bool(true)) + ); + } + + #[test] + fn access() { + let mut types = UniqueArena::new(); + let mut constants = Arena::new(); + let mut const_expressions = Arena::new(); + + let matrix_ty = types.insert( + Type { + name: None, + inner: TypeInner::Matrix { + columns: VectorSize::Bi, + rows: VectorSize::Tri, + scalar: crate::Scalar::F32, + }, + }, + Default::default(), + ); + + let vec_ty = types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size: VectorSize::Tri, + scalar: crate::Scalar::F32, + }, + }, + Default::default(), + ); + + let mut vec1_components = Vec::with_capacity(3); + let mut vec2_components = Vec::with_capacity(3); + + for i in 0..3 { + let h = const_expressions.append( + Expression::Literal(Literal::F32(i as f32)), + Default::default(), + ); + + vec1_components.push(h) + } + + for i in 3..6 { + let h = const_expressions.append( + Expression::Literal(Literal::F32(i as f32)), + Default::default(), + ); + + vec2_components.push(h) + } + + let vec1 = constants.append( + Constant { + name: None, + r#override: crate::Override::None, + ty: vec_ty, + init: const_expressions.append( + Expression::Compose { + ty: vec_ty, + components: vec1_components, + }, + Default::default(), + ), + }, + Default::default(), + ); + + let vec2 = constants.append( + Constant { + name: None, + r#override: crate::Override::None, + ty: vec_ty, + init: const_expressions.append( + Expression::Compose { + ty: vec_ty, + components: vec2_components, + }, + Default::default(), + ), + }, + Default::default(), + ); + + let h = constants.append( + Constant { + name: None, + r#override: crate::Override::None, + ty: matrix_ty, + init: const_expressions.append( + Expression::Compose { + ty: matrix_ty, + components: vec![constants[vec1].init, constants[vec2].init], + }, + Default::default(), + ), + }, + Default::default(), + ); + + let base = const_expressions.append(Expression::Constant(h), Default::default()); + + let mut solver = ConstantEvaluator { + behavior: Behavior::Wgsl, + types: &mut types, + constants: &constants, + expressions: &mut const_expressions, + function_local_data: None, + }; + + let root1 = Expression::AccessIndex { base, index: 1 }; + + let res1 = solver + .try_eval_and_append(&root1, Default::default()) + .unwrap(); + + let root2 = Expression::AccessIndex { + base: res1, + index: 2, + }; + + let res2 = solver + .try_eval_and_append(&root2, Default::default()) + .unwrap(); + + match const_expressions[res1] { + Expression::Compose { + ref ty, + ref components, + } => { + assert_eq!(*ty, vec_ty); + let mut components_iter = components.iter().copied(); + assert_eq!( + const_expressions[components_iter.next().unwrap()], + Expression::Literal(Literal::F32(3.)) + ); + assert_eq!( + const_expressions[components_iter.next().unwrap()], + Expression::Literal(Literal::F32(4.)) + ); + assert_eq!( + const_expressions[components_iter.next().unwrap()], + Expression::Literal(Literal::F32(5.)) + ); + assert!(components_iter.next().is_none()); + } + _ => panic!("Expected vector"), + } + + assert_eq!( + const_expressions[res2], + Expression::Literal(Literal::F32(5.)) + ); + } + + #[test] + fn compose_of_constants() { + let mut types = UniqueArena::new(); + let mut constants = Arena::new(); + let mut const_expressions = Arena::new(); + + let i32_ty = types.insert( + Type { + name: None, + inner: TypeInner::Scalar(crate::Scalar::I32), + }, + Default::default(), + ); + + let vec2_i32_ty = types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size: VectorSize::Bi, + scalar: crate::Scalar::I32, + }, + }, + Default::default(), + ); + + let h = constants.append( + Constant { + name: None, + r#override: crate::Override::None, + ty: i32_ty, + init: const_expressions + .append(Expression::Literal(Literal::I32(4)), Default::default()), + }, + Default::default(), + ); + + let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + + let mut solver = ConstantEvaluator { + behavior: Behavior::Wgsl, + types: &mut types, + constants: &constants, + expressions: &mut const_expressions, + function_local_data: None, + }; + + let solved_compose = solver + .try_eval_and_append( + &Expression::Compose { + ty: vec2_i32_ty, + components: vec![h_expr, h_expr], + }, + Default::default(), + ) + .unwrap(); + let solved_negate = solver + .try_eval_and_append( + &Expression::Unary { + op: UnaryOperator::Negate, + expr: solved_compose, + }, + Default::default(), + ) + .unwrap(); + + let pass = match const_expressions[solved_negate] { + Expression::Compose { ty, ref components } => { + ty == vec2_i32_ty + && components.iter().all(|&component| { + let component = &const_expressions[component]; + matches!(*component, Expression::Literal(Literal::I32(-4))) + }) + } + _ => false, + }; + if !pass { + panic!("unexpected evaluation result") + } + } + + #[test] + fn splat_of_constant() { + let mut types = UniqueArena::new(); + let mut constants = Arena::new(); + let mut const_expressions = Arena::new(); + + let i32_ty = types.insert( + Type { + name: None, + inner: TypeInner::Scalar(crate::Scalar::I32), + }, + Default::default(), + ); + + let vec2_i32_ty = types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size: VectorSize::Bi, + scalar: crate::Scalar::I32, + }, + }, + Default::default(), + ); + + let h = constants.append( + Constant { + name: None, + r#override: crate::Override::None, + ty: i32_ty, + init: const_expressions + .append(Expression::Literal(Literal::I32(4)), Default::default()), + }, + Default::default(), + ); + + let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + + let mut solver = ConstantEvaluator { + behavior: Behavior::Wgsl, + types: &mut types, + constants: &constants, + expressions: &mut const_expressions, + function_local_data: None, + }; + + let solved_compose = solver + .try_eval_and_append( + &Expression::Splat { + size: VectorSize::Bi, + value: h_expr, + }, + Default::default(), + ) + .unwrap(); + let solved_negate = solver + .try_eval_and_append( + &Expression::Unary { + op: UnaryOperator::Negate, + expr: solved_compose, + }, + Default::default(), + ) + .unwrap(); + + let pass = match const_expressions[solved_negate] { + Expression::Compose { ty, ref components } => { + ty == vec2_i32_ty + && components.iter().all(|&component| { + let component = &const_expressions[component]; + matches!(*component, Expression::Literal(Literal::I32(-4))) + }) + } + _ => false, + }; + if !pass { + panic!("unexpected evaluation result") + } + } +} + +/// Trait for conversions of abstract values to concrete types. +trait TryFromAbstract<T>: Sized { + /// Convert an abstract literal `value` to `Self`. + /// + /// Since Naga's `AbstractInt` and `AbstractFloat` exist to support + /// WGSL, we follow WGSL's conversion rules here: + /// + /// - WGSL §6.1.2. Conversion Rank says that automatic conversions + /// to integers are either lossless or an error. + /// + /// - WGSL §14.6.4 Floating Point Conversion says that conversions + /// to floating point in constant expressions and override + /// expressions are errors if the value is out of range for the + /// destination type, but rounding is okay. + /// + /// [`AbstractInt`]: crate::Literal::AbstractInt + /// [`Float`]: crate::Literal::Float + fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>; +} + +impl TryFromAbstract<i64> for i32 { + fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> { + i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy { + value: format!("{value:?}"), + to_type: "i32", + }) + } +} + +impl TryFromAbstract<i64> for u32 { + fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> { + u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy { + value: format!("{value:?}"), + to_type: "u32", + }) + } +} + +impl TryFromAbstract<i64> for f32 { + fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> { + let f = value as f32; + // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of + // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for + // overflow here. + Ok(f) + } +} + +impl TryFromAbstract<f64> for f32 { + fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> { + let f = value as f32; + if f.is_infinite() { + return Err(ConstantEvaluatorError::AutomaticConversionLossy { + value: format!("{value:?}"), + to_type: "f32", + }); + } + Ok(f) + } +} + +impl TryFromAbstract<i64> for f64 { + fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> { + let f = value as f64; + // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of + // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for + // overflow here. + Ok(f) + } +} + +impl TryFromAbstract<f64> for f64 { + fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> { + Ok(value) + } +} + +impl TryFromAbstract<f64> for i32 { + fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> { + Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i32" }) + } +} + +impl TryFromAbstract<f64> for u32 { + fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> { + Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u32" }) + } +} diff --git a/third_party/rust/naga/src/proc/emitter.rs b/third_party/rust/naga/src/proc/emitter.rs new file mode 100644 index 0000000000..0df804fff2 --- /dev/null +++ b/third_party/rust/naga/src/proc/emitter.rs @@ -0,0 +1,39 @@ +use crate::arena::Arena; + +/// Helper class to emit expressions +#[allow(dead_code)] +#[derive(Default, Debug)] +pub struct Emitter { + start_len: Option<usize>, +} + +#[allow(dead_code)] +impl Emitter { + pub fn start(&mut self, arena: &Arena<crate::Expression>) { + if self.start_len.is_some() { + unreachable!("Emitting has already started!"); + } + self.start_len = Some(arena.len()); + } + pub const fn is_running(&self) -> bool { + self.start_len.is_some() + } + #[must_use] + pub fn finish( + &mut self, + arena: &Arena<crate::Expression>, + ) -> Option<(crate::Statement, crate::span::Span)> { + let start_len = self.start_len.take().unwrap(); + if start_len != arena.len() { + #[allow(unused_mut)] + let mut span = crate::span::Span::default(); + let range = arena.range_from(start_len); + for handle in range.clone() { + span.subsume(arena.get_span(handle)) + } + Some((crate::Statement::Emit(range), span)) + } else { + None + } + } +} diff --git a/third_party/rust/naga/src/proc/index.rs b/third_party/rust/naga/src/proc/index.rs new file mode 100644 index 0000000000..af3221c0fe --- /dev/null +++ b/third_party/rust/naga/src/proc/index.rs @@ -0,0 +1,435 @@ +/*! +Definitions for index bounds checking. +*/ + +use crate::{valid, Handle, UniqueArena}; +use bit_set::BitSet; + +/// How should code generated by Naga do bounds checks? +/// +/// When a vector, matrix, or array index is out of bounds—either negative, or +/// greater than or equal to the number of elements in the type—WGSL requires +/// that some other index of the implementation's choice that is in bounds is +/// used instead. (There are no types with zero elements.) +/// +/// Similarly, when out-of-bounds coordinates, array indices, or sample indices +/// are presented to the WGSL `textureLoad` and `textureStore` operations, the +/// operation is redirected to do something safe. +/// +/// Different users of Naga will prefer different defaults: +/// +/// - When used as part of a WebGPU implementation, the WGSL specification +/// requires the `Restrict` behavior for array, vector, and matrix accesses, +/// and either the `Restrict` or `ReadZeroSkipWrite` behaviors for texture +/// accesses. +/// +/// - When used by the `wgpu` crate for native development, `wgpu` selects +/// `ReadZeroSkipWrite` as its default. +/// +/// - Naga's own default is `Unchecked`, so that shader translations +/// are as faithful to the original as possible. +/// +/// Sometimes the underlying hardware and drivers can perform bounds checks +/// themselves, in a way that performs better than the checks Naga would inject. +/// If you're using native checks like this, then having Naga inject its own +/// checks as well would be redundant, and the `Unchecked` policy is +/// appropriate. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum BoundsCheckPolicy { + /// Replace out-of-bounds indexes with some arbitrary in-bounds index. + /// + /// (This does not necessarily mean clamping. For example, interpreting the + /// index as unsigned and taking the minimum with the largest valid index + /// would also be a valid implementation. That would map negative indices to + /// the last element, not the first.) + Restrict, + + /// Out-of-bounds reads return zero, and writes have no effect. + /// + /// When applied to a chain of accesses, like `a[i][j].b[k]`, all index + /// expressions are evaluated, regardless of whether prior or later index + /// expressions were in bounds. But all the accesses per se are skipped + /// if any index is out of bounds. + ReadZeroSkipWrite, + + /// Naga adds no checks to indexing operations. Generate the fastest code + /// possible. This is the default for Naga, as a translator, but consumers + /// should consider defaulting to a safer behavior. + Unchecked, +} + +/// Policies for injecting bounds checks during code generation. +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct BoundsCheckPolicies { + /// How should the generated code handle array, vector, or matrix indices + /// that are out of range? + #[cfg_attr(feature = "deserialize", serde(default))] + pub index: BoundsCheckPolicy, + + /// How should the generated code handle array, vector, or matrix indices + /// that are out of range, when those values live in a [`GlobalVariable`] in + /// the [`Storage`] or [`Uniform`] address spaces? + /// + /// Some graphics hardware provides "robust buffer access", a feature that + /// ensures that using a pointer cannot access memory outside the 'buffer' + /// that it was derived from. In Naga terms, this means that the hardware + /// ensures that pointers computed by applying [`Access`] and + /// [`AccessIndex`] expressions to a [`GlobalVariable`] whose [`space`] is + /// [`Storage`] or [`Uniform`] will never read or write memory outside that + /// global variable. + /// + /// When hardware offers such a feature, it is probably undesirable to have + /// Naga inject bounds checking code for such accesses, since the hardware + /// can probably provide the same protection more efficiently. However, + /// bounds checks are still needed on accesses to indexable values that do + /// not live in buffers, like local variables. + /// + /// So, this option provides a separate policy that applies only to accesses + /// to storage and uniform globals. When depending on hardware bounds + /// checking, this policy can be `Unchecked` to avoid unnecessary overhead. + /// + /// When special hardware support is not available, this should probably be + /// the same as `index_bounds_check_policy`. + /// + /// [`GlobalVariable`]: crate::GlobalVariable + /// [`space`]: crate::GlobalVariable::space + /// [`Restrict`]: crate::back::BoundsCheckPolicy::Restrict + /// [`ReadZeroSkipWrite`]: crate::back::BoundsCheckPolicy::ReadZeroSkipWrite + /// [`Access`]: crate::Expression::Access + /// [`AccessIndex`]: crate::Expression::AccessIndex + /// [`Storage`]: crate::AddressSpace::Storage + /// [`Uniform`]: crate::AddressSpace::Uniform + #[cfg_attr(feature = "deserialize", serde(default))] + pub buffer: BoundsCheckPolicy, + + /// How should the generated code handle image texel loads that are out + /// of range? + /// + /// This controls the behavior of [`ImageLoad`] expressions when a coordinate, + /// texture array index, level of detail, or multisampled sample number is out of range. + /// + /// [`ImageLoad`]: crate::Expression::ImageLoad + #[cfg_attr(feature = "deserialize", serde(default))] + pub image_load: BoundsCheckPolicy, + + /// How should the generated code handle image texel stores that are out + /// of range? + /// + /// This controls the behavior of [`ImageStore`] statements when a coordinate, + /// texture array index, level of detail, or multisampled sample number is out of range. + /// + /// This policy should't be needed since all backends should ignore OOB writes. + /// + /// [`ImageStore`]: crate::Statement::ImageStore + #[cfg_attr(feature = "deserialize", serde(default))] + pub image_store: BoundsCheckPolicy, + + /// How should the generated code handle binding array indexes that are out of bounds. + #[cfg_attr(feature = "deserialize", serde(default))] + pub binding_array: BoundsCheckPolicy, +} + +/// The default `BoundsCheckPolicy` is `Unchecked`. +impl Default for BoundsCheckPolicy { + fn default() -> Self { + BoundsCheckPolicy::Unchecked + } +} + +impl BoundsCheckPolicies { + /// Determine which policy applies to `base`. + /// + /// `base` is the "base" expression (the expression being indexed) of a `Access` + /// and `AccessIndex` expression. This is either a pointer, a value, being directly + /// indexed, or a binding array. + /// + /// See the documentation for [`BoundsCheckPolicy`] for details about + /// when each policy applies. + pub fn choose_policy( + &self, + base: Handle<crate::Expression>, + types: &UniqueArena<crate::Type>, + info: &valid::FunctionInfo, + ) -> BoundsCheckPolicy { + let ty = info[base].ty.inner_with(types); + + if let crate::TypeInner::BindingArray { .. } = *ty { + return self.binding_array; + } + + match ty.pointer_space() { + Some(crate::AddressSpace::Storage { access: _ } | crate::AddressSpace::Uniform) => { + self.buffer + } + // This covers other address spaces, but also accessing vectors and + // matrices by value, where no pointer is involved. + _ => self.index, + } + } + + /// Return `true` if any of `self`'s policies are `policy`. + pub fn contains(&self, policy: BoundsCheckPolicy) -> bool { + self.index == policy + || self.buffer == policy + || self.image_load == policy + || self.image_store == policy + } +} + +/// An index that may be statically known, or may need to be computed at runtime. +/// +/// This enum lets us handle both [`Access`] and [`AccessIndex`] expressions +/// with the same code. +/// +/// [`Access`]: crate::Expression::Access +/// [`AccessIndex`]: crate::Expression::AccessIndex +#[derive(Clone, Copy, Debug)] +pub enum GuardedIndex { + Known(u32), + Expression(Handle<crate::Expression>), +} + +/// Build a set of expressions used as indices, to cache in temporary variables when +/// emitted. +/// +/// Given the bounds-check policies `policies`, construct a `BitSet` containing the handle +/// indices of all the expressions in `function` that are ever used as guarded indices +/// under the [`ReadZeroSkipWrite`] policy. The `module` argument must be the module to +/// which `function` belongs, and `info` should be that function's analysis results. +/// +/// Such index expressions will be used twice in the generated code: first for the +/// comparison to see if the index is in bounds, and then for the access itself, should +/// the comparison succeed. To avoid computing the expressions twice, the generated code +/// should cache them in temporary variables. +/// +/// Why do we need to build such a set in advance, instead of just processing access +/// expressions as we encounter them? Whether an expression needs to be cached depends on +/// whether it appears as something like the [`index`] operand of an [`Access`] expression +/// or the [`level`] operand of an [`ImageLoad`] expression, and on the index bounds check +/// policies that apply to those accesses. But [`Emit`] statements just identify a range +/// of expressions by index; there's no good way to tell what an expression is used +/// for. The only way to do it is to just iterate over all the expressions looking for +/// relevant `Access` expressions --- which is what this function does. +/// +/// Simple expressions like variable loads and constants don't make sense to cache: it's +/// no better than just re-evaluating them. But constants are not covered by `Emit` +/// statements, and `Load`s are always cached to ensure they occur at the right time, so +/// we don't bother filtering them out from this set. +/// +/// Fortunately, we don't need to deal with [`ImageStore`] statements here. When we emit +/// code for a statement, the writer isn't in the middle of an expression, so we can just +/// emit declarations for temporaries, initialized appropriately. +/// +/// None of these concerns apply for SPIR-V output, since it's easy to just reuse an +/// instruction ID in two places; that has the same semantics as a temporary variable, and +/// it's inherent in the design of SPIR-V. This function is more useful for text-based +/// back ends. +/// +/// [`ReadZeroSkipWrite`]: BoundsCheckPolicy::ReadZeroSkipWrite +/// [`index`]: crate::Expression::Access::index +/// [`Access`]: crate::Expression::Access +/// [`level`]: crate::Expression::ImageLoad::level +/// [`ImageLoad`]: crate::Expression::ImageLoad +/// [`Emit`]: crate::Statement::Emit +/// [`ImageStore`]: crate::Statement::ImageStore +pub fn find_checked_indexes( + module: &crate::Module, + function: &crate::Function, + info: &crate::valid::FunctionInfo, + policies: BoundsCheckPolicies, +) -> BitSet { + use crate::Expression as Ex; + + let mut guarded_indices = BitSet::new(); + + // Don't bother scanning if we never need `ReadZeroSkipWrite`. + if policies.contains(BoundsCheckPolicy::ReadZeroSkipWrite) { + for (_handle, expr) in function.expressions.iter() { + // There's no need to handle `AccessIndex` expressions, as their + // indices never need to be cached. + match *expr { + Ex::Access { base, index } => { + if policies.choose_policy(base, &module.types, info) + == BoundsCheckPolicy::ReadZeroSkipWrite + && access_needs_check( + base, + GuardedIndex::Expression(index), + module, + function, + info, + ) + .is_some() + { + guarded_indices.insert(index.index()); + } + } + Ex::ImageLoad { + coordinate, + array_index, + sample, + level, + .. + } => { + if policies.image_load == BoundsCheckPolicy::ReadZeroSkipWrite { + guarded_indices.insert(coordinate.index()); + if let Some(array_index) = array_index { + guarded_indices.insert(array_index.index()); + } + if let Some(sample) = sample { + guarded_indices.insert(sample.index()); + } + if let Some(level) = level { + guarded_indices.insert(level.index()); + } + } + } + _ => {} + } + } + } + + guarded_indices +} + +/// Determine whether `index` is statically known to be in bounds for `base`. +/// +/// If we can't be sure that the index is in bounds, return the limit within +/// which valid indices must fall. +/// +/// The return value is one of the following: +/// +/// - `Some(Known(n))` indicates that `n` is the largest valid index. +/// +/// - `Some(Computed(global))` indicates that the largest valid index is one +/// less than the length of the array that is the last member of the +/// struct held in `global`. +/// +/// - `None` indicates that the index need not be checked, either because it +/// is statically known to be in bounds, or because the applicable policy +/// is `Unchecked`. +/// +/// This function only handles subscriptable types: arrays, vectors, and +/// matrices. It does not handle struct member indices; those never require +/// run-time checks, so it's best to deal with them further up the call +/// chain. +pub fn access_needs_check( + base: Handle<crate::Expression>, + mut index: GuardedIndex, + module: &crate::Module, + function: &crate::Function, + info: &crate::valid::FunctionInfo, +) -> Option<IndexableLength> { + let base_inner = info[base].ty.inner_with(&module.types); + // Unwrap safety: `Err` here indicates unindexable base types and invalid + // length constants, but `access_needs_check` is only used by back ends, so + // validation should have caught those problems. + let length = base_inner.indexable_length(module).unwrap(); + index.try_resolve_to_constant(function, module); + if let (&GuardedIndex::Known(index), &IndexableLength::Known(length)) = (&index, &length) { + if index < length { + // Index is statically known to be in bounds, no check needed. + return None; + } + }; + + Some(length) +} + +impl GuardedIndex { + /// Make a `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible. + /// + /// Return values that are already `Known` unchanged. + fn try_resolve_to_constant(&mut self, function: &crate::Function, module: &crate::Module) { + if let GuardedIndex::Expression(expr) = *self { + if let Ok(value) = module + .to_ctx() + .eval_expr_to_u32_from(expr, &function.expressions) + { + *self = GuardedIndex::Known(value); + } + } + } +} + +#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)] +pub enum IndexableLengthError { + #[error("Type is not indexable, and has no length (validation error)")] + TypeNotIndexable, + #[error("Array length constant {0:?} is invalid")] + InvalidArrayLength(Handle<crate::Expression>), +} + +impl crate::TypeInner { + /// Return the length of a subscriptable type. + /// + /// The `self` parameter should be a handle to a vector, matrix, or array + /// type, a pointer to one of those, or a value pointer. Arrays may be + /// fixed-size, dynamically sized, or sized by a specializable constant. + /// This function does not handle struct member references, as with + /// `AccessIndex`. + /// + /// The value returned is appropriate for bounds checks on subscripting. + /// + /// Return an error if `self` does not describe a subscriptable type at all. + pub fn indexable_length( + &self, + module: &crate::Module, + ) -> Result<IndexableLength, IndexableLengthError> { + use crate::TypeInner as Ti; + let known_length = match *self { + Ti::Vector { size, .. } => size as _, + Ti::Matrix { columns, .. } => columns as _, + Ti::Array { size, .. } | Ti::BindingArray { size, .. } => { + return size.to_indexable_length(module); + } + Ti::ValuePointer { + size: Some(size), .. + } => size as _, + Ti::Pointer { base, .. } => { + // When assigning types to expressions, ResolveContext::Resolve + // does a separate sub-match here instead of a full recursion, + // so we'll do the same. + let base_inner = &module.types[base].inner; + match *base_inner { + Ti::Vector { size, .. } => size as _, + Ti::Matrix { columns, .. } => columns as _, + Ti::Array { size, .. } | Ti::BindingArray { size, .. } => { + return size.to_indexable_length(module) + } + _ => return Err(IndexableLengthError::TypeNotIndexable), + } + } + _ => return Err(IndexableLengthError::TypeNotIndexable), + }; + Ok(IndexableLength::Known(known_length)) + } +} + +/// The number of elements in an indexable type. +/// +/// This summarizes the length of vectors, matrices, and arrays in a way that is +/// convenient for indexing and bounds-checking code. +#[derive(Debug)] +pub enum IndexableLength { + /// Values of this type always have the given number of elements. + Known(u32), + + /// The number of elements is determined at runtime. + Dynamic, +} + +impl crate::ArraySize { + pub const fn to_indexable_length( + self, + _module: &crate::Module, + ) -> Result<IndexableLength, IndexableLengthError> { + Ok(match self { + Self::Constant(length) => IndexableLength::Known(length.get()), + Self::Dynamic => IndexableLength::Dynamic, + }) + } +} diff --git a/third_party/rust/naga/src/proc/layouter.rs b/third_party/rust/naga/src/proc/layouter.rs new file mode 100644 index 0000000000..1c78a594d1 --- /dev/null +++ b/third_party/rust/naga/src/proc/layouter.rs @@ -0,0 +1,251 @@ +use crate::arena::Handle; +use std::{fmt::Display, num::NonZeroU32, ops}; + +/// A newtype struct where its only valid values are powers of 2 +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct Alignment(NonZeroU32); + +impl Alignment { + pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) }); + pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) }); + pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) }); + pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) }); + pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) }); + + pub const MIN_UNIFORM: Self = Self::SIXTEEN; + + pub const fn new(n: u32) -> Option<Self> { + if n.is_power_of_two() { + // SAFETY: value can't be 0 since we just checked if it's a power of 2 + Some(Self(unsafe { NonZeroU32::new_unchecked(n) })) + } else { + None + } + } + + /// # Panics + /// If `width` is not a power of 2 + pub fn from_width(width: u8) -> Self { + Self::new(width as u32).unwrap() + } + + /// Returns whether or not `n` is a multiple of this alignment. + pub const fn is_aligned(&self, n: u32) -> bool { + // equivalent to: `n % self.0.get() == 0` but much faster + n & (self.0.get() - 1) == 0 + } + + /// Round `n` up to the nearest alignment boundary. + pub const fn round_up(&self, n: u32) -> u32 { + // equivalent to: + // match n % self.0.get() { + // 0 => n, + // rem => n + (self.0.get() - rem), + // } + let mask = self.0.get() - 1; + (n + mask) & !mask + } +} + +impl Display for Alignment { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.get().fmt(f) + } +} + +impl ops::Mul<u32> for Alignment { + type Output = u32; + + fn mul(self, rhs: u32) -> Self::Output { + self.0.get() * rhs + } +} + +impl ops::Mul for Alignment { + type Output = Alignment; + + fn mul(self, rhs: Alignment) -> Self::Output { + // SAFETY: both lhs and rhs are powers of 2, the result will be a power of 2 + Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) }) + } +} + +impl From<crate::VectorSize> for Alignment { + fn from(size: crate::VectorSize) -> Self { + match size { + crate::VectorSize::Bi => Alignment::TWO, + crate::VectorSize::Tri => Alignment::FOUR, + crate::VectorSize::Quad => Alignment::FOUR, + } + } +} + +/// Size and alignment information for a type. +#[derive(Clone, Copy, Debug, Hash, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct TypeLayout { + pub size: u32, + pub alignment: Alignment, +} + +impl TypeLayout { + /// Produce the stride as if this type is a base of an array. + pub const fn to_stride(&self) -> u32 { + self.alignment.round_up(self.size) + } +} + +/// Helper processor that derives the sizes of all types. +/// +/// `Layouter` uses the default layout algorithm/table, described in +/// [WGSL §4.3.7, "Memory Layout"] +/// +/// A `Layouter` may be indexed by `Handle<Type>` values: `layouter[handle]` is the +/// layout of the type whose handle is `handle`. +/// +/// [WGSL §4.3.7, "Memory Layout"](https://gpuweb.github.io/gpuweb/wgsl/#memory-layouts) +#[derive(Debug, Default)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct Layouter { + /// Layouts for types in an arena, indexed by `Handle` index. + layouts: Vec<TypeLayout>, +} + +impl ops::Index<Handle<crate::Type>> for Layouter { + type Output = TypeLayout; + fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout { + &self.layouts[handle.index()] + } +} + +#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)] +pub enum LayoutErrorInner { + #[error("Array element type {0:?} doesn't exist")] + InvalidArrayElementType(Handle<crate::Type>), + #[error("Struct member[{0}] type {1:?} doesn't exist")] + InvalidStructMemberType(u32, Handle<crate::Type>), + #[error("Type width must be a power of two")] + NonPowerOfTwoWidth, +} + +#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)] +#[error("Error laying out type {ty:?}: {inner}")] +pub struct LayoutError { + pub ty: Handle<crate::Type>, + pub inner: LayoutErrorInner, +} + +impl LayoutErrorInner { + const fn with(self, ty: Handle<crate::Type>) -> LayoutError { + LayoutError { ty, inner: self } + } +} + +impl Layouter { + /// Remove all entries from this `Layouter`, retaining storage. + pub fn clear(&mut self) { + self.layouts.clear(); + } + + /// Extend this `Layouter` with layouts for any new entries in `gctx.types`. + /// + /// Ensure that every type in `gctx.types` has a corresponding [TypeLayout] + /// in [`self.layouts`]. + /// + /// Some front ends need to be able to compute layouts for existing types + /// while module construction is still in progress and new types are still + /// being added. This function assumes that the `TypeLayout` values already + /// present in `self.layouts` cover their corresponding entries in `types`, + /// and extends `self.layouts` as needed to cover the rest. Thus, a front + /// end can call this function at any time, passing its current type and + /// constant arenas, and then assume that layouts are available for all + /// types. + #[allow(clippy::or_fun_call)] + pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> { + use crate::TypeInner as Ti; + + for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) { + let size = ty.inner.size(gctx); + let layout = match ty.inner { + Ti::Scalar(scalar) | Ti::Atomic(scalar) => { + let alignment = Alignment::new(scalar.width as u32) + .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; + TypeLayout { size, alignment } + } + Ti::Vector { + size: vec_size, + scalar, + } => { + let alignment = Alignment::new(scalar.width as u32) + .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; + TypeLayout { + size, + alignment: Alignment::from(vec_size) * alignment, + } + } + Ti::Matrix { + columns: _, + rows, + scalar, + } => { + let alignment = Alignment::new(scalar.width as u32) + .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; + TypeLayout { + size, + alignment: Alignment::from(rows) * alignment, + } + } + Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout { + size, + alignment: Alignment::ONE, + }, + Ti::Array { + base, + stride: _, + size: _, + } => TypeLayout { + size, + alignment: if base < ty_handle { + self[base].alignment + } else { + return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle)); + }, + }, + Ti::Struct { span, ref members } => { + let mut alignment = Alignment::ONE; + for (index, member) in members.iter().enumerate() { + alignment = if member.ty < ty_handle { + alignment.max(self[member.ty].alignment) + } else { + return Err(LayoutErrorInner::InvalidStructMemberType( + index as u32, + member.ty, + ) + .with(ty_handle)); + }; + } + TypeLayout { + size: span, + alignment, + } + } + Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => TypeLayout { + size, + alignment: Alignment::ONE, + }, + }; + debug_assert!(size <= layout.size); + self.layouts.push(layout); + } + + Ok(()) + } +} diff --git a/third_party/rust/naga/src/proc/mod.rs b/third_party/rust/naga/src/proc/mod.rs new file mode 100644 index 0000000000..b9ce80b5ea --- /dev/null +++ b/third_party/rust/naga/src/proc/mod.rs @@ -0,0 +1,809 @@ +/*! +[`Module`](super::Module) processing functionality. +*/ + +mod constant_evaluator; +mod emitter; +pub mod index; +mod layouter; +mod namer; +mod terminator; +mod typifier; + +pub use constant_evaluator::{ + ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker, +}; +pub use emitter::Emitter; +pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError}; +pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout}; +pub use namer::{EntryPointIndex, NameKey, Namer}; +pub use terminator::ensure_block_returns; +pub use typifier::{ResolveContext, ResolveError, TypeResolution}; + +impl From<super::StorageFormat> for super::ScalarKind { + fn from(format: super::StorageFormat) -> Self { + use super::{ScalarKind as Sk, StorageFormat as Sf}; + match format { + Sf::R8Unorm => Sk::Float, + Sf::R8Snorm => Sk::Float, + Sf::R8Uint => Sk::Uint, + Sf::R8Sint => Sk::Sint, + Sf::R16Uint => Sk::Uint, + Sf::R16Sint => Sk::Sint, + Sf::R16Float => Sk::Float, + Sf::Rg8Unorm => Sk::Float, + Sf::Rg8Snorm => Sk::Float, + Sf::Rg8Uint => Sk::Uint, + Sf::Rg8Sint => Sk::Sint, + Sf::R32Uint => Sk::Uint, + Sf::R32Sint => Sk::Sint, + Sf::R32Float => Sk::Float, + Sf::Rg16Uint => Sk::Uint, + Sf::Rg16Sint => Sk::Sint, + Sf::Rg16Float => Sk::Float, + Sf::Rgba8Unorm => Sk::Float, + Sf::Rgba8Snorm => Sk::Float, + Sf::Rgba8Uint => Sk::Uint, + Sf::Rgba8Sint => Sk::Sint, + Sf::Bgra8Unorm => Sk::Float, + Sf::Rgb10a2Uint => Sk::Uint, + Sf::Rgb10a2Unorm => Sk::Float, + Sf::Rg11b10Float => Sk::Float, + Sf::Rg32Uint => Sk::Uint, + Sf::Rg32Sint => Sk::Sint, + Sf::Rg32Float => Sk::Float, + Sf::Rgba16Uint => Sk::Uint, + Sf::Rgba16Sint => Sk::Sint, + Sf::Rgba16Float => Sk::Float, + Sf::Rgba32Uint => Sk::Uint, + Sf::Rgba32Sint => Sk::Sint, + Sf::Rgba32Float => Sk::Float, + Sf::R16Unorm => Sk::Float, + Sf::R16Snorm => Sk::Float, + Sf::Rg16Unorm => Sk::Float, + Sf::Rg16Snorm => Sk::Float, + Sf::Rgba16Unorm => Sk::Float, + Sf::Rgba16Snorm => Sk::Float, + } + } +} + +impl super::ScalarKind { + pub const fn is_numeric(self) -> bool { + match self { + crate::ScalarKind::Sint + | crate::ScalarKind::Uint + | crate::ScalarKind::Float + | crate::ScalarKind::AbstractInt + | crate::ScalarKind::AbstractFloat => true, + crate::ScalarKind::Bool => false, + } + } +} + +impl super::Scalar { + pub const I32: Self = Self { + kind: crate::ScalarKind::Sint, + width: 4, + }; + pub const U32: Self = Self { + kind: crate::ScalarKind::Uint, + width: 4, + }; + pub const F32: Self = Self { + kind: crate::ScalarKind::Float, + width: 4, + }; + pub const F64: Self = Self { + kind: crate::ScalarKind::Float, + width: 8, + }; + pub const I64: Self = Self { + kind: crate::ScalarKind::Sint, + width: 8, + }; + pub const BOOL: Self = Self { + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + }; + pub const ABSTRACT_INT: Self = Self { + kind: crate::ScalarKind::AbstractInt, + width: crate::ABSTRACT_WIDTH, + }; + pub const ABSTRACT_FLOAT: Self = Self { + kind: crate::ScalarKind::AbstractFloat, + width: crate::ABSTRACT_WIDTH, + }; + + pub const fn is_abstract(self) -> bool { + match self.kind { + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => true, + crate::ScalarKind::Sint + | crate::ScalarKind::Uint + | crate::ScalarKind::Float + | crate::ScalarKind::Bool => false, + } + } + + /// Construct a float `Scalar` with the given width. + /// + /// This is especially common when dealing with + /// `TypeInner::Matrix`, where the scalar kind is implicit. + pub const fn float(width: crate::Bytes) -> Self { + Self { + kind: crate::ScalarKind::Float, + width, + } + } + + pub const fn to_inner_scalar(self) -> crate::TypeInner { + crate::TypeInner::Scalar(self) + } + + pub const fn to_inner_vector(self, size: crate::VectorSize) -> crate::TypeInner { + crate::TypeInner::Vector { size, scalar: self } + } + + pub const fn to_inner_atomic(self) -> crate::TypeInner { + crate::TypeInner::Atomic(self) + } +} + +impl PartialEq for crate::Literal { + fn eq(&self, other: &Self) -> bool { + match (*self, *other) { + (Self::F64(a), Self::F64(b)) => a.to_bits() == b.to_bits(), + (Self::F32(a), Self::F32(b)) => a.to_bits() == b.to_bits(), + (Self::U32(a), Self::U32(b)) => a == b, + (Self::I32(a), Self::I32(b)) => a == b, + (Self::I64(a), Self::I64(b)) => a == b, + (Self::Bool(a), Self::Bool(b)) => a == b, + _ => false, + } + } +} +impl Eq for crate::Literal {} +impl std::hash::Hash for crate::Literal { + fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) { + match *self { + Self::F64(v) | Self::AbstractFloat(v) => { + hasher.write_u8(0); + v.to_bits().hash(hasher); + } + Self::F32(v) => { + hasher.write_u8(1); + v.to_bits().hash(hasher); + } + Self::U32(v) => { + hasher.write_u8(2); + v.hash(hasher); + } + Self::I32(v) => { + hasher.write_u8(3); + v.hash(hasher); + } + Self::Bool(v) => { + hasher.write_u8(4); + v.hash(hasher); + } + Self::I64(v) | Self::AbstractInt(v) => { + hasher.write_u8(5); + v.hash(hasher); + } + } + } +} + +impl crate::Literal { + pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> { + match (value, scalar.kind, scalar.width) { + (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)), + (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)), + (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)), + (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)), + (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)), + (1, crate::ScalarKind::Bool, 4) => Some(Self::Bool(true)), + (0, crate::ScalarKind::Bool, 4) => Some(Self::Bool(false)), + _ => None, + } + } + + pub const fn zero(scalar: crate::Scalar) -> Option<Self> { + Self::new(0, scalar) + } + + pub const fn one(scalar: crate::Scalar) -> Option<Self> { + Self::new(1, scalar) + } + + pub const fn width(&self) -> crate::Bytes { + match *self { + Self::F64(_) | Self::I64(_) => 8, + Self::F32(_) | Self::U32(_) | Self::I32(_) => 4, + Self::Bool(_) => crate::BOOL_WIDTH, + Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH, + } + } + pub const fn scalar(&self) -> crate::Scalar { + match *self { + Self::F64(_) => crate::Scalar::F64, + Self::F32(_) => crate::Scalar::F32, + Self::U32(_) => crate::Scalar::U32, + Self::I32(_) => crate::Scalar::I32, + Self::I64(_) => crate::Scalar::I64, + Self::Bool(_) => crate::Scalar::BOOL, + Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT, + Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT, + } + } + pub const fn scalar_kind(&self) -> crate::ScalarKind { + self.scalar().kind + } + pub const fn ty_inner(&self) -> crate::TypeInner { + crate::TypeInner::Scalar(self.scalar()) + } +} + +pub const POINTER_SPAN: u32 = 4; + +impl super::TypeInner { + /// Return the scalar type of `self`. + /// + /// If `inner` is a scalar, vector, or matrix type, return + /// its scalar type. Otherwise, return `None`. + pub const fn scalar(&self) -> Option<super::Scalar> { + use crate::TypeInner as Ti; + match *self { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => Some(scalar), + Ti::Matrix { scalar, .. } => Some(scalar), + _ => None, + } + } + + pub fn scalar_kind(&self) -> Option<super::ScalarKind> { + self.scalar().map(|scalar| scalar.kind) + } + + pub fn scalar_width(&self) -> Option<u8> { + self.scalar().map(|scalar| scalar.width * 8) + } + + pub const fn pointer_space(&self) -> Option<crate::AddressSpace> { + match *self { + Self::Pointer { space, .. } => Some(space), + Self::ValuePointer { space, .. } => Some(space), + _ => None, + } + } + + pub fn is_atomic_pointer(&self, types: &crate::UniqueArena<crate::Type>) -> bool { + match *self { + crate::TypeInner::Pointer { base, .. } => match types[base].inner { + crate::TypeInner::Atomic { .. } => true, + _ => false, + }, + _ => false, + } + } + + /// Get the size of this type. + pub fn size(&self, _gctx: GlobalCtx) -> u32 { + match *self { + Self::Scalar(scalar) | Self::Atomic(scalar) => scalar.width as u32, + Self::Vector { size, scalar } => size as u32 * scalar.width as u32, + // matrices are treated as arrays of aligned columns + Self::Matrix { + columns, + rows, + scalar, + } => Alignment::from(rows) * scalar.width as u32 * columns as u32, + Self::Pointer { .. } | Self::ValuePointer { .. } => POINTER_SPAN, + Self::Array { + base: _, + size, + stride, + } => { + let count = match size { + super::ArraySize::Constant(count) => count.get(), + // A dynamically-sized array has to have at least one element + super::ArraySize::Dynamic => 1, + }; + count * stride + } + Self::Struct { span, .. } => span, + Self::Image { .. } + | Self::Sampler { .. } + | Self::AccelerationStructure + | Self::RayQuery + | Self::BindingArray { .. } => 0, + } + } + + /// Return the canonical form of `self`, or `None` if it's already in + /// canonical form. + /// + /// Certain types have multiple representations in `TypeInner`. This + /// function converts all forms of equivalent types to a single + /// representative of their class, so that simply applying `Eq` to the + /// result indicates whether the types are equivalent, as far as Naga IR is + /// concerned. + pub fn canonical_form( + &self, + types: &crate::UniqueArena<crate::Type>, + ) -> Option<crate::TypeInner> { + use crate::TypeInner as Ti; + match *self { + Ti::Pointer { base, space } => match types[base].inner { + Ti::Scalar(scalar) => Some(Ti::ValuePointer { + size: None, + scalar, + space, + }), + Ti::Vector { size, scalar } => Some(Ti::ValuePointer { + size: Some(size), + scalar, + space, + }), + _ => None, + }, + _ => None, + } + } + + /// Compare `self` and `rhs` as types. + /// + /// This is mostly the same as `<TypeInner as Eq>::eq`, but it treats + /// `ValuePointer` and `Pointer` types as equivalent. + /// + /// When you know that one side of the comparison is never a pointer, it's + /// fine to not bother with canonicalization, and just compare `TypeInner` + /// values with `==`. + pub fn equivalent( + &self, + rhs: &crate::TypeInner, + types: &crate::UniqueArena<crate::Type>, + ) -> bool { + let left = self.canonical_form(types); + let right = rhs.canonical_form(types); + left.as_ref().unwrap_or(self) == right.as_ref().unwrap_or(rhs) + } + + pub fn is_dynamically_sized(&self, types: &crate::UniqueArena<crate::Type>) -> bool { + use crate::TypeInner as Ti; + match *self { + Ti::Array { size, .. } => size == crate::ArraySize::Dynamic, + Ti::Struct { ref members, .. } => members + .last() + .map(|last| types[last.ty].inner.is_dynamically_sized(types)) + .unwrap_or(false), + _ => false, + } + } + + pub fn components(&self) -> Option<u32> { + Some(match *self { + Self::Vector { size, .. } => size as u32, + Self::Matrix { columns, .. } => columns as u32, + Self::Array { + size: crate::ArraySize::Constant(len), + .. + } => len.get(), + Self::Struct { ref members, .. } => members.len() as u32, + _ => return None, + }) + } + + pub fn component_type(&self, index: usize) -> Option<TypeResolution> { + Some(match *self { + Self::Vector { scalar, .. } => TypeResolution::Value(crate::TypeInner::Scalar(scalar)), + Self::Matrix { rows, scalar, .. } => { + TypeResolution::Value(crate::TypeInner::Vector { size: rows, scalar }) + } + Self::Array { + base, + size: crate::ArraySize::Constant(_), + .. + } => TypeResolution::Handle(base), + Self::Struct { ref members, .. } => TypeResolution::Handle(members[index].ty), + _ => return None, + }) + } +} + +impl super::AddressSpace { + pub fn access(self) -> crate::StorageAccess { + use crate::StorageAccess as Sa; + match self { + crate::AddressSpace::Function + | crate::AddressSpace::Private + | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE, + crate::AddressSpace::Uniform => Sa::LOAD, + crate::AddressSpace::Storage { access } => access, + crate::AddressSpace::Handle => Sa::LOAD, + crate::AddressSpace::PushConstant => Sa::LOAD, + } + } +} + +impl super::MathFunction { + pub const fn argument_count(&self) -> usize { + match *self { + // comparison + Self::Abs => 1, + Self::Min => 2, + Self::Max => 2, + Self::Clamp => 3, + Self::Saturate => 1, + // trigonometry + Self::Cos => 1, + Self::Cosh => 1, + Self::Sin => 1, + Self::Sinh => 1, + Self::Tan => 1, + Self::Tanh => 1, + Self::Acos => 1, + Self::Asin => 1, + Self::Atan => 1, + Self::Atan2 => 2, + Self::Asinh => 1, + Self::Acosh => 1, + Self::Atanh => 1, + Self::Radians => 1, + Self::Degrees => 1, + // decomposition + Self::Ceil => 1, + Self::Floor => 1, + Self::Round => 1, + Self::Fract => 1, + Self::Trunc => 1, + Self::Modf => 1, + Self::Frexp => 1, + Self::Ldexp => 2, + // exponent + Self::Exp => 1, + Self::Exp2 => 1, + Self::Log => 1, + Self::Log2 => 1, + Self::Pow => 2, + // geometry + Self::Dot => 2, + Self::Outer => 2, + Self::Cross => 2, + Self::Distance => 2, + Self::Length => 1, + Self::Normalize => 1, + Self::FaceForward => 3, + Self::Reflect => 2, + Self::Refract => 3, + // computational + Self::Sign => 1, + Self::Fma => 3, + Self::Mix => 3, + Self::Step => 2, + Self::SmoothStep => 3, + Self::Sqrt => 1, + Self::InverseSqrt => 1, + Self::Inverse => 1, + Self::Transpose => 1, + Self::Determinant => 1, + // bits + Self::CountTrailingZeros => 1, + Self::CountLeadingZeros => 1, + Self::CountOneBits => 1, + Self::ReverseBits => 1, + Self::ExtractBits => 3, + Self::InsertBits => 4, + Self::FindLsb => 1, + Self::FindMsb => 1, + // data packing + Self::Pack4x8snorm => 1, + Self::Pack4x8unorm => 1, + Self::Pack2x16snorm => 1, + Self::Pack2x16unorm => 1, + Self::Pack2x16float => 1, + // data unpacking + Self::Unpack4x8snorm => 1, + Self::Unpack4x8unorm => 1, + Self::Unpack2x16snorm => 1, + Self::Unpack2x16unorm => 1, + Self::Unpack2x16float => 1, + } + } +} + +impl crate::Expression { + /// Returns true if the expression is considered emitted at the start of a function. + pub const fn needs_pre_emit(&self) -> bool { + match *self { + Self::Literal(_) + | Self::Constant(_) + | Self::ZeroValue(_) + | Self::FunctionArgument(_) + | Self::GlobalVariable(_) + | Self::LocalVariable(_) => true, + _ => false, + } + } + + /// Return true if this expression is a dynamic array index, for [`Access`]. + /// + /// This method returns true if this expression is a dynamically computed + /// index, and as such can only be used to index matrices and arrays when + /// they appear behind a pointer. See the documentation for [`Access`] for + /// details. + /// + /// Note, this does not check the _type_ of the given expression. It's up to + /// the caller to establish that the `Access` expression is well-typed + /// through other means, like [`ResolveContext`]. + /// + /// [`Access`]: crate::Expression::Access + /// [`ResolveContext`]: crate::proc::ResolveContext + pub fn is_dynamic_index(&self, module: &crate::Module) -> bool { + match *self { + Self::Literal(_) | Self::ZeroValue(_) => false, + Self::Constant(handle) => { + let constant = &module.constants[handle]; + !matches!(constant.r#override, crate::Override::None) + } + _ => true, + } + } +} + +impl crate::Function { + /// Return the global variable being accessed by the expression `pointer`. + /// + /// Assuming that `pointer` is a series of `Access` and `AccessIndex` + /// expressions that ultimately access some part of a `GlobalVariable`, + /// return a handle for that global. + /// + /// If the expression does not ultimately access a global variable, return + /// `None`. + pub fn originating_global( + &self, + mut pointer: crate::Handle<crate::Expression>, + ) -> Option<crate::Handle<crate::GlobalVariable>> { + loop { + pointer = match self.expressions[pointer] { + crate::Expression::Access { base, .. } => base, + crate::Expression::AccessIndex { base, .. } => base, + crate::Expression::GlobalVariable(handle) => return Some(handle), + crate::Expression::LocalVariable(_) => return None, + crate::Expression::FunctionArgument(_) => return None, + // There are no other expressions that produce pointer values. + _ => unreachable!(), + } + } + } +} + +impl crate::SampleLevel { + pub const fn implicit_derivatives(&self) -> bool { + match *self { + Self::Auto | Self::Bias(_) => true, + Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false, + } + } +} + +impl crate::Binding { + pub const fn to_built_in(&self) -> Option<crate::BuiltIn> { + match *self { + crate::Binding::BuiltIn(built_in) => Some(built_in), + Self::Location { .. } => None, + } + } +} + +impl super::SwizzleComponent { + pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W]; + + pub const fn index(&self) -> u32 { + match *self { + Self::X => 0, + Self::Y => 1, + Self::Z => 2, + Self::W => 3, + } + } + pub const fn from_index(idx: u32) -> Self { + match idx { + 0 => Self::X, + 1 => Self::Y, + 2 => Self::Z, + _ => Self::W, + } + } +} + +impl super::ImageClass { + pub const fn is_multisampled(self) -> bool { + match self { + crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi, + crate::ImageClass::Storage { .. } => false, + } + } + + pub const fn is_mipmapped(self) -> bool { + match self { + crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi, + crate::ImageClass::Storage { .. } => false, + } + } +} + +impl crate::Module { + pub const fn to_ctx(&self) -> GlobalCtx<'_> { + GlobalCtx { + types: &self.types, + constants: &self.constants, + const_expressions: &self.const_expressions, + } + } +} + +#[derive(Debug)] +pub(super) enum U32EvalError { + NonConst, + Negative, +} + +#[derive(Clone, Copy)] +pub struct GlobalCtx<'a> { + pub types: &'a crate::UniqueArena<crate::Type>, + pub constants: &'a crate::Arena<crate::Constant>, + pub const_expressions: &'a crate::Arena<crate::Expression>, +} + +impl GlobalCtx<'_> { + /// Try to evaluate the expression in `self.const_expressions` using its `handle` and return it as a `u32`. + #[allow(dead_code)] + pub(super) fn eval_expr_to_u32( + &self, + handle: crate::Handle<crate::Expression>, + ) -> Result<u32, U32EvalError> { + self.eval_expr_to_u32_from(handle, self.const_expressions) + } + + /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `u32`. + pub(super) fn eval_expr_to_u32_from( + &self, + handle: crate::Handle<crate::Expression>, + arena: &crate::Arena<crate::Expression>, + ) -> Result<u32, U32EvalError> { + match self.eval_expr_to_literal_from(handle, arena) { + Some(crate::Literal::U32(value)) => Ok(value), + Some(crate::Literal::I32(value)) => { + value.try_into().map_err(|_| U32EvalError::Negative) + } + _ => Err(U32EvalError::NonConst), + } + } + + #[allow(dead_code)] + pub(crate) fn eval_expr_to_literal( + &self, + handle: crate::Handle<crate::Expression>, + ) -> Option<crate::Literal> { + self.eval_expr_to_literal_from(handle, self.const_expressions) + } + + fn eval_expr_to_literal_from( + &self, + handle: crate::Handle<crate::Expression>, + arena: &crate::Arena<crate::Expression>, + ) -> Option<crate::Literal> { + fn get( + gctx: GlobalCtx, + handle: crate::Handle<crate::Expression>, + arena: &crate::Arena<crate::Expression>, + ) -> Option<crate::Literal> { + match arena[handle] { + crate::Expression::Literal(literal) => Some(literal), + crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner { + crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar), + _ => None, + }, + _ => None, + } + } + match arena[handle] { + crate::Expression::Constant(c) => { + get(*self, self.constants[c].init, self.const_expressions) + } + _ => get(*self, handle, arena), + } + } +} + +/// Return an iterator over the individual components assembled by a +/// `Compose` expression. +/// +/// Given `ty` and `components` from an `Expression::Compose`, return an +/// iterator over the components of the resulting value. +/// +/// Normally, this would just be an iterator over `components`. However, +/// `Compose` expressions can concatenate vectors, in which case the i'th +/// value being composed is not generally the i'th element of `components`. +/// This function consults `ty` to decide if this concatenation is occurring, +/// and returns an iterator that produces the components of the result of +/// the `Compose` expression in either case. +pub fn flatten_compose<'arenas>( + ty: crate::Handle<crate::Type>, + components: &'arenas [crate::Handle<crate::Expression>], + expressions: &'arenas crate::Arena<crate::Expression>, + types: &'arenas crate::UniqueArena<crate::Type>, +) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas { + // Returning `impl Iterator` is a bit tricky. We may or may not + // want to flatten the components, but we have to settle on a + // single concrete type to return. This function returns a single + // iterator chain that handles both the flattening and + // non-flattening cases. + let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner { + (size as usize, true) + } else { + (components.len(), false) + }; + + /// Flatten `Compose` expressions if `is_vector` is true. + fn flatten_compose<'c>( + component: &'c crate::Handle<crate::Expression>, + is_vector: bool, + expressions: &'c crate::Arena<crate::Expression>, + ) -> &'c [crate::Handle<crate::Expression>] { + if is_vector { + if let crate::Expression::Compose { + ty: _, + components: ref subcomponents, + } = expressions[*component] + { + return subcomponents; + } + } + std::slice::from_ref(component) + } + + /// Flatten `Splat` expressions if `is_vector` is true. + fn flatten_splat<'c>( + component: &'c crate::Handle<crate::Expression>, + is_vector: bool, + expressions: &'c crate::Arena<crate::Expression>, + ) -> impl Iterator<Item = crate::Handle<crate::Expression>> { + let mut expr = *component; + let mut count = 1; + if is_vector { + if let crate::Expression::Splat { size, value } = expressions[expr] { + expr = value; + count = size as usize; + } + } + std::iter::repeat(expr).take(count) + } + + // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to + // flatten up to two levels of `Compose` expressions. + // + // Expressions like `vec4(vec3(1.0), 1.0)` require us to flatten + // `Splat` expressions. Fortunately, the operand of a `Splat` must + // be a scalar, so we can stop there. + components + .iter() + .flat_map(move |component| flatten_compose(component, is_vector, expressions)) + .flat_map(move |component| flatten_compose(component, is_vector, expressions)) + .flat_map(move |component| flatten_splat(component, is_vector, expressions)) + .take(size) +} + +#[test] +fn test_matrix_size() { + let module = crate::Module::default(); + assert_eq!( + crate::TypeInner::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Tri, + scalar: crate::Scalar::F32, + } + .size(module.to_ctx()), + 48, + ); +} diff --git a/third_party/rust/naga/src/proc/namer.rs b/third_party/rust/naga/src/proc/namer.rs new file mode 100644 index 0000000000..8afacb593d --- /dev/null +++ b/third_party/rust/naga/src/proc/namer.rs @@ -0,0 +1,281 @@ +use crate::{arena::Handle, FastHashMap, FastHashSet}; +use std::borrow::Cow; +use std::hash::{Hash, Hasher}; + +pub type EntryPointIndex = u16; +const SEPARATOR: char = '_'; + +#[derive(Debug, Eq, Hash, PartialEq)] +pub enum NameKey { + Constant(Handle<crate::Constant>), + GlobalVariable(Handle<crate::GlobalVariable>), + Type(Handle<crate::Type>), + StructMember(Handle<crate::Type>, u32), + Function(Handle<crate::Function>), + FunctionArgument(Handle<crate::Function>, u32), + FunctionLocal(Handle<crate::Function>, Handle<crate::LocalVariable>), + EntryPoint(EntryPointIndex), + EntryPointLocal(EntryPointIndex, Handle<crate::LocalVariable>), + EntryPointArgument(EntryPointIndex, u32), +} + +/// This processor assigns names to all the things in a module +/// that may need identifiers in a textual backend. +#[derive(Default)] +pub struct Namer { + /// The last numeric suffix used for each base name. Zero means "no suffix". + unique: FastHashMap<String, u32>, + keywords: FastHashSet<&'static str>, + keywords_case_insensitive: FastHashSet<AsciiUniCase<&'static str>>, + reserved_prefixes: Vec<&'static str>, +} + +impl Namer { + /// Return a form of `string` suitable for use as the base of an identifier. + /// + /// - Drop leading digits. + /// - Retain only alphanumeric and `_` characters. + /// - Avoid prefixes in [`Namer::reserved_prefixes`]. + /// - Replace consecutive `_` characters with a single `_` character. + /// + /// The return value is a valid identifier prefix in all of Naga's output languages, + /// and it never ends with a `SEPARATOR` character. + /// It is used as a key into the unique table. + fn sanitize<'s>(&self, string: &'s str) -> Cow<'s, str> { + let string = string + .trim_start_matches(|c: char| c.is_numeric()) + .trim_end_matches(SEPARATOR); + + let base = if !string.is_empty() + && !string.contains("__") + && string + .chars() + .all(|c: char| c.is_ascii_alphanumeric() || c == '_') + { + Cow::Borrowed(string) + } else { + let mut filtered = string + .chars() + .filter(|&c| c.is_ascii_alphanumeric() || c == '_') + .fold(String::new(), |mut s, c| { + if s.ends_with('_') && c == '_' { + return s; + } + s.push(c); + s + }); + let stripped_len = filtered.trim_end_matches(SEPARATOR).len(); + filtered.truncate(stripped_len); + if filtered.is_empty() { + filtered.push_str("unnamed"); + } + Cow::Owned(filtered) + }; + + for prefix in &self.reserved_prefixes { + if base.starts_with(prefix) { + return format!("gen_{base}").into(); + } + } + + base + } + + /// Return a new identifier based on `label_raw`. + /// + /// The result: + /// - is a valid identifier even if `label_raw` is not + /// - conflicts with no keywords listed in `Namer::keywords`, and + /// - is different from any identifier previously constructed by this + /// `Namer`. + /// + /// Guarantee uniqueness by applying a numeric suffix when necessary. If `label_raw` + /// itself ends with digits, separate them from the suffix with an underscore. + pub fn call(&mut self, label_raw: &str) -> String { + use std::fmt::Write as _; // for write!-ing to Strings + + let base = self.sanitize(label_raw); + debug_assert!(!base.is_empty() && !base.ends_with(SEPARATOR)); + + // This would seem to be a natural place to use `HashMap::entry`. However, `entry` + // requires an owned key, and we'd like to avoid heap-allocating strings we're + // just going to throw away. The approach below double-hashes only when we create + // a new entry, in which case the heap allocation of the owned key was more + // expensive anyway. + match self.unique.get_mut(base.as_ref()) { + Some(count) => { + *count += 1; + // Add the suffix. This may fit in base's existing allocation. + let mut suffixed = base.into_owned(); + write!(suffixed, "{}{}", SEPARATOR, *count).unwrap(); + suffixed + } + None => { + let mut suffixed = base.to_string(); + if base.ends_with(char::is_numeric) + || self.keywords.contains(base.as_ref()) + || self + .keywords_case_insensitive + .contains(&AsciiUniCase(base.as_ref())) + { + suffixed.push(SEPARATOR); + } + debug_assert!(!self.keywords.contains::<str>(&suffixed)); + // `self.unique` wants to own its keys. This allocates only if we haven't + // already done so earlier. + self.unique.insert(base.into_owned(), 0); + suffixed + } + } + } + + pub fn call_or(&mut self, label: &Option<String>, fallback: &str) -> String { + self.call(match *label { + Some(ref name) => name, + None => fallback, + }) + } + + /// Enter a local namespace for things like structs. + /// + /// Struct member names only need to be unique amongst themselves, not + /// globally. This function temporarily establishes a fresh, empty naming + /// context for the duration of the call to `body`. + fn namespace(&mut self, capacity: usize, body: impl FnOnce(&mut Self)) { + let fresh = FastHashMap::with_capacity_and_hasher(capacity, Default::default()); + let outer = std::mem::replace(&mut self.unique, fresh); + body(self); + self.unique = outer; + } + + pub fn reset( + &mut self, + module: &crate::Module, + reserved_keywords: &[&'static str], + extra_reserved_keywords: &[&'static str], + reserved_keywords_case_insensitive: &[&'static str], + reserved_prefixes: &[&'static str], + output: &mut FastHashMap<NameKey, String>, + ) { + self.reserved_prefixes.clear(); + self.reserved_prefixes.extend(reserved_prefixes.iter()); + + self.unique.clear(); + self.keywords.clear(); + self.keywords.extend(reserved_keywords.iter()); + self.keywords.extend(extra_reserved_keywords.iter()); + + debug_assert!(reserved_keywords_case_insensitive + .iter() + .all(|s| s.is_ascii())); + self.keywords_case_insensitive.clear(); + self.keywords_case_insensitive.extend( + reserved_keywords_case_insensitive + .iter() + .map(|string| (AsciiUniCase(*string))), + ); + + let mut temp = String::new(); + + for (ty_handle, ty) in module.types.iter() { + let ty_name = self.call_or(&ty.name, "type"); + output.insert(NameKey::Type(ty_handle), ty_name); + + if let crate::TypeInner::Struct { ref members, .. } = ty.inner { + // struct members have their own namespace, because access is always prefixed + self.namespace(members.len(), |namer| { + for (index, member) in members.iter().enumerate() { + let name = namer.call_or(&member.name, "member"); + output.insert(NameKey::StructMember(ty_handle, index as u32), name); + } + }) + } + } + + for (ep_index, ep) in module.entry_points.iter().enumerate() { + let ep_name = self.call(&ep.name); + output.insert(NameKey::EntryPoint(ep_index as _), ep_name); + for (index, arg) in ep.function.arguments.iter().enumerate() { + let name = self.call_or(&arg.name, "param"); + output.insert( + NameKey::EntryPointArgument(ep_index as _, index as u32), + name, + ); + } + for (handle, var) in ep.function.local_variables.iter() { + let name = self.call_or(&var.name, "local"); + output.insert(NameKey::EntryPointLocal(ep_index as _, handle), name); + } + } + + for (fun_handle, fun) in module.functions.iter() { + let fun_name = self.call_or(&fun.name, "function"); + output.insert(NameKey::Function(fun_handle), fun_name); + for (index, arg) in fun.arguments.iter().enumerate() { + let name = self.call_or(&arg.name, "param"); + output.insert(NameKey::FunctionArgument(fun_handle, index as u32), name); + } + for (handle, var) in fun.local_variables.iter() { + let name = self.call_or(&var.name, "local"); + output.insert(NameKey::FunctionLocal(fun_handle, handle), name); + } + } + + for (handle, var) in module.global_variables.iter() { + let name = self.call_or(&var.name, "global"); + output.insert(NameKey::GlobalVariable(handle), name); + } + + for (handle, constant) in module.constants.iter() { + let label = match constant.name { + Some(ref name) => name, + None => { + use std::fmt::Write; + // Try to be more descriptive about the constant values + temp.clear(); + write!(temp, "const_{}", output[&NameKey::Type(constant.ty)]).unwrap(); + &temp + } + }; + let name = self.call(label); + output.insert(NameKey::Constant(handle), name); + } + } +} + +/// A string wrapper type with an ascii case insensitive Eq and Hash impl +struct AsciiUniCase<S: AsRef<str> + ?Sized>(S); + +impl<S: AsRef<str>> PartialEq<Self> for AsciiUniCase<S> { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0.as_ref().eq_ignore_ascii_case(other.0.as_ref()) + } +} + +impl<S: AsRef<str>> Eq for AsciiUniCase<S> {} + +impl<S: AsRef<str>> Hash for AsciiUniCase<S> { + #[inline] + fn hash<H: Hasher>(&self, hasher: &mut H) { + for byte in self + .0 + .as_ref() + .as_bytes() + .iter() + .map(|b| b.to_ascii_lowercase()) + { + hasher.write_u8(byte); + } + } +} + +#[test] +fn test() { + let mut namer = Namer::default(); + assert_eq!(namer.call("x"), "x"); + assert_eq!(namer.call("x"), "x_1"); + assert_eq!(namer.call("x1"), "x1_"); + assert_eq!(namer.call("__x"), "_x"); + assert_eq!(namer.call("1___x"), "_x_1"); +} diff --git a/third_party/rust/naga/src/proc/terminator.rs b/third_party/rust/naga/src/proc/terminator.rs new file mode 100644 index 0000000000..a5239d4eca --- /dev/null +++ b/third_party/rust/naga/src/proc/terminator.rs @@ -0,0 +1,44 @@ +/// Ensure that the given block has return statements +/// at the end of its control flow. +/// +/// Note: we don't want to blindly append a return statement +/// to the end, because it may be either redundant or invalid, +/// e.g. when the user already has returns in if/else branches. +pub fn ensure_block_returns(block: &mut crate::Block) { + use crate::Statement as S; + match block.last_mut() { + Some(&mut S::Block(ref mut b)) => { + ensure_block_returns(b); + } + Some(&mut S::If { + condition: _, + ref mut accept, + ref mut reject, + }) => { + ensure_block_returns(accept); + ensure_block_returns(reject); + } + Some(&mut S::Switch { + selector: _, + ref mut cases, + }) => { + for case in cases.iter_mut() { + if !case.fall_through { + ensure_block_returns(&mut case.body); + } + } + } + Some(&mut (S::Emit(_) | S::Break | S::Continue | S::Return { .. } | S::Kill)) => (), + Some( + &mut (S::Loop { .. } + | S::Store { .. } + | S::ImageStore { .. } + | S::Call { .. } + | S::RayQuery { .. } + | S::Atomic { .. } + | S::WorkGroupUniformLoad { .. } + | S::Barrier(_)), + ) + | None => block.push(S::Return { value: None }, Default::default()), + } +} diff --git a/third_party/rust/naga/src/proc/typifier.rs b/third_party/rust/naga/src/proc/typifier.rs new file mode 100644 index 0000000000..9c4403445c --- /dev/null +++ b/third_party/rust/naga/src/proc/typifier.rs @@ -0,0 +1,893 @@ +use crate::arena::{Arena, Handle, UniqueArena}; + +use thiserror::Error; + +/// The result of computing an expression's type. +/// +/// This is the (Rust) type returned by [`ResolveContext::resolve`] to represent +/// the (Naga) type it ascribes to some expression. +/// +/// You might expect such a function to simply return a `Handle<Type>`. However, +/// we want type resolution to be a read-only process, and that would limit the +/// possible results to types already present in the expression's associated +/// `UniqueArena<Type>`. Naga IR does have certain expressions whose types are +/// not certain to be present. +/// +/// So instead, type resolution returns a `TypeResolution` enum: either a +/// [`Handle`], referencing some type in the arena, or a [`Value`], holding a +/// free-floating [`TypeInner`]. This extends the range to cover anything that +/// can be represented with a `TypeInner` referring to the existing arena. +/// +/// What sorts of expressions can have types not available in the arena? +/// +/// - An [`Access`] or [`AccessIndex`] expression applied to a [`Vector`] or +/// [`Matrix`] must have a [`Scalar`] or [`Vector`] type. But since `Vector` +/// and `Matrix` represent their element and column types implicitly, not +/// via a handle, there may not be a suitable type in the expression's +/// associated arena. Instead, resolving such an expression returns a +/// `TypeResolution::Value(TypeInner::X { ... })`, where `X` is `Scalar` or +/// `Vector`. +/// +/// - Similarly, the type of an [`Access`] or [`AccessIndex`] expression +/// applied to a *pointer to* a vector or matrix must produce a *pointer to* +/// a scalar or vector type. These cannot be represented with a +/// [`TypeInner::Pointer`], since the `Pointer`'s `base` must point into the +/// arena, and as before, we cannot assume that a suitable scalar or vector +/// type is there. So we take things one step further and provide +/// [`TypeInner::ValuePointer`], specifically for the case of pointers to +/// scalars or vectors. This type fits in a `TypeInner` and is exactly +/// equivalent to a `Pointer` to a `Vector` or `Scalar`. +/// +/// So, for example, the type of an `Access` expression applied to a value of type: +/// +/// ```ignore +/// TypeInner::Matrix { columns, rows, width } +/// ``` +/// +/// might be: +/// +/// ```ignore +/// TypeResolution::Value(TypeInner::Vector { +/// size: rows, +/// kind: ScalarKind::Float, +/// width, +/// }) +/// ``` +/// +/// and the type of an access to a pointer of address space `space` to such a +/// matrix might be: +/// +/// ```ignore +/// TypeResolution::Value(TypeInner::ValuePointer { +/// size: Some(rows), +/// kind: ScalarKind::Float, +/// width, +/// space, +/// }) +/// ``` +/// +/// [`Handle`]: TypeResolution::Handle +/// [`Value`]: TypeResolution::Value +/// +/// [`Access`]: crate::Expression::Access +/// [`AccessIndex`]: crate::Expression::AccessIndex +/// +/// [`TypeInner`]: crate::TypeInner +/// [`Matrix`]: crate::TypeInner::Matrix +/// [`Pointer`]: crate::TypeInner::Pointer +/// [`Scalar`]: crate::TypeInner::Scalar +/// [`ValuePointer`]: crate::TypeInner::ValuePointer +/// [`Vector`]: crate::TypeInner::Vector +/// +/// [`TypeInner::Pointer`]: crate::TypeInner::Pointer +/// [`TypeInner::ValuePointer`]: crate::TypeInner::ValuePointer +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum TypeResolution { + /// A type stored in the associated arena. + Handle(Handle<crate::Type>), + + /// A free-floating [`TypeInner`], representing a type that may not be + /// available in the associated arena. However, the `TypeInner` itself may + /// contain `Handle<Type>` values referring to types from the arena. + /// + /// [`TypeInner`]: crate::TypeInner + Value(crate::TypeInner), +} + +impl TypeResolution { + pub const fn handle(&self) -> Option<Handle<crate::Type>> { + match *self { + Self::Handle(handle) => Some(handle), + Self::Value(_) => None, + } + } + + pub fn inner_with<'a>(&'a self, arena: &'a UniqueArena<crate::Type>) -> &'a crate::TypeInner { + match *self { + Self::Handle(handle) => &arena[handle].inner, + Self::Value(ref inner) => inner, + } + } +} + +// Clone is only implemented for numeric variants of `TypeInner`. +impl Clone for TypeResolution { + fn clone(&self) -> Self { + use crate::TypeInner as Ti; + match *self { + Self::Handle(handle) => Self::Handle(handle), + Self::Value(ref v) => Self::Value(match *v { + Ti::Scalar(scalar) => Ti::Scalar(scalar), + Ti::Vector { size, scalar } => Ti::Vector { size, scalar }, + Ti::Matrix { + rows, + columns, + scalar, + } => Ti::Matrix { + rows, + columns, + scalar, + }, + Ti::Pointer { base, space } => Ti::Pointer { base, space }, + Ti::ValuePointer { + size, + scalar, + space, + } => Ti::ValuePointer { + size, + scalar, + space, + }, + _ => unreachable!("Unexpected clone type: {:?}", v), + }), + } + } +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ResolveError { + #[error("Index {index} is out of bounds for expression {expr:?}")] + OutOfBoundsIndex { + expr: Handle<crate::Expression>, + index: u32, + }, + #[error("Invalid access into expression {expr:?}, indexed: {indexed}")] + InvalidAccess { + expr: Handle<crate::Expression>, + indexed: bool, + }, + #[error("Invalid sub-access into type {ty:?}, indexed: {indexed}")] + InvalidSubAccess { + ty: Handle<crate::Type>, + indexed: bool, + }, + #[error("Invalid scalar {0:?}")] + InvalidScalar(Handle<crate::Expression>), + #[error("Invalid vector {0:?}")] + InvalidVector(Handle<crate::Expression>), + #[error("Invalid pointer {0:?}")] + InvalidPointer(Handle<crate::Expression>), + #[error("Invalid image {0:?}")] + InvalidImage(Handle<crate::Expression>), + #[error("Function {name} not defined")] + FunctionNotDefined { name: String }, + #[error("Function without return type")] + FunctionReturnsVoid, + #[error("Incompatible operands: {0}")] + IncompatibleOperands(String), + #[error("Function argument {0} doesn't exist")] + FunctionArgumentNotFound(u32), + #[error("Special type is not registered within the module")] + MissingSpecialType, +} + +pub struct ResolveContext<'a> { + pub constants: &'a Arena<crate::Constant>, + pub types: &'a UniqueArena<crate::Type>, + pub special_types: &'a crate::SpecialTypes, + pub global_vars: &'a Arena<crate::GlobalVariable>, + pub local_vars: &'a Arena<crate::LocalVariable>, + pub functions: &'a Arena<crate::Function>, + pub arguments: &'a [crate::FunctionArgument], +} + +impl<'a> ResolveContext<'a> { + /// Initialize a resolve context from the module. + pub const fn with_locals( + module: &'a crate::Module, + local_vars: &'a Arena<crate::LocalVariable>, + arguments: &'a [crate::FunctionArgument], + ) -> Self { + Self { + constants: &module.constants, + types: &module.types, + special_types: &module.special_types, + global_vars: &module.global_variables, + local_vars, + functions: &module.functions, + arguments, + } + } + + /// Determine the type of `expr`. + /// + /// The `past` argument must be a closure that can resolve the types of any + /// expressions that `expr` refers to. These can be gathered by caching the + /// results of prior calls to `resolve`, perhaps as done by the + /// [`front::Typifier`] utility type. + /// + /// Type resolution is a read-only process: this method takes `self` by + /// shared reference. However, this means that we cannot add anything to + /// `self.types` that we might need to describe `expr`. To work around this, + /// this method returns a [`TypeResolution`], rather than simply returning a + /// `Handle<Type>`; see the documentation for [`TypeResolution`] for + /// details. + /// + /// [`front::Typifier`]: crate::front::Typifier + pub fn resolve( + &self, + expr: &crate::Expression, + past: impl Fn(Handle<crate::Expression>) -> Result<&'a TypeResolution, ResolveError>, + ) -> Result<TypeResolution, ResolveError> { + use crate::TypeInner as Ti; + let types = self.types; + Ok(match *expr { + crate::Expression::Access { base, .. } => match *past(base)?.inner_with(types) { + // Arrays and matrices can only be indexed dynamically behind a + // pointer, but that's a validation error, not a type error, so + // go ahead provide a type here. + Ti::Array { base, .. } => TypeResolution::Handle(base), + Ti::Matrix { rows, scalar, .. } => { + TypeResolution::Value(Ti::Vector { size: rows, scalar }) + } + Ti::Vector { size: _, scalar } => TypeResolution::Value(Ti::Scalar(scalar)), + Ti::ValuePointer { + size: Some(_), + scalar, + space, + } => TypeResolution::Value(Ti::ValuePointer { + size: None, + scalar, + space, + }), + Ti::Pointer { base, space } => { + TypeResolution::Value(match types[base].inner { + Ti::Array { base, .. } => Ti::Pointer { base, space }, + Ti::Vector { size: _, scalar } => Ti::ValuePointer { + size: None, + scalar, + space, + }, + // Matrices are only dynamically indexed behind a pointer + Ti::Matrix { + columns: _, + rows, + scalar, + } => Ti::ValuePointer { + size: Some(rows), + scalar, + space, + }, + Ti::BindingArray { base, .. } => Ti::Pointer { base, space }, + ref other => { + log::error!("Access sub-type {:?}", other); + return Err(ResolveError::InvalidSubAccess { + ty: base, + indexed: false, + }); + } + }) + } + Ti::BindingArray { base, .. } => TypeResolution::Handle(base), + ref other => { + log::error!("Access type {:?}", other); + return Err(ResolveError::InvalidAccess { + expr: base, + indexed: false, + }); + } + }, + crate::Expression::AccessIndex { base, index } => { + match *past(base)?.inner_with(types) { + Ti::Vector { size, scalar } => { + if index >= size as u32 { + return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); + } + TypeResolution::Value(Ti::Scalar(scalar)) + } + Ti::Matrix { + columns, + rows, + scalar, + } => { + if index >= columns as u32 { + return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); + } + TypeResolution::Value(crate::TypeInner::Vector { size: rows, scalar }) + } + Ti::Array { base, .. } => TypeResolution::Handle(base), + Ti::Struct { ref members, .. } => { + let member = members + .get(index as usize) + .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?; + TypeResolution::Handle(member.ty) + } + Ti::ValuePointer { + size: Some(size), + scalar, + space, + } => { + if index >= size as u32 { + return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); + } + TypeResolution::Value(Ti::ValuePointer { + size: None, + scalar, + space, + }) + } + Ti::Pointer { + base: ty_base, + space, + } => TypeResolution::Value(match types[ty_base].inner { + Ti::Array { base, .. } => Ti::Pointer { base, space }, + Ti::Vector { size, scalar } => { + if index >= size as u32 { + return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); + } + Ti::ValuePointer { + size: None, + scalar, + space, + } + } + Ti::Matrix { + rows, + columns, + scalar, + } => { + if index >= columns as u32 { + return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); + } + Ti::ValuePointer { + size: Some(rows), + scalar, + space, + } + } + Ti::Struct { ref members, .. } => { + let member = members + .get(index as usize) + .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?; + Ti::Pointer { + base: member.ty, + space, + } + } + Ti::BindingArray { base, .. } => Ti::Pointer { base, space }, + ref other => { + log::error!("Access index sub-type {:?}", other); + return Err(ResolveError::InvalidSubAccess { + ty: ty_base, + indexed: true, + }); + } + }), + Ti::BindingArray { base, .. } => TypeResolution::Handle(base), + ref other => { + log::error!("Access index type {:?}", other); + return Err(ResolveError::InvalidAccess { + expr: base, + indexed: true, + }); + } + } + } + crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) { + Ti::Scalar(scalar) => TypeResolution::Value(Ti::Vector { size, scalar }), + ref other => { + log::error!("Scalar type {:?}", other); + return Err(ResolveError::InvalidScalar(value)); + } + }, + crate::Expression::Swizzle { + size, + vector, + pattern: _, + } => match *past(vector)?.inner_with(types) { + Ti::Vector { size: _, scalar } => { + TypeResolution::Value(Ti::Vector { size, scalar }) + } + ref other => { + log::error!("Vector type {:?}", other); + return Err(ResolveError::InvalidVector(vector)); + } + }, + crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()), + crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty), + crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty), + crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty), + crate::Expression::FunctionArgument(index) => { + let arg = self + .arguments + .get(index as usize) + .ok_or(ResolveError::FunctionArgumentNotFound(index))?; + TypeResolution::Handle(arg.ty) + } + crate::Expression::GlobalVariable(h) => { + let var = &self.global_vars[h]; + if var.space == crate::AddressSpace::Handle { + TypeResolution::Handle(var.ty) + } else { + TypeResolution::Value(Ti::Pointer { + base: var.ty, + space: var.space, + }) + } + } + crate::Expression::LocalVariable(h) => { + let var = &self.local_vars[h]; + TypeResolution::Value(Ti::Pointer { + base: var.ty, + space: crate::AddressSpace::Function, + }) + } + crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) { + Ti::Pointer { base, space: _ } => { + if let Ti::Atomic(scalar) = types[base].inner { + TypeResolution::Value(Ti::Scalar(scalar)) + } else { + TypeResolution::Handle(base) + } + } + Ti::ValuePointer { + size, + scalar, + space: _, + } => TypeResolution::Value(match size { + Some(size) => Ti::Vector { size, scalar }, + None => Ti::Scalar(scalar), + }), + ref other => { + log::error!("Pointer type {:?}", other); + return Err(ResolveError::InvalidPointer(pointer)); + } + }, + crate::Expression::ImageSample { + image, + gather: Some(_), + .. + } => match *past(image)?.inner_with(types) { + Ti::Image { class, .. } => TypeResolution::Value(Ti::Vector { + scalar: crate::Scalar { + kind: match class { + crate::ImageClass::Sampled { kind, multi: _ } => kind, + _ => crate::ScalarKind::Float, + }, + width: 4, + }, + size: crate::VectorSize::Quad, + }), + ref other => { + log::error!("Image type {:?}", other); + return Err(ResolveError::InvalidImage(image)); + } + }, + crate::Expression::ImageSample { image, .. } + | crate::Expression::ImageLoad { image, .. } => match *past(image)?.inner_with(types) { + Ti::Image { class, .. } => TypeResolution::Value(match class { + crate::ImageClass::Depth { multi: _ } => Ti::Scalar(crate::Scalar::F32), + crate::ImageClass::Sampled { kind, multi: _ } => Ti::Vector { + scalar: crate::Scalar { kind, width: 4 }, + size: crate::VectorSize::Quad, + }, + crate::ImageClass::Storage { format, .. } => Ti::Vector { + scalar: crate::Scalar { + kind: format.into(), + width: 4, + }, + size: crate::VectorSize::Quad, + }, + }), + ref other => { + log::error!("Image type {:?}", other); + return Err(ResolveError::InvalidImage(image)); + } + }, + crate::Expression::ImageQuery { image, query } => TypeResolution::Value(match query { + crate::ImageQuery::Size { level: _ } => match *past(image)?.inner_with(types) { + Ti::Image { dim, .. } => match dim { + crate::ImageDimension::D1 => Ti::Scalar(crate::Scalar::U32), + crate::ImageDimension::D2 | crate::ImageDimension::Cube => Ti::Vector { + size: crate::VectorSize::Bi, + scalar: crate::Scalar::U32, + }, + crate::ImageDimension::D3 => Ti::Vector { + size: crate::VectorSize::Tri, + scalar: crate::Scalar::U32, + }, + }, + ref other => { + log::error!("Image type {:?}", other); + return Err(ResolveError::InvalidImage(image)); + } + }, + crate::ImageQuery::NumLevels + | crate::ImageQuery::NumLayers + | crate::ImageQuery::NumSamples => Ti::Scalar(crate::Scalar::U32), + }), + crate::Expression::Unary { expr, .. } => past(expr)?.clone(), + crate::Expression::Binary { op, left, right } => match op { + crate::BinaryOperator::Add + | crate::BinaryOperator::Subtract + | crate::BinaryOperator::Divide + | crate::BinaryOperator::Modulo => past(left)?.clone(), + crate::BinaryOperator::Multiply => { + let (res_left, res_right) = (past(left)?, past(right)?); + match (res_left.inner_with(types), res_right.inner_with(types)) { + ( + &Ti::Matrix { + columns: _, + rows, + scalar, + }, + &Ti::Matrix { columns, .. }, + ) => TypeResolution::Value(Ti::Matrix { + columns, + rows, + scalar, + }), + ( + &Ti::Matrix { + columns: _, + rows, + scalar, + }, + &Ti::Vector { .. }, + ) => TypeResolution::Value(Ti::Vector { size: rows, scalar }), + ( + &Ti::Vector { .. }, + &Ti::Matrix { + columns, + rows: _, + scalar, + }, + ) => TypeResolution::Value(Ti::Vector { + size: columns, + scalar, + }), + (&Ti::Scalar { .. }, _) => res_right.clone(), + (_, &Ti::Scalar { .. }) => res_left.clone(), + (&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(), + (tl, tr) => { + return Err(ResolveError::IncompatibleOperands(format!( + "{tl:?} * {tr:?}" + ))) + } + } + } + crate::BinaryOperator::Equal + | crate::BinaryOperator::NotEqual + | crate::BinaryOperator::Less + | crate::BinaryOperator::LessEqual + | crate::BinaryOperator::Greater + | crate::BinaryOperator::GreaterEqual + | crate::BinaryOperator::LogicalAnd + | crate::BinaryOperator::LogicalOr => { + let scalar = crate::Scalar::BOOL; + let inner = match *past(left)?.inner_with(types) { + Ti::Scalar { .. } => Ti::Scalar(scalar), + Ti::Vector { size, .. } => Ti::Vector { size, scalar }, + ref other => { + return Err(ResolveError::IncompatibleOperands(format!( + "{op:?}({other:?}, _)" + ))) + } + }; + TypeResolution::Value(inner) + } + crate::BinaryOperator::And + | crate::BinaryOperator::ExclusiveOr + | crate::BinaryOperator::InclusiveOr + | crate::BinaryOperator::ShiftLeft + | crate::BinaryOperator::ShiftRight => past(left)?.clone(), + }, + crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty), + crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty), + crate::Expression::Select { accept, .. } => past(accept)?.clone(), + crate::Expression::Derivative { expr, .. } => past(expr)?.clone(), + crate::Expression::Relational { fun, argument } => match fun { + crate::RelationalFunction::All | crate::RelationalFunction::Any => { + TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL)) + } + crate::RelationalFunction::IsNan | crate::RelationalFunction::IsInf => { + match *past(argument)?.inner_with(types) { + Ti::Scalar { .. } => TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL)), + Ti::Vector { size, .. } => TypeResolution::Value(Ti::Vector { + scalar: crate::Scalar::BOOL, + size, + }), + ref other => { + return Err(ResolveError::IncompatibleOperands(format!( + "{fun:?}({other:?})" + ))) + } + } + } + }, + crate::Expression::Math { + fun, + arg, + arg1, + arg2: _, + arg3: _, + } => { + use crate::MathFunction as Mf; + let res_arg = past(arg)?; + match fun { + // comparison + Mf::Abs | + Mf::Min | + Mf::Max | + Mf::Clamp | + Mf::Saturate | + // trigonometry + Mf::Cos | + Mf::Cosh | + Mf::Sin | + Mf::Sinh | + Mf::Tan | + Mf::Tanh | + Mf::Acos | + Mf::Asin | + Mf::Atan | + Mf::Atan2 | + Mf::Asinh | + Mf::Acosh | + Mf::Atanh | + Mf::Radians | + Mf::Degrees | + // decomposition + Mf::Ceil | + Mf::Floor | + Mf::Round | + Mf::Fract | + Mf::Trunc | + Mf::Ldexp | + // exponent + Mf::Exp | + Mf::Exp2 | + Mf::Log | + Mf::Log2 | + Mf::Pow => res_arg.clone(), + Mf::Modf | Mf::Frexp => { + let (size, width) = match res_arg.inner_with(types) { + &Ti::Scalar(crate::Scalar { + kind: crate::ScalarKind::Float, + width, + }) => (None, width), + &Ti::Vector { + scalar: crate::Scalar { + kind: crate::ScalarKind::Float, + width, + }, + size, + } => (Some(size), width), + ref other => + return Err(ResolveError::IncompatibleOperands(format!("{fun:?}({other:?}, _)"))) + }; + let result = self + .special_types + .predeclared_types + .get(&if fun == Mf::Modf { + crate::PredeclaredType::ModfResult { size, width } + } else { + crate::PredeclaredType::FrexpResult { size, width } + }) + .ok_or(ResolveError::MissingSpecialType)?; + TypeResolution::Handle(*result) + }, + // geometry + Mf::Dot => match *res_arg.inner_with(types) { + Ti::Vector { + size: _, + scalar, + } => TypeResolution::Value(Ti::Scalar(scalar)), + ref other => + return Err(ResolveError::IncompatibleOperands( + format!("{fun:?}({other:?}, _)") + )), + }, + Mf::Outer => { + let arg1 = arg1.ok_or_else(|| ResolveError::IncompatibleOperands( + format!("{fun:?}(_, None)") + ))?; + match (res_arg.inner_with(types), past(arg1)?.inner_with(types)) { + ( + &Ti::Vector { size: columns, scalar }, + &Ti::Vector{ size: rows, .. } + ) => TypeResolution::Value(Ti::Matrix { + columns, + rows, + scalar, + }), + (left, right) => + return Err(ResolveError::IncompatibleOperands( + format!("{fun:?}({left:?}, {right:?})") + )), + } + }, + Mf::Cross => res_arg.clone(), + Mf::Distance | + Mf::Length => match *res_arg.inner_with(types) { + Ti::Scalar(scalar) | + Ti::Vector {scalar,size:_} => TypeResolution::Value(Ti::Scalar(scalar)), + ref other => return Err(ResolveError::IncompatibleOperands( + format!("{fun:?}({other:?})") + )), + }, + Mf::Normalize | + Mf::FaceForward | + Mf::Reflect | + Mf::Refract => res_arg.clone(), + // computational + Mf::Sign | + Mf::Fma | + Mf::Mix | + Mf::Step | + Mf::SmoothStep | + Mf::Sqrt | + Mf::InverseSqrt => res_arg.clone(), + Mf::Transpose => match *res_arg.inner_with(types) { + Ti::Matrix { + columns, + rows, + scalar, + } => TypeResolution::Value(Ti::Matrix { + columns: rows, + rows: columns, + scalar, + }), + ref other => return Err(ResolveError::IncompatibleOperands( + format!("{fun:?}({other:?})") + )), + }, + Mf::Inverse => match *res_arg.inner_with(types) { + Ti::Matrix { + columns, + rows, + scalar, + } if columns == rows => TypeResolution::Value(Ti::Matrix { + columns, + rows, + scalar, + }), + ref other => return Err(ResolveError::IncompatibleOperands( + format!("{fun:?}({other:?})") + )), + }, + Mf::Determinant => match *res_arg.inner_with(types) { + Ti::Matrix { + scalar, + .. + } => TypeResolution::Value(Ti::Scalar(scalar)), + ref other => return Err(ResolveError::IncompatibleOperands( + format!("{fun:?}({other:?})") + )), + }, + // bits + Mf::CountTrailingZeros | + Mf::CountLeadingZeros | + Mf::CountOneBits | + Mf::ReverseBits | + Mf::ExtractBits | + Mf::InsertBits | + Mf::FindLsb | + Mf::FindMsb => match *res_arg.inner_with(types) { + Ti::Scalar(scalar @ crate::Scalar { + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + .. + }) => TypeResolution::Value(Ti::Scalar(scalar)), + Ti::Vector { + size, + scalar: scalar @ crate::Scalar { + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + .. + } + } => TypeResolution::Value(Ti::Vector { size, scalar }), + ref other => return Err(ResolveError::IncompatibleOperands( + format!("{fun:?}({other:?})") + )), + }, + // data packing + Mf::Pack4x8snorm | + Mf::Pack4x8unorm | + Mf::Pack2x16snorm | + Mf::Pack2x16unorm | + Mf::Pack2x16float => TypeResolution::Value(Ti::Scalar(crate::Scalar::U32)), + // data unpacking + Mf::Unpack4x8snorm | + Mf::Unpack4x8unorm => TypeResolution::Value(Ti::Vector { + size: crate::VectorSize::Quad, + scalar: crate::Scalar::F32 + }), + Mf::Unpack2x16snorm | + Mf::Unpack2x16unorm | + Mf::Unpack2x16float => TypeResolution::Value(Ti::Vector { + size: crate::VectorSize::Bi, + scalar: crate::Scalar::F32 + }), + } + } + crate::Expression::As { + expr, + kind, + convert, + } => match *past(expr)?.inner_with(types) { + Ti::Scalar(crate::Scalar { width, .. }) => { + TypeResolution::Value(Ti::Scalar(crate::Scalar { + kind, + width: convert.unwrap_or(width), + })) + } + Ti::Vector { + size, + scalar: crate::Scalar { kind: _, width }, + } => TypeResolution::Value(Ti::Vector { + size, + scalar: crate::Scalar { + kind, + width: convert.unwrap_or(width), + }, + }), + Ti::Matrix { + columns, + rows, + mut scalar, + } => { + if let Some(width) = convert { + scalar.width = width; + } + TypeResolution::Value(Ti::Matrix { + columns, + rows, + scalar, + }) + } + ref other => { + return Err(ResolveError::IncompatibleOperands(format!( + "{other:?} as {kind:?}" + ))) + } + }, + crate::Expression::CallResult(function) => { + let result = self.functions[function] + .result + .as_ref() + .ok_or(ResolveError::FunctionReturnsVoid)?; + TypeResolution::Handle(result.ty) + } + crate::Expression::ArrayLength(_) => { + TypeResolution::Value(Ti::Scalar(crate::Scalar::U32)) + } + crate::Expression::RayQueryProceedResult => { + TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL)) + } + crate::Expression::RayQueryGetIntersection { .. } => { + let result = self + .special_types + .ray_intersection + .ok_or(ResolveError::MissingSpecialType)?; + TypeResolution::Handle(result) + } + }) + } +} + +#[test] +fn test_error_size() { + use std::mem::size_of; + assert_eq!(size_of::<ResolveError>(), 32); +} diff --git a/third_party/rust/naga/src/span.rs b/third_party/rust/naga/src/span.rs new file mode 100644 index 0000000000..53246b25d6 --- /dev/null +++ b/third_party/rust/naga/src/span.rs @@ -0,0 +1,501 @@ +use crate::{Arena, Handle, UniqueArena}; +use std::{error::Error, fmt, ops::Range}; + +/// A source code span, used for error reporting. +#[derive(Clone, Copy, Debug, PartialEq, Default)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +pub struct Span { + start: u32, + end: u32, +} + +impl Span { + pub const UNDEFINED: Self = Self { start: 0, end: 0 }; + /// Creates a new `Span` from a range of byte indices + /// + /// Note: end is exclusive, it doesn't belong to the `Span` + pub const fn new(start: u32, end: u32) -> Self { + Span { start, end } + } + + /// Returns a new `Span` starting at `self` and ending at `other` + pub const fn until(&self, other: &Self) -> Self { + Span { + start: self.start, + end: other.end, + } + } + + /// Modifies `self` to contain the smallest `Span` possible that + /// contains both `self` and `other` + pub fn subsume(&mut self, other: Self) { + *self = if !self.is_defined() { + // self isn't defined so use other + other + } else if !other.is_defined() { + // other isn't defined so don't try to subsume + *self + } else { + // Both self and other are defined so calculate the span that contains them both + Span { + start: self.start.min(other.start), + end: self.end.max(other.end), + } + } + } + + /// Returns the smallest `Span` possible that contains all the `Span`s + /// defined in the `from` iterator + pub fn total_span<T: Iterator<Item = Self>>(from: T) -> Self { + let mut span: Self = Default::default(); + for other in from { + span.subsume(other); + } + span + } + + /// Converts `self` to a range if the span is not unknown + pub fn to_range(self) -> Option<Range<usize>> { + if self.is_defined() { + Some(self.start as usize..self.end as usize) + } else { + None + } + } + + /// Check whether `self` was defined or is a default/unknown span + pub fn is_defined(&self) -> bool { + *self != Self::default() + } + + /// Return a [`SourceLocation`] for this span in the provided source. + pub fn location(&self, source: &str) -> SourceLocation { + let prefix = &source[..self.start as usize]; + let line_number = prefix.matches('\n').count() as u32 + 1; + let line_start = prefix.rfind('\n').map(|pos| pos + 1).unwrap_or(0); + let line_position = source[line_start..self.start as usize].chars().count() as u32 + 1; + + SourceLocation { + line_number, + line_position, + offset: self.start, + length: self.end - self.start, + } + } +} + +impl From<Range<usize>> for Span { + fn from(range: Range<usize>) -> Self { + Span { + start: range.start as u32, + end: range.end as u32, + } + } +} + +impl std::ops::Index<Span> for str { + type Output = str; + + #[inline] + fn index(&self, span: Span) -> &str { + &self[span.start as usize..span.end as usize] + } +} + +/// A human-readable representation for a span, tailored for text source. +/// +/// Corresponds to the positional members of [`GPUCompilationMessage`][gcm] from +/// the WebGPU specification, except that `offset` and `length` are in bytes +/// (UTF-8 code units), instead of UTF-16 code units. +/// +/// [gcm]: https://www.w3.org/TR/webgpu/#gpucompilationmessage +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct SourceLocation { + /// 1-based line number. + pub line_number: u32, + /// 1-based column of the start of this span + pub line_position: u32, + /// 0-based Offset in code units (in bytes) of the start of the span. + pub offset: u32, + /// Length in code units (in bytes) of the span. + pub length: u32, +} + +/// A source code span together with "context", a user-readable description of what part of the error it refers to. +pub type SpanContext = (Span, String); + +/// Wrapper class for [`Error`], augmenting it with a list of [`SpanContext`]s. +#[derive(Debug, Clone)] +pub struct WithSpan<E> { + inner: E, + spans: Vec<SpanContext>, +} + +impl<E> fmt::Display for WithSpan<E> +where + E: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + self.inner.fmt(f) + } +} + +#[cfg(test)] +impl<E> PartialEq for WithSpan<E> +where + E: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.inner.eq(&other.inner) + } +} + +impl<E> Error for WithSpan<E> +where + E: Error, +{ + fn source(&self) -> Option<&(dyn Error + 'static)> { + self.inner.source() + } +} + +impl<E> WithSpan<E> { + /// Create a new [`WithSpan`] from an [`Error`], containing no spans. + pub const fn new(inner: E) -> Self { + Self { + inner, + spans: Vec::new(), + } + } + + /// Reverse of [`Self::new`], discards span information and returns an inner error. + #[allow(clippy::missing_const_for_fn)] // ignore due to requirement of #![feature(const_precise_live_drops)] + pub fn into_inner(self) -> E { + self.inner + } + + pub const fn as_inner(&self) -> &E { + &self.inner + } + + /// Iterator over stored [`SpanContext`]s. + pub fn spans(&self) -> impl ExactSizeIterator<Item = &SpanContext> { + self.spans.iter() + } + + /// Add a new span with description. + pub fn with_span<S>(mut self, span: Span, description: S) -> Self + where + S: ToString, + { + if span.is_defined() { + self.spans.push((span, description.to_string())); + } + self + } + + /// Add a [`SpanContext`]. + pub fn with_context(self, span_context: SpanContext) -> Self { + let (span, description) = span_context; + self.with_span(span, description) + } + + /// Add a [`Handle`] from either [`Arena`] or [`UniqueArena`], borrowing its span information from there + /// and annotating with a type and the handle representation. + pub(crate) fn with_handle<T, A: SpanProvider<T>>(self, handle: Handle<T>, arena: &A) -> Self { + self.with_context(arena.get_span_context(handle)) + } + + /// Convert inner error using [`From`]. + pub fn into_other<E2>(self) -> WithSpan<E2> + where + E2: From<E>, + { + WithSpan { + inner: self.inner.into(), + spans: self.spans, + } + } + + /// Convert inner error into another type. Joins span information contained in `self` + /// with what is returned from `func`. + pub fn and_then<F, E2>(self, func: F) -> WithSpan<E2> + where + F: FnOnce(E) -> WithSpan<E2>, + { + let mut res = func(self.inner); + res.spans.extend(self.spans); + res + } + + /// Return a [`SourceLocation`] for our first span, if we have one. + pub fn location(&self, source: &str) -> Option<SourceLocation> { + if self.spans.is_empty() { + return None; + } + + Some(self.spans[0].0.location(source)) + } + + fn diagnostic(&self) -> codespan_reporting::diagnostic::Diagnostic<()> + where + E: Error, + { + use codespan_reporting::diagnostic::{Diagnostic, Label}; + let diagnostic = Diagnostic::error() + .with_message(self.inner.to_string()) + .with_labels( + self.spans() + .map(|&(span, ref desc)| { + Label::primary((), span.to_range().unwrap()).with_message(desc.to_owned()) + }) + .collect(), + ) + .with_notes({ + let mut notes = Vec::new(); + let mut source: &dyn Error = &self.inner; + while let Some(next) = Error::source(source) { + notes.push(next.to_string()); + source = next; + } + notes + }); + diagnostic + } + + /// Emits a summary of the error to standard error stream. + pub fn emit_to_stderr(&self, source: &str) + where + E: Error, + { + self.emit_to_stderr_with_path(source, "wgsl") + } + + /// Emits a summary of the error to standard error stream. + pub fn emit_to_stderr_with_path(&self, source: &str, path: &str) + where + E: Error, + { + use codespan_reporting::{files, term}; + use term::termcolor::{ColorChoice, StandardStream}; + + let files = files::SimpleFile::new(path, source); + let config = term::Config::default(); + let writer = StandardStream::stderr(ColorChoice::Auto); + term::emit(&mut writer.lock(), &config, &files, &self.diagnostic()) + .expect("cannot write error"); + } + + /// Emits a summary of the error to a string. + pub fn emit_to_string(&self, source: &str) -> String + where + E: Error, + { + self.emit_to_string_with_path(source, "wgsl") + } + + /// Emits a summary of the error to a string. + pub fn emit_to_string_with_path(&self, source: &str, path: &str) -> String + where + E: Error, + { + use codespan_reporting::{files, term}; + use term::termcolor::NoColor; + + let files = files::SimpleFile::new(path, source); + let config = codespan_reporting::term::Config::default(); + let mut writer = NoColor::new(Vec::new()); + term::emit(&mut writer, &config, &files, &self.diagnostic()).expect("cannot write error"); + String::from_utf8(writer.into_inner()).unwrap() + } +} + +/// Convenience trait for [`Error`] to be able to apply spans to anything. +pub(crate) trait AddSpan: Sized { + type Output; + /// See [`WithSpan::new`]. + fn with_span(self) -> Self::Output; + /// See [`WithSpan::with_span`]. + fn with_span_static(self, span: Span, description: &'static str) -> Self::Output; + /// See [`WithSpan::with_context`]. + fn with_span_context(self, span_context: SpanContext) -> Self::Output; + /// See [`WithSpan::with_handle`]. + fn with_span_handle<T, A: SpanProvider<T>>(self, handle: Handle<T>, arena: &A) -> Self::Output; +} + +/// Trait abstracting over getting a span from an [`Arena`] or a [`UniqueArena`]. +pub(crate) trait SpanProvider<T> { + fn get_span(&self, handle: Handle<T>) -> Span; + fn get_span_context(&self, handle: Handle<T>) -> SpanContext { + match self.get_span(handle) { + x if !x.is_defined() => (Default::default(), "".to_string()), + known => ( + known, + format!("{} {:?}", std::any::type_name::<T>(), handle), + ), + } + } +} + +impl<T> SpanProvider<T> for Arena<T> { + fn get_span(&self, handle: Handle<T>) -> Span { + self.get_span(handle) + } +} + +impl<T> SpanProvider<T> for UniqueArena<T> { + fn get_span(&self, handle: Handle<T>) -> Span { + self.get_span(handle) + } +} + +impl<E> AddSpan for E +where + E: Error, +{ + type Output = WithSpan<Self>; + fn with_span(self) -> WithSpan<Self> { + WithSpan::new(self) + } + + fn with_span_static(self, span: Span, description: &'static str) -> WithSpan<Self> { + WithSpan::new(self).with_span(span, description) + } + + fn with_span_context(self, span_context: SpanContext) -> WithSpan<Self> { + WithSpan::new(self).with_context(span_context) + } + + fn with_span_handle<T, A: SpanProvider<T>>( + self, + handle: Handle<T>, + arena: &A, + ) -> WithSpan<Self> { + WithSpan::new(self).with_handle(handle, arena) + } +} + +/// Convenience trait for [`Result`], adding a [`MapErrWithSpan::map_err_inner`] +/// mapping to [`WithSpan::and_then`]. +pub trait MapErrWithSpan<E, E2>: Sized { + type Output: Sized; + fn map_err_inner<F, E3>(self, func: F) -> Self::Output + where + F: FnOnce(E) -> WithSpan<E3>, + E2: From<E3>; +} + +impl<T, E, E2> MapErrWithSpan<E, E2> for Result<T, WithSpan<E>> { + type Output = Result<T, WithSpan<E2>>; + fn map_err_inner<F, E3>(self, func: F) -> Result<T, WithSpan<E2>> + where + F: FnOnce(E) -> WithSpan<E3>, + E2: From<E3>, + { + self.map_err(|e| e.and_then(func).into_other::<E2>()) + } +} + +#[test] +fn span_location() { + let source = "12\n45\n\n89\n"; + assert_eq!( + Span { start: 0, end: 1 }.location(source), + SourceLocation { + line_number: 1, + line_position: 1, + offset: 0, + length: 1 + } + ); + assert_eq!( + Span { start: 1, end: 2 }.location(source), + SourceLocation { + line_number: 1, + line_position: 2, + offset: 1, + length: 1 + } + ); + assert_eq!( + Span { start: 2, end: 3 }.location(source), + SourceLocation { + line_number: 1, + line_position: 3, + offset: 2, + length: 1 + } + ); + assert_eq!( + Span { start: 3, end: 5 }.location(source), + SourceLocation { + line_number: 2, + line_position: 1, + offset: 3, + length: 2 + } + ); + assert_eq!( + Span { start: 4, end: 6 }.location(source), + SourceLocation { + line_number: 2, + line_position: 2, + offset: 4, + length: 2 + } + ); + assert_eq!( + Span { start: 5, end: 6 }.location(source), + SourceLocation { + line_number: 2, + line_position: 3, + offset: 5, + length: 1 + } + ); + assert_eq!( + Span { start: 6, end: 7 }.location(source), + SourceLocation { + line_number: 3, + line_position: 1, + offset: 6, + length: 1 + } + ); + assert_eq!( + Span { start: 7, end: 8 }.location(source), + SourceLocation { + line_number: 4, + line_position: 1, + offset: 7, + length: 1 + } + ); + assert_eq!( + Span { start: 8, end: 9 }.location(source), + SourceLocation { + line_number: 4, + line_position: 2, + offset: 8, + length: 1 + } + ); + assert_eq!( + Span { start: 9, end: 10 }.location(source), + SourceLocation { + line_number: 4, + line_position: 3, + offset: 9, + length: 1 + } + ); + assert_eq!( + Span { start: 10, end: 11 }.location(source), + SourceLocation { + line_number: 5, + line_position: 1, + offset: 10, + length: 1 + } + ); +} diff --git a/third_party/rust/naga/src/valid/analyzer.rs b/third_party/rust/naga/src/valid/analyzer.rs new file mode 100644 index 0000000000..df6fc5e9b0 --- /dev/null +++ b/third_party/rust/naga/src/valid/analyzer.rs @@ -0,0 +1,1281 @@ +/*! Module analyzer. + +Figures out the following properties: + - control flow uniformity + - texture/sampler pairs + - expression reference counts +!*/ + +use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags}; +use crate::span::{AddSpan as _, WithSpan}; +use crate::{ + arena::{Arena, Handle}, + proc::{ResolveContext, TypeResolution}, +}; +use std::ops; + +pub type NonUniformResult = Option<Handle<crate::Expression>>; + +// Remove this once we update our uniformity analysis and +// add support for the `derivative_uniformity` diagnostic +const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true; + +bitflags::bitflags! { + /// Kinds of expressions that require uniform control flow. + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct UniformityRequirements: u8 { + const WORK_GROUP_BARRIER = 0x1; + const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 }; + const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 }; + } +} + +/// Uniform control flow characteristics. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +#[cfg_attr(test, derive(PartialEq))] +pub struct Uniformity { + /// A child expression with non-uniform result. + /// + /// This means, when the relevant invocations are scheduled on a compute unit, + /// they have to use vector registers to store an individual value + /// per invocation. + /// + /// Whenever the control flow is conditioned on such value, + /// the hardware needs to keep track of the mask of invocations, + /// and process all branches of the control flow. + /// + /// Any operations that depend on non-uniform results also produce non-uniform. + pub non_uniform_result: NonUniformResult, + /// If this expression requires uniform control flow, store the reason here. + pub requirements: UniformityRequirements, +} + +impl Uniformity { + const fn new() -> Self { + Uniformity { + non_uniform_result: None, + requirements: UniformityRequirements::empty(), + } + } +} + +bitflags::bitflags! { + #[derive(Clone, Copy, Debug, PartialEq)] + struct ExitFlags: u8 { + /// Control flow may return from the function, which makes all the + /// subsequent statements within the current function (only!) + /// to be executed in a non-uniform control flow. + const MAY_RETURN = 0x1; + /// Control flow may be killed. Anything after `Statement::Kill` is + /// considered inside non-uniform context. + const MAY_KILL = 0x2; + } +} + +/// Uniformity characteristics of a function. +#[cfg_attr(test, derive(Debug, PartialEq))] +struct FunctionUniformity { + result: Uniformity, + exit: ExitFlags, +} + +impl ops::BitOr for FunctionUniformity { + type Output = Self; + fn bitor(self, other: Self) -> Self { + FunctionUniformity { + result: Uniformity { + non_uniform_result: self + .result + .non_uniform_result + .or(other.result.non_uniform_result), + requirements: self.result.requirements | other.result.requirements, + }, + exit: self.exit | other.exit, + } + } +} + +impl FunctionUniformity { + const fn new() -> Self { + FunctionUniformity { + result: Uniformity::new(), + exit: ExitFlags::empty(), + } + } + + /// Returns a disruptor based on the stored exit flags, if any. + const fn exit_disruptor(&self) -> Option<UniformityDisruptor> { + if self.exit.contains(ExitFlags::MAY_RETURN) { + Some(UniformityDisruptor::Return) + } else if self.exit.contains(ExitFlags::MAY_KILL) { + Some(UniformityDisruptor::Discard) + } else { + None + } + } +} + +bitflags::bitflags! { + /// Indicates how a global variable is used. + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct GlobalUse: u8 { + /// Data will be read from the variable. + const READ = 0x1; + /// Data will be written to the variable. + const WRITE = 0x2; + /// The information about the data is queried. + const QUERY = 0x4; + } +} + +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct SamplingKey { + pub image: Handle<crate::GlobalVariable>, + pub sampler: Handle<crate::GlobalVariable>, +} + +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct ExpressionInfo { + pub uniformity: Uniformity, + pub ref_count: usize, + assignable_global: Option<Handle<crate::GlobalVariable>>, + pub ty: TypeResolution, +} + +impl ExpressionInfo { + const fn new() -> Self { + ExpressionInfo { + uniformity: Uniformity::new(), + ref_count: 0, + assignable_global: None, + // this doesn't matter at this point, will be overwritten + ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Bool, + width: 0, + })), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +enum GlobalOrArgument { + Global(Handle<crate::GlobalVariable>), + Argument(u32), +} + +impl GlobalOrArgument { + fn from_expression( + expression_arena: &Arena<crate::Expression>, + expression: Handle<crate::Expression>, + ) -> Result<GlobalOrArgument, ExpressionError> { + Ok(match expression_arena[expression] { + crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var), + crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i), + crate::Expression::Access { base, .. } + | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] { + crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var), + _ => return Err(ExpressionError::ExpectedGlobalOrArgument), + }, + _ => return Err(ExpressionError::ExpectedGlobalOrArgument), + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +struct Sampling { + image: GlobalOrArgument, + sampler: GlobalOrArgument, +} + +#[derive(Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct FunctionInfo { + /// Validation flags. + #[allow(dead_code)] + flags: ValidationFlags, + /// Set of shader stages where calling this function is valid. + pub available_stages: ShaderStages, + /// Uniformity characteristics. + pub uniformity: Uniformity, + /// Function may kill the invocation. + pub may_kill: bool, + + /// All pairs of (texture, sampler) globals that may be used together in + /// sampling operations by this function and its callees. This includes + /// pairings that arise when this function passes textures and samplers as + /// arguments to its callees. + /// + /// This table does not include uses of textures and samplers passed as + /// arguments to this function itself, since we do not know which globals + /// those will be. However, this table *is* exhaustive when computed for an + /// entry point function: entry points never receive textures or samplers as + /// arguments, so all an entry point's sampling can be reported in terms of + /// globals. + /// + /// The GLSL back end uses this table to construct reflection info that + /// clients need to construct texture-combined sampler values. + pub sampling_set: crate::FastHashSet<SamplingKey>, + + /// How this function and its callees use this module's globals. + /// + /// This is indexed by `Handle<GlobalVariable>` indices. However, + /// `FunctionInfo` implements `std::ops::Index<Handle<GlobalVariable>>`, + /// so you can simply index this struct with a global handle to retrieve + /// its usage information. + global_uses: Box<[GlobalUse]>, + + /// Information about each expression in this function's body. + /// + /// This is indexed by `Handle<Expression>` indices. However, `FunctionInfo` + /// implements `std::ops::Index<Handle<Expression>>`, so you can simply + /// index this struct with an expression handle to retrieve its + /// `ExpressionInfo`. + expressions: Box<[ExpressionInfo]>, + + /// All (texture, sampler) pairs that may be used together in sampling + /// operations by this function and its callees, whether they are accessed + /// as globals or passed as arguments. + /// + /// Participants are represented by [`GlobalVariable`] handles whenever + /// possible, and otherwise by indices of this function's arguments. + /// + /// When analyzing a function call, we combine this data about the callee + /// with the actual arguments being passed to produce the callers' own + /// `sampling_set` and `sampling` tables. + /// + /// [`GlobalVariable`]: crate::GlobalVariable + sampling: crate::FastHashSet<Sampling>, + + /// Indicates that the function is using dual source blending. + pub dual_source_blending: bool, +} + +impl FunctionInfo { + pub const fn global_variable_count(&self) -> usize { + self.global_uses.len() + } + pub const fn expression_count(&self) -> usize { + self.expressions.len() + } + pub fn dominates_global_use(&self, other: &Self) -> bool { + for (self_global_uses, other_global_uses) in + self.global_uses.iter().zip(other.global_uses.iter()) + { + if !self_global_uses.contains(*other_global_uses) { + return false; + } + } + true + } +} + +impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo { + type Output = GlobalUse; + fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse { + &self.global_uses[handle.index()] + } +} + +impl ops::Index<Handle<crate::Expression>> for FunctionInfo { + type Output = ExpressionInfo; + fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo { + &self.expressions[handle.index()] + } +} + +/// Disruptor of the uniform control flow. +#[derive(Clone, Copy, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum UniformityDisruptor { + #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")] + Expression(Handle<crate::Expression>), + #[error("There is a Return earlier in the control flow of the function")] + Return, + #[error("There is a Discard earlier in the entry point across all called functions")] + Discard, +} + +impl FunctionInfo { + /// Adds a value-type reference to an expression. + #[must_use] + fn add_ref_impl( + &mut self, + handle: Handle<crate::Expression>, + global_use: GlobalUse, + ) -> NonUniformResult { + let info = &mut self.expressions[handle.index()]; + info.ref_count += 1; + // mark the used global as read + if let Some(global) = info.assignable_global { + self.global_uses[global.index()] |= global_use; + } + info.uniformity.non_uniform_result + } + + /// Adds a value-type reference to an expression. + #[must_use] + fn add_ref(&mut self, handle: Handle<crate::Expression>) -> NonUniformResult { + self.add_ref_impl(handle, GlobalUse::READ) + } + + /// Adds a potentially assignable reference to an expression. + /// These are destinations for `Store` and `ImageStore` statements, + /// which can transit through `Access` and `AccessIndex`. + #[must_use] + fn add_assignable_ref( + &mut self, + handle: Handle<crate::Expression>, + assignable_global: &mut Option<Handle<crate::GlobalVariable>>, + ) -> NonUniformResult { + let info = &mut self.expressions[handle.index()]; + info.ref_count += 1; + // propagate the assignable global up the chain, till it either hits + // a value-type expression, or the assignment statement. + if let Some(global) = info.assignable_global { + if let Some(_old) = assignable_global.replace(global) { + unreachable!() + } + } + info.uniformity.non_uniform_result + } + + /// Inherit information from a called function. + fn process_call( + &mut self, + callee: &Self, + arguments: &[Handle<crate::Expression>], + expression_arena: &Arena<crate::Expression>, + ) -> Result<FunctionUniformity, WithSpan<FunctionError>> { + self.sampling_set + .extend(callee.sampling_set.iter().cloned()); + for sampling in callee.sampling.iter() { + // If the callee was passed the texture or sampler as an argument, + // we may now be able to determine which globals those referred to. + let image_storage = match sampling.image { + GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var), + GlobalOrArgument::Argument(i) => { + let handle = arguments[i as usize]; + GlobalOrArgument::from_expression(expression_arena, handle).map_err( + |source| { + FunctionError::Expression { handle, source } + .with_span_handle(handle, expression_arena) + }, + )? + } + }; + + let sampler_storage = match sampling.sampler { + GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var), + GlobalOrArgument::Argument(i) => { + let handle = arguments[i as usize]; + GlobalOrArgument::from_expression(expression_arena, handle).map_err( + |source| { + FunctionError::Expression { handle, source } + .with_span_handle(handle, expression_arena) + }, + )? + } + }; + + // If we've managed to pin both the image and sampler down to + // specific globals, record that in our `sampling_set`. Otherwise, + // record as much as we do know in our own `sampling` table, for our + // callers to sort out. + match (image_storage, sampler_storage) { + (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => { + self.sampling_set.insert(SamplingKey { image, sampler }); + } + (image, sampler) => { + self.sampling.insert(Sampling { image, sampler }); + } + } + } + + // Inherit global use from our callees. + for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) { + *mine |= *other; + } + + Ok(FunctionUniformity { + result: callee.uniformity.clone(), + exit: if callee.may_kill { + ExitFlags::MAY_KILL + } else { + ExitFlags::empty() + }, + }) + } + + /// Compute the [`ExpressionInfo`] for `handle`. + /// + /// Replace the dummy entry in [`self.expressions`] for `handle` + /// with a real `ExpressionInfo` value describing that expression. + /// + /// This function is called as part of a forward sweep through the + /// arena, so we can assume that all earlier expressions in the + /// arena already have valid info. Since expressions only depend + /// on earlier expressions, this includes all our subexpressions. + /// + /// Adjust the reference counts on all expressions we use. + /// + /// Also populate the [`sampling_set`], [`sampling`] and + /// [`global_uses`] fields of `self`. + /// + /// [`self.expressions`]: FunctionInfo::expressions + /// [`sampling_set`]: FunctionInfo::sampling_set + /// [`sampling`]: FunctionInfo::sampling + /// [`global_uses`]: FunctionInfo::global_uses + #[allow(clippy::or_fun_call)] + fn process_expression( + &mut self, + handle: Handle<crate::Expression>, + expression_arena: &Arena<crate::Expression>, + other_functions: &[FunctionInfo], + resolve_context: &ResolveContext, + capabilities: super::Capabilities, + ) -> Result<(), ExpressionError> { + use crate::{Expression as E, SampleLevel as Sl}; + + let expression = &expression_arena[handle]; + let mut assignable_global = None; + let uniformity = match *expression { + E::Access { base, index } => { + let base_ty = self[base].ty.inner_with(resolve_context.types); + + // build up the caps needed if this is indexed non-uniformly + let mut needed_caps = super::Capabilities::empty(); + let is_binding_array = match *base_ty { + crate::TypeInner::BindingArray { + base: array_element_ty_handle, + .. + } => { + // these are nasty aliases, but these idents are too long and break rustfmt + let ub_st = super::Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING; + let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING; + let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING; + + // We're a binding array, so lets use the type of _what_ we are array of to determine if we can non-uniformly index it. + let array_element_ty = + &resolve_context.types[array_element_ty_handle].inner; + + needed_caps |= match *array_element_ty { + // If we're an image, use the appropriate limit. + crate::TypeInner::Image { class, .. } => match class { + crate::ImageClass::Storage { .. } => ub_st, + _ => st_sb, + }, + crate::TypeInner::Sampler { .. } => sampler, + // If we're anything but an image, assume we're a buffer and use the address space. + _ => { + if let E::GlobalVariable(global_handle) = expression_arena[base] { + let global = &resolve_context.global_vars[global_handle]; + match global.space { + crate::AddressSpace::Uniform => ub_st, + crate::AddressSpace::Storage { .. } => st_sb, + _ => unreachable!(), + } + } else { + unreachable!() + } + } + }; + + true + } + _ => false, + }; + + if self[index].uniformity.non_uniform_result.is_some() + && !capabilities.contains(needed_caps) + && is_binding_array + { + return Err(ExpressionError::MissingCapabilities(needed_caps)); + } + + Uniformity { + non_uniform_result: self + .add_assignable_ref(base, &mut assignable_global) + .or(self.add_ref(index)), + requirements: UniformityRequirements::empty(), + } + } + E::AccessIndex { base, .. } => Uniformity { + non_uniform_result: self.add_assignable_ref(base, &mut assignable_global), + requirements: UniformityRequirements::empty(), + }, + // always uniform + E::Splat { size: _, value } => Uniformity { + non_uniform_result: self.add_ref(value), + requirements: UniformityRequirements::empty(), + }, + E::Swizzle { vector, .. } => Uniformity { + non_uniform_result: self.add_ref(vector), + requirements: UniformityRequirements::empty(), + }, + E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => Uniformity::new(), + E::Compose { ref components, .. } => { + let non_uniform_result = components + .iter() + .fold(None, |nur, &comp| nur.or(self.add_ref(comp))); + Uniformity { + non_uniform_result, + requirements: UniformityRequirements::empty(), + } + } + // depends on the builtin or interpolation + E::FunctionArgument(index) => { + let arg = &resolve_context.arguments[index as usize]; + let uniform = match arg.binding { + Some(crate::Binding::BuiltIn(built_in)) => match built_in { + // per-polygon built-ins are uniform + crate::BuiltIn::FrontFacing + // per-work-group built-ins are uniform + | crate::BuiltIn::WorkGroupId + | crate::BuiltIn::WorkGroupSize + | crate::BuiltIn::NumWorkGroups => true, + _ => false, + }, + // only flat inputs are uniform + Some(crate::Binding::Location { + interpolation: Some(crate::Interpolation::Flat), + .. + }) => true, + _ => false, + }; + Uniformity { + non_uniform_result: if uniform { None } else { Some(handle) }, + requirements: UniformityRequirements::empty(), + } + } + // depends on the address space + E::GlobalVariable(gh) => { + use crate::AddressSpace as As; + assignable_global = Some(gh); + let var = &resolve_context.global_vars[gh]; + let uniform = match var.space { + // local data is non-uniform + As::Function | As::Private => false, + // workgroup memory is exclusively accessed by the group + As::WorkGroup => true, + // uniform data + As::Uniform | As::PushConstant => true, + // storage data is only uniform when read-only + As::Storage { access } => !access.contains(crate::StorageAccess::STORE), + As::Handle => false, + }; + Uniformity { + non_uniform_result: if uniform { None } else { Some(handle) }, + requirements: UniformityRequirements::empty(), + } + } + E::LocalVariable(_) => Uniformity { + non_uniform_result: Some(handle), + requirements: UniformityRequirements::empty(), + }, + E::Load { pointer } => Uniformity { + non_uniform_result: self.add_ref(pointer), + requirements: UniformityRequirements::empty(), + }, + E::ImageSample { + image, + sampler, + gather: _, + coordinate, + array_index, + offset: _, + level, + depth_ref, + } => { + let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?; + let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?; + + match (image_storage, sampler_storage) { + (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => { + self.sampling_set.insert(SamplingKey { image, sampler }); + } + _ => { + self.sampling.insert(Sampling { + image: image_storage, + sampler: sampler_storage, + }); + } + } + + // "nur" == "Non-Uniform Result" + let array_nur = array_index.and_then(|h| self.add_ref(h)); + let level_nur = match level { + Sl::Auto | Sl::Zero => None, + Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h), + Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)), + }; + let dref_nur = depth_ref.and_then(|h| self.add_ref(h)); + Uniformity { + non_uniform_result: self + .add_ref(image) + .or(self.add_ref(sampler)) + .or(self.add_ref(coordinate)) + .or(array_nur) + .or(level_nur) + .or(dref_nur), + requirements: if level.implicit_derivatives() { + UniformityRequirements::IMPLICIT_LEVEL + } else { + UniformityRequirements::empty() + }, + } + } + E::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + let array_nur = array_index.and_then(|h| self.add_ref(h)); + let sample_nur = sample.and_then(|h| self.add_ref(h)); + let level_nur = level.and_then(|h| self.add_ref(h)); + Uniformity { + non_uniform_result: self + .add_ref(image) + .or(self.add_ref(coordinate)) + .or(array_nur) + .or(sample_nur) + .or(level_nur), + requirements: UniformityRequirements::empty(), + } + } + E::ImageQuery { image, query } => { + let query_nur = match query { + crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h), + _ => None, + }; + Uniformity { + non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur), + requirements: UniformityRequirements::empty(), + } + } + E::Unary { expr, .. } => Uniformity { + non_uniform_result: self.add_ref(expr), + requirements: UniformityRequirements::empty(), + }, + E::Binary { left, right, .. } => Uniformity { + non_uniform_result: self.add_ref(left).or(self.add_ref(right)), + requirements: UniformityRequirements::empty(), + }, + E::Select { + condition, + accept, + reject, + } => Uniformity { + non_uniform_result: self + .add_ref(condition) + .or(self.add_ref(accept)) + .or(self.add_ref(reject)), + requirements: UniformityRequirements::empty(), + }, + // explicit derivatives require uniform + E::Derivative { expr, .. } => Uniformity { + //Note: taking a derivative of a uniform doesn't make it non-uniform + non_uniform_result: self.add_ref(expr), + requirements: UniformityRequirements::DERIVATIVE, + }, + E::Relational { argument, .. } => Uniformity { + non_uniform_result: self.add_ref(argument), + requirements: UniformityRequirements::empty(), + }, + E::Math { + fun: _, + arg, + arg1, + arg2, + arg3, + } => { + let arg1_nur = arg1.and_then(|h| self.add_ref(h)); + let arg2_nur = arg2.and_then(|h| self.add_ref(h)); + let arg3_nur = arg3.and_then(|h| self.add_ref(h)); + Uniformity { + non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur), + requirements: UniformityRequirements::empty(), + } + } + E::As { expr, .. } => Uniformity { + non_uniform_result: self.add_ref(expr), + requirements: UniformityRequirements::empty(), + }, + E::CallResult(function) => other_functions[function.index()].uniformity.clone(), + E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity { + non_uniform_result: Some(handle), + requirements: UniformityRequirements::empty(), + }, + E::WorkGroupUniformLoadResult { .. } => Uniformity { + // The result of WorkGroupUniformLoad is always uniform by definition + non_uniform_result: None, + // The call is what cares about uniformity, not the expression + // This expression is never emitted, so this requirement should never be used anyway? + requirements: UniformityRequirements::empty(), + }, + E::ArrayLength(expr) => Uniformity { + non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY), + requirements: UniformityRequirements::empty(), + }, + E::RayQueryGetIntersection { + query, + committed: _, + } => Uniformity { + non_uniform_result: self.add_ref(query), + requirements: UniformityRequirements::empty(), + }, + }; + + let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; + self.expressions[handle.index()] = ExpressionInfo { + uniformity, + ref_count: 0, + assignable_global, + ty, + }; + Ok(()) + } + + /// Analyzes the uniformity requirements of a block (as a sequence of statements). + /// Returns the uniformity characteristics at the *function* level, i.e. + /// whether or not the function requires to be called in uniform control flow, + /// and whether the produced result is not disrupting the control flow. + /// + /// The parent control flow is uniform if `disruptor.is_none()`. + /// + /// Returns a `NonUniformControlFlow` error if any of the expressions in the block + /// require uniformity, but the current flow is non-uniform. + #[allow(clippy::or_fun_call)] + fn process_block( + &mut self, + statements: &crate::Block, + other_functions: &[FunctionInfo], + mut disruptor: Option<UniformityDisruptor>, + expression_arena: &Arena<crate::Expression>, + ) -> Result<FunctionUniformity, WithSpan<FunctionError>> { + use crate::Statement as S; + + let mut combined_uniformity = FunctionUniformity::new(); + for statement in statements { + let uniformity = match *statement { + S::Emit(ref range) => { + let mut requirements = UniformityRequirements::empty(); + for expr in range.clone() { + let req = self.expressions[expr.index()].uniformity.requirements; + if self + .flags + .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY) + && !req.is_empty() + { + if let Some(cause) = disruptor { + return Err(FunctionError::NonUniformControlFlow(req, expr, cause) + .with_span_handle(expr, expression_arena)); + } + } + requirements |= req; + } + FunctionUniformity { + result: Uniformity { + non_uniform_result: None, + requirements, + }, + exit: ExitFlags::empty(), + } + } + S::Break | S::Continue => FunctionUniformity::new(), + S::Kill => FunctionUniformity { + result: Uniformity::new(), + exit: if disruptor.is_some() { + ExitFlags::MAY_KILL + } else { + ExitFlags::empty() + }, + }, + S::Barrier(_) => FunctionUniformity { + result: Uniformity { + non_uniform_result: None, + requirements: UniformityRequirements::WORK_GROUP_BARRIER, + }, + exit: ExitFlags::empty(), + }, + S::WorkGroupUniformLoad { pointer, .. } => { + let _condition_nur = self.add_ref(pointer); + + // Don't check that this call occurs in uniform control flow until Naga implements WGSL's standard + // uniformity analysis (https://github.com/gfx-rs/naga/issues/1744). + // The uniformity analysis Naga uses now is less accurate than the one in the WGSL standard, + // causing Naga to reject correct uses of `workgroupUniformLoad` in some interesting programs. + + /* + if self + .flags + .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY) + { + let condition_nur = self.add_ref(pointer); + let this_disruptor = + disruptor.or(condition_nur.map(UniformityDisruptor::Expression)); + if let Some(cause) = this_disruptor { + return Err(FunctionError::NonUniformWorkgroupUniformLoad(cause) + .with_span_static(*span, "WorkGroupUniformLoad")); + } + } */ + FunctionUniformity { + result: Uniformity { + non_uniform_result: None, + requirements: UniformityRequirements::WORK_GROUP_BARRIER, + }, + exit: ExitFlags::empty(), + } + } + S::Block(ref b) => { + self.process_block(b, other_functions, disruptor, expression_arena)? + } + S::If { + condition, + ref accept, + ref reject, + } => { + let condition_nur = self.add_ref(condition); + let branch_disruptor = + disruptor.or(condition_nur.map(UniformityDisruptor::Expression)); + let accept_uniformity = self.process_block( + accept, + other_functions, + branch_disruptor, + expression_arena, + )?; + let reject_uniformity = self.process_block( + reject, + other_functions, + branch_disruptor, + expression_arena, + )?; + accept_uniformity | reject_uniformity + } + S::Switch { + selector, + ref cases, + } => { + let selector_nur = self.add_ref(selector); + let branch_disruptor = + disruptor.or(selector_nur.map(UniformityDisruptor::Expression)); + let mut uniformity = FunctionUniformity::new(); + let mut case_disruptor = branch_disruptor; + for case in cases.iter() { + let case_uniformity = self.process_block( + &case.body, + other_functions, + case_disruptor, + expression_arena, + )?; + case_disruptor = if case.fall_through { + case_disruptor.or(case_uniformity.exit_disruptor()) + } else { + branch_disruptor + }; + uniformity = uniformity | case_uniformity; + } + uniformity + } + S::Loop { + ref body, + ref continuing, + break_if, + } => { + let body_uniformity = + self.process_block(body, other_functions, disruptor, expression_arena)?; + let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor()); + let continuing_uniformity = self.process_block( + continuing, + other_functions, + continuing_disruptor, + expression_arena, + )?; + if let Some(expr) = break_if { + let _ = self.add_ref(expr); + } + body_uniformity | continuing_uniformity + } + S::Return { value } => FunctionUniformity { + result: Uniformity { + non_uniform_result: value.and_then(|expr| self.add_ref(expr)), + requirements: UniformityRequirements::empty(), + }, + exit: if disruptor.is_some() { + ExitFlags::MAY_RETURN + } else { + ExitFlags::empty() + }, + }, + // Here and below, the used expressions are already emitted, + // and their results do not affect the function return value, + // so we can ignore their non-uniformity. + S::Store { pointer, value } => { + let _ = self.add_ref_impl(pointer, GlobalUse::WRITE); + let _ = self.add_ref(value); + FunctionUniformity::new() + } + S::ImageStore { + image, + coordinate, + array_index, + value, + } => { + let _ = self.add_ref_impl(image, GlobalUse::WRITE); + if let Some(expr) = array_index { + let _ = self.add_ref(expr); + } + let _ = self.add_ref(coordinate); + let _ = self.add_ref(value); + FunctionUniformity::new() + } + S::Call { + function, + ref arguments, + result: _, + } => { + for &argument in arguments { + let _ = self.add_ref(argument); + } + let info = &other_functions[function.index()]; + //Note: the result is validated by the Validator, not here + self.process_call(info, arguments, expression_arena)? + } + S::Atomic { + pointer, + ref fun, + value, + result: _, + } => { + let _ = self.add_ref_impl(pointer, GlobalUse::WRITE); + let _ = self.add_ref(value); + if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { + let _ = self.add_ref(cmp); + } + FunctionUniformity::new() + } + S::RayQuery { query, ref fun } => { + let _ = self.add_ref(query); + if let crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } = *fun + { + let _ = self.add_ref(acceleration_structure); + let _ = self.add_ref(descriptor); + } + FunctionUniformity::new() + } + }; + + disruptor = disruptor.or(uniformity.exit_disruptor()); + combined_uniformity = combined_uniformity | uniformity; + } + Ok(combined_uniformity) + } +} + +impl ModuleInfo { + /// Populates `self.const_expression_types` + pub(super) fn process_const_expression( + &mut self, + handle: Handle<crate::Expression>, + resolve_context: &ResolveContext, + gctx: crate::proc::GlobalCtx, + ) -> Result<(), super::ConstExpressionError> { + self.const_expression_types[handle.index()] = + resolve_context.resolve(&gctx.const_expressions[handle], |h| Ok(&self[h]))?; + Ok(()) + } + + /// Builds the `FunctionInfo` based on the function, and validates the + /// uniform control flow if required by the expressions of this function. + pub(super) fn process_function( + &self, + fun: &crate::Function, + module: &crate::Module, + flags: ValidationFlags, + capabilities: super::Capabilities, + ) -> Result<FunctionInfo, WithSpan<FunctionError>> { + let mut info = FunctionInfo { + flags, + available_stages: ShaderStages::all(), + uniformity: Uniformity::new(), + may_kill: false, + sampling_set: crate::FastHashSet::default(), + global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(), + expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(), + sampling: crate::FastHashSet::default(), + dual_source_blending: false, + }; + let resolve_context = + ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments); + + for (handle, _) in fun.expressions.iter() { + if let Err(source) = info.process_expression( + handle, + &fun.expressions, + &self.functions, + &resolve_context, + capabilities, + ) { + return Err(FunctionError::Expression { handle, source } + .with_span_handle(handle, &fun.expressions)); + } + } + + for (_, expr) in fun.local_variables.iter() { + if let Some(init) = expr.init { + let _ = info.add_ref(init); + } + } + + let uniformity = info.process_block(&fun.body, &self.functions, None, &fun.expressions)?; + info.uniformity = uniformity.result; + info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL); + + Ok(info) + } + + pub fn get_entry_point(&self, index: usize) -> &FunctionInfo { + &self.entry_points[index] + } +} + +#[test] +fn uniform_control_flow() { + use crate::{Expression as E, Statement as S}; + + let mut type_arena = crate::UniqueArena::new(); + let ty = type_arena.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size: crate::VectorSize::Bi, + scalar: crate::Scalar::F32, + }, + }, + Default::default(), + ); + let mut global_var_arena = Arena::new(); + let non_uniform_global = global_var_arena.append( + crate::GlobalVariable { + name: None, + init: None, + ty, + space: crate::AddressSpace::Handle, + binding: None, + }, + Default::default(), + ); + let uniform_global = global_var_arena.append( + crate::GlobalVariable { + name: None, + init: None, + ty, + binding: None, + space: crate::AddressSpace::Uniform, + }, + Default::default(), + ); + + let mut expressions = Arena::new(); + // checks the uniform control flow + let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default()); + // checks the non-uniform control flow + let derivative_expr = expressions.append( + E::Derivative { + axis: crate::DerivativeAxis::X, + ctrl: crate::DerivativeControl::None, + expr: constant_expr, + }, + Default::default(), + ); + let emit_range_constant_derivative = expressions.range_from(0); + let non_uniform_global_expr = + expressions.append(E::GlobalVariable(non_uniform_global), Default::default()); + let uniform_global_expr = + expressions.append(E::GlobalVariable(uniform_global), Default::default()); + let emit_range_globals = expressions.range_from(2); + + // checks the QUERY flag + let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default()); + // checks the transitive WRITE flag + let access_expr = expressions.append( + E::AccessIndex { + base: non_uniform_global_expr, + index: 1, + }, + Default::default(), + ); + let emit_range_query_access_globals = expressions.range_from(2); + + let mut info = FunctionInfo { + flags: ValidationFlags::all(), + available_stages: ShaderStages::all(), + uniformity: Uniformity::new(), + may_kill: false, + sampling_set: crate::FastHashSet::default(), + global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(), + expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(), + sampling: crate::FastHashSet::default(), + dual_source_blending: false, + }; + let resolve_context = ResolveContext { + constants: &Arena::new(), + types: &type_arena, + special_types: &crate::SpecialTypes::default(), + global_vars: &global_var_arena, + local_vars: &Arena::new(), + functions: &Arena::new(), + arguments: &[], + }; + for (handle, _) in expressions.iter() { + info.process_expression( + handle, + &expressions, + &[], + &resolve_context, + super::Capabilities::empty(), + ) + .unwrap(); + } + assert_eq!(info[non_uniform_global_expr].ref_count, 1); + assert_eq!(info[uniform_global_expr].ref_count, 1); + assert_eq!(info[query_expr].ref_count, 0); + assert_eq!(info[access_expr].ref_count, 0); + assert_eq!(info[non_uniform_global], GlobalUse::empty()); + assert_eq!(info[uniform_global], GlobalUse::QUERY); + + let stmt_emit1 = S::Emit(emit_range_globals.clone()); + let stmt_if_uniform = S::If { + condition: uniform_global_expr, + accept: crate::Block::new(), + reject: vec![ + S::Emit(emit_range_constant_derivative.clone()), + S::Store { + pointer: constant_expr, + value: derivative_expr, + }, + ] + .into(), + }; + assert_eq!( + info.process_block( + &vec![stmt_emit1, stmt_if_uniform].into(), + &[], + None, + &expressions + ), + Ok(FunctionUniformity { + result: Uniformity { + non_uniform_result: None, + requirements: UniformityRequirements::DERIVATIVE, + }, + exit: ExitFlags::empty(), + }), + ); + assert_eq!(info[constant_expr].ref_count, 2); + assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY); + + let stmt_emit2 = S::Emit(emit_range_globals.clone()); + let stmt_if_non_uniform = S::If { + condition: non_uniform_global_expr, + accept: vec![ + S::Emit(emit_range_constant_derivative), + S::Store { + pointer: constant_expr, + value: derivative_expr, + }, + ] + .into(), + reject: crate::Block::new(), + }; + { + let block_info = info.process_block( + &vec![stmt_emit2, stmt_if_non_uniform].into(), + &[], + None, + &expressions, + ); + if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { + assert_eq!(info[derivative_expr].ref_count, 2); + } else { + assert_eq!( + block_info, + Err(FunctionError::NonUniformControlFlow( + UniformityRequirements::DERIVATIVE, + derivative_expr, + UniformityDisruptor::Expression(non_uniform_global_expr) + ) + .with_span()), + ); + assert_eq!(info[derivative_expr].ref_count, 1); + } + } + assert_eq!(info[non_uniform_global], GlobalUse::READ); + + let stmt_emit3 = S::Emit(emit_range_globals); + let stmt_return_non_uniform = S::Return { + value: Some(non_uniform_global_expr), + }; + assert_eq!( + info.process_block( + &vec![stmt_emit3, stmt_return_non_uniform].into(), + &[], + Some(UniformityDisruptor::Return), + &expressions + ), + Ok(FunctionUniformity { + result: Uniformity { + non_uniform_result: Some(non_uniform_global_expr), + requirements: UniformityRequirements::empty(), + }, + exit: ExitFlags::MAY_RETURN, + }), + ); + assert_eq!(info[non_uniform_global_expr].ref_count, 3); + + // Check that uniformity requirements reach through a pointer + let stmt_emit4 = S::Emit(emit_range_query_access_globals); + let stmt_assign = S::Store { + pointer: access_expr, + value: query_expr, + }; + let stmt_return_pointer = S::Return { + value: Some(access_expr), + }; + let stmt_kill = S::Kill; + assert_eq!( + info.process_block( + &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(), + &[], + Some(UniformityDisruptor::Discard), + &expressions + ), + Ok(FunctionUniformity { + result: Uniformity { + non_uniform_result: Some(non_uniform_global_expr), + requirements: UniformityRequirements::empty(), + }, + exit: ExitFlags::all(), + }), + ); + assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE); +} diff --git a/third_party/rust/naga/src/valid/compose.rs b/third_party/rust/naga/src/valid/compose.rs new file mode 100644 index 0000000000..c21e98c6f2 --- /dev/null +++ b/third_party/rust/naga/src/valid/compose.rs @@ -0,0 +1,128 @@ +use crate::proc::TypeResolution; + +use crate::arena::Handle; + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum ComposeError { + #[error("Composing of type {0:?} can't be done")] + Type(Handle<crate::Type>), + #[error("Composing expects {expected} components but {given} were given")] + ComponentCount { given: u32, expected: u32 }, + #[error("Composing {index}'s component type is not expected")] + ComponentType { index: u32 }, +} + +pub fn validate_compose( + self_ty_handle: Handle<crate::Type>, + gctx: crate::proc::GlobalCtx, + component_resolutions: impl ExactSizeIterator<Item = TypeResolution>, +) -> Result<(), ComposeError> { + use crate::TypeInner as Ti; + + match gctx.types[self_ty_handle].inner { + // vectors are composed from scalars or other vectors + Ti::Vector { size, scalar } => { + let mut total = 0; + for (index, comp_res) in component_resolutions.enumerate() { + total += match *comp_res.inner_with(gctx.types) { + Ti::Scalar(comp_scalar) if comp_scalar == scalar => 1, + Ti::Vector { + size: comp_size, + scalar: comp_scalar, + } if comp_scalar == scalar => comp_size as u32, + ref other => { + log::error!( + "Vector component[{}] type {:?}, building {:?}", + index, + other, + scalar + ); + return Err(ComposeError::ComponentType { + index: index as u32, + }); + } + }; + } + if size as u32 != total { + return Err(ComposeError::ComponentCount { + expected: size as u32, + given: total, + }); + } + } + // matrix are composed from column vectors + Ti::Matrix { + columns, + rows, + scalar, + } => { + let inner = Ti::Vector { size: rows, scalar }; + if columns as usize != component_resolutions.len() { + return Err(ComposeError::ComponentCount { + expected: columns as u32, + given: component_resolutions.len() as u32, + }); + } + for (index, comp_res) in component_resolutions.enumerate() { + if comp_res.inner_with(gctx.types) != &inner { + log::error!("Matrix component[{}] type {:?}", index, comp_res); + return Err(ComposeError::ComponentType { + index: index as u32, + }); + } + } + } + Ti::Array { + base, + size: crate::ArraySize::Constant(count), + stride: _, + } => { + if count.get() as usize != component_resolutions.len() { + return Err(ComposeError::ComponentCount { + expected: count.get(), + given: component_resolutions.len() as u32, + }); + } + for (index, comp_res) in component_resolutions.enumerate() { + let base_inner = &gctx.types[base].inner; + let comp_res_inner = comp_res.inner_with(gctx.types); + // We don't support arrays of pointers, but it seems best not to + // embed that assumption here, so use `TypeInner::equivalent`. + if !base_inner.equivalent(comp_res_inner, gctx.types) { + log::error!("Array component[{}] type {:?}", index, comp_res); + return Err(ComposeError::ComponentType { + index: index as u32, + }); + } + } + } + Ti::Struct { ref members, .. } => { + if members.len() != component_resolutions.len() { + return Err(ComposeError::ComponentCount { + given: component_resolutions.len() as u32, + expected: members.len() as u32, + }); + } + for (index, (member, comp_res)) in members.iter().zip(component_resolutions).enumerate() + { + let member_inner = &gctx.types[member.ty].inner; + let comp_res_inner = comp_res.inner_with(gctx.types); + // We don't support pointers in structs, but it seems best not to embed + // that assumption here, so use `TypeInner::equivalent`. + if !comp_res_inner.equivalent(member_inner, gctx.types) { + log::error!("Struct component[{}] type {:?}", index, comp_res); + return Err(ComposeError::ComponentType { + index: index as u32, + }); + } + } + } + ref other => { + log::error!("Composing of {:?}", other); + return Err(ComposeError::Type(self_ty_handle)); + } + } + + Ok(()) +} diff --git a/third_party/rust/naga/src/valid/expression.rs b/third_party/rust/naga/src/valid/expression.rs new file mode 100644 index 0000000000..c82d60f062 --- /dev/null +++ b/third_party/rust/naga/src/valid/expression.rs @@ -0,0 +1,1797 @@ +use super::{ + compose::validate_compose, validate_atomic_compare_exchange_struct, FunctionInfo, ModuleInfo, + ShaderStages, TypeFlags, +}; +use crate::arena::UniqueArena; + +use crate::{ + arena::Handle, + proc::{IndexableLengthError, ResolveError}, +}; + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum ExpressionError { + #[error("Doesn't exist")] + DoesntExist, + #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")] + NotInScope, + #[error("Base type {0:?} is not compatible with this expression")] + InvalidBaseType(Handle<crate::Expression>), + #[error("Accessing with index {0:?} can't be done")] + InvalidIndexType(Handle<crate::Expression>), + #[error("Accessing {0:?} via a negative index is invalid")] + NegativeIndex(Handle<crate::Expression>), + #[error("Accessing index {1} is out of {0:?} bounds")] + IndexOutOfBounds(Handle<crate::Expression>, u32), + #[error("The expression {0:?} may only be indexed by a constant")] + IndexMustBeConstant(Handle<crate::Expression>), + #[error("Function argument {0:?} doesn't exist")] + FunctionArgumentDoesntExist(u32), + #[error("Loading of {0:?} can't be done")] + InvalidPointerType(Handle<crate::Expression>), + #[error("Array length of {0:?} can't be done")] + InvalidArrayType(Handle<crate::Expression>), + #[error("Get intersection of {0:?} can't be done")] + InvalidRayQueryType(Handle<crate::Expression>), + #[error("Splatting {0:?} can't be done")] + InvalidSplatType(Handle<crate::Expression>), + #[error("Swizzling {0:?} can't be done")] + InvalidVectorType(Handle<crate::Expression>), + #[error("Swizzle component {0:?} is outside of vector size {1:?}")] + InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize), + #[error(transparent)] + Compose(#[from] super::ComposeError), + #[error(transparent)] + IndexableLength(#[from] IndexableLengthError), + #[error("Operation {0:?} can't work with {1:?}")] + InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>), + #[error("Operation {0:?} can't work with {1:?} and {2:?}")] + InvalidBinaryOperandTypes( + crate::BinaryOperator, + Handle<crate::Expression>, + Handle<crate::Expression>, + ), + #[error("Selecting is not possible")] + InvalidSelectTypes, + #[error("Relational argument {0:?} is not a boolean vector")] + InvalidBooleanVector(Handle<crate::Expression>), + #[error("Relational argument {0:?} is not a float")] + InvalidFloatArgument(Handle<crate::Expression>), + #[error("Type resolution failed")] + Type(#[from] ResolveError), + #[error("Not a global variable")] + ExpectedGlobalVariable, + #[error("Not a global variable or a function argument")] + ExpectedGlobalOrArgument, + #[error("Needs to be an binding array instead of {0:?}")] + ExpectedBindingArrayType(Handle<crate::Type>), + #[error("Needs to be an image instead of {0:?}")] + ExpectedImageType(Handle<crate::Type>), + #[error("Needs to be an image instead of {0:?}")] + ExpectedSamplerType(Handle<crate::Type>), + #[error("Unable to operate on image class {0:?}")] + InvalidImageClass(crate::ImageClass), + #[error("Derivatives can only be taken from scalar and vector floats")] + InvalidDerivative, + #[error("Image array index parameter is misplaced")] + InvalidImageArrayIndex, + #[error("Inappropriate sample or level-of-detail index for texel access")] + InvalidImageOtherIndex, + #[error("Image array index type of {0:?} is not an integer scalar")] + InvalidImageArrayIndexType(Handle<crate::Expression>), + #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")] + InvalidImageOtherIndexType(Handle<crate::Expression>), + #[error("Image coordinate type of {1:?} does not match dimension {0:?}")] + InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>), + #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")] + ComparisonSamplingMismatch { + image: crate::ImageClass, + sampler: bool, + has_ref: bool, + }, + #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")] + InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>), + #[error("Depth reference {0:?} is not a scalar float")] + InvalidDepthReference(Handle<crate::Expression>), + #[error("Depth sample level can only be Auto or Zero")] + InvalidDepthSampleLevel, + #[error("Gather level can only be Zero")] + InvalidGatherLevel, + #[error("Gather component {0:?} doesn't exist in the image")] + InvalidGatherComponent(crate::SwizzleComponent), + #[error("Gather can't be done for image dimension {0:?}")] + InvalidGatherDimension(crate::ImageDimension), + #[error("Sample level (exact) type {0:?} is not a scalar float")] + InvalidSampleLevelExactType(Handle<crate::Expression>), + #[error("Sample level (bias) type {0:?} is not a scalar float")] + InvalidSampleLevelBiasType(Handle<crate::Expression>), + #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")] + InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>), + #[error("Unable to cast")] + InvalidCastArgument, + #[error("Invalid argument count for {0:?}")] + WrongArgumentCount(crate::MathFunction), + #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")] + InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>), + #[error("Atomic result type can't be {0:?}")] + InvalidAtomicResultType(Handle<crate::Type>), + #[error( + "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type." + )] + InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>), + #[error("Shader requires capability {0:?}")] + MissingCapabilities(super::Capabilities), + #[error(transparent)] + Literal(#[from] LiteralError), +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum ConstExpressionError { + #[error("The expression is not a constant expression")] + NonConst, + #[error(transparent)] + Compose(#[from] super::ComposeError), + #[error("Splatting {0:?} can't be done")] + InvalidSplatType(Handle<crate::Expression>), + #[error("Type resolution failed")] + Type(#[from] ResolveError), + #[error(transparent)] + Literal(#[from] LiteralError), + #[error(transparent)] + Width(#[from] super::r#type::WidthError), +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum LiteralError { + #[error("Float literal is NaN")] + NaN, + #[error("Float literal is infinite")] + Infinity, + #[error(transparent)] + Width(#[from] super::r#type::WidthError), +} + +struct ExpressionTypeResolver<'a> { + root: Handle<crate::Expression>, + types: &'a UniqueArena<crate::Type>, + info: &'a FunctionInfo, +} + +impl<'a> std::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'a> { + type Output = crate::TypeInner; + + #[allow(clippy::panic)] + fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output { + if handle < self.root { + self.info[handle].ty.inner_with(self.types) + } else { + // `Validator::validate_module_handles` should have caught this. + panic!( + "Depends on {:?}, which has not been processed yet", + self.root + ) + } + } +} + +impl super::Validator { + pub(super) fn validate_const_expression( + &self, + handle: Handle<crate::Expression>, + gctx: crate::proc::GlobalCtx, + mod_info: &ModuleInfo, + ) -> Result<(), ConstExpressionError> { + use crate::Expression as E; + + match gctx.const_expressions[handle] { + E::Literal(literal) => { + self.validate_literal(literal)?; + } + E::Constant(_) | E::ZeroValue(_) => {} + E::Compose { ref components, ty } => { + validate_compose( + ty, + gctx, + components.iter().map(|&handle| mod_info[handle].clone()), + )?; + } + E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) { + crate::TypeInner::Scalar { .. } => {} + _ => return Err(super::ConstExpressionError::InvalidSplatType(value)), + }, + _ => return Err(super::ConstExpressionError::NonConst), + } + + Ok(()) + } + + pub(super) fn validate_expression( + &self, + root: Handle<crate::Expression>, + expression: &crate::Expression, + function: &crate::Function, + module: &crate::Module, + info: &FunctionInfo, + mod_info: &ModuleInfo, + ) -> Result<ShaderStages, ExpressionError> { + use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti}; + + let resolver = ExpressionTypeResolver { + root, + types: &module.types, + info, + }; + + let stages = match *expression { + E::Access { base, index } => { + let base_type = &resolver[base]; + // See the documentation for `Expression::Access`. + let dynamic_indexing_restricted = match *base_type { + Ti::Vector { .. } => false, + Ti::Matrix { .. } | Ti::Array { .. } => true, + Ti::Pointer { .. } + | Ti::ValuePointer { size: Some(_), .. } + | Ti::BindingArray { .. } => false, + ref other => { + log::error!("Indexing of {:?}", other); + return Err(ExpressionError::InvalidBaseType(base)); + } + }; + match resolver[index] { + //TODO: only allow one of these + Ti::Scalar(Sc { + kind: Sk::Sint | Sk::Uint, + .. + }) => {} + ref other => { + log::error!("Indexing by {:?}", other); + return Err(ExpressionError::InvalidIndexType(index)); + } + } + if dynamic_indexing_restricted + && function.expressions[index].is_dynamic_index(module) + { + return Err(ExpressionError::IndexMustBeConstant(base)); + } + + // If we know both the length and the index, we can do the + // bounds check now. + if let crate::proc::IndexableLength::Known(known_length) = + base_type.indexable_length(module)? + { + match module + .to_ctx() + .eval_expr_to_u32_from(index, &function.expressions) + { + Ok(value) => { + if value >= known_length { + return Err(ExpressionError::IndexOutOfBounds(base, value)); + } + } + Err(crate::proc::U32EvalError::Negative) => { + return Err(ExpressionError::NegativeIndex(base)) + } + Err(crate::proc::U32EvalError::NonConst) => {} + } + } + + ShaderStages::all() + } + E::AccessIndex { base, index } => { + fn resolve_index_limit( + module: &crate::Module, + top: Handle<crate::Expression>, + ty: &crate::TypeInner, + top_level: bool, + ) -> Result<u32, ExpressionError> { + let limit = match *ty { + Ti::Vector { size, .. } + | Ti::ValuePointer { + size: Some(size), .. + } => size as u32, + Ti::Matrix { columns, .. } => columns as u32, + Ti::Array { + size: crate::ArraySize::Constant(len), + .. + } => len.get(), + Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, // can't statically know, but need run-time checks + Ti::Pointer { base, .. } if top_level => { + resolve_index_limit(module, top, &module.types[base].inner, false)? + } + Ti::Struct { ref members, .. } => members.len() as u32, + ref other => { + log::error!("Indexing of {:?}", other); + return Err(ExpressionError::InvalidBaseType(top)); + } + }; + Ok(limit) + } + + let limit = resolve_index_limit(module, base, &resolver[base], true)?; + if index >= limit { + return Err(ExpressionError::IndexOutOfBounds(base, limit)); + } + ShaderStages::all() + } + E::Splat { size: _, value } => match resolver[value] { + Ti::Scalar { .. } => ShaderStages::all(), + ref other => { + log::error!("Splat scalar type {:?}", other); + return Err(ExpressionError::InvalidSplatType(value)); + } + }, + E::Swizzle { + size, + vector, + pattern, + } => { + let vec_size = match resolver[vector] { + Ti::Vector { size: vec_size, .. } => vec_size, + ref other => { + log::error!("Swizzle vector type {:?}", other); + return Err(ExpressionError::InvalidVectorType(vector)); + } + }; + for &sc in pattern[..size as usize].iter() { + if sc as u8 >= vec_size as u8 { + return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size)); + } + } + ShaderStages::all() + } + E::Literal(literal) => { + self.validate_literal(literal)?; + ShaderStages::all() + } + E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(), + E::Compose { ref components, ty } => { + validate_compose( + ty, + module.to_ctx(), + components.iter().map(|&handle| info[handle].ty.clone()), + )?; + ShaderStages::all() + } + E::FunctionArgument(index) => { + if index >= function.arguments.len() as u32 { + return Err(ExpressionError::FunctionArgumentDoesntExist(index)); + } + ShaderStages::all() + } + E::GlobalVariable(_handle) => ShaderStages::all(), + E::LocalVariable(_handle) => ShaderStages::all(), + E::Load { pointer } => { + match resolver[pointer] { + Ti::Pointer { base, .. } + if self.types[base.index()] + .flags + .contains(TypeFlags::SIZED | TypeFlags::DATA) => {} + Ti::ValuePointer { .. } => {} + ref other => { + log::error!("Loading {:?}", other); + return Err(ExpressionError::InvalidPointerType(pointer)); + } + } + ShaderStages::all() + } + E::ImageSample { + image, + sampler, + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + } => { + // check the validity of expressions + let image_ty = Self::global_var_ty(module, function, image)?; + let sampler_ty = Self::global_var_ty(module, function, sampler)?; + + let comparison = match module.types[sampler_ty].inner { + Ti::Sampler { comparison } => comparison, + _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)), + }; + + let (class, dim) = match module.types[image_ty].inner { + Ti::Image { + class, + arrayed, + dim, + } => { + // check the array property + if arrayed != array_index.is_some() { + return Err(ExpressionError::InvalidImageArrayIndex); + } + if let Some(expr) = array_index { + match resolver[expr] { + Ti::Scalar(Sc { + kind: Sk::Sint | Sk::Uint, + .. + }) => {} + _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)), + } + } + (class, dim) + } + _ => return Err(ExpressionError::ExpectedImageType(image_ty)), + }; + + // check sampling and comparison properties + let image_depth = match class { + crate::ImageClass::Sampled { + kind: crate::ScalarKind::Float, + multi: false, + } => false, + crate::ImageClass::Sampled { + kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint, + multi: false, + } if gather.is_some() => false, + crate::ImageClass::Depth { multi: false } => true, + _ => return Err(ExpressionError::InvalidImageClass(class)), + }; + if comparison != depth_ref.is_some() || (comparison && !image_depth) { + return Err(ExpressionError::ComparisonSamplingMismatch { + image: class, + sampler: comparison, + has_ref: depth_ref.is_some(), + }); + } + + // check texture coordinates type + let num_components = match dim { + crate::ImageDimension::D1 => 1, + crate::ImageDimension::D2 => 2, + crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3, + }; + match resolver[coordinate] { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) if num_components == 1 => {} + Ti::Vector { + size, + scalar: + Sc { + kind: Sk::Float, .. + }, + } if size as u32 == num_components => {} + _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)), + } + + // check constant offset + if let Some(const_expr) = offset { + match *mod_info[const_expr].inner_with(&module.types) { + Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {} + Ti::Vector { + size, + scalar: Sc { kind: Sk::Sint, .. }, + } if size as u32 == num_components => {} + _ => { + return Err(ExpressionError::InvalidSampleOffset(dim, const_expr)); + } + } + } + + // check depth reference type + if let Some(expr) = depth_ref { + match resolver[expr] { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) => {} + _ => return Err(ExpressionError::InvalidDepthReference(expr)), + } + match level { + crate::SampleLevel::Auto | crate::SampleLevel::Zero => {} + _ => return Err(ExpressionError::InvalidDepthSampleLevel), + } + } + + if let Some(component) = gather { + match dim { + crate::ImageDimension::D2 | crate::ImageDimension::Cube => {} + crate::ImageDimension::D1 | crate::ImageDimension::D3 => { + return Err(ExpressionError::InvalidGatherDimension(dim)) + } + }; + let max_component = match class { + crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X, + _ => crate::SwizzleComponent::W, + }; + if component > max_component { + return Err(ExpressionError::InvalidGatherComponent(component)); + } + match level { + crate::SampleLevel::Zero => {} + _ => return Err(ExpressionError::InvalidGatherLevel), + } + } + + // check level properties + match level { + crate::SampleLevel::Auto => ShaderStages::FRAGMENT, + crate::SampleLevel::Zero => ShaderStages::all(), + crate::SampleLevel::Exact(expr) => { + match resolver[expr] { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) => {} + _ => return Err(ExpressionError::InvalidSampleLevelExactType(expr)), + } + ShaderStages::all() + } + crate::SampleLevel::Bias(expr) => { + match resolver[expr] { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) => {} + _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)), + } + ShaderStages::FRAGMENT + } + crate::SampleLevel::Gradient { x, y } => { + match resolver[x] { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) if num_components == 1 => {} + Ti::Vector { + size, + scalar: + Sc { + kind: Sk::Float, .. + }, + } if size as u32 == num_components => {} + _ => { + return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x)) + } + } + match resolver[y] { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) if num_components == 1 => {} + Ti::Vector { + size, + scalar: + Sc { + kind: Sk::Float, .. + }, + } if size as u32 == num_components => {} + _ => { + return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y)) + } + } + ShaderStages::all() + } + } + } + E::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + let ty = Self::global_var_ty(module, function, image)?; + match module.types[ty].inner { + Ti::Image { + class, + arrayed, + dim, + } => { + match resolver[coordinate].image_storage_coordinates() { + Some(coord_dim) if coord_dim == dim => {} + _ => { + return Err(ExpressionError::InvalidImageCoordinateType( + dim, coordinate, + )) + } + }; + if arrayed != array_index.is_some() { + return Err(ExpressionError::InvalidImageArrayIndex); + } + if let Some(expr) = array_index { + match resolver[expr] { + Ti::Scalar(Sc { + kind: Sk::Sint | Sk::Uint, + width: _, + }) => {} + _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)), + } + } + + match (sample, class.is_multisampled()) { + (None, false) => {} + (Some(sample), true) => { + if resolver[sample].scalar_kind() != Some(Sk::Sint) { + return Err(ExpressionError::InvalidImageOtherIndexType( + sample, + )); + } + } + _ => { + return Err(ExpressionError::InvalidImageOtherIndex); + } + } + + match (level, class.is_mipmapped()) { + (None, false) => {} + (Some(level), true) => { + if resolver[level].scalar_kind() != Some(Sk::Sint) { + return Err(ExpressionError::InvalidImageOtherIndexType(level)); + } + } + _ => { + return Err(ExpressionError::InvalidImageOtherIndex); + } + } + } + _ => return Err(ExpressionError::ExpectedImageType(ty)), + } + ShaderStages::all() + } + E::ImageQuery { image, query } => { + let ty = Self::global_var_ty(module, function, image)?; + match module.types[ty].inner { + Ti::Image { class, arrayed, .. } => { + let good = match query { + crate::ImageQuery::NumLayers => arrayed, + crate::ImageQuery::Size { level: None } => true, + crate::ImageQuery::Size { level: Some(_) } + | crate::ImageQuery::NumLevels => class.is_mipmapped(), + crate::ImageQuery::NumSamples => class.is_multisampled(), + }; + if !good { + return Err(ExpressionError::InvalidImageClass(class)); + } + } + _ => return Err(ExpressionError::ExpectedImageType(ty)), + } + ShaderStages::all() + } + E::Unary { op, expr } => { + use crate::UnaryOperator as Uo; + let inner = &resolver[expr]; + match (op, inner.scalar_kind()) { + (Uo::Negate, Some(Sk::Float | Sk::Sint)) + | (Uo::LogicalNot, Some(Sk::Bool)) + | (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {} + other => { + log::error!("Op {:?} kind {:?}", op, other); + return Err(ExpressionError::InvalidUnaryOperandType(op, expr)); + } + } + ShaderStages::all() + } + E::Binary { op, left, right } => { + use crate::BinaryOperator as Bo; + let left_inner = &resolver[left]; + let right_inner = &resolver[right]; + let good = match op { + Bo::Add | Bo::Subtract => match *left_inner { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, + }, + Ti::Matrix { .. } => left_inner == right_inner, + _ => false, + }, + Bo::Divide | Bo::Modulo => match *left_inner { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, + }, + _ => false, + }, + Bo::Multiply => { + let kind_allowed = match left_inner.scalar_kind() { + Some(Sk::Uint | Sk::Sint | Sk::Float) => true, + Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false, + }; + let types_match = match (left_inner, right_inner) { + // Straight scalar and mixed scalar/vector. + (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2)) + | ( + &Ti::Vector { + scalar: scalar1, .. + }, + &Ti::Scalar(scalar2), + ) + | ( + &Ti::Scalar(scalar1), + &Ti::Vector { + scalar: scalar2, .. + }, + ) => scalar1 == scalar2, + // Scalar/matrix. + ( + &Ti::Scalar(Sc { + kind: Sk::Float, .. + }), + &Ti::Matrix { .. }, + ) + | ( + &Ti::Matrix { .. }, + &Ti::Scalar(Sc { + kind: Sk::Float, .. + }), + ) => true, + // Vector/vector. + ( + &Ti::Vector { + size: size1, + scalar: scalar1, + }, + &Ti::Vector { + size: size2, + scalar: scalar2, + }, + ) => scalar1 == scalar2 && size1 == size2, + // Matrix * vector. + ( + &Ti::Matrix { columns, .. }, + &Ti::Vector { + size, + scalar: + Sc { + kind: Sk::Float, .. + }, + }, + ) => columns == size, + // Vector * matrix. + ( + &Ti::Vector { + size, + scalar: + Sc { + kind: Sk::Float, .. + }, + }, + &Ti::Matrix { rows, .. }, + ) => size == rows, + (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => { + columns == rows + } + _ => false, + }; + let left_width = left_inner.scalar_width().unwrap_or(0); + let right_width = right_inner.scalar_width().unwrap_or(0); + kind_allowed && types_match && left_width == right_width + } + Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner, + Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => { + match *left_inner { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, + }, + ref other => { + log::error!("Op {:?} left type {:?}", op, other); + false + } + } + } + Bo::LogicalAnd | Bo::LogicalOr => match *left_inner { + Ti::Scalar(Sc { kind: Sk::Bool, .. }) + | Ti::Vector { + scalar: Sc { kind: Sk::Bool, .. }, + .. + } => left_inner == right_inner, + ref other => { + log::error!("Op {:?} left type {:?}", op, other); + false + } + }, + Bo::And | Bo::InclusiveOr => match *left_inner { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { + Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner, + Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false, + }, + ref other => { + log::error!("Op {:?} left type {:?}", op, other); + false + } + }, + Bo::ExclusiveOr => match *left_inner { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { + Sk::Sint | Sk::Uint => left_inner == right_inner, + Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false, + }, + ref other => { + log::error!("Op {:?} left type {:?}", op, other); + false + } + }, + Bo::ShiftLeft | Bo::ShiftRight => { + let (base_size, base_scalar) = match *left_inner { + Ti::Scalar(scalar) => (Ok(None), scalar), + Ti::Vector { size, scalar } => (Ok(Some(size)), scalar), + ref other => { + log::error!("Op {:?} base type {:?}", op, other); + (Err(()), Sc::BOOL) + } + }; + let shift_size = match *right_inner { + Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None), + Ti::Vector { + size, + scalar: Sc { kind: Sk::Uint, .. }, + } => Ok(Some(size)), + ref other => { + log::error!("Op {:?} shift type {:?}", op, other); + Err(()) + } + }; + match base_scalar.kind { + Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size, + Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false, + } + } + }; + if !good { + log::error!( + "Left: {:?} of type {:?}", + function.expressions[left], + left_inner + ); + log::error!( + "Right: {:?} of type {:?}", + function.expressions[right], + right_inner + ); + return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right)); + } + ShaderStages::all() + } + E::Select { + condition, + accept, + reject, + } => { + let accept_inner = &resolver[accept]; + let reject_inner = &resolver[reject]; + let condition_good = match resolver[condition] { + Ti::Scalar(Sc { + kind: Sk::Bool, + width: _, + }) => { + // When `condition` is a single boolean, `accept` and + // `reject` can be vectors or scalars. + match *accept_inner { + Ti::Scalar { .. } | Ti::Vector { .. } => true, + _ => false, + } + } + Ti::Vector { + size, + scalar: + Sc { + kind: Sk::Bool, + width: _, + }, + } => match *accept_inner { + Ti::Vector { + size: other_size, .. + } => size == other_size, + _ => false, + }, + _ => false, + }; + if !condition_good || accept_inner != reject_inner { + return Err(ExpressionError::InvalidSelectTypes); + } + ShaderStages::all() + } + E::Derivative { expr, .. } => { + match resolver[expr] { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) + | Ti::Vector { + scalar: + Sc { + kind: Sk::Float, .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidDerivative), + } + ShaderStages::FRAGMENT + } + E::Relational { fun, argument } => { + use crate::RelationalFunction as Rf; + let argument_inner = &resolver[argument]; + match fun { + Rf::All | Rf::Any => match *argument_inner { + Ti::Vector { + scalar: Sc { kind: Sk::Bool, .. }, + .. + } => {} + ref other => { + log::error!("All/Any of type {:?}", other); + return Err(ExpressionError::InvalidBooleanVector(argument)); + } + }, + Rf::IsNan | Rf::IsInf => match *argument_inner { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } + if scalar.kind == Sk::Float => {} + ref other => { + log::error!("Float test of type {:?}", other); + return Err(ExpressionError::InvalidFloatArgument(argument)); + } + }, + } + ShaderStages::all() + } + E::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + use crate::MathFunction as Mf; + + let resolve = |arg| &resolver[arg]; + let arg_ty = resolve(arg); + let arg1_ty = arg1.map(resolve); + let arg2_ty = arg2.map(resolve); + let arg3_ty = arg3.map(resolve); + match fun { + Mf::Abs => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + let good = match *arg_ty { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { + scalar.kind != Sk::Bool + } + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); + } + } + Mf::Min | Mf::Max => { + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + let good = match *arg_ty { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { + scalar.kind != Sk::Bool + } + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Clamp => { + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), None) => (ty1, ty2), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + let good = match *arg_ty { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { + scalar.kind != Sk::Bool + } + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + if arg2_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )); + } + } + Mf::Saturate + | Mf::Cos + | Mf::Cosh + | Mf::Sin + | Mf::Sinh + | Mf::Tan + | Mf::Tanh + | Mf::Acos + | Mf::Asin + | Mf::Atan + | Mf::Asinh + | Mf::Acosh + | Mf::Atanh + | Mf::Radians + | Mf::Degrees + | Mf::Ceil + | Mf::Floor + | Mf::Round + | Mf::Fract + | Mf::Trunc + | Mf::Exp + | Mf::Exp2 + | Mf::Log + | Mf::Log2 + | Mf::Length + | Mf::Sqrt + | Mf::InverseSqrt => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } + if scalar.kind == Sk::Float => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::Sign => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Scalar(Sc { + kind: Sk::Float | Sk::Sint, + .. + }) + | Ti::Vector { + scalar: + Sc { + kind: Sk::Float | Sk::Sint, + .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => { + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } + if scalar.kind == Sk::Float => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Modf | Mf::Frexp => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + if !matches!(*arg_ty, + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } + if scalar.kind == Sk::Float) + { + return Err(ExpressionError::InvalidArgumentType(fun, 1, arg)); + } + } + Mf::Ldexp => { + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + let size0 = match *arg_ty { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) => None, + Ti::Vector { + scalar: + Sc { + kind: Sk::Float, .. + }, + size, + } => Some(size), + _ => { + return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); + } + }; + let good = match *arg1_ty { + Ti::Scalar(Sc { kind: Sk::Sint, .. }) if size0.is_none() => true, + Ti::Vector { + size, + scalar: Sc { kind: Sk::Sint, .. }, + } if Some(size) == size0 => true, + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Dot => { + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Vector { + scalar: + Sc { + kind: Sk::Float | Sk::Sint | Sk::Uint, + .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Outer | Mf::Cross | Mf::Reflect => { + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Vector { + scalar: + Sc { + kind: Sk::Float, .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Refract => { + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), None) => (ty1, ty2), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + + match *arg_ty { + Ti::Vector { + scalar: + Sc { + kind: Sk::Float, .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + + match (arg_ty, arg2_ty) { + ( + &Ti::Vector { + scalar: + Sc { + width: vector_width, + .. + }, + .. + }, + &Ti::Scalar(Sc { + width: scalar_width, + kind: Sk::Float, + }), + ) if vector_width == scalar_width => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )) + } + } + } + Mf::Normalize => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Vector { + scalar: + Sc { + kind: Sk::Float, .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::FaceForward | Mf::Fma | Mf::SmoothStep => { + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), None) => (ty1, ty2), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) + | Ti::Vector { + scalar: + Sc { + kind: Sk::Float, .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + if arg2_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )); + } + } + Mf::Mix => { + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), None) => (ty1, ty2), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + let arg_width = match *arg_ty { + Ti::Scalar(Sc { + kind: Sk::Float, + width, + }) + | Ti::Vector { + scalar: + Sc { + kind: Sk::Float, + width, + }, + .. + } => width, + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + }; + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + // the last argument can always be a scalar + match *arg2_ty { + Ti::Scalar(Sc { + kind: Sk::Float, + width, + }) if width == arg_width => {} + _ if arg2_ty == arg_ty => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )); + } + } + } + Mf::Inverse | Mf::Determinant => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + let good = match *arg_ty { + Ti::Matrix { columns, rows, .. } => columns == rows, + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); + } + } + Mf::Transpose => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Matrix { .. } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::CountTrailingZeros + | Mf::CountLeadingZeros + | Mf::CountOneBits + | Mf::ReverseBits + | Mf::FindLsb + | Mf::FindMsb => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Scalar(Sc { + kind: Sk::Sint | Sk::Uint, + .. + }) + | Ti::Vector { + scalar: + Sc { + kind: Sk::Sint | Sk::Uint, + .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::InsertBits => { + let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Scalar(Sc { + kind: Sk::Sint | Sk::Uint, + .. + }) + | Ti::Vector { + scalar: + Sc { + kind: Sk::Sint | Sk::Uint, + .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + match *arg2_ty { + Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )) + } + } + match *arg3_ty { + Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg3.unwrap(), + )) + } + } + } + Mf::ExtractBits => { + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), None) => (ty1, ty2), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Scalar(Sc { + kind: Sk::Sint | Sk::Uint, + .. + }) + | Ti::Vector { + scalar: + Sc { + kind: Sk::Sint | Sk::Uint, + .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + match *arg1_ty { + Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg1.unwrap(), + )) + } + } + match *arg2_ty { + Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )) + } + } + } + Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Vector { + size: crate::VectorSize::Bi, + scalar: + Sc { + kind: Sk::Float, .. + }, + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::Pack4x8snorm | Mf::Pack4x8unorm => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Vector { + size: crate::VectorSize::Quad, + scalar: + Sc { + kind: Sk::Float, .. + }, + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::Unpack2x16float + | Mf::Unpack2x16snorm + | Mf::Unpack2x16unorm + | Mf::Unpack4x8snorm + | Mf::Unpack4x8unorm => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + } + ShaderStages::all() + } + E::As { + expr, + kind, + convert, + } => { + let mut base_scalar = match resolver[expr] { + crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => { + scalar + } + crate::TypeInner::Matrix { scalar, .. } => scalar, + _ => return Err(ExpressionError::InvalidCastArgument), + }; + base_scalar.kind = kind; + if let Some(width) = convert { + base_scalar.width = width; + } + if self.check_width(base_scalar).is_err() { + return Err(ExpressionError::InvalidCastArgument); + } + ShaderStages::all() + } + E::CallResult(function) => mod_info.functions[function.index()].available_stages, + E::AtomicResult { ty, comparison } => { + let scalar_predicate = |ty: &crate::TypeInner| match ty { + &crate::TypeInner::Scalar( + scalar @ Sc { + kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint, + .. + }, + ) => self.check_width(scalar).is_ok(), + _ => false, + }; + let good = match &module.types[ty].inner { + ty if !comparison => scalar_predicate(ty), + &crate::TypeInner::Struct { ref members, .. } if comparison => { + validate_atomic_compare_exchange_struct( + &module.types, + members, + scalar_predicate, + ) + } + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidAtomicResultType(ty)); + } + ShaderStages::all() + } + E::WorkGroupUniformLoadResult { ty } => { + if self.types[ty.index()] + .flags + // Sized | Constructible is exactly the types currently supported by + // WorkGroupUniformLoad + .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE) + { + ShaderStages::COMPUTE + } else { + return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty)); + } + } + E::ArrayLength(expr) => match resolver[expr] { + Ti::Pointer { base, .. } => { + let base_ty = &resolver.types[base]; + if let Ti::Array { + size: crate::ArraySize::Dynamic, + .. + } = base_ty.inner + { + ShaderStages::all() + } else { + return Err(ExpressionError::InvalidArrayType(expr)); + } + } + ref other => { + log::error!("Array length of {:?}", other); + return Err(ExpressionError::InvalidArrayType(expr)); + } + }, + E::RayQueryProceedResult => ShaderStages::all(), + E::RayQueryGetIntersection { + query, + committed: _, + } => match resolver[query] { + Ti::Pointer { + base, + space: crate::AddressSpace::Function, + } => match resolver.types[base].inner { + Ti::RayQuery => ShaderStages::all(), + ref other => { + log::error!("Intersection result of a pointer to {:?}", other); + return Err(ExpressionError::InvalidRayQueryType(query)); + } + }, + ref other => { + log::error!("Intersection result of {:?}", other); + return Err(ExpressionError::InvalidRayQueryType(query)); + } + }, + }; + Ok(stages) + } + + fn global_var_ty( + module: &crate::Module, + function: &crate::Function, + expr: Handle<crate::Expression>, + ) -> Result<Handle<crate::Type>, ExpressionError> { + use crate::Expression as Ex; + + match function.expressions[expr] { + Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty), + Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty), + Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => { + match function.expressions[base] { + Ex::GlobalVariable(var_handle) => { + let array_ty = module.global_variables[var_handle].ty; + + match module.types[array_ty].inner { + crate::TypeInner::BindingArray { base, .. } => Ok(base), + _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)), + } + } + _ => Err(ExpressionError::ExpectedGlobalVariable), + } + } + _ => Err(ExpressionError::ExpectedGlobalVariable), + } + } + + pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> { + self.check_width(literal.scalar())?; + check_literal_value(literal)?; + + Ok(()) + } +} + +pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> { + let is_nan = match literal { + crate::Literal::F64(v) => v.is_nan(), + crate::Literal::F32(v) => v.is_nan(), + _ => false, + }; + if is_nan { + return Err(LiteralError::NaN); + } + + let is_infinite = match literal { + crate::Literal::F64(v) => v.is_infinite(), + crate::Literal::F32(v) => v.is_infinite(), + _ => false, + }; + if is_infinite { + return Err(LiteralError::Infinity); + } + + Ok(()) +} + +#[cfg(all(test, feature = "validate"))] +/// Validate a module containing the given expression, expecting an error. +fn validate_with_expression( + expr: crate::Expression, + caps: super::Capabilities, +) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> { + use crate::span::Span; + + let mut function = crate::Function::default(); + function.expressions.append(expr, Span::default()); + function.body.push( + crate::Statement::Emit(function.expressions.range_from(0)), + Span::default(), + ); + + let mut module = crate::Module::default(); + module.functions.append(function, Span::default()); + + let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps); + + validator.validate(&module) +} + +#[cfg(all(test, feature = "validate"))] +/// Validate a module containing the given constant expression, expecting an error. +fn validate_with_const_expression( + expr: crate::Expression, + caps: super::Capabilities, +) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> { + use crate::span::Span; + + let mut module = crate::Module::default(); + module.const_expressions.append(expr, Span::default()); + + let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps); + + validator.validate(&module) +} + +/// Using F64 in a function's expression arena is forbidden. +#[cfg(feature = "validate")] +#[test] +fn f64_runtime_literals() { + let result = validate_with_expression( + crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), + super::Capabilities::default(), + ); + let error = result.unwrap_err().into_inner(); + assert!(matches!( + error, + crate::valid::ValidationError::Function { + source: super::FunctionError::Expression { + source: super::ExpressionError::Literal(super::LiteralError::Width( + super::r#type::WidthError::MissingCapability { + name: "f64", + flag: "FLOAT64", + } + ),), + .. + }, + .. + } + )); + + let result = validate_with_expression( + crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), + super::Capabilities::default() | super::Capabilities::FLOAT64, + ); + assert!(result.is_ok()); +} + +/// Using F64 in a module's constant expression arena is forbidden. +#[cfg(feature = "validate")] +#[test] +fn f64_const_literals() { + let result = validate_with_const_expression( + crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), + super::Capabilities::default(), + ); + let error = result.unwrap_err().into_inner(); + assert!(matches!( + error, + crate::valid::ValidationError::ConstExpression { + source: super::ConstExpressionError::Literal(super::LiteralError::Width( + super::r#type::WidthError::MissingCapability { + name: "f64", + flag: "FLOAT64", + } + )), + .. + } + )); + + let result = validate_with_const_expression( + crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), + super::Capabilities::default() | super::Capabilities::FLOAT64, + ); + assert!(result.is_ok()); +} + +/// Using I64 in a function's expression arena is forbidden. +#[cfg(feature = "validate")] +#[test] +fn i64_runtime_literals() { + let result = validate_with_expression( + crate::Expression::Literal(crate::Literal::I64(1729)), + // There is no capability that enables this. + super::Capabilities::all(), + ); + let error = result.unwrap_err().into_inner(); + assert!(matches!( + error, + crate::valid::ValidationError::Function { + source: super::FunctionError::Expression { + source: super::ExpressionError::Literal(super::LiteralError::Width( + super::r#type::WidthError::Unsupported64Bit + ),), + .. + }, + .. + } + )); +} + +/// Using I64 in a module's constant expression arena is forbidden. +#[cfg(feature = "validate")] +#[test] +fn i64_const_literals() { + let result = validate_with_const_expression( + crate::Expression::Literal(crate::Literal::I64(1729)), + // There is no capability that enables this. + super::Capabilities::all(), + ); + let error = result.unwrap_err().into_inner(); + assert!(matches!( + error, + crate::valid::ValidationError::ConstExpression { + source: super::ConstExpressionError::Literal(super::LiteralError::Width( + super::r#type::WidthError::Unsupported64Bit, + ),), + .. + } + )); +} diff --git a/third_party/rust/naga/src/valid/function.rs b/third_party/rust/naga/src/valid/function.rs new file mode 100644 index 0000000000..f0ca22cbda --- /dev/null +++ b/third_party/rust/naga/src/valid/function.rs @@ -0,0 +1,1056 @@ +use crate::arena::Handle; +use crate::arena::{Arena, UniqueArena}; + +use super::validate_atomic_compare_exchange_struct; + +use super::{ + analyzer::{UniformityDisruptor, UniformityRequirements}, + ExpressionError, FunctionInfo, ModuleInfo, +}; +use crate::span::WithSpan; +use crate::span::{AddSpan as _, MapErrWithSpan as _}; + +use bit_set::BitSet; + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum CallError { + #[error("Argument {index} expression is invalid")] + Argument { + index: usize, + source: ExpressionError, + }, + #[error("Result expression {0:?} has already been introduced earlier")] + ResultAlreadyInScope(Handle<crate::Expression>), + #[error("Result value is invalid")] + ResultValue(#[source] ExpressionError), + #[error("Requires {required} arguments, but {seen} are provided")] + ArgumentCount { required: usize, seen: usize }, + #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")] + ArgumentType { + index: usize, + required: Handle<crate::Type>, + seen_expression: Handle<crate::Expression>, + }, + #[error("The emitted expression doesn't match the call")] + ExpressionMismatch(Option<Handle<crate::Expression>>), +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum AtomicError { + #[error("Pointer {0:?} to atomic is invalid.")] + InvalidPointer(Handle<crate::Expression>), + #[error("Operand {0:?} has invalid type.")] + InvalidOperand(Handle<crate::Expression>), + #[error("Result type for {0:?} doesn't match the statement")] + ResultTypeMismatch(Handle<crate::Expression>), +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum LocalVariableError { + #[error("Local variable has a type {0:?} that can't be stored in a local variable.")] + InvalidType(Handle<crate::Type>), + #[error("Initializer doesn't match the variable type")] + InitializerType, + #[error("Initializer is not const")] + NonConstInitializer, +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum FunctionError { + #[error("Expression {handle:?} is invalid")] + Expression { + handle: Handle<crate::Expression>, + source: ExpressionError, + }, + #[error("Expression {0:?} can't be introduced - it's already in scope")] + ExpressionAlreadyInScope(Handle<crate::Expression>), + #[error("Local variable {handle:?} '{name}' is invalid")] + LocalVariable { + handle: Handle<crate::LocalVariable>, + name: String, + source: LocalVariableError, + }, + #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")] + InvalidArgumentType { index: usize, name: String }, + #[error("The function's given return type cannot be returned from functions")] + NonConstructibleReturnType, + #[error("Argument '{name}' at index {index} is a pointer of space {space:?}, which can't be passed into functions.")] + InvalidArgumentPointerSpace { + index: usize, + name: String, + space: crate::AddressSpace, + }, + #[error("There are instructions after `return`/`break`/`continue`")] + InstructionsAfterReturn, + #[error("The `break` is used outside of a `loop` or `switch` context")] + BreakOutsideOfLoopOrSwitch, + #[error("The `continue` is used outside of a `loop` context")] + ContinueOutsideOfLoop, + #[error("The `return` is called within a `continuing` block")] + InvalidReturnSpot, + #[error("The `return` value {0:?} does not match the function return value")] + InvalidReturnType(Option<Handle<crate::Expression>>), + #[error("The `if` condition {0:?} is not a boolean scalar")] + InvalidIfType(Handle<crate::Expression>), + #[error("The `switch` value {0:?} is not an integer scalar")] + InvalidSwitchType(Handle<crate::Expression>), + #[error("Multiple `switch` cases for {0:?} are present")] + ConflictingSwitchCase(crate::SwitchValue), + #[error("The `switch` contains cases with conflicting types")] + ConflictingCaseType, + #[error("The `switch` is missing a `default` case")] + MissingDefaultCase, + #[error("Multiple `default` cases are present")] + MultipleDefaultCases, + #[error("The last `switch` case contains a `fallthrough`")] + LastCaseFallTrough, + #[error("The pointer {0:?} doesn't relate to a valid destination for a store")] + InvalidStorePointer(Handle<crate::Expression>), + #[error("The value {0:?} can not be stored")] + InvalidStoreValue(Handle<crate::Expression>), + #[error("The type of {value:?} doesn't match the type stored in {pointer:?}")] + InvalidStoreTypes { + pointer: Handle<crate::Expression>, + value: Handle<crate::Expression>, + }, + #[error("Image store parameters are invalid")] + InvalidImageStore(#[source] ExpressionError), + #[error("Call to {function:?} is invalid")] + InvalidCall { + function: Handle<crate::Function>, + #[source] + error: CallError, + }, + #[error("Atomic operation is invalid")] + InvalidAtomic(#[from] AtomicError), + #[error("Ray Query {0:?} is not a local variable")] + InvalidRayQueryExpression(Handle<crate::Expression>), + #[error("Acceleration structure {0:?} is not a matching expression")] + InvalidAccelerationStructure(Handle<crate::Expression>), + #[error("Ray descriptor {0:?} is not a matching expression")] + InvalidRayDescriptor(Handle<crate::Expression>), + #[error("Ray Query {0:?} does not have a matching type")] + InvalidRayQueryType(Handle<crate::Type>), + #[error( + "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}" + )] + NonUniformControlFlow( + UniformityRequirements, + Handle<crate::Expression>, + UniformityDisruptor, + ), + #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their arguments: \"{name}\" has attributes")] + PipelineInputRegularFunction { name: String }, + #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their return value types")] + PipelineOutputRegularFunction, + #[error("Required uniformity for WorkGroupUniformLoad is not fulfilled because of {0:?}")] + // The actual load statement will be "pointed to" by the span + NonUniformWorkgroupUniformLoad(UniformityDisruptor), + // This is only possible with a misbehaving frontend + #[error("The expression {0:?} for a WorkGroupUniformLoad isn't a WorkgroupUniformLoadResult")] + WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>), + #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")] + WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>), +} + +bitflags::bitflags! { + #[repr(transparent)] + #[derive(Clone, Copy)] + struct ControlFlowAbility: u8 { + /// The control can return out of this block. + const RETURN = 0x1; + /// The control can break. + const BREAK = 0x2; + /// The control can continue. + const CONTINUE = 0x4; + } +} + +struct BlockInfo { + stages: super::ShaderStages, + finished: bool, +} + +struct BlockContext<'a> { + abilities: ControlFlowAbility, + info: &'a FunctionInfo, + expressions: &'a Arena<crate::Expression>, + types: &'a UniqueArena<crate::Type>, + local_vars: &'a Arena<crate::LocalVariable>, + global_vars: &'a Arena<crate::GlobalVariable>, + functions: &'a Arena<crate::Function>, + special_types: &'a crate::SpecialTypes, + prev_infos: &'a [FunctionInfo], + return_type: Option<Handle<crate::Type>>, +} + +impl<'a> BlockContext<'a> { + fn new( + fun: &'a crate::Function, + module: &'a crate::Module, + info: &'a FunctionInfo, + prev_infos: &'a [FunctionInfo], + ) -> Self { + Self { + abilities: ControlFlowAbility::RETURN, + info, + expressions: &fun.expressions, + types: &module.types, + local_vars: &fun.local_variables, + global_vars: &module.global_variables, + functions: &module.functions, + special_types: &module.special_types, + prev_infos, + return_type: fun.result.as_ref().map(|fr| fr.ty), + } + } + + const fn with_abilities(&self, abilities: ControlFlowAbility) -> Self { + BlockContext { abilities, ..*self } + } + + fn get_expression(&self, handle: Handle<crate::Expression>) -> &'a crate::Expression { + &self.expressions[handle] + } + + fn resolve_type_impl( + &self, + handle: Handle<crate::Expression>, + valid_expressions: &BitSet, + ) -> Result<&crate::TypeInner, WithSpan<ExpressionError>> { + if handle.index() >= self.expressions.len() { + Err(ExpressionError::DoesntExist.with_span()) + } else if !valid_expressions.contains(handle.index()) { + Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions)) + } else { + Ok(self.info[handle].ty.inner_with(self.types)) + } + } + + fn resolve_type( + &self, + handle: Handle<crate::Expression>, + valid_expressions: &BitSet, + ) -> Result<&crate::TypeInner, WithSpan<FunctionError>> { + self.resolve_type_impl(handle, valid_expressions) + .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span()) + } + + fn resolve_pointer_type( + &self, + handle: Handle<crate::Expression>, + ) -> Result<&crate::TypeInner, FunctionError> { + if handle.index() >= self.expressions.len() { + Err(FunctionError::Expression { + handle, + source: ExpressionError::DoesntExist, + }) + } else { + Ok(self.info[handle].ty.inner_with(self.types)) + } + } +} + +impl super::Validator { + fn validate_call( + &mut self, + function: Handle<crate::Function>, + arguments: &[Handle<crate::Expression>], + result: Option<Handle<crate::Expression>>, + context: &BlockContext, + ) -> Result<super::ShaderStages, WithSpan<CallError>> { + let fun = &context.functions[function]; + if fun.arguments.len() != arguments.len() { + return Err(CallError::ArgumentCount { + required: fun.arguments.len(), + seen: arguments.len(), + } + .with_span()); + } + for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() { + let ty = context + .resolve_type_impl(expr, &self.valid_expression_set) + .map_err_inner(|source| { + CallError::Argument { index, source } + .with_span_handle(expr, context.expressions) + })?; + let arg_inner = &context.types[arg.ty].inner; + if !ty.equivalent(arg_inner, context.types) { + return Err(CallError::ArgumentType { + index, + required: arg.ty, + seen_expression: expr, + } + .with_span_handle(expr, context.expressions)); + } + } + + if let Some(expr) = result { + if self.valid_expression_set.insert(expr.index()) { + self.valid_expression_list.push(expr); + } else { + return Err(CallError::ResultAlreadyInScope(expr) + .with_span_handle(expr, context.expressions)); + } + match context.expressions[expr] { + crate::Expression::CallResult(callee) + if fun.result.is_some() && callee == function => {} + _ => { + return Err(CallError::ExpressionMismatch(result) + .with_span_handle(expr, context.expressions)) + } + } + } else if fun.result.is_some() { + return Err(CallError::ExpressionMismatch(result).with_span()); + } + + let callee_info = &context.prev_infos[function.index()]; + Ok(callee_info.available_stages) + } + + fn emit_expression( + &mut self, + handle: Handle<crate::Expression>, + context: &BlockContext, + ) -> Result<(), WithSpan<FunctionError>> { + if self.valid_expression_set.insert(handle.index()) { + self.valid_expression_list.push(handle); + Ok(()) + } else { + Err(FunctionError::ExpressionAlreadyInScope(handle) + .with_span_handle(handle, context.expressions)) + } + } + + fn validate_atomic( + &mut self, + pointer: Handle<crate::Expression>, + fun: &crate::AtomicFunction, + value: Handle<crate::Expression>, + result: Handle<crate::Expression>, + context: &BlockContext, + ) -> Result<(), WithSpan<FunctionError>> { + let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?; + let ptr_scalar = match *pointer_inner { + crate::TypeInner::Pointer { base, .. } => match context.types[base].inner { + crate::TypeInner::Atomic(scalar) => scalar, + ref other => { + log::error!("Atomic pointer to type {:?}", other); + return Err(AtomicError::InvalidPointer(pointer) + .with_span_handle(pointer, context.expressions) + .into_other()); + } + }, + ref other => { + log::error!("Atomic on type {:?}", other); + return Err(AtomicError::InvalidPointer(pointer) + .with_span_handle(pointer, context.expressions) + .into_other()); + } + }; + + let value_inner = context.resolve_type(value, &self.valid_expression_set)?; + match *value_inner { + crate::TypeInner::Scalar(scalar) if scalar == ptr_scalar => {} + ref other => { + log::error!("Atomic operand type {:?}", other); + return Err(AtomicError::InvalidOperand(value) + .with_span_handle(value, context.expressions) + .into_other()); + } + } + + if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { + if context.resolve_type(cmp, &self.valid_expression_set)? != value_inner { + log::error!("Atomic exchange comparison has a different type from the value"); + return Err(AtomicError::InvalidOperand(cmp) + .with_span_handle(cmp, context.expressions) + .into_other()); + } + } + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::AtomicResult { ty, comparison } + if { + let scalar_predicate = + |ty: &crate::TypeInner| *ty == crate::TypeInner::Scalar(ptr_scalar); + match &context.types[ty].inner { + ty if !comparison => scalar_predicate(ty), + &crate::TypeInner::Struct { ref members, .. } if comparison => { + validate_atomic_compare_exchange_struct( + context.types, + members, + scalar_predicate, + ) + } + _ => false, + } + } => {} + _ => { + return Err(AtomicError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } + + fn validate_block_impl( + &mut self, + statements: &crate::Block, + context: &BlockContext, + ) -> Result<BlockInfo, WithSpan<FunctionError>> { + use crate::{AddressSpace, Statement as S, TypeInner as Ti}; + let mut finished = false; + let mut stages = super::ShaderStages::all(); + for (statement, &span) in statements.span_iter() { + if finished { + return Err(FunctionError::InstructionsAfterReturn + .with_span_static(span, "instructions after return")); + } + match *statement { + S::Emit(ref range) => { + for handle in range.clone() { + self.emit_expression(handle, context)?; + } + } + S::Block(ref block) => { + let info = self.validate_block(block, context)?; + stages &= info.stages; + finished = info.finished; + } + S::If { + condition, + ref accept, + ref reject, + } => { + match *context.resolve_type(condition, &self.valid_expression_set)? { + Ti::Scalar(crate::Scalar { + kind: crate::ScalarKind::Bool, + width: _, + }) => {} + _ => { + return Err(FunctionError::InvalidIfType(condition) + .with_span_handle(condition, context.expressions)) + } + } + stages &= self.validate_block(accept, context)?.stages; + stages &= self.validate_block(reject, context)?.stages; + } + S::Switch { + selector, + ref cases, + } => { + let uint = match context + .resolve_type(selector, &self.valid_expression_set)? + .scalar_kind() + { + Some(crate::ScalarKind::Uint) => true, + Some(crate::ScalarKind::Sint) => false, + _ => { + return Err(FunctionError::InvalidSwitchType(selector) + .with_span_handle(selector, context.expressions)) + } + }; + self.switch_values.clear(); + for case in cases { + match case.value { + crate::SwitchValue::I32(_) if !uint => {} + crate::SwitchValue::U32(_) if uint => {} + crate::SwitchValue::Default => {} + _ => { + return Err(FunctionError::ConflictingCaseType.with_span_static( + case.body + .span_iter() + .next() + .map_or(Default::default(), |(_, s)| *s), + "conflicting switch arm here", + )); + } + }; + if !self.switch_values.insert(case.value) { + return Err(match case.value { + crate::SwitchValue::Default => FunctionError::MultipleDefaultCases + .with_span_static( + case.body + .span_iter() + .next() + .map_or(Default::default(), |(_, s)| *s), + "duplicated switch arm here", + ), + _ => FunctionError::ConflictingSwitchCase(case.value) + .with_span_static( + case.body + .span_iter() + .next() + .map_or(Default::default(), |(_, s)| *s), + "conflicting switch arm here", + ), + }); + } + } + if !self.switch_values.contains(&crate::SwitchValue::Default) { + return Err(FunctionError::MissingDefaultCase + .with_span_static(span, "missing default case")); + } + if let Some(case) = cases.last() { + if case.fall_through { + return Err(FunctionError::LastCaseFallTrough.with_span_static( + case.body + .span_iter() + .next() + .map_or(Default::default(), |(_, s)| *s), + "bad switch arm here", + )); + } + } + let pass_through_abilities = context.abilities + & (ControlFlowAbility::RETURN | ControlFlowAbility::CONTINUE); + let sub_context = + context.with_abilities(pass_through_abilities | ControlFlowAbility::BREAK); + for case in cases { + stages &= self.validate_block(&case.body, &sub_context)?.stages; + } + } + S::Loop { + ref body, + ref continuing, + break_if, + } => { + // special handling for block scoping is needed here, + // because the continuing{} block inherits the scope + let base_expression_count = self.valid_expression_list.len(); + let pass_through_abilities = context.abilities & ControlFlowAbility::RETURN; + stages &= self + .validate_block_impl( + body, + &context.with_abilities( + pass_through_abilities + | ControlFlowAbility::BREAK + | ControlFlowAbility::CONTINUE, + ), + )? + .stages; + stages &= self + .validate_block_impl( + continuing, + &context.with_abilities(ControlFlowAbility::empty()), + )? + .stages; + + if let Some(condition) = break_if { + match *context.resolve_type(condition, &self.valid_expression_set)? { + Ti::Scalar(crate::Scalar { + kind: crate::ScalarKind::Bool, + width: _, + }) => {} + _ => { + return Err(FunctionError::InvalidIfType(condition) + .with_span_handle(condition, context.expressions)) + } + } + } + + for handle in self.valid_expression_list.drain(base_expression_count..) { + self.valid_expression_set.remove(handle.index()); + } + } + S::Break => { + if !context.abilities.contains(ControlFlowAbility::BREAK) { + return Err(FunctionError::BreakOutsideOfLoopOrSwitch + .with_span_static(span, "invalid break")); + } + finished = true; + } + S::Continue => { + if !context.abilities.contains(ControlFlowAbility::CONTINUE) { + return Err(FunctionError::ContinueOutsideOfLoop + .with_span_static(span, "invalid continue")); + } + finished = true; + } + S::Return { value } => { + if !context.abilities.contains(ControlFlowAbility::RETURN) { + return Err(FunctionError::InvalidReturnSpot + .with_span_static(span, "invalid return")); + } + let value_ty = value + .map(|expr| context.resolve_type(expr, &self.valid_expression_set)) + .transpose()?; + let expected_ty = context.return_type.map(|ty| &context.types[ty].inner); + // We can't return pointers, but it seems best not to embed that + // assumption here, so use `TypeInner::equivalent` for comparison. + let okay = match (value_ty, expected_ty) { + (None, None) => true, + (Some(value_inner), Some(expected_inner)) => { + value_inner.equivalent(expected_inner, context.types) + } + (_, _) => false, + }; + + if !okay { + log::error!( + "Returning {:?} where {:?} is expected", + value_ty, + expected_ty + ); + if let Some(handle) = value { + return Err(FunctionError::InvalidReturnType(value) + .with_span_handle(handle, context.expressions)); + } else { + return Err(FunctionError::InvalidReturnType(value) + .with_span_static(span, "invalid return")); + } + } + finished = true; + } + S::Kill => { + stages &= super::ShaderStages::FRAGMENT; + finished = true; + } + S::Barrier(_) => { + stages &= super::ShaderStages::COMPUTE; + } + S::Store { pointer, value } => { + let mut current = pointer; + loop { + let _ = context + .resolve_pointer_type(current) + .map_err(|e| e.with_span())?; + match context.expressions[current] { + crate::Expression::Access { base, .. } + | crate::Expression::AccessIndex { base, .. } => current = base, + crate::Expression::LocalVariable(_) + | crate::Expression::GlobalVariable(_) + | crate::Expression::FunctionArgument(_) => break, + _ => { + return Err(FunctionError::InvalidStorePointer(current) + .with_span_handle(pointer, context.expressions)) + } + } + } + + let value_ty = context.resolve_type(value, &self.valid_expression_set)?; + match *value_ty { + Ti::Image { .. } | Ti::Sampler { .. } => { + return Err(FunctionError::InvalidStoreValue(value) + .with_span_handle(value, context.expressions)); + } + _ => {} + } + + let pointer_ty = context + .resolve_pointer_type(pointer) + .map_err(|e| e.with_span())?; + + let good = match *pointer_ty { + Ti::Pointer { base, space: _ } => match context.types[base].inner { + Ti::Atomic(scalar) => *value_ty == Ti::Scalar(scalar), + ref other => value_ty == other, + }, + Ti::ValuePointer { + size: Some(size), + scalar, + space: _, + } => *value_ty == Ti::Vector { size, scalar }, + Ti::ValuePointer { + size: None, + scalar, + space: _, + } => *value_ty == Ti::Scalar(scalar), + _ => false, + }; + if !good { + return Err(FunctionError::InvalidStoreTypes { pointer, value } + .with_span() + .with_handle(pointer, context.expressions) + .with_handle(value, context.expressions)); + } + + if let Some(space) = pointer_ty.pointer_space() { + if !space.access().contains(crate::StorageAccess::STORE) { + return Err(FunctionError::InvalidStorePointer(pointer) + .with_span_static( + context.expressions.get_span(pointer), + "writing to this location is not permitted", + )); + } + } + } + S::ImageStore { + image, + coordinate, + array_index, + value, + } => { + //Note: this code uses a lot of `FunctionError::InvalidImageStore`, + // and could probably be refactored. + let var = match *context.get_expression(image) { + crate::Expression::GlobalVariable(var_handle) => { + &context.global_vars[var_handle] + } + // We're looking at a binding index situation, so punch through the index and look at the global behind it. + crate::Expression::Access { base, .. } + | crate::Expression::AccessIndex { base, .. } => { + match *context.get_expression(base) { + crate::Expression::GlobalVariable(var_handle) => { + &context.global_vars[var_handle] + } + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::ExpectedGlobalVariable, + ) + .with_span_handle(image, context.expressions)) + } + } + } + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::ExpectedGlobalVariable, + ) + .with_span_handle(image, context.expressions)) + } + }; + + // Punch through a binding array to get the underlying type + let global_ty = match context.types[var.ty].inner { + Ti::BindingArray { base, .. } => &context.types[base].inner, + ref inner => inner, + }; + + let value_ty = match *global_ty { + Ti::Image { + class, + arrayed, + dim, + } => { + match context + .resolve_type(coordinate, &self.valid_expression_set)? + .image_storage_coordinates() + { + Some(coord_dim) if coord_dim == dim => {} + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::InvalidImageCoordinateType( + dim, coordinate, + ), + ) + .with_span_handle(coordinate, context.expressions)); + } + }; + if arrayed != array_index.is_some() { + return Err(FunctionError::InvalidImageStore( + ExpressionError::InvalidImageArrayIndex, + ) + .with_span_handle(coordinate, context.expressions)); + } + if let Some(expr) = array_index { + match *context.resolve_type(expr, &self.valid_expression_set)? { + Ti::Scalar(crate::Scalar { + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + width: _, + }) => {} + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::InvalidImageArrayIndexType(expr), + ) + .with_span_handle(expr, context.expressions)); + } + } + } + match class { + crate::ImageClass::Storage { format, .. } => { + crate::TypeInner::Vector { + size: crate::VectorSize::Quad, + scalar: crate::Scalar { + kind: format.into(), + width: 4, + }, + } + } + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::InvalidImageClass(class), + ) + .with_span_handle(image, context.expressions)); + } + } + } + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::ExpectedImageType(var.ty), + ) + .with_span() + .with_handle(var.ty, context.types) + .with_handle(image, context.expressions)) + } + }; + + if *context.resolve_type(value, &self.valid_expression_set)? != value_ty { + return Err(FunctionError::InvalidStoreValue(value) + .with_span_handle(value, context.expressions)); + } + } + S::Call { + function, + ref arguments, + result, + } => match self.validate_call(function, arguments, result, context) { + Ok(callee_stages) => stages &= callee_stages, + Err(error) => { + return Err(error.and_then(|error| { + FunctionError::InvalidCall { function, error } + .with_span_static(span, "invalid function call") + })) + } + }, + S::Atomic { + pointer, + ref fun, + value, + result, + } => { + self.validate_atomic(pointer, fun, value, result, context)?; + } + S::WorkGroupUniformLoad { pointer, result } => { + stages &= super::ShaderStages::COMPUTE; + let pointer_inner = + context.resolve_type(pointer, &self.valid_expression_set)?; + match *pointer_inner { + Ti::Pointer { + space: AddressSpace::WorkGroup, + .. + } => {} + Ti::ValuePointer { + space: AddressSpace::WorkGroup, + .. + } => {} + _ => { + return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer) + .with_span_static(span, "WorkGroupUniformLoad")) + } + } + self.emit_expression(result, context)?; + let ty = match &context.expressions[result] { + &crate::Expression::WorkGroupUniformLoadResult { ty } => ty, + _ => { + return Err(FunctionError::WorkgroupUniformLoadExpressionMismatch( + result, + ) + .with_span_static(span, "WorkGroupUniformLoad")); + } + }; + let expected_pointer_inner = Ti::Pointer { + base: ty, + space: AddressSpace::WorkGroup, + }; + if !expected_pointer_inner.equivalent(pointer_inner, context.types) { + return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer) + .with_span_static(span, "WorkGroupUniformLoad")); + } + } + S::RayQuery { query, ref fun } => { + let query_var = match *context.get_expression(query) { + crate::Expression::LocalVariable(var) => &context.local_vars[var], + ref other => { + log::error!("Unexpected ray query expression {other:?}"); + return Err(FunctionError::InvalidRayQueryExpression(query) + .with_span_static(span, "invalid query expression")); + } + }; + match context.types[query_var.ty].inner { + Ti::RayQuery => {} + ref other => { + log::error!("Unexpected ray query type {other:?}"); + return Err(FunctionError::InvalidRayQueryType(query_var.ty) + .with_span_static(span, "invalid query type")); + } + } + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + match *context + .resolve_type(acceleration_structure, &self.valid_expression_set)? + { + Ti::AccelerationStructure => {} + _ => { + return Err(FunctionError::InvalidAccelerationStructure( + acceleration_structure, + ) + .with_span_static(span, "invalid acceleration structure")) + } + } + let desc_ty_given = + context.resolve_type(descriptor, &self.valid_expression_set)?; + let desc_ty_expected = context + .special_types + .ray_desc + .map(|handle| &context.types[handle].inner); + if Some(desc_ty_given) != desc_ty_expected { + return Err(FunctionError::InvalidRayDescriptor(descriptor) + .with_span_static(span, "invalid ray descriptor")); + } + } + crate::RayQueryFunction::Proceed { result } => { + self.emit_expression(result, context)?; + } + crate::RayQueryFunction::Terminate => {} + } + } + } + } + Ok(BlockInfo { stages, finished }) + } + + fn validate_block( + &mut self, + statements: &crate::Block, + context: &BlockContext, + ) -> Result<BlockInfo, WithSpan<FunctionError>> { + let base_expression_count = self.valid_expression_list.len(); + let info = self.validate_block_impl(statements, context)?; + for handle in self.valid_expression_list.drain(base_expression_count..) { + self.valid_expression_set.remove(handle.index()); + } + Ok(info) + } + + fn validate_local_var( + &self, + var: &crate::LocalVariable, + gctx: crate::proc::GlobalCtx, + fun_info: &FunctionInfo, + expression_constness: &crate::proc::ExpressionConstnessTracker, + ) -> Result<(), LocalVariableError> { + log::debug!("var {:?}", var); + let type_info = self + .types + .get(var.ty.index()) + .ok_or(LocalVariableError::InvalidType(var.ty))?; + if !type_info.flags.contains(super::TypeFlags::CONSTRUCTIBLE) { + return Err(LocalVariableError::InvalidType(var.ty)); + } + + if let Some(init) = var.init { + let decl_ty = &gctx.types[var.ty].inner; + let init_ty = fun_info[init].ty.inner_with(gctx.types); + if !decl_ty.equivalent(init_ty, gctx.types) { + return Err(LocalVariableError::InitializerType); + } + + if !expression_constness.is_const(init) { + return Err(LocalVariableError::NonConstInitializer); + } + } + + Ok(()) + } + + pub(super) fn validate_function( + &mut self, + fun: &crate::Function, + module: &crate::Module, + mod_info: &ModuleInfo, + entry_point: bool, + ) -> Result<FunctionInfo, WithSpan<FunctionError>> { + let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?; + + let expression_constness = + crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions); + + for (var_handle, var) in fun.local_variables.iter() { + self.validate_local_var(var, module.to_ctx(), &info, &expression_constness) + .map_err(|source| { + FunctionError::LocalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(var.ty, &module.types) + .with_handle(var_handle, &fun.local_variables) + })?; + } + + for (index, argument) in fun.arguments.iter().enumerate() { + match module.types[argument.ty].inner.pointer_space() { + Some(crate::AddressSpace::Private | crate::AddressSpace::Function) | None => {} + Some(other) => { + return Err(FunctionError::InvalidArgumentPointerSpace { + index, + name: argument.name.clone().unwrap_or_default(), + space: other, + } + .with_span_handle(argument.ty, &module.types)) + } + } + // Check for the least informative error last. + if !self.types[argument.ty.index()] + .flags + .contains(super::TypeFlags::ARGUMENT) + { + return Err(FunctionError::InvalidArgumentType { + index, + name: argument.name.clone().unwrap_or_default(), + } + .with_span_handle(argument.ty, &module.types)); + } + + if !entry_point && argument.binding.is_some() { + return Err(FunctionError::PipelineInputRegularFunction { + name: argument.name.clone().unwrap_or_default(), + } + .with_span_handle(argument.ty, &module.types)); + } + } + + if let Some(ref result) = fun.result { + if !self.types[result.ty.index()] + .flags + .contains(super::TypeFlags::CONSTRUCTIBLE) + { + return Err(FunctionError::NonConstructibleReturnType + .with_span_handle(result.ty, &module.types)); + } + + if !entry_point && result.binding.is_some() { + return Err(FunctionError::PipelineOutputRegularFunction + .with_span_handle(result.ty, &module.types)); + } + } + + self.valid_expression_set.clear(); + self.valid_expression_list.clear(); + for (handle, expr) in fun.expressions.iter() { + if expr.needs_pre_emit() { + self.valid_expression_set.insert(handle.index()); + } + if self.flags.contains(super::ValidationFlags::EXPRESSIONS) { + match self.validate_expression(handle, expr, fun, module, &info, mod_info) { + Ok(stages) => info.available_stages &= stages, + Err(source) => { + return Err(FunctionError::Expression { handle, source } + .with_span_handle(handle, &fun.expressions)) + } + } + } + } + + if self.flags.contains(super::ValidationFlags::BLOCKS) { + let stages = self + .validate_block( + &fun.body, + &BlockContext::new(fun, module, &info, &mod_info.functions), + )? + .stages; + info.available_stages &= stages; + } + Ok(info) + } +} diff --git a/third_party/rust/naga/src/valid/handles.rs b/third_party/rust/naga/src/valid/handles.rs new file mode 100644 index 0000000000..e482f293bb --- /dev/null +++ b/third_party/rust/naga/src/valid/handles.rs @@ -0,0 +1,699 @@ +//! Implementation of `Validator::validate_module_handles`. + +use crate::{ + arena::{BadHandle, BadRangeError}, + Handle, +}; + +use crate::{Arena, UniqueArena}; + +use super::ValidationError; + +use std::{convert::TryInto, hash::Hash, num::NonZeroU32}; + +impl super::Validator { + /// Validates that all handles within `module` are: + /// + /// * Valid, in the sense that they contain indices within each arena structure inside the + /// [`crate::Module`] type. + /// * No arena contents contain any items that have forward dependencies; that is, the value + /// associated with a handle only may contain references to handles in the same arena that + /// were constructed before it. + /// + /// By validating the above conditions, we free up subsequent logic to assume that handle + /// accesses are infallible. + /// + /// # Errors + /// + /// Errors returned by this method are intentionally sparse, for simplicity of implementation. + /// It is expected that only buggy frontends or fuzzers should ever emit IR that fails this + /// validation pass. + pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> { + let &crate::Module { + ref constants, + ref entry_points, + ref functions, + ref global_variables, + ref types, + ref special_types, + ref const_expressions, + } = module; + + // NOTE: Types being first is important. All other forms of validation depend on this. + for (this_handle, ty) in types.iter() { + match ty.inner { + crate::TypeInner::Scalar { .. } + | crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } + | crate::TypeInner::ValuePointer { .. } + | crate::TypeInner::Atomic { .. } + | crate::TypeInner::Image { .. } + | crate::TypeInner::Sampler { .. } + | crate::TypeInner::AccelerationStructure + | crate::TypeInner::RayQuery => (), + crate::TypeInner::Pointer { base, space: _ } => { + this_handle.check_dep(base)?; + } + crate::TypeInner::Array { base, .. } + | crate::TypeInner::BindingArray { base, .. } => { + this_handle.check_dep(base)?; + } + crate::TypeInner::Struct { + ref members, + span: _, + } => { + this_handle.check_dep_iter(members.iter().map(|m| m.ty))?; + } + } + } + + for handle_and_expr in const_expressions.iter() { + Self::validate_const_expression_handles(handle_and_expr, constants, types)?; + } + + let validate_type = |handle| Self::validate_type_handle(handle, types); + let validate_const_expr = + |handle| Self::validate_expression_handle(handle, const_expressions); + + for (_handle, constant) in constants.iter() { + let &crate::Constant { + name: _, + r#override: _, + ty, + init, + } = constant; + validate_type(ty)?; + validate_const_expr(init)?; + } + + for (_handle, global_variable) in global_variables.iter() { + let &crate::GlobalVariable { + name: _, + space: _, + binding: _, + ty, + init, + } = global_variable; + validate_type(ty)?; + if let Some(init_expr) = init { + validate_const_expr(init_expr)?; + } + } + + let validate_function = |function_handle, function: &_| -> Result<_, InvalidHandleError> { + let &crate::Function { + name: _, + ref arguments, + ref result, + ref local_variables, + ref expressions, + ref named_expressions, + ref body, + } = function; + + for arg in arguments.iter() { + let &crate::FunctionArgument { + name: _, + ty, + binding: _, + } = arg; + validate_type(ty)?; + } + + if let &Some(crate::FunctionResult { ty, binding: _ }) = result { + validate_type(ty)?; + } + + for (_handle, local_variable) in local_variables.iter() { + let &crate::LocalVariable { name: _, ty, init } = local_variable; + validate_type(ty)?; + if let Some(init) = init { + Self::validate_expression_handle(init, expressions)?; + } + } + + for handle in named_expressions.keys().copied() { + Self::validate_expression_handle(handle, expressions)?; + } + + for handle_and_expr in expressions.iter() { + Self::validate_expression_handles( + handle_and_expr, + constants, + const_expressions, + types, + local_variables, + global_variables, + functions, + function_handle, + )?; + } + + Self::validate_block_handles(body, expressions, functions)?; + + Ok(()) + }; + + for entry_point in entry_points.iter() { + validate_function(None, &entry_point.function)?; + } + + for (function_handle, function) in functions.iter() { + validate_function(Some(function_handle), function)?; + } + + if let Some(ty) = special_types.ray_desc { + validate_type(ty)?; + } + if let Some(ty) = special_types.ray_intersection { + validate_type(ty)?; + } + + Ok(()) + } + + fn validate_type_handle( + handle: Handle<crate::Type>, + types: &UniqueArena<crate::Type>, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for_uniq(types).map(|_| ()) + } + + fn validate_constant_handle( + handle: Handle<crate::Constant>, + constants: &Arena<crate::Constant>, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(constants).map(|_| ()) + } + + fn validate_expression_handle( + handle: Handle<crate::Expression>, + expressions: &Arena<crate::Expression>, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(expressions).map(|_| ()) + } + + fn validate_function_handle( + handle: Handle<crate::Function>, + functions: &Arena<crate::Function>, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(functions).map(|_| ()) + } + + fn validate_const_expression_handles( + (handle, expression): (Handle<crate::Expression>, &crate::Expression), + constants: &Arena<crate::Constant>, + types: &UniqueArena<crate::Type>, + ) -> Result<(), InvalidHandleError> { + let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_type = |handle| Self::validate_type_handle(handle, types); + + match *expression { + crate::Expression::Literal(_) => {} + crate::Expression::Constant(constant) => { + validate_constant(constant)?; + handle.check_dep(constants[constant].init)?; + } + crate::Expression::ZeroValue(ty) => { + validate_type(ty)?; + } + crate::Expression::Compose { ty, ref components } => { + validate_type(ty)?; + handle.check_dep_iter(components.iter().copied())?; + } + _ => {} + } + Ok(()) + } + + #[allow(clippy::too_many_arguments)] + fn validate_expression_handles( + (handle, expression): (Handle<crate::Expression>, &crate::Expression), + constants: &Arena<crate::Constant>, + const_expressions: &Arena<crate::Expression>, + types: &UniqueArena<crate::Type>, + local_variables: &Arena<crate::LocalVariable>, + global_variables: &Arena<crate::GlobalVariable>, + functions: &Arena<crate::Function>, + // The handle of the current function or `None` if it's an entry point + current_function: Option<Handle<crate::Function>>, + ) -> Result<(), InvalidHandleError> { + let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_const_expr = + |handle| Self::validate_expression_handle(handle, const_expressions); + let validate_type = |handle| Self::validate_type_handle(handle, types); + + match *expression { + crate::Expression::Access { base, index } => { + handle.check_dep(base)?.check_dep(index)?; + } + crate::Expression::AccessIndex { base, .. } => { + handle.check_dep(base)?; + } + crate::Expression::Splat { value, .. } => { + handle.check_dep(value)?; + } + crate::Expression::Swizzle { vector, .. } => { + handle.check_dep(vector)?; + } + crate::Expression::Literal(_) => {} + crate::Expression::Constant(constant) => { + validate_constant(constant)?; + } + crate::Expression::ZeroValue(ty) => { + validate_type(ty)?; + } + crate::Expression::Compose { ty, ref components } => { + validate_type(ty)?; + handle.check_dep_iter(components.iter().copied())?; + } + crate::Expression::FunctionArgument(_arg_idx) => (), + crate::Expression::GlobalVariable(global_variable) => { + global_variable.check_valid_for(global_variables)?; + } + crate::Expression::LocalVariable(local_variable) => { + local_variable.check_valid_for(local_variables)?; + } + crate::Expression::Load { pointer } => { + handle.check_dep(pointer)?; + } + crate::Expression::ImageSample { + image, + sampler, + gather: _, + coordinate, + array_index, + offset, + level, + depth_ref, + } => { + if let Some(offset) = offset { + validate_const_expr(offset)?; + } + + handle + .check_dep(image)? + .check_dep(sampler)? + .check_dep(coordinate)? + .check_dep_opt(array_index)?; + + match level { + crate::SampleLevel::Auto | crate::SampleLevel::Zero => (), + crate::SampleLevel::Exact(expr) => { + handle.check_dep(expr)?; + } + crate::SampleLevel::Bias(expr) => { + handle.check_dep(expr)?; + } + crate::SampleLevel::Gradient { x, y } => { + handle.check_dep(x)?.check_dep(y)?; + } + }; + + handle.check_dep_opt(depth_ref)?; + } + crate::Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + handle + .check_dep(image)? + .check_dep(coordinate)? + .check_dep_opt(array_index)? + .check_dep_opt(sample)? + .check_dep_opt(level)?; + } + crate::Expression::ImageQuery { image, query } => { + handle.check_dep(image)?; + match query { + crate::ImageQuery::Size { level } => { + handle.check_dep_opt(level)?; + } + crate::ImageQuery::NumLevels + | crate::ImageQuery::NumLayers + | crate::ImageQuery::NumSamples => (), + }; + } + crate::Expression::Unary { + op: _, + expr: operand, + } => { + handle.check_dep(operand)?; + } + crate::Expression::Binary { op: _, left, right } => { + handle.check_dep(left)?.check_dep(right)?; + } + crate::Expression::Select { + condition, + accept, + reject, + } => { + handle + .check_dep(condition)? + .check_dep(accept)? + .check_dep(reject)?; + } + crate::Expression::Derivative { expr: argument, .. } => { + handle.check_dep(argument)?; + } + crate::Expression::Relational { fun: _, argument } => { + handle.check_dep(argument)?; + } + crate::Expression::Math { + fun: _, + arg, + arg1, + arg2, + arg3, + } => { + handle + .check_dep(arg)? + .check_dep_opt(arg1)? + .check_dep_opt(arg2)? + .check_dep_opt(arg3)?; + } + crate::Expression::As { + expr: input, + kind: _, + convert: _, + } => { + handle.check_dep(input)?; + } + crate::Expression::CallResult(function) => { + Self::validate_function_handle(function, functions)?; + if let Some(handle) = current_function { + handle.check_dep(function)?; + } + } + crate::Expression::AtomicResult { .. } + | crate::Expression::RayQueryProceedResult + | crate::Expression::WorkGroupUniformLoadResult { .. } => (), + crate::Expression::ArrayLength(array) => { + handle.check_dep(array)?; + } + crate::Expression::RayQueryGetIntersection { + query, + committed: _, + } => { + handle.check_dep(query)?; + } + } + Ok(()) + } + + fn validate_block_handles( + block: &crate::Block, + expressions: &Arena<crate::Expression>, + functions: &Arena<crate::Function>, + ) -> Result<(), InvalidHandleError> { + let validate_block = |block| Self::validate_block_handles(block, expressions, functions); + let validate_expr = |handle| Self::validate_expression_handle(handle, expressions); + let validate_expr_opt = |handle_opt| { + if let Some(handle) = handle_opt { + validate_expr(handle)?; + } + Ok(()) + }; + + block.iter().try_for_each(|stmt| match *stmt { + crate::Statement::Emit(ref expr_range) => { + expr_range.check_valid_for(expressions)?; + Ok(()) + } + crate::Statement::Block(ref block) => { + validate_block(block)?; + Ok(()) + } + crate::Statement::If { + condition, + ref accept, + ref reject, + } => { + validate_expr(condition)?; + validate_block(accept)?; + validate_block(reject)?; + Ok(()) + } + crate::Statement::Switch { + selector, + ref cases, + } => { + validate_expr(selector)?; + for &crate::SwitchCase { + value: _, + ref body, + fall_through: _, + } in cases + { + validate_block(body)?; + } + Ok(()) + } + crate::Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + validate_block(body)?; + validate_block(continuing)?; + validate_expr_opt(break_if)?; + Ok(()) + } + crate::Statement::Return { value } => validate_expr_opt(value), + crate::Statement::Store { pointer, value } => { + validate_expr(pointer)?; + validate_expr(value)?; + Ok(()) + } + crate::Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => { + validate_expr(image)?; + validate_expr(coordinate)?; + validate_expr_opt(array_index)?; + validate_expr(value)?; + Ok(()) + } + crate::Statement::Atomic { + pointer, + fun, + value, + result, + } => { + validate_expr(pointer)?; + match fun { + crate::AtomicFunction::Add + | crate::AtomicFunction::Subtract + | crate::AtomicFunction::And + | crate::AtomicFunction::ExclusiveOr + | crate::AtomicFunction::InclusiveOr + | crate::AtomicFunction::Min + | crate::AtomicFunction::Max => (), + crate::AtomicFunction::Exchange { compare } => validate_expr_opt(compare)?, + }; + validate_expr(value)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::WorkGroupUniformLoad { pointer, result } => { + validate_expr(pointer)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::Call { + function, + ref arguments, + result, + } => { + Self::validate_function_handle(function, functions)?; + for arg in arguments.iter().copied() { + validate_expr(arg)?; + } + validate_expr_opt(result)?; + Ok(()) + } + crate::Statement::RayQuery { query, ref fun } => { + validate_expr(query)?; + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + validate_expr(acceleration_structure)?; + validate_expr(descriptor)?; + } + crate::RayQueryFunction::Proceed { result } => { + validate_expr(result)?; + } + crate::RayQueryFunction::Terminate => {} + } + Ok(()) + } + crate::Statement::Break + | crate::Statement::Continue + | crate::Statement::Kill + | crate::Statement::Barrier(_) => Ok(()), + }) + } +} + +impl From<BadHandle> for ValidationError { + fn from(source: BadHandle) -> Self { + Self::InvalidHandle(source.into()) + } +} + +impl From<FwdDepError> for ValidationError { + fn from(source: FwdDepError) -> Self { + Self::InvalidHandle(source.into()) + } +} + +impl From<BadRangeError> for ValidationError { + fn from(source: BadRangeError) -> Self { + Self::InvalidHandle(source.into()) + } +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum InvalidHandleError { + #[error(transparent)] + BadHandle(#[from] BadHandle), + #[error(transparent)] + ForwardDependency(#[from] FwdDepError), + #[error(transparent)] + BadRange(#[from] BadRangeError), +} + +#[derive(Clone, Debug, thiserror::Error)] +#[error( + "{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \ + which has not been processed yet" +)] +pub struct FwdDepError { + // This error is used for many `Handle` types, but there's no point in making this generic, so + // we just flatten them all to `Handle<()>` here. + subject: Handle<()>, + subject_kind: &'static str, + depends_on: Handle<()>, + depends_on_kind: &'static str, +} + +impl<T> Handle<T> { + /// Check that `self` is valid within `arena` using [`Arena::check_contains_handle`]. + pub(self) fn check_valid_for(self, arena: &Arena<T>) -> Result<(), InvalidHandleError> { + arena.check_contains_handle(self)?; + Ok(()) + } + + /// Check that `self` is valid within `arena` using [`UniqueArena::check_contains_handle`]. + pub(self) fn check_valid_for_uniq( + self, + arena: &UniqueArena<T>, + ) -> Result<(), InvalidHandleError> + where + T: Eq + Hash, + { + arena.check_contains_handle(self)?; + Ok(()) + } + + /// Check that `depends_on` was constructed before `self` by comparing handle indices. + /// + /// If `self` is a valid handle (i.e., it has been validated using [`Self::check_valid_for`]) + /// and this function returns [`Ok`], then it may be assumed that `depends_on` is also valid. + /// In [`naga`](crate)'s current arena-based implementation, this is useful for validating + /// recursive definitions of arena-based values in linear time. + /// + /// # Errors + /// + /// If `depends_on`'s handle is from the same [`Arena`] as `self'`s, but not constructed earlier + /// than `self`'s, this function returns an error. + pub(self) fn check_dep(self, depends_on: Self) -> Result<Self, FwdDepError> { + if depends_on < self { + Ok(self) + } else { + let erase_handle_type = |handle: Handle<_>| { + Handle::new(NonZeroU32::new((handle.index() + 1).try_into().unwrap()).unwrap()) + }; + Err(FwdDepError { + subject: erase_handle_type(self), + subject_kind: std::any::type_name::<T>(), + depends_on: erase_handle_type(depends_on), + depends_on_kind: std::any::type_name::<T>(), + }) + } + } + + /// Like [`Self::check_dep`], except for [`Option`]al handle values. + pub(self) fn check_dep_opt(self, depends_on: Option<Self>) -> Result<Self, FwdDepError> { + self.check_dep_iter(depends_on.into_iter()) + } + + /// Like [`Self::check_dep`], except for [`Iterator`]s over handle values. + pub(self) fn check_dep_iter( + self, + depends_on: impl Iterator<Item = Self>, + ) -> Result<Self, FwdDepError> { + for handle in depends_on { + self.check_dep(handle)?; + } + Ok(self) + } +} + +impl<T> crate::arena::Range<T> { + pub(self) fn check_valid_for(&self, arena: &Arena<T>) -> Result<(), BadRangeError> { + arena.check_contains_range(self) + } +} + +#[test] +fn constant_deps() { + use crate::{Constant, Expression, Literal, Span, Type, TypeInner}; + + let nowhere = Span::default(); + + let mut types = UniqueArena::new(); + let mut const_exprs = Arena::new(); + let mut fun_exprs = Arena::new(); + let mut constants = Arena::new(); + + let i32_handle = types.insert( + Type { + name: None, + inner: TypeInner::Scalar(crate::Scalar::I32), + }, + nowhere, + ); + + // Construct a self-referential constant by misusing a handle to + // fun_exprs as a constant initializer. + let fun_expr = fun_exprs.append(Expression::Literal(Literal::I32(42)), nowhere); + let self_referential_const = constants.append( + Constant { + name: None, + r#override: crate::Override::None, + ty: i32_handle, + init: fun_expr, + }, + nowhere, + ); + let _self_referential_expr = + const_exprs.append(Expression::Constant(self_referential_const), nowhere); + + for handle_and_expr in const_exprs.iter() { + assert!(super::Validator::validate_const_expression_handles( + handle_and_expr, + &constants, + &types, + ) + .is_err()); + } +} diff --git a/third_party/rust/naga/src/valid/interface.rs b/third_party/rust/naga/src/valid/interface.rs new file mode 100644 index 0000000000..84c8b09ddb --- /dev/null +++ b/third_party/rust/naga/src/valid/interface.rs @@ -0,0 +1,709 @@ +use super::{ + analyzer::{FunctionInfo, GlobalUse}, + Capabilities, Disalignment, FunctionError, ModuleInfo, +}; +use crate::arena::{Handle, UniqueArena}; + +use crate::span::{AddSpan as _, MapErrWithSpan as _, SpanProvider as _, WithSpan}; +use bit_set::BitSet; + +const MAX_WORKGROUP_SIZE: u32 = 0x4000; + +#[derive(Clone, Debug, thiserror::Error)] +pub enum GlobalVariableError { + #[error("Usage isn't compatible with address space {0:?}")] + InvalidUsage(crate::AddressSpace), + #[error("Type isn't compatible with address space {0:?}")] + InvalidType(crate::AddressSpace), + #[error("Type flags {seen:?} do not meet the required {required:?}")] + MissingTypeFlags { + required: super::TypeFlags, + seen: super::TypeFlags, + }, + #[error("Capability {0:?} is not supported")] + UnsupportedCapability(Capabilities), + #[error("Binding decoration is missing or not applicable")] + InvalidBinding, + #[error("Alignment requirements for address space {0:?} are not met by {1:?}")] + Alignment( + crate::AddressSpace, + Handle<crate::Type>, + #[source] Disalignment, + ), + #[error("Initializer doesn't match the variable type")] + InitializerType, + #[error("Initializer can't be used with address space {0:?}")] + InitializerNotAllowed(crate::AddressSpace), + #[error("Storage address space doesn't support write-only access")] + StorageAddressSpaceWriteOnlyNotSupported, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum VaryingError { + #[error("The type {0:?} does not match the varying")] + InvalidType(Handle<crate::Type>), + #[error("The type {0:?} cannot be used for user-defined entry point inputs or outputs")] + NotIOShareableType(Handle<crate::Type>), + #[error("Interpolation is not valid")] + InvalidInterpolation, + #[error("Interpolation must be specified on vertex shader outputs and fragment shader inputs")] + MissingInterpolation, + #[error("Built-in {0:?} is not available at this stage")] + InvalidBuiltInStage(crate::BuiltIn), + #[error("Built-in type for {0:?} is invalid")] + InvalidBuiltInType(crate::BuiltIn), + #[error("Entry point arguments and return values must all have bindings")] + MissingBinding, + #[error("Struct member {0} is missing a binding")] + MemberMissingBinding(u32), + #[error("Multiple bindings at location {location} are present")] + BindingCollision { location: u32 }, + #[error("Built-in {0:?} is present more than once")] + DuplicateBuiltIn(crate::BuiltIn), + #[error("Capability {0:?} is not supported")] + UnsupportedCapability(Capabilities), + #[error("The attribute {0:?} is only valid as an output for stage {1:?}")] + InvalidInputAttributeInStage(&'static str, crate::ShaderStage), + #[error("The attribute {0:?} is not valid for stage {1:?}")] + InvalidAttributeInStage(&'static str, crate::ShaderStage), + #[error( + "The location index {location} cannot be used together with the attribute {attribute:?}" + )] + InvalidLocationAttributeCombination { + location: u32, + attribute: &'static str, + }, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum EntryPointError { + #[error("Multiple conflicting entry points")] + Conflict, + #[error("Vertex shaders must return a `@builtin(position)` output value")] + MissingVertexOutputPosition, + #[error("Early depth test is not applicable")] + UnexpectedEarlyDepthTest, + #[error("Workgroup size is not applicable")] + UnexpectedWorkgroupSize, + #[error("Workgroup size is out of range")] + OutOfRangeWorkgroupSize, + #[error("Uses operations forbidden at this stage")] + ForbiddenStageOperations, + #[error("Global variable {0:?} is used incorrectly as {1:?}")] + InvalidGlobalUsage(Handle<crate::GlobalVariable>, GlobalUse), + #[error("More than 1 push constant variable is used")] + MoreThanOnePushConstantUsed, + #[error("Bindings for {0:?} conflict with other resource")] + BindingCollision(Handle<crate::GlobalVariable>), + #[error("Argument {0} varying error")] + Argument(u32, #[source] VaryingError), + #[error(transparent)] + Result(#[from] VaryingError), + #[error("Location {location} interpolation of an integer has to be flat")] + InvalidIntegerInterpolation { location: u32 }, + #[error(transparent)] + Function(#[from] FunctionError), + #[error( + "Invalid locations {location_mask:?} are set while dual source blending. Only location 0 may be set." + )] + InvalidLocationsWhileDualSourceBlending { location_mask: BitSet }, +} + +fn storage_usage(access: crate::StorageAccess) -> GlobalUse { + let mut storage_usage = GlobalUse::QUERY; + if access.contains(crate::StorageAccess::LOAD) { + storage_usage |= GlobalUse::READ; + } + if access.contains(crate::StorageAccess::STORE) { + storage_usage |= GlobalUse::WRITE; + } + storage_usage +} + +struct VaryingContext<'a> { + stage: crate::ShaderStage, + output: bool, + second_blend_source: bool, + types: &'a UniqueArena<crate::Type>, + type_info: &'a Vec<super::r#type::TypeInfo>, + location_mask: &'a mut BitSet, + built_ins: &'a mut crate::FastHashSet<crate::BuiltIn>, + capabilities: Capabilities, + flags: super::ValidationFlags, +} + +impl VaryingContext<'_> { + fn validate_impl( + &mut self, + ty: Handle<crate::Type>, + binding: &crate::Binding, + ) -> Result<(), VaryingError> { + use crate::{BuiltIn as Bi, ShaderStage as St, TypeInner as Ti, VectorSize as Vs}; + + let ty_inner = &self.types[ty].inner; + match *binding { + crate::Binding::BuiltIn(built_in) => { + // Ignore the `invariant` field for the sake of duplicate checks, + // but use the original in error messages. + let canonical = if let crate::BuiltIn::Position { .. } = built_in { + crate::BuiltIn::Position { invariant: false } + } else { + built_in + }; + + if self.built_ins.contains(&canonical) { + return Err(VaryingError::DuplicateBuiltIn(built_in)); + } + self.built_ins.insert(canonical); + + let required = match built_in { + Bi::ClipDistance => Capabilities::CLIP_DISTANCE, + Bi::CullDistance => Capabilities::CULL_DISTANCE, + Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX, + Bi::ViewIndex => Capabilities::MULTIVIEW, + Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING, + _ => Capabilities::empty(), + }; + if !self.capabilities.contains(required) { + return Err(VaryingError::UnsupportedCapability(required)); + } + + let (visible, type_good) = match built_in { + Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => ( + self.stage == St::Vertex && !self.output, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::ClipDistance | Bi::CullDistance => ( + self.stage == St::Vertex && self.output, + match *ty_inner { + Ti::Array { base, .. } => { + self.types[base].inner == Ti::Scalar(crate::Scalar::F32) + } + _ => false, + }, + ), + Bi::PointSize => ( + self.stage == St::Vertex && self.output, + *ty_inner == Ti::Scalar(crate::Scalar::F32), + ), + Bi::PointCoord => ( + self.stage == St::Fragment && !self.output, + *ty_inner + == Ti::Vector { + size: Vs::Bi, + scalar: crate::Scalar::F32, + }, + ), + Bi::Position { .. } => ( + match self.stage { + St::Vertex => self.output, + St::Fragment => !self.output, + St::Compute => false, + }, + *ty_inner + == Ti::Vector { + size: Vs::Quad, + scalar: crate::Scalar::F32, + }, + ), + Bi::ViewIndex => ( + match self.stage { + St::Vertex | St::Fragment => !self.output, + St::Compute => false, + }, + *ty_inner == Ti::Scalar(crate::Scalar::I32), + ), + Bi::FragDepth => ( + self.stage == St::Fragment && self.output, + *ty_inner == Ti::Scalar(crate::Scalar::F32), + ), + Bi::FrontFacing => ( + self.stage == St::Fragment && !self.output, + *ty_inner == Ti::Scalar(crate::Scalar::BOOL), + ), + Bi::PrimitiveIndex => ( + self.stage == St::Fragment && !self.output, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::SampleIndex => ( + self.stage == St::Fragment && !self.output, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::SampleMask => ( + self.stage == St::Fragment, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::LocalInvocationIndex => ( + self.stage == St::Compute && !self.output, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::GlobalInvocationId + | Bi::LocalInvocationId + | Bi::WorkGroupId + | Bi::WorkGroupSize + | Bi::NumWorkGroups => ( + self.stage == St::Compute && !self.output, + *ty_inner + == Ti::Vector { + size: Vs::Tri, + scalar: crate::Scalar::U32, + }, + ), + }; + + if !visible { + return Err(VaryingError::InvalidBuiltInStage(built_in)); + } + if !type_good { + log::warn!("Wrong builtin type: {:?}", ty_inner); + return Err(VaryingError::InvalidBuiltInType(built_in)); + } + } + crate::Binding::Location { + location, + interpolation, + sampling, + second_blend_source, + } => { + // Only IO-shareable types may be stored in locations. + if !self.type_info[ty.index()] + .flags + .contains(super::TypeFlags::IO_SHAREABLE) + { + return Err(VaryingError::NotIOShareableType(ty)); + } + + if second_blend_source { + if !self + .capabilities + .contains(Capabilities::DUAL_SOURCE_BLENDING) + { + return Err(VaryingError::UnsupportedCapability( + Capabilities::DUAL_SOURCE_BLENDING, + )); + } + if self.stage != crate::ShaderStage::Fragment { + return Err(VaryingError::InvalidAttributeInStage( + "second_blend_source", + self.stage, + )); + } + if !self.output { + return Err(VaryingError::InvalidInputAttributeInStage( + "second_blend_source", + self.stage, + )); + } + if location != 0 { + return Err(VaryingError::InvalidLocationAttributeCombination { + location, + attribute: "second_blend_source", + }); + } + + self.second_blend_source = true; + } else if !self.location_mask.insert(location as usize) { + if self.flags.contains(super::ValidationFlags::BINDINGS) { + return Err(VaryingError::BindingCollision { location }); + } + } + + let needs_interpolation = match self.stage { + crate::ShaderStage::Vertex => self.output, + crate::ShaderStage::Fragment => !self.output, + crate::ShaderStage::Compute => false, + }; + + // It doesn't make sense to specify a sampling when `interpolation` is `Flat`, but + // SPIR-V and GLSL both explicitly tolerate such combinations of decorators / + // qualifiers, so we won't complain about that here. + let _ = sampling; + + let required = match sampling { + Some(crate::Sampling::Sample) => Capabilities::MULTISAMPLED_SHADING, + _ => Capabilities::empty(), + }; + if !self.capabilities.contains(required) { + return Err(VaryingError::UnsupportedCapability(required)); + } + + match ty_inner.scalar_kind() { + Some(crate::ScalarKind::Float) => { + if needs_interpolation && interpolation.is_none() { + return Err(VaryingError::MissingInterpolation); + } + } + Some(_) => { + if needs_interpolation && interpolation != Some(crate::Interpolation::Flat) + { + return Err(VaryingError::InvalidInterpolation); + } + } + None => return Err(VaryingError::InvalidType(ty)), + } + } + } + + Ok(()) + } + + fn validate( + &mut self, + ty: Handle<crate::Type>, + binding: Option<&crate::Binding>, + ) -> Result<(), WithSpan<VaryingError>> { + let span_context = self.types.get_span_context(ty); + match binding { + Some(binding) => self + .validate_impl(ty, binding) + .map_err(|e| e.with_span_context(span_context)), + None => { + match self.types[ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + for (index, member) in members.iter().enumerate() { + let span_context = self.types.get_span_context(ty); + match member.binding { + None => { + if self.flags.contains(super::ValidationFlags::BINDINGS) { + return Err(VaryingError::MemberMissingBinding( + index as u32, + ) + .with_span_context(span_context)); + } + } + Some(ref binding) => self + .validate_impl(member.ty, binding) + .map_err(|e| e.with_span_context(span_context))?, + } + } + } + _ => { + if self.flags.contains(super::ValidationFlags::BINDINGS) { + return Err(VaryingError::MissingBinding.with_span()); + } + } + } + Ok(()) + } + } + } +} + +impl super::Validator { + pub(super) fn validate_global_var( + &self, + var: &crate::GlobalVariable, + gctx: crate::proc::GlobalCtx, + mod_info: &ModuleInfo, + ) -> Result<(), GlobalVariableError> { + use super::TypeFlags; + + log::debug!("var {:?}", var); + let inner_ty = match gctx.types[var.ty].inner { + // A binding array is (mostly) supposed to behave the same as a + // series of individually bound resources, so we can (mostly) + // validate a `binding_array<T>` as if it were just a plain `T`. + crate::TypeInner::BindingArray { base, .. } => match var.space { + crate::AddressSpace::Storage { .. } + | crate::AddressSpace::Uniform + | crate::AddressSpace::Handle => base, + _ => return Err(GlobalVariableError::InvalidUsage(var.space)), + }, + _ => var.ty, + }; + let type_info = &self.types[inner_ty.index()]; + + let (required_type_flags, is_resource) = match var.space { + crate::AddressSpace::Function => { + return Err(GlobalVariableError::InvalidUsage(var.space)) + } + crate::AddressSpace::Storage { access } => { + if let Err((ty_handle, disalignment)) = type_info.storage_layout { + if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) { + return Err(GlobalVariableError::Alignment( + var.space, + ty_handle, + disalignment, + )); + } + } + if access == crate::StorageAccess::STORE { + return Err(GlobalVariableError::StorageAddressSpaceWriteOnlyNotSupported); + } + (TypeFlags::DATA | TypeFlags::HOST_SHAREABLE, true) + } + crate::AddressSpace::Uniform => { + if let Err((ty_handle, disalignment)) = type_info.uniform_layout { + if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) { + return Err(GlobalVariableError::Alignment( + var.space, + ty_handle, + disalignment, + )); + } + } + ( + TypeFlags::DATA + | TypeFlags::COPY + | TypeFlags::SIZED + | TypeFlags::HOST_SHAREABLE, + true, + ) + } + crate::AddressSpace::Handle => { + match gctx.types[inner_ty].inner { + crate::TypeInner::Image { class, .. } => match class { + crate::ImageClass::Storage { + format: + crate::StorageFormat::R16Unorm + | crate::StorageFormat::R16Snorm + | crate::StorageFormat::Rg16Unorm + | crate::StorageFormat::Rg16Snorm + | crate::StorageFormat::Rgba16Unorm + | crate::StorageFormat::Rgba16Snorm, + .. + } => { + if !self + .capabilities + .contains(Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS) + { + return Err(GlobalVariableError::UnsupportedCapability( + Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS, + )); + } + } + _ => {} + }, + crate::TypeInner::Sampler { .. } + | crate::TypeInner::AccelerationStructure + | crate::TypeInner::RayQuery => {} + _ => { + return Err(GlobalVariableError::InvalidType(var.space)); + } + } + + (TypeFlags::empty(), true) + } + crate::AddressSpace::Private => (TypeFlags::CONSTRUCTIBLE, false), + crate::AddressSpace::WorkGroup => (TypeFlags::DATA | TypeFlags::SIZED, false), + crate::AddressSpace::PushConstant => { + if !self.capabilities.contains(Capabilities::PUSH_CONSTANT) { + return Err(GlobalVariableError::UnsupportedCapability( + Capabilities::PUSH_CONSTANT, + )); + } + ( + TypeFlags::DATA + | TypeFlags::COPY + | TypeFlags::HOST_SHAREABLE + | TypeFlags::SIZED, + false, + ) + } + }; + + if !type_info.flags.contains(required_type_flags) { + return Err(GlobalVariableError::MissingTypeFlags { + seen: type_info.flags, + required: required_type_flags, + }); + } + + if is_resource != var.binding.is_some() { + if self.flags.contains(super::ValidationFlags::BINDINGS) { + return Err(GlobalVariableError::InvalidBinding); + } + } + + if let Some(init) = var.init { + match var.space { + crate::AddressSpace::Private | crate::AddressSpace::Function => {} + _ => { + return Err(GlobalVariableError::InitializerNotAllowed(var.space)); + } + } + + let decl_ty = &gctx.types[var.ty].inner; + let init_ty = mod_info[init].inner_with(gctx.types); + if !decl_ty.equivalent(init_ty, gctx.types) { + return Err(GlobalVariableError::InitializerType); + } + } + + Ok(()) + } + + pub(super) fn validate_entry_point( + &mut self, + ep: &crate::EntryPoint, + module: &crate::Module, + mod_info: &ModuleInfo, + ) -> Result<FunctionInfo, WithSpan<EntryPointError>> { + if ep.early_depth_test.is_some() { + let required = Capabilities::EARLY_DEPTH_TEST; + if !self.capabilities.contains(required) { + return Err( + EntryPointError::Result(VaryingError::UnsupportedCapability(required)) + .with_span(), + ); + } + + if ep.stage != crate::ShaderStage::Fragment { + return Err(EntryPointError::UnexpectedEarlyDepthTest.with_span()); + } + } + + if ep.stage == crate::ShaderStage::Compute { + if ep + .workgroup_size + .iter() + .any(|&s| s == 0 || s > MAX_WORKGROUP_SIZE) + { + return Err(EntryPointError::OutOfRangeWorkgroupSize.with_span()); + } + } else if ep.workgroup_size != [0; 3] { + return Err(EntryPointError::UnexpectedWorkgroupSize.with_span()); + } + + let mut info = self + .validate_function(&ep.function, module, mod_info, true) + .map_err(WithSpan::into_other)?; + + { + use super::ShaderStages; + + let stage_bit = match ep.stage { + crate::ShaderStage::Vertex => ShaderStages::VERTEX, + crate::ShaderStage::Fragment => ShaderStages::FRAGMENT, + crate::ShaderStage::Compute => ShaderStages::COMPUTE, + }; + + if !info.available_stages.contains(stage_bit) { + return Err(EntryPointError::ForbiddenStageOperations.with_span()); + } + } + + self.location_mask.clear(); + let mut argument_built_ins = crate::FastHashSet::default(); + // TODO: add span info to function arguments + for (index, fa) in ep.function.arguments.iter().enumerate() { + let mut ctx = VaryingContext { + stage: ep.stage, + output: false, + second_blend_source: false, + types: &module.types, + type_info: &self.types, + location_mask: &mut self.location_mask, + built_ins: &mut argument_built_ins, + capabilities: self.capabilities, + flags: self.flags, + }; + ctx.validate(fa.ty, fa.binding.as_ref()) + .map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?; + } + + self.location_mask.clear(); + if let Some(ref fr) = ep.function.result { + let mut result_built_ins = crate::FastHashSet::default(); + let mut ctx = VaryingContext { + stage: ep.stage, + output: true, + second_blend_source: false, + types: &module.types, + type_info: &self.types, + location_mask: &mut self.location_mask, + built_ins: &mut result_built_ins, + capabilities: self.capabilities, + flags: self.flags, + }; + ctx.validate(fr.ty, fr.binding.as_ref()) + .map_err_inner(|e| EntryPointError::Result(e).with_span())?; + if ctx.second_blend_source { + // Only the first location may be used when dual source blending + if ctx.location_mask.len() == 1 && ctx.location_mask.contains(0) { + info.dual_source_blending = true; + } else { + return Err(EntryPointError::InvalidLocationsWhileDualSourceBlending { + location_mask: self.location_mask.clone(), + } + .with_span()); + } + } + + if ep.stage == crate::ShaderStage::Vertex + && !result_built_ins.contains(&crate::BuiltIn::Position { invariant: false }) + { + return Err(EntryPointError::MissingVertexOutputPosition.with_span()); + } + } else if ep.stage == crate::ShaderStage::Vertex { + return Err(EntryPointError::MissingVertexOutputPosition.with_span()); + } + + { + let used_push_constants = module + .global_variables + .iter() + .filter(|&(_, var)| var.space == crate::AddressSpace::PushConstant) + .map(|(handle, _)| handle) + .filter(|&handle| !info[handle].is_empty()); + // Check if there is more than one push constant, and error if so. + // Use a loop for when returning multiple errors is supported. + #[allow(clippy::never_loop)] + for handle in used_push_constants.skip(1) { + return Err(EntryPointError::MoreThanOnePushConstantUsed + .with_span_handle(handle, &module.global_variables)); + } + } + + self.ep_resource_bindings.clear(); + for (var_handle, var) in module.global_variables.iter() { + let usage = info[var_handle]; + if usage.is_empty() { + continue; + } + + let allowed_usage = match var.space { + crate::AddressSpace::Function => unreachable!(), + crate::AddressSpace::Uniform => GlobalUse::READ | GlobalUse::QUERY, + crate::AddressSpace::Storage { access } => storage_usage(access), + crate::AddressSpace::Handle => match module.types[var.ty].inner { + crate::TypeInner::BindingArray { base, .. } => match module.types[base].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage { access, .. }, + .. + } => storage_usage(access), + _ => GlobalUse::READ | GlobalUse::QUERY, + }, + crate::TypeInner::Image { + class: crate::ImageClass::Storage { access, .. }, + .. + } => storage_usage(access), + _ => GlobalUse::READ | GlobalUse::QUERY, + }, + crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => GlobalUse::all(), + crate::AddressSpace::PushConstant => GlobalUse::READ, + }; + if !allowed_usage.contains(usage) { + log::warn!("\tUsage error for: {:?}", var); + log::warn!( + "\tAllowed usage: {:?}, requested: {:?}", + allowed_usage, + usage + ); + return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage) + .with_span_handle(var_handle, &module.global_variables)); + } + + if let Some(ref bind) = var.binding { + if !self.ep_resource_bindings.insert(bind.clone()) { + if self.flags.contains(super::ValidationFlags::BINDINGS) { + return Err(EntryPointError::BindingCollision(var_handle) + .with_span_handle(var_handle, &module.global_variables)); + } + } + } + } + + Ok(info) + } +} diff --git a/third_party/rust/naga/src/valid/mod.rs b/third_party/rust/naga/src/valid/mod.rs new file mode 100644 index 0000000000..388495a3ac --- /dev/null +++ b/third_party/rust/naga/src/valid/mod.rs @@ -0,0 +1,477 @@ +/*! +Shader validator. +*/ + +mod analyzer; +mod compose; +mod expression; +mod function; +mod handles; +mod interface; +mod r#type; + +use crate::{ + arena::Handle, + proc::{LayoutError, Layouter, TypeResolution}, + FastHashSet, +}; +use bit_set::BitSet; +use std::ops; + +//TODO: analyze the model at the same time as we validate it, +// merge the corresponding matches over expressions and statements. + +use crate::span::{AddSpan as _, WithSpan}; +pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements}; +pub use compose::ComposeError; +pub use expression::{check_literal_value, LiteralError}; +pub use expression::{ConstExpressionError, ExpressionError}; +pub use function::{CallError, FunctionError, LocalVariableError}; +pub use interface::{EntryPointError, GlobalVariableError, VaryingError}; +pub use r#type::{Disalignment, TypeError, TypeFlags}; + +use self::handles::InvalidHandleError; + +bitflags::bitflags! { + /// Validation flags. + /// + /// If you are working with trusted shaders, then you may be able + /// to save some time by skipping validation. + /// + /// If you do not perform full validation, invalid shaders may + /// cause Naga to panic. If you do perform full validation and + /// [`Validator::validate`] returns `Ok`, then Naga promises that + /// code generation will either succeed or return an error; it + /// should never panic. + /// + /// The default value for `ValidationFlags` is + /// `ValidationFlags::all()`. + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct ValidationFlags: u8 { + /// Expressions. + const EXPRESSIONS = 0x1; + /// Statements and blocks of them. + const BLOCKS = 0x2; + /// Uniformity of control flow for operations that require it. + const CONTROL_FLOW_UNIFORMITY = 0x4; + /// Host-shareable structure layouts. + const STRUCT_LAYOUTS = 0x8; + /// Constants. + const CONSTANTS = 0x10; + /// Group, binding, and location attributes. + const BINDINGS = 0x20; + } +} + +impl Default for ValidationFlags { + fn default() -> Self { + Self::all() + } +} + +bitflags::bitflags! { + /// Allowed IR capabilities. + #[must_use] + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct Capabilities: u16 { + /// Support for [`AddressSpace:PushConstant`]. + const PUSH_CONSTANT = 0x1; + /// Float values with width = 8. + const FLOAT64 = 0x2; + /// Support for [`Builtin:PrimitiveIndex`]. + const PRIMITIVE_INDEX = 0x4; + /// Support for non-uniform indexing of sampled textures and storage buffer arrays. + const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 0x8; + /// Support for non-uniform indexing of uniform buffers and storage texture arrays. + const UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 0x10; + /// Support for non-uniform indexing of samplers. + const SAMPLER_NON_UNIFORM_INDEXING = 0x20; + /// Support for [`Builtin::ClipDistance`]. + const CLIP_DISTANCE = 0x40; + /// Support for [`Builtin::CullDistance`]. + const CULL_DISTANCE = 0x80; + /// Support for 16-bit normalized storage texture formats. + const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 0x100; + /// Support for [`BuiltIn::ViewIndex`]. + const MULTIVIEW = 0x200; + /// Support for `early_depth_test`. + const EARLY_DEPTH_TEST = 0x400; + /// Support for [`Builtin::SampleIndex`] and [`Sampling::Sample`]. + const MULTISAMPLED_SHADING = 0x800; + /// Support for ray queries and acceleration structures. + const RAY_QUERY = 0x1000; + /// Support for generating two sources for blending from fragment shaders. + const DUAL_SOURCE_BLENDING = 0x2000; + /// Support for arrayed cube textures. + const CUBE_ARRAY_TEXTURES = 0x4000; + } +} + +impl Default for Capabilities { + fn default() -> Self { + Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES + } +} + +bitflags::bitflags! { + /// Validation flags. + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct ShaderStages: u8 { + const VERTEX = 0x1; + const FRAGMENT = 0x2; + const COMPUTE = 0x4; + } +} + +#[derive(Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct ModuleInfo { + type_flags: Vec<TypeFlags>, + functions: Vec<FunctionInfo>, + entry_points: Vec<FunctionInfo>, + const_expression_types: Box<[TypeResolution]>, +} + +impl ops::Index<Handle<crate::Type>> for ModuleInfo { + type Output = TypeFlags; + fn index(&self, handle: Handle<crate::Type>) -> &Self::Output { + &self.type_flags[handle.index()] + } +} + +impl ops::Index<Handle<crate::Function>> for ModuleInfo { + type Output = FunctionInfo; + fn index(&self, handle: Handle<crate::Function>) -> &Self::Output { + &self.functions[handle.index()] + } +} + +impl ops::Index<Handle<crate::Expression>> for ModuleInfo { + type Output = TypeResolution; + fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output { + &self.const_expression_types[handle.index()] + } +} + +#[derive(Debug)] +pub struct Validator { + flags: ValidationFlags, + capabilities: Capabilities, + types: Vec<r#type::TypeInfo>, + layouter: Layouter, + location_mask: BitSet, + ep_resource_bindings: FastHashSet<crate::ResourceBinding>, + #[allow(dead_code)] + switch_values: FastHashSet<crate::SwitchValue>, + valid_expression_list: Vec<Handle<crate::Expression>>, + valid_expression_set: BitSet, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum ConstantError { + #[error("The type doesn't match the constant")] + InvalidType, + #[error("The type is not constructible")] + NonConstructibleType, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum ValidationError { + #[error(transparent)] + InvalidHandle(#[from] InvalidHandleError), + #[error(transparent)] + Layouter(#[from] LayoutError), + #[error("Type {handle:?} '{name}' is invalid")] + Type { + handle: Handle<crate::Type>, + name: String, + source: TypeError, + }, + #[error("Constant expression {handle:?} is invalid")] + ConstExpression { + handle: Handle<crate::Expression>, + source: ConstExpressionError, + }, + #[error("Constant {handle:?} '{name}' is invalid")] + Constant { + handle: Handle<crate::Constant>, + name: String, + source: ConstantError, + }, + #[error("Global variable {handle:?} '{name}' is invalid")] + GlobalVariable { + handle: Handle<crate::GlobalVariable>, + name: String, + source: GlobalVariableError, + }, + #[error("Function {handle:?} '{name}' is invalid")] + Function { + handle: Handle<crate::Function>, + name: String, + source: FunctionError, + }, + #[error("Entry point {name} at {stage:?} is invalid")] + EntryPoint { + stage: crate::ShaderStage, + name: String, + source: EntryPointError, + }, + #[error("Module is corrupted")] + Corrupted, +} + +impl crate::TypeInner { + const fn is_sized(&self) -> bool { + match *self { + Self::Scalar { .. } + | Self::Vector { .. } + | Self::Matrix { .. } + | Self::Array { + size: crate::ArraySize::Constant(_), + .. + } + | Self::Atomic { .. } + | Self::Pointer { .. } + | Self::ValuePointer { .. } + | Self::Struct { .. } => true, + Self::Array { .. } + | Self::Image { .. } + | Self::Sampler { .. } + | Self::AccelerationStructure + | Self::RayQuery + | Self::BindingArray { .. } => false, + } + } + + /// Return the `ImageDimension` for which `self` is an appropriate coordinate. + const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> { + match *self { + Self::Scalar(crate::Scalar { + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + .. + }) => Some(crate::ImageDimension::D1), + Self::Vector { + size: crate::VectorSize::Bi, + scalar: + crate::Scalar { + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + .. + }, + } => Some(crate::ImageDimension::D2), + Self::Vector { + size: crate::VectorSize::Tri, + scalar: + crate::Scalar { + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + .. + }, + } => Some(crate::ImageDimension::D3), + _ => None, + } + } +} + +impl Validator { + /// Construct a new validator instance. + pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self { + Validator { + flags, + capabilities, + types: Vec::new(), + layouter: Layouter::default(), + location_mask: BitSet::new(), + ep_resource_bindings: FastHashSet::default(), + switch_values: FastHashSet::default(), + valid_expression_list: Vec::new(), + valid_expression_set: BitSet::new(), + } + } + + /// Reset the validator internals + pub fn reset(&mut self) { + self.types.clear(); + self.layouter.clear(); + self.location_mask.clear(); + self.ep_resource_bindings.clear(); + self.switch_values.clear(); + self.valid_expression_list.clear(); + self.valid_expression_set.clear(); + } + + fn validate_constant( + &self, + handle: Handle<crate::Constant>, + gctx: crate::proc::GlobalCtx, + mod_info: &ModuleInfo, + ) -> Result<(), ConstantError> { + let con = &gctx.constants[handle]; + + let type_info = &self.types[con.ty.index()]; + if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) { + return Err(ConstantError::NonConstructibleType); + } + + let decl_ty = &gctx.types[con.ty].inner; + let init_ty = mod_info[con.init].inner_with(gctx.types); + if !decl_ty.equivalent(init_ty, gctx.types) { + return Err(ConstantError::InvalidType); + } + + Ok(()) + } + + /// Check the given module to be valid. + pub fn validate( + &mut self, + module: &crate::Module, + ) -> Result<ModuleInfo, WithSpan<ValidationError>> { + self.reset(); + self.reset_types(module.types.len()); + + Self::validate_module_handles(module).map_err(|e| e.with_span())?; + + self.layouter.update(module.to_ctx()).map_err(|e| { + let handle = e.ty; + ValidationError::from(e).with_span_handle(handle, &module.types) + })?; + + // These should all get overwritten. + let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Bool, + width: 0, + })); + + let mut mod_info = ModuleInfo { + type_flags: Vec::with_capacity(module.types.len()), + functions: Vec::with_capacity(module.functions.len()), + entry_points: Vec::with_capacity(module.entry_points.len()), + const_expression_types: vec![placeholder; module.const_expressions.len()] + .into_boxed_slice(), + }; + + for (handle, ty) in module.types.iter() { + let ty_info = self + .validate_type(handle, module.to_ctx()) + .map_err(|source| { + ValidationError::Type { + handle, + name: ty.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.types) + })?; + mod_info.type_flags.push(ty_info.flags); + self.types[handle.index()] = ty_info; + } + + { + let t = crate::Arena::new(); + let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]); + for (handle, _) in module.const_expressions.iter() { + mod_info + .process_const_expression(handle, &resolve_context, module.to_ctx()) + .map_err(|source| { + ValidationError::ConstExpression { handle, source } + .with_span_handle(handle, &module.const_expressions) + })? + } + } + + if self.flags.contains(ValidationFlags::CONSTANTS) { + for (handle, _) in module.const_expressions.iter() { + self.validate_const_expression(handle, module.to_ctx(), &mod_info) + .map_err(|source| { + ValidationError::ConstExpression { handle, source } + .with_span_handle(handle, &module.const_expressions) + })? + } + + for (handle, constant) in module.constants.iter() { + self.validate_constant(handle, module.to_ctx(), &mod_info) + .map_err(|source| { + ValidationError::Constant { + handle, + name: constant.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.constants) + })? + } + } + + for (var_handle, var) in module.global_variables.iter() { + self.validate_global_var(var, module.to_ctx(), &mod_info) + .map_err(|source| { + ValidationError::GlobalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(var_handle, &module.global_variables) + })?; + } + + for (handle, fun) in module.functions.iter() { + match self.validate_function(fun, module, &mod_info, false) { + Ok(info) => mod_info.functions.push(info), + Err(error) => { + return Err(error.and_then(|source| { + ValidationError::Function { + handle, + name: fun.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.functions) + })) + } + } + } + + let mut ep_map = FastHashSet::default(); + for ep in module.entry_points.iter() { + if !ep_map.insert((ep.stage, &ep.name)) { + return Err(ValidationError::EntryPoint { + stage: ep.stage, + name: ep.name.clone(), + source: EntryPointError::Conflict, + } + .with_span()); // TODO: keep some EP span information? + } + + match self.validate_entry_point(ep, module, &mod_info) { + Ok(info) => mod_info.entry_points.push(info), + Err(error) => { + return Err(error.and_then(|source| { + ValidationError::EntryPoint { + stage: ep.stage, + name: ep.name.clone(), + source, + } + .with_span() + })); + } + } + } + + Ok(mod_info) + } +} + +fn validate_atomic_compare_exchange_struct( + types: &crate::UniqueArena<crate::Type>, + members: &[crate::StructMember], + scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool, +) -> bool { + members.len() == 2 + && members[0].name.as_deref() == Some("old_value") + && scalar_predicate(&types[members[0].ty].inner) + && members[1].name.as_deref() == Some("exchanged") + && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL) +} diff --git a/third_party/rust/naga/src/valid/type.rs b/third_party/rust/naga/src/valid/type.rs new file mode 100644 index 0000000000..1e3e03fe19 --- /dev/null +++ b/third_party/rust/naga/src/valid/type.rs @@ -0,0 +1,643 @@ +use super::Capabilities; +use crate::{arena::Handle, proc::Alignment}; + +bitflags::bitflags! { + /// Flags associated with [`Type`]s by [`Validator`]. + /// + /// [`Type`]: crate::Type + /// [`Validator`]: crate::valid::Validator + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[repr(transparent)] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct TypeFlags: u8 { + /// Can be used for data variables. + /// + /// This flag is required on types of local variables, function + /// arguments, array elements, and struct members. + /// + /// This includes all types except `Image`, `Sampler`, + /// and some `Pointer` types. + const DATA = 0x1; + + /// The data type has a size known by pipeline creation time. + /// + /// Unsized types are quite restricted. The only unsized types permitted + /// by Naga, other than the non-[`DATA`] types like [`Image`] and + /// [`Sampler`], are dynamically-sized [`Array`s], and [`Struct`s] whose + /// last members are such arrays. See the documentation for those types + /// for details. + /// + /// [`DATA`]: TypeFlags::DATA + /// [`Image`]: crate::Type::Image + /// [`Sampler`]: crate::Type::Sampler + /// [`Array`]: crate::Type::Array + /// [`Struct`]: crate::Type::struct + const SIZED = 0x2; + + /// The data can be copied around. + const COPY = 0x4; + + /// Can be be used for user-defined IO between pipeline stages. + /// + /// This covers anything that can be in [`Location`] binding: + /// non-bool scalars and vectors, matrices, and structs and + /// arrays containing only interface types. + const IO_SHAREABLE = 0x8; + + /// Can be used for host-shareable structures. + const HOST_SHAREABLE = 0x10; + + /// This type can be passed as a function argument. + const ARGUMENT = 0x40; + + /// A WGSL [constructible] type. + /// + /// The constructible types are scalars, vectors, matrices, fixed-size + /// arrays of constructible types, and structs whose members are all + /// constructible. + /// + /// [constructible]: https://gpuweb.github.io/gpuweb/wgsl/#constructible + const CONSTRUCTIBLE = 0x80; + } +} + +#[derive(Clone, Copy, Debug, thiserror::Error)] +pub enum Disalignment { + #[error("The array stride {stride} is not a multiple of the required alignment {alignment}")] + ArrayStride { stride: u32, alignment: Alignment }, + #[error("The struct span {span}, is not a multiple of the required alignment {alignment}")] + StructSpan { span: u32, alignment: Alignment }, + #[error("The struct member[{index}] offset {offset} is not a multiple of the required alignment {alignment}")] + MemberOffset { + index: u32, + offset: u32, + alignment: Alignment, + }, + #[error("The struct member[{index}] offset {offset} must be at least {expected}")] + MemberOffsetAfterStruct { + index: u32, + offset: u32, + expected: u32, + }, + #[error("The struct member[{index}] is not statically sized")] + UnsizedMember { index: u32 }, + #[error("The type is not host-shareable")] + NonHostShareable, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum TypeError { + #[error("Capability {0:?} is required")] + MissingCapability(Capabilities), + #[error("The {0:?} scalar width {1} is not supported for an atomic")] + InvalidAtomicWidth(crate::ScalarKind, crate::Bytes), + #[error("Invalid type for pointer target {0:?}")] + InvalidPointerBase(Handle<crate::Type>), + #[error("Unsized types like {base:?} must be in the `Storage` address space, not `{space:?}`")] + InvalidPointerToUnsized { + base: Handle<crate::Type>, + space: crate::AddressSpace, + }, + #[error("Expected data type, found {0:?}")] + InvalidData(Handle<crate::Type>), + #[error("Base type {0:?} for the array is invalid")] + InvalidArrayBaseType(Handle<crate::Type>), + #[error("Matrix elements must always be floating-point types")] + MatrixElementNotFloat, + #[error("The constant {0:?} is specialized, and cannot be used as an array size")] + UnsupportedSpecializedArrayLength(Handle<crate::Constant>), + #[error("Array stride {stride} does not match the expected {expected}")] + InvalidArrayStride { stride: u32, expected: u32 }, + #[error("Field '{0}' can't be dynamically-sized, has type {1:?}")] + InvalidDynamicArray(String, Handle<crate::Type>), + #[error("The base handle {0:?} has to be a struct")] + BindingArrayBaseTypeNotStruct(Handle<crate::Type>), + #[error("Structure member[{index}] at {offset} overlaps the previous member")] + MemberOverlap { index: u32, offset: u32 }, + #[error( + "Structure member[{index}] at {offset} and size {size} crosses the structure boundary of size {span}" + )] + MemberOutOfBounds { + index: u32, + offset: u32, + size: u32, + span: u32, + }, + #[error("Structure types must have at least one member")] + EmptyStruct, + #[error(transparent)] + WidthError(#[from] WidthError), +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum WidthError { + #[error("The {0:?} scalar width {1} is not supported")] + Invalid(crate::ScalarKind, crate::Bytes), + #[error("Using `{name}` values requires the `naga::valid::Capabilities::{flag}` flag")] + MissingCapability { + name: &'static str, + flag: &'static str, + }, + + #[error("64-bit integers are not yet supported")] + Unsupported64Bit, + + #[error("Abstract types may only appear in constant expressions")] + Abstract, +} + +// Only makes sense if `flags.contains(HOST_SHAREABLE)` +type LayoutCompatibility = Result<Alignment, (Handle<crate::Type>, Disalignment)>; + +fn check_member_layout( + accum: &mut LayoutCompatibility, + member: &crate::StructMember, + member_index: u32, + member_layout: LayoutCompatibility, + parent_handle: Handle<crate::Type>, +) { + *accum = match (*accum, member_layout) { + (Ok(cur_alignment), Ok(alignment)) => { + if alignment.is_aligned(member.offset) { + Ok(cur_alignment.max(alignment)) + } else { + Err(( + parent_handle, + Disalignment::MemberOffset { + index: member_index, + offset: member.offset, + alignment, + }, + )) + } + } + (Err(e), _) | (_, Err(e)) => Err(e), + }; +} + +/// Determine whether a pointer in `space` can be passed as an argument. +/// +/// If a pointer in `space` is permitted to be passed as an argument to a +/// user-defined function, return `TypeFlags::ARGUMENT`. Otherwise, return +/// `TypeFlags::empty()`. +/// +/// Pointers passed as arguments to user-defined functions must be in the +/// `Function` or `Private` address space. +const fn ptr_space_argument_flag(space: crate::AddressSpace) -> TypeFlags { + use crate::AddressSpace as As; + match space { + As::Function | As::Private => TypeFlags::ARGUMENT, + As::Uniform | As::Storage { .. } | As::Handle | As::PushConstant | As::WorkGroup => { + TypeFlags::empty() + } + } +} + +#[derive(Clone, Debug)] +pub(super) struct TypeInfo { + pub flags: TypeFlags, + pub uniform_layout: LayoutCompatibility, + pub storage_layout: LayoutCompatibility, +} + +impl TypeInfo { + const fn dummy() -> Self { + TypeInfo { + flags: TypeFlags::empty(), + uniform_layout: Ok(Alignment::ONE), + storage_layout: Ok(Alignment::ONE), + } + } + + const fn new(flags: TypeFlags, alignment: Alignment) -> Self { + TypeInfo { + flags, + uniform_layout: Ok(alignment), + storage_layout: Ok(alignment), + } + } +} + +impl super::Validator { + const fn require_type_capability(&self, capability: Capabilities) -> Result<(), TypeError> { + if self.capabilities.contains(capability) { + Ok(()) + } else { + Err(TypeError::MissingCapability(capability)) + } + } + + pub(super) const fn check_width(&self, scalar: crate::Scalar) -> Result<(), WidthError> { + let good = match scalar.kind { + crate::ScalarKind::Bool => scalar.width == crate::BOOL_WIDTH, + crate::ScalarKind::Float => { + if scalar.width == 8 { + if !self.capabilities.contains(Capabilities::FLOAT64) { + return Err(WidthError::MissingCapability { + name: "f64", + flag: "FLOAT64", + }); + } + true + } else { + scalar.width == 4 + } + } + crate::ScalarKind::Sint | crate::ScalarKind::Uint => { + if scalar.width == 8 { + return Err(WidthError::Unsupported64Bit); + } + scalar.width == 4 + } + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { + return Err(WidthError::Abstract); + } + }; + if good { + Ok(()) + } else { + Err(WidthError::Invalid(scalar.kind, scalar.width)) + } + } + + pub(super) fn reset_types(&mut self, size: usize) { + self.types.clear(); + self.types.resize(size, TypeInfo::dummy()); + self.layouter.clear(); + } + + pub(super) fn validate_type( + &self, + handle: Handle<crate::Type>, + gctx: crate::proc::GlobalCtx, + ) -> Result<TypeInfo, TypeError> { + use crate::TypeInner as Ti; + Ok(match gctx.types[handle].inner { + Ti::Scalar(scalar) => { + self.check_width(scalar)?; + let shareable = if scalar.kind.is_numeric() { + TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE + } else { + TypeFlags::empty() + }; + TypeInfo::new( + TypeFlags::DATA + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE + | shareable, + Alignment::from_width(scalar.width), + ) + } + Ti::Vector { size, scalar } => { + self.check_width(scalar)?; + let shareable = if scalar.kind.is_numeric() { + TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE + } else { + TypeFlags::empty() + }; + TypeInfo::new( + TypeFlags::DATA + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::HOST_SHAREABLE + | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE + | shareable, + Alignment::from(size) * Alignment::from_width(scalar.width), + ) + } + Ti::Matrix { + columns: _, + rows, + scalar, + } => { + if scalar.kind != crate::ScalarKind::Float { + return Err(TypeError::MatrixElementNotFloat); + } + self.check_width(scalar)?; + TypeInfo::new( + TypeFlags::DATA + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::HOST_SHAREABLE + | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE, + Alignment::from(rows) * Alignment::from_width(scalar.width), + ) + } + Ti::Atomic(crate::Scalar { kind, width }) => { + let good = match kind { + crate::ScalarKind::Bool + | crate::ScalarKind::Float + | crate::ScalarKind::AbstractInt + | crate::ScalarKind::AbstractFloat => false, + crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4, + }; + if !good { + return Err(TypeError::InvalidAtomicWidth(kind, width)); + } + TypeInfo::new( + TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE, + Alignment::from_width(width), + ) + } + Ti::Pointer { base, space } => { + use crate::AddressSpace as As; + + let base_info = &self.types[base.index()]; + if !base_info.flags.contains(TypeFlags::DATA) { + return Err(TypeError::InvalidPointerBase(base)); + } + + // Runtime-sized values can only live in the `Storage` address + // space, so it's useless to have a pointer to such a type in + // any other space. + // + // Detecting this problem here prevents the definition of + // functions like: + // + // fn f(p: ptr<workgroup, UnsizedType>) -> ... { ... } + // + // which would otherwise be permitted, but uncallable. (They + // may also present difficulties in code generation). + if !base_info.flags.contains(TypeFlags::SIZED) { + match space { + As::Storage { .. } => {} + _ => { + return Err(TypeError::InvalidPointerToUnsized { base, space }); + } + } + } + + // `Validator::validate_function` actually checks the address + // space of pointer arguments explicitly before checking the + // `ARGUMENT` flag, to give better error messages. But it seems + // best to set `ARGUMENT` accurately anyway. + let argument_flag = ptr_space_argument_flag(space); + + // Pointers cannot be stored in variables, structure members, or + // array elements, so we do not mark them as `DATA`. + TypeInfo::new( + argument_flag | TypeFlags::SIZED | TypeFlags::COPY, + Alignment::ONE, + ) + } + Ti::ValuePointer { + size: _, + scalar, + space, + } => { + // ValuePointer should be treated the same way as the equivalent + // Pointer / Scalar / Vector combination, so each step in those + // variants' match arms should have a counterpart here. + // + // However, some cases are trivial: All our implicit base types + // are DATA and SIZED, so we can never return + // `InvalidPointerBase` or `InvalidPointerToUnsized`. + self.check_width(scalar)?; + + // `Validator::validate_function` actually checks the address + // space of pointer arguments explicitly before checking the + // `ARGUMENT` flag, to give better error messages. But it seems + // best to set `ARGUMENT` accurately anyway. + let argument_flag = ptr_space_argument_flag(space); + + // Pointers cannot be stored in variables, structure members, or + // array elements, so we do not mark them as `DATA`. + TypeInfo::new( + argument_flag | TypeFlags::SIZED | TypeFlags::COPY, + Alignment::ONE, + ) + } + Ti::Array { base, size, stride } => { + let base_info = &self.types[base.index()]; + if !base_info.flags.contains(TypeFlags::DATA | TypeFlags::SIZED) { + return Err(TypeError::InvalidArrayBaseType(base)); + } + + let base_layout = self.layouter[base]; + let general_alignment = base_layout.alignment; + let uniform_layout = match base_info.uniform_layout { + Ok(base_alignment) => { + let alignment = base_alignment + .max(general_alignment) + .max(Alignment::MIN_UNIFORM); + if alignment.is_aligned(stride) { + Ok(alignment) + } else { + Err((handle, Disalignment::ArrayStride { stride, alignment })) + } + } + Err(e) => Err(e), + }; + let storage_layout = match base_info.storage_layout { + Ok(base_alignment) => { + let alignment = base_alignment.max(general_alignment); + if alignment.is_aligned(stride) { + Ok(alignment) + } else { + Err((handle, Disalignment::ArrayStride { stride, alignment })) + } + } + Err(e) => Err(e), + }; + + let type_info_mask = match size { + crate::ArraySize::Constant(_) => { + TypeFlags::DATA + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::HOST_SHAREABLE + | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE + } + crate::ArraySize::Dynamic => { + // Non-SIZED types may only appear as the last element of a structure. + // This is enforced by checks for SIZED-ness for all compound types, + // and a special case for structs. + TypeFlags::DATA | TypeFlags::COPY | TypeFlags::HOST_SHAREABLE + } + }; + + TypeInfo { + flags: base_info.flags & type_info_mask, + uniform_layout, + storage_layout, + } + } + Ti::Struct { ref members, span } => { + if members.is_empty() { + return Err(TypeError::EmptyStruct); + } + + let mut ti = TypeInfo::new( + TypeFlags::DATA + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::HOST_SHAREABLE + | TypeFlags::IO_SHAREABLE + | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE, + Alignment::ONE, + ); + ti.uniform_layout = Ok(Alignment::MIN_UNIFORM); + + let mut min_offset = 0; + + let mut prev_struct_data: Option<(u32, u32)> = None; + + for (i, member) in members.iter().enumerate() { + let base_info = &self.types[member.ty.index()]; + if !base_info.flags.contains(TypeFlags::DATA) { + return Err(TypeError::InvalidData(member.ty)); + } + if !base_info.flags.contains(TypeFlags::HOST_SHAREABLE) { + if ti.uniform_layout.is_ok() { + ti.uniform_layout = Err((member.ty, Disalignment::NonHostShareable)); + } + if ti.storage_layout.is_ok() { + ti.storage_layout = Err((member.ty, Disalignment::NonHostShareable)); + } + } + ti.flags &= base_info.flags; + + if member.offset < min_offset { + // HACK: this could be nicer. We want to allow some structures + // to not bother with offsets/alignments if they are never + // used for host sharing. + if member.offset == 0 { + ti.flags.set(TypeFlags::HOST_SHAREABLE, false); + } else { + return Err(TypeError::MemberOverlap { + index: i as u32, + offset: member.offset, + }); + } + } + + let base_size = gctx.types[member.ty].inner.size(gctx); + min_offset = member.offset + base_size; + if min_offset > span { + return Err(TypeError::MemberOutOfBounds { + index: i as u32, + offset: member.offset, + size: base_size, + span, + }); + } + + check_member_layout( + &mut ti.uniform_layout, + member, + i as u32, + base_info.uniform_layout, + handle, + ); + check_member_layout( + &mut ti.storage_layout, + member, + i as u32, + base_info.storage_layout, + handle, + ); + + // Validate rule: If a structure member itself has a structure type S, + // then the number of bytes between the start of that member and + // the start of any following member must be at least roundUp(16, SizeOf(S)). + if let Some((span, offset)) = prev_struct_data { + let diff = member.offset - offset; + let min = Alignment::MIN_UNIFORM.round_up(span); + if diff < min { + ti.uniform_layout = Err(( + handle, + Disalignment::MemberOffsetAfterStruct { + index: i as u32, + offset: member.offset, + expected: offset + min, + }, + )); + } + }; + + prev_struct_data = match gctx.types[member.ty].inner { + crate::TypeInner::Struct { span, .. } => Some((span, member.offset)), + _ => None, + }; + + // The last field may be an unsized array. + if !base_info.flags.contains(TypeFlags::SIZED) { + let is_array = match gctx.types[member.ty].inner { + crate::TypeInner::Array { .. } => true, + _ => false, + }; + if !is_array || i + 1 != members.len() { + let name = member.name.clone().unwrap_or_default(); + return Err(TypeError::InvalidDynamicArray(name, member.ty)); + } + if ti.uniform_layout.is_ok() { + ti.uniform_layout = + Err((handle, Disalignment::UnsizedMember { index: i as u32 })); + } + } + } + + let alignment = self.layouter[handle].alignment; + if !alignment.is_aligned(span) { + ti.uniform_layout = Err((handle, Disalignment::StructSpan { span, alignment })); + ti.storage_layout = Err((handle, Disalignment::StructSpan { span, alignment })); + } + + ti + } + Ti::Image { + dim, + arrayed, + class: _, + } => { + if arrayed && matches!(dim, crate::ImageDimension::Cube) { + self.require_type_capability(Capabilities::CUBE_ARRAY_TEXTURES)?; + } + TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE) + } + Ti::Sampler { .. } => TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE), + Ti::AccelerationStructure => { + self.require_type_capability(Capabilities::RAY_QUERY)?; + TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE) + } + Ti::RayQuery => { + self.require_type_capability(Capabilities::RAY_QUERY)?; + TypeInfo::new( + TypeFlags::DATA | TypeFlags::CONSTRUCTIBLE | TypeFlags::SIZED, + Alignment::ONE, + ) + } + Ti::BindingArray { base, size } => { + if base >= handle { + return Err(TypeError::InvalidArrayBaseType(base)); + } + let type_info_mask = match size { + crate::ArraySize::Constant(_) => TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE, + crate::ArraySize::Dynamic => { + // Final type is non-sized + TypeFlags::HOST_SHAREABLE + } + }; + let base_info = &self.types[base.index()]; + + if base_info.flags.contains(TypeFlags::DATA) { + // Currently Naga only supports binding arrays of structs for non-handle types. + match gctx.types[base].inner { + crate::TypeInner::Struct { .. } => {} + _ => return Err(TypeError::BindingArrayBaseTypeNotStruct(base)), + }; + } + + TypeInfo::new(base_info.flags & type_info_mask, Alignment::ONE) + } + }) + } +} |