summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src')
-rw-r--r--third_party/rust/naga/src/arena.rs717
-rw-r--r--third_party/rust/naga/src/back/dot/mod.rs695
-rw-r--r--third_party/rust/naga/src/back/glsl/features.rs512
-rw-r--r--third_party/rust/naga/src/back/glsl/keywords.rs204
-rw-r--r--third_party/rust/naga/src/back/glsl/mod.rs4114
-rw-r--r--third_party/rust/naga/src/back/hlsl/conv.rs222
-rw-r--r--third_party/rust/naga/src/back/hlsl/help.rs1126
-rw-r--r--third_party/rust/naga/src/back/hlsl/keywords.rs166
-rw-r--r--third_party/rust/naga/src/back/hlsl/mod.rs302
-rw-r--r--third_party/rust/naga/src/back/hlsl/storage.rs510
-rw-r--r--third_party/rust/naga/src/back/hlsl/writer.rs3188
-rw-r--r--third_party/rust/naga/src/back/mod.rs250
-rw-r--r--third_party/rust/naga/src/back/msl/keywords.rs217
-rw-r--r--third_party/rust/naga/src/back/msl/mod.rs489
-rw-r--r--third_party/rust/naga/src/back/msl/sampler.rs175
-rw-r--r--third_party/rust/naga/src/back/msl/writer.rs4429
-rw-r--r--third_party/rust/naga/src/back/spv/block.rs2235
-rw-r--r--third_party/rust/naga/src/back/spv/helpers.rs108
-rw-r--r--third_party/rust/naga/src/back/spv/image.rs1269
-rw-r--r--third_party/rust/naga/src/back/spv/index.rs417
-rw-r--r--third_party/rust/naga/src/back/spv/instructions.rs1063
-rw-r--r--third_party/rust/naga/src/back/spv/layout.rs210
-rw-r--r--third_party/rust/naga/src/back/spv/mod.rs729
-rw-r--r--third_party/rust/naga/src/back/spv/ray.rs273
-rw-r--r--third_party/rust/naga/src/back/spv/recyclable.rs60
-rw-r--r--third_party/rust/naga/src/back/spv/selection.rs257
-rw-r--r--third_party/rust/naga/src/back/spv/writer.rs1966
-rw-r--r--third_party/rust/naga/src/back/wgsl/mod.rs52
-rw-r--r--third_party/rust/naga/src/back/wgsl/writer.rs1968
-rw-r--r--third_party/rust/naga/src/block.rs145
-rw-r--r--third_party/rust/naga/src/front/glsl/ast.rs393
-rw-r--r--third_party/rust/naga/src/front/glsl/builtins.rs2497
-rw-r--r--third_party/rust/naga/src/front/glsl/constants.rs984
-rw-r--r--third_party/rust/naga/src/front/glsl/context.rs1609
-rw-r--r--third_party/rust/naga/src/front/glsl/error.rs134
-rw-r--r--third_party/rust/naga/src/front/glsl/functions.rs1680
-rw-r--r--third_party/rust/naga/src/front/glsl/lex.rs301
-rw-r--r--third_party/rust/naga/src/front/glsl/mod.rs235
-rw-r--r--third_party/rust/naga/src/front/glsl/offset.rs173
-rw-r--r--third_party/rust/naga/src/front/glsl/parser.rs448
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/declarations.rs671
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/expressions.rs546
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/functions.rs654
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/types.rs434
-rw-r--r--third_party/rust/naga/src/front/glsl/parser_tests.rs832
-rw-r--r--third_party/rust/naga/src/front/glsl/token.rs137
-rw-r--r--third_party/rust/naga/src/front/glsl/types.rs335
-rw-r--r--third_party/rust/naga/src/front/glsl/variables.rs679
-rw-r--r--third_party/rust/naga/src/front/interpolator.rs61
-rw-r--r--third_party/rust/naga/src/front/mod.rs374
-rw-r--r--third_party/rust/naga/src/front/spv/convert.rs181
-rw-r--r--third_party/rust/naga/src/front/spv/error.rs127
-rw-r--r--third_party/rust/naga/src/front/spv/function.rs661
-rw-r--r--third_party/rust/naga/src/front/spv/image.rs737
-rw-r--r--third_party/rust/naga/src/front/spv/mod.rs5195
-rw-r--r--third_party/rust/naga/src/front/spv/null.rs175
-rw-r--r--third_party/rust/naga/src/front/type_gen.rs314
-rw-r--r--third_party/rust/naga/src/front/wgsl/error.rs690
-rw-r--r--third_party/rust/naga/src/front/wgsl/index.rs193
-rw-r--r--third_party/rust/naga/src/front/wgsl/lower/construction.rs668
-rw-r--r--third_party/rust/naga/src/front/wgsl/lower/mod.rs2426
-rw-r--r--third_party/rust/naga/src/front/wgsl/mod.rs340
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/ast.rs486
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/conv.rs236
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/lexer.rs723
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/mod.rs2298
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/number.rs443
-rw-r--r--third_party/rust/naga/src/front/wgsl/tests.rs483
-rw-r--r--third_party/rust/naga/src/keywords/mod.rs6
-rw-r--r--third_party/rust/naga/src/keywords/wgsl.rs229
-rw-r--r--third_party/rust/naga/src/lib.rs1892
-rw-r--r--third_party/rust/naga/src/proc/index.rs437
-rw-r--r--third_party/rust/naga/src/proc/layouter.rs256
-rw-r--r--third_party/rust/naga/src/proc/mod.rs493
-rw-r--r--third_party/rust/naga/src/proc/namer.rs261
-rw-r--r--third_party/rust/naga/src/proc/terminator.rs43
-rw-r--r--third_party/rust/naga/src/proc/typifier.rs909
-rw-r--r--third_party/rust/naga/src/span.rs523
-rw-r--r--third_party/rust/naga/src/valid/analyzer.rs1200
-rw-r--r--third_party/rust/naga/src/valid/compose.rs138
-rw-r--r--third_party/rust/naga/src/valid/expression.rs1482
-rw-r--r--third_party/rust/naga/src/valid/function.rs1048
-rw-r--r--third_party/rust/naga/src/valid/handles.rs652
-rw-r--r--third_party/rust/naga/src/valid/interface.rs676
-rw-r--r--third_party/rust/naga/src/valid/mod.rs467
-rw-r--r--third_party/rust/naga/src/valid/type.rs636
86 files changed, 68969 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/arena.rs b/third_party/rust/naga/src/arena.rs
new file mode 100644
index 0000000000..5c2681c834
--- /dev/null
+++ b/third_party/rust/naga/src/arena.rs
@@ -0,0 +1,717 @@
+use std::{cmp::Ordering, fmt, hash, marker::PhantomData, num::NonZeroU32, ops};
+
+/// An unique index in the arena array that a handle points to.
+/// The "non-zero" part ensures that an `Option<Handle<T>>` has
+/// the same size and representation as `Handle<T>`.
+type Index = NonZeroU32;
+
+use crate::Span;
+use indexmap::set::IndexSet;
+
+#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)]
+#[error("Handle {index} of {kind} is either not present, or inaccessible yet")]
+pub struct BadHandle {
+ pub kind: &'static str,
+ pub index: usize,
+}
+
+impl BadHandle {
+ fn new<T>(handle: Handle<T>) -> Self {
+ Self {
+ kind: std::any::type_name::<T>(),
+ index: handle.index(),
+ }
+ }
+}
+
+/// A strongly typed reference to an arena item.
+///
+/// A `Handle` value can be used as an index into an [`Arena`] or [`UniqueArena`].
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(
+ any(feature = "serialize", feature = "deserialize"),
+ serde(transparent)
+)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub struct Handle<T> {
+ index: Index,
+ #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))]
+ marker: PhantomData<T>,
+}
+
+impl<T> Clone for Handle<T> {
+ fn clone(&self) -> Self {
+ Handle {
+ index: self.index,
+ marker: self.marker,
+ }
+ }
+}
+
+impl<T> Copy for Handle<T> {}
+
+impl<T> PartialEq for Handle<T> {
+ fn eq(&self, other: &Self) -> bool {
+ self.index == other.index
+ }
+}
+
+impl<T> Eq for Handle<T> {}
+
+impl<T> PartialOrd for Handle<T> {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ self.index.partial_cmp(&other.index)
+ }
+}
+
+impl<T> Ord for Handle<T> {
+ fn cmp(&self, other: &Self) -> Ordering {
+ self.index.cmp(&other.index)
+ }
+}
+
+impl<T> fmt::Debug for Handle<T> {
+ fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ write!(formatter, "[{}]", self.index)
+ }
+}
+
+impl<T> hash::Hash for Handle<T> {
+ fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
+ self.index.hash(hasher)
+ }
+}
+
+impl<T> Handle<T> {
+ #[cfg(test)]
+ pub const DUMMY: Self = Handle {
+ index: unsafe { NonZeroU32::new_unchecked(u32::MAX) },
+ marker: PhantomData,
+ };
+
+ pub(crate) const fn new(index: Index) -> Self {
+ Handle {
+ index,
+ marker: PhantomData,
+ }
+ }
+
+ /// Returns the zero-based index of this handle.
+ pub const fn index(self) -> usize {
+ let index = self.index.get() - 1;
+ index as usize
+ }
+
+ /// Convert a `usize` index into a `Handle<T>`.
+ fn from_usize(index: usize) -> Self {
+ let handle_index = u32::try_from(index + 1)
+ .ok()
+ .and_then(Index::new)
+ .expect("Failed to insert into arena. Handle overflows");
+ Handle::new(handle_index)
+ }
+
+ /// Convert a `usize` index into a `Handle<T>`, without range checks.
+ const unsafe fn from_usize_unchecked(index: usize) -> Self {
+ Handle::new(Index::new_unchecked((index + 1) as u32))
+ }
+}
+
+/// A strongly typed range of handles.
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(
+ any(feature = "serialize", feature = "deserialize"),
+ serde(transparent)
+)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub struct Range<T> {
+ inner: ops::Range<u32>,
+ #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))]
+ marker: PhantomData<T>,
+}
+
+impl<T> Range<T> {
+ pub(crate) const fn erase_type(self) -> Range<()> {
+ let Self { inner, marker: _ } = self;
+ Range {
+ inner,
+ marker: PhantomData,
+ }
+ }
+}
+
+// NOTE: Keep this diagnostic in sync with that of [`BadHandle`].
+#[derive(Clone, Debug, thiserror::Error)]
+#[error("Handle range {range:?} of {kind} is either not present, or inaccessible yet")]
+pub struct BadRangeError {
+ // This error is used for many `Handle` types, but there's no point in making this generic, so
+ // we just flatten them all to `Handle<()>` here.
+ kind: &'static str,
+ range: Range<()>,
+}
+
+impl BadRangeError {
+ pub fn new<T>(range: Range<T>) -> Self {
+ Self {
+ kind: std::any::type_name::<T>(),
+ range: range.erase_type(),
+ }
+ }
+}
+
+impl<T> Clone for Range<T> {
+ fn clone(&self) -> Self {
+ Range {
+ inner: self.inner.clone(),
+ marker: self.marker,
+ }
+ }
+}
+
+impl<T> fmt::Debug for Range<T> {
+ fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ write!(formatter, "[{}..{}]", self.inner.start + 1, self.inner.end)
+ }
+}
+
+impl<T> Iterator for Range<T> {
+ type Item = Handle<T>;
+ fn next(&mut self) -> Option<Self::Item> {
+ if self.inner.start < self.inner.end {
+ self.inner.start += 1;
+ Some(Handle {
+ index: NonZeroU32::new(self.inner.start).unwrap(),
+ marker: self.marker,
+ })
+ } else {
+ None
+ }
+ }
+}
+
+impl<T> Range<T> {
+ pub fn new_from_bounds(first: Handle<T>, last: Handle<T>) -> Self {
+ Self {
+ inner: (first.index() as u32)..(last.index() as u32 + 1),
+ marker: Default::default(),
+ }
+ }
+}
+
+/// An arena holding some kind of component (e.g., type, constant,
+/// instruction, etc.) that can be referenced.
+///
+/// Adding new items to the arena produces a strongly-typed [`Handle`].
+/// The arena can be indexed using the given handle to obtain
+/// a reference to the stored item.
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "serialize", serde(transparent))]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Arena<T> {
+ /// Values of this arena.
+ data: Vec<T>,
+ #[cfg(feature = "span")]
+ #[cfg_attr(feature = "serialize", serde(skip))]
+ span_info: Vec<Span>,
+}
+
+impl<T> Default for Arena<T> {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl<T: fmt::Debug> fmt::Debug for Arena<T> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.debug_map().entries(self.iter()).finish()
+ }
+}
+
+impl<T> Arena<T> {
+ /// Create a new arena with no initial capacity allocated.
+ pub const fn new() -> Self {
+ Arena {
+ data: Vec::new(),
+ #[cfg(feature = "span")]
+ span_info: Vec::new(),
+ }
+ }
+
+ /// Extracts the inner vector.
+ #[allow(clippy::missing_const_for_fn)] // ignore due to requirement of #![feature(const_precise_live_drops)]
+ pub fn into_inner(self) -> Vec<T> {
+ self.data
+ }
+
+ /// Returns the current number of items stored in this arena.
+ pub fn len(&self) -> usize {
+ self.data.len()
+ }
+
+ /// Returns `true` if the arena contains no elements.
+ pub fn is_empty(&self) -> bool {
+ self.data.is_empty()
+ }
+
+ /// Returns an iterator over the items stored in this arena, returning both
+ /// the item's handle and a reference to it.
+ pub fn iter(&self) -> impl DoubleEndedIterator<Item = (Handle<T>, &T)> {
+ self.data
+ .iter()
+ .enumerate()
+ .map(|(i, v)| unsafe { (Handle::from_usize_unchecked(i), v) })
+ }
+
+ /// Returns a iterator over the items stored in this arena,
+ /// returning both the item's handle and a mutable reference to it.
+ pub fn iter_mut(&mut self) -> impl DoubleEndedIterator<Item = (Handle<T>, &mut T)> {
+ self.data
+ .iter_mut()
+ .enumerate()
+ .map(|(i, v)| unsafe { (Handle::from_usize_unchecked(i), v) })
+ }
+
+ /// Adds a new value to the arena, returning a typed handle.
+ pub fn append(&mut self, value: T, span: Span) -> Handle<T> {
+ #[cfg(not(feature = "span"))]
+ let _ = span;
+ let index = self.data.len();
+ self.data.push(value);
+ #[cfg(feature = "span")]
+ self.span_info.push(span);
+ Handle::from_usize(index)
+ }
+
+ /// Fetch a handle to an existing type.
+ pub fn fetch_if<F: Fn(&T) -> bool>(&self, fun: F) -> Option<Handle<T>> {
+ self.data
+ .iter()
+ .position(fun)
+ .map(|index| unsafe { Handle::from_usize_unchecked(index) })
+ }
+
+ /// Adds a value with a custom check for uniqueness:
+ /// returns a handle pointing to
+ /// an existing element if the check succeeds, or adds a new
+ /// element otherwise.
+ pub fn fetch_if_or_append<F: Fn(&T, &T) -> bool>(
+ &mut self,
+ value: T,
+ span: Span,
+ fun: F,
+ ) -> Handle<T> {
+ if let Some(index) = self.data.iter().position(|d| fun(d, &value)) {
+ unsafe { Handle::from_usize_unchecked(index) }
+ } else {
+ self.append(value, span)
+ }
+ }
+
+ /// Adds a value with a check for uniqueness, where the check is plain comparison.
+ pub fn fetch_or_append(&mut self, value: T, span: Span) -> Handle<T>
+ where
+ T: PartialEq,
+ {
+ self.fetch_if_or_append(value, span, T::eq)
+ }
+
+ pub fn try_get(&self, handle: Handle<T>) -> Result<&T, BadHandle> {
+ self.data
+ .get(handle.index())
+ .ok_or_else(|| BadHandle::new(handle))
+ }
+
+ /// Get a mutable reference to an element in the arena.
+ pub fn get_mut(&mut self, handle: Handle<T>) -> &mut T {
+ self.data.get_mut(handle.index()).unwrap()
+ }
+
+ /// Get the range of handles from a particular number of elements to the end.
+ pub fn range_from(&self, old_length: usize) -> Range<T> {
+ Range {
+ inner: old_length as u32..self.data.len() as u32,
+ marker: PhantomData,
+ }
+ }
+
+ /// Clears the arena keeping all allocations
+ pub fn clear(&mut self) {
+ self.data.clear()
+ }
+
+ pub fn get_span(&self, handle: Handle<T>) -> Span {
+ #[cfg(feature = "span")]
+ {
+ *self
+ .span_info
+ .get(handle.index())
+ .unwrap_or(&Span::default())
+ }
+ #[cfg(not(feature = "span"))]
+ {
+ let _ = handle;
+ Span::default()
+ }
+ }
+
+ /// Assert that `handle` is valid for this arena.
+ pub fn check_contains_handle(&self, handle: Handle<T>) -> Result<(), BadHandle> {
+ if handle.index() < self.data.len() {
+ Ok(())
+ } else {
+ Err(BadHandle::new(handle))
+ }
+ }
+
+ /// Assert that `range` is valid for this arena.
+ pub fn check_contains_range(&self, range: &Range<T>) -> Result<(), BadRangeError> {
+ // Since `range.inner` is a `Range<u32>`, we only need to
+ // check that the start precedes the end, and that the end is
+ // in range.
+ if range.inner.start > range.inner.end
+ || self
+ .check_contains_handle(Handle::new(range.inner.end.try_into().unwrap()))
+ .is_err()
+ {
+ Err(BadRangeError::new(range.clone()))
+ } else {
+ Ok(())
+ }
+ }
+}
+
+#[cfg(feature = "deserialize")]
+impl<'de, T> serde::Deserialize<'de> for Arena<T>
+where
+ T: serde::Deserialize<'de>,
+{
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ let data = Vec::deserialize(deserializer)?;
+ #[cfg(feature = "span")]
+ let span_info = std::iter::repeat(Span::default())
+ .take(data.len())
+ .collect();
+
+ Ok(Self {
+ data,
+ #[cfg(feature = "span")]
+ span_info,
+ })
+ }
+}
+
+impl<T> ops::Index<Handle<T>> for Arena<T> {
+ type Output = T;
+ fn index(&self, handle: Handle<T>) -> &T {
+ &self.data[handle.index()]
+ }
+}
+
+impl<T> ops::IndexMut<Handle<T>> for Arena<T> {
+ fn index_mut(&mut self, handle: Handle<T>) -> &mut T {
+ &mut self.data[handle.index()]
+ }
+}
+
+impl<T> ops::Index<Range<T>> for Arena<T> {
+ type Output = [T];
+ fn index(&self, range: Range<T>) -> &[T] {
+ &self.data[range.inner.start as usize..range.inner.end as usize]
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn append_non_unique() {
+ let mut arena: Arena<u8> = Arena::new();
+ let t1 = arena.append(0, Default::default());
+ let t2 = arena.append(0, Default::default());
+ assert!(t1 != t2);
+ assert!(arena[t1] == arena[t2]);
+ }
+
+ #[test]
+ fn append_unique() {
+ let mut arena: Arena<u8> = Arena::new();
+ let t1 = arena.append(0, Default::default());
+ let t2 = arena.append(1, Default::default());
+ assert!(t1 != t2);
+ assert!(arena[t1] != arena[t2]);
+ }
+
+ #[test]
+ fn fetch_or_append_non_unique() {
+ let mut arena: Arena<u8> = Arena::new();
+ let t1 = arena.fetch_or_append(0, Default::default());
+ let t2 = arena.fetch_or_append(0, Default::default());
+ assert!(t1 == t2);
+ assert!(arena[t1] == arena[t2])
+ }
+
+ #[test]
+ fn fetch_or_append_unique() {
+ let mut arena: Arena<u8> = Arena::new();
+ let t1 = arena.fetch_or_append(0, Default::default());
+ let t2 = arena.fetch_or_append(1, Default::default());
+ assert!(t1 != t2);
+ assert!(arena[t1] != arena[t2]);
+ }
+}
+
+/// An arena whose elements are guaranteed to be unique.
+///
+/// A `UniqueArena` holds a set of unique values of type `T`, each with an
+/// associated [`Span`]. Inserting a value returns a `Handle<T>`, which can be
+/// used to index the `UniqueArena` and obtain shared access to the `T` element.
+/// Access via a `Handle` is an array lookup - no hash lookup is necessary.
+///
+/// The element type must implement `Eq` and `Hash`. Insertions of equivalent
+/// elements, according to `Eq`, all return the same `Handle`.
+///
+/// Once inserted, elements may not be mutated.
+///
+/// `UniqueArena` is similar to [`Arena`]: If `Arena` is vector-like,
+/// `UniqueArena` is `HashSet`-like.
+#[cfg_attr(feature = "clone", derive(Clone))]
+pub struct UniqueArena<T> {
+ set: IndexSet<T>,
+
+ /// Spans for the elements, indexed by handle.
+ ///
+ /// The length of this vector is always equal to `set.len()`. `IndexSet`
+ /// promises that its elements "are indexed in a compact range, without
+ /// holes in the range 0..set.len()", so we can always use the indices
+ /// returned by insertion as indices into this vector.
+ #[cfg(feature = "span")]
+ span_info: Vec<Span>,
+}
+
+impl<T> UniqueArena<T> {
+ /// Create a new arena with no initial capacity allocated.
+ pub fn new() -> Self {
+ UniqueArena {
+ set: IndexSet::new(),
+ #[cfg(feature = "span")]
+ span_info: Vec::new(),
+ }
+ }
+
+ /// Return the current number of items stored in this arena.
+ pub fn len(&self) -> usize {
+ self.set.len()
+ }
+
+ /// Return `true` if the arena contains no elements.
+ pub fn is_empty(&self) -> bool {
+ self.set.is_empty()
+ }
+
+ /// Clears the arena, keeping all allocations.
+ pub fn clear(&mut self) {
+ self.set.clear();
+ #[cfg(feature = "span")]
+ self.span_info.clear();
+ }
+
+ /// Return the span associated with `handle`.
+ ///
+ /// If a value has been inserted multiple times, the span returned is the
+ /// one provided with the first insertion.
+ ///
+ /// If the `span` feature is not enabled, always return `Span::default`.
+ /// This can be detected with [`Span::is_defined`].
+ pub fn get_span(&self, handle: Handle<T>) -> Span {
+ #[cfg(feature = "span")]
+ {
+ *self
+ .span_info
+ .get(handle.index())
+ .unwrap_or(&Span::default())
+ }
+ #[cfg(not(feature = "span"))]
+ {
+ let _ = handle;
+ Span::default()
+ }
+ }
+}
+
+impl<T: Eq + hash::Hash> UniqueArena<T> {
+ /// Returns an iterator over the items stored in this arena, returning both
+ /// the item's handle and a reference to it.
+ pub fn iter(&self) -> impl DoubleEndedIterator<Item = (Handle<T>, &T)> {
+ self.set.iter().enumerate().map(|(i, v)| {
+ let position = i + 1;
+ let index = unsafe { Index::new_unchecked(position as u32) };
+ (Handle::new(index), v)
+ })
+ }
+
+ /// Insert a new value into the arena.
+ ///
+ /// Return a [`Handle<T>`], which can be used to index this arena to get a
+ /// shared reference to the element.
+ ///
+ /// If this arena already contains an element that is `Eq` to `value`,
+ /// return a `Handle` to the existing element, and drop `value`.
+ ///
+ /// When the `span` feature is enabled, if `value` is inserted into the
+ /// arena, associate `span` with it. An element's span can be retrieved with
+ /// the [`get_span`] method.
+ ///
+ /// [`Handle<T>`]: Handle
+ /// [`get_span`]: UniqueArena::get_span
+ pub fn insert(&mut self, value: T, span: Span) -> Handle<T> {
+ let (index, added) = self.set.insert_full(value);
+
+ #[cfg(feature = "span")]
+ {
+ if added {
+ debug_assert!(index == self.span_info.len());
+ self.span_info.push(span);
+ }
+
+ debug_assert!(self.set.len() == self.span_info.len());
+ }
+
+ #[cfg(not(feature = "span"))]
+ let _ = (span, added);
+
+ Handle::from_usize(index)
+ }
+
+ /// Replace an old value with a new value.
+ ///
+ /// # Panics
+ ///
+ /// - if the old value is not in the arena
+ /// - if the new value already exists in the arena
+ pub fn replace(&mut self, old: Handle<T>, new: T) {
+ let (index, added) = self.set.insert_full(new);
+ assert!(added && index == self.set.len() - 1);
+
+ self.set.swap_remove_index(old.index()).unwrap();
+ }
+
+ /// Return this arena's handle for `value`, if present.
+ ///
+ /// If this arena already contains an element equal to `value`,
+ /// return its handle. Otherwise, return `None`.
+ pub fn get(&self, value: &T) -> Option<Handle<T>> {
+ self.set
+ .get_index_of(value)
+ .map(|index| unsafe { Handle::from_usize_unchecked(index) })
+ }
+
+ /// Return this arena's value at `handle`, if that is a valid handle.
+ pub fn get_handle(&self, handle: Handle<T>) -> Result<&T, BadHandle> {
+ self.set
+ .get_index(handle.index())
+ .ok_or_else(|| BadHandle::new(handle))
+ }
+
+ /// Assert that `handle` is valid for this arena.
+ pub fn check_contains_handle(&self, handle: Handle<T>) -> Result<(), BadHandle> {
+ if handle.index() < self.set.len() {
+ Ok(())
+ } else {
+ Err(BadHandle::new(handle))
+ }
+ }
+}
+
+impl<T> Default for UniqueArena<T> {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl<T: fmt::Debug + Eq + hash::Hash> fmt::Debug for UniqueArena<T> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.debug_map().entries(self.iter()).finish()
+ }
+}
+
+impl<T> ops::Index<Handle<T>> for UniqueArena<T> {
+ type Output = T;
+ fn index(&self, handle: Handle<T>) -> &T {
+ &self.set[handle.index()]
+ }
+}
+
+#[cfg(feature = "serialize")]
+impl<T> serde::Serialize for UniqueArena<T>
+where
+ T: Eq + hash::Hash + serde::Serialize,
+{
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: serde::Serializer,
+ {
+ self.set.serialize(serializer)
+ }
+}
+
+#[cfg(feature = "deserialize")]
+impl<'de, T> serde::Deserialize<'de> for UniqueArena<T>
+where
+ T: Eq + hash::Hash + serde::Deserialize<'de>,
+{
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ let set = IndexSet::deserialize(deserializer)?;
+ #[cfg(feature = "span")]
+ let span_info = std::iter::repeat(Span::default()).take(set.len()).collect();
+
+ Ok(Self {
+ set,
+ #[cfg(feature = "span")]
+ span_info,
+ })
+ }
+}
+
+//Note: largely borrowed from `HashSet` implementation
+#[cfg(feature = "arbitrary")]
+impl<'a, T> arbitrary::Arbitrary<'a> for UniqueArena<T>
+where
+ T: Eq + hash::Hash + arbitrary::Arbitrary<'a>,
+{
+ fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
+ let mut arena = Self::default();
+ for elem in u.arbitrary_iter()? {
+ arena.set.insert(elem?);
+ #[cfg(feature = "span")]
+ arena.span_info.push(Span::UNDEFINED);
+ }
+ Ok(arena)
+ }
+
+ fn arbitrary_take_rest(u: arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
+ let mut arena = Self::default();
+ for elem in u.arbitrary_take_rest_iter()? {
+ arena.set.insert(elem?);
+ #[cfg(feature = "span")]
+ arena.span_info.push(Span::UNDEFINED);
+ }
+ Ok(arena)
+ }
+
+ #[inline]
+ fn size_hint(depth: usize) -> (usize, Option<usize>) {
+ let depth_hint = <usize as arbitrary::Arbitrary>::size_hint(depth);
+ arbitrary::size_hint::and(depth_hint, (0, None))
+ }
+}
diff --git a/third_party/rust/naga/src/back/dot/mod.rs b/third_party/rust/naga/src/back/dot/mod.rs
new file mode 100644
index 0000000000..1eebbee067
--- /dev/null
+++ b/third_party/rust/naga/src/back/dot/mod.rs
@@ -0,0 +1,695 @@
+/*!
+Backend for [DOT][dot] (Graphviz).
+
+This backend writes a graph in the DOT language, for the ease
+of IR inspection and debugging.
+
+[dot]: https://graphviz.org/doc/info/lang.html
+*/
+
+use crate::{
+ arena::Handle,
+ valid::{FunctionInfo, ModuleInfo},
+};
+
+use std::{
+ borrow::Cow,
+ fmt::{Error as FmtError, Write as _},
+};
+
+/// Configuration options for the dot backend
+#[derive(Default)]
+pub struct Options {
+ /// Only emit function bodies
+ pub cfg_only: bool,
+}
+
+/// Identifier used to address a graph node
+type NodeId = usize;
+
+/// Stores the target nodes for control flow statements
+#[derive(Default, Clone, Copy)]
+struct Targets {
+ /// The node, if some, where continue operations will land
+ continue_target: Option<usize>,
+ /// The node, if some, where break operations will land
+ break_target: Option<usize>,
+}
+
+/// Stores information about the graph of statements
+#[derive(Default)]
+struct StatementGraph {
+ /// List of node names
+ nodes: Vec<&'static str>,
+ /// List of edges of the control flow, the items are defined as
+ /// (from, to, label)
+ flow: Vec<(NodeId, NodeId, &'static str)>,
+ /// List of implicit edges of the control flow, used for jump
+ /// operations such as continue or break, the items are defined as
+ /// (from, to, label, color_id)
+ jumps: Vec<(NodeId, NodeId, &'static str, usize)>,
+ /// List of dependency relationships between a statement node and
+ /// expressions
+ dependencies: Vec<(NodeId, Handle<crate::Expression>, &'static str)>,
+ /// List of expression emitted by statement node
+ emits: Vec<(NodeId, Handle<crate::Expression>)>,
+ /// List of function call by statement node
+ calls: Vec<(NodeId, Handle<crate::Function>)>,
+}
+
+impl StatementGraph {
+ /// Adds a new block to the statement graph, returning the first and last node, respectively
+ fn add(&mut self, block: &[crate::Statement], targets: Targets) -> (NodeId, NodeId) {
+ use crate::Statement as S;
+
+ // The first node of the block isn't a statement but a virtual node
+ let root = self.nodes.len();
+ self.nodes.push(if root == 0 { "Root" } else { "Node" });
+ // Track the last placed node, this will be returned to the caller and
+ // will also be used to generate the control flow edges
+ let mut last_node = root;
+ for statement in block {
+ // Reserve a new node for the current statement and link it to the
+ // node of the previous statement
+ let id = self.nodes.len();
+ self.flow.push((last_node, id, ""));
+ self.nodes.push(""); // reserve space
+
+ // Track the node identifier for the merge node, the merge node is
+ // the last node of a statement, normally this is the node itself,
+ // but for control flow statements such as `if`s and `switch`s this
+ // is a virtual node where all branches merge back.
+ let mut merge_id = id;
+
+ self.nodes[id] = match *statement {
+ S::Emit(ref range) => {
+ for handle in range.clone() {
+ self.emits.push((id, handle));
+ }
+ "Emit"
+ }
+ S::Kill => "Kill", //TODO: link to the beginning
+ S::Break => {
+ // Try to link to the break target, otherwise produce
+ // a broken connection
+ if let Some(target) = targets.break_target {
+ self.jumps.push((id, target, "Break", 5))
+ } else {
+ self.jumps.push((id, root, "Broken", 7))
+ }
+ "Break"
+ }
+ S::Continue => {
+ // Try to link to the continue target, otherwise produce
+ // a broken connection
+ if let Some(target) = targets.continue_target {
+ self.jumps.push((id, target, "Continue", 5))
+ } else {
+ self.jumps.push((id, root, "Broken", 7))
+ }
+ "Continue"
+ }
+ S::Barrier(_flags) => "Barrier",
+ S::Block(ref b) => {
+ let (other, last) = self.add(b, targets);
+ self.flow.push((id, other, ""));
+ // All following nodes should connect to the end of the block
+ // statement so change the merge id to it.
+ merge_id = last;
+ "Block"
+ }
+ S::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ self.dependencies.push((id, condition, "condition"));
+ let (accept_id, accept_last) = self.add(accept, targets);
+ self.flow.push((id, accept_id, "accept"));
+ let (reject_id, reject_last) = self.add(reject, targets);
+ self.flow.push((id, reject_id, "reject"));
+
+ // Create a merge node, link the branches to it and set it
+ // as the merge node to make the next statement node link to it
+ merge_id = self.nodes.len();
+ self.nodes.push("Merge");
+ self.flow.push((accept_last, merge_id, ""));
+ self.flow.push((reject_last, merge_id, ""));
+
+ "If"
+ }
+ S::Switch {
+ selector,
+ ref cases,
+ } => {
+ self.dependencies.push((id, selector, "selector"));
+
+ // Create a merge node and set it as the merge node to make
+ // the next statement node link to it
+ merge_id = self.nodes.len();
+ self.nodes.push("Merge");
+
+ // Create a new targets structure and set the break target
+ // to the merge node
+ let mut targets = targets;
+ targets.break_target = Some(merge_id);
+
+ for case in cases {
+ let (case_id, case_last) = self.add(&case.body, targets);
+ let label = match case.value {
+ crate::SwitchValue::Default => "default",
+ _ => "case",
+ };
+ self.flow.push((id, case_id, label));
+ // Link the last node of the branch to the merge node
+ self.flow.push((case_last, merge_id, ""));
+ }
+ "Switch"
+ }
+ S::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ // Create a new targets structure and set the break target
+ // to the merge node, this must happen before generating the
+ // continuing block since it can break.
+ let mut targets = targets;
+ targets.break_target = Some(id);
+
+ let (continuing_id, continuing_last) = self.add(continuing, targets);
+
+ // Set the the continue target to the beginning
+ // of the newly generated continuing block
+ targets.continue_target = Some(continuing_id);
+
+ let (body_id, body_last) = self.add(body, targets);
+
+ self.flow.push((id, body_id, "body"));
+
+ // Link the last node of the body to the continuing block
+ self.flow.push((body_last, continuing_id, "continuing"));
+ // Link the last node of the continuing block back to the
+ // beginning of the loop body
+ self.flow.push((continuing_last, body_id, "continuing"));
+
+ if let Some(expr) = break_if {
+ self.dependencies.push((continuing_id, expr, "break if"));
+ }
+
+ "Loop"
+ }
+ S::Return { value } => {
+ if let Some(expr) = value {
+ self.dependencies.push((id, expr, "value"));
+ }
+ "Return"
+ }
+ S::Store { pointer, value } => {
+ self.dependencies.push((id, value, "value"));
+ self.emits.push((id, pointer));
+ "Store"
+ }
+ S::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ self.dependencies.push((id, image, "image"));
+ self.dependencies.push((id, coordinate, "coordinate"));
+ if let Some(expr) = array_index {
+ self.dependencies.push((id, expr, "array_index"));
+ }
+ self.dependencies.push((id, value, "value"));
+ "ImageStore"
+ }
+ S::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ for &arg in arguments {
+ self.dependencies.push((id, arg, "arg"));
+ }
+ if let Some(expr) = result {
+ self.emits.push((id, expr));
+ }
+ self.calls.push((id, function));
+ "Call"
+ }
+ S::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ self.emits.push((id, result));
+ self.dependencies.push((id, pointer, "pointer"));
+ self.dependencies.push((id, value, "value"));
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ self.dependencies.push((id, cmp, "cmp"));
+ }
+ "Atomic"
+ }
+ S::RayQuery { query, ref fun } => {
+ self.dependencies.push((id, query, "query"));
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ self.dependencies.push((
+ id,
+ acceleration_structure,
+ "acceleration_structure",
+ ));
+ self.dependencies.push((id, descriptor, "descriptor"));
+ "RayQueryInitialize"
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ self.emits.push((id, result));
+ "RayQueryProceed"
+ }
+ crate::RayQueryFunction::Terminate => "RayQueryTerminate",
+ }
+ }
+ };
+ // Set the last node to the merge node
+ last_node = merge_id;
+ }
+ (root, last_node)
+ }
+}
+
+#[allow(clippy::manual_unwrap_or)]
+fn name(option: &Option<String>) -> &str {
+ match *option {
+ Some(ref name) => name,
+ None => "",
+ }
+}
+
+/// set39 color scheme from <https://graphviz.org/doc/info/colors.html>
+const COLORS: &[&str] = &[
+ "white", // pattern starts at 1
+ "#8dd3c7", "#ffffb3", "#bebada", "#fb8072", "#80b1d3", "#fdb462", "#b3de69", "#fccde5",
+ "#d9d9d9",
+];
+
+fn write_fun(
+ output: &mut String,
+ prefix: String,
+ fun: &crate::Function,
+ info: Option<&FunctionInfo>,
+ options: &Options,
+) -> Result<(), FmtError> {
+ writeln!(output, "\t\tnode [ style=filled ]")?;
+
+ if !options.cfg_only {
+ for (handle, var) in fun.local_variables.iter() {
+ writeln!(
+ output,
+ "\t\t{}_l{} [ shape=hexagon label=\"{:?} '{}'\" ]",
+ prefix,
+ handle.index(),
+ handle,
+ name(&var.name),
+ )?;
+ }
+
+ write_function_expressions(output, &prefix, fun, info)?;
+ }
+
+ let mut sg = StatementGraph::default();
+ sg.add(&fun.body, Targets::default());
+ for (index, label) in sg.nodes.into_iter().enumerate() {
+ writeln!(
+ output,
+ "\t\t{prefix}_s{index} [ shape=square label=\"{label}\" ]",
+ )?;
+ }
+ for (from, to, label) in sg.flow {
+ writeln!(
+ output,
+ "\t\t{prefix}_s{from} -> {prefix}_s{to} [ arrowhead=tee label=\"{label}\" ]",
+ )?;
+ }
+ for (from, to, label, color_id) in sg.jumps {
+ writeln!(
+ output,
+ "\t\t{}_s{} -> {}_s{} [ arrowhead=tee style=dashed color=\"{}\" label=\"{}\" ]",
+ prefix, from, prefix, to, COLORS[color_id], label,
+ )?;
+ }
+
+ if !options.cfg_only {
+ for (to, expr, label) in sg.dependencies {
+ writeln!(
+ output,
+ "\t\t{}_e{} -> {}_s{} [ label=\"{}\" ]",
+ prefix,
+ expr.index(),
+ prefix,
+ to,
+ label,
+ )?;
+ }
+ for (from, to) in sg.emits {
+ writeln!(
+ output,
+ "\t\t{}_s{} -> {}_e{} [ style=dotted ]",
+ prefix,
+ from,
+ prefix,
+ to.index(),
+ )?;
+ }
+ }
+
+ for (from, function) in sg.calls {
+ writeln!(
+ output,
+ "\t\t{}_s{} -> f{}_s0",
+ prefix,
+ from,
+ function.index(),
+ )?;
+ }
+
+ Ok(())
+}
+
+fn write_function_expressions(
+ output: &mut String,
+ prefix: &str,
+ fun: &crate::Function,
+ info: Option<&FunctionInfo>,
+) -> Result<(), FmtError> {
+ enum Payload<'a> {
+ Arguments(&'a [Handle<crate::Expression>]),
+ Local(Handle<crate::LocalVariable>),
+ Global(Handle<crate::GlobalVariable>),
+ }
+
+ let mut edges = crate::FastHashMap::<&str, _>::default();
+ let mut payload = None;
+ for (handle, expression) in fun.expressions.iter() {
+ use crate::Expression as E;
+ let (label, color_id) = match *expression {
+ E::Access { base, index } => {
+ edges.insert("base", base);
+ edges.insert("index", index);
+ ("Access".into(), 1)
+ }
+ E::AccessIndex { base, index } => {
+ edges.insert("base", base);
+ (format!("AccessIndex[{index}]").into(), 1)
+ }
+ E::Constant(_) => ("Constant".into(), 2),
+ E::Splat { size, value } => {
+ edges.insert("value", value);
+ (format!("Splat{size:?}").into(), 3)
+ }
+ E::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ edges.insert("vector", vector);
+ (format!("Swizzle{:?}", &pattern[..size as usize]).into(), 3)
+ }
+ E::Compose { ref components, .. } => {
+ payload = Some(Payload::Arguments(components));
+ ("Compose".into(), 3)
+ }
+ E::FunctionArgument(index) => (format!("Argument[{index}]").into(), 1),
+ E::GlobalVariable(h) => {
+ payload = Some(Payload::Global(h));
+ ("Global".into(), 2)
+ }
+ E::LocalVariable(h) => {
+ payload = Some(Payload::Local(h));
+ ("Local".into(), 1)
+ }
+ E::Load { pointer } => {
+ edges.insert("pointer", pointer);
+ ("Load".into(), 4)
+ }
+ E::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset: _,
+ level,
+ depth_ref,
+ } => {
+ edges.insert("image", image);
+ edges.insert("sampler", sampler);
+ edges.insert("coordinate", coordinate);
+ if let Some(expr) = array_index {
+ edges.insert("array_index", expr);
+ }
+ match level {
+ crate::SampleLevel::Auto => {}
+ crate::SampleLevel::Zero => {}
+ crate::SampleLevel::Exact(expr) => {
+ edges.insert("level", expr);
+ }
+ crate::SampleLevel::Bias(expr) => {
+ edges.insert("bias", expr);
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ edges.insert("grad_x", x);
+ edges.insert("grad_y", y);
+ }
+ }
+ if let Some(expr) = depth_ref {
+ edges.insert("depth_ref", expr);
+ }
+ let string = match gather {
+ Some(component) => Cow::Owned(format!("ImageGather{component:?}")),
+ _ => Cow::Borrowed("ImageSample"),
+ };
+ (string, 5)
+ }
+ E::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ edges.insert("image", image);
+ edges.insert("coordinate", coordinate);
+ if let Some(expr) = array_index {
+ edges.insert("array_index", expr);
+ }
+ if let Some(sample) = sample {
+ edges.insert("sample", sample);
+ }
+ if let Some(level) = level {
+ edges.insert("level", level);
+ }
+ ("ImageLoad".into(), 5)
+ }
+ E::ImageQuery { image, query } => {
+ edges.insert("image", image);
+ let args = match query {
+ crate::ImageQuery::Size { level } => {
+ if let Some(expr) = level {
+ edges.insert("level", expr);
+ }
+ Cow::from("ImageSize")
+ }
+ _ => Cow::Owned(format!("{query:?}")),
+ };
+ (args, 7)
+ }
+ E::Unary { op, expr } => {
+ edges.insert("expr", expr);
+ (format!("{op:?}").into(), 6)
+ }
+ E::Binary { op, left, right } => {
+ edges.insert("left", left);
+ edges.insert("right", right);
+ (format!("{op:?}").into(), 6)
+ }
+ E::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ edges.insert("condition", condition);
+ edges.insert("accept", accept);
+ edges.insert("reject", reject);
+ ("Select".into(), 3)
+ }
+ E::Derivative { axis, ctrl, expr } => {
+ edges.insert("", expr);
+ (format!("d{axis:?}{ctrl:?}").into(), 8)
+ }
+ E::Relational { fun, argument } => {
+ edges.insert("arg", argument);
+ (format!("{fun:?}").into(), 6)
+ }
+ E::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ edges.insert("arg", arg);
+ if let Some(expr) = arg1 {
+ edges.insert("arg1", expr);
+ }
+ if let Some(expr) = arg2 {
+ edges.insert("arg2", expr);
+ }
+ if let Some(expr) = arg3 {
+ edges.insert("arg3", expr);
+ }
+ (format!("{fun:?}").into(), 7)
+ }
+ E::As {
+ kind,
+ expr,
+ convert,
+ } => {
+ edges.insert("", expr);
+ let string = match convert {
+ Some(width) => format!("Convert<{kind:?},{width}>"),
+ None => format!("Bitcast<{kind:?}>"),
+ };
+ (string.into(), 3)
+ }
+ E::CallResult(_function) => ("CallResult".into(), 4),
+ E::AtomicResult { .. } => ("AtomicResult".into(), 4),
+ E::ArrayLength(expr) => {
+ edges.insert("", expr);
+ ("ArrayLength".into(), 7)
+ }
+ E::RayQueryProceedResult => ("rayQueryProceedResult".into(), 4),
+ E::RayQueryGetIntersection { query, committed } => {
+ edges.insert("", query);
+ let ty = if committed { "Committed" } else { "Candidate" };
+ (format!("rayQueryGet{}Intersection", ty).into(), 4)
+ }
+ };
+
+ // give uniform expressions an outline
+ let color_attr = match info {
+ Some(info) if info[handle].uniformity.non_uniform_result.is_none() => "fillcolor",
+ _ => "color",
+ };
+ writeln!(
+ output,
+ "\t\t{}_e{} [ {}=\"{}\" label=\"{:?} {}\" ]",
+ prefix,
+ handle.index(),
+ color_attr,
+ COLORS[color_id],
+ handle,
+ label,
+ )?;
+
+ for (key, edge) in edges.drain() {
+ writeln!(
+ output,
+ "\t\t{}_e{} -> {}_e{} [ label=\"{}\" ]",
+ prefix,
+ edge.index(),
+ prefix,
+ handle.index(),
+ key,
+ )?;
+ }
+ match payload.take() {
+ Some(Payload::Arguments(list)) => {
+ write!(output, "\t\t{{")?;
+ for &comp in list {
+ write!(output, " {}_e{}", prefix, comp.index())?;
+ }
+ writeln!(output, " }} -> {}_e{}", prefix, handle.index())?;
+ }
+ Some(Payload::Local(h)) => {
+ writeln!(
+ output,
+ "\t\t{}_l{} -> {}_e{}",
+ prefix,
+ h.index(),
+ prefix,
+ handle.index(),
+ )?;
+ }
+ Some(Payload::Global(h)) => {
+ writeln!(
+ output,
+ "\t\tg{} -> {}_e{} [fillcolor=gray]",
+ h.index(),
+ prefix,
+ handle.index(),
+ )?;
+ }
+ None => {}
+ }
+ }
+
+ Ok(())
+}
+
+/// Write shader module to a [`String`].
+pub fn write(
+ module: &crate::Module,
+ mod_info: Option<&ModuleInfo>,
+ options: Options,
+) -> Result<String, FmtError> {
+ use std::fmt::Write as _;
+
+ let mut output = String::new();
+ output += "digraph Module {\n";
+
+ if !options.cfg_only {
+ writeln!(output, "\tsubgraph cluster_globals {{")?;
+ writeln!(output, "\t\tlabel=\"Globals\"")?;
+ for (handle, var) in module.global_variables.iter() {
+ writeln!(
+ output,
+ "\t\tg{} [ shape=hexagon label=\"{:?} {:?}/'{}'\" ]",
+ handle.index(),
+ handle,
+ var.space,
+ name(&var.name),
+ )?;
+ }
+ writeln!(output, "\t}}")?;
+ }
+
+ for (handle, fun) in module.functions.iter() {
+ let prefix = format!("f{}", handle.index());
+ writeln!(output, "\tsubgraph cluster_{prefix} {{")?;
+ writeln!(
+ output,
+ "\t\tlabel=\"Function{:?}/'{}'\"",
+ handle,
+ name(&fun.name)
+ )?;
+ let info = mod_info.map(|a| &a[handle]);
+ write_fun(&mut output, prefix, fun, info, &options)?;
+ writeln!(output, "\t}}")?;
+ }
+ for (ep_index, ep) in module.entry_points.iter().enumerate() {
+ let prefix = format!("ep{ep_index}");
+ writeln!(output, "\tsubgraph cluster_{prefix} {{")?;
+ writeln!(output, "\t\tlabel=\"{:?}/'{}'\"", ep.stage, ep.name)?;
+ let info = mod_info.map(|a| a.get_entry_point(ep_index));
+ write_fun(&mut output, prefix, &ep.function, info, &options)?;
+ writeln!(output, "\t}}")?;
+ }
+
+ output += "}\n";
+ Ok(output)
+}
diff --git a/third_party/rust/naga/src/back/glsl/features.rs b/third_party/rust/naga/src/back/glsl/features.rs
new file mode 100644
index 0000000000..8392969b56
--- /dev/null
+++ b/third_party/rust/naga/src/back/glsl/features.rs
@@ -0,0 +1,512 @@
+use super::{BackendResult, Error, Version, Writer};
+use crate::{
+ AddressSpace, Binding, Bytes, Expression, Handle, ImageClass, ImageDimension, Interpolation,
+ Sampling, ScalarKind, ShaderStage, StorageFormat, Type, TypeInner,
+};
+use std::fmt::Write;
+
+bitflags::bitflags! {
+ /// Structure used to encode additions to GLSL that aren't supported by all versions.
+ pub struct Features: u32 {
+ /// Buffer address space support.
+ const BUFFER_STORAGE = 1;
+ const ARRAY_OF_ARRAYS = 1 << 1;
+ /// 8 byte floats.
+ const DOUBLE_TYPE = 1 << 2;
+ /// More image formats.
+ const FULL_IMAGE_FORMATS = 1 << 3;
+ const MULTISAMPLED_TEXTURES = 1 << 4;
+ const MULTISAMPLED_TEXTURE_ARRAYS = 1 << 5;
+ const CUBE_TEXTURES_ARRAY = 1 << 6;
+ const COMPUTE_SHADER = 1 << 7;
+ /// Image load and early depth tests.
+ const IMAGE_LOAD_STORE = 1 << 8;
+ const CONSERVATIVE_DEPTH = 1 << 9;
+ /// Interpolation and auxiliary qualifiers.
+ ///
+ /// Perspective, Flat, and Centroid are available in all GLSL versions we support.
+ const NOPERSPECTIVE_QUALIFIER = 1 << 11;
+ const SAMPLE_QUALIFIER = 1 << 12;
+ const CLIP_DISTANCE = 1 << 13;
+ const CULL_DISTANCE = 1 << 14;
+ /// Sample ID.
+ const SAMPLE_VARIABLES = 1 << 15;
+ /// Arrays with a dynamic length.
+ const DYNAMIC_ARRAY_SIZE = 1 << 16;
+ const MULTI_VIEW = 1 << 17;
+ /// Texture samples query
+ const TEXTURE_SAMPLES = 1 << 18;
+ /// Texture levels query
+ const TEXTURE_LEVELS = 1 << 19;
+ /// Image size query
+ const IMAGE_SIZE = 1 << 20;
+ }
+}
+
+/// Helper structure used to store the required [`Features`] needed to output a
+/// [`Module`](crate::Module)
+///
+/// Provides helper methods to check for availability and writing required extensions
+pub struct FeaturesManager(Features);
+
+impl FeaturesManager {
+ /// Creates a new [`FeaturesManager`] instance
+ pub const fn new() -> Self {
+ Self(Features::empty())
+ }
+
+ /// Adds to the list of required [`Features`]
+ pub fn request(&mut self, features: Features) {
+ self.0 |= features
+ }
+
+ /// Checks that all required [`Features`] are available for the specified
+ /// [`Version`](super::Version) otherwise returns an
+ /// [`Error::MissingFeatures`](super::Error::MissingFeatures)
+ pub fn check_availability(&self, version: Version) -> BackendResult {
+ // Will store all the features that are unavailable
+ let mut missing = Features::empty();
+
+ // Helper macro to check for feature availability
+ macro_rules! check_feature {
+ // Used when only core glsl supports the feature
+ ($feature:ident, $core:literal) => {
+ if self.0.contains(Features::$feature)
+ && (version < Version::Desktop($core) || version.is_es())
+ {
+ missing |= Features::$feature;
+ }
+ };
+ // Used when both core and es support the feature
+ ($feature:ident, $core:literal, $es:literal) => {
+ if self.0.contains(Features::$feature)
+ && (version < Version::Desktop($core) || version < Version::new_gles($es))
+ {
+ missing |= Features::$feature;
+ }
+ };
+ }
+
+ check_feature!(COMPUTE_SHADER, 420, 310);
+ check_feature!(BUFFER_STORAGE, 400, 310);
+ check_feature!(DOUBLE_TYPE, 150);
+ check_feature!(CUBE_TEXTURES_ARRAY, 130, 310);
+ check_feature!(MULTISAMPLED_TEXTURES, 150, 300);
+ check_feature!(MULTISAMPLED_TEXTURE_ARRAYS, 150, 310);
+ check_feature!(ARRAY_OF_ARRAYS, 120, 310);
+ check_feature!(IMAGE_LOAD_STORE, 130, 310);
+ check_feature!(CONSERVATIVE_DEPTH, 130, 300);
+ check_feature!(CONSERVATIVE_DEPTH, 130, 300);
+ check_feature!(NOPERSPECTIVE_QUALIFIER, 130);
+ check_feature!(SAMPLE_QUALIFIER, 400, 320);
+ check_feature!(CLIP_DISTANCE, 130, 300 /* with extension */);
+ check_feature!(CULL_DISTANCE, 450, 300 /* with extension */);
+ check_feature!(SAMPLE_VARIABLES, 400, 300);
+ check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
+ match version {
+ Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300),
+ _ => check_feature!(MULTI_VIEW, 140, 310),
+ };
+ // Only available on glsl core, this means that opengl es can't query the number
+ // of samples nor levels in a image and neither do bound checks on the sample nor
+ // the level argument of texelFecth
+ check_feature!(TEXTURE_SAMPLES, 150);
+ check_feature!(TEXTURE_LEVELS, 130);
+ check_feature!(IMAGE_SIZE, 430, 310);
+
+ // Return an error if there are missing features
+ if missing.is_empty() {
+ Ok(())
+ } else {
+ Err(Error::MissingFeatures(missing))
+ }
+ }
+
+ /// Helper method used to write all needed extensions
+ ///
+ /// # Notes
+ /// This won't check for feature availability so it might output extensions that aren't even
+ /// supported.[`check_availability`](Self::check_availability) will check feature availability
+ pub fn write(&self, version: Version, mut out: impl Write) -> BackendResult {
+ if self.0.contains(Features::COMPUTE_SHADER) && !version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_compute_shader.txt
+ writeln!(out, "#extension GL_ARB_compute_shader : require")?;
+ }
+
+ if self.0.contains(Features::BUFFER_STORAGE) && !version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_storage_buffer_object.txt
+ writeln!(
+ out,
+ "#extension GL_ARB_shader_storage_buffer_object : require"
+ )?;
+ }
+
+ if self.0.contains(Features::DOUBLE_TYPE) && version < Version::Desktop(400) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_gpu_shader_fp64.txt
+ writeln!(out, "#extension GL_ARB_gpu_shader_fp64 : require")?;
+ }
+
+ if self.0.contains(Features::CUBE_TEXTURES_ARRAY) {
+ if version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_cube_map_array.txt
+ writeln!(out, "#extension GL_EXT_texture_cube_map_array : require")?;
+ } else if version < Version::Desktop(400) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_cube_map_array.txt
+ writeln!(out, "#extension GL_ARB_texture_cube_map_array : require")?;
+ }
+ }
+
+ if self.0.contains(Features::MULTISAMPLED_TEXTURE_ARRAYS) && version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_texture_storage_multisample_2d_array.txt
+ writeln!(
+ out,
+ "#extension GL_OES_texture_storage_multisample_2d_array : require"
+ )?;
+ }
+
+ if self.0.contains(Features::ARRAY_OF_ARRAYS) && version < Version::Desktop(430) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_arrays_of_arrays.txt
+ writeln!(out, "#extension ARB_arrays_of_arrays : require")?;
+ }
+
+ if self.0.contains(Features::IMAGE_LOAD_STORE) {
+ if self.0.contains(Features::FULL_IMAGE_FORMATS) && version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/NV/NV_image_formats.txt
+ writeln!(out, "#extension GL_NV_image_formats : require")?;
+ }
+
+ if version < Version::Desktop(420) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_image_load_store.txt
+ writeln!(out, "#extension GL_ARB_shader_image_load_store : require")?;
+ }
+ }
+
+ if self.0.contains(Features::CONSERVATIVE_DEPTH) {
+ if version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_conservative_depth.txt
+ writeln!(out, "#extension GL_EXT_conservative_depth : require")?;
+ }
+
+ if version < Version::Desktop(420) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt
+ writeln!(out, "#extension GL_ARB_conservative_depth : require")?;
+ }
+ }
+
+ if (self.0.contains(Features::CLIP_DISTANCE) || self.0.contains(Features::CULL_DISTANCE))
+ && version.is_es()
+ {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_clip_cull_distance.txt
+ writeln!(out, "#extension GL_EXT_clip_cull_distance : require")?;
+ }
+
+ if self.0.contains(Features::SAMPLE_VARIABLES) && version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_sample_variables.txt
+ writeln!(out, "#extension GL_OES_sample_variables : require")?;
+ }
+
+ if self.0.contains(Features::SAMPLE_VARIABLES) && version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_sample_variables.txt
+ writeln!(out, "#extension GL_OES_sample_variables : require")?;
+ }
+
+ if self.0.contains(Features::MULTI_VIEW) {
+ if let Version::Embedded { is_webgl: true, .. } = version {
+ // https://www.khronos.org/registry/OpenGL/extensions/OVR/OVR_multiview2.txt
+ writeln!(out, "#extension GL_OVR_multiview2 : require")?;
+ } else {
+ // https://github.com/KhronosGroup/GLSL/blob/master/extensions/ext/GL_EXT_multiview.txt
+ writeln!(out, "#extension GL_EXT_multiview : require")?;
+ }
+ }
+
+ if self.0.contains(Features::TEXTURE_SAMPLES) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_texture_image_samples.txt
+ writeln!(
+ out,
+ "#extension GL_ARB_shader_texture_image_samples : require"
+ )?;
+ }
+
+ if self.0.contains(Features::TEXTURE_LEVELS) && version < Version::Desktop(430) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_query_levels.txt
+ writeln!(out, "#extension GL_ARB_texture_query_levels : require")?;
+ }
+
+ Ok(())
+ }
+}
+
+impl<'a, W> Writer<'a, W> {
+ /// Helper method that searches the module for all the needed [`Features`]
+ ///
+ /// # Errors
+ /// If the version doesn't support any of the needed [`Features`] a
+ /// [`Error::MissingFeatures`](super::Error::MissingFeatures) will be returned
+ pub(super) fn collect_required_features(&mut self) -> BackendResult {
+ let ep_info = self.info.get_entry_point(self.entry_point_idx as usize);
+
+ if let Some(depth_test) = self.entry_point.early_depth_test {
+ // If IMAGE_LOAD_STORE is supported for this version of GLSL
+ if self.options.version.supports_early_depth_test() {
+ self.features.request(Features::IMAGE_LOAD_STORE);
+ }
+
+ if depth_test.conservative.is_some() {
+ self.features.request(Features::CONSERVATIVE_DEPTH);
+ }
+ }
+
+ for arg in self.entry_point.function.arguments.iter() {
+ self.varying_required_features(arg.binding.as_ref(), arg.ty);
+ }
+ if let Some(ref result) = self.entry_point.function.result {
+ self.varying_required_features(result.binding.as_ref(), result.ty);
+ }
+
+ if let ShaderStage::Compute = self.entry_point.stage {
+ self.features.request(Features::COMPUTE_SHADER)
+ }
+
+ if self.multiview.is_some() {
+ self.features.request(Features::MULTI_VIEW);
+ }
+
+ for (ty_handle, ty) in self.module.types.iter() {
+ match ty.inner {
+ TypeInner::Scalar { kind, width } => self.scalar_required_features(kind, width),
+ TypeInner::Vector { kind, width, .. } => self.scalar_required_features(kind, width),
+ TypeInner::Matrix { width, .. } => {
+ self.scalar_required_features(ScalarKind::Float, width)
+ }
+ TypeInner::Array { base, size, .. } => {
+ if let TypeInner::Array { .. } = self.module.types[base].inner {
+ self.features.request(Features::ARRAY_OF_ARRAYS)
+ }
+
+ // If the array is dynamically sized
+ if size == crate::ArraySize::Dynamic {
+ let mut is_used = false;
+
+ // Check if this type is used in a global that is needed by the current entrypoint
+ for (global_handle, global) in self.module.global_variables.iter() {
+ // Skip unused globals
+ if ep_info[global_handle].is_empty() {
+ continue;
+ }
+
+ // If this array is the type of a global, then this array is used
+ if global.ty == ty_handle {
+ is_used = true;
+ break;
+ }
+
+ // If the type of this global is a struct
+ if let crate::TypeInner::Struct { ref members, .. } =
+ self.module.types[global.ty].inner
+ {
+ // Check the last element of the struct to see if it's type uses
+ // this array
+ if let Some(last) = members.last() {
+ if last.ty == ty_handle {
+ is_used = true;
+ break;
+ }
+ }
+ }
+ }
+
+ // If this dynamically size array is used, we need dynamic array size support
+ if is_used {
+ self.features.request(Features::DYNAMIC_ARRAY_SIZE);
+ }
+ }
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ if arrayed && dim == ImageDimension::Cube {
+ self.features.request(Features::CUBE_TEXTURES_ARRAY)
+ }
+
+ match class {
+ ImageClass::Sampled { multi: true, .. }
+ | ImageClass::Depth { multi: true } => {
+ self.features.request(Features::MULTISAMPLED_TEXTURES);
+ if arrayed {
+ self.features.request(Features::MULTISAMPLED_TEXTURE_ARRAYS);
+ }
+ }
+ ImageClass::Storage { format, .. } => match format {
+ StorageFormat::R8Unorm
+ | StorageFormat::R8Snorm
+ | StorageFormat::R8Uint
+ | StorageFormat::R8Sint
+ | StorageFormat::R16Uint
+ | StorageFormat::R16Sint
+ | StorageFormat::R16Float
+ | StorageFormat::Rg8Unorm
+ | StorageFormat::Rg8Snorm
+ | StorageFormat::Rg8Uint
+ | StorageFormat::Rg8Sint
+ | StorageFormat::Rg16Uint
+ | StorageFormat::Rg16Sint
+ | StorageFormat::Rg16Float
+ | StorageFormat::Rgb10a2Unorm
+ | StorageFormat::Rg11b10Float
+ | StorageFormat::Rg32Uint
+ | StorageFormat::Rg32Sint
+ | StorageFormat::Rg32Float => {
+ self.features.request(Features::FULL_IMAGE_FORMATS)
+ }
+ _ => {}
+ },
+ ImageClass::Sampled { multi: false, .. }
+ | ImageClass::Depth { multi: false } => {}
+ }
+ }
+ _ => {}
+ }
+ }
+
+ let mut push_constant_used = false;
+
+ for (handle, global) in self.module.global_variables.iter() {
+ if ep_info[handle].is_empty() {
+ continue;
+ }
+ match global.space {
+ AddressSpace::WorkGroup => self.features.request(Features::COMPUTE_SHADER),
+ AddressSpace::Storage { .. } => self.features.request(Features::BUFFER_STORAGE),
+ AddressSpace::PushConstant => {
+ if push_constant_used {
+ return Err(Error::MultiplePushConstants);
+ }
+ push_constant_used = true;
+ }
+ _ => {}
+ }
+ }
+
+ // We will need to pass some of the members to a closure, so we need
+ // to separate them otherwise the borrow checker will complain, this
+ // shouldn't be needed in rust 2021
+ let &mut Self {
+ module,
+ info,
+ ref mut features,
+ entry_point,
+ entry_point_idx,
+ ref policies,
+ ..
+ } = self;
+
+ // Loop trough all expressions in both functions and the entry point
+ // to check for needed features
+ for (expressions, info) in module
+ .functions
+ .iter()
+ .map(|(h, f)| (&f.expressions, &info[h]))
+ .chain(std::iter::once((
+ &entry_point.function.expressions,
+ info.get_entry_point(entry_point_idx as usize),
+ )))
+ {
+ for (_, expr) in expressions.iter() {
+ match *expr {
+ // Check for queries that neeed aditonal features
+ Expression::ImageQuery {
+ image,
+ query,
+ ..
+ } => match query {
+ // Storage images use `imageSize` which is only available
+ // in glsl > 420
+ //
+ // layers queries are also implemented as size queries
+ crate::ImageQuery::Size { .. } | crate::ImageQuery::NumLayers => {
+ if let TypeInner::Image {
+ class: crate::ImageClass::Storage { .. }, ..
+ } = *info[image].ty.inner_with(&module.types) {
+ features.request(Features::IMAGE_SIZE)
+ }
+ },
+ crate::ImageQuery::NumLevels => features.request(Features::TEXTURE_LEVELS),
+ crate::ImageQuery::NumSamples => features.request(Features::TEXTURE_SAMPLES),
+ }
+ ,
+ // Check for image loads that needs bound checking on the sample
+ // or level argument since this requires a feature
+ Expression::ImageLoad {
+ sample, level, ..
+ } => {
+ if policies.image != crate::proc::BoundsCheckPolicy::Unchecked {
+ if sample.is_some() {
+ features.request(Features::TEXTURE_SAMPLES)
+ }
+
+ if level.is_some() {
+ features.request(Features::TEXTURE_LEVELS)
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+
+ self.features.check_availability(self.options.version)
+ }
+
+ /// Helper method that checks the [`Features`] needed by a scalar
+ fn scalar_required_features(&mut self, kind: ScalarKind, width: Bytes) {
+ if kind == ScalarKind::Float && width == 8 {
+ self.features.request(Features::DOUBLE_TYPE);
+ }
+ }
+
+ fn varying_required_features(&mut self, binding: Option<&Binding>, ty: Handle<Type>) {
+ match self.module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for member in members {
+ self.varying_required_features(member.binding.as_ref(), member.ty);
+ }
+ }
+ _ => {
+ if let Some(binding) = binding {
+ match *binding {
+ Binding::BuiltIn(built_in) => match built_in {
+ crate::BuiltIn::ClipDistance => {
+ self.features.request(Features::CLIP_DISTANCE)
+ }
+ crate::BuiltIn::CullDistance => {
+ self.features.request(Features::CULL_DISTANCE)
+ }
+ crate::BuiltIn::SampleIndex => {
+ self.features.request(Features::SAMPLE_VARIABLES)
+ }
+ crate::BuiltIn::ViewIndex => {
+ self.features.request(Features::MULTI_VIEW)
+ }
+ _ => {}
+ },
+ Binding::Location {
+ location: _,
+ interpolation,
+ sampling,
+ } => {
+ if interpolation == Some(Interpolation::Linear) {
+ self.features.request(Features::NOPERSPECTIVE_QUALIFIER);
+ }
+ if sampling == Some(Sampling::Sample) {
+ self.features.request(Features::SAMPLE_QUALIFIER);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/glsl/keywords.rs b/third_party/rust/naga/src/back/glsl/keywords.rs
new file mode 100644
index 0000000000..5a2836c189
--- /dev/null
+++ b/third_party/rust/naga/src/back/glsl/keywords.rs
@@ -0,0 +1,204 @@
+pub const RESERVED_KEYWORDS: &[&str] = &[
+ "attribute",
+ "const",
+ "uniform",
+ "varying",
+ "buffer",
+ "shared",
+ "coherent",
+ "volatile",
+ "restrict",
+ "readonly",
+ "writeonly",
+ "atomic_uint",
+ "layout",
+ "centroid",
+ "flat",
+ "smooth",
+ "noperspective",
+ "patch",
+ "sample",
+ "break",
+ "continue",
+ "do",
+ "for",
+ "while",
+ "switch",
+ "case",
+ "default",
+ "if",
+ "else",
+ "subroutine",
+ "in",
+ "out",
+ "inout",
+ "float",
+ "double",
+ "int",
+ "void",
+ "bool",
+ "true",
+ "false",
+ "invariant",
+ "precise",
+ "discard",
+ "return",
+ "mat2",
+ "mat3",
+ "mat4",
+ "dmat2",
+ "dmat3",
+ "dmat4",
+ "mat2x2",
+ "mat2x3",
+ "mat2x4",
+ "dmat2x2",
+ "dmat2x3",
+ "dmat2x4",
+ "mat3x2",
+ "mat3x3",
+ "mat3x4",
+ "dmat3x2",
+ "dmat3x3",
+ "dmat3x4",
+ "mat4x2",
+ "mat4x3",
+ "mat4x4",
+ "dmat4x2",
+ "dmat4x3",
+ "dmat4x4",
+ "vec2",
+ "vec3",
+ "vec4",
+ "ivec2",
+ "ivec3",
+ "ivec4",
+ "bvec2",
+ "bvec3",
+ "bvec4",
+ "dvec2",
+ "dvec3",
+ "dvec4",
+ "uint",
+ "uvec2",
+ "uvec3",
+ "uvec4",
+ "lowp",
+ "mediump",
+ "highp",
+ "precision",
+ "sampler1D",
+ "sampler2D",
+ "sampler3D",
+ "samplerCube",
+ "sampler1DShadow",
+ "sampler2DShadow",
+ "samplerCubeShadow",
+ "sampler1DArray",
+ "sampler2DArray",
+ "sampler1DArrayShadow",
+ "sampler2DArrayShadow",
+ "isampler1D",
+ "isampler2D",
+ "isampler3D",
+ "isamplerCube",
+ "isampler1DArray",
+ "isampler2DArray",
+ "usampler1D",
+ "usampler2D",
+ "usampler3D",
+ "usamplerCube",
+ "usampler1DArray",
+ "usampler2DArray",
+ "sampler2DRect",
+ "sampler2DRectShadow",
+ "isampler2D",
+ "Rect",
+ "usampler2DRect",
+ "samplerBuffer",
+ "isamplerBuffer",
+ "usamplerBuffer",
+ "sampler2DMS",
+ "isampler2DMS",
+ "usampler2DMS",
+ "sampler2DMSArray",
+ "isampler2DMSArray",
+ "usampler2DMSArray",
+ "samplerCubeArray",
+ "samplerCubeArrayShadow",
+ "isamplerCubeArray",
+ "usamplerCubeArray",
+ "image1D",
+ "iimage1D",
+ "uimage1D",
+ "image2D",
+ "iimage2D",
+ "uimage2D",
+ "image3D",
+ "iimage3D",
+ "uimage3D",
+ "image2DRect",
+ "iimage2DRect",
+ "uimage2DRect",
+ "imageCube",
+ "iimageCube",
+ "uimageCube",
+ "imageBuffer",
+ "iimageBuffer",
+ "uimageBuffer",
+ "image1DArray",
+ "iimage1DArray",
+ "uimage1DArray",
+ "image2DArray",
+ "iimage2DArray",
+ "uimage2DArray",
+ "imageCubeArray",
+ "iimageCubeArray",
+ "uimageCubeArray",
+ "image2DMS",
+ "iimage2DMS",
+ "uimage2DMS",
+ "image2DMSArray",
+ "iimage2DMSArray",
+ "uimage2DMSArraystruct",
+ "common",
+ "partition",
+ "active",
+ "asm",
+ "class",
+ "union",
+ "enum",
+ "typedef",
+ "template",
+ "this",
+ "resource",
+ "goto",
+ "inline",
+ "noinline",
+ "public",
+ "static",
+ "extern",
+ "external",
+ "interface",
+ "long",
+ "short",
+ "half",
+ "fixed",
+ "unsigned",
+ "superp",
+ "input",
+ "output",
+ "hvec2",
+ "hvec3",
+ "hvec4",
+ "fvec2",
+ "fvec3",
+ "fvec4",
+ "sampler3DRect",
+ "filter",
+ "sizeof",
+ "cast",
+ "namespace",
+ "using",
+ "main",
+];
diff --git a/third_party/rust/naga/src/back/glsl/mod.rs b/third_party/rust/naga/src/back/glsl/mod.rs
new file mode 100644
index 0000000000..9195b96837
--- /dev/null
+++ b/third_party/rust/naga/src/back/glsl/mod.rs
@@ -0,0 +1,4114 @@
+/*!
+Backend for [GLSL][glsl] (OpenGL Shading Language).
+
+The main structure is [`Writer`], it maintains internal state that is used
+to output a [`Module`](crate::Module) into glsl
+
+# Supported versions
+### Core
+- 330
+- 400
+- 410
+- 420
+- 430
+- 450
+
+### ES
+- 300
+- 310
+
+[glsl]: https://www.khronos.org/registry/OpenGL/index_gl.php
+*/
+
+// GLSL is mostly a superset of C but it also removes some parts of it this is a list of relevant
+// aspects for this backend.
+//
+// The most notable change is the introduction of the version preprocessor directive that must
+// always be the first line of a glsl file and is written as
+// `#version number profile`
+// `number` is the version itself (i.e. 300) and `profile` is the
+// shader profile we only support "core" and "es", the former is used in desktop applications and
+// the later is used in embedded contexts, mobile devices and browsers. Each one as it's own
+// versions (at the time of writing this the latest version for "core" is 460 and for "es" is 320)
+//
+// Other important preprocessor addition is the extension directive which is written as
+// `#extension name: behaviour`
+// Extensions provide increased features in a plugin fashion but they aren't required to be
+// supported hence why they are called extensions, that's why `behaviour` is used it specifies
+// whether the extension is strictly required or if it should only be enabled if needed. In our case
+// when we use extensions we set behaviour to `require` always.
+//
+// The only thing that glsl removes that makes a difference are pointers.
+//
+// Additions that are relevant for the backend are the discard keyword, the introduction of
+// vector, matrices, samplers, image types and functions that provide common shader operations
+
+pub use features::Features;
+
+use crate::{
+ back,
+ proc::{self, NameKey},
+ valid, Handle, ShaderStage, TypeInner,
+};
+use features::FeaturesManager;
+use std::{
+ cmp::Ordering,
+ fmt,
+ fmt::{Error as FmtError, Write},
+};
+use thiserror::Error;
+
+/// Contains the features related code and the features querying method
+mod features;
+/// Contains a constant with a slice of all the reserved keywords RESERVED_KEYWORDS
+mod keywords;
+
+/// List of supported `core` GLSL versions.
+pub const SUPPORTED_CORE_VERSIONS: &[u16] = &[330, 400, 410, 420, 430, 440, 450];
+/// List of supported `es` GLSL versions.
+pub const SUPPORTED_ES_VERSIONS: &[u16] = &[300, 310, 320];
+
+/// The suffix of the variable that will hold the calculated clamped level
+/// of detail for bounds checking in `ImageLoad`
+const CLAMPED_LOD_SUFFIX: &str = "_clamped_lod";
+
+/// Mapping between resources and bindings.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, u8>;
+
+impl crate::AtomicFunction {
+ const fn to_glsl(self) -> &'static str {
+ match self {
+ Self::Add | Self::Subtract => "Add",
+ Self::And => "And",
+ Self::InclusiveOr => "Or",
+ Self::ExclusiveOr => "Xor",
+ Self::Min => "Min",
+ Self::Max => "Max",
+ Self::Exchange { compare: None } => "Exchange",
+ Self::Exchange { compare: Some(_) } => "", //TODO
+ }
+ }
+}
+
+impl crate::AddressSpace {
+ const fn is_buffer(&self) -> bool {
+ match *self {
+ crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } => true,
+ _ => false,
+ }
+ }
+
+ /// Whether a variable with this address space can be initialized
+ const fn initializable(&self) -> bool {
+ match *self {
+ crate::AddressSpace::Function | crate::AddressSpace::Private => true,
+ crate::AddressSpace::WorkGroup
+ | crate::AddressSpace::Uniform
+ | crate::AddressSpace::Storage { .. }
+ | crate::AddressSpace::Handle
+ | crate::AddressSpace::PushConstant => false,
+ }
+ }
+}
+
+/// A GLSL version.
+#[derive(Debug, Copy, Clone, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum Version {
+ /// `core` GLSL.
+ Desktop(u16),
+ /// `es` GLSL.
+ Embedded { version: u16, is_webgl: bool },
+}
+
+impl Version {
+ /// Create a new gles version
+ pub const fn new_gles(version: u16) -> Self {
+ Self::Embedded {
+ version,
+ is_webgl: false,
+ }
+ }
+
+ /// Returns true if self is `Version::Embedded` (i.e. is a es version)
+ const fn is_es(&self) -> bool {
+ match *self {
+ Version::Desktop(_) => false,
+ Version::Embedded { .. } => true,
+ }
+ }
+
+ /// Returns true if targetting WebGL
+ const fn is_webgl(&self) -> bool {
+ match *self {
+ Version::Desktop(_) => false,
+ Version::Embedded { is_webgl, .. } => is_webgl,
+ }
+ }
+
+ /// Checks the list of currently supported versions and returns true if it contains the
+ /// specified version
+ ///
+ /// # Notes
+ /// As an invalid version number will never be added to the supported version list
+ /// so this also checks for version validity
+ fn is_supported(&self) -> bool {
+ match *self {
+ Version::Desktop(v) => SUPPORTED_CORE_VERSIONS.contains(&v),
+ Version::Embedded { version: v, .. } => SUPPORTED_ES_VERSIONS.contains(&v),
+ }
+ }
+
+ /// Checks if the version supports all of the explicit layouts:
+ /// - `location=` qualifiers for bindings
+ /// - `binding=` qualifiers for resources
+ ///
+ /// Note: `location=` for vertex inputs and fragment outputs is supported
+ /// unconditionally for GLES 300.
+ fn supports_explicit_locations(&self) -> bool {
+ *self >= Version::Desktop(410) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_early_depth_test(&self) -> bool {
+ *self >= Version::Desktop(130) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_std430_layout(&self) -> bool {
+ *self >= Version::Desktop(430) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_fma_function(&self) -> bool {
+ *self >= Version::Desktop(400) || *self >= Version::new_gles(320)
+ }
+
+ fn supports_integer_functions(&self) -> bool {
+ *self >= Version::Desktop(400) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_derivative_control(&self) -> bool {
+ *self >= Version::Desktop(450)
+ }
+}
+
+impl PartialOrd for Version {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ match (*self, *other) {
+ (Version::Desktop(x), Version::Desktop(y)) => Some(x.cmp(&y)),
+ (Version::Embedded { version: x, .. }, Version::Embedded { version: y, .. }) => {
+ Some(x.cmp(&y))
+ }
+ _ => None,
+ }
+ }
+}
+
+impl fmt::Display for Version {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match *self {
+ Version::Desktop(v) => write!(f, "{v} core"),
+ Version::Embedded { version: v, .. } => write!(f, "{v} es"),
+ }
+ }
+}
+
+bitflags::bitflags! {
+ /// Configuration flags for the [`Writer`].
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct WriterFlags: u32 {
+ /// Flip output Y and extend Z from (0, 1) to (-1, 1).
+ const ADJUST_COORDINATE_SPACE = 0x1;
+ /// Supports GL_EXT_texture_shadow_lod on the host, which provides
+ /// additional functions on shadows and arrays of shadows.
+ const TEXTURE_SHADOW_LOD = 0x2;
+ /// Include unused global variables, constants and functions. By default the output will exclude
+ /// global variables that are not used in the specified entrypoint (including indirect use),
+ /// all constant declarations, and functions that use excluded global variables.
+ const INCLUDE_UNUSED_ITEMS = 0x4;
+ /// Emit `PointSize` output builtin to vertex shaders, which is
+ /// required for drawing with `PointList` topology.
+ ///
+ /// https://registry.khronos.org/OpenGL/specs/es/3.2/GLSL_ES_Specification_3.20.html#built-in-language-variables
+ /// The variable gl_PointSize is intended for a shader to write the size of the point to be rasterized. It is measured in pixels.
+ /// If gl_PointSize is not written to, its value is undefined in subsequent pipe stages.
+ const FORCE_POINT_SIZE = 0x10;
+ }
+}
+
+/// Configuration used in the [`Writer`].
+#[derive(Debug, Clone)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Options {
+ /// The GLSL version to be used.
+ pub version: Version,
+ /// Configuration flags for the [`Writer`].
+ pub writer_flags: WriterFlags,
+ /// Map of resources association to binding locations.
+ pub binding_map: BindingMap,
+ /// Should workgroup variables be zero initialized (by polyfilling)?
+ pub zero_initialize_workgroup_memory: bool,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ version: Version::new_gles(310),
+ writer_flags: WriterFlags::ADJUST_COORDINATE_SPACE,
+ binding_map: BindingMap::default(),
+ zero_initialize_workgroup_memory: true,
+ }
+ }
+}
+
+/// A subset of options meant to be changed per pipeline.
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct PipelineOptions {
+ /// The stage of the entry point.
+ pub shader_stage: ShaderStage,
+ /// The name of the entry point.
+ ///
+ /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown.
+ pub entry_point: String,
+ /// How many views to render to, if doing multiview rendering.
+ pub multiview: Option<std::num::NonZeroU32>,
+}
+
+/// Reflection info for texture mappings and uniforms.
+pub struct ReflectionInfo {
+ /// Mapping between texture names and variables/samplers.
+ pub texture_mapping: crate::FastHashMap<String, TextureMapping>,
+ /// Mapping between uniform variables and names.
+ pub uniforms: crate::FastHashMap<Handle<crate::GlobalVariable>, String>,
+}
+
+/// Mapping between a texture and its sampler, if it exists.
+///
+/// GLSL pre-Vulkan has no concept of separate textures and samplers. Instead, everything is a
+/// `gsamplerN` where `g` is the scalar type and `N` is the dimension. But naga uses separate textures
+/// and samplers in the IR, so the backend produces a [`FastHashMap`](crate::FastHashMap) with the texture name
+/// as a key and a [`TextureMapping`] as a value. This way, the user knows where to bind.
+///
+/// [`Storage`](crate::ImageClass::Storage) images produce `gimageN` and don't have an associated sampler,
+/// so the [`sampler`](Self::sampler) field will be [`None`].
+#[derive(Debug, Clone)]
+pub struct TextureMapping {
+ /// Handle to the image global variable.
+ pub texture: Handle<crate::GlobalVariable>,
+ /// Handle to the associated sampler global variable, if it exists.
+ pub sampler: Option<Handle<crate::GlobalVariable>>,
+}
+
+/// Helper structure that generates a number
+#[derive(Default)]
+struct IdGenerator(u32);
+
+impl IdGenerator {
+ /// Generates a number that's guaranteed to be unique for this `IdGenerator`
+ fn generate(&mut self) -> u32 {
+ // It's just an increasing number but it does the job
+ let ret = self.0;
+ self.0 += 1;
+ ret
+ }
+}
+
+/// Helper wrapper used to get a name for a varying
+///
+/// Varying have different naming schemes depending on their binding:
+/// - Varyings with builtin bindings get the from [`glsl_built_in`](glsl_built_in).
+/// - Varyings with location bindings are named `_S_location_X` where `S` is a
+/// prefix identifying which pipeline stage the varying connects, and `X` is
+/// the location.
+struct VaryingName<'a> {
+ binding: &'a crate::Binding,
+ stage: ShaderStage,
+ output: bool,
+ targetting_webgl: bool,
+}
+impl fmt::Display for VaryingName<'_> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match *self.binding {
+ crate::Binding::Location { location, .. } => {
+ let prefix = match (self.stage, self.output) {
+ (ShaderStage::Compute, _) => unreachable!(),
+ // pipeline to vertex
+ (ShaderStage::Vertex, false) => "p2vs",
+ // vertex to fragment
+ (ShaderStage::Vertex, true) | (ShaderStage::Fragment, false) => "vs2fs",
+ // fragment to pipeline
+ (ShaderStage::Fragment, true) => "fs2p",
+ };
+ write!(f, "_{prefix}_location{location}",)
+ }
+ crate::Binding::BuiltIn(built_in) => {
+ write!(
+ f,
+ "{}",
+ glsl_built_in(built_in, self.output, self.targetting_webgl)
+ )
+ }
+ }
+ }
+}
+
+impl ShaderStage {
+ const fn to_str(self) -> &'static str {
+ match self {
+ ShaderStage::Compute => "cs",
+ ShaderStage::Fragment => "fs",
+ ShaderStage::Vertex => "vs",
+ }
+ }
+}
+
+/// Shorthand result used internally by the backend
+type BackendResult<T = ()> = Result<T, Error>;
+
+/// A GLSL compilation error.
+#[derive(Debug, Error)]
+pub enum Error {
+ /// A error occurred while writing to the output.
+ #[error("Format error")]
+ FmtError(#[from] FmtError),
+ /// The specified [`Version`] doesn't have all required [`Features`].
+ ///
+ /// Contains the missing [`Features`].
+ #[error("The selected version doesn't support {0:?}")]
+ MissingFeatures(Features),
+ /// [`AddressSpace::PushConstant`](crate::AddressSpace::PushConstant) was used more than
+ /// once in the entry point, which isn't supported.
+ #[error("Multiple push constants aren't supported")]
+ MultiplePushConstants,
+ /// The specified [`Version`] isn't supported.
+ #[error("The specified version isn't supported")]
+ VersionNotSupported,
+ /// The entry point couldn't be found.
+ #[error("The requested entry point couldn't be found")]
+ EntryPointNotFound,
+ /// A call was made to an unsupported external.
+ #[error("A call was made to an unsupported external: {0}")]
+ UnsupportedExternal(String),
+ /// A scalar with an unsupported width was requested.
+ #[error("A scalar with an unsupported width was requested: {0:?} {1:?}")]
+ UnsupportedScalar(crate::ScalarKind, crate::Bytes),
+ /// A image was used with multiple samplers, which isn't supported.
+ #[error("A image was used with multiple samplers")]
+ ImageMultipleSamplers,
+ #[error("{0}")]
+ Custom(String),
+}
+
+/// Binary operation with a different logic on the GLSL side.
+enum BinaryOperation {
+ /// Vector comparison should use the function like `greaterThan()`, etc.
+ VectorCompare,
+ /// Vector component wise operation; used to polyfill unsupported ops like `|` and `&` for `bvecN`'s
+ VectorComponentWise,
+ /// GLSL `%` is SPIR-V `OpUMod/OpSMod` and `mod()` is `OpFMod`, but [`BinaryOperator::Modulo`](crate::BinaryOperator::Modulo) is `OpFRem`.
+ Modulo,
+ /// Any plain operation. No additional logic required.
+ Other,
+}
+
+/// Writer responsible for all code generation.
+pub struct Writer<'a, W> {
+ // Inputs
+ /// The module being written.
+ module: &'a crate::Module,
+ /// The module analysis.
+ info: &'a valid::ModuleInfo,
+ /// The output writer.
+ out: W,
+ /// User defined configuration to be used.
+ options: &'a Options,
+ /// The bound checking policies to be used
+ policies: proc::BoundsCheckPolicies,
+
+ // Internal State
+ /// Features manager used to store all the needed features and write them.
+ features: FeaturesManager,
+ namer: proc::Namer,
+ /// A map with all the names needed for writing the module
+ /// (generated by a [`Namer`](crate::proc::Namer)).
+ names: crate::FastHashMap<NameKey, String>,
+ /// A map with the names of global variables needed for reflections.
+ reflection_names_globals: crate::FastHashMap<Handle<crate::GlobalVariable>, String>,
+ /// The selected entry point.
+ entry_point: &'a crate::EntryPoint,
+ /// The index of the selected entry point.
+ entry_point_idx: proc::EntryPointIndex,
+ /// A generator for unique block numbers.
+ block_id: IdGenerator,
+ /// Set of expressions that have associated temporary variables.
+ named_expressions: crate::NamedExpressions,
+ /// Set of expressions that need to be baked to avoid unnecessary repetition in output
+ need_bake_expressions: back::NeedBakeExpressions,
+ /// How many views to render to, if doing multiview rendering.
+ multiview: Option<std::num::NonZeroU32>,
+}
+
+impl<'a, W: Write> Writer<'a, W> {
+ /// Creates a new [`Writer`] instance.
+ ///
+ /// # Errors
+ /// - If the version specified is invalid or supported.
+ /// - If the entry point couldn't be found in the module.
+ /// - If the version specified doesn't support some used features.
+ pub fn new(
+ out: W,
+ module: &'a crate::Module,
+ info: &'a valid::ModuleInfo,
+ options: &'a Options,
+ pipeline_options: &'a PipelineOptions,
+ policies: proc::BoundsCheckPolicies,
+ ) -> Result<Self, Error> {
+ // Check if the requested version is supported
+ if !options.version.is_supported() {
+ log::error!("Version {}", options.version);
+ return Err(Error::VersionNotSupported);
+ }
+
+ // Try to find the entry point and corresponding index
+ let ep_idx = module
+ .entry_points
+ .iter()
+ .position(|ep| {
+ pipeline_options.shader_stage == ep.stage && pipeline_options.entry_point == ep.name
+ })
+ .ok_or(Error::EntryPointNotFound)?;
+
+ // Generate a map with names required to write the module
+ let mut names = crate::FastHashMap::default();
+ let mut namer = proc::Namer::default();
+ namer.reset(module, keywords::RESERVED_KEYWORDS, &["gl_"], &mut names);
+
+ // Build the instance
+ let mut this = Self {
+ module,
+ info,
+ out,
+ options,
+ policies,
+
+ namer,
+ features: FeaturesManager::new(),
+ names,
+ reflection_names_globals: crate::FastHashMap::default(),
+ entry_point: &module.entry_points[ep_idx],
+ entry_point_idx: ep_idx as u16,
+ multiview: pipeline_options.multiview,
+ block_id: IdGenerator::default(),
+ named_expressions: Default::default(),
+ need_bake_expressions: Default::default(),
+ };
+
+ // Find all features required to print this module
+ this.collect_required_features()?;
+
+ Ok(this)
+ }
+
+ /// Writes the [`Module`](crate::Module) as glsl to the output
+ ///
+ /// # Notes
+ /// If an error occurs while writing, the output might have been written partially
+ ///
+ /// # Panics
+ /// Might panic if the module is invalid
+ pub fn write(&mut self) -> Result<ReflectionInfo, Error> {
+ // We use `writeln!(self.out)` throughout the write to add newlines
+ // to make the output more readable
+
+ let es = self.options.version.is_es();
+
+ // Write the version (It must be the first thing or it isn't a valid glsl output)
+ writeln!(self.out, "#version {}", self.options.version)?;
+ // Write all the needed extensions
+ //
+ // This used to be the last thing being written as it allowed to search for features while
+ // writing the module saving some loops but some older versions (420 or less) required the
+ // extensions to appear before being used, even though extensions are part of the
+ // preprocessor not the processor ¯\_(ツ)_/¯
+ self.features.write(self.options.version, &mut self.out)?;
+
+ // Write the additional extensions
+ if self
+ .options
+ .writer_flags
+ .contains(WriterFlags::TEXTURE_SHADOW_LOD)
+ {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_shadow_lod.txt
+ writeln!(self.out, "#extension GL_EXT_texture_shadow_lod : require")?;
+ }
+
+ // glsl es requires a precision to be specified for floats and ints
+ // TODO: Should this be user configurable?
+ if es {
+ writeln!(self.out)?;
+ writeln!(self.out, "precision highp float;")?;
+ writeln!(self.out, "precision highp int;")?;
+ writeln!(self.out)?;
+ }
+
+ if self.entry_point.stage == ShaderStage::Compute {
+ let workgroup_size = self.entry_point.workgroup_size;
+ writeln!(
+ self.out,
+ "layout(local_size_x = {}, local_size_y = {}, local_size_z = {}) in;",
+ workgroup_size[0], workgroup_size[1], workgroup_size[2]
+ )?;
+ writeln!(self.out)?;
+ }
+
+ // Enable early depth tests if needed
+ if let Some(depth_test) = self.entry_point.early_depth_test {
+ // If early depth test is supported for this version of GLSL
+ if self.options.version.supports_early_depth_test() {
+ writeln!(self.out, "layout(early_fragment_tests) in;")?;
+
+ if let Some(conservative) = depth_test.conservative {
+ use crate::ConservativeDepth as Cd;
+
+ let depth = match conservative {
+ Cd::GreaterEqual => "greater",
+ Cd::LessEqual => "less",
+ Cd::Unchanged => "unchanged",
+ };
+ writeln!(self.out, "layout (depth_{depth}) out float gl_FragDepth;")?;
+ }
+ writeln!(self.out)?;
+ } else {
+ log::warn!(
+ "Early depth testing is not supported for this version of GLSL: {}",
+ self.options.version
+ );
+ }
+ }
+
+ if self.entry_point.stage == ShaderStage::Vertex && self.options.version.is_webgl() {
+ if let Some(multiview) = self.multiview.as_ref() {
+ writeln!(self.out, "layout(num_views = {multiview}) in;")?;
+ writeln!(self.out)?;
+ }
+ }
+
+ let ep_info = self.info.get_entry_point(self.entry_point_idx as usize);
+
+ // Write struct types.
+ //
+ // This are always ordered because the IR is structured in a way that
+ // you can't make a struct without adding all of its members first.
+ for (handle, ty) in self.module.types.iter() {
+ if let TypeInner::Struct { ref members, .. } = ty.inner {
+ // Structures ending with runtime-sized arrays can only be
+ // rendered as shader storage blocks in GLSL, not stand-alone
+ // struct types.
+ if !self.module.types[members.last().unwrap().ty]
+ .inner
+ .is_dynamically_sized(&self.module.types)
+ {
+ let name = &self.names[&NameKey::Type(handle)];
+ write!(self.out, "struct {name} ")?;
+ self.write_struct_body(handle, members)?;
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+
+ // Write the globals
+ //
+ // Unless explicitly disabled with WriterFlags::INCLUDE_UNUSED_ITEMS,
+ // we filter all globals that aren't used by the selected entry point as they might be
+ // interfere with each other (i.e. two globals with the same location but different with
+ // different classes)
+ let include_unused = self
+ .options
+ .writer_flags
+ .contains(WriterFlags::INCLUDE_UNUSED_ITEMS);
+ for (handle, global) in self.module.global_variables.iter() {
+ let is_unused = ep_info[handle].is_empty();
+ if !include_unused && is_unused {
+ continue;
+ }
+
+ match self.module.types[global.ty].inner {
+ // We treat images separately because they might require
+ // writing the storage format
+ TypeInner::Image {
+ mut dim,
+ arrayed,
+ class,
+ } => {
+ // Gather the storage format if needed
+ let storage_format_access = match self.module.types[global.ty].inner {
+ TypeInner::Image {
+ class: crate::ImageClass::Storage { format, access },
+ ..
+ } => Some((format, access)),
+ _ => None,
+ };
+
+ if dim == crate::ImageDimension::D1 && es {
+ dim = crate::ImageDimension::D2
+ }
+
+ // Gether the location if needed
+ let layout_binding = if self.options.version.supports_explicit_locations() {
+ let br = global.binding.as_ref().unwrap();
+ self.options.binding_map.get(br).cloned()
+ } else {
+ None
+ };
+
+ // Write all the layout qualifiers
+ if layout_binding.is_some() || storage_format_access.is_some() {
+ write!(self.out, "layout(")?;
+ if let Some(binding) = layout_binding {
+ write!(self.out, "binding = {binding}")?;
+ }
+ if let Some((format, _)) = storage_format_access {
+ let format_str = glsl_storage_format(format);
+ let separator = match layout_binding {
+ Some(_) => ",",
+ None => "",
+ };
+ write!(self.out, "{separator}{format_str}")?;
+ }
+ write!(self.out, ") ")?;
+ }
+
+ if let Some((_, access)) = storage_format_access {
+ self.write_storage_access(access)?;
+ }
+
+ // All images in glsl are `uniform`
+ // The trailing space is important
+ write!(self.out, "uniform ")?;
+
+ // write the type
+ //
+ // This is way we need the leading space because `write_image_type` doesn't add
+ // any spaces at the beginning or end
+ self.write_image_type(dim, arrayed, class)?;
+
+ // Finally write the name and end the global with a `;`
+ // The leading space is important
+ let global_name = self.get_global_name(handle, global);
+ writeln!(self.out, " {global_name};")?;
+ writeln!(self.out)?;
+
+ self.reflection_names_globals.insert(handle, global_name);
+ }
+ // glsl has no concept of samplers so we just ignore it
+ TypeInner::Sampler { .. } => continue,
+ // All other globals are written by `write_global`
+ _ => {
+ self.write_global(handle, global)?;
+ // Add a newline (only for readability)
+ writeln!(self.out)?;
+ }
+ }
+ }
+
+ if include_unused {
+ // write named constants
+ for (handle, constant) in self.module.constants.iter() {
+ if let Some(name) = constant.name.as_ref() {
+ write!(self.out, "const ")?;
+ match constant.inner {
+ crate::ConstantInner::Scalar { width, value } => {
+ // create a TypeInner to write
+ let inner = TypeInner::Scalar {
+ width,
+ kind: value.scalar_kind(),
+ };
+ self.write_value_type(&inner)?;
+ }
+ crate::ConstantInner::Composite { ty, .. } => {
+ self.write_type(ty)?;
+ }
+ };
+ write!(self.out, " {name} = ")?;
+ self.write_constant(handle)?;
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+
+ for arg in self.entry_point.function.arguments.iter() {
+ self.write_varying(arg.binding.as_ref(), arg.ty, false)?;
+ }
+ if let Some(ref result) = self.entry_point.function.result {
+ self.write_varying(result.binding.as_ref(), result.ty, true)?;
+ }
+ writeln!(self.out)?;
+
+ // Write all regular functions
+ for (handle, function) in self.module.functions.iter() {
+ // Check that the function doesn't use globals that aren't supported
+ // by the current entry point
+ if !include_unused && !ep_info.dominates_global_use(&self.info[handle]) {
+ continue;
+ }
+
+ let fun_info = &self.info[handle];
+
+ // Write the function
+ self.write_function(back::FunctionType::Function(handle), function, fun_info)?;
+
+ writeln!(self.out)?;
+ }
+
+ self.write_function(
+ back::FunctionType::EntryPoint(self.entry_point_idx),
+ &self.entry_point.function,
+ ep_info,
+ )?;
+
+ // Add newline at the end of file
+ writeln!(self.out)?;
+
+ // Collect all reflection info and return it to the user
+ self.collect_reflection_info()
+ }
+
+ fn write_array_size(
+ &mut self,
+ base: Handle<crate::Type>,
+ size: crate::ArraySize,
+ ) -> BackendResult {
+ write!(self.out, "[")?;
+
+ // Write the array size
+ // Writes nothing if `ArraySize::Dynamic`
+ // Panics if `ArraySize::Constant` has a constant that isn't an sint or uint
+ match size {
+ crate::ArraySize::Constant(const_handle) => {
+ match self.module.constants[const_handle].inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Uint(size),
+ } => write!(self.out, "{size}")?,
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Sint(size),
+ } => write!(self.out, "{size}")?,
+ _ => unreachable!(),
+ }
+ }
+ crate::ArraySize::Dynamic => (),
+ }
+
+ write!(self.out, "]")?;
+
+ if let TypeInner::Array {
+ base: next_base,
+ size: next_size,
+ ..
+ } = self.module.types[base].inner
+ {
+ self.write_array_size(next_base, next_size)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write value types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ ///
+ /// # Panics
+ /// - If type is either a image, a sampler, a pointer, or a struct
+ /// - If it's an Array with a [`ArraySize::Constant`](crate::ArraySize::Constant) with a
+ /// constant that isn't a [`Scalar`](crate::ConstantInner::Scalar) or if the
+ /// scalar value isn't an [`Sint`](crate::ScalarValue::Sint) or [`Uint`](crate::ScalarValue::Uint)
+ fn write_value_type(&mut self, inner: &TypeInner) -> BackendResult {
+ match *inner {
+ // Scalars are simple we just get the full name from `glsl_scalar`
+ TypeInner::Scalar { kind, width }
+ | TypeInner::Atomic { kind, width }
+ | TypeInner::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space: _,
+ } => write!(self.out, "{}", glsl_scalar(kind, width)?.full)?,
+ // Vectors are just `gvecN` where `g` is the scalar prefix and `N` is the vector size
+ TypeInner::Vector { size, kind, width }
+ | TypeInner::ValuePointer {
+ size: Some(size),
+ kind,
+ width,
+ space: _,
+ } => write!(
+ self.out,
+ "{}vec{}",
+ glsl_scalar(kind, width)?.prefix,
+ size as u8
+ )?,
+ // Matrices are written with `gmatMxN` where `g` is the scalar prefix (only floats and
+ // doubles are allowed), `M` is the columns count and `N` is the rows count
+ //
+ // glsl supports a matrix shorthand `gmatN` where `N` = `M` but it doesn't justify the
+ // extra branch to write matrices this way
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => write!(
+ self.out,
+ "{}mat{}x{}",
+ glsl_scalar(crate::ScalarKind::Float, width)?.prefix,
+ columns as u8,
+ rows as u8
+ )?,
+ // GLSL arrays are written as `type name[size]`
+ // Current code is written arrays only as `[size]`
+ // Base `type` and `name` should be written outside
+ TypeInner::Array { base, size, .. } => self.write_array_size(base, size)?,
+ // Panic if either Image, Sampler, Pointer, or a Struct is being written
+ //
+ // Write all variants instead of `_` so that if new variants are added a
+ // no exhaustiveness error is thrown
+ TypeInner::Pointer { .. }
+ | TypeInner::Struct { .. }
+ | TypeInner::Image { .. }
+ | TypeInner::Sampler { .. }
+ | TypeInner::AccelerationStructure
+ | TypeInner::RayQuery
+ | TypeInner::BindingArray { .. } => {
+ return Err(Error::Custom(format!("Unable to write type {inner:?}")))
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ ///
+ /// # Panics
+ /// - If type is either a image or sampler
+ /// - If it's an Array with a [`ArraySize::Constant`](crate::ArraySize::Constant) with a
+ /// constant that isn't a [`Scalar`](crate::ConstantInner::Scalar) or if the
+ /// scalar value isn't an [`Sint`](crate::ScalarValue::Sint) or [`Uint`](crate::ScalarValue::Uint)
+ fn write_type(&mut self, ty: Handle<crate::Type>) -> BackendResult {
+ match self.module.types[ty].inner {
+ // glsl has no pointer types so just write types as normal and loads are skipped
+ TypeInner::Pointer { base, .. } => self.write_type(base),
+ // glsl structs are written as just the struct name
+ TypeInner::Struct { .. } => {
+ // Get the struct name
+ let name = &self.names[&NameKey::Type(ty)];
+ write!(self.out, "{name}")?;
+ Ok(())
+ }
+ // glsl array has the size separated from the base type
+ TypeInner::Array { base, .. } => self.write_type(base),
+ ref other => self.write_value_type(other),
+ }
+ }
+
+ /// Helper method to write a image type
+ ///
+ /// # Notes
+ /// Adds no leading or trailing whitespace
+ fn write_image_type(
+ &mut self,
+ dim: crate::ImageDimension,
+ arrayed: bool,
+ class: crate::ImageClass,
+ ) -> BackendResult {
+ // glsl images consist of four parts the scalar prefix, the image "type", the dimensions
+ // and modifiers
+ //
+ // There exists two image types
+ // - sampler - for sampled images
+ // - image - for storage images
+ //
+ // There are three possible modifiers that can be used together and must be written in
+ // this order to be valid
+ // - MS - used if it's a multisampled image
+ // - Array - used if it's an image array
+ // - Shadow - used if it's a depth image
+ use crate::ImageClass as Ic;
+
+ let (base, kind, ms, comparison) = match class {
+ Ic::Sampled { kind, multi: true } => ("sampler", kind, "MS", ""),
+ Ic::Sampled { kind, multi: false } => ("sampler", kind, "", ""),
+ Ic::Depth { multi: true } => ("sampler", crate::ScalarKind::Float, "MS", ""),
+ Ic::Depth { multi: false } => ("sampler", crate::ScalarKind::Float, "", "Shadow"),
+ Ic::Storage { format, .. } => ("image", format.into(), "", ""),
+ };
+
+ write!(
+ self.out,
+ "highp {}{}{}{}{}{}",
+ glsl_scalar(kind, 4)?.prefix,
+ base,
+ glsl_dimension(dim),
+ ms,
+ if arrayed { "Array" } else { "" },
+ comparison
+ )?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write non images/sampler globals
+ ///
+ /// # Notes
+ /// Adds a newline
+ ///
+ /// # Panics
+ /// If the global has type sampler
+ fn write_global(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ if self.options.version.supports_explicit_locations() {
+ if let Some(ref br) = global.binding {
+ match self.options.binding_map.get(br) {
+ Some(binding) => {
+ let layout = match global.space {
+ crate::AddressSpace::Storage { .. } => {
+ if self.options.version.supports_std430_layout() {
+ "std430, "
+ } else {
+ "std140, "
+ }
+ }
+ crate::AddressSpace::Uniform => "std140, ",
+ _ => "",
+ };
+ write!(self.out, "layout({layout}binding = {binding}) ")?
+ }
+ None => {
+ log::debug!("unassigned binding for {:?}", global.name);
+ if let crate::AddressSpace::Storage { .. } = global.space {
+ if self.options.version.supports_std430_layout() {
+ write!(self.out, "layout(std430) ")?
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if let crate::AddressSpace::Storage { access } = global.space {
+ self.write_storage_access(access)?;
+ }
+
+ if let Some(storage_qualifier) = glsl_storage_qualifier(global.space) {
+ write!(self.out, "{storage_qualifier} ")?;
+ }
+
+ match global.space {
+ crate::AddressSpace::Private => {
+ self.write_simple_global(handle, global)?;
+ }
+ crate::AddressSpace::WorkGroup => {
+ self.write_simple_global(handle, global)?;
+ }
+ crate::AddressSpace::PushConstant => {
+ self.write_simple_global(handle, global)?;
+ }
+ crate::AddressSpace::Uniform => {
+ self.write_interface_block(handle, global)?;
+ }
+ crate::AddressSpace::Storage { .. } => {
+ self.write_interface_block(handle, global)?;
+ }
+ // A global variable in the `Function` address space is a
+ // contradiction in terms.
+ crate::AddressSpace::Function => unreachable!(),
+ // Textures and samplers are handled directly in `Writer::write`.
+ crate::AddressSpace::Handle => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ fn write_simple_global(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ self.write_type(global.ty)?;
+ write!(self.out, " ")?;
+ self.write_global_name(handle, global)?;
+
+ if let TypeInner::Array { base, size, .. } = self.module.types[global.ty].inner {
+ self.write_array_size(base, size)?;
+ }
+
+ if global.space.initializable() && is_value_init_supported(self.module, global.ty) {
+ write!(self.out, " = ")?;
+ if let Some(init) = global.init {
+ self.write_constant(init)?;
+ } else {
+ self.write_zero_init_value(global.ty)?;
+ }
+ }
+
+ writeln!(self.out, ";")?;
+
+ if let crate::AddressSpace::PushConstant = global.space {
+ let global_name = self.get_global_name(handle, global);
+ self.reflection_names_globals.insert(handle, global_name);
+ }
+
+ Ok(())
+ }
+
+ /// Write an interface block for a single Naga global.
+ ///
+ /// Write `block_name { members }`. Since `block_name` must be unique
+ /// between blocks and structs, we add `_block_ID` where `ID` is a
+ /// `IdGenerator` generated number. Write `members` in the same way we write
+ /// a struct's members.
+ fn write_interface_block(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ // Write the block name, it's just the struct name appended with `_block_ID`
+ let ty_name = &self.names[&NameKey::Type(global.ty)];
+ let block_name = format!(
+ "{}_block_{}{:?}",
+ ty_name,
+ self.block_id.generate(),
+ self.entry_point.stage,
+ );
+ write!(self.out, "{block_name} ")?;
+ self.reflection_names_globals.insert(handle, block_name);
+
+ match self.module.types[global.ty].inner {
+ crate::TypeInner::Struct { ref members, .. }
+ if self.module.types[members.last().unwrap().ty]
+ .inner
+ .is_dynamically_sized(&self.module.types) =>
+ {
+ // Structs with dynamically sized arrays must have their
+ // members lifted up as members of the interface block. GLSL
+ // can't write such struct types anyway.
+ self.write_struct_body(global.ty, members)?;
+ write!(self.out, " ")?;
+ self.write_global_name(handle, global)?;
+ }
+ _ => {
+ // A global of any other type is written as the sole member
+ // of the interface block. Since the interface block is
+ // anonymous, this becomes visible in the global scope.
+ write!(self.out, "{{ ")?;
+ self.write_type(global.ty)?;
+ write!(self.out, " ")?;
+ self.write_global_name(handle, global)?;
+ if let TypeInner::Array { base, size, .. } = self.module.types[global.ty].inner {
+ self.write_array_size(base, size)?;
+ }
+ write!(self.out, "; }}")?;
+ }
+ }
+
+ writeln!(self.out, ";")?;
+
+ Ok(())
+ }
+
+ /// Helper method used to find which expressions of a given function require baking
+ ///
+ /// # Notes
+ /// Clears `need_bake_expressions` set before adding to it
+ fn update_expressions_to_bake(&mut self, func: &crate::Function, info: &valid::FunctionInfo) {
+ use crate::Expression;
+ self.need_bake_expressions.clear();
+ for (fun_handle, expr) in func.expressions.iter() {
+ let expr_info = &info[fun_handle];
+ let min_ref_count = func.expressions[fun_handle].bake_ref_count();
+ if min_ref_count <= expr_info.ref_count {
+ self.need_bake_expressions.insert(fun_handle);
+ }
+
+ let inner = expr_info.ty.inner_with(&self.module.types);
+
+ if let Expression::Math { fun, arg, arg1, .. } = *expr {
+ match fun {
+ crate::MathFunction::Dot => {
+ // if the expression is a Dot product with integer arguments,
+ // then the args needs baking as well
+ if let TypeInner::Scalar { kind, .. } = *inner {
+ match kind {
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
+ self.need_bake_expressions.insert(arg);
+ self.need_bake_expressions.insert(arg1.unwrap());
+ }
+ _ => {}
+ }
+ }
+ }
+ crate::MathFunction::CountLeadingZeros => {
+ if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
+ self.need_bake_expressions.insert(arg);
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+ }
+
+ /// Helper method used to get a name for a global
+ ///
+ /// Globals have different naming schemes depending on their binding:
+ /// - Globals without bindings use the name from the [`Namer`](crate::proc::Namer)
+ /// - Globals with resource binding are named `_group_X_binding_Y` where `X`
+ /// is the group and `Y` is the binding
+ fn get_global_name(
+ &self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> String {
+ match global.binding {
+ Some(ref br) => {
+ format!(
+ "_group_{}_binding_{}_{}",
+ br.group,
+ br.binding,
+ self.entry_point.stage.to_str()
+ )
+ }
+ None => self.names[&NameKey::GlobalVariable(handle)].clone(),
+ }
+ }
+
+ /// Helper method used to write a name for a global without additional heap allocation
+ fn write_global_name(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ match global.binding {
+ Some(ref br) => write!(
+ self.out,
+ "_group_{}_binding_{}_{}",
+ br.group,
+ br.binding,
+ self.entry_point.stage.to_str()
+ )?,
+ None => write!(
+ self.out,
+ "{}",
+ &self.names[&NameKey::GlobalVariable(handle)]
+ )?,
+ }
+
+ Ok(())
+ }
+
+ /// Write a GLSL global that will carry a Naga entry point's argument or return value.
+ ///
+ /// A Naga entry point's arguments and return value are rendered in GLSL as
+ /// variables at global scope with the `in` and `out` storage qualifiers.
+ /// The code we generate for `main` loads from all the `in` globals into
+ /// appropriately named locals. Before it returns, `main` assigns the
+ /// components of its return value into all the `out` globals.
+ ///
+ /// This function writes a declaration for one such GLSL global,
+ /// representing a value passed into or returned from [`self.entry_point`]
+ /// that has a [`Location`] binding. The global's name is generated based on
+ /// the location index and the shader stages being connected; see
+ /// [`VaryingName`]. This means we don't need to know the names of
+ /// arguments, just their types and bindings.
+ ///
+ /// Emit nothing for entry point arguments or return values with [`BuiltIn`]
+ /// bindings; `main` will read from or assign to the appropriate GLSL
+ /// special variable; these are pre-declared. As an exception, we do declare
+ /// `gl_Position` or `gl_FragCoord` with the `invariant` qualifier if
+ /// needed.
+ ///
+ /// Use `output` together with [`self.entry_point.stage`] to determine which
+ /// shader stages are being connected, and choose the `in` or `out` storage
+ /// qualifier.
+ ///
+ /// [`self.entry_point`]: Writer::entry_point
+ /// [`self.entry_point.stage`]: crate::EntryPoint::stage
+ /// [`Location`]: crate::Binding::Location
+ /// [`BuiltIn`]: crate::Binding::BuiltIn
+ fn write_varying(
+ &mut self,
+ binding: Option<&crate::Binding>,
+ ty: Handle<crate::Type>,
+ output: bool,
+ ) -> Result<(), Error> {
+ // For a struct, emit a separate global for each member with a binding.
+ if let crate::TypeInner::Struct { ref members, .. } = self.module.types[ty].inner {
+ for member in members {
+ self.write_varying(member.binding.as_ref(), member.ty, output)?;
+ }
+ return Ok(());
+ }
+
+ let binding = match binding {
+ None => return Ok(()),
+ Some(binding) => binding,
+ };
+
+ let (location, interpolation, sampling) = match *binding {
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ } => (location, interpolation, sampling),
+ crate::Binding::BuiltIn(built_in) => {
+ if let crate::BuiltIn::Position { invariant: true } = built_in {
+ match (self.options.version, self.entry_point.stage) {
+ (
+ Version::Embedded {
+ version: 300,
+ is_webgl: true,
+ },
+ ShaderStage::Fragment,
+ ) => {
+ // `invariant gl_FragCoord` is not allowed in WebGL2 and possibly
+ // OpenGL ES in general (waiting on confirmation).
+ //
+ // See https://github.com/KhronosGroup/WebGL/issues/3518
+ }
+ _ => {
+ writeln!(
+ self.out,
+ "invariant {};",
+ glsl_built_in(built_in, output, self.options.version.is_webgl())
+ )?;
+ }
+ }
+ }
+ return Ok(());
+ }
+ };
+
+ // Write the interpolation modifier if needed
+ //
+ // We ignore all interpolation and auxiliary modifiers that aren't used in fragment
+ // shaders' input globals or vertex shaders' output globals.
+ let emit_interpolation_and_auxiliary = match self.entry_point.stage {
+ ShaderStage::Vertex => output,
+ ShaderStage::Fragment => !output,
+ ShaderStage::Compute => false,
+ };
+
+ // Write the I/O locations, if allowed
+ if self.options.version.supports_explicit_locations() || !emit_interpolation_and_auxiliary {
+ write!(self.out, "layout(location = {location}) ")?;
+ }
+
+ // Write the interpolation qualifier.
+ if let Some(interp) = interpolation {
+ if emit_interpolation_and_auxiliary {
+ write!(self.out, "{} ", glsl_interpolation(interp))?;
+ }
+ }
+
+ // Write the sampling auxiliary qualifier.
+ //
+ // Before GLSL 4.2, the `centroid` and `sample` qualifiers were required to appear
+ // immediately before the `in` / `out` qualifier, so we'll just follow that rule
+ // here, regardless of the version.
+ if let Some(sampling) = sampling {
+ if emit_interpolation_and_auxiliary {
+ if let Some(qualifier) = glsl_sampling(sampling) {
+ write!(self.out, "{qualifier} ")?;
+ }
+ }
+ }
+
+ // Write the input/output qualifier.
+ write!(self.out, "{} ", if output { "out" } else { "in" })?;
+
+ // Write the type
+ // `write_type` adds no leading or trailing spaces
+ self.write_type(ty)?;
+
+ // Finally write the global name and end the global with a `;` and a newline
+ // Leading space is important
+ let vname = VaryingName {
+ binding: &crate::Binding::Location {
+ location,
+ interpolation: None,
+ sampling: None,
+ },
+ stage: self.entry_point.stage,
+ output,
+ targetting_webgl: self.options.version.is_webgl(),
+ };
+ writeln!(self.out, " {vname};")?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write functions (both entry points and regular functions)
+ ///
+ /// # Notes
+ /// Adds a newline
+ fn write_function(
+ &mut self,
+ ty: back::FunctionType,
+ func: &crate::Function,
+ info: &valid::FunctionInfo,
+ ) -> BackendResult {
+ // Create a function context for the function being written
+ let ctx = back::FunctionCtx {
+ ty,
+ info,
+ expressions: &func.expressions,
+ named_expressions: &func.named_expressions,
+ };
+
+ self.named_expressions.clear();
+ self.update_expressions_to_bake(func, info);
+
+ // Write the function header
+ //
+ // glsl headers are the same as in c:
+ // `ret_type name(args)`
+ // `ret_type` is the return type
+ // `name` is the function name
+ // `args` is a comma separated list of `type name`
+ // | - `type` is the argument type
+ // | - `name` is the argument name
+
+ // Start by writing the return type if any otherwise write void
+ // This is the only place where `void` is a valid type
+ // (though it's more a keyword than a type)
+ if let back::FunctionType::EntryPoint(_) = ctx.ty {
+ write!(self.out, "void")?;
+ } else if let Some(ref result) = func.result {
+ self.write_type(result.ty)?;
+ } else {
+ write!(self.out, "void")?;
+ }
+
+ // Write the function name and open parentheses for the argument list
+ let function_name = match ctx.ty {
+ back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
+ back::FunctionType::EntryPoint(_) => "main",
+ };
+ write!(self.out, " {function_name}(")?;
+
+ // Write the comma separated argument list
+ //
+ // We need access to `Self` here so we use the reference passed to the closure as an
+ // argument instead of capturing as that would cause a borrow checker error
+ let arguments = match ctx.ty {
+ back::FunctionType::EntryPoint(_) => &[][..],
+ back::FunctionType::Function(_) => &func.arguments,
+ };
+ let arguments: Vec<_> = arguments
+ .iter()
+ .enumerate()
+ .filter(|&(_, arg)| match self.module.types[arg.ty].inner {
+ TypeInner::Sampler { .. } => false,
+ _ => true,
+ })
+ .collect();
+ self.write_slice(&arguments, |this, _, &(i, arg)| {
+ // Write the argument type
+ match this.module.types[arg.ty].inner {
+ // We treat images separately because they might require
+ // writing the storage format
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ // Write the storage format if needed
+ if let TypeInner::Image {
+ class: crate::ImageClass::Storage { format, .. },
+ ..
+ } = this.module.types[arg.ty].inner
+ {
+ write!(this.out, "layout({}) ", glsl_storage_format(format))?;
+ }
+
+ // write the type
+ //
+ // This is way we need the leading space because `write_image_type` doesn't add
+ // any spaces at the beginning or end
+ this.write_image_type(dim, arrayed, class)?;
+ }
+ TypeInner::Pointer { base, .. } => {
+ // write parameter qualifiers
+ write!(this.out, "inout ")?;
+ this.write_type(base)?;
+ }
+ // All other types are written by `write_type`
+ _ => {
+ this.write_type(arg.ty)?;
+ }
+ }
+
+ // Write the argument name
+ // The leading space is important
+ write!(this.out, " {}", &this.names[&ctx.argument_key(i as u32)])?;
+
+ // Write array size
+ match this.module.types[arg.ty].inner {
+ TypeInner::Array { base, size, .. } => {
+ this.write_array_size(base, size)?;
+ }
+ TypeInner::Pointer { base, .. } => {
+ if let TypeInner::Array { base, size, .. } = this.module.types[base].inner {
+ this.write_array_size(base, size)?;
+ }
+ }
+ _ => {}
+ }
+
+ Ok(())
+ })?;
+
+ // Close the parentheses and open braces to start the function body
+ writeln!(self.out, ") {{")?;
+
+ if self.options.zero_initialize_workgroup_memory
+ && ctx.ty.is_compute_entry_point(self.module)
+ {
+ self.write_workgroup_variables_initialization(&ctx)?;
+ }
+
+ // Compose the function arguments from globals, in case of an entry point.
+ if let back::FunctionType::EntryPoint(ep_index) = ctx.ty {
+ let stage = self.module.entry_points[ep_index as usize].stage;
+ for (index, arg) in func.arguments.iter().enumerate() {
+ write!(self.out, "{}", back::INDENT)?;
+ self.write_type(arg.ty)?;
+ let name = &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
+ write!(self.out, " {name}")?;
+ write!(self.out, " = ")?;
+ match self.module.types[arg.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ self.write_type(arg.ty)?;
+ write!(self.out, "(")?;
+ for (index, member) in members.iter().enumerate() {
+ let varying_name = VaryingName {
+ binding: member.binding.as_ref().unwrap(),
+ stage,
+ output: false,
+ targetting_webgl: self.options.version.is_webgl(),
+ };
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "{varying_name}")?;
+ }
+ writeln!(self.out, ");")?;
+ }
+ _ => {
+ let varying_name = VaryingName {
+ binding: arg.binding.as_ref().unwrap(),
+ stage,
+ output: false,
+ targetting_webgl: self.options.version.is_webgl(),
+ };
+ writeln!(self.out, "{varying_name};")?;
+ }
+ }
+ }
+ }
+
+ // Write all function locals
+ // Locals are `type name (= init)?;` where the init part (including the =) are optional
+ //
+ // Always adds a newline
+ for (handle, local) in func.local_variables.iter() {
+ // Write indentation (only for readability) and the type
+ // `write_type` adds no trailing space
+ write!(self.out, "{}", back::INDENT)?;
+ self.write_type(local.ty)?;
+
+ // Write the local name
+ // The leading space is important
+ write!(self.out, " {}", self.names[&ctx.name_key(handle)])?;
+ // Write size for array type
+ if let TypeInner::Array { base, size, .. } = self.module.types[local.ty].inner {
+ self.write_array_size(base, size)?;
+ }
+ // Write the local initializer if needed
+ if let Some(init) = local.init {
+ // Put the equal signal only if there's a initializer
+ // The leading and trailing spaces aren't needed but help with readability
+ write!(self.out, " = ")?;
+
+ // Write the constant
+ // `write_constant` adds no trailing or leading space/newline
+ self.write_constant(init)?;
+ } else if is_value_init_supported(self.module, local.ty) {
+ write!(self.out, " = ")?;
+ self.write_zero_init_value(local.ty)?;
+ }
+
+ // Finish the local with `;` and add a newline (only for readability)
+ writeln!(self.out, ";")?
+ }
+
+ // Write the function body (statement list)
+ for sta in func.body.iter() {
+ // Write a statement, the indentation should always be 1 when writing the function body
+ // `write_stmt` adds a newline
+ self.write_stmt(sta, &ctx, back::Level(1))?;
+ }
+
+ // Close braces and add a newline
+ writeln!(self.out, "}}")?;
+
+ Ok(())
+ }
+
+ fn write_workgroup_variables_initialization(
+ &mut self,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ let mut vars = self
+ .module
+ .global_variables
+ .iter()
+ .filter(|&(handle, var)| {
+ !ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ })
+ .peekable();
+
+ if vars.peek().is_some() {
+ let level = back::Level(1);
+
+ writeln!(self.out, "{level}if (gl_LocalInvocationID == uvec3(0u)) {{")?;
+
+ for (handle, var) in vars {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{}{} = ", level.next(), name)?;
+ self.write_zero_init_value(var.ty)?;
+ writeln!(self.out, ";")?;
+ }
+
+ writeln!(self.out, "{level}}}")?;
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method that writes a list of comma separated `T` with a writer function `F`
+ ///
+ /// The writer function `F` receives a mutable reference to `self` that if needed won't cause
+ /// borrow checker issues (using for example a closure with `self` will cause issues), the
+ /// second argument is the 0 based index of the element on the list, and the last element is
+ /// a reference to the element `T` being written
+ ///
+ /// # Notes
+ /// - Adds no newlines or leading/trailing whitespace
+ /// - The last element won't have a trailing `,`
+ fn write_slice<T, F: FnMut(&mut Self, u32, &T) -> BackendResult>(
+ &mut self,
+ data: &[T],
+ mut f: F,
+ ) -> BackendResult {
+ // Loop trough `data` invoking `f` for each element
+ for (i, item) in data.iter().enumerate() {
+ f(self, i as u32, item)?;
+
+ // Only write a comma if isn't the last element
+ if i != data.len().saturating_sub(1) {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write constants
+ ///
+ /// # Notes
+ /// Adds no newlines or leading/trailing whitespace
+ fn write_constant(&mut self, handle: Handle<crate::Constant>) -> BackendResult {
+ use crate::ScalarValue as Sv;
+
+ match self.module.constants[handle].inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ ref value,
+ } => match *value {
+ // Signed integers don't need anything special
+ Sv::Sint(int) => write!(self.out, "{int}")?,
+ // Unsigned integers need a `u` at the end
+ //
+ // While `core` doesn't necessarily need it, it's allowed and since `es` needs it we
+ // always write it as the extra branch wouldn't have any benefit in readability
+ Sv::Uint(int) => write!(self.out, "{int}u")?,
+ // Floats are written using `Debug` instead of `Display` because it always appends the
+ // decimal part even it's zero which is needed for a valid glsl float constant
+ Sv::Float(float) => write!(self.out, "{float:?}")?,
+ // Booleans are either `true` or `false` so nothing special needs to be done
+ Sv::Bool(boolean) => write!(self.out, "{boolean}")?,
+ },
+ // Composite constant are created using the same syntax as compose
+ // `type(components)` where `components` is a comma separated list of constants
+ crate::ConstantInner::Composite { ty, ref components } => {
+ self.write_type(ty)?;
+ if let TypeInner::Array { base, size, .. } = self.module.types[ty].inner {
+ self.write_array_size(base, size)?;
+ }
+ write!(self.out, "(")?;
+
+ // Write the comma separated constants
+ self.write_slice(components, |this, _, arg| this.write_constant(*arg))?;
+
+ write!(self.out, ")")?
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to output a dot product as an arithmetic expression
+ ///
+ fn write_dot_product(
+ &mut self,
+ arg: Handle<crate::Expression>,
+ arg1: Handle<crate::Expression>,
+ size: usize,
+ ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ // Write parantheses around the dot product expression to prevent operators
+ // with different precedences from applying earlier.
+ write!(self.out, "(")?;
+
+ // Cycle trough all the components of the vector
+ for index in 0..size {
+ let component = back::COMPONENTS[index];
+ // Write the addition to the previous product
+ // This will print an extra '+' at the beginning but that is fine in glsl
+ write!(self.out, " + ")?;
+ // Write the first vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.write_expr(arg, ctx)?;
+ // Access the current component on the first vector
+ write!(self.out, ".{component} * ")?;
+ // Write the second vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.write_expr(arg1, ctx)?;
+ // Access the current component on the second vector
+ write!(self.out, ".{component}")?;
+ }
+
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ /// Helper method used to write structs
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_struct_body(
+ &mut self,
+ handle: Handle<crate::Type>,
+ members: &[crate::StructMember],
+ ) -> BackendResult {
+ // glsl structs are written as in C
+ // `struct name() { members };`
+ // | `struct` is a keyword
+ // | `name` is the struct name
+ // | `members` is a semicolon separated list of `type name`
+ // | `type` is the member type
+ // | `name` is the member name
+ writeln!(self.out, "{{")?;
+
+ for (idx, member) in members.iter().enumerate() {
+ // The indentation is only for readability
+ write!(self.out, "{}", back::INDENT)?;
+
+ match self.module.types[member.ty].inner {
+ TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ } => {
+ self.write_type(base)?;
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::StructMember(handle, idx as u32)]
+ )?;
+ // Write [size]
+ self.write_array_size(base, size)?;
+ // Newline is important
+ writeln!(self.out, ";")?;
+ }
+ _ => {
+ // Write the member type
+ // Adds no trailing space
+ self.write_type(member.ty)?;
+
+ // Write the member name and put a semicolon
+ // The leading space is important
+ // All members must have a semicolon even the last one
+ writeln!(
+ self.out,
+ " {};",
+ &self.names[&NameKey::StructMember(handle, idx as u32)]
+ )?;
+ }
+ }
+ }
+
+ write!(self.out, "}}")?;
+ Ok(())
+ }
+
+ /// Helper method used to write statements
+ ///
+ /// # Notes
+ /// Always adds a newline
+ fn write_stmt(
+ &mut self,
+ sta: &crate::Statement,
+ ctx: &back::FunctionCtx,
+ level: back::Level,
+ ) -> BackendResult {
+ use crate::Statement;
+
+ match *sta {
+ // This is where we can generate intermediate constants for some expression types.
+ Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ let info = &ctx.info[handle];
+ let ptr_class = info.ty.inner_with(&self.module.types).pointer_space();
+ let expr_name = if ptr_class.is_some() {
+ // GLSL can't save a pointer-valued expression in a variable,
+ // but we shouldn't ever need to: they should never be named expressions,
+ // and none of the expression types flagged by bake_ref_count can be pointer-valued.
+ None
+ } else if let Some(name) = ctx.named_expressions.get(&handle) {
+ // Front end provides names for all variables at the start of writing.
+ // But we write them to step by step. We need to recache them
+ // Otherwise, we could accidentally write variable name instead of full expression.
+ // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
+ Some(self.namer.call(name))
+ } else if self.need_bake_expressions.contains(&handle) {
+ Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
+ } else if info.ref_count == 0 {
+ Some(self.namer.call(""))
+ } else {
+ None
+ };
+
+ // If we are going to write an `ImageLoad` next and the target image
+ // is sampled and we are using the `Restrict` policy for bounds
+ // checking images we need to write a local holding the clamped lod.
+ if let crate::Expression::ImageLoad {
+ image,
+ level: Some(level_expr),
+ ..
+ } = ctx.expressions[handle]
+ {
+ if let TypeInner::Image {
+ class: crate::ImageClass::Sampled { .. },
+ ..
+ } = *ctx.info[image].ty.inner_with(&self.module.types)
+ {
+ if let proc::BoundsCheckPolicy::Restrict = self.policies.image {
+ write!(self.out, "{level}")?;
+ self.write_clamped_lod(ctx, handle, image, level_expr)?
+ }
+ }
+ }
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{level}")?;
+ self.write_named_expr(handle, name, ctx)?;
+ }
+ }
+ }
+ // Blocks are simple we just need to write the block statements between braces
+ // We could also just print the statements but this is more readable and maps more
+ // closely to the IR
+ Statement::Block(ref block) => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "{{")?;
+ for sta in block.iter() {
+ // Increase the indentation to help with readability
+ self.write_stmt(sta, ctx, level.next())?
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ // Ifs are written as in C:
+ // ```
+ // if(condition) {
+ // accept
+ // } else {
+ // reject
+ // }
+ // ```
+ Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "if (")?;
+ self.write_expr(condition, ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ for sta in accept {
+ // Increase indentation to help with readability
+ self.write_stmt(sta, ctx, level.next())?;
+ }
+
+ // If there are no statements in the reject block we skip writing it
+ // This is only for readability
+ if !reject.is_empty() {
+ writeln!(self.out, "{level}}} else {{")?;
+
+ for sta in reject {
+ // Increase indentation to help with readability
+ self.write_stmt(sta, ctx, level.next())?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ // Switch are written as in C:
+ // ```
+ // switch (selector) {
+ // // Fallthrough
+ // case label:
+ // block
+ // // Non fallthrough
+ // case label:
+ // block
+ // break;
+ // default:
+ // block
+ // }
+ // ```
+ // Where the `default` case happens isn't important but we put it last
+ // so that we don't need to print a `break` for it
+ Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ // Start the switch
+ write!(self.out, "{level}")?;
+ write!(self.out, "switch(")?;
+ self.write_expr(selector, ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ // Write all cases
+ let l2 = level.next();
+ for case in cases {
+ match case.value {
+ crate::SwitchValue::I32(value) => write!(self.out, "{l2}case {value}:")?,
+ crate::SwitchValue::U32(value) => write!(self.out, "{l2}case {value}u:")?,
+ crate::SwitchValue::Default => write!(self.out, "{l2}default:")?,
+ }
+
+ let write_block_braces = !(case.fall_through && case.body.is_empty());
+ if write_block_braces {
+ writeln!(self.out, " {{")?;
+ } else {
+ writeln!(self.out)?;
+ }
+
+ for sta in case.body.iter() {
+ self.write_stmt(sta, ctx, l2.next())?;
+ }
+
+ if !case.fall_through && case.body.last().map_or(true, |s| !s.is_terminator()) {
+ writeln!(self.out, "{}break;", l2.next())?;
+ }
+
+ if write_block_braces {
+ writeln!(self.out, "{l2}}}")?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ // Loops in naga IR are based on wgsl loops, glsl can emulate the behaviour by using a
+ // while true loop and appending the continuing block to the body resulting on:
+ // ```
+ // bool loop_init = true;
+ // while(true) {
+ // if (!loop_init) { <continuing> }
+ // loop_init = false;
+ // <body>
+ // }
+ // ```
+ Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ if !continuing.is_empty() || break_if.is_some() {
+ let gate_name = self.namer.call("loop_init");
+ writeln!(self.out, "{level}bool {gate_name} = true;")?;
+ writeln!(self.out, "{level}while(true) {{")?;
+ let l2 = level.next();
+ let l3 = l2.next();
+ writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
+ for sta in continuing {
+ self.write_stmt(sta, ctx, l3)?;
+ }
+ if let Some(condition) = break_if {
+ write!(self.out, "{l3}if (")?;
+ self.write_expr(condition, ctx)?;
+ writeln!(self.out, ") {{")?;
+ writeln!(self.out, "{}break;", l3.next())?;
+ writeln!(self.out, "{l3}}}")?;
+ }
+ writeln!(self.out, "{l2}}}")?;
+ writeln!(self.out, "{}{} = false;", level.next(), gate_name)?;
+ } else {
+ writeln!(self.out, "{level}while(true) {{")?;
+ }
+ for sta in body {
+ self.write_stmt(sta, ctx, level.next())?;
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ // Break, continue and return as written as in C
+ // `break;`
+ Statement::Break => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "break;")?
+ }
+ // `continue;`
+ Statement::Continue => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "continue;")?
+ }
+ // `return expr;`, `expr` is optional
+ Statement::Return { value } => {
+ write!(self.out, "{level}")?;
+ match ctx.ty {
+ back::FunctionType::Function(_) => {
+ write!(self.out, "return")?;
+ // Write the expression to be returned if needed
+ if let Some(expr) = value {
+ write!(self.out, " ")?;
+ self.write_expr(expr, ctx)?;
+ }
+ writeln!(self.out, ";")?;
+ }
+ back::FunctionType::EntryPoint(ep_index) => {
+ let mut has_point_size = false;
+ let ep = &self.module.entry_points[ep_index as usize];
+ if let Some(ref result) = ep.function.result {
+ let value = value.unwrap();
+ match self.module.types[result.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let temp_struct_name = match ctx.expressions[value] {
+ crate::Expression::Compose { .. } => {
+ let return_struct = "_tmp_return";
+ write!(
+ self.out,
+ "{} {} = ",
+ &self.names[&NameKey::Type(result.ty)],
+ return_struct
+ )?;
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ";")?;
+ write!(self.out, "{level}")?;
+ Some(return_struct)
+ }
+ _ => None,
+ };
+
+ for (index, member) in members.iter().enumerate() {
+ if let Some(crate::Binding::BuiltIn(
+ crate::BuiltIn::PointSize,
+ )) = member.binding
+ {
+ has_point_size = true;
+ }
+
+ let varying_name = VaryingName {
+ binding: member.binding.as_ref().unwrap(),
+ stage: ep.stage,
+ output: true,
+ targetting_webgl: self.options.version.is_webgl(),
+ };
+ write!(self.out, "{varying_name} = ")?;
+
+ if let Some(struct_name) = temp_struct_name {
+ write!(self.out, "{struct_name}")?;
+ } else {
+ self.write_expr(value, ctx)?;
+ }
+
+ // Write field name
+ writeln!(
+ self.out,
+ ".{};",
+ &self.names
+ [&NameKey::StructMember(result.ty, index as u32)]
+ )?;
+ write!(self.out, "{level}")?;
+ }
+ }
+ _ => {
+ let name = VaryingName {
+ binding: result.binding.as_ref().unwrap(),
+ stage: ep.stage,
+ output: true,
+ targetting_webgl: self.options.version.is_webgl(),
+ };
+ write!(self.out, "{name} = ")?;
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ";")?;
+ write!(self.out, "{level}")?;
+ }
+ }
+ }
+
+ let is_vertex_stage = self.module.entry_points[ep_index as usize].stage
+ == ShaderStage::Vertex;
+ if is_vertex_stage
+ && self
+ .options
+ .writer_flags
+ .contains(WriterFlags::ADJUST_COORDINATE_SPACE)
+ {
+ writeln!(
+ self.out,
+ "gl_Position.yz = vec2(-gl_Position.y, gl_Position.z * 2.0 - gl_Position.w);",
+ )?;
+ write!(self.out, "{level}")?;
+ }
+
+ if is_vertex_stage
+ && self
+ .options
+ .writer_flags
+ .contains(WriterFlags::FORCE_POINT_SIZE)
+ && !has_point_size
+ {
+ writeln!(self.out, "gl_PointSize = 1.0;")?;
+ write!(self.out, "{level}")?;
+ }
+ writeln!(self.out, "return;")?;
+ }
+ }
+ }
+ // This is one of the places were glsl adds to the syntax of C in this case the discard
+ // keyword which ceases all further processing in a fragment shader, it's called OpKill
+ // in spir-v that's why it's called `Statement::Kill`
+ Statement::Kill => writeln!(self.out, "{level}discard;")?,
+ Statement::Barrier(flags) => {
+ self.write_barrier(flags, level)?;
+ }
+ // Stores in glsl are just variable assignments written as `pointer = value;`
+ Statement::Store { pointer, value } => {
+ write!(self.out, "{level}")?;
+ self.write_expr(pointer, ctx)?;
+ write!(self.out, " = ")?;
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ";")?
+ }
+ // Stores a value into an image.
+ Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ write!(self.out, "{level}")?;
+ self.write_image_store(ctx, image, coordinate, array_index, value)?
+ }
+ // A `Call` is written `name(arguments)` where `arguments` is a comma separated expressions list
+ Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ if let Some(expr) = result {
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ let result = self.module.functions[function].result.as_ref().unwrap();
+ self.write_type(result.ty)?;
+ write!(self.out, " {name} = ")?;
+ self.named_expressions.insert(expr, name);
+ }
+ write!(self.out, "{}(", &self.names[&NameKey::Function(function)])?;
+ let arguments: Vec<_> = arguments
+ .iter()
+ .enumerate()
+ .filter_map(|(i, arg)| {
+ let arg_ty = self.module.functions[function].arguments[i].ty;
+ match self.module.types[arg_ty].inner {
+ TypeInner::Sampler { .. } => None,
+ _ => Some(*arg),
+ }
+ })
+ .collect();
+ self.write_slice(&arguments, |this, _, arg| this.write_expr(*arg, ctx))?;
+ writeln!(self.out, ");")?
+ }
+ Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
+ self.write_value_type(res_ty)?;
+ write!(self.out, " {res_name} = ")?;
+ self.named_expressions.insert(result, res_name);
+
+ let fun_str = fun.to_glsl();
+ write!(self.out, "atomic{fun_str}(")?;
+ self.write_expr(pointer, ctx)?;
+ write!(self.out, ", ")?;
+ // handle the special cases
+ match *fun {
+ crate::AtomicFunction::Subtract => {
+ // we just wrote `InterlockedAdd`, so negate the argument
+ write!(self.out, "-")?;
+ }
+ crate::AtomicFunction::Exchange { compare: Some(_) } => {
+ return Err(Error::Custom(
+ "atomic CompareExchange is not implemented".to_string(),
+ ));
+ }
+ _ => {}
+ }
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ Statement::RayQuery { .. } => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Helper method to write expressions
+ ///
+ /// # Notes
+ /// Doesn't add any newlines or leading/trailing spaces
+ fn write_expr(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ use crate::Expression;
+
+ if let Some(name) = self.named_expressions.get(&expr) {
+ write!(self.out, "{name}")?;
+ return Ok(());
+ }
+
+ match ctx.expressions[expr] {
+ // `Access` is applied to arrays, vectors and matrices and is written as indexing
+ Expression::Access { base, index } => {
+ self.write_expr(base, ctx)?;
+ write!(self.out, "[")?;
+ self.write_expr(index, ctx)?;
+ write!(self.out, "]")?
+ }
+ // `AccessIndex` is the same as `Access` except that the index is a constant and it can
+ // be applied to structs, in this case we need to find the name of the field at that
+ // index and write `base.field_name`
+ Expression::AccessIndex { base, index } => {
+ self.write_expr(base, ctx)?;
+
+ let base_ty_res = &ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&self.module.types);
+ let base_ty_handle = match *resolved {
+ TypeInner::Pointer { base, space: _ } => {
+ resolved = &self.module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+
+ match *resolved {
+ TypeInner::Vector { .. } => {
+ // Write vector access as a swizzle
+ write!(self.out, ".{}", back::COMPONENTS[index as usize])?
+ }
+ TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?,
+ TypeInner::Struct { .. } => {
+ // This will never panic in case the type is a `Struct`, this is not true
+ // for other types so we can only check while inside this match arm
+ let ty = base_ty_handle.unwrap();
+
+ write!(
+ self.out,
+ ".{}",
+ &self.names[&NameKey::StructMember(ty, index)]
+ )?
+ }
+ ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
+ }
+ }
+ // Constants are delegated to `write_constant`
+ Expression::Constant(constant) => self.write_constant(constant)?,
+ // `Splat` needs to actually write down a vector, it's not always inferred in GLSL.
+ Expression::Splat { size: _, value } => {
+ let resolved = ctx.info[expr].ty.inner_with(&self.module.types);
+ self.write_value_type(resolved)?;
+ write!(self.out, "(")?;
+ self.write_expr(value, ctx)?;
+ write!(self.out, ")")?
+ }
+ // `Swizzle` adds a few letters behind the dot.
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.write_expr(vector, ctx)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ self.out.write_char(back::COMPONENTS[sc as usize])?;
+ }
+ }
+ // `Compose` is pretty simple we just write `type(components)` where `components` is a
+ // comma separated list of expressions
+ Expression::Compose { ty, ref components } => {
+ self.write_type(ty)?;
+
+ let resolved = ctx.info[expr].ty.inner_with(&self.module.types);
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(base, size)?;
+ }
+
+ write!(self.out, "(")?;
+ self.write_slice(components, |this, _, arg| this.write_expr(*arg, ctx))?;
+ write!(self.out, ")")?
+ }
+ // Function arguments are written as the argument name
+ Expression::FunctionArgument(pos) => {
+ write!(self.out, "{}", &self.names[&ctx.argument_key(pos)])?
+ }
+ // Global variables need some special work for their name but
+ // `get_global_name` does the work for us
+ Expression::GlobalVariable(handle) => {
+ let global = &self.module.global_variables[handle];
+ self.write_global_name(handle, global)?
+ }
+ // A local is written as it's name
+ Expression::LocalVariable(handle) => {
+ write!(self.out, "{}", self.names[&ctx.name_key(handle)])?
+ }
+ // glsl has no pointers so there's no load operation, just write the pointer expression
+ Expression::Load { pointer } => self.write_expr(pointer, ctx)?,
+ // `ImageSample` is a bit complicated compared to the rest of the IR.
+ //
+ // First there are three variations depending whether the sample level is explicitly set,
+ // if it's automatic or it it's bias:
+ // `texture(image, coordinate)` - Automatic sample level
+ // `texture(image, coordinate, bias)` - Bias sample level
+ // `textureLod(image, coordinate, level)` - Zero or Exact sample level
+ //
+ // Furthermore if `depth_ref` is some we need to append it to the coordinate vector
+ Expression::ImageSample {
+ image,
+ sampler: _, //TODO?
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ let dim = match *ctx.info[image].ty.inner_with(&self.module.types) {
+ TypeInner::Image { dim, .. } => dim,
+ _ => unreachable!(),
+ };
+
+ if dim == crate::ImageDimension::Cube
+ && array_index.is_some()
+ && depth_ref.is_some()
+ {
+ match level {
+ crate::SampleLevel::Zero
+ | crate::SampleLevel::Exact(_)
+ | crate::SampleLevel::Gradient { .. }
+ | crate::SampleLevel::Bias(_) => {
+ return Err(Error::Custom(String::from(
+ "gsamplerCubeArrayShadow isn't supported in textureGrad, \
+ textureLod or texture with bias",
+ )))
+ }
+ crate::SampleLevel::Auto => {}
+ }
+ }
+
+ // textureLod on sampler2DArrayShadow and samplerCubeShadow does not exist in GLSL.
+ // To emulate this, we will have to use textureGrad with a constant gradient of 0.
+ let workaround_lod_array_shadow_as_grad = (array_index.is_some()
+ || dim == crate::ImageDimension::Cube)
+ && depth_ref.is_some()
+ && gather.is_none()
+ && !self
+ .options
+ .writer_flags
+ .contains(WriterFlags::TEXTURE_SHADOW_LOD);
+
+ //Write the function to be used depending on the sample level
+ let fun_name = match level {
+ crate::SampleLevel::Zero if gather.is_some() => "textureGather",
+ crate::SampleLevel::Auto | crate::SampleLevel::Bias(_) => "texture",
+ crate::SampleLevel::Zero | crate::SampleLevel::Exact(_) => {
+ if workaround_lod_array_shadow_as_grad {
+ "textureGrad"
+ } else {
+ "textureLod"
+ }
+ }
+ crate::SampleLevel::Gradient { .. } => "textureGrad",
+ };
+ let offset_name = match offset {
+ Some(_) => "Offset",
+ None => "",
+ };
+
+ write!(self.out, "{fun_name}{offset_name}(")?;
+
+ // Write the image that will be used
+ self.write_expr(image, ctx)?;
+ // The space here isn't required but it helps with readability
+ write!(self.out, ", ")?;
+
+ // We need to get the coordinates vector size to later build a vector that's `size + 1`
+ // if `depth_ref` is some, if it isn't a vector we panic as that's not a valid expression
+ let mut coord_dim = match *ctx.info[coordinate].ty.inner_with(&self.module.types) {
+ TypeInner::Vector { size, .. } => size as u8,
+ TypeInner::Scalar { .. } => 1,
+ _ => unreachable!(),
+ };
+
+ if array_index.is_some() {
+ coord_dim += 1;
+ }
+ let merge_depth_ref = depth_ref.is_some() && gather.is_none() && coord_dim < 4;
+ if merge_depth_ref {
+ coord_dim += 1;
+ }
+
+ let tex_1d_hack = dim == crate::ImageDimension::D1 && self.options.version.is_es();
+ let is_vec = tex_1d_hack || coord_dim != 1;
+ // Compose a new texture coordinates vector
+ if is_vec {
+ write!(self.out, "vec{}(", coord_dim + tex_1d_hack as u8)?;
+ }
+ self.write_expr(coordinate, ctx)?;
+ if tex_1d_hack {
+ write!(self.out, ", 0.0")?;
+ }
+ if let Some(expr) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+ if merge_depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(depth_ref.unwrap(), ctx)?;
+ }
+ if is_vec {
+ write!(self.out, ")")?;
+ }
+
+ if let (Some(expr), false) = (depth_ref, merge_depth_ref) {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+
+ match level {
+ // Auto needs no more arguments
+ crate::SampleLevel::Auto => (),
+ // Zero needs level set to 0
+ crate::SampleLevel::Zero => {
+ if workaround_lod_array_shadow_as_grad {
+ let vec_dim = match dim {
+ crate::ImageDimension::Cube => 3,
+ _ => 2,
+ };
+ write!(self.out, ", vec{vec_dim}(0.0), vec{vec_dim}(0.0)")?;
+ } else if gather.is_none() {
+ write!(self.out, ", 0.0")?;
+ }
+ }
+ // Exact and bias require another argument
+ crate::SampleLevel::Exact(expr) => {
+ if workaround_lod_array_shadow_as_grad {
+ log::warn!("Unable to `textureLod` a shadow array, ignoring the LOD");
+ write!(self.out, ", vec2(0,0), vec2(0,0)")?;
+ } else {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+ }
+ crate::SampleLevel::Bias(_) => {
+ // This needs to be done after the offset writing
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ // If we are using sampler2D to replace sampler1D, we also
+ // need to make sure to use vec2 gradients
+ if tex_1d_hack {
+ write!(self.out, ", vec2(")?;
+ self.write_expr(x, ctx)?;
+ write!(self.out, ", 0.0)")?;
+ write!(self.out, ", vec2(")?;
+ self.write_expr(y, ctx)?;
+ write!(self.out, ", 0.0)")?;
+ } else {
+ write!(self.out, ", ")?;
+ self.write_expr(x, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(y, ctx)?;
+ }
+ }
+ }
+
+ if let Some(constant) = offset {
+ write!(self.out, ", ")?;
+ if tex_1d_hack {
+ write!(self.out, "ivec2(")?;
+ }
+ self.write_constant(constant)?;
+ if tex_1d_hack {
+ write!(self.out, ", 0)")?;
+ }
+ }
+
+ // Bias is always the last argument
+ if let crate::SampleLevel::Bias(expr) = level {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+
+ if let (Some(component), None) = (gather, depth_ref) {
+ write!(self.out, ", {}", component as usize)?;
+ }
+
+ // End the function
+ write!(self.out, ")")?
+ }
+ Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => self.write_image_load(expr, ctx, image, coordinate, array_index, sample, level)?,
+ // Query translates into one of the:
+ // - textureSize/imageSize
+ // - textureQueryLevels
+ // - textureSamples/imageSamples
+ Expression::ImageQuery { image, query } => {
+ use crate::ImageClass;
+
+ // This will only panic if the module is invalid
+ let (dim, class) = match *ctx.info[image].ty.inner_with(&self.module.types) {
+ TypeInner::Image {
+ dim,
+ arrayed: _,
+ class,
+ } => (dim, class),
+ _ => unreachable!(),
+ };
+ let components = match dim {
+ crate::ImageDimension::D1 => 1,
+ crate::ImageDimension::D2 => 2,
+ crate::ImageDimension::D3 => 3,
+ crate::ImageDimension::Cube => 2,
+ };
+
+ if let crate::ImageQuery::Size { .. } = query {
+ match components {
+ 1 => write!(self.out, "uint(")?,
+ _ => write!(self.out, "uvec{components}(")?,
+ }
+ } else {
+ write!(self.out, "uint(")?;
+ }
+
+ match query {
+ crate::ImageQuery::Size { level } => {
+ match class {
+ ImageClass::Sampled { multi, .. } | ImageClass::Depth { multi } => {
+ write!(self.out, "textureSize(")?;
+ self.write_expr(image, ctx)?;
+ if let Some(expr) = level {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ } else if !multi {
+ // All textureSize calls requires an lod argument
+ // except for multisampled samplers
+ write!(self.out, ", 0")?;
+ }
+ }
+ ImageClass::Storage { .. } => {
+ write!(self.out, "imageSize(")?;
+ self.write_expr(image, ctx)?;
+ }
+ }
+ write!(self.out, ")")?;
+ if components != 1 || self.options.version.is_es() {
+ write!(self.out, ".{}", &"xyz"[..components])?;
+ }
+ }
+ crate::ImageQuery::NumLevels => {
+ write!(self.out, "textureQueryLevels(",)?;
+ self.write_expr(image, ctx)?;
+ write!(self.out, ")",)?;
+ }
+ crate::ImageQuery::NumLayers => {
+ let fun_name = match class {
+ ImageClass::Sampled { .. } | ImageClass::Depth { .. } => "textureSize",
+ ImageClass::Storage { .. } => "imageSize",
+ };
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(image, ctx)?;
+ // All textureSize calls requires an lod argument
+ // except for multisampled samplers
+ if class.is_multisampled() {
+ write!(self.out, ", 0")?;
+ }
+ write!(self.out, ")")?;
+ if components != 1 || self.options.version.is_es() {
+ write!(self.out, ".{}", back::COMPONENTS[components])?;
+ }
+ }
+ crate::ImageQuery::NumSamples => {
+ let fun_name = match class {
+ ImageClass::Sampled { .. } | ImageClass::Depth { .. } => {
+ "textureSamples"
+ }
+ ImageClass::Storage { .. } => "imageSamples",
+ };
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(image, ctx)?;
+ write!(self.out, ")",)?;
+ }
+ }
+
+ write!(self.out, ")")?;
+ }
+ // `Unary` is pretty straightforward
+ // "-" - for `Negate`
+ // "~" - for `Not` if it's an integer
+ // "!" - for `Not` if it's a boolean
+ //
+ // We also wrap the everything in parentheses to avoid precedence issues
+ Expression::Unary { op, expr } => {
+ use crate::{ScalarKind as Sk, UnaryOperator as Uo};
+
+ let ty = ctx.info[expr].ty.inner_with(&self.module.types);
+
+ match *ty {
+ TypeInner::Vector { kind: Sk::Bool, .. } => {
+ write!(self.out, "not(")?;
+ }
+ _ => {
+ let operator = match op {
+ Uo::Negate => "-",
+ Uo::Not => match ty.scalar_kind() {
+ Some(Sk::Sint) | Some(Sk::Uint) => "~",
+ Some(Sk::Bool) => "!",
+ ref other => {
+ return Err(Error::Custom(format!(
+ "Cannot apply not to type {other:?}"
+ )))
+ }
+ },
+ };
+
+ write!(self.out, "{operator}(")?;
+ }
+ }
+
+ self.write_expr(expr, ctx)?;
+
+ write!(self.out, ")")?
+ }
+ // `Binary` we just write `left op right`, except when dealing with
+ // comparison operations on vectors as they are implemented with
+ // builtin functions.
+ // Once again we wrap everything in parentheses to avoid precedence issues
+ Expression::Binary {
+ mut op,
+ left,
+ right,
+ } => {
+ // Holds `Some(function_name)` if the binary operation is
+ // implemented as a function call
+ use crate::{BinaryOperator as Bo, ScalarKind as Sk, TypeInner as Ti};
+
+ let left_inner = ctx.info[left].ty.inner_with(&self.module.types);
+ let right_inner = ctx.info[right].ty.inner_with(&self.module.types);
+
+ let function = match (left_inner, right_inner) {
+ (&Ti::Vector { kind, .. }, &Ti::Vector { .. }) => match op {
+ Bo::Less
+ | Bo::LessEqual
+ | Bo::Greater
+ | Bo::GreaterEqual
+ | Bo::Equal
+ | Bo::NotEqual => BinaryOperation::VectorCompare,
+ Bo::Modulo if kind == Sk::Float => BinaryOperation::Modulo,
+ Bo::And if kind == Sk::Bool => {
+ op = crate::BinaryOperator::LogicalAnd;
+ BinaryOperation::VectorComponentWise
+ }
+ Bo::InclusiveOr if kind == Sk::Bool => {
+ op = crate::BinaryOperator::LogicalOr;
+ BinaryOperation::VectorComponentWise
+ }
+ _ => BinaryOperation::Other,
+ },
+ _ => match (left_inner.scalar_kind(), right_inner.scalar_kind()) {
+ (Some(Sk::Float), _) | (_, Some(Sk::Float)) => match op {
+ Bo::Modulo => BinaryOperation::Modulo,
+ _ => BinaryOperation::Other,
+ },
+ (Some(Sk::Bool), Some(Sk::Bool)) => match op {
+ Bo::InclusiveOr => {
+ op = crate::BinaryOperator::LogicalOr;
+ BinaryOperation::Other
+ }
+ Bo::And => {
+ op = crate::BinaryOperator::LogicalAnd;
+ BinaryOperation::Other
+ }
+ _ => BinaryOperation::Other,
+ },
+ _ => BinaryOperation::Other,
+ },
+ };
+
+ match function {
+ BinaryOperation::VectorCompare => {
+ let op_str = match op {
+ Bo::Less => "lessThan(",
+ Bo::LessEqual => "lessThanEqual(",
+ Bo::Greater => "greaterThan(",
+ Bo::GreaterEqual => "greaterThanEqual(",
+ Bo::Equal => "equal(",
+ Bo::NotEqual => "notEqual(",
+ _ => unreachable!(),
+ };
+ write!(self.out, "{op_str}")?;
+ self.write_expr(left, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(right, ctx)?;
+ write!(self.out, ")")?;
+ }
+ BinaryOperation::VectorComponentWise => {
+ self.write_value_type(left_inner)?;
+ write!(self.out, "(")?;
+
+ let size = match *left_inner {
+ Ti::Vector { size, .. } => size,
+ _ => unreachable!(),
+ };
+
+ for i in 0..size as usize {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+
+ self.write_expr(left, ctx)?;
+ write!(self.out, ".{}", back::COMPONENTS[i])?;
+
+ write!(self.out, " {} ", back::binary_operation_str(op))?;
+
+ self.write_expr(right, ctx)?;
+ write!(self.out, ".{}", back::COMPONENTS[i])?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ // TODO: handle undefined behavior of BinaryOperator::Modulo
+ //
+ // sint:
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ // if sign(left) == -1 || sign(right) == -1 return result as defined by WGSL
+ //
+ // uint:
+ // if right == 0 return 0
+ //
+ // float:
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+ BinaryOperation::Modulo => {
+ write!(self.out, "(")?;
+
+ // write `e1 - e2 * trunc(e1 / e2)`
+ self.write_expr(left, ctx)?;
+ write!(self.out, " - ")?;
+ self.write_expr(right, ctx)?;
+ write!(self.out, " * ")?;
+ write!(self.out, "trunc(")?;
+ self.write_expr(left, ctx)?;
+ write!(self.out, " / ")?;
+ self.write_expr(right, ctx)?;
+ write!(self.out, ")")?;
+
+ write!(self.out, ")")?;
+ }
+ BinaryOperation::Other => {
+ write!(self.out, "(")?;
+
+ self.write_expr(left, ctx)?;
+ write!(self.out, " {} ", back::binary_operation_str(op))?;
+ self.write_expr(right, ctx)?;
+
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ // `Select` is written as `condition ? accept : reject`
+ // We wrap everything in parentheses to avoid precedence issues
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ let cond_ty = ctx.info[condition].ty.inner_with(&self.module.types);
+ let vec_select = if let TypeInner::Vector { .. } = *cond_ty {
+ true
+ } else {
+ false
+ };
+
+ // TODO: Boolean mix on desktop required GL_EXT_shader_integer_mix
+ if vec_select {
+ // Glsl defines that for mix when the condition is a boolean the first element
+ // is picked if condition is false and the second if condition is true
+ write!(self.out, "mix(")?;
+ self.write_expr(reject, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(accept, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(condition, ctx)?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(condition, ctx)?;
+ write!(self.out, " ? ")?;
+ self.write_expr(accept, ctx)?;
+ write!(self.out, " : ")?;
+ self.write_expr(reject, ctx)?;
+ }
+
+ write!(self.out, ")")?
+ }
+ // `Derivative` is a function call to a glsl provided function
+ Expression::Derivative { axis, ctrl, expr } => {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ let fun_name = if self.options.version.supports_derivative_control() {
+ match (axis, ctrl) {
+ (Axis::X, Ctrl::Coarse) => "dFdxCoarse",
+ (Axis::X, Ctrl::Fine) => "dFdxFine",
+ (Axis::X, Ctrl::None) => "dFdx",
+ (Axis::Y, Ctrl::Coarse) => "dFdyCoarse",
+ (Axis::Y, Ctrl::Fine) => "dFdyFine",
+ (Axis::Y, Ctrl::None) => "dFdy",
+ (Axis::Width, Ctrl::Coarse) => "fwidthCoarse",
+ (Axis::Width, Ctrl::Fine) => "fwidthFine",
+ (Axis::Width, Ctrl::None) => "fwidth",
+ }
+ } else {
+ match axis {
+ Axis::X => "dFdx",
+ Axis::Y => "dFdy",
+ Axis::Width => "fwidth",
+ }
+ };
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ")")?
+ }
+ // `Relational` is a normal function call to some glsl provided functions
+ Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+
+ let fun_name = match fun {
+ // There's no specific function for this but we can invert the result of `isinf`
+ Rf::IsFinite => "!isinf",
+ Rf::IsInf => "isinf",
+ Rf::IsNan => "isnan",
+ // There's also no function for this but we can invert `isnan`
+ Rf::IsNormal => "!isnan",
+ Rf::All => "all",
+ Rf::Any => "any",
+ };
+ write!(self.out, "{fun_name}(")?;
+
+ self.write_expr(argument, ctx)?;
+
+ write!(self.out, ")")?
+ }
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ let fun_name = match fun {
+ // comparison
+ Mf::Abs => "abs",
+ Mf::Min => "min",
+ Mf::Max => "max",
+ Mf::Clamp => "clamp",
+ Mf::Saturate => {
+ write!(self.out, "clamp(")?;
+
+ self.write_expr(arg, ctx)?;
+
+ match *ctx.info[arg].ty.inner_with(&self.module.types) {
+ crate::TypeInner::Vector { size, .. } => write!(
+ self.out,
+ ", vec{}(0.0), vec{0}(1.0)",
+ back::vector_size_str(size)
+ )?,
+ _ => write!(self.out, ", 0.0, 1.0")?,
+ }
+
+ write!(self.out, ")")?;
+
+ return Ok(());
+ }
+ // trigonometry
+ Mf::Cos => "cos",
+ Mf::Cosh => "cosh",
+ Mf::Sin => "sin",
+ Mf::Sinh => "sinh",
+ Mf::Tan => "tan",
+ Mf::Tanh => "tanh",
+ Mf::Acos => "acos",
+ Mf::Asin => "asin",
+ Mf::Atan => "atan",
+ Mf::Asinh => "asinh",
+ Mf::Acosh => "acosh",
+ Mf::Atanh => "atanh",
+ Mf::Radians => "radians",
+ Mf::Degrees => "degrees",
+ // glsl doesn't have atan2 function
+ // use two-argument variation of the atan function
+ Mf::Atan2 => "atan",
+ // decomposition
+ Mf::Ceil => "ceil",
+ Mf::Floor => "floor",
+ Mf::Round => "roundEven",
+ Mf::Fract => "fract",
+ Mf::Trunc => "trunc",
+ Mf::Modf => "modf",
+ Mf::Frexp => "frexp",
+ Mf::Ldexp => "ldexp",
+ // exponent
+ Mf::Exp => "exp",
+ Mf::Exp2 => "exp2",
+ Mf::Log => "log",
+ Mf::Log2 => "log2",
+ Mf::Pow => "pow",
+ // geometry
+ Mf::Dot => match *ctx.info[arg].ty.inner_with(&self.module.types) {
+ crate::TypeInner::Vector {
+ kind: crate::ScalarKind::Float,
+ ..
+ } => "dot",
+ crate::TypeInner::Vector { size, .. } => {
+ return self.write_dot_product(arg, arg1.unwrap(), size as usize, ctx)
+ }
+ _ => unreachable!(
+ "Correct TypeInner for dot product should be already validated"
+ ),
+ },
+ Mf::Outer => "outerProduct",
+ Mf::Cross => "cross",
+ Mf::Distance => "distance",
+ Mf::Length => "length",
+ Mf::Normalize => "normalize",
+ Mf::FaceForward => "faceforward",
+ Mf::Reflect => "reflect",
+ Mf::Refract => "refract",
+ // computational
+ Mf::Sign => "sign",
+ Mf::Fma => {
+ if self.options.version.supports_fma_function() {
+ // Use the fma function when available
+ "fma"
+ } else {
+ // No fma support. Transform the function call into an arithmetic expression
+ write!(self.out, "(")?;
+
+ self.write_expr(arg, ctx)?;
+ write!(self.out, " * ")?;
+
+ let arg1 =
+ arg1.ok_or_else(|| Error::Custom("Missing fma arg1".to_owned()))?;
+ self.write_expr(arg1, ctx)?;
+ write!(self.out, " + ")?;
+
+ let arg2 =
+ arg2.ok_or_else(|| Error::Custom("Missing fma arg2".to_owned()))?;
+ self.write_expr(arg2, ctx)?;
+ write!(self.out, ")")?;
+
+ return Ok(());
+ }
+ }
+ Mf::Mix => "mix",
+ Mf::Step => "step",
+ Mf::SmoothStep => "smoothstep",
+ Mf::Sqrt => "sqrt",
+ Mf::InverseSqrt => "inversesqrt",
+ Mf::Inverse => "inverse",
+ Mf::Transpose => "transpose",
+ Mf::Determinant => "determinant",
+ // bits
+ Mf::CountTrailingZeros => {
+ match *ctx.info[arg].ty.inner_with(&self.module.types) {
+ crate::TypeInner::Vector { size, kind, .. } => {
+ let s = back::vector_size_str(size);
+ if let crate::ScalarKind::Uint = kind {
+ write!(self.out, "min(uvec{s}(findLSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")), uvec{s}(32u))")?;
+ } else {
+ write!(self.out, "ivec{s}(min(uvec{s}(findLSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")), uvec{s}(32u)))")?;
+ }
+ }
+ crate::TypeInner::Scalar { kind, .. } => {
+ if let crate::ScalarKind::Uint = kind {
+ write!(self.out, "min(uint(findLSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")), 32u)")?;
+ } else {
+ write!(self.out, "int(min(uint(findLSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")), 32u))")?;
+ }
+ }
+ _ => unreachable!(),
+ };
+ return Ok(());
+ }
+ Mf::CountLeadingZeros => {
+ if self.options.version.supports_integer_functions() {
+ match *ctx.info[arg].ty.inner_with(&self.module.types) {
+ crate::TypeInner::Vector { size, kind, .. } => {
+ let s = back::vector_size_str(size);
+
+ if let crate::ScalarKind::Uint = kind {
+ write!(self.out, "uvec{s}(ivec{s}(31) - findMSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "mix(ivec{s}(31) - findMSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, "), ivec{s}(0), lessThan(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ", ivec{s}(0)))")?;
+ }
+ }
+ crate::TypeInner::Scalar { kind, .. } => {
+ if let crate::ScalarKind::Uint = kind {
+ write!(self.out, "uint(31 - findMSB(")?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, " < 0 ? 0 : 31 - findMSB(")?;
+ }
+
+ self.write_expr(arg, ctx)?;
+ write!(self.out, "))")?;
+ }
+ _ => unreachable!(),
+ };
+ } else {
+ match *ctx.info[arg].ty.inner_with(&self.module.types) {
+ crate::TypeInner::Vector { size, kind, .. } => {
+ let s = back::vector_size_str(size);
+
+ if let crate::ScalarKind::Uint = kind {
+ write!(self.out, "uvec{s}(")?;
+ write!(self.out, "vec{s}(31.0) - floor(log2(vec{s}(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ") + 0.5)))")?;
+ } else {
+ write!(self.out, "ivec{s}(")?;
+ write!(self.out, "mix(vec{s}(31.0) - floor(log2(vec{s}(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ") + 0.5)), ")?;
+ write!(self.out, "vec{s}(0.0), lessThan(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ", ivec{s}(0u))))")?;
+ }
+ }
+ crate::TypeInner::Scalar { kind, .. } => {
+ if let crate::ScalarKind::Uint = kind {
+ write!(self.out, "uint(31.0 - floor(log2(float(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ") + 0.5)))")?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, " < 0 ? 0 : int(")?;
+ write!(self.out, "31.0 - floor(log2(float(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ") + 0.5))))")?;
+ }
+ }
+ _ => unreachable!(),
+ };
+ }
+
+ return Ok(());
+ }
+ Mf::CountOneBits => "bitCount",
+ Mf::ReverseBits => "bitfieldReverse",
+ Mf::ExtractBits => "bitfieldExtract",
+ Mf::InsertBits => "bitfieldInsert",
+ Mf::FindLsb => "findLSB",
+ Mf::FindMsb => "findMSB",
+ // data packing
+ Mf::Pack4x8snorm => "packSnorm4x8",
+ Mf::Pack4x8unorm => "packUnorm4x8",
+ Mf::Pack2x16snorm => "packSnorm2x16",
+ Mf::Pack2x16unorm => "packUnorm2x16",
+ Mf::Pack2x16float => "packHalf2x16",
+ // data unpacking
+ Mf::Unpack4x8snorm => "unpackSnorm4x8",
+ Mf::Unpack4x8unorm => "unpackUnorm4x8",
+ Mf::Unpack2x16snorm => "unpackSnorm2x16",
+ Mf::Unpack2x16unorm => "unpackUnorm2x16",
+ Mf::Unpack2x16float => "unpackHalf2x16",
+ };
+
+ let extract_bits = fun == Mf::ExtractBits;
+ let insert_bits = fun == Mf::InsertBits;
+
+ // Some GLSL functions always return signed integers (like findMSB),
+ // so they need to be cast to uint if the argument is also an uint.
+ let ret_might_need_int_to_uint =
+ matches!(fun, Mf::FindLsb | Mf::FindMsb | Mf::CountOneBits | Mf::Abs);
+
+ // Some GLSL functions only accept signed integers (like abs),
+ // so they need their argument cast from uint to int.
+ let arg_might_need_uint_to_int = matches!(fun, Mf::Abs);
+
+ // Check if the argument is an unsigned integer and return the vector size
+ // in case it's a vector
+ let maybe_uint_size = match *ctx.info[arg].ty.inner_with(&self.module.types) {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ } => Some(None),
+ crate::TypeInner::Vector {
+ kind: crate::ScalarKind::Uint,
+ size,
+ ..
+ } => Some(Some(size)),
+ _ => None,
+ };
+
+ // Cast to uint if the function needs it
+ if ret_might_need_int_to_uint {
+ if let Some(maybe_size) = maybe_uint_size {
+ match maybe_size {
+ Some(size) => write!(self.out, "uvec{}(", size as u8)?,
+ None => write!(self.out, "uint(")?,
+ }
+ }
+ }
+
+ write!(self.out, "{fun_name}(")?;
+
+ // Cast to int if the function needs it
+ if arg_might_need_uint_to_int {
+ if let Some(maybe_size) = maybe_uint_size {
+ match maybe_size {
+ Some(size) => write!(self.out, "ivec{}(", size as u8)?,
+ None => write!(self.out, "int(")?,
+ }
+ }
+ }
+
+ self.write_expr(arg, ctx)?;
+
+ // Close the cast from uint to int
+ if arg_might_need_uint_to_int && maybe_uint_size.is_some() {
+ write!(self.out, ")")?
+ }
+
+ if let Some(arg) = arg1 {
+ write!(self.out, ", ")?;
+ if extract_bits {
+ write!(self.out, "int(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(arg, ctx)?;
+ }
+ }
+ if let Some(arg) = arg2 {
+ write!(self.out, ", ")?;
+ if extract_bits || insert_bits {
+ write!(self.out, "int(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(arg, ctx)?;
+ }
+ }
+ if let Some(arg) = arg3 {
+ write!(self.out, ", ")?;
+ if insert_bits {
+ write!(self.out, "int(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(arg, ctx)?;
+ }
+ }
+ write!(self.out, ")")?;
+
+ // Close the cast from int to uint
+ if ret_might_need_int_to_uint && maybe_uint_size.is_some() {
+ write!(self.out, ")")?
+ }
+ }
+ // `As` is always a call.
+ // If `convert` is true the function name is the type
+ // Else the function name is one of the glsl provided bitcast functions
+ Expression::As {
+ expr,
+ kind: target_kind,
+ convert,
+ } => {
+ let inner = ctx.info[expr].ty.inner_with(&self.module.types);
+ match convert {
+ Some(width) => {
+ // this is similar to `write_type`, but with the target kind
+ let scalar = glsl_scalar(target_kind, width)?;
+ match *inner {
+ TypeInner::Matrix { columns, rows, .. } => write!(
+ self.out,
+ "{}mat{}x{}",
+ scalar.prefix, columns as u8, rows as u8
+ )?,
+ TypeInner::Vector { size, .. } => {
+ write!(self.out, "{}vec{}", scalar.prefix, size as u8)?
+ }
+ _ => write!(self.out, "{}", scalar.full)?,
+ }
+
+ write!(self.out, "(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ")")?
+ }
+ None => {
+ use crate::ScalarKind as Sk;
+
+ let target_vector_type = match *inner {
+ TypeInner::Vector { size, width, .. } => Some(TypeInner::Vector {
+ size,
+ width,
+ kind: target_kind,
+ }),
+ _ => None,
+ };
+
+ let source_kind = inner.scalar_kind().unwrap();
+
+ match (source_kind, target_kind, target_vector_type) {
+ // No conversion needed
+ (Sk::Sint, Sk::Sint, _)
+ | (Sk::Uint, Sk::Uint, _)
+ | (Sk::Float, Sk::Float, _)
+ | (Sk::Bool, Sk::Bool, _) => {
+ self.write_expr(expr, ctx)?;
+ return Ok(());
+ }
+
+ // Cast to/from floats
+ (Sk::Float, Sk::Sint, _) => write!(self.out, "floatBitsToInt")?,
+ (Sk::Float, Sk::Uint, _) => write!(self.out, "floatBitsToUint")?,
+ (Sk::Sint, Sk::Float, _) => write!(self.out, "intBitsToFloat")?,
+ (Sk::Uint, Sk::Float, _) => write!(self.out, "uintBitsToFloat")?,
+
+ // Cast between vector types
+ (_, _, Some(vector)) => {
+ self.write_value_type(&vector)?;
+ }
+
+ // There is no way to bitcast between Uint/Sint in glsl. Use constructor conversion
+ (Sk::Uint | Sk::Bool, Sk::Sint, None) => write!(self.out, "int")?,
+ (Sk::Sint | Sk::Bool, Sk::Uint, None) => write!(self.out, "uint")?,
+ (Sk::Bool, Sk::Float, None) => write!(self.out, "float")?,
+ (Sk::Sint | Sk::Uint | Sk::Float, Sk::Bool, None) => {
+ write!(self.out, "bool")?
+ }
+ };
+
+ write!(self.out, "(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ // These expressions never show up in `Emit`.
+ Expression::CallResult(_)
+ | Expression::AtomicResult { .. }
+ | Expression::RayQueryProceedResult => unreachable!(),
+ // `ArrayLength` is written as `expr.length()` and we convert it to a uint
+ Expression::ArrayLength(expr) => {
+ write!(self.out, "uint(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ".length())")?
+ }
+ // not supported yet
+ Expression::RayQueryGetIntersection { .. } => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Helper function to write the local holding the clamped lod
+ fn write_clamped_lod(
+ &mut self,
+ ctx: &back::FunctionCtx,
+ expr: Handle<crate::Expression>,
+ image: Handle<crate::Expression>,
+ level_expr: Handle<crate::Expression>,
+ ) -> Result<(), Error> {
+ // Define our local and start a call to `clamp`
+ write!(
+ self.out,
+ "int {}{}{} = clamp(",
+ back::BAKE_PREFIX,
+ expr.index(),
+ CLAMPED_LOD_SUFFIX
+ )?;
+ // Write the lod that will be clamped
+ self.write_expr(level_expr, ctx)?;
+ // Set the min value to 0 and start a call to `textureQueryLevels` to get
+ // the maximum value
+ write!(self.out, ", 0, textureQueryLevels(")?;
+ // Write the target image as an argument to `textureQueryLevels`
+ self.write_expr(image, ctx)?;
+ // Close the call to `textureQueryLevels` subtract 1 from it since
+ // the lod argument is 0 based, close the `clamp` call and end the
+ // local declaration statement.
+ writeln!(self.out, ") - 1);")?;
+
+ Ok(())
+ }
+
+ // Helper method used to retrieve how many elements a coordinate vector
+ // for the images operations need.
+ fn get_coordinate_vector_size(&self, dim: crate::ImageDimension, arrayed: bool) -> u8 {
+ // openGL es doesn't have 1D images so we need workaround it
+ let tex_1d_hack = dim == crate::ImageDimension::D1 && self.options.version.is_es();
+ // Get how many components the coordinate vector needs for the dimensions only
+ let tex_coord_size = match dim {
+ crate::ImageDimension::D1 => 1,
+ crate::ImageDimension::D2 => 2,
+ crate::ImageDimension::D3 => 3,
+ crate::ImageDimension::Cube => 2,
+ };
+ // Calculate the true size of the coordinate vector by adding 1 for arrayed images
+ // and another 1 if we need to workaround 1D images by making them 2D
+ tex_coord_size + tex_1d_hack as u8 + arrayed as u8
+ }
+
+ /// Helper method to write the coordinate vector for image operations
+ fn write_texture_coord(
+ &mut self,
+ ctx: &back::FunctionCtx,
+ vector_size: u8,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ // Emulate 1D images as 2D for profiles that don't support it (glsl es)
+ tex_1d_hack: bool,
+ ) -> Result<(), Error> {
+ match array_index {
+ // If the image needs an array indice we need to add it to the end of our
+ // coordinate vector, to do so we will use the `ivec(ivec, scalar)`
+ // constructor notation (NOTE: the inner `ivec` can also be a scalar, this
+ // is important for 1D arrayed images).
+ Some(layer_expr) => {
+ write!(self.out, "ivec{vector_size}(")?;
+ self.write_expr(coordinate, ctx)?;
+ write!(self.out, ", ")?;
+ // If we are replacing sampler1D with sampler2D we also need
+ // to add another zero to the coordinates vector for the y component
+ if tex_1d_hack {
+ write!(self.out, "0, ")?;
+ }
+ self.write_expr(layer_expr, ctx)?;
+ write!(self.out, ")")?;
+ }
+ // Otherwise write just the expression (and the 1D hack if needed)
+ None => {
+ let uvec_size = match *ctx.info[coordinate].ty.inner_with(&self.module.types) {
+ TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ } => Some(None),
+ TypeInner::Vector {
+ size,
+ kind: crate::ScalarKind::Uint,
+ ..
+ } => Some(Some(size as u32)),
+ _ => None,
+ };
+ if tex_1d_hack {
+ write!(self.out, "ivec2(")?;
+ } else if uvec_size.is_some() {
+ match uvec_size {
+ Some(None) => write!(self.out, "int(")?,
+ Some(Some(size)) => write!(self.out, "ivec{size}(")?,
+ _ => {}
+ }
+ }
+ self.write_expr(coordinate, ctx)?;
+ if tex_1d_hack {
+ write!(self.out, ", 0)")?;
+ } else if uvec_size.is_some() {
+ write!(self.out, ")")?;
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method to write the `ImageStore` statement
+ fn write_image_store(
+ &mut self,
+ ctx: &back::FunctionCtx,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ value: Handle<crate::Expression>,
+ ) -> Result<(), Error> {
+ use crate::ImageDimension as IDim;
+
+ // NOTE: openGL requires that `imageStore`s have no effets when the texel is invalid
+ // so we don't need to generate bounds checks (OpenGL 4.2 Core §3.9.20)
+
+ // This will only panic if the module is invalid
+ let dim = match *ctx.info[image].ty.inner_with(&self.module.types) {
+ TypeInner::Image { dim, .. } => dim,
+ _ => unreachable!(),
+ };
+
+ // Begin our call to `imageStore`
+ write!(self.out, "imageStore(")?;
+ self.write_expr(image, ctx)?;
+ // Separate the image argument from the coordinates
+ write!(self.out, ", ")?;
+
+ // openGL es doesn't have 1D images so we need workaround it
+ let tex_1d_hack = dim == IDim::D1 && self.options.version.is_es();
+ // Write the coordinate vector
+ self.write_texture_coord(
+ ctx,
+ // Get the size of the coordinate vector
+ self.get_coordinate_vector_size(dim, array_index.is_some()),
+ coordinate,
+ array_index,
+ tex_1d_hack,
+ )?;
+
+ // Separate the coordinate from the value to write and write the expression
+ // of the value to write.
+ write!(self.out, ", ")?;
+ self.write_expr(value, ctx)?;
+ // End the call to `imageStore` and the statement.
+ writeln!(self.out, ");")?;
+
+ Ok(())
+ }
+
+ /// Helper method for writing an `ImageLoad` expression.
+ #[allow(clippy::too_many_arguments)]
+ fn write_image_load(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ ctx: &back::FunctionCtx,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ sample: Option<Handle<crate::Expression>>,
+ level: Option<Handle<crate::Expression>>,
+ ) -> Result<(), Error> {
+ use crate::ImageDimension as IDim;
+
+ // `ImageLoad` is a bit complicated.
+ // There are two functions one for sampled
+ // images another for storage images, the former uses `texelFetch` and the
+ // latter uses `imageLoad`.
+ //
+ // Furthermore we have `level` which is always `Some` for sampled images
+ // and `None` for storage images, so we end up with two functions:
+ // - `texelFetch(image, coordinate, level)` for sampled images
+ // - `imageLoad(image, coordinate)` for storage images
+ //
+ // Finally we also have to consider bounds checking, for storage images
+ // this is easy since openGL requires that invalid texels always return
+ // 0, for sampled images we need to either verify that all arguments are
+ // in bounds (`ReadZeroSkipWrite`) or make them a valid texel (`Restrict`).
+
+ // This will only panic if the module is invalid
+ let (dim, class) = match *ctx.info[image].ty.inner_with(&self.module.types) {
+ TypeInner::Image {
+ dim,
+ arrayed: _,
+ class,
+ } => (dim, class),
+ _ => unreachable!(),
+ };
+
+ // Get the name of the function to be used for the load operation
+ // and the policy to be used with it.
+ let (fun_name, policy) = match class {
+ // Sampled images inherit the policy from the user passed policies
+ crate::ImageClass::Sampled { .. } => ("texelFetch", self.policies.image),
+ crate::ImageClass::Storage { .. } => {
+ // OpenGL 4.2 Core §3.9.20 defines that out of bounds texels in `imageLoad`s
+ // always return zero values so we don't need to generate bounds checks
+ ("imageLoad", proc::BoundsCheckPolicy::Unchecked)
+ }
+ // TODO: Is there even a function for this?
+ crate::ImageClass::Depth { multi: _ } => {
+ return Err(Error::Custom(
+ "WGSL `textureLoad` from depth textures is not supported in GLSL".to_string(),
+ ))
+ }
+ };
+
+ // openGL es doesn't have 1D images so we need workaround it
+ let tex_1d_hack = dim == IDim::D1 && self.options.version.is_es();
+ // Get the size of the coordinate vector
+ let vector_size = self.get_coordinate_vector_size(dim, array_index.is_some());
+
+ if let proc::BoundsCheckPolicy::ReadZeroSkipWrite = policy {
+ // To write the bounds checks for `ReadZeroSkipWrite` we will use a
+ // ternary operator since we are in the middle of an expression and
+ // need to return a value.
+ //
+ // NOTE: glsl does short circuit when evaluating logical
+ // expressions so we can be sure that after we test a
+ // condition it will be true for the next ones
+
+ // Write parantheses around the ternary operator to prevent problems with
+ // expressions emitted before or after it having more precedence
+ write!(self.out, "(",)?;
+
+ // The lod check needs to precede the size check since we need
+ // to use the lod to get the size of the image at that level.
+ if let Some(level_expr) = level {
+ self.write_expr(level_expr, ctx)?;
+ write!(self.out, " < textureQueryLevels(",)?;
+ self.write_expr(image, ctx)?;
+ // Chain the next check
+ write!(self.out, ") && ")?;
+ }
+
+ // Check that the sample arguments doesn't exceed the number of samples
+ if let Some(sample_expr) = sample {
+ self.write_expr(sample_expr, ctx)?;
+ write!(self.out, " < textureSamples(",)?;
+ self.write_expr(image, ctx)?;
+ // Chain the next check
+ write!(self.out, ") && ")?;
+ }
+
+ // We now need to write the size checks for the coordinates and array index
+ // first we write the comparation function in case the image is 1D non arrayed
+ // (and no 1D to 2D hack was needed) we are comparing scalars so the less than
+ // operator will suffice, but otherwise we'll be comparing two vectors so we'll
+ // need to use the `lessThan` function but it returns a vector of booleans (one
+ // for each comparison) so we need to fold it all in one scalar boolean, since
+ // we want all comparisons to pass we use the `all` function which will only
+ // return `true` if all the elements of the boolean vector are also `true`.
+ //
+ // So we'll end with one of the following forms
+ // - `coord < textureSize(image, lod)` for 1D images
+ // - `all(lessThan(coord, textureSize(image, lod)))` for normal images
+ // - `all(lessThan(ivec(coord, array_index), textureSize(image, lod)))`
+ // for arrayed images
+ // - `all(lessThan(coord, textureSize(image)))` for multi sampled images
+
+ if vector_size != 1 {
+ write!(self.out, "all(lessThan(")?;
+ }
+
+ // Write the coordinate vector
+ self.write_texture_coord(ctx, vector_size, coordinate, array_index, tex_1d_hack)?;
+
+ if vector_size != 1 {
+ // If we used the `lessThan` function we need to separate the
+ // coordinates from the image size.
+ write!(self.out, ", ")?;
+ } else {
+ // If we didn't use it (ie. 1D images) we perform the comparsion
+ // using the less than operator.
+ write!(self.out, " < ")?;
+ }
+
+ // Call `textureSize` to get our image size
+ write!(self.out, "textureSize(")?;
+ self.write_expr(image, ctx)?;
+ // `textureSize` uses the lod as a second argument for mipmapped images
+ if let Some(level_expr) = level {
+ // Separate the image from the lod
+ write!(self.out, ", ")?;
+ self.write_expr(level_expr, ctx)?;
+ }
+ // Close the `textureSize` call
+ write!(self.out, ")")?;
+
+ if vector_size != 1 {
+ // Close the `all` and `lessThan` calls
+ write!(self.out, "))")?;
+ }
+
+ // Finally end the condition part of the ternary operator
+ write!(self.out, " ? ")?;
+ }
+
+ // Begin the call to the function used to load the texel
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(image, ctx)?;
+ write!(self.out, ", ")?;
+
+ // If we are using `Restrict` bounds checking we need to pass valid texel
+ // coordinates, to do so we use the `clamp` function to get a value between
+ // 0 and the image size - 1 (indexing begins at 0)
+ if let proc::BoundsCheckPolicy::Restrict = policy {
+ write!(self.out, "clamp(")?;
+ }
+
+ // Write the coordinate vector
+ self.write_texture_coord(ctx, vector_size, coordinate, array_index, tex_1d_hack)?;
+
+ // If we are using `Restrict` bounds checking we need to write the rest of the
+ // clamp we initiated before writing the coordinates.
+ if let proc::BoundsCheckPolicy::Restrict = policy {
+ // Write the min value 0
+ if vector_size == 1 {
+ write!(self.out, ", 0")?;
+ } else {
+ write!(self.out, ", ivec{vector_size}(0)")?;
+ }
+ // Start the `textureSize` call to use as the max value.
+ write!(self.out, ", textureSize(")?;
+ self.write_expr(image, ctx)?;
+ // If the image is mipmapped we need to add the lod argument to the
+ // `textureSize` call, but this needs to be the clamped lod, this should
+ // have been generated earlier and put in a local.
+ if class.is_mipmapped() {
+ write!(
+ self.out,
+ ", {}{}{}",
+ back::BAKE_PREFIX,
+ handle.index(),
+ CLAMPED_LOD_SUFFIX
+ )?;
+ }
+ // Close the `textureSize` call
+ write!(self.out, ")")?;
+
+ // Subtract 1 from the `textureSize` call since the coordinates are zero based.
+ if vector_size == 1 {
+ write!(self.out, " - 1")?;
+ } else {
+ write!(self.out, " - ivec{vector_size}(1)")?;
+ }
+
+ // Close the `clamp` call
+ write!(self.out, ")")?;
+
+ // Add the clamped lod (if present) as the second argument to the
+ // image load function.
+ if level.is_some() {
+ write!(
+ self.out,
+ ", {}{}{}",
+ back::BAKE_PREFIX,
+ handle.index(),
+ CLAMPED_LOD_SUFFIX
+ )?;
+ }
+
+ // If a sample argument is needed we need to clamp it between 0 and
+ // the number of samples the image has.
+ if let Some(sample_expr) = sample {
+ write!(self.out, ", clamp(")?;
+ self.write_expr(sample_expr, ctx)?;
+ // Set the min value to 0 and start the call to `textureSamples`
+ write!(self.out, ", 0, textureSamples(")?;
+ self.write_expr(image, ctx)?;
+ // Close the `textureSamples` call, subtract 1 from it since the sample
+ // argument is zero based, and close the `clamp` call
+ writeln!(self.out, ") - 1)")?;
+ }
+ } else if let Some(sample_or_level) = sample.or(level) {
+ // If no bounds checking is need just add the sample or level argument
+ // after the coordinates
+ write!(self.out, ", ")?;
+ self.write_expr(sample_or_level, ctx)?;
+ }
+
+ // Close the image load function.
+ write!(self.out, ")")?;
+
+ // If we were using the `ReadZeroSkipWrite` policy we need to end the first branch
+ // (which is taken if the condition is `true`) with a colon (`:`) and write the
+ // second branch which is just a 0 value.
+ if let proc::BoundsCheckPolicy::ReadZeroSkipWrite = policy {
+ // Get the kind of the output value.
+ let kind = match class {
+ // Only sampled images can reach here since storage images
+ // don't need bounds checks and depth images aren't implmented
+ crate::ImageClass::Sampled { kind, .. } => kind,
+ _ => unreachable!(),
+ };
+
+ // End the first branch
+ write!(self.out, " : ")?;
+ // Write the 0 value
+ write!(self.out, "{}vec4(", glsl_scalar(kind, 4)?.prefix,)?;
+ self.write_zero_init_scalar(kind)?;
+ // Close the zero value constructor
+ write!(self.out, ")")?;
+ // Close the parantheses surrounding our ternary
+ write!(self.out, ")")?;
+ }
+
+ Ok(())
+ }
+
+ fn write_named_expr(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ name: String,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ match ctx.info[handle].ty {
+ proc::TypeResolution::Handle(ty_handle) => match self.module.types[ty_handle].inner {
+ TypeInner::Struct { .. } => {
+ let ty_name = &self.names[&NameKey::Type(ty_handle)];
+ write!(self.out, "{ty_name}")?;
+ }
+ _ => {
+ self.write_type(ty_handle)?;
+ }
+ },
+ proc::TypeResolution::Value(ref inner) => {
+ self.write_value_type(inner)?;
+ }
+ }
+
+ let base_ty_res = &ctx.info[handle].ty;
+ let resolved = base_ty_res.inner_with(&self.module.types);
+
+ write!(self.out, " {name}")?;
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(base, size)?;
+ }
+ write!(self.out, " = ")?;
+ self.write_expr(handle, ctx)?;
+ writeln!(self.out, ";")?;
+ self.named_expressions.insert(handle, name);
+
+ Ok(())
+ }
+
+ /// Helper function that write string with default zero initialization for supported types
+ fn write_zero_init_value(&mut self, ty: Handle<crate::Type>) -> BackendResult {
+ let inner = &self.module.types[ty].inner;
+ match *inner {
+ TypeInner::Scalar { kind, .. } | TypeInner::Atomic { kind, .. } => {
+ self.write_zero_init_scalar(kind)?;
+ }
+ TypeInner::Vector { kind, .. } => {
+ self.write_value_type(inner)?;
+ write!(self.out, "(")?;
+ self.write_zero_init_scalar(kind)?;
+ write!(self.out, ")")?;
+ }
+ TypeInner::Matrix { .. } => {
+ self.write_value_type(inner)?;
+ write!(self.out, "(")?;
+ self.write_zero_init_scalar(crate::ScalarKind::Float)?;
+ write!(self.out, ")")?;
+ }
+ TypeInner::Array { base, size, .. } => {
+ let count = match size
+ .to_indexable_length(self.module)
+ .expect("Bad array size")
+ {
+ proc::IndexableLength::Known(count) => count,
+ proc::IndexableLength::Dynamic => return Ok(()),
+ };
+ self.write_type(base)?;
+ self.write_array_size(base, size)?;
+ write!(self.out, "(")?;
+ for _ in 1..count {
+ self.write_zero_init_value(base)?;
+ write!(self.out, ", ")?;
+ }
+ // write last parameter without comma and space
+ self.write_zero_init_value(base)?;
+ write!(self.out, ")")?;
+ }
+ TypeInner::Struct { ref members, .. } => {
+ let name = &self.names[&NameKey::Type(ty)];
+ write!(self.out, "{name}(")?;
+ for (i, member) in members.iter().enumerate() {
+ self.write_zero_init_value(member.ty)?;
+ if i != members.len().saturating_sub(1) {
+ write!(self.out, ", ")?;
+ }
+ }
+ write!(self.out, ")")?;
+ }
+ _ => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Helper function that write string with zero initialization for scalar
+ fn write_zero_init_scalar(&mut self, kind: crate::ScalarKind) -> BackendResult {
+ match kind {
+ crate::ScalarKind::Bool => write!(self.out, "false")?,
+ crate::ScalarKind::Uint => write!(self.out, "0u")?,
+ crate::ScalarKind::Float => write!(self.out, "0.0")?,
+ crate::ScalarKind::Sint => write!(self.out, "0")?,
+ }
+
+ Ok(())
+ }
+
+ /// Issue a memory barrier. Please note that to ensure visibility,
+ /// OpenGL always requires a call to the `barrier()` function after a `memoryBarrier*()`
+ fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
+ if flags.contains(crate::Barrier::STORAGE) {
+ writeln!(self.out, "{level}memoryBarrierBuffer();")?;
+ }
+ if flags.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(self.out, "{level}memoryBarrierShared();")?;
+ }
+ writeln!(self.out, "{level}barrier();")?;
+ Ok(())
+ }
+
+ /// Helper function that return the glsl storage access string of [`StorageAccess`](crate::StorageAccess)
+ ///
+ /// glsl allows adding both `readonly` and `writeonly` but this means that
+ /// they can only be used to query information about the resource which isn't what
+ /// we want here so when storage access is both `LOAD` and `STORE` add no modifiers
+ fn write_storage_access(&mut self, storage_access: crate::StorageAccess) -> BackendResult {
+ if !storage_access.contains(crate::StorageAccess::STORE) {
+ write!(self.out, "readonly ")?;
+ }
+ if !storage_access.contains(crate::StorageAccess::LOAD) {
+ write!(self.out, "writeonly ")?;
+ }
+ Ok(())
+ }
+
+ /// Helper method used to produce the reflection info that's returned to the user
+ fn collect_reflection_info(&self) -> Result<ReflectionInfo, Error> {
+ use std::collections::hash_map::Entry;
+ let info = self.info.get_entry_point(self.entry_point_idx as usize);
+ let mut texture_mapping = crate::FastHashMap::default();
+ let mut uniforms = crate::FastHashMap::default();
+
+ for sampling in info.sampling_set.iter() {
+ let tex_name = self.reflection_names_globals[&sampling.image].clone();
+
+ match texture_mapping.entry(tex_name) {
+ Entry::Vacant(v) => {
+ v.insert(TextureMapping {
+ texture: sampling.image,
+ sampler: Some(sampling.sampler),
+ });
+ }
+ Entry::Occupied(e) => {
+ if e.get().sampler != Some(sampling.sampler) {
+ log::error!("Conflicting samplers for {}", e.key());
+ return Err(Error::ImageMultipleSamplers);
+ }
+ }
+ }
+ }
+
+ for (handle, var) in self.module.global_variables.iter() {
+ if info[handle].is_empty() {
+ continue;
+ }
+ match self.module.types[var.ty].inner {
+ crate::TypeInner::Image { .. } => {
+ let tex_name = self.reflection_names_globals[&handle].clone();
+ match texture_mapping.entry(tex_name) {
+ Entry::Vacant(v) => {
+ v.insert(TextureMapping {
+ texture: handle,
+ sampler: None,
+ });
+ }
+ Entry::Occupied(_) => {
+ // already used with a sampler, do nothing
+ }
+ }
+ }
+ _ => match var.space {
+ crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } => {
+ let name = self.reflection_names_globals[&handle].clone();
+ uniforms.insert(handle, name);
+ }
+ _ => (),
+ },
+ }
+ }
+
+ Ok(ReflectionInfo {
+ texture_mapping,
+ uniforms,
+ })
+ }
+}
+
+/// Structure returned by [`glsl_scalar`](glsl_scalar)
+///
+/// It contains both a prefix used in other types and the full type name
+struct ScalarString<'a> {
+ /// The prefix used to compose other types
+ prefix: &'a str,
+ /// The name of the scalar type
+ full: &'a str,
+}
+
+/// Helper function that returns scalar related strings
+///
+/// Check [`ScalarString`](ScalarString) for the information provided
+///
+/// # Errors
+/// If a [`Float`](crate::ScalarKind::Float) with an width that isn't 4 or 8
+const fn glsl_scalar(
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+) -> Result<ScalarString<'static>, Error> {
+ use crate::ScalarKind as Sk;
+
+ Ok(match kind {
+ Sk::Sint => ScalarString {
+ prefix: "i",
+ full: "int",
+ },
+ Sk::Uint => ScalarString {
+ prefix: "u",
+ full: "uint",
+ },
+ Sk::Float => match width {
+ 4 => ScalarString {
+ prefix: "",
+ full: "float",
+ },
+ 8 => ScalarString {
+ prefix: "d",
+ full: "double",
+ },
+ _ => return Err(Error::UnsupportedScalar(kind, width)),
+ },
+ Sk::Bool => ScalarString {
+ prefix: "b",
+ full: "bool",
+ },
+ })
+}
+
+/// Helper function that returns the glsl variable name for a builtin
+const fn glsl_built_in(
+ built_in: crate::BuiltIn,
+ output: bool,
+ targetting_webgl: bool,
+) -> &'static str {
+ use crate::BuiltIn as Bi;
+
+ match built_in {
+ Bi::Position { .. } => {
+ if output {
+ "gl_Position"
+ } else {
+ "gl_FragCoord"
+ }
+ }
+ Bi::ViewIndex if targetting_webgl => "int(gl_ViewID_OVR)",
+ Bi::ViewIndex => "gl_ViewIndex",
+ // vertex
+ Bi::BaseInstance => "uint(gl_BaseInstance)",
+ Bi::BaseVertex => "uint(gl_BaseVertex)",
+ Bi::ClipDistance => "gl_ClipDistance",
+ Bi::CullDistance => "gl_CullDistance",
+ Bi::InstanceIndex => "uint(gl_InstanceID)",
+ Bi::PointSize => "gl_PointSize",
+ Bi::VertexIndex => "uint(gl_VertexID)",
+ // fragment
+ Bi::FragDepth => "gl_FragDepth",
+ Bi::PointCoord => "gl_PointCoord",
+ Bi::FrontFacing => "gl_FrontFacing",
+ Bi::PrimitiveIndex => "uint(gl_PrimitiveID)",
+ Bi::SampleIndex => "gl_SampleID",
+ Bi::SampleMask => {
+ if output {
+ "gl_SampleMask"
+ } else {
+ "gl_SampleMaskIn"
+ }
+ }
+ // compute
+ Bi::GlobalInvocationId => "gl_GlobalInvocationID",
+ Bi::LocalInvocationId => "gl_LocalInvocationID",
+ Bi::LocalInvocationIndex => "gl_LocalInvocationIndex",
+ Bi::WorkGroupId => "gl_WorkGroupID",
+ Bi::WorkGroupSize => "gl_WorkGroupSize",
+ Bi::NumWorkGroups => "gl_NumWorkGroups",
+ }
+}
+
+/// Helper function that returns the string corresponding to the address space
+const fn glsl_storage_qualifier(space: crate::AddressSpace) -> Option<&'static str> {
+ use crate::AddressSpace as As;
+
+ match space {
+ As::Function => None,
+ As::Private => None,
+ As::Storage { .. } => Some("buffer"),
+ As::Uniform => Some("uniform"),
+ As::Handle => Some("uniform"),
+ As::WorkGroup => Some("shared"),
+ As::PushConstant => Some("uniform"),
+ }
+}
+
+/// Helper function that returns the string corresponding to the glsl interpolation qualifier
+const fn glsl_interpolation(interpolation: crate::Interpolation) -> &'static str {
+ use crate::Interpolation as I;
+
+ match interpolation {
+ I::Perspective => "smooth",
+ I::Linear => "noperspective",
+ I::Flat => "flat",
+ }
+}
+
+/// Return the GLSL auxiliary qualifier for the given sampling value.
+const fn glsl_sampling(sampling: crate::Sampling) -> Option<&'static str> {
+ use crate::Sampling as S;
+
+ match sampling {
+ S::Center => None,
+ S::Centroid => Some("centroid"),
+ S::Sample => Some("sample"),
+ }
+}
+
+/// Helper function that returns the glsl dimension string of [`ImageDimension`](crate::ImageDimension)
+const fn glsl_dimension(dim: crate::ImageDimension) -> &'static str {
+ use crate::ImageDimension as IDim;
+
+ match dim {
+ IDim::D1 => "1D",
+ IDim::D2 => "2D",
+ IDim::D3 => "3D",
+ IDim::Cube => "Cube",
+ }
+}
+
+/// Helper function that returns the glsl storage format string of [`StorageFormat`](crate::StorageFormat)
+const fn glsl_storage_format(format: crate::StorageFormat) -> &'static str {
+ use crate::StorageFormat as Sf;
+
+ match format {
+ Sf::R8Unorm => "r8",
+ Sf::R8Snorm => "r8_snorm",
+ Sf::R8Uint => "r8ui",
+ Sf::R8Sint => "r8i",
+ Sf::R16Uint => "r16ui",
+ Sf::R16Sint => "r16i",
+ Sf::R16Float => "r16f",
+ Sf::Rg8Unorm => "rg8",
+ Sf::Rg8Snorm => "rg8_snorm",
+ Sf::Rg8Uint => "rg8ui",
+ Sf::Rg8Sint => "rg8i",
+ Sf::R32Uint => "r32ui",
+ Sf::R32Sint => "r32i",
+ Sf::R32Float => "r32f",
+ Sf::Rg16Uint => "rg16ui",
+ Sf::Rg16Sint => "rg16i",
+ Sf::Rg16Float => "rg16f",
+ Sf::Rgba8Unorm => "rgba8",
+ Sf::Rgba8Snorm => "rgba8_snorm",
+ Sf::Rgba8Uint => "rgba8ui",
+ Sf::Rgba8Sint => "rgba8i",
+ Sf::Rgb10a2Unorm => "rgb10_a2ui",
+ Sf::Rg11b10Float => "r11f_g11f_b10f",
+ Sf::Rg32Uint => "rg32ui",
+ Sf::Rg32Sint => "rg32i",
+ Sf::Rg32Float => "rg32f",
+ Sf::Rgba16Uint => "rgba16ui",
+ Sf::Rgba16Sint => "rgba16i",
+ Sf::Rgba16Float => "rgba16f",
+ Sf::Rgba32Uint => "rgba32ui",
+ Sf::Rgba32Sint => "rgba32i",
+ Sf::Rgba32Float => "rgba32f",
+ Sf::R16Unorm => "r16",
+ Sf::R16Snorm => "r16_snorm",
+ Sf::Rg16Unorm => "rg16",
+ Sf::Rg16Snorm => "rg16_snorm",
+ Sf::Rgba16Unorm => "rgba16",
+ Sf::Rgba16Snorm => "rgba16_snorm",
+ }
+}
+
+fn is_value_init_supported(module: &crate::Module, ty: Handle<crate::Type>) -> bool {
+ match module.types[ty].inner {
+ TypeInner::Scalar { .. } | TypeInner::Vector { .. } | TypeInner::Matrix { .. } => true,
+ TypeInner::Array { base, size, .. } => {
+ size != crate::ArraySize::Dynamic && is_value_init_supported(module, base)
+ }
+ TypeInner::Struct { ref members, .. } => members
+ .iter()
+ .all(|member| is_value_init_supported(module, member.ty)),
+ _ => false,
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/conv.rs b/third_party/rust/naga/src/back/hlsl/conv.rs
new file mode 100644
index 0000000000..5eb24962f6
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/conv.rs
@@ -0,0 +1,222 @@
+use std::borrow::Cow;
+
+use crate::proc::Alignment;
+
+use super::Error;
+
+impl crate::ScalarKind {
+ pub(super) fn to_hlsl_cast(self) -> &'static str {
+ match self {
+ Self::Float => "asfloat",
+ Self::Sint => "asint",
+ Self::Uint => "asuint",
+ Self::Bool => unreachable!(),
+ }
+ }
+
+ /// Helper function that returns scalar related strings
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-scalar>
+ pub(super) const fn to_hlsl_str(self, width: crate::Bytes) -> Result<&'static str, Error> {
+ match self {
+ Self::Sint => Ok("int"),
+ Self::Uint => Ok("uint"),
+ Self::Float => match width {
+ 2 => Ok("half"),
+ 4 => Ok("float"),
+ 8 => Ok("double"),
+ _ => Err(Error::UnsupportedScalar(self, width)),
+ },
+ Self::Bool => Ok("bool"),
+ }
+ }
+}
+
+impl crate::TypeInner {
+ pub(super) const fn is_matrix(&self) -> bool {
+ match *self {
+ Self::Matrix { .. } => true,
+ _ => false,
+ }
+ }
+
+ pub(super) fn size_hlsl(
+ &self,
+ types: &crate::UniqueArena<crate::Type>,
+ constants: &crate::Arena<crate::Constant>,
+ ) -> u32 {
+ match *self {
+ Self::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let stride = Alignment::from(rows) * width as u32;
+ let last_row_size = rows as u32 * width as u32;
+ ((columns as u32 - 1) * stride) + last_row_size
+ }
+ Self::Array { base, size, stride } => {
+ let count = match size {
+ crate::ArraySize::Constant(handle) => {
+ constants[handle].to_array_length().unwrap_or(1)
+ }
+ // A dynamically-sized array has to have at least one element
+ crate::ArraySize::Dynamic => 1,
+ };
+ let last_el_size = types[base].inner.size_hlsl(types, constants);
+ ((count - 1) * stride) + last_el_size
+ }
+ _ => self.size(constants),
+ }
+ }
+
+ /// Used to generate the name of the wrapped type constructor
+ pub(super) fn hlsl_type_id<'a>(
+ base: crate::Handle<crate::Type>,
+ types: &crate::UniqueArena<crate::Type>,
+ constants: &crate::Arena<crate::Constant>,
+ names: &'a crate::FastHashMap<crate::proc::NameKey, String>,
+ ) -> Result<Cow<'a, str>, Error> {
+ Ok(match types[base].inner {
+ crate::TypeInner::Scalar { kind, width } => Cow::Borrowed(kind.to_hlsl_str(width)?),
+ crate::TypeInner::Vector { size, kind, width } => Cow::Owned(format!(
+ "{}{}",
+ kind.to_hlsl_str(width)?,
+ crate::back::vector_size_str(size)
+ )),
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => Cow::Owned(format!(
+ "{}{}x{}",
+ crate::ScalarKind::Float.to_hlsl_str(width)?,
+ crate::back::vector_size_str(columns),
+ crate::back::vector_size_str(rows),
+ )),
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => Cow::Owned(format!(
+ "array{}_{}_",
+ constants[size].to_array_length().unwrap(),
+ Self::hlsl_type_id(base, types, constants, names)?
+ )),
+ crate::TypeInner::Struct { .. } => {
+ Cow::Borrowed(&names[&crate::proc::NameKey::Type(base)])
+ }
+ _ => unreachable!(),
+ })
+ }
+}
+
+impl crate::StorageFormat {
+ pub(super) const fn to_hlsl_str(self) -> &'static str {
+ match self {
+ Self::R16Float => "float",
+ Self::R8Unorm | Self::R16Unorm => "unorm float",
+ Self::R8Snorm | Self::R16Snorm => "snorm float",
+ Self::R8Uint | Self::R16Uint => "uint",
+ Self::R8Sint | Self::R16Sint => "int",
+
+ Self::Rg16Float => "float2",
+ Self::Rg8Unorm | Self::Rg16Unorm => "unorm float2",
+ Self::Rg8Snorm | Self::Rg16Snorm => "snorm float2",
+
+ Self::Rg8Sint | Self::Rg16Sint => "int2",
+ Self::Rg8Uint | Self::Rg16Uint => "uint2",
+
+ Self::Rg11b10Float => "float3",
+
+ Self::Rgba16Float | Self::R32Float | Self::Rg32Float | Self::Rgba32Float => "float4",
+ Self::Rgba8Unorm | Self::Rgba16Unorm | Self::Rgb10a2Unorm => "unorm float4",
+ Self::Rgba8Snorm | Self::Rgba16Snorm => "snorm float4",
+
+ Self::Rgba8Uint
+ | Self::Rgba16Uint
+ | Self::R32Uint
+ | Self::Rg32Uint
+ | Self::Rgba32Uint => "uint4",
+ Self::Rgba8Sint
+ | Self::Rgba16Sint
+ | Self::R32Sint
+ | Self::Rg32Sint
+ | Self::Rgba32Sint => "int4",
+ }
+ }
+}
+
+impl crate::BuiltIn {
+ pub(super) fn to_hlsl_str(self) -> Result<&'static str, Error> {
+ Ok(match self {
+ Self::Position { .. } => "SV_Position",
+ // vertex
+ Self::ClipDistance => "SV_ClipDistance",
+ Self::CullDistance => "SV_CullDistance",
+ Self::InstanceIndex => "SV_InstanceID",
+ Self::VertexIndex => "SV_VertexID",
+ // fragment
+ Self::FragDepth => "SV_Depth",
+ Self::FrontFacing => "SV_IsFrontFace",
+ Self::PrimitiveIndex => "SV_PrimitiveID",
+ Self::SampleIndex => "SV_SampleIndex",
+ Self::SampleMask => "SV_Coverage",
+ // compute
+ Self::GlobalInvocationId => "SV_DispatchThreadID",
+ Self::LocalInvocationId => "SV_GroupThreadID",
+ Self::LocalInvocationIndex => "SV_GroupIndex",
+ Self::WorkGroupId => "SV_GroupID",
+ // The specific semantic we use here doesn't matter, because references
+ // to this field will get replaced with references to `SPECIAL_CBUF_VAR`
+ // in `Writer::write_expr`.
+ Self::NumWorkGroups => "SV_GroupID",
+ Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
+ return Err(Error::Unimplemented(format!("builtin {self:?}")))
+ }
+ Self::PointSize | Self::ViewIndex | Self::PointCoord => {
+ return Err(Error::Custom(format!("Unsupported builtin {self:?}")))
+ }
+ })
+ }
+}
+
+impl crate::Interpolation {
+ /// Return the string corresponding to the HLSL interpolation qualifier.
+ pub(super) const fn to_hlsl_str(self) -> Option<&'static str> {
+ match self {
+ // Would be "linear", but it's the default interpolation in SM4 and up
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-struct#interpolation-modifiers-introduced-in-shader-model-4
+ Self::Perspective => None,
+ Self::Linear => Some("noperspective"),
+ Self::Flat => Some("nointerpolation"),
+ }
+ }
+}
+
+impl crate::Sampling {
+ /// Return the HLSL auxiliary qualifier for the given sampling value.
+ pub(super) const fn to_hlsl_str(self) -> Option<&'static str> {
+ match self {
+ Self::Center => None,
+ Self::Centroid => Some("centroid"),
+ Self::Sample => Some("sample"),
+ }
+ }
+}
+
+impl crate::AtomicFunction {
+ /// Return the HLSL suffix for the `InterlockedXxx` method.
+ pub(super) const fn to_hlsl_suffix(self) -> &'static str {
+ match self {
+ Self::Add | Self::Subtract => "Add",
+ Self::And => "And",
+ Self::InclusiveOr => "Or",
+ Self::ExclusiveOr => "Xor",
+ Self::Min => "Min",
+ Self::Max => "Max",
+ Self::Exchange { compare: None } => "Exchange",
+ Self::Exchange { .. } => "", //TODO
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/help.rs b/third_party/rust/naga/src/back/hlsl/help.rs
new file mode 100644
index 0000000000..8ae9baf62d
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/help.rs
@@ -0,0 +1,1126 @@
+/*!
+Helpers for the hlsl backend
+
+Important note about `Expression::ImageQuery`/`Expression::ArrayLength` and hlsl backend:
+
+Due to implementation of `GetDimensions` function in hlsl (<https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>)
+backend can't work with it as an expression.
+Instead, it generates a unique wrapped function per `Expression::ImageQuery`, based on texture info and query function.
+See `WrappedImageQuery` struct that represents a unique function and will be generated before writing all statements and expressions.
+This allowed to works with `Expression::ImageQuery` as expression and write wrapped function.
+
+For example:
+```wgsl
+let dim_1d = textureDimensions(image_1d);
+```
+
+```hlsl
+int NagaDimensions1D(Texture1D<float4>)
+{
+ uint4 ret;
+ image_1d.GetDimensions(ret.x);
+ return ret.x;
+}
+
+int dim_1d = NagaDimensions1D(image_1d);
+```
+*/
+
+use super::{super::FunctionCtx, BackendResult};
+use crate::{arena::Handle, proc::NameKey};
+use std::fmt::Write;
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedArrayLength {
+ pub(super) writable: bool,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedImageQuery {
+ pub(super) dim: crate::ImageDimension,
+ pub(super) arrayed: bool,
+ pub(super) class: crate::ImageClass,
+ pub(super) query: ImageQuery,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedConstructor {
+ pub(super) ty: Handle<crate::Type>,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedStructMatrixAccess {
+ pub(super) ty: Handle<crate::Type>,
+ pub(super) index: u32,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedMatCx2 {
+ pub(super) columns: crate::VectorSize,
+}
+
+/// HLSL backend requires its own `ImageQuery` enum.
+///
+/// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function.
+/// IR version can't be unique per function, because it's store mipmap level as an expression.
+///
+/// For example:
+/// ```wgsl
+/// let dim_cube_array_lod = textureDimensions(image_cube_array, 1);
+/// let dim_cube_array_lod2 = textureDimensions(image_cube_array, 1);
+/// ```
+///
+/// ```ir
+/// ImageQuery {
+/// image: [1],
+/// query: Size {
+/// level: Some(
+/// [1],
+/// ),
+/// },
+/// },
+/// ImageQuery {
+/// image: [1],
+/// query: Size {
+/// level: Some(
+/// [2],
+/// ),
+/// },
+/// },
+/// ```
+///
+/// HLSL should generate only 1 function for this case.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) enum ImageQuery {
+ Size,
+ SizeLevel,
+ NumLevels,
+ NumLayers,
+ NumSamples,
+}
+
+impl From<crate::ImageQuery> for ImageQuery {
+ fn from(q: crate::ImageQuery) -> Self {
+ use crate::ImageQuery as Iq;
+ match q {
+ Iq::Size { level: Some(_) } => ImageQuery::SizeLevel,
+ Iq::Size { level: None } => ImageQuery::Size,
+ Iq::NumLevels => ImageQuery::NumLevels,
+ Iq::NumLayers => ImageQuery::NumLayers,
+ Iq::NumSamples => ImageQuery::NumSamples,
+ }
+ }
+}
+
+impl<'a, W: Write> super::Writer<'a, W> {
+ pub(super) fn write_image_type(
+ &mut self,
+ dim: crate::ImageDimension,
+ arrayed: bool,
+ class: crate::ImageClass,
+ ) -> BackendResult {
+ let access_str = match class {
+ crate::ImageClass::Storage { .. } => "RW",
+ _ => "",
+ };
+ let dim_str = dim.to_hlsl_str();
+ let arrayed_str = if arrayed { "Array" } else { "" };
+ write!(self.out, "{access_str}Texture{dim_str}{arrayed_str}")?;
+ match class {
+ crate::ImageClass::Depth { multi } => {
+ let multi_str = if multi { "MS" } else { "" };
+ write!(self.out, "{multi_str}<float>")?
+ }
+ crate::ImageClass::Sampled { kind, multi } => {
+ let multi_str = if multi { "MS" } else { "" };
+ let scalar_kind_str = kind.to_hlsl_str(4)?;
+ write!(self.out, "{multi_str}<{scalar_kind_str}4>")?
+ }
+ crate::ImageClass::Storage { format, .. } => {
+ let storage_format_str = format.to_hlsl_str();
+ write!(self.out, "<{storage_format_str}>")?
+ }
+ }
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_array_length_function_name(
+ &mut self,
+ query: WrappedArrayLength,
+ ) -> BackendResult {
+ let access_str = if query.writable { "RW" } else { "" };
+ write!(self.out, "NagaBufferLength{access_str}",)?;
+
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::ArrayLength`
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer-getdimensions>
+ pub(super) fn write_wrapped_array_length_function(
+ &mut self,
+ module: &crate::Module,
+ wal: WrappedArrayLength,
+ expr_handle: Handle<crate::Expression>,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const ARGUMENT_VARIABLE_NAME: &str = "buffer";
+ const RETURN_VARIABLE_NAME: &str = "ret";
+
+ // Write function return type and name
+ let ret_ty = func_ctx.info[expr_handle].ty.inner_with(&module.types);
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, " ")?;
+ self.write_wrapped_array_length_function_name(wal)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let access_str = if wal.writable { "RW" } else { "" };
+ writeln!(
+ self.out,
+ "{access_str}ByteAddressBuffer {ARGUMENT_VARIABLE_NAME})"
+ )?;
+ // Write function body
+ writeln!(self.out, "{{")?;
+
+ // Write `GetDimensions` function.
+ writeln!(self.out, "{INDENT}uint {RETURN_VARIABLE_NAME};")?;
+ writeln!(
+ self.out,
+ "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions({RETURN_VARIABLE_NAME});"
+ )?;
+
+ // Write return value
+ writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_image_query_function_name(
+ &mut self,
+ query: WrappedImageQuery,
+ ) -> BackendResult {
+ let dim_str = query.dim.to_hlsl_str();
+ let class_str = match query.class {
+ crate::ImageClass::Sampled { multi: true, .. } => "MS",
+ crate::ImageClass::Depth { multi: true } => "DepthMS",
+ crate::ImageClass::Depth { multi: false } => "Depth",
+ crate::ImageClass::Sampled { multi: false, .. } => "",
+ crate::ImageClass::Storage { .. } => "RW",
+ };
+ let arrayed_str = if query.arrayed { "Array" } else { "" };
+ let query_str = match query.query {
+ ImageQuery::Size => "Dimensions",
+ ImageQuery::SizeLevel => "MipDimensions",
+ ImageQuery::NumLevels => "NumLevels",
+ ImageQuery::NumLayers => "NumLayers",
+ ImageQuery::NumSamples => "NumSamples",
+ };
+
+ write!(self.out, "Naga{class_str}{query_str}{dim_str}{arrayed_str}")?;
+
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::ImageQuery`
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>
+ pub(super) fn write_wrapped_image_query_function(
+ &mut self,
+ module: &crate::Module,
+ wiq: WrappedImageQuery,
+ expr_handle: Handle<crate::Expression>,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ use crate::{
+ back::{COMPONENTS, INDENT},
+ ImageDimension as IDim,
+ };
+
+ const ARGUMENT_VARIABLE_NAME: &str = "tex";
+ const RETURN_VARIABLE_NAME: &str = "ret";
+ const MIP_LEVEL_PARAM: &str = "mip_level";
+
+ // Write function return type and name
+ let ret_ty = func_ctx.info[expr_handle].ty.inner_with(&module.types);
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, " ")?;
+ self.write_wrapped_image_query_function_name(wiq)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ // Texture always first parameter
+ self.write_image_type(wiq.dim, wiq.arrayed, wiq.class)?;
+ write!(self.out, " {ARGUMENT_VARIABLE_NAME}")?;
+ // Mipmap is a second parameter if exists
+ if let ImageQuery::SizeLevel = wiq.query {
+ write!(self.out, ", uint {MIP_LEVEL_PARAM}")?;
+ }
+ writeln!(self.out, ")")?;
+
+ // Write function body
+ writeln!(self.out, "{{")?;
+
+ let array_coords = usize::from(wiq.arrayed);
+ // extra parameter is the mip level count or the sample count
+ let extra_coords = match wiq.class {
+ crate::ImageClass::Storage { .. } => 0,
+ crate::ImageClass::Sampled { .. } | crate::ImageClass::Depth { .. } => 1,
+ };
+
+ // GetDimensions Overloaded Methods
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions#overloaded-methods
+ let (ret_swizzle, number_of_params) = match wiq.query {
+ ImageQuery::Size | ImageQuery::SizeLevel => {
+ let ret = match wiq.dim {
+ IDim::D1 => "x",
+ IDim::D2 => "xy",
+ IDim::D3 => "xyz",
+ IDim::Cube => "xy",
+ };
+ (ret, ret.len() + array_coords + extra_coords)
+ }
+ ImageQuery::NumLevels | ImageQuery::NumSamples | ImageQuery::NumLayers => {
+ if wiq.arrayed || wiq.dim == IDim::D3 {
+ ("w", 4)
+ } else {
+ ("z", 3)
+ }
+ }
+ };
+
+ // Write `GetDimensions` function.
+ writeln!(self.out, "{INDENT}uint4 {RETURN_VARIABLE_NAME};")?;
+ write!(self.out, "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions(")?;
+ match wiq.query {
+ ImageQuery::SizeLevel => {
+ write!(self.out, "{MIP_LEVEL_PARAM}, ")?;
+ }
+ _ => match wiq.class {
+ crate::ImageClass::Sampled { multi: true, .. }
+ | crate::ImageClass::Depth { multi: true }
+ | crate::ImageClass::Storage { .. } => {}
+ _ => {
+ // Write zero mipmap level for supported types
+ write!(self.out, "0, ")?;
+ }
+ },
+ }
+
+ for component in COMPONENTS[..number_of_params - 1].iter() {
+ write!(self.out, "{RETURN_VARIABLE_NAME}.{component}, ")?;
+ }
+
+ // write last parameter without comma and space for last parameter
+ write!(
+ self.out,
+ "{}.{}",
+ RETURN_VARIABLE_NAME,
+ COMPONENTS[number_of_params - 1]
+ )?;
+
+ writeln!(self.out, ");")?;
+
+ // Write return value
+ writeln!(
+ self.out,
+ "{INDENT}return {RETURN_VARIABLE_NAME}.{ret_swizzle};"
+ )?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_constructor_function_name(
+ &mut self,
+ module: &crate::Module,
+ constructor: WrappedConstructor,
+ ) -> BackendResult {
+ let name = crate::TypeInner::hlsl_type_id(
+ constructor.ty,
+ &module.types,
+ &module.constants,
+ &self.names,
+ )?;
+ write!(self.out, "Construct{name}")?;
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::Compose` for structures.
+ pub(super) fn write_wrapped_constructor_function(
+ &mut self,
+ module: &crate::Module,
+ constructor: WrappedConstructor,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const ARGUMENT_VARIABLE_NAME: &str = "arg";
+ const RETURN_VARIABLE_NAME: &str = "ret";
+
+ // Write function return type and name
+ if let crate::TypeInner::Array { base, size, .. } = module.types[constructor.ty].inner {
+ write!(self.out, "typedef ")?;
+ self.write_type(module, constructor.ty)?;
+ write!(self.out, " ret_")?;
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ self.write_array_size(module, base, size)?;
+ writeln!(self.out, ";")?;
+
+ write!(self.out, "ret_")?;
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ } else {
+ self.write_type(module, constructor.ty)?;
+ }
+ write!(self.out, " ")?;
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+
+ let mut write_arg = |i, ty| -> BackendResult {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_type(module, ty)?;
+ write!(self.out, " {ARGUMENT_VARIABLE_NAME}{i}")?;
+ if let crate::TypeInner::Array { base, size, .. } = module.types[ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ Ok(())
+ };
+
+ match module.types[constructor.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (i, member) in members.iter().enumerate() {
+ write_arg(i, member.ty)?;
+ }
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ let count = module.constants[size].to_array_length().unwrap();
+ for i in 0..count as usize {
+ write_arg(i, base)?;
+ }
+ }
+ _ => unreachable!(),
+ };
+
+ write!(self.out, ")")?;
+
+ // Write function body
+ writeln!(self.out, " {{")?;
+
+ match module.types[constructor.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let struct_name = &self.names[&NameKey::Type(constructor.ty)];
+ writeln!(
+ self.out,
+ "{INDENT}{struct_name} {RETURN_VARIABLE_NAME} = ({struct_name})0;"
+ )?;
+ for (i, member) in members.iter().enumerate() {
+ let field_name = &self.names[&NameKey::StructMember(constructor.ty, i as u32)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix {
+ columns,
+ rows: crate::VectorSize::Bi,
+ ..
+ } if member.binding.is_none() => {
+ for j in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}{RETURN_VARIABLE_NAME}.{field_name}_{j} = {ARGUMENT_VARIABLE_NAME}{i}[{j}];"
+ )?;
+ }
+ }
+ ref other => {
+ // We cast arrays of native HLSL `floatCx2`s to arrays of `matCx2`s
+ // (where the inner matrix is represented by a struct with C `float2` members).
+ // See the module-level block comment in mod.rs for details.
+ if let Some(super::writer::MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = super::writer::get_inner_matrix_data(module, member.ty)
+ {
+ write!(
+ self.out,
+ "{}{}.{} = (__mat{}x2",
+ INDENT, RETURN_VARIABLE_NAME, field_name, columns as u8
+ )?;
+ if let crate::TypeInner::Array { base, size, .. } = *other {
+ self.write_array_size(module, base, size)?;
+ }
+ writeln!(self.out, "){ARGUMENT_VARIABLE_NAME}{i};",)?;
+ } else {
+ writeln!(
+ self.out,
+ "{INDENT}{RETURN_VARIABLE_NAME}.{field_name} = {ARGUMENT_VARIABLE_NAME}{i};",
+ )?;
+ }
+ }
+ }
+ }
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ write!(self.out, "{INDENT}")?;
+ self.write_type(module, base)?;
+ write!(self.out, " {RETURN_VARIABLE_NAME}")?;
+ self.write_array_size(module, base, crate::ArraySize::Constant(size))?;
+ write!(self.out, " = {{ ")?;
+ let count = module.constants[size].to_array_length().unwrap();
+ for i in 0..count {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "{ARGUMENT_VARIABLE_NAME}{i}")?;
+ }
+ writeln!(self.out, " }};",)?;
+ }
+ _ => unreachable!(),
+ }
+
+ // Write return value
+ writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_get_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "GetMat{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to get a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_get_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+
+ // Write function return type and name
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ let ret_ty = &module.types[member.ty].inner;
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, " ")?;
+ self.write_wrapped_struct_matrix_get_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}")?;
+
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ // Write return value
+ write!(self.out, "{INDENT}return ")?;
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, "(")?;
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}")?;
+ }
+ }
+ _ => unreachable!(),
+ }
+ writeln!(self.out, ");")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_set_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "SetMat{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to set a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_set_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+ const MATRIX_ARGUMENT_VARIABLE_NAME: &str = "mat";
+
+ // Write function return type and name
+ write!(self.out, "void ")?;
+ self.write_wrapped_struct_matrix_set_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ self.write_type(module, member.ty)?;
+ write!(self.out, " {MATRIX_ARGUMENT_VARIABLE_NAME}")?;
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {MATRIX_ARGUMENT_VARIABLE_NAME}[{i}];"
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_set_vec_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "SetMatVec{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to set a vec2 on a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_set_vec_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+ const VECTOR_ARGUMENT_VARIABLE_NAME: &str = "vec";
+ const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx";
+
+ // Write function return type and name
+ write!(self.out, "void ")?;
+ self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ let vec_ty = match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { rows, width, .. } => crate::TypeInner::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ },
+ _ => unreachable!(),
+ };
+ self.write_value_type(module, &vec_ty)?;
+ write!(
+ self.out,
+ " {VECTOR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}"
+ )?;
+
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ writeln!(
+ self.out,
+ "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{"
+ )?;
+
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {VECTOR_ARGUMENT_VARIABLE_NAME}; break; }}"
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ writeln!(self.out, "{INDENT}}}")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_set_scalar_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "SetMatScalar{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to set a float on a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_set_scalar_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+ const SCALAR_ARGUMENT_VARIABLE_NAME: &str = "scalar";
+ const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx";
+ const VECTOR_INDEX_ARGUMENT_VARIABLE_NAME: &str = "vec_idx";
+
+ // Write function return type and name
+ write!(self.out, "void ")?;
+ self.write_wrapped_struct_matrix_set_scalar_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ let scalar_ty = match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { width, .. } => crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Float,
+ width,
+ },
+ _ => unreachable!(),
+ };
+ self.write_value_type(module, &scalar_ty)?;
+ write!(
+ self.out,
+ " {SCALAR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}, uint {VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}"
+ )?;
+
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ writeln!(
+ self.out,
+ "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{"
+ )?;
+
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}[{VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}] = {SCALAR_ARGUMENT_VARIABLE_NAME}; break; }}"
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ writeln!(self.out, "{INDENT}}}")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::ImageQuery` and `Expression::ArrayLength`
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>
+ pub(super) fn write_wrapped_functions(
+ &mut self,
+ module: &crate::Module,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ for (handle, _) in func_ctx.expressions.iter() {
+ match func_ctx.expressions[handle] {
+ crate::Expression::ArrayLength(expr) => {
+ let global_expr = match func_ctx.expressions[expr] {
+ crate::Expression::GlobalVariable(_) => expr,
+ crate::Expression::AccessIndex { base, index: _ } => base,
+ ref other => unreachable!("Array length of {:?}", other),
+ };
+ let global_var = match func_ctx.expressions[global_expr] {
+ crate::Expression::GlobalVariable(var_handle) => {
+ &module.global_variables[var_handle]
+ }
+ ref other => unreachable!("Array length of base {:?}", other),
+ };
+ let storage_access = match global_var.space {
+ crate::AddressSpace::Storage { access } => access,
+ _ => crate::StorageAccess::default(),
+ };
+ let wal = WrappedArrayLength {
+ writable: storage_access.contains(crate::StorageAccess::STORE),
+ };
+
+ if !self.wrapped.array_lengths.contains(&wal) {
+ self.write_wrapped_array_length_function(module, wal, handle, func_ctx)?;
+ self.wrapped.array_lengths.insert(wal);
+ }
+ }
+ crate::Expression::ImageQuery { image, query } => {
+ let wiq = match *func_ctx.info[image].ty.inner_with(&module.types) {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => WrappedImageQuery {
+ dim,
+ arrayed,
+ class,
+ query: query.into(),
+ },
+ _ => unreachable!("we only query images"),
+ };
+
+ if !self.wrapped.image_queries.contains(&wiq) {
+ self.write_wrapped_image_query_function(module, wiq, handle, func_ctx)?;
+ self.wrapped.image_queries.insert(wiq);
+ }
+ }
+ // Write `WrappedConstructor` for structs that are loaded from `AddressSpace::Storage`
+ // since they will later be used by the fn `write_storage_load`
+ crate::Expression::Load { pointer } => {
+ let pointer_space = func_ctx.info[pointer]
+ .ty
+ .inner_with(&module.types)
+ .pointer_space();
+
+ if let Some(crate::AddressSpace::Storage { .. }) = pointer_space {
+ if let Some(ty) = func_ctx.info[handle].ty.handle() {
+ write_wrapped_constructor(self, ty, module)?;
+ }
+ }
+
+ fn write_wrapped_constructor<W: Write>(
+ writer: &mut super::Writer<'_, W>,
+ ty: Handle<crate::Type>,
+ module: &crate::Module,
+ ) -> BackendResult {
+ match module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for member in members {
+ write_wrapped_constructor(writer, member.ty, module)?;
+ }
+
+ let constructor = WrappedConstructor { ty };
+ if !writer.wrapped.constructors.contains(&constructor) {
+ writer
+ .write_wrapped_constructor_function(module, constructor)?;
+ writer.wrapped.constructors.insert(constructor);
+ }
+ }
+ crate::TypeInner::Array { base, .. } => {
+ write_wrapped_constructor(writer, base, module)?;
+ let constructor = WrappedConstructor { ty };
+ if !writer.wrapped.constructors.contains(&constructor) {
+ writer
+ .write_wrapped_constructor_function(module, constructor)?;
+ writer.wrapped.constructors.insert(constructor);
+ }
+ }
+ _ => {}
+ };
+
+ Ok(())
+ }
+ }
+ crate::Expression::Compose { ty, components: _ } => {
+ let constructor = match module.types[ty].inner {
+ crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => {
+ WrappedConstructor { ty }
+ }
+ _ => continue,
+ };
+ if !self.wrapped.constructors.contains(&constructor) {
+ self.write_wrapped_constructor_function(module, constructor)?;
+ self.wrapped.constructors.insert(constructor);
+ }
+ }
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s
+ // (see top level module docs for details).
+ //
+ // The functions injected here are required to get the matrix accesses working.
+ crate::Expression::AccessIndex { base, index } => {
+ let base_ty_res = &func_ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+ let base_ty_handle = match *resolved {
+ crate::TypeInner::Pointer { base, .. } => {
+ resolved = &module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+ if let crate::TypeInner::Struct { ref members, .. } = *resolved {
+ let member = &members[index as usize];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix {
+ rows: crate::VectorSize::Bi,
+ ..
+ } if member.binding.is_none() => {
+ let ty = base_ty_handle.unwrap();
+ let access = WrappedStructMatrixAccess { ty, index };
+
+ if !self.wrapped.struct_matrix_access.contains(&access) {
+ self.write_wrapped_struct_matrix_get_function(module, access)?;
+ self.write_wrapped_struct_matrix_set_function(module, access)?;
+ self.write_wrapped_struct_matrix_set_vec_function(
+ module, access,
+ )?;
+ self.write_wrapped_struct_matrix_set_scalar_function(
+ module, access,
+ )?;
+ self.wrapped.struct_matrix_access.insert(access);
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+ _ => {}
+ };
+ }
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_constructor_function_for_constant(
+ &mut self,
+ module: &crate::Module,
+ constant: &crate::Constant,
+ ) -> BackendResult {
+ if let crate::ConstantInner::Composite { ty, ref components } = constant.inner {
+ match module.types[ty].inner {
+ crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => {
+ let constructor = WrappedConstructor { ty };
+ if !self.wrapped.constructors.contains(&constructor) {
+ self.write_wrapped_constructor_function(module, constructor)?;
+ self.wrapped.constructors.insert(constructor);
+ }
+ }
+ _ => {}
+ }
+ for constant in components {
+ self.write_wrapped_constructor_function_for_constant(
+ module,
+ &module.constants[*constant],
+ )?;
+ }
+ }
+
+ Ok(())
+ }
+
+ pub(super) fn write_texture_coordinates(
+ &mut self,
+ kind: &str,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ mip_level: Option<Handle<crate::Expression>>,
+ module: &crate::Module,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ // HLSL expects the array index to be merged with the coordinate
+ let extra = array_index.is_some() as usize + (mip_level.is_some()) as usize;
+ if extra == 0 {
+ self.write_expr(module, coordinate, func_ctx)?;
+ } else {
+ let num_coords = match *func_ctx.info[coordinate].ty.inner_with(&module.types) {
+ crate::TypeInner::Scalar { .. } => 1,
+ crate::TypeInner::Vector { size, .. } => size as usize,
+ _ => unreachable!(),
+ };
+ write!(self.out, "{}{}(", kind, num_coords + extra)?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ if let Some(expr) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ if let Some(expr) = mip_level {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ write!(self.out, ")")?;
+ }
+ Ok(())
+ }
+
+ pub(super) fn write_mat_cx2_typedef_and_functions(
+ &mut self,
+ WrappedMatCx2 { columns }: WrappedMatCx2,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ // typedef
+ write!(self.out, "typedef struct {{ ")?;
+ for i in 0..columns as u8 {
+ write!(self.out, "float2 _{i}; ")?;
+ }
+ writeln!(self.out, "}} __mat{}x2;", columns as u8)?;
+
+ // __get_col_of_mat
+ writeln!(
+ self.out,
+ "float2 __get_col_of_mat{}x2(__mat{}x2 mat, uint idx) {{",
+ columns as u8, columns as u8
+ )?;
+ writeln!(self.out, "{INDENT}switch(idx) {{")?;
+ for i in 0..columns as u8 {
+ writeln!(self.out, "{INDENT}case {i}: {{ return mat._{i}; }}")?;
+ }
+ writeln!(self.out, "{INDENT}default: {{ return (float2)0; }}")?;
+ writeln!(self.out, "{INDENT}}}")?;
+ writeln!(self.out, "}}")?;
+
+ // __set_col_of_mat
+ writeln!(
+ self.out,
+ "void __set_col_of_mat{}x2(__mat{}x2 mat, uint idx, float2 value) {{",
+ columns as u8, columns as u8
+ )?;
+ writeln!(self.out, "{INDENT}switch(idx) {{")?;
+ for i in 0..columns as u8 {
+ writeln!(self.out, "{INDENT}case {i}: {{ mat._{i} = value; break; }}")?;
+ }
+ writeln!(self.out, "{INDENT}}}")?;
+ writeln!(self.out, "}}")?;
+
+ // __set_el_of_mat
+ writeln!(
+ self.out,
+ "void __set_el_of_mat{}x2(__mat{}x2 mat, uint idx, uint vec_idx, float value) {{",
+ columns as u8, columns as u8
+ )?;
+ writeln!(self.out, "{INDENT}switch(idx) {{")?;
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}case {i}: {{ mat._{i}[vec_idx] = value; break; }}"
+ )?;
+ }
+ writeln!(self.out, "{INDENT}}}")?;
+ writeln!(self.out, "}}")?;
+
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_all_mat_cx2_typedefs_and_functions(
+ &mut self,
+ module: &crate::Module,
+ ) -> BackendResult {
+ for (handle, _) in module.global_variables.iter() {
+ let global = &module.global_variables[handle];
+
+ if global.space == crate::AddressSpace::Uniform {
+ if let Some(super::writer::MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = super::writer::get_inner_matrix_data(module, global.ty)
+ {
+ let entry = WrappedMatCx2 { columns };
+ if !self.wrapped.mat_cx2s.contains(&entry) {
+ self.write_mat_cx2_typedef_and_functions(entry)?;
+ self.wrapped.mat_cx2s.insert(entry);
+ }
+ }
+ }
+ }
+
+ for (_, ty) in module.types.iter() {
+ if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
+ for member in members.iter() {
+ if let crate::TypeInner::Array { .. } = module.types[member.ty].inner {
+ if let Some(super::writer::MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = super::writer::get_inner_matrix_data(module, member.ty)
+ {
+ let entry = WrappedMatCx2 { columns };
+ if !self.wrapped.mat_cx2s.contains(&entry) {
+ self.write_mat_cx2_typedef_and_functions(entry)?;
+ self.wrapped.mat_cx2s.insert(entry);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/keywords.rs b/third_party/rust/naga/src/back/hlsl/keywords.rs
new file mode 100644
index 0000000000..7519b767a1
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/keywords.rs
@@ -0,0 +1,166 @@
+/*!
+HLSL Reserved Words
+- <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-appendix-keywords>
+- <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-appendix-reserved-words>
+*/
+
+pub const RESERVED: &[&str] = &[
+ "AppendStructuredBuffer",
+ "asm",
+ "asm_fragment",
+ "BlendState",
+ "bool",
+ "break",
+ "Buffer",
+ "ByteAddressBuffer",
+ "case",
+ "cbuffer",
+ "centroid",
+ "class",
+ "column_major",
+ "compile",
+ "compile_fragment",
+ "CompileShader",
+ "const",
+ "continue",
+ "ComputeShader",
+ "ConsumeStructuredBuffer",
+ "default",
+ "DepthStencilState",
+ "DepthStencilView",
+ "discard",
+ "do",
+ "double",
+ "DomainShader",
+ "dword",
+ "else",
+ "export",
+ "extern",
+ "false",
+ "float",
+ "for",
+ "fxgroup",
+ "GeometryShader",
+ "groupshared",
+ "half",
+ "Hullshader",
+ "if",
+ "in",
+ "inline",
+ "inout",
+ "InputPatch",
+ "int",
+ "interface",
+ "line",
+ "lineadj",
+ "linear",
+ "LineStream",
+ "matrix",
+ "min16float",
+ "min10float",
+ "min16int",
+ "min12int",
+ "min16uint",
+ "namespace",
+ "nointerpolation",
+ "noperspective",
+ "NULL",
+ "out",
+ "OutputPatch",
+ "packoffset",
+ "pass",
+ "pixelfragment",
+ "PixelShader",
+ "point",
+ "PointStream",
+ "precise",
+ "RasterizerState",
+ "RenderTargetView",
+ "return",
+ "register",
+ "row_major",
+ "RWBuffer",
+ "RWByteAddressBuffer",
+ "RWStructuredBuffer",
+ "RWTexture1D",
+ "RWTexture1DArray",
+ "RWTexture2D",
+ "RWTexture2DArray",
+ "RWTexture3D",
+ "sample",
+ "sampler",
+ "SamplerState",
+ "SamplerComparisonState",
+ "shared",
+ "snorm",
+ "stateblock",
+ "stateblock_state",
+ "static",
+ "string",
+ "struct",
+ "switch",
+ "StructuredBuffer",
+ "tbuffer",
+ "technique",
+ "technique10",
+ "technique11",
+ "texture",
+ "Texture1D",
+ "Texture1DArray",
+ "Texture2D",
+ "Texture2DArray",
+ "Texture2DMS",
+ "Texture2DMSArray",
+ "Texture3D",
+ "TextureCube",
+ "TextureCubeArray",
+ "true",
+ "typedef",
+ "triangle",
+ "triangleadj",
+ "TriangleStream",
+ "uint",
+ "uniform",
+ "unorm",
+ "unsigned",
+ "vector",
+ "vertexfragment",
+ "VertexShader",
+ "void",
+ "volatile",
+ "while",
+ "auto",
+ "case",
+ "catch",
+ "char",
+ "class",
+ "const_cast",
+ "default",
+ "delete",
+ "dynamic_cast",
+ "enum",
+ "explicit",
+ "friend",
+ "goto",
+ "long",
+ "mutable",
+ "new",
+ "operator",
+ "private",
+ "protected",
+ "public",
+ "reinterpret_cast",
+ "short",
+ "signed",
+ "sizeof",
+ "static_cast",
+ "template",
+ "this",
+ "throw",
+ "try",
+ "typename",
+ "union",
+ "unsigned",
+ "using",
+ "virtual",
+];
diff --git a/third_party/rust/naga/src/back/hlsl/mod.rs b/third_party/rust/naga/src/back/hlsl/mod.rs
new file mode 100644
index 0000000000..f3a6f9106c
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/mod.rs
@@ -0,0 +1,302 @@
+/*!
+Backend for [HLSL][hlsl] (High-Level Shading Language).
+
+# Supported shader model versions:
+- 5.0
+- 5.1
+- 6.0
+
+# Layout of values in `uniform` buffers
+
+WGSL's ["Internal Layout of Values"][ilov] rules specify how each WGSL
+type should be stored in `uniform` and `storage` buffers. The HLSL we
+generate must access values in that form, even when it is not what
+HLSL would use normally.
+
+The rules described here only apply to WGSL `uniform` variables. WGSL
+`storage` buffers are translated as HLSL `ByteAddressBuffers`, for
+which we generate `Load` and `Store` method calls with explicit byte
+offsets. WGSL pipeline inputs must be scalars or vectors; they cannot
+be matrices, which is where the interesting problems arise.
+
+## Row- and column-major ordering for matrices
+
+WGSL specifies that matrices in uniform buffers are stored in
+column-major order. This matches HLSL's default, so one might expect
+things to be straightforward. Unfortunately, WGSL and HLSL disagree on
+what indexing a matrix means: in WGSL, `m[i]` retrieves the `i`'th
+*column* of `m`, whereas in HLSL it retrieves the `i`'th *row*. We
+want to avoid translating `m[i]` into some complicated reassembly of a
+vector from individually fetched components, so this is a problem.
+
+However, with a bit of trickery, it is possible to use HLSL's `m[i]`
+as the translation of WGSL's `m[i]`:
+
+- We declare all matrices in uniform buffers in HLSL with the
+ `row_major` qualifier, and transpose the row and column counts: a
+ WGSL `mat3x4<f32>`, say, becomes an HLSL `row_major float3x4`. (Note
+ that WGSL and HLSL type names put the row and column in reverse
+ order.) Since the HLSL type is the transpose of how WebGPU directs
+ the user to store the data, HLSL will load all matrices transposed.
+
+- Since matrices are transposed, an HLSL indexing expression retrieves
+ the "columns" of the intended WGSL value, as desired.
+
+- For vector-matrix multiplication, since `mul(transpose(m), v)` is
+ equivalent to `mul(v, m)` (note the reversal of the arguments), and
+ `mul(v, transpose(m))` is equivalent to `mul(m, v)`, we can
+ translate WGSL `m * v` and `v * m` to HLSL by simply reversing the
+ arguments to `mul`.
+
+## Padding in two-row matrices
+
+An HLSL `row_major floatKx2` matrix has padding between its rows that
+the WGSL `matKx2<f32>` matrix it represents does not. HLSL stores all
+matrix rows [aligned on 16-byte boundaries][16bb], whereas WGSL says
+that the columns of a `matKx2<f32>` need only be [aligned as required
+for `vec2<f32>`][ilov], which is [eight-byte alignment][8bb].
+
+To compensate for this, any time a `matKx2<f32>` appears in a WGSL
+`uniform` variable, whether directly as the variable's type or as part
+of a struct/array, we actually emit `K` separate `float2` members, and
+assemble/disassemble the matrix from its columns (in WGSL; rows in
+HLSL) upon load and store.
+
+For example, the following WGSL struct type:
+
+```ignore
+struct Baz {
+ m: mat3x2<f32>,
+}
+```
+
+is rendered as the HLSL struct type:
+
+```ignore
+struct Baz {
+ float2 m_0; float2 m_1; float2 m_2;
+};
+```
+
+The `wrapped_struct_matrix` functions in `help.rs` generate HLSL
+helper functions to access such members, converting between the stored
+form and the HLSL matrix types appropriately. For example, for reading
+the member `m` of the `Baz` struct above, we emit:
+
+```ignore
+float3x2 GetMatmOnBaz(Baz obj) {
+ return float3x2(obj.m_0, obj.m_1, obj.m_2);
+}
+```
+
+We also emit an analogous `Set` function, as well as functions for
+accessing individual columns by dynamic index.
+
+[hlsl]: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl
+[ilov]: https://gpuweb.github.io/gpuweb/wgsl/#internal-value-layout
+[16bb]: https://github.com/microsoft/DirectXShaderCompiler/wiki/Buffer-Packing#constant-buffer-packing
+[8bb]: https://gpuweb.github.io/gpuweb/wgsl/#alignment-and-size
+*/
+
+mod conv;
+mod help;
+mod keywords;
+mod storage;
+mod writer;
+
+use std::fmt::Error as FmtError;
+use thiserror::Error;
+
+use crate::{back, proc};
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct BindTarget {
+ pub space: u8,
+ pub register: u32,
+ /// If the binding is an unsized binding array, this overrides the size.
+ pub binding_array_size: Option<u32>,
+}
+
+// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
+
+/// A HLSL shader model version.
+#[allow(non_snake_case, non_camel_case_types)]
+#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum ShaderModel {
+ V5_0,
+ V5_1,
+ V6_0,
+}
+
+impl ShaderModel {
+ pub const fn to_str(self) -> &'static str {
+ match self {
+ Self::V5_0 => "5_0",
+ Self::V5_1 => "5_1",
+ Self::V6_0 => "6_0",
+ }
+ }
+}
+
+impl crate::ShaderStage {
+ pub const fn to_hlsl_str(self) -> &'static str {
+ match self {
+ Self::Vertex => "vs",
+ Self::Fragment => "ps",
+ Self::Compute => "cs",
+ }
+ }
+}
+
+impl crate::ImageDimension {
+ const fn to_hlsl_str(self) -> &'static str {
+ match self {
+ Self::D1 => "1D",
+ Self::D2 => "2D",
+ Self::D3 => "3D",
+ Self::Cube => "Cube",
+ }
+ }
+}
+
+/// Shorthand result used internally by the backend
+type BackendResult = Result<(), Error>;
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum EntryPointError {
+ #[error("mapping of {0:?} is missing")]
+ MissingBinding(crate::ResourceBinding),
+}
+
+/// Configuration used in the [`Writer`].
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Options {
+ /// The hlsl shader model to be used
+ pub shader_model: ShaderModel,
+ /// Map of resources association to binding locations.
+ pub binding_map: BindingMap,
+ /// Don't panic on missing bindings, instead generate any HLSL.
+ pub fake_missing_bindings: bool,
+ /// Add special constants to `SV_VertexIndex` and `SV_InstanceIndex`,
+ /// to make them work like in Vulkan/Metal, with help of the host.
+ pub special_constants_binding: Option<BindTarget>,
+ /// Bind target of the push constant buffer
+ pub push_constants_target: Option<BindTarget>,
+ /// Should workgroup variables be zero initialized (by polyfilling)?
+ pub zero_initialize_workgroup_memory: bool,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ shader_model: ShaderModel::V5_1,
+ binding_map: BindingMap::default(),
+ fake_missing_bindings: true,
+ special_constants_binding: None,
+ push_constants_target: None,
+ zero_initialize_workgroup_memory: true,
+ }
+ }
+}
+
+impl Options {
+ fn resolve_resource_binding(
+ &self,
+ res_binding: &crate::ResourceBinding,
+ ) -> Result<BindTarget, EntryPointError> {
+ match self.binding_map.get(res_binding) {
+ Some(target) => Ok(target.clone()),
+ None if self.fake_missing_bindings => Ok(BindTarget {
+ space: res_binding.group as u8,
+ register: res_binding.binding,
+ binding_array_size: None,
+ }),
+ None => Err(EntryPointError::MissingBinding(res_binding.clone())),
+ }
+ }
+}
+
+/// Reflection info for entry point names.
+#[derive(Default)]
+pub struct ReflectionInfo {
+ /// Mapping of the entry point names.
+ ///
+ /// Each item in the array corresponds to an entry point index. The real entry point name may be different if one of the
+ /// reserved words are used.
+ ///
+ /// Note: Some entry points may fail translation because of missing bindings.
+ pub entry_point_names: Vec<Result<String, EntryPointError>>,
+}
+
+#[derive(Error, Debug)]
+pub enum Error {
+ #[error(transparent)]
+ IoError(#[from] FmtError),
+ #[error("A scalar with an unsupported width was requested: {0:?} {1:?}")]
+ UnsupportedScalar(crate::ScalarKind, crate::Bytes),
+ #[error("{0}")]
+ Unimplemented(String), // TODO: Error used only during development
+ #[error("{0}")]
+ Custom(String),
+}
+
+#[derive(Default)]
+struct Wrapped {
+ array_lengths: crate::FastHashSet<help::WrappedArrayLength>,
+ image_queries: crate::FastHashSet<help::WrappedImageQuery>,
+ constructors: crate::FastHashSet<help::WrappedConstructor>,
+ struct_matrix_access: crate::FastHashSet<help::WrappedStructMatrixAccess>,
+ mat_cx2s: crate::FastHashSet<help::WrappedMatCx2>,
+}
+
+impl Wrapped {
+ fn clear(&mut self) {
+ self.array_lengths.clear();
+ self.image_queries.clear();
+ self.constructors.clear();
+ self.struct_matrix_access.clear();
+ self.mat_cx2s.clear();
+ }
+}
+
+pub struct Writer<'a, W> {
+ out: W,
+ names: crate::FastHashMap<proc::NameKey, String>,
+ namer: proc::Namer,
+ /// HLSL backend options
+ options: &'a Options,
+ /// Information about entry point arguments and result types.
+ entry_point_io: Vec<writer::EntryPointInterface>,
+ /// Set of expressions that have associated temporary variables
+ named_expressions: crate::NamedExpressions,
+ wrapped: Wrapped,
+
+ /// A reference to some part of a global variable, lowered to a series of
+ /// byte offset calculations.
+ ///
+ /// See the [`storage`] module for background on why we need this.
+ ///
+ /// Each [`SubAccess`] in the vector is a lowering of some [`Access`] or
+ /// [`AccessIndex`] expression to the level of byte strides and offsets. See
+ /// [`SubAccess`] for details.
+ ///
+ /// This field is a member of [`Writer`] solely to allow re-use of
+ /// the `Vec`'s dynamic allocation. The value is no longer needed
+ /// once HLSL for the access has been generated.
+ ///
+ /// [`Storage`]: crate::AddressSpace::Storage
+ /// [`SubAccess`]: storage::SubAccess
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ temp_access_chain: Vec<storage::SubAccess>,
+ need_bake_expressions: back::NeedBakeExpressions,
+}
diff --git a/third_party/rust/naga/src/back/hlsl/storage.rs b/third_party/rust/naga/src/back/hlsl/storage.rs
new file mode 100644
index 0000000000..813dd73649
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/storage.rs
@@ -0,0 +1,510 @@
+/*!
+Generating accesses to [`ByteAddressBuffer`] contents.
+
+Naga IR globals in the [`Storage`] address space are rendered as
+[`ByteAddressBuffer`]s or [`RWByteAddressBuffer`]s in HLSL. These
+buffers don't have HLSL types (structs, arrays, etc.); instead, they
+are just raw blocks of bytes, with methods to load and store values of
+specific types at particular byte offsets. This means that Naga must
+translate chains of [`Access`] and [`AccessIndex`] expressions into
+HLSL expressions that compute byte offsets into the buffer.
+
+To generate code for a [`Storage`] access:
+
+- Call [`Writer::fill_access_chain`] on the expression referring to
+ the value. This populates [`Writer::temp_access_chain`] with the
+ appropriate byte offset calculations, as a vector of [`SubAccess`]
+ values.
+
+- Call [`Writer::write_storage_address`] to emit an HLSL expression
+ for a given slice of [`SubAccess`] values.
+
+Naga IR expressions can operate on composite values of any type, but
+[`ByteAddressBuffer`] and [`RWByteAddressBuffer`] have only a fixed
+set of `Load` and `Store` methods, to access one through four
+consecutive 32-bit values. To synthesize a Naga access, you can
+initialize [`temp_access_chain`] to refer to the composite, and then
+temporarily push and pop additional steps on
+[`Writer::temp_access_chain`] to generate accesses to the individual
+elements/members.
+
+The [`temp_access_chain`] field is a member of [`Writer`] solely to
+allow re-use of the `Vec`'s dynamic allocation. Its value is no longer
+needed once HLSL for the access has been generated.
+
+[`Storage`]: crate::AddressSpace::Storage
+[`ByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-byteaddressbuffer
+[`RWByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer
+[`Access`]: crate::Expression::Access
+[`AccessIndex`]: crate::Expression::AccessIndex
+[`Writer::fill_access_chain`]: super::Writer::fill_access_chain
+[`Writer::write_storage_address`]: super::Writer::write_storage_address
+[`Writer::temp_access_chain`]: super::Writer::temp_access_chain
+[`temp_access_chain`]: super::Writer::temp_access_chain
+[`Writer`]: super::Writer
+*/
+
+use super::{super::FunctionCtx, BackendResult, Error};
+use crate::{
+ proc::{Alignment, NameKey, TypeResolution},
+ Handle,
+};
+
+use std::{fmt, mem};
+
+const STORE_TEMP_NAME: &str = "_value";
+
+/// One step in accessing a [`Storage`] global's component or element.
+///
+/// [`Writer::temp_access_chain`] holds a series of these structures,
+/// describing how to compute the byte offset of a particular element
+/// or member of some global variable in the [`Storage`] address
+/// space.
+///
+/// [`Writer::temp_access_chain`]: super::Writer::temp_access_chain
+/// [`Storage`]: crate::AddressSpace::Storage
+#[derive(Debug)]
+pub(super) enum SubAccess {
+ /// Add the given byte offset. This is used for struct members, or
+ /// known components of a vector or matrix. In all those cases,
+ /// the byte offset is a compile-time constant.
+ Offset(u32),
+
+ /// Scale `value` by `stride`, and add that to the current byte
+ /// offset. This is used to compute the offset of an array element
+ /// whose index is computed at runtime.
+ Index {
+ value: Handle<crate::Expression>,
+ stride: u32,
+ },
+}
+
+pub(super) enum StoreValue {
+ Expression(Handle<crate::Expression>),
+ TempIndex {
+ depth: usize,
+ index: u32,
+ ty: TypeResolution,
+ },
+ TempAccess {
+ depth: usize,
+ base: Handle<crate::Type>,
+ member_index: u32,
+ },
+}
+
+impl<W: fmt::Write> super::Writer<'_, W> {
+ pub(super) fn write_storage_address(
+ &mut self,
+ module: &crate::Module,
+ chain: &[SubAccess],
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ if chain.is_empty() {
+ write!(self.out, "0")?;
+ }
+ for (i, access) in chain.iter().enumerate() {
+ if i != 0 {
+ write!(self.out, "+")?;
+ }
+ match *access {
+ SubAccess::Offset(offset) => {
+ write!(self.out, "{offset}")?;
+ }
+ SubAccess::Index { value, stride } => {
+ self.write_expr(module, value, func_ctx)?;
+ write!(self.out, "*{stride}")?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn write_storage_load_sequence<I: Iterator<Item = (TypeResolution, u32)>>(
+ &mut self,
+ module: &crate::Module,
+ var_handle: Handle<crate::GlobalVariable>,
+ sequence: I,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ for (i, (ty_resolution, offset)) in sequence.enumerate() {
+ // add the index temporarily
+ self.temp_access_chain.push(SubAccess::Offset(offset));
+ if i != 0 {
+ write!(self.out, ", ")?;
+ };
+ self.write_storage_load(module, var_handle, ty_resolution, func_ctx)?;
+ self.temp_access_chain.pop();
+ }
+ Ok(())
+ }
+
+ /// Emit code to access a [`Storage`] global's component.
+ ///
+ /// Emit HLSL to access the component of `var_handle`, a global
+ /// variable in the [`Storage`] address space, whose type is
+ /// `result_ty` and whose location within the global is given by
+ /// [`self.temp_access_chain`]. See the [`storage`] module's
+ /// documentation for background.
+ ///
+ /// [`Storage`]: crate::AddressSpace::Storage
+ /// [`self.temp_access_chain`]: super::Writer::temp_access_chain
+ pub(super) fn write_storage_load(
+ &mut self,
+ module: &crate::Module,
+ var_handle: Handle<crate::GlobalVariable>,
+ result_ty: TypeResolution,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ match *result_ty.inner_with(&module.types) {
+ crate::TypeInner::Scalar { kind, width: _ } => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ let cast = kind.to_hlsl_cast();
+ write!(self.out, "{cast}({var_name}.Load(")?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, "))")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Vector {
+ size,
+ kind,
+ width: _,
+ } => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ let cast = kind.to_hlsl_cast();
+ write!(self.out, "{}({}.Load{}(", cast, var_name, size as u8)?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, "))")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ write!(
+ self.out,
+ "{}{}x{}(",
+ crate::ScalarKind::Float.to_hlsl_str(width)?,
+ columns as u8,
+ rows as u8,
+ )?;
+
+ // Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
+ let row_stride = Alignment::from(rows) * width as u32;
+ let iter = (0..columns as u32).map(|i| {
+ let ty_inner = crate::TypeInner::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ };
+ (TypeResolution::Value(ty_inner), i * row_stride)
+ });
+ self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(const_handle),
+ ..
+ } => {
+ let constructor = super::help::WrappedConstructor {
+ ty: result_ty.handle().unwrap(),
+ };
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ write!(self.out, "(")?;
+ let count = module.constants[const_handle].to_array_length().unwrap();
+ let stride = module.types[base].inner.size(&module.constants);
+ let iter = (0..count).map(|i| (TypeResolution::Handle(base), stride * i));
+ self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ let constructor = super::help::WrappedConstructor {
+ ty: result_ty.handle().unwrap(),
+ };
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ write!(self.out, "(")?;
+ let iter = members
+ .iter()
+ .map(|m| (TypeResolution::Handle(m.ty), m.offset));
+ self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ _ => unreachable!(),
+ }
+ Ok(())
+ }
+
+ fn write_store_value(
+ &mut self,
+ module: &crate::Module,
+ value: &StoreValue,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ match *value {
+ StoreValue::Expression(expr) => self.write_expr(module, expr, func_ctx)?,
+ StoreValue::TempIndex {
+ depth,
+ index,
+ ty: _,
+ } => write!(self.out, "{STORE_TEMP_NAME}{depth}[{index}]")?,
+ StoreValue::TempAccess {
+ depth,
+ base,
+ member_index,
+ } => {
+ let name = &self.names[&NameKey::StructMember(base, member_index)];
+ write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}")?
+ }
+ }
+ Ok(())
+ }
+
+ /// Helper function to write down the Store operation on a `ByteAddressBuffer`.
+ pub(super) fn write_storage_store(
+ &mut self,
+ module: &crate::Module,
+ var_handle: Handle<crate::GlobalVariable>,
+ value: StoreValue,
+ func_ctx: &FunctionCtx,
+ level: crate::back::Level,
+ ) -> BackendResult {
+ let temp_resolution;
+ let ty_resolution = match value {
+ StoreValue::Expression(expr) => &func_ctx.info[expr].ty,
+ StoreValue::TempIndex {
+ depth: _,
+ index: _,
+ ref ty,
+ } => ty,
+ StoreValue::TempAccess {
+ depth: _,
+ base,
+ member_index,
+ } => {
+ let ty_handle = match module.types[base].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ members[member_index as usize].ty
+ }
+ _ => unreachable!(),
+ };
+ temp_resolution = TypeResolution::Handle(ty_handle);
+ &temp_resolution
+ }
+ };
+ match *ty_resolution.inner_with(&module.types) {
+ crate::TypeInner::Scalar { .. } => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "{level}{var_name}.Store(")?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, ", asuint(")?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, "));")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Vector { size, .. } => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, ", asuint(")?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, "));")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ // first, assign the value to a temporary
+ writeln!(self.out, "{level}{{")?;
+ let depth = level.0 + 1;
+ write!(
+ self.out,
+ "{}{}{}x{} {}{} = ",
+ level.next(),
+ crate::ScalarKind::Float.to_hlsl_str(width)?,
+ columns as u8,
+ rows as u8,
+ STORE_TEMP_NAME,
+ depth,
+ )?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, ";")?;
+
+ // Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
+ let row_stride = Alignment::from(rows) * width as u32;
+
+ // then iterate the stores
+ for i in 0..columns as u32 {
+ self.temp_access_chain
+ .push(SubAccess::Offset(i * row_stride));
+ let ty_inner = crate::TypeInner::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ };
+ let sv = StoreValue::TempIndex {
+ depth,
+ index: i,
+ ty: TypeResolution::Value(ty_inner),
+ };
+ self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
+ self.temp_access_chain.pop();
+ }
+ // done
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(const_handle),
+ ..
+ } => {
+ // first, assign the value to a temporary
+ writeln!(self.out, "{level}{{")?;
+ write!(self.out, "{}", level.next())?;
+ self.write_value_type(module, &module.types[base].inner)?;
+ let depth = level.next().0;
+ write!(self.out, " {STORE_TEMP_NAME}{depth}")?;
+ self.write_array_size(module, base, crate::ArraySize::Constant(const_handle))?;
+ write!(self.out, " = ")?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ // then iterate the stores
+ let count = module.constants[const_handle].to_array_length().unwrap();
+ let stride = module.types[base].inner.size(&module.constants);
+ for i in 0..count {
+ self.temp_access_chain.push(SubAccess::Offset(i * stride));
+ let sv = StoreValue::TempIndex {
+ depth,
+ index: i,
+ ty: TypeResolution::Handle(base),
+ };
+ self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
+ self.temp_access_chain.pop();
+ }
+ // done
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ // first, assign the value to a temporary
+ writeln!(self.out, "{level}{{")?;
+ let depth = level.next().0;
+ let struct_ty = ty_resolution.handle().unwrap();
+ let struct_name = &self.names[&NameKey::Type(struct_ty)];
+ write!(
+ self.out,
+ "{}{} {}{} = ",
+ level.next(),
+ struct_name,
+ STORE_TEMP_NAME,
+ depth
+ )?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ // then iterate the stores
+ for (i, member) in members.iter().enumerate() {
+ self.temp_access_chain
+ .push(SubAccess::Offset(member.offset));
+ let sv = StoreValue::TempAccess {
+ depth,
+ base: struct_ty,
+ member_index: i as u32,
+ };
+ self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
+ self.temp_access_chain.pop();
+ }
+ // done
+ writeln!(self.out, "{level}}}")?;
+ }
+ _ => unreachable!(),
+ }
+ Ok(())
+ }
+
+ /// Set [`temp_access_chain`] to compute the byte offset of `cur_expr`.
+ ///
+ /// The `cur_expr` expression must be a reference to a global
+ /// variable in the [`Storage`] address space, or a chain of
+ /// [`Access`] and [`AccessIndex`] expressions referring to some
+ /// component of such a global.
+ ///
+ /// [`temp_access_chain`]: super::Writer::temp_access_chain
+ /// [`Storage`]: crate::AddressSpace::Storage
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ pub(super) fn fill_access_chain(
+ &mut self,
+ module: &crate::Module,
+ mut cur_expr: Handle<crate::Expression>,
+ func_ctx: &FunctionCtx,
+ ) -> Result<Handle<crate::GlobalVariable>, Error> {
+ enum AccessIndex {
+ Expression(Handle<crate::Expression>),
+ Constant(u32),
+ }
+ enum Parent<'a> {
+ Array { stride: u32 },
+ Struct(&'a [crate::StructMember]),
+ }
+ self.temp_access_chain.clear();
+
+ loop {
+ let (next_expr, access_index) = match func_ctx.expressions[cur_expr] {
+ crate::Expression::GlobalVariable(handle) => return Ok(handle),
+ crate::Expression::Access { base, index } => (base, AccessIndex::Expression(index)),
+ crate::Expression::AccessIndex { base, index } => {
+ (base, AccessIndex::Constant(index))
+ }
+ ref other => {
+ return Err(Error::Unimplemented(format!("Pointer access of {other:?}")))
+ }
+ };
+
+ let parent = match *func_ctx.info[next_expr].ty.inner_with(&module.types) {
+ crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
+ crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members),
+ crate::TypeInner::Array { stride, .. } => Parent::Array { stride },
+ crate::TypeInner::Vector { width, .. } => Parent::Array {
+ stride: width as u32,
+ },
+ crate::TypeInner::Matrix { rows, width, .. } => Parent::Array {
+ // The stride between matrices is the count of rows as this is how
+ // long each column is.
+ stride: Alignment::from(rows) * width as u32,
+ },
+ _ => unreachable!(),
+ },
+ crate::TypeInner::ValuePointer { width, .. } => Parent::Array {
+ stride: width as u32,
+ },
+ _ => unreachable!(),
+ };
+
+ let sub = match (parent, access_index) {
+ (Parent::Array { stride }, AccessIndex::Expression(value)) => {
+ SubAccess::Index { value, stride }
+ }
+ (Parent::Array { stride }, AccessIndex::Constant(index)) => {
+ SubAccess::Offset(stride * index)
+ }
+ (Parent::Struct(members), AccessIndex::Constant(index)) => {
+ SubAccess::Offset(members[index as usize].offset)
+ }
+ (Parent::Struct(_), AccessIndex::Expression(_)) => unreachable!(),
+ };
+
+ self.temp_access_chain.push(sub);
+ cur_expr = next_expr;
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/writer.rs b/third_party/rust/naga/src/back/hlsl/writer.rs
new file mode 100644
index 0000000000..a3810c5dab
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/writer.rs
@@ -0,0 +1,3188 @@
+use super::{
+ help::{WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess},
+ storage::StoreValue,
+ BackendResult, Error, Options,
+};
+use crate::{
+ back,
+ proc::{self, NameKey},
+ valid, Handle, Module, ScalarKind, ShaderStage, TypeInner,
+};
+use std::{fmt, mem};
+
+const LOCATION_SEMANTIC: &str = "LOC";
+const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
+const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
+const SPECIAL_BASE_VERTEX: &str = "base_vertex";
+const SPECIAL_BASE_INSTANCE: &str = "base_instance";
+const SPECIAL_OTHER: &str = "other";
+
+struct EpStructMember {
+ name: String,
+ ty: Handle<crate::Type>,
+ // technically, this should always be `Some`
+ binding: Option<crate::Binding>,
+ index: u32,
+}
+
+/// Structure contains information required for generating
+/// wrapped structure of all entry points arguments
+struct EntryPointBinding {
+ /// Name of the fake EP argument that contains the struct
+ /// with all the flattened input data.
+ arg_name: String,
+ /// Generated structure name
+ ty_name: String,
+ /// Members of generated structure
+ members: Vec<EpStructMember>,
+}
+
+pub(super) struct EntryPointInterface {
+ /// If `Some`, the input of an entry point is gathered in a special
+ /// struct with members sorted by binding.
+ /// The `EntryPointBinding::members` array is sorted by index,
+ /// so that we can walk it in `write_ep_arguments_initialization`.
+ input: Option<EntryPointBinding>,
+ /// If `Some`, the output of an entry point is flattened.
+ /// The `EntryPointBinding::members` array is sorted by binding,
+ /// So that we can walk it in `Statement::Return` handler.
+ output: Option<EntryPointBinding>,
+}
+
+#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
+enum InterfaceKey {
+ Location(u32),
+ BuiltIn(crate::BuiltIn),
+ Other,
+}
+
+impl InterfaceKey {
+ const fn new(binding: Option<&crate::Binding>) -> Self {
+ match binding {
+ Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
+ Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
+ None => Self::Other,
+ }
+ }
+}
+
+#[derive(Copy, Clone, PartialEq)]
+enum Io {
+ Input,
+ Output,
+}
+
+impl<'a, W: fmt::Write> super::Writer<'a, W> {
+ pub fn new(out: W, options: &'a Options) -> Self {
+ Self {
+ out,
+ names: crate::FastHashMap::default(),
+ namer: proc::Namer::default(),
+ options,
+ entry_point_io: Vec::new(),
+ named_expressions: crate::NamedExpressions::default(),
+ wrapped: super::Wrapped::default(),
+ temp_access_chain: Vec::new(),
+ need_bake_expressions: Default::default(),
+ }
+ }
+
+ fn reset(&mut self, module: &Module) {
+ self.names.clear();
+ self.namer
+ .reset(module, super::keywords::RESERVED, &[], &mut self.names);
+ self.entry_point_io.clear();
+ self.named_expressions.clear();
+ self.wrapped.clear();
+ self.need_bake_expressions.clear();
+ }
+
+ /// Helper method used to find which expressions of a given function require baking
+ ///
+ /// # Notes
+ /// Clears `need_bake_expressions` set before adding to it
+ fn update_expressions_to_bake(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ info: &valid::FunctionInfo,
+ ) {
+ use crate::Expression;
+ self.need_bake_expressions.clear();
+ for (fun_handle, expr) in func.expressions.iter() {
+ let expr_info = &info[fun_handle];
+ let min_ref_count = func.expressions[fun_handle].bake_ref_count();
+ if min_ref_count <= expr_info.ref_count {
+ self.need_bake_expressions.insert(fun_handle);
+ }
+
+ if let Expression::Math { fun, arg, .. } = *expr {
+ match fun {
+ crate::MathFunction::Asinh
+ | crate::MathFunction::Acosh
+ | crate::MathFunction::Atanh
+ | crate::MathFunction::Unpack2x16float => {
+ self.need_bake_expressions.insert(arg);
+ }
+ crate::MathFunction::CountLeadingZeros => {
+ let inner = info[fun_handle].ty.inner_with(&module.types);
+ if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
+ self.need_bake_expressions.insert(arg);
+ }
+ }
+ _ => {}
+ }
+ }
+
+ if let Expression::Derivative { axis, ctrl, expr } = *expr {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
+ self.need_bake_expressions.insert(expr);
+ }
+ }
+ }
+ }
+
+ pub fn write(
+ &mut self,
+ module: &Module,
+ module_info: &valid::ModuleInfo,
+ ) -> Result<super::ReflectionInfo, Error> {
+ self.reset(module);
+
+ // Write special constants, if needed
+ if let Some(ref bt) = self.options.special_constants_binding {
+ writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
+ writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_BASE_VERTEX)?;
+ writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_BASE_INSTANCE)?;
+ writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
+ writeln!(self.out, "}};")?;
+ write!(
+ self.out,
+ "ConstantBuffer<{}> {}: register(b{}",
+ SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
+ )?;
+ if bt.space != 0 {
+ write!(self.out, ", space{}", bt.space)?;
+ }
+ writeln!(self.out, ");")?;
+ }
+
+ // Write all constants
+ // For example, input wgsl shader:
+ // ```wgsl
+ // let c_scale: f32 = 1.2;
+ // return VertexOutput(uv, vec4<f32>(c_scale * pos, 0.0, 1.0));
+ // ```
+ //
+ // Output shader:
+ // ```hlsl
+ // static const float c_scale = 1.2;
+ // const VertexOutput vertexoutput1 = { vertexinput.uv3, float4((c_scale * vertexinput.pos1), 0.0, 1.0) };
+ // ```
+ //
+ // If we remove `write_global_constant` `c_scale` will be inlined.
+ for (handle, constant) in module.constants.iter() {
+ if constant.name.is_some() {
+ self.write_global_constant(module, &constant.inner, handle)?;
+ }
+ }
+
+ // Extra newline for readability
+ writeln!(self.out)?;
+
+ // Save all entry point output types
+ let ep_results = module
+ .entry_points
+ .iter()
+ .map(|ep| (ep.stage, ep.function.result.clone()))
+ .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
+
+ self.write_all_mat_cx2_typedefs_and_functions(module)?;
+
+ // Write all structs
+ for (handle, ty) in module.types.iter() {
+ if let TypeInner::Struct { ref members, span } = ty.inner {
+ if module.types[members.last().unwrap().ty]
+ .inner
+ .is_dynamically_sized(&module.types)
+ {
+ // unsized arrays can only be in storage buffers,
+ // for which we use `ByteAddressBuffer` anyway.
+ continue;
+ }
+
+ let ep_result = ep_results.iter().find(|e| {
+ if let Some(ref result) = e.1 {
+ result.ty == handle
+ } else {
+ false
+ }
+ });
+
+ self.write_struct(
+ module,
+ handle,
+ members,
+ span,
+ ep_result.map(|r| (r.0, Io::Output)),
+ )?;
+ writeln!(self.out)?;
+ }
+ }
+
+ // Write wrapped constructor functions used in constants
+ for (_, constant) in module.constants.iter() {
+ self.write_wrapped_constructor_function_for_constant(module, constant)?;
+ }
+
+ // Write all globals
+ for (ty, _) in module.global_variables.iter() {
+ self.write_global(module, ty)?;
+ }
+
+ if !module.global_variables.is_empty() {
+ // Add extra newline for readability
+ writeln!(self.out)?;
+ }
+
+ // Write all entry points wrapped structs
+ for (index, ep) in module.entry_points.iter().enumerate() {
+ let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
+ let ep_io = self.write_ep_interface(module, &ep.function, ep.stage, &ep_name)?;
+ self.entry_point_io.push(ep_io);
+ }
+
+ // Write all regular functions
+ for (handle, function) in module.functions.iter() {
+ let info = &module_info[handle];
+
+ // Check if all of the globals are accessible
+ if !self.options.fake_missing_bindings {
+ if let Some((var_handle, _)) =
+ module
+ .global_variables
+ .iter()
+ .find(|&(var_handle, var)| match var.binding {
+ Some(ref binding) if !info[var_handle].is_empty() => {
+ self.options.resolve_resource_binding(binding).is_err()
+ }
+ _ => false,
+ })
+ {
+ log::info!(
+ "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
+ handle,
+ function.name,
+ var_handle
+ );
+ continue;
+ }
+ }
+
+ let ctx = back::FunctionCtx {
+ ty: back::FunctionType::Function(handle),
+ info,
+ expressions: &function.expressions,
+ named_expressions: &function.named_expressions,
+ };
+ let name = self.names[&NameKey::Function(handle)].clone();
+
+ // Write wrapped function for `Expression::ImageQuery` and `Expressions::ArrayLength`
+ // before writing all statements and expressions.
+ self.write_wrapped_functions(module, &ctx)?;
+
+ self.write_function(module, name.as_str(), function, &ctx, info)?;
+
+ writeln!(self.out)?;
+ }
+
+ let mut entry_point_names = Vec::with_capacity(module.entry_points.len());
+
+ // Write all entry points
+ for (index, ep) in module.entry_points.iter().enumerate() {
+ let info = module_info.get_entry_point(index);
+
+ if !self.options.fake_missing_bindings {
+ let mut ep_error = None;
+ for (var_handle, var) in module.global_variables.iter() {
+ match var.binding {
+ Some(ref binding) if !info[var_handle].is_empty() => {
+ if let Err(err) = self.options.resolve_resource_binding(binding) {
+ ep_error = Some(err);
+ break;
+ }
+ }
+ _ => {}
+ }
+ }
+ if let Some(err) = ep_error {
+ entry_point_names.push(Err(err));
+ continue;
+ }
+ }
+
+ let ctx = back::FunctionCtx {
+ ty: back::FunctionType::EntryPoint(index as u16),
+ info,
+ expressions: &ep.function.expressions,
+ named_expressions: &ep.function.named_expressions,
+ };
+
+ // Write wrapped function for `Expression::ImageQuery` and `Expressions::ArrayLength`
+ // before writing all statements and expressions.
+ self.write_wrapped_functions(module, &ctx)?;
+
+ if ep.stage == ShaderStage::Compute {
+ // HLSL is calling workgroup size "num threads"
+ let num_threads = ep.workgroup_size;
+ writeln!(
+ self.out,
+ "[numthreads({}, {}, {})]",
+ num_threads[0], num_threads[1], num_threads[2]
+ )?;
+ }
+
+ let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
+ self.write_function(module, &name, &ep.function, &ctx, info)?;
+
+ if index < module.entry_points.len() - 1 {
+ writeln!(self.out)?;
+ }
+
+ entry_point_names.push(Ok(name));
+ }
+
+ Ok(super::ReflectionInfo { entry_point_names })
+ }
+
+ fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
+ match *binding {
+ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
+ write!(self.out, "precise ")?;
+ }
+ crate::Binding::Location {
+ interpolation,
+ sampling,
+ ..
+ } => {
+ if let Some(interpolation) = interpolation {
+ if let Some(string) = interpolation.to_hlsl_str() {
+ write!(self.out, "{string} ")?
+ }
+ }
+
+ if let Some(sampling) = sampling {
+ if let Some(string) = sampling.to_hlsl_str() {
+ write!(self.out, "{string} ")?
+ }
+ }
+ }
+ crate::Binding::BuiltIn(_) => {}
+ }
+
+ Ok(())
+ }
+
+ //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
+ // if they are struct, so that the `stage` argument here could be omitted.
+ fn write_semantic(
+ &mut self,
+ binding: &crate::Binding,
+ stage: Option<(ShaderStage, Io)>,
+ ) -> BackendResult {
+ match *binding {
+ crate::Binding::BuiltIn(builtin) => {
+ let builtin_str = builtin.to_hlsl_str()?;
+ write!(self.out, " : {builtin_str}")?;
+ }
+ crate::Binding::Location { location, .. } => {
+ if stage == Some((crate::ShaderStage::Fragment, Io::Output)) {
+ write!(self.out, " : SV_Target{location}")?;
+ } else {
+ write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ fn write_interface_struct(
+ &mut self,
+ module: &Module,
+ shader_stage: (ShaderStage, Io),
+ struct_name: String,
+ mut members: Vec<EpStructMember>,
+ ) -> Result<EntryPointBinding, Error> {
+ // Sort the members so that first come the user-defined varyings
+ // in ascending locations, and then built-ins. This allows VS and FS
+ // interfaces to match with regards to order.
+ members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
+
+ write!(self.out, "struct {struct_name}")?;
+ writeln!(self.out, " {{")?;
+ for m in members.iter() {
+ write!(self.out, "{}", back::INDENT)?;
+ if let Some(ref binding) = m.binding {
+ self.write_modifier(binding)?;
+ }
+ self.write_type(module, m.ty)?;
+ write!(self.out, " {}", &m.name)?;
+ if let Some(ref binding) = m.binding {
+ self.write_semantic(binding, Some(shader_stage))?;
+ }
+ writeln!(self.out, ";")?;
+ }
+ writeln!(self.out, "}};")?;
+ writeln!(self.out)?;
+
+ match shader_stage.1 {
+ Io::Input => {
+ // bring back the original order
+ members.sort_by_key(|m| m.index);
+ }
+ Io::Output => {
+ // keep it sorted by binding
+ }
+ }
+
+ Ok(EntryPointBinding {
+ arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
+ ty_name: struct_name,
+ members,
+ })
+ }
+
+ /// Flatten all entry point arguments into a single struct.
+ /// This is needed since we need to re-order them: first placing user locations,
+ /// then built-ins.
+ fn write_ep_input_struct(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ stage: ShaderStage,
+ entry_point_name: &str,
+ ) -> Result<EntryPointBinding, Error> {
+ let struct_name = format!("{stage:?}Input_{entry_point_name}");
+
+ let mut fake_members = Vec::new();
+ for arg in func.arguments.iter() {
+ match module.types[arg.ty].inner {
+ TypeInner::Struct { ref members, .. } => {
+ for member in members.iter() {
+ let name = self.namer.call_or(&member.name, "member");
+ let index = fake_members.len() as u32;
+ fake_members.push(EpStructMember {
+ name,
+ ty: member.ty,
+ binding: member.binding.clone(),
+ index,
+ });
+ }
+ }
+ _ => {
+ let member_name = self.namer.call_or(&arg.name, "member");
+ let index = fake_members.len() as u32;
+ fake_members.push(EpStructMember {
+ name: member_name,
+ ty: arg.ty,
+ binding: arg.binding.clone(),
+ index,
+ });
+ }
+ }
+ }
+
+ self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
+ }
+
+ /// Flatten all entry point results into a single struct.
+ /// This is needed since we need to re-order them: first placing user locations,
+ /// then built-ins.
+ fn write_ep_output_struct(
+ &mut self,
+ module: &Module,
+ result: &crate::FunctionResult,
+ stage: ShaderStage,
+ entry_point_name: &str,
+ ) -> Result<EntryPointBinding, Error> {
+ let struct_name = format!("{stage:?}Output_{entry_point_name}");
+
+ let mut fake_members = Vec::new();
+ let empty = [];
+ let members = match module.types[result.ty].inner {
+ TypeInner::Struct { ref members, .. } => members,
+ ref other => {
+ log::error!("Unexpected {:?} output type without a binding", other);
+ &empty[..]
+ }
+ };
+
+ for member in members.iter() {
+ let member_name = self.namer.call_or(&member.name, "member");
+ let index = fake_members.len() as u32;
+ fake_members.push(EpStructMember {
+ name: member_name,
+ ty: member.ty,
+ binding: member.binding.clone(),
+ index,
+ });
+ }
+
+ self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
+ }
+
+ /// Writes special interface structures for an entry point. The special structures have
+ /// all the fields flattened into them and sorted by binding. They are only needed for
+ /// VS outputs and FS inputs, so that these interfaces match.
+ fn write_ep_interface(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ stage: ShaderStage,
+ ep_name: &str,
+ ) -> Result<EntryPointInterface, Error> {
+ Ok(EntryPointInterface {
+ input: if !func.arguments.is_empty() && stage == ShaderStage::Fragment {
+ Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
+ } else {
+ None
+ },
+ output: match func.result {
+ Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
+ Some(self.write_ep_output_struct(module, fr, stage, ep_name)?)
+ }
+ _ => None,
+ },
+ })
+ }
+
+ /// Write an entry point preface that initializes the arguments as specified in IR.
+ fn write_ep_arguments_initialization(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ ep_index: u16,
+ ) -> BackendResult {
+ let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
+ Some(ep_input) => ep_input,
+ None => return Ok(()),
+ };
+ let mut fake_iter = ep_input.members.iter();
+ for (arg_index, arg) in func.arguments.iter().enumerate() {
+ write!(self.out, "{}", back::INDENT)?;
+ self.write_type(module, arg.ty)?;
+ let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
+ write!(self.out, " {arg_name}")?;
+ match module.types[arg.ty].inner {
+ TypeInner::Array { base, size, .. } => {
+ self.write_array_size(module, base, size)?;
+ let fake_member = fake_iter.next().unwrap();
+ writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
+ }
+ TypeInner::Struct { ref members, .. } => {
+ write!(self.out, " = {{ ")?;
+ for index in 0..members.len() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ let fake_member = fake_iter.next().unwrap();
+ write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
+ }
+ writeln!(self.out, " }};")?;
+ }
+ _ => {
+ let fake_member = fake_iter.next().unwrap();
+ writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
+ }
+ }
+ }
+ assert!(fake_iter.next().is_none());
+ Ok(())
+ }
+
+ /// Helper method used to write global variables
+ /// # Notes
+ /// Always adds a newline
+ fn write_global(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::GlobalVariable>,
+ ) -> BackendResult {
+ let global = &module.global_variables[handle];
+ let inner = &module.types[global.ty].inner;
+
+ if let Some(ref binding) = global.binding {
+ if let Err(err) = self.options.resolve_resource_binding(binding) {
+ log::info!(
+ "Skipping global {:?} (name {:?}) for being inaccessible: {}",
+ handle,
+ global.name,
+ err,
+ );
+ return Ok(());
+ }
+ }
+
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
+ let register_ty = match global.space {
+ crate::AddressSpace::Function => unreachable!("Function address space"),
+ crate::AddressSpace::Private => {
+ write!(self.out, "static ")?;
+ self.write_type(module, global.ty)?;
+ ""
+ }
+ crate::AddressSpace::WorkGroup => {
+ write!(self.out, "groupshared ")?;
+ self.write_type(module, global.ty)?;
+ ""
+ }
+ crate::AddressSpace::Uniform => {
+ // constant buffer declarations are expected to be inlined, e.g.
+ // `cbuffer foo: register(b0) { field1: type1; }`
+ write!(self.out, "cbuffer")?;
+ "b"
+ }
+ crate::AddressSpace::Storage { access } => {
+ let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
+ ("RW", "u")
+ } else {
+ ("", "t")
+ };
+ write!(self.out, "{prefix}ByteAddressBuffer")?;
+ register
+ }
+ crate::AddressSpace::Handle => {
+ let handle_ty = match *inner {
+ TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
+ _ => inner,
+ };
+
+ let register = match *handle_ty {
+ TypeInner::Sampler { .. } => "s",
+ // all storage textures are UAV, unconditionally
+ TypeInner::Image {
+ class: crate::ImageClass::Storage { .. },
+ ..
+ } => "u",
+ _ => "t",
+ };
+ self.write_type(module, global.ty)?;
+ register
+ }
+ crate::AddressSpace::PushConstant => {
+ // The type of the push constants will be wrapped in `ConstantBuffer`
+ write!(self.out, "ConstantBuffer<")?;
+ "b"
+ }
+ };
+
+ // If the global is a push constant write the type now because it will be a
+ // generic argument to `ConstantBuffer`
+ if global.space == crate::AddressSpace::PushConstant {
+ self.write_global_type(module, global.ty)?;
+
+ // need to write the array size if the type was emitted with `write_type`
+ if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ // Close the angled brackets for the generic argument
+ write!(self.out, ">")?;
+ }
+
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, " {name}")?;
+
+ // Push constants need to be assigned a binding explicitly by the consumer
+ // since naga has no way to know the binding from the shader alone
+ if global.space == crate::AddressSpace::PushConstant {
+ let target = self
+ .options
+ .push_constants_target
+ .as_ref()
+ .expect("No bind target was defined for the push constants block");
+ write!(self.out, ": register(b{}", target.register)?;
+ if target.space != 0 {
+ write!(self.out, ", space{}", target.space)?;
+ }
+ write!(self.out, ")")?;
+ }
+
+ if let Some(ref binding) = global.binding {
+ // this was already resolved earlier when we started evaluating an entry point.
+ let bt = self.options.resolve_resource_binding(binding).unwrap();
+
+ // need to write the binding array size if the type was emitted with `write_type`
+ if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
+ if let Some(overridden_size) = bt.binding_array_size {
+ write!(self.out, "[{overridden_size}]")?;
+ } else {
+ self.write_array_size(module, base, size)?;
+ }
+ }
+
+ write!(self.out, " : register({}{}", register_ty, bt.register)?;
+ if bt.space != 0 {
+ write!(self.out, ", space{}", bt.space)?;
+ }
+ write!(self.out, ")")?;
+ } else {
+ // need to write the array size if the type was emitted with `write_type`
+ if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ if global.space == crate::AddressSpace::Private {
+ write!(self.out, " = ")?;
+ if let Some(init) = global.init {
+ self.write_constant(module, init)?;
+ } else {
+ self.write_default_init(module, global.ty)?;
+ }
+ }
+ }
+
+ if global.space == crate::AddressSpace::Uniform {
+ write!(self.out, " {{ ")?;
+
+ self.write_global_type(module, global.ty)?;
+
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::GlobalVariable(handle)]
+ )?;
+
+ // need to write the array size if the type was emitted with `write_type`
+ if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ writeln!(self.out, "; }}")?;
+ } else {
+ writeln!(self.out, ";")?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write global constants
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_global_constant(
+ &mut self,
+ module: &Module,
+ inner: &crate::ConstantInner,
+ handle: Handle<crate::Constant>,
+ ) -> BackendResult {
+ write!(self.out, "static const ")?;
+ match *inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ ref value,
+ } => {
+ // Write type
+ let ty_str = match *value {
+ crate::ScalarValue::Sint(_) => "int",
+ crate::ScalarValue::Uint(_) => "uint",
+ crate::ScalarValue::Float(_) => "float",
+ crate::ScalarValue::Bool(_) => "bool",
+ };
+ let name = &self.names[&NameKey::Constant(handle)];
+ write!(self.out, "{ty_str} {name} = ")?;
+
+ // Second match required to avoid heap allocation by `format!()`
+ match *value {
+ crate::ScalarValue::Sint(value) => write!(self.out, "{value}")?,
+ crate::ScalarValue::Uint(value) => write!(self.out, "{value}")?,
+ crate::ScalarValue::Float(value) => {
+ // Floats are written using `Debug` instead of `Display` because it always appends the
+ // decimal part even it's zero
+ write!(self.out, "{value:?}")?
+ }
+ crate::ScalarValue::Bool(value) => write!(self.out, "{value}")?,
+ };
+ }
+ crate::ConstantInner::Composite { ty, ref components } => {
+ self.write_type(module, ty)?;
+ let name = &self.names[&NameKey::Constant(handle)];
+ write!(self.out, " {name} = ")?;
+ self.write_composite_constant(module, ty, components)?;
+ }
+ }
+ writeln!(self.out, ";")?;
+ Ok(())
+ }
+
+ pub(super) fn write_array_size(
+ &mut self,
+ module: &Module,
+ base: Handle<crate::Type>,
+ size: crate::ArraySize,
+ ) -> BackendResult {
+ write!(self.out, "[")?;
+
+ // Write the array size
+ // Writes nothing if `ArraySize::Dynamic`
+ // Panics if `ArraySize::Constant` has a constant that isn't an sint or uint
+ match size {
+ crate::ArraySize::Constant(const_handle) => {
+ let size = module.constants[const_handle].to_array_length().unwrap();
+ write!(self.out, "{size}")?;
+ }
+ crate::ArraySize::Dynamic => {}
+ }
+
+ write!(self.out, "]")?;
+
+ if let TypeInner::Array {
+ base: next_base,
+ size: next_size,
+ ..
+ } = module.types[base].inner
+ {
+ self.write_array_size(module, next_base, next_size)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write structs
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_struct(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Type>,
+ members: &[crate::StructMember],
+ span: u32,
+ shader_stage: Option<(ShaderStage, Io)>,
+ ) -> BackendResult {
+ // Write struct name
+ let struct_name = &self.names[&NameKey::Type(handle)];
+ writeln!(self.out, "struct {struct_name} {{")?;
+
+ let mut last_offset = 0;
+ for (index, member) in members.iter().enumerate() {
+ if member.binding.is_none() && member.offset > last_offset {
+ // using int as padding should work as long as the backend
+ // doesn't support a type that's less than 4 bytes in size
+ // (Error::UnsupportedScalar catches this)
+ let padding = (member.offset - last_offset) / 4;
+ for i in 0..padding {
+ writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
+ }
+ }
+ let ty_inner = &module.types[member.ty].inner;
+ last_offset = member.offset + ty_inner.size_hlsl(&module.types, &module.constants);
+
+ // The indentation is only for readability
+ write!(self.out, "{}", back::INDENT)?;
+
+ match module.types[member.ty].inner {
+ TypeInner::Array { base, size, .. } => {
+ // HLSL arrays are written as `type name[size]`
+
+ self.write_global_type(module, member.ty)?;
+
+ // Write `name`
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::StructMember(handle, index as u32)]
+ )?;
+ // Write [size]
+ self.write_array_size(module, base, size)?;
+ }
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ TypeInner::Matrix {
+ rows,
+ columns,
+ width,
+ } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
+ let vec_ty = crate::TypeInner::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ };
+ let field_name_key = NameKey::StructMember(handle, index as u32);
+
+ for i in 0..columns as u8 {
+ if i != 0 {
+ write!(self.out, "; ")?;
+ }
+ self.write_value_type(module, &vec_ty)?;
+ write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
+ }
+ }
+ _ => {
+ // Write modifier before type
+ if let Some(ref binding) = member.binding {
+ self.write_modifier(binding)?;
+ }
+
+ // Even though Naga IR matrices are column-major, we must describe
+ // matrices passed from the CPU as being in row-major order.
+ // See the module-level block comment in mod.rs for details.
+ if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
+ write!(self.out, "row_major ")?;
+ }
+
+ // Write the member type and name
+ self.write_type(module, member.ty)?;
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::StructMember(handle, index as u32)]
+ )?;
+ }
+ }
+
+ if let Some(ref binding) = member.binding {
+ self.write_semantic(binding, shader_stage)?;
+ };
+ writeln!(self.out, ";")?;
+ }
+
+ // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
+ if members.last().unwrap().binding.is_none() && span > last_offset {
+ let padding = (span - last_offset) / 4;
+ for i in 0..padding {
+ writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
+ }
+ }
+
+ writeln!(self.out, "}};")?;
+ Ok(())
+ }
+
+ /// Helper method used to write global/structs non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ pub(super) fn write_global_type(
+ &mut self,
+ module: &Module,
+ ty: Handle<crate::Type>,
+ ) -> BackendResult {
+ let matrix_data = get_inner_matrix_data(module, ty);
+
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = matrix_data
+ {
+ write!(self.out, "__mat{}x2", columns as u8)?;
+ } else {
+ // Even though Naga IR matrices are column-major, we must describe
+ // matrices passed from the CPU as being in row-major order.
+ // See the module-level block comment in mod.rs for details.
+ if matrix_data.is_some() {
+ write!(self.out, "row_major ")?;
+ }
+
+ self.write_type(module, ty)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ let inner = &module.types[ty].inner;
+ match *inner {
+ TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
+ // hlsl array has the size separated from the base type
+ TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
+ self.write_type(module, base)?
+ }
+ ref other => self.write_value_type(module, other)?,
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write value types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
+ match *inner {
+ TypeInner::Scalar { kind, width } | TypeInner::Atomic { kind, width } => {
+ write!(self.out, "{}", kind.to_hlsl_str(width)?)?;
+ }
+ TypeInner::Vector { size, kind, width } => {
+ write!(
+ self.out,
+ "{}{}",
+ kind.to_hlsl_str(width)?,
+ back::vector_size_str(size)
+ )?;
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ // The IR supports only float matrix
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
+
+ // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
+ write!(
+ self.out,
+ "{}{}x{}",
+ crate::ScalarKind::Float.to_hlsl_str(width)?,
+ back::vector_size_str(columns),
+ back::vector_size_str(rows),
+ )?;
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ self.write_image_type(dim, arrayed, class)?;
+ }
+ TypeInner::Sampler { comparison } => {
+ let sampler = if comparison {
+ "SamplerComparisonState"
+ } else {
+ "SamplerState"
+ };
+ write!(self.out, "{sampler}")?;
+ }
+ // HLSL arrays are written as `type name[size]`
+ // Current code is written arrays only as `[size]`
+ // Base `type` and `name` should be written outside
+ TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
+ self.write_array_size(module, base, size)?;
+ }
+ _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write functions
+ /// # Notes
+ /// Ends in a newline
+ fn write_function(
+ &mut self,
+ module: &Module,
+ name: &str,
+ func: &crate::Function,
+ func_ctx: &back::FunctionCtx<'_>,
+ info: &valid::FunctionInfo,
+ ) -> BackendResult {
+ // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
+
+ self.update_expressions_to_bake(module, func, info);
+
+ // Write modifier
+ if let Some(crate::FunctionResult {
+ binding:
+ Some(
+ ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position {
+ invariant: true,
+ }),
+ ),
+ ..
+ }) = func.result
+ {
+ self.write_modifier(binding)?;
+ }
+
+ // Write return type
+ if let Some(ref result) = func.result {
+ match func_ctx.ty {
+ back::FunctionType::Function(_) => {
+ self.write_type(module, result.ty)?;
+ }
+ back::FunctionType::EntryPoint(index) => {
+ if let Some(ref ep_output) = self.entry_point_io[index as usize].output {
+ write!(self.out, "{}", ep_output.ty_name)?;
+ } else {
+ self.write_type(module, result.ty)?;
+ }
+ }
+ }
+ } else {
+ write!(self.out, "void")?;
+ }
+
+ // Write function name
+ write!(self.out, " {name}(")?;
+
+ let need_workgroup_variables_initialization =
+ self.need_workgroup_variables_initialization(func_ctx, module);
+
+ // Write function arguments for non entry point functions
+ match func_ctx.ty {
+ back::FunctionType::Function(handle) => {
+ for (index, arg) in func.arguments.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ // Write argument type
+ let arg_ty = match module.types[arg.ty].inner {
+ // pointers in function arguments are expected and resolve to `inout`
+ TypeInner::Pointer { base, .. } => {
+ //TODO: can we narrow this down to just `in` when possible?
+ write!(self.out, "inout ")?;
+ base
+ }
+ _ => arg.ty,
+ };
+ self.write_type(module, arg_ty)?;
+
+ let argument_name =
+ &self.names[&NameKey::FunctionArgument(handle, index as u32)];
+
+ // Write argument name. Space is important.
+ write!(self.out, " {argument_name}")?;
+ if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ }
+ }
+ back::FunctionType::EntryPoint(ep_index) => {
+ if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
+ write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?;
+ } else {
+ let stage = module.entry_points[ep_index as usize].stage;
+ for (index, arg) in func.arguments.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_type(module, arg.ty)?;
+
+ let argument_name =
+ &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
+
+ write!(self.out, " {argument_name}")?;
+ if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ if let Some(ref binding) = arg.binding {
+ self.write_semantic(binding, Some((stage, Io::Input)))?;
+ }
+ }
+
+ if need_workgroup_variables_initialization {
+ if !func.arguments.is_empty() {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
+ }
+ }
+ }
+ }
+ // Ends of arguments
+ write!(self.out, ")")?;
+
+ // Write semantic if it present
+ if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
+ let stage = module.entry_points[index as usize].stage;
+ if let Some(crate::FunctionResult {
+ binding: Some(ref binding),
+ ..
+ }) = func.result
+ {
+ self.write_semantic(binding, Some((stage, Io::Output)))?;
+ }
+ }
+
+ // Function body start
+ writeln!(self.out)?;
+ writeln!(self.out, "{{")?;
+
+ if need_workgroup_variables_initialization {
+ self.write_workgroup_variables_initialization(func_ctx, module)?;
+ }
+
+ if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
+ self.write_ep_arguments_initialization(module, func, index)?;
+ }
+
+ // Write function local variables
+ for (handle, local) in func.local_variables.iter() {
+ // Write indentation (only for readability)
+ write!(self.out, "{}", back::INDENT)?;
+
+ // Write the local name
+ // The leading space is important
+ self.write_type(module, local.ty)?;
+ write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
+ // Write size for array type
+ if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ write!(self.out, " = ")?;
+ // Write the local initializer if needed
+ if let Some(init) = local.init {
+ // Put the equal signal only if there's a initializer
+ // The leading and trailing spaces aren't needed but help with readability
+
+ // Write the constant
+ // `write_constant` adds no trailing or leading space/newline
+ self.write_constant(module, init)?;
+ } else {
+ // Zero initialize local variables
+ self.write_default_init(module, local.ty)?;
+ }
+
+ // Finish the local with `;` and add a newline (only for readability)
+ writeln!(self.out, ";")?
+ }
+
+ if !func.local_variables.is_empty() {
+ writeln!(self.out)?;
+ }
+
+ // Write the function body (statement list)
+ for sta in func.body.iter() {
+ // The indentation should always be 1 when writing the function body
+ self.write_stmt(module, sta, func_ctx, back::Level(1))?;
+ }
+
+ writeln!(self.out, "}}")?;
+
+ self.named_expressions.clear();
+
+ Ok(())
+ }
+
+ fn need_workgroup_variables_initialization(
+ &mut self,
+ func_ctx: &back::FunctionCtx,
+ module: &Module,
+ ) -> bool {
+ self.options.zero_initialize_workgroup_memory
+ && func_ctx.ty.is_compute_entry_point(module)
+ && module.global_variables.iter().any(|(handle, var)| {
+ !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ })
+ }
+
+ fn write_workgroup_variables_initialization(
+ &mut self,
+ func_ctx: &back::FunctionCtx,
+ module: &Module,
+ ) -> BackendResult {
+ let level = back::Level(1);
+
+ writeln!(
+ self.out,
+ "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
+ )?;
+
+ let vars = module.global_variables.iter().filter(|&(handle, var)| {
+ !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ });
+
+ for (handle, var) in vars {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{}{} = ", level.next(), name)?;
+ self.write_default_init(module, var.ty)?;
+ writeln!(self.out, ";")?;
+ }
+
+ writeln!(self.out, "{level}}}")?;
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)
+ }
+
+ /// Helper method used to write statements
+ ///
+ /// # Notes
+ /// Always adds a newline
+ fn write_stmt(
+ &mut self,
+ module: &Module,
+ stmt: &crate::Statement,
+ func_ctx: &back::FunctionCtx<'_>,
+ level: back::Level,
+ ) -> BackendResult {
+ use crate::Statement;
+
+ match *stmt {
+ Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ let info = &func_ctx.info[handle];
+ let ptr_class = info.ty.inner_with(&module.types).pointer_space();
+ let expr_name = if ptr_class.is_some() {
+ // HLSL can't save a pointer-valued expression in a variable,
+ // but we shouldn't ever need to: they should never be named expressions,
+ // and none of the expression types flagged by bake_ref_count can be pointer-valued.
+ None
+ } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
+ // Front end provides names for all variables at the start of writing.
+ // But we write them to step by step. We need to recache them
+ // Otherwise, we could accidentally write variable name instead of full expression.
+ // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
+ Some(self.namer.call(name))
+ } else if self.need_bake_expressions.contains(&handle) {
+ Some(format!("_expr{}", handle.index()))
+ } else if info.ref_count == 0 {
+ Some(self.namer.call(""))
+ } else {
+ None
+ };
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{level}")?;
+ self.write_named_expr(module, handle, name, func_ctx)?;
+ }
+ }
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Block(ref block) => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "{{")?;
+ for sta in block.iter() {
+ // Increase the indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, level.next())?
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "if (")?;
+ self.write_expr(module, condition, func_ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ let l2 = level.next();
+ for sta in accept {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+
+ // If there are no statements in the reject block we skip writing it
+ // This is only for readability
+ if !reject.is_empty() {
+ writeln!(self.out, "{level}}} else {{")?;
+
+ for sta in reject {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Kill => writeln!(self.out, "{level}discard;")?,
+ Statement::Return { value: None } => {
+ writeln!(self.out, "{level}return;")?;
+ }
+ Statement::Return { value: Some(expr) } => {
+ let base_ty_res = &func_ctx.info[expr].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+ if let TypeInner::Pointer { base, space: _ } = *resolved {
+ resolved = &module.types[base].inner;
+ }
+
+ if let TypeInner::Struct { .. } = *resolved {
+ // We can safely unwrap here, since we now we working with struct
+ let ty = base_ty_res.handle().unwrap();
+ let struct_name = &self.names[&NameKey::Type(ty)];
+ let variable_name = self.namer.call(&struct_name.to_lowercase());
+ write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
+ self.write_expr(module, expr, func_ctx)?;
+ writeln!(self.out, ";")?;
+
+ // for entry point returns, we may need to reshuffle the outputs into a different struct
+ let ep_output = match func_ctx.ty {
+ back::FunctionType::Function(_) => None,
+ back::FunctionType::EntryPoint(index) => {
+ self.entry_point_io[index as usize].output.as_ref()
+ }
+ };
+ let final_name = match ep_output {
+ Some(ep_output) => {
+ let final_name = self.namer.call(&variable_name);
+ write!(
+ self.out,
+ "{}const {} {} = {{ ",
+ level, ep_output.ty_name, final_name,
+ )?;
+ for (index, m) in ep_output.members.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
+ write!(self.out, "{variable_name}.{member_name}")?;
+ }
+ writeln!(self.out, " }};")?;
+ final_name
+ }
+ None => variable_name,
+ };
+ writeln!(self.out, "{level}return {final_name};")?;
+ } else {
+ write!(self.out, "{level}return ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ writeln!(self.out, ";")?
+ }
+ }
+ Statement::Store { pointer, value } => {
+ let ty_inner = func_ctx.info[pointer].ty.inner_with(&module.types);
+ if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
+ let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
+ self.write_storage_store(
+ module,
+ var_handle,
+ StoreValue::Expression(value),
+ func_ctx,
+ level,
+ )?;
+ } else {
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ //
+ // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
+ // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
+ struct MatrixAccess {
+ base: Handle<crate::Expression>,
+ index: u32,
+ }
+ enum Index {
+ Expression(Handle<crate::Expression>),
+ Static(u32),
+ }
+
+ let get_members = |expr: Handle<crate::Expression>| {
+ let base_ty_res = &func_ctx.info[expr].ty;
+ let resolved = base_ty_res.inner_with(&module.types);
+ match *resolved {
+ TypeInner::Pointer { base, .. } => match module.types[base].inner {
+ TypeInner::Struct { ref members, .. } => Some(members),
+ _ => None,
+ },
+ _ => None,
+ }
+ };
+
+ let mut matrix = None;
+ let mut vector = None;
+ let mut scalar = None;
+
+ let mut current_expr = pointer;
+ for _ in 0..3 {
+ let resolved = func_ctx.info[current_expr].ty.inner_with(&module.types);
+
+ match (resolved, &func_ctx.expressions[current_expr]) {
+ (
+ &TypeInner::Pointer { base: ty, .. },
+ &crate::Expression::AccessIndex { base, index },
+ ) if matches!(
+ module.types[ty].inner,
+ TypeInner::Matrix {
+ rows: crate::VectorSize::Bi,
+ ..
+ }
+ ) && get_members(base)
+ .map(|members| members[index as usize].binding.is_none())
+ == Some(true) =>
+ {
+ matrix = Some(MatrixAccess { base, index });
+ break;
+ }
+ (
+ &TypeInner::ValuePointer {
+ size: Some(crate::VectorSize::Bi),
+ ..
+ },
+ &crate::Expression::Access { base, index },
+ ) => {
+ vector = Some(Index::Expression(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer {
+ size: Some(crate::VectorSize::Bi),
+ ..
+ },
+ &crate::Expression::AccessIndex { base, index },
+ ) => {
+ vector = Some(Index::Static(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::Access { base, index },
+ ) => {
+ scalar = Some(Index::Expression(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::AccessIndex { base, index },
+ ) => {
+ scalar = Some(Index::Static(index));
+ current_expr = base;
+ }
+ _ => break,
+ }
+ }
+
+ write!(self.out, "{level}")?;
+
+ if let Some(MatrixAccess { index, base }) = matrix {
+ let base_ty_res = &func_ctx.info[base].ty;
+ let resolved = base_ty_res.inner_with(&module.types);
+ let ty = match *resolved {
+ TypeInner::Pointer { base, .. } => base,
+ _ => base_ty_res.handle().unwrap(),
+ };
+
+ if let Some(Index::Static(vec_index)) = vector {
+ self.write_expr(module, base, func_ctx)?;
+ write!(
+ self.out,
+ ".{}_{}",
+ &self.names[&NameKey::StructMember(ty, index)],
+ vec_index
+ )?;
+
+ if let Some(scalar_index) = scalar {
+ write!(self.out, "[")?;
+ match scalar_index {
+ Index::Static(index) => {
+ write!(self.out, "{index}")?;
+ }
+ Index::Expression(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ write!(self.out, "]")?;
+ }
+
+ write!(self.out, " = ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ } else {
+ let access = WrappedStructMatrixAccess { ty, index };
+ match (&vector, &scalar) {
+ (&Some(_), &Some(_)) => {
+ self.write_wrapped_struct_matrix_set_scalar_function_name(
+ access,
+ )?;
+ }
+ (&Some(_), &None) => {
+ self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
+ }
+ (&None, _) => {
+ self.write_wrapped_struct_matrix_set_function_name(access)?;
+ }
+ }
+
+ write!(self.out, "(")?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+
+ if let Some(Index::Expression(vec_index)) = vector {
+ write!(self.out, ", ")?;
+ self.write_expr(module, vec_index, func_ctx)?;
+
+ if let Some(scalar_index) = scalar {
+ write!(self.out, ", ")?;
+ match scalar_index {
+ Index::Static(index) => {
+ write!(self.out, "{index}")?;
+ }
+ Index::Expression(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ }
+ }
+ writeln!(self.out, ");")?;
+ }
+ } else {
+ // We handle `Store`s to __matCx2 column vectors and scalar elements via
+ // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
+ struct MatrixData {
+ columns: crate::VectorSize,
+ base: Handle<crate::Expression>,
+ }
+
+ enum Index {
+ Expression(Handle<crate::Expression>),
+ Static(u32),
+ }
+
+ let mut matrix = None;
+ let mut vector = None;
+ let mut scalar = None;
+
+ let mut current_expr = pointer;
+ for _ in 0..3 {
+ let resolved = func_ctx.info[current_expr].ty.inner_with(&module.types);
+ match (resolved, &func_ctx.expressions[current_expr]) {
+ (
+ &TypeInner::ValuePointer {
+ size: Some(crate::VectorSize::Bi),
+ ..
+ },
+ &crate::Expression::Access { base, index },
+ ) => {
+ vector = Some(index);
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::Access { base, index },
+ ) => {
+ scalar = Some(Index::Expression(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::AccessIndex { base, index },
+ ) => {
+ scalar = Some(Index::Static(index));
+ current_expr = base;
+ }
+ _ => {
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = get_inner_matrix_of_struct_array_member(
+ module,
+ current_expr,
+ func_ctx,
+ true,
+ ) {
+ matrix = Some(MatrixData {
+ columns,
+ base: current_expr,
+ });
+ }
+
+ break;
+ }
+ }
+ }
+
+ if let (Some(MatrixData { columns, base }), Some(vec_index)) =
+ (matrix, vector)
+ {
+ if scalar.is_some() {
+ write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
+ } else {
+ write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
+ }
+ write!(self.out, "(")?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, vec_index, func_ctx)?;
+
+ if let Some(scalar_index) = scalar {
+ write!(self.out, ", ")?;
+ match scalar_index {
+ Index::Static(index) => {
+ write!(self.out, "{index}")?;
+ }
+ Index::Expression(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ }
+
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+
+ writeln!(self.out, ");")?;
+ } else {
+ self.write_expr(module, pointer, func_ctx)?;
+ write!(self.out, " = ")?;
+
+ // We cast the RHS of this store in cases where the LHS
+ // is a struct member with type:
+ // - matCx2 or
+ // - a (possibly nested) array of matCx2's
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = get_inner_matrix_of_struct_array_member(
+ module, pointer, func_ctx, false,
+ ) {
+ let mut resolved =
+ func_ctx.info[pointer].ty.inner_with(&module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ }
+
+ write!(self.out, "(__mat{}x2", columns as u8)?;
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, ")")?;
+ }
+
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ";")?
+ }
+ }
+ }
+ }
+ Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ let l2 = level.next();
+ if !continuing.is_empty() || break_if.is_some() {
+ let gate_name = self.namer.call("loop_init");
+ writeln!(self.out, "{level}bool {gate_name} = true;")?;
+ writeln!(self.out, "{level}while(true) {{")?;
+ writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
+ let l3 = l2.next();
+ for sta in continuing.iter() {
+ self.write_stmt(module, sta, func_ctx, l3)?;
+ }
+ if let Some(condition) = break_if {
+ write!(self.out, "{l3}if (")?;
+ self.write_expr(module, condition, func_ctx)?;
+ writeln!(self.out, ") {{")?;
+ writeln!(self.out, "{}break;", l3.next())?;
+ writeln!(self.out, "{l3}}}")?;
+ }
+ writeln!(self.out, "{l2}}}")?;
+ writeln!(self.out, "{l2}{gate_name} = false;")?;
+ } else {
+ writeln!(self.out, "{level}while(true) {{")?;
+ }
+
+ for sta in body.iter() {
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Break => writeln!(self.out, "{level}break;")?,
+ Statement::Continue => writeln!(self.out, "{level}continue;")?,
+ Statement::Barrier(barrier) => {
+ self.write_barrier(barrier, level)?;
+ }
+ Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ write!(self.out, "{level}")?;
+ self.write_expr(module, image, func_ctx)?;
+
+ write!(self.out, "[")?;
+ if let Some(index) = array_index {
+ // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
+ write!(self.out, "int3(")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(module, coordinate, func_ctx)?;
+ }
+ write!(self.out, "]")?;
+
+ write!(self.out, " = ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ }
+ Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ if let Some(expr) = result {
+ write!(self.out, "const ")?;
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ let expr_ty = &func_ctx.info[expr].ty;
+ match *expr_ty {
+ proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
+ proc::TypeResolution::Value(ref value) => {
+ self.write_value_type(module, value)?
+ }
+ };
+ write!(self.out, " {name} = ")?;
+ self.named_expressions.insert(expr, name);
+ }
+ let func_name = &self.names[&NameKey::Function(function)];
+ write!(self.out, "{func_name}(")?;
+ for (index, argument) in arguments.iter().enumerate() {
+ self.write_expr(module, *argument, func_ctx)?;
+ // Only write a comma if isn't the last element
+ if index != arguments.len().saturating_sub(1) {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ }
+ writeln!(self.out, ");")?
+ }
+ Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ match func_ctx.info[result].ty {
+ proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
+ proc::TypeResolution::Value(ref value) => {
+ self.write_value_type(module, value)?
+ }
+ };
+
+ // Validation ensures that `pointer` has a `Pointer` type.
+ let pointer_space = func_ctx.info[pointer]
+ .ty
+ .inner_with(&module.types)
+ .pointer_space()
+ .unwrap();
+
+ let fun_str = fun.to_hlsl_suffix();
+ write!(self.out, " {res_name}; ")?;
+ match pointer_space {
+ crate::AddressSpace::WorkGroup => {
+ write!(self.out, "Interlocked{fun_str}(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ }
+ crate::AddressSpace::Storage { .. } => {
+ let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
+ // The call to `self.write_storage_address` wants
+ // mutable access to all of `self`, so temporarily take
+ // ownership of our reusable access chain buffer.
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "{var_name}.Interlocked{fun_str}(")?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ self.temp_access_chain = chain;
+ }
+ ref other => {
+ return Err(Error::Custom(format!(
+ "invalid address space {other:?} for atomic statement"
+ )))
+ }
+ }
+ write!(self.out, ", ")?;
+ // handle the special cases
+ match *fun {
+ crate::AtomicFunction::Subtract => {
+ // we just wrote `InterlockedAdd`, so negate the argument
+ write!(self.out, "-")?;
+ }
+ crate::AtomicFunction::Exchange { compare: Some(_) } => {
+ return Err(Error::Unimplemented("atomic CompareExchange".to_string()));
+ }
+ _ => {}
+ }
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ", {res_name});")?;
+ self.named_expressions.insert(result, res_name);
+ }
+ Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ // Start the switch
+ write!(self.out, "{level}")?;
+ write!(self.out, "switch(")?;
+ self.write_expr(module, selector, func_ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ // Write all cases
+ let indent_level_1 = level.next();
+ let indent_level_2 = indent_level_1.next();
+
+ for (i, case) in cases.iter().enumerate() {
+ match case.value {
+ crate::SwitchValue::I32(value) => {
+ write!(self.out, "{indent_level_1}case {value}:")?
+ }
+ crate::SwitchValue::U32(value) => {
+ write!(self.out, "{indent_level_1}case {value}u:")?
+ }
+ crate::SwitchValue::Default => {
+ write!(self.out, "{indent_level_1}default:")?
+ }
+ }
+
+ // The new block is not only stylistic, it plays a role here:
+ // We might end up having to write the same case body
+ // multiple times due to FXC not supporting fallthrough.
+ // Therefore, some `Expression`s written by `Statement::Emit`
+ // will end up having the same name (`_expr<handle_index>`).
+ // So we need to put each case in its own scope.
+ let write_block_braces = !(case.fall_through && case.body.is_empty());
+ if write_block_braces {
+ writeln!(self.out, " {{")?;
+ } else {
+ writeln!(self.out)?;
+ }
+
+ // Although FXC does support a series of case clauses before
+ // a block[^yes], it does not support fallthrough from a
+ // non-empty case block to the next[^no]. If this case has a
+ // non-empty body with a fallthrough, emulate that by
+ // duplicating the bodies of all the cases it would fall
+ // into as extensions of this case's own body. This makes
+ // the HLSL output potentially quadratic in the size of the
+ // Naga IR.
+ //
+ // [^yes]: ```hlsl
+ // case 1:
+ // case 2: do_stuff()
+ // ```
+ // [^no]: ```hlsl
+ // case 1: do_this();
+ // case 2: do_that();
+ // ```
+ if case.fall_through && !case.body.is_empty() {
+ let curr_len = i + 1;
+ let end_case_idx = curr_len
+ + cases
+ .iter()
+ .skip(curr_len)
+ .position(|case| !case.fall_through)
+ .unwrap();
+ let indent_level_3 = indent_level_2.next();
+ for case in &cases[i..=end_case_idx] {
+ writeln!(self.out, "{indent_level_2}{{")?;
+ let prev_len = self.named_expressions.len();
+ for sta in case.body.iter() {
+ self.write_stmt(module, sta, func_ctx, indent_level_3)?;
+ }
+ // Clear all named expressions that were previously inserted by the statements in the block
+ self.named_expressions.truncate(prev_len);
+ writeln!(self.out, "{indent_level_2}}}")?;
+ }
+
+ let last_case = &cases[end_case_idx];
+ if last_case.body.last().map_or(true, |s| !s.is_terminator()) {
+ writeln!(self.out, "{indent_level_2}break;")?;
+ }
+ } else {
+ for sta in case.body.iter() {
+ self.write_stmt(module, sta, func_ctx, indent_level_2)?;
+ }
+ if !case.fall_through
+ && case.body.last().map_or(true, |s| !s.is_terminator())
+ {
+ writeln!(self.out, "{indent_level_2}break;")?;
+ }
+ }
+
+ if write_block_braces {
+ writeln!(self.out, "{indent_level_1}}}")?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::RayQuery { .. } => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Helper method to write expressions
+ ///
+ /// # Notes
+ /// Doesn't add any newlines or leading/trailing spaces
+ pub(super) fn write_expr(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ use crate::Expression;
+
+ // Handle the special semantics for base vertex/instance
+ let ff_input = if self.options.special_constants_binding.is_some() {
+ func_ctx.is_fixed_function_input(expr, module)
+ } else {
+ None
+ };
+ let closing_bracket = match ff_input {
+ Some(crate::BuiltIn::VertexIndex) => {
+ write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_BASE_VERTEX} + ")?;
+ ")"
+ }
+ Some(crate::BuiltIn::InstanceIndex) => {
+ write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_BASE_INSTANCE} + ",)?;
+ ")"
+ }
+ Some(crate::BuiltIn::NumWorkGroups) => {
+ //Note: despite their names (`BASE_VERTEX` and `BASE_INSTANCE`),
+ // in compute shaders the special constants contain the number
+ // of workgroups, which we are using here.
+ write!(
+ self.out,
+ "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_BASE_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_BASE_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
+ )?;
+ return Ok(());
+ }
+ _ => "",
+ };
+
+ if let Some(name) = self.named_expressions.get(&expr) {
+ write!(self.out, "{name}{closing_bracket}")?;
+ return Ok(());
+ }
+
+ let expression = &func_ctx.expressions[expr];
+
+ match *expression {
+ Expression::Constant(constant) => self.write_constant(module, constant)?,
+ Expression::Compose { ty, ref components } => {
+ match module.types[ty].inner {
+ TypeInner::Struct { .. } | TypeInner::Array { .. } => {
+ self.write_wrapped_constructor_function_name(
+ module,
+ WrappedConstructor { ty },
+ )?;
+ }
+ _ => {
+ self.write_type(module, ty)?;
+ }
+ };
+
+ write!(self.out, "(")?;
+
+ for (index, &component) in components.iter().enumerate() {
+ if index != 0 {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ self.write_expr(module, component, func_ctx)?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ // All of the multiplication can be expressed as `mul`,
+ // except vector * vector, which needs to use the "*" operator.
+ Expression::Binary {
+ op: crate::BinaryOperator::Multiply,
+ left,
+ right,
+ } if func_ctx.info[left].ty.inner_with(&module.types).is_matrix()
+ || func_ctx.info[right]
+ .ty
+ .inner_with(&module.types)
+ .is_matrix() =>
+ {
+ // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
+ write!(self.out, "mul(")?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+
+ // TODO: handle undefined behavior of BinaryOperator::Modulo
+ //
+ // sint:
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ // if sign(left) != sign(right) return result as defined by WGSL
+ //
+ // uint:
+ // if right == 0 return 0
+ //
+ // float:
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+
+ // While HLSL supports float operands with the % operator it is only
+ // defined in cases where both sides are either positive or negative.
+ Expression::Binary {
+ op: crate::BinaryOperator::Modulo,
+ left,
+ right,
+ } if func_ctx.info[left]
+ .ty
+ .inner_with(&module.types)
+ .scalar_kind()
+ == Some(crate::ScalarKind::Float) =>
+ {
+ write!(self.out, "fmod(")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Binary { op, left, right } => {
+ write!(self.out, "(")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, " {} ", crate::back::binary_operation_str(op))?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Access { base, index } => {
+ if let Some(crate::AddressSpace::Storage { .. }) = func_ctx.info[expr]
+ .ty
+ .inner_with(&module.types)
+ .pointer_space()
+ {
+ // do nothing, the chain is written on `Load`/`Store`
+ } else {
+ // We use the function __get_col_of_matCx2 here in cases
+ // where `base`s type resolves to a matCx2 and is part of a
+ // struct member with type of (possibly nested) array of matCx2's.
+ //
+ // Note that this only works for `Load`s and we handle
+ // `Store`s differently in `Statement::Store`.
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
+ {
+ write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ write!(self.out, ")")?;
+ return Ok(());
+ }
+
+ let base_ty_res = &func_ctx.info[base].ty;
+ let resolved = base_ty_res.inner_with(&module.types);
+
+ let non_uniform_qualifier = match *resolved {
+ TypeInner::BindingArray { .. } => {
+ let uniformity = &func_ctx.info[index].uniformity;
+
+ uniformity.non_uniform_result.is_some()
+ }
+ _ => false,
+ };
+
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, "[")?;
+ if non_uniform_qualifier {
+ write!(self.out, "NonUniformResourceIndex(")?;
+ }
+ self.write_expr(module, index, func_ctx)?;
+ if non_uniform_qualifier {
+ write!(self.out, ")")?;
+ }
+ write!(self.out, "]")?;
+ }
+ }
+ Expression::AccessIndex { base, index } => {
+ if let Some(crate::AddressSpace::Storage { .. }) = func_ctx.info[expr]
+ .ty
+ .inner_with(&module.types)
+ .pointer_space()
+ {
+ // do nothing, the chain is written on `Load`/`Store`
+ } else {
+ fn write_access<W: fmt::Write>(
+ writer: &mut super::Writer<'_, W>,
+ resolved: &TypeInner,
+ base_ty_handle: Option<Handle<crate::Type>>,
+ index: u32,
+ ) -> BackendResult {
+ match *resolved {
+ // We specifcally lift the ValuePointer to this case. While `[0]` is valid
+ // HLSL for any vector behind a value pointer, FXC completely miscompiles
+ // it and generates completely nonsensical DXBC.
+ //
+ // See https://github.com/gfx-rs/naga/issues/2095 for more details.
+ TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
+ // Write vector access as a swizzle
+ write!(writer.out, ".{}", back::COMPONENTS[index as usize])?
+ }
+ TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::BindingArray { .. } => write!(writer.out, "[{index}]")?,
+ TypeInner::Struct { .. } => {
+ // This will never panic in case the type is a `Struct`, this is not true
+ // for other types so we can only check while inside this match arm
+ let ty = base_ty_handle.unwrap();
+
+ write!(
+ writer.out,
+ ".{}",
+ &writer.names[&NameKey::StructMember(ty, index)]
+ )?
+ }
+ ref other => {
+ return Err(Error::Custom(format!("Cannot index {other:?}")))
+ }
+ }
+ Ok(())
+ }
+
+ // We write the matrix column access in a special way since
+ // the type of `base` is our special __matCx2 struct.
+ if let Some(MatrixType {
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ ..
+ }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
+ {
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, "._{index}")?;
+ return Ok(());
+ }
+
+ let base_ty_res = &func_ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+ let base_ty_handle = match *resolved {
+ TypeInner::Pointer { base, .. } => {
+ resolved = &module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ //
+ // We handle matrix reconstruction here for Loads.
+ // Stores are handled directly by `Statement::Store`.
+ if let TypeInner::Struct { ref members, .. } = *resolved {
+ let member = &members[index as usize];
+
+ match module.types[member.ty].inner {
+ TypeInner::Matrix {
+ rows: crate::VectorSize::Bi,
+ ..
+ } if member.binding.is_none() => {
+ let ty = base_ty_handle.unwrap();
+ self.write_wrapped_struct_matrix_get_function_name(
+ WrappedStructMatrixAccess { ty, index },
+ )?;
+ write!(self.out, "(")?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ")")?;
+ return Ok(());
+ }
+ _ => {}
+ }
+ }
+
+ self.write_expr(module, base, func_ctx)?;
+ write_access(self, resolved, base_ty_handle, index)?;
+ }
+ }
+ Expression::FunctionArgument(pos) => {
+ let key = match func_ctx.ty {
+ back::FunctionType::Function(handle) => NameKey::FunctionArgument(handle, pos),
+ back::FunctionType::EntryPoint(index) => {
+ NameKey::EntryPointArgument(index, pos)
+ }
+ };
+ let name = &self.names[&key];
+ write!(self.out, "{name}")?;
+ }
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ use crate::SampleLevel as Sl;
+ const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
+
+ let (base_str, component_str) = match gather {
+ Some(component) => ("Gather", COMPONENTS[component as usize]),
+ None => ("Sample", ""),
+ };
+ let cmp_str = match depth_ref {
+ Some(_) => "Cmp",
+ None => "",
+ };
+ let level_str = match level {
+ Sl::Zero if gather.is_none() => "LevelZero",
+ Sl::Auto | Sl::Zero => "",
+ Sl::Exact(_) => "Level",
+ Sl::Bias(_) => "Bias",
+ Sl::Gradient { .. } => "Grad",
+ };
+
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
+ self.write_expr(module, sampler, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_texture_coordinates(
+ "float",
+ coordinate,
+ array_index,
+ None,
+ module,
+ func_ctx,
+ )?;
+
+ if let Some(depth_ref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(module, depth_ref, func_ctx)?;
+ }
+
+ match level {
+ Sl::Auto | Sl::Zero => {}
+ Sl::Exact(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Bias(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Gradient { x, y } => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, x, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, y, func_ctx)?;
+ }
+ }
+
+ if let Some(offset) = offset {
+ write!(self.out, ", ")?;
+ self.write_constant(module, offset)?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ Expression::ImageQuery { image, query } => {
+ // use wrapped image query function
+ if let TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } = *func_ctx.info[image].ty.inner_with(&module.types)
+ {
+ let wrapped_image_query = WrappedImageQuery {
+ dim,
+ arrayed,
+ class,
+ query: query.into(),
+ };
+
+ self.write_wrapped_image_query_function_name(wrapped_image_query)?;
+ write!(self.out, "(")?;
+ // Image always first param
+ self.write_expr(module, image, func_ctx)?;
+ if let crate::ImageQuery::Size { level: Some(level) } = query {
+ write!(self.out, ", ")?;
+ self.write_expr(module, level, func_ctx)?;
+ }
+ write!(self.out, ")")?;
+ }
+ }
+ Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ".Load(")?;
+
+ self.write_texture_coordinates(
+ "int",
+ coordinate,
+ array_index,
+ level,
+ module,
+ func_ctx,
+ )?;
+
+ if let Some(sample) = sample {
+ write!(self.out, ", ")?;
+ self.write_expr(module, sample, func_ctx)?;
+ }
+
+ // close bracket for Load function
+ write!(self.out, ")")?;
+
+ // return x component if return type is scalar
+ if let TypeInner::Scalar { .. } = *func_ctx.info[expr].ty.inner_with(&module.types)
+ {
+ write!(self.out, ".x")?;
+ }
+ }
+ Expression::GlobalVariable(handle) => match module.global_variables[handle].space {
+ crate::AddressSpace::Storage { .. } => {}
+ _ => {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{name}")?;
+ }
+ },
+ Expression::LocalVariable(handle) => {
+ write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
+ }
+ Expression::Load { pointer } => {
+ match func_ctx.info[pointer]
+ .ty
+ .inner_with(&module.types)
+ .pointer_space()
+ {
+ Some(crate::AddressSpace::Storage { .. }) => {
+ let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
+ let result_ty = func_ctx.info[expr].ty.clone();
+ self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
+ }
+ _ => {
+ let mut close_paren = false;
+
+ // We cast the value loaded to a native HLSL floatCx2
+ // in cases where it is of type:
+ // - __matCx2 or
+ // - a (possibly nested) array of __matCx2's
+ if let Some(MatrixType {
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ ..
+ }) = get_inner_matrix_of_struct_array_member(
+ module, pointer, func_ctx, false,
+ )
+ .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
+ {
+ let mut resolved = func_ctx.info[pointer].ty.inner_with(&module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ }
+
+ write!(self.out, "((")?;
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_type(module, base)?;
+ self.write_array_size(module, base, size)?;
+ } else {
+ self.write_value_type(module, resolved)?;
+ }
+ write!(self.out, ")")?;
+ close_paren = true;
+ }
+
+ self.write_expr(module, pointer, func_ctx)?;
+
+ if close_paren {
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ }
+ Expression::Unary { op, expr } => {
+ use crate::{ScalarKind as Sk, UnaryOperator as Uo};
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
+ let op_str = match op {
+ Uo::Negate => "-",
+ Uo::Not => match func_ctx.info[expr]
+ .ty
+ .inner_with(&module.types)
+ .scalar_kind()
+ {
+ Some(Sk::Sint) | Some(Sk::Uint) => "~",
+ Some(Sk::Bool) => "!",
+ ref other => {
+ return Err(Error::Custom(format!(
+ "Cannot apply not to type {other:?}"
+ )))
+ }
+ },
+ };
+ write!(self.out, "{op_str}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ let inner = func_ctx.info[expr].ty.inner_with(&module.types);
+ match convert {
+ Some(dst_width) => {
+ match *inner {
+ TypeInner::Vector { size, .. } => {
+ write!(
+ self.out,
+ "{}{}(",
+ kind.to_hlsl_str(dst_width)?,
+ back::vector_size_str(size)
+ )?;
+ }
+ TypeInner::Scalar { .. } => {
+ write!(self.out, "{}(", kind.to_hlsl_str(dst_width)?,)?;
+ }
+ TypeInner::Matrix { columns, rows, .. } => {
+ write!(
+ self.out,
+ "{}{}x{}(",
+ kind.to_hlsl_str(dst_width)?,
+ back::vector_size_str(columns),
+ back::vector_size_str(rows)
+ )?;
+ }
+ _ => {
+ return Err(Error::Unimplemented(format!(
+ "write_expr expression::as {inner:?}"
+ )));
+ }
+ };
+ }
+ None => {
+ write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
+ }
+ }
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ enum Function {
+ Asincosh { is_sin: bool },
+ Atanh,
+ Unpack2x16float,
+ Regular(&'static str),
+ MissingIntOverload(&'static str),
+ CountTrailingZeros,
+ CountLeadingZeros,
+ }
+
+ let fun = match fun {
+ // comparison
+ Mf::Abs => Function::Regular("abs"),
+ Mf::Min => Function::Regular("min"),
+ Mf::Max => Function::Regular("max"),
+ Mf::Clamp => Function::Regular("clamp"),
+ Mf::Saturate => Function::Regular("saturate"),
+ // trigonometry
+ Mf::Cos => Function::Regular("cos"),
+ Mf::Cosh => Function::Regular("cosh"),
+ Mf::Sin => Function::Regular("sin"),
+ Mf::Sinh => Function::Regular("sinh"),
+ Mf::Tan => Function::Regular("tan"),
+ Mf::Tanh => Function::Regular("tanh"),
+ Mf::Acos => Function::Regular("acos"),
+ Mf::Asin => Function::Regular("asin"),
+ Mf::Atan => Function::Regular("atan"),
+ Mf::Atan2 => Function::Regular("atan2"),
+ Mf::Asinh => Function::Asincosh { is_sin: true },
+ Mf::Acosh => Function::Asincosh { is_sin: false },
+ Mf::Atanh => Function::Atanh,
+ Mf::Radians => Function::Regular("radians"),
+ Mf::Degrees => Function::Regular("degrees"),
+ // decomposition
+ Mf::Ceil => Function::Regular("ceil"),
+ Mf::Floor => Function::Regular("floor"),
+ Mf::Round => Function::Regular("round"),
+ Mf::Fract => Function::Regular("frac"),
+ Mf::Trunc => Function::Regular("trunc"),
+ Mf::Modf => Function::Regular("modf"),
+ Mf::Frexp => Function::Regular("frexp"),
+ Mf::Ldexp => Function::Regular("ldexp"),
+ // exponent
+ Mf::Exp => Function::Regular("exp"),
+ Mf::Exp2 => Function::Regular("exp2"),
+ Mf::Log => Function::Regular("log"),
+ Mf::Log2 => Function::Regular("log2"),
+ Mf::Pow => Function::Regular("pow"),
+ // geometry
+ Mf::Dot => Function::Regular("dot"),
+ //Mf::Outer => ,
+ Mf::Cross => Function::Regular("cross"),
+ Mf::Distance => Function::Regular("distance"),
+ Mf::Length => Function::Regular("length"),
+ Mf::Normalize => Function::Regular("normalize"),
+ Mf::FaceForward => Function::Regular("faceforward"),
+ Mf::Reflect => Function::Regular("reflect"),
+ Mf::Refract => Function::Regular("refract"),
+ // computational
+ Mf::Sign => Function::Regular("sign"),
+ Mf::Fma => Function::Regular("mad"),
+ Mf::Mix => Function::Regular("lerp"),
+ Mf::Step => Function::Regular("step"),
+ Mf::SmoothStep => Function::Regular("smoothstep"),
+ Mf::Sqrt => Function::Regular("sqrt"),
+ Mf::InverseSqrt => Function::Regular("rsqrt"),
+ //Mf::Inverse =>,
+ Mf::Transpose => Function::Regular("transpose"),
+ Mf::Determinant => Function::Regular("determinant"),
+ // bits
+ Mf::CountTrailingZeros => Function::CountTrailingZeros,
+ Mf::CountLeadingZeros => Function::CountLeadingZeros,
+ Mf::CountOneBits => Function::MissingIntOverload("countbits"),
+ Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
+ Mf::FindLsb => Function::Regular("firstbitlow"),
+ Mf::FindMsb => Function::Regular("firstbithigh"),
+ Mf::Unpack2x16float => Function::Unpack2x16float,
+ _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
+ };
+
+ match fun {
+ Function::Asincosh { is_sin } => {
+ write!(self.out, "log(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " + sqrt(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " * ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ match is_sin {
+ true => write!(self.out, " + 1.0))")?,
+ false => write!(self.out, " - 1.0))")?,
+ }
+ }
+ Function::Atanh => {
+ write!(self.out, "0.5 * log((1.0 + ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") / (1.0 - ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ }
+ Function::Unpack2x16float => {
+ write!(self.out, "float2(f16tof32(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "), f16tof32((")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") >> 16))")?;
+ }
+ Function::Regular(fun_name) => {
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ if let Some(arg) = arg1 {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ if let Some(arg) = arg2 {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ if let Some(arg) = arg3 {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ write!(self.out, ")")?
+ }
+ Function::MissingIntOverload(fun_name) => {
+ let scalar_kind = &func_ctx.info[arg]
+ .ty
+ .inner_with(&module.types)
+ .scalar_kind();
+ if let Some(ScalarKind::Sint) = *scalar_kind {
+ write!(self.out, "asint({fun_name}(asuint(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ } else {
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Function::CountTrailingZeros => {
+ match *func_ctx.info[arg].ty.inner_with(&module.types) {
+ TypeInner::Vector { size, kind, .. } => {
+ let s = match size {
+ crate::VectorSize::Bi => ".xx",
+ crate::VectorSize::Tri => ".xxx",
+ crate::VectorSize::Quad => ".xxxx",
+ };
+
+ if let ScalarKind::Uint = kind {
+ write!(self.out, "min((32u){s}, firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "asint(min((32u){s}, asuint(firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))))")?;
+ }
+ }
+ TypeInner::Scalar { kind, .. } => {
+ if let ScalarKind::Uint = kind {
+ write!(self.out, "min(32u, firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "asint(min(32u, asuint(firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))))")?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ return Ok(());
+ }
+ Function::CountLeadingZeros => {
+ match *func_ctx.info[arg].ty.inner_with(&module.types) {
+ TypeInner::Vector { size, kind, .. } => {
+ let s = match size {
+ crate::VectorSize::Bi => ".xx",
+ crate::VectorSize::Tri => ".xxx",
+ crate::VectorSize::Quad => ".xxxx",
+ };
+
+ if let ScalarKind::Uint = kind {
+ write!(self.out, "asuint((31){s} - firstbithigh(")?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ " < (0){s} ? (0){s} : (31){s} - firstbithigh("
+ )?;
+ }
+ }
+ TypeInner::Scalar { kind, .. } => {
+ if let ScalarKind::Uint = kind {
+ write!(self.out, "asuint(31 - firstbithigh(")?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " < 0 ? 0 : 31 - firstbithigh(")?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+
+ return Ok(());
+ }
+ }
+ }
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.write_expr(module, vector, func_ctx)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ self.out.write_char(back::COMPONENTS[sc as usize])?;
+ }
+ }
+ Expression::ArrayLength(expr) => {
+ let var_handle = match func_ctx.expressions[expr] {
+ Expression::AccessIndex { base, index: _ } => {
+ match func_ctx.expressions[base] {
+ Expression::GlobalVariable(handle) => handle,
+ _ => unreachable!(),
+ }
+ }
+ Expression::GlobalVariable(handle) => handle,
+ _ => unreachable!(),
+ };
+
+ let var = &module.global_variables[var_handle];
+ let (offset, stride) = match module.types[var.ty].inner {
+ TypeInner::Array { stride, .. } => (0, stride),
+ TypeInner::Struct { ref members, .. } => {
+ let last = members.last().unwrap();
+ let stride = match module.types[last.ty].inner {
+ TypeInner::Array { stride, .. } => stride,
+ _ => unreachable!(),
+ };
+ (last.offset, stride)
+ }
+ _ => unreachable!(),
+ };
+
+ let storage_access = match var.space {
+ crate::AddressSpace::Storage { access } => access,
+ _ => crate::StorageAccess::default(),
+ };
+ let wrapped_array_length = WrappedArrayLength {
+ writable: storage_access.contains(crate::StorageAccess::STORE),
+ };
+
+ write!(self.out, "((")?;
+ self.write_wrapped_array_length_function_name(wrapped_array_length)?;
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "({var_name}) - {offset}) / {stride})")?
+ }
+ Expression::Derivative { axis, ctrl, expr } => {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
+ let tail = match ctrl {
+ Ctrl::Coarse => "coarse",
+ Ctrl::Fine => "fine",
+ Ctrl::None => unreachable!(),
+ };
+ write!(self.out, "abs(ddx_{tail}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")) + abs(ddy_{tail}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, "))")?
+ } else {
+ let fun_str = match (axis, ctrl) {
+ (Axis::X, Ctrl::Coarse) => "ddx_coarse",
+ (Axis::X, Ctrl::Fine) => "ddx_fine",
+ (Axis::X, Ctrl::None) => "ddx",
+ (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
+ (Axis::Y, Ctrl::Fine) => "ddy_fine",
+ (Axis::Y, Ctrl::None) => "ddy",
+ (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
+ (Axis::Width, Ctrl::None) => "fwidth",
+ };
+ write!(self.out, "{fun_str}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ }
+ Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+
+ let fun_str = match fun {
+ Rf::All => "all",
+ Rf::Any => "any",
+ Rf::IsNan => "isnan",
+ Rf::IsInf => "isinf",
+ Rf::IsFinite => "isfinite",
+ Rf::IsNormal => "isnormal",
+ };
+ write!(self.out, "{fun_str}(")?;
+ self.write_expr(module, argument, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Splat { size, value } => {
+ // hlsl is not supported one value constructor
+ // if we write, for example, int4(0), dxc returns error:
+ // error: too few elements in vector initialization (expected 4 elements, have 1)
+ let number_of_components = match size {
+ crate::VectorSize::Bi => "xx",
+ crate::VectorSize::Tri => "xxx",
+ crate::VectorSize::Quad => "xxxx",
+ };
+ write!(self.out, "(")?;
+ self.write_expr(module, value, func_ctx)?;
+ write!(self.out, ").{number_of_components}")?
+ }
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ write!(self.out, "(")?;
+ self.write_expr(module, condition, func_ctx)?;
+ write!(self.out, " ? ")?;
+ self.write_expr(module, accept, func_ctx)?;
+ write!(self.out, " : ")?;
+ self.write_expr(module, reject, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ // Not supported yet
+ Expression::RayQueryGetIntersection { .. } => unreachable!(),
+ // Nothing to do here, since call expression already cached
+ Expression::CallResult(_)
+ | Expression::AtomicResult { .. }
+ | Expression::RayQueryProceedResult => {}
+ }
+
+ if !closing_bracket.is_empty() {
+ write!(self.out, "{closing_bracket}")?;
+ }
+ Ok(())
+ }
+
+ /// Helper method used to write constants
+ ///
+ /// # Notes
+ /// Doesn't add any newlines or leading/trailing spaces
+ fn write_constant(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Constant>,
+ ) -> BackendResult {
+ let constant = &module.constants[handle];
+ match constant.inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ ref value,
+ } => {
+ if constant.name.is_some() {
+ write!(self.out, "{}", &self.names[&NameKey::Constant(handle)])?;
+ } else {
+ self.write_scalar_value(*value)?;
+ }
+ }
+ crate::ConstantInner::Composite { ty, ref components } => {
+ self.write_composite_constant(module, ty, components)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn write_composite_constant(
+ &mut self,
+ module: &Module,
+ ty: Handle<crate::Type>,
+ components: &[Handle<crate::Constant>],
+ ) -> BackendResult {
+ match module.types[ty].inner {
+ TypeInner::Struct { .. } | TypeInner::Array { .. } => {
+ self.write_wrapped_constructor_function_name(module, WrappedConstructor { ty })?;
+ }
+ _ => {
+ self.write_type(module, ty)?;
+ }
+ };
+ write!(self.out, "(")?;
+ for (index, constant) in components.iter().enumerate() {
+ self.write_constant(module, *constant)?;
+ // Only write a comma if isn't the last element
+ if index != components.len().saturating_sub(1) {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ }
+ write!(self.out, ")")?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write [`ScalarValue`](crate::ScalarValue)
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_scalar_value(&mut self, value: crate::ScalarValue) -> BackendResult {
+ use crate::ScalarValue as Sv;
+
+ match value {
+ Sv::Sint(value) => write!(self.out, "{value}")?,
+ Sv::Uint(value) => write!(self.out, "{value}u")?,
+ // Floats are written using `Debug` instead of `Display` because it always appends the
+ // decimal part even it's zero
+ Sv::Float(value) => write!(self.out, "{value:?}")?,
+ Sv::Bool(value) => write!(self.out, "{value}")?,
+ }
+
+ Ok(())
+ }
+
+ fn write_named_expr(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Expression>,
+ name: String,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ match ctx.info[handle].ty {
+ proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
+ TypeInner::Struct { .. } => {
+ let ty_name = &self.names[&NameKey::Type(ty_handle)];
+ write!(self.out, "{ty_name}")?;
+ }
+ _ => {
+ self.write_type(module, ty_handle)?;
+ }
+ },
+ proc::TypeResolution::Value(ref inner) => {
+ self.write_value_type(module, inner)?;
+ }
+ }
+
+ let base_ty_res = &ctx.info[handle].ty;
+ let resolved = base_ty_res.inner_with(&module.types);
+
+ write!(self.out, " {name}")?;
+ // If rhs is a array type, we should write array size
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, " = ")?;
+ self.write_expr(module, handle, ctx)?;
+ writeln!(self.out, ";")?;
+ self.named_expressions.insert(handle, name);
+
+ Ok(())
+ }
+
+ /// Helper function that write default zero initialization
+ fn write_default_init(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ write!(self.out, "(")?;
+ self.write_type(module, ty)?;
+ if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, ")0")?;
+ Ok(())
+ }
+
+ fn write_barrier(&mut self, barrier: crate::Barrier, level: back::Level) -> BackendResult {
+ if barrier.contains(crate::Barrier::STORAGE) {
+ writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
+ }
+ if barrier.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
+ }
+ Ok(())
+ }
+}
+
+pub(super) struct MatrixType {
+ pub(super) columns: crate::VectorSize,
+ pub(super) rows: crate::VectorSize,
+ pub(super) width: crate::Bytes,
+}
+
+pub(super) fn get_inner_matrix_data(
+ module: &Module,
+ handle: Handle<crate::Type>,
+) -> Option<MatrixType> {
+ match module.types[handle].inner {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => Some(MatrixType {
+ columns,
+ rows,
+ width,
+ }),
+ TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
+ _ => None,
+ }
+}
+
+/// Returns the matrix data if the access chain starting at `base`:
+/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
+/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
+/// - ends at an expression with resolved type of [`TypeInner::Struct`]
+pub(super) fn get_inner_matrix_of_struct_array_member(
+ module: &Module,
+ base: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ direct: bool,
+) -> Option<MatrixType> {
+ let mut mat_data = None;
+ let mut array_base = None;
+
+ let mut current_base = base;
+ loop {
+ let mut resolved = func_ctx.info[current_base].ty.inner_with(&module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ };
+
+ match *resolved {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ mat_data = Some(MatrixType {
+ columns,
+ rows,
+ width,
+ })
+ }
+ TypeInner::Array { base, .. } => {
+ array_base = Some(base);
+ }
+ TypeInner::Struct { .. } => {
+ if let Some(array_base) = array_base {
+ if direct {
+ return mat_data;
+ } else {
+ return get_inner_matrix_data(module, array_base);
+ }
+ }
+
+ break;
+ }
+ _ => break,
+ }
+
+ current_base = match func_ctx.expressions[current_base] {
+ crate::Expression::Access { base, .. } => base,
+ crate::Expression::AccessIndex { base, .. } => base,
+ _ => break,
+ };
+ }
+ None
+}
+
+/// Returns the matrix data if the access chain starting at `base`:
+/// - starts with an expression with resolved type of [`TypeInner::Matrix`]
+/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
+/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
+fn get_inner_matrix_of_global_uniform(
+ module: &Module,
+ base: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+) -> Option<MatrixType> {
+ let mut mat_data = None;
+ let mut array_base = None;
+
+ let mut current_base = base;
+ loop {
+ let mut resolved = func_ctx.info[current_base].ty.inner_with(&module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ };
+
+ match *resolved {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ mat_data = Some(MatrixType {
+ columns,
+ rows,
+ width,
+ })
+ }
+ TypeInner::Array { base, .. } => {
+ array_base = Some(base);
+ }
+ _ => break,
+ }
+
+ current_base = match func_ctx.expressions[current_base] {
+ crate::Expression::Access { base, .. } => base,
+ crate::Expression::AccessIndex { base, .. } => base,
+ crate::Expression::GlobalVariable(handle)
+ if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
+ {
+ return mat_data.or_else(|| {
+ array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
+ })
+ }
+ _ => break,
+ };
+ }
+ None
+}
diff --git a/third_party/rust/naga/src/back/mod.rs b/third_party/rust/naga/src/back/mod.rs
new file mode 100644
index 0000000000..8467ee787b
--- /dev/null
+++ b/third_party/rust/naga/src/back/mod.rs
@@ -0,0 +1,250 @@
+/*!
+Backend functions that export shader [`Module`](super::Module)s into binary and text formats.
+*/
+#![allow(dead_code)] // can be dead if none of the enabled backends need it
+
+#[cfg(feature = "dot-out")]
+pub mod dot;
+#[cfg(feature = "glsl-out")]
+pub mod glsl;
+#[cfg(feature = "hlsl-out")]
+pub mod hlsl;
+#[cfg(feature = "msl-out")]
+pub mod msl;
+#[cfg(feature = "spv-out")]
+pub mod spv;
+#[cfg(feature = "wgsl-out")]
+pub mod wgsl;
+
+const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];
+const INDENT: &str = " ";
+const BAKE_PREFIX: &str = "_e";
+
+type NeedBakeExpressions = crate::FastHashSet<crate::Handle<crate::Expression>>;
+
+#[derive(Clone, Copy)]
+struct Level(usize);
+
+impl Level {
+ const fn next(&self) -> Self {
+ Level(self.0 + 1)
+ }
+}
+
+impl std::fmt::Display for Level {
+ fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
+ (0..self.0).try_for_each(|_| formatter.write_str(INDENT))
+ }
+}
+
+/// Stores the current function type (either a regular function or an entry point)
+///
+/// Also stores data needed to identify it (handle for a regular function or index for an entry point)
+enum FunctionType {
+ /// A regular function and it's handle
+ Function(crate::Handle<crate::Function>),
+ /// A entry point and it's index
+ EntryPoint(crate::proc::EntryPointIndex),
+}
+
+impl FunctionType {
+ fn is_compute_entry_point(&self, module: &crate::Module) -> bool {
+ match *self {
+ FunctionType::EntryPoint(index) => {
+ module.entry_points[index as usize].stage == crate::ShaderStage::Compute
+ }
+ FunctionType::Function(_) => false,
+ }
+ }
+}
+
+/// Helper structure that stores data needed when writing the function
+struct FunctionCtx<'a> {
+ /// The current function being written
+ ty: FunctionType,
+ /// Analysis about the function
+ info: &'a crate::valid::FunctionInfo,
+ /// The expression arena of the current function being written
+ expressions: &'a crate::Arena<crate::Expression>,
+ /// Map of expressions that have associated variable names
+ named_expressions: &'a crate::NamedExpressions,
+}
+
+impl FunctionCtx<'_> {
+ /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for a local in the current function
+ const fn name_key(&self, local: crate::Handle<crate::LocalVariable>) -> crate::proc::NameKey {
+ match self.ty {
+ FunctionType::Function(handle) => crate::proc::NameKey::FunctionLocal(handle, local),
+ FunctionType::EntryPoint(idx) => crate::proc::NameKey::EntryPointLocal(idx, local),
+ }
+ }
+
+ /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for a function argument.
+ ///
+ /// # Panics
+ /// - If the function arguments are less or equal to `arg`
+ const fn argument_key(&self, arg: u32) -> crate::proc::NameKey {
+ match self.ty {
+ FunctionType::Function(handle) => crate::proc::NameKey::FunctionArgument(handle, arg),
+ FunctionType::EntryPoint(ep_index) => {
+ crate::proc::NameKey::EntryPointArgument(ep_index, arg)
+ }
+ }
+ }
+
+ // Returns true if the given expression points to a fixed-function pipeline input.
+ fn is_fixed_function_input(
+ &self,
+ mut expression: crate::Handle<crate::Expression>,
+ module: &crate::Module,
+ ) -> Option<crate::BuiltIn> {
+ let ep_function = match self.ty {
+ FunctionType::Function(_) => return None,
+ FunctionType::EntryPoint(ep_index) => &module.entry_points[ep_index as usize].function,
+ };
+ let mut built_in = None;
+ loop {
+ match self.expressions[expression] {
+ crate::Expression::FunctionArgument(arg_index) => {
+ return match ep_function.arguments[arg_index as usize].binding {
+ Some(crate::Binding::BuiltIn(bi)) => Some(bi),
+ _ => built_in,
+ };
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ match *self.info[base].ty.inner_with(&module.types) {
+ crate::TypeInner::Struct { ref members, .. } => {
+ if let Some(crate::Binding::BuiltIn(bi)) =
+ members[index as usize].binding
+ {
+ built_in = Some(bi);
+ }
+ }
+ _ => return None,
+ }
+ expression = base;
+ }
+ _ => return None,
+ }
+ }
+ }
+}
+
+impl crate::Expression {
+ /// Returns the ref count, upon reaching which this expression
+ /// should be considered for baking.
+ ///
+ /// Note: we have to cache any expressions that depend on the control flow,
+ /// or otherwise they may be moved into a non-uniform control flow, accidentally.
+ /// See the [module-level documentation][emit] for details.
+ ///
+ /// [emit]: index.html#expression-evaluation-time
+ const fn bake_ref_count(&self) -> usize {
+ match *self {
+ // accesses are never cached, only loads are
+ crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => usize::MAX,
+ // sampling may use the control flow, and image ops look better by themselves
+ crate::Expression::ImageSample { .. } | crate::Expression::ImageLoad { .. } => 1,
+ // derivatives use the control flow
+ crate::Expression::Derivative { .. } => 1,
+ // TODO: We need a better fix for named `Load` expressions
+ // More info - https://github.com/gfx-rs/naga/pull/914
+ // And https://github.com/gfx-rs/naga/issues/910
+ crate::Expression::Load { .. } => 1,
+ // cache expressions that are referenced multiple times
+ _ => 2,
+ }
+ }
+}
+
+/// Helper function that returns the string corresponding to the [`BinaryOperator`](crate::BinaryOperator)
+/// # Notes
+/// Used by `glsl-out`, `msl-out`, `wgsl-out`, `hlsl-out`.
+const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str {
+ use crate::BinaryOperator as Bo;
+ match op {
+ Bo::Add => "+",
+ Bo::Subtract => "-",
+ Bo::Multiply => "*",
+ Bo::Divide => "/",
+ Bo::Modulo => "%",
+ Bo::Equal => "==",
+ Bo::NotEqual => "!=",
+ Bo::Less => "<",
+ Bo::LessEqual => "<=",
+ Bo::Greater => ">",
+ Bo::GreaterEqual => ">=",
+ Bo::And => "&",
+ Bo::ExclusiveOr => "^",
+ Bo::InclusiveOr => "|",
+ Bo::LogicalAnd => "&&",
+ Bo::LogicalOr => "||",
+ Bo::ShiftLeft => "<<",
+ Bo::ShiftRight => ">>",
+ }
+}
+
+/// Helper function that returns the string corresponding to the [`VectorSize`](crate::VectorSize)
+/// # Notes
+/// Used by `msl-out`, `wgsl-out`, `hlsl-out`.
+const fn vector_size_str(size: crate::VectorSize) -> &'static str {
+ match size {
+ crate::VectorSize::Bi => "2",
+ crate::VectorSize::Tri => "3",
+ crate::VectorSize::Quad => "4",
+ }
+}
+
+impl crate::TypeInner {
+ const fn is_handle(&self) -> bool {
+ match *self {
+ crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => true,
+ _ => false,
+ }
+ }
+}
+
+impl crate::Statement {
+ /// Returns true if the statement directly terminates the current block.
+ ///
+ /// Used to decide whether case blocks require a explicit `break`.
+ pub const fn is_terminator(&self) -> bool {
+ match *self {
+ crate::Statement::Break
+ | crate::Statement::Continue
+ | crate::Statement::Return { .. }
+ | crate::Statement::Kill => true,
+ _ => false,
+ }
+ }
+}
+
+bitflags::bitflags! {
+ /// Ray flags, for a [`RayDesc`]'s `flags` field.
+ ///
+ /// Note that these exactly correspond to the SPIR-V "Ray Flags" mask, and
+ /// the SPIR-V backend passes them directly through to the
+ /// `OpRayQueryInitializeKHR` instruction. (We have to choose something, so
+ /// we might as well make one back end's life easier.)
+ ///
+ /// [`RayDesc`]: crate::Module::generate_ray_desc_type
+ #[derive(Default)]
+ pub struct RayFlag: u32 {
+ const OPAQUE = 0x01;
+ const NO_OPAQUE = 0x02;
+ const TERMINATE_ON_FIRST_HIT = 0x04;
+ const SKIP_CLOSEST_HIT_SHADER = 0x08;
+ const CULL_BACK_FACING = 0x10;
+ const CULL_FRONT_FACING = 0x20;
+ const CULL_OPAQUE = 0x40;
+ const CULL_NO_OPAQUE = 0x80;
+ const SKIP_TRIANGLES = 0x100;
+ const SKIP_AABBS = 0x200;
+ }
+}
+
+#[repr(u32)]
+enum RayIntersectionType {
+ Triangle = 1,
+ BoundingBox = 4,
+}
diff --git a/third_party/rust/naga/src/back/msl/keywords.rs b/third_party/rust/naga/src/back/msl/keywords.rs
new file mode 100644
index 0000000000..a3a9c52dcc
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/keywords.rs
@@ -0,0 +1,217 @@
+//TODO: find a complete list
+pub const RESERVED: &[&str] = &[
+ // control flow
+ "break",
+ "if",
+ "else",
+ "continue",
+ "goto",
+ "do",
+ "while",
+ "for",
+ "switch",
+ "case",
+ // types and values
+ "void",
+ "unsigned",
+ "signed",
+ "bool",
+ "char",
+ "int",
+ "uint",
+ "long",
+ "float",
+ "double",
+ "char8_t",
+ "wchar_t",
+ "true",
+ "false",
+ "nullptr",
+ "union",
+ "class",
+ "struct",
+ "enum",
+ // other
+ "main",
+ "using",
+ "decltype",
+ "sizeof",
+ "typeof",
+ "typedef",
+ "explicit",
+ "export",
+ "friend",
+ "namespace",
+ "operator",
+ "public",
+ "template",
+ "typename",
+ "typeid",
+ "co_await",
+ "co_return",
+ "co_yield",
+ "module",
+ "import",
+ "ray_data",
+ "vec_step",
+ "visible",
+ "as_type",
+ "this",
+ // qualifiers
+ "mutable",
+ "static",
+ "volatile",
+ "restrict",
+ "const",
+ "non-temporal",
+ "dereferenceable",
+ "invariant",
+ // exceptions
+ "throw",
+ "try",
+ "catch",
+ // operators
+ "const_cast",
+ "dynamic_cast",
+ "reinterpret_cast",
+ "static_cast",
+ "new",
+ "delete",
+ "and",
+ "and_eq",
+ "bitand",
+ "bitor",
+ "compl",
+ "not",
+ "not_eq",
+ "or",
+ "or_eq",
+ "xor",
+ "xor_eq",
+ "compl",
+ // Metal-specific
+ "constant",
+ "device",
+ "threadgroup",
+ "threadgroup_imageblock",
+ "kernel",
+ "compute",
+ "vertex",
+ "fragment",
+ "read_only",
+ "write_only",
+ "read_write",
+ "auto",
+ // Metal reserved types
+ "llong",
+ "ullong",
+ "quad",
+ "complex",
+ "imaginary",
+ // Metal constants
+ "CHAR_BIT",
+ "SCHAR_MAX",
+ "SCHAR_MIN",
+ "UCHAR_MAX",
+ "CHAR_MAX",
+ "CHAR_MIN",
+ "USHRT_MAX",
+ "SHRT_MAX",
+ "SHRT_MIN",
+ "UINT_MAX",
+ "INT_MAX",
+ "INT_MIN",
+ "ULONG_MAX",
+ "LONG_MAX",
+ "LONG_MIN",
+ "ULLONG_MAX",
+ "LLONG_MAX",
+ "LLONG_MIN",
+ "FLT_DIG",
+ "FLT_MANT_DIG",
+ "FLT_MAX_10_EXP",
+ "FLT_MAX_EXP",
+ "FLT_MIN_10_EXP",
+ "FLT_MIN_EXP",
+ "FLT_RADIX",
+ "FLT_MAX",
+ "FLT_MIN",
+ "FLT_EPSILON",
+ "FLT_DECIMAL_DIG",
+ "FP_ILOGB0",
+ "FP_ILOGB0",
+ "FP_ILOGBNAN",
+ "FP_ILOGBNAN",
+ "MAXFLOAT",
+ "HUGE_VALF",
+ "INFINITY",
+ "NAN",
+ "M_E_F",
+ "M_LOG2E_F",
+ "M_LOG10E_F",
+ "M_LN2_F",
+ "M_LN10_F",
+ "M_PI_F",
+ "M_PI_2_F",
+ "M_PI_4_F",
+ "M_1_PI_F",
+ "M_2_PI_F",
+ "M_2_SQRTPI_F",
+ "M_SQRT2_F",
+ "M_SQRT1_2_F",
+ "HALF_DIG",
+ "HALF_MANT_DIG",
+ "HALF_MAX_10_EXP",
+ "HALF_MAX_EXP",
+ "HALF_MIN_10_EXP",
+ "HALF_MIN_EXP",
+ "HALF_RADIX",
+ "HALF_MAX",
+ "HALF_MIN",
+ "HALF_EPSILON",
+ "HALF_DECIMAL_DIG",
+ "MAXHALF",
+ "HUGE_VALH",
+ "M_E_H",
+ "M_LOG2E_H",
+ "M_LOG10E_H",
+ "M_LN2_H",
+ "M_LN10_H",
+ "M_PI_H",
+ "M_PI_2_H",
+ "M_PI_4_H",
+ "M_1_PI_H",
+ "M_2_PI_H",
+ "M_2_SQRTPI_H",
+ "M_SQRT2_H",
+ "M_SQRT1_2_H",
+ "DBL_DIG",
+ "DBL_MANT_DIG",
+ "DBL_MAX_10_EXP",
+ "DBL_MAX_EXP",
+ "DBL_MIN_10_EXP",
+ "DBL_MIN_EXP",
+ "DBL_RADIX",
+ "DBL_MAX",
+ "DBL_MIN",
+ "DBL_EPSILON",
+ "DBL_DECIMAL_DIG",
+ "MAXDOUBLE",
+ "HUGE_VAL",
+ "M_E",
+ "M_LOG2E",
+ "M_LOG10E",
+ "M_LN2",
+ "M_LN10",
+ "M_PI",
+ "M_PI_2",
+ "M_PI_4",
+ "M_1_PI",
+ "M_2_PI",
+ "M_2_SQRTPI",
+ "M_SQRT2",
+ "M_SQRT1_2",
+ // Naga utilities
+ "DefaultConstructible",
+ "clamped_lod_e",
+];
diff --git a/third_party/rust/naga/src/back/msl/mod.rs b/third_party/rust/naga/src/back/msl/mod.rs
new file mode 100644
index 0000000000..3174d4b756
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/mod.rs
@@ -0,0 +1,489 @@
+/*!
+Backend for [MSL][msl] (Metal Shading Language).
+
+## Binding model
+
+Metal's bindings are flat per resource. Since there isn't an obvious mapping
+from SPIR-V's descriptor sets, we require a separate mapping provided in the options.
+This mapping may have one or more resource end points for each descriptor set + index
+pair.
+
+## Entry points
+
+Even though MSL and our IR appear to be similar in that the entry points in both can
+accept arguments and return values, the restrictions are different.
+MSL allows the varyings to be either in separate arguments, or inside a single
+`[[stage_in]]` struct. We gather input varyings and form this artificial structure.
+We also add all the (non-Private) globals into the arguments.
+
+At the beginning of the entry point, we assign the local constants and re-compose
+the arguments as they are declared on IR side, so that the rest of the logic can
+pretend that MSL doesn't have all the restrictions it has.
+
+For the result type, if it's a structure, we re-compose it with a temporary value
+holding the result.
+
+[msl]: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+*/
+
+use crate::{arena::Handle, proc::index, valid::ModuleInfo};
+use std::fmt::{Error as FmtError, Write};
+
+mod keywords;
+pub mod sampler;
+mod writer;
+
+pub use writer::Writer;
+
+pub type Slot = u8;
+pub type InlineSamplerIndex = u8;
+
+#[derive(Clone, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum BindSamplerTarget {
+ Resource(Slot),
+ Inline(InlineSamplerIndex),
+}
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
+pub struct BindTarget {
+ pub buffer: Option<Slot>,
+ pub texture: Option<Slot>,
+ pub sampler: Option<BindSamplerTarget>,
+ /// If the binding is an unsized binding array, this overrides the size.
+ pub binding_array_size: Option<u32>,
+ pub mutable: bool,
+}
+
+// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
+
+#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
+pub struct EntryPointResources {
+ pub resources: BindingMap,
+
+ pub push_constant_buffer: Option<Slot>,
+
+ /// The slot of a buffer that contains an array of `u32`,
+ /// one for the size of each bound buffer that contains a runtime array,
+ /// in order of [`crate::GlobalVariable`] declarations.
+ pub sizes_buffer: Option<Slot>,
+}
+
+pub type EntryPointResourceMap = std::collections::BTreeMap<String, EntryPointResources>;
+
+enum ResolvedBinding {
+ BuiltIn(crate::BuiltIn),
+ Attribute(u32),
+ Color(u32),
+ User {
+ prefix: &'static str,
+ index: u32,
+ interpolation: Option<ResolvedInterpolation>,
+ },
+ Resource(BindTarget),
+}
+
+#[derive(Copy, Clone)]
+enum ResolvedInterpolation {
+ CenterPerspective,
+ CenterNoPerspective,
+ CentroidPerspective,
+ CentroidNoPerspective,
+ SamplePerspective,
+ SampleNoPerspective,
+ Flat,
+}
+
+// Note: some of these should be removed in favor of proper IR validation.
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error(transparent)]
+ Format(#[from] FmtError),
+ #[error("bind target {0:?} is empty")]
+ UnimplementedBindTarget(BindTarget),
+ #[error("composing of {0:?} is not implemented yet")]
+ UnsupportedCompose(Handle<crate::Type>),
+ #[error("operation {0:?} is not implemented yet")]
+ UnsupportedBinaryOp(crate::BinaryOperator),
+ #[error("standard function '{0}' is not implemented yet")]
+ UnsupportedCall(String),
+ #[error("feature '{0}' is not implemented yet")]
+ FeatureNotImplemented(String),
+ #[error("module is not valid")]
+ Validation,
+ #[error("BuiltIn {0:?} is not supported")]
+ UnsupportedBuiltIn(crate::BuiltIn),
+ #[error("capability {0:?} is not supported")]
+ CapabilityNotSupported(crate::valid::Capabilities),
+ #[error("address space {0:?} is not supported for target MSL version")]
+ UnsupportedAddressSpace(crate::AddressSpace),
+ #[error("attribute '{0}' is not supported for target MSL version")]
+ UnsupportedAttribute(String),
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum EntryPointError {
+ #[error("global '{0}' doesn't have a binding")]
+ MissingBinding(String),
+ #[error("mapping of {0:?} is missing")]
+ MissingBindTarget(crate::ResourceBinding),
+ #[error("mapping for push constants is missing")]
+ MissingPushConstants,
+ #[error("mapping for sizes buffer is missing")]
+ MissingSizesBuffer,
+}
+
+/// Points in the MSL code where we might emit a pipeline input or output.
+///
+/// Note that, even though vertex shaders' outputs are always fragment
+/// shaders' inputs, we still need to distinguish `VertexOutput` and
+/// `FragmentInput`, since there are certain differences in the way
+/// [`ResolvedBinding`s] are represented on either side.
+///
+/// [`ResolvedBinding`s]: ResolvedBinding
+#[derive(Clone, Copy, Debug)]
+enum LocationMode {
+ /// Input to the vertex shader.
+ VertexInput,
+
+ /// Output from the vertex shader.
+ VertexOutput,
+
+ /// Input to the fragment shader.
+ FragmentInput,
+
+ /// Output from the fragment shader.
+ FragmentOutput,
+
+ /// Compute shader input or output.
+ Uniform,
+}
+
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Options {
+ /// (Major, Minor) target version of the Metal Shading Language.
+ pub lang_version: (u8, u8),
+ /// Map of entry-point resources, indexed by entry point function name, to slots.
+ pub per_entry_point_map: EntryPointResourceMap,
+ /// Samplers to be inlined into the code.
+ pub inline_samplers: Vec<sampler::InlineSampler>,
+ /// Make it possible to link different stages via SPIRV-Cross.
+ pub spirv_cross_compatibility: bool,
+ /// Don't panic on missing bindings, instead generate invalid MSL.
+ pub fake_missing_bindings: bool,
+ /// Bounds checking policies.
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub bounds_check_policies: index::BoundsCheckPolicies,
+ /// Should workgroup variables be zero initialized (by polyfilling)?
+ pub zero_initialize_workgroup_memory: bool,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ lang_version: (2, 0),
+ per_entry_point_map: EntryPointResourceMap::default(),
+ inline_samplers: Vec::new(),
+ spirv_cross_compatibility: false,
+ fake_missing_bindings: true,
+ bounds_check_policies: index::BoundsCheckPolicies::default(),
+ zero_initialize_workgroup_memory: true,
+ }
+ }
+}
+
+/// A subset of options that are meant to be changed per pipeline.
+#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct PipelineOptions {
+ /// Allow `BuiltIn::PointSize` in the vertex shader.
+ ///
+ /// Metal doesn't like this for non-point primitive topologies.
+ pub allow_point_size: bool,
+}
+
+impl Options {
+ fn resolve_local_binding(
+ &self,
+ binding: &crate::Binding,
+ mode: LocationMode,
+ ) -> Result<ResolvedBinding, Error> {
+ match *binding {
+ crate::Binding::BuiltIn(mut built_in) => {
+ if let crate::BuiltIn::Position { ref mut invariant } = built_in {
+ if *invariant && self.lang_version < (2, 1) {
+ return Err(Error::UnsupportedAttribute("invariant".to_string()));
+ }
+
+ // The 'invariant' attribute may only appear on vertex
+ // shader outputs, not fragment shader inputs.
+ if !matches!(mode, LocationMode::VertexOutput) {
+ *invariant = false;
+ }
+ }
+
+ Ok(ResolvedBinding::BuiltIn(built_in))
+ }
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ } => match mode {
+ LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(location)),
+ LocationMode::FragmentOutput => Ok(ResolvedBinding::Color(location)),
+ LocationMode::VertexOutput | LocationMode::FragmentInput => {
+ Ok(ResolvedBinding::User {
+ prefix: if self.spirv_cross_compatibility {
+ "locn"
+ } else {
+ "loc"
+ },
+ index: location,
+ interpolation: {
+ // unwrap: The verifier ensures that vertex shader outputs and fragment
+ // shader inputs always have fully specified interpolation, and that
+ // sampling is `None` only for Flat interpolation.
+ let interpolation = interpolation.unwrap();
+ let sampling = sampling.unwrap_or(crate::Sampling::Center);
+ Some(ResolvedInterpolation::from_binding(interpolation, sampling))
+ },
+ })
+ }
+ LocationMode::Uniform => {
+ log::error!(
+ "Unexpected Binding::Location({}) for the Uniform mode",
+ location
+ );
+ Err(Error::Validation)
+ }
+ },
+ }
+ }
+
+ fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> {
+ self.per_entry_point_map.get(&ep.name)
+ }
+
+ fn get_resource_binding_target(
+ &self,
+ ep: &crate::EntryPoint,
+ res_binding: &crate::ResourceBinding,
+ ) -> Option<&BindTarget> {
+ self.get_entry_point_resources(ep)
+ .and_then(|res| res.resources.get(res_binding))
+ }
+
+ fn resolve_resource_binding(
+ &self,
+ ep: &crate::EntryPoint,
+ res_binding: &crate::ResourceBinding,
+ ) -> Result<ResolvedBinding, EntryPointError> {
+ let target = self.get_resource_binding_target(ep, res_binding);
+ match target {
+ Some(target) => Ok(ResolvedBinding::Resource(target.clone())),
+ None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
+ prefix: "fake",
+ index: 0,
+ interpolation: None,
+ }),
+ None => Err(EntryPointError::MissingBindTarget(res_binding.clone())),
+ }
+ }
+
+ fn resolve_push_constants(
+ &self,
+ ep: &crate::EntryPoint,
+ ) -> Result<ResolvedBinding, EntryPointError> {
+ let slot = self
+ .get_entry_point_resources(ep)
+ .and_then(|res| res.push_constant_buffer);
+ match slot {
+ Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
+ buffer: Some(slot),
+ ..Default::default()
+ })),
+ None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
+ prefix: "fake",
+ index: 0,
+ interpolation: None,
+ }),
+ None => Err(EntryPointError::MissingPushConstants),
+ }
+ }
+
+ fn resolve_sizes_buffer(
+ &self,
+ ep: &crate::EntryPoint,
+ ) -> Result<ResolvedBinding, EntryPointError> {
+ let slot = self
+ .get_entry_point_resources(ep)
+ .and_then(|res| res.sizes_buffer);
+ match slot {
+ Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
+ buffer: Some(slot),
+ ..Default::default()
+ })),
+ None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
+ prefix: "fake",
+ index: 0,
+ interpolation: None,
+ }),
+ None => Err(EntryPointError::MissingSizesBuffer),
+ }
+ }
+}
+
+impl ResolvedBinding {
+ fn as_inline_sampler<'a>(&self, options: &'a Options) -> Option<&'a sampler::InlineSampler> {
+ match *self {
+ Self::Resource(BindTarget {
+ sampler: Some(BindSamplerTarget::Inline(index)),
+ ..
+ }) => Some(&options.inline_samplers[index as usize]),
+ _ => None,
+ }
+ }
+
+ const fn as_bind_target(&self) -> Option<&BindTarget> {
+ match *self {
+ Self::Resource(ref target) => Some(target),
+ _ => None,
+ }
+ }
+
+ fn try_fmt<W: Write>(&self, out: &mut W) -> Result<(), Error> {
+ write!(out, " [[")?;
+ match *self {
+ Self::BuiltIn(built_in) => {
+ use crate::BuiltIn as Bi;
+ let name = match built_in {
+ Bi::Position { invariant: false } => "position",
+ Bi::Position { invariant: true } => "position, invariant",
+ // vertex
+ Bi::BaseInstance => "base_instance",
+ Bi::BaseVertex => "base_vertex",
+ Bi::ClipDistance => "clip_distance",
+ Bi::InstanceIndex => "instance_id",
+ Bi::PointSize => "point_size",
+ Bi::VertexIndex => "vertex_id",
+ // fragment
+ Bi::FragDepth => "depth(any)",
+ Bi::PointCoord => "point_coord",
+ Bi::FrontFacing => "front_facing",
+ Bi::PrimitiveIndex => "primitive_id",
+ Bi::SampleIndex => "sample_id",
+ Bi::SampleMask => "sample_mask",
+ // compute
+ Bi::GlobalInvocationId => "thread_position_in_grid",
+ Bi::LocalInvocationId => "thread_position_in_threadgroup",
+ Bi::LocalInvocationIndex => "thread_index_in_threadgroup",
+ Bi::WorkGroupId => "threadgroup_position_in_grid",
+ Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
+ Bi::NumWorkGroups => "threadgroups_per_grid",
+ Bi::CullDistance | Bi::ViewIndex => {
+ return Err(Error::UnsupportedBuiltIn(built_in))
+ }
+ };
+ write!(out, "{name}")?;
+ }
+ Self::Attribute(index) => write!(out, "attribute({index})")?,
+ Self::Color(index) => write!(out, "color({index})")?,
+ Self::User {
+ prefix,
+ index,
+ interpolation,
+ } => {
+ write!(out, "user({prefix}{index})")?;
+ if let Some(interpolation) = interpolation {
+ write!(out, ", ")?;
+ interpolation.try_fmt(out)?;
+ }
+ }
+ Self::Resource(ref target) => {
+ if let Some(id) = target.buffer {
+ write!(out, "buffer({id})")?;
+ } else if let Some(id) = target.texture {
+ write!(out, "texture({id})")?;
+ } else if let Some(BindSamplerTarget::Resource(id)) = target.sampler {
+ write!(out, "sampler({id})")?;
+ } else {
+ return Err(Error::UnimplementedBindTarget(target.clone()));
+ }
+ }
+ }
+ write!(out, "]]")?;
+ Ok(())
+ }
+}
+
+impl ResolvedInterpolation {
+ const fn from_binding(interpolation: crate::Interpolation, sampling: crate::Sampling) -> Self {
+ use crate::Interpolation as I;
+ use crate::Sampling as S;
+
+ match (interpolation, sampling) {
+ (I::Perspective, S::Center) => Self::CenterPerspective,
+ (I::Perspective, S::Centroid) => Self::CentroidPerspective,
+ (I::Perspective, S::Sample) => Self::SamplePerspective,
+ (I::Linear, S::Center) => Self::CenterNoPerspective,
+ (I::Linear, S::Centroid) => Self::CentroidNoPerspective,
+ (I::Linear, S::Sample) => Self::SampleNoPerspective,
+ (I::Flat, _) => Self::Flat,
+ }
+ }
+
+ fn try_fmt<W: Write>(self, out: &mut W) -> Result<(), Error> {
+ let identifier = match self {
+ Self::CenterPerspective => "center_perspective",
+ Self::CenterNoPerspective => "center_no_perspective",
+ Self::CentroidPerspective => "centroid_perspective",
+ Self::CentroidNoPerspective => "centroid_no_perspective",
+ Self::SamplePerspective => "sample_perspective",
+ Self::SampleNoPerspective => "sample_no_perspective",
+ Self::Flat => "flat",
+ };
+ out.write_str(identifier)?;
+ Ok(())
+ }
+}
+
+/// Information about a translated module that is required
+/// for the use of the result.
+pub struct TranslationInfo {
+ /// Mapping of the entry point names. Each item in the array
+ /// corresponds to an entry point index.
+ ///
+ ///Note: Some entry points may fail translation because of missing bindings.
+ pub entry_point_names: Vec<Result<String, EntryPointError>>,
+}
+
+pub fn write_string(
+ module: &crate::Module,
+ info: &ModuleInfo,
+ options: &Options,
+ pipeline_options: &PipelineOptions,
+) -> Result<(String, TranslationInfo), Error> {
+ let mut w = writer::Writer::new(String::new());
+ let info = w.write(module, info, options, pipeline_options)?;
+ Ok((w.finish(), info))
+}
+
+#[test]
+fn test_error_size() {
+ use std::mem::size_of;
+ assert_eq!(size_of::<Error>(), 32);
+}
diff --git a/third_party/rust/naga/src/back/msl/sampler.rs b/third_party/rust/naga/src/back/msl/sampler.rs
new file mode 100644
index 0000000000..3b95fa3781
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/sampler.rs
@@ -0,0 +1,175 @@
+#[cfg(feature = "deserialize")]
+use serde::Deserialize;
+#[cfg(feature = "serialize")]
+use serde::Serialize;
+use std::{num::NonZeroU32, ops::Range};
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Coord {
+ Normalized,
+ Pixel,
+}
+
+impl Default for Coord {
+ fn default() -> Self {
+ Self::Normalized
+ }
+}
+
+impl Coord {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Normalized => "normalized",
+ Self::Pixel => "pixel",
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Address {
+ Repeat,
+ MirroredRepeat,
+ ClampToEdge,
+ ClampToZero,
+ ClampToBorder,
+}
+
+impl Default for Address {
+ fn default() -> Self {
+ Self::ClampToEdge
+ }
+}
+
+impl Address {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Repeat => "repeat",
+ Self::MirroredRepeat => "mirrored_repeat",
+ Self::ClampToEdge => "clamp_to_edge",
+ Self::ClampToZero => "clamp_to_zero",
+ Self::ClampToBorder => "clamp_to_border",
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum BorderColor {
+ TransparentBlack,
+ OpaqueBlack,
+ OpaqueWhite,
+}
+
+impl Default for BorderColor {
+ fn default() -> Self {
+ Self::TransparentBlack
+ }
+}
+
+impl BorderColor {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::TransparentBlack => "transparent_black",
+ Self::OpaqueBlack => "opaque_black",
+ Self::OpaqueWhite => "opaque_white",
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Filter {
+ Nearest,
+ Linear,
+}
+
+impl Filter {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Nearest => "nearest",
+ Self::Linear => "linear",
+ }
+ }
+}
+
+impl Default for Filter {
+ fn default() -> Self {
+ Self::Nearest
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum CompareFunc {
+ Never,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+ Equal,
+ NotEqual,
+ Always,
+}
+
+impl Default for CompareFunc {
+ fn default() -> Self {
+ Self::Never
+ }
+}
+
+impl CompareFunc {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Never => "never",
+ Self::Less => "less",
+ Self::LessEqual => "less_equal",
+ Self::Greater => "greater",
+ Self::GreaterEqual => "greater_equal",
+ Self::Equal => "equal",
+ Self::NotEqual => "not_equal",
+ Self::Always => "always",
+ }
+ }
+}
+
+#[derive(Clone, Debug, Default, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct InlineSampler {
+ pub coord: Coord,
+ pub address: [Address; 3],
+ pub border_color: BorderColor,
+ pub mag_filter: Filter,
+ pub min_filter: Filter,
+ pub mip_filter: Option<Filter>,
+ pub lod_clamp: Option<Range<f32>>,
+ pub max_anisotropy: Option<NonZeroU32>,
+ pub compare_func: CompareFunc,
+}
+
+impl Eq for InlineSampler {}
+
+#[allow(clippy::derive_hash_xor_eq)]
+impl std::hash::Hash for InlineSampler {
+ fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
+ self.coord.hash(hasher);
+ self.address.hash(hasher);
+ self.border_color.hash(hasher);
+ self.mag_filter.hash(hasher);
+ self.min_filter.hash(hasher);
+ self.mip_filter.hash(hasher);
+ self.lod_clamp
+ .as_ref()
+ .map(|range| (range.start.to_bits(), range.end.to_bits()))
+ .hash(hasher);
+ self.max_anisotropy.hash(hasher);
+ self.compare_func.hash(hasher);
+ }
+}
diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs
new file mode 100644
index 0000000000..88600a8f11
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/writer.rs
@@ -0,0 +1,4429 @@
+use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
+use crate::{
+ arena::Handle,
+ back,
+ proc::index,
+ proc::{self, NameKey, TypeResolution},
+ valid, FastHashMap, FastHashSet,
+};
+use bit_set::BitSet;
+use std::{
+ fmt::{Display, Error as FmtError, Formatter, Write},
+ iter,
+};
+
+/// Shorthand result used internally by the backend
+type BackendResult = Result<(), Error>;
+
+const NAMESPACE: &str = "metal";
+// The name of the array member of the Metal struct types we generate to
+// represent Naga `Array` types. See the comments in `Writer::write_type_defs`
+// for details.
+const WRAPPED_ARRAY_FIELD: &str = "inner";
+// This is a hack: we need to pass a pointer to an atomic,
+// but generally the backend isn't putting "&" in front of every pointer.
+// Some more general handling of pointers is needed to be implemented here.
+const ATOMIC_REFERENCE: &str = "&";
+
+const RT_NAMESPACE: &str = "metal::raytracing";
+const RAY_QUERY_TYPE: &str = "_RayQuery";
+const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
+const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
+const RAY_QUERY_FIELD_READY: &str = "ready";
+const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
+
+/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
+///
+/// The `sizes` slice determines whether this function writes a
+/// scalar, vector, or matrix type:
+///
+/// - An empty slice produces a scalar type.
+/// - A one-element slice produces a vector type.
+/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
+fn put_numeric_type(
+ out: &mut impl Write,
+ kind: crate::ScalarKind,
+ sizes: &[crate::VectorSize],
+) -> Result<(), FmtError> {
+ match (kind, sizes) {
+ (kind, &[]) => {
+ write!(out, "{}", kind.to_msl_name())
+ }
+ (kind, &[rows]) => {
+ write!(
+ out,
+ "{}::{}{}",
+ NAMESPACE,
+ kind.to_msl_name(),
+ back::vector_size_str(rows)
+ )
+ }
+ (kind, &[rows, columns]) => {
+ write!(
+ out,
+ "{}::{}{}x{}",
+ NAMESPACE,
+ kind.to_msl_name(),
+ back::vector_size_str(columns),
+ back::vector_size_str(rows)
+ )
+ }
+ (_, _) => Ok(()), // not meaningful
+ }
+}
+
+/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
+const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
+
+struct TypeContext<'a> {
+ handle: Handle<crate::Type>,
+ module: &'a crate::Module,
+ names: &'a FastHashMap<NameKey, String>,
+ access: crate::StorageAccess,
+ binding: Option<&'a super::ResolvedBinding>,
+ first_time: bool,
+}
+
+impl<'a> Display for TypeContext<'a> {
+ fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
+ let ty = &self.module.types[self.handle];
+ if ty.needs_alias() && !self.first_time {
+ let name = &self.names[&NameKey::Type(self.handle)];
+ return write!(out, "{name}");
+ }
+
+ match ty.inner {
+ crate::TypeInner::Scalar { kind, .. } => put_numeric_type(out, kind, &[]),
+ crate::TypeInner::Atomic { kind, .. } => {
+ write!(out, "{}::atomic_{}", NAMESPACE, kind.to_msl_name())
+ }
+ crate::TypeInner::Vector { size, kind, .. } => put_numeric_type(out, kind, &[size]),
+ crate::TypeInner::Matrix { columns, rows, .. } => {
+ put_numeric_type(out, crate::ScalarKind::Float, &[rows, columns])
+ }
+ crate::TypeInner::Pointer { base, space } => {
+ let sub = Self {
+ handle: base,
+ first_time: false,
+ ..*self
+ };
+ let space_name = match space.to_msl_name() {
+ Some(name) => name,
+ None => return Ok(()),
+ };
+ write!(out, "{space_name} {sub}&")
+ }
+ crate::TypeInner::ValuePointer {
+ size,
+ kind,
+ width: _,
+ space,
+ } => {
+ match space.to_msl_name() {
+ Some(name) => write!(out, "{name} ")?,
+ None => return Ok(()),
+ };
+ match size {
+ Some(rows) => put_numeric_type(out, kind, &[rows])?,
+ None => put_numeric_type(out, kind, &[])?,
+ };
+
+ write!(out, "&")
+ }
+ crate::TypeInner::Array { base, .. } => {
+ let sub = Self {
+ handle: base,
+ first_time: false,
+ ..*self
+ };
+ // Array lengths go at the end of the type definition,
+ // so just print the element type here.
+ write!(out, "{sub}")
+ }
+ crate::TypeInner::Struct { .. } => unreachable!(),
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ let dim_str = match dim {
+ crate::ImageDimension::D1 => "1d",
+ crate::ImageDimension::D2 => "2d",
+ crate::ImageDimension::D3 => "3d",
+ crate::ImageDimension::Cube => "cube",
+ };
+ let (texture_str, msaa_str, kind, access) = match class {
+ crate::ImageClass::Sampled { kind, multi } => {
+ let (msaa_str, access) = if multi {
+ ("_ms", "read")
+ } else {
+ ("", "sample")
+ };
+ ("texture", msaa_str, kind, access)
+ }
+ crate::ImageClass::Depth { multi } => {
+ let (msaa_str, access) = if multi {
+ ("_ms", "read")
+ } else {
+ ("", "sample")
+ };
+ ("depth", msaa_str, crate::ScalarKind::Float, access)
+ }
+ crate::ImageClass::Storage { format, .. } => {
+ let access = if self
+ .access
+ .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
+ {
+ "read_write"
+ } else if self.access.contains(crate::StorageAccess::STORE) {
+ "write"
+ } else if self.access.contains(crate::StorageAccess::LOAD) {
+ "read"
+ } else {
+ log::warn!(
+ "Storage access for {:?} (name '{}'): {:?}",
+ self.handle,
+ ty.name.as_deref().unwrap_or_default(),
+ self.access
+ );
+ unreachable!("module is not valid");
+ };
+ ("texture", "", format.into(), access)
+ }
+ };
+ let base_name = kind.to_msl_name();
+ let array_str = if arrayed { "_array" } else { "" };
+ write!(
+ out,
+ "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
+ )
+ }
+ crate::TypeInner::Sampler { comparison: _ } => {
+ write!(out, "{NAMESPACE}::sampler")
+ }
+ crate::TypeInner::AccelerationStructure => {
+ write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
+ }
+ crate::TypeInner::RayQuery => {
+ write!(out, "{RAY_QUERY_TYPE}")
+ }
+ crate::TypeInner::BindingArray { base, size } => {
+ let base_tyname = Self {
+ handle: base,
+ first_time: false,
+ ..*self
+ };
+
+ if let Some(&super::ResolvedBinding::Resource(super::BindTarget {
+ binding_array_size: Some(override_size),
+ ..
+ })) = self.binding
+ {
+ write!(out, "{NAMESPACE}::array<{base_tyname}, {override_size}>")
+ } else if let crate::ArraySize::Constant(size) = size {
+ let constant_ctx = ConstantContext {
+ handle: size,
+ arena: &self.module.constants,
+ names: self.names,
+ first_time: false,
+ };
+ write!(out, "{NAMESPACE}::array<{base_tyname}, {constant_ctx}>")
+ } else {
+ unreachable!("metal requires all arrays be constant sized");
+ }
+ }
+ }
+ }
+}
+
+struct TypedGlobalVariable<'a> {
+ module: &'a crate::Module,
+ names: &'a FastHashMap<NameKey, String>,
+ handle: Handle<crate::GlobalVariable>,
+ usage: valid::GlobalUse,
+ binding: Option<&'a super::ResolvedBinding>,
+ reference: bool,
+}
+
+impl<'a> TypedGlobalVariable<'a> {
+ fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
+ let var = &self.module.global_variables[self.handle];
+ let name = &self.names[&NameKey::GlobalVariable(self.handle)];
+
+ let storage_access = match var.space {
+ crate::AddressSpace::Storage { access } => access,
+ _ => match self.module.types[var.ty].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => access,
+ crate::TypeInner::BindingArray { base, .. } => {
+ match self.module.types[base].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => access,
+ _ => crate::StorageAccess::default(),
+ }
+ }
+ _ => crate::StorageAccess::default(),
+ },
+ };
+ let ty_name = TypeContext {
+ handle: var.ty,
+ module: self.module,
+ names: self.names,
+ access: storage_access,
+ binding: self.binding,
+ first_time: false,
+ };
+
+ let (space, access, reference) = match var.space.to_msl_name() {
+ Some(space) if self.reference => {
+ let access = if var.space.needs_access_qualifier()
+ && !self.usage.contains(valid::GlobalUse::WRITE)
+ {
+ "const"
+ } else {
+ ""
+ };
+ (space, access, "&")
+ }
+ _ => ("", "", ""),
+ };
+
+ Ok(write!(
+ out,
+ "{}{}{}{}{}{} {}",
+ space,
+ if space.is_empty() { "" } else { " " },
+ ty_name,
+ if access.is_empty() { "" } else { " " },
+ access,
+ reference,
+ name,
+ )?)
+ }
+}
+
+struct ConstantContext<'a> {
+ handle: Handle<crate::Constant>,
+ arena: &'a crate::Arena<crate::Constant>,
+ names: &'a FastHashMap<NameKey, String>,
+ first_time: bool,
+}
+
+impl<'a> Display for ConstantContext<'a> {
+ fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
+ let con = &self.arena[self.handle];
+ if con.needs_alias() && !self.first_time {
+ let name = &self.names[&NameKey::Constant(self.handle)];
+ return write!(out, "{name}");
+ }
+
+ match con.inner {
+ crate::ConstantInner::Scalar { value, width: _ } => match value {
+ crate::ScalarValue::Sint(value) => {
+ write!(out, "{value}")
+ }
+ crate::ScalarValue::Uint(value) => {
+ write!(out, "{value}u")
+ }
+ crate::ScalarValue::Float(value) => {
+ if value.is_infinite() {
+ let sign = if value.is_sign_negative() { "-" } else { "" };
+ write!(out, "{sign}INFINITY")
+ } else if value.is_nan() {
+ write!(out, "NAN")
+ } else {
+ let suffix = if value.fract() == 0.0 { ".0" } else { "" };
+
+ write!(out, "{value}{suffix}")
+ }
+ }
+ crate::ScalarValue::Bool(value) => {
+ write!(out, "{value}")
+ }
+ },
+ crate::ConstantInner::Composite { .. } => unreachable!("should be aliased"),
+ }
+ }
+}
+
+pub struct Writer<W> {
+ out: W,
+ names: FastHashMap<NameKey, String>,
+ named_expressions: crate::NamedExpressions,
+ /// Set of expressions that need to be baked to avoid unnecessary repetition in output
+ need_bake_expressions: back::NeedBakeExpressions,
+ namer: proc::Namer,
+ #[cfg(test)]
+ put_expression_stack_pointers: FastHashSet<*const ()>,
+ #[cfg(test)]
+ put_block_stack_pointers: FastHashSet<*const ()>,
+ /// Set of (struct type, struct field index) denoting which fields require
+ /// padding inserted **before** them (i.e. between fields at index - 1 and index)
+ struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
+}
+
+impl crate::ScalarKind {
+ const fn to_msl_name(self) -> &'static str {
+ match self {
+ Self::Float => "float",
+ Self::Sint => "int",
+ Self::Uint => "uint",
+ Self::Bool => "bool",
+ }
+ }
+}
+
+const fn separate(need_separator: bool) -> &'static str {
+ if need_separator {
+ ","
+ } else {
+ ""
+ }
+}
+
+fn should_pack_struct_member(
+ members: &[crate::StructMember],
+ span: u32,
+ index: usize,
+ module: &crate::Module,
+) -> Option<crate::ScalarKind> {
+ let member = &members[index];
+ //Note: this is imperfect - the same structure can be used for host-shared
+ // things, where packed float would matter.
+ if member.binding.is_some() {
+ return None;
+ }
+
+ let ty_inner = &module.types[member.ty].inner;
+ let last_offset = member.offset + ty_inner.size(&module.constants);
+ let next_offset = match members.get(index + 1) {
+ Some(next) => next.offset,
+ None => span,
+ };
+ let is_tight = next_offset == last_offset;
+
+ match *ty_inner {
+ crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ width: 4,
+ kind,
+ } if member.offset & 0xF != 0 || is_tight => Some(kind),
+ _ => None,
+ }
+}
+
+fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
+ match arena[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ if let Some(member) = members.last() {
+ if let crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } = arena[member.ty].inner
+ {
+ return true;
+ }
+ }
+ false
+ }
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } => true,
+ _ => false,
+ }
+}
+
+impl crate::AddressSpace {
+ /// Returns true if global variables in this address space are
+ /// passed in function arguments. These arguments need to be
+ /// passed through any functions called from the entry point.
+ const fn needs_pass_through(&self) -> bool {
+ match *self {
+ Self::Uniform
+ | Self::Storage { .. }
+ | Self::Private
+ | Self::WorkGroup
+ | Self::PushConstant
+ | Self::Handle => true,
+ Self::Function => false,
+ }
+ }
+
+ /// Returns true if the address space may need a "const" qualifier.
+ const fn needs_access_qualifier(&self) -> bool {
+ match *self {
+ //Note: we are ignoring the storage access here, and instead
+ // rely on the actual use of a global by functions. This means we
+ // may end up with "const" even if the binding is read-write,
+ // and that should be OK.
+ Self::Storage { .. } => true,
+ // These should always be read-write.
+ Self::Private | Self::WorkGroup => false,
+ // These translate to `constant` address space, no need for qualifiers.
+ Self::Uniform | Self::PushConstant => false,
+ // Not applicable.
+ Self::Handle | Self::Function => false,
+ }
+ }
+
+ const fn to_msl_name(self) -> Option<&'static str> {
+ match self {
+ Self::Handle => None,
+ Self::Uniform | Self::PushConstant => Some("constant"),
+ Self::Storage { .. } => Some("device"),
+ Self::Private | Self::Function => Some("thread"),
+ Self::WorkGroup => Some("threadgroup"),
+ }
+ }
+}
+
+impl crate::Type {
+ // Returns `true` if we need to emit an alias for this type.
+ const fn needs_alias(&self) -> bool {
+ use crate::TypeInner as Ti;
+
+ match self.inner {
+ // value types are concise enough, we only alias them if they are named
+ Ti::Scalar { .. }
+ | Ti::Vector { .. }
+ | Ti::Matrix { .. }
+ | Ti::Atomic { .. }
+ | Ti::Pointer { .. }
+ | Ti::ValuePointer { .. } => self.name.is_some(),
+ // composite types are better to be aliased, regardless of the name
+ Ti::Struct { .. } | Ti::Array { .. } => true,
+ // handle types may be different, depending on the global var access, so we always inline them
+ Ti::Image { .. }
+ | Ti::Sampler { .. }
+ | Ti::AccelerationStructure
+ | Ti::RayQuery
+ | Ti::BindingArray { .. } => false,
+ }
+ }
+}
+
+impl crate::Constant {
+ // Returns `true` if we need to emit an alias for this constant.
+ const fn needs_alias(&self) -> bool {
+ match self.inner {
+ crate::ConstantInner::Scalar { .. } => self.name.is_some(),
+ crate::ConstantInner::Composite { .. } => true,
+ }
+ }
+}
+
+enum FunctionOrigin {
+ Handle(Handle<crate::Function>),
+ EntryPoint(proc::EntryPointIndex),
+}
+
+/// A level of detail argument.
+///
+/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
+/// save the clamped level of detail in a temporary variable whose name is based
+/// on the handle of the `ImageLoad` expression. But for other policies, we just
+/// use the expression directly.
+///
+/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+#[derive(Clone, Copy)]
+enum LevelOfDetail {
+ Direct(Handle<crate::Expression>),
+ Restricted(Handle<crate::Expression>),
+}
+
+/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
+///
+/// When this is used in code paths unconcerned with the `Restrict` bounds check
+/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
+/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
+/// be too awkward. If that changes, we can revisit.
+///
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+/// [`ImageStore`]: crate::Statement::ImageStore
+struct TexelAddress {
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ sample: Option<Handle<crate::Expression>>,
+ level: Option<LevelOfDetail>,
+}
+
+struct ExpressionContext<'a> {
+ function: &'a crate::Function,
+ origin: FunctionOrigin,
+ info: &'a valid::FunctionInfo,
+ module: &'a crate::Module,
+ pipeline_options: &'a PipelineOptions,
+ policies: index::BoundsCheckPolicies,
+
+ /// A bitset containing the `Expression` handle indexes of expressions used
+ /// as indices in `ReadZeroSkipWrite`-policy accesses. These may need to be
+ /// cached in temporary variables. See `index::find_checked_indexes` for
+ /// details.
+ guarded_indices: BitSet,
+}
+
+impl<'a> ExpressionContext<'a> {
+ fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
+ self.info[handle].ty.inner_with(&self.module.types)
+ }
+
+ /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
+ ///
+ /// Only mipmapped images need to specify a level of detail. Since 1D
+ /// textures cannot have mipmaps, MSL requires that the level argument to
+ /// texture1d queries and accesses must be a constexpr 0. It's easiest
+ /// just to omit the level entirely for 1D textures.
+ fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
+ let image_ty = self.resolve_type(image);
+ if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
+ class.is_mipmapped() && dim != crate::ImageDimension::D1
+ } else {
+ false
+ }
+ }
+
+ fn choose_bounds_check_policy(
+ &self,
+ pointer: Handle<crate::Expression>,
+ ) -> index::BoundsCheckPolicy {
+ self.policies
+ .choose_policy(pointer, &self.module.types, self.info)
+ }
+
+ fn access_needs_check(
+ &self,
+ base: Handle<crate::Expression>,
+ index: index::GuardedIndex,
+ ) -> Option<index::IndexableLength> {
+ index::access_needs_check(base, index, self.module, self.function, self.info)
+ }
+
+ fn get_packed_vec_kind(
+ &self,
+ expr_handle: Handle<crate::Expression>,
+ ) -> Option<crate::ScalarKind> {
+ match self.function.expressions[expr_handle] {
+ crate::Expression::AccessIndex { base, index } => {
+ let ty = match *self.resolve_type(base) {
+ crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
+ ref ty => ty,
+ };
+ match *ty {
+ crate::TypeInner::Struct {
+ ref members, span, ..
+ } => should_pack_struct_member(members, span, index as usize, self.module),
+ _ => None,
+ }
+ }
+ _ => None,
+ }
+ }
+}
+
+struct StatementContext<'a> {
+ expression: ExpressionContext<'a>,
+ mod_info: &'a valid::ModuleInfo,
+ result_struct: Option<&'a str>,
+}
+
+impl<W: Write> Writer<W> {
+ /// Creates a new `Writer` instance.
+ pub fn new(out: W) -> Self {
+ Writer {
+ out,
+ names: FastHashMap::default(),
+ named_expressions: Default::default(),
+ need_bake_expressions: Default::default(),
+ namer: proc::Namer::default(),
+ #[cfg(test)]
+ put_expression_stack_pointers: Default::default(),
+ #[cfg(test)]
+ put_block_stack_pointers: Default::default(),
+ struct_member_pads: FastHashSet::default(),
+ }
+ }
+
+ /// Finishes writing and returns the output.
+ // See https://github.com/rust-lang/rust-clippy/issues/4979.
+ #[allow(clippy::missing_const_for_fn)]
+ pub fn finish(self) -> W {
+ self.out
+ }
+
+ fn put_call_parameters(
+ &mut self,
+ parameters: impl Iterator<Item = Handle<crate::Expression>>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ write!(self.out, "(")?;
+ for (i, handle) in parameters.enumerate() {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.put_expression(handle, context, true)?;
+ }
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_level_of_detail(
+ &mut self,
+ level: LevelOfDetail,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match level {
+ LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
+ LevelOfDetail::Restricted(load) => {
+ write!(self.out, "{}{}", CLAMPED_LOD_LOAD_PREFIX, load.index())?
+ }
+ }
+ Ok(())
+ }
+
+ fn put_image_query(
+ &mut self,
+ image: Handle<crate::Expression>,
+ query: &str,
+ level: Option<LevelOfDetail>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_{query}(")?;
+ if let Some(level) = level {
+ self.put_level_of_detail(level, context)?;
+ }
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_image_size_query(
+ &mut self,
+ image: Handle<crate::Expression>,
+ level: Option<LevelOfDetail>,
+ kind: crate::ScalarKind,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ //Note: MSL only has separate width/height/depth queries,
+ // so compose the result of them.
+ let dim = match *context.resolve_type(image) {
+ crate::TypeInner::Image { dim, .. } => dim,
+ ref other => unreachable!("Unexpected type {:?}", other),
+ };
+ let coordinate_type = kind.to_msl_name();
+ match dim {
+ crate::ImageDimension::D1 => {
+ // Since 1D textures never have mipmaps, MSL requires that the
+ // `level` argument be a constexpr 0. It's simplest for us just
+ // to pass `None` and omit the level entirely.
+ if kind == crate::ScalarKind::Uint {
+ // No need to construct a vector. No cast needed.
+ self.put_image_query(image, "width", None, context)?;
+ } else {
+ // There's no definition for `int` in the `metal` namespace.
+ write!(self.out, "int(")?;
+ self.put_image_query(image, "width", None, context)?;
+ write!(self.out, ")")?;
+ }
+ }
+ crate::ImageDimension::D2 => {
+ write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
+ self.put_image_query(image, "width", level, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_query(image, "height", level, context)?;
+ write!(self.out, ")")?;
+ }
+ crate::ImageDimension::D3 => {
+ write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
+ self.put_image_query(image, "width", level, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_query(image, "height", level, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_query(image, "depth", level, context)?;
+ write!(self.out, ")")?;
+ }
+ crate::ImageDimension::Cube => {
+ write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
+ self.put_image_query(image, "width", level, context)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Ok(())
+ }
+
+ fn put_cast_to_uint_scalar_or_vector(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // coordinates in IR are int, but Metal expects uint
+ match *context.resolve_type(expr) {
+ crate::TypeInner::Scalar { .. } => {
+ put_numeric_type(&mut self.out, crate::ScalarKind::Uint, &[])?
+ }
+ crate::TypeInner::Vector { size, .. } => {
+ put_numeric_type(&mut self.out, crate::ScalarKind::Uint, &[size])?
+ }
+ _ => return Err(Error::Validation),
+ };
+
+ write!(self.out, "(")?;
+ self.put_expression(expr, context, true)?;
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_image_sample_level(
+ &mut self,
+ image: Handle<crate::Expression>,
+ level: crate::SampleLevel,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let has_levels = context.image_needs_lod(image);
+ match level {
+ crate::SampleLevel::Auto => {}
+ crate::SampleLevel::Zero => {
+ //TODO: do we support Zero on `Sampled` image classes?
+ }
+ _ if !has_levels => {
+ log::warn!("1D image can't be sampled with level {:?}", level);
+ }
+ crate::SampleLevel::Exact(h) => {
+ write!(self.out, ", {NAMESPACE}::level(")?;
+ self.put_expression(h, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::SampleLevel::Bias(h) => {
+ write!(self.out, ", {NAMESPACE}::bias(")?;
+ self.put_expression(h, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ write!(self.out, ", {NAMESPACE}::gradient2d(")?;
+ self.put_expression(x, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(y, context, true)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Ok(())
+ }
+
+ fn put_image_coordinate_limits(
+ &mut self,
+ image: Handle<crate::Expression>,
+ level: Option<LevelOfDetail>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
+ write!(self.out, " - 1")?;
+ Ok(())
+ }
+
+ /// General function for writing restricted image indexes.
+ ///
+ /// This is used to produce restricted mip levels, array indices, and sample
+ /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
+ /// [`Restrict`] bounds check policy.
+ ///
+ /// This function writes an expression of the form:
+ ///
+ /// ```ignore
+ ///
+ /// metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
+ ///
+ /// ```
+ ///
+ /// [`ImageLoad`]: crate::Expression::ImageLoad
+ /// [`ImageStore`]: crate::Statement::ImageStore
+ /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
+ fn put_restricted_scalar_image_index(
+ &mut self,
+ image: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ limit_method: &str,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ write!(self.out, "{NAMESPACE}::min(uint(")?;
+ self.put_expression(index, context, true)?;
+ write!(self.out, "), ")?;
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".{limit_method}() - 1)")?;
+ Ok(())
+ }
+
+ fn put_restricted_texel_address(
+ &mut self,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // Write the coordinate.
+ write!(self.out, "{NAMESPACE}::min(")?;
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_coordinate_limits(image, address.level, context)?;
+ write!(self.out, ")")?;
+
+ // Write the array index, if present.
+ if let Some(array_index) = address.array_index {
+ write!(self.out, ", ")?;
+ self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
+ }
+
+ // Write the sample index, if present.
+ if let Some(sample) = address.sample {
+ write!(self.out, ", ")?;
+ self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
+ }
+
+ // The level of detail should be clamped and cached by
+ // `put_cache_restricted_level`, so we don't need to clamp it here.
+ if let Some(level) = address.level {
+ write!(self.out, ", ")?;
+ self.put_level_of_detail(level, context)?;
+ }
+
+ Ok(())
+ }
+
+ /// Write an expression that is true if the given image access is in bounds.
+ fn put_image_access_bounds_check(
+ &mut self,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let mut conjunction = "";
+
+ // First, check the level of detail. Only if that is in bounds can we
+ // use it to find the appropriate bounds for the coordinates.
+ let level = if let Some(level) = address.level {
+ write!(self.out, "uint(")?;
+ self.put_level_of_detail(level, context)?;
+ write!(self.out, ") < ")?;
+ self.put_expression(image, context, true)?;
+ write!(self.out, ".get_num_mip_levels()")?;
+ conjunction = " && ";
+ Some(level)
+ } else {
+ None
+ };
+
+ // Check sample index, if present.
+ if let Some(sample) = address.sample {
+ write!(self.out, "uint(")?;
+ self.put_expression(sample, context, true)?;
+ write!(self.out, ") < ")?;
+ self.put_expression(image, context, true)?;
+ write!(self.out, ".get_num_samples()")?;
+ conjunction = " && ";
+ }
+
+ // Check array index, if present.
+ if let Some(array_index) = address.array_index {
+ write!(self.out, "{conjunction}uint(")?;
+ self.put_expression(array_index, context, true)?;
+ write!(self.out, ") < ")?;
+ self.put_expression(image, context, true)?;
+ write!(self.out, ".get_array_size()")?;
+ conjunction = " && ";
+ }
+
+ // Finally, check if the coordinates are within bounds.
+ let coord_is_vector = match *context.resolve_type(address.coordinate) {
+ crate::TypeInner::Vector { .. } => true,
+ _ => false,
+ };
+ write!(self.out, "{conjunction}")?;
+ if coord_is_vector {
+ write!(self.out, "{NAMESPACE}::all(")?;
+ }
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
+ write!(self.out, " < ")?;
+ self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
+ if coord_is_vector {
+ write!(self.out, ")")?;
+ }
+
+ Ok(())
+ }
+
+ fn put_image_load(
+ &mut self,
+ load: Handle<crate::Expression>,
+ image: Handle<crate::Expression>,
+ mut address: TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match context.policies.image {
+ proc::BoundsCheckPolicy::Restrict => {
+ // Use the cached restricted level of detail, if any. Omit the
+ // level altogether for 1D textures.
+ if address.level.is_some() {
+ address.level = if context.image_needs_lod(image) {
+ Some(LevelOfDetail::Restricted(load))
+ } else {
+ None
+ }
+ }
+
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".read(")?;
+ self.put_restricted_texel_address(image, &address, context)?;
+ write!(self.out, ")")?;
+ }
+ proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
+ write!(self.out, "(")?;
+ self.put_image_access_bounds_check(image, &address, context)?;
+ write!(self.out, " ? ")?;
+ self.put_unchecked_image_load(image, &address, context)?;
+ write!(self.out, ": DefaultConstructible())")?;
+ }
+ proc::BoundsCheckPolicy::Unchecked => {
+ self.put_unchecked_image_load(image, &address, context)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_image_load(
+ &mut self,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".read(")?;
+ // coordinates in IR are int, but Metal expects uint
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
+ if let Some(expr) = address.array_index {
+ write!(self.out, ", ")?;
+ self.put_expression(expr, context, true)?;
+ }
+ if let Some(sample) = address.sample {
+ write!(self.out, ", ")?;
+ self.put_expression(sample, context, true)?;
+ }
+ if let Some(level) = address.level {
+ if context.image_needs_lod(image) {
+ write!(self.out, ", ")?;
+ self.put_level_of_detail(level, context)?;
+ }
+ }
+ write!(self.out, ")")?;
+
+ Ok(())
+ }
+
+ fn put_image_store(
+ &mut self,
+ level: back::Level,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ value: Handle<crate::Expression>,
+ context: &StatementContext,
+ ) -> BackendResult {
+ match context.expression.policies.image {
+ proc::BoundsCheckPolicy::Restrict => {
+ // We don't have a restricted level value, because we don't
+ // support writes to mipmapped textures.
+ debug_assert!(address.level.is_none());
+
+ write!(self.out, "{level}")?;
+ self.put_expression(image, &context.expression, false)?;
+ write!(self.out, ".write(")?;
+ self.put_expression(value, &context.expression, true)?;
+ write!(self.out, ", ")?;
+ self.put_restricted_texel_address(image, address, &context.expression)?;
+ writeln!(self.out, ");")?;
+ }
+ proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
+ write!(self.out, "{level}if (")?;
+ self.put_image_access_bounds_check(image, address, &context.expression)?;
+ writeln!(self.out, ") {{")?;
+ self.put_unchecked_image_store(level.next(), image, address, value, context)?;
+ writeln!(self.out, "{level}}}")?;
+ }
+ proc::BoundsCheckPolicy::Unchecked => {
+ self.put_unchecked_image_store(level, image, address, value, context)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_image_store(
+ &mut self,
+ level: back::Level,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ value: Handle<crate::Expression>,
+ context: &StatementContext,
+ ) -> BackendResult {
+ write!(self.out, "{level}")?;
+ self.put_expression(image, &context.expression, false)?;
+ write!(self.out, ".write(")?;
+ self.put_expression(value, &context.expression, true)?;
+ write!(self.out, ", ")?;
+ // coordinates in IR are int, but Metal expects uint
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
+ if let Some(expr) = address.array_index {
+ write!(self.out, ", ")?;
+ self.put_expression(expr, &context.expression, true)?;
+ }
+ writeln!(self.out, ");")?;
+
+ Ok(())
+ }
+
+ fn put_compose(
+ &mut self,
+ ty: Handle<crate::Type>,
+ components: &[Handle<crate::Expression>],
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match context.module.types[ty].inner {
+ crate::TypeInner::Scalar { width: 4, kind } if components.len() == 1 => {
+ write!(self.out, "{}", kind.to_msl_name())?;
+ self.put_call_parameters(components.iter().cloned(), context)?;
+ }
+ crate::TypeInner::Vector { size, kind, .. } => {
+ put_numeric_type(&mut self.out, kind, &[size])?;
+ self.put_call_parameters(components.iter().cloned(), context)?;
+ }
+ crate::TypeInner::Matrix { columns, rows, .. } => {
+ put_numeric_type(&mut self.out, crate::ScalarKind::Float, &[rows, columns])?;
+ self.put_call_parameters(components.iter().cloned(), context)?;
+ }
+ crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
+ write!(self.out, "{} {{", &self.names[&NameKey::Type(ty)])?;
+ for (index, &component) in components.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ // insert padding initialization, if needed
+ if self.struct_member_pads.contains(&(ty, index as u32)) {
+ write!(self.out, "{{}}, ")?;
+ }
+ self.put_expression(component, context, true)?;
+ }
+ write!(self.out, "}}")?;
+ }
+ _ => return Err(Error::UnsupportedCompose(ty)),
+ }
+ Ok(())
+ }
+
+ /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
+ ///
+ /// The 'maximum valid index' is simply one less than the array's length.
+ ///
+ /// This emits an expression of the form `a / b`, so the caller must
+ /// parenthesize its output if it will be applying operators of higher
+ /// precedence.
+ ///
+ /// `handle` must be the handle of a global variable whose final member is a
+ /// dynamically sized array.
+ fn put_dynamic_array_max_index(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let global = &context.module.global_variables[handle];
+ let (offset, array_ty) = match context.module.types[global.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => match members.last() {
+ Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
+ None => return Err(Error::Validation),
+ },
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } => (0, global.ty),
+ _ => return Err(Error::Validation),
+ };
+
+ let (size, stride) = match context.module.types[array_ty].inner {
+ crate::TypeInner::Array { base, stride, .. } => (
+ context.module.types[base]
+ .inner
+ .size(&context.module.constants),
+ stride,
+ ),
+ _ => return Err(Error::Validation),
+ };
+
+ // When the stride length is larger than the size, the final element's stride of
+ // bytes would have padding following the value. But the buffer size in
+ // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
+ // enough to hold the actual values' bytes.
+ //
+ // So subtract off the size to get a byte size that falls at the start or within
+ // the final element. Then divide by the stride size, to get one less than the
+ // length, and then add one. This works even if the buffer size does include the
+ // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
+ // if there are zero elements in the array, but the WebGPU `validating shader binding`
+ // rules, together with draw-time validation when `minBindingSize` is zero,
+ // prevent that.
+ write!(
+ self.out,
+ "(_buffer_sizes.size{idx} - {offset} - {size}) / {stride}",
+ idx = handle.index(),
+ offset = offset,
+ size = size,
+ stride = stride,
+ )?;
+ Ok(())
+ }
+
+ fn put_atomic_fetch(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ key: &str,
+ value: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_atomic_operation(pointer, "fetch_", key, value, context)
+ }
+
+ fn put_atomic_operation(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ key1: &str,
+ key2: &str,
+ value: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // If the pointer we're passing to the atomic operation needs to be conditional
+ // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
+ // the pointer operand should be unchecked.
+ let policy = context.choose_bounds_check_policy(pointer);
+ let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
+
+ // If requested and successfully put bounds checks, continue the ternary expression.
+ if checked {
+ write!(self.out, " ? ")?;
+ }
+
+ write!(
+ self.out,
+ "{NAMESPACE}::atomic_{key1}{key2}_explicit({ATOMIC_REFERENCE}"
+ )?;
+ self.put_access_chain(pointer, policy, context)?;
+ write!(self.out, ", ")?;
+ self.put_expression(value, context, true)?;
+ write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
+
+ // Finish the ternary expression.
+ if checked {
+ write!(self.out, " : DefaultConstructible()")?;
+ }
+
+ Ok(())
+ }
+
+ /// Emit code for the arithmetic expression of the dot product.
+ ///
+ fn put_dot_product(
+ &mut self,
+ arg: Handle<crate::Expression>,
+ arg1: Handle<crate::Expression>,
+ size: usize,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // Write parantheses around the dot product expression to prevent operators
+ // with different precedences from applying earlier.
+ write!(self.out, "(")?;
+
+ // Cycle trough all the components of the vector
+ for index in 0..size {
+ let component = back::COMPONENTS[index];
+ // Write the addition to the previous product
+ // This will print an extra '+' at the beginning but that is fine in msl
+ write!(self.out, " + ")?;
+ // Write the first vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.put_expression(arg, context, true)?;
+ // Access the current component on the first vector
+ write!(self.out, ".{component} * ")?;
+ // Write the second vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.put_expression(arg1, context, true)?;
+ // Access the current component on the second vector
+ write!(self.out, ".{component}")?;
+ }
+
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ /// Emit code for the expression `expr_handle`.
+ ///
+ /// The `is_scoped` argument is true if the surrounding operators have the
+ /// precedence of the comma operator, or lower. So, for example:
+ ///
+ /// - Pass `true` for `is_scoped` when writing function arguments, an
+ /// expression statement, an initializer expression, or anything already
+ /// wrapped in parenthesis.
+ ///
+ /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
+ /// almost anything else.
+ fn put_expression(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ // Add to the set in order to track the stack size.
+ #[cfg(test)]
+ #[allow(trivial_casts)]
+ self.put_expression_stack_pointers
+ .insert(&expr_handle as *const _ as *const ());
+
+ if let Some(name) = self.named_expressions.get(&expr_handle) {
+ write!(self.out, "{name}")?;
+ return Ok(());
+ }
+
+ let expression = &context.function.expressions[expr_handle];
+ log::trace!("expression {:?} = {:?}", expr_handle, expression);
+ match *expression {
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => {
+ // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
+ // Since `put_bounds_checks` and `put_access_chain` handle an entire
+ // access chain at a time, recursing back through `put_expression` only
+ // for index expressions and the base object, we will never see intermediate
+ // `Access` or `AccessIndex` expressions here.
+ let policy = context.choose_bounds_check_policy(base);
+ if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(
+ expr_handle,
+ context,
+ back::Level(0),
+ if is_scoped { "" } else { "(" },
+ )?
+ {
+ write!(self.out, " ? ")?;
+ self.put_access_chain(expr_handle, policy, context)?;
+ write!(self.out, " : DefaultConstructible()")?;
+
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ } else {
+ self.put_access_chain(expr_handle, policy, context)?;
+ }
+ }
+ crate::Expression::Constant(handle) => {
+ let coco = ConstantContext {
+ handle,
+ arena: &context.module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, "{coco}")?;
+ }
+ crate::Expression::Splat { size, value } => {
+ let scalar_kind = match *context.resolve_type(value) {
+ crate::TypeInner::Scalar { kind, .. } => kind,
+ _ => return Err(Error::Validation),
+ };
+ put_numeric_type(&mut self.out, scalar_kind, &[size])?;
+ write!(self.out, "(")?;
+ self.put_expression(value, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.put_wrapped_expression_for_packed_vec3_access(vector, context, false)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
+ }
+ }
+ crate::Expression::Compose { ty, ref components } => {
+ self.put_compose(ty, components, context)?;
+ }
+ crate::Expression::FunctionArgument(index) => {
+ let name_key = match context.origin {
+ FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
+ FunctionOrigin::EntryPoint(ep_index) => {
+ NameKey::EntryPointArgument(ep_index, index)
+ }
+ };
+ let name = &self.names[&name_key];
+ write!(self.out, "{name}")?;
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{name}")?;
+ }
+ crate::Expression::LocalVariable(handle) => {
+ let name_key = match context.origin {
+ FunctionOrigin::Handle(fun_handle) => {
+ NameKey::FunctionLocal(fun_handle, handle)
+ }
+ FunctionOrigin::EntryPoint(ep_index) => {
+ NameKey::EntryPointLocal(ep_index, handle)
+ }
+ };
+ let name = &self.names[&name_key];
+ write!(self.out, "{name}")?;
+ }
+ crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
+ crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ let main_op = match gather {
+ Some(_) => "gather",
+ None => "sample",
+ };
+ let comparison_op = match depth_ref {
+ Some(_) => "_compare",
+ None => "",
+ };
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".{main_op}{comparison_op}(")?;
+ self.put_expression(sampler, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(coordinate, context, true)?;
+ if let Some(expr) = array_index {
+ write!(self.out, ", ")?;
+ self.put_expression(expr, context, true)?;
+ }
+ if let Some(dref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.put_expression(dref, context, true)?;
+ }
+
+ self.put_image_sample_level(image, level, context)?;
+
+ if let Some(constant) = offset {
+ let coco = ConstantContext {
+ handle: constant,
+ arena: &context.module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, ", {coco}")?;
+ }
+ match gather {
+ None | Some(crate::SwizzleComponent::X) => {}
+ Some(component) => {
+ let is_cube_map = match *context.resolve_type(image) {
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::Cube,
+ ..
+ } => true,
+ _ => false,
+ };
+ // Offset always comes before the gather, except
+ // in cube maps where it's not applicable
+ if offset.is_none() && !is_cube_map {
+ write!(self.out, ", {NAMESPACE}::int2(0)")?;
+ }
+ let letter = back::COMPONENTS[component as usize];
+ write!(self.out, ", {NAMESPACE}::component::{letter}")?;
+ }
+ }
+ write!(self.out, ")")?;
+ }
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ let address = TexelAddress {
+ coordinate,
+ array_index,
+ sample,
+ level: level.map(LevelOfDetail::Direct),
+ };
+ self.put_image_load(expr_handle, image, address, context)?;
+ }
+ //Note: for all the queries, the signed integers are expected,
+ // so a conversion is needed.
+ crate::Expression::ImageQuery { image, query } => match query {
+ crate::ImageQuery::Size { level } => {
+ self.put_image_size_query(
+ image,
+ level.map(LevelOfDetail::Direct),
+ crate::ScalarKind::Uint,
+ context,
+ )?;
+ }
+ crate::ImageQuery::NumLevels => {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_num_mip_levels()")?;
+ }
+ crate::ImageQuery::NumLayers => {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_array_size()")?;
+ }
+ crate::ImageQuery::NumSamples => {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_num_samples()")?;
+ }
+ },
+ crate::Expression::Unary { op, expr } => {
+ use crate::{ScalarKind as Sk, UnaryOperator as Uo};
+ let op_str = match op {
+ Uo::Negate => "-",
+ Uo::Not => match context.resolve_type(expr).scalar_kind() {
+ Some(Sk::Sint) | Some(Sk::Uint) => "~",
+ Some(Sk::Bool) => "!",
+ _ => return Err(Error::Validation),
+ },
+ };
+ write!(self.out, "{op_str}(")?;
+ self.put_expression(expr, context, false)?;
+ write!(self.out, ")")?;
+ }
+ crate::Expression::Binary { op, left, right } => {
+ let op_str = crate::back::binary_operation_str(op);
+ let kind = context
+ .resolve_type(left)
+ .scalar_kind()
+ .ok_or(Error::UnsupportedBinaryOp(op))?;
+
+ // TODO: handle undefined behavior of BinaryOperator::Modulo
+ //
+ // sint:
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ // if sign(left) == -1 || sign(right) == -1 return result as defined by WGSL
+ //
+ // uint:
+ // if right == 0 return 0
+ //
+ // float:
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+
+ if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
+ write!(self.out, "{NAMESPACE}::fmod(")?;
+ self.put_expression(left, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(right, context, true)?;
+ write!(self.out, ")")?;
+ } else {
+ if !is_scoped {
+ write!(self.out, "(")?;
+ }
+
+ // Cast packed vector if necessary
+ // Packed vector - matrix multiplications are not supported in MSL
+ if op == crate::BinaryOperator::Multiply
+ && matches!(
+ context.resolve_type(right),
+ &crate::TypeInner::Matrix { .. }
+ )
+ {
+ self.put_wrapped_expression_for_packed_vec3_access(left, context, false)?;
+ } else {
+ self.put_expression(left, context, false)?;
+ }
+
+ write!(self.out, " {op_str} ")?;
+
+ // See comment above
+ if op == crate::BinaryOperator::Multiply
+ && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
+ {
+ self.put_wrapped_expression_for_packed_vec3_access(right, context, false)?;
+ } else {
+ self.put_expression(right, context, false)?;
+ }
+
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ crate::Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => match *context.resolve_type(condition) {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ ..
+ } => {
+ if !is_scoped {
+ write!(self.out, "(")?;
+ }
+ self.put_expression(condition, context, false)?;
+ write!(self.out, " ? ")?;
+ self.put_expression(accept, context, false)?;
+ write!(self.out, " : ")?;
+ self.put_expression(reject, context, false)?;
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ }
+ crate::TypeInner::Vector {
+ kind: crate::ScalarKind::Bool,
+ ..
+ } => {
+ write!(self.out, "{NAMESPACE}::select(")?;
+ self.put_expression(reject, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(accept, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(condition, context, true)?;
+ write!(self.out, ")")?;
+ }
+ _ => return Err(Error::Validation),
+ },
+ crate::Expression::Derivative { axis, expr, .. } => {
+ use crate::DerivativeAxis as Axis;
+ let op = match axis {
+ Axis::X => "dfdx",
+ Axis::Y => "dfdy",
+ Axis::Width => "fwidth",
+ };
+ write!(self.out, "{NAMESPACE}::{op}")?;
+ self.put_call_parameters(iter::once(expr), context)?;
+ }
+ crate::Expression::Relational { fun, argument } => {
+ let op = match fun {
+ crate::RelationalFunction::Any => "any",
+ crate::RelationalFunction::All => "all",
+ crate::RelationalFunction::IsNan => "isnan",
+ crate::RelationalFunction::IsInf => "isinf",
+ crate::RelationalFunction::IsFinite => "isfinite",
+ crate::RelationalFunction::IsNormal => "isnormal",
+ };
+ write!(self.out, "{NAMESPACE}::{op}")?;
+ self.put_call_parameters(iter::once(argument), context)?;
+ }
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ let scalar_argument = match *context.resolve_type(arg) {
+ crate::TypeInner::Scalar { .. } => true,
+ _ => false,
+ };
+
+ let fun_name = match fun {
+ // comparison
+ Mf::Abs => "abs",
+ Mf::Min => "min",
+ Mf::Max => "max",
+ Mf::Clamp => "clamp",
+ Mf::Saturate => "saturate",
+ // trigonometry
+ Mf::Cos => "cos",
+ Mf::Cosh => "cosh",
+ Mf::Sin => "sin",
+ Mf::Sinh => "sinh",
+ Mf::Tan => "tan",
+ Mf::Tanh => "tanh",
+ Mf::Acos => "acos",
+ Mf::Asin => "asin",
+ Mf::Atan => "atan",
+ Mf::Atan2 => "atan2",
+ Mf::Asinh => "asinh",
+ Mf::Acosh => "acosh",
+ Mf::Atanh => "atanh",
+ Mf::Radians => "",
+ Mf::Degrees => "",
+ // decomposition
+ Mf::Ceil => "ceil",
+ Mf::Floor => "floor",
+ Mf::Round => "rint",
+ Mf::Fract => "fract",
+ Mf::Trunc => "trunc",
+ Mf::Modf => "modf",
+ Mf::Frexp => "frexp",
+ Mf::Ldexp => "ldexp",
+ // exponent
+ Mf::Exp => "exp",
+ Mf::Exp2 => "exp2",
+ Mf::Log => "log",
+ Mf::Log2 => "log2",
+ Mf::Pow => "pow",
+ // geometry
+ Mf::Dot => match *context.resolve_type(arg) {
+ crate::TypeInner::Vector {
+ kind: crate::ScalarKind::Float,
+ ..
+ } => "dot",
+ crate::TypeInner::Vector { size, .. } => {
+ return self.put_dot_product(arg, arg1.unwrap(), size as usize, context)
+ }
+ _ => unreachable!(
+ "Correct TypeInner for dot product should be already validated"
+ ),
+ },
+ Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
+ Mf::Cross => "cross",
+ Mf::Distance => "distance",
+ Mf::Length if scalar_argument => "abs",
+ Mf::Length => "length",
+ Mf::Normalize => "normalize",
+ Mf::FaceForward => "faceforward",
+ Mf::Reflect => "reflect",
+ Mf::Refract => "refract",
+ // computational
+ Mf::Sign => "sign",
+ Mf::Fma => "fma",
+ Mf::Mix => "mix",
+ Mf::Step => "step",
+ Mf::SmoothStep => "smoothstep",
+ Mf::Sqrt => "sqrt",
+ Mf::InverseSqrt => "rsqrt",
+ Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
+ Mf::Transpose => "transpose",
+ Mf::Determinant => "determinant",
+ // bits
+ Mf::CountTrailingZeros => "ctz",
+ Mf::CountLeadingZeros => "clz",
+ Mf::CountOneBits => "popcount",
+ Mf::ReverseBits => "reverse_bits",
+ Mf::ExtractBits => "extract_bits",
+ Mf::InsertBits => "insert_bits",
+ Mf::FindLsb => "",
+ Mf::FindMsb => "",
+ // data packing
+ Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
+ Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
+ Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
+ Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
+ Mf::Pack2x16float => "",
+ // data unpacking
+ Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
+ Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
+ Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
+ Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
+ Mf::Unpack2x16float => "",
+ };
+
+ if fun == Mf::Distance && scalar_argument {
+ write!(self.out, "{NAMESPACE}::abs(")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, " - ")?;
+ self.put_expression(arg1.unwrap(), context, false)?;
+ write!(self.out, ")")?;
+ } else if fun == Mf::FindLsb {
+ write!(self.out, "((({NAMESPACE}::ctz(")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ") + 1) % 33) - 1)")?;
+ } else if fun == Mf::FindMsb {
+ let inner = context.resolve_type(arg);
+
+ write!(self.out, "{NAMESPACE}::select(31 - {NAMESPACE}::clz(")?;
+
+ if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
+ write!(self.out, "{NAMESPACE}::select(")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ", ~")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " < 0)")?;
+ } else {
+ self.put_expression(arg, context, true)?;
+ }
+
+ write!(self.out, "), ")?;
+
+ // or metal will complain that select is ambiguous
+ match *inner {
+ crate::TypeInner::Vector { size, kind, .. } => {
+ let size = back::vector_size_str(size);
+ if let crate::ScalarKind::Sint = kind {
+ write!(self.out, "int{size}")?;
+ } else {
+ write!(self.out, "uint{size}")?;
+ }
+ }
+ crate::TypeInner::Scalar { kind, .. } => {
+ if let crate::ScalarKind::Sint = kind {
+ write!(self.out, "int")?;
+ } else {
+ write!(self.out, "uint")?;
+ }
+ }
+ _ => (),
+ }
+
+ write!(self.out, "(-1), ")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " == 0 || ")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " == -1)")?;
+ } else if fun == Mf::Unpack2x16float {
+ write!(self.out, "float2(as_type<half2>(")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, "))")?;
+ } else if fun == Mf::Pack2x16float {
+ write!(self.out, "as_type<uint>(half2(")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, "))")?;
+ } else if fun == Mf::Radians {
+ write!(self.out, "((")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, ") * 0.017453292519943295474)")?;
+ } else if fun == Mf::Degrees {
+ write!(self.out, "((")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, ") * 57.295779513082322865)")?;
+ } else {
+ write!(self.out, "{NAMESPACE}::{fun_name}")?;
+ self.put_call_parameters(
+ iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
+ context,
+ )?;
+ }
+ }
+ crate::Expression::As {
+ expr,
+ kind,
+ convert,
+ } => match *context.resolve_type(expr) {
+ crate::TypeInner::Scalar {
+ kind: src_kind,
+ width: src_width,
+ }
+ | crate::TypeInner::Vector {
+ kind: src_kind,
+ width: src_width,
+ ..
+ } => {
+ let is_bool_cast =
+ kind == crate::ScalarKind::Bool || src_kind == crate::ScalarKind::Bool;
+ let op = match convert {
+ Some(w) if w == src_width || is_bool_cast => "static_cast",
+ Some(8) if kind == crate::ScalarKind::Float => {
+ return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
+ }
+ Some(_) => return Err(Error::Validation),
+ None => "as_type",
+ };
+ write!(self.out, "{op}<")?;
+ match *context.resolve_type(expr) {
+ crate::TypeInner::Vector { size, .. } => {
+ put_numeric_type(&mut self.out, kind, &[size])?
+ }
+ _ => put_numeric_type(&mut self.out, kind, &[])?,
+ };
+ write!(self.out, ">(")?;
+ self.put_expression(expr, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::TypeInner::Matrix { columns, rows, .. } => {
+ put_numeric_type(&mut self.out, kind, &[rows, columns])?;
+ write!(self.out, "(")?;
+ self.put_expression(expr, context, true)?;
+ write!(self.out, ")")?;
+ }
+ _ => return Err(Error::Validation),
+ },
+ // has to be a named expression
+ crate::Expression::CallResult(_)
+ | crate::Expression::AtomicResult { .. }
+ | crate::Expression::RayQueryProceedResult => {
+ unreachable!()
+ }
+ crate::Expression::ArrayLength(expr) => {
+ // Find the global to which the array belongs.
+ let global = match context.function.expressions[expr] {
+ crate::Expression::AccessIndex { base, .. } => {
+ match context.function.expressions[base] {
+ crate::Expression::GlobalVariable(handle) => handle,
+ _ => return Err(Error::Validation),
+ }
+ }
+ crate::Expression::GlobalVariable(handle) => handle,
+ _ => return Err(Error::Validation),
+ };
+
+ if !is_scoped {
+ write!(self.out, "(")?;
+ }
+ write!(self.out, "1 + ")?;
+ self.put_dynamic_array_max_index(global, context)?;
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ }
+ crate::Expression::RayQueryGetIntersection { query, committed } => {
+ if !committed {
+ unimplemented!()
+ }
+ let ty = context.module.special_types.ray_intersection.unwrap();
+ let type_name = &self.names[&NameKey::Type(ty)];
+ write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
+ self.put_expression(query, context, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
+ let fields = [
+ "distance",
+ "user_instance_id",
+ "instance_id",
+ "", // SBT offset
+ "geometry_id",
+ "primitive_id",
+ "triangle_barycentric_coord",
+ "triangle_front_facing",
+ "", // padding
+ "object_to_world_transform",
+ "world_to_object_transform",
+ ];
+ for field in fields {
+ write!(self.out, ", ")?;
+ if field.is_empty() {
+ write!(self.out, "{{}}")?;
+ } else {
+ self.put_expression(query, context, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
+ }
+ }
+ write!(self.out, "}}")?;
+ }
+ }
+ Ok(())
+ }
+
+ /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
+ fn put_wrapped_expression_for_packed_vec3_access(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ if let Some(scalar_kind) = context.get_packed_vec_kind(expr_handle) {
+ write!(self.out, "{}::{}3(", NAMESPACE, scalar_kind.to_msl_name())?;
+ self.put_expression(expr_handle, context, is_scoped)?;
+ write!(self.out, ")")?;
+ } else {
+ self.put_expression(expr_handle, context, is_scoped)?;
+ }
+ Ok(())
+ }
+
+ /// Write a `GuardedIndex` as a Metal expression.
+ fn put_index(
+ &mut self,
+ index: index::GuardedIndex,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ match index {
+ index::GuardedIndex::Expression(expr) => {
+ self.put_expression(expr, context, is_scoped)?
+ }
+ index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
+ }
+ Ok(())
+ }
+
+ /// Emit an index bounds check condition for `chain`, if required.
+ ///
+ /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
+ /// operating either on a pointer to a value, or on a value directly. If we cannot
+ /// statically determine that all indexing operations in `chain` are within
+ /// bounds, then write a conditional expression to check them dynamically,
+ /// and return true. All accesses in the chain are checked by the generated
+ /// expression.
+ ///
+ /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
+ ///
+ /// The text written is of the form:
+ ///
+ /// ```ignore
+ /// {level}{prefix}uint(i) < 4 && uint(j) < 10
+ /// ```
+ ///
+ /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
+ /// statements, presumably these arguments start an indented `if` statement; for
+ /// [`Load`] expressions, the caller is probably building up a ternary `?:`
+ /// expression. In either case, what is written is not a complete syntactic structure
+ /// in its own right, and the caller will have to finish it off if we return `true`.
+ ///
+ /// If no expression is written, return false.
+ ///
+ /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
+ /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
+ /// [`Store`]: crate::Statement::Store
+ /// [`Load`]: crate::Expression::Load
+ #[allow(unused_variables)]
+ fn put_bounds_checks(
+ &mut self,
+ mut chain: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ level: back::Level,
+ prefix: &'static str,
+ ) -> Result<bool, Error> {
+ let mut check_written = false;
+
+ // Iterate over the access chain, handling each expression.
+ loop {
+ // Produce a `GuardedIndex`, so we can shared code between the
+ // `Access` and `AccessIndex` cases.
+ let (base, guarded_index) = match context.function.expressions[chain] {
+ crate::Expression::Access { base, index } => {
+ (base, Some(index::GuardedIndex::Expression(index)))
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ // Don't try to check indices into structs. Validation already took
+ // care of them, and index::needs_guard doesn't handle that case.
+ let mut base_inner = context.resolve_type(base);
+ if let crate::TypeInner::Pointer { base, .. } = *base_inner {
+ base_inner = &context.module.types[base].inner;
+ }
+ match *base_inner {
+ crate::TypeInner::Struct { .. } => (base, None),
+ _ => (base, Some(index::GuardedIndex::Known(index))),
+ }
+ }
+ _ => break,
+ };
+
+ if let Some(index) = guarded_index {
+ if let Some(length) = context.access_needs_check(base, index) {
+ if check_written {
+ write!(self.out, " && ")?;
+ } else {
+ write!(self.out, "{level}{prefix}")?;
+ check_written = true;
+ }
+
+ // Check that the index falls within bounds. Do this with a single
+ // comparison, by casting the index to `uint` first, so that negative
+ // indices become large positive values.
+ write!(self.out, "uint(")?;
+ self.put_index(index, context, true)?;
+ self.out.write_str(") < ")?;
+ match length {
+ index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
+ index::IndexableLength::Dynamic => {
+ let global = context
+ .function
+ .originating_global(base)
+ .ok_or(Error::Validation)?;
+ write!(self.out, "1 + ")?;
+ self.put_dynamic_array_max_index(global, context)?
+ }
+ }
+ }
+ }
+
+ chain = base
+ }
+
+ Ok(check_written)
+ }
+
+ /// Write the access chain `chain`.
+ ///
+ /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
+ /// operating either on a pointer to a value, or on a value directly.
+ ///
+ /// Generate bounds checks code only if `policy` is [`Restrict`]. The
+ /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
+ /// that must be handled in the caller.
+ ///
+ /// Handle the entire chain, recursing back into `put_expression` only for index
+ /// expressions and the base expression that originates the pointer or composite value
+ /// being accessed. This allows `put_expression` to assume that any `Access` or
+ /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
+ /// `ReadZeroSkipWrite` checks.
+ ///
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
+ /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
+ fn put_access_chain(
+ &mut self,
+ chain: Handle<crate::Expression>,
+ policy: index::BoundsCheckPolicy,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match context.function.expressions[chain] {
+ crate::Expression::Access { base, index } => {
+ let mut base_ty = context.resolve_type(base);
+
+ // Look through any pointers to see what we're really indexing.
+ if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
+ base_ty = &context.module.types[base].inner;
+ }
+
+ self.put_subscripted_access_chain(
+ base,
+ base_ty,
+ index::GuardedIndex::Expression(index),
+ policy,
+ context,
+ )?;
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ let base_resolution = &context.info[base].ty;
+ let mut base_ty = base_resolution.inner_with(&context.module.types);
+ let mut base_ty_handle = base_resolution.handle();
+
+ // Look through any pointers to see what we're really indexing.
+ if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
+ base_ty = &context.module.types[base].inner;
+ base_ty_handle = Some(base);
+ }
+
+ // Handle structs and anything else that can use `.x` syntax here, so
+ // `put_subscripted_access_chain` won't have to handle the absurd case of
+ // indexing a struct with an expression.
+ match *base_ty {
+ crate::TypeInner::Struct { .. } => {
+ let base_ty = base_ty_handle.unwrap();
+ self.put_access_chain(base, policy, context)?;
+ let name = &self.names[&NameKey::StructMember(base_ty, index)];
+ write!(self.out, ".{name}")?;
+ }
+ crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
+ self.put_access_chain(base, policy, context)?;
+ // Prior to Metal v2.1 component access for packed vectors wasn't available
+ // however array indexing is
+ if context.get_packed_vec_kind(base).is_some() {
+ write!(self.out, "[{index}]")?;
+ } else {
+ write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
+ }
+ }
+ _ => {
+ self.put_subscripted_access_chain(
+ base,
+ base_ty,
+ index::GuardedIndex::Known(index),
+ policy,
+ context,
+ )?;
+ }
+ }
+ }
+ _ => self.put_expression(chain, context, false)?,
+ }
+
+ Ok(())
+ }
+
+ /// Write a `[]`-style access of `base` by `index`.
+ ///
+ /// If `policy` is [`Restrict`], then generate code as needed to force all index
+ /// values within bounds.
+ ///
+ /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
+ /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
+ /// removed. Our callers often already have this handy.
+ ///
+ /// This only emits `[]` expressions; it doesn't handle struct member accesses or
+ /// referencing vector components by name.
+ ///
+ /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
+ /// [`Array`]: crate::TypeInner::Array
+ /// [`Vector`]: crate::TypeInner::Vector
+ /// [`Pointer`]: crate::TypeInner::Pointer
+ fn put_subscripted_access_chain(
+ &mut self,
+ base: Handle<crate::Expression>,
+ base_ty: &crate::TypeInner,
+ index: index::GuardedIndex,
+ policy: index::BoundsCheckPolicy,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let accessing_wrapped_array = match *base_ty {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(_),
+ ..
+ } => true,
+ _ => false,
+ };
+
+ self.put_access_chain(base, policy, context)?;
+ if accessing_wrapped_array {
+ write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
+ }
+ write!(self.out, "[")?;
+
+ // Decide whether this index needs to be clamped to fall within range.
+ let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
+ context.access_needs_check(base, index)
+ } else {
+ None
+ };
+ if let Some(limit) = restriction_needed {
+ write!(self.out, "{NAMESPACE}::min(unsigned(")?;
+ self.put_index(index, context, true)?;
+ write!(self.out, "), ")?;
+ match limit {
+ index::IndexableLength::Known(limit) => {
+ write!(self.out, "{}u", limit - 1)?;
+ }
+ index::IndexableLength::Dynamic => {
+ let global = context
+ .function
+ .originating_global(base)
+ .ok_or(Error::Validation)?;
+ self.put_dynamic_array_max_index(global, context)?;
+ }
+ }
+ write!(self.out, ")")?;
+ } else {
+ self.put_index(index, context, true)?;
+ }
+
+ write!(self.out, "]")?;
+
+ Ok(())
+ }
+
+ fn put_load(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ // Since access chains never cross between address spaces, we can just
+ // check the index bounds check policy once at the top.
+ let policy = context.choose_bounds_check_policy(pointer);
+ if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(
+ pointer,
+ context,
+ back::Level(0),
+ if is_scoped { "" } else { "(" },
+ )?
+ {
+ write!(self.out, " ? ")?;
+ self.put_unchecked_load(pointer, policy, context)?;
+ write!(self.out, " : DefaultConstructible()")?;
+
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ } else {
+ self.put_unchecked_load(pointer, policy, context)?;
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_load(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ policy: index::BoundsCheckPolicy,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let is_atomic = match *context.resolve_type(pointer) {
+ crate::TypeInner::Pointer { base, .. } => match context.module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => true,
+ _ => false,
+ },
+ _ => false,
+ };
+
+ if is_atomic {
+ write!(
+ self.out,
+ "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
+ )?;
+ self.put_access_chain(pointer, policy, context)?;
+ write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
+ } else {
+ // We don't do any dereferencing with `*` here as pointer arguments to functions
+ // are done by `&` references and not `*` pointers. These do not need to be
+ // dereferenced.
+ self.put_access_chain(pointer, policy, context)?;
+ }
+
+ Ok(())
+ }
+
+ fn put_return_value(
+ &mut self,
+ level: back::Level,
+ expr_handle: Handle<crate::Expression>,
+ result_struct: Option<&str>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match result_struct {
+ Some(struct_name) => {
+ let mut has_point_size = false;
+ let result_ty = context.function.result.as_ref().unwrap().ty;
+ match context.module.types[result_ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let tmp = "_tmp";
+ write!(self.out, "{level}const auto {tmp} = ")?;
+ self.put_expression(expr_handle, context, true)?;
+ writeln!(self.out, ";")?;
+ write!(self.out, "{level}return {struct_name} {{")?;
+
+ let mut is_first = true;
+
+ for (index, member) in members.iter().enumerate() {
+ if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
+ member.binding
+ {
+ has_point_size = true;
+ if !context.pipeline_options.allow_point_size {
+ continue;
+ }
+ }
+
+ let comma = if is_first { "" } else { "," };
+ is_first = false;
+ let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
+ // HACK: we are forcefully deduplicating the expression here
+ // to convert from a wrapped struct to a raw array, e.g.
+ // `float gl_ClipDistance1 [[clip_distance]] [1];`.
+ if let crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(const_handle),
+ ..
+ } = context.module.types[member.ty].inner
+ {
+ let size = context.module.constants[const_handle]
+ .to_array_length()
+ .unwrap();
+ write!(self.out, "{comma} {{")?;
+ for j in 0..size {
+ if j != 0 {
+ write!(self.out, ",")?;
+ }
+ write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
+ }
+ write!(self.out, "}}")?;
+ } else {
+ write!(self.out, "{comma} {tmp}.{name}")?;
+ }
+ }
+ }
+ _ => {
+ write!(self.out, "{level}return {struct_name} {{ ")?;
+ self.put_expression(expr_handle, context, true)?;
+ }
+ }
+
+ if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
+ let stage = context.module.entry_points[ep_index as usize].stage;
+ if context.pipeline_options.allow_point_size
+ && stage == crate::ShaderStage::Vertex
+ && !has_point_size
+ {
+ // point size was injected and comes last
+ write!(self.out, ", 1.0")?;
+ }
+ }
+ write!(self.out, " }}")?;
+ }
+ None => {
+ write!(self.out, "{level}return ")?;
+ self.put_expression(expr_handle, context, true)?;
+ }
+ }
+ writeln!(self.out, ";")?;
+ Ok(())
+ }
+
+ /// Helper method used to find which expressions of a given function require baking
+ ///
+ /// # Notes
+ /// This function overwrites the contents of `self.need_bake_expressions`
+ fn update_expressions_to_bake(
+ &mut self,
+ func: &crate::Function,
+ info: &valid::FunctionInfo,
+ context: &ExpressionContext,
+ ) {
+ use crate::Expression;
+ self.need_bake_expressions.clear();
+
+ for (expr_handle, expr) in func.expressions.iter() {
+ // Expressions whose reference count is above the
+ // threshold should always be stored in temporaries.
+ let expr_info = &info[expr_handle];
+ let min_ref_count = func.expressions[expr_handle].bake_ref_count();
+ if min_ref_count <= expr_info.ref_count {
+ self.need_bake_expressions.insert(expr_handle);
+ } else {
+ match expr_info.ty {
+ // force ray desc to be baked: it's used multiple times internally
+ TypeResolution::Handle(h)
+ if Some(h) == context.module.special_types.ray_desc =>
+ {
+ self.need_bake_expressions.insert(expr_handle);
+ }
+ _ => {}
+ }
+ }
+
+ if let Expression::Math { fun, arg, arg1, .. } = *expr {
+ match fun {
+ crate::MathFunction::Dot => {
+ // WGSL's `dot` function works on any `vecN` type, but Metal's only
+ // works on floating-point vectors, so we emit inline code for
+ // integer vector `dot` calls. But that code uses each argument `N`
+ // times, once for each component (see `put_dot_product`), so to
+ // avoid duplicated evaluation, we must bake integer operands.
+
+ // check what kind of product this is depending
+ // on the resolve type of the Dot function itself
+ if let crate::TypeInner::Scalar { kind, .. } =
+ *context.resolve_type(expr_handle)
+ {
+ match kind {
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
+ self.need_bake_expressions.insert(arg);
+ self.need_bake_expressions.insert(arg1.unwrap());
+ }
+ _ => {}
+ }
+ }
+ }
+ crate::MathFunction::FindMsb => {
+ self.need_bake_expressions.insert(arg);
+ }
+ _ => {}
+ }
+ }
+ }
+ }
+
+ fn start_baking_expression(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ name: &str,
+ ) -> BackendResult {
+ match context.info[handle].ty {
+ TypeResolution::Handle(ty_handle) => {
+ let ty_name = TypeContext {
+ handle: ty_handle,
+ module: context.module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{ty_name}")?;
+ }
+ TypeResolution::Value(crate::TypeInner::Scalar { kind, .. }) => {
+ put_numeric_type(&mut self.out, kind, &[])?;
+ }
+ TypeResolution::Value(crate::TypeInner::Vector { size, kind, .. }) => {
+ put_numeric_type(&mut self.out, kind, &[size])?;
+ }
+ TypeResolution::Value(crate::TypeInner::Matrix { columns, rows, .. }) => {
+ put_numeric_type(&mut self.out, crate::ScalarKind::Float, &[rows, columns])?;
+ }
+ TypeResolution::Value(ref other) => {
+ log::warn!("Type {:?} isn't a known local", other); //TEMP!
+ return Err(Error::FeatureNotImplemented("weird local type".to_string()));
+ }
+ }
+
+ //TODO: figure out the naming scheme that wouldn't collide with user names.
+ write!(self.out, " {name} = ")?;
+
+ Ok(())
+ }
+
+ /// Cache a clamped level of detail value, if necessary.
+ ///
+ /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
+ /// properly clamped level of detail value both in the access itself, and
+ /// for fetching the size of the requested MIP level, needed to clamp the
+ /// coordinates. To avoid recomputing this clamped level of detail, we cache
+ /// it in a temporary variable, as part of the [`Emit`] statement covering
+ /// the [`ImageLoad`] expression.
+ ///
+ /// [`ImageLoad`]: crate::Expression::ImageLoad
+ /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
+ /// [`Emit`]: crate::Statement::Emit
+ fn put_cache_restricted_level(
+ &mut self,
+ load: Handle<crate::Expression>,
+ image: Handle<crate::Expression>,
+ mip_level: Option<Handle<crate::Expression>>,
+ indent: back::Level,
+ context: &StatementContext,
+ ) -> BackendResult {
+ // Does this image access actually require (or even permit) a
+ // level-of-detail, and does the policy require us to restrict it?
+ let level_of_detail = match mip_level {
+ Some(level) => level,
+ None => return Ok(()),
+ };
+
+ if context.expression.policies.image != index::BoundsCheckPolicy::Restrict
+ || !context.expression.image_needs_lod(image)
+ {
+ return Ok(());
+ }
+
+ write!(
+ self.out,
+ "{}uint {}{} = ",
+ indent,
+ CLAMPED_LOD_LOAD_PREFIX,
+ load.index(),
+ )?;
+ self.put_restricted_scalar_image_index(
+ image,
+ level_of_detail,
+ "get_num_mip_levels",
+ &context.expression,
+ )?;
+ writeln!(self.out, ";")?;
+
+ Ok(())
+ }
+
+ fn put_block(
+ &mut self,
+ level: back::Level,
+ statements: &[crate::Statement],
+ context: &StatementContext,
+ ) -> BackendResult {
+ // Add to the set in order to track the stack size.
+ #[cfg(test)]
+ #[allow(trivial_casts)]
+ self.put_block_stack_pointers
+ .insert(&level as *const _ as *const ());
+
+ for statement in statements {
+ log::trace!("statement[{}] {:?}", level.0, statement);
+ match *statement {
+ crate::Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ // `ImageLoad` expressions covered by the `Restrict` bounds check policy
+ // may need to cache a clamped version of their level-of-detail argument.
+ if let crate::Expression::ImageLoad {
+ image,
+ level: mip_level,
+ ..
+ } = context.expression.function.expressions[handle]
+ {
+ self.put_cache_restricted_level(
+ handle, image, mip_level, level, context,
+ )?;
+ }
+
+ let info = &context.expression.info[handle];
+ let ptr_class = info
+ .ty
+ .inner_with(&context.expression.module.types)
+ .pointer_space();
+ let expr_name = if ptr_class.is_some() {
+ None // don't bake pointer expressions (just yet)
+ } else if let Some(name) =
+ context.expression.function.named_expressions.get(&handle)
+ {
+ // The `crate::Function::named_expressions` table holds
+ // expressions that should be saved in temporaries once they
+ // are `Emit`ted. We only add them to `self.named_expressions`
+ // when we reach the `Emit` that covers them, so that we don't
+ // try to use their names before we've actually initialized
+ // the temporary that holds them.
+ //
+ // Don't assume the names in `named_expressions` are unique,
+ // or even valid. Use the `Namer`.
+ Some(self.namer.call(name))
+ } else if info.ref_count == 0 {
+ Some(self.namer.call(""))
+ } else {
+ // If this expression is an index that we're going to first compare
+ // against a limit, and then actually use as an index, then we may
+ // want to cache it in a temporary, to avoid evaluating it twice.
+ let bake =
+ if context.expression.guarded_indices.contains(handle.index()) {
+ true
+ } else {
+ self.need_bake_expressions.contains(&handle)
+ };
+
+ if bake {
+ Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
+ } else {
+ None
+ }
+ };
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{level}")?;
+ self.start_baking_expression(handle, &context.expression, &name)?;
+ self.put_expression(handle, &context.expression, true)?;
+ self.named_expressions.insert(handle, name);
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+ crate::Statement::Block(ref block) => {
+ if !block.is_empty() {
+ writeln!(self.out, "{level}{{")?;
+ self.put_block(level.next(), block, context)?;
+ writeln!(self.out, "{level}}}")?;
+ }
+ }
+ crate::Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{level}if (")?;
+ self.put_expression(condition, &context.expression, true)?;
+ writeln!(self.out, ") {{")?;
+ self.put_block(level.next(), accept, context)?;
+ if !reject.is_empty() {
+ writeln!(self.out, "{level}}} else {{")?;
+ self.put_block(level.next(), reject, context)?;
+ }
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ write!(self.out, "{level}switch(")?;
+ self.put_expression(selector, &context.expression, true)?;
+ writeln!(self.out, ") {{")?;
+ let lcase = level.next();
+ for case in cases.iter() {
+ match case.value {
+ crate::SwitchValue::I32(value) => {
+ write!(self.out, "{lcase}case {value}:")?;
+ }
+ crate::SwitchValue::U32(value) => {
+ write!(self.out, "{lcase}case {value}u:")?;
+ }
+ crate::SwitchValue::Default => {
+ write!(self.out, "{lcase}default:")?;
+ }
+ }
+
+ let write_block_braces = !(case.fall_through && case.body.is_empty());
+ if write_block_braces {
+ writeln!(self.out, " {{")?;
+ } else {
+ writeln!(self.out)?;
+ }
+
+ self.put_block(lcase.next(), &case.body, context)?;
+ if !case.fall_through
+ && case.body.last().map_or(true, |s| !s.is_terminator())
+ {
+ writeln!(self.out, "{}break;", lcase.next())?;
+ }
+
+ if write_block_braces {
+ writeln!(self.out, "{lcase}}}")?;
+ }
+ }
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ if !continuing.is_empty() || break_if.is_some() {
+ let gate_name = self.namer.call("loop_init");
+ writeln!(self.out, "{level}bool {gate_name} = true;")?;
+ writeln!(self.out, "{level}while(true) {{")?;
+ let lif = level.next();
+ let lcontinuing = lif.next();
+ writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
+ self.put_block(lcontinuing, continuing, context)?;
+ if let Some(condition) = break_if {
+ write!(self.out, "{lcontinuing}if (")?;
+ self.put_expression(condition, &context.expression, true)?;
+ writeln!(self.out, ") {{")?;
+ writeln!(self.out, "{}break;", lcontinuing.next())?;
+ writeln!(self.out, "{lcontinuing}}}")?;
+ }
+ writeln!(self.out, "{lif}}}")?;
+ writeln!(self.out, "{lif}{gate_name} = false;")?;
+ } else {
+ writeln!(self.out, "{level}while(true) {{")?;
+ }
+ self.put_block(level.next(), body, context)?;
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::Statement::Break => {
+ writeln!(self.out, "{level}break;")?;
+ }
+ crate::Statement::Continue => {
+ writeln!(self.out, "{level}continue;")?;
+ }
+ crate::Statement::Return {
+ value: Some(expr_handle),
+ } => {
+ self.put_return_value(
+ level,
+ expr_handle,
+ context.result_struct,
+ &context.expression,
+ )?;
+ }
+ crate::Statement::Return { value: None } => {
+ writeln!(self.out, "{level}return;")?;
+ }
+ crate::Statement::Kill => {
+ writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
+ }
+ crate::Statement::Barrier(flags) => {
+ self.write_barrier(flags, level)?;
+ }
+ crate::Statement::Store { pointer, value } => {
+ self.put_store(pointer, value, level, context)?
+ }
+ crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ let address = TexelAddress {
+ coordinate,
+ array_index,
+ sample: None,
+ level: None,
+ };
+ self.put_image_store(level, image, &address, value, context)?
+ }
+ crate::Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ if let Some(expr) = result {
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ self.start_baking_expression(expr, &context.expression, &name)?;
+ self.named_expressions.insert(expr, name);
+ }
+ let fun_name = &self.names[&NameKey::Function(function)];
+ write!(self.out, "{fun_name}(")?;
+ // first, write down the actual arguments
+ for (i, &handle) in arguments.iter().enumerate() {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.put_expression(handle, &context.expression, true)?;
+ }
+ // follow-up with any global resources used
+ let mut separate = !arguments.is_empty();
+ let fun_info = &context.mod_info[function];
+ let mut supports_array_length = false;
+ for (handle, var) in context.expression.module.global_variables.iter() {
+ if fun_info[handle].is_empty() {
+ continue;
+ }
+ if var.space.needs_pass_through() {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ if separate {
+ write!(self.out, ", ")?;
+ } else {
+ separate = true;
+ }
+ write!(self.out, "{name}")?;
+ }
+ supports_array_length |=
+ needs_array_length(var.ty, &context.expression.module.types);
+ }
+ if supports_array_length {
+ if separate {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "_buffer_sizes")?;
+ }
+
+ // done
+ writeln!(self.out, ");")?;
+ }
+ crate::Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_baking_expression(result, &context.expression, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+ match *fun {
+ crate::AtomicFunction::Add => {
+ self.put_atomic_fetch(pointer, "add", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Subtract => {
+ self.put_atomic_fetch(pointer, "sub", value, &context.expression)?;
+ }
+ crate::AtomicFunction::And => {
+ self.put_atomic_fetch(pointer, "and", value, &context.expression)?;
+ }
+ crate::AtomicFunction::InclusiveOr => {
+ self.put_atomic_fetch(pointer, "or", value, &context.expression)?;
+ }
+ crate::AtomicFunction::ExclusiveOr => {
+ self.put_atomic_fetch(pointer, "xor", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Min => {
+ self.put_atomic_fetch(pointer, "min", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Max => {
+ self.put_atomic_fetch(pointer, "max", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Exchange { compare: None } => {
+ self.put_atomic_operation(
+ pointer,
+ "exchange",
+ "",
+ value,
+ &context.expression,
+ )?;
+ }
+ crate::AtomicFunction::Exchange { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "atomic CompareExchange".to_string(),
+ ));
+ }
+ }
+ // done
+ writeln!(self.out, ";")?;
+ }
+ crate::Statement::RayQuery { query, ref fun } => {
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ //TODO: how to deal with winding?
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
+ {
+ let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
+ let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(
+ self.out,
+ ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
+ )?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
+ writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
+ }
+ {
+ let f_opaque = back::RayFlag::OPAQUE.bits();
+ let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
+ writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
+ }
+ {
+ let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(
+ self.out,
+ ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
+ )?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ writeln!(self.out, ".flags & {flag}) != 0);")?;
+ }
+
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(
+ self.out,
+ ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
+ )?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".origin, ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".dir, ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".tmin, ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".tmax), ")?;
+ self.put_expression(acceleration_structure, &context.expression, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".cull_mask);")?;
+
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ write!(self.out, "{level}")?;
+ let name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_baking_expression(result, &context.expression, &name)?;
+ self.named_expressions.insert(result, name);
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
+ //TODO: actually proceed?
+
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
+ }
+ crate::RayQueryFunction::Terminate => {
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.abort();")?;
+ }
+ }
+ }
+ }
+ }
+
+ // un-emit expressions
+ //TODO: take care of loop/continuing?
+ for statement in statements {
+ if let crate::Statement::Emit(ref range) = *statement {
+ for handle in range.clone() {
+ self.named_expressions.remove(&handle);
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn put_store(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ value: Handle<crate::Expression>,
+ level: back::Level,
+ context: &StatementContext,
+ ) -> BackendResult {
+ let policy = context.expression.choose_bounds_check_policy(pointer);
+ if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
+ {
+ writeln!(self.out, ") {{")?;
+ self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
+ writeln!(self.out, "{level}}}")?;
+ } else {
+ self.put_unchecked_store(pointer, value, policy, level, context)?;
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_store(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ value: Handle<crate::Expression>,
+ policy: index::BoundsCheckPolicy,
+ level: back::Level,
+ context: &StatementContext,
+ ) -> BackendResult {
+ let pointer_inner = context.expression.resolve_type(pointer);
+ let (array_size, is_atomic) = match *pointer_inner {
+ crate::TypeInner::Pointer { base, .. } => {
+ match context.expression.module.types[base].inner {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(ch),
+ ..
+ } => (Some(ch), false),
+ crate::TypeInner::Atomic { .. } => (None, true),
+ _ => (None, false),
+ }
+ }
+ _ => (None, false),
+ };
+
+ // we can't assign fixed-size arrays
+ if let Some(const_handle) = array_size {
+ let size = context.expression.module.constants[const_handle]
+ .to_array_length()
+ .unwrap();
+ write!(self.out, "{level}for(int _i=0; _i<{size}; ++_i) ")?;
+ self.put_access_chain(pointer, policy, &context.expression)?;
+ write!(self.out, ".{WRAPPED_ARRAY_FIELD}[_i] = ")?;
+ self.put_expression(value, &context.expression, true)?;
+ writeln!(self.out, ".{WRAPPED_ARRAY_FIELD}[_i];")?;
+ } else if is_atomic {
+ write!(
+ self.out,
+ "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
+ )?;
+ self.put_access_chain(pointer, policy, &context.expression)?;
+ write!(self.out, ", ")?;
+ self.put_expression(value, &context.expression, true)?;
+ writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
+ } else {
+ write!(self.out, "{level}")?;
+ self.put_access_chain(pointer, policy, &context.expression)?;
+ write!(self.out, " = ")?;
+ self.put_expression(value, &context.expression, true)?;
+ writeln!(self.out, ";")?;
+ }
+
+ Ok(())
+ }
+
+ pub fn write(
+ &mut self,
+ module: &crate::Module,
+ info: &valid::ModuleInfo,
+ options: &Options,
+ pipeline_options: &PipelineOptions,
+ ) -> Result<TranslationInfo, Error> {
+ self.names.clear();
+ self.namer
+ .reset(module, super::keywords::RESERVED, &[], &mut self.names);
+ self.struct_member_pads.clear();
+
+ writeln!(
+ self.out,
+ "// language: metal{}.{}",
+ options.lang_version.0, options.lang_version.1
+ )?;
+ writeln!(self.out, "#include <metal_stdlib>")?;
+ writeln!(self.out, "#include <simd/simd.h>")?;
+ writeln!(self.out)?;
+ // Work around Metal bug where `uint` is not available by default
+ writeln!(self.out, "using {NAMESPACE}::uint;")?;
+
+ if module.types.iter().any(|(_, t)| match t.inner {
+ crate::TypeInner::RayQuery => true,
+ _ => false,
+ }) {
+ let tab = back::INDENT;
+ writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
+ let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
+ writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
+ writeln!(
+ self.out,
+ "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
+ )?;
+ writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
+ writeln!(self.out, "}};")?;
+ writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
+ let v_triangle = back::RayIntersectionType::Triangle as u32;
+ let v_bbox = back::RayIntersectionType::BoundingBox as u32;
+ writeln!(
+ self.out,
+ "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
+ )?;
+ writeln!(
+ self.out,
+ "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
+ )?;
+ writeln!(self.out, "}}")?;
+ }
+ if options
+ .bounds_check_policies
+ .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
+ {
+ self.put_default_constructible()?;
+ }
+ writeln!(self.out)?;
+
+ {
+ let mut indices = vec![];
+ for (handle, var) in module.global_variables.iter() {
+ if needs_array_length(var.ty, &module.types) {
+ let idx = handle.index();
+ indices.push(idx);
+ }
+ }
+
+ if !indices.is_empty() {
+ writeln!(self.out, "struct _mslBufferSizes {{")?;
+
+ for idx in indices {
+ writeln!(self.out, "{}uint size{};", back::INDENT, idx)?;
+ }
+
+ writeln!(self.out, "}};")?;
+ writeln!(self.out)?;
+ }
+ };
+
+ self.write_scalar_constants(module)?;
+ self.write_type_defs(module)?;
+ self.write_composite_constants(module)?;
+ self.write_functions(module, info, options, pipeline_options)
+ }
+
+ /// Write the definition for the `DefaultConstructible` class.
+ ///
+ /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
+ /// produce 'zero' values for any type, including structs, arrays, and so
+ /// on. We could do this by emitting default constructor applications, but
+ /// that would entail printing the name of the type, which is more trouble
+ /// than you'd think. Instead, we just construct this magic C++14 class that
+ /// can be converted to any type that can be default constructed, using
+ /// template parameter inference to detect which type is needed, so we don't
+ /// have to figure out the name.
+ ///
+ /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
+ fn put_default_constructible(&mut self) -> BackendResult {
+ let tab = back::INDENT;
+ writeln!(self.out, "struct DefaultConstructible {{")?;
+ writeln!(self.out, "{tab}template<typename T>")?;
+ writeln!(self.out, "{tab}operator T() && {{")?;
+ writeln!(self.out, "{tab}{tab}return T {{}};")?;
+ writeln!(self.out, "{tab}}}")?;
+ writeln!(self.out, "}};")?;
+ Ok(())
+ }
+
+ fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
+ for (handle, ty) in module.types.iter() {
+ if !ty.needs_alias() {
+ continue;
+ }
+ let name = &self.names[&NameKey::Type(handle)];
+ match ty.inner {
+ // Naga IR can pass around arrays by value, but Metal, following
+ // C++, performs an array-to-pointer conversion (C++ [conv.array])
+ // on expressions of array type, so assigning the array by value
+ // isn't possible. However, Metal *does* assign structs by
+ // value. So in our Metal output, we wrap all array types in
+ // synthetic struct types:
+ //
+ // struct type1 {
+ // float inner[10]
+ // };
+ //
+ // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
+ // any expression that actually wants access to the array.
+ crate::TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ } => {
+ let base_name = TypeContext {
+ handle: base,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+
+ match size {
+ crate::ArraySize::Constant(const_handle) => {
+ let coco = ConstantContext {
+ handle: const_handle,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+
+ writeln!(self.out, "struct {name} {{")?;
+ writeln!(
+ self.out,
+ "{}{} {}[{}];",
+ back::INDENT,
+ base_name,
+ WRAPPED_ARRAY_FIELD,
+ coco
+ )?;
+ writeln!(self.out, "}};")?;
+ }
+ crate::ArraySize::Dynamic => {
+ writeln!(self.out, "typedef {base_name} {name}[1];")?;
+ }
+ }
+ }
+ crate::TypeInner::Struct {
+ ref members, span, ..
+ } => {
+ writeln!(self.out, "struct {name} {{")?;
+ let mut last_offset = 0;
+ for (index, member) in members.iter().enumerate() {
+ // quick and dirty way to figure out if we need this...
+ if member.binding.is_none() && member.offset > last_offset {
+ self.struct_member_pads.insert((handle, index as u32));
+ let pad = member.offset - last_offset;
+ writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
+ }
+ let ty_inner = &module.types[member.ty].inner;
+ last_offset = member.offset + ty_inner.size(&module.constants);
+
+ let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
+
+ // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
+ match should_pack_struct_member(members, span, index, module) {
+ Some(kind) => {
+ writeln!(
+ self.out,
+ "{}{}::packed_{}3 {};",
+ back::INDENT,
+ NAMESPACE,
+ kind.to_msl_name(),
+ member_name
+ )?;
+ }
+ None => {
+ let base_name = TypeContext {
+ handle: member.ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ writeln!(
+ self.out,
+ "{}{} {};",
+ back::INDENT,
+ base_name,
+ member_name
+ )?;
+
+ // for 3-component vectors, add one component
+ if let crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ kind: _,
+ width,
+ } = *ty_inner
+ {
+ last_offset += width as u32;
+ }
+ }
+ }
+ }
+ writeln!(self.out, "}};")?;
+ }
+ _ => {
+ let ty_name = TypeContext {
+ handle,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: true,
+ };
+ writeln!(self.out, "typedef {ty_name} {name};")?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn write_scalar_constants(&mut self, module: &crate::Module) -> BackendResult {
+ for (handle, constant) in module.constants.iter() {
+ match constant.inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ ref value,
+ } if constant.name.is_some() => {
+ debug_assert!(constant.needs_alias());
+ write!(self.out, "constexpr constant ")?;
+ match *value {
+ crate::ScalarValue::Sint(_) => {
+ write!(self.out, "int")?;
+ }
+ crate::ScalarValue::Uint(_) => {
+ write!(self.out, "unsigned")?;
+ }
+ crate::ScalarValue::Float(_) => {
+ write!(self.out, "float")?;
+ }
+ crate::ScalarValue::Bool(_) => {
+ write!(self.out, "bool")?;
+ }
+ }
+ let name = &self.names[&NameKey::Constant(handle)];
+ let coco = ConstantContext {
+ handle,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: true,
+ };
+ writeln!(self.out, " {name} = {coco};")?;
+ }
+ _ => {}
+ }
+ }
+ Ok(())
+ }
+
+ fn write_composite_constants(&mut self, module: &crate::Module) -> BackendResult {
+ for (handle, constant) in module.constants.iter() {
+ match constant.inner {
+ crate::ConstantInner::Scalar { .. } => {}
+ crate::ConstantInner::Composite { ty, ref components } => {
+ debug_assert!(constant.needs_alias());
+ let name = &self.names[&NameKey::Constant(handle)];
+ let ty_name = TypeContext {
+ handle: ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "constant {ty_name} {name} = {{",)?;
+ for (i, &sub_handle) in components.iter().enumerate() {
+ // insert padding initialization, if needed
+ if self.struct_member_pads.contains(&(ty, i as u32)) {
+ write!(self.out, ", {{}}")?;
+ }
+ let separator = if i != 0 { ", " } else { "" };
+ let coco = ConstantContext {
+ handle: sub_handle,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, "{separator}{coco}")?;
+ }
+ writeln!(self.out, "}};")?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn put_inline_sampler_properties(
+ &mut self,
+ level: back::Level,
+ sampler: &sm::InlineSampler,
+ ) -> BackendResult {
+ for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
+ writeln!(
+ self.out,
+ "{}{}::{}_address::{},",
+ level,
+ NAMESPACE,
+ letter,
+ address.as_str(),
+ )?;
+ }
+ writeln!(
+ self.out,
+ "{}{}::mag_filter::{},",
+ level,
+ NAMESPACE,
+ sampler.mag_filter.as_str(),
+ )?;
+ writeln!(
+ self.out,
+ "{}{}::min_filter::{},",
+ level,
+ NAMESPACE,
+ sampler.min_filter.as_str(),
+ )?;
+ if let Some(filter) = sampler.mip_filter {
+ writeln!(
+ self.out,
+ "{}{}::mip_filter::{},",
+ level,
+ NAMESPACE,
+ filter.as_str(),
+ )?;
+ }
+ // avoid setting it on platforms that don't support it
+ if sampler.border_color != sm::BorderColor::TransparentBlack {
+ writeln!(
+ self.out,
+ "{}{}::border_color::{},",
+ level,
+ NAMESPACE,
+ sampler.border_color.as_str(),
+ )?;
+ }
+ //TODO: I'm not able to feed this in a way that MSL likes:
+ //>error: use of undeclared identifier 'lod_clamp'
+ //>error: no member named 'max_anisotropy' in namespace 'metal'
+ if false {
+ if let Some(ref lod) = sampler.lod_clamp {
+ writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
+ }
+ if let Some(aniso) = sampler.max_anisotropy {
+ writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
+ }
+ }
+ if sampler.compare_func != sm::CompareFunc::Never {
+ writeln!(
+ self.out,
+ "{}{}::compare_func::{},",
+ level,
+ NAMESPACE,
+ sampler.compare_func.as_str(),
+ )?;
+ }
+ writeln!(
+ self.out,
+ "{}{}::coord::{}",
+ level,
+ NAMESPACE,
+ sampler.coord.as_str()
+ )?;
+ Ok(())
+ }
+
+ // Returns the array of mapped entry point names.
+ fn write_functions(
+ &mut self,
+ module: &crate::Module,
+ mod_info: &valid::ModuleInfo,
+ options: &Options,
+ pipeline_options: &PipelineOptions,
+ ) -> Result<TranslationInfo, Error> {
+ let mut pass_through_globals = Vec::new();
+ for (fun_handle, fun) in module.functions.iter() {
+ log::trace!(
+ "function {:?}, handle {:?}",
+ fun.name.as_deref().unwrap_or("(anonymous)"),
+ fun_handle
+ );
+
+ let fun_info = &mod_info[fun_handle];
+ pass_through_globals.clear();
+ let mut supports_array_length = false;
+ for (handle, var) in module.global_variables.iter() {
+ if !fun_info[handle].is_empty() {
+ if var.space.needs_pass_through() {
+ pass_through_globals.push(handle);
+ }
+ supports_array_length |= needs_array_length(var.ty, &module.types);
+ }
+ }
+
+ writeln!(self.out)?;
+ let fun_name = &self.names[&NameKey::Function(fun_handle)];
+ match fun.result {
+ Some(ref result) => {
+ let ty_name = TypeContext {
+ handle: result.ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{ty_name}")?;
+ }
+ None => {
+ write!(self.out, "void")?;
+ }
+ }
+ writeln!(self.out, " {fun_name}(")?;
+
+ for (index, arg) in fun.arguments.iter().enumerate() {
+ let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
+ let param_type_name = TypeContext {
+ handle: arg.ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let separator = separate(
+ !pass_through_globals.is_empty()
+ || index + 1 != fun.arguments.len()
+ || supports_array_length,
+ );
+ writeln!(
+ self.out,
+ "{}{} {}{}",
+ back::INDENT,
+ param_type_name,
+ name,
+ separator
+ )?;
+ }
+ for (index, &handle) in pass_through_globals.iter().enumerate() {
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage: fun_info[handle],
+ binding: None,
+ reference: true,
+ };
+ let separator =
+ separate(index + 1 != pass_through_globals.len() || supports_array_length);
+ write!(self.out, "{}", back::INDENT)?;
+ tyvar.try_fmt(&mut self.out)?;
+ writeln!(self.out, "{separator}")?;
+ }
+
+ if supports_array_length {
+ writeln!(
+ self.out,
+ "{}constant _mslBufferSizes& _buffer_sizes",
+ back::INDENT
+ )?;
+ }
+
+ writeln!(self.out, ") {{")?;
+
+ for (local_handle, local) in fun.local_variables.iter() {
+ let ty_name = TypeContext {
+ handle: local.ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let local_name = &self.names[&NameKey::FunctionLocal(fun_handle, local_handle)];
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, local_name)?;
+ match local.init {
+ Some(value) => {
+ let coco = ConstantContext {
+ handle: value,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, " = {coco}")?;
+ }
+ None => {
+ write!(self.out, " = {{}}")?;
+ }
+ };
+ writeln!(self.out, ";")?;
+ }
+
+ let guarded_indices =
+ index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
+
+ let context = StatementContext {
+ expression: ExpressionContext {
+ function: fun,
+ origin: FunctionOrigin::Handle(fun_handle),
+ info: fun_info,
+ policies: options.bounds_check_policies,
+ guarded_indices,
+ module,
+ pipeline_options,
+ },
+ mod_info,
+ result_struct: None,
+ };
+ self.named_expressions.clear();
+ self.update_expressions_to_bake(fun, fun_info, &context.expression);
+ self.put_block(back::Level(1), &fun.body, &context)?;
+ writeln!(self.out, "}}")?;
+ }
+
+ let mut info = TranslationInfo {
+ entry_point_names: Vec::with_capacity(module.entry_points.len()),
+ };
+ for (ep_index, ep) in module.entry_points.iter().enumerate() {
+ let fun = &ep.function;
+ let fun_info = mod_info.get_entry_point(ep_index);
+ let mut ep_error = None;
+
+ log::trace!(
+ "entry point {:?}, index {:?}",
+ fun.name.as_deref().unwrap_or("(anonymous)"),
+ ep_index
+ );
+
+ // Is any global variable used by this entry point dynamically sized?
+ let supports_array_length = module
+ .global_variables
+ .iter()
+ .filter(|&(handle, _)| !fun_info[handle].is_empty())
+ .any(|(_, var)| needs_array_length(var.ty, &module.types));
+
+ // skip this entry point if any global bindings are missing,
+ // or their types are incompatible.
+ if !options.fake_missing_bindings {
+ for (var_handle, var) in module.global_variables.iter() {
+ if fun_info[var_handle].is_empty() {
+ continue;
+ }
+ match var.space {
+ crate::AddressSpace::Uniform
+ | crate::AddressSpace::Storage { .. }
+ | crate::AddressSpace::Handle => {
+ let br = match var.binding {
+ Some(ref br) => br,
+ None => {
+ let var_name = var.name.clone().unwrap_or_default();
+ ep_error =
+ Some(super::EntryPointError::MissingBinding(var_name));
+ break;
+ }
+ };
+ let target = options.get_resource_binding_target(ep, br);
+ let good = match target {
+ Some(target) => {
+ let binding_ty = match module.types[var.ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => {
+ &module.types[base].inner
+ }
+ ref ty => ty,
+ };
+ match *binding_ty {
+ crate::TypeInner::Image { .. } => target.texture.is_some(),
+ crate::TypeInner::Sampler { .. } => {
+ target.sampler.is_some()
+ }
+ _ => target.buffer.is_some(),
+ }
+ }
+ None => false,
+ };
+ if !good {
+ ep_error =
+ Some(super::EntryPointError::MissingBindTarget(br.clone()));
+ break;
+ }
+ }
+ crate::AddressSpace::PushConstant => {
+ if let Err(e) = options.resolve_push_constants(ep) {
+ ep_error = Some(e);
+ break;
+ }
+ }
+ crate::AddressSpace::Function
+ | crate::AddressSpace::Private
+ | crate::AddressSpace::WorkGroup => {}
+ }
+ }
+ if supports_array_length {
+ if let Err(err) = options.resolve_sizes_buffer(ep) {
+ ep_error = Some(err);
+ }
+ }
+ }
+
+ if let Some(err) = ep_error {
+ info.entry_point_names.push(Err(err));
+ continue;
+ }
+ let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
+ info.entry_point_names.push(Ok(fun_name.clone()));
+
+ writeln!(self.out)?;
+
+ let (em_str, in_mode, out_mode) = match ep.stage {
+ crate::ShaderStage::Vertex => (
+ "vertex",
+ LocationMode::VertexInput,
+ LocationMode::VertexOutput,
+ ),
+ crate::ShaderStage::Fragment { .. } => (
+ "fragment",
+ LocationMode::FragmentInput,
+ LocationMode::FragmentOutput,
+ ),
+ crate::ShaderStage::Compute { .. } => {
+ ("kernel", LocationMode::Uniform, LocationMode::Uniform)
+ }
+ };
+
+ // List all the Naga `EntryPoint`'s `Function`'s arguments,
+ // flattening structs into their members. In Metal, we will pass
+ // each of these values to the entry point as a separate argument—
+ // except for the varyings, handled next.
+ let mut flattened_arguments = Vec::new();
+ for (arg_index, arg) in fun.arguments.iter().enumerate() {
+ match module.types[arg.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (member_index, member) in members.iter().enumerate() {
+ let member_index = member_index as u32;
+ flattened_arguments.push((
+ NameKey::StructMember(arg.ty, member_index),
+ member.ty,
+ member.binding.as_ref(),
+ ));
+ }
+ }
+ _ => flattened_arguments.push((
+ NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
+ arg.ty,
+ arg.binding.as_ref(),
+ )),
+ }
+ }
+
+ // Identify the varyings among the argument values, and emit a
+ // struct type named `<fun>Input` to hold them.
+ let stage_in_name = format!("{fun_name}Input");
+ let varyings_member_name = self.namer.call("varyings");
+ let mut has_varyings = false;
+ if !flattened_arguments.is_empty() {
+ writeln!(self.out, "struct {stage_in_name} {{")?;
+ for &(ref name_key, ty, binding) in flattened_arguments.iter() {
+ let binding = match binding {
+ Some(ref binding @ &crate::Binding::Location { .. }) => binding,
+ _ => continue,
+ };
+ has_varyings = true;
+ let name = &self.names[name_key];
+ let ty_name = TypeContext {
+ handle: ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let resolved = options.resolve_local_binding(binding, in_mode)?;
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out, ";")?;
+ }
+ writeln!(self.out, "}};")?;
+ }
+
+ // Define a struct type named for the return value, if any, named
+ // `<fun>Output`.
+ let stage_out_name = format!("{fun_name}Output");
+ let result_member_name = self.namer.call("member");
+ let result_type_name = match fun.result {
+ Some(ref result) => {
+ let mut result_members = Vec::new();
+ if let crate::TypeInner::Struct { ref members, .. } =
+ module.types[result.ty].inner
+ {
+ for (member_index, member) in members.iter().enumerate() {
+ result_members.push((
+ &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
+ member.ty,
+ member.binding.as_ref(),
+ ));
+ }
+ } else {
+ result_members.push((
+ &result_member_name,
+ result.ty,
+ result.binding.as_ref(),
+ ));
+ }
+
+ writeln!(self.out, "struct {stage_out_name} {{")?;
+ let mut has_point_size = false;
+ for (name, ty, binding) in result_members {
+ let ty_name = TypeContext {
+ handle: ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: true,
+ };
+ let binding = binding.ok_or(Error::Validation)?;
+
+ if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
+ has_point_size = true;
+ if !pipeline_options.allow_point_size {
+ continue;
+ }
+ }
+
+ let array_len = match module.types[ty].inner {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(handle),
+ ..
+ } => module.constants[handle].to_array_length(),
+ _ => None,
+ };
+ let resolved = options.resolve_local_binding(binding, out_mode)?;
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
+ if let Some(array_len) = array_len {
+ write!(self.out, " [{array_len}]")?;
+ }
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out, ";")?;
+ }
+
+ if pipeline_options.allow_point_size
+ && ep.stage == crate::ShaderStage::Vertex
+ && !has_point_size
+ {
+ // inject the point size output last
+ writeln!(
+ self.out,
+ "{}float _point_size [[point_size]];",
+ back::INDENT
+ )?;
+ }
+ writeln!(self.out, "}};")?;
+ &stage_out_name
+ }
+ None => "void",
+ };
+
+ // Write the entry point function's name, and begin its argument list.
+ writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
+ let mut is_first_argument = true;
+
+ // If we have produced a struct holding the `EntryPoint`'s
+ // `Function`'s arguments' varyings, pass that struct first.
+ if has_varyings {
+ writeln!(
+ self.out,
+ " {stage_in_name} {varyings_member_name} [[stage_in]]"
+ )?;
+ is_first_argument = false;
+ }
+
+ let mut local_invocation_id = None;
+
+ // Then pass the remaining arguments not included in the varyings
+ // struct.
+ //
+ // Since `Namer.reset` wasn't expecting struct members to be
+ // suddenly injected into the normal namespace like this,
+ // `self.names` doesn't keep them distinct from other variables.
+ // Generate fresh names for these arguments, and remember the
+ // mapping.
+ let mut flattened_member_names = FastHashMap::default();
+ for &(ref name_key, ty, binding) in flattened_arguments.iter() {
+ let binding = match binding {
+ Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
+ _ => continue,
+ };
+ let name = if let NameKey::StructMember(ty, index) = *name_key {
+ // We should always insert a fresh entry here, but use
+ // `or_insert` to get a reference to the `String` we just
+ // inserted.
+ flattened_member_names
+ .entry(NameKey::StructMember(ty, index))
+ .or_insert_with(|| self.namer.call(&self.names[name_key]))
+ } else {
+ &self.names[name_key]
+ };
+
+ if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
+ local_invocation_id = Some(name_key);
+ }
+
+ let ty_name = TypeContext {
+ handle: ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let resolved = options.resolve_local_binding(binding, in_mode)?;
+ let separator = if is_first_argument {
+ is_first_argument = false;
+ ' '
+ } else {
+ ','
+ };
+ write!(self.out, "{separator} {ty_name} {name}")?;
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out)?;
+ }
+
+ let need_workgroup_variables_initialization =
+ self.need_workgroup_variables_initialization(options, ep, module, fun_info);
+
+ if need_workgroup_variables_initialization && local_invocation_id.is_none() {
+ let separator = if is_first_argument {
+ is_first_argument = false;
+ ' '
+ } else {
+ ','
+ };
+ writeln!(
+ self.out,
+ "{separator} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]"
+ )?;
+ }
+
+ // Those global variables used by this entry point and its callees
+ // get passed as arguments. `Private` globals are an exception, they
+ // don't outlive this invocation, so we declare them below as locals
+ // within the entry point.
+ for (handle, var) in module.global_variables.iter() {
+ let usage = fun_info[handle];
+ if usage.is_empty() || var.space == crate::AddressSpace::Private {
+ continue;
+ }
+ // the resolves have already been checked for `!fake_missing_bindings` case
+ let resolved = match var.space {
+ crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
+ crate::AddressSpace::WorkGroup => None,
+ crate::AddressSpace::Storage { .. } if options.lang_version < (2, 0) => {
+ return Err(Error::UnsupportedAddressSpace(var.space))
+ }
+ _ => options
+ .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
+ .ok(),
+ };
+ if let Some(ref resolved) = resolved {
+ // Inline samplers are be defined in the EP body
+ if resolved.as_inline_sampler(options).is_some() {
+ continue;
+ }
+ }
+
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage,
+ binding: resolved.as_ref(),
+ reference: true,
+ };
+ let separator = if is_first_argument {
+ is_first_argument = false;
+ ' '
+ } else {
+ ','
+ };
+ write!(self.out, "{separator} ")?;
+ tyvar.try_fmt(&mut self.out)?;
+ if let Some(resolved) = resolved {
+ resolved.try_fmt(&mut self.out)?;
+ }
+ if let Some(value) = var.init {
+ let coco = ConstantContext {
+ handle: value,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, " = {coco}")?;
+ }
+ writeln!(self.out)?;
+ }
+
+ // If this entry uses any variable-length arrays, their sizes are
+ // passed as a final struct-typed argument.
+ if supports_array_length {
+ // this is checked earlier
+ let resolved = options.resolve_sizes_buffer(ep).unwrap();
+ let separator = if module.global_variables.is_empty() {
+ ' '
+ } else {
+ ','
+ };
+ write!(
+ self.out,
+ "{separator} constant _mslBufferSizes& _buffer_sizes",
+ )?;
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out)?;
+ }
+
+ // end of the entry point argument list
+ writeln!(self.out, ") {{")?;
+
+ if need_workgroup_variables_initialization {
+ self.write_workgroup_variables_initialization(
+ module,
+ mod_info,
+ fun_info,
+ local_invocation_id,
+ )?;
+ }
+
+ // Metal doesn't support private mutable variables outside of functions,
+ // so we put them here, just like the locals.
+ for (handle, var) in module.global_variables.iter() {
+ let usage = fun_info[handle];
+ if usage.is_empty() {
+ continue;
+ }
+ if var.space == crate::AddressSpace::Private {
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage,
+ binding: None,
+ reference: false,
+ };
+ write!(self.out, "{}", back::INDENT)?;
+ tyvar.try_fmt(&mut self.out)?;
+ match var.init {
+ Some(value) => {
+ let coco = ConstantContext {
+ handle: value,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ writeln!(self.out, " = {coco};")?;
+ }
+ None => {
+ writeln!(self.out, " = {{}};")?;
+ }
+ };
+ } else if let Some(ref binding) = var.binding {
+ // write an inline sampler
+ let resolved = options.resolve_resource_binding(ep, binding).unwrap();
+ if let Some(sampler) = resolved.as_inline_sampler(options) {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ writeln!(
+ self.out,
+ "{}constexpr {}::sampler {}(",
+ back::INDENT,
+ NAMESPACE,
+ name
+ )?;
+ self.put_inline_sampler_properties(back::Level(2), sampler)?;
+ writeln!(self.out, "{});", back::INDENT)?;
+ }
+ }
+ }
+
+ // Now take the arguments that we gathered into structs, and the
+ // structs that we flattened into arguments, and emit local
+ // variables with initializers that put everything back the way the
+ // body code expects.
+ //
+ // If we had to generate fresh names for struct members passed as
+ // arguments, be sure to use those names when rebuilding the struct.
+ //
+ // "Each day, I change some zeros to ones, and some ones to zeros.
+ // The rest, I leave alone."
+ for (arg_index, arg) in fun.arguments.iter().enumerate() {
+ let arg_name =
+ &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
+ match module.types[arg.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let struct_name = &self.names[&NameKey::Type(arg.ty)];
+ write!(
+ self.out,
+ "{}const {} {} = {{ ",
+ back::INDENT,
+ struct_name,
+ arg_name
+ )?;
+ for (member_index, member) in members.iter().enumerate() {
+ let key = NameKey::StructMember(arg.ty, member_index as u32);
+ // If it's not in the varying struct, then we should
+ // have passed it as its own argument and assigned
+ // it a new name.
+ let name = match member.binding {
+ Some(crate::Binding::BuiltIn { .. }) => {
+ &flattened_member_names[&key]
+ }
+ _ => &self.names[&key],
+ };
+ if member_index != 0 {
+ write!(self.out, ", ")?;
+ }
+ if let Some(crate::Binding::Location { .. }) = member.binding {
+ write!(self.out, "{varyings_member_name}.")?;
+ }
+ write!(self.out, "{name}")?;
+ }
+ writeln!(self.out, " }};")?;
+ }
+ _ => {
+ if let Some(crate::Binding::Location { .. }) = arg.binding {
+ writeln!(
+ self.out,
+ "{}const auto {} = {}.{};",
+ back::INDENT,
+ arg_name,
+ varyings_member_name,
+ arg_name
+ )?;
+ }
+ }
+ }
+ }
+
+ // Finally, declare all the local variables that we need
+ //TODO: we can postpone this till the relevant expressions are emitted
+ for (local_handle, local) in fun.local_variables.iter() {
+ let name = &self.names[&NameKey::EntryPointLocal(ep_index as _, local_handle)];
+ let ty_name = TypeContext {
+ handle: local.ty,
+ module,
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
+ match local.init {
+ Some(value) => {
+ let coco = ConstantContext {
+ handle: value,
+ arena: &module.constants,
+ names: &self.names,
+ first_time: false,
+ };
+ write!(self.out, " = {coco}")?;
+ }
+ None => {
+ write!(self.out, " = {{}}")?;
+ }
+ };
+ writeln!(self.out, ";")?;
+ }
+
+ let guarded_indices =
+ index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
+
+ let context = StatementContext {
+ expression: ExpressionContext {
+ function: fun,
+ origin: FunctionOrigin::EntryPoint(ep_index as _),
+ info: fun_info,
+ policies: options.bounds_check_policies,
+ guarded_indices,
+ module,
+ pipeline_options,
+ },
+ mod_info,
+ result_struct: Some(&stage_out_name),
+ };
+ self.named_expressions.clear();
+ self.update_expressions_to_bake(fun, fun_info, &context.expression);
+ self.put_block(back::Level(1), &fun.body, &context)?;
+ writeln!(self.out, "}}")?;
+ if ep_index + 1 != module.entry_points.len() {
+ writeln!(self.out)?;
+ }
+ }
+
+ Ok(info)
+ }
+
+ fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
+ // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
+ // so we try to avoid it here.
+ if flags.is_empty() {
+ writeln!(
+ self.out,
+ "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
+ )?;
+ }
+ if flags.contains(crate::Barrier::STORAGE) {
+ writeln!(
+ self.out,
+ "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
+ )?;
+ }
+ if flags.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(
+ self.out,
+ "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
+ )?;
+ }
+ Ok(())
+ }
+}
+
+/// Initializing workgroup variables is more tricky for Metal because we have to deal
+/// with atomics at the type-level (which don't have a copy constructor).
+mod workgroup_mem_init {
+ use crate::EntryPoint;
+
+ use super::*;
+
+ enum Access {
+ GlobalVariable(Handle<crate::GlobalVariable>),
+ StructMember(Handle<crate::Type>, u32),
+ Array(usize),
+ }
+
+ impl Access {
+ fn write<W: Write>(
+ &self,
+ writer: &mut W,
+ names: &FastHashMap<NameKey, String>,
+ ) -> Result<(), core::fmt::Error> {
+ match *self {
+ Access::GlobalVariable(handle) => {
+ write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
+ }
+ Access::StructMember(handle, index) => {
+ write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
+ }
+ Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
+ }
+ }
+ }
+
+ struct AccessStack {
+ stack: Vec<Access>,
+ array_depth: usize,
+ }
+
+ impl AccessStack {
+ const fn new() -> Self {
+ Self {
+ stack: Vec::new(),
+ array_depth: 0,
+ }
+ }
+
+ fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
+ let array_depth = self.array_depth;
+ self.stack.push(Access::Array(array_depth));
+ self.array_depth += 1;
+ let res = cb(self, array_depth);
+ self.stack.pop();
+ self.array_depth -= 1;
+ res
+ }
+
+ fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
+ self.stack.push(new);
+ let res = cb(self);
+ self.stack.pop();
+ res
+ }
+
+ fn write<W: Write>(
+ &self,
+ writer: &mut W,
+ names: &FastHashMap<NameKey, String>,
+ ) -> Result<(), core::fmt::Error> {
+ for next in self.stack.iter() {
+ next.write(writer, names)?;
+ }
+ Ok(())
+ }
+ }
+
+ impl<W: Write> Writer<W> {
+ pub(super) fn need_workgroup_variables_initialization(
+ &mut self,
+ options: &Options,
+ ep: &EntryPoint,
+ module: &crate::Module,
+ fun_info: &valid::FunctionInfo,
+ ) -> bool {
+ options.zero_initialize_workgroup_memory
+ && ep.stage == crate::ShaderStage::Compute
+ && module.global_variables.iter().any(|(handle, var)| {
+ !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ })
+ }
+
+ pub(super) fn write_workgroup_variables_initialization(
+ &mut self,
+ module: &crate::Module,
+ module_info: &valid::ModuleInfo,
+ fun_info: &valid::FunctionInfo,
+ local_invocation_id: Option<&NameKey>,
+ ) -> BackendResult {
+ let level = back::Level(1);
+
+ writeln!(
+ self.out,
+ "{}if ({}::all({} == {}::uint3(0u))) {{",
+ level,
+ NAMESPACE,
+ local_invocation_id
+ .map(|name_key| self.names[name_key].as_str())
+ .unwrap_or("__local_invocation_id"),
+ NAMESPACE,
+ )?;
+
+ let mut access_stack = AccessStack::new();
+
+ let vars = module.global_variables.iter().filter(|&(handle, var)| {
+ !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ });
+
+ for (handle, var) in vars {
+ access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
+ self.write_workgroup_variable_initialization(
+ module,
+ module_info,
+ var.ty,
+ access_stack,
+ level.next(),
+ )
+ })?;
+ }
+
+ writeln!(self.out, "{level}}}")?;
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)
+ }
+
+ fn write_workgroup_variable_initialization(
+ &mut self,
+ module: &crate::Module,
+ module_info: &valid::ModuleInfo,
+ ty: Handle<crate::Type>,
+ access_stack: &mut AccessStack,
+ level: back::Level,
+ ) -> BackendResult {
+ if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
+ write!(self.out, "{level}")?;
+ access_stack.write(&mut self.out, &self.names)?;
+ writeln!(self.out, " = {{}};")?;
+ } else {
+ match module.types[ty].inner {
+ crate::TypeInner::Atomic { .. } => {
+ write!(
+ self.out,
+ "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
+ )?;
+ access_stack.write(&mut self.out, &self.names)?;
+ writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
+ }
+ crate::TypeInner::Array { base, size, .. } => {
+ let count = match size.to_indexable_length(module).expect("Bad array size")
+ {
+ proc::IndexableLength::Known(count) => count,
+ proc::IndexableLength::Dynamic => unreachable!(),
+ };
+
+ access_stack.enter_array(|access_stack, array_depth| {
+ writeln!(
+ self.out,
+ "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
+ )?;
+ self.write_workgroup_variable_initialization(
+ module,
+ module_info,
+ base,
+ access_stack,
+ level.next(),
+ )?;
+ writeln!(self.out, "{level}}}")?;
+ BackendResult::Ok(())
+ })?;
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (index, member) in members.iter().enumerate() {
+ access_stack.enter(
+ Access::StructMember(ty, index as u32),
+ |access_stack| {
+ self.write_workgroup_variable_initialization(
+ module,
+ module_info,
+ member.ty,
+ access_stack,
+ level,
+ )
+ },
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ Ok(())
+ }
+ }
+}
+
+#[test]
+fn test_stack_size() {
+ use crate::valid::{Capabilities, ValidationFlags};
+ // create a module with at least one expression nested
+ let mut module = crate::Module::default();
+ let constant = module.constants.append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Float(1.0),
+ width: 4,
+ },
+ },
+ Default::default(),
+ );
+ let mut fun = crate::Function::default();
+ let const_expr = fun
+ .expressions
+ .append(crate::Expression::Constant(constant), Default::default());
+ let nested_expr = fun.expressions.append(
+ crate::Expression::Unary {
+ op: crate::UnaryOperator::Negate,
+ expr: const_expr,
+ },
+ Default::default(),
+ );
+ fun.body.push(
+ crate::Statement::Emit(fun.expressions.range_from(1)),
+ Default::default(),
+ );
+ fun.body.push(
+ crate::Statement::If {
+ condition: nested_expr,
+ accept: crate::Block::new(),
+ reject: crate::Block::new(),
+ },
+ Default::default(),
+ );
+ let _ = module.functions.append(fun, Default::default());
+ // analyse the module
+ let info = crate::valid::Validator::new(ValidationFlags::empty(), Capabilities::empty())
+ .validate(&module)
+ .unwrap();
+ // process the module
+ let mut writer = Writer::new(String::new());
+ writer
+ .write(&module, &info, &Default::default(), &Default::default())
+ .unwrap();
+
+ {
+ // check expression stack
+ let mut addresses_start = usize::MAX;
+ let mut addresses_end = 0usize;
+ for pointer in writer.put_expression_stack_pointers {
+ addresses_start = addresses_start.min(pointer as usize);
+ addresses_end = addresses_end.max(pointer as usize);
+ }
+ let stack_size = addresses_end - addresses_start;
+ // check the size (in debug only)
+ // last observed macOS value: 20528 (CI)
+ if !(11000..=25000).contains(&stack_size) {
+ panic!("`put_expression` stack size {stack_size} has changed!");
+ }
+ }
+
+ {
+ // check block stack
+ let mut addresses_start = usize::MAX;
+ let mut addresses_end = 0usize;
+ for pointer in writer.put_block_stack_pointers {
+ addresses_start = addresses_start.min(pointer as usize);
+ addresses_end = addresses_end.max(pointer as usize);
+ }
+ let stack_size = addresses_end - addresses_start;
+ // check the size (in debug only)
+ // last observed macOS value: 19152 (CI)
+ if !(9000..=20000).contains(&stack_size) {
+ panic!("`put_block` stack size {stack_size} has changed!");
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/block.rs b/third_party/rust/naga/src/back/spv/block.rs
new file mode 100644
index 0000000000..8366df415a
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/block.rs
@@ -0,0 +1,2235 @@
+/*!
+Implementations for `BlockContext` methods.
+*/
+
+use super::{
+ index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext, Dimension,
+ Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer, WriterFlags,
+};
+use crate::{arena::Handle, proc::TypeResolution};
+use spirv::Word;
+
+fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
+ match *type_inner {
+ crate::TypeInner::Scalar { .. } => Dimension::Scalar,
+ crate::TypeInner::Vector { .. } => Dimension::Vector,
+ crate::TypeInner::Matrix { .. } => Dimension::Matrix,
+ _ => unreachable!(),
+ }
+}
+
+/// The results of emitting code for a left-hand-side expression.
+///
+/// On success, `write_expression_pointer` returns one of these.
+enum ExpressionPointer {
+ /// The pointer to the expression's value is available, as the value of the
+ /// expression with the given id.
+ Ready { pointer_id: Word },
+
+ /// The access expression must be conditional on the value of `condition`, a boolean
+ /// expression that is true if all indices are in bounds. If `condition` is true, then
+ /// `access` is an `OpAccessChain` instruction that will compute a pointer to the
+ /// expression's value. If `condition` is false, then executing `access` would be
+ /// undefined behavior.
+ Conditional {
+ condition: Word,
+ access: Instruction,
+ },
+}
+
+/// The termination statement to be added to the end of the block
+pub enum BlockExit {
+ /// Generates an OpReturn (void return)
+ Return,
+ /// Generates an OpBranch to the specified block
+ Branch {
+ /// The branch target block
+ target: Word,
+ },
+ /// Translates a loop `break if` into an `OpBranchConditional` to the
+ /// merge block if true (the merge block is passed through [`LoopContext::break_id`]
+ /// or else to the loop header (passed through [`preamble_id`])
+ ///
+ /// [`preamble_id`]: Self::BreakIf::preamble_id
+ BreakIf {
+ /// The condition of the `break if`
+ condition: Handle<crate::Expression>,
+ /// The loop header block id
+ preamble_id: Word,
+ },
+}
+
+impl Writer {
+ // Flip Y coordinate to adjust for coordinate space difference
+ // between SPIR-V and our IR.
+ // The `position_id` argument is a pointer to a `vecN<f32>`,
+ // whose `y` component we will negate.
+ fn write_epilogue_position_y_flip(
+ &mut self,
+ position_id: Word,
+ body: &mut Vec<Instruction>,
+ ) -> Result<(), Error> {
+ let float_ptr_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ pointer_space: Some(spirv::StorageClass::Output),
+ }));
+ let index_y_id = self.get_index_constant(1);
+ let access_id = self.id_gen.next();
+ body.push(Instruction::access_chain(
+ float_ptr_type_id,
+ access_id,
+ position_id,
+ &[index_y_id],
+ ));
+
+ let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ pointer_space: None,
+ }));
+ let load_id = self.id_gen.next();
+ body.push(Instruction::load(float_type_id, load_id, access_id, None));
+
+ let neg_id = self.id_gen.next();
+ body.push(Instruction::unary(
+ spirv::Op::FNegate,
+ float_type_id,
+ neg_id,
+ load_id,
+ ));
+
+ body.push(Instruction::store(access_id, neg_id, None));
+ Ok(())
+ }
+
+ // Clamp fragment depth between 0 and 1.
+ fn write_epilogue_frag_depth_clamp(
+ &mut self,
+ frag_depth_id: Word,
+ body: &mut Vec<Instruction>,
+ ) -> Result<(), Error> {
+ let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ pointer_space: None,
+ }));
+ let value0_id = self.get_constant_scalar(crate::ScalarValue::Float(0.0), 4);
+ let value1_id = self.get_constant_scalar(crate::ScalarValue::Float(1.0), 4);
+
+ let original_id = self.id_gen.next();
+ body.push(Instruction::load(
+ float_type_id,
+ original_id,
+ frag_depth_id,
+ None,
+ ));
+
+ let clamp_id = self.id_gen.next();
+ body.push(Instruction::ext_inst(
+ self.gl450_ext_inst_id,
+ spirv::GLOp::FClamp,
+ float_type_id,
+ clamp_id,
+ &[original_id, value0_id, value1_id],
+ ));
+
+ body.push(Instruction::store(frag_depth_id, clamp_id, None));
+ Ok(())
+ }
+
+ fn write_entry_point_return(
+ &mut self,
+ value_id: Word,
+ ir_result: &crate::FunctionResult,
+ result_members: &[ResultMember],
+ body: &mut Vec<Instruction>,
+ ) -> Result<(), Error> {
+ for (index, res_member) in result_members.iter().enumerate() {
+ let member_value_id = match ir_result.binding {
+ Some(_) => value_id,
+ None => {
+ let member_value_id = self.id_gen.next();
+ body.push(Instruction::composite_extract(
+ res_member.type_id,
+ member_value_id,
+ value_id,
+ &[index as u32],
+ ));
+ member_value_id
+ }
+ };
+
+ body.push(Instruction::store(res_member.id, member_value_id, None));
+
+ match res_member.built_in {
+ Some(crate::BuiltIn::Position { .. })
+ if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
+ {
+ self.write_epilogue_position_y_flip(res_member.id, body)?;
+ }
+ Some(crate::BuiltIn::FragDepth)
+ if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
+ {
+ self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
+ }
+ _ => {}
+ }
+ }
+ Ok(())
+ }
+}
+
+impl<'w> BlockContext<'w> {
+ /// Decide whether to put off emitting instructions for `expr_handle`.
+ ///
+ /// We would like to gather together chains of `Access` and `AccessIndex`
+ /// Naga expressions into a single `OpAccessChain` SPIR-V instruction. To do
+ /// this, we don't generate instructions for these exprs when we first
+ /// encounter them. Their ids in `self.writer.cached.ids` are left as zero. Then,
+ /// once we encounter a `Load` or `Store` expression that actually needs the
+ /// chain's value, we call `write_expression_pointer` to handle the whole
+ /// thing in one fell swoop.
+ fn is_intermediate(&self, expr_handle: Handle<crate::Expression>) -> bool {
+ match self.ir_function.expressions[expr_handle] {
+ crate::Expression::GlobalVariable(handle) => {
+ let ty = self.ir_module.global_variables[handle].ty;
+ match self.ir_module.types[ty].inner {
+ crate::TypeInner::BindingArray { .. } => false,
+ _ => true,
+ }
+ }
+ crate::Expression::LocalVariable(_) => true,
+ crate::Expression::FunctionArgument(index) => {
+ let arg = &self.ir_function.arguments[index as usize];
+ self.ir_module.types[arg.ty].inner.pointer_space().is_some()
+ }
+
+ // The chain rule: if this `Access...`'s `base` operand was
+ // previously omitted, then omit this one, too.
+ _ => self.cached.ids[expr_handle.index()] == 0,
+ }
+ }
+
+ /// Cache an expression for a value.
+ pub(super) fn cache_expression_value(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
+
+ let id = match self.ir_function.expressions[expr_handle] {
+ crate::Expression::Access { base, index: _ } if self.is_intermediate(base) => {
+ // See `is_intermediate`; we'll handle this later in
+ // `write_expression_pointer`.
+ 0
+ }
+ crate::Expression::Access { base, index } => {
+ let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types);
+ match *base_ty_inner {
+ crate::TypeInner::Vector { .. } => {
+ self.write_vector_access(expr_handle, base, index, block)?
+ }
+ crate::TypeInner::BindingArray {
+ base: binding_type, ..
+ } => {
+ let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
+ base: binding_type,
+ class: spirv::StorageClass::UniformConstant,
+ });
+
+ let result_id = match self.write_expression_pointer(
+ expr_handle,
+ block,
+ Some(binding_array_false_pointer),
+ )? {
+ ExpressionPointer::Ready { pointer_id } => pointer_id,
+ ExpressionPointer::Conditional { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "Texture array out-of-bounds handling",
+ ));
+ }
+ };
+
+ let binding_type_id = self.get_type_id(LookupType::Handle(binding_type));
+
+ let load_id = self.gen_id();
+ block.body.push(Instruction::load(
+ binding_type_id,
+ load_id,
+ result_id,
+ None,
+ ));
+
+ if self.fun_info[index].uniformity.non_uniform_result.is_some() {
+ self.writer.require_any(
+ "NonUniformEXT",
+ &[spirv::Capability::ShaderNonUniform],
+ )?;
+ self.writer.use_extension("SPV_EXT_descriptor_indexing");
+ self.writer
+ .decorate(load_id, spirv::Decoration::NonUniform, &[]);
+ }
+ load_id
+ }
+ ref other => {
+ log::error!(
+ "Unable to access base {:?} of type {:?}",
+ self.ir_function.expressions[base],
+ other
+ );
+ return Err(Error::Validation(
+ "only vectors may be dynamically indexed by value",
+ ));
+ }
+ }
+ }
+ crate::Expression::AccessIndex { base, index: _ } if self.is_intermediate(base) => {
+ // See `is_intermediate`; we'll handle this later in
+ // `write_expression_pointer`.
+ 0
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. }
+ | crate::TypeInner::Array { .. }
+ | crate::TypeInner::Struct { .. } => {
+ // We never need bounds checks here: dynamically sized arrays can
+ // only appear behind pointers, and are thus handled by the
+ // `is_intermediate` case above. Everything else's size is
+ // statically known and checked in validation.
+ let id = self.gen_id();
+ let base_id = self.cached[base];
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ id,
+ base_id,
+ &[index],
+ ));
+ id
+ }
+ crate::TypeInner::BindingArray {
+ base: binding_type, ..
+ } => {
+ let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
+ base: binding_type,
+ class: spirv::StorageClass::UniformConstant,
+ });
+
+ let result_id = match self.write_expression_pointer(
+ expr_handle,
+ block,
+ Some(binding_array_false_pointer),
+ )? {
+ ExpressionPointer::Ready { pointer_id } => pointer_id,
+ ExpressionPointer::Conditional { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "Texture array out-of-bounds handling",
+ ));
+ }
+ };
+
+ let binding_type_id = self.get_type_id(LookupType::Handle(binding_type));
+
+ let load_id = self.gen_id();
+ block.body.push(Instruction::load(
+ binding_type_id,
+ load_id,
+ result_id,
+ None,
+ ));
+
+ load_id
+ }
+ ref other => {
+ log::error!("Unable to access index of {:?}", other);
+ return Err(Error::FeatureNotImplemented("access index for type"));
+ }
+ }
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ self.writer.global_variables[handle.index()].access_id
+ }
+ crate::Expression::Constant(handle) => self.writer.constant_ids[handle.index()],
+ crate::Expression::Splat { size, value } => {
+ let value_id = self.cached[value];
+ let components = [value_id; 4];
+ let id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ id,
+ &components[..size as usize],
+ ));
+ id
+ }
+ crate::Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ let vector_id = self.cached[vector];
+ self.temp_list.clear();
+ for &sc in pattern[..size as usize].iter() {
+ self.temp_list.push(sc as Word);
+ }
+ let id = self.gen_id();
+ block.body.push(Instruction::vector_shuffle(
+ result_type_id,
+ id,
+ vector_id,
+ vector_id,
+ &self.temp_list,
+ ));
+ id
+ }
+ crate::Expression::Compose {
+ ty: _,
+ ref components,
+ } => {
+ self.temp_list.clear();
+ for &component in components {
+ self.temp_list.push(self.cached[component]);
+ }
+
+ let id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ id,
+ &self.temp_list,
+ ));
+ id
+ }
+ crate::Expression::Unary { op, expr } => {
+ let id = self.gen_id();
+ let expr_id = self.cached[expr];
+ let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
+
+ let spirv_op = match op {
+ crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
+ Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNot,
+ Some(crate::ScalarKind::Uint) | None => {
+ log::error!("Unable to negate {:?}", expr_ty_inner);
+ return Err(Error::FeatureNotImplemented("negation"));
+ }
+ },
+ crate::UnaryOperator::Not => match expr_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNot,
+ _ => spirv::Op::Not,
+ },
+ };
+
+ block
+ .body
+ .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
+ id
+ }
+ crate::Expression::Binary { op, left, right } => {
+ let id = self.gen_id();
+ let left_id = self.cached[left];
+ let right_id = self.cached[right];
+
+ let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
+ let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
+
+ let left_dimension = get_dimension(left_ty_inner);
+ let right_dimension = get_dimension(right_ty_inner);
+
+ let mut reverse_operands = false;
+
+ let spirv_op = match op {
+ crate::BinaryOperator::Add => match *left_ty_inner {
+ crate::TypeInner::Scalar { kind, .. }
+ | crate::TypeInner::Vector { kind, .. } => match kind {
+ crate::ScalarKind::Float => spirv::Op::FAdd,
+ _ => spirv::Op::IAdd,
+ },
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ self.write_matrix_matrix_column_op(
+ block,
+ id,
+ result_type_id,
+ left_id,
+ right_id,
+ columns,
+ rows,
+ width,
+ spirv::Op::FAdd,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Subtract => match *left_ty_inner {
+ crate::TypeInner::Scalar { kind, .. }
+ | crate::TypeInner::Vector { kind, .. } => match kind {
+ crate::ScalarKind::Float => spirv::Op::FSub,
+ _ => spirv::Op::ISub,
+ },
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ self.write_matrix_matrix_column_op(
+ block,
+ id,
+ result_type_id,
+ left_id,
+ right_id,
+ columns,
+ rows,
+ width,
+ spirv::Op::FSub,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Multiply => match (left_dimension, right_dimension) {
+ (Dimension::Scalar, Dimension::Vector) => {
+ self.write_vector_scalar_mult(
+ block,
+ id,
+ result_type_id,
+ right_id,
+ left_id,
+ right_ty_inner,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ (Dimension::Vector, Dimension::Scalar) => {
+ self.write_vector_scalar_mult(
+ block,
+ id,
+ result_type_id,
+ left_id,
+ right_id,
+ left_ty_inner,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ (Dimension::Vector, Dimension::Matrix) => spirv::Op::VectorTimesMatrix,
+ (Dimension::Matrix, Dimension::Scalar) => spirv::Op::MatrixTimesScalar,
+ (Dimension::Scalar, Dimension::Matrix) => {
+ reverse_operands = true;
+ spirv::Op::MatrixTimesScalar
+ }
+ (Dimension::Matrix, Dimension::Vector) => spirv::Op::MatrixTimesVector,
+ (Dimension::Matrix, Dimension::Matrix) => spirv::Op::MatrixTimesMatrix,
+ (Dimension::Vector, Dimension::Vector)
+ | (Dimension::Scalar, Dimension::Scalar)
+ if left_ty_inner.scalar_kind() == Some(crate::ScalarKind::Float) =>
+ {
+ spirv::Op::FMul
+ }
+ (Dimension::Vector, Dimension::Vector)
+ | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
+ },
+ crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
+ Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
+ Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
+ // TODO: handle undefined behavior
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ Some(crate::ScalarKind::Sint) => spirv::Op::SRem,
+ // TODO: handle undefined behavior
+ // if right == 0 return 0
+ Some(crate::ScalarKind::Uint) => spirv::Op::UMod,
+ // TODO: handle undefined behavior
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+ Some(crate::ScalarKind::Float) => spirv::Op::FRem,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
+ spirv::Op::IEqual
+ }
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
+ spirv::Op::INotEqual
+ }
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
+ Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
+ Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
+ Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
+ Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
+ _ => spirv::Op::BitwiseAnd,
+ },
+ crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
+ crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
+ _ => spirv::Op::BitwiseOr,
+ },
+ crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
+ crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
+ crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
+ crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
+ Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
+ _ => unimplemented!(),
+ },
+ };
+
+ block.body.push(Instruction::binary(
+ spirv_op,
+ result_type_id,
+ id,
+ if reverse_operands { right_id } else { left_id },
+ if reverse_operands { left_id } else { right_id },
+ ));
+ id
+ }
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+ enum MathOp {
+ Ext(spirv::GLOp),
+ Custom(Instruction),
+ }
+
+ let arg0_id = self.cached[arg];
+ let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
+ let arg_scalar_kind = arg_ty.scalar_kind();
+ let arg1_id = match arg1 {
+ Some(handle) => self.cached[handle],
+ None => 0,
+ };
+ let arg2_id = match arg2 {
+ Some(handle) => self.cached[handle],
+ None => 0,
+ };
+ let arg3_id = match arg3 {
+ Some(handle) => self.cached[handle],
+ None => 0,
+ };
+
+ let id = self.gen_id();
+ let math_op = match fun {
+ // comparison
+ Mf::Abs => {
+ match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FAbs),
+ Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GLOp::SAbs),
+ Some(crate::ScalarKind::Uint) => {
+ MathOp::Custom(Instruction::unary(
+ spirv::Op::CopyObject, // do nothing
+ result_type_id,
+ id,
+ arg0_id,
+ ))
+ }
+ other => unimplemented!("Unexpected abs({:?})", other),
+ }
+ }
+ Mf::Min => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FMin,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SMin,
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::UMin,
+ other => unimplemented!("Unexpected min({:?})", other),
+ }),
+ Mf::Max => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FMax,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SMax,
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax,
+ other => unimplemented!("Unexpected max({:?})", other),
+ }),
+ Mf::Clamp => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FClamp,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SClamp,
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::UClamp,
+ other => unimplemented!("Unexpected max({:?})", other),
+ }),
+ Mf::Saturate => {
+ let (maybe_size, width) = match *arg_ty {
+ crate::TypeInner::Vector { size, width, .. } => (Some(size), width),
+ crate::TypeInner::Scalar { width, .. } => (None, width),
+ ref other => unimplemented!("Unexpected saturate({:?})", other),
+ };
+
+ let mut arg1_id = self
+ .writer
+ .get_constant_scalar(crate::ScalarValue::Float(0.0), width);
+ let mut arg2_id = self
+ .writer
+ .get_constant_scalar(crate::ScalarValue::Float(1.0), width);
+
+ if let Some(size) = maybe_size {
+ let ty = LocalType::Value {
+ vector_size: Some(size),
+ kind: crate::ScalarKind::Float,
+ width,
+ pointer_space: None,
+ }
+ .into();
+
+ self.temp_list.clear();
+ self.temp_list.resize(size as _, arg1_id);
+
+ arg1_id = self.writer.get_constant_composite(ty, &self.temp_list);
+
+ self.temp_list.fill(arg2_id);
+
+ arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
+ }
+
+ MathOp::Custom(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::FClamp,
+ result_type_id,
+ id,
+ &[arg0_id, arg1_id, arg2_id],
+ ))
+ }
+ // trigonometry
+ Mf::Sin => MathOp::Ext(spirv::GLOp::Sin),
+ Mf::Sinh => MathOp::Ext(spirv::GLOp::Sinh),
+ Mf::Asin => MathOp::Ext(spirv::GLOp::Asin),
+ Mf::Cos => MathOp::Ext(spirv::GLOp::Cos),
+ Mf::Cosh => MathOp::Ext(spirv::GLOp::Cosh),
+ Mf::Acos => MathOp::Ext(spirv::GLOp::Acos),
+ Mf::Tan => MathOp::Ext(spirv::GLOp::Tan),
+ Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh),
+ Mf::Atan => MathOp::Ext(spirv::GLOp::Atan),
+ Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2),
+ Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh),
+ Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh),
+ Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh),
+ Mf::Radians => MathOp::Ext(spirv::GLOp::Radians),
+ Mf::Degrees => MathOp::Ext(spirv::GLOp::Degrees),
+ // decomposition
+ Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil),
+ Mf::Round => MathOp::Ext(spirv::GLOp::RoundEven),
+ Mf::Floor => MathOp::Ext(spirv::GLOp::Floor),
+ Mf::Fract => MathOp::Ext(spirv::GLOp::Fract),
+ Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc),
+ Mf::Modf => MathOp::Ext(spirv::GLOp::Modf),
+ Mf::Frexp => MathOp::Ext(spirv::GLOp::Frexp),
+ Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
+ // geometry
+ Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Vector {
+ kind: crate::ScalarKind::Float,
+ ..
+ } => MathOp::Custom(Instruction::binary(
+ spirv::Op::Dot,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ )),
+ // TODO: consider using integer dot product if VK_KHR_shader_integer_dot_product is available
+ crate::TypeInner::Vector { size, .. } => {
+ self.write_dot_product(
+ id,
+ result_type_id,
+ arg0_id,
+ arg1_id,
+ size as u32,
+ block,
+ );
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ _ => unreachable!(
+ "Correct TypeInner for dot product should be already validated"
+ ),
+ },
+ Mf::Outer => MathOp::Custom(Instruction::binary(
+ spirv::Op::OuterProduct,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ )),
+ Mf::Cross => MathOp::Ext(spirv::GLOp::Cross),
+ Mf::Distance => MathOp::Ext(spirv::GLOp::Distance),
+ Mf::Length => MathOp::Ext(spirv::GLOp::Length),
+ Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize),
+ Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward),
+ Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect),
+ Mf::Refract => MathOp::Ext(spirv::GLOp::Refract),
+ // exponent
+ Mf::Exp => MathOp::Ext(spirv::GLOp::Exp),
+ Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2),
+ Mf::Log => MathOp::Ext(spirv::GLOp::Log),
+ Mf::Log2 => MathOp::Ext(spirv::GLOp::Log2),
+ Mf::Pow => MathOp::Ext(spirv::GLOp::Pow),
+ // computational
+ Mf::Sign => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FSign,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SSign,
+ other => unimplemented!("Unexpected sign({:?})", other),
+ }),
+ Mf::Fma => MathOp::Ext(spirv::GLOp::Fma),
+ Mf::Mix => {
+ let selector = arg2.unwrap();
+ let selector_ty =
+ self.fun_info[selector].ty.inner_with(&self.ir_module.types);
+ match (arg_ty, selector_ty) {
+ // if the selector is a scalar, we need to splat it
+ (
+ &crate::TypeInner::Vector { size, .. },
+ &crate::TypeInner::Scalar { kind, width },
+ ) => {
+ let selector_type_id =
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(size),
+ kind,
+ width,
+ pointer_space: None,
+ }));
+ self.temp_list.clear();
+ self.temp_list.resize(size as usize, arg2_id);
+
+ let selector_id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ selector_type_id,
+ selector_id,
+ &self.temp_list,
+ ));
+
+ MathOp::Custom(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::FMix,
+ result_type_id,
+ id,
+ &[arg0_id, arg1_id, selector_id],
+ ))
+ }
+ _ => MathOp::Ext(spirv::GLOp::FMix),
+ }
+ }
+ Mf::Step => MathOp::Ext(spirv::GLOp::Step),
+ Mf::SmoothStep => MathOp::Ext(spirv::GLOp::SmoothStep),
+ Mf::Sqrt => MathOp::Ext(spirv::GLOp::Sqrt),
+ Mf::InverseSqrt => MathOp::Ext(spirv::GLOp::InverseSqrt),
+ Mf::Inverse => MathOp::Ext(spirv::GLOp::MatrixInverse),
+ Mf::Transpose => MathOp::Custom(Instruction::unary(
+ spirv::Op::Transpose,
+ result_type_id,
+ id,
+ arg0_id,
+ )),
+ Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant),
+ Mf::ReverseBits => MathOp::Custom(Instruction::unary(
+ spirv::Op::BitReverse,
+ result_type_id,
+ id,
+ arg0_id,
+ )),
+ Mf::CountTrailingZeros => {
+ let uint = crate::ScalarValue::Uint(32);
+ let uint_id = match *arg_ty {
+ crate::TypeInner::Vector { size, width, .. } => {
+ let ty = LocalType::Value {
+ vector_size: Some(size),
+ kind: crate::ScalarKind::Uint,
+ width,
+ pointer_space: None,
+ }
+ .into();
+
+ self.temp_list.clear();
+ self.temp_list.resize(
+ size as _,
+ self.writer.get_constant_scalar(uint, width),
+ );
+
+ self.writer.get_constant_composite(ty, &self.temp_list)
+ }
+ crate::TypeInner::Scalar { width, .. } => {
+ self.writer.get_constant_scalar(uint, width)
+ }
+ _ => unreachable!(),
+ };
+
+ let lsb_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::FindILsb,
+ result_type_id,
+ lsb_id,
+ &[arg0_id],
+ ));
+
+ MathOp::Custom(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::UMin,
+ result_type_id,
+ id,
+ &[uint_id, lsb_id],
+ ))
+ }
+ Mf::CountLeadingZeros => {
+ let int = crate::ScalarValue::Sint(31);
+
+ let (int_type_id, int_id) = match *arg_ty {
+ crate::TypeInner::Vector { size, width, .. } => {
+ let ty = LocalType::Value {
+ vector_size: Some(size),
+ kind: crate::ScalarKind::Sint,
+ width,
+ pointer_space: None,
+ }
+ .into();
+
+ self.temp_list.clear();
+ self.temp_list
+ .resize(size as _, self.writer.get_constant_scalar(int, width));
+
+ (
+ self.get_type_id(ty),
+ self.writer.get_constant_composite(ty, &self.temp_list),
+ )
+ }
+ crate::TypeInner::Scalar { width, .. } => (
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Sint,
+ width,
+ pointer_space: None,
+ })),
+ self.writer.get_constant_scalar(int, width),
+ ),
+ _ => unreachable!(),
+ };
+
+ let msb_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::FindUMsb,
+ int_type_id,
+ msb_id,
+ &[arg0_id],
+ ));
+
+ MathOp::Custom(Instruction::binary(
+ spirv::Op::ISub,
+ result_type_id,
+ id,
+ int_id,
+ msb_id,
+ ))
+ }
+ Mf::CountOneBits => MathOp::Custom(Instruction::unary(
+ spirv::Op::BitCount,
+ result_type_id,
+ id,
+ arg0_id,
+ )),
+ Mf::ExtractBits => {
+ let op = match arg_scalar_kind {
+ Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
+ Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
+ other => unimplemented!("Unexpected sign({:?})", other),
+ };
+ MathOp::Custom(Instruction::ternary(
+ op,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ arg2_id,
+ ))
+ }
+ Mf::InsertBits => MathOp::Custom(Instruction::quaternary(
+ spirv::Op::BitFieldInsert,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ arg2_id,
+ arg3_id,
+ )),
+ Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb),
+ Mf::FindMsb => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
+ other => unimplemented!("Unexpected findMSB({:?})", other),
+ }),
+ Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8),
+ Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8),
+ Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16),
+ Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16),
+ Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16),
+ Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8),
+ Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8),
+ Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16),
+ Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16),
+ Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16),
+ };
+
+ block.body.push(match math_op {
+ MathOp::Ext(op) => Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ op,
+ result_type_id,
+ id,
+ &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
+ ),
+ MathOp::Custom(inst) => inst,
+ });
+ id
+ }
+ crate::Expression::LocalVariable(variable) => self.function.variables[&variable].id,
+ crate::Expression::Load { pointer } => {
+ match self.write_expression_pointer(pointer, block, None)? {
+ ExpressionPointer::Ready { pointer_id } => {
+ let id = self.gen_id();
+ let atomic_space =
+ match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Pointer { base, space } => {
+ match self.ir_module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => Some(space),
+ _ => None,
+ }
+ }
+ _ => None,
+ };
+ let instruction = if let Some(space) = atomic_space {
+ let (semantics, scope) = space.to_spirv_semantics_and_scope();
+ let scope_constant_id = self.get_scope_constant(scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ Instruction::atomic_load(
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ )
+ } else {
+ Instruction::load(result_type_id, id, pointer_id, None)
+ };
+ block.body.push(instruction);
+ id
+ }
+ ExpressionPointer::Conditional { condition, access } => {
+ //TODO: support atomics?
+ self.write_conditional_indexed_load(
+ result_type_id,
+ condition,
+ block,
+ move |id_gen, block| {
+ // The in-bounds path. Perform the access and the load.
+ let pointer_id = access.result_id.unwrap();
+ let value_id = id_gen.next();
+ block.body.push(access);
+ block.body.push(Instruction::load(
+ result_type_id,
+ value_id,
+ pointer_id,
+ None,
+ ));
+ value_id
+ },
+ )
+ }
+ }
+ }
+ crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
+ crate::Expression::CallResult(_)
+ | crate::Expression::AtomicResult { .. }
+ | crate::Expression::RayQueryProceedResult => self.cached[expr_handle],
+ crate::Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ use crate::ScalarKind as Sk;
+
+ let expr_id = self.cached[expr];
+ let (src_kind, src_size, src_width, is_matrix) =
+ match *self.fun_info[expr].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Scalar { kind, width } => (kind, None, width, false),
+ crate::TypeInner::Vector { kind, width, size } => {
+ (kind, Some(size), width, false)
+ }
+ crate::TypeInner::Matrix { width, .. } => (kind, None, width, true),
+ ref other => {
+ log::error!("As source {:?}", other);
+ return Err(Error::Validation("Unexpected Expression::As source"));
+ }
+ };
+
+ enum Cast {
+ Identity,
+ Unary(spirv::Op),
+ Binary(spirv::Op, Word),
+ Ternary(spirv::Op, Word, Word),
+ }
+
+ let cast = if is_matrix {
+ // we only support identity casts for matrices
+ Cast::Unary(spirv::Op::CopyObject)
+ } else {
+ match (src_kind, kind, convert) {
+ // Filter out identity casts. Some Adreno drivers are
+ // confused by no-op OpBitCast instructions.
+ (src_kind, kind, convert)
+ if src_kind == kind && convert.unwrap_or(src_width) == src_width =>
+ {
+ Cast::Identity
+ }
+ (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject),
+ (_, _, None) => Cast::Unary(spirv::Op::Bitcast),
+ // casting to a bool - generate `OpXxxNotEqual`
+ (_, Sk::Bool, Some(_)) => {
+ let (op, value) = match src_kind {
+ Sk::Sint => (spirv::Op::INotEqual, crate::ScalarValue::Sint(0)),
+ Sk::Uint => (spirv::Op::INotEqual, crate::ScalarValue::Uint(0)),
+ Sk::Float => {
+ (spirv::Op::FUnordNotEqual, crate::ScalarValue::Float(0.0))
+ }
+ Sk::Bool => unreachable!(),
+ };
+ let zero_scalar_id = self.writer.get_constant_scalar(value, src_width);
+ let zero_id = match src_size {
+ Some(size) => {
+ let ty = LocalType::Value {
+ vector_size: Some(size),
+ kind: src_kind,
+ width: src_width,
+ pointer_space: None,
+ }
+ .into();
+
+ self.temp_list.clear();
+ self.temp_list.resize(size as _, zero_scalar_id);
+
+ self.writer.get_constant_composite(ty, &self.temp_list)
+ }
+ None => zero_scalar_id,
+ };
+
+ Cast::Binary(op, zero_id)
+ }
+ // casting from a bool - generate `OpSelect`
+ (Sk::Bool, _, Some(dst_width)) => {
+ let (val0, val1) = match kind {
+ Sk::Sint => {
+ (crate::ScalarValue::Sint(0), crate::ScalarValue::Sint(1))
+ }
+ Sk::Uint => {
+ (crate::ScalarValue::Uint(0), crate::ScalarValue::Uint(1))
+ }
+ Sk::Float => (
+ crate::ScalarValue::Float(0.0),
+ crate::ScalarValue::Float(1.0),
+ ),
+ Sk::Bool => unreachable!(),
+ };
+ let scalar0_id = self.writer.get_constant_scalar(val0, dst_width);
+ let scalar1_id = self.writer.get_constant_scalar(val1, dst_width);
+ let (accept_id, reject_id) = match src_size {
+ Some(size) => {
+ let ty = LocalType::Value {
+ vector_size: Some(size),
+ kind,
+ width: dst_width,
+ pointer_space: None,
+ }
+ .into();
+
+ self.temp_list.clear();
+ self.temp_list.resize(size as _, scalar0_id);
+
+ let vec0_id =
+ self.writer.get_constant_composite(ty, &self.temp_list);
+
+ self.temp_list.fill(scalar1_id);
+
+ let vec1_id =
+ self.writer.get_constant_composite(ty, &self.temp_list);
+
+ (vec1_id, vec0_id)
+ }
+ None => (scalar1_id, scalar0_id),
+ };
+
+ Cast::Ternary(spirv::Op::Select, accept_id, reject_id)
+ }
+ (Sk::Float, Sk::Uint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToU),
+ (Sk::Float, Sk::Sint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToS),
+ (Sk::Float, Sk::Float, Some(dst_width)) if src_width != dst_width => {
+ Cast::Unary(spirv::Op::FConvert)
+ }
+ (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF),
+ (Sk::Sint, Sk::Sint, Some(dst_width)) if src_width != dst_width => {
+ Cast::Unary(spirv::Op::SConvert)
+ }
+ (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF),
+ (Sk::Uint, Sk::Uint, Some(dst_width)) if src_width != dst_width => {
+ Cast::Unary(spirv::Op::UConvert)
+ }
+ // We assume it's either an identity cast, or int-uint.
+ _ => Cast::Unary(spirv::Op::Bitcast),
+ }
+ };
+
+ let id = self.gen_id();
+ let instruction = match cast {
+ Cast::Identity => None,
+ Cast::Unary(op) => Some(Instruction::unary(op, result_type_id, id, expr_id)),
+ Cast::Binary(op, operand) => Some(Instruction::binary(
+ op,
+ result_type_id,
+ id,
+ expr_id,
+ operand,
+ )),
+ Cast::Ternary(op, op1, op2) => Some(Instruction::ternary(
+ op,
+ result_type_id,
+ id,
+ expr_id,
+ op1,
+ op2,
+ )),
+ };
+ if let Some(instruction) = instruction {
+ block.body.push(instruction);
+ id
+ } else {
+ expr_id
+ }
+ }
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => self.write_image_load(
+ result_type_id,
+ image,
+ coordinate,
+ array_index,
+ level,
+ sample,
+ block,
+ )?,
+ crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => self.write_image_sample(
+ result_type_id,
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ block,
+ )?,
+ crate::Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ let id = self.gen_id();
+ let mut condition_id = self.cached[condition];
+ let accept_id = self.cached[accept];
+ let reject_id = self.cached[reject];
+
+ let condition_ty = self.fun_info[condition]
+ .ty
+ .inner_with(&self.ir_module.types);
+ let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
+
+ if let (
+ &crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width,
+ },
+ &crate::TypeInner::Vector { size, .. },
+ ) = (condition_ty, object_ty)
+ {
+ self.temp_list.clear();
+ self.temp_list.resize(size as usize, condition_id);
+
+ let bool_vector_type_id =
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(size),
+ kind: crate::ScalarKind::Bool,
+ width,
+ pointer_space: None,
+ }));
+
+ let id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ bool_vector_type_id,
+ id,
+ &self.temp_list,
+ ));
+ condition_id = id
+ }
+
+ let instruction =
+ Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
+ block.body.push(instruction);
+ id
+ }
+ crate::Expression::Derivative { axis, ctrl, expr } => {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ match ctrl {
+ Ctrl::Coarse | Ctrl::Fine => {
+ self.writer.require_any(
+ "DerivativeControl",
+ &[spirv::Capability::DerivativeControl],
+ )?;
+ }
+ Ctrl::None => {}
+ }
+ let id = self.gen_id();
+ let expr_id = self.cached[expr];
+ let op = match (axis, ctrl) {
+ (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse,
+ (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine,
+ (Axis::X, Ctrl::None) => spirv::Op::DPdx,
+ (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse,
+ (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine,
+ (Axis::Y, Ctrl::None) => spirv::Op::DPdy,
+ (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse,
+ (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine,
+ (Axis::Width, Ctrl::None) => spirv::Op::Fwidth,
+ };
+ block
+ .body
+ .push(Instruction::derivative(op, result_type_id, id, expr_id));
+ id
+ }
+ crate::Expression::ImageQuery { image, query } => {
+ self.write_image_query(result_type_id, image, query, block)?
+ }
+ crate::Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+ let arg_id = self.cached[argument];
+ let op = match fun {
+ Rf::All => spirv::Op::All,
+ Rf::Any => spirv::Op::Any,
+ Rf::IsNan => spirv::Op::IsNan,
+ Rf::IsInf => spirv::Op::IsInf,
+ //TODO: these require Kernel capability
+ Rf::IsFinite | Rf::IsNormal => {
+ return Err(Error::FeatureNotImplemented("is finite/normal"))
+ }
+ };
+ let id = self.gen_id();
+ block
+ .body
+ .push(Instruction::relational(op, result_type_id, id, arg_id));
+ id
+ }
+ crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
+ crate::Expression::RayQueryGetIntersection { query, committed } => {
+ if !committed {
+ return Err(Error::FeatureNotImplemented("candidate intersection"));
+ }
+ self.write_ray_query_get_intersection(query, block)
+ }
+ };
+
+ self.cached[expr_handle] = id;
+ Ok(())
+ }
+
+ /// Build an `OpAccessChain` instruction.
+ ///
+ /// Emit any needed bounds-checking expressions to `block`.
+ ///
+ /// Some cases we need to generate a different return type than what the IR gives us.
+ /// This is because pointers to binding arrays don't exist in the IR, but we need to
+ /// create them to create an access chain in SPIRV.
+ ///
+ /// On success, the return value is an [`ExpressionPointer`] value; see the
+ /// documentation for that type.
+ fn write_expression_pointer(
+ &mut self,
+ mut expr_handle: Handle<crate::Expression>,
+ block: &mut Block,
+ return_type_override: Option<LookupType>,
+ ) -> Result<ExpressionPointer, Error> {
+ let result_lookup_ty = match self.fun_info[expr_handle].ty {
+ TypeResolution::Handle(ty_handle) => match return_type_override {
+ // We use the return type override as a special case for binding arrays as the OpAccessChain
+ // needs to return a pointer, but indexing into a binding array just gives you the type of
+ // the binding in the IR.
+ Some(ty) => ty,
+ None => LookupType::Handle(ty_handle),
+ },
+ TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()),
+ };
+ let result_type_id = self.get_type_id(result_lookup_ty);
+
+ // The id of the boolean `and` of all dynamic bounds checks up to this point. If
+ // `None`, then we haven't done any dynamic bounds checks yet.
+ //
+ // When we have a chain of bounds checks, we combine them with `OpLogicalAnd`, not
+ // a short-circuit branch. This means we might do comparisons we don't need to,
+ // but we expect these checks to almost always succeed, and keeping branches to a
+ // minimum is essential.
+ let mut accumulated_checks = None;
+
+ self.temp_list.clear();
+ let root_id = loop {
+ expr_handle = match self.ir_function.expressions[expr_handle] {
+ crate::Expression::Access { base, index } => {
+ let index_id = match self.write_bounds_check(base, index, block)? {
+ BoundsCheckResult::KnownInBounds(known_index) => {
+ // Even if the index is known, `OpAccessIndex`
+ // requires expression operands, not literals.
+ let scalar = crate::ScalarValue::Uint(known_index as u64);
+ self.writer.get_constant_scalar(scalar, 4)
+ }
+ BoundsCheckResult::Computed(computed_index_id) => computed_index_id,
+ BoundsCheckResult::Conditional(comparison_id) => {
+ match accumulated_checks {
+ Some(prior_checks) => {
+ let combined = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::LogicalAnd,
+ self.writer.get_bool_type_id(),
+ combined,
+ prior_checks,
+ comparison_id,
+ ));
+ accumulated_checks = Some(combined);
+ }
+ None => {
+ // Start a fresh chain of checks.
+ accumulated_checks = Some(comparison_id);
+ }
+ }
+
+ // Either way, the index to use is unchanged.
+ self.cached[index]
+ }
+ };
+ self.temp_list.push(index_id);
+
+ base
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ let const_id = self.get_index_constant(index);
+ self.temp_list.push(const_id);
+ base
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ let gv = &self.writer.global_variables[handle.index()];
+ break gv.access_id;
+ }
+ crate::Expression::LocalVariable(variable) => {
+ let local_var = &self.function.variables[&variable];
+ break local_var.id;
+ }
+ crate::Expression::FunctionArgument(index) => {
+ break self.function.parameter_id(index);
+ }
+ ref other => unimplemented!("Unexpected pointer expression {:?}", other),
+ }
+ };
+
+ let pointer = if self.temp_list.is_empty() {
+ ExpressionPointer::Ready {
+ pointer_id: root_id,
+ }
+ } else {
+ self.temp_list.reverse();
+ let pointer_id = self.gen_id();
+ let access =
+ Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
+
+ // If we generated some bounds checks, we need to leave it to our
+ // caller to generate the branch, the access, the load or store, and
+ // the zero value (for loads). Otherwise, we can emit the access
+ // ourselves, and just hand them the id of the pointer.
+ match accumulated_checks {
+ Some(condition) => ExpressionPointer::Conditional { condition, access },
+ None => {
+ block.body.push(access);
+ ExpressionPointer::Ready { pointer_id }
+ }
+ }
+ };
+
+ Ok(pointer)
+ }
+
+ /// Build the instructions for matrix - matrix column operations
+ #[allow(clippy::too_many_arguments)]
+ fn write_matrix_matrix_column_op(
+ &mut self,
+ block: &mut Block,
+ result_id: Word,
+ result_type_id: Word,
+ left_id: Word,
+ right_id: Word,
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: u8,
+ op: spirv::Op,
+ ) {
+ self.temp_list.clear();
+
+ let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(rows),
+ kind: crate::ScalarKind::Float,
+ width,
+ pointer_space: None,
+ }));
+
+ for index in 0..columns as u32 {
+ let column_id_left = self.gen_id();
+ let column_id_right = self.gen_id();
+ let column_id_res = self.gen_id();
+
+ block.body.push(Instruction::composite_extract(
+ vector_type_id,
+ column_id_left,
+ left_id,
+ &[index],
+ ));
+ block.body.push(Instruction::composite_extract(
+ vector_type_id,
+ column_id_right,
+ right_id,
+ &[index],
+ ));
+ block.body.push(Instruction::binary(
+ op,
+ vector_type_id,
+ column_id_res,
+ column_id_left,
+ column_id_right,
+ ));
+
+ self.temp_list.push(column_id_res);
+ }
+
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ result_id,
+ &self.temp_list,
+ ));
+ }
+
+ /// Build the instructions for vector - scalar multiplication
+ fn write_vector_scalar_mult(
+ &mut self,
+ block: &mut Block,
+ result_id: Word,
+ result_type_id: Word,
+ vector_id: Word,
+ scalar_id: Word,
+ vector: &crate::TypeInner,
+ ) {
+ let (size, kind) = match *vector {
+ crate::TypeInner::Vector { size, kind, .. } => (size, kind),
+ _ => unreachable!(),
+ };
+
+ let (op, operand_id) = match kind {
+ crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id),
+ _ => {
+ let operand_id = self.gen_id();
+ self.temp_list.clear();
+ self.temp_list.resize(size as usize, scalar_id);
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ operand_id,
+ &self.temp_list,
+ ));
+ (spirv::Op::IMul, operand_id)
+ }
+ };
+
+ block.body.push(Instruction::binary(
+ op,
+ result_type_id,
+ result_id,
+ vector_id,
+ operand_id,
+ ));
+ }
+
+ /// Build the instructions for the arithmetic expression of a dot product
+ fn write_dot_product(
+ &mut self,
+ result_id: Word,
+ result_type_id: Word,
+ arg0_id: Word,
+ arg1_id: Word,
+ size: u32,
+ block: &mut Block,
+ ) {
+ let mut partial_sum = self.writer.write_constant_null(result_type_id);
+ let last_component = size - 1;
+ for index in 0..=last_component {
+ // compute the product of the current components
+ let a_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ a_id,
+ arg0_id,
+ &[index],
+ ));
+ let b_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ b_id,
+ arg1_id,
+ &[index],
+ ));
+ let prod_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::IMul,
+ result_type_id,
+ prod_id,
+ a_id,
+ b_id,
+ ));
+
+ // choose the id for the next sum, depending on current index
+ let id = if index == last_component {
+ result_id
+ } else {
+ self.gen_id()
+ };
+
+ // sum the computed product with the partial sum
+ block.body.push(Instruction::binary(
+ spirv::Op::IAdd,
+ result_type_id,
+ id,
+ partial_sum,
+ prod_id,
+ ));
+ // set the id of the result as the previous partial sum
+ partial_sum = id;
+ }
+ }
+
+ pub(super) fn write_block(
+ &mut self,
+ label_id: Word,
+ statements: &[crate::Statement],
+ exit: BlockExit,
+ loop_context: LoopContext,
+ ) -> Result<(), Error> {
+ let mut block = Block::new(label_id);
+
+ for statement in statements {
+ match *statement {
+ crate::Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ self.cache_expression_value(handle, &mut block)?;
+ }
+ }
+ crate::Statement::Block(ref block_statements) => {
+ let scope_id = self.gen_id();
+ self.function.consume(block, Instruction::branch(scope_id));
+
+ let merge_id = self.gen_id();
+ self.write_block(
+ scope_id,
+ block_statements,
+ BlockExit::Branch { target: merge_id },
+ loop_context,
+ )?;
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ let condition_id = self.cached[condition];
+
+ let merge_id = self.gen_id();
+ block.body.push(Instruction::selection_merge(
+ merge_id,
+ spirv::SelectionControl::NONE,
+ ));
+
+ let accept_id = if accept.is_empty() {
+ None
+ } else {
+ Some(self.gen_id())
+ };
+ let reject_id = if reject.is_empty() {
+ None
+ } else {
+ Some(self.gen_id())
+ };
+
+ self.function.consume(
+ block,
+ Instruction::branch_conditional(
+ condition_id,
+ accept_id.unwrap_or(merge_id),
+ reject_id.unwrap_or(merge_id),
+ ),
+ );
+
+ if let Some(block_id) = accept_id {
+ self.write_block(
+ block_id,
+ accept,
+ BlockExit::Branch { target: merge_id },
+ loop_context,
+ )?;
+ }
+ if let Some(block_id) = reject_id {
+ self.write_block(
+ block_id,
+ reject,
+ BlockExit::Branch { target: merge_id },
+ loop_context,
+ )?;
+ }
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ let selector_id = self.cached[selector];
+
+ let merge_id = self.gen_id();
+ block.body.push(Instruction::selection_merge(
+ merge_id,
+ spirv::SelectionControl::NONE,
+ ));
+
+ let mut default_id = None;
+ // id of previous empty fall-through case
+ let mut last_id = None;
+
+ let mut raw_cases = Vec::with_capacity(cases.len());
+ let mut case_ids = Vec::with_capacity(cases.len());
+ for case in cases.iter() {
+ // take id of previous empty fall-through case or generate a new one
+ let label_id = last_id.take().unwrap_or_else(|| self.gen_id());
+
+ if case.fall_through && case.body.is_empty() {
+ last_id = Some(label_id);
+ }
+
+ case_ids.push(label_id);
+
+ match case.value {
+ crate::SwitchValue::I32(value) => {
+ raw_cases.push(super::instructions::Case {
+ value: value as Word,
+ label_id,
+ });
+ }
+ crate::SwitchValue::U32(value) => {
+ raw_cases.push(super::instructions::Case { value, label_id });
+ }
+ crate::SwitchValue::Default => {
+ default_id = Some(label_id);
+ }
+ }
+ }
+
+ let default_id = default_id.unwrap();
+
+ self.function.consume(
+ block,
+ Instruction::switch(selector_id, default_id, &raw_cases),
+ );
+
+ let inner_context = LoopContext {
+ break_id: Some(merge_id),
+ ..loop_context
+ };
+
+ for (i, (case, label_id)) in cases
+ .iter()
+ .zip(case_ids.iter())
+ .filter(|&(case, _)| !(case.fall_through && case.body.is_empty()))
+ .enumerate()
+ {
+ let case_finish_id = if case.fall_through {
+ case_ids[i + 1]
+ } else {
+ merge_id
+ };
+ self.write_block(
+ *label_id,
+ &case.body,
+ BlockExit::Branch {
+ target: case_finish_id,
+ },
+ inner_context,
+ )?;
+ }
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ let preamble_id = self.gen_id();
+ self.function
+ .consume(block, Instruction::branch(preamble_id));
+
+ let merge_id = self.gen_id();
+ let body_id = self.gen_id();
+ let continuing_id = self.gen_id();
+
+ // SPIR-V requires the continuing to the `OpLoopMerge`,
+ // so we have to start a new block with it.
+ block = Block::new(preamble_id);
+ block.body.push(Instruction::loop_merge(
+ merge_id,
+ continuing_id,
+ spirv::SelectionControl::NONE,
+ ));
+ self.function.consume(block, Instruction::branch(body_id));
+
+ self.write_block(
+ body_id,
+ body,
+ BlockExit::Branch {
+ target: continuing_id,
+ },
+ LoopContext {
+ continuing_id: Some(continuing_id),
+ break_id: Some(merge_id),
+ },
+ )?;
+
+ let exit = match break_if {
+ Some(condition) => BlockExit::BreakIf {
+ condition,
+ preamble_id,
+ },
+ None => BlockExit::Branch {
+ target: preamble_id,
+ },
+ };
+
+ self.write_block(
+ continuing_id,
+ continuing,
+ exit,
+ LoopContext {
+ continuing_id: None,
+ break_id: Some(merge_id),
+ },
+ )?;
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::Break => {
+ self.function
+ .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
+ return Ok(());
+ }
+ crate::Statement::Continue => {
+ self.function.consume(
+ block,
+ Instruction::branch(loop_context.continuing_id.unwrap()),
+ );
+ return Ok(());
+ }
+ crate::Statement::Return { value: Some(value) } => {
+ let value_id = self.cached[value];
+ let instruction = match self.function.entry_point_context {
+ // If this is an entry point, and we need to return anything,
+ // let's instead store the output variables and return `void`.
+ Some(ref context) => {
+ self.writer.write_entry_point_return(
+ value_id,
+ self.ir_function.result.as_ref().unwrap(),
+ &context.results,
+ &mut block.body,
+ )?;
+ Instruction::return_void()
+ }
+ None => Instruction::return_value(value_id),
+ };
+ self.function.consume(block, instruction);
+ return Ok(());
+ }
+ crate::Statement::Return { value: None } => {
+ self.function.consume(block, Instruction::return_void());
+ return Ok(());
+ }
+ crate::Statement::Kill => {
+ self.function.consume(block, Instruction::kill());
+ return Ok(());
+ }
+ crate::Statement::Barrier(flags) => {
+ self.writer.write_barrier(flags, &mut block);
+ }
+ crate::Statement::Store { pointer, value } => {
+ let value_id = self.cached[value];
+ match self.write_expression_pointer(pointer, &mut block, None)? {
+ ExpressionPointer::Ready { pointer_id } => {
+ let atomic_space = match *self.fun_info[pointer]
+ .ty
+ .inner_with(&self.ir_module.types)
+ {
+ crate::TypeInner::Pointer { base, space } => {
+ match self.ir_module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => Some(space),
+ _ => None,
+ }
+ }
+ _ => None,
+ };
+ let instruction = if let Some(space) = atomic_space {
+ let (semantics, scope) = space.to_spirv_semantics_and_scope();
+ let scope_constant_id = self.get_scope_constant(scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ Instruction::atomic_store(
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ } else {
+ Instruction::store(pointer_id, value_id, None)
+ };
+ block.body.push(instruction);
+ }
+ ExpressionPointer::Conditional { condition, access } => {
+ let mut selection = Selection::start(&mut block, ());
+ selection.if_true(self, condition, ());
+
+ // The in-bounds path. Perform the access and the store.
+ let pointer_id = access.result_id.unwrap();
+ selection.block().body.push(access);
+ selection
+ .block()
+ .body
+ .push(Instruction::store(pointer_id, value_id, None));
+
+ // Finish the in-bounds block and start the merge block. This
+ // is the block we'll leave current on return.
+ selection.finish(self, ());
+ }
+ };
+ }
+ crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
+ crate::Statement::Call {
+ function: local_function,
+ ref arguments,
+ result,
+ } => {
+ let id = self.gen_id();
+ self.temp_list.clear();
+ for &argument in arguments {
+ self.temp_list.push(self.cached[argument]);
+ }
+
+ let type_id = match result {
+ Some(expr) => {
+ self.cached[expr] = id;
+ self.get_expression_type_id(&self.fun_info[expr].ty)
+ }
+ None => self.writer.void_type,
+ };
+
+ block.body.push(Instruction::function_call(
+ type_id,
+ id,
+ self.writer.lookup_function[&local_function],
+ &self.temp_list,
+ ));
+ }
+ crate::Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ let id = self.gen_id();
+ let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
+
+ self.cached[result] = id;
+
+ let pointer_id =
+ match self.write_expression_pointer(pointer, &mut block, None)? {
+ ExpressionPointer::Ready { pointer_id } => pointer_id,
+ ExpressionPointer::Conditional { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "Atomics out-of-bounds handling",
+ ));
+ }
+ };
+
+ let space = self.fun_info[pointer]
+ .ty
+ .inner_with(&self.ir_module.types)
+ .pointer_space()
+ .unwrap();
+ let (semantics, scope) = space.to_spirv_semantics_and_scope();
+ let scope_constant_id = self.get_scope_constant(scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ let value_id = self.cached[value];
+ let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
+
+ let instruction = match *fun {
+ crate::AtomicFunction::Add => Instruction::atomic_binary(
+ spirv::Op::AtomicIAdd,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::Subtract => Instruction::atomic_binary(
+ spirv::Op::AtomicISub,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::And => Instruction::atomic_binary(
+ spirv::Op::AtomicAnd,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::InclusiveOr => Instruction::atomic_binary(
+ spirv::Op::AtomicOr,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::ExclusiveOr => Instruction::atomic_binary(
+ spirv::Op::AtomicXor,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::Min => {
+ let spirv_op = match *value_inner {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: _,
+ } => spirv::Op::AtomicSMin,
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: _,
+ } => spirv::Op::AtomicUMin,
+ _ => unimplemented!(),
+ };
+ Instruction::atomic_binary(
+ spirv_op,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ }
+ crate::AtomicFunction::Max => {
+ let spirv_op = match *value_inner {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: _,
+ } => spirv::Op::AtomicSMax,
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: _,
+ } => spirv::Op::AtomicUMax,
+ _ => unimplemented!(),
+ };
+ Instruction::atomic_binary(
+ spirv_op,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ }
+ crate::AtomicFunction::Exchange { compare: None } => {
+ Instruction::atomic_binary(
+ spirv::Op::AtomicExchange,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ }
+ crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
+ let scalar_type_id = match *value_inner {
+ crate::TypeInner::Scalar { kind, width } => {
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind,
+ width,
+ pointer_space: None,
+ }))
+ }
+ _ => unimplemented!(),
+ };
+ let bool_type_id =
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ pointer_space: None,
+ }));
+
+ let cas_result_id = self.gen_id();
+ let equality_result_id = self.gen_id();
+ let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
+ cas_instr.set_type(scalar_type_id);
+ cas_instr.set_result(cas_result_id);
+ cas_instr.add_operand(pointer_id);
+ cas_instr.add_operand(scope_constant_id);
+ cas_instr.add_operand(semantics_id); // semantics if equal
+ cas_instr.add_operand(semantics_id); // semantics if not equal
+ cas_instr.add_operand(value_id);
+ cas_instr.add_operand(self.cached[cmp]);
+ block.body.push(cas_instr);
+ block.body.push(Instruction::binary(
+ spirv::Op::IEqual,
+ bool_type_id,
+ equality_result_id,
+ cas_result_id,
+ self.cached[cmp],
+ ));
+ Instruction::composite_construct(
+ result_type_id,
+ id,
+ &[cas_result_id, equality_result_id],
+ )
+ }
+ };
+
+ block.body.push(instruction);
+ }
+ crate::Statement::RayQuery { query, ref fun } => {
+ self.write_ray_query_function(query, fun, &mut block);
+ }
+ }
+ }
+
+ let termination = match exit {
+ // We're generating code for the top-level Block of the function, so we
+ // need to end it with some kind of return instruction.
+ BlockExit::Return => match self.ir_function.result {
+ Some(ref result) if self.function.entry_point_context.is_none() => {
+ let type_id = self.get_type_id(LookupType::Handle(result.ty));
+ let null_id = self.writer.write_constant_null(type_id);
+ Instruction::return_value(null_id)
+ }
+ _ => Instruction::return_void(),
+ },
+ BlockExit::Branch { target } => Instruction::branch(target),
+ BlockExit::BreakIf {
+ condition,
+ preamble_id,
+ } => {
+ let condition_id = self.cached[condition];
+
+ Instruction::branch_conditional(
+ condition_id,
+ loop_context.break_id.unwrap(),
+ preamble_id,
+ )
+ }
+ };
+
+ self.function.consume(block, termination);
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/helpers.rs b/third_party/rust/naga/src/back/spv/helpers.rs
new file mode 100644
index 0000000000..1ef0db1912
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/helpers.rs
@@ -0,0 +1,108 @@
+use crate::{Handle, UniqueArena};
+use spirv::Word;
+
+pub(super) fn bytes_to_words(bytes: &[u8]) -> Vec<Word> {
+ bytes
+ .chunks(4)
+ .map(|chars| chars.iter().rev().fold(0u32, |u, c| (u << 8) | *c as u32))
+ .collect()
+}
+
+pub(super) fn string_to_words(input: &str) -> Vec<Word> {
+ let bytes = input.as_bytes();
+ let mut words = bytes_to_words(bytes);
+
+ if bytes.len() % 4 == 0 {
+ // nul-termination
+ words.push(0x0u32);
+ }
+
+ words
+}
+
+pub(super) const fn map_storage_class(space: crate::AddressSpace) -> spirv::StorageClass {
+ match space {
+ crate::AddressSpace::Handle => spirv::StorageClass::UniformConstant,
+ crate::AddressSpace::Function => spirv::StorageClass::Function,
+ crate::AddressSpace::Private => spirv::StorageClass::Private,
+ crate::AddressSpace::Storage { .. } => spirv::StorageClass::StorageBuffer,
+ crate::AddressSpace::Uniform => spirv::StorageClass::Uniform,
+ crate::AddressSpace::WorkGroup => spirv::StorageClass::Workgroup,
+ crate::AddressSpace::PushConstant => spirv::StorageClass::PushConstant,
+ }
+}
+
+pub(super) fn contains_builtin(
+ binding: Option<&crate::Binding>,
+ ty: Handle<crate::Type>,
+ arena: &UniqueArena<crate::Type>,
+ built_in: crate::BuiltIn,
+) -> bool {
+ if let Some(&crate::Binding::BuiltIn(bi)) = binding {
+ bi == built_in
+ } else if let crate::TypeInner::Struct { ref members, .. } = arena[ty].inner {
+ members
+ .iter()
+ .any(|member| contains_builtin(member.binding.as_ref(), member.ty, arena, built_in))
+ } else {
+ false // unreachable
+ }
+}
+
+impl crate::AddressSpace {
+ pub(super) const fn to_spirv_semantics_and_scope(
+ self,
+ ) -> (spirv::MemorySemantics, spirv::Scope) {
+ match self {
+ Self::Storage { .. } => (spirv::MemorySemantics::UNIFORM_MEMORY, spirv::Scope::Device),
+ Self::WorkGroup => (
+ spirv::MemorySemantics::WORKGROUP_MEMORY,
+ spirv::Scope::Workgroup,
+ ),
+ _ => (spirv::MemorySemantics::empty(), spirv::Scope::Invocation),
+ }
+ }
+}
+
+/// Return true if the global requires a type decorated with `Block`.
+///
+/// Vulkan spec v1.3 §15.6.2, "Descriptor Set Interface", says:
+///
+/// > Variables identified with the `Uniform` storage class are used to
+/// > access transparent buffer backed resources. Such variables must
+/// > be:
+/// >
+/// > - typed as `OpTypeStruct`, or an array of this type,
+/// >
+/// > - identified with a `Block` or `BufferBlock` decoration, and
+/// >
+/// > - laid out explicitly using the `Offset`, `ArrayStride`, and
+/// > `MatrixStride` decorations as specified in §15.6.4, "Offset
+/// > and Stride Assignment."
+// See `back::spv::GlobalVariable::access_id` for details.
+pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariable) -> bool {
+ match var.space {
+ crate::AddressSpace::Uniform
+ | crate::AddressSpace::Storage { .. }
+ | crate::AddressSpace::PushConstant => {}
+ _ => return false,
+ };
+ match ir_module.types[var.ty].inner {
+ crate::TypeInner::Struct {
+ ref members,
+ span: _,
+ } => match members.last() {
+ Some(member) => match ir_module.types[member.ty].inner {
+ // Structs with dynamically sized arrays can't be copied and can't be wrapped.
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } => false,
+ _ => true,
+ },
+ None => false,
+ },
+ // if it's not a structure, let's wrap it to be able to put "Block"
+ _ => true,
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/image.rs b/third_party/rust/naga/src/back/spv/image.rs
new file mode 100644
index 0000000000..27f3520502
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/image.rs
@@ -0,0 +1,1269 @@
+/*!
+Generating SPIR-V for image operations.
+*/
+
+use super::{
+ selection::{MergeTuple, Selection},
+ Block, BlockContext, Error, IdGenerator, Instruction, LocalType, LookupType,
+};
+use crate::arena::Handle;
+use spirv::Word;
+
+/// Information about a vector of coordinates.
+///
+/// The coordinate vectors expected by SPIR-V `OpImageRead` and `OpImageFetch`
+/// supply the array index for arrayed images as an additional component at
+/// the end, whereas Naga's `ImageLoad`, `ImageStore`, and `ImageSample` carry
+/// the array index as a separate field.
+///
+/// In the process of generating code to compute the combined vector, we also
+/// produce SPIR-V types and vector lengths that are useful elsewhere. This
+/// struct gathers that information into one place, with standard names.
+struct ImageCoordinates {
+ /// The SPIR-V id of the combined coordinate/index vector value.
+ ///
+ /// Note: when indexing a non-arrayed 1D image, this will be a scalar.
+ value_id: Word,
+
+ /// The SPIR-V id of the type of `value`.
+ type_id: Word,
+
+ /// The number of components in `value`, if it is a vector, or `None` if it
+ /// is a scalar.
+ size: Option<crate::VectorSize>,
+}
+
+/// A trait for image access (load or store) code generators.
+///
+/// Types implementing this trait hold information about an `ImageStore` or
+/// `ImageLoad` operation that is not affected by the bounds check policy. The
+/// `generate` method emits code for the access, given the results of bounds
+/// checking.
+///
+/// The [`image`] bounds checks policy affects access coordinates, level of
+/// detail, and sample index, but never the image id, result type (if any), or
+/// the specific SPIR-V instruction used. Types that implement this trait gather
+/// together the latter category, so we don't have to plumb them through the
+/// bounds-checking code.
+///
+/// [`image`]: crate::proc::BoundsCheckPolicies::index
+trait Access {
+ /// The Rust type that represents SPIR-V values and types for this access.
+ ///
+ /// For operations like loads, this is `Word`. For operations like stores,
+ /// this is `()`.
+ ///
+ /// For `ReadZeroSkipWrite`, this will be the type of the selection
+ /// construct that performs the bounds checks, so it must implement
+ /// `MergeTuple`.
+ type Output: MergeTuple + Copy + Clone;
+
+ /// Write an image access to `block`.
+ ///
+ /// Access the texel at `coordinates_id`. The optional `level_id` indicates
+ /// the level of detail, and `sample_id` is the index of the sample to
+ /// access in a multisampled texel.
+ ///
+ /// Ths method assumes that `coordinates_id` has already had the image array
+ /// index, if any, folded in, as done by `write_image_coordinates`.
+ ///
+ /// Return the value id produced by the instruction, if any.
+ ///
+ /// Use `id_gen` to generate SPIR-V ids as necessary.
+ fn generate(
+ &self,
+ id_gen: &mut IdGenerator,
+ coordinates_id: Word,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ ) -> Self::Output;
+
+ /// Return the SPIR-V type of the value produced by the code written by
+ /// `generate`. If the access does not produce a value, `Self::Output`
+ /// should be `()`.
+ fn result_type(&self) -> Self::Output;
+
+ /// Construct the SPIR-V 'zero' value to be returned for an out-of-bounds
+ /// access under the `ReadZeroSkipWrite` policy. If the access does not
+ /// produce a value, `Self::Output` should be `()`.
+ fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Self::Output;
+}
+
+/// Texel access information for an [`ImageLoad`] expression.
+///
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+struct Load {
+ /// The specific opcode we'll use to perform the fetch. Storage images
+ /// require `OpImageRead`, while sampled images require `OpImageFetch`.
+ opcode: spirv::Op,
+
+ /// The type id produced by the actual image access instruction.
+ type_id: Word,
+
+ /// The id of the image being accessed.
+ image_id: Word,
+}
+
+impl Load {
+ fn from_image_expr(
+ ctx: &mut BlockContext<'_>,
+ image_id: Word,
+ image_class: crate::ImageClass,
+ result_type_id: Word,
+ ) -> Result<Load, Error> {
+ let opcode = match image_class {
+ crate::ImageClass::Storage { .. } => spirv::Op::ImageRead,
+ crate::ImageClass::Depth { .. } | crate::ImageClass::Sampled { .. } => {
+ spirv::Op::ImageFetch
+ }
+ };
+
+ // `OpImageRead` and `OpImageFetch` instructions produce vec4<f32>
+ // values. Most of the time, we can just use `result_type_id` for
+ // this. The exception is that `Expression::ImageLoad` from a depth
+ // image produces a scalar `f32`, so in that case we need to find
+ // the right SPIR-V type for the access instruction here.
+ let type_id = match image_class {
+ crate::ImageClass::Depth { .. } => {
+ ctx.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Quad),
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ pointer_space: None,
+ }))
+ }
+ _ => result_type_id,
+ };
+
+ Ok(Load {
+ opcode,
+ type_id,
+ image_id,
+ })
+ }
+}
+
+impl Access for Load {
+ type Output = Word;
+
+ /// Write an instruction to access a given texel of this image.
+ fn generate(
+ &self,
+ id_gen: &mut IdGenerator,
+ coordinates_id: Word,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ ) -> Word {
+ let texel_id = id_gen.next();
+ let mut instruction = Instruction::image_fetch_or_read(
+ self.opcode,
+ self.type_id,
+ texel_id,
+ self.image_id,
+ coordinates_id,
+ );
+
+ match (level_id, sample_id) {
+ (None, None) => {}
+ (Some(level_id), None) => {
+ instruction.add_operand(spirv::ImageOperands::LOD.bits());
+ instruction.add_operand(level_id);
+ }
+ (None, Some(sample_id)) => {
+ instruction.add_operand(spirv::ImageOperands::SAMPLE.bits());
+ instruction.add_operand(sample_id);
+ }
+ // There's no such thing as a multi-sampled mipmap.
+ (Some(_), Some(_)) => unreachable!(),
+ }
+
+ block.body.push(instruction);
+
+ texel_id
+ }
+
+ fn result_type(&self) -> Word {
+ self.type_id
+ }
+
+ fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Word {
+ ctx.writer.write_constant_null(self.type_id)
+ }
+}
+
+/// Texel access information for a [`Store`] statement.
+///
+/// [`Store`]: crate::Statement::Store
+struct Store {
+ /// The id of the image being written to.
+ image_id: Word,
+
+ /// The value we're going to write to the texel.
+ value_id: Word,
+}
+
+impl Access for Store {
+ /// Stores don't generate any value.
+ type Output = ();
+
+ fn generate(
+ &self,
+ _id_gen: &mut IdGenerator,
+ coordinates_id: Word,
+ _level_id: Option<Word>,
+ _sample_id: Option<Word>,
+ block: &mut Block,
+ ) {
+ block.body.push(Instruction::image_write(
+ self.image_id,
+ coordinates_id,
+ self.value_id,
+ ));
+ }
+
+ /// Stores don't generate any value, so this just returns `()`.
+ fn result_type(&self) {}
+
+ /// Stores don't generate any value, so this just returns `()`.
+ fn out_of_bounds_value(&self, _ctx: &mut BlockContext<'_>) {}
+}
+
+impl<'w> BlockContext<'w> {
+ /// Extend image coordinates with an array index, if necessary.
+ ///
+ /// Whereas [`Expression::ImageLoad`] and [`ImageSample`] treat the array
+ /// index as a separate operand from the coordinates, SPIR-V image access
+ /// instructions include the array index in the `coordinates` operand. This
+ /// function builds a SPIR-V coordinate vector from a Naga coordinate vector
+ /// and array index, if one is supplied, and returns a `ImageCoordinates`
+ /// struct describing what it built.
+ ///
+ /// If `array_index` is `Some(expr)`, then this function constructs a new
+ /// vector that is `coordinates` with `array_index` concatenated onto the
+ /// end: a `vec2` becomes a `vec3`, a scalar becomes a `vec2`, and so on.
+ ///
+ /// If `array_index` is `None`, then the return value uses `coordinates`
+ /// unchanged. Note that, when indexing a non-arrayed 1D image, this will be
+ /// a scalar value.
+ ///
+ /// If needed, this function generates code to convert the array index,
+ /// always an integer scalar, to match the component type of `coordinates`.
+ /// Naga's `ImageLoad` and SPIR-V's `OpImageRead`, `OpImageFetch`, and
+ /// `OpImageWrite` all use integer coordinates, while Naga's `ImageSample`
+ /// and SPIR-V's `OpImageSample...` instructions all take floating-point
+ /// coordinate vectors.
+ ///
+ /// [`Expression::ImageLoad`]: crate::Expression::ImageLoad
+ /// [`ImageSample`]: crate::Expression::ImageSample
+ fn write_image_coordinates(
+ &mut self,
+ coordinates: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ block: &mut Block,
+ ) -> Result<ImageCoordinates, Error> {
+ use crate::TypeInner as Ti;
+ use crate::VectorSize as Vs;
+
+ let coordinates_id = self.cached[coordinates];
+ let ty = &self.fun_info[coordinates].ty;
+ let inner_ty = ty.inner_with(&self.ir_module.types);
+
+ // If there's no array index, the image coordinates are exactly the
+ // `coordinate` field of the `Expression::ImageLoad`. No work is needed.
+ let array_index = match array_index {
+ None => {
+ let value_id = coordinates_id;
+ let type_id = self.get_expression_type_id(ty);
+ let size = match *inner_ty {
+ Ti::Scalar { .. } => None,
+ Ti::Vector { size, .. } => Some(size),
+ _ => return Err(Error::Validation("coordinate type")),
+ };
+ return Ok(ImageCoordinates {
+ value_id,
+ type_id,
+ size,
+ });
+ }
+ Some(ix) => ix,
+ };
+
+ // Find the component type of `coordinates`, and figure out the size the
+ // combined coordinate vector will have.
+ let (component_kind, size) = match *inner_ty {
+ Ti::Scalar { kind, width: 4 } => (kind, Some(Vs::Bi)),
+ Ti::Vector {
+ kind,
+ width: 4,
+ size: Vs::Bi,
+ } => (kind, Some(Vs::Tri)),
+ Ti::Vector {
+ kind,
+ width: 4,
+ size: Vs::Tri,
+ } => (kind, Some(Vs::Quad)),
+ Ti::Vector { size: Vs::Quad, .. } => {
+ return Err(Error::Validation("extending vec4 coordinate"));
+ }
+ ref other => {
+ log::error!("wrong coordinate type {:?}", other);
+ return Err(Error::Validation("coordinate type"));
+ }
+ };
+
+ // Convert the index to the coordinate component type, if necessary.
+ let array_index_id = self.cached[array_index];
+ let ty = &self.fun_info[array_index].ty;
+ let inner_ty = ty.inner_with(&self.ir_module.types);
+ let array_index_kind = if let Ti::Scalar { kind, width: 4 } = *inner_ty {
+ debug_assert!(matches!(
+ kind,
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint
+ ));
+ kind
+ } else {
+ unreachable!("we only allow i32 and u32");
+ };
+ let cast = match (component_kind, array_index_kind) {
+ (crate::ScalarKind::Sint, crate::ScalarKind::Sint)
+ | (crate::ScalarKind::Uint, crate::ScalarKind::Uint) => None,
+ (crate::ScalarKind::Sint, crate::ScalarKind::Uint)
+ | (crate::ScalarKind::Uint, crate::ScalarKind::Sint) => Some(spirv::Op::Bitcast),
+ (crate::ScalarKind::Float, crate::ScalarKind::Sint) => Some(spirv::Op::ConvertSToF),
+ (crate::ScalarKind::Float, crate::ScalarKind::Uint) => Some(spirv::Op::ConvertUToF),
+ (crate::ScalarKind::Bool, _) => unreachable!("we don't allow bool for component"),
+ (_, crate::ScalarKind::Bool | crate::ScalarKind::Float) => {
+ unreachable!("we don't allow bool or float for array index")
+ }
+ };
+ let reconciled_array_index_id = if let Some(cast) = cast {
+ let component_ty_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: component_kind,
+ width: 4,
+ pointer_space: None,
+ }));
+ let reconciled_id = self.gen_id();
+ block.body.push(Instruction::unary(
+ cast,
+ component_ty_id,
+ reconciled_id,
+ array_index_id,
+ ));
+ reconciled_id
+ } else {
+ array_index_id
+ };
+
+ // Find the SPIR-V type for the combined coordinates/index vector.
+ let type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: size,
+ kind: component_kind,
+ width: 4,
+ pointer_space: None,
+ }));
+
+ // Schmear the coordinates and index together.
+ let value_id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ type_id,
+ value_id,
+ &[coordinates_id, reconciled_array_index_id],
+ ));
+ Ok(ImageCoordinates {
+ value_id,
+ type_id,
+ size,
+ })
+ }
+
+ pub(super) fn get_image_id(&mut self, expr_handle: Handle<crate::Expression>) -> Word {
+ let id = match self.ir_function.expressions[expr_handle] {
+ crate::Expression::GlobalVariable(handle) => {
+ self.writer.global_variables[handle.index()].handle_id
+ }
+ crate::Expression::FunctionArgument(i) => {
+ self.function.parameters[i as usize].handle_id
+ }
+ crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => {
+ self.cached[expr_handle]
+ }
+ ref other => unreachable!("Unexpected image expression {:?}", other),
+ };
+
+ if id == 0 {
+ unreachable!(
+ "Image expression {:?} doesn't have a handle ID",
+ expr_handle
+ );
+ }
+
+ id
+ }
+
+ /// Generate a vector or scalar 'one' for arithmetic on `coordinates`.
+ ///
+ /// If `coordinates` is a scalar, return a scalar one. Otherwise, return
+ /// a vector of ones.
+ fn write_coordinate_one(&mut self, coordinates: &ImageCoordinates) -> Result<Word, Error> {
+ let one = self.get_scope_constant(1);
+ match coordinates.size {
+ None => Ok(one),
+ Some(vector_size) => {
+ let ones = [one; 4];
+ let id = self.gen_id();
+ Instruction::constant_composite(
+ coordinates.type_id,
+ id,
+ &ones[..vector_size as usize],
+ )
+ .to_words(&mut self.writer.logical_layout.declarations);
+ Ok(id)
+ }
+ }
+ }
+
+ /// Generate code to restrict `input` to fall between zero and one less than
+ /// `size_id`.
+ ///
+ /// Both must be 32-bit scalar integer values, whose type is given by
+ /// `type_id`. The computed value is also of type `type_id`.
+ fn restrict_scalar(
+ &mut self,
+ type_id: Word,
+ input_id: Word,
+ size_id: Word,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ let i32_one_id = self.get_scope_constant(1);
+
+ // Subtract one from `size` to get the largest valid value.
+ let limit_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ISub,
+ type_id,
+ limit_id,
+ size_id,
+ i32_one_id,
+ ));
+
+ // Use an unsigned minimum, to handle both positive out-of-range values
+ // and negative values in a single instruction: negative values of
+ // `input_id` get treated as very large positive values.
+ let restricted_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::UMin,
+ type_id,
+ restricted_id,
+ &[input_id, limit_id],
+ ));
+
+ Ok(restricted_id)
+ }
+
+ /// Write instructions to query the size of an image.
+ ///
+ /// This takes care of selecting the right instruction depending on whether
+ /// a level of detail parameter is present.
+ fn write_coordinate_bounds(
+ &mut self,
+ type_id: Word,
+ image_id: Word,
+ level_id: Option<Word>,
+ block: &mut Block,
+ ) -> Word {
+ let coordinate_bounds_id = self.gen_id();
+ match level_id {
+ Some(level_id) => {
+ // A level of detail was provided, so fetch the image size for
+ // that level.
+ let mut inst = Instruction::image_query(
+ spirv::Op::ImageQuerySizeLod,
+ type_id,
+ coordinate_bounds_id,
+ image_id,
+ );
+ inst.add_operand(level_id);
+ block.body.push(inst);
+ }
+ _ => {
+ // No level of detail was given.
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySize,
+ type_id,
+ coordinate_bounds_id,
+ image_id,
+ ));
+ }
+ }
+
+ coordinate_bounds_id
+ }
+
+ /// Write code to restrict coordinates for an image reference.
+ ///
+ /// First, clamp the level of detail or sample index to fall within bounds.
+ /// Then, obtain the image size, possibly using the clamped level of detail.
+ /// Finally, use an unsigned minimum instruction to force all coordinates
+ /// into range.
+ ///
+ /// Return a triple `(COORDS, LEVEL, SAMPLE)`, where `COORDS` is a coordinate
+ /// vector (including the array index, if any), `LEVEL` is an optional level
+ /// of detail, and `SAMPLE` is an optional sample index, all guaranteed to
+ /// be in-bounds for `image_id`.
+ ///
+ /// The result is usually a vector, but it is a scalar when indexing
+ /// non-arrayed 1D images.
+ fn write_restricted_coordinates(
+ &mut self,
+ image_id: Word,
+ coordinates: ImageCoordinates,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ ) -> Result<(Word, Option<Word>, Option<Word>), Error> {
+ self.writer.require_any(
+ "the `Restrict` image bounds check policy",
+ &[spirv::Capability::ImageQuery],
+ )?;
+
+ let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ pointer_space: None,
+ }));
+
+ // If `level` is `Some`, clamp it to fall within bounds. This must
+ // happen first, because we'll use it to query the image size for
+ // clamping the actual coordinates.
+ let level_id = level_id
+ .map(|level_id| {
+ // Find the number of mipmap levels in this image.
+ let num_levels_id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQueryLevels,
+ i32_type_id,
+ num_levels_id,
+ image_id,
+ ));
+
+ self.restrict_scalar(i32_type_id, level_id, num_levels_id, block)
+ })
+ .transpose()?;
+
+ // If `sample_id` is `Some`, clamp it to fall within bounds.
+ let sample_id = sample_id
+ .map(|sample_id| {
+ // Find the number of samples per texel.
+ let num_samples_id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySamples,
+ i32_type_id,
+ num_samples_id,
+ image_id,
+ ));
+
+ self.restrict_scalar(i32_type_id, sample_id, num_samples_id, block)
+ })
+ .transpose()?;
+
+ // Obtain the image bounds, including the array element count.
+ let coordinate_bounds_id =
+ self.write_coordinate_bounds(coordinates.type_id, image_id, level_id, block);
+
+ // Compute maximum valid values from the bounds.
+ let ones = self.write_coordinate_one(&coordinates)?;
+ let coordinate_limit_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ISub,
+ coordinates.type_id,
+ coordinate_limit_id,
+ coordinate_bounds_id,
+ ones,
+ ));
+
+ // Restrict the coordinates to fall within those bounds.
+ //
+ // Use an unsigned minimum, to handle both positive out-of-range values
+ // and negative values in a single instruction: negative values of
+ // `coordinates` get treated as very large positive values.
+ let restricted_coordinates_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::UMin,
+ coordinates.type_id,
+ restricted_coordinates_id,
+ &[coordinates.value_id, coordinate_limit_id],
+ ));
+
+ Ok((restricted_coordinates_id, level_id, sample_id))
+ }
+
+ fn write_conditional_image_access<A: Access>(
+ &mut self,
+ image_id: Word,
+ coordinates: ImageCoordinates,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ access: &A,
+ ) -> Result<A::Output, Error> {
+ self.writer.require_any(
+ "the `ReadZeroSkipWrite` image bounds check policy",
+ &[spirv::Capability::ImageQuery],
+ )?;
+
+ let bool_type_id = self.writer.get_bool_type_id();
+ let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ pointer_space: None,
+ }));
+
+ let null_id = access.out_of_bounds_value(self);
+
+ let mut selection = Selection::start(block, access.result_type());
+
+ // If `level_id` is `Some`, check whether it is within bounds. This must
+ // happen first, because we'll be supplying this as an argument when we
+ // query the image size.
+ if let Some(level_id) = level_id {
+ // Find the number of mipmap levels in this image.
+ let num_levels_id = self.gen_id();
+ selection.block().body.push(Instruction::image_query(
+ spirv::Op::ImageQueryLevels,
+ i32_type_id,
+ num_levels_id,
+ image_id,
+ ));
+
+ let lod_cond_id = self.gen_id();
+ selection.block().body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ bool_type_id,
+ lod_cond_id,
+ level_id,
+ num_levels_id,
+ ));
+
+ selection.if_true(self, lod_cond_id, null_id);
+ }
+
+ // If `sample_id` is `Some`, check whether it is in bounds.
+ if let Some(sample_id) = sample_id {
+ // Find the number of samples per texel.
+ let num_samples_id = self.gen_id();
+ selection.block().body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySamples,
+ i32_type_id,
+ num_samples_id,
+ image_id,
+ ));
+
+ let samples_cond_id = self.gen_id();
+ selection.block().body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ bool_type_id,
+ samples_cond_id,
+ sample_id,
+ num_samples_id,
+ ));
+
+ selection.if_true(self, samples_cond_id, null_id);
+ }
+
+ // Obtain the image bounds, including any array element count.
+ let coordinate_bounds_id = self.write_coordinate_bounds(
+ coordinates.type_id,
+ image_id,
+ level_id,
+ selection.block(),
+ );
+
+ // Compare the coordinates against the bounds.
+ let coords_bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: coordinates.size,
+ kind: crate::ScalarKind::Bool,
+ width: 1,
+ pointer_space: None,
+ }));
+ let coords_conds_id = self.gen_id();
+ selection.block().body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ coords_bool_type_id,
+ coords_conds_id,
+ coordinates.value_id,
+ coordinate_bounds_id,
+ ));
+
+ // If the comparison above was a vector comparison, then we need to
+ // check that all components of the comparison are true.
+ let coords_cond_id = if coords_bool_type_id != bool_type_id {
+ let id = self.gen_id();
+ selection.block().body.push(Instruction::relational(
+ spirv::Op::All,
+ bool_type_id,
+ id,
+ coords_conds_id,
+ ));
+ id
+ } else {
+ coords_conds_id
+ };
+
+ selection.if_true(self, coords_cond_id, null_id);
+
+ // All conditions are met. We can carry out the access.
+ let texel_id = access.generate(
+ &mut self.writer.id_gen,
+ coordinates.value_id,
+ level_id,
+ sample_id,
+ selection.block(),
+ );
+
+ // This, then, is the value of the 'true' branch.
+ Ok(selection.finish(self, texel_id))
+ }
+
+ /// Generate code for an `ImageLoad` expression.
+ ///
+ /// The arguments are the components of an `Expression::ImageLoad` variant.
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn write_image_load(
+ &mut self,
+ result_type_id: Word,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ level: Option<Handle<crate::Expression>>,
+ sample: Option<Handle<crate::Expression>>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ let image_id = self.get_image_id(image);
+ let image_type = self.fun_info[image].ty.inner_with(&self.ir_module.types);
+ let image_class = match *image_type {
+ crate::TypeInner::Image { class, .. } => class,
+ _ => return Err(Error::Validation("image type")),
+ };
+
+ let access = Load::from_image_expr(self, image_id, image_class, result_type_id)?;
+ let coordinates = self.write_image_coordinates(coordinate, array_index, block)?;
+
+ let level_id = level.map(|expr| self.cached[expr]);
+ let sample_id = sample.map(|expr| self.cached[expr]);
+
+ // Perform the access, according to the bounds check policy.
+ let access_id = match self.writer.bounds_check_policies.image {
+ crate::proc::BoundsCheckPolicy::Restrict => {
+ let (coords, level_id, sample_id) = self.write_restricted_coordinates(
+ image_id,
+ coordinates,
+ level_id,
+ sample_id,
+ block,
+ )?;
+ access.generate(&mut self.writer.id_gen, coords, level_id, sample_id, block)
+ }
+ crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => self
+ .write_conditional_image_access(
+ image_id,
+ coordinates,
+ level_id,
+ sample_id,
+ block,
+ &access,
+ )?,
+ crate::proc::BoundsCheckPolicy::Unchecked => access.generate(
+ &mut self.writer.id_gen,
+ coordinates.value_id,
+ level_id,
+ sample_id,
+ block,
+ ),
+ };
+
+ // For depth images, `ImageLoad` expressions produce a single f32,
+ // whereas the SPIR-V instructions always produce a vec4. So we may have
+ // to pull out the component we need.
+ let result_id = if result_type_id == access.result_type() {
+ // The instruction produced the type we expected. We can use
+ // its result as-is.
+ access_id
+ } else {
+ // For `ImageClass::Depth` images, SPIR-V gave us four components,
+ // but we only want the first one.
+ let component_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ component_id,
+ access_id,
+ &[0],
+ ));
+ component_id
+ };
+
+ Ok(result_id)
+ }
+
+ /// Generate code for an `ImageSample` expression.
+ ///
+ /// The arguments are the components of an `Expression::ImageSample` variant.
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn write_image_sample(
+ &mut self,
+ result_type_id: Word,
+ image: Handle<crate::Expression>,
+ sampler: Handle<crate::Expression>,
+ gather: Option<crate::SwizzleComponent>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ offset: Option<Handle<crate::Constant>>,
+ level: crate::SampleLevel,
+ depth_ref: Option<Handle<crate::Expression>>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ use super::instructions::SampleLod;
+ // image
+ let image_id = self.get_image_id(image);
+ let image_type = self.fun_info[image].ty.handle().unwrap();
+ // SPIR-V doesn't know about our `Depth` class, and it returns
+ // `vec4<f32>`, so we need to grab the first component out of it.
+ let needs_sub_access = match self.ir_module.types[image_type].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Depth { .. },
+ ..
+ } => depth_ref.is_none() && gather.is_none(),
+ _ => false,
+ };
+ let sample_result_type_id = if needs_sub_access {
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Quad),
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ pointer_space: None,
+ }))
+ } else {
+ result_type_id
+ };
+
+ // OpTypeSampledImage
+ let image_type_id = self.get_type_id(LookupType::Handle(image_type));
+ let sampled_image_type_id =
+ self.get_type_id(LookupType::Local(LocalType::SampledImage { image_type_id }));
+
+ let sampler_id = self.get_image_id(sampler);
+ let coordinates_id = self
+ .write_image_coordinates(coordinate, array_index, block)?
+ .value_id;
+
+ let sampled_image_id = self.gen_id();
+ block.body.push(Instruction::sampled_image(
+ sampled_image_type_id,
+ sampled_image_id,
+ image_id,
+ sampler_id,
+ ));
+ let id = self.gen_id();
+
+ let depth_id = depth_ref.map(|handle| self.cached[handle]);
+ let mut mask = spirv::ImageOperands::empty();
+ mask.set(spirv::ImageOperands::CONST_OFFSET, offset.is_some());
+
+ let mut main_instruction = match (level, gather) {
+ (_, Some(component)) => {
+ let component_id = self.get_index_constant(component as u32);
+ let mut inst = Instruction::image_gather(
+ sample_result_type_id,
+ id,
+ sampled_image_id,
+ coordinates_id,
+ component_id,
+ depth_id,
+ );
+ if !mask.is_empty() {
+ inst.add_operand(mask.bits());
+ }
+ inst
+ }
+ (crate::SampleLevel::Zero, None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Explicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let zero_id = self
+ .writer
+ .get_constant_scalar(crate::ScalarValue::Float(0.0), 4);
+
+ mask |= spirv::ImageOperands::LOD;
+ inst.add_operand(mask.bits());
+ inst.add_operand(zero_id);
+
+ inst
+ }
+ (crate::SampleLevel::Auto, None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Implicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+ if !mask.is_empty() {
+ inst.add_operand(mask.bits());
+ }
+ inst
+ }
+ (crate::SampleLevel::Exact(lod_handle), None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Explicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let lod_id = self.cached[lod_handle];
+ mask |= spirv::ImageOperands::LOD;
+ inst.add_operand(mask.bits());
+ inst.add_operand(lod_id);
+
+ inst
+ }
+ (crate::SampleLevel::Bias(bias_handle), None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Implicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let bias_id = self.cached[bias_handle];
+ mask |= spirv::ImageOperands::BIAS;
+ inst.add_operand(mask.bits());
+ inst.add_operand(bias_id);
+
+ inst
+ }
+ (crate::SampleLevel::Gradient { x, y }, None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Explicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let x_id = self.cached[x];
+ let y_id = self.cached[y];
+ mask |= spirv::ImageOperands::GRAD;
+ inst.add_operand(mask.bits());
+ inst.add_operand(x_id);
+ inst.add_operand(y_id);
+
+ inst
+ }
+ };
+
+ if let Some(offset_const) = offset {
+ let offset_id = self.writer.constant_ids[offset_const.index()];
+ main_instruction.add_operand(offset_id);
+ }
+
+ block.body.push(main_instruction);
+
+ let id = if needs_sub_access {
+ let sub_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ sub_id,
+ id,
+ &[0],
+ ));
+ sub_id
+ } else {
+ id
+ };
+
+ Ok(id)
+ }
+
+ /// Generate code for an `ImageQuery` expression.
+ ///
+ /// The arguments are the components of an `Expression::ImageQuery` variant.
+ pub(super) fn write_image_query(
+ &mut self,
+ result_type_id: Word,
+ image: Handle<crate::Expression>,
+ query: crate::ImageQuery,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ use crate::{ImageClass as Ic, ImageDimension as Id, ImageQuery as Iq};
+
+ let image_id = self.get_image_id(image);
+ let image_type = self.fun_info[image].ty.handle().unwrap();
+ let (dim, arrayed, class) = match self.ir_module.types[image_type].inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => (dim, arrayed, class),
+ _ => {
+ return Err(Error::Validation("image type"));
+ }
+ };
+
+ self.writer
+ .require_any("image queries", &[spirv::Capability::ImageQuery])?;
+
+ let id = match query {
+ Iq::Size { level } => {
+ let dim_coords = match dim {
+ Id::D1 => 1,
+ Id::D2 | Id::Cube => 2,
+ Id::D3 => 3,
+ };
+ let array_coords = usize::from(arrayed);
+ let vector_size = match dim_coords + array_coords {
+ 2 => Some(crate::VectorSize::Bi),
+ 3 => Some(crate::VectorSize::Tri),
+ 4 => Some(crate::VectorSize::Quad),
+ _ => None,
+ };
+ let extended_size_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ pointer_space: None,
+ }));
+
+ let (query_op, level_id) = match class {
+ Ic::Sampled { multi: true, .. }
+ | Ic::Depth { multi: true }
+ | Ic::Storage { .. } => (spirv::Op::ImageQuerySize, None),
+ _ => {
+ let level_id = match level {
+ Some(expr) => self.cached[expr],
+ None => self.get_index_constant(0),
+ };
+ (spirv::Op::ImageQuerySizeLod, Some(level_id))
+ }
+ };
+
+ // The ID of the vector returned by SPIR-V, which contains the dimensions
+ // as well as the layer count.
+ let id_extended = self.gen_id();
+ let mut inst = Instruction::image_query(
+ query_op,
+ extended_size_type_id,
+ id_extended,
+ image_id,
+ );
+ if let Some(expr_id) = level_id {
+ inst.add_operand(expr_id);
+ }
+ block.body.push(inst);
+
+ let bitcast_type_id = self.get_type_id(
+ LocalType::Value {
+ vector_size,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ pointer_space: None,
+ }
+ .into(),
+ );
+ let bitcast_id = self.gen_id();
+ block.body.push(Instruction::unary(
+ spirv::Op::Bitcast,
+ bitcast_type_id,
+ bitcast_id,
+ id_extended,
+ ));
+
+ if result_type_id != bitcast_type_id {
+ let id = self.gen_id();
+ let components = match dim {
+ // always pick the first component, and duplicate it for all 3 dimensions
+ Id::Cube => &[0u32, 0][..],
+ _ => &[0u32, 1, 2, 3][..dim_coords],
+ };
+ block.body.push(Instruction::vector_shuffle(
+ result_type_id,
+ id,
+ bitcast_id,
+ bitcast_id,
+ components,
+ ));
+
+ id
+ } else {
+ bitcast_id
+ }
+ }
+ Iq::NumLevels => {
+ let query_id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQueryLevels,
+ self.get_type_id(
+ LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ pointer_space: None,
+ }
+ .into(),
+ ),
+ query_id,
+ image_id,
+ ));
+
+ let id = self.gen_id();
+ block.body.push(Instruction::unary(
+ spirv::Op::Bitcast,
+ result_type_id,
+ id,
+ query_id,
+ ));
+
+ id
+ }
+ Iq::NumLayers => {
+ let vec_size = match dim {
+ Id::D1 => crate::VectorSize::Bi,
+ Id::D2 | Id::Cube => crate::VectorSize::Tri,
+ Id::D3 => crate::VectorSize::Quad,
+ };
+ let extended_size_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(vec_size),
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ pointer_space: None,
+ }));
+ let id_extended = self.gen_id();
+ let mut inst = Instruction::image_query(
+ spirv::Op::ImageQuerySizeLod,
+ extended_size_type_id,
+ id_extended,
+ image_id,
+ );
+ inst.add_operand(self.get_index_constant(0));
+ block.body.push(inst);
+
+ let extract_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ self.get_type_id(
+ LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ pointer_space: None,
+ }
+ .into(),
+ ),
+ extract_id,
+ id_extended,
+ &[vec_size as u32 - 1],
+ ));
+
+ let id = self.gen_id();
+ block.body.push(Instruction::unary(
+ spirv::Op::Bitcast,
+ result_type_id,
+ id,
+ extract_id,
+ ));
+
+ id
+ }
+ Iq::NumSamples => {
+ let query_id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySamples,
+ self.get_type_id(
+ LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ pointer_space: None,
+ }
+ .into(),
+ ),
+ query_id,
+ image_id,
+ ));
+
+ let id = self.gen_id();
+ block.body.push(Instruction::unary(
+ spirv::Op::Bitcast,
+ result_type_id,
+ id,
+ query_id,
+ ));
+
+ id
+ }
+ };
+
+ Ok(id)
+ }
+
+ pub(super) fn write_image_store(
+ &mut self,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ value: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ let image_id = self.get_image_id(image);
+ let coordinates = self.write_image_coordinates(coordinate, array_index, block)?;
+ let value_id = self.cached[value];
+
+ let write = Store { image_id, value_id };
+
+ match self.writer.bounds_check_policies.image {
+ crate::proc::BoundsCheckPolicy::Restrict => {
+ let (coords, _, _) =
+ self.write_restricted_coordinates(image_id, coordinates, None, None, block)?;
+ write.generate(&mut self.writer.id_gen, coords, None, None, block);
+ }
+ crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
+ self.write_conditional_image_access(
+ image_id,
+ coordinates,
+ None,
+ None,
+ block,
+ &write,
+ )?;
+ }
+ crate::proc::BoundsCheckPolicy::Unchecked => {
+ write.generate(
+ &mut self.writer.id_gen,
+ coordinates.value_id,
+ None,
+ None,
+ block,
+ );
+ }
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/index.rs b/third_party/rust/naga/src/back/spv/index.rs
new file mode 100644
index 0000000000..d2cbdf4d6d
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/index.rs
@@ -0,0 +1,417 @@
+/*!
+Bounds-checking for SPIR-V output.
+*/
+
+use super::{
+ helpers::global_needs_wrapper, selection::Selection, Block, BlockContext, Error, IdGenerator,
+ Instruction, Word,
+};
+use crate::{arena::Handle, proc::BoundsCheckPolicy};
+
+/// The results of performing a bounds check.
+///
+/// On success, `write_bounds_check` returns a value of this type.
+pub(super) enum BoundsCheckResult {
+ /// The index is statically known and in bounds, with the given value.
+ KnownInBounds(u32),
+
+ /// The given instruction computes the index to be used.
+ Computed(Word),
+
+ /// The given instruction computes a boolean condition which is true
+ /// if the index is in bounds.
+ Conditional(Word),
+}
+
+/// A value that we either know at translation time, or need to compute at runtime.
+pub(super) enum MaybeKnown<T> {
+ /// The value is known at shader translation time.
+ Known(T),
+
+ /// The value is computed by the instruction with the given id.
+ Computed(Word),
+}
+
+impl<'w> BlockContext<'w> {
+ /// Emit code to compute the length of a run-time array.
+ ///
+ /// Given `array`, an expression referring a runtime-sized array, return the
+ /// instruction id for the array's length.
+ pub(super) fn write_runtime_array_length(
+ &mut self,
+ array: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ // Naga IR permits runtime-sized arrays as global variables or as the
+ // final member of a struct that is a global variable. SPIR-V permits
+ // only the latter, so this back end wraps bare runtime-sized arrays
+ // in a made-up struct; see `helpers::global_needs_wrapper` and its uses.
+ // This code must handle both cases.
+ let (structure_id, last_member_index) = match self.ir_function.expressions[array] {
+ crate::Expression::AccessIndex { base, index } => {
+ match self.ir_function.expressions[base] {
+ crate::Expression::GlobalVariable(handle) => (
+ self.writer.global_variables[handle.index()].access_id,
+ index,
+ ),
+ _ => return Err(Error::Validation("array length expression")),
+ }
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ let global = &self.ir_module.global_variables[handle];
+ if !global_needs_wrapper(self.ir_module, global) {
+ return Err(Error::Validation("array length expression"));
+ }
+
+ (self.writer.global_variables[handle.index()].var_id, 0)
+ }
+ _ => return Err(Error::Validation("array length expression")),
+ };
+
+ let length_id = self.gen_id();
+ block.body.push(Instruction::array_length(
+ self.writer.get_uint_type_id(),
+ length_id,
+ structure_id,
+ last_member_index,
+ ));
+
+ Ok(length_id)
+ }
+
+ /// Compute the length of a subscriptable value.
+ ///
+ /// Given `sequence`, an expression referring to some indexable type, return
+ /// its length. The result may either be computed by SPIR-V instructions, or
+ /// known at shader translation time.
+ ///
+ /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
+ /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
+ /// sized, or use a specializable constant as its length.
+ fn write_sequence_length(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<MaybeKnown<u32>, Error> {
+ let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types);
+ match sequence_ty.indexable_length(self.ir_module) {
+ Ok(crate::proc::IndexableLength::Known(known_length)) => {
+ Ok(MaybeKnown::Known(known_length))
+ }
+ Ok(crate::proc::IndexableLength::Dynamic) => {
+ let length_id = self.write_runtime_array_length(sequence, block)?;
+ Ok(MaybeKnown::Computed(length_id))
+ }
+ Err(err) => {
+ log::error!("Sequence length for {:?} failed: {}", sequence, err);
+ Err(Error::Validation("indexable length"))
+ }
+ }
+ }
+
+ /// Compute the maximum valid index of a subscriptable value.
+ ///
+ /// Given `sequence`, an expression referring to some indexable type, return
+ /// its maximum valid index - one less than its length. The result may
+ /// either be computed, or known at shader translation time.
+ ///
+ /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
+ /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
+ /// sized, or use a specializable constant as its length.
+ fn write_sequence_max_index(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<MaybeKnown<u32>, Error> {
+ match self.write_sequence_length(sequence, block)? {
+ MaybeKnown::Known(known_length) => {
+ // We should have thrown out all attempts to subscript zero-length
+ // sequences during validation, so the following subtraction should never
+ // underflow.
+ assert!(known_length > 0);
+ // Compute the max index from the length now.
+ Ok(MaybeKnown::Known(known_length - 1))
+ }
+ MaybeKnown::Computed(length_id) => {
+ // Emit code to compute the max index from the length.
+ let const_one_id = self.get_index_constant(1);
+ let max_index_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ISub,
+ self.writer.get_uint_type_id(),
+ max_index_id,
+ length_id,
+ const_one_id,
+ ));
+ Ok(MaybeKnown::Computed(max_index_id))
+ }
+ }
+ }
+
+ /// Restrict an index to be in range for a vector, matrix, or array.
+ ///
+ /// This is used to implement `BoundsCheckPolicy::Restrict`. An in-bounds
+ /// index is left unchanged. An out-of-bounds index is replaced with some
+ /// arbitrary in-bounds index. Note,this is not necessarily clamping; for
+ /// example, negative indices might be changed to refer to the last element
+ /// of the sequence, not the first, as clamping would do.
+ ///
+ /// Either return the restricted index value, if known, or add instructions
+ /// to `block` to compute it, and return the id of the result. See the
+ /// documentation for `BoundsCheckResult` for details.
+ ///
+ /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
+ /// `Pointer` to any of those, or a `ValuePointer`. An array may be
+ /// fixed-size, dynamically sized, or use a specializable constant as its
+ /// length.
+ pub(super) fn write_restricted_index(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<BoundsCheckResult, Error> {
+ let index_id = self.cached[index];
+
+ // Get the sequence's maximum valid index. Return early if we've already
+ // done the bounds check.
+ let max_index_id = match self.write_sequence_max_index(sequence, block)? {
+ MaybeKnown::Known(known_max_index) => {
+ if let crate::Expression::Constant(index_k) = self.ir_function.expressions[index] {
+ if let Some(known_index) = self.ir_module.constants[index_k].to_array_length() {
+ // Both the index and length are known at compile time.
+ //
+ // In strict WGSL compliance mode, out-of-bounds indices cannot be
+ // reported at shader translation time, and must be replaced with
+ // in-bounds indices at run time. So we cannot assume that
+ // validation ensured the index was in bounds. Restrict now.
+ let restricted = std::cmp::min(known_index, known_max_index);
+ return Ok(BoundsCheckResult::KnownInBounds(restricted));
+ }
+ }
+
+ self.get_index_constant(known_max_index)
+ }
+ MaybeKnown::Computed(max_index_id) => max_index_id,
+ };
+
+ // One or the other of the index or length is dynamic, so emit code for
+ // BoundsCheckPolicy::Restrict.
+ let restricted_index_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::UMin,
+ self.writer.get_uint_type_id(),
+ restricted_index_id,
+ &[index_id, max_index_id],
+ ));
+ Ok(BoundsCheckResult::Computed(restricted_index_id))
+ }
+
+ /// Write an index bounds comparison to `block`, if needed.
+ ///
+ /// If we're able to determine statically that `index` is in bounds for
+ /// `sequence`, return `KnownInBounds(value)`, where `value` is the actual
+ /// value of the index. (In principle, one could know that the index is in
+ /// bounds without knowing its specific value, but in our simple-minded
+ /// situation, we always know it.)
+ ///
+ /// If instead we must generate code to perform the comparison at run time,
+ /// return `Conditional(comparison_id)`, where `comparison_id` is an
+ /// instruction producing a boolean value that is true if `index` is in
+ /// bounds for `sequence`.
+ ///
+ /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
+ /// `Pointer` to any of those, or a `ValuePointer`. An array may be
+ /// fixed-size, dynamically sized, or use a specializable constant as its
+ /// length.
+ fn write_index_comparison(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<BoundsCheckResult, Error> {
+ let index_id = self.cached[index];
+
+ // Get the sequence's length. Return early if we've already done the
+ // bounds check.
+ let length_id = match self.write_sequence_length(sequence, block)? {
+ MaybeKnown::Known(known_length) => {
+ if let crate::Expression::Constant(index_k) = self.ir_function.expressions[index] {
+ if let Some(known_index) = self.ir_module.constants[index_k].to_array_length() {
+ // Both the index and length are known at compile time.
+ //
+ // It would be nice to assume that, since we are using the
+ // `ReadZeroSkipWrite` policy, we are not in strict WGSL
+ // compliance mode, and thus we can count on the validator to have
+ // rejected any programs with known out-of-bounds indices, and
+ // thus just return `KnownInBounds` here without actually
+ // checking.
+ //
+ // But it's also reasonable to expect that bounds check policies
+ // and error reporting policies should be able to vary
+ // independently without introducing security holes. So, we should
+ // support the case where bad indices do not cause validation
+ // errors, and are handled via `ReadZeroSkipWrite`.
+ //
+ // In theory, when `known_index` is bad, we could return a new
+ // `KnownOutOfBounds` variant here. But it's simpler just to fall
+ // through and let the bounds check take place. The shader is
+ // broken anyway, so it doesn't make sense to invest in emitting
+ // the ideal code for it.
+ if known_index < known_length {
+ return Ok(BoundsCheckResult::KnownInBounds(known_index));
+ }
+ }
+ }
+
+ self.get_index_constant(known_length)
+ }
+ MaybeKnown::Computed(length_id) => length_id,
+ };
+
+ // Compare the index against the length.
+ let condition_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ self.writer.get_bool_type_id(),
+ condition_id,
+ index_id,
+ length_id,
+ ));
+
+ // Indicate that we did generate the check.
+ Ok(BoundsCheckResult::Conditional(condition_id))
+ }
+
+ /// Emit a conditional load for `BoundsCheckPolicy::ReadZeroSkipWrite`.
+ ///
+ /// Generate code to load a value of `result_type` if `condition` is true,
+ /// and generate a null value of that type if it is false. Call `emit_load`
+ /// to emit the instructions to perform the load. Return the id of the
+ /// merged value of the two branches.
+ pub(super) fn write_conditional_indexed_load<F>(
+ &mut self,
+ result_type: Word,
+ condition: Word,
+ block: &mut Block,
+ emit_load: F,
+ ) -> Word
+ where
+ F: FnOnce(&mut IdGenerator, &mut Block) -> Word,
+ {
+ // For the out-of-bounds case, we produce a zero value.
+ let null_id = self.writer.write_constant_null(result_type);
+
+ let mut selection = Selection::start(block, result_type);
+
+ // As it turns out, we don't actually need a full 'if-then-else'
+ // structure for this: SPIR-V constants are declared up front, so the
+ // 'else' block would have no instructions. Instead we emit something
+ // like this:
+ //
+ // result = zero;
+ // if in_bounds {
+ // result = do the load;
+ // }
+ // use result;
+
+ // Continue only if the index was in bounds. Otherwise, branch to the
+ // merge block.
+ selection.if_true(self, condition, null_id);
+
+ // The in-bounds path. Perform the access and the load.
+ let loaded_value = emit_load(&mut self.writer.id_gen, selection.block());
+
+ selection.finish(self, loaded_value)
+ }
+
+ /// Emit code for bounds checks for an array, vector, or matrix access.
+ ///
+ /// This implements either `index_bounds_check_policy` or
+ /// `buffer_bounds_check_policy`, depending on the address space of the
+ /// pointer being accessed.
+ ///
+ /// Return a `BoundsCheckResult` indicating how the index should be
+ /// consumed. See that type's documentation for details.
+ pub(super) fn write_bounds_check(
+ &mut self,
+ base: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<BoundsCheckResult, Error> {
+ let policy = self.writer.bounds_check_policies.choose_policy(
+ base,
+ &self.ir_module.types,
+ self.fun_info,
+ );
+
+ Ok(match policy {
+ BoundsCheckPolicy::Restrict => self.write_restricted_index(base, index, block)?,
+ BoundsCheckPolicy::ReadZeroSkipWrite => {
+ self.write_index_comparison(base, index, block)?
+ }
+ BoundsCheckPolicy::Unchecked => BoundsCheckResult::Computed(self.cached[index]),
+ })
+ }
+
+ /// Emit code to subscript a vector by value with a computed index.
+ ///
+ /// Return the id of the element value.
+ pub(super) fn write_vector_access(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ base: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
+
+ let base_id = self.cached[base];
+ let index_id = self.cached[index];
+
+ let result_id = match self.write_bounds_check(base, index, block)? {
+ BoundsCheckResult::KnownInBounds(known_index) => {
+ let result_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ result_id,
+ base_id,
+ &[known_index],
+ ));
+ result_id
+ }
+ BoundsCheckResult::Computed(computed_index_id) => {
+ let result_id = self.gen_id();
+ block.body.push(Instruction::vector_extract_dynamic(
+ result_type_id,
+ result_id,
+ base_id,
+ computed_index_id,
+ ));
+ result_id
+ }
+ BoundsCheckResult::Conditional(comparison_id) => {
+ // Run-time bounds checks were required. Emit
+ // conditional load.
+ self.write_conditional_indexed_load(
+ result_type_id,
+ comparison_id,
+ block,
+ |id_gen, block| {
+ // The in-bounds path. Generate the access.
+ let element_id = id_gen.next();
+ block.body.push(Instruction::vector_extract_dynamic(
+ result_type_id,
+ element_id,
+ base_id,
+ index_id,
+ ));
+ element_id
+ },
+ )
+ }
+ };
+
+ Ok(result_id)
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/instructions.rs b/third_party/rust/naga/src/back/spv/instructions.rs
new file mode 100644
index 0000000000..96d0278285
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/instructions.rs
@@ -0,0 +1,1063 @@
+use super::helpers;
+use spirv::{Op, Word};
+
+pub(super) enum Signedness {
+ Unsigned = 0,
+ Signed = 1,
+}
+
+pub(super) enum SampleLod {
+ Explicit,
+ Implicit,
+}
+
+pub(super) struct Case {
+ pub value: Word,
+ pub label_id: Word,
+}
+
+impl super::Instruction {
+ //
+ // Debug Instructions
+ //
+
+ pub(super) fn source(source_language: spirv::SourceLanguage, version: u32) -> Self {
+ let mut instruction = Self::new(Op::Source);
+ instruction.add_operand(source_language as u32);
+ instruction.add_operands(helpers::bytes_to_words(&version.to_le_bytes()));
+ instruction
+ }
+
+ pub(super) fn name(target_id: Word, name: &str) -> Self {
+ let mut instruction = Self::new(Op::Name);
+ instruction.add_operand(target_id);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn member_name(target_id: Word, member: Word, name: &str) -> Self {
+ let mut instruction = Self::new(Op::MemberName);
+ instruction.add_operand(target_id);
+ instruction.add_operand(member);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ //
+ // Annotation Instructions
+ //
+
+ pub(super) fn decorate(
+ target_id: Word,
+ decoration: spirv::Decoration,
+ operands: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::Decorate);
+ instruction.add_operand(target_id);
+ instruction.add_operand(decoration as u32);
+ for operand in operands {
+ instruction.add_operand(*operand)
+ }
+ instruction
+ }
+
+ pub(super) fn member_decorate(
+ target_id: Word,
+ member_index: Word,
+ decoration: spirv::Decoration,
+ operands: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::MemberDecorate);
+ instruction.add_operand(target_id);
+ instruction.add_operand(member_index);
+ instruction.add_operand(decoration as u32);
+ for operand in operands {
+ instruction.add_operand(*operand)
+ }
+ instruction
+ }
+
+ //
+ // Extension Instructions
+ //
+
+ pub(super) fn extension(name: &str) -> Self {
+ let mut instruction = Self::new(Op::Extension);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn ext_inst_import(id: Word, name: &str) -> Self {
+ let mut instruction = Self::new(Op::ExtInstImport);
+ instruction.set_result(id);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn ext_inst(
+ set_id: Word,
+ op: spirv::GLOp,
+ result_type_id: Word,
+ id: Word,
+ operands: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::ExtInst);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(set_id);
+ instruction.add_operand(op as u32);
+ for operand in operands {
+ instruction.add_operand(*operand)
+ }
+ instruction
+ }
+
+ //
+ // Mode-Setting Instructions
+ //
+
+ pub(super) fn memory_model(
+ addressing_model: spirv::AddressingModel,
+ memory_model: spirv::MemoryModel,
+ ) -> Self {
+ let mut instruction = Self::new(Op::MemoryModel);
+ instruction.add_operand(addressing_model as u32);
+ instruction.add_operand(memory_model as u32);
+ instruction
+ }
+
+ pub(super) fn entry_point(
+ execution_model: spirv::ExecutionModel,
+ entry_point_id: Word,
+ name: &str,
+ interface_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::EntryPoint);
+ instruction.add_operand(execution_model as u32);
+ instruction.add_operand(entry_point_id);
+ instruction.add_operands(helpers::string_to_words(name));
+
+ for interface_id in interface_ids {
+ instruction.add_operand(*interface_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn execution_mode(
+ entry_point_id: Word,
+ execution_mode: spirv::ExecutionMode,
+ args: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::ExecutionMode);
+ instruction.add_operand(entry_point_id);
+ instruction.add_operand(execution_mode as u32);
+ for arg in args {
+ instruction.add_operand(*arg);
+ }
+ instruction
+ }
+
+ pub(super) fn capability(capability: spirv::Capability) -> Self {
+ let mut instruction = Self::new(Op::Capability);
+ instruction.add_operand(capability as u32);
+ instruction
+ }
+
+ //
+ // Type-Declaration Instructions
+ //
+
+ pub(super) fn type_void(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeVoid);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_bool(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeBool);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_int(id: Word, width: Word, signedness: Signedness) -> Self {
+ let mut instruction = Self::new(Op::TypeInt);
+ instruction.set_result(id);
+ instruction.add_operand(width);
+ instruction.add_operand(signedness as u32);
+ instruction
+ }
+
+ pub(super) fn type_float(id: Word, width: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeFloat);
+ instruction.set_result(id);
+ instruction.add_operand(width);
+ instruction
+ }
+
+ pub(super) fn type_vector(
+ id: Word,
+ component_type_id: Word,
+ component_count: crate::VectorSize,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypeVector);
+ instruction.set_result(id);
+ instruction.add_operand(component_type_id);
+ instruction.add_operand(component_count as u32);
+ instruction
+ }
+
+ pub(super) fn type_matrix(
+ id: Word,
+ column_type_id: Word,
+ column_count: crate::VectorSize,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypeMatrix);
+ instruction.set_result(id);
+ instruction.add_operand(column_type_id);
+ instruction.add_operand(column_count as u32);
+ instruction
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn type_image(
+ id: Word,
+ sampled_type_id: Word,
+ dim: spirv::Dim,
+ flags: super::ImageTypeFlags,
+ image_format: spirv::ImageFormat,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypeImage);
+ instruction.set_result(id);
+ instruction.add_operand(sampled_type_id);
+ instruction.add_operand(dim as u32);
+ instruction.add_operand(flags.contains(super::ImageTypeFlags::DEPTH) as u32);
+ instruction.add_operand(flags.contains(super::ImageTypeFlags::ARRAYED) as u32);
+ instruction.add_operand(flags.contains(super::ImageTypeFlags::MULTISAMPLED) as u32);
+ instruction.add_operand(if flags.contains(super::ImageTypeFlags::SAMPLED) {
+ 1
+ } else {
+ 2
+ });
+ instruction.add_operand(image_format as u32);
+ instruction
+ }
+
+ pub(super) fn type_sampler(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeSampler);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_acceleration_structure(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeAccelerationStructureKHR);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_ray_query(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeRayQueryKHR);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_sampled_image(id: Word, image_type_id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeSampledImage);
+ instruction.set_result(id);
+ instruction.add_operand(image_type_id);
+ instruction
+ }
+
+ pub(super) fn type_array(id: Word, element_type_id: Word, length_id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeArray);
+ instruction.set_result(id);
+ instruction.add_operand(element_type_id);
+ instruction.add_operand(length_id);
+ instruction
+ }
+
+ pub(super) fn type_runtime_array(id: Word, element_type_id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeRuntimeArray);
+ instruction.set_result(id);
+ instruction.add_operand(element_type_id);
+ instruction
+ }
+
+ pub(super) fn type_struct(id: Word, member_ids: &[Word]) -> Self {
+ let mut instruction = Self::new(Op::TypeStruct);
+ instruction.set_result(id);
+
+ for member_id in member_ids {
+ instruction.add_operand(*member_id)
+ }
+
+ instruction
+ }
+
+ pub(super) fn type_pointer(
+ id: Word,
+ storage_class: spirv::StorageClass,
+ type_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypePointer);
+ instruction.set_result(id);
+ instruction.add_operand(storage_class as u32);
+ instruction.add_operand(type_id);
+ instruction
+ }
+
+ pub(super) fn type_function(id: Word, return_type_id: Word, parameter_ids: &[Word]) -> Self {
+ let mut instruction = Self::new(Op::TypeFunction);
+ instruction.set_result(id);
+ instruction.add_operand(return_type_id);
+
+ for parameter_id in parameter_ids {
+ instruction.add_operand(*parameter_id);
+ }
+
+ instruction
+ }
+
+ //
+ // Constant-Creation Instructions
+ //
+
+ pub(super) fn constant_null(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::ConstantNull);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn constant_true(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::ConstantTrue);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn constant_false(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::ConstantFalse);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn constant(result_type_id: Word, id: Word, values: &[Word]) -> Self {
+ let mut instruction = Self::new(Op::Constant);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ for value in values {
+ instruction.add_operand(*value);
+ }
+
+ instruction
+ }
+
+ pub(super) fn constant_composite(
+ result_type_id: Word,
+ id: Word,
+ constituent_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::ConstantComposite);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ for constituent_id in constituent_ids {
+ instruction.add_operand(*constituent_id);
+ }
+
+ instruction
+ }
+
+ //
+ // Memory Instructions
+ //
+
+ pub(super) fn variable(
+ result_type_id: Word,
+ id: Word,
+ storage_class: spirv::StorageClass,
+ initializer_id: Option<Word>,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Variable);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(storage_class as u32);
+
+ if let Some(initializer_id) = initializer_id {
+ instruction.add_operand(initializer_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn load(
+ result_type_id: Word,
+ id: Word,
+ pointer_id: Word,
+ memory_access: Option<spirv::MemoryAccess>,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Load);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(pointer_id);
+
+ if let Some(memory_access) = memory_access {
+ instruction.add_operand(memory_access.bits());
+ }
+
+ instruction
+ }
+
+ pub(super) fn atomic_load(
+ result_type_id: Word,
+ id: Word,
+ pointer_id: Word,
+ scope_id: Word,
+ semantics_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::AtomicLoad);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(pointer_id);
+ instruction.add_operand(scope_id);
+ instruction.add_operand(semantics_id);
+ instruction
+ }
+
+ pub(super) fn store(
+ pointer_id: Word,
+ value_id: Word,
+ memory_access: Option<spirv::MemoryAccess>,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Store);
+ instruction.add_operand(pointer_id);
+ instruction.add_operand(value_id);
+
+ if let Some(memory_access) = memory_access {
+ instruction.add_operand(memory_access.bits());
+ }
+
+ instruction
+ }
+
+ pub(super) fn atomic_store(
+ pointer_id: Word,
+ scope_id: Word,
+ semantics_id: Word,
+ value_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::AtomicStore);
+ instruction.add_operand(pointer_id);
+ instruction.add_operand(scope_id);
+ instruction.add_operand(semantics_id);
+ instruction.add_operand(value_id);
+ instruction
+ }
+
+ pub(super) fn access_chain(
+ result_type_id: Word,
+ id: Word,
+ base_id: Word,
+ index_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::AccessChain);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(base_id);
+
+ for index_id in index_ids {
+ instruction.add_operand(*index_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn array_length(
+ result_type_id: Word,
+ id: Word,
+ structure_id: Word,
+ array_member: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::ArrayLength);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(structure_id);
+ instruction.add_operand(array_member);
+ instruction
+ }
+
+ //
+ // Function Instructions
+ //
+
+ pub(super) fn function(
+ return_type_id: Word,
+ id: Word,
+ function_control: spirv::FunctionControl,
+ function_type_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Function);
+ instruction.set_type(return_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(function_control.bits());
+ instruction.add_operand(function_type_id);
+ instruction
+ }
+
+ pub(super) fn function_parameter(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::FunctionParameter);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) const fn function_end() -> Self {
+ Self::new(Op::FunctionEnd)
+ }
+
+ pub(super) fn function_call(
+ result_type_id: Word,
+ id: Word,
+ function_id: Word,
+ argument_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::FunctionCall);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(function_id);
+
+ for argument_id in argument_ids {
+ instruction.add_operand(*argument_id);
+ }
+
+ instruction
+ }
+
+ //
+ // Image Instructions
+ //
+
+ pub(super) fn sampled_image(
+ result_type_id: Word,
+ id: Word,
+ image: Word,
+ sampler: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::SampledImage);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(image);
+ instruction.add_operand(sampler);
+ instruction
+ }
+
+ pub(super) fn image_sample(
+ result_type_id: Word,
+ id: Word,
+ lod: SampleLod,
+ sampled_image: Word,
+ coordinates: Word,
+ depth_ref: Option<Word>,
+ ) -> Self {
+ let op = match (lod, depth_ref) {
+ (SampleLod::Explicit, None) => Op::ImageSampleExplicitLod,
+ (SampleLod::Implicit, None) => Op::ImageSampleImplicitLod,
+ (SampleLod::Explicit, Some(_)) => Op::ImageSampleDrefExplicitLod,
+ (SampleLod::Implicit, Some(_)) => Op::ImageSampleDrefImplicitLod,
+ };
+
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(sampled_image);
+ instruction.add_operand(coordinates);
+ if let Some(dref) = depth_ref {
+ instruction.add_operand(dref);
+ }
+
+ instruction
+ }
+
+ pub(super) fn image_gather(
+ result_type_id: Word,
+ id: Word,
+ sampled_image: Word,
+ coordinates: Word,
+ component_id: Word,
+ depth_ref: Option<Word>,
+ ) -> Self {
+ let op = match depth_ref {
+ None => Op::ImageGather,
+ Some(_) => Op::ImageDrefGather,
+ };
+
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(sampled_image);
+ instruction.add_operand(coordinates);
+ if let Some(dref) = depth_ref {
+ instruction.add_operand(dref);
+ } else {
+ instruction.add_operand(component_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn image_fetch_or_read(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ image: Word,
+ coordinates: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(image);
+ instruction.add_operand(coordinates);
+ instruction
+ }
+
+ pub(super) fn image_write(image: Word, coordinates: Word, value: Word) -> Self {
+ let mut instruction = Self::new(Op::ImageWrite);
+ instruction.add_operand(image);
+ instruction.add_operand(coordinates);
+ instruction.add_operand(value);
+ instruction
+ }
+
+ pub(super) fn image_query(op: Op, result_type_id: Word, id: Word, image: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(image);
+ instruction
+ }
+
+ //
+ // Ray Query Instructions
+ //
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn ray_query_initialize(
+ query: Word,
+ acceleration_structure: Word,
+ ray_flags: Word,
+ cull_mask: Word,
+ ray_origin: Word,
+ ray_tmin: Word,
+ ray_dir: Word,
+ ray_tmax: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::RayQueryInitializeKHR);
+ instruction.add_operand(query);
+ instruction.add_operand(acceleration_structure);
+ instruction.add_operand(ray_flags);
+ instruction.add_operand(cull_mask);
+ instruction.add_operand(ray_origin);
+ instruction.add_operand(ray_tmin);
+ instruction.add_operand(ray_dir);
+ instruction.add_operand(ray_tmax);
+ instruction
+ }
+
+ pub(super) fn ray_query_proceed(result_type_id: Word, id: Word, query: Word) -> Self {
+ let mut instruction = Self::new(Op::RayQueryProceedKHR);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(query);
+ instruction
+ }
+
+ pub(super) fn ray_query_get_intersection(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ query: Word,
+ intersection: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(query);
+ instruction.add_operand(intersection);
+ instruction
+ }
+
+ //
+ // Conversion Instructions
+ //
+ pub(super) fn unary(op: Op, result_type_id: Word, id: Word, value: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(value);
+ instruction
+ }
+
+ //
+ // Composite Instructions
+ //
+
+ pub(super) fn composite_construct(
+ result_type_id: Word,
+ id: Word,
+ constituent_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::CompositeConstruct);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ for constituent_id in constituent_ids {
+ instruction.add_operand(*constituent_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn composite_extract(
+ result_type_id: Word,
+ id: Word,
+ composite_id: Word,
+ indices: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::CompositeExtract);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ instruction.add_operand(composite_id);
+ for index in indices {
+ instruction.add_operand(*index);
+ }
+
+ instruction
+ }
+
+ pub(super) fn vector_extract_dynamic(
+ result_type_id: Word,
+ id: Word,
+ vector_id: Word,
+ index_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::VectorExtractDynamic);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ instruction.add_operand(vector_id);
+ instruction.add_operand(index_id);
+
+ instruction
+ }
+
+ pub(super) fn vector_shuffle(
+ result_type_id: Word,
+ id: Word,
+ v1_id: Word,
+ v2_id: Word,
+ components: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::VectorShuffle);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(v1_id);
+ instruction.add_operand(v2_id);
+
+ for &component in components {
+ instruction.add_operand(component);
+ }
+
+ instruction
+ }
+
+ //
+ // Arithmetic Instructions
+ //
+ pub(super) fn binary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(operand_1);
+ instruction.add_operand(operand_2);
+ instruction
+ }
+
+ pub(super) fn ternary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+ operand_3: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(operand_1);
+ instruction.add_operand(operand_2);
+ instruction.add_operand(operand_3);
+ instruction
+ }
+
+ pub(super) fn quaternary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+ operand_3: Word,
+ operand_4: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(operand_1);
+ instruction.add_operand(operand_2);
+ instruction.add_operand(operand_3);
+ instruction.add_operand(operand_4);
+ instruction
+ }
+
+ pub(super) fn relational(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(expr_id);
+ instruction
+ }
+
+ pub(super) fn atomic_binary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ pointer: Word,
+ scope_id: Word,
+ semantics_id: Word,
+ value: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(pointer);
+ instruction.add_operand(scope_id);
+ instruction.add_operand(semantics_id);
+ instruction.add_operand(value);
+ instruction
+ }
+
+ //
+ // Bit Instructions
+ //
+
+ //
+ // Relational and Logical Instructions
+ //
+
+ //
+ // Derivative Instructions
+ //
+
+ pub(super) fn derivative(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(expr_id);
+ instruction
+ }
+
+ //
+ // Control-Flow Instructions
+ //
+
+ pub(super) fn phi(
+ result_type_id: Word,
+ result_id: Word,
+ var_parent_pairs: &[(Word, Word)],
+ ) -> Self {
+ let mut instruction = Self::new(Op::Phi);
+ instruction.add_operand(result_type_id);
+ instruction.add_operand(result_id);
+ for &(variable, parent) in var_parent_pairs {
+ instruction.add_operand(variable);
+ instruction.add_operand(parent);
+ }
+ instruction
+ }
+
+ pub(super) fn selection_merge(
+ merge_id: Word,
+ selection_control: spirv::SelectionControl,
+ ) -> Self {
+ let mut instruction = Self::new(Op::SelectionMerge);
+ instruction.add_operand(merge_id);
+ instruction.add_operand(selection_control.bits());
+ instruction
+ }
+
+ pub(super) fn loop_merge(
+ merge_id: Word,
+ continuing_id: Word,
+ selection_control: spirv::SelectionControl,
+ ) -> Self {
+ let mut instruction = Self::new(Op::LoopMerge);
+ instruction.add_operand(merge_id);
+ instruction.add_operand(continuing_id);
+ instruction.add_operand(selection_control.bits());
+ instruction
+ }
+
+ pub(super) fn label(id: Word) -> Self {
+ let mut instruction = Self::new(Op::Label);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn branch(id: Word) -> Self {
+ let mut instruction = Self::new(Op::Branch);
+ instruction.add_operand(id);
+ instruction
+ }
+
+ // TODO Branch Weights not implemented.
+ pub(super) fn branch_conditional(
+ condition_id: Word,
+ true_label: Word,
+ false_label: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::BranchConditional);
+ instruction.add_operand(condition_id);
+ instruction.add_operand(true_label);
+ instruction.add_operand(false_label);
+ instruction
+ }
+
+ pub(super) fn switch(selector_id: Word, default_id: Word, cases: &[Case]) -> Self {
+ let mut instruction = Self::new(Op::Switch);
+ instruction.add_operand(selector_id);
+ instruction.add_operand(default_id);
+ for case in cases {
+ instruction.add_operand(case.value);
+ instruction.add_operand(case.label_id);
+ }
+ instruction
+ }
+
+ pub(super) fn select(
+ result_type_id: Word,
+ id: Word,
+ condition_id: Word,
+ accept_id: Word,
+ reject_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Select);
+ instruction.add_operand(result_type_id);
+ instruction.add_operand(id);
+ instruction.add_operand(condition_id);
+ instruction.add_operand(accept_id);
+ instruction.add_operand(reject_id);
+ instruction
+ }
+
+ pub(super) const fn kill() -> Self {
+ Self::new(Op::Kill)
+ }
+
+ pub(super) const fn return_void() -> Self {
+ Self::new(Op::Return)
+ }
+
+ pub(super) fn return_value(value_id: Word) -> Self {
+ let mut instruction = Self::new(Op::ReturnValue);
+ instruction.add_operand(value_id);
+ instruction
+ }
+
+ //
+ // Atomic Instructions
+ //
+
+ //
+ // Primitive Instructions
+ //
+
+ // Barriers
+
+ pub(super) fn control_barrier(
+ exec_scope_id: Word,
+ mem_scope_id: Word,
+ semantics_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::ControlBarrier);
+ instruction.add_operand(exec_scope_id);
+ instruction.add_operand(mem_scope_id);
+ instruction.add_operand(semantics_id);
+ instruction
+ }
+}
+
+impl From<crate::StorageFormat> for spirv::ImageFormat {
+ fn from(format: crate::StorageFormat) -> Self {
+ use crate::StorageFormat as Sf;
+ match format {
+ Sf::R8Unorm => Self::R8,
+ Sf::R8Snorm => Self::R8Snorm,
+ Sf::R8Uint => Self::R8ui,
+ Sf::R8Sint => Self::R8i,
+ Sf::R16Uint => Self::R16ui,
+ Sf::R16Sint => Self::R16i,
+ Sf::R16Float => Self::R16f,
+ Sf::Rg8Unorm => Self::Rg8,
+ Sf::Rg8Snorm => Self::Rg8Snorm,
+ Sf::Rg8Uint => Self::Rg8ui,
+ Sf::Rg8Sint => Self::Rg8i,
+ Sf::R32Uint => Self::R32ui,
+ Sf::R32Sint => Self::R32i,
+ Sf::R32Float => Self::R32f,
+ Sf::Rg16Uint => Self::Rg16ui,
+ Sf::Rg16Sint => Self::Rg16i,
+ Sf::Rg16Float => Self::Rg16f,
+ Sf::Rgba8Unorm => Self::Rgba8,
+ Sf::Rgba8Snorm => Self::Rgba8Snorm,
+ Sf::Rgba8Uint => Self::Rgba8ui,
+ Sf::Rgba8Sint => Self::Rgba8i,
+ Sf::Rgb10a2Unorm => Self::Rgb10a2ui,
+ Sf::Rg11b10Float => Self::R11fG11fB10f,
+ Sf::Rg32Uint => Self::Rg32ui,
+ Sf::Rg32Sint => Self::Rg32i,
+ Sf::Rg32Float => Self::Rg32f,
+ Sf::Rgba16Uint => Self::Rgba16ui,
+ Sf::Rgba16Sint => Self::Rgba16i,
+ Sf::Rgba16Float => Self::Rgba16f,
+ Sf::Rgba32Uint => Self::Rgba32ui,
+ Sf::Rgba32Sint => Self::Rgba32i,
+ Sf::Rgba32Float => Self::Rgba32f,
+ Sf::R16Unorm => Self::R16,
+ Sf::R16Snorm => Self::R16Snorm,
+ Sf::Rg16Unorm => Self::Rg16,
+ Sf::Rg16Snorm => Self::Rg16Snorm,
+ Sf::Rgba16Unorm => Self::Rgba16,
+ Sf::Rgba16Snorm => Self::Rgba16Snorm,
+ }
+ }
+}
+
+impl From<crate::ImageDimension> for spirv::Dim {
+ fn from(dim: crate::ImageDimension) -> Self {
+ use crate::ImageDimension as Id;
+ match dim {
+ Id::D1 => Self::Dim1D,
+ Id::D2 => Self::Dim2D,
+ Id::D3 => Self::Dim3D,
+ Id::Cube => Self::DimCube,
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/layout.rs b/third_party/rust/naga/src/back/spv/layout.rs
new file mode 100644
index 0000000000..39117a3d2a
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/layout.rs
@@ -0,0 +1,210 @@
+use super::{Instruction, LogicalLayout, PhysicalLayout};
+use spirv::{Op, Word, MAGIC_NUMBER};
+use std::iter;
+
+// https://github.com/KhronosGroup/SPIRV-Headers/pull/195
+const GENERATOR: Word = 28;
+
+impl PhysicalLayout {
+ pub(super) const fn new(version: Word) -> Self {
+ PhysicalLayout {
+ magic_number: MAGIC_NUMBER,
+ version,
+ generator: GENERATOR,
+ bound: 0,
+ instruction_schema: 0x0u32,
+ }
+ }
+
+ pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) {
+ sink.extend(iter::once(self.magic_number));
+ sink.extend(iter::once(self.version));
+ sink.extend(iter::once(self.generator));
+ sink.extend(iter::once(self.bound));
+ sink.extend(iter::once(self.instruction_schema));
+ }
+}
+
+impl super::recyclable::Recyclable for PhysicalLayout {
+ fn recycle(self) -> Self {
+ PhysicalLayout {
+ magic_number: self.magic_number,
+ version: self.version,
+ generator: self.generator,
+ instruction_schema: self.instruction_schema,
+ bound: 0,
+ }
+ }
+}
+
+impl LogicalLayout {
+ pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) {
+ sink.extend(self.capabilities.iter().cloned());
+ sink.extend(self.extensions.iter().cloned());
+ sink.extend(self.ext_inst_imports.iter().cloned());
+ sink.extend(self.memory_model.iter().cloned());
+ sink.extend(self.entry_points.iter().cloned());
+ sink.extend(self.execution_modes.iter().cloned());
+ sink.extend(self.debugs.iter().cloned());
+ sink.extend(self.annotations.iter().cloned());
+ sink.extend(self.declarations.iter().cloned());
+ sink.extend(self.function_declarations.iter().cloned());
+ sink.extend(self.function_definitions.iter().cloned());
+ }
+}
+
+impl super::recyclable::Recyclable for LogicalLayout {
+ fn recycle(self) -> Self {
+ Self {
+ capabilities: self.capabilities.recycle(),
+ extensions: self.extensions.recycle(),
+ ext_inst_imports: self.ext_inst_imports.recycle(),
+ memory_model: self.memory_model.recycle(),
+ entry_points: self.entry_points.recycle(),
+ execution_modes: self.execution_modes.recycle(),
+ debugs: self.debugs.recycle(),
+ annotations: self.annotations.recycle(),
+ declarations: self.declarations.recycle(),
+ function_declarations: self.function_declarations.recycle(),
+ function_definitions: self.function_definitions.recycle(),
+ }
+ }
+}
+
+impl Instruction {
+ pub(super) const fn new(op: Op) -> Self {
+ Instruction {
+ op,
+ wc: 1, // Always start at 1 for the first word (OP + WC),
+ type_id: None,
+ result_id: None,
+ operands: vec![],
+ }
+ }
+
+ #[allow(clippy::panic)]
+ pub(super) fn set_type(&mut self, id: Word) {
+ assert!(self.type_id.is_none(), "Type can only be set once");
+ self.type_id = Some(id);
+ self.wc += 1;
+ }
+
+ #[allow(clippy::panic)]
+ pub(super) fn set_result(&mut self, id: Word) {
+ assert!(self.result_id.is_none(), "Result can only be set once");
+ self.result_id = Some(id);
+ self.wc += 1;
+ }
+
+ pub(super) fn add_operand(&mut self, operand: Word) {
+ self.operands.push(operand);
+ self.wc += 1;
+ }
+
+ pub(super) fn add_operands(&mut self, operands: Vec<Word>) {
+ for operand in operands.into_iter() {
+ self.add_operand(operand)
+ }
+ }
+
+ pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) {
+ sink.extend(Some(self.wc << 16 | self.op as u32));
+ sink.extend(self.type_id);
+ sink.extend(self.result_id);
+ sink.extend(self.operands.iter().cloned());
+ }
+}
+
+impl Instruction {
+ #[cfg(test)]
+ fn validate(&self, words: &[Word]) {
+ let mut inst_index = 0;
+ let (wc, op) = ((words[inst_index] >> 16) as u16, words[inst_index] as u16);
+ inst_index += 1;
+
+ assert_eq!(wc, words.len() as u16);
+ assert_eq!(op, self.op as u16);
+
+ if self.type_id.is_some() {
+ assert_eq!(words[inst_index], self.type_id.unwrap());
+ inst_index += 1;
+ }
+
+ if self.result_id.is_some() {
+ assert_eq!(words[inst_index], self.result_id.unwrap());
+ inst_index += 1;
+ }
+
+ for (op_index, i) in (inst_index..wc as usize).enumerate() {
+ assert_eq!(words[i], self.operands[op_index]);
+ }
+ }
+}
+
+#[test]
+fn test_physical_layout_in_words() {
+ let bound = 5;
+ let version = 0x10203;
+
+ let mut output = vec![];
+ let mut layout = PhysicalLayout::new(version);
+ layout.bound = bound;
+
+ layout.in_words(&mut output);
+
+ assert_eq!(&output, &[MAGIC_NUMBER, version, GENERATOR, bound, 0,]);
+}
+
+#[test]
+fn test_logical_layout_in_words() {
+ let mut output = vec![];
+ let mut layout = LogicalLayout::default();
+ let layout_vectors = 11;
+ let mut instructions = Vec::with_capacity(layout_vectors);
+
+ let vector_names = &[
+ "Capabilities",
+ "Extensions",
+ "External Instruction Imports",
+ "Memory Model",
+ "Entry Points",
+ "Execution Modes",
+ "Debugs",
+ "Annotations",
+ "Declarations",
+ "Function Declarations",
+ "Function Definitions",
+ ];
+
+ for (i, _) in vector_names.iter().enumerate().take(layout_vectors) {
+ let mut dummy_instruction = Instruction::new(Op::Constant);
+ dummy_instruction.set_type((i + 1) as u32);
+ dummy_instruction.set_result((i + 2) as u32);
+ dummy_instruction.add_operand((i + 3) as u32);
+ dummy_instruction.add_operands(super::helpers::string_to_words(
+ format!("This is the vector: {}", vector_names[i]).as_str(),
+ ));
+ instructions.push(dummy_instruction);
+ }
+
+ instructions[0].to_words(&mut layout.capabilities);
+ instructions[1].to_words(&mut layout.extensions);
+ instructions[2].to_words(&mut layout.ext_inst_imports);
+ instructions[3].to_words(&mut layout.memory_model);
+ instructions[4].to_words(&mut layout.entry_points);
+ instructions[5].to_words(&mut layout.execution_modes);
+ instructions[6].to_words(&mut layout.debugs);
+ instructions[7].to_words(&mut layout.annotations);
+ instructions[8].to_words(&mut layout.declarations);
+ instructions[9].to_words(&mut layout.function_declarations);
+ instructions[10].to_words(&mut layout.function_definitions);
+
+ layout.in_words(&mut output);
+
+ let mut index: usize = 0;
+ for instruction in instructions {
+ let wc = instruction.wc as usize;
+ instruction.validate(&output[index..index + wc]);
+ index += wc;
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/mod.rs b/third_party/rust/naga/src/back/spv/mod.rs
new file mode 100644
index 0000000000..9b084911b1
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/mod.rs
@@ -0,0 +1,729 @@
+/*!
+Backend for [SPIR-V][spv] (Standard Portable Intermediate Representation).
+
+[spv]: https://www.khronos.org/registry/SPIR-V/
+*/
+
+mod block;
+mod helpers;
+mod image;
+mod index;
+mod instructions;
+mod layout;
+mod ray;
+mod recyclable;
+mod selection;
+mod writer;
+
+pub use spirv::Capability;
+
+use crate::arena::Handle;
+use crate::proc::{BoundsCheckPolicies, TypeResolution};
+
+use spirv::Word;
+use std::ops;
+use thiserror::Error;
+
+#[derive(Clone)]
+struct PhysicalLayout {
+ magic_number: Word,
+ version: Word,
+ generator: Word,
+ bound: Word,
+ instruction_schema: Word,
+}
+
+#[derive(Default)]
+struct LogicalLayout {
+ capabilities: Vec<Word>,
+ extensions: Vec<Word>,
+ ext_inst_imports: Vec<Word>,
+ memory_model: Vec<Word>,
+ entry_points: Vec<Word>,
+ execution_modes: Vec<Word>,
+ debugs: Vec<Word>,
+ annotations: Vec<Word>,
+ declarations: Vec<Word>,
+ function_declarations: Vec<Word>,
+ function_definitions: Vec<Word>,
+}
+
+struct Instruction {
+ op: spirv::Op,
+ wc: u32,
+ type_id: Option<Word>,
+ result_id: Option<Word>,
+ operands: Vec<Word>,
+}
+
+const BITS_PER_BYTE: crate::Bytes = 8;
+
+#[derive(Clone, Debug, Error)]
+pub enum Error {
+ #[error("The requested entry point couldn't be found")]
+ EntryPointNotFound,
+ #[error("target SPIRV-{0}.{1} is not supported")]
+ UnsupportedVersion(u8, u8),
+ #[error("using {0} requires at least one of the capabilities {1:?}, but none are available")]
+ MissingCapabilities(&'static str, Vec<Capability>),
+ #[error("unimplemented {0}")]
+ FeatureNotImplemented(&'static str),
+ #[error("module is not validated properly: {0}")]
+ Validation(&'static str),
+}
+
+#[derive(Default)]
+struct IdGenerator(Word);
+
+impl IdGenerator {
+ fn next(&mut self) -> Word {
+ self.0 += 1;
+ self.0
+ }
+}
+
+/// A SPIR-V block to which we are still adding instructions.
+///
+/// A `Block` represents a SPIR-V block that does not yet have a termination
+/// instruction like `OpBranch` or `OpReturn`.
+///
+/// The `OpLabel` that starts the block is implicit. It will be emitted based on
+/// `label_id` when we write the block to a `LogicalLayout`.
+///
+/// To terminate a `Block`, pass the block and the termination instruction to
+/// `Function::consume`. This takes ownership of the `Block` and transforms it
+/// into a `TerminatedBlock`.
+struct Block {
+ label_id: Word,
+ body: Vec<Instruction>,
+}
+
+/// A SPIR-V block that ends with a termination instruction.
+struct TerminatedBlock {
+ label_id: Word,
+ body: Vec<Instruction>,
+}
+
+impl Block {
+ const fn new(label_id: Word) -> Self {
+ Block {
+ label_id,
+ body: Vec::new(),
+ }
+ }
+}
+
+struct LocalVariable {
+ id: Word,
+ instruction: Instruction,
+}
+
+struct ResultMember {
+ id: Word,
+ type_id: Word,
+ built_in: Option<crate::BuiltIn>,
+}
+
+struct EntryPointContext {
+ argument_ids: Vec<Word>,
+ results: Vec<ResultMember>,
+}
+
+#[derive(Default)]
+struct Function {
+ signature: Option<Instruction>,
+ parameters: Vec<FunctionArgument>,
+ variables: crate::FastHashMap<Handle<crate::LocalVariable>, LocalVariable>,
+ blocks: Vec<TerminatedBlock>,
+ entry_point_context: Option<EntryPointContext>,
+}
+
+impl Function {
+ fn consume(&mut self, mut block: Block, termination: Instruction) {
+ block.body.push(termination);
+ self.blocks.push(TerminatedBlock {
+ label_id: block.label_id,
+ body: block.body,
+ })
+ }
+
+ fn parameter_id(&self, index: u32) -> Word {
+ match self.entry_point_context {
+ Some(ref context) => context.argument_ids[index as usize],
+ None => self.parameters[index as usize]
+ .instruction
+ .result_id
+ .unwrap(),
+ }
+ }
+}
+
+/// Characteristics of a SPIR-V `OpTypeImage` type.
+///
+/// SPIR-V requires non-composite types to be unique, including images. Since we
+/// use `LocalType` for this deduplication, it's essential that `LocalImageType`
+/// be equal whenever the corresponding `OpTypeImage`s would be. To reduce the
+/// likelihood of mistakes, we use fields that correspond exactly to the
+/// operands of an `OpTypeImage` instruction, using the actual SPIR-V types
+/// where practical.
+#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
+struct LocalImageType {
+ sampled_type: crate::ScalarKind,
+ dim: spirv::Dim,
+ flags: ImageTypeFlags,
+ image_format: spirv::ImageFormat,
+}
+
+bitflags::bitflags! {
+ /// Flags corresponding to the boolean(-ish) parameters to OpTypeImage.
+ pub struct ImageTypeFlags: u8 {
+ const DEPTH = 0x1;
+ const ARRAYED = 0x2;
+ const MULTISAMPLED = 0x4;
+ const SAMPLED = 0x8;
+ }
+}
+
+impl LocalImageType {
+ /// Construct a `LocalImageType` from the fields of a `TypeInner::Image`.
+ fn from_inner(dim: crate::ImageDimension, arrayed: bool, class: crate::ImageClass) -> Self {
+ let make_flags = |multi: bool, other: ImageTypeFlags| -> ImageTypeFlags {
+ let mut flags = other;
+ flags.set(ImageTypeFlags::ARRAYED, arrayed);
+ flags.set(ImageTypeFlags::MULTISAMPLED, multi);
+ flags
+ };
+
+ let dim = spirv::Dim::from(dim);
+
+ match class {
+ crate::ImageClass::Sampled { kind, multi } => LocalImageType {
+ sampled_type: kind,
+ dim,
+ flags: make_flags(multi, ImageTypeFlags::SAMPLED),
+ image_format: spirv::ImageFormat::Unknown,
+ },
+ crate::ImageClass::Depth { multi } => LocalImageType {
+ sampled_type: crate::ScalarKind::Float,
+ dim,
+ flags: make_flags(multi, ImageTypeFlags::DEPTH | ImageTypeFlags::SAMPLED),
+ image_format: spirv::ImageFormat::Unknown,
+ },
+ crate::ImageClass::Storage { format, access: _ } => LocalImageType {
+ sampled_type: crate::ScalarKind::from(format),
+ dim,
+ flags: make_flags(false, ImageTypeFlags::empty()),
+ image_format: format.into(),
+ },
+ }
+ }
+}
+
+/// A SPIR-V type constructed during code generation.
+///
+/// This is the variant of [`LookupType`] used to represent types that might not
+/// be available in the arena. Variants are present here for one of two reasons:
+///
+/// - They represent types synthesized during code generation, as explained
+/// in the documentation for [`LookupType`].
+///
+/// - They represent types for which SPIR-V forbids duplicate `OpType...`
+/// instructions, requiring deduplication.
+///
+/// This is not a complete copy of [`TypeInner`]: for example, SPIR-V generation
+/// never synthesizes new struct types, so `LocalType` has nothing for that.
+///
+/// Each `LocalType` variant should be handled identically to its analogous
+/// `TypeInner` variant. You can use the [`make_local`] function to help with
+/// this, by converting everything possible to a `LocalType` before inspecting
+/// it.
+///
+/// ## `Localtype` equality and SPIR-V `OpType` uniqueness
+///
+/// The definition of `Eq` on `LocalType` is carefully chosen to help us follow
+/// certain SPIR-V rules. SPIR-V §2.8 requires some classes of `OpType...`
+/// instructions to be unique; for example, you can't have two `OpTypeInt 32 1`
+/// instructions in the same module. All 32-bit signed integers must use the
+/// same type id.
+///
+/// All SPIR-V types that must be unique can be represented as a `LocalType`,
+/// and two `LocalType`s are always `Eq` if SPIR-V would require them to use the
+/// same `OpType...` instruction. This lets us avoid duplicates by recording the
+/// ids of the type instructions we've already generated in a hash table,
+/// [`Writer::lookup_type`], keyed by `LocalType`.
+///
+/// As another example, [`LocalImageType`], stored in the `LocalType::Image`
+/// variant, is designed to help us deduplicate `OpTypeImage` instructions. See
+/// its documentation for details.
+///
+/// `LocalType` also includes variants like `Pointer` that do not need to be
+/// unique - but it is harmless to avoid the duplication.
+///
+/// As it always must, the `Hash` implementation respects the `Eq` relation.
+///
+/// [`TypeInner`]: crate::TypeInner
+#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
+enum LocalType {
+ /// A scalar, vector, or pointer to one of those.
+ Value {
+ /// If `None`, this represents a scalar type. If `Some`, this represents
+ /// a vector type of the given size.
+ vector_size: Option<crate::VectorSize>,
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ pointer_space: Option<spirv::StorageClass>,
+ },
+ /// A matrix of floating-point values.
+ Matrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: crate::Bytes,
+ },
+ Pointer {
+ base: Handle<crate::Type>,
+ class: spirv::StorageClass,
+ },
+ Image(LocalImageType),
+ SampledImage {
+ image_type_id: Word,
+ },
+ Sampler,
+ PointerToBindingArray {
+ base: Handle<crate::Type>,
+ size: u64,
+ },
+ BindingArray {
+ base: Handle<crate::Type>,
+ size: u64,
+ },
+ AccelerationStructure,
+ RayQuery,
+}
+
+/// A type encountered during SPIR-V generation.
+///
+/// In the process of writing SPIR-V, we need to synthesize various types for
+/// intermediate results and such: pointer types, vector/matrix component types,
+/// or even booleans, which usually appear in SPIR-V code even when they're not
+/// used by the module source.
+///
+/// However, we can't use `crate::Type` or `crate::TypeInner` for these, as the
+/// type arena may not contain what we need (it only contains types used
+/// directly by other parts of the IR), and the IR module is immutable, so we
+/// can't add anything to it.
+///
+/// So for local use in the SPIR-V writer, we use this type, which holds either
+/// a handle into the arena, or a [`LocalType`] containing something synthesized
+/// locally.
+///
+/// This is very similar to the [`proc::TypeResolution`] enum, with `LocalType`
+/// playing the role of `TypeInner`. However, `LocalType` also has other
+/// properties needed for SPIR-V generation; see the description of
+/// [`LocalType`] for details.
+///
+/// [`proc::TypeResolution`]: crate::proc::TypeResolution
+#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
+enum LookupType {
+ Handle(Handle<crate::Type>),
+ Local(LocalType),
+}
+
+impl From<LocalType> for LookupType {
+ fn from(local: LocalType) -> Self {
+ Self::Local(local)
+ }
+}
+
+#[derive(Debug, PartialEq, Clone, Hash, Eq)]
+struct LookupFunctionType {
+ parameter_type_ids: Vec<Word>,
+ return_type_id: Word,
+}
+
+fn make_local(inner: &crate::TypeInner) -> Option<LocalType> {
+ Some(match *inner {
+ crate::TypeInner::Scalar { kind, width } | crate::TypeInner::Atomic { kind, width } => {
+ LocalType::Value {
+ vector_size: None,
+ kind,
+ width,
+ pointer_space: None,
+ }
+ }
+ crate::TypeInner::Vector { size, kind, width } => LocalType::Value {
+ vector_size: Some(size),
+ kind,
+ width,
+ pointer_space: None,
+ },
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => LocalType::Matrix {
+ columns,
+ rows,
+ width,
+ },
+ crate::TypeInner::Pointer { base, space } => LocalType::Pointer {
+ base,
+ class: helpers::map_storage_class(space),
+ },
+ crate::TypeInner::ValuePointer {
+ size,
+ kind,
+ width,
+ space,
+ } => LocalType::Value {
+ vector_size: size,
+ kind,
+ width,
+ pointer_space: Some(helpers::map_storage_class(space)),
+ },
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)),
+ crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler,
+ crate::TypeInner::AccelerationStructure => LocalType::AccelerationStructure,
+ crate::TypeInner::RayQuery => LocalType::RayQuery,
+ crate::TypeInner::Array { .. }
+ | crate::TypeInner::Struct { .. }
+ | crate::TypeInner::BindingArray { .. } => return None,
+ })
+}
+
+#[derive(Debug)]
+enum Dimension {
+ Scalar,
+ Vector,
+ Matrix,
+}
+
+/// A map from evaluated [`Expression`](crate::Expression)s to their SPIR-V ids.
+///
+/// When we emit code to evaluate a given `Expression`, we record the
+/// SPIR-V id of its value here, under its `Handle<Expression>` index.
+///
+/// A `CachedExpressions` value can be indexed by a `Handle<Expression>` value.
+///
+/// [emit]: index.html#expression-evaluation-time-and-scope
+#[derive(Default)]
+struct CachedExpressions {
+ ids: Vec<Word>,
+}
+impl CachedExpressions {
+ fn reset(&mut self, length: usize) {
+ self.ids.clear();
+ self.ids.resize(length, 0);
+ }
+}
+impl ops::Index<Handle<crate::Expression>> for CachedExpressions {
+ type Output = Word;
+ fn index(&self, h: Handle<crate::Expression>) -> &Word {
+ let id = &self.ids[h.index()];
+ if *id == 0 {
+ unreachable!("Expression {:?} is not cached!", h);
+ }
+ id
+ }
+}
+impl ops::IndexMut<Handle<crate::Expression>> for CachedExpressions {
+ fn index_mut(&mut self, h: Handle<crate::Expression>) -> &mut Word {
+ let id = &mut self.ids[h.index()];
+ if *id != 0 {
+ unreachable!("Expression {:?} is already cached!", h);
+ }
+ id
+ }
+}
+impl recyclable::Recyclable for CachedExpressions {
+ fn recycle(self) -> Self {
+ CachedExpressions {
+ ids: self.ids.recycle(),
+ }
+ }
+}
+
+#[derive(Eq, Hash, PartialEq)]
+enum CachedConstant {
+ Scalar {
+ value: crate::ScalarValue,
+ width: crate::Bytes,
+ },
+ Composite {
+ ty: LookupType,
+ constituent_ids: Vec<Word>,
+ },
+}
+
+#[derive(Clone)]
+struct GlobalVariable {
+ /// ID of the OpVariable that declares the global.
+ ///
+ /// If you need the variable's value, use [`access_id`] instead of this
+ /// field. If we wrapped the Naga IR `GlobalVariable`'s type in a struct to
+ /// comply with Vulkan's requirements, then this points to the `OpVariable`
+ /// with the synthesized struct type, whereas `access_id` points to the
+ /// field of said struct that holds the variable's actual value.
+ ///
+ /// This is used to compute the `access_id` pointer in function prologues,
+ /// and used for `ArrayLength` expressions, which do need the struct.
+ ///
+ /// [`access_id`]: GlobalVariable::access_id
+ var_id: Word,
+
+ /// For `AddressSpace::Handle` variables, this ID is recorded in the function
+ /// prelude block (and reset before every function) as `OpLoad` of the variable.
+ /// It is then used for all the global ops, such as `OpImageSample`.
+ handle_id: Word,
+
+ /// Actual ID used to access this variable.
+ /// For wrapped buffer variables, this ID is `OpAccessChain` into the
+ /// wrapper. Otherwise, the same as `var_id`.
+ ///
+ /// Vulkan requires that globals in the `StorageBuffer` and `Uniform` storage
+ /// classes must be structs with the `Block` decoration, but WGSL and Naga IR
+ /// make no such requirement. So for such variables, we generate a wrapper struct
+ /// type with a single element of the type given by Naga, generate an
+ /// `OpAccessChain` for that member in the function prelude, and use that pointer
+ /// to refer to the global in the function body. This is the id of that access,
+ /// updated for each function in `write_function`.
+ access_id: Word,
+}
+
+impl GlobalVariable {
+ const fn dummy() -> Self {
+ Self {
+ var_id: 0,
+ handle_id: 0,
+ access_id: 0,
+ }
+ }
+
+ const fn new(id: Word) -> Self {
+ Self {
+ var_id: id,
+ handle_id: 0,
+ access_id: 0,
+ }
+ }
+
+ /// Prepare `self` for use within a single function.
+ fn reset_for_function(&mut self) {
+ self.handle_id = 0;
+ self.access_id = 0;
+ }
+}
+
+struct FunctionArgument {
+ /// Actual instruction of the argument.
+ instruction: Instruction,
+ handle_id: Word,
+}
+
+/// General information needed to emit SPIR-V for Naga statements.
+struct BlockContext<'w> {
+ /// The writer handling the module to which this code belongs.
+ writer: &'w mut Writer,
+
+ /// The [`Module`](crate::Module) for which we're generating code.
+ ir_module: &'w crate::Module,
+
+ /// The [`Function`](crate::Function) for which we're generating code.
+ ir_function: &'w crate::Function,
+
+ /// Information module validation produced about
+ /// [`ir_function`](BlockContext::ir_function).
+ fun_info: &'w crate::valid::FunctionInfo,
+
+ /// The [`spv::Function`](Function) to which we are contributing SPIR-V instructions.
+ function: &'w mut Function,
+
+ /// SPIR-V ids for expressions we've evaluated.
+ cached: CachedExpressions,
+
+ /// The `Writer`'s temporary vector, for convenience.
+ temp_list: Vec<Word>,
+}
+
+impl BlockContext<'_> {
+ fn gen_id(&mut self) -> Word {
+ self.writer.id_gen.next()
+ }
+
+ fn get_type_id(&mut self, lookup_type: LookupType) -> Word {
+ self.writer.get_type_id(lookup_type)
+ }
+
+ fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
+ self.writer.get_expression_type_id(tr)
+ }
+
+ fn get_index_constant(&mut self, index: Word) -> Word {
+ self.writer
+ .get_constant_scalar(crate::ScalarValue::Uint(index as _), 4)
+ }
+
+ fn get_scope_constant(&mut self, scope: Word) -> Word {
+ self.writer
+ .get_constant_scalar(crate::ScalarValue::Sint(scope as _), 4)
+ }
+}
+
+#[derive(Clone, Copy, Default)]
+struct LoopContext {
+ continuing_id: Option<Word>,
+ break_id: Option<Word>,
+}
+
+pub struct Writer {
+ physical_layout: PhysicalLayout,
+ logical_layout: LogicalLayout,
+ id_gen: IdGenerator,
+
+ /// The set of capabilities modules are permitted to use.
+ ///
+ /// This is initialized from `Options::capabilities`.
+ capabilities_available: Option<crate::FastHashSet<Capability>>,
+
+ /// The set of capabilities used by this module.
+ ///
+ /// If `capabilities_available` is `Some`, then this is always a subset of
+ /// that.
+ capabilities_used: crate::FastHashSet<Capability>,
+
+ /// The set of spirv extensions used.
+ extensions_used: crate::FastHashSet<&'static str>,
+
+ debugs: Vec<Instruction>,
+ annotations: Vec<Instruction>,
+ flags: WriterFlags,
+ bounds_check_policies: BoundsCheckPolicies,
+ zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode,
+ void_type: Word,
+ //TODO: convert most of these into vectors, addressable by handle indices
+ lookup_type: crate::FastHashMap<LookupType, Word>,
+ lookup_function: crate::FastHashMap<Handle<crate::Function>, Word>,
+ lookup_function_type: crate::FastHashMap<LookupFunctionType, Word>,
+ constant_ids: Vec<Word>,
+ cached_constants: crate::FastHashMap<CachedConstant, Word>,
+ global_variables: Vec<GlobalVariable>,
+ binding_map: BindingMap,
+
+ // Cached expressions are only meaningful within a BlockContext, but we
+ // retain the table here between functions to save heap allocations.
+ saved_cached: CachedExpressions,
+
+ gl450_ext_inst_id: Word,
+ // Just a temporary list of SPIR-V ids
+ temp_list: Vec<Word>,
+}
+
+bitflags::bitflags! {
+ pub struct WriterFlags: u32 {
+ /// Include debug labels for everything.
+ const DEBUG = 0x1;
+ /// Flip Y coordinate of `BuiltIn::Position` output.
+ const ADJUST_COORDINATE_SPACE = 0x2;
+ /// Emit `OpName` for input/output locations.
+ /// Contrary to spec, some drivers treat it as semantic, not allowing
+ /// any conflicts.
+ const LABEL_VARYINGS = 0x4;
+ /// Emit `PointSize` output builtin to vertex shaders, which is
+ /// required for drawing with `PointList` topology.
+ const FORCE_POINT_SIZE = 0x8;
+ /// Clamp `BuiltIn::FragDepth` output between 0 and 1.
+ const CLAMP_FRAG_DEPTH = 0x10;
+ }
+}
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct BindingInfo {
+ /// If the binding is an unsized binding array, this overrides the size.
+ pub binding_array_size: Option<u32>,
+}
+
+// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindingInfo>;
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub enum ZeroInitializeWorkgroupMemoryMode {
+ /// Via `VK_KHR_zero_initialize_workgroup_memory` or Vulkan 1.3
+ Native,
+ /// Via assignments + barrier
+ Polyfill,
+ None,
+}
+
+#[derive(Debug, Clone)]
+pub struct Options {
+ /// (Major, Minor) target version of the SPIR-V.
+ pub lang_version: (u8, u8),
+
+ /// Configuration flags for the writer.
+ pub flags: WriterFlags,
+
+ /// Map of resources to information about the binding.
+ pub binding_map: BindingMap,
+
+ /// If given, the set of capabilities modules are allowed to use. Code that
+ /// requires capabilities beyond these is rejected with an error.
+ ///
+ /// If this is `None`, all capabilities are permitted.
+ pub capabilities: Option<crate::FastHashSet<Capability>>,
+
+ /// How should generate code handle array, vector, matrix, or image texel
+ /// indices that are out of range?
+ pub bounds_check_policies: BoundsCheckPolicies,
+
+ /// Dictates the way workgroup variables should be zero initialized
+ pub zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ let mut flags = WriterFlags::ADJUST_COORDINATE_SPACE
+ | WriterFlags::LABEL_VARYINGS
+ | WriterFlags::CLAMP_FRAG_DEPTH;
+ if cfg!(debug_assertions) {
+ flags |= WriterFlags::DEBUG;
+ }
+ Options {
+ lang_version: (1, 0),
+ flags,
+ binding_map: BindingMap::default(),
+ capabilities: None,
+ bounds_check_policies: crate::proc::BoundsCheckPolicies::default(),
+ zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode::Polyfill,
+ }
+ }
+}
+
+// A subset of options meant to be changed per pipeline.
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct PipelineOptions {
+ /// The stage of the entry point.
+ pub shader_stage: crate::ShaderStage,
+ /// The name of the entry point.
+ ///
+ /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown.
+ pub entry_point: String,
+}
+
+pub fn write_vec(
+ module: &crate::Module,
+ info: &crate::valid::ModuleInfo,
+ options: &Options,
+ pipeline_options: Option<&PipelineOptions>,
+) -> Result<Vec<u32>, Error> {
+ let mut words = Vec::new();
+ let mut w = Writer::new(options)?;
+ w.write(module, info, pipeline_options, &mut words)?;
+ Ok(words)
+}
diff --git a/third_party/rust/naga/src/back/spv/ray.rs b/third_party/rust/naga/src/back/spv/ray.rs
new file mode 100644
index 0000000000..79eb2ff971
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/ray.rs
@@ -0,0 +1,273 @@
+/*!
+Generating SPIR-V for ray query operations.
+*/
+
+use super::{Block, BlockContext, Instruction, LocalType, LookupType};
+use crate::arena::Handle;
+
+impl<'w> BlockContext<'w> {
+ pub(super) fn write_ray_query_function(
+ &mut self,
+ query: Handle<crate::Expression>,
+ function: &crate::RayQueryFunction,
+ block: &mut Block,
+ ) {
+ let query_id = self.cached[query];
+ match *function {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ //Note: composite extract indices and types must match `generate_ray_desc_type`
+ let desc_id = self.cached[descriptor];
+ let acc_struct_id = self.get_image_id(acceleration_structure);
+ let width = 4;
+
+ let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Uint,
+ width,
+ pointer_space: None,
+ }));
+ let ray_flags_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ flag_type_id,
+ ray_flags_id,
+ desc_id,
+ &[0],
+ ));
+ let cull_mask_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ flag_type_id,
+ cull_mask_id,
+ desc_id,
+ &[1],
+ ));
+
+ let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Float,
+ width,
+ pointer_space: None,
+ }));
+ let tmin_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ scalar_type_id,
+ tmin_id,
+ desc_id,
+ &[2],
+ ));
+ let tmax_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ scalar_type_id,
+ tmax_id,
+ desc_id,
+ &[3],
+ ));
+
+ let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Tri),
+ kind: crate::ScalarKind::Float,
+ width,
+ pointer_space: None,
+ }));
+ let ray_origin_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ vector_type_id,
+ ray_origin_id,
+ desc_id,
+ &[4],
+ ));
+ let ray_dir_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ vector_type_id,
+ ray_dir_id,
+ desc_id,
+ &[5],
+ ));
+
+ block.body.push(Instruction::ray_query_initialize(
+ query_id,
+ acc_struct_id,
+ ray_flags_id,
+ cull_mask_id,
+ ray_origin_id,
+ tmin_id,
+ ray_dir_id,
+ tmax_id,
+ ));
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ let id = self.gen_id();
+ self.cached[result] = id;
+ let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
+
+ block
+ .body
+ .push(Instruction::ray_query_proceed(result_type_id, id, query_id));
+ }
+ crate::RayQueryFunction::Terminate => {}
+ }
+ }
+
+ pub(super) fn write_ray_query_get_intersection(
+ &mut self,
+ query: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> spirv::Word {
+ let width = 4;
+ let query_id = self.cached[query];
+ let intersection_id = self.writer.get_constant_scalar(
+ crate::ScalarValue::Uint(
+ spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _,
+ ),
+ width,
+ );
+
+ let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Uint,
+ width,
+ pointer_space: None,
+ }));
+ let kind_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionTypeKHR,
+ flag_type_id,
+ kind_id,
+ query_id,
+ intersection_id,
+ ));
+ let instance_custom_index_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
+ flag_type_id,
+ instance_custom_index_id,
+ query_id,
+ intersection_id,
+ ));
+ let instance_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
+ flag_type_id,
+ instance_id,
+ query_id,
+ intersection_id,
+ ));
+ let sbt_record_offset_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
+ flag_type_id,
+ sbt_record_offset_id,
+ query_id,
+ intersection_id,
+ ));
+ let geometry_index_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
+ flag_type_id,
+ geometry_index_id,
+ query_id,
+ intersection_id,
+ ));
+ let primitive_index_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
+ flag_type_id,
+ primitive_index_id,
+ query_id,
+ intersection_id,
+ ));
+
+ let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Float,
+ width,
+ pointer_space: None,
+ }));
+ let t_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionTKHR,
+ scalar_type_id,
+ t_id,
+ query_id,
+ intersection_id,
+ ));
+
+ let barycentrics_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Bi),
+ kind: crate::ScalarKind::Float,
+ width,
+ pointer_space: None,
+ }));
+ let barycentrics_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
+ barycentrics_type_id,
+ barycentrics_id,
+ query_id,
+ intersection_id,
+ ));
+
+ let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ pointer_space: None,
+ }));
+ let front_face_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
+ bool_type_id,
+ front_face_id,
+ query_id,
+ intersection_id,
+ ));
+
+ let transform_type_id = self.get_type_id(LookupType::Local(LocalType::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ width,
+ }));
+ let object_to_world_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
+ transform_type_id,
+ object_to_world_id,
+ query_id,
+ intersection_id,
+ ));
+ let world_to_object_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
+ transform_type_id,
+ world_to_object_id,
+ query_id,
+ intersection_id,
+ ));
+
+ let id = self.gen_id();
+ let intersection_type_id = self.get_type_id(LookupType::Handle(
+ self.ir_module.special_types.ray_intersection.unwrap(),
+ ));
+ //Note: the arguments must match `generate_ray_intersection_type` layout
+ block.body.push(Instruction::composite_construct(
+ intersection_type_id,
+ id,
+ &[
+ kind_id,
+ t_id,
+ instance_custom_index_id,
+ instance_id,
+ sbt_record_offset_id,
+ geometry_index_id,
+ primitive_index_id,
+ barycentrics_id,
+ front_face_id,
+ object_to_world_id,
+ world_to_object_id,
+ ],
+ ));
+ id
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/recyclable.rs b/third_party/rust/naga/src/back/spv/recyclable.rs
new file mode 100644
index 0000000000..49f3a02741
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/recyclable.rs
@@ -0,0 +1,60 @@
+/*!
+Reusing collections' previous allocations.
+*/
+
+/// A value that can be reset to its initial state, retaining its current allocations.
+///
+/// Naga attempts to lower the cost of SPIR-V generation by allowing clients to
+/// reuse the same `Writer` for multiple Module translations. Reusing a `Writer`
+/// means that the `Vec`s, `HashMap`s, and other heap-allocated structures the
+/// `Writer` uses internally begin the translation with heap-allocated buffers
+/// ready to use.
+///
+/// But this approach introduces the risk of `Writer` state leaking from one
+/// module to the next. When a developer adds fields to `Writer` or its internal
+/// types, they must remember to reset their contents between modules.
+///
+/// One trick to ensure that every field has been accounted for is to use Rust's
+/// struct literal syntax to construct a new, reset value. If a developer adds a
+/// field, but neglects to update the reset code, the compiler will complain
+/// that a field is missing from the literal. This trait's `recycle` method
+/// takes `self` by value, and returns `Self` by value, encouraging the use of
+/// struct literal expressions in its implementation.
+pub trait Recyclable {
+ /// Clear `self`, retaining its current memory allocations.
+ ///
+ /// Shrink the buffer if it's currently much larger than was actually used.
+ /// This prevents a module with exceptionally large allocations from causing
+ /// the `Writer` to retain more memory than it needs indefinitely.
+ fn recycle(self) -> Self;
+}
+
+// Stock values for various collections.
+
+impl<T> Recyclable for Vec<T> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
+
+impl<K, V, S: Clone> Recyclable for std::collections::HashMap<K, V, S> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
+
+impl<K, S: Clone> Recyclable for std::collections::HashSet<K, S> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
+
+impl<K: Ord, V> Recyclable for std::collections::BTreeMap<K, V> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/selection.rs b/third_party/rust/naga/src/back/spv/selection.rs
new file mode 100644
index 0000000000..788b1f10ab
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/selection.rs
@@ -0,0 +1,257 @@
+/*!
+Generate SPIR-V conditional structures.
+
+Builders for `if` structures with `and`s.
+
+The types in this module track the information needed to emit SPIR-V code
+for complex conditional structures, like those whose conditions involve
+short-circuiting 'and' and 'or' structures. These track labels and can emit
+`OpPhi` instructions to merge values produced along different paths.
+
+This currently only supports exactly the forms Naga uses, so it doesn't
+support `or` or `else`, and only supports zero or one merged values.
+
+Naga needs to emit code roughly like this:
+
+```ignore
+
+ value = DEFAULT;
+ if COND1 && COND2 {
+ value = THEN_VALUE;
+ }
+ // use value
+
+```
+
+Assuming `ctx` and `block` are a mutable references to a [`BlockContext`]
+and the current [`Block`], and `merge_type` is the SPIR-V type for the
+merged value `value`, we can build SPIR-V for the code above like so:
+
+```ignore
+
+ let cond = Selection::start(block, merge_type);
+ // ... compute `cond1` ...
+ cond.if_true(ctx, cond1, DEFAULT);
+ // ... compute `cond2` ...
+ cond.if_true(ctx, cond2, DEFAULT);
+ // ... compute THEN_VALUE
+ let merged_value = cond.finish(ctx, THEN_VALUE);
+
+```
+
+After this, `merged_value` is either `DEFAULT` or `THEN_VALUE`, depending on
+the path by which the merged block was reached.
+
+This takes care of writing all branch instructions, including an
+`OpSelectionMerge` annotation in the header block; starting new blocks and
+assigning them labels; and emitting the `OpPhi` that gathers together the
+right sources for the merged values, for every path through the selection
+construct.
+
+When there is no merged value to produce, you can pass `()` for `merge_type`
+and the merge values. In this case no `OpPhi` instructions are produced, and
+the `finish` method returns `()`.
+
+To enforce proper nesting, a `Selection` takes ownership of the `&mut Block`
+pointer for the duration of its lifetime. To obtain the block for generating
+code in the selection's body, call the `Selection::block` method.
+*/
+
+use super::{Block, BlockContext, Instruction};
+use spirv::Word;
+
+/// A private struct recording what we know about the selection construct so far.
+pub(super) struct Selection<'b, M: MergeTuple> {
+ /// The block pointer we're emitting code into.
+ block: &'b mut Block,
+
+ /// The label of the selection construct's merge block, or `None` if we
+ /// haven't yet written the `OpSelectionMerge` merge instruction.
+ merge_label: Option<Word>,
+
+ /// A set of `(VALUES, PARENT)` pairs, used to build `OpPhi` instructions in
+ /// the merge block. Each `PARENT` is the label of a predecessor block of
+ /// the merge block. The corresponding `VALUES` holds the ids of the values
+ /// that `PARENT` contributes to the merged values.
+ ///
+ /// We emit all branches to the merge block, so we know all its
+ /// predecessors. And we refuse to emit a branch unless we're given the
+ /// values the branching block contributes to the merge, so we always have
+ /// everything we need to emit the correct phis, by construction.
+ values: Vec<(M, Word)>,
+
+ /// The types of the values in each element of `values`.
+ merge_types: M,
+}
+
+impl<'b, M: MergeTuple> Selection<'b, M> {
+ /// Start a new selection construct.
+ ///
+ /// The `block` argument indicates the selection's header block.
+ ///
+ /// The `merge_types` argument should be a `Word` or tuple of `Word`s, each
+ /// value being the SPIR-V result type id of an `OpPhi` instruction that
+ /// will be written to the selection's merge block when this selection's
+ /// [`finish`] method is called. This argument may also be `()`, for
+ /// selections that produce no values.
+ ///
+ /// (This function writes no code to `block` itself; it simply constructs a
+ /// fresh `Selection`.)
+ ///
+ /// [`finish`]: Selection::finish
+ pub(super) fn start(block: &'b mut Block, merge_types: M) -> Self {
+ Selection {
+ block,
+ merge_label: None,
+ values: vec![],
+ merge_types,
+ }
+ }
+
+ pub(super) fn block(&mut self) -> &mut Block {
+ self.block
+ }
+
+ /// Branch to a successor block if `cond` is true, otherwise merge.
+ ///
+ /// If `cond` is false, branch to the merge block, using `values` as the
+ /// merged values. Otherwise, proceed to a new block.
+ ///
+ /// The `values` argument must be the same shape as the `merge_types`
+ /// argument passed to `Selection::start`.
+ pub(super) fn if_true(&mut self, ctx: &mut BlockContext, cond: Word, values: M) {
+ self.values.push((values, self.block.label_id));
+
+ let merge_label = self.make_merge_label(ctx);
+ let next_label = ctx.gen_id();
+ ctx.function.consume(
+ std::mem::replace(self.block, Block::new(next_label)),
+ Instruction::branch_conditional(cond, next_label, merge_label),
+ );
+ }
+
+ /// Emit an unconditional branch to the merge block, and compute merged
+ /// values.
+ ///
+ /// Use `final_values` as the merged values contributed by the current
+ /// block, and transition to the merge block, emitting `OpPhi` instructions
+ /// to produce the merged values. This must be the same shape as the
+ /// `merge_types` argument passed to [`Selection::start`].
+ ///
+ /// Return the SPIR-V ids of the merged values. This value has the same
+ /// shape as the `merge_types` argument passed to `Selection::start`.
+ pub(super) fn finish(self, ctx: &mut BlockContext, final_values: M) -> M {
+ match self {
+ Selection {
+ merge_label: None, ..
+ } => {
+ // We didn't actually emit any branches, so `self.values` must
+ // be empty, and `final_values` are the only sources we have for
+ // the merged values. Easy peasy.
+ final_values
+ }
+
+ Selection {
+ block,
+ merge_label: Some(merge_label),
+ mut values,
+ merge_types,
+ } => {
+ // Emit the final branch and transition to the merge block.
+ values.push((final_values, block.label_id));
+ ctx.function.consume(
+ std::mem::replace(block, Block::new(merge_label)),
+ Instruction::branch(merge_label),
+ );
+
+ // Now that we're in the merge block, build the phi instructions.
+ merge_types.write_phis(ctx, block, &values)
+ }
+ }
+ }
+
+ /// Return the id of the merge block, writing a merge instruction if needed.
+ fn make_merge_label(&mut self, ctx: &mut BlockContext) -> Word {
+ match self.merge_label {
+ None => {
+ let merge_label = ctx.gen_id();
+ self.block.body.push(Instruction::selection_merge(
+ merge_label,
+ spirv::SelectionControl::NONE,
+ ));
+ self.merge_label = Some(merge_label);
+ merge_label
+ }
+ Some(merge_label) => merge_label,
+ }
+ }
+}
+
+/// A trait to help `Selection` manage any number of merged values.
+///
+/// Some selection constructs, like a `ReadZeroSkipWrite` bounds check on a
+/// [`Load`] expression, produce a single merged value. Others produce no merged
+/// value, like a bounds check on a [`Store`] statement.
+///
+/// To let `Selection` work nicely with both cases, we let the merge type
+/// argument passed to [`Selection::start`] be any type that implements this
+/// `MergeTuple` trait. `MergeTuple` is then implemented for `()`, `Word`,
+/// `(Word, Word)`, and so on.
+///
+/// A `MergeTuple` type can represent either a bunch of SPIR-V types or values;
+/// the `merge_types` argument to `Selection::start` are type ids, whereas the
+/// `values` arguments to the [`if_true`] and [`finish`] methods are value ids.
+/// The set of merged value returned by `finish` is a tuple of value ids.
+///
+/// In fact, since Naga only uses zero- and single-valued selection constructs
+/// at present, we only implement `MergeTuple` for `()` and `Word`. But if you
+/// add more cases, feel free to add more implementations. Once const generics
+/// are available, we could have a single implementation of `MergeTuple` for all
+/// lengths of arrays, and be done with it.
+///
+/// [`Load`]: crate::Expression::Load
+/// [`Store`]: crate::Statement::Store
+/// [`if_true`]: Selection::if_true
+/// [`finish`]: Selection::finish
+pub(super) trait MergeTuple: Sized {
+ /// Write OpPhi instructions for the given set of predecessors.
+ ///
+ /// The `predecessors` vector should be a vector of `(LABEL, VALUES)` pairs,
+ /// where each `VALUES` holds the values contributed by the branch from
+ /// `LABEL`, which should be one of the current block's predecessors.
+ fn write_phis(
+ self,
+ ctx: &mut BlockContext,
+ block: &mut Block,
+ predecessors: &[(Self, Word)],
+ ) -> Self;
+}
+
+/// Selections that produce a single merged value.
+///
+/// For example, `ImageLoad` with `BoundsCheckPolicy::ReadZeroSkipWrite` either
+/// returns a texel value or zeros.
+impl MergeTuple for Word {
+ fn write_phis(
+ self,
+ ctx: &mut BlockContext,
+ block: &mut Block,
+ predecessors: &[(Word, Word)],
+ ) -> Word {
+ let merged_value = ctx.gen_id();
+ block
+ .body
+ .push(Instruction::phi(self, merged_value, predecessors));
+ merged_value
+ }
+}
+
+/// Selections that produce no merged values.
+///
+/// For example, `ImageStore` under `BoundsCheckPolicy::ReadZeroSkipWrite`
+/// either does the store or skips it, but in neither case does it produce a
+/// value.
+impl MergeTuple for () {
+ /// No phis need to be generated.
+ fn write_phis(self, _: &mut BlockContext, _: &mut Block, _: &[((), Word)]) {}
+}
diff --git a/third_party/rust/naga/src/back/spv/writer.rs b/third_party/rust/naga/src/back/spv/writer.rs
new file mode 100644
index 0000000000..ba235e6d03
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/writer.rs
@@ -0,0 +1,1966 @@
+use super::{
+ helpers::{contains_builtin, global_needs_wrapper, map_storage_class},
+ make_local, Block, BlockContext, CachedConstant, CachedExpressions, EntryPointContext, Error,
+ Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, LocalType, LocalVariable,
+ LogicalLayout, LookupFunctionType, LookupType, LoopContext, Options, PhysicalLayout,
+ PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE,
+};
+use crate::{
+ arena::{Handle, UniqueArena},
+ back::spv::BindingInfo,
+ proc::{Alignment, TypeResolution},
+ valid::{FunctionInfo, ModuleInfo},
+};
+use spirv::Word;
+use std::collections::hash_map::Entry;
+
+struct FunctionInterface<'a> {
+ varying_ids: &'a mut Vec<Word>,
+ stage: crate::ShaderStage,
+}
+
+impl Function {
+ fn to_words(&self, sink: &mut impl Extend<Word>) {
+ self.signature.as_ref().unwrap().to_words(sink);
+ for argument in self.parameters.iter() {
+ argument.instruction.to_words(sink);
+ }
+ for (index, block) in self.blocks.iter().enumerate() {
+ Instruction::label(block.label_id).to_words(sink);
+ if index == 0 {
+ for local_var in self.variables.values() {
+ local_var.instruction.to_words(sink);
+ }
+ }
+ for instruction in block.body.iter() {
+ instruction.to_words(sink);
+ }
+ }
+ }
+}
+
+impl Writer {
+ pub fn new(options: &Options) -> Result<Self, Error> {
+ let (major, minor) = options.lang_version;
+ if major != 1 {
+ return Err(Error::UnsupportedVersion(major, minor));
+ }
+ let raw_version = ((major as u32) << 16) | ((minor as u32) << 8);
+
+ let mut capabilities_used = crate::FastHashSet::default();
+ capabilities_used.insert(spirv::Capability::Shader);
+
+ let mut id_gen = IdGenerator::default();
+ let gl450_ext_inst_id = id_gen.next();
+ let void_type = id_gen.next();
+
+ Ok(Writer {
+ physical_layout: PhysicalLayout::new(raw_version),
+ logical_layout: LogicalLayout::default(),
+ id_gen,
+ capabilities_available: options.capabilities.clone(),
+ capabilities_used,
+ extensions_used: crate::FastHashSet::default(),
+ debugs: vec![],
+ annotations: vec![],
+ flags: options.flags,
+ bounds_check_policies: options.bounds_check_policies,
+ zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory,
+ void_type,
+ lookup_type: crate::FastHashMap::default(),
+ lookup_function: crate::FastHashMap::default(),
+ lookup_function_type: crate::FastHashMap::default(),
+ constant_ids: Vec::new(),
+ cached_constants: crate::FastHashMap::default(),
+ global_variables: Vec::new(),
+ binding_map: options.binding_map.clone(),
+ saved_cached: CachedExpressions::default(),
+ gl450_ext_inst_id,
+ temp_list: Vec::new(),
+ })
+ }
+
+ /// Reset `Writer` to its initial state, retaining any allocations.
+ ///
+ /// Why not just implement `Recyclable` for `Writer`? By design,
+ /// `Recyclable::recycle` requires ownership of the value, not just
+ /// `&mut`; see the trait documentation. But we need to use this method
+ /// from functions like `Writer::write`, which only have `&mut Writer`.
+ /// Workarounds include unsafe code (`std::ptr::read`, then `write`, ugh)
+ /// or something like a `Default` impl that returns an oddly-initialized
+ /// `Writer`, which is worse.
+ fn reset(&mut self) {
+ use super::recyclable::Recyclable;
+ use std::mem::take;
+
+ let mut id_gen = IdGenerator::default();
+ let gl450_ext_inst_id = id_gen.next();
+ let void_type = id_gen.next();
+
+ // Every field of the old writer that is not determined by the `Options`
+ // passed to `Writer::new` should be reset somehow.
+ let fresh = Writer {
+ // Copied from the old Writer:
+ flags: self.flags,
+ bounds_check_policies: self.bounds_check_policies,
+ zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory,
+ capabilities_available: take(&mut self.capabilities_available),
+ binding_map: take(&mut self.binding_map),
+
+ // Initialized afresh:
+ id_gen,
+ void_type,
+ gl450_ext_inst_id,
+
+ // Recycled:
+ capabilities_used: take(&mut self.capabilities_used).recycle(),
+ extensions_used: take(&mut self.extensions_used).recycle(),
+ physical_layout: self.physical_layout.clone().recycle(),
+ logical_layout: take(&mut self.logical_layout).recycle(),
+ debugs: take(&mut self.debugs).recycle(),
+ annotations: take(&mut self.annotations).recycle(),
+ lookup_type: take(&mut self.lookup_type).recycle(),
+ lookup_function: take(&mut self.lookup_function).recycle(),
+ lookup_function_type: take(&mut self.lookup_function_type).recycle(),
+ constant_ids: take(&mut self.constant_ids).recycle(),
+ cached_constants: take(&mut self.cached_constants).recycle(),
+ global_variables: take(&mut self.global_variables).recycle(),
+ saved_cached: take(&mut self.saved_cached).recycle(),
+ temp_list: take(&mut self.temp_list).recycle(),
+ };
+
+ *self = fresh;
+
+ self.capabilities_used.insert(spirv::Capability::Shader);
+ }
+
+ /// Indicate that the code requires any one of the listed capabilities.
+ ///
+ /// If nothing in `capabilities` appears in the available capabilities
+ /// specified in the [`Options`] from which this `Writer` was created,
+ /// return an error. The `what` string is used in the error message to
+ /// explain what provoked the requirement. (If no available capabilities were
+ /// given, assume everything is available.)
+ ///
+ /// The first acceptable capability will be added to this `Writer`'s
+ /// [`capabilities_used`] table, and an `OpCapability` emitted for it in the
+ /// result. For this reason, more specific capabilities should be listed
+ /// before more general.
+ ///
+ /// [`capabilities_used`]: Writer::capabilities_used
+ pub(super) fn require_any(
+ &mut self,
+ what: &'static str,
+ capabilities: &[spirv::Capability],
+ ) -> Result<(), Error> {
+ match *capabilities {
+ [] => Ok(()),
+ [first, ..] => {
+ // Find the first acceptable capability, or return an error if
+ // there is none.
+ let selected = match self.capabilities_available {
+ None => first,
+ Some(ref available) => {
+ match capabilities.iter().find(|cap| available.contains(cap)) {
+ Some(&cap) => cap,
+ None => {
+ return Err(Error::MissingCapabilities(what, capabilities.to_vec()))
+ }
+ }
+ }
+ };
+ self.capabilities_used.insert(selected);
+ Ok(())
+ }
+ }
+ }
+
+ /// Indicate that the code uses the given extension.
+ pub(super) fn use_extension(&mut self, extension: &'static str) {
+ self.extensions_used.insert(extension);
+ }
+
+ pub(super) fn get_type_id(&mut self, lookup_ty: LookupType) -> Word {
+ match self.lookup_type.entry(lookup_ty) {
+ Entry::Occupied(e) => *e.get(),
+ Entry::Vacant(e) => {
+ let local = match lookup_ty {
+ LookupType::Handle(_handle) => unreachable!("Handles are populated at start"),
+ LookupType::Local(local) => local,
+ };
+
+ let id = self.id_gen.next();
+ e.insert(id);
+ self.write_type_declaration_local(id, local);
+ id
+ }
+ }
+ }
+
+ pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
+ let lookup_ty = match *tr {
+ TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle),
+ TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()),
+ };
+ self.get_type_id(lookup_ty)
+ }
+
+ pub(super) fn get_pointer_id(
+ &mut self,
+ arena: &UniqueArena<crate::Type>,
+ handle: Handle<crate::Type>,
+ class: spirv::StorageClass,
+ ) -> Result<Word, Error> {
+ let ty_id = self.get_type_id(LookupType::Handle(handle));
+ if let crate::TypeInner::Pointer { .. } = arena[handle].inner {
+ return Ok(ty_id);
+ }
+ let lookup_type = LookupType::Local(LocalType::Pointer {
+ base: handle,
+ class,
+ });
+ Ok(if let Some(&id) = self.lookup_type.get(&lookup_type) {
+ id
+ } else {
+ let id = self.id_gen.next();
+ let instruction = Instruction::type_pointer(id, class, ty_id);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_type.insert(lookup_type, id);
+ id
+ })
+ }
+
+ pub(super) fn get_uint_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn get_float_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn get_uint3_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: Some(crate::VectorSize::Tri),
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn get_float_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
+ let lookup_type = LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ pointer_space: Some(class),
+ });
+ if let Some(&id) = self.lookup_type.get(&lookup_type) {
+ id
+ } else {
+ let id = self.id_gen.next();
+ let ty_id = self.get_float_type_id();
+ let instruction = Instruction::type_pointer(id, class, ty_id);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_type.insert(lookup_type, id);
+ id
+ }
+ }
+
+ pub(super) fn get_uint3_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
+ let lookup_type = LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Tri),
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ pointer_space: Some(class),
+ });
+ if let Some(&id) = self.lookup_type.get(&lookup_type) {
+ id
+ } else {
+ let id = self.id_gen.next();
+ let ty_id = self.get_uint3_type_id();
+ let instruction = Instruction::type_pointer(id, class, ty_id);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_type.insert(lookup_type, id);
+ id
+ }
+ }
+
+ pub(super) fn get_bool_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ kind: crate::ScalarKind::Bool,
+ width: 1,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn get_bool3_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: Some(crate::VectorSize::Tri),
+ kind: crate::ScalarKind::Bool,
+ width: 1,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) {
+ self.annotations
+ .push(Instruction::decorate(id, decoration, operands));
+ }
+
+ fn write_function(
+ &mut self,
+ ir_function: &crate::Function,
+ info: &FunctionInfo,
+ ir_module: &crate::Module,
+ mut interface: Option<FunctionInterface>,
+ ) -> Result<Word, Error> {
+ let mut function = Function::default();
+
+ for (handle, variable) in ir_function.local_variables.iter() {
+ let id = self.id_gen.next();
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = variable.name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ let init_word = variable
+ .init
+ .map(|constant| self.constant_ids[constant.index()]);
+ let pointer_type_id =
+ self.get_pointer_id(&ir_module.types, variable.ty, spirv::StorageClass::Function)?;
+ let instruction = Instruction::variable(
+ pointer_type_id,
+ id,
+ spirv::StorageClass::Function,
+ init_word.or_else(|| match ir_module.types[variable.ty].inner {
+ crate::TypeInner::RayQuery => None,
+ _ => {
+ let type_id = self.get_type_id(LookupType::Handle(variable.ty));
+ Some(self.write_constant_null(type_id))
+ }
+ }),
+ );
+ function
+ .variables
+ .insert(handle, LocalVariable { id, instruction });
+ }
+
+ let prelude_id = self.id_gen.next();
+ let mut prelude = Block::new(prelude_id);
+ let mut ep_context = EntryPointContext {
+ argument_ids: Vec::new(),
+ results: Vec::new(),
+ };
+
+ let mut local_invocation_id = None;
+
+ let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len());
+ for argument in ir_function.arguments.iter() {
+ let class = spirv::StorageClass::Input;
+ let handle_ty = ir_module.types[argument.ty].inner.is_handle();
+ let argument_type_id = match handle_ty {
+ true => self.get_pointer_id(
+ &ir_module.types,
+ argument.ty,
+ spirv::StorageClass::UniformConstant,
+ )?,
+ false => self.get_type_id(LookupType::Handle(argument.ty)),
+ };
+
+ if let Some(ref mut iface) = interface {
+ let id = if let Some(ref binding) = argument.binding {
+ let name = argument.name.as_deref();
+
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ name,
+ argument.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ let id = self.id_gen.next();
+ prelude
+ .body
+ .push(Instruction::load(argument_type_id, id, varying_id, None));
+
+ if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
+ local_invocation_id = Some(id);
+ }
+
+ id
+ } else if let crate::TypeInner::Struct { ref members, .. } =
+ ir_module.types[argument.ty].inner
+ {
+ let struct_id = self.id_gen.next();
+ let mut constituent_ids = Vec::with_capacity(members.len());
+ for member in members {
+ let type_id = self.get_type_id(LookupType::Handle(member.ty));
+ let name = member.name.as_deref();
+ let binding = member.binding.as_ref().unwrap();
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ name,
+ member.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ let id = self.id_gen.next();
+ prelude
+ .body
+ .push(Instruction::load(type_id, id, varying_id, None));
+ constituent_ids.push(id);
+
+ if binding == &crate::Binding::BuiltIn(crate::BuiltIn::GlobalInvocationId) {
+ local_invocation_id = Some(id);
+ }
+ }
+ prelude.body.push(Instruction::composite_construct(
+ argument_type_id,
+ struct_id,
+ &constituent_ids,
+ ));
+ struct_id
+ } else {
+ unreachable!("Missing argument binding on an entry point");
+ };
+ ep_context.argument_ids.push(id);
+ } else {
+ let argument_id = self.id_gen.next();
+ let instruction = Instruction::function_parameter(argument_type_id, argument_id);
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = argument.name {
+ self.debugs.push(Instruction::name(argument_id, name));
+ }
+ }
+ function.parameters.push(FunctionArgument {
+ instruction,
+ handle_id: if handle_ty {
+ let id = self.id_gen.next();
+ prelude.body.push(Instruction::load(
+ self.get_type_id(LookupType::Handle(argument.ty)),
+ id,
+ argument_id,
+ None,
+ ));
+ id
+ } else {
+ 0
+ },
+ });
+ parameter_type_ids.push(argument_type_id);
+ };
+ }
+
+ let return_type_id = match ir_function.result {
+ Some(ref result) => {
+ if let Some(ref mut iface) = interface {
+ let mut has_point_size = false;
+ let class = spirv::StorageClass::Output;
+ if let Some(ref binding) = result.binding {
+ has_point_size |=
+ *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
+ let type_id = self.get_type_id(LookupType::Handle(result.ty));
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ None,
+ result.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ ep_context.results.push(ResultMember {
+ id: varying_id,
+ type_id,
+ built_in: binding.to_built_in(),
+ });
+ } else if let crate::TypeInner::Struct { ref members, .. } =
+ ir_module.types[result.ty].inner
+ {
+ for member in members {
+ let type_id = self.get_type_id(LookupType::Handle(member.ty));
+ let name = member.name.as_deref();
+ let binding = member.binding.as_ref().unwrap();
+ has_point_size |=
+ *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ name,
+ member.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ ep_context.results.push(ResultMember {
+ id: varying_id,
+ type_id,
+ built_in: binding.to_built_in(),
+ });
+ }
+ } else {
+ unreachable!("Missing result binding on an entry point");
+ }
+
+ if self.flags.contains(WriterFlags::FORCE_POINT_SIZE)
+ && iface.stage == crate::ShaderStage::Vertex
+ && !has_point_size
+ {
+ // add point size artificially
+ let varying_id = self.id_gen.next();
+ let pointer_type_id = self.get_float_pointer_type_id(class);
+ Instruction::variable(pointer_type_id, varying_id, class, None)
+ .to_words(&mut self.logical_layout.declarations);
+ self.decorate(
+ varying_id,
+ spirv::Decoration::BuiltIn,
+ &[spirv::BuiltIn::PointSize as u32],
+ );
+ iface.varying_ids.push(varying_id);
+
+ let default_value_id =
+ self.get_constant_scalar(crate::ScalarValue::Float(1.0), 4);
+ prelude
+ .body
+ .push(Instruction::store(varying_id, default_value_id, None));
+ }
+ self.void_type
+ } else {
+ self.get_type_id(LookupType::Handle(result.ty))
+ }
+ }
+ None => self.void_type,
+ };
+
+ let lookup_function_type = LookupFunctionType {
+ parameter_type_ids,
+ return_type_id,
+ };
+
+ let function_id = self.id_gen.next();
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = ir_function.name {
+ self.debugs.push(Instruction::name(function_id, name));
+ }
+ }
+
+ let function_type = self.get_function_type(lookup_function_type);
+ function.signature = Some(Instruction::function(
+ return_type_id,
+ function_id,
+ spirv::FunctionControl::empty(),
+ function_type,
+ ));
+
+ if interface.is_some() {
+ function.entry_point_context = Some(ep_context);
+ }
+
+ // fill up the `GlobalVariable::access_id`
+ for gv in self.global_variables.iter_mut() {
+ gv.reset_for_function();
+ }
+ for (handle, var) in ir_module.global_variables.iter() {
+ if info[handle].is_empty() {
+ continue;
+ }
+
+ let mut gv = self.global_variables[handle.index()].clone();
+ if let Some(ref mut iface) = interface {
+ // Have to include global variables in the interface
+ if self.physical_layout.version >= 0x10400 {
+ iface.varying_ids.push(gv.var_id);
+ }
+ }
+
+ // Handle globals are pre-emitted and should be loaded automatically.
+ //
+ // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
+ let is_binding_array = match ir_module.types[var.ty].inner {
+ crate::TypeInner::BindingArray { .. } => true,
+ _ => false,
+ };
+
+ if var.space == crate::AddressSpace::Handle && !is_binding_array {
+ let var_type_id = self.get_type_id(LookupType::Handle(var.ty));
+ let id = self.id_gen.next();
+ prelude
+ .body
+ .push(Instruction::load(var_type_id, id, gv.var_id, None));
+ gv.access_id = gv.var_id;
+ gv.handle_id = id;
+ } else if global_needs_wrapper(ir_module, var) {
+ let class = map_storage_class(var.space);
+ let pointer_type_id = self.get_pointer_id(&ir_module.types, var.ty, class)?;
+ let index_id = self.get_index_constant(0);
+
+ let id = self.id_gen.next();
+ prelude.body.push(Instruction::access_chain(
+ pointer_type_id,
+ id,
+ gv.var_id,
+ &[index_id],
+ ));
+ gv.access_id = id;
+ } else {
+ // by default, the variable ID is accessed as is
+ gv.access_id = gv.var_id;
+ };
+
+ // work around borrow checking in the presence of `self.xxx()` calls
+ self.global_variables[handle.index()] = gv;
+ }
+
+ // Create a `BlockContext` for generating SPIR-V for the function's
+ // body.
+ let mut context = BlockContext {
+ ir_module,
+ ir_function,
+ fun_info: info,
+ function: &mut function,
+ // Re-use the cached expression table from prior functions.
+ cached: std::mem::take(&mut self.saved_cached),
+
+ // Steal the Writer's temp list for a bit.
+ temp_list: std::mem::take(&mut self.temp_list),
+ writer: self,
+ };
+
+ // fill up the pre-emitted expressions
+ context.cached.reset(ir_function.expressions.len());
+ for (handle, expr) in ir_function.expressions.iter() {
+ if expr.needs_pre_emit() {
+ context.cache_expression_value(handle, &mut prelude)?;
+ }
+ }
+
+ let next_id = context.gen_id();
+
+ context
+ .function
+ .consume(prelude, Instruction::branch(next_id));
+
+ let workgroup_vars_init_exit_block_id =
+ match (context.writer.zero_initialize_workgroup_memory, interface) {
+ (
+ super::ZeroInitializeWorkgroupMemoryMode::Polyfill,
+ Some(
+ ref mut interface @ FunctionInterface {
+ stage: crate::ShaderStage::Compute,
+ ..
+ },
+ ),
+ ) => context.writer.generate_workgroup_vars_init_block(
+ next_id,
+ ir_module,
+ info,
+ local_invocation_id,
+ interface,
+ context.function,
+ ),
+ _ => None,
+ };
+
+ let main_id = if let Some(exit_id) = workgroup_vars_init_exit_block_id {
+ exit_id
+ } else {
+ next_id
+ };
+
+ context.write_block(
+ main_id,
+ &ir_function.body,
+ super::block::BlockExit::Return,
+ LoopContext::default(),
+ )?;
+
+ // Consume the `BlockContext`, ending its borrows and letting the
+ // `Writer` steal back its cached expression table and temp_list.
+ let BlockContext {
+ cached, temp_list, ..
+ } = context;
+ self.saved_cached = cached;
+ self.temp_list = temp_list;
+
+ function.to_words(&mut self.logical_layout.function_definitions);
+ Instruction::function_end().to_words(&mut self.logical_layout.function_definitions);
+
+ Ok(function_id)
+ }
+
+ fn write_execution_mode(
+ &mut self,
+ function_id: Word,
+ mode: spirv::ExecutionMode,
+ ) -> Result<(), Error> {
+ //self.check(mode.required_capabilities())?;
+ Instruction::execution_mode(function_id, mode, &[])
+ .to_words(&mut self.logical_layout.execution_modes);
+ Ok(())
+ }
+
+ // TODO Move to instructions module
+ fn write_entry_point(
+ &mut self,
+ entry_point: &crate::EntryPoint,
+ info: &FunctionInfo,
+ ir_module: &crate::Module,
+ ) -> Result<Instruction, Error> {
+ let mut interface_ids = Vec::new();
+ let function_id = self.write_function(
+ &entry_point.function,
+ info,
+ ir_module,
+ Some(FunctionInterface {
+ varying_ids: &mut interface_ids,
+ stage: entry_point.stage,
+ }),
+ )?;
+
+ let exec_model = match entry_point.stage {
+ crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
+ crate::ShaderStage::Fragment => {
+ self.write_execution_mode(function_id, spirv::ExecutionMode::OriginUpperLeft)?;
+ if let Some(ref result) = entry_point.function.result {
+ if contains_builtin(
+ result.binding.as_ref(),
+ result.ty,
+ &ir_module.types,
+ crate::BuiltIn::FragDepth,
+ ) {
+ self.write_execution_mode(
+ function_id,
+ spirv::ExecutionMode::DepthReplacing,
+ )?;
+ }
+ }
+ spirv::ExecutionModel::Fragment
+ }
+ crate::ShaderStage::Compute => {
+ let execution_mode = spirv::ExecutionMode::LocalSize;
+ //self.check(execution_mode.required_capabilities())?;
+ Instruction::execution_mode(
+ function_id,
+ execution_mode,
+ &entry_point.workgroup_size,
+ )
+ .to_words(&mut self.logical_layout.execution_modes);
+ spirv::ExecutionModel::GLCompute
+ }
+ };
+ //self.check(exec_model.required_capabilities())?;
+
+ Ok(Instruction::entry_point(
+ exec_model,
+ function_id,
+ &entry_point.name,
+ interface_ids.as_slice(),
+ ))
+ }
+
+ fn make_scalar(
+ &mut self,
+ id: Word,
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ ) -> Instruction {
+ use crate::ScalarKind as Sk;
+
+ let bits = (width * BITS_PER_BYTE) as u32;
+ match kind {
+ Sk::Sint | Sk::Uint => {
+ let signedness = if kind == Sk::Sint {
+ super::instructions::Signedness::Signed
+ } else {
+ super::instructions::Signedness::Unsigned
+ };
+ let cap = match bits {
+ 8 => Some(spirv::Capability::Int8),
+ 16 => Some(spirv::Capability::Int16),
+ 64 => Some(spirv::Capability::Int64),
+ _ => None,
+ };
+ if let Some(cap) = cap {
+ self.capabilities_used.insert(cap);
+ }
+ Instruction::type_int(id, bits, signedness)
+ }
+ Sk::Float => {
+ if bits == 64 {
+ self.capabilities_used.insert(spirv::Capability::Float64);
+ }
+ Instruction::type_float(id, bits)
+ }
+ Sk::Bool => Instruction::type_bool(id),
+ }
+ }
+
+ fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> {
+ match *inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ let sampled = match class {
+ crate::ImageClass::Sampled { .. } => true,
+ crate::ImageClass::Depth { .. } => true,
+ crate::ImageClass::Storage { format, .. } => {
+ self.request_image_format_capabilities(format.into())?;
+ false
+ }
+ };
+
+ match dim {
+ crate::ImageDimension::D1 => {
+ if sampled {
+ self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?;
+ } else {
+ self.require_any("1D storage images", &[spirv::Capability::Image1D])?;
+ }
+ }
+ crate::ImageDimension::Cube if arrayed => {
+ if sampled {
+ self.require_any(
+ "sampled cube array images",
+ &[spirv::Capability::SampledCubeArray],
+ )?;
+ } else {
+ self.require_any(
+ "cube array storage images",
+ &[spirv::Capability::ImageCubeArray],
+ )?;
+ }
+ }
+ _ => {}
+ }
+ }
+ crate::TypeInner::AccelerationStructure => {
+ self.require_any("Acceleration Structure", &[spirv::Capability::RayQueryKHR])?;
+ }
+ crate::TypeInner::RayQuery => {
+ self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?;
+ }
+ _ => {}
+ }
+ Ok(())
+ }
+
+ fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) {
+ let instruction = match local_ty {
+ LocalType::Value {
+ vector_size: None,
+ kind,
+ width,
+ pointer_space: None,
+ } => self.make_scalar(id, kind, width),
+ LocalType::Value {
+ vector_size: Some(size),
+ kind,
+ width,
+ pointer_space: None,
+ } => {
+ let scalar_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind,
+ width,
+ pointer_space: None,
+ }));
+ Instruction::type_vector(id, scalar_id, size)
+ }
+ LocalType::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let vector_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(rows),
+ kind: crate::ScalarKind::Float,
+ width,
+ pointer_space: None,
+ }));
+ Instruction::type_matrix(id, vector_id, columns)
+ }
+ LocalType::Pointer { base, class } => {
+ let type_id = self.get_type_id(LookupType::Handle(base));
+ Instruction::type_pointer(id, class, type_id)
+ }
+ LocalType::Value {
+ vector_size,
+ kind,
+ width,
+ pointer_space: Some(class),
+ } => {
+ let type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size,
+ kind,
+ width,
+ pointer_space: None,
+ }));
+ Instruction::type_pointer(id, class, type_id)
+ }
+ LocalType::Image(image) => {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ kind: image.sampled_type,
+ width: 4,
+ pointer_space: None,
+ };
+ let type_id = self.get_type_id(LookupType::Local(local_type));
+ Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format)
+ }
+ LocalType::Sampler => Instruction::type_sampler(id),
+ LocalType::SampledImage { image_type_id } => {
+ Instruction::type_sampled_image(id, image_type_id)
+ }
+ LocalType::BindingArray { base, size } => {
+ let inner_ty = self.get_type_id(LookupType::Handle(base));
+ let scalar_id = self.get_constant_scalar(crate::ScalarValue::Uint(size), 4);
+ Instruction::type_array(id, inner_ty, scalar_id)
+ }
+ LocalType::PointerToBindingArray { base, size } => {
+ let inner_ty =
+ self.get_type_id(LookupType::Local(LocalType::BindingArray { base, size }));
+ Instruction::type_pointer(id, spirv::StorageClass::UniformConstant, inner_ty)
+ }
+ LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id),
+ LocalType::RayQuery => Instruction::type_ray_query(id),
+ };
+
+ instruction.to_words(&mut self.logical_layout.declarations);
+ }
+
+ fn write_type_declaration_arena(
+ &mut self,
+ arena: &UniqueArena<crate::Type>,
+ handle: Handle<crate::Type>,
+ ) -> Result<Word, Error> {
+ let ty = &arena[handle];
+ let id = if let Some(local) = make_local(&ty.inner) {
+ // This type can be represented as a `LocalType`, so check if we've
+ // already written an instruction for it. If not, do so now, with
+ // `write_type_declaration_local`.
+ match self.lookup_type.entry(LookupType::Local(local)) {
+ // We already have an id for this `LocalType`.
+ Entry::Occupied(e) => *e.get(),
+
+ // It's a type we haven't seen before.
+ Entry::Vacant(e) => {
+ let id = self.id_gen.next();
+ e.insert(id);
+
+ self.write_type_declaration_local(id, local);
+
+ // If it's a type that needs SPIR-V capabilities, request them now,
+ // so write_type_declaration_local can stay infallible.
+ self.request_type_capabilities(&ty.inner)?;
+
+ id
+ }
+ }
+ } else {
+ use spirv::Decoration;
+
+ let id = self.id_gen.next();
+ let instruction = match ty.inner {
+ crate::TypeInner::Array { base, size, stride } => {
+ self.decorate(id, Decoration::ArrayStride, &[stride]);
+
+ let type_id = self.get_type_id(LookupType::Handle(base));
+ match size {
+ crate::ArraySize::Constant(const_handle) => {
+ let length_id = self.constant_ids[const_handle.index()];
+ Instruction::type_array(id, type_id, length_id)
+ }
+ crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id),
+ }
+ }
+ crate::TypeInner::BindingArray { base, size } => {
+ let type_id = self.get_type_id(LookupType::Handle(base));
+ match size {
+ crate::ArraySize::Constant(const_handle) => {
+ let length_id = self.constant_ids[const_handle.index()];
+ Instruction::type_array(id, type_id, length_id)
+ }
+ crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id),
+ }
+ }
+ crate::TypeInner::Struct {
+ ref members,
+ span: _,
+ } => {
+ let mut member_ids = Vec::with_capacity(members.len());
+ for (index, member) in members.iter().enumerate() {
+ self.decorate_struct_member(id, index, member, arena)?;
+ let member_id = self.get_type_id(LookupType::Handle(member.ty));
+ member_ids.push(member_id);
+ }
+ Instruction::type_struct(id, member_ids.as_slice())
+ }
+
+ // These all have TypeLocal representations, so they should have been
+ // handled by `write_type_declaration_local` above.
+ crate::TypeInner::Scalar { .. }
+ | crate::TypeInner::Atomic { .. }
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. }
+ | crate::TypeInner::Pointer { .. }
+ | crate::TypeInner::ValuePointer { .. }
+ | crate::TypeInner::Image { .. }
+ | crate::TypeInner::Sampler { .. }
+ | crate::TypeInner::AccelerationStructure
+ | crate::TypeInner::RayQuery => unreachable!(),
+ };
+
+ instruction.to_words(&mut self.logical_layout.declarations);
+ id
+ };
+
+ // Add this handle as a new alias for that type.
+ self.lookup_type.insert(LookupType::Handle(handle), id);
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = ty.name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ Ok(id)
+ }
+
+ fn request_image_format_capabilities(
+ &mut self,
+ format: spirv::ImageFormat,
+ ) -> Result<(), Error> {
+ use spirv::ImageFormat as If;
+ match format {
+ If::Rg32f
+ | If::Rg16f
+ | If::R11fG11fB10f
+ | If::R16f
+ | If::Rgba16
+ | If::Rgb10A2
+ | If::Rg16
+ | If::Rg8
+ | If::R16
+ | If::R8
+ | If::Rgba16Snorm
+ | If::Rg16Snorm
+ | If::Rg8Snorm
+ | If::R16Snorm
+ | If::R8Snorm
+ | If::Rg32i
+ | If::Rg16i
+ | If::Rg8i
+ | If::R16i
+ | If::R8i
+ | If::Rgb10a2ui
+ | If::Rg32ui
+ | If::Rg16ui
+ | If::Rg8ui
+ | If::R16ui
+ | If::R8ui => self.require_any(
+ "storage image format",
+ &[spirv::Capability::StorageImageExtendedFormats],
+ ),
+ If::R64ui | If::R64i => self.require_any(
+ "64-bit integer storage image format",
+ &[spirv::Capability::Int64ImageEXT],
+ ),
+ If::Unknown
+ | If::Rgba32f
+ | If::Rgba16f
+ | If::R32f
+ | If::Rgba8
+ | If::Rgba8Snorm
+ | If::Rgba32i
+ | If::Rgba16i
+ | If::Rgba8i
+ | If::R32i
+ | If::Rgba32ui
+ | If::Rgba16ui
+ | If::Rgba8ui
+ | If::R32ui => Ok(()),
+ }
+ }
+
+ pub(super) fn get_index_constant(&mut self, index: Word) -> Word {
+ self.get_constant_scalar(crate::ScalarValue::Uint(index as _), 4)
+ }
+
+ pub(super) fn get_constant_scalar(
+ &mut self,
+ value: crate::ScalarValue,
+ width: crate::Bytes,
+ ) -> Word {
+ let scalar = CachedConstant::Scalar { value, width };
+ if let Some(&id) = self.cached_constants.get(&scalar) {
+ return id;
+ }
+ let id = self.id_gen.next();
+ self.write_constant_scalar(id, &value, width, None);
+ self.cached_constants.insert(scalar, id);
+ id
+ }
+
+ fn write_constant_scalar(
+ &mut self,
+ id: Word,
+ value: &crate::ScalarValue,
+ width: crate::Bytes,
+ debug_name: Option<&String>,
+ ) {
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(name) = debug_name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+ let type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ kind: value.scalar_kind(),
+ width,
+ pointer_space: None,
+ }));
+ let (solo, pair);
+ let instruction = match *value {
+ crate::ScalarValue::Sint(val) => {
+ let words = match width {
+ 4 => {
+ solo = [val as u32];
+ &solo[..]
+ }
+ 8 => {
+ pair = [val as u32, (val >> 32) as u32];
+ &pair
+ }
+ _ => unreachable!(),
+ };
+ Instruction::constant(type_id, id, words)
+ }
+ crate::ScalarValue::Uint(val) => {
+ let words = match width {
+ 4 => {
+ solo = [val as u32];
+ &solo[..]
+ }
+ 8 => {
+ pair = [val as u32, (val >> 32) as u32];
+ &pair
+ }
+ _ => unreachable!(),
+ };
+ Instruction::constant(type_id, id, words)
+ }
+ crate::ScalarValue::Float(val) => {
+ let words = match width {
+ 4 => {
+ solo = [(val as f32).to_bits()];
+ &solo[..]
+ }
+ 8 => {
+ let bits = f64::to_bits(val);
+ pair = [bits as u32, (bits >> 32) as u32];
+ &pair
+ }
+ _ => unreachable!(),
+ };
+ Instruction::constant(type_id, id, words)
+ }
+ crate::ScalarValue::Bool(true) => Instruction::constant_true(type_id, id),
+ crate::ScalarValue::Bool(false) => Instruction::constant_false(type_id, id),
+ };
+
+ instruction.to_words(&mut self.logical_layout.declarations);
+ }
+
+ pub(super) fn get_constant_composite(
+ &mut self,
+ ty: LookupType,
+ constituent_ids: &[Word],
+ ) -> Word {
+ let composite = CachedConstant::Composite {
+ ty,
+ constituent_ids: constituent_ids.to_vec(),
+ };
+ if let Some(&id) = self.cached_constants.get(&composite) {
+ return id;
+ }
+ let id = self.id_gen.next();
+ self.write_constant_composite(id, ty, constituent_ids, None);
+ self.cached_constants.insert(composite, id);
+ id
+ }
+
+ fn write_constant_composite(
+ &mut self,
+ id: Word,
+ ty: LookupType,
+ constituent_ids: &[Word],
+ debug_name: Option<&String>,
+ ) {
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(name) = debug_name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+ let type_id = self.get_type_id(ty);
+ Instruction::constant_composite(type_id, id, constituent_ids)
+ .to_words(&mut self.logical_layout.declarations);
+ }
+
+ pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word {
+ let null_id = self.id_gen.next();
+ Instruction::constant_null(type_id, null_id)
+ .to_words(&mut self.logical_layout.declarations);
+ null_id
+ }
+
+ pub(super) fn write_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
+ let memory_scope = if flags.contains(crate::Barrier::STORAGE) {
+ spirv::Scope::Device
+ } else {
+ spirv::Scope::Workgroup
+ };
+ let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
+ semantics.set(
+ spirv::MemorySemantics::UNIFORM_MEMORY,
+ flags.contains(crate::Barrier::STORAGE),
+ );
+ semantics.set(
+ spirv::MemorySemantics::WORKGROUP_MEMORY,
+ flags.contains(crate::Barrier::WORK_GROUP),
+ );
+ let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32);
+ let mem_scope_id = self.get_index_constant(memory_scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ block.body.push(Instruction::control_barrier(
+ exec_scope_id,
+ mem_scope_id,
+ semantics_id,
+ ));
+ }
+
+ fn generate_workgroup_vars_init_block(
+ &mut self,
+ entry_id: Word,
+ ir_module: &crate::Module,
+ info: &FunctionInfo,
+ local_invocation_id: Option<Word>,
+ interface: &mut FunctionInterface,
+ function: &mut Function,
+ ) -> Option<Word> {
+ let body = ir_module
+ .global_variables
+ .iter()
+ .filter(|&(handle, var)| {
+ !info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ })
+ .map(|(handle, var)| {
+ // It's safe to use `var_id` here, not `access_id`, because only
+ // variables in the `Uniform` and `StorageBuffer` address spaces
+ // get wrapped, and we're initializing `WorkGroup` variables.
+ let var_id = self.global_variables[handle.index()].var_id;
+ let var_type_id = self.get_type_id(LookupType::Handle(var.ty));
+ let init_word = self.write_constant_null(var_type_id);
+ Instruction::store(var_id, init_word, None)
+ })
+ .collect::<Vec<_>>();
+
+ if body.is_empty() {
+ return None;
+ }
+
+ let uint3_type_id = self.get_uint3_type_id();
+
+ let mut pre_if_block = Block::new(entry_id);
+
+ let local_invocation_id = if let Some(local_invocation_id) = local_invocation_id {
+ local_invocation_id
+ } else {
+ let varying_id = self.id_gen.next();
+ let class = spirv::StorageClass::Input;
+ let pointer_type_id = self.get_uint3_pointer_type_id(class);
+
+ Instruction::variable(pointer_type_id, varying_id, class, None)
+ .to_words(&mut self.logical_layout.declarations);
+
+ self.decorate(
+ varying_id,
+ spirv::Decoration::BuiltIn,
+ &[spirv::BuiltIn::LocalInvocationId as u32],
+ );
+
+ interface.varying_ids.push(varying_id);
+ let id = self.id_gen.next();
+ pre_if_block
+ .body
+ .push(Instruction::load(uint3_type_id, id, varying_id, None));
+
+ id
+ };
+
+ let zero_id = self.write_constant_null(uint3_type_id);
+ let bool3_type_id = self.get_bool3_type_id();
+
+ let eq_id = self.id_gen.next();
+ pre_if_block.body.push(Instruction::binary(
+ spirv::Op::IEqual,
+ bool3_type_id,
+ eq_id,
+ local_invocation_id,
+ zero_id,
+ ));
+
+ let condition_id = self.id_gen.next();
+ let bool_type_id = self.get_bool_type_id();
+ pre_if_block.body.push(Instruction::relational(
+ spirv::Op::All,
+ bool_type_id,
+ condition_id,
+ eq_id,
+ ));
+
+ let merge_id = self.id_gen.next();
+ pre_if_block.body.push(Instruction::selection_merge(
+ merge_id,
+ spirv::SelectionControl::NONE,
+ ));
+
+ let accept_id = self.id_gen.next();
+ function.consume(
+ pre_if_block,
+ Instruction::branch_conditional(condition_id, accept_id, merge_id),
+ );
+
+ let accept_block = Block {
+ label_id: accept_id,
+ body,
+ };
+ function.consume(accept_block, Instruction::branch(merge_id));
+
+ let mut post_if_block = Block::new(merge_id);
+
+ self.write_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block);
+
+ let next_id = self.id_gen.next();
+ function.consume(post_if_block, Instruction::branch(next_id));
+ Some(next_id)
+ }
+
+ /// Generate an `OpVariable` for one value in an [`EntryPoint`]'s IO interface.
+ ///
+ /// The [`Binding`]s of the arguments and result of an [`EntryPoint`]'s
+ /// [`Function`] describe a SPIR-V shader interface. In SPIR-V, the
+ /// interface is represented by global variables in the `Input` and `Output`
+ /// storage classes, with decorations indicating which builtin or location
+ /// each variable corresponds to.
+ ///
+ /// This function emits a single global `OpVariable` for a single value from
+ /// the interface, and adds appropriate decorations to indicate which
+ /// builtin or location it represents, how it should be interpolated, and so
+ /// on. The `class` argument gives the variable's SPIR-V storage class,
+ /// which should be either [`Input`] or [`Output`].
+ ///
+ /// [`Binding`]: crate::Binding
+ /// [`Function`]: crate::Function
+ /// [`EntryPoint`]: crate::EntryPoint
+ /// [`Input`]: spirv::StorageClass::Input
+ /// [`Output`]: spirv::StorageClass::Output
+ fn write_varying(
+ &mut self,
+ ir_module: &crate::Module,
+ stage: crate::ShaderStage,
+ class: spirv::StorageClass,
+ debug_name: Option<&str>,
+ ty: Handle<crate::Type>,
+ binding: &crate::Binding,
+ ) -> Result<Word, Error> {
+ let id = self.id_gen.next();
+ let pointer_type_id = self.get_pointer_id(&ir_module.types, ty, class)?;
+ Instruction::variable(pointer_type_id, id, class, None)
+ .to_words(&mut self.logical_layout.declarations);
+
+ if self
+ .flags
+ .contains(WriterFlags::DEBUG | WriterFlags::LABEL_VARYINGS)
+ {
+ if let Some(name) = debug_name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ use spirv::{BuiltIn, Decoration};
+
+ match *binding {
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ } => {
+ self.decorate(id, Decoration::Location, &[location]);
+
+ let no_decorations =
+ // VUID-StandaloneSpirv-Flat-06202
+ // > The Flat, NoPerspective, Sample, and Centroid decorations
+ // > must not be used on variables with the Input storage class in a vertex shader
+ (class == spirv::StorageClass::Input && stage == crate::ShaderStage::Vertex) ||
+ // VUID-StandaloneSpirv-Flat-06201
+ // > The Flat, NoPerspective, Sample, and Centroid decorations
+ // > must not be used on variables with the Output storage class in a fragment shader
+ (class == spirv::StorageClass::Output && stage == crate::ShaderStage::Fragment);
+
+ if !no_decorations {
+ match interpolation {
+ // Perspective-correct interpolation is the default in SPIR-V.
+ None | Some(crate::Interpolation::Perspective) => (),
+ Some(crate::Interpolation::Flat) => {
+ self.decorate(id, Decoration::Flat, &[]);
+ }
+ Some(crate::Interpolation::Linear) => {
+ self.decorate(id, Decoration::NoPerspective, &[]);
+ }
+ }
+ match sampling {
+ // Center sampling is the default in SPIR-V.
+ None | Some(crate::Sampling::Center) => (),
+ Some(crate::Sampling::Centroid) => {
+ self.decorate(id, Decoration::Centroid, &[]);
+ }
+ Some(crate::Sampling::Sample) => {
+ self.require_any(
+ "per-sample interpolation",
+ &[spirv::Capability::SampleRateShading],
+ )?;
+ self.decorate(id, Decoration::Sample, &[]);
+ }
+ }
+ }
+ }
+ crate::Binding::BuiltIn(built_in) => {
+ use crate::BuiltIn as Bi;
+ let built_in = match built_in {
+ Bi::Position { invariant } => {
+ if invariant {
+ self.decorate(id, Decoration::Invariant, &[]);
+ }
+
+ if class == spirv::StorageClass::Output {
+ BuiltIn::Position
+ } else {
+ BuiltIn::FragCoord
+ }
+ }
+ Bi::ViewIndex => {
+ self.require_any("`view_index` built-in", &[spirv::Capability::MultiView])?;
+ BuiltIn::ViewIndex
+ }
+ // vertex
+ Bi::BaseInstance => BuiltIn::BaseInstance,
+ Bi::BaseVertex => BuiltIn::BaseVertex,
+ Bi::ClipDistance => BuiltIn::ClipDistance,
+ Bi::CullDistance => BuiltIn::CullDistance,
+ Bi::InstanceIndex => BuiltIn::InstanceIndex,
+ Bi::PointSize => BuiltIn::PointSize,
+ Bi::VertexIndex => BuiltIn::VertexIndex,
+ // fragment
+ Bi::FragDepth => BuiltIn::FragDepth,
+ Bi::PointCoord => BuiltIn::PointCoord,
+ Bi::FrontFacing => BuiltIn::FrontFacing,
+ Bi::PrimitiveIndex => {
+ self.require_any(
+ "`primitive_index` built-in",
+ &[spirv::Capability::Geometry],
+ )?;
+ BuiltIn::PrimitiveId
+ }
+ Bi::SampleIndex => {
+ self.require_any(
+ "`sample_index` built-in",
+ &[spirv::Capability::SampleRateShading],
+ )?;
+
+ BuiltIn::SampleId
+ }
+ Bi::SampleMask => BuiltIn::SampleMask,
+ // compute
+ Bi::GlobalInvocationId => BuiltIn::GlobalInvocationId,
+ Bi::LocalInvocationId => BuiltIn::LocalInvocationId,
+ Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
+ Bi::WorkGroupId => BuiltIn::WorkgroupId,
+ Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
+ Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
+ };
+
+ self.decorate(id, Decoration::BuiltIn, &[built_in as u32]);
+
+ use crate::ScalarKind as Sk;
+
+ // Per the Vulkan spec, `VUID-StandaloneSpirv-Flat-04744`:
+ //
+ // > Any variable with integer or double-precision floating-
+ // > point type and with Input storage class in a fragment
+ // > shader, must be decorated Flat
+ if class == spirv::StorageClass::Input && stage == crate::ShaderStage::Fragment {
+ let is_flat = match ir_module.types[ty].inner {
+ crate::TypeInner::Scalar { kind, .. }
+ | crate::TypeInner::Vector { kind, .. } => match kind {
+ Sk::Uint | Sk::Sint | Sk::Bool => true,
+ Sk::Float => false,
+ },
+ _ => false,
+ };
+
+ if is_flat {
+ self.decorate(id, Decoration::Flat, &[]);
+ }
+ }
+ }
+ }
+
+ Ok(id)
+ }
+
+ fn write_global_variable(
+ &mut self,
+ ir_module: &crate::Module,
+ global_variable: &crate::GlobalVariable,
+ ) -> Result<Word, Error> {
+ use spirv::Decoration;
+
+ let id = self.id_gen.next();
+ let class = map_storage_class(global_variable.space);
+
+ //self.check(class.required_capabilities())?;
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = global_variable.name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ let storage_access = match global_variable.space {
+ crate::AddressSpace::Storage { access } => Some(access),
+ _ => match ir_module.types[global_variable.ty].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => Some(access),
+ _ => None,
+ },
+ };
+ if let Some(storage_access) = storage_access {
+ if !storage_access.contains(crate::StorageAccess::LOAD) {
+ self.decorate(id, Decoration::NonReadable, &[]);
+ }
+ if !storage_access.contains(crate::StorageAccess::STORE) {
+ self.decorate(id, Decoration::NonWritable, &[]);
+ }
+ }
+
+ let mut substitute_inner_type_lookup = None;
+ if let Some(ref res_binding) = global_variable.binding {
+ self.decorate(id, Decoration::DescriptorSet, &[res_binding.group]);
+ self.decorate(id, Decoration::Binding, &[res_binding.binding]);
+
+ if let Some(&BindingInfo {
+ binding_array_size: Some(remapped_binding_array_size),
+ }) = self.binding_map.get(res_binding)
+ {
+ if let crate::TypeInner::BindingArray { base, .. } =
+ ir_module.types[global_variable.ty].inner
+ {
+ substitute_inner_type_lookup =
+ Some(LookupType::Local(LocalType::PointerToBindingArray {
+ base,
+ size: remapped_binding_array_size as u64,
+ }))
+ }
+ } else {
+ }
+ };
+
+ let init_word = global_variable
+ .init
+ .map(|constant| self.constant_ids[constant.index()]);
+ let inner_type_id = self.get_type_id(
+ substitute_inner_type_lookup.unwrap_or(LookupType::Handle(global_variable.ty)),
+ );
+
+ // generate the wrapping structure if needed
+ let pointer_type_id = if global_needs_wrapper(ir_module, global_variable) {
+ let wrapper_type_id = self.id_gen.next();
+
+ self.decorate(wrapper_type_id, Decoration::Block, &[]);
+ let member = crate::StructMember {
+ name: None,
+ ty: global_variable.ty,
+ binding: None,
+ offset: 0,
+ };
+ self.decorate_struct_member(wrapper_type_id, 0, &member, &ir_module.types)?;
+
+ Instruction::type_struct(wrapper_type_id, &[inner_type_id])
+ .to_words(&mut self.logical_layout.declarations);
+
+ let pointer_type_id = self.id_gen.next();
+ Instruction::type_pointer(pointer_type_id, class, wrapper_type_id)
+ .to_words(&mut self.logical_layout.declarations);
+
+ pointer_type_id
+ } else {
+ // This is a global variable in the Storage address space. The only
+ // way it could have `global_needs_wrapper() == false` is if it has
+ // a runtime-sized array. In this case, we need to decorate it with
+ // Block.
+ if let crate::AddressSpace::Storage { .. } = global_variable.space {
+ self.decorate(inner_type_id, Decoration::Block, &[]);
+ }
+ if substitute_inner_type_lookup.is_some() {
+ inner_type_id
+ } else {
+ self.get_pointer_id(&ir_module.types, global_variable.ty, class)?
+ }
+ };
+
+ let init_word = match (global_variable.space, self.zero_initialize_workgroup_memory) {
+ (crate::AddressSpace::Private, _)
+ | (crate::AddressSpace::WorkGroup, super::ZeroInitializeWorkgroupMemoryMode::Native) => {
+ init_word.or_else(|| Some(self.write_constant_null(inner_type_id)))
+ }
+ _ => init_word,
+ };
+
+ Instruction::variable(pointer_type_id, id, class, init_word)
+ .to_words(&mut self.logical_layout.declarations);
+ Ok(id)
+ }
+
+ /// Write the necessary decorations for a struct member.
+ ///
+ /// Emit decorations for the `index`'th member of the struct type
+ /// designated by `struct_id`, described by `member`.
+ fn decorate_struct_member(
+ &mut self,
+ struct_id: Word,
+ index: usize,
+ member: &crate::StructMember,
+ arena: &UniqueArena<crate::Type>,
+ ) -> Result<(), Error> {
+ use spirv::Decoration;
+
+ self.annotations.push(Instruction::member_decorate(
+ struct_id,
+ index as u32,
+ Decoration::Offset,
+ &[member.offset],
+ ));
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = member.name {
+ self.debugs
+ .push(Instruction::member_name(struct_id, index as u32, name));
+ }
+ }
+
+ // Matrices and arrays of matrices both require decorations,
+ // so "see through" an array to determine if they're needed.
+ let member_array_subty_inner = match arena[member.ty].inner {
+ crate::TypeInner::Array { base, .. } => &arena[base].inner,
+ ref other => other,
+ };
+ if let crate::TypeInner::Matrix {
+ columns: _,
+ rows,
+ width,
+ } = *member_array_subty_inner
+ {
+ let byte_stride = Alignment::from(rows) * width as u32;
+ self.annotations.push(Instruction::member_decorate(
+ struct_id,
+ index as u32,
+ Decoration::ColMajor,
+ &[],
+ ));
+ self.annotations.push(Instruction::member_decorate(
+ struct_id,
+ index as u32,
+ Decoration::MatrixStride,
+ &[byte_stride],
+ ));
+ }
+
+ Ok(())
+ }
+
+ fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word {
+ match self
+ .lookup_function_type
+ .entry(lookup_function_type.clone())
+ {
+ Entry::Occupied(e) => *e.get(),
+ Entry::Vacant(_) => {
+ let id = self.id_gen.next();
+ let instruction = Instruction::type_function(
+ id,
+ lookup_function_type.return_type_id,
+ &lookup_function_type.parameter_type_ids,
+ );
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_function_type.insert(lookup_function_type, id);
+ id
+ }
+ }
+ }
+
+ fn write_physical_layout(&mut self) {
+ self.physical_layout.bound = self.id_gen.0 + 1;
+ }
+
+ fn write_logical_layout(
+ &mut self,
+ ir_module: &crate::Module,
+ mod_info: &ModuleInfo,
+ ep_index: Option<usize>,
+ ) -> Result<(), Error> {
+ fn has_view_index_check(
+ ir_module: &crate::Module,
+ binding: Option<&crate::Binding>,
+ ty: Handle<crate::Type>,
+ ) -> bool {
+ match ir_module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => members.iter().any(|member| {
+ has_view_index_check(ir_module, member.binding.as_ref(), member.ty)
+ }),
+ _ => binding == Some(&crate::Binding::BuiltIn(crate::BuiltIn::ViewIndex)),
+ }
+ }
+
+ let has_storage_buffers =
+ ir_module
+ .global_variables
+ .iter()
+ .any(|(_, var)| match var.space {
+ crate::AddressSpace::Storage { .. } => true,
+ _ => false,
+ });
+ let has_view_index = ir_module
+ .entry_points
+ .iter()
+ .flat_map(|entry| entry.function.arguments.iter())
+ .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty));
+ let has_ray_query = ir_module.special_types.ray_desc.is_some()
+ | ir_module.special_types.ray_intersection.is_some();
+
+ if self.physical_layout.version < 0x10300 && has_storage_buffers {
+ // enable the storage buffer class on < SPV-1.3
+ Instruction::extension("SPV_KHR_storage_buffer_storage_class")
+ .to_words(&mut self.logical_layout.extensions);
+ }
+ if has_view_index {
+ Instruction::extension("SPV_KHR_multiview")
+ .to_words(&mut self.logical_layout.extensions)
+ }
+ if has_ray_query {
+ Instruction::extension("SPV_KHR_ray_query")
+ .to_words(&mut self.logical_layout.extensions)
+ }
+ Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations);
+ Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450")
+ .to_words(&mut self.logical_layout.ext_inst_imports);
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ self.debugs
+ .push(Instruction::source(spirv::SourceLanguage::GLSL, 450));
+ }
+
+ self.constant_ids.resize(ir_module.constants.len(), 0);
+ // first, output all the scalar constants
+ for (handle, constant) in ir_module.constants.iter() {
+ match constant.inner {
+ crate::ConstantInner::Composite { .. } => continue,
+ crate::ConstantInner::Scalar { width, ref value } => {
+ self.constant_ids[handle.index()] = match constant.name {
+ Some(ref name) => {
+ let id = self.id_gen.next();
+ self.write_constant_scalar(id, value, width, Some(name));
+ id
+ }
+ None => self.get_constant_scalar(*value, width),
+ };
+ }
+ }
+ }
+
+ // then all types, some of them may rely on constants and struct type set
+ for (handle, _) in ir_module.types.iter() {
+ self.write_type_declaration_arena(&ir_module.types, handle)?;
+ }
+
+ // then all the composite constants, they rely on types
+ for (handle, constant) in ir_module.constants.iter() {
+ match constant.inner {
+ crate::ConstantInner::Scalar { .. } => continue,
+ crate::ConstantInner::Composite { ty, ref components } => {
+ let ty = LookupType::Handle(ty);
+
+ let mut constituent_ids = Vec::with_capacity(components.len());
+ for constituent in components.iter() {
+ let constituent_id = self.constant_ids[constituent.index()];
+ constituent_ids.push(constituent_id);
+ }
+
+ self.constant_ids[handle.index()] = match constant.name {
+ Some(ref name) => {
+ let id = self.id_gen.next();
+ self.write_constant_composite(id, ty, &constituent_ids, Some(name));
+ id
+ }
+ None => self.get_constant_composite(ty, &constituent_ids),
+ };
+ }
+ }
+ }
+ debug_assert_eq!(self.constant_ids.iter().position(|&id| id == 0), None);
+
+ // now write all globals
+ for (handle, var) in ir_module.global_variables.iter() {
+ // If a single entry point was specified, only write `OpVariable` instructions
+ // for the globals it actually uses. Emit dummies for the others,
+ // to preserve the indices in `global_variables`.
+ let gvar = match ep_index {
+ Some(index) if mod_info.get_entry_point(index)[handle].is_empty() => {
+ GlobalVariable::dummy()
+ }
+ _ => {
+ let id = self.write_global_variable(ir_module, var)?;
+ GlobalVariable::new(id)
+ }
+ };
+ self.global_variables.push(gvar);
+ }
+
+ // all functions
+ for (handle, ir_function) in ir_module.functions.iter() {
+ let info = &mod_info[handle];
+ if let Some(index) = ep_index {
+ let ep_info = mod_info.get_entry_point(index);
+ // If this function uses globals that we omitted from the SPIR-V
+ // because the entry point and its callees didn't use them,
+ // then we must skip it.
+ if !ep_info.dominates_global_use(info) {
+ log::info!("Skip function {:?}", ir_function.name);
+ continue;
+ }
+ }
+ let id = self.write_function(ir_function, info, ir_module, None)?;
+ self.lookup_function.insert(handle, id);
+ }
+
+ // and entry points
+ for (index, ir_ep) in ir_module.entry_points.iter().enumerate() {
+ if ep_index.is_some() && ep_index != Some(index) {
+ continue;
+ }
+ let info = mod_info.get_entry_point(index);
+ let ep_instruction = self.write_entry_point(ir_ep, info, ir_module)?;
+ ep_instruction.to_words(&mut self.logical_layout.entry_points);
+ }
+
+ for capability in self.capabilities_used.iter() {
+ Instruction::capability(*capability).to_words(&mut self.logical_layout.capabilities);
+ }
+ for extension in self.extensions_used.iter() {
+ Instruction::extension(extension).to_words(&mut self.logical_layout.extensions);
+ }
+ if ir_module.entry_points.is_empty() {
+ // SPIR-V doesn't like modules without entry points
+ Instruction::capability(spirv::Capability::Linkage)
+ .to_words(&mut self.logical_layout.capabilities);
+ }
+
+ let addressing_model = spirv::AddressingModel::Logical;
+ let memory_model = spirv::MemoryModel::GLSL450;
+ //self.check(addressing_model.required_capabilities())?;
+ //self.check(memory_model.required_capabilities())?;
+
+ Instruction::memory_model(addressing_model, memory_model)
+ .to_words(&mut self.logical_layout.memory_model);
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ for debug in self.debugs.iter() {
+ debug.to_words(&mut self.logical_layout.debugs);
+ }
+ }
+
+ for annotation in self.annotations.iter() {
+ annotation.to_words(&mut self.logical_layout.annotations);
+ }
+
+ Ok(())
+ }
+
+ pub fn write(
+ &mut self,
+ ir_module: &crate::Module,
+ info: &ModuleInfo,
+ pipeline_options: Option<&PipelineOptions>,
+ words: &mut Vec<Word>,
+ ) -> Result<(), Error> {
+ self.reset();
+
+ // Try to find the entry point and corresponding index
+ let ep_index = match pipeline_options {
+ Some(po) => {
+ let index = ir_module
+ .entry_points
+ .iter()
+ .position(|ep| po.shader_stage == ep.stage && po.entry_point == ep.name)
+ .ok_or(Error::EntryPointNotFound)?;
+ Some(index)
+ }
+ None => None,
+ };
+
+ self.write_logical_layout(ir_module, info, ep_index)?;
+ self.write_physical_layout();
+
+ self.physical_layout.in_words(words);
+ self.logical_layout.in_words(words);
+ Ok(())
+ }
+
+ /// Return the set of capabilities the last module written used.
+ pub const fn get_capabilities_used(&self) -> &crate::FastHashSet<spirv::Capability> {
+ &self.capabilities_used
+ }
+}
+
+#[test]
+fn test_write_physical_layout() {
+ let mut writer = Writer::new(&Options::default()).unwrap();
+ assert_eq!(writer.physical_layout.bound, 0);
+ writer.write_physical_layout();
+ assert_eq!(writer.physical_layout.bound, 3);
+}
diff --git a/third_party/rust/naga/src/back/wgsl/mod.rs b/third_party/rust/naga/src/back/wgsl/mod.rs
new file mode 100644
index 0000000000..d731b1ca0c
--- /dev/null
+++ b/third_party/rust/naga/src/back/wgsl/mod.rs
@@ -0,0 +1,52 @@
+/*!
+Backend for [WGSL][wgsl] (WebGPU Shading Language).
+
+[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html
+*/
+
+mod writer;
+
+use thiserror::Error;
+
+pub use writer::{Writer, WriterFlags};
+
+#[derive(Error, Debug)]
+pub enum Error {
+ #[error(transparent)]
+ FmtError(#[from] std::fmt::Error),
+ #[error("{0}")]
+ Custom(String),
+ #[error("{0}")]
+ Unimplemented(String), // TODO: Error used only during development
+ #[error("Unsupported math function: {0:?}")]
+ UnsupportedMathFunction(crate::MathFunction),
+ #[error("Unsupported relational function: {0:?}")]
+ UnsupportedRelationalFunction(crate::RelationalFunction),
+}
+
+pub fn write_string(
+ module: &crate::Module,
+ info: &crate::valid::ModuleInfo,
+ flags: WriterFlags,
+) -> Result<String, Error> {
+ let mut w = Writer::new(String::new(), flags);
+ w.write(module, info)?;
+ let output = w.finish();
+ Ok(output)
+}
+
+impl crate::AtomicFunction {
+ const fn to_wgsl(self) -> &'static str {
+ match self {
+ Self::Add => "Add",
+ Self::Subtract => "Sub",
+ Self::And => "And",
+ Self::InclusiveOr => "Or",
+ Self::ExclusiveOr => "Xor",
+ Self::Min => "Min",
+ Self::Max => "Max",
+ Self::Exchange { compare: None } => "Exchange",
+ Self::Exchange { .. } => "CompareExchangeWeak",
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/wgsl/writer.rs b/third_party/rust/naga/src/back/wgsl/writer.rs
new file mode 100644
index 0000000000..92086c94a8
--- /dev/null
+++ b/third_party/rust/naga/src/back/wgsl/writer.rs
@@ -0,0 +1,1968 @@
+use super::Error;
+use crate::{
+ back,
+ proc::{self, NameKey},
+ valid, Handle, Module, ShaderStage, TypeInner,
+};
+use std::fmt::Write;
+
+/// Shorthand result used internally by the backend
+type BackendResult = Result<(), Error>;
+
+/// WGSL [attribute](https://gpuweb.github.io/gpuweb/wgsl/#attributes)
+enum Attribute {
+ Binding(u32),
+ BuiltIn(crate::BuiltIn),
+ Group(u32),
+ Invariant,
+ Interpolate(Option<crate::Interpolation>, Option<crate::Sampling>),
+ Location(u32),
+ Stage(ShaderStage),
+ WorkGroupSize([u32; 3]),
+}
+
+/// The WGSL form that `write_expr_with_indirection` should use to render a Naga
+/// expression.
+///
+/// Sometimes a Naga `Expression` alone doesn't provide enough information to
+/// choose the right rendering for it in WGSL. For example, one natural WGSL
+/// rendering of a Naga `LocalVariable(x)` expression might be `&x`, since
+/// `LocalVariable` produces a pointer to the local variable's storage. But when
+/// rendering a `Store` statement, the `pointer` operand must be the left hand
+/// side of a WGSL assignment, so the proper rendering is `x`.
+///
+/// The caller of `write_expr_with_indirection` must provide an `Expected` value
+/// to indicate how ambiguous expressions should be rendered.
+#[derive(Clone, Copy, Debug)]
+enum Indirection {
+ /// Render pointer-construction expressions as WGSL `ptr`-typed expressions.
+ ///
+ /// This is the right choice for most cases. Whenever a Naga pointer
+ /// expression is not the `pointer` operand of a `Load` or `Store`, it
+ /// must be a WGSL pointer expression.
+ Ordinary,
+
+ /// Render pointer-construction expressions as WGSL reference-typed
+ /// expressions.
+ ///
+ /// For example, this is the right choice for the `pointer` operand when
+ /// rendering a `Store` statement as a WGSL assignment.
+ Reference,
+}
+
+bitflags::bitflags! {
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct WriterFlags: u32 {
+ /// Always annotate the type information instead of inferring.
+ const EXPLICIT_TYPES = 0x1;
+ }
+}
+
+pub struct Writer<W> {
+ out: W,
+ flags: WriterFlags,
+ names: crate::FastHashMap<NameKey, String>,
+ namer: proc::Namer,
+ named_expressions: crate::NamedExpressions,
+ ep_results: Vec<(ShaderStage, Handle<crate::Type>)>,
+}
+
+impl<W: Write> Writer<W> {
+ pub fn new(out: W, flags: WriterFlags) -> Self {
+ Writer {
+ out,
+ flags,
+ names: crate::FastHashMap::default(),
+ namer: proc::Namer::default(),
+ named_expressions: crate::NamedExpressions::default(),
+ ep_results: vec![],
+ }
+ }
+
+ fn reset(&mut self, module: &Module) {
+ self.names.clear();
+ self.namer.reset(
+ module,
+ crate::keywords::wgsl::RESERVED,
+ // an identifier must not start with two underscore
+ &["__"],
+ &mut self.names,
+ );
+ self.named_expressions.clear();
+ self.ep_results.clear();
+ }
+
+ pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
+ self.reset(module);
+
+ // Save all ep result types
+ for (_, ep) in module.entry_points.iter().enumerate() {
+ if let Some(ref result) = ep.function.result {
+ self.ep_results.push((ep.stage, result.ty));
+ }
+ }
+
+ // Write all structs
+ for (handle, ty) in module.types.iter() {
+ if let TypeInner::Struct {
+ ref members,
+ span: _,
+ } = ty.inner
+ {
+ self.write_struct(module, handle, members)?;
+ writeln!(self.out)?;
+ }
+ }
+
+ // Write all constants
+ for (handle, constant) in module.constants.iter() {
+ if constant.name.is_some() {
+ self.write_global_constant(module, &constant.inner, handle)?;
+ }
+ }
+
+ // Write all globals
+ for (ty, global) in module.global_variables.iter() {
+ self.write_global(module, global, ty)?;
+ }
+
+ if !module.global_variables.is_empty() {
+ // Add extra newline for readability
+ writeln!(self.out)?;
+ }
+
+ // Write all regular functions
+ for (handle, function) in module.functions.iter() {
+ let fun_info = &info[handle];
+
+ let func_ctx = back::FunctionCtx {
+ ty: back::FunctionType::Function(handle),
+ info: fun_info,
+ expressions: &function.expressions,
+ named_expressions: &function.named_expressions,
+ };
+
+ // Write the function
+ self.write_function(module, function, &func_ctx)?;
+
+ writeln!(self.out)?;
+ }
+
+ // Write all entry points
+ for (index, ep) in module.entry_points.iter().enumerate() {
+ let attributes = match ep.stage {
+ ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)],
+ ShaderStage::Compute => vec![
+ Attribute::Stage(ShaderStage::Compute),
+ Attribute::WorkGroupSize(ep.workgroup_size),
+ ],
+ };
+
+ self.write_attributes(&attributes)?;
+ // Add a newline after attribute
+ writeln!(self.out)?;
+
+ let func_ctx = back::FunctionCtx {
+ ty: back::FunctionType::EntryPoint(index as u16),
+ info: info.get_entry_point(index),
+ expressions: &ep.function.expressions,
+ named_expressions: &ep.function.named_expressions,
+ };
+ self.write_function(module, &ep.function, &func_ctx)?;
+
+ if index < module.entry_points.len() - 1 {
+ writeln!(self.out)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write [`ScalarValue`](crate::ScalarValue)
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_scalar_value(&mut self, value: crate::ScalarValue) -> BackendResult {
+ use crate::ScalarValue as Sv;
+
+ match value {
+ Sv::Sint(value) => write!(self.out, "{value}")?,
+ Sv::Uint(value) => write!(self.out, "{value}u")?,
+ // Floats are written using `Debug` instead of `Display` because it always appends the
+ // decimal part even it's zero
+ Sv::Float(value) => write!(self.out, "{value:?}")?,
+ Sv::Bool(value) => write!(self.out, "{value}")?,
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write struct name
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_struct_name(&mut self, module: &Module, handle: Handle<crate::Type>) -> BackendResult {
+ if module.types[handle].name.is_none() {
+ if let Some(&(stage, _)) = self.ep_results.iter().find(|&&(_, ty)| ty == handle) {
+ let name = match stage {
+ ShaderStage::Compute => "ComputeOutput",
+ ShaderStage::Fragment => "FragmentOutput",
+ ShaderStage::Vertex => "VertexOutput",
+ };
+
+ write!(self.out, "{name}")?;
+ return Ok(());
+ }
+ }
+
+ write!(self.out, "{}", self.names[&NameKey::Type(handle)])?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write
+ /// [functions](https://gpuweb.github.io/gpuweb/wgsl/#functions)
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_function(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ let func_name = match func_ctx.ty {
+ back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)],
+ back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
+ };
+
+ // Write function name
+ write!(self.out, "fn {func_name}(")?;
+
+ // Write function arguments
+ for (index, arg) in func.arguments.iter().enumerate() {
+ // Write argument attribute if a binding is present
+ if let Some(ref binding) = arg.binding {
+ self.write_attributes(&map_binding_to_attribute(
+ binding,
+ module.types[arg.ty].inner.scalar_kind(),
+ ))?;
+ }
+ // Write argument name
+ let argument_name = match func_ctx.ty {
+ back::FunctionType::Function(handle) => {
+ &self.names[&NameKey::FunctionArgument(handle, index as u32)]
+ }
+ back::FunctionType::EntryPoint(ep_index) => {
+ &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)]
+ }
+ };
+
+ write!(self.out, "{argument_name}: ")?;
+ // Write argument type
+ self.write_type(module, arg.ty)?;
+ if index < func.arguments.len() - 1 {
+ // Add a separator between args
+ write!(self.out, ", ")?;
+ }
+ }
+
+ write!(self.out, ")")?;
+
+ // Write function return type
+ if let Some(ref result) = func.result {
+ write!(self.out, " -> ")?;
+ if let Some(ref binding) = result.binding {
+ self.write_attributes(&map_binding_to_attribute(
+ binding,
+ module.types[result.ty].inner.scalar_kind(),
+ ))?;
+ }
+ self.write_type(module, result.ty)?;
+ }
+
+ write!(self.out, " {{")?;
+ writeln!(self.out)?;
+
+ // Write function local variables
+ for (handle, local) in func.local_variables.iter() {
+ // Write indentation (only for readability)
+ write!(self.out, "{}", back::INDENT)?;
+
+ // Write the local name
+ // The leading space is important
+ write!(self.out, "var {}: ", self.names[&func_ctx.name_key(handle)])?;
+
+ // Write the local type
+ self.write_type(module, local.ty)?;
+
+ // Write the local initializer if needed
+ if let Some(init) = local.init {
+ // Put the equal signal only if there's a initializer
+ // The leading and trailing spaces aren't needed but help with readability
+ write!(self.out, " = ")?;
+
+ // Write the constant
+ // `write_constant` adds no trailing or leading space/newline
+ self.write_constant(module, init)?;
+ }
+
+ // Finish the local with `;` and add a newline (only for readability)
+ writeln!(self.out, ";")?
+ }
+
+ if !func.local_variables.is_empty() {
+ writeln!(self.out)?;
+ }
+
+ // Write the function body (statement list)
+ for sta in func.body.iter() {
+ // The indentation should always be 1 when writing the function body
+ self.write_stmt(module, sta, func_ctx, back::Level(1))?;
+ }
+
+ writeln!(self.out, "}}")?;
+
+ self.named_expressions.clear();
+
+ Ok(())
+ }
+
+ /// Helper method to write a attribute
+ fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult {
+ for attribute in attributes {
+ match *attribute {
+ Attribute::Location(id) => write!(self.out, "@location({id}) ")?,
+ Attribute::BuiltIn(builtin_attrib) => {
+ let builtin = builtin_str(builtin_attrib)?;
+ write!(self.out, "@builtin({builtin}) ")?;
+ }
+ Attribute::Stage(shader_stage) => {
+ let stage_str = match shader_stage {
+ ShaderStage::Vertex => "vertex",
+ ShaderStage::Fragment => "fragment",
+ ShaderStage::Compute => "compute",
+ };
+ write!(self.out, "@{stage_str} ")?;
+ }
+ Attribute::WorkGroupSize(size) => {
+ write!(
+ self.out,
+ "@workgroup_size({}, {}, {}) ",
+ size[0], size[1], size[2]
+ )?;
+ }
+ Attribute::Binding(id) => write!(self.out, "@binding({id}) ")?,
+ Attribute::Group(id) => write!(self.out, "@group({id}) ")?,
+ Attribute::Invariant => write!(self.out, "@invariant ")?,
+ Attribute::Interpolate(interpolation, sampling) => {
+ if sampling.is_some() && sampling != Some(crate::Sampling::Center) {
+ write!(
+ self.out,
+ "@interpolate({}, {}) ",
+ interpolation_str(
+ interpolation.unwrap_or(crate::Interpolation::Perspective)
+ ),
+ sampling_str(sampling.unwrap_or(crate::Sampling::Center))
+ )?;
+ } else if interpolation.is_some()
+ && interpolation != Some(crate::Interpolation::Perspective)
+ {
+ write!(
+ self.out,
+ "@interpolate({}) ",
+ interpolation_str(
+ interpolation.unwrap_or(crate::Interpolation::Perspective)
+ )
+ )?;
+ }
+ }
+ };
+ }
+ Ok(())
+ }
+
+ /// Helper method used to write structs
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_struct(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Type>,
+ members: &[crate::StructMember],
+ ) -> BackendResult {
+ write!(self.out, "struct ")?;
+ self.write_struct_name(module, handle)?;
+ write!(self.out, " {{")?;
+ writeln!(self.out)?;
+ for (index, member) in members.iter().enumerate() {
+ // The indentation is only for readability
+ write!(self.out, "{}", back::INDENT)?;
+ if let Some(ref binding) = member.binding {
+ self.write_attributes(&map_binding_to_attribute(
+ binding,
+ module.types[member.ty].inner.scalar_kind(),
+ ))?;
+ }
+ // Write struct member name and type
+ let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
+ write!(self.out, "{member_name}: ")?;
+ self.write_type(module, member.ty)?;
+ write!(self.out, ",")?;
+ writeln!(self.out)?;
+ }
+
+ write!(self.out, "}}")?;
+
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ let inner = &module.types[ty].inner;
+ match *inner {
+ TypeInner::Struct { .. } => self.write_struct_name(module, ty)?,
+ ref other => self.write_value_type(module, other)?,
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write value types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
+ match *inner {
+ TypeInner::Vector { size, kind, width } => write!(
+ self.out,
+ "vec{}<{}>",
+ back::vector_size_str(size),
+ scalar_kind_str(kind, width),
+ )?,
+ TypeInner::Sampler { comparison: false } => {
+ write!(self.out, "sampler")?;
+ }
+ TypeInner::Sampler { comparison: true } => {
+ write!(self.out, "sampler_comparison")?;
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ // More about texture types: https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type
+ use crate::ImageClass as Ic;
+
+ let dim_str = image_dimension_str(dim);
+ let arrayed_str = if arrayed { "_array" } else { "" };
+ let (class_str, multisampled_str, format_str, storage_str) = match class {
+ Ic::Sampled { kind, multi } => (
+ "",
+ if multi { "multisampled_" } else { "" },
+ scalar_kind_str(kind, 4),
+ "",
+ ),
+ Ic::Depth { multi } => {
+ ("depth_", if multi { "multisampled_" } else { "" }, "", "")
+ }
+ Ic::Storage { format, access } => (
+ "storage_",
+ "",
+ storage_format_str(format),
+ if access.contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
+ {
+ ",read_write"
+ } else if access.contains(crate::StorageAccess::LOAD) {
+ ",read"
+ } else {
+ ",write"
+ },
+ ),
+ };
+ write!(
+ self.out,
+ "texture_{class_str}{multisampled_str}{dim_str}{arrayed_str}"
+ )?;
+
+ if !format_str.is_empty() {
+ write!(self.out, "<{format_str}{storage_str}>")?;
+ }
+ }
+ TypeInner::Scalar { kind, width } => {
+ write!(self.out, "{}", scalar_kind_str(kind, width))?;
+ }
+ TypeInner::Atomic { kind, width } => {
+ write!(self.out, "atomic<{}>", scalar_kind_str(kind, width))?;
+ }
+ TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ } => {
+ // More info https://gpuweb.github.io/gpuweb/wgsl/#array-types
+ // array<A, 3> -- Constant array
+ // array<A> -- Dynamic array
+ write!(self.out, "array<")?;
+ match size {
+ crate::ArraySize::Constant(handle) => {
+ self.write_type(module, base)?;
+ write!(self.out, ",")?;
+ self.write_constant(module, handle)?;
+ }
+ crate::ArraySize::Dynamic => {
+ self.write_type(module, base)?;
+ }
+ }
+ write!(self.out, ">")?;
+ }
+ TypeInner::BindingArray { base, size } => {
+ // More info https://github.com/gpuweb/gpuweb/issues/2105
+ write!(self.out, "binding_array<")?;
+ match size {
+ crate::ArraySize::Constant(handle) => {
+ self.write_type(module, base)?;
+ write!(self.out, ",")?;
+ self.write_constant(module, handle)?;
+ }
+ crate::ArraySize::Dynamic => {
+ self.write_type(module, base)?;
+ }
+ }
+ write!(self.out, ">")?;
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width: _,
+ } => {
+ write!(
+ self.out,
+ //TODO: Can matrix be other than f32?
+ "mat{}x{}<f32>",
+ back::vector_size_str(columns),
+ back::vector_size_str(rows),
+ )?;
+ }
+ TypeInner::Pointer { base, space } => {
+ let (address, maybe_access) = address_space_str(space);
+ // Everything but `AddressSpace::Handle` gives us a `address` name, but
+ // Naga IR never produces pointers to handles, so it doesn't matter much
+ // how we write such a type. Just write it as the base type alone.
+ if let Some(space) = address {
+ write!(self.out, "ptr<{space}, ")?;
+ }
+ self.write_type(module, base)?;
+ if address.is_some() {
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {access}")?;
+ }
+ write!(self.out, ">")?;
+ }
+ }
+ TypeInner::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space,
+ } => {
+ let (address, maybe_access) = address_space_str(space);
+ if let Some(space) = address {
+ write!(self.out, "ptr<{}, {}", space, scalar_kind_str(kind, width))?;
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {access}")?;
+ }
+ write!(self.out, ">")?;
+ } else {
+ return Err(Error::Unimplemented(format!(
+ "ValuePointer to AddressSpace::Handle {inner:?}"
+ )));
+ }
+ }
+ TypeInner::ValuePointer {
+ size: Some(size),
+ kind,
+ width,
+ space,
+ } => {
+ let (address, maybe_access) = address_space_str(space);
+ if let Some(space) = address {
+ write!(
+ self.out,
+ "ptr<{}, vec{}<{}>",
+ space,
+ back::vector_size_str(size),
+ scalar_kind_str(kind, width)
+ )?;
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {access}")?;
+ }
+ write!(self.out, ">")?;
+ } else {
+ return Err(Error::Unimplemented(format!(
+ "ValuePointer to AddressSpace::Handle {inner:?}"
+ )));
+ }
+ write!(self.out, ">")?;
+ }
+ _ => {
+ return Err(Error::Unimplemented(format!("write_value_type {inner:?}")));
+ }
+ }
+
+ Ok(())
+ }
+ /// Helper method used to write statements
+ ///
+ /// # Notes
+ /// Always adds a newline
+ fn write_stmt(
+ &mut self,
+ module: &Module,
+ stmt: &crate::Statement,
+ func_ctx: &back::FunctionCtx<'_>,
+ level: back::Level,
+ ) -> BackendResult {
+ use crate::{Expression, Statement};
+
+ match *stmt {
+ Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ let info = &func_ctx.info[handle];
+ let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) {
+ // Front end provides names for all variables at the start of writing.
+ // But we write them to step by step. We need to recache them
+ // Otherwise, we could accidentally write variable name instead of full expression.
+ // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
+ Some(self.namer.call(name))
+ } else if info.ref_count == 0 {
+ write!(self.out, "{level}_ = ")?;
+ self.write_expr(module, handle, func_ctx)?;
+ writeln!(self.out, ";")?;
+ continue;
+ } else {
+ let expr = &func_ctx.expressions[handle];
+ let min_ref_count = expr.bake_ref_count();
+ // Forcefully creating baking expressions in some cases to help with readability
+ let required_baking_expr = match *expr {
+ Expression::ImageLoad { .. }
+ | Expression::ImageQuery { .. }
+ | Expression::ImageSample { .. } => true,
+ _ => false,
+ };
+ if min_ref_count <= info.ref_count || required_baking_expr {
+ Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
+ } else {
+ None
+ }
+ };
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{level}")?;
+ self.start_named_expr(module, handle, func_ctx, &name)?;
+ self.write_expr(module, handle, func_ctx)?;
+ self.named_expressions.insert(handle, name);
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "if ")?;
+ self.write_expr(module, condition, func_ctx)?;
+ writeln!(self.out, " {{")?;
+
+ let l2 = level.next();
+ for sta in accept {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+
+ // If there are no statements in the reject block we skip writing it
+ // This is only for readability
+ if !reject.is_empty() {
+ writeln!(self.out, "{level}}} else {{")?;
+
+ for sta in reject {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Return { value } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "return")?;
+ if let Some(return_value) = value {
+ // The leading space is important
+ write!(self.out, " ")?;
+ self.write_expr(module, return_value, func_ctx)?;
+ }
+ writeln!(self.out, ";")?;
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Kill => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "discard;")?
+ }
+ Statement::Store { pointer, value } => {
+ write!(self.out, "{level}")?;
+
+ let is_atomic = match *func_ctx.info[pointer].ty.inner_with(&module.types) {
+ crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => true,
+ _ => false,
+ },
+ _ => false,
+ };
+ if is_atomic {
+ write!(self.out, "atomicStore(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr_with_indirection(
+ module,
+ pointer,
+ func_ctx,
+ Indirection::Reference,
+ )?;
+ write!(self.out, " = ")?;
+ self.write_expr(module, value, func_ctx)?;
+ }
+ writeln!(self.out, ";")?
+ }
+ Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ if let Some(expr) = result {
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ self.start_named_expr(module, expr, func_ctx, &name)?;
+ self.named_expressions.insert(expr, name);
+ }
+ let func_name = &self.names[&NameKey::Function(function)];
+ write!(self.out, "{func_name}(")?;
+ for (index, &argument) in arguments.iter().enumerate() {
+ self.write_expr(module, argument, func_ctx)?;
+ // Only write a comma if isn't the last element
+ if index != arguments.len().saturating_sub(1) {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ }
+ writeln!(self.out, ");")?
+ }
+ Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_named_expr(module, result, func_ctx, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+
+ let fun_str = fun.to_wgsl();
+ write!(self.out, "atomic{fun_str}(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ write!(self.out, ", ")?;
+ self.write_expr(module, cmp, func_ctx)?;
+ }
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ");")?
+ }
+ Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "textureStore(")?;
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ if let Some(array_index_expr) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index_expr, func_ctx)?;
+ }
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Block(ref block) => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "{{")?;
+ for sta in block.iter() {
+ // Increase the indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, level.next())?
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ // Start the switch
+ write!(self.out, "{level}")?;
+ write!(self.out, "switch ")?;
+ self.write_expr(module, selector, func_ctx)?;
+ writeln!(self.out, " {{")?;
+
+ let l2 = level.next();
+ let mut new_case = true;
+ for case in cases {
+ if case.fall_through && !case.body.is_empty() {
+ // TODO: we could do the same workaround as we did for the HLSL backend
+ return Err(Error::Unimplemented(
+ "fall-through switch case block".into(),
+ ));
+ }
+
+ match case.value {
+ crate::SwitchValue::I32(value) => {
+ if new_case {
+ write!(self.out, "{l2}case ")?;
+ }
+ write!(self.out, "{value}")?;
+ }
+ crate::SwitchValue::U32(value) => {
+ if new_case {
+ write!(self.out, "{l2}case ")?;
+ }
+ write!(self.out, "{value}u")?;
+ }
+ crate::SwitchValue::Default => {
+ if new_case {
+ if case.fall_through {
+ write!(self.out, "{l2}case ")?;
+ } else {
+ write!(self.out, "{l2}")?;
+ }
+ }
+ write!(self.out, "default")?;
+ }
+ }
+
+ new_case = !case.fall_through;
+
+ if case.fall_through {
+ write!(self.out, ", ")?;
+ } else {
+ writeln!(self.out, ": {{")?;
+ }
+
+ for sta in case.body.iter() {
+ self.write_stmt(module, sta, func_ctx, l2.next())?;
+ }
+
+ if !case.fall_through {
+ writeln!(self.out, "{l2}}}")?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "loop {{")?;
+
+ let l2 = level.next();
+ for sta in body.iter() {
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+
+ // The continuing is optional so we don't need to write it if
+ // it is empty, but the `break if` counts as a continuing statement
+ // so even if `continuing` is empty we must generate it if a
+ // `break if` exists
+ if !continuing.is_empty() || break_if.is_some() {
+ writeln!(self.out, "{l2}continuing {{")?;
+ for sta in continuing.iter() {
+ self.write_stmt(module, sta, func_ctx, l2.next())?;
+ }
+
+ // The `break if` is always the last
+ // statement of the `continuing` block
+ if let Some(condition) = break_if {
+ // The trailing space is important
+ write!(self.out, "{}break if ", l2.next())?;
+ self.write_expr(module, condition, func_ctx)?;
+ // Close the `break if` statement
+ writeln!(self.out, ";")?;
+ }
+
+ writeln!(self.out, "{l2}}}")?;
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Break => {
+ writeln!(self.out, "{level}break;")?;
+ }
+ Statement::Continue => {
+ writeln!(self.out, "{level}continue;")?;
+ }
+ Statement::Barrier(barrier) => {
+ if barrier.contains(crate::Barrier::STORAGE) {
+ writeln!(self.out, "{level}storageBarrier();")?;
+ }
+
+ if barrier.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(self.out, "{level}workgroupBarrier();")?;
+ }
+ }
+ Statement::RayQuery { .. } => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Return the sort of indirection that `expr`'s plain form evaluates to.
+ ///
+ /// An expression's 'plain form' is the most general rendition of that
+ /// expression into WGSL, lacking `&` or `*` operators:
+ ///
+ /// - The plain form of `LocalVariable(x)` is simply `x`, which is a reference
+ /// to the local variable's storage.
+ ///
+ /// - The plain form of `GlobalVariable(g)` is simply `g`, which is usually a
+ /// reference to the global variable's storage. However, globals in the
+ /// `Handle` address space are immutable, and `GlobalVariable` expressions for
+ /// those produce the value directly, not a pointer to it. Such
+ /// `GlobalVariable` expressions are `Ordinary`.
+ ///
+ /// - `Access` and `AccessIndex` are `Reference` when their `base` operand is a
+ /// pointer. If they are applied directly to a composite value, they are
+ /// `Ordinary`.
+ ///
+ /// Note that `FunctionArgument` expressions are never `Reference`, even when
+ /// the argument's type is `Pointer`. `FunctionArgument` always evaluates to the
+ /// argument's value directly, so any pointer it produces is merely the value
+ /// passed by the caller.
+ fn plain_form_indirection(
+ &self,
+ expr: Handle<crate::Expression>,
+ module: &Module,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> Indirection {
+ use crate::Expression as Ex;
+
+ // Named expressions are `let` expressions, which apply the Load Rule,
+ // so if their type is a Naga pointer, then that must be a WGSL pointer
+ // as well.
+ if self.named_expressions.contains_key(&expr) {
+ return Indirection::Ordinary;
+ }
+
+ match func_ctx.expressions[expr] {
+ Ex::LocalVariable(_) => Indirection::Reference,
+ Ex::GlobalVariable(handle) => {
+ let global = &module.global_variables[handle];
+ match global.space {
+ crate::AddressSpace::Handle => Indirection::Ordinary,
+ _ => Indirection::Reference,
+ }
+ }
+ Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
+ let base_ty = func_ctx.info[base].ty.inner_with(&module.types);
+ match *base_ty {
+ crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
+ Indirection::Reference
+ }
+ _ => Indirection::Ordinary,
+ }
+ }
+ _ => Indirection::Ordinary,
+ }
+ }
+
+ fn start_named_expr(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx,
+ name: &str,
+ ) -> BackendResult {
+ // Write variable name
+ write!(self.out, "let {name}")?;
+ if self.flags.contains(WriterFlags::EXPLICIT_TYPES) {
+ write!(self.out, ": ")?;
+ let ty = &func_ctx.info[handle].ty;
+ // Write variable type
+ match *ty {
+ proc::TypeResolution::Handle(handle) => {
+ self.write_type(module, handle)?;
+ }
+ proc::TypeResolution::Value(ref inner) => {
+ self.write_value_type(module, inner)?;
+ }
+ }
+ }
+
+ write!(self.out, " = ")?;
+ Ok(())
+ }
+
+ /// Write the ordinary WGSL form of `expr`.
+ ///
+ /// See `write_expr_with_indirection` for details.
+ fn write_expr(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary)
+ }
+
+ /// Write `expr` as a WGSL expression with the requested indirection.
+ ///
+ /// In terms of the WGSL grammar, the resulting expression is a
+ /// `singular_expression`. It may be parenthesized. This makes it suitable
+ /// for use as the operand of a unary or binary operator without worrying
+ /// about precedence.
+ ///
+ /// This does not produce newlines or indentation.
+ ///
+ /// The `requested` argument indicates (roughly) whether Naga
+ /// `Pointer`-valued expressions represent WGSL references or pointers. See
+ /// `Indirection` for details.
+ fn write_expr_with_indirection(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ requested: Indirection,
+ ) -> BackendResult {
+ // If the plain form of the expression is not what we need, emit the
+ // operator necessary to correct that.
+ let plain = self.plain_form_indirection(expr, module, func_ctx);
+ match (requested, plain) {
+ (Indirection::Ordinary, Indirection::Reference) => {
+ write!(self.out, "(&")?;
+ self.write_expr_plain_form(module, expr, func_ctx, plain)?;
+ write!(self.out, ")")?;
+ }
+ (Indirection::Reference, Indirection::Ordinary) => {
+ write!(self.out, "(*")?;
+ self.write_expr_plain_form(module, expr, func_ctx, plain)?;
+ write!(self.out, ")")?;
+ }
+ (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?,
+ }
+
+ Ok(())
+ }
+
+ /// Write the 'plain form' of `expr`.
+ ///
+ /// An expression's 'plain form' is the most general rendition of that
+ /// expression into WGSL, lacking `&` or `*` operators. The plain forms of
+ /// `LocalVariable(x)` and `GlobalVariable(g)` are simply `x` and `g`. Such
+ /// Naga expressions represent both WGSL pointers and references; it's the
+ /// caller's responsibility to distinguish those cases appropriately.
+ fn write_expr_plain_form(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ indirection: Indirection,
+ ) -> BackendResult {
+ use crate::Expression;
+
+ if let Some(name) = self.named_expressions.get(&expr) {
+ write!(self.out, "{name}")?;
+ return Ok(());
+ }
+
+ let expression = &func_ctx.expressions[expr];
+
+ // Write the plain WGSL form of a Naga expression.
+ //
+ // The plain form of `LocalVariable` and `GlobalVariable` expressions is
+ // simply the variable name; `*` and `&` operators are never emitted.
+ //
+ // The plain form of `Access` and `AccessIndex` expressions are WGSL
+ // `postfix_expression` forms for member/component access and
+ // subscripting.
+ match *expression {
+ Expression::Constant(constant) => self.write_constant(module, constant)?,
+ Expression::Compose { ty, ref components } => {
+ self.write_type(module, ty)?;
+ write!(self.out, "(")?;
+ for (index, component) in components.iter().enumerate() {
+ self.write_expr(module, *component, func_ctx)?;
+ // Only write a comma if isn't the last element
+ if index != components.len().saturating_sub(1) {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ }
+ write!(self.out, ")")?
+ }
+ Expression::FunctionArgument(pos) => {
+ let name_key = func_ctx.argument_key(pos);
+ let name = &self.names[&name_key];
+ write!(self.out, "{name}")?;
+ }
+ Expression::Binary { op, left, right } => {
+ write!(self.out, "(")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, " {} ", back::binary_operation_str(op))?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Access { base, index } => {
+ self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
+ write!(self.out, "[")?;
+ self.write_expr(module, index, func_ctx)?;
+ write!(self.out, "]")?
+ }
+ Expression::AccessIndex { base, index } => {
+ let base_ty_res = &func_ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+
+ self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
+
+ let base_ty_handle = match *resolved {
+ TypeInner::Pointer { base, space: _ } => {
+ resolved = &module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+
+ match *resolved {
+ TypeInner::Vector { .. } => {
+ // Write vector access as a swizzle
+ write!(self.out, ".{}", back::COMPONENTS[index as usize])?
+ }
+ TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::BindingArray { .. }
+ | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?,
+ TypeInner::Struct { .. } => {
+ // This will never panic in case the type is a `Struct`, this is not true
+ // for other types so we can only check while inside this match arm
+ let ty = base_ty_handle.unwrap();
+
+ write!(
+ self.out,
+ ".{}",
+ &self.names[&NameKey::StructMember(ty, index)]
+ )?
+ }
+ ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
+ }
+ }
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather: None,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ use crate::SampleLevel as Sl;
+
+ let suffix_cmp = match depth_ref {
+ Some(_) => "Compare",
+ None => "",
+ };
+ let suffix_level = match level {
+ Sl::Auto => "",
+ Sl::Zero | Sl::Exact(_) => "Level",
+ Sl::Bias(_) => "Bias",
+ Sl::Gradient { .. } => "Grad",
+ };
+
+ write!(self.out, "textureSample{suffix_cmp}{suffix_level}(")?;
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, sampler, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+
+ if let Some(array_index) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index, func_ctx)?;
+ }
+
+ if let Some(depth_ref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(module, depth_ref, func_ctx)?;
+ }
+
+ match level {
+ Sl::Auto => {}
+ Sl::Zero => {
+ // Level 0 is implied for depth comparison
+ if depth_ref.is_none() {
+ write!(self.out, ", 0.0")?;
+ }
+ }
+ Sl::Exact(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Bias(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Gradient { x, y } => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, x, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, y, func_ctx)?;
+ }
+ }
+
+ if let Some(offset) = offset {
+ write!(self.out, ", ")?;
+ self.write_constant(module, offset)?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather: Some(component),
+ coordinate,
+ array_index,
+ offset,
+ level: _,
+ depth_ref,
+ } => {
+ let suffix_cmp = match depth_ref {
+ Some(_) => "Compare",
+ None => "",
+ };
+
+ write!(self.out, "textureGather{suffix_cmp}(")?;
+ match *func_ctx.info[image].ty.inner_with(&module.types) {
+ TypeInner::Image {
+ class: crate::ImageClass::Depth { multi: _ },
+ ..
+ } => {}
+ _ => {
+ write!(self.out, "{}, ", component as u8)?;
+ }
+ }
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, sampler, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+
+ if let Some(array_index) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index, func_ctx)?;
+ }
+
+ if let Some(depth_ref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(module, depth_ref, func_ctx)?;
+ }
+
+ if let Some(offset) = offset {
+ write!(self.out, ", ")?;
+ self.write_constant(module, offset)?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ Expression::ImageQuery { image, query } => {
+ use crate::ImageQuery as Iq;
+
+ let texture_function = match query {
+ Iq::Size { .. } => "textureDimensions",
+ Iq::NumLevels => "textureNumLevels",
+ Iq::NumLayers => "textureNumLayers",
+ Iq::NumSamples => "textureNumSamples",
+ };
+
+ write!(self.out, "{texture_function}(")?;
+ self.write_expr(module, image, func_ctx)?;
+ if let Iq::Size { level: Some(level) } = query {
+ write!(self.out, ", ")?;
+ self.write_expr(module, level, func_ctx)?;
+ };
+ write!(self.out, ")")?;
+ }
+ Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ write!(self.out, "textureLoad(")?;
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ if let Some(array_index) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index, func_ctx)?;
+ }
+ if let Some(index) = sample.or(level) {
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ }
+ write!(self.out, ")")?;
+ }
+ Expression::GlobalVariable(handle) => {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{name}")?;
+ }
+ Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ let inner = func_ctx.info[expr].ty.inner_with(&module.types);
+ match *inner {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ ..
+ } => {
+ let scalar_kind_str = scalar_kind_str(kind, convert.unwrap_or(width));
+ write!(
+ self.out,
+ "mat{}x{}<{}>",
+ back::vector_size_str(columns),
+ back::vector_size_str(rows),
+ scalar_kind_str
+ )?;
+ }
+ TypeInner::Vector { size, width, .. } => {
+ let vector_size_str = back::vector_size_str(size);
+ let scalar_kind_str = scalar_kind_str(kind, convert.unwrap_or(width));
+ if convert.is_some() {
+ write!(self.out, "vec{vector_size_str}<{scalar_kind_str}>")?;
+ } else {
+ write!(self.out, "bitcast<vec{vector_size_str}<{scalar_kind_str}>>")?;
+ }
+ }
+ TypeInner::Scalar { width, .. } => {
+ let scalar_kind_str = scalar_kind_str(kind, convert.unwrap_or(width));
+ if convert.is_some() {
+ write!(self.out, "{scalar_kind_str}")?
+ } else {
+ write!(self.out, "bitcast<{scalar_kind_str}>")?
+ }
+ }
+ _ => {
+ return Err(Error::Unimplemented(format!(
+ "write_expr expression::as {inner:?}"
+ )));
+ }
+ };
+ write!(self.out, "(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Splat { size, value } => {
+ let inner = func_ctx.info[value].ty.inner_with(&module.types);
+ let (scalar_kind, scalar_width) = match *inner {
+ crate::TypeInner::Scalar { kind, width } => (kind, width),
+ _ => {
+ return Err(Error::Unimplemented(format!(
+ "write_expr expression::splat {inner:?}"
+ )));
+ }
+ };
+ let scalar = scalar_kind_str(scalar_kind, scalar_width);
+ let size = back::vector_size_str(size);
+
+ write!(self.out, "vec{size}<{scalar}>(")?;
+ self.write_expr(module, value, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Load { pointer } => {
+ let is_atomic = match *func_ctx.info[pointer].ty.inner_with(&module.types) {
+ crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => true,
+ _ => false,
+ },
+ _ => false,
+ };
+
+ if is_atomic {
+ write!(self.out, "atomicLoad(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr_with_indirection(
+ module,
+ pointer,
+ func_ctx,
+ Indirection::Reference,
+ )?;
+ }
+ }
+ Expression::LocalVariable(handle) => {
+ write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
+ }
+ Expression::ArrayLength(expr) => {
+ write!(self.out, "arrayLength(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ enum Function {
+ Regular(&'static str),
+ }
+
+ let function = match fun {
+ Mf::Abs => Function::Regular("abs"),
+ Mf::Min => Function::Regular("min"),
+ Mf::Max => Function::Regular("max"),
+ Mf::Clamp => Function::Regular("clamp"),
+ Mf::Saturate => Function::Regular("saturate"),
+ // trigonometry
+ Mf::Cos => Function::Regular("cos"),
+ Mf::Cosh => Function::Regular("cosh"),
+ Mf::Sin => Function::Regular("sin"),
+ Mf::Sinh => Function::Regular("sinh"),
+ Mf::Tan => Function::Regular("tan"),
+ Mf::Tanh => Function::Regular("tanh"),
+ Mf::Acos => Function::Regular("acos"),
+ Mf::Asin => Function::Regular("asin"),
+ Mf::Atan => Function::Regular("atan"),
+ Mf::Atan2 => Function::Regular("atan2"),
+ Mf::Asinh => Function::Regular("asinh"),
+ Mf::Acosh => Function::Regular("acosh"),
+ Mf::Atanh => Function::Regular("atanh"),
+ Mf::Radians => Function::Regular("radians"),
+ Mf::Degrees => Function::Regular("degrees"),
+ // decomposition
+ Mf::Ceil => Function::Regular("ceil"),
+ Mf::Floor => Function::Regular("floor"),
+ Mf::Round => Function::Regular("round"),
+ Mf::Fract => Function::Regular("fract"),
+ Mf::Trunc => Function::Regular("trunc"),
+ Mf::Modf => Function::Regular("modf"),
+ Mf::Frexp => Function::Regular("frexp"),
+ Mf::Ldexp => Function::Regular("ldexp"),
+ // exponent
+ Mf::Exp => Function::Regular("exp"),
+ Mf::Exp2 => Function::Regular("exp2"),
+ Mf::Log => Function::Regular("log"),
+ Mf::Log2 => Function::Regular("log2"),
+ Mf::Pow => Function::Regular("pow"),
+ // geometry
+ Mf::Dot => Function::Regular("dot"),
+ Mf::Outer => Function::Regular("outerProduct"),
+ Mf::Cross => Function::Regular("cross"),
+ Mf::Distance => Function::Regular("distance"),
+ Mf::Length => Function::Regular("length"),
+ Mf::Normalize => Function::Regular("normalize"),
+ Mf::FaceForward => Function::Regular("faceForward"),
+ Mf::Reflect => Function::Regular("reflect"),
+ Mf::Refract => Function::Regular("refract"),
+ // computational
+ Mf::Sign => Function::Regular("sign"),
+ Mf::Fma => Function::Regular("fma"),
+ Mf::Mix => Function::Regular("mix"),
+ Mf::Step => Function::Regular("step"),
+ Mf::SmoothStep => Function::Regular("smoothstep"),
+ Mf::Sqrt => Function::Regular("sqrt"),
+ Mf::InverseSqrt => Function::Regular("inverseSqrt"),
+ Mf::Transpose => Function::Regular("transpose"),
+ Mf::Determinant => Function::Regular("determinant"),
+ // bits
+ Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"),
+ Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"),
+ Mf::CountOneBits => Function::Regular("countOneBits"),
+ Mf::ReverseBits => Function::Regular("reverseBits"),
+ Mf::ExtractBits => Function::Regular("extractBits"),
+ Mf::InsertBits => Function::Regular("insertBits"),
+ Mf::FindLsb => Function::Regular("firstTrailingBit"),
+ Mf::FindMsb => Function::Regular("firstLeadingBit"),
+ // data packing
+ Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"),
+ Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"),
+ Mf::Pack2x16snorm => Function::Regular("pack2x16snorm"),
+ Mf::Pack2x16unorm => Function::Regular("pack2x16unorm"),
+ Mf::Pack2x16float => Function::Regular("pack2x16float"),
+ // data unpacking
+ Mf::Unpack4x8snorm => Function::Regular("unpack4x8snorm"),
+ Mf::Unpack4x8unorm => Function::Regular("unpack4x8unorm"),
+ Mf::Unpack2x16snorm => Function::Regular("unpack2x16snorm"),
+ Mf::Unpack2x16unorm => Function::Regular("unpack2x16unorm"),
+ Mf::Unpack2x16float => Function::Regular("unpack2x16float"),
+ Mf::Inverse => {
+ return Err(Error::UnsupportedMathFunction(fun));
+ }
+ };
+
+ match function {
+ Function::Regular(fun_name) => {
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ write!(self.out, ")")?
+ }
+ }
+ }
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.write_expr(module, vector, func_ctx)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ self.out.write_char(back::COMPONENTS[sc as usize])?;
+ }
+ }
+ Expression::Unary { op, expr } => {
+ let unary = match op {
+ crate::UnaryOperator::Negate => "-",
+ crate::UnaryOperator::Not => {
+ match *func_ctx.info[expr].ty.inner_with(&module.types) {
+ TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ ..
+ }
+ | TypeInner::Vector { .. } => "!",
+ _ => "~",
+ }
+ }
+ };
+
+ write!(self.out, "{unary}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+
+ write!(self.out, ")")?
+ }
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ write!(self.out, "select(")?;
+ self.write_expr(module, reject, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, accept, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, condition, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Derivative { axis, ctrl, expr } => {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ let op = match (axis, ctrl) {
+ (Axis::X, Ctrl::Coarse) => "dpdxCoarse",
+ (Axis::X, Ctrl::Fine) => "dpdxFine",
+ (Axis::X, Ctrl::None) => "dpdx",
+ (Axis::Y, Ctrl::Coarse) => "dpdyCoarse",
+ (Axis::Y, Ctrl::Fine) => "dpdyFine",
+ (Axis::Y, Ctrl::None) => "dpdy",
+ (Axis::Width, Ctrl::Coarse) => "fwidthCoarse",
+ (Axis::Width, Ctrl::Fine) => "fwidthFine",
+ (Axis::Width, Ctrl::None) => "fwidth",
+ };
+ write!(self.out, "{op}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+
+ let fun_name = match fun {
+ Rf::All => "all",
+ Rf::Any => "any",
+ _ => return Err(Error::UnsupportedRelationalFunction(fun)),
+ };
+ write!(self.out, "{fun_name}(")?;
+
+ self.write_expr(module, argument, func_ctx)?;
+
+ write!(self.out, ")")?
+ }
+ // Not supported yet
+ Expression::RayQueryGetIntersection { .. } => unreachable!(),
+ // Nothing to do here, since call expression already cached
+ Expression::CallResult(_)
+ | Expression::AtomicResult { .. }
+ | Expression::RayQueryProceedResult => {}
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write global variables
+ /// # Notes
+ /// Always adds a newline
+ fn write_global(
+ &mut self,
+ module: &Module,
+ global: &crate::GlobalVariable,
+ handle: Handle<crate::GlobalVariable>,
+ ) -> BackendResult {
+ // Write group and binding attributes if present
+ if let Some(ref binding) = global.binding {
+ self.write_attributes(&[
+ Attribute::Group(binding.group),
+ Attribute::Binding(binding.binding),
+ ])?;
+ writeln!(self.out)?;
+ }
+
+ // First write global name and address space if supported
+ write!(self.out, "var")?;
+ let (address, maybe_access) = address_space_str(global.space);
+ if let Some(space) = address {
+ write!(self.out, "<{space}")?;
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {access}")?;
+ }
+ write!(self.out, ">")?;
+ }
+ write!(
+ self.out,
+ " {}: ",
+ &self.names[&NameKey::GlobalVariable(handle)]
+ )?;
+
+ // Write global type
+ self.write_type(module, global.ty)?;
+
+ // Write initializer
+ if let Some(init) = global.init {
+ write!(self.out, " = ")?;
+ self.write_constant(module, init)?;
+ }
+
+ // End with semicolon
+ writeln!(self.out, ";")?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write constants
+ ///
+ /// # Notes
+ /// Doesn't add any newlines or leading/trailing spaces
+ fn write_constant(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Constant>,
+ ) -> BackendResult {
+ let constant = &module.constants[handle];
+ match constant.inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ ref value,
+ } => {
+ if constant.name.is_some() {
+ write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
+ } else {
+ self.write_scalar_value(*value)?;
+ }
+ }
+ crate::ConstantInner::Composite { ty, ref components } => {
+ self.write_type(module, ty)?;
+ write!(self.out, "(")?;
+
+ // Write the comma separated constants
+ for (index, constant) in components.iter().enumerate() {
+ self.write_constant(module, *constant)?;
+ // Only write a comma if isn't the last element
+ if index != components.len().saturating_sub(1) {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ }
+ write!(self.out, ")")?
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write global constants
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_global_constant(
+ &mut self,
+ module: &Module,
+ inner: &crate::ConstantInner,
+ handle: Handle<crate::Constant>,
+ ) -> BackendResult {
+ match *inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ ref value,
+ } => {
+ let name = &self.names[&NameKey::Constant(handle)];
+ // First write only constant name
+ write!(self.out, "const {name}: ")?;
+ // Next write constant type and value
+ match *value {
+ crate::ScalarValue::Sint(value) => {
+ write!(self.out, "i32 = {value}")?;
+ }
+ crate::ScalarValue::Uint(value) => {
+ write!(self.out, "u32 = {value}u")?;
+ }
+ crate::ScalarValue::Float(value) => {
+ // Floats are written using `Debug` instead of `Display` because it always appends the
+ // decimal part even it's zero
+ write!(self.out, "f32 = {value:?}")?;
+ }
+ crate::ScalarValue::Bool(value) => {
+ write!(self.out, "bool = {value}")?;
+ }
+ };
+ // End with semicolon
+ writeln!(self.out, ";")?;
+ }
+ crate::ConstantInner::Composite { ty, ref components } => {
+ let name = &self.names[&NameKey::Constant(handle)];
+ // First write only constant name
+ write!(self.out, "const {name}: ")?;
+ // Next write constant type
+ self.write_type(module, ty)?;
+
+ write!(self.out, " = ")?;
+ self.write_type(module, ty)?;
+
+ write!(self.out, "(")?;
+ for (index, constant) in components.iter().enumerate() {
+ self.write_constant(module, *constant)?;
+ // Only write a comma if isn't the last element
+ if index != components.len().saturating_sub(1) {
+ // The leading space is for readability only
+ write!(self.out, ", ")?;
+ }
+ }
+ write!(self.out, ");")?;
+ }
+ }
+ // End with extra newline for readability
+ writeln!(self.out)?;
+ Ok(())
+ }
+
+ // See https://github.com/rust-lang/rust-clippy/issues/4979.
+ #[allow(clippy::missing_const_for_fn)]
+ pub fn finish(self) -> W {
+ self.out
+ }
+}
+
+fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> {
+ use crate::BuiltIn as Bi;
+
+ Ok(match built_in {
+ Bi::VertexIndex => "vertex_index",
+ Bi::InstanceIndex => "instance_index",
+ Bi::Position { .. } => "position",
+ Bi::FrontFacing => "front_facing",
+ Bi::FragDepth => "frag_depth",
+ Bi::LocalInvocationId => "local_invocation_id",
+ Bi::LocalInvocationIndex => "local_invocation_index",
+ Bi::GlobalInvocationId => "global_invocation_id",
+ Bi::WorkGroupId => "workgroup_id",
+ Bi::NumWorkGroups => "num_workgroups",
+ Bi::SampleIndex => "sample_index",
+ Bi::SampleMask => "sample_mask",
+ Bi::PrimitiveIndex => "primitive_index",
+ Bi::ViewIndex => "view_index",
+ Bi::BaseInstance
+ | Bi::BaseVertex
+ | Bi::ClipDistance
+ | Bi::CullDistance
+ | Bi::PointSize
+ | Bi::PointCoord
+ | Bi::WorkGroupSize => {
+ return Err(Error::Custom(format!("Unsupported builtin {built_in:?}")))
+ }
+ })
+}
+
+const fn image_dimension_str(dim: crate::ImageDimension) -> &'static str {
+ use crate::ImageDimension as IDim;
+
+ match dim {
+ IDim::D1 => "1d",
+ IDim::D2 => "2d",
+ IDim::D3 => "3d",
+ IDim::Cube => "cube",
+ }
+}
+
+const fn scalar_kind_str(kind: crate::ScalarKind, width: u8) -> &'static str {
+ use crate::ScalarKind as Sk;
+
+ match (kind, width) {
+ (Sk::Float, 8) => "f64",
+ (Sk::Float, 4) => "f32",
+ (Sk::Sint, 4) => "i32",
+ (Sk::Uint, 4) => "u32",
+ (Sk::Bool, 1) => "bool",
+ _ => unreachable!(),
+ }
+}
+
+const fn storage_format_str(format: crate::StorageFormat) -> &'static str {
+ use crate::StorageFormat as Sf;
+
+ match format {
+ Sf::R8Unorm => "r8unorm",
+ Sf::R8Snorm => "r8snorm",
+ Sf::R8Uint => "r8uint",
+ Sf::R8Sint => "r8sint",
+ Sf::R16Uint => "r16uint",
+ Sf::R16Sint => "r16sint",
+ Sf::R16Float => "r16float",
+ Sf::Rg8Unorm => "rg8unorm",
+ Sf::Rg8Snorm => "rg8snorm",
+ Sf::Rg8Uint => "rg8uint",
+ Sf::Rg8Sint => "rg8sint",
+ Sf::R32Uint => "r32uint",
+ Sf::R32Sint => "r32sint",
+ Sf::R32Float => "r32float",
+ Sf::Rg16Uint => "rg16uint",
+ Sf::Rg16Sint => "rg16sint",
+ Sf::Rg16Float => "rg16float",
+ Sf::Rgba8Unorm => "rgba8unorm",
+ Sf::Rgba8Snorm => "rgba8snorm",
+ Sf::Rgba8Uint => "rgba8uint",
+ Sf::Rgba8Sint => "rgba8sint",
+ Sf::Rgb10a2Unorm => "rgb10a2unorm",
+ Sf::Rg11b10Float => "rg11b10float",
+ Sf::Rg32Uint => "rg32uint",
+ Sf::Rg32Sint => "rg32sint",
+ Sf::Rg32Float => "rg32float",
+ Sf::Rgba16Uint => "rgba16uint",
+ Sf::Rgba16Sint => "rgba16sint",
+ Sf::Rgba16Float => "rgba16float",
+ Sf::Rgba32Uint => "rgba32uint",
+ Sf::Rgba32Sint => "rgba32sint",
+ Sf::Rgba32Float => "rgba32float",
+ Sf::R16Unorm => "r16unorm",
+ Sf::R16Snorm => "r16snorm",
+ Sf::Rg16Unorm => "rg16unorm",
+ Sf::Rg16Snorm => "rg16snorm",
+ Sf::Rgba16Unorm => "rgba16unorm",
+ Sf::Rgba16Snorm => "rgba16snorm",
+ }
+}
+
+/// Helper function that returns the string corresponding to the WGSL interpolation qualifier
+const fn interpolation_str(interpolation: crate::Interpolation) -> &'static str {
+ use crate::Interpolation as I;
+
+ match interpolation {
+ I::Perspective => "perspective",
+ I::Linear => "linear",
+ I::Flat => "flat",
+ }
+}
+
+/// Return the WGSL auxiliary qualifier for the given sampling value.
+const fn sampling_str(sampling: crate::Sampling) -> &'static str {
+ use crate::Sampling as S;
+
+ match sampling {
+ S::Center => "",
+ S::Centroid => "centroid",
+ S::Sample => "sample",
+ }
+}
+
+const fn address_space_str(
+ space: crate::AddressSpace,
+) -> (Option<&'static str>, Option<&'static str>) {
+ use crate::AddressSpace as As;
+
+ (
+ Some(match space {
+ As::Private => "private",
+ As::Uniform => "uniform",
+ As::Storage { access } => {
+ if access.contains(crate::StorageAccess::STORE) {
+ return (Some("storage"), Some("read_write"));
+ } else {
+ "storage"
+ }
+ }
+ As::PushConstant => "push_constant",
+ As::WorkGroup => "workgroup",
+ As::Handle => return (None, None),
+ As::Function => "function",
+ }),
+ None,
+ )
+}
+
+fn map_binding_to_attribute(
+ binding: &crate::Binding,
+ scalar_kind: Option<crate::ScalarKind>,
+) -> Vec<Attribute> {
+ match *binding {
+ crate::Binding::BuiltIn(built_in) => {
+ if let crate::BuiltIn::Position { invariant: true } = built_in {
+ vec![Attribute::BuiltIn(built_in), Attribute::Invariant]
+ } else {
+ vec![Attribute::BuiltIn(built_in)]
+ }
+ }
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ } => match scalar_kind {
+ Some(crate::ScalarKind::Float) => vec![
+ Attribute::Location(location),
+ Attribute::Interpolate(interpolation, sampling),
+ ],
+ _ => vec![Attribute::Location(location)],
+ },
+ }
+}
diff --git a/third_party/rust/naga/src/block.rs b/third_party/rust/naga/src/block.rs
new file mode 100644
index 0000000000..ba4c95ab3a
--- /dev/null
+++ b/third_party/rust/naga/src/block.rs
@@ -0,0 +1,145 @@
+use crate::{Span, Statement};
+use std::ops::{Deref, DerefMut, RangeBounds};
+
+/// A code block is a vector of statements, with maybe a vector of spans.
+#[derive(Debug, Clone, Default)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "serialize", serde(transparent))]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub struct Block {
+ body: Vec<Statement>,
+ #[cfg(feature = "span")]
+ #[cfg_attr(feature = "serialize", serde(skip))]
+ span_info: Vec<Span>,
+}
+
+impl Block {
+ pub const fn new() -> Self {
+ Self {
+ body: Vec::new(),
+ #[cfg(feature = "span")]
+ span_info: Vec::new(),
+ }
+ }
+
+ pub fn from_vec(body: Vec<Statement>) -> Self {
+ #[cfg(feature = "span")]
+ let span_info = std::iter::repeat(Span::default())
+ .take(body.len())
+ .collect();
+ Self {
+ body,
+ #[cfg(feature = "span")]
+ span_info,
+ }
+ }
+
+ pub fn with_capacity(capacity: usize) -> Self {
+ Self {
+ body: Vec::with_capacity(capacity),
+ #[cfg(feature = "span")]
+ span_info: Vec::with_capacity(capacity),
+ }
+ }
+
+ #[allow(unused_variables)]
+ pub fn push(&mut self, end: Statement, span: Span) {
+ self.body.push(end);
+ #[cfg(feature = "span")]
+ self.span_info.push(span);
+ }
+
+ pub fn extend(&mut self, item: Option<(Statement, Span)>) {
+ if let Some((end, span)) = item {
+ self.push(end, span)
+ }
+ }
+
+ pub fn extend_block(&mut self, other: Self) {
+ #[cfg(feature = "span")]
+ self.span_info.extend(other.span_info);
+ self.body.extend(other.body);
+ }
+
+ pub fn append(&mut self, other: &mut Self) {
+ #[cfg(feature = "span")]
+ self.span_info.append(&mut other.span_info);
+ self.body.append(&mut other.body);
+ }
+
+ pub fn cull<R: RangeBounds<usize> + Clone>(&mut self, range: R) {
+ #[cfg(feature = "span")]
+ self.span_info.drain(range.clone());
+ self.body.drain(range);
+ }
+
+ pub fn splice<R: RangeBounds<usize> + Clone>(&mut self, range: R, other: Self) {
+ #[cfg(feature = "span")]
+ self.span_info
+ .splice(range.clone(), other.span_info.into_iter());
+ self.body.splice(range, other.body.into_iter());
+ }
+ pub fn span_iter(&self) -> impl Iterator<Item = (&Statement, &Span)> {
+ #[cfg(feature = "span")]
+ let span_iter = self.span_info.iter();
+ #[cfg(not(feature = "span"))]
+ let span_iter = std::iter::repeat_with(|| &Span::UNDEFINED);
+
+ self.body.iter().zip(span_iter)
+ }
+
+ pub fn span_iter_mut(&mut self) -> impl Iterator<Item = (&mut Statement, Option<&mut Span>)> {
+ #[cfg(feature = "span")]
+ let span_iter = self.span_info.iter_mut().map(Some);
+ #[cfg(not(feature = "span"))]
+ let span_iter = std::iter::repeat_with(|| None);
+
+ self.body.iter_mut().zip(span_iter)
+ }
+
+ pub fn is_empty(&self) -> bool {
+ self.body.is_empty()
+ }
+
+ pub fn len(&self) -> usize {
+ self.body.len()
+ }
+}
+
+impl Deref for Block {
+ type Target = [Statement];
+ fn deref(&self) -> &[Statement] {
+ &self.body
+ }
+}
+
+impl DerefMut for Block {
+ fn deref_mut(&mut self) -> &mut [Statement] {
+ &mut self.body
+ }
+}
+
+impl<'a> IntoIterator for &'a Block {
+ type Item = &'a Statement;
+ type IntoIter = std::slice::Iter<'a, Statement>;
+
+ fn into_iter(self) -> std::slice::Iter<'a, Statement> {
+ self.iter()
+ }
+}
+
+#[cfg(feature = "deserialize")]
+impl<'de> serde::Deserialize<'de> for Block {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ Ok(Self::from_vec(Vec::deserialize(deserializer)?))
+ }
+}
+
+impl From<Vec<Statement>> for Block {
+ fn from(body: Vec<Statement>) -> Self {
+ Self::from_vec(body)
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/ast.rs b/third_party/rust/naga/src/front/glsl/ast.rs
new file mode 100644
index 0000000000..cbb21f6de9
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/ast.rs
@@ -0,0 +1,393 @@
+use std::{borrow::Cow, fmt};
+
+use super::{builtins::MacroCall, context::ExprPos, Span};
+use crate::{
+ AddressSpace, BinaryOperator, Binding, Constant, Expression, Function, GlobalVariable, Handle,
+ Interpolation, Sampling, StorageAccess, Type, UnaryOperator,
+};
+
+#[derive(Debug, Clone, Copy)]
+pub enum GlobalLookupKind {
+ Variable(Handle<GlobalVariable>),
+ Constant(Handle<Constant>, Handle<Type>),
+ BlockSelect(Handle<GlobalVariable>, u32),
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct GlobalLookup {
+ pub kind: GlobalLookupKind,
+ pub entry_arg: Option<usize>,
+ pub mutable: bool,
+}
+
+#[derive(Debug, Clone)]
+pub struct ParameterInfo {
+ pub qualifier: ParameterQualifier,
+ /// Whether the parameter should be treated as a depth image instead of a
+ /// sampled image.
+ pub depth: bool,
+}
+
+/// How the function is implemented
+#[derive(Clone, Copy)]
+pub enum FunctionKind {
+ /// The function is user defined
+ Call(Handle<Function>),
+ /// The function is a builtin
+ Macro(MacroCall),
+}
+
+impl fmt::Debug for FunctionKind {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match *self {
+ Self::Call(_) => write!(f, "Call"),
+ Self::Macro(_) => write!(f, "Macro"),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Overload {
+ /// Normalized function parameters, modifiers are not applied
+ pub parameters: Vec<Handle<Type>>,
+ pub parameters_info: Vec<ParameterInfo>,
+ /// How the function is implemented
+ pub kind: FunctionKind,
+ /// Whether this function was already defined or is just a prototype
+ pub defined: bool,
+ /// Whether this overload is the one provided by the language or has
+ /// been redeclared by the user (builtins only)
+ pub internal: bool,
+ /// Whether or not this function returns void (nothing)
+ pub void: bool,
+}
+
+bitflags::bitflags! {
+ /// Tracks the variations of the builtin already generated, this is needed because some
+ /// builtins overloads can't be generated unless explicitly used, since they might cause
+ /// unneeded capabilities to be requested
+ #[derive(Default)]
+ pub struct BuiltinVariations: u32 {
+ /// Request the standard overloads
+ const STANDARD = 1 << 0;
+ /// Request overloads that use the double type
+ const DOUBLE = 1 << 1;
+ /// Request overloads that use samplerCubeArray(Shadow)
+ const CUBE_TEXTURES_ARRAY = 1 << 2;
+ /// Request overloads that use sampler2DMSArray
+ const D2_MULTI_TEXTURES_ARRAY = 1 << 3;
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct FunctionDeclaration {
+ pub overloads: Vec<Overload>,
+ /// Tracks the builtin overload variations that were already generated
+ pub variations: BuiltinVariations,
+}
+
+#[derive(Debug)]
+pub struct EntryArg {
+ pub name: Option<String>,
+ pub binding: Binding,
+ pub handle: Handle<GlobalVariable>,
+ pub storage: StorageQualifier,
+}
+
+#[derive(Debug, Clone)]
+pub struct VariableReference {
+ pub expr: Handle<Expression>,
+ /// Wether the variable is of a pointer type (and needs loading) or not
+ pub load: bool,
+ /// Wether the value of the variable can be changed or not
+ pub mutable: bool,
+ pub constant: Option<(Handle<Constant>, Handle<Type>)>,
+ pub entry_arg: Option<usize>,
+}
+
+#[derive(Debug, Clone)]
+pub struct HirExpr {
+ pub kind: HirExprKind,
+ pub meta: Span,
+}
+
+#[derive(Debug, Clone)]
+pub enum HirExprKind {
+ Access {
+ base: Handle<HirExpr>,
+ index: Handle<HirExpr>,
+ },
+ Select {
+ base: Handle<HirExpr>,
+ field: String,
+ },
+ Constant(Handle<Constant>),
+ Binary {
+ left: Handle<HirExpr>,
+ op: BinaryOperator,
+ right: Handle<HirExpr>,
+ },
+ Unary {
+ op: UnaryOperator,
+ expr: Handle<HirExpr>,
+ },
+ Variable(VariableReference),
+ Call(FunctionCall),
+ /// Represents the ternary operator in glsl (`:?`)
+ Conditional {
+ /// The expression that will decide which branch to take, must evaluate to a boolean
+ condition: Handle<HirExpr>,
+ /// The expression that will be evaluated if [`condition`] returns `true`
+ ///
+ /// [`condition`]: Self::Conditional::condition
+ accept: Handle<HirExpr>,
+ /// The expression that will be evaluated if [`condition`] returns `false`
+ ///
+ /// [`condition`]: Self::Conditional::condition
+ reject: Handle<HirExpr>,
+ },
+ Assign {
+ tgt: Handle<HirExpr>,
+ value: Handle<HirExpr>,
+ },
+ /// A prefix/postfix operator like `++`
+ PrePostfix {
+ /// The operation to be performed
+ op: BinaryOperator,
+ /// Whether this is a postfix or a prefix
+ postfix: bool,
+ /// The target expression
+ expr: Handle<HirExpr>,
+ },
+ /// A method call like `what.something(a, b, c)`
+ Method {
+ /// expression the method call applies to (`what` in the example)
+ expr: Handle<HirExpr>,
+ /// the method name (`something` in the example)
+ name: String,
+ /// the arguments to the method (`a`, `b`, and `c` in the example)
+ args: Vec<Handle<HirExpr>>,
+ },
+}
+
+#[derive(Debug, Hash, PartialEq, Eq)]
+pub enum QualifierKey<'a> {
+ String(Cow<'a, str>),
+ /// Used for `std140` and `std430` layout qualifiers
+ Layout,
+ /// Used for image formats
+ Format,
+}
+
+#[derive(Debug)]
+pub enum QualifierValue {
+ None,
+ Uint(u32),
+ Layout(StructLayout),
+ Format(crate::StorageFormat),
+}
+
+#[derive(Debug, Default)]
+pub struct TypeQualifiers<'a> {
+ pub span: Span,
+ pub storage: (StorageQualifier, Span),
+ pub invariant: Option<Span>,
+ pub interpolation: Option<(Interpolation, Span)>,
+ pub precision: Option<(Precision, Span)>,
+ pub sampling: Option<(Sampling, Span)>,
+ /// Memory qualifiers used in the declaration to set the storage access to be used
+ /// in declarations that support it (storage images and buffers)
+ pub storage_access: Option<(StorageAccess, Span)>,
+ pub layout_qualifiers: crate::FastHashMap<QualifierKey<'a>, (QualifierValue, Span)>,
+}
+
+impl<'a> TypeQualifiers<'a> {
+ /// Appends `errors` with errors for all unused qualifiers
+ pub fn unused_errors(&self, errors: &mut Vec<super::Error>) {
+ if let Some(meta) = self.invariant {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError(
+ "Invariant qualifier can only be used in in/out variables".into(),
+ ),
+ meta,
+ });
+ }
+
+ if let Some((_, meta)) = self.interpolation {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError(
+ "Interpolation qualifiers can only be used in in/out variables".into(),
+ ),
+ meta,
+ });
+ }
+
+ if let Some((_, meta)) = self.sampling {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError(
+ "Sampling qualifiers can only be used in in/out variables".into(),
+ ),
+ meta,
+ });
+ }
+
+ if let Some((_, meta)) = self.storage_access {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError(
+ "Memory qualifiers can only be used in storage variables".into(),
+ ),
+ meta,
+ });
+ }
+
+ for &(_, meta) in self.layout_qualifiers.values() {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError("Unexpected qualifier".into()),
+ meta,
+ });
+ }
+ }
+
+ /// Removes the layout qualifier with `name`, if it exists and adds an error if it isn't
+ /// a [`QualifierValue::Uint`]
+ pub fn uint_layout_qualifier(
+ &mut self,
+ name: &'a str,
+ errors: &mut Vec<super::Error>,
+ ) -> Option<u32> {
+ match self
+ .layout_qualifiers
+ .remove(&QualifierKey::String(name.into()))
+ {
+ Some((QualifierValue::Uint(v), _)) => Some(v),
+ Some((_, meta)) => {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError("Qualifier expects a uint value".into()),
+ meta,
+ });
+ // Return a dummy value instead of `None` to differentiate from
+ // the qualifier not existing, since some parts might require the
+ // qualifier to exist and throwing another error that it doesn't
+ // exist would be unhelpful
+ Some(0)
+ }
+ _ => None,
+ }
+ }
+
+ /// Removes the layout qualifier with `name`, if it exists and adds an error if it isn't
+ /// a [`QualifierValue::None`]
+ pub fn none_layout_qualifier(&mut self, name: &'a str, errors: &mut Vec<super::Error>) -> bool {
+ match self
+ .layout_qualifiers
+ .remove(&QualifierKey::String(name.into()))
+ {
+ Some((QualifierValue::None, _)) => true,
+ Some((_, meta)) => {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError(
+ "Qualifier doesn't expect a value".into(),
+ ),
+ meta,
+ });
+ // Return a `true` to since the qualifier is defined and adding
+ // another error for it not being defined would be unhelpful
+ true
+ }
+ _ => false,
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub enum FunctionCallKind {
+ TypeConstructor(Handle<Type>),
+ Function(String),
+}
+
+#[derive(Debug, Clone)]
+pub struct FunctionCall {
+ pub kind: FunctionCallKind,
+ pub args: Vec<Handle<HirExpr>>,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub enum StorageQualifier {
+ AddressSpace(AddressSpace),
+ Input,
+ Output,
+ Const,
+}
+
+impl Default for StorageQualifier {
+ fn default() -> Self {
+ StorageQualifier::AddressSpace(AddressSpace::Function)
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum StructLayout {
+ Std140,
+ Std430,
+}
+
+// TODO: Encode precision hints in the IR
+/// A precision hint used in GLSL declarations.
+///
+/// Precision hints can be used to either speed up shader execution or control
+/// the precision of arithmetic operations.
+///
+/// To use a precision hint simply add it before the type in the declaration.
+/// ```glsl
+/// mediump float a;
+/// ```
+///
+/// The default when no precision is declared is `highp` which means that all
+/// operations operate with the type defined width.
+///
+/// For `mediump` and `lowp` operations follow the spir-v
+/// [`RelaxedPrecision`][RelaxedPrecision] decoration semantics.
+///
+/// [RelaxedPrecision]: https://www.khronos.org/registry/SPIR-V/specs/unified1/SPIRV.html#_a_id_relaxedprecisionsection_a_relaxed_precision
+#[derive(Debug, Clone, PartialEq, Copy)]
+pub enum Precision {
+ /// `lowp` precision
+ Low,
+ /// `mediump` precision
+ Medium,
+ /// `highp` precision
+ High,
+}
+
+#[derive(Debug, Clone, PartialEq, Copy)]
+pub enum ParameterQualifier {
+ In,
+ Out,
+ InOut,
+ Const,
+}
+
+impl ParameterQualifier {
+ /// Returns true if the argument should be passed as a lhs expression
+ pub const fn is_lhs(&self) -> bool {
+ match *self {
+ ParameterQualifier::Out | ParameterQualifier::InOut => true,
+ _ => false,
+ }
+ }
+
+ /// Converts from a parameter qualifier into a [`ExprPos`](ExprPos)
+ pub const fn as_pos(&self) -> ExprPos {
+ match *self {
+ ParameterQualifier::Out | ParameterQualifier::InOut => ExprPos::Lhs,
+ _ => ExprPos::Rhs,
+ }
+ }
+}
+
+/// The GLSL profile used by a shader.
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub enum Profile {
+ /// The `core` profile, default when no profile is specified.
+ Core,
+}
diff --git a/third_party/rust/naga/src/front/glsl/builtins.rs b/third_party/rust/naga/src/front/glsl/builtins.rs
new file mode 100644
index 0000000000..14e223d6a4
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/builtins.rs
@@ -0,0 +1,2497 @@
+use super::{
+ ast::{
+ BuiltinVariations, FunctionDeclaration, FunctionKind, Overload, ParameterInfo,
+ ParameterQualifier,
+ },
+ context::Context,
+ Error, ErrorKind, Frontend, Result,
+};
+use crate::{
+ BinaryOperator, Block, Constant, DerivativeAxis as Axis, DerivativeControl as Ctrl, Expression,
+ Handle, ImageClass, ImageDimension as Dim, ImageQuery, MathFunction, Module,
+ RelationalFunction, SampleLevel, ScalarKind as Sk, Span, Type, TypeInner, UnaryOperator,
+ VectorSize,
+};
+
+impl crate::ScalarKind {
+ const fn dummy_storage_format(&self) -> crate::StorageFormat {
+ match *self {
+ Sk::Sint => crate::StorageFormat::R16Sint,
+ Sk::Uint => crate::StorageFormat::R16Uint,
+ _ => crate::StorageFormat::R16Float,
+ }
+ }
+}
+
+impl Module {
+ /// Helper function, to create a function prototype for a builtin
+ fn add_builtin(&mut self, args: Vec<TypeInner>, builtin: MacroCall) -> Overload {
+ let mut parameters = Vec::with_capacity(args.len());
+ let mut parameters_info = Vec::with_capacity(args.len());
+
+ for arg in args {
+ parameters.push(self.types.insert(
+ Type {
+ name: None,
+ inner: arg,
+ },
+ Span::default(),
+ ));
+ parameters_info.push(ParameterInfo {
+ qualifier: ParameterQualifier::In,
+ depth: false,
+ });
+ }
+
+ Overload {
+ parameters,
+ parameters_info,
+ kind: FunctionKind::Macro(builtin),
+ defined: false,
+ internal: true,
+ void: false,
+ }
+ }
+}
+
+const fn make_coords_arg(number_of_components: usize, kind: Sk) -> TypeInner {
+ let width = 4;
+
+ match number_of_components {
+ 1 => TypeInner::Scalar { kind, width },
+ _ => TypeInner::Vector {
+ size: match number_of_components {
+ 2 => VectorSize::Bi,
+ 3 => VectorSize::Tri,
+ _ => VectorSize::Quad,
+ },
+ kind,
+ width,
+ },
+ }
+}
+
+/// Inject builtins into the declaration
+///
+/// This is done to not add a large startup cost and not increase memory
+/// usage if it isn't needed.
+pub fn inject_builtin(
+ declaration: &mut FunctionDeclaration,
+ module: &mut Module,
+ name: &str,
+ mut variations: BuiltinVariations,
+) {
+ log::trace!(
+ "{} variations: {:?} {:?}",
+ name,
+ variations,
+ declaration.variations
+ );
+ // Don't regeneate variations
+ variations.remove(declaration.variations);
+ declaration.variations |= variations;
+
+ if variations.contains(BuiltinVariations::STANDARD) {
+ inject_standard_builtins(declaration, module, name)
+ }
+
+ if variations.contains(BuiltinVariations::DOUBLE) {
+ inject_double_builtin(declaration, module, name)
+ }
+
+ let width = 4;
+ match name {
+ "texture"
+ | "textureGrad"
+ | "textureGradOffset"
+ | "textureLod"
+ | "textureLodOffset"
+ | "textureOffset"
+ | "textureProj"
+ | "textureProjGrad"
+ | "textureProjGradOffset"
+ | "textureProjLod"
+ | "textureProjLodOffset"
+ | "textureProjOffset" => {
+ let f = |kind, dim, arrayed, multi, shadow| {
+ for bits in 0..=0b11 {
+ let variant = bits & 0b1 != 0;
+ let bias = bits & 0b10 != 0;
+
+ let (proj, offset, level_type) = match name {
+ // texture(gsampler, gvec P, [float bias]);
+ "texture" => (false, false, TextureLevelType::None),
+ // textureGrad(gsampler, gvec P, gvec dPdx, gvec dPdy);
+ "textureGrad" => (false, false, TextureLevelType::Grad),
+ // textureGradOffset(gsampler, gvec P, gvec dPdx, gvec dPdy, ivec offset);
+ "textureGradOffset" => (false, true, TextureLevelType::Grad),
+ // textureLod(gsampler, gvec P, float lod);
+ "textureLod" => (false, false, TextureLevelType::Lod),
+ // textureLodOffset(gsampler, gvec P, float lod, ivec offset);
+ "textureLodOffset" => (false, true, TextureLevelType::Lod),
+ // textureOffset(gsampler, gvec+1 P, ivec offset, [float bias]);
+ "textureOffset" => (false, true, TextureLevelType::None),
+ // textureProj(gsampler, gvec+1 P, [float bias]);
+ "textureProj" => (true, false, TextureLevelType::None),
+ // textureProjGrad(gsampler, gvec+1 P, gvec dPdx, gvec dPdy);
+ "textureProjGrad" => (true, false, TextureLevelType::Grad),
+ // textureProjGradOffset(gsampler, gvec+1 P, gvec dPdx, gvec dPdy, ivec offset);
+ "textureProjGradOffset" => (true, true, TextureLevelType::Grad),
+ // textureProjLod(gsampler, gvec+1 P, float lod);
+ "textureProjLod" => (true, false, TextureLevelType::Lod),
+ // textureProjLodOffset(gsampler, gvec+1 P, gvec dPdx, gvec dPdy, ivec offset);
+ "textureProjLodOffset" => (true, true, TextureLevelType::Lod),
+ // textureProjOffset(gsampler, gvec+1 P, ivec offset, [float bias]);
+ "textureProjOffset" => (true, true, TextureLevelType::None),
+ _ => unreachable!(),
+ };
+
+ let builtin = MacroCall::Texture {
+ proj,
+ offset,
+ shadow,
+ level_type,
+ };
+
+ // Parse out the variant settings.
+ let grad = level_type == TextureLevelType::Grad;
+ let lod = level_type == TextureLevelType::Lod;
+
+ let supports_variant = proj && !shadow;
+ if variant && !supports_variant {
+ continue;
+ }
+
+ if bias && !matches!(level_type, TextureLevelType::None) {
+ continue;
+ }
+
+ // Proj doesn't work with arrayed or Cube
+ if proj && (arrayed || dim == Dim::Cube) {
+ continue;
+ }
+
+ // texture operations with offset are not supported for cube maps
+ if dim == Dim::Cube && offset {
+ continue;
+ }
+
+ // sampler2DArrayShadow can't be used in textureLod or in texture with bias
+ if (lod || bias) && arrayed && shadow && dim == Dim::D2 {
+ continue;
+ }
+
+ // TODO: glsl supports using bias with depth samplers but naga doesn't
+ if bias && shadow {
+ continue;
+ }
+
+ let class = match shadow {
+ true => ImageClass::Depth { multi },
+ false => ImageClass::Sampled { kind, multi },
+ };
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ };
+
+ let num_coords_from_dim = image_dims_to_coords_size(dim).min(3);
+ let mut num_coords = num_coords_from_dim;
+
+ if shadow && proj {
+ num_coords = 4;
+ } else if dim == Dim::D1 && shadow {
+ num_coords = 3;
+ } else if shadow {
+ num_coords += 1;
+ } else if proj {
+ if variant && num_coords == 4 {
+ // Normal form already has 4 components, no need to have a variant form.
+ continue;
+ } else if variant {
+ num_coords = 4;
+ } else {
+ num_coords += 1;
+ }
+ }
+
+ if !(dim == Dim::D1 && shadow) {
+ num_coords += arrayed as usize;
+ }
+
+ // Special case: texture(gsamplerCubeArrayShadow) kicks the shadow compare ref to a separate argument,
+ // since it would otherwise take five arguments. It also can't take a bias, nor can it be proj/grad/lod/offset
+ // (presumably because nobody asked for it, and implementation complexity?)
+ if num_coords >= 5 {
+ if lod || grad || offset || proj || bias {
+ continue;
+ }
+ debug_assert!(dim == Dim::Cube && shadow && arrayed);
+ }
+ debug_assert!(num_coords <= 5);
+
+ let vector = make_coords_arg(num_coords, Sk::Float);
+ let mut args = vec![image, vector];
+
+ if num_coords == 5 {
+ args.push(TypeInner::Scalar {
+ kind: Sk::Float,
+ width,
+ });
+ }
+
+ match level_type {
+ TextureLevelType::Lod => {
+ args.push(TypeInner::Scalar {
+ kind: Sk::Float,
+ width,
+ });
+ }
+ TextureLevelType::Grad => {
+ args.push(make_coords_arg(num_coords_from_dim, Sk::Float));
+ args.push(make_coords_arg(num_coords_from_dim, Sk::Float));
+ }
+ TextureLevelType::None => {}
+ };
+
+ if offset {
+ args.push(make_coords_arg(num_coords_from_dim, Sk::Sint));
+ }
+
+ if bias {
+ args.push(TypeInner::Scalar {
+ kind: Sk::Float,
+ width,
+ });
+ }
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, builtin));
+ }
+ };
+
+ texture_args_generator(TextureArgsOptions::SHADOW | variations.into(), f)
+ }
+ "textureSize" => {
+ let f = |kind, dim, arrayed, multi, shadow| {
+ let class = match shadow {
+ true => ImageClass::Depth { multi },
+ false => ImageClass::Sampled { kind, multi },
+ };
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ };
+
+ let mut args = vec![image];
+
+ if !multi {
+ args.push(TypeInner::Scalar {
+ kind: Sk::Sint,
+ width,
+ })
+ }
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::TextureSize { arrayed }))
+ };
+
+ texture_args_generator(
+ TextureArgsOptions::SHADOW | TextureArgsOptions::MULTI | variations.into(),
+ f,
+ )
+ }
+ "texelFetch" | "texelFetchOffset" => {
+ let offset = "texelFetchOffset" == name;
+ let f = |kind, dim, arrayed, multi, _shadow| {
+ // Cube images aren't supported
+ if let Dim::Cube = dim {
+ return;
+ }
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class: ImageClass::Sampled { kind, multi },
+ };
+
+ let dim_value = image_dims_to_coords_size(dim);
+ let coordinates = make_coords_arg(dim_value + arrayed as usize, Sk::Sint);
+
+ let mut args = vec![
+ image,
+ coordinates,
+ TypeInner::Scalar {
+ kind: Sk::Sint,
+ width,
+ },
+ ];
+
+ if offset {
+ args.push(make_coords_arg(dim_value, Sk::Sint));
+ }
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::ImageLoad { multi }))
+ };
+
+ // Don't generate shadow images since they aren't supported
+ texture_args_generator(TextureArgsOptions::MULTI | variations.into(), f)
+ }
+ "imageSize" => {
+ let f = |kind: Sk, dim, arrayed, _, _| {
+ // Naga doesn't support cube images and it's usefulness
+ // is questionable, so they won't be supported for now
+ if dim == Dim::Cube {
+ return;
+ }
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class: ImageClass::Storage {
+ format: kind.dummy_storage_format(),
+ access: crate::StorageAccess::empty(),
+ },
+ };
+
+ declaration
+ .overloads
+ .push(module.add_builtin(vec![image], MacroCall::TextureSize { arrayed }))
+ };
+
+ texture_args_generator(variations.into(), f)
+ }
+ "imageLoad" => {
+ let f = |kind: Sk, dim, arrayed, _, _| {
+ // Naga doesn't support cube images and it's usefulness
+ // is questionable, so they won't be supported for now
+ if dim == Dim::Cube {
+ return;
+ }
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class: ImageClass::Storage {
+ format: kind.dummy_storage_format(),
+ access: crate::StorageAccess::LOAD,
+ },
+ };
+
+ let dim_value = image_dims_to_coords_size(dim);
+ let mut coord_size = dim_value + arrayed as usize;
+ // > Every OpenGL API call that operates on cubemap array
+ // > textures takes layer-faces, not array layers
+ //
+ // So this means that imageCubeArray only takes a three component
+ // vector coordinate and the third component is a layer index.
+ if Dim::Cube == dim && arrayed {
+ coord_size = 3
+ }
+ let coordinates = make_coords_arg(coord_size, Sk::Sint);
+
+ let args = vec![image, coordinates];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::ImageLoad { multi: false }))
+ };
+
+ // Don't generate shadow nor multisampled images since they aren't supported
+ texture_args_generator(variations.into(), f)
+ }
+ "imageStore" => {
+ let f = |kind: Sk, dim, arrayed, _, _| {
+ // Naga doesn't support cube images and it's usefulness
+ // is questionable, so they won't be supported for now
+ if dim == Dim::Cube {
+ return;
+ }
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class: ImageClass::Storage {
+ format: kind.dummy_storage_format(),
+ access: crate::StorageAccess::STORE,
+ },
+ };
+
+ let dim_value = image_dims_to_coords_size(dim);
+ let mut coord_size = dim_value + arrayed as usize;
+ // > Every OpenGL API call that operates on cubemap array
+ // > textures takes layer-faces, not array layers
+ //
+ // So this means that imageCubeArray only takes a three component
+ // vector coordinate and the third component is a layer index.
+ if Dim::Cube == dim && arrayed {
+ coord_size = 3
+ }
+ let coordinates = make_coords_arg(coord_size, Sk::Sint);
+
+ let args = vec![
+ image,
+ coordinates,
+ TypeInner::Vector {
+ size: VectorSize::Quad,
+ kind,
+ width,
+ },
+ ];
+
+ let mut overload = module.add_builtin(args, MacroCall::ImageStore);
+ overload.void = true;
+ declaration.overloads.push(overload)
+ };
+
+ // Don't generate shadow nor multisampled images since they aren't supported
+ texture_args_generator(variations.into(), f)
+ }
+ _ => {}
+ }
+}
+
+/// Injects the builtins into declaration that don't need any special variations
+fn inject_standard_builtins(
+ declaration: &mut FunctionDeclaration,
+ module: &mut Module,
+ name: &str,
+) {
+ let width = 4;
+ match name {
+ "sampler1D" | "sampler1DArray" | "sampler2D" | "sampler2DArray" | "sampler2DMS"
+ | "sampler2DMSArray" | "sampler3D" | "samplerCube" | "samplerCubeArray" => {
+ declaration.overloads.push(module.add_builtin(
+ vec![
+ TypeInner::Image {
+ dim: match name {
+ "sampler1D" | "sampler1DArray" => Dim::D1,
+ "sampler2D" | "sampler2DArray" | "sampler2DMS" | "sampler2DMSArray" => {
+ Dim::D2
+ }
+ "sampler3D" => Dim::D3,
+ _ => Dim::Cube,
+ },
+ arrayed: matches!(
+ name,
+ "sampler1DArray"
+ | "sampler2DArray"
+ | "sampler2DMSArray"
+ | "samplerCubeArray"
+ ),
+ class: ImageClass::Sampled {
+ kind: Sk::Float,
+ multi: matches!(name, "sampler2DMS" | "sampler2DMSArray"),
+ },
+ },
+ TypeInner::Sampler { comparison: false },
+ ],
+ MacroCall::Sampler,
+ ))
+ }
+ "sampler1DShadow"
+ | "sampler1DArrayShadow"
+ | "sampler2DShadow"
+ | "sampler2DArrayShadow"
+ | "samplerCubeShadow"
+ | "samplerCubeArrayShadow" => {
+ let dim = match name {
+ "sampler1DShadow" | "sampler1DArrayShadow" => Dim::D1,
+ "sampler2DShadow" | "sampler2DArrayShadow" => Dim::D2,
+ _ => Dim::Cube,
+ };
+ let arrayed = matches!(
+ name,
+ "sampler1DArrayShadow" | "sampler2DArrayShadow" | "samplerCubeArrayShadow"
+ );
+
+ for i in 0..2 {
+ let ty = TypeInner::Image {
+ dim,
+ arrayed,
+ class: match i {
+ 0 => ImageClass::Sampled {
+ kind: Sk::Float,
+ multi: false,
+ },
+ _ => ImageClass::Depth { multi: false },
+ },
+ };
+
+ declaration.overloads.push(module.add_builtin(
+ vec![ty, TypeInner::Sampler { comparison: true }],
+ MacroCall::SamplerShadow,
+ ))
+ }
+ }
+ "sin" | "exp" | "exp2" | "sinh" | "cos" | "cosh" | "tan" | "tanh" | "acos" | "asin"
+ | "log" | "log2" | "radians" | "degrees" | "asinh" | "acosh" | "atanh"
+ | "floatBitsToInt" | "floatBitsToUint" | "dFdx" | "dFdxFine" | "dFdxCoarse" | "dFdy"
+ | "dFdyFine" | "dFdyCoarse" | "fwidth" | "fwidthFine" | "fwidthCoarse" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let kind = Sk::Float;
+
+ declaration.overloads.push(module.add_builtin(
+ vec![match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ }],
+ match name {
+ "sin" => MacroCall::MathFunction(MathFunction::Sin),
+ "exp" => MacroCall::MathFunction(MathFunction::Exp),
+ "exp2" => MacroCall::MathFunction(MathFunction::Exp2),
+ "sinh" => MacroCall::MathFunction(MathFunction::Sinh),
+ "cos" => MacroCall::MathFunction(MathFunction::Cos),
+ "cosh" => MacroCall::MathFunction(MathFunction::Cosh),
+ "tan" => MacroCall::MathFunction(MathFunction::Tan),
+ "tanh" => MacroCall::MathFunction(MathFunction::Tanh),
+ "acos" => MacroCall::MathFunction(MathFunction::Acos),
+ "asin" => MacroCall::MathFunction(MathFunction::Asin),
+ "log" => MacroCall::MathFunction(MathFunction::Log),
+ "log2" => MacroCall::MathFunction(MathFunction::Log2),
+ "asinh" => MacroCall::MathFunction(MathFunction::Asinh),
+ "acosh" => MacroCall::MathFunction(MathFunction::Acosh),
+ "atanh" => MacroCall::MathFunction(MathFunction::Atanh),
+ "radians" => MacroCall::MathFunction(MathFunction::Radians),
+ "degrees" => MacroCall::MathFunction(MathFunction::Degrees),
+ "floatBitsToInt" => MacroCall::BitCast(Sk::Sint),
+ "floatBitsToUint" => MacroCall::BitCast(Sk::Uint),
+ "dFdxCoarse" => MacroCall::Derivate(Axis::X, Ctrl::Coarse),
+ "dFdyCoarse" => MacroCall::Derivate(Axis::Y, Ctrl::Coarse),
+ "fwidthCoarse" => MacroCall::Derivate(Axis::Width, Ctrl::Coarse),
+ "dFdxFine" => MacroCall::Derivate(Axis::X, Ctrl::Fine),
+ "dFdyFine" => MacroCall::Derivate(Axis::Y, Ctrl::Fine),
+ "fwidthFine" => MacroCall::Derivate(Axis::Width, Ctrl::Fine),
+ "dFdx" => MacroCall::Derivate(Axis::X, Ctrl::None),
+ "dFdy" => MacroCall::Derivate(Axis::Y, Ctrl::None),
+ "fwidth" => MacroCall::Derivate(Axis::Width, Ctrl::None),
+ _ => unreachable!(),
+ },
+ ))
+ }
+ }
+ "intBitsToFloat" | "uintBitsToFloat" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let kind = match name {
+ "intBitsToFloat" => Sk::Sint,
+ _ => Sk::Uint,
+ };
+
+ declaration.overloads.push(module.add_builtin(
+ vec![match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ }],
+ MacroCall::BitCast(Sk::Float),
+ ))
+ }
+ }
+ "pow" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let kind = Sk::Float;
+ let ty = || match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+
+ declaration.overloads.push(
+ module
+ .add_builtin(vec![ty(), ty()], MacroCall::MathFunction(MathFunction::Pow)),
+ )
+ }
+ }
+ "abs" | "sign" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ // bit 2 - float/sint
+ for bits in 0..0b1000 {
+ let size = match bits & 0b11 {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let kind = match bits >> 2 {
+ 0b0 => Sk::Float,
+ _ => Sk::Sint,
+ };
+
+ let args = vec![match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ }];
+
+ declaration.overloads.push(module.add_builtin(
+ args,
+ MacroCall::MathFunction(match name {
+ "abs" => MathFunction::Abs,
+ "sign" => MathFunction::Sign,
+ _ => unreachable!(),
+ }),
+ ))
+ }
+ }
+ "bitCount" | "bitfieldReverse" | "bitfieldExtract" | "bitfieldInsert" | "findLSB"
+ | "findMSB" => {
+ let fun = match name {
+ "bitCount" => MathFunction::CountOneBits,
+ "bitfieldReverse" => MathFunction::ReverseBits,
+ "bitfieldExtract" => MathFunction::ExtractBits,
+ "bitfieldInsert" => MathFunction::InsertBits,
+ "findLSB" => MathFunction::FindLsb,
+ "findMSB" => MathFunction::FindMsb,
+ _ => unreachable!(),
+ };
+
+ let mc = match fun {
+ MathFunction::ExtractBits => MacroCall::BitfieldExtract,
+ MathFunction::InsertBits => MacroCall::BitfieldInsert,
+ _ => MacroCall::MathFunction(fun),
+ };
+
+ // bits layout
+ // bit 0 - int/uint
+ // bit 1 through 2 - dims
+ for bits in 0..0b1000 {
+ let kind = match bits & 0b1 {
+ 0b0 => Sk::Sint,
+ _ => Sk::Uint,
+ };
+ let size = match bits >> 1 {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ let ty = || match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+
+ let mut args = vec![ty()];
+
+ match fun {
+ MathFunction::ExtractBits => {
+ args.push(TypeInner::Scalar {
+ kind: Sk::Sint,
+ width: 4,
+ });
+ args.push(TypeInner::Scalar {
+ kind: Sk::Sint,
+ width: 4,
+ });
+ }
+ MathFunction::InsertBits => {
+ args.push(ty());
+ args.push(TypeInner::Scalar {
+ kind: Sk::Sint,
+ width: 4,
+ });
+ args.push(TypeInner::Scalar {
+ kind: Sk::Sint,
+ width: 4,
+ });
+ }
+ _ => {}
+ }
+
+ // we need to cast the return type of findLsb / findMsb
+ let mc = if kind == Sk::Uint {
+ match mc {
+ MacroCall::MathFunction(MathFunction::FindLsb) => MacroCall::FindLsbUint,
+ MacroCall::MathFunction(MathFunction::FindMsb) => MacroCall::FindMsbUint,
+ mc => mc,
+ }
+ } else {
+ mc
+ };
+
+ declaration.overloads.push(module.add_builtin(args, mc))
+ }
+ }
+ "packSnorm4x8" | "packUnorm4x8" | "packSnorm2x16" | "packUnorm2x16" | "packHalf2x16" => {
+ let fun = match name {
+ "packSnorm4x8" => MathFunction::Pack4x8snorm,
+ "packUnorm4x8" => MathFunction::Pack4x8unorm,
+ "packSnorm2x16" => MathFunction::Pack2x16unorm,
+ "packUnorm2x16" => MathFunction::Pack2x16snorm,
+ "packHalf2x16" => MathFunction::Pack2x16float,
+ _ => unreachable!(),
+ };
+
+ let ty = match fun {
+ MathFunction::Pack4x8snorm | MathFunction::Pack4x8unorm => TypeInner::Vector {
+ size: crate::VectorSize::Quad,
+ kind: Sk::Float,
+ width: 4,
+ },
+ MathFunction::Pack2x16unorm
+ | MathFunction::Pack2x16snorm
+ | MathFunction::Pack2x16float => TypeInner::Vector {
+ size: crate::VectorSize::Bi,
+ kind: Sk::Float,
+ width: 4,
+ },
+ _ => unreachable!(),
+ };
+
+ let args = vec![ty];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(fun)));
+ }
+ "unpackSnorm4x8" | "unpackUnorm4x8" | "unpackSnorm2x16" | "unpackUnorm2x16"
+ | "unpackHalf2x16" => {
+ let fun = match name {
+ "unpackSnorm4x8" => MathFunction::Unpack4x8snorm,
+ "unpackUnorm4x8" => MathFunction::Unpack4x8unorm,
+ "unpackSnorm2x16" => MathFunction::Unpack2x16snorm,
+ "unpackUnorm2x16" => MathFunction::Unpack2x16unorm,
+ "unpackHalf2x16" => MathFunction::Unpack2x16float,
+ _ => unreachable!(),
+ };
+
+ let args = vec![TypeInner::Scalar {
+ kind: Sk::Uint,
+ width: 4,
+ }];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(fun)));
+ }
+ "atan" => {
+ // bits layout
+ // bit 0 - atan/atan2
+ // bit 1 through 2 - dims
+ for bits in 0..0b1000 {
+ let fun = match bits & 0b1 {
+ 0b0 => MathFunction::Atan,
+ _ => MathFunction::Atan2,
+ };
+ let size = match bits >> 1 {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let kind = Sk::Float;
+ let ty = || match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+
+ let mut args = vec![ty()];
+
+ if fun == MathFunction::Atan2 {
+ args.push(ty())
+ }
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(fun)))
+ }
+ }
+ "all" | "any" | "not" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b11 {
+ let size = match bits {
+ 0b00 => VectorSize::Bi,
+ 0b01 => VectorSize::Tri,
+ _ => VectorSize::Quad,
+ };
+
+ let args = vec![TypeInner::Vector {
+ size,
+ kind: Sk::Bool,
+ width: crate::BOOL_WIDTH,
+ }];
+
+ let fun = match name {
+ "all" => MacroCall::Relational(RelationalFunction::All),
+ "any" => MacroCall::Relational(RelationalFunction::Any),
+ "not" => MacroCall::Unary(UnaryOperator::Not),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "lessThan" | "greaterThan" | "lessThanEqual" | "greaterThanEqual" => {
+ for bits in 0..0b1001 {
+ let (size, kind) = match bits {
+ 0b0000 => (VectorSize::Bi, Sk::Float),
+ 0b0001 => (VectorSize::Tri, Sk::Float),
+ 0b0010 => (VectorSize::Quad, Sk::Float),
+ 0b0011 => (VectorSize::Bi, Sk::Sint),
+ 0b0100 => (VectorSize::Tri, Sk::Sint),
+ 0b0101 => (VectorSize::Quad, Sk::Sint),
+ 0b0110 => (VectorSize::Bi, Sk::Uint),
+ 0b0111 => (VectorSize::Tri, Sk::Uint),
+ _ => (VectorSize::Quad, Sk::Uint),
+ };
+
+ let ty = || TypeInner::Vector { size, kind, width };
+ let args = vec![ty(), ty()];
+
+ let fun = MacroCall::Binary(match name {
+ "lessThan" => BinaryOperator::Less,
+ "greaterThan" => BinaryOperator::Greater,
+ "lessThanEqual" => BinaryOperator::LessEqual,
+ "greaterThanEqual" => BinaryOperator::GreaterEqual,
+ _ => unreachable!(),
+ });
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "equal" | "notEqual" => {
+ for bits in 0..0b1100 {
+ let (size, kind) = match bits {
+ 0b0000 => (VectorSize::Bi, Sk::Float),
+ 0b0001 => (VectorSize::Tri, Sk::Float),
+ 0b0010 => (VectorSize::Quad, Sk::Float),
+ 0b0011 => (VectorSize::Bi, Sk::Sint),
+ 0b0100 => (VectorSize::Tri, Sk::Sint),
+ 0b0101 => (VectorSize::Quad, Sk::Sint),
+ 0b0110 => (VectorSize::Bi, Sk::Uint),
+ 0b0111 => (VectorSize::Tri, Sk::Uint),
+ 0b1000 => (VectorSize::Quad, Sk::Uint),
+ 0b1001 => (VectorSize::Bi, Sk::Bool),
+ 0b1010 => (VectorSize::Tri, Sk::Bool),
+ _ => (VectorSize::Quad, Sk::Bool),
+ };
+
+ let width = if let Sk::Bool = kind {
+ crate::BOOL_WIDTH
+ } else {
+ width
+ };
+
+ let ty = || TypeInner::Vector { size, kind, width };
+ let args = vec![ty(), ty()];
+
+ let fun = MacroCall::Binary(match name {
+ "equal" => BinaryOperator::Equal,
+ "notEqual" => BinaryOperator::NotEqual,
+ _ => unreachable!(),
+ });
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "min" | "max" => {
+ // bits layout
+ // bit 0 through 1 - scalar kind
+ // bit 2 through 4 - dims
+ for bits in 0..0b11100 {
+ let kind = match bits & 0b11 {
+ 0b00 => Sk::Float,
+ 0b01 => Sk::Sint,
+ 0b10 => Sk::Uint,
+ _ => continue,
+ };
+ let (size, second_size) = match bits >> 2 {
+ 0b000 => (None, None),
+ 0b001 => (Some(VectorSize::Bi), None),
+ 0b010 => (Some(VectorSize::Tri), None),
+ 0b011 => (Some(VectorSize::Quad), None),
+ 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)),
+ 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)),
+ _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)),
+ };
+
+ let args = vec![
+ match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ },
+ match second_size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ },
+ ];
+
+ let fun = match name {
+ "max" => MacroCall::Splatted(MathFunction::Max, size, 1),
+ "min" => MacroCall::Splatted(MathFunction::Min, size, 1),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "mix" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ // bit 2 through 4 - types
+ //
+ // 0b10011 is the last element since splatted single elements
+ // were already added
+ for bits in 0..0b10011 {
+ let size = match bits & 0b11 {
+ 0b00 => Some(VectorSize::Bi),
+ 0b01 => Some(VectorSize::Tri),
+ 0b10 => Some(VectorSize::Quad),
+ _ => None,
+ };
+ let (kind, splatted, boolean) = match bits >> 2 {
+ 0b000 => (Sk::Sint, false, true),
+ 0b001 => (Sk::Uint, false, true),
+ 0b010 => (Sk::Float, false, true),
+ 0b011 => (Sk::Float, false, false),
+ _ => (Sk::Float, true, false),
+ };
+
+ let ty = |kind, width| match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+ let args = vec![
+ ty(kind, width),
+ ty(kind, width),
+ match (boolean, splatted) {
+ (true, _) => ty(Sk::Bool, crate::BOOL_WIDTH),
+ (_, false) => TypeInner::Scalar { kind, width },
+ _ => ty(kind, width),
+ },
+ ];
+
+ declaration.overloads.push(module.add_builtin(
+ args,
+ match boolean {
+ true => MacroCall::MixBoolean,
+ false => MacroCall::Splatted(MathFunction::Mix, size, 2),
+ },
+ ))
+ }
+ }
+ "clamp" => {
+ // bits layout
+ // bit 0 through 1 - float/int/uint
+ // bit 2 through 3 - dims
+ // bit 4 - splatted
+ //
+ // 0b11010 is the last element since splatted single elements
+ // were already added
+ for bits in 0..0b11011 {
+ let kind = match bits & 0b11 {
+ 0b00 => Sk::Float,
+ 0b01 => Sk::Sint,
+ 0b10 => Sk::Uint,
+ _ => continue,
+ };
+ let size = match (bits >> 2) & 0b11 {
+ 0b00 => Some(VectorSize::Bi),
+ 0b01 => Some(VectorSize::Tri),
+ 0b10 => Some(VectorSize::Quad),
+ _ => None,
+ };
+ let splatted = bits & 0b10000 == 0b10000;
+
+ let base_ty = || match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+ let limit_ty = || match splatted {
+ true => TypeInner::Scalar { kind, width },
+ false => base_ty(),
+ };
+
+ let args = vec![base_ty(), limit_ty(), limit_ty()];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::Clamp(size)))
+ }
+ }
+ "barrier" => declaration
+ .overloads
+ .push(module.add_builtin(Vec::new(), MacroCall::Barrier)),
+ // Add common builtins with floats
+ _ => inject_common_builtin(declaration, module, name, 4),
+ }
+}
+
+/// Injects the builtins into declaration that need doubles
+fn inject_double_builtin(declaration: &mut FunctionDeclaration, module: &mut Module, name: &str) {
+ let width = 8;
+ match name {
+ "abs" | "sign" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let kind = Sk::Float;
+
+ let args = vec![match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ }];
+
+ declaration.overloads.push(module.add_builtin(
+ args,
+ MacroCall::MathFunction(match name {
+ "abs" => MathFunction::Abs,
+ "sign" => MathFunction::Sign,
+ _ => unreachable!(),
+ }),
+ ))
+ }
+ }
+ "min" | "max" => {
+ // bits layout
+ // bit 0 through 2 - dims
+ for bits in 0..0b111 {
+ let (size, second_size) = match bits {
+ 0b000 => (None, None),
+ 0b001 => (Some(VectorSize::Bi), None),
+ 0b010 => (Some(VectorSize::Tri), None),
+ 0b011 => (Some(VectorSize::Quad), None),
+ 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)),
+ 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)),
+ _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)),
+ };
+ let kind = Sk::Float;
+
+ let args = vec![
+ match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ },
+ match second_size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ },
+ ];
+
+ let fun = match name {
+ "max" => MacroCall::Splatted(MathFunction::Max, size, 1),
+ "min" => MacroCall::Splatted(MathFunction::Min, size, 1),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "mix" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ // bit 2 through 3 - splatted/boolean
+ //
+ // 0b1010 is the last element since splatted with single elements
+ // is equal to normal single elements
+ for bits in 0..0b1011 {
+ let size = match bits & 0b11 {
+ 0b00 => Some(VectorSize::Quad),
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => None,
+ };
+ let kind = Sk::Float;
+ let (splatted, boolean) = match bits >> 2 {
+ 0b00 => (false, false),
+ 0b01 => (false, true),
+ _ => (true, false),
+ };
+
+ let ty = |kind, width| match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+ let args = vec![
+ ty(kind, width),
+ ty(kind, width),
+ match (boolean, splatted) {
+ (true, _) => ty(Sk::Bool, crate::BOOL_WIDTH),
+ (_, false) => TypeInner::Scalar { kind, width },
+ _ => ty(kind, width),
+ },
+ ];
+
+ declaration.overloads.push(module.add_builtin(
+ args,
+ match boolean {
+ true => MacroCall::MixBoolean,
+ false => MacroCall::Splatted(MathFunction::Mix, size, 2),
+ },
+ ))
+ }
+ }
+ "clamp" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ // bit 2 - splatted
+ //
+ // 0b110 is the last element since splatted with single elements
+ // is equal to normal single elements
+ for bits in 0..0b111 {
+ let kind = Sk::Float;
+ let size = match bits & 0b11 {
+ 0b00 => Some(VectorSize::Bi),
+ 0b01 => Some(VectorSize::Tri),
+ 0b10 => Some(VectorSize::Quad),
+ _ => None,
+ };
+ let splatted = bits & 0b100 == 0b100;
+
+ let base_ty = || match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+ let limit_ty = || match splatted {
+ true => TypeInner::Scalar { kind, width },
+ false => base_ty(),
+ };
+
+ let args = vec![base_ty(), limit_ty(), limit_ty()];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::Clamp(size)))
+ }
+ }
+ "lessThan" | "greaterThan" | "lessThanEqual" | "greaterThanEqual" | "equal"
+ | "notEqual" => {
+ for bits in 0..0b11 {
+ let (size, kind) = match bits {
+ 0b00 => (VectorSize::Bi, Sk::Float),
+ 0b01 => (VectorSize::Tri, Sk::Float),
+ _ => (VectorSize::Quad, Sk::Float),
+ };
+
+ let ty = || TypeInner::Vector { size, kind, width };
+ let args = vec![ty(), ty()];
+
+ let fun = MacroCall::Binary(match name {
+ "lessThan" => BinaryOperator::Less,
+ "greaterThan" => BinaryOperator::Greater,
+ "lessThanEqual" => BinaryOperator::LessEqual,
+ "greaterThanEqual" => BinaryOperator::GreaterEqual,
+ "equal" => BinaryOperator::Equal,
+ "notEqual" => BinaryOperator::NotEqual,
+ _ => unreachable!(),
+ });
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ // Add common builtins with doubles
+ _ => inject_common_builtin(declaration, module, name, 8),
+ }
+}
+
+/// Injects the builtins into declaration that can used either float or doubles
+fn inject_common_builtin(
+ declaration: &mut FunctionDeclaration,
+ module: &mut Module,
+ name: &str,
+ float_width: crate::Bytes,
+) {
+ match name {
+ "ceil" | "round" | "roundEven" | "floor" | "fract" | "trunc" | "sqrt" | "inversesqrt"
+ | "normalize" | "length" | "isinf" | "isnan" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ let args = vec![match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ None => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ }];
+
+ let fun = match name {
+ "ceil" => MacroCall::MathFunction(MathFunction::Ceil),
+ "round" | "roundEven" => MacroCall::MathFunction(MathFunction::Round),
+ "floor" => MacroCall::MathFunction(MathFunction::Floor),
+ "fract" => MacroCall::MathFunction(MathFunction::Fract),
+ "trunc" => MacroCall::MathFunction(MathFunction::Trunc),
+ "sqrt" => MacroCall::MathFunction(MathFunction::Sqrt),
+ "inversesqrt" => MacroCall::MathFunction(MathFunction::InverseSqrt),
+ "normalize" => MacroCall::MathFunction(MathFunction::Normalize),
+ "length" => MacroCall::MathFunction(MathFunction::Length),
+ "isinf" => MacroCall::Relational(RelationalFunction::IsInf),
+ "isnan" => MacroCall::Relational(RelationalFunction::IsNan),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "dot" | "reflect" | "distance" | "ldexp" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let ty = || match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ None => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ };
+
+ let fun = match name {
+ "dot" => MacroCall::MathFunction(MathFunction::Dot),
+ "reflect" => MacroCall::MathFunction(MathFunction::Reflect),
+ "distance" => MacroCall::MathFunction(MathFunction::Distance),
+ "ldexp" => MacroCall::MathFunction(MathFunction::Ldexp),
+ _ => unreachable!(),
+ };
+
+ declaration
+ .overloads
+ .push(module.add_builtin(vec![ty(), ty()], fun))
+ }
+ }
+ "transpose" => {
+ // bits layout
+ // bit 0 through 3 - dims
+ for bits in 0..0b1001 {
+ let (rows, columns) = match bits {
+ 0b0000 => (VectorSize::Bi, VectorSize::Bi),
+ 0b0001 => (VectorSize::Bi, VectorSize::Tri),
+ 0b0010 => (VectorSize::Bi, VectorSize::Quad),
+ 0b0011 => (VectorSize::Tri, VectorSize::Bi),
+ 0b0100 => (VectorSize::Tri, VectorSize::Tri),
+ 0b0101 => (VectorSize::Tri, VectorSize::Quad),
+ 0b0110 => (VectorSize::Quad, VectorSize::Bi),
+ 0b0111 => (VectorSize::Quad, VectorSize::Tri),
+ _ => (VectorSize::Quad, VectorSize::Quad),
+ };
+
+ declaration.overloads.push(module.add_builtin(
+ vec![TypeInner::Matrix {
+ columns,
+ rows,
+ width: float_width,
+ }],
+ MacroCall::MathFunction(MathFunction::Transpose),
+ ))
+ }
+ }
+ "inverse" | "determinant" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b11 {
+ let (rows, columns) = match bits {
+ 0b00 => (VectorSize::Bi, VectorSize::Bi),
+ 0b01 => (VectorSize::Tri, VectorSize::Tri),
+ _ => (VectorSize::Quad, VectorSize::Quad),
+ };
+
+ let args = vec![TypeInner::Matrix {
+ columns,
+ rows,
+ width: float_width,
+ }];
+
+ declaration.overloads.push(module.add_builtin(
+ args,
+ MacroCall::MathFunction(match name {
+ "inverse" => MathFunction::Inverse,
+ "determinant" => MathFunction::Determinant,
+ _ => unreachable!(),
+ }),
+ ))
+ }
+ }
+ "mod" | "step" => {
+ // bits layout
+ // bit 0 through 2 - dims
+ for bits in 0..0b111 {
+ let (size, second_size) = match bits {
+ 0b000 => (None, None),
+ 0b001 => (Some(VectorSize::Bi), None),
+ 0b010 => (Some(VectorSize::Tri), None),
+ 0b011 => (Some(VectorSize::Quad), None),
+ 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)),
+ 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)),
+ _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)),
+ };
+
+ let mut args = Vec::with_capacity(2);
+ let step = name == "step";
+
+ for i in 0..2 {
+ let maybe_size = match i == step as u32 {
+ true => size,
+ false => second_size,
+ };
+
+ args.push(match maybe_size {
+ Some(size) => TypeInner::Vector {
+ size,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ None => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ })
+ }
+
+ let fun = match name {
+ "mod" => MacroCall::Mod(size),
+ "step" => MacroCall::Splatted(MathFunction::Step, size, 0),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "modf" | "frexp" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ let ty = module.types.insert(
+ Type {
+ name: None,
+ inner: match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ None => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ },
+ },
+ Span::default(),
+ );
+
+ let parameters = vec![ty, ty];
+
+ let fun = match name {
+ "modf" => MacroCall::MathFunction(MathFunction::Modf),
+ "frexp" => MacroCall::MathFunction(MathFunction::Frexp),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(Overload {
+ parameters,
+ parameters_info: vec![
+ ParameterInfo {
+ qualifier: ParameterQualifier::In,
+ depth: false,
+ },
+ ParameterInfo {
+ qualifier: ParameterQualifier::Out,
+ depth: false,
+ },
+ ],
+ kind: FunctionKind::Macro(fun),
+ defined: false,
+ internal: true,
+ void: false,
+ })
+ }
+ }
+ "cross" => {
+ let args = vec![
+ TypeInner::Vector {
+ size: VectorSize::Tri,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ TypeInner::Vector {
+ size: VectorSize::Tri,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ ];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Cross)))
+ }
+ "outerProduct" => {
+ // bits layout
+ // bit 0 through 3 - dims
+ for bits in 0..0b1001 {
+ let (size1, size2) = match bits {
+ 0b0000 => (VectorSize::Bi, VectorSize::Bi),
+ 0b0001 => (VectorSize::Bi, VectorSize::Tri),
+ 0b0010 => (VectorSize::Bi, VectorSize::Quad),
+ 0b0011 => (VectorSize::Tri, VectorSize::Bi),
+ 0b0100 => (VectorSize::Tri, VectorSize::Tri),
+ 0b0101 => (VectorSize::Tri, VectorSize::Quad),
+ 0b0110 => (VectorSize::Quad, VectorSize::Bi),
+ 0b0111 => (VectorSize::Quad, VectorSize::Tri),
+ _ => (VectorSize::Quad, VectorSize::Quad),
+ };
+
+ let args = vec![
+ TypeInner::Vector {
+ size: size1,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ TypeInner::Vector {
+ size: size2,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ ];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Outer)))
+ }
+ }
+ "faceforward" | "fma" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ let ty = || match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ None => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ };
+ let args = vec![ty(), ty(), ty()];
+
+ let fun = match name {
+ "faceforward" => MacroCall::MathFunction(MathFunction::FaceForward),
+ "fma" => MacroCall::MathFunction(MathFunction::Fma),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "refract" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ let ty = || match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ None => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ };
+ let args = vec![
+ ty(),
+ ty(),
+ TypeInner::Scalar {
+ kind: Sk::Float,
+ width: 4,
+ },
+ ];
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Refract)))
+ }
+ }
+ "smoothstep" => {
+ // bit 0 - splatted
+ // bit 1 through 2 - dims
+ for bits in 0..0b1000 {
+ let splatted = bits & 0b1 == 0b1;
+ let size = match bits >> 1 {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ if splatted && size.is_none() {
+ continue;
+ }
+
+ let base_ty = || match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ kind: Sk::Float,
+ width: float_width,
+ },
+ None => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ };
+ let ty = || match splatted {
+ true => TypeInner::Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ },
+ false => base_ty(),
+ };
+ declaration.overloads.push(module.add_builtin(
+ vec![ty(), ty(), base_ty()],
+ MacroCall::SmoothStep { splatted: size },
+ ))
+ }
+ }
+ // The function isn't a builtin or we don't yet support it
+ _ => {}
+ }
+}
+
+#[derive(Clone, Copy, PartialEq, Debug)]
+pub enum TextureLevelType {
+ None,
+ Lod,
+ Grad,
+}
+
+/// A compiler defined builtin function
+#[derive(Clone, Copy, PartialEq, Debug)]
+pub enum MacroCall {
+ Sampler,
+ SamplerShadow,
+ Texture {
+ proj: bool,
+ offset: bool,
+ shadow: bool,
+ level_type: TextureLevelType,
+ },
+ TextureSize {
+ arrayed: bool,
+ },
+ ImageLoad {
+ multi: bool,
+ },
+ ImageStore,
+ MathFunction(MathFunction),
+ FindLsbUint,
+ FindMsbUint,
+ BitfieldExtract,
+ BitfieldInsert,
+ Relational(RelationalFunction),
+ Unary(UnaryOperator),
+ Binary(BinaryOperator),
+ Mod(Option<VectorSize>),
+ Splatted(MathFunction, Option<VectorSize>, usize),
+ MixBoolean,
+ Clamp(Option<VectorSize>),
+ BitCast(Sk),
+ Derivate(Axis, Ctrl),
+ Barrier,
+ /// SmoothStep needs a separate variant because it might need it's inputs
+ /// to be splatted depending on the overload
+ SmoothStep {
+ /// The size of the splat operation if some
+ splatted: Option<VectorSize>,
+ },
+}
+
+impl MacroCall {
+ /// Adds the necessary expressions and statements to the passed body and
+ /// finally returns the final expression with the correct result
+ pub fn call(
+ &self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ body: &mut Block,
+ args: &mut [Handle<Expression>],
+ meta: Span,
+ ) -> Result<Option<Handle<Expression>>> {
+ Ok(Some(match *self {
+ MacroCall::Sampler => {
+ ctx.samplers.insert(args[0], args[1]);
+ args[0]
+ }
+ MacroCall::SamplerShadow => {
+ sampled_to_depth(
+ &mut frontend.module,
+ ctx,
+ args[0],
+ meta,
+ &mut frontend.errors,
+ );
+ frontend.invalidate_expression(ctx, args[0], meta)?;
+ ctx.samplers.insert(args[0], args[1]);
+ args[0]
+ }
+ MacroCall::Texture {
+ proj,
+ offset,
+ shadow,
+ level_type,
+ } => {
+ let mut coords = args[1];
+
+ if proj {
+ let size = match *frontend.resolve_type(ctx, coords, meta)? {
+ TypeInner::Vector { size, .. } => size,
+ _ => unreachable!(),
+ };
+ let mut right = ctx.add_expression(
+ Expression::AccessIndex {
+ base: coords,
+ index: size as u32 - 1,
+ },
+ Span::default(),
+ body,
+ );
+ let left = if let VectorSize::Bi = size {
+ ctx.add_expression(
+ Expression::AccessIndex {
+ base: coords,
+ index: 0,
+ },
+ Span::default(),
+ body,
+ )
+ } else {
+ let size = match size {
+ VectorSize::Tri => VectorSize::Bi,
+ _ => VectorSize::Tri,
+ };
+ right = ctx.add_expression(
+ Expression::Splat { size, value: right },
+ Span::default(),
+ body,
+ );
+ ctx.vector_resize(size, coords, Span::default(), body)
+ };
+ coords = ctx.add_expression(
+ Expression::Binary {
+ op: BinaryOperator::Divide,
+ left,
+ right,
+ },
+ Span::default(),
+ body,
+ );
+ }
+
+ let extra = args.get(2).copied();
+ let comps =
+ frontend.coordinate_components(ctx, args[0], coords, extra, meta, body)?;
+
+ let mut num_args = 2;
+
+ if comps.used_extra {
+ num_args += 1;
+ };
+
+ // Parse out explicit texture level.
+ let mut level = match level_type {
+ TextureLevelType::None => SampleLevel::Auto,
+
+ TextureLevelType::Lod => {
+ num_args += 1;
+
+ if shadow {
+ log::warn!("Assuming LOD {:?} is zero", args[2],);
+
+ SampleLevel::Zero
+ } else {
+ SampleLevel::Exact(args[2])
+ }
+ }
+
+ TextureLevelType::Grad => {
+ num_args += 2;
+
+ if shadow {
+ log::warn!(
+ "Assuming gradients {:?} and {:?} are not greater than 1",
+ args[2],
+ args[3],
+ );
+ SampleLevel::Zero
+ } else {
+ SampleLevel::Gradient {
+ x: args[2],
+ y: args[3],
+ }
+ }
+ }
+ };
+
+ let texture_offset = match offset {
+ true => {
+ let offset_arg = args[num_args];
+ num_args += 1;
+ match frontend.solve_constant(ctx, offset_arg, meta) {
+ Ok(v) => Some(v),
+ Err(e) => {
+ frontend.errors.push(e);
+ None
+ }
+ }
+ }
+ false => None,
+ };
+
+ // Now go back and look for optional bias arg (if available)
+ if let TextureLevelType::None = level_type {
+ level = args
+ .get(num_args)
+ .copied()
+ .map_or(SampleLevel::Auto, SampleLevel::Bias);
+ }
+
+ texture_call(ctx, args[0], level, comps, texture_offset, body, meta)?
+ }
+
+ MacroCall::TextureSize { arrayed } => {
+ let mut expr = ctx.add_expression(
+ Expression::ImageQuery {
+ image: args[0],
+ query: ImageQuery::Size {
+ level: args.get(1).copied(),
+ },
+ },
+ Span::default(),
+ body,
+ );
+
+ if arrayed {
+ let mut components = Vec::with_capacity(4);
+
+ let size = match *frontend.resolve_type(ctx, expr, meta)? {
+ TypeInner::Vector { size: ori_size, .. } => {
+ for index in 0..(ori_size as u32) {
+ components.push(ctx.add_expression(
+ Expression::AccessIndex { base: expr, index },
+ Span::default(),
+ body,
+ ))
+ }
+
+ match ori_size {
+ VectorSize::Bi => VectorSize::Tri,
+ _ => VectorSize::Quad,
+ }
+ }
+ _ => {
+ components.push(expr);
+ VectorSize::Bi
+ }
+ };
+
+ components.push(ctx.add_expression(
+ Expression::ImageQuery {
+ image: args[0],
+ query: ImageQuery::NumLayers,
+ },
+ Span::default(),
+ body,
+ ));
+
+ let ty = frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ },
+ Span::default(),
+ );
+
+ expr = ctx.add_expression(Expression::Compose { components, ty }, meta, body)
+ }
+
+ ctx.add_expression(
+ Expression::As {
+ expr,
+ kind: Sk::Sint,
+ convert: Some(4),
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::ImageLoad { multi } => {
+ let comps =
+ frontend.coordinate_components(ctx, args[0], args[1], None, meta, body)?;
+ let (sample, level) = match (multi, args.get(2)) {
+ (_, None) => (None, None),
+ (true, Some(&arg)) => (Some(arg), None),
+ (false, Some(&arg)) => (None, Some(arg)),
+ };
+ ctx.add_expression(
+ Expression::ImageLoad {
+ image: args[0],
+ coordinate: comps.coordinate,
+ array_index: comps.array_index,
+ sample,
+ level,
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::ImageStore => {
+ let comps =
+ frontend.coordinate_components(ctx, args[0], args[1], None, meta, body)?;
+ ctx.emit_restart(body);
+ body.push(
+ crate::Statement::ImageStore {
+ image: args[0],
+ coordinate: comps.coordinate,
+ array_index: comps.array_index,
+ value: args[2],
+ },
+ meta,
+ );
+ return Ok(None);
+ }
+ MacroCall::MathFunction(fun) => ctx.add_expression(
+ Expression::Math {
+ fun,
+ arg: args[0],
+ arg1: args.get(1).copied(),
+ arg2: args.get(2).copied(),
+ arg3: args.get(3).copied(),
+ },
+ Span::default(),
+ body,
+ ),
+ mc @ (MacroCall::FindLsbUint | MacroCall::FindMsbUint) => {
+ let fun = match mc {
+ MacroCall::FindLsbUint => MathFunction::FindLsb,
+ MacroCall::FindMsbUint => MathFunction::FindMsb,
+ _ => unreachable!(),
+ };
+ let res = ctx.add_expression(
+ Expression::Math {
+ fun,
+ arg: args[0],
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ Span::default(),
+ body,
+ );
+ ctx.add_expression(
+ Expression::As {
+ expr: res,
+ kind: Sk::Sint,
+ convert: Some(4),
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::BitfieldInsert => {
+ let conv_arg_2 = ctx.add_expression(
+ Expression::As {
+ expr: args[2],
+ kind: Sk::Uint,
+ convert: Some(4),
+ },
+ Span::default(),
+ body,
+ );
+ let conv_arg_3 = ctx.add_expression(
+ Expression::As {
+ expr: args[3],
+ kind: Sk::Uint,
+ convert: Some(4),
+ },
+ Span::default(),
+ body,
+ );
+ ctx.add_expression(
+ Expression::Math {
+ fun: MathFunction::InsertBits,
+ arg: args[0],
+ arg1: Some(args[1]),
+ arg2: Some(conv_arg_2),
+ arg3: Some(conv_arg_3),
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::BitfieldExtract => {
+ let conv_arg_1 = ctx.add_expression(
+ Expression::As {
+ expr: args[1],
+ kind: Sk::Uint,
+ convert: Some(4),
+ },
+ Span::default(),
+ body,
+ );
+ let conv_arg_2 = ctx.add_expression(
+ Expression::As {
+ expr: args[2],
+ kind: Sk::Uint,
+ convert: Some(4),
+ },
+ Span::default(),
+ body,
+ );
+ ctx.add_expression(
+ Expression::Math {
+ fun: MathFunction::ExtractBits,
+ arg: args[0],
+ arg1: Some(conv_arg_1),
+ arg2: Some(conv_arg_2),
+ arg3: None,
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::Relational(fun) => ctx.add_expression(
+ Expression::Relational {
+ fun,
+ argument: args[0],
+ },
+ Span::default(),
+ body,
+ ),
+ MacroCall::Unary(op) => ctx.add_expression(
+ Expression::Unary { op, expr: args[0] },
+ Span::default(),
+ body,
+ ),
+ MacroCall::Binary(op) => ctx.add_expression(
+ Expression::Binary {
+ op,
+ left: args[0],
+ right: args[1],
+ },
+ Span::default(),
+ body,
+ ),
+ MacroCall::Mod(size) => {
+ ctx.implicit_splat(frontend, &mut args[1], meta, size)?;
+
+ // x - y * floor(x / y)
+
+ let div = ctx.add_expression(
+ Expression::Binary {
+ op: BinaryOperator::Divide,
+ left: args[0],
+ right: args[1],
+ },
+ Span::default(),
+ body,
+ );
+ let floor = ctx.add_expression(
+ Expression::Math {
+ fun: MathFunction::Floor,
+ arg: div,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ Span::default(),
+ body,
+ );
+ let mult = ctx.add_expression(
+ Expression::Binary {
+ op: BinaryOperator::Multiply,
+ left: floor,
+ right: args[1],
+ },
+ Span::default(),
+ body,
+ );
+ ctx.add_expression(
+ Expression::Binary {
+ op: BinaryOperator::Subtract,
+ left: args[0],
+ right: mult,
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::Splatted(fun, size, i) => {
+ ctx.implicit_splat(frontend, &mut args[i], meta, size)?;
+
+ ctx.add_expression(
+ Expression::Math {
+ fun,
+ arg: args[0],
+ arg1: args.get(1).copied(),
+ arg2: args.get(2).copied(),
+ arg3: args.get(3).copied(),
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::MixBoolean => ctx.add_expression(
+ Expression::Select {
+ condition: args[2],
+ accept: args[1],
+ reject: args[0],
+ },
+ Span::default(),
+ body,
+ ),
+ MacroCall::Clamp(size) => {
+ ctx.implicit_splat(frontend, &mut args[1], meta, size)?;
+ ctx.implicit_splat(frontend, &mut args[2], meta, size)?;
+
+ ctx.add_expression(
+ Expression::Math {
+ fun: MathFunction::Clamp,
+ arg: args[0],
+ arg1: args.get(1).copied(),
+ arg2: args.get(2).copied(),
+ arg3: args.get(3).copied(),
+ },
+ Span::default(),
+ body,
+ )
+ }
+ MacroCall::BitCast(kind) => ctx.add_expression(
+ Expression::As {
+ expr: args[0],
+ kind,
+ convert: None,
+ },
+ Span::default(),
+ body,
+ ),
+ MacroCall::Derivate(axis, ctrl) => ctx.add_expression(
+ Expression::Derivative {
+ axis,
+ ctrl,
+ expr: args[0],
+ },
+ Span::default(),
+ body,
+ ),
+ MacroCall::Barrier => {
+ ctx.emit_restart(body);
+ body.push(crate::Statement::Barrier(crate::Barrier::all()), meta);
+ return Ok(None);
+ }
+ MacroCall::SmoothStep { splatted } => {
+ ctx.implicit_splat(frontend, &mut args[0], meta, splatted)?;
+ ctx.implicit_splat(frontend, &mut args[1], meta, splatted)?;
+
+ ctx.add_expression(
+ Expression::Math {
+ fun: MathFunction::SmoothStep,
+ arg: args[0],
+ arg1: args.get(1).copied(),
+ arg2: args.get(2).copied(),
+ arg3: None,
+ },
+ Span::default(),
+ body,
+ )
+ }
+ }))
+ }
+}
+
+fn texture_call(
+ ctx: &mut Context,
+ image: Handle<Expression>,
+ level: SampleLevel,
+ comps: CoordComponents,
+ offset: Option<Handle<Constant>>,
+ body: &mut Block,
+ meta: Span,
+) -> Result<Handle<Expression>> {
+ if let Some(sampler) = ctx.samplers.get(&image).copied() {
+ let mut array_index = comps.array_index;
+
+ if let Some(ref mut array_index_expr) = array_index {
+ ctx.conversion(array_index_expr, meta, Sk::Sint, 4)?;
+ }
+
+ Ok(ctx.add_expression(
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather: None, //TODO
+ coordinate: comps.coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref: comps.depth_ref,
+ },
+ meta,
+ body,
+ ))
+ } else {
+ Err(Error {
+ kind: ErrorKind::SemanticError("Bad call".into()),
+ meta,
+ })
+ }
+}
+
+/// Helper struct for texture calls with the separate components from the vector argument
+///
+/// Obtained by calling [`coordinate_components`](Frontend::coordinate_components)
+#[derive(Debug)]
+struct CoordComponents {
+ coordinate: Handle<Expression>,
+ depth_ref: Option<Handle<Expression>>,
+ array_index: Option<Handle<Expression>>,
+ used_extra: bool,
+}
+
+impl Frontend {
+ /// Helper function for texture calls, splits the vector argument into it's components
+ fn coordinate_components(
+ &mut self,
+ ctx: &mut Context,
+ image: Handle<Expression>,
+ coord: Handle<Expression>,
+ extra: Option<Handle<Expression>>,
+ meta: Span,
+ body: &mut Block,
+ ) -> Result<CoordComponents> {
+ if let TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } = *self.resolve_type(ctx, image, meta)?
+ {
+ let image_size = match dim {
+ Dim::D1 => None,
+ Dim::D2 => Some(VectorSize::Bi),
+ Dim::D3 => Some(VectorSize::Tri),
+ Dim::Cube => Some(VectorSize::Tri),
+ };
+ let coord_size = match *self.resolve_type(ctx, coord, meta)? {
+ TypeInner::Vector { size, .. } => Some(size),
+ _ => None,
+ };
+ let (shadow, storage) = match class {
+ ImageClass::Depth { .. } => (true, false),
+ ImageClass::Storage { .. } => (false, true),
+ ImageClass::Sampled { .. } => (false, false),
+ };
+
+ let coordinate = match (image_size, coord_size) {
+ (Some(size), Some(coord_s)) if size != coord_s => {
+ ctx.vector_resize(size, coord, Span::default(), body)
+ }
+ (None, Some(_)) => ctx.add_expression(
+ Expression::AccessIndex {
+ base: coord,
+ index: 0,
+ },
+ Span::default(),
+ body,
+ ),
+ _ => coord,
+ };
+
+ let mut coord_index = image_size.map_or(1, |s| s as u32);
+
+ let array_index = if arrayed && !(storage && dim == Dim::Cube) {
+ let index = coord_index;
+ coord_index += 1;
+
+ Some(ctx.add_expression(
+ Expression::AccessIndex { base: coord, index },
+ Span::default(),
+ body,
+ ))
+ } else {
+ None
+ };
+ let mut used_extra = false;
+ let depth_ref = match shadow {
+ true => {
+ let index = coord_index;
+
+ if index == 4 {
+ used_extra = true;
+ extra
+ } else {
+ Some(ctx.add_expression(
+ Expression::AccessIndex { base: coord, index },
+ Span::default(),
+ body,
+ ))
+ }
+ }
+ false => None,
+ };
+
+ Ok(CoordComponents {
+ coordinate,
+ depth_ref,
+ array_index,
+ used_extra,
+ })
+ } else {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError("Type is not an image".into()),
+ meta,
+ });
+
+ Ok(CoordComponents {
+ coordinate: coord,
+ depth_ref: None,
+ array_index: None,
+ used_extra: false,
+ })
+ }
+ }
+}
+
+/// Helper function to cast a expression holding a sampled image to a
+/// depth image.
+pub fn sampled_to_depth(
+ module: &mut Module,
+ ctx: &mut Context,
+ image: Handle<Expression>,
+ meta: Span,
+ errors: &mut Vec<Error>,
+) {
+ // Get the a mutable type handle of the underlying image storage
+ let ty = match ctx[image] {
+ Expression::GlobalVariable(handle) => &mut module.global_variables.get_mut(handle).ty,
+ Expression::FunctionArgument(i) => {
+ // Mark the function argument as carrying a depth texture
+ ctx.parameters_info[i as usize].depth = true;
+ // NOTE: We need to later also change the parameter type
+ &mut ctx.arguments[i as usize].ty
+ }
+ _ => {
+ // Only globals and function arguments are allowed to carry an image
+ return errors.push(Error {
+ kind: ErrorKind::SemanticError("Not a valid texture expression".into()),
+ meta,
+ });
+ }
+ };
+
+ match module.types[*ty].inner {
+ // Update the image class to depth in case it already isn't
+ TypeInner::Image {
+ class,
+ dim,
+ arrayed,
+ } => match class {
+ ImageClass::Sampled { multi, .. } => {
+ *ty = module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Image {
+ dim,
+ arrayed,
+ class: ImageClass::Depth { multi },
+ },
+ },
+ Span::default(),
+ )
+ }
+ ImageClass::Depth { .. } => {}
+ // Other image classes aren't allowed to be transformed to depth
+ ImageClass::Storage { .. } => errors.push(Error {
+ kind: ErrorKind::SemanticError("Not a texture".into()),
+ meta,
+ }),
+ },
+ _ => errors.push(Error {
+ kind: ErrorKind::SemanticError("Not a texture".into()),
+ meta,
+ }),
+ };
+
+ // Copy the handle to allow borrowing the `ctx` again
+ let ty = *ty;
+
+ // If the image was passed through a function argument we also need to change
+ // the corresponding parameter
+ if let Expression::FunctionArgument(i) = ctx[image] {
+ ctx.parameters[i as usize] = ty;
+ }
+}
+
+bitflags::bitflags! {
+ /// Influences the operation `texture_args_generator`
+ struct TextureArgsOptions: u32 {
+ /// Generates multisampled variants of images
+ const MULTI = 1 << 0;
+ /// Generates shadow variants of images
+ const SHADOW = 1 << 1;
+ /// Generates standard images
+ const STANDARD = 1 << 2;
+ /// Generates cube arrayed images
+ const CUBE_ARRAY = 1 << 3;
+ /// Generates cube arrayed images
+ const D2_MULTI_ARRAY = 1 << 4;
+ }
+}
+
+impl From<BuiltinVariations> for TextureArgsOptions {
+ fn from(variations: BuiltinVariations) -> Self {
+ let mut options = TextureArgsOptions::empty();
+ if variations.contains(BuiltinVariations::STANDARD) {
+ options |= TextureArgsOptions::STANDARD
+ }
+ if variations.contains(BuiltinVariations::CUBE_TEXTURES_ARRAY) {
+ options |= TextureArgsOptions::CUBE_ARRAY
+ }
+ if variations.contains(BuiltinVariations::D2_MULTI_TEXTURES_ARRAY) {
+ options |= TextureArgsOptions::D2_MULTI_ARRAY
+ }
+ options
+ }
+}
+
+/// Helper function to generate the image components for texture/image builtins
+///
+/// Calls the passed function `f` with:
+/// ```text
+/// f(ScalarKind, ImageDimension, arrayed, multi, shadow)
+/// ```
+///
+/// `options` controls extra image variants generation like multisampling and depth,
+/// see the struct documentation
+fn texture_args_generator(
+ options: TextureArgsOptions,
+ mut f: impl FnMut(crate::ScalarKind, Dim, bool, bool, bool),
+) {
+ for kind in [Sk::Float, Sk::Uint, Sk::Sint].iter().copied() {
+ for dim in [Dim::D1, Dim::D2, Dim::D3, Dim::Cube].iter().copied() {
+ for arrayed in [false, true].iter().copied() {
+ if dim == Dim::Cube && arrayed {
+ if !options.contains(TextureArgsOptions::CUBE_ARRAY) {
+ continue;
+ }
+ } else if Dim::D2 == dim
+ && options.contains(TextureArgsOptions::MULTI)
+ && arrayed
+ && options.contains(TextureArgsOptions::D2_MULTI_ARRAY)
+ {
+ // multisampling for sampler2DMSArray
+ f(kind, dim, arrayed, true, false);
+ } else if !options.contains(TextureArgsOptions::STANDARD) {
+ continue;
+ }
+
+ f(kind, dim, arrayed, false, false);
+
+ // 3D images can't be neither arrayed nor shadow
+ // so we break out early, this way arrayed will always
+ // be false and we won't hit the shadow branch
+ if let Dim::D3 = dim {
+ break;
+ }
+
+ if Dim::D2 == dim && options.contains(TextureArgsOptions::MULTI) && !arrayed {
+ // multisampling
+ f(kind, dim, arrayed, true, false);
+ }
+
+ if Sk::Float == kind && options.contains(TextureArgsOptions::SHADOW) {
+ // shadow
+ f(kind, dim, arrayed, false, true);
+ }
+ }
+ }
+ }
+}
+
+/// Helper functions used to convert from a image dimension into a integer representing the
+/// number of components needed for the coordinates vector (1 means scalar instead of vector)
+const fn image_dims_to_coords_size(dim: Dim) -> usize {
+ match dim {
+ Dim::D1 => 1,
+ Dim::D2 => 2,
+ _ => 3,
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/constants.rs b/third_party/rust/naga/src/front/glsl/constants.rs
new file mode 100644
index 0000000000..045a9c6ffb
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/constants.rs
@@ -0,0 +1,984 @@
+use crate::{
+ arena::{Arena, Handle, UniqueArena},
+ BinaryOperator, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner,
+ UnaryOperator,
+};
+
+#[derive(Debug)]
+pub struct ConstantSolver<'a> {
+ pub types: &'a mut UniqueArena<Type>,
+ pub expressions: &'a Arena<Expression>,
+ pub constants: &'a mut Arena<Constant>,
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+pub enum ConstantSolvingError {
+ #[error("Constants cannot access function arguments")]
+ FunctionArg,
+ #[error("Constants cannot access global variables")]
+ GlobalVariable,
+ #[error("Constants cannot access local variables")]
+ LocalVariable,
+ #[error("Cannot get the array length of a non array type")]
+ InvalidArrayLengthArg,
+ #[error("Constants cannot get the array length of a dynamically sized array")]
+ ArrayLengthDynamic,
+ #[error("Constants cannot call functions")]
+ Call,
+ #[error("Constants don't support atomic functions")]
+ Atomic,
+ #[error("Constants don't support relational functions")]
+ Relational,
+ #[error("Constants don't support derivative functions")]
+ Derivative,
+ #[error("Constants don't support select expressions")]
+ Select,
+ #[error("Constants don't support load expressions")]
+ Load,
+ #[error("Constants don't support image expressions")]
+ ImageExpression,
+ #[error("Constants don't support ray query expressions")]
+ RayQueryExpression,
+ #[error("Cannot access the type")]
+ InvalidAccessBase,
+ #[error("Cannot access at the index")]
+ InvalidAccessIndex,
+ #[error("Cannot access with index of type")]
+ InvalidAccessIndexTy,
+ #[error("Constants don't support bitcasts")]
+ Bitcast,
+ #[error("Cannot cast type")]
+ InvalidCastArg,
+ #[error("Cannot apply the unary op to the argument")]
+ InvalidUnaryOpArg,
+ #[error("Cannot apply the binary op to the arguments")]
+ InvalidBinaryOpArgs,
+ #[error("Cannot apply math function to type")]
+ InvalidMathArg,
+ #[error("Splat is defined only on scalar values")]
+ SplatScalarOnly,
+ #[error("Can only swizzle vector constants")]
+ SwizzleVectorOnly,
+ #[error("Not implemented as constant expression: {0}")]
+ NotImplemented(String),
+}
+
+impl<'a> ConstantSolver<'a> {
+ pub fn solve(
+ &mut self,
+ expr: Handle<Expression>,
+ ) -> Result<Handle<Constant>, ConstantSolvingError> {
+ let span = self.expressions.get_span(expr);
+ match self.expressions[expr] {
+ Expression::Constant(constant) => Ok(constant),
+ Expression::AccessIndex { base, index } => self.access(base, index as usize),
+ Expression::Access { base, index } => {
+ let index = self.solve(index)?;
+
+ self.access(base, self.constant_index(index)?)
+ }
+ Expression::Splat {
+ size,
+ value: splat_value,
+ } => {
+ let value_constant = self.solve(splat_value)?;
+ let ty = match self.constants[value_constant].inner {
+ ConstantInner::Scalar { ref value, width } => {
+ let kind = value.scalar_kind();
+ self.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector { size, kind, width },
+ },
+ span,
+ )
+ }
+ ConstantInner::Composite { .. } => {
+ return Err(ConstantSolvingError::SplatScalarOnly);
+ }
+ };
+
+ let inner = ConstantInner::Composite {
+ ty,
+ components: vec![value_constant; size as usize],
+ };
+ Ok(self.register_constant(inner, span))
+ }
+ Expression::Swizzle {
+ size,
+ vector: src_vector,
+ pattern,
+ } => {
+ let src_constant = self.solve(src_vector)?;
+ let (ty, src_components) = match self.constants[src_constant].inner {
+ ConstantInner::Scalar { .. } => {
+ return Err(ConstantSolvingError::SwizzleVectorOnly);
+ }
+ ConstantInner::Composite {
+ ty,
+ components: ref src_components,
+ } => match self.types[ty].inner {
+ crate::TypeInner::Vector {
+ size: _,
+ kind,
+ width,
+ } => {
+ let dst_ty = self.types.insert(
+ Type {
+ name: None,
+ inner: crate::TypeInner::Vector { size, kind, width },
+ },
+ span,
+ );
+ (dst_ty, &src_components[..])
+ }
+ _ => {
+ return Err(ConstantSolvingError::SwizzleVectorOnly);
+ }
+ },
+ };
+
+ let components = pattern
+ .iter()
+ .map(|&sc| src_components[sc as usize])
+ .collect();
+ let inner = ConstantInner::Composite { ty, components };
+
+ Ok(self.register_constant(inner, span))
+ }
+ Expression::Compose { ty, ref components } => {
+ let components = components
+ .iter()
+ .map(|c| self.solve(*c))
+ .collect::<Result<_, _>>()?;
+ let inner = ConstantInner::Composite { ty, components };
+
+ Ok(self.register_constant(inner, span))
+ }
+ Expression::Unary { expr, op } => {
+ let expr_constant = self.solve(expr)?;
+
+ self.unary_op(op, expr_constant, span)
+ }
+ Expression::Binary { left, right, op } => {
+ let left_constant = self.solve(left)?;
+ let right_constant = self.solve(right)?;
+
+ self.binary_op(op, left_constant, right_constant, span)
+ }
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ ..
+ } => {
+ let arg = self.solve(arg)?;
+ let arg1 = arg1.map(|arg| self.solve(arg)).transpose()?;
+ let arg2 = arg2.map(|arg| self.solve(arg)).transpose()?;
+
+ let const0 = &self.constants[arg].inner;
+ let const1 = arg1.map(|arg| &self.constants[arg].inner);
+ let const2 = arg2.map(|arg| &self.constants[arg].inner);
+
+ match fun {
+ crate::MathFunction::Pow => {
+ let (value, width) = match (const0, const1.unwrap()) {
+ (
+ &ConstantInner::Scalar {
+ width,
+ value: value0,
+ },
+ &ConstantInner::Scalar { value: value1, .. },
+ ) => (
+ match (value0, value1) {
+ (ScalarValue::Sint(a), ScalarValue::Sint(b)) => {
+ ScalarValue::Sint(a.pow(b as u32))
+ }
+ (ScalarValue::Uint(a), ScalarValue::Uint(b)) => {
+ ScalarValue::Uint(a.pow(b as u32))
+ }
+ (ScalarValue::Float(a), ScalarValue::Float(b)) => {
+ ScalarValue::Float(a.powf(b))
+ }
+ _ => return Err(ConstantSolvingError::InvalidMathArg),
+ },
+ width,
+ ),
+ _ => return Err(ConstantSolvingError::InvalidMathArg),
+ };
+
+ let inner = ConstantInner::Scalar { width, value };
+ Ok(self.register_constant(inner, span))
+ }
+ crate::MathFunction::Clamp => {
+ let (value, width) = match (const0, const1.unwrap(), const2.unwrap()) {
+ (
+ &ConstantInner::Scalar {
+ width,
+ value: value0,
+ },
+ &ConstantInner::Scalar { value: value1, .. },
+ &ConstantInner::Scalar { value: value2, .. },
+ ) => (
+ match (value0, value1, value2) {
+ (
+ ScalarValue::Sint(a),
+ ScalarValue::Sint(b),
+ ScalarValue::Sint(c),
+ ) => ScalarValue::Sint(a.clamp(b, c)),
+ (
+ ScalarValue::Uint(a),
+ ScalarValue::Uint(b),
+ ScalarValue::Uint(c),
+ ) => ScalarValue::Uint(a.clamp(b, c)),
+ (
+ ScalarValue::Float(a),
+ ScalarValue::Float(b),
+ ScalarValue::Float(c),
+ ) => ScalarValue::Float(glsl_float_clamp(a, b, c)),
+ _ => return Err(ConstantSolvingError::InvalidMathArg),
+ },
+ width,
+ ),
+ _ => {
+ return Err(ConstantSolvingError::NotImplemented(format!(
+ "{fun:?} applied to vector values"
+ )))
+ }
+ };
+
+ let inner = ConstantInner::Scalar { width, value };
+ Ok(self.register_constant(inner, span))
+ }
+ _ => Err(ConstantSolvingError::NotImplemented(format!("{fun:?}"))),
+ }
+ }
+ Expression::As {
+ convert,
+ expr,
+ kind,
+ } => {
+ let expr_constant = self.solve(expr)?;
+
+ match convert {
+ Some(width) => self.cast(expr_constant, kind, width, span),
+ None => Err(ConstantSolvingError::Bitcast),
+ }
+ }
+ Expression::ArrayLength(expr) => {
+ let array = self.solve(expr)?;
+
+ match self.constants[array].inner {
+ ConstantInner::Scalar { .. } => {
+ Err(ConstantSolvingError::InvalidArrayLengthArg)
+ }
+ ConstantInner::Composite { ty, .. } => match self.types[ty].inner {
+ TypeInner::Array { size, .. } => match size {
+ crate::ArraySize::Constant(constant) => Ok(constant),
+ crate::ArraySize::Dynamic => {
+ Err(ConstantSolvingError::ArrayLengthDynamic)
+ }
+ },
+ _ => Err(ConstantSolvingError::InvalidArrayLengthArg),
+ },
+ }
+ }
+
+ Expression::Load { .. } => Err(ConstantSolvingError::Load),
+ Expression::Select { .. } => Err(ConstantSolvingError::Select),
+ Expression::LocalVariable(_) => Err(ConstantSolvingError::LocalVariable),
+ Expression::Derivative { .. } => Err(ConstantSolvingError::Derivative),
+ Expression::Relational { .. } => Err(ConstantSolvingError::Relational),
+ Expression::CallResult { .. } => Err(ConstantSolvingError::Call),
+ Expression::AtomicResult { .. } => Err(ConstantSolvingError::Atomic),
+ Expression::FunctionArgument(_) => Err(ConstantSolvingError::FunctionArg),
+ Expression::GlobalVariable(_) => Err(ConstantSolvingError::GlobalVariable),
+ Expression::ImageSample { .. }
+ | Expression::ImageLoad { .. }
+ | Expression::ImageQuery { .. } => Err(ConstantSolvingError::ImageExpression),
+ Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
+ Err(ConstantSolvingError::RayQueryExpression)
+ }
+ }
+ }
+
+ fn access(
+ &mut self,
+ base: Handle<Expression>,
+ index: usize,
+ ) -> Result<Handle<Constant>, ConstantSolvingError> {
+ let base = self.solve(base)?;
+
+ match self.constants[base].inner {
+ ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase),
+ ConstantInner::Composite { ty, ref components } => {
+ match self.types[ty].inner {
+ TypeInner::Vector { .. }
+ | TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::Struct { .. } => (),
+ _ => return Err(ConstantSolvingError::InvalidAccessBase),
+ }
+
+ components
+ .get(index)
+ .copied()
+ .ok_or(ConstantSolvingError::InvalidAccessIndex)
+ }
+ }
+ }
+
+ fn constant_index(&self, constant: Handle<Constant>) -> Result<usize, ConstantSolvingError> {
+ match self.constants[constant].inner {
+ ConstantInner::Scalar {
+ value: ScalarValue::Uint(index),
+ ..
+ } => Ok(index as usize),
+ _ => Err(ConstantSolvingError::InvalidAccessIndexTy),
+ }
+ }
+
+ fn cast(
+ &mut self,
+ constant: Handle<Constant>,
+ kind: ScalarKind,
+ target_width: crate::Bytes,
+ span: crate::Span,
+ ) -> Result<Handle<Constant>, ConstantSolvingError> {
+ let mut inner = self.constants[constant].inner.clone();
+
+ match inner {
+ ConstantInner::Scalar {
+ ref mut value,
+ ref mut width,
+ } => {
+ *width = target_width;
+ *value = match kind {
+ ScalarKind::Sint => ScalarValue::Sint(match *value {
+ ScalarValue::Sint(v) => v,
+ ScalarValue::Uint(v) => v as i64,
+ ScalarValue::Float(v) => v as i64,
+ ScalarValue::Bool(v) => v as i64,
+ }),
+ ScalarKind::Uint => ScalarValue::Uint(match *value {
+ ScalarValue::Sint(v) => v as u64,
+ ScalarValue::Uint(v) => v,
+ ScalarValue::Float(v) => v as u64,
+ ScalarValue::Bool(v) => v as u64,
+ }),
+ ScalarKind::Float => ScalarValue::Float(match *value {
+ ScalarValue::Sint(v) => v as f64,
+ ScalarValue::Uint(v) => v as f64,
+ ScalarValue::Float(v) => v,
+ ScalarValue::Bool(v) => v as u64 as f64,
+ }),
+ ScalarKind::Bool => ScalarValue::Bool(match *value {
+ ScalarValue::Sint(v) => v != 0,
+ ScalarValue::Uint(v) => v != 0,
+ ScalarValue::Float(v) => v != 0.0,
+ ScalarValue::Bool(v) => v,
+ }),
+ }
+ }
+ ConstantInner::Composite {
+ ty,
+ ref mut components,
+ } => {
+ match self.types[ty].inner {
+ TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
+ _ => return Err(ConstantSolvingError::InvalidCastArg),
+ }
+
+ for component in components {
+ *component = self.cast(*component, kind, target_width, span)?;
+ }
+ }
+ }
+
+ Ok(self.register_constant(inner, span))
+ }
+
+ fn unary_op(
+ &mut self,
+ op: UnaryOperator,
+ constant: Handle<Constant>,
+ span: crate::Span,
+ ) -> Result<Handle<Constant>, ConstantSolvingError> {
+ let mut inner = self.constants[constant].inner.clone();
+
+ match inner {
+ ConstantInner::Scalar { ref mut value, .. } => match op {
+ UnaryOperator::Negate => match *value {
+ ScalarValue::Sint(ref mut v) => *v = -*v,
+ ScalarValue::Float(ref mut v) => *v = -*v,
+ _ => return Err(ConstantSolvingError::InvalidUnaryOpArg),
+ },
+ UnaryOperator::Not => match *value {
+ ScalarValue::Sint(ref mut v) => *v = !*v,
+ ScalarValue::Uint(ref mut v) => *v = !*v,
+ ScalarValue::Bool(ref mut v) => *v = !*v,
+ ScalarValue::Float(_) => return Err(ConstantSolvingError::InvalidUnaryOpArg),
+ },
+ },
+ ConstantInner::Composite {
+ ty,
+ ref mut components,
+ } => {
+ match self.types[ty].inner {
+ TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
+ _ => return Err(ConstantSolvingError::InvalidCastArg),
+ }
+
+ for component in components {
+ *component = self.unary_op(op, *component, span)?
+ }
+ }
+ }
+
+ Ok(self.register_constant(inner, span))
+ }
+
+ fn binary_op(
+ &mut self,
+ op: BinaryOperator,
+ left: Handle<Constant>,
+ right: Handle<Constant>,
+ span: crate::Span,
+ ) -> Result<Handle<Constant>, ConstantSolvingError> {
+ let left_inner = &self.constants[left].inner;
+ let right_inner = &self.constants[right].inner;
+
+ let inner = match (left_inner, right_inner) {
+ (
+ &ConstantInner::Scalar {
+ value: left_value,
+ width,
+ },
+ &ConstantInner::Scalar {
+ value: right_value,
+ width: _,
+ },
+ ) => {
+ let value = match op {
+ BinaryOperator::Equal => ScalarValue::Bool(left_value == right_value),
+ BinaryOperator::NotEqual => ScalarValue::Bool(left_value != right_value),
+ BinaryOperator::Less => ScalarValue::Bool(left_value < right_value),
+ BinaryOperator::LessEqual => ScalarValue::Bool(left_value <= right_value),
+ BinaryOperator::Greater => ScalarValue::Bool(left_value > right_value),
+ BinaryOperator::GreaterEqual => ScalarValue::Bool(left_value >= right_value),
+
+ _ => match (left_value, right_value) {
+ (ScalarValue::Sint(a), ScalarValue::Sint(b)) => {
+ ScalarValue::Sint(match op {
+ BinaryOperator::Add => a.wrapping_add(b),
+ BinaryOperator::Subtract => a.wrapping_sub(b),
+ BinaryOperator::Multiply => a.wrapping_mul(b),
+ BinaryOperator::Divide => a.checked_div(b).unwrap_or(0),
+ BinaryOperator::Modulo => a.checked_rem(b).unwrap_or(0),
+ BinaryOperator::And => a & b,
+ BinaryOperator::ExclusiveOr => a ^ b,
+ BinaryOperator::InclusiveOr => a | b,
+ _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
+ })
+ }
+ (ScalarValue::Sint(a), ScalarValue::Uint(b)) => {
+ ScalarValue::Sint(match op {
+ BinaryOperator::ShiftLeft => a.wrapping_shl(b as u32),
+ BinaryOperator::ShiftRight => a.wrapping_shr(b as u32),
+ _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
+ })
+ }
+ (ScalarValue::Uint(a), ScalarValue::Uint(b)) => {
+ ScalarValue::Uint(match op {
+ BinaryOperator::Add => a.wrapping_add(b),
+ BinaryOperator::Subtract => a.wrapping_sub(b),
+ BinaryOperator::Multiply => a.wrapping_mul(b),
+ BinaryOperator::Divide => a.checked_div(b).unwrap_or(0),
+ BinaryOperator::Modulo => a.checked_rem(b).unwrap_or(0),
+ BinaryOperator::And => a & b,
+ BinaryOperator::ExclusiveOr => a ^ b,
+ BinaryOperator::InclusiveOr => a | b,
+ BinaryOperator::ShiftLeft => a.wrapping_shl(b as u32),
+ BinaryOperator::ShiftRight => a.wrapping_shr(b as u32),
+ _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
+ })
+ }
+ (ScalarValue::Float(a), ScalarValue::Float(b)) => {
+ ScalarValue::Float(match op {
+ BinaryOperator::Add => a + b,
+ BinaryOperator::Subtract => a - b,
+ BinaryOperator::Multiply => a * b,
+ BinaryOperator::Divide => a / b,
+ BinaryOperator::Modulo => a - b * (a / b).floor(),
+ _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
+ })
+ }
+ (ScalarValue::Bool(a), ScalarValue::Bool(b)) => {
+ ScalarValue::Bool(match op {
+ BinaryOperator::LogicalAnd => a && b,
+ BinaryOperator::LogicalOr => a || b,
+ _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
+ })
+ }
+ _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
+ },
+ };
+
+ ConstantInner::Scalar { value, width }
+ }
+ (&ConstantInner::Composite { ref components, ty }, &ConstantInner::Scalar { .. }) => {
+ let mut components = components.clone();
+ for comp in components.iter_mut() {
+ *comp = self.binary_op(op, *comp, right, span)?;
+ }
+ ConstantInner::Composite { ty, components }
+ }
+ (&ConstantInner::Scalar { .. }, &ConstantInner::Composite { ref components, ty }) => {
+ let mut components = components.clone();
+ for comp in components.iter_mut() {
+ *comp = self.binary_op(op, left, *comp, span)?;
+ }
+ ConstantInner::Composite { ty, components }
+ }
+ _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
+ };
+
+ Ok(self.register_constant(inner, span))
+ }
+
+ fn register_constant(&mut self, inner: ConstantInner, span: crate::Span) -> Handle<Constant> {
+ self.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ span,
+ )
+ }
+}
+
+/// Helper function to implement the GLSL `max` function for floats.
+///
+/// While Rust does provide a `f64::max` method, it has a different behavior than the
+/// GLSL `max` for NaNs. In Rust, if any of the arguments is a NaN, then the other
+/// is returned.
+///
+/// This leads to different results in the following example
+/// ```
+/// use std::cmp::max;
+/// std::f64::NAN.max(1.0);
+/// ```
+///
+/// Rust will return `1.0` while GLSL should return NaN.
+fn glsl_float_max(x: f64, y: f64) -> f64 {
+ if x < y {
+ y
+ } else {
+ x
+ }
+}
+
+/// Helper function to implement the GLSL `min` function for floats.
+///
+/// While Rust does provide a `f64::min` method, it has a different behavior than the
+/// GLSL `min` for NaNs. In Rust, if any of the arguments is a NaN, then the other
+/// is returned.
+///
+/// This leads to different results in the following example
+/// ```
+/// use std::cmp::min;
+/// std::f64::NAN.min(1.0);
+/// ```
+///
+/// Rust will return `1.0` while GLSL should return NaN.
+fn glsl_float_min(x: f64, y: f64) -> f64 {
+ if y < x {
+ y
+ } else {
+ x
+ }
+}
+
+/// Helper function to implement the GLSL `clamp` function for floats.
+///
+/// While Rust does provide a `f64::clamp` method, it panics if either
+/// `min` or `max` are `NaN`s which is not the behavior specified by
+/// the glsl specification.
+fn glsl_float_clamp(value: f64, min: f64, max: f64) -> f64 {
+ glsl_float_min(glsl_float_max(value, min), max)
+}
+
+#[cfg(test)]
+mod tests {
+ use std::vec;
+
+ use crate::{
+ Arena, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner,
+ UnaryOperator, UniqueArena, VectorSize,
+ };
+
+ use super::ConstantSolver;
+
+ #[test]
+ fn nan_handling() {
+ assert!(super::glsl_float_max(f64::NAN, 2.0).is_nan());
+ assert!(!super::glsl_float_max(2.0, f64::NAN).is_nan());
+
+ assert!(super::glsl_float_min(f64::NAN, 2.0).is_nan());
+ assert!(!super::glsl_float_min(2.0, f64::NAN).is_nan());
+
+ assert!(super::glsl_float_clamp(f64::NAN, 1.0, 2.0).is_nan());
+ assert!(!super::glsl_float_clamp(1.0, f64::NAN, 2.0).is_nan());
+ assert!(!super::glsl_float_clamp(1.0, 2.0, f64::NAN).is_nan());
+ }
+
+ #[test]
+ fn unary_op() {
+ let mut types = UniqueArena::new();
+ let mut expressions = Arena::new();
+ let mut constants = Arena::new();
+
+ let vec_ty = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: VectorSize::Bi,
+ kind: ScalarKind::Sint,
+ width: 4,
+ },
+ },
+ Default::default(),
+ );
+
+ let h = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Sint(4),
+ },
+ },
+ Default::default(),
+ );
+
+ let h1 = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Sint(8),
+ },
+ },
+ Default::default(),
+ );
+
+ let vec_h = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Composite {
+ ty: vec_ty,
+ components: vec![h, h1],
+ },
+ },
+ Default::default(),
+ );
+
+ let expr = expressions.append(Expression::Constant(h), Default::default());
+ let expr1 = expressions.append(Expression::Constant(vec_h), Default::default());
+
+ let root1 = expressions.append(
+ Expression::Unary {
+ op: UnaryOperator::Negate,
+ expr,
+ },
+ Default::default(),
+ );
+
+ let root2 = expressions.append(
+ Expression::Unary {
+ op: UnaryOperator::Not,
+ expr,
+ },
+ Default::default(),
+ );
+
+ let root3 = expressions.append(
+ Expression::Unary {
+ op: UnaryOperator::Not,
+ expr: expr1,
+ },
+ Default::default(),
+ );
+
+ let mut solver = ConstantSolver {
+ types: &mut types,
+ expressions: &expressions,
+ constants: &mut constants,
+ };
+
+ let res1 = solver.solve(root1).unwrap();
+ let res2 = solver.solve(root2).unwrap();
+ let res3 = solver.solve(root3).unwrap();
+
+ assert_eq!(
+ constants[res1].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Sint(-4),
+ },
+ );
+
+ assert_eq!(
+ constants[res2].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Sint(!4),
+ },
+ );
+
+ let res3_inner = &constants[res3].inner;
+
+ match *res3_inner {
+ ConstantInner::Composite {
+ ref ty,
+ ref components,
+ } => {
+ assert_eq!(*ty, vec_ty);
+ let mut components_iter = components.iter().copied();
+ assert_eq!(
+ constants[components_iter.next().unwrap()].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Sint(!4),
+ },
+ );
+ assert_eq!(
+ constants[components_iter.next().unwrap()].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Sint(!8),
+ },
+ );
+ assert!(components_iter.next().is_none());
+ }
+ _ => panic!("Expected vector"),
+ }
+ }
+
+ #[test]
+ fn cast() {
+ let mut expressions = Arena::new();
+ let mut constants = Arena::new();
+
+ let h = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Sint(4),
+ },
+ },
+ Default::default(),
+ );
+
+ let expr = expressions.append(Expression::Constant(h), Default::default());
+
+ let root = expressions.append(
+ Expression::As {
+ expr,
+ kind: ScalarKind::Bool,
+ convert: Some(crate::BOOL_WIDTH),
+ },
+ Default::default(),
+ );
+
+ let mut solver = ConstantSolver {
+ types: &mut UniqueArena::new(),
+ expressions: &expressions,
+ constants: &mut constants,
+ };
+
+ let res = solver.solve(root).unwrap();
+
+ assert_eq!(
+ constants[res].inner,
+ ConstantInner::Scalar {
+ width: crate::BOOL_WIDTH,
+ value: ScalarValue::Bool(true),
+ },
+ );
+ }
+
+ #[test]
+ fn access() {
+ let mut types = UniqueArena::new();
+ let mut expressions = Arena::new();
+ let mut constants = Arena::new();
+
+ let matrix_ty = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns: VectorSize::Bi,
+ rows: VectorSize::Tri,
+ width: 4,
+ },
+ },
+ Default::default(),
+ );
+
+ let vec_ty = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: VectorSize::Tri,
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ },
+ Default::default(),
+ );
+
+ let mut vec1_components = Vec::with_capacity(3);
+ let mut vec2_components = Vec::with_capacity(3);
+
+ for i in 0..3 {
+ let h = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(i as f64),
+ },
+ },
+ Default::default(),
+ );
+
+ vec1_components.push(h)
+ }
+
+ for i in 3..6 {
+ let h = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(i as f64),
+ },
+ },
+ Default::default(),
+ );
+
+ vec2_components.push(h)
+ }
+
+ let vec1 = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Composite {
+ ty: vec_ty,
+ components: vec1_components,
+ },
+ },
+ Default::default(),
+ );
+
+ let vec2 = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Composite {
+ ty: vec_ty,
+ components: vec2_components,
+ },
+ },
+ Default::default(),
+ );
+
+ let h = constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Composite {
+ ty: matrix_ty,
+ components: vec![vec1, vec2],
+ },
+ },
+ Default::default(),
+ );
+
+ let base = expressions.append(Expression::Constant(h), Default::default());
+ let root1 = expressions.append(
+ Expression::AccessIndex { base, index: 1 },
+ Default::default(),
+ );
+ let root2 = expressions.append(
+ Expression::AccessIndex {
+ base: root1,
+ index: 2,
+ },
+ Default::default(),
+ );
+
+ let mut solver = ConstantSolver {
+ types: &mut types,
+ expressions: &expressions,
+ constants: &mut constants,
+ };
+
+ let res1 = solver.solve(root1).unwrap();
+ let res2 = solver.solve(root2).unwrap();
+
+ let res1_inner = &constants[res1].inner;
+
+ match *res1_inner {
+ ConstantInner::Composite {
+ ref ty,
+ ref components,
+ } => {
+ assert_eq!(*ty, vec_ty);
+ let mut components_iter = components.iter().copied();
+ assert_eq!(
+ constants[components_iter.next().unwrap()].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(3.),
+ },
+ );
+ assert_eq!(
+ constants[components_iter.next().unwrap()].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(4.),
+ },
+ );
+ assert_eq!(
+ constants[components_iter.next().unwrap()].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(5.),
+ },
+ );
+ assert!(components_iter.next().is_none());
+ }
+ _ => panic!("Expected vector"),
+ }
+
+ assert_eq!(
+ constants[res2].inner,
+ ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(5.),
+ },
+ );
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/context.rs b/third_party/rust/naga/src/front/glsl/context.rs
new file mode 100644
index 0000000000..6df4850efa
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/context.rs
@@ -0,0 +1,1609 @@
+use super::{
+ ast::{
+ GlobalLookup, GlobalLookupKind, HirExpr, HirExprKind, ParameterInfo, ParameterQualifier,
+ VariableReference,
+ },
+ error::{Error, ErrorKind},
+ types::{scalar_components, type_power},
+ Frontend, Result,
+};
+use crate::{
+ front::{Emitter, Typifier},
+ AddressSpace, Arena, BinaryOperator, Block, Constant, Expression, FastHashMap,
+ FunctionArgument, Handle, LocalVariable, RelationalFunction, ScalarKind, ScalarValue, Span,
+ Statement, Type, TypeInner, VectorSize,
+};
+use std::{convert::TryFrom, ops::Index};
+
+/// The position at which an expression is, used while lowering
+#[derive(Clone, Copy, PartialEq, Eq, Debug)]
+pub enum ExprPos {
+ /// The expression is in the left hand side of an assignment
+ Lhs,
+ /// The expression is in the right hand side of an assignment
+ Rhs,
+ /// The expression is an array being indexed, needed to allow constant
+ /// arrays to be dynamically indexed
+ AccessBase {
+ /// The index is a constant
+ constant_index: bool,
+ },
+}
+
+impl ExprPos {
+ /// Returns an lhs position if the current position is lhs otherwise AccessBase
+ const fn maybe_access_base(&self, constant_index: bool) -> Self {
+ match *self {
+ ExprPos::Lhs
+ | ExprPos::AccessBase {
+ constant_index: false,
+ } => *self,
+ _ => ExprPos::AccessBase { constant_index },
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Context {
+ pub expressions: Arena<Expression>,
+ pub locals: Arena<LocalVariable>,
+
+ /// The [`FunctionArgument`]s for the final [`crate::Function`].
+ ///
+ /// Parameters with the `out` and `inout` qualifiers have [`Pointer`] types
+ /// here. For example, an `inout vec2 a` argument would be a [`Pointer`] to
+ /// a [`Vector`].
+ ///
+ /// [`Pointer`]: crate::TypeInner::Pointer
+ /// [`Vector`]: crate::TypeInner::Vector
+ pub arguments: Vec<FunctionArgument>,
+
+ /// The parameter types given in the source code.
+ ///
+ /// The `out` and `inout` qualifiers don't affect the types that appear
+ /// here. For example, an `inout vec2 a` argument would simply be a
+ /// [`Vector`], not a pointer to one.
+ ///
+ /// [`Vector`]: crate::TypeInner::Vector
+ pub parameters: Vec<Handle<Type>>,
+ pub parameters_info: Vec<ParameterInfo>,
+
+ pub symbol_table: crate::front::SymbolTable<String, VariableReference>,
+ pub samplers: FastHashMap<Handle<Expression>, Handle<Expression>>,
+
+ pub typifier: Typifier,
+ emitter: Emitter,
+ stmt_ctx: Option<StmtContext>,
+}
+
+impl Context {
+ pub fn new(frontend: &Frontend, body: &mut Block) -> Self {
+ let mut this = Context {
+ expressions: Arena::new(),
+ locals: Arena::new(),
+ arguments: Vec::new(),
+
+ parameters: Vec::new(),
+ parameters_info: Vec::new(),
+
+ symbol_table: crate::front::SymbolTable::default(),
+ samplers: FastHashMap::default(),
+
+ typifier: Typifier::new(),
+ emitter: Emitter::default(),
+ stmt_ctx: Some(StmtContext::new()),
+ };
+
+ this.emit_start();
+
+ for &(ref name, lookup) in frontend.global_variables.iter() {
+ this.add_global(frontend, name, lookup, body)
+ }
+
+ this
+ }
+
+ pub fn add_global(
+ &mut self,
+ frontend: &Frontend,
+ name: &str,
+ GlobalLookup {
+ kind,
+ entry_arg,
+ mutable,
+ }: GlobalLookup,
+ body: &mut Block,
+ ) {
+ self.emit_end(body);
+ let (expr, load, constant) = match kind {
+ GlobalLookupKind::Variable(v) => {
+ let span = frontend.module.global_variables.get_span(v);
+ let res = (
+ self.expressions.append(Expression::GlobalVariable(v), span),
+ frontend.module.global_variables[v].space != AddressSpace::Handle,
+ None,
+ );
+ self.emit_start();
+
+ res
+ }
+ GlobalLookupKind::BlockSelect(handle, index) => {
+ let span = frontend.module.global_variables.get_span(handle);
+ let base = self
+ .expressions
+ .append(Expression::GlobalVariable(handle), span);
+ self.emit_start();
+ let expr = self
+ .expressions
+ .append(Expression::AccessIndex { base, index }, span);
+
+ (
+ expr,
+ {
+ let ty = frontend.module.global_variables[handle].ty;
+
+ match frontend.module.types[ty].inner {
+ TypeInner::Struct { ref members, .. } => {
+ if let TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } = frontend.module.types[members[index as usize].ty].inner
+ {
+ false
+ } else {
+ true
+ }
+ }
+ _ => true,
+ }
+ },
+ None,
+ )
+ }
+ GlobalLookupKind::Constant(v, ty) => {
+ let span = frontend.module.constants.get_span(v);
+ let res = (
+ self.expressions.append(Expression::Constant(v), span),
+ false,
+ Some((v, ty)),
+ );
+ self.emit_start();
+ res
+ }
+ };
+
+ let var = VariableReference {
+ expr,
+ load,
+ mutable,
+ constant,
+ entry_arg,
+ };
+
+ self.symbol_table.add(name.into(), var);
+ }
+
+ /// Starts the expression emitter
+ ///
+ /// # Panics
+ ///
+ /// - If called twice in a row without calling [`emit_end`][Self::emit_end].
+ #[inline]
+ pub fn emit_start(&mut self) {
+ self.emitter.start(&self.expressions)
+ }
+
+ /// Emits all the expressions captured by the emitter to the passed `body`
+ ///
+ /// # Panics
+ ///
+ /// - If called before calling [`emit_start`].
+ /// - If called twice in a row without calling [`emit_start`].
+ ///
+ /// [`emit_start`]: Self::emit_start
+ pub fn emit_end(&mut self, body: &mut Block) {
+ body.extend(self.emitter.finish(&self.expressions))
+ }
+
+ /// Emits all the expressions captured by the emitter to the passed `body`
+ /// and starts the emitter again
+ ///
+ /// # Panics
+ ///
+ /// - If called before calling [`emit_start`][Self::emit_start].
+ pub fn emit_restart(&mut self, body: &mut Block) {
+ self.emit_end(body);
+ self.emit_start()
+ }
+
+ pub fn add_expression(
+ &mut self,
+ expr: Expression,
+ meta: Span,
+ body: &mut Block,
+ ) -> Handle<Expression> {
+ let needs_pre_emit = expr.needs_pre_emit();
+ if needs_pre_emit {
+ self.emit_end(body);
+ }
+ let handle = self.expressions.append(expr, meta);
+ if needs_pre_emit {
+ self.emit_start();
+ }
+ handle
+ }
+
+ /// Add variable to current scope
+ ///
+ /// Returns a variable if a variable with the same name was already defined,
+ /// otherwise returns `None`
+ pub fn add_local_var(
+ &mut self,
+ name: String,
+ expr: Handle<Expression>,
+ mutable: bool,
+ ) -> Option<VariableReference> {
+ let var = VariableReference {
+ expr,
+ load: true,
+ mutable,
+ constant: None,
+ entry_arg: None,
+ };
+
+ self.symbol_table.add(name, var)
+ }
+
+ /// Add function argument to current scope
+ pub fn add_function_arg(
+ &mut self,
+ frontend: &mut Frontend,
+ body: &mut Block,
+ name_meta: Option<(String, Span)>,
+ ty: Handle<Type>,
+ qualifier: ParameterQualifier,
+ ) {
+ let index = self.arguments.len();
+ let mut arg = FunctionArgument {
+ name: name_meta.as_ref().map(|&(ref name, _)| name.clone()),
+ ty,
+ binding: None,
+ };
+ self.parameters.push(ty);
+
+ let opaque = match frontend.module.types[ty].inner {
+ TypeInner::Image { .. } | TypeInner::Sampler { .. } => true,
+ _ => false,
+ };
+
+ if qualifier.is_lhs() {
+ let span = frontend.module.types.get_span(arg.ty);
+ arg.ty = frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Pointer {
+ base: arg.ty,
+ space: AddressSpace::Function,
+ },
+ },
+ span,
+ )
+ }
+
+ self.arguments.push(arg);
+
+ self.parameters_info.push(ParameterInfo {
+ qualifier,
+ depth: false,
+ });
+
+ if let Some((name, meta)) = name_meta {
+ let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta, body);
+ let mutable = qualifier != ParameterQualifier::Const && !opaque;
+ let load = qualifier.is_lhs();
+
+ let var = if mutable && !load {
+ let handle = self.locals.append(
+ LocalVariable {
+ name: Some(name.clone()),
+ ty,
+ init: None,
+ },
+ meta,
+ );
+ let local_expr = self.add_expression(Expression::LocalVariable(handle), meta, body);
+
+ self.emit_restart(body);
+
+ body.push(
+ Statement::Store {
+ pointer: local_expr,
+ value: expr,
+ },
+ meta,
+ );
+
+ VariableReference {
+ expr: local_expr,
+ load: true,
+ mutable,
+ constant: None,
+ entry_arg: None,
+ }
+ } else {
+ VariableReference {
+ expr,
+ load,
+ mutable,
+ constant: None,
+ entry_arg: None,
+ }
+ };
+
+ self.symbol_table.add(name, var);
+ }
+ }
+
+ /// Returns a [`StmtContext`](StmtContext) to be used in parsing and lowering
+ ///
+ /// # Panics
+ /// - If more than one [`StmtContext`](StmtContext) are active at the same
+ /// time or if the previous call didn't use it in lowering.
+ #[must_use]
+ pub fn stmt_ctx(&mut self) -> StmtContext {
+ self.stmt_ctx.take().unwrap()
+ }
+
+ /// Lowers a [`HirExpr`](HirExpr) which might produce a [`Expression`](Expression).
+ ///
+ /// consumes a [`StmtContext`](StmtContext) returning it to the context so
+ /// that it can be used again later.
+ pub fn lower(
+ &mut self,
+ mut stmt: StmtContext,
+ frontend: &mut Frontend,
+ expr: Handle<HirExpr>,
+ pos: ExprPos,
+ body: &mut Block,
+ ) -> Result<(Option<Handle<Expression>>, Span)> {
+ let res = self.lower_inner(&stmt, frontend, expr, pos, body);
+
+ stmt.hir_exprs.clear();
+ self.stmt_ctx = Some(stmt);
+
+ res
+ }
+
+ /// Similar to [`lower`](Self::lower) but returns an error if the expression
+ /// returns void (ie. doesn't produce a [`Expression`](Expression)).
+ ///
+ /// consumes a [`StmtContext`](StmtContext) returning it to the context so
+ /// that it can be used again later.
+ pub fn lower_expect(
+ &mut self,
+ mut stmt: StmtContext,
+ frontend: &mut Frontend,
+ expr: Handle<HirExpr>,
+ pos: ExprPos,
+ body: &mut Block,
+ ) -> Result<(Handle<Expression>, Span)> {
+ let res = self.lower_expect_inner(&stmt, frontend, expr, pos, body);
+
+ stmt.hir_exprs.clear();
+ self.stmt_ctx = Some(stmt);
+
+ res
+ }
+
+ /// internal implementation of [`lower_expect`](Self::lower_expect)
+ ///
+ /// this method is only public because it's used in
+ /// [`function_call`](Frontend::function_call), unless you know what
+ /// you're doing use [`lower_expect`](Self::lower_expect)
+ pub fn lower_expect_inner(
+ &mut self,
+ stmt: &StmtContext,
+ frontend: &mut Frontend,
+ expr: Handle<HirExpr>,
+ pos: ExprPos,
+ body: &mut Block,
+ ) -> Result<(Handle<Expression>, Span)> {
+ let (maybe_expr, meta) = self.lower_inner(stmt, frontend, expr, pos, body)?;
+
+ let expr = match maybe_expr {
+ Some(e) => e,
+ None => {
+ return Err(Error {
+ kind: ErrorKind::SemanticError("Expression returns void".into()),
+ meta,
+ })
+ }
+ };
+
+ Ok((expr, meta))
+ }
+
+ fn lower_store(
+ &mut self,
+ pointer: Handle<Expression>,
+ value: Handle<Expression>,
+ meta: Span,
+ body: &mut Block,
+ ) {
+ if let Expression::Swizzle {
+ size,
+ mut vector,
+ pattern,
+ } = self.expressions[pointer]
+ {
+ // Stores to swizzled values are not directly supported,
+ // lower them as series of per-component stores.
+ let size = match size {
+ VectorSize::Bi => 2,
+ VectorSize::Tri => 3,
+ VectorSize::Quad => 4,
+ };
+
+ if let Expression::Load { pointer } = self.expressions[vector] {
+ vector = pointer;
+ }
+
+ #[allow(clippy::needless_range_loop)]
+ for index in 0..size {
+ let dst = self.add_expression(
+ Expression::AccessIndex {
+ base: vector,
+ index: pattern[index].index(),
+ },
+ meta,
+ body,
+ );
+ let src = self.add_expression(
+ Expression::AccessIndex {
+ base: value,
+ index: index as u32,
+ },
+ meta,
+ body,
+ );
+
+ self.emit_restart(body);
+
+ body.push(
+ Statement::Store {
+ pointer: dst,
+ value: src,
+ },
+ meta,
+ );
+ }
+ } else {
+ self.emit_restart(body);
+
+ body.push(Statement::Store { pointer, value }, meta);
+ }
+ }
+
+ /// Internal implementation of [`lower`](Self::lower)
+ fn lower_inner(
+ &mut self,
+ stmt: &StmtContext,
+ frontend: &mut Frontend,
+ expr: Handle<HirExpr>,
+ pos: ExprPos,
+ body: &mut Block,
+ ) -> Result<(Option<Handle<Expression>>, Span)> {
+ let HirExpr { ref kind, meta } = stmt.hir_exprs[expr];
+
+ log::debug!("Lowering {:?} (kind {:?}, pos {:?})", expr, kind, pos);
+
+ let handle = match *kind {
+ HirExprKind::Access { base, index } => {
+ let (index, index_meta) =
+ self.lower_expect_inner(stmt, frontend, index, ExprPos::Rhs, body)?;
+ let maybe_constant_index = match pos {
+ // Don't try to generate `AccessIndex` if in a LHS position, since it
+ // wouldn't produce a pointer.
+ ExprPos::Lhs => None,
+ _ => frontend.solve_constant(self, index, index_meta).ok(),
+ };
+
+ let base = self
+ .lower_expect_inner(
+ stmt,
+ frontend,
+ base,
+ pos.maybe_access_base(maybe_constant_index.is_some()),
+ body,
+ )?
+ .0;
+
+ let pointer = maybe_constant_index
+ .and_then(|constant| {
+ Some(self.add_expression(
+ Expression::AccessIndex {
+ base,
+ index: match frontend.module.constants[constant].inner {
+ crate::ConstantInner::Scalar {
+ value: ScalarValue::Uint(i),
+ ..
+ } => u32::try_from(i).ok()?,
+ crate::ConstantInner::Scalar {
+ value: ScalarValue::Sint(i),
+ ..
+ } => u32::try_from(i).ok()?,
+ _ => return None,
+ },
+ },
+ meta,
+ body,
+ ))
+ })
+ .unwrap_or_else(|| {
+ self.add_expression(Expression::Access { base, index }, meta, body)
+ });
+
+ if ExprPos::Rhs == pos {
+ let resolved = frontend.resolve_type(self, pointer, meta)?;
+ if resolved.pointer_space().is_some() {
+ return Ok((
+ Some(self.add_expression(Expression::Load { pointer }, meta, body)),
+ meta,
+ ));
+ }
+ }
+
+ pointer
+ }
+ HirExprKind::Select { base, ref field } => {
+ let base = self.lower_expect_inner(stmt, frontend, base, pos, body)?.0;
+
+ frontend.field_selection(self, pos, body, base, field, meta)?
+ }
+ HirExprKind::Constant(constant) if pos != ExprPos::Lhs => {
+ self.add_expression(Expression::Constant(constant), meta, body)
+ }
+ HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => {
+ let (mut left, left_meta) =
+ self.lower_expect_inner(stmt, frontend, left, ExprPos::Rhs, body)?;
+ let (mut right, right_meta) =
+ self.lower_expect_inner(stmt, frontend, right, ExprPos::Rhs, body)?;
+
+ match op {
+ BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => self
+ .implicit_conversion(
+ frontend,
+ &mut right,
+ right_meta,
+ ScalarKind::Uint,
+ 4,
+ )?,
+ _ => self.binary_implicit_conversion(
+ frontend, &mut left, left_meta, &mut right, right_meta,
+ )?,
+ }
+
+ frontend.typifier_grow(self, left, left_meta)?;
+ frontend.typifier_grow(self, right, right_meta)?;
+
+ let left_inner = self.typifier.get(left, &frontend.module.types);
+ let right_inner = self.typifier.get(right, &frontend.module.types);
+
+ match (left_inner, right_inner) {
+ (
+ &TypeInner::Matrix {
+ columns: left_columns,
+ rows: left_rows,
+ width: left_width,
+ },
+ &TypeInner::Matrix {
+ columns: right_columns,
+ rows: right_rows,
+ width: right_width,
+ },
+ ) => {
+ let dimensions_ok = if op == BinaryOperator::Multiply {
+ left_columns == right_rows
+ } else {
+ left_columns == right_columns && left_rows == right_rows
+ };
+
+ // Check that the two arguments have the same dimensions
+ if !dimensions_ok || left_width != right_width {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!(
+ "Cannot apply operation to {left_inner:?} and {right_inner:?}"
+ )
+ .into(),
+ ),
+ meta,
+ })
+ }
+
+ match op {
+ BinaryOperator::Divide => {
+ // Naga IR doesn't support matrix division so we need to
+ // divide the columns individually and reassemble the matrix
+ let mut components = Vec::with_capacity(left_columns as usize);
+
+ for index in 0..left_columns as u32 {
+ // Get the column vectors
+ let left_vector = self.add_expression(
+ Expression::AccessIndex { base: left, index },
+ meta,
+ body,
+ );
+ let right_vector = self.add_expression(
+ Expression::AccessIndex { base: right, index },
+ meta,
+ body,
+ );
+
+ // Divide the vectors
+ let column = self.expressions.append(
+ Expression::Binary {
+ op,
+ left: left_vector,
+ right: right_vector,
+ },
+ meta,
+ );
+
+ components.push(column)
+ }
+
+ // Rebuild the matrix from the divided vectors
+ self.expressions.append(
+ Expression::Compose {
+ ty: frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns: left_columns,
+ rows: left_rows,
+ width: left_width,
+ },
+ },
+ Span::default(),
+ ),
+ components,
+ },
+ meta,
+ )
+ }
+ BinaryOperator::Equal | BinaryOperator::NotEqual => {
+ // Naga IR doesn't support matrix comparisons so we need to
+ // compare the columns individually and then fold them together
+ //
+ // The folding is done using a logical and for equality and
+ // a logical or for inequality
+ let equals = op == BinaryOperator::Equal;
+
+ let (op, combine, fun) = match equals {
+ true => (
+ BinaryOperator::Equal,
+ BinaryOperator::LogicalAnd,
+ RelationalFunction::All,
+ ),
+ false => (
+ BinaryOperator::NotEqual,
+ BinaryOperator::LogicalOr,
+ RelationalFunction::Any,
+ ),
+ };
+
+ let mut root = None;
+
+ for index in 0..left_columns as u32 {
+ // Get the column vectors
+ let left_vector = self.add_expression(
+ Expression::AccessIndex { base: left, index },
+ meta,
+ body,
+ );
+ let right_vector = self.add_expression(
+ Expression::AccessIndex { base: right, index },
+ meta,
+ body,
+ );
+
+ let argument = self.expressions.append(
+ Expression::Binary {
+ op,
+ left: left_vector,
+ right: right_vector,
+ },
+ meta,
+ );
+
+ // The result of comparing two vectors is a boolean vector
+ // so use a relational function like all to get a single
+ // boolean value
+ let compare = self.add_expression(
+ Expression::Relational { fun, argument },
+ meta,
+ body,
+ );
+
+ // Fold the result
+ root = Some(match root {
+ Some(right) => self.add_expression(
+ Expression::Binary {
+ op: combine,
+ left: compare,
+ right,
+ },
+ meta,
+ body,
+ ),
+ None => compare,
+ });
+ }
+
+ root.unwrap()
+ }
+ _ => self.add_expression(
+ Expression::Binary { left, op, right },
+ meta,
+ body,
+ ),
+ }
+ }
+ (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op {
+ BinaryOperator::Equal | BinaryOperator::NotEqual => {
+ let equals = op == BinaryOperator::Equal;
+
+ let (op, fun) = match equals {
+ true => (BinaryOperator::Equal, RelationalFunction::All),
+ false => (BinaryOperator::NotEqual, RelationalFunction::Any),
+ };
+
+ let argument = self
+ .expressions
+ .append(Expression::Binary { op, left, right }, meta);
+
+ self.add_expression(
+ Expression::Relational { fun, argument },
+ meta,
+ body,
+ )
+ }
+ _ => {
+ self.add_expression(Expression::Binary { left, op, right }, meta, body)
+ }
+ },
+ (&TypeInner::Vector { size, .. }, &TypeInner::Scalar { .. }) => match op {
+ BinaryOperator::Add
+ | BinaryOperator::Subtract
+ | BinaryOperator::Divide
+ | BinaryOperator::And
+ | BinaryOperator::ExclusiveOr
+ | BinaryOperator::InclusiveOr
+ | BinaryOperator::ShiftLeft
+ | BinaryOperator::ShiftRight => {
+ let scalar_vector = self.add_expression(
+ Expression::Splat { size, value: right },
+ meta,
+ body,
+ );
+
+ self.add_expression(
+ Expression::Binary {
+ op,
+ left,
+ right: scalar_vector,
+ },
+ meta,
+ body,
+ )
+ }
+ _ => {
+ self.add_expression(Expression::Binary { left, op, right }, meta, body)
+ }
+ },
+ (&TypeInner::Scalar { .. }, &TypeInner::Vector { size, .. }) => match op {
+ BinaryOperator::Add
+ | BinaryOperator::Subtract
+ | BinaryOperator::Divide
+ | BinaryOperator::And
+ | BinaryOperator::ExclusiveOr
+ | BinaryOperator::InclusiveOr => {
+ let scalar_vector = self.add_expression(
+ Expression::Splat { size, value: left },
+ meta,
+ body,
+ );
+
+ self.add_expression(
+ Expression::Binary {
+ op,
+ left: scalar_vector,
+ right,
+ },
+ meta,
+ body,
+ )
+ }
+ _ => {
+ self.add_expression(Expression::Binary { left, op, right }, meta, body)
+ }
+ },
+ (
+ &TypeInner::Scalar {
+ width: left_width, ..
+ },
+ &TypeInner::Matrix {
+ rows,
+ columns,
+ width: right_width,
+ },
+ ) => {
+ // Check that the two arguments have the same width
+ if left_width != right_width {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!(
+ "Cannot apply operation to {left_inner:?} and {right_inner:?}"
+ )
+ .into(),
+ ),
+ meta,
+ })
+ }
+
+ match op {
+ BinaryOperator::Divide
+ | BinaryOperator::Add
+ | BinaryOperator::Subtract => {
+ // Naga IR doesn't support all matrix by scalar operations so
+ // we need for some to turn the scalar into a vector by
+ // splatting it and then for each column vector apply the
+ // operation and finally reconstruct the matrix
+ let scalar_vector = self.add_expression(
+ Expression::Splat {
+ size: rows,
+ value: left,
+ },
+ meta,
+ body,
+ );
+
+ let mut components = Vec::with_capacity(columns as usize);
+
+ for index in 0..columns as u32 {
+ // Get the column vector
+ let matrix_column = self.add_expression(
+ Expression::AccessIndex { base: right, index },
+ meta,
+ body,
+ );
+
+ // Apply the operation to the splatted vector and
+ // the column vector
+ let column = self.expressions.append(
+ Expression::Binary {
+ op,
+ left: scalar_vector,
+ right: matrix_column,
+ },
+ meta,
+ );
+
+ components.push(column)
+ }
+
+ // Rebuild the matrix from the operation result vectors
+ self.expressions.append(
+ Expression::Compose {
+ ty: frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns,
+ rows,
+ width: left_width,
+ },
+ },
+ Span::default(),
+ ),
+ components,
+ },
+ meta,
+ )
+ }
+ _ => self.add_expression(
+ Expression::Binary { left, op, right },
+ meta,
+ body,
+ ),
+ }
+ }
+ (
+ &TypeInner::Matrix {
+ rows,
+ columns,
+ width: left_width,
+ },
+ &TypeInner::Scalar {
+ width: right_width, ..
+ },
+ ) => {
+ // Check that the two arguments have the same width
+ if left_width != right_width {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!(
+ "Cannot apply operation to {left_inner:?} and {right_inner:?}"
+ )
+ .into(),
+ ),
+ meta,
+ })
+ }
+
+ match op {
+ BinaryOperator::Divide
+ | BinaryOperator::Add
+ | BinaryOperator::Subtract => {
+ // Naga IR doesn't support all matrix by scalar operations so
+ // we need for some to turn the scalar into a vector by
+ // splatting it and then for each column vector apply the
+ // operation and finally reconstruct the matrix
+
+ let scalar_vector = self.add_expression(
+ Expression::Splat {
+ size: rows,
+ value: right,
+ },
+ meta,
+ body,
+ );
+
+ let mut components = Vec::with_capacity(columns as usize);
+
+ for index in 0..columns as u32 {
+ // Get the column vector
+ let matrix_column = self.add_expression(
+ Expression::AccessIndex { base: left, index },
+ meta,
+ body,
+ );
+
+ // Apply the operation to the splatted vector and
+ // the column vector
+ let column = self.expressions.append(
+ Expression::Binary {
+ op,
+ left: matrix_column,
+ right: scalar_vector,
+ },
+ meta,
+ );
+
+ components.push(column)
+ }
+
+ // Rebuild the matrix from the operation result vectors
+ self.expressions.append(
+ Expression::Compose {
+ ty: frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns,
+ rows,
+ width: left_width,
+ },
+ },
+ Span::default(),
+ ),
+ components,
+ },
+ meta,
+ )
+ }
+ _ => self.add_expression(
+ Expression::Binary { left, op, right },
+ meta,
+ body,
+ ),
+ }
+ }
+ _ => self.add_expression(Expression::Binary { left, op, right }, meta, body),
+ }
+ }
+ HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => {
+ let expr = self
+ .lower_expect_inner(stmt, frontend, expr, ExprPos::Rhs, body)?
+ .0;
+
+ self.add_expression(Expression::Unary { op, expr }, meta, body)
+ }
+ HirExprKind::Variable(ref var) => match pos {
+ ExprPos::Lhs => {
+ if !var.mutable {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Variable cannot be used in LHS position".into(),
+ ),
+ meta,
+ })
+ }
+
+ var.expr
+ }
+ ExprPos::AccessBase { constant_index } => {
+ // If the index isn't constant all accesses backed by a constant base need
+ // to be done through a proxy local variable, since constants have a non
+ // pointer type which is required for dynamic indexing
+ if !constant_index {
+ if let Some((constant, ty)) = var.constant {
+ let local = self.locals.append(
+ LocalVariable {
+ name: None,
+ ty,
+ init: Some(constant),
+ },
+ Span::default(),
+ );
+
+ self.add_expression(
+ Expression::LocalVariable(local),
+ Span::default(),
+ body,
+ )
+ } else {
+ var.expr
+ }
+ } else {
+ var.expr
+ }
+ }
+ _ if var.load => {
+ self.add_expression(Expression::Load { pointer: var.expr }, meta, body)
+ }
+ ExprPos::Rhs => var.expr,
+ },
+ HirExprKind::Call(ref call) if pos != ExprPos::Lhs => {
+ let maybe_expr = frontend.function_or_constructor_call(
+ self,
+ stmt,
+ body,
+ call.kind.clone(),
+ &call.args,
+ meta,
+ )?;
+ return Ok((maybe_expr, meta));
+ }
+ // `HirExprKind::Conditional` represents the ternary operator in glsl (`:?`)
+ //
+ // The ternary operator is defined to only evaluate one of the two possible
+ // expressions which means that it's behavior is that of an `if` statement,
+ // and it's merely syntatic sugar for it.
+ HirExprKind::Conditional {
+ condition,
+ accept,
+ reject,
+ } if ExprPos::Lhs != pos => {
+ // Given an expression `a ? b : c`, we need to produce a Naga
+ // statement roughly like:
+ //
+ // var temp;
+ // if a {
+ // temp = convert(b);
+ // } else {
+ // temp = convert(c);
+ // }
+ //
+ // where `convert` stands for type conversions to bring `b` and `c` to
+ // the same type, and then use `temp` to represent the value of the whole
+ // conditional expression in subsequent code.
+
+ // Lower the condition first to the current bodyy
+ let condition = self
+ .lower_expect_inner(stmt, frontend, condition, ExprPos::Rhs, body)?
+ .0;
+
+ // Emit all expressions since we will be adding statements to
+ // other bodies next
+ self.emit_restart(body);
+
+ // Create the bodies for the two cases
+ let mut accept_body = Block::new();
+ let mut reject_body = Block::new();
+
+ // Lower the `true` branch
+ let (mut accept, accept_meta) =
+ self.lower_expect_inner(stmt, frontend, accept, pos, &mut accept_body)?;
+
+ // Flush the body of the `true` branch, to start emitting on the
+ // `false` branch
+ self.emit_restart(&mut accept_body);
+
+ // Lower the `false` branch
+ let (mut reject, reject_meta) =
+ self.lower_expect_inner(stmt, frontend, reject, pos, &mut reject_body)?;
+
+ // Flush the body of the `false` branch
+ self.emit_restart(&mut reject_body);
+
+ // We need to do some custom implicit conversions since the two target expressions
+ // are in different bodies
+ if let (
+ Some((accept_power, accept_width, accept_kind)),
+ Some((reject_power, reject_width, reject_kind)),
+ ) = (
+ // Get the components of both branches and calculate the type power
+ self.expr_scalar_components(frontend, accept, accept_meta)?
+ .and_then(|(kind, width)| Some((type_power(kind, width)?, width, kind))),
+ self.expr_scalar_components(frontend, reject, reject_meta)?
+ .and_then(|(kind, width)| Some((type_power(kind, width)?, width, kind))),
+ ) {
+ match accept_power.cmp(&reject_power) {
+ std::cmp::Ordering::Less => {
+ self.conversion(&mut accept, accept_meta, reject_kind, reject_width)?;
+ // The expression belongs to the `true` branch so we need to flush to
+ // the respective body
+ self.emit_end(&mut accept_body);
+ }
+ // Technically there's nothing to flush but later we will need to
+ // add some expressions that must not be emitted so instead
+ // of flushing, starting and flushing again, just make sure
+ // everything is flushed.
+ std::cmp::Ordering::Equal => self.emit_end(body),
+ std::cmp::Ordering::Greater => {
+ self.conversion(&mut reject, reject_meta, accept_kind, accept_width)?;
+ // The expression belongs to the `false` branch so we need to flush to
+ // the respective body
+ self.emit_end(&mut reject_body);
+ }
+ }
+ } else {
+ // Technically there's nothing to flush but later we will need to
+ // add some expressions that must not be emitted.
+ self.emit_end(body)
+ }
+
+ // We need to get the type of the resulting expression to create the local,
+ // this must be done after implicit conversions to ensure both branches have
+ // the same type.
+ let ty = frontend.resolve_type_handle(self, accept, accept_meta)?;
+
+ // Add the local that will hold the result of our conditional
+ let local = self.locals.append(
+ LocalVariable {
+ name: None,
+ ty,
+ init: None,
+ },
+ meta,
+ );
+
+ // Note: `Expression::LocalVariable` must not be emited so it's important
+ // that at this point the emitter is flushed but not started.
+ let local_expr = self
+ .expressions
+ .append(Expression::LocalVariable(local), meta);
+
+ // Add to each body the store to the result variable
+ accept_body.push(
+ Statement::Store {
+ pointer: local_expr,
+ value: accept,
+ },
+ accept_meta,
+ );
+ reject_body.push(
+ Statement::Store {
+ pointer: local_expr,
+ value: reject,
+ },
+ reject_meta,
+ );
+
+ // Finally add the `If` to the main body with the `condition` we lowered
+ // earlier and the branches we prepared.
+ body.push(
+ Statement::If {
+ condition,
+ accept: accept_body,
+ reject: reject_body,
+ },
+ meta,
+ );
+
+ // Restart the emitter
+ self.emit_start();
+
+ // Note: `Expression::Load` must be emited before it's used so make
+ // sure the emitter is active here.
+ self.expressions.append(
+ Expression::Load {
+ pointer: local_expr,
+ },
+ meta,
+ )
+ }
+ HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => {
+ let (pointer, ptr_meta) =
+ self.lower_expect_inner(stmt, frontend, tgt, ExprPos::Lhs, body)?;
+ let (mut value, value_meta) =
+ self.lower_expect_inner(stmt, frontend, value, ExprPos::Rhs, body)?;
+
+ let ty = match *frontend.resolve_type(self, pointer, ptr_meta)? {
+ TypeInner::Pointer { base, .. } => &frontend.module.types[base].inner,
+ ref ty => ty,
+ };
+
+ if let Some((kind, width)) = scalar_components(ty) {
+ self.implicit_conversion(frontend, &mut value, value_meta, kind, width)?;
+ }
+
+ self.lower_store(pointer, value, meta, body);
+
+ value
+ }
+ HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => {
+ let (pointer, _) =
+ self.lower_expect_inner(stmt, frontend, expr, ExprPos::Lhs, body)?;
+ let left = if let Expression::Swizzle { .. } = self.expressions[pointer] {
+ pointer
+ } else {
+ self.add_expression(Expression::Load { pointer }, meta, body)
+ };
+
+ let make_constant_inner = |kind, width| {
+ let value = match kind {
+ ScalarKind::Sint => crate::ScalarValue::Sint(1),
+ ScalarKind::Uint => crate::ScalarValue::Uint(1),
+ ScalarKind::Float => crate::ScalarValue::Float(1.0),
+ ScalarKind::Bool => return None,
+ };
+
+ Some(crate::ConstantInner::Scalar { width, value })
+ };
+ let res = match *frontend.resolve_type(self, left, meta)? {
+ TypeInner::Scalar { kind, width } => {
+ let ty = TypeInner::Scalar { kind, width };
+ make_constant_inner(kind, width).map(|i| (ty, i, None, None))
+ }
+ TypeInner::Vector { size, kind, width } => {
+ let ty = TypeInner::Vector { size, kind, width };
+ make_constant_inner(kind, width).map(|i| (ty, i, Some(size), None))
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let ty = TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ };
+ make_constant_inner(ScalarKind::Float, width)
+ .map(|i| (ty, i, Some(rows), Some(columns)))
+ }
+ _ => None,
+ };
+ let (ty_inner, inner, rows, columns) = match res {
+ Some(res) => res,
+ None => {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Increment/decrement only works on scalar/vector/matrix".into(),
+ ),
+ meta,
+ });
+ return Ok((Some(left), meta));
+ }
+ };
+
+ let constant_1 = frontend.module.constants.append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ Default::default(),
+ );
+ let mut right = self.add_expression(Expression::Constant(constant_1), meta, body);
+
+ // Glsl allows pre/postfixes operations on vectors and matrices, so if the
+ // target is either of them change the right side of the addition to be splatted
+ // to the same size as the target, furthermore if the target is a matrix
+ // use a composed matrix using the splatted value.
+ if let Some(size) = rows {
+ right =
+ self.add_expression(Expression::Splat { size, value: right }, meta, body);
+
+ if let Some(cols) = columns {
+ let ty = frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: ty_inner,
+ },
+ meta,
+ );
+
+ right = self.add_expression(
+ Expression::Compose {
+ ty,
+ components: std::iter::repeat(right).take(cols as usize).collect(),
+ },
+ meta,
+ body,
+ );
+ }
+ }
+
+ let value = self.add_expression(Expression::Binary { op, left, right }, meta, body);
+
+ self.lower_store(pointer, value, meta, body);
+
+ if postfix {
+ left
+ } else {
+ value
+ }
+ }
+ HirExprKind::Method {
+ expr: object,
+ ref name,
+ ref args,
+ } if ExprPos::Lhs != pos => {
+ let args = args
+ .iter()
+ .map(|e| self.lower_expect_inner(stmt, frontend, *e, ExprPos::Rhs, body))
+ .collect::<Result<Vec<_>>>()?;
+ match name.as_ref() {
+ "length" => {
+ if !args.is_empty() {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ ".length() doesn't take any arguments".into(),
+ ),
+ meta,
+ });
+ }
+ let lowered_array = self
+ .lower_expect_inner(stmt, frontend, object, pos, body)?
+ .0;
+ let array_type = frontend.resolve_type(self, lowered_array, meta)?;
+
+ match *array_type {
+ TypeInner::Array {
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ let mut array_length =
+ self.add_expression(Expression::Constant(size), meta, body);
+ self.forced_conversion(
+ frontend,
+ &mut array_length,
+ meta,
+ ScalarKind::Sint,
+ 4,
+ )?;
+ array_length
+ }
+ // let the error be handled in type checking if it's not a dynamic array
+ _ => {
+ let mut array_length = self.add_expression(
+ Expression::ArrayLength(lowered_array),
+ meta,
+ body,
+ );
+ self.conversion(&mut array_length, meta, ScalarKind::Sint, 4)?;
+ array_length
+ }
+ }
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::SemanticError(
+ format!("unknown method '{name}'").into(),
+ ),
+ meta,
+ });
+ }
+ }
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::SemanticError(
+ format!("{:?} cannot be in the left hand side", stmt.hir_exprs[expr])
+ .into(),
+ ),
+ meta,
+ })
+ }
+ };
+
+ log::trace!(
+ "Lowered {:?}\n\tKind = {:?}\n\tPos = {:?}\n\tResult = {:?}",
+ expr,
+ kind,
+ pos,
+ handle
+ );
+
+ Ok((Some(handle), meta))
+ }
+
+ pub fn expr_scalar_components(
+ &mut self,
+ frontend: &Frontend,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<Option<(ScalarKind, crate::Bytes)>> {
+ let ty = frontend.resolve_type(self, expr, meta)?;
+ Ok(scalar_components(ty))
+ }
+
+ pub fn expr_power(
+ &mut self,
+ frontend: &Frontend,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<Option<u32>> {
+ Ok(self
+ .expr_scalar_components(frontend, expr, meta)?
+ .and_then(|(kind, width)| type_power(kind, width)))
+ }
+
+ pub fn conversion(
+ &mut self,
+ expr: &mut Handle<Expression>,
+ meta: Span,
+ kind: ScalarKind,
+ width: crate::Bytes,
+ ) -> Result<()> {
+ *expr = self.expressions.append(
+ Expression::As {
+ expr: *expr,
+ kind,
+ convert: Some(width),
+ },
+ meta,
+ );
+
+ Ok(())
+ }
+
+ pub fn implicit_conversion(
+ &mut self,
+ frontend: &Frontend,
+ expr: &mut Handle<Expression>,
+ meta: Span,
+ kind: ScalarKind,
+ width: crate::Bytes,
+ ) -> Result<()> {
+ if let (Some(tgt_power), Some(expr_power)) = (
+ type_power(kind, width),
+ self.expr_power(frontend, *expr, meta)?,
+ ) {
+ if tgt_power > expr_power {
+ self.conversion(expr, meta, kind, width)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ pub fn forced_conversion(
+ &mut self,
+ frontend: &Frontend,
+ expr: &mut Handle<Expression>,
+ meta: Span,
+ kind: ScalarKind,
+ width: crate::Bytes,
+ ) -> Result<()> {
+ if let Some((expr_scalar_kind, expr_width)) =
+ self.expr_scalar_components(frontend, *expr, meta)?
+ {
+ if expr_scalar_kind != kind || expr_width != width {
+ self.conversion(expr, meta, kind, width)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ pub fn binary_implicit_conversion(
+ &mut self,
+ frontend: &Frontend,
+ left: &mut Handle<Expression>,
+ left_meta: Span,
+ right: &mut Handle<Expression>,
+ right_meta: Span,
+ ) -> Result<()> {
+ let left_components = self.expr_scalar_components(frontend, *left, left_meta)?;
+ let right_components = self.expr_scalar_components(frontend, *right, right_meta)?;
+
+ if let (
+ Some((left_power, left_width, left_kind)),
+ Some((right_power, right_width, right_kind)),
+ ) = (
+ left_components.and_then(|(kind, width)| Some((type_power(kind, width)?, width, kind))),
+ right_components
+ .and_then(|(kind, width)| Some((type_power(kind, width)?, width, kind))),
+ ) {
+ match left_power.cmp(&right_power) {
+ std::cmp::Ordering::Less => {
+ self.conversion(left, left_meta, right_kind, right_width)?;
+ }
+ std::cmp::Ordering::Equal => {}
+ std::cmp::Ordering::Greater => {
+ self.conversion(right, right_meta, left_kind, left_width)?;
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ pub fn implicit_splat(
+ &mut self,
+ frontend: &Frontend,
+ expr: &mut Handle<Expression>,
+ meta: Span,
+ vector_size: Option<VectorSize>,
+ ) -> Result<()> {
+ let expr_type = frontend.resolve_type(self, *expr, meta)?;
+
+ if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) {
+ *expr = self
+ .expressions
+ .append(Expression::Splat { size, value: *expr }, meta)
+ }
+
+ Ok(())
+ }
+
+ pub fn vector_resize(
+ &mut self,
+ size: VectorSize,
+ vector: Handle<Expression>,
+ meta: Span,
+ body: &mut Block,
+ ) -> Handle<Expression> {
+ self.add_expression(
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern: crate::SwizzleComponent::XYZW,
+ },
+ meta,
+ body,
+ )
+ }
+}
+
+impl Index<Handle<Expression>> for Context {
+ type Output = Expression;
+
+ fn index(&self, index: Handle<Expression>) -> &Self::Output {
+ &self.expressions[index]
+ }
+}
+
+/// Helper struct passed when parsing expressions
+///
+/// This struct should only be obtained through [`stmt_ctx`](Context::stmt_ctx)
+/// and only one of these may be active at any time per context.
+#[derive(Debug)]
+pub struct StmtContext {
+ /// A arena of high level expressions which can be lowered through a
+ /// [`Context`](Context) to naga's [`Expression`](crate::Expression)s
+ pub hir_exprs: Arena<HirExpr>,
+}
+
+impl StmtContext {
+ const fn new() -> Self {
+ StmtContext {
+ hir_exprs: Arena::new(),
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/error.rs b/third_party/rust/naga/src/front/glsl/error.rs
new file mode 100644
index 0000000000..46b3aecd08
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/error.rs
@@ -0,0 +1,134 @@
+use super::{constants::ConstantSolvingError, token::TokenValue};
+use crate::Span;
+use pp_rs::token::PreprocessorError;
+use std::borrow::Cow;
+use thiserror::Error;
+
+fn join_with_comma(list: &[ExpectedToken]) -> String {
+ let mut string = "".to_string();
+ for (i, val) in list.iter().enumerate() {
+ string.push_str(&val.to_string());
+ match i {
+ i if i == list.len() - 1 => {}
+ i if i == list.len() - 2 => string.push_str(" or "),
+ _ => string.push_str(", "),
+ }
+ }
+ string
+}
+
+/// One of the expected tokens returned in [`InvalidToken`](ErrorKind::InvalidToken).
+#[derive(Debug, PartialEq)]
+pub enum ExpectedToken {
+ /// A specific token was expected.
+ Token(TokenValue),
+ /// A type was expected.
+ TypeName,
+ /// An identifier was expected.
+ Identifier,
+ /// An integer literal was expected.
+ IntLiteral,
+ /// A float literal was expected.
+ FloatLiteral,
+ /// A boolean literal was expected.
+ BoolLiteral,
+ /// The end of file was expected.
+ Eof,
+}
+impl From<TokenValue> for ExpectedToken {
+ fn from(token: TokenValue) -> Self {
+ ExpectedToken::Token(token)
+ }
+}
+impl std::fmt::Display for ExpectedToken {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match *self {
+ ExpectedToken::Token(ref token) => write!(f, "{token:?}"),
+ ExpectedToken::TypeName => write!(f, "a type"),
+ ExpectedToken::Identifier => write!(f, "identifier"),
+ ExpectedToken::IntLiteral => write!(f, "integer literal"),
+ ExpectedToken::FloatLiteral => write!(f, "float literal"),
+ ExpectedToken::BoolLiteral => write!(f, "bool literal"),
+ ExpectedToken::Eof => write!(f, "end of file"),
+ }
+ }
+}
+
+/// Information about the cause of an error.
+#[derive(Debug, Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum ErrorKind {
+ /// Whilst parsing as encountered an unexpected EOF.
+ #[error("Unexpected end of file")]
+ EndOfFile,
+ /// The shader specified an unsupported or invalid profile.
+ #[error("Invalid profile: {0}")]
+ InvalidProfile(String),
+ /// The shader requested an unsupported or invalid version.
+ #[error("Invalid version: {0}")]
+ InvalidVersion(u64),
+ /// Whilst parsing an unexpected token was encountered.
+ ///
+ /// A list of expected tokens is also returned.
+ #[error("Expected {}, found {0:?}", join_with_comma(.1))]
+ InvalidToken(TokenValue, Vec<ExpectedToken>),
+ /// A specific feature is not yet implemented.
+ ///
+ /// To help prioritize work please open an issue in the github issue tracker
+ /// if none exist already or react to the already existing one.
+ #[error("Not implemented: {0}")]
+ NotImplemented(&'static str),
+ /// A reference to a variable that wasn't declared was used.
+ #[error("Unknown variable: {0}")]
+ UnknownVariable(String),
+ /// A reference to a type that wasn't declared was used.
+ #[error("Unknown type: {0}")]
+ UnknownType(String),
+ /// A reference to a non existent member of a type was made.
+ #[error("Unknown field: {0}")]
+ UnknownField(String),
+ /// An unknown layout qualifier was used.
+ ///
+ /// If the qualifier does exist please open an issue in the github issue tracker
+ /// if none exist already or react to the already existing one to help
+ /// prioritize work.
+ #[error("Unknown layout qualifier: {0}")]
+ UnknownLayoutQualifier(String),
+ /// Unsupported matrix of the form matCx2
+ ///
+ /// Our IR expects matrices of the form matCx2 to have a stride of 8 however
+ /// matrices in the std140 layout have a stride of at least 16
+ #[error("unsupported matrix of the form matCx2 in std140 block layout")]
+ UnsupportedMatrixTypeInStd140,
+ /// A variable with the same name already exists in the current scope.
+ #[error("Variable already declared: {0}")]
+ VariableAlreadyDeclared(String),
+ /// A semantic error was detected in the shader.
+ #[error("{0}")]
+ SemanticError(Cow<'static, str>),
+ /// An error was returned by the preprocessor.
+ #[error("{0:?}")]
+ PreprocessorError(PreprocessorError),
+ /// The parser entered an illegal state and exited
+ ///
+ /// This obviously is a bug and as such should be reported in the github issue tracker
+ #[error("Internal error: {0}")]
+ InternalError(&'static str),
+}
+
+impl From<ConstantSolvingError> for ErrorKind {
+ fn from(err: ConstantSolvingError) -> Self {
+ ErrorKind::SemanticError(err.to_string().into())
+ }
+}
+
+/// Error returned during shader parsing.
+#[derive(Debug, Error)]
+#[error("{kind}")]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Error {
+ /// Holds the information about the error itself.
+ pub kind: ErrorKind,
+ /// Holds information about the range of the source code where the error happened.
+ pub meta: Span,
+}
diff --git a/third_party/rust/naga/src/front/glsl/functions.rs b/third_party/rust/naga/src/front/glsl/functions.rs
new file mode 100644
index 0000000000..10c964b5e0
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/functions.rs
@@ -0,0 +1,1680 @@
+use super::{
+ ast::*,
+ builtins::{inject_builtin, sampled_to_depth},
+ context::{Context, ExprPos, StmtContext},
+ error::{Error, ErrorKind},
+ types::scalar_components,
+ Frontend, Result,
+};
+use crate::{
+ front::glsl::types::type_power, proc::ensure_block_returns, AddressSpace, Arena, Block,
+ Constant, ConstantInner, EntryPoint, Expression, Function, FunctionArgument, FunctionResult,
+ Handle, LocalVariable, ScalarKind, ScalarValue, Span, Statement, StructMember, Type, TypeInner,
+};
+use std::iter;
+
+/// Struct detailing a store operation that must happen after a function call
+struct ProxyWrite {
+ /// The store target
+ target: Handle<Expression>,
+ /// A pointer to read the value of the store
+ value: Handle<Expression>,
+ /// An optional conversion to be applied
+ convert: Option<(ScalarKind, crate::Bytes)>,
+}
+
+impl Frontend {
+ fn add_constant_value(
+ &mut self,
+ scalar_kind: ScalarKind,
+ value: u64,
+ meta: Span,
+ ) -> Handle<Constant> {
+ let value = match scalar_kind {
+ ScalarKind::Uint => ScalarValue::Uint(value),
+ ScalarKind::Sint => ScalarValue::Sint(value as i64),
+ ScalarKind::Float => ScalarValue::Float(value as f64),
+ ScalarKind::Bool => unreachable!(),
+ };
+
+ self.module.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar { width: 4, value },
+ },
+ meta,
+ )
+ }
+
+ pub(crate) fn function_or_constructor_call(
+ &mut self,
+ ctx: &mut Context,
+ stmt: &StmtContext,
+ body: &mut Block,
+ fc: FunctionCallKind,
+ raw_args: &[Handle<HirExpr>],
+ meta: Span,
+ ) -> Result<Option<Handle<Expression>>> {
+ let args: Vec<_> = raw_args
+ .iter()
+ .map(|e| ctx.lower_expect_inner(stmt, self, *e, ExprPos::Rhs, body))
+ .collect::<Result<_>>()?;
+
+ match fc {
+ FunctionCallKind::TypeConstructor(ty) => {
+ if args.len() == 1 {
+ self.constructor_single(ctx, body, ty, args[0], meta)
+ .map(Some)
+ } else {
+ self.constructor_many(ctx, body, ty, args, meta).map(Some)
+ }
+ }
+ FunctionCallKind::Function(name) => {
+ self.function_call(ctx, stmt, body, name, args, raw_args, meta)
+ }
+ }
+ }
+
+ fn constructor_single(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ ty: Handle<Type>,
+ (mut value, expr_meta): (Handle<Expression>, Span),
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let expr_type = self.resolve_type(ctx, value, expr_meta)?;
+
+ let vector_size = match *expr_type {
+ TypeInner::Vector { size, .. } => Some(size),
+ _ => None,
+ };
+
+ // Special case: if casting from a bool, we need to use Select and not As.
+ match self.module.types[ty].inner.scalar_kind() {
+ Some(result_scalar_kind)
+ if expr_type.scalar_kind() == Some(ScalarKind::Bool)
+ && result_scalar_kind != ScalarKind::Bool =>
+ {
+ let c0 = self.add_constant_value(result_scalar_kind, 0u64, meta);
+ let c1 = self.add_constant_value(result_scalar_kind, 1u64, meta);
+ let mut reject = ctx.add_expression(Expression::Constant(c0), expr_meta, body);
+ let mut accept = ctx.add_expression(Expression::Constant(c1), expr_meta, body);
+
+ ctx.implicit_splat(self, &mut reject, meta, vector_size)?;
+ ctx.implicit_splat(self, &mut accept, meta, vector_size)?;
+
+ let h = ctx.add_expression(
+ Expression::Select {
+ accept,
+ reject,
+ condition: value,
+ },
+ expr_meta,
+ body,
+ );
+
+ return Ok(h);
+ }
+ _ => {}
+ }
+
+ Ok(match self.module.types[ty].inner {
+ TypeInner::Vector { size, kind, width } if vector_size.is_none() => {
+ ctx.forced_conversion(self, &mut value, expr_meta, kind, width)?;
+
+ if let TypeInner::Scalar { .. } = *self.resolve_type(ctx, value, expr_meta)? {
+ ctx.add_expression(Expression::Splat { size, value }, meta, body)
+ } else {
+ self.vector_constructor(
+ ctx,
+ body,
+ ty,
+ size,
+ kind,
+ width,
+ &[(value, expr_meta)],
+ meta,
+ )?
+ }
+ }
+ TypeInner::Scalar { kind, width } => {
+ let mut expr = value;
+ if let TypeInner::Vector { .. } | TypeInner::Matrix { .. } =
+ *self.resolve_type(ctx, value, expr_meta)?
+ {
+ expr = ctx.add_expression(
+ Expression::AccessIndex {
+ base: expr,
+ index: 0,
+ },
+ meta,
+ body,
+ );
+ }
+
+ if let TypeInner::Matrix { .. } = *self.resolve_type(ctx, value, expr_meta)? {
+ expr = ctx.add_expression(
+ Expression::AccessIndex {
+ base: expr,
+ index: 0,
+ },
+ meta,
+ body,
+ );
+ }
+
+ ctx.add_expression(
+ Expression::As {
+ kind,
+ expr,
+ convert: Some(width),
+ },
+ meta,
+ body,
+ )
+ }
+ TypeInner::Vector { size, kind, width } => {
+ if vector_size.map_or(true, |s| s != size) {
+ value = ctx.vector_resize(size, value, expr_meta, body);
+ }
+
+ ctx.add_expression(
+ Expression::As {
+ kind,
+ expr: value,
+ convert: Some(width),
+ },
+ meta,
+ body,
+ )
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => self.matrix_one_arg(
+ ctx,
+ body,
+ ty,
+ columns,
+ rows,
+ width,
+ (value, expr_meta),
+ meta,
+ )?,
+ TypeInner::Struct { ref members, .. } => {
+ let scalar_components = members
+ .get(0)
+ .and_then(|member| scalar_components(&self.module.types[member.ty].inner));
+ if let Some((kind, width)) = scalar_components {
+ ctx.implicit_conversion(self, &mut value, expr_meta, kind, width)?;
+ }
+
+ ctx.add_expression(
+ Expression::Compose {
+ ty,
+ components: vec![value],
+ },
+ meta,
+ body,
+ )
+ }
+
+ TypeInner::Array { base, .. } => {
+ let scalar_components = scalar_components(&self.module.types[base].inner);
+ if let Some((kind, width)) = scalar_components {
+ ctx.implicit_conversion(self, &mut value, expr_meta, kind, width)?;
+ }
+
+ ctx.add_expression(
+ Expression::Compose {
+ ty,
+ components: vec![value],
+ },
+ meta,
+ body,
+ )
+ }
+ _ => {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError("Bad type constructor".into()),
+ meta,
+ });
+
+ value
+ }
+ })
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn matrix_one_arg(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ ty: Handle<Type>,
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: crate::Bytes,
+ (mut value, expr_meta): (Handle<Expression>, Span),
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let mut components = Vec::with_capacity(columns as usize);
+ // TODO: casts
+ // `Expression::As` doesn't support matrix width
+ // casts so we need to do some extra work for casts
+
+ ctx.forced_conversion(self, &mut value, expr_meta, ScalarKind::Float, width)?;
+ match *self.resolve_type(ctx, value, expr_meta)? {
+ TypeInner::Scalar { .. } => {
+ // If a matrix is constructed with a single scalar value, then that
+ // value is used to initialize all the values along the diagonal of
+ // the matrix; the rest are given zeros.
+ let vector_ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: rows,
+ kind: ScalarKind::Float,
+ width,
+ },
+ },
+ meta,
+ );
+ let zero_constant = self.module.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width,
+ value: ScalarValue::Float(0.0),
+ },
+ },
+ meta,
+ );
+ let zero = ctx.add_expression(Expression::Constant(zero_constant), meta, body);
+
+ for i in 0..columns as u32 {
+ components.push(
+ ctx.add_expression(
+ Expression::Compose {
+ ty: vector_ty,
+ components: (0..rows as u32)
+ .map(|r| match r == i {
+ true => value,
+ false => zero,
+ })
+ .collect(),
+ },
+ meta,
+ body,
+ ),
+ )
+ }
+ }
+ TypeInner::Matrix {
+ rows: ori_rows,
+ columns: ori_cols,
+ ..
+ } => {
+ // If a matrix is constructed from a matrix, then each component
+ // (column i, row j) in the result that has a corresponding component
+ // (column i, row j) in the argument will be initialized from there. All
+ // other components will be initialized to the identity matrix.
+
+ let zero_constant = self.module.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width,
+ value: ScalarValue::Float(0.0),
+ },
+ },
+ meta,
+ );
+ let zero = ctx.add_expression(Expression::Constant(zero_constant), meta, body);
+ let one_constant = self.module.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width,
+ value: ScalarValue::Float(1.0),
+ },
+ },
+ meta,
+ );
+ let one = ctx.add_expression(Expression::Constant(one_constant), meta, body);
+ let vector_ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: rows,
+ kind: ScalarKind::Float,
+ width,
+ },
+ },
+ meta,
+ );
+
+ for i in 0..columns as u32 {
+ if i < ori_cols as u32 {
+ use std::cmp::Ordering;
+
+ let vector = ctx.add_expression(
+ Expression::AccessIndex {
+ base: value,
+ index: i,
+ },
+ meta,
+ body,
+ );
+
+ components.push(match ori_rows.cmp(&rows) {
+ Ordering::Less => {
+ let components = (0..rows as u32)
+ .map(|r| {
+ if r < ori_rows as u32 {
+ ctx.add_expression(
+ Expression::AccessIndex {
+ base: vector,
+ index: r,
+ },
+ meta,
+ body,
+ )
+ } else if r == i {
+ one
+ } else {
+ zero
+ }
+ })
+ .collect();
+
+ ctx.add_expression(
+ Expression::Compose {
+ ty: vector_ty,
+ components,
+ },
+ meta,
+ body,
+ )
+ }
+ Ordering::Equal => vector,
+ Ordering::Greater => ctx.vector_resize(rows, vector, meta, body),
+ })
+ } else {
+ let vec_constant = self.module.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Composite {
+ ty: vector_ty,
+ components: (0..rows as u32)
+ .map(|r| match r == i {
+ true => one_constant,
+ false => zero_constant,
+ })
+ .collect(),
+ },
+ },
+ meta,
+ );
+ let vec =
+ ctx.add_expression(Expression::Constant(vec_constant), meta, body);
+
+ components.push(vec)
+ }
+ }
+ }
+ _ => {
+ components = iter::repeat(value).take(columns as usize).collect();
+ }
+ }
+
+ Ok(ctx.add_expression(Expression::Compose { ty, components }, meta, body))
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn vector_constructor(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ ty: Handle<Type>,
+ size: crate::VectorSize,
+ kind: ScalarKind,
+ width: crate::Bytes,
+ args: &[(Handle<Expression>, Span)],
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let mut components = Vec::with_capacity(size as usize);
+
+ for (mut arg, expr_meta) in args.iter().copied() {
+ ctx.forced_conversion(self, &mut arg, expr_meta, kind, width)?;
+
+ if components.len() >= size as usize {
+ break;
+ }
+
+ match *self.resolve_type(ctx, arg, expr_meta)? {
+ TypeInner::Scalar { .. } => components.push(arg),
+ TypeInner::Matrix { rows, columns, .. } => {
+ components.reserve(rows as usize * columns as usize);
+ for c in 0..(columns as u32) {
+ let base = ctx.add_expression(
+ Expression::AccessIndex {
+ base: arg,
+ index: c,
+ },
+ expr_meta,
+ body,
+ );
+ for r in 0..(rows as u32) {
+ components.push(ctx.add_expression(
+ Expression::AccessIndex { base, index: r },
+ expr_meta,
+ body,
+ ))
+ }
+ }
+ }
+ TypeInner::Vector { size: ori_size, .. } => {
+ components.reserve(ori_size as usize);
+ for index in 0..(ori_size as u32) {
+ components.push(ctx.add_expression(
+ Expression::AccessIndex { base: arg, index },
+ expr_meta,
+ body,
+ ))
+ }
+ }
+ _ => components.push(arg),
+ }
+ }
+
+ components.truncate(size as usize);
+
+ Ok(ctx.add_expression(Expression::Compose { ty, components }, meta, body))
+ }
+
+ fn constructor_many(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ ty: Handle<Type>,
+ args: Vec<(Handle<Expression>, Span)>,
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let mut components = Vec::with_capacity(args.len());
+
+ match self.module.types[ty].inner {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let mut flattened = Vec::with_capacity(columns as usize * rows as usize);
+
+ for (mut arg, meta) in args.iter().copied() {
+ ctx.forced_conversion(self, &mut arg, meta, ScalarKind::Float, width)?;
+
+ match *self.resolve_type(ctx, arg, meta)? {
+ TypeInner::Vector { size, .. } => {
+ for i in 0..(size as u32) {
+ flattened.push(ctx.add_expression(
+ Expression::AccessIndex {
+ base: arg,
+ index: i,
+ },
+ meta,
+ body,
+ ))
+ }
+ }
+ _ => flattened.push(arg),
+ }
+ }
+
+ let ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: rows,
+ kind: ScalarKind::Float,
+ width,
+ },
+ },
+ meta,
+ );
+
+ for chunk in flattened.chunks(rows as usize) {
+ components.push(ctx.add_expression(
+ Expression::Compose {
+ ty,
+ components: Vec::from(chunk),
+ },
+ meta,
+ body,
+ ))
+ }
+ }
+ TypeInner::Vector { size, kind, width } => {
+ return self.vector_constructor(ctx, body, ty, size, kind, width, &args, meta)
+ }
+ TypeInner::Array { base, .. } => {
+ for (mut arg, meta) in args.iter().copied() {
+ let scalar_components = scalar_components(&self.module.types[base].inner);
+ if let Some((kind, width)) = scalar_components {
+ ctx.implicit_conversion(self, &mut arg, meta, kind, width)?;
+ }
+
+ components.push(arg)
+ }
+ }
+ TypeInner::Struct { ref members, .. } => {
+ for ((mut arg, meta), member) in args.iter().copied().zip(members.iter()) {
+ let scalar_components = scalar_components(&self.module.types[member.ty].inner);
+ if let Some((kind, width)) = scalar_components {
+ ctx.implicit_conversion(self, &mut arg, meta, kind, width)?;
+ }
+
+ components.push(arg)
+ }
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::SemanticError("Constructor: Too many arguments".into()),
+ meta,
+ })
+ }
+ }
+
+ Ok(ctx.add_expression(Expression::Compose { ty, components }, meta, body))
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn function_call(
+ &mut self,
+ ctx: &mut Context,
+ stmt: &StmtContext,
+ body: &mut Block,
+ name: String,
+ args: Vec<(Handle<Expression>, Span)>,
+ raw_args: &[Handle<HirExpr>],
+ meta: Span,
+ ) -> Result<Option<Handle<Expression>>> {
+ // Grow the typifier to be able to index it later without needing
+ // to hold the context mutably
+ for &(expr, span) in args.iter() {
+ self.typifier_grow(ctx, expr, span)?;
+ }
+
+ // Check if the passed arguments require any special variations
+ let mut variations = builtin_required_variations(
+ args.iter()
+ .map(|&(expr, _)| ctx.typifier.get(expr, &self.module.types)),
+ );
+
+ // Initiate the declaration if it wasn't previously initialized and inject builtins
+ let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| {
+ variations |= BuiltinVariations::STANDARD;
+ Default::default()
+ });
+ inject_builtin(declaration, &mut self.module, &name, variations);
+
+ // Borrow again but without mutability, at this point a declaration is guaranteed
+ let declaration = self.lookup_function.get(&name).unwrap();
+
+ // Possibly contains the overload to be used in the call
+ let mut maybe_overload = None;
+ // The conversions needed for the best analyzed overload, this is initialized all to
+ // `NONE` to make sure that conversions always pass the first time without ambiguity
+ let mut old_conversions = vec![Conversion::None; args.len()];
+ // Tracks whether the comparison between overloads lead to an ambiguity
+ let mut ambiguous = false;
+
+ // Iterate over all the available overloads to select either an exact match or a
+ // overload which has suitable implicit conversions
+ 'outer: for (overload_idx, overload) in declaration.overloads.iter().enumerate() {
+ // If the overload and the function call don't have the same number of arguments
+ // continue to the next overload
+ if args.len() != overload.parameters.len() {
+ continue;
+ }
+
+ log::trace!("Testing overload {}", overload_idx);
+
+ // Stores whether the current overload matches exactly the function call
+ let mut exact = true;
+ // State of the selection
+ // If None we still don't know what is the best overload
+ // If Some(true) the new overload is better
+ // If Some(false) the old overload is better
+ let mut superior = None;
+ // Store the conversions for the current overload so that later they can replace the
+ // conversions used for querying the best overload
+ let mut new_conversions = vec![Conversion::None; args.len()];
+
+ // Loop through the overload parameters and check if the current overload is better
+ // compared to the previous best overload.
+ for (i, overload_parameter) in overload.parameters.iter().enumerate() {
+ let call_argument = &args[i];
+ let parameter_info = &overload.parameters_info[i];
+
+ // If the image is used in the overload as a depth texture convert it
+ // before comparing, otherwise exact matches wouldn't be reported
+ if parameter_info.depth {
+ sampled_to_depth(
+ &mut self.module,
+ ctx,
+ call_argument.0,
+ call_argument.1,
+ &mut self.errors,
+ );
+ self.invalidate_expression(ctx, call_argument.0, call_argument.1)?
+ }
+
+ let overload_param_ty = &self.module.types[*overload_parameter].inner;
+ let call_arg_ty = self.resolve_type(ctx, call_argument.0, call_argument.1)?;
+
+ log::trace!(
+ "Testing parameter {}\n\tOverload = {:?}\n\tCall = {:?}",
+ i,
+ overload_param_ty,
+ call_arg_ty
+ );
+
+ // Storage images cannot be directly compared since while the access is part of the
+ // type in naga's IR, in glsl they are a qualifier and don't enter in the match as
+ // long as the access needed is satisfied.
+ if let (
+ &TypeInner::Image {
+ class:
+ crate::ImageClass::Storage {
+ format: overload_format,
+ access: overload_access,
+ },
+ dim: overload_dim,
+ arrayed: overload_arrayed,
+ },
+ &TypeInner::Image {
+ class:
+ crate::ImageClass::Storage {
+ format: call_format,
+ access: call_access,
+ },
+ dim: call_dim,
+ arrayed: call_arrayed,
+ },
+ ) = (overload_param_ty, call_arg_ty)
+ {
+ // Images size must match otherwise the overload isn't what we want
+ let good_size = call_dim == overload_dim && call_arrayed == overload_arrayed;
+ // Glsl requires the formats to strictly match unless you are builtin
+ // function overload and have not been replaced, in which case we only
+ // check that the format scalar kind matches
+ let good_format = overload_format == call_format
+ || (overload.internal
+ && ScalarKind::from(overload_format) == ScalarKind::from(call_format));
+ if !(good_size && good_format) {
+ continue 'outer;
+ }
+
+ // While storage access mismatch is an error it isn't one that causes
+ // the overload matching to fail so we defer the error and consider
+ // that the images match exactly
+ if !call_access.contains(overload_access) {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!(
+ "'{name}': image needs {overload_access:?} access but only {call_access:?} was provided"
+ )
+ .into(),
+ ),
+ meta,
+ });
+ }
+
+ // The images satisfy the conditions to be considered as an exact match
+ new_conversions[i] = Conversion::Exact;
+ continue;
+ } else if overload_param_ty == call_arg_ty {
+ // If the types match there's no need to check for conversions so continue
+ new_conversions[i] = Conversion::Exact;
+ continue;
+ }
+
+ // Glsl defines that inout follows both the conversions for input parameters and
+ // output parameters, this means that the type must have a conversion from both the
+ // call argument to the function parameter and the function parameter to the call
+ // argument, the only way this is possible is for the conversion to be an identity
+ // (i.e. call argument = function parameter)
+ if let ParameterQualifier::InOut = parameter_info.qualifier {
+ continue 'outer;
+ }
+
+ // The function call argument and the function definition
+ // parameter are not equal at this point, so we need to try
+ // implicit conversions.
+ //
+ // Now there are two cases, the argument is defined as a normal
+ // parameter (`in` or `const`), in this case an implicit
+ // conversion is made from the calling argument to the
+ // definition argument. If the parameter is `out` the
+ // opposite needs to be done, so the implicit conversion is made
+ // from the definition argument to the calling argument.
+ let maybe_conversion = if parameter_info.qualifier.is_lhs() {
+ conversion(call_arg_ty, overload_param_ty)
+ } else {
+ conversion(overload_param_ty, call_arg_ty)
+ };
+
+ let conversion = match maybe_conversion {
+ Some(info) => info,
+ None => continue 'outer,
+ };
+
+ // At this point a conversion will be needed so the overload no longer
+ // exactly matches the call arguments
+ exact = false;
+
+ // Compare the conversions needed for this overload parameter to that of the
+ // last overload analyzed respective parameter, the value is:
+ // - `true` when the new overload argument has a better conversion
+ // - `false` when the old overload argument has a better conversion
+ let best_arg = match (conversion, old_conversions[i]) {
+ // An exact match is always better, we don't need to check this for the
+ // current overload since it was checked earlier
+ (_, Conversion::Exact) => false,
+ // No overload was yet analyzed so this one is the best yet
+ (_, Conversion::None) => true,
+ // A conversion from a float to a double is the best possible conversion
+ (Conversion::FloatToDouble, _) => true,
+ (_, Conversion::FloatToDouble) => false,
+ // A conversion from a float to an integer is preferred than one
+ // from double to an integer
+ (Conversion::IntToFloat, Conversion::IntToDouble) => true,
+ (Conversion::IntToDouble, Conversion::IntToFloat) => false,
+ // This case handles things like no conversion and exact which were already
+ // treated and other cases which no conversion is better than the other
+ _ => continue,
+ };
+
+ // Check if the best parameter corresponds to the current selected overload
+ // to pass to the next comparison, if this isn't true mark it as ambiguous
+ match best_arg {
+ true => match superior {
+ Some(false) => ambiguous = true,
+ _ => {
+ superior = Some(true);
+ new_conversions[i] = conversion
+ }
+ },
+ false => match superior {
+ Some(true) => ambiguous = true,
+ _ => superior = Some(false),
+ },
+ }
+ }
+
+ // The overload matches exactly the function call so there's no ambiguity (since
+ // repeated overload aren't allowed) and the current overload is selected, no
+ // further querying is needed.
+ if exact {
+ maybe_overload = Some(overload);
+ ambiguous = false;
+ break;
+ }
+
+ match superior {
+ // New overload is better keep it
+ Some(true) => {
+ maybe_overload = Some(overload);
+ // Replace the conversions
+ old_conversions = new_conversions;
+ }
+ // Old overload is better do nothing
+ Some(false) => {}
+ // No overload was better than the other this can be caused
+ // when all conversions are ambiguous in which the overloads themselves are
+ // ambiguous.
+ None => {
+ ambiguous = true;
+ // Assign the new overload, this helps ensures that in this case of
+ // ambiguity the parsing won't end immediately and allow for further
+ // collection of errors.
+ maybe_overload = Some(overload);
+ }
+ }
+ }
+
+ if ambiguous {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!("Ambiguous best function for '{name}'").into(),
+ ),
+ meta,
+ })
+ }
+
+ let overload = maybe_overload.ok_or_else(|| Error {
+ kind: ErrorKind::SemanticError(format!("Unknown function '{name}'").into()),
+ meta,
+ })?;
+
+ let parameters_info = overload.parameters_info.clone();
+ let parameters = overload.parameters.clone();
+ let is_void = overload.void;
+ let kind = overload.kind;
+
+ let mut arguments = Vec::with_capacity(args.len());
+ let mut proxy_writes = Vec::new();
+
+ // Iterate through the function call arguments applying transformations as needed
+ for (((parameter_info, call_argument), expr), parameter) in parameters_info
+ .iter()
+ .zip(&args)
+ .zip(raw_args)
+ .zip(&parameters)
+ {
+ let (mut handle, meta) =
+ ctx.lower_expect_inner(stmt, self, *expr, parameter_info.qualifier.as_pos(), body)?;
+
+ if parameter_info.qualifier.is_lhs() {
+ self.process_lhs_argument(
+ ctx,
+ body,
+ meta,
+ *parameter,
+ parameter_info,
+ handle,
+ call_argument,
+ &mut proxy_writes,
+ &mut arguments,
+ )?;
+
+ continue;
+ }
+
+ let scalar_comps = scalar_components(&self.module.types[*parameter].inner);
+
+ // Apply implicit conversions as needed
+ if let Some((kind, width)) = scalar_comps {
+ ctx.implicit_conversion(self, &mut handle, meta, kind, width)?;
+ }
+
+ arguments.push(handle)
+ }
+
+ match kind {
+ FunctionKind::Call(function) => {
+ ctx.emit_end(body);
+
+ let result = if !is_void {
+ Some(ctx.add_expression(Expression::CallResult(function), meta, body))
+ } else {
+ None
+ };
+
+ body.push(
+ crate::Statement::Call {
+ function,
+ arguments,
+ result,
+ },
+ meta,
+ );
+
+ ctx.emit_start();
+
+ // Write back all the variables that were scheduled to their original place
+ for proxy_write in proxy_writes {
+ let mut value = ctx.add_expression(
+ Expression::Load {
+ pointer: proxy_write.value,
+ },
+ meta,
+ body,
+ );
+
+ if let Some((kind, width)) = proxy_write.convert {
+ ctx.conversion(&mut value, meta, kind, width)?;
+ }
+
+ ctx.emit_restart(body);
+
+ body.push(
+ Statement::Store {
+ pointer: proxy_write.target,
+ value,
+ },
+ meta,
+ );
+ }
+
+ Ok(result)
+ }
+ FunctionKind::Macro(builtin) => {
+ builtin.call(self, ctx, body, arguments.as_mut_slice(), meta)
+ }
+ }
+ }
+
+ /// Processes a function call argument that appears in place of an output
+ /// parameter.
+ #[allow(clippy::too_many_arguments)]
+ fn process_lhs_argument(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ meta: Span,
+ parameter_ty: Handle<Type>,
+ parameter_info: &ParameterInfo,
+ original: Handle<Expression>,
+ call_argument: &(Handle<Expression>, Span),
+ proxy_writes: &mut Vec<ProxyWrite>,
+ arguments: &mut Vec<Handle<Expression>>,
+ ) -> Result<()> {
+ let original_ty = self.resolve_type(ctx, original, meta)?;
+ let original_pointer_space = original_ty.pointer_space();
+
+ // The type of a possible spill variable needed for a proxy write
+ let mut maybe_ty = match *original_ty {
+ // If the argument is to be passed as a pointer but the type of the
+ // expression returns a vector it must mean that it was for example
+ // swizzled and it must be spilled into a local before calling
+ TypeInner::Vector { size, kind, width } => Some(self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector { size, kind, width },
+ },
+ Span::default(),
+ )),
+ // If the argument is a pointer whose address space isn't `Function`, an
+ // indirection through a local variable is needed to align the address
+ // spaces of the call argument and the overload parameter.
+ TypeInner::Pointer { base, space } if space != AddressSpace::Function => Some(base),
+ TypeInner::ValuePointer {
+ size,
+ kind,
+ width,
+ space,
+ } if space != AddressSpace::Function => {
+ let inner = match size {
+ Some(size) => TypeInner::Vector { size, kind, width },
+ None => TypeInner::Scalar { kind, width },
+ };
+
+ Some(
+ self.module
+ .types
+ .insert(Type { name: None, inner }, Span::default()),
+ )
+ }
+ _ => None,
+ };
+
+ // Since the original expression might be a pointer and we want a value
+ // for the proxy writes, we might need to load the pointer.
+ let value = if original_pointer_space.is_some() {
+ ctx.add_expression(
+ Expression::Load { pointer: original },
+ Span::default(),
+ body,
+ )
+ } else {
+ original
+ };
+
+ let call_arg_ty = self.resolve_type(ctx, call_argument.0, call_argument.1)?;
+ let overload_param_ty = &self.module.types[parameter_ty].inner;
+ let needs_conversion = call_arg_ty != overload_param_ty;
+
+ let arg_scalar_comps = scalar_components(call_arg_ty);
+
+ // Since output parameters also allow implicit conversions from the
+ // parameter to the argument, we need to spill the conversion to a
+ // variable and create a proxy write for the original variable.
+ if needs_conversion {
+ maybe_ty = Some(parameter_ty);
+ }
+
+ if let Some(ty) = maybe_ty {
+ // Create the spill variable
+ let spill_var = ctx.locals.append(
+ LocalVariable {
+ name: None,
+ ty,
+ init: None,
+ },
+ Span::default(),
+ );
+ let spill_expr =
+ ctx.add_expression(Expression::LocalVariable(spill_var), Span::default(), body);
+
+ // If the argument is also copied in we must store the value of the
+ // original variable to the spill variable.
+ if let ParameterQualifier::InOut = parameter_info.qualifier {
+ body.push(
+ Statement::Store {
+ pointer: spill_expr,
+ value,
+ },
+ Span::default(),
+ );
+ }
+
+ // Add the spill variable as an argument to the function call
+ arguments.push(spill_expr);
+
+ let convert = if needs_conversion {
+ arg_scalar_comps
+ } else {
+ None
+ };
+
+ // Register the temporary local to be written back to it's original
+ // place after the function call
+ if let Expression::Swizzle {
+ size,
+ mut vector,
+ pattern,
+ } = ctx.expressions[original]
+ {
+ if let Expression::Load { pointer } = ctx.expressions[vector] {
+ vector = pointer;
+ }
+
+ for (i, component) in pattern.iter().take(size as usize).enumerate() {
+ let original = ctx.add_expression(
+ Expression::AccessIndex {
+ base: vector,
+ index: *component as u32,
+ },
+ Span::default(),
+ body,
+ );
+
+ let spill_component = ctx.add_expression(
+ Expression::AccessIndex {
+ base: spill_expr,
+ index: i as u32,
+ },
+ Span::default(),
+ body,
+ );
+
+ proxy_writes.push(ProxyWrite {
+ target: original,
+ value: spill_component,
+ convert,
+ });
+ }
+ } else {
+ proxy_writes.push(ProxyWrite {
+ target: original,
+ value: spill_expr,
+ convert,
+ });
+ }
+ } else {
+ arguments.push(original);
+ }
+
+ Ok(())
+ }
+
+ pub(crate) fn add_function(
+ &mut self,
+ ctx: Context,
+ name: String,
+ result: Option<FunctionResult>,
+ mut body: Block,
+ meta: Span,
+ ) {
+ ensure_block_returns(&mut body);
+
+ let void = result.is_none();
+
+ let &mut Frontend {
+ ref mut lookup_function,
+ ref mut module,
+ ..
+ } = self;
+
+ // Check if the passed arguments require any special variations
+ let mut variations =
+ builtin_required_variations(ctx.parameters.iter().map(|&arg| &module.types[arg].inner));
+
+ // Initiate the declaration if it wasn't previously initialized and inject builtins
+ let declaration = lookup_function.entry(name.clone()).or_insert_with(|| {
+ variations |= BuiltinVariations::STANDARD;
+ Default::default()
+ });
+ inject_builtin(declaration, module, &name, variations);
+
+ let Context {
+ expressions,
+ locals,
+ arguments,
+ parameters,
+ parameters_info,
+ ..
+ } = ctx;
+
+ let function = Function {
+ name: Some(name),
+ arguments,
+ result,
+ local_variables: locals,
+ expressions,
+ named_expressions: crate::NamedExpressions::default(),
+ body,
+ };
+
+ 'outer: for decl in declaration.overloads.iter_mut() {
+ if parameters.len() != decl.parameters.len() {
+ continue;
+ }
+
+ for (new_parameter, old_parameter) in parameters.iter().zip(decl.parameters.iter()) {
+ let new_inner = &module.types[*new_parameter].inner;
+ let old_inner = &module.types[*old_parameter].inner;
+
+ if new_inner != old_inner {
+ continue 'outer;
+ }
+ }
+
+ if decl.defined {
+ return self.errors.push(Error {
+ kind: ErrorKind::SemanticError("Function already defined".into()),
+ meta,
+ });
+ }
+
+ decl.defined = true;
+ decl.parameters_info = parameters_info;
+ match decl.kind {
+ FunctionKind::Call(handle) => *self.module.functions.get_mut(handle) = function,
+ FunctionKind::Macro(_) => {
+ let handle = module.functions.append(function, meta);
+ decl.kind = FunctionKind::Call(handle)
+ }
+ }
+ return;
+ }
+
+ let handle = module.functions.append(function, meta);
+ declaration.overloads.push(Overload {
+ parameters,
+ parameters_info,
+ kind: FunctionKind::Call(handle),
+ defined: true,
+ internal: false,
+ void,
+ });
+ }
+
+ pub(crate) fn add_prototype(
+ &mut self,
+ ctx: Context,
+ name: String,
+ result: Option<FunctionResult>,
+ meta: Span,
+ ) {
+ let void = result.is_none();
+
+ let &mut Frontend {
+ ref mut lookup_function,
+ ref mut module,
+ ..
+ } = self;
+
+ // Check if the passed arguments require any special variations
+ let mut variations =
+ builtin_required_variations(ctx.parameters.iter().map(|&arg| &module.types[arg].inner));
+
+ // Initiate the declaration if it wasn't previously initialized and inject builtins
+ let declaration = lookup_function.entry(name.clone()).or_insert_with(|| {
+ variations |= BuiltinVariations::STANDARD;
+ Default::default()
+ });
+ inject_builtin(declaration, module, &name, variations);
+
+ let Context {
+ arguments,
+ parameters,
+ parameters_info,
+ ..
+ } = ctx;
+
+ let function = Function {
+ name: Some(name),
+ arguments,
+ result,
+ ..Default::default()
+ };
+
+ 'outer: for decl in declaration.overloads.iter() {
+ if parameters.len() != decl.parameters.len() {
+ continue;
+ }
+
+ for (new_parameter, old_parameter) in parameters.iter().zip(decl.parameters.iter()) {
+ let new_inner = &module.types[*new_parameter].inner;
+ let old_inner = &module.types[*old_parameter].inner;
+
+ if new_inner != old_inner {
+ continue 'outer;
+ }
+ }
+
+ return self.errors.push(Error {
+ kind: ErrorKind::SemanticError("Prototype already defined".into()),
+ meta,
+ });
+ }
+
+ let handle = module.functions.append(function, meta);
+ declaration.overloads.push(Overload {
+ parameters,
+ parameters_info,
+ kind: FunctionKind::Call(handle),
+ defined: false,
+ internal: false,
+ void,
+ });
+ }
+
+ /// Helper function for building the input/output interface of the entry point
+ ///
+ /// Calls `f` with the data of the entry point argument, flattening composite types
+ /// recursively
+ ///
+ /// The passed arguments to the callback are:
+ /// - The name
+ /// - The pointer expression to the global storage
+ /// - The handle to the type of the entry point argument
+ /// - The binding of the entry point argument
+ /// - The expression arena
+ fn arg_type_walker(
+ &self,
+ name: Option<String>,
+ binding: crate::Binding,
+ pointer: Handle<Expression>,
+ ty: Handle<Type>,
+ expressions: &mut Arena<Expression>,
+ f: &mut impl FnMut(
+ Option<String>,
+ Handle<Expression>,
+ Handle<Type>,
+ crate::Binding,
+ &mut Arena<Expression>,
+ ),
+ ) {
+ match self.module.types[ty].inner {
+ TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(constant),
+ ..
+ } => {
+ let mut location = match binding {
+ crate::Binding::Location { location, .. } => location,
+ crate::Binding::BuiltIn(_) => return,
+ };
+
+ // TODO: Better error reporting
+ // right now we just don't walk the array if the size isn't known at
+ // compile time and let validation catch it
+ let size = match self.module.constants[constant].to_array_length() {
+ Some(val) => val,
+ None => return f(name, pointer, ty, binding, expressions),
+ };
+
+ let interpolation =
+ self.module.types[base]
+ .inner
+ .scalar_kind()
+ .map(|kind| match kind {
+ ScalarKind::Float => crate::Interpolation::Perspective,
+ _ => crate::Interpolation::Flat,
+ });
+
+ for index in 0..size {
+ let member_pointer = expressions.append(
+ Expression::AccessIndex {
+ base: pointer,
+ index,
+ },
+ crate::Span::default(),
+ );
+
+ let binding = crate::Binding::Location {
+ location,
+ interpolation,
+ sampling: None,
+ };
+ location += 1;
+
+ self.arg_type_walker(
+ name.clone(),
+ binding,
+ member_pointer,
+ base,
+ expressions,
+ f,
+ )
+ }
+ }
+ TypeInner::Struct { ref members, .. } => {
+ let mut location = match binding {
+ crate::Binding::Location { location, .. } => location,
+ crate::Binding::BuiltIn(_) => return,
+ };
+
+ for (i, member) in members.iter().enumerate() {
+ let member_pointer = expressions.append(
+ Expression::AccessIndex {
+ base: pointer,
+ index: i as u32,
+ },
+ crate::Span::default(),
+ );
+
+ let binding = match member.binding.clone() {
+ Some(binding) => binding,
+ None => {
+ let interpolation = self.module.types[member.ty]
+ .inner
+ .scalar_kind()
+ .map(|kind| match kind {
+ ScalarKind::Float => crate::Interpolation::Perspective,
+ _ => crate::Interpolation::Flat,
+ });
+ let binding = crate::Binding::Location {
+ location,
+ interpolation,
+ sampling: None,
+ };
+ location += 1;
+ binding
+ }
+ };
+
+ self.arg_type_walker(
+ member.name.clone(),
+ binding,
+ member_pointer,
+ member.ty,
+ expressions,
+ f,
+ )
+ }
+ }
+ _ => f(name, pointer, ty, binding, expressions),
+ }
+ }
+
+ pub(crate) fn add_entry_point(
+ &mut self,
+ function: Handle<Function>,
+ global_init_body: Block,
+ mut expressions: Arena<Expression>,
+ ) {
+ let mut arguments = Vec::new();
+ let mut body = Block::with_capacity(
+ // global init body
+ global_init_body.len() +
+ // prologue and epilogue
+ self.entry_args.len() * 2
+ // Call, Emit for composing struct and return
+ + 3,
+ );
+
+ for arg in self.entry_args.iter() {
+ if arg.storage != StorageQualifier::Input {
+ continue;
+ }
+
+ let pointer =
+ expressions.append(Expression::GlobalVariable(arg.handle), Default::default());
+
+ self.arg_type_walker(
+ arg.name.clone(),
+ arg.binding.clone(),
+ pointer,
+ self.module.global_variables[arg.handle].ty,
+ &mut expressions,
+ &mut |name, pointer, ty, binding, expressions| {
+ let idx = arguments.len() as u32;
+
+ arguments.push(FunctionArgument {
+ name,
+ ty,
+ binding: Some(binding),
+ });
+
+ let value =
+ expressions.append(Expression::FunctionArgument(idx), Default::default());
+ body.push(Statement::Store { pointer, value }, Default::default());
+ },
+ )
+ }
+
+ body.extend_block(global_init_body);
+
+ body.push(
+ Statement::Call {
+ function,
+ arguments: Vec::new(),
+ result: None,
+ },
+ Default::default(),
+ );
+
+ let mut span = 0;
+ let mut members = Vec::new();
+ let mut components = Vec::new();
+
+ for arg in self.entry_args.iter() {
+ if arg.storage != StorageQualifier::Output {
+ continue;
+ }
+
+ let pointer =
+ expressions.append(Expression::GlobalVariable(arg.handle), Default::default());
+
+ self.arg_type_walker(
+ arg.name.clone(),
+ arg.binding.clone(),
+ pointer,
+ self.module.global_variables[arg.handle].ty,
+ &mut expressions,
+ &mut |name, pointer, ty, binding, expressions| {
+ members.push(StructMember {
+ name,
+ ty,
+ binding: Some(binding),
+ offset: span,
+ });
+
+ span += self.module.types[ty].inner.size(&self.module.constants);
+
+ let len = expressions.len();
+ let load = expressions.append(Expression::Load { pointer }, Default::default());
+ body.push(
+ Statement::Emit(expressions.range_from(len)),
+ Default::default(),
+ );
+ components.push(load)
+ },
+ )
+ }
+
+ let (ty, value) = if !components.is_empty() {
+ let ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Struct { members, span },
+ },
+ Default::default(),
+ );
+
+ let len = expressions.len();
+ let res =
+ expressions.append(Expression::Compose { ty, components }, Default::default());
+ body.push(
+ Statement::Emit(expressions.range_from(len)),
+ Default::default(),
+ );
+
+ (Some(ty), Some(res))
+ } else {
+ (None, None)
+ };
+
+ body.push(Statement::Return { value }, Default::default());
+
+ self.module.entry_points.push(EntryPoint {
+ name: "main".to_string(),
+ stage: self.meta.stage,
+ early_depth_test: Some(crate::EarlyDepthTest { conservative: None })
+ .filter(|_| self.meta.early_fragment_tests),
+ workgroup_size: self.meta.workgroup_size,
+ function: Function {
+ arguments,
+ expressions,
+ body,
+ result: ty.map(|ty| FunctionResult { ty, binding: None }),
+ ..Default::default()
+ },
+ });
+ }
+}
+
+/// Helper enum containing the type of conversion need for a call
+#[derive(PartialEq, Eq, Clone, Copy, Debug)]
+enum Conversion {
+ /// No conversion needed
+ Exact,
+ /// Float to double conversion needed
+ FloatToDouble,
+ /// Int or uint to float conversion needed
+ IntToFloat,
+ /// Int or uint to double conversion needed
+ IntToDouble,
+ /// Other type of conversion needed
+ Other,
+ /// No conversion was yet registered
+ None,
+}
+
+/// Helper function, returns the type of conversion from `source` to `target`, if a
+/// conversion is not possible returns None.
+fn conversion(target: &TypeInner, source: &TypeInner) -> Option<Conversion> {
+ use ScalarKind::*;
+
+ // Gather the `ScalarKind` and scalar width from both the target and the source
+ let (target_kind, target_width, source_kind, source_width) = match (target, source) {
+ // Conversions between scalars are allowed
+ (
+ &TypeInner::Scalar {
+ kind: tgt_kind,
+ width: tgt_width,
+ },
+ &TypeInner::Scalar {
+ kind: src_kind,
+ width: src_width,
+ },
+ ) => (tgt_kind, tgt_width, src_kind, src_width),
+ // Conversions between vectors of the same size are allowed
+ (
+ &TypeInner::Vector {
+ kind: tgt_kind,
+ size: tgt_size,
+ width: tgt_width,
+ },
+ &TypeInner::Vector {
+ kind: src_kind,
+ size: src_size,
+ width: src_width,
+ },
+ ) if tgt_size == src_size => (tgt_kind, tgt_width, src_kind, src_width),
+ // Conversions between matrices of the same size are allowed
+ (
+ &TypeInner::Matrix {
+ rows: tgt_rows,
+ columns: tgt_cols,
+ width: tgt_width,
+ },
+ &TypeInner::Matrix {
+ rows: src_rows,
+ columns: src_cols,
+ width: src_width,
+ },
+ ) if tgt_cols == src_cols && tgt_rows == src_rows => (Float, tgt_width, Float, src_width),
+ _ => return None,
+ };
+
+ // Check if source can be converted into target, if this is the case then the type
+ // power of target must be higher than that of source
+ let target_power = type_power(target_kind, target_width);
+ let source_power = type_power(source_kind, source_width);
+ if target_power < source_power {
+ return None;
+ }
+
+ Some(
+ match ((target_kind, target_width), (source_kind, source_width)) {
+ // A conversion from a float to a double is special
+ ((Float, 8), (Float, 4)) => Conversion::FloatToDouble,
+ // A conversion from an integer to a float is special
+ ((Float, 4), (Sint | Uint, _)) => Conversion::IntToFloat,
+ // A conversion from an integer to a double is special
+ ((Float, 8), (Sint | Uint, _)) => Conversion::IntToDouble,
+ _ => Conversion::Other,
+ },
+ )
+}
+
+/// Helper method returning all the non standard builtin variations needed
+/// to process the function call with the passed arguments
+fn builtin_required_variations<'a>(args: impl Iterator<Item = &'a TypeInner>) -> BuiltinVariations {
+ let mut variations = BuiltinVariations::empty();
+
+ for ty in args {
+ match *ty {
+ TypeInner::ValuePointer { kind, width, .. }
+ | TypeInner::Scalar { kind, width }
+ | TypeInner::Vector { kind, width, .. } => {
+ if kind == ScalarKind::Float && width == 8 {
+ variations |= BuiltinVariations::DOUBLE
+ }
+ }
+ TypeInner::Matrix { width, .. } => {
+ if width == 8 {
+ variations |= BuiltinVariations::DOUBLE
+ }
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ if dim == crate::ImageDimension::Cube && arrayed {
+ variations |= BuiltinVariations::CUBE_TEXTURES_ARRAY
+ }
+
+ if dim == crate::ImageDimension::D2 && arrayed && class.is_multisampled() {
+ variations |= BuiltinVariations::D2_MULTI_TEXTURES_ARRAY
+ }
+ }
+ _ => {}
+ }
+ }
+
+ variations
+}
diff --git a/third_party/rust/naga/src/front/glsl/lex.rs b/third_party/rust/naga/src/front/glsl/lex.rs
new file mode 100644
index 0000000000..1b59a9bf3e
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/lex.rs
@@ -0,0 +1,301 @@
+use super::{
+ ast::Precision,
+ token::{Directive, DirectiveKind, Token, TokenValue},
+ types::parse_type,
+};
+use crate::{FastHashMap, Span, StorageAccess};
+use pp_rs::{
+ pp::Preprocessor,
+ token::{PreprocessorError, Punct, TokenValue as PPTokenValue},
+};
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct LexerResult {
+ pub kind: LexerResultKind,
+ pub meta: Span,
+}
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum LexerResultKind {
+ Token(Token),
+ Directive(Directive),
+ Error(PreprocessorError),
+}
+
+pub struct Lexer<'a> {
+ pp: Preprocessor<'a>,
+}
+
+impl<'a> Lexer<'a> {
+ pub fn new(input: &'a str, defines: &'a FastHashMap<String, String>) -> Self {
+ let mut pp = Preprocessor::new(input);
+ for (define, value) in defines {
+ pp.add_define(define, value).unwrap(); //TODO: handle error
+ }
+ Lexer { pp }
+ }
+}
+
+impl<'a> Iterator for Lexer<'a> {
+ type Item = LexerResult;
+ fn next(&mut self) -> Option<Self::Item> {
+ let pp_token = match self.pp.next()? {
+ Ok(t) => t,
+ Err((err, loc)) => {
+ return Some(LexerResult {
+ kind: LexerResultKind::Error(err),
+ meta: loc.into(),
+ });
+ }
+ };
+
+ let meta = pp_token.location.into();
+ let value = match pp_token.value {
+ PPTokenValue::Extension(extension) => {
+ return Some(LexerResult {
+ kind: LexerResultKind::Directive(Directive {
+ kind: DirectiveKind::Extension,
+ tokens: extension.tokens,
+ }),
+ meta,
+ })
+ }
+ PPTokenValue::Float(float) => TokenValue::FloatConstant(float),
+ PPTokenValue::Ident(ident) => {
+ match ident.as_str() {
+ // Qualifiers
+ "layout" => TokenValue::Layout,
+ "in" => TokenValue::In,
+ "out" => TokenValue::Out,
+ "uniform" => TokenValue::Uniform,
+ "buffer" => TokenValue::Buffer,
+ "shared" => TokenValue::Shared,
+ "invariant" => TokenValue::Invariant,
+ "flat" => TokenValue::Interpolation(crate::Interpolation::Flat),
+ "noperspective" => TokenValue::Interpolation(crate::Interpolation::Linear),
+ "smooth" => TokenValue::Interpolation(crate::Interpolation::Perspective),
+ "centroid" => TokenValue::Sampling(crate::Sampling::Centroid),
+ "sample" => TokenValue::Sampling(crate::Sampling::Sample),
+ "const" => TokenValue::Const,
+ "inout" => TokenValue::InOut,
+ "precision" => TokenValue::Precision,
+ "highp" => TokenValue::PrecisionQualifier(Precision::High),
+ "mediump" => TokenValue::PrecisionQualifier(Precision::Medium),
+ "lowp" => TokenValue::PrecisionQualifier(Precision::Low),
+ "restrict" => TokenValue::Restrict,
+ "readonly" => TokenValue::MemoryQualifier(StorageAccess::LOAD),
+ "writeonly" => TokenValue::MemoryQualifier(StorageAccess::STORE),
+ // values
+ "true" => TokenValue::BoolConstant(true),
+ "false" => TokenValue::BoolConstant(false),
+ // jump statements
+ "continue" => TokenValue::Continue,
+ "break" => TokenValue::Break,
+ "return" => TokenValue::Return,
+ "discard" => TokenValue::Discard,
+ // selection statements
+ "if" => TokenValue::If,
+ "else" => TokenValue::Else,
+ "switch" => TokenValue::Switch,
+ "case" => TokenValue::Case,
+ "default" => TokenValue::Default,
+ // iteration statements
+ "while" => TokenValue::While,
+ "do" => TokenValue::Do,
+ "for" => TokenValue::For,
+ // types
+ "void" => TokenValue::Void,
+ "struct" => TokenValue::Struct,
+ word => match parse_type(word) {
+ Some(t) => TokenValue::TypeName(t),
+ None => TokenValue::Identifier(String::from(word)),
+ },
+ }
+ }
+ PPTokenValue::Integer(integer) => TokenValue::IntConstant(integer),
+ PPTokenValue::Punct(punct) => match punct {
+ // Compound assignments
+ Punct::AddAssign => TokenValue::AddAssign,
+ Punct::SubAssign => TokenValue::SubAssign,
+ Punct::MulAssign => TokenValue::MulAssign,
+ Punct::DivAssign => TokenValue::DivAssign,
+ Punct::ModAssign => TokenValue::ModAssign,
+ Punct::LeftShiftAssign => TokenValue::LeftShiftAssign,
+ Punct::RightShiftAssign => TokenValue::RightShiftAssign,
+ Punct::AndAssign => TokenValue::AndAssign,
+ Punct::XorAssign => TokenValue::XorAssign,
+ Punct::OrAssign => TokenValue::OrAssign,
+
+ // Two character punctuation
+ Punct::Increment => TokenValue::Increment,
+ Punct::Decrement => TokenValue::Decrement,
+ Punct::LogicalAnd => TokenValue::LogicalAnd,
+ Punct::LogicalOr => TokenValue::LogicalOr,
+ Punct::LogicalXor => TokenValue::LogicalXor,
+ Punct::LessEqual => TokenValue::LessEqual,
+ Punct::GreaterEqual => TokenValue::GreaterEqual,
+ Punct::EqualEqual => TokenValue::Equal,
+ Punct::NotEqual => TokenValue::NotEqual,
+ Punct::LeftShift => TokenValue::LeftShift,
+ Punct::RightShift => TokenValue::RightShift,
+
+ // Parenthesis or similar
+ Punct::LeftBrace => TokenValue::LeftBrace,
+ Punct::RightBrace => TokenValue::RightBrace,
+ Punct::LeftParen => TokenValue::LeftParen,
+ Punct::RightParen => TokenValue::RightParen,
+ Punct::LeftBracket => TokenValue::LeftBracket,
+ Punct::RightBracket => TokenValue::RightBracket,
+
+ // Other one character punctuation
+ Punct::LeftAngle => TokenValue::LeftAngle,
+ Punct::RightAngle => TokenValue::RightAngle,
+ Punct::Semicolon => TokenValue::Semicolon,
+ Punct::Comma => TokenValue::Comma,
+ Punct::Colon => TokenValue::Colon,
+ Punct::Dot => TokenValue::Dot,
+ Punct::Equal => TokenValue::Assign,
+ Punct::Bang => TokenValue::Bang,
+ Punct::Minus => TokenValue::Dash,
+ Punct::Tilde => TokenValue::Tilde,
+ Punct::Plus => TokenValue::Plus,
+ Punct::Star => TokenValue::Star,
+ Punct::Slash => TokenValue::Slash,
+ Punct::Percent => TokenValue::Percent,
+ Punct::Pipe => TokenValue::VerticalBar,
+ Punct::Caret => TokenValue::Caret,
+ Punct::Ampersand => TokenValue::Ampersand,
+ Punct::Question => TokenValue::Question,
+ },
+ PPTokenValue::Pragma(pragma) => {
+ return Some(LexerResult {
+ kind: LexerResultKind::Directive(Directive {
+ kind: DirectiveKind::Pragma,
+ tokens: pragma.tokens,
+ }),
+ meta,
+ })
+ }
+ PPTokenValue::Version(version) => {
+ return Some(LexerResult {
+ kind: LexerResultKind::Directive(Directive {
+ kind: DirectiveKind::Version {
+ is_first_directive: version.is_first_directive,
+ },
+ tokens: version.tokens,
+ }),
+ meta,
+ })
+ }
+ };
+
+ Some(LexerResult {
+ kind: LexerResultKind::Token(Token { value, meta }),
+ meta,
+ })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use pp_rs::token::{Integer, Location, Token as PPToken, TokenValue as PPTokenValue};
+
+ use super::{
+ super::token::{Directive, DirectiveKind, Token, TokenValue},
+ Lexer, LexerResult, LexerResultKind,
+ };
+ use crate::Span;
+
+ #[test]
+ fn lex_tokens() {
+ let defines = crate::FastHashMap::default();
+
+ // line comments
+ let mut lex = Lexer::new("#version 450\nvoid main () {}", &defines);
+ let mut location = Location::default();
+ location.start = 9;
+ location.end = 12;
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Directive(Directive {
+ kind: DirectiveKind::Version {
+ is_first_directive: true
+ },
+ tokens: vec![PPToken {
+ value: PPTokenValue::Integer(Integer {
+ signed: true,
+ value: 450,
+ width: 32
+ }),
+ location
+ }]
+ }),
+ meta: Span::new(1, 8)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::Void,
+ meta: Span::new(13, 17)
+ }),
+ meta: Span::new(13, 17)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::Identifier("main".into()),
+ meta: Span::new(18, 22)
+ }),
+ meta: Span::new(18, 22)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::LeftParen,
+ meta: Span::new(23, 24)
+ }),
+ meta: Span::new(23, 24)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::RightParen,
+ meta: Span::new(24, 25)
+ }),
+ meta: Span::new(24, 25)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::LeftBrace,
+ meta: Span::new(26, 27)
+ }),
+ meta: Span::new(26, 27)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::RightBrace,
+ meta: Span::new(27, 28)
+ }),
+ meta: Span::new(27, 28)
+ }
+ );
+ assert_eq!(lex.next(), None);
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/mod.rs b/third_party/rust/naga/src/front/glsl/mod.rs
new file mode 100644
index 0000000000..a1aa4622d3
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/mod.rs
@@ -0,0 +1,235 @@
+/*!
+Frontend for [GLSL][glsl] (OpenGL Shading Language).
+
+To begin, take a look at the documentation for the [`Frontend`].
+
+# Supported versions
+## Vulkan
+- 440 (partial)
+- 450
+- 460
+
+[glsl]: https://www.khronos.org/registry/OpenGL/index_gl.php
+*/
+
+pub use ast::{Precision, Profile};
+pub use error::{Error, ErrorKind, ExpectedToken};
+pub use token::TokenValue;
+
+use crate::{proc::Layouter, FastHashMap, FastHashSet, Handle, Module, ShaderStage, Span, Type};
+use ast::{EntryArg, FunctionDeclaration, GlobalLookup};
+use parser::ParsingContext;
+
+mod ast;
+mod builtins;
+mod constants;
+mod context;
+mod error;
+mod functions;
+mod lex;
+mod offset;
+mod parser;
+#[cfg(test)]
+mod parser_tests;
+mod token;
+mod types;
+mod variables;
+
+type Result<T> = std::result::Result<T, Error>;
+
+/// Per-shader options passed to [`parse`](Frontend::parse).
+///
+/// The [`From`](From) trait is implemented for [`ShaderStage`](ShaderStage) to
+/// provide a quick way to create a Options instance.
+/// ```rust
+/// # use naga::ShaderStage;
+/// # use naga::front::glsl::Options;
+/// Options::from(ShaderStage::Vertex);
+/// ```
+#[derive(Debug)]
+pub struct Options {
+ /// The shader stage in the pipeline.
+ pub stage: ShaderStage,
+ /// Preprocesor definitions to be used, akin to having
+ /// ```glsl
+ /// #define key value
+ /// ```
+ /// for each key value pair in the map.
+ pub defines: FastHashMap<String, String>,
+}
+
+impl From<ShaderStage> for Options {
+ fn from(stage: ShaderStage) -> Self {
+ Options {
+ stage,
+ defines: FastHashMap::default(),
+ }
+ }
+}
+
+/// Additional information about the GLSL shader.
+///
+/// Stores additional information about the GLSL shader which might not be
+/// stored in the shader [`Module`](Module).
+#[derive(Debug)]
+pub struct ShaderMetadata {
+ /// The GLSL version specified in the shader through the use of the
+ /// `#version` preprocessor directive.
+ pub version: u16,
+ /// The GLSL profile specified in the shader through the use of the
+ /// `#version` preprocessor directive.
+ pub profile: Profile,
+ /// The shader stage in the pipeline, passed to the [`parse`](Frontend::parse)
+ /// method via the [`Options`](Options) struct.
+ pub stage: ShaderStage,
+
+ /// The workgroup size for compute shaders, defaults to `[1; 3]` for
+ /// compute shaders and `[0; 3]` for non compute shaders.
+ pub workgroup_size: [u32; 3],
+ /// Whether or not early fragment tests where requested by the shader.
+ /// Defaults to `false`.
+ pub early_fragment_tests: bool,
+
+ /// The shader can request extensions via the
+ /// `#extension` preprocessor directive, in the directive a behavior
+ /// parameter is used to control whether the extension should be disabled,
+ /// warn on usage, enabled if possible or required.
+ ///
+ /// This field only stores extensions which were required or requested to
+ /// be enabled if possible and they are supported.
+ pub extensions: FastHashSet<String>,
+}
+
+impl ShaderMetadata {
+ fn reset(&mut self, stage: ShaderStage) {
+ self.version = 0;
+ self.profile = Profile::Core;
+ self.stage = stage;
+ self.workgroup_size = [u32::from(stage == ShaderStage::Compute); 3];
+ self.early_fragment_tests = false;
+ self.extensions.clear();
+ }
+}
+
+impl Default for ShaderMetadata {
+ fn default() -> Self {
+ ShaderMetadata {
+ version: 0,
+ profile: Profile::Core,
+ stage: ShaderStage::Vertex,
+ workgroup_size: [0; 3],
+ early_fragment_tests: false,
+ extensions: FastHashSet::default(),
+ }
+ }
+}
+
+/// The `Frontend` is the central structure of the GLSL frontend.
+///
+/// To instantiate a new `Frontend` the [`Default`](Default) trait is used, so a
+/// call to the associated function [`Frontend::default`](Frontend::default) will
+/// return a new `Frontend` instance.
+///
+/// To parse a shader simply call the [`parse`](Frontend::parse) method with a
+/// [`Options`](Options) struct and a [`&str`](str) holding the glsl code.
+///
+/// The `Frontend` also provides the [`metadata`](Frontend::metadata) to get some
+/// further information about the previously parsed shader, like version and
+/// extensions used (see the documentation for
+/// [`ShaderMetadata`](ShaderMetadata) to see all the returned information)
+///
+/// # Example usage
+/// ```rust
+/// use naga::ShaderStage;
+/// use naga::front::glsl::{Frontend, Options};
+///
+/// let glsl = r#"
+/// #version 450 core
+///
+/// void main() {}
+/// "#;
+///
+/// let mut frontend = Frontend::default();
+/// let options = Options::from(ShaderStage::Vertex);
+/// frontend.parse(&options, glsl);
+/// ```
+///
+/// # Reusability
+///
+/// If there's a need to parse more than one shader reusing the same `Frontend`
+/// instance may be beneficial since internal allocations will be reused.
+///
+/// Calling the [`parse`](Frontend::parse) method multiple times will reset the
+/// `Frontend` so no extra care is needed when reusing.
+#[derive(Debug, Default)]
+pub struct Frontend {
+ meta: ShaderMetadata,
+
+ lookup_function: FastHashMap<String, FunctionDeclaration>,
+ lookup_type: FastHashMap<String, Handle<Type>>,
+
+ global_variables: Vec<(String, GlobalLookup)>,
+
+ entry_args: Vec<EntryArg>,
+
+ layouter: Layouter,
+
+ errors: Vec<Error>,
+
+ module: Module,
+}
+
+impl Frontend {
+ fn reset(&mut self, stage: ShaderStage) {
+ self.meta.reset(stage);
+
+ self.lookup_function.clear();
+ self.lookup_type.clear();
+ self.global_variables.clear();
+ self.entry_args.clear();
+ self.layouter.clear();
+
+ // This is necessary because if the last parsing errored out, the module
+ // wouldn't have been taken
+ self.module = Module::default();
+ }
+
+ /// Parses a shader either outputting a shader [`Module`](Module) or a list
+ /// of [`Error`](Error)s.
+ ///
+ /// Multiple calls using the same `Frontend` and different shaders are supported.
+ pub fn parse(
+ &mut self,
+ options: &Options,
+ source: &str,
+ ) -> std::result::Result<Module, Vec<Error>> {
+ self.reset(options.stage);
+
+ let lexer = lex::Lexer::new(source, &options.defines);
+ let mut ctx = ParsingContext::new(lexer);
+
+ if let Err(e) = ctx.parse(self) {
+ self.errors.push(e);
+ }
+
+ if self.errors.is_empty() {
+ Ok(std::mem::take(&mut self.module))
+ } else {
+ Err(std::mem::take(&mut self.errors))
+ }
+ }
+
+ /// Returns additional information about the parsed shader which might not be
+ /// stored in the [`Module`](Module), see the documentation for
+ /// [`ShaderMetadata`](ShaderMetadata) for more information about the
+ /// returned data.
+ ///
+ /// # Notes
+ ///
+ /// Following an unsuccessful parsing the state of the returned information
+ /// is undefined, it might contain only partial information about the
+ /// current shader, the previous shader or both.
+ pub const fn metadata(&self) -> &ShaderMetadata {
+ &self.meta
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/offset.rs b/third_party/rust/naga/src/front/glsl/offset.rs
new file mode 100644
index 0000000000..d7950489cb
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/offset.rs
@@ -0,0 +1,173 @@
+/*!
+Module responsible for calculating the offset and span for types.
+
+There exists two types of layouts std140 and std430 (there's technically
+two more layouts, shared and packed. Shared is not supported by spirv. Packed is
+implementation dependent and for now it's just implemented as an alias to
+std140).
+
+The OpenGl spec (the layout rules are defined by the OpenGl spec in section
+7.6.2.2 as opposed to the GLSL spec) uses the term basic machine units which are
+equivalent to bytes.
+*/
+
+use super::{
+ ast::StructLayout,
+ error::{Error, ErrorKind},
+ Span,
+};
+use crate::{proc::Alignment, Arena, Constant, Handle, Type, TypeInner, UniqueArena};
+
+/// Struct with information needed for defining a struct member.
+///
+/// Returned by [`calculate_offset`](calculate_offset)
+#[derive(Debug)]
+pub struct TypeAlignSpan {
+ /// The handle to the type, this might be the same handle passed to
+ /// [`calculate_offset`](calculate_offset) or a new such a new array type
+ /// with a different stride set.
+ pub ty: Handle<Type>,
+ /// The alignment required by the type.
+ pub align: Alignment,
+ /// The size of the type.
+ pub span: u32,
+}
+
+/// Returns the type, alignment and span of a struct member according to a [`StructLayout`](StructLayout).
+///
+/// The functions returns a [`TypeAlignSpan`](TypeAlignSpan) which has a `ty` member
+/// this should be used as the struct member type because for example arrays may have to
+/// change the stride and as such need to have a different type.
+pub fn calculate_offset(
+ mut ty: Handle<Type>,
+ meta: Span,
+ layout: StructLayout,
+ types: &mut UniqueArena<Type>,
+ constants: &Arena<Constant>,
+ errors: &mut Vec<Error>,
+) -> TypeAlignSpan {
+ // When using the std430 storage layout, shader storage blocks will be laid out in buffer storage
+ // identically to uniform and shader storage blocks using the std140 layout, except
+ // that the base alignment and stride of arrays of scalars and vectors in rule 4 and of
+ // structures in rule 9 are not rounded up a multiple of the base alignment of a vec4.
+
+ let (align, span) = match types[ty].inner {
+ // 1. If the member is a scalar consuming N basic machine units,
+ // the base alignment is N.
+ TypeInner::Scalar { width, .. } => (Alignment::from_width(width), width as u32),
+ // 2. If the member is a two- or four-component vector with components
+ // consuming N basic machine units, the base alignment is 2N or 4N, respectively.
+ // 3. If the member is a three-component vector with components consuming N
+ // basic machine units, the base alignment is 4N.
+ TypeInner::Vector { size, width, .. } => (
+ Alignment::from(size) * Alignment::from_width(width),
+ size as u32 * width as u32,
+ ),
+ // 4. If the member is an array of scalars or vectors, the base alignment and array
+ // stride are set to match the base alignment of a single array element, according
+ // to rules (1), (2), and (3), and rounded up to the base alignment of a vec4.
+ // TODO: Matrices array
+ TypeInner::Array { base, size, .. } => {
+ let info = calculate_offset(base, meta, layout, types, constants, errors);
+
+ let name = types[ty].name.clone();
+
+ // See comment at the beginning of the function
+ let (align, stride) = if StructLayout::Std430 == layout {
+ (info.align, info.align.round_up(info.span))
+ } else {
+ let align = info.align.max(Alignment::MIN_UNIFORM);
+ (align, align.round_up(info.span))
+ };
+
+ let span = match size {
+ crate::ArraySize::Constant(s) => {
+ constants[s].to_array_length().unwrap_or(1) * stride
+ }
+ crate::ArraySize::Dynamic => stride,
+ };
+
+ let ty_span = types.get_span(ty);
+ ty = types.insert(
+ Type {
+ name,
+ inner: TypeInner::Array {
+ base: info.ty,
+ size,
+ stride,
+ },
+ },
+ ty_span,
+ );
+
+ (align, span)
+ }
+ // 5. If the member is a column-major matrix with C columns and R rows, the
+ // matrix is stored identically to an array of C column vectors with R
+ // components each, according to rule (4)
+ // TODO: Row major matrices
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let mut align = Alignment::from(rows) * Alignment::from_width(width);
+
+ // See comment at the beginning of the function
+ if StructLayout::Std430 != layout {
+ align = align.max(Alignment::MIN_UNIFORM);
+ }
+
+ // See comment on the error kind
+ if StructLayout::Std140 == layout && rows == crate::VectorSize::Bi {
+ errors.push(Error {
+ kind: ErrorKind::UnsupportedMatrixTypeInStd140,
+ meta,
+ });
+ }
+
+ (align, align * columns as u32)
+ }
+ TypeInner::Struct { ref members, .. } => {
+ let mut span = 0;
+ let mut align = Alignment::ONE;
+ let mut members = members.clone();
+ let name = types[ty].name.clone();
+
+ for member in members.iter_mut() {
+ let info = calculate_offset(member.ty, meta, layout, types, constants, errors);
+
+ let member_alignment = info.align;
+ span = member_alignment.round_up(span);
+ align = member_alignment.max(align);
+
+ member.ty = info.ty;
+ member.offset = span;
+
+ span += info.span;
+ }
+
+ span = align.round_up(span);
+
+ let ty_span = types.get_span(ty);
+ ty = types.insert(
+ Type {
+ name,
+ inner: TypeInner::Struct { members, span },
+ },
+ ty_span,
+ );
+
+ (align, span)
+ }
+ _ => {
+ errors.push(Error {
+ kind: ErrorKind::SemanticError("Invalid struct member type".into()),
+ meta,
+ });
+ (Alignment::ONE, 0)
+ }
+ };
+
+ TypeAlignSpan { ty, align, span }
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser.rs b/third_party/rust/naga/src/front/glsl/parser.rs
new file mode 100644
index 0000000000..ed139e799d
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser.rs
@@ -0,0 +1,448 @@
+use super::{
+ ast::{FunctionKind, Profile, TypeQualifiers},
+ context::{Context, ExprPos},
+ error::ExpectedToken,
+ error::{Error, ErrorKind},
+ lex::{Lexer, LexerResultKind},
+ token::{Directive, DirectiveKind},
+ token::{Token, TokenValue},
+ variables::{GlobalOrConstant, VarDeclaration},
+ Frontend, Result,
+};
+use crate::{arena::Handle, Block, Constant, ConstantInner, Expression, ScalarValue, Span, Type};
+use core::convert::TryFrom;
+use pp_rs::token::{PreprocessorError, Token as PPToken, TokenValue as PPTokenValue};
+use std::iter::Peekable;
+
+mod declarations;
+mod expressions;
+mod functions;
+mod types;
+
+pub struct ParsingContext<'source> {
+ lexer: Peekable<Lexer<'source>>,
+ /// Used to store tokens already consumed by the parser but that need to be backtracked
+ backtracked_token: Option<Token>,
+ last_meta: Span,
+}
+
+impl<'source> ParsingContext<'source> {
+ pub fn new(lexer: Lexer<'source>) -> Self {
+ ParsingContext {
+ lexer: lexer.peekable(),
+ backtracked_token: None,
+ last_meta: Span::default(),
+ }
+ }
+
+ /// Helper method for backtracking from a consumed token
+ ///
+ /// This method should always be used instead of assigning to `backtracked_token` since
+ /// it validates that backtracking hasn't occurred more than one time in a row
+ ///
+ /// # Panics
+ /// - If the parser already backtracked without bumping in between
+ pub fn backtrack(&mut self, token: Token) -> Result<()> {
+ // This should never happen
+ if let Some(ref prev_token) = self.backtracked_token {
+ return Err(Error {
+ kind: ErrorKind::InternalError("The parser tried to backtrack twice in a row"),
+ meta: prev_token.meta,
+ });
+ }
+
+ self.backtracked_token = Some(token);
+
+ Ok(())
+ }
+
+ pub fn expect_ident(&mut self, frontend: &mut Frontend) -> Result<(String, Span)> {
+ let token = self.bump(frontend)?;
+
+ match token.value {
+ TokenValue::Identifier(name) => Ok((name, token.meta)),
+ _ => Err(Error {
+ kind: ErrorKind::InvalidToken(token.value, vec![ExpectedToken::Identifier]),
+ meta: token.meta,
+ }),
+ }
+ }
+
+ pub fn expect(&mut self, frontend: &mut Frontend, value: TokenValue) -> Result<Token> {
+ let token = self.bump(frontend)?;
+
+ if token.value != value {
+ Err(Error {
+ kind: ErrorKind::InvalidToken(token.value, vec![value.into()]),
+ meta: token.meta,
+ })
+ } else {
+ Ok(token)
+ }
+ }
+
+ pub fn next(&mut self, frontend: &mut Frontend) -> Option<Token> {
+ loop {
+ if let Some(token) = self.backtracked_token.take() {
+ self.last_meta = token.meta;
+ break Some(token);
+ }
+
+ let res = self.lexer.next()?;
+
+ match res.kind {
+ LexerResultKind::Token(token) => {
+ self.last_meta = token.meta;
+ break Some(token);
+ }
+ LexerResultKind::Directive(directive) => {
+ frontend.handle_directive(directive, res.meta)
+ }
+ LexerResultKind::Error(error) => frontend.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(error),
+ meta: res.meta,
+ }),
+ }
+ }
+ }
+
+ pub fn bump(&mut self, frontend: &mut Frontend) -> Result<Token> {
+ self.next(frontend).ok_or(Error {
+ kind: ErrorKind::EndOfFile,
+ meta: self.last_meta,
+ })
+ }
+
+ /// Returns None on the end of the file rather than an error like other methods
+ pub fn bump_if(&mut self, frontend: &mut Frontend, value: TokenValue) -> Option<Token> {
+ if self.peek(frontend).filter(|t| t.value == value).is_some() {
+ self.bump(frontend).ok()
+ } else {
+ None
+ }
+ }
+
+ pub fn peek(&mut self, frontend: &mut Frontend) -> Option<&Token> {
+ loop {
+ if let Some(ref token) = self.backtracked_token {
+ break Some(token);
+ }
+
+ match self.lexer.peek()?.kind {
+ LexerResultKind::Token(_) => {
+ let res = self.lexer.peek()?;
+
+ match res.kind {
+ LexerResultKind::Token(ref token) => break Some(token),
+ _ => unreachable!(),
+ }
+ }
+ LexerResultKind::Error(_) | LexerResultKind::Directive(_) => {
+ let res = self.lexer.next()?;
+
+ match res.kind {
+ LexerResultKind::Directive(directive) => {
+ frontend.handle_directive(directive, res.meta)
+ }
+ LexerResultKind::Error(error) => frontend.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(error),
+ meta: res.meta,
+ }),
+ LexerResultKind::Token(_) => unreachable!(),
+ }
+ }
+ }
+ }
+ }
+
+ pub fn expect_peek(&mut self, frontend: &mut Frontend) -> Result<&Token> {
+ let meta = self.last_meta;
+ self.peek(frontend).ok_or(Error {
+ kind: ErrorKind::EndOfFile,
+ meta,
+ })
+ }
+
+ pub fn parse(&mut self, frontend: &mut Frontend) -> Result<()> {
+ // Body and expression arena for global initialization
+ let mut body = Block::new();
+ let mut ctx = Context::new(frontend, &mut body);
+
+ while self.peek(frontend).is_some() {
+ self.parse_external_declaration(frontend, &mut ctx, &mut body)?;
+ }
+
+ // Add an `EntryPoint` to `parser.module` for `main`, if a
+ // suitable overload exists. Error out if we can't find one.
+ if let Some(declaration) = frontend.lookup_function.get("main") {
+ for decl in declaration.overloads.iter() {
+ if let FunctionKind::Call(handle) = decl.kind {
+ if decl.defined && decl.parameters.is_empty() {
+ frontend.add_entry_point(handle, body, ctx.expressions);
+ return Ok(());
+ }
+ }
+ }
+ }
+
+ Err(Error {
+ kind: ErrorKind::SemanticError("Missing entry point".into()),
+ meta: Span::default(),
+ })
+ }
+
+ fn parse_uint_constant(&mut self, frontend: &mut Frontend) -> Result<(u32, Span)> {
+ let (value, meta) = self.parse_constant_expression(frontend)?;
+
+ let int = match frontend.module.constants[value].inner {
+ ConstantInner::Scalar {
+ value: ScalarValue::Uint(int),
+ ..
+ } => u32::try_from(int).map_err(|_| Error {
+ kind: ErrorKind::SemanticError("int constant overflows".into()),
+ meta,
+ })?,
+ ConstantInner::Scalar {
+ value: ScalarValue::Sint(int),
+ ..
+ } => u32::try_from(int).map_err(|_| Error {
+ kind: ErrorKind::SemanticError("int constant overflows".into()),
+ meta,
+ })?,
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::SemanticError("Expected a uint constant".into()),
+ meta,
+ })
+ }
+ };
+
+ Ok((int, meta))
+ }
+
+ fn parse_constant_expression(
+ &mut self,
+ frontend: &mut Frontend,
+ ) -> Result<(Handle<Constant>, Span)> {
+ let mut block = Block::new();
+
+ let mut ctx = Context::new(frontend, &mut block);
+
+ let mut stmt_ctx = ctx.stmt_ctx();
+ let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, &mut block, None)?;
+ let (root, meta) = ctx.lower_expect(stmt_ctx, frontend, expr, ExprPos::Rhs, &mut block)?;
+
+ Ok((frontend.solve_constant(&ctx, root, meta)?, meta))
+ }
+}
+
+impl Frontend {
+ fn handle_directive(&mut self, directive: Directive, meta: Span) {
+ let mut tokens = directive.tokens.into_iter();
+
+ match directive.kind {
+ DirectiveKind::Version { is_first_directive } => {
+ if !is_first_directive {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "#version must occur first in shader".into(),
+ ),
+ meta,
+ })
+ }
+
+ match tokens.next() {
+ Some(PPToken {
+ value: PPTokenValue::Integer(int),
+ location,
+ }) => match int.value {
+ 440 | 450 | 460 => self.meta.version = int.value as u16,
+ _ => self.errors.push(Error {
+ kind: ErrorKind::InvalidVersion(int.value),
+ meta: location.into(),
+ }),
+ },
+ Some(PPToken { value, location }) => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ }),
+ None => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine),
+ meta,
+ }),
+ };
+
+ match tokens.next() {
+ Some(PPToken {
+ value: PPTokenValue::Ident(name),
+ location,
+ }) => match name.as_str() {
+ "core" => self.meta.profile = Profile::Core,
+ _ => self.errors.push(Error {
+ kind: ErrorKind::InvalidProfile(name),
+ meta: location.into(),
+ }),
+ },
+ Some(PPToken { value, location }) => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ }),
+ None => {}
+ };
+
+ if let Some(PPToken { value, location }) = tokens.next() {
+ self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ })
+ }
+ }
+ DirectiveKind::Extension => {
+ // TODO: Proper extension handling
+ // - Checking for extension support in the compiler
+ // - Handle behaviors such as warn
+ // - Handle the all extension
+ let name = match tokens.next() {
+ Some(PPToken {
+ value: PPTokenValue::Ident(name),
+ ..
+ }) => Some(name),
+ Some(PPToken { value, location }) => {
+ self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ });
+
+ None
+ }
+ None => {
+ self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(
+ PreprocessorError::UnexpectedNewLine,
+ ),
+ meta,
+ });
+
+ None
+ }
+ };
+
+ match tokens.next() {
+ Some(PPToken {
+ value: PPTokenValue::Punct(pp_rs::token::Punct::Colon),
+ ..
+ }) => {}
+ Some(PPToken { value, location }) => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ }),
+ None => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine),
+ meta,
+ }),
+ };
+
+ match tokens.next() {
+ Some(PPToken {
+ value: PPTokenValue::Ident(behavior),
+ location,
+ }) => match behavior.as_str() {
+ "require" | "enable" | "warn" | "disable" => {
+ if let Some(name) = name {
+ self.meta.extensions.insert(name);
+ }
+ }
+ _ => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ PPTokenValue::Ident(behavior),
+ )),
+ meta: location.into(),
+ }),
+ },
+ Some(PPToken { value, location }) => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ }),
+ None => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine),
+ meta,
+ }),
+ }
+
+ if let Some(PPToken { value, location }) = tokens.next() {
+ self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ })
+ }
+ }
+ DirectiveKind::Pragma => {
+ // TODO: handle some common pragmas?
+ }
+ }
+ }
+}
+
+pub struct DeclarationContext<'ctx, 'qualifiers> {
+ qualifiers: TypeQualifiers<'qualifiers>,
+ /// Indicates a global declaration
+ external: bool,
+
+ ctx: &'ctx mut Context,
+ body: &'ctx mut Block,
+}
+
+impl<'ctx, 'qualifiers> DeclarationContext<'ctx, 'qualifiers> {
+ fn add_var(
+ &mut self,
+ frontend: &mut Frontend,
+ ty: Handle<Type>,
+ name: String,
+ init: Option<Handle<Constant>>,
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let decl = VarDeclaration {
+ qualifiers: &mut self.qualifiers,
+ ty,
+ name: Some(name),
+ init,
+ meta,
+ };
+
+ match self.external {
+ true => {
+ let global = frontend.add_global_var(self.ctx, self.body, decl)?;
+ let expr = match global {
+ GlobalOrConstant::Global(handle) => Expression::GlobalVariable(handle),
+ GlobalOrConstant::Constant(handle) => Expression::Constant(handle),
+ };
+ Ok(self.ctx.add_expression(expr, meta, self.body))
+ }
+ false => frontend.add_local_var(self.ctx, self.body, decl),
+ }
+ }
+
+ /// Emits all the expressions captured by the emitter and starts the emitter again
+ ///
+ /// Alias to [`emit_restart`] with the declaration body
+ ///
+ /// [`emit_restart`]: Context::emit_restart
+ #[inline]
+ fn flush_expressions(&mut self) {
+ self.ctx.emit_restart(self.body);
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser/declarations.rs b/third_party/rust/naga/src/front/glsl/parser/declarations.rs
new file mode 100644
index 0000000000..9085469dda
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser/declarations.rs
@@ -0,0 +1,671 @@
+use crate::{
+ front::glsl::{
+ ast::{
+ GlobalLookup, GlobalLookupKind, Precision, QualifierKey, QualifierValue,
+ StorageQualifier, StructLayout, TypeQualifiers,
+ },
+ context::{Context, ExprPos},
+ error::ExpectedToken,
+ offset,
+ token::{Token, TokenValue},
+ types::scalar_components,
+ variables::{GlobalOrConstant, VarDeclaration},
+ Error, ErrorKind, Frontend, Span,
+ },
+ proc::Alignment,
+ AddressSpace, Block, Expression, FunctionResult, Handle, ScalarKind, Statement, StructMember,
+ Type, TypeInner,
+};
+
+use super::{DeclarationContext, ParsingContext, Result};
+
+/// Helper method used to retrieve the child type of `ty` at
+/// index `i`.
+///
+/// # Note
+///
+/// Does not check if the index is valid and returns the same type
+/// when indexing out-of-bounds a struct or indexing a non indexable
+/// type.
+fn element_or_member_type(
+ ty: Handle<Type>,
+ i: usize,
+ types: &mut crate::UniqueArena<Type>,
+) -> Handle<Type> {
+ match types[ty].inner {
+ // The child type of a vector is a scalar of the same kind and width
+ TypeInner::Vector { kind, width, .. } => types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Scalar { kind, width },
+ },
+ Default::default(),
+ ),
+ // The child type of a matrix is a vector of floats with the same
+ // width and the size of the matrix rows.
+ TypeInner::Matrix { rows, width, .. } => types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: rows,
+ kind: ScalarKind::Float,
+ width,
+ },
+ },
+ Default::default(),
+ ),
+ // The child type of an array is the base type of the array
+ TypeInner::Array { base, .. } => base,
+ // The child type of a struct at index `i` is the type of it's
+ // member at that same index.
+ //
+ // In case the index is out of bounds the same type is returned
+ TypeInner::Struct { ref members, .. } => {
+ members.get(i).map(|member| member.ty).unwrap_or(ty)
+ }
+ // The type isn't indexable, the same type is returned
+ _ => ty,
+ }
+}
+
+impl<'source> ParsingContext<'source> {
+ pub fn parse_external_declaration(
+ &mut self,
+ frontend: &mut Frontend,
+ global_ctx: &mut Context,
+ global_body: &mut Block,
+ ) -> Result<()> {
+ if self
+ .parse_declaration(frontend, global_ctx, global_body, true)?
+ .is_none()
+ {
+ let token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Semicolon if frontend.meta.version == 460 => Ok(()),
+ _ => {
+ let expected = match frontend.meta.version {
+ 460 => vec![TokenValue::Semicolon.into(), ExpectedToken::Eof],
+ _ => vec![ExpectedToken::Eof],
+ };
+ Err(Error {
+ kind: ErrorKind::InvalidToken(token.value, expected),
+ meta: token.meta,
+ })
+ }
+ }
+ } else {
+ Ok(())
+ }
+ }
+
+ pub fn parse_initializer(
+ &mut self,
+ frontend: &mut Frontend,
+ ty: Handle<Type>,
+ ctx: &mut Context,
+ body: &mut Block,
+ ) -> Result<(Handle<Expression>, Span)> {
+ // initializer:
+ // assignment_expression
+ // LEFT_BRACE initializer_list RIGHT_BRACE
+ // LEFT_BRACE initializer_list COMMA RIGHT_BRACE
+ //
+ // initializer_list:
+ // initializer
+ // initializer_list COMMA initializer
+ if let Some(Token { mut meta, .. }) = self.bump_if(frontend, TokenValue::LeftBrace) {
+ // initializer_list
+ let mut components = Vec::new();
+ loop {
+ // The type expected to be parsed inside the initializer list
+ let new_ty =
+ element_or_member_type(ty, components.len(), &mut frontend.module.types);
+
+ components.push(self.parse_initializer(frontend, new_ty, ctx, body)?.0);
+
+ let token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Comma => {
+ if let Some(Token { meta: end_meta, .. }) =
+ self.bump_if(frontend, TokenValue::RightBrace)
+ {
+ meta.subsume(end_meta);
+ break;
+ }
+ }
+ TokenValue::RightBrace => {
+ meta.subsume(token.meta);
+ break;
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![TokenValue::Comma.into(), TokenValue::RightBrace.into()],
+ ),
+ meta: token.meta,
+ })
+ }
+ }
+ }
+
+ Ok((
+ ctx.add_expression(Expression::Compose { ty, components }, meta, body),
+ meta,
+ ))
+ } else {
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_assignment(frontend, ctx, &mut stmt, body)?;
+ let (mut init, init_meta) =
+ ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?;
+
+ let scalar_components = scalar_components(&frontend.module.types[ty].inner);
+ if let Some((kind, width)) = scalar_components {
+ ctx.implicit_conversion(frontend, &mut init, init_meta, kind, width)?;
+ }
+
+ Ok((init, init_meta))
+ }
+ }
+
+ // Note: caller preparsed the type and qualifiers
+ // Note: caller skips this if the fallthrough token is not expected to be consumed here so this
+ // produced Error::InvalidToken if it isn't consumed
+ pub fn parse_init_declarator_list(
+ &mut self,
+ frontend: &mut Frontend,
+ mut ty: Handle<Type>,
+ ctx: &mut DeclarationContext,
+ ) -> Result<()> {
+ // init_declarator_list:
+ // single_declaration
+ // init_declarator_list COMMA IDENTIFIER
+ // init_declarator_list COMMA IDENTIFIER array_specifier
+ // init_declarator_list COMMA IDENTIFIER array_specifier EQUAL initializer
+ // init_declarator_list COMMA IDENTIFIER EQUAL initializer
+ //
+ // single_declaration:
+ // fully_specified_type
+ // fully_specified_type IDENTIFIER
+ // fully_specified_type IDENTIFIER array_specifier
+ // fully_specified_type IDENTIFIER array_specifier EQUAL initializer
+ // fully_specified_type IDENTIFIER EQUAL initializer
+
+ // Consume any leading comma, e.g. this is valid: `float, a=1;`
+ if self
+ .peek(frontend)
+ .map_or(false, |t| t.value == TokenValue::Comma)
+ {
+ self.next(frontend);
+ }
+
+ loop {
+ let token = self.bump(frontend)?;
+ let name = match token.value {
+ TokenValue::Semicolon => break,
+ TokenValue::Identifier(name) => name,
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()],
+ ),
+ meta: token.meta,
+ })
+ }
+ };
+ let mut meta = token.meta;
+
+ // array_specifier
+ // array_specifier EQUAL initializer
+ // EQUAL initializer
+
+ // parse an array specifier if it exists
+ // NOTE: unlike other parse methods this one doesn't expect an array specifier and
+ // returns Ok(None) rather than an error if there is not one
+ self.parse_array_specifier(frontend, &mut meta, &mut ty)?;
+
+ let init = self
+ .bump_if(frontend, TokenValue::Assign)
+ .map::<Result<_>, _>(|_| {
+ let (mut expr, init_meta) =
+ self.parse_initializer(frontend, ty, ctx.ctx, ctx.body)?;
+
+ let scalar_components = scalar_components(&frontend.module.types[ty].inner);
+ if let Some((kind, width)) = scalar_components {
+ ctx.ctx
+ .implicit_conversion(frontend, &mut expr, init_meta, kind, width)?;
+ }
+
+ meta.subsume(init_meta);
+
+ Ok((expr, init_meta))
+ })
+ .transpose()?;
+
+ let is_const = ctx.qualifiers.storage.0 == StorageQualifier::Const;
+ let maybe_constant = if ctx.external {
+ if let Some((root, meta)) = init {
+ match frontend.solve_constant(ctx.ctx, root, meta) {
+ Ok(res) => Some(res),
+ // If the declaration is external (global scope) and is constant qualified
+ // then the initializer must be a constant expression
+ Err(err) if is_const => return Err(err),
+ _ => None,
+ }
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
+ let pointer = ctx.add_var(frontend, ty, name, maybe_constant, meta)?;
+
+ if let Some((value, _)) = init.filter(|_| maybe_constant.is_none()) {
+ ctx.flush_expressions();
+ ctx.body.push(Statement::Store { pointer, value }, meta);
+ }
+
+ let token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Semicolon => break,
+ TokenValue::Comma => {}
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![TokenValue::Comma.into(), TokenValue::Semicolon.into()],
+ ),
+ meta: token.meta,
+ })
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ /// `external` whether or not we are in a global or local context
+ pub fn parse_declaration(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ body: &mut Block,
+ external: bool,
+ ) -> Result<Option<Span>> {
+ //declaration:
+ // function_prototype SEMICOLON
+ //
+ // init_declarator_list SEMICOLON
+ // PRECISION precision_qualifier type_specifier SEMICOLON
+ //
+ // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE SEMICOLON
+ // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE IDENTIFIER SEMICOLON
+ // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE IDENTIFIER array_specifier SEMICOLON
+ // type_qualifier SEMICOLON type_qualifier IDENTIFIER SEMICOLON
+ // type_qualifier IDENTIFIER identifier_list SEMICOLON
+
+ if self.peek_type_qualifier(frontend) || self.peek_type_name(frontend) {
+ let mut qualifiers = self.parse_type_qualifiers(frontend)?;
+
+ if self.peek_type_name(frontend) {
+ // This branch handles variables and function prototypes and if
+ // external is true also function definitions
+ let (ty, mut meta) = self.parse_type(frontend)?;
+
+ let token = self.bump(frontend)?;
+ let token_fallthrough = match token.value {
+ TokenValue::Identifier(name) => match self.expect_peek(frontend)?.value {
+ TokenValue::LeftParen => {
+ // This branch handles function definition and prototypes
+ self.bump(frontend)?;
+
+ let result = ty.map(|ty| FunctionResult { ty, binding: None });
+ let mut body = Block::new();
+
+ let mut context = Context::new(frontend, &mut body);
+
+ self.parse_function_args(frontend, &mut context, &mut body)?;
+
+ let end_meta = self.expect(frontend, TokenValue::RightParen)?.meta;
+ meta.subsume(end_meta);
+
+ let token = self.bump(frontend)?;
+ return match token.value {
+ TokenValue::Semicolon => {
+ // This branch handles function prototypes
+ frontend.add_prototype(context, name, result, meta);
+
+ Ok(Some(meta))
+ }
+ TokenValue::LeftBrace if external => {
+ // This branch handles function definitions
+ // as you can see by the guard this branch
+ // only happens if external is also true
+
+ // parse the body
+ self.parse_compound_statement(
+ token.meta,
+ frontend,
+ &mut context,
+ &mut body,
+ &mut None,
+ )?;
+
+ frontend.add_function(context, name, result, body, meta);
+
+ Ok(Some(meta))
+ }
+ _ if external => Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![
+ TokenValue::LeftBrace.into(),
+ TokenValue::Semicolon.into(),
+ ],
+ ),
+ meta: token.meta,
+ }),
+ _ => Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![TokenValue::Semicolon.into()],
+ ),
+ meta: token.meta,
+ }),
+ };
+ }
+ // Pass the token to the init_declarator_list parser
+ _ => Token {
+ value: TokenValue::Identifier(name),
+ meta: token.meta,
+ },
+ },
+ // Pass the token to the init_declarator_list parser
+ _ => token,
+ };
+
+ // If program execution has reached here then this will be a
+ // init_declarator_list
+ // token_fallthrough will have a token that was already bumped
+ if let Some(ty) = ty {
+ let mut ctx = DeclarationContext {
+ qualifiers,
+ external,
+ ctx,
+ body,
+ };
+
+ self.backtrack(token_fallthrough)?;
+ self.parse_init_declarator_list(frontend, ty, &mut ctx)?;
+ } else {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError("Declaration cannot have void type".into()),
+ meta,
+ })
+ }
+
+ Ok(Some(meta))
+ } else {
+ // This branch handles struct definitions and modifiers like
+ // ```glsl
+ // layout(early_fragment_tests);
+ // ```
+ let token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Identifier(ty_name) => {
+ if self.bump_if(frontend, TokenValue::LeftBrace).is_some() {
+ self.parse_block_declaration(
+ frontend,
+ ctx,
+ body,
+ &mut qualifiers,
+ ty_name,
+ token.meta,
+ )
+ .map(Some)
+ } else {
+ if qualifiers.invariant.take().is_some() {
+ frontend.make_variable_invariant(ctx, body, &ty_name, token.meta);
+
+ qualifiers.unused_errors(&mut frontend.errors);
+ self.expect(frontend, TokenValue::Semicolon)?;
+ return Ok(Some(qualifiers.span));
+ }
+
+ //TODO: declaration
+ // type_qualifier IDENTIFIER SEMICOLON
+ // type_qualifier IDENTIFIER identifier_list SEMICOLON
+ Err(Error {
+ kind: ErrorKind::NotImplemented("variable qualifier"),
+ meta: token.meta,
+ })
+ }
+ }
+ TokenValue::Semicolon => {
+ if let Some(value) =
+ qualifiers.uint_layout_qualifier("local_size_x", &mut frontend.errors)
+ {
+ frontend.meta.workgroup_size[0] = value;
+ }
+ if let Some(value) =
+ qualifiers.uint_layout_qualifier("local_size_y", &mut frontend.errors)
+ {
+ frontend.meta.workgroup_size[1] = value;
+ }
+ if let Some(value) =
+ qualifiers.uint_layout_qualifier("local_size_z", &mut frontend.errors)
+ {
+ frontend.meta.workgroup_size[2] = value;
+ }
+
+ frontend.meta.early_fragment_tests |= qualifiers
+ .none_layout_qualifier("early_fragment_tests", &mut frontend.errors);
+
+ qualifiers.unused_errors(&mut frontend.errors);
+
+ Ok(Some(qualifiers.span))
+ }
+ _ => Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()],
+ ),
+ meta: token.meta,
+ }),
+ }
+ }
+ } else {
+ match self.peek(frontend).map(|t| &t.value) {
+ Some(&TokenValue::Precision) => {
+ // PRECISION precision_qualifier type_specifier SEMICOLON
+ self.bump(frontend)?;
+
+ let token = self.bump(frontend)?;
+ let _ = match token.value {
+ TokenValue::PrecisionQualifier(p) => p,
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![
+ TokenValue::PrecisionQualifier(Precision::High).into(),
+ TokenValue::PrecisionQualifier(Precision::Medium).into(),
+ TokenValue::PrecisionQualifier(Precision::Low).into(),
+ ],
+ ),
+ meta: token.meta,
+ })
+ }
+ };
+
+ let (ty, meta) = self.parse_type_non_void(frontend)?;
+
+ match frontend.module.types[ty].inner {
+ TypeInner::Scalar {
+ kind: ScalarKind::Float | ScalarKind::Sint,
+ ..
+ } => {}
+ _ => frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Precision statement can only work on floats and ints".into(),
+ ),
+ meta,
+ }),
+ }
+
+ self.expect(frontend, TokenValue::Semicolon)?;
+
+ Ok(Some(meta))
+ }
+ _ => Ok(None),
+ }
+ }
+ }
+
+ pub fn parse_block_declaration(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ body: &mut Block,
+ qualifiers: &mut TypeQualifiers,
+ ty_name: String,
+ mut meta: Span,
+ ) -> Result<Span> {
+ let layout = match qualifiers.layout_qualifiers.remove(&QualifierKey::Layout) {
+ Some((QualifierValue::Layout(l), _)) => l,
+ None => {
+ if let StorageQualifier::AddressSpace(AddressSpace::Storage { .. }) =
+ qualifiers.storage.0
+ {
+ StructLayout::Std430
+ } else {
+ StructLayout::Std140
+ }
+ }
+ _ => unreachable!(),
+ };
+
+ let mut members = Vec::new();
+ let span = self.parse_struct_declaration_list(frontend, &mut members, layout)?;
+ self.expect(frontend, TokenValue::RightBrace)?;
+
+ let mut ty = frontend.module.types.insert(
+ Type {
+ name: Some(ty_name),
+ inner: TypeInner::Struct {
+ members: members.clone(),
+ span,
+ },
+ },
+ Default::default(),
+ );
+
+ let token = self.bump(frontend)?;
+ let name = match token.value {
+ TokenValue::Semicolon => None,
+ TokenValue::Identifier(name) => {
+ self.parse_array_specifier(frontend, &mut meta, &mut ty)?;
+
+ self.expect(frontend, TokenValue::Semicolon)?;
+
+ Some(name)
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()],
+ ),
+ meta: token.meta,
+ })
+ }
+ };
+
+ let global = frontend.add_global_var(
+ ctx,
+ body,
+ VarDeclaration {
+ qualifiers,
+ ty,
+ name,
+ init: None,
+ meta,
+ },
+ )?;
+
+ for (i, k, ty) in members.into_iter().enumerate().filter_map(|(i, m)| {
+ let ty = m.ty;
+ m.name.map(|s| (i as u32, s, ty))
+ }) {
+ let lookup = GlobalLookup {
+ kind: match global {
+ GlobalOrConstant::Global(handle) => GlobalLookupKind::BlockSelect(handle, i),
+ GlobalOrConstant::Constant(handle) => GlobalLookupKind::Constant(handle, ty),
+ },
+ entry_arg: None,
+ mutable: true,
+ };
+ ctx.add_global(frontend, &k, lookup, body);
+
+ frontend.global_variables.push((k, lookup));
+ }
+
+ Ok(meta)
+ }
+
+ // TODO: Accept layout arguments
+ pub fn parse_struct_declaration_list(
+ &mut self,
+ frontend: &mut Frontend,
+ members: &mut Vec<StructMember>,
+ layout: StructLayout,
+ ) -> Result<u32> {
+ let mut span = 0;
+ let mut align = Alignment::ONE;
+
+ loop {
+ // TODO: type_qualifier
+
+ let (mut ty, mut meta) = self.parse_type_non_void(frontend)?;
+ let (name, end_meta) = self.expect_ident(frontend)?;
+
+ meta.subsume(end_meta);
+
+ self.parse_array_specifier(frontend, &mut meta, &mut ty)?;
+
+ self.expect(frontend, TokenValue::Semicolon)?;
+
+ let info = offset::calculate_offset(
+ ty,
+ meta,
+ layout,
+ &mut frontend.module.types,
+ &frontend.module.constants,
+ &mut frontend.errors,
+ );
+
+ let member_alignment = info.align;
+ span = member_alignment.round_up(span);
+ align = member_alignment.max(align);
+
+ members.push(StructMember {
+ name: Some(name),
+ ty: info.ty,
+ binding: None,
+ offset: span,
+ });
+
+ span += info.span;
+
+ if let TokenValue::RightBrace = self.expect_peek(frontend)?.value {
+ break;
+ }
+ }
+
+ span = align.round_up(span);
+
+ Ok(span)
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser/expressions.rs b/third_party/rust/naga/src/front/glsl/parser/expressions.rs
new file mode 100644
index 0000000000..f09e58b6f6
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser/expressions.rs
@@ -0,0 +1,546 @@
+use crate::{
+ front::glsl::{
+ ast::{FunctionCall, FunctionCallKind, HirExpr, HirExprKind},
+ context::{Context, StmtContext},
+ error::{ErrorKind, ExpectedToken},
+ parser::ParsingContext,
+ token::{Token, TokenValue},
+ Error, Frontend, Result, Span,
+ },
+ ArraySize, BinaryOperator, Block, Constant, ConstantInner, Handle, ScalarValue, Type,
+ TypeInner, UnaryOperator,
+};
+
+impl<'source> ParsingContext<'source> {
+ pub fn parse_primary(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ ) -> Result<Handle<HirExpr>> {
+ let mut token = self.bump(frontend)?;
+
+ let (width, value) = match token.value {
+ TokenValue::IntConstant(int) => (
+ (int.width / 8) as u8,
+ if int.signed {
+ ScalarValue::Sint(int.value as i64)
+ } else {
+ ScalarValue::Uint(int.value)
+ },
+ ),
+ TokenValue::FloatConstant(float) => (
+ (float.width / 8) as u8,
+ ScalarValue::Float(float.value as f64),
+ ),
+ TokenValue::BoolConstant(value) => (1, ScalarValue::Bool(value)),
+ TokenValue::LeftParen => {
+ let expr = self.parse_expression(frontend, ctx, stmt, body)?;
+ let meta = self.expect(frontend, TokenValue::RightParen)?.meta;
+
+ token.meta.subsume(meta);
+
+ return Ok(expr);
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![
+ TokenValue::LeftParen.into(),
+ ExpectedToken::IntLiteral,
+ ExpectedToken::FloatLiteral,
+ ExpectedToken::BoolLiteral,
+ ],
+ ),
+ meta: token.meta,
+ });
+ }
+ };
+
+ let handle = frontend.module.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar { width, value },
+ },
+ token.meta,
+ );
+
+ Ok(stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Constant(handle),
+ meta: token.meta,
+ },
+ Default::default(),
+ ))
+ }
+
+ pub fn parse_function_call_args(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ meta: &mut Span,
+ ) -> Result<Vec<Handle<HirExpr>>> {
+ let mut args = Vec::new();
+ if let Some(token) = self.bump_if(frontend, TokenValue::RightParen) {
+ meta.subsume(token.meta);
+ } else {
+ loop {
+ args.push(self.parse_assignment(frontend, ctx, stmt, body)?);
+
+ let token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Comma => {}
+ TokenValue::RightParen => {
+ meta.subsume(token.meta);
+ break;
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![TokenValue::Comma.into(), TokenValue::RightParen.into()],
+ ),
+ meta: token.meta,
+ });
+ }
+ }
+ }
+ }
+
+ Ok(args)
+ }
+
+ pub fn parse_postfix(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ ) -> Result<Handle<HirExpr>> {
+ let mut base = if self.peek_type_name(frontend) {
+ let (mut handle, mut meta) = self.parse_type_non_void(frontend)?;
+
+ self.expect(frontend, TokenValue::LeftParen)?;
+ let args = self.parse_function_call_args(frontend, ctx, stmt, body, &mut meta)?;
+
+ if let TypeInner::Array {
+ size: ArraySize::Dynamic,
+ stride,
+ base,
+ } = frontend.module.types[handle].inner
+ {
+ let span = frontend.module.types.get_span(handle);
+
+ let constant = frontend.module.constants.fetch_or_append(
+ Constant {
+ name: None,
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Uint(args.len() as u64),
+ },
+ },
+ Span::default(),
+ );
+ handle = frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Array {
+ stride,
+ base,
+ size: ArraySize::Constant(constant),
+ },
+ },
+ span,
+ )
+ }
+
+ stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Call(FunctionCall {
+ kind: FunctionCallKind::TypeConstructor(handle),
+ args,
+ }),
+ meta,
+ },
+ Default::default(),
+ )
+ } else if let TokenValue::Identifier(_) = self.expect_peek(frontend)?.value {
+ let (name, mut meta) = self.expect_ident(frontend)?;
+
+ let expr = if self.bump_if(frontend, TokenValue::LeftParen).is_some() {
+ let args = self.parse_function_call_args(frontend, ctx, stmt, body, &mut meta)?;
+
+ let kind = match frontend.lookup_type.get(&name) {
+ Some(ty) => FunctionCallKind::TypeConstructor(*ty),
+ None => FunctionCallKind::Function(name),
+ };
+
+ HirExpr {
+ kind: HirExprKind::Call(FunctionCall { kind, args }),
+ meta,
+ }
+ } else {
+ let var = match frontend.lookup_variable(ctx, body, &name, meta) {
+ Some(var) => var,
+ None => {
+ return Err(Error {
+ kind: ErrorKind::UnknownVariable(name),
+ meta,
+ })
+ }
+ };
+
+ HirExpr {
+ kind: HirExprKind::Variable(var),
+ meta,
+ }
+ };
+
+ stmt.hir_exprs.append(expr, Default::default())
+ } else {
+ self.parse_primary(frontend, ctx, stmt, body)?
+ };
+
+ while let TokenValue::LeftBracket
+ | TokenValue::Dot
+ | TokenValue::Increment
+ | TokenValue::Decrement = self.expect_peek(frontend)?.value
+ {
+ let Token { value, mut meta } = self.bump(frontend)?;
+
+ match value {
+ TokenValue::LeftBracket => {
+ let index = self.parse_expression(frontend, ctx, stmt, body)?;
+ let end_meta = self.expect(frontend, TokenValue::RightBracket)?.meta;
+
+ meta.subsume(end_meta);
+ base = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Access { base, index },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ TokenValue::Dot => {
+ let (field, end_meta) = self.expect_ident(frontend)?;
+
+ if self.bump_if(frontend, TokenValue::LeftParen).is_some() {
+ let args =
+ self.parse_function_call_args(frontend, ctx, stmt, body, &mut meta)?;
+
+ base = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Method {
+ expr: base,
+ name: field,
+ args,
+ },
+ meta,
+ },
+ Default::default(),
+ );
+ continue;
+ }
+
+ meta.subsume(end_meta);
+ base = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Select { base, field },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ TokenValue::Increment | TokenValue::Decrement => {
+ base = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::PrePostfix {
+ op: match value {
+ TokenValue::Increment => crate::BinaryOperator::Add,
+ _ => crate::BinaryOperator::Subtract,
+ },
+ postfix: true,
+ expr: base,
+ },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ Ok(base)
+ }
+
+ pub fn parse_unary(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ ) -> Result<Handle<HirExpr>> {
+ Ok(match self.expect_peek(frontend)?.value {
+ TokenValue::Plus | TokenValue::Dash | TokenValue::Bang | TokenValue::Tilde => {
+ let Token { value, mut meta } = self.bump(frontend)?;
+
+ let expr = self.parse_unary(frontend, ctx, stmt, body)?;
+ let end_meta = stmt.hir_exprs[expr].meta;
+
+ let kind = match value {
+ TokenValue::Dash => HirExprKind::Unary {
+ op: UnaryOperator::Negate,
+ expr,
+ },
+ TokenValue::Bang | TokenValue::Tilde => HirExprKind::Unary {
+ op: UnaryOperator::Not,
+ expr,
+ },
+ _ => return Ok(expr),
+ };
+
+ meta.subsume(end_meta);
+ stmt.hir_exprs
+ .append(HirExpr { kind, meta }, Default::default())
+ }
+ TokenValue::Increment | TokenValue::Decrement => {
+ let Token { value, meta } = self.bump(frontend)?;
+
+ let expr = self.parse_unary(frontend, ctx, stmt, body)?;
+
+ stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::PrePostfix {
+ op: match value {
+ TokenValue::Increment => crate::BinaryOperator::Add,
+ _ => crate::BinaryOperator::Subtract,
+ },
+ postfix: false,
+ expr,
+ },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ _ => self.parse_postfix(frontend, ctx, stmt, body)?,
+ })
+ }
+
+ pub fn parse_binary(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ passthrough: Option<Handle<HirExpr>>,
+ min_bp: u8,
+ ) -> Result<Handle<HirExpr>> {
+ let mut left = passthrough
+ .ok_or(ErrorKind::EndOfFile /* Dummy error */)
+ .or_else(|_| self.parse_unary(frontend, ctx, stmt, body))?;
+ let mut meta = stmt.hir_exprs[left].meta;
+
+ while let Some((l_bp, r_bp)) = binding_power(&self.expect_peek(frontend)?.value) {
+ if l_bp < min_bp {
+ break;
+ }
+
+ let Token { value, .. } = self.bump(frontend)?;
+
+ let right = self.parse_binary(frontend, ctx, stmt, body, None, r_bp)?;
+ let end_meta = stmt.hir_exprs[right].meta;
+
+ meta.subsume(end_meta);
+ left = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Binary {
+ left,
+ op: match value {
+ TokenValue::LogicalOr => BinaryOperator::LogicalOr,
+ TokenValue::LogicalXor => BinaryOperator::NotEqual,
+ TokenValue::LogicalAnd => BinaryOperator::LogicalAnd,
+ TokenValue::VerticalBar => BinaryOperator::InclusiveOr,
+ TokenValue::Caret => BinaryOperator::ExclusiveOr,
+ TokenValue::Ampersand => BinaryOperator::And,
+ TokenValue::Equal => BinaryOperator::Equal,
+ TokenValue::NotEqual => BinaryOperator::NotEqual,
+ TokenValue::GreaterEqual => BinaryOperator::GreaterEqual,
+ TokenValue::LessEqual => BinaryOperator::LessEqual,
+ TokenValue::LeftAngle => BinaryOperator::Less,
+ TokenValue::RightAngle => BinaryOperator::Greater,
+ TokenValue::LeftShift => BinaryOperator::ShiftLeft,
+ TokenValue::RightShift => BinaryOperator::ShiftRight,
+ TokenValue::Plus => BinaryOperator::Add,
+ TokenValue::Dash => BinaryOperator::Subtract,
+ TokenValue::Star => BinaryOperator::Multiply,
+ TokenValue::Slash => BinaryOperator::Divide,
+ TokenValue::Percent => BinaryOperator::Modulo,
+ _ => unreachable!(),
+ },
+ right,
+ },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+
+ Ok(left)
+ }
+
+ pub fn parse_conditional(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ passthrough: Option<Handle<HirExpr>>,
+ ) -> Result<Handle<HirExpr>> {
+ let mut condition = self.parse_binary(frontend, ctx, stmt, body, passthrough, 0)?;
+ let mut meta = stmt.hir_exprs[condition].meta;
+
+ if self.bump_if(frontend, TokenValue::Question).is_some() {
+ let accept = self.parse_expression(frontend, ctx, stmt, body)?;
+ self.expect(frontend, TokenValue::Colon)?;
+ let reject = self.parse_assignment(frontend, ctx, stmt, body)?;
+ let end_meta = stmt.hir_exprs[reject].meta;
+
+ meta.subsume(end_meta);
+ condition = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Conditional {
+ condition,
+ accept,
+ reject,
+ },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+
+ Ok(condition)
+ }
+
+ pub fn parse_assignment(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ ) -> Result<Handle<HirExpr>> {
+ let tgt = self.parse_unary(frontend, ctx, stmt, body)?;
+ let mut meta = stmt.hir_exprs[tgt].meta;
+
+ Ok(match self.expect_peek(frontend)?.value {
+ TokenValue::Assign => {
+ self.bump(frontend)?;
+ let value = self.parse_assignment(frontend, ctx, stmt, body)?;
+ let end_meta = stmt.hir_exprs[value].meta;
+
+ meta.subsume(end_meta);
+ stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Assign { tgt, value },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ TokenValue::OrAssign
+ | TokenValue::AndAssign
+ | TokenValue::AddAssign
+ | TokenValue::DivAssign
+ | TokenValue::ModAssign
+ | TokenValue::SubAssign
+ | TokenValue::MulAssign
+ | TokenValue::LeftShiftAssign
+ | TokenValue::RightShiftAssign
+ | TokenValue::XorAssign => {
+ let token = self.bump(frontend)?;
+ let right = self.parse_assignment(frontend, ctx, stmt, body)?;
+ let end_meta = stmt.hir_exprs[right].meta;
+
+ meta.subsume(end_meta);
+ let value = stmt.hir_exprs.append(
+ HirExpr {
+ meta,
+ kind: HirExprKind::Binary {
+ left: tgt,
+ op: match token.value {
+ TokenValue::OrAssign => BinaryOperator::InclusiveOr,
+ TokenValue::AndAssign => BinaryOperator::And,
+ TokenValue::AddAssign => BinaryOperator::Add,
+ TokenValue::DivAssign => BinaryOperator::Divide,
+ TokenValue::ModAssign => BinaryOperator::Modulo,
+ TokenValue::SubAssign => BinaryOperator::Subtract,
+ TokenValue::MulAssign => BinaryOperator::Multiply,
+ TokenValue::LeftShiftAssign => BinaryOperator::ShiftLeft,
+ TokenValue::RightShiftAssign => BinaryOperator::ShiftRight,
+ TokenValue::XorAssign => BinaryOperator::ExclusiveOr,
+ _ => unreachable!(),
+ },
+ right,
+ },
+ },
+ Default::default(),
+ );
+
+ stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Assign { tgt, value },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ _ => self.parse_conditional(frontend, ctx, stmt, body, Some(tgt))?,
+ })
+ }
+
+ pub fn parse_expression(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ body: &mut Block,
+ ) -> Result<Handle<HirExpr>> {
+ let mut expr = self.parse_assignment(frontend, ctx, stmt, body)?;
+
+ while let TokenValue::Comma = self.expect_peek(frontend)?.value {
+ self.bump(frontend)?;
+ expr = self.parse_assignment(frontend, ctx, stmt, body)?;
+ }
+
+ Ok(expr)
+ }
+}
+
+const fn binding_power(value: &TokenValue) -> Option<(u8, u8)> {
+ Some(match *value {
+ TokenValue::LogicalOr => (1, 2),
+ TokenValue::LogicalXor => (3, 4),
+ TokenValue::LogicalAnd => (5, 6),
+ TokenValue::VerticalBar => (7, 8),
+ TokenValue::Caret => (9, 10),
+ TokenValue::Ampersand => (11, 12),
+ TokenValue::Equal | TokenValue::NotEqual => (13, 14),
+ TokenValue::GreaterEqual
+ | TokenValue::LessEqual
+ | TokenValue::LeftAngle
+ | TokenValue::RightAngle => (15, 16),
+ TokenValue::LeftShift | TokenValue::RightShift => (17, 18),
+ TokenValue::Plus | TokenValue::Dash => (19, 20),
+ TokenValue::Star | TokenValue::Slash | TokenValue::Percent => (21, 22),
+ _ => return None,
+ })
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser/functions.rs b/third_party/rust/naga/src/front/glsl/parser/functions.rs
new file mode 100644
index 0000000000..781f4cffed
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser/functions.rs
@@ -0,0 +1,654 @@
+use crate::front::glsl::context::ExprPos;
+use crate::front::glsl::Span;
+use crate::{
+ front::glsl::{
+ ast::ParameterQualifier,
+ context::Context,
+ parser::ParsingContext,
+ token::{Token, TokenValue},
+ variables::VarDeclaration,
+ Error, ErrorKind, Frontend, Result,
+ },
+ Block, ConstantInner, Expression, ScalarValue, Statement, SwitchCase, UnaryOperator,
+};
+
+impl<'source> ParsingContext<'source> {
+ pub fn peek_parameter_qualifier(&mut self, frontend: &mut Frontend) -> bool {
+ self.peek(frontend).map_or(false, |t| match t.value {
+ TokenValue::In | TokenValue::Out | TokenValue::InOut | TokenValue::Const => true,
+ _ => false,
+ })
+ }
+
+ /// Returns the parsed `ParameterQualifier` or `ParameterQualifier::In`
+ pub fn parse_parameter_qualifier(&mut self, frontend: &mut Frontend) -> ParameterQualifier {
+ if self.peek_parameter_qualifier(frontend) {
+ match self.bump(frontend).unwrap().value {
+ TokenValue::In => ParameterQualifier::In,
+ TokenValue::Out => ParameterQualifier::Out,
+ TokenValue::InOut => ParameterQualifier::InOut,
+ TokenValue::Const => ParameterQualifier::Const,
+ _ => unreachable!(),
+ }
+ } else {
+ ParameterQualifier::In
+ }
+ }
+
+ pub fn parse_statement(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ body: &mut Block,
+ terminator: &mut Option<usize>,
+ ) -> Result<Option<Span>> {
+ // Type qualifiers always identify a declaration statement
+ if self.peek_type_qualifier(frontend) {
+ return self.parse_declaration(frontend, ctx, body, false);
+ }
+
+ // Type names can identify either declaration statements or type constructors
+ // depending on wether the token following the type name is a `(` (LeftParen)
+ if self.peek_type_name(frontend) {
+ // Start by consuming the type name so that we can peek the token after it
+ let token = self.bump(frontend)?;
+ // Peek the next token and check if it's a `(` (LeftParen) if so the statement
+ // is a constructor, otherwise it's a declaration. We need to do the check
+ // beforehand and not in the if since we will backtrack before the if
+ let declaration = TokenValue::LeftParen != self.expect_peek(frontend)?.value;
+
+ self.backtrack(token)?;
+
+ if declaration {
+ return self.parse_declaration(frontend, ctx, body, false);
+ }
+ }
+
+ let new_break = || {
+ let mut block = Block::new();
+ block.push(Statement::Break, crate::Span::default());
+ block
+ };
+
+ let &Token {
+ ref value,
+ mut meta,
+ } = self.expect_peek(frontend)?;
+
+ let meta_rest = match *value {
+ TokenValue::Continue => {
+ let meta = self.bump(frontend)?.meta;
+ body.push(Statement::Continue, meta);
+ terminator.get_or_insert(body.len());
+ self.expect(frontend, TokenValue::Semicolon)?.meta
+ }
+ TokenValue::Break => {
+ let meta = self.bump(frontend)?.meta;
+ body.push(Statement::Break, meta);
+ terminator.get_or_insert(body.len());
+ self.expect(frontend, TokenValue::Semicolon)?.meta
+ }
+ TokenValue::Return => {
+ self.bump(frontend)?;
+ let (value, meta) = match self.expect_peek(frontend)?.value {
+ TokenValue::Semicolon => (None, self.bump(frontend)?.meta),
+ _ => {
+ // TODO: Implicit conversions
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?;
+ self.expect(frontend, TokenValue::Semicolon)?;
+ let (handle, meta) =
+ ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?;
+ (Some(handle), meta)
+ }
+ };
+
+ ctx.emit_restart(body);
+
+ body.push(Statement::Return { value }, meta);
+ terminator.get_or_insert(body.len());
+
+ meta
+ }
+ TokenValue::Discard => {
+ let meta = self.bump(frontend)?.meta;
+ body.push(Statement::Kill, meta);
+ terminator.get_or_insert(body.len());
+
+ self.expect(frontend, TokenValue::Semicolon)?.meta
+ }
+ TokenValue::If => {
+ let mut meta = self.bump(frontend)?.meta;
+
+ self.expect(frontend, TokenValue::LeftParen)?;
+ let condition = {
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?;
+ let (handle, more_meta) =
+ ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?;
+ meta.subsume(more_meta);
+ handle
+ };
+ self.expect(frontend, TokenValue::RightParen)?;
+
+ ctx.emit_restart(body);
+
+ let mut accept = Block::new();
+ if let Some(more_meta) =
+ self.parse_statement(frontend, ctx, &mut accept, &mut None)?
+ {
+ meta.subsume(more_meta)
+ }
+
+ let mut reject = Block::new();
+ if self.bump_if(frontend, TokenValue::Else).is_some() {
+ if let Some(more_meta) =
+ self.parse_statement(frontend, ctx, &mut reject, &mut None)?
+ {
+ meta.subsume(more_meta);
+ }
+ }
+
+ body.push(
+ Statement::If {
+ condition,
+ accept,
+ reject,
+ },
+ meta,
+ );
+
+ meta
+ }
+ TokenValue::Switch => {
+ let mut meta = self.bump(frontend)?.meta;
+ let end_meta;
+
+ self.expect(frontend, TokenValue::LeftParen)?;
+
+ let (selector, uint) = {
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?;
+ let (root, meta) =
+ ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?;
+ let uint = frontend.resolve_type(ctx, root, meta)?.scalar_kind()
+ == Some(crate::ScalarKind::Uint);
+ (root, uint)
+ };
+
+ self.expect(frontend, TokenValue::RightParen)?;
+
+ ctx.emit_restart(body);
+
+ let mut cases = Vec::new();
+ // Track if any default case is present in the switch statement.
+ let mut default_present = false;
+
+ self.expect(frontend, TokenValue::LeftBrace)?;
+ loop {
+ let value = match self.expect_peek(frontend)?.value {
+ TokenValue::Case => {
+ self.bump(frontend)?;
+
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?;
+ let (root, meta) =
+ ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs, body)?;
+ let constant = frontend.solve_constant(ctx, root, meta)?;
+
+ match frontend.module.constants[constant].inner {
+ ConstantInner::Scalar {
+ value: ScalarValue::Sint(int),
+ ..
+ } => match uint {
+ true => crate::SwitchValue::U32(int as u32),
+ false => crate::SwitchValue::I32(int as i32),
+ },
+ ConstantInner::Scalar {
+ value: ScalarValue::Uint(int),
+ ..
+ } => crate::SwitchValue::U32(int as u32),
+ _ => {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Case values can only be integers".into(),
+ ),
+ meta,
+ });
+
+ crate::SwitchValue::I32(0)
+ }
+ }
+ }
+ TokenValue::Default => {
+ self.bump(frontend)?;
+ default_present = true;
+ crate::SwitchValue::Default
+ }
+ TokenValue::RightBrace => {
+ end_meta = self.bump(frontend)?.meta;
+ break;
+ }
+ _ => {
+ let Token { value, meta } = self.bump(frontend)?;
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ value,
+ vec![
+ TokenValue::Case.into(),
+ TokenValue::Default.into(),
+ TokenValue::RightBrace.into(),
+ ],
+ ),
+ meta,
+ });
+ }
+ };
+
+ self.expect(frontend, TokenValue::Colon)?;
+
+ let mut body = Block::new();
+
+ let mut case_terminator = None;
+ loop {
+ match self.expect_peek(frontend)?.value {
+ TokenValue::Case | TokenValue::Default | TokenValue::RightBrace => {
+ break
+ }
+ _ => {
+ self.parse_statement(
+ frontend,
+ ctx,
+ &mut body,
+ &mut case_terminator,
+ )?;
+ }
+ }
+ }
+
+ let mut fall_through = true;
+
+ if let Some(mut idx) = case_terminator {
+ if let Statement::Break = body[idx - 1] {
+ fall_through = false;
+ idx -= 1;
+ }
+
+ body.cull(idx..)
+ }
+
+ cases.push(SwitchCase {
+ value,
+ body,
+ fall_through,
+ })
+ }
+
+ meta.subsume(end_meta);
+
+ // NOTE: do not unwrap here since a switch statement isn't required
+ // to have any cases.
+ if let Some(case) = cases.last_mut() {
+ // GLSL requires that the last case not be empty, so we check
+ // that here and produce an error otherwise (fall_through must
+ // also be checked because `break`s count as statements but
+ // they aren't added to the body)
+ if case.body.is_empty() && case.fall_through {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "last case/default label must be followed by statements".into(),
+ ),
+ meta,
+ })
+ }
+
+ // GLSL allows the last case to not have any `break` statement,
+ // this would mark it as fall through but naga's IR requires that
+ // the last case must not be fall through, so we mark need to mark
+ // the last case as not fall through always.
+ case.fall_through = false;
+ }
+
+ // Add an empty default case in case non was present, this is needed because
+ // naga's IR requires that all switch statements must have a default case but
+ // GLSL doesn't require that, so we might need to add an empty default case.
+ if !default_present {
+ cases.push(SwitchCase {
+ value: crate::SwitchValue::Default,
+ body: Block::new(),
+ fall_through: false,
+ })
+ }
+
+ body.push(Statement::Switch { selector, cases }, meta);
+
+ meta
+ }
+ TokenValue::While => {
+ let mut meta = self.bump(frontend)?.meta;
+
+ let mut loop_body = Block::new();
+
+ let mut stmt = ctx.stmt_ctx();
+ self.expect(frontend, TokenValue::LeftParen)?;
+ let root = self.parse_expression(frontend, ctx, &mut stmt, &mut loop_body)?;
+ meta.subsume(self.expect(frontend, TokenValue::RightParen)?.meta);
+
+ let (expr, expr_meta) =
+ ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs, &mut loop_body)?;
+ let condition = ctx.add_expression(
+ Expression::Unary {
+ op: UnaryOperator::Not,
+ expr,
+ },
+ expr_meta,
+ &mut loop_body,
+ );
+
+ ctx.emit_restart(&mut loop_body);
+
+ loop_body.push(
+ Statement::If {
+ condition,
+ accept: new_break(),
+ reject: Block::new(),
+ },
+ crate::Span::default(),
+ );
+
+ meta.subsume(expr_meta);
+
+ if let Some(body_meta) =
+ self.parse_statement(frontend, ctx, &mut loop_body, &mut None)?
+ {
+ meta.subsume(body_meta);
+ }
+
+ body.push(
+ Statement::Loop {
+ body: loop_body,
+ continuing: Block::new(),
+ break_if: None,
+ },
+ meta,
+ );
+
+ meta
+ }
+ TokenValue::Do => {
+ let mut meta = self.bump(frontend)?.meta;
+
+ let mut loop_body = Block::new();
+
+ let mut terminator = None;
+ self.parse_statement(frontend, ctx, &mut loop_body, &mut terminator)?;
+
+ let mut stmt = ctx.stmt_ctx();
+
+ self.expect(frontend, TokenValue::While)?;
+ self.expect(frontend, TokenValue::LeftParen)?;
+ let root = self.parse_expression(frontend, ctx, &mut stmt, &mut loop_body)?;
+ let end_meta = self.expect(frontend, TokenValue::RightParen)?.meta;
+
+ meta.subsume(end_meta);
+
+ let (expr, expr_meta) =
+ ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs, &mut loop_body)?;
+ let condition = ctx.add_expression(
+ Expression::Unary {
+ op: UnaryOperator::Not,
+ expr,
+ },
+ expr_meta,
+ &mut loop_body,
+ );
+
+ ctx.emit_restart(&mut loop_body);
+
+ loop_body.push(
+ Statement::If {
+ condition,
+ accept: new_break(),
+ reject: Block::new(),
+ },
+ crate::Span::default(),
+ );
+
+ if let Some(idx) = terminator {
+ loop_body.cull(idx..)
+ }
+
+ body.push(
+ Statement::Loop {
+ body: loop_body,
+ continuing: Block::new(),
+ break_if: None,
+ },
+ meta,
+ );
+
+ meta
+ }
+ TokenValue::For => {
+ let mut meta = self.bump(frontend)?.meta;
+
+ ctx.symbol_table.push_scope();
+ self.expect(frontend, TokenValue::LeftParen)?;
+
+ if self.bump_if(frontend, TokenValue::Semicolon).is_none() {
+ if self.peek_type_name(frontend) || self.peek_type_qualifier(frontend) {
+ self.parse_declaration(frontend, ctx, body, false)?;
+ } else {
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?;
+ ctx.lower(stmt, frontend, expr, ExprPos::Rhs, body)?;
+ self.expect(frontend, TokenValue::Semicolon)?;
+ }
+ }
+
+ let (mut block, mut continuing) = (Block::new(), Block::new());
+
+ if self.bump_if(frontend, TokenValue::Semicolon).is_none() {
+ let (expr, expr_meta) = if self.peek_type_name(frontend)
+ || self.peek_type_qualifier(frontend)
+ {
+ let mut qualifiers = self.parse_type_qualifiers(frontend)?;
+ let (ty, mut meta) = self.parse_type_non_void(frontend)?;
+ let name = self.expect_ident(frontend)?.0;
+
+ self.expect(frontend, TokenValue::Assign)?;
+
+ let (value, end_meta) =
+ self.parse_initializer(frontend, ty, ctx, &mut block)?;
+ meta.subsume(end_meta);
+
+ let decl = VarDeclaration {
+ qualifiers: &mut qualifiers,
+ ty,
+ name: Some(name),
+ init: None,
+ meta,
+ };
+
+ let pointer = frontend.add_local_var(ctx, &mut block, decl)?;
+
+ ctx.emit_restart(&mut block);
+
+ block.push(Statement::Store { pointer, value }, meta);
+
+ (value, end_meta)
+ } else {
+ let mut stmt = ctx.stmt_ctx();
+ let root = self.parse_expression(frontend, ctx, &mut stmt, &mut block)?;
+ ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs, &mut block)?
+ };
+
+ let condition = ctx.add_expression(
+ Expression::Unary {
+ op: UnaryOperator::Not,
+ expr,
+ },
+ expr_meta,
+ &mut block,
+ );
+
+ ctx.emit_restart(&mut block);
+
+ block.push(
+ Statement::If {
+ condition,
+ accept: new_break(),
+ reject: Block::new(),
+ },
+ crate::Span::default(),
+ );
+
+ self.expect(frontend, TokenValue::Semicolon)?;
+ }
+
+ match self.expect_peek(frontend)?.value {
+ TokenValue::RightParen => {}
+ _ => {
+ let mut stmt = ctx.stmt_ctx();
+ let rest =
+ self.parse_expression(frontend, ctx, &mut stmt, &mut continuing)?;
+ ctx.lower(stmt, frontend, rest, ExprPos::Rhs, &mut continuing)?;
+ }
+ }
+
+ meta.subsume(self.expect(frontend, TokenValue::RightParen)?.meta);
+
+ if let Some(stmt_meta) =
+ self.parse_statement(frontend, ctx, &mut block, &mut None)?
+ {
+ meta.subsume(stmt_meta);
+ }
+
+ body.push(
+ Statement::Loop {
+ body: block,
+ continuing,
+ break_if: None,
+ },
+ meta,
+ );
+
+ ctx.symbol_table.pop_scope();
+
+ meta
+ }
+ TokenValue::LeftBrace => {
+ let meta = self.bump(frontend)?.meta;
+
+ let mut block = Block::new();
+
+ let mut block_terminator = None;
+ let meta = self.parse_compound_statement(
+ meta,
+ frontend,
+ ctx,
+ &mut block,
+ &mut block_terminator,
+ )?;
+
+ body.push(Statement::Block(block), meta);
+ if block_terminator.is_some() {
+ terminator.get_or_insert(body.len());
+ }
+
+ meta
+ }
+ TokenValue::Semicolon => self.bump(frontend)?.meta,
+ _ => {
+ // Attempt to force expression parsing for remainder of the
+ // tokens. Unknown or invalid tokens will be caught there and
+ // turned into an error.
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt, body)?;
+ ctx.lower(stmt, frontend, expr, ExprPos::Rhs, body)?;
+ self.expect(frontend, TokenValue::Semicolon)?.meta
+ }
+ };
+
+ meta.subsume(meta_rest);
+ Ok(Some(meta))
+ }
+
+ pub fn parse_compound_statement(
+ &mut self,
+ mut meta: Span,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ body: &mut Block,
+ terminator: &mut Option<usize>,
+ ) -> Result<Span> {
+ ctx.symbol_table.push_scope();
+
+ loop {
+ if let Some(Token {
+ meta: brace_meta, ..
+ }) = self.bump_if(frontend, TokenValue::RightBrace)
+ {
+ meta.subsume(brace_meta);
+ break;
+ }
+
+ let stmt = self.parse_statement(frontend, ctx, body, terminator)?;
+
+ if let Some(stmt_meta) = stmt {
+ meta.subsume(stmt_meta);
+ }
+ }
+
+ if let Some(idx) = *terminator {
+ body.cull(idx..)
+ }
+
+ ctx.symbol_table.pop_scope();
+
+ Ok(meta)
+ }
+
+ pub fn parse_function_args(
+ &mut self,
+ frontend: &mut Frontend,
+ context: &mut Context,
+ body: &mut Block,
+ ) -> Result<()> {
+ if self.bump_if(frontend, TokenValue::Void).is_some() {
+ return Ok(());
+ }
+
+ loop {
+ if self.peek_type_name(frontend) || self.peek_parameter_qualifier(frontend) {
+ let qualifier = self.parse_parameter_qualifier(frontend);
+ let mut ty = self.parse_type_non_void(frontend)?.0;
+
+ match self.expect_peek(frontend)?.value {
+ TokenValue::Comma => {
+ self.bump(frontend)?;
+ context.add_function_arg(frontend, body, None, ty, qualifier);
+ continue;
+ }
+ TokenValue::Identifier(_) => {
+ let mut name = self.expect_ident(frontend)?;
+ self.parse_array_specifier(frontend, &mut name.1, &mut ty)?;
+
+ context.add_function_arg(frontend, body, Some(name), ty, qualifier);
+
+ if self.bump_if(frontend, TokenValue::Comma).is_some() {
+ continue;
+ }
+
+ break;
+ }
+ _ => break,
+ }
+ }
+
+ break;
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser/types.rs b/third_party/rust/naga/src/front/glsl/parser/types.rs
new file mode 100644
index 0000000000..08a70669a0
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser/types.rs
@@ -0,0 +1,434 @@
+use crate::{
+ front::glsl::{
+ ast::{QualifierKey, QualifierValue, StorageQualifier, StructLayout, TypeQualifiers},
+ error::ExpectedToken,
+ parser::ParsingContext,
+ token::{Token, TokenValue},
+ Error, ErrorKind, Frontend, Result,
+ },
+ AddressSpace, ArraySize, Handle, Span, Type, TypeInner,
+};
+
+impl<'source> ParsingContext<'source> {
+ /// Parses an optional array_specifier returning wether or not it's present
+ /// and modifying the type handle if it exists
+ pub fn parse_array_specifier(
+ &mut self,
+ frontend: &mut Frontend,
+ span: &mut Span,
+ ty: &mut Handle<Type>,
+ ) -> Result<()> {
+ while self.parse_array_specifier_single(frontend, span, ty)? {}
+ Ok(())
+ }
+
+ /// Implementation of [`Self::parse_array_specifier`] for a single array_specifier
+ fn parse_array_specifier_single(
+ &mut self,
+ frontend: &mut Frontend,
+ span: &mut Span,
+ ty: &mut Handle<Type>,
+ ) -> Result<bool> {
+ if self.bump_if(frontend, TokenValue::LeftBracket).is_some() {
+ let size = if let Some(Token { meta, .. }) =
+ self.bump_if(frontend, TokenValue::RightBracket)
+ {
+ span.subsume(meta);
+ ArraySize::Dynamic
+ } else {
+ let (value, constant_span) = self.parse_uint_constant(frontend)?;
+ let constant = frontend.module.constants.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Uint(value as u64),
+ },
+ },
+ constant_span,
+ );
+ let end_span = self.expect(frontend, TokenValue::RightBracket)?.meta;
+ span.subsume(end_span);
+ ArraySize::Constant(constant)
+ };
+
+ frontend
+ .layouter
+ .update(&frontend.module.types, &frontend.module.constants)
+ .unwrap();
+ let stride = frontend.layouter[*ty].to_stride();
+ *ty = frontend.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Array {
+ base: *ty,
+ size,
+ stride,
+ },
+ },
+ *span,
+ );
+
+ Ok(true)
+ } else {
+ Ok(false)
+ }
+ }
+
+ pub fn parse_type(&mut self, frontend: &mut Frontend) -> Result<(Option<Handle<Type>>, Span)> {
+ let token = self.bump(frontend)?;
+ let mut handle = match token.value {
+ TokenValue::Void => return Ok((None, token.meta)),
+ TokenValue::TypeName(ty) => frontend.module.types.insert(ty, token.meta),
+ TokenValue::Struct => {
+ let mut meta = token.meta;
+ let ty_name = self.expect_ident(frontend)?.0;
+ self.expect(frontend, TokenValue::LeftBrace)?;
+ let mut members = Vec::new();
+ let span = self.parse_struct_declaration_list(
+ frontend,
+ &mut members,
+ StructLayout::Std140,
+ )?;
+ let end_meta = self.expect(frontend, TokenValue::RightBrace)?.meta;
+ meta.subsume(end_meta);
+ let ty = frontend.module.types.insert(
+ Type {
+ name: Some(ty_name.clone()),
+ inner: TypeInner::Struct { members, span },
+ },
+ meta,
+ );
+ frontend.lookup_type.insert(ty_name, ty);
+ ty
+ }
+ TokenValue::Identifier(ident) => match frontend.lookup_type.get(&ident) {
+ Some(ty) => *ty,
+ None => {
+ return Err(Error {
+ kind: ErrorKind::UnknownType(ident),
+ meta: token.meta,
+ })
+ }
+ },
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![
+ TokenValue::Void.into(),
+ TokenValue::Struct.into(),
+ ExpectedToken::TypeName,
+ ],
+ ),
+ meta: token.meta,
+ });
+ }
+ };
+
+ let mut span = token.meta;
+ self.parse_array_specifier(frontend, &mut span, &mut handle)?;
+ Ok((Some(handle), span))
+ }
+
+ pub fn parse_type_non_void(&mut self, frontend: &mut Frontend) -> Result<(Handle<Type>, Span)> {
+ let (maybe_ty, meta) = self.parse_type(frontend)?;
+ let ty = maybe_ty.ok_or_else(|| Error {
+ kind: ErrorKind::SemanticError("Type can't be void".into()),
+ meta,
+ })?;
+
+ Ok((ty, meta))
+ }
+
+ pub fn peek_type_qualifier(&mut self, frontend: &mut Frontend) -> bool {
+ self.peek(frontend).map_or(false, |t| match t.value {
+ TokenValue::Invariant
+ | TokenValue::Interpolation(_)
+ | TokenValue::Sampling(_)
+ | TokenValue::PrecisionQualifier(_)
+ | TokenValue::Const
+ | TokenValue::In
+ | TokenValue::Out
+ | TokenValue::Uniform
+ | TokenValue::Shared
+ | TokenValue::Buffer
+ | TokenValue::Restrict
+ | TokenValue::MemoryQualifier(_)
+ | TokenValue::Layout => true,
+ _ => false,
+ })
+ }
+
+ pub fn parse_type_qualifiers<'a>(
+ &mut self,
+ frontend: &mut Frontend,
+ ) -> Result<TypeQualifiers<'a>> {
+ let mut qualifiers = TypeQualifiers::default();
+
+ while self.peek_type_qualifier(frontend) {
+ let token = self.bump(frontend)?;
+
+ // Handle layout qualifiers outside the match since this can push multiple values
+ if token.value == TokenValue::Layout {
+ self.parse_layout_qualifier_id_list(frontend, &mut qualifiers)?;
+ continue;
+ }
+
+ qualifiers.span.subsume(token.meta);
+
+ match token.value {
+ TokenValue::Invariant => {
+ if qualifiers.invariant.is_some() {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Cannot use more than one invariant qualifier per declaration"
+ .into(),
+ ),
+ meta: token.meta,
+ })
+ }
+
+ qualifiers.invariant = Some(token.meta);
+ }
+ TokenValue::Interpolation(i) => {
+ if qualifiers.interpolation.is_some() {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Cannot use more than one interpolation qualifier per declaration"
+ .into(),
+ ),
+ meta: token.meta,
+ })
+ }
+
+ qualifiers.interpolation = Some((i, token.meta));
+ }
+ TokenValue::Const
+ | TokenValue::In
+ | TokenValue::Out
+ | TokenValue::Uniform
+ | TokenValue::Shared
+ | TokenValue::Buffer => {
+ let storage = match token.value {
+ TokenValue::Const => StorageQualifier::Const,
+ TokenValue::In => StorageQualifier::Input,
+ TokenValue::Out => StorageQualifier::Output,
+ TokenValue::Uniform => {
+ StorageQualifier::AddressSpace(AddressSpace::Uniform)
+ }
+ TokenValue::Shared => {
+ StorageQualifier::AddressSpace(AddressSpace::WorkGroup)
+ }
+ TokenValue::Buffer => {
+ StorageQualifier::AddressSpace(AddressSpace::Storage {
+ access: crate::StorageAccess::all(),
+ })
+ }
+ _ => unreachable!(),
+ };
+
+ if StorageQualifier::AddressSpace(AddressSpace::Function)
+ != qualifiers.storage.0
+ {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Cannot use more than one storage qualifier per declaration".into(),
+ ),
+ meta: token.meta,
+ });
+ }
+
+ qualifiers.storage = (storage, token.meta);
+ }
+ TokenValue::Sampling(s) => {
+ if qualifiers.sampling.is_some() {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Cannot use more than one sampling qualifier per declaration"
+ .into(),
+ ),
+ meta: token.meta,
+ })
+ }
+
+ qualifiers.sampling = Some((s, token.meta));
+ }
+ TokenValue::PrecisionQualifier(p) => {
+ if qualifiers.precision.is_some() {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Cannot use more than one precision qualifier per declaration"
+ .into(),
+ ),
+ meta: token.meta,
+ })
+ }
+
+ qualifiers.precision = Some((p, token.meta));
+ }
+ TokenValue::MemoryQualifier(access) => {
+ let storage_access = qualifiers
+ .storage_access
+ .get_or_insert((crate::StorageAccess::all(), Span::default()));
+ if !storage_access.0.contains(!access) {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "The same memory qualifier can only be used once".into(),
+ ),
+ meta: token.meta,
+ })
+ }
+
+ storage_access.0 &= access;
+ storage_access.1.subsume(token.meta);
+ }
+ TokenValue::Restrict => continue,
+ _ => unreachable!(),
+ };
+ }
+
+ Ok(qualifiers)
+ }
+
+ pub fn parse_layout_qualifier_id_list(
+ &mut self,
+ frontend: &mut Frontend,
+ qualifiers: &mut TypeQualifiers,
+ ) -> Result<()> {
+ self.expect(frontend, TokenValue::LeftParen)?;
+ loop {
+ self.parse_layout_qualifier_id(frontend, &mut qualifiers.layout_qualifiers)?;
+
+ if self.bump_if(frontend, TokenValue::Comma).is_some() {
+ continue;
+ }
+
+ break;
+ }
+ let token = self.expect(frontend, TokenValue::RightParen)?;
+ qualifiers.span.subsume(token.meta);
+
+ Ok(())
+ }
+
+ pub fn parse_layout_qualifier_id(
+ &mut self,
+ frontend: &mut Frontend,
+ qualifiers: &mut crate::FastHashMap<QualifierKey, (QualifierValue, Span)>,
+ ) -> Result<()> {
+ // layout_qualifier_id:
+ // IDENTIFIER
+ // IDENTIFIER EQUAL constant_expression
+ // SHARED
+ let mut token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Identifier(name) => {
+ let (key, value) = match name.as_str() {
+ "std140" => (
+ QualifierKey::Layout,
+ QualifierValue::Layout(StructLayout::Std140),
+ ),
+ "std430" => (
+ QualifierKey::Layout,
+ QualifierValue::Layout(StructLayout::Std430),
+ ),
+ word => {
+ if let Some(format) = map_image_format(word) {
+ (QualifierKey::Format, QualifierValue::Format(format))
+ } else {
+ let key = QualifierKey::String(name.into());
+ let value = if self.bump_if(frontend, TokenValue::Assign).is_some() {
+ let (value, end_meta) = match self.parse_uint_constant(frontend) {
+ Ok(v) => v,
+ Err(e) => {
+ frontend.errors.push(e);
+ (0, Span::default())
+ }
+ };
+ token.meta.subsume(end_meta);
+
+ QualifierValue::Uint(value)
+ } else {
+ QualifierValue::None
+ };
+
+ (key, value)
+ }
+ }
+ };
+
+ qualifiers.insert(key, (value, token.meta));
+ }
+ _ => frontend.errors.push(Error {
+ kind: ErrorKind::InvalidToken(token.value, vec![ExpectedToken::Identifier]),
+ meta: token.meta,
+ }),
+ }
+
+ Ok(())
+ }
+
+ pub fn peek_type_name(&mut self, frontend: &mut Frontend) -> bool {
+ self.peek(frontend).map_or(false, |t| match t.value {
+ TokenValue::TypeName(_) | TokenValue::Void => true,
+ TokenValue::Struct => true,
+ TokenValue::Identifier(ref ident) => frontend.lookup_type.contains_key(ident),
+ _ => false,
+ })
+ }
+}
+
+fn map_image_format(word: &str) -> Option<crate::StorageFormat> {
+ use crate::StorageFormat as Sf;
+
+ let format = match word {
+ // float-image-format-qualifier:
+ "rgba32f" => Sf::Rgba32Float,
+ "rgba16f" => Sf::Rgba16Float,
+ "rg32f" => Sf::Rg32Float,
+ "rg16f" => Sf::Rg16Float,
+ "r11f_g11f_b10f" => Sf::Rg11b10Float,
+ "r32f" => Sf::R32Float,
+ "r16f" => Sf::R16Float,
+ "rgba16" => Sf::Rgba16Unorm,
+ "rgb10_a2" => Sf::Rgb10a2Unorm,
+ "rgba8" => Sf::Rgba8Unorm,
+ "rg16" => Sf::Rg16Unorm,
+ "rg8" => Sf::Rg8Unorm,
+ "r16" => Sf::R16Unorm,
+ "r8" => Sf::R8Unorm,
+ "rgba16_snorm" => Sf::Rgba16Snorm,
+ "rgba8_snorm" => Sf::Rgba8Snorm,
+ "rg16_snorm" => Sf::Rg16Snorm,
+ "rg8_snorm" => Sf::Rg8Snorm,
+ "r16_snorm" => Sf::R16Snorm,
+ "r8_snorm" => Sf::R8Snorm,
+ // int-image-format-qualifier:
+ "rgba32i" => Sf::Rgba32Sint,
+ "rgba16i" => Sf::Rgba16Sint,
+ "rgba8i" => Sf::Rgba8Sint,
+ "rg32i" => Sf::Rg32Sint,
+ "rg16i" => Sf::Rg16Sint,
+ "rg8i" => Sf::Rg8Sint,
+ "r32i" => Sf::R32Sint,
+ "r16i" => Sf::R16Sint,
+ "r8i" => Sf::R8Sint,
+ // uint-image-format-qualifier:
+ "rgba32ui" => Sf::Rgba32Uint,
+ "rgba16ui" => Sf::Rgba16Uint,
+ "rgba8ui" => Sf::Rgba8Uint,
+ "rg32ui" => Sf::Rg32Uint,
+ "rg16ui" => Sf::Rg16Uint,
+ "rg8ui" => Sf::Rg8Uint,
+ "r32ui" => Sf::R32Uint,
+ "r16ui" => Sf::R16Uint,
+ "r8ui" => Sf::R8Uint,
+ // TODO: These next ones seem incorrect to me
+ // "rgb10_a2ui" => Sf::Rgb10a2Unorm,
+ _ => return None,
+ };
+
+ Some(format)
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser_tests.rs b/third_party/rust/naga/src/front/glsl/parser_tests.rs
new file mode 100644
index 0000000000..ae529a6fbe
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser_tests.rs
@@ -0,0 +1,832 @@
+use super::{
+ ast::Profile,
+ error::ExpectedToken,
+ error::{Error, ErrorKind},
+ token::TokenValue,
+ Frontend, Options, Span,
+};
+use crate::ShaderStage;
+use pp_rs::token::PreprocessorError;
+
+#[test]
+fn version() {
+ let mut frontend = Frontend::default();
+
+ // invalid versions
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 99000\n void main(){}",
+ )
+ .err()
+ .unwrap(),
+ vec![Error {
+ kind: ErrorKind::InvalidVersion(99000),
+ meta: Span::new(9, 14)
+ }],
+ );
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 449\n void main(){}",
+ )
+ .err()
+ .unwrap(),
+ vec![Error {
+ kind: ErrorKind::InvalidVersion(449),
+ meta: Span::new(9, 12)
+ }]
+ );
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 450 smart\n void main(){}",
+ )
+ .err()
+ .unwrap(),
+ vec![Error {
+ kind: ErrorKind::InvalidProfile("smart".into()),
+ meta: Span::new(13, 18),
+ }]
+ );
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 450\nvoid main(){} #version 450",
+ )
+ .err()
+ .unwrap(),
+ vec![
+ Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedHash,),
+ meta: Span::new(27, 28),
+ },
+ Error {
+ kind: ErrorKind::InvalidToken(
+ TokenValue::Identifier("version".into()),
+ vec![ExpectedToken::Eof]
+ ),
+ meta: Span::new(28, 35)
+ }
+ ]
+ );
+
+ // valid versions
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ " # version 450\nvoid main() {}",
+ )
+ .unwrap();
+ assert_eq!(
+ (frontend.metadata().version, frontend.metadata().profile),
+ (450, Profile::Core)
+ );
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 450\nvoid main() {}",
+ )
+ .unwrap();
+ assert_eq!(
+ (frontend.metadata().version, frontend.metadata().profile),
+ (450, Profile::Core)
+ );
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 450 core\nvoid main(void) {}",
+ )
+ .unwrap();
+ assert_eq!(
+ (frontend.metadata().version, frontend.metadata().profile),
+ (450, Profile::Core)
+ );
+}
+
+#[test]
+fn control_flow() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ if (true) {
+ return 1;
+ } else {
+ return 2;
+ }
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ if (true) {
+ return 1;
+ }
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ int x;
+ int y = 3;
+ switch (5) {
+ case 2:
+ x = 2;
+ case 5:
+ x = 5;
+ y = 2;
+ break;
+ default:
+ x = 0;
+ }
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ int x = 0;
+ while(x < 5) {
+ x = x + 1;
+ }
+ do {
+ x = x - 1;
+ } while(x >= 4)
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ int x = 0;
+ for(int i = 0; i < 10;) {
+ x = x + 2;
+ }
+ for(;;);
+ return x;
+ }
+ "#,
+ )
+ .unwrap();
+}
+
+#[test]
+fn declarations() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(location = 0) in vec2 v_uv;
+ layout(location = 0) out vec4 o_color;
+ layout(set = 1, binding = 1) uniform texture2D tex;
+ layout(set = 1, binding = 2) uniform sampler tex_sampler;
+
+ layout(early_fragment_tests) in;
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(std140, set = 2, binding = 0)
+ uniform u_locals {
+ vec3 model_offs;
+ float load_time;
+ ivec4 atlas_offs;
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(push_constant)
+ uniform u_locals {
+ vec3 model_offs;
+ float load_time;
+ ivec4 atlas_offs;
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(std430, set = 2, binding = 0)
+ uniform u_locals {
+ vec3 model_offs;
+ float load_time;
+ ivec4 atlas_offs;
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(std140, set = 2, binding = 0)
+ uniform u_locals {
+ vec3 model_offs;
+ float load_time;
+ } block_var;
+
+ void main() {
+ load_time * model_offs;
+ block_var.load_time * block_var.model_offs;
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ float vector = vec4(1.0 / 17.0, 9.0 / 17.0, 3.0 / 17.0, 11.0 / 17.0);
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ precision highp float;
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+}
+
+#[test]
+fn textures() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(location = 0) in vec2 v_uv;
+ layout(location = 0) out vec4 o_color;
+ layout(set = 1, binding = 1) uniform texture2D tex;
+ layout(set = 1, binding = 2) uniform sampler tex_sampler;
+ void main() {
+ o_color = texture(sampler2D(tex, tex_sampler), v_uv);
+ o_color.a = texture(sampler2D(tex, tex_sampler), v_uv, 2.0).a;
+ }
+ "#,
+ )
+ .unwrap();
+}
+
+#[test]
+fn functions() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void test1(float);
+ void test1(float) {}
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void test2(float a) {}
+ void test3(float a, float b) {}
+ void test4(float, float) {}
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ float test(float a) { return a; }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ float test(vec4 p) {
+ return p.x;
+ }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ // Function overloading
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ float test(vec2 p);
+ float test(vec3 p);
+ float test(vec4 p);
+
+ float test(vec2 p) {
+ return p.x;
+ }
+
+ float test(vec3 p) {
+ return p.x;
+ }
+
+ float test(vec4 p) {
+ return p.x;
+ }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ int test(vec4 p) {
+ return p.x;
+ }
+
+ float test(vec4 p) {
+ return p.x;
+ }
+
+ void main() {}
+ "#,
+ )
+ .err()
+ .unwrap(),
+ vec![Error {
+ kind: ErrorKind::SemanticError("Function already defined".into()),
+ meta: Span::new(134, 152),
+ }]
+ );
+
+ println!();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ float callee(uint q) {
+ return float(q);
+ }
+
+ float caller() {
+ callee(1u);
+ }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ // Nested function call
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ layout(set = 0, binding = 1) uniform texture2D t_noise;
+ layout(set = 0, binding = 2) uniform sampler s_noise;
+
+ void main() {
+ textureLod(sampler2D(t_noise, s_noise), vec2(1.0), 0);
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void fun(vec2 in_parameter, out float out_parameter) {
+ ivec2 _ = ivec2(in_parameter);
+ }
+
+ void main() {
+ float a;
+ fun(vec2(1.0), a);
+ }
+ "#,
+ )
+ .unwrap();
+}
+
+#[test]
+fn constants() {
+ use crate::{Constant, ConstantInner, ScalarValue};
+ let mut frontend = Frontend::default();
+
+ let module = frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ const float a = 1.0;
+ float global = a;
+ const float b = a;
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ let mut constants = module.constants.iter();
+
+ assert_eq!(
+ constants.next().unwrap().1,
+ &Constant {
+ name: Some("a".to_owned()),
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(1.0)
+ }
+ }
+ );
+ assert_eq!(
+ constants.next().unwrap().1,
+ &Constant {
+ name: Some("b".to_owned()),
+ specialization: None,
+ inner: ConstantInner::Scalar {
+ width: 4,
+ value: ScalarValue::Float(1.0)
+ }
+ }
+ );
+
+ assert!(constants.next().is_none());
+}
+
+#[test]
+fn function_overloading() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+
+ float saturate(float v) { return clamp(v, 0.0, 1.0); }
+ vec2 saturate(vec2 v) { return clamp(v, vec2(0.0), vec2(1.0)); }
+ vec3 saturate(vec3 v) { return clamp(v, vec3(0.0), vec3(1.0)); }
+ vec4 saturate(vec4 v) { return clamp(v, vec4(0.0), vec4(1.0)); }
+
+ void main() {
+ float v1 = saturate(1.5);
+ vec2 v2 = saturate(vec2(0.5, 1.5));
+ vec3 v3 = saturate(vec3(0.5, 1.5, 2.5));
+ vec3 v4 = saturate(vec4(0.5, 1.5, 2.5, 3.5));
+ }
+ "#,
+ )
+ .unwrap();
+}
+
+#[test]
+fn implicit_conversions() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ mat4 a = mat4(1);
+ float b = 1u;
+ float c = 1 + 2.0;
+ }
+ "#,
+ )
+ .unwrap();
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void test(int a) {}
+ void test(uint a) {}
+
+ void main() {
+ test(1.0);
+ }
+ "#,
+ )
+ .err()
+ .unwrap(),
+ vec![Error {
+ kind: ErrorKind::SemanticError("Unknown function \'test\'".into()),
+ meta: Span::new(156, 165),
+ }]
+ );
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void test(float a) {}
+ void test(uint a) {}
+
+ void main() {
+ test(1);
+ }
+ "#,
+ )
+ .err()
+ .unwrap(),
+ vec![Error {
+ kind: ErrorKind::SemanticError("Ambiguous best function for \'test\'".into()),
+ meta: Span::new(158, 165),
+ }]
+ );
+}
+
+#[test]
+fn structs() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ Test {
+ vec4 pos;
+ } xx;
+
+ void main() {}
+ "#,
+ )
+ .unwrap_err();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ struct Test {
+ vec4 pos;
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ const int NUM_VECS = 42;
+ struct Test {
+ vec4 vecs[NUM_VECS];
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ struct Hello {
+ vec4 test;
+ } test() {
+ return Hello( vec4(1.0) );
+ }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ struct Test {};
+
+ void main() {}
+ "#,
+ )
+ .unwrap_err();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ inout struct Test {
+ vec4 x;
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap_err();
+}
+
+#[test]
+fn swizzles() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ vec4 v = vec4(1);
+ v.xyz = vec3(2);
+ v.x = 5.0;
+ v.xyz.zxy.yx.xy = vec2(5.0, 1.0);
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ vec4 v = vec4(1);
+ v.xx = vec2(5.0);
+ }
+ "#,
+ )
+ .unwrap_err();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ vec3 v = vec3(1);
+ v.w = 2.0;
+ }
+ "#,
+ )
+ .unwrap_err();
+}
+
+#[test]
+fn expressions() {
+ let mut frontend = Frontend::default();
+
+ // Vector indexing
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ float test(int index) {
+ vec4 v = vec4(1.0, 2.0, 3.0, 4.0);
+ return v[index] + 1.0;
+ }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ // Prefix increment/decrement
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ uint index = 0;
+
+ --index;
+ ++index;
+ }
+ "#,
+ )
+ .unwrap();
+
+ // Dynamic indexing of array
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ const vec4 positions[1] = { vec4(0) };
+
+ gl_Position = positions[gl_VertexIndex];
+ }
+ "#,
+ )
+ .unwrap();
+}
diff --git a/third_party/rust/naga/src/front/glsl/token.rs b/third_party/rust/naga/src/front/glsl/token.rs
new file mode 100644
index 0000000000..74f10cbf85
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/token.rs
@@ -0,0 +1,137 @@
+pub use pp_rs::token::{Float, Integer, Location, PreprocessorError, Token as PPToken};
+
+use super::ast::Precision;
+use crate::{Interpolation, Sampling, Span, Type};
+
+impl From<Location> for Span {
+ fn from(loc: Location) -> Self {
+ Span::new(loc.start, loc.end)
+ }
+}
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Token {
+ pub value: TokenValue,
+ pub meta: Span,
+}
+
+/// A token passed from the lexing used in the parsing.
+///
+/// This type is exported since it's returned in the
+/// [`InvalidToken`](super::ErrorKind::InvalidToken) error.
+#[derive(Debug, PartialEq)]
+pub enum TokenValue {
+ Identifier(String),
+
+ FloatConstant(Float),
+ IntConstant(Integer),
+ BoolConstant(bool),
+
+ Layout,
+ In,
+ Out,
+ InOut,
+ Uniform,
+ Buffer,
+ Const,
+ Shared,
+
+ Restrict,
+ /// A `glsl` memory qualifier such as `writeonly`
+ ///
+ /// The associated [`crate::StorageAccess`] is the access being allowed
+ /// (for example `writeonly` has an associated value of [`crate::StorageAccess::STORE`])
+ MemoryQualifier(crate::StorageAccess),
+
+ Invariant,
+ Interpolation(Interpolation),
+ Sampling(Sampling),
+ Precision,
+ PrecisionQualifier(Precision),
+
+ Continue,
+ Break,
+ Return,
+ Discard,
+
+ If,
+ Else,
+ Switch,
+ Case,
+ Default,
+ While,
+ Do,
+ For,
+
+ Void,
+ Struct,
+ TypeName(Type),
+
+ Assign,
+ AddAssign,
+ SubAssign,
+ MulAssign,
+ DivAssign,
+ ModAssign,
+ LeftShiftAssign,
+ RightShiftAssign,
+ AndAssign,
+ XorAssign,
+ OrAssign,
+
+ Increment,
+ Decrement,
+
+ LogicalOr,
+ LogicalAnd,
+ LogicalXor,
+
+ LessEqual,
+ GreaterEqual,
+ Equal,
+ NotEqual,
+
+ LeftShift,
+ RightShift,
+
+ LeftBrace,
+ RightBrace,
+ LeftParen,
+ RightParen,
+ LeftBracket,
+ RightBracket,
+ LeftAngle,
+ RightAngle,
+
+ Comma,
+ Semicolon,
+ Colon,
+ Dot,
+ Bang,
+ Dash,
+ Tilde,
+ Plus,
+ Star,
+ Slash,
+ Percent,
+ VerticalBar,
+ Caret,
+ Ampersand,
+ Question,
+}
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Directive {
+ pub kind: DirectiveKind,
+ pub tokens: Vec<PPToken>,
+}
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum DirectiveKind {
+ Version { is_first_directive: bool },
+ Extension,
+ Pragma,
+}
diff --git a/third_party/rust/naga/src/front/glsl/types.rs b/third_party/rust/naga/src/front/glsl/types.rs
new file mode 100644
index 0000000000..a7967848d5
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/types.rs
@@ -0,0 +1,335 @@
+use super::{
+ constants::ConstantSolver, context::Context, Error, ErrorKind, Frontend, Result, Span,
+};
+use crate::{
+ proc::ResolveContext, Bytes, Constant, Expression, Handle, ImageClass, ImageDimension,
+ ScalarKind, Type, TypeInner, VectorSize,
+};
+
+pub fn parse_type(type_name: &str) -> Option<Type> {
+ match type_name {
+ "bool" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ },
+ }),
+ "float" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ }),
+ "double" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Float,
+ width: 8,
+ },
+ }),
+ "int" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Sint,
+ width: 4,
+ },
+ }),
+ "uint" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Uint,
+ width: 4,
+ },
+ }),
+ "sampler" | "samplerShadow" => Some(Type {
+ name: None,
+ inner: TypeInner::Sampler {
+ comparison: type_name == "samplerShadow",
+ },
+ }),
+ word => {
+ fn kind_width_parse(ty: &str) -> Option<(ScalarKind, u8)> {
+ Some(match ty {
+ "" => (ScalarKind::Float, 4),
+ "b" => (ScalarKind::Bool, crate::BOOL_WIDTH),
+ "i" => (ScalarKind::Sint, 4),
+ "u" => (ScalarKind::Uint, 4),
+ "d" => (ScalarKind::Float, 8),
+ _ => return None,
+ })
+ }
+
+ fn size_parse(n: &str) -> Option<VectorSize> {
+ Some(match n {
+ "2" => VectorSize::Bi,
+ "3" => VectorSize::Tri,
+ "4" => VectorSize::Quad,
+ _ => return None,
+ })
+ }
+
+ let vec_parse = |word: &str| {
+ let mut iter = word.split("vec");
+
+ let kind = iter.next()?;
+ let size = iter.next()?;
+ let (kind, width) = kind_width_parse(kind)?;
+ let size = size_parse(size)?;
+
+ Some(Type {
+ name: None,
+ inner: TypeInner::Vector { size, kind, width },
+ })
+ };
+
+ let mat_parse = |word: &str| {
+ let mut iter = word.split("mat");
+
+ let kind = iter.next()?;
+ let size = iter.next()?;
+ let (_, width) = kind_width_parse(kind)?;
+
+ let (columns, rows) = if let Some(size) = size_parse(size) {
+ (size, size)
+ } else {
+ let mut iter = size.split('x');
+ match (iter.next()?, iter.next()?, iter.next()) {
+ (col, row, None) => (size_parse(col)?, size_parse(row)?),
+ _ => return None,
+ }
+ };
+
+ Some(Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ },
+ })
+ };
+
+ let texture_parse = |word: &str| {
+ let mut iter = word.split("texture");
+
+ let texture_kind = |ty| {
+ Some(match ty {
+ "" => ScalarKind::Float,
+ "i" => ScalarKind::Sint,
+ "u" => ScalarKind::Uint,
+ _ => return None,
+ })
+ };
+
+ let kind = iter.next()?;
+ let size = iter.next()?;
+ let kind = texture_kind(kind)?;
+
+ let sampled = |multi| ImageClass::Sampled { kind, multi };
+
+ let (dim, arrayed, class) = match size {
+ "1D" => (ImageDimension::D1, false, sampled(false)),
+ "1DArray" => (ImageDimension::D1, true, sampled(false)),
+ "2D" => (ImageDimension::D2, false, sampled(false)),
+ "2DArray" => (ImageDimension::D2, true, sampled(false)),
+ "2DMS" => (ImageDimension::D2, false, sampled(true)),
+ "2DMSArray" => (ImageDimension::D2, true, sampled(true)),
+ "3D" => (ImageDimension::D3, false, sampled(false)),
+ "Cube" => (ImageDimension::Cube, false, sampled(false)),
+ "CubeArray" => (ImageDimension::Cube, true, sampled(false)),
+ _ => return None,
+ };
+
+ Some(Type {
+ name: None,
+ inner: TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ },
+ })
+ };
+
+ let image_parse = |word: &str| {
+ let mut iter = word.split("image");
+
+ let texture_kind = |ty| {
+ Some(match ty {
+ "" => ScalarKind::Float,
+ "i" => ScalarKind::Sint,
+ "u" => ScalarKind::Uint,
+ _ => return None,
+ })
+ };
+
+ let kind = iter.next()?;
+ let size = iter.next()?;
+ // TODO: Check that the texture format and the kind match
+ let _ = texture_kind(kind)?;
+
+ let class = ImageClass::Storage {
+ format: crate::StorageFormat::R8Uint,
+ access: crate::StorageAccess::all(),
+ };
+
+ // TODO: glsl support multisampled storage images, naga doesn't
+ let (dim, arrayed) = match size {
+ "1D" => (ImageDimension::D1, false),
+ "1DArray" => (ImageDimension::D1, true),
+ "2D" => (ImageDimension::D2, false),
+ "2DArray" => (ImageDimension::D2, true),
+ "3D" => (ImageDimension::D3, false),
+ // Naga doesn't support cube images and it's usefulness
+ // is questionable, so they won't be supported for now
+ // "Cube" => (ImageDimension::Cube, false),
+ // "CubeArray" => (ImageDimension::Cube, true),
+ _ => return None,
+ };
+
+ Some(Type {
+ name: None,
+ inner: TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ },
+ })
+ };
+
+ vec_parse(word)
+ .or_else(|| mat_parse(word))
+ .or_else(|| texture_parse(word))
+ .or_else(|| image_parse(word))
+ }
+ }
+}
+
+pub const fn scalar_components(ty: &TypeInner) -> Option<(ScalarKind, Bytes)> {
+ match *ty {
+ TypeInner::Scalar { kind, width } => Some((kind, width)),
+ TypeInner::Vector { kind, width, .. } => Some((kind, width)),
+ TypeInner::Matrix { width, .. } => Some((ScalarKind::Float, width)),
+ TypeInner::ValuePointer { kind, width, .. } => Some((kind, width)),
+ _ => None,
+ }
+}
+
+pub const fn type_power(kind: ScalarKind, width: Bytes) -> Option<u32> {
+ Some(match kind {
+ ScalarKind::Sint => 0,
+ ScalarKind::Uint => 1,
+ ScalarKind::Float if width == 4 => 2,
+ ScalarKind::Float => 3,
+ ScalarKind::Bool => return None,
+ })
+}
+
+impl Frontend {
+ /// Resolves the types of the expressions until `expr` (inclusive)
+ ///
+ /// This needs to be done before the [`typifier`] can be queried for
+ /// the types of the expressions in the range between the last grow and `expr`.
+ ///
+ /// # Note
+ ///
+ /// The `resolve_type*` methods (like [`resolve_type`]) automatically
+ /// grow the [`typifier`] so calling this method is not necessary when using
+ /// them.
+ ///
+ /// [`typifier`]: Context::typifier
+ /// [`resolve_type`]: Self::resolve_type
+ pub(crate) fn typifier_grow(
+ &self,
+ ctx: &mut Context,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<()> {
+ let resolve_ctx = ResolveContext::with_locals(&self.module, &ctx.locals, &ctx.arguments);
+
+ ctx.typifier
+ .grow(expr, &ctx.expressions, &resolve_ctx)
+ .map_err(|error| Error {
+ kind: ErrorKind::SemanticError(format!("Can't resolve type: {error:?}").into()),
+ meta,
+ })
+ }
+
+ /// Gets the type for the result of the `expr` expression
+ ///
+ /// Automatically grows the [`typifier`] to `expr` so calling
+ /// [`typifier_grow`] is not necessary
+ ///
+ /// [`typifier`]: Context::typifier
+ /// [`typifier_grow`]: Self::typifier_grow
+ pub(crate) fn resolve_type<'b>(
+ &'b self,
+ ctx: &'b mut Context,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<&'b TypeInner> {
+ self.typifier_grow(ctx, expr, meta)?;
+ Ok(ctx.typifier.get(expr, &self.module.types))
+ }
+
+ /// Gets the type handle for the result of the `expr` expression
+ ///
+ /// Automatically grows the [`typifier`] to `expr` so calling
+ /// [`typifier_grow`] is not necessary
+ ///
+ /// # Note
+ ///
+ /// Consider using [`resolve_type`] whenever possible
+ /// since it doesn't require adding each type to the [`types`] arena
+ /// and it doesn't need to mutably borrow the [`Parser`][Self]
+ ///
+ /// [`types`]: crate::Module::types
+ /// [`typifier`]: Context::typifier
+ /// [`typifier_grow`]: Self::typifier_grow
+ /// [`resolve_type`]: Self::resolve_type
+ pub(crate) fn resolve_type_handle(
+ &mut self,
+ ctx: &mut Context,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<Handle<Type>> {
+ self.typifier_grow(ctx, expr, meta)?;
+ Ok(ctx.typifier.register_type(expr, &mut self.module.types))
+ }
+
+ /// Invalidates the cached type resolution for `expr` forcing a recomputation
+ pub(crate) fn invalidate_expression<'b>(
+ &'b self,
+ ctx: &'b mut Context,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<()> {
+ let resolve_ctx = ResolveContext::with_locals(&self.module, &ctx.locals, &ctx.arguments);
+
+ ctx.typifier
+ .invalidate(expr, &ctx.expressions, &resolve_ctx)
+ .map_err(|error| Error {
+ kind: ErrorKind::SemanticError(format!("Can't resolve type: {error:?}").into()),
+ meta,
+ })
+ }
+
+ pub(crate) fn solve_constant(
+ &mut self,
+ ctx: &Context,
+ root: Handle<Expression>,
+ meta: Span,
+ ) -> Result<Handle<Constant>> {
+ let mut solver = ConstantSolver {
+ types: &mut self.module.types,
+ expressions: &ctx.expressions,
+ constants: &mut self.module.constants,
+ };
+
+ solver.solve(root).map_err(|e| Error {
+ kind: e.into(),
+ meta,
+ })
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/variables.rs b/third_party/rust/naga/src/front/glsl/variables.rs
new file mode 100644
index 0000000000..4c0e0a927d
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/variables.rs
@@ -0,0 +1,679 @@
+use super::{
+ ast::*,
+ context::{Context, ExprPos},
+ error::{Error, ErrorKind},
+ Frontend, Result, Span,
+};
+use crate::{
+ AddressSpace, Binding, Block, BuiltIn, Constant, Expression, GlobalVariable, Handle,
+ Interpolation, LocalVariable, ResourceBinding, ScalarKind, ShaderStage, SwizzleComponent, Type,
+ TypeInner, VectorSize,
+};
+
+pub struct VarDeclaration<'a, 'key> {
+ pub qualifiers: &'a mut TypeQualifiers<'key>,
+ pub ty: Handle<Type>,
+ pub name: Option<String>,
+ pub init: Option<Handle<Constant>>,
+ pub meta: Span,
+}
+
+/// Information about a builtin used in [`add_builtin`](Frontend::add_builtin).
+struct BuiltInData {
+ /// The type of the builtin.
+ inner: TypeInner,
+ /// The associated builtin class.
+ builtin: BuiltIn,
+ /// Whether the builtin can be written to or not.
+ mutable: bool,
+ /// The storage used for the builtin.
+ storage: StorageQualifier,
+}
+
+pub enum GlobalOrConstant {
+ Global(Handle<GlobalVariable>),
+ Constant(Handle<Constant>),
+}
+
+impl Frontend {
+ /// Adds a builtin and returns a variable reference to it
+ fn add_builtin(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ name: &str,
+ data: BuiltInData,
+ meta: Span,
+ ) -> Option<VariableReference> {
+ let ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: data.inner,
+ },
+ meta,
+ );
+
+ let handle = self.module.global_variables.append(
+ GlobalVariable {
+ name: Some(name.into()),
+ space: AddressSpace::Private,
+ binding: None,
+ ty,
+ init: None,
+ },
+ meta,
+ );
+
+ let idx = self.entry_args.len();
+ self.entry_args.push(EntryArg {
+ name: None,
+ binding: Binding::BuiltIn(data.builtin),
+ handle,
+ storage: data.storage,
+ });
+
+ self.global_variables.push((
+ name.into(),
+ GlobalLookup {
+ kind: GlobalLookupKind::Variable(handle),
+ entry_arg: Some(idx),
+ mutable: data.mutable,
+ },
+ ));
+
+ let expr = ctx.add_expression(Expression::GlobalVariable(handle), meta, body);
+
+ let var = VariableReference {
+ expr,
+ load: true,
+ mutable: data.mutable,
+ constant: None,
+ entry_arg: Some(idx),
+ };
+
+ ctx.symbol_table.add_root(name.into(), var.clone());
+
+ Some(var)
+ }
+
+ pub(crate) fn lookup_variable(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ name: &str,
+ meta: Span,
+ ) -> Option<VariableReference> {
+ if let Some(var) = ctx.symbol_table.lookup(name).cloned() {
+ return Some(var);
+ }
+
+ let data = match name {
+ "gl_Position" => BuiltInData {
+ inner: TypeInner::Vector {
+ size: VectorSize::Quad,
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ builtin: BuiltIn::Position { invariant: false },
+ mutable: true,
+ storage: StorageQualifier::Output,
+ },
+ "gl_FragCoord" => BuiltInData {
+ inner: TypeInner::Vector {
+ size: VectorSize::Quad,
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ builtin: BuiltIn::Position { invariant: false },
+ mutable: false,
+ storage: StorageQualifier::Input,
+ },
+ "gl_PointCoord" => BuiltInData {
+ inner: TypeInner::Vector {
+ size: VectorSize::Bi,
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ builtin: BuiltIn::PointCoord,
+ mutable: false,
+ storage: StorageQualifier::Input,
+ },
+ "gl_GlobalInvocationID"
+ | "gl_NumWorkGroups"
+ | "gl_WorkGroupSize"
+ | "gl_WorkGroupID"
+ | "gl_LocalInvocationID" => BuiltInData {
+ inner: TypeInner::Vector {
+ size: VectorSize::Tri,
+ kind: ScalarKind::Uint,
+ width: 4,
+ },
+ builtin: match name {
+ "gl_GlobalInvocationID" => BuiltIn::GlobalInvocationId,
+ "gl_NumWorkGroups" => BuiltIn::NumWorkGroups,
+ "gl_WorkGroupSize" => BuiltIn::WorkGroupSize,
+ "gl_WorkGroupID" => BuiltIn::WorkGroupId,
+ "gl_LocalInvocationID" => BuiltIn::LocalInvocationId,
+ _ => unreachable!(),
+ },
+ mutable: false,
+ storage: StorageQualifier::Input,
+ },
+ "gl_FrontFacing" => BuiltInData {
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ },
+ builtin: BuiltIn::FrontFacing,
+ mutable: false,
+ storage: StorageQualifier::Input,
+ },
+ "gl_PointSize" | "gl_FragDepth" => BuiltInData {
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ builtin: match name {
+ "gl_PointSize" => BuiltIn::PointSize,
+ "gl_FragDepth" => BuiltIn::FragDepth,
+ _ => unreachable!(),
+ },
+ mutable: true,
+ storage: StorageQualifier::Output,
+ },
+ "gl_ClipDistance" | "gl_CullDistance" => {
+ let base = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ },
+ meta,
+ );
+
+ BuiltInData {
+ inner: TypeInner::Array {
+ base,
+ size: crate::ArraySize::Dynamic,
+ stride: 4,
+ },
+ builtin: match name {
+ "gl_ClipDistance" => BuiltIn::ClipDistance,
+ "gl_CullDistance" => BuiltIn::CullDistance,
+ _ => unreachable!(),
+ },
+ mutable: self.meta.stage == ShaderStage::Vertex,
+ storage: StorageQualifier::Output,
+ }
+ }
+ _ => {
+ let builtin = match name {
+ "gl_BaseVertex" => BuiltIn::BaseVertex,
+ "gl_BaseInstance" => BuiltIn::BaseInstance,
+ "gl_PrimitiveID" => BuiltIn::PrimitiveIndex,
+ "gl_InstanceIndex" => BuiltIn::InstanceIndex,
+ "gl_VertexIndex" => BuiltIn::VertexIndex,
+ "gl_SampleID" => BuiltIn::SampleIndex,
+ "gl_LocalInvocationIndex" => BuiltIn::LocalInvocationIndex,
+ _ => return None,
+ };
+
+ BuiltInData {
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Uint,
+ width: 4,
+ },
+ builtin,
+ mutable: false,
+ storage: StorageQualifier::Input,
+ }
+ }
+ };
+
+ self.add_builtin(ctx, body, name, data, meta)
+ }
+
+ pub(crate) fn make_variable_invariant(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ name: &str,
+ meta: Span,
+ ) {
+ if let Some(var) = self.lookup_variable(ctx, body, name, meta) {
+ if let Some(index) = var.entry_arg {
+ if let Binding::BuiltIn(BuiltIn::Position { ref mut invariant }) =
+ self.entry_args[index].binding
+ {
+ *invariant = true;
+ }
+ }
+ }
+ }
+
+ pub(crate) fn field_selection(
+ &mut self,
+ ctx: &mut Context,
+ pos: ExprPos,
+ body: &mut Block,
+ expression: Handle<Expression>,
+ name: &str,
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let (ty, is_pointer) = match *self.resolve_type(ctx, expression, meta)? {
+ TypeInner::Pointer { base, .. } => (&self.module.types[base].inner, true),
+ ref ty => (ty, false),
+ };
+ match *ty {
+ TypeInner::Struct { ref members, .. } => {
+ let index = members
+ .iter()
+ .position(|m| m.name == Some(name.into()))
+ .ok_or_else(|| Error {
+ kind: ErrorKind::UnknownField(name.into()),
+ meta,
+ })?;
+ let pointer = ctx.add_expression(
+ Expression::AccessIndex {
+ base: expression,
+ index: index as u32,
+ },
+ meta,
+ body,
+ );
+
+ Ok(match pos {
+ ExprPos::Rhs if is_pointer => {
+ ctx.add_expression(Expression::Load { pointer }, meta, body)
+ }
+ _ => pointer,
+ })
+ }
+ // swizzles (xyzw, rgba, stpq)
+ TypeInner::Vector { size, .. } => {
+ let check_swizzle_components = |comps: &str| {
+ name.chars()
+ .map(|c| {
+ comps
+ .find(c)
+ .filter(|i| *i < size as usize)
+ .map(|i| SwizzleComponent::from_index(i as u32))
+ })
+ .collect::<Option<Vec<SwizzleComponent>>>()
+ };
+
+ let components = check_swizzle_components("xyzw")
+ .or_else(|| check_swizzle_components("rgba"))
+ .or_else(|| check_swizzle_components("stpq"));
+
+ if let Some(components) = components {
+ if let ExprPos::Lhs = pos {
+ let not_unique = (1..components.len())
+ .any(|i| components[i..].contains(&components[i - 1]));
+ if not_unique {
+ self.errors.push(Error {
+ kind:
+ ErrorKind::SemanticError(
+ format!(
+ "swizzle cannot have duplicate components in left-hand-side expression for \"{name:?}\""
+ )
+ .into(),
+ ),
+ meta ,
+ })
+ }
+ }
+
+ let mut pattern = [SwizzleComponent::X; 4];
+ for (pat, component) in pattern.iter_mut().zip(&components) {
+ *pat = *component;
+ }
+
+ // flatten nested swizzles (vec.zyx.xy.x => vec.z)
+ let mut expression = expression;
+ if let Expression::Swizzle {
+ size: _,
+ vector,
+ pattern: ref src_pattern,
+ } = ctx[expression]
+ {
+ expression = vector;
+ for pat in &mut pattern {
+ *pat = src_pattern[pat.index() as usize];
+ }
+ }
+
+ let size = match components.len() {
+ // Swizzles with just one component are accesses and not swizzles
+ 1 => {
+ match pos {
+ // If the position is in the right hand side and the base
+ // vector is a pointer, load it, otherwise the swizzle would
+ // produce a pointer
+ ExprPos::Rhs if is_pointer => {
+ expression = ctx.add_expression(
+ Expression::Load {
+ pointer: expression,
+ },
+ meta,
+ body,
+ );
+ }
+ _ => {}
+ };
+ return Ok(ctx.add_expression(
+ Expression::AccessIndex {
+ base: expression,
+ index: pattern[0].index(),
+ },
+ meta,
+ body,
+ ));
+ }
+ 2 => VectorSize::Bi,
+ 3 => VectorSize::Tri,
+ 4 => VectorSize::Quad,
+ _ => {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!("Bad swizzle size for \"{name:?}\"").into(),
+ ),
+ meta,
+ });
+
+ VectorSize::Quad
+ }
+ };
+
+ if is_pointer {
+ // NOTE: for lhs expression, this extra load ends up as an unused expr, because the
+ // assignment will extract the pointer and use it directly anyway. Unfortunately we
+ // need it for validation to pass, as swizzles cannot operate on pointer values.
+ expression = ctx.add_expression(
+ Expression::Load {
+ pointer: expression,
+ },
+ meta,
+ body,
+ );
+ }
+
+ Ok(ctx.add_expression(
+ Expression::Swizzle {
+ size,
+ vector: expression,
+ pattern,
+ },
+ meta,
+ body,
+ ))
+ } else {
+ Err(Error {
+ kind: ErrorKind::SemanticError(
+ format!("Invalid swizzle for vector \"{name}\"").into(),
+ ),
+ meta,
+ })
+ }
+ }
+ _ => Err(Error {
+ kind: ErrorKind::SemanticError(
+ format!("Can't lookup field on this type \"{name}\"").into(),
+ ),
+ meta,
+ }),
+ }
+ }
+
+ pub(crate) fn add_global_var(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ VarDeclaration {
+ qualifiers,
+ mut ty,
+ name,
+ init,
+ meta,
+ }: VarDeclaration,
+ ) -> Result<GlobalOrConstant> {
+ let storage = qualifiers.storage.0;
+ let (ret, lookup) = match storage {
+ StorageQualifier::Input | StorageQualifier::Output => {
+ let input = storage == StorageQualifier::Input;
+ // TODO: glslang seems to use a counter for variables without
+ // explicit location (even if that causes collisions)
+ let location = qualifiers
+ .uint_layout_qualifier("location", &mut self.errors)
+ .unwrap_or(0);
+ let interpolation = qualifiers.interpolation.take().map(|(i, _)| i).or_else(|| {
+ let kind = self.module.types[ty].inner.scalar_kind()?;
+ Some(match kind {
+ ScalarKind::Float => Interpolation::Perspective,
+ _ => Interpolation::Flat,
+ })
+ });
+ let sampling = qualifiers.sampling.take().map(|(s, _)| s);
+
+ let handle = self.module.global_variables.append(
+ GlobalVariable {
+ name: name.clone(),
+ space: AddressSpace::Private,
+ binding: None,
+ ty,
+ init,
+ },
+ meta,
+ );
+
+ let idx = self.entry_args.len();
+ self.entry_args.push(EntryArg {
+ name: name.clone(),
+ binding: Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ },
+ handle,
+ storage,
+ });
+
+ let lookup = GlobalLookup {
+ kind: GlobalLookupKind::Variable(handle),
+ entry_arg: Some(idx),
+ mutable: !input,
+ };
+
+ (GlobalOrConstant::Global(handle), lookup)
+ }
+ StorageQualifier::Const => {
+ let init = init.ok_or_else(|| Error {
+ kind: ErrorKind::SemanticError("const values must have an initializer".into()),
+ meta,
+ })?;
+
+ let lookup = GlobalLookup {
+ kind: GlobalLookupKind::Constant(init, ty),
+ entry_arg: None,
+ mutable: false,
+ };
+
+ if let Some(name) = name.as_ref() {
+ let constant = self.module.constants.get_mut(init);
+ if constant.name.is_none() {
+ // set the name of the constant
+ constant.name = Some(name.clone())
+ } else {
+ // add a copy of the constant with the new name
+ let new_const = Constant {
+ name: Some(name.clone()),
+ specialization: constant.specialization,
+ inner: constant.inner.clone(),
+ };
+ self.module.constants.fetch_or_append(new_const, meta);
+ }
+ }
+
+ (GlobalOrConstant::Constant(init), lookup)
+ }
+ StorageQualifier::AddressSpace(mut space) => {
+ match space {
+ AddressSpace::Storage { ref mut access } => {
+ if let Some((allowed_access, _)) = qualifiers.storage_access.take() {
+ *access = allowed_access;
+ }
+ }
+ AddressSpace::Uniform => match self.module.types[ty].inner {
+ TypeInner::Image {
+ class,
+ dim,
+ arrayed,
+ } => {
+ if let crate::ImageClass::Storage {
+ mut access,
+ mut format,
+ } = class
+ {
+ if let Some((allowed_access, _)) = qualifiers.storage_access.take()
+ {
+ access = allowed_access;
+ }
+
+ match qualifiers.layout_qualifiers.remove(&QualifierKey::Format) {
+ Some((QualifierValue::Format(f), _)) => format = f,
+ // TODO: glsl supports images without format qualifier
+ // if they are `writeonly`
+ None => self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "image types require a format layout qualifier".into(),
+ ),
+ meta,
+ }),
+ _ => unreachable!(),
+ }
+
+ ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Image {
+ dim,
+ arrayed,
+ class: crate::ImageClass::Storage { format, access },
+ },
+ },
+ meta,
+ );
+ }
+
+ space = AddressSpace::Handle
+ }
+ TypeInner::Sampler { .. } => space = AddressSpace::Handle,
+ _ => {
+ if qualifiers.none_layout_qualifier("push_constant", &mut self.errors) {
+ space = AddressSpace::PushConstant
+ }
+ }
+ },
+ AddressSpace::Function => space = AddressSpace::Private,
+ _ => {}
+ };
+
+ let binding = match space {
+ AddressSpace::Uniform | AddressSpace::Storage { .. } | AddressSpace::Handle => {
+ let binding = qualifiers.uint_layout_qualifier("binding", &mut self.errors);
+ if binding.is_none() {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "uniform/buffer blocks require layout(binding=X)".into(),
+ ),
+ meta,
+ });
+ }
+ let set = qualifiers.uint_layout_qualifier("set", &mut self.errors);
+ binding.map(|binding| ResourceBinding {
+ group: set.unwrap_or(0),
+ binding,
+ })
+ }
+ _ => None,
+ };
+
+ let handle = self.module.global_variables.append(
+ GlobalVariable {
+ name: name.clone(),
+ space,
+ binding,
+ ty,
+ init,
+ },
+ meta,
+ );
+
+ let lookup = GlobalLookup {
+ kind: GlobalLookupKind::Variable(handle),
+ entry_arg: None,
+ mutable: true,
+ };
+
+ (GlobalOrConstant::Global(handle), lookup)
+ }
+ };
+
+ if let Some(name) = name {
+ ctx.add_global(self, &name, lookup, body);
+
+ self.global_variables.push((name, lookup));
+ }
+
+ qualifiers.unused_errors(&mut self.errors);
+
+ Ok(ret)
+ }
+
+ pub(crate) fn add_local_var(
+ &mut self,
+ ctx: &mut Context,
+ body: &mut Block,
+ decl: VarDeclaration,
+ ) -> Result<Handle<Expression>> {
+ let storage = decl.qualifiers.storage;
+ let mutable = match storage.0 {
+ StorageQualifier::AddressSpace(AddressSpace::Function) => true,
+ StorageQualifier::Const => false,
+ _ => {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError("Locals cannot have a storage qualifier".into()),
+ meta: storage.1,
+ });
+ true
+ }
+ };
+
+ let handle = ctx.locals.append(
+ LocalVariable {
+ name: decl.name.clone(),
+ ty: decl.ty,
+ init: None,
+ },
+ decl.meta,
+ );
+ let expr = ctx.add_expression(Expression::LocalVariable(handle), decl.meta, body);
+
+ if let Some(name) = decl.name {
+ let maybe_var = ctx.add_local_var(name.clone(), expr, mutable);
+
+ if maybe_var.is_some() {
+ self.errors.push(Error {
+ kind: ErrorKind::VariableAlreadyDeclared(name),
+ meta: decl.meta,
+ })
+ }
+ }
+
+ decl.qualifiers.unused_errors(&mut self.errors);
+
+ Ok(expr)
+ }
+}
diff --git a/third_party/rust/naga/src/front/interpolator.rs b/third_party/rust/naga/src/front/interpolator.rs
new file mode 100644
index 0000000000..96f5db6ff7
--- /dev/null
+++ b/third_party/rust/naga/src/front/interpolator.rs
@@ -0,0 +1,61 @@
+/*!
+Interpolation defaults.
+*/
+
+impl crate::Binding {
+ /// Apply the usual default interpolation for `ty` to `binding`.
+ ///
+ /// This function is a utility front ends may use to satisfy the Naga IR's
+ /// requirement, meant to ensure that input languages' policies have been
+ /// applied appropriately, that all I/O `Binding`s from the vertex shader to the
+ /// fragment shader must have non-`None` `interpolation` values.
+ ///
+ /// All the shader languages Naga supports have similar rules:
+ /// perspective-correct, center-sampled interpolation is the default for any
+ /// binding that can vary, and everything else either defaults to flat, or
+ /// requires an explicit flat qualifier/attribute/what-have-you.
+ ///
+ /// If `binding` is not a [`Location`] binding, or if its [`interpolation`] is
+ /// already set, then make no changes. Otherwise, set `binding`'s interpolation
+ /// and sampling to reasonable defaults depending on `ty`, the type of the value
+ /// being interpolated:
+ ///
+ /// - If `ty` is a floating-point scalar, vector, or matrix type, then
+ /// default to [`Perspective`] interpolation and [`Center`] sampling.
+ ///
+ /// - If `ty` is an integral scalar or vector, then default to [`Flat`]
+ /// interpolation, which has no associated sampling.
+ ///
+ /// - For any other types, make no change. Such types are not permitted as
+ /// user-defined IO values, and will probably be flagged by the verifier
+ ///
+ /// When structs appear in input or output types, each member ought to have its
+ /// own [`Binding`], so structs are simply covered by the third case.
+ ///
+ /// [`Binding`]: crate::Binding
+ /// [`Location`]: crate::Binding::Location
+ /// [`interpolation`]: crate::Binding::Location::interpolation
+ /// [`Perspective`]: crate::Interpolation::Perspective
+ /// [`Flat`]: crate::Interpolation::Flat
+ /// [`Center`]: crate::Sampling::Center
+ pub fn apply_default_interpolation(&mut self, ty: &crate::TypeInner) {
+ if let crate::Binding::Location {
+ location: _,
+ interpolation: ref mut interpolation @ None,
+ ref mut sampling,
+ } = *self
+ {
+ match ty.scalar_kind() {
+ Some(crate::ScalarKind::Float) => {
+ *interpolation = Some(crate::Interpolation::Perspective);
+ *sampling = Some(crate::Sampling::Center);
+ }
+ Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
+ *interpolation = Some(crate::Interpolation::Flat);
+ *sampling = None;
+ }
+ Some(_) | None => {}
+ }
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/mod.rs b/third_party/rust/naga/src/front/mod.rs
new file mode 100644
index 0000000000..d6f38671ea
--- /dev/null
+++ b/third_party/rust/naga/src/front/mod.rs
@@ -0,0 +1,374 @@
+/*!
+Frontend parsers that consume binary and text shaders and load them into [`Module`](super::Module)s.
+*/
+
+mod interpolator;
+mod type_gen;
+
+#[cfg(feature = "glsl-in")]
+pub mod glsl;
+#[cfg(feature = "spv-in")]
+pub mod spv;
+#[cfg(feature = "wgsl-in")]
+pub mod wgsl;
+
+use crate::{
+ arena::{Arena, Handle, UniqueArena},
+ proc::{ResolveContext, ResolveError, TypeResolution},
+ FastHashMap,
+};
+use std::ops;
+
+/// Helper class to emit expressions
+#[allow(dead_code)]
+#[derive(Default, Debug)]
+struct Emitter {
+ start_len: Option<usize>,
+}
+
+#[allow(dead_code)]
+impl Emitter {
+ fn start(&mut self, arena: &Arena<crate::Expression>) {
+ if self.start_len.is_some() {
+ unreachable!("Emitting has already started!");
+ }
+ self.start_len = Some(arena.len());
+ }
+ #[must_use]
+ fn finish(
+ &mut self,
+ arena: &Arena<crate::Expression>,
+ ) -> Option<(crate::Statement, crate::span::Span)> {
+ let start_len = self.start_len.take().unwrap();
+ if start_len != arena.len() {
+ #[allow(unused_mut)]
+ let mut span = crate::span::Span::default();
+ let range = arena.range_from(start_len);
+ #[cfg(feature = "span")]
+ for handle in range.clone() {
+ span.subsume(arena.get_span(handle))
+ }
+ Some((crate::Statement::Emit(range), span))
+ } else {
+ None
+ }
+ }
+}
+
+#[allow(dead_code)]
+impl super::ConstantInner {
+ const fn boolean(value: bool) -> Self {
+ Self::Scalar {
+ width: super::BOOL_WIDTH,
+ value: super::ScalarValue::Bool(value),
+ }
+ }
+}
+
+/// A table of types for an `Arena<Expression>`.
+///
+/// A front end can use a `Typifier` to get types for an arena's expressions
+/// while it is still contributing expressions to it. At any point, you can call
+/// [`typifier.grow(expr, arena, ctx)`], where `expr` is a `Handle<Expression>`
+/// referring to something in `arena`, and the `Typifier` will resolve the types
+/// of all the expressions up to and including `expr`. Then you can write
+/// `typifier[handle]` to get the type of any handle at or before `expr`.
+///
+/// Note that `Typifier` does *not* build an `Arena<Type>` as a part of its
+/// usual operation. Ideally, a module's type arena should only contain types
+/// actually needed by `Handle<Type>`s elsewhere in the module — functions,
+/// variables, [`Compose`] expressions, other types, and so on — so we don't
+/// want every little thing that occurs as the type of some intermediate
+/// expression to show up there.
+///
+/// Instead, `Typifier` accumulates a [`TypeResolution`] for each expression,
+/// which refers to the `Arena<Type>` in the [`ResolveContext`] passed to `grow`
+/// as needed. [`TypeResolution`] is a lightweight representation for
+/// intermediate types like this; see its documentation for details.
+///
+/// If you do need to register a `Typifier`'s conclusion in an `Arena<Type>`
+/// (say, for a [`LocalVariable`] whose type you've inferred), you can use
+/// [`register_type`] to do so.
+///
+/// [`typifier.grow(expr, arena)`]: Typifier::grow
+/// [`register_type`]: Typifier::register_type
+/// [`Compose`]: crate::Expression::Compose
+/// [`LocalVariable`]: crate::LocalVariable
+#[derive(Debug, Default)]
+pub struct Typifier {
+ resolutions: Vec<TypeResolution>,
+}
+
+impl Typifier {
+ pub const fn new() -> Self {
+ Typifier {
+ resolutions: Vec::new(),
+ }
+ }
+
+ pub fn reset(&mut self) {
+ self.resolutions.clear()
+ }
+
+ pub fn get<'a>(
+ &'a self,
+ expr_handle: Handle<crate::Expression>,
+ types: &'a UniqueArena<crate::Type>,
+ ) -> &'a crate::TypeInner {
+ self.resolutions[expr_handle.index()].inner_with(types)
+ }
+
+ /// Add an expression's type to an `Arena<Type>`.
+ ///
+ /// Add the type of `expr_handle` to `types`, and return a `Handle<Type>`
+ /// referring to it.
+ ///
+ /// # Note
+ ///
+ /// If you just need a [`TypeInner`] for `expr_handle`'s type, consider
+ /// using `typifier[expression].inner_with(types)` instead. Calling
+ /// [`TypeResolution::inner_with`] often lets us avoid adding anything to
+ /// the arena, which can significantly reduce the number of types that end
+ /// up in the final module.
+ ///
+ /// [`TypeInner`]: crate::TypeInner
+ pub fn register_type(
+ &self,
+ expr_handle: Handle<crate::Expression>,
+ types: &mut UniqueArena<crate::Type>,
+ ) -> Handle<crate::Type> {
+ match self[expr_handle].clone() {
+ TypeResolution::Handle(handle) => handle,
+ TypeResolution::Value(inner) => {
+ types.insert(crate::Type { name: None, inner }, crate::Span::UNDEFINED)
+ }
+ }
+ }
+
+ /// Grow this typifier until it contains a type for `expr_handle`.
+ pub fn grow(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ expressions: &Arena<crate::Expression>,
+ ctx: &ResolveContext,
+ ) -> Result<(), ResolveError> {
+ if self.resolutions.len() <= expr_handle.index() {
+ for (eh, expr) in expressions.iter().skip(self.resolutions.len()) {
+ //Note: the closure can't `Err` by construction
+ let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h.index()]))?;
+ log::debug!("Resolving {:?} = {:?} : {:?}", eh, expr, resolution);
+ self.resolutions.push(resolution);
+ }
+ }
+ Ok(())
+ }
+
+ /// Recompute the type resolution for `expr_handle`.
+ ///
+ /// If the type of `expr_handle` hasn't yet been calculated, call
+ /// [`grow`](Self::grow) to ensure it is covered.
+ ///
+ /// In either case, when this returns, `self[expr_handle]` should be an
+ /// updated type resolution for `expr_handle`.
+ pub fn invalidate(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ expressions: &Arena<crate::Expression>,
+ ctx: &ResolveContext,
+ ) -> Result<(), ResolveError> {
+ if self.resolutions.len() <= expr_handle.index() {
+ self.grow(expr_handle, expressions, ctx)
+ } else {
+ let expr = &expressions[expr_handle];
+ //Note: the closure can't `Err` by construction
+ let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h.index()]))?;
+ self.resolutions[expr_handle.index()] = resolution;
+ Ok(())
+ }
+ }
+}
+
+impl ops::Index<Handle<crate::Expression>> for Typifier {
+ type Output = TypeResolution;
+ fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
+ &self.resolutions[handle.index()]
+ }
+}
+
+/// Type representing a lexical scope, associating a name to a single variable
+///
+/// The scope is generic over the variable representation and name representaion
+/// in order to allow larger flexibility on the frontends on how they might
+/// represent them.
+type Scope<Name, Var> = FastHashMap<Name, Var>;
+
+/// Structure responsible for managing variable lookups and keeping track of
+/// lexical scopes
+///
+/// The symbol table is generic over the variable representation and its name
+/// to allow larger flexibility on the frontends on how they might represent them.
+///
+/// ```
+/// use naga::front::SymbolTable;
+///
+/// // Create a new symbol table with `u32`s representing the variable
+/// let mut symbol_table: SymbolTable<&str, u32> = SymbolTable::default();
+///
+/// // Add two variables named `var1` and `var2` with 0 and 2 respectively
+/// symbol_table.add("var1", 0);
+/// symbol_table.add("var2", 2);
+///
+/// // Check that `var1` exists and is `0`
+/// assert_eq!(symbol_table.lookup("var1"), Some(&0));
+///
+/// // Push a new scope and add a variable to it named `var1` shadowing the
+/// // variable of our previous scope
+/// symbol_table.push_scope();
+/// symbol_table.add("var1", 1);
+///
+/// // Check that `var1` now points to the new value of `1` and `var2` still
+/// // exists with its value of `2`
+/// assert_eq!(symbol_table.lookup("var1"), Some(&1));
+/// assert_eq!(symbol_table.lookup("var2"), Some(&2));
+///
+/// // Pop the scope
+/// symbol_table.pop_scope();
+///
+/// // Check that `var1` now refers to our initial variable with value `0`
+/// assert_eq!(symbol_table.lookup("var1"), Some(&0));
+/// ```
+///
+/// Scopes are ordered as a LIFO stack so a variable defined in a later scope
+/// with the same name as another variable defined in a earlier scope will take
+/// precedence in the lookup. Scopes can be added with [`push_scope`] and
+/// removed with [`pop_scope`].
+///
+/// A root scope is added when the symbol table is created and must always be
+/// present. Trying to pop it will result in a panic.
+///
+/// Variables can be added with [`add`] and looked up with [`lookup`]. Adding a
+/// variable will do so in the currently active scope and as mentioned
+/// previously a lookup will search from the current scope to the root scope.
+///
+/// [`push_scope`]: Self::push_scope
+/// [`pop_scope`]: Self::push_scope
+/// [`add`]: Self::add
+/// [`lookup`]: Self::lookup
+pub struct SymbolTable<Name, Var> {
+ /// Stack of lexical scopes. Not all scopes are active; see [`cursor`].
+ ///
+ /// [`cursor`]: Self::cursor
+ scopes: Vec<Scope<Name, Var>>,
+ /// Limit of the [`scopes`] stack (exclusive). By using a separate value for
+ /// the stack length instead of `Vec`'s own internal length, the scopes can
+ /// be reused to cache memory allocations.
+ ///
+ /// [`scopes`]: Self::scopes
+ cursor: usize,
+}
+
+impl<Name, Var> SymbolTable<Name, Var> {
+ /// Adds a new lexical scope.
+ ///
+ /// All variables declared after this point will be added to this scope
+ /// until another scope is pushed or [`pop_scope`] is called, causing this
+ /// scope to be removed along with all variables added to it.
+ ///
+ /// [`pop_scope`]: Self::pop_scope
+ pub fn push_scope(&mut self) {
+ // If the cursor is equal to the scope's stack length then we need to
+ // push another empty scope. Otherwise we can reuse the already existing
+ // scope.
+ if self.scopes.len() == self.cursor {
+ self.scopes.push(FastHashMap::default())
+ } else {
+ self.scopes[self.cursor].clear();
+ }
+
+ self.cursor += 1;
+ }
+
+ /// Removes the current lexical scope and all its variables
+ ///
+ /// # PANICS
+ /// - If the current lexical scope is the root scope
+ pub fn pop_scope(&mut self) {
+ // Despite the method title, the variables are only deleted when the
+ // scope is reused. This is because while a clear is inevitable if the
+ // scope needs to be reused, there are cases where the scope might be
+ // popped and not reused, i.e. if another scope with the same nesting
+ // level is never pushed again.
+ assert!(self.cursor != 1, "Tried to pop the root scope");
+
+ self.cursor -= 1;
+ }
+}
+
+impl<Name, Var> SymbolTable<Name, Var>
+where
+ Name: std::hash::Hash + Eq,
+{
+ /// Perform a lookup for a variable named `name`.
+ ///
+ /// As stated in the struct level documentation the lookup will proceed from
+ /// the current scope to the root scope, returning `Some` when a variable is
+ /// found or `None` if there doesn't exist a variable with `name` in any
+ /// scope.
+ pub fn lookup<Q: ?Sized>(&self, name: &Q) -> Option<&Var>
+ where
+ Name: std::borrow::Borrow<Q>,
+ Q: std::hash::Hash + Eq,
+ {
+ // Iterate backwards trough the scopes and try to find the variable
+ for scope in self.scopes[..self.cursor].iter().rev() {
+ if let Some(var) = scope.get(name) {
+ return Some(var);
+ }
+ }
+
+ None
+ }
+
+ /// Adds a new variable to the current scope.
+ ///
+ /// Returns the previous variable with the same name in this scope if it
+ /// exists, so that the frontend might handle it in case variable shadowing
+ /// is disallowed.
+ pub fn add(&mut self, name: Name, var: Var) -> Option<Var> {
+ self.scopes[self.cursor - 1].insert(name, var)
+ }
+
+ /// Adds a new variable to the root scope.
+ ///
+ /// This is used in GLSL for builtins which aren't known in advance and only
+ /// when used for the first time, so there must be a way to add those
+ /// declarations to the root unconditionally from the current scope.
+ ///
+ /// Returns the previous variable with the same name in the root scope if it
+ /// exists, so that the frontend might handle it in case variable shadowing
+ /// is disallowed.
+ pub fn add_root(&mut self, name: Name, var: Var) -> Option<Var> {
+ self.scopes[0].insert(name, var)
+ }
+}
+
+impl<Name, Var> Default for SymbolTable<Name, Var> {
+ /// Constructs a new symbol table with a root scope
+ fn default() -> Self {
+ Self {
+ scopes: vec![FastHashMap::default()],
+ cursor: 1,
+ }
+ }
+}
+
+use std::fmt;
+
+impl<Name: fmt::Debug, Var: fmt::Debug> fmt::Debug for SymbolTable<Name, Var> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.write_str("SymbolTable ")?;
+ f.debug_list()
+ .entries(self.scopes[..self.cursor].iter())
+ .finish()
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/convert.rs b/third_party/rust/naga/src/front/spv/convert.rs
new file mode 100644
index 0000000000..2967555d2b
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/convert.rs
@@ -0,0 +1,181 @@
+use super::error::Error;
+use num_traits::cast::FromPrimitive;
+use std::convert::TryInto;
+
+pub(super) const fn map_binary_operator(word: spirv::Op) -> Result<crate::BinaryOperator, Error> {
+ use crate::BinaryOperator;
+ use spirv::Op;
+
+ match word {
+ // Arithmetic Instructions +, -, *, /, %
+ Op::IAdd | Op::FAdd => Ok(BinaryOperator::Add),
+ Op::ISub | Op::FSub => Ok(BinaryOperator::Subtract),
+ Op::IMul | Op::FMul => Ok(BinaryOperator::Multiply),
+ Op::UDiv | Op::SDiv | Op::FDiv => Ok(BinaryOperator::Divide),
+ Op::SRem => Ok(BinaryOperator::Modulo),
+ // Relational and Logical Instructions
+ Op::IEqual | Op::FOrdEqual | Op::FUnordEqual | Op::LogicalEqual => {
+ Ok(BinaryOperator::Equal)
+ }
+ Op::INotEqual | Op::FOrdNotEqual | Op::FUnordNotEqual | Op::LogicalNotEqual => {
+ Ok(BinaryOperator::NotEqual)
+ }
+ Op::ULessThan | Op::SLessThan | Op::FOrdLessThan | Op::FUnordLessThan => {
+ Ok(BinaryOperator::Less)
+ }
+ Op::ULessThanEqual
+ | Op::SLessThanEqual
+ | Op::FOrdLessThanEqual
+ | Op::FUnordLessThanEqual => Ok(BinaryOperator::LessEqual),
+ Op::UGreaterThan | Op::SGreaterThan | Op::FOrdGreaterThan | Op::FUnordGreaterThan => {
+ Ok(BinaryOperator::Greater)
+ }
+ Op::UGreaterThanEqual
+ | Op::SGreaterThanEqual
+ | Op::FOrdGreaterThanEqual
+ | Op::FUnordGreaterThanEqual => Ok(BinaryOperator::GreaterEqual),
+ Op::BitwiseOr => Ok(BinaryOperator::InclusiveOr),
+ Op::BitwiseXor => Ok(BinaryOperator::ExclusiveOr),
+ Op::BitwiseAnd => Ok(BinaryOperator::And),
+ _ => Err(Error::UnknownBinaryOperator(word)),
+ }
+}
+
+pub(super) const fn map_relational_fun(
+ word: spirv::Op,
+) -> Result<crate::RelationalFunction, Error> {
+ use crate::RelationalFunction as Rf;
+ use spirv::Op;
+
+ match word {
+ Op::All => Ok(Rf::All),
+ Op::Any => Ok(Rf::Any),
+ Op::IsNan => Ok(Rf::IsNan),
+ Op::IsInf => Ok(Rf::IsInf),
+ Op::IsFinite => Ok(Rf::IsFinite),
+ Op::IsNormal => Ok(Rf::IsNormal),
+ _ => Err(Error::UnknownRelationalFunction(word)),
+ }
+}
+
+pub(super) const fn map_vector_size(word: spirv::Word) -> Result<crate::VectorSize, Error> {
+ match word {
+ 2 => Ok(crate::VectorSize::Bi),
+ 3 => Ok(crate::VectorSize::Tri),
+ 4 => Ok(crate::VectorSize::Quad),
+ _ => Err(Error::InvalidVectorSize(word)),
+ }
+}
+
+pub(super) fn map_image_dim(word: spirv::Word) -> Result<crate::ImageDimension, Error> {
+ use spirv::Dim as D;
+ match D::from_u32(word) {
+ Some(D::Dim1D) => Ok(crate::ImageDimension::D1),
+ Some(D::Dim2D) => Ok(crate::ImageDimension::D2),
+ Some(D::Dim3D) => Ok(crate::ImageDimension::D3),
+ Some(D::DimCube) => Ok(crate::ImageDimension::Cube),
+ _ => Err(Error::UnsupportedImageDim(word)),
+ }
+}
+
+pub(super) fn map_image_format(word: spirv::Word) -> Result<crate::StorageFormat, Error> {
+ match spirv::ImageFormat::from_u32(word) {
+ Some(spirv::ImageFormat::R8) => Ok(crate::StorageFormat::R8Unorm),
+ Some(spirv::ImageFormat::R8Snorm) => Ok(crate::StorageFormat::R8Snorm),
+ Some(spirv::ImageFormat::R8ui) => Ok(crate::StorageFormat::R8Uint),
+ Some(spirv::ImageFormat::R8i) => Ok(crate::StorageFormat::R8Sint),
+ Some(spirv::ImageFormat::R16) => Ok(crate::StorageFormat::R16Unorm),
+ Some(spirv::ImageFormat::R16Snorm) => Ok(crate::StorageFormat::R16Snorm),
+ Some(spirv::ImageFormat::R16ui) => Ok(crate::StorageFormat::R16Uint),
+ Some(spirv::ImageFormat::R16i) => Ok(crate::StorageFormat::R16Sint),
+ Some(spirv::ImageFormat::R16f) => Ok(crate::StorageFormat::R16Float),
+ Some(spirv::ImageFormat::Rg8) => Ok(crate::StorageFormat::Rg8Unorm),
+ Some(spirv::ImageFormat::Rg8Snorm) => Ok(crate::StorageFormat::Rg8Snorm),
+ Some(spirv::ImageFormat::Rg8ui) => Ok(crate::StorageFormat::Rg8Uint),
+ Some(spirv::ImageFormat::Rg8i) => Ok(crate::StorageFormat::Rg8Sint),
+ Some(spirv::ImageFormat::R32ui) => Ok(crate::StorageFormat::R32Uint),
+ Some(spirv::ImageFormat::R32i) => Ok(crate::StorageFormat::R32Sint),
+ Some(spirv::ImageFormat::R32f) => Ok(crate::StorageFormat::R32Float),
+ Some(spirv::ImageFormat::Rg16) => Ok(crate::StorageFormat::Rg16Unorm),
+ Some(spirv::ImageFormat::Rg16Snorm) => Ok(crate::StorageFormat::Rg16Snorm),
+ Some(spirv::ImageFormat::Rg16ui) => Ok(crate::StorageFormat::Rg16Uint),
+ Some(spirv::ImageFormat::Rg16i) => Ok(crate::StorageFormat::Rg16Sint),
+ Some(spirv::ImageFormat::Rg16f) => Ok(crate::StorageFormat::Rg16Float),
+ Some(spirv::ImageFormat::Rgba8) => Ok(crate::StorageFormat::Rgba8Unorm),
+ Some(spirv::ImageFormat::Rgba8Snorm) => Ok(crate::StorageFormat::Rgba8Snorm),
+ Some(spirv::ImageFormat::Rgba8ui) => Ok(crate::StorageFormat::Rgba8Uint),
+ Some(spirv::ImageFormat::Rgba8i) => Ok(crate::StorageFormat::Rgba8Sint),
+ Some(spirv::ImageFormat::Rgb10a2ui) => Ok(crate::StorageFormat::Rgb10a2Unorm),
+ Some(spirv::ImageFormat::R11fG11fB10f) => Ok(crate::StorageFormat::Rg11b10Float),
+ Some(spirv::ImageFormat::Rg32ui) => Ok(crate::StorageFormat::Rg32Uint),
+ Some(spirv::ImageFormat::Rg32i) => Ok(crate::StorageFormat::Rg32Sint),
+ Some(spirv::ImageFormat::Rg32f) => Ok(crate::StorageFormat::Rg32Float),
+ Some(spirv::ImageFormat::Rgba16) => Ok(crate::StorageFormat::Rgba16Unorm),
+ Some(spirv::ImageFormat::Rgba16Snorm) => Ok(crate::StorageFormat::Rgba16Snorm),
+ Some(spirv::ImageFormat::Rgba16ui) => Ok(crate::StorageFormat::Rgba16Uint),
+ Some(spirv::ImageFormat::Rgba16i) => Ok(crate::StorageFormat::Rgba16Sint),
+ Some(spirv::ImageFormat::Rgba16f) => Ok(crate::StorageFormat::Rgba16Float),
+ Some(spirv::ImageFormat::Rgba32ui) => Ok(crate::StorageFormat::Rgba32Uint),
+ Some(spirv::ImageFormat::Rgba32i) => Ok(crate::StorageFormat::Rgba32Sint),
+ Some(spirv::ImageFormat::Rgba32f) => Ok(crate::StorageFormat::Rgba32Float),
+ _ => Err(Error::UnsupportedImageFormat(word)),
+ }
+}
+
+pub(super) fn map_width(word: spirv::Word) -> Result<crate::Bytes, Error> {
+ (word >> 3) // bits to bytes
+ .try_into()
+ .map_err(|_| Error::InvalidTypeWidth(word))
+}
+
+pub(super) fn map_builtin(word: spirv::Word, invariant: bool) -> Result<crate::BuiltIn, Error> {
+ use spirv::BuiltIn as Bi;
+ Ok(match spirv::BuiltIn::from_u32(word) {
+ Some(Bi::Position | Bi::FragCoord) => crate::BuiltIn::Position { invariant },
+ Some(Bi::ViewIndex) => crate::BuiltIn::ViewIndex,
+ // vertex
+ Some(Bi::BaseInstance) => crate::BuiltIn::BaseInstance,
+ Some(Bi::BaseVertex) => crate::BuiltIn::BaseVertex,
+ Some(Bi::ClipDistance) => crate::BuiltIn::ClipDistance,
+ Some(Bi::CullDistance) => crate::BuiltIn::CullDistance,
+ Some(Bi::InstanceIndex) => crate::BuiltIn::InstanceIndex,
+ Some(Bi::PointSize) => crate::BuiltIn::PointSize,
+ Some(Bi::VertexIndex) => crate::BuiltIn::VertexIndex,
+ // fragment
+ Some(Bi::FragDepth) => crate::BuiltIn::FragDepth,
+ Some(Bi::PointCoord) => crate::BuiltIn::PointCoord,
+ Some(Bi::FrontFacing) => crate::BuiltIn::FrontFacing,
+ Some(Bi::PrimitiveId) => crate::BuiltIn::PrimitiveIndex,
+ Some(Bi::SampleId) => crate::BuiltIn::SampleIndex,
+ Some(Bi::SampleMask) => crate::BuiltIn::SampleMask,
+ // compute
+ Some(Bi::GlobalInvocationId) => crate::BuiltIn::GlobalInvocationId,
+ Some(Bi::LocalInvocationId) => crate::BuiltIn::LocalInvocationId,
+ Some(Bi::LocalInvocationIndex) => crate::BuiltIn::LocalInvocationIndex,
+ Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId,
+ Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize,
+ Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups,
+ _ => return Err(Error::UnsupportedBuiltIn(word)),
+ })
+}
+
+pub(super) fn map_storage_class(word: spirv::Word) -> Result<super::ExtendedClass, Error> {
+ use super::ExtendedClass as Ec;
+ use spirv::StorageClass as Sc;
+ Ok(match Sc::from_u32(word) {
+ Some(Sc::Function) => Ec::Global(crate::AddressSpace::Function),
+ Some(Sc::Input) => Ec::Input,
+ Some(Sc::Output) => Ec::Output,
+ Some(Sc::Private) => Ec::Global(crate::AddressSpace::Private),
+ Some(Sc::UniformConstant) => Ec::Global(crate::AddressSpace::Handle),
+ Some(Sc::StorageBuffer) => Ec::Global(crate::AddressSpace::Storage {
+ //Note: this is restricted by decorations later
+ access: crate::StorageAccess::all(),
+ }),
+ // we expect the `Storage` case to be filtered out before calling this function.
+ Some(Sc::Uniform) => Ec::Global(crate::AddressSpace::Uniform),
+ Some(Sc::Workgroup) => Ec::Global(crate::AddressSpace::WorkGroup),
+ Some(Sc::PushConstant) => Ec::Global(crate::AddressSpace::PushConstant),
+ _ => return Err(Error::UnsupportedStorageClass(word)),
+ })
+}
diff --git a/third_party/rust/naga/src/front/spv/error.rs b/third_party/rust/naga/src/front/spv/error.rs
new file mode 100644
index 0000000000..6c9cf384f1
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/error.rs
@@ -0,0 +1,127 @@
+use super::ModuleState;
+use crate::arena::Handle;
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error("invalid header")]
+ InvalidHeader,
+ #[error("invalid word count")]
+ InvalidWordCount,
+ #[error("unknown instruction {0}")]
+ UnknownInstruction(u16),
+ #[error("unknown capability %{0}")]
+ UnknownCapability(spirv::Word),
+ #[error("unsupported instruction {1:?} at {0:?}")]
+ UnsupportedInstruction(ModuleState, spirv::Op),
+ #[error("unsupported capability {0:?}")]
+ UnsupportedCapability(spirv::Capability),
+ #[error("unsupported extension {0}")]
+ UnsupportedExtension(String),
+ #[error("unsupported extension set {0}")]
+ UnsupportedExtSet(String),
+ #[error("unsupported extension instantiation set %{0}")]
+ UnsupportedExtInstSet(spirv::Word),
+ #[error("unsupported extension instantiation %{0}")]
+ UnsupportedExtInst(spirv::Word),
+ #[error("unsupported type {0:?}")]
+ UnsupportedType(Handle<crate::Type>),
+ #[error("unsupported execution model %{0}")]
+ UnsupportedExecutionModel(spirv::Word),
+ #[error("unsupported execution mode %{0}")]
+ UnsupportedExecutionMode(spirv::Word),
+ #[error("unsupported storage class %{0}")]
+ UnsupportedStorageClass(spirv::Word),
+ #[error("unsupported image dimension %{0}")]
+ UnsupportedImageDim(spirv::Word),
+ #[error("unsupported image format %{0}")]
+ UnsupportedImageFormat(spirv::Word),
+ #[error("unsupported builtin %{0}")]
+ UnsupportedBuiltIn(spirv::Word),
+ #[error("unsupported control flow %{0}")]
+ UnsupportedControlFlow(spirv::Word),
+ #[error("unsupported binary operator %{0}")]
+ UnsupportedBinaryOperator(spirv::Word),
+ #[error("Naga supports OpTypeRuntimeArray in the StorageBuffer storage class only")]
+ UnsupportedRuntimeArrayStorageClass,
+ #[error("unsupported matrix stride {stride} for a {columns}x{rows} matrix with scalar width={width}")]
+ UnsupportedMatrixStride {
+ stride: u32,
+ columns: u8,
+ rows: u8,
+ width: u8,
+ },
+ #[error("unknown binary operator {0:?}")]
+ UnknownBinaryOperator(spirv::Op),
+ #[error("unknown relational function {0:?}")]
+ UnknownRelationalFunction(spirv::Op),
+ #[error("invalid parameter {0:?}")]
+ InvalidParameter(spirv::Op),
+ #[error("invalid operand count {1} for {0:?}")]
+ InvalidOperandCount(spirv::Op, u16),
+ #[error("invalid operand")]
+ InvalidOperand,
+ #[error("invalid id %{0}")]
+ InvalidId(spirv::Word),
+ #[error("invalid decoration %{0}")]
+ InvalidDecoration(spirv::Word),
+ #[error("invalid type width %{0}")]
+ InvalidTypeWidth(spirv::Word),
+ #[error("invalid sign %{0}")]
+ InvalidSign(spirv::Word),
+ #[error("invalid inner type %{0}")]
+ InvalidInnerType(spirv::Word),
+ #[error("invalid vector size %{0}")]
+ InvalidVectorSize(spirv::Word),
+ #[error("invalid access type %{0}")]
+ InvalidAccessType(spirv::Word),
+ #[error("invalid access {0:?}")]
+ InvalidAccess(crate::Expression),
+ #[error("invalid access index %{0}")]
+ InvalidAccessIndex(spirv::Word),
+ #[error("invalid binding %{0}")]
+ InvalidBinding(spirv::Word),
+ #[error("invalid global var {0:?}")]
+ InvalidGlobalVar(crate::Expression),
+ #[error("invalid image/sampler expression {0:?}")]
+ InvalidImageExpression(crate::Expression),
+ #[error("invalid image base type {0:?}")]
+ InvalidImageBaseType(Handle<crate::Type>),
+ #[error("invalid image {0:?}")]
+ InvalidImage(Handle<crate::Type>),
+ #[error("invalid as type {0:?}")]
+ InvalidAsType(Handle<crate::Type>),
+ #[error("invalid vector type {0:?}")]
+ InvalidVectorType(Handle<crate::Type>),
+ #[error("inconsistent comparison sampling {0:?}")]
+ InconsistentComparisonSampling(Handle<crate::GlobalVariable>),
+ #[error("wrong function result type %{0}")]
+ WrongFunctionResultType(spirv::Word),
+ #[error("wrong function argument type %{0}")]
+ WrongFunctionArgumentType(spirv::Word),
+ #[error("missing decoration {0:?}")]
+ MissingDecoration(spirv::Decoration),
+ #[error("bad string")]
+ BadString,
+ #[error("incomplete data")]
+ IncompleteData,
+ #[error("invalid terminator")]
+ InvalidTerminator,
+ #[error("invalid edge classification")]
+ InvalidEdgeClassification,
+ #[error("cycle detected in the CFG during traversal at {0}")]
+ ControlFlowGraphCycle(crate::front::spv::BlockId),
+ #[error("recursive function call %{0}")]
+ FunctionCallCycle(spirv::Word),
+ #[error("invalid array size {0:?}")]
+ InvalidArraySize(Handle<crate::Constant>),
+ #[error("invalid barrier scope %{0}")]
+ InvalidBarrierScope(spirv::Word),
+ #[error("invalid barrier memory semantics %{0}")]
+ InvalidBarrierMemorySemantics(spirv::Word),
+ #[error(
+ "arrays of images / samplers are supported only through bindings for \
+ now (i.e. you can't create an array of images or samplers that doesn't \
+ come from a binding)"
+ )]
+ NonBindingArrayOfImageOrSamplers,
+}
diff --git a/third_party/rust/naga/src/front/spv/function.rs b/third_party/rust/naga/src/front/spv/function.rs
new file mode 100644
index 0000000000..5dc781504e
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/function.rs
@@ -0,0 +1,661 @@
+use crate::{
+ arena::{Arena, Handle},
+ front::spv::{BlockContext, BodyIndex},
+};
+
+use super::{Error, Instruction, LookupExpression, LookupHelper as _};
+use crate::front::Emitter;
+
+pub type BlockId = u32;
+
+#[derive(Copy, Clone, Debug)]
+pub struct MergeInstruction {
+ pub merge_block_id: BlockId,
+ pub continue_block_id: Option<BlockId>,
+}
+
+impl<I: Iterator<Item = u32>> super::Frontend<I> {
+ // Registers a function call. It will generate a dummy handle to call, which
+ // gets resolved after all the functions are processed.
+ pub(super) fn add_call(
+ &mut self,
+ from: spirv::Word,
+ to: spirv::Word,
+ ) -> Handle<crate::Function> {
+ let dummy_handle = self
+ .dummy_functions
+ .append(crate::Function::default(), Default::default());
+ self.deferred_function_calls.push(to);
+ self.function_call_graph.add_edge(from, to, ());
+ dummy_handle
+ }
+
+ pub(super) fn parse_function(&mut self, module: &mut crate::Module) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.lookup_expression.clear();
+ self.lookup_load_override.clear();
+ self.lookup_sampled_image.clear();
+
+ let result_type_id = self.next()?;
+ let fun_id = self.next()?;
+ let _fun_control = self.next()?;
+ let fun_type_id = self.next()?;
+
+ let mut fun = {
+ let ft = self.lookup_function_type.lookup(fun_type_id)?;
+ if ft.return_type_id != result_type_id {
+ return Err(Error::WrongFunctionResultType(result_type_id));
+ }
+ crate::Function {
+ name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name),
+ arguments: Vec::with_capacity(ft.parameter_type_ids.len()),
+ result: if self.lookup_void_type == Some(result_type_id) {
+ None
+ } else {
+ let lookup_result_ty = self.lookup_type.lookup(result_type_id)?;
+ Some(crate::FunctionResult {
+ ty: lookup_result_ty.handle,
+ binding: None,
+ })
+ },
+ local_variables: Arena::new(),
+ expressions: self
+ .make_expression_storage(&module.global_variables, &module.constants),
+ named_expressions: crate::NamedExpressions::default(),
+ body: crate::Block::new(),
+ }
+ };
+
+ // read parameters
+ for i in 0..fun.arguments.capacity() {
+ let start = self.data_offset;
+ match self.next_inst()? {
+ Instruction {
+ op: spirv::Op::FunctionParameter,
+ wc: 3,
+ } => {
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let handle = fun.expressions.append(
+ crate::Expression::FunctionArgument(i as u32),
+ self.span_from(start),
+ );
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle,
+ type_id,
+ // Setting this to an invalid id will cause get_expr_handle
+ // to default to the main body making sure no load/stores
+ // are added.
+ block_id: 0,
+ },
+ );
+ //Note: we redo the lookup in order to work around `self` borrowing
+
+ if type_id
+ != self
+ .lookup_function_type
+ .lookup(fun_type_id)?
+ .parameter_type_ids[i]
+ {
+ return Err(Error::WrongFunctionArgumentType(type_id));
+ }
+ let ty = self.lookup_type.lookup(type_id)?.handle;
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+ fun.arguments.push(crate::FunctionArgument {
+ name: decor.name,
+ ty,
+ binding: None,
+ });
+ }
+ Instruction { op, .. } => return Err(Error::InvalidParameter(op)),
+ }
+ }
+
+ // Read body
+ self.function_call_graph.add_node(fun_id);
+ let mut parameters_sampling =
+ vec![super::image::SamplingFlags::empty(); fun.arguments.len()];
+
+ let mut block_ctx = BlockContext {
+ phis: Default::default(),
+ blocks: Default::default(),
+ body_for_label: Default::default(),
+ mergers: Default::default(),
+ bodies: Default::default(),
+ function_id: fun_id,
+ expressions: &mut fun.expressions,
+ local_arena: &mut fun.local_variables,
+ const_arena: &mut module.constants,
+ type_arena: &module.types,
+ global_arena: &module.global_variables,
+ arguments: &fun.arguments,
+ parameter_sampling: &mut parameters_sampling,
+ };
+ // Insert the main body whose parent is also himself
+ block_ctx.bodies.push(super::Body::with_parent(0));
+
+ // Scan the blocks and add them as nodes
+ loop {
+ let fun_inst = self.next_inst()?;
+ log::debug!("{:?}", fun_inst.op);
+ match fun_inst.op {
+ spirv::Op::Line => {
+ fun_inst.expect(4)?;
+ let _file_id = self.next()?;
+ let _row_id = self.next()?;
+ let _col_id = self.next()?;
+ }
+ spirv::Op::Label => {
+ // Read the label ID
+ fun_inst.expect(2)?;
+ let block_id = self.next()?;
+
+ self.next_block(block_id, &mut block_ctx)?;
+ }
+ spirv::Op::FunctionEnd => {
+ fun_inst.expect(1)?;
+ break;
+ }
+ _ => {
+ return Err(Error::UnsupportedInstruction(self.state, fun_inst.op));
+ }
+ }
+ }
+
+ if let Some(ref prefix) = self.options.block_ctx_dump_prefix {
+ let dump_suffix = match self.lookup_entry_point.get(&fun_id) {
+ Some(ep) => format!("block_ctx.{:?}-{}.txt", ep.stage, ep.name),
+ None => format!("block_ctx.Fun-{}.txt", module.functions.len()),
+ };
+ let dest = prefix.join(dump_suffix);
+ let dump = format!("{block_ctx:#?}");
+ if let Err(e) = std::fs::write(&dest, dump) {
+ log::error!("Unable to dump the block context into {:?}: {}", dest, e);
+ }
+ }
+
+ // Emit `Store` statements to properly initialize all the local variables we
+ // created for `phi` expressions.
+ //
+ // Note that get_expr_handle also contributes slightly odd entries to this table,
+ // to get the spill.
+ for phi in block_ctx.phis.iter() {
+ // Get a pointer to the local variable for the phi's value.
+ let phi_pointer = block_ctx.expressions.append(
+ crate::Expression::LocalVariable(phi.local),
+ crate::Span::default(),
+ );
+
+ // At the end of each of `phi`'s predecessor blocks, store the corresponding
+ // source value in the phi's local variable.
+ for &(source, predecessor) in phi.expressions.iter() {
+ let source_lexp = &self.lookup_expression[&source];
+ let predecessor_body_idx = block_ctx.body_for_label[&predecessor];
+ // If the expression is a global/argument it will have a 0 block
+ // id so we must use a default value instead of panicking
+ let source_body_idx = block_ctx
+ .body_for_label
+ .get(&source_lexp.block_id)
+ .copied()
+ .unwrap_or(0);
+
+ // If the Naga `Expression` generated for `source` is in scope, then we
+ // can simply store that in the phi's local variable.
+ //
+ // Otherwise, spill the source value to a local variable in the block that
+ // defines it. (We know this store dominates the predecessor; otherwise,
+ // the phi wouldn't have been able to refer to that source expression in
+ // the first place.) Then, the predecessor block can count on finding the
+ // source's value in that local variable.
+ let value = if super::is_parent(predecessor_body_idx, source_body_idx, &block_ctx) {
+ source_lexp.handle
+ } else {
+ // The source SPIR-V expression is not defined in the phi's
+ // predecessor block, nor is it a globally available expression. So it
+ // must be defined off in some other block that merely dominates the
+ // predecessor. This means that the corresponding Naga `Expression`
+ // may not be in scope in the predecessor block.
+ //
+ // In the block that defines `source`, spill it to a fresh local
+ // variable, to ensure we can still use it at the end of the
+ // predecessor.
+ let ty = self.lookup_type[&source_lexp.type_id].handle;
+ let local = block_ctx.local_arena.append(
+ crate::LocalVariable {
+ name: None,
+ ty,
+ init: None,
+ },
+ crate::Span::default(),
+ );
+
+ let pointer = block_ctx.expressions.append(
+ crate::Expression::LocalVariable(local),
+ crate::Span::default(),
+ );
+
+ // Get the spilled value of the source expression.
+ let start = block_ctx.expressions.len();
+ let expr = block_ctx
+ .expressions
+ .append(crate::Expression::Load { pointer }, crate::Span::default());
+ let range = block_ctx.expressions.range_from(start);
+
+ block_ctx
+ .blocks
+ .get_mut(&predecessor)
+ .unwrap()
+ .push(crate::Statement::Emit(range), crate::Span::default());
+
+ // At the end of the block that defines it, spill the source
+ // expression's value.
+ block_ctx
+ .blocks
+ .get_mut(&source_lexp.block_id)
+ .unwrap()
+ .push(
+ crate::Statement::Store {
+ pointer,
+ value: source_lexp.handle,
+ },
+ crate::Span::default(),
+ );
+
+ expr
+ };
+
+ // At the end of the phi predecessor block, store the source
+ // value in the phi's value.
+ block_ctx.blocks.get_mut(&predecessor).unwrap().push(
+ crate::Statement::Store {
+ pointer: phi_pointer,
+ value,
+ },
+ crate::Span::default(),
+ )
+ }
+ }
+
+ fun.body = block_ctx.lower();
+
+ // done
+ let fun_handle = module.functions.append(fun, self.span_from_with_op(start));
+ self.lookup_function.insert(
+ fun_id,
+ super::LookupFunction {
+ handle: fun_handle,
+ parameters_sampling,
+ },
+ );
+
+ if let Some(ep) = self.lookup_entry_point.remove(&fun_id) {
+ // create a wrapping function
+ let mut function = crate::Function {
+ name: Some(format!("{}_wrap", ep.name)),
+ arguments: Vec::new(),
+ result: None,
+ local_variables: Arena::new(),
+ expressions: Arena::new(),
+ named_expressions: crate::NamedExpressions::default(),
+ body: crate::Block::new(),
+ };
+
+ // 1. copy the inputs from arguments to privates
+ for &v_id in ep.variable_ids.iter() {
+ let lvar = self.lookup_variable.lookup(v_id)?;
+ if let super::Variable::Input(ref arg) = lvar.inner {
+ let span = module.global_variables.get_span(lvar.handle);
+ let arg_expr = function.expressions.append(
+ crate::Expression::FunctionArgument(function.arguments.len() as u32),
+ span,
+ );
+ let load_expr = if arg.ty == module.global_variables[lvar.handle].ty {
+ arg_expr
+ } else {
+ // The only case where the type is different is if we need to treat
+ // unsigned integer as signed.
+ let mut emitter = Emitter::default();
+ emitter.start(&function.expressions);
+ let handle = function.expressions.append(
+ crate::Expression::As {
+ expr: arg_expr,
+ kind: crate::ScalarKind::Sint,
+ convert: Some(4),
+ },
+ span,
+ );
+ function.body.extend(emitter.finish(&function.expressions));
+ handle
+ };
+ function.body.push(
+ crate::Statement::Store {
+ pointer: function
+ .expressions
+ .append(crate::Expression::GlobalVariable(lvar.handle), span),
+ value: load_expr,
+ },
+ span,
+ );
+
+ let mut arg = arg.clone();
+ if ep.stage == crate::ShaderStage::Fragment {
+ if let Some(ref mut binding) = arg.binding {
+ binding.apply_default_interpolation(&module.types[arg.ty].inner);
+ }
+ }
+ function.arguments.push(arg);
+ }
+ }
+ // 2. call the wrapped function
+ let fake_id = !(module.entry_points.len() as u32); // doesn't matter, as long as it's not a collision
+ let dummy_handle = self.add_call(fake_id, fun_id);
+ function.body.push(
+ crate::Statement::Call {
+ function: dummy_handle,
+ arguments: Vec::new(),
+ result: None,
+ },
+ crate::Span::default(),
+ );
+
+ // 3. copy the outputs from privates to the result
+ let mut members = Vec::new();
+ let mut components = Vec::new();
+ for &v_id in ep.variable_ids.iter() {
+ let lvar = self.lookup_variable.lookup(v_id)?;
+ if let super::Variable::Output(ref result) = lvar.inner {
+ let span = module.global_variables.get_span(lvar.handle);
+ let expr_handle = function
+ .expressions
+ .append(crate::Expression::GlobalVariable(lvar.handle), span);
+
+ // Cull problematic builtins of gl_PerVertex.
+ // See the docs for `Frontend::gl_per_vertex_builtin_access`.
+ {
+ let ty = &module.types[result.ty];
+ match ty.inner {
+ crate::TypeInner::Struct {
+ members: ref original_members,
+ span,
+ } if ty.name.as_deref() == Some("gl_PerVertex") => {
+ let mut new_members = original_members.clone();
+ for member in &mut new_members {
+ if let Some(crate::Binding::BuiltIn(built_in)) = member.binding
+ {
+ if !self.gl_per_vertex_builtin_access.contains(&built_in) {
+ member.binding = None
+ }
+ }
+ }
+ if &new_members != original_members {
+ module.types.replace(
+ result.ty,
+ crate::Type {
+ name: ty.name.clone(),
+ inner: crate::TypeInner::Struct {
+ members: new_members,
+ span,
+ },
+ },
+ );
+ }
+ }
+ _ => {}
+ }
+ }
+
+ match module.types[result.ty].inner {
+ crate::TypeInner::Struct {
+ members: ref sub_members,
+ ..
+ } => {
+ for (index, sm) in sub_members.iter().enumerate() {
+ if sm.binding.is_none() {
+ continue;
+ }
+ let mut sm = sm.clone();
+
+ if let Some(ref mut binding) = sm.binding {
+ if ep.stage == crate::ShaderStage::Vertex {
+ binding.apply_default_interpolation(
+ &module.types[sm.ty].inner,
+ );
+ }
+ }
+
+ members.push(sm);
+
+ components.push(function.expressions.append(
+ crate::Expression::AccessIndex {
+ base: expr_handle,
+ index: index as u32,
+ },
+ span,
+ ));
+ }
+ }
+ ref inner => {
+ let mut binding = result.binding.clone();
+ if let Some(ref mut binding) = binding {
+ if ep.stage == crate::ShaderStage::Vertex {
+ binding.apply_default_interpolation(inner);
+ }
+ }
+
+ members.push(crate::StructMember {
+ name: None,
+ ty: result.ty,
+ binding,
+ offset: 0,
+ });
+ // populate just the globals first, then do `Load` in a
+ // separate step, so that we can get a range.
+ components.push(expr_handle);
+ }
+ }
+ }
+ }
+
+ for (member_index, member) in members.iter().enumerate() {
+ match member.binding {
+ Some(crate::Binding::BuiltIn(crate::BuiltIn::Position { .. }))
+ if self.options.adjust_coordinate_space =>
+ {
+ let mut emitter = Emitter::default();
+ emitter.start(&function.expressions);
+ let global_expr = components[member_index];
+ let span = function.expressions.get_span(global_expr);
+ let access_expr = function.expressions.append(
+ crate::Expression::AccessIndex {
+ base: global_expr,
+ index: 1,
+ },
+ span,
+ );
+ let load_expr = function.expressions.append(
+ crate::Expression::Load {
+ pointer: access_expr,
+ },
+ span,
+ );
+ let neg_expr = function.expressions.append(
+ crate::Expression::Unary {
+ op: crate::UnaryOperator::Negate,
+ expr: load_expr,
+ },
+ span,
+ );
+ function.body.extend(emitter.finish(&function.expressions));
+ function.body.push(
+ crate::Statement::Store {
+ pointer: access_expr,
+ value: neg_expr,
+ },
+ span,
+ );
+ }
+ _ => {}
+ }
+ }
+
+ let mut emitter = Emitter::default();
+ emitter.start(&function.expressions);
+ for component in components.iter_mut() {
+ let load_expr = crate::Expression::Load {
+ pointer: *component,
+ };
+ let span = function.expressions.get_span(*component);
+ *component = function.expressions.append(load_expr, span);
+ }
+
+ match members[..] {
+ [] => {}
+ [ref member] => {
+ function.body.extend(emitter.finish(&function.expressions));
+ let span = function.expressions.get_span(components[0]);
+ function.body.push(
+ crate::Statement::Return {
+ value: components.first().cloned(),
+ },
+ span,
+ );
+ function.result = Some(crate::FunctionResult {
+ ty: member.ty,
+ binding: member.binding.clone(),
+ });
+ }
+ _ => {
+ let span = crate::Span::total_span(
+ components.iter().map(|h| function.expressions.get_span(*h)),
+ );
+ let ty = module.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Struct {
+ members,
+ span: 0xFFFF, // shouldn't matter
+ },
+ },
+ span,
+ );
+ let result_expr = function
+ .expressions
+ .append(crate::Expression::Compose { ty, components }, span);
+ function.body.extend(emitter.finish(&function.expressions));
+ function.body.push(
+ crate::Statement::Return {
+ value: Some(result_expr),
+ },
+ span,
+ );
+ function.result = Some(crate::FunctionResult { ty, binding: None });
+ }
+ }
+
+ module.entry_points.push(crate::EntryPoint {
+ name: ep.name,
+ stage: ep.stage,
+ early_depth_test: ep.early_depth_test,
+ workgroup_size: ep.workgroup_size,
+ function,
+ });
+ }
+
+ Ok(())
+ }
+}
+
+impl<'function> BlockContext<'function> {
+ /// Consumes the `BlockContext` producing a Ir [`Block`](crate::Block)
+ fn lower(mut self) -> crate::Block {
+ fn lower_impl(
+ blocks: &mut crate::FastHashMap<spirv::Word, crate::Block>,
+ bodies: &[super::Body],
+ body_idx: BodyIndex,
+ ) -> crate::Block {
+ let mut block = crate::Block::new();
+
+ for item in bodies[body_idx].data.iter() {
+ match *item {
+ super::BodyFragment::BlockId(id) => block.append(blocks.get_mut(&id).unwrap()),
+ super::BodyFragment::If {
+ condition,
+ accept,
+ reject,
+ } => {
+ let accept = lower_impl(blocks, bodies, accept);
+ let reject = lower_impl(blocks, bodies, reject);
+
+ block.push(
+ crate::Statement::If {
+ condition,
+ accept,
+ reject,
+ },
+ crate::Span::default(),
+ )
+ }
+ super::BodyFragment::Loop { body, continuing } => {
+ let body = lower_impl(blocks, bodies, body);
+ let continuing = lower_impl(blocks, bodies, continuing);
+
+ block.push(
+ crate::Statement::Loop {
+ body,
+ continuing,
+ break_if: None,
+ },
+ crate::Span::default(),
+ )
+ }
+ super::BodyFragment::Switch {
+ selector,
+ ref cases,
+ default,
+ } => {
+ let mut ir_cases: Vec<_> = cases
+ .iter()
+ .map(|&(value, body_idx)| {
+ let body = lower_impl(blocks, bodies, body_idx);
+
+ // Handle simple cases that would make a fallthrough statement unreachable code
+ let fall_through = body.last().map_or(true, |s| !s.is_terminator());
+
+ crate::SwitchCase {
+ value: crate::SwitchValue::I32(value),
+ body,
+ fall_through,
+ }
+ })
+ .collect();
+ ir_cases.push(crate::SwitchCase {
+ value: crate::SwitchValue::Default,
+ body: lower_impl(blocks, bodies, default),
+ fall_through: false,
+ });
+
+ block.push(
+ crate::Statement::Switch {
+ selector,
+ cases: ir_cases,
+ },
+ crate::Span::default(),
+ )
+ }
+ super::BodyFragment::Break => {
+ block.push(crate::Statement::Break, crate::Span::default())
+ }
+ super::BodyFragment::Continue => {
+ block.push(crate::Statement::Continue, crate::Span::default())
+ }
+ }
+ }
+
+ block
+ }
+
+ lower_impl(&mut self.blocks, &self.bodies, 0)
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/image.rs b/third_party/rust/naga/src/front/spv/image.rs
new file mode 100644
index 0000000000..757843f789
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/image.rs
@@ -0,0 +1,737 @@
+use crate::arena::{Arena, Handle, UniqueArena};
+
+use super::{Error, LookupExpression, LookupHelper as _};
+
+#[derive(Clone, Debug)]
+pub(super) struct LookupSampledImage {
+ image: Handle<crate::Expression>,
+ sampler: Handle<crate::Expression>,
+}
+
+bitflags::bitflags! {
+ /// Flags describing sampling method.
+ pub struct SamplingFlags: u32 {
+ /// Regular sampling.
+ const REGULAR = 0x1;
+ /// Comparison sampling.
+ const COMPARISON = 0x2;
+ }
+}
+
+impl<'function> super::BlockContext<'function> {
+ fn get_image_expr_ty(
+ &self,
+ handle: Handle<crate::Expression>,
+ ) -> Result<Handle<crate::Type>, Error> {
+ match self.expressions[handle] {
+ crate::Expression::GlobalVariable(handle) => Ok(self.global_arena[handle].ty),
+ crate::Expression::FunctionArgument(i) => Ok(self.arguments[i as usize].ty),
+ ref other => Err(Error::InvalidImageExpression(other.clone())),
+ }
+ }
+}
+
+/// Options of a sampling operation.
+#[derive(Debug)]
+pub struct SamplingOptions {
+ /// Projection sampling: the division by W is expected to happen
+ /// in the texture unit.
+ pub project: bool,
+ /// Depth comparison sampling with a reference value.
+ pub compare: bool,
+}
+
+enum ExtraCoordinate {
+ ArrayLayer,
+ Projection,
+ Garbage,
+}
+
+/// Return the texture coordinates separated from the array layer,
+/// and/or divided by the projection term.
+///
+/// The Proj sampling ops expect an extra coordinate for the W.
+/// The arrayed (can't be Proj!) images expect an extra coordinate for the layer.
+fn extract_image_coordinates(
+ image_dim: crate::ImageDimension,
+ extra_coordinate: ExtraCoordinate,
+ base: Handle<crate::Expression>,
+ coordinate_ty: Handle<crate::Type>,
+ ctx: &mut super::BlockContext,
+) -> (Handle<crate::Expression>, Option<Handle<crate::Expression>>) {
+ let (given_size, kind) = match ctx.type_arena[coordinate_ty].inner {
+ crate::TypeInner::Scalar { kind, .. } => (None, kind),
+ crate::TypeInner::Vector { size, kind, .. } => (Some(size), kind),
+ ref other => unreachable!("Unexpected texture coordinate {:?}", other),
+ };
+
+ let required_size = image_dim.required_coordinate_size();
+ let required_ty = required_size.map(|size| {
+ ctx.type_arena
+ .get(&crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ size,
+ kind,
+ width: 4,
+ },
+ })
+ .expect("Required coordinate type should have been set up by `parse_type_image`!")
+ });
+ let extra_expr = crate::Expression::AccessIndex {
+ base,
+ index: required_size.map_or(1, |size| size as u32),
+ };
+
+ let base_span = ctx.expressions.get_span(base);
+
+ match extra_coordinate {
+ ExtraCoordinate::ArrayLayer => {
+ let extracted = match required_size {
+ None => ctx
+ .expressions
+ .append(crate::Expression::AccessIndex { base, index: 0 }, base_span),
+ Some(size) => {
+ let mut components = Vec::with_capacity(size as usize);
+ for index in 0..size as u32 {
+ let comp = ctx
+ .expressions
+ .append(crate::Expression::AccessIndex { base, index }, base_span);
+ components.push(comp);
+ }
+ ctx.expressions.append(
+ crate::Expression::Compose {
+ ty: required_ty.unwrap(),
+ components,
+ },
+ base_span,
+ )
+ }
+ };
+ let array_index_f32 = ctx.expressions.append(extra_expr, base_span);
+ let array_index = ctx.expressions.append(
+ crate::Expression::As {
+ kind: crate::ScalarKind::Sint,
+ expr: array_index_f32,
+ convert: Some(4),
+ },
+ base_span,
+ );
+ (extracted, Some(array_index))
+ }
+ ExtraCoordinate::Projection => {
+ let projection = ctx.expressions.append(extra_expr, base_span);
+ let divided = match required_size {
+ None => {
+ let temp = ctx
+ .expressions
+ .append(crate::Expression::AccessIndex { base, index: 0 }, base_span);
+ ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Divide,
+ left: temp,
+ right: projection,
+ },
+ base_span,
+ )
+ }
+ Some(size) => {
+ let mut components = Vec::with_capacity(size as usize);
+ for index in 0..size as u32 {
+ let temp = ctx
+ .expressions
+ .append(crate::Expression::AccessIndex { base, index }, base_span);
+ let comp = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Divide,
+ left: temp,
+ right: projection,
+ },
+ base_span,
+ );
+ components.push(comp);
+ }
+ ctx.expressions.append(
+ crate::Expression::Compose {
+ ty: required_ty.unwrap(),
+ components,
+ },
+ base_span,
+ )
+ }
+ };
+ (divided, None)
+ }
+ ExtraCoordinate::Garbage if given_size == required_size => (base, None),
+ ExtraCoordinate::Garbage => {
+ use crate::SwizzleComponent as Sc;
+ let cut_expr = match required_size {
+ None => crate::Expression::AccessIndex { base, index: 0 },
+ Some(size) => crate::Expression::Swizzle {
+ size,
+ vector: base,
+ pattern: [Sc::X, Sc::Y, Sc::Z, Sc::W],
+ },
+ };
+ (ctx.expressions.append(cut_expr, base_span), None)
+ }
+ }
+}
+
+pub(super) fn patch_comparison_type(
+ flags: SamplingFlags,
+ var: &mut crate::GlobalVariable,
+ arena: &mut UniqueArena<crate::Type>,
+) -> bool {
+ if !flags.contains(SamplingFlags::COMPARISON) {
+ return true;
+ }
+ if flags == SamplingFlags::all() {
+ return false;
+ }
+
+ log::debug!("Flipping comparison for {:?}", var);
+ let original_ty = &arena[var.ty];
+ let original_ty_span = arena.get_span(var.ty);
+ let ty_inner = match original_ty.inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Sampled { multi, .. },
+ dim,
+ arrayed,
+ } => crate::TypeInner::Image {
+ class: crate::ImageClass::Depth { multi },
+ dim,
+ arrayed,
+ },
+ crate::TypeInner::Sampler { .. } => crate::TypeInner::Sampler { comparison: true },
+ ref other => unreachable!("Unexpected type for comparison mutation: {:?}", other),
+ };
+
+ let name = original_ty.name.clone();
+ var.ty = arena.insert(
+ crate::Type {
+ name,
+ inner: ty_inner,
+ },
+ original_ty_span,
+ );
+ true
+}
+
+impl<I: Iterator<Item = u32>> super::Frontend<I> {
+ pub(super) fn parse_image_couple(&mut self) -> Result<(), Error> {
+ let _result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let image_id = self.next()?;
+ let sampler_id = self.next()?;
+ let image_lexp = self.lookup_expression.lookup(image_id)?;
+ let sampler_lexp = self.lookup_expression.lookup(sampler_id)?;
+ self.lookup_sampled_image.insert(
+ result_id,
+ LookupSampledImage {
+ image: image_lexp.handle,
+ sampler: sampler_lexp.handle,
+ },
+ );
+ Ok(())
+ }
+
+ pub(super) fn parse_image_uncouple(&mut self, block_id: spirv::Word) -> Result<(), Error> {
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let sampled_image_id = self.next()?;
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: self.lookup_sampled_image.lookup(sampled_image_id)?.image,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ pub(super) fn parse_image_write(
+ &mut self,
+ words_left: u16,
+ ctx: &mut super::BlockContext,
+ emitter: &mut crate::front::Emitter,
+ block: &mut crate::Block,
+ body_idx: usize,
+ ) -> Result<crate::Statement, Error> {
+ let image_id = self.next()?;
+ let coordinate_id = self.next()?;
+ let value_id = self.next()?;
+
+ let image_ops = if words_left != 0 { self.next()? } else { 0 };
+
+ if image_ops != 0 {
+ let other = spirv::ImageOperands::from_bits_truncate(image_ops);
+ log::warn!("Unknown image write ops {:?}", other);
+ for _ in 1..words_left {
+ self.next()?;
+ }
+ }
+
+ let image_lexp = self.lookup_expression.lookup(image_id)?;
+ let image_ty = ctx.get_image_expr_ty(image_lexp.handle)?;
+
+ let coord_lexp = self.lookup_expression.lookup(coordinate_id)?;
+ let coord_handle =
+ self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx);
+ let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle;
+ let (coordinate, array_index) = match ctx.type_arena[image_ty].inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: _,
+ } => extract_image_coordinates(
+ dim,
+ if arrayed {
+ ExtraCoordinate::ArrayLayer
+ } else {
+ ExtraCoordinate::Garbage
+ },
+ coord_handle,
+ coord_type_handle,
+ ctx,
+ ),
+ _ => return Err(Error::InvalidImage(image_ty)),
+ };
+
+ let value_lexp = self.lookup_expression.lookup(value_id)?;
+ let value = self.get_expr_handle(value_id, value_lexp, ctx, emitter, block, body_idx);
+
+ Ok(crate::Statement::ImageStore {
+ image: image_lexp.handle,
+ coordinate,
+ array_index,
+ value,
+ })
+ }
+
+ pub(super) fn parse_image_load(
+ &mut self,
+ mut words_left: u16,
+ ctx: &mut super::BlockContext,
+ emitter: &mut crate::front::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let image_id = self.next()?;
+ let coordinate_id = self.next()?;
+
+ let mut image_ops = if words_left != 0 {
+ words_left -= 1;
+ self.next()?
+ } else {
+ 0
+ };
+
+ let mut sample = None;
+ let mut level = None;
+ while image_ops != 0 {
+ let bit = 1 << image_ops.trailing_zeros();
+ match spirv::ImageOperands::from_bits_truncate(bit) {
+ spirv::ImageOperands::LOD => {
+ let lod_expr = self.next()?;
+ let lod_lexp = self.lookup_expression.lookup(lod_expr)?;
+ let lod_handle =
+ self.get_expr_handle(lod_expr, lod_lexp, ctx, emitter, block, body_idx);
+ level = Some(lod_handle);
+ words_left -= 1;
+ }
+ spirv::ImageOperands::SAMPLE => {
+ let sample_expr = self.next()?;
+ let sample_handle = self.lookup_expression.lookup(sample_expr)?.handle;
+ sample = Some(sample_handle);
+ words_left -= 1;
+ }
+ other => {
+ log::warn!("Unknown image load op {:?}", other);
+ for _ in 0..words_left {
+ self.next()?;
+ }
+ break;
+ }
+ }
+ image_ops ^= bit;
+ }
+
+ // No need to call get_expr_handle here since only globals/arguments are
+ // allowed as images and they are always in the root scope
+ let image_lexp = self.lookup_expression.lookup(image_id)?;
+ let image_ty = ctx.get_image_expr_ty(image_lexp.handle)?;
+
+ let coord_lexp = self.lookup_expression.lookup(coordinate_id)?;
+ let coord_handle =
+ self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx);
+ let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle;
+ let (coordinate, array_index) = match ctx.type_arena[image_ty].inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: _,
+ } => extract_image_coordinates(
+ dim,
+ if arrayed {
+ ExtraCoordinate::ArrayLayer
+ } else {
+ ExtraCoordinate::Garbage
+ },
+ coord_handle,
+ coord_type_handle,
+ ctx,
+ ),
+ _ => return Err(Error::InvalidImage(image_ty)),
+ };
+
+ let expr = crate::Expression::ImageLoad {
+ image: image_lexp.handle,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn parse_image_sample(
+ &mut self,
+ mut words_left: u16,
+ options: SamplingOptions,
+ ctx: &mut super::BlockContext,
+ emitter: &mut crate::front::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let sampled_image_id = self.next()?;
+ let coordinate_id = self.next()?;
+ let dref_id = if options.compare {
+ Some(self.next()?)
+ } else {
+ None
+ };
+
+ let mut image_ops = if words_left != 0 {
+ words_left -= 1;
+ self.next()?
+ } else {
+ 0
+ };
+
+ let mut level = crate::SampleLevel::Auto;
+ let mut offset = None;
+ while image_ops != 0 {
+ let bit = 1 << image_ops.trailing_zeros();
+ match spirv::ImageOperands::from_bits_truncate(bit) {
+ spirv::ImageOperands::BIAS => {
+ let bias_expr = self.next()?;
+ let bias_lexp = self.lookup_expression.lookup(bias_expr)?;
+ let bias_handle =
+ self.get_expr_handle(bias_expr, bias_lexp, ctx, emitter, block, body_idx);
+ level = crate::SampleLevel::Bias(bias_handle);
+ words_left -= 1;
+ }
+ spirv::ImageOperands::LOD => {
+ let lod_expr = self.next()?;
+ let lod_lexp = self.lookup_expression.lookup(lod_expr)?;
+ let lod_handle =
+ self.get_expr_handle(lod_expr, lod_lexp, ctx, emitter, block, body_idx);
+ level = if options.compare {
+ log::debug!("Assuming {:?} is zero", lod_handle);
+ crate::SampleLevel::Zero
+ } else {
+ crate::SampleLevel::Exact(lod_handle)
+ };
+ words_left -= 1;
+ }
+ spirv::ImageOperands::GRAD => {
+ let grad_x_expr = self.next()?;
+ let grad_x_lexp = self.lookup_expression.lookup(grad_x_expr)?;
+ let grad_x_handle = self.get_expr_handle(
+ grad_x_expr,
+ grad_x_lexp,
+ ctx,
+ emitter,
+ block,
+ body_idx,
+ );
+ let grad_y_expr = self.next()?;
+ let grad_y_lexp = self.lookup_expression.lookup(grad_y_expr)?;
+ let grad_y_handle = self.get_expr_handle(
+ grad_y_expr,
+ grad_y_lexp,
+ ctx,
+ emitter,
+ block,
+ body_idx,
+ );
+ level = if options.compare {
+ log::debug!(
+ "Assuming gradients {:?} and {:?} are not greater than 1",
+ grad_x_handle,
+ grad_y_handle
+ );
+ crate::SampleLevel::Zero
+ } else {
+ crate::SampleLevel::Gradient {
+ x: grad_x_handle,
+ y: grad_y_handle,
+ }
+ };
+ words_left -= 2;
+ }
+ spirv::ImageOperands::CONST_OFFSET => {
+ let offset_constant = self.next()?;
+ let offset_handle = self.lookup_constant.lookup(offset_constant)?.handle;
+ offset = Some(offset_handle);
+ words_left -= 1;
+ }
+ other => {
+ log::warn!("Unknown image sample operand {:?}", other);
+ for _ in 0..words_left {
+ self.next()?;
+ }
+ break;
+ }
+ }
+ image_ops ^= bit;
+ }
+
+ let si_lexp = self.lookup_sampled_image.lookup(sampled_image_id)?;
+ let coord_lexp = self.lookup_expression.lookup(coordinate_id)?;
+ let coord_handle =
+ self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx);
+ let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle;
+
+ let sampling_bit = if options.compare {
+ SamplingFlags::COMPARISON
+ } else {
+ SamplingFlags::REGULAR
+ };
+
+ let image_ty = match ctx.expressions[si_lexp.image] {
+ crate::Expression::GlobalVariable(handle) => {
+ if let Some(flags) = self.handle_sampling.get_mut(&handle) {
+ *flags |= sampling_bit;
+ }
+
+ ctx.global_arena[handle].ty
+ }
+
+ crate::Expression::FunctionArgument(i) => {
+ ctx.parameter_sampling[i as usize] |= sampling_bit;
+ ctx.arguments[i as usize].ty
+ }
+
+ crate::Expression::Access { base, .. } => match ctx.expressions[base] {
+ crate::Expression::GlobalVariable(handle) => {
+ if let Some(flags) = self.handle_sampling.get_mut(&handle) {
+ *flags |= sampling_bit;
+ }
+
+ match ctx.type_arena[ctx.global_arena[handle].ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => base,
+ _ => return Err(Error::InvalidGlobalVar(ctx.expressions[base].clone())),
+ }
+ }
+
+ ref other => return Err(Error::InvalidGlobalVar(other.clone())),
+ },
+
+ ref other => return Err(Error::InvalidGlobalVar(other.clone())),
+ };
+
+ match ctx.expressions[si_lexp.sampler] {
+ crate::Expression::GlobalVariable(handle) => {
+ *self.handle_sampling.get_mut(&handle).unwrap() |= sampling_bit;
+ }
+
+ crate::Expression::FunctionArgument(i) => {
+ ctx.parameter_sampling[i as usize] |= sampling_bit;
+ }
+
+ crate::Expression::Access { base, .. } => match ctx.expressions[base] {
+ crate::Expression::GlobalVariable(handle) => {
+ *self.handle_sampling.get_mut(&handle).unwrap() |= sampling_bit;
+ }
+
+ ref other => return Err(Error::InvalidGlobalVar(other.clone())),
+ },
+
+ ref other => return Err(Error::InvalidGlobalVar(other.clone())),
+ }
+
+ let ((coordinate, array_index), depth_ref) = match ctx.type_arena[image_ty].inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: _,
+ } => (
+ extract_image_coordinates(
+ dim,
+ if options.project {
+ ExtraCoordinate::Projection
+ } else if arrayed {
+ ExtraCoordinate::ArrayLayer
+ } else {
+ ExtraCoordinate::Garbage
+ },
+ coord_handle,
+ coord_type_handle,
+ ctx,
+ ),
+ {
+ match dref_id {
+ Some(id) => {
+ let expr_lexp = self.lookup_expression.lookup(id)?;
+ let mut expr =
+ self.get_expr_handle(id, expr_lexp, ctx, emitter, block, body_idx);
+
+ if options.project {
+ let required_size = dim.required_coordinate_size();
+ let right = ctx.expressions.append(
+ crate::Expression::AccessIndex {
+ base: coord_handle,
+ index: required_size.map_or(1, |size| size as u32),
+ },
+ crate::Span::default(),
+ );
+ expr = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Divide,
+ left: expr,
+ right,
+ },
+ crate::Span::default(),
+ )
+ };
+ Some(expr)
+ }
+ None => None,
+ }
+ },
+ ),
+ _ => return Err(Error::InvalidImage(image_ty)),
+ };
+
+ let expr = crate::Expression::ImageSample {
+ image: si_lexp.image,
+ sampler: si_lexp.sampler,
+ gather: None, //TODO
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ pub(super) fn parse_image_query_size(
+ &mut self,
+ at_level: bool,
+ ctx: &mut super::BlockContext,
+ emitter: &mut crate::front::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let image_id = self.next()?;
+ let level = if at_level {
+ let level_id = self.next()?;
+ let level_lexp = self.lookup_expression.lookup(level_id)?;
+ Some(self.get_expr_handle(level_id, level_lexp, ctx, emitter, block, body_idx))
+ } else {
+ None
+ };
+
+ // No need to call get_expr_handle here since only globals/arguments are
+ // allowed as images and they are always in the root scope
+ //TODO: handle arrays and cubes
+ let image_lexp = self.lookup_expression.lookup(image_id)?;
+
+ let expr = crate::Expression::ImageQuery {
+ image: image_lexp.handle,
+ query: crate::ImageQuery::Size { level },
+ };
+ let expr = crate::Expression::As {
+ expr: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ kind: crate::ScalarKind::Sint,
+ convert: Some(4),
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ pub(super) fn parse_image_query_other(
+ &mut self,
+ query: crate::ImageQuery,
+ expressions: &mut Arena<crate::Expression>,
+ block_id: spirv::Word,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let image_id = self.next()?;
+
+ // No need to call get_expr_handle here since only globals/arguments are
+ // allowed as images and they are always in the root scope
+ let image_lexp = self.lookup_expression.lookup(image_id)?.clone();
+
+ let expr = crate::Expression::ImageQuery {
+ image: image_lexp.handle,
+ query,
+ };
+ let expr = crate::Expression::As {
+ expr: expressions.append(expr, self.span_from_with_op(start)),
+ kind: crate::ScalarKind::Sint,
+ convert: Some(4),
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/mod.rs b/third_party/rust/naga/src/front/spv/mod.rs
new file mode 100644
index 0000000000..c69a230cb0
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/mod.rs
@@ -0,0 +1,5195 @@
+/*!
+Frontend for [SPIR-V][spv] (Standard Portable Intermediate Representation).
+
+## ID lookups
+
+Our IR links to everything with `Handle`, while SPIR-V uses IDs.
+In order to keep track of the associations, the parser has many lookup tables.
+There map `spv::Word` into a specific IR handle, plus potentially a bit of
+extra info, such as the related SPIR-V type ID.
+TODO: would be nice to find ways that avoid looking up as much
+
+## Inputs/Outputs
+
+We create a private variable for each input/output. The relevant inputs are
+populated at the start of an entry point. The outputs are saved at the end.
+
+The function associated with an entry point is wrapped in another function,
+such that we can handle any `Return` statements without problems.
+
+## Row-major matrices
+
+We don't handle them natively, since the IR only expects column majority.
+Instead, we detect when such matrix is accessed in the `OpAccessChain`,
+and we generate a parallel expression that loads the value, but transposed.
+This value then gets used instead of `OpLoad` result later on.
+
+[spv]: https://www.khronos.org/registry/SPIR-V/
+*/
+
+mod convert;
+mod error;
+mod function;
+mod image;
+mod null;
+
+use convert::*;
+pub use error::Error;
+use function::*;
+
+use crate::{
+ arena::{Arena, Handle, UniqueArena},
+ proc::{Alignment, Layouter},
+ FastHashMap, FastHashSet,
+};
+
+use num_traits::cast::FromPrimitive;
+use petgraph::graphmap::GraphMap;
+use std::{convert::TryInto, mem, num::NonZeroU32, path::PathBuf};
+
+pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[
+ spirv::Capability::Shader,
+ spirv::Capability::VulkanMemoryModel,
+ spirv::Capability::ClipDistance,
+ spirv::Capability::CullDistance,
+ spirv::Capability::SampleRateShading,
+ spirv::Capability::DerivativeControl,
+ spirv::Capability::InterpolationFunction,
+ spirv::Capability::Matrix,
+ spirv::Capability::ImageQuery,
+ spirv::Capability::Sampled1D,
+ spirv::Capability::Image1D,
+ spirv::Capability::SampledCubeArray,
+ spirv::Capability::ImageCubeArray,
+ spirv::Capability::ImageMSArray,
+ spirv::Capability::StorageImageExtendedFormats,
+ spirv::Capability::Sampled1D,
+ spirv::Capability::SampledCubeArray,
+ spirv::Capability::Int8,
+ spirv::Capability::Int16,
+ spirv::Capability::Int64,
+ spirv::Capability::Float16,
+ spirv::Capability::Float64,
+ spirv::Capability::Geometry,
+ spirv::Capability::MultiView,
+ // tricky ones
+ spirv::Capability::UniformBufferArrayDynamicIndexing,
+ spirv::Capability::StorageBufferArrayDynamicIndexing,
+];
+pub const SUPPORTED_EXTENSIONS: &[&str] = &[
+ "SPV_KHR_storage_buffer_storage_class",
+ "SPV_KHR_vulkan_memory_model",
+ "SPV_KHR_multiview",
+];
+pub const SUPPORTED_EXT_SETS: &[&str] = &["GLSL.std.450"];
+
+#[derive(Copy, Clone)]
+pub struct Instruction {
+ op: spirv::Op,
+ wc: u16,
+}
+
+impl Instruction {
+ const fn expect(self, count: u16) -> Result<(), Error> {
+ if self.wc == count {
+ Ok(())
+ } else {
+ Err(Error::InvalidOperandCount(self.op, self.wc))
+ }
+ }
+
+ fn expect_at_least(self, count: u16) -> Result<u16, Error> {
+ self.wc
+ .checked_sub(count)
+ .ok_or(Error::InvalidOperandCount(self.op, self.wc))
+ }
+}
+
+impl crate::TypeInner {
+ fn can_comparison_sample(&self, module: &crate::Module) -> bool {
+ match *self {
+ crate::TypeInner::Image {
+ class:
+ crate::ImageClass::Sampled {
+ kind: crate::ScalarKind::Float,
+ multi: false,
+ },
+ ..
+ } => true,
+ crate::TypeInner::Sampler { .. } => true,
+ crate::TypeInner::BindingArray { base, .. } => {
+ module.types[base].inner.can_comparison_sample(module)
+ }
+ _ => false,
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
+pub enum ModuleState {
+ Empty,
+ Capability,
+ Extension,
+ ExtInstImport,
+ MemoryModel,
+ EntryPoint,
+ ExecutionMode,
+ Source,
+ Name,
+ ModuleProcessed,
+ Annotation,
+ Type,
+ Function,
+}
+
+trait LookupHelper {
+ type Target;
+ fn lookup(&self, key: spirv::Word) -> Result<&Self::Target, Error>;
+}
+
+impl<T> LookupHelper for FastHashMap<spirv::Word, T> {
+ type Target = T;
+ fn lookup(&self, key: spirv::Word) -> Result<&T, Error> {
+ self.get(&key).ok_or(Error::InvalidId(key))
+ }
+}
+
+impl crate::ImageDimension {
+ const fn required_coordinate_size(&self) -> Option<crate::VectorSize> {
+ match *self {
+ crate::ImageDimension::D1 => None,
+ crate::ImageDimension::D2 => Some(crate::VectorSize::Bi),
+ crate::ImageDimension::D3 => Some(crate::VectorSize::Tri),
+ crate::ImageDimension::Cube => Some(crate::VectorSize::Tri),
+ }
+ }
+}
+
+type MemberIndex = u32;
+
+bitflags::bitflags! {
+ #[derive(Default)]
+ struct DecorationFlags: u32 {
+ const NON_READABLE = 0x1;
+ const NON_WRITABLE = 0x2;
+ }
+}
+
+impl DecorationFlags {
+ fn to_storage_access(self) -> crate::StorageAccess {
+ let mut access = crate::StorageAccess::all();
+ if self.contains(DecorationFlags::NON_READABLE) {
+ access &= !crate::StorageAccess::LOAD;
+ }
+ if self.contains(DecorationFlags::NON_WRITABLE) {
+ access &= !crate::StorageAccess::STORE;
+ }
+ access
+ }
+}
+
+#[derive(Debug, PartialEq)]
+enum Majority {
+ Column,
+ Row,
+}
+
+#[derive(Debug, Default)]
+struct Decoration {
+ name: Option<String>,
+ built_in: Option<spirv::Word>,
+ location: Option<spirv::Word>,
+ desc_set: Option<spirv::Word>,
+ desc_index: Option<spirv::Word>,
+ specialization: Option<spirv::Word>,
+ storage_buffer: bool,
+ offset: Option<spirv::Word>,
+ array_stride: Option<NonZeroU32>,
+ matrix_stride: Option<NonZeroU32>,
+ matrix_major: Option<Majority>,
+ invariant: bool,
+ interpolation: Option<crate::Interpolation>,
+ sampling: Option<crate::Sampling>,
+ flags: DecorationFlags,
+}
+
+impl Decoration {
+ fn debug_name(&self) -> &str {
+ match self.name {
+ Some(ref name) => name.as_str(),
+ None => "?",
+ }
+ }
+
+ const fn resource_binding(&self) -> Option<crate::ResourceBinding> {
+ match *self {
+ Decoration {
+ desc_set: Some(group),
+ desc_index: Some(binding),
+ ..
+ } => Some(crate::ResourceBinding { group, binding }),
+ _ => None,
+ }
+ }
+
+ fn io_binding(&self) -> Result<crate::Binding, Error> {
+ match *self {
+ Decoration {
+ built_in: Some(built_in),
+ location: None,
+ invariant,
+ ..
+ } => Ok(crate::Binding::BuiltIn(map_builtin(built_in, invariant)?)),
+ Decoration {
+ built_in: None,
+ location: Some(location),
+ interpolation,
+ sampling,
+ ..
+ } => Ok(crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ }),
+ _ => Err(Error::MissingDecoration(spirv::Decoration::Location)),
+ }
+ }
+}
+
+#[derive(Debug)]
+struct LookupFunctionType {
+ parameter_type_ids: Vec<spirv::Word>,
+ return_type_id: spirv::Word,
+}
+
+struct LookupFunction {
+ handle: Handle<crate::Function>,
+ parameters_sampling: Vec<image::SamplingFlags>,
+}
+
+#[derive(Debug)]
+struct EntryPoint {
+ stage: crate::ShaderStage,
+ name: String,
+ early_depth_test: Option<crate::EarlyDepthTest>,
+ workgroup_size: [u32; 3],
+ variable_ids: Vec<spirv::Word>,
+}
+
+#[derive(Clone, Debug)]
+struct LookupType {
+ handle: Handle<crate::Type>,
+ base_id: Option<spirv::Word>,
+}
+
+#[derive(Debug)]
+struct LookupConstant {
+ handle: Handle<crate::Constant>,
+ type_id: spirv::Word,
+}
+
+#[derive(Debug)]
+enum Variable {
+ Global,
+ Input(crate::FunctionArgument),
+ Output(crate::FunctionResult),
+}
+
+#[derive(Debug)]
+struct LookupVariable {
+ inner: Variable,
+ handle: Handle<crate::GlobalVariable>,
+ type_id: spirv::Word,
+}
+
+/// Information about SPIR-V result ids, stored in `Parser::lookup_expression`.
+#[derive(Clone, Debug)]
+struct LookupExpression {
+ /// The `Expression` constructed for this result.
+ ///
+ /// Note that, while a SPIR-V result id can be used in any block dominated
+ /// by its definition, a Naga `Expression` is only in scope for the rest of
+ /// its subtree. `Parser::get_expr_handle` takes care of spilling the result
+ /// to a `LocalVariable` which can then be used anywhere.
+ handle: Handle<crate::Expression>,
+
+ /// The SPIR-V type of this result.
+ type_id: spirv::Word,
+
+ /// The label id of the block that defines this expression.
+ ///
+ /// This is zero for globals, constants, and function parameters, since they
+ /// originate outside any function's block.
+ block_id: spirv::Word,
+}
+
+#[derive(Debug)]
+struct LookupMember {
+ type_id: spirv::Word,
+ // This is true for either matrices, or arrays of matrices (yikes).
+ row_major: bool,
+}
+
+#[derive(Clone, Debug)]
+enum LookupLoadOverride {
+ /// For arrays of matrices, we track them but not loading yet.
+ Pending,
+ /// For matrices, vectors, and scalars, we pre-load the data.
+ Loaded(Handle<crate::Expression>),
+}
+
+#[derive(PartialEq)]
+enum ExtendedClass {
+ Global(crate::AddressSpace),
+ Input,
+ Output,
+}
+
+#[derive(Clone, Debug)]
+pub struct Options {
+ /// The IR coordinate space matches all the APIs except SPIR-V,
+ /// so by default we flip the Y coordinate of the `BuiltIn::Position`.
+ /// This flag can be used to avoid this.
+ pub adjust_coordinate_space: bool,
+ /// Only allow shaders with the known set of capabilities.
+ pub strict_capabilities: bool,
+ pub block_ctx_dump_prefix: Option<PathBuf>,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ adjust_coordinate_space: true,
+ strict_capabilities: false,
+ block_ctx_dump_prefix: None,
+ }
+ }
+}
+
+/// An index into the `BlockContext::bodies` table.
+type BodyIndex = usize;
+
+/// An intermediate representation of a Naga [`Statement`].
+///
+/// `Body` and `BodyFragment` values form a tree: the `BodyIndex` fields of the
+/// variants are indices of the child `Body` values in [`BlockContext::bodies`].
+/// The `lower` function assembles the final `Statement` tree from this `Body`
+/// tree. See [`BlockContext`] for details.
+///
+/// [`Statement`]: crate::Statement
+#[derive(Debug)]
+enum BodyFragment {
+ BlockId(spirv::Word),
+ If {
+ condition: Handle<crate::Expression>,
+ accept: BodyIndex,
+ reject: BodyIndex,
+ },
+ Loop {
+ body: BodyIndex,
+ continuing: BodyIndex,
+ },
+ Switch {
+ selector: Handle<crate::Expression>,
+ cases: Vec<(i32, BodyIndex)>,
+ default: BodyIndex,
+ },
+ Break,
+ Continue,
+}
+
+/// An intermediate representation of a Naga [`Block`].
+///
+/// This will be assembled into a `Block` once we've added spills for phi nodes
+/// and out-of-scope expressions. See [`BlockContext`] for details.
+///
+/// [`Block`]: crate::Block
+#[derive(Debug)]
+struct Body {
+ /// The index of the direct parent of this body
+ parent: usize,
+ data: Vec<BodyFragment>,
+}
+
+impl Body {
+ /// Creates a new empty `Body` with the specified `parent`
+ pub const fn with_parent(parent: usize) -> Self {
+ Body {
+ parent,
+ data: Vec::new(),
+ }
+ }
+}
+
+#[derive(Debug)]
+struct PhiExpression {
+ /// The local variable used for the phi node
+ local: Handle<crate::LocalVariable>,
+ /// List of (expression, block)
+ expressions: Vec<(spirv::Word, spirv::Word)>,
+}
+
+#[derive(Debug)]
+enum MergeBlockInformation {
+ LoopMerge,
+ LoopContinue,
+ SelectionMerge,
+ SwitchMerge,
+}
+
+/// Fragments of Naga IR, to be assembled into `Statements` once data flow is
+/// resolved.
+///
+/// We can't build a Naga `Statement` tree directly from SPIR-V blocks for two
+/// main reasons:
+///
+/// - A SPIR-V expression can be used in any SPIR-V block dominated by its
+/// definition, whereas Naga expressions are scoped to the rest of their
+/// subtree. This means that discovering an expression use later in the
+/// function retroactively requires us to have spilled that expression into a
+/// local variable back before we left its scope.
+///
+/// - We translate SPIR-V OpPhi expressions as Naga local variables in which we
+/// store the appropriate value before jumping to the OpPhi's block.
+///
+/// Both cases require us to go back and amend previously generated Naga IR
+/// based on things we discover later. But modifying old blocks in arbitrary
+/// spots in a `Statement` tree is awkward.
+///
+/// Instead, as we iterate through the function's body, we accumulate
+/// control-flow-free fragments of Naga IR in the [`blocks`] table, while
+/// building a skeleton of the Naga `Statement` tree in [`bodies`]. We note any
+/// spills and temporaries we must introduce in [`phis`].
+///
+/// Finally, once we've processed the entire function, we add temporaries and
+/// spills to the fragmentary `Blocks` as directed by `phis`, and assemble them
+/// into the final Naga `Statement` tree as directed by `bodies`.
+///
+/// [`blocks`]: BlockContext::blocks
+/// [`bodies`]: BlockContext::bodies
+/// [`phis`]: BlockContext::phis
+/// [`lower`]: function::lower
+#[derive(Debug)]
+struct BlockContext<'function> {
+ /// Phi nodes encountered when parsing the function, used to generate spills
+ /// to local variables.
+ phis: Vec<PhiExpression>,
+
+ /// Fragments of control-flow-free Naga IR.
+ ///
+ /// These will be stitched together into a proper `Statement` tree according
+ /// to `bodies`, once parsing is complete.
+ blocks: FastHashMap<spirv::Word, crate::Block>,
+
+ /// Map from block label ids to the index of the corresponding `Body` in
+ /// `bodies`.
+ body_for_label: FastHashMap<spirv::Word, BodyIndex>,
+
+ /// SPIR-V metadata about merge/continue blocks.
+ mergers: FastHashMap<spirv::Word, MergeBlockInformation>,
+
+ /// A table of `Body` values, each representing a block in the final IR.
+ bodies: Vec<Body>,
+
+ /// Id of the function currently being processed
+ function_id: spirv::Word,
+ /// Expression arena of the function currently being processed
+ expressions: &'function mut Arena<crate::Expression>,
+ /// Local variables arena of the function currently being processed
+ local_arena: &'function mut Arena<crate::LocalVariable>,
+ /// Constants arena of the module being processed
+ const_arena: &'function mut Arena<crate::Constant>,
+ /// Type arena of the module being processed
+ type_arena: &'function UniqueArena<crate::Type>,
+ /// Global arena of the module being processed
+ global_arena: &'function Arena<crate::GlobalVariable>,
+ /// Arguments of the function currently being processed
+ arguments: &'function [crate::FunctionArgument],
+ /// Metadata about the usage of function parameters as sampling objects
+ parameter_sampling: &'function mut [image::SamplingFlags],
+}
+
+enum SignAnchor {
+ Result,
+ Operand,
+}
+
+pub struct Frontend<I> {
+ data: I,
+ data_offset: usize,
+ state: ModuleState,
+ layouter: Layouter,
+ temp_bytes: Vec<u8>,
+ ext_glsl_id: Option<spirv::Word>,
+ future_decor: FastHashMap<spirv::Word, Decoration>,
+ future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>,
+ lookup_member: FastHashMap<(Handle<crate::Type>, MemberIndex), LookupMember>,
+ handle_sampling: FastHashMap<Handle<crate::GlobalVariable>, image::SamplingFlags>,
+ lookup_type: FastHashMap<spirv::Word, LookupType>,
+ lookup_void_type: Option<spirv::Word>,
+ lookup_storage_buffer_types: FastHashMap<Handle<crate::Type>, crate::StorageAccess>,
+ // Lookup for samplers and sampled images, storing flags on how they are used.
+ lookup_constant: FastHashMap<spirv::Word, LookupConstant>,
+ lookup_variable: FastHashMap<spirv::Word, LookupVariable>,
+ lookup_expression: FastHashMap<spirv::Word, LookupExpression>,
+ // Load overrides are used to work around row-major matrices
+ lookup_load_override: FastHashMap<spirv::Word, LookupLoadOverride>,
+ lookup_sampled_image: FastHashMap<spirv::Word, image::LookupSampledImage>,
+ lookup_function_type: FastHashMap<spirv::Word, LookupFunctionType>,
+ lookup_function: FastHashMap<spirv::Word, LookupFunction>,
+ lookup_entry_point: FastHashMap<spirv::Word, EntryPoint>,
+ //Note: each `OpFunctionCall` gets a single entry here, indexed by the
+ // dummy `Handle<crate::Function>` of the call site.
+ deferred_function_calls: Vec<spirv::Word>,
+ dummy_functions: Arena<crate::Function>,
+ // Graph of all function calls through the module.
+ // It's used to sort the functions (as nodes) topologically,
+ // so that in the IR any called function is already known.
+ function_call_graph: GraphMap<spirv::Word, (), petgraph::Directed>,
+ options: Options,
+ index_constants: Vec<Handle<crate::Constant>>,
+ index_constant_expressions: Vec<Handle<crate::Expression>>,
+
+ /// Maps for a switch from a case target to the respective body and associated literals that
+ /// use that target block id.
+ ///
+ /// Used to preserve allocations between instruction parsing.
+ switch_cases: indexmap::IndexMap<
+ spirv::Word,
+ (BodyIndex, Vec<i32>),
+ std::hash::BuildHasherDefault<rustc_hash::FxHasher>,
+ >,
+
+ /// Tracks access to gl_PerVertex's builtins, it is used to cull unused builtins since initializing those can
+ /// affect performance and the mere presence of some of these builtins might cause backends to error since they
+ /// might be unsupported.
+ ///
+ /// The problematic builtins are: PointSize, ClipDistance and CullDistance.
+ ///
+ /// glslang declares those by default even though they are never written to
+ /// (see <https://github.com/KhronosGroup/glslang/issues/1868>)
+ gl_per_vertex_builtin_access: FastHashSet<crate::BuiltIn>,
+}
+
+impl<I: Iterator<Item = u32>> Frontend<I> {
+ pub fn new(data: I, options: &Options) -> Self {
+ Frontend {
+ data,
+ data_offset: 0,
+ state: ModuleState::Empty,
+ layouter: Layouter::default(),
+ temp_bytes: Vec::new(),
+ ext_glsl_id: None,
+ future_decor: FastHashMap::default(),
+ future_member_decor: FastHashMap::default(),
+ handle_sampling: FastHashMap::default(),
+ lookup_member: FastHashMap::default(),
+ lookup_type: FastHashMap::default(),
+ lookup_void_type: None,
+ lookup_storage_buffer_types: FastHashMap::default(),
+ lookup_constant: FastHashMap::default(),
+ lookup_variable: FastHashMap::default(),
+ lookup_expression: FastHashMap::default(),
+ lookup_load_override: FastHashMap::default(),
+ lookup_sampled_image: FastHashMap::default(),
+ lookup_function_type: FastHashMap::default(),
+ lookup_function: FastHashMap::default(),
+ lookup_entry_point: FastHashMap::default(),
+ deferred_function_calls: Vec::default(),
+ dummy_functions: Arena::new(),
+ function_call_graph: GraphMap::new(),
+ options: options.clone(),
+ index_constants: Vec::new(),
+ index_constant_expressions: Vec::new(),
+ switch_cases: indexmap::IndexMap::default(),
+ gl_per_vertex_builtin_access: FastHashSet::default(),
+ }
+ }
+
+ fn span_from(&self, from: usize) -> crate::Span {
+ crate::Span::from(from..self.data_offset)
+ }
+
+ fn span_from_with_op(&self, from: usize) -> crate::Span {
+ crate::Span::from((from - 4)..self.data_offset)
+ }
+
+ fn next(&mut self) -> Result<u32, Error> {
+ if let Some(res) = self.data.next() {
+ self.data_offset += 4;
+ Ok(res)
+ } else {
+ Err(Error::IncompleteData)
+ }
+ }
+
+ fn next_inst(&mut self) -> Result<Instruction, Error> {
+ let word = self.next()?;
+ let (wc, opcode) = ((word >> 16) as u16, (word & 0xffff) as u16);
+ if wc == 0 {
+ return Err(Error::InvalidWordCount);
+ }
+ let op = spirv::Op::from_u16(opcode).ok_or(Error::UnknownInstruction(opcode))?;
+
+ Ok(Instruction { op, wc })
+ }
+
+ fn next_string(&mut self, mut count: u16) -> Result<(String, u16), Error> {
+ self.temp_bytes.clear();
+ loop {
+ if count == 0 {
+ return Err(Error::BadString);
+ }
+ count -= 1;
+ let chars = self.next()?.to_le_bytes();
+ let pos = chars.iter().position(|&c| c == 0).unwrap_or(4);
+ self.temp_bytes.extend_from_slice(&chars[..pos]);
+ if pos < 4 {
+ break;
+ }
+ }
+ std::str::from_utf8(&self.temp_bytes)
+ .map(|s| (s.to_owned(), count))
+ .map_err(|_| Error::BadString)
+ }
+
+ fn next_decoration(
+ &mut self,
+ inst: Instruction,
+ base_words: u16,
+ dec: &mut Decoration,
+ ) -> Result<(), Error> {
+ let raw = self.next()?;
+ let dec_typed = spirv::Decoration::from_u32(raw).ok_or(Error::InvalidDecoration(raw))?;
+ log::trace!("\t\t{}: {:?}", dec.debug_name(), dec_typed);
+ match dec_typed {
+ spirv::Decoration::BuiltIn => {
+ inst.expect(base_words + 2)?;
+ dec.built_in = Some(self.next()?);
+ }
+ spirv::Decoration::Location => {
+ inst.expect(base_words + 2)?;
+ dec.location = Some(self.next()?);
+ }
+ spirv::Decoration::DescriptorSet => {
+ inst.expect(base_words + 2)?;
+ dec.desc_set = Some(self.next()?);
+ }
+ spirv::Decoration::Binding => {
+ inst.expect(base_words + 2)?;
+ dec.desc_index = Some(self.next()?);
+ }
+ spirv::Decoration::BufferBlock => {
+ dec.storage_buffer = true;
+ }
+ spirv::Decoration::Offset => {
+ inst.expect(base_words + 2)?;
+ dec.offset = Some(self.next()?);
+ }
+ spirv::Decoration::ArrayStride => {
+ inst.expect(base_words + 2)?;
+ dec.array_stride = NonZeroU32::new(self.next()?);
+ }
+ spirv::Decoration::MatrixStride => {
+ inst.expect(base_words + 2)?;
+ dec.matrix_stride = NonZeroU32::new(self.next()?);
+ }
+ spirv::Decoration::Invariant => {
+ dec.invariant = true;
+ }
+ spirv::Decoration::NoPerspective => {
+ dec.interpolation = Some(crate::Interpolation::Linear);
+ }
+ spirv::Decoration::Flat => {
+ dec.interpolation = Some(crate::Interpolation::Flat);
+ }
+ spirv::Decoration::Centroid => {
+ dec.sampling = Some(crate::Sampling::Centroid);
+ }
+ spirv::Decoration::Sample => {
+ dec.sampling = Some(crate::Sampling::Sample);
+ }
+ spirv::Decoration::NonReadable => {
+ dec.flags |= DecorationFlags::NON_READABLE;
+ }
+ spirv::Decoration::NonWritable => {
+ dec.flags |= DecorationFlags::NON_WRITABLE;
+ }
+ spirv::Decoration::ColMajor => {
+ dec.matrix_major = Some(Majority::Column);
+ }
+ spirv::Decoration::RowMajor => {
+ dec.matrix_major = Some(Majority::Row);
+ }
+ spirv::Decoration::SpecId => {
+ dec.specialization = Some(self.next()?);
+ }
+ other => {
+ log::warn!("Unknown decoration {:?}", other);
+ for _ in base_words + 1..inst.wc {
+ let _var = self.next()?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ /// Return the Naga `Expression` for a given SPIR-V result `id`.
+ ///
+ /// `lookup` must be the `LookupExpression` for `id`.
+ ///
+ /// SPIR-V result ids can be used by any block dominated by the id's
+ /// definition, but Naga `Expressions` are only in scope for the remainder
+ /// of their `Statement` subtree. This means that the `Expression` generated
+ /// for `id` may no longer be in scope. In such cases, this function takes
+ /// care of spilling the value of `id` to a `LocalVariable` which can then
+ /// be used anywhere. The SPIR-V domination rule ensures that the
+ /// `LocalVariable` has been initialized before it is used.
+ ///
+ /// The `body_idx` argument should be the index of the `Body` that hopes to
+ /// use `id`'s `Expression`.
+ fn get_expr_handle(
+ &self,
+ id: spirv::Word,
+ lookup: &LookupExpression,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ body_idx: BodyIndex,
+ ) -> Handle<crate::Expression> {
+ // What `Body` was `id` defined in?
+ let expr_body_idx = ctx
+ .body_for_label
+ .get(&lookup.block_id)
+ .copied()
+ .unwrap_or(0);
+
+ // Don't need to do a load/store if the expression is in the main body
+ // or if the expression is in the same body as where the query was
+ // requested. The body_idx might actually not be the final one if a loop
+ // or conditional occurs but in those cases we know that the new body
+ // will be a subscope of the body that was passed so we can still reuse
+ // the handle and not issue a load/store.
+ if is_parent(body_idx, expr_body_idx, ctx) {
+ lookup.handle
+ } else {
+ // Add a temporary variable of the same type which will be used to
+ // store the original expression and used in the current block
+ let ty = self.lookup_type[&lookup.type_id].handle;
+ let local = ctx.local_arena.append(
+ crate::LocalVariable {
+ name: None,
+ ty,
+ init: None,
+ },
+ crate::Span::default(),
+ );
+
+ block.extend(emitter.finish(ctx.expressions));
+ let pointer = ctx.expressions.append(
+ crate::Expression::LocalVariable(local),
+ crate::Span::default(),
+ );
+ emitter.start(ctx.expressions);
+ let expr = ctx
+ .expressions
+ .append(crate::Expression::Load { pointer }, crate::Span::default());
+
+ // Add a slightly odd entry to the phi table, so that while `id`'s
+ // `Expression` is still in scope, the usual phi processing will
+ // spill its value to `local`, where we can find it later.
+ //
+ // This pretends that the block in which `id` is defined is the
+ // predecessor of some other block with a phi in it that cites id as
+ // one of its sources, and uses `local` as its variable. There is no
+ // such phi, but nobody needs to know that.
+ ctx.phis.push(PhiExpression {
+ local,
+ expressions: vec![(id, lookup.block_id)],
+ });
+
+ expr
+ }
+ }
+
+ fn parse_expr_unary_op(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::UnaryOperator,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p_id = self.next()?;
+
+ let p_lexp = self.lookup_expression.lookup(p_id)?;
+ let handle = self.get_expr_handle(p_id, p_lexp, ctx, emitter, block, body_idx);
+
+ let expr = crate::Expression::Unary { op, expr: handle };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_expr_binary_op(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::BinaryOperator,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx);
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx);
+
+ let expr = crate::Expression::Binary { op, left, right };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ /// A more complicated version of the unary op,
+ /// where we force the operand to have the same type as the result.
+ fn parse_expr_unary_op_sign_adjusted(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::UnaryOperator,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let span = self.span_from_with_op(start);
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx);
+
+ let result_lookup_ty = self.lookup_type.lookup(result_type_id)?;
+ let kind = ctx.type_arena[result_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+
+ let expr = crate::Expression::Unary {
+ op,
+ expr: if p1_lexp.type_id == result_type_id {
+ left
+ } else {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: left,
+ kind,
+ convert: None,
+ },
+ span,
+ )
+ },
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ /// A more complicated version of the binary op,
+ /// where we force the operand to have the same type as the result.
+ /// This is mostly needed for "i++" and "i--" coming from GLSL.
+ #[allow(clippy::too_many_arguments)]
+ fn parse_expr_binary_op_sign_adjusted(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::BinaryOperator,
+ // For arithmetic operations, we need the sign of operands to match the result.
+ // For boolean operations, however, the operands need to match the signs, but
+ // result is always different - a boolean.
+ anchor: SignAnchor,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+ let span = self.span_from_with_op(start);
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx);
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx);
+
+ let expected_type_id = match anchor {
+ SignAnchor::Result => result_type_id,
+ SignAnchor::Operand => p1_lexp.type_id,
+ };
+ let expected_lookup_ty = self.lookup_type.lookup(expected_type_id)?;
+ let kind = ctx.type_arena[expected_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+
+ let expr = crate::Expression::Binary {
+ op,
+ left: if p1_lexp.type_id == expected_type_id {
+ left
+ } else {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: left,
+ kind,
+ convert: None,
+ },
+ span,
+ )
+ },
+ right: if p2_lexp.type_id == expected_type_id {
+ right
+ } else {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: right,
+ kind,
+ convert: None,
+ },
+ span,
+ )
+ },
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ /// A version of the binary op where one or both of the arguments might need to be casted to a
+ /// specific integer kind (unsigned or signed), used for operations like OpINotEqual or
+ /// OpUGreaterThan.
+ #[allow(clippy::too_many_arguments)]
+ fn parse_expr_int_comparison(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::BinaryOperator,
+ kind: crate::ScalarKind,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+ let span = self.span_from_with_op(start);
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx);
+ let p1_lookup_ty = self.lookup_type.lookup(p1_lexp.type_id)?;
+ let p1_kind = ctx.type_arena[p1_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx);
+ let p2_lookup_ty = self.lookup_type.lookup(p2_lexp.type_id)?;
+ let p2_kind = ctx.type_arena[p2_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+
+ let expr = crate::Expression::Binary {
+ op,
+ left: if p1_kind == kind {
+ left
+ } else {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: left,
+ kind,
+ convert: None,
+ },
+ span,
+ )
+ },
+ right: if p2_kind == kind {
+ right
+ } else {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: right,
+ kind,
+ convert: None,
+ },
+ span,
+ )
+ },
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_expr_shift_op(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::BinaryOperator,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+
+ let span = self.span_from_with_op(start);
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx);
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let p2_handle = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx);
+ // convert the shift to Uint
+ let right = ctx.expressions.append(
+ crate::Expression::As {
+ expr: p2_handle,
+ kind: crate::ScalarKind::Uint,
+ convert: None,
+ },
+ span,
+ );
+
+ let expr = crate::Expression::Binary { op, left, right };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_expr_derivative(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut super::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ (axis, ctrl): (crate::DerivativeAxis, crate::DerivativeControl),
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let arg_id = self.next()?;
+
+ let arg_lexp = self.lookup_expression.lookup(arg_id)?;
+ let arg_handle = self.get_expr_handle(arg_id, arg_lexp, ctx, emitter, block, body_idx);
+
+ let expr = crate::Expression::Derivative {
+ axis,
+ ctrl,
+ expr: arg_handle,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn insert_composite(
+ &self,
+ root_expr: Handle<crate::Expression>,
+ root_type_id: spirv::Word,
+ object_expr: Handle<crate::Expression>,
+ selections: &[spirv::Word],
+ type_arena: &UniqueArena<crate::Type>,
+ expressions: &mut Arena<crate::Expression>,
+ constants: &Arena<crate::Constant>,
+ span: crate::Span,
+ ) -> Result<Handle<crate::Expression>, Error> {
+ let selection = match selections.first() {
+ Some(&index) => index,
+ None => return Ok(object_expr),
+ };
+ let root_span = expressions.get_span(root_expr);
+ let root_lookup = self.lookup_type.lookup(root_type_id)?;
+ let (count, child_type_id) = match type_arena[root_lookup.handle].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let child_member = self
+ .lookup_member
+ .get(&(root_lookup.handle, selection))
+ .ok_or(Error::InvalidAccessType(root_type_id))?;
+ (members.len(), child_member.type_id)
+ }
+ crate::TypeInner::Array { size, .. } => {
+ let size = match size {
+ crate::ArraySize::Constant(handle) => match constants[handle] {
+ crate::Constant {
+ specialization: Some(_),
+ ..
+ } => return Err(Error::UnsupportedType(root_lookup.handle)),
+ ref unspecialized => unspecialized
+ .to_array_length()
+ .ok_or(Error::InvalidArraySize(handle))?,
+ },
+ // A runtime sized array is not a composite type
+ crate::ArraySize::Dynamic => {
+ return Err(Error::InvalidAccessType(root_type_id))
+ }
+ };
+
+ let child_type_id = root_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(root_type_id))?;
+
+ (size as usize, child_type_id)
+ }
+ crate::TypeInner::Vector { size, .. }
+ | crate::TypeInner::Matrix { columns: size, .. } => {
+ let child_type_id = root_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(root_type_id))?;
+ (size as usize, child_type_id)
+ }
+ _ => return Err(Error::InvalidAccessType(root_type_id)),
+ };
+
+ let mut components = Vec::with_capacity(count);
+ for index in 0..count as u32 {
+ let expr = expressions.append(
+ crate::Expression::AccessIndex {
+ base: root_expr,
+ index,
+ },
+ if index == selection { span } else { root_span },
+ );
+ components.push(expr);
+ }
+ components[selection as usize] = self.insert_composite(
+ components[selection as usize],
+ child_type_id,
+ object_expr,
+ &selections[1..],
+ type_arena,
+ expressions,
+ constants,
+ span,
+ )?;
+
+ Ok(expressions.append(
+ crate::Expression::Compose {
+ ty: root_lookup.handle,
+ components,
+ },
+ span,
+ ))
+ }
+
+ /// Add the next SPIR-V block's contents to `block_ctx`.
+ ///
+ /// Except for the function's entry block, `block_id` should be the label of
+ /// a block we've seen mentioned before, with an entry in
+ /// `block_ctx.body_for_label` to tell us which `Body` it contributes to.
+ fn next_block(&mut self, block_id: spirv::Word, ctx: &mut BlockContext) -> Result<(), Error> {
+ // Extend `body` with the correct form for a branch to `target`.
+ fn merger(body: &mut Body, target: &MergeBlockInformation) {
+ body.data.push(match *target {
+ MergeBlockInformation::LoopContinue => BodyFragment::Continue,
+ MergeBlockInformation::LoopMerge | MergeBlockInformation::SwitchMerge => {
+ BodyFragment::Break
+ }
+
+ // Finishing a selection merge means just falling off the end of
+ // the `accept` or `reject` block of the `If` statement.
+ MergeBlockInformation::SelectionMerge => return,
+ })
+ }
+
+ let mut emitter = super::Emitter::default();
+ emitter.start(ctx.expressions);
+
+ // Find the `Body` that this block belongs to. Index zero is the
+ // function's root `Body`, corresponding to `Function::body`.
+ let mut body_idx = *ctx.body_for_label.entry(block_id).or_default();
+ let mut block = crate::Block::new();
+ // Stores the merge block as defined by a `OpSelectionMerge` otherwise is `None`
+ //
+ // This is used in `OpSwitch` to promote the `MergeBlockInformation` from
+ // `SelectionMerge` to `SwitchMerge` to allow `Break`s this isn't desirable for
+ // `LoopMerge`s because otherwise `Continue`s wouldn't be allowed
+ let mut selection_merge_block = None;
+
+ macro_rules! get_expr_handle {
+ ($id:expr, $lexp:expr) => {
+ self.get_expr_handle($id, $lexp, ctx, &mut emitter, &mut block, body_idx)
+ };
+ }
+ macro_rules! parse_expr_op {
+ ($op:expr, BINARY) => {
+ self.parse_expr_binary_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op)
+ };
+
+ ($op:expr, SHIFT) => {
+ self.parse_expr_shift_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op)
+ };
+ ($op:expr, UNARY) => {
+ self.parse_expr_unary_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op)
+ };
+ ($axis:expr, $ctrl:expr, DERIVATIVE) => {
+ self.parse_expr_derivative(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ ($axis, $ctrl),
+ )
+ };
+ }
+
+ let terminator = loop {
+ use spirv::Op;
+ let start = self.data_offset;
+ let inst = self.next_inst()?;
+ let span = crate::Span::from(start..(start + 4 * (inst.wc as usize)));
+ log::debug!("\t\t{:?} [{}]", inst.op, inst.wc);
+
+ match inst.op {
+ Op::Line => {
+ inst.expect(4)?;
+ let _file_id = self.next()?;
+ let _row_id = self.next()?;
+ let _col_id = self.next()?;
+ }
+ Op::NoLine => inst.expect(1)?,
+ Op::Undef => {
+ inst.expect(3)?;
+ let (type_id, id, handle) =
+ self.parse_null_constant(inst, ctx.type_arena, ctx.const_arena)?;
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle: ctx
+ .expressions
+ .append(crate::Expression::Constant(handle), span),
+ type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Variable => {
+ inst.expect_at_least(4)?;
+ block.extend(emitter.finish(ctx.expressions));
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let _storage_class = self.next()?;
+ let init = if inst.wc > 4 {
+ inst.expect(5)?;
+ let init_id = self.next()?;
+ let lconst = self.lookup_constant.lookup(init_id)?;
+ Some(lconst.handle)
+ } else {
+ None
+ };
+
+ let name = self
+ .future_decor
+ .remove(&result_id)
+ .and_then(|decor| decor.name);
+ if let Some(ref name) = name {
+ log::debug!("\t\t\tid={} name={}", result_id, name);
+ }
+ let lookup_ty = self.lookup_type.lookup(result_type_id)?;
+ let var_handle = ctx.local_arena.append(
+ crate::LocalVariable {
+ name,
+ ty: match ctx.type_arena[lookup_ty.handle].inner {
+ crate::TypeInner::Pointer { base, .. } => base,
+ _ => lookup_ty.handle,
+ },
+ init,
+ },
+ span,
+ );
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx
+ .expressions
+ .append(crate::Expression::LocalVariable(var_handle), span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ emitter.start(ctx.expressions);
+ }
+ Op::Phi => {
+ inst.expect_at_least(3)?;
+ block.extend(emitter.finish(ctx.expressions));
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+
+ let name = format!("phi_{result_id}");
+ let local = ctx.local_arena.append(
+ crate::LocalVariable {
+ name: Some(name),
+ ty: self.lookup_type.lookup(result_type_id)?.handle,
+ init: None,
+ },
+ self.span_from(start),
+ );
+ let pointer = ctx
+ .expressions
+ .append(crate::Expression::LocalVariable(local), span);
+
+ let in_count = (inst.wc - 3) / 2;
+ let mut phi = PhiExpression {
+ local,
+ expressions: Vec::with_capacity(in_count as usize),
+ };
+ for _ in 0..in_count {
+ let expr = self.next()?;
+ let block = self.next()?;
+ phi.expressions.push((expr, block));
+ }
+
+ ctx.phis.push(phi);
+ emitter.start(ctx.expressions);
+
+ // Associate the lookup with an actual value, which is emitted
+ // into the current block.
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx
+ .expressions
+ .append(crate::Expression::Load { pointer }, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::AccessChain | Op::InBoundsAccessChain => {
+ struct AccessExpression {
+ base_handle: Handle<crate::Expression>,
+ type_id: spirv::Word,
+ load_override: Option<LookupLoadOverride>,
+ }
+
+ inst.expect_at_least(4)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ log::trace!("\t\t\tlooking up expr {:?}", base_id);
+
+ let mut acex = {
+ let lexp = self.lookup_expression.lookup(base_id)?;
+ let lty = self.lookup_type.lookup(lexp.type_id)?;
+
+ // HACK `OpAccessChain` and `OpInBoundsAccessChain`
+ // require for the result type to be a pointer, but if
+ // we're given a pointer to an image / sampler, it will
+ // be *already* dereferenced, since we do that early
+ // during `parse_type_pointer()`.
+ //
+ // This can happen only through `BindingArray`, since
+ // that's the only case where one can obtain a pointer
+ // to an image / sampler, and so let's match on that:
+ let dereference = match ctx.type_arena[lty.handle].inner {
+ crate::TypeInner::BindingArray { .. } => false,
+ _ => true,
+ };
+
+ let type_id = if dereference {
+ lty.base_id.ok_or(Error::InvalidAccessType(lexp.type_id))?
+ } else {
+ lexp.type_id
+ };
+
+ AccessExpression {
+ base_handle: get_expr_handle!(base_id, lexp),
+ type_id,
+ load_override: self.lookup_load_override.get(&base_id).cloned(),
+ }
+ };
+
+ for _ in 4..inst.wc {
+ let access_id = self.next()?;
+ log::trace!("\t\t\tlooking up index expr {:?}", access_id);
+ let index_expr = self.lookup_expression.lookup(access_id)?.clone();
+ let index_expr_handle = get_expr_handle!(access_id, &index_expr);
+ let index_expr_data = &ctx.expressions[index_expr.handle];
+ let index_maybe = match *index_expr_data {
+ crate::Expression::Constant(const_handle) => {
+ Some(ctx.const_arena[const_handle].to_array_length().ok_or(
+ Error::InvalidAccess(crate::Expression::Constant(const_handle)),
+ )?)
+ }
+ _ => None,
+ };
+
+ log::trace!("\t\t\tlooking up type {:?}", acex.type_id);
+ let type_lookup = self.lookup_type.lookup(acex.type_id)?;
+ let ty = &ctx.type_arena[type_lookup.handle];
+ acex = match ty.inner {
+ // can only index a struct with a constant
+ crate::TypeInner::Struct { ref members, .. } => {
+ let index = index_maybe
+ .ok_or_else(|| Error::InvalidAccess(index_expr_data.clone()))?;
+
+ let lookup_member = self
+ .lookup_member
+ .get(&(type_lookup.handle, index))
+ .ok_or(Error::InvalidAccessType(acex.type_id))?;
+ let base_handle = ctx.expressions.append(
+ crate::Expression::AccessIndex {
+ base: acex.base_handle,
+ index,
+ },
+ span,
+ );
+
+ if ty.name.as_deref() == Some("gl_PerVertex") {
+ if let Some(crate::Binding::BuiltIn(built_in)) =
+ members[index as usize].binding
+ {
+ self.gl_per_vertex_builtin_access.insert(built_in);
+ }
+ }
+
+ AccessExpression {
+ base_handle,
+ type_id: lookup_member.type_id,
+ load_override: if lookup_member.row_major {
+ debug_assert!(acex.load_override.is_none());
+ let sub_type_lookup =
+ self.lookup_type.lookup(lookup_member.type_id)?;
+ Some(match ctx.type_arena[sub_type_lookup.handle].inner {
+ // load it transposed, to match column major expectations
+ crate::TypeInner::Matrix { .. } => {
+ let loaded = ctx.expressions.append(
+ crate::Expression::Load {
+ pointer: base_handle,
+ },
+ span,
+ );
+ let transposed = ctx.expressions.append(
+ crate::Expression::Math {
+ fun: crate::MathFunction::Transpose,
+ arg: loaded,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ span,
+ );
+ LookupLoadOverride::Loaded(transposed)
+ }
+ _ => LookupLoadOverride::Pending,
+ })
+ } else {
+ None
+ },
+ }
+ }
+ crate::TypeInner::Matrix { .. } => {
+ let load_override = match acex.load_override {
+ // We are indexing inside a row-major matrix
+ Some(LookupLoadOverride::Loaded(load_expr)) => {
+ let index = index_maybe.ok_or_else(|| {
+ Error::InvalidAccess(index_expr_data.clone())
+ })?;
+ let sub_handle = ctx.expressions.append(
+ crate::Expression::AccessIndex {
+ base: load_expr,
+ index,
+ },
+ span,
+ );
+ Some(LookupLoadOverride::Loaded(sub_handle))
+ }
+ _ => None,
+ };
+ let sub_expr = match index_maybe {
+ Some(index) => crate::Expression::AccessIndex {
+ base: acex.base_handle,
+ index,
+ },
+ None => crate::Expression::Access {
+ base: acex.base_handle,
+ index: index_expr_handle,
+ },
+ };
+ AccessExpression {
+ base_handle: ctx.expressions.append(sub_expr, span),
+ type_id: type_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(acex.type_id))?,
+ load_override,
+ }
+ }
+ // This must be a vector or an array.
+ _ => {
+ let base_handle = ctx.expressions.append(
+ crate::Expression::Access {
+ base: acex.base_handle,
+ index: index_expr_handle,
+ },
+ span,
+ );
+ let load_override = match acex.load_override {
+ // If there is a load override in place, then we always end up
+ // with a side-loaded value here.
+ Some(lookup_load_override) => {
+ let sub_expr = match lookup_load_override {
+ // We must be indexing into the array of row-major matrices.
+ // Let's load the result of indexing and transpose it.
+ LookupLoadOverride::Pending => {
+ let loaded = ctx.expressions.append(
+ crate::Expression::Load {
+ pointer: base_handle,
+ },
+ span,
+ );
+ ctx.expressions.append(
+ crate::Expression::Math {
+ fun: crate::MathFunction::Transpose,
+ arg: loaded,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ span,
+ )
+ }
+ // We are indexing inside a row-major matrix.
+ LookupLoadOverride::Loaded(load_expr) => {
+ ctx.expressions.append(
+ crate::Expression::Access {
+ base: load_expr,
+ index: index_expr_handle,
+ },
+ span,
+ )
+ }
+ };
+ Some(LookupLoadOverride::Loaded(sub_expr))
+ }
+ None => None,
+ };
+ AccessExpression {
+ base_handle,
+ type_id: type_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(acex.type_id))?,
+ load_override,
+ }
+ }
+ };
+ }
+
+ if let Some(load_expr) = acex.load_override {
+ self.lookup_load_override.insert(result_id, load_expr);
+ }
+ let lookup_expression = LookupExpression {
+ handle: acex.base_handle,
+ type_id: result_type_id,
+ block_id,
+ };
+ self.lookup_expression.insert(result_id, lookup_expression);
+ }
+ Op::VectorExtractDynamic => {
+ inst.expect(5)?;
+
+ let result_type_id = self.next()?;
+ let id = self.next()?;
+ let composite_id = self.next()?;
+ let index_id = self.next()?;
+
+ let root_lexp = self.lookup_expression.lookup(composite_id)?;
+ let root_handle = get_expr_handle!(composite_id, root_lexp);
+ let root_type_lookup = self.lookup_type.lookup(root_lexp.type_id)?;
+ let index_lexp = self.lookup_expression.lookup(index_id)?;
+ let index_handle = get_expr_handle!(index_id, index_lexp);
+
+ let num_components = match ctx.type_arena[root_type_lookup.handle].inner {
+ crate::TypeInner::Vector { size, .. } => size as usize,
+ _ => return Err(Error::InvalidVectorType(root_type_lookup.handle)),
+ };
+
+ let mut handle = ctx.expressions.append(
+ crate::Expression::Access {
+ base: root_handle,
+ index: self.index_constant_expressions[0],
+ },
+ span,
+ );
+ for &index_expr in self.index_constant_expressions[1..num_components].iter() {
+ let access_expr = ctx.expressions.append(
+ crate::Expression::Access {
+ base: root_handle,
+ index: index_expr,
+ },
+ span,
+ );
+ let cond = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Equal,
+ left: index_expr,
+ right: index_handle,
+ },
+ span,
+ );
+ handle = ctx.expressions.append(
+ crate::Expression::Select {
+ condition: cond,
+ accept: access_expr,
+ reject: handle,
+ },
+ span,
+ );
+ }
+
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::VectorInsertDynamic => {
+ inst.expect(6)?;
+
+ let result_type_id = self.next()?;
+ let id = self.next()?;
+ let composite_id = self.next()?;
+ let object_id = self.next()?;
+ let index_id = self.next()?;
+
+ let object_lexp = self.lookup_expression.lookup(object_id)?;
+ let object_handle = get_expr_handle!(object_id, object_lexp);
+ let root_lexp = self.lookup_expression.lookup(composite_id)?;
+ let root_handle = get_expr_handle!(composite_id, root_lexp);
+ let root_type_lookup = self.lookup_type.lookup(root_lexp.type_id)?;
+ let index_lexp = self.lookup_expression.lookup(index_id)?;
+ let mut index_handle = get_expr_handle!(index_id, index_lexp);
+ let index_type = self.lookup_type.lookup(index_lexp.type_id)?.handle;
+
+ // SPIR-V allows signed and unsigned indices but naga's is strict about
+ // types and since the `index_constants` are all signed integers, we need
+ // to cast the index to a signed integer if it's unsigned.
+ if let Some(crate::ScalarKind::Uint) =
+ ctx.type_arena[index_type].inner.scalar_kind()
+ {
+ index_handle = ctx.expressions.append(
+ crate::Expression::As {
+ expr: index_handle,
+ kind: crate::ScalarKind::Sint,
+ convert: None,
+ },
+ span,
+ )
+ }
+
+ let num_components = match ctx.type_arena[root_type_lookup.handle].inner {
+ crate::TypeInner::Vector { size, .. } => size as usize,
+ _ => return Err(Error::InvalidVectorType(root_type_lookup.handle)),
+ };
+ let mut components = Vec::with_capacity(num_components);
+ for &index_expr in self.index_constant_expressions[..num_components].iter() {
+ let access_expr = ctx.expressions.append(
+ crate::Expression::Access {
+ base: root_handle,
+ index: index_expr,
+ },
+ span,
+ );
+ let cond = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Equal,
+ left: index_expr,
+ right: index_handle,
+ },
+ span,
+ );
+ let handle = ctx.expressions.append(
+ crate::Expression::Select {
+ condition: cond,
+ accept: object_handle,
+ reject: access_expr,
+ },
+ span,
+ );
+ components.push(handle);
+ }
+ let handle = ctx.expressions.append(
+ crate::Expression::Compose {
+ ty: root_type_lookup.handle,
+ components,
+ },
+ span,
+ );
+
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::CompositeExtract => {
+ inst.expect_at_least(4)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ log::trace!("\t\t\tlooking up expr {:?}", base_id);
+ let mut lexp = self.lookup_expression.lookup(base_id)?.clone();
+ lexp.handle = get_expr_handle!(base_id, &lexp);
+ for _ in 4..inst.wc {
+ let index = self.next()?;
+ log::trace!("\t\t\tlooking up type {:?}", lexp.type_id);
+ let type_lookup = self.lookup_type.lookup(lexp.type_id)?;
+ let type_id = match ctx.type_arena[type_lookup.handle].inner {
+ crate::TypeInner::Struct { .. } => {
+ self.lookup_member
+ .get(&(type_lookup.handle, index))
+ .ok_or(Error::InvalidAccessType(lexp.type_id))?
+ .type_id
+ }
+ crate::TypeInner::Array { .. }
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. } => type_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(lexp.type_id))?,
+ ref other => {
+ log::warn!("composite type {:?}", other);
+ return Err(Error::UnsupportedType(type_lookup.handle));
+ }
+ };
+ lexp = LookupExpression {
+ handle: ctx.expressions.append(
+ crate::Expression::AccessIndex {
+ base: lexp.handle,
+ index,
+ },
+ span,
+ ),
+ type_id,
+ block_id,
+ };
+ }
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: lexp.handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::CompositeInsert => {
+ inst.expect_at_least(5)?;
+
+ let result_type_id = self.next()?;
+ let id = self.next()?;
+ let object_id = self.next()?;
+ let composite_id = self.next()?;
+ let mut selections = Vec::with_capacity(inst.wc as usize - 5);
+ for _ in 5..inst.wc {
+ selections.push(self.next()?);
+ }
+
+ let object_lexp = self.lookup_expression.lookup(object_id)?.clone();
+ let object_handle = get_expr_handle!(object_id, &object_lexp);
+ let root_lexp = self.lookup_expression.lookup(composite_id)?.clone();
+ let root_handle = get_expr_handle!(composite_id, &root_lexp);
+ let handle = self.insert_composite(
+ root_handle,
+ result_type_id,
+ object_handle,
+ &selections,
+ ctx.type_arena,
+ ctx.expressions,
+ ctx.const_arena,
+ span,
+ )?;
+
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::CompositeConstruct => {
+ inst.expect_at_least(3)?;
+
+ let result_type_id = self.next()?;
+ let id = self.next()?;
+ let mut components = Vec::with_capacity(inst.wc as usize - 2);
+ for _ in 3..inst.wc {
+ let comp_id = self.next()?;
+ log::trace!("\t\t\tlooking up expr {:?}", comp_id);
+ let lexp = self.lookup_expression.lookup(comp_id)?;
+ let handle = get_expr_handle!(comp_id, lexp);
+ components.push(handle);
+ }
+ let ty = self.lookup_type.lookup(result_type_id)?.handle;
+ let first = components[0];
+ let expr = match ctx.type_arena[ty].inner {
+ // this is an optimization to detect the splat
+ crate::TypeInner::Vector { size, .. }
+ if components.len() == size as usize
+ && components[1..].iter().all(|&c| c == first) =>
+ {
+ crate::Expression::Splat { size, value: first }
+ }
+ _ => crate::Expression::Compose { ty, components },
+ };
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Load => {
+ inst.expect_at_least(4)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let pointer_id = self.next()?;
+ if inst.wc != 4 {
+ inst.expect(5)?;
+ let _memory_access = self.next()?;
+ }
+
+ let base_lexp = self.lookup_expression.lookup(pointer_id)?;
+ let base_handle = get_expr_handle!(pointer_id, base_lexp);
+ let type_lookup = self.lookup_type.lookup(base_lexp.type_id)?;
+ let handle = match ctx.type_arena[type_lookup.handle].inner {
+ crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => {
+ base_handle
+ }
+ _ => match self.lookup_load_override.get(&pointer_id) {
+ Some(&LookupLoadOverride::Loaded(handle)) => handle,
+ //Note: we aren't handling `LookupLoadOverride::Pending` properly here
+ _ => ctx.expressions.append(
+ crate::Expression::Load {
+ pointer: base_handle,
+ },
+ span,
+ ),
+ },
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Store => {
+ inst.expect_at_least(3)?;
+
+ let pointer_id = self.next()?;
+ let value_id = self.next()?;
+ if inst.wc != 3 {
+ inst.expect(4)?;
+ let _memory_access = self.next()?;
+ }
+ let base_expr = self.lookup_expression.lookup(pointer_id)?;
+ let base_handle = get_expr_handle!(pointer_id, base_expr);
+ let value_expr = self.lookup_expression.lookup(value_id)?;
+ let value_handle = get_expr_handle!(value_id, value_expr);
+
+ block.extend(emitter.finish(ctx.expressions));
+ block.push(
+ crate::Statement::Store {
+ pointer: base_handle,
+ value: value_handle,
+ },
+ span,
+ );
+ emitter.start(ctx.expressions);
+ }
+ // Arithmetic Instructions +, -, *, /, %
+ Op::SNegate | Op::FNegate => {
+ inst.expect(4)?;
+ self.parse_expr_unary_op_sign_adjusted(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ crate::UnaryOperator::Negate,
+ )?;
+ }
+ Op::IAdd
+ | Op::ISub
+ | Op::IMul
+ | Op::BitwiseOr
+ | Op::BitwiseXor
+ | Op::BitwiseAnd
+ | Op::SDiv
+ | Op::SRem => {
+ inst.expect(5)?;
+ let operator = map_binary_operator(inst.op)?;
+ self.parse_expr_binary_op_sign_adjusted(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ operator,
+ SignAnchor::Result,
+ )?;
+ }
+ Op::IEqual | Op::INotEqual => {
+ inst.expect(5)?;
+ let operator = map_binary_operator(inst.op)?;
+ self.parse_expr_binary_op_sign_adjusted(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ operator,
+ SignAnchor::Operand,
+ )?;
+ }
+ Op::FAdd => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Add, BINARY)?;
+ }
+ Op::FSub => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Subtract, BINARY)?;
+ }
+ Op::FMul => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Multiply, BINARY)?;
+ }
+ Op::UDiv | Op::FDiv => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Divide, BINARY)?;
+ }
+ Op::UMod | Op::FRem => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Modulo, BINARY)?;
+ }
+ Op::SMod => {
+ inst.expect(5)?;
+
+ // x - y * int(floor(float(x) / float(y)))
+
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+ let span = self.span_from_with_op(start);
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(
+ p1_id,
+ p1_lexp,
+ ctx,
+ &mut emitter,
+ &mut block,
+ body_idx,
+ );
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let right = self.get_expr_handle(
+ p2_id,
+ p2_lexp,
+ ctx,
+ &mut emitter,
+ &mut block,
+ body_idx,
+ );
+
+ let result_ty = self.lookup_type.lookup(result_type_id)?;
+ let inner = &ctx.type_arena[result_ty.handle].inner;
+ let kind = inner.scalar_kind().unwrap();
+ let size = inner.size(ctx.const_arena) as u8;
+
+ let left_cast = ctx.expressions.append(
+ crate::Expression::As {
+ expr: left,
+ kind: crate::ScalarKind::Float,
+ convert: Some(size),
+ },
+ span,
+ );
+ let right_cast = ctx.expressions.append(
+ crate::Expression::As {
+ expr: right,
+ kind: crate::ScalarKind::Float,
+ convert: Some(size),
+ },
+ span,
+ );
+ let div = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Divide,
+ left: left_cast,
+ right: right_cast,
+ },
+ span,
+ );
+ let floor = ctx.expressions.append(
+ crate::Expression::Math {
+ fun: crate::MathFunction::Floor,
+ arg: div,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ span,
+ );
+ let cast = ctx.expressions.append(
+ crate::Expression::As {
+ expr: floor,
+ kind,
+ convert: Some(size),
+ },
+ span,
+ );
+ let mult = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Multiply,
+ left: cast,
+ right,
+ },
+ span,
+ );
+ let sub = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Subtract,
+ left,
+ right: mult,
+ },
+ span,
+ );
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: sub,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::FMod => {
+ inst.expect(5)?;
+
+ // x - y * floor(x / y)
+
+ let start = self.data_offset;
+ let span = self.span_from_with_op(start);
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(
+ p1_id,
+ p1_lexp,
+ ctx,
+ &mut emitter,
+ &mut block,
+ body_idx,
+ );
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let right = self.get_expr_handle(
+ p2_id,
+ p2_lexp,
+ ctx,
+ &mut emitter,
+ &mut block,
+ body_idx,
+ );
+
+ let div = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Divide,
+ left,
+ right,
+ },
+ span,
+ );
+ let floor = ctx.expressions.append(
+ crate::Expression::Math {
+ fun: crate::MathFunction::Floor,
+ arg: div,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ span,
+ );
+ let mult = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Multiply,
+ left: floor,
+ right,
+ },
+ span,
+ );
+ let sub = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Subtract,
+ left,
+ right: mult,
+ },
+ span,
+ );
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: sub,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::VectorTimesScalar
+ | Op::VectorTimesMatrix
+ | Op::MatrixTimesScalar
+ | Op::MatrixTimesVector
+ | Op::MatrixTimesMatrix => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Multiply, BINARY)?;
+ }
+ Op::Transpose => {
+ inst.expect(4)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let matrix_id = self.next()?;
+ let matrix_lexp = self.lookup_expression.lookup(matrix_id)?;
+ let matrix_handle = get_expr_handle!(matrix_id, matrix_lexp);
+ let expr = crate::Expression::Math {
+ fun: crate::MathFunction::Transpose,
+ arg: matrix_handle,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Dot => {
+ inst.expect(5)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let left_id = self.next()?;
+ let right_id = self.next()?;
+ let left_lexp = self.lookup_expression.lookup(left_id)?;
+ let left_handle = get_expr_handle!(left_id, left_lexp);
+ let right_lexp = self.lookup_expression.lookup(right_id)?;
+ let right_handle = get_expr_handle!(right_id, right_lexp);
+ let expr = crate::Expression::Math {
+ fun: crate::MathFunction::Dot,
+ arg: left_handle,
+ arg1: Some(right_handle),
+ arg2: None,
+ arg3: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::BitFieldInsert => {
+ inst.expect(7)?;
+
+ let start = self.data_offset;
+ let span = self.span_from_with_op(start);
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ let insert_id = self.next()?;
+ let offset_id = self.next()?;
+ let count_id = self.next()?;
+ let base_lexp = self.lookup_expression.lookup(base_id)?;
+ let base_handle = get_expr_handle!(base_id, base_lexp);
+ let insert_lexp = self.lookup_expression.lookup(insert_id)?;
+ let insert_handle = get_expr_handle!(insert_id, insert_lexp);
+ let offset_lexp = self.lookup_expression.lookup(offset_id)?;
+ let offset_handle = get_expr_handle!(offset_id, offset_lexp);
+ let offset_lookup_ty = self.lookup_type.lookup(offset_lexp.type_id)?;
+ let count_lexp = self.lookup_expression.lookup(count_id)?;
+ let count_handle = get_expr_handle!(count_id, count_lexp);
+ let count_lookup_ty = self.lookup_type.lookup(count_lexp.type_id)?;
+
+ let offset_kind = ctx.type_arena[offset_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+ let count_kind = ctx.type_arena[count_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+
+ let offset_cast_handle = if offset_kind != crate::ScalarKind::Uint {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: offset_handle,
+ kind: crate::ScalarKind::Uint,
+ convert: None,
+ },
+ span,
+ )
+ } else {
+ offset_handle
+ };
+
+ let count_cast_handle = if count_kind != crate::ScalarKind::Uint {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: count_handle,
+ kind: crate::ScalarKind::Uint,
+ convert: None,
+ },
+ span,
+ )
+ } else {
+ count_handle
+ };
+
+ let expr = crate::Expression::Math {
+ fun: crate::MathFunction::InsertBits,
+ arg: base_handle,
+ arg1: Some(insert_handle),
+ arg2: Some(offset_cast_handle),
+ arg3: Some(count_cast_handle),
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::BitFieldSExtract | Op::BitFieldUExtract => {
+ inst.expect(6)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ let offset_id = self.next()?;
+ let count_id = self.next()?;
+ let base_lexp = self.lookup_expression.lookup(base_id)?;
+ let base_handle = get_expr_handle!(base_id, base_lexp);
+ let offset_lexp = self.lookup_expression.lookup(offset_id)?;
+ let offset_handle = get_expr_handle!(offset_id, offset_lexp);
+ let offset_lookup_ty = self.lookup_type.lookup(offset_lexp.type_id)?;
+ let count_lexp = self.lookup_expression.lookup(count_id)?;
+ let count_handle = get_expr_handle!(count_id, count_lexp);
+ let count_lookup_ty = self.lookup_type.lookup(count_lexp.type_id)?;
+
+ let offset_kind = ctx.type_arena[offset_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+ let count_kind = ctx.type_arena[count_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+
+ let offset_cast_handle = if offset_kind != crate::ScalarKind::Uint {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: offset_handle,
+ kind: crate::ScalarKind::Uint,
+ convert: None,
+ },
+ span,
+ )
+ } else {
+ offset_handle
+ };
+
+ let count_cast_handle = if count_kind != crate::ScalarKind::Uint {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: count_handle,
+ kind: crate::ScalarKind::Uint,
+ convert: None,
+ },
+ span,
+ )
+ } else {
+ count_handle
+ };
+
+ let expr = crate::Expression::Math {
+ fun: crate::MathFunction::ExtractBits,
+ arg: base_handle,
+ arg1: Some(offset_cast_handle),
+ arg2: Some(count_cast_handle),
+ arg3: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::BitReverse | Op::BitCount => {
+ inst.expect(4)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ let base_lexp = self.lookup_expression.lookup(base_id)?;
+ let base_handle = get_expr_handle!(base_id, base_lexp);
+ let expr = crate::Expression::Math {
+ fun: match inst.op {
+ Op::BitReverse => crate::MathFunction::ReverseBits,
+ Op::BitCount => crate::MathFunction::CountOneBits,
+ _ => unreachable!(),
+ },
+ arg: base_handle,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::OuterProduct => {
+ inst.expect(5)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let left_id = self.next()?;
+ let right_id = self.next()?;
+ let left_lexp = self.lookup_expression.lookup(left_id)?;
+ let left_handle = get_expr_handle!(left_id, left_lexp);
+ let right_lexp = self.lookup_expression.lookup(right_id)?;
+ let right_handle = get_expr_handle!(right_id, right_lexp);
+ let expr = crate::Expression::Math {
+ fun: crate::MathFunction::Outer,
+ arg: left_handle,
+ arg1: Some(right_handle),
+ arg2: None,
+ arg3: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ // Bitwise instructions
+ Op::Not => {
+ inst.expect(4)?;
+ self.parse_expr_unary_op_sign_adjusted(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ crate::UnaryOperator::Not,
+ )?;
+ }
+ Op::ShiftRightLogical => {
+ inst.expect(5)?;
+ //TODO: convert input and result to usigned
+ parse_expr_op!(crate::BinaryOperator::ShiftRight, SHIFT)?;
+ }
+ Op::ShiftRightArithmetic => {
+ inst.expect(5)?;
+ //TODO: convert input and result to signed
+ parse_expr_op!(crate::BinaryOperator::ShiftRight, SHIFT)?;
+ }
+ Op::ShiftLeftLogical => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::ShiftLeft, SHIFT)?;
+ }
+ // Sampling
+ Op::Image => {
+ inst.expect(4)?;
+ self.parse_image_uncouple(block_id)?;
+ }
+ Op::SampledImage => {
+ inst.expect(5)?;
+ self.parse_image_couple()?;
+ }
+ Op::ImageWrite => {
+ let extra = inst.expect_at_least(4)?;
+ let stmt =
+ self.parse_image_write(extra, ctx, &mut emitter, &mut block, body_idx)?;
+ block.extend(emitter.finish(ctx.expressions));
+ block.push(stmt, span);
+ emitter.start(ctx.expressions);
+ }
+ Op::ImageFetch | Op::ImageRead => {
+ let extra = inst.expect_at_least(5)?;
+ self.parse_image_load(
+ extra,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageSampleImplicitLod | Op::ImageSampleExplicitLod => {
+ let extra = inst.expect_at_least(5)?;
+ let options = image::SamplingOptions {
+ compare: false,
+ project: false,
+ };
+ self.parse_image_sample(
+ extra,
+ options,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageSampleProjImplicitLod | Op::ImageSampleProjExplicitLod => {
+ let extra = inst.expect_at_least(5)?;
+ let options = image::SamplingOptions {
+ compare: false,
+ project: true,
+ };
+ self.parse_image_sample(
+ extra,
+ options,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageSampleDrefImplicitLod | Op::ImageSampleDrefExplicitLod => {
+ let extra = inst.expect_at_least(6)?;
+ let options = image::SamplingOptions {
+ compare: true,
+ project: false,
+ };
+ self.parse_image_sample(
+ extra,
+ options,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageSampleProjDrefImplicitLod | Op::ImageSampleProjDrefExplicitLod => {
+ let extra = inst.expect_at_least(6)?;
+ let options = image::SamplingOptions {
+ compare: true,
+ project: true,
+ };
+ self.parse_image_sample(
+ extra,
+ options,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageQuerySize => {
+ inst.expect(4)?;
+ self.parse_image_query_size(
+ false,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageQuerySizeLod => {
+ inst.expect(5)?;
+ self.parse_image_query_size(
+ true,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageQueryLevels => {
+ inst.expect(4)?;
+ self.parse_image_query_other(
+ crate::ImageQuery::NumLevels,
+ ctx.expressions,
+ block_id,
+ )?;
+ }
+ Op::ImageQuerySamples => {
+ inst.expect(4)?;
+ self.parse_image_query_other(
+ crate::ImageQuery::NumSamples,
+ ctx.expressions,
+ block_id,
+ )?;
+ }
+ // other ops
+ Op::Select => {
+ inst.expect(6)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let condition = self.next()?;
+ let o1_id = self.next()?;
+ let o2_id = self.next()?;
+
+ let cond_lexp = self.lookup_expression.lookup(condition)?;
+ let cond_handle = get_expr_handle!(condition, cond_lexp);
+ let o1_lexp = self.lookup_expression.lookup(o1_id)?;
+ let o1_handle = get_expr_handle!(o1_id, o1_lexp);
+ let o2_lexp = self.lookup_expression.lookup(o2_id)?;
+ let o2_handle = get_expr_handle!(o2_id, o2_lexp);
+
+ let expr = crate::Expression::Select {
+ condition: cond_handle,
+ accept: o1_handle,
+ reject: o2_handle,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::VectorShuffle => {
+ inst.expect_at_least(5)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let v1_id = self.next()?;
+ let v2_id = self.next()?;
+
+ let v1_lexp = self.lookup_expression.lookup(v1_id)?;
+ let v1_lty = self.lookup_type.lookup(v1_lexp.type_id)?;
+ let v1_handle = get_expr_handle!(v1_id, v1_lexp);
+ let n1 = match ctx.type_arena[v1_lty.handle].inner {
+ crate::TypeInner::Vector { size, .. } => size as u32,
+ _ => return Err(Error::InvalidInnerType(v1_lexp.type_id)),
+ };
+ let v2_lexp = self.lookup_expression.lookup(v2_id)?;
+ let v2_lty = self.lookup_type.lookup(v2_lexp.type_id)?;
+ let v2_handle = get_expr_handle!(v2_id, v2_lexp);
+ let n2 = match ctx.type_arena[v2_lty.handle].inner {
+ crate::TypeInner::Vector { size, .. } => size as u32,
+ _ => return Err(Error::InvalidInnerType(v2_lexp.type_id)),
+ };
+
+ self.temp_bytes.clear();
+ let mut max_component = 0;
+ for _ in 5..inst.wc as usize {
+ let mut index = self.next()?;
+ if index == u32::MAX {
+ // treat Undefined as X
+ index = 0;
+ }
+ max_component = max_component.max(index);
+ self.temp_bytes.push(index as u8);
+ }
+
+ // Check for swizzle first.
+ let expr = if max_component < n1 {
+ use crate::SwizzleComponent as Sc;
+ let size = match self.temp_bytes.len() {
+ 2 => crate::VectorSize::Bi,
+ 3 => crate::VectorSize::Tri,
+ _ => crate::VectorSize::Quad,
+ };
+ let mut pattern = [Sc::X; 4];
+ for (pat, index) in pattern.iter_mut().zip(self.temp_bytes.drain(..)) {
+ *pat = match index {
+ 0 => Sc::X,
+ 1 => Sc::Y,
+ 2 => Sc::Z,
+ _ => Sc::W,
+ };
+ }
+ crate::Expression::Swizzle {
+ size,
+ vector: v1_handle,
+ pattern,
+ }
+ } else {
+ // Fall back to access + compose
+ let mut components = Vec::with_capacity(self.temp_bytes.len());
+ for index in self.temp_bytes.drain(..).map(|i| i as u32) {
+ let expr = if index < n1 {
+ crate::Expression::AccessIndex {
+ base: v1_handle,
+ index,
+ }
+ } else if index < n1 + n2 {
+ crate::Expression::AccessIndex {
+ base: v2_handle,
+ index: index - n1,
+ }
+ } else {
+ return Err(Error::InvalidAccessIndex(index));
+ };
+ components.push(ctx.expressions.append(expr, span));
+ }
+ crate::Expression::Compose {
+ ty: self.lookup_type.lookup(result_type_id)?.handle,
+ components,
+ }
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Bitcast
+ | Op::ConvertSToF
+ | Op::ConvertUToF
+ | Op::ConvertFToU
+ | Op::ConvertFToS
+ | Op::FConvert
+ | Op::UConvert
+ | Op::SConvert => {
+ inst.expect(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let value_id = self.next()?;
+
+ let value_lexp = self.lookup_expression.lookup(value_id)?;
+ let ty_lookup = self.lookup_type.lookup(result_type_id)?;
+ let (kind, width) = match ctx.type_arena[ty_lookup.handle].inner {
+ crate::TypeInner::Scalar { kind, width }
+ | crate::TypeInner::Vector { kind, width, .. } => (kind, width),
+ crate::TypeInner::Matrix { width, .. } => (crate::ScalarKind::Float, width),
+ _ => return Err(Error::InvalidAsType(ty_lookup.handle)),
+ };
+
+ let expr = crate::Expression::As {
+ expr: get_expr_handle!(value_id, value_lexp),
+ kind,
+ convert: if kind == crate::ScalarKind::Bool {
+ Some(crate::BOOL_WIDTH)
+ } else if inst.op == Op::Bitcast {
+ None
+ } else {
+ Some(width)
+ },
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::FunctionCall => {
+ inst.expect_at_least(4)?;
+ block.extend(emitter.finish(ctx.expressions));
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let func_id = self.next()?;
+
+ let mut arguments = Vec::with_capacity(inst.wc as usize - 4);
+ for _ in 0..arguments.capacity() {
+ let arg_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(arg_id)?;
+ arguments.push(get_expr_handle!(arg_id, lexp));
+ }
+
+ // We just need an unique handle here, nothing more.
+ let function = self.add_call(ctx.function_id, func_id);
+
+ let result = if self.lookup_void_type == Some(result_type_id) {
+ None
+ } else {
+ let expr_handle = ctx
+ .expressions
+ .append(crate::Expression::CallResult(function), span);
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expr_handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Some(expr_handle)
+ };
+ block.push(
+ crate::Statement::Call {
+ function,
+ arguments,
+ result,
+ },
+ span,
+ );
+ emitter.start(ctx.expressions);
+ }
+ Op::ExtInst => {
+ use crate::MathFunction as Mf;
+ use spirv::GLOp as Glo;
+
+ let base_wc = 5;
+ inst.expect_at_least(base_wc)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let set_id = self.next()?;
+ if Some(set_id) != self.ext_glsl_id {
+ return Err(Error::UnsupportedExtInstSet(set_id));
+ }
+ let inst_id = self.next()?;
+ let gl_op = Glo::from_u32(inst_id).ok_or(Error::UnsupportedExtInst(inst_id))?;
+
+ let fun = match gl_op {
+ Glo::Round => Mf::Round,
+ Glo::RoundEven => Mf::Round,
+ Glo::Trunc => Mf::Trunc,
+ Glo::FAbs | Glo::SAbs => Mf::Abs,
+ Glo::FSign | Glo::SSign => Mf::Sign,
+ Glo::Floor => Mf::Floor,
+ Glo::Ceil => Mf::Ceil,
+ Glo::Fract => Mf::Fract,
+ Glo::Sin => Mf::Sin,
+ Glo::Cos => Mf::Cos,
+ Glo::Tan => Mf::Tan,
+ Glo::Asin => Mf::Asin,
+ Glo::Acos => Mf::Acos,
+ Glo::Atan => Mf::Atan,
+ Glo::Sinh => Mf::Sinh,
+ Glo::Cosh => Mf::Cosh,
+ Glo::Tanh => Mf::Tanh,
+ Glo::Atan2 => Mf::Atan2,
+ Glo::Asinh => Mf::Asinh,
+ Glo::Acosh => Mf::Acosh,
+ Glo::Atanh => Mf::Atanh,
+ Glo::Radians => Mf::Radians,
+ Glo::Degrees => Mf::Degrees,
+ Glo::Pow => Mf::Pow,
+ Glo::Exp => Mf::Exp,
+ Glo::Log => Mf::Log,
+ Glo::Exp2 => Mf::Exp2,
+ Glo::Log2 => Mf::Log2,
+ Glo::Sqrt => Mf::Sqrt,
+ Glo::InverseSqrt => Mf::InverseSqrt,
+ Glo::MatrixInverse => Mf::Inverse,
+ Glo::Determinant => Mf::Determinant,
+ Glo::Modf => Mf::Modf,
+ Glo::FMin | Glo::UMin | Glo::SMin | Glo::NMin => Mf::Min,
+ Glo::FMax | Glo::UMax | Glo::SMax | Glo::NMax => Mf::Max,
+ Glo::FClamp | Glo::UClamp | Glo::SClamp | Glo::NClamp => Mf::Clamp,
+ Glo::FMix => Mf::Mix,
+ Glo::Step => Mf::Step,
+ Glo::SmoothStep => Mf::SmoothStep,
+ Glo::Fma => Mf::Fma,
+ Glo::Frexp => Mf::Frexp, //TODO: FrexpStruct?
+ Glo::Ldexp => Mf::Ldexp,
+ Glo::Length => Mf::Length,
+ Glo::Distance => Mf::Distance,
+ Glo::Cross => Mf::Cross,
+ Glo::Normalize => Mf::Normalize,
+ Glo::FaceForward => Mf::FaceForward,
+ Glo::Reflect => Mf::Reflect,
+ Glo::Refract => Mf::Refract,
+ Glo::PackUnorm4x8 => Mf::Pack4x8unorm,
+ Glo::PackSnorm4x8 => Mf::Pack4x8snorm,
+ Glo::PackHalf2x16 => Mf::Pack2x16float,
+ Glo::PackUnorm2x16 => Mf::Pack2x16unorm,
+ Glo::PackSnorm2x16 => Mf::Pack2x16snorm,
+ Glo::UnpackUnorm4x8 => Mf::Unpack4x8unorm,
+ Glo::UnpackSnorm4x8 => Mf::Unpack4x8snorm,
+ Glo::UnpackHalf2x16 => Mf::Unpack2x16float,
+ Glo::UnpackUnorm2x16 => Mf::Unpack2x16unorm,
+ Glo::UnpackSnorm2x16 => Mf::Unpack2x16snorm,
+ Glo::FindILsb => Mf::FindLsb,
+ Glo::FindUMsb | Glo::FindSMsb => Mf::FindMsb,
+ _ => return Err(Error::UnsupportedExtInst(inst_id)),
+ };
+
+ let arg_count = fun.argument_count();
+ inst.expect(base_wc + arg_count as u16)?;
+ let arg = {
+ let arg_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(arg_id)?;
+ get_expr_handle!(arg_id, lexp)
+ };
+ let arg1 = if arg_count > 1 {
+ let arg_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(arg_id)?;
+ Some(get_expr_handle!(arg_id, lexp))
+ } else {
+ None
+ };
+ let arg2 = if arg_count > 2 {
+ let arg_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(arg_id)?;
+ Some(get_expr_handle!(arg_id, lexp))
+ } else {
+ None
+ };
+ let arg3 = if arg_count > 3 {
+ let arg_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(arg_id)?;
+ Some(get_expr_handle!(arg_id, lexp))
+ } else {
+ None
+ };
+
+ let expr = crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ // Relational and Logical Instructions
+ Op::LogicalNot => {
+ inst.expect(4)?;
+ parse_expr_op!(crate::UnaryOperator::Not, UNARY)?;
+ }
+ Op::LogicalOr => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::LogicalOr, BINARY)?;
+ }
+ Op::LogicalAnd => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::LogicalAnd, BINARY)?;
+ }
+ Op::SGreaterThan | Op::SGreaterThanEqual | Op::SLessThan | Op::SLessThanEqual => {
+ inst.expect(5)?;
+ self.parse_expr_int_comparison(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ map_binary_operator(inst.op)?,
+ crate::ScalarKind::Sint,
+ )?;
+ }
+ Op::UGreaterThan | Op::UGreaterThanEqual | Op::ULessThan | Op::ULessThanEqual => {
+ inst.expect(5)?;
+ self.parse_expr_int_comparison(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ map_binary_operator(inst.op)?,
+ crate::ScalarKind::Uint,
+ )?;
+ }
+ Op::FOrdEqual
+ | Op::FUnordEqual
+ | Op::FOrdNotEqual
+ | Op::FUnordNotEqual
+ | Op::FOrdLessThan
+ | Op::FUnordLessThan
+ | Op::FOrdGreaterThan
+ | Op::FUnordGreaterThan
+ | Op::FOrdLessThanEqual
+ | Op::FUnordLessThanEqual
+ | Op::FOrdGreaterThanEqual
+ | Op::FUnordGreaterThanEqual
+ | Op::LogicalEqual
+ | Op::LogicalNotEqual => {
+ inst.expect(5)?;
+ let operator = map_binary_operator(inst.op)?;
+ parse_expr_op!(operator, BINARY)?;
+ }
+ Op::Any | Op::All | Op::IsNan | Op::IsInf | Op::IsFinite | Op::IsNormal => {
+ inst.expect(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let arg_id = self.next()?;
+
+ let arg_lexp = self.lookup_expression.lookup(arg_id)?;
+ let arg_handle = get_expr_handle!(arg_id, arg_lexp);
+
+ let expr = crate::Expression::Relational {
+ fun: map_relational_fun(inst.op)?,
+ argument: arg_handle,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Kill => {
+ inst.expect(1)?;
+ break Some(crate::Statement::Kill);
+ }
+ Op::Unreachable => {
+ inst.expect(1)?;
+ break None;
+ }
+ Op::Return => {
+ inst.expect(1)?;
+ break Some(crate::Statement::Return { value: None });
+ }
+ Op::ReturnValue => {
+ inst.expect(2)?;
+ let value_id = self.next()?;
+ let value_lexp = self.lookup_expression.lookup(value_id)?;
+ let value_handle = get_expr_handle!(value_id, value_lexp);
+ break Some(crate::Statement::Return {
+ value: Some(value_handle),
+ });
+ }
+ Op::Branch => {
+ inst.expect(2)?;
+ let target_id = self.next()?;
+
+ // If this is a branch to a merge or continue block,
+ // then that ends the current body.
+ if let Some(info) = ctx.mergers.get(&target_id) {
+ block.extend(emitter.finish(ctx.expressions));
+ ctx.blocks.insert(block_id, block);
+ let body = &mut ctx.bodies[body_idx];
+ body.data.push(BodyFragment::BlockId(block_id));
+
+ merger(body, info);
+
+ return Ok(());
+ }
+
+ // Since the target of the branch has no merge information,
+ // this must be the only branch to that block. This means
+ // we can treat it as an extension of the current `Body`.
+ //
+ // NOTE: it's possible that another branch was already made to this block
+ // setting the body index in which case it SHOULD NOT be overridden.
+ // For example a switch with falltrough, the OpSwitch will set the body to
+ // the respective case and the case may branch to another case in which case
+ // the body index shouldn't be changed
+ ctx.body_for_label.entry(target_id).or_insert(body_idx);
+
+ break None;
+ }
+ Op::BranchConditional => {
+ inst.expect_at_least(4)?;
+
+ let condition = {
+ let condition_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(condition_id)?;
+ get_expr_handle!(condition_id, lexp)
+ };
+
+ block.extend(emitter.finish(ctx.expressions));
+ ctx.blocks.insert(block_id, block);
+ let body = &mut ctx.bodies[body_idx];
+ body.data.push(BodyFragment::BlockId(block_id));
+
+ let true_id = self.next()?;
+ let false_id = self.next()?;
+
+ let same_target = true_id == false_id;
+
+ // Start a body block for the `accept` branch.
+ let accept = ctx.bodies.len();
+ let mut accept_block = Body::with_parent(body_idx);
+
+ // If the `OpBranchConditional`target is somebody else's
+ // merge or continue block, then put a `Break` or `Continue`
+ // statement in this new body block.
+ if let Some(info) = ctx.mergers.get(&true_id) {
+ merger(
+ match same_target {
+ true => &mut ctx.bodies[body_idx],
+ false => &mut accept_block,
+ },
+ info,
+ )
+ } else {
+ // Note the body index for the block we're branching to.
+ let prev = ctx.body_for_label.insert(
+ true_id,
+ match same_target {
+ true => body_idx,
+ false => accept,
+ },
+ );
+ debug_assert!(prev.is_none());
+ }
+
+ if same_target {
+ return Ok(());
+ }
+
+ ctx.bodies.push(accept_block);
+
+ // Handle the `reject` branch just like the `accept` block.
+ let reject = ctx.bodies.len();
+ let mut reject_block = Body::with_parent(body_idx);
+
+ if let Some(info) = ctx.mergers.get(&false_id) {
+ merger(&mut reject_block, info)
+ } else {
+ let prev = ctx.body_for_label.insert(false_id, reject);
+ debug_assert!(prev.is_none());
+ }
+
+ ctx.bodies.push(reject_block);
+
+ let body = &mut ctx.bodies[body_idx];
+ body.data.push(BodyFragment::If {
+ condition,
+ accept,
+ reject,
+ });
+
+ // Consume branch weights
+ for _ in 4..inst.wc {
+ let _ = self.next()?;
+ }
+
+ return Ok(());
+ }
+ Op::Switch => {
+ inst.expect_at_least(3)?;
+ let selector = self.next()?;
+ let default_id = self.next()?;
+
+ // If the previous instruction was a `OpSelectionMerge` then we must
+ // promote the `MergeBlockInformation` to a `SwitchMerge`
+ if let Some(merge) = selection_merge_block {
+ ctx.mergers
+ .insert(merge, MergeBlockInformation::SwitchMerge);
+ }
+
+ let default = ctx.bodies.len();
+ ctx.bodies.push(Body::with_parent(body_idx));
+ ctx.body_for_label.entry(default_id).or_insert(default);
+
+ let selector_lexp = &self.lookup_expression[&selector];
+ let selector_lty = self.lookup_type.lookup(selector_lexp.type_id)?;
+ let selector_handle = get_expr_handle!(selector, selector_lexp);
+ let selector = match ctx.type_arena[selector_lty.handle].inner {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: _,
+ } => {
+ // IR expects a signed integer, so do a bitcast
+ ctx.expressions.append(
+ crate::Expression::As {
+ kind: crate::ScalarKind::Sint,
+ expr: selector_handle,
+ convert: None,
+ },
+ span,
+ )
+ }
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: _,
+ } => selector_handle,
+ ref other => unimplemented!("Unexpected selector {:?}", other),
+ };
+
+ // Clear past switch cases to prevent them from entering this one
+ self.switch_cases.clear();
+
+ for _ in 0..(inst.wc - 3) / 2 {
+ let literal = self.next()?;
+ let target = self.next()?;
+
+ let case_body_idx = ctx.bodies.len();
+
+ // Check if any previous case already used this target block id, if so
+ // group them together to reorder them later so that no weird
+ // falltrough cases happen.
+ if let Some(&mut (_, ref mut literals)) = self.switch_cases.get_mut(&target)
+ {
+ literals.push(literal as i32);
+ continue;
+ }
+
+ let mut body = Body::with_parent(body_idx);
+
+ if let Some(info) = ctx.mergers.get(&target) {
+ merger(&mut body, info);
+ }
+
+ ctx.bodies.push(body);
+ ctx.body_for_label.entry(target).or_insert(case_body_idx);
+
+ // Register this target block id as already having been processed and
+ // the respective body index assigned and the first case value
+ self.switch_cases
+ .insert(target, (case_body_idx, vec![literal as i32]));
+ }
+
+ // Loop trough the collected target blocks creating a new case for each
+ // literal pointing to it, only one case will have the true body and all the
+ // others will be empty falltrough so that they all execute the same body
+ // without duplicating code.
+ //
+ // Since `switch_cases` is an indexmap the order of insertion is preserved
+ // this is needed because spir-v defines falltrough order in the switch
+ // instruction.
+ let mut cases = Vec::with_capacity((inst.wc as usize - 3) / 2);
+ for &(case_body_idx, ref literals) in self.switch_cases.values() {
+ let value = literals[0];
+
+ for &literal in literals.iter().skip(1) {
+ let empty_body_idx = ctx.bodies.len();
+ let body = Body::with_parent(body_idx);
+
+ ctx.bodies.push(body);
+
+ cases.push((literal, empty_body_idx));
+ }
+
+ cases.push((value, case_body_idx));
+ }
+
+ block.extend(emitter.finish(ctx.expressions));
+
+ let body = &mut ctx.bodies[body_idx];
+ ctx.blocks.insert(block_id, block);
+ // Make sure the vector has space for at least two more allocations
+ body.data.reserve(2);
+ body.data.push(BodyFragment::BlockId(block_id));
+ body.data.push(BodyFragment::Switch {
+ selector,
+ cases,
+ default,
+ });
+
+ return Ok(());
+ }
+ Op::SelectionMerge => {
+ inst.expect(3)?;
+ let merge_block_id = self.next()?;
+ // TODO: Selection Control Mask
+ let _selection_control = self.next()?;
+
+ // Indicate that the merge block is a continuation of the
+ // current `Body`.
+ ctx.body_for_label.entry(merge_block_id).or_insert(body_idx);
+
+ // Let subsequent branches to the merge block know that
+ // they've reached the end of the selection construct.
+ ctx.mergers
+ .insert(merge_block_id, MergeBlockInformation::SelectionMerge);
+
+ selection_merge_block = Some(merge_block_id);
+ }
+ Op::LoopMerge => {
+ inst.expect_at_least(4)?;
+ let merge_block_id = self.next()?;
+ let continuing = self.next()?;
+
+ // TODO: Loop Control Parameters
+ for _ in 0..inst.wc - 3 {
+ self.next()?;
+ }
+
+ // Indicate that the merge block is a continuation of the
+ // current `Body`.
+ ctx.body_for_label.entry(merge_block_id).or_insert(body_idx);
+ // Let subsequent branches to the merge block know that
+ // they're `Break` statements.
+ ctx.mergers
+ .insert(merge_block_id, MergeBlockInformation::LoopMerge);
+
+ let loop_body_idx = ctx.bodies.len();
+ ctx.bodies.push(Body::with_parent(body_idx));
+
+ let continue_idx = ctx.bodies.len();
+ // The continue block inherits the scope of the loop body
+ ctx.bodies.push(Body::with_parent(loop_body_idx));
+ ctx.body_for_label.entry(continuing).or_insert(continue_idx);
+ // Let subsequent branches to the continue block know that
+ // they're `Continue` statements.
+ ctx.mergers
+ .insert(continuing, MergeBlockInformation::LoopContinue);
+
+ // The loop header always belongs to the loop body
+ ctx.body_for_label.insert(block_id, loop_body_idx);
+
+ let parent_body = &mut ctx.bodies[body_idx];
+ parent_body.data.push(BodyFragment::Loop {
+ body: loop_body_idx,
+ continuing: continue_idx,
+ });
+ body_idx = loop_body_idx;
+ }
+ Op::DPdxCoarse => {
+ parse_expr_op!(
+ crate::DerivativeAxis::X,
+ crate::DerivativeControl::Coarse,
+ DERIVATIVE
+ )?;
+ }
+ Op::DPdyCoarse => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Y,
+ crate::DerivativeControl::Coarse,
+ DERIVATIVE
+ )?;
+ }
+ Op::FwidthCoarse => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Width,
+ crate::DerivativeControl::Coarse,
+ DERIVATIVE
+ )?;
+ }
+ Op::DPdxFine => {
+ parse_expr_op!(
+ crate::DerivativeAxis::X,
+ crate::DerivativeControl::Fine,
+ DERIVATIVE
+ )?;
+ }
+ Op::DPdyFine => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Y,
+ crate::DerivativeControl::Fine,
+ DERIVATIVE
+ )?;
+ }
+ Op::FwidthFine => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Width,
+ crate::DerivativeControl::Fine,
+ DERIVATIVE
+ )?;
+ }
+ Op::DPdx => {
+ parse_expr_op!(
+ crate::DerivativeAxis::X,
+ crate::DerivativeControl::None,
+ DERIVATIVE
+ )?;
+ }
+ Op::DPdy => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Y,
+ crate::DerivativeControl::None,
+ DERIVATIVE
+ )?;
+ }
+ Op::Fwidth => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Width,
+ crate::DerivativeControl::None,
+ DERIVATIVE
+ )?;
+ }
+ Op::ArrayLength => {
+ inst.expect(5)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let structure_id = self.next()?;
+ let member_index = self.next()?;
+
+ // We're assuming that the validation pass, if it's run, will catch if the
+ // wrong types or parameters are supplied here.
+
+ let structure_ptr = self.lookup_expression.lookup(structure_id)?;
+ let structure_handle = get_expr_handle!(structure_id, structure_ptr);
+
+ let member_ptr = ctx.expressions.append(
+ crate::Expression::AccessIndex {
+ base: structure_handle,
+ index: member_index,
+ },
+ span,
+ );
+
+ let length = ctx
+ .expressions
+ .append(crate::Expression::ArrayLength(member_ptr), span);
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: length,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::CopyMemory => {
+ inst.expect_at_least(3)?;
+ let target_id = self.next()?;
+ let source_id = self.next()?;
+ let _memory_access = if inst.wc != 3 {
+ inst.expect(4)?;
+ spirv::MemoryAccess::from_bits(self.next()?)
+ .ok_or(Error::InvalidParameter(Op::CopyMemory))?
+ } else {
+ spirv::MemoryAccess::NONE
+ };
+
+ // TODO: check if the source and target types are the same?
+ let target = self.lookup_expression.lookup(target_id)?;
+ let target_handle = get_expr_handle!(target_id, target);
+ let source = self.lookup_expression.lookup(source_id)?;
+ let source_handle = get_expr_handle!(source_id, source);
+
+ // This operation is practically the same as loading and then storing, I think.
+ let value_expr = ctx.expressions.append(
+ crate::Expression::Load {
+ pointer: source_handle,
+ },
+ span,
+ );
+
+ block.extend(emitter.finish(ctx.expressions));
+ block.push(
+ crate::Statement::Store {
+ pointer: target_handle,
+ value: value_expr,
+ },
+ span,
+ );
+
+ emitter.start(ctx.expressions);
+ }
+ Op::ControlBarrier => {
+ inst.expect(4)?;
+ let exec_scope_id = self.next()?;
+ let _mem_scope_raw = self.next()?;
+ let semantics_id = self.next()?;
+ let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
+ let semantics_const = self.lookup_constant.lookup(semantics_id)?;
+ let exec_scope = match ctx.const_arena[exec_scope_const.handle].inner {
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Uint(raw),
+ width: _,
+ } => raw as u32,
+ _ => return Err(Error::InvalidBarrierScope(exec_scope_id)),
+ };
+ let semantics = match ctx.const_arena[semantics_const.handle].inner {
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Uint(raw),
+ width: _,
+ } => raw as u32,
+ _ => return Err(Error::InvalidBarrierMemorySemantics(semantics_id)),
+ };
+ if exec_scope == spirv::Scope::Workgroup as u32 {
+ let mut flags = crate::Barrier::empty();
+ flags.set(
+ crate::Barrier::STORAGE,
+ semantics & spirv::MemorySemantics::UNIFORM_MEMORY.bits() != 0,
+ );
+ flags.set(
+ crate::Barrier::WORK_GROUP,
+ semantics
+ & (spirv::MemorySemantics::SUBGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY)
+ .bits()
+ != 0,
+ );
+ block.push(crate::Statement::Barrier(flags), span);
+ } else {
+ log::warn!("Unsupported barrier execution scope: {}", exec_scope);
+ }
+ }
+ Op::CopyObject => {
+ inst.expect(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let operand_id = self.next()?;
+
+ let lookup = self.lookup_expression.lookup(operand_id)?;
+ let handle = get_expr_handle!(operand_id, lookup);
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)),
+ }
+ };
+
+ block.extend(emitter.finish(ctx.expressions));
+ if let Some(stmt) = terminator {
+ block.push(stmt, crate::Span::default());
+ }
+
+ // Save this block fragment in `block_ctx.blocks`, and mark it to be
+ // incorporated into the current body at `Statement` assembly time.
+ ctx.blocks.insert(block_id, block);
+ let body = &mut ctx.bodies[body_idx];
+ body.data.push(BodyFragment::BlockId(block_id));
+ Ok(())
+ }
+
+ fn make_expression_storage(
+ &mut self,
+ globals: &Arena<crate::GlobalVariable>,
+ constants: &Arena<crate::Constant>,
+ ) -> Arena<crate::Expression> {
+ let mut expressions = Arena::new();
+ #[allow(clippy::panic)]
+ {
+ assert!(self.lookup_expression.is_empty());
+ }
+ // register global variables
+ for (&id, var) in self.lookup_variable.iter() {
+ let span = globals.get_span(var.handle);
+ let handle = expressions.append(crate::Expression::GlobalVariable(var.handle), span);
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ type_id: var.type_id,
+ handle,
+ // Setting this to an invalid id will cause get_expr_handle
+ // to default to the main body making sure no load/stores
+ // are added.
+ block_id: 0,
+ },
+ );
+ }
+ // register special constants
+ self.index_constant_expressions.clear();
+ for &con_handle in self.index_constants.iter() {
+ let span = constants.get_span(con_handle);
+ let handle = expressions.append(crate::Expression::Constant(con_handle), span);
+ self.index_constant_expressions.push(handle);
+ }
+ // register constants
+ for (&id, con) in self.lookup_constant.iter() {
+ let span = constants.get_span(con.handle);
+ let handle = expressions.append(crate::Expression::Constant(con.handle), span);
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ type_id: con.type_id,
+ handle,
+ // Setting this to an invalid id will cause get_expr_handle
+ // to default to the main body making sure no load/stores
+ // are added.
+ block_id: 0,
+ },
+ );
+ }
+ // done
+ expressions
+ }
+
+ fn switch(&mut self, state: ModuleState, op: spirv::Op) -> Result<(), Error> {
+ if state < self.state {
+ Err(Error::UnsupportedInstruction(self.state, op))
+ } else {
+ self.state = state;
+ Ok(())
+ }
+ }
+
+ /// Walk the statement tree and patch it in the following cases:
+ /// 1. Function call targets are replaced by `deferred_function_calls` map
+ fn patch_statements(
+ &mut self,
+ statements: &mut crate::Block,
+ expressions: &mut Arena<crate::Expression>,
+ fun_parameter_sampling: &mut [image::SamplingFlags],
+ ) -> Result<(), Error> {
+ use crate::Statement as S;
+ let mut i = 0usize;
+ while i < statements.len() {
+ match statements[i] {
+ S::Emit(_) => {}
+ S::Block(ref mut block) => {
+ self.patch_statements(block, expressions, fun_parameter_sampling)?;
+ }
+ S::If {
+ condition: _,
+ ref mut accept,
+ ref mut reject,
+ } => {
+ self.patch_statements(reject, expressions, fun_parameter_sampling)?;
+ self.patch_statements(accept, expressions, fun_parameter_sampling)?;
+ }
+ S::Switch {
+ selector: _,
+ ref mut cases,
+ } => {
+ for case in cases.iter_mut() {
+ self.patch_statements(&mut case.body, expressions, fun_parameter_sampling)?;
+ }
+ }
+ S::Loop {
+ ref mut body,
+ ref mut continuing,
+ break_if: _,
+ } => {
+ self.patch_statements(body, expressions, fun_parameter_sampling)?;
+ self.patch_statements(continuing, expressions, fun_parameter_sampling)?;
+ }
+ S::Break
+ | S::Continue
+ | S::Return { .. }
+ | S::Kill
+ | S::Barrier(_)
+ | S::Store { .. }
+ | S::ImageStore { .. }
+ | S::Atomic { .. }
+ | S::RayQuery { .. } => {}
+ S::Call {
+ function: ref mut callee,
+ ref arguments,
+ ..
+ } => {
+ let fun_id = self.deferred_function_calls[callee.index()];
+ let fun_lookup = self.lookup_function.lookup(fun_id)?;
+ *callee = fun_lookup.handle;
+
+ // Patch sampling flags
+ for (arg_index, arg) in arguments.iter().enumerate() {
+ let flags = match fun_lookup.parameters_sampling.get(arg_index) {
+ Some(&flags) if !flags.is_empty() => flags,
+ _ => continue,
+ };
+
+ match expressions[*arg] {
+ crate::Expression::GlobalVariable(handle) => {
+ if let Some(sampling) = self.handle_sampling.get_mut(&handle) {
+ *sampling |= flags
+ }
+ }
+ crate::Expression::FunctionArgument(i) => {
+ fun_parameter_sampling[i as usize] |= flags;
+ }
+ ref other => return Err(Error::InvalidGlobalVar(other.clone())),
+ }
+ }
+ }
+ }
+ i += 1;
+ }
+ Ok(())
+ }
+
+ fn patch_function(
+ &mut self,
+ handle: Option<Handle<crate::Function>>,
+ fun: &mut crate::Function,
+ ) -> Result<(), Error> {
+ // Note: this search is a bit unfortunate
+ let (fun_id, mut parameters_sampling) = match handle {
+ Some(h) => {
+ let (&fun_id, lookup) = self
+ .lookup_function
+ .iter_mut()
+ .find(|&(_, ref lookup)| lookup.handle == h)
+ .unwrap();
+ (fun_id, mem::take(&mut lookup.parameters_sampling))
+ }
+ None => (0, Vec::new()),
+ };
+
+ for (_, expr) in fun.expressions.iter_mut() {
+ if let crate::Expression::CallResult(ref mut function) = *expr {
+ let fun_id = self.deferred_function_calls[function.index()];
+ *function = self.lookup_function.lookup(fun_id)?.handle;
+ }
+ }
+
+ self.patch_statements(
+ &mut fun.body,
+ &mut fun.expressions,
+ &mut parameters_sampling,
+ )?;
+
+ if let Some(lookup) = self.lookup_function.get_mut(&fun_id) {
+ lookup.parameters_sampling = parameters_sampling;
+ }
+ Ok(())
+ }
+
+ pub fn parse(mut self) -> Result<crate::Module, Error> {
+ let mut module = {
+ if self.next()? != spirv::MAGIC_NUMBER {
+ return Err(Error::InvalidHeader);
+ }
+ let version_raw = self.next()?;
+ let generator = self.next()?;
+ let _bound = self.next()?;
+ let _schema = self.next()?;
+ log::info!("Generated by {} version {:x}", generator, version_raw);
+ crate::Module::default()
+ };
+
+ // register indexing constants
+ self.index_constants.clear();
+ for i in 0..4 {
+ let handle = module.constants.append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Sint(i),
+ },
+ },
+ Default::default(),
+ );
+ self.index_constants.push(handle);
+ }
+
+ self.layouter.clear();
+ self.dummy_functions = Arena::new();
+ self.lookup_function.clear();
+ self.function_call_graph.clear();
+
+ loop {
+ use spirv::Op;
+
+ let inst = match self.next_inst() {
+ Ok(inst) => inst,
+ Err(Error::IncompleteData) => break,
+ Err(other) => return Err(other),
+ };
+ log::debug!("\t{:?} [{}]", inst.op, inst.wc);
+
+ match inst.op {
+ Op::Capability => self.parse_capability(inst),
+ Op::Extension => self.parse_extension(inst),
+ Op::ExtInstImport => self.parse_ext_inst_import(inst),
+ Op::MemoryModel => self.parse_memory_model(inst),
+ Op::EntryPoint => self.parse_entry_point(inst),
+ Op::ExecutionMode => self.parse_execution_mode(inst),
+ Op::String => self.parse_string(inst),
+ Op::Source => self.parse_source(inst),
+ Op::SourceExtension => self.parse_source_extension(inst),
+ Op::Name => self.parse_name(inst),
+ Op::MemberName => self.parse_member_name(inst),
+ Op::ModuleProcessed => self.parse_module_processed(inst),
+ Op::Decorate => self.parse_decorate(inst),
+ Op::MemberDecorate => self.parse_member_decorate(inst),
+ Op::TypeVoid => self.parse_type_void(inst),
+ Op::TypeBool => self.parse_type_bool(inst, &mut module),
+ Op::TypeInt => self.parse_type_int(inst, &mut module),
+ Op::TypeFloat => self.parse_type_float(inst, &mut module),
+ Op::TypeVector => self.parse_type_vector(inst, &mut module),
+ Op::TypeMatrix => self.parse_type_matrix(inst, &mut module),
+ Op::TypeFunction => self.parse_type_function(inst),
+ Op::TypePointer => self.parse_type_pointer(inst, &mut module),
+ Op::TypeArray => self.parse_type_array(inst, &mut module),
+ Op::TypeRuntimeArray => self.parse_type_runtime_array(inst, &mut module),
+ Op::TypeStruct => self.parse_type_struct(inst, &mut module),
+ Op::TypeImage => self.parse_type_image(inst, &mut module),
+ Op::TypeSampledImage => self.parse_type_sampled_image(inst),
+ Op::TypeSampler => self.parse_type_sampler(inst, &mut module),
+ Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module),
+ Op::ConstantComposite => self.parse_composite_constant(inst, &mut module),
+ Op::ConstantNull | Op::Undef => self
+ .parse_null_constant(inst, &module.types, &mut module.constants)
+ .map(|_| ()),
+ Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module),
+ Op::ConstantFalse => self.parse_bool_constant(inst, false, &mut module),
+ Op::Variable => self.parse_global_variable(inst, &mut module),
+ Op::Function => {
+ self.switch(ModuleState::Function, inst.op)?;
+ inst.expect(5)?;
+ self.parse_function(&mut module)
+ }
+ _ => Err(Error::UnsupportedInstruction(self.state, inst.op)), //TODO
+ }?;
+ }
+
+ log::info!("Patching...");
+ {
+ let mut nodes = petgraph::algo::toposort(&self.function_call_graph, None)
+ .map_err(|cycle| Error::FunctionCallCycle(cycle.node_id()))?;
+ nodes.reverse(); // we need dominated first
+ let mut functions = mem::take(&mut module.functions);
+ for fun_id in nodes {
+ if fun_id > !(functions.len() as u32) {
+ // skip all the fake IDs registered for the entry points
+ continue;
+ }
+ let lookup = self.lookup_function.get_mut(&fun_id).unwrap();
+ // take out the function from the old array
+ let fun = mem::take(&mut functions[lookup.handle]);
+ // add it to the newly formed arena, and adjust the lookup
+ lookup.handle = module
+ .functions
+ .append(fun, functions.get_span(lookup.handle));
+ }
+ }
+ // patch all the functions
+ for (handle, fun) in module.functions.iter_mut() {
+ self.patch_function(Some(handle), fun)?;
+ }
+ for ep in module.entry_points.iter_mut() {
+ self.patch_function(None, &mut ep.function)?;
+ }
+
+ // Check all the images and samplers to have consistent comparison property.
+ for (handle, flags) in self.handle_sampling.drain() {
+ if !image::patch_comparison_type(
+ flags,
+ module.global_variables.get_mut(handle),
+ &mut module.types,
+ ) {
+ return Err(Error::InconsistentComparisonSampling(handle));
+ }
+ }
+
+ if !self.future_decor.is_empty() {
+ log::warn!("Unused item decorations: {:?}", self.future_decor);
+ self.future_decor.clear();
+ }
+ if !self.future_member_decor.is_empty() {
+ log::warn!("Unused member decorations: {:?}", self.future_member_decor);
+ self.future_member_decor.clear();
+ }
+
+ Ok(module)
+ }
+
+ fn parse_capability(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Capability, inst.op)?;
+ inst.expect(2)?;
+ let capability = self.next()?;
+ let cap =
+ spirv::Capability::from_u32(capability).ok_or(Error::UnknownCapability(capability))?;
+ if !SUPPORTED_CAPABILITIES.contains(&cap) {
+ if self.options.strict_capabilities {
+ return Err(Error::UnsupportedCapability(cap));
+ } else {
+ log::warn!("Unknown capability {:?}", cap);
+ }
+ }
+ Ok(())
+ }
+
+ fn parse_extension(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Extension, inst.op)?;
+ inst.expect_at_least(2)?;
+ let (name, left) = self.next_string(inst.wc - 1)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+ if !SUPPORTED_EXTENSIONS.contains(&name.as_str()) {
+ return Err(Error::UnsupportedExtension(name));
+ }
+ Ok(())
+ }
+
+ fn parse_ext_inst_import(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Extension, inst.op)?;
+ inst.expect_at_least(3)?;
+ let result_id = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 2)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+ if !SUPPORTED_EXT_SETS.contains(&name.as_str()) {
+ return Err(Error::UnsupportedExtSet(name));
+ }
+ self.ext_glsl_id = Some(result_id);
+ Ok(())
+ }
+
+ fn parse_memory_model(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::MemoryModel, inst.op)?;
+ inst.expect(3)?;
+ let _addressing_model = self.next()?;
+ let _memory_model = self.next()?;
+ Ok(())
+ }
+
+ fn parse_entry_point(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::EntryPoint, inst.op)?;
+ inst.expect_at_least(4)?;
+ let exec_model = self.next()?;
+ let exec_model = spirv::ExecutionModel::from_u32(exec_model)
+ .ok_or(Error::UnsupportedExecutionModel(exec_model))?;
+ let function_id = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 3)?;
+ let ep = EntryPoint {
+ stage: match exec_model {
+ spirv::ExecutionModel::Vertex => crate::ShaderStage::Vertex,
+ spirv::ExecutionModel::Fragment => crate::ShaderStage::Fragment,
+ spirv::ExecutionModel::GLCompute => crate::ShaderStage::Compute,
+ _ => return Err(Error::UnsupportedExecutionModel(exec_model as u32)),
+ },
+ name,
+ early_depth_test: None,
+ workgroup_size: [0; 3],
+ variable_ids: self.data.by_ref().take(left as usize).collect(),
+ };
+ self.lookup_entry_point.insert(function_id, ep);
+ Ok(())
+ }
+
+ fn parse_execution_mode(&mut self, inst: Instruction) -> Result<(), Error> {
+ use spirv::ExecutionMode;
+
+ self.switch(ModuleState::ExecutionMode, inst.op)?;
+ inst.expect_at_least(3)?;
+
+ let ep_id = self.next()?;
+ let mode_id = self.next()?;
+ let args: Vec<spirv::Word> = self.data.by_ref().take(inst.wc as usize - 3).collect();
+
+ let ep = self
+ .lookup_entry_point
+ .get_mut(&ep_id)
+ .ok_or(Error::InvalidId(ep_id))?;
+ let mode = spirv::ExecutionMode::from_u32(mode_id)
+ .ok_or(Error::UnsupportedExecutionMode(mode_id))?;
+
+ match mode {
+ ExecutionMode::EarlyFragmentTests => {
+ if ep.early_depth_test.is_none() {
+ ep.early_depth_test = Some(crate::EarlyDepthTest { conservative: None });
+ }
+ }
+ ExecutionMode::DepthUnchanged => {
+ ep.early_depth_test = Some(crate::EarlyDepthTest {
+ conservative: Some(crate::ConservativeDepth::Unchanged),
+ });
+ }
+ ExecutionMode::DepthGreater => {
+ ep.early_depth_test = Some(crate::EarlyDepthTest {
+ conservative: Some(crate::ConservativeDepth::GreaterEqual),
+ });
+ }
+ ExecutionMode::DepthLess => {
+ ep.early_depth_test = Some(crate::EarlyDepthTest {
+ conservative: Some(crate::ConservativeDepth::LessEqual),
+ });
+ }
+ ExecutionMode::DepthReplacing => {
+ // Ignored because it can be deduced from the IR.
+ }
+ ExecutionMode::OriginUpperLeft => {
+ // Ignored because the other option (OriginLowerLeft) is not valid in Vulkan mode.
+ }
+ ExecutionMode::LocalSize => {
+ ep.workgroup_size = [args[0], args[1], args[2]];
+ }
+ _ => {
+ return Err(Error::UnsupportedExecutionMode(mode_id));
+ }
+ }
+
+ Ok(())
+ }
+
+ fn parse_string(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Source, inst.op)?;
+ inst.expect_at_least(3)?;
+ let _id = self.next()?;
+ let (_name, _) = self.next_string(inst.wc - 2)?;
+ Ok(())
+ }
+
+ fn parse_source(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Source, inst.op)?;
+ for _ in 1..inst.wc {
+ let _ = self.next()?;
+ }
+ Ok(())
+ }
+
+ fn parse_source_extension(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Source, inst.op)?;
+ inst.expect_at_least(2)?;
+ let (_name, _) = self.next_string(inst.wc - 1)?;
+ Ok(())
+ }
+
+ fn parse_name(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Name, inst.op)?;
+ inst.expect_at_least(3)?;
+ let id = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 2)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+ self.future_decor.entry(id).or_default().name = Some(name);
+ Ok(())
+ }
+
+ fn parse_member_name(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Name, inst.op)?;
+ inst.expect_at_least(4)?;
+ let id = self.next()?;
+ let member = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 3)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+
+ self.future_member_decor
+ .entry((id, member))
+ .or_default()
+ .name = Some(name);
+ Ok(())
+ }
+
+ fn parse_module_processed(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Name, inst.op)?;
+ inst.expect_at_least(2)?;
+ let (_info, left) = self.next_string(inst.wc - 1)?;
+ //Note: string is ignored
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+ Ok(())
+ }
+
+ fn parse_decorate(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Annotation, inst.op)?;
+ inst.expect_at_least(3)?;
+ let id = self.next()?;
+ let mut dec = self.future_decor.remove(&id).unwrap_or_default();
+ self.next_decoration(inst, 2, &mut dec)?;
+ self.future_decor.insert(id, dec);
+ Ok(())
+ }
+
+ fn parse_member_decorate(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Annotation, inst.op)?;
+ inst.expect_at_least(4)?;
+ let id = self.next()?;
+ let member = self.next()?;
+
+ let mut dec = self
+ .future_member_decor
+ .remove(&(id, member))
+ .unwrap_or_default();
+ self.next_decoration(inst, 3, &mut dec)?;
+ self.future_member_decor.insert((id, member), dec);
+ Ok(())
+ }
+
+ fn parse_type_void(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(2)?;
+ let id = self.next()?;
+ self.lookup_void_type = Some(id);
+ Ok(())
+ }
+
+ fn parse_type_bool(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(2)?;
+ let id = self.next()?;
+ let inner = crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_int(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let width = self.next()?;
+ let sign = self.next()?;
+ let inner = crate::TypeInner::Scalar {
+ kind: match sign {
+ 0 => crate::ScalarKind::Uint,
+ 1 => crate::ScalarKind::Sint,
+ _ => return Err(Error::InvalidSign(sign)),
+ },
+ width: map_width(width)?,
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_float(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let id = self.next()?;
+ let width = self.next()?;
+ let inner = crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Float,
+ width: map_width(width)?,
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_vector(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let type_id = self.next()?;
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let (kind, width) = match module.types[type_lookup.handle].inner {
+ crate::TypeInner::Scalar { kind, width } => (kind, width),
+ _ => return Err(Error::InvalidInnerType(type_id)),
+ };
+ let component_count = self.next()?;
+ let inner = crate::TypeInner::Vector {
+ size: map_vector_size(component_count)?,
+ kind,
+ width,
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: Some(type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_matrix(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let vector_type_id = self.next()?;
+ let num_columns = self.next()?;
+ let decor = self.future_decor.remove(&id);
+
+ let vector_type_lookup = self.lookup_type.lookup(vector_type_id)?;
+ let inner = match module.types[vector_type_lookup.handle].inner {
+ crate::TypeInner::Vector { size, width, .. } => crate::TypeInner::Matrix {
+ columns: map_vector_size(num_columns)?,
+ rows: size,
+ width,
+ },
+ _ => return Err(Error::InvalidInnerType(vector_type_id)),
+ };
+
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: decor.and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: Some(vector_type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_function(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(3)?;
+ let id = self.next()?;
+ let return_type_id = self.next()?;
+ let parameter_type_ids = self.data.by_ref().take(inst.wc as usize - 3).collect();
+ self.lookup_function_type.insert(
+ id,
+ LookupFunctionType {
+ parameter_type_ids,
+ return_type_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_pointer(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let storage_class = self.next()?;
+ let type_id = self.next()?;
+
+ let decor = self.future_decor.remove(&id);
+ let base_lookup_ty = self.lookup_type.lookup(type_id)?;
+ let base_inner = &module.types[base_lookup_ty.handle].inner;
+
+ let space = if let Some(space) = base_inner.pointer_space() {
+ space
+ } else if self
+ .lookup_storage_buffer_types
+ .contains_key(&base_lookup_ty.handle)
+ {
+ crate::AddressSpace::Storage {
+ access: crate::StorageAccess::default(),
+ }
+ } else {
+ match map_storage_class(storage_class)? {
+ ExtendedClass::Global(space) => space,
+ ExtendedClass::Input | ExtendedClass::Output => crate::AddressSpace::Private,
+ }
+ };
+
+ // We don't support pointers to runtime-sized arrays in the `Uniform`
+ // storage class with the `BufferBlock` decoration. Runtime-sized arrays
+ // should be in the StorageBuffer class.
+ if let crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } = *base_inner
+ {
+ match space {
+ crate::AddressSpace::Storage { .. } => {}
+ _ => {
+ return Err(Error::UnsupportedRuntimeArrayStorageClass);
+ }
+ }
+ }
+
+ // Don't bother with pointer stuff for `Handle` types.
+ let lookup_ty = if space == crate::AddressSpace::Handle {
+ base_lookup_ty.clone()
+ } else {
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: decor.and_then(|dec| dec.name),
+ inner: crate::TypeInner::Pointer {
+ base: base_lookup_ty.handle,
+ space,
+ },
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: Some(type_id),
+ }
+ };
+ self.lookup_type.insert(id, lookup_ty);
+ Ok(())
+ }
+
+ fn parse_type_array(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let type_id = self.next()?;
+ let length_id = self.next()?;
+ let length_const = self.lookup_constant.lookup(length_id)?;
+
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+ let base = self.lookup_type.lookup(type_id)?.handle;
+
+ self.layouter
+ .update(&module.types, &module.constants)
+ .unwrap();
+
+ // HACK if the underlying type is an image or a sampler, let's assume
+ // that we're dealing with a binding-array
+ //
+ // Note that it's not a strictly correct assumption, but rather a trade
+ // off caused by an impedance mismatch between SPIR-V's and Naga's type
+ // systems - Naga distinguishes between arrays and binding-arrays via
+ // types (i.e. both kinds of arrays are just different types), while
+ // SPIR-V distinguishes between them through usage - e.g. given:
+ //
+ // ```
+ // %image = OpTypeImage %float 2D 2 0 0 2 Rgba16f
+ // %uint_256 = OpConstant %uint 256
+ // %image_array = OpTypeArray %image %uint_256
+ // ```
+ //
+ // ```
+ // %image = OpTypeImage %float 2D 2 0 0 2 Rgba16f
+ // %uint_256 = OpConstant %uint 256
+ // %image_array = OpTypeArray %image %uint_256
+ // %image_array_ptr = OpTypePointer UniformConstant %image_array
+ // ```
+ //
+ // ... in the first case, `%image_array` should technically correspond
+ // to `TypeInner::Array`, while in the second case it should say
+ // `TypeInner::BindingArray` (kinda, depending on whether `%image_array`
+ // is ever used as a freestanding type or rather always through the
+ // pointer-indirection).
+ //
+ // Anyway, at the moment we don't support other kinds of image / sampler
+ // arrays than those binding-based, so this assumption is pretty safe
+ // for now.
+ let inner = if let crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } =
+ module.types[base].inner
+ {
+ crate::TypeInner::BindingArray {
+ base,
+ size: crate::ArraySize::Constant(length_const.handle),
+ }
+ } else {
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(length_const.handle),
+ stride: match decor.array_stride {
+ Some(stride) => stride.get(),
+ None => self.layouter[base].to_stride(),
+ },
+ }
+ };
+
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: decor.name,
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: Some(type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_runtime_array(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let id = self.next()?;
+ let type_id = self.next()?;
+
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+ let base = self.lookup_type.lookup(type_id)?.handle;
+
+ self.layouter
+ .update(&module.types, &module.constants)
+ .unwrap();
+
+ // HACK same case as in `parse_type_array()`
+ let inner = if let crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } =
+ module.types[base].inner
+ {
+ crate::TypeInner::BindingArray {
+ base: self.lookup_type.lookup(type_id)?.handle,
+ size: crate::ArraySize::Dynamic,
+ }
+ } else {
+ crate::TypeInner::Array {
+ base: self.lookup_type.lookup(type_id)?.handle,
+ size: crate::ArraySize::Dynamic,
+ stride: match decor.array_stride {
+ Some(stride) => stride.get(),
+ None => self.layouter[base].to_stride(),
+ },
+ }
+ };
+
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: decor.name,
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: Some(type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_struct(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(2)?;
+ let id = self.next()?;
+ let parent_decor = self.future_decor.remove(&id);
+ let is_storage_buffer = parent_decor
+ .as_ref()
+ .map_or(false, |decor| decor.storage_buffer);
+
+ self.layouter
+ .update(&module.types, &module.constants)
+ .unwrap();
+
+ let mut members = Vec::<crate::StructMember>::with_capacity(inst.wc as usize - 2);
+ let mut member_lookups = Vec::with_capacity(members.capacity());
+ let mut storage_access = crate::StorageAccess::empty();
+ let mut span = 0;
+ let mut alignment = Alignment::ONE;
+ for i in 0..u32::from(inst.wc) - 2 {
+ let type_id = self.next()?;
+ let ty = self.lookup_type.lookup(type_id)?.handle;
+ let decor = self
+ .future_member_decor
+ .remove(&(id, i))
+ .unwrap_or_default();
+
+ storage_access |= decor.flags.to_storage_access();
+
+ member_lookups.push(LookupMember {
+ type_id,
+ row_major: decor.matrix_major == Some(Majority::Row),
+ });
+
+ let member_alignment = self.layouter[ty].alignment;
+ span = member_alignment.round_up(span);
+ alignment = member_alignment.max(alignment);
+
+ let binding = decor.io_binding().ok();
+ if let Some(offset) = decor.offset {
+ span = offset;
+ }
+ let offset = span;
+
+ span += self.layouter[ty].size;
+
+ let inner = &module.types[ty].inner;
+ if let crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } = *inner
+ {
+ if let Some(stride) = decor.matrix_stride {
+ let expected_stride = Alignment::from(rows) * width as u32;
+ if stride.get() != expected_stride {
+ return Err(Error::UnsupportedMatrixStride {
+ stride: stride.get(),
+ columns: columns as u8,
+ rows: rows as u8,
+ width,
+ });
+ }
+ }
+ }
+
+ members.push(crate::StructMember {
+ name: decor.name,
+ ty,
+ binding,
+ offset,
+ });
+ }
+
+ span = alignment.round_up(span);
+
+ let inner = crate::TypeInner::Struct { span, members };
+
+ let ty_handle = module.types.insert(
+ crate::Type {
+ name: parent_decor.and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ );
+
+ if is_storage_buffer {
+ self.lookup_storage_buffer_types
+ .insert(ty_handle, storage_access);
+ }
+ for (i, member_lookup) in member_lookups.into_iter().enumerate() {
+ self.lookup_member
+ .insert((ty_handle, i as u32), member_lookup);
+ }
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: ty_handle,
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_image(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(9)?;
+
+ let id = self.next()?;
+ let sample_type_id = self.next()?;
+ let dim = self.next()?;
+ let _is_depth = self.next()?;
+ let is_array = self.next()? != 0;
+ let is_msaa = self.next()? != 0;
+ let _is_sampled = self.next()?;
+ let format = self.next()?;
+
+ let dim = map_image_dim(dim)?;
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+
+ // ensure there is a type for texture coordinate without extra components
+ module.types.insert(
+ crate::Type {
+ name: None,
+ inner: {
+ let kind = crate::ScalarKind::Float;
+ let width = 4;
+ match dim.required_coordinate_size() {
+ None => crate::TypeInner::Scalar { kind, width },
+ Some(size) => crate::TypeInner::Vector { size, kind, width },
+ }
+ },
+ },
+ Default::default(),
+ );
+
+ let base_handle = self.lookup_type.lookup(sample_type_id)?.handle;
+ let kind = module.types[base_handle]
+ .inner
+ .scalar_kind()
+ .ok_or(Error::InvalidImageBaseType(base_handle))?;
+
+ let inner = crate::TypeInner::Image {
+ class: if format != 0 {
+ crate::ImageClass::Storage {
+ format: map_image_format(format)?,
+ access: crate::StorageAccess::default(),
+ }
+ } else {
+ crate::ImageClass::Sampled {
+ kind,
+ multi: is_msaa,
+ }
+ },
+ dim,
+ arrayed: is_array,
+ };
+
+ let handle = module.types.insert(
+ crate::Type {
+ name: decor.name,
+ inner,
+ },
+ self.span_from_with_op(start),
+ );
+
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle,
+ base_id: Some(sample_type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_sampled_image(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let id = self.next()?;
+ let image_id = self.next()?;
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: self.lookup_type.lookup(image_id)?.handle,
+ base_id: Some(image_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_sampler(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(2)?;
+ let id = self.next()?;
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+ let handle = module.types.insert(
+ crate::Type {
+ name: decor.name,
+ inner: crate::TypeInner::Sampler { comparison: false },
+ },
+ self.span_from_with_op(start),
+ );
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle,
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_constant(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(4)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let ty = type_lookup.handle;
+
+ let inner = match module.types[ty].inner {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width,
+ } => {
+ let low = self.next()?;
+ let high = if width > 4 {
+ inst.expect(5)?;
+ self.next()?
+ } else {
+ 0
+ };
+ crate::ConstantInner::Scalar {
+ width,
+ value: crate::ScalarValue::Uint((u64::from(high) << 32) | u64::from(low)),
+ }
+ }
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width,
+ } => {
+ let low = self.next()?;
+ let high = if width > 4 {
+ inst.expect(5)?;
+ self.next()?
+ } else {
+ 0
+ };
+ crate::ConstantInner::Scalar {
+ width,
+ value: crate::ScalarValue::Sint(
+ (i64::from(high as i32) << 32) | ((i64::from(low as i32) << 32) >> 32),
+ ),
+ }
+ }
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Float,
+ width,
+ } => {
+ let low = self.next()?;
+ let extended = match width {
+ 4 => f64::from(f32::from_bits(low)),
+ 8 => {
+ inst.expect(5)?;
+ let high = self.next()?;
+ f64::from_bits((u64::from(high) << 32) | u64::from(low))
+ }
+ _ => return Err(Error::InvalidTypeWidth(width as u32)),
+ };
+ crate::ConstantInner::Scalar {
+ width,
+ value: crate::ScalarValue::Float(extended),
+ }
+ }
+ _ => return Err(Error::UnsupportedType(type_lookup.handle)),
+ };
+
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+
+ self.lookup_constant.insert(
+ id,
+ LookupConstant {
+ handle: module.constants.append(
+ crate::Constant {
+ specialization: decor.specialization,
+ name: decor.name,
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ type_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_composite_constant(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(3)?;
+ let type_id = self.next()?;
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let ty = type_lookup.handle;
+ let id = self.next()?;
+
+ let mut components = Vec::with_capacity(inst.wc as usize - 3);
+ for _ in 0..components.capacity() {
+ let component_id = self.next()?;
+ let constant = self.lookup_constant.lookup(component_id)?;
+ components.push(constant.handle);
+ }
+
+ self.lookup_constant.insert(
+ id,
+ LookupConstant {
+ handle: module.constants.append(
+ crate::Constant {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ specialization: None,
+ inner: crate::ConstantInner::Composite { ty, components },
+ },
+ self.span_from_with_op(start),
+ ),
+ type_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_null_constant(
+ &mut self,
+ inst: Instruction,
+ types: &UniqueArena<crate::Type>,
+ constants: &mut Arena<crate::Constant>,
+ ) -> Result<(u32, u32, Handle<crate::Constant>), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let span = self.span_from_with_op(start);
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let ty = type_lookup.handle;
+
+ let inner = null::generate_null_constant(ty, types, constants, span)?;
+ let handle = constants.append(
+ crate::Constant {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ specialization: None, //TODO
+ inner,
+ },
+ span,
+ );
+ self.lookup_constant
+ .insert(id, LookupConstant { handle, type_id });
+ Ok((type_id, id, handle))
+ }
+
+ fn parse_bool_constant(
+ &mut self,
+ inst: Instruction,
+ value: bool,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+
+ self.lookup_constant.insert(
+ id,
+ LookupConstant {
+ handle: module.constants.append(
+ crate::Constant {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ specialization: None, //TODO
+ inner: crate::ConstantInner::boolean(value),
+ },
+ self.span_from_with_op(start),
+ ),
+ type_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_global_variable(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(4)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let storage_class = self.next()?;
+ let init = if inst.wc > 4 {
+ inst.expect(5)?;
+ let init_id = self.next()?;
+ let lconst = self.lookup_constant.lookup(init_id)?;
+ Some(lconst.handle)
+ } else {
+ None
+ };
+ let span = self.span_from_with_op(start);
+ let mut dec = self.future_decor.remove(&id).unwrap_or_default();
+
+ let original_ty = self.lookup_type.lookup(type_id)?.handle;
+ let mut ty = original_ty;
+
+ if let crate::TypeInner::Pointer { base, space: _ } = module.types[original_ty].inner {
+ ty = base;
+ }
+
+ if let crate::TypeInner::BindingArray { .. } = module.types[original_ty].inner {
+ // Inside `parse_type_array()` we guess that an array of images or
+ // samplers must be a binding array, and here we validate that guess
+ if dec.desc_set.is_none() || dec.desc_index.is_none() {
+ return Err(Error::NonBindingArrayOfImageOrSamplers);
+ }
+ }
+
+ if let crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: crate::ImageClass::Storage { format, access: _ },
+ } = module.types[ty].inner
+ {
+ // Storage image types in IR have to contain the access, but not in the SPIR-V.
+ // The same image type in SPIR-V can be used (and has to be used) for multiple images.
+ // So we copy the type out and apply the variable access decorations.
+ let access = dec.flags.to_storage_access();
+
+ ty = module.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: crate::ImageClass::Storage { format, access },
+ },
+ },
+ Default::default(),
+ );
+ }
+
+ let ext_class = match self.lookup_storage_buffer_types.get(&ty) {
+ Some(&access) => ExtendedClass::Global(crate::AddressSpace::Storage { access }),
+ None => map_storage_class(storage_class)?,
+ };
+
+ // Fix empty name for gl_PerVertex struct generated by glslang
+ if let crate::TypeInner::Pointer { .. } = module.types[original_ty].inner {
+ if ext_class == ExtendedClass::Input || ext_class == ExtendedClass::Output {
+ if let Some(ref dec_name) = dec.name {
+ if dec_name.is_empty() {
+ dec.name = Some("perVertexStruct".to_string())
+ }
+ }
+ }
+ }
+
+ let (inner, var) = match ext_class {
+ ExtendedClass::Global(mut space) => {
+ if let crate::AddressSpace::Storage { ref mut access } = space {
+ *access &= dec.flags.to_storage_access();
+ }
+ let var = crate::GlobalVariable {
+ binding: dec.resource_binding(),
+ name: dec.name,
+ space,
+ ty,
+ init,
+ };
+ (Variable::Global, var)
+ }
+ ExtendedClass::Input => {
+ let binding = dec.io_binding()?;
+ let mut unsigned_ty = ty;
+ if let crate::Binding::BuiltIn(built_in) = binding {
+ let needs_inner_uint = match built_in {
+ crate::BuiltIn::BaseInstance
+ | crate::BuiltIn::BaseVertex
+ | crate::BuiltIn::InstanceIndex
+ | crate::BuiltIn::SampleIndex
+ | crate::BuiltIn::VertexIndex
+ | crate::BuiltIn::PrimitiveIndex
+ | crate::BuiltIn::LocalInvocationIndex => Some(crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ }),
+ crate::BuiltIn::GlobalInvocationId
+ | crate::BuiltIn::LocalInvocationId
+ | crate::BuiltIn::WorkGroupId
+ | crate::BuiltIn::WorkGroupSize => Some(crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ }),
+ _ => None,
+ };
+ if let (Some(inner), Some(crate::ScalarKind::Sint)) =
+ (needs_inner_uint, module.types[ty].inner.scalar_kind())
+ {
+ unsigned_ty = module
+ .types
+ .insert(crate::Type { name: None, inner }, Default::default());
+ }
+ }
+
+ let var = crate::GlobalVariable {
+ name: dec.name.clone(),
+ space: crate::AddressSpace::Private,
+ binding: None,
+ ty,
+ init: None,
+ };
+
+ let inner = Variable::Input(crate::FunctionArgument {
+ name: dec.name,
+ ty: unsigned_ty,
+ binding: Some(binding),
+ });
+ (inner, var)
+ }
+ ExtendedClass::Output => {
+ // For output interface blocks, this would be a structure.
+ let binding = dec.io_binding().ok();
+ let init = match binding {
+ Some(crate::Binding::BuiltIn(built_in)) => {
+ match null::generate_default_built_in(
+ Some(built_in),
+ ty,
+ &module.types,
+ &mut module.constants,
+ span,
+ ) {
+ Ok(handle) => Some(handle),
+ Err(e) => {
+ log::warn!("Failed to initialize output built-in: {}", e);
+ None
+ }
+ }
+ }
+ Some(crate::Binding::Location { .. }) => None,
+ None => match module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ // A temporary to avoid borrowing `module.types`
+ let pairs = members
+ .iter()
+ .map(|member| {
+ let built_in = match member.binding {
+ Some(crate::Binding::BuiltIn(built_in)) => Some(built_in),
+ _ => None,
+ };
+ (built_in, member.ty)
+ })
+ .collect::<Vec<_>>();
+
+ let mut components = Vec::with_capacity(members.len());
+ for (built_in, member_ty) in pairs {
+ let handle = null::generate_default_built_in(
+ built_in,
+ member_ty,
+ &module.types,
+ &mut module.constants,
+ span,
+ )?;
+ components.push(handle);
+ }
+ Some(module.constants.append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Composite { ty, components },
+ },
+ span,
+ ))
+ }
+ _ => None,
+ },
+ };
+
+ let var = crate::GlobalVariable {
+ name: dec.name,
+ space: crate::AddressSpace::Private,
+ binding: None,
+ ty,
+ init,
+ };
+ let inner = Variable::Output(crate::FunctionResult { ty, binding });
+ (inner, var)
+ }
+ };
+
+ let handle = module.global_variables.append(var, span);
+
+ if module.types[ty].inner.can_comparison_sample(module) {
+ log::debug!("\t\ttracking {:?} for sampling properties", handle);
+
+ self.handle_sampling
+ .insert(handle, image::SamplingFlags::empty());
+ }
+
+ self.lookup_variable.insert(
+ id,
+ LookupVariable {
+ inner,
+ handle,
+ type_id,
+ },
+ );
+ Ok(())
+ }
+}
+
+pub fn parse_u8_slice(data: &[u8], options: &Options) -> Result<crate::Module, Error> {
+ if data.len() % 4 != 0 {
+ return Err(Error::IncompleteData);
+ }
+
+ let words = data
+ .chunks(4)
+ .map(|c| u32::from_le_bytes(c.try_into().unwrap()));
+ Frontend::new(words, options).parse()
+}
+
+#[cfg(test)]
+mod test {
+ #[test]
+ fn parse() {
+ let bin = vec![
+ // Magic number. Version number: 1.0.
+ 0x03, 0x02, 0x23, 0x07, 0x00, 0x00, 0x01, 0x00,
+ // Generator number: 0. Bound: 0.
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Reserved word: 0.
+ 0x00, 0x00, 0x00, 0x00, // OpMemoryModel. Logical.
+ 0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // GLSL450.
+ 0x01, 0x00, 0x00, 0x00,
+ ];
+ let _ = super::parse_u8_slice(&bin, &Default::default()).unwrap();
+ }
+}
+
+/// Helper function to check if `child` is in the scope of `parent`
+fn is_parent(mut child: usize, parent: usize, block_ctx: &BlockContext) -> bool {
+ loop {
+ if child == parent {
+ // The child is in the scope parent
+ break true;
+ } else if child == 0 {
+ // Searched finished at the root the child isn't in the parent's body
+ break false;
+ }
+
+ child = block_ctx.bodies[child].parent;
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/null.rs b/third_party/rust/naga/src/front/spv/null.rs
new file mode 100644
index 0000000000..85350c563e
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/null.rs
@@ -0,0 +1,175 @@
+use super::Error;
+use crate::arena::{Arena, Handle, UniqueArena};
+
+const fn make_scalar_inner(kind: crate::ScalarKind, width: crate::Bytes) -> crate::ConstantInner {
+ crate::ConstantInner::Scalar {
+ width,
+ value: match kind {
+ crate::ScalarKind::Uint => crate::ScalarValue::Uint(0),
+ crate::ScalarKind::Sint => crate::ScalarValue::Sint(0),
+ crate::ScalarKind::Float => crate::ScalarValue::Float(0.0),
+ crate::ScalarKind::Bool => crate::ScalarValue::Bool(false),
+ },
+ }
+}
+
+pub fn generate_null_constant(
+ ty: Handle<crate::Type>,
+ type_arena: &UniqueArena<crate::Type>,
+ constant_arena: &mut Arena<crate::Constant>,
+ span: crate::Span,
+) -> Result<crate::ConstantInner, Error> {
+ let inner = match type_arena[ty].inner {
+ crate::TypeInner::Scalar { kind, width } => make_scalar_inner(kind, width),
+ crate::TypeInner::Vector { size, kind, width } => {
+ let mut components = Vec::with_capacity(size as usize);
+ for _ in 0..size as usize {
+ components.push(constant_arena.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: make_scalar_inner(kind, width),
+ },
+ span,
+ ));
+ }
+ crate::ConstantInner::Composite { ty, components }
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ // If we successfully declared a matrix type, we have declared a vector type for it too.
+ let vector_ty = type_arena
+ .get(&crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ kind: crate::ScalarKind::Float,
+ size: rows,
+ width,
+ },
+ })
+ .unwrap();
+ let vector_inner = generate_null_constant(vector_ty, type_arena, constant_arena, span)?;
+ let vector_handle = constant_arena.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: vector_inner,
+ },
+ span,
+ );
+ crate::ConstantInner::Composite {
+ ty,
+ components: vec![vector_handle; columns as usize],
+ }
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ let mut components = Vec::with_capacity(members.len());
+ // copy out the types to avoid borrowing `members`
+ let member_tys = members.iter().map(|member| member.ty).collect::<Vec<_>>();
+ for member_ty in member_tys {
+ let inner = generate_null_constant(member_ty, type_arena, constant_arena, span)?;
+ components.push(constant_arena.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ span,
+ ));
+ }
+ crate::ConstantInner::Composite { ty, components }
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(handle),
+ ..
+ } => {
+ let size = constant_arena[handle]
+ .to_array_length()
+ .ok_or(Error::InvalidArraySize(handle))?;
+ let inner = generate_null_constant(base, type_arena, constant_arena, span)?;
+ let value = constant_arena.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ span,
+ );
+ crate::ConstantInner::Composite {
+ ty,
+ components: vec![value; size as usize],
+ }
+ }
+ ref other => {
+ log::warn!("null constant type {:?}", other);
+ return Err(Error::UnsupportedType(ty));
+ }
+ };
+ Ok(inner)
+}
+
+/// Create a default value for an output built-in.
+pub fn generate_default_built_in(
+ built_in: Option<crate::BuiltIn>,
+ ty: Handle<crate::Type>,
+ type_arena: &UniqueArena<crate::Type>,
+ constant_arena: &mut Arena<crate::Constant>,
+ span: crate::Span,
+) -> Result<Handle<crate::Constant>, Error> {
+ let inner = match built_in {
+ Some(crate::BuiltIn::Position { .. }) => {
+ let zero = constant_arena.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Float(0.0),
+ width: 4,
+ },
+ },
+ span,
+ );
+ let one = constant_arena.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Float(1.0),
+ width: 4,
+ },
+ },
+ span,
+ );
+ crate::ConstantInner::Composite {
+ ty,
+ components: vec![zero, zero, zero, one],
+ }
+ }
+ Some(crate::BuiltIn::PointSize) => crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Float(1.0),
+ width: 4,
+ },
+ Some(crate::BuiltIn::FragDepth) => crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Float(0.0),
+ width: 4,
+ },
+ Some(crate::BuiltIn::SampleMask) => crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Uint(u64::MAX),
+ width: 4,
+ },
+ //Note: `crate::BuiltIn::ClipDistance` is intentionally left for the default path
+ _ => generate_null_constant(ty, type_arena, constant_arena, span)?,
+ };
+ Ok(constant_arena.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ span,
+ ))
+}
diff --git a/third_party/rust/naga/src/front/type_gen.rs b/third_party/rust/naga/src/front/type_gen.rs
new file mode 100644
index 0000000000..1ee454c448
--- /dev/null
+++ b/third_party/rust/naga/src/front/type_gen.rs
@@ -0,0 +1,314 @@
+/*!
+Type generators.
+*/
+
+use crate::{arena::Handle, span::Span};
+
+impl crate::Module {
+ pub fn generate_atomic_compare_exchange_result(
+ &mut self,
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ ) -> Handle<crate::Type> {
+ let bool_ty = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let scalar_ty = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar { kind, width },
+ },
+ Span::UNDEFINED,
+ );
+
+ self.types.insert(
+ crate::Type {
+ name: Some(format!(
+ "__atomic_compare_exchange_result<{kind:?},{width}>"
+ )),
+ inner: crate::TypeInner::Struct {
+ members: vec![
+ crate::StructMember {
+ name: Some("old_value".to_string()),
+ ty: scalar_ty,
+ binding: None,
+ offset: 0,
+ },
+ crate::StructMember {
+ name: Some("exchanged".to_string()),
+ ty: bool_ty,
+ binding: None,
+ offset: 4,
+ },
+ ],
+ span: 8,
+ },
+ },
+ Span::UNDEFINED,
+ )
+ }
+ /// Populate this module's [`SpecialTypes::ray_desc`] type.
+ ///
+ /// [`SpecialTypes::ray_desc`] is the type of the [`descriptor`] operand of
+ /// an [`Initialize`] [`RayQuery`] statement. In WGSL, it is a struct type
+ /// referred to as `RayDesc`.
+ ///
+ /// Backends consume values of this type to drive platform APIs, so if you
+ /// change any its fields, you must update the backends to match. Look for
+ /// backend code dealing with [`RayQueryFunction::Initialize`].
+ ///
+ /// [`SpecialTypes::ray_desc`]: crate::SpecialTypes::ray_desc
+ /// [`descriptor`]: crate::RayQueryFunction::Initialize::descriptor
+ /// [`Initialize`]: crate::RayQueryFunction::Initialize
+ /// [`RayQuery`]: crate::Statement::RayQuery
+ /// [`RayQueryFunction::Initialize`]: crate::RayQueryFunction::Initialize
+ pub fn generate_ray_desc_type(&mut self) -> Handle<crate::Type> {
+ if let Some(handle) = self.special_types.ray_desc {
+ return handle;
+ }
+
+ let width = 4;
+ let ty_flag = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar {
+ width,
+ kind: crate::ScalarKind::Uint,
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let ty_scalar = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar {
+ width,
+ kind: crate::ScalarKind::Float,
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let ty_vector = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Float,
+ width,
+ },
+ },
+ Span::UNDEFINED,
+ );
+
+ let handle = self.types.insert(
+ crate::Type {
+ name: Some("RayDesc".to_string()),
+ inner: crate::TypeInner::Struct {
+ members: vec![
+ crate::StructMember {
+ name: Some("flags".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 0,
+ },
+ crate::StructMember {
+ name: Some("cull_mask".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 4,
+ },
+ crate::StructMember {
+ name: Some("tmin".to_string()),
+ ty: ty_scalar,
+ binding: None,
+ offset: 8,
+ },
+ crate::StructMember {
+ name: Some("tmax".to_string()),
+ ty: ty_scalar,
+ binding: None,
+ offset: 12,
+ },
+ crate::StructMember {
+ name: Some("origin".to_string()),
+ ty: ty_vector,
+ binding: None,
+ offset: 16,
+ },
+ crate::StructMember {
+ name: Some("dir".to_string()),
+ ty: ty_vector,
+ binding: None,
+ offset: 32,
+ },
+ ],
+ span: 48,
+ },
+ },
+ Span::UNDEFINED,
+ );
+
+ self.special_types.ray_desc = Some(handle);
+ handle
+ }
+
+ /// Populate this module's [`SpecialTypes::ray_intersection`] type.
+ ///
+ /// [`SpecialTypes::ray_intersection`] is the type of a
+ /// `RayQueryGetIntersection` expression. In WGSL, it is a struct type
+ /// referred to as `RayIntersection`.
+ ///
+ /// Backends construct values of this type based on platform APIs, so if you
+ /// change any its fields, you must update the backends to match. Look for
+ /// the backend's handling for [`Expression::RayQueryGetIntersection`].
+ ///
+ /// [`SpecialTypes::ray_intersection`]: crate::SpecialTypes::ray_intersection
+ /// [`Expression::RayQueryGetIntersection`]: crate::Expression::RayQueryGetIntersection
+ pub fn generate_ray_intersection_type(&mut self) -> Handle<crate::Type> {
+ if let Some(handle) = self.special_types.ray_intersection {
+ return handle;
+ }
+
+ let width = 4;
+ let ty_flag = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar {
+ width,
+ kind: crate::ScalarKind::Uint,
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let ty_scalar = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar {
+ width,
+ kind: crate::ScalarKind::Float,
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let ty_barycentrics = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ width,
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Float,
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let ty_bool = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar {
+ width: crate::BOOL_WIDTH,
+ kind: crate::ScalarKind::Bool,
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let ty_transform = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ width,
+ },
+ },
+ Span::UNDEFINED,
+ );
+
+ let handle = self.types.insert(
+ crate::Type {
+ name: Some("RayIntersection".to_string()),
+ inner: crate::TypeInner::Struct {
+ members: vec![
+ crate::StructMember {
+ name: Some("kind".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 0,
+ },
+ crate::StructMember {
+ name: Some("t".to_string()),
+ ty: ty_scalar,
+ binding: None,
+ offset: 4,
+ },
+ crate::StructMember {
+ name: Some("instance_custom_index".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 8,
+ },
+ crate::StructMember {
+ name: Some("instance_id".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 12,
+ },
+ crate::StructMember {
+ name: Some("sbt_record_offset".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 16,
+ },
+ crate::StructMember {
+ name: Some("geometry_index".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 20,
+ },
+ crate::StructMember {
+ name: Some("primitive_index".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 24,
+ },
+ crate::StructMember {
+ name: Some("barycentrics".to_string()),
+ ty: ty_barycentrics,
+ binding: None,
+ offset: 28,
+ },
+ crate::StructMember {
+ name: Some("front_face".to_string()),
+ ty: ty_bool,
+ binding: None,
+ offset: 36,
+ },
+ crate::StructMember {
+ name: Some("object_to_world".to_string()),
+ ty: ty_transform,
+ binding: None,
+ offset: 48,
+ },
+ crate::StructMember {
+ name: Some("world_to_object".to_string()),
+ ty: ty_transform,
+ binding: None,
+ offset: 112,
+ },
+ ],
+ span: 176,
+ },
+ },
+ Span::UNDEFINED,
+ );
+
+ self.special_types.ray_intersection = Some(handle);
+ handle
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/error.rs b/third_party/rust/naga/src/front/wgsl/error.rs
new file mode 100644
index 0000000000..82f7c609ee
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/error.rs
@@ -0,0 +1,690 @@
+use crate::front::wgsl::parse::lexer::Token;
+use crate::proc::{Alignment, ResolveError};
+use crate::{SourceLocation, Span};
+use codespan_reporting::diagnostic::{Diagnostic, Label};
+use codespan_reporting::files::SimpleFile;
+use codespan_reporting::term;
+use std::borrow::Cow;
+use std::ops::Range;
+use termcolor::{ColorChoice, NoColor, StandardStream};
+use thiserror::Error;
+
+#[derive(Clone, Debug)]
+pub struct ParseError {
+ message: String,
+ labels: Vec<(Span, Cow<'static, str>)>,
+ notes: Vec<String>,
+}
+
+impl ParseError {
+ pub fn labels(&self) -> impl Iterator<Item = (Span, &str)> + ExactSizeIterator + '_ {
+ self.labels
+ .iter()
+ .map(|&(span, ref msg)| (span, msg.as_ref()))
+ }
+
+ pub fn message(&self) -> &str {
+ &self.message
+ }
+
+ fn diagnostic(&self) -> Diagnostic<()> {
+ let diagnostic = Diagnostic::error()
+ .with_message(self.message.to_string())
+ .with_labels(
+ self.labels
+ .iter()
+ .map(|label| {
+ Label::primary((), label.0.to_range().unwrap())
+ .with_message(label.1.to_string())
+ })
+ .collect(),
+ )
+ .with_notes(
+ self.notes
+ .iter()
+ .map(|note| format!("note: {note}"))
+ .collect(),
+ );
+ diagnostic
+ }
+
+ /// Emits a summary of the error to standard error stream.
+ pub fn emit_to_stderr(&self, source: &str) {
+ self.emit_to_stderr_with_path(source, "wgsl")
+ }
+
+ /// Emits a summary of the error to standard error stream.
+ pub fn emit_to_stderr_with_path(&self, source: &str, path: &str) {
+ let files = SimpleFile::new(path, source);
+ let config = codespan_reporting::term::Config::default();
+ let writer = StandardStream::stderr(ColorChoice::Auto);
+ term::emit(&mut writer.lock(), &config, &files, &self.diagnostic())
+ .expect("cannot write error");
+ }
+
+ /// Emits a summary of the error to a string.
+ pub fn emit_to_string(&self, source: &str) -> String {
+ self.emit_to_string_with_path(source, "wgsl")
+ }
+
+ /// Emits a summary of the error to a string.
+ pub fn emit_to_string_with_path(&self, source: &str, path: &str) -> String {
+ let files = SimpleFile::new(path, source);
+ let config = codespan_reporting::term::Config::default();
+ let mut writer = NoColor::new(Vec::new());
+ term::emit(&mut writer, &config, &files, &self.diagnostic()).expect("cannot write error");
+ String::from_utf8(writer.into_inner()).unwrap()
+ }
+
+ /// Returns a [`SourceLocation`] for the first label in the error message.
+ pub fn location(&self, source: &str) -> Option<SourceLocation> {
+ self.labels.get(0).map(|label| label.0.location(source))
+ }
+}
+
+impl std::fmt::Display for ParseError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.message)
+ }
+}
+
+impl std::error::Error for ParseError {
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ None
+ }
+}
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum ExpectedToken<'a> {
+ Token(Token<'a>),
+ Identifier,
+ Number,
+ Integer,
+ /// A compile-time constant expression.
+ Constant,
+ /// Expected: constant, parenthesized expression, identifier
+ PrimaryExpression,
+ /// Expected: assignment, increment/decrement expression
+ Assignment,
+ /// Expected: 'case', 'default', '}'
+ SwitchItem,
+ /// Expected: ',', ')'
+ WorkgroupSizeSeparator,
+ /// Expected: 'struct', 'let', 'var', 'type', ';', 'fn', eof
+ GlobalItem,
+ /// Expected a type.
+ Type,
+ /// Access of `var`, `let`, `const`.
+ Variable,
+ /// Access of a function
+ Function,
+}
+
+#[derive(Clone, Copy, Debug, Error, PartialEq)]
+pub enum NumberError {
+ #[error("invalid numeric literal format")]
+ Invalid,
+ #[error("numeric literal not representable by target type")]
+ NotRepresentable,
+ #[error("unimplemented f16 type")]
+ UnimplementedF16,
+}
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum InvalidAssignmentType {
+ Other,
+ Swizzle,
+ ImmutableBinding(Span),
+}
+
+#[derive(Clone, Debug)]
+pub enum Error<'a> {
+ Unexpected(Span, ExpectedToken<'a>),
+ UnexpectedComponents(Span),
+ BadNumber(Span, NumberError),
+ /// A negative signed integer literal where both signed and unsigned,
+ /// but only non-negative literals are allowed.
+ NegativeInt(Span),
+ BadU32Constant(Span),
+ BadMatrixScalarKind(Span, crate::ScalarKind, u8),
+ BadAccessor(Span),
+ BadTexture(Span),
+ BadTypeCast {
+ span: Span,
+ from_type: String,
+ to_type: String,
+ },
+ BadTextureSampleType {
+ span: Span,
+ kind: crate::ScalarKind,
+ width: u8,
+ },
+ BadIncrDecrReferenceType(Span),
+ InvalidResolve(ResolveError),
+ InvalidForInitializer(Span),
+ /// A break if appeared outside of a continuing block
+ InvalidBreakIf(Span),
+ InvalidGatherComponent(Span),
+ InvalidConstructorComponentType(Span, i32),
+ InvalidIdentifierUnderscore(Span),
+ ReservedIdentifierPrefix(Span),
+ UnknownAddressSpace(Span),
+ UnknownAttribute(Span),
+ UnknownBuiltin(Span),
+ UnknownAccess(Span),
+ UnknownIdent(Span, &'a str),
+ UnknownScalarType(Span),
+ UnknownType(Span),
+ UnknownStorageFormat(Span),
+ UnknownConservativeDepth(Span),
+ SizeAttributeTooLow(Span, u32),
+ AlignAttributeTooLow(Span, Alignment),
+ NonPowerOfTwoAlignAttribute(Span),
+ InconsistentBinding(Span),
+ TypeNotConstructible(Span),
+ TypeNotInferrable(Span),
+ InitializationTypeMismatch(Span, String, String),
+ MissingType(Span),
+ MissingAttribute(&'static str, Span),
+ InvalidAtomicPointer(Span),
+ InvalidAtomicOperandType(Span),
+ InvalidRayQueryPointer(Span),
+ Pointer(&'static str, Span),
+ NotPointer(Span),
+ NotReference(&'static str, Span),
+ InvalidAssignment {
+ span: Span,
+ ty: InvalidAssignmentType,
+ },
+ ReservedKeyword(Span),
+ /// Redefinition of an identifier (used for both module-scope and local redefinitions).
+ Redefinition {
+ /// Span of the identifier in the previous definition.
+ previous: Span,
+
+ /// Span of the identifier in the new definition.
+ current: Span,
+ },
+ /// A declaration refers to itself directly.
+ RecursiveDeclaration {
+ /// The location of the name of the declaration.
+ ident: Span,
+
+ /// The point at which it is used.
+ usage: Span,
+ },
+ /// A declaration refers to itself indirectly, through one or more other
+ /// definitions.
+ CyclicDeclaration {
+ /// The location of the name of some declaration in the cycle.
+ ident: Span,
+
+ /// The edges of the cycle of references.
+ ///
+ /// Each `(decl, reference)` pair indicates that the declaration whose
+ /// name is `decl` has an identifier at `reference` whose definition is
+ /// the next declaration in the cycle. The last pair's `reference` is
+ /// the same identifier as `ident`, above.
+ path: Vec<(Span, Span)>,
+ },
+ ConstExprUnsupported(Span),
+ InvalidSwitchValue {
+ uint: bool,
+ span: Span,
+ },
+ CalledEntryPoint(Span),
+ WrongArgumentCount {
+ span: Span,
+ expected: Range<u32>,
+ found: u32,
+ },
+ FunctionReturnsVoid(Span),
+ Other,
+}
+
+impl<'a> Error<'a> {
+ pub(crate) fn as_parse_error(&self, source: &'a str) -> ParseError {
+ match *self {
+ Error::Unexpected(unexpected_span, expected) => {
+ let expected_str = match expected {
+ ExpectedToken::Token(token) => {
+ match token {
+ Token::Separator(c) => format!("'{c}'"),
+ Token::Paren(c) => format!("'{c}'"),
+ Token::Attribute => "@".to_string(),
+ Token::Number(_) => "number".to_string(),
+ Token::Word(s) => s.to_string(),
+ Token::Operation(c) => format!("operation ('{c}')"),
+ Token::LogicalOperation(c) => format!("logical operation ('{c}')"),
+ Token::ShiftOperation(c) => format!("bitshift ('{c}{c}')"),
+ Token::AssignmentOperation(c) if c=='<' || c=='>' => format!("bitshift ('{c}{c}=')"),
+ Token::AssignmentOperation(c) => format!("operation ('{c}=')"),
+ Token::IncrementOperation => "increment operation".to_string(),
+ Token::DecrementOperation => "decrement operation".to_string(),
+ Token::Arrow => "->".to_string(),
+ Token::Unknown(c) => format!("unknown ('{c}')"),
+ Token::Trivia => "trivia".to_string(),
+ Token::End => "end".to_string(),
+ }
+ }
+ ExpectedToken::Identifier => "identifier".to_string(),
+ ExpectedToken::Number => "32-bit signed integer literal".to_string(),
+ ExpectedToken::Integer => "unsigned/signed integer literal".to_string(),
+ ExpectedToken::Constant => "compile-time constant".to_string(),
+ ExpectedToken::PrimaryExpression => "expression".to_string(),
+ ExpectedToken::Assignment => "assignment or increment/decrement".to_string(),
+ ExpectedToken::SwitchItem => "switch item ('case' or 'default') or a closing curly bracket to signify the end of the switch statement ('}')".to_string(),
+ ExpectedToken::WorkgroupSizeSeparator => "workgroup size separator (',') or a closing parenthesis".to_string(),
+ ExpectedToken::GlobalItem => "global item ('struct', 'const', 'var', 'alias', ';', 'fn') or the end of the file".to_string(),
+ ExpectedToken::Type => "type".to_string(),
+ ExpectedToken::Variable => "variable access".to_string(),
+ ExpectedToken::Function => "function name".to_string(),
+ };
+ ParseError {
+ message: format!(
+ "expected {}, found '{}'",
+ expected_str, &source[unexpected_span],
+ ),
+ labels: vec![(unexpected_span, format!("expected {expected_str}").into())],
+ notes: vec![],
+ }
+ }
+ Error::UnexpectedComponents(bad_span) => ParseError {
+ message: "unexpected components".to_string(),
+ labels: vec![(bad_span, "unexpected components".into())],
+ notes: vec![],
+ },
+ Error::BadNumber(bad_span, ref err) => ParseError {
+ message: format!("{}: `{}`", err, &source[bad_span],),
+ labels: vec![(bad_span, err.to_string().into())],
+ notes: vec![],
+ },
+ Error::NegativeInt(bad_span) => ParseError {
+ message: format!(
+ "expected non-negative integer literal, found `{}`",
+ &source[bad_span],
+ ),
+ labels: vec![(bad_span, "expected non-negative integer".into())],
+ notes: vec![],
+ },
+ Error::BadU32Constant(bad_span) => ParseError {
+ message: format!(
+ "expected unsigned integer constant expression, found `{}`",
+ &source[bad_span],
+ ),
+ labels: vec![(bad_span, "expected unsigned integer".into())],
+ notes: vec![],
+ },
+ Error::BadMatrixScalarKind(span, kind, width) => ParseError {
+ message: format!(
+ "matrix scalar type must be floating-point, but found `{}`",
+ kind.to_wgsl(width)
+ ),
+ labels: vec![(span, "must be floating-point (e.g. `f32`)".into())],
+ notes: vec![],
+ },
+ Error::BadAccessor(accessor_span) => ParseError {
+ message: format!("invalid field accessor `{}`", &source[accessor_span],),
+ labels: vec![(accessor_span, "invalid accessor".into())],
+ notes: vec![],
+ },
+ Error::UnknownIdent(ident_span, ident) => ParseError {
+ message: format!("no definition in scope for identifier: '{ident}'"),
+ labels: vec![(ident_span, "unknown identifier".into())],
+ notes: vec![],
+ },
+ Error::UnknownScalarType(bad_span) => ParseError {
+ message: format!("unknown scalar type: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown scalar type".into())],
+ notes: vec!["Valid scalar types are f32, f64, i32, u32, bool".into()],
+ },
+ Error::BadTextureSampleType { span, kind, width } => ParseError {
+ message: format!(
+ "texture sample type must be one of f32, i32 or u32, but found {}",
+ kind.to_wgsl(width)
+ ),
+ labels: vec![(span, "must be one of f32, i32 or u32".into())],
+ notes: vec![],
+ },
+ Error::BadIncrDecrReferenceType(span) => ParseError {
+ message:
+ "increment/decrement operation requires reference type to be one of i32 or u32"
+ .to_string(),
+ labels: vec![(span, "must be a reference type of i32 or u32".into())],
+ notes: vec![],
+ },
+ Error::BadTexture(bad_span) => ParseError {
+ message: format!(
+ "expected an image, but found '{}' which is not an image",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "not an image".into())],
+ notes: vec![],
+ },
+ Error::BadTypeCast {
+ span,
+ ref from_type,
+ ref to_type,
+ } => {
+ let msg = format!("cannot cast a {from_type} to a {to_type}");
+ ParseError {
+ message: msg.clone(),
+ labels: vec![(span, msg.into())],
+ notes: vec![],
+ }
+ }
+ Error::InvalidResolve(ref resolve_error) => ParseError {
+ message: resolve_error.to_string(),
+ labels: vec![],
+ notes: vec![],
+ },
+ Error::InvalidForInitializer(bad_span) => ParseError {
+ message: format!(
+ "for(;;) initializer is not an assignment or a function call: '{}'",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "not an assignment or function call".into())],
+ notes: vec![],
+ },
+ Error::InvalidBreakIf(bad_span) => ParseError {
+ message: "A break if is only allowed in a continuing block".to_string(),
+ labels: vec![(bad_span, "not in a continuing block".into())],
+ notes: vec![],
+ },
+ Error::InvalidGatherComponent(bad_span) => ParseError {
+ message: format!(
+ "textureGather component '{}' doesn't exist, must be 0, 1, 2, or 3",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "invalid component".into())],
+ notes: vec![],
+ },
+ Error::InvalidConstructorComponentType(bad_span, component) => ParseError {
+ message: format!("invalid type for constructor component at index [{component}]"),
+ labels: vec![(bad_span, "invalid component type".into())],
+ notes: vec![],
+ },
+ Error::InvalidIdentifierUnderscore(bad_span) => ParseError {
+ message: "Identifier can't be '_'".to_string(),
+ labels: vec![(bad_span, "invalid identifier".into())],
+ notes: vec![
+ "Use phony assignment instead ('_ =' notice the absence of 'let' or 'var')"
+ .to_string(),
+ ],
+ },
+ Error::ReservedIdentifierPrefix(bad_span) => ParseError {
+ message: format!(
+ "Identifier starts with a reserved prefix: '{}'",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "invalid identifier".into())],
+ notes: vec![],
+ },
+ Error::UnknownAddressSpace(bad_span) => ParseError {
+ message: format!("unknown address space: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown address space".into())],
+ notes: vec![],
+ },
+ Error::UnknownAttribute(bad_span) => ParseError {
+ message: format!("unknown attribute: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown attribute".into())],
+ notes: vec![],
+ },
+ Error::UnknownBuiltin(bad_span) => ParseError {
+ message: format!("unknown builtin: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown builtin".into())],
+ notes: vec![],
+ },
+ Error::UnknownAccess(bad_span) => ParseError {
+ message: format!("unknown access: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown access".into())],
+ notes: vec![],
+ },
+ Error::UnknownStorageFormat(bad_span) => ParseError {
+ message: format!("unknown storage format: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown storage format".into())],
+ notes: vec![],
+ },
+ Error::UnknownConservativeDepth(bad_span) => ParseError {
+ message: format!("unknown conservative depth: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown conservative depth".into())],
+ notes: vec![],
+ },
+ Error::UnknownType(bad_span) => ParseError {
+ message: format!("unknown type: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown type".into())],
+ notes: vec![],
+ },
+ Error::SizeAttributeTooLow(bad_span, min_size) => ParseError {
+ message: format!("struct member size must be at least {min_size}"),
+ labels: vec![(bad_span, format!("must be at least {min_size}").into())],
+ notes: vec![],
+ },
+ Error::AlignAttributeTooLow(bad_span, min_align) => ParseError {
+ message: format!("struct member alignment must be at least {min_align}"),
+ labels: vec![(bad_span, format!("must be at least {min_align}").into())],
+ notes: vec![],
+ },
+ Error::NonPowerOfTwoAlignAttribute(bad_span) => ParseError {
+ message: "struct member alignment must be a power of 2".to_string(),
+ labels: vec![(bad_span, "must be a power of 2".into())],
+ notes: vec![],
+ },
+ Error::InconsistentBinding(span) => ParseError {
+ message: "input/output binding is not consistent".to_string(),
+ labels: vec![(span, "input/output binding is not consistent".into())],
+ notes: vec![],
+ },
+ Error::TypeNotConstructible(span) => ParseError {
+ message: format!("type `{}` is not constructible", &source[span]),
+ labels: vec![(span, "type is not constructible".into())],
+ notes: vec![],
+ },
+ Error::TypeNotInferrable(span) => ParseError {
+ message: "type can't be inferred".to_string(),
+ labels: vec![(span, "type can't be inferred".into())],
+ notes: vec![],
+ },
+ Error::InitializationTypeMismatch(name_span, ref expected_ty, ref got_ty) => {
+ ParseError {
+ message: format!(
+ "the type of `{}` is expected to be `{}`, but got `{}`",
+ &source[name_span], expected_ty, got_ty,
+ ),
+ labels: vec![(
+ name_span,
+ format!("definition of `{}`", &source[name_span]).into(),
+ )],
+ notes: vec![],
+ }
+ }
+ Error::MissingType(name_span) => ParseError {
+ message: format!("variable `{}` needs a type", &source[name_span]),
+ labels: vec![(
+ name_span,
+ format!("definition of `{}`", &source[name_span]).into(),
+ )],
+ notes: vec![],
+ },
+ Error::MissingAttribute(name, name_span) => ParseError {
+ message: format!(
+ "variable `{}` needs a '{}' attribute",
+ &source[name_span], name
+ ),
+ labels: vec![(
+ name_span,
+ format!("definition of `{}`", &source[name_span]).into(),
+ )],
+ notes: vec![],
+ },
+ Error::InvalidAtomicPointer(span) => ParseError {
+ message: "atomic operation is done on a pointer to a non-atomic".to_string(),
+ labels: vec![(span, "atomic pointer is invalid".into())],
+ notes: vec![],
+ },
+ Error::InvalidAtomicOperandType(span) => ParseError {
+ message: "atomic operand type is inconsistent with the operation".to_string(),
+ labels: vec![(span, "atomic operand type is invalid".into())],
+ notes: vec![],
+ },
+ Error::InvalidRayQueryPointer(span) => ParseError {
+ message: "ray query operation is done on a pointer to a non-ray-query".to_string(),
+ labels: vec![(span, "ray query pointer is invalid".into())],
+ notes: vec![],
+ },
+ Error::NotPointer(span) => ParseError {
+ message: "the operand of the `*` operator must be a pointer".to_string(),
+ labels: vec![(span, "expression is not a pointer".into())],
+ notes: vec![],
+ },
+ Error::NotReference(what, span) => ParseError {
+ message: format!("{what} must be a reference"),
+ labels: vec![(span, "expression is not a reference".into())],
+ notes: vec![],
+ },
+ Error::InvalidAssignment { span, ty } => {
+ let (extra_label, notes) = match ty {
+ InvalidAssignmentType::Swizzle => (
+ None,
+ vec![
+ "WGSL does not support assignments to swizzles".into(),
+ "consider assigning each component individually".into(),
+ ],
+ ),
+ InvalidAssignmentType::ImmutableBinding(binding_span) => (
+ Some((binding_span, "this is an immutable binding".into())),
+ vec![format!(
+ "consider declaring '{}' with `var` instead of `let`",
+ &source[binding_span]
+ )],
+ ),
+ InvalidAssignmentType::Other => (None, vec![]),
+ };
+
+ ParseError {
+ message: "invalid left-hand side of assignment".into(),
+ labels: std::iter::once((span, "cannot assign to this expression".into()))
+ .chain(extra_label)
+ .collect(),
+ notes,
+ }
+ }
+ Error::Pointer(what, span) => ParseError {
+ message: format!("{what} must not be a pointer"),
+ labels: vec![(span, "expression is a pointer".into())],
+ notes: vec![],
+ },
+ Error::ReservedKeyword(name_span) => ParseError {
+ message: format!("name `{}` is a reserved keyword", &source[name_span]),
+ labels: vec![(
+ name_span,
+ format!("definition of `{}`", &source[name_span]).into(),
+ )],
+ notes: vec![],
+ },
+ Error::Redefinition { previous, current } => ParseError {
+ message: format!("redefinition of `{}`", &source[current]),
+ labels: vec![
+ (
+ current,
+ format!("redefinition of `{}`", &source[current]).into(),
+ ),
+ (
+ previous,
+ format!("previous definition of `{}`", &source[previous]).into(),
+ ),
+ ],
+ notes: vec![],
+ },
+ Error::RecursiveDeclaration { ident, usage } => ParseError {
+ message: format!("declaration of `{}` is recursive", &source[ident]),
+ labels: vec![(ident, "".into()), (usage, "uses itself here".into())],
+ notes: vec![],
+ },
+ Error::CyclicDeclaration { ident, ref path } => ParseError {
+ message: format!("declaration of `{}` is cyclic", &source[ident]),
+ labels: path
+ .iter()
+ .enumerate()
+ .flat_map(|(i, &(ident, usage))| {
+ [
+ (ident, "".into()),
+ (
+ usage,
+ if i == path.len() - 1 {
+ "ending the cycle".into()
+ } else {
+ format!("uses `{}`", &source[ident]).into()
+ },
+ ),
+ ]
+ })
+ .collect(),
+ notes: vec![],
+ },
+ Error::ConstExprUnsupported(span) => ParseError {
+ message: "this constant expression is not supported".to_string(),
+ labels: vec![(span, "expression is not supported".into())],
+ notes: vec![
+ "this should be fixed in a future version of Naga".into(),
+ "https://github.com/gfx-rs/naga/issues/1829".into(),
+ ],
+ },
+ Error::InvalidSwitchValue { uint, span } => ParseError {
+ message: "invalid switch value".to_string(),
+ labels: vec![(
+ span,
+ if uint {
+ "expected unsigned integer"
+ } else {
+ "expected signed integer"
+ }
+ .into(),
+ )],
+ notes: vec![if uint {
+ format!("suffix the integer with a `u`: '{}u'", &source[span])
+ } else {
+ let span = span.to_range().unwrap();
+ format!(
+ "remove the `u` suffix: '{}'",
+ &source[span.start..span.end - 1]
+ )
+ }],
+ },
+ Error::CalledEntryPoint(span) => ParseError {
+ message: "entry point cannot be called".to_string(),
+ labels: vec![(span, "entry point cannot be called".into())],
+ notes: vec![],
+ },
+ Error::WrongArgumentCount {
+ span,
+ ref expected,
+ found,
+ } => ParseError {
+ message: format!(
+ "wrong number of arguments: expected {}, found {}",
+ if expected.len() < 2 {
+ format!("{}", expected.start)
+ } else {
+ format!("{}..{}", expected.start, expected.end)
+ },
+ found
+ ),
+ labels: vec![(span, "wrong number of arguments".into())],
+ notes: vec![],
+ },
+ Error::FunctionReturnsVoid(span) => ParseError {
+ message: "function does not return any value".to_string(),
+ labels: vec![(span, "".into())],
+ notes: vec![
+ "perhaps you meant to call the function in a separate statement?".into(),
+ ],
+ },
+ Error::Other => ParseError {
+ message: "other error".to_string(),
+ labels: vec![],
+ notes: vec![],
+ },
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/index.rs b/third_party/rust/naga/src/front/wgsl/index.rs
new file mode 100644
index 0000000000..a5524fe8f1
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/index.rs
@@ -0,0 +1,193 @@
+use super::Error;
+use crate::front::wgsl::parse::ast;
+use crate::{FastHashMap, Handle, Span};
+
+/// A `GlobalDecl` list in which each definition occurs before all its uses.
+pub struct Index<'a> {
+ dependency_order: Vec<Handle<ast::GlobalDecl<'a>>>,
+}
+
+impl<'a> Index<'a> {
+ /// Generate an `Index` for the given translation unit.
+ ///
+ /// Perform a topological sort on `tu`'s global declarations, placing
+ /// referents before the definitions that refer to them.
+ ///
+ /// Return an error if the graph of references between declarations contains
+ /// any cycles.
+ pub fn generate(tu: &ast::TranslationUnit<'a>) -> Result<Self, Error<'a>> {
+ // Produce a map from global definitions' names to their `Handle<GlobalDecl>`s.
+ // While doing so, reject conflicting definitions.
+ let mut globals = FastHashMap::with_capacity_and_hasher(tu.decls.len(), Default::default());
+ for (handle, decl) in tu.decls.iter() {
+ let ident = decl_ident(decl);
+ let name = ident.name;
+ if let Some(old) = globals.insert(name, handle) {
+ return Err(Error::Redefinition {
+ previous: decl_ident(&tu.decls[old]).span,
+ current: ident.span,
+ });
+ }
+ }
+
+ let len = tu.decls.len();
+ let solver = DependencySolver {
+ globals: &globals,
+ module: tu,
+ visited: vec![false; len],
+ temp_visited: vec![false; len],
+ path: Vec::new(),
+ out: Vec::with_capacity(len),
+ };
+ let dependency_order = solver.solve()?;
+
+ Ok(Self { dependency_order })
+ }
+
+ /// Iterate over `GlobalDecl`s, visiting each definition before all its uses.
+ ///
+ /// Produce handles for all of the `GlobalDecl`s of the `TranslationUnit`
+ /// passed to `Index::generate`, ordered so that a given declaration is
+ /// produced before any other declaration that uses it.
+ pub fn visit_ordered(&self) -> impl Iterator<Item = Handle<ast::GlobalDecl<'a>>> + '_ {
+ self.dependency_order.iter().copied()
+ }
+}
+
+/// An edge from a reference to its referent in the current depth-first
+/// traversal.
+///
+/// This is like `ast::Dependency`, except that we've determined which
+/// `GlobalDecl` it refers to.
+struct ResolvedDependency<'a> {
+ /// The referent of some identifier used in the current declaration.
+ decl: Handle<ast::GlobalDecl<'a>>,
+
+ /// Where that use occurs within the current declaration.
+ usage: Span,
+}
+
+/// Local state for ordering a `TranslationUnit`'s module-scope declarations.
+///
+/// Values of this type are used temporarily by `Index::generate`
+/// to perform a depth-first sort on the declarations.
+/// Technically, what we want is a topological sort, but a depth-first sort
+/// has one key benefit - it's much more efficient in storing
+/// the path of each node for error generation.
+struct DependencySolver<'source, 'temp> {
+ /// A map from module-scope definitions' names to their handles.
+ globals: &'temp FastHashMap<&'source str, Handle<ast::GlobalDecl<'source>>>,
+
+ /// The translation unit whose declarations we're ordering.
+ module: &'temp ast::TranslationUnit<'source>,
+
+ /// For each handle, whether we have pushed it onto `out` yet.
+ visited: Vec<bool>,
+
+ /// For each handle, whether it is an predecessor in the current depth-first
+ /// traversal. This is used to detect cycles in the reference graph.
+ temp_visited: Vec<bool>,
+
+ /// The current path in our depth-first traversal. Used for generating
+ /// error messages for non-trivial reference cycles.
+ path: Vec<ResolvedDependency<'source>>,
+
+ /// The list of declaration handles, with declarations before uses.
+ out: Vec<Handle<ast::GlobalDecl<'source>>>,
+}
+
+impl<'a> DependencySolver<'a, '_> {
+ /// Produce the sorted list of declaration handles, and check for cycles.
+ fn solve(mut self) -> Result<Vec<Handle<ast::GlobalDecl<'a>>>, Error<'a>> {
+ for (id, _) in self.module.decls.iter() {
+ if self.visited[id.index()] {
+ continue;
+ }
+
+ self.dfs(id)?;
+ }
+
+ Ok(self.out)
+ }
+
+ /// Ensure that all declarations used by `id` have been added to the
+ /// ordering, and then append `id` itself.
+ fn dfs(&mut self, id: Handle<ast::GlobalDecl<'a>>) -> Result<(), Error<'a>> {
+ let decl = &self.module.decls[id];
+ let id_usize = id.index();
+
+ self.temp_visited[id_usize] = true;
+ for dep in decl.dependencies.iter() {
+ if let Some(&dep_id) = self.globals.get(dep.ident) {
+ self.path.push(ResolvedDependency {
+ decl: dep_id,
+ usage: dep.usage,
+ });
+ let dep_id_usize = dep_id.index();
+
+ if self.temp_visited[dep_id_usize] {
+ // Found a cycle.
+ return if dep_id == id {
+ // A declaration refers to itself directly.
+ Err(Error::RecursiveDeclaration {
+ ident: decl_ident(decl).span,
+ usage: dep.usage,
+ })
+ } else {
+ // A declaration refers to itself indirectly, through
+ // one or more other definitions. Report the entire path
+ // of references.
+ let start_at = self
+ .path
+ .iter()
+ .rev()
+ .enumerate()
+ .find_map(|(i, dep)| (dep.decl == dep_id).then_some(i))
+ .unwrap_or(0);
+
+ Err(Error::CyclicDeclaration {
+ ident: decl_ident(&self.module.decls[dep_id]).span,
+ path: self.path[start_at..]
+ .iter()
+ .map(|curr_dep| {
+ let curr_id = curr_dep.decl;
+ let curr_decl = &self.module.decls[curr_id];
+
+ (decl_ident(curr_decl).span, curr_dep.usage)
+ })
+ .collect(),
+ })
+ };
+ } else if !self.visited[dep_id_usize] {
+ self.dfs(dep_id)?;
+ }
+
+ // Remove this edge from the current path.
+ self.path.pop();
+ }
+
+ // Ignore unresolved identifiers; they may be predeclared objects.
+ }
+
+ // Remove this node from the current path.
+ self.temp_visited[id_usize] = false;
+
+ // Now everything this declaration uses has been visited, and is already
+ // present in `out`. That means we we can append this one to the
+ // ordering, and mark it as visited.
+ self.out.push(id);
+ self.visited[id_usize] = true;
+
+ Ok(())
+ }
+}
+
+const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
+ match decl.kind {
+ ast::GlobalDeclKind::Fn(ref f) => f.name,
+ ast::GlobalDeclKind::Var(ref v) => v.name,
+ ast::GlobalDeclKind::Const(ref c) => c.name,
+ ast::GlobalDeclKind::Struct(ref s) => s.name,
+ ast::GlobalDeclKind::Type(ref t) => t.name,
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/lower/construction.rs b/third_party/rust/naga/src/front/wgsl/lower/construction.rs
new file mode 100644
index 0000000000..723d4441f5
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/lower/construction.rs
@@ -0,0 +1,668 @@
+use crate::front::wgsl::parse::ast;
+use crate::{Handle, Span};
+
+use crate::front::wgsl::error::Error;
+use crate::front::wgsl::lower::{ExpressionContext, Lowerer, OutputContext};
+use crate::proc::TypeResolution;
+
+enum ConcreteConstructorHandle {
+ PartialVector {
+ size: crate::VectorSize,
+ },
+ PartialMatrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ },
+ PartialArray,
+ Type(Handle<crate::Type>),
+}
+
+impl ConcreteConstructorHandle {
+ fn borrow<'a>(&self, module: &'a crate::Module) -> ConcreteConstructor<'a> {
+ match *self {
+ Self::PartialVector { size } => ConcreteConstructor::PartialVector { size },
+ Self::PartialMatrix { columns, rows } => {
+ ConcreteConstructor::PartialMatrix { columns, rows }
+ }
+ Self::PartialArray => ConcreteConstructor::PartialArray,
+ Self::Type(handle) => ConcreteConstructor::Type(handle, &module.types[handle].inner),
+ }
+ }
+}
+
+enum ConcreteConstructor<'a> {
+ PartialVector {
+ size: crate::VectorSize,
+ },
+ PartialMatrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ },
+ PartialArray,
+ Type(Handle<crate::Type>, &'a crate::TypeInner),
+}
+
+impl ConcreteConstructorHandle {
+ fn to_error_string(&self, ctx: ExpressionContext) -> String {
+ match *self {
+ Self::PartialVector { size } => {
+ format!("vec{}<?>", size as u32,)
+ }
+ Self::PartialMatrix { columns, rows } => {
+ format!("mat{}x{}<?>", columns as u32, rows as u32,)
+ }
+ Self::PartialArray => "array<?, ?>".to_string(),
+ Self::Type(ty) => ctx.format_type(ty),
+ }
+ }
+}
+
+enum ComponentsHandle<'a> {
+ None,
+ One {
+ component: Handle<crate::Expression>,
+ span: Span,
+ ty: &'a TypeResolution,
+ },
+ Many {
+ components: Vec<Handle<crate::Expression>>,
+ spans: Vec<Span>,
+ first_component_ty: &'a TypeResolution,
+ },
+}
+
+impl<'a> ComponentsHandle<'a> {
+ fn borrow(self, module: &'a crate::Module) -> Components<'a> {
+ match self {
+ Self::None => Components::None,
+ Self::One {
+ component,
+ span,
+ ty,
+ } => Components::One {
+ component,
+ span,
+ ty_inner: ty.inner_with(&module.types),
+ },
+ Self::Many {
+ components,
+ spans,
+ first_component_ty,
+ } => Components::Many {
+ components,
+ spans,
+ first_component_ty_inner: first_component_ty.inner_with(&module.types),
+ },
+ }
+ }
+}
+
+enum Components<'a> {
+ None,
+ One {
+ component: Handle<crate::Expression>,
+ span: Span,
+ ty_inner: &'a crate::TypeInner,
+ },
+ Many {
+ components: Vec<Handle<crate::Expression>>,
+ spans: Vec<Span>,
+ first_component_ty_inner: &'a crate::TypeInner,
+ },
+}
+
+impl Components<'_> {
+ fn into_components_vec(self) -> Vec<Handle<crate::Expression>> {
+ match self {
+ Self::None => vec![],
+ Self::One { component, .. } => vec![component],
+ Self::Many { components, .. } => components,
+ }
+ }
+}
+
+impl<'source, 'temp> Lowerer<'source, 'temp> {
+ /// Generate Naga IR for a type constructor expression.
+ ///
+ /// The `constructor` value represents the head of the constructor
+ /// expression, which is at least a hint of which type is being built; if
+ /// it's one of the `Partial` variants, we need to consider the argument
+ /// types as well.
+ ///
+ /// This is used for [`Construct`] expressions, but also for [`Call`]
+ /// expressions, once we've determined that the "callable" (in WGSL spec
+ /// terms) is actually a type.
+ ///
+ /// [`Construct`]: ast::Expression::Construct
+ /// [`Call`]: ast::Expression::Call
+ pub fn construct(
+ &mut self,
+ span: Span,
+ constructor: &ast::ConstructorType<'source>,
+ ty_span: Span,
+ components: &[Handle<ast::Expression<'source>>],
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let constructor_h = self.constructor(constructor, ctx.as_output())?;
+
+ let components_h = match *components {
+ [] => ComponentsHandle::None,
+ [component] => {
+ let span = ctx.ast_expressions.get_span(component);
+ let component = self.expression(component, ctx.reborrow())?;
+ ctx.grow_types(component)?;
+ let ty = &ctx.typifier[component];
+
+ ComponentsHandle::One {
+ component,
+ span,
+ ty,
+ }
+ }
+ [component, ref rest @ ..] => {
+ let span = ctx.ast_expressions.get_span(component);
+ let component = self.expression(component, ctx.reborrow())?;
+
+ let components = std::iter::once(Ok(component))
+ .chain(
+ rest.iter()
+ .map(|&component| self.expression(component, ctx.reborrow())),
+ )
+ .collect::<Result<_, _>>()?;
+ let spans = std::iter::once(span)
+ .chain(
+ rest.iter()
+ .map(|&component| ctx.ast_expressions.get_span(component)),
+ )
+ .collect();
+
+ ctx.grow_types(component)?;
+ let ty = &ctx.typifier[component];
+
+ ComponentsHandle::Many {
+ components,
+ spans,
+ first_component_ty: ty,
+ }
+ }
+ };
+
+ let (components, constructor) = (
+ components_h.borrow(ctx.module),
+ constructor_h.borrow(ctx.module),
+ );
+ let expr = match (components, constructor) {
+ // Empty constructor
+ (Components::None, dst_ty) => {
+ let ty = match dst_ty {
+ ConcreteConstructor::Type(ty, _) => ty,
+ _ => return Err(Error::TypeNotInferrable(ty_span)),
+ };
+
+ return match ctx.create_zero_value_constant(ty) {
+ Some(constant) => {
+ Ok(ctx.interrupt_emitter(crate::Expression::Constant(constant), span))
+ }
+ None => Err(Error::TypeNotConstructible(ty_span)),
+ };
+ }
+
+ // Scalar constructor & conversion (scalar -> scalar)
+ (
+ Components::One {
+ component,
+ ty_inner: &crate::TypeInner::Scalar { .. },
+ ..
+ },
+ ConcreteConstructor::Type(_, &crate::TypeInner::Scalar { kind, width }),
+ ) => crate::Expression::As {
+ expr: component,
+ kind,
+ convert: Some(width),
+ },
+
+ // Vector conversion (vector -> vector)
+ (
+ Components::One {
+ component,
+ ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
+ ..
+ },
+ ConcreteConstructor::Type(
+ _,
+ &crate::TypeInner::Vector {
+ size: dst_size,
+ kind: dst_kind,
+ width: dst_width,
+ },
+ ),
+ ) if dst_size == src_size => crate::Expression::As {
+ expr: component,
+ kind: dst_kind,
+ convert: Some(dst_width),
+ },
+
+ // Vector conversion (vector -> vector) - partial
+ (
+ Components::One {
+ component,
+ ty_inner:
+ &crate::TypeInner::Vector {
+ size: src_size,
+ kind: src_kind,
+ ..
+ },
+ ..
+ },
+ ConcreteConstructor::PartialVector { size: dst_size },
+ ) if dst_size == src_size => crate::Expression::As {
+ expr: component,
+ kind: src_kind,
+ convert: None,
+ },
+
+ // Matrix conversion (matrix -> matrix)
+ (
+ Components::One {
+ component,
+ ty_inner:
+ &crate::TypeInner::Matrix {
+ columns: src_columns,
+ rows: src_rows,
+ ..
+ },
+ ..
+ },
+ ConcreteConstructor::Type(
+ _,
+ &crate::TypeInner::Matrix {
+ columns: dst_columns,
+ rows: dst_rows,
+ width: dst_width,
+ },
+ ),
+ ) if dst_columns == src_columns && dst_rows == src_rows => crate::Expression::As {
+ expr: component,
+ kind: crate::ScalarKind::Float,
+ convert: Some(dst_width),
+ },
+
+ // Matrix conversion (matrix -> matrix) - partial
+ (
+ Components::One {
+ component,
+ ty_inner:
+ &crate::TypeInner::Matrix {
+ columns: src_columns,
+ rows: src_rows,
+ ..
+ },
+ ..
+ },
+ ConcreteConstructor::PartialMatrix {
+ columns: dst_columns,
+ rows: dst_rows,
+ },
+ ) if dst_columns == src_columns && dst_rows == src_rows => crate::Expression::As {
+ expr: component,
+ kind: crate::ScalarKind::Float,
+ convert: None,
+ },
+
+ // Vector constructor (splat) - infer type
+ (
+ Components::One {
+ component,
+ ty_inner: &crate::TypeInner::Scalar { .. },
+ ..
+ },
+ ConcreteConstructor::PartialVector { size },
+ ) => crate::Expression::Splat {
+ size,
+ value: component,
+ },
+
+ // Vector constructor (splat)
+ (
+ Components::One {
+ component,
+ ty_inner:
+ &crate::TypeInner::Scalar {
+ kind: src_kind,
+ width: src_width,
+ ..
+ },
+ ..
+ },
+ ConcreteConstructor::Type(
+ _,
+ &crate::TypeInner::Vector {
+ size,
+ kind: dst_kind,
+ width: dst_width,
+ },
+ ),
+ ) if dst_kind == src_kind || dst_width == src_width => crate::Expression::Splat {
+ size,
+ value: component,
+ },
+
+ // Vector constructor (by elements)
+ (
+ Components::Many {
+ components,
+ first_component_ty_inner:
+ &crate::TypeInner::Scalar { kind, width }
+ | &crate::TypeInner::Vector { kind, width, .. },
+ ..
+ },
+ ConcreteConstructor::PartialVector { size },
+ )
+ | (
+ Components::Many {
+ components,
+ first_component_ty_inner:
+ &crate::TypeInner::Scalar { .. } | &crate::TypeInner::Vector { .. },
+ ..
+ },
+ ConcreteConstructor::Type(_, &crate::TypeInner::Vector { size, width, kind }),
+ ) => {
+ let inner = crate::TypeInner::Vector { size, kind, width };
+ let ty = ctx.ensure_type_exists(inner);
+ crate::Expression::Compose { ty, components }
+ }
+
+ // Matrix constructor (by elements)
+ (
+ Components::Many {
+ components,
+ first_component_ty_inner: &crate::TypeInner::Scalar { width, .. },
+ ..
+ },
+ ConcreteConstructor::PartialMatrix { columns, rows },
+ )
+ | (
+ Components::Many {
+ components,
+ first_component_ty_inner: &crate::TypeInner::Scalar { .. },
+ ..
+ },
+ ConcreteConstructor::Type(
+ _,
+ &crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ },
+ ),
+ ) => {
+ let vec_ty = ctx.ensure_type_exists(crate::TypeInner::Vector {
+ width,
+ kind: crate::ScalarKind::Float,
+ size: rows,
+ });
+
+ let components = components
+ .chunks(rows as usize)
+ .map(|vec_components| {
+ ctx.naga_expressions.append(
+ crate::Expression::Compose {
+ ty: vec_ty,
+ components: Vec::from(vec_components),
+ },
+ Default::default(),
+ )
+ })
+ .collect();
+
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ });
+ crate::Expression::Compose { ty, components }
+ }
+
+ // Matrix constructor (by columns)
+ (
+ Components::Many {
+ components,
+ first_component_ty_inner: &crate::TypeInner::Vector { width, .. },
+ ..
+ },
+ ConcreteConstructor::PartialMatrix { columns, rows },
+ )
+ | (
+ Components::Many {
+ components,
+ first_component_ty_inner: &crate::TypeInner::Vector { .. },
+ ..
+ },
+ ConcreteConstructor::Type(
+ _,
+ &crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ },
+ ),
+ ) => {
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ });
+ crate::Expression::Compose { ty, components }
+ }
+
+ // Array constructor - infer type
+ (components, ConcreteConstructor::PartialArray) => {
+ let components = components.into_components_vec();
+
+ let base = ctx.register_type(components[0])?;
+
+ let size = crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Uint(components.len() as _),
+ },
+ };
+
+ let inner = crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(
+ ctx.module.constants.fetch_or_append(size, Span::UNDEFINED),
+ ),
+ stride: {
+ self.layouter
+ .update(&ctx.module.types, &ctx.module.constants)
+ .unwrap();
+ self.layouter[base].to_stride()
+ },
+ };
+ let ty = ctx.ensure_type_exists(inner);
+
+ crate::Expression::Compose { ty, components }
+ }
+
+ // Array constructor
+ (components, ConcreteConstructor::Type(ty, &crate::TypeInner::Array { .. })) => {
+ let components = components.into_components_vec();
+ crate::Expression::Compose { ty, components }
+ }
+
+ // Struct constructor
+ (components, ConcreteConstructor::Type(ty, &crate::TypeInner::Struct { .. })) => {
+ crate::Expression::Compose {
+ ty,
+ components: components.into_components_vec(),
+ }
+ }
+
+ // ERRORS
+
+ // Bad conversion (type cast)
+ (Components::One { span, ty_inner, .. }, _) => {
+ let from_type = ctx.format_typeinner(ty_inner);
+ return Err(Error::BadTypeCast {
+ span,
+ from_type,
+ to_type: constructor_h.to_error_string(ctx.reborrow()),
+ });
+ }
+
+ // Too many parameters for scalar constructor
+ (
+ Components::Many { spans, .. },
+ ConcreteConstructor::Type(_, &crate::TypeInner::Scalar { .. }),
+ ) => {
+ let span = spans[1].until(spans.last().unwrap());
+ return Err(Error::UnexpectedComponents(span));
+ }
+
+ // Parameters are of the wrong type for vector or matrix constructor
+ (
+ Components::Many { spans, .. },
+ ConcreteConstructor::Type(
+ _,
+ &crate::TypeInner::Vector { .. } | &crate::TypeInner::Matrix { .. },
+ )
+ | ConcreteConstructor::PartialVector { .. }
+ | ConcreteConstructor::PartialMatrix { .. },
+ ) => {
+ return Err(Error::InvalidConstructorComponentType(spans[0], 0));
+ }
+
+ // Other types can't be constructed
+ _ => return Err(Error::TypeNotConstructible(ty_span)),
+ };
+
+ let expr = ctx.naga_expressions.append(expr, span);
+ Ok(expr)
+ }
+
+ /// Build a Naga IR [`ConstantInner`] given a WGSL construction expression.
+ ///
+ /// Given `constructor`, representing the head of a WGSL [`type constructor
+ /// expression`], and a slice of [`ast::Expression`] handles representing
+ /// the constructor's arguments, build a Naga [`ConstantInner`] value
+ /// representing the given value.
+ ///
+ /// If `constructor` is for a composite type, this may entail adding new
+ /// [`Type`]s and [`Constant`]s to [`ctx.module`], if it doesn't already
+ /// have what we need.
+ ///
+ /// If the arguments cannot be evaluated at compile time, return an error.
+ ///
+ /// [`ConstantInner`]: crate::ConstantInner
+ /// [`type constructor expression`]: https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr
+ /// [`Function::expressions`]: ast::Function::expressions
+ /// [`TranslationUnit::global_expressions`]: ast::TranslationUnit::global_expressions
+ /// [`Type`]: crate::Type
+ /// [`Constant`]: crate::Constant
+ /// [`ctx.module`]: OutputContext::module
+ pub fn const_construct(
+ &mut self,
+ span: Span,
+ constructor: &ast::ConstructorType<'source>,
+ components: &[Handle<ast::Expression<'source>>],
+ mut ctx: OutputContext<'source, '_, '_>,
+ ) -> Result<crate::ConstantInner, Error<'source>> {
+ // TODO: Support zero values, splatting and inference.
+
+ let constructor = self.constructor(constructor, ctx.reborrow())?;
+
+ let c = match constructor {
+ ConcreteConstructorHandle::Type(ty) => {
+ let components = components
+ .iter()
+ .map(|&expr| self.constant(expr, ctx.reborrow()))
+ .collect::<Result<_, _>>()?;
+
+ crate::ConstantInner::Composite { ty, components }
+ }
+ _ => return Err(Error::ConstExprUnsupported(span)),
+ };
+ Ok(c)
+ }
+
+ /// Build a Naga IR [`Type`] for `constructor` if there is enough
+ /// information to do so.
+ ///
+ /// For `Partial` variants of [`ast::ConstructorType`], we don't know the
+ /// component type, so in that case we return the appropriate `Partial`
+ /// variant of [`ConcreteConstructorHandle`].
+ ///
+ /// But for the other `ConstructorType` variants, we have everything we need
+ /// to know to actually produce a Naga IR type. In this case we add to/find
+ /// in [`ctx.module`] a suitable Naga `Type` and return a
+ /// [`ConcreteConstructorHandle::Type`] value holding its handle.
+ ///
+ /// Note that constructing an [`Array`] type may require inserting
+ /// [`Constant`]s as well as `Type`s into `ctx.module`, to represent the
+ /// array's length.
+ ///
+ /// [`Type`]: crate::Type
+ /// [`ctx.module`]: OutputContext::module
+ /// [`Array`]: crate::TypeInner::Array
+ /// [`Constant`]: crate::Constant
+ fn constructor<'out>(
+ &mut self,
+ constructor: &ast::ConstructorType<'source>,
+ mut ctx: OutputContext<'source, '_, 'out>,
+ ) -> Result<ConcreteConstructorHandle, Error<'source>> {
+ let c = match *constructor {
+ ast::ConstructorType::Scalar { width, kind } => {
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Scalar { width, kind });
+ ConcreteConstructorHandle::Type(ty)
+ }
+ ast::ConstructorType::PartialVector { size } => {
+ ConcreteConstructorHandle::PartialVector { size }
+ }
+ ast::ConstructorType::Vector { size, kind, width } => {
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Vector { size, kind, width });
+ ConcreteConstructorHandle::Type(ty)
+ }
+ ast::ConstructorType::PartialMatrix { rows, columns } => {
+ ConcreteConstructorHandle::PartialMatrix { rows, columns }
+ }
+ ast::ConstructorType::Matrix {
+ rows,
+ columns,
+ width,
+ } => {
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ });
+ ConcreteConstructorHandle::Type(ty)
+ }
+ ast::ConstructorType::PartialArray => ConcreteConstructorHandle::PartialArray,
+ ast::ConstructorType::Array { base, size } => {
+ let base = self.resolve_ast_type(base, ctx.reborrow())?;
+ let size = match size {
+ ast::ArraySize::Constant(expr) => {
+ crate::ArraySize::Constant(self.constant(expr, ctx.reborrow())?)
+ }
+ ast::ArraySize::Dynamic => crate::ArraySize::Dynamic,
+ };
+
+ self.layouter
+ .update(&ctx.module.types, &ctx.module.constants)
+ .unwrap();
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Array {
+ base,
+ size,
+ stride: self.layouter[base].to_stride(),
+ });
+ ConcreteConstructorHandle::Type(ty)
+ }
+ ast::ConstructorType::Type(ty) => ConcreteConstructorHandle::Type(ty),
+ };
+
+ Ok(c)
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/lower/mod.rs b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
new file mode 100644
index 0000000000..bc9cce1bee
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
@@ -0,0 +1,2426 @@
+use crate::front::wgsl::error::{Error, ExpectedToken, InvalidAssignmentType};
+use crate::front::wgsl::index::Index;
+use crate::front::wgsl::parse::number::Number;
+use crate::front::wgsl::parse::{ast, conv};
+use crate::front::{Emitter, Typifier};
+use crate::proc::{ensure_block_returns, Alignment, Layouter, ResolveContext, TypeResolution};
+use crate::{Arena, FastHashMap, Handle, Span};
+use indexmap::IndexMap;
+
+mod construction;
+
+/// State for constructing a `crate::Module`.
+pub struct OutputContext<'source, 'temp, 'out> {
+ /// The `TranslationUnit`'s expressions arena.
+ ast_expressions: &'temp Arena<ast::Expression<'source>>,
+
+ /// The `TranslationUnit`'s types arena.
+ types: &'temp Arena<ast::Type<'source>>,
+
+ // Naga IR values.
+ /// The map from the names of module-scope declarations to the Naga IR
+ /// `Handle`s we have built for them, owned by `Lowerer::lower`.
+ globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>,
+
+ /// The module we're constructing.
+ module: &'out mut crate::Module,
+}
+
+impl<'source> OutputContext<'source, '_, '_> {
+ fn reborrow(&mut self) -> OutputContext<'source, '_, '_> {
+ OutputContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ }
+ }
+
+ fn ensure_type_exists(&mut self, inner: crate::TypeInner) -> Handle<crate::Type> {
+ self.module
+ .types
+ .insert(crate::Type { inner, name: None }, Span::UNDEFINED)
+ }
+}
+
+/// State for lowering a statement within a function.
+pub struct StatementContext<'source, 'temp, 'out> {
+ // WGSL AST values.
+ /// A reference to [`TranslationUnit::expressions`] for the translation unit
+ /// we're lowering.
+ ///
+ /// [`TranslationUnit::expressions`]: ast::TranslationUnit::expressions
+ ast_expressions: &'temp Arena<ast::Expression<'source>>,
+
+ /// A reference to [`TranslationUnit::types`] for the translation unit
+ /// we're lowering.
+ ///
+ /// [`TranslationUnit::types`]: ast::TranslationUnit::types
+ types: &'temp Arena<ast::Type<'source>>,
+
+ // Naga IR values.
+ /// The map from the names of module-scope declarations to the Naga IR
+ /// `Handle`s we have built for them, owned by `Lowerer::lower`.
+ globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>,
+
+ /// A map from `ast::Local` handles to the Naga expressions we've built for them.
+ ///
+ /// The Naga expressions are either [`LocalVariable`] or
+ /// [`FunctionArgument`] expressions.
+ ///
+ /// [`LocalVariable`]: crate::Expression::LocalVariable
+ /// [`FunctionArgument`]: crate::Expression::FunctionArgument
+ local_table: &'temp mut FastHashMap<Handle<ast::Local>, TypedExpression>,
+
+ typifier: &'temp mut Typifier,
+ variables: &'out mut Arena<crate::LocalVariable>,
+ naga_expressions: &'out mut Arena<crate::Expression>,
+ /// Stores the names of expressions that are assigned in `let` statement
+ /// Also stores the spans of the names, for use in errors.
+ named_expressions: &'out mut IndexMap<Handle<crate::Expression>, (String, Span)>,
+ arguments: &'out [crate::FunctionArgument],
+ module: &'out mut crate::Module,
+}
+
+impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
+ fn reborrow(&mut self) -> StatementContext<'a, '_, '_> {
+ StatementContext {
+ local_table: self.local_table,
+ globals: self.globals,
+ types: self.types,
+ ast_expressions: self.ast_expressions,
+ typifier: self.typifier,
+ variables: self.variables,
+ naga_expressions: self.naga_expressions,
+ named_expressions: self.named_expressions,
+ arguments: self.arguments,
+ module: self.module,
+ }
+ }
+
+ fn as_expression<'t>(
+ &'t mut self,
+ block: &'t mut crate::Block,
+ emitter: &'t mut Emitter,
+ ) -> ExpressionContext<'a, 't, '_>
+ where
+ 'temp: 't,
+ {
+ ExpressionContext {
+ local_table: self.local_table,
+ globals: self.globals,
+ types: self.types,
+ ast_expressions: self.ast_expressions,
+ typifier: self.typifier,
+ naga_expressions: self.naga_expressions,
+ module: self.module,
+ local_vars: self.variables,
+ arguments: self.arguments,
+ block,
+ emitter,
+ }
+ }
+
+ fn as_output(&mut self) -> OutputContext<'a, '_, '_> {
+ OutputContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ }
+ }
+
+ fn invalid_assignment_type(&self, expr: Handle<crate::Expression>) -> InvalidAssignmentType {
+ if let Some(&(_, span)) = self.named_expressions.get(&expr) {
+ InvalidAssignmentType::ImmutableBinding(span)
+ } else {
+ match self.naga_expressions[expr] {
+ crate::Expression::Swizzle { .. } => InvalidAssignmentType::Swizzle,
+ crate::Expression::Access { base, .. } => self.invalid_assignment_type(base),
+ crate::Expression::AccessIndex { base, .. } => self.invalid_assignment_type(base),
+ _ => InvalidAssignmentType::Other,
+ }
+ }
+ }
+}
+
+/// State for lowering an `ast::Expression` to Naga IR.
+///
+/// Not to be confused with `parser::ExpressionContext`.
+pub struct ExpressionContext<'source, 'temp, 'out> {
+ // WGSL AST values.
+ local_table: &'temp mut FastHashMap<Handle<ast::Local>, TypedExpression>,
+ ast_expressions: &'temp Arena<ast::Expression<'source>>,
+ types: &'temp Arena<ast::Type<'source>>,
+
+ // Naga IR values.
+ /// The map from the names of module-scope declarations to the Naga IR
+ /// `Handle`s we have built for them, owned by `Lowerer::lower`.
+ globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>,
+
+ typifier: &'temp mut Typifier,
+ naga_expressions: &'out mut Arena<crate::Expression>,
+ local_vars: &'out Arena<crate::LocalVariable>,
+ arguments: &'out [crate::FunctionArgument],
+ module: &'out mut crate::Module,
+ block: &'temp mut crate::Block,
+ emitter: &'temp mut Emitter,
+}
+
+impl<'a> ExpressionContext<'a, '_, '_> {
+ fn reborrow(&mut self) -> ExpressionContext<'a, '_, '_> {
+ ExpressionContext {
+ local_table: self.local_table,
+ globals: self.globals,
+ types: self.types,
+ ast_expressions: self.ast_expressions,
+ typifier: self.typifier,
+ naga_expressions: self.naga_expressions,
+ module: self.module,
+ local_vars: self.local_vars,
+ arguments: self.arguments,
+ block: self.block,
+ emitter: self.emitter,
+ }
+ }
+
+ fn as_output(&mut self) -> OutputContext<'a, '_, '_> {
+ OutputContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ }
+ }
+
+ /// Determine the type of `handle`, and add it to the module's arena.
+ ///
+ /// If you just need a `TypeInner` for `handle`'s type, use
+ /// [`grow_types`] and [`resolved_inner`] instead. This function
+ /// should only be used when the type of `handle` needs to appear
+ /// in the module's final `Arena<Type>`, for example, if you're
+ /// creating a [`LocalVariable`] whose type is inferred from its
+ /// initializer.
+ ///
+ /// [`grow_types`]: Self::grow_types
+ /// [`resolved_inner`]: Self::resolved_inner
+ /// [`LocalVariable`]: crate::LocalVariable
+ fn register_type(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ ) -> Result<Handle<crate::Type>, Error<'a>> {
+ self.grow_types(handle)?;
+ Ok(self.typifier.register_type(handle, &mut self.module.types))
+ }
+
+ /// Resolve the types of all expressions up through `handle`.
+ ///
+ /// Ensure that [`self.typifier`] has a [`TypeResolution`] for
+ /// every expression in [`self.naga_expressions`].
+ ///
+ /// This does not add types to any arena. The [`Typifier`]
+ /// documentation explains the steps we take to avoid filling
+ /// arenas with intermediate types.
+ ///
+ /// This function takes `&mut self`, so it can't conveniently
+ /// return a shared reference to the resulting `TypeResolution`:
+ /// the shared reference would extend the mutable borrow, and you
+ /// wouldn't be able to use `self` for anything else. Instead, you
+ /// should call `grow_types` to cover the handles you need, and
+ /// then use `self.typifier[handle]` or
+ /// [`self.resolved_inner(handle)`] to get at their resolutions.
+ ///
+ /// [`self.typifier`]: ExpressionContext::typifier
+ /// [`self.resolved_inner(handle)`]: ExpressionContext::resolved_inner
+ /// [`Typifier`]: Typifier
+ fn grow_types(&mut self, handle: Handle<crate::Expression>) -> Result<&mut Self, Error<'a>> {
+ let resolve_ctx = ResolveContext::with_locals(self.module, self.local_vars, self.arguments);
+ self.typifier
+ .grow(handle, self.naga_expressions, &resolve_ctx)
+ .map_err(Error::InvalidResolve)?;
+ Ok(self)
+ }
+
+ fn resolved_inner(&self, handle: Handle<crate::Expression>) -> &crate::TypeInner {
+ self.typifier[handle].inner_with(&self.module.types)
+ }
+
+ fn image_data(
+ &mut self,
+ image: Handle<crate::Expression>,
+ span: Span,
+ ) -> Result<(crate::ImageClass, bool), Error<'a>> {
+ self.grow_types(image)?;
+ match *self.resolved_inner(image) {
+ crate::TypeInner::Image { class, arrayed, .. } => Ok((class, arrayed)),
+ _ => Err(Error::BadTexture(span)),
+ }
+ }
+
+ fn prepare_args<'b>(
+ &mut self,
+ args: &'b [Handle<ast::Expression<'a>>],
+ min_args: u32,
+ span: Span,
+ ) -> ArgumentContext<'b, 'a> {
+ ArgumentContext {
+ args: args.iter(),
+ min_args,
+ args_used: 0,
+ total_args: args.len() as u32,
+ span,
+ }
+ }
+
+ /// Insert splats, if needed by the non-'*' operations.
+ fn binary_op_splat(
+ &mut self,
+ op: crate::BinaryOperator,
+ left: &mut Handle<crate::Expression>,
+ right: &mut Handle<crate::Expression>,
+ ) -> Result<(), Error<'a>> {
+ if op != crate::BinaryOperator::Multiply {
+ self.grow_types(*left)?.grow_types(*right)?;
+
+ let left_size = match *self.resolved_inner(*left) {
+ crate::TypeInner::Vector { size, .. } => Some(size),
+ _ => None,
+ };
+
+ match (left_size, self.resolved_inner(*right)) {
+ (Some(size), &crate::TypeInner::Scalar { .. }) => {
+ *right = self.naga_expressions.append(
+ crate::Expression::Splat {
+ size,
+ value: *right,
+ },
+ self.naga_expressions.get_span(*right),
+ );
+ }
+ (None, &crate::TypeInner::Vector { size, .. }) => {
+ *left = self.naga_expressions.append(
+ crate::Expression::Splat { size, value: *left },
+ self.naga_expressions.get_span(*left),
+ );
+ }
+ _ => {}
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Add a single expression to the expression table that is not covered by `self.emitter`.
+ ///
+ /// This is useful for `CallResult` and `AtomicResult` expressions, which should not be covered by
+ /// `Emit` statements.
+ fn interrupt_emitter(
+ &mut self,
+ expression: crate::Expression,
+ span: Span,
+ ) -> Handle<crate::Expression> {
+ self.block
+ .extend(self.emitter.finish(self.naga_expressions));
+ let result = self.naga_expressions.append(expression, span);
+ self.emitter.start(self.naga_expressions);
+ result
+ }
+
+ /// Apply the WGSL Load Rule to `expr`.
+ ///
+ /// If `expr` is has type `ref<SC, T, A>`, perform a load to produce a value of type
+ /// `T`. Otherwise, return `expr` unchanged.
+ fn apply_load_rule(&mut self, expr: TypedExpression) -> Handle<crate::Expression> {
+ if expr.is_reference {
+ let load = crate::Expression::Load {
+ pointer: expr.handle,
+ };
+ let span = self.naga_expressions.get_span(expr.handle);
+ self.naga_expressions.append(load, span)
+ } else {
+ expr.handle
+ }
+ }
+
+ /// Creates a zero value constant of type `ty`
+ ///
+ /// Returns `None` if the given `ty` is not a constructible type
+ fn create_zero_value_constant(
+ &mut self,
+ ty: Handle<crate::Type>,
+ ) -> Option<Handle<crate::Constant>> {
+ let inner = match self.module.types[ty].inner {
+ crate::TypeInner::Scalar { kind, width } => {
+ let value = match kind {
+ crate::ScalarKind::Sint => crate::ScalarValue::Sint(0),
+ crate::ScalarKind::Uint => crate::ScalarValue::Uint(0),
+ crate::ScalarKind::Float => crate::ScalarValue::Float(0.),
+ crate::ScalarKind::Bool => crate::ScalarValue::Bool(false),
+ };
+ crate::ConstantInner::Scalar { width, value }
+ }
+ crate::TypeInner::Vector { size, kind, width } => {
+ let scalar_ty = self.ensure_type_exists(crate::TypeInner::Scalar { width, kind });
+ let component = self.create_zero_value_constant(scalar_ty)?;
+ crate::ConstantInner::Composite {
+ ty,
+ components: (0..size as u8).map(|_| component).collect(),
+ }
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let vec_ty = self.ensure_type_exists(crate::TypeInner::Vector {
+ width,
+ kind: crate::ScalarKind::Float,
+ size: rows,
+ });
+ let component = self.create_zero_value_constant(vec_ty)?;
+ crate::ConstantInner::Composite {
+ ty,
+ components: (0..columns as u8).map(|_| component).collect(),
+ }
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ let size = self.module.constants[size].to_array_length()?;
+ let component = self.create_zero_value_constant(base)?;
+ crate::ConstantInner::Composite {
+ ty,
+ components: (0..size).map(|_| component).collect(),
+ }
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ let members = members.clone();
+ crate::ConstantInner::Composite {
+ ty,
+ components: members
+ .iter()
+ .map(|member| self.create_zero_value_constant(member.ty))
+ .collect::<Option<_>>()?,
+ }
+ }
+ _ => return None,
+ };
+
+ let constant = self.module.constants.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ Span::UNDEFINED,
+ );
+ Some(constant)
+ }
+
+ fn format_typeinner(&self, inner: &crate::TypeInner) -> String {
+ inner.to_wgsl(&self.module.types, &self.module.constants)
+ }
+
+ fn format_type(&self, handle: Handle<crate::Type>) -> String {
+ let ty = &self.module.types[handle];
+ match ty.name {
+ Some(ref name) => name.clone(),
+ None => self.format_typeinner(&ty.inner),
+ }
+ }
+
+ fn format_type_resolution(&self, resolution: &TypeResolution) -> String {
+ match *resolution {
+ TypeResolution::Handle(handle) => self.format_type(handle),
+ TypeResolution::Value(ref inner) => self.format_typeinner(inner),
+ }
+ }
+
+ fn ensure_type_exists(&mut self, inner: crate::TypeInner) -> Handle<crate::Type> {
+ self.as_output().ensure_type_exists(inner)
+ }
+}
+
+struct ArgumentContext<'ctx, 'source> {
+ args: std::slice::Iter<'ctx, Handle<ast::Expression<'source>>>,
+ min_args: u32,
+ args_used: u32,
+ total_args: u32,
+ span: Span,
+}
+
+impl<'source> ArgumentContext<'_, 'source> {
+ pub fn finish(self) -> Result<(), Error<'source>> {
+ if self.args.len() == 0 {
+ Ok(())
+ } else {
+ Err(Error::WrongArgumentCount {
+ found: self.total_args,
+ expected: self.min_args..self.args_used + 1,
+ span: self.span,
+ })
+ }
+ }
+
+ pub fn next(&mut self) -> Result<Handle<ast::Expression<'source>>, Error<'source>> {
+ match self.args.next().copied() {
+ Some(arg) => {
+ self.args_used += 1;
+ Ok(arg)
+ }
+ None => Err(Error::WrongArgumentCount {
+ found: self.total_args,
+ expected: self.min_args..self.args_used + 1,
+ span: self.span,
+ }),
+ }
+ }
+}
+
+/// A Naga [`Expression`] handle, with WGSL type information.
+///
+/// Naga and WGSL types are very close, but Naga lacks WGSL's 'reference' types,
+/// which we need to know to apply the Load Rule. This struct carries a Naga
+/// `Handle<Expression>` along with enough information to determine its WGSL type.
+///
+/// [`Expression`]: crate::Expression
+#[derive(Debug, Copy, Clone)]
+struct TypedExpression {
+ /// The handle of the Naga expression.
+ handle: Handle<crate::Expression>,
+
+ /// True if this expression's WGSL type is a reference.
+ ///
+ /// When this is true, `handle` must be a pointer.
+ is_reference: bool,
+}
+
+impl TypedExpression {
+ const fn non_reference(handle: Handle<crate::Expression>) -> TypedExpression {
+ TypedExpression {
+ handle,
+ is_reference: false,
+ }
+ }
+}
+
+enum Composition {
+ Single(u32),
+ Multi(crate::VectorSize, [crate::SwizzleComponent; 4]),
+}
+
+impl Composition {
+ const fn letter_component(letter: char) -> Option<crate::SwizzleComponent> {
+ use crate::SwizzleComponent as Sc;
+ match letter {
+ 'x' | 'r' => Some(Sc::X),
+ 'y' | 'g' => Some(Sc::Y),
+ 'z' | 'b' => Some(Sc::Z),
+ 'w' | 'a' => Some(Sc::W),
+ _ => None,
+ }
+ }
+
+ fn extract_impl(name: &str, name_span: Span) -> Result<u32, Error> {
+ let ch = name.chars().next().ok_or(Error::BadAccessor(name_span))?;
+ match Self::letter_component(ch) {
+ Some(sc) => Ok(sc as u32),
+ None => Err(Error::BadAccessor(name_span)),
+ }
+ }
+
+ fn make(name: &str, name_span: Span) -> Result<Self, Error> {
+ if name.len() > 1 {
+ let mut components = [crate::SwizzleComponent::X; 4];
+ for (comp, ch) in components.iter_mut().zip(name.chars()) {
+ *comp = Self::letter_component(ch).ok_or(Error::BadAccessor(name_span))?;
+ }
+
+ let size = match name.len() {
+ 2 => crate::VectorSize::Bi,
+ 3 => crate::VectorSize::Tri,
+ 4 => crate::VectorSize::Quad,
+ _ => return Err(Error::BadAccessor(name_span)),
+ };
+ Ok(Composition::Multi(size, components))
+ } else {
+ Self::extract_impl(name, name_span).map(Composition::Single)
+ }
+ }
+}
+
+/// An `ast::GlobalDecl` for which we have built the Naga IR equivalent.
+enum LoweredGlobalDecl {
+ Function(Handle<crate::Function>),
+ Var(Handle<crate::GlobalVariable>),
+ Const(Handle<crate::Constant>),
+ Type(Handle<crate::Type>),
+ EntryPoint,
+}
+
+enum ConstantOrInner {
+ Constant(Handle<crate::Constant>),
+ Inner(crate::ConstantInner),
+}
+
+enum Texture {
+ Gather,
+ GatherCompare,
+
+ Sample,
+ SampleBias,
+ SampleCompare,
+ SampleCompareLevel,
+ SampleGrad,
+ SampleLevel,
+ // SampleBaseClampToEdge,
+}
+
+impl Texture {
+ pub fn map(word: &str) -> Option<Self> {
+ Some(match word {
+ "textureGather" => Self::Gather,
+ "textureGatherCompare" => Self::GatherCompare,
+
+ "textureSample" => Self::Sample,
+ "textureSampleBias" => Self::SampleBias,
+ "textureSampleCompare" => Self::SampleCompare,
+ "textureSampleCompareLevel" => Self::SampleCompareLevel,
+ "textureSampleGrad" => Self::SampleGrad,
+ "textureSampleLevel" => Self::SampleLevel,
+ // "textureSampleBaseClampToEdge" => Some(Self::SampleBaseClampToEdge),
+ _ => return None,
+ })
+ }
+
+ pub const fn min_argument_count(&self) -> u32 {
+ match *self {
+ Self::Gather => 3,
+ Self::GatherCompare => 4,
+
+ Self::Sample => 3,
+ Self::SampleBias => 5,
+ Self::SampleCompare => 5,
+ Self::SampleCompareLevel => 5,
+ Self::SampleGrad => 6,
+ Self::SampleLevel => 5,
+ // Self::SampleBaseClampToEdge => 3,
+ }
+ }
+}
+
+pub struct Lowerer<'source, 'temp> {
+ index: &'temp Index<'source>,
+ layouter: Layouter,
+}
+
+impl<'source, 'temp> Lowerer<'source, 'temp> {
+ pub fn new(index: &'temp Index<'source>) -> Self {
+ Self {
+ index,
+ layouter: Layouter::default(),
+ }
+ }
+
+ pub fn lower(
+ &mut self,
+ tu: &'temp ast::TranslationUnit<'source>,
+ ) -> Result<crate::Module, Error<'source>> {
+ let mut module = crate::Module::default();
+
+ let mut ctx = OutputContext {
+ ast_expressions: &tu.expressions,
+ globals: &mut FastHashMap::default(),
+ types: &tu.types,
+ module: &mut module,
+ };
+
+ for decl_handle in self.index.visit_ordered() {
+ let span = tu.decls.get_span(decl_handle);
+ let decl = &tu.decls[decl_handle];
+
+ match decl.kind {
+ ast::GlobalDeclKind::Fn(ref f) => {
+ let lowered_decl = self.function(f, span, ctx.reborrow())?;
+ ctx.globals.insert(f.name.name, lowered_decl);
+ }
+ ast::GlobalDeclKind::Var(ref v) => {
+ let ty = self.resolve_ast_type(v.ty, ctx.reborrow())?;
+
+ let init = v
+ .init
+ .map(|init| self.constant(init, ctx.reborrow()))
+ .transpose()?;
+
+ let handle = ctx.module.global_variables.append(
+ crate::GlobalVariable {
+ name: Some(v.name.name.to_string()),
+ space: v.space,
+ binding: v.binding.clone(),
+ ty,
+ init,
+ },
+ span,
+ );
+
+ ctx.globals
+ .insert(v.name.name, LoweredGlobalDecl::Var(handle));
+ }
+ ast::GlobalDeclKind::Const(ref c) => {
+ let inner = self.constant_inner(c.init, ctx.reborrow())?;
+ let inner = match inner {
+ ConstantOrInner::Constant(c) => ctx.module.constants[c].inner.clone(),
+ ConstantOrInner::Inner(inner) => inner,
+ };
+
+ let inferred_type = match inner {
+ crate::ConstantInner::Scalar { width, value } => {
+ ctx.ensure_type_exists(crate::TypeInner::Scalar {
+ width,
+ kind: value.scalar_kind(),
+ })
+ }
+ crate::ConstantInner::Composite { ty, .. } => ty,
+ };
+
+ let handle = ctx.module.constants.append(
+ crate::Constant {
+ name: Some(c.name.name.to_string()),
+ specialization: None,
+ inner,
+ },
+ span,
+ );
+
+ let explicit_ty =
+ c.ty.map(|ty| self.resolve_ast_type(ty, ctx.reborrow()))
+ .transpose()?;
+
+ if let Some(explicit) = explicit_ty {
+ if explicit != inferred_type {
+ let ty = &ctx.module.types[explicit];
+ let explicit = ty.name.clone().unwrap_or_else(|| {
+ ty.inner.to_wgsl(&ctx.module.types, &ctx.module.constants)
+ });
+
+ let ty = &ctx.module.types[inferred_type];
+ let inferred = ty.name.clone().unwrap_or_else(|| {
+ ty.inner.to_wgsl(&ctx.module.types, &ctx.module.constants)
+ });
+
+ return Err(Error::InitializationTypeMismatch(
+ c.name.span,
+ explicit,
+ inferred,
+ ));
+ }
+ }
+
+ ctx.globals
+ .insert(c.name.name, LoweredGlobalDecl::Const(handle));
+ }
+ ast::GlobalDeclKind::Struct(ref s) => {
+ let handle = self.r#struct(s, span, ctx.reborrow())?;
+ ctx.globals
+ .insert(s.name.name, LoweredGlobalDecl::Type(handle));
+ }
+ ast::GlobalDeclKind::Type(ref alias) => {
+ let ty = self.resolve_ast_type(alias.ty, ctx.reborrow())?;
+ ctx.globals
+ .insert(alias.name.name, LoweredGlobalDecl::Type(ty));
+ }
+ }
+ }
+
+ Ok(module)
+ }
+
+ fn function(
+ &mut self,
+ f: &ast::Function<'source>,
+ span: Span,
+ mut ctx: OutputContext<'source, '_, '_>,
+ ) -> Result<LoweredGlobalDecl, Error<'source>> {
+ let mut local_table = FastHashMap::default();
+ let mut local_variables = Arena::new();
+ let mut expressions = Arena::new();
+ let mut named_expressions = IndexMap::default();
+
+ let arguments = f
+ .arguments
+ .iter()
+ .enumerate()
+ .map(|(i, arg)| {
+ let ty = self.resolve_ast_type(arg.ty, ctx.reborrow())?;
+ let expr = expressions
+ .append(crate::Expression::FunctionArgument(i as u32), arg.name.span);
+ local_table.insert(arg.handle, TypedExpression::non_reference(expr));
+ named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span));
+
+ Ok(crate::FunctionArgument {
+ name: Some(arg.name.name.to_string()),
+ ty,
+ binding: self.interpolate_default(&arg.binding, ty, ctx.reborrow()),
+ })
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let result = f
+ .result
+ .as_ref()
+ .map(|res| {
+ self.resolve_ast_type(res.ty, ctx.reborrow())
+ .map(|ty| crate::FunctionResult {
+ ty,
+ binding: self.interpolate_default(&res.binding, ty, ctx.reborrow()),
+ })
+ })
+ .transpose()?;
+
+ let mut typifier = Typifier::default();
+ let mut body = self.block(
+ &f.body,
+ StatementContext {
+ local_table: &mut local_table,
+ globals: ctx.globals,
+ ast_expressions: ctx.ast_expressions,
+ typifier: &mut typifier,
+ variables: &mut local_variables,
+ naga_expressions: &mut expressions,
+ named_expressions: &mut named_expressions,
+ types: ctx.types,
+ module: ctx.module,
+ arguments: &arguments,
+ },
+ )?;
+ ensure_block_returns(&mut body);
+
+ let function = crate::Function {
+ name: Some(f.name.name.to_string()),
+ arguments,
+ result,
+ local_variables,
+ expressions,
+ named_expressions: named_expressions
+ .into_iter()
+ .map(|(key, (name, _))| (key, name))
+ .collect(),
+ body,
+ };
+
+ if let Some(ref entry) = f.entry_point {
+ ctx.module.entry_points.push(crate::EntryPoint {
+ name: f.name.name.to_string(),
+ stage: entry.stage,
+ early_depth_test: entry.early_depth_test,
+ workgroup_size: entry.workgroup_size,
+ function,
+ });
+ Ok(LoweredGlobalDecl::EntryPoint)
+ } else {
+ let handle = ctx.module.functions.append(function, span);
+ Ok(LoweredGlobalDecl::Function(handle))
+ }
+ }
+
+ fn block(
+ &mut self,
+ b: &ast::Block<'source>,
+ mut ctx: StatementContext<'source, '_, '_>,
+ ) -> Result<crate::Block, Error<'source>> {
+ let mut block = crate::Block::default();
+
+ for stmt in b.stmts.iter() {
+ self.statement(stmt, &mut block, ctx.reborrow())?;
+ }
+
+ Ok(block)
+ }
+
+ fn statement(
+ &mut self,
+ stmt: &ast::Statement<'source>,
+ block: &mut crate::Block,
+ mut ctx: StatementContext<'source, '_, '_>,
+ ) -> Result<(), Error<'source>> {
+ let out = match stmt.kind {
+ ast::StatementKind::Block(ref block) => {
+ let block = self.block(block, ctx.reborrow())?;
+ crate::Statement::Block(block)
+ }
+ ast::StatementKind::LocalDecl(ref decl) => match *decl {
+ ast::LocalDecl::Let(ref l) => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let value = self.expression(l.init, ctx.as_expression(block, &mut emitter))?;
+
+ let explicit_ty =
+ l.ty.map(|ty| self.resolve_ast_type(ty, ctx.as_output()))
+ .transpose()?;
+
+ if let Some(ty) = explicit_ty {
+ let mut ctx = ctx.as_expression(block, &mut emitter);
+ let init_ty = ctx.register_type(value)?;
+ if !ctx.module.types[ty]
+ .inner
+ .equivalent(&ctx.module.types[init_ty].inner, &ctx.module.types)
+ {
+ return Err(Error::InitializationTypeMismatch(
+ l.name.span,
+ ctx.format_type(ty),
+ ctx.format_type(init_ty),
+ ));
+ }
+ }
+
+ block.extend(emitter.finish(ctx.naga_expressions));
+ ctx.local_table
+ .insert(l.handle, TypedExpression::non_reference(value));
+ ctx.named_expressions
+ .insert(value, (l.name.name.to_string(), l.name.span));
+
+ return Ok(());
+ }
+ ast::LocalDecl::Var(ref v) => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let initializer = match v.init {
+ Some(init) => {
+ let initializer =
+ self.expression(init, ctx.as_expression(block, &mut emitter))?;
+ ctx.as_expression(block, &mut emitter)
+ .grow_types(initializer)?;
+ Some(initializer)
+ }
+ None => None,
+ };
+
+ let explicit_ty =
+ v.ty.map(|ty| self.resolve_ast_type(ty, ctx.as_output()))
+ .transpose()?;
+
+ let ty = match (explicit_ty, initializer) {
+ (Some(explicit), Some(initializer)) => {
+ let ctx = ctx.as_expression(block, &mut emitter);
+ let initializer_ty = ctx.resolved_inner(initializer);
+ if !ctx.module.types[explicit]
+ .inner
+ .equivalent(initializer_ty, &ctx.module.types)
+ {
+ return Err(Error::InitializationTypeMismatch(
+ v.name.span,
+ ctx.format_type(explicit),
+ ctx.format_typeinner(initializer_ty),
+ ));
+ }
+ explicit
+ }
+ (Some(explicit), None) => explicit,
+ (None, Some(initializer)) => ctx
+ .as_expression(block, &mut emitter)
+ .register_type(initializer)?,
+ (None, None) => {
+ return Err(Error::MissingType(v.name.span));
+ }
+ };
+
+ let var = ctx.variables.append(
+ crate::LocalVariable {
+ name: Some(v.name.name.to_string()),
+ ty,
+ init: None,
+ },
+ stmt.span,
+ );
+
+ let handle = ctx
+ .as_expression(block, &mut emitter)
+ .interrupt_emitter(crate::Expression::LocalVariable(var), Span::UNDEFINED);
+ block.extend(emitter.finish(ctx.naga_expressions));
+ ctx.local_table.insert(
+ v.handle,
+ TypedExpression {
+ handle,
+ is_reference: true,
+ },
+ );
+
+ match initializer {
+ Some(initializer) => crate::Statement::Store {
+ pointer: handle,
+ value: initializer,
+ },
+ None => return Ok(()),
+ }
+ }
+ },
+ ast::StatementKind::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let condition =
+ self.expression(condition, ctx.as_expression(block, &mut emitter))?;
+ block.extend(emitter.finish(ctx.naga_expressions));
+
+ let accept = self.block(accept, ctx.reborrow())?;
+ let reject = self.block(reject, ctx.reborrow())?;
+
+ crate::Statement::If {
+ condition,
+ accept,
+ reject,
+ }
+ }
+ ast::StatementKind::Switch {
+ selector,
+ ref cases,
+ } => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let mut ectx = ctx.as_expression(block, &mut emitter);
+ let selector = self.expression(selector, ectx.reborrow())?;
+
+ ectx.grow_types(selector)?;
+ let uint =
+ ectx.resolved_inner(selector).scalar_kind() == Some(crate::ScalarKind::Uint);
+ block.extend(emitter.finish(ctx.naga_expressions));
+
+ let cases = cases
+ .iter()
+ .map(|case| {
+ Ok(crate::SwitchCase {
+ value: match case.value {
+ ast::SwitchValue::I32(value) if !uint => {
+ crate::SwitchValue::I32(value)
+ }
+ ast::SwitchValue::U32(value) if uint => {
+ crate::SwitchValue::U32(value)
+ }
+ ast::SwitchValue::Default => crate::SwitchValue::Default,
+ _ => {
+ return Err(Error::InvalidSwitchValue {
+ uint,
+ span: case.value_span,
+ });
+ }
+ },
+ body: self.block(&case.body, ctx.reborrow())?,
+ fall_through: case.fall_through,
+ })
+ })
+ .collect::<Result<_, _>>()?;
+
+ crate::Statement::Switch { selector, cases }
+ }
+ ast::StatementKind::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ let body = self.block(body, ctx.reborrow())?;
+ let mut continuing = self.block(continuing, ctx.reborrow())?;
+
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+ let break_if = break_if
+ .map(|expr| self.expression(expr, ctx.as_expression(block, &mut emitter)))
+ .transpose()?;
+ continuing.extend(emitter.finish(ctx.naga_expressions));
+
+ crate::Statement::Loop {
+ body,
+ continuing,
+ break_if,
+ }
+ }
+ ast::StatementKind::Break => crate::Statement::Break,
+ ast::StatementKind::Continue => crate::Statement::Continue,
+ ast::StatementKind::Return { value } => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let value = value
+ .map(|expr| self.expression(expr, ctx.as_expression(block, &mut emitter)))
+ .transpose()?;
+ block.extend(emitter.finish(ctx.naga_expressions));
+
+ crate::Statement::Return { value }
+ }
+ ast::StatementKind::Kill => crate::Statement::Kill,
+ ast::StatementKind::Call {
+ ref function,
+ ref arguments,
+ } => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let _ = self.call(
+ stmt.span,
+ function,
+ arguments,
+ ctx.as_expression(block, &mut emitter),
+ )?;
+ block.extend(emitter.finish(ctx.naga_expressions));
+ return Ok(());
+ }
+ ast::StatementKind::Assign { target, op, value } => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let expr =
+ self.expression_for_reference(target, ctx.as_expression(block, &mut emitter))?;
+ let mut value = self.expression(value, ctx.as_expression(block, &mut emitter))?;
+
+ if !expr.is_reference {
+ let ty = ctx.invalid_assignment_type(expr.handle);
+
+ return Err(Error::InvalidAssignment {
+ span: ctx.ast_expressions.get_span(target),
+ ty,
+ });
+ }
+
+ let value = match op {
+ Some(op) => {
+ let mut ctx = ctx.as_expression(block, &mut emitter);
+ let mut left = ctx.apply_load_rule(expr);
+ ctx.binary_op_splat(op, &mut left, &mut value)?;
+ ctx.naga_expressions.append(
+ crate::Expression::Binary {
+ op,
+ left,
+ right: value,
+ },
+ stmt.span,
+ )
+ }
+ None => value,
+ };
+ block.extend(emitter.finish(ctx.naga_expressions));
+
+ crate::Statement::Store {
+ pointer: expr.handle,
+ value,
+ }
+ }
+ ast::StatementKind::Increment(value) | ast::StatementKind::Decrement(value) => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let op = match stmt.kind {
+ ast::StatementKind::Increment(_) => crate::BinaryOperator::Add,
+ ast::StatementKind::Decrement(_) => crate::BinaryOperator::Subtract,
+ _ => unreachable!(),
+ };
+
+ let value_span = ctx.ast_expressions.get_span(value);
+ let reference =
+ self.expression_for_reference(value, ctx.as_expression(block, &mut emitter))?;
+ let mut ectx = ctx.as_expression(block, &mut emitter);
+
+ ectx.grow_types(reference.handle)?;
+ let (kind, width) = match *ectx.resolved_inner(reference.handle) {
+ crate::TypeInner::ValuePointer {
+ size: None,
+ kind,
+ width,
+ ..
+ } => (kind, width),
+ crate::TypeInner::Pointer { base, .. } => match ectx.module.types[base].inner {
+ crate::TypeInner::Scalar { kind, width } => (kind, width),
+ _ => return Err(Error::BadIncrDecrReferenceType(value_span)),
+ },
+ _ => return Err(Error::BadIncrDecrReferenceType(value_span)),
+ };
+ let constant_inner = crate::ConstantInner::Scalar {
+ width,
+ value: match kind {
+ crate::ScalarKind::Sint => crate::ScalarValue::Sint(1),
+ crate::ScalarKind::Uint => crate::ScalarValue::Uint(1),
+ _ => return Err(Error::BadIncrDecrReferenceType(value_span)),
+ },
+ };
+ let constant = ectx.module.constants.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: constant_inner,
+ },
+ Span::UNDEFINED,
+ );
+
+ let left = ectx.naga_expressions.append(
+ crate::Expression::Load {
+ pointer: reference.handle,
+ },
+ value_span,
+ );
+ let right =
+ ectx.interrupt_emitter(crate::Expression::Constant(constant), Span::UNDEFINED);
+ let value = ectx
+ .naga_expressions
+ .append(crate::Expression::Binary { op, left, right }, stmt.span);
+
+ block.extend(emitter.finish(ctx.naga_expressions));
+ crate::Statement::Store {
+ pointer: reference.handle,
+ value,
+ }
+ }
+ ast::StatementKind::Ignore(expr) => {
+ let mut emitter = Emitter::default();
+ emitter.start(ctx.naga_expressions);
+
+ let _ = self.expression(expr, ctx.as_expression(block, &mut emitter))?;
+ block.extend(emitter.finish(ctx.naga_expressions));
+ return Ok(());
+ }
+ };
+
+ block.push(out, stmt.span);
+
+ Ok(())
+ }
+
+ fn expression(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let expr = self.expression_for_reference(expr, ctx.reborrow())?;
+ Ok(ctx.apply_load_rule(expr))
+ }
+
+ fn expression_for_reference(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<TypedExpression, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let expr = &ctx.ast_expressions[expr];
+
+ let (expr, is_reference) = match *expr {
+ ast::Expression::Literal(literal) => {
+ let inner = match literal {
+ ast::Literal::Number(Number::F32(f)) => crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Float(f as _),
+ },
+ ast::Literal::Number(Number::I32(i)) => crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Sint(i as _),
+ },
+ ast::Literal::Number(Number::U32(u)) => crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Uint(u as _),
+ },
+ ast::Literal::Number(_) => {
+ unreachable!("got abstract numeric type when not expected");
+ }
+ ast::Literal::Bool(b) => crate::ConstantInner::Scalar {
+ width: 1,
+ value: crate::ScalarValue::Bool(b),
+ },
+ };
+ let handle = ctx.module.constants.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ Span::UNDEFINED,
+ );
+ let handle = ctx.interrupt_emitter(crate::Expression::Constant(handle), span);
+ return Ok(TypedExpression::non_reference(handle));
+ }
+ ast::Expression::Ident(ast::IdentExpr::Local(local)) => {
+ return Ok(ctx.local_table[&local])
+ }
+ ast::Expression::Ident(ast::IdentExpr::Unresolved(name)) => {
+ return if let Some(global) = ctx.globals.get(name) {
+ let (expr, is_reference) = match *global {
+ LoweredGlobalDecl::Var(handle) => (
+ crate::Expression::GlobalVariable(handle),
+ ctx.module.global_variables[handle].space
+ != crate::AddressSpace::Handle,
+ ),
+ LoweredGlobalDecl::Const(handle) => {
+ (crate::Expression::Constant(handle), false)
+ }
+ _ => {
+ return Err(Error::Unexpected(span, ExpectedToken::Variable));
+ }
+ };
+
+ let handle = ctx.interrupt_emitter(expr, span);
+ Ok(TypedExpression {
+ handle,
+ is_reference,
+ })
+ } else {
+ Err(Error::UnknownIdent(span, name))
+ }
+ }
+ ast::Expression::Construct {
+ ref ty,
+ ty_span,
+ ref components,
+ } => {
+ let handle = self.construct(span, ty, ty_span, components, ctx.reborrow())?;
+ return Ok(TypedExpression::non_reference(handle));
+ }
+ ast::Expression::Unary { op, expr } => {
+ let expr = self.expression(expr, ctx.reborrow())?;
+ (crate::Expression::Unary { op, expr }, false)
+ }
+ ast::Expression::AddrOf(expr) => {
+ // The `&` operator simply converts a reference to a pointer. And since a
+ // reference is required, the Load Rule is not applied.
+ let expr = self.expression_for_reference(expr, ctx.reborrow())?;
+ if !expr.is_reference {
+ return Err(Error::NotReference("the operand of the `&` operator", span));
+ }
+
+ // No code is generated. We just declare the pointer a reference now.
+ return Ok(TypedExpression {
+ is_reference: false,
+ ..expr
+ });
+ }
+ ast::Expression::Deref(expr) => {
+ // The pointer we dereference must be loaded.
+ let pointer = self.expression(expr, ctx.reborrow())?;
+
+ ctx.grow_types(pointer)?;
+ if ctx.resolved_inner(pointer).pointer_space().is_none() {
+ return Err(Error::NotPointer(span));
+ }
+
+ return Ok(TypedExpression {
+ handle: pointer,
+ is_reference: true,
+ });
+ }
+ ast::Expression::Binary { op, left, right } => {
+ // Load both operands.
+ let mut left = self.expression(left, ctx.reborrow())?;
+ let mut right = self.expression(right, ctx.reborrow())?;
+ ctx.binary_op_splat(op, &mut left, &mut right)?;
+ (crate::Expression::Binary { op, left, right }, false)
+ }
+ ast::Expression::Call {
+ ref function,
+ ref arguments,
+ } => {
+ let handle = self
+ .call(span, function, arguments, ctx.reborrow())?
+ .ok_or(Error::FunctionReturnsVoid(function.span))?;
+ return Ok(TypedExpression::non_reference(handle));
+ }
+ ast::Expression::Index { base, index } => {
+ let expr = self.expression_for_reference(base, ctx.reborrow())?;
+ let index = self.expression(index, ctx.reborrow())?;
+
+ ctx.grow_types(expr.handle)?;
+ let wgsl_pointer =
+ ctx.resolved_inner(expr.handle).pointer_space().is_some() && !expr.is_reference;
+
+ if wgsl_pointer {
+ return Err(Error::Pointer(
+ "the value indexed by a `[]` subscripting expression",
+ ctx.ast_expressions.get_span(base),
+ ));
+ }
+
+ if let crate::Expression::Constant(constant) = ctx.naga_expressions[index] {
+ let span = ctx.naga_expressions.get_span(index);
+ let index = match ctx.module.constants[constant].inner {
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Uint(int),
+ ..
+ } => u32::try_from(int).map_err(|_| Error::BadU32Constant(span)),
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Sint(int),
+ ..
+ } => u32::try_from(int).map_err(|_| Error::BadU32Constant(span)),
+ _ => Err(Error::BadU32Constant(span)),
+ }?;
+
+ (
+ crate::Expression::AccessIndex {
+ base: expr.handle,
+ index,
+ },
+ expr.is_reference,
+ )
+ } else {
+ (
+ crate::Expression::Access {
+ base: expr.handle,
+ index,
+ },
+ expr.is_reference,
+ )
+ }
+ }
+ ast::Expression::Member { base, ref field } => {
+ let TypedExpression {
+ handle,
+ is_reference,
+ } = self.expression_for_reference(base, ctx.reborrow())?;
+
+ ctx.grow_types(handle)?;
+ let temp_inner;
+ let (composite, wgsl_pointer) = match *ctx.resolved_inner(handle) {
+ crate::TypeInner::Pointer { base, .. } => {
+ (&ctx.module.types[base].inner, !is_reference)
+ }
+ crate::TypeInner::ValuePointer {
+ size: None,
+ kind,
+ width,
+ ..
+ } => {
+ temp_inner = crate::TypeInner::Scalar { kind, width };
+ (&temp_inner, !is_reference)
+ }
+ crate::TypeInner::ValuePointer {
+ size: Some(size),
+ kind,
+ width,
+ ..
+ } => {
+ temp_inner = crate::TypeInner::Vector { size, kind, width };
+ (&temp_inner, !is_reference)
+ }
+ ref other => (other, false),
+ };
+
+ if wgsl_pointer {
+ return Err(Error::Pointer(
+ "the value accessed by a `.member` expression",
+ ctx.ast_expressions.get_span(base),
+ ));
+ }
+
+ let access = match *composite {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let index = members
+ .iter()
+ .position(|m| m.name.as_deref() == Some(field.name))
+ .ok_or(Error::BadAccessor(field.span))?
+ as u32;
+
+ (
+ crate::Expression::AccessIndex {
+ base: handle,
+ index,
+ },
+ is_reference,
+ )
+ }
+ crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => {
+ match Composition::make(field.name, field.span)? {
+ Composition::Multi(size, pattern) => {
+ let vector = ctx.apply_load_rule(TypedExpression {
+ handle,
+ is_reference,
+ });
+
+ (
+ crate::Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ },
+ false,
+ )
+ }
+ Composition::Single(index) => (
+ crate::Expression::AccessIndex {
+ base: handle,
+ index,
+ },
+ is_reference,
+ ),
+ }
+ }
+ _ => return Err(Error::BadAccessor(field.span)),
+ };
+
+ access
+ }
+ ast::Expression::Bitcast { expr, to, ty_span } => {
+ let expr = self.expression(expr, ctx.reborrow())?;
+ let to_resolved = self.resolve_ast_type(to, ctx.as_output())?;
+
+ let kind = match ctx.module.types[to_resolved].inner {
+ crate::TypeInner::Scalar { kind, .. } => kind,
+ crate::TypeInner::Vector { kind, .. } => kind,
+ _ => {
+ let ty = &ctx.typifier[expr];
+ return Err(Error::BadTypeCast {
+ from_type: ctx.format_type_resolution(ty),
+ span: ty_span,
+ to_type: ctx.format_type(to_resolved),
+ });
+ }
+ };
+
+ (
+ crate::Expression::As {
+ expr,
+ kind,
+ convert: None,
+ },
+ false,
+ )
+ }
+ };
+
+ let handle = ctx.naga_expressions.append(expr, span);
+ Ok(TypedExpression {
+ handle,
+ is_reference,
+ })
+ }
+
+ /// Generate Naga IR for call expressions and statements, and type
+ /// constructor expressions.
+ ///
+ /// The "function" being called is simply an `Ident` that we know refers to
+ /// some module-scope definition.
+ ///
+ /// - If it is the name of a type, then the expression is a type constructor
+ /// expression: either constructing a value from components, a conversion
+ /// expression, or a zero value expression.
+ ///
+ /// - If it is the name of a function, then we're generating a [`Call`]
+ /// statement. We may be in the midst of generating code for an
+ /// expression, in which case we must generate an `Emit` statement to
+ /// force evaluation of the IR expressions we've generated so far, add the
+ /// `Call` statement to the current block, and then resume generating
+ /// expressions.
+ ///
+ /// [`Call`]: crate::Statement::Call
+ fn call(
+ &mut self,
+ span: Span,
+ function: &ast::Ident<'source>,
+ arguments: &[Handle<ast::Expression<'source>>],
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<Option<Handle<crate::Expression>>, Error<'source>> {
+ match ctx.globals.get(function.name) {
+ Some(&LoweredGlobalDecl::Type(ty)) => {
+ let handle = self.construct(
+ span,
+ &ast::ConstructorType::Type(ty),
+ function.span,
+ arguments,
+ ctx.reborrow(),
+ )?;
+ Ok(Some(handle))
+ }
+ Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => {
+ Err(Error::Unexpected(function.span, ExpectedToken::Function))
+ }
+ Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)),
+ Some(&LoweredGlobalDecl::Function(function)) => {
+ let arguments = arguments
+ .iter()
+ .map(|&arg| self.expression(arg, ctx.reborrow()))
+ .collect::<Result<Vec<_>, _>>()?;
+
+ ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions));
+ let result = ctx.module.functions[function].result.is_some().then(|| {
+ ctx.naga_expressions
+ .append(crate::Expression::CallResult(function), span)
+ });
+ ctx.emitter.start(ctx.naga_expressions);
+ ctx.block.push(
+ crate::Statement::Call {
+ function,
+ arguments,
+ result,
+ },
+ span,
+ );
+
+ Ok(result)
+ }
+ None => {
+ let span = function.span;
+ let expr = if let Some(fun) = conv::map_relational_fun(function.name) {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let argument = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ crate::Expression::Relational { fun, argument }
+ } else if let Some((axis, ctrl)) = conv::map_derivative(function.name) {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let expr = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ crate::Expression::Derivative { axis, ctrl, expr }
+ } else if let Some(fun) = conv::map_standard_fun(function.name) {
+ let expected = fun.argument_count() as _;
+ let mut args = ctx.prepare_args(arguments, expected, span);
+
+ let arg = self.expression(args.next()?, ctx.reborrow())?;
+ let arg1 = args
+ .next()
+ .map(|x| self.expression(x, ctx.reborrow()))
+ .ok()
+ .transpose()?;
+ let arg2 = args
+ .next()
+ .map(|x| self.expression(x, ctx.reborrow()))
+ .ok()
+ .transpose()?;
+ let arg3 = args
+ .next()
+ .map(|x| self.expression(x, ctx.reborrow()))
+ .ok()
+ .transpose()?;
+
+ args.finish()?;
+
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ }
+ } else if let Some(fun) = Texture::map(function.name) {
+ self.texture_sample_helper(fun, arguments, span, ctx.reborrow())?
+ } else {
+ match function.name {
+ "select" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+
+ let reject = self.expression(args.next()?, ctx.reborrow())?;
+ let accept = self.expression(args.next()?, ctx.reborrow())?;
+ let condition = self.expression(args.next()?, ctx.reborrow())?;
+
+ args.finish()?;
+
+ crate::Expression::Select {
+ reject,
+ accept,
+ condition,
+ }
+ }
+ "arrayLength" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let expr = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ crate::Expression::ArrayLength(expr)
+ }
+ "atomicLoad" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let pointer = self.atomic_pointer(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ crate::Expression::Load { pointer }
+ }
+ "atomicStore" => {
+ let mut args = ctx.prepare_args(arguments, 2, span);
+ let pointer = self.atomic_pointer(args.next()?, ctx.reborrow())?;
+ let value = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions));
+ ctx.emitter.start(ctx.naga_expressions);
+ ctx.block
+ .push(crate::Statement::Store { pointer, value }, span);
+ return Ok(None);
+ }
+ "atomicAdd" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Add,
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicSub" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Subtract,
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicAnd" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::And,
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicOr" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::InclusiveOr,
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicXor" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::ExclusiveOr,
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicMin" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Min,
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicMax" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Max,
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicExchange" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Exchange { compare: None },
+ arguments,
+ ctx.reborrow(),
+ )?))
+ }
+ "atomicCompareExchangeWeak" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+
+ let pointer = self.atomic_pointer(args.next()?, ctx.reborrow())?;
+
+ let compare = self.expression(args.next()?, ctx.reborrow())?;
+
+ let value = args.next()?;
+ let value_span = ctx.ast_expressions.get_span(value);
+ let value = self.expression(value, ctx.reborrow())?;
+ ctx.grow_types(value)?;
+
+ args.finish()?;
+
+ let expression = match *ctx.resolved_inner(value) {
+ crate::TypeInner::Scalar { kind, width } => {
+ crate::Expression::AtomicResult {
+ //TODO: cache this to avoid generating duplicate types
+ ty: ctx
+ .module
+ .generate_atomic_compare_exchange_result(kind, width),
+ comparison: true,
+ }
+ }
+ _ => return Err(Error::InvalidAtomicOperandType(value_span)),
+ };
+
+ let result = ctx.interrupt_emitter(expression, span);
+ ctx.block.push(
+ crate::Statement::Atomic {
+ pointer,
+ fun: crate::AtomicFunction::Exchange {
+ compare: Some(compare),
+ },
+ value,
+ result,
+ },
+ span,
+ );
+ return Ok(Some(result));
+ }
+ "storageBarrier" => {
+ ctx.prepare_args(arguments, 0, span).finish()?;
+
+ ctx.block
+ .push(crate::Statement::Barrier(crate::Barrier::STORAGE), span);
+ return Ok(None);
+ }
+ "workgroupBarrier" => {
+ ctx.prepare_args(arguments, 0, span).finish()?;
+
+ ctx.block
+ .push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span);
+ return Ok(None);
+ }
+ "textureStore" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+
+ let image = args.next()?;
+ let image_span = ctx.ast_expressions.get_span(image);
+ let image = self.expression(image, ctx.reborrow())?;
+
+ let coordinate = self.expression(args.next()?, ctx.reborrow())?;
+
+ let (_, arrayed) = ctx.image_data(image, image_span)?;
+ let array_index = arrayed
+ .then(|| self.expression(args.next()?, ctx.reborrow()))
+ .transpose()?;
+
+ let value = self.expression(args.next()?, ctx.reborrow())?;
+
+ args.finish()?;
+
+ ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions));
+ ctx.emitter.start(ctx.naga_expressions);
+ let stmt = crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ };
+ ctx.block.push(stmt, span);
+ return Ok(None);
+ }
+ "textureLoad" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+
+ let image = args.next()?;
+ let image_span = ctx.ast_expressions.get_span(image);
+ let image = self.expression(image, ctx.reborrow())?;
+
+ let coordinate = self.expression(args.next()?, ctx.reborrow())?;
+
+ let (class, arrayed) = ctx.image_data(image, image_span)?;
+ let array_index = arrayed
+ .then(|| self.expression(args.next()?, ctx.reborrow()))
+ .transpose()?;
+
+ let level = class
+ .is_mipmapped()
+ .then(|| self.expression(args.next()?, ctx.reborrow()))
+ .transpose()?;
+
+ let sample = class
+ .is_multisampled()
+ .then(|| self.expression(args.next()?, ctx.reborrow()))
+ .transpose()?;
+
+ args.finish()?;
+
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ level,
+ sample,
+ }
+ }
+ "textureDimensions" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx.reborrow())?;
+ let level = args
+ .next()
+ .map(|arg| self.expression(arg, ctx.reborrow()))
+ .ok()
+ .transpose()?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::Size { level },
+ }
+ }
+ "textureNumLevels" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::NumLevels,
+ }
+ }
+ "textureNumLayers" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::NumLayers,
+ }
+ }
+ "textureNumSamples" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::NumSamples,
+ }
+ }
+ "rayQueryInitialize" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+ let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?;
+ let acceleration_structure =
+ self.expression(args.next()?, ctx.reborrow())?;
+ let descriptor = self.expression(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ let _ = ctx.module.generate_ray_desc_type();
+ let fun = crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ };
+
+ ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions));
+ ctx.emitter.start(ctx.naga_expressions);
+ ctx.block
+ .push(crate::Statement::RayQuery { query, fun }, span);
+ return Ok(None);
+ }
+ "rayQueryProceed" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ ctx.block.extend(ctx.emitter.finish(ctx.naga_expressions));
+ let result = ctx
+ .naga_expressions
+ .append(crate::Expression::RayQueryProceedResult, span);
+ let fun = crate::RayQueryFunction::Proceed { result };
+
+ ctx.emitter.start(ctx.naga_expressions);
+ ctx.block
+ .push(crate::Statement::RayQuery { query, fun }, span);
+ return Ok(Some(result));
+ }
+ "rayQueryGetCommittedIntersection" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?;
+ args.finish()?;
+
+ let _ = ctx.module.generate_ray_intersection_type();
+
+ crate::Expression::RayQueryGetIntersection {
+ query,
+ committed: true,
+ }
+ }
+ "RayDesc" => {
+ let ty = ctx.module.generate_ray_desc_type();
+ let handle = self.construct(
+ span,
+ &ast::ConstructorType::Type(ty),
+ function.span,
+ arguments,
+ ctx.reborrow(),
+ )?;
+ return Ok(Some(handle));
+ }
+ _ => return Err(Error::UnknownIdent(function.span, function.name)),
+ }
+ };
+
+ let expr = ctx.naga_expressions.append(expr, span);
+ Ok(Some(expr))
+ }
+ }
+ }
+
+ fn atomic_pointer(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let pointer = self.expression(expr, ctx.reborrow())?;
+
+ ctx.grow_types(pointer)?;
+ match *ctx.resolved_inner(pointer) {
+ crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => Ok(pointer),
+ ref other => {
+ log::error!("Pointer type to {:?} passed to atomic op", other);
+ Err(Error::InvalidAtomicPointer(span))
+ }
+ },
+ ref other => {
+ log::error!("Type {:?} passed to atomic op", other);
+ Err(Error::InvalidAtomicPointer(span))
+ }
+ }
+ }
+
+ fn atomic_helper(
+ &mut self,
+ span: Span,
+ fun: crate::AtomicFunction,
+ args: &[Handle<ast::Expression<'source>>],
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let mut args = ctx.prepare_args(args, 2, span);
+
+ let pointer = self.atomic_pointer(args.next()?, ctx.reborrow())?;
+
+ let value = args.next()?;
+ let value = self.expression(value, ctx.reborrow())?;
+ let ty = ctx.register_type(value)?;
+
+ args.finish()?;
+
+ let result = ctx.interrupt_emitter(
+ crate::Expression::AtomicResult {
+ ty,
+ comparison: false,
+ },
+ span,
+ );
+ ctx.block.push(
+ crate::Statement::Atomic {
+ pointer,
+ fun,
+ value,
+ result,
+ },
+ span,
+ );
+ Ok(result)
+ }
+
+ fn texture_sample_helper(
+ &mut self,
+ fun: Texture,
+ args: &[Handle<ast::Expression<'source>>],
+ span: Span,
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<crate::Expression, Error<'source>> {
+ let mut args = ctx.prepare_args(args, fun.min_argument_count(), span);
+
+ let (image, gather) = match fun {
+ Texture::Gather => {
+ let image_or_component = args.next()?;
+ match self.gather_component(image_or_component, ctx.reborrow())? {
+ Some(component) => {
+ let image = args.next()?;
+ (image, Some(component))
+ }
+ None => (image_or_component, Some(crate::SwizzleComponent::X)),
+ }
+ }
+ Texture::GatherCompare => {
+ let image = args.next()?;
+ (image, Some(crate::SwizzleComponent::X))
+ }
+
+ _ => {
+ let image = args.next()?;
+ (image, None)
+ }
+ };
+
+ let image_span = ctx.ast_expressions.get_span(image);
+ let image = self.expression(image, ctx.reborrow())?;
+
+ let sampler = self.expression(args.next()?, ctx.reborrow())?;
+
+ let coordinate = self.expression(args.next()?, ctx.reborrow())?;
+
+ let (_, arrayed) = ctx.image_data(image, image_span)?;
+ let array_index = arrayed
+ .then(|| self.expression(args.next()?, ctx.reborrow()))
+ .transpose()?;
+
+ let (level, depth_ref) = match fun {
+ Texture::Gather => (crate::SampleLevel::Zero, None),
+ Texture::GatherCompare => {
+ let reference = self.expression(args.next()?, ctx.reborrow())?;
+ (crate::SampleLevel::Zero, Some(reference))
+ }
+
+ Texture::Sample => (crate::SampleLevel::Auto, None),
+ Texture::SampleBias => {
+ let bias = self.expression(args.next()?, ctx.reborrow())?;
+ (crate::SampleLevel::Bias(bias), None)
+ }
+ Texture::SampleCompare => {
+ let reference = self.expression(args.next()?, ctx.reborrow())?;
+ (crate::SampleLevel::Auto, Some(reference))
+ }
+ Texture::SampleCompareLevel => {
+ let reference = self.expression(args.next()?, ctx.reborrow())?;
+ (crate::SampleLevel::Zero, Some(reference))
+ }
+ Texture::SampleGrad => {
+ let x = self.expression(args.next()?, ctx.reborrow())?;
+ let y = self.expression(args.next()?, ctx.reborrow())?;
+ (crate::SampleLevel::Gradient { x, y }, None)
+ }
+ Texture::SampleLevel => {
+ let level = self.expression(args.next()?, ctx.reborrow())?;
+ (crate::SampleLevel::Exact(level), None)
+ }
+ };
+
+ let offset = args
+ .next()
+ .map(|arg| self.constant(arg, ctx.as_output()))
+ .ok()
+ .transpose()?;
+
+ args.finish()?;
+
+ Ok(crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ })
+ }
+
+ fn gather_component(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<Option<crate::SwizzleComponent>, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+
+ let constant = match self.constant_inner(expr, ctx.as_output()).ok() {
+ Some(ConstantOrInner::Constant(c)) => ctx.module.constants[c].inner.clone(),
+ Some(ConstantOrInner::Inner(inner)) => inner,
+ None => return Ok(None),
+ };
+
+ let int = match constant {
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Sint(i),
+ ..
+ } if i >= 0 => i as u64,
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Uint(i),
+ ..
+ } => i,
+ _ => {
+ return Err(Error::InvalidGatherComponent(span));
+ }
+ };
+
+ crate::SwizzleComponent::XYZW
+ .get(int as usize)
+ .copied()
+ .map(Some)
+ .ok_or(Error::InvalidGatherComponent(span))
+ }
+
+ fn r#struct(
+ &mut self,
+ s: &ast::Struct<'source>,
+ span: Span,
+ mut ctx: OutputContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Type>, Error<'source>> {
+ let mut offset = 0;
+ let mut struct_alignment = Alignment::ONE;
+ let mut members = Vec::with_capacity(s.members.len());
+
+ for member in s.members.iter() {
+ let ty = self.resolve_ast_type(member.ty, ctx.reborrow())?;
+
+ self.layouter
+ .update(&ctx.module.types, &ctx.module.constants)
+ .unwrap();
+
+ let member_min_size = self.layouter[ty].size;
+ let member_min_alignment = self.layouter[ty].alignment;
+
+ let member_size = if let Some((size, span)) = member.size {
+ if size < member_min_size {
+ return Err(Error::SizeAttributeTooLow(span, member_min_size));
+ } else {
+ size
+ }
+ } else {
+ member_min_size
+ };
+
+ let member_alignment = if let Some((align, span)) = member.align {
+ if let Some(alignment) = Alignment::new(align) {
+ if alignment < member_min_alignment {
+ return Err(Error::AlignAttributeTooLow(span, member_min_alignment));
+ } else {
+ alignment
+ }
+ } else {
+ return Err(Error::NonPowerOfTwoAlignAttribute(span));
+ }
+ } else {
+ member_min_alignment
+ };
+
+ let binding = self.interpolate_default(&member.binding, ty, ctx.reborrow());
+
+ offset = member_alignment.round_up(offset);
+ struct_alignment = struct_alignment.max(member_alignment);
+
+ members.push(crate::StructMember {
+ name: Some(member.name.name.to_owned()),
+ ty,
+ binding,
+ offset,
+ });
+
+ offset += member_size;
+ }
+
+ let size = struct_alignment.round_up(offset);
+ let inner = crate::TypeInner::Struct {
+ members,
+ span: size,
+ };
+
+ let handle = ctx.module.types.insert(
+ crate::Type {
+ name: Some(s.name.name.to_string()),
+ inner,
+ },
+ span,
+ );
+ Ok(handle)
+ }
+
+ /// Return a Naga `Handle<Type>` representing the front-end type `handle`.
+ fn resolve_ast_type(
+ &mut self,
+ handle: Handle<ast::Type<'source>>,
+ mut ctx: OutputContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Type>, Error<'source>> {
+ let inner = match ctx.types[handle] {
+ ast::Type::Scalar { kind, width } => crate::TypeInner::Scalar { kind, width },
+ ast::Type::Vector { size, kind, width } => {
+ crate::TypeInner::Vector { size, kind, width }
+ }
+ ast::Type::Matrix {
+ rows,
+ columns,
+ width,
+ } => crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ },
+ ast::Type::Atomic { kind, width } => crate::TypeInner::Atomic { kind, width },
+ ast::Type::Pointer { base, space } => {
+ let base = self.resolve_ast_type(base, ctx.reborrow())?;
+ crate::TypeInner::Pointer { base, space }
+ }
+ ast::Type::Array { base, size } => {
+ let base = self.resolve_ast_type(base, ctx.reborrow())?;
+ self.layouter
+ .update(&ctx.module.types, &ctx.module.constants)
+ .unwrap();
+
+ crate::TypeInner::Array {
+ base,
+ size: match size {
+ ast::ArraySize::Constant(constant) => {
+ let constant = self.constant(constant, ctx.reborrow())?;
+ crate::ArraySize::Constant(constant)
+ }
+ ast::ArraySize::Dynamic => crate::ArraySize::Dynamic,
+ },
+ stride: self.layouter[base].to_stride(),
+ }
+ }
+ ast::Type::Image {
+ dim,
+ arrayed,
+ class,
+ } => crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ },
+ ast::Type::Sampler { comparison } => crate::TypeInner::Sampler { comparison },
+ ast::Type::AccelerationStructure => crate::TypeInner::AccelerationStructure,
+ ast::Type::RayQuery => crate::TypeInner::RayQuery,
+ ast::Type::BindingArray { base, size } => {
+ let base = self.resolve_ast_type(base, ctx.reborrow())?;
+
+ crate::TypeInner::BindingArray {
+ base,
+ size: match size {
+ ast::ArraySize::Constant(constant) => {
+ let constant = self.constant(constant, ctx.reborrow())?;
+ crate::ArraySize::Constant(constant)
+ }
+ ast::ArraySize::Dynamic => crate::ArraySize::Dynamic,
+ },
+ }
+ }
+ ast::Type::RayDesc => {
+ return Ok(ctx.module.generate_ray_desc_type());
+ }
+ ast::Type::RayIntersection => {
+ return Ok(ctx.module.generate_ray_intersection_type());
+ }
+ ast::Type::User(ref ident) => {
+ return match ctx.globals.get(ident.name) {
+ Some(&LoweredGlobalDecl::Type(handle)) => Ok(handle),
+ Some(_) => Err(Error::Unexpected(ident.span, ExpectedToken::Type)),
+ None => Err(Error::UnknownType(ident.span)),
+ }
+ }
+ };
+
+ Ok(ctx.ensure_type_exists(inner))
+ }
+
+ /// Find or construct a Naga [`Constant`] whose value is `expr`.
+ ///
+ /// The `ctx` indicates the Naga [`Module`] to which we should add
+ /// new `Constant`s or [`Type`]s as needed.
+ ///
+ /// [`Module`]: crate::Module
+ /// [`Constant`]: crate::Constant
+ /// [`Type`]: crate::Type
+ fn constant(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ mut ctx: OutputContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Constant>, Error<'source>> {
+ let inner = match self.constant_inner(expr, ctx.reborrow())? {
+ ConstantOrInner::Constant(c) => return Ok(c),
+ ConstantOrInner::Inner(inner) => inner,
+ };
+
+ let c = ctx.module.constants.fetch_or_append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner,
+ },
+ Span::UNDEFINED,
+ );
+ Ok(c)
+ }
+
+ fn constant_inner(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ mut ctx: OutputContext<'source, '_, '_>,
+ ) -> Result<ConstantOrInner, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let inner = match ctx.ast_expressions[expr] {
+ ast::Expression::Literal(literal) => match literal {
+ ast::Literal::Number(Number::F32(f)) => crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Float(f as _),
+ },
+ ast::Literal::Number(Number::I32(i)) => crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Sint(i as _),
+ },
+ ast::Literal::Number(Number::U32(u)) => crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Uint(u as _),
+ },
+ ast::Literal::Number(_) => {
+ unreachable!("got abstract numeric type when not expected");
+ }
+ ast::Literal::Bool(b) => crate::ConstantInner::Scalar {
+ width: 1,
+ value: crate::ScalarValue::Bool(b),
+ },
+ },
+ ast::Expression::Ident(ast::IdentExpr::Local(_)) => {
+ return Err(Error::Unexpected(span, ExpectedToken::Constant))
+ }
+ ast::Expression::Ident(ast::IdentExpr::Unresolved(name)) => {
+ return if let Some(global) = ctx.globals.get(name) {
+ match *global {
+ LoweredGlobalDecl::Const(handle) => Ok(ConstantOrInner::Constant(handle)),
+ _ => Err(Error::Unexpected(span, ExpectedToken::Constant)),
+ }
+ } else {
+ Err(Error::UnknownIdent(span, name))
+ }
+ }
+ ast::Expression::Construct {
+ ref ty,
+ ref components,
+ ..
+ } => self.const_construct(span, ty, components, ctx.reborrow())?,
+ ast::Expression::Call {
+ ref function,
+ ref arguments,
+ } => match ctx.globals.get(function.name) {
+ Some(&LoweredGlobalDecl::Type(ty)) => self.const_construct(
+ span,
+ &ast::ConstructorType::Type(ty),
+ arguments,
+ ctx.reborrow(),
+ )?,
+ Some(_) => return Err(Error::ConstExprUnsupported(span)),
+ None => return Err(Error::UnknownIdent(function.span, function.name)),
+ },
+ _ => return Err(Error::ConstExprUnsupported(span)),
+ };
+
+ Ok(ConstantOrInner::Inner(inner))
+ }
+
+ fn interpolate_default(
+ &mut self,
+ binding: &Option<crate::Binding>,
+ ty: Handle<crate::Type>,
+ ctx: OutputContext<'source, '_, '_>,
+ ) -> Option<crate::Binding> {
+ let mut binding = binding.clone();
+ if let Some(ref mut binding) = binding {
+ binding.apply_default_interpolation(&ctx.module.types[ty].inner);
+ }
+
+ binding
+ }
+
+ fn ray_query_pointer(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ mut ctx: ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let pointer = self.expression(expr, ctx.reborrow())?;
+
+ ctx.grow_types(pointer)?;
+ match *ctx.resolved_inner(pointer) {
+ crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner {
+ crate::TypeInner::RayQuery => Ok(pointer),
+ ref other => {
+ log::error!("Pointer type to {:?} passed to ray query op", other);
+ Err(Error::InvalidRayQueryPointer(span))
+ }
+ },
+ ref other => {
+ log::error!("Type {:?} passed to ray query op", other);
+ Err(Error::InvalidRayQueryPointer(span))
+ }
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/mod.rs b/third_party/rust/naga/src/front/wgsl/mod.rs
new file mode 100644
index 0000000000..eb21fae6c9
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/mod.rs
@@ -0,0 +1,340 @@
+/*!
+Frontend for [WGSL][wgsl] (WebGPU Shading Language).
+
+[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html
+*/
+
+mod error;
+mod index;
+mod lower;
+mod parse;
+#[cfg(test)]
+mod tests;
+
+use crate::arena::{Arena, UniqueArena};
+
+use crate::front::wgsl::error::Error;
+use crate::front::wgsl::parse::Parser;
+use thiserror::Error;
+
+pub use crate::front::wgsl::error::ParseError;
+use crate::front::wgsl::lower::Lowerer;
+
+pub struct Frontend {
+ parser: Parser,
+}
+
+impl Frontend {
+ pub const fn new() -> Self {
+ Self {
+ parser: Parser::new(),
+ }
+ }
+
+ pub fn parse(&mut self, source: &str) -> Result<crate::Module, ParseError> {
+ self.inner(source).map_err(|x| x.as_parse_error(source))
+ }
+
+ fn inner<'a>(&mut self, source: &'a str) -> Result<crate::Module, Error<'a>> {
+ let tu = self.parser.parse(source)?;
+ let index = index::Index::generate(&tu)?;
+ let module = Lowerer::new(&index).lower(&tu)?;
+
+ Ok(module)
+ }
+}
+
+pub fn parse_str(source: &str) -> Result<crate::Module, ParseError> {
+ Frontend::new().parse(source)
+}
+
+impl crate::StorageFormat {
+ const fn to_wgsl(self) -> &'static str {
+ use crate::StorageFormat as Sf;
+ match self {
+ Sf::R8Unorm => "r8unorm",
+ Sf::R8Snorm => "r8snorm",
+ Sf::R8Uint => "r8uint",
+ Sf::R8Sint => "r8sint",
+ Sf::R16Uint => "r16uint",
+ Sf::R16Sint => "r16sint",
+ Sf::R16Float => "r16float",
+ Sf::Rg8Unorm => "rg8unorm",
+ Sf::Rg8Snorm => "rg8snorm",
+ Sf::Rg8Uint => "rg8uint",
+ Sf::Rg8Sint => "rg8sint",
+ Sf::R32Uint => "r32uint",
+ Sf::R32Sint => "r32sint",
+ Sf::R32Float => "r32float",
+ Sf::Rg16Uint => "rg16uint",
+ Sf::Rg16Sint => "rg16sint",
+ Sf::Rg16Float => "rg16float",
+ Sf::Rgba8Unorm => "rgba8unorm",
+ Sf::Rgba8Snorm => "rgba8snorm",
+ Sf::Rgba8Uint => "rgba8uint",
+ Sf::Rgba8Sint => "rgba8sint",
+ Sf::Rgb10a2Unorm => "rgb10a2unorm",
+ Sf::Rg11b10Float => "rg11b10float",
+ Sf::Rg32Uint => "rg32uint",
+ Sf::Rg32Sint => "rg32sint",
+ Sf::Rg32Float => "rg32float",
+ Sf::Rgba16Uint => "rgba16uint",
+ Sf::Rgba16Sint => "rgba16sint",
+ Sf::Rgba16Float => "rgba16float",
+ Sf::Rgba32Uint => "rgba32uint",
+ Sf::Rgba32Sint => "rgba32sint",
+ Sf::Rgba32Float => "rgba32float",
+ Sf::R16Unorm => "r16unorm",
+ Sf::R16Snorm => "r16snorm",
+ Sf::Rg16Unorm => "rg16unorm",
+ Sf::Rg16Snorm => "rg16snorm",
+ Sf::Rgba16Unorm => "rgba16unorm",
+ Sf::Rgba16Snorm => "rgba16snorm",
+ }
+ }
+}
+
+impl crate::TypeInner {
+ /// Formats the type as it is written in wgsl.
+ ///
+ /// For example `vec3<f32>`.
+ ///
+ /// Note: The names of a `TypeInner::Struct` is not known. Therefore this method will simply return "struct" for them.
+ fn to_wgsl(
+ &self,
+ types: &UniqueArena<crate::Type>,
+ constants: &Arena<crate::Constant>,
+ ) -> String {
+ use crate::TypeInner as Ti;
+
+ match *self {
+ Ti::Scalar { kind, width } => kind.to_wgsl(width),
+ Ti::Vector { size, kind, width } => {
+ format!("vec{}<{}>", size as u32, kind.to_wgsl(width))
+ }
+ Ti::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ format!(
+ "mat{}x{}<{}>",
+ columns as u32,
+ rows as u32,
+ crate::ScalarKind::Float.to_wgsl(width),
+ )
+ }
+ Ti::Atomic { kind, width } => {
+ format!("atomic<{}>", kind.to_wgsl(width))
+ }
+ Ti::Pointer { base, .. } => {
+ let base = &types[base];
+ let name = base.name.as_deref().unwrap_or("unknown");
+ format!("ptr<{name}>")
+ }
+ Ti::ValuePointer { kind, width, .. } => {
+ format!("ptr<{}>", kind.to_wgsl(width))
+ }
+ Ti::Array { base, size, .. } => {
+ let member_type = &types[base];
+ let base = member_type.name.as_deref().unwrap_or("unknown");
+ match size {
+ crate::ArraySize::Constant(size) => {
+ let constant = &constants[size];
+ let size = constant
+ .name
+ .clone()
+ .unwrap_or_else(|| match constant.inner {
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Uint(size),
+ ..
+ } => size.to_string(),
+ crate::ConstantInner::Scalar {
+ value: crate::ScalarValue::Sint(size),
+ ..
+ } => size.to_string(),
+ _ => "?".to_string(),
+ });
+ format!("array<{base}, {size}>")
+ }
+ crate::ArraySize::Dynamic => format!("array<{base}>"),
+ }
+ }
+ Ti::Struct { .. } => {
+ // TODO: Actually output the struct?
+ "struct".to_string()
+ }
+ Ti::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ let dim_suffix = match dim {
+ crate::ImageDimension::D1 => "_1d",
+ crate::ImageDimension::D2 => "_2d",
+ crate::ImageDimension::D3 => "_3d",
+ crate::ImageDimension::Cube => "_cube",
+ };
+ let array_suffix = if arrayed { "_array" } else { "" };
+
+ let class_suffix = match class {
+ crate::ImageClass::Sampled { multi: true, .. } => "_multisampled",
+ crate::ImageClass::Depth { multi: false } => "_depth",
+ crate::ImageClass::Depth { multi: true } => "_depth_multisampled",
+ crate::ImageClass::Sampled { multi: false, .. }
+ | crate::ImageClass::Storage { .. } => "",
+ };
+
+ let type_in_brackets = match class {
+ crate::ImageClass::Sampled { kind, .. } => {
+ // Note: The only valid widths are 4 bytes wide.
+ // The lexer has already verified this, so we can safely assume it here.
+ // https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type
+ let element_type = kind.to_wgsl(4);
+ format!("<{element_type}>")
+ }
+ crate::ImageClass::Depth { multi: _ } => String::new(),
+ crate::ImageClass::Storage { format, access } => {
+ if access.contains(crate::StorageAccess::STORE) {
+ format!("<{},write>", format.to_wgsl())
+ } else {
+ format!("<{}>", format.to_wgsl())
+ }
+ }
+ };
+
+ format!("texture{class_suffix}{dim_suffix}{array_suffix}{type_in_brackets}")
+ }
+ Ti::Sampler { .. } => "sampler".to_string(),
+ Ti::AccelerationStructure => "acceleration_structure".to_string(),
+ Ti::RayQuery => "ray_query".to_string(),
+ Ti::BindingArray { base, size, .. } => {
+ let member_type = &types[base];
+ let base = member_type.name.as_deref().unwrap_or("unknown");
+ match size {
+ crate::ArraySize::Constant(size) => {
+ let size = constants[size].name.as_deref().unwrap_or("unknown");
+ format!("binding_array<{base}, {size}>")
+ }
+ crate::ArraySize::Dynamic => format!("binding_array<{base}>"),
+ }
+ }
+ }
+ }
+}
+
+mod type_inner_tests {
+ #[test]
+ fn to_wgsl() {
+ let mut types = crate::UniqueArena::new();
+ let mut constants = crate::Arena::new();
+ let c = constants.append(
+ crate::Constant {
+ name: Some("C".to_string()),
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Uint(32),
+ },
+ },
+ Default::default(),
+ );
+
+ let mytype1 = types.insert(
+ crate::Type {
+ name: Some("MyType1".to_string()),
+ inner: crate::TypeInner::Struct {
+ members: vec![],
+ span: 0,
+ },
+ },
+ Default::default(),
+ );
+ let mytype2 = types.insert(
+ crate::Type {
+ name: Some("MyType2".to_string()),
+ inner: crate::TypeInner::Struct {
+ members: vec![],
+ span: 0,
+ },
+ },
+ Default::default(),
+ );
+
+ let array = crate::TypeInner::Array {
+ base: mytype1,
+ stride: 4,
+ size: crate::ArraySize::Constant(c),
+ };
+ assert_eq!(array.to_wgsl(&types, &constants), "array<MyType1, C>");
+
+ let mat = crate::TypeInner::Matrix {
+ rows: crate::VectorSize::Quad,
+ columns: crate::VectorSize::Bi,
+ width: 8,
+ };
+ assert_eq!(mat.to_wgsl(&types, &constants), "mat2x4<f64>");
+
+ let ptr = crate::TypeInner::Pointer {
+ base: mytype2,
+ space: crate::AddressSpace::Storage {
+ access: crate::StorageAccess::default(),
+ },
+ };
+ assert_eq!(ptr.to_wgsl(&types, &constants), "ptr<MyType2>");
+
+ let img1 = crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Sampled {
+ kind: crate::ScalarKind::Float,
+ multi: true,
+ },
+ };
+ assert_eq!(
+ img1.to_wgsl(&types, &constants),
+ "texture_multisampled_2d<f32>"
+ );
+
+ let img2 = crate::TypeInner::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: true,
+ class: crate::ImageClass::Depth { multi: false },
+ };
+ assert_eq!(img2.to_wgsl(&types, &constants), "texture_depth_cube_array");
+
+ let img3 = crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: true },
+ };
+ assert_eq!(
+ img3.to_wgsl(&types, &constants),
+ "texture_depth_multisampled_2d"
+ );
+
+ let array = crate::TypeInner::BindingArray {
+ base: mytype1,
+ size: crate::ArraySize::Constant(c),
+ };
+ assert_eq!(
+ array.to_wgsl(&types, &constants),
+ "binding_array<MyType1, C>"
+ );
+ }
+}
+
+impl crate::ScalarKind {
+ /// Format a scalar kind+width as a type is written in wgsl.
+ ///
+ /// Examples: `f32`, `u64`, `bool`.
+ fn to_wgsl(self, width: u8) -> String {
+ let prefix = match self {
+ crate::ScalarKind::Sint => "i",
+ crate::ScalarKind::Uint => "u",
+ crate::ScalarKind::Float => "f",
+ crate::ScalarKind::Bool => return "bool".to_string(),
+ };
+ format!("{}{}", prefix, width * 8)
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/ast.rs b/third_party/rust/naga/src/front/wgsl/parse/ast.rs
new file mode 100644
index 0000000000..2a56ac6f80
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/ast.rs
@@ -0,0 +1,486 @@
+use crate::front::wgsl::parse::number::Number;
+use crate::{Arena, FastHashSet, Handle, Span};
+use std::hash::Hash;
+
+#[derive(Debug, Default)]
+pub struct TranslationUnit<'a> {
+ pub decls: Arena<GlobalDecl<'a>>,
+ /// The common expressions arena for the entire translation unit.
+ ///
+ /// All functions, global initializers, array lengths, etc. store their
+ /// expressions here. We apportion these out to individual Naga
+ /// [`Function`]s' expression arenas at lowering time. Keeping them all in a
+ /// single arena simplifies handling of things like array lengths (which are
+ /// effectively global and thus don't clearly belong to any function) and
+ /// initializers (which can appear in both function-local and module-scope
+ /// contexts).
+ ///
+ /// [`Function`]: crate::Function
+ pub expressions: Arena<Expression<'a>>,
+
+ /// Non-user-defined types, like `vec4<f32>` or `array<i32, 10>`.
+ ///
+ /// These are referred to by `Handle<ast::Type<'a>>` values.
+ /// User-defined types are referred to by name until lowering.
+ pub types: Arena<Type<'a>>,
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct Ident<'a> {
+ pub name: &'a str,
+ pub span: Span,
+}
+
+#[derive(Debug)]
+pub enum IdentExpr<'a> {
+ Unresolved(&'a str),
+ Local(Handle<Local>),
+}
+
+/// A reference to a module-scope definition or predeclared object.
+///
+/// Each [`GlobalDecl`] holds a set of these values, to be resolved to
+/// specific definitions later. To support de-duplication, `Eq` and
+/// `Hash` on a `Dependency` value consider only the name, not the
+/// source location at which the reference occurs.
+#[derive(Debug)]
+pub struct Dependency<'a> {
+ /// The name referred to.
+ pub ident: &'a str,
+
+ /// The location at which the reference to that name occurs.
+ pub usage: Span,
+}
+
+impl Hash for Dependency<'_> {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ self.ident.hash(state);
+ }
+}
+
+impl PartialEq for Dependency<'_> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ident == other.ident
+ }
+}
+
+impl Eq for Dependency<'_> {}
+
+/// A module-scope declaration.
+#[derive(Debug)]
+pub struct GlobalDecl<'a> {
+ pub kind: GlobalDeclKind<'a>,
+
+ /// Names of all module-scope or predeclared objects this
+ /// declaration uses.
+ pub dependencies: FastHashSet<Dependency<'a>>,
+}
+
+#[derive(Debug)]
+pub enum GlobalDeclKind<'a> {
+ Fn(Function<'a>),
+ Var(GlobalVariable<'a>),
+ Const(Const<'a>),
+ Struct(Struct<'a>),
+ Type(TypeAlias<'a>),
+}
+
+#[derive(Debug)]
+pub struct FunctionArgument<'a> {
+ pub name: Ident<'a>,
+ pub ty: Handle<Type<'a>>,
+ pub binding: Option<crate::Binding>,
+ pub handle: Handle<Local>,
+}
+
+#[derive(Debug)]
+pub struct FunctionResult<'a> {
+ pub ty: Handle<Type<'a>>,
+ pub binding: Option<crate::Binding>,
+}
+
+#[derive(Debug)]
+pub struct EntryPoint {
+ pub stage: crate::ShaderStage,
+ pub early_depth_test: Option<crate::EarlyDepthTest>,
+ pub workgroup_size: [u32; 3],
+}
+
+#[cfg(doc)]
+use crate::front::wgsl::lower::{ExpressionContext, StatementContext};
+
+#[derive(Debug)]
+pub struct Function<'a> {
+ pub entry_point: Option<EntryPoint>,
+ pub name: Ident<'a>,
+ pub arguments: Vec<FunctionArgument<'a>>,
+ pub result: Option<FunctionResult<'a>>,
+
+ /// Local variable and function argument arena.
+ ///
+ /// Note that the `Local` here is actually a zero-sized type. The AST keeps
+ /// all the detailed information about locals - names, types, etc. - in
+ /// [`LocalDecl`] statements. For arguments, that information is kept in
+ /// [`arguments`]. This `Arena`'s only role is to assign a unique `Handle`
+ /// to each of them, and track their definitions' spans for use in
+ /// diagnostics.
+ ///
+ /// In the AST, when an [`Ident`] expression refers to a local variable or
+ /// argument, its [`IdentExpr`] holds the referent's `Handle<Local>` in this
+ /// arena.
+ ///
+ /// During lowering, [`LocalDecl`] statements add entries to a per-function
+ /// table that maps `Handle<Local>` values to their Naga representations,
+ /// accessed via [`StatementContext::local_table`] and
+ /// [`ExpressionContext::local_table`]. This table is then consulted when
+ /// lowering subsequent [`Ident`] expressions.
+ ///
+ /// [`LocalDecl`]: StatementKind::LocalDecl
+ /// [`arguments`]: Function::arguments
+ /// [`Ident`]: Expression::Ident
+ /// [`StatementContext::local_table`]: StatementContext::local_table
+ /// [`ExpressionContext::local_table`]: ExpressionContext::local_table
+ pub locals: Arena<Local>,
+
+ pub body: Block<'a>,
+}
+
+#[derive(Debug)]
+pub struct GlobalVariable<'a> {
+ pub name: Ident<'a>,
+ pub space: crate::AddressSpace,
+ pub binding: Option<crate::ResourceBinding>,
+ pub ty: Handle<Type<'a>>,
+ pub init: Option<Handle<Expression<'a>>>,
+}
+
+#[derive(Debug)]
+pub struct StructMember<'a> {
+ pub name: Ident<'a>,
+ pub ty: Handle<Type<'a>>,
+ pub binding: Option<crate::Binding>,
+ pub align: Option<(u32, Span)>,
+ pub size: Option<(u32, Span)>,
+}
+
+#[derive(Debug)]
+pub struct Struct<'a> {
+ pub name: Ident<'a>,
+ pub members: Vec<StructMember<'a>>,
+}
+
+#[derive(Debug)]
+pub struct TypeAlias<'a> {
+ pub name: Ident<'a>,
+ pub ty: Handle<Type<'a>>,
+}
+
+#[derive(Debug)]
+pub struct Const<'a> {
+ pub name: Ident<'a>,
+ pub ty: Option<Handle<Type<'a>>>,
+ pub init: Handle<Expression<'a>>,
+}
+
+/// The size of an [`Array`] or [`BindingArray`].
+///
+/// [`Array`]: Type::Array
+/// [`BindingArray`]: Type::BindingArray
+#[derive(Debug, Copy, Clone)]
+pub enum ArraySize<'a> {
+ /// The length as a constant expression.
+ Constant(Handle<Expression<'a>>),
+ Dynamic,
+}
+
+#[derive(Debug)]
+pub enum Type<'a> {
+ Scalar {
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ },
+ Vector {
+ size: crate::VectorSize,
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ },
+ Matrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: crate::Bytes,
+ },
+ Atomic {
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ },
+ Pointer {
+ base: Handle<Type<'a>>,
+ space: crate::AddressSpace,
+ },
+ Array {
+ base: Handle<Type<'a>>,
+ size: ArraySize<'a>,
+ },
+ Image {
+ dim: crate::ImageDimension,
+ arrayed: bool,
+ class: crate::ImageClass,
+ },
+ Sampler {
+ comparison: bool,
+ },
+ AccelerationStructure,
+ RayQuery,
+ RayDesc,
+ RayIntersection,
+ BindingArray {
+ base: Handle<Type<'a>>,
+ size: ArraySize<'a>,
+ },
+
+ /// A user-defined type, like a struct or a type alias.
+ User(Ident<'a>),
+}
+
+#[derive(Debug, Default)]
+pub struct Block<'a> {
+ pub stmts: Vec<Statement<'a>>,
+}
+
+#[derive(Debug)]
+pub struct Statement<'a> {
+ pub kind: StatementKind<'a>,
+ pub span: Span,
+}
+
+#[derive(Debug)]
+pub enum StatementKind<'a> {
+ LocalDecl(LocalDecl<'a>),
+ Block(Block<'a>),
+ If {
+ condition: Handle<Expression<'a>>,
+ accept: Block<'a>,
+ reject: Block<'a>,
+ },
+ Switch {
+ selector: Handle<Expression<'a>>,
+ cases: Vec<SwitchCase<'a>>,
+ },
+ Loop {
+ body: Block<'a>,
+ continuing: Block<'a>,
+ break_if: Option<Handle<Expression<'a>>>,
+ },
+ Break,
+ Continue,
+ Return {
+ value: Option<Handle<Expression<'a>>>,
+ },
+ Kill,
+ Call {
+ function: Ident<'a>,
+ arguments: Vec<Handle<Expression<'a>>>,
+ },
+ Assign {
+ target: Handle<Expression<'a>>,
+ op: Option<crate::BinaryOperator>,
+ value: Handle<Expression<'a>>,
+ },
+ Increment(Handle<Expression<'a>>),
+ Decrement(Handle<Expression<'a>>),
+ Ignore(Handle<Expression<'a>>),
+}
+
+#[derive(Debug)]
+pub enum SwitchValue {
+ I32(i32),
+ U32(u32),
+ Default,
+}
+
+#[derive(Debug)]
+pub struct SwitchCase<'a> {
+ pub value: SwitchValue,
+ pub value_span: Span,
+ pub body: Block<'a>,
+ pub fall_through: bool,
+}
+
+/// A type at the head of a [`Construct`] expression.
+///
+/// WGSL has two types of [`type constructor expressions`]:
+///
+/// - Those that fully specify the type being constructed, like
+/// `vec3<f32>(x,y,z)`, which obviously constructs a `vec3<f32>`.
+///
+/// - Those that leave the component type of the composite being constructed
+/// implicit, to be inferred from the argument types, like `vec3(x,y,z)`,
+/// which constructs a `vec3<T>` where `T` is the type of `x`, `y`, and `z`.
+///
+/// This enum represents the head type of both cases. The `PartialFoo` variants
+/// represent the second case, where the component type is implicit.
+///
+/// This does not cover structs or types referred to by type aliases. See the
+/// documentation for [`Construct`] and [`Call`] expressions for details.
+///
+/// [`Construct`]: Expression::Construct
+/// [`type constructor expressions`]: https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr
+/// [`Call`]: Expression::Call
+#[derive(Debug)]
+pub enum ConstructorType<'a> {
+ /// A scalar type or conversion: `f32(1)`.
+ Scalar {
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ },
+
+ /// A vector construction whose component type is inferred from the
+ /// argument: `vec3(1.0)`.
+ PartialVector { size: crate::VectorSize },
+
+ /// A vector construction whose component type is written out:
+ /// `vec3<f32>(1.0)`.
+ Vector {
+ size: crate::VectorSize,
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ },
+
+ /// A matrix construction whose component type is inferred from the
+ /// argument: `mat2x2(1,2,3,4)`.
+ PartialMatrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ },
+
+ /// A matrix construction whose component type is written out:
+ /// `mat2x2<f32>(1,2,3,4)`.
+ Matrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: crate::Bytes,
+ },
+
+ /// An array whose component type and size are inferred from the arguments:
+ /// `array(3,4,5)`.
+ PartialArray,
+
+ /// An array whose component type and size are written out:
+ /// `array<u32, 4>(3,4,5)`.
+ Array {
+ base: Handle<Type<'a>>,
+ size: ArraySize<'a>,
+ },
+
+ /// Constructing a value of a known Naga IR type.
+ ///
+ /// This variant is produced only during lowering, when we have Naga types
+ /// available, never during parsing.
+ Type(Handle<crate::Type>),
+}
+
+#[derive(Debug, Copy, Clone)]
+pub enum Literal {
+ Bool(bool),
+ Number(Number),
+}
+
+#[cfg(doc)]
+use crate::front::wgsl::lower::Lowerer;
+
+#[derive(Debug)]
+pub enum Expression<'a> {
+ Literal(Literal),
+ Ident(IdentExpr<'a>),
+
+ /// A type constructor expression.
+ ///
+ /// This is only used for expressions like `KEYWORD(EXPR...)` and
+ /// `KEYWORD<PARAM>(EXPR...)`, where `KEYWORD` is a [type-defining keyword] like
+ /// `vec3`. These keywords cannot be shadowed by user definitions, so we can
+ /// tell that such an expression is a construction immediately.
+ ///
+ /// For ordinary identifiers, we can't tell whether an expression like
+ /// `IDENTIFIER(EXPR, ...)` is a construction expression or a function call
+ /// until we know `IDENTIFIER`'s definition, so we represent those as
+ /// [`Call`] expressions.
+ ///
+ /// [type-defining keyword]: https://gpuweb.github.io/gpuweb/wgsl/#type-defining-keywords
+ /// [`Call`]: Expression::Call
+ Construct {
+ ty: ConstructorType<'a>,
+ ty_span: Span,
+ components: Vec<Handle<Expression<'a>>>,
+ },
+ Unary {
+ op: crate::UnaryOperator,
+ expr: Handle<Expression<'a>>,
+ },
+ AddrOf(Handle<Expression<'a>>),
+ Deref(Handle<Expression<'a>>),
+ Binary {
+ op: crate::BinaryOperator,
+ left: Handle<Expression<'a>>,
+ right: Handle<Expression<'a>>,
+ },
+
+ /// A function call or type constructor expression.
+ ///
+ /// We can't tell whether an expression like `IDENTIFIER(EXPR, ...)` is a
+ /// construction expression or a function call until we know `IDENTIFIER`'s
+ /// definition, so we represent everything of that form as one of these
+ /// expressions until lowering. At that point, [`Lowerer::call`] has
+ /// everything's definition in hand, and can decide whether to emit a Naga
+ /// [`Constant`], [`As`], [`Splat`], or [`Compose`] expression.
+ ///
+ /// [`Lowerer::call`]: Lowerer::call
+ /// [`Constant`]: crate::Expression::Constant
+ /// [`As`]: crate::Expression::As
+ /// [`Splat`]: crate::Expression::Splat
+ /// [`Compose`]: crate::Expression::Compose
+ Call {
+ function: Ident<'a>,
+ arguments: Vec<Handle<Expression<'a>>>,
+ },
+ Index {
+ base: Handle<Expression<'a>>,
+ index: Handle<Expression<'a>>,
+ },
+ Member {
+ base: Handle<Expression<'a>>,
+ field: Ident<'a>,
+ },
+ Bitcast {
+ expr: Handle<Expression<'a>>,
+ to: Handle<Type<'a>>,
+ ty_span: Span,
+ },
+}
+
+#[derive(Debug)]
+pub struct LocalVariable<'a> {
+ pub name: Ident<'a>,
+ pub ty: Option<Handle<Type<'a>>>,
+ pub init: Option<Handle<Expression<'a>>>,
+ pub handle: Handle<Local>,
+}
+
+#[derive(Debug)]
+pub struct Let<'a> {
+ pub name: Ident<'a>,
+ pub ty: Option<Handle<Type<'a>>>,
+ pub init: Handle<Expression<'a>>,
+ pub handle: Handle<Local>,
+}
+
+#[derive(Debug)]
+pub enum LocalDecl<'a> {
+ Var(LocalVariable<'a>),
+ Let(Let<'a>),
+}
+
+#[derive(Debug)]
+/// A placeholder for a local variable declaration.
+///
+/// See [`Function::locals`] for more information.
+pub struct Local;
diff --git a/third_party/rust/naga/src/front/wgsl/parse/conv.rs b/third_party/rust/naga/src/front/wgsl/parse/conv.rs
new file mode 100644
index 0000000000..a69dbbe470
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/conv.rs
@@ -0,0 +1,236 @@
+use super::Error;
+use crate::Span;
+
+pub fn map_address_space(word: &str, span: Span) -> Result<crate::AddressSpace, Error<'_>> {
+ match word {
+ "private" => Ok(crate::AddressSpace::Private),
+ "workgroup" => Ok(crate::AddressSpace::WorkGroup),
+ "uniform" => Ok(crate::AddressSpace::Uniform),
+ "storage" => Ok(crate::AddressSpace::Storage {
+ access: crate::StorageAccess::default(),
+ }),
+ "push_constant" => Ok(crate::AddressSpace::PushConstant),
+ "function" => Ok(crate::AddressSpace::Function),
+ _ => Err(Error::UnknownAddressSpace(span)),
+ }
+}
+
+pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>> {
+ Ok(match word {
+ "position" => crate::BuiltIn::Position { invariant: false },
+ // vertex
+ "vertex_index" => crate::BuiltIn::VertexIndex,
+ "instance_index" => crate::BuiltIn::InstanceIndex,
+ "view_index" => crate::BuiltIn::ViewIndex,
+ // fragment
+ "front_facing" => crate::BuiltIn::FrontFacing,
+ "frag_depth" => crate::BuiltIn::FragDepth,
+ "primitive_index" => crate::BuiltIn::PrimitiveIndex,
+ "sample_index" => crate::BuiltIn::SampleIndex,
+ "sample_mask" => crate::BuiltIn::SampleMask,
+ // compute
+ "global_invocation_id" => crate::BuiltIn::GlobalInvocationId,
+ "local_invocation_id" => crate::BuiltIn::LocalInvocationId,
+ "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
+ "workgroup_id" => crate::BuiltIn::WorkGroupId,
+ "num_workgroups" => crate::BuiltIn::NumWorkGroups,
+ _ => return Err(Error::UnknownBuiltin(span)),
+ })
+}
+
+pub fn map_interpolation(word: &str, span: Span) -> Result<crate::Interpolation, Error<'_>> {
+ match word {
+ "linear" => Ok(crate::Interpolation::Linear),
+ "flat" => Ok(crate::Interpolation::Flat),
+ "perspective" => Ok(crate::Interpolation::Perspective),
+ _ => Err(Error::UnknownAttribute(span)),
+ }
+}
+
+pub fn map_sampling(word: &str, span: Span) -> Result<crate::Sampling, Error<'_>> {
+ match word {
+ "center" => Ok(crate::Sampling::Center),
+ "centroid" => Ok(crate::Sampling::Centroid),
+ "sample" => Ok(crate::Sampling::Sample),
+ _ => Err(Error::UnknownAttribute(span)),
+ }
+}
+
+pub fn map_storage_format(word: &str, span: Span) -> Result<crate::StorageFormat, Error<'_>> {
+ use crate::StorageFormat as Sf;
+ Ok(match word {
+ "r8unorm" => Sf::R8Unorm,
+ "r8snorm" => Sf::R8Snorm,
+ "r8uint" => Sf::R8Uint,
+ "r8sint" => Sf::R8Sint,
+ "r16unorm" => Sf::R16Unorm,
+ "r16snorm" => Sf::R16Snorm,
+ "r16uint" => Sf::R16Uint,
+ "r16sint" => Sf::R16Sint,
+ "r16float" => Sf::R16Float,
+ "rg8unorm" => Sf::Rg8Unorm,
+ "rg8snorm" => Sf::Rg8Snorm,
+ "rg8uint" => Sf::Rg8Uint,
+ "rg8sint" => Sf::Rg8Sint,
+ "r32uint" => Sf::R32Uint,
+ "r32sint" => Sf::R32Sint,
+ "r32float" => Sf::R32Float,
+ "rg16unorm" => Sf::Rg16Unorm,
+ "rg16snorm" => Sf::Rg16Snorm,
+ "rg16uint" => Sf::Rg16Uint,
+ "rg16sint" => Sf::Rg16Sint,
+ "rg16float" => Sf::Rg16Float,
+ "rgba8unorm" => Sf::Rgba8Unorm,
+ "rgba8snorm" => Sf::Rgba8Snorm,
+ "rgba8uint" => Sf::Rgba8Uint,
+ "rgba8sint" => Sf::Rgba8Sint,
+ "rgb10a2unorm" => Sf::Rgb10a2Unorm,
+ "rg11b10float" => Sf::Rg11b10Float,
+ "rg32uint" => Sf::Rg32Uint,
+ "rg32sint" => Sf::Rg32Sint,
+ "rg32float" => Sf::Rg32Float,
+ "rgba16unorm" => Sf::Rgba16Unorm,
+ "rgba16snorm" => Sf::Rgba16Snorm,
+ "rgba16uint" => Sf::Rgba16Uint,
+ "rgba16sint" => Sf::Rgba16Sint,
+ "rgba16float" => Sf::Rgba16Float,
+ "rgba32uint" => Sf::Rgba32Uint,
+ "rgba32sint" => Sf::Rgba32Sint,
+ "rgba32float" => Sf::Rgba32Float,
+ _ => return Err(Error::UnknownStorageFormat(span)),
+ })
+}
+
+pub fn get_scalar_type(word: &str) -> Option<(crate::ScalarKind, crate::Bytes)> {
+ match word {
+ // "f16" => Some((crate::ScalarKind::Float, 2)),
+ "f32" => Some((crate::ScalarKind::Float, 4)),
+ "f64" => Some((crate::ScalarKind::Float, 8)),
+ "i32" => Some((crate::ScalarKind::Sint, 4)),
+ "u32" => Some((crate::ScalarKind::Uint, 4)),
+ "bool" => Some((crate::ScalarKind::Bool, crate::BOOL_WIDTH)),
+ _ => None,
+ }
+}
+
+pub fn map_derivative(word: &str) -> Option<(crate::DerivativeAxis, crate::DerivativeControl)> {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ match word {
+ "dpdxCoarse" => Some((Axis::X, Ctrl::Coarse)),
+ "dpdyCoarse" => Some((Axis::Y, Ctrl::Coarse)),
+ "fwidthCoarse" => Some((Axis::Width, Ctrl::Coarse)),
+ "dpdxFine" => Some((Axis::X, Ctrl::Fine)),
+ "dpdyFine" => Some((Axis::Y, Ctrl::Fine)),
+ "fwidthFine" => Some((Axis::Width, Ctrl::Fine)),
+ "dpdx" => Some((Axis::X, Ctrl::None)),
+ "dpdy" => Some((Axis::Y, Ctrl::None)),
+ "fwidth" => Some((Axis::Width, Ctrl::None)),
+ _ => None,
+ }
+}
+
+pub fn map_relational_fun(word: &str) -> Option<crate::RelationalFunction> {
+ match word {
+ "any" => Some(crate::RelationalFunction::Any),
+ "all" => Some(crate::RelationalFunction::All),
+ _ => None,
+ }
+}
+
+pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
+ use crate::MathFunction as Mf;
+ Some(match word {
+ // comparison
+ "abs" => Mf::Abs,
+ "min" => Mf::Min,
+ "max" => Mf::Max,
+ "clamp" => Mf::Clamp,
+ "saturate" => Mf::Saturate,
+ // trigonometry
+ "cos" => Mf::Cos,
+ "cosh" => Mf::Cosh,
+ "sin" => Mf::Sin,
+ "sinh" => Mf::Sinh,
+ "tan" => Mf::Tan,
+ "tanh" => Mf::Tanh,
+ "acos" => Mf::Acos,
+ "acosh" => Mf::Acosh,
+ "asin" => Mf::Asin,
+ "asinh" => Mf::Asinh,
+ "atan" => Mf::Atan,
+ "atanh" => Mf::Atanh,
+ "atan2" => Mf::Atan2,
+ "radians" => Mf::Radians,
+ "degrees" => Mf::Degrees,
+ // decomposition
+ "ceil" => Mf::Ceil,
+ "floor" => Mf::Floor,
+ "round" => Mf::Round,
+ "fract" => Mf::Fract,
+ "trunc" => Mf::Trunc,
+ "modf" => Mf::Modf,
+ "frexp" => Mf::Frexp,
+ "ldexp" => Mf::Ldexp,
+ // exponent
+ "exp" => Mf::Exp,
+ "exp2" => Mf::Exp2,
+ "log" => Mf::Log,
+ "log2" => Mf::Log2,
+ "pow" => Mf::Pow,
+ // geometry
+ "dot" => Mf::Dot,
+ "outerProduct" => Mf::Outer,
+ "cross" => Mf::Cross,
+ "distance" => Mf::Distance,
+ "length" => Mf::Length,
+ "normalize" => Mf::Normalize,
+ "faceForward" => Mf::FaceForward,
+ "reflect" => Mf::Reflect,
+ "refract" => Mf::Refract,
+ // computational
+ "sign" => Mf::Sign,
+ "fma" => Mf::Fma,
+ "mix" => Mf::Mix,
+ "step" => Mf::Step,
+ "smoothstep" => Mf::SmoothStep,
+ "sqrt" => Mf::Sqrt,
+ "inverseSqrt" => Mf::InverseSqrt,
+ "transpose" => Mf::Transpose,
+ "determinant" => Mf::Determinant,
+ // bits
+ "countTrailingZeros" => Mf::CountTrailingZeros,
+ "countLeadingZeros" => Mf::CountLeadingZeros,
+ "countOneBits" => Mf::CountOneBits,
+ "reverseBits" => Mf::ReverseBits,
+ "extractBits" => Mf::ExtractBits,
+ "insertBits" => Mf::InsertBits,
+ "firstTrailingBit" => Mf::FindLsb,
+ "firstLeadingBit" => Mf::FindMsb,
+ // data packing
+ "pack4x8snorm" => Mf::Pack4x8snorm,
+ "pack4x8unorm" => Mf::Pack4x8unorm,
+ "pack2x16snorm" => Mf::Pack2x16snorm,
+ "pack2x16unorm" => Mf::Pack2x16unorm,
+ "pack2x16float" => Mf::Pack2x16float,
+ // data unpacking
+ "unpack4x8snorm" => Mf::Unpack4x8snorm,
+ "unpack4x8unorm" => Mf::Unpack4x8unorm,
+ "unpack2x16snorm" => Mf::Unpack2x16snorm,
+ "unpack2x16unorm" => Mf::Unpack2x16unorm,
+ "unpack2x16float" => Mf::Unpack2x16float,
+ _ => return None,
+ })
+}
+
+pub fn map_conservative_depth(
+ word: &str,
+ span: Span,
+) -> Result<crate::ConservativeDepth, Error<'_>> {
+ use crate::ConservativeDepth as Cd;
+ match word {
+ "greater_equal" => Ok(Cd::GreaterEqual),
+ "less_equal" => Ok(Cd::LessEqual),
+ "unchanged" => Ok(Cd::Unchanged),
+ _ => Err(Error::UnknownConservativeDepth(span)),
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/lexer.rs b/third_party/rust/naga/src/front/wgsl/parse/lexer.rs
new file mode 100644
index 0000000000..4654d4087f
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/lexer.rs
@@ -0,0 +1,723 @@
+use super::{number::consume_number, Error, ExpectedToken};
+use crate::front::wgsl::error::NumberError;
+use crate::front::wgsl::parse::{conv, Number};
+use crate::Span;
+
+type TokenSpan<'a> = (Token<'a>, Span);
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum Token<'a> {
+ Separator(char),
+ Paren(char),
+ Attribute,
+ Number(Result<Number, NumberError>),
+ Word(&'a str),
+ Operation(char),
+ LogicalOperation(char),
+ ShiftOperation(char),
+ AssignmentOperation(char),
+ IncrementOperation,
+ DecrementOperation,
+ Arrow,
+ Unknown(char),
+ Trivia,
+ End,
+}
+
+fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str) {
+ let pos = input.find(|c| !what(c)).unwrap_or(input.len());
+ input.split_at(pos)
+}
+
+/// Return the token at the start of `input`.
+///
+/// If `generic` is `false`, then the bit shift operators `>>` or `<<`
+/// are valid lookahead tokens for the current parser state (see [§3.1
+/// Parsing] in the WGSL specification). In other words:
+///
+/// - If `generic` is `true`, then we are expecting an angle bracket
+/// around a generic type parameter, like the `<` and `>` in
+/// `vec3<f32>`, so interpret `<` and `>` as `Token::Paren` tokens,
+/// even if they're part of `<<` or `>>` sequences.
+///
+/// - Otherwise, interpret `<<` and `>>` as shift operators:
+/// `Token::LogicalOperation` tokens.
+///
+/// [§3.1 Parsing]: https://gpuweb.github.io/gpuweb/wgsl/#parsing
+fn consume_token(input: &str, generic: bool) -> (Token<'_>, &str) {
+ let mut chars = input.chars();
+ let cur = match chars.next() {
+ Some(c) => c,
+ None => return (Token::End, ""),
+ };
+ match cur {
+ ':' | ';' | ',' => (Token::Separator(cur), chars.as_str()),
+ '.' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('0'..='9') => consume_number(input),
+ _ => (Token::Separator(cur), og_chars),
+ }
+ }
+ '@' => (Token::Attribute, chars.as_str()),
+ '(' | ')' | '{' | '}' | '[' | ']' => (Token::Paren(cur), chars.as_str()),
+ '<' | '>' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') if !generic => (Token::LogicalOperation(cur), chars.as_str()),
+ Some(c) if c == cur && !generic => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::ShiftOperation(cur), og_chars),
+ }
+ }
+ _ => (Token::Paren(cur), og_chars),
+ }
+ }
+ '0'..='9' => consume_number(input),
+ '/' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('/') => {
+ let _ = chars.position(is_comment_end);
+ (Token::Trivia, chars.as_str())
+ }
+ Some('*') => {
+ let mut depth = 1;
+ let mut prev = None;
+
+ for c in &mut chars {
+ match (prev, c) {
+ (Some('*'), '/') => {
+ prev = None;
+ depth -= 1;
+ if depth == 0 {
+ return (Token::Trivia, chars.as_str());
+ }
+ }
+ (Some('/'), '*') => {
+ prev = None;
+ depth += 1;
+ }
+ _ => {
+ prev = Some(c);
+ }
+ }
+ }
+
+ (Token::End, "")
+ }
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '-' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('>') => (Token::Arrow, chars.as_str()),
+ Some('0'..='9' | '.') => consume_number(input),
+ Some('-') => (Token::DecrementOperation, chars.as_str()),
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '+' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('+') => (Token::IncrementOperation, chars.as_str()),
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '*' | '%' | '^' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '~' => (Token::Operation(cur), chars.as_str()),
+ '=' | '!' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') => (Token::LogicalOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '&' | '|' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some(c) if c == cur => (Token::LogicalOperation(cur), chars.as_str()),
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ _ if is_blankspace(cur) => {
+ let (_, rest) = consume_any(input, is_blankspace);
+ (Token::Trivia, rest)
+ }
+ _ if is_word_start(cur) => {
+ let (word, rest) = consume_any(input, is_word_part);
+ (Token::Word(word), rest)
+ }
+ _ => (Token::Unknown(cur), chars.as_str()),
+ }
+}
+
+/// Returns whether or not a char is a comment end
+/// (Unicode Pattern_White_Space excluding U+0020, U+0009, U+200E and U+200F)
+const fn is_comment_end(c: char) -> bool {
+ match c {
+ '\u{000a}'..='\u{000d}' | '\u{0085}' | '\u{2028}' | '\u{2029}' => true,
+ _ => false,
+ }
+}
+
+/// Returns whether or not a char is a blankspace (Unicode Pattern_White_Space)
+const fn is_blankspace(c: char) -> bool {
+ match c {
+ '\u{0020}'
+ | '\u{0009}'..='\u{000d}'
+ | '\u{0085}'
+ | '\u{200e}'
+ | '\u{200f}'
+ | '\u{2028}'
+ | '\u{2029}' => true,
+ _ => false,
+ }
+}
+
+/// Returns whether or not a char is a word start (Unicode XID_Start + '_')
+fn is_word_start(c: char) -> bool {
+ c == '_' || unicode_xid::UnicodeXID::is_xid_start(c)
+}
+
+/// Returns whether or not a char is a word part (Unicode XID_Continue)
+fn is_word_part(c: char) -> bool {
+ unicode_xid::UnicodeXID::is_xid_continue(c)
+}
+
+#[derive(Clone)]
+pub(in crate::front::wgsl) struct Lexer<'a> {
+ input: &'a str,
+ pub(in crate::front::wgsl) source: &'a str,
+ // The byte offset of the end of the last non-trivia token.
+ last_end_offset: usize,
+}
+
+impl<'a> Lexer<'a> {
+ pub(in crate::front::wgsl) const fn new(input: &'a str) -> Self {
+ Lexer {
+ input,
+ source: input,
+ last_end_offset: 0,
+ }
+ }
+
+ /// Calls the function with a lexer and returns the result of the function as well as the span for everything the function parsed
+ ///
+ /// # Examples
+ /// ```ignore
+ /// let lexer = Lexer::new("5");
+ /// let (value, span) = lexer.capture_span(Lexer::next_uint_literal);
+ /// assert_eq!(value, 5);
+ /// ```
+ #[inline]
+ pub fn capture_span<T, E>(
+ &mut self,
+ inner: impl FnOnce(&mut Self) -> Result<T, E>,
+ ) -> Result<(T, Span), E> {
+ let start = self.current_byte_offset();
+ let res = inner(self)?;
+ let end = self.current_byte_offset();
+ Ok((res, Span::from(start..end)))
+ }
+
+ pub(in crate::front::wgsl) fn start_byte_offset(&mut self) -> usize {
+ loop {
+ // Eat all trivia because `next` doesn't eat trailing trivia.
+ let (token, rest) = consume_token(self.input, false);
+ if let Token::Trivia = token {
+ self.input = rest;
+ } else {
+ return self.current_byte_offset();
+ }
+ }
+ }
+
+ fn peek_token_and_rest(&mut self) -> (TokenSpan<'a>, &'a str) {
+ let mut cloned = self.clone();
+ let token = cloned.next();
+ let rest = cloned.input;
+ (token, rest)
+ }
+
+ const fn current_byte_offset(&self) -> usize {
+ self.source.len() - self.input.len()
+ }
+
+ pub(in crate::front::wgsl) fn span_from(&self, offset: usize) -> Span {
+ Span::from(offset..self.last_end_offset)
+ }
+
+ /// Return the next non-whitespace token from `self`.
+ ///
+ /// Assume we are a parse state where bit shift operators may
+ /// occur, but not angle brackets.
+ #[must_use]
+ pub(in crate::front::wgsl) fn next(&mut self) -> TokenSpan<'a> {
+ self.next_impl(false)
+ }
+
+ /// Return the next non-whitespace token from `self`.
+ ///
+ /// Assume we are in a parse state where angle brackets may occur,
+ /// but not bit shift operators.
+ #[must_use]
+ pub(in crate::front::wgsl) fn next_generic(&mut self) -> TokenSpan<'a> {
+ self.next_impl(true)
+ }
+
+ /// Return the next non-whitespace token from `self`, with a span.
+ ///
+ /// See [`consume_token`] for the meaning of `generic`.
+ fn next_impl(&mut self, generic: bool) -> TokenSpan<'a> {
+ let mut start_byte_offset = self.current_byte_offset();
+ loop {
+ let (token, rest) = consume_token(self.input, generic);
+ self.input = rest;
+ match token {
+ Token::Trivia => start_byte_offset = self.current_byte_offset(),
+ _ => {
+ self.last_end_offset = self.current_byte_offset();
+ return (token, self.span_from(start_byte_offset));
+ }
+ }
+ }
+ }
+
+ #[must_use]
+ pub(in crate::front::wgsl) fn peek(&mut self) -> TokenSpan<'a> {
+ let (token, _) = self.peek_token_and_rest();
+ token
+ }
+
+ pub(in crate::front::wgsl) fn expect_span(
+ &mut self,
+ expected: Token<'a>,
+ ) -> Result<Span, Error<'a>> {
+ let next = self.next();
+ if next.0 == expected {
+ Ok(next.1)
+ } else {
+ Err(Error::Unexpected(next.1, ExpectedToken::Token(expected)))
+ }
+ }
+
+ pub(in crate::front::wgsl) fn expect(&mut self, expected: Token<'a>) -> Result<(), Error<'a>> {
+ self.expect_span(expected)?;
+ Ok(())
+ }
+
+ pub(in crate::front::wgsl) fn expect_generic_paren(
+ &mut self,
+ expected: char,
+ ) -> Result<(), Error<'a>> {
+ let next = self.next_generic();
+ if next.0 == Token::Paren(expected) {
+ Ok(())
+ } else {
+ Err(Error::Unexpected(
+ next.1,
+ ExpectedToken::Token(Token::Paren(expected)),
+ ))
+ }
+ }
+
+ /// If the next token matches it is skipped and true is returned
+ pub(in crate::front::wgsl) fn skip(&mut self, what: Token<'_>) -> bool {
+ let (peeked_token, rest) = self.peek_token_and_rest();
+ if peeked_token.0 == what {
+ self.input = rest;
+ true
+ } else {
+ false
+ }
+ }
+
+ pub(in crate::front::wgsl) fn next_ident_with_span(
+ &mut self,
+ ) -> Result<(&'a str, Span), Error<'a>> {
+ match self.next() {
+ (Token::Word(word), span) if word == "_" => {
+ Err(Error::InvalidIdentifierUnderscore(span))
+ }
+ (Token::Word(word), span) if word.starts_with("__") => {
+ Err(Error::ReservedIdentifierPrefix(span))
+ }
+ (Token::Word(word), span) => Ok((word, span)),
+ other => Err(Error::Unexpected(other.1, ExpectedToken::Identifier)),
+ }
+ }
+
+ pub(in crate::front::wgsl) fn next_ident(
+ &mut self,
+ ) -> Result<super::ast::Ident<'a>, Error<'a>> {
+ let ident = self
+ .next_ident_with_span()
+ .map(|(name, span)| super::ast::Ident { name, span })?;
+
+ if crate::keywords::wgsl::RESERVED.contains(&ident.name) {
+ return Err(Error::ReservedKeyword(ident.span));
+ }
+
+ Ok(ident)
+ }
+
+ /// Parses a generic scalar type, for example `<f32>`.
+ pub(in crate::front::wgsl) fn next_scalar_generic(
+ &mut self,
+ ) -> Result<(crate::ScalarKind, crate::Bytes), Error<'a>> {
+ self.expect_generic_paren('<')?;
+ let pair = match self.next() {
+ (Token::Word(word), span) => {
+ conv::get_scalar_type(word).ok_or(Error::UnknownScalarType(span))
+ }
+ (_, span) => Err(Error::UnknownScalarType(span)),
+ }?;
+ self.expect_generic_paren('>')?;
+ Ok(pair)
+ }
+
+ /// Parses a generic scalar type, for example `<f32>`.
+ ///
+ /// Returns the span covering the inner type, excluding the brackets.
+ pub(in crate::front::wgsl) fn next_scalar_generic_with_span(
+ &mut self,
+ ) -> Result<(crate::ScalarKind, crate::Bytes, Span), Error<'a>> {
+ self.expect_generic_paren('<')?;
+ let pair = match self.next() {
+ (Token::Word(word), span) => conv::get_scalar_type(word)
+ .map(|(a, b)| (a, b, span))
+ .ok_or(Error::UnknownScalarType(span)),
+ (_, span) => Err(Error::UnknownScalarType(span)),
+ }?;
+ self.expect_generic_paren('>')?;
+ Ok(pair)
+ }
+
+ pub(in crate::front::wgsl) fn next_storage_access(
+ &mut self,
+ ) -> Result<crate::StorageAccess, Error<'a>> {
+ let (ident, span) = self.next_ident_with_span()?;
+ match ident {
+ "read" => Ok(crate::StorageAccess::LOAD),
+ "write" => Ok(crate::StorageAccess::STORE),
+ "read_write" => Ok(crate::StorageAccess::LOAD | crate::StorageAccess::STORE),
+ _ => Err(Error::UnknownAccess(span)),
+ }
+ }
+
+ pub(in crate::front::wgsl) fn next_format_generic(
+ &mut self,
+ ) -> Result<(crate::StorageFormat, crate::StorageAccess), Error<'a>> {
+ self.expect(Token::Paren('<'))?;
+ let (ident, ident_span) = self.next_ident_with_span()?;
+ let format = conv::map_storage_format(ident, ident_span)?;
+ self.expect(Token::Separator(','))?;
+ let access = self.next_storage_access()?;
+ self.expect(Token::Paren('>'))?;
+ Ok((format, access))
+ }
+
+ pub(in crate::front::wgsl) fn open_arguments(&mut self) -> Result<(), Error<'a>> {
+ self.expect(Token::Paren('('))
+ }
+
+ pub(in crate::front::wgsl) fn close_arguments(&mut self) -> Result<(), Error<'a>> {
+ let _ = self.skip(Token::Separator(','));
+ self.expect(Token::Paren(')'))
+ }
+
+ pub(in crate::front::wgsl) fn next_argument(&mut self) -> Result<bool, Error<'a>> {
+ let paren = Token::Paren(')');
+ if self.skip(Token::Separator(',')) {
+ Ok(!self.skip(paren))
+ } else {
+ self.expect(paren).map(|()| false)
+ }
+ }
+}
+
+#[cfg(test)]
+fn sub_test(source: &str, expected_tokens: &[Token]) {
+ let mut lex = Lexer::new(source);
+ for &token in expected_tokens {
+ assert_eq!(lex.next().0, token);
+ }
+ assert_eq!(lex.next().0, Token::End);
+}
+
+#[test]
+fn test_numbers() {
+ // WGSL spec examples //
+
+ // decimal integer
+ sub_test(
+ "0x123 0X123u 1u 123 0 0i 0x3f",
+ &[
+ Token::Number(Ok(Number::I32(291))),
+ Token::Number(Ok(Number::U32(291))),
+ Token::Number(Ok(Number::U32(1))),
+ Token::Number(Ok(Number::I32(123))),
+ Token::Number(Ok(Number::I32(0))),
+ Token::Number(Ok(Number::I32(0))),
+ Token::Number(Ok(Number::I32(63))),
+ ],
+ );
+ // decimal floating point
+ sub_test(
+ "0.e+4f 01. .01 12.34 .0f 0h 1e-3 0xa.fp+2 0x1P+4f 0X.3 0x3p+2h 0X1.fp-4 0x3.2p+2h",
+ &[
+ Token::Number(Ok(Number::F32(0.))),
+ Token::Number(Ok(Number::F32(1.))),
+ Token::Number(Ok(Number::F32(0.01))),
+ Token::Number(Ok(Number::F32(12.34))),
+ Token::Number(Ok(Number::F32(0.))),
+ Token::Number(Err(NumberError::UnimplementedF16)),
+ Token::Number(Ok(Number::F32(0.001))),
+ Token::Number(Ok(Number::F32(43.75))),
+ Token::Number(Ok(Number::F32(16.))),
+ Token::Number(Ok(Number::F32(0.1875))),
+ Token::Number(Err(NumberError::UnimplementedF16)),
+ Token::Number(Ok(Number::F32(0.12109375))),
+ Token::Number(Err(NumberError::UnimplementedF16)),
+ ],
+ );
+
+ // MIN / MAX //
+
+ // min / max decimal signed integer
+ sub_test(
+ "-2147483648i 2147483647i -2147483649i 2147483648i",
+ &[
+ Token::Number(Ok(Number::I32(i32::MIN))),
+ Token::Number(Ok(Number::I32(i32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+ // min / max decimal unsigned integer
+ sub_test(
+ "0u 4294967295u -1u 4294967296u",
+ &[
+ Token::Number(Ok(Number::U32(u32::MIN))),
+ Token::Number(Ok(Number::U32(u32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+
+ // min / max hexadecimal signed integer
+ sub_test(
+ "-0x80000000i 0x7FFFFFFFi -0x80000001i 0x80000000i",
+ &[
+ Token::Number(Ok(Number::I32(i32::MIN))),
+ Token::Number(Ok(Number::I32(i32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+ // min / max hexadecimal unsigned integer
+ sub_test(
+ "0x0u 0xFFFFFFFFu -0x1u 0x100000000u",
+ &[
+ Token::Number(Ok(Number::U32(u32::MIN))),
+ Token::Number(Ok(Number::U32(u32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+
+ /// ≈ 2^-126 * 2^−23 (= 2^−149)
+ const SMALLEST_POSITIVE_SUBNORMAL_F32: f32 = 1e-45;
+ /// ≈ 2^-126 * (1 − 2^−23)
+ const LARGEST_SUBNORMAL_F32: f32 = 1.1754942e-38;
+ /// ≈ 2^-126
+ const SMALLEST_POSITIVE_NORMAL_F32: f32 = f32::MIN_POSITIVE;
+ /// ≈ 1 − 2^−24
+ const LARGEST_F32_LESS_THAN_ONE: f32 = 0.99999994;
+ /// ≈ 1 + 2^−23
+ const SMALLEST_F32_LARGER_THAN_ONE: f32 = 1.0000001;
+ /// ≈ -(2^127 * (2 − 2^−23))
+ const SMALLEST_NORMAL_F32: f32 = f32::MIN;
+ /// ≈ 2^127 * (2 − 2^−23)
+ const LARGEST_NORMAL_F32: f32 = f32::MAX;
+
+ // decimal floating point
+ sub_test(
+ "1e-45f 1.1754942e-38f 1.17549435e-38f 0.99999994f 1.0000001f -3.40282347e+38f 3.40282347e+38f",
+ &[
+ Token::Number(Ok(Number::F32(
+ SMALLEST_POSITIVE_SUBNORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ LARGEST_SUBNORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ SMALLEST_POSITIVE_NORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ LARGEST_F32_LESS_THAN_ONE,
+ ))),
+ Token::Number(Ok(Number::F32(
+ SMALLEST_F32_LARGER_THAN_ONE,
+ ))),
+ Token::Number(Ok(Number::F32(
+ SMALLEST_NORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ LARGEST_NORMAL_F32,
+ ))),
+ ],
+ );
+ sub_test(
+ "-3.40282367e+38f 3.40282367e+38f",
+ &[
+ Token::Number(Err(NumberError::NotRepresentable)), // ≈ -2^128
+ Token::Number(Err(NumberError::NotRepresentable)), // ≈ 2^128
+ ],
+ );
+
+ // hexadecimal floating point
+ sub_test(
+ "0x1p-149f 0x7FFFFFp-149f 0x1p-126f 0xFFFFFFp-24f 0x800001p-23f -0xFFFFFFp+104f 0xFFFFFFp+104f",
+ &[
+ Token::Number(Ok(Number::F32(
+ SMALLEST_POSITIVE_SUBNORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ LARGEST_SUBNORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ SMALLEST_POSITIVE_NORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ LARGEST_F32_LESS_THAN_ONE,
+ ))),
+ Token::Number(Ok(Number::F32(
+ SMALLEST_F32_LARGER_THAN_ONE,
+ ))),
+ Token::Number(Ok(Number::F32(
+ SMALLEST_NORMAL_F32,
+ ))),
+ Token::Number(Ok(Number::F32(
+ LARGEST_NORMAL_F32,
+ ))),
+ ],
+ );
+ sub_test(
+ "-0x1p128f 0x1p128f 0x1.000001p0f",
+ &[
+ Token::Number(Err(NumberError::NotRepresentable)), // = -2^128
+ Token::Number(Err(NumberError::NotRepresentable)), // = 2^128
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+}
+
+#[test]
+fn test_tokens() {
+ sub_test("id123_OK", &[Token::Word("id123_OK")]);
+ sub_test(
+ "92No",
+ &[Token::Number(Ok(Number::I32(92))), Token::Word("No")],
+ );
+ sub_test(
+ "2u3o",
+ &[
+ Token::Number(Ok(Number::U32(2))),
+ Token::Number(Ok(Number::I32(3))),
+ Token::Word("o"),
+ ],
+ );
+ sub_test(
+ "2.4f44po",
+ &[
+ Token::Number(Ok(Number::F32(2.4))),
+ Token::Number(Ok(Number::I32(44))),
+ Token::Word("po"),
+ ],
+ );
+ sub_test(
+ "Δέλτα réflexion Кызыл 𐰓𐰏𐰇 朝焼け سلام 검정 שָׁלוֹם गुलाबी փիրուզ",
+ &[
+ Token::Word("Δέλτα"),
+ Token::Word("réflexion"),
+ Token::Word("Кызыл"),
+ Token::Word("𐰓𐰏𐰇"),
+ Token::Word("朝焼け"),
+ Token::Word("سلام"),
+ Token::Word("검정"),
+ Token::Word("שָׁלוֹם"),
+ Token::Word("गुलाबी"),
+ Token::Word("փիրուզ"),
+ ],
+ );
+ sub_test("æNoø", &[Token::Word("æNoø")]);
+ sub_test("No¾", &[Token::Word("No"), Token::Unknown('¾')]);
+ sub_test("No好", &[Token::Word("No好")]);
+ sub_test("_No", &[Token::Word("_No")]);
+ sub_test(
+ "*/*/***/*//=/*****//",
+ &[
+ Token::Operation('*'),
+ Token::AssignmentOperation('/'),
+ Token::Operation('/'),
+ ],
+ );
+}
+
+#[test]
+fn test_variable_decl() {
+ sub_test(
+ "@group(0 ) var< uniform> texture: texture_multisampled_2d <f32 >;",
+ &[
+ Token::Attribute,
+ Token::Word("group"),
+ Token::Paren('('),
+ Token::Number(Ok(Number::I32(0))),
+ Token::Paren(')'),
+ Token::Word("var"),
+ Token::Paren('<'),
+ Token::Word("uniform"),
+ Token::Paren('>'),
+ Token::Word("texture"),
+ Token::Separator(':'),
+ Token::Word("texture_multisampled_2d"),
+ Token::Paren('<'),
+ Token::Word("f32"),
+ Token::Paren('>'),
+ Token::Separator(';'),
+ ],
+ );
+ sub_test(
+ "var<storage,read_write> buffer: array<u32>;",
+ &[
+ Token::Word("var"),
+ Token::Paren('<'),
+ Token::Word("storage"),
+ Token::Separator(','),
+ Token::Word("read_write"),
+ Token::Paren('>'),
+ Token::Word("buffer"),
+ Token::Separator(':'),
+ Token::Word("array"),
+ Token::Paren('<'),
+ Token::Word("u32"),
+ Token::Paren('>'),
+ Token::Separator(';'),
+ ],
+ );
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/mod.rs b/third_party/rust/naga/src/front/wgsl/parse/mod.rs
new file mode 100644
index 0000000000..cf9892a1f3
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/mod.rs
@@ -0,0 +1,2298 @@
+use crate::front::wgsl::error::{Error, ExpectedToken};
+use crate::front::wgsl::parse::lexer::{Lexer, Token};
+use crate::front::wgsl::parse::number::Number;
+use crate::front::SymbolTable;
+use crate::{Arena, FastHashSet, Handle, Span};
+
+pub mod ast;
+pub mod conv;
+pub mod lexer;
+pub mod number;
+
+/// State for constructing an AST expression.
+///
+/// Not to be confused with `lower::ExpressionContext`.
+struct ExpressionContext<'input, 'temp, 'out> {
+ /// The [`TranslationUnit::expressions`] arena to which we should contribute
+ /// expressions.
+ ///
+ /// [`TranslationUnit::expressions`]: ast::TranslationUnit::expressions
+ expressions: &'out mut Arena<ast::Expression<'input>>,
+
+ /// The [`TranslationUnit::types`] arena to which we should contribute new
+ /// types.
+ ///
+ /// [`TranslationUnit::types`]: ast::TranslationUnit::types
+ types: &'out mut Arena<ast::Type<'input>>,
+
+ /// A map from identifiers in scope to the locals/arguments they represent.
+ ///
+ /// The handles refer to the [`Function::locals`] area; see that field's
+ /// documentation for details.
+ ///
+ /// [`Function::locals`]: ast::Function::locals
+ local_table: &'temp mut SymbolTable<&'input str, Handle<ast::Local>>,
+
+ /// The [`Function::locals`] arena for the function we're building.
+ ///
+ /// [`Function::locals`]: ast::Function::locals
+ locals: &'out mut Arena<ast::Local>,
+
+ /// Identifiers used by the current global declaration that have no local definition.
+ ///
+ /// This becomes the [`GlobalDecl`]'s [`dependencies`] set.
+ ///
+ /// Note that we don't know at parse time what kind of [`GlobalDecl`] the
+ /// name refers to. We can't look up names until we've seen the entire
+ /// translation unit.
+ ///
+ /// [`GlobalDecl`]: ast::GlobalDecl
+ /// [`dependencies`]: ast::GlobalDecl::dependencies
+ unresolved: &'out mut FastHashSet<ast::Dependency<'input>>,
+}
+
+impl<'a> ExpressionContext<'a, '_, '_> {
+ fn reborrow(&mut self) -> ExpressionContext<'a, '_, '_> {
+ ExpressionContext {
+ expressions: self.expressions,
+ types: self.types,
+ local_table: self.local_table,
+ locals: self.locals,
+ unresolved: self.unresolved,
+ }
+ }
+
+ fn parse_binary_op(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ classifier: impl Fn(Token<'a>) -> Option<crate::BinaryOperator>,
+ mut parser: impl FnMut(
+ &mut Lexer<'a>,
+ ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ let start = lexer.start_byte_offset();
+ let mut accumulator = parser(lexer, self.reborrow())?;
+ while let Some(op) = classifier(lexer.peek().0) {
+ let _ = lexer.next();
+ let left = accumulator;
+ let right = parser(lexer, self.reborrow())?;
+ accumulator = self.expressions.append(
+ ast::Expression::Binary { op, left, right },
+ lexer.span_from(start),
+ );
+ }
+ Ok(accumulator)
+ }
+}
+
+/// Which grammar rule we are in the midst of parsing.
+///
+/// This is used for error checking. `Parser` maintains a stack of
+/// these and (occasionally) checks that it is being pushed and popped
+/// as expected.
+#[derive(Clone, Debug, PartialEq)]
+enum Rule {
+ Attribute,
+ VariableDecl,
+ TypeDecl,
+ FunctionDecl,
+ Block,
+ Statement,
+ PrimaryExpr,
+ SingularExpr,
+ UnaryExpr,
+ GeneralExpr,
+}
+
+#[derive(Default)]
+struct BindingParser {
+ location: Option<u32>,
+ built_in: Option<crate::BuiltIn>,
+ interpolation: Option<crate::Interpolation>,
+ sampling: Option<crate::Sampling>,
+ invariant: bool,
+}
+
+impl BindingParser {
+ fn parse<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ name: &'a str,
+ name_span: Span,
+ ) -> Result<(), Error<'a>> {
+ match name {
+ "location" => {
+ lexer.expect(Token::Paren('('))?;
+ self.location = Some(Parser::non_negative_i32_literal(lexer)?);
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "builtin" => {
+ lexer.expect(Token::Paren('('))?;
+ let (raw, span) = lexer.next_ident_with_span()?;
+ self.built_in = Some(conv::map_built_in(raw, span)?);
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "interpolate" => {
+ lexer.expect(Token::Paren('('))?;
+ let (raw, span) = lexer.next_ident_with_span()?;
+ self.interpolation = Some(conv::map_interpolation(raw, span)?);
+ if lexer.skip(Token::Separator(',')) {
+ let (raw, span) = lexer.next_ident_with_span()?;
+ self.sampling = Some(conv::map_sampling(raw, span)?);
+ }
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "invariant" => self.invariant = true,
+ _ => return Err(Error::UnknownAttribute(name_span)),
+ }
+ Ok(())
+ }
+
+ const fn finish<'a>(self, span: Span) -> Result<Option<crate::Binding>, Error<'a>> {
+ match (
+ self.location,
+ self.built_in,
+ self.interpolation,
+ self.sampling,
+ self.invariant,
+ ) {
+ (None, None, None, None, false) => Ok(None),
+ (Some(location), None, interpolation, sampling, false) => {
+ // Before handing over the completed `Module`, we call
+ // `apply_default_interpolation` to ensure that the interpolation and
+ // sampling have been explicitly specified on all vertex shader output and fragment
+ // shader input user bindings, so leaving them potentially `None` here is fine.
+ Ok(Some(crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ }))
+ }
+ (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant) => {
+ Ok(Some(crate::Binding::BuiltIn(crate::BuiltIn::Position {
+ invariant,
+ })))
+ }
+ (None, Some(built_in), None, None, false) => {
+ Ok(Some(crate::Binding::BuiltIn(built_in)))
+ }
+ (_, _, _, _, _) => Err(Error::InconsistentBinding(span)),
+ }
+ }
+}
+
+pub struct Parser {
+ rules: Vec<(Rule, usize)>,
+}
+
+impl Parser {
+ pub const fn new() -> Self {
+ Parser { rules: Vec::new() }
+ }
+
+ fn reset(&mut self) {
+ self.rules.clear();
+ }
+
+ fn push_rule_span(&mut self, rule: Rule, lexer: &mut Lexer<'_>) {
+ self.rules.push((rule, lexer.start_byte_offset()));
+ }
+
+ fn pop_rule_span(&mut self, lexer: &Lexer<'_>) -> Span {
+ let (_, initial) = self.rules.pop().unwrap();
+ lexer.span_from(initial)
+ }
+
+ fn peek_rule_span(&mut self, lexer: &Lexer<'_>) -> Span {
+ let &(_, initial) = self.rules.last().unwrap();
+ lexer.span_from(initial)
+ }
+
+ fn switch_value<'a>(lexer: &mut Lexer<'a>) -> Result<(ast::SwitchValue, Span), Error<'a>> {
+ let token_span = lexer.next();
+ match token_span.0 {
+ Token::Word("default") => Ok((ast::SwitchValue::Default, token_span.1)),
+ Token::Number(Ok(Number::U32(num))) => Ok((ast::SwitchValue::U32(num), token_span.1)),
+ Token::Number(Ok(Number::I32(num))) => Ok((ast::SwitchValue::I32(num), token_span.1)),
+ Token::Number(Err(e)) => Err(Error::BadNumber(token_span.1, e)),
+ _ => Err(Error::Unexpected(token_span.1, ExpectedToken::Integer)),
+ }
+ }
+
+ /// Parse a non-negative signed integer literal.
+ /// This is for attributes like `size`, `location` and others.
+ fn non_negative_i32_literal<'a>(lexer: &mut Lexer<'a>) -> Result<u32, Error<'a>> {
+ match lexer.next() {
+ (Token::Number(Ok(Number::I32(num))), span) => {
+ u32::try_from(num).map_err(|_| Error::NegativeInt(span))
+ }
+ (Token::Number(Err(e)), span) => Err(Error::BadNumber(span, e)),
+ other => Err(Error::Unexpected(other.1, ExpectedToken::Number)),
+ }
+ }
+
+ /// Parse a non-negative integer literal that may be either signed or unsigned.
+ /// This is for the `workgroup_size` attribute and array lengths.
+ /// Note: these values should be no larger than [`i32::MAX`], but this is not checked here.
+ fn generic_non_negative_int_literal<'a>(lexer: &mut Lexer<'a>) -> Result<u32, Error<'a>> {
+ match lexer.next() {
+ (Token::Number(Ok(Number::I32(num))), span) => {
+ u32::try_from(num).map_err(|_| Error::NegativeInt(span))
+ }
+ (Token::Number(Ok(Number::U32(num))), _) => Ok(num),
+ (Token::Number(Err(e)), span) => Err(Error::BadNumber(span, e)),
+ other => Err(Error::Unexpected(other.1, ExpectedToken::Number)),
+ }
+ }
+
+ /// Decide if we're looking at a construction expression, and return its
+ /// type if so.
+ ///
+ /// If the identifier `word` is a [type-defining keyword], then return a
+ /// [`ConstructorType`] value describing the type to build. Return an error
+ /// if the type is not constructible (like `sampler`).
+ ///
+ /// If `word` isn't a type name, then return `None`.
+ ///
+ /// [type-defining keyword]: https://gpuweb.github.io/gpuweb/wgsl/#type-defining-keywords
+ /// [`ConstructorType`]: ast::ConstructorType
+ fn constructor_type<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ word: &'a str,
+ span: Span,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Option<ast::ConstructorType<'a>>, Error<'a>> {
+ if let Some((kind, width)) = conv::get_scalar_type(word) {
+ return Ok(Some(ast::ConstructorType::Scalar { kind, width }));
+ }
+
+ let partial = match word {
+ "vec2" => ast::ConstructorType::PartialVector {
+ size: crate::VectorSize::Bi,
+ },
+ "vec2i" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ }))
+ }
+ "vec2u" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ }))
+ }
+ "vec2f" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ }))
+ }
+ "vec3" => ast::ConstructorType::PartialVector {
+ size: crate::VectorSize::Tri,
+ },
+ "vec3i" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ }))
+ }
+ "vec3u" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ }))
+ }
+ "vec3f" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ }))
+ }
+ "vec4" => ast::ConstructorType::PartialVector {
+ size: crate::VectorSize::Quad,
+ },
+ "vec4i" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Quad,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ }))
+ }
+ "vec4u" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Quad,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ }))
+ }
+ "vec4f" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Quad,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ }))
+ }
+ "mat2x2" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Bi,
+ },
+ "mat2x2f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }))
+ }
+ "mat2x3" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Tri,
+ },
+ "mat2x3f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ }))
+ }
+ "mat2x4" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Quad,
+ },
+ "mat2x4f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ }))
+ }
+ "mat3x2" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Bi,
+ },
+ "mat3x2f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }))
+ }
+ "mat3x3" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ },
+ "mat3x3f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ }))
+ }
+ "mat3x4" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Quad,
+ },
+ "mat3x4f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ }))
+ }
+ "mat4x2" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Bi,
+ },
+ "mat4x2f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }))
+ }
+ "mat4x3" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ },
+ "mat4x3f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ }))
+ }
+ "mat4x4" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Quad,
+ },
+ "mat4x4f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ }))
+ }
+ "array" => ast::ConstructorType::PartialArray,
+ "atomic"
+ | "binding_array"
+ | "sampler"
+ | "sampler_comparison"
+ | "texture_1d"
+ | "texture_1d_array"
+ | "texture_2d"
+ | "texture_2d_array"
+ | "texture_3d"
+ | "texture_cube"
+ | "texture_cube_array"
+ | "texture_multisampled_2d"
+ | "texture_multisampled_2d_array"
+ | "texture_depth_2d"
+ | "texture_depth_2d_array"
+ | "texture_depth_cube"
+ | "texture_depth_cube_array"
+ | "texture_depth_multisampled_2d"
+ | "texture_storage_1d"
+ | "texture_storage_1d_array"
+ | "texture_storage_2d"
+ | "texture_storage_2d_array"
+ | "texture_storage_3d" => return Err(Error::TypeNotConstructible(span)),
+ _ => return Ok(None),
+ };
+
+ // parse component type if present
+ match (lexer.peek().0, partial) {
+ (Token::Paren('<'), ast::ConstructorType::PartialVector { size }) => {
+ let (kind, width) = lexer.next_scalar_generic()?;
+ Ok(Some(ast::ConstructorType::Vector { size, kind, width }))
+ }
+ (Token::Paren('<'), ast::ConstructorType::PartialMatrix { columns, rows }) => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ match kind {
+ crate::ScalarKind::Float => Ok(Some(ast::ConstructorType::Matrix {
+ columns,
+ rows,
+ width,
+ })),
+ _ => Err(Error::BadMatrixScalarKind(span, kind, width)),
+ }
+ }
+ (Token::Paren('<'), ast::ConstructorType::PartialArray) => {
+ lexer.expect_generic_paren('<')?;
+ let base = self.type_decl(lexer, ctx.reborrow())?;
+ let size = if lexer.skip(Token::Separator(',')) {
+ let expr = self.unary_expression(lexer, ctx.reborrow())?;
+ ast::ArraySize::Constant(expr)
+ } else {
+ ast::ArraySize::Dynamic
+ };
+ lexer.expect_generic_paren('>')?;
+
+ Ok(Some(ast::ConstructorType::Array { base, size }))
+ }
+ (_, partial) => Ok(Some(partial)),
+ }
+ }
+
+ /// Expects `name` to be consumed (not in lexer).
+ fn arguments<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Vec<Handle<ast::Expression<'a>>>, Error<'a>> {
+ lexer.open_arguments()?;
+ let mut arguments = Vec::new();
+ loop {
+ if !arguments.is_empty() {
+ if !lexer.next_argument()? {
+ break;
+ }
+ } else if lexer.skip(Token::Paren(')')) {
+ break;
+ }
+ let arg = self.general_expression(lexer, ctx.reborrow())?;
+ arguments.push(arg);
+ }
+
+ Ok(arguments)
+ }
+
+ /// Expects [`Rule::PrimaryExpr`] or [`Rule::SingularExpr`] on top; does not pop it.
+ /// Expects `name` to be consumed (not in lexer).
+ fn function_call<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ name: &'a str,
+ name_span: Span,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ assert!(self.rules.last().is_some());
+
+ let expr = match name {
+ // bitcast looks like a function call, but it's an operator and must be handled differently.
+ "bitcast" => {
+ lexer.expect_generic_paren('<')?;
+ let start = lexer.start_byte_offset();
+ let to = self.type_decl(lexer, ctx.reborrow())?;
+ let span = lexer.span_from(start);
+ lexer.expect_generic_paren('>')?;
+
+ lexer.open_arguments()?;
+ let expr = self.general_expression(lexer, ctx.reborrow())?;
+ lexer.close_arguments()?;
+
+ ast::Expression::Bitcast {
+ expr,
+ to,
+ ty_span: span,
+ }
+ }
+ // everything else must be handled later, since they can be hidden by user-defined functions.
+ _ => {
+ let arguments = self.arguments(lexer, ctx.reborrow())?;
+ ctx.unresolved.insert(ast::Dependency {
+ ident: name,
+ usage: name_span,
+ });
+ ast::Expression::Call {
+ function: ast::Ident {
+ name,
+ span: name_span,
+ },
+ arguments,
+ }
+ }
+ };
+
+ let span = self.peek_rule_span(lexer);
+ let expr = ctx.expressions.append(expr, span);
+ Ok(expr)
+ }
+
+ fn ident_expr<'a>(
+ &mut self,
+ name: &'a str,
+ name_span: Span,
+ ctx: ExpressionContext<'a, '_, '_>,
+ ) -> ast::IdentExpr<'a> {
+ match ctx.local_table.lookup(name) {
+ Some(&local) => ast::IdentExpr::Local(local),
+ None => {
+ ctx.unresolved.insert(ast::Dependency {
+ ident: name,
+ usage: name_span,
+ });
+ ast::IdentExpr::Unresolved(name)
+ }
+ }
+ }
+
+ fn primary_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ self.push_rule_span(Rule::PrimaryExpr, lexer);
+
+ let expr = match lexer.peek() {
+ (Token::Paren('('), _) => {
+ let _ = lexer.next();
+ let expr = self.general_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren(')'))?;
+ self.pop_rule_span(lexer);
+ return Ok(expr);
+ }
+ (Token::Word("true"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Bool(true))
+ }
+ (Token::Word("false"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Bool(false))
+ }
+ (Token::Number(res), span) => {
+ let _ = lexer.next();
+ let num = res.map_err(|err| Error::BadNumber(span, err))?;
+ ast::Expression::Literal(ast::Literal::Number(num))
+ }
+ (Token::Word("RAY_FLAG_NONE"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Number(Number::U32(0)))
+ }
+ (Token::Word("RAY_FLAG_TERMINATE_ON_FIRST_HIT"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Number(Number::U32(4)))
+ }
+ (Token::Word("RAY_QUERY_INTERSECTION_NONE"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Number(Number::U32(0)))
+ }
+ (Token::Word(word), span) => {
+ let start = lexer.start_byte_offset();
+ let _ = lexer.next();
+
+ if let Some(ty) = self.constructor_type(lexer, word, span, ctx.reborrow())? {
+ let ty_span = lexer.span_from(start);
+ let components = self.arguments(lexer, ctx.reborrow())?;
+ ast::Expression::Construct {
+ ty,
+ ty_span,
+ components,
+ }
+ } else if let Token::Paren('(') = lexer.peek().0 {
+ self.pop_rule_span(lexer);
+ return self.function_call(lexer, word, span, ctx);
+ } else if word == "bitcast" {
+ self.pop_rule_span(lexer);
+ return self.function_call(lexer, word, span, ctx);
+ } else {
+ let ident = self.ident_expr(word, span, ctx.reborrow());
+ ast::Expression::Ident(ident)
+ }
+ }
+ other => return Err(Error::Unexpected(other.1, ExpectedToken::PrimaryExpression)),
+ };
+
+ let span = self.pop_rule_span(lexer);
+ let expr = ctx.expressions.append(expr, span);
+ Ok(expr)
+ }
+
+ fn postfix<'a>(
+ &mut self,
+ span_start: usize,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ expr: Handle<ast::Expression<'a>>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ let mut expr = expr;
+
+ loop {
+ let expression = match lexer.peek().0 {
+ Token::Separator('.') => {
+ let _ = lexer.next();
+ let field = lexer.next_ident()?;
+
+ ast::Expression::Member { base: expr, field }
+ }
+ Token::Paren('[') => {
+ let _ = lexer.next();
+ let index = self.general_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren(']'))?;
+
+ ast::Expression::Index { base: expr, index }
+ }
+ _ => break,
+ };
+
+ let span = lexer.span_from(span_start);
+ expr = ctx.expressions.append(expression, span);
+ }
+
+ Ok(expr)
+ }
+
+ /// Parse a `unary_expression`.
+ fn unary_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ self.push_rule_span(Rule::UnaryExpr, lexer);
+ //TODO: refactor this to avoid backing up
+ let expr = match lexer.peek().0 {
+ Token::Operation('-') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx.reborrow())?;
+ let expr = ast::Expression::Unary {
+ op: crate::UnaryOperator::Negate,
+ expr,
+ };
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ Token::Operation('!' | '~') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx.reborrow())?;
+ let expr = ast::Expression::Unary {
+ op: crate::UnaryOperator::Not,
+ expr,
+ };
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ Token::Operation('*') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx.reborrow())?;
+ let expr = ast::Expression::Deref(expr);
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ Token::Operation('&') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx.reborrow())?;
+ let expr = ast::Expression::AddrOf(expr);
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ _ => self.singular_expression(lexer, ctx.reborrow())?,
+ };
+
+ self.pop_rule_span(lexer);
+ Ok(expr)
+ }
+
+ /// Parse a `singular_expression`.
+ fn singular_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ let start = lexer.start_byte_offset();
+ self.push_rule_span(Rule::SingularExpr, lexer);
+ let primary_expr = self.primary_expression(lexer, ctx.reborrow())?;
+ let singular_expr = self.postfix(start, lexer, ctx.reborrow(), primary_expr)?;
+ self.pop_rule_span(lexer);
+
+ Ok(singular_expr)
+ }
+
+ fn equality_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut context: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ // equality_expression
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::LogicalOperation('=') => Some(crate::BinaryOperator::Equal),
+ Token::LogicalOperation('!') => Some(crate::BinaryOperator::NotEqual),
+ _ => None,
+ },
+ // relational_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Paren('<') => Some(crate::BinaryOperator::Less),
+ Token::Paren('>') => Some(crate::BinaryOperator::Greater),
+ Token::LogicalOperation('<') => Some(crate::BinaryOperator::LessEqual),
+ Token::LogicalOperation('>') => Some(crate::BinaryOperator::GreaterEqual),
+ _ => None,
+ },
+ // shift_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::ShiftOperation('<') => {
+ Some(crate::BinaryOperator::ShiftLeft)
+ }
+ Token::ShiftOperation('>') => {
+ Some(crate::BinaryOperator::ShiftRight)
+ }
+ _ => None,
+ },
+ // additive_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('+') => Some(crate::BinaryOperator::Add),
+ Token::Operation('-') => {
+ Some(crate::BinaryOperator::Subtract)
+ }
+ _ => None,
+ },
+ // multiplicative_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('*') => {
+ Some(crate::BinaryOperator::Multiply)
+ }
+ Token::Operation('/') => {
+ Some(crate::BinaryOperator::Divide)
+ }
+ Token::Operation('%') => {
+ Some(crate::BinaryOperator::Modulo)
+ }
+ _ => None,
+ },
+ |lexer, context| self.unary_expression(lexer, context),
+ )
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )
+ }
+
+ fn general_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ self.general_expression_with_span(lexer, ctx.reborrow())
+ .map(|(expr, _)| expr)
+ }
+
+ fn general_expression_with_span<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut context: ExpressionContext<'a, '_, '_>,
+ ) -> Result<(Handle<ast::Expression<'a>>, Span), Error<'a>> {
+ self.push_rule_span(Rule::GeneralExpr, lexer);
+ // logical_or_expression
+ let handle = context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::LogicalOperation('|') => Some(crate::BinaryOperator::LogicalOr),
+ _ => None,
+ },
+ // logical_and_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::LogicalOperation('&') => Some(crate::BinaryOperator::LogicalAnd),
+ _ => None,
+ },
+ // inclusive_or_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('|') => Some(crate::BinaryOperator::InclusiveOr),
+ _ => None,
+ },
+ // exclusive_or_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('^') => {
+ Some(crate::BinaryOperator::ExclusiveOr)
+ }
+ _ => None,
+ },
+ // and_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('&') => {
+ Some(crate::BinaryOperator::And)
+ }
+ _ => None,
+ },
+ |lexer, context| {
+ self.equality_expression(lexer, context)
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )?;
+ Ok((handle, self.pop_rule_span(lexer)))
+ }
+
+ fn variable_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<ast::GlobalVariable<'a>, Error<'a>> {
+ self.push_rule_span(Rule::VariableDecl, lexer);
+ let mut space = crate::AddressSpace::Handle;
+
+ if lexer.skip(Token::Paren('<')) {
+ let (class_str, span) = lexer.next_ident_with_span()?;
+ space = match class_str {
+ "storage" => {
+ let access = if lexer.skip(Token::Separator(',')) {
+ lexer.next_storage_access()?
+ } else {
+ // defaulting to `read`
+ crate::StorageAccess::LOAD
+ };
+ crate::AddressSpace::Storage { access }
+ }
+ _ => conv::map_address_space(class_str, span)?,
+ };
+ lexer.expect(Token::Paren('>'))?;
+ }
+ let name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(':'))?;
+ let ty = self.type_decl(lexer, ctx.reborrow())?;
+
+ let init = if lexer.skip(Token::Operation('=')) {
+ let handle = self.general_expression(lexer, ctx.reborrow())?;
+ Some(handle)
+ } else {
+ None
+ };
+ lexer.expect(Token::Separator(';'))?;
+ self.pop_rule_span(lexer);
+
+ Ok(ast::GlobalVariable {
+ name,
+ space,
+ binding: None,
+ ty,
+ init,
+ })
+ }
+
+ fn struct_body<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Vec<ast::StructMember<'a>>, Error<'a>> {
+ let mut members = Vec::new();
+
+ lexer.expect(Token::Paren('{'))?;
+ let mut ready = true;
+ while !lexer.skip(Token::Paren('}')) {
+ if !ready {
+ return Err(Error::Unexpected(
+ lexer.next().1,
+ ExpectedToken::Token(Token::Separator(',')),
+ ));
+ }
+ let (mut size, mut align) = (None, None);
+ self.push_rule_span(Rule::Attribute, lexer);
+ let mut bind_parser = BindingParser::default();
+ while lexer.skip(Token::Attribute) {
+ match lexer.next_ident_with_span()? {
+ ("size", _) => {
+ lexer.expect(Token::Paren('('))?;
+ let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?;
+ lexer.expect(Token::Paren(')'))?;
+ size = Some((value, span));
+ }
+ ("align", _) => {
+ lexer.expect(Token::Paren('('))?;
+ let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?;
+ lexer.expect(Token::Paren(')'))?;
+ align = Some((value, span));
+ }
+ (word, word_span) => bind_parser.parse(lexer, word, word_span)?,
+ }
+ }
+
+ let bind_span = self.pop_rule_span(lexer);
+ let binding = bind_parser.finish(bind_span)?;
+
+ let name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(':'))?;
+ let ty = self.type_decl(lexer, ctx.reborrow())?;
+ ready = lexer.skip(Token::Separator(','));
+
+ members.push(ast::StructMember {
+ name,
+ ty,
+ binding,
+ size,
+ align,
+ });
+ }
+
+ Ok(members)
+ }
+
+ fn matrix_scalar_type<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ ) -> Result<ast::Type<'a>, Error<'a>> {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ match kind {
+ crate::ScalarKind::Float => Ok(ast::Type::Matrix {
+ columns,
+ rows,
+ width,
+ }),
+ _ => Err(Error::BadMatrixScalarKind(span, kind, width)),
+ }
+ }
+
+ fn type_decl_impl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ word: &'a str,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Option<ast::Type<'a>>, Error<'a>> {
+ if let Some((kind, width)) = conv::get_scalar_type(word) {
+ return Ok(Some(ast::Type::Scalar { kind, width }));
+ }
+
+ Ok(Some(match word {
+ "vec2" => {
+ let (kind, width) = lexer.next_scalar_generic()?;
+ ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ kind,
+ width,
+ }
+ }
+ "vec2i" => ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ },
+ "vec2u" => ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ "vec2f" => ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ },
+ "vec3" => {
+ let (kind, width) = lexer.next_scalar_generic()?;
+ ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ kind,
+ width,
+ }
+ }
+ "vec3i" => ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ },
+ "vec3u" => ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ "vec3f" => ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ },
+ "vec4" => {
+ let (kind, width) = lexer.next_scalar_generic()?;
+ ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ kind,
+ width,
+ }
+ }
+ "vec4i" => ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ },
+ "vec4u" => ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ "vec4f" => ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ },
+ "mat2x2" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Bi)?
+ }
+ "mat2x2f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ },
+ "mat2x3" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Tri)?
+ }
+ "mat2x3f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ },
+ "mat2x4" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Quad)?
+ }
+ "mat2x4f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ },
+ "mat3x2" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Bi)?
+ }
+ "mat3x2f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ },
+ "mat3x3" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Tri)?
+ }
+ "mat3x3f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ },
+ "mat3x4" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Quad)?
+ }
+ "mat3x4f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ },
+ "mat4x2" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Bi)?
+ }
+ "mat4x2f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ },
+ "mat4x3" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Tri)?
+ }
+ "mat4x3f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ },
+ "mat4x4" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Quad)?
+ }
+ "mat4x4f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ },
+ "atomic" => {
+ let (kind, width) = lexer.next_scalar_generic()?;
+ ast::Type::Atomic { kind, width }
+ }
+ "ptr" => {
+ lexer.expect_generic_paren('<')?;
+ let (ident, span) = lexer.next_ident_with_span()?;
+ let mut space = conv::map_address_space(ident, span)?;
+ lexer.expect(Token::Separator(','))?;
+ let base = self.type_decl(lexer, ctx)?;
+ if let crate::AddressSpace::Storage { ref mut access } = space {
+ *access = if lexer.skip(Token::Separator(',')) {
+ lexer.next_storage_access()?
+ } else {
+ crate::StorageAccess::LOAD
+ };
+ }
+ lexer.expect_generic_paren('>')?;
+ ast::Type::Pointer { base, space }
+ }
+ "array" => {
+ lexer.expect_generic_paren('<')?;
+ let base = self.type_decl(lexer, ctx.reborrow())?;
+ let size = if lexer.skip(Token::Separator(',')) {
+ let size = self.unary_expression(lexer, ctx.reborrow())?;
+ ast::ArraySize::Constant(size)
+ } else {
+ ast::ArraySize::Dynamic
+ };
+ lexer.expect_generic_paren('>')?;
+
+ ast::Type::Array { base, size }
+ }
+ "binding_array" => {
+ lexer.expect_generic_paren('<')?;
+ let base = self.type_decl(lexer, ctx.reborrow())?;
+ let size = if lexer.skip(Token::Separator(',')) {
+ let size = self.unary_expression(lexer, ctx.reborrow())?;
+ ast::ArraySize::Constant(size)
+ } else {
+ ast::ArraySize::Dynamic
+ };
+ lexer.expect_generic_paren('>')?;
+
+ ast::Type::BindingArray { base, size }
+ }
+ "sampler" => ast::Type::Sampler { comparison: false },
+ "sampler_comparison" => ast::Type::Sampler { comparison: true },
+ "texture_1d" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: false,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ "texture_1d_array" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: true,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ "texture_2d" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ "texture_2d_array" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ "texture_3d" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D3,
+ arrayed: false,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ "texture_cube" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: false,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ "texture_cube_array" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: true,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ "texture_multisampled_2d" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Sampled { kind, multi: true },
+ }
+ }
+ "texture_multisampled_2d_array" => {
+ let (kind, width, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(kind, width, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Sampled { kind, multi: true },
+ }
+ }
+ "texture_depth_2d" => ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_2d_array" => ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_cube" => ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_cube_array" => ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: true,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_multisampled_2d" => ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: true },
+ },
+ "texture_storage_1d" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: false,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_1d_array" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: true,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_2d" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_2d_array" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_3d" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D3,
+ arrayed: false,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "acceleration_structure" => ast::Type::AccelerationStructure,
+ "ray_query" => ast::Type::RayQuery,
+ "RayDesc" => ast::Type::RayDesc,
+ "RayIntersection" => ast::Type::RayIntersection,
+ _ => return Ok(None),
+ }))
+ }
+
+ const fn check_texture_sample_type(
+ kind: crate::ScalarKind,
+ width: u8,
+ span: Span,
+ ) -> Result<(), Error<'static>> {
+ use crate::ScalarKind::*;
+ // Validate according to https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type
+ match (kind, width) {
+ (Float | Sint | Uint, 4) => Ok(()),
+ _ => Err(Error::BadTextureSampleType { span, kind, width }),
+ }
+ }
+
+ /// Parse type declaration of a given name.
+ fn type_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Type<'a>>, Error<'a>> {
+ self.push_rule_span(Rule::TypeDecl, lexer);
+
+ let (name, span) = lexer.next_ident_with_span()?;
+
+ let ty = match self.type_decl_impl(lexer, name, ctx.reborrow())? {
+ Some(ty) => ty,
+ None => {
+ ctx.unresolved.insert(ast::Dependency {
+ ident: name,
+ usage: span,
+ });
+ ast::Type::User(ast::Ident { name, span })
+ }
+ };
+
+ self.pop_rule_span(lexer);
+
+ let handle = ctx.types.append(ty, Span::UNDEFINED);
+ Ok(handle)
+ }
+
+ fn assignment_op_and_rhs<'a, 'out>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, 'out>,
+ block: &'out mut ast::Block<'a>,
+ target: Handle<ast::Expression<'a>>,
+ span_start: usize,
+ ) -> Result<(), Error<'a>> {
+ use crate::BinaryOperator as Bo;
+
+ let op = lexer.next();
+ let (op, value) = match op {
+ (Token::Operation('='), _) => {
+ let value = self.general_expression(lexer, ctx.reborrow())?;
+ (None, value)
+ }
+ (Token::AssignmentOperation(c), _) => {
+ let op = match c {
+ '<' => Bo::ShiftLeft,
+ '>' => Bo::ShiftRight,
+ '+' => Bo::Add,
+ '-' => Bo::Subtract,
+ '*' => Bo::Multiply,
+ '/' => Bo::Divide,
+ '%' => Bo::Modulo,
+ '&' => Bo::And,
+ '|' => Bo::InclusiveOr,
+ '^' => Bo::ExclusiveOr,
+ // Note: `consume_token` shouldn't produce any other assignment ops
+ _ => unreachable!(),
+ };
+
+ let value = self.general_expression(lexer, ctx.reborrow())?;
+ (Some(op), value)
+ }
+ token @ (Token::IncrementOperation | Token::DecrementOperation, _) => {
+ let op = match token.0 {
+ Token::IncrementOperation => ast::StatementKind::Increment,
+ Token::DecrementOperation => ast::StatementKind::Decrement,
+ _ => unreachable!(),
+ };
+
+ let span = lexer.span_from(span_start);
+ block.stmts.push(ast::Statement {
+ kind: op(target),
+ span,
+ });
+ return Ok(());
+ }
+ _ => return Err(Error::Unexpected(op.1, ExpectedToken::Assignment)),
+ };
+
+ let span = lexer.span_from(span_start);
+ block.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Assign { target, op, value },
+ span,
+ });
+ Ok(())
+ }
+
+ /// Parse an assignment statement (will also parse increment and decrement statements)
+ fn assignment_statement<'a, 'out>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, 'out>,
+ block: &'out mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ let span_start = lexer.start_byte_offset();
+ let target = self.general_expression(lexer, ctx.reborrow())?;
+ self.assignment_op_and_rhs(lexer, ctx, block, target, span_start)
+ }
+
+ /// Parse a function call statement.
+ /// Expects `ident` to be consumed (not in the lexer).
+ fn function_statement<'a, 'out>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ident: &'a str,
+ ident_span: Span,
+ span_start: usize,
+ mut context: ExpressionContext<'a, '_, 'out>,
+ block: &'out mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ self.push_rule_span(Rule::SingularExpr, lexer);
+
+ context.unresolved.insert(ast::Dependency {
+ ident,
+ usage: ident_span,
+ });
+ let arguments = self.arguments(lexer, context.reborrow())?;
+ let span = lexer.span_from(span_start);
+
+ block.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Call {
+ function: ast::Ident {
+ name: ident,
+ span: ident_span,
+ },
+ arguments,
+ },
+ span,
+ });
+
+ self.pop_rule_span(lexer);
+
+ Ok(())
+ }
+
+ fn function_call_or_assignment_statement<'a, 'out>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut context: ExpressionContext<'a, '_, 'out>,
+ block: &'out mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ let span_start = lexer.start_byte_offset();
+ match lexer.peek() {
+ (Token::Word(name), span) => {
+ // A little hack for 2 token lookahead.
+ let cloned = lexer.clone();
+ let _ = lexer.next();
+ match lexer.peek() {
+ (Token::Paren('('), _) => self.function_statement(
+ lexer,
+ name,
+ span,
+ span_start,
+ context.reborrow(),
+ block,
+ ),
+ _ => {
+ *lexer = cloned;
+ self.assignment_statement(lexer, context.reborrow(), block)
+ }
+ }
+ }
+ _ => self.assignment_statement(lexer, context.reborrow(), block),
+ }
+ }
+
+ fn statement<'a, 'out>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, 'out>,
+ block: &'out mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ self.push_rule_span(Rule::Statement, lexer);
+ match lexer.peek() {
+ (Token::Separator(';'), _) => {
+ let _ = lexer.next();
+ self.pop_rule_span(lexer);
+ return Ok(());
+ }
+ (Token::Paren('{'), _) => {
+ let (inner, span) = self.block(lexer, ctx.reborrow())?;
+ block.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Block(inner),
+ span,
+ });
+ self.pop_rule_span(lexer);
+ return Ok(());
+ }
+ (Token::Word(word), _) => {
+ let kind = match word {
+ "_" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Operation('='))?;
+ let expr = self.general_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Separator(';'))?;
+
+ ast::StatementKind::Ignore(expr)
+ }
+ "let" => {
+ let _ = lexer.next();
+ let name = lexer.next_ident()?;
+
+ let given_ty = if lexer.skip(Token::Separator(':')) {
+ let ty = self.type_decl(lexer, ctx.reborrow())?;
+ Some(ty)
+ } else {
+ None
+ };
+ lexer.expect(Token::Operation('='))?;
+ let expr_id = self.general_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Separator(';'))?;
+
+ let handle = ctx.locals.append(ast::Local, name.span);
+ if let Some(old) = ctx.local_table.add(name.name, handle) {
+ return Err(Error::Redefinition {
+ previous: ctx.locals.get_span(old),
+ current: name.span,
+ });
+ }
+
+ ast::StatementKind::LocalDecl(ast::LocalDecl::Let(ast::Let {
+ name,
+ ty: given_ty,
+ init: expr_id,
+ handle,
+ }))
+ }
+ "var" => {
+ let _ = lexer.next();
+
+ let name = lexer.next_ident()?;
+ let ty = if lexer.skip(Token::Separator(':')) {
+ let ty = self.type_decl(lexer, ctx.reborrow())?;
+ Some(ty)
+ } else {
+ None
+ };
+
+ let init = if lexer.skip(Token::Operation('=')) {
+ let init = self.general_expression(lexer, ctx.reborrow())?;
+ Some(init)
+ } else {
+ None
+ };
+
+ lexer.expect(Token::Separator(';'))?;
+
+ let handle = ctx.locals.append(ast::Local, name.span);
+ if let Some(old) = ctx.local_table.add(name.name, handle) {
+ return Err(Error::Redefinition {
+ previous: ctx.locals.get_span(old),
+ current: name.span,
+ });
+ }
+
+ ast::StatementKind::LocalDecl(ast::LocalDecl::Var(ast::LocalVariable {
+ name,
+ ty,
+ init,
+ handle,
+ }))
+ }
+ "return" => {
+ let _ = lexer.next();
+ let value = if lexer.peek().0 != Token::Separator(';') {
+ let handle = self.general_expression(lexer, ctx.reborrow())?;
+ Some(handle)
+ } else {
+ None
+ };
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Return { value }
+ }
+ "if" => {
+ let _ = lexer.next();
+ let condition = self.general_expression(lexer, ctx.reborrow())?;
+
+ let accept = self.block(lexer, ctx.reborrow())?.0;
+
+ let mut elsif_stack = Vec::new();
+ let mut elseif_span_start = lexer.start_byte_offset();
+ let mut reject = loop {
+ if !lexer.skip(Token::Word("else")) {
+ break ast::Block::default();
+ }
+
+ if !lexer.skip(Token::Word("if")) {
+ // ... else { ... }
+ break self.block(lexer, ctx.reborrow())?.0;
+ }
+
+ // ... else if (...) { ... }
+ let other_condition = self.general_expression(lexer, ctx.reborrow())?;
+ let other_block = self.block(lexer, ctx.reborrow())?;
+ elsif_stack.push((elseif_span_start, other_condition, other_block));
+ elseif_span_start = lexer.start_byte_offset();
+ };
+
+ // reverse-fold the else-if blocks
+ //Note: we may consider uplifting this to the IR
+ for (other_span_start, other_cond, other_block) in
+ elsif_stack.into_iter().rev()
+ {
+ let sub_stmt = ast::StatementKind::If {
+ condition: other_cond,
+ accept: other_block.0,
+ reject,
+ };
+ reject = ast::Block::default();
+ let span = lexer.span_from(other_span_start);
+ reject.stmts.push(ast::Statement {
+ kind: sub_stmt,
+ span,
+ })
+ }
+
+ ast::StatementKind::If {
+ condition,
+ accept,
+ reject,
+ }
+ }
+ "switch" => {
+ let _ = lexer.next();
+ let selector = self.general_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren('{'))?;
+ let mut cases = Vec::new();
+
+ loop {
+ // cases + default
+ match lexer.next() {
+ (Token::Word("case"), _) => {
+ // parse a list of values
+ let (value, value_span) = loop {
+ let (value, value_span) = Self::switch_value(lexer)?;
+ if lexer.skip(Token::Separator(',')) {
+ if lexer.skip(Token::Separator(':')) {
+ break (value, value_span);
+ }
+ } else {
+ lexer.skip(Token::Separator(':'));
+ break (value, value_span);
+ }
+ cases.push(ast::SwitchCase {
+ value,
+ value_span,
+ body: ast::Block::default(),
+ fall_through: true,
+ });
+ };
+
+ let body = self.block(lexer, ctx.reborrow())?.0;
+
+ cases.push(ast::SwitchCase {
+ value,
+ value_span,
+ body,
+ fall_through: false,
+ });
+ }
+ (Token::Word("default"), value_span) => {
+ lexer.skip(Token::Separator(':'));
+ let body = self.block(lexer, ctx.reborrow())?.0;
+ cases.push(ast::SwitchCase {
+ value: ast::SwitchValue::Default,
+ value_span,
+ body,
+ fall_through: false,
+ });
+ }
+ (Token::Paren('}'), _) => break,
+ (_, span) => {
+ return Err(Error::Unexpected(span, ExpectedToken::SwitchItem))
+ }
+ }
+ }
+
+ ast::StatementKind::Switch { selector, cases }
+ }
+ "loop" => self.r#loop(lexer, ctx.reborrow())?,
+ "while" => {
+ let _ = lexer.next();
+ let mut body = ast::Block::default();
+
+ let (condition, span) = lexer.capture_span(|lexer| {
+ let condition = self.general_expression(lexer, ctx.reborrow())?;
+ Ok(condition)
+ })?;
+ let mut reject = ast::Block::default();
+ reject.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Break,
+ span,
+ });
+
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::If {
+ condition,
+ accept: ast::Block::default(),
+ reject,
+ },
+ span,
+ });
+
+ let (block, span) = self.block(lexer, ctx.reborrow())?;
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Block(block),
+ span,
+ });
+
+ ast::StatementKind::Loop {
+ body,
+ continuing: ast::Block::default(),
+ break_if: None,
+ }
+ }
+ "for" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Paren('('))?;
+
+ ctx.local_table.push_scope();
+
+ if !lexer.skip(Token::Separator(';')) {
+ let num_statements = block.stmts.len();
+ let (_, span) = lexer.capture_span(|lexer| {
+ self.statement(lexer, ctx.reborrow(), block)
+ })?;
+
+ if block.stmts.len() != num_statements {
+ match block.stmts.last().unwrap().kind {
+ ast::StatementKind::Call { .. }
+ | ast::StatementKind::Assign { .. }
+ | ast::StatementKind::LocalDecl(_) => {}
+ _ => return Err(Error::InvalidForInitializer(span)),
+ }
+ }
+ };
+
+ let mut body = ast::Block::default();
+ if !lexer.skip(Token::Separator(';')) {
+ let (condition, span) = lexer.capture_span(|lexer| {
+ let condition = self.general_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Separator(';'))?;
+ Ok(condition)
+ })?;
+ let mut reject = ast::Block::default();
+ reject.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Break,
+ span,
+ });
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::If {
+ condition,
+ accept: ast::Block::default(),
+ reject,
+ },
+ span,
+ });
+ };
+
+ let mut continuing = ast::Block::default();
+ if !lexer.skip(Token::Paren(')')) {
+ self.function_call_or_assignment_statement(
+ lexer,
+ ctx.reborrow(),
+ &mut continuing,
+ )?;
+ lexer.expect(Token::Paren(')'))?;
+ }
+
+ let (block, span) = self.block(lexer, ctx.reborrow())?;
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Block(block),
+ span,
+ });
+
+ ctx.local_table.pop_scope();
+
+ ast::StatementKind::Loop {
+ body,
+ continuing,
+ break_if: None,
+ }
+ }
+ "break" => {
+ let (_, span) = lexer.next();
+ // Check if the next token is an `if`, this indicates
+ // that the user tried to type out a `break if` which
+ // is illegal in this position.
+ let (peeked_token, peeked_span) = lexer.peek();
+ if let Token::Word("if") = peeked_token {
+ let span = span.until(&peeked_span);
+ return Err(Error::InvalidBreakIf(span));
+ }
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Break
+ }
+ "continue" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Continue
+ }
+ "discard" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Kill
+ }
+ // assignment or a function call
+ _ => {
+ self.function_call_or_assignment_statement(lexer, ctx.reborrow(), block)?;
+ lexer.expect(Token::Separator(';'))?;
+ self.pop_rule_span(lexer);
+ return Ok(());
+ }
+ };
+
+ let span = self.pop_rule_span(lexer);
+ block.stmts.push(ast::Statement { kind, span });
+ }
+ _ => {
+ self.assignment_statement(lexer, ctx.reborrow(), block)?;
+ lexer.expect(Token::Separator(';'))?;
+ self.pop_rule_span(lexer);
+ }
+ }
+ Ok(())
+ }
+
+ fn r#loop<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<ast::StatementKind<'a>, Error<'a>> {
+ let _ = lexer.next();
+ let mut body = ast::Block::default();
+ let mut continuing = ast::Block::default();
+ let mut break_if = None;
+
+ lexer.expect(Token::Paren('{'))?;
+
+ ctx.local_table.push_scope();
+
+ loop {
+ if lexer.skip(Token::Word("continuing")) {
+ // Branch for the `continuing` block, this must be
+ // the last thing in the loop body
+
+ // Expect a opening brace to start the continuing block
+ lexer.expect(Token::Paren('{'))?;
+ loop {
+ if lexer.skip(Token::Word("break")) {
+ // Branch for the `break if` statement, this statement
+ // has the form `break if <expr>;` and must be the last
+ // statement in a continuing block
+
+ // The break must be followed by an `if` to form
+ // the break if
+ lexer.expect(Token::Word("if"))?;
+
+ let condition = self.general_expression(lexer, ctx.reborrow())?;
+ // Set the condition of the break if to the newly parsed
+ // expression
+ break_if = Some(condition);
+
+ // Expect a semicolon to close the statement
+ lexer.expect(Token::Separator(';'))?;
+ // Expect a closing brace to close the continuing block,
+ // since the break if must be the last statement
+ lexer.expect(Token::Paren('}'))?;
+ // Stop parsing the continuing block
+ break;
+ } else if lexer.skip(Token::Paren('}')) {
+ // If we encounter a closing brace it means we have reached
+ // the end of the continuing block and should stop processing
+ break;
+ } else {
+ // Otherwise try to parse a statement
+ self.statement(lexer, ctx.reborrow(), &mut continuing)?;
+ }
+ }
+ // Since the continuing block must be the last part of the loop body,
+ // we expect to see a closing brace to end the loop body
+ lexer.expect(Token::Paren('}'))?;
+ break;
+ }
+ if lexer.skip(Token::Paren('}')) {
+ // If we encounter a closing brace it means we have reached
+ // the end of the loop body and should stop processing
+ break;
+ }
+ // Otherwise try to parse a statement
+ self.statement(lexer, ctx.reborrow(), &mut body)?;
+ }
+
+ ctx.local_table.pop_scope();
+
+ Ok(ast::StatementKind::Loop {
+ body,
+ continuing,
+ break_if,
+ })
+ }
+
+ /// compound_statement
+ fn block<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<(ast::Block<'a>, Span), Error<'a>> {
+ self.push_rule_span(Rule::Block, lexer);
+
+ ctx.local_table.push_scope();
+
+ lexer.expect(Token::Paren('{'))?;
+ let mut statements = ast::Block::default();
+ while !lexer.skip(Token::Paren('}')) {
+ self.statement(lexer, ctx.reborrow(), &mut statements)?;
+ }
+
+ ctx.local_table.pop_scope();
+
+ let span = self.pop_rule_span(lexer);
+ Ok((statements, span))
+ }
+
+ fn varying_binding<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ) -> Result<Option<crate::Binding>, Error<'a>> {
+ let mut bind_parser = BindingParser::default();
+ self.push_rule_span(Rule::Attribute, lexer);
+
+ while lexer.skip(Token::Attribute) {
+ let (word, span) = lexer.next_ident_with_span()?;
+ bind_parser.parse(lexer, word, span)?;
+ }
+
+ let span = self.pop_rule_span(lexer);
+ bind_parser.finish(span)
+ }
+
+ fn function_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ out: &mut ast::TranslationUnit<'a>,
+ dependencies: &mut FastHashSet<ast::Dependency<'a>>,
+ ) -> Result<ast::Function<'a>, Error<'a>> {
+ self.push_rule_span(Rule::FunctionDecl, lexer);
+ // read function name
+ let fun_name = lexer.next_ident()?;
+
+ let mut locals = Arena::new();
+
+ let mut ctx = ExpressionContext {
+ expressions: &mut out.expressions,
+ local_table: &mut SymbolTable::default(),
+ locals: &mut locals,
+ types: &mut out.types,
+ unresolved: dependencies,
+ };
+
+ // read parameter list
+ let mut arguments = Vec::new();
+ lexer.expect(Token::Paren('('))?;
+ let mut ready = true;
+ while !lexer.skip(Token::Paren(')')) {
+ if !ready {
+ return Err(Error::Unexpected(
+ lexer.next().1,
+ ExpectedToken::Token(Token::Separator(',')),
+ ));
+ }
+ let binding = self.varying_binding(lexer)?;
+
+ let param_name = lexer.next_ident()?;
+
+ lexer.expect(Token::Separator(':'))?;
+ let param_type = self.type_decl(lexer, ctx.reborrow())?;
+
+ let handle = ctx.locals.append(ast::Local, param_name.span);
+ ctx.local_table.add(param_name.name, handle);
+ arguments.push(ast::FunctionArgument {
+ name: param_name,
+ ty: param_type,
+ binding,
+ handle,
+ });
+ ready = lexer.skip(Token::Separator(','));
+ }
+ // read return type
+ let result = if lexer.skip(Token::Arrow) && !lexer.skip(Token::Word("void")) {
+ let binding = self.varying_binding(lexer)?;
+ let ty = self.type_decl(lexer, ctx.reborrow())?;
+ Some(ast::FunctionResult { ty, binding })
+ } else {
+ None
+ };
+
+ // read body
+ let body = self.block(lexer, ctx)?.0;
+
+ let fun = ast::Function {
+ entry_point: None,
+ name: fun_name,
+ arguments,
+ result,
+ body,
+ locals,
+ };
+
+ // done
+ self.pop_rule_span(lexer);
+
+ Ok(fun)
+ }
+
+ fn global_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ out: &mut ast::TranslationUnit<'a>,
+ ) -> Result<(), Error<'a>> {
+ // read attributes
+ let mut binding = None;
+ let mut stage = None;
+ let mut workgroup_size = [0u32; 3];
+ let mut early_depth_test = None;
+ let (mut bind_index, mut bind_group) = (None, None);
+
+ self.push_rule_span(Rule::Attribute, lexer);
+ while lexer.skip(Token::Attribute) {
+ match lexer.next_ident_with_span()? {
+ ("binding", _) => {
+ lexer.expect(Token::Paren('('))?;
+ bind_index = Some(Self::non_negative_i32_literal(lexer)?);
+ lexer.expect(Token::Paren(')'))?;
+ }
+ ("group", _) => {
+ lexer.expect(Token::Paren('('))?;
+ bind_group = Some(Self::non_negative_i32_literal(lexer)?);
+ lexer.expect(Token::Paren(')'))?;
+ }
+ ("vertex", _) => {
+ stage = Some(crate::ShaderStage::Vertex);
+ }
+ ("fragment", _) => {
+ stage = Some(crate::ShaderStage::Fragment);
+ }
+ ("compute", _) => {
+ stage = Some(crate::ShaderStage::Compute);
+ }
+ ("workgroup_size", _) => {
+ lexer.expect(Token::Paren('('))?;
+ workgroup_size = [1u32; 3];
+ for (i, size) in workgroup_size.iter_mut().enumerate() {
+ *size = Self::generic_non_negative_int_literal(lexer)?;
+ match lexer.next() {
+ (Token::Paren(')'), _) => break,
+ (Token::Separator(','), _) if i != 2 => (),
+ other => {
+ return Err(Error::Unexpected(
+ other.1,
+ ExpectedToken::WorkgroupSizeSeparator,
+ ))
+ }
+ }
+ }
+ }
+ ("early_depth_test", _) => {
+ let conservative = if lexer.skip(Token::Paren('(')) {
+ let (ident, ident_span) = lexer.next_ident_with_span()?;
+ let value = conv::map_conservative_depth(ident, ident_span)?;
+ lexer.expect(Token::Paren(')'))?;
+ Some(value)
+ } else {
+ None
+ };
+ early_depth_test = Some(crate::EarlyDepthTest { conservative });
+ }
+ (_, word_span) => return Err(Error::UnknownAttribute(word_span)),
+ }
+ }
+
+ let attrib_span = self.pop_rule_span(lexer);
+ match (bind_group, bind_index) {
+ (Some(group), Some(index)) => {
+ binding = Some(crate::ResourceBinding {
+ group,
+ binding: index,
+ });
+ }
+ (Some(_), None) => return Err(Error::MissingAttribute("binding", attrib_span)),
+ (None, Some(_)) => return Err(Error::MissingAttribute("group", attrib_span)),
+ (None, None) => {}
+ }
+
+ let mut dependencies = FastHashSet::default();
+ let mut ctx = ExpressionContext {
+ expressions: &mut out.expressions,
+ local_table: &mut SymbolTable::default(),
+ locals: &mut Arena::new(),
+ types: &mut out.types,
+ unresolved: &mut dependencies,
+ };
+
+ // read item
+ let start = lexer.start_byte_offset();
+ let kind = match lexer.next() {
+ (Token::Separator(';'), _) => None,
+ (Token::Word("struct"), _) => {
+ let name = lexer.next_ident()?;
+
+ let members = self.struct_body(lexer, ctx)?;
+ Some(ast::GlobalDeclKind::Struct(ast::Struct { name, members }))
+ }
+ (Token::Word("alias"), _) => {
+ let name = lexer.next_ident()?;
+
+ lexer.expect(Token::Operation('='))?;
+ let ty = self.type_decl(lexer, ctx)?;
+ lexer.expect(Token::Separator(';'))?;
+ Some(ast::GlobalDeclKind::Type(ast::TypeAlias { name, ty }))
+ }
+ (Token::Word("const"), _) => {
+ let name = lexer.next_ident()?;
+
+ let ty = if lexer.skip(Token::Separator(':')) {
+ let ty = self.type_decl(lexer, ctx.reborrow())?;
+ Some(ty)
+ } else {
+ None
+ };
+
+ lexer.expect(Token::Operation('='))?;
+ let init = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Separator(';'))?;
+
+ Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init }))
+ }
+ (Token::Word("var"), _) => {
+ let mut var = self.variable_decl(lexer, ctx)?;
+ var.binding = binding.take();
+ Some(ast::GlobalDeclKind::Var(var))
+ }
+ (Token::Word("fn"), _) => {
+ let function = self.function_decl(lexer, out, &mut dependencies)?;
+ Some(ast::GlobalDeclKind::Fn(ast::Function {
+ entry_point: stage.map(|stage| ast::EntryPoint {
+ stage,
+ early_depth_test,
+ workgroup_size,
+ }),
+ ..function
+ }))
+ }
+ (Token::End, _) => return Ok(()),
+ other => return Err(Error::Unexpected(other.1, ExpectedToken::GlobalItem)),
+ };
+
+ if let Some(kind) = kind {
+ out.decls.append(
+ ast::GlobalDecl { kind, dependencies },
+ lexer.span_from(start),
+ );
+ }
+
+ if !self.rules.is_empty() {
+ log::error!("Reached the end of global decl, but rule stack is not empty");
+ log::error!("Rules: {:?}", self.rules);
+ return Err(Error::Other);
+ };
+
+ match binding {
+ None => Ok(()),
+ // we had the attribute but no var?
+ Some(_) => Err(Error::Other),
+ }
+ }
+
+ pub fn parse<'a>(&mut self, source: &'a str) -> Result<ast::TranslationUnit<'a>, Error<'a>> {
+ self.reset();
+
+ let mut lexer = Lexer::new(source);
+ let mut tu = ast::TranslationUnit::default();
+ loop {
+ match self.global_decl(&mut lexer, &mut tu) {
+ Err(error) => return Err(error),
+ Ok(()) => {
+ if lexer.peek().0 == Token::End {
+ break;
+ }
+ }
+ }
+ }
+
+ Ok(tu)
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/number.rs b/third_party/rust/naga/src/front/wgsl/parse/number.rs
new file mode 100644
index 0000000000..57a2be6142
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/number.rs
@@ -0,0 +1,443 @@
+use std::borrow::Cow;
+
+use crate::front::wgsl::error::NumberError;
+use crate::front::wgsl::parse::lexer::Token;
+
+/// When using this type assume no Abstract Int/Float for now
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum Number {
+ /// Abstract Int (-2^63 ≤ i < 2^63)
+ AbstractInt(i64),
+ /// Abstract Float (IEEE-754 binary64)
+ AbstractFloat(f64),
+ /// Concrete i32
+ I32(i32),
+ /// Concrete u32
+ U32(u32),
+ /// Concrete f32
+ F32(f32),
+}
+
+impl Number {
+ /// Convert abstract numbers to a plausible concrete counterpart.
+ ///
+ /// Return concrete numbers unchanged. If the conversion would be
+ /// lossy, return an error.
+ fn abstract_to_concrete(self) -> Result<Number, NumberError> {
+ match self {
+ Number::AbstractInt(num) => i32::try_from(num)
+ .map(Number::I32)
+ .map_err(|_| NumberError::NotRepresentable),
+ Number::AbstractFloat(num) => {
+ let num = num as f32;
+ if num.is_finite() {
+ Ok(Number::F32(num))
+ } else {
+ Err(NumberError::NotRepresentable)
+ }
+ }
+ num => Ok(num),
+ }
+ }
+}
+
+// TODO: when implementing Creation-Time Expressions, remove the ability to match the minus sign
+
+pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) {
+ let (result, rest) = parse(input);
+ (
+ Token::Number(result.and_then(Number::abstract_to_concrete)),
+ rest,
+ )
+}
+
+enum Kind {
+ Int(IntKind),
+ Float(FloatKind),
+}
+
+enum IntKind {
+ I32,
+ U32,
+}
+
+enum FloatKind {
+ F32,
+ F16,
+}
+
+// The following regexes (from the WGSL spec) will be matched:
+
+// int_literal:
+// | / 0 [iu]? /
+// | / [1-9][0-9]* [iu]? /
+// | / 0[xX][0-9a-fA-F]+ [iu]? /
+
+// decimal_float_literal:
+// | / 0 [fh] /
+// | / [1-9][0-9]* [fh] /
+// | / [0-9]* \.[0-9]+ ([eE][+-]?[0-9]+)? [fh]? /
+// | / [0-9]+ \.[0-9]* ([eE][+-]?[0-9]+)? [fh]? /
+// | / [0-9]+ [eE][+-]?[0-9]+ [fh]? /
+
+// hex_float_literal:
+// | / 0[xX][0-9a-fA-F]* \.[0-9a-fA-F]+ ([pP][+-]?[0-9]+ [fh]?)? /
+// | / 0[xX][0-9a-fA-F]+ \.[0-9a-fA-F]* ([pP][+-]?[0-9]+ [fh]?)? /
+// | / 0[xX][0-9a-fA-F]+ [pP][+-]?[0-9]+ [fh]? /
+
+// You could visualize the regex below via https://debuggex.com to get a rough idea what `parse` is doing
+// -?(?:0[xX](?:([0-9a-fA-F]+\.[0-9a-fA-F]*|[0-9a-fA-F]*\.[0-9a-fA-F]+)(?:([pP][+-]?[0-9]+)([fh]?))?|([0-9a-fA-F]+)([pP][+-]?[0-9]+)([fh]?)|([0-9a-fA-F]+)([iu]?))|((?:[0-9]+[eE][+-]?[0-9]+|(?:[0-9]+\.[0-9]*|[0-9]*\.[0-9]+)(?:[eE][+-]?[0-9]+)?))([fh]?)|((?:[0-9]|[1-9][0-9]+))([iufh]?))
+
+fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
+ /// returns `true` and consumes `X` bytes from the given byte buffer
+ /// if the given `X` nr of patterns are found at the start of the buffer
+ macro_rules! consume {
+ ($bytes:ident, $($pattern:pat),*) => {
+ match $bytes {
+ &[$($pattern),*, ref rest @ ..] => { $bytes = rest; true },
+ _ => false,
+ }
+ };
+ }
+
+ /// consumes one byte from the given byte buffer
+ /// if one of the given patterns are found at the start of the buffer
+ /// returning the corresponding expr for the matched pattern
+ macro_rules! consume_map {
+ ($bytes:ident, [$($pattern:pat_param => $to:expr),*]) => {
+ match $bytes {
+ $( &[$pattern, ref rest @ ..] => { $bytes = rest; Some($to) }, )*
+ _ => None,
+ }
+ };
+ }
+
+ /// consumes all consecutive bytes matched by the `0-9` pattern from the given byte buffer
+ /// returning the number of consumed bytes
+ macro_rules! consume_dec_digits {
+ ($bytes:ident) => {{
+ let start_len = $bytes.len();
+ while let &[b'0'..=b'9', ref rest @ ..] = $bytes {
+ $bytes = rest;
+ }
+ start_len - $bytes.len()
+ }};
+ }
+
+ /// consumes all consecutive bytes matched by the `0-9 | a-f | A-F` pattern from the given byte buffer
+ /// returning the number of consumed bytes
+ macro_rules! consume_hex_digits {
+ ($bytes:ident) => {{
+ let start_len = $bytes.len();
+ while let &[b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F', ref rest @ ..] = $bytes {
+ $bytes = rest;
+ }
+ start_len - $bytes.len()
+ }};
+ }
+
+ /// maps the given `&[u8]` (tail of the initial `input: &str`) to a `&str`
+ macro_rules! rest_to_str {
+ ($bytes:ident) => {
+ &input[input.len() - $bytes.len()..]
+ };
+ }
+
+ struct ExtractSubStr<'a>(&'a str);
+
+ impl<'a> ExtractSubStr<'a> {
+ /// given an `input` and a `start` (tail of the `input`)
+ /// creates a new [`ExtractSubStr`](`Self`)
+ fn start(input: &'a str, start: &'a [u8]) -> Self {
+ let start = input.len() - start.len();
+ Self(&input[start..])
+ }
+ /// given an `end` (tail of the initial `input`)
+ /// returns a substring of `input`
+ fn end(&self, end: &'a [u8]) -> &'a str {
+ let end = self.0.len() - end.len();
+ &self.0[..end]
+ }
+ }
+
+ let mut bytes = input.as_bytes();
+
+ let general_extract = ExtractSubStr::start(input, bytes);
+
+ let is_negative = consume!(bytes, b'-');
+
+ if consume!(bytes, b'0', b'x' | b'X') {
+ let digits_extract = ExtractSubStr::start(input, bytes);
+
+ let consumed = consume_hex_digits!(bytes);
+
+ if consume!(bytes, b'.') {
+ let consumed_after_period = consume_hex_digits!(bytes);
+
+ if consumed + consumed_after_period == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let significand = general_extract.end(bytes);
+
+ if consume!(bytes, b'p' | b'P') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let number = general_extract.end(bytes);
+
+ let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
+
+ (parse_hex_float(number, kind), rest_to_str!(bytes))
+ } else {
+ (
+ parse_hex_float_missing_exponent(significand, None),
+ rest_to_str!(bytes),
+ )
+ }
+ } else {
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let significand = general_extract.end(bytes);
+ let digits = digits_extract.end(bytes);
+
+ let exp_extract = ExtractSubStr::start(input, bytes);
+
+ if consume!(bytes, b'p' | b'P') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let exponent = exp_extract.end(bytes);
+
+ let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
+
+ (
+ parse_hex_float_missing_period(significand, exponent, kind),
+ rest_to_str!(bytes),
+ )
+ } else {
+ let kind = consume_map!(bytes, [b'i' => IntKind::I32, b'u' => IntKind::U32]);
+
+ (
+ parse_hex_int(is_negative, digits, kind),
+ rest_to_str!(bytes),
+ )
+ }
+ }
+ } else {
+ let is_first_zero = bytes.first() == Some(&b'0');
+
+ let consumed = consume_dec_digits!(bytes);
+
+ if consume!(bytes, b'.') {
+ let consumed_after_period = consume_dec_digits!(bytes);
+
+ if consumed + consumed_after_period == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ if consume!(bytes, b'e' | b'E') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+ }
+
+ let number = general_extract.end(bytes);
+
+ let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
+
+ (parse_dec_float(number, kind), rest_to_str!(bytes))
+ } else {
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ if consume!(bytes, b'e' | b'E') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let number = general_extract.end(bytes);
+
+ let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
+
+ (parse_dec_float(number, kind), rest_to_str!(bytes))
+ } else {
+ // make sure the multi-digit numbers don't start with zero
+ if consumed > 1 && is_first_zero {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let digits_with_sign = general_extract.end(bytes);
+
+ let kind = consume_map!(bytes, [
+ b'i' => Kind::Int(IntKind::I32),
+ b'u' => Kind::Int(IntKind::U32),
+ b'f' => Kind::Float(FloatKind::F32),
+ b'h' => Kind::Float(FloatKind::F16)
+ ]);
+
+ (
+ parse_dec(is_negative, digits_with_sign, kind),
+ rest_to_str!(bytes),
+ )
+ }
+ }
+ }
+}
+
+fn parse_hex_float_missing_exponent(
+ // format: -?0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ )
+ significand: &str,
+ kind: Option<FloatKind>,
+) -> Result<Number, NumberError> {
+ let hexf_input = format!("{}{}", significand, "p0");
+ parse_hex_float(&hexf_input, kind)
+}
+
+fn parse_hex_float_missing_period(
+ // format: -?0[xX] [0-9a-fA-F]+
+ significand: &str,
+ // format: [pP][+-]?[0-9]+
+ exponent: &str,
+ kind: Option<FloatKind>,
+) -> Result<Number, NumberError> {
+ let hexf_input = format!("{significand}.{exponent}");
+ parse_hex_float(&hexf_input, kind)
+}
+
+fn parse_hex_int(
+ is_negative: bool,
+ // format: [0-9a-fA-F]+
+ digits: &str,
+ kind: Option<IntKind>,
+) -> Result<Number, NumberError> {
+ let digits_with_sign = if is_negative {
+ Cow::Owned(format!("-{digits}"))
+ } else {
+ Cow::Borrowed(digits)
+ };
+ parse_int(&digits_with_sign, kind, 16, is_negative)
+}
+
+fn parse_dec(
+ is_negative: bool,
+ // format: -? ( [0-9] | [1-9][0-9]+ )
+ digits_with_sign: &str,
+ kind: Option<Kind>,
+) -> Result<Number, NumberError> {
+ match kind {
+ None => parse_int(digits_with_sign, None, 10, is_negative),
+ Some(Kind::Int(kind)) => parse_int(digits_with_sign, Some(kind), 10, is_negative),
+ Some(Kind::Float(kind)) => parse_dec_float(digits_with_sign, Some(kind)),
+ }
+}
+
+// Float parsing notes
+
+// The following chapters of IEEE 754-2019 are relevant:
+//
+// 7.4 Overflow (largest finite number is exceeded by what would have been
+// the rounded floating-point result were the exponent range unbounded)
+//
+// 7.5 Underflow (tiny non-zero result is detected;
+// for decimal formats tininess is detected before rounding when a non-zero result
+// computed as though both the exponent range and the precision were unbounded
+// would lie strictly between 2^−126)
+//
+// 7.6 Inexact (rounded result differs from what would have been computed
+// were both exponent range and precision unbounded)
+
+// The WGSL spec requires us to error:
+// on overflow for decimal floating point literals
+// on overflow and inexact for hexadecimal floating point literals
+// (underflow is not mentioned)
+
+// hexf_parse errors on overflow, underflow, inexact
+// rust std lib float from str handles overflow, underflow, inexact transparently (rounds and will not error)
+
+// Therefore we only check for overflow manually for decimal floating point literals
+
+// input format: -?0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) [pP][+-]?[0-9]+
+fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
+ match kind {
+ None => match hexf_parse::parse_hexf64(input, false) {
+ Ok(num) => Ok(Number::AbstractFloat(num)),
+ // can only be ParseHexfErrorKind::Inexact but we can't check since it's private
+ _ => Err(NumberError::NotRepresentable),
+ },
+ Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) {
+ Ok(num) => Ok(Number::F32(num)),
+ // can only be ParseHexfErrorKind::Inexact but we can't check since it's private
+ _ => Err(NumberError::NotRepresentable),
+ },
+ Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
+ }
+}
+
+// input format: -? ( [0-9]+\.[0-9]* | [0-9]*\.[0-9]+ ) ([eE][+-]?[0-9]+)?
+// | -? [0-9]+ [eE][+-]?[0-9]+
+fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
+ match kind {
+ None => {
+ let num = input.parse::<f64>().unwrap(); // will never fail
+ num.is_finite()
+ .then_some(Number::AbstractFloat(num))
+ .ok_or(NumberError::NotRepresentable)
+ }
+ Some(FloatKind::F32) => {
+ let num = input.parse::<f32>().unwrap(); // will never fail
+ num.is_finite()
+ .then_some(Number::F32(num))
+ .ok_or(NumberError::NotRepresentable)
+ }
+ Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
+ }
+}
+
+fn parse_int(
+ input: &str,
+ kind: Option<IntKind>,
+ radix: u32,
+ is_negative: bool,
+) -> Result<Number, NumberError> {
+ fn map_err(e: core::num::ParseIntError) -> NumberError {
+ match *e.kind() {
+ core::num::IntErrorKind::PosOverflow | core::num::IntErrorKind::NegOverflow => {
+ NumberError::NotRepresentable
+ }
+ _ => unreachable!(),
+ }
+ }
+ match kind {
+ None => match i64::from_str_radix(input, radix) {
+ Ok(num) => Ok(Number::AbstractInt(num)),
+ Err(e) => Err(map_err(e)),
+ },
+ Some(IntKind::I32) => match i32::from_str_radix(input, radix) {
+ Ok(num) => Ok(Number::I32(num)),
+ Err(e) => Err(map_err(e)),
+ },
+ Some(IntKind::U32) if is_negative => Err(NumberError::NotRepresentable),
+ Some(IntKind::U32) => match u32::from_str_radix(input, radix) {
+ Ok(num) => Ok(Number::U32(num)),
+ Err(e) => Err(map_err(e)),
+ },
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/tests.rs b/third_party/rust/naga/src/front/wgsl/tests.rs
new file mode 100644
index 0000000000..31b5890707
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/tests.rs
@@ -0,0 +1,483 @@
+use super::parse_str;
+
+#[test]
+fn parse_comment() {
+ parse_str(
+ "//
+ ////
+ ///////////////////////////////////////////////////////// asda
+ //////////////////// dad ////////// /
+ /////////////////////////////////////////////////////////////////////////////////////////////////////
+ //
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_types() {
+ parse_str("const a : i32 = 2;").unwrap();
+ assert!(parse_str("const a : x32 = 2;").is_err());
+ parse_str("var t: texture_2d<f32>;").unwrap();
+ parse_str("var t: texture_cube_array<i32>;").unwrap();
+ parse_str("var t: texture_multisampled_2d<u32>;").unwrap();
+ parse_str("var t: texture_storage_1d<rgba8uint,write>;").unwrap();
+ parse_str("var t: texture_storage_3d<r32float,read>;").unwrap();
+}
+
+#[test]
+fn parse_type_inference() {
+ parse_str(
+ "
+ fn foo() {
+ let a = 2u;
+ let b: u32 = a;
+ var x = 3.;
+ var y = vec2<f32>(1, 2);
+ }",
+ )
+ .unwrap();
+ assert!(parse_str(
+ "
+ fn foo() { let c : i32 = 2.0; }",
+ )
+ .is_err());
+}
+
+#[test]
+fn parse_type_cast() {
+ parse_str(
+ "
+ const a : i32 = 2;
+ fn main() {
+ var x: f32 = f32(a);
+ x = f32(i32(a + 1) / 2);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ let x: vec2<f32> = vec2<f32>(1.0, 2.0);
+ let y: vec2<u32> = vec2<u32>(x);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ let x: vec2<f32> = vec2<f32>(0.0);
+ }
+ ",
+ )
+ .unwrap();
+ assert!(parse_str(
+ "
+ fn main() {
+ let x: vec2<f32> = vec2<f32>(0);
+ }
+ ",
+ )
+ .is_err());
+}
+
+#[test]
+fn parse_struct() {
+ parse_str(
+ "
+ struct Foo { x: i32 }
+ struct Bar {
+ @size(16) x: vec2<i32>,
+ @align(16) y: f32,
+ @size(32) @align(128) z: vec3<f32>,
+ };
+ struct Empty {}
+ var<storage,read_write> s: Foo;
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_standard_fun() {
+ parse_str(
+ "
+ fn main() {
+ var x: i32 = min(max(1, 2), 3);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_statement() {
+ parse_str(
+ "
+ fn main() {
+ ;
+ {}
+ {;}
+ }
+ ",
+ )
+ .unwrap();
+
+ parse_str(
+ "
+ fn foo() {}
+ fn bar() { foo(); }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_if() {
+ parse_str(
+ "
+ fn main() {
+ if true {
+ discard;
+ } else {}
+ if 0 != 1 {}
+ if false {
+ return;
+ } else if true {
+ return;
+ } else {}
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_parentheses_if() {
+ parse_str(
+ "
+ fn main() {
+ if (true) {
+ discard;
+ } else {}
+ if (0 != 1) {}
+ if (false) {
+ return;
+ } else if (true) {
+ return;
+ } else {}
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_loop() {
+ parse_str(
+ "
+ fn main() {
+ var i: i32 = 0;
+ loop {
+ if i == 1 { break; }
+ continuing { i = 1; }
+ }
+ loop {
+ if i == 0 { continue; }
+ break;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ var found: bool = false;
+ var i: i32 = 0;
+ while !found {
+ if i == 10 {
+ found = true;
+ }
+
+ i = i + 1;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ while true {
+ break;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ var a: i32 = 0;
+ for(var i: i32 = 0; i < 4; i = i + 1) {
+ a = a + 2;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ for(;;) {
+ break;
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_switch() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch (3) {
+ case 0, 1: { pos = 0.0; }
+ case 2: { pos = 1.0; }
+ default: { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_switch_optional_colon_in_case() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch (3) {
+ case 0, 1 { pos = 0.0; }
+ case 2 { pos = 1.0; }
+ default { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_switch_default_in_case() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch (3) {
+ case 0, 1: { pos = 0.0; }
+ case 2: {}
+ case default, 3: { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_parentheses_switch() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch pos > 1.0 {
+ default: { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_load() {
+ parse_str(
+ "
+ var t: texture_3d<u32>;
+ fn foo() {
+ let r: vec4<u32> = textureLoad(t, vec3<u32>(0.0, 1.0, 2.0), 1);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ var t: texture_multisampled_2d_array<i32>;
+ fn foo() {
+ let r: vec4<i32> = textureLoad(t, vec2<i32>(10, 20), 2, 3);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ var t: texture_storage_1d_array<r32float,read>;
+ fn foo() {
+ let r: vec4<f32> = textureLoad(t, 10, 2);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_store() {
+ parse_str(
+ "
+ var t: texture_storage_2d<rgba8unorm,write>;
+ fn foo() {
+ textureStore(t, vec2<i32>(10, 20), vec4<f32>(0.0, 1.0, 2.0, 3.0));
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_query() {
+ parse_str(
+ "
+ var t: texture_multisampled_2d_array<f32>;
+ fn foo() {
+ var dim: vec2<u32> = textureDimensions(t);
+ dim = textureDimensions(t, 0);
+ let layers: u32 = textureNumLayers(t);
+ let samples: u32 = textureNumSamples(t);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_postfix() {
+ parse_str(
+ "fn foo() {
+ let x: f32 = vec4<f32>(1.0, 2.0, 3.0, 4.0).xyz.rgbr.aaaa.wz.g;
+ let y: f32 = fract(vec2<f32>(0.5, x)).x;
+ }",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_expressions() {
+ parse_str("fn foo() {
+ let x: f32 = select(0.0, 1.0, true);
+ let y: vec2<f32> = select(vec2<f32>(1.0, 1.0), vec2<f32>(x, x), vec2<bool>(x < 0.5, x > 0.5));
+ let z: bool = !(0.0 == 1.0);
+ }").unwrap();
+}
+
+#[test]
+fn parse_pointers() {
+ parse_str(
+ "fn foo() {
+ var x: f32 = 1.0;
+ let px = &x;
+ let py = frexp(0.5, px);
+ }",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_struct_instantiation() {
+ parse_str(
+ "
+ struct Foo {
+ a: f32,
+ b: vec3<f32>,
+ }
+
+ @fragment
+ fn fs_main() {
+ var foo: Foo = Foo(0.0, vec3<f32>(0.0, 1.0, 42.0));
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_array_length() {
+ parse_str(
+ "
+ struct Foo {
+ data: array<u32>
+ } // this is used as both input and output for convenience
+
+ @group(0) @binding(0)
+ var<storage> foo: Foo;
+
+ @group(0) @binding(1)
+ var<storage> bar: array<u32>;
+
+ fn baz() {
+ var x: u32 = arrayLength(foo.data);
+ var y: u32 = arrayLength(bar);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_storage_buffers() {
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage,read> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage,write> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage,read_write> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_alias() {
+ parse_str(
+ "
+ alias Vec4 = vec4<f32>;
+ ",
+ )
+ .unwrap();
+}
diff --git a/third_party/rust/naga/src/keywords/mod.rs b/third_party/rust/naga/src/keywords/mod.rs
new file mode 100644
index 0000000000..d54a1704f7
--- /dev/null
+++ b/third_party/rust/naga/src/keywords/mod.rs
@@ -0,0 +1,6 @@
+/*!
+Lists of reserved keywords for each shading language with a [frontend][crate::front] or [backend][crate::back].
+*/
+
+#[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))]
+pub mod wgsl;
diff --git a/third_party/rust/naga/src/keywords/wgsl.rs b/third_party/rust/naga/src/keywords/wgsl.rs
new file mode 100644
index 0000000000..7b47a13128
--- /dev/null
+++ b/third_party/rust/naga/src/keywords/wgsl.rs
@@ -0,0 +1,229 @@
+/*!
+Keywords for [WGSL][wgsl] (WebGPU Shading Language).
+
+[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html
+*/
+
+// https://gpuweb.github.io/gpuweb/wgsl/#keyword-summary
+// last sync: https://github.com/gpuweb/gpuweb/blob/39f2321f547c8f0b7f473cf1d47fba30b1691303/wgsl/index.bs
+pub const RESERVED: &[&str] = &[
+ // Type-defining Keywords
+ "array",
+ "atomic",
+ "bool",
+ "f32",
+ "f16",
+ "i32",
+ "mat2x2",
+ "mat2x3",
+ "mat2x4",
+ "mat3x2",
+ "mat3x3",
+ "mat3x4",
+ "mat4x2",
+ "mat4x3",
+ "mat4x4",
+ "ptr",
+ "sampler",
+ "sampler_comparison",
+ "texture_1d",
+ "texture_2d",
+ "texture_2d_array",
+ "texture_3d",
+ "texture_cube",
+ "texture_cube_array",
+ "texture_multisampled_2d",
+ "texture_storage_1d",
+ "texture_storage_2d",
+ "texture_storage_2d_array",
+ "texture_storage_3d",
+ "texture_depth_2d",
+ "texture_depth_2d_array",
+ "texture_depth_cube",
+ "texture_depth_cube_array",
+ "texture_depth_multisampled_2d",
+ "u32",
+ "vec2",
+ "vec3",
+ "vec4",
+ // Other Keywords
+ "alias",
+ "bitcast",
+ "break",
+ "case",
+ "const",
+ "continue",
+ "continuing",
+ "default",
+ "discard",
+ "else",
+ "enable",
+ "false",
+ "fn",
+ "for",
+ "if",
+ "let",
+ "loop",
+ "override",
+ "return",
+ "static_assert",
+ "struct",
+ "switch",
+ "true",
+ "type",
+ "var",
+ "while",
+ // Reserved Words
+ "CompileShader",
+ "ComputeShader",
+ "DomainShader",
+ "GeometryShader",
+ "Hullshader",
+ "NULL",
+ "Self",
+ "abstract",
+ "active",
+ "alignas",
+ "alignof",
+ "as",
+ "asm",
+ "asm_fragment",
+ "async",
+ "attribute",
+ "auto",
+ "await",
+ "become",
+ "binding_array",
+ "cast",
+ "catch",
+ "class",
+ "co_await",
+ "co_return",
+ "co_yield",
+ "coherent",
+ "column_major",
+ "common",
+ "compile",
+ "compile_fragment",
+ "concept",
+ "const_cast",
+ "consteval",
+ "constexpr",
+ "constinit",
+ "crate",
+ "debugger",
+ "decltype",
+ "delete",
+ "demote",
+ "demote_to_helper",
+ "do",
+ "dynamic_cast",
+ "enum",
+ "explicit",
+ "export",
+ "extends",
+ "extern",
+ "external",
+ "fallthrough",
+ "filter",
+ "final",
+ "finally",
+ "friend",
+ "from",
+ "fxgroup",
+ "get",
+ "goto",
+ "groupshared",
+ "handle",
+ "highp",
+ "impl",
+ "implements",
+ "import",
+ "inline",
+ "inout",
+ "instanceof",
+ "interface",
+ "layout",
+ "lowp",
+ "macro",
+ "macro_rules",
+ "match",
+ "mediump",
+ "meta",
+ "mod",
+ "module",
+ "move",
+ "mut",
+ "mutable",
+ "namespace",
+ "new",
+ "nil",
+ "noexcept",
+ "noinline",
+ "nointerpolation",
+ "noperspective",
+ "null",
+ "nullptr",
+ "of",
+ "operator",
+ "package",
+ "packoffset",
+ "partition",
+ "pass",
+ "patch",
+ "pixelfragment",
+ "precise",
+ "precision",
+ "premerge",
+ "priv",
+ "protected",
+ "pub",
+ "public",
+ "readonly",
+ "ref",
+ "regardless",
+ "register",
+ "reinterpret_cast",
+ "requires",
+ "resource",
+ "restrict",
+ "self",
+ "set",
+ "shared",
+ "signed",
+ "sizeof",
+ "smooth",
+ "snorm",
+ "static",
+ "static_assert",
+ "static_cast",
+ "std",
+ "subroutine",
+ "super",
+ "target",
+ "template",
+ "this",
+ "thread_local",
+ "throw",
+ "trait",
+ "try",
+ "typedef",
+ "typeid",
+ "typename",
+ "typeof",
+ "union",
+ "unless",
+ "unorm",
+ "unsafe",
+ "unsized",
+ "use",
+ "using",
+ "varying",
+ "virtual",
+ "volatile",
+ "wgsl",
+ "where",
+ "with",
+ "writeonly",
+ "yield",
+];
diff --git a/third_party/rust/naga/src/lib.rs b/third_party/rust/naga/src/lib.rs
new file mode 100644
index 0000000000..a70015d16d
--- /dev/null
+++ b/third_party/rust/naga/src/lib.rs
@@ -0,0 +1,1892 @@
+/*! Universal shader translator.
+
+The central structure of the crate is [`Module`]. A `Module` contains:
+
+- [`Function`]s, which have arguments, a return type, local variables, and a body,
+
+- [`EntryPoint`]s, which are specialized functions that can serve as the entry
+ point for pipeline stages like vertex shading or fragment shading,
+
+- [`Constant`]s and [`GlobalVariable`]s used by `EntryPoint`s and `Function`s, and
+
+- [`Type`]s used by the above.
+
+The body of an `EntryPoint` or `Function` is represented using two types:
+
+- An [`Expression`] produces a value, but has no side effects or control flow.
+ `Expressions` include variable references, unary and binary operators, and so
+ on.
+
+- A [`Statement`] can have side effects and structured control flow.
+ `Statement`s do not produce a value, other than by storing one in some
+ designated place. `Statements` include blocks, conditionals, and loops, but also
+ operations that have side effects, like stores and function calls.
+
+`Statement`s form a tree, with pointers into the DAG of `Expression`s.
+
+Restricting side effects to statements simplifies analysis and code generation.
+A Naga backend can generate code to evaluate an `Expression` however and
+whenever it pleases, as long as it is certain to observe the side effects of all
+previously executed `Statement`s.
+
+Many `Statement` variants use the [`Block`] type, which is `Vec<Statement>`,
+with optional span info, representing a series of statements executed in order. The body of an
+`EntryPoint`s or `Function` is a `Block`, and `Statement` has a
+[`Block`][Statement::Block] variant.
+
+If the `clone` feature is enabled, [`Arena`], [`UniqueArena`], [`Type`], [`TypeInner`],
+[`Constant`], [`Function`], [`EntryPoint`] and [`Module`] can be cloned.
+
+## Arenas
+
+To improve translator performance and reduce memory usage, most structures are
+stored in an [`Arena`]. An `Arena<T>` stores a series of `T` values, indexed by
+[`Handle<T>`](Handle) values, which are just wrappers around integer indexes.
+For example, a `Function`'s expressions are stored in an `Arena<Expression>`,
+and compound expressions refer to their sub-expressions via `Handle<Expression>`
+values. (When examining the serialized form of a `Module`, note that the first
+element of an `Arena` has an index of 1, not 0.)
+
+A [`UniqueArena`] is just like an `Arena`, except that it stores only a single
+instance of each value. The value type must implement `Eq` and `Hash`. Like an
+`Arena`, inserting a value into a `UniqueArena` returns a `Handle` which can be
+used to efficiently access the value, without a hash lookup. Inserting a value
+multiple times returns the same `Handle`.
+
+If the `span` feature is enabled, both `Arena` and `UniqueArena` can associate a
+source code span with each element.
+
+## Function Calls
+
+Naga's representation of function calls is unusual. Most languages treat
+function calls as expressions, but because calls may have side effects, Naga
+represents them as a kind of statement, [`Statement::Call`]. If the function
+returns a value, a call statement designates a particular [`Expression::CallResult`]
+expression to represent its return value, for use by subsequent statements and
+expressions.
+
+## `Expression` evaluation time
+
+It is essential to know when an [`Expression`] should be evaluated, because its
+value may depend on previous [`Statement`]s' effects. But whereas the order of
+execution for a tree of `Statement`s is apparent from its structure, it is not
+so clear for `Expressions`, since an expression may be referred to by any number
+of `Statement`s and other `Expression`s.
+
+Naga's rules for when `Expression`s are evaluated are as follows:
+
+- [`Constant`](Expression::Constant) expressions are considered to be
+ implicitly evaluated before execution begins.
+
+- [`FunctionArgument`] and [`LocalVariable`] expressions are considered
+ implicitly evaluated upon entry to the function to which they belong.
+ Function arguments cannot be assigned to, and `LocalVariable` expressions
+ produce a *pointer to* the variable's value (for use with [`Load`] and
+ [`Store`]). Neither varies while the function executes, so it suffices to
+ consider these expressions evaluated once on entry.
+
+- Similarly, [`GlobalVariable`] expressions are considered implicitly
+ evaluated before execution begins, since their value does not change while
+ code executes, for one of two reasons:
+
+ - Most `GlobalVariable` expressions produce a pointer to the variable's
+ value, for use with [`Load`] and [`Store`], as `LocalVariable`
+ expressions do. Although the variable's value may change, its address
+ does not.
+
+ - A `GlobalVariable` expression referring to a global in the
+ [`AddressSpace::Handle`] address space produces the value directly, not
+ a pointer. Such global variables hold opaque types like shaders or
+ images, and cannot be assigned to.
+
+- A [`CallResult`] expression that is the `result` of a [`Statement::Call`],
+ representing the call's return value, is evaluated when the `Call` statement
+ is executed.
+
+- Similarly, an [`AtomicResult`] expression that is the `result` of an
+ [`Atomic`] statement, representing the result of the atomic operation, is
+ evaluated when the `Atomic` statement is executed.
+
+- A [`RayQueryProceedResult`] expression, which is a boolean
+ indicating if the ray query is finished, is evaluated when the
+ [`RayQuery`] statement whose [`Proceed::result`] points to it is
+ executed.
+
+- All other expressions are evaluated when the (unique) [`Statement::Emit`]
+ statement that covers them is executed.
+
+Now, strictly speaking, not all `Expression` variants actually care when they're
+evaluated. For example, you can evaluate a [`BinaryOperator::Add`] expression
+any time you like, as long as you give it the right operands. It's really only a
+very small set of expressions that are affected by timing:
+
+- [`Load`], [`ImageSample`], and [`ImageLoad`] expressions are influenced by
+ stores to the variables or images they access, and must execute at the
+ proper time relative to them.
+
+- [`Derivative`] expressions are sensitive to control flow uniformity: they
+ must not be moved out of an area of uniform control flow into a non-uniform
+ area.
+
+- More generally, any expression that's used by more than one other expression
+ or statement should probably be evaluated only once, and then stored in a
+ variable to be cited at each point of use.
+
+Naga tries to help back ends handle all these cases correctly in a somewhat
+circuitous way. The [`ModuleInfo`] structure returned by [`Validator::validate`]
+provides a reference count for each expression in each function in the module.
+Naturally, any expression with a reference count of two or more deserves to be
+evaluated and stored in a temporary variable at the point that the `Emit`
+statement covering it is executed. But if we selectively lower the reference
+count threshold to _one_ for the sensitive expression types listed above, so
+that we _always_ generate a temporary variable and save their value, then the
+same code that manages multiply referenced expressions will take care of
+introducing temporaries for time-sensitive expressions as well. The
+`Expression::bake_ref_count` method (private to the back ends) is meant to help
+with this.
+
+## `Expression` scope
+
+Each `Expression` has a *scope*, which is the region of the function within
+which it can be used by `Statement`s and other `Expression`s. It is a validation
+error to use an `Expression` outside its scope.
+
+An expression's scope is defined as follows:
+
+- The scope of a [`Constant`], [`GlobalVariable`], [`FunctionArgument`] or
+ [`LocalVariable`] expression covers the entire `Function` in which it
+ occurs.
+
+- The scope of an expression evaluated by an [`Emit`] statement covers the
+ subsequent expressions in that `Emit`, the subsequent statements in the `Block`
+ to which that `Emit` belongs (if any) and their sub-statements (if any).
+
+- The `result` expression of a [`Call`] or [`Atomic`] statement has a scope
+ covering the subsequent statements in the `Block` in which the statement
+ occurs (if any) and their sub-statements (if any).
+
+For example, this implies that an expression evaluated by some statement in a
+nested `Block` is not available in the `Block`'s parents. Such a value would
+need to be stored in a local variable to be carried upwards in the statement
+tree.
+
+[`AtomicResult`]: Expression::AtomicResult
+[`RayQueryProceedResult`]: Expression::RayQueryProceedResult
+[`CallResult`]: Expression::CallResult
+[`Constant`]: Expression::Constant
+[`Derivative`]: Expression::Derivative
+[`FunctionArgument`]: Expression::FunctionArgument
+[`GlobalVariable`]: Expression::GlobalVariable
+[`ImageLoad`]: Expression::ImageLoad
+[`ImageSample`]: Expression::ImageSample
+[`Load`]: Expression::Load
+[`LocalVariable`]: Expression::LocalVariable
+
+[`Atomic`]: Statement::Atomic
+[`Call`]: Statement::Call
+[`Emit`]: Statement::Emit
+[`Store`]: Statement::Store
+[`RayQuery`]: Statement::RayQuery
+
+[`Proceed::result`]: RayQueryFunction::Proceed::result
+
+[`Validator::validate`]: valid::Validator::validate
+[`ModuleInfo`]: valid::ModuleInfo
+*/
+
+#![allow(
+ clippy::new_without_default,
+ clippy::unneeded_field_pattern,
+ clippy::match_like_matches_macro,
+ clippy::collapsible_if,
+ clippy::derive_partial_eq_without_eq,
+ clippy::needless_borrowed_reference
+)]
+#![warn(
+ trivial_casts,
+ trivial_numeric_casts,
+ unused_extern_crates,
+ unused_qualifications,
+ clippy::pattern_type_mismatch,
+ clippy::missing_const_for_fn,
+ clippy::rest_pat_in_fully_bound_structs,
+ clippy::match_wildcard_for_single_variants
+)]
+#![cfg_attr(not(test), deny(clippy::panic))]
+
+mod arena;
+pub mod back;
+mod block;
+pub mod front;
+pub mod keywords;
+pub mod proc;
+mod span;
+pub mod valid;
+
+pub use crate::arena::{Arena, Handle, Range, UniqueArena};
+
+pub use crate::span::{SourceLocation, Span, SpanContext, WithSpan};
+#[cfg(feature = "arbitrary")]
+use arbitrary::Arbitrary;
+#[cfg(feature = "deserialize")]
+use serde::Deserialize;
+#[cfg(feature = "serialize")]
+use serde::Serialize;
+
+/// Width of a boolean type, in bytes.
+pub const BOOL_WIDTH: Bytes = 1;
+
+/// Hash map that is faster but not resilient to DoS attacks.
+pub type FastHashMap<K, T> = rustc_hash::FxHashMap<K, T>;
+/// Hash set that is faster but not resilient to DoS attacks.
+pub type FastHashSet<K> = rustc_hash::FxHashSet<K>;
+
+/// Map of expressions that have associated variable names
+pub(crate) type NamedExpressions = indexmap::IndexMap<
+ Handle<Expression>,
+ String,
+ std::hash::BuildHasherDefault<rustc_hash::FxHasher>,
+>;
+
+/// Early fragment tests.
+///
+/// In a standard situation, if a driver determines that it is possible to switch on early depth test, it will.
+///
+/// Typical situations when early depth test is switched off:
+/// - Calling `discard` in a shader.
+/// - Writing to the depth buffer, unless ConservativeDepth is enabled.
+///
+/// To use in a shader:
+/// - GLSL: `layout(early_fragment_tests) in;`
+/// - HLSL: `Attribute earlydepthstencil`
+/// - SPIR-V: `ExecutionMode EarlyFragmentTests`
+/// - WGSL: `@early_depth_test`
+///
+/// For more, see:
+/// - <https://www.khronos.org/opengl/wiki/Early_Fragment_Test#Explicit_specification>
+/// - <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-attributes-earlydepthstencil>
+/// - <https://www.khronos.org/registry/SPIR-V/specs/unified1/SPIRV.html#Execution_Mode>
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct EarlyDepthTest {
+ conservative: Option<ConservativeDepth>,
+}
+/// Enables adjusting depth without disabling early Z.
+///
+/// To use in a shader:
+/// - GLSL: `layout (depth_<greater/less/unchanged/any>) out float gl_FragDepth;`
+/// - `depth_any` option behaves as if the layout qualifier was not present.
+/// - HLSL: `SV_DepthGreaterEqual`/`SV_DepthLessEqual`/`SV_Depth`
+/// - SPIR-V: `ExecutionMode Depth<Greater/Less/Unchanged>`
+/// - WGSL: `@early_depth_test(greater_equal/less_equal/unchanged)`
+///
+/// For more, see:
+/// - <https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt>
+/// - <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-semantics#system-value-semantics>
+/// - <https://www.khronos.org/registry/SPIR-V/specs/unified1/SPIRV.html#Execution_Mode>
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum ConservativeDepth {
+ /// Shader may rewrite depth only with a value greater than calculated.
+ GreaterEqual,
+
+ /// Shader may rewrite depth smaller than one that would have been written without the modification.
+ LessEqual,
+
+ /// Shader may not rewrite depth value.
+ Unchanged,
+}
+
+/// Stage of the programmable pipeline.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+#[allow(missing_docs)] // The names are self evident
+pub enum ShaderStage {
+ Vertex,
+ Fragment,
+ Compute,
+}
+
+/// Addressing space of variables.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum AddressSpace {
+ /// Function locals.
+ Function,
+ /// Private data, per invocation, mutable.
+ Private,
+ /// Workgroup shared data, mutable.
+ WorkGroup,
+ /// Uniform buffer data.
+ Uniform,
+ /// Storage buffer data, potentially mutable.
+ Storage { access: StorageAccess },
+ /// Opaque handles, such as samplers and images.
+ Handle,
+ /// Push constants.
+ PushConstant,
+}
+
+/// Built-in inputs and outputs.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum BuiltIn {
+ Position { invariant: bool },
+ ViewIndex,
+ // vertex
+ BaseInstance,
+ BaseVertex,
+ ClipDistance,
+ CullDistance,
+ InstanceIndex,
+ PointSize,
+ VertexIndex,
+ // fragment
+ FragDepth,
+ PointCoord,
+ FrontFacing,
+ PrimitiveIndex,
+ SampleIndex,
+ SampleMask,
+ // compute
+ GlobalInvocationId,
+ LocalInvocationId,
+ LocalInvocationIndex,
+ WorkGroupId,
+ WorkGroupSize,
+ NumWorkGroups,
+}
+
+/// Number of bytes per scalar.
+pub type Bytes = u8;
+
+/// Number of components in a vector.
+#[repr(u8)]
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum VectorSize {
+ /// 2D vector
+ Bi = 2,
+ /// 3D vector
+ Tri = 3,
+ /// 4D vector
+ Quad = 4,
+}
+
+/// Primitive type for a scalar.
+#[repr(u8)]
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum ScalarKind {
+ /// Signed integer type.
+ Sint,
+ /// Unsigned integer type.
+ Uint,
+ /// Floating point type.
+ Float,
+ /// Boolean type.
+ Bool,
+}
+
+/// Size of an array.
+#[repr(u8)]
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum ArraySize {
+ /// The array size is constant.
+ Constant(Handle<Constant>),
+ /// The array size can change at runtime.
+ Dynamic,
+}
+
+/// The interpolation qualifier of a binding or struct field.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum Interpolation {
+ /// The value will be interpolated in a perspective-correct fashion.
+ /// Also known as "smooth" in glsl.
+ Perspective,
+ /// Indicates that linear, non-perspective, correct
+ /// interpolation must be used.
+ /// Also known as "no_perspective" in glsl.
+ Linear,
+ /// Indicates that no interpolation will be performed.
+ Flat,
+}
+
+/// The sampling qualifiers of a binding or struct field.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum Sampling {
+ /// Interpolate the value at the center of the pixel.
+ Center,
+
+ /// Interpolate the value at a point that lies within all samples covered by
+ /// the fragment within the current primitive. In multisampling, use a
+ /// single value for all samples in the primitive.
+ Centroid,
+
+ /// Interpolate the value at each sample location. In multisampling, invoke
+ /// the fragment shader once per sample.
+ Sample,
+}
+
+/// Member of a user-defined structure.
+// Clone is used only for error reporting and is not intended for end users
+#[derive(Clone, Debug, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct StructMember {
+ pub name: Option<String>,
+ /// Type of the field.
+ pub ty: Handle<Type>,
+ /// For I/O structs, defines the binding.
+ pub binding: Option<Binding>,
+ /// Offset from the beginning from the struct.
+ pub offset: u32,
+}
+
+/// The number of dimensions an image has.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum ImageDimension {
+ /// 1D image
+ D1,
+ /// 2D image
+ D2,
+ /// 3D image
+ D3,
+ /// Cube map
+ Cube,
+}
+
+bitflags::bitflags! {
+ /// Flags describing an image.
+ #[cfg_attr(feature = "serialize", derive(Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(Deserialize))]
+ #[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+ #[derive(Default)]
+ pub struct StorageAccess: u32 {
+ /// Storage can be used as a source for load ops.
+ const LOAD = 0x1;
+ /// Storage can be used as a target for store ops.
+ const STORE = 0x2;
+ }
+}
+
+/// Image storage format.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum StorageFormat {
+ // 8-bit formats
+ R8Unorm,
+ R8Snorm,
+ R8Uint,
+ R8Sint,
+
+ // 16-bit formats
+ R16Uint,
+ R16Sint,
+ R16Float,
+ Rg8Unorm,
+ Rg8Snorm,
+ Rg8Uint,
+ Rg8Sint,
+
+ // 32-bit formats
+ R32Uint,
+ R32Sint,
+ R32Float,
+ Rg16Uint,
+ Rg16Sint,
+ Rg16Float,
+ Rgba8Unorm,
+ Rgba8Snorm,
+ Rgba8Uint,
+ Rgba8Sint,
+
+ // Packed 32-bit formats
+ Rgb10a2Unorm,
+ Rg11b10Float,
+
+ // 64-bit formats
+ Rg32Uint,
+ Rg32Sint,
+ Rg32Float,
+ Rgba16Uint,
+ Rgba16Sint,
+ Rgba16Float,
+
+ // 128-bit formats
+ Rgba32Uint,
+ Rgba32Sint,
+ Rgba32Float,
+
+ // Normalized 16-bit per channel formats
+ R16Unorm,
+ R16Snorm,
+ Rg16Unorm,
+ Rg16Snorm,
+ Rgba16Unorm,
+ Rgba16Snorm,
+}
+
+/// Sub-class of the image type.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum ImageClass {
+ /// Regular sampled image.
+ Sampled {
+ /// Kind of values to sample.
+ kind: ScalarKind,
+ /// Multi-sampled image.
+ ///
+ /// A multi-sampled image holds several samples per texel. Multi-sampled
+ /// images cannot have mipmaps.
+ multi: bool,
+ },
+ /// Depth comparison image.
+ Depth {
+ /// Multi-sampled depth image.
+ multi: bool,
+ },
+ /// Storage image.
+ Storage {
+ format: StorageFormat,
+ access: StorageAccess,
+ },
+}
+
+/// A data type declared in the module.
+#[derive(Debug, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct Type {
+ /// The name of the type, if any.
+ pub name: Option<String>,
+ /// Inner structure that depends on the kind of the type.
+ pub inner: TypeInner,
+}
+
+/// Enum with additional information, depending on the kind of type.
+#[derive(Debug, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum TypeInner {
+ /// Number of integral or floating-point kind.
+ Scalar { kind: ScalarKind, width: Bytes },
+ /// Vector of numbers.
+ Vector {
+ size: VectorSize,
+ kind: ScalarKind,
+ width: Bytes,
+ },
+ /// Matrix of floats.
+ Matrix {
+ columns: VectorSize,
+ rows: VectorSize,
+ width: Bytes,
+ },
+ /// Atomic scalar.
+ Atomic { kind: ScalarKind, width: Bytes },
+ /// Pointer to another type.
+ ///
+ /// Pointers to scalars and vectors should be treated as equivalent to
+ /// [`ValuePointer`] types. Use the [`TypeInner::equivalent`] method to
+ /// compare types in a way that treats pointers correctly.
+ ///
+ /// ## Pointers to non-`SIZED` types
+ ///
+ /// The `base` type of a pointer may be a non-[`SIZED`] type like a
+ /// dynamically-sized [`Array`], or a [`Struct`] whose last member is a
+ /// dynamically sized array. Such pointers occur as the types of
+ /// [`GlobalVariable`] or [`AccessIndex`] expressions referring to
+ /// dynamically-sized arrays.
+ ///
+ /// However, among pointers to non-`SIZED` types, only pointers to `Struct`s
+ /// are [`DATA`]. Pointers to dynamically sized `Array`s cannot be passed as
+ /// arguments, stored in variables, or held in arrays or structures. Their
+ /// only use is as the types of `AccessIndex` expressions.
+ ///
+ /// [`SIZED`]: valid::TypeFlags::SIZED
+ /// [`DATA`]: valid::TypeFlags::DATA
+ /// [`Array`]: TypeInner::Array
+ /// [`Struct`]: TypeInner::Struct
+ /// [`ValuePointer`]: TypeInner::ValuePointer
+ /// [`GlobalVariable`]: Expression::GlobalVariable
+ /// [`AccessIndex`]: Expression::AccessIndex
+ Pointer {
+ base: Handle<Type>,
+ space: AddressSpace,
+ },
+
+ /// Pointer to a scalar or vector.
+ ///
+ /// A `ValuePointer` type is equivalent to a `Pointer` whose `base` is a
+ /// `Scalar` or `Vector` type. This is for use in [`TypeResolution::Value`]
+ /// variants; see the documentation for [`TypeResolution`] for details.
+ ///
+ /// Use the [`TypeInner::equivalent`] method to compare types that could be
+ /// pointers, to ensure that `Pointer` and `ValuePointer` types are
+ /// recognized as equivalent.
+ ///
+ /// [`TypeResolution`]: proc::TypeResolution
+ /// [`TypeResolution::Value`]: proc::TypeResolution::Value
+ ValuePointer {
+ size: Option<VectorSize>,
+ kind: ScalarKind,
+ width: Bytes,
+ space: AddressSpace,
+ },
+
+ /// Homogenous list of elements.
+ ///
+ /// The `base` type must be a [`SIZED`], [`DATA`] type.
+ ///
+ /// ## Dynamically sized arrays
+ ///
+ /// An `Array` is [`SIZED`] unless its `size` is [`Dynamic`].
+ /// Dynamically-sized arrays may only appear in a few situations:
+ ///
+ /// - They may appear as the type of a [`GlobalVariable`], or as the last
+ /// member of a [`Struct`].
+ ///
+ /// - They may appear as the base type of a [`Pointer`]. An
+ /// [`AccessIndex`] expression referring to a struct's final
+ /// unsized array member would have such a pointer type. However, such
+ /// pointer types may only appear as the types of such intermediate
+ /// expressions. They are not [`DATA`], and cannot be stored in
+ /// variables, held in arrays or structs, or passed as parameters.
+ ///
+ /// [`SIZED`]: crate::valid::TypeFlags::SIZED
+ /// [`DATA`]: crate::valid::TypeFlags::DATA
+ /// [`Dynamic`]: ArraySize::Dynamic
+ /// [`Struct`]: TypeInner::Struct
+ /// [`Pointer`]: TypeInner::Pointer
+ /// [`AccessIndex`]: Expression::AccessIndex
+ Array {
+ base: Handle<Type>,
+ size: ArraySize,
+ stride: u32,
+ },
+
+ /// User-defined structure.
+ ///
+ /// There must always be at least one member.
+ ///
+ /// A `Struct` type is [`DATA`], and the types of its members must be
+ /// `DATA` as well.
+ ///
+ /// Member types must be [`SIZED`], except for the final member of a
+ /// struct, which may be a dynamically sized [`Array`]. The
+ /// `Struct` type itself is `SIZED` when all its members are `SIZED`.
+ ///
+ /// [`DATA`]: crate::valid::TypeFlags::DATA
+ /// [`SIZED`]: crate::valid::TypeFlags::SIZED
+ /// [`Array`]: TypeInner::Array
+ Struct {
+ members: Vec<StructMember>,
+ //TODO: should this be unaligned?
+ span: u32,
+ },
+ /// Possibly multidimensional array of texels.
+ Image {
+ dim: ImageDimension,
+ arrayed: bool,
+ //TODO: consider moving `multisampled: bool` out
+ class: ImageClass,
+ },
+ /// Can be used to sample values from images.
+ Sampler { comparison: bool },
+
+ /// Opaque object representing an acceleration structure of geometry.
+ AccelerationStructure,
+
+ /// Locally used handle for ray queries.
+ RayQuery,
+
+ /// Array of bindings.
+ ///
+ /// A `BindingArray` represents an array where each element draws its value
+ /// from a separate bound resource. The array's element type `base` may be
+ /// [`Image`], [`Sampler`], or any type that would be permitted for a global
+ /// in the [`Uniform`] or [`Storage`] address spaces. Only global variables
+ /// may be binding arrays; on the host side, their values are provided by
+ /// [`TextureViewArray`], [`SamplerArray`], or [`BufferArray`]
+ /// bindings.
+ ///
+ /// Since each element comes from a distinct resource, a binding array of
+ /// images could have images of varying sizes (but not varying dimensions;
+ /// they must all have the same `Image` type). Or, a binding array of
+ /// buffers could have elements that are dynamically sized arrays, each with
+ /// a different length.
+ ///
+ /// Binding arrays are not [`DATA`]. This means that all binding array
+ /// globals must be placed in the [`Handle`] address space. Referring to
+ /// such a global produces a `BindingArray` value directly; there are never
+ /// pointers to binding arrays. The only operation permitted on
+ /// `BindingArray` values is indexing, which yields the element by value,
+ /// not a pointer to the element. (This means that buffer array contents
+ /// cannot be stored to; [naga#1864] covers lifting this restriction.)
+ ///
+ /// Unlike textures and samplers, binding arrays are not [`ARGUMENT`], so
+ /// they cannot be passed as arguments to functions.
+ ///
+ /// Naga's WGSL front end supports binding arrays with the type syntax
+ /// `binding_array<T, N>`.
+ ///
+ /// [`Image`]: TypeInner::Image
+ /// [`Sampler`]: TypeInner::Sampler
+ /// [`Uniform`]: AddressSpace::Uniform
+ /// [`Storage`]: AddressSpace::Storage
+ /// [`TextureViewArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.TextureViewArray
+ /// [`SamplerArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.SamplerArray
+ /// [`BufferArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.BufferArray
+ /// [`DATA`]: crate::valid::TypeFlags::DATA
+ /// [`Handle`]: AddressSpace::Handle
+ /// [`ARGUMENT`]: crate::valid::TypeFlags::ARGUMENT
+ /// [naga#1864]: https://github.com/gfx-rs/naga/issues/1864
+ BindingArray { base: Handle<Type>, size: ArraySize },
+}
+
+/// Constant value.
+#[derive(Debug, PartialEq)]
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct Constant {
+ pub name: Option<String>,
+ pub specialization: Option<u32>,
+ pub inner: ConstantInner,
+}
+
+/// A literal scalar value, used in constants.
+#[derive(Debug, Clone, Copy, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum ScalarValue {
+ Sint(i64),
+ Uint(u64),
+ Float(f64),
+ Bool(bool),
+}
+
+/// Additional information, dependent on the kind of constant.
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum ConstantInner {
+ Scalar {
+ width: Bytes,
+ value: ScalarValue,
+ },
+ Composite {
+ ty: Handle<Type>,
+ components: Vec<Handle<Constant>>,
+ },
+}
+
+/// Describes how an input/output variable is to be bound.
+#[derive(Clone, Debug, Eq, PartialEq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum Binding {
+ /// Built-in shader variable.
+ BuiltIn(BuiltIn),
+
+ /// Indexed location.
+ ///
+ /// Values passed from the [`Vertex`] stage to the [`Fragment`] stage must
+ /// have their `interpolation` defaulted (i.e. not `None`) by the front end
+ /// as appropriate for that language.
+ ///
+ /// For other stages, we permit interpolations even though they're ignored.
+ /// When a front end is parsing a struct type, it usually doesn't know what
+ /// stages will be using it for IO, so it's easiest if it can apply the
+ /// defaults to anything with a `Location` binding, just in case.
+ ///
+ /// For anything other than floating-point scalars and vectors, the
+ /// interpolation must be `Flat`.
+ ///
+ /// [`Vertex`]: crate::ShaderStage::Vertex
+ /// [`Fragment`]: crate::ShaderStage::Fragment
+ Location {
+ location: u32,
+ interpolation: Option<Interpolation>,
+ sampling: Option<Sampling>,
+ },
+}
+
+/// Pipeline binding information for global resources.
+#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct ResourceBinding {
+ /// The bind group index.
+ pub group: u32,
+ /// Binding number within the group.
+ pub binding: u32,
+}
+
+/// Variable defined at module level.
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct GlobalVariable {
+ /// Name of the variable, if any.
+ pub name: Option<String>,
+ /// How this variable is to be stored.
+ pub space: AddressSpace,
+ /// For resources, defines the binding point.
+ pub binding: Option<ResourceBinding>,
+ /// The type of this variable.
+ pub ty: Handle<Type>,
+ /// Initial value for this variable.
+ pub init: Option<Handle<Constant>>,
+}
+
+/// Variable defined at function level.
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct LocalVariable {
+ /// Name of the variable, if any.
+ pub name: Option<String>,
+ /// The type of this variable.
+ pub ty: Handle<Type>,
+ /// Initial value for this variable.
+ pub init: Option<Handle<Constant>>,
+}
+
+/// Operation that can be applied on a single value.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum UnaryOperator {
+ Negate,
+ Not,
+}
+
+/// Operation that can be applied on two values.
+///
+/// ## Arithmetic type rules
+///
+/// The arithmetic operations `Add`, `Subtract`, `Multiply`, `Divide`, and
+/// `Modulo` can all be applied to [`Scalar`] types other than [`Bool`], or
+/// [`Vector`]s thereof. Both operands must have the same type.
+///
+/// `Add` and `Subtract` can also be applied to [`Matrix`] values. Both operands
+/// must have the same type.
+///
+/// `Multiply` supports additional cases:
+///
+/// - A [`Matrix`] or [`Vector`] can be multiplied by a scalar [`Float`],
+/// either on the left or the right.
+///
+/// - A [`Matrix`] on the left can be multiplied by a [`Vector`] on the right
+/// if the matrix has as many columns as the vector has components (`matCxR
+/// * VecC`).
+///
+/// - A [`Vector`] on the left can be multiplied by a [`Matrix`] on the right
+/// if the matrix has as many rows as the vector has components (`VecR *
+/// matCxR`).
+///
+/// - Two matrices can be multiplied if the left operand has as many columns
+/// as the right operand has rows (`matNxR * matCxN`).
+///
+/// In all the above `Multiply` cases, the byte widths of the underlying scalar
+/// types of both operands must be the same.
+///
+/// Note that `Multiply` supports mixed vector and scalar operations directly,
+/// whereas the other arithmetic operations require an explicit [`Splat`] for
+/// mixed-type use.
+///
+/// [`Scalar`]: TypeInner::Scalar
+/// [`Vector`]: TypeInner::Vector
+/// [`Matrix`]: TypeInner::Matrix
+/// [`Float`]: ScalarKind::Float
+/// [`Bool`]: ScalarKind::Bool
+/// [`Splat`]: Expression::Splat
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum BinaryOperator {
+ Add,
+ Subtract,
+ Multiply,
+ Divide,
+ /// Equivalent of the WGSL's `%` operator or SPIR-V's `OpFRem`
+ Modulo,
+ Equal,
+ NotEqual,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+ And,
+ ExclusiveOr,
+ InclusiveOr,
+ LogicalAnd,
+ LogicalOr,
+ ShiftLeft,
+ /// Right shift carries the sign of signed integers only.
+ ShiftRight,
+}
+
+/// Function on an atomic value.
+///
+/// Note: these do not include load/store, which use the existing
+/// [`Expression::Load`] and [`Statement::Store`].
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum AtomicFunction {
+ Add,
+ Subtract,
+ And,
+ ExclusiveOr,
+ InclusiveOr,
+ Min,
+ Max,
+ Exchange { compare: Option<Handle<Expression>> },
+}
+
+/// Hint at which precision to compute a derivative.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum DerivativeControl {
+ Coarse,
+ Fine,
+ None,
+}
+
+/// Axis on which to compute a derivative.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum DerivativeAxis {
+ X,
+ Y,
+ Width,
+}
+
+/// Built-in shader function for testing relation between values.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum RelationalFunction {
+ All,
+ Any,
+ IsNan,
+ IsInf,
+ IsFinite,
+ IsNormal,
+}
+
+/// Built-in shader function for math.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum MathFunction {
+ // comparison
+ Abs,
+ Min,
+ Max,
+ Clamp,
+ Saturate,
+ // trigonometry
+ Cos,
+ Cosh,
+ Sin,
+ Sinh,
+ Tan,
+ Tanh,
+ Acos,
+ Asin,
+ Atan,
+ Atan2,
+ Asinh,
+ Acosh,
+ Atanh,
+ Radians,
+ Degrees,
+ // decomposition
+ Ceil,
+ Floor,
+ Round,
+ Fract,
+ Trunc,
+ Modf,
+ Frexp,
+ Ldexp,
+ // exponent
+ Exp,
+ Exp2,
+ Log,
+ Log2,
+ Pow,
+ // geometry
+ Dot,
+ Outer,
+ Cross,
+ Distance,
+ Length,
+ Normalize,
+ FaceForward,
+ Reflect,
+ Refract,
+ // computational
+ Sign,
+ Fma,
+ Mix,
+ Step,
+ SmoothStep,
+ Sqrt,
+ InverseSqrt,
+ Inverse,
+ Transpose,
+ Determinant,
+ // bits
+ CountTrailingZeros,
+ CountLeadingZeros,
+ CountOneBits,
+ ReverseBits,
+ ExtractBits,
+ InsertBits,
+ FindLsb,
+ FindMsb,
+ // data packing
+ Pack4x8snorm,
+ Pack4x8unorm,
+ Pack2x16snorm,
+ Pack2x16unorm,
+ Pack2x16float,
+ // data unpacking
+ Unpack4x8snorm,
+ Unpack4x8unorm,
+ Unpack2x16snorm,
+ Unpack2x16unorm,
+ Unpack2x16float,
+}
+
+/// Sampling modifier to control the level of detail.
+#[derive(Clone, Copy, Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum SampleLevel {
+ Auto,
+ Zero,
+ Exact(Handle<Expression>),
+ Bias(Handle<Expression>),
+ Gradient {
+ x: Handle<Expression>,
+ y: Handle<Expression>,
+ },
+}
+
+/// Type of an image query.
+#[derive(Clone, Copy, Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum ImageQuery {
+ /// Get the size at the specified level.
+ Size {
+ /// If `None`, the base level is considered.
+ level: Option<Handle<Expression>>,
+ },
+ /// Get the number of mipmap levels.
+ NumLevels,
+ /// Get the number of array layers.
+ NumLayers,
+ /// Get the number of samples.
+ NumSamples,
+}
+
+/// Component selection for a vector swizzle.
+#[repr(u8)]
+#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum SwizzleComponent {
+ ///
+ X = 0,
+ ///
+ Y = 1,
+ ///
+ Z = 2,
+ ///
+ W = 3,
+}
+
+bitflags::bitflags! {
+ /// Memory barrier flags.
+ #[cfg_attr(feature = "serialize", derive(Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(Deserialize))]
+ #[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+ #[derive(Default)]
+ pub struct Barrier: u32 {
+ /// Barrier affects all `AddressSpace::Storage` accesses.
+ const STORAGE = 0x1;
+ /// Barrier affects all `AddressSpace::WorkGroup` accesses.
+ const WORK_GROUP = 0x2;
+ }
+}
+
+/// An expression that can be evaluated to obtain a value.
+///
+/// This is a Single Static Assignment (SSA) scheme similar to SPIR-V.
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum Expression {
+ /// Array access with a computed index.
+ ///
+ /// ## Typing rules
+ ///
+ /// The `base` operand must be some composite type: [`Vector`], [`Matrix`],
+ /// [`Array`], a [`Pointer`] to one of those, or a [`ValuePointer`] with a
+ /// `size`.
+ ///
+ /// The `index` operand must be an integer, signed or unsigned.
+ ///
+ /// Indexing a [`Vector`] or [`Array`] produces a value of its element type.
+ /// Indexing a [`Matrix`] produces a [`Vector`].
+ ///
+ /// Indexing a [`Pointer`] to any of the above produces a pointer to the
+ /// element/component type, in the same [`space`]. In the case of [`Array`],
+ /// the result is an actual [`Pointer`], but for vectors and matrices, there
+ /// may not be any type in the arena representing the component's type, so
+ /// those produce [`ValuePointer`] types equivalent to the appropriate
+ /// [`Pointer`].
+ ///
+ /// ## Dynamic indexing restrictions
+ ///
+ /// To accommodate restrictions in some of the shader languages that Naga
+ /// targets, it is not permitted to subscript a matrix or array with a
+ /// dynamically computed index unless that matrix or array appears behind a
+ /// pointer. In other words, if the inner type of `base` is [`Array`] or
+ /// [`Matrix`], then `index` must be a constant. But if the type of `base`
+ /// is a [`Pointer`] to an array or matrix or a [`ValuePointer`] with a
+ /// `size`, then the index may be any expression of integer type.
+ ///
+ /// You can use the [`Expression::is_dynamic_index`] method to determine
+ /// whether a given index expression requires matrix or array base operands
+ /// to be behind a pointer.
+ ///
+ /// (It would be simpler to always require the use of `AccessIndex` when
+ /// subscripting arrays and matrices that are not behind pointers, but to
+ /// accommodate existing front ends, Naga also permits `Access`, with a
+ /// restricted `index`.)
+ ///
+ /// [`Vector`]: TypeInner::Vector
+ /// [`Matrix`]: TypeInner::Matrix
+ /// [`Array`]: TypeInner::Array
+ /// [`Pointer`]: TypeInner::Pointer
+ /// [`space`]: TypeInner::Pointer::space
+ /// [`ValuePointer`]: TypeInner::ValuePointer
+ /// [`Float`]: ScalarKind::Float
+ Access {
+ base: Handle<Expression>,
+ index: Handle<Expression>,
+ },
+ /// Access the same types as [`Access`], plus [`Struct`] with a known index.
+ ///
+ /// [`Access`]: Expression::Access
+ /// [`Struct`]: TypeInner::Struct
+ AccessIndex {
+ base: Handle<Expression>,
+ index: u32,
+ },
+ /// Constant value.
+ Constant(Handle<Constant>),
+ /// Splat scalar into a vector.
+ Splat {
+ size: VectorSize,
+ value: Handle<Expression>,
+ },
+ /// Vector swizzle.
+ Swizzle {
+ size: VectorSize,
+ vector: Handle<Expression>,
+ pattern: [SwizzleComponent; 4],
+ },
+ /// Composite expression.
+ Compose {
+ ty: Handle<Type>,
+ components: Vec<Handle<Expression>>,
+ },
+
+ /// Reference a function parameter, by its index.
+ ///
+ /// A `FunctionArgument` expression evaluates to a pointer to the argument's
+ /// value. You must use a [`Load`] expression to retrieve its value, or a
+ /// [`Store`] statement to assign it a new value.
+ ///
+ /// [`Load`]: Expression::Load
+ /// [`Store`]: Statement::Store
+ FunctionArgument(u32),
+
+ /// Reference a global variable.
+ ///
+ /// If the given `GlobalVariable`'s [`space`] is [`AddressSpace::Handle`],
+ /// then the variable stores some opaque type like a sampler or an image,
+ /// and a `GlobalVariable` expression referring to it produces the
+ /// variable's value directly.
+ ///
+ /// For any other address space, a `GlobalVariable` expression produces a
+ /// pointer to the variable's value. You must use a [`Load`] expression to
+ /// retrieve its value, or a [`Store`] statement to assign it a new value.
+ ///
+ /// [`space`]: GlobalVariable::space
+ /// [`Load`]: Expression::Load
+ /// [`Store`]: Statement::Store
+ GlobalVariable(Handle<GlobalVariable>),
+
+ /// Reference a local variable.
+ ///
+ /// A `LocalVariable` expression evaluates to a pointer to the variable's value.
+ /// You must use a [`Load`](Expression::Load) expression to retrieve its value,
+ /// or a [`Store`](Statement::Store) statement to assign it a new value.
+ LocalVariable(Handle<LocalVariable>),
+
+ /// Load a value indirectly.
+ ///
+ /// For [`TypeInner::Atomic`] the result is a corresponding scalar.
+ /// For other types behind the `pointer<T>`, the result is `T`.
+ Load { pointer: Handle<Expression> },
+ /// Sample a point from a sampled or a depth image.
+ ImageSample {
+ image: Handle<Expression>,
+ sampler: Handle<Expression>,
+ /// If Some(), this operation is a gather operation
+ /// on the selected component.
+ gather: Option<SwizzleComponent>,
+ coordinate: Handle<Expression>,
+ array_index: Option<Handle<Expression>>,
+ offset: Option<Handle<Constant>>,
+ level: SampleLevel,
+ depth_ref: Option<Handle<Expression>>,
+ },
+
+ /// Load a texel from an image.
+ ///
+ /// For most images, this returns a four-element vector of the same
+ /// [`ScalarKind`] as the image. If the format of the image does not have
+ /// four components, default values are provided: the first three components
+ /// (typically R, G, and B) default to zero, and the final component
+ /// (typically alpha) defaults to one.
+ ///
+ /// However, if the image's [`class`] is [`Depth`], then this returns a
+ /// [`Float`] scalar value.
+ ///
+ /// [`ScalarKind`]: ScalarKind
+ /// [`class`]: TypeInner::Image::class
+ /// [`Depth`]: ImageClass::Depth
+ /// [`Float`]: ScalarKind::Float
+ ImageLoad {
+ /// The image to load a texel from. This must have type [`Image`]. (This
+ /// will necessarily be a [`GlobalVariable`] or [`FunctionArgument`]
+ /// expression, since no other expressions are allowed to have that
+ /// type.)
+ ///
+ /// [`Image`]: TypeInner::Image
+ /// [`GlobalVariable`]: Expression::GlobalVariable
+ /// [`FunctionArgument`]: Expression::FunctionArgument
+ image: Handle<Expression>,
+
+ /// The coordinate of the texel we wish to load. This must be a scalar
+ /// for [`D1`] images, a [`Bi`] vector for [`D2`] images, and a [`Tri`]
+ /// vector for [`D3`] images. (Array indices, sample indices, and
+ /// explicit level-of-detail values are supplied separately.) Its
+ /// component type must be [`Sint`].
+ ///
+ /// [`D1`]: ImageDimension::D1
+ /// [`D2`]: ImageDimension::D2
+ /// [`D3`]: ImageDimension::D3
+ /// [`Bi`]: VectorSize::Bi
+ /// [`Tri`]: VectorSize::Tri
+ /// [`Sint`]: ScalarKind::Sint
+ coordinate: Handle<Expression>,
+
+ /// The index into an arrayed image. If the [`arrayed`] flag in
+ /// `image`'s type is `true`, then this must be `Some(expr)`, where
+ /// `expr` is a [`Sint`] scalar. Otherwise, it must be `None`.
+ ///
+ /// [`arrayed`]: TypeInner::Image::arrayed
+ /// [`Sint`]: ScalarKind::Sint
+ array_index: Option<Handle<Expression>>,
+
+ /// A sample index, for multisampled [`Sampled`] and [`Depth`] images.
+ ///
+ /// [`Sampled`]: ImageClass::Sampled
+ /// [`Depth`]: ImageClass::Depth
+ sample: Option<Handle<Expression>>,
+
+ /// A level of detail, for mipmapped images.
+ ///
+ /// This must be present when accessing non-multisampled
+ /// [`Sampled`] and [`Depth`] images, even if only the
+ /// full-resolution level is present (in which case the only
+ /// valid level is zero).
+ ///
+ /// [`Sampled`]: ImageClass::Sampled
+ /// [`Depth`]: ImageClass::Depth
+ level: Option<Handle<Expression>>,
+ },
+
+ /// Query information from an image.
+ ImageQuery {
+ image: Handle<Expression>,
+ query: ImageQuery,
+ },
+ /// Apply an unary operator.
+ Unary {
+ op: UnaryOperator,
+ expr: Handle<Expression>,
+ },
+ /// Apply a binary operator.
+ Binary {
+ op: BinaryOperator,
+ left: Handle<Expression>,
+ right: Handle<Expression>,
+ },
+ /// Select between two values based on a condition.
+ ///
+ /// Note that, because expressions have no side effects, it is unobservable
+ /// whether the non-selected branch is evaluated.
+ Select {
+ /// Boolean expression
+ condition: Handle<Expression>,
+ accept: Handle<Expression>,
+ reject: Handle<Expression>,
+ },
+ /// Compute the derivative on an axis.
+ Derivative {
+ axis: DerivativeAxis,
+ ctrl: DerivativeControl,
+ //modifier,
+ expr: Handle<Expression>,
+ },
+ /// Call a relational function.
+ Relational {
+ fun: RelationalFunction,
+ argument: Handle<Expression>,
+ },
+ /// Call a math function
+ Math {
+ fun: MathFunction,
+ arg: Handle<Expression>,
+ arg1: Option<Handle<Expression>>,
+ arg2: Option<Handle<Expression>>,
+ arg3: Option<Handle<Expression>>,
+ },
+ /// Cast a simple type to another kind.
+ As {
+ /// Source expression, which can only be a scalar or a vector.
+ expr: Handle<Expression>,
+ /// Target scalar kind.
+ kind: ScalarKind,
+ /// If provided, converts to the specified byte width.
+ /// Otherwise, bitcast.
+ convert: Option<Bytes>,
+ },
+ /// Result of calling another function.
+ CallResult(Handle<Function>),
+ /// Result of an atomic operation.
+ AtomicResult { ty: Handle<Type>, comparison: bool },
+ /// Get the length of an array.
+ /// The expression must resolve to a pointer to an array with a dynamic size.
+ ///
+ /// This doesn't match the semantics of spirv's `OpArrayLength`, which must be passed
+ /// a pointer to a structure containing a runtime array in its' last field.
+ ArrayLength(Handle<Expression>),
+
+ /// Result of a [`Proceed`] [`RayQuery`] statement.
+ ///
+ /// [`Proceed`]: RayQueryFunction::Proceed
+ /// [`RayQuery`]: Statement::RayQuery
+ RayQueryProceedResult,
+
+ /// Return an intersection found by `query`.
+ ///
+ /// If `committed` is true, return the committed result available when
+ RayQueryGetIntersection {
+ query: Handle<Expression>,
+ committed: bool,
+ },
+}
+
+pub use block::Block;
+
+/// The value of the switch case.
+// Clone is used only for error reporting and is not intended for end users
+#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum SwitchValue {
+ I32(i32),
+ U32(u32),
+ Default,
+}
+
+/// A case for a switch statement.
+// Clone is used only for error reporting and is not intended for end users
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct SwitchCase {
+ /// Value, upon which the case is considered true.
+ pub value: SwitchValue,
+ /// Body of the case.
+ pub body: Block,
+ /// If true, the control flow continues to the next case in the list,
+ /// or default.
+ pub fall_through: bool,
+}
+
+/// An operation that a [`RayQuery` statement] applies to its [`query`] operand.
+///
+/// [`RayQuery` statement]: Statement::RayQuery
+/// [`query`]: Statement::RayQuery::query
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum RayQueryFunction {
+ /// Initialize the `RayQuery` object.
+ Initialize {
+ /// The acceleration structure within which this query should search for hits.
+ ///
+ /// The expression must be an [`AccelerationStructure`].
+ ///
+ /// [`AccelerationStructure`]: TypeInner::AccelerationStructure
+ acceleration_structure: Handle<Expression>,
+
+ #[allow(rustdoc::private_intra_doc_links)]
+ /// A struct of detailed parameters for the ray query.
+ ///
+ /// This expression should have the struct type given in
+ /// [`SpecialTypes::ray_desc`]. This is available in the WGSL
+ /// front end as the `RayDesc` type.
+ descriptor: Handle<Expression>,
+ },
+
+ /// Start or continue the query given by the statement's [`query`] operand.
+ ///
+ /// After executing this statement, the `result` expression is a
+ /// [`Bool`] scalar indicating whether there are more intersection
+ /// candidates to consider.
+ ///
+ /// [`query`]: Statement::RayQuery::query
+ /// [`Bool`]: ScalarKind::Bool
+ Proceed {
+ result: Handle<Expression>,
+ },
+
+ Terminate,
+}
+
+//TODO: consider removing `Clone`. It's not valid to clone `Statement::Emit` anyway.
+/// Instructions which make up an executable block.
+// Clone is used only for error reporting and is not intended for end users
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum Statement {
+ /// Emit a range of expressions, visible to all statements that follow in this block.
+ ///
+ /// See the [module-level documentation][emit] for details.
+ ///
+ /// [emit]: index.html#expression-evaluation-time
+ Emit(Range<Expression>),
+ /// A block containing more statements, to be executed sequentially.
+ Block(Block),
+ /// Conditionally executes one of two blocks, based on the value of the condition.
+ If {
+ condition: Handle<Expression>, //bool
+ accept: Block,
+ reject: Block,
+ },
+ /// Conditionally executes one of multiple blocks, based on the value of the selector.
+ ///
+ /// Each case must have a distinct [`value`], exactly one of which must be
+ /// [`Default`]. The `Default` may appear at any position, and covers all
+ /// values not explicitly appearing in other cases. A `Default` appearing in
+ /// the midst of the list of cases does not shadow the cases that follow.
+ ///
+ /// Some backend languages don't support fallthrough (HLSL due to FXC,
+ /// WGSL), and may translate fallthrough cases in the IR by duplicating
+ /// code. However, all backend languages do support cases selected by
+ /// multiple values, like `case 1: case 2: case 3: { ... }`. This is
+ /// represented in the IR as a series of fallthrough cases with empty
+ /// bodies, except for the last.
+ ///
+ /// [`value`]: SwitchCase::value
+ /// [`body`]: SwitchCase::body
+ /// [`Default`]: SwitchValue::Default
+ Switch {
+ selector: Handle<Expression>, //int
+ cases: Vec<SwitchCase>,
+ },
+
+ /// Executes a block repeatedly.
+ ///
+ /// Each iteration of the loop executes the `body` block, followed by the
+ /// `continuing` block.
+ ///
+ /// Executing a [`Break`], [`Return`] or [`Kill`] statement exits the loop.
+ ///
+ /// A [`Continue`] statement in `body` jumps to the `continuing` block. The
+ /// `continuing` block is meant to be used to represent structures like the
+ /// third expression of a C-style `for` loop head, to which `continue`
+ /// statements in the loop's body jump.
+ ///
+ /// The `continuing` block and its substatements must not contain `Return`
+ /// or `Kill` statements, or any `Break` or `Continue` statements targeting
+ /// this loop. (It may have `Break` and `Continue` statements targeting
+ /// loops or switches nested within the `continuing` block.)
+ ///
+ /// If present, `break_if` is an expression which is evaluated after the
+ /// continuing block. If its value is true, control continues after the
+ /// `Loop` statement, rather than branching back to the top of body as
+ /// usual. The `break_if` expression corresponds to a "break if" statement
+ /// in WGSL, or a loop whose back edge is an `OpBranchConditional`
+ /// instruction in SPIR-V.
+ ///
+ /// [`Break`]: Statement::Break
+ /// [`Continue`]: Statement::Continue
+ /// [`Kill`]: Statement::Kill
+ /// [`Return`]: Statement::Return
+ /// [`break if`]: Self::Loop::break_if
+ Loop {
+ body: Block,
+ continuing: Block,
+ break_if: Option<Handle<Expression>>,
+ },
+
+ /// Exits the innermost enclosing [`Loop`] or [`Switch`].
+ ///
+ /// A `Break` statement may only appear within a [`Loop`] or [`Switch`]
+ /// statement. It may not break out of a [`Loop`] from within the loop's
+ /// `continuing` block.
+ ///
+ /// [`Loop`]: Statement::Loop
+ /// [`Switch`]: Statement::Switch
+ Break,
+
+ /// Skips to the `continuing` block of the innermost enclosing [`Loop`].
+ ///
+ /// A `Continue` statement may only appear within the `body` block of the
+ /// innermost enclosing [`Loop`] statement. It must not appear within that
+ /// loop's `continuing` block.
+ ///
+ /// [`Loop`]: Statement::Loop
+ Continue,
+
+ /// Returns from the function (possibly with a value).
+ ///
+ /// `Return` statements are forbidden within the `continuing` block of a
+ /// [`Loop`] statement.
+ ///
+ /// [`Loop`]: Statement::Loop
+ Return { value: Option<Handle<Expression>> },
+
+ /// Aborts the current shader execution.
+ ///
+ /// `Kill` statements are forbidden within the `continuing` block of a
+ /// [`Loop`] statement.
+ ///
+ /// [`Loop`]: Statement::Loop
+ Kill,
+
+ /// Synchronize invocations within the work group.
+ /// The `Barrier` flags control which memory accesses should be synchronized.
+ /// If empty, this becomes purely an execution barrier.
+ Barrier(Barrier),
+ /// Stores a value at an address.
+ ///
+ /// For [`TypeInner::Atomic`] type behind the pointer, the value
+ /// has to be a corresponding scalar.
+ /// For other types behind the `pointer<T>`, the value is `T`.
+ ///
+ /// This statement is a barrier for any operations on the
+ /// `Expression::LocalVariable` or `Expression::GlobalVariable`
+ /// that is the destination of an access chain, started
+ /// from the `pointer`.
+ Store {
+ pointer: Handle<Expression>,
+ value: Handle<Expression>,
+ },
+ /// Stores a texel value to an image.
+ ///
+ /// The `image`, `coordinate`, and `array_index` fields have the same
+ /// meanings as the corresponding operands of an [`ImageLoad`] expression;
+ /// see that documentation for details. Storing into multisampled images or
+ /// images with mipmaps is not supported, so there are no `level` or
+ /// `sample` operands.
+ ///
+ /// This statement is a barrier for any operations on the corresponding
+ /// [`Expression::GlobalVariable`] for this image.
+ ///
+ /// [`ImageLoad`]: Expression::ImageLoad
+ ImageStore {
+ image: Handle<Expression>,
+ coordinate: Handle<Expression>,
+ array_index: Option<Handle<Expression>>,
+ value: Handle<Expression>,
+ },
+ /// Atomic function.
+ Atomic {
+ /// Pointer to an atomic value.
+ pointer: Handle<Expression>,
+ /// Function to run on the atomic.
+ fun: AtomicFunction,
+ /// Value to use in the function.
+ value: Handle<Expression>,
+ /// [`AtomicResult`] expression representing this function's result.
+ ///
+ /// [`AtomicResult`]: crate::Expression::AtomicResult
+ result: Handle<Expression>,
+ },
+ /// Calls a function.
+ ///
+ /// If the `result` is `Some`, the corresponding expression has to be
+ /// `Expression::CallResult`, and this statement serves as a barrier for any
+ /// operations on that expression.
+ Call {
+ function: Handle<Function>,
+ arguments: Vec<Handle<Expression>>,
+ result: Option<Handle<Expression>>,
+ },
+ RayQuery {
+ /// The [`RayQuery`] object this statement operates on.
+ ///
+ /// [`RayQuery`]: TypeInner::RayQuery
+ query: Handle<Expression>,
+
+ /// The specific operation we're performing on `query`.
+ fun: RayQueryFunction,
+ },
+}
+
+/// A function argument.
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct FunctionArgument {
+ /// Name of the argument, if any.
+ pub name: Option<String>,
+ /// Type of the argument.
+ pub ty: Handle<Type>,
+ /// For entry points, an argument has to have a binding
+ /// unless it's a structure.
+ pub binding: Option<Binding>,
+}
+
+/// A function result.
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct FunctionResult {
+ /// Type of the result.
+ pub ty: Handle<Type>,
+ /// For entry points, the result has to have a binding
+ /// unless it's a structure.
+ pub binding: Option<Binding>,
+}
+
+/// A function defined in the module.
+#[derive(Debug, Default)]
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct Function {
+ /// Name of the function, if any.
+ pub name: Option<String>,
+ /// Information about function argument.
+ pub arguments: Vec<FunctionArgument>,
+ /// The result of this function, if any.
+ pub result: Option<FunctionResult>,
+ /// Local variables defined and used in the function.
+ pub local_variables: Arena<LocalVariable>,
+ /// Expressions used inside this function.
+ ///
+ /// An `Expression` must occur before all other `Expression`s that use its
+ /// value.
+ pub expressions: Arena<Expression>,
+ /// Map of expressions that have associated variable names
+ pub named_expressions: NamedExpressions,
+ /// Block of instructions comprising the body of the function.
+ pub body: Block,
+}
+
+/// The main function for a pipeline stage.
+///
+/// An [`EntryPoint`] is a [`Function`] that serves as the main function for a
+/// graphics or compute pipeline stage. For example, an `EntryPoint` whose
+/// [`stage`] is [`ShaderStage::Vertex`] can serve as a graphics pipeline's
+/// vertex shader.
+///
+/// Since an entry point is called directly by the graphics or compute pipeline,
+/// not by other WGSL functions, you must specify what the pipeline should pass
+/// as the entry point's arguments, and what values it will return. For example,
+/// a vertex shader needs a vertex's attributes as its arguments, but if it's
+/// used for instanced draw calls, it will also want to know the instance id.
+/// The vertex shader's return value will usually include an output vertex
+/// position, and possibly other attributes to be interpolated and passed along
+/// to a fragment shader.
+///
+/// To specify this, the arguments and result of an `EntryPoint`'s [`function`]
+/// must each have a [`Binding`], or be structs whose members all have
+/// `Binding`s. This associates every value passed to or returned from the entry
+/// point with either a [`BuiltIn`] or a [`Location`]:
+///
+/// - A [`BuiltIn`] has special semantics, usually specific to its pipeline
+/// stage. For example, the result of a vertex shader can include a
+/// [`BuiltIn::Position`] value, which determines the position of a vertex
+/// of a rendered primitive. Or, a compute shader might take an argument
+/// whose binding is [`BuiltIn::WorkGroupSize`], through which the compute
+/// pipeline would pass the number of invocations in your workgroup.
+///
+/// - A [`Location`] indicates user-defined IO to be passed from one pipeline
+/// stage to the next. For example, a vertex shader might also produce a
+/// `uv` texture location as a user-defined IO value.
+///
+/// In other words, the pipeline stage's input and output interface are
+/// determined by the bindings of the arguments and result of the `EntryPoint`'s
+/// [`function`].
+///
+/// [`Function`]: crate::Function
+/// [`Location`]: Binding::Location
+/// [`function`]: EntryPoint::function
+/// [`stage`]: EntryPoint::stage
+#[derive(Debug)]
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct EntryPoint {
+ /// Name of this entry point, visible externally.
+ ///
+ /// Entry point names for a given `stage` must be distinct within a module.
+ pub name: String,
+ /// Shader stage.
+ pub stage: ShaderStage,
+ /// Early depth test for fragment stages.
+ pub early_depth_test: Option<EarlyDepthTest>,
+ /// Workgroup size for compute stages
+ pub workgroup_size: [u32; 3],
+ /// The entrance function.
+ pub function: Function,
+}
+
+/// Set of special types that can be optionally generated by the frontends.
+#[derive(Debug, Default)]
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct SpecialTypes {
+ /// Type for `RayDesc`.
+ ///
+ /// Call [`Module::generate_ray_desc_type`] to populate this if
+ /// needed and return the handle.
+ pub ray_desc: Option<Handle<Type>>,
+
+ /// Type for `RayIntersection`.
+ ///
+ /// Call [`Module::generate_ray_intersection_type`] to populate
+ /// this if needed and return the handle.
+ pub ray_intersection: Option<Handle<Type>>,
+}
+
+/// Shader module.
+///
+/// A module is a set of constants, global variables and functions, as well as
+/// the types required to define them.
+///
+/// Some functions are marked as entry points, to be used in a certain shader stage.
+///
+/// To create a new module, use the `Default` implementation.
+/// Alternatively, you can load an existing shader using one of the [available front ends][front].
+///
+/// When finished, you can export modules using one of the [available backends][back].
+#[derive(Debug, Default)]
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct Module {
+ /// Arena for the types defined in this module.
+ pub types: UniqueArena<Type>,
+ /// Dictionary of special type handles.
+ pub special_types: SpecialTypes,
+ /// Arena for the constants defined in this module.
+ pub constants: Arena<Constant>,
+ /// Arena for the global variables defined in this module.
+ pub global_variables: Arena<GlobalVariable>,
+ /// Arena for the functions defined in this module.
+ ///
+ /// Each function must appear in this arena strictly before all its callers.
+ /// Recursion is not supported.
+ pub functions: Arena<Function>,
+ /// Entry points.
+ pub entry_points: Vec<EntryPoint>,
+}
diff --git a/third_party/rust/naga/src/proc/index.rs b/third_party/rust/naga/src/proc/index.rs
new file mode 100644
index 0000000000..3fea79ec01
--- /dev/null
+++ b/third_party/rust/naga/src/proc/index.rs
@@ -0,0 +1,437 @@
+/*!
+Definitions for index bounds checking.
+*/
+
+use crate::{valid, Handle, UniqueArena};
+use bit_set::BitSet;
+
+/// How should code generated by Naga do bounds checks?
+///
+/// When a vector, matrix, or array index is out of bounds—either negative, or
+/// greater than or equal to the number of elements in the type—WGSL requires
+/// that some other index of the implementation's choice that is in bounds is
+/// used instead. (There are no types with zero elements.)
+///
+/// Similarly, when out-of-bounds coordinates, array indices, or sample indices
+/// are presented to the WGSL `textureLoad` and `textureStore` operations, the
+/// operation is redirected to do something safe.
+///
+/// Different users of Naga will prefer different defaults:
+///
+/// - When used as part of a WebGPU implementation, the WGSL specification
+/// requires the `Restrict` behavior for array, vector, and matrix accesses,
+/// and either the `Restrict` or `ReadZeroSkipWrite` behaviors for texture
+/// accesses.
+///
+/// - When used by the `wgpu` crate for native development, `wgpu` selects
+/// `ReadZeroSkipWrite` as its default.
+///
+/// - Naga's own default is `Unchecked`, so that shader translations
+/// are as faithful to the original as possible.
+///
+/// Sometimes the underlying hardware and drivers can perform bounds checks
+/// themselves, in a way that performs better than the checks Naga would inject.
+/// If you're using native checks like this, then having Naga inject its own
+/// checks as well would be redundant, and the `Unchecked` policy is
+/// appropriate.
+#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum BoundsCheckPolicy {
+ /// Replace out-of-bounds indexes with some arbitrary in-bounds index.
+ ///
+ /// (This does not necessarily mean clamping. For example, interpreting the
+ /// index as unsigned and taking the minimum with the largest valid index
+ /// would also be a valid implementation. That would map negative indices to
+ /// the last element, not the first.)
+ Restrict,
+
+ /// Out-of-bounds reads return zero, and writes have no effect.
+ ///
+ /// When applied to a chain of accesses, like `a[i][j].b[k]`, all index
+ /// expressions are evaluated, regardless of whether prior or later index
+ /// expressions were in bounds. But all the accesses per se are skipped
+ /// if any index is out of bounds.
+ ReadZeroSkipWrite,
+
+ /// Naga adds no checks to indexing operations. Generate the fastest code
+ /// possible. This is the default for Naga, as a translator, but consumers
+ /// should consider defaulting to a safer behavior.
+ Unchecked,
+}
+
+/// Policies for injecting bounds checks during code generation.
+#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct BoundsCheckPolicies {
+ /// How should the generated code handle array, vector, or matrix indices
+ /// that are out of range?
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub index: BoundsCheckPolicy,
+
+ /// How should the generated code handle array, vector, or matrix indices
+ /// that are out of range, when those values live in a [`GlobalVariable`] in
+ /// the [`Storage`] or [`Uniform`] address spaces?
+ ///
+ /// Some graphics hardware provides "robust buffer access", a feature that
+ /// ensures that using a pointer cannot access memory outside the 'buffer'
+ /// that it was derived from. In Naga terms, this means that the hardware
+ /// ensures that pointers computed by applying [`Access`] and
+ /// [`AccessIndex`] expressions to a [`GlobalVariable`] whose [`space`] is
+ /// [`Storage`] or [`Uniform`] will never read or write memory outside that
+ /// global variable.
+ ///
+ /// When hardware offers such a feature, it is probably undesirable to have
+ /// Naga inject bounds checking code for such accesses, since the hardware
+ /// can probably provide the same protection more efficiently. However,
+ /// bounds checks are still needed on accesses to indexable values that do
+ /// not live in buffers, like local variables.
+ ///
+ /// So, this option provides a separate policy that applies only to accesses
+ /// to storage and uniform globals. When depending on hardware bounds
+ /// checking, this policy can be `Unchecked` to avoid unnecessary overhead.
+ ///
+ /// When special hardware support is not available, this should probably be
+ /// the same as `index_bounds_check_policy`.
+ ///
+ /// [`GlobalVariable`]: crate::GlobalVariable
+ /// [`space`]: crate::GlobalVariable::space
+ /// [`Restrict`]: crate::back::BoundsCheckPolicy::Restrict
+ /// [`ReadZeroSkipWrite`]: crate::back::BoundsCheckPolicy::ReadZeroSkipWrite
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ /// [`Storage`]: crate::AddressSpace::Storage
+ /// [`Uniform`]: crate::AddressSpace::Uniform
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub buffer: BoundsCheckPolicy,
+
+ /// How should the generated code handle image texel references that are out
+ /// of range?
+ ///
+ /// This controls the behavior of [`ImageLoad`] expressions and
+ /// [`ImageStore`] statements when a coordinate, texture array index, level
+ /// of detail, or multisampled sample number is out of range.
+ ///
+ /// [`ImageLoad`]: crate::Expression::ImageLoad
+ /// [`ImageStore`]: crate::Statement::ImageStore
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub image: BoundsCheckPolicy,
+
+ /// How should the generated code handle binding array indexes that are out of bounds.
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub binding_array: BoundsCheckPolicy,
+}
+
+/// The default `BoundsCheckPolicy` is `Unchecked`.
+impl Default for BoundsCheckPolicy {
+ fn default() -> Self {
+ BoundsCheckPolicy::Unchecked
+ }
+}
+
+impl BoundsCheckPolicies {
+ /// Determine which policy applies to `base`.
+ ///
+ /// `base` is the "base" expression (the expression being indexed) of a `Access`
+ /// and `AccessIndex` expression. This is either a pointer, a value, being directly
+ /// indexed, or a binding array.
+ ///
+ /// See the documentation for [`BoundsCheckPolicy`] for details about
+ /// when each policy applies.
+ pub fn choose_policy(
+ &self,
+ base: Handle<crate::Expression>,
+ types: &UniqueArena<crate::Type>,
+ info: &valid::FunctionInfo,
+ ) -> BoundsCheckPolicy {
+ let ty = info[base].ty.inner_with(types);
+
+ if let crate::TypeInner::BindingArray { .. } = *ty {
+ return self.binding_array;
+ }
+
+ match ty.pointer_space() {
+ Some(crate::AddressSpace::Storage { access: _ } | crate::AddressSpace::Uniform) => {
+ self.buffer
+ }
+ // This covers other address spaces, but also accessing vectors and
+ // matrices by value, where no pointer is involved.
+ _ => self.index,
+ }
+ }
+
+ /// Return `true` if any of `self`'s policies are `policy`.
+ pub fn contains(&self, policy: BoundsCheckPolicy) -> bool {
+ self.index == policy || self.buffer == policy || self.image == policy
+ }
+}
+
+/// An index that may be statically known, or may need to be computed at runtime.
+///
+/// This enum lets us handle both [`Access`] and [`AccessIndex`] expressions
+/// with the same code.
+///
+/// [`Access`]: crate::Expression::Access
+/// [`AccessIndex`]: crate::Expression::AccessIndex
+#[derive(Clone, Copy, Debug)]
+pub enum GuardedIndex {
+ Known(u32),
+ Expression(Handle<crate::Expression>),
+}
+
+/// Build a set of expressions used as indices, to cache in temporary variables when
+/// emitted.
+///
+/// Given the bounds-check policies `policies`, construct a `BitSet` containing the handle
+/// indices of all the expressions in `function` that are ever used as guarded indices
+/// under the [`ReadZeroSkipWrite`] policy. The `module` argument must be the module to
+/// which `function` belongs, and `info` should be that function's analysis results.
+///
+/// Such index expressions will be used twice in the generated code: first for the
+/// comparison to see if the index is in bounds, and then for the access itself, should
+/// the comparison succeed. To avoid computing the expressions twice, the generated code
+/// should cache them in temporary variables.
+///
+/// Why do we need to build such a set in advance, instead of just processing access
+/// expressions as we encounter them? Whether an expression needs to be cached depends on
+/// whether it appears as something like the [`index`] operand of an [`Access`] expression
+/// or the [`level`] operand of an [`ImageLoad`] expression, and on the index bounds check
+/// policies that apply to those accesses. But [`Emit`] statements just identify a range
+/// of expressions by index; there's no good way to tell what an expression is used
+/// for. The only way to do it is to just iterate over all the expressions looking for
+/// relevant `Access` expressions --- which is what this function does.
+///
+/// Simple expressions like variable loads and constants don't make sense to cache: it's
+/// no better than just re-evaluating them. But constants are not covered by `Emit`
+/// statements, and `Load`s are always cached to ensure they occur at the right time, so
+/// we don't bother filtering them out from this set.
+///
+/// Fortunately, we don't need to deal with [`ImageStore`] statements here. When we emit
+/// code for a statement, the writer isn't in the middle of an expression, so we can just
+/// emit declarations for temporaries, initialized appropriately.
+///
+/// None of these concerns apply for SPIR-V output, since it's easy to just reuse an
+/// instruction ID in two places; that has the same semantics as a temporary variable, and
+/// it's inherent in the design of SPIR-V. This function is more useful for text-based
+/// back ends.
+///
+/// [`ReadZeroSkipWrite`]: BoundsCheckPolicy::ReadZeroSkipWrite
+/// [`index`]: crate::Expression::Access::index
+/// [`Access`]: crate::Expression::Access
+/// [`level`]: crate::Expression::ImageLoad::level
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+/// [`Emit`]: crate::Statement::Emit
+/// [`ImageStore`]: crate::Statement::ImageStore
+pub fn find_checked_indexes(
+ module: &crate::Module,
+ function: &crate::Function,
+ info: &crate::valid::FunctionInfo,
+ policies: BoundsCheckPolicies,
+) -> BitSet {
+ use crate::Expression as Ex;
+
+ let mut guarded_indices = BitSet::new();
+
+ // Don't bother scanning if we never need `ReadZeroSkipWrite`.
+ if policies.contains(BoundsCheckPolicy::ReadZeroSkipWrite) {
+ for (_handle, expr) in function.expressions.iter() {
+ // There's no need to handle `AccessIndex` expressions, as their
+ // indices never need to be cached.
+ match *expr {
+ Ex::Access { base, index } => {
+ if policies.choose_policy(base, &module.types, info)
+ == BoundsCheckPolicy::ReadZeroSkipWrite
+ && access_needs_check(
+ base,
+ GuardedIndex::Expression(index),
+ module,
+ function,
+ info,
+ )
+ .is_some()
+ {
+ guarded_indices.insert(index.index());
+ }
+ }
+ Ex::ImageLoad {
+ coordinate,
+ array_index,
+ sample,
+ level,
+ ..
+ } => {
+ if policies.image == BoundsCheckPolicy::ReadZeroSkipWrite {
+ guarded_indices.insert(coordinate.index());
+ if let Some(array_index) = array_index {
+ guarded_indices.insert(array_index.index());
+ }
+ if let Some(sample) = sample {
+ guarded_indices.insert(sample.index());
+ }
+ if let Some(level) = level {
+ guarded_indices.insert(level.index());
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+
+ guarded_indices
+}
+
+/// Determine whether `index` is statically known to be in bounds for `base`.
+///
+/// If we can't be sure that the index is in bounds, return the limit within
+/// which valid indices must fall.
+///
+/// The return value is one of the following:
+///
+/// - `Some(Known(n))` indicates that `n` is the largest valid index.
+///
+/// - `Some(Computed(global))` indicates that the largest valid index is one
+/// less than the length of the array that is the last member of the
+/// struct held in `global`.
+///
+/// - `None` indicates that the index need not be checked, either because it
+/// is statically known to be in bounds, or because the applicable policy
+/// is `Unchecked`.
+///
+/// This function only handles subscriptable types: arrays, vectors, and
+/// matrices. It does not handle struct member indices; those never require
+/// run-time checks, so it's best to deal with them further up the call
+/// chain.
+pub fn access_needs_check(
+ base: Handle<crate::Expression>,
+ mut index: GuardedIndex,
+ module: &crate::Module,
+ function: &crate::Function,
+ info: &crate::valid::FunctionInfo,
+) -> Option<IndexableLength> {
+ let base_inner = info[base].ty.inner_with(&module.types);
+ // Unwrap safety: `Err` here indicates unindexable base types and invalid
+ // length constants, but `access_needs_check` is only used by back ends, so
+ // validation should have caught those problems.
+ let length = base_inner.indexable_length(module).unwrap();
+ index.try_resolve_to_constant(function, module);
+ if let (&GuardedIndex::Known(index), &IndexableLength::Known(length)) = (&index, &length) {
+ if index < length {
+ // Index is statically known to be in bounds, no check needed.
+ return None;
+ }
+ };
+
+ Some(length)
+}
+
+impl GuardedIndex {
+ /// Make A `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible.
+ ///
+ /// If the expression is a [`Constant`] whose value is a non-specialized, scalar
+ /// integer constant that can be converted to a `u32`, do so and return a
+ /// `GuardedIndex::Known`. Otherwise, return the `GuardedIndex::Expression`
+ /// unchanged.
+ ///
+ /// Return values that are already `Known` unchanged.
+ ///
+ /// [`Constant`]: crate::Expression::Constant
+ fn try_resolve_to_constant(&mut self, function: &crate::Function, module: &crate::Module) {
+ if let GuardedIndex::Expression(expr) = *self {
+ if let crate::Expression::Constant(handle) = function.expressions[expr] {
+ if let Some(value) = module.constants[handle].to_array_length() {
+ *self = GuardedIndex::Known(value);
+ }
+ }
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)]
+pub enum IndexableLengthError {
+ #[error("Type is not indexable, and has no length (validation error)")]
+ TypeNotIndexable,
+ #[error("Array length constant {0:?} is invalid")]
+ InvalidArrayLength(Handle<crate::Constant>),
+}
+
+impl crate::TypeInner {
+ /// Return the length of a subscriptable type.
+ ///
+ /// The `self` parameter should be a handle to a vector, matrix, or array
+ /// type, a pointer to one of those, or a value pointer. Arrays may be
+ /// fixed-size, dynamically sized, or sized by a specializable constant.
+ /// This function does not handle struct member references, as with
+ /// `AccessIndex`.
+ ///
+ /// The value returned is appropriate for bounds checks on subscripting.
+ ///
+ /// Return an error if `self` does not describe a subscriptable type at all.
+ pub fn indexable_length(
+ &self,
+ module: &crate::Module,
+ ) -> Result<IndexableLength, IndexableLengthError> {
+ use crate::TypeInner as Ti;
+ let known_length = match *self {
+ Ti::Vector { size, .. } => size as _,
+ Ti::Matrix { columns, .. } => columns as _,
+ Ti::Array { size, .. } | Ti::BindingArray { size, .. } => {
+ return size.to_indexable_length(module);
+ }
+ Ti::ValuePointer {
+ size: Some(size), ..
+ } => size as _,
+ Ti::Pointer { base, .. } => {
+ // When assigning types to expressions, ResolveContext::Resolve
+ // does a separate sub-match here instead of a full recursion,
+ // so we'll do the same.
+ let base_inner = &module.types[base].inner;
+ match *base_inner {
+ Ti::Vector { size, .. } => size as _,
+ Ti::Matrix { columns, .. } => columns as _,
+ Ti::Array { size, .. } => return size.to_indexable_length(module),
+ _ => return Err(IndexableLengthError::TypeNotIndexable),
+ }
+ }
+ _ => return Err(IndexableLengthError::TypeNotIndexable),
+ };
+ Ok(IndexableLength::Known(known_length))
+ }
+}
+
+/// The number of elements in an indexable type.
+///
+/// This summarizes the length of vectors, matrices, and arrays in a way that is
+/// convenient for indexing and bounds-checking code.
+#[derive(Debug)]
+pub enum IndexableLength {
+ /// Values of this type always have the given number of elements.
+ Known(u32),
+
+ /// The number of elements is determined at runtime.
+ Dynamic,
+}
+
+impl crate::ArraySize {
+ pub fn to_indexable_length(
+ self,
+ module: &crate::Module,
+ ) -> Result<IndexableLength, IndexableLengthError> {
+ Ok(match self {
+ Self::Constant(k) => {
+ let constant = &module.constants[k];
+ if constant.specialization.is_some() {
+ // Specializable constants are not supported as array lengths.
+ // See valid::TypeError::UnsupportedSpecializedArrayLength.
+ return Err(IndexableLengthError::InvalidArrayLength(k));
+ }
+ let length = constant
+ .to_array_length()
+ .ok_or(IndexableLengthError::InvalidArrayLength(k))?;
+ IndexableLength::Known(length)
+ }
+ Self::Dynamic => IndexableLength::Dynamic,
+ })
+ }
+}
diff --git a/third_party/rust/naga/src/proc/layouter.rs b/third_party/rust/naga/src/proc/layouter.rs
new file mode 100644
index 0000000000..65369d1cc8
--- /dev/null
+++ b/third_party/rust/naga/src/proc/layouter.rs
@@ -0,0 +1,256 @@
+use crate::arena::{Arena, Handle, UniqueArena};
+use std::{fmt::Display, num::NonZeroU32, ops};
+
+/// A newtype struct where its only valid values are powers of 2
+#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Alignment(NonZeroU32);
+
+impl Alignment {
+ pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) });
+ pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) });
+ pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) });
+ pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) });
+ pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) });
+
+ pub const MIN_UNIFORM: Self = Self::SIXTEEN;
+
+ pub const fn new(n: u32) -> Option<Self> {
+ if n.is_power_of_two() {
+ // SAFETY: value can't be 0 since we just checked if it's a power of 2
+ Some(Self(unsafe { NonZeroU32::new_unchecked(n) }))
+ } else {
+ None
+ }
+ }
+
+ /// # Panics
+ /// If `width` is not a power of 2
+ pub fn from_width(width: u8) -> Self {
+ Self::new(width as u32).unwrap()
+ }
+
+ /// Returns whether or not `n` is a multiple of this alignment.
+ pub const fn is_aligned(&self, n: u32) -> bool {
+ // equivalent to: `n % self.0.get() == 0` but much faster
+ n & (self.0.get() - 1) == 0
+ }
+
+ /// Round `n` up to the nearest alignment boundary.
+ pub const fn round_up(&self, n: u32) -> u32 {
+ // equivalent to:
+ // match n % self.0.get() {
+ // 0 => n,
+ // rem => n + (self.0.get() - rem),
+ // }
+ let mask = self.0.get() - 1;
+ (n + mask) & !mask
+ }
+}
+
+impl Display for Alignment {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.0.get().fmt(f)
+ }
+}
+
+impl ops::Mul<u32> for Alignment {
+ type Output = u32;
+
+ fn mul(self, rhs: u32) -> Self::Output {
+ self.0.get() * rhs
+ }
+}
+
+impl ops::Mul for Alignment {
+ type Output = Alignment;
+
+ fn mul(self, rhs: Alignment) -> Self::Output {
+ // SAFETY: both lhs and rhs are powers of 2, the result will be a power of 2
+ Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) })
+ }
+}
+
+impl From<crate::VectorSize> for Alignment {
+ fn from(size: crate::VectorSize) -> Self {
+ match size {
+ crate::VectorSize::Bi => Alignment::TWO,
+ crate::VectorSize::Tri => Alignment::FOUR,
+ crate::VectorSize::Quad => Alignment::FOUR,
+ }
+ }
+}
+
+/// Size and alignment information for a type.
+#[derive(Clone, Copy, Debug, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct TypeLayout {
+ pub size: u32,
+ pub alignment: Alignment,
+}
+
+impl TypeLayout {
+ /// Produce the stride as if this type is a base of an array.
+ pub const fn to_stride(&self) -> u32 {
+ self.alignment.round_up(self.size)
+ }
+}
+
+/// Helper processor that derives the sizes of all types.
+///
+/// `Layouter` uses the default layout algorithm/table, described in
+/// [WGSL §4.3.7, "Memory Layout"]
+///
+/// A `Layouter` may be indexed by `Handle<Type>` values: `layouter[handle]` is the
+/// layout of the type whose handle is `handle`.
+///
+/// [WGSL §4.3.7, "Memory Layout"](https://gpuweb.github.io/gpuweb/wgsl/#memory-layouts)
+#[derive(Debug, Default)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Layouter {
+ /// Layouts for types in an arena, indexed by `Handle` index.
+ layouts: Vec<TypeLayout>,
+}
+
+impl ops::Index<Handle<crate::Type>> for Layouter {
+ type Output = TypeLayout;
+ fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
+ &self.layouts[handle.index()]
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
+pub enum LayoutErrorInner {
+ #[error("Array element type {0:?} doesn't exist")]
+ InvalidArrayElementType(Handle<crate::Type>),
+ #[error("Struct member[{0}] type {1:?} doesn't exist")]
+ InvalidStructMemberType(u32, Handle<crate::Type>),
+ #[error("Type width must be a power of two")]
+ NonPowerOfTwoWidth,
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
+#[error("Error laying out type {ty:?}: {inner}")]
+pub struct LayoutError {
+ pub ty: Handle<crate::Type>,
+ pub inner: LayoutErrorInner,
+}
+
+impl LayoutErrorInner {
+ const fn with(self, ty: Handle<crate::Type>) -> LayoutError {
+ LayoutError { ty, inner: self }
+ }
+}
+
+impl Layouter {
+ /// Remove all entries from this `Layouter`, retaining storage.
+ pub fn clear(&mut self) {
+ self.layouts.clear();
+ }
+
+ /// Extend this `Layouter` with layouts for any new entries in `types`.
+ ///
+ /// Ensure that every type in `types` has a corresponding [TypeLayout] in
+ /// [`self.layouts`].
+ ///
+ /// Some front ends need to be able to compute layouts for existing types
+ /// while module construction is still in progress and new types are still
+ /// being added. This function assumes that the `TypeLayout` values already
+ /// present in `self.layouts` cover their corresponding entries in `types`,
+ /// and extends `self.layouts` as needed to cover the rest. Thus, a front
+ /// end can call this function at any time, passing its current type and
+ /// constant arenas, and then assume that layouts are available for all
+ /// types.
+ #[allow(clippy::or_fun_call)]
+ pub fn update(
+ &mut self,
+ types: &UniqueArena<crate::Type>,
+ constants: &Arena<crate::Constant>,
+ ) -> Result<(), LayoutError> {
+ use crate::TypeInner as Ti;
+
+ for (ty_handle, ty) in types.iter().skip(self.layouts.len()) {
+ let size = ty.inner.size(constants);
+ let layout = match ty.inner {
+ Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => {
+ let alignment = Alignment::new(width as u32)
+ .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
+ TypeLayout { size, alignment }
+ }
+ Ti::Vector {
+ size: vec_size,
+ width,
+ ..
+ } => {
+ let alignment = Alignment::new(width as u32)
+ .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
+ TypeLayout {
+ size,
+ alignment: Alignment::from(vec_size) * alignment,
+ }
+ }
+ Ti::Matrix {
+ columns: _,
+ rows,
+ width,
+ } => {
+ let alignment = Alignment::new(width as u32)
+ .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
+ TypeLayout {
+ size,
+ alignment: Alignment::from(rows) * alignment,
+ }
+ }
+ Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
+ size,
+ alignment: Alignment::ONE,
+ },
+ Ti::Array {
+ base,
+ stride: _,
+ size: _,
+ } => TypeLayout {
+ size,
+ alignment: if base < ty_handle {
+ self[base].alignment
+ } else {
+ return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle));
+ },
+ },
+ Ti::Struct { span, ref members } => {
+ let mut alignment = Alignment::ONE;
+ for (index, member) in members.iter().enumerate() {
+ alignment = if member.ty < ty_handle {
+ alignment.max(self[member.ty].alignment)
+ } else {
+ return Err(LayoutErrorInner::InvalidStructMemberType(
+ index as u32,
+ member.ty,
+ )
+ .with(ty_handle));
+ };
+ }
+ TypeLayout {
+ size: span,
+ alignment,
+ }
+ }
+ Ti::Image { .. }
+ | Ti::Sampler { .. }
+ | Ti::AccelerationStructure
+ | Ti::RayQuery
+ | Ti::BindingArray { .. } => TypeLayout {
+ size,
+ alignment: Alignment::ONE,
+ },
+ };
+ debug_assert!(size <= layout.size);
+ self.layouts.push(layout);
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/proc/mod.rs b/third_party/rust/naga/src/proc/mod.rs
new file mode 100644
index 0000000000..a775272a19
--- /dev/null
+++ b/third_party/rust/naga/src/proc/mod.rs
@@ -0,0 +1,493 @@
+/*!
+[`Module`](super::Module) processing functionality.
+*/
+
+pub mod index;
+mod layouter;
+mod namer;
+mod terminator;
+mod typifier;
+
+use std::cmp::PartialEq;
+
+pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
+pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
+pub use namer::{EntryPointIndex, NameKey, Namer};
+pub use terminator::ensure_block_returns;
+pub use typifier::{ResolveContext, ResolveError, TypeResolution};
+
+impl From<super::StorageFormat> for super::ScalarKind {
+ fn from(format: super::StorageFormat) -> Self {
+ use super::{ScalarKind as Sk, StorageFormat as Sf};
+ match format {
+ Sf::R8Unorm => Sk::Float,
+ Sf::R8Snorm => Sk::Float,
+ Sf::R8Uint => Sk::Uint,
+ Sf::R8Sint => Sk::Sint,
+ Sf::R16Uint => Sk::Uint,
+ Sf::R16Sint => Sk::Sint,
+ Sf::R16Float => Sk::Float,
+ Sf::Rg8Unorm => Sk::Float,
+ Sf::Rg8Snorm => Sk::Float,
+ Sf::Rg8Uint => Sk::Uint,
+ Sf::Rg8Sint => Sk::Sint,
+ Sf::R32Uint => Sk::Uint,
+ Sf::R32Sint => Sk::Sint,
+ Sf::R32Float => Sk::Float,
+ Sf::Rg16Uint => Sk::Uint,
+ Sf::Rg16Sint => Sk::Sint,
+ Sf::Rg16Float => Sk::Float,
+ Sf::Rgba8Unorm => Sk::Float,
+ Sf::Rgba8Snorm => Sk::Float,
+ Sf::Rgba8Uint => Sk::Uint,
+ Sf::Rgba8Sint => Sk::Sint,
+ Sf::Rgb10a2Unorm => Sk::Float,
+ Sf::Rg11b10Float => Sk::Float,
+ Sf::Rg32Uint => Sk::Uint,
+ Sf::Rg32Sint => Sk::Sint,
+ Sf::Rg32Float => Sk::Float,
+ Sf::Rgba16Uint => Sk::Uint,
+ Sf::Rgba16Sint => Sk::Sint,
+ Sf::Rgba16Float => Sk::Float,
+ Sf::Rgba32Uint => Sk::Uint,
+ Sf::Rgba32Sint => Sk::Sint,
+ Sf::Rgba32Float => Sk::Float,
+ Sf::R16Unorm => Sk::Float,
+ Sf::R16Snorm => Sk::Float,
+ Sf::Rg16Unorm => Sk::Float,
+ Sf::Rg16Snorm => Sk::Float,
+ Sf::Rgba16Unorm => Sk::Float,
+ Sf::Rgba16Snorm => Sk::Float,
+ }
+ }
+}
+
+impl super::ScalarValue {
+ pub const fn scalar_kind(&self) -> super::ScalarKind {
+ match *self {
+ Self::Uint(_) => super::ScalarKind::Uint,
+ Self::Sint(_) => super::ScalarKind::Sint,
+ Self::Float(_) => super::ScalarKind::Float,
+ Self::Bool(_) => super::ScalarKind::Bool,
+ }
+ }
+}
+
+impl super::ScalarKind {
+ pub const fn is_numeric(self) -> bool {
+ match self {
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint | crate::ScalarKind::Float => true,
+ crate::ScalarKind::Bool => false,
+ }
+ }
+}
+
+pub const POINTER_SPAN: u32 = 4;
+
+impl super::TypeInner {
+ pub const fn scalar_kind(&self) -> Option<super::ScalarKind> {
+ match *self {
+ super::TypeInner::Scalar { kind, .. } | super::TypeInner::Vector { kind, .. } => {
+ Some(kind)
+ }
+ super::TypeInner::Matrix { .. } => Some(super::ScalarKind::Float),
+ _ => None,
+ }
+ }
+
+ pub const fn pointer_space(&self) -> Option<crate::AddressSpace> {
+ match *self {
+ Self::Pointer { space, .. } => Some(space),
+ Self::ValuePointer { space, .. } => Some(space),
+ _ => None,
+ }
+ }
+
+ /// Get the size of this type.
+ pub fn size(&self, constants: &super::Arena<super::Constant>) -> u32 {
+ match *self {
+ Self::Scalar { kind: _, width } | Self::Atomic { kind: _, width } => width as u32,
+ Self::Vector {
+ size,
+ kind: _,
+ width,
+ } => size as u32 * width as u32,
+ // matrices are treated as arrays of aligned columns
+ Self::Matrix {
+ columns,
+ rows,
+ width,
+ } => Alignment::from(rows) * width as u32 * columns as u32,
+ Self::Pointer { .. } | Self::ValuePointer { .. } => POINTER_SPAN,
+ Self::Array {
+ base: _,
+ size,
+ stride,
+ } => {
+ let count = match size {
+ super::ArraySize::Constant(handle) => {
+ constants[handle].to_array_length().unwrap_or(1)
+ }
+ // A dynamically-sized array has to have at least one element
+ super::ArraySize::Dynamic => 1,
+ };
+ count * stride
+ }
+ Self::Struct { span, .. } => span,
+ Self::Image { .. }
+ | Self::Sampler { .. }
+ | Self::AccelerationStructure
+ | Self::RayQuery
+ | Self::BindingArray { .. } => 0,
+ }
+ }
+
+ /// Return the canonical form of `self`, or `None` if it's already in
+ /// canonical form.
+ ///
+ /// Certain types have multiple representations in `TypeInner`. This
+ /// function converts all forms of equivalent types to a single
+ /// representative of their class, so that simply applying `Eq` to the
+ /// result indicates whether the types are equivalent, as far as Naga IR is
+ /// concerned.
+ pub fn canonical_form(
+ &self,
+ types: &crate::UniqueArena<crate::Type>,
+ ) -> Option<crate::TypeInner> {
+ use crate::TypeInner as Ti;
+ match *self {
+ Ti::Pointer { base, space } => match types[base].inner {
+ Ti::Scalar { kind, width } => Some(Ti::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space,
+ }),
+ Ti::Vector { size, kind, width } => Some(Ti::ValuePointer {
+ size: Some(size),
+ kind,
+ width,
+ space,
+ }),
+ _ => None,
+ },
+ _ => None,
+ }
+ }
+
+ /// Compare `self` and `rhs` as types.
+ ///
+ /// This is mostly the same as `<TypeInner as Eq>::eq`, but it treats
+ /// `ValuePointer` and `Pointer` types as equivalent.
+ ///
+ /// When you know that one side of the comparison is never a pointer, it's
+ /// fine to not bother with canonicalization, and just compare `TypeInner`
+ /// values with `==`.
+ pub fn equivalent(
+ &self,
+ rhs: &crate::TypeInner,
+ types: &crate::UniqueArena<crate::Type>,
+ ) -> bool {
+ let left = self.canonical_form(types);
+ let right = rhs.canonical_form(types);
+ left.as_ref().unwrap_or(self) == right.as_ref().unwrap_or(rhs)
+ }
+
+ pub fn is_dynamically_sized(&self, types: &crate::UniqueArena<crate::Type>) -> bool {
+ use crate::TypeInner as Ti;
+ match *self {
+ Ti::Array { size, .. } => size == crate::ArraySize::Dynamic,
+ Ti::Struct { ref members, .. } => members
+ .last()
+ .map(|last| types[last.ty].inner.is_dynamically_sized(types))
+ .unwrap_or(false),
+ _ => false,
+ }
+ }
+}
+
+impl super::AddressSpace {
+ pub fn access(self) -> crate::StorageAccess {
+ use crate::StorageAccess as Sa;
+ match self {
+ crate::AddressSpace::Function
+ | crate::AddressSpace::Private
+ | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
+ crate::AddressSpace::Uniform => Sa::LOAD,
+ crate::AddressSpace::Storage { access } => access,
+ crate::AddressSpace::Handle => Sa::LOAD,
+ crate::AddressSpace::PushConstant => Sa::LOAD,
+ }
+ }
+}
+
+impl super::MathFunction {
+ pub const fn argument_count(&self) -> usize {
+ match *self {
+ // comparison
+ Self::Abs => 1,
+ Self::Min => 2,
+ Self::Max => 2,
+ Self::Clamp => 3,
+ Self::Saturate => 1,
+ // trigonometry
+ Self::Cos => 1,
+ Self::Cosh => 1,
+ Self::Sin => 1,
+ Self::Sinh => 1,
+ Self::Tan => 1,
+ Self::Tanh => 1,
+ Self::Acos => 1,
+ Self::Asin => 1,
+ Self::Atan => 1,
+ Self::Atan2 => 2,
+ Self::Asinh => 1,
+ Self::Acosh => 1,
+ Self::Atanh => 1,
+ Self::Radians => 1,
+ Self::Degrees => 1,
+ // decomposition
+ Self::Ceil => 1,
+ Self::Floor => 1,
+ Self::Round => 1,
+ Self::Fract => 1,
+ Self::Trunc => 1,
+ Self::Modf => 2,
+ Self::Frexp => 2,
+ Self::Ldexp => 2,
+ // exponent
+ Self::Exp => 1,
+ Self::Exp2 => 1,
+ Self::Log => 1,
+ Self::Log2 => 1,
+ Self::Pow => 2,
+ // geometry
+ Self::Dot => 2,
+ Self::Outer => 2,
+ Self::Cross => 2,
+ Self::Distance => 2,
+ Self::Length => 1,
+ Self::Normalize => 1,
+ Self::FaceForward => 3,
+ Self::Reflect => 2,
+ Self::Refract => 3,
+ // computational
+ Self::Sign => 1,
+ Self::Fma => 3,
+ Self::Mix => 3,
+ Self::Step => 2,
+ Self::SmoothStep => 3,
+ Self::Sqrt => 1,
+ Self::InverseSqrt => 1,
+ Self::Inverse => 1,
+ Self::Transpose => 1,
+ Self::Determinant => 1,
+ // bits
+ Self::CountTrailingZeros => 1,
+ Self::CountLeadingZeros => 1,
+ Self::CountOneBits => 1,
+ Self::ReverseBits => 1,
+ Self::ExtractBits => 3,
+ Self::InsertBits => 4,
+ Self::FindLsb => 1,
+ Self::FindMsb => 1,
+ // data packing
+ Self::Pack4x8snorm => 1,
+ Self::Pack4x8unorm => 1,
+ Self::Pack2x16snorm => 1,
+ Self::Pack2x16unorm => 1,
+ Self::Pack2x16float => 1,
+ // data unpacking
+ Self::Unpack4x8snorm => 1,
+ Self::Unpack4x8unorm => 1,
+ Self::Unpack2x16snorm => 1,
+ Self::Unpack2x16unorm => 1,
+ Self::Unpack2x16float => 1,
+ }
+ }
+}
+
+impl crate::Expression {
+ /// Returns true if the expression is considered emitted at the start of a function.
+ pub const fn needs_pre_emit(&self) -> bool {
+ match *self {
+ Self::Constant(_)
+ | Self::FunctionArgument(_)
+ | Self::GlobalVariable(_)
+ | Self::LocalVariable(_) => true,
+ _ => false,
+ }
+ }
+
+ /// Return true if this expression is a dynamic array index, for [`Access`].
+ ///
+ /// This method returns true if this expression is a dynamically computed
+ /// index, and as such can only be used to index matrices and arrays when
+ /// they appear behind a pointer. See the documentation for [`Access`] for
+ /// details.
+ ///
+ /// Note, this does not check the _type_ of the given expression. It's up to
+ /// the caller to establish that the `Access` expression is well-typed
+ /// through other means, like [`ResolveContext`].
+ ///
+ /// [`Access`]: crate::Expression::Access
+ /// [`ResolveContext`]: crate::proc::ResolveContext
+ pub fn is_dynamic_index(&self, module: &crate::Module) -> bool {
+ if let Self::Constant(handle) = *self {
+ let constant = &module.constants[handle];
+ constant.specialization.is_some()
+ } else {
+ true
+ }
+ }
+}
+
+impl crate::Function {
+ /// Return the global variable being accessed by the expression `pointer`.
+ ///
+ /// Assuming that `pointer` is a series of `Access` and `AccessIndex`
+ /// expressions that ultimately access some part of a `GlobalVariable`,
+ /// return a handle for that global.
+ ///
+ /// If the expression does not ultimately access a global variable, return
+ /// `None`.
+ pub fn originating_global(
+ &self,
+ mut pointer: crate::Handle<crate::Expression>,
+ ) -> Option<crate::Handle<crate::GlobalVariable>> {
+ loop {
+ pointer = match self.expressions[pointer] {
+ crate::Expression::Access { base, .. } => base,
+ crate::Expression::AccessIndex { base, .. } => base,
+ crate::Expression::GlobalVariable(handle) => return Some(handle),
+ crate::Expression::LocalVariable(_) => return None,
+ crate::Expression::FunctionArgument(_) => return None,
+ // There are no other expressions that produce pointer values.
+ _ => unreachable!(),
+ }
+ }
+ }
+}
+
+impl crate::SampleLevel {
+ pub const fn implicit_derivatives(&self) -> bool {
+ match *self {
+ Self::Auto | Self::Bias(_) => true,
+ Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
+ }
+ }
+}
+
+impl crate::Constant {
+ /// Interpret this constant as an array length, and return it as a `u32`.
+ ///
+ /// Ignore any specialization available for this constant; return its
+ /// unspecialized value.
+ ///
+ /// If the constant has an inappropriate kind (non-scalar or non-integer) or
+ /// value (negative, out of range for u32), return `None`. This usually
+ /// indicates an error, but only the caller has enough information to report
+ /// the error helpfully: in back ends, it's a validation error, but in front
+ /// ends, it may indicate ill-formed input (for example, a SPIR-V
+ /// `OpArrayType` referring to an inappropriate `OpConstant`). So we return
+ /// `Option` and let the caller sort things out.
+ pub(crate) fn to_array_length(&self) -> Option<u32> {
+ match self.inner {
+ crate::ConstantInner::Scalar { value, width: _ } => match value {
+ crate::ScalarValue::Uint(value) => value.try_into().ok(),
+ // Accept a signed integer size to avoid
+ // requiring an explicit uint
+ // literal. Type inference should make
+ // this unnecessary.
+ crate::ScalarValue::Sint(value) => value.try_into().ok(),
+ _ => None,
+ },
+ // caught by type validation
+ crate::ConstantInner::Composite { .. } => None,
+ }
+ }
+}
+
+impl crate::Binding {
+ pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
+ match *self {
+ crate::Binding::BuiltIn(built_in) => Some(built_in),
+ Self::Location { .. } => None,
+ }
+ }
+}
+
+//TODO: should we use an existing crate for hashable floats?
+impl PartialEq for crate::ScalarValue {
+ fn eq(&self, other: &Self) -> bool {
+ match (*self, *other) {
+ (Self::Uint(a), Self::Uint(b)) => a == b,
+ (Self::Sint(a), Self::Sint(b)) => a == b,
+ (Self::Float(a), Self::Float(b)) => a.to_bits() == b.to_bits(),
+ (Self::Bool(a), Self::Bool(b)) => a == b,
+ _ => false,
+ }
+ }
+}
+impl Eq for crate::ScalarValue {}
+impl std::hash::Hash for crate::ScalarValue {
+ fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
+ match *self {
+ Self::Sint(v) => v.hash(hasher),
+ Self::Uint(v) => v.hash(hasher),
+ Self::Float(v) => v.to_bits().hash(hasher),
+ Self::Bool(v) => v.hash(hasher),
+ }
+ }
+}
+
+impl super::SwizzleComponent {
+ pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
+
+ pub const fn index(&self) -> u32 {
+ match *self {
+ Self::X => 0,
+ Self::Y => 1,
+ Self::Z => 2,
+ Self::W => 3,
+ }
+ }
+ pub const fn from_index(idx: u32) -> Self {
+ match idx {
+ 0 => Self::X,
+ 1 => Self::Y,
+ 2 => Self::Z,
+ _ => Self::W,
+ }
+ }
+}
+
+impl super::ImageClass {
+ pub const fn is_multisampled(self) -> bool {
+ match self {
+ crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
+ crate::ImageClass::Storage { .. } => false,
+ }
+ }
+
+ pub const fn is_mipmapped(self) -> bool {
+ match self {
+ crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
+ crate::ImageClass::Storage { .. } => false,
+ }
+ }
+}
+
+#[test]
+fn test_matrix_size() {
+ let constants = crate::Arena::new();
+ assert_eq!(
+ crate::TypeInner::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ width: 4
+ }
+ .size(&constants),
+ 48,
+ );
+}
diff --git a/third_party/rust/naga/src/proc/namer.rs b/third_party/rust/naga/src/proc/namer.rs
new file mode 100644
index 0000000000..053126b8ac
--- /dev/null
+++ b/third_party/rust/naga/src/proc/namer.rs
@@ -0,0 +1,261 @@
+use crate::{arena::Handle, FastHashMap, FastHashSet};
+use std::borrow::Cow;
+
+pub type EntryPointIndex = u16;
+const SEPARATOR: char = '_';
+
+#[derive(Debug, Eq, Hash, PartialEq)]
+pub enum NameKey {
+ Constant(Handle<crate::Constant>),
+ GlobalVariable(Handle<crate::GlobalVariable>),
+ Type(Handle<crate::Type>),
+ StructMember(Handle<crate::Type>, u32),
+ Function(Handle<crate::Function>),
+ FunctionArgument(Handle<crate::Function>, u32),
+ FunctionLocal(Handle<crate::Function>, Handle<crate::LocalVariable>),
+ EntryPoint(EntryPointIndex),
+ EntryPointLocal(EntryPointIndex, Handle<crate::LocalVariable>),
+ EntryPointArgument(EntryPointIndex, u32),
+}
+
+/// This processor assigns names to all the things in a module
+/// that may need identifiers in a textual backend.
+#[derive(Default)]
+pub struct Namer {
+ /// The last numeric suffix used for each base name. Zero means "no suffix".
+ unique: FastHashMap<String, u32>,
+ keywords: FastHashSet<String>,
+ reserved_prefixes: Vec<String>,
+}
+
+impl Namer {
+ /// Return a form of `string` suitable for use as the base of an identifier.
+ ///
+ /// - Drop leading digits.
+ /// - Retain only alphanumeric and `_` characters.
+ /// - Avoid prefixes in [`Namer::reserved_prefixes`].
+ ///
+ /// The return value is a valid identifier prefix in all of Naga's output languages,
+ /// and it never ends with a `SEPARATOR` character.
+ /// It is used as a key into the unique table.
+ fn sanitize<'s>(&self, string: &'s str) -> Cow<'s, str> {
+ let string = string
+ .trim_start_matches(|c: char| c.is_numeric())
+ .trim_end_matches(SEPARATOR);
+
+ let base = if !string.is_empty()
+ && string
+ .chars()
+ .all(|c: char| c.is_ascii_alphanumeric() || c == '_')
+ {
+ Cow::Borrowed(string)
+ } else {
+ let mut filtered = string
+ .chars()
+ .filter(|&c| c.is_ascii_alphanumeric() || c == '_')
+ .collect::<String>();
+ let stripped_len = filtered.trim_end_matches(SEPARATOR).len();
+ filtered.truncate(stripped_len);
+ if filtered.is_empty() {
+ filtered.push_str("unnamed");
+ }
+ Cow::Owned(filtered)
+ };
+
+ for prefix in &self.reserved_prefixes {
+ if base.starts_with(prefix) {
+ return format!("gen_{base}").into();
+ }
+ }
+
+ base
+ }
+
+ /// Return a new identifier based on `label_raw`.
+ ///
+ /// The result:
+ /// - is a valid identifier even if `label_raw` is not
+ /// - conflicts with no keywords listed in `Namer::keywords`, and
+ /// - is different from any identifier previously constructed by this
+ /// `Namer`.
+ ///
+ /// Guarantee uniqueness by applying a numeric suffix when necessary. If `label_raw`
+ /// itself ends with digits, separate them from the suffix with an underscore.
+ pub fn call(&mut self, label_raw: &str) -> String {
+ use std::fmt::Write as _; // for write!-ing to Strings
+
+ let base = self.sanitize(label_raw);
+ debug_assert!(!base.is_empty() && !base.ends_with(SEPARATOR));
+
+ // This would seem to be a natural place to use `HashMap::entry`. However, `entry`
+ // requires an owned key, and we'd like to avoid heap-allocating strings we're
+ // just going to throw away. The approach below double-hashes only when we create
+ // a new entry, in which case the heap allocation of the owned key was more
+ // expensive anyway.
+ match self.unique.get_mut(base.as_ref()) {
+ Some(count) => {
+ *count += 1;
+ // Add the suffix. This may fit in base's existing allocation.
+ let mut suffixed = base.into_owned();
+ write!(suffixed, "{}{}", SEPARATOR, *count).unwrap();
+ suffixed
+ }
+ None => {
+ let mut suffixed = base.to_string();
+ if base.ends_with(char::is_numeric) || self.keywords.contains(base.as_ref()) {
+ suffixed.push(SEPARATOR);
+ }
+ debug_assert!(!self.keywords.contains(&suffixed));
+ // `self.unique` wants to own its keys. This allocates only if we haven't
+ // already done so earlier.
+ self.unique.insert(base.into_owned(), 0);
+ suffixed
+ }
+ }
+ }
+
+ pub fn call_or(&mut self, label: &Option<String>, fallback: &str) -> String {
+ self.call(match *label {
+ Some(ref name) => name,
+ None => fallback,
+ })
+ }
+
+ /// Enter a local namespace for things like structs.
+ ///
+ /// Struct member names only need to be unique amongst themselves, not
+ /// globally. This function temporarily establishes a fresh, empty naming
+ /// context for the duration of the call to `body`.
+ fn namespace(&mut self, capacity: usize, body: impl FnOnce(&mut Self)) {
+ let fresh = FastHashMap::with_capacity_and_hasher(capacity, Default::default());
+ let outer = std::mem::replace(&mut self.unique, fresh);
+ body(self);
+ self.unique = outer;
+ }
+
+ pub fn reset(
+ &mut self,
+ module: &crate::Module,
+ reserved_keywords: &[&str],
+ reserved_prefixes: &[&str],
+ output: &mut FastHashMap<NameKey, String>,
+ ) {
+ self.reserved_prefixes.clear();
+ self.reserved_prefixes
+ .extend(reserved_prefixes.iter().map(|string| string.to_string()));
+
+ self.unique.clear();
+ self.keywords.clear();
+ self.keywords
+ .extend(reserved_keywords.iter().map(|string| (string.to_string())));
+ let mut temp = String::new();
+
+ for (ty_handle, ty) in module.types.iter() {
+ let ty_name = self.call_or(&ty.name, "type");
+ output.insert(NameKey::Type(ty_handle), ty_name);
+
+ if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
+ // struct members have their own namespace, because access is always prefixed
+ self.namespace(members.len(), |namer| {
+ for (index, member) in members.iter().enumerate() {
+ let name = namer.call_or(&member.name, "member");
+ output.insert(NameKey::StructMember(ty_handle, index as u32), name);
+ }
+ })
+ }
+ }
+
+ for (ep_index, ep) in module.entry_points.iter().enumerate() {
+ let ep_name = self.call(&ep.name);
+ output.insert(NameKey::EntryPoint(ep_index as _), ep_name);
+ for (index, arg) in ep.function.arguments.iter().enumerate() {
+ let name = self.call_or(&arg.name, "param");
+ output.insert(
+ NameKey::EntryPointArgument(ep_index as _, index as u32),
+ name,
+ );
+ }
+ for (handle, var) in ep.function.local_variables.iter() {
+ let name = self.call_or(&var.name, "local");
+ output.insert(NameKey::EntryPointLocal(ep_index as _, handle), name);
+ }
+ }
+
+ for (fun_handle, fun) in module.functions.iter() {
+ let fun_name = self.call_or(&fun.name, "function");
+ output.insert(NameKey::Function(fun_handle), fun_name);
+ for (index, arg) in fun.arguments.iter().enumerate() {
+ let name = self.call_or(&arg.name, "param");
+ output.insert(NameKey::FunctionArgument(fun_handle, index as u32), name);
+ }
+ for (handle, var) in fun.local_variables.iter() {
+ let name = self.call_or(&var.name, "local");
+ output.insert(NameKey::FunctionLocal(fun_handle, handle), name);
+ }
+ }
+
+ for (handle, var) in module.global_variables.iter() {
+ let name = self.call_or(&var.name, "global");
+ output.insert(NameKey::GlobalVariable(handle), name);
+ }
+
+ for (handle, constant) in module.constants.iter() {
+ let label = match constant.name {
+ Some(ref name) => name,
+ None => {
+ use std::fmt::Write;
+ // Try to be more descriptive about the constant values
+ temp.clear();
+ match constant.inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Sint(v),
+ } => write!(temp, "const_{v}i"),
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Uint(v),
+ } => write!(temp, "const_{v}u"),
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Float(v),
+ } => {
+ let abs = v.abs();
+ write!(
+ temp,
+ "const_{}{}",
+ if v < 0.0 { "n" } else { "" },
+ abs.trunc(),
+ )
+ .unwrap();
+ let fract = abs.fract();
+ if fract == 0.0 {
+ write!(temp, "f")
+ } else {
+ write!(temp, "_{:02}f", (fract * 100.0) as i8)
+ }
+ }
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Bool(v),
+ } => write!(temp, "const_{v}"),
+ crate::ConstantInner::Composite { ty, components: _ } => {
+ write!(temp, "const_{}", output[&NameKey::Type(ty)])
+ }
+ }
+ .unwrap();
+ &temp
+ }
+ };
+ let name = self.call(label);
+ output.insert(NameKey::Constant(handle), name);
+ }
+ }
+}
+
+#[test]
+fn test() {
+ let mut namer = Namer::default();
+ assert_eq!(namer.call("x"), "x");
+ assert_eq!(namer.call("x"), "x_1");
+ assert_eq!(namer.call("x1"), "x1_");
+}
diff --git a/third_party/rust/naga/src/proc/terminator.rs b/third_party/rust/naga/src/proc/terminator.rs
new file mode 100644
index 0000000000..ca0c3f10bc
--- /dev/null
+++ b/third_party/rust/naga/src/proc/terminator.rs
@@ -0,0 +1,43 @@
+/// Ensure that the given block has return statements
+/// at the end of its control flow.
+///
+/// Note: we don't want to blindly append a return statement
+/// to the end, because it may be either redundant or invalid,
+/// e.g. when the user already has returns in if/else branches.
+pub fn ensure_block_returns(block: &mut crate::Block) {
+ use crate::Statement as S;
+ match block.last_mut() {
+ Some(&mut S::Block(ref mut b)) => {
+ ensure_block_returns(b);
+ }
+ Some(&mut S::If {
+ condition: _,
+ ref mut accept,
+ ref mut reject,
+ }) => {
+ ensure_block_returns(accept);
+ ensure_block_returns(reject);
+ }
+ Some(&mut S::Switch {
+ selector: _,
+ ref mut cases,
+ }) => {
+ for case in cases.iter_mut() {
+ if !case.fall_through {
+ ensure_block_returns(&mut case.body);
+ }
+ }
+ }
+ Some(&mut (S::Emit(_) | S::Break | S::Continue | S::Return { .. } | S::Kill)) => (),
+ Some(
+ &mut (S::Loop { .. }
+ | S::Store { .. }
+ | S::ImageStore { .. }
+ | S::Call { .. }
+ | S::RayQuery { .. }
+ | S::Atomic { .. }
+ | S::Barrier(_)),
+ )
+ | None => block.push(S::Return { value: None }, Default::default()),
+ }
+}
diff --git a/third_party/rust/naga/src/proc/typifier.rs b/third_party/rust/naga/src/proc/typifier.rs
new file mode 100644
index 0000000000..0bb9019a29
--- /dev/null
+++ b/third_party/rust/naga/src/proc/typifier.rs
@@ -0,0 +1,909 @@
+use crate::arena::{Arena, Handle, UniqueArena};
+
+use thiserror::Error;
+
+/// The result of computing an expression's type.
+///
+/// This is the (Rust) type returned by [`ResolveContext::resolve`] to represent
+/// the (Naga) type it ascribes to some expression.
+///
+/// You might expect such a function to simply return a `Handle<Type>`. However,
+/// we want type resolution to be a read-only process, and that would limit the
+/// possible results to types already present in the expression's associated
+/// `UniqueArena<Type>`. Naga IR does have certain expressions whose types are
+/// not certain to be present.
+///
+/// So instead, type resolution returns a `TypeResolution` enum: either a
+/// [`Handle`], referencing some type in the arena, or a [`Value`], holding a
+/// free-floating [`TypeInner`]. This extends the range to cover anything that
+/// can be represented with a `TypeInner` referring to the existing arena.
+///
+/// What sorts of expressions can have types not available in the arena?
+///
+/// - An [`Access`] or [`AccessIndex`] expression applied to a [`Vector`] or
+/// [`Matrix`] must have a [`Scalar`] or [`Vector`] type. But since `Vector`
+/// and `Matrix` represent their element and column types implicitly, not
+/// via a handle, there may not be a suitable type in the expression's
+/// associated arena. Instead, resolving such an expression returns a
+/// `TypeResolution::Value(TypeInner::X { ... })`, where `X` is `Scalar` or
+/// `Vector`.
+///
+/// - Similarly, the type of an [`Access`] or [`AccessIndex`] expression
+/// applied to a *pointer to* a vector or matrix must produce a *pointer to*
+/// a scalar or vector type. These cannot be represented with a
+/// [`TypeInner::Pointer`], since the `Pointer`'s `base` must point into the
+/// arena, and as before, we cannot assume that a suitable scalar or vector
+/// type is there. So we take things one step further and provide
+/// [`TypeInner::ValuePointer`], specifically for the case of pointers to
+/// scalars or vectors. This type fits in a `TypeInner` and is exactly
+/// equivalent to a `Pointer` to a `Vector` or `Scalar`.
+///
+/// So, for example, the type of an `Access` expression applied to a value of type:
+///
+/// ```ignore
+/// TypeInner::Matrix { columns, rows, width }
+/// ```
+///
+/// might be:
+///
+/// ```ignore
+/// TypeResolution::Value(TypeInner::Vector {
+/// size: rows,
+/// kind: ScalarKind::Float,
+/// width,
+/// })
+/// ```
+///
+/// and the type of an access to a pointer of address space `space` to such a
+/// matrix might be:
+///
+/// ```ignore
+/// TypeResolution::Value(TypeInner::ValuePointer {
+/// size: Some(rows),
+/// kind: ScalarKind::Float,
+/// width,
+/// space,
+/// })
+/// ```
+///
+/// [`Handle`]: TypeResolution::Handle
+/// [`Value`]: TypeResolution::Value
+///
+/// [`Access`]: crate::Expression::Access
+/// [`AccessIndex`]: crate::Expression::AccessIndex
+///
+/// [`TypeInner`]: crate::TypeInner
+/// [`Matrix`]: crate::TypeInner::Matrix
+/// [`Pointer`]: crate::TypeInner::Pointer
+/// [`Scalar`]: crate::TypeInner::Scalar
+/// [`ValuePointer`]: crate::TypeInner::ValuePointer
+/// [`Vector`]: crate::TypeInner::Vector
+///
+/// [`TypeInner::Pointer`]: crate::TypeInner::Pointer
+/// [`TypeInner::ValuePointer`]: crate::TypeInner::ValuePointer
+#[derive(Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum TypeResolution {
+ /// A type stored in the associated arena.
+ Handle(Handle<crate::Type>),
+
+ /// A free-floating [`TypeInner`], representing a type that may not be
+ /// available in the associated arena. However, the `TypeInner` itself may
+ /// contain `Handle<Type>` values referring to types from the arena.
+ ///
+ /// [`TypeInner`]: crate::TypeInner
+ Value(crate::TypeInner),
+}
+
+impl TypeResolution {
+ pub const fn handle(&self) -> Option<Handle<crate::Type>> {
+ match *self {
+ Self::Handle(handle) => Some(handle),
+ Self::Value(_) => None,
+ }
+ }
+
+ pub fn inner_with<'a>(&'a self, arena: &'a UniqueArena<crate::Type>) -> &'a crate::TypeInner {
+ match *self {
+ Self::Handle(handle) => &arena[handle].inner,
+ Self::Value(ref inner) => inner,
+ }
+ }
+}
+
+// Clone is only implemented for numeric variants of `TypeInner`.
+impl Clone for TypeResolution {
+ fn clone(&self) -> Self {
+ use crate::TypeInner as Ti;
+ match *self {
+ Self::Handle(handle) => Self::Handle(handle),
+ Self::Value(ref v) => Self::Value(match *v {
+ Ti::Scalar { kind, width } => Ti::Scalar { kind, width },
+ Ti::Vector { size, kind, width } => Ti::Vector { size, kind, width },
+ Ti::Matrix {
+ rows,
+ columns,
+ width,
+ } => Ti::Matrix {
+ rows,
+ columns,
+ width,
+ },
+ Ti::Pointer { base, space } => Ti::Pointer { base, space },
+ Ti::ValuePointer {
+ size,
+ kind,
+ width,
+ space,
+ } => Ti::ValuePointer {
+ size,
+ kind,
+ width,
+ space,
+ },
+ _ => unreachable!("Unexpected clone type: {:?}", v),
+ }),
+ }
+ }
+}
+
+impl crate::ConstantInner {
+ pub const fn resolve_type(&self) -> TypeResolution {
+ match *self {
+ Self::Scalar { width, ref value } => TypeResolution::Value(crate::TypeInner::Scalar {
+ kind: value.scalar_kind(),
+ width,
+ }),
+ Self::Composite { ty, components: _ } => TypeResolution::Handle(ty),
+ }
+ }
+}
+
+#[derive(Clone, Debug, Error, PartialEq)]
+pub enum ResolveError {
+ #[error("Index {index} is out of bounds for expression {expr:?}")]
+ OutOfBoundsIndex {
+ expr: Handle<crate::Expression>,
+ index: u32,
+ },
+ #[error("Invalid access into expression {expr:?}, indexed: {indexed}")]
+ InvalidAccess {
+ expr: Handle<crate::Expression>,
+ indexed: bool,
+ },
+ #[error("Invalid sub-access into type {ty:?}, indexed: {indexed}")]
+ InvalidSubAccess {
+ ty: Handle<crate::Type>,
+ indexed: bool,
+ },
+ #[error("Invalid scalar {0:?}")]
+ InvalidScalar(Handle<crate::Expression>),
+ #[error("Invalid vector {0:?}")]
+ InvalidVector(Handle<crate::Expression>),
+ #[error("Invalid pointer {0:?}")]
+ InvalidPointer(Handle<crate::Expression>),
+ #[error("Invalid image {0:?}")]
+ InvalidImage(Handle<crate::Expression>),
+ #[error("Function {name} not defined")]
+ FunctionNotDefined { name: String },
+ #[error("Function without return type")]
+ FunctionReturnsVoid,
+ #[error("Incompatible operands: {0}")]
+ IncompatibleOperands(String),
+ #[error("Function argument {0} doesn't exist")]
+ FunctionArgumentNotFound(u32),
+ #[error("Special type is not registered within the module")]
+ MissingSpecialType,
+}
+
+pub struct ResolveContext<'a> {
+ pub constants: &'a Arena<crate::Constant>,
+ pub types: &'a UniqueArena<crate::Type>,
+ pub special_types: &'a crate::SpecialTypes,
+ pub global_vars: &'a Arena<crate::GlobalVariable>,
+ pub local_vars: &'a Arena<crate::LocalVariable>,
+ pub functions: &'a Arena<crate::Function>,
+ pub arguments: &'a [crate::FunctionArgument],
+}
+
+impl<'a> ResolveContext<'a> {
+ /// Initialize a resolve context from the module.
+ pub const fn with_locals(
+ module: &'a crate::Module,
+ local_vars: &'a Arena<crate::LocalVariable>,
+ arguments: &'a [crate::FunctionArgument],
+ ) -> Self {
+ Self {
+ constants: &module.constants,
+ types: &module.types,
+ special_types: &module.special_types,
+ global_vars: &module.global_variables,
+ local_vars,
+ functions: &module.functions,
+ arguments,
+ }
+ }
+
+ /// Determine the type of `expr`.
+ ///
+ /// The `past` argument must be a closure that can resolve the types of any
+ /// expressions that `expr` refers to. These can be gathered by caching the
+ /// results of prior calls to `resolve`, perhaps as done by the
+ /// [`front::Typifier`] utility type.
+ ///
+ /// Type resolution is a read-only process: this method takes `self` by
+ /// shared reference. However, this means that we cannot add anything to
+ /// `self.types` that we might need to describe `expr`. To work around this,
+ /// this method returns a [`TypeResolution`], rather than simply returning a
+ /// `Handle<Type>`; see the documentation for [`TypeResolution`] for
+ /// details.
+ ///
+ /// [`front::Typifier`]: crate::front::Typifier
+ pub fn resolve(
+ &self,
+ expr: &crate::Expression,
+ past: impl Fn(Handle<crate::Expression>) -> Result<&'a TypeResolution, ResolveError>,
+ ) -> Result<TypeResolution, ResolveError> {
+ use crate::TypeInner as Ti;
+ let types = self.types;
+ Ok(match *expr {
+ crate::Expression::Access { base, .. } => match *past(base)?.inner_with(types) {
+ // Arrays and matrices can only be indexed dynamically behind a
+ // pointer, but that's a validation error, not a type error, so
+ // go ahead provide a type here.
+ Ti::Array { base, .. } => TypeResolution::Handle(base),
+ Ti::Matrix { rows, width, .. } => TypeResolution::Value(Ti::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ }),
+ Ti::Vector {
+ size: _,
+ kind,
+ width,
+ } => TypeResolution::Value(Ti::Scalar { kind, width }),
+ Ti::ValuePointer {
+ size: Some(_),
+ kind,
+ width,
+ space,
+ } => TypeResolution::Value(Ti::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space,
+ }),
+ Ti::Pointer { base, space } => {
+ TypeResolution::Value(match types[base].inner {
+ Ti::Array { base, .. } => Ti::Pointer { base, space },
+ Ti::Vector {
+ size: _,
+ kind,
+ width,
+ } => Ti::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space,
+ },
+ // Matrices are only dynamically indexed behind a pointer
+ Ti::Matrix {
+ columns: _,
+ rows,
+ width,
+ } => Ti::ValuePointer {
+ kind: crate::ScalarKind::Float,
+ size: Some(rows),
+ width,
+ space,
+ },
+ ref other => {
+ log::error!("Access sub-type {:?}", other);
+ return Err(ResolveError::InvalidSubAccess {
+ ty: base,
+ indexed: false,
+ });
+ }
+ })
+ }
+ Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
+ ref other => {
+ log::error!("Access type {:?}", other);
+ return Err(ResolveError::InvalidAccess {
+ expr: base,
+ indexed: false,
+ });
+ }
+ },
+ crate::Expression::AccessIndex { base, index } => {
+ match *past(base)?.inner_with(types) {
+ Ti::Vector { size, kind, width } => {
+ if index >= size as u32 {
+ return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
+ }
+ TypeResolution::Value(Ti::Scalar { kind, width })
+ }
+ Ti::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ if index >= columns as u32 {
+ return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
+ }
+ TypeResolution::Value(crate::TypeInner::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ })
+ }
+ Ti::Array { base, .. } => TypeResolution::Handle(base),
+ Ti::Struct { ref members, .. } => {
+ let member = members
+ .get(index as usize)
+ .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
+ TypeResolution::Handle(member.ty)
+ }
+ Ti::ValuePointer {
+ size: Some(size),
+ kind,
+ width,
+ space,
+ } => {
+ if index >= size as u32 {
+ return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
+ }
+ TypeResolution::Value(Ti::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space,
+ })
+ }
+ Ti::Pointer {
+ base: ty_base,
+ space,
+ } => TypeResolution::Value(match types[ty_base].inner {
+ Ti::Array { base, .. } => Ti::Pointer { base, space },
+ Ti::Vector { size, kind, width } => {
+ if index >= size as u32 {
+ return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
+ }
+ Ti::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space,
+ }
+ }
+ Ti::Matrix {
+ rows,
+ columns,
+ width,
+ } => {
+ if index >= columns as u32 {
+ return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
+ }
+ Ti::ValuePointer {
+ size: Some(rows),
+ kind: crate::ScalarKind::Float,
+ width,
+ space,
+ }
+ }
+ Ti::Struct { ref members, .. } => {
+ let member = members
+ .get(index as usize)
+ .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
+ Ti::Pointer {
+ base: member.ty,
+ space,
+ }
+ }
+ ref other => {
+ log::error!("Access index sub-type {:?}", other);
+ return Err(ResolveError::InvalidSubAccess {
+ ty: ty_base,
+ indexed: true,
+ });
+ }
+ }),
+ Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
+ ref other => {
+ log::error!("Access index type {:?}", other);
+ return Err(ResolveError::InvalidAccess {
+ expr: base,
+ indexed: true,
+ });
+ }
+ }
+ }
+ crate::Expression::Constant(h) => match self.constants[h].inner {
+ crate::ConstantInner::Scalar { width, ref value } => {
+ TypeResolution::Value(Ti::Scalar {
+ kind: value.scalar_kind(),
+ width,
+ })
+ }
+ crate::ConstantInner::Composite { ty, components: _ } => TypeResolution::Handle(ty),
+ },
+ crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) {
+ Ti::Scalar { kind, width } => {
+ TypeResolution::Value(Ti::Vector { size, kind, width })
+ }
+ ref other => {
+ log::error!("Scalar type {:?}", other);
+ return Err(ResolveError::InvalidScalar(value));
+ }
+ },
+ crate::Expression::Swizzle {
+ size,
+ vector,
+ pattern: _,
+ } => match *past(vector)?.inner_with(types) {
+ Ti::Vector {
+ size: _,
+ kind,
+ width,
+ } => TypeResolution::Value(Ti::Vector { size, kind, width }),
+ ref other => {
+ log::error!("Vector type {:?}", other);
+ return Err(ResolveError::InvalidVector(vector));
+ }
+ },
+ crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty),
+ crate::Expression::FunctionArgument(index) => {
+ let arg = self
+ .arguments
+ .get(index as usize)
+ .ok_or(ResolveError::FunctionArgumentNotFound(index))?;
+ TypeResolution::Handle(arg.ty)
+ }
+ crate::Expression::GlobalVariable(h) => {
+ let var = &self.global_vars[h];
+ if var.space == crate::AddressSpace::Handle {
+ TypeResolution::Handle(var.ty)
+ } else {
+ TypeResolution::Value(Ti::Pointer {
+ base: var.ty,
+ space: var.space,
+ })
+ }
+ }
+ crate::Expression::LocalVariable(h) => {
+ let var = &self.local_vars[h];
+ TypeResolution::Value(Ti::Pointer {
+ base: var.ty,
+ space: crate::AddressSpace::Function,
+ })
+ }
+ crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) {
+ Ti::Pointer { base, space: _ } => {
+ if let Ti::Atomic { kind, width } = types[base].inner {
+ TypeResolution::Value(Ti::Scalar { kind, width })
+ } else {
+ TypeResolution::Handle(base)
+ }
+ }
+ Ti::ValuePointer {
+ size,
+ kind,
+ width,
+ space: _,
+ } => TypeResolution::Value(match size {
+ Some(size) => Ti::Vector { size, kind, width },
+ None => Ti::Scalar { kind, width },
+ }),
+ ref other => {
+ log::error!("Pointer type {:?}", other);
+ return Err(ResolveError::InvalidPointer(pointer));
+ }
+ },
+ crate::Expression::ImageSample {
+ image,
+ gather: Some(_),
+ ..
+ } => match *past(image)?.inner_with(types) {
+ Ti::Image { class, .. } => TypeResolution::Value(Ti::Vector {
+ kind: match class {
+ crate::ImageClass::Sampled { kind, multi: _ } => kind,
+ _ => crate::ScalarKind::Float,
+ },
+ width: 4,
+ size: crate::VectorSize::Quad,
+ }),
+ ref other => {
+ log::error!("Image type {:?}", other);
+ return Err(ResolveError::InvalidImage(image));
+ }
+ },
+ crate::Expression::ImageSample { image, .. }
+ | crate::Expression::ImageLoad { image, .. } => match *past(image)?.inner_with(types) {
+ Ti::Image { class, .. } => TypeResolution::Value(match class {
+ crate::ImageClass::Depth { multi: _ } => Ti::Scalar {
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ },
+ crate::ImageClass::Sampled { kind, multi: _ } => Ti::Vector {
+ kind,
+ width: 4,
+ size: crate::VectorSize::Quad,
+ },
+ crate::ImageClass::Storage { format, .. } => Ti::Vector {
+ kind: format.into(),
+ width: 4,
+ size: crate::VectorSize::Quad,
+ },
+ }),
+ ref other => {
+ log::error!("Image type {:?}", other);
+ return Err(ResolveError::InvalidImage(image));
+ }
+ },
+ crate::Expression::ImageQuery { image, query } => TypeResolution::Value(match query {
+ crate::ImageQuery::Size { level: _ } => match *past(image)?.inner_with(types) {
+ Ti::Image { dim, .. } => match dim {
+ crate::ImageDimension::D1 => Ti::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ crate::ImageDimension::D2 | crate::ImageDimension::Cube => Ti::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ crate::ImageDimension::D3 => Ti::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ },
+ ref other => {
+ log::error!("Image type {:?}", other);
+ return Err(ResolveError::InvalidImage(image));
+ }
+ },
+ crate::ImageQuery::NumLevels
+ | crate::ImageQuery::NumLayers
+ | crate::ImageQuery::NumSamples => Ti::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ }),
+ crate::Expression::Unary { expr, .. } => past(expr)?.clone(),
+ crate::Expression::Binary { op, left, right } => match op {
+ crate::BinaryOperator::Add
+ | crate::BinaryOperator::Subtract
+ | crate::BinaryOperator::Divide
+ | crate::BinaryOperator::Modulo => past(left)?.clone(),
+ crate::BinaryOperator::Multiply => {
+ let (res_left, res_right) = (past(left)?, past(right)?);
+ match (res_left.inner_with(types), res_right.inner_with(types)) {
+ (
+ &Ti::Matrix {
+ columns: _,
+ rows,
+ width,
+ },
+ &Ti::Matrix { columns, .. },
+ ) => TypeResolution::Value(Ti::Matrix {
+ columns,
+ rows,
+ width,
+ }),
+ (
+ &Ti::Matrix {
+ columns: _,
+ rows,
+ width,
+ },
+ &Ti::Vector { .. },
+ ) => TypeResolution::Value(Ti::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ }),
+ (
+ &Ti::Vector { .. },
+ &Ti::Matrix {
+ columns,
+ rows: _,
+ width,
+ },
+ ) => TypeResolution::Value(Ti::Vector {
+ size: columns,
+ kind: crate::ScalarKind::Float,
+ width,
+ }),
+ (&Ti::Scalar { .. }, _) => res_right.clone(),
+ (_, &Ti::Scalar { .. }) => res_left.clone(),
+ (&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(),
+ (tl, tr) => {
+ return Err(ResolveError::IncompatibleOperands(format!(
+ "{tl:?} * {tr:?}"
+ )))
+ }
+ }
+ }
+ crate::BinaryOperator::Equal
+ | crate::BinaryOperator::NotEqual
+ | crate::BinaryOperator::Less
+ | crate::BinaryOperator::LessEqual
+ | crate::BinaryOperator::Greater
+ | crate::BinaryOperator::GreaterEqual
+ | crate::BinaryOperator::LogicalAnd
+ | crate::BinaryOperator::LogicalOr => {
+ let kind = crate::ScalarKind::Bool;
+ let width = crate::BOOL_WIDTH;
+ let inner = match *past(left)?.inner_with(types) {
+ Ti::Scalar { .. } => Ti::Scalar { kind, width },
+ Ti::Vector { size, .. } => Ti::Vector { size, kind, width },
+ ref other => {
+ return Err(ResolveError::IncompatibleOperands(format!(
+ "{op:?}({other:?}, _)"
+ )))
+ }
+ };
+ TypeResolution::Value(inner)
+ }
+ crate::BinaryOperator::And
+ | crate::BinaryOperator::ExclusiveOr
+ | crate::BinaryOperator::InclusiveOr
+ | crate::BinaryOperator::ShiftLeft
+ | crate::BinaryOperator::ShiftRight => past(left)?.clone(),
+ },
+ crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
+ crate::Expression::Select { accept, .. } => past(accept)?.clone(),
+ crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
+ crate::Expression::Relational { fun, argument } => match fun {
+ crate::RelationalFunction::All | crate::RelationalFunction::Any => {
+ TypeResolution::Value(Ti::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ })
+ }
+ crate::RelationalFunction::IsNan
+ | crate::RelationalFunction::IsInf
+ | crate::RelationalFunction::IsFinite
+ | crate::RelationalFunction::IsNormal => match *past(argument)?.inner_with(types) {
+ Ti::Scalar { .. } => TypeResolution::Value(Ti::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ }),
+ Ti::Vector { size, .. } => TypeResolution::Value(Ti::Vector {
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ size,
+ }),
+ ref other => {
+ return Err(ResolveError::IncompatibleOperands(format!(
+ "{fun:?}({other:?})"
+ )))
+ }
+ },
+ },
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2: _,
+ arg3: _,
+ } => {
+ use crate::MathFunction as Mf;
+ let res_arg = past(arg)?;
+ match fun {
+ // comparison
+ Mf::Abs |
+ Mf::Min |
+ Mf::Max |
+ Mf::Clamp |
+ Mf::Saturate |
+ // trigonometry
+ Mf::Cos |
+ Mf::Cosh |
+ Mf::Sin |
+ Mf::Sinh |
+ Mf::Tan |
+ Mf::Tanh |
+ Mf::Acos |
+ Mf::Asin |
+ Mf::Atan |
+ Mf::Atan2 |
+ Mf::Asinh |
+ Mf::Acosh |
+ Mf::Atanh |
+ Mf::Radians |
+ Mf::Degrees |
+ // decomposition
+ Mf::Ceil |
+ Mf::Floor |
+ Mf::Round |
+ Mf::Fract |
+ Mf::Trunc |
+ Mf::Modf |
+ Mf::Frexp |
+ Mf::Ldexp |
+ // exponent
+ Mf::Exp |
+ Mf::Exp2 |
+ Mf::Log |
+ Mf::Log2 |
+ Mf::Pow => res_arg.clone(),
+ // geometry
+ Mf::Dot => match *res_arg.inner_with(types) {
+ Ti::Vector {
+ kind,
+ size: _,
+ width,
+ } => TypeResolution::Value(Ti::Scalar { kind, width }),
+ ref other =>
+ return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?}, _)")
+ )),
+ },
+ Mf::Outer => {
+ let arg1 = arg1.ok_or_else(|| ResolveError::IncompatibleOperands(
+ format!("{fun:?}(_, None)")
+ ))?;
+ match (res_arg.inner_with(types), past(arg1)?.inner_with(types)) {
+ (&Ti::Vector {kind: _, size: columns,width}, &Ti::Vector{ size: rows, .. }) => TypeResolution::Value(Ti::Matrix { columns, rows, width }),
+ (left, right) =>
+ return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({left:?}, {right:?})")
+ )),
+ }
+ },
+ Mf::Cross => res_arg.clone(),
+ Mf::Distance |
+ Mf::Length => match *res_arg.inner_with(types) {
+ Ti::Scalar {width,kind} |
+ Ti::Vector {width,kind,size:_} => TypeResolution::Value(Ti::Scalar { kind, width }),
+ ref other => return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?})")
+ )),
+ },
+ Mf::Normalize |
+ Mf::FaceForward |
+ Mf::Reflect |
+ Mf::Refract => res_arg.clone(),
+ // computational
+ Mf::Sign |
+ Mf::Fma |
+ Mf::Mix |
+ Mf::Step |
+ Mf::SmoothStep |
+ Mf::Sqrt |
+ Mf::InverseSqrt => res_arg.clone(),
+ Mf::Transpose => match *res_arg.inner_with(types) {
+ Ti::Matrix {
+ columns,
+ rows,
+ width,
+ } => TypeResolution::Value(Ti::Matrix {
+ columns: rows,
+ rows: columns,
+ width,
+ }),
+ ref other => return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?})")
+ )),
+ },
+ Mf::Inverse => match *res_arg.inner_with(types) {
+ Ti::Matrix {
+ columns,
+ rows,
+ width,
+ } if columns == rows => TypeResolution::Value(Ti::Matrix {
+ columns,
+ rows,
+ width,
+ }),
+ ref other => return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?})")
+ )),
+ },
+ Mf::Determinant => match *res_arg.inner_with(types) {
+ Ti::Matrix {
+ width,
+ ..
+ } => TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Float, width }),
+ ref other => return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?})")
+ )),
+ },
+ // bits
+ Mf::CountTrailingZeros |
+ Mf::CountLeadingZeros |
+ Mf::CountOneBits |
+ Mf::ReverseBits |
+ Mf::ExtractBits |
+ Mf::InsertBits |
+ Mf::FindLsb |
+ Mf::FindMsb => match *res_arg.inner_with(types) {
+ Ti::Scalar { kind: kind @ (crate::ScalarKind::Sint | crate::ScalarKind::Uint), width } =>
+ TypeResolution::Value(Ti::Scalar { kind, width }),
+ Ti::Vector { size, kind: kind @ (crate::ScalarKind::Sint | crate::ScalarKind::Uint), width } =>
+ TypeResolution::Value(Ti::Vector { size, kind, width }),
+ ref other => return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?})")
+ )),
+ },
+ // data packing
+ Mf::Pack4x8snorm |
+ Mf::Pack4x8unorm |
+ Mf::Pack2x16snorm |
+ Mf::Pack2x16unorm |
+ Mf::Pack2x16float => TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Uint, width: 4 }),
+ // data unpacking
+ Mf::Unpack4x8snorm |
+ Mf::Unpack4x8unorm => TypeResolution::Value(Ti::Vector { size: crate::VectorSize::Quad, kind: crate::ScalarKind::Float, width: 4 }),
+ Mf::Unpack2x16snorm |
+ Mf::Unpack2x16unorm |
+ Mf::Unpack2x16float => TypeResolution::Value(Ti::Vector { size: crate::VectorSize::Bi, kind: crate::ScalarKind::Float, width: 4 }),
+ }
+ }
+ crate::Expression::As {
+ expr,
+ kind,
+ convert,
+ } => match *past(expr)?.inner_with(types) {
+ Ti::Scalar { kind: _, width } => TypeResolution::Value(Ti::Scalar {
+ kind,
+ width: convert.unwrap_or(width),
+ }),
+ Ti::Vector {
+ kind: _,
+ size,
+ width,
+ } => TypeResolution::Value(Ti::Vector {
+ kind,
+ size,
+ width: convert.unwrap_or(width),
+ }),
+ Ti::Matrix {
+ columns,
+ rows,
+ width,
+ } => TypeResolution::Value(Ti::Matrix {
+ columns,
+ rows,
+ width: convert.unwrap_or(width),
+ }),
+ ref other => {
+ return Err(ResolveError::IncompatibleOperands(format!(
+ "{other:?} as {kind:?}"
+ )))
+ }
+ },
+ crate::Expression::CallResult(function) => {
+ let result = self.functions[function]
+ .result
+ .as_ref()
+ .ok_or(ResolveError::FunctionReturnsVoid)?;
+ TypeResolution::Handle(result.ty)
+ }
+ crate::Expression::ArrayLength(_) => TypeResolution::Value(Ti::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ }),
+ crate::Expression::RayQueryProceedResult => TypeResolution::Value(Ti::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ }),
+ crate::Expression::RayQueryGetIntersection { .. } => {
+ let result = self
+ .special_types
+ .ray_intersection
+ .ok_or(ResolveError::MissingSpecialType)?;
+ TypeResolution::Handle(result)
+ }
+ })
+ }
+}
+
+#[test]
+fn test_error_size() {
+ use std::mem::size_of;
+ assert_eq!(size_of::<ResolveError>(), 32);
+}
diff --git a/third_party/rust/naga/src/span.rs b/third_party/rust/naga/src/span.rs
new file mode 100644
index 0000000000..90bf3d174e
--- /dev/null
+++ b/third_party/rust/naga/src/span.rs
@@ -0,0 +1,523 @@
+use crate::{Arena, Handle, UniqueArena};
+use std::{error::Error, fmt, ops::Range};
+
+/// A source code span, used for error reporting.
+#[derive(Clone, Copy, Debug, PartialEq, Default)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub struct Span {
+ start: u32,
+ end: u32,
+}
+
+impl Span {
+ pub const UNDEFINED: Self = Self { start: 0, end: 0 };
+ /// Creates a new `Span` from a range of byte indices
+ ///
+ /// Note: end is exclusive, it doesn't belong to the `Span`
+ pub const fn new(start: u32, end: u32) -> Self {
+ Span { start, end }
+ }
+
+ /// Returns a new `Span` starting at `self` and ending at `other`
+ pub const fn until(&self, other: &Self) -> Self {
+ Span {
+ start: self.start,
+ end: other.end,
+ }
+ }
+
+ /// Modifies `self` to contain the smallest `Span` possible that
+ /// contains both `self` and `other`
+ pub fn subsume(&mut self, other: Self) {
+ *self = if !self.is_defined() {
+ // self isn't defined so use other
+ other
+ } else if !other.is_defined() {
+ // other isn't defined so don't try to subsume
+ *self
+ } else {
+ // Both self and other are defined so calculate the span that contains them both
+ Span {
+ start: self.start.min(other.start),
+ end: self.end.max(other.end),
+ }
+ }
+ }
+
+ /// Returns the smallest `Span` possible that contains all the `Span`s
+ /// defined in the `from` iterator
+ pub fn total_span<T: Iterator<Item = Self>>(from: T) -> Self {
+ let mut span: Self = Default::default();
+ for other in from {
+ span.subsume(other);
+ }
+ span
+ }
+
+ /// Converts `self` to a range if the span is not unknown
+ pub fn to_range(self) -> Option<Range<usize>> {
+ if self.is_defined() {
+ Some(self.start as usize..self.end as usize)
+ } else {
+ None
+ }
+ }
+
+ /// Check whether `self` was defined or is a default/unknown span
+ pub fn is_defined(&self) -> bool {
+ *self != Self::default()
+ }
+
+ /// Return a [`SourceLocation`] for this span in the provided source.
+ pub fn location(&self, source: &str) -> SourceLocation {
+ let prefix = &source[..self.start as usize];
+ let line_number = prefix.matches('\n').count() as u32 + 1;
+ let line_start = prefix.rfind('\n').map(|pos| pos + 1).unwrap_or(0);
+ let line_position = source[line_start..self.start as usize].chars().count() as u32 + 1;
+
+ SourceLocation {
+ line_number,
+ line_position,
+ offset: self.start,
+ length: self.end - self.start,
+ }
+ }
+}
+
+impl From<Range<usize>> for Span {
+ fn from(range: Range<usize>) -> Self {
+ Span {
+ start: range.start as u32,
+ end: range.end as u32,
+ }
+ }
+}
+
+impl std::ops::Index<Span> for str {
+ type Output = str;
+
+ #[inline]
+ fn index(&self, span: Span) -> &str {
+ &self[span.start as usize..span.end as usize]
+ }
+}
+
+/// A human-readable representation for a span, tailored for text source.
+///
+/// Corresponds to the positional members of [`GPUCompilationMessage`][gcm] from
+/// the WebGPU specification, except that `offset` and `length` are in bytes
+/// (UTF-8 code units), instead of UTF-16 code units.
+///
+/// [gcm]: https://www.w3.org/TR/webgpu/#gpucompilationmessage
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+pub struct SourceLocation {
+ /// 1-based line number.
+ pub line_number: u32,
+ /// 1-based column of the start of this span
+ pub line_position: u32,
+ /// 0-based Offset in code units (in bytes) of the start of the span.
+ pub offset: u32,
+ /// Length in code units (in bytes) of the span.
+ pub length: u32,
+}
+
+/// A source code span together with "context", a user-readable description of what part of the error it refers to.
+pub type SpanContext = (Span, String);
+
+/// Wrapper class for [`Error`], augmenting it with a list of [`SpanContext`]s.
+#[derive(Debug, Clone)]
+pub struct WithSpan<E> {
+ inner: E,
+ #[cfg(feature = "span")]
+ spans: Vec<SpanContext>,
+}
+
+impl<E> fmt::Display for WithSpan<E>
+where
+ E: fmt::Display,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
+ self.inner.fmt(f)
+ }
+}
+
+#[cfg(test)]
+impl<E> PartialEq for WithSpan<E>
+where
+ E: PartialEq,
+{
+ fn eq(&self, other: &Self) -> bool {
+ self.inner.eq(&other.inner)
+ }
+}
+
+impl<E> Error for WithSpan<E>
+where
+ E: Error,
+{
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ self.inner.source()
+ }
+}
+
+impl<E> WithSpan<E> {
+ /// Create a new [`WithSpan`] from an [`Error`], containing no spans.
+ pub const fn new(inner: E) -> Self {
+ Self {
+ inner,
+ #[cfg(feature = "span")]
+ spans: Vec::new(),
+ }
+ }
+
+ /// Reverse of [`Self::new`], discards span information and returns an inner error.
+ #[allow(clippy::missing_const_for_fn)] // ignore due to requirement of #![feature(const_precise_live_drops)]
+ pub fn into_inner(self) -> E {
+ self.inner
+ }
+
+ pub const fn as_inner(&self) -> &E {
+ &self.inner
+ }
+
+ /// Iterator over stored [`SpanContext`]s.
+ pub fn spans(&self) -> impl Iterator<Item = &SpanContext> + ExactSizeIterator {
+ #[cfg(feature = "span")]
+ return self.spans.iter();
+ #[cfg(not(feature = "span"))]
+ return std::iter::empty();
+ }
+
+ /// Add a new span with description.
+ #[cfg_attr(not(feature = "span"), allow(unused_variables, unused_mut))]
+ pub fn with_span<S>(mut self, span: Span, description: S) -> Self
+ where
+ S: ToString,
+ {
+ #[cfg(feature = "span")]
+ if span.is_defined() {
+ self.spans.push((span, description.to_string()));
+ }
+ self
+ }
+
+ /// Add a [`SpanContext`].
+ pub fn with_context(self, span_context: SpanContext) -> Self {
+ let (span, description) = span_context;
+ self.with_span(span, description)
+ }
+
+ /// Add a [`Handle`] from either [`Arena`] or [`UniqueArena`], borrowing its span information from there
+ /// and annotating with a type and the handle representation.
+ pub(crate) fn with_handle<T, A: SpanProvider<T>>(self, handle: Handle<T>, arena: &A) -> Self {
+ self.with_context(arena.get_span_context(handle))
+ }
+
+ /// Convert inner error using [`From`].
+ pub fn into_other<E2>(self) -> WithSpan<E2>
+ where
+ E2: From<E>,
+ {
+ WithSpan {
+ inner: self.inner.into(),
+ #[cfg(feature = "span")]
+ spans: self.spans,
+ }
+ }
+
+ /// Convert inner error into another type. Joins span information contained in `self`
+ /// with what is returned from `func`.
+ pub fn and_then<F, E2>(self, func: F) -> WithSpan<E2>
+ where
+ F: FnOnce(E) -> WithSpan<E2>,
+ {
+ #[cfg_attr(not(feature = "span"), allow(unused_mut))]
+ let mut res = func(self.inner);
+ #[cfg(feature = "span")]
+ res.spans.extend(self.spans);
+ res
+ }
+
+ #[cfg(feature = "span")]
+ /// Return a [`SourceLocation`] for our first span, if we have one.
+ pub fn location(&self, source: &str) -> Option<SourceLocation> {
+ if self.spans.is_empty() {
+ return None;
+ }
+
+ Some(self.spans[0].0.location(source))
+ }
+
+ #[cfg(not(feature = "span"))]
+ /// Return a [`SourceLocation`] for our first span, if we have one.
+ pub fn location(&self, _source: &str) -> Option<SourceLocation> {
+ None
+ }
+
+ #[cfg(feature = "span")]
+ fn diagnostic(&self) -> codespan_reporting::diagnostic::Diagnostic<()>
+ where
+ E: Error,
+ {
+ use codespan_reporting::diagnostic::{Diagnostic, Label};
+ let diagnostic = Diagnostic::error()
+ .with_message(self.inner.to_string())
+ .with_labels(
+ self.spans()
+ .map(|&(span, ref desc)| {
+ Label::primary((), span.to_range().unwrap()).with_message(desc.to_owned())
+ })
+ .collect(),
+ )
+ .with_notes({
+ let mut notes = Vec::new();
+ let mut source: &dyn Error = &self.inner;
+ while let Some(next) = Error::source(source) {
+ notes.push(next.to_string());
+ source = next;
+ }
+ notes
+ });
+ diagnostic
+ }
+
+ /// Emits a summary of the error to standard error stream.
+ #[cfg(feature = "span")]
+ pub fn emit_to_stderr(&self, source: &str)
+ where
+ E: Error,
+ {
+ self.emit_to_stderr_with_path(source, "wgsl")
+ }
+
+ /// Emits a summary of the error to standard error stream.
+ #[cfg(feature = "span")]
+ pub fn emit_to_stderr_with_path(&self, source: &str, path: &str)
+ where
+ E: Error,
+ {
+ use codespan_reporting::{files, term};
+ use term::termcolor::{ColorChoice, StandardStream};
+
+ let files = files::SimpleFile::new(path, source);
+ let config = term::Config::default();
+ let writer = StandardStream::stderr(ColorChoice::Auto);
+ term::emit(&mut writer.lock(), &config, &files, &self.diagnostic())
+ .expect("cannot write error");
+ }
+
+ /// Emits a summary of the error to a string.
+ #[cfg(feature = "span")]
+ pub fn emit_to_string(&self, source: &str) -> String
+ where
+ E: Error,
+ {
+ self.emit_to_string_with_path(source, "wgsl")
+ }
+
+ /// Emits a summary of the error to a string.
+ #[cfg(feature = "span")]
+ pub fn emit_to_string_with_path(&self, source: &str, path: &str) -> String
+ where
+ E: Error,
+ {
+ use codespan_reporting::{files, term};
+ use term::termcolor::NoColor;
+
+ let files = files::SimpleFile::new(path, source);
+ let config = codespan_reporting::term::Config::default();
+ let mut writer = NoColor::new(Vec::new());
+ term::emit(&mut writer, &config, &files, &self.diagnostic()).expect("cannot write error");
+ String::from_utf8(writer.into_inner()).unwrap()
+ }
+}
+
+/// Convenience trait for [`Error`] to be able to apply spans to anything.
+pub(crate) trait AddSpan: Sized {
+ type Output;
+ /// See [`WithSpan::new`].
+ fn with_span(self) -> Self::Output;
+ /// See [`WithSpan::with_span`].
+ fn with_span_static(self, span: Span, description: &'static str) -> Self::Output;
+ /// See [`WithSpan::with_context`].
+ fn with_span_context(self, span_context: SpanContext) -> Self::Output;
+ /// See [`WithSpan::with_handle`].
+ fn with_span_handle<T, A: SpanProvider<T>>(self, handle: Handle<T>, arena: &A) -> Self::Output;
+}
+
+/// Trait abstracting over getting a span from an [`Arena`] or a [`UniqueArena`].
+pub(crate) trait SpanProvider<T> {
+ fn get_span(&self, handle: Handle<T>) -> Span;
+ fn get_span_context(&self, handle: Handle<T>) -> SpanContext {
+ match self.get_span(handle) {
+ x if !x.is_defined() => (Default::default(), "".to_string()),
+ known => (
+ known,
+ format!("{} {:?}", std::any::type_name::<T>(), handle),
+ ),
+ }
+ }
+}
+
+impl<T> SpanProvider<T> for Arena<T> {
+ fn get_span(&self, handle: Handle<T>) -> Span {
+ self.get_span(handle)
+ }
+}
+
+impl<T> SpanProvider<T> for UniqueArena<T> {
+ fn get_span(&self, handle: Handle<T>) -> Span {
+ self.get_span(handle)
+ }
+}
+
+impl<E> AddSpan for E
+where
+ E: Error,
+{
+ type Output = WithSpan<Self>;
+ fn with_span(self) -> WithSpan<Self> {
+ WithSpan::new(self)
+ }
+
+ fn with_span_static(self, span: Span, description: &'static str) -> WithSpan<Self> {
+ WithSpan::new(self).with_span(span, description)
+ }
+
+ fn with_span_context(self, span_context: SpanContext) -> WithSpan<Self> {
+ WithSpan::new(self).with_context(span_context)
+ }
+
+ fn with_span_handle<T, A: SpanProvider<T>>(
+ self,
+ handle: Handle<T>,
+ arena: &A,
+ ) -> WithSpan<Self> {
+ WithSpan::new(self).with_handle(handle, arena)
+ }
+}
+
+/// Convenience trait for [`Result`], adding a [`MapErrWithSpan::map_err_inner`]
+/// mapping to [`WithSpan::and_then`].
+pub trait MapErrWithSpan<E, E2>: Sized {
+ type Output: Sized;
+ fn map_err_inner<F, E3>(self, func: F) -> Self::Output
+ where
+ F: FnOnce(E) -> WithSpan<E3>,
+ E2: From<E3>;
+}
+
+impl<T, E, E2> MapErrWithSpan<E, E2> for Result<T, WithSpan<E>> {
+ type Output = Result<T, WithSpan<E2>>;
+ fn map_err_inner<F, E3>(self, func: F) -> Result<T, WithSpan<E2>>
+ where
+ F: FnOnce(E) -> WithSpan<E3>,
+ E2: From<E3>,
+ {
+ self.map_err(|e| e.and_then(func).into_other::<E2>())
+ }
+}
+
+#[test]
+fn span_location() {
+ let source = "12\n45\n\n89\n";
+ assert_eq!(
+ Span { start: 0, end: 1 }.location(source),
+ SourceLocation {
+ line_number: 1,
+ line_position: 1,
+ offset: 0,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 1, end: 2 }.location(source),
+ SourceLocation {
+ line_number: 1,
+ line_position: 2,
+ offset: 1,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 2, end: 3 }.location(source),
+ SourceLocation {
+ line_number: 1,
+ line_position: 3,
+ offset: 2,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 3, end: 5 }.location(source),
+ SourceLocation {
+ line_number: 2,
+ line_position: 1,
+ offset: 3,
+ length: 2
+ }
+ );
+ assert_eq!(
+ Span { start: 4, end: 6 }.location(source),
+ SourceLocation {
+ line_number: 2,
+ line_position: 2,
+ offset: 4,
+ length: 2
+ }
+ );
+ assert_eq!(
+ Span { start: 5, end: 6 }.location(source),
+ SourceLocation {
+ line_number: 2,
+ line_position: 3,
+ offset: 5,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 6, end: 7 }.location(source),
+ SourceLocation {
+ line_number: 3,
+ line_position: 1,
+ offset: 6,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 7, end: 8 }.location(source),
+ SourceLocation {
+ line_number: 4,
+ line_position: 1,
+ offset: 7,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 8, end: 9 }.location(source),
+ SourceLocation {
+ line_number: 4,
+ line_position: 2,
+ offset: 8,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 9, end: 10 }.location(source),
+ SourceLocation {
+ line_number: 4,
+ line_position: 3,
+ offset: 9,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 10, end: 11 }.location(source),
+ SourceLocation {
+ line_number: 5,
+ line_position: 1,
+ offset: 10,
+ length: 1
+ }
+ );
+}
diff --git a/third_party/rust/naga/src/valid/analyzer.rs b/third_party/rust/naga/src/valid/analyzer.rs
new file mode 100644
index 0000000000..e9b155b6eb
--- /dev/null
+++ b/third_party/rust/naga/src/valid/analyzer.rs
@@ -0,0 +1,1200 @@
+/*! Module analyzer.
+
+Figures out the following properties:
+ - control flow uniformity
+ - texture/sampler pairs
+ - expression reference counts
+!*/
+
+use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
+use crate::span::{AddSpan as _, WithSpan};
+use crate::{
+ arena::{Arena, Handle},
+ proc::{ResolveContext, TypeResolution},
+};
+use std::ops;
+
+pub type NonUniformResult = Option<Handle<crate::Expression>>;
+
+bitflags::bitflags! {
+ /// Kinds of expressions that require uniform control flow.
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct UniformityRequirements: u8 {
+ const WORK_GROUP_BARRIER = 0x1;
+ const DERIVATIVE = 0x2;
+ const IMPLICIT_LEVEL = 0x4;
+ }
+}
+
+/// Uniform control flow characteristics.
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Uniformity {
+ /// A child expression with non-uniform result.
+ ///
+ /// This means, when the relevant invocations are scheduled on a compute unit,
+ /// they have to use vector registers to store an individual value
+ /// per invocation.
+ ///
+ /// Whenever the control flow is conditioned on such value,
+ /// the hardware needs to keep track of the mask of invocations,
+ /// and process all branches of the control flow.
+ ///
+ /// Any operations that depend on non-uniform results also produce non-uniform.
+ pub non_uniform_result: NonUniformResult,
+ /// If this expression requires uniform control flow, store the reason here.
+ pub requirements: UniformityRequirements,
+}
+
+impl Uniformity {
+ const fn new() -> Self {
+ Uniformity {
+ non_uniform_result: None,
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+}
+
+bitflags::bitflags! {
+ struct ExitFlags: u8 {
+ /// Control flow may return from the function, which makes all the
+ /// subsequent statements within the current function (only!)
+ /// to be executed in a non-uniform control flow.
+ const MAY_RETURN = 0x1;
+ /// Control flow may be killed. Anything after `Statement::Kill` is
+ /// considered inside non-uniform context.
+ const MAY_KILL = 0x2;
+ }
+}
+
+/// Uniformity characteristics of a function.
+#[cfg_attr(test, derive(Debug, PartialEq))]
+struct FunctionUniformity {
+ result: Uniformity,
+ exit: ExitFlags,
+}
+
+impl ops::BitOr for FunctionUniformity {
+ type Output = Self;
+ fn bitor(self, other: Self) -> Self {
+ FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: self
+ .result
+ .non_uniform_result
+ .or(other.result.non_uniform_result),
+ requirements: self.result.requirements | other.result.requirements,
+ },
+ exit: self.exit | other.exit,
+ }
+ }
+}
+
+impl FunctionUniformity {
+ const fn new() -> Self {
+ FunctionUniformity {
+ result: Uniformity::new(),
+ exit: ExitFlags::empty(),
+ }
+ }
+
+ /// Returns a disruptor based on the stored exit flags, if any.
+ const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
+ if self.exit.contains(ExitFlags::MAY_RETURN) {
+ Some(UniformityDisruptor::Return)
+ } else if self.exit.contains(ExitFlags::MAY_KILL) {
+ Some(UniformityDisruptor::Discard)
+ } else {
+ None
+ }
+ }
+}
+
+bitflags::bitflags! {
+ /// Indicates how a global variable is used.
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct GlobalUse: u8 {
+ /// Data will be read from the variable.
+ const READ = 0x1;
+ /// Data will be written to the variable.
+ const WRITE = 0x2;
+ /// The information about the data is queried.
+ const QUERY = 0x4;
+ }
+}
+
+#[derive(Clone, Debug, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct SamplingKey {
+ pub image: Handle<crate::GlobalVariable>,
+ pub sampler: Handle<crate::GlobalVariable>,
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct ExpressionInfo {
+ pub uniformity: Uniformity,
+ pub ref_count: usize,
+ assignable_global: Option<Handle<crate::GlobalVariable>>,
+ pub ty: TypeResolution,
+}
+
+impl ExpressionInfo {
+ const fn new() -> Self {
+ ExpressionInfo {
+ uniformity: Uniformity::new(),
+ ref_count: 0,
+ assignable_global: None,
+ // this doesn't matter at this point, will be overwritten
+ ty: TypeResolution::Value(crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: 0,
+ }),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+enum GlobalOrArgument {
+ Global(Handle<crate::GlobalVariable>),
+ Argument(u32),
+}
+
+impl GlobalOrArgument {
+ fn from_expression(
+ expression_arena: &Arena<crate::Expression>,
+ expression: Handle<crate::Expression>,
+ ) -> Result<GlobalOrArgument, ExpressionError> {
+ Ok(match expression_arena[expression] {
+ crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
+ crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
+ crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
+ _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
+ },
+ _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
+ })
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+struct Sampling {
+ image: GlobalOrArgument,
+ sampler: GlobalOrArgument,
+}
+
+#[derive(Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct FunctionInfo {
+ /// Validation flags.
+ #[allow(dead_code)]
+ flags: ValidationFlags,
+ /// Set of shader stages where calling this function is valid.
+ pub available_stages: ShaderStages,
+ /// Uniformity characteristics.
+ pub uniformity: Uniformity,
+ /// Function may kill the invocation.
+ pub may_kill: bool,
+
+ /// All pairs of (texture, sampler) globals that may be used together in
+ /// sampling operations by this function and its callees. This includes
+ /// pairings that arise when this function passes textures and samplers as
+ /// arguments to its callees.
+ ///
+ /// This table does not include uses of textures and samplers passed as
+ /// arguments to this function itself, since we do not know which globals
+ /// those will be. However, this table *is* exhaustive when computed for an
+ /// entry point function: entry points never receive textures or samplers as
+ /// arguments, so all an entry point's sampling can be reported in terms of
+ /// globals.
+ ///
+ /// The GLSL back end uses this table to construct reflection info that
+ /// clients need to construct texture-combined sampler values.
+ pub sampling_set: crate::FastHashSet<SamplingKey>,
+
+ /// How this function and its callees use this module's globals.
+ ///
+ /// This is indexed by `Handle<GlobalVariable>` indices. However,
+ /// `FunctionInfo` implements `std::ops::Index<Handle<GlobalVariable>>`,
+ /// so you can simply index this struct with a global handle to retrieve
+ /// its usage information.
+ global_uses: Box<[GlobalUse]>,
+
+ /// Information about each expression in this function's body.
+ ///
+ /// This is indexed by `Handle<Expression>` indices. However, `FunctionInfo`
+ /// implements `std::ops::Index<Handle<Expression>>`, so you can simply
+ /// index this struct with an expression handle to retrieve its
+ /// `ExpressionInfo`.
+ expressions: Box<[ExpressionInfo]>,
+
+ /// All (texture, sampler) pairs that may be used together in sampling
+ /// operations by this function and its callees, whether they are accessed
+ /// as globals or passed as arguments.
+ ///
+ /// Participants are represented by [`GlobalVariable`] handles whenever
+ /// possible, and otherwise by indices of this function's arguments.
+ ///
+ /// When analyzing a function call, we combine this data about the callee
+ /// with the actual arguments being passed to produce the callers' own
+ /// `sampling_set` and `sampling` tables.
+ ///
+ /// [`GlobalVariable`]: crate::GlobalVariable
+ sampling: crate::FastHashSet<Sampling>,
+}
+
+impl FunctionInfo {
+ pub const fn global_variable_count(&self) -> usize {
+ self.global_uses.len()
+ }
+ pub const fn expression_count(&self) -> usize {
+ self.expressions.len()
+ }
+ pub fn dominates_global_use(&self, other: &Self) -> bool {
+ for (self_global_uses, other_global_uses) in
+ self.global_uses.iter().zip(other.global_uses.iter())
+ {
+ if !self_global_uses.contains(*other_global_uses) {
+ return false;
+ }
+ }
+ true
+ }
+}
+
+impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
+ type Output = GlobalUse;
+ fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
+ &self.global_uses[handle.index()]
+ }
+}
+
+impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
+ type Output = ExpressionInfo;
+ fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
+ &self.expressions[handle.index()]
+ }
+}
+
+/// Disruptor of the uniform control flow.
+#[derive(Clone, Copy, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum UniformityDisruptor {
+ #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
+ Expression(Handle<crate::Expression>),
+ #[error("There is a Return earlier in the control flow of the function")]
+ Return,
+ #[error("There is a Discard earlier in the entry point across all called functions")]
+ Discard,
+}
+
+impl FunctionInfo {
+ /// Adds a value-type reference to an expression.
+ #[must_use]
+ fn add_ref_impl(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ global_use: GlobalUse,
+ ) -> NonUniformResult {
+ let info = &mut self.expressions[handle.index()];
+ info.ref_count += 1;
+ // mark the used global as read
+ if let Some(global) = info.assignable_global {
+ self.global_uses[global.index()] |= global_use;
+ }
+ info.uniformity.non_uniform_result
+ }
+
+ /// Adds a value-type reference to an expression.
+ #[must_use]
+ fn add_ref(&mut self, handle: Handle<crate::Expression>) -> NonUniformResult {
+ self.add_ref_impl(handle, GlobalUse::READ)
+ }
+
+ /// Adds a potentially assignable reference to an expression.
+ /// These are destinations for `Store` and `ImageStore` statements,
+ /// which can transit through `Access` and `AccessIndex`.
+ #[must_use]
+ fn add_assignable_ref(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
+ ) -> NonUniformResult {
+ let info = &mut self.expressions[handle.index()];
+ info.ref_count += 1;
+ // propagate the assignable global up the chain, till it either hits
+ // a value-type expression, or the assignment statement.
+ if let Some(global) = info.assignable_global {
+ if let Some(_old) = assignable_global.replace(global) {
+ unreachable!()
+ }
+ }
+ info.uniformity.non_uniform_result
+ }
+
+ /// Inherit information from a called function.
+ fn process_call(
+ &mut self,
+ callee: &Self,
+ arguments: &[Handle<crate::Expression>],
+ expression_arena: &Arena<crate::Expression>,
+ ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
+ self.sampling_set
+ .extend(callee.sampling_set.iter().cloned());
+ for sampling in callee.sampling.iter() {
+ // If the callee was passed the texture or sampler as an argument,
+ // we may now be able to determine which globals those referred to.
+ let image_storage = match sampling.image {
+ GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
+ GlobalOrArgument::Argument(i) => {
+ let handle = arguments[i as usize];
+ GlobalOrArgument::from_expression(expression_arena, handle).map_err(
+ |source| {
+ FunctionError::Expression { handle, source }
+ .with_span_handle(handle, expression_arena)
+ },
+ )?
+ }
+ };
+
+ let sampler_storage = match sampling.sampler {
+ GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
+ GlobalOrArgument::Argument(i) => {
+ let handle = arguments[i as usize];
+ GlobalOrArgument::from_expression(expression_arena, handle).map_err(
+ |source| {
+ FunctionError::Expression { handle, source }
+ .with_span_handle(handle, expression_arena)
+ },
+ )?
+ }
+ };
+
+ // If we've managed to pin both the image and sampler down to
+ // specific globals, record that in our `sampling_set`. Otherwise,
+ // record as much as we do know in our own `sampling` table, for our
+ // callers to sort out.
+ match (image_storage, sampler_storage) {
+ (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
+ self.sampling_set.insert(SamplingKey { image, sampler });
+ }
+ (image, sampler) => {
+ self.sampling.insert(Sampling { image, sampler });
+ }
+ }
+ }
+
+ // Inherit global use from our callees.
+ for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
+ *mine |= *other;
+ }
+
+ Ok(FunctionUniformity {
+ result: callee.uniformity.clone(),
+ exit: if callee.may_kill {
+ ExitFlags::MAY_KILL
+ } else {
+ ExitFlags::empty()
+ },
+ })
+ }
+
+ /// Computes the expression info and stores it in `self.expressions`.
+ /// Also, bumps the reference counts on dependent expressions.
+ #[allow(clippy::or_fun_call)]
+ fn process_expression(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ expression: &crate::Expression,
+ expression_arena: &Arena<crate::Expression>,
+ other_functions: &[FunctionInfo],
+ resolve_context: &ResolveContext,
+ capabilities: super::Capabilities,
+ ) -> Result<(), ExpressionError> {
+ use crate::{Expression as E, SampleLevel as Sl};
+
+ let mut assignable_global = None;
+ let uniformity = match *expression {
+ E::Access { base, index } => {
+ let base_ty = self[base].ty.inner_with(resolve_context.types);
+
+ // build up the caps needed if this is indexed non-uniformly
+ let mut needed_caps = super::Capabilities::empty();
+ let is_binding_array = match *base_ty {
+ crate::TypeInner::BindingArray {
+ base: array_element_ty_handle,
+ ..
+ } => {
+ // these are nasty aliases, but these idents are too long and break rustfmt
+ let ub_st = super::Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
+ let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
+ let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
+
+ // We're a binding array, so lets use the type of _what_ we are array of to determine if we can non-uniformly index it.
+ let array_element_ty =
+ &resolve_context.types[array_element_ty_handle].inner;
+
+ needed_caps |= match *array_element_ty {
+ // If we're an image, use the appropriate limit.
+ crate::TypeInner::Image { class, .. } => match class {
+ crate::ImageClass::Storage { .. } => ub_st,
+ _ => st_sb,
+ },
+ crate::TypeInner::Sampler { .. } => sampler,
+ // If we're anything but an image, assume we're a buffer and use the address space.
+ _ => {
+ if let E::GlobalVariable(global_handle) = expression_arena[base] {
+ let global = &resolve_context.global_vars[global_handle];
+ match global.space {
+ crate::AddressSpace::Uniform => ub_st,
+ crate::AddressSpace::Storage { .. } => st_sb,
+ _ => unreachable!(),
+ }
+ } else {
+ unreachable!()
+ }
+ }
+ };
+
+ true
+ }
+ _ => false,
+ };
+
+ if self[index].uniformity.non_uniform_result.is_some()
+ && !capabilities.contains(needed_caps)
+ && is_binding_array
+ {
+ return Err(ExpressionError::MissingCapabilities(needed_caps));
+ }
+
+ Uniformity {
+ non_uniform_result: self
+ .add_assignable_ref(base, &mut assignable_global)
+ .or(self.add_ref(index)),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::AccessIndex { base, .. } => Uniformity {
+ non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
+ requirements: UniformityRequirements::empty(),
+ },
+ // always uniform
+ E::Constant(_) => Uniformity::new(),
+ E::Splat { size: _, value } => Uniformity {
+ non_uniform_result: self.add_ref(value),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Swizzle { vector, .. } => Uniformity {
+ non_uniform_result: self.add_ref(vector),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Compose { ref components, .. } => {
+ let non_uniform_result = components
+ .iter()
+ .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
+ Uniformity {
+ non_uniform_result,
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ // depends on the builtin or interpolation
+ E::FunctionArgument(index) => {
+ let arg = &resolve_context.arguments[index as usize];
+ let uniform = match arg.binding {
+ Some(crate::Binding::BuiltIn(built_in)) => match built_in {
+ // per-polygon built-ins are uniform
+ crate::BuiltIn::FrontFacing
+ // per-work-group built-ins are uniform
+ | crate::BuiltIn::WorkGroupId
+ | crate::BuiltIn::WorkGroupSize
+ | crate::BuiltIn::NumWorkGroups => true,
+ _ => false,
+ },
+ // only flat inputs are uniform
+ Some(crate::Binding::Location {
+ interpolation: Some(crate::Interpolation::Flat),
+ ..
+ }) => true,
+ _ => false,
+ };
+ Uniformity {
+ non_uniform_result: if uniform { None } else { Some(handle) },
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ // depends on the address space
+ E::GlobalVariable(gh) => {
+ use crate::AddressSpace as As;
+ assignable_global = Some(gh);
+ let var = &resolve_context.global_vars[gh];
+ let uniform = match var.space {
+ // local data is non-uniform
+ As::Function | As::Private => false,
+ // workgroup memory is exclusively accessed by the group
+ As::WorkGroup => true,
+ // uniform data
+ As::Uniform | As::PushConstant => true,
+ // storage data is only uniform when read-only
+ As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
+ As::Handle => false,
+ };
+ Uniformity {
+ non_uniform_result: if uniform { None } else { Some(handle) },
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::LocalVariable(_) => Uniformity {
+ non_uniform_result: Some(handle),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Load { pointer } => Uniformity {
+ non_uniform_result: self.add_ref(pointer),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::ImageSample {
+ image,
+ sampler,
+ gather: _,
+ coordinate,
+ array_index,
+ offset: _,
+ level,
+ depth_ref,
+ } => {
+ let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
+ let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
+
+ match (image_storage, sampler_storage) {
+ (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
+ self.sampling_set.insert(SamplingKey { image, sampler });
+ }
+ _ => {
+ self.sampling.insert(Sampling {
+ image: image_storage,
+ sampler: sampler_storage,
+ });
+ }
+ }
+
+ // "nur" == "Non-Uniform Result"
+ let array_nur = array_index.and_then(|h| self.add_ref(h));
+ let level_nur = match level {
+ Sl::Auto | Sl::Zero => None,
+ Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
+ Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
+ };
+ let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
+ Uniformity {
+ non_uniform_result: self
+ .add_ref(image)
+ .or(self.add_ref(sampler))
+ .or(self.add_ref(coordinate))
+ .or(array_nur)
+ .or(level_nur)
+ .or(dref_nur),
+ requirements: if level.implicit_derivatives() {
+ UniformityRequirements::IMPLICIT_LEVEL
+ } else {
+ UniformityRequirements::empty()
+ },
+ }
+ }
+ E::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ let array_nur = array_index.and_then(|h| self.add_ref(h));
+ let sample_nur = sample.and_then(|h| self.add_ref(h));
+ let level_nur = level.and_then(|h| self.add_ref(h));
+ Uniformity {
+ non_uniform_result: self
+ .add_ref(image)
+ .or(self.add_ref(coordinate))
+ .or(array_nur)
+ .or(sample_nur)
+ .or(level_nur),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::ImageQuery { image, query } => {
+ let query_nur = match query {
+ crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
+ _ => None,
+ };
+ Uniformity {
+ non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::Unary { expr, .. } => Uniformity {
+ non_uniform_result: self.add_ref(expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Binary { left, right, .. } => Uniformity {
+ non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Select {
+ condition,
+ accept,
+ reject,
+ } => Uniformity {
+ non_uniform_result: self
+ .add_ref(condition)
+ .or(self.add_ref(accept))
+ .or(self.add_ref(reject)),
+ requirements: UniformityRequirements::empty(),
+ },
+ // explicit derivatives require uniform
+ E::Derivative { expr, .. } => Uniformity {
+ //Note: taking a derivative of a uniform doesn't make it non-uniform
+ non_uniform_result: self.add_ref(expr),
+ requirements: UniformityRequirements::DERIVATIVE,
+ },
+ E::Relational { argument, .. } => Uniformity {
+ non_uniform_result: self.add_ref(argument),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Math {
+ arg, arg1, arg2, ..
+ } => {
+ let arg1_nur = arg1.and_then(|h| self.add_ref(h));
+ let arg2_nur = arg2.and_then(|h| self.add_ref(h));
+ Uniformity {
+ non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::As { expr, .. } => Uniformity {
+ non_uniform_result: self.add_ref(expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
+ E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
+ non_uniform_result: Some(handle),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::ArrayLength(expr) => Uniformity {
+ non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::RayQueryGetIntersection {
+ query,
+ committed: _,
+ } => Uniformity {
+ non_uniform_result: self.add_ref(query),
+ requirements: UniformityRequirements::empty(),
+ },
+ };
+
+ let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
+ self.expressions[handle.index()] = ExpressionInfo {
+ uniformity,
+ ref_count: 0,
+ assignable_global,
+ ty,
+ };
+ Ok(())
+ }
+
+ /// Analyzes the uniformity requirements of a block (as a sequence of statements).
+ /// Returns the uniformity characteristics at the *function* level, i.e.
+ /// whether or not the function requires to be called in uniform control flow,
+ /// and whether the produced result is not disrupting the control flow.
+ ///
+ /// The parent control flow is uniform if `disruptor.is_none()`.
+ ///
+ /// Returns a `NonUniformControlFlow` error if any of the expressions in the block
+ /// require uniformity, but the current flow is non-uniform.
+ #[allow(clippy::or_fun_call)]
+ fn process_block(
+ &mut self,
+ statements: &crate::Block,
+ other_functions: &[FunctionInfo],
+ mut disruptor: Option<UniformityDisruptor>,
+ expression_arena: &Arena<crate::Expression>,
+ ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
+ use crate::Statement as S;
+
+ let mut combined_uniformity = FunctionUniformity::new();
+ for statement in statements {
+ let uniformity = match *statement {
+ S::Emit(ref range) => {
+ let mut requirements = UniformityRequirements::empty();
+ for expr in range.clone() {
+ let req = self.expressions[expr.index()].uniformity.requirements;
+ #[cfg(feature = "validate")]
+ if self
+ .flags
+ .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
+ && !req.is_empty()
+ {
+ if let Some(cause) = disruptor {
+ return Err(FunctionError::NonUniformControlFlow(req, expr, cause)
+ .with_span_handle(expr, expression_arena));
+ }
+ }
+ requirements |= req;
+ }
+ FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: None,
+ requirements,
+ },
+ exit: ExitFlags::empty(),
+ }
+ }
+ S::Break | S::Continue => FunctionUniformity::new(),
+ S::Kill => FunctionUniformity {
+ result: Uniformity::new(),
+ exit: if disruptor.is_some() {
+ ExitFlags::MAY_KILL
+ } else {
+ ExitFlags::empty()
+ },
+ },
+ S::Barrier(_) => FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: None,
+ requirements: UniformityRequirements::WORK_GROUP_BARRIER,
+ },
+ exit: ExitFlags::empty(),
+ },
+ S::Block(ref b) => {
+ self.process_block(b, other_functions, disruptor, expression_arena)?
+ }
+ S::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ let condition_nur = self.add_ref(condition);
+ let branch_disruptor =
+ disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
+ let accept_uniformity = self.process_block(
+ accept,
+ other_functions,
+ branch_disruptor,
+ expression_arena,
+ )?;
+ let reject_uniformity = self.process_block(
+ reject,
+ other_functions,
+ branch_disruptor,
+ expression_arena,
+ )?;
+ accept_uniformity | reject_uniformity
+ }
+ S::Switch {
+ selector,
+ ref cases,
+ } => {
+ let selector_nur = self.add_ref(selector);
+ let branch_disruptor =
+ disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
+ let mut uniformity = FunctionUniformity::new();
+ let mut case_disruptor = branch_disruptor;
+ for case in cases.iter() {
+ let case_uniformity = self.process_block(
+ &case.body,
+ other_functions,
+ case_disruptor,
+ expression_arena,
+ )?;
+ case_disruptor = if case.fall_through {
+ case_disruptor.or(case_uniformity.exit_disruptor())
+ } else {
+ branch_disruptor
+ };
+ uniformity = uniformity | case_uniformity;
+ }
+ uniformity
+ }
+ S::Loop {
+ ref body,
+ ref continuing,
+ break_if: _,
+ } => {
+ let body_uniformity =
+ self.process_block(body, other_functions, disruptor, expression_arena)?;
+ let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
+ let continuing_uniformity = self.process_block(
+ continuing,
+ other_functions,
+ continuing_disruptor,
+ expression_arena,
+ )?;
+ body_uniformity | continuing_uniformity
+ }
+ S::Return { value } => FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
+ requirements: UniformityRequirements::empty(),
+ },
+ exit: if disruptor.is_some() {
+ ExitFlags::MAY_RETURN
+ } else {
+ ExitFlags::empty()
+ },
+ },
+ // Here and below, the used expressions are already emitted,
+ // and their results do not affect the function return value,
+ // so we can ignore their non-uniformity.
+ S::Store { pointer, value } => {
+ let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
+ let _ = self.add_ref(value);
+ FunctionUniformity::new()
+ }
+ S::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ let _ = self.add_ref_impl(image, GlobalUse::WRITE);
+ if let Some(expr) = array_index {
+ let _ = self.add_ref(expr);
+ }
+ let _ = self.add_ref(coordinate);
+ let _ = self.add_ref(value);
+ FunctionUniformity::new()
+ }
+ S::Call {
+ function,
+ ref arguments,
+ result: _,
+ } => {
+ for &argument in arguments {
+ let _ = self.add_ref(argument);
+ }
+ let info = &other_functions[function.index()];
+ //Note: the result is validated by the Validator, not here
+ self.process_call(info, arguments, expression_arena)?
+ }
+ S::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result: _,
+ } => {
+ let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
+ let _ = self.add_ref(value);
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ let _ = self.add_ref(cmp);
+ }
+ FunctionUniformity::new()
+ }
+ S::RayQuery { query, ref fun } => {
+ let _ = self.add_ref(query);
+ if let crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } = *fun
+ {
+ let _ = self.add_ref(acceleration_structure);
+ let _ = self.add_ref(descriptor);
+ }
+ FunctionUniformity::new()
+ }
+ };
+
+ disruptor = disruptor.or(uniformity.exit_disruptor());
+ combined_uniformity = combined_uniformity | uniformity;
+ }
+ Ok(combined_uniformity)
+ }
+}
+
+impl ModuleInfo {
+ /// Builds the `FunctionInfo` based on the function, and validates the
+ /// uniform control flow if required by the expressions of this function.
+ pub(super) fn process_function(
+ &self,
+ fun: &crate::Function,
+ module: &crate::Module,
+ flags: ValidationFlags,
+ capabilities: super::Capabilities,
+ ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
+ let mut info = FunctionInfo {
+ flags,
+ available_stages: ShaderStages::all(),
+ uniformity: Uniformity::new(),
+ may_kill: false,
+ sampling_set: crate::FastHashSet::default(),
+ global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
+ expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
+ sampling: crate::FastHashSet::default(),
+ };
+ let resolve_context =
+ ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
+
+ for (handle, expr) in fun.expressions.iter() {
+ if let Err(source) = info.process_expression(
+ handle,
+ expr,
+ &fun.expressions,
+ &self.functions,
+ &resolve_context,
+ capabilities,
+ ) {
+ return Err(FunctionError::Expression { handle, source }
+ .with_span_handle(handle, &fun.expressions));
+ }
+ }
+
+ let uniformity = info.process_block(&fun.body, &self.functions, None, &fun.expressions)?;
+ info.uniformity = uniformity.result;
+ info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
+
+ Ok(info)
+ }
+
+ pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
+ &self.entry_points[index]
+ }
+}
+
+#[test]
+#[cfg(feature = "validate")]
+fn uniform_control_flow() {
+ use crate::{Expression as E, Statement as S};
+
+ let mut constant_arena = Arena::new();
+ let constant = constant_arena.append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Uint(0),
+ },
+ },
+ Default::default(),
+ );
+ let mut type_arena = crate::UniqueArena::new();
+ let ty = type_arena.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ },
+ },
+ Default::default(),
+ );
+ let mut global_var_arena = Arena::new();
+ let non_uniform_global = global_var_arena.append(
+ crate::GlobalVariable {
+ name: None,
+ init: None,
+ ty,
+ space: crate::AddressSpace::Handle,
+ binding: None,
+ },
+ Default::default(),
+ );
+ let uniform_global = global_var_arena.append(
+ crate::GlobalVariable {
+ name: None,
+ init: None,
+ ty,
+ binding: None,
+ space: crate::AddressSpace::Uniform,
+ },
+ Default::default(),
+ );
+
+ let mut expressions = Arena::new();
+ // checks the uniform control flow
+ let constant_expr = expressions.append(E::Constant(constant), Default::default());
+ // checks the non-uniform control flow
+ let derivative_expr = expressions.append(
+ E::Derivative {
+ axis: crate::DerivativeAxis::X,
+ ctrl: crate::DerivativeControl::None,
+ expr: constant_expr,
+ },
+ Default::default(),
+ );
+ let emit_range_constant_derivative = expressions.range_from(0);
+ let non_uniform_global_expr =
+ expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
+ let uniform_global_expr =
+ expressions.append(E::GlobalVariable(uniform_global), Default::default());
+ let emit_range_globals = expressions.range_from(2);
+
+ // checks the QUERY flag
+ let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
+ // checks the transitive WRITE flag
+ let access_expr = expressions.append(
+ E::AccessIndex {
+ base: non_uniform_global_expr,
+ index: 1,
+ },
+ Default::default(),
+ );
+ let emit_range_query_access_globals = expressions.range_from(2);
+
+ let mut info = FunctionInfo {
+ flags: ValidationFlags::all(),
+ available_stages: ShaderStages::all(),
+ uniformity: Uniformity::new(),
+ may_kill: false,
+ sampling_set: crate::FastHashSet::default(),
+ global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
+ expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
+ sampling: crate::FastHashSet::default(),
+ };
+ let resolve_context = ResolveContext {
+ constants: &constant_arena,
+ types: &type_arena,
+ special_types: &crate::SpecialTypes::default(),
+ global_vars: &global_var_arena,
+ local_vars: &Arena::new(),
+ functions: &Arena::new(),
+ arguments: &[],
+ };
+ for (handle, expression) in expressions.iter() {
+ info.process_expression(
+ handle,
+ expression,
+ &expressions,
+ &[],
+ &resolve_context,
+ super::Capabilities::empty(),
+ )
+ .unwrap();
+ }
+ assert_eq!(info[non_uniform_global_expr].ref_count, 1);
+ assert_eq!(info[uniform_global_expr].ref_count, 1);
+ assert_eq!(info[query_expr].ref_count, 0);
+ assert_eq!(info[access_expr].ref_count, 0);
+ assert_eq!(info[non_uniform_global], GlobalUse::empty());
+ assert_eq!(info[uniform_global], GlobalUse::QUERY);
+
+ let stmt_emit1 = S::Emit(emit_range_globals.clone());
+ let stmt_if_uniform = S::If {
+ condition: uniform_global_expr,
+ accept: crate::Block::new(),
+ reject: vec![
+ S::Emit(emit_range_constant_derivative.clone()),
+ S::Store {
+ pointer: constant_expr,
+ value: derivative_expr,
+ },
+ ]
+ .into(),
+ };
+ assert_eq!(
+ info.process_block(
+ &vec![stmt_emit1, stmt_if_uniform].into(),
+ &[],
+ None,
+ &expressions
+ ),
+ Ok(FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: None,
+ requirements: UniformityRequirements::DERIVATIVE,
+ },
+ exit: ExitFlags::empty(),
+ }),
+ );
+ assert_eq!(info[constant_expr].ref_count, 2);
+ assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
+
+ let stmt_emit2 = S::Emit(emit_range_globals.clone());
+ let stmt_if_non_uniform = S::If {
+ condition: non_uniform_global_expr,
+ accept: vec![
+ S::Emit(emit_range_constant_derivative),
+ S::Store {
+ pointer: constant_expr,
+ value: derivative_expr,
+ },
+ ]
+ .into(),
+ reject: crate::Block::new(),
+ };
+ assert_eq!(
+ info.process_block(
+ &vec![stmt_emit2, stmt_if_non_uniform].into(),
+ &[],
+ None,
+ &expressions
+ ),
+ Err(FunctionError::NonUniformControlFlow(
+ UniformityRequirements::DERIVATIVE,
+ derivative_expr,
+ UniformityDisruptor::Expression(non_uniform_global_expr)
+ )
+ .with_span()),
+ );
+ assert_eq!(info[derivative_expr].ref_count, 1);
+ assert_eq!(info[non_uniform_global], GlobalUse::READ);
+
+ let stmt_emit3 = S::Emit(emit_range_globals);
+ let stmt_return_non_uniform = S::Return {
+ value: Some(non_uniform_global_expr),
+ };
+ assert_eq!(
+ info.process_block(
+ &vec![stmt_emit3, stmt_return_non_uniform].into(),
+ &[],
+ Some(UniformityDisruptor::Return),
+ &expressions
+ ),
+ Ok(FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: Some(non_uniform_global_expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ exit: ExitFlags::MAY_RETURN,
+ }),
+ );
+ assert_eq!(info[non_uniform_global_expr].ref_count, 3);
+
+ // Check that uniformity requirements reach through a pointer
+ let stmt_emit4 = S::Emit(emit_range_query_access_globals);
+ let stmt_assign = S::Store {
+ pointer: access_expr,
+ value: query_expr,
+ };
+ let stmt_return_pointer = S::Return {
+ value: Some(access_expr),
+ };
+ let stmt_kill = S::Kill;
+ assert_eq!(
+ info.process_block(
+ &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
+ &[],
+ Some(UniformityDisruptor::Discard),
+ &expressions
+ ),
+ Ok(FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: Some(non_uniform_global_expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ exit: ExitFlags::all(),
+ }),
+ );
+ assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
+}
diff --git a/third_party/rust/naga/src/valid/compose.rs b/third_party/rust/naga/src/valid/compose.rs
new file mode 100644
index 0000000000..e77d538255
--- /dev/null
+++ b/third_party/rust/naga/src/valid/compose.rs
@@ -0,0 +1,138 @@
+#[cfg(feature = "validate")]
+use crate::{
+ arena::{Arena, UniqueArena},
+ proc::TypeResolution,
+};
+
+use crate::arena::Handle;
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum ComposeError {
+ #[error("Composing of type {0:?} can't be done")]
+ Type(Handle<crate::Type>),
+ #[error("Composing expects {expected} components but {given} were given")]
+ ComponentCount { given: u32, expected: u32 },
+ #[error("Composing {index}'s component type is not expected")]
+ ComponentType { index: u32 },
+}
+
+#[cfg(feature = "validate")]
+pub fn validate_compose(
+ self_ty_handle: Handle<crate::Type>,
+ constant_arena: &Arena<crate::Constant>,
+ type_arena: &UniqueArena<crate::Type>,
+ component_resolutions: impl ExactSizeIterator<Item = TypeResolution>,
+) -> Result<(), ComposeError> {
+ use crate::TypeInner as Ti;
+
+ match type_arena[self_ty_handle].inner {
+ // vectors are composed from scalars or other vectors
+ Ti::Vector { size, kind, width } => {
+ let mut total = 0;
+ for (index, comp_res) in component_resolutions.enumerate() {
+ total += match *comp_res.inner_with(type_arena) {
+ Ti::Scalar {
+ kind: comp_kind,
+ width: comp_width,
+ } if comp_kind == kind && comp_width == width => 1,
+ Ti::Vector {
+ size: comp_size,
+ kind: comp_kind,
+ width: comp_width,
+ } if comp_kind == kind && comp_width == width => comp_size as u32,
+ ref other => {
+ log::error!("Vector component[{}] type {:?}", index, other);
+ return Err(ComposeError::ComponentType {
+ index: index as u32,
+ });
+ }
+ };
+ }
+ if size as u32 != total {
+ return Err(ComposeError::ComponentCount {
+ expected: size as u32,
+ given: total,
+ });
+ }
+ }
+ // matrix are composed from column vectors
+ Ti::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let inner = Ti::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ };
+ if columns as usize != component_resolutions.len() {
+ return Err(ComposeError::ComponentCount {
+ expected: columns as u32,
+ given: component_resolutions.len() as u32,
+ });
+ }
+ for (index, comp_res) in component_resolutions.enumerate() {
+ if comp_res.inner_with(type_arena) != &inner {
+ log::error!("Matrix component[{}] type {:?}", index, comp_res);
+ return Err(ComposeError::ComponentType {
+ index: index as u32,
+ });
+ }
+ }
+ }
+ Ti::Array {
+ base,
+ size: crate::ArraySize::Constant(handle),
+ stride: _,
+ } => {
+ let count = constant_arena[handle].to_array_length().unwrap();
+ if count as usize != component_resolutions.len() {
+ return Err(ComposeError::ComponentCount {
+ expected: count,
+ given: component_resolutions.len() as u32,
+ });
+ }
+ for (index, comp_res) in component_resolutions.enumerate() {
+ let base_inner = &type_arena[base].inner;
+ let comp_res_inner = comp_res.inner_with(type_arena);
+ // We don't support arrays of pointers, but it seems best not to
+ // embed that assumption here, so use `TypeInner::equivalent`.
+ if !base_inner.equivalent(comp_res_inner, type_arena) {
+ log::error!("Array component[{}] type {:?}", index, comp_res);
+ return Err(ComposeError::ComponentType {
+ index: index as u32,
+ });
+ }
+ }
+ }
+ Ti::Struct { ref members, .. } => {
+ if members.len() != component_resolutions.len() {
+ return Err(ComposeError::ComponentCount {
+ given: component_resolutions.len() as u32,
+ expected: members.len() as u32,
+ });
+ }
+ for (index, (member, comp_res)) in members.iter().zip(component_resolutions).enumerate()
+ {
+ let member_inner = &type_arena[member.ty].inner;
+ let comp_res_inner = comp_res.inner_with(type_arena);
+ // We don't support pointers in structs, but it seems best not to embed
+ // that assumption here, so use `TypeInner::equivalent`.
+ if !comp_res_inner.equivalent(member_inner, type_arena) {
+ log::error!("Struct component[{}] type {:?}", index, comp_res);
+ return Err(ComposeError::ComponentType {
+ index: index as u32,
+ });
+ }
+ }
+ }
+ ref other => {
+ log::error!("Composing of {:?}", other);
+ return Err(ComposeError::Type(self_ty_handle));
+ }
+ }
+
+ Ok(())
+}
diff --git a/third_party/rust/naga/src/valid/expression.rs b/third_party/rust/naga/src/valid/expression.rs
new file mode 100644
index 0000000000..3a81707904
--- /dev/null
+++ b/third_party/rust/naga/src/valid/expression.rs
@@ -0,0 +1,1482 @@
+#[cfg(feature = "validate")]
+use std::ops::Index;
+
+#[cfg(feature = "validate")]
+use super::{
+ compose::validate_compose, validate_atomic_compare_exchange_struct, FunctionInfo, ShaderStages,
+ TypeFlags,
+};
+#[cfg(feature = "validate")]
+use crate::arena::UniqueArena;
+
+use crate::{
+ arena::Handle,
+ proc::{IndexableLengthError, ResolveError},
+};
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum ExpressionError {
+ #[error("Doesn't exist")]
+ DoesntExist,
+ #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
+ NotInScope,
+ #[error("Base type {0:?} is not compatible with this expression")]
+ InvalidBaseType(Handle<crate::Expression>),
+ #[error("Accessing with index {0:?} can't be done")]
+ InvalidIndexType(Handle<crate::Expression>),
+ #[error("Accessing index {1:?} is out of {0:?} bounds")]
+ IndexOutOfBounds(Handle<crate::Expression>, crate::ScalarValue),
+ #[error("The expression {0:?} may only be indexed by a constant")]
+ IndexMustBeConstant(Handle<crate::Expression>),
+ #[error("Function argument {0:?} doesn't exist")]
+ FunctionArgumentDoesntExist(u32),
+ #[error("Loading of {0:?} can't be done")]
+ InvalidPointerType(Handle<crate::Expression>),
+ #[error("Array length of {0:?} can't be done")]
+ InvalidArrayType(Handle<crate::Expression>),
+ #[error("Get intersection of {0:?} can't be done")]
+ InvalidRayQueryType(Handle<crate::Expression>),
+ #[error("Splatting {0:?} can't be done")]
+ InvalidSplatType(Handle<crate::Expression>),
+ #[error("Swizzling {0:?} can't be done")]
+ InvalidVectorType(Handle<crate::Expression>),
+ #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
+ InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
+ #[error(transparent)]
+ Compose(#[from] super::ComposeError),
+ #[error(transparent)]
+ IndexableLength(#[from] IndexableLengthError),
+ #[error("Operation {0:?} can't work with {1:?}")]
+ InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
+ #[error("Operation {0:?} can't work with {1:?} and {2:?}")]
+ InvalidBinaryOperandTypes(
+ crate::BinaryOperator,
+ Handle<crate::Expression>,
+ Handle<crate::Expression>,
+ ),
+ #[error("Selecting is not possible")]
+ InvalidSelectTypes,
+ #[error("Relational argument {0:?} is not a boolean vector")]
+ InvalidBooleanVector(Handle<crate::Expression>),
+ #[error("Relational argument {0:?} is not a float")]
+ InvalidFloatArgument(Handle<crate::Expression>),
+ #[error("Type resolution failed")]
+ Type(#[from] ResolveError),
+ #[error("Not a global variable")]
+ ExpectedGlobalVariable,
+ #[error("Not a global variable or a function argument")]
+ ExpectedGlobalOrArgument,
+ #[error("Needs to be an binding array instead of {0:?}")]
+ ExpectedBindingArrayType(Handle<crate::Type>),
+ #[error("Needs to be an image instead of {0:?}")]
+ ExpectedImageType(Handle<crate::Type>),
+ #[error("Needs to be an image instead of {0:?}")]
+ ExpectedSamplerType(Handle<crate::Type>),
+ #[error("Unable to operate on image class {0:?}")]
+ InvalidImageClass(crate::ImageClass),
+ #[error("Derivatives can only be taken from scalar and vector floats")]
+ InvalidDerivative,
+ #[error("Image array index parameter is misplaced")]
+ InvalidImageArrayIndex,
+ #[error("Inappropriate sample or level-of-detail index for texel access")]
+ InvalidImageOtherIndex,
+ #[error("Image array index type of {0:?} is not an integer scalar")]
+ InvalidImageArrayIndexType(Handle<crate::Expression>),
+ #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
+ InvalidImageOtherIndexType(Handle<crate::Expression>),
+ #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
+ InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
+ #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
+ ComparisonSamplingMismatch {
+ image: crate::ImageClass,
+ sampler: bool,
+ has_ref: bool,
+ },
+ #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
+ InvalidSampleOffset(crate::ImageDimension, Handle<crate::Constant>),
+ #[error("Depth reference {0:?} is not a scalar float")]
+ InvalidDepthReference(Handle<crate::Expression>),
+ #[error("Depth sample level can only be Auto or Zero")]
+ InvalidDepthSampleLevel,
+ #[error("Gather level can only be Zero")]
+ InvalidGatherLevel,
+ #[error("Gather component {0:?} doesn't exist in the image")]
+ InvalidGatherComponent(crate::SwizzleComponent),
+ #[error("Gather can't be done for image dimension {0:?}")]
+ InvalidGatherDimension(crate::ImageDimension),
+ #[error("Sample level (exact) type {0:?} is not a scalar float")]
+ InvalidSampleLevelExactType(Handle<crate::Expression>),
+ #[error("Sample level (bias) type {0:?} is not a scalar float")]
+ InvalidSampleLevelBiasType(Handle<crate::Expression>),
+ #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
+ InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
+ #[error("Unable to cast")]
+ InvalidCastArgument,
+ #[error("Invalid argument count for {0:?}")]
+ WrongArgumentCount(crate::MathFunction),
+ #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
+ InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
+ #[error("Atomic result type can't be {0:?}")]
+ InvalidAtomicResultType(Handle<crate::Type>),
+ #[error("Shader requires capability {0:?}")]
+ MissingCapabilities(super::Capabilities),
+}
+
+#[cfg(feature = "validate")]
+struct ExpressionTypeResolver<'a> {
+ root: Handle<crate::Expression>,
+ types: &'a UniqueArena<crate::Type>,
+ info: &'a FunctionInfo,
+}
+
+#[cfg(feature = "validate")]
+impl<'a> Index<Handle<crate::Expression>> for ExpressionTypeResolver<'a> {
+ type Output = crate::TypeInner;
+
+ #[allow(clippy::panic)]
+ fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
+ if handle < self.root {
+ self.info[handle].ty.inner_with(self.types)
+ } else {
+ // `Validator::validate_module_handles` should have caught this.
+ panic!(
+ "Depends on {:?}, which has not been processed yet",
+ self.root
+ )
+ }
+ }
+}
+
+#[cfg(feature = "validate")]
+impl super::Validator {
+ pub(super) fn validate_expression(
+ &self,
+ root: Handle<crate::Expression>,
+ expression: &crate::Expression,
+ function: &crate::Function,
+ module: &crate::Module,
+ info: &FunctionInfo,
+ other_infos: &[FunctionInfo],
+ ) -> Result<ShaderStages, ExpressionError> {
+ use crate::{Expression as E, ScalarKind as Sk, TypeInner as Ti};
+
+ let resolver = ExpressionTypeResolver {
+ root,
+ types: &module.types,
+ info,
+ };
+
+ let stages = match *expression {
+ E::Access { base, index } => {
+ let base_type = &resolver[base];
+ // See the documentation for `Expression::Access`.
+ let dynamic_indexing_restricted = match *base_type {
+ Ti::Vector { .. } => false,
+ Ti::Matrix { .. } | Ti::Array { .. } => true,
+ Ti::Pointer { .. }
+ | Ti::ValuePointer { size: Some(_), .. }
+ | Ti::BindingArray { .. } => false,
+ ref other => {
+ log::error!("Indexing of {:?}", other);
+ return Err(ExpressionError::InvalidBaseType(base));
+ }
+ };
+ match resolver[index] {
+ //TODO: only allow one of these
+ Ti::Scalar {
+ kind: Sk::Sint | Sk::Uint,
+ width: _,
+ } => {}
+ ref other => {
+ log::error!("Indexing by {:?}", other);
+ return Err(ExpressionError::InvalidIndexType(index));
+ }
+ }
+ if dynamic_indexing_restricted
+ && function.expressions[index].is_dynamic_index(module)
+ {
+ return Err(ExpressionError::IndexMustBeConstant(base));
+ }
+
+ // If we know both the length and the index, we can do the
+ // bounds check now.
+ if let crate::proc::IndexableLength::Known(known_length) =
+ base_type.indexable_length(module)?
+ {
+ if let E::Constant(k) = function.expressions[index] {
+ if let crate::Constant {
+ // We must treat specializable constants as unknown.
+ specialization: None,
+ // Non-scalar indices should have been caught above.
+ inner: crate::ConstantInner::Scalar { value, .. },
+ ..
+ } = module.constants[k]
+ {
+ match value {
+ crate::ScalarValue::Uint(u) if u >= known_length as u64 => {
+ return Err(ExpressionError::IndexOutOfBounds(base, value));
+ }
+ crate::ScalarValue::Sint(s)
+ if s < 0 || s >= known_length as i64 =>
+ {
+ return Err(ExpressionError::IndexOutOfBounds(base, value));
+ }
+ _ => (),
+ }
+ }
+ }
+ }
+
+ ShaderStages::all()
+ }
+ E::AccessIndex { base, index } => {
+ fn resolve_index_limit(
+ module: &crate::Module,
+ top: Handle<crate::Expression>,
+ ty: &crate::TypeInner,
+ top_level: bool,
+ ) -> Result<u32, ExpressionError> {
+ let limit = match *ty {
+ Ti::Vector { size, .. }
+ | Ti::ValuePointer {
+ size: Some(size), ..
+ } => size as u32,
+ Ti::Matrix { columns, .. } => columns as u32,
+ Ti::Array {
+ size: crate::ArraySize::Constant(handle),
+ ..
+ } => module.constants[handle].to_array_length().unwrap(),
+ Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, // can't statically know, but need run-time checks
+ Ti::Pointer { base, .. } if top_level => {
+ resolve_index_limit(module, top, &module.types[base].inner, false)?
+ }
+ Ti::Struct { ref members, .. } => members.len() as u32,
+ ref other => {
+ log::error!("Indexing of {:?}", other);
+ return Err(ExpressionError::InvalidBaseType(top));
+ }
+ };
+ Ok(limit)
+ }
+
+ let limit = resolve_index_limit(module, base, &resolver[base], true)?;
+ if index >= limit {
+ return Err(ExpressionError::IndexOutOfBounds(
+ base,
+ crate::ScalarValue::Uint(limit as _),
+ ));
+ }
+ ShaderStages::all()
+ }
+ E::Constant(_handle) => ShaderStages::all(),
+ E::Splat { size: _, value } => match resolver[value] {
+ Ti::Scalar { .. } => ShaderStages::all(),
+ ref other => {
+ log::error!("Splat scalar type {:?}", other);
+ return Err(ExpressionError::InvalidSplatType(value));
+ }
+ },
+ E::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ let vec_size = match resolver[vector] {
+ Ti::Vector { size: vec_size, .. } => vec_size,
+ ref other => {
+ log::error!("Swizzle vector type {:?}", other);
+ return Err(ExpressionError::InvalidVectorType(vector));
+ }
+ };
+ for &sc in pattern[..size as usize].iter() {
+ if sc as u8 >= vec_size as u8 {
+ return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
+ }
+ }
+ ShaderStages::all()
+ }
+ E::Compose { ref components, ty } => {
+ validate_compose(
+ ty,
+ &module.constants,
+ &module.types,
+ components.iter().map(|&handle| info[handle].ty.clone()),
+ )?;
+ ShaderStages::all()
+ }
+ E::FunctionArgument(index) => {
+ if index >= function.arguments.len() as u32 {
+ return Err(ExpressionError::FunctionArgumentDoesntExist(index));
+ }
+ ShaderStages::all()
+ }
+ E::GlobalVariable(_handle) => ShaderStages::all(),
+ E::LocalVariable(_handle) => ShaderStages::all(),
+ E::Load { pointer } => {
+ match resolver[pointer] {
+ Ti::Pointer { base, .. }
+ if self.types[base.index()]
+ .flags
+ .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
+ Ti::ValuePointer { .. } => {}
+ ref other => {
+ log::error!("Loading {:?}", other);
+ return Err(ExpressionError::InvalidPointerType(pointer));
+ }
+ }
+ ShaderStages::all()
+ }
+ E::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ // check the validity of expressions
+ let image_ty = Self::global_var_ty(module, function, image)?;
+ let sampler_ty = Self::global_var_ty(module, function, sampler)?;
+
+ let comparison = match module.types[sampler_ty].inner {
+ Ti::Sampler { comparison } => comparison,
+ _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
+ };
+
+ let (class, dim) = match module.types[image_ty].inner {
+ Ti::Image {
+ class,
+ arrayed,
+ dim,
+ } => {
+ // check the array property
+ if arrayed != array_index.is_some() {
+ return Err(ExpressionError::InvalidImageArrayIndex);
+ }
+ if let Some(expr) = array_index {
+ match resolver[expr] {
+ Ti::Scalar {
+ kind: Sk::Sint | Sk::Uint,
+ width: _,
+ } => {}
+ _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
+ }
+ }
+ (class, dim)
+ }
+ _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
+ };
+
+ // check sampling and comparison properties
+ let image_depth = match class {
+ crate::ImageClass::Sampled {
+ kind: crate::ScalarKind::Float,
+ multi: false,
+ } => false,
+ crate::ImageClass::Sampled {
+ kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
+ multi: false,
+ } if gather.is_some() => false,
+ crate::ImageClass::Depth { multi: false } => true,
+ _ => return Err(ExpressionError::InvalidImageClass(class)),
+ };
+ if comparison != depth_ref.is_some() || (comparison && !image_depth) {
+ return Err(ExpressionError::ComparisonSamplingMismatch {
+ image: class,
+ sampler: comparison,
+ has_ref: depth_ref.is_some(),
+ });
+ }
+
+ // check texture coordinates type
+ let num_components = match dim {
+ crate::ImageDimension::D1 => 1,
+ crate::ImageDimension::D2 => 2,
+ crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
+ };
+ match resolver[coordinate] {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ } if num_components == 1 => {}
+ Ti::Vector {
+ size,
+ kind: Sk::Float,
+ ..
+ } if size as u32 == num_components => {}
+ _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
+ }
+
+ // check constant offset
+ if let Some(const_handle) = offset {
+ let good = match module.constants[const_handle].inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Sint(_),
+ } => num_components == 1,
+ crate::ConstantInner::Scalar { .. } => false,
+ crate::ConstantInner::Composite { ty, .. } => {
+ match module.types[ty].inner {
+ Ti::Vector {
+ size,
+ kind: Sk::Sint,
+ ..
+ } => size as u32 == num_components,
+ _ => false,
+ }
+ }
+ };
+ if !good {
+ return Err(ExpressionError::InvalidSampleOffset(dim, const_handle));
+ }
+ }
+
+ // check depth reference type
+ if let Some(expr) = depth_ref {
+ match resolver[expr] {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidDepthReference(expr)),
+ }
+ match level {
+ crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
+ _ => return Err(ExpressionError::InvalidDepthSampleLevel),
+ }
+ }
+
+ if let Some(component) = gather {
+ match dim {
+ crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
+ crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
+ return Err(ExpressionError::InvalidGatherDimension(dim))
+ }
+ };
+ let max_component = match class {
+ crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
+ _ => crate::SwizzleComponent::W,
+ };
+ if component > max_component {
+ return Err(ExpressionError::InvalidGatherComponent(component));
+ }
+ match level {
+ crate::SampleLevel::Zero => {}
+ _ => return Err(ExpressionError::InvalidGatherLevel),
+ }
+ }
+
+ // check level properties
+ match level {
+ crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
+ crate::SampleLevel::Zero => ShaderStages::all(),
+ crate::SampleLevel::Exact(expr) => {
+ match resolver[expr] {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidSampleLevelExactType(expr)),
+ }
+ ShaderStages::all()
+ }
+ crate::SampleLevel::Bias(expr) => {
+ match resolver[expr] {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
+ }
+ ShaderStages::all()
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ match resolver[x] {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ } if num_components == 1 => {}
+ Ti::Vector {
+ size,
+ kind: Sk::Float,
+ ..
+ } if size as u32 == num_components => {}
+ _ => {
+ return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
+ }
+ }
+ match resolver[y] {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ } if num_components == 1 => {}
+ Ti::Vector {
+ size,
+ kind: Sk::Float,
+ ..
+ } if size as u32 == num_components => {}
+ _ => {
+ return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
+ }
+ }
+ ShaderStages::all()
+ }
+ }
+ }
+ E::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ let ty = Self::global_var_ty(module, function, image)?;
+ match module.types[ty].inner {
+ Ti::Image {
+ class,
+ arrayed,
+ dim,
+ } => {
+ match resolver[coordinate].image_storage_coordinates() {
+ Some(coord_dim) if coord_dim == dim => {}
+ _ => {
+ return Err(ExpressionError::InvalidImageCoordinateType(
+ dim, coordinate,
+ ))
+ }
+ };
+ if arrayed != array_index.is_some() {
+ return Err(ExpressionError::InvalidImageArrayIndex);
+ }
+ if let Some(expr) = array_index {
+ match resolver[expr] {
+ Ti::Scalar {
+ kind: Sk::Sint | Sk::Uint,
+ width: _,
+ } => {}
+ _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
+ }
+ }
+
+ match (sample, class.is_multisampled()) {
+ (None, false) => {}
+ (Some(sample), true) => {
+ if resolver[sample].scalar_kind() != Some(Sk::Sint) {
+ return Err(ExpressionError::InvalidImageOtherIndexType(
+ sample,
+ ));
+ }
+ }
+ _ => {
+ return Err(ExpressionError::InvalidImageOtherIndex);
+ }
+ }
+
+ match (level, class.is_mipmapped()) {
+ (None, false) => {}
+ (Some(level), true) => {
+ if resolver[level].scalar_kind() != Some(Sk::Sint) {
+ return Err(ExpressionError::InvalidImageOtherIndexType(level));
+ }
+ }
+ _ => {
+ return Err(ExpressionError::InvalidImageOtherIndex);
+ }
+ }
+ }
+ _ => return Err(ExpressionError::ExpectedImageType(ty)),
+ }
+ ShaderStages::all()
+ }
+ E::ImageQuery { image, query } => {
+ let ty = Self::global_var_ty(module, function, image)?;
+ match module.types[ty].inner {
+ Ti::Image { class, arrayed, .. } => {
+ let good = match query {
+ crate::ImageQuery::NumLayers => arrayed,
+ crate::ImageQuery::Size { level: None } => true,
+ crate::ImageQuery::Size { level: Some(_) }
+ | crate::ImageQuery::NumLevels => class.is_mipmapped(),
+ crate::ImageQuery::NumSamples => class.is_multisampled(),
+ };
+ if !good {
+ return Err(ExpressionError::InvalidImageClass(class));
+ }
+ }
+ _ => return Err(ExpressionError::ExpectedImageType(ty)),
+ }
+ ShaderStages::all()
+ }
+ E::Unary { op, expr } => {
+ use crate::UnaryOperator as Uo;
+ let inner = &resolver[expr];
+ match (op, inner.scalar_kind()) {
+ (_, Some(Sk::Sint | Sk::Bool))
+ //TODO: restrict Negate for bools?
+ | (Uo::Negate, Some(Sk::Float))
+ | (Uo::Not, Some(Sk::Uint)) => {}
+ other => {
+ log::error!("Op {:?} kind {:?}", op, other);
+ return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
+ }
+ }
+ ShaderStages::all()
+ }
+ E::Binary { op, left, right } => {
+ use crate::BinaryOperator as Bo;
+ let left_inner = &resolver[left];
+ let right_inner = &resolver[right];
+ let good = match op {
+ Bo::Add | Bo::Subtract => match *left_inner {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind {
+ Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
+ Sk::Bool => false,
+ },
+ Ti::Matrix { .. } => left_inner == right_inner,
+ _ => false,
+ },
+ Bo::Divide | Bo::Modulo => match *left_inner {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind {
+ Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
+ Sk::Bool => false,
+ },
+ _ => false,
+ },
+ Bo::Multiply => {
+ let kind_allowed = match left_inner.scalar_kind() {
+ Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
+ Some(Sk::Bool) | None => false,
+ };
+ let types_match = match (left_inner, right_inner) {
+ // Straight scalar and mixed scalar/vector.
+ (&Ti::Scalar { kind: kind1, .. }, &Ti::Scalar { kind: kind2, .. })
+ | (&Ti::Vector { kind: kind1, .. }, &Ti::Scalar { kind: kind2, .. })
+ | (&Ti::Scalar { kind: kind1, .. }, &Ti::Vector { kind: kind2, .. }) => {
+ kind1 == kind2
+ }
+ // Scalar/matrix.
+ (
+ &Ti::Scalar {
+ kind: Sk::Float, ..
+ },
+ &Ti::Matrix { .. },
+ )
+ | (
+ &Ti::Matrix { .. },
+ &Ti::Scalar {
+ kind: Sk::Float, ..
+ },
+ ) => true,
+ // Vector/vector.
+ (
+ &Ti::Vector {
+ kind: kind1,
+ size: size1,
+ ..
+ },
+ &Ti::Vector {
+ kind: kind2,
+ size: size2,
+ ..
+ },
+ ) => kind1 == kind2 && size1 == size2,
+ // Matrix * vector.
+ (
+ &Ti::Matrix { columns, .. },
+ &Ti::Vector {
+ kind: Sk::Float,
+ size,
+ ..
+ },
+ ) => columns == size,
+ // Vector * matrix.
+ (
+ &Ti::Vector {
+ kind: Sk::Float,
+ size,
+ ..
+ },
+ &Ti::Matrix { rows, .. },
+ ) => size == rows,
+ (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
+ columns == rows
+ }
+ _ => false,
+ };
+ let left_width = match *left_inner {
+ Ti::Scalar { width, .. }
+ | Ti::Vector { width, .. }
+ | Ti::Matrix { width, .. } => width,
+ _ => 0,
+ };
+ let right_width = match *right_inner {
+ Ti::Scalar { width, .. }
+ | Ti::Vector { width, .. }
+ | Ti::Matrix { width, .. } => width,
+ _ => 0,
+ };
+ kind_allowed && types_match && left_width == right_width
+ }
+ Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
+ Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
+ match *left_inner {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind {
+ Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
+ Sk::Bool => false,
+ },
+ ref other => {
+ log::error!("Op {:?} left type {:?}", op, other);
+ false
+ }
+ }
+ }
+ Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
+ Ti::Scalar { kind: Sk::Bool, .. } | Ti::Vector { kind: Sk::Bool, .. } => {
+ left_inner == right_inner
+ }
+ ref other => {
+ log::error!("Op {:?} left type {:?}", op, other);
+ false
+ }
+ },
+ Bo::And | Bo::InclusiveOr => match *left_inner {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind {
+ Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
+ Sk::Float => false,
+ },
+ ref other => {
+ log::error!("Op {:?} left type {:?}", op, other);
+ false
+ }
+ },
+ Bo::ExclusiveOr => match *left_inner {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind {
+ Sk::Sint | Sk::Uint => left_inner == right_inner,
+ Sk::Bool | Sk::Float => false,
+ },
+ ref other => {
+ log::error!("Op {:?} left type {:?}", op, other);
+ false
+ }
+ },
+ Bo::ShiftLeft | Bo::ShiftRight => {
+ let (base_size, base_kind) = match *left_inner {
+ Ti::Scalar { kind, .. } => (Ok(None), kind),
+ Ti::Vector { size, kind, .. } => (Ok(Some(size)), kind),
+ ref other => {
+ log::error!("Op {:?} base type {:?}", op, other);
+ (Err(()), Sk::Bool)
+ }
+ };
+ let shift_size = match *right_inner {
+ Ti::Scalar { kind: Sk::Uint, .. } => Ok(None),
+ Ti::Vector {
+ size,
+ kind: Sk::Uint,
+ ..
+ } => Ok(Some(size)),
+ ref other => {
+ log::error!("Op {:?} shift type {:?}", op, other);
+ Err(())
+ }
+ };
+ match base_kind {
+ Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
+ Sk::Float | Sk::Bool => false,
+ }
+ }
+ };
+ if !good {
+ log::error!(
+ "Left: {:?} of type {:?}",
+ function.expressions[left],
+ left_inner
+ );
+ log::error!(
+ "Right: {:?} of type {:?}",
+ function.expressions[right],
+ right_inner
+ );
+ return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right));
+ }
+ ShaderStages::all()
+ }
+ E::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ let accept_inner = &resolver[accept];
+ let reject_inner = &resolver[reject];
+ let condition_good = match resolver[condition] {
+ Ti::Scalar {
+ kind: Sk::Bool,
+ width: _,
+ } => {
+ // When `condition` is a single boolean, `accept` and
+ // `reject` can be vectors or scalars.
+ match *accept_inner {
+ Ti::Scalar { .. } | Ti::Vector { .. } => true,
+ _ => false,
+ }
+ }
+ Ti::Vector {
+ size,
+ kind: Sk::Bool,
+ width: _,
+ } => match *accept_inner {
+ Ti::Vector {
+ size: other_size, ..
+ } => size == other_size,
+ _ => false,
+ },
+ _ => false,
+ };
+ if !condition_good || accept_inner != reject_inner {
+ return Err(ExpressionError::InvalidSelectTypes);
+ }
+ ShaderStages::all()
+ }
+ E::Derivative { expr, .. } => {
+ match resolver[expr] {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ }
+ | Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidDerivative),
+ }
+ ShaderStages::FRAGMENT
+ }
+ E::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+ let argument_inner = &resolver[argument];
+ match fun {
+ Rf::All | Rf::Any => match *argument_inner {
+ Ti::Vector { kind: Sk::Bool, .. } => {}
+ ref other => {
+ log::error!("All/Any of type {:?}", other);
+ return Err(ExpressionError::InvalidBooleanVector(argument));
+ }
+ },
+ Rf::IsNan | Rf::IsInf | Rf::IsFinite | Rf::IsNormal => match *argument_inner {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ }
+ | Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ ref other => {
+ log::error!("Float test of type {:?}", other);
+ return Err(ExpressionError::InvalidFloatArgument(argument));
+ }
+ },
+ }
+ ShaderStages::all()
+ }
+ E::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ let resolve = |arg| &resolver[arg];
+ let arg_ty = resolve(arg);
+ let arg1_ty = arg1.map(resolve);
+ let arg2_ty = arg2.map(resolve);
+ let arg3_ty = arg3.map(resolve);
+ match fun {
+ Mf::Abs => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ let good = match *arg_ty {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => kind != Sk::Bool,
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
+ }
+ }
+ Mf::Min | Mf::Max => {
+ let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), None, None) => ty1,
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ let good = match *arg_ty {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => kind != Sk::Bool,
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ }
+ Mf::Clamp => {
+ let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), None) => (ty1, ty2),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ let good = match *arg_ty {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => kind != Sk::Bool,
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ if arg2_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ));
+ }
+ }
+ Mf::Saturate
+ | Mf::Cos
+ | Mf::Cosh
+ | Mf::Sin
+ | Mf::Sinh
+ | Mf::Tan
+ | Mf::Tanh
+ | Mf::Acos
+ | Mf::Asin
+ | Mf::Atan
+ | Mf::Asinh
+ | Mf::Acosh
+ | Mf::Atanh
+ | Mf::Radians
+ | Mf::Degrees
+ | Mf::Ceil
+ | Mf::Floor
+ | Mf::Round
+ | Mf::Fract
+ | Mf::Trunc
+ | Mf::Exp
+ | Mf::Exp2
+ | Mf::Log
+ | Mf::Log2
+ | Mf::Length
+ | Mf::Sign
+ | Mf::Sqrt
+ | Mf::InverseSqrt => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ }
+ | Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => {
+ let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), None, None) => ty1,
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ }
+ | Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ }
+ Mf::Modf | Mf::Frexp | Mf::Ldexp => {
+ let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), None, None) => ty1,
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ let (size0, width0) = match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Float,
+ width,
+ } => (None, width),
+ Ti::Vector {
+ kind: Sk::Float,
+ size,
+ width,
+ } => (Some(size), width),
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ };
+ let good = match *arg1_ty {
+ Ti::Pointer { base, space: _ } => module.types[base].inner == *arg_ty,
+ Ti::ValuePointer {
+ size,
+ kind: Sk::Float,
+ width,
+ space: _,
+ } => size == size0 && width == width0,
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ }
+ Mf::Dot => {
+ let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), None, None) => ty1,
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Vector {
+ kind: Sk::Float | Sk::Sint | Sk::Uint,
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ }
+ Mf::Outer | Mf::Cross | Mf::Reflect => {
+ let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), None, None) => ty1,
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ }
+ Mf::Refract => {
+ let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), None) => (ty1, ty2),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+
+ match *arg_ty {
+ Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+
+ match (arg_ty, arg2_ty) {
+ (
+ &Ti::Vector {
+ width: vector_width,
+ ..
+ },
+ &Ti::Scalar {
+ width: scalar_width,
+ kind: Sk::Float,
+ },
+ ) if vector_width == scalar_width => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ))
+ }
+ }
+ }
+ Mf::Normalize => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::FaceForward | Mf::Fma | Mf::SmoothStep => {
+ let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), None) => (ty1, ty2),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ }
+ | Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ if arg2_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ));
+ }
+ }
+ Mf::Mix => {
+ let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), None) => (ty1, ty2),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ let arg_width = match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Float,
+ width,
+ }
+ | Ti::Vector {
+ kind: Sk::Float,
+ width,
+ ..
+ } => width,
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ };
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ // the last argument can always be a scalar
+ match *arg2_ty {
+ Ti::Scalar {
+ kind: Sk::Float,
+ width,
+ } if width == arg_width => {}
+ _ if arg2_ty == arg_ty => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ));
+ }
+ }
+ }
+ Mf::Inverse | Mf::Determinant => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ let good = match *arg_ty {
+ Ti::Matrix { columns, rows, .. } => columns == rows,
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
+ }
+ }
+ Mf::Transpose => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Matrix { .. } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::CountTrailingZeros
+ | Mf::CountLeadingZeros
+ | Mf::CountOneBits
+ | Mf::ReverseBits
+ | Mf::FindLsb
+ | Mf::FindMsb => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ }
+ | Ti::Vector {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::InsertBits => {
+ let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ }
+ | Ti::Vector {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ match *arg2_ty {
+ Ti::Scalar { kind: Sk::Uint, .. } => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ))
+ }
+ }
+ match *arg3_ty {
+ Ti::Scalar { kind: Sk::Uint, .. } => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg3.unwrap(),
+ ))
+ }
+ }
+ }
+ Mf::ExtractBits => {
+ let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), None) => (ty1, ty2),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ }
+ | Ti::Vector {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ match *arg1_ty {
+ Ti::Scalar { kind: Sk::Uint, .. } => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg1.unwrap(),
+ ))
+ }
+ }
+ match *arg2_ty {
+ Ti::Scalar { kind: Sk::Uint, .. } => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ))
+ }
+ }
+ }
+ Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Vector {
+ size: crate::VectorSize::Bi,
+ kind: Sk::Float,
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::Pack4x8snorm | Mf::Pack4x8unorm => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Vector {
+ size: crate::VectorSize::Quad,
+ kind: Sk::Float,
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::Unpack2x16float
+ | Mf::Unpack2x16snorm
+ | Mf::Unpack2x16unorm
+ | Mf::Unpack4x8snorm
+ | Mf::Unpack4x8unorm => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Scalar { kind: Sk::Uint, .. } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ }
+ ShaderStages::all()
+ }
+ E::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ let base_width = match resolver[expr] {
+ crate::TypeInner::Scalar { width, .. }
+ | crate::TypeInner::Vector { width, .. }
+ | crate::TypeInner::Matrix { width, .. } => width,
+ _ => return Err(ExpressionError::InvalidCastArgument),
+ };
+ let width = convert.unwrap_or(base_width);
+ if self.check_width(kind, width).is_err() {
+ return Err(ExpressionError::InvalidCastArgument);
+ }
+ ShaderStages::all()
+ }
+ E::CallResult(function) => other_infos[function.index()].available_stages,
+ E::AtomicResult { ty, comparison } => {
+ let scalar_predicate = |ty: &crate::TypeInner| match ty {
+ &crate::TypeInner::Scalar {
+ kind: kind @ (crate::ScalarKind::Uint | crate::ScalarKind::Sint),
+ width,
+ } => self.check_width(kind, width).is_ok(),
+ _ => false,
+ };
+ let good = match &module.types[ty].inner {
+ ty if !comparison => scalar_predicate(ty),
+ &crate::TypeInner::Struct { ref members, .. } if comparison => {
+ validate_atomic_compare_exchange_struct(
+ &module.types,
+ members,
+ scalar_predicate,
+ )
+ }
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidAtomicResultType(ty));
+ }
+ ShaderStages::all()
+ }
+ E::ArrayLength(expr) => match resolver[expr] {
+ Ti::Pointer { base, .. } => {
+ let base_ty = &resolver.types[base];
+ if let Ti::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } = base_ty.inner
+ {
+ ShaderStages::all()
+ } else {
+ return Err(ExpressionError::InvalidArrayType(expr));
+ }
+ }
+ ref other => {
+ log::error!("Array length of {:?}", other);
+ return Err(ExpressionError::InvalidArrayType(expr));
+ }
+ },
+ E::RayQueryProceedResult => ShaderStages::all(),
+ E::RayQueryGetIntersection {
+ query,
+ committed: _,
+ } => match resolver[query] {
+ Ti::Pointer {
+ base,
+ space: crate::AddressSpace::Function,
+ } => match resolver.types[base].inner {
+ Ti::RayQuery => ShaderStages::all(),
+ ref other => {
+ log::error!("Intersection result of a pointer to {:?}", other);
+ return Err(ExpressionError::InvalidRayQueryType(query));
+ }
+ },
+ ref other => {
+ log::error!("Intersection result of {:?}", other);
+ return Err(ExpressionError::InvalidRayQueryType(query));
+ }
+ },
+ };
+ Ok(stages)
+ }
+
+ fn global_var_ty(
+ module: &crate::Module,
+ function: &crate::Function,
+ expr: Handle<crate::Expression>,
+ ) -> Result<Handle<crate::Type>, ExpressionError> {
+ use crate::Expression as Ex;
+
+ match function.expressions[expr] {
+ Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
+ Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
+ Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
+ match function.expressions[base] {
+ Ex::GlobalVariable(var_handle) => {
+ let array_ty = module.global_variables[var_handle].ty;
+
+ match module.types[array_ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => Ok(base),
+ _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
+ }
+ }
+ _ => Err(ExpressionError::ExpectedGlobalVariable),
+ }
+ }
+ _ => Err(ExpressionError::ExpectedGlobalVariable),
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/valid/function.rs b/third_party/rust/naga/src/valid/function.rs
new file mode 100644
index 0000000000..8ebcddd9e3
--- /dev/null
+++ b/third_party/rust/naga/src/valid/function.rs
@@ -0,0 +1,1048 @@
+use crate::arena::Handle;
+#[cfg(feature = "validate")]
+use crate::arena::{Arena, UniqueArena};
+
+#[cfg(feature = "validate")]
+use super::validate_atomic_compare_exchange_struct;
+
+use super::{
+ analyzer::{UniformityDisruptor, UniformityRequirements},
+ ExpressionError, FunctionInfo, ModuleInfo,
+};
+use crate::span::WithSpan;
+#[cfg(feature = "validate")]
+use crate::span::{AddSpan as _, MapErrWithSpan as _};
+
+#[cfg(feature = "validate")]
+use bit_set::BitSet;
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum CallError {
+ #[error("Argument {index} expression is invalid")]
+ Argument {
+ index: usize,
+ source: ExpressionError,
+ },
+ #[error("Result expression {0:?} has already been introduced earlier")]
+ ResultAlreadyInScope(Handle<crate::Expression>),
+ #[error("Result value is invalid")]
+ ResultValue(#[source] ExpressionError),
+ #[error("Requires {required} arguments, but {seen} are provided")]
+ ArgumentCount { required: usize, seen: usize },
+ #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")]
+ ArgumentType {
+ index: usize,
+ required: Handle<crate::Type>,
+ seen_expression: Handle<crate::Expression>,
+ },
+ #[error("The emitted expression doesn't match the call")]
+ ExpressionMismatch(Option<Handle<crate::Expression>>),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum AtomicError {
+ #[error("Pointer {0:?} to atomic is invalid.")]
+ InvalidPointer(Handle<crate::Expression>),
+ #[error("Operand {0:?} has invalid type.")]
+ InvalidOperand(Handle<crate::Expression>),
+ #[error("Result type for {0:?} doesn't match the statement")]
+ ResultTypeMismatch(Handle<crate::Expression>),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum LocalVariableError {
+ #[error("Local variable has a type {0:?} that can't be stored in a local variable.")]
+ InvalidType(Handle<crate::Type>),
+ #[error("Initializer doesn't match the variable type")]
+ InitializerType,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum FunctionError {
+ #[error("Expression {handle:?} is invalid")]
+ Expression {
+ handle: Handle<crate::Expression>,
+ source: ExpressionError,
+ },
+ #[error("Expression {0:?} can't be introduced - it's already in scope")]
+ ExpressionAlreadyInScope(Handle<crate::Expression>),
+ #[error("Local variable {handle:?} '{name}' is invalid")]
+ LocalVariable {
+ handle: Handle<crate::LocalVariable>,
+ name: String,
+ source: LocalVariableError,
+ },
+ #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")]
+ InvalidArgumentType { index: usize, name: String },
+ #[error("The function's given return type cannot be returned from functions")]
+ NonConstructibleReturnType,
+ #[error("Argument '{name}' at index {index} is a pointer of space {space:?}, which can't be passed into functions.")]
+ InvalidArgumentPointerSpace {
+ index: usize,
+ name: String,
+ space: crate::AddressSpace,
+ },
+ #[error("There are instructions after `return`/`break`/`continue`")]
+ InstructionsAfterReturn,
+ #[error("The `break` is used outside of a `loop` or `switch` context")]
+ BreakOutsideOfLoopOrSwitch,
+ #[error("The `continue` is used outside of a `loop` context")]
+ ContinueOutsideOfLoop,
+ #[error("The `return` is called within a `continuing` block")]
+ InvalidReturnSpot,
+ #[error("The `return` value {0:?} does not match the function return value")]
+ InvalidReturnType(Option<Handle<crate::Expression>>),
+ #[error("The `if` condition {0:?} is not a boolean scalar")]
+ InvalidIfType(Handle<crate::Expression>),
+ #[error("The `switch` value {0:?} is not an integer scalar")]
+ InvalidSwitchType(Handle<crate::Expression>),
+ #[error("Multiple `switch` cases for {0:?} are present")]
+ ConflictingSwitchCase(crate::SwitchValue),
+ #[error("The `switch` contains cases with conflicting types")]
+ ConflictingCaseType,
+ #[error("The `switch` is missing a `default` case")]
+ MissingDefaultCase,
+ #[error("Multiple `default` cases are present")]
+ MultipleDefaultCases,
+ #[error("The last `switch` case contains a `falltrough`")]
+ LastCaseFallTrough,
+ #[error("The pointer {0:?} doesn't relate to a valid destination for a store")]
+ InvalidStorePointer(Handle<crate::Expression>),
+ #[error("The value {0:?} can not be stored")]
+ InvalidStoreValue(Handle<crate::Expression>),
+ #[error("Store of {value:?} into {pointer:?} doesn't have matching types")]
+ InvalidStoreTypes {
+ pointer: Handle<crate::Expression>,
+ value: Handle<crate::Expression>,
+ },
+ #[error("Image store parameters are invalid")]
+ InvalidImageStore(#[source] ExpressionError),
+ #[error("Call to {function:?} is invalid")]
+ InvalidCall {
+ function: Handle<crate::Function>,
+ #[source]
+ error: CallError,
+ },
+ #[error("Atomic operation is invalid")]
+ InvalidAtomic(#[from] AtomicError),
+ #[error("Ray Query {0:?} is not a local variable")]
+ InvalidRayQueryExpression(Handle<crate::Expression>),
+ #[error("Acceleration structure {0:?} is not a matching expression")]
+ InvalidAccelerationStructure(Handle<crate::Expression>),
+ #[error("Ray descriptor {0:?} is not a matching expression")]
+ InvalidRayDescriptor(Handle<crate::Expression>),
+ #[error("Ray Query {0:?} does not have a matching type")]
+ InvalidRayQueryType(Handle<crate::Type>),
+ #[error(
+ "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
+ )]
+ NonUniformControlFlow(
+ UniformityRequirements,
+ Handle<crate::Expression>,
+ UniformityDisruptor,
+ ),
+ #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their arguments: \"{name}\" has attributes")]
+ PipelineInputRegularFunction { name: String },
+ #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their return value types")]
+ PipelineOutputRegularFunction,
+}
+
+bitflags::bitflags! {
+ #[repr(transparent)]
+ struct ControlFlowAbility: u8 {
+ /// The control can return out of this block.
+ const RETURN = 0x1;
+ /// The control can break.
+ const BREAK = 0x2;
+ /// The control can continue.
+ const CONTINUE = 0x4;
+ }
+}
+
+#[cfg(feature = "validate")]
+struct BlockInfo {
+ stages: super::ShaderStages,
+ finished: bool,
+}
+
+#[cfg(feature = "validate")]
+struct BlockContext<'a> {
+ abilities: ControlFlowAbility,
+ info: &'a FunctionInfo,
+ expressions: &'a Arena<crate::Expression>,
+ types: &'a UniqueArena<crate::Type>,
+ local_vars: &'a Arena<crate::LocalVariable>,
+ global_vars: &'a Arena<crate::GlobalVariable>,
+ functions: &'a Arena<crate::Function>,
+ special_types: &'a crate::SpecialTypes,
+ prev_infos: &'a [FunctionInfo],
+ return_type: Option<Handle<crate::Type>>,
+}
+
+#[cfg(feature = "validate")]
+impl<'a> BlockContext<'a> {
+ fn new(
+ fun: &'a crate::Function,
+ module: &'a crate::Module,
+ info: &'a FunctionInfo,
+ prev_infos: &'a [FunctionInfo],
+ ) -> Self {
+ Self {
+ abilities: ControlFlowAbility::RETURN,
+ info,
+ expressions: &fun.expressions,
+ types: &module.types,
+ local_vars: &fun.local_variables,
+ global_vars: &module.global_variables,
+ functions: &module.functions,
+ special_types: &module.special_types,
+ prev_infos,
+ return_type: fun.result.as_ref().map(|fr| fr.ty),
+ }
+ }
+
+ const fn with_abilities(&self, abilities: ControlFlowAbility) -> Self {
+ BlockContext { abilities, ..*self }
+ }
+
+ fn get_expression(&self, handle: Handle<crate::Expression>) -> &'a crate::Expression {
+ &self.expressions[handle]
+ }
+
+ fn resolve_type_impl(
+ &self,
+ handle: Handle<crate::Expression>,
+ valid_expressions: &BitSet,
+ ) -> Result<&crate::TypeInner, WithSpan<ExpressionError>> {
+ if handle.index() >= self.expressions.len() {
+ Err(ExpressionError::DoesntExist.with_span())
+ } else if !valid_expressions.contains(handle.index()) {
+ Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions))
+ } else {
+ Ok(self.info[handle].ty.inner_with(self.types))
+ }
+ }
+
+ fn resolve_type(
+ &self,
+ handle: Handle<crate::Expression>,
+ valid_expressions: &BitSet,
+ ) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
+ self.resolve_type_impl(handle, valid_expressions)
+ .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span())
+ }
+
+ fn resolve_pointer_type(
+ &self,
+ handle: Handle<crate::Expression>,
+ ) -> Result<&crate::TypeInner, FunctionError> {
+ if handle.index() >= self.expressions.len() {
+ Err(FunctionError::Expression {
+ handle,
+ source: ExpressionError::DoesntExist,
+ })
+ } else {
+ Ok(self.info[handle].ty.inner_with(self.types))
+ }
+ }
+}
+
+impl super::Validator {
+ #[cfg(feature = "validate")]
+ fn validate_call(
+ &mut self,
+ function: Handle<crate::Function>,
+ arguments: &[Handle<crate::Expression>],
+ result: Option<Handle<crate::Expression>>,
+ context: &BlockContext,
+ ) -> Result<super::ShaderStages, WithSpan<CallError>> {
+ let fun = &context.functions[function];
+ if fun.arguments.len() != arguments.len() {
+ return Err(CallError::ArgumentCount {
+ required: fun.arguments.len(),
+ seen: arguments.len(),
+ }
+ .with_span());
+ }
+ for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() {
+ let ty = context
+ .resolve_type_impl(expr, &self.valid_expression_set)
+ .map_err_inner(|source| {
+ CallError::Argument { index, source }
+ .with_span_handle(expr, context.expressions)
+ })?;
+ let arg_inner = &context.types[arg.ty].inner;
+ if !ty.equivalent(arg_inner, context.types) {
+ return Err(CallError::ArgumentType {
+ index,
+ required: arg.ty,
+ seen_expression: expr,
+ }
+ .with_span_handle(expr, context.expressions));
+ }
+ }
+
+ if let Some(expr) = result {
+ if self.valid_expression_set.insert(expr.index()) {
+ self.valid_expression_list.push(expr);
+ } else {
+ return Err(CallError::ResultAlreadyInScope(expr)
+ .with_span_handle(expr, context.expressions));
+ }
+ match context.expressions[expr] {
+ crate::Expression::CallResult(callee)
+ if fun.result.is_some() && callee == function => {}
+ _ => {
+ return Err(CallError::ExpressionMismatch(result)
+ .with_span_handle(expr, context.expressions))
+ }
+ }
+ } else if fun.result.is_some() {
+ return Err(CallError::ExpressionMismatch(result).with_span());
+ }
+
+ let callee_info = &context.prev_infos[function.index()];
+ Ok(callee_info.available_stages)
+ }
+
+ #[cfg(feature = "validate")]
+ fn emit_expression(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ context: &BlockContext,
+ ) -> Result<(), WithSpan<FunctionError>> {
+ if self.valid_expression_set.insert(handle.index()) {
+ self.valid_expression_list.push(handle);
+ Ok(())
+ } else {
+ Err(FunctionError::ExpressionAlreadyInScope(handle)
+ .with_span_handle(handle, context.expressions))
+ }
+ }
+
+ #[cfg(feature = "validate")]
+ fn validate_atomic(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ fun: &crate::AtomicFunction,
+ value: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ context: &BlockContext,
+ ) -> Result<(), WithSpan<FunctionError>> {
+ let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?;
+ let (ptr_kind, ptr_width) = match *pointer_inner {
+ crate::TypeInner::Pointer { base, .. } => match context.types[base].inner {
+ crate::TypeInner::Atomic { kind, width } => (kind, width),
+ ref other => {
+ log::error!("Atomic pointer to type {:?}", other);
+ return Err(AtomicError::InvalidPointer(pointer)
+ .with_span_handle(pointer, context.expressions)
+ .into_other());
+ }
+ },
+ ref other => {
+ log::error!("Atomic on type {:?}", other);
+ return Err(AtomicError::InvalidPointer(pointer)
+ .with_span_handle(pointer, context.expressions)
+ .into_other());
+ }
+ };
+
+ let value_inner = context.resolve_type(value, &self.valid_expression_set)?;
+ match *value_inner {
+ crate::TypeInner::Scalar { width, kind } if kind == ptr_kind && width == ptr_width => {}
+ ref other => {
+ log::error!("Atomic operand type {:?}", other);
+ return Err(AtomicError::InvalidOperand(value)
+ .with_span_handle(value, context.expressions)
+ .into_other());
+ }
+ }
+
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ if context.resolve_type(cmp, &self.valid_expression_set)? != value_inner {
+ log::error!("Atomic exchange comparison has a different type from the value");
+ return Err(AtomicError::InvalidOperand(cmp)
+ .with_span_handle(cmp, context.expressions)
+ .into_other());
+ }
+ }
+
+ self.emit_expression(result, context)?;
+ match context.expressions[result] {
+ crate::Expression::AtomicResult { ty, comparison }
+ if {
+ let scalar_predicate = |ty: &crate::TypeInner| {
+ *ty == crate::TypeInner::Scalar {
+ kind: ptr_kind,
+ width: ptr_width,
+ }
+ };
+ match &context.types[ty].inner {
+ ty if !comparison => scalar_predicate(ty),
+ &crate::TypeInner::Struct { ref members, .. } if comparison => {
+ validate_atomic_compare_exchange_struct(
+ context.types,
+ members,
+ scalar_predicate,
+ )
+ }
+ _ => false,
+ }
+ } => {}
+ _ => {
+ return Err(AtomicError::ResultTypeMismatch(result)
+ .with_span_handle(result, context.expressions)
+ .into_other())
+ }
+ }
+ Ok(())
+ }
+
+ #[cfg(feature = "validate")]
+ fn validate_block_impl(
+ &mut self,
+ statements: &crate::Block,
+ context: &BlockContext,
+ ) -> Result<BlockInfo, WithSpan<FunctionError>> {
+ use crate::{Statement as S, TypeInner as Ti};
+ let mut finished = false;
+ let mut stages = super::ShaderStages::all();
+ for (statement, &span) in statements.span_iter() {
+ if finished {
+ return Err(FunctionError::InstructionsAfterReturn
+ .with_span_static(span, "instructions after return"));
+ }
+ match *statement {
+ S::Emit(ref range) => {
+ for handle in range.clone() {
+ self.emit_expression(handle, context)?;
+ }
+ }
+ S::Block(ref block) => {
+ let info = self.validate_block(block, context)?;
+ stages &= info.stages;
+ finished = info.finished;
+ }
+ S::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ match *context.resolve_type(condition, &self.valid_expression_set)? {
+ Ti::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: _,
+ } => {}
+ _ => {
+ return Err(FunctionError::InvalidIfType(condition)
+ .with_span_handle(condition, context.expressions))
+ }
+ }
+ stages &= self.validate_block(accept, context)?.stages;
+ stages &= self.validate_block(reject, context)?.stages;
+ }
+ S::Switch {
+ selector,
+ ref cases,
+ } => {
+ let uint = match context
+ .resolve_type(selector, &self.valid_expression_set)?
+ .scalar_kind()
+ {
+ Some(crate::ScalarKind::Uint) => true,
+ Some(crate::ScalarKind::Sint) => false,
+ _ => {
+ return Err(FunctionError::InvalidSwitchType(selector)
+ .with_span_handle(selector, context.expressions))
+ }
+ };
+ self.switch_values.clear();
+ for case in cases {
+ match case.value {
+ crate::SwitchValue::I32(_) if !uint => {}
+ crate::SwitchValue::U32(_) if uint => {}
+ crate::SwitchValue::Default => {}
+ _ => {
+ return Err(FunctionError::ConflictingCaseType.with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "conflicting switch arm here",
+ ));
+ }
+ };
+ if !self.switch_values.insert(case.value) {
+ return Err(match case.value {
+ crate::SwitchValue::Default => FunctionError::MultipleDefaultCases
+ .with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "duplicated switch arm here",
+ ),
+ _ => FunctionError::ConflictingSwitchCase(case.value)
+ .with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "conflicting switch arm here",
+ ),
+ });
+ }
+ }
+ if !self.switch_values.contains(&crate::SwitchValue::Default) {
+ return Err(FunctionError::MissingDefaultCase
+ .with_span_static(span, "missing default case"));
+ }
+ if let Some(case) = cases.last() {
+ if case.fall_through {
+ return Err(FunctionError::LastCaseFallTrough.with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "bad switch arm here",
+ ));
+ }
+ }
+ let pass_through_abilities = context.abilities
+ & (ControlFlowAbility::RETURN | ControlFlowAbility::CONTINUE);
+ let sub_context =
+ context.with_abilities(pass_through_abilities | ControlFlowAbility::BREAK);
+ for case in cases {
+ stages &= self.validate_block(&case.body, &sub_context)?.stages;
+ }
+ }
+ S::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ // special handling for block scoping is needed here,
+ // because the continuing{} block inherits the scope
+ let base_expression_count = self.valid_expression_list.len();
+ let pass_through_abilities = context.abilities & ControlFlowAbility::RETURN;
+ stages &= self
+ .validate_block_impl(
+ body,
+ &context.with_abilities(
+ pass_through_abilities
+ | ControlFlowAbility::BREAK
+ | ControlFlowAbility::CONTINUE,
+ ),
+ )?
+ .stages;
+ stages &= self
+ .validate_block_impl(
+ continuing,
+ &context.with_abilities(ControlFlowAbility::empty()),
+ )?
+ .stages;
+
+ if let Some(condition) = break_if {
+ match *context.resolve_type(condition, &self.valid_expression_set)? {
+ Ti::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: _,
+ } => {}
+ _ => {
+ return Err(FunctionError::InvalidIfType(condition)
+ .with_span_handle(condition, context.expressions))
+ }
+ }
+ }
+
+ for handle in self.valid_expression_list.drain(base_expression_count..) {
+ self.valid_expression_set.remove(handle.index());
+ }
+ }
+ S::Break => {
+ if !context.abilities.contains(ControlFlowAbility::BREAK) {
+ return Err(FunctionError::BreakOutsideOfLoopOrSwitch
+ .with_span_static(span, "invalid break"));
+ }
+ finished = true;
+ }
+ S::Continue => {
+ if !context.abilities.contains(ControlFlowAbility::CONTINUE) {
+ return Err(FunctionError::ContinueOutsideOfLoop
+ .with_span_static(span, "invalid continue"));
+ }
+ finished = true;
+ }
+ S::Return { value } => {
+ if !context.abilities.contains(ControlFlowAbility::RETURN) {
+ return Err(FunctionError::InvalidReturnSpot
+ .with_span_static(span, "invalid return"));
+ }
+ let value_ty = value
+ .map(|expr| context.resolve_type(expr, &self.valid_expression_set))
+ .transpose()?;
+ let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
+ // We can't return pointers, but it seems best not to embed that
+ // assumption here, so use `TypeInner::equivalent` for comparison.
+ let okay = match (value_ty, expected_ty) {
+ (None, None) => true,
+ (Some(value_inner), Some(expected_inner)) => {
+ value_inner.equivalent(expected_inner, context.types)
+ }
+ (_, _) => false,
+ };
+
+ if !okay {
+ log::error!(
+ "Returning {:?} where {:?} is expected",
+ value_ty,
+ expected_ty
+ );
+ if let Some(handle) = value {
+ return Err(FunctionError::InvalidReturnType(value)
+ .with_span_handle(handle, context.expressions));
+ } else {
+ return Err(FunctionError::InvalidReturnType(value)
+ .with_span_static(span, "invalid return"));
+ }
+ }
+ finished = true;
+ }
+ S::Kill => {
+ stages &= super::ShaderStages::FRAGMENT;
+ finished = true;
+ }
+ S::Barrier(_) => {
+ stages &= super::ShaderStages::COMPUTE;
+ }
+ S::Store { pointer, value } => {
+ let mut current = pointer;
+ loop {
+ let _ = context
+ .resolve_pointer_type(current)
+ .map_err(|e| e.with_span())?;
+ match context.expressions[current] {
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => current = base,
+ crate::Expression::LocalVariable(_)
+ | crate::Expression::GlobalVariable(_)
+ | crate::Expression::FunctionArgument(_) => break,
+ _ => {
+ return Err(FunctionError::InvalidStorePointer(current)
+ .with_span_handle(pointer, context.expressions))
+ }
+ }
+ }
+
+ let value_ty = context.resolve_type(value, &self.valid_expression_set)?;
+ match *value_ty {
+ Ti::Image { .. } | Ti::Sampler { .. } => {
+ return Err(FunctionError::InvalidStoreValue(value)
+ .with_span_handle(value, context.expressions));
+ }
+ _ => {}
+ }
+
+ let pointer_ty = context
+ .resolve_pointer_type(pointer)
+ .map_err(|e| e.with_span())?;
+
+ let good = match *pointer_ty {
+ Ti::Pointer { base, space: _ } => match context.types[base].inner {
+ Ti::Atomic { kind, width } => *value_ty == Ti::Scalar { kind, width },
+ ref other => value_ty == other,
+ },
+ Ti::ValuePointer {
+ size: Some(size),
+ kind,
+ width,
+ space: _,
+ } => *value_ty == Ti::Vector { size, kind, width },
+ Ti::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space: _,
+ } => *value_ty == Ti::Scalar { kind, width },
+ _ => false,
+ };
+ if !good {
+ return Err(FunctionError::InvalidStoreTypes { pointer, value }
+ .with_span()
+ .with_handle(pointer, context.expressions)
+ .with_handle(value, context.expressions));
+ }
+
+ if let Some(space) = pointer_ty.pointer_space() {
+ if !space.access().contains(crate::StorageAccess::STORE) {
+ return Err(FunctionError::InvalidStorePointer(pointer)
+ .with_span_static(
+ context.expressions.get_span(pointer),
+ "writing to this location is not permitted",
+ ));
+ }
+ }
+ }
+ S::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ //Note: this code uses a lot of `FunctionError::InvalidImageStore`,
+ // and could probably be refactored.
+ let var = match *context.get_expression(image) {
+ crate::Expression::GlobalVariable(var_handle) => {
+ &context.global_vars[var_handle]
+ }
+ // We're looking at a binding index situation, so punch through the index and look at the global behind it.
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => {
+ match *context.get_expression(base) {
+ crate::Expression::GlobalVariable(var_handle) => {
+ &context.global_vars[var_handle]
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::ExpectedGlobalVariable,
+ )
+ .with_span_handle(image, context.expressions))
+ }
+ }
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::ExpectedGlobalVariable,
+ )
+ .with_span_handle(image, context.expressions))
+ }
+ };
+
+ // Punch through a binding array to get the underlying type
+ let global_ty = match context.types[var.ty].inner {
+ Ti::BindingArray { base, .. } => &context.types[base].inner,
+ ref inner => inner,
+ };
+
+ let value_ty = match *global_ty {
+ Ti::Image {
+ class,
+ arrayed,
+ dim,
+ } => {
+ match context
+ .resolve_type(coordinate, &self.valid_expression_set)?
+ .image_storage_coordinates()
+ {
+ Some(coord_dim) if coord_dim == dim => {}
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageCoordinateType(
+ dim, coordinate,
+ ),
+ )
+ .with_span_handle(coordinate, context.expressions));
+ }
+ };
+ if arrayed != array_index.is_some() {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageArrayIndex,
+ )
+ .with_span_handle(coordinate, context.expressions));
+ }
+ if let Some(expr) = array_index {
+ match *context.resolve_type(expr, &self.valid_expression_set)? {
+ Ti::Scalar {
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ width: _,
+ } => {}
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageArrayIndexType(expr),
+ )
+ .with_span_handle(expr, context.expressions));
+ }
+ }
+ }
+ match class {
+ crate::ImageClass::Storage { format, .. } => {
+ crate::TypeInner::Vector {
+ kind: format.into(),
+ size: crate::VectorSize::Quad,
+ width: 4,
+ }
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageClass(class),
+ )
+ .with_span_handle(image, context.expressions));
+ }
+ }
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::ExpectedImageType(var.ty),
+ )
+ .with_span()
+ .with_handle(var.ty, context.types)
+ .with_handle(image, context.expressions))
+ }
+ };
+
+ if *context.resolve_type(value, &self.valid_expression_set)? != value_ty {
+ return Err(FunctionError::InvalidStoreValue(value)
+ .with_span_handle(value, context.expressions));
+ }
+ }
+ S::Call {
+ function,
+ ref arguments,
+ result,
+ } => match self.validate_call(function, arguments, result, context) {
+ Ok(callee_stages) => stages &= callee_stages,
+ Err(error) => {
+ return Err(error.and_then(|error| {
+ FunctionError::InvalidCall { function, error }
+ .with_span_static(span, "invalid function call")
+ }))
+ }
+ },
+ S::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ self.validate_atomic(pointer, fun, value, result, context)?;
+ }
+ S::RayQuery { query, ref fun } => {
+ let query_var = match *context.get_expression(query) {
+ crate::Expression::LocalVariable(var) => &context.local_vars[var],
+ ref other => {
+ log::error!("Unexpected ray query expression {other:?}");
+ return Err(FunctionError::InvalidRayQueryExpression(query)
+ .with_span_static(span, "invalid query expression"));
+ }
+ };
+ match context.types[query_var.ty].inner {
+ Ti::RayQuery => {}
+ ref other => {
+ log::error!("Unexpected ray query type {other:?}");
+ return Err(FunctionError::InvalidRayQueryType(query_var.ty)
+ .with_span_static(span, "invalid query type"));
+ }
+ }
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ match *context
+ .resolve_type(acceleration_structure, &self.valid_expression_set)?
+ {
+ Ti::AccelerationStructure => {}
+ _ => {
+ return Err(FunctionError::InvalidAccelerationStructure(
+ acceleration_structure,
+ )
+ .with_span_static(span, "invalid acceleration structure"))
+ }
+ }
+ let desc_ty_given =
+ context.resolve_type(descriptor, &self.valid_expression_set)?;
+ let desc_ty_expected = context
+ .special_types
+ .ray_desc
+ .map(|handle| &context.types[handle].inner);
+ if Some(desc_ty_given) != desc_ty_expected {
+ return Err(FunctionError::InvalidRayDescriptor(descriptor)
+ .with_span_static(span, "invalid ray descriptor"));
+ }
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ self.emit_expression(result, context)?;
+ }
+ crate::RayQueryFunction::Terminate => {}
+ }
+ }
+ }
+ }
+ Ok(BlockInfo { stages, finished })
+ }
+
+ #[cfg(feature = "validate")]
+ fn validate_block(
+ &mut self,
+ statements: &crate::Block,
+ context: &BlockContext,
+ ) -> Result<BlockInfo, WithSpan<FunctionError>> {
+ let base_expression_count = self.valid_expression_list.len();
+ let info = self.validate_block_impl(statements, context)?;
+ for handle in self.valid_expression_list.drain(base_expression_count..) {
+ self.valid_expression_set.remove(handle.index());
+ }
+ Ok(info)
+ }
+
+ #[cfg(feature = "validate")]
+ fn validate_local_var(
+ &self,
+ var: &crate::LocalVariable,
+ types: &UniqueArena<crate::Type>,
+ constants: &Arena<crate::Constant>,
+ ) -> Result<(), LocalVariableError> {
+ log::debug!("var {:?}", var);
+ let type_info = self
+ .types
+ .get(var.ty.index())
+ .ok_or(LocalVariableError::InvalidType(var.ty))?;
+ if !type_info
+ .flags
+ .contains(super::TypeFlags::DATA | super::TypeFlags::SIZED)
+ {
+ return Err(LocalVariableError::InvalidType(var.ty));
+ }
+
+ if let Some(const_handle) = var.init {
+ match constants[const_handle].inner {
+ crate::ConstantInner::Scalar { width, ref value } => {
+ let ty_inner = crate::TypeInner::Scalar {
+ width,
+ kind: value.scalar_kind(),
+ };
+ if types[var.ty].inner != ty_inner {
+ return Err(LocalVariableError::InitializerType);
+ }
+ }
+ crate::ConstantInner::Composite { ty, components: _ } => {
+ if ty != var.ty {
+ return Err(LocalVariableError::InitializerType);
+ }
+ }
+ }
+ }
+ Ok(())
+ }
+
+ pub(super) fn validate_function(
+ &mut self,
+ fun: &crate::Function,
+ module: &crate::Module,
+ mod_info: &ModuleInfo,
+ #[cfg_attr(not(feature = "validate"), allow(unused))] entry_point: bool,
+ ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
+ #[cfg_attr(not(feature = "validate"), allow(unused_mut))]
+ let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
+
+ #[cfg(feature = "validate")]
+ for (var_handle, var) in fun.local_variables.iter() {
+ self.validate_local_var(var, &module.types, &module.constants)
+ .map_err(|source| {
+ FunctionError::LocalVariable {
+ handle: var_handle,
+ name: var.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(var.ty, &module.types)
+ .with_handle(var_handle, &fun.local_variables)
+ })?;
+ }
+
+ #[cfg(feature = "validate")]
+ for (index, argument) in fun.arguments.iter().enumerate() {
+ match module.types[argument.ty].inner.pointer_space() {
+ Some(
+ crate::AddressSpace::Private
+ | crate::AddressSpace::Function
+ | crate::AddressSpace::WorkGroup,
+ )
+ | None => {}
+ Some(other) => {
+ return Err(FunctionError::InvalidArgumentPointerSpace {
+ index,
+ name: argument.name.clone().unwrap_or_default(),
+ space: other,
+ }
+ .with_span_handle(argument.ty, &module.types))
+ }
+ }
+ // Check for the least informative error last.
+ if !self.types[argument.ty.index()]
+ .flags
+ .contains(super::TypeFlags::ARGUMENT)
+ {
+ return Err(FunctionError::InvalidArgumentType {
+ index,
+ name: argument.name.clone().unwrap_or_default(),
+ }
+ .with_span_handle(argument.ty, &module.types));
+ }
+
+ if !entry_point && argument.binding.is_some() {
+ return Err(FunctionError::PipelineInputRegularFunction {
+ name: argument.name.clone().unwrap_or_default(),
+ }
+ .with_span_handle(argument.ty, &module.types));
+ }
+ }
+
+ #[cfg(feature = "validate")]
+ if let Some(ref result) = fun.result {
+ if !self.types[result.ty.index()]
+ .flags
+ .contains(super::TypeFlags::CONSTRUCTIBLE)
+ {
+ return Err(FunctionError::NonConstructibleReturnType
+ .with_span_handle(result.ty, &module.types));
+ }
+
+ if !entry_point && result.binding.is_some() {
+ return Err(FunctionError::PipelineOutputRegularFunction
+ .with_span_handle(result.ty, &module.types));
+ }
+ }
+
+ self.valid_expression_set.clear();
+ self.valid_expression_list.clear();
+ for (handle, expr) in fun.expressions.iter() {
+ if expr.needs_pre_emit() {
+ self.valid_expression_set.insert(handle.index());
+ }
+ #[cfg(feature = "validate")]
+ if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
+ match self.validate_expression(
+ handle,
+ expr,
+ fun,
+ module,
+ &info,
+ &mod_info.functions,
+ ) {
+ Ok(stages) => info.available_stages &= stages,
+ Err(source) => {
+ return Err(FunctionError::Expression { handle, source }
+ .with_span_handle(handle, &fun.expressions))
+ }
+ }
+ }
+ }
+
+ #[cfg(feature = "validate")]
+ if self.flags.contains(super::ValidationFlags::BLOCKS) {
+ let stages = self
+ .validate_block(
+ &fun.body,
+ &BlockContext::new(fun, module, &info, &mod_info.functions),
+ )?
+ .stages;
+ info.available_stages &= stages;
+ }
+ Ok(info)
+ }
+}
diff --git a/third_party/rust/naga/src/valid/handles.rs b/third_party/rust/naga/src/valid/handles.rs
new file mode 100644
index 0000000000..fdd43cd585
--- /dev/null
+++ b/third_party/rust/naga/src/valid/handles.rs
@@ -0,0 +1,652 @@
+//! Implementation of `Validator::validate_module_handles`.
+
+use crate::{
+ arena::{BadHandle, BadRangeError},
+ Handle,
+};
+
+#[cfg(feature = "validate")]
+use crate::{Arena, UniqueArena};
+
+#[cfg(feature = "validate")]
+use super::{TypeError, ValidationError};
+
+#[cfg(feature = "validate")]
+use std::{convert::TryInto, hash::Hash, num::NonZeroU32};
+
+#[cfg(feature = "validate")]
+impl super::Validator {
+ /// Validates that all handles within `module` are:
+ ///
+ /// * Valid, in the sense that they contain indices within each arena structure inside the
+ /// [`crate::Module`] type.
+ /// * No arena contents contain any items that have forward dependencies; that is, the value
+ /// associated with a handle only may contain references to handles in the same arena that
+ /// were constructed before it.
+ ///
+ /// By validating the above conditions, we free up subsequent logic to assume that handle
+ /// accesses are infallible.
+ ///
+ /// # Errors
+ ///
+ /// Errors returned by this method are intentionally sparse, for simplicity of implementation.
+ /// It is expected that only buggy frontends or fuzzers should ever emit IR that fails this
+ /// validation pass.
+ pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> {
+ let &crate::Module {
+ ref constants,
+ ref entry_points,
+ ref functions,
+ ref global_variables,
+ ref types,
+ ref special_types,
+ } = module;
+
+ // NOTE: Types being first is important. All other forms of validation depend on this.
+ for (this_handle, ty) in types.iter() {
+ let &crate::Type {
+ ref name,
+ ref inner,
+ } = ty;
+
+ let validate_array_size = |size| {
+ match size {
+ crate::ArraySize::Constant(constant) => {
+ let &crate::Constant {
+ name: _,
+ specialization: _,
+ ref inner,
+ } = constants.try_get(constant)?;
+ if !matches!(inner, &crate::ConstantInner::Scalar { .. }) {
+ return Err(ValidationError::Type {
+ handle: this_handle,
+ name: name.clone().unwrap_or_default(),
+ source: TypeError::InvalidArraySizeConstant(constant),
+ });
+ }
+ }
+ crate::ArraySize::Dynamic => (),
+ };
+ Ok(this_handle)
+ };
+
+ match *inner {
+ crate::TypeInner::Scalar { .. }
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. }
+ | crate::TypeInner::ValuePointer { .. }
+ | crate::TypeInner::Atomic { .. }
+ | crate::TypeInner::Image { .. }
+ | crate::TypeInner::Sampler { .. }
+ | crate::TypeInner::AccelerationStructure
+ | crate::TypeInner::RayQuery => (),
+ crate::TypeInner::Pointer { base, space: _ } => {
+ this_handle.check_dep(base)?;
+ }
+ crate::TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ }
+ | crate::TypeInner::BindingArray { base, size } => {
+ this_handle.check_dep(base)?;
+ validate_array_size(size)?;
+ }
+ crate::TypeInner::Struct {
+ ref members,
+ span: _,
+ } => {
+ this_handle.check_dep_iter(members.iter().map(|m| m.ty))?;
+ }
+ }
+ }
+
+ let validate_type = |handle| Self::validate_type_handle(handle, types);
+
+ for (this_handle, constant) in constants.iter() {
+ let &crate::Constant {
+ name: _,
+ specialization: _,
+ ref inner,
+ } = constant;
+ match *inner {
+ crate::ConstantInner::Scalar { .. } => (),
+ crate::ConstantInner::Composite { ty, ref components } => {
+ validate_type(ty)?;
+ this_handle.check_dep_iter(components.iter().copied())?;
+ }
+ }
+ }
+
+ let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
+
+ for (_handle, global_variable) in global_variables.iter() {
+ let &crate::GlobalVariable {
+ name: _,
+ space: _,
+ binding: _,
+ ty,
+ init,
+ } = global_variable;
+ validate_type(ty)?;
+ if let Some(init_expr) = init {
+ validate_constant(init_expr)?;
+ }
+ }
+
+ let validate_function = |function_handle, function: &_| -> Result<_, InvalidHandleError> {
+ let &crate::Function {
+ name: _,
+ ref arguments,
+ ref result,
+ ref local_variables,
+ ref expressions,
+ ref named_expressions,
+ ref body,
+ } = function;
+
+ for arg in arguments.iter() {
+ let &crate::FunctionArgument {
+ name: _,
+ ty,
+ binding: _,
+ } = arg;
+ validate_type(ty)?;
+ }
+
+ if let &Some(crate::FunctionResult { ty, binding: _ }) = result {
+ validate_type(ty)?;
+ }
+
+ for (_handle, local_variable) in local_variables.iter() {
+ let &crate::LocalVariable { name: _, ty, init } = local_variable;
+ validate_type(ty)?;
+ if let Some(init_constant) = init {
+ validate_constant(init_constant)?;
+ }
+ }
+
+ for handle in named_expressions.keys().copied() {
+ Self::validate_expression_handle(handle, expressions)?;
+ }
+
+ for handle_and_expr in expressions.iter() {
+ Self::validate_expression_handles(
+ handle_and_expr,
+ constants,
+ types,
+ local_variables,
+ global_variables,
+ functions,
+ function_handle,
+ )?;
+ }
+
+ Self::validate_block_handles(body, expressions, functions)?;
+
+ Ok(())
+ };
+
+ for entry_point in entry_points.iter() {
+ validate_function(None, &entry_point.function)?;
+ }
+
+ for (function_handle, function) in functions.iter() {
+ validate_function(Some(function_handle), function)?;
+ }
+
+ if let Some(ty) = special_types.ray_desc {
+ validate_type(ty)?;
+ }
+ if let Some(ty) = special_types.ray_intersection {
+ validate_type(ty)?;
+ }
+
+ Ok(())
+ }
+
+ fn validate_type_handle(
+ handle: Handle<crate::Type>,
+ types: &UniqueArena<crate::Type>,
+ ) -> Result<(), InvalidHandleError> {
+ handle.check_valid_for_uniq(types).map(|_| ())
+ }
+
+ fn validate_constant_handle(
+ handle: Handle<crate::Constant>,
+ constants: &Arena<crate::Constant>,
+ ) -> Result<(), InvalidHandleError> {
+ handle.check_valid_for(constants).map(|_| ())
+ }
+
+ fn validate_expression_handle(
+ handle: Handle<crate::Expression>,
+ expressions: &Arena<crate::Expression>,
+ ) -> Result<(), InvalidHandleError> {
+ handle.check_valid_for(expressions).map(|_| ())
+ }
+
+ fn validate_function_handle(
+ handle: Handle<crate::Function>,
+ functions: &Arena<crate::Function>,
+ ) -> Result<(), InvalidHandleError> {
+ handle.check_valid_for(functions).map(|_| ())
+ }
+
+ fn validate_expression_handles(
+ (handle, expression): (Handle<crate::Expression>, &crate::Expression),
+ constants: &Arena<crate::Constant>,
+ types: &UniqueArena<crate::Type>,
+ local_variables: &Arena<crate::LocalVariable>,
+ global_variables: &Arena<crate::GlobalVariable>,
+ functions: &Arena<crate::Function>,
+ // The handle of the current function or `None` if it's an entry point
+ current_function: Option<Handle<crate::Function>>,
+ ) -> Result<(), InvalidHandleError> {
+ let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
+ let validate_type = |handle| Self::validate_type_handle(handle, types);
+
+ match *expression {
+ crate::Expression::Access { base, index } => {
+ handle.check_dep(base)?.check_dep(index)?;
+ }
+ crate::Expression::AccessIndex { base, .. } => {
+ handle.check_dep(base)?;
+ }
+ crate::Expression::Constant(constant) => {
+ validate_constant(constant)?;
+ }
+ crate::Expression::Splat { value, .. } => {
+ handle.check_dep(value)?;
+ }
+ crate::Expression::Swizzle { vector, .. } => {
+ handle.check_dep(vector)?;
+ }
+ crate::Expression::Compose { ty, ref components } => {
+ validate_type(ty)?;
+ handle.check_dep_iter(components.iter().copied())?;
+ }
+ crate::Expression::FunctionArgument(_arg_idx) => (),
+ crate::Expression::GlobalVariable(global_variable) => {
+ global_variable.check_valid_for(global_variables)?;
+ }
+ crate::Expression::LocalVariable(local_variable) => {
+ local_variable.check_valid_for(local_variables)?;
+ }
+ crate::Expression::Load { pointer } => {
+ handle.check_dep(pointer)?;
+ }
+ crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather: _,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ if let Some(offset) = offset {
+ validate_constant(offset)?;
+ }
+
+ handle
+ .check_dep(image)?
+ .check_dep(sampler)?
+ .check_dep(coordinate)?
+ .check_dep_opt(array_index)?;
+
+ match level {
+ crate::SampleLevel::Auto | crate::SampleLevel::Zero => (),
+ crate::SampleLevel::Exact(expr) => {
+ handle.check_dep(expr)?;
+ }
+ crate::SampleLevel::Bias(expr) => {
+ handle.check_dep(expr)?;
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ handle.check_dep(x)?.check_dep(y)?;
+ }
+ };
+
+ handle.check_dep_opt(depth_ref)?;
+ }
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ handle
+ .check_dep(image)?
+ .check_dep(coordinate)?
+ .check_dep_opt(array_index)?
+ .check_dep_opt(sample)?
+ .check_dep_opt(level)?;
+ }
+ crate::Expression::ImageQuery { image, query } => {
+ handle.check_dep(image)?;
+ match query {
+ crate::ImageQuery::Size { level } => {
+ handle.check_dep_opt(level)?;
+ }
+ crate::ImageQuery::NumLevels
+ | crate::ImageQuery::NumLayers
+ | crate::ImageQuery::NumSamples => (),
+ };
+ }
+ crate::Expression::Unary {
+ op: _,
+ expr: operand,
+ } => {
+ handle.check_dep(operand)?;
+ }
+ crate::Expression::Binary { op: _, left, right } => {
+ handle.check_dep(left)?.check_dep(right)?;
+ }
+ crate::Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ handle
+ .check_dep(condition)?
+ .check_dep(accept)?
+ .check_dep(reject)?;
+ }
+ crate::Expression::Derivative { expr: argument, .. } => {
+ handle.check_dep(argument)?;
+ }
+ crate::Expression::Relational { fun: _, argument } => {
+ handle.check_dep(argument)?;
+ }
+ crate::Expression::Math {
+ fun: _,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ handle
+ .check_dep(arg)?
+ .check_dep_opt(arg1)?
+ .check_dep_opt(arg2)?
+ .check_dep_opt(arg3)?;
+ }
+ crate::Expression::As {
+ expr: input,
+ kind: _,
+ convert: _,
+ } => {
+ handle.check_dep(input)?;
+ }
+ crate::Expression::CallResult(function) => {
+ Self::validate_function_handle(function, functions)?;
+ if let Some(handle) = current_function {
+ handle.check_dep(function)?;
+ }
+ }
+ crate::Expression::AtomicResult { .. } | crate::Expression::RayQueryProceedResult => (),
+ crate::Expression::ArrayLength(array) => {
+ handle.check_dep(array)?;
+ }
+ crate::Expression::RayQueryGetIntersection {
+ query,
+ committed: _,
+ } => {
+ handle.check_dep(query)?;
+ }
+ }
+ Ok(())
+ }
+
+ fn validate_block_handles(
+ block: &crate::Block,
+ expressions: &Arena<crate::Expression>,
+ functions: &Arena<crate::Function>,
+ ) -> Result<(), InvalidHandleError> {
+ let validate_block = |block| Self::validate_block_handles(block, expressions, functions);
+ let validate_expr = |handle| Self::validate_expression_handle(handle, expressions);
+ let validate_expr_opt = |handle_opt| {
+ if let Some(handle) = handle_opt {
+ validate_expr(handle)?;
+ }
+ Ok(())
+ };
+
+ block.iter().try_for_each(|stmt| match *stmt {
+ crate::Statement::Emit(ref expr_range) => {
+ expr_range.check_valid_for(expressions)?;
+ Ok(())
+ }
+ crate::Statement::Block(ref block) => {
+ validate_block(block)?;
+ Ok(())
+ }
+ crate::Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ validate_expr(condition)?;
+ validate_block(accept)?;
+ validate_block(reject)?;
+ Ok(())
+ }
+ crate::Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ validate_expr(selector)?;
+ for &crate::SwitchCase {
+ value: _,
+ ref body,
+ fall_through: _,
+ } in cases
+ {
+ validate_block(body)?;
+ }
+ Ok(())
+ }
+ crate::Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ validate_block(body)?;
+ validate_block(continuing)?;
+ validate_expr_opt(break_if)?;
+ Ok(())
+ }
+ crate::Statement::Return { value } => validate_expr_opt(value),
+ crate::Statement::Store { pointer, value } => {
+ validate_expr(pointer)?;
+ validate_expr(value)?;
+ Ok(())
+ }
+ crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ validate_expr(image)?;
+ validate_expr(coordinate)?;
+ validate_expr_opt(array_index)?;
+ validate_expr(value)?;
+ Ok(())
+ }
+ crate::Statement::Atomic {
+ pointer,
+ fun,
+ value,
+ result,
+ } => {
+ validate_expr(pointer)?;
+ match fun {
+ crate::AtomicFunction::Add
+ | crate::AtomicFunction::Subtract
+ | crate::AtomicFunction::And
+ | crate::AtomicFunction::ExclusiveOr
+ | crate::AtomicFunction::InclusiveOr
+ | crate::AtomicFunction::Min
+ | crate::AtomicFunction::Max => (),
+ crate::AtomicFunction::Exchange { compare } => validate_expr_opt(compare)?,
+ };
+ validate_expr(value)?;
+ validate_expr(result)?;
+ Ok(())
+ }
+ crate::Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ Self::validate_function_handle(function, functions)?;
+ for arg in arguments.iter().copied() {
+ validate_expr(arg)?;
+ }
+ validate_expr_opt(result)?;
+ Ok(())
+ }
+ crate::Statement::RayQuery { query, ref fun } => {
+ validate_expr(query)?;
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ validate_expr(acceleration_structure)?;
+ validate_expr(descriptor)?;
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ validate_expr(result)?;
+ }
+ crate::RayQueryFunction::Terminate => {}
+ }
+ Ok(())
+ }
+ crate::Statement::Break
+ | crate::Statement::Continue
+ | crate::Statement::Kill
+ | crate::Statement::Barrier(_) => Ok(()),
+ })
+ }
+}
+
+#[cfg(feature = "validate")]
+impl From<BadHandle> for ValidationError {
+ fn from(source: BadHandle) -> Self {
+ Self::InvalidHandle(source.into())
+ }
+}
+
+#[cfg(feature = "validate")]
+impl From<FwdDepError> for ValidationError {
+ fn from(source: FwdDepError) -> Self {
+ Self::InvalidHandle(source.into())
+ }
+}
+
+#[cfg(feature = "validate")]
+impl From<BadRangeError> for ValidationError {
+ fn from(source: BadRangeError) -> Self {
+ Self::InvalidHandle(source.into())
+ }
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum InvalidHandleError {
+ #[error(transparent)]
+ BadHandle(#[from] BadHandle),
+ #[error(transparent)]
+ ForwardDependency(#[from] FwdDepError),
+ #[error(transparent)]
+ BadRange(#[from] BadRangeError),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[error(
+ "{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \
+ which has not been processed yet"
+)]
+pub struct FwdDepError {
+ // This error is used for many `Handle` types, but there's no point in making this generic, so
+ // we just flatten them all to `Handle<()>` here.
+ subject: Handle<()>,
+ subject_kind: &'static str,
+ depends_on: Handle<()>,
+ depends_on_kind: &'static str,
+}
+
+#[cfg(feature = "validate")]
+impl<T> Handle<T> {
+ /// Check that `self` is valid within `arena` using [`Arena::check_contains_handle`].
+ pub(self) fn check_valid_for(self, arena: &Arena<T>) -> Result<(), InvalidHandleError> {
+ arena.check_contains_handle(self)?;
+ Ok(())
+ }
+
+ /// Check that `self` is valid within `arena` using [`UniqueArena::check_contains_handle`].
+ pub(self) fn check_valid_for_uniq(
+ self,
+ arena: &UniqueArena<T>,
+ ) -> Result<(), InvalidHandleError>
+ where
+ T: Eq + Hash,
+ {
+ arena.check_contains_handle(self)?;
+ Ok(())
+ }
+
+ /// Check that `depends_on` was constructed before `self` by comparing handle indices.
+ ///
+ /// If `self` is a valid handle (i.e., it has been validated using [`Self::check_valid_for`])
+ /// and this function returns [`Ok`], then it may be assumed that `depends_on` is also valid.
+ /// In [`naga`](crate)'s current arena-based implementation, this is useful for validating
+ /// recursive definitions of arena-based values in linear time.
+ ///
+ /// # Errors
+ ///
+ /// If `depends_on`'s handle is from the same [`Arena`] as `self'`s, but not constructed earlier
+ /// than `self`'s, this function returns an error.
+ pub(self) fn check_dep(self, depends_on: Self) -> Result<Self, FwdDepError> {
+ if depends_on < self {
+ Ok(self)
+ } else {
+ let erase_handle_type = |handle: Handle<_>| {
+ Handle::new(NonZeroU32::new((handle.index() + 1).try_into().unwrap()).unwrap())
+ };
+ Err(FwdDepError {
+ subject: erase_handle_type(self),
+ subject_kind: std::any::type_name::<T>(),
+ depends_on: erase_handle_type(depends_on),
+ depends_on_kind: std::any::type_name::<T>(),
+ })
+ }
+ }
+
+ /// Like [`Self::check_dep`], except for [`Option`]al handle values.
+ pub(self) fn check_dep_opt(self, depends_on: Option<Self>) -> Result<Self, FwdDepError> {
+ self.check_dep_iter(depends_on.into_iter())
+ }
+
+ /// Like [`Self::check_dep`], except for [`Iterator`]s over handle values.
+ pub(self) fn check_dep_iter(
+ self,
+ depends_on: impl Iterator<Item = Self>,
+ ) -> Result<Self, FwdDepError> {
+ for handle in depends_on {
+ self.check_dep(handle)?;
+ }
+ Ok(self)
+ }
+}
+
+#[cfg(feature = "validate")]
+impl<T> crate::arena::Range<T> {
+ pub(self) fn check_valid_for(&self, arena: &Arena<T>) -> Result<(), BadRangeError> {
+ arena.check_contains_range(self)
+ }
+}
diff --git a/third_party/rust/naga/src/valid/interface.rs b/third_party/rust/naga/src/valid/interface.rs
new file mode 100644
index 0000000000..268490c52e
--- /dev/null
+++ b/third_party/rust/naga/src/valid/interface.rs
@@ -0,0 +1,676 @@
+use super::{
+ analyzer::{FunctionInfo, GlobalUse},
+ Capabilities, Disalignment, FunctionError, ModuleInfo,
+};
+use crate::arena::{Handle, UniqueArena};
+
+use crate::span::{AddSpan as _, MapErrWithSpan as _, SpanProvider as _, WithSpan};
+use bit_set::BitSet;
+
+#[cfg(feature = "validate")]
+const MAX_WORKGROUP_SIZE: u32 = 0x4000;
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum GlobalVariableError {
+ #[error("Usage isn't compatible with address space {0:?}")]
+ InvalidUsage(crate::AddressSpace),
+ #[error("Type isn't compatible with address space {0:?}")]
+ InvalidType(crate::AddressSpace),
+ #[error("Type flags {seen:?} do not meet the required {required:?}")]
+ MissingTypeFlags {
+ required: super::TypeFlags,
+ seen: super::TypeFlags,
+ },
+ #[error("Capability {0:?} is not supported")]
+ UnsupportedCapability(Capabilities),
+ #[error("Binding decoration is missing or not applicable")]
+ InvalidBinding,
+ #[error("Alignment requirements for address space {0:?} are not met by {1:?}")]
+ Alignment(
+ crate::AddressSpace,
+ Handle<crate::Type>,
+ #[source] Disalignment,
+ ),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum VaryingError {
+ #[error("The type {0:?} does not match the varying")]
+ InvalidType(Handle<crate::Type>),
+ #[error("The type {0:?} cannot be used for user-defined entry point inputs or outputs")]
+ NotIOShareableType(Handle<crate::Type>),
+ #[error("Interpolation is not valid")]
+ InvalidInterpolation,
+ #[error("Interpolation must be specified on vertex shader outputs and fragment shader inputs")]
+ MissingInterpolation,
+ #[error("Built-in {0:?} is not available at this stage")]
+ InvalidBuiltInStage(crate::BuiltIn),
+ #[error("Built-in type for {0:?} is invalid")]
+ InvalidBuiltInType(crate::BuiltIn),
+ #[error("Entry point arguments and return values must all have bindings")]
+ MissingBinding,
+ #[error("Struct member {0} is missing a binding")]
+ MemberMissingBinding(u32),
+ #[error("Multiple bindings at location {location} are present")]
+ BindingCollision { location: u32 },
+ #[error("Built-in {0:?} is present more than once")]
+ DuplicateBuiltIn(crate::BuiltIn),
+ #[error("Capability {0:?} is not supported")]
+ UnsupportedCapability(Capabilities),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum EntryPointError {
+ #[error("Multiple conflicting entry points")]
+ Conflict,
+ #[error("Vertex shaders must return a `@builtin(position)` output value")]
+ MissingVertexOutputPosition,
+ #[error("Early depth test is not applicable")]
+ UnexpectedEarlyDepthTest,
+ #[error("Workgroup size is not applicable")]
+ UnexpectedWorkgroupSize,
+ #[error("Workgroup size is out of range")]
+ OutOfRangeWorkgroupSize,
+ #[error("Uses operations forbidden at this stage")]
+ ForbiddenStageOperations,
+ #[error("Global variable {0:?} is used incorrectly as {1:?}")]
+ InvalidGlobalUsage(Handle<crate::GlobalVariable>, GlobalUse),
+ #[error("Bindings for {0:?} conflict with other resource")]
+ BindingCollision(Handle<crate::GlobalVariable>),
+ #[error("Argument {0} varying error")]
+ Argument(u32, #[source] VaryingError),
+ #[error(transparent)]
+ Result(#[from] VaryingError),
+ #[error("Location {location} interpolation of an integer has to be flat")]
+ InvalidIntegerInterpolation { location: u32 },
+ #[error(transparent)]
+ Function(#[from] FunctionError),
+}
+
+#[cfg(feature = "validate")]
+fn storage_usage(access: crate::StorageAccess) -> GlobalUse {
+ let mut storage_usage = GlobalUse::QUERY;
+ if access.contains(crate::StorageAccess::LOAD) {
+ storage_usage |= GlobalUse::READ;
+ }
+ if access.contains(crate::StorageAccess::STORE) {
+ storage_usage |= GlobalUse::WRITE;
+ }
+ storage_usage
+}
+
+struct VaryingContext<'a> {
+ stage: crate::ShaderStage,
+ output: bool,
+ types: &'a UniqueArena<crate::Type>,
+ type_info: &'a Vec<super::r#type::TypeInfo>,
+ location_mask: &'a mut BitSet,
+ built_ins: &'a mut crate::FastHashSet<crate::BuiltIn>,
+ capabilities: Capabilities,
+
+ #[cfg(feature = "validate")]
+ flags: super::ValidationFlags,
+}
+
+impl VaryingContext<'_> {
+ fn validate_impl(
+ &mut self,
+ ty: Handle<crate::Type>,
+ binding: &crate::Binding,
+ ) -> Result<(), VaryingError> {
+ use crate::{
+ BuiltIn as Bi, ScalarKind as Sk, ShaderStage as St, TypeInner as Ti, VectorSize as Vs,
+ };
+
+ let ty_inner = &self.types[ty].inner;
+ match *binding {
+ crate::Binding::BuiltIn(built_in) => {
+ // Ignore the `invariant` field for the sake of duplicate checks,
+ // but use the original in error messages.
+ let canonical = if let crate::BuiltIn::Position { .. } = built_in {
+ crate::BuiltIn::Position { invariant: false }
+ } else {
+ built_in
+ };
+
+ if self.built_ins.contains(&canonical) {
+ return Err(VaryingError::DuplicateBuiltIn(built_in));
+ }
+ self.built_ins.insert(canonical);
+
+ let required = match built_in {
+ Bi::ClipDistance => Capabilities::CLIP_DISTANCE,
+ Bi::CullDistance => Capabilities::CULL_DISTANCE,
+ Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX,
+ Bi::ViewIndex => Capabilities::MULTIVIEW,
+ Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING,
+ _ => Capabilities::empty(),
+ };
+ if !self.capabilities.contains(required) {
+ return Err(VaryingError::UnsupportedCapability(required));
+ }
+
+ let width = 4;
+ let (visible, type_good) = match built_in {
+ Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => (
+ self.stage == St::Vertex && !self.output,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Uint,
+ width,
+ },
+ ),
+ Bi::ClipDistance | Bi::CullDistance => (
+ self.stage == St::Vertex && self.output,
+ match *ty_inner {
+ Ti::Array { base, .. } => {
+ self.types[base].inner
+ == Ti::Scalar {
+ kind: Sk::Float,
+ width,
+ }
+ }
+ _ => false,
+ },
+ ),
+ Bi::PointSize => (
+ self.stage == St::Vertex && self.output,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Float,
+ width,
+ },
+ ),
+ Bi::PointCoord => (
+ self.stage == St::Fragment && !self.output,
+ *ty_inner
+ == Ti::Vector {
+ size: Vs::Bi,
+ kind: Sk::Float,
+ width,
+ },
+ ),
+ Bi::Position { .. } => (
+ match self.stage {
+ St::Vertex => self.output,
+ St::Fragment => !self.output,
+ St::Compute => false,
+ },
+ *ty_inner
+ == Ti::Vector {
+ size: Vs::Quad,
+ kind: Sk::Float,
+ width,
+ },
+ ),
+ Bi::ViewIndex => (
+ match self.stage {
+ St::Vertex | St::Fragment => !self.output,
+ St::Compute => false,
+ },
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Sint,
+ width,
+ },
+ ),
+ Bi::FragDepth => (
+ self.stage == St::Fragment && self.output,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Float,
+ width,
+ },
+ ),
+ Bi::FrontFacing => (
+ self.stage == St::Fragment && !self.output,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Bool,
+ width: crate::BOOL_WIDTH,
+ },
+ ),
+ Bi::PrimitiveIndex => (
+ self.stage == St::Fragment && !self.output,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Uint,
+ width,
+ },
+ ),
+ Bi::SampleIndex => (
+ self.stage == St::Fragment && !self.output,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Uint,
+ width,
+ },
+ ),
+ Bi::SampleMask => (
+ self.stage == St::Fragment,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Uint,
+ width,
+ },
+ ),
+ Bi::LocalInvocationIndex => (
+ self.stage == St::Compute && !self.output,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Uint,
+ width,
+ },
+ ),
+ Bi::GlobalInvocationId
+ | Bi::LocalInvocationId
+ | Bi::WorkGroupId
+ | Bi::WorkGroupSize
+ | Bi::NumWorkGroups => (
+ self.stage == St::Compute && !self.output,
+ *ty_inner
+ == Ti::Vector {
+ size: Vs::Tri,
+ kind: Sk::Uint,
+ width,
+ },
+ ),
+ };
+
+ if !visible {
+ return Err(VaryingError::InvalidBuiltInStage(built_in));
+ }
+ if !type_good {
+ log::warn!("Wrong builtin type: {:?}", ty_inner);
+ return Err(VaryingError::InvalidBuiltInType(built_in));
+ }
+ }
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ } => {
+ // Only IO-shareable types may be stored in locations.
+ if !self.type_info[ty.index()]
+ .flags
+ .contains(super::TypeFlags::IO_SHAREABLE)
+ {
+ return Err(VaryingError::NotIOShareableType(ty));
+ }
+ if !self.location_mask.insert(location as usize) {
+ #[cfg(feature = "validate")]
+ if self.flags.contains(super::ValidationFlags::BINDINGS) {
+ return Err(VaryingError::BindingCollision { location });
+ }
+ }
+
+ let needs_interpolation = match self.stage {
+ crate::ShaderStage::Vertex => self.output,
+ crate::ShaderStage::Fragment => !self.output,
+ crate::ShaderStage::Compute => false,
+ };
+
+ // It doesn't make sense to specify a sampling when `interpolation` is `Flat`, but
+ // SPIR-V and GLSL both explicitly tolerate such combinations of decorators /
+ // qualifiers, so we won't complain about that here.
+ let _ = sampling;
+
+ let required = match sampling {
+ Some(crate::Sampling::Sample) => Capabilities::MULTISAMPLED_SHADING,
+ _ => Capabilities::empty(),
+ };
+ if !self.capabilities.contains(required) {
+ return Err(VaryingError::UnsupportedCapability(required));
+ }
+
+ match ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Float) => {
+ if needs_interpolation && interpolation.is_none() {
+ return Err(VaryingError::MissingInterpolation);
+ }
+ }
+ Some(_) => {
+ if needs_interpolation && interpolation != Some(crate::Interpolation::Flat)
+ {
+ return Err(VaryingError::InvalidInterpolation);
+ }
+ }
+ None => return Err(VaryingError::InvalidType(ty)),
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ fn validate(
+ &mut self,
+ ty: Handle<crate::Type>,
+ binding: Option<&crate::Binding>,
+ ) -> Result<(), WithSpan<VaryingError>> {
+ let span_context = self.types.get_span_context(ty);
+ match binding {
+ Some(binding) => self
+ .validate_impl(ty, binding)
+ .map_err(|e| e.with_span_context(span_context)),
+ None => {
+ match self.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (index, member) in members.iter().enumerate() {
+ let span_context = self.types.get_span_context(ty);
+ match member.binding {
+ None => {
+ #[cfg(feature = "validate")]
+ if self.flags.contains(super::ValidationFlags::BINDINGS) {
+ return Err(VaryingError::MemberMissingBinding(
+ index as u32,
+ )
+ .with_span_context(span_context));
+ }
+ #[cfg(not(feature = "validate"))]
+ let _ = index;
+ }
+ Some(ref binding) => self
+ .validate_impl(member.ty, binding)
+ .map_err(|e| e.with_span_context(span_context))?,
+ }
+ }
+ }
+ _ =>
+ {
+ #[cfg(feature = "validate")]
+ if self.flags.contains(super::ValidationFlags::BINDINGS) {
+ return Err(VaryingError::MissingBinding.with_span());
+ }
+ }
+ }
+ Ok(())
+ }
+ }
+ }
+}
+
+impl super::Validator {
+ #[cfg(feature = "validate")]
+ pub(super) fn validate_global_var(
+ &self,
+ var: &crate::GlobalVariable,
+ types: &UniqueArena<crate::Type>,
+ ) -> Result<(), GlobalVariableError> {
+ use super::TypeFlags;
+
+ log::debug!("var {:?}", var);
+ let type_info = &self.types[var.ty.index()];
+
+ let (required_type_flags, is_resource) = match var.space {
+ crate::AddressSpace::Function => {
+ return Err(GlobalVariableError::InvalidUsage(var.space))
+ }
+ crate::AddressSpace::Storage { .. } => {
+ if let Err((ty_handle, disalignment)) = type_info.storage_layout {
+ if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) {
+ return Err(GlobalVariableError::Alignment(
+ var.space,
+ ty_handle,
+ disalignment,
+ ));
+ }
+ }
+ (TypeFlags::DATA | TypeFlags::HOST_SHAREABLE, true)
+ }
+ crate::AddressSpace::Uniform => {
+ if let Err((ty_handle, disalignment)) = type_info.uniform_layout {
+ if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) {
+ return Err(GlobalVariableError::Alignment(
+ var.space,
+ ty_handle,
+ disalignment,
+ ));
+ }
+ }
+ (
+ TypeFlags::DATA
+ | TypeFlags::COPY
+ | TypeFlags::SIZED
+ | TypeFlags::HOST_SHAREABLE,
+ true,
+ )
+ }
+ crate::AddressSpace::Handle => {
+ match types[var.ty].inner {
+ crate::TypeInner::Image { .. }
+ | crate::TypeInner::Sampler { .. }
+ | crate::TypeInner::BindingArray { .. }
+ | crate::TypeInner::AccelerationStructure
+ | crate::TypeInner::RayQuery => {}
+ _ => {
+ return Err(GlobalVariableError::InvalidType(var.space));
+ }
+ };
+ let inner_ty = match &types[var.ty].inner {
+ &crate::TypeInner::BindingArray { base, .. } => &types[base].inner,
+ ty => ty,
+ };
+ if let crate::TypeInner::Image {
+ class:
+ crate::ImageClass::Storage {
+ format:
+ crate::StorageFormat::R16Unorm
+ | crate::StorageFormat::R16Snorm
+ | crate::StorageFormat::Rg16Unorm
+ | crate::StorageFormat::Rg16Snorm
+ | crate::StorageFormat::Rgba16Unorm
+ | crate::StorageFormat::Rgba16Snorm,
+ ..
+ },
+ ..
+ } = *inner_ty
+ {
+ if !self
+ .capabilities
+ .contains(Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS)
+ {
+ return Err(GlobalVariableError::UnsupportedCapability(
+ Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
+ ));
+ }
+ }
+
+ (TypeFlags::empty(), true)
+ }
+ crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => {
+ (TypeFlags::DATA | TypeFlags::SIZED, false)
+ }
+ crate::AddressSpace::PushConstant => {
+ if !self.capabilities.contains(Capabilities::PUSH_CONSTANT) {
+ return Err(GlobalVariableError::UnsupportedCapability(
+ Capabilities::PUSH_CONSTANT,
+ ));
+ }
+ (
+ TypeFlags::DATA
+ | TypeFlags::COPY
+ | TypeFlags::HOST_SHAREABLE
+ | TypeFlags::SIZED,
+ false,
+ )
+ }
+ };
+
+ if !type_info.flags.contains(required_type_flags) {
+ return Err(GlobalVariableError::MissingTypeFlags {
+ seen: type_info.flags,
+ required: required_type_flags,
+ });
+ }
+
+ if is_resource != var.binding.is_some() {
+ if self.flags.contains(super::ValidationFlags::BINDINGS) {
+ return Err(GlobalVariableError::InvalidBinding);
+ }
+ }
+
+ Ok(())
+ }
+
+ pub(super) fn validate_entry_point(
+ &mut self,
+ ep: &crate::EntryPoint,
+ module: &crate::Module,
+ mod_info: &ModuleInfo,
+ ) -> Result<FunctionInfo, WithSpan<EntryPointError>> {
+ #[cfg(feature = "validate")]
+ if ep.early_depth_test.is_some() {
+ let required = Capabilities::EARLY_DEPTH_TEST;
+ if !self.capabilities.contains(required) {
+ return Err(
+ EntryPointError::Result(VaryingError::UnsupportedCapability(required))
+ .with_span(),
+ );
+ }
+
+ if ep.stage != crate::ShaderStage::Fragment {
+ return Err(EntryPointError::UnexpectedEarlyDepthTest.with_span());
+ }
+ }
+
+ #[cfg(feature = "validate")]
+ if ep.stage == crate::ShaderStage::Compute {
+ if ep
+ .workgroup_size
+ .iter()
+ .any(|&s| s == 0 || s > MAX_WORKGROUP_SIZE)
+ {
+ return Err(EntryPointError::OutOfRangeWorkgroupSize.with_span());
+ }
+ } else if ep.workgroup_size != [0; 3] {
+ return Err(EntryPointError::UnexpectedWorkgroupSize.with_span());
+ }
+
+ let info = self
+ .validate_function(&ep.function, module, mod_info, true)
+ .map_err(WithSpan::into_other)?;
+
+ #[cfg(feature = "validate")]
+ {
+ use super::ShaderStages;
+
+ let stage_bit = match ep.stage {
+ crate::ShaderStage::Vertex => ShaderStages::VERTEX,
+ crate::ShaderStage::Fragment => ShaderStages::FRAGMENT,
+ crate::ShaderStage::Compute => ShaderStages::COMPUTE,
+ };
+
+ if !info.available_stages.contains(stage_bit) {
+ return Err(EntryPointError::ForbiddenStageOperations.with_span());
+ }
+ }
+
+ self.location_mask.clear();
+ let mut argument_built_ins = crate::FastHashSet::default();
+ // TODO: add span info to function arguments
+ for (index, fa) in ep.function.arguments.iter().enumerate() {
+ let mut ctx = VaryingContext {
+ stage: ep.stage,
+ output: false,
+ types: &module.types,
+ type_info: &self.types,
+ location_mask: &mut self.location_mask,
+ built_ins: &mut argument_built_ins,
+ capabilities: self.capabilities,
+
+ #[cfg(feature = "validate")]
+ flags: self.flags,
+ };
+ ctx.validate(fa.ty, fa.binding.as_ref())
+ .map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?;
+ }
+
+ self.location_mask.clear();
+ if let Some(ref fr) = ep.function.result {
+ let mut result_built_ins = crate::FastHashSet::default();
+ let mut ctx = VaryingContext {
+ stage: ep.stage,
+ output: true,
+ types: &module.types,
+ type_info: &self.types,
+ location_mask: &mut self.location_mask,
+ built_ins: &mut result_built_ins,
+ capabilities: self.capabilities,
+
+ #[cfg(feature = "validate")]
+ flags: self.flags,
+ };
+ ctx.validate(fr.ty, fr.binding.as_ref())
+ .map_err_inner(|e| EntryPointError::Result(e).with_span())?;
+
+ #[cfg(feature = "validate")]
+ if ep.stage == crate::ShaderStage::Vertex
+ && !result_built_ins.contains(&crate::BuiltIn::Position { invariant: false })
+ {
+ return Err(EntryPointError::MissingVertexOutputPosition.with_span());
+ }
+ } else if ep.stage == crate::ShaderStage::Vertex {
+ #[cfg(feature = "validate")]
+ return Err(EntryPointError::MissingVertexOutputPosition.with_span());
+ }
+
+ for bg in self.bind_group_masks.iter_mut() {
+ bg.clear();
+ }
+
+ #[cfg(feature = "validate")]
+ for (var_handle, var) in module.global_variables.iter() {
+ let usage = info[var_handle];
+ if usage.is_empty() {
+ continue;
+ }
+
+ let allowed_usage = match var.space {
+ crate::AddressSpace::Function => unreachable!(),
+ crate::AddressSpace::Uniform => GlobalUse::READ | GlobalUse::QUERY,
+ crate::AddressSpace::Storage { access } => storage_usage(access),
+ crate::AddressSpace::Handle => match module.types[var.ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => match module.types[base].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => storage_usage(access),
+ _ => GlobalUse::READ | GlobalUse::QUERY,
+ },
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => storage_usage(access),
+ _ => GlobalUse::READ | GlobalUse::QUERY,
+ },
+ crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => GlobalUse::all(),
+ crate::AddressSpace::PushConstant => GlobalUse::READ,
+ };
+ if !allowed_usage.contains(usage) {
+ log::warn!("\tUsage error for: {:?}", var);
+ log::warn!(
+ "\tAllowed usage: {:?}, requested: {:?}",
+ allowed_usage,
+ usage
+ );
+ return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage)
+ .with_span_handle(var_handle, &module.global_variables));
+ }
+
+ if let Some(ref bind) = var.binding {
+ while self.bind_group_masks.len() <= bind.group as usize {
+ self.bind_group_masks.push(BitSet::new());
+ }
+ if !self.bind_group_masks[bind.group as usize].insert(bind.binding as usize) {
+ if self.flags.contains(super::ValidationFlags::BINDINGS) {
+ return Err(EntryPointError::BindingCollision(var_handle)
+ .with_span_handle(var_handle, &module.global_variables));
+ }
+ }
+ }
+ }
+
+ Ok(info)
+ }
+}
diff --git a/third_party/rust/naga/src/valid/mod.rs b/third_party/rust/naga/src/valid/mod.rs
new file mode 100644
index 0000000000..6b3a2e1456
--- /dev/null
+++ b/third_party/rust/naga/src/valid/mod.rs
@@ -0,0 +1,467 @@
+/*!
+Shader validator.
+*/
+
+mod analyzer;
+mod compose;
+mod expression;
+mod function;
+mod handles;
+mod interface;
+mod r#type;
+
+#[cfg(feature = "validate")]
+use crate::arena::{Arena, UniqueArena};
+
+use crate::{
+ arena::Handle,
+ proc::{LayoutError, Layouter},
+ FastHashSet,
+};
+use bit_set::BitSet;
+use std::ops;
+
+//TODO: analyze the model at the same time as we validate it,
+// merge the corresponding matches over expressions and statements.
+
+use crate::span::{AddSpan as _, WithSpan};
+pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
+pub use compose::ComposeError;
+pub use expression::ExpressionError;
+pub use function::{CallError, FunctionError, LocalVariableError};
+pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
+pub use r#type::{Disalignment, TypeError, TypeFlags};
+
+use self::handles::InvalidHandleError;
+
+bitflags::bitflags! {
+ /// Validation flags.
+ ///
+ /// If you are working with trusted shaders, then you may be able
+ /// to save some time by skipping validation.
+ ///
+ /// If you do not perform full validation, invalid shaders may
+ /// cause Naga to panic. If you do perform full validation and
+ /// [`Validator::validate`] returns `Ok`, then Naga promises that
+ /// code generation will either succeed or return an error; it
+ /// should never panic.
+ ///
+ /// The default value for `ValidationFlags` is
+ /// `ValidationFlags::all()`. If Naga's `"validate"` feature is
+ /// enabled, this requests full validation; otherwise, this
+ /// requests no validation. (The `"validate"` feature is disabled
+ /// by default.)
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct ValidationFlags: u8 {
+ /// Expressions.
+ #[cfg(feature = "validate")]
+ const EXPRESSIONS = 0x1;
+ /// Statements and blocks of them.
+ #[cfg(feature = "validate")]
+ const BLOCKS = 0x2;
+ /// Uniformity of control flow for operations that require it.
+ #[cfg(feature = "validate")]
+ const CONTROL_FLOW_UNIFORMITY = 0x4;
+ /// Host-shareable structure layouts.
+ #[cfg(feature = "validate")]
+ const STRUCT_LAYOUTS = 0x8;
+ /// Constants.
+ #[cfg(feature = "validate")]
+ const CONSTANTS = 0x10;
+ /// Group, binding, and location attributes.
+ #[cfg(feature = "validate")]
+ const BINDINGS = 0x20;
+ }
+}
+
+impl Default for ValidationFlags {
+ fn default() -> Self {
+ Self::all()
+ }
+}
+
+bitflags::bitflags! {
+ /// Allowed IR capabilities.
+ #[must_use]
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct Capabilities: u16 {
+ /// Support for [`AddressSpace:PushConstant`].
+ const PUSH_CONSTANT = 0x1;
+ /// Float values with width = 8.
+ const FLOAT64 = 0x2;
+ /// Support for [`Builtin:PrimitiveIndex`].
+ const PRIMITIVE_INDEX = 0x4;
+ /// Support for non-uniform indexing of sampled textures and storage buffer arrays.
+ const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 0x8;
+ /// Support for non-uniform indexing of uniform buffers and storage texture arrays.
+ const UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 0x10;
+ /// Support for non-uniform indexing of samplers.
+ const SAMPLER_NON_UNIFORM_INDEXING = 0x20;
+ /// Support for [`Builtin::ClipDistance`].
+ const CLIP_DISTANCE = 0x40;
+ /// Support for [`Builtin::CullDistance`].
+ const CULL_DISTANCE = 0x80;
+ /// Support for 16-bit normalized storage texture formats.
+ const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 0x100;
+ /// Support for [`BuiltIn::ViewIndex`].
+ const MULTIVIEW = 0x200;
+ /// Support for `early_depth_test`.
+ const EARLY_DEPTH_TEST = 0x400;
+ /// Support for [`Builtin::SampleIndex`] and [`Sampling::Sample`].
+ const MULTISAMPLED_SHADING = 0x800;
+ /// Support for ray queries and acceleration structures.
+ const RAY_QUERY = 0x1000;
+ }
+}
+
+impl Default for Capabilities {
+ fn default() -> Self {
+ Self::MULTISAMPLED_SHADING
+ }
+}
+
+bitflags::bitflags! {
+ /// Validation flags.
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct ShaderStages: u8 {
+ const VERTEX = 0x1;
+ const FRAGMENT = 0x2;
+ const COMPUTE = 0x4;
+ }
+}
+
+#[derive(Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct ModuleInfo {
+ type_flags: Vec<TypeFlags>,
+ functions: Vec<FunctionInfo>,
+ entry_points: Vec<FunctionInfo>,
+}
+
+impl ops::Index<Handle<crate::Type>> for ModuleInfo {
+ type Output = TypeFlags;
+ fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
+ &self.type_flags[handle.index()]
+ }
+}
+
+impl ops::Index<Handle<crate::Function>> for ModuleInfo {
+ type Output = FunctionInfo;
+ fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
+ &self.functions[handle.index()]
+ }
+}
+
+#[derive(Debug)]
+pub struct Validator {
+ flags: ValidationFlags,
+ capabilities: Capabilities,
+ types: Vec<r#type::TypeInfo>,
+ layouter: Layouter,
+ location_mask: BitSet,
+ bind_group_masks: Vec<BitSet>,
+ #[allow(dead_code)]
+ switch_values: FastHashSet<crate::SwitchValue>,
+ valid_expression_list: Vec<Handle<crate::Expression>>,
+ valid_expression_set: BitSet,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum ConstantError {
+ #[error("The type doesn't match the constant")]
+ InvalidType,
+ #[error("The component handle {0:?} can not be resolved")]
+ UnresolvedComponent(Handle<crate::Constant>),
+ #[error("The array size handle {0:?} can not be resolved")]
+ UnresolvedSize(Handle<crate::Constant>),
+ #[error(transparent)]
+ Compose(#[from] ComposeError),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum ValidationError {
+ #[error(transparent)]
+ InvalidHandle(#[from] InvalidHandleError),
+ #[error(transparent)]
+ Layouter(#[from] LayoutError),
+ #[error("Type {handle:?} '{name}' is invalid")]
+ Type {
+ handle: Handle<crate::Type>,
+ name: String,
+ source: TypeError,
+ },
+ #[error("Constant {handle:?} '{name}' is invalid")]
+ Constant {
+ handle: Handle<crate::Constant>,
+ name: String,
+ source: ConstantError,
+ },
+ #[error("Global variable {handle:?} '{name}' is invalid")]
+ GlobalVariable {
+ handle: Handle<crate::GlobalVariable>,
+ name: String,
+ source: GlobalVariableError,
+ },
+ #[error("Function {handle:?} '{name}' is invalid")]
+ Function {
+ handle: Handle<crate::Function>,
+ name: String,
+ source: FunctionError,
+ },
+ #[error("Entry point {name} at {stage:?} is invalid")]
+ EntryPoint {
+ stage: crate::ShaderStage,
+ name: String,
+ source: EntryPointError,
+ },
+ #[error("Module is corrupted")]
+ Corrupted,
+}
+
+impl crate::TypeInner {
+ #[cfg(feature = "validate")]
+ const fn is_sized(&self) -> bool {
+ match *self {
+ Self::Scalar { .. }
+ | Self::Vector { .. }
+ | Self::Matrix { .. }
+ | Self::Array {
+ size: crate::ArraySize::Constant(_),
+ ..
+ }
+ | Self::Atomic { .. }
+ | Self::Pointer { .. }
+ | Self::ValuePointer { .. }
+ | Self::Struct { .. } => true,
+ Self::Array { .. }
+ | Self::Image { .. }
+ | Self::Sampler { .. }
+ | Self::AccelerationStructure
+ | Self::RayQuery
+ | Self::BindingArray { .. } => false,
+ }
+ }
+
+ /// Return the `ImageDimension` for which `self` is an appropriate coordinate.
+ #[cfg(feature = "validate")]
+ const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
+ match *self {
+ Self::Scalar {
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ ..
+ } => Some(crate::ImageDimension::D1),
+ Self::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ ..
+ } => Some(crate::ImageDimension::D2),
+ Self::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ ..
+ } => Some(crate::ImageDimension::D3),
+ _ => None,
+ }
+ }
+}
+
+impl Validator {
+ /// Construct a new validator instance.
+ pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
+ Validator {
+ flags,
+ capabilities,
+ types: Vec::new(),
+ layouter: Layouter::default(),
+ location_mask: BitSet::new(),
+ bind_group_masks: Vec::new(),
+ switch_values: FastHashSet::default(),
+ valid_expression_list: Vec::new(),
+ valid_expression_set: BitSet::new(),
+ }
+ }
+
+ /// Reset the validator internals
+ pub fn reset(&mut self) {
+ self.types.clear();
+ self.layouter.clear();
+ self.location_mask.clear();
+ self.bind_group_masks.clear();
+ self.switch_values.clear();
+ self.valid_expression_list.clear();
+ self.valid_expression_set.clear();
+ }
+
+ #[cfg(feature = "validate")]
+ fn validate_constant(
+ &self,
+ handle: Handle<crate::Constant>,
+ constants: &Arena<crate::Constant>,
+ types: &UniqueArena<crate::Type>,
+ ) -> Result<(), ConstantError> {
+ let con = &constants[handle];
+ match con.inner {
+ crate::ConstantInner::Scalar { width, ref value } => {
+ if self.check_width(value.scalar_kind(), width).is_err() {
+ return Err(ConstantError::InvalidType);
+ }
+ }
+ crate::ConstantInner::Composite { ty, ref components } => {
+ match types[ty].inner {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(size_handle),
+ ..
+ } if handle <= size_handle => {
+ return Err(ConstantError::UnresolvedSize(size_handle));
+ }
+ _ => {}
+ }
+ if let Some(&comp) = components.iter().find(|&&comp| handle <= comp) {
+ return Err(ConstantError::UnresolvedComponent(comp));
+ }
+ compose::validate_compose(
+ ty,
+ constants,
+ types,
+ components
+ .iter()
+ .map(|&component| constants[component].inner.resolve_type()),
+ )?;
+ }
+ }
+ Ok(())
+ }
+
+ /// Check the given module to be valid.
+ pub fn validate(
+ &mut self,
+ module: &crate::Module,
+ ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
+ self.reset();
+ self.reset_types(module.types.len());
+
+ #[cfg(feature = "validate")]
+ Self::validate_module_handles(module).map_err(|e| e.with_span())?;
+
+ self.layouter
+ .update(&module.types, &module.constants)
+ .map_err(|e| {
+ let handle = e.ty;
+ ValidationError::from(e).with_span_handle(handle, &module.types)
+ })?;
+
+ #[cfg(feature = "validate")]
+ if self.flags.contains(ValidationFlags::CONSTANTS) {
+ for (handle, constant) in module.constants.iter() {
+ self.validate_constant(handle, &module.constants, &module.types)
+ .map_err(|source| {
+ ValidationError::Constant {
+ handle,
+ name: constant.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(handle, &module.constants)
+ })?
+ }
+ }
+
+ let mut mod_info = ModuleInfo {
+ type_flags: Vec::with_capacity(module.types.len()),
+ functions: Vec::with_capacity(module.functions.len()),
+ entry_points: Vec::with_capacity(module.entry_points.len()),
+ };
+
+ for (handle, ty) in module.types.iter() {
+ let ty_info = self
+ .validate_type(handle, &module.types, &module.constants)
+ .map_err(|source| {
+ ValidationError::Type {
+ handle,
+ name: ty.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(handle, &module.types)
+ })?;
+ mod_info.type_flags.push(ty_info.flags);
+ self.types[handle.index()] = ty_info;
+ }
+
+ #[cfg(feature = "validate")]
+ for (var_handle, var) in module.global_variables.iter() {
+ self.validate_global_var(var, &module.types)
+ .map_err(|source| {
+ ValidationError::GlobalVariable {
+ handle: var_handle,
+ name: var.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(var_handle, &module.global_variables)
+ })?;
+ }
+
+ for (handle, fun) in module.functions.iter() {
+ match self.validate_function(fun, module, &mod_info, false) {
+ Ok(info) => mod_info.functions.push(info),
+ Err(error) => {
+ return Err(error.and_then(|source| {
+ ValidationError::Function {
+ handle,
+ name: fun.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(handle, &module.functions)
+ }))
+ }
+ }
+ }
+
+ let mut ep_map = FastHashSet::default();
+ for ep in module.entry_points.iter() {
+ if !ep_map.insert((ep.stage, &ep.name)) {
+ return Err(ValidationError::EntryPoint {
+ stage: ep.stage,
+ name: ep.name.clone(),
+ source: EntryPointError::Conflict,
+ }
+ .with_span()); // TODO: keep some EP span information?
+ }
+
+ match self.validate_entry_point(ep, module, &mod_info) {
+ Ok(info) => mod_info.entry_points.push(info),
+ Err(error) => {
+ return Err(error.and_then(|source| {
+ ValidationError::EntryPoint {
+ stage: ep.stage,
+ name: ep.name.clone(),
+ source,
+ }
+ .with_span()
+ }));
+ }
+ }
+ }
+
+ Ok(mod_info)
+ }
+}
+
+#[cfg(feature = "validate")]
+fn validate_atomic_compare_exchange_struct(
+ types: &UniqueArena<crate::Type>,
+ members: &[crate::StructMember],
+ scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
+) -> bool {
+ members.len() == 2
+ && members[0].name.as_deref() == Some("old_value")
+ && scalar_predicate(&types[members[0].ty].inner)
+ && members[1].name.as_deref() == Some("exchanged")
+ && types[members[1].ty].inner
+ == crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ }
+}
diff --git a/third_party/rust/naga/src/valid/type.rs b/third_party/rust/naga/src/valid/type.rs
new file mode 100644
index 0000000000..d8dd37d09b
--- /dev/null
+++ b/third_party/rust/naga/src/valid/type.rs
@@ -0,0 +1,636 @@
+use super::Capabilities;
+use crate::{
+ arena::{Arena, Handle, UniqueArena},
+ proc::Alignment,
+};
+
+bitflags::bitflags! {
+ /// Flags associated with [`Type`]s by [`Validator`].
+ ///
+ /// [`Type`]: crate::Type
+ /// [`Validator`]: crate::valid::Validator
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ #[repr(transparent)]
+ pub struct TypeFlags: u8 {
+ /// Can be used for data variables.
+ ///
+ /// This flag is required on types of local variables, function
+ /// arguments, array elements, and struct members.
+ ///
+ /// This includes all types except `Image`, `Sampler`,
+ /// and some `Pointer` types.
+ const DATA = 0x1;
+
+ /// The data type has a size known by pipeline creation time.
+ ///
+ /// Unsized types are quite restricted. The only unsized types permitted
+ /// by Naga, other than the non-[`DATA`] types like [`Image`] and
+ /// [`Sampler`], are dynamically-sized [`Array`s], and [`Struct`s] whose
+ /// last members are such arrays. See the documentation for those types
+ /// for details.
+ ///
+ /// [`DATA`]: TypeFlags::DATA
+ /// [`Image`]: crate::Type::Image
+ /// [`Sampler`]: crate::Type::Sampler
+ /// [`Array`]: crate::Type::Array
+ /// [`Struct`]: crate::Type::struct
+ const SIZED = 0x2;
+
+ /// The data can be copied around.
+ const COPY = 0x4;
+
+ /// Can be be used for user-defined IO between pipeline stages.
+ ///
+ /// This covers anything that can be in [`Location`] binding:
+ /// non-bool scalars and vectors, matrices, and structs and
+ /// arrays containing only interface types.
+ const IO_SHAREABLE = 0x8;
+
+ /// Can be used for host-shareable structures.
+ const HOST_SHAREABLE = 0x10;
+
+ /// This type can be passed as a function argument.
+ const ARGUMENT = 0x40;
+
+ /// A WGSL [constructible] type.
+ ///
+ /// The constructible types are scalars, vectors, matrices, fixed-size
+ /// arrays of constructible types, and structs whose members are all
+ /// constructible.
+ ///
+ /// [constructible]: https://gpuweb.github.io/gpuweb/wgsl/#constructible
+ const CONSTRUCTIBLE = 0x80;
+ }
+}
+
+#[derive(Clone, Copy, Debug, thiserror::Error)]
+pub enum Disalignment {
+ #[error("The array stride {stride} is not a multiple of the required alignment {alignment}")]
+ ArrayStride { stride: u32, alignment: Alignment },
+ #[error("The struct span {span}, is not a multiple of the required alignment {alignment}")]
+ StructSpan { span: u32, alignment: Alignment },
+ #[error("The struct member[{index}] offset {offset} is not a multiple of the required alignment {alignment}")]
+ MemberOffset {
+ index: u32,
+ offset: u32,
+ alignment: Alignment,
+ },
+ #[error("The struct member[{index}] offset {offset} must be at least {expected}")]
+ MemberOffsetAfterStruct {
+ index: u32,
+ offset: u32,
+ expected: u32,
+ },
+ #[error("The struct member[{index}] is not statically sized")]
+ UnsizedMember { index: u32 },
+ #[error("The type is not host-shareable")]
+ NonHostShareable,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum TypeError {
+ #[error("Capability {0:?} is required")]
+ MissingCapability(Capabilities),
+ #[error("The {0:?} scalar width {1} is not supported")]
+ InvalidWidth(crate::ScalarKind, crate::Bytes),
+ #[error("The {0:?} scalar width {1} is not supported for an atomic")]
+ InvalidAtomicWidth(crate::ScalarKind, crate::Bytes),
+ #[error("The base handle {0:?} can not be resolved")]
+ UnresolvedBase(Handle<crate::Type>),
+ #[error("Invalid type for pointer target {0:?}")]
+ InvalidPointerBase(Handle<crate::Type>),
+ #[error("Unsized types like {base:?} must be in the `Storage` address space, not `{space:?}`")]
+ InvalidPointerToUnsized {
+ base: Handle<crate::Type>,
+ space: crate::AddressSpace,
+ },
+ #[error("Expected data type, found {0:?}")]
+ InvalidData(Handle<crate::Type>),
+ #[error("Base type {0:?} for the array is invalid")]
+ InvalidArrayBaseType(Handle<crate::Type>),
+ #[error("The constant {0:?} can not be used for an array size")]
+ InvalidArraySizeConstant(Handle<crate::Constant>),
+ #[error("The constant {0:?} is specialized, and cannot be used as an array size")]
+ UnsupportedSpecializedArrayLength(Handle<crate::Constant>),
+ #[error("Array type {0:?} must have a length of one or more")]
+ NonPositiveArrayLength(Handle<crate::Constant>),
+ #[error("Array stride {stride} does not match the expected {expected}")]
+ InvalidArrayStride { stride: u32, expected: u32 },
+ #[error("Field '{0}' can't be dynamically-sized, has type {1:?}")]
+ InvalidDynamicArray(String, Handle<crate::Type>),
+ #[error("Structure member[{index}] at {offset} overlaps the previous member")]
+ MemberOverlap { index: u32, offset: u32 },
+ #[error(
+ "Structure member[{index}] at {offset} and size {size} crosses the structure boundary of size {span}"
+ )]
+ MemberOutOfBounds {
+ index: u32,
+ offset: u32,
+ size: u32,
+ span: u32,
+ },
+ #[error("Structure types must have at least one member")]
+ EmptyStruct,
+}
+
+// Only makes sense if `flags.contains(HOST_SHAREABLE)`
+type LayoutCompatibility = Result<Alignment, (Handle<crate::Type>, Disalignment)>;
+
+fn check_member_layout(
+ accum: &mut LayoutCompatibility,
+ member: &crate::StructMember,
+ member_index: u32,
+ member_layout: LayoutCompatibility,
+ parent_handle: Handle<crate::Type>,
+) {
+ *accum = match (*accum, member_layout) {
+ (Ok(cur_alignment), Ok(alignment)) => {
+ if alignment.is_aligned(member.offset) {
+ Ok(cur_alignment.max(alignment))
+ } else {
+ Err((
+ parent_handle,
+ Disalignment::MemberOffset {
+ index: member_index,
+ offset: member.offset,
+ alignment,
+ },
+ ))
+ }
+ }
+ (Err(e), _) | (_, Err(e)) => Err(e),
+ };
+}
+
+/// Determine whether a pointer in `space` can be passed as an argument.
+///
+/// If a pointer in `space` is permitted to be passed as an argument to a
+/// user-defined function, return `TypeFlags::ARGUMENT`. Otherwise, return
+/// `TypeFlags::empty()`.
+///
+/// Pointers passed as arguments to user-defined functions must be in the
+/// `Function`, `Private`, or `Workgroup` storage space.
+const fn ptr_space_argument_flag(space: crate::AddressSpace) -> TypeFlags {
+ use crate::AddressSpace as As;
+ match space {
+ As::Function | As::Private | As::WorkGroup => TypeFlags::ARGUMENT,
+ As::Uniform | As::Storage { .. } | As::Handle | As::PushConstant => TypeFlags::empty(),
+ }
+}
+
+#[derive(Clone, Debug)]
+pub(super) struct TypeInfo {
+ pub flags: TypeFlags,
+ pub uniform_layout: LayoutCompatibility,
+ pub storage_layout: LayoutCompatibility,
+}
+
+impl TypeInfo {
+ const fn dummy() -> Self {
+ TypeInfo {
+ flags: TypeFlags::empty(),
+ uniform_layout: Ok(Alignment::ONE),
+ storage_layout: Ok(Alignment::ONE),
+ }
+ }
+
+ const fn new(flags: TypeFlags, alignment: Alignment) -> Self {
+ TypeInfo {
+ flags,
+ uniform_layout: Ok(alignment),
+ storage_layout: Ok(alignment),
+ }
+ }
+}
+
+impl super::Validator {
+ const fn require_type_capability(&self, capability: Capabilities) -> Result<(), TypeError> {
+ if self.capabilities.contains(capability) {
+ Ok(())
+ } else {
+ Err(TypeError::MissingCapability(capability))
+ }
+ }
+
+ pub(super) fn check_width(
+ &self,
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ ) -> Result<(), TypeError> {
+ let good = match kind {
+ crate::ScalarKind::Bool => width == crate::BOOL_WIDTH,
+ crate::ScalarKind::Float => {
+ if width == 8 {
+ self.require_type_capability(Capabilities::FLOAT64)?;
+ true
+ } else {
+ width == 4
+ }
+ }
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4,
+ };
+ if good {
+ Ok(())
+ } else {
+ Err(TypeError::InvalidWidth(kind, width))
+ }
+ }
+
+ pub(super) fn reset_types(&mut self, size: usize) {
+ self.types.clear();
+ self.types.resize(size, TypeInfo::dummy());
+ self.layouter.clear();
+ }
+
+ pub(super) fn validate_type(
+ &self,
+ handle: Handle<crate::Type>,
+ types: &UniqueArena<crate::Type>,
+ constants: &Arena<crate::Constant>,
+ ) -> Result<TypeInfo, TypeError> {
+ use crate::TypeInner as Ti;
+ Ok(match types[handle].inner {
+ Ti::Scalar { kind, width } => {
+ self.check_width(kind, width)?;
+ let shareable = if kind.is_numeric() {
+ TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE
+ } else {
+ TypeFlags::empty()
+ };
+ TypeInfo::new(
+ TypeFlags::DATA
+ | TypeFlags::SIZED
+ | TypeFlags::COPY
+ | TypeFlags::ARGUMENT
+ | TypeFlags::CONSTRUCTIBLE
+ | shareable,
+ Alignment::from_width(width),
+ )
+ }
+ Ti::Vector { size, kind, width } => {
+ self.check_width(kind, width)?;
+ let shareable = if kind.is_numeric() {
+ TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE
+ } else {
+ TypeFlags::empty()
+ };
+ TypeInfo::new(
+ TypeFlags::DATA
+ | TypeFlags::SIZED
+ | TypeFlags::COPY
+ | TypeFlags::HOST_SHAREABLE
+ | TypeFlags::ARGUMENT
+ | TypeFlags::CONSTRUCTIBLE
+ | shareable,
+ Alignment::from(size) * Alignment::from_width(width),
+ )
+ }
+ Ti::Matrix {
+ columns: _,
+ rows,
+ width,
+ } => {
+ self.check_width(crate::ScalarKind::Float, width)?;
+ TypeInfo::new(
+ TypeFlags::DATA
+ | TypeFlags::SIZED
+ | TypeFlags::COPY
+ | TypeFlags::HOST_SHAREABLE
+ | TypeFlags::ARGUMENT
+ | TypeFlags::CONSTRUCTIBLE,
+ Alignment::from(rows) * Alignment::from_width(width),
+ )
+ }
+ Ti::Atomic { kind, width } => {
+ let good = match kind {
+ crate::ScalarKind::Bool | crate::ScalarKind::Float => false,
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4,
+ };
+ if !good {
+ return Err(TypeError::InvalidAtomicWidth(kind, width));
+ }
+ TypeInfo::new(
+ TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE,
+ Alignment::from_width(width),
+ )
+ }
+ Ti::Pointer { base, space } => {
+ use crate::AddressSpace as As;
+
+ if base >= handle {
+ return Err(TypeError::UnresolvedBase(base));
+ }
+
+ let base_info = &self.types[base.index()];
+ if !base_info.flags.contains(TypeFlags::DATA) {
+ return Err(TypeError::InvalidPointerBase(base));
+ }
+
+ // Runtime-sized values can only live in the `Storage` storage
+ // space, so it's useless to have a pointer to such a type in
+ // any other space.
+ //
+ // Detecting this problem here prevents the definition of
+ // functions like:
+ //
+ // fn f(p: ptr<workgroup, UnsizedType>) -> ... { ... }
+ //
+ // which would otherwise be permitted, but uncallable. (They
+ // may also present difficulties in code generation).
+ if !base_info.flags.contains(TypeFlags::SIZED) {
+ match space {
+ As::Storage { .. } => {}
+ _ => {
+ return Err(TypeError::InvalidPointerToUnsized { base, space });
+ }
+ }
+ }
+
+ // `Validator::validate_function` actually checks the storage
+ // space of pointer arguments explicitly before checking the
+ // `ARGUMENT` flag, to give better error messages. But it seems
+ // best to set `ARGUMENT` accurately anyway.
+ let argument_flag = ptr_space_argument_flag(space);
+
+ // Pointers cannot be stored in variables, structure members, or
+ // array elements, so we do not mark them as `DATA`.
+ TypeInfo::new(
+ argument_flag | TypeFlags::SIZED | TypeFlags::COPY,
+ Alignment::ONE,
+ )
+ }
+ Ti::ValuePointer {
+ size: _,
+ kind,
+ width,
+ space,
+ } => {
+ // ValuePointer should be treated the same way as the equivalent
+ // Pointer / Scalar / Vector combination, so each step in those
+ // variants' match arms should have a counterpart here.
+ //
+ // However, some cases are trivial: All our implicit base types
+ // are DATA and SIZED, so we can never return
+ // `InvalidPointerBase` or `InvalidPointerToUnsized`.
+ self.check_width(kind, width)?;
+
+ // `Validator::validate_function` actually checks the storage
+ // space of pointer arguments explicitly before checking the
+ // `ARGUMENT` flag, to give better error messages. But it seems
+ // best to set `ARGUMENT` accurately anyway.
+ let argument_flag = ptr_space_argument_flag(space);
+
+ // Pointers cannot be stored in variables, structure members, or
+ // array elements, so we do not mark them as `DATA`.
+ TypeInfo::new(
+ argument_flag | TypeFlags::SIZED | TypeFlags::COPY,
+ Alignment::ONE,
+ )
+ }
+ Ti::Array { base, size, stride } => {
+ if base >= handle {
+ return Err(TypeError::UnresolvedBase(base));
+ }
+ let base_info = &self.types[base.index()];
+ if !base_info.flags.contains(TypeFlags::DATA | TypeFlags::SIZED) {
+ return Err(TypeError::InvalidArrayBaseType(base));
+ }
+
+ let base_layout = self.layouter[base];
+ let general_alignment = base_layout.alignment;
+ let uniform_layout = match base_info.uniform_layout {
+ Ok(base_alignment) => {
+ let alignment = base_alignment
+ .max(general_alignment)
+ .max(Alignment::MIN_UNIFORM);
+ if alignment.is_aligned(stride) {
+ Ok(alignment)
+ } else {
+ Err((handle, Disalignment::ArrayStride { stride, alignment }))
+ }
+ }
+ Err(e) => Err(e),
+ };
+ let storage_layout = match base_info.storage_layout {
+ Ok(base_alignment) => {
+ let alignment = base_alignment.max(general_alignment);
+ if alignment.is_aligned(stride) {
+ Ok(alignment)
+ } else {
+ Err((handle, Disalignment::ArrayStride { stride, alignment }))
+ }
+ }
+ Err(e) => Err(e),
+ };
+
+ let type_info_mask = match size {
+ crate::ArraySize::Constant(const_handle) => {
+ let constant = &constants[const_handle];
+ let length_is_positive = match *constant {
+ crate::Constant {
+ specialization: Some(_),
+ ..
+ } => {
+ // Many of our back ends don't seem to support
+ // specializable array lengths. If you want to try to make
+ // this work, be sure to address all uses of
+ // `Constant::to_array_length`, which ignores
+ // specialization.
+ return Err(TypeError::UnsupportedSpecializedArrayLength(
+ const_handle,
+ ));
+ }
+ crate::Constant {
+ inner:
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Uint(length),
+ },
+ ..
+ } => length > 0,
+ // Accept a signed integer size to avoid
+ // requiring an explicit uint
+ // literal. Type inference should make
+ // this unnecessary.
+ crate::Constant {
+ inner:
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Sint(length),
+ },
+ ..
+ } => length > 0,
+ _ => {
+ log::warn!("Array size {:?}", constant);
+ return Err(TypeError::InvalidArraySizeConstant(const_handle));
+ }
+ };
+
+ if !length_is_positive {
+ return Err(TypeError::NonPositiveArrayLength(const_handle));
+ }
+
+ TypeFlags::DATA
+ | TypeFlags::SIZED
+ | TypeFlags::COPY
+ | TypeFlags::HOST_SHAREABLE
+ | TypeFlags::ARGUMENT
+ | TypeFlags::CONSTRUCTIBLE
+ }
+ crate::ArraySize::Dynamic => {
+ // Non-SIZED types may only appear as the last element of a structure.
+ // This is enforced by checks for SIZED-ness for all compound types,
+ // and a special case for structs.
+ TypeFlags::DATA | TypeFlags::COPY | TypeFlags::HOST_SHAREABLE
+ }
+ };
+
+ TypeInfo {
+ flags: base_info.flags & type_info_mask,
+ uniform_layout,
+ storage_layout,
+ }
+ }
+ Ti::Struct { ref members, span } => {
+ if members.is_empty() {
+ return Err(TypeError::EmptyStruct);
+ }
+
+ let mut ti = TypeInfo::new(
+ TypeFlags::DATA
+ | TypeFlags::SIZED
+ | TypeFlags::COPY
+ | TypeFlags::HOST_SHAREABLE
+ | TypeFlags::IO_SHAREABLE
+ | TypeFlags::ARGUMENT
+ | TypeFlags::CONSTRUCTIBLE,
+ Alignment::ONE,
+ );
+ ti.uniform_layout = Ok(Alignment::MIN_UNIFORM);
+
+ let mut min_offset = 0;
+
+ let mut prev_struct_data: Option<(u32, u32)> = None;
+
+ for (i, member) in members.iter().enumerate() {
+ if member.ty >= handle {
+ return Err(TypeError::UnresolvedBase(member.ty));
+ }
+ let base_info = &self.types[member.ty.index()];
+ if !base_info.flags.contains(TypeFlags::DATA) {
+ return Err(TypeError::InvalidData(member.ty));
+ }
+ if !base_info.flags.contains(TypeFlags::HOST_SHAREABLE) {
+ if ti.uniform_layout.is_ok() {
+ ti.uniform_layout = Err((member.ty, Disalignment::NonHostShareable));
+ }
+ if ti.storage_layout.is_ok() {
+ ti.storage_layout = Err((member.ty, Disalignment::NonHostShareable));
+ }
+ }
+ ti.flags &= base_info.flags;
+
+ if member.offset < min_offset {
+ // HACK: this could be nicer. We want to allow some structures
+ // to not bother with offsets/alignments if they are never
+ // used for host sharing.
+ if member.offset == 0 {
+ ti.flags.set(TypeFlags::HOST_SHAREABLE, false);
+ } else {
+ return Err(TypeError::MemberOverlap {
+ index: i as u32,
+ offset: member.offset,
+ });
+ }
+ }
+
+ let base_size = types[member.ty].inner.size(constants);
+ min_offset = member.offset + base_size;
+ if min_offset > span {
+ return Err(TypeError::MemberOutOfBounds {
+ index: i as u32,
+ offset: member.offset,
+ size: base_size,
+ span,
+ });
+ }
+
+ check_member_layout(
+ &mut ti.uniform_layout,
+ member,
+ i as u32,
+ base_info.uniform_layout,
+ handle,
+ );
+ check_member_layout(
+ &mut ti.storage_layout,
+ member,
+ i as u32,
+ base_info.storage_layout,
+ handle,
+ );
+
+ // Validate rule: If a structure member itself has a structure type S,
+ // then the number of bytes between the start of that member and
+ // the start of any following member must be at least roundUp(16, SizeOf(S)).
+ if let Some((span, offset)) = prev_struct_data {
+ let diff = member.offset - offset;
+ let min = Alignment::MIN_UNIFORM.round_up(span);
+ if diff < min {
+ ti.uniform_layout = Err((
+ handle,
+ Disalignment::MemberOffsetAfterStruct {
+ index: i as u32,
+ offset: member.offset,
+ expected: offset + min,
+ },
+ ));
+ }
+ };
+
+ prev_struct_data = match types[member.ty].inner {
+ crate::TypeInner::Struct { span, .. } => Some((span, member.offset)),
+ _ => None,
+ };
+
+ // The last field may be an unsized array.
+ if !base_info.flags.contains(TypeFlags::SIZED) {
+ let is_array = match types[member.ty].inner {
+ crate::TypeInner::Array { .. } => true,
+ _ => false,
+ };
+ if !is_array || i + 1 != members.len() {
+ let name = member.name.clone().unwrap_or_default();
+ return Err(TypeError::InvalidDynamicArray(name, member.ty));
+ }
+ if ti.uniform_layout.is_ok() {
+ ti.uniform_layout =
+ Err((handle, Disalignment::UnsizedMember { index: i as u32 }));
+ }
+ }
+ }
+
+ let alignment = self.layouter[handle].alignment;
+ if !alignment.is_aligned(span) {
+ ti.uniform_layout = Err((handle, Disalignment::StructSpan { span, alignment }));
+ ti.storage_layout = Err((handle, Disalignment::StructSpan { span, alignment }));
+ }
+
+ ti
+ }
+ Ti::Image { .. } | Ti::Sampler { .. } => {
+ TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE)
+ }
+ Ti::AccelerationStructure => {
+ self.require_type_capability(Capabilities::RAY_QUERY)?;
+ TypeInfo::new(TypeFlags::empty(), Alignment::ONE)
+ }
+ Ti::RayQuery => {
+ self.require_type_capability(Capabilities::RAY_QUERY)?;
+ TypeInfo::new(TypeFlags::DATA | TypeFlags::SIZED, Alignment::ONE)
+ }
+ Ti::BindingArray { .. } => TypeInfo::new(TypeFlags::empty(), Alignment::ONE),
+ })
+ }
+}