summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-19 00:47:55 +0000
commit26a029d407be480d791972afb5975cf62c9360a6 (patch)
treef435a8308119effd964b339f76abb83a57c29483 /third_party/rust/naga/src
parentInitial commit. (diff)
downloadfirefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz
firefox-26a029d407be480d791972afb5975cf62c9360a6.zip
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r--third_party/rust/naga/src/arena.rs772
-rw-r--r--third_party/rust/naga/src/back/dot/mod.rs703
-rw-r--r--third_party/rust/naga/src/back/glsl/features.rs536
-rw-r--r--third_party/rust/naga/src/back/glsl/keywords.rs484
-rw-r--r--third_party/rust/naga/src/back/glsl/mod.rs4532
-rw-r--r--third_party/rust/naga/src/back/hlsl/conv.rs222
-rw-r--r--third_party/rust/naga/src/back/hlsl/help.rs1138
-rw-r--r--third_party/rust/naga/src/back/hlsl/keywords.rs904
-rw-r--r--third_party/rust/naga/src/back/hlsl/mod.rs302
-rw-r--r--third_party/rust/naga/src/back/hlsl/storage.rs494
-rw-r--r--third_party/rust/naga/src/back/hlsl/writer.rs3366
-rw-r--r--third_party/rust/naga/src/back/mod.rs273
-rw-r--r--third_party/rust/naga/src/back/msl/keywords.rs342
-rw-r--r--third_party/rust/naga/src/back/msl/mod.rs541
-rw-r--r--third_party/rust/naga/src/back/msl/sampler.rs176
-rw-r--r--third_party/rust/naga/src/back/msl/writer.rs4659
-rw-r--r--third_party/rust/naga/src/back/spv/block.rs2368
-rw-r--r--third_party/rust/naga/src/back/spv/helpers.rs109
-rw-r--r--third_party/rust/naga/src/back/spv/image.rs1210
-rw-r--r--third_party/rust/naga/src/back/spv/index.rs421
-rw-r--r--third_party/rust/naga/src/back/spv/instructions.rs1100
-rw-r--r--third_party/rust/naga/src/back/spv/layout.rs210
-rw-r--r--third_party/rust/naga/src/back/spv/mod.rs748
-rw-r--r--third_party/rust/naga/src/back/spv/ray.rs261
-rw-r--r--third_party/rust/naga/src/back/spv/recyclable.rs67
-rw-r--r--third_party/rust/naga/src/back/spv/selection.rs257
-rw-r--r--third_party/rust/naga/src/back/spv/writer.rs2063
-rw-r--r--third_party/rust/naga/src/back/wgsl/mod.rs52
-rw-r--r--third_party/rust/naga/src/back/wgsl/writer.rs1961
-rw-r--r--third_party/rust/naga/src/block.rs123
-rw-r--r--third_party/rust/naga/src/compact/expressions.rs389
-rw-r--r--third_party/rust/naga/src/compact/functions.rs106
-rw-r--r--third_party/rust/naga/src/compact/handle_set_map.rs170
-rw-r--r--third_party/rust/naga/src/compact/mod.rs307
-rw-r--r--third_party/rust/naga/src/compact/statements.rs300
-rw-r--r--third_party/rust/naga/src/compact/types.rs102
-rw-r--r--third_party/rust/naga/src/front/glsl/ast.rs394
-rw-r--r--third_party/rust/naga/src/front/glsl/builtins.rs2314
-rw-r--r--third_party/rust/naga/src/front/glsl/context.rs1506
-rw-r--r--third_party/rust/naga/src/front/glsl/error.rs191
-rw-r--r--third_party/rust/naga/src/front/glsl/functions.rs1602
-rw-r--r--third_party/rust/naga/src/front/glsl/lex.rs301
-rw-r--r--third_party/rust/naga/src/front/glsl/mod.rs232
-rw-r--r--third_party/rust/naga/src/front/glsl/offset.rs173
-rw-r--r--third_party/rust/naga/src/front/glsl/parser.rs431
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/declarations.rs677
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/expressions.rs542
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/functions.rs656
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/types.rs443
-rw-r--r--third_party/rust/naga/src/front/glsl/parser_tests.rs858
-rw-r--r--third_party/rust/naga/src/front/glsl/token.rs137
-rw-r--r--third_party/rust/naga/src/front/glsl/types.rs360
-rw-r--r--third_party/rust/naga/src/front/glsl/variables.rs646
-rw-r--r--third_party/rust/naga/src/front/interpolator.rs62
-rw-r--r--third_party/rust/naga/src/front/mod.rs328
-rw-r--r--third_party/rust/naga/src/front/spv/convert.rs179
-rw-r--r--third_party/rust/naga/src/front/spv/error.rs154
-rw-r--r--third_party/rust/naga/src/front/spv/function.rs674
-rw-r--r--third_party/rust/naga/src/front/spv/image.rs767
-rw-r--r--third_party/rust/naga/src/front/spv/mod.rs5356
-rw-r--r--third_party/rust/naga/src/front/spv/null.rs31
-rw-r--r--third_party/rust/naga/src/front/type_gen.rs437
-rw-r--r--third_party/rust/naga/src/front/wgsl/error.rs775
-rw-r--r--third_party/rust/naga/src/front/wgsl/index.rs193
-rw-r--r--third_party/rust/naga/src/front/wgsl/lower/construction.rs616
-rw-r--r--third_party/rust/naga/src/front/wgsl/lower/conversion.rs503
-rw-r--r--third_party/rust/naga/src/front/wgsl/lower/mod.rs2760
-rw-r--r--third_party/rust/naga/src/front/wgsl/mod.rs49
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/ast.rs491
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/conv.rs254
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/lexer.rs739
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/mod.rs2350
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/number.rs420
-rw-r--r--third_party/rust/naga/src/front/wgsl/tests.rs637
-rw-r--r--third_party/rust/naga/src/front/wgsl/to_wgsl.rs283
-rw-r--r--third_party/rust/naga/src/keywords/mod.rs6
-rw-r--r--third_party/rust/naga/src/keywords/wgsl.rs229
-rw-r--r--third_party/rust/naga/src/lib.rs2072
-rw-r--r--third_party/rust/naga/src/proc/constant_evaluator.rs2475
-rw-r--r--third_party/rust/naga/src/proc/emitter.rs39
-rw-r--r--third_party/rust/naga/src/proc/index.rs435
-rw-r--r--third_party/rust/naga/src/proc/layouter.rs251
-rw-r--r--third_party/rust/naga/src/proc/mod.rs809
-rw-r--r--third_party/rust/naga/src/proc/namer.rs281
-rw-r--r--third_party/rust/naga/src/proc/terminator.rs44
-rw-r--r--third_party/rust/naga/src/proc/typifier.rs893
-rw-r--r--third_party/rust/naga/src/span.rs501
-rw-r--r--third_party/rust/naga/src/valid/analyzer.rs1281
-rw-r--r--third_party/rust/naga/src/valid/compose.rs128
-rw-r--r--third_party/rust/naga/src/valid/expression.rs1797
-rw-r--r--third_party/rust/naga/src/valid/function.rs1056
-rw-r--r--third_party/rust/naga/src/valid/handles.rs699
-rw-r--r--third_party/rust/naga/src/valid/interface.rs709
-rw-r--r--third_party/rust/naga/src/valid/mod.rs477
-rw-r--r--third_party/rust/naga/src/valid/type.rs643
95 files changed, 76114 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/arena.rs b/third_party/rust/naga/src/arena.rs
new file mode 100644
index 0000000000..c37538667f
--- /dev/null
+++ b/third_party/rust/naga/src/arena.rs
@@ -0,0 +1,772 @@
+use std::{cmp::Ordering, fmt, hash, marker::PhantomData, num::NonZeroU32, ops};
+
+/// An unique index in the arena array that a handle points to.
+/// The "non-zero" part ensures that an `Option<Handle<T>>` has
+/// the same size and representation as `Handle<T>`.
+type Index = NonZeroU32;
+
+use crate::{FastIndexSet, Span};
+
+#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)]
+#[error("Handle {index} of {kind} is either not present, or inaccessible yet")]
+pub struct BadHandle {
+ pub kind: &'static str,
+ pub index: usize,
+}
+
+impl BadHandle {
+ fn new<T>(handle: Handle<T>) -> Self {
+ Self {
+ kind: std::any::type_name::<T>(),
+ index: handle.index(),
+ }
+ }
+}
+
+/// A strongly typed reference to an arena item.
+///
+/// A `Handle` value can be used as an index into an [`Arena`] or [`UniqueArena`].
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(
+ any(feature = "serialize", feature = "deserialize"),
+ serde(transparent)
+)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub struct Handle<T> {
+ index: Index,
+ #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))]
+ marker: PhantomData<T>,
+}
+
+impl<T> Clone for Handle<T> {
+ fn clone(&self) -> Self {
+ *self
+ }
+}
+
+impl<T> Copy for Handle<T> {}
+
+impl<T> PartialEq for Handle<T> {
+ fn eq(&self, other: &Self) -> bool {
+ self.index == other.index
+ }
+}
+
+impl<T> Eq for Handle<T> {}
+
+impl<T> PartialOrd for Handle<T> {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+}
+
+impl<T> Ord for Handle<T> {
+ fn cmp(&self, other: &Self) -> Ordering {
+ self.index.cmp(&other.index)
+ }
+}
+
+impl<T> fmt::Debug for Handle<T> {
+ fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ write!(formatter, "[{}]", self.index)
+ }
+}
+
+impl<T> hash::Hash for Handle<T> {
+ fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
+ self.index.hash(hasher)
+ }
+}
+
+impl<T> Handle<T> {
+ #[cfg(test)]
+ pub const DUMMY: Self = Handle {
+ index: unsafe { NonZeroU32::new_unchecked(u32::MAX) },
+ marker: PhantomData,
+ };
+
+ pub(crate) const fn new(index: Index) -> Self {
+ Handle {
+ index,
+ marker: PhantomData,
+ }
+ }
+
+ /// Returns the zero-based index of this handle.
+ pub const fn index(self) -> usize {
+ let index = self.index.get() - 1;
+ index as usize
+ }
+
+ /// Convert a `usize` index into a `Handle<T>`.
+ fn from_usize(index: usize) -> Self {
+ let handle_index = u32::try_from(index + 1)
+ .ok()
+ .and_then(Index::new)
+ .expect("Failed to insert into arena. Handle overflows");
+ Handle::new(handle_index)
+ }
+
+ /// Convert a `usize` index into a `Handle<T>`, without range checks.
+ const unsafe fn from_usize_unchecked(index: usize) -> Self {
+ Handle::new(Index::new_unchecked((index + 1) as u32))
+ }
+}
+
+/// A strongly typed range of handles.
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(
+ any(feature = "serialize", feature = "deserialize"),
+ serde(transparent)
+)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub struct Range<T> {
+ inner: ops::Range<u32>,
+ #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))]
+ marker: PhantomData<T>,
+}
+
+impl<T> Range<T> {
+ pub(crate) const fn erase_type(self) -> Range<()> {
+ let Self { inner, marker: _ } = self;
+ Range {
+ inner,
+ marker: PhantomData,
+ }
+ }
+}
+
+// NOTE: Keep this diagnostic in sync with that of [`BadHandle`].
+#[derive(Clone, Debug, thiserror::Error)]
+#[error("Handle range {range:?} of {kind} is either not present, or inaccessible yet")]
+pub struct BadRangeError {
+ // This error is used for many `Handle` types, but there's no point in making this generic, so
+ // we just flatten them all to `Handle<()>` here.
+ kind: &'static str,
+ range: Range<()>,
+}
+
+impl BadRangeError {
+ pub fn new<T>(range: Range<T>) -> Self {
+ Self {
+ kind: std::any::type_name::<T>(),
+ range: range.erase_type(),
+ }
+ }
+}
+
+impl<T> Clone for Range<T> {
+ fn clone(&self) -> Self {
+ Range {
+ inner: self.inner.clone(),
+ marker: self.marker,
+ }
+ }
+}
+
+impl<T> fmt::Debug for Range<T> {
+ fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ write!(formatter, "[{}..{}]", self.inner.start + 1, self.inner.end)
+ }
+}
+
+impl<T> Iterator for Range<T> {
+ type Item = Handle<T>;
+ fn next(&mut self) -> Option<Self::Item> {
+ if self.inner.start < self.inner.end {
+ self.inner.start += 1;
+ Some(Handle {
+ index: NonZeroU32::new(self.inner.start).unwrap(),
+ marker: self.marker,
+ })
+ } else {
+ None
+ }
+ }
+}
+
+impl<T> Range<T> {
+ /// Return a range enclosing handles `first` through `last`, inclusive.
+ pub fn new_from_bounds(first: Handle<T>, last: Handle<T>) -> Self {
+ Self {
+ inner: (first.index() as u32)..(last.index() as u32 + 1),
+ marker: Default::default(),
+ }
+ }
+
+ /// return the first and last handles included in `self`.
+ ///
+ /// If `self` is an empty range, there are no handles included, so
+ /// return `None`.
+ pub fn first_and_last(&self) -> Option<(Handle<T>, Handle<T>)> {
+ if self.inner.start < self.inner.end {
+ Some((
+ // `Range::new_from_bounds` expects a 1-based, start- and
+ // end-inclusive range, but `self.inner` is a zero-based,
+ // end-exclusive range.
+ Handle::new(Index::new(self.inner.start + 1).unwrap()),
+ Handle::new(Index::new(self.inner.end).unwrap()),
+ ))
+ } else {
+ None
+ }
+ }
+
+ /// Return the zero-based index range covered by `self`.
+ pub fn zero_based_index_range(&self) -> ops::Range<u32> {
+ self.inner.clone()
+ }
+
+ /// Construct a `Range` that covers the zero-based indices in `inner`.
+ pub fn from_zero_based_index_range(inner: ops::Range<u32>, arena: &Arena<T>) -> Self {
+ // Since `inner` is a `Range<u32>`, we only need to check that
+ // the start and end are well-ordered, and that the end fits
+ // within `arena`.
+ assert!(inner.start <= inner.end);
+ assert!(inner.end as usize <= arena.len());
+ Self {
+ inner,
+ marker: Default::default(),
+ }
+ }
+}
+
+/// An arena holding some kind of component (e.g., type, constant,
+/// instruction, etc.) that can be referenced.
+///
+/// Adding new items to the arena produces a strongly-typed [`Handle`].
+/// The arena can be indexed using the given handle to obtain
+/// a reference to the stored item.
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "serialize", serde(transparent))]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Arena<T> {
+ /// Values of this arena.
+ data: Vec<T>,
+ #[cfg_attr(feature = "serialize", serde(skip))]
+ span_info: Vec<Span>,
+}
+
+impl<T> Default for Arena<T> {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl<T: fmt::Debug> fmt::Debug for Arena<T> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.debug_map().entries(self.iter()).finish()
+ }
+}
+
+impl<T> Arena<T> {
+ /// Create a new arena with no initial capacity allocated.
+ pub const fn new() -> Self {
+ Arena {
+ data: Vec::new(),
+ span_info: Vec::new(),
+ }
+ }
+
+ /// Extracts the inner vector.
+ #[allow(clippy::missing_const_for_fn)] // ignore due to requirement of #![feature(const_precise_live_drops)]
+ pub fn into_inner(self) -> Vec<T> {
+ self.data
+ }
+
+ /// Returns the current number of items stored in this arena.
+ pub fn len(&self) -> usize {
+ self.data.len()
+ }
+
+ /// Returns `true` if the arena contains no elements.
+ pub fn is_empty(&self) -> bool {
+ self.data.is_empty()
+ }
+
+ /// Returns an iterator over the items stored in this arena, returning both
+ /// the item's handle and a reference to it.
+ pub fn iter(&self) -> impl DoubleEndedIterator<Item = (Handle<T>, &T)> {
+ self.data
+ .iter()
+ .enumerate()
+ .map(|(i, v)| unsafe { (Handle::from_usize_unchecked(i), v) })
+ }
+
+ /// Returns a iterator over the items stored in this arena,
+ /// returning both the item's handle and a mutable reference to it.
+ pub fn iter_mut(&mut self) -> impl DoubleEndedIterator<Item = (Handle<T>, &mut T)> {
+ self.data
+ .iter_mut()
+ .enumerate()
+ .map(|(i, v)| unsafe { (Handle::from_usize_unchecked(i), v) })
+ }
+
+ /// Adds a new value to the arena, returning a typed handle.
+ pub fn append(&mut self, value: T, span: Span) -> Handle<T> {
+ let index = self.data.len();
+ self.data.push(value);
+ self.span_info.push(span);
+ Handle::from_usize(index)
+ }
+
+ /// Fetch a handle to an existing type.
+ pub fn fetch_if<F: Fn(&T) -> bool>(&self, fun: F) -> Option<Handle<T>> {
+ self.data
+ .iter()
+ .position(fun)
+ .map(|index| unsafe { Handle::from_usize_unchecked(index) })
+ }
+
+ /// Adds a value with a custom check for uniqueness:
+ /// returns a handle pointing to
+ /// an existing element if the check succeeds, or adds a new
+ /// element otherwise.
+ pub fn fetch_if_or_append<F: Fn(&T, &T) -> bool>(
+ &mut self,
+ value: T,
+ span: Span,
+ fun: F,
+ ) -> Handle<T> {
+ if let Some(index) = self.data.iter().position(|d| fun(d, &value)) {
+ unsafe { Handle::from_usize_unchecked(index) }
+ } else {
+ self.append(value, span)
+ }
+ }
+
+ /// Adds a value with a check for uniqueness, where the check is plain comparison.
+ pub fn fetch_or_append(&mut self, value: T, span: Span) -> Handle<T>
+ where
+ T: PartialEq,
+ {
+ self.fetch_if_or_append(value, span, T::eq)
+ }
+
+ pub fn try_get(&self, handle: Handle<T>) -> Result<&T, BadHandle> {
+ self.data
+ .get(handle.index())
+ .ok_or_else(|| BadHandle::new(handle))
+ }
+
+ /// Get a mutable reference to an element in the arena.
+ pub fn get_mut(&mut self, handle: Handle<T>) -> &mut T {
+ self.data.get_mut(handle.index()).unwrap()
+ }
+
+ /// Get the range of handles from a particular number of elements to the end.
+ pub fn range_from(&self, old_length: usize) -> Range<T> {
+ Range {
+ inner: old_length as u32..self.data.len() as u32,
+ marker: PhantomData,
+ }
+ }
+
+ /// Clears the arena keeping all allocations
+ pub fn clear(&mut self) {
+ self.data.clear()
+ }
+
+ pub fn get_span(&self, handle: Handle<T>) -> Span {
+ *self
+ .span_info
+ .get(handle.index())
+ .unwrap_or(&Span::default())
+ }
+
+ /// Assert that `handle` is valid for this arena.
+ pub fn check_contains_handle(&self, handle: Handle<T>) -> Result<(), BadHandle> {
+ if handle.index() < self.data.len() {
+ Ok(())
+ } else {
+ Err(BadHandle::new(handle))
+ }
+ }
+
+ /// Assert that `range` is valid for this arena.
+ pub fn check_contains_range(&self, range: &Range<T>) -> Result<(), BadRangeError> {
+ // Since `range.inner` is a `Range<u32>`, we only need to check that the
+ // start precedes the end, and that the end is in range.
+ if range.inner.start > range.inner.end {
+ return Err(BadRangeError::new(range.clone()));
+ }
+
+ // Empty ranges are tolerated: they can be produced by compaction.
+ if range.inner.start == range.inner.end {
+ return Ok(());
+ }
+
+ // `range.inner` is zero-based, but end-exclusive, so `range.inner.end`
+ // is actually the right one-based index for the last handle within the
+ // range.
+ let last_handle = Handle::new(range.inner.end.try_into().unwrap());
+ if self.check_contains_handle(last_handle).is_err() {
+ return Err(BadRangeError::new(range.clone()));
+ }
+
+ Ok(())
+ }
+
+ #[cfg(feature = "compact")]
+ pub(crate) fn retain_mut<P>(&mut self, mut predicate: P)
+ where
+ P: FnMut(Handle<T>, &mut T) -> bool,
+ {
+ let mut index = 0;
+ let mut retained = 0;
+ self.data.retain_mut(|elt| {
+ let handle = Handle::new(Index::new(index as u32 + 1).unwrap());
+ let keep = predicate(handle, elt);
+
+ // Since `predicate` needs mutable access to each element,
+ // we can't feasibly call it twice, so we have to compact
+ // spans by hand in parallel as part of this iteration.
+ if keep {
+ self.span_info[retained] = self.span_info[index];
+ retained += 1;
+ }
+
+ index += 1;
+ keep
+ });
+
+ self.span_info.truncate(retained);
+ }
+}
+
+#[cfg(feature = "deserialize")]
+impl<'de, T> serde::Deserialize<'de> for Arena<T>
+where
+ T: serde::Deserialize<'de>,
+{
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ let data = Vec::deserialize(deserializer)?;
+ let span_info = std::iter::repeat(Span::default())
+ .take(data.len())
+ .collect();
+
+ Ok(Self { data, span_info })
+ }
+}
+
+impl<T> ops::Index<Handle<T>> for Arena<T> {
+ type Output = T;
+ fn index(&self, handle: Handle<T>) -> &T {
+ &self.data[handle.index()]
+ }
+}
+
+impl<T> ops::IndexMut<Handle<T>> for Arena<T> {
+ fn index_mut(&mut self, handle: Handle<T>) -> &mut T {
+ &mut self.data[handle.index()]
+ }
+}
+
+impl<T> ops::Index<Range<T>> for Arena<T> {
+ type Output = [T];
+ fn index(&self, range: Range<T>) -> &[T] {
+ &self.data[range.inner.start as usize..range.inner.end as usize]
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn append_non_unique() {
+ let mut arena: Arena<u8> = Arena::new();
+ let t1 = arena.append(0, Default::default());
+ let t2 = arena.append(0, Default::default());
+ assert!(t1 != t2);
+ assert!(arena[t1] == arena[t2]);
+ }
+
+ #[test]
+ fn append_unique() {
+ let mut arena: Arena<u8> = Arena::new();
+ let t1 = arena.append(0, Default::default());
+ let t2 = arena.append(1, Default::default());
+ assert!(t1 != t2);
+ assert!(arena[t1] != arena[t2]);
+ }
+
+ #[test]
+ fn fetch_or_append_non_unique() {
+ let mut arena: Arena<u8> = Arena::new();
+ let t1 = arena.fetch_or_append(0, Default::default());
+ let t2 = arena.fetch_or_append(0, Default::default());
+ assert!(t1 == t2);
+ assert!(arena[t1] == arena[t2])
+ }
+
+ #[test]
+ fn fetch_or_append_unique() {
+ let mut arena: Arena<u8> = Arena::new();
+ let t1 = arena.fetch_or_append(0, Default::default());
+ let t2 = arena.fetch_or_append(1, Default::default());
+ assert!(t1 != t2);
+ assert!(arena[t1] != arena[t2]);
+ }
+}
+
+/// An arena whose elements are guaranteed to be unique.
+///
+/// A `UniqueArena` holds a set of unique values of type `T`, each with an
+/// associated [`Span`]. Inserting a value returns a `Handle<T>`, which can be
+/// used to index the `UniqueArena` and obtain shared access to the `T` element.
+/// Access via a `Handle` is an array lookup - no hash lookup is necessary.
+///
+/// The element type must implement `Eq` and `Hash`. Insertions of equivalent
+/// elements, according to `Eq`, all return the same `Handle`.
+///
+/// Once inserted, elements may not be mutated.
+///
+/// `UniqueArena` is similar to [`Arena`]: If `Arena` is vector-like,
+/// `UniqueArena` is `HashSet`-like.
+#[cfg_attr(feature = "clone", derive(Clone))]
+pub struct UniqueArena<T> {
+ set: FastIndexSet<T>,
+
+ /// Spans for the elements, indexed by handle.
+ ///
+ /// The length of this vector is always equal to `set.len()`. `FastIndexSet`
+ /// promises that its elements "are indexed in a compact range, without
+ /// holes in the range 0..set.len()", so we can always use the indices
+ /// returned by insertion as indices into this vector.
+ span_info: Vec<Span>,
+}
+
+impl<T> UniqueArena<T> {
+ /// Create a new arena with no initial capacity allocated.
+ pub fn new() -> Self {
+ UniqueArena {
+ set: FastIndexSet::default(),
+ span_info: Vec::new(),
+ }
+ }
+
+ /// Return the current number of items stored in this arena.
+ pub fn len(&self) -> usize {
+ self.set.len()
+ }
+
+ /// Return `true` if the arena contains no elements.
+ pub fn is_empty(&self) -> bool {
+ self.set.is_empty()
+ }
+
+ /// Clears the arena, keeping all allocations.
+ pub fn clear(&mut self) {
+ self.set.clear();
+ self.span_info.clear();
+ }
+
+ /// Return the span associated with `handle`.
+ ///
+ /// If a value has been inserted multiple times, the span returned is the
+ /// one provided with the first insertion.
+ pub fn get_span(&self, handle: Handle<T>) -> Span {
+ *self
+ .span_info
+ .get(handle.index())
+ .unwrap_or(&Span::default())
+ }
+
+ #[cfg(feature = "compact")]
+ pub(crate) fn drain_all(&mut self) -> UniqueArenaDrain<T> {
+ UniqueArenaDrain {
+ inner_elts: self.set.drain(..),
+ inner_spans: self.span_info.drain(..),
+ index: Index::new(1).unwrap(),
+ }
+ }
+}
+
+#[cfg(feature = "compact")]
+pub(crate) struct UniqueArenaDrain<'a, T> {
+ inner_elts: indexmap::set::Drain<'a, T>,
+ inner_spans: std::vec::Drain<'a, Span>,
+ index: Index,
+}
+
+#[cfg(feature = "compact")]
+impl<'a, T> Iterator for UniqueArenaDrain<'a, T> {
+ type Item = (Handle<T>, T, Span);
+
+ fn next(&mut self) -> Option<Self::Item> {
+ match self.inner_elts.next() {
+ Some(elt) => {
+ let handle = Handle::new(self.index);
+ self.index = self.index.checked_add(1).unwrap();
+ let span = self.inner_spans.next().unwrap();
+ Some((handle, elt, span))
+ }
+ None => None,
+ }
+ }
+}
+
+impl<T: Eq + hash::Hash> UniqueArena<T> {
+ /// Returns an iterator over the items stored in this arena, returning both
+ /// the item's handle and a reference to it.
+ pub fn iter(&self) -> impl DoubleEndedIterator<Item = (Handle<T>, &T)> {
+ self.set.iter().enumerate().map(|(i, v)| {
+ let position = i + 1;
+ let index = unsafe { Index::new_unchecked(position as u32) };
+ (Handle::new(index), v)
+ })
+ }
+
+ /// Insert a new value into the arena.
+ ///
+ /// Return a [`Handle<T>`], which can be used to index this arena to get a
+ /// shared reference to the element.
+ ///
+ /// If this arena already contains an element that is `Eq` to `value`,
+ /// return a `Handle` to the existing element, and drop `value`.
+ ///
+ /// If `value` is inserted into the arena, associate `span` with
+ /// it. An element's span can be retrieved with the [`get_span`]
+ /// method.
+ ///
+ /// [`Handle<T>`]: Handle
+ /// [`get_span`]: UniqueArena::get_span
+ pub fn insert(&mut self, value: T, span: Span) -> Handle<T> {
+ let (index, added) = self.set.insert_full(value);
+
+ if added {
+ debug_assert!(index == self.span_info.len());
+ self.span_info.push(span);
+ }
+
+ debug_assert!(self.set.len() == self.span_info.len());
+
+ Handle::from_usize(index)
+ }
+
+ /// Replace an old value with a new value.
+ ///
+ /// # Panics
+ ///
+ /// - if the old value is not in the arena
+ /// - if the new value already exists in the arena
+ pub fn replace(&mut self, old: Handle<T>, new: T) {
+ let (index, added) = self.set.insert_full(new);
+ assert!(added && index == self.set.len() - 1);
+
+ self.set.swap_remove_index(old.index()).unwrap();
+ }
+
+ /// Return this arena's handle for `value`, if present.
+ ///
+ /// If this arena already contains an element equal to `value`,
+ /// return its handle. Otherwise, return `None`.
+ pub fn get(&self, value: &T) -> Option<Handle<T>> {
+ self.set
+ .get_index_of(value)
+ .map(|index| unsafe { Handle::from_usize_unchecked(index) })
+ }
+
+ /// Return this arena's value at `handle`, if that is a valid handle.
+ pub fn get_handle(&self, handle: Handle<T>) -> Result<&T, BadHandle> {
+ self.set
+ .get_index(handle.index())
+ .ok_or_else(|| BadHandle::new(handle))
+ }
+
+ /// Assert that `handle` is valid for this arena.
+ pub fn check_contains_handle(&self, handle: Handle<T>) -> Result<(), BadHandle> {
+ if handle.index() < self.set.len() {
+ Ok(())
+ } else {
+ Err(BadHandle::new(handle))
+ }
+ }
+}
+
+impl<T> Default for UniqueArena<T> {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl<T: fmt::Debug + Eq + hash::Hash> fmt::Debug for UniqueArena<T> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.debug_map().entries(self.iter()).finish()
+ }
+}
+
+impl<T> ops::Index<Handle<T>> for UniqueArena<T> {
+ type Output = T;
+ fn index(&self, handle: Handle<T>) -> &T {
+ &self.set[handle.index()]
+ }
+}
+
+#[cfg(feature = "serialize")]
+impl<T> serde::Serialize for UniqueArena<T>
+where
+ T: Eq + hash::Hash + serde::Serialize,
+{
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: serde::Serializer,
+ {
+ self.set.serialize(serializer)
+ }
+}
+
+#[cfg(feature = "deserialize")]
+impl<'de, T> serde::Deserialize<'de> for UniqueArena<T>
+where
+ T: Eq + hash::Hash + serde::Deserialize<'de>,
+{
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ let set = FastIndexSet::deserialize(deserializer)?;
+ let span_info = std::iter::repeat(Span::default()).take(set.len()).collect();
+
+ Ok(Self { set, span_info })
+ }
+}
+
+//Note: largely borrowed from `HashSet` implementation
+#[cfg(feature = "arbitrary")]
+impl<'a, T> arbitrary::Arbitrary<'a> for UniqueArena<T>
+where
+ T: Eq + hash::Hash + arbitrary::Arbitrary<'a>,
+{
+ fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
+ let mut arena = Self::default();
+ for elem in u.arbitrary_iter()? {
+ arena.set.insert(elem?);
+ arena.span_info.push(Span::UNDEFINED);
+ }
+ Ok(arena)
+ }
+
+ fn arbitrary_take_rest(u: arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
+ let mut arena = Self::default();
+ for elem in u.arbitrary_take_rest_iter()? {
+ arena.set.insert(elem?);
+ arena.span_info.push(Span::UNDEFINED);
+ }
+ Ok(arena)
+ }
+
+ #[inline]
+ fn size_hint(depth: usize) -> (usize, Option<usize>) {
+ let depth_hint = <usize as arbitrary::Arbitrary>::size_hint(depth);
+ arbitrary::size_hint::and(depth_hint, (0, None))
+ }
+}
diff --git a/third_party/rust/naga/src/back/dot/mod.rs b/third_party/rust/naga/src/back/dot/mod.rs
new file mode 100644
index 0000000000..1556371df1
--- /dev/null
+++ b/third_party/rust/naga/src/back/dot/mod.rs
@@ -0,0 +1,703 @@
+/*!
+Backend for [DOT][dot] (Graphviz).
+
+This backend writes a graph in the DOT language, for the ease
+of IR inspection and debugging.
+
+[dot]: https://graphviz.org/doc/info/lang.html
+*/
+
+use crate::{
+ arena::Handle,
+ valid::{FunctionInfo, ModuleInfo},
+};
+
+use std::{
+ borrow::Cow,
+ fmt::{Error as FmtError, Write as _},
+};
+
+/// Configuration options for the dot backend
+#[derive(Clone, Default)]
+pub struct Options {
+ /// Only emit function bodies
+ pub cfg_only: bool,
+}
+
+/// Identifier used to address a graph node
+type NodeId = usize;
+
+/// Stores the target nodes for control flow statements
+#[derive(Default, Clone, Copy)]
+struct Targets {
+ /// The node, if some, where continue operations will land
+ continue_target: Option<usize>,
+ /// The node, if some, where break operations will land
+ break_target: Option<usize>,
+}
+
+/// Stores information about the graph of statements
+#[derive(Default)]
+struct StatementGraph {
+ /// List of node names
+ nodes: Vec<&'static str>,
+ /// List of edges of the control flow, the items are defined as
+ /// (from, to, label)
+ flow: Vec<(NodeId, NodeId, &'static str)>,
+ /// List of implicit edges of the control flow, used for jump
+ /// operations such as continue or break, the items are defined as
+ /// (from, to, label, color_id)
+ jumps: Vec<(NodeId, NodeId, &'static str, usize)>,
+ /// List of dependency relationships between a statement node and
+ /// expressions
+ dependencies: Vec<(NodeId, Handle<crate::Expression>, &'static str)>,
+ /// List of expression emitted by statement node
+ emits: Vec<(NodeId, Handle<crate::Expression>)>,
+ /// List of function call by statement node
+ calls: Vec<(NodeId, Handle<crate::Function>)>,
+}
+
+impl StatementGraph {
+ /// Adds a new block to the statement graph, returning the first and last node, respectively
+ fn add(&mut self, block: &[crate::Statement], targets: Targets) -> (NodeId, NodeId) {
+ use crate::Statement as S;
+
+ // The first node of the block isn't a statement but a virtual node
+ let root = self.nodes.len();
+ self.nodes.push(if root == 0 { "Root" } else { "Node" });
+ // Track the last placed node, this will be returned to the caller and
+ // will also be used to generate the control flow edges
+ let mut last_node = root;
+ for statement in block {
+ // Reserve a new node for the current statement and link it to the
+ // node of the previous statement
+ let id = self.nodes.len();
+ self.flow.push((last_node, id, ""));
+ self.nodes.push(""); // reserve space
+
+ // Track the node identifier for the merge node, the merge node is
+ // the last node of a statement, normally this is the node itself,
+ // but for control flow statements such as `if`s and `switch`s this
+ // is a virtual node where all branches merge back.
+ let mut merge_id = id;
+
+ self.nodes[id] = match *statement {
+ S::Emit(ref range) => {
+ for handle in range.clone() {
+ self.emits.push((id, handle));
+ }
+ "Emit"
+ }
+ S::Kill => "Kill", //TODO: link to the beginning
+ S::Break => {
+ // Try to link to the break target, otherwise produce
+ // a broken connection
+ if let Some(target) = targets.break_target {
+ self.jumps.push((id, target, "Break", 5))
+ } else {
+ self.jumps.push((id, root, "Broken", 7))
+ }
+ "Break"
+ }
+ S::Continue => {
+ // Try to link to the continue target, otherwise produce
+ // a broken connection
+ if let Some(target) = targets.continue_target {
+ self.jumps.push((id, target, "Continue", 5))
+ } else {
+ self.jumps.push((id, root, "Broken", 7))
+ }
+ "Continue"
+ }
+ S::Barrier(_flags) => "Barrier",
+ S::Block(ref b) => {
+ let (other, last) = self.add(b, targets);
+ self.flow.push((id, other, ""));
+ // All following nodes should connect to the end of the block
+ // statement so change the merge id to it.
+ merge_id = last;
+ "Block"
+ }
+ S::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ self.dependencies.push((id, condition, "condition"));
+ let (accept_id, accept_last) = self.add(accept, targets);
+ self.flow.push((id, accept_id, "accept"));
+ let (reject_id, reject_last) = self.add(reject, targets);
+ self.flow.push((id, reject_id, "reject"));
+
+ // Create a merge node, link the branches to it and set it
+ // as the merge node to make the next statement node link to it
+ merge_id = self.nodes.len();
+ self.nodes.push("Merge");
+ self.flow.push((accept_last, merge_id, ""));
+ self.flow.push((reject_last, merge_id, ""));
+
+ "If"
+ }
+ S::Switch {
+ selector,
+ ref cases,
+ } => {
+ self.dependencies.push((id, selector, "selector"));
+
+ // Create a merge node and set it as the merge node to make
+ // the next statement node link to it
+ merge_id = self.nodes.len();
+ self.nodes.push("Merge");
+
+ // Create a new targets structure and set the break target
+ // to the merge node
+ let mut targets = targets;
+ targets.break_target = Some(merge_id);
+
+ for case in cases {
+ let (case_id, case_last) = self.add(&case.body, targets);
+ let label = match case.value {
+ crate::SwitchValue::Default => "default",
+ _ => "case",
+ };
+ self.flow.push((id, case_id, label));
+ // Link the last node of the branch to the merge node
+ self.flow.push((case_last, merge_id, ""));
+ }
+ "Switch"
+ }
+ S::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ // Create a new targets structure and set the break target
+ // to the merge node, this must happen before generating the
+ // continuing block since it can break.
+ let mut targets = targets;
+ targets.break_target = Some(id);
+
+ let (continuing_id, continuing_last) = self.add(continuing, targets);
+
+ // Set the the continue target to the beginning
+ // of the newly generated continuing block
+ targets.continue_target = Some(continuing_id);
+
+ let (body_id, body_last) = self.add(body, targets);
+
+ self.flow.push((id, body_id, "body"));
+
+ // Link the last node of the body to the continuing block
+ self.flow.push((body_last, continuing_id, "continuing"));
+ // Link the last node of the continuing block back to the
+ // beginning of the loop body
+ self.flow.push((continuing_last, body_id, "continuing"));
+
+ if let Some(expr) = break_if {
+ self.dependencies.push((continuing_id, expr, "break if"));
+ }
+
+ "Loop"
+ }
+ S::Return { value } => {
+ if let Some(expr) = value {
+ self.dependencies.push((id, expr, "value"));
+ }
+ "Return"
+ }
+ S::Store { pointer, value } => {
+ self.dependencies.push((id, value, "value"));
+ self.emits.push((id, pointer));
+ "Store"
+ }
+ S::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ self.dependencies.push((id, image, "image"));
+ self.dependencies.push((id, coordinate, "coordinate"));
+ if let Some(expr) = array_index {
+ self.dependencies.push((id, expr, "array_index"));
+ }
+ self.dependencies.push((id, value, "value"));
+ "ImageStore"
+ }
+ S::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ for &arg in arguments {
+ self.dependencies.push((id, arg, "arg"));
+ }
+ if let Some(expr) = result {
+ self.emits.push((id, expr));
+ }
+ self.calls.push((id, function));
+ "Call"
+ }
+ S::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ self.emits.push((id, result));
+ self.dependencies.push((id, pointer, "pointer"));
+ self.dependencies.push((id, value, "value"));
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ self.dependencies.push((id, cmp, "cmp"));
+ }
+ "Atomic"
+ }
+ S::WorkGroupUniformLoad { pointer, result } => {
+ self.emits.push((id, result));
+ self.dependencies.push((id, pointer, "pointer"));
+ "WorkGroupUniformLoad"
+ }
+ S::RayQuery { query, ref fun } => {
+ self.dependencies.push((id, query, "query"));
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ self.dependencies.push((
+ id,
+ acceleration_structure,
+ "acceleration_structure",
+ ));
+ self.dependencies.push((id, descriptor, "descriptor"));
+ "RayQueryInitialize"
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ self.emits.push((id, result));
+ "RayQueryProceed"
+ }
+ crate::RayQueryFunction::Terminate => "RayQueryTerminate",
+ }
+ }
+ };
+ // Set the last node to the merge node
+ last_node = merge_id;
+ }
+ (root, last_node)
+ }
+}
+
+#[allow(clippy::manual_unwrap_or)]
+fn name(option: &Option<String>) -> &str {
+ match *option {
+ Some(ref name) => name,
+ None => "",
+ }
+}
+
+/// set39 color scheme from <https://graphviz.org/doc/info/colors.html>
+const COLORS: &[&str] = &[
+ "white", // pattern starts at 1
+ "#8dd3c7", "#ffffb3", "#bebada", "#fb8072", "#80b1d3", "#fdb462", "#b3de69", "#fccde5",
+ "#d9d9d9",
+];
+
+fn write_fun(
+ output: &mut String,
+ prefix: String,
+ fun: &crate::Function,
+ info: Option<&FunctionInfo>,
+ options: &Options,
+) -> Result<(), FmtError> {
+ writeln!(output, "\t\tnode [ style=filled ]")?;
+
+ if !options.cfg_only {
+ for (handle, var) in fun.local_variables.iter() {
+ writeln!(
+ output,
+ "\t\t{}_l{} [ shape=hexagon label=\"{:?} '{}'\" ]",
+ prefix,
+ handle.index(),
+ handle,
+ name(&var.name),
+ )?;
+ }
+
+ write_function_expressions(output, &prefix, fun, info)?;
+ }
+
+ let mut sg = StatementGraph::default();
+ sg.add(&fun.body, Targets::default());
+ for (index, label) in sg.nodes.into_iter().enumerate() {
+ writeln!(
+ output,
+ "\t\t{prefix}_s{index} [ shape=square label=\"{label}\" ]",
+ )?;
+ }
+ for (from, to, label) in sg.flow {
+ writeln!(
+ output,
+ "\t\t{prefix}_s{from} -> {prefix}_s{to} [ arrowhead=tee label=\"{label}\" ]",
+ )?;
+ }
+ for (from, to, label, color_id) in sg.jumps {
+ writeln!(
+ output,
+ "\t\t{}_s{} -> {}_s{} [ arrowhead=tee style=dashed color=\"{}\" label=\"{}\" ]",
+ prefix, from, prefix, to, COLORS[color_id], label,
+ )?;
+ }
+
+ if !options.cfg_only {
+ for (to, expr, label) in sg.dependencies {
+ writeln!(
+ output,
+ "\t\t{}_e{} -> {}_s{} [ label=\"{}\" ]",
+ prefix,
+ expr.index(),
+ prefix,
+ to,
+ label,
+ )?;
+ }
+ for (from, to) in sg.emits {
+ writeln!(
+ output,
+ "\t\t{}_s{} -> {}_e{} [ style=dotted ]",
+ prefix,
+ from,
+ prefix,
+ to.index(),
+ )?;
+ }
+ }
+
+ for (from, function) in sg.calls {
+ writeln!(
+ output,
+ "\t\t{}_s{} -> f{}_s0",
+ prefix,
+ from,
+ function.index(),
+ )?;
+ }
+
+ Ok(())
+}
+
+fn write_function_expressions(
+ output: &mut String,
+ prefix: &str,
+ fun: &crate::Function,
+ info: Option<&FunctionInfo>,
+) -> Result<(), FmtError> {
+ enum Payload<'a> {
+ Arguments(&'a [Handle<crate::Expression>]),
+ Local(Handle<crate::LocalVariable>),
+ Global(Handle<crate::GlobalVariable>),
+ }
+
+ let mut edges = crate::FastHashMap::<&str, _>::default();
+ let mut payload = None;
+ for (handle, expression) in fun.expressions.iter() {
+ use crate::Expression as E;
+ let (label, color_id) = match *expression {
+ E::Literal(_) => ("Literal".into(), 2),
+ E::Constant(_) => ("Constant".into(), 2),
+ E::ZeroValue(_) => ("ZeroValue".into(), 2),
+ E::Compose { ref components, .. } => {
+ payload = Some(Payload::Arguments(components));
+ ("Compose".into(), 3)
+ }
+ E::Access { base, index } => {
+ edges.insert("base", base);
+ edges.insert("index", index);
+ ("Access".into(), 1)
+ }
+ E::AccessIndex { base, index } => {
+ edges.insert("base", base);
+ (format!("AccessIndex[{index}]").into(), 1)
+ }
+ E::Splat { size, value } => {
+ edges.insert("value", value);
+ (format!("Splat{size:?}").into(), 3)
+ }
+ E::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ edges.insert("vector", vector);
+ (format!("Swizzle{:?}", &pattern[..size as usize]).into(), 3)
+ }
+ E::FunctionArgument(index) => (format!("Argument[{index}]").into(), 1),
+ E::GlobalVariable(h) => {
+ payload = Some(Payload::Global(h));
+ ("Global".into(), 2)
+ }
+ E::LocalVariable(h) => {
+ payload = Some(Payload::Local(h));
+ ("Local".into(), 1)
+ }
+ E::Load { pointer } => {
+ edges.insert("pointer", pointer);
+ ("Load".into(), 4)
+ }
+ E::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset: _,
+ level,
+ depth_ref,
+ } => {
+ edges.insert("image", image);
+ edges.insert("sampler", sampler);
+ edges.insert("coordinate", coordinate);
+ if let Some(expr) = array_index {
+ edges.insert("array_index", expr);
+ }
+ match level {
+ crate::SampleLevel::Auto => {}
+ crate::SampleLevel::Zero => {}
+ crate::SampleLevel::Exact(expr) => {
+ edges.insert("level", expr);
+ }
+ crate::SampleLevel::Bias(expr) => {
+ edges.insert("bias", expr);
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ edges.insert("grad_x", x);
+ edges.insert("grad_y", y);
+ }
+ }
+ if let Some(expr) = depth_ref {
+ edges.insert("depth_ref", expr);
+ }
+ let string = match gather {
+ Some(component) => Cow::Owned(format!("ImageGather{component:?}")),
+ _ => Cow::Borrowed("ImageSample"),
+ };
+ (string, 5)
+ }
+ E::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ edges.insert("image", image);
+ edges.insert("coordinate", coordinate);
+ if let Some(expr) = array_index {
+ edges.insert("array_index", expr);
+ }
+ if let Some(sample) = sample {
+ edges.insert("sample", sample);
+ }
+ if let Some(level) = level {
+ edges.insert("level", level);
+ }
+ ("ImageLoad".into(), 5)
+ }
+ E::ImageQuery { image, query } => {
+ edges.insert("image", image);
+ let args = match query {
+ crate::ImageQuery::Size { level } => {
+ if let Some(expr) = level {
+ edges.insert("level", expr);
+ }
+ Cow::from("ImageSize")
+ }
+ _ => Cow::Owned(format!("{query:?}")),
+ };
+ (args, 7)
+ }
+ E::Unary { op, expr } => {
+ edges.insert("expr", expr);
+ (format!("{op:?}").into(), 6)
+ }
+ E::Binary { op, left, right } => {
+ edges.insert("left", left);
+ edges.insert("right", right);
+ (format!("{op:?}").into(), 6)
+ }
+ E::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ edges.insert("condition", condition);
+ edges.insert("accept", accept);
+ edges.insert("reject", reject);
+ ("Select".into(), 3)
+ }
+ E::Derivative { axis, ctrl, expr } => {
+ edges.insert("", expr);
+ (format!("d{axis:?}{ctrl:?}").into(), 8)
+ }
+ E::Relational { fun, argument } => {
+ edges.insert("arg", argument);
+ (format!("{fun:?}").into(), 6)
+ }
+ E::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ edges.insert("arg", arg);
+ if let Some(expr) = arg1 {
+ edges.insert("arg1", expr);
+ }
+ if let Some(expr) = arg2 {
+ edges.insert("arg2", expr);
+ }
+ if let Some(expr) = arg3 {
+ edges.insert("arg3", expr);
+ }
+ (format!("{fun:?}").into(), 7)
+ }
+ E::As {
+ kind,
+ expr,
+ convert,
+ } => {
+ edges.insert("", expr);
+ let string = match convert {
+ Some(width) => format!("Convert<{kind:?},{width}>"),
+ None => format!("Bitcast<{kind:?}>"),
+ };
+ (string.into(), 3)
+ }
+ E::CallResult(_function) => ("CallResult".into(), 4),
+ E::AtomicResult { .. } => ("AtomicResult".into(), 4),
+ E::WorkGroupUniformLoadResult { .. } => ("WorkGroupUniformLoadResult".into(), 4),
+ E::ArrayLength(expr) => {
+ edges.insert("", expr);
+ ("ArrayLength".into(), 7)
+ }
+ E::RayQueryProceedResult => ("rayQueryProceedResult".into(), 4),
+ E::RayQueryGetIntersection { query, committed } => {
+ edges.insert("", query);
+ let ty = if committed { "Committed" } else { "Candidate" };
+ (format!("rayQueryGet{}Intersection", ty).into(), 4)
+ }
+ };
+
+ // give uniform expressions an outline
+ let color_attr = match info {
+ Some(info) if info[handle].uniformity.non_uniform_result.is_none() => "fillcolor",
+ _ => "color",
+ };
+ writeln!(
+ output,
+ "\t\t{}_e{} [ {}=\"{}\" label=\"{:?} {}\" ]",
+ prefix,
+ handle.index(),
+ color_attr,
+ COLORS[color_id],
+ handle,
+ label,
+ )?;
+
+ for (key, edge) in edges.drain() {
+ writeln!(
+ output,
+ "\t\t{}_e{} -> {}_e{} [ label=\"{}\" ]",
+ prefix,
+ edge.index(),
+ prefix,
+ handle.index(),
+ key,
+ )?;
+ }
+ match payload.take() {
+ Some(Payload::Arguments(list)) => {
+ write!(output, "\t\t{{")?;
+ for &comp in list {
+ write!(output, " {}_e{}", prefix, comp.index())?;
+ }
+ writeln!(output, " }} -> {}_e{}", prefix, handle.index())?;
+ }
+ Some(Payload::Local(h)) => {
+ writeln!(
+ output,
+ "\t\t{}_l{} -> {}_e{}",
+ prefix,
+ h.index(),
+ prefix,
+ handle.index(),
+ )?;
+ }
+ Some(Payload::Global(h)) => {
+ writeln!(
+ output,
+ "\t\tg{} -> {}_e{} [fillcolor=gray]",
+ h.index(),
+ prefix,
+ handle.index(),
+ )?;
+ }
+ None => {}
+ }
+ }
+
+ Ok(())
+}
+
+/// Write shader module to a [`String`].
+pub fn write(
+ module: &crate::Module,
+ mod_info: Option<&ModuleInfo>,
+ options: Options,
+) -> Result<String, FmtError> {
+ use std::fmt::Write as _;
+
+ let mut output = String::new();
+ output += "digraph Module {\n";
+
+ if !options.cfg_only {
+ writeln!(output, "\tsubgraph cluster_globals {{")?;
+ writeln!(output, "\t\tlabel=\"Globals\"")?;
+ for (handle, var) in module.global_variables.iter() {
+ writeln!(
+ output,
+ "\t\tg{} [ shape=hexagon label=\"{:?} {:?}/'{}'\" ]",
+ handle.index(),
+ handle,
+ var.space,
+ name(&var.name),
+ )?;
+ }
+ writeln!(output, "\t}}")?;
+ }
+
+ for (handle, fun) in module.functions.iter() {
+ let prefix = format!("f{}", handle.index());
+ writeln!(output, "\tsubgraph cluster_{prefix} {{")?;
+ writeln!(
+ output,
+ "\t\tlabel=\"Function{:?}/'{}'\"",
+ handle,
+ name(&fun.name)
+ )?;
+ let info = mod_info.map(|a| &a[handle]);
+ write_fun(&mut output, prefix, fun, info, &options)?;
+ writeln!(output, "\t}}")?;
+ }
+ for (ep_index, ep) in module.entry_points.iter().enumerate() {
+ let prefix = format!("ep{ep_index}");
+ writeln!(output, "\tsubgraph cluster_{prefix} {{")?;
+ writeln!(output, "\t\tlabel=\"{:?}/'{}'\"", ep.stage, ep.name)?;
+ let info = mod_info.map(|a| a.get_entry_point(ep_index));
+ write_fun(&mut output, prefix, &ep.function, info, &options)?;
+ writeln!(output, "\t}}")?;
+ }
+
+ output += "}\n";
+ Ok(output)
+}
diff --git a/third_party/rust/naga/src/back/glsl/features.rs b/third_party/rust/naga/src/back/glsl/features.rs
new file mode 100644
index 0000000000..e7de05f695
--- /dev/null
+++ b/third_party/rust/naga/src/back/glsl/features.rs
@@ -0,0 +1,536 @@
+use super::{BackendResult, Error, Version, Writer};
+use crate::{
+ back::glsl::{Options, WriterFlags},
+ AddressSpace, Binding, Expression, Handle, ImageClass, ImageDimension, Interpolation, Sampling,
+ Scalar, ScalarKind, ShaderStage, StorageFormat, Type, TypeInner,
+};
+use std::fmt::Write;
+
+bitflags::bitflags! {
+ /// Structure used to encode additions to GLSL that aren't supported by all versions.
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct Features: u32 {
+ /// Buffer address space support.
+ const BUFFER_STORAGE = 1;
+ const ARRAY_OF_ARRAYS = 1 << 1;
+ /// 8 byte floats.
+ const DOUBLE_TYPE = 1 << 2;
+ /// More image formats.
+ const FULL_IMAGE_FORMATS = 1 << 3;
+ const MULTISAMPLED_TEXTURES = 1 << 4;
+ const MULTISAMPLED_TEXTURE_ARRAYS = 1 << 5;
+ const CUBE_TEXTURES_ARRAY = 1 << 6;
+ const COMPUTE_SHADER = 1 << 7;
+ /// Image load and early depth tests.
+ const IMAGE_LOAD_STORE = 1 << 8;
+ const CONSERVATIVE_DEPTH = 1 << 9;
+ /// Interpolation and auxiliary qualifiers.
+ ///
+ /// Perspective, Flat, and Centroid are available in all GLSL versions we support.
+ const NOPERSPECTIVE_QUALIFIER = 1 << 11;
+ const SAMPLE_QUALIFIER = 1 << 12;
+ const CLIP_DISTANCE = 1 << 13;
+ const CULL_DISTANCE = 1 << 14;
+ /// Sample ID.
+ const SAMPLE_VARIABLES = 1 << 15;
+ /// Arrays with a dynamic length.
+ const DYNAMIC_ARRAY_SIZE = 1 << 16;
+ const MULTI_VIEW = 1 << 17;
+ /// Texture samples query
+ const TEXTURE_SAMPLES = 1 << 18;
+ /// Texture levels query
+ const TEXTURE_LEVELS = 1 << 19;
+ /// Image size query
+ const IMAGE_SIZE = 1 << 20;
+ /// Dual source blending
+ const DUAL_SOURCE_BLENDING = 1 << 21;
+ /// Instance index
+ ///
+ /// We can always support this, either through the language or a polyfill
+ const INSTANCE_INDEX = 1 << 22;
+ }
+}
+
+/// Helper structure used to store the required [`Features`] needed to output a
+/// [`Module`](crate::Module)
+///
+/// Provides helper methods to check for availability and writing required extensions
+pub struct FeaturesManager(Features);
+
+impl FeaturesManager {
+ /// Creates a new [`FeaturesManager`] instance
+ pub const fn new() -> Self {
+ Self(Features::empty())
+ }
+
+ /// Adds to the list of required [`Features`]
+ pub fn request(&mut self, features: Features) {
+ self.0 |= features
+ }
+
+ /// Checks if the list of features [`Features`] contains the specified [`Features`]
+ pub fn contains(&mut self, features: Features) -> bool {
+ self.0.contains(features)
+ }
+
+ /// Checks that all required [`Features`] are available for the specified
+ /// [`Version`] otherwise returns an [`Error::MissingFeatures`].
+ pub fn check_availability(&self, version: Version) -> BackendResult {
+ // Will store all the features that are unavailable
+ let mut missing = Features::empty();
+
+ // Helper macro to check for feature availability
+ macro_rules! check_feature {
+ // Used when only core glsl supports the feature
+ ($feature:ident, $core:literal) => {
+ if self.0.contains(Features::$feature)
+ && (version < Version::Desktop($core) || version.is_es())
+ {
+ missing |= Features::$feature;
+ }
+ };
+ // Used when both core and es support the feature
+ ($feature:ident, $core:literal, $es:literal) => {
+ if self.0.contains(Features::$feature)
+ && (version < Version::Desktop($core) || version < Version::new_gles($es))
+ {
+ missing |= Features::$feature;
+ }
+ };
+ }
+
+ check_feature!(COMPUTE_SHADER, 420, 310);
+ check_feature!(BUFFER_STORAGE, 400, 310);
+ check_feature!(DOUBLE_TYPE, 150);
+ check_feature!(CUBE_TEXTURES_ARRAY, 130, 310);
+ check_feature!(MULTISAMPLED_TEXTURES, 150, 300);
+ check_feature!(MULTISAMPLED_TEXTURE_ARRAYS, 150, 310);
+ check_feature!(ARRAY_OF_ARRAYS, 120, 310);
+ check_feature!(IMAGE_LOAD_STORE, 130, 310);
+ check_feature!(CONSERVATIVE_DEPTH, 130, 300);
+ check_feature!(NOPERSPECTIVE_QUALIFIER, 130);
+ check_feature!(SAMPLE_QUALIFIER, 400, 320);
+ check_feature!(CLIP_DISTANCE, 130, 300 /* with extension */);
+ check_feature!(CULL_DISTANCE, 450, 300 /* with extension */);
+ check_feature!(SAMPLE_VARIABLES, 400, 300);
+ check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
+ check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */);
+ match version {
+ Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300),
+ _ => check_feature!(MULTI_VIEW, 140, 310),
+ };
+ // Only available on glsl core, this means that opengl es can't query the number
+ // of samples nor levels in a image and neither do bound checks on the sample nor
+ // the level argument of texelFecth
+ check_feature!(TEXTURE_SAMPLES, 150);
+ check_feature!(TEXTURE_LEVELS, 130);
+ check_feature!(IMAGE_SIZE, 430, 310);
+
+ // Return an error if there are missing features
+ if missing.is_empty() {
+ Ok(())
+ } else {
+ Err(Error::MissingFeatures(missing))
+ }
+ }
+
+ /// Helper method used to write all needed extensions
+ ///
+ /// # Notes
+ /// This won't check for feature availability so it might output extensions that aren't even
+ /// supported.[`check_availability`](Self::check_availability) will check feature availability
+ pub fn write(&self, options: &Options, mut out: impl Write) -> BackendResult {
+ if self.0.contains(Features::COMPUTE_SHADER) && !options.version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_compute_shader.txt
+ writeln!(out, "#extension GL_ARB_compute_shader : require")?;
+ }
+
+ if self.0.contains(Features::BUFFER_STORAGE) && !options.version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_storage_buffer_object.txt
+ writeln!(
+ out,
+ "#extension GL_ARB_shader_storage_buffer_object : require"
+ )?;
+ }
+
+ if self.0.contains(Features::DOUBLE_TYPE) && options.version < Version::Desktop(400) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_gpu_shader_fp64.txt
+ writeln!(out, "#extension GL_ARB_gpu_shader_fp64 : require")?;
+ }
+
+ if self.0.contains(Features::CUBE_TEXTURES_ARRAY) {
+ if options.version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_cube_map_array.txt
+ writeln!(out, "#extension GL_EXT_texture_cube_map_array : require")?;
+ } else if options.version < Version::Desktop(400) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_cube_map_array.txt
+ writeln!(out, "#extension GL_ARB_texture_cube_map_array : require")?;
+ }
+ }
+
+ if self.0.contains(Features::MULTISAMPLED_TEXTURE_ARRAYS) && options.version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_texture_storage_multisample_2d_array.txt
+ writeln!(
+ out,
+ "#extension GL_OES_texture_storage_multisample_2d_array : require"
+ )?;
+ }
+
+ if self.0.contains(Features::ARRAY_OF_ARRAYS) && options.version < Version::Desktop(430) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_arrays_of_arrays.txt
+ writeln!(out, "#extension ARB_arrays_of_arrays : require")?;
+ }
+
+ if self.0.contains(Features::IMAGE_LOAD_STORE) {
+ if self.0.contains(Features::FULL_IMAGE_FORMATS) && options.version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/NV/NV_image_formats.txt
+ writeln!(out, "#extension GL_NV_image_formats : require")?;
+ }
+
+ if options.version < Version::Desktop(420) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_image_load_store.txt
+ writeln!(out, "#extension GL_ARB_shader_image_load_store : require")?;
+ }
+ }
+
+ if self.0.contains(Features::CONSERVATIVE_DEPTH) {
+ if options.version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_conservative_depth.txt
+ writeln!(out, "#extension GL_EXT_conservative_depth : require")?;
+ }
+
+ if options.version < Version::Desktop(420) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt
+ writeln!(out, "#extension GL_ARB_conservative_depth : require")?;
+ }
+ }
+
+ if (self.0.contains(Features::CLIP_DISTANCE) || self.0.contains(Features::CULL_DISTANCE))
+ && options.version.is_es()
+ {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_clip_cull_distance.txt
+ writeln!(out, "#extension GL_EXT_clip_cull_distance : require")?;
+ }
+
+ if self.0.contains(Features::SAMPLE_VARIABLES) && options.version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_sample_variables.txt
+ writeln!(out, "#extension GL_OES_sample_variables : require")?;
+ }
+
+ if self.0.contains(Features::MULTI_VIEW) {
+ if let Version::Embedded { is_webgl: true, .. } = options.version {
+ // https://www.khronos.org/registry/OpenGL/extensions/OVR/OVR_multiview2.txt
+ writeln!(out, "#extension GL_OVR_multiview2 : require")?;
+ } else {
+ // https://github.com/KhronosGroup/GLSL/blob/master/extensions/ext/GL_EXT_multiview.txt
+ writeln!(out, "#extension GL_EXT_multiview : require")?;
+ }
+ }
+
+ if self.0.contains(Features::TEXTURE_SAMPLES) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_texture_image_samples.txt
+ writeln!(
+ out,
+ "#extension GL_ARB_shader_texture_image_samples : require"
+ )?;
+ }
+
+ if self.0.contains(Features::TEXTURE_LEVELS) && options.version < Version::Desktop(430) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_query_levels.txt
+ writeln!(out, "#extension GL_ARB_texture_query_levels : require")?;
+ }
+ if self.0.contains(Features::DUAL_SOURCE_BLENDING) && options.version.is_es() {
+ // https://registry.khronos.org/OpenGL/extensions/EXT/EXT_blend_func_extended.txt
+ writeln!(out, "#extension GL_EXT_blend_func_extended : require")?;
+ }
+
+ if self.0.contains(Features::INSTANCE_INDEX) {
+ if options.writer_flags.contains(WriterFlags::DRAW_PARAMETERS) {
+ // https://registry.khronos.org/OpenGL/extensions/ARB/ARB_shader_draw_parameters.txt
+ writeln!(out, "#extension GL_ARB_shader_draw_parameters : require")?;
+ }
+ }
+
+ Ok(())
+ }
+}
+
+impl<'a, W> Writer<'a, W> {
+ /// Helper method that searches the module for all the needed [`Features`]
+ ///
+ /// # Errors
+ /// If the version doesn't support any of the needed [`Features`] a
+ /// [`Error::MissingFeatures`] will be returned
+ pub(super) fn collect_required_features(&mut self) -> BackendResult {
+ let ep_info = self.info.get_entry_point(self.entry_point_idx as usize);
+
+ if let Some(depth_test) = self.entry_point.early_depth_test {
+ // If IMAGE_LOAD_STORE is supported for this version of GLSL
+ if self.options.version.supports_early_depth_test() {
+ self.features.request(Features::IMAGE_LOAD_STORE);
+ }
+
+ if depth_test.conservative.is_some() {
+ self.features.request(Features::CONSERVATIVE_DEPTH);
+ }
+ }
+
+ for arg in self.entry_point.function.arguments.iter() {
+ self.varying_required_features(arg.binding.as_ref(), arg.ty);
+ }
+ if let Some(ref result) = self.entry_point.function.result {
+ self.varying_required_features(result.binding.as_ref(), result.ty);
+ }
+
+ if let ShaderStage::Compute = self.entry_point.stage {
+ self.features.request(Features::COMPUTE_SHADER)
+ }
+
+ if self.multiview.is_some() {
+ self.features.request(Features::MULTI_VIEW);
+ }
+
+ for (ty_handle, ty) in self.module.types.iter() {
+ match ty.inner {
+ TypeInner::Scalar(scalar)
+ | TypeInner::Vector { scalar, .. }
+ | TypeInner::Matrix { scalar, .. } => self.scalar_required_features(scalar),
+ TypeInner::Array { base, size, .. } => {
+ if let TypeInner::Array { .. } = self.module.types[base].inner {
+ self.features.request(Features::ARRAY_OF_ARRAYS)
+ }
+
+ // If the array is dynamically sized
+ if size == crate::ArraySize::Dynamic {
+ let mut is_used = false;
+
+ // Check if this type is used in a global that is needed by the current entrypoint
+ for (global_handle, global) in self.module.global_variables.iter() {
+ // Skip unused globals
+ if ep_info[global_handle].is_empty() {
+ continue;
+ }
+
+ // If this array is the type of a global, then this array is used
+ if global.ty == ty_handle {
+ is_used = true;
+ break;
+ }
+
+ // If the type of this global is a struct
+ if let crate::TypeInner::Struct { ref members, .. } =
+ self.module.types[global.ty].inner
+ {
+ // Check the last element of the struct to see if it's type uses
+ // this array
+ if let Some(last) = members.last() {
+ if last.ty == ty_handle {
+ is_used = true;
+ break;
+ }
+ }
+ }
+ }
+
+ // If this dynamically size array is used, we need dynamic array size support
+ if is_used {
+ self.features.request(Features::DYNAMIC_ARRAY_SIZE);
+ }
+ }
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ if arrayed && dim == ImageDimension::Cube {
+ self.features.request(Features::CUBE_TEXTURES_ARRAY)
+ }
+
+ match class {
+ ImageClass::Sampled { multi: true, .. }
+ | ImageClass::Depth { multi: true } => {
+ self.features.request(Features::MULTISAMPLED_TEXTURES);
+ if arrayed {
+ self.features.request(Features::MULTISAMPLED_TEXTURE_ARRAYS);
+ }
+ }
+ ImageClass::Storage { format, .. } => match format {
+ StorageFormat::R8Unorm
+ | StorageFormat::R8Snorm
+ | StorageFormat::R8Uint
+ | StorageFormat::R8Sint
+ | StorageFormat::R16Uint
+ | StorageFormat::R16Sint
+ | StorageFormat::R16Float
+ | StorageFormat::Rg8Unorm
+ | StorageFormat::Rg8Snorm
+ | StorageFormat::Rg8Uint
+ | StorageFormat::Rg8Sint
+ | StorageFormat::Rg16Uint
+ | StorageFormat::Rg16Sint
+ | StorageFormat::Rg16Float
+ | StorageFormat::Rgb10a2Uint
+ | StorageFormat::Rgb10a2Unorm
+ | StorageFormat::Rg11b10Float
+ | StorageFormat::Rg32Uint
+ | StorageFormat::Rg32Sint
+ | StorageFormat::Rg32Float => {
+ self.features.request(Features::FULL_IMAGE_FORMATS)
+ }
+ _ => {}
+ },
+ ImageClass::Sampled { multi: false, .. }
+ | ImageClass::Depth { multi: false } => {}
+ }
+ }
+ _ => {}
+ }
+ }
+
+ let mut push_constant_used = false;
+
+ for (handle, global) in self.module.global_variables.iter() {
+ if ep_info[handle].is_empty() {
+ continue;
+ }
+ match global.space {
+ AddressSpace::WorkGroup => self.features.request(Features::COMPUTE_SHADER),
+ AddressSpace::Storage { .. } => self.features.request(Features::BUFFER_STORAGE),
+ AddressSpace::PushConstant => {
+ if push_constant_used {
+ return Err(Error::MultiplePushConstants);
+ }
+ push_constant_used = true;
+ }
+ _ => {}
+ }
+ }
+
+ // We will need to pass some of the members to a closure, so we need
+ // to separate them otherwise the borrow checker will complain, this
+ // shouldn't be needed in rust 2021
+ let &mut Self {
+ module,
+ info,
+ ref mut features,
+ entry_point,
+ entry_point_idx,
+ ref policies,
+ ..
+ } = self;
+
+ // Loop trough all expressions in both functions and the entry point
+ // to check for needed features
+ for (expressions, info) in module
+ .functions
+ .iter()
+ .map(|(h, f)| (&f.expressions, &info[h]))
+ .chain(std::iter::once((
+ &entry_point.function.expressions,
+ info.get_entry_point(entry_point_idx as usize),
+ )))
+ {
+ for (_, expr) in expressions.iter() {
+ match *expr {
+ // Check for queries that need aditonal features
+ Expression::ImageQuery {
+ image,
+ query,
+ ..
+ } => match query {
+ // Storage images use `imageSize` which is only available
+ // in glsl > 420
+ //
+ // layers queries are also implemented as size queries
+ crate::ImageQuery::Size { .. } | crate::ImageQuery::NumLayers => {
+ if let TypeInner::Image {
+ class: crate::ImageClass::Storage { .. }, ..
+ } = *info[image].ty.inner_with(&module.types) {
+ features.request(Features::IMAGE_SIZE)
+ }
+ },
+ crate::ImageQuery::NumLevels => features.request(Features::TEXTURE_LEVELS),
+ crate::ImageQuery::NumSamples => features.request(Features::TEXTURE_SAMPLES),
+ }
+ ,
+ // Check for image loads that needs bound checking on the sample
+ // or level argument since this requires a feature
+ Expression::ImageLoad {
+ sample, level, ..
+ } => {
+ if policies.image_load != crate::proc::BoundsCheckPolicy::Unchecked {
+ if sample.is_some() {
+ features.request(Features::TEXTURE_SAMPLES)
+ }
+
+ if level.is_some() {
+ features.request(Features::TEXTURE_LEVELS)
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+
+ self.features.check_availability(self.options.version)
+ }
+
+ /// Helper method that checks the [`Features`] needed by a scalar
+ fn scalar_required_features(&mut self, scalar: Scalar) {
+ if scalar.kind == ScalarKind::Float && scalar.width == 8 {
+ self.features.request(Features::DOUBLE_TYPE);
+ }
+ }
+
+ fn varying_required_features(&mut self, binding: Option<&Binding>, ty: Handle<Type>) {
+ match self.module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for member in members {
+ self.varying_required_features(member.binding.as_ref(), member.ty);
+ }
+ }
+ _ => {
+ if let Some(binding) = binding {
+ match *binding {
+ Binding::BuiltIn(built_in) => match built_in {
+ crate::BuiltIn::ClipDistance => {
+ self.features.request(Features::CLIP_DISTANCE)
+ }
+ crate::BuiltIn::CullDistance => {
+ self.features.request(Features::CULL_DISTANCE)
+ }
+ crate::BuiltIn::SampleIndex => {
+ self.features.request(Features::SAMPLE_VARIABLES)
+ }
+ crate::BuiltIn::ViewIndex => {
+ self.features.request(Features::MULTI_VIEW)
+ }
+ crate::BuiltIn::InstanceIndex => {
+ self.features.request(Features::INSTANCE_INDEX)
+ }
+ _ => {}
+ },
+ Binding::Location {
+ location: _,
+ interpolation,
+ sampling,
+ second_blend_source,
+ } => {
+ if interpolation == Some(Interpolation::Linear) {
+ self.features.request(Features::NOPERSPECTIVE_QUALIFIER);
+ }
+ if sampling == Some(Sampling::Sample) {
+ self.features.request(Features::SAMPLE_QUALIFIER);
+ }
+ if second_blend_source {
+ self.features.request(Features::DUAL_SOURCE_BLENDING);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/glsl/keywords.rs b/third_party/rust/naga/src/back/glsl/keywords.rs
new file mode 100644
index 0000000000..857c935e68
--- /dev/null
+++ b/third_party/rust/naga/src/back/glsl/keywords.rs
@@ -0,0 +1,484 @@
+pub const RESERVED_KEYWORDS: &[&str] = &[
+ //
+ // GLSL 4.6 keywords, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L2004-L2322
+ // GLSL ES 3.2 keywords, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/es/3.2/GLSL_ES_Specification_3.20.html#L2166-L2478
+ //
+ // Note: The GLSL ES 3.2 keywords are the same as GLSL 4.6 keywords with some residing in the reserved section.
+ // The only exception are the missing Vulkan keywords which I think is an oversight (see https://github.com/KhronosGroup/OpenGL-Registry/issues/585).
+ //
+ "const",
+ "uniform",
+ "buffer",
+ "shared",
+ "attribute",
+ "varying",
+ "coherent",
+ "volatile",
+ "restrict",
+ "readonly",
+ "writeonly",
+ "atomic_uint",
+ "layout",
+ "centroid",
+ "flat",
+ "smooth",
+ "noperspective",
+ "patch",
+ "sample",
+ "invariant",
+ "precise",
+ "break",
+ "continue",
+ "do",
+ "for",
+ "while",
+ "switch",
+ "case",
+ "default",
+ "if",
+ "else",
+ "subroutine",
+ "in",
+ "out",
+ "inout",
+ "int",
+ "void",
+ "bool",
+ "true",
+ "false",
+ "float",
+ "double",
+ "discard",
+ "return",
+ "vec2",
+ "vec3",
+ "vec4",
+ "ivec2",
+ "ivec3",
+ "ivec4",
+ "bvec2",
+ "bvec3",
+ "bvec4",
+ "uint",
+ "uvec2",
+ "uvec3",
+ "uvec4",
+ "dvec2",
+ "dvec3",
+ "dvec4",
+ "mat2",
+ "mat3",
+ "mat4",
+ "mat2x2",
+ "mat2x3",
+ "mat2x4",
+ "mat3x2",
+ "mat3x3",
+ "mat3x4",
+ "mat4x2",
+ "mat4x3",
+ "mat4x4",
+ "dmat2",
+ "dmat3",
+ "dmat4",
+ "dmat2x2",
+ "dmat2x3",
+ "dmat2x4",
+ "dmat3x2",
+ "dmat3x3",
+ "dmat3x4",
+ "dmat4x2",
+ "dmat4x3",
+ "dmat4x4",
+ "lowp",
+ "mediump",
+ "highp",
+ "precision",
+ "sampler1D",
+ "sampler1DShadow",
+ "sampler1DArray",
+ "sampler1DArrayShadow",
+ "isampler1D",
+ "isampler1DArray",
+ "usampler1D",
+ "usampler1DArray",
+ "sampler2D",
+ "sampler2DShadow",
+ "sampler2DArray",
+ "sampler2DArrayShadow",
+ "isampler2D",
+ "isampler2DArray",
+ "usampler2D",
+ "usampler2DArray",
+ "sampler2DRect",
+ "sampler2DRectShadow",
+ "isampler2DRect",
+ "usampler2DRect",
+ "sampler2DMS",
+ "isampler2DMS",
+ "usampler2DMS",
+ "sampler2DMSArray",
+ "isampler2DMSArray",
+ "usampler2DMSArray",
+ "sampler3D",
+ "isampler3D",
+ "usampler3D",
+ "samplerCube",
+ "samplerCubeShadow",
+ "isamplerCube",
+ "usamplerCube",
+ "samplerCubeArray",
+ "samplerCubeArrayShadow",
+ "isamplerCubeArray",
+ "usamplerCubeArray",
+ "samplerBuffer",
+ "isamplerBuffer",
+ "usamplerBuffer",
+ "image1D",
+ "iimage1D",
+ "uimage1D",
+ "image1DArray",
+ "iimage1DArray",
+ "uimage1DArray",
+ "image2D",
+ "iimage2D",
+ "uimage2D",
+ "image2DArray",
+ "iimage2DArray",
+ "uimage2DArray",
+ "image2DRect",
+ "iimage2DRect",
+ "uimage2DRect",
+ "image2DMS",
+ "iimage2DMS",
+ "uimage2DMS",
+ "image2DMSArray",
+ "iimage2DMSArray",
+ "uimage2DMSArray",
+ "image3D",
+ "iimage3D",
+ "uimage3D",
+ "imageCube",
+ "iimageCube",
+ "uimageCube",
+ "imageCubeArray",
+ "iimageCubeArray",
+ "uimageCubeArray",
+ "imageBuffer",
+ "iimageBuffer",
+ "uimageBuffer",
+ "struct",
+ // Vulkan keywords
+ "texture1D",
+ "texture1DArray",
+ "itexture1D",
+ "itexture1DArray",
+ "utexture1D",
+ "utexture1DArray",
+ "texture2D",
+ "texture2DArray",
+ "itexture2D",
+ "itexture2DArray",
+ "utexture2D",
+ "utexture2DArray",
+ "texture2DRect",
+ "itexture2DRect",
+ "utexture2DRect",
+ "texture2DMS",
+ "itexture2DMS",
+ "utexture2DMS",
+ "texture2DMSArray",
+ "itexture2DMSArray",
+ "utexture2DMSArray",
+ "texture3D",
+ "itexture3D",
+ "utexture3D",
+ "textureCube",
+ "itextureCube",
+ "utextureCube",
+ "textureCubeArray",
+ "itextureCubeArray",
+ "utextureCubeArray",
+ "textureBuffer",
+ "itextureBuffer",
+ "utextureBuffer",
+ "sampler",
+ "samplerShadow",
+ "subpassInput",
+ "isubpassInput",
+ "usubpassInput",
+ "subpassInputMS",
+ "isubpassInputMS",
+ "usubpassInputMS",
+ // Reserved keywords
+ "common",
+ "partition",
+ "active",
+ "asm",
+ "class",
+ "union",
+ "enum",
+ "typedef",
+ "template",
+ "this",
+ "resource",
+ "goto",
+ "inline",
+ "noinline",
+ "public",
+ "static",
+ "extern",
+ "external",
+ "interface",
+ "long",
+ "short",
+ "half",
+ "fixed",
+ "unsigned",
+ "superp",
+ "input",
+ "output",
+ "hvec2",
+ "hvec3",
+ "hvec4",
+ "fvec2",
+ "fvec3",
+ "fvec4",
+ "filter",
+ "sizeof",
+ "cast",
+ "namespace",
+ "using",
+ "sampler3DRect",
+ //
+ // GLSL 4.6 Built-In Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13314
+ //
+ // Angle and Trigonometry Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13469-L13561C5
+ "radians",
+ "degrees",
+ "sin",
+ "cos",
+ "tan",
+ "asin",
+ "acos",
+ "atan",
+ "sinh",
+ "cosh",
+ "tanh",
+ "asinh",
+ "acosh",
+ "atanh",
+ // Exponential Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13569-L13620
+ "pow",
+ "exp",
+ "log",
+ "exp2",
+ "log2",
+ "sqrt",
+ "inversesqrt",
+ // Common Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13628-L13908
+ "abs",
+ "sign",
+ "floor",
+ "trunc",
+ "round",
+ "roundEven",
+ "ceil",
+ "fract",
+ "mod",
+ "modf",
+ "min",
+ "max",
+ "clamp",
+ "mix",
+ "step",
+ "smoothstep",
+ "isnan",
+ "isinf",
+ "floatBitsToInt",
+ "floatBitsToUint",
+ "intBitsToFloat",
+ "uintBitsToFloat",
+ "fma",
+ "frexp",
+ "ldexp",
+ // Floating-Point Pack and Unpack Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13916-L14007
+ "packUnorm2x16",
+ "packSnorm2x16",
+ "packUnorm4x8",
+ "packSnorm4x8",
+ "unpackUnorm2x16",
+ "unpackSnorm2x16",
+ "unpackUnorm4x8",
+ "unpackSnorm4x8",
+ "packHalf2x16",
+ "unpackHalf2x16",
+ "packDouble2x32",
+ "unpackDouble2x32",
+ // Geometric Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14014-L14121
+ "length",
+ "distance",
+ "dot",
+ "cross",
+ "normalize",
+ "ftransform",
+ "faceforward",
+ "reflect",
+ "refract",
+ // Matrix Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14151-L14215
+ "matrixCompMult",
+ "outerProduct",
+ "transpose",
+ "determinant",
+ "inverse",
+ // Vector Relational Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14259-L14322
+ "lessThan",
+ "lessThanEqual",
+ "greaterThan",
+ "greaterThanEqual",
+ "equal",
+ "notEqual",
+ "any",
+ "all",
+ "not",
+ // Integer Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14335-L14432
+ "uaddCarry",
+ "usubBorrow",
+ "umulExtended",
+ "imulExtended",
+ "bitfieldExtract",
+ "bitfieldInsert",
+ "bitfieldReverse",
+ "bitCount",
+ "findLSB",
+ "findMSB",
+ // Texture Query Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14645-L14732
+ "textureSize",
+ "textureQueryLod",
+ "textureQueryLevels",
+ "textureSamples",
+ // Texel Lookup Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14736-L14997
+ "texture",
+ "textureProj",
+ "textureLod",
+ "textureOffset",
+ "texelFetch",
+ "texelFetchOffset",
+ "textureProjOffset",
+ "textureLodOffset",
+ "textureProjLod",
+ "textureProjLodOffset",
+ "textureGrad",
+ "textureGradOffset",
+ "textureProjGrad",
+ "textureProjGradOffset",
+ // Texture Gather Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15077-L15154
+ "textureGather",
+ "textureGatherOffset",
+ "textureGatherOffsets",
+ // Compatibility Profile Texture Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15161-L15220
+ "texture1D",
+ "texture1DProj",
+ "texture1DLod",
+ "texture1DProjLod",
+ "texture2D",
+ "texture2DProj",
+ "texture2DLod",
+ "texture2DProjLod",
+ "texture3D",
+ "texture3DProj",
+ "texture3DLod",
+ "texture3DProjLod",
+ "textureCube",
+ "textureCubeLod",
+ "shadow1D",
+ "shadow2D",
+ "shadow1DProj",
+ "shadow2DProj",
+ "shadow1DLod",
+ "shadow2DLod",
+ "shadow1DProjLod",
+ "shadow2DProjLod",
+ // Atomic Counter Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15241-L15531
+ "atomicCounterIncrement",
+ "atomicCounterDecrement",
+ "atomicCounter",
+ "atomicCounterAdd",
+ "atomicCounterSubtract",
+ "atomicCounterMin",
+ "atomicCounterMax",
+ "atomicCounterAnd",
+ "atomicCounterOr",
+ "atomicCounterXor",
+ "atomicCounterExchange",
+ "atomicCounterCompSwap",
+ // Atomic Memory Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15563-L15624
+ "atomicAdd",
+ "atomicMin",
+ "atomicMax",
+ "atomicAnd",
+ "atomicOr",
+ "atomicXor",
+ "atomicExchange",
+ "atomicCompSwap",
+ // Image Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15763-L15878
+ "imageSize",
+ "imageSamples",
+ "imageLoad",
+ "imageStore",
+ "imageAtomicAdd",
+ "imageAtomicMin",
+ "imageAtomicMax",
+ "imageAtomicAnd",
+ "imageAtomicOr",
+ "imageAtomicXor",
+ "imageAtomicExchange",
+ "imageAtomicCompSwap",
+ // Geometry Shader Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15886-L15932
+ "EmitStreamVertex",
+ "EndStreamPrimitive",
+ "EmitVertex",
+ "EndPrimitive",
+ // Fragment Processing Functions, Derivative Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16041-L16114
+ "dFdx",
+ "dFdy",
+ "dFdxFine",
+ "dFdyFine",
+ "dFdxCoarse",
+ "dFdyCoarse",
+ "fwidth",
+ "fwidthFine",
+ "fwidthCoarse",
+ // Fragment Processing Functions, Interpolation Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16150-L16198
+ "interpolateAtCentroid",
+ "interpolateAtSample",
+ "interpolateAtOffset",
+ // Noise Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16214-L16243
+ "noise1",
+ "noise2",
+ "noise3",
+ "noise4",
+ // Shader Invocation Control Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16255-L16276
+ "barrier",
+ // Shader Memory Control Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16336-L16382
+ "memoryBarrier",
+ "memoryBarrierAtomicCounter",
+ "memoryBarrierBuffer",
+ "memoryBarrierShared",
+ "memoryBarrierImage",
+ "groupMemoryBarrier",
+ // Subpass-Input Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16451-L16470
+ "subpassLoad",
+ // Shader Invocation Group Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16483-L16511
+ "anyInvocation",
+ "allInvocations",
+ "allInvocationsEqual",
+ //
+ // entry point name (should not be shadowed)
+ //
+ "main",
+ // Naga utilities:
+ super::MODF_FUNCTION,
+ super::FREXP_FUNCTION,
+ super::FIRST_INSTANCE_BINDING,
+];
diff --git a/third_party/rust/naga/src/back/glsl/mod.rs b/third_party/rust/naga/src/back/glsl/mod.rs
new file mode 100644
index 0000000000..e346d43257
--- /dev/null
+++ b/third_party/rust/naga/src/back/glsl/mod.rs
@@ -0,0 +1,4532 @@
+/*!
+Backend for [GLSL][glsl] (OpenGL Shading Language).
+
+The main structure is [`Writer`], it maintains internal state that is used
+to output a [`Module`](crate::Module) into glsl
+
+# Supported versions
+### Core
+- 330
+- 400
+- 410
+- 420
+- 430
+- 450
+
+### ES
+- 300
+- 310
+
+[glsl]: https://www.khronos.org/registry/OpenGL/index_gl.php
+*/
+
+// GLSL is mostly a superset of C but it also removes some parts of it this is a list of relevant
+// aspects for this backend.
+//
+// The most notable change is the introduction of the version preprocessor directive that must
+// always be the first line of a glsl file and is written as
+// `#version number profile`
+// `number` is the version itself (i.e. 300) and `profile` is the
+// shader profile we only support "core" and "es", the former is used in desktop applications and
+// the later is used in embedded contexts, mobile devices and browsers. Each one as it's own
+// versions (at the time of writing this the latest version for "core" is 460 and for "es" is 320)
+//
+// Other important preprocessor addition is the extension directive which is written as
+// `#extension name: behaviour`
+// Extensions provide increased features in a plugin fashion but they aren't required to be
+// supported hence why they are called extensions, that's why `behaviour` is used it specifies
+// whether the extension is strictly required or if it should only be enabled if needed. In our case
+// when we use extensions we set behaviour to `require` always.
+//
+// The only thing that glsl removes that makes a difference are pointers.
+//
+// Additions that are relevant for the backend are the discard keyword, the introduction of
+// vector, matrices, samplers, image types and functions that provide common shader operations
+
+pub use features::Features;
+
+use crate::{
+ back,
+ proc::{self, NameKey},
+ valid, Handle, ShaderStage, TypeInner,
+};
+use features::FeaturesManager;
+use std::{
+ cmp::Ordering,
+ fmt,
+ fmt::{Error as FmtError, Write},
+ mem,
+};
+use thiserror::Error;
+
+/// Contains the features related code and the features querying method
+mod features;
+/// Contains a constant with a slice of all the reserved keywords RESERVED_KEYWORDS
+mod keywords;
+
+/// List of supported `core` GLSL versions.
+pub const SUPPORTED_CORE_VERSIONS: &[u16] = &[140, 150, 330, 400, 410, 420, 430, 440, 450, 460];
+/// List of supported `es` GLSL versions.
+pub const SUPPORTED_ES_VERSIONS: &[u16] = &[300, 310, 320];
+
+/// The suffix of the variable that will hold the calculated clamped level
+/// of detail for bounds checking in `ImageLoad`
+const CLAMPED_LOD_SUFFIX: &str = "_clamped_lod";
+
+pub(crate) const MODF_FUNCTION: &str = "naga_modf";
+pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
+
+// Must match code in glsl_built_in
+pub const FIRST_INSTANCE_BINDING: &str = "naga_vs_first_instance";
+
+/// Mapping between resources and bindings.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, u8>;
+
+impl crate::AtomicFunction {
+ const fn to_glsl(self) -> &'static str {
+ match self {
+ Self::Add | Self::Subtract => "Add",
+ Self::And => "And",
+ Self::InclusiveOr => "Or",
+ Self::ExclusiveOr => "Xor",
+ Self::Min => "Min",
+ Self::Max => "Max",
+ Self::Exchange { compare: None } => "Exchange",
+ Self::Exchange { compare: Some(_) } => "", //TODO
+ }
+ }
+}
+
+impl crate::AddressSpace {
+ const fn is_buffer(&self) -> bool {
+ match *self {
+ crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } => true,
+ _ => false,
+ }
+ }
+
+ /// Whether a variable with this address space can be initialized
+ const fn initializable(&self) -> bool {
+ match *self {
+ crate::AddressSpace::Function | crate::AddressSpace::Private => true,
+ crate::AddressSpace::WorkGroup
+ | crate::AddressSpace::Uniform
+ | crate::AddressSpace::Storage { .. }
+ | crate::AddressSpace::Handle
+ | crate::AddressSpace::PushConstant => false,
+ }
+ }
+}
+
+/// A GLSL version.
+#[derive(Debug, Copy, Clone, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum Version {
+ /// `core` GLSL.
+ Desktop(u16),
+ /// `es` GLSL.
+ Embedded { version: u16, is_webgl: bool },
+}
+
+impl Version {
+ /// Create a new gles version
+ pub const fn new_gles(version: u16) -> Self {
+ Self::Embedded {
+ version,
+ is_webgl: false,
+ }
+ }
+
+ /// Returns true if self is `Version::Embedded` (i.e. is a es version)
+ const fn is_es(&self) -> bool {
+ match *self {
+ Version::Desktop(_) => false,
+ Version::Embedded { .. } => true,
+ }
+ }
+
+ /// Returns true if targeting WebGL
+ const fn is_webgl(&self) -> bool {
+ match *self {
+ Version::Desktop(_) => false,
+ Version::Embedded { is_webgl, .. } => is_webgl,
+ }
+ }
+
+ /// Checks the list of currently supported versions and returns true if it contains the
+ /// specified version
+ ///
+ /// # Notes
+ /// As an invalid version number will never be added to the supported version list
+ /// so this also checks for version validity
+ fn is_supported(&self) -> bool {
+ match *self {
+ Version::Desktop(v) => SUPPORTED_CORE_VERSIONS.contains(&v),
+ Version::Embedded { version: v, .. } => SUPPORTED_ES_VERSIONS.contains(&v),
+ }
+ }
+
+ fn supports_io_locations(&self) -> bool {
+ *self >= Version::Desktop(330) || *self >= Version::new_gles(300)
+ }
+
+ /// Checks if the version supports all of the explicit layouts:
+ /// - `location=` qualifiers for bindings
+ /// - `binding=` qualifiers for resources
+ ///
+ /// Note: `location=` for vertex inputs and fragment outputs is supported
+ /// unconditionally for GLES 300.
+ fn supports_explicit_locations(&self) -> bool {
+ *self >= Version::Desktop(410) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_early_depth_test(&self) -> bool {
+ *self >= Version::Desktop(130) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_std430_layout(&self) -> bool {
+ *self >= Version::Desktop(430) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_fma_function(&self) -> bool {
+ *self >= Version::Desktop(400) || *self >= Version::new_gles(320)
+ }
+
+ fn supports_integer_functions(&self) -> bool {
+ *self >= Version::Desktop(400) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_frexp_function(&self) -> bool {
+ *self >= Version::Desktop(400) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_derivative_control(&self) -> bool {
+ *self >= Version::Desktop(450)
+ }
+}
+
+impl PartialOrd for Version {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ match (*self, *other) {
+ (Version::Desktop(x), Version::Desktop(y)) => Some(x.cmp(&y)),
+ (Version::Embedded { version: x, .. }, Version::Embedded { version: y, .. }) => {
+ Some(x.cmp(&y))
+ }
+ _ => None,
+ }
+ }
+}
+
+impl fmt::Display for Version {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match *self {
+ Version::Desktop(v) => write!(f, "{v} core"),
+ Version::Embedded { version: v, .. } => write!(f, "{v} es"),
+ }
+ }
+}
+
+bitflags::bitflags! {
+ /// Configuration flags for the [`Writer`].
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct WriterFlags: u32 {
+ /// Flip output Y and extend Z from (0, 1) to (-1, 1).
+ const ADJUST_COORDINATE_SPACE = 0x1;
+ /// Supports GL_EXT_texture_shadow_lod on the host, which provides
+ /// additional functions on shadows and arrays of shadows.
+ const TEXTURE_SHADOW_LOD = 0x2;
+ /// Supports ARB_shader_draw_parameters on the host, which provides
+ /// support for `gl_BaseInstanceARB`, `gl_BaseVertexARB`, and `gl_DrawIDARB`.
+ const DRAW_PARAMETERS = 0x4;
+ /// Include unused global variables, constants and functions. By default the output will exclude
+ /// global variables that are not used in the specified entrypoint (including indirect use),
+ /// all constant declarations, and functions that use excluded global variables.
+ const INCLUDE_UNUSED_ITEMS = 0x10;
+ /// Emit `PointSize` output builtin to vertex shaders, which is
+ /// required for drawing with `PointList` topology.
+ ///
+ /// https://registry.khronos.org/OpenGL/specs/es/3.2/GLSL_ES_Specification_3.20.html#built-in-language-variables
+ /// The variable gl_PointSize is intended for a shader to write the size of the point to be rasterized. It is measured in pixels.
+ /// If gl_PointSize is not written to, its value is undefined in subsequent pipe stages.
+ const FORCE_POINT_SIZE = 0x20;
+ }
+}
+
+/// Configuration used in the [`Writer`].
+#[derive(Debug, Clone)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Options {
+ /// The GLSL version to be used.
+ pub version: Version,
+ /// Configuration flags for the [`Writer`].
+ pub writer_flags: WriterFlags,
+ /// Map of resources association to binding locations.
+ pub binding_map: BindingMap,
+ /// Should workgroup variables be zero initialized (by polyfilling)?
+ pub zero_initialize_workgroup_memory: bool,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ version: Version::new_gles(310),
+ writer_flags: WriterFlags::ADJUST_COORDINATE_SPACE,
+ binding_map: BindingMap::default(),
+ zero_initialize_workgroup_memory: true,
+ }
+ }
+}
+
+/// A subset of options meant to be changed per pipeline.
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct PipelineOptions {
+ /// The stage of the entry point.
+ pub shader_stage: ShaderStage,
+ /// The name of the entry point.
+ ///
+ /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown.
+ pub entry_point: String,
+ /// How many views to render to, if doing multiview rendering.
+ pub multiview: Option<std::num::NonZeroU32>,
+}
+
+#[derive(Debug)]
+pub struct VaryingLocation {
+ /// The location of the global.
+ /// This corresponds to `layout(location = ..)` in GLSL.
+ pub location: u32,
+ /// The index which can be used for dual source blending.
+ /// This corresponds to `layout(index = ..)` in GLSL.
+ pub index: u32,
+}
+
+/// Reflection info for texture mappings and uniforms.
+#[derive(Debug)]
+pub struct ReflectionInfo {
+ /// Mapping between texture names and variables/samplers.
+ pub texture_mapping: crate::FastHashMap<String, TextureMapping>,
+ /// Mapping between uniform variables and names.
+ pub uniforms: crate::FastHashMap<Handle<crate::GlobalVariable>, String>,
+ /// Mapping between names and attribute locations.
+ pub varying: crate::FastHashMap<String, VaryingLocation>,
+ /// List of push constant items in the shader.
+ pub push_constant_items: Vec<PushConstantItem>,
+}
+
+/// Mapping between a texture and its sampler, if it exists.
+///
+/// GLSL pre-Vulkan has no concept of separate textures and samplers. Instead, everything is a
+/// `gsamplerN` where `g` is the scalar type and `N` is the dimension. But naga uses separate textures
+/// and samplers in the IR, so the backend produces a [`FastHashMap`](crate::FastHashMap) with the texture name
+/// as a key and a [`TextureMapping`] as a value. This way, the user knows where to bind.
+///
+/// [`Storage`](crate::ImageClass::Storage) images produce `gimageN` and don't have an associated sampler,
+/// so the [`sampler`](Self::sampler) field will be [`None`].
+#[derive(Debug, Clone)]
+pub struct TextureMapping {
+ /// Handle to the image global variable.
+ pub texture: Handle<crate::GlobalVariable>,
+ /// Handle to the associated sampler global variable, if it exists.
+ pub sampler: Option<Handle<crate::GlobalVariable>>,
+}
+
+/// All information to bind a single uniform value to the shader.
+///
+/// Push constants are emulated using traditional uniforms in OpenGL.
+///
+/// These are composed of a set of primitives (scalar, vector, matrix) that
+/// are given names. Because they are not backed by the concept of a buffer,
+/// we must do the work of calculating the offset of each primitive in the
+/// push constant block.
+#[derive(Debug, Clone)]
+pub struct PushConstantItem {
+ /// GL uniform name for the item. This name is the same as if you were
+ /// to access it directly from a GLSL shader.
+ ///
+ /// The with the following example, the following names will be generated,
+ /// one name per GLSL uniform.
+ ///
+ /// ```glsl
+ /// struct InnerStruct {
+ /// value: f32,
+ /// }
+ ///
+ /// struct PushConstant {
+ /// InnerStruct inner;
+ /// vec4 array[2];
+ /// }
+ ///
+ /// uniform PushConstants _push_constant_binding_cs;
+ /// ```
+ ///
+ /// ```text
+ /// - _push_constant_binding_cs.inner.value
+ /// - _push_constant_binding_cs.array[0]
+ /// - _push_constant_binding_cs.array[1]
+ /// ```
+ ///
+ pub access_path: String,
+ /// Type of the uniform. This will only ever be a scalar, vector, or matrix.
+ pub ty: Handle<crate::Type>,
+ /// The offset in the push constant memory block this uniform maps to.
+ ///
+ /// The size of the uniform can be derived from the type.
+ pub offset: u32,
+}
+
+/// Helper structure that generates a number
+#[derive(Default)]
+struct IdGenerator(u32);
+
+impl IdGenerator {
+ /// Generates a number that's guaranteed to be unique for this `IdGenerator`
+ fn generate(&mut self) -> u32 {
+ // It's just an increasing number but it does the job
+ let ret = self.0;
+ self.0 += 1;
+ ret
+ }
+}
+
+/// Assorted options needed for generating varyings.
+#[derive(Clone, Copy)]
+struct VaryingOptions {
+ output: bool,
+ targeting_webgl: bool,
+ draw_parameters: bool,
+}
+
+impl VaryingOptions {
+ const fn from_writer_options(options: &Options, output: bool) -> Self {
+ Self {
+ output,
+ targeting_webgl: options.version.is_webgl(),
+ draw_parameters: options.writer_flags.contains(WriterFlags::DRAW_PARAMETERS),
+ }
+ }
+}
+
+/// Helper wrapper used to get a name for a varying
+///
+/// Varying have different naming schemes depending on their binding:
+/// - Varyings with builtin bindings get the from [`glsl_built_in`].
+/// - Varyings with location bindings are named `_S_location_X` where `S` is a
+/// prefix identifying which pipeline stage the varying connects, and `X` is
+/// the location.
+struct VaryingName<'a> {
+ binding: &'a crate::Binding,
+ stage: ShaderStage,
+ options: VaryingOptions,
+}
+impl fmt::Display for VaryingName<'_> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match *self.binding {
+ crate::Binding::Location {
+ second_blend_source: true,
+ ..
+ } => {
+ write!(f, "_fs2p_location1",)
+ }
+ crate::Binding::Location { location, .. } => {
+ let prefix = match (self.stage, self.options.output) {
+ (ShaderStage::Compute, _) => unreachable!(),
+ // pipeline to vertex
+ (ShaderStage::Vertex, false) => "p2vs",
+ // vertex to fragment
+ (ShaderStage::Vertex, true) | (ShaderStage::Fragment, false) => "vs2fs",
+ // fragment to pipeline
+ (ShaderStage::Fragment, true) => "fs2p",
+ };
+ write!(f, "_{prefix}_location{location}",)
+ }
+ crate::Binding::BuiltIn(built_in) => {
+ write!(f, "{}", glsl_built_in(built_in, self.options))
+ }
+ }
+ }
+}
+
+impl ShaderStage {
+ const fn to_str(self) -> &'static str {
+ match self {
+ ShaderStage::Compute => "cs",
+ ShaderStage::Fragment => "fs",
+ ShaderStage::Vertex => "vs",
+ }
+ }
+}
+
+/// Shorthand result used internally by the backend
+type BackendResult<T = ()> = Result<T, Error>;
+
+/// A GLSL compilation error.
+#[derive(Debug, Error)]
+pub enum Error {
+ /// A error occurred while writing to the output.
+ #[error("Format error")]
+ FmtError(#[from] FmtError),
+ /// The specified [`Version`] doesn't have all required [`Features`].
+ ///
+ /// Contains the missing [`Features`].
+ #[error("The selected version doesn't support {0:?}")]
+ MissingFeatures(Features),
+ /// [`AddressSpace::PushConstant`](crate::AddressSpace::PushConstant) was used more than
+ /// once in the entry point, which isn't supported.
+ #[error("Multiple push constants aren't supported")]
+ MultiplePushConstants,
+ /// The specified [`Version`] isn't supported.
+ #[error("The specified version isn't supported")]
+ VersionNotSupported,
+ /// The entry point couldn't be found.
+ #[error("The requested entry point couldn't be found")]
+ EntryPointNotFound,
+ /// A call was made to an unsupported external.
+ #[error("A call was made to an unsupported external: {0}")]
+ UnsupportedExternal(String),
+ /// A scalar with an unsupported width was requested.
+ #[error("A scalar with an unsupported width was requested: {0:?}")]
+ UnsupportedScalar(crate::Scalar),
+ /// A image was used with multiple samplers, which isn't supported.
+ #[error("A image was used with multiple samplers")]
+ ImageMultipleSamplers,
+ #[error("{0}")]
+ Custom(String),
+}
+
+/// Binary operation with a different logic on the GLSL side.
+enum BinaryOperation {
+ /// Vector comparison should use the function like `greaterThan()`, etc.
+ VectorCompare,
+ /// Vector component wise operation; used to polyfill unsupported ops like `|` and `&` for `bvecN`'s
+ VectorComponentWise,
+ /// GLSL `%` is SPIR-V `OpUMod/OpSMod` and `mod()` is `OpFMod`, but [`BinaryOperator::Modulo`](crate::BinaryOperator::Modulo) is `OpFRem`.
+ Modulo,
+ /// Any plain operation. No additional logic required.
+ Other,
+}
+
+/// Writer responsible for all code generation.
+pub struct Writer<'a, W> {
+ // Inputs
+ /// The module being written.
+ module: &'a crate::Module,
+ /// The module analysis.
+ info: &'a valid::ModuleInfo,
+ /// The output writer.
+ out: W,
+ /// User defined configuration to be used.
+ options: &'a Options,
+ /// The bound checking policies to be used
+ policies: proc::BoundsCheckPolicies,
+
+ // Internal State
+ /// Features manager used to store all the needed features and write them.
+ features: FeaturesManager,
+ namer: proc::Namer,
+ /// A map with all the names needed for writing the module
+ /// (generated by a [`Namer`](crate::proc::Namer)).
+ names: crate::FastHashMap<NameKey, String>,
+ /// A map with the names of global variables needed for reflections.
+ reflection_names_globals: crate::FastHashMap<Handle<crate::GlobalVariable>, String>,
+ /// The selected entry point.
+ entry_point: &'a crate::EntryPoint,
+ /// The index of the selected entry point.
+ entry_point_idx: proc::EntryPointIndex,
+ /// A generator for unique block numbers.
+ block_id: IdGenerator,
+ /// Set of expressions that have associated temporary variables.
+ named_expressions: crate::NamedExpressions,
+ /// Set of expressions that need to be baked to avoid unnecessary repetition in output
+ need_bake_expressions: back::NeedBakeExpressions,
+ /// How many views to render to, if doing multiview rendering.
+ multiview: Option<std::num::NonZeroU32>,
+ /// Mapping of varying variables to their location. Needed for reflections.
+ varying: crate::FastHashMap<String, VaryingLocation>,
+}
+
+impl<'a, W: Write> Writer<'a, W> {
+ /// Creates a new [`Writer`] instance.
+ ///
+ /// # Errors
+ /// - If the version specified is invalid or supported.
+ /// - If the entry point couldn't be found in the module.
+ /// - If the version specified doesn't support some used features.
+ pub fn new(
+ out: W,
+ module: &'a crate::Module,
+ info: &'a valid::ModuleInfo,
+ options: &'a Options,
+ pipeline_options: &'a PipelineOptions,
+ policies: proc::BoundsCheckPolicies,
+ ) -> Result<Self, Error> {
+ // Check if the requested version is supported
+ if !options.version.is_supported() {
+ log::error!("Version {}", options.version);
+ return Err(Error::VersionNotSupported);
+ }
+
+ // Try to find the entry point and corresponding index
+ let ep_idx = module
+ .entry_points
+ .iter()
+ .position(|ep| {
+ pipeline_options.shader_stage == ep.stage && pipeline_options.entry_point == ep.name
+ })
+ .ok_or(Error::EntryPointNotFound)?;
+
+ // Generate a map with names required to write the module
+ let mut names = crate::FastHashMap::default();
+ let mut namer = proc::Namer::default();
+ namer.reset(
+ module,
+ keywords::RESERVED_KEYWORDS,
+ &[],
+ &[],
+ &[
+ "gl_", // all GL built-in variables
+ "_group", // all normal bindings
+ "_push_constant_binding_", // all push constant bindings
+ ],
+ &mut names,
+ );
+
+ // Build the instance
+ let mut this = Self {
+ module,
+ info,
+ out,
+ options,
+ policies,
+
+ namer,
+ features: FeaturesManager::new(),
+ names,
+ reflection_names_globals: crate::FastHashMap::default(),
+ entry_point: &module.entry_points[ep_idx],
+ entry_point_idx: ep_idx as u16,
+ multiview: pipeline_options.multiview,
+ block_id: IdGenerator::default(),
+ named_expressions: Default::default(),
+ need_bake_expressions: Default::default(),
+ varying: Default::default(),
+ };
+
+ // Find all features required to print this module
+ this.collect_required_features()?;
+
+ Ok(this)
+ }
+
+ /// Writes the [`Module`](crate::Module) as glsl to the output
+ ///
+ /// # Notes
+ /// If an error occurs while writing, the output might have been written partially
+ ///
+ /// # Panics
+ /// Might panic if the module is invalid
+ pub fn write(&mut self) -> Result<ReflectionInfo, Error> {
+ // We use `writeln!(self.out)` throughout the write to add newlines
+ // to make the output more readable
+
+ let es = self.options.version.is_es();
+
+ // Write the version (It must be the first thing or it isn't a valid glsl output)
+ writeln!(self.out, "#version {}", self.options.version)?;
+ // Write all the needed extensions
+ //
+ // This used to be the last thing being written as it allowed to search for features while
+ // writing the module saving some loops but some older versions (420 or less) required the
+ // extensions to appear before being used, even though extensions are part of the
+ // preprocessor not the processor ¯\_(ツ)_/¯
+ self.features.write(self.options, &mut self.out)?;
+
+ // Write the additional extensions
+ if self
+ .options
+ .writer_flags
+ .contains(WriterFlags::TEXTURE_SHADOW_LOD)
+ {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_shadow_lod.txt
+ writeln!(self.out, "#extension GL_EXT_texture_shadow_lod : require")?;
+ }
+
+ // glsl es requires a precision to be specified for floats and ints
+ // TODO: Should this be user configurable?
+ if es {
+ writeln!(self.out)?;
+ writeln!(self.out, "precision highp float;")?;
+ writeln!(self.out, "precision highp int;")?;
+ writeln!(self.out)?;
+ }
+
+ if self.entry_point.stage == ShaderStage::Compute {
+ let workgroup_size = self.entry_point.workgroup_size;
+ writeln!(
+ self.out,
+ "layout(local_size_x = {}, local_size_y = {}, local_size_z = {}) in;",
+ workgroup_size[0], workgroup_size[1], workgroup_size[2]
+ )?;
+ writeln!(self.out)?;
+ }
+
+ if self.entry_point.stage == ShaderStage::Vertex
+ && !self
+ .options
+ .writer_flags
+ .contains(WriterFlags::DRAW_PARAMETERS)
+ && self.features.contains(Features::INSTANCE_INDEX)
+ {
+ writeln!(self.out, "uniform uint {FIRST_INSTANCE_BINDING};")?;
+ writeln!(self.out)?;
+ }
+
+ // Enable early depth tests if needed
+ if let Some(depth_test) = self.entry_point.early_depth_test {
+ // If early depth test is supported for this version of GLSL
+ if self.options.version.supports_early_depth_test() {
+ writeln!(self.out, "layout(early_fragment_tests) in;")?;
+
+ if let Some(conservative) = depth_test.conservative {
+ use crate::ConservativeDepth as Cd;
+
+ let depth = match conservative {
+ Cd::GreaterEqual => "greater",
+ Cd::LessEqual => "less",
+ Cd::Unchanged => "unchanged",
+ };
+ writeln!(self.out, "layout (depth_{depth}) out float gl_FragDepth;")?;
+ }
+ writeln!(self.out)?;
+ } else {
+ log::warn!(
+ "Early depth testing is not supported for this version of GLSL: {}",
+ self.options.version
+ );
+ }
+ }
+
+ if self.entry_point.stage == ShaderStage::Vertex && self.options.version.is_webgl() {
+ if let Some(multiview) = self.multiview.as_ref() {
+ writeln!(self.out, "layout(num_views = {multiview}) in;")?;
+ writeln!(self.out)?;
+ }
+ }
+
+ // Write struct types.
+ //
+ // This are always ordered because the IR is structured in a way that
+ // you can't make a struct without adding all of its members first.
+ for (handle, ty) in self.module.types.iter() {
+ if let TypeInner::Struct { ref members, .. } = ty.inner {
+ // Structures ending with runtime-sized arrays can only be
+ // rendered as shader storage blocks in GLSL, not stand-alone
+ // struct types.
+ if !self.module.types[members.last().unwrap().ty]
+ .inner
+ .is_dynamically_sized(&self.module.types)
+ {
+ let name = &self.names[&NameKey::Type(handle)];
+ write!(self.out, "struct {name} ")?;
+ self.write_struct_body(handle, members)?;
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+
+ // Write functions to create special types.
+ for (type_key, struct_ty) in self.module.special_types.predeclared_types.iter() {
+ match type_key {
+ &crate::PredeclaredType::ModfResult { size, width }
+ | &crate::PredeclaredType::FrexpResult { size, width } => {
+ let arg_type_name_owner;
+ let arg_type_name = if let Some(size) = size {
+ arg_type_name_owner =
+ format!("{}vec{}", if width == 8 { "d" } else { "" }, size as u8);
+ &arg_type_name_owner
+ } else if width == 8 {
+ "double"
+ } else {
+ "float"
+ };
+
+ let other_type_name_owner;
+ let (defined_func_name, called_func_name, other_type_name) =
+ if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
+ (MODF_FUNCTION, "modf", arg_type_name)
+ } else {
+ let other_type_name = if let Some(size) = size {
+ other_type_name_owner = format!("ivec{}", size as u8);
+ &other_type_name_owner
+ } else {
+ "int"
+ };
+ (FREXP_FUNCTION, "frexp", other_type_name)
+ };
+
+ let struct_name = &self.names[&NameKey::Type(*struct_ty)];
+
+ writeln!(self.out)?;
+ if !self.options.version.supports_frexp_function()
+ && matches!(type_key, &crate::PredeclaredType::FrexpResult { .. })
+ {
+ writeln!(
+ self.out,
+ "{struct_name} {defined_func_name}({arg_type_name} arg) {{
+ {other_type_name} other = arg == {arg_type_name}(0) ? {other_type_name}(0) : {other_type_name}({arg_type_name}(1) + log2(arg));
+ {arg_type_name} fract = arg * exp2({arg_type_name}(-other));
+ return {struct_name}(fract, other);
+}}",
+ )?;
+ } else {
+ writeln!(
+ self.out,
+ "{struct_name} {defined_func_name}({arg_type_name} arg) {{
+ {other_type_name} other;
+ {arg_type_name} fract = {called_func_name}(arg, other);
+ return {struct_name}(fract, other);
+}}",
+ )?;
+ }
+ }
+ &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
+ }
+ }
+
+ // Write all named constants
+ let mut constants = self
+ .module
+ .constants
+ .iter()
+ .filter(|&(_, c)| c.name.is_some())
+ .peekable();
+ while let Some((handle, _)) = constants.next() {
+ self.write_global_constant(handle)?;
+ // Add extra newline for readability on last iteration
+ if constants.peek().is_none() {
+ writeln!(self.out)?;
+ }
+ }
+
+ let ep_info = self.info.get_entry_point(self.entry_point_idx as usize);
+
+ // Write the globals
+ //
+ // Unless explicitly disabled with WriterFlags::INCLUDE_UNUSED_ITEMS,
+ // we filter all globals that aren't used by the selected entry point as they might be
+ // interfere with each other (i.e. two globals with the same location but different with
+ // different classes)
+ let include_unused = self
+ .options
+ .writer_flags
+ .contains(WriterFlags::INCLUDE_UNUSED_ITEMS);
+ for (handle, global) in self.module.global_variables.iter() {
+ let is_unused = ep_info[handle].is_empty();
+ if !include_unused && is_unused {
+ continue;
+ }
+
+ match self.module.types[global.ty].inner {
+ // We treat images separately because they might require
+ // writing the storage format
+ TypeInner::Image {
+ mut dim,
+ arrayed,
+ class,
+ } => {
+ // Gather the storage format if needed
+ let storage_format_access = match self.module.types[global.ty].inner {
+ TypeInner::Image {
+ class: crate::ImageClass::Storage { format, access },
+ ..
+ } => Some((format, access)),
+ _ => None,
+ };
+
+ if dim == crate::ImageDimension::D1 && es {
+ dim = crate::ImageDimension::D2
+ }
+
+ // Gether the location if needed
+ let layout_binding = if self.options.version.supports_explicit_locations() {
+ let br = global.binding.as_ref().unwrap();
+ self.options.binding_map.get(br).cloned()
+ } else {
+ None
+ };
+
+ // Write all the layout qualifiers
+ if layout_binding.is_some() || storage_format_access.is_some() {
+ write!(self.out, "layout(")?;
+ if let Some(binding) = layout_binding {
+ write!(self.out, "binding = {binding}")?;
+ }
+ if let Some((format, _)) = storage_format_access {
+ let format_str = glsl_storage_format(format)?;
+ let separator = match layout_binding {
+ Some(_) => ",",
+ None => "",
+ };
+ write!(self.out, "{separator}{format_str}")?;
+ }
+ write!(self.out, ") ")?;
+ }
+
+ if let Some((_, access)) = storage_format_access {
+ self.write_storage_access(access)?;
+ }
+
+ // All images in glsl are `uniform`
+ // The trailing space is important
+ write!(self.out, "uniform ")?;
+
+ // write the type
+ //
+ // This is way we need the leading space because `write_image_type` doesn't add
+ // any spaces at the beginning or end
+ self.write_image_type(dim, arrayed, class)?;
+
+ // Finally write the name and end the global with a `;`
+ // The leading space is important
+ let global_name = self.get_global_name(handle, global);
+ writeln!(self.out, " {global_name};")?;
+ writeln!(self.out)?;
+
+ self.reflection_names_globals.insert(handle, global_name);
+ }
+ // glsl has no concept of samplers so we just ignore it
+ TypeInner::Sampler { .. } => continue,
+ // All other globals are written by `write_global`
+ _ => {
+ self.write_global(handle, global)?;
+ // Add a newline (only for readability)
+ writeln!(self.out)?;
+ }
+ }
+ }
+
+ for arg in self.entry_point.function.arguments.iter() {
+ self.write_varying(arg.binding.as_ref(), arg.ty, false)?;
+ }
+ if let Some(ref result) = self.entry_point.function.result {
+ self.write_varying(result.binding.as_ref(), result.ty, true)?;
+ }
+ writeln!(self.out)?;
+
+ // Write all regular functions
+ for (handle, function) in self.module.functions.iter() {
+ // Check that the function doesn't use globals that aren't supported
+ // by the current entry point
+ if !include_unused && !ep_info.dominates_global_use(&self.info[handle]) {
+ continue;
+ }
+
+ let fun_info = &self.info[handle];
+
+ // Skip functions that that are not compatible with this entry point's stage.
+ //
+ // When validation is enabled, it rejects modules whose entry points try to call
+ // incompatible functions, so if we got this far, then any functions incompatible
+ // with our selected entry point must not be used.
+ //
+ // When validation is disabled, `fun_info.available_stages` is always just
+ // `ShaderStages::all()`, so this will write all functions in the module, and
+ // the downstream GLSL compiler will catch any problems.
+ if !fun_info.available_stages.contains(ep_info.available_stages) {
+ continue;
+ }
+
+ // Write the function
+ self.write_function(back::FunctionType::Function(handle), function, fun_info)?;
+
+ writeln!(self.out)?;
+ }
+
+ self.write_function(
+ back::FunctionType::EntryPoint(self.entry_point_idx),
+ &self.entry_point.function,
+ ep_info,
+ )?;
+
+ // Add newline at the end of file
+ writeln!(self.out)?;
+
+ // Collect all reflection info and return it to the user
+ self.collect_reflection_info()
+ }
+
+ fn write_array_size(
+ &mut self,
+ base: Handle<crate::Type>,
+ size: crate::ArraySize,
+ ) -> BackendResult {
+ write!(self.out, "[")?;
+
+ // Write the array size
+ // Writes nothing if `ArraySize::Dynamic`
+ match size {
+ crate::ArraySize::Constant(size) => {
+ write!(self.out, "{size}")?;
+ }
+ crate::ArraySize::Dynamic => (),
+ }
+
+ write!(self.out, "]")?;
+
+ if let TypeInner::Array {
+ base: next_base,
+ size: next_size,
+ ..
+ } = self.module.types[base].inner
+ {
+ self.write_array_size(next_base, next_size)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write value types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_value_type(&mut self, inner: &TypeInner) -> BackendResult {
+ match *inner {
+ // Scalars are simple we just get the full name from `glsl_scalar`
+ TypeInner::Scalar(scalar)
+ | TypeInner::Atomic(scalar)
+ | TypeInner::ValuePointer {
+ size: None,
+ scalar,
+ space: _,
+ } => write!(self.out, "{}", glsl_scalar(scalar)?.full)?,
+ // Vectors are just `gvecN` where `g` is the scalar prefix and `N` is the vector size
+ TypeInner::Vector { size, scalar }
+ | TypeInner::ValuePointer {
+ size: Some(size),
+ scalar,
+ space: _,
+ } => write!(self.out, "{}vec{}", glsl_scalar(scalar)?.prefix, size as u8)?,
+ // Matrices are written with `gmatMxN` where `g` is the scalar prefix (only floats and
+ // doubles are allowed), `M` is the columns count and `N` is the rows count
+ //
+ // glsl supports a matrix shorthand `gmatN` where `N` = `M` but it doesn't justify the
+ // extra branch to write matrices this way
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => write!(
+ self.out,
+ "{}mat{}x{}",
+ glsl_scalar(scalar)?.prefix,
+ columns as u8,
+ rows as u8
+ )?,
+ // GLSL arrays are written as `type name[size]`
+ // Here we only write the size of the array i.e. `[size]`
+ // Base `type` and `name` should be written outside
+ TypeInner::Array { base, size, .. } => self.write_array_size(base, size)?,
+ // Write all variants instead of `_` so that if new variants are added a
+ // no exhaustiveness error is thrown
+ TypeInner::Pointer { .. }
+ | TypeInner::Struct { .. }
+ | TypeInner::Image { .. }
+ | TypeInner::Sampler { .. }
+ | TypeInner::AccelerationStructure
+ | TypeInner::RayQuery
+ | TypeInner::BindingArray { .. } => {
+ return Err(Error::Custom(format!("Unable to write type {inner:?}")))
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_type(&mut self, ty: Handle<crate::Type>) -> BackendResult {
+ match self.module.types[ty].inner {
+ // glsl has no pointer types so just write types as normal and loads are skipped
+ TypeInner::Pointer { base, .. } => self.write_type(base),
+ // glsl structs are written as just the struct name
+ TypeInner::Struct { .. } => {
+ // Get the struct name
+ let name = &self.names[&NameKey::Type(ty)];
+ write!(self.out, "{name}")?;
+ Ok(())
+ }
+ // glsl array has the size separated from the base type
+ TypeInner::Array { base, .. } => self.write_type(base),
+ ref other => self.write_value_type(other),
+ }
+ }
+
+ /// Helper method to write a image type
+ ///
+ /// # Notes
+ /// Adds no leading or trailing whitespace
+ fn write_image_type(
+ &mut self,
+ dim: crate::ImageDimension,
+ arrayed: bool,
+ class: crate::ImageClass,
+ ) -> BackendResult {
+ // glsl images consist of four parts the scalar prefix, the image "type", the dimensions
+ // and modifiers
+ //
+ // There exists two image types
+ // - sampler - for sampled images
+ // - image - for storage images
+ //
+ // There are three possible modifiers that can be used together and must be written in
+ // this order to be valid
+ // - MS - used if it's a multisampled image
+ // - Array - used if it's an image array
+ // - Shadow - used if it's a depth image
+ use crate::ImageClass as Ic;
+
+ let (base, kind, ms, comparison) = match class {
+ Ic::Sampled { kind, multi: true } => ("sampler", kind, "MS", ""),
+ Ic::Sampled { kind, multi: false } => ("sampler", kind, "", ""),
+ Ic::Depth { multi: true } => ("sampler", crate::ScalarKind::Float, "MS", ""),
+ Ic::Depth { multi: false } => ("sampler", crate::ScalarKind::Float, "", "Shadow"),
+ Ic::Storage { format, .. } => ("image", format.into(), "", ""),
+ };
+
+ let precision = if self.options.version.is_es() {
+ "highp "
+ } else {
+ ""
+ };
+
+ write!(
+ self.out,
+ "{}{}{}{}{}{}{}",
+ precision,
+ glsl_scalar(crate::Scalar { kind, width: 4 })?.prefix,
+ base,
+ glsl_dimension(dim),
+ ms,
+ if arrayed { "Array" } else { "" },
+ comparison
+ )?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write non images/sampler globals
+ ///
+ /// # Notes
+ /// Adds a newline
+ ///
+ /// # Panics
+ /// If the global has type sampler
+ fn write_global(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ if self.options.version.supports_explicit_locations() {
+ if let Some(ref br) = global.binding {
+ match self.options.binding_map.get(br) {
+ Some(binding) => {
+ let layout = match global.space {
+ crate::AddressSpace::Storage { .. } => {
+ if self.options.version.supports_std430_layout() {
+ "std430, "
+ } else {
+ "std140, "
+ }
+ }
+ crate::AddressSpace::Uniform => "std140, ",
+ _ => "",
+ };
+ write!(self.out, "layout({layout}binding = {binding}) ")?
+ }
+ None => {
+ log::debug!("unassigned binding for {:?}", global.name);
+ if let crate::AddressSpace::Storage { .. } = global.space {
+ if self.options.version.supports_std430_layout() {
+ write!(self.out, "layout(std430) ")?
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if let crate::AddressSpace::Storage { access } = global.space {
+ self.write_storage_access(access)?;
+ }
+
+ if let Some(storage_qualifier) = glsl_storage_qualifier(global.space) {
+ write!(self.out, "{storage_qualifier} ")?;
+ }
+
+ match global.space {
+ crate::AddressSpace::Private => {
+ self.write_simple_global(handle, global)?;
+ }
+ crate::AddressSpace::WorkGroup => {
+ self.write_simple_global(handle, global)?;
+ }
+ crate::AddressSpace::PushConstant => {
+ self.write_simple_global(handle, global)?;
+ }
+ crate::AddressSpace::Uniform => {
+ self.write_interface_block(handle, global)?;
+ }
+ crate::AddressSpace::Storage { .. } => {
+ self.write_interface_block(handle, global)?;
+ }
+ // A global variable in the `Function` address space is a
+ // contradiction in terms.
+ crate::AddressSpace::Function => unreachable!(),
+ // Textures and samplers are handled directly in `Writer::write`.
+ crate::AddressSpace::Handle => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ fn write_simple_global(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ self.write_type(global.ty)?;
+ write!(self.out, " ")?;
+ self.write_global_name(handle, global)?;
+
+ if let TypeInner::Array { base, size, .. } = self.module.types[global.ty].inner {
+ self.write_array_size(base, size)?;
+ }
+
+ if global.space.initializable() && is_value_init_supported(self.module, global.ty) {
+ write!(self.out, " = ")?;
+ if let Some(init) = global.init {
+ self.write_const_expr(init)?;
+ } else {
+ self.write_zero_init_value(global.ty)?;
+ }
+ }
+
+ writeln!(self.out, ";")?;
+
+ if let crate::AddressSpace::PushConstant = global.space {
+ let global_name = self.get_global_name(handle, global);
+ self.reflection_names_globals.insert(handle, global_name);
+ }
+
+ Ok(())
+ }
+
+ /// Write an interface block for a single Naga global.
+ ///
+ /// Write `block_name { members }`. Since `block_name` must be unique
+ /// between blocks and structs, we add `_block_ID` where `ID` is a
+ /// `IdGenerator` generated number. Write `members` in the same way we write
+ /// a struct's members.
+ fn write_interface_block(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ // Write the block name, it's just the struct name appended with `_block_ID`
+ let ty_name = &self.names[&NameKey::Type(global.ty)];
+ let block_name = format!(
+ "{}_block_{}{:?}",
+ // avoid double underscores as they are reserved in GLSL
+ ty_name.trim_end_matches('_'),
+ self.block_id.generate(),
+ self.entry_point.stage,
+ );
+ write!(self.out, "{block_name} ")?;
+ self.reflection_names_globals.insert(handle, block_name);
+
+ match self.module.types[global.ty].inner {
+ crate::TypeInner::Struct { ref members, .. }
+ if self.module.types[members.last().unwrap().ty]
+ .inner
+ .is_dynamically_sized(&self.module.types) =>
+ {
+ // Structs with dynamically sized arrays must have their
+ // members lifted up as members of the interface block. GLSL
+ // can't write such struct types anyway.
+ self.write_struct_body(global.ty, members)?;
+ write!(self.out, " ")?;
+ self.write_global_name(handle, global)?;
+ }
+ _ => {
+ // A global of any other type is written as the sole member
+ // of the interface block. Since the interface block is
+ // anonymous, this becomes visible in the global scope.
+ write!(self.out, "{{ ")?;
+ self.write_type(global.ty)?;
+ write!(self.out, " ")?;
+ self.write_global_name(handle, global)?;
+ if let TypeInner::Array { base, size, .. } = self.module.types[global.ty].inner {
+ self.write_array_size(base, size)?;
+ }
+ write!(self.out, "; }}")?;
+ }
+ }
+
+ writeln!(self.out, ";")?;
+
+ Ok(())
+ }
+
+ /// Helper method used to find which expressions of a given function require baking
+ ///
+ /// # Notes
+ /// Clears `need_bake_expressions` set before adding to it
+ fn update_expressions_to_bake(&mut self, func: &crate::Function, info: &valid::FunctionInfo) {
+ use crate::Expression;
+ self.need_bake_expressions.clear();
+ for (fun_handle, expr) in func.expressions.iter() {
+ let expr_info = &info[fun_handle];
+ let min_ref_count = func.expressions[fun_handle].bake_ref_count();
+ if min_ref_count <= expr_info.ref_count {
+ self.need_bake_expressions.insert(fun_handle);
+ }
+
+ let inner = expr_info.ty.inner_with(&self.module.types);
+
+ if let Expression::Math { fun, arg, arg1, .. } = *expr {
+ match fun {
+ crate::MathFunction::Dot => {
+ // if the expression is a Dot product with integer arguments,
+ // then the args needs baking as well
+ if let TypeInner::Scalar(crate::Scalar { kind, .. }) = *inner {
+ match kind {
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
+ self.need_bake_expressions.insert(arg);
+ self.need_bake_expressions.insert(arg1.unwrap());
+ }
+ _ => {}
+ }
+ }
+ }
+ crate::MathFunction::CountLeadingZeros => {
+ if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
+ self.need_bake_expressions.insert(arg);
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+ }
+
+ /// Helper method used to get a name for a global
+ ///
+ /// Globals have different naming schemes depending on their binding:
+ /// - Globals without bindings use the name from the [`Namer`](crate::proc::Namer)
+ /// - Globals with resource binding are named `_group_X_binding_Y` where `X`
+ /// is the group and `Y` is the binding
+ fn get_global_name(
+ &self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> String {
+ match (&global.binding, global.space) {
+ (&Some(ref br), _) => {
+ format!(
+ "_group_{}_binding_{}_{}",
+ br.group,
+ br.binding,
+ self.entry_point.stage.to_str()
+ )
+ }
+ (&None, crate::AddressSpace::PushConstant) => {
+ format!("_push_constant_binding_{}", self.entry_point.stage.to_str())
+ }
+ (&None, _) => self.names[&NameKey::GlobalVariable(handle)].clone(),
+ }
+ }
+
+ /// Helper method used to write a name for a global without additional heap allocation
+ fn write_global_name(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ match (&global.binding, global.space) {
+ (&Some(ref br), _) => write!(
+ self.out,
+ "_group_{}_binding_{}_{}",
+ br.group,
+ br.binding,
+ self.entry_point.stage.to_str()
+ )?,
+ (&None, crate::AddressSpace::PushConstant) => write!(
+ self.out,
+ "_push_constant_binding_{}",
+ self.entry_point.stage.to_str()
+ )?,
+ (&None, _) => write!(
+ self.out,
+ "{}",
+ &self.names[&NameKey::GlobalVariable(handle)]
+ )?,
+ }
+
+ Ok(())
+ }
+
+ /// Write a GLSL global that will carry a Naga entry point's argument or return value.
+ ///
+ /// A Naga entry point's arguments and return value are rendered in GLSL as
+ /// variables at global scope with the `in` and `out` storage qualifiers.
+ /// The code we generate for `main` loads from all the `in` globals into
+ /// appropriately named locals. Before it returns, `main` assigns the
+ /// components of its return value into all the `out` globals.
+ ///
+ /// This function writes a declaration for one such GLSL global,
+ /// representing a value passed into or returned from [`self.entry_point`]
+ /// that has a [`Location`] binding. The global's name is generated based on
+ /// the location index and the shader stages being connected; see
+ /// [`VaryingName`]. This means we don't need to know the names of
+ /// arguments, just their types and bindings.
+ ///
+ /// Emit nothing for entry point arguments or return values with [`BuiltIn`]
+ /// bindings; `main` will read from or assign to the appropriate GLSL
+ /// special variable; these are pre-declared. As an exception, we do declare
+ /// `gl_Position` or `gl_FragCoord` with the `invariant` qualifier if
+ /// needed.
+ ///
+ /// Use `output` together with [`self.entry_point.stage`] to determine which
+ /// shader stages are being connected, and choose the `in` or `out` storage
+ /// qualifier.
+ ///
+ /// [`self.entry_point`]: Writer::entry_point
+ /// [`self.entry_point.stage`]: crate::EntryPoint::stage
+ /// [`Location`]: crate::Binding::Location
+ /// [`BuiltIn`]: crate::Binding::BuiltIn
+ fn write_varying(
+ &mut self,
+ binding: Option<&crate::Binding>,
+ ty: Handle<crate::Type>,
+ output: bool,
+ ) -> Result<(), Error> {
+ // For a struct, emit a separate global for each member with a binding.
+ if let crate::TypeInner::Struct { ref members, .. } = self.module.types[ty].inner {
+ for member in members {
+ self.write_varying(member.binding.as_ref(), member.ty, output)?;
+ }
+ return Ok(());
+ }
+
+ let binding = match binding {
+ None => return Ok(()),
+ Some(binding) => binding,
+ };
+
+ let (location, interpolation, sampling, second_blend_source) = match *binding {
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ second_blend_source,
+ } => (location, interpolation, sampling, second_blend_source),
+ crate::Binding::BuiltIn(built_in) => {
+ if let crate::BuiltIn::Position { invariant: true } = built_in {
+ match (self.options.version, self.entry_point.stage) {
+ (
+ Version::Embedded {
+ version: 300,
+ is_webgl: true,
+ },
+ ShaderStage::Fragment,
+ ) => {
+ // `invariant gl_FragCoord` is not allowed in WebGL2 and possibly
+ // OpenGL ES in general (waiting on confirmation).
+ //
+ // See https://github.com/KhronosGroup/WebGL/issues/3518
+ }
+ _ => {
+ writeln!(
+ self.out,
+ "invariant {};",
+ glsl_built_in(
+ built_in,
+ VaryingOptions::from_writer_options(self.options, output)
+ )
+ )?;
+ }
+ }
+ }
+ return Ok(());
+ }
+ };
+
+ // Write the interpolation modifier if needed
+ //
+ // We ignore all interpolation and auxiliary modifiers that aren't used in fragment
+ // shaders' input globals or vertex shaders' output globals.
+ let emit_interpolation_and_auxiliary = match self.entry_point.stage {
+ ShaderStage::Vertex => output,
+ ShaderStage::Fragment => !output,
+ ShaderStage::Compute => false,
+ };
+
+ // Write the I/O locations, if allowed
+ let io_location = if self.options.version.supports_explicit_locations()
+ || !emit_interpolation_and_auxiliary
+ {
+ if self.options.version.supports_io_locations() {
+ if second_blend_source {
+ write!(self.out, "layout(location = {location}, index = 1) ")?;
+ } else {
+ write!(self.out, "layout(location = {location}) ")?;
+ }
+ None
+ } else {
+ Some(VaryingLocation {
+ location,
+ index: second_blend_source as u32,
+ })
+ }
+ } else {
+ None
+ };
+
+ // Write the interpolation qualifier.
+ if let Some(interp) = interpolation {
+ if emit_interpolation_and_auxiliary {
+ write!(self.out, "{} ", glsl_interpolation(interp))?;
+ }
+ }
+
+ // Write the sampling auxiliary qualifier.
+ //
+ // Before GLSL 4.2, the `centroid` and `sample` qualifiers were required to appear
+ // immediately before the `in` / `out` qualifier, so we'll just follow that rule
+ // here, regardless of the version.
+ if let Some(sampling) = sampling {
+ if emit_interpolation_and_auxiliary {
+ if let Some(qualifier) = glsl_sampling(sampling) {
+ write!(self.out, "{qualifier} ")?;
+ }
+ }
+ }
+
+ // Write the input/output qualifier.
+ write!(self.out, "{} ", if output { "out" } else { "in" })?;
+
+ // Write the type
+ // `write_type` adds no leading or trailing spaces
+ self.write_type(ty)?;
+
+ // Finally write the global name and end the global with a `;` and a newline
+ // Leading space is important
+ let vname = VaryingName {
+ binding: &crate::Binding::Location {
+ location,
+ interpolation: None,
+ sampling: None,
+ second_blend_source,
+ },
+ stage: self.entry_point.stage,
+ options: VaryingOptions::from_writer_options(self.options, output),
+ };
+ writeln!(self.out, " {vname};")?;
+
+ if let Some(location) = io_location {
+ self.varying.insert(vname.to_string(), location);
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write functions (both entry points and regular functions)
+ ///
+ /// # Notes
+ /// Adds a newline
+ fn write_function(
+ &mut self,
+ ty: back::FunctionType,
+ func: &crate::Function,
+ info: &valid::FunctionInfo,
+ ) -> BackendResult {
+ // Create a function context for the function being written
+ let ctx = back::FunctionCtx {
+ ty,
+ info,
+ expressions: &func.expressions,
+ named_expressions: &func.named_expressions,
+ };
+
+ self.named_expressions.clear();
+ self.update_expressions_to_bake(func, info);
+
+ // Write the function header
+ //
+ // glsl headers are the same as in c:
+ // `ret_type name(args)`
+ // `ret_type` is the return type
+ // `name` is the function name
+ // `args` is a comma separated list of `type name`
+ // | - `type` is the argument type
+ // | - `name` is the argument name
+
+ // Start by writing the return type if any otherwise write void
+ // This is the only place where `void` is a valid type
+ // (though it's more a keyword than a type)
+ if let back::FunctionType::EntryPoint(_) = ctx.ty {
+ write!(self.out, "void")?;
+ } else if let Some(ref result) = func.result {
+ self.write_type(result.ty)?;
+ if let TypeInner::Array { base, size, .. } = self.module.types[result.ty].inner {
+ self.write_array_size(base, size)?
+ }
+ } else {
+ write!(self.out, "void")?;
+ }
+
+ // Write the function name and open parentheses for the argument list
+ let function_name = match ctx.ty {
+ back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
+ back::FunctionType::EntryPoint(_) => "main",
+ };
+ write!(self.out, " {function_name}(")?;
+
+ // Write the comma separated argument list
+ //
+ // We need access to `Self` here so we use the reference passed to the closure as an
+ // argument instead of capturing as that would cause a borrow checker error
+ let arguments = match ctx.ty {
+ back::FunctionType::EntryPoint(_) => &[][..],
+ back::FunctionType::Function(_) => &func.arguments,
+ };
+ let arguments: Vec<_> = arguments
+ .iter()
+ .enumerate()
+ .filter(|&(_, arg)| match self.module.types[arg.ty].inner {
+ TypeInner::Sampler { .. } => false,
+ _ => true,
+ })
+ .collect();
+ self.write_slice(&arguments, |this, _, &(i, arg)| {
+ // Write the argument type
+ match this.module.types[arg.ty].inner {
+ // We treat images separately because they might require
+ // writing the storage format
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ // Write the storage format if needed
+ if let TypeInner::Image {
+ class: crate::ImageClass::Storage { format, .. },
+ ..
+ } = this.module.types[arg.ty].inner
+ {
+ write!(this.out, "layout({}) ", glsl_storage_format(format)?)?;
+ }
+
+ // write the type
+ //
+ // This is way we need the leading space because `write_image_type` doesn't add
+ // any spaces at the beginning or end
+ this.write_image_type(dim, arrayed, class)?;
+ }
+ TypeInner::Pointer { base, .. } => {
+ // write parameter qualifiers
+ write!(this.out, "inout ")?;
+ this.write_type(base)?;
+ }
+ // All other types are written by `write_type`
+ _ => {
+ this.write_type(arg.ty)?;
+ }
+ }
+
+ // Write the argument name
+ // The leading space is important
+ write!(this.out, " {}", &this.names[&ctx.argument_key(i as u32)])?;
+
+ // Write array size
+ match this.module.types[arg.ty].inner {
+ TypeInner::Array { base, size, .. } => {
+ this.write_array_size(base, size)?;
+ }
+ TypeInner::Pointer { base, .. } => {
+ if let TypeInner::Array { base, size, .. } = this.module.types[base].inner {
+ this.write_array_size(base, size)?;
+ }
+ }
+ _ => {}
+ }
+
+ Ok(())
+ })?;
+
+ // Close the parentheses and open braces to start the function body
+ writeln!(self.out, ") {{")?;
+
+ if self.options.zero_initialize_workgroup_memory
+ && ctx.ty.is_compute_entry_point(self.module)
+ {
+ self.write_workgroup_variables_initialization(&ctx)?;
+ }
+
+ // Compose the function arguments from globals, in case of an entry point.
+ if let back::FunctionType::EntryPoint(ep_index) = ctx.ty {
+ let stage = self.module.entry_points[ep_index as usize].stage;
+ for (index, arg) in func.arguments.iter().enumerate() {
+ write!(self.out, "{}", back::INDENT)?;
+ self.write_type(arg.ty)?;
+ let name = &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
+ write!(self.out, " {name}")?;
+ write!(self.out, " = ")?;
+ match self.module.types[arg.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ self.write_type(arg.ty)?;
+ write!(self.out, "(")?;
+ for (index, member) in members.iter().enumerate() {
+ let varying_name = VaryingName {
+ binding: member.binding.as_ref().unwrap(),
+ stage,
+ options: VaryingOptions::from_writer_options(self.options, false),
+ };
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "{varying_name}")?;
+ }
+ writeln!(self.out, ");")?;
+ }
+ _ => {
+ let varying_name = VaryingName {
+ binding: arg.binding.as_ref().unwrap(),
+ stage,
+ options: VaryingOptions::from_writer_options(self.options, false),
+ };
+ writeln!(self.out, "{varying_name};")?;
+ }
+ }
+ }
+ }
+
+ // Write all function locals
+ // Locals are `type name (= init)?;` where the init part (including the =) are optional
+ //
+ // Always adds a newline
+ for (handle, local) in func.local_variables.iter() {
+ // Write indentation (only for readability) and the type
+ // `write_type` adds no trailing space
+ write!(self.out, "{}", back::INDENT)?;
+ self.write_type(local.ty)?;
+
+ // Write the local name
+ // The leading space is important
+ write!(self.out, " {}", self.names[&ctx.name_key(handle)])?;
+ // Write size for array type
+ if let TypeInner::Array { base, size, .. } = self.module.types[local.ty].inner {
+ self.write_array_size(base, size)?;
+ }
+ // Write the local initializer if needed
+ if let Some(init) = local.init {
+ // Put the equal signal only if there's a initializer
+ // The leading and trailing spaces aren't needed but help with readability
+ write!(self.out, " = ")?;
+
+ // Write the constant
+ // `write_constant` adds no trailing or leading space/newline
+ self.write_expr(init, &ctx)?;
+ } else if is_value_init_supported(self.module, local.ty) {
+ write!(self.out, " = ")?;
+ self.write_zero_init_value(local.ty)?;
+ }
+
+ // Finish the local with `;` and add a newline (only for readability)
+ writeln!(self.out, ";")?
+ }
+
+ // Write the function body (statement list)
+ for sta in func.body.iter() {
+ // Write a statement, the indentation should always be 1 when writing the function body
+ // `write_stmt` adds a newline
+ self.write_stmt(sta, &ctx, back::Level(1))?;
+ }
+
+ // Close braces and add a newline
+ writeln!(self.out, "}}")?;
+
+ Ok(())
+ }
+
+ fn write_workgroup_variables_initialization(
+ &mut self,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ let mut vars = self
+ .module
+ .global_variables
+ .iter()
+ .filter(|&(handle, var)| {
+ !ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ })
+ .peekable();
+
+ if vars.peek().is_some() {
+ let level = back::Level(1);
+
+ writeln!(self.out, "{level}if (gl_LocalInvocationID == uvec3(0u)) {{")?;
+
+ for (handle, var) in vars {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{}{} = ", level.next(), name)?;
+ self.write_zero_init_value(var.ty)?;
+ writeln!(self.out, ";")?;
+ }
+
+ writeln!(self.out, "{level}}}")?;
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+ }
+
+ Ok(())
+ }
+
+ /// Write a list of comma separated `T` values using a writer function `F`.
+ ///
+ /// The writer function `F` receives a mutable reference to `self` that if needed won't cause
+ /// borrow checker issues (using for example a closure with `self` will cause issues), the
+ /// second argument is the 0 based index of the element on the list, and the last element is
+ /// a reference to the element `T` being written
+ ///
+ /// # Notes
+ /// - Adds no newlines or leading/trailing whitespace
+ /// - The last element won't have a trailing `,`
+ fn write_slice<T, F: FnMut(&mut Self, u32, &T) -> BackendResult>(
+ &mut self,
+ data: &[T],
+ mut f: F,
+ ) -> BackendResult {
+ // Loop through `data` invoking `f` for each element
+ for (index, item) in data.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ f(self, index as u32, item)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write global constants
+ fn write_global_constant(&mut self, handle: Handle<crate::Constant>) -> BackendResult {
+ write!(self.out, "const ")?;
+ let constant = &self.module.constants[handle];
+ self.write_type(constant.ty)?;
+ let name = &self.names[&NameKey::Constant(handle)];
+ write!(self.out, " {name}")?;
+ if let TypeInner::Array { base, size, .. } = self.module.types[constant.ty].inner {
+ self.write_array_size(base, size)?;
+ }
+ write!(self.out, " = ")?;
+ self.write_const_expr(constant.init)?;
+ writeln!(self.out, ";")?;
+ Ok(())
+ }
+
+ /// Helper method used to output a dot product as an arithmetic expression
+ ///
+ fn write_dot_product(
+ &mut self,
+ arg: Handle<crate::Expression>,
+ arg1: Handle<crate::Expression>,
+ size: usize,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ // Write parentheses around the dot product expression to prevent operators
+ // with different precedences from applying earlier.
+ write!(self.out, "(")?;
+
+ // Cycle trough all the components of the vector
+ for index in 0..size {
+ let component = back::COMPONENTS[index];
+ // Write the addition to the previous product
+ // This will print an extra '+' at the beginning but that is fine in glsl
+ write!(self.out, " + ")?;
+ // Write the first vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.write_expr(arg, ctx)?;
+ // Access the current component on the first vector
+ write!(self.out, ".{component} * ")?;
+ // Write the second vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.write_expr(arg1, ctx)?;
+ // Access the current component on the second vector
+ write!(self.out, ".{component}")?;
+ }
+
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ /// Helper method used to write structs
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_struct_body(
+ &mut self,
+ handle: Handle<crate::Type>,
+ members: &[crate::StructMember],
+ ) -> BackendResult {
+ // glsl structs are written as in C
+ // `struct name() { members };`
+ // | `struct` is a keyword
+ // | `name` is the struct name
+ // | `members` is a semicolon separated list of `type name`
+ // | `type` is the member type
+ // | `name` is the member name
+ writeln!(self.out, "{{")?;
+
+ for (idx, member) in members.iter().enumerate() {
+ // The indentation is only for readability
+ write!(self.out, "{}", back::INDENT)?;
+
+ match self.module.types[member.ty].inner {
+ TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ } => {
+ self.write_type(base)?;
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::StructMember(handle, idx as u32)]
+ )?;
+ // Write [size]
+ self.write_array_size(base, size)?;
+ // Newline is important
+ writeln!(self.out, ";")?;
+ }
+ _ => {
+ // Write the member type
+ // Adds no trailing space
+ self.write_type(member.ty)?;
+
+ // Write the member name and put a semicolon
+ // The leading space is important
+ // All members must have a semicolon even the last one
+ writeln!(
+ self.out,
+ " {};",
+ &self.names[&NameKey::StructMember(handle, idx as u32)]
+ )?;
+ }
+ }
+ }
+
+ write!(self.out, "}}")?;
+ Ok(())
+ }
+
+ /// Helper method used to write statements
+ ///
+ /// # Notes
+ /// Always adds a newline
+ fn write_stmt(
+ &mut self,
+ sta: &crate::Statement,
+ ctx: &back::FunctionCtx,
+ level: back::Level,
+ ) -> BackendResult {
+ use crate::Statement;
+
+ match *sta {
+ // This is where we can generate intermediate constants for some expression types.
+ Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ let ptr_class = ctx.resolve_type(handle, &self.module.types).pointer_space();
+ let expr_name = if ptr_class.is_some() {
+ // GLSL can't save a pointer-valued expression in a variable,
+ // but we shouldn't ever need to: they should never be named expressions,
+ // and none of the expression types flagged by bake_ref_count can be pointer-valued.
+ None
+ } else if let Some(name) = ctx.named_expressions.get(&handle) {
+ // Front end provides names for all variables at the start of writing.
+ // But we write them to step by step. We need to recache them
+ // Otherwise, we could accidentally write variable name instead of full expression.
+ // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
+ Some(self.namer.call(name))
+ } else if self.need_bake_expressions.contains(&handle) {
+ Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
+ } else {
+ None
+ };
+
+ // If we are going to write an `ImageLoad` next and the target image
+ // is sampled and we are using the `Restrict` policy for bounds
+ // checking images we need to write a local holding the clamped lod.
+ if let crate::Expression::ImageLoad {
+ image,
+ level: Some(level_expr),
+ ..
+ } = ctx.expressions[handle]
+ {
+ if let TypeInner::Image {
+ class: crate::ImageClass::Sampled { .. },
+ ..
+ } = *ctx.resolve_type(image, &self.module.types)
+ {
+ if let proc::BoundsCheckPolicy::Restrict = self.policies.image_load {
+ write!(self.out, "{level}")?;
+ self.write_clamped_lod(ctx, handle, image, level_expr)?
+ }
+ }
+ }
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{level}")?;
+ self.write_named_expr(handle, name, handle, ctx)?;
+ }
+ }
+ }
+ // Blocks are simple we just need to write the block statements between braces
+ // We could also just print the statements but this is more readable and maps more
+ // closely to the IR
+ Statement::Block(ref block) => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "{{")?;
+ for sta in block.iter() {
+ // Increase the indentation to help with readability
+ self.write_stmt(sta, ctx, level.next())?
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ // Ifs are written as in C:
+ // ```
+ // if(condition) {
+ // accept
+ // } else {
+ // reject
+ // }
+ // ```
+ Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "if (")?;
+ self.write_expr(condition, ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ for sta in accept {
+ // Increase indentation to help with readability
+ self.write_stmt(sta, ctx, level.next())?;
+ }
+
+ // If there are no statements in the reject block we skip writing it
+ // This is only for readability
+ if !reject.is_empty() {
+ writeln!(self.out, "{level}}} else {{")?;
+
+ for sta in reject {
+ // Increase indentation to help with readability
+ self.write_stmt(sta, ctx, level.next())?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ // Switch are written as in C:
+ // ```
+ // switch (selector) {
+ // // Fallthrough
+ // case label:
+ // block
+ // // Non fallthrough
+ // case label:
+ // block
+ // break;
+ // default:
+ // block
+ // }
+ // ```
+ // Where the `default` case happens isn't important but we put it last
+ // so that we don't need to print a `break` for it
+ Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ // Start the switch
+ write!(self.out, "{level}")?;
+ write!(self.out, "switch(")?;
+ self.write_expr(selector, ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ // Write all cases
+ let l2 = level.next();
+ for case in cases {
+ match case.value {
+ crate::SwitchValue::I32(value) => write!(self.out, "{l2}case {value}:")?,
+ crate::SwitchValue::U32(value) => write!(self.out, "{l2}case {value}u:")?,
+ crate::SwitchValue::Default => write!(self.out, "{l2}default:")?,
+ }
+
+ let write_block_braces = !(case.fall_through && case.body.is_empty());
+ if write_block_braces {
+ writeln!(self.out, " {{")?;
+ } else {
+ writeln!(self.out)?;
+ }
+
+ for sta in case.body.iter() {
+ self.write_stmt(sta, ctx, l2.next())?;
+ }
+
+ if !case.fall_through && case.body.last().map_or(true, |s| !s.is_terminator()) {
+ writeln!(self.out, "{}break;", l2.next())?;
+ }
+
+ if write_block_braces {
+ writeln!(self.out, "{l2}}}")?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ // Loops in naga IR are based on wgsl loops, glsl can emulate the behaviour by using a
+ // while true loop and appending the continuing block to the body resulting on:
+ // ```
+ // bool loop_init = true;
+ // while(true) {
+ // if (!loop_init) { <continuing> }
+ // loop_init = false;
+ // <body>
+ // }
+ // ```
+ Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ if !continuing.is_empty() || break_if.is_some() {
+ let gate_name = self.namer.call("loop_init");
+ writeln!(self.out, "{level}bool {gate_name} = true;")?;
+ writeln!(self.out, "{level}while(true) {{")?;
+ let l2 = level.next();
+ let l3 = l2.next();
+ writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
+ for sta in continuing {
+ self.write_stmt(sta, ctx, l3)?;
+ }
+ if let Some(condition) = break_if {
+ write!(self.out, "{l3}if (")?;
+ self.write_expr(condition, ctx)?;
+ writeln!(self.out, ") {{")?;
+ writeln!(self.out, "{}break;", l3.next())?;
+ writeln!(self.out, "{l3}}}")?;
+ }
+ writeln!(self.out, "{l2}}}")?;
+ writeln!(self.out, "{}{} = false;", level.next(), gate_name)?;
+ } else {
+ writeln!(self.out, "{level}while(true) {{")?;
+ }
+ for sta in body {
+ self.write_stmt(sta, ctx, level.next())?;
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ // Break, continue and return as written as in C
+ // `break;`
+ Statement::Break => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "break;")?
+ }
+ // `continue;`
+ Statement::Continue => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "continue;")?
+ }
+ // `return expr;`, `expr` is optional
+ Statement::Return { value } => {
+ write!(self.out, "{level}")?;
+ match ctx.ty {
+ back::FunctionType::Function(_) => {
+ write!(self.out, "return")?;
+ // Write the expression to be returned if needed
+ if let Some(expr) = value {
+ write!(self.out, " ")?;
+ self.write_expr(expr, ctx)?;
+ }
+ writeln!(self.out, ";")?;
+ }
+ back::FunctionType::EntryPoint(ep_index) => {
+ let mut has_point_size = false;
+ let ep = &self.module.entry_points[ep_index as usize];
+ if let Some(ref result) = ep.function.result {
+ let value = value.unwrap();
+ match self.module.types[result.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let temp_struct_name = match ctx.expressions[value] {
+ crate::Expression::Compose { .. } => {
+ let return_struct = "_tmp_return";
+ write!(
+ self.out,
+ "{} {} = ",
+ &self.names[&NameKey::Type(result.ty)],
+ return_struct
+ )?;
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ";")?;
+ write!(self.out, "{level}")?;
+ Some(return_struct)
+ }
+ _ => None,
+ };
+
+ for (index, member) in members.iter().enumerate() {
+ if let Some(crate::Binding::BuiltIn(
+ crate::BuiltIn::PointSize,
+ )) = member.binding
+ {
+ has_point_size = true;
+ }
+
+ let varying_name = VaryingName {
+ binding: member.binding.as_ref().unwrap(),
+ stage: ep.stage,
+ options: VaryingOptions::from_writer_options(
+ self.options,
+ true,
+ ),
+ };
+ write!(self.out, "{varying_name} = ")?;
+
+ if let Some(struct_name) = temp_struct_name {
+ write!(self.out, "{struct_name}")?;
+ } else {
+ self.write_expr(value, ctx)?;
+ }
+
+ // Write field name
+ writeln!(
+ self.out,
+ ".{};",
+ &self.names
+ [&NameKey::StructMember(result.ty, index as u32)]
+ )?;
+ write!(self.out, "{level}")?;
+ }
+ }
+ _ => {
+ let name = VaryingName {
+ binding: result.binding.as_ref().unwrap(),
+ stage: ep.stage,
+ options: VaryingOptions::from_writer_options(
+ self.options,
+ true,
+ ),
+ };
+ write!(self.out, "{name} = ")?;
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ";")?;
+ write!(self.out, "{level}")?;
+ }
+ }
+ }
+
+ let is_vertex_stage = self.module.entry_points[ep_index as usize].stage
+ == ShaderStage::Vertex;
+ if is_vertex_stage
+ && self
+ .options
+ .writer_flags
+ .contains(WriterFlags::ADJUST_COORDINATE_SPACE)
+ {
+ writeln!(
+ self.out,
+ "gl_Position.yz = vec2(-gl_Position.y, gl_Position.z * 2.0 - gl_Position.w);",
+ )?;
+ write!(self.out, "{level}")?;
+ }
+
+ if is_vertex_stage
+ && self
+ .options
+ .writer_flags
+ .contains(WriterFlags::FORCE_POINT_SIZE)
+ && !has_point_size
+ {
+ writeln!(self.out, "gl_PointSize = 1.0;")?;
+ write!(self.out, "{level}")?;
+ }
+ writeln!(self.out, "return;")?;
+ }
+ }
+ }
+ // This is one of the places were glsl adds to the syntax of C in this case the discard
+ // keyword which ceases all further processing in a fragment shader, it's called OpKill
+ // in spir-v that's why it's called `Statement::Kill`
+ Statement::Kill => writeln!(self.out, "{level}discard;")?,
+ Statement::Barrier(flags) => {
+ self.write_barrier(flags, level)?;
+ }
+ // Stores in glsl are just variable assignments written as `pointer = value;`
+ Statement::Store { pointer, value } => {
+ write!(self.out, "{level}")?;
+ self.write_expr(pointer, ctx)?;
+ write!(self.out, " = ")?;
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ";")?
+ }
+ Statement::WorkGroupUniformLoad { pointer, result } => {
+ // GLSL doesn't have pointers, which means that this backend needs to ensure that
+ // the actual "loading" is happening between the two barriers.
+ // This is done in `Emit` by never emitting a variable name for pointer variables
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+
+ let result_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ write!(self.out, "{level}")?;
+ // Expressions cannot have side effects, so just writing the expression here is fine.
+ self.write_named_expr(pointer, result_name, result, ctx)?;
+
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+ }
+ // Stores a value into an image.
+ Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ write!(self.out, "{level}")?;
+ self.write_image_store(ctx, image, coordinate, array_index, value)?
+ }
+ // A `Call` is written `name(arguments)` where `arguments` is a comma separated expressions list
+ Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ if let Some(expr) = result {
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ let result = self.module.functions[function].result.as_ref().unwrap();
+ self.write_type(result.ty)?;
+ write!(self.out, " {name}")?;
+ if let TypeInner::Array { base, size, .. } = self.module.types[result.ty].inner
+ {
+ self.write_array_size(base, size)?
+ }
+ write!(self.out, " = ")?;
+ self.named_expressions.insert(expr, name);
+ }
+ write!(self.out, "{}(", &self.names[&NameKey::Function(function)])?;
+ let arguments: Vec<_> = arguments
+ .iter()
+ .enumerate()
+ .filter_map(|(i, arg)| {
+ let arg_ty = self.module.functions[function].arguments[i].ty;
+ match self.module.types[arg_ty].inner {
+ TypeInner::Sampler { .. } => None,
+ _ => Some(*arg),
+ }
+ })
+ .collect();
+ self.write_slice(&arguments, |this, _, arg| this.write_expr(*arg, ctx))?;
+ writeln!(self.out, ");")?
+ }
+ Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ let res_ty = ctx.resolve_type(result, &self.module.types);
+ self.write_value_type(res_ty)?;
+ write!(self.out, " {res_name} = ")?;
+ self.named_expressions.insert(result, res_name);
+
+ let fun_str = fun.to_glsl();
+ write!(self.out, "atomic{fun_str}(")?;
+ self.write_expr(pointer, ctx)?;
+ write!(self.out, ", ")?;
+ // handle the special cases
+ match *fun {
+ crate::AtomicFunction::Subtract => {
+ // we just wrote `InterlockedAdd`, so negate the argument
+ write!(self.out, "-")?;
+ }
+ crate::AtomicFunction::Exchange { compare: Some(_) } => {
+ return Err(Error::Custom(
+ "atomic CompareExchange is not implemented".to_string(),
+ ));
+ }
+ _ => {}
+ }
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ Statement::RayQuery { .. } => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Write a const expression.
+ ///
+ /// Write `expr`, a handle to an [`Expression`] in the current [`Module`]'s
+ /// constant expression arena, as GLSL expression.
+ ///
+ /// # Notes
+ /// Adds no newlines or leading/trailing whitespace
+ ///
+ /// [`Expression`]: crate::Expression
+ /// [`Module`]: crate::Module
+ fn write_const_expr(&mut self, expr: Handle<crate::Expression>) -> BackendResult {
+ self.write_possibly_const_expr(
+ expr,
+ &self.module.const_expressions,
+ |expr| &self.info[expr],
+ |writer, expr| writer.write_const_expr(expr),
+ )
+ }
+
+ /// Write [`Expression`] variants that can occur in both runtime and const expressions.
+ ///
+ /// Write `expr`, a handle to an [`Expression`] in the arena `expressions`,
+ /// as as GLSL expression. This must be one of the [`Expression`] variants
+ /// that is allowed to occur in constant expressions.
+ ///
+ /// Use `write_expression` to write subexpressions.
+ ///
+ /// This is the common code for `write_expr`, which handles arbitrary
+ /// runtime expressions, and `write_const_expr`, which only handles
+ /// const-expressions. Each of those callers passes itself (essentially) as
+ /// the `write_expression` callback, so that subexpressions are restricted
+ /// to the appropriate variants.
+ ///
+ /// # Notes
+ /// Adds no newlines or leading/trailing whitespace
+ ///
+ /// [`Expression`]: crate::Expression
+ fn write_possibly_const_expr<'w, I, E>(
+ &'w mut self,
+ expr: Handle<crate::Expression>,
+ expressions: &crate::Arena<crate::Expression>,
+ info: I,
+ write_expression: E,
+ ) -> BackendResult
+ where
+ I: Fn(Handle<crate::Expression>) -> &'w proc::TypeResolution,
+ E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
+ {
+ use crate::Expression;
+
+ match expressions[expr] {
+ Expression::Literal(literal) => {
+ match literal {
+ // Floats are written using `Debug` instead of `Display` because it always appends the
+ // decimal part even it's zero which is needed for a valid glsl float constant
+ crate::Literal::F64(value) => write!(self.out, "{:?}LF", value)?,
+ crate::Literal::F32(value) => write!(self.out, "{:?}", value)?,
+ // Unsigned integers need a `u` at the end
+ //
+ // While `core` doesn't necessarily need it, it's allowed and since `es` needs it we
+ // always write it as the extra branch wouldn't have any benefit in readability
+ crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
+ crate::Literal::I32(value) => write!(self.out, "{}", value)?,
+ crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
+ crate::Literal::I64(_) => {
+ return Err(Error::Custom("GLSL has no 64-bit integer type".into()));
+ }
+ crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
+ return Err(Error::Custom(
+ "Abstract types should not appear in IR presented to backends".into(),
+ ));
+ }
+ }
+ }
+ Expression::Constant(handle) => {
+ let constant = &self.module.constants[handle];
+ if constant.name.is_some() {
+ write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
+ } else {
+ self.write_const_expr(constant.init)?;
+ }
+ }
+ Expression::ZeroValue(ty) => {
+ self.write_zero_init_value(ty)?;
+ }
+ Expression::Compose { ty, ref components } => {
+ self.write_type(ty)?;
+
+ if let TypeInner::Array { base, size, .. } = self.module.types[ty].inner {
+ self.write_array_size(base, size)?;
+ }
+
+ write!(self.out, "(")?;
+ for (index, component) in components.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ write_expression(self, *component)?;
+ }
+ write!(self.out, ")")?
+ }
+ // `Splat` needs to actually write down a vector, it's not always inferred in GLSL.
+ Expression::Splat { size: _, value } => {
+ let resolved = info(expr).inner_with(&self.module.types);
+ self.write_value_type(resolved)?;
+ write!(self.out, "(")?;
+ write_expression(self, value)?;
+ write!(self.out, ")")?
+ }
+ _ => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Helper method to write expressions
+ ///
+ /// # Notes
+ /// Doesn't add any newlines or leading/trailing spaces
+ fn write_expr(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ use crate::Expression;
+
+ if let Some(name) = self.named_expressions.get(&expr) {
+ write!(self.out, "{name}")?;
+ return Ok(());
+ }
+
+ match ctx.expressions[expr] {
+ Expression::Literal(_)
+ | Expression::Constant(_)
+ | Expression::ZeroValue(_)
+ | Expression::Compose { .. }
+ | Expression::Splat { .. } => {
+ self.write_possibly_const_expr(
+ expr,
+ ctx.expressions,
+ |expr| &ctx.info[expr].ty,
+ |writer, expr| writer.write_expr(expr, ctx),
+ )?;
+ }
+ // `Access` is applied to arrays, vectors and matrices and is written as indexing
+ Expression::Access { base, index } => {
+ self.write_expr(base, ctx)?;
+ write!(self.out, "[")?;
+ self.write_expr(index, ctx)?;
+ write!(self.out, "]")?
+ }
+ // `AccessIndex` is the same as `Access` except that the index is a constant and it can
+ // be applied to structs, in this case we need to find the name of the field at that
+ // index and write `base.field_name`
+ Expression::AccessIndex { base, index } => {
+ self.write_expr(base, ctx)?;
+
+ let base_ty_res = &ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&self.module.types);
+ let base_ty_handle = match *resolved {
+ TypeInner::Pointer { base, space: _ } => {
+ resolved = &self.module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+
+ match *resolved {
+ TypeInner::Vector { .. } => {
+ // Write vector access as a swizzle
+ write!(self.out, ".{}", back::COMPONENTS[index as usize])?
+ }
+ TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?,
+ TypeInner::Struct { .. } => {
+ // This will never panic in case the type is a `Struct`, this is not true
+ // for other types so we can only check while inside this match arm
+ let ty = base_ty_handle.unwrap();
+
+ write!(
+ self.out,
+ ".{}",
+ &self.names[&NameKey::StructMember(ty, index)]
+ )?
+ }
+ ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
+ }
+ }
+ // `Swizzle` adds a few letters behind the dot.
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.write_expr(vector, ctx)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ self.out.write_char(back::COMPONENTS[sc as usize])?;
+ }
+ }
+ // Function arguments are written as the argument name
+ Expression::FunctionArgument(pos) => {
+ write!(self.out, "{}", &self.names[&ctx.argument_key(pos)])?
+ }
+ // Global variables need some special work for their name but
+ // `get_global_name` does the work for us
+ Expression::GlobalVariable(handle) => {
+ let global = &self.module.global_variables[handle];
+ self.write_global_name(handle, global)?
+ }
+ // A local is written as it's name
+ Expression::LocalVariable(handle) => {
+ write!(self.out, "{}", self.names[&ctx.name_key(handle)])?
+ }
+ // glsl has no pointers so there's no load operation, just write the pointer expression
+ Expression::Load { pointer } => self.write_expr(pointer, ctx)?,
+ // `ImageSample` is a bit complicated compared to the rest of the IR.
+ //
+ // First there are three variations depending whether the sample level is explicitly set,
+ // if it's automatic or it it's bias:
+ // `texture(image, coordinate)` - Automatic sample level
+ // `texture(image, coordinate, bias)` - Bias sample level
+ // `textureLod(image, coordinate, level)` - Zero or Exact sample level
+ //
+ // Furthermore if `depth_ref` is some we need to append it to the coordinate vector
+ Expression::ImageSample {
+ image,
+ sampler: _, //TODO?
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ let dim = match *ctx.resolve_type(image, &self.module.types) {
+ TypeInner::Image { dim, .. } => dim,
+ _ => unreachable!(),
+ };
+
+ if dim == crate::ImageDimension::Cube
+ && array_index.is_some()
+ && depth_ref.is_some()
+ {
+ match level {
+ crate::SampleLevel::Zero
+ | crate::SampleLevel::Exact(_)
+ | crate::SampleLevel::Gradient { .. }
+ | crate::SampleLevel::Bias(_) => {
+ return Err(Error::Custom(String::from(
+ "gsamplerCubeArrayShadow isn't supported in textureGrad, \
+ textureLod or texture with bias",
+ )))
+ }
+ crate::SampleLevel::Auto => {}
+ }
+ }
+
+ // textureLod on sampler2DArrayShadow and samplerCubeShadow does not exist in GLSL.
+ // To emulate this, we will have to use textureGrad with a constant gradient of 0.
+ let workaround_lod_array_shadow_as_grad = (array_index.is_some()
+ || dim == crate::ImageDimension::Cube)
+ && depth_ref.is_some()
+ && gather.is_none()
+ && !self
+ .options
+ .writer_flags
+ .contains(WriterFlags::TEXTURE_SHADOW_LOD);
+
+ //Write the function to be used depending on the sample level
+ let fun_name = match level {
+ crate::SampleLevel::Zero if gather.is_some() => "textureGather",
+ crate::SampleLevel::Auto | crate::SampleLevel::Bias(_) => "texture",
+ crate::SampleLevel::Zero | crate::SampleLevel::Exact(_) => {
+ if workaround_lod_array_shadow_as_grad {
+ "textureGrad"
+ } else {
+ "textureLod"
+ }
+ }
+ crate::SampleLevel::Gradient { .. } => "textureGrad",
+ };
+ let offset_name = match offset {
+ Some(_) => "Offset",
+ None => "",
+ };
+
+ write!(self.out, "{fun_name}{offset_name}(")?;
+
+ // Write the image that will be used
+ self.write_expr(image, ctx)?;
+ // The space here isn't required but it helps with readability
+ write!(self.out, ", ")?;
+
+ // We need to get the coordinates vector size to later build a vector that's `size + 1`
+ // if `depth_ref` is some, if it isn't a vector we panic as that's not a valid expression
+ let mut coord_dim = match *ctx.resolve_type(coordinate, &self.module.types) {
+ TypeInner::Vector { size, .. } => size as u8,
+ TypeInner::Scalar { .. } => 1,
+ _ => unreachable!(),
+ };
+
+ if array_index.is_some() {
+ coord_dim += 1;
+ }
+ let merge_depth_ref = depth_ref.is_some() && gather.is_none() && coord_dim < 4;
+ if merge_depth_ref {
+ coord_dim += 1;
+ }
+
+ let tex_1d_hack = dim == crate::ImageDimension::D1 && self.options.version.is_es();
+ let is_vec = tex_1d_hack || coord_dim != 1;
+ // Compose a new texture coordinates vector
+ if is_vec {
+ write!(self.out, "vec{}(", coord_dim + tex_1d_hack as u8)?;
+ }
+ self.write_expr(coordinate, ctx)?;
+ if tex_1d_hack {
+ write!(self.out, ", 0.0")?;
+ }
+ if let Some(expr) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+ if merge_depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(depth_ref.unwrap(), ctx)?;
+ }
+ if is_vec {
+ write!(self.out, ")")?;
+ }
+
+ if let (Some(expr), false) = (depth_ref, merge_depth_ref) {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+
+ match level {
+ // Auto needs no more arguments
+ crate::SampleLevel::Auto => (),
+ // Zero needs level set to 0
+ crate::SampleLevel::Zero => {
+ if workaround_lod_array_shadow_as_grad {
+ let vec_dim = match dim {
+ crate::ImageDimension::Cube => 3,
+ _ => 2,
+ };
+ write!(self.out, ", vec{vec_dim}(0.0), vec{vec_dim}(0.0)")?;
+ } else if gather.is_none() {
+ write!(self.out, ", 0.0")?;
+ }
+ }
+ // Exact and bias require another argument
+ crate::SampleLevel::Exact(expr) => {
+ if workaround_lod_array_shadow_as_grad {
+ log::warn!("Unable to `textureLod` a shadow array, ignoring the LOD");
+ write!(self.out, ", vec2(0,0), vec2(0,0)")?;
+ } else {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+ }
+ crate::SampleLevel::Bias(_) => {
+ // This needs to be done after the offset writing
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ // If we are using sampler2D to replace sampler1D, we also
+ // need to make sure to use vec2 gradients
+ if tex_1d_hack {
+ write!(self.out, ", vec2(")?;
+ self.write_expr(x, ctx)?;
+ write!(self.out, ", 0.0)")?;
+ write!(self.out, ", vec2(")?;
+ self.write_expr(y, ctx)?;
+ write!(self.out, ", 0.0)")?;
+ } else {
+ write!(self.out, ", ")?;
+ self.write_expr(x, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(y, ctx)?;
+ }
+ }
+ }
+
+ if let Some(constant) = offset {
+ write!(self.out, ", ")?;
+ if tex_1d_hack {
+ write!(self.out, "ivec2(")?;
+ }
+ self.write_const_expr(constant)?;
+ if tex_1d_hack {
+ write!(self.out, ", 0)")?;
+ }
+ }
+
+ // Bias is always the last argument
+ if let crate::SampleLevel::Bias(expr) = level {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+
+ if let (Some(component), None) = (gather, depth_ref) {
+ write!(self.out, ", {}", component as usize)?;
+ }
+
+ // End the function
+ write!(self.out, ")")?
+ }
+ Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => self.write_image_load(expr, ctx, image, coordinate, array_index, sample, level)?,
+ // Query translates into one of the:
+ // - textureSize/imageSize
+ // - textureQueryLevels
+ // - textureSamples/imageSamples
+ Expression::ImageQuery { image, query } => {
+ use crate::ImageClass;
+
+ // This will only panic if the module is invalid
+ let (dim, class) = match *ctx.resolve_type(image, &self.module.types) {
+ TypeInner::Image {
+ dim,
+ arrayed: _,
+ class,
+ } => (dim, class),
+ _ => unreachable!(),
+ };
+ let components = match dim {
+ crate::ImageDimension::D1 => 1,
+ crate::ImageDimension::D2 => 2,
+ crate::ImageDimension::D3 => 3,
+ crate::ImageDimension::Cube => 2,
+ };
+
+ if let crate::ImageQuery::Size { .. } = query {
+ match components {
+ 1 => write!(self.out, "uint(")?,
+ _ => write!(self.out, "uvec{components}(")?,
+ }
+ } else {
+ write!(self.out, "uint(")?;
+ }
+
+ match query {
+ crate::ImageQuery::Size { level } => {
+ match class {
+ ImageClass::Sampled { multi, .. } | ImageClass::Depth { multi } => {
+ write!(self.out, "textureSize(")?;
+ self.write_expr(image, ctx)?;
+ if let Some(expr) = level {
+ let cast_to_int = matches!(
+ *ctx.resolve_type(expr, &self.module.types),
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ })
+ );
+
+ write!(self.out, ", ")?;
+
+ if cast_to_int {
+ write!(self.out, "int(")?;
+ }
+
+ self.write_expr(expr, ctx)?;
+
+ if cast_to_int {
+ write!(self.out, ")")?;
+ }
+ } else if !multi {
+ // All textureSize calls requires an lod argument
+ // except for multisampled samplers
+ write!(self.out, ", 0")?;
+ }
+ }
+ ImageClass::Storage { .. } => {
+ write!(self.out, "imageSize(")?;
+ self.write_expr(image, ctx)?;
+ }
+ }
+ write!(self.out, ")")?;
+ if components != 1 || self.options.version.is_es() {
+ write!(self.out, ".{}", &"xyz"[..components])?;
+ }
+ }
+ crate::ImageQuery::NumLevels => {
+ write!(self.out, "textureQueryLevels(",)?;
+ self.write_expr(image, ctx)?;
+ write!(self.out, ")",)?;
+ }
+ crate::ImageQuery::NumLayers => {
+ let fun_name = match class {
+ ImageClass::Sampled { .. } | ImageClass::Depth { .. } => "textureSize",
+ ImageClass::Storage { .. } => "imageSize",
+ };
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(image, ctx)?;
+ // All textureSize calls requires an lod argument
+ // except for multisampled samplers
+ if class.is_multisampled() {
+ write!(self.out, ", 0")?;
+ }
+ write!(self.out, ")")?;
+ if components != 1 || self.options.version.is_es() {
+ write!(self.out, ".{}", back::COMPONENTS[components])?;
+ }
+ }
+ crate::ImageQuery::NumSamples => {
+ let fun_name = match class {
+ ImageClass::Sampled { .. } | ImageClass::Depth { .. } => {
+ "textureSamples"
+ }
+ ImageClass::Storage { .. } => "imageSamples",
+ };
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(image, ctx)?;
+ write!(self.out, ")",)?;
+ }
+ }
+
+ write!(self.out, ")")?;
+ }
+ Expression::Unary { op, expr } => {
+ let operator_or_fn = match op {
+ crate::UnaryOperator::Negate => "-",
+ crate::UnaryOperator::LogicalNot => {
+ match *ctx.resolve_type(expr, &self.module.types) {
+ TypeInner::Vector { .. } => "not",
+ _ => "!",
+ }
+ }
+ crate::UnaryOperator::BitwiseNot => "~",
+ };
+ write!(self.out, "{operator_or_fn}(")?;
+
+ self.write_expr(expr, ctx)?;
+
+ write!(self.out, ")")?
+ }
+ // `Binary` we just write `left op right`, except when dealing with
+ // comparison operations on vectors as they are implemented with
+ // builtin functions.
+ // Once again we wrap everything in parentheses to avoid precedence issues
+ Expression::Binary {
+ mut op,
+ left,
+ right,
+ } => {
+ // Holds `Some(function_name)` if the binary operation is
+ // implemented as a function call
+ use crate::{BinaryOperator as Bo, ScalarKind as Sk, TypeInner as Ti};
+
+ let left_inner = ctx.resolve_type(left, &self.module.types);
+ let right_inner = ctx.resolve_type(right, &self.module.types);
+
+ let function = match (left_inner, right_inner) {
+ (&Ti::Vector { scalar, .. }, &Ti::Vector { .. }) => match op {
+ Bo::Less
+ | Bo::LessEqual
+ | Bo::Greater
+ | Bo::GreaterEqual
+ | Bo::Equal
+ | Bo::NotEqual => BinaryOperation::VectorCompare,
+ Bo::Modulo if scalar.kind == Sk::Float => BinaryOperation::Modulo,
+ Bo::And if scalar.kind == Sk::Bool => {
+ op = crate::BinaryOperator::LogicalAnd;
+ BinaryOperation::VectorComponentWise
+ }
+ Bo::InclusiveOr if scalar.kind == Sk::Bool => {
+ op = crate::BinaryOperator::LogicalOr;
+ BinaryOperation::VectorComponentWise
+ }
+ _ => BinaryOperation::Other,
+ },
+ _ => match (left_inner.scalar_kind(), right_inner.scalar_kind()) {
+ (Some(Sk::Float), _) | (_, Some(Sk::Float)) => match op {
+ Bo::Modulo => BinaryOperation::Modulo,
+ _ => BinaryOperation::Other,
+ },
+ (Some(Sk::Bool), Some(Sk::Bool)) => match op {
+ Bo::InclusiveOr => {
+ op = crate::BinaryOperator::LogicalOr;
+ BinaryOperation::Other
+ }
+ Bo::And => {
+ op = crate::BinaryOperator::LogicalAnd;
+ BinaryOperation::Other
+ }
+ _ => BinaryOperation::Other,
+ },
+ _ => BinaryOperation::Other,
+ },
+ };
+
+ match function {
+ BinaryOperation::VectorCompare => {
+ let op_str = match op {
+ Bo::Less => "lessThan(",
+ Bo::LessEqual => "lessThanEqual(",
+ Bo::Greater => "greaterThan(",
+ Bo::GreaterEqual => "greaterThanEqual(",
+ Bo::Equal => "equal(",
+ Bo::NotEqual => "notEqual(",
+ _ => unreachable!(),
+ };
+ write!(self.out, "{op_str}")?;
+ self.write_expr(left, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(right, ctx)?;
+ write!(self.out, ")")?;
+ }
+ BinaryOperation::VectorComponentWise => {
+ self.write_value_type(left_inner)?;
+ write!(self.out, "(")?;
+
+ let size = match *left_inner {
+ Ti::Vector { size, .. } => size,
+ _ => unreachable!(),
+ };
+
+ for i in 0..size as usize {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+
+ self.write_expr(left, ctx)?;
+ write!(self.out, ".{}", back::COMPONENTS[i])?;
+
+ write!(self.out, " {} ", back::binary_operation_str(op))?;
+
+ self.write_expr(right, ctx)?;
+ write!(self.out, ".{}", back::COMPONENTS[i])?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ // TODO: handle undefined behavior of BinaryOperator::Modulo
+ //
+ // sint:
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ // if sign(left) == -1 || sign(right) == -1 return result as defined by WGSL
+ //
+ // uint:
+ // if right == 0 return 0
+ //
+ // float:
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+ BinaryOperation::Modulo => {
+ write!(self.out, "(")?;
+
+ // write `e1 - e2 * trunc(e1 / e2)`
+ self.write_expr(left, ctx)?;
+ write!(self.out, " - ")?;
+ self.write_expr(right, ctx)?;
+ write!(self.out, " * ")?;
+ write!(self.out, "trunc(")?;
+ self.write_expr(left, ctx)?;
+ write!(self.out, " / ")?;
+ self.write_expr(right, ctx)?;
+ write!(self.out, ")")?;
+
+ write!(self.out, ")")?;
+ }
+ BinaryOperation::Other => {
+ write!(self.out, "(")?;
+
+ self.write_expr(left, ctx)?;
+ write!(self.out, " {} ", back::binary_operation_str(op))?;
+ self.write_expr(right, ctx)?;
+
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ // `Select` is written as `condition ? accept : reject`
+ // We wrap everything in parentheses to avoid precedence issues
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ let cond_ty = ctx.resolve_type(condition, &self.module.types);
+ let vec_select = if let TypeInner::Vector { .. } = *cond_ty {
+ true
+ } else {
+ false
+ };
+
+ // TODO: Boolean mix on desktop required GL_EXT_shader_integer_mix
+ if vec_select {
+ // Glsl defines that for mix when the condition is a boolean the first element
+ // is picked if condition is false and the second if condition is true
+ write!(self.out, "mix(")?;
+ self.write_expr(reject, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(accept, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(condition, ctx)?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(condition, ctx)?;
+ write!(self.out, " ? ")?;
+ self.write_expr(accept, ctx)?;
+ write!(self.out, " : ")?;
+ self.write_expr(reject, ctx)?;
+ }
+
+ write!(self.out, ")")?
+ }
+ // `Derivative` is a function call to a glsl provided function
+ Expression::Derivative { axis, ctrl, expr } => {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ let fun_name = if self.options.version.supports_derivative_control() {
+ match (axis, ctrl) {
+ (Axis::X, Ctrl::Coarse) => "dFdxCoarse",
+ (Axis::X, Ctrl::Fine) => "dFdxFine",
+ (Axis::X, Ctrl::None) => "dFdx",
+ (Axis::Y, Ctrl::Coarse) => "dFdyCoarse",
+ (Axis::Y, Ctrl::Fine) => "dFdyFine",
+ (Axis::Y, Ctrl::None) => "dFdy",
+ (Axis::Width, Ctrl::Coarse) => "fwidthCoarse",
+ (Axis::Width, Ctrl::Fine) => "fwidthFine",
+ (Axis::Width, Ctrl::None) => "fwidth",
+ }
+ } else {
+ match axis {
+ Axis::X => "dFdx",
+ Axis::Y => "dFdy",
+ Axis::Width => "fwidth",
+ }
+ };
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ")")?
+ }
+ // `Relational` is a normal function call to some glsl provided functions
+ Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+
+ let fun_name = match fun {
+ Rf::IsInf => "isinf",
+ Rf::IsNan => "isnan",
+ Rf::All => "all",
+ Rf::Any => "any",
+ };
+ write!(self.out, "{fun_name}(")?;
+
+ self.write_expr(argument, ctx)?;
+
+ write!(self.out, ")")?
+ }
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ let fun_name = match fun {
+ // comparison
+ Mf::Abs => "abs",
+ Mf::Min => "min",
+ Mf::Max => "max",
+ Mf::Clamp => "clamp",
+ Mf::Saturate => {
+ write!(self.out, "clamp(")?;
+
+ self.write_expr(arg, ctx)?;
+
+ match *ctx.resolve_type(arg, &self.module.types) {
+ crate::TypeInner::Vector { size, .. } => write!(
+ self.out,
+ ", vec{}(0.0), vec{0}(1.0)",
+ back::vector_size_str(size)
+ )?,
+ _ => write!(self.out, ", 0.0, 1.0")?,
+ }
+
+ write!(self.out, ")")?;
+
+ return Ok(());
+ }
+ // trigonometry
+ Mf::Cos => "cos",
+ Mf::Cosh => "cosh",
+ Mf::Sin => "sin",
+ Mf::Sinh => "sinh",
+ Mf::Tan => "tan",
+ Mf::Tanh => "tanh",
+ Mf::Acos => "acos",
+ Mf::Asin => "asin",
+ Mf::Atan => "atan",
+ Mf::Asinh => "asinh",
+ Mf::Acosh => "acosh",
+ Mf::Atanh => "atanh",
+ Mf::Radians => "radians",
+ Mf::Degrees => "degrees",
+ // glsl doesn't have atan2 function
+ // use two-argument variation of the atan function
+ Mf::Atan2 => "atan",
+ // decomposition
+ Mf::Ceil => "ceil",
+ Mf::Floor => "floor",
+ Mf::Round => "roundEven",
+ Mf::Fract => "fract",
+ Mf::Trunc => "trunc",
+ Mf::Modf => MODF_FUNCTION,
+ Mf::Frexp => FREXP_FUNCTION,
+ Mf::Ldexp => "ldexp",
+ // exponent
+ Mf::Exp => "exp",
+ Mf::Exp2 => "exp2",
+ Mf::Log => "log",
+ Mf::Log2 => "log2",
+ Mf::Pow => "pow",
+ // geometry
+ Mf::Dot => match *ctx.resolve_type(arg, &self.module.types) {
+ crate::TypeInner::Vector {
+ scalar:
+ crate::Scalar {
+ kind: crate::ScalarKind::Float,
+ ..
+ },
+ ..
+ } => "dot",
+ crate::TypeInner::Vector { size, .. } => {
+ return self.write_dot_product(arg, arg1.unwrap(), size as usize, ctx)
+ }
+ _ => unreachable!(
+ "Correct TypeInner for dot product should be already validated"
+ ),
+ },
+ Mf::Outer => "outerProduct",
+ Mf::Cross => "cross",
+ Mf::Distance => "distance",
+ Mf::Length => "length",
+ Mf::Normalize => "normalize",
+ Mf::FaceForward => "faceforward",
+ Mf::Reflect => "reflect",
+ Mf::Refract => "refract",
+ // computational
+ Mf::Sign => "sign",
+ Mf::Fma => {
+ if self.options.version.supports_fma_function() {
+ // Use the fma function when available
+ "fma"
+ } else {
+ // No fma support. Transform the function call into an arithmetic expression
+ write!(self.out, "(")?;
+
+ self.write_expr(arg, ctx)?;
+ write!(self.out, " * ")?;
+
+ let arg1 =
+ arg1.ok_or_else(|| Error::Custom("Missing fma arg1".to_owned()))?;
+ self.write_expr(arg1, ctx)?;
+ write!(self.out, " + ")?;
+
+ let arg2 =
+ arg2.ok_or_else(|| Error::Custom("Missing fma arg2".to_owned()))?;
+ self.write_expr(arg2, ctx)?;
+ write!(self.out, ")")?;
+
+ return Ok(());
+ }
+ }
+ Mf::Mix => "mix",
+ Mf::Step => "step",
+ Mf::SmoothStep => "smoothstep",
+ Mf::Sqrt => "sqrt",
+ Mf::InverseSqrt => "inversesqrt",
+ Mf::Inverse => "inverse",
+ Mf::Transpose => "transpose",
+ Mf::Determinant => "determinant",
+ // bits
+ Mf::CountTrailingZeros => {
+ match *ctx.resolve_type(arg, &self.module.types) {
+ crate::TypeInner::Vector { size, scalar, .. } => {
+ let s = back::vector_size_str(size);
+ if let crate::ScalarKind::Uint = scalar.kind {
+ write!(self.out, "min(uvec{s}(findLSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")), uvec{s}(32u))")?;
+ } else {
+ write!(self.out, "ivec{s}(min(uvec{s}(findLSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")), uvec{s}(32u)))")?;
+ }
+ }
+ crate::TypeInner::Scalar(scalar) => {
+ if let crate::ScalarKind::Uint = scalar.kind {
+ write!(self.out, "min(uint(findLSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")), 32u)")?;
+ } else {
+ write!(self.out, "int(min(uint(findLSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")), 32u))")?;
+ }
+ }
+ _ => unreachable!(),
+ };
+ return Ok(());
+ }
+ Mf::CountLeadingZeros => {
+ if self.options.version.supports_integer_functions() {
+ match *ctx.resolve_type(arg, &self.module.types) {
+ crate::TypeInner::Vector { size, scalar } => {
+ let s = back::vector_size_str(size);
+
+ if let crate::ScalarKind::Uint = scalar.kind {
+ write!(self.out, "uvec{s}(ivec{s}(31) - findMSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "mix(ivec{s}(31) - findMSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, "), ivec{s}(0), lessThan(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ", ivec{s}(0)))")?;
+ }
+ }
+ crate::TypeInner::Scalar(scalar) => {
+ if let crate::ScalarKind::Uint = scalar.kind {
+ write!(self.out, "uint(31 - findMSB(")?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, " < 0 ? 0 : 31 - findMSB(")?;
+ }
+
+ self.write_expr(arg, ctx)?;
+ write!(self.out, "))")?;
+ }
+ _ => unreachable!(),
+ };
+ } else {
+ match *ctx.resolve_type(arg, &self.module.types) {
+ crate::TypeInner::Vector { size, scalar } => {
+ let s = back::vector_size_str(size);
+
+ if let crate::ScalarKind::Uint = scalar.kind {
+ write!(self.out, "uvec{s}(")?;
+ write!(self.out, "vec{s}(31.0) - floor(log2(vec{s}(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ") + 0.5)))")?;
+ } else {
+ write!(self.out, "ivec{s}(")?;
+ write!(self.out, "mix(vec{s}(31.0) - floor(log2(vec{s}(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ") + 0.5)), ")?;
+ write!(self.out, "vec{s}(0.0), lessThan(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ", ivec{s}(0u))))")?;
+ }
+ }
+ crate::TypeInner::Scalar(scalar) => {
+ if let crate::ScalarKind::Uint = scalar.kind {
+ write!(self.out, "uint(31.0 - floor(log2(float(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ") + 0.5)))")?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, " < 0 ? 0 : int(")?;
+ write!(self.out, "31.0 - floor(log2(float(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ") + 0.5))))")?;
+ }
+ }
+ _ => unreachable!(),
+ };
+ }
+
+ return Ok(());
+ }
+ Mf::CountOneBits => "bitCount",
+ Mf::ReverseBits => "bitfieldReverse",
+ Mf::ExtractBits => "bitfieldExtract",
+ Mf::InsertBits => "bitfieldInsert",
+ Mf::FindLsb => "findLSB",
+ Mf::FindMsb => "findMSB",
+ // data packing
+ Mf::Pack4x8snorm => "packSnorm4x8",
+ Mf::Pack4x8unorm => "packUnorm4x8",
+ Mf::Pack2x16snorm => "packSnorm2x16",
+ Mf::Pack2x16unorm => "packUnorm2x16",
+ Mf::Pack2x16float => "packHalf2x16",
+ // data unpacking
+ Mf::Unpack4x8snorm => "unpackSnorm4x8",
+ Mf::Unpack4x8unorm => "unpackUnorm4x8",
+ Mf::Unpack2x16snorm => "unpackSnorm2x16",
+ Mf::Unpack2x16unorm => "unpackUnorm2x16",
+ Mf::Unpack2x16float => "unpackHalf2x16",
+ };
+
+ let extract_bits = fun == Mf::ExtractBits;
+ let insert_bits = fun == Mf::InsertBits;
+
+ // Some GLSL functions always return signed integers (like findMSB),
+ // so they need to be cast to uint if the argument is also an uint.
+ let ret_might_need_int_to_uint =
+ matches!(fun, Mf::FindLsb | Mf::FindMsb | Mf::CountOneBits | Mf::Abs);
+
+ // Some GLSL functions only accept signed integers (like abs),
+ // so they need their argument cast from uint to int.
+ let arg_might_need_uint_to_int = matches!(fun, Mf::Abs);
+
+ // Check if the argument is an unsigned integer and return the vector size
+ // in case it's a vector
+ let maybe_uint_size = match *ctx.resolve_type(arg, &self.module.types) {
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ }) => Some(None),
+ crate::TypeInner::Vector {
+ scalar:
+ crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ },
+ size,
+ } => Some(Some(size)),
+ _ => None,
+ };
+
+ // Cast to uint if the function needs it
+ if ret_might_need_int_to_uint {
+ if let Some(maybe_size) = maybe_uint_size {
+ match maybe_size {
+ Some(size) => write!(self.out, "uvec{}(", size as u8)?,
+ None => write!(self.out, "uint(")?,
+ }
+ }
+ }
+
+ write!(self.out, "{fun_name}(")?;
+
+ // Cast to int if the function needs it
+ if arg_might_need_uint_to_int {
+ if let Some(maybe_size) = maybe_uint_size {
+ match maybe_size {
+ Some(size) => write!(self.out, "ivec{}(", size as u8)?,
+ None => write!(self.out, "int(")?,
+ }
+ }
+ }
+
+ self.write_expr(arg, ctx)?;
+
+ // Close the cast from uint to int
+ if arg_might_need_uint_to_int && maybe_uint_size.is_some() {
+ write!(self.out, ")")?
+ }
+
+ if let Some(arg) = arg1 {
+ write!(self.out, ", ")?;
+ if extract_bits {
+ write!(self.out, "int(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(arg, ctx)?;
+ }
+ }
+ if let Some(arg) = arg2 {
+ write!(self.out, ", ")?;
+ if extract_bits || insert_bits {
+ write!(self.out, "int(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(arg, ctx)?;
+ }
+ }
+ if let Some(arg) = arg3 {
+ write!(self.out, ", ")?;
+ if insert_bits {
+ write!(self.out, "int(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(arg, ctx)?;
+ }
+ }
+ write!(self.out, ")")?;
+
+ // Close the cast from int to uint
+ if ret_might_need_int_to_uint && maybe_uint_size.is_some() {
+ write!(self.out, ")")?
+ }
+ }
+ // `As` is always a call.
+ // If `convert` is true the function name is the type
+ // Else the function name is one of the glsl provided bitcast functions
+ Expression::As {
+ expr,
+ kind: target_kind,
+ convert,
+ } => {
+ let inner = ctx.resolve_type(expr, &self.module.types);
+ match convert {
+ Some(width) => {
+ // this is similar to `write_type`, but with the target kind
+ let scalar = glsl_scalar(crate::Scalar {
+ kind: target_kind,
+ width,
+ })?;
+ match *inner {
+ TypeInner::Matrix { columns, rows, .. } => write!(
+ self.out,
+ "{}mat{}x{}",
+ scalar.prefix, columns as u8, rows as u8
+ )?,
+ TypeInner::Vector { size, .. } => {
+ write!(self.out, "{}vec{}", scalar.prefix, size as u8)?
+ }
+ _ => write!(self.out, "{}", scalar.full)?,
+ }
+
+ write!(self.out, "(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ")")?
+ }
+ None => {
+ use crate::ScalarKind as Sk;
+
+ let target_vector_type = match *inner {
+ TypeInner::Vector { size, scalar } => Some(TypeInner::Vector {
+ size,
+ scalar: crate::Scalar {
+ kind: target_kind,
+ width: scalar.width,
+ },
+ }),
+ _ => None,
+ };
+
+ let source_kind = inner.scalar_kind().unwrap();
+
+ match (source_kind, target_kind, target_vector_type) {
+ // No conversion needed
+ (Sk::Sint, Sk::Sint, _)
+ | (Sk::Uint, Sk::Uint, _)
+ | (Sk::Float, Sk::Float, _)
+ | (Sk::Bool, Sk::Bool, _) => {
+ self.write_expr(expr, ctx)?;
+ return Ok(());
+ }
+
+ // Cast to/from floats
+ (Sk::Float, Sk::Sint, _) => write!(self.out, "floatBitsToInt")?,
+ (Sk::Float, Sk::Uint, _) => write!(self.out, "floatBitsToUint")?,
+ (Sk::Sint, Sk::Float, _) => write!(self.out, "intBitsToFloat")?,
+ (Sk::Uint, Sk::Float, _) => write!(self.out, "uintBitsToFloat")?,
+
+ // Cast between vector types
+ (_, _, Some(vector)) => {
+ self.write_value_type(&vector)?;
+ }
+
+ // There is no way to bitcast between Uint/Sint in glsl. Use constructor conversion
+ (Sk::Uint | Sk::Bool, Sk::Sint, None) => write!(self.out, "int")?,
+ (Sk::Sint | Sk::Bool, Sk::Uint, None) => write!(self.out, "uint")?,
+ (Sk::Bool, Sk::Float, None) => write!(self.out, "float")?,
+ (Sk::Sint | Sk::Uint | Sk::Float, Sk::Bool, None) => {
+ write!(self.out, "bool")?
+ }
+
+ (Sk::AbstractInt | Sk::AbstractFloat, _, _)
+ | (_, Sk::AbstractInt | Sk::AbstractFloat, _) => unreachable!(),
+ };
+
+ write!(self.out, "(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ // These expressions never show up in `Emit`.
+ Expression::CallResult(_)
+ | Expression::AtomicResult { .. }
+ | Expression::RayQueryProceedResult
+ | Expression::WorkGroupUniformLoadResult { .. } => unreachable!(),
+ // `ArrayLength` is written as `expr.length()` and we convert it to a uint
+ Expression::ArrayLength(expr) => {
+ write!(self.out, "uint(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ".length())")?
+ }
+ // not supported yet
+ Expression::RayQueryGetIntersection { .. } => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Helper function to write the local holding the clamped lod
+ fn write_clamped_lod(
+ &mut self,
+ ctx: &back::FunctionCtx,
+ expr: Handle<crate::Expression>,
+ image: Handle<crate::Expression>,
+ level_expr: Handle<crate::Expression>,
+ ) -> Result<(), Error> {
+ // Define our local and start a call to `clamp`
+ write!(
+ self.out,
+ "int {}{}{} = clamp(",
+ back::BAKE_PREFIX,
+ expr.index(),
+ CLAMPED_LOD_SUFFIX
+ )?;
+ // Write the lod that will be clamped
+ self.write_expr(level_expr, ctx)?;
+ // Set the min value to 0 and start a call to `textureQueryLevels` to get
+ // the maximum value
+ write!(self.out, ", 0, textureQueryLevels(")?;
+ // Write the target image as an argument to `textureQueryLevels`
+ self.write_expr(image, ctx)?;
+ // Close the call to `textureQueryLevels` subtract 1 from it since
+ // the lod argument is 0 based, close the `clamp` call and end the
+ // local declaration statement.
+ writeln!(self.out, ") - 1);")?;
+
+ Ok(())
+ }
+
+ // Helper method used to retrieve how many elements a coordinate vector
+ // for the images operations need.
+ fn get_coordinate_vector_size(&self, dim: crate::ImageDimension, arrayed: bool) -> u8 {
+ // openGL es doesn't have 1D images so we need workaround it
+ let tex_1d_hack = dim == crate::ImageDimension::D1 && self.options.version.is_es();
+ // Get how many components the coordinate vector needs for the dimensions only
+ let tex_coord_size = match dim {
+ crate::ImageDimension::D1 => 1,
+ crate::ImageDimension::D2 => 2,
+ crate::ImageDimension::D3 => 3,
+ crate::ImageDimension::Cube => 2,
+ };
+ // Calculate the true size of the coordinate vector by adding 1 for arrayed images
+ // and another 1 if we need to workaround 1D images by making them 2D
+ tex_coord_size + tex_1d_hack as u8 + arrayed as u8
+ }
+
+ /// Helper method to write the coordinate vector for image operations
+ fn write_texture_coord(
+ &mut self,
+ ctx: &back::FunctionCtx,
+ vector_size: u8,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ // Emulate 1D images as 2D for profiles that don't support it (glsl es)
+ tex_1d_hack: bool,
+ ) -> Result<(), Error> {
+ match array_index {
+ // If the image needs an array indice we need to add it to the end of our
+ // coordinate vector, to do so we will use the `ivec(ivec, scalar)`
+ // constructor notation (NOTE: the inner `ivec` can also be a scalar, this
+ // is important for 1D arrayed images).
+ Some(layer_expr) => {
+ write!(self.out, "ivec{vector_size}(")?;
+ self.write_expr(coordinate, ctx)?;
+ write!(self.out, ", ")?;
+ // If we are replacing sampler1D with sampler2D we also need
+ // to add another zero to the coordinates vector for the y component
+ if tex_1d_hack {
+ write!(self.out, "0, ")?;
+ }
+ self.write_expr(layer_expr, ctx)?;
+ write!(self.out, ")")?;
+ }
+ // Otherwise write just the expression (and the 1D hack if needed)
+ None => {
+ let uvec_size = match *ctx.resolve_type(coordinate, &self.module.types) {
+ TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ }) => Some(None),
+ TypeInner::Vector {
+ size,
+ scalar:
+ crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ },
+ } => Some(Some(size as u32)),
+ _ => None,
+ };
+ if tex_1d_hack {
+ write!(self.out, "ivec2(")?;
+ } else if uvec_size.is_some() {
+ match uvec_size {
+ Some(None) => write!(self.out, "int(")?,
+ Some(Some(size)) => write!(self.out, "ivec{size}(")?,
+ _ => {}
+ }
+ }
+ self.write_expr(coordinate, ctx)?;
+ if tex_1d_hack {
+ write!(self.out, ", 0)")?;
+ } else if uvec_size.is_some() {
+ write!(self.out, ")")?;
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method to write the `ImageStore` statement
+ fn write_image_store(
+ &mut self,
+ ctx: &back::FunctionCtx,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ value: Handle<crate::Expression>,
+ ) -> Result<(), Error> {
+ use crate::ImageDimension as IDim;
+
+ // NOTE: openGL requires that `imageStore`s have no effets when the texel is invalid
+ // so we don't need to generate bounds checks (OpenGL 4.2 Core §3.9.20)
+
+ // This will only panic if the module is invalid
+ let dim = match *ctx.resolve_type(image, &self.module.types) {
+ TypeInner::Image { dim, .. } => dim,
+ _ => unreachable!(),
+ };
+
+ // Begin our call to `imageStore`
+ write!(self.out, "imageStore(")?;
+ self.write_expr(image, ctx)?;
+ // Separate the image argument from the coordinates
+ write!(self.out, ", ")?;
+
+ // openGL es doesn't have 1D images so we need workaround it
+ let tex_1d_hack = dim == IDim::D1 && self.options.version.is_es();
+ // Write the coordinate vector
+ self.write_texture_coord(
+ ctx,
+ // Get the size of the coordinate vector
+ self.get_coordinate_vector_size(dim, array_index.is_some()),
+ coordinate,
+ array_index,
+ tex_1d_hack,
+ )?;
+
+ // Separate the coordinate from the value to write and write the expression
+ // of the value to write.
+ write!(self.out, ", ")?;
+ self.write_expr(value, ctx)?;
+ // End the call to `imageStore` and the statement.
+ writeln!(self.out, ");")?;
+
+ Ok(())
+ }
+
+ /// Helper method for writing an `ImageLoad` expression.
+ #[allow(clippy::too_many_arguments)]
+ fn write_image_load(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ ctx: &back::FunctionCtx,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ sample: Option<Handle<crate::Expression>>,
+ level: Option<Handle<crate::Expression>>,
+ ) -> Result<(), Error> {
+ use crate::ImageDimension as IDim;
+
+ // `ImageLoad` is a bit complicated.
+ // There are two functions one for sampled
+ // images another for storage images, the former uses `texelFetch` and the
+ // latter uses `imageLoad`.
+ //
+ // Furthermore we have `level` which is always `Some` for sampled images
+ // and `None` for storage images, so we end up with two functions:
+ // - `texelFetch(image, coordinate, level)` for sampled images
+ // - `imageLoad(image, coordinate)` for storage images
+ //
+ // Finally we also have to consider bounds checking, for storage images
+ // this is easy since openGL requires that invalid texels always return
+ // 0, for sampled images we need to either verify that all arguments are
+ // in bounds (`ReadZeroSkipWrite`) or make them a valid texel (`Restrict`).
+
+ // This will only panic if the module is invalid
+ let (dim, class) = match *ctx.resolve_type(image, &self.module.types) {
+ TypeInner::Image {
+ dim,
+ arrayed: _,
+ class,
+ } => (dim, class),
+ _ => unreachable!(),
+ };
+
+ // Get the name of the function to be used for the load operation
+ // and the policy to be used with it.
+ let (fun_name, policy) = match class {
+ // Sampled images inherit the policy from the user passed policies
+ crate::ImageClass::Sampled { .. } => ("texelFetch", self.policies.image_load),
+ crate::ImageClass::Storage { .. } => {
+ // OpenGL ES 3.1 mentions in Chapter "8.22 Texture Image Loads and Stores" that:
+ // "Invalid image loads will return a vector where the value of R, G, and B components
+ // is 0 and the value of the A component is undefined."
+ //
+ // OpenGL 4.2 Core mentions in Chapter "3.9.20 Texture Image Loads and Stores" that:
+ // "Invalid image loads will return zero."
+ //
+ // So, we only inject bounds checks for ES
+ let policy = if self.options.version.is_es() {
+ self.policies.image_load
+ } else {
+ proc::BoundsCheckPolicy::Unchecked
+ };
+ ("imageLoad", policy)
+ }
+ // TODO: Is there even a function for this?
+ crate::ImageClass::Depth { multi: _ } => {
+ return Err(Error::Custom(
+ "WGSL `textureLoad` from depth textures is not supported in GLSL".to_string(),
+ ))
+ }
+ };
+
+ // openGL es doesn't have 1D images so we need workaround it
+ let tex_1d_hack = dim == IDim::D1 && self.options.version.is_es();
+ // Get the size of the coordinate vector
+ let vector_size = self.get_coordinate_vector_size(dim, array_index.is_some());
+
+ if let proc::BoundsCheckPolicy::ReadZeroSkipWrite = policy {
+ // To write the bounds checks for `ReadZeroSkipWrite` we will use a
+ // ternary operator since we are in the middle of an expression and
+ // need to return a value.
+ //
+ // NOTE: glsl does short circuit when evaluating logical
+ // expressions so we can be sure that after we test a
+ // condition it will be true for the next ones
+
+ // Write parentheses around the ternary operator to prevent problems with
+ // expressions emitted before or after it having more precedence
+ write!(self.out, "(",)?;
+
+ // The lod check needs to precede the size check since we need
+ // to use the lod to get the size of the image at that level.
+ if let Some(level_expr) = level {
+ self.write_expr(level_expr, ctx)?;
+ write!(self.out, " < textureQueryLevels(",)?;
+ self.write_expr(image, ctx)?;
+ // Chain the next check
+ write!(self.out, ") && ")?;
+ }
+
+ // Check that the sample arguments doesn't exceed the number of samples
+ if let Some(sample_expr) = sample {
+ self.write_expr(sample_expr, ctx)?;
+ write!(self.out, " < textureSamples(",)?;
+ self.write_expr(image, ctx)?;
+ // Chain the next check
+ write!(self.out, ") && ")?;
+ }
+
+ // We now need to write the size checks for the coordinates and array index
+ // first we write the comparison function in case the image is 1D non arrayed
+ // (and no 1D to 2D hack was needed) we are comparing scalars so the less than
+ // operator will suffice, but otherwise we'll be comparing two vectors so we'll
+ // need to use the `lessThan` function but it returns a vector of booleans (one
+ // for each comparison) so we need to fold it all in one scalar boolean, since
+ // we want all comparisons to pass we use the `all` function which will only
+ // return `true` if all the elements of the boolean vector are also `true`.
+ //
+ // So we'll end with one of the following forms
+ // - `coord < textureSize(image, lod)` for 1D images
+ // - `all(lessThan(coord, textureSize(image, lod)))` for normal images
+ // - `all(lessThan(ivec(coord, array_index), textureSize(image, lod)))`
+ // for arrayed images
+ // - `all(lessThan(coord, textureSize(image)))` for multi sampled images
+
+ if vector_size != 1 {
+ write!(self.out, "all(lessThan(")?;
+ }
+
+ // Write the coordinate vector
+ self.write_texture_coord(ctx, vector_size, coordinate, array_index, tex_1d_hack)?;
+
+ if vector_size != 1 {
+ // If we used the `lessThan` function we need to separate the
+ // coordinates from the image size.
+ write!(self.out, ", ")?;
+ } else {
+ // If we didn't use it (ie. 1D images) we perform the comparison
+ // using the less than operator.
+ write!(self.out, " < ")?;
+ }
+
+ // Call `textureSize` to get our image size
+ write!(self.out, "textureSize(")?;
+ self.write_expr(image, ctx)?;
+ // `textureSize` uses the lod as a second argument for mipmapped images
+ if let Some(level_expr) = level {
+ // Separate the image from the lod
+ write!(self.out, ", ")?;
+ self.write_expr(level_expr, ctx)?;
+ }
+ // Close the `textureSize` call
+ write!(self.out, ")")?;
+
+ if vector_size != 1 {
+ // Close the `all` and `lessThan` calls
+ write!(self.out, "))")?;
+ }
+
+ // Finally end the condition part of the ternary operator
+ write!(self.out, " ? ")?;
+ }
+
+ // Begin the call to the function used to load the texel
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(image, ctx)?;
+ write!(self.out, ", ")?;
+
+ // If we are using `Restrict` bounds checking we need to pass valid texel
+ // coordinates, to do so we use the `clamp` function to get a value between
+ // 0 and the image size - 1 (indexing begins at 0)
+ if let proc::BoundsCheckPolicy::Restrict = policy {
+ write!(self.out, "clamp(")?;
+ }
+
+ // Write the coordinate vector
+ self.write_texture_coord(ctx, vector_size, coordinate, array_index, tex_1d_hack)?;
+
+ // If we are using `Restrict` bounds checking we need to write the rest of the
+ // clamp we initiated before writing the coordinates.
+ if let proc::BoundsCheckPolicy::Restrict = policy {
+ // Write the min value 0
+ if vector_size == 1 {
+ write!(self.out, ", 0")?;
+ } else {
+ write!(self.out, ", ivec{vector_size}(0)")?;
+ }
+ // Start the `textureSize` call to use as the max value.
+ write!(self.out, ", textureSize(")?;
+ self.write_expr(image, ctx)?;
+ // If the image is mipmapped we need to add the lod argument to the
+ // `textureSize` call, but this needs to be the clamped lod, this should
+ // have been generated earlier and put in a local.
+ if class.is_mipmapped() {
+ write!(
+ self.out,
+ ", {}{}{}",
+ back::BAKE_PREFIX,
+ handle.index(),
+ CLAMPED_LOD_SUFFIX
+ )?;
+ }
+ // Close the `textureSize` call
+ write!(self.out, ")")?;
+
+ // Subtract 1 from the `textureSize` call since the coordinates are zero based.
+ if vector_size == 1 {
+ write!(self.out, " - 1")?;
+ } else {
+ write!(self.out, " - ivec{vector_size}(1)")?;
+ }
+
+ // Close the `clamp` call
+ write!(self.out, ")")?;
+
+ // Add the clamped lod (if present) as the second argument to the
+ // image load function.
+ if level.is_some() {
+ write!(
+ self.out,
+ ", {}{}{}",
+ back::BAKE_PREFIX,
+ handle.index(),
+ CLAMPED_LOD_SUFFIX
+ )?;
+ }
+
+ // If a sample argument is needed we need to clamp it between 0 and
+ // the number of samples the image has.
+ if let Some(sample_expr) = sample {
+ write!(self.out, ", clamp(")?;
+ self.write_expr(sample_expr, ctx)?;
+ // Set the min value to 0 and start the call to `textureSamples`
+ write!(self.out, ", 0, textureSamples(")?;
+ self.write_expr(image, ctx)?;
+ // Close the `textureSamples` call, subtract 1 from it since the sample
+ // argument is zero based, and close the `clamp` call
+ writeln!(self.out, ") - 1)")?;
+ }
+ } else if let Some(sample_or_level) = sample.or(level) {
+ // If no bounds checking is need just add the sample or level argument
+ // after the coordinates
+ write!(self.out, ", ")?;
+ self.write_expr(sample_or_level, ctx)?;
+ }
+
+ // Close the image load function.
+ write!(self.out, ")")?;
+
+ // If we were using the `ReadZeroSkipWrite` policy we need to end the first branch
+ // (which is taken if the condition is `true`) with a colon (`:`) and write the
+ // second branch which is just a 0 value.
+ if let proc::BoundsCheckPolicy::ReadZeroSkipWrite = policy {
+ // Get the kind of the output value.
+ let kind = match class {
+ // Only sampled images can reach here since storage images
+ // don't need bounds checks and depth images aren't implemented
+ crate::ImageClass::Sampled { kind, .. } => kind,
+ _ => unreachable!(),
+ };
+
+ // End the first branch
+ write!(self.out, " : ")?;
+ // Write the 0 value
+ write!(
+ self.out,
+ "{}vec4(",
+ glsl_scalar(crate::Scalar { kind, width: 4 })?.prefix,
+ )?;
+ self.write_zero_init_scalar(kind)?;
+ // Close the zero value constructor
+ write!(self.out, ")")?;
+ // Close the parentheses surrounding our ternary
+ write!(self.out, ")")?;
+ }
+
+ Ok(())
+ }
+
+ fn write_named_expr(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ name: String,
+ // The expression which is being named.
+ // Generally, this is the same as handle, except in WorkGroupUniformLoad
+ named: Handle<crate::Expression>,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ match ctx.info[named].ty {
+ proc::TypeResolution::Handle(ty_handle) => match self.module.types[ty_handle].inner {
+ TypeInner::Struct { .. } => {
+ let ty_name = &self.names[&NameKey::Type(ty_handle)];
+ write!(self.out, "{ty_name}")?;
+ }
+ _ => {
+ self.write_type(ty_handle)?;
+ }
+ },
+ proc::TypeResolution::Value(ref inner) => {
+ self.write_value_type(inner)?;
+ }
+ }
+
+ let resolved = ctx.resolve_type(named, &self.module.types);
+
+ write!(self.out, " {name}")?;
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(base, size)?;
+ }
+ write!(self.out, " = ")?;
+ self.write_expr(handle, ctx)?;
+ writeln!(self.out, ";")?;
+ self.named_expressions.insert(named, name);
+
+ Ok(())
+ }
+
+ /// Helper function that write string with default zero initialization for supported types
+ fn write_zero_init_value(&mut self, ty: Handle<crate::Type>) -> BackendResult {
+ let inner = &self.module.types[ty].inner;
+ match *inner {
+ TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
+ self.write_zero_init_scalar(scalar.kind)?;
+ }
+ TypeInner::Vector { scalar, .. } => {
+ self.write_value_type(inner)?;
+ write!(self.out, "(")?;
+ self.write_zero_init_scalar(scalar.kind)?;
+ write!(self.out, ")")?;
+ }
+ TypeInner::Matrix { .. } => {
+ self.write_value_type(inner)?;
+ write!(self.out, "(")?;
+ self.write_zero_init_scalar(crate::ScalarKind::Float)?;
+ write!(self.out, ")")?;
+ }
+ TypeInner::Array { base, size, .. } => {
+ let count = match size
+ .to_indexable_length(self.module)
+ .expect("Bad array size")
+ {
+ proc::IndexableLength::Known(count) => count,
+ proc::IndexableLength::Dynamic => return Ok(()),
+ };
+ self.write_type(base)?;
+ self.write_array_size(base, size)?;
+ write!(self.out, "(")?;
+ for _ in 1..count {
+ self.write_zero_init_value(base)?;
+ write!(self.out, ", ")?;
+ }
+ // write last parameter without comma and space
+ self.write_zero_init_value(base)?;
+ write!(self.out, ")")?;
+ }
+ TypeInner::Struct { ref members, .. } => {
+ let name = &self.names[&NameKey::Type(ty)];
+ write!(self.out, "{name}(")?;
+ for (index, member) in members.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_zero_init_value(member.ty)?;
+ }
+ write!(self.out, ")")?;
+ }
+ _ => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Helper function that write string with zero initialization for scalar
+ fn write_zero_init_scalar(&mut self, kind: crate::ScalarKind) -> BackendResult {
+ match kind {
+ crate::ScalarKind::Bool => write!(self.out, "false")?,
+ crate::ScalarKind::Uint => write!(self.out, "0u")?,
+ crate::ScalarKind::Float => write!(self.out, "0.0")?,
+ crate::ScalarKind::Sint => write!(self.out, "0")?,
+ crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => {
+ return Err(Error::Custom(
+ "Abstract types should not appear in IR presented to backends".to_string(),
+ ))
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Issue a memory barrier. Please note that to ensure visibility,
+ /// OpenGL always requires a call to the `barrier()` function after a `memoryBarrier*()`
+ fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
+ if flags.contains(crate::Barrier::STORAGE) {
+ writeln!(self.out, "{level}memoryBarrierBuffer();")?;
+ }
+ if flags.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(self.out, "{level}memoryBarrierShared();")?;
+ }
+ writeln!(self.out, "{level}barrier();")?;
+ Ok(())
+ }
+
+ /// Helper function that return the glsl storage access string of [`StorageAccess`](crate::StorageAccess)
+ ///
+ /// glsl allows adding both `readonly` and `writeonly` but this means that
+ /// they can only be used to query information about the resource which isn't what
+ /// we want here so when storage access is both `LOAD` and `STORE` add no modifiers
+ fn write_storage_access(&mut self, storage_access: crate::StorageAccess) -> BackendResult {
+ if !storage_access.contains(crate::StorageAccess::STORE) {
+ write!(self.out, "readonly ")?;
+ }
+ if !storage_access.contains(crate::StorageAccess::LOAD) {
+ write!(self.out, "writeonly ")?;
+ }
+ Ok(())
+ }
+
+ /// Helper method used to produce the reflection info that's returned to the user
+ fn collect_reflection_info(&mut self) -> Result<ReflectionInfo, Error> {
+ use std::collections::hash_map::Entry;
+ let info = self.info.get_entry_point(self.entry_point_idx as usize);
+ let mut texture_mapping = crate::FastHashMap::default();
+ let mut uniforms = crate::FastHashMap::default();
+
+ for sampling in info.sampling_set.iter() {
+ let tex_name = self.reflection_names_globals[&sampling.image].clone();
+
+ match texture_mapping.entry(tex_name) {
+ Entry::Vacant(v) => {
+ v.insert(TextureMapping {
+ texture: sampling.image,
+ sampler: Some(sampling.sampler),
+ });
+ }
+ Entry::Occupied(e) => {
+ if e.get().sampler != Some(sampling.sampler) {
+ log::error!("Conflicting samplers for {}", e.key());
+ return Err(Error::ImageMultipleSamplers);
+ }
+ }
+ }
+ }
+
+ let mut push_constant_info = None;
+ for (handle, var) in self.module.global_variables.iter() {
+ if info[handle].is_empty() {
+ continue;
+ }
+ match self.module.types[var.ty].inner {
+ crate::TypeInner::Image { .. } => {
+ let tex_name = self.reflection_names_globals[&handle].clone();
+ match texture_mapping.entry(tex_name) {
+ Entry::Vacant(v) => {
+ v.insert(TextureMapping {
+ texture: handle,
+ sampler: None,
+ });
+ }
+ Entry::Occupied(_) => {
+ // already used with a sampler, do nothing
+ }
+ }
+ }
+ _ => match var.space {
+ crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } => {
+ let name = self.reflection_names_globals[&handle].clone();
+ uniforms.insert(handle, name);
+ }
+ crate::AddressSpace::PushConstant => {
+ let name = self.reflection_names_globals[&handle].clone();
+ push_constant_info = Some((name, var.ty));
+ }
+ _ => (),
+ },
+ }
+ }
+
+ let mut push_constant_segments = Vec::new();
+ let mut push_constant_items = vec![];
+
+ if let Some((name, ty)) = push_constant_info {
+ // We don't have a layouter available to us, so we need to create one.
+ //
+ // This is potentially a bit wasteful, but the set of types in the program
+ // shouldn't be too large.
+ let mut layouter = crate::proc::Layouter::default();
+ layouter.update(self.module.to_ctx()).unwrap();
+
+ // We start with the name of the binding itself.
+ push_constant_segments.push(name);
+
+ // We then recursively collect all the uniform fields of the push constant.
+ self.collect_push_constant_items(
+ ty,
+ &mut push_constant_segments,
+ &layouter,
+ &mut 0,
+ &mut push_constant_items,
+ );
+ }
+
+ Ok(ReflectionInfo {
+ texture_mapping,
+ uniforms,
+ varying: mem::take(&mut self.varying),
+ push_constant_items,
+ })
+ }
+
+ fn collect_push_constant_items(
+ &mut self,
+ ty: Handle<crate::Type>,
+ segments: &mut Vec<String>,
+ layouter: &crate::proc::Layouter,
+ offset: &mut u32,
+ items: &mut Vec<PushConstantItem>,
+ ) {
+ // At this point in the recursion, `segments` contains the path
+ // needed to access `ty` from the root.
+
+ let layout = &layouter[ty];
+ *offset = layout.alignment.round_up(*offset);
+ match self.module.types[ty].inner {
+ // All these types map directly to GL uniforms.
+ TypeInner::Scalar { .. } | TypeInner::Vector { .. } | TypeInner::Matrix { .. } => {
+ // Build the full name, by combining all current segments.
+ let name: String = segments.iter().map(String::as_str).collect();
+ items.push(PushConstantItem {
+ access_path: name,
+ offset: *offset,
+ ty,
+ });
+ *offset += layout.size;
+ }
+ // Arrays are recursed into.
+ TypeInner::Array { base, size, .. } => {
+ let crate::ArraySize::Constant(count) = size else {
+ unreachable!("Cannot have dynamic arrays in push constants");
+ };
+
+ for i in 0..count.get() {
+ // Add the array accessor and recurse.
+ segments.push(format!("[{}]", i));
+ self.collect_push_constant_items(base, segments, layouter, offset, items);
+ segments.pop();
+ }
+
+ // Ensure the stride is kept by rounding up to the alignment.
+ *offset = layout.alignment.round_up(*offset)
+ }
+ TypeInner::Struct { ref members, .. } => {
+ for (index, member) in members.iter().enumerate() {
+ // Add struct accessor and recurse.
+ segments.push(format!(
+ ".{}",
+ self.names[&NameKey::StructMember(ty, index as u32)]
+ ));
+ self.collect_push_constant_items(member.ty, segments, layouter, offset, items);
+ segments.pop();
+ }
+
+ // Ensure ending padding is kept by rounding up to the alignment.
+ *offset = layout.alignment.round_up(*offset)
+ }
+ _ => unreachable!(),
+ }
+ }
+}
+
+/// Structure returned by [`glsl_scalar`]
+///
+/// It contains both a prefix used in other types and the full type name
+struct ScalarString<'a> {
+ /// The prefix used to compose other types
+ prefix: &'a str,
+ /// The name of the scalar type
+ full: &'a str,
+}
+
+/// Helper function that returns scalar related strings
+///
+/// Check [`ScalarString`] for the information provided
+///
+/// # Errors
+/// If a [`Float`](crate::ScalarKind::Float) with an width that isn't 4 or 8
+const fn glsl_scalar(scalar: crate::Scalar) -> Result<ScalarString<'static>, Error> {
+ use crate::ScalarKind as Sk;
+
+ Ok(match scalar.kind {
+ Sk::Sint => ScalarString {
+ prefix: "i",
+ full: "int",
+ },
+ Sk::Uint => ScalarString {
+ prefix: "u",
+ full: "uint",
+ },
+ Sk::Float => match scalar.width {
+ 4 => ScalarString {
+ prefix: "",
+ full: "float",
+ },
+ 8 => ScalarString {
+ prefix: "d",
+ full: "double",
+ },
+ _ => return Err(Error::UnsupportedScalar(scalar)),
+ },
+ Sk::Bool => ScalarString {
+ prefix: "b",
+ full: "bool",
+ },
+ Sk::AbstractInt | Sk::AbstractFloat => {
+ return Err(Error::UnsupportedScalar(scalar));
+ }
+ })
+}
+
+/// Helper function that returns the glsl variable name for a builtin
+const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'static str {
+ use crate::BuiltIn as Bi;
+
+ match built_in {
+ Bi::Position { .. } => {
+ if options.output {
+ "gl_Position"
+ } else {
+ "gl_FragCoord"
+ }
+ }
+ Bi::ViewIndex if options.targeting_webgl => "int(gl_ViewID_OVR)",
+ Bi::ViewIndex => "gl_ViewIndex",
+ // vertex
+ Bi::BaseInstance => "uint(gl_BaseInstance)",
+ Bi::BaseVertex => "uint(gl_BaseVertex)",
+ Bi::ClipDistance => "gl_ClipDistance",
+ Bi::CullDistance => "gl_CullDistance",
+ Bi::InstanceIndex => {
+ if options.draw_parameters {
+ "(uint(gl_InstanceID) + uint(gl_BaseInstanceARB))"
+ } else {
+ // Must match FIRST_INSTANCE_BINDING
+ "(uint(gl_InstanceID) + naga_vs_first_instance)"
+ }
+ }
+ Bi::PointSize => "gl_PointSize",
+ Bi::VertexIndex => "uint(gl_VertexID)",
+ // fragment
+ Bi::FragDepth => "gl_FragDepth",
+ Bi::PointCoord => "gl_PointCoord",
+ Bi::FrontFacing => "gl_FrontFacing",
+ Bi::PrimitiveIndex => "uint(gl_PrimitiveID)",
+ Bi::SampleIndex => "gl_SampleID",
+ Bi::SampleMask => {
+ if options.output {
+ "gl_SampleMask"
+ } else {
+ "gl_SampleMaskIn"
+ }
+ }
+ // compute
+ Bi::GlobalInvocationId => "gl_GlobalInvocationID",
+ Bi::LocalInvocationId => "gl_LocalInvocationID",
+ Bi::LocalInvocationIndex => "gl_LocalInvocationIndex",
+ Bi::WorkGroupId => "gl_WorkGroupID",
+ Bi::WorkGroupSize => "gl_WorkGroupSize",
+ Bi::NumWorkGroups => "gl_NumWorkGroups",
+ }
+}
+
+/// Helper function that returns the string corresponding to the address space
+const fn glsl_storage_qualifier(space: crate::AddressSpace) -> Option<&'static str> {
+ use crate::AddressSpace as As;
+
+ match space {
+ As::Function => None,
+ As::Private => None,
+ As::Storage { .. } => Some("buffer"),
+ As::Uniform => Some("uniform"),
+ As::Handle => Some("uniform"),
+ As::WorkGroup => Some("shared"),
+ As::PushConstant => Some("uniform"),
+ }
+}
+
+/// Helper function that returns the string corresponding to the glsl interpolation qualifier
+const fn glsl_interpolation(interpolation: crate::Interpolation) -> &'static str {
+ use crate::Interpolation as I;
+
+ match interpolation {
+ I::Perspective => "smooth",
+ I::Linear => "noperspective",
+ I::Flat => "flat",
+ }
+}
+
+/// Return the GLSL auxiliary qualifier for the given sampling value.
+const fn glsl_sampling(sampling: crate::Sampling) -> Option<&'static str> {
+ use crate::Sampling as S;
+
+ match sampling {
+ S::Center => None,
+ S::Centroid => Some("centroid"),
+ S::Sample => Some("sample"),
+ }
+}
+
+/// Helper function that returns the glsl dimension string of [`ImageDimension`](crate::ImageDimension)
+const fn glsl_dimension(dim: crate::ImageDimension) -> &'static str {
+ use crate::ImageDimension as IDim;
+
+ match dim {
+ IDim::D1 => "1D",
+ IDim::D2 => "2D",
+ IDim::D3 => "3D",
+ IDim::Cube => "Cube",
+ }
+}
+
+/// Helper function that returns the glsl storage format string of [`StorageFormat`](crate::StorageFormat)
+fn glsl_storage_format(format: crate::StorageFormat) -> Result<&'static str, Error> {
+ use crate::StorageFormat as Sf;
+
+ Ok(match format {
+ Sf::R8Unorm => "r8",
+ Sf::R8Snorm => "r8_snorm",
+ Sf::R8Uint => "r8ui",
+ Sf::R8Sint => "r8i",
+ Sf::R16Uint => "r16ui",
+ Sf::R16Sint => "r16i",
+ Sf::R16Float => "r16f",
+ Sf::Rg8Unorm => "rg8",
+ Sf::Rg8Snorm => "rg8_snorm",
+ Sf::Rg8Uint => "rg8ui",
+ Sf::Rg8Sint => "rg8i",
+ Sf::R32Uint => "r32ui",
+ Sf::R32Sint => "r32i",
+ Sf::R32Float => "r32f",
+ Sf::Rg16Uint => "rg16ui",
+ Sf::Rg16Sint => "rg16i",
+ Sf::Rg16Float => "rg16f",
+ Sf::Rgba8Unorm => "rgba8",
+ Sf::Rgba8Snorm => "rgba8_snorm",
+ Sf::Rgba8Uint => "rgba8ui",
+ Sf::Rgba8Sint => "rgba8i",
+ Sf::Rgb10a2Uint => "rgb10_a2ui",
+ Sf::Rgb10a2Unorm => "rgb10_a2",
+ Sf::Rg11b10Float => "r11f_g11f_b10f",
+ Sf::Rg32Uint => "rg32ui",
+ Sf::Rg32Sint => "rg32i",
+ Sf::Rg32Float => "rg32f",
+ Sf::Rgba16Uint => "rgba16ui",
+ Sf::Rgba16Sint => "rgba16i",
+ Sf::Rgba16Float => "rgba16f",
+ Sf::Rgba32Uint => "rgba32ui",
+ Sf::Rgba32Sint => "rgba32i",
+ Sf::Rgba32Float => "rgba32f",
+ Sf::R16Unorm => "r16",
+ Sf::R16Snorm => "r16_snorm",
+ Sf::Rg16Unorm => "rg16",
+ Sf::Rg16Snorm => "rg16_snorm",
+ Sf::Rgba16Unorm => "rgba16",
+ Sf::Rgba16Snorm => "rgba16_snorm",
+
+ Sf::Bgra8Unorm => {
+ return Err(Error::Custom(
+ "Support format BGRA8 is not implemented".into(),
+ ))
+ }
+ })
+}
+
+fn is_value_init_supported(module: &crate::Module, ty: Handle<crate::Type>) -> bool {
+ match module.types[ty].inner {
+ TypeInner::Scalar { .. } | TypeInner::Vector { .. } | TypeInner::Matrix { .. } => true,
+ TypeInner::Array { base, size, .. } => {
+ size != crate::ArraySize::Dynamic && is_value_init_supported(module, base)
+ }
+ TypeInner::Struct { ref members, .. } => members
+ .iter()
+ .all(|member| is_value_init_supported(module, member.ty)),
+ _ => false,
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/conv.rs b/third_party/rust/naga/src/back/hlsl/conv.rs
new file mode 100644
index 0000000000..b6918ddc42
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/conv.rs
@@ -0,0 +1,222 @@
+use std::borrow::Cow;
+
+use crate::proc::Alignment;
+
+use super::Error;
+
+impl crate::ScalarKind {
+ pub(super) fn to_hlsl_cast(self) -> &'static str {
+ match self {
+ Self::Float => "asfloat",
+ Self::Sint => "asint",
+ Self::Uint => "asuint",
+ Self::Bool | Self::AbstractInt | Self::AbstractFloat => unreachable!(),
+ }
+ }
+}
+
+impl crate::Scalar {
+ /// Helper function that returns scalar related strings
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-scalar>
+ pub(super) const fn to_hlsl_str(self) -> Result<&'static str, Error> {
+ match self.kind {
+ crate::ScalarKind::Sint => Ok("int"),
+ crate::ScalarKind::Uint => Ok("uint"),
+ crate::ScalarKind::Float => match self.width {
+ 2 => Ok("half"),
+ 4 => Ok("float"),
+ 8 => Ok("double"),
+ _ => Err(Error::UnsupportedScalar(self)),
+ },
+ crate::ScalarKind::Bool => Ok("bool"),
+ crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => {
+ Err(Error::UnsupportedScalar(self))
+ }
+ }
+ }
+}
+
+impl crate::TypeInner {
+ pub(super) const fn is_matrix(&self) -> bool {
+ match *self {
+ Self::Matrix { .. } => true,
+ _ => false,
+ }
+ }
+
+ pub(super) fn size_hlsl(&self, gctx: crate::proc::GlobalCtx) -> u32 {
+ match *self {
+ Self::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ let stride = Alignment::from(rows) * scalar.width as u32;
+ let last_row_size = rows as u32 * scalar.width as u32;
+ ((columns as u32 - 1) * stride) + last_row_size
+ }
+ Self::Array { base, size, stride } => {
+ let count = match size {
+ crate::ArraySize::Constant(size) => size.get(),
+ // A dynamically-sized array has to have at least one element
+ crate::ArraySize::Dynamic => 1,
+ };
+ let last_el_size = gctx.types[base].inner.size_hlsl(gctx);
+ ((count - 1) * stride) + last_el_size
+ }
+ _ => self.size(gctx),
+ }
+ }
+
+ /// Used to generate the name of the wrapped type constructor
+ pub(super) fn hlsl_type_id<'a>(
+ base: crate::Handle<crate::Type>,
+ gctx: crate::proc::GlobalCtx,
+ names: &'a crate::FastHashMap<crate::proc::NameKey, String>,
+ ) -> Result<Cow<'a, str>, Error> {
+ Ok(match gctx.types[base].inner {
+ crate::TypeInner::Scalar(scalar) => Cow::Borrowed(scalar.to_hlsl_str()?),
+ crate::TypeInner::Vector { size, scalar } => Cow::Owned(format!(
+ "{}{}",
+ scalar.to_hlsl_str()?,
+ crate::back::vector_size_str(size)
+ )),
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => Cow::Owned(format!(
+ "{}{}x{}",
+ scalar.to_hlsl_str()?,
+ crate::back::vector_size_str(columns),
+ crate::back::vector_size_str(rows),
+ )),
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => Cow::Owned(format!(
+ "array{size}_{}_",
+ Self::hlsl_type_id(base, gctx, names)?
+ )),
+ crate::TypeInner::Struct { .. } => {
+ Cow::Borrowed(&names[&crate::proc::NameKey::Type(base)])
+ }
+ _ => unreachable!(),
+ })
+ }
+}
+
+impl crate::StorageFormat {
+ pub(super) const fn to_hlsl_str(self) -> &'static str {
+ match self {
+ Self::R16Float => "float",
+ Self::R8Unorm | Self::R16Unorm => "unorm float",
+ Self::R8Snorm | Self::R16Snorm => "snorm float",
+ Self::R8Uint | Self::R16Uint => "uint",
+ Self::R8Sint | Self::R16Sint => "int",
+
+ Self::Rg16Float => "float2",
+ Self::Rg8Unorm | Self::Rg16Unorm => "unorm float2",
+ Self::Rg8Snorm | Self::Rg16Snorm => "snorm float2",
+
+ Self::Rg8Sint | Self::Rg16Sint => "int2",
+ Self::Rg8Uint | Self::Rg16Uint => "uint2",
+
+ Self::Rg11b10Float => "float3",
+
+ Self::Rgba16Float | Self::R32Float | Self::Rg32Float | Self::Rgba32Float => "float4",
+ Self::Rgba8Unorm | Self::Bgra8Unorm | Self::Rgba16Unorm | Self::Rgb10a2Unorm => {
+ "unorm float4"
+ }
+ Self::Rgba8Snorm | Self::Rgba16Snorm => "snorm float4",
+
+ Self::Rgba8Uint
+ | Self::Rgba16Uint
+ | Self::R32Uint
+ | Self::Rg32Uint
+ | Self::Rgba32Uint
+ | Self::Rgb10a2Uint => "uint4",
+ Self::Rgba8Sint
+ | Self::Rgba16Sint
+ | Self::R32Sint
+ | Self::Rg32Sint
+ | Self::Rgba32Sint => "int4",
+ }
+ }
+}
+
+impl crate::BuiltIn {
+ pub(super) fn to_hlsl_str(self) -> Result<&'static str, Error> {
+ Ok(match self {
+ Self::Position { .. } => "SV_Position",
+ // vertex
+ Self::ClipDistance => "SV_ClipDistance",
+ Self::CullDistance => "SV_CullDistance",
+ Self::InstanceIndex => "SV_InstanceID",
+ Self::VertexIndex => "SV_VertexID",
+ // fragment
+ Self::FragDepth => "SV_Depth",
+ Self::FrontFacing => "SV_IsFrontFace",
+ Self::PrimitiveIndex => "SV_PrimitiveID",
+ Self::SampleIndex => "SV_SampleIndex",
+ Self::SampleMask => "SV_Coverage",
+ // compute
+ Self::GlobalInvocationId => "SV_DispatchThreadID",
+ Self::LocalInvocationId => "SV_GroupThreadID",
+ Self::LocalInvocationIndex => "SV_GroupIndex",
+ Self::WorkGroupId => "SV_GroupID",
+ // The specific semantic we use here doesn't matter, because references
+ // to this field will get replaced with references to `SPECIAL_CBUF_VAR`
+ // in `Writer::write_expr`.
+ Self::NumWorkGroups => "SV_GroupID",
+ Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
+ return Err(Error::Unimplemented(format!("builtin {self:?}")))
+ }
+ Self::PointSize | Self::ViewIndex | Self::PointCoord => {
+ return Err(Error::Custom(format!("Unsupported builtin {self:?}")))
+ }
+ })
+ }
+}
+
+impl crate::Interpolation {
+ /// Return the string corresponding to the HLSL interpolation qualifier.
+ pub(super) const fn to_hlsl_str(self) -> Option<&'static str> {
+ match self {
+ // Would be "linear", but it's the default interpolation in SM4 and up
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-struct#interpolation-modifiers-introduced-in-shader-model-4
+ Self::Perspective => None,
+ Self::Linear => Some("noperspective"),
+ Self::Flat => Some("nointerpolation"),
+ }
+ }
+}
+
+impl crate::Sampling {
+ /// Return the HLSL auxiliary qualifier for the given sampling value.
+ pub(super) const fn to_hlsl_str(self) -> Option<&'static str> {
+ match self {
+ Self::Center => None,
+ Self::Centroid => Some("centroid"),
+ Self::Sample => Some("sample"),
+ }
+ }
+}
+
+impl crate::AtomicFunction {
+ /// Return the HLSL suffix for the `InterlockedXxx` method.
+ pub(super) const fn to_hlsl_suffix(self) -> &'static str {
+ match self {
+ Self::Add | Self::Subtract => "Add",
+ Self::And => "And",
+ Self::InclusiveOr => "Or",
+ Self::ExclusiveOr => "Xor",
+ Self::Min => "Min",
+ Self::Max => "Max",
+ Self::Exchange { compare: None } => "Exchange",
+ Self::Exchange { .. } => "", //TODO
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/help.rs b/third_party/rust/naga/src/back/hlsl/help.rs
new file mode 100644
index 0000000000..fa6062a1ad
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/help.rs
@@ -0,0 +1,1138 @@
+/*!
+Helpers for the hlsl backend
+
+Important note about `Expression::ImageQuery`/`Expression::ArrayLength` and hlsl backend:
+
+Due to implementation of `GetDimensions` function in hlsl (<https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>)
+backend can't work with it as an expression.
+Instead, it generates a unique wrapped function per `Expression::ImageQuery`, based on texture info and query function.
+See `WrappedImageQuery` struct that represents a unique function and will be generated before writing all statements and expressions.
+This allowed to works with `Expression::ImageQuery` as expression and write wrapped function.
+
+For example:
+```wgsl
+let dim_1d = textureDimensions(image_1d);
+```
+
+```hlsl
+int NagaDimensions1D(Texture1D<float4>)
+{
+ uint4 ret;
+ image_1d.GetDimensions(ret.x);
+ return ret.x;
+}
+
+int dim_1d = NagaDimensions1D(image_1d);
+```
+*/
+
+use super::{super::FunctionCtx, BackendResult};
+use crate::{arena::Handle, proc::NameKey};
+use std::fmt::Write;
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedArrayLength {
+ pub(super) writable: bool,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedImageQuery {
+ pub(super) dim: crate::ImageDimension,
+ pub(super) arrayed: bool,
+ pub(super) class: crate::ImageClass,
+ pub(super) query: ImageQuery,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedConstructor {
+ pub(super) ty: Handle<crate::Type>,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedStructMatrixAccess {
+ pub(super) ty: Handle<crate::Type>,
+ pub(super) index: u32,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedMatCx2 {
+ pub(super) columns: crate::VectorSize,
+}
+
+/// HLSL backend requires its own `ImageQuery` enum.
+///
+/// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function.
+/// IR version can't be unique per function, because it's store mipmap level as an expression.
+///
+/// For example:
+/// ```wgsl
+/// let dim_cube_array_lod = textureDimensions(image_cube_array, 1);
+/// let dim_cube_array_lod2 = textureDimensions(image_cube_array, 1);
+/// ```
+///
+/// ```ir
+/// ImageQuery {
+/// image: [1],
+/// query: Size {
+/// level: Some(
+/// [1],
+/// ),
+/// },
+/// },
+/// ImageQuery {
+/// image: [1],
+/// query: Size {
+/// level: Some(
+/// [2],
+/// ),
+/// },
+/// },
+/// ```
+///
+/// HLSL should generate only 1 function for this case.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) enum ImageQuery {
+ Size,
+ SizeLevel,
+ NumLevels,
+ NumLayers,
+ NumSamples,
+}
+
+impl From<crate::ImageQuery> for ImageQuery {
+ fn from(q: crate::ImageQuery) -> Self {
+ use crate::ImageQuery as Iq;
+ match q {
+ Iq::Size { level: Some(_) } => ImageQuery::SizeLevel,
+ Iq::Size { level: None } => ImageQuery::Size,
+ Iq::NumLevels => ImageQuery::NumLevels,
+ Iq::NumLayers => ImageQuery::NumLayers,
+ Iq::NumSamples => ImageQuery::NumSamples,
+ }
+ }
+}
+
+impl<'a, W: Write> super::Writer<'a, W> {
+ pub(super) fn write_image_type(
+ &mut self,
+ dim: crate::ImageDimension,
+ arrayed: bool,
+ class: crate::ImageClass,
+ ) -> BackendResult {
+ let access_str = match class {
+ crate::ImageClass::Storage { .. } => "RW",
+ _ => "",
+ };
+ let dim_str = dim.to_hlsl_str();
+ let arrayed_str = if arrayed { "Array" } else { "" };
+ write!(self.out, "{access_str}Texture{dim_str}{arrayed_str}")?;
+ match class {
+ crate::ImageClass::Depth { multi } => {
+ let multi_str = if multi { "MS" } else { "" };
+ write!(self.out, "{multi_str}<float>")?
+ }
+ crate::ImageClass::Sampled { kind, multi } => {
+ let multi_str = if multi { "MS" } else { "" };
+ let scalar_kind_str = crate::Scalar { kind, width: 4 }.to_hlsl_str()?;
+ write!(self.out, "{multi_str}<{scalar_kind_str}4>")?
+ }
+ crate::ImageClass::Storage { format, .. } => {
+ let storage_format_str = format.to_hlsl_str();
+ write!(self.out, "<{storage_format_str}>")?
+ }
+ }
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_array_length_function_name(
+ &mut self,
+ query: WrappedArrayLength,
+ ) -> BackendResult {
+ let access_str = if query.writable { "RW" } else { "" };
+ write!(self.out, "NagaBufferLength{access_str}",)?;
+
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::ArrayLength`
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer-getdimensions>
+ pub(super) fn write_wrapped_array_length_function(
+ &mut self,
+ wal: WrappedArrayLength,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const ARGUMENT_VARIABLE_NAME: &str = "buffer";
+ const RETURN_VARIABLE_NAME: &str = "ret";
+
+ // Write function return type and name
+ write!(self.out, "uint ")?;
+ self.write_wrapped_array_length_function_name(wal)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let access_str = if wal.writable { "RW" } else { "" };
+ writeln!(
+ self.out,
+ "{access_str}ByteAddressBuffer {ARGUMENT_VARIABLE_NAME})"
+ )?;
+ // Write function body
+ writeln!(self.out, "{{")?;
+
+ // Write `GetDimensions` function.
+ writeln!(self.out, "{INDENT}uint {RETURN_VARIABLE_NAME};")?;
+ writeln!(
+ self.out,
+ "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions({RETURN_VARIABLE_NAME});"
+ )?;
+
+ // Write return value
+ writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_image_query_function_name(
+ &mut self,
+ query: WrappedImageQuery,
+ ) -> BackendResult {
+ let dim_str = query.dim.to_hlsl_str();
+ let class_str = match query.class {
+ crate::ImageClass::Sampled { multi: true, .. } => "MS",
+ crate::ImageClass::Depth { multi: true } => "DepthMS",
+ crate::ImageClass::Depth { multi: false } => "Depth",
+ crate::ImageClass::Sampled { multi: false, .. } => "",
+ crate::ImageClass::Storage { .. } => "RW",
+ };
+ let arrayed_str = if query.arrayed { "Array" } else { "" };
+ let query_str = match query.query {
+ ImageQuery::Size => "Dimensions",
+ ImageQuery::SizeLevel => "MipDimensions",
+ ImageQuery::NumLevels => "NumLevels",
+ ImageQuery::NumLayers => "NumLayers",
+ ImageQuery::NumSamples => "NumSamples",
+ };
+
+ write!(self.out, "Naga{class_str}{query_str}{dim_str}{arrayed_str}")?;
+
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::ImageQuery`
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>
+ pub(super) fn write_wrapped_image_query_function(
+ &mut self,
+ module: &crate::Module,
+ wiq: WrappedImageQuery,
+ expr_handle: Handle<crate::Expression>,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ use crate::{
+ back::{COMPONENTS, INDENT},
+ ImageDimension as IDim,
+ };
+
+ const ARGUMENT_VARIABLE_NAME: &str = "tex";
+ const RETURN_VARIABLE_NAME: &str = "ret";
+ const MIP_LEVEL_PARAM: &str = "mip_level";
+
+ // Write function return type and name
+ let ret_ty = func_ctx.resolve_type(expr_handle, &module.types);
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, " ")?;
+ self.write_wrapped_image_query_function_name(wiq)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ // Texture always first parameter
+ self.write_image_type(wiq.dim, wiq.arrayed, wiq.class)?;
+ write!(self.out, " {ARGUMENT_VARIABLE_NAME}")?;
+ // Mipmap is a second parameter if exists
+ if let ImageQuery::SizeLevel = wiq.query {
+ write!(self.out, ", uint {MIP_LEVEL_PARAM}")?;
+ }
+ writeln!(self.out, ")")?;
+
+ // Write function body
+ writeln!(self.out, "{{")?;
+
+ let array_coords = usize::from(wiq.arrayed);
+ // extra parameter is the mip level count or the sample count
+ let extra_coords = match wiq.class {
+ crate::ImageClass::Storage { .. } => 0,
+ crate::ImageClass::Sampled { .. } | crate::ImageClass::Depth { .. } => 1,
+ };
+
+ // GetDimensions Overloaded Methods
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions#overloaded-methods
+ let (ret_swizzle, number_of_params) = match wiq.query {
+ ImageQuery::Size | ImageQuery::SizeLevel => {
+ let ret = match wiq.dim {
+ IDim::D1 => "x",
+ IDim::D2 => "xy",
+ IDim::D3 => "xyz",
+ IDim::Cube => "xy",
+ };
+ (ret, ret.len() + array_coords + extra_coords)
+ }
+ ImageQuery::NumLevels | ImageQuery::NumSamples | ImageQuery::NumLayers => {
+ if wiq.arrayed || wiq.dim == IDim::D3 {
+ ("w", 4)
+ } else {
+ ("z", 3)
+ }
+ }
+ };
+
+ // Write `GetDimensions` function.
+ writeln!(self.out, "{INDENT}uint4 {RETURN_VARIABLE_NAME};")?;
+ write!(self.out, "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions(")?;
+ match wiq.query {
+ ImageQuery::SizeLevel => {
+ write!(self.out, "{MIP_LEVEL_PARAM}, ")?;
+ }
+ _ => match wiq.class {
+ crate::ImageClass::Sampled { multi: true, .. }
+ | crate::ImageClass::Depth { multi: true }
+ | crate::ImageClass::Storage { .. } => {}
+ _ => {
+ // Write zero mipmap level for supported types
+ write!(self.out, "0, ")?;
+ }
+ },
+ }
+
+ for component in COMPONENTS[..number_of_params - 1].iter() {
+ write!(self.out, "{RETURN_VARIABLE_NAME}.{component}, ")?;
+ }
+
+ // write last parameter without comma and space for last parameter
+ write!(
+ self.out,
+ "{}.{}",
+ RETURN_VARIABLE_NAME,
+ COMPONENTS[number_of_params - 1]
+ )?;
+
+ writeln!(self.out, ");")?;
+
+ // Write return value
+ writeln!(
+ self.out,
+ "{INDENT}return {RETURN_VARIABLE_NAME}.{ret_swizzle};"
+ )?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_constructor_function_name(
+ &mut self,
+ module: &crate::Module,
+ constructor: WrappedConstructor,
+ ) -> BackendResult {
+ let name = crate::TypeInner::hlsl_type_id(constructor.ty, module.to_ctx(), &self.names)?;
+ write!(self.out, "Construct{name}")?;
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::Compose` for structures.
+ pub(super) fn write_wrapped_constructor_function(
+ &mut self,
+ module: &crate::Module,
+ constructor: WrappedConstructor,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const ARGUMENT_VARIABLE_NAME: &str = "arg";
+ const RETURN_VARIABLE_NAME: &str = "ret";
+
+ // Write function return type and name
+ if let crate::TypeInner::Array { base, size, .. } = module.types[constructor.ty].inner {
+ write!(self.out, "typedef ")?;
+ self.write_type(module, constructor.ty)?;
+ write!(self.out, " ret_")?;
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ self.write_array_size(module, base, size)?;
+ writeln!(self.out, ";")?;
+
+ write!(self.out, "ret_")?;
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ } else {
+ self.write_type(module, constructor.ty)?;
+ }
+ write!(self.out, " ")?;
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+
+ let mut write_arg = |i, ty| -> BackendResult {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_type(module, ty)?;
+ write!(self.out, " {ARGUMENT_VARIABLE_NAME}{i}")?;
+ if let crate::TypeInner::Array { base, size, .. } = module.types[ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ Ok(())
+ };
+
+ match module.types[constructor.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (i, member) in members.iter().enumerate() {
+ write_arg(i, member.ty)?;
+ }
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ for i in 0..size.get() as usize {
+ write_arg(i, base)?;
+ }
+ }
+ _ => unreachable!(),
+ };
+
+ write!(self.out, ")")?;
+
+ // Write function body
+ writeln!(self.out, " {{")?;
+
+ match module.types[constructor.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let struct_name = &self.names[&NameKey::Type(constructor.ty)];
+ writeln!(
+ self.out,
+ "{INDENT}{struct_name} {RETURN_VARIABLE_NAME} = ({struct_name})0;"
+ )?;
+ for (i, member) in members.iter().enumerate() {
+ let field_name = &self.names[&NameKey::StructMember(constructor.ty, i as u32)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix {
+ columns,
+ rows: crate::VectorSize::Bi,
+ ..
+ } if member.binding.is_none() => {
+ for j in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}{RETURN_VARIABLE_NAME}.{field_name}_{j} = {ARGUMENT_VARIABLE_NAME}{i}[{j}];"
+ )?;
+ }
+ }
+ ref other => {
+ // We cast arrays of native HLSL `floatCx2`s to arrays of `matCx2`s
+ // (where the inner matrix is represented by a struct with C `float2` members).
+ // See the module-level block comment in mod.rs for details.
+ if let Some(super::writer::MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = super::writer::get_inner_matrix_data(module, member.ty)
+ {
+ write!(
+ self.out,
+ "{}{}.{} = (__mat{}x2",
+ INDENT, RETURN_VARIABLE_NAME, field_name, columns as u8
+ )?;
+ if let crate::TypeInner::Array { base, size, .. } = *other {
+ self.write_array_size(module, base, size)?;
+ }
+ writeln!(self.out, "){ARGUMENT_VARIABLE_NAME}{i};",)?;
+ } else {
+ writeln!(
+ self.out,
+ "{INDENT}{RETURN_VARIABLE_NAME}.{field_name} = {ARGUMENT_VARIABLE_NAME}{i};",
+ )?;
+ }
+ }
+ }
+ }
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ write!(self.out, "{INDENT}")?;
+ self.write_type(module, base)?;
+ write!(self.out, " {RETURN_VARIABLE_NAME}")?;
+ self.write_array_size(module, base, crate::ArraySize::Constant(size))?;
+ write!(self.out, " = {{ ")?;
+ for i in 0..size.get() {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "{ARGUMENT_VARIABLE_NAME}{i}")?;
+ }
+ writeln!(self.out, " }};",)?;
+ }
+ _ => unreachable!(),
+ }
+
+ // Write return value
+ writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_get_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "GetMat{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to get a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_get_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+
+ // Write function return type and name
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ let ret_ty = &module.types[member.ty].inner;
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, " ")?;
+ self.write_wrapped_struct_matrix_get_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}")?;
+
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ // Write return value
+ write!(self.out, "{INDENT}return ")?;
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, "(")?;
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}")?;
+ }
+ }
+ _ => unreachable!(),
+ }
+ writeln!(self.out, ");")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_set_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "SetMat{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to set a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_set_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+ const MATRIX_ARGUMENT_VARIABLE_NAME: &str = "mat";
+
+ // Write function return type and name
+ write!(self.out, "void ")?;
+ self.write_wrapped_struct_matrix_set_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ self.write_type(module, member.ty)?;
+ write!(self.out, " {MATRIX_ARGUMENT_VARIABLE_NAME}")?;
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {MATRIX_ARGUMENT_VARIABLE_NAME}[{i}];"
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_set_vec_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "SetMatVec{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to set a vec2 on a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_set_vec_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+ const VECTOR_ARGUMENT_VARIABLE_NAME: &str = "vec";
+ const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx";
+
+ // Write function return type and name
+ write!(self.out, "void ")?;
+ self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ let vec_ty = match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { rows, scalar, .. } => {
+ crate::TypeInner::Vector { size: rows, scalar }
+ }
+ _ => unreachable!(),
+ };
+ self.write_value_type(module, &vec_ty)?;
+ write!(
+ self.out,
+ " {VECTOR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}"
+ )?;
+
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ writeln!(
+ self.out,
+ "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{"
+ )?;
+
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {VECTOR_ARGUMENT_VARIABLE_NAME}; break; }}"
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ writeln!(self.out, "{INDENT}}}")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_set_scalar_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "SetMatScalar{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to set a float on a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_set_scalar_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+ const SCALAR_ARGUMENT_VARIABLE_NAME: &str = "scalar";
+ const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx";
+ const VECTOR_INDEX_ARGUMENT_VARIABLE_NAME: &str = "vec_idx";
+
+ // Write function return type and name
+ write!(self.out, "void ")?;
+ self.write_wrapped_struct_matrix_set_scalar_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ let scalar_ty = match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { scalar, .. } => crate::TypeInner::Scalar(scalar),
+ _ => unreachable!(),
+ };
+ self.write_value_type(module, &scalar_ty)?;
+ write!(
+ self.out,
+ " {SCALAR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}, uint {VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}"
+ )?;
+
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ writeln!(
+ self.out,
+ "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{"
+ )?;
+
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}[{VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}] = {SCALAR_ARGUMENT_VARIABLE_NAME}; break; }}"
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ writeln!(self.out, "{INDENT}}}")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ /// Write functions to create special types.
+ pub(super) fn write_special_functions(&mut self, module: &crate::Module) -> BackendResult {
+ for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
+ match type_key {
+ &crate::PredeclaredType::ModfResult { size, width }
+ | &crate::PredeclaredType::FrexpResult { size, width } => {
+ let arg_type_name_owner;
+ let arg_type_name = if let Some(size) = size {
+ arg_type_name_owner = format!(
+ "{}{}",
+ if width == 8 { "double" } else { "float" },
+ size as u8
+ );
+ &arg_type_name_owner
+ } else if width == 8 {
+ "double"
+ } else {
+ "float"
+ };
+
+ let (defined_func_name, called_func_name, second_field_name, sign_multiplier) =
+ if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
+ (super::writer::MODF_FUNCTION, "modf", "whole", "")
+ } else {
+ (
+ super::writer::FREXP_FUNCTION,
+ "frexp",
+ "exp_",
+ "sign(arg) * ",
+ )
+ };
+
+ let struct_name = &self.names[&NameKey::Type(*struct_ty)];
+
+ writeln!(
+ self.out,
+ "{struct_name} {defined_func_name}({arg_type_name} arg) {{
+ {arg_type_name} other;
+ {struct_name} result;
+ result.fract = {sign_multiplier}{called_func_name}(arg, other);
+ result.{second_field_name} = other;
+ return result;
+}}"
+ )?;
+ writeln!(self.out)?;
+ }
+ &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper function that writes compose wrapped functions
+ pub(super) fn write_wrapped_compose_functions(
+ &mut self,
+ module: &crate::Module,
+ expressions: &crate::Arena<crate::Expression>,
+ ) -> BackendResult {
+ for (handle, _) in expressions.iter() {
+ if let crate::Expression::Compose { ty, .. } = expressions[handle] {
+ match module.types[ty].inner {
+ crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => {
+ let constructor = WrappedConstructor { ty };
+ if self.wrapped.constructors.insert(constructor) {
+ self.write_wrapped_constructor_function(module, constructor)?;
+ }
+ }
+ _ => {}
+ };
+ }
+ }
+ Ok(())
+ }
+
+ /// Helper function that writes various wrapped functions
+ pub(super) fn write_wrapped_functions(
+ &mut self,
+ module: &crate::Module,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ self.write_wrapped_compose_functions(module, func_ctx.expressions)?;
+
+ for (handle, _) in func_ctx.expressions.iter() {
+ match func_ctx.expressions[handle] {
+ crate::Expression::ArrayLength(expr) => {
+ let global_expr = match func_ctx.expressions[expr] {
+ crate::Expression::GlobalVariable(_) => expr,
+ crate::Expression::AccessIndex { base, index: _ } => base,
+ ref other => unreachable!("Array length of {:?}", other),
+ };
+ let global_var = match func_ctx.expressions[global_expr] {
+ crate::Expression::GlobalVariable(var_handle) => {
+ &module.global_variables[var_handle]
+ }
+ ref other => unreachable!("Array length of base {:?}", other),
+ };
+ let storage_access = match global_var.space {
+ crate::AddressSpace::Storage { access } => access,
+ _ => crate::StorageAccess::default(),
+ };
+ let wal = WrappedArrayLength {
+ writable: storage_access.contains(crate::StorageAccess::STORE),
+ };
+
+ if self.wrapped.array_lengths.insert(wal) {
+ self.write_wrapped_array_length_function(wal)?;
+ }
+ }
+ crate::Expression::ImageQuery { image, query } => {
+ let wiq = match *func_ctx.resolve_type(image, &module.types) {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => WrappedImageQuery {
+ dim,
+ arrayed,
+ class,
+ query: query.into(),
+ },
+ _ => unreachable!("we only query images"),
+ };
+
+ if self.wrapped.image_queries.insert(wiq) {
+ self.write_wrapped_image_query_function(module, wiq, handle, func_ctx)?;
+ }
+ }
+ // Write `WrappedConstructor` for structs that are loaded from `AddressSpace::Storage`
+ // since they will later be used by the fn `write_storage_load`
+ crate::Expression::Load { pointer } => {
+ let pointer_space = func_ctx
+ .resolve_type(pointer, &module.types)
+ .pointer_space();
+
+ if let Some(crate::AddressSpace::Storage { .. }) = pointer_space {
+ if let Some(ty) = func_ctx.info[handle].ty.handle() {
+ write_wrapped_constructor(self, ty, module)?;
+ }
+ }
+
+ fn write_wrapped_constructor<W: Write>(
+ writer: &mut super::Writer<'_, W>,
+ ty: Handle<crate::Type>,
+ module: &crate::Module,
+ ) -> BackendResult {
+ match module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for member in members {
+ write_wrapped_constructor(writer, member.ty, module)?;
+ }
+
+ let constructor = WrappedConstructor { ty };
+ if writer.wrapped.constructors.insert(constructor) {
+ writer
+ .write_wrapped_constructor_function(module, constructor)?;
+ }
+ }
+ crate::TypeInner::Array { base, .. } => {
+ write_wrapped_constructor(writer, base, module)?;
+
+ let constructor = WrappedConstructor { ty };
+ if writer.wrapped.constructors.insert(constructor) {
+ writer
+ .write_wrapped_constructor_function(module, constructor)?;
+ }
+ }
+ _ => {}
+ };
+
+ Ok(())
+ }
+ }
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s
+ // (see top level module docs for details).
+ //
+ // The functions injected here are required to get the matrix accesses working.
+ crate::Expression::AccessIndex { base, index } => {
+ let base_ty_res = &func_ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+ let base_ty_handle = match *resolved {
+ crate::TypeInner::Pointer { base, .. } => {
+ resolved = &module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+ if let crate::TypeInner::Struct { ref members, .. } = *resolved {
+ let member = &members[index as usize];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix {
+ rows: crate::VectorSize::Bi,
+ ..
+ } if member.binding.is_none() => {
+ let ty = base_ty_handle.unwrap();
+ let access = WrappedStructMatrixAccess { ty, index };
+
+ if self.wrapped.struct_matrix_access.insert(access) {
+ self.write_wrapped_struct_matrix_get_function(module, access)?;
+ self.write_wrapped_struct_matrix_set_function(module, access)?;
+ self.write_wrapped_struct_matrix_set_vec_function(
+ module, access,
+ )?;
+ self.write_wrapped_struct_matrix_set_scalar_function(
+ module, access,
+ )?;
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+ _ => {}
+ };
+ }
+
+ Ok(())
+ }
+
+ pub(super) fn write_texture_coordinates(
+ &mut self,
+ kind: &str,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ mip_level: Option<Handle<crate::Expression>>,
+ module: &crate::Module,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ // HLSL expects the array index to be merged with the coordinate
+ let extra = array_index.is_some() as usize + (mip_level.is_some()) as usize;
+ if extra == 0 {
+ self.write_expr(module, coordinate, func_ctx)?;
+ } else {
+ let num_coords = match *func_ctx.resolve_type(coordinate, &module.types) {
+ crate::TypeInner::Scalar { .. } => 1,
+ crate::TypeInner::Vector { size, .. } => size as usize,
+ _ => unreachable!(),
+ };
+ write!(self.out, "{}{}(", kind, num_coords + extra)?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ if let Some(expr) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ if let Some(expr) = mip_level {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ write!(self.out, ")")?;
+ }
+ Ok(())
+ }
+
+ pub(super) fn write_mat_cx2_typedef_and_functions(
+ &mut self,
+ WrappedMatCx2 { columns }: WrappedMatCx2,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ // typedef
+ write!(self.out, "typedef struct {{ ")?;
+ for i in 0..columns as u8 {
+ write!(self.out, "float2 _{i}; ")?;
+ }
+ writeln!(self.out, "}} __mat{}x2;", columns as u8)?;
+
+ // __get_col_of_mat
+ writeln!(
+ self.out,
+ "float2 __get_col_of_mat{}x2(__mat{}x2 mat, uint idx) {{",
+ columns as u8, columns as u8
+ )?;
+ writeln!(self.out, "{INDENT}switch(idx) {{")?;
+ for i in 0..columns as u8 {
+ writeln!(self.out, "{INDENT}case {i}: {{ return mat._{i}; }}")?;
+ }
+ writeln!(self.out, "{INDENT}default: {{ return (float2)0; }}")?;
+ writeln!(self.out, "{INDENT}}}")?;
+ writeln!(self.out, "}}")?;
+
+ // __set_col_of_mat
+ writeln!(
+ self.out,
+ "void __set_col_of_mat{}x2(__mat{}x2 mat, uint idx, float2 value) {{",
+ columns as u8, columns as u8
+ )?;
+ writeln!(self.out, "{INDENT}switch(idx) {{")?;
+ for i in 0..columns as u8 {
+ writeln!(self.out, "{INDENT}case {i}: {{ mat._{i} = value; break; }}")?;
+ }
+ writeln!(self.out, "{INDENT}}}")?;
+ writeln!(self.out, "}}")?;
+
+ // __set_el_of_mat
+ writeln!(
+ self.out,
+ "void __set_el_of_mat{}x2(__mat{}x2 mat, uint idx, uint vec_idx, float value) {{",
+ columns as u8, columns as u8
+ )?;
+ writeln!(self.out, "{INDENT}switch(idx) {{")?;
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}case {i}: {{ mat._{i}[vec_idx] = value; break; }}"
+ )?;
+ }
+ writeln!(self.out, "{INDENT}}}")?;
+ writeln!(self.out, "}}")?;
+
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_all_mat_cx2_typedefs_and_functions(
+ &mut self,
+ module: &crate::Module,
+ ) -> BackendResult {
+ for (handle, _) in module.global_variables.iter() {
+ let global = &module.global_variables[handle];
+
+ if global.space == crate::AddressSpace::Uniform {
+ if let Some(super::writer::MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = super::writer::get_inner_matrix_data(module, global.ty)
+ {
+ let entry = WrappedMatCx2 { columns };
+ if self.wrapped.mat_cx2s.insert(entry) {
+ self.write_mat_cx2_typedef_and_functions(entry)?;
+ }
+ }
+ }
+ }
+
+ for (_, ty) in module.types.iter() {
+ if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
+ for member in members.iter() {
+ if let crate::TypeInner::Array { .. } = module.types[member.ty].inner {
+ if let Some(super::writer::MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = super::writer::get_inner_matrix_data(module, member.ty)
+ {
+ let entry = WrappedMatCx2 { columns };
+ if self.wrapped.mat_cx2s.insert(entry) {
+ self.write_mat_cx2_typedef_and_functions(entry)?;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/keywords.rs b/third_party/rust/naga/src/back/hlsl/keywords.rs
new file mode 100644
index 0000000000..059e533ff7
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/keywords.rs
@@ -0,0 +1,904 @@
+// When compiling with FXC without strict mode, these keywords are actually case insensitive.
+// If you compile with strict mode and specify a different casing like "Pass" instead in an identifier, FXC will give this error:
+// "error X3086: alternate cases for 'pass' are deprecated in strict mode"
+// This behavior is not documented anywhere, but as far as I can tell this is the full list.
+pub const RESERVED_CASE_INSENSITIVE: &[&str] = &[
+ "asm",
+ "decl",
+ "pass",
+ "technique",
+ "Texture1D",
+ "Texture2D",
+ "Texture3D",
+ "TextureCube",
+];
+
+pub const RESERVED: &[&str] = &[
+ // FXC keywords, from https://github.com/MicrosoftDocs/win32/blob/c885cb0c63b0e9be80c6a0e6512473ac6f4e771e/desktop-src/direct3dhlsl/dx-graphics-hlsl-appendix-keywords.md?plain=1#L99-L118
+ "AppendStructuredBuffer",
+ "asm",
+ "asm_fragment",
+ "BlendState",
+ "bool",
+ "break",
+ "Buffer",
+ "ByteAddressBuffer",
+ "case",
+ "cbuffer",
+ "centroid",
+ "class",
+ "column_major",
+ "compile",
+ "compile_fragment",
+ "CompileShader",
+ "const",
+ "continue",
+ "ComputeShader",
+ "ConsumeStructuredBuffer",
+ "default",
+ "DepthStencilState",
+ "DepthStencilView",
+ "discard",
+ "do",
+ "double",
+ "DomainShader",
+ "dword",
+ "else",
+ "export",
+ "extern",
+ "false",
+ "float",
+ "for",
+ "fxgroup",
+ "GeometryShader",
+ "groupshared",
+ "half",
+ "Hullshader",
+ "if",
+ "in",
+ "inline",
+ "inout",
+ "InputPatch",
+ "int",
+ "interface",
+ "line",
+ "lineadj",
+ "linear",
+ "LineStream",
+ "matrix",
+ "min16float",
+ "min10float",
+ "min16int",
+ "min12int",
+ "min16uint",
+ "namespace",
+ "nointerpolation",
+ "noperspective",
+ "NULL",
+ "out",
+ "OutputPatch",
+ "packoffset",
+ "pass",
+ "pixelfragment",
+ "PixelShader",
+ "point",
+ "PointStream",
+ "precise",
+ "RasterizerState",
+ "RenderTargetView",
+ "return",
+ "register",
+ "row_major",
+ "RWBuffer",
+ "RWByteAddressBuffer",
+ "RWStructuredBuffer",
+ "RWTexture1D",
+ "RWTexture1DArray",
+ "RWTexture2D",
+ "RWTexture2DArray",
+ "RWTexture3D",
+ "sample",
+ "sampler",
+ "SamplerState",
+ "SamplerComparisonState",
+ "shared",
+ "snorm",
+ "stateblock",
+ "stateblock_state",
+ "static",
+ "string",
+ "struct",
+ "switch",
+ "StructuredBuffer",
+ "tbuffer",
+ "technique",
+ "technique10",
+ "technique11",
+ "texture",
+ "Texture1D",
+ "Texture1DArray",
+ "Texture2D",
+ "Texture2DArray",
+ "Texture2DMS",
+ "Texture2DMSArray",
+ "Texture3D",
+ "TextureCube",
+ "TextureCubeArray",
+ "true",
+ "typedef",
+ "triangle",
+ "triangleadj",
+ "TriangleStream",
+ "uint",
+ "uniform",
+ "unorm",
+ "unsigned",
+ "vector",
+ "vertexfragment",
+ "VertexShader",
+ "void",
+ "volatile",
+ "while",
+ // FXC reserved keywords, from https://github.com/MicrosoftDocs/win32/blob/c885cb0c63b0e9be80c6a0e6512473ac6f4e771e/desktop-src/direct3dhlsl/dx-graphics-hlsl-appendix-reserved-words.md?plain=1#L19-L38
+ "auto",
+ "case",
+ "catch",
+ "char",
+ "class",
+ "const_cast",
+ "default",
+ "delete",
+ "dynamic_cast",
+ "enum",
+ "explicit",
+ "friend",
+ "goto",
+ "long",
+ "mutable",
+ "new",
+ "operator",
+ "private",
+ "protected",
+ "public",
+ "reinterpret_cast",
+ "short",
+ "signed",
+ "sizeof",
+ "static_cast",
+ "template",
+ "this",
+ "throw",
+ "try",
+ "typename",
+ "union",
+ "unsigned",
+ "using",
+ "virtual",
+ // FXC intrinsics, from https://github.com/MicrosoftDocs/win32/blob/1682b99e203708f6f5eda972d966e30f3c1588de/desktop-src/direct3dhlsl/dx-graphics-hlsl-intrinsic-functions.md?plain=1#L26-L165
+ "abort",
+ "abs",
+ "acos",
+ "all",
+ "AllMemoryBarrier",
+ "AllMemoryBarrierWithGroupSync",
+ "any",
+ "asdouble",
+ "asfloat",
+ "asin",
+ "asint",
+ "asuint",
+ "atan",
+ "atan2",
+ "ceil",
+ "CheckAccessFullyMapped",
+ "clamp",
+ "clip",
+ "cos",
+ "cosh",
+ "countbits",
+ "cross",
+ "D3DCOLORtoUBYTE4",
+ "ddx",
+ "ddx_coarse",
+ "ddx_fine",
+ "ddy",
+ "ddy_coarse",
+ "ddy_fine",
+ "degrees",
+ "determinant",
+ "DeviceMemoryBarrier",
+ "DeviceMemoryBarrierWithGroupSync",
+ "distance",
+ "dot",
+ "dst",
+ "errorf",
+ "EvaluateAttributeCentroid",
+ "EvaluateAttributeAtSample",
+ "EvaluateAttributeSnapped",
+ "exp",
+ "exp2",
+ "f16tof32",
+ "f32tof16",
+ "faceforward",
+ "firstbithigh",
+ "firstbitlow",
+ "floor",
+ "fma",
+ "fmod",
+ "frac",
+ "frexp",
+ "fwidth",
+ "GetRenderTargetSampleCount",
+ "GetRenderTargetSamplePosition",
+ "GroupMemoryBarrier",
+ "GroupMemoryBarrierWithGroupSync",
+ "InterlockedAdd",
+ "InterlockedAnd",
+ "InterlockedCompareExchange",
+ "InterlockedCompareStore",
+ "InterlockedExchange",
+ "InterlockedMax",
+ "InterlockedMin",
+ "InterlockedOr",
+ "InterlockedXor",
+ "isfinite",
+ "isinf",
+ "isnan",
+ "ldexp",
+ "length",
+ "lerp",
+ "lit",
+ "log",
+ "log10",
+ "log2",
+ "mad",
+ "max",
+ "min",
+ "modf",
+ "msad4",
+ "mul",
+ "noise",
+ "normalize",
+ "pow",
+ "printf",
+ "Process2DQuadTessFactorsAvg",
+ "Process2DQuadTessFactorsMax",
+ "Process2DQuadTessFactorsMin",
+ "ProcessIsolineTessFactors",
+ "ProcessQuadTessFactorsAvg",
+ "ProcessQuadTessFactorsMax",
+ "ProcessQuadTessFactorsMin",
+ "ProcessTriTessFactorsAvg",
+ "ProcessTriTessFactorsMax",
+ "ProcessTriTessFactorsMin",
+ "radians",
+ "rcp",
+ "reflect",
+ "refract",
+ "reversebits",
+ "round",
+ "rsqrt",
+ "saturate",
+ "sign",
+ "sin",
+ "sincos",
+ "sinh",
+ "smoothstep",
+ "sqrt",
+ "step",
+ "tan",
+ "tanh",
+ "tex1D",
+ "tex1Dbias",
+ "tex1Dgrad",
+ "tex1Dlod",
+ "tex1Dproj",
+ "tex2D",
+ "tex2Dbias",
+ "tex2Dgrad",
+ "tex2Dlod",
+ "tex2Dproj",
+ "tex3D",
+ "tex3Dbias",
+ "tex3Dgrad",
+ "tex3Dlod",
+ "tex3Dproj",
+ "texCUBE",
+ "texCUBEbias",
+ "texCUBEgrad",
+ "texCUBElod",
+ "texCUBEproj",
+ "transpose",
+ "trunc",
+ // DXC (reserved) keywords, from https://github.com/microsoft/DirectXShaderCompiler/blob/d5d478470d3020a438d3cb810b8d3fe0992e6709/tools/clang/include/clang/Basic/TokenKinds.def#L222-L648
+ // with the KEYALL, KEYCXX, BOOLSUPPORT, WCHARSUPPORT, KEYHLSL options enabled (see https://github.com/microsoft/DirectXShaderCompiler/blob/d5d478470d3020a438d3cb810b8d3fe0992e6709/tools/clang/lib/Frontend/CompilerInvocation.cpp#L1199)
+ "auto",
+ "break",
+ "case",
+ "char",
+ "const",
+ "continue",
+ "default",
+ "do",
+ "double",
+ "else",
+ "enum",
+ "extern",
+ "float",
+ "for",
+ "goto",
+ "if",
+ "inline",
+ "int",
+ "long",
+ "register",
+ "return",
+ "short",
+ "signed",
+ "sizeof",
+ "static",
+ "struct",
+ "switch",
+ "typedef",
+ "union",
+ "unsigned",
+ "void",
+ "volatile",
+ "while",
+ "_Alignas",
+ "_Alignof",
+ "_Atomic",
+ "_Complex",
+ "_Generic",
+ "_Imaginary",
+ "_Noreturn",
+ "_Static_assert",
+ "_Thread_local",
+ "__func__",
+ "__objc_yes",
+ "__objc_no",
+ "asm",
+ "bool",
+ "catch",
+ "class",
+ "const_cast",
+ "delete",
+ "dynamic_cast",
+ "explicit",
+ "export",
+ "false",
+ "friend",
+ "mutable",
+ "namespace",
+ "new",
+ "operator",
+ "private",
+ "protected",
+ "public",
+ "reinterpret_cast",
+ "static_cast",
+ "template",
+ "this",
+ "throw",
+ "true",
+ "try",
+ "typename",
+ "typeid",
+ "using",
+ "virtual",
+ "wchar_t",
+ "_Decimal32",
+ "_Decimal64",
+ "_Decimal128",
+ "__null",
+ "__alignof",
+ "__attribute",
+ "__builtin_choose_expr",
+ "__builtin_offsetof",
+ "__builtin_va_arg",
+ "__extension__",
+ "__imag",
+ "__int128",
+ "__label__",
+ "__real",
+ "__thread",
+ "__FUNCTION__",
+ "__PRETTY_FUNCTION__",
+ "__is_nothrow_assignable",
+ "__is_constructible",
+ "__is_nothrow_constructible",
+ "__has_nothrow_assign",
+ "__has_nothrow_move_assign",
+ "__has_nothrow_copy",
+ "__has_nothrow_constructor",
+ "__has_trivial_assign",
+ "__has_trivial_move_assign",
+ "__has_trivial_copy",
+ "__has_trivial_constructor",
+ "__has_trivial_move_constructor",
+ "__has_trivial_destructor",
+ "__has_virtual_destructor",
+ "__is_abstract",
+ "__is_base_of",
+ "__is_class",
+ "__is_convertible_to",
+ "__is_empty",
+ "__is_enum",
+ "__is_final",
+ "__is_literal",
+ "__is_literal_type",
+ "__is_pod",
+ "__is_polymorphic",
+ "__is_trivial",
+ "__is_union",
+ "__is_trivially_constructible",
+ "__is_trivially_copyable",
+ "__is_trivially_assignable",
+ "__underlying_type",
+ "__is_lvalue_expr",
+ "__is_rvalue_expr",
+ "__is_arithmetic",
+ "__is_floating_point",
+ "__is_integral",
+ "__is_complete_type",
+ "__is_void",
+ "__is_array",
+ "__is_function",
+ "__is_reference",
+ "__is_lvalue_reference",
+ "__is_rvalue_reference",
+ "__is_fundamental",
+ "__is_object",
+ "__is_scalar",
+ "__is_compound",
+ "__is_pointer",
+ "__is_member_object_pointer",
+ "__is_member_function_pointer",
+ "__is_member_pointer",
+ "__is_const",
+ "__is_volatile",
+ "__is_standard_layout",
+ "__is_signed",
+ "__is_unsigned",
+ "__is_same",
+ "__is_convertible",
+ "__array_rank",
+ "__array_extent",
+ "__private_extern__",
+ "__module_private__",
+ "__declspec",
+ "__cdecl",
+ "__stdcall",
+ "__fastcall",
+ "__thiscall",
+ "__vectorcall",
+ "cbuffer",
+ "tbuffer",
+ "packoffset",
+ "linear",
+ "centroid",
+ "nointerpolation",
+ "noperspective",
+ "sample",
+ "column_major",
+ "row_major",
+ "in",
+ "out",
+ "inout",
+ "uniform",
+ "precise",
+ "center",
+ "shared",
+ "groupshared",
+ "discard",
+ "snorm",
+ "unorm",
+ "point",
+ "line",
+ "lineadj",
+ "triangle",
+ "triangleadj",
+ "globallycoherent",
+ "interface",
+ "sampler_state",
+ "technique",
+ "indices",
+ "vertices",
+ "primitives",
+ "payload",
+ "Technique",
+ "technique10",
+ "technique11",
+ "__builtin_omp_required_simd_align",
+ "__pascal",
+ "__fp16",
+ "__alignof__",
+ "__asm",
+ "__asm__",
+ "__attribute__",
+ "__complex",
+ "__complex__",
+ "__const",
+ "__const__",
+ "__decltype",
+ "__imag__",
+ "__inline",
+ "__inline__",
+ "__nullptr",
+ "__real__",
+ "__restrict",
+ "__restrict__",
+ "__signed",
+ "__signed__",
+ "__typeof",
+ "__typeof__",
+ "__volatile",
+ "__volatile__",
+ "_Nonnull",
+ "_Nullable",
+ "_Null_unspecified",
+ "__builtin_convertvector",
+ "__char16_t",
+ "__char32_t",
+ // DXC intrinsics, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/utils/hct/gen_intrin_main.txt#L86-L376
+ "D3DCOLORtoUBYTE4",
+ "GetRenderTargetSampleCount",
+ "GetRenderTargetSamplePosition",
+ "abort",
+ "abs",
+ "acos",
+ "all",
+ "AllMemoryBarrier",
+ "AllMemoryBarrierWithGroupSync",
+ "any",
+ "asdouble",
+ "asfloat",
+ "asfloat16",
+ "asint16",
+ "asin",
+ "asint",
+ "asuint",
+ "asuint16",
+ "atan",
+ "atan2",
+ "ceil",
+ "clamp",
+ "clip",
+ "cos",
+ "cosh",
+ "countbits",
+ "cross",
+ "ddx",
+ "ddx_coarse",
+ "ddx_fine",
+ "ddy",
+ "ddy_coarse",
+ "ddy_fine",
+ "degrees",
+ "determinant",
+ "DeviceMemoryBarrier",
+ "DeviceMemoryBarrierWithGroupSync",
+ "distance",
+ "dot",
+ "dst",
+ "EvaluateAttributeAtSample",
+ "EvaluateAttributeCentroid",
+ "EvaluateAttributeSnapped",
+ "GetAttributeAtVertex",
+ "exp",
+ "exp2",
+ "f16tof32",
+ "f32tof16",
+ "faceforward",
+ "firstbithigh",
+ "firstbitlow",
+ "floor",
+ "fma",
+ "fmod",
+ "frac",
+ "frexp",
+ "fwidth",
+ "GroupMemoryBarrier",
+ "GroupMemoryBarrierWithGroupSync",
+ "InterlockedAdd",
+ "InterlockedMin",
+ "InterlockedMax",
+ "InterlockedAnd",
+ "InterlockedOr",
+ "InterlockedXor",
+ "InterlockedCompareStore",
+ "InterlockedExchange",
+ "InterlockedCompareExchange",
+ "InterlockedCompareStoreFloatBitwise",
+ "InterlockedCompareExchangeFloatBitwise",
+ "isfinite",
+ "isinf",
+ "isnan",
+ "ldexp",
+ "length",
+ "lerp",
+ "lit",
+ "log",
+ "log10",
+ "log2",
+ "mad",
+ "max",
+ "min",
+ "modf",
+ "msad4",
+ "mul",
+ "normalize",
+ "pow",
+ "printf",
+ "Process2DQuadTessFactorsAvg",
+ "Process2DQuadTessFactorsMax",
+ "Process2DQuadTessFactorsMin",
+ "ProcessIsolineTessFactors",
+ "ProcessQuadTessFactorsAvg",
+ "ProcessQuadTessFactorsMax",
+ "ProcessQuadTessFactorsMin",
+ "ProcessTriTessFactorsAvg",
+ "ProcessTriTessFactorsMax",
+ "ProcessTriTessFactorsMin",
+ "radians",
+ "rcp",
+ "reflect",
+ "refract",
+ "reversebits",
+ "round",
+ "rsqrt",
+ "saturate",
+ "sign",
+ "sin",
+ "sincos",
+ "sinh",
+ "smoothstep",
+ "source_mark",
+ "sqrt",
+ "step",
+ "tan",
+ "tanh",
+ "tex1D",
+ "tex1Dbias",
+ "tex1Dgrad",
+ "tex1Dlod",
+ "tex1Dproj",
+ "tex2D",
+ "tex2Dbias",
+ "tex2Dgrad",
+ "tex2Dlod",
+ "tex2Dproj",
+ "tex3D",
+ "tex3Dbias",
+ "tex3Dgrad",
+ "tex3Dlod",
+ "tex3Dproj",
+ "texCUBE",
+ "texCUBEbias",
+ "texCUBEgrad",
+ "texCUBElod",
+ "texCUBEproj",
+ "transpose",
+ "trunc",
+ "CheckAccessFullyMapped",
+ "AddUint64",
+ "NonUniformResourceIndex",
+ "WaveIsFirstLane",
+ "WaveGetLaneIndex",
+ "WaveGetLaneCount",
+ "WaveActiveAnyTrue",
+ "WaveActiveAllTrue",
+ "WaveActiveAllEqual",
+ "WaveActiveBallot",
+ "WaveReadLaneAt",
+ "WaveReadLaneFirst",
+ "WaveActiveCountBits",
+ "WaveActiveSum",
+ "WaveActiveProduct",
+ "WaveActiveBitAnd",
+ "WaveActiveBitOr",
+ "WaveActiveBitXor",
+ "WaveActiveMin",
+ "WaveActiveMax",
+ "WavePrefixCountBits",
+ "WavePrefixSum",
+ "WavePrefixProduct",
+ "WaveMatch",
+ "WaveMultiPrefixBitAnd",
+ "WaveMultiPrefixBitOr",
+ "WaveMultiPrefixBitXor",
+ "WaveMultiPrefixCountBits",
+ "WaveMultiPrefixProduct",
+ "WaveMultiPrefixSum",
+ "QuadReadLaneAt",
+ "QuadReadAcrossX",
+ "QuadReadAcrossY",
+ "QuadReadAcrossDiagonal",
+ "QuadAny",
+ "QuadAll",
+ "TraceRay",
+ "ReportHit",
+ "CallShader",
+ "IgnoreHit",
+ "AcceptHitAndEndSearch",
+ "DispatchRaysIndex",
+ "DispatchRaysDimensions",
+ "WorldRayOrigin",
+ "WorldRayDirection",
+ "ObjectRayOrigin",
+ "ObjectRayDirection",
+ "RayTMin",
+ "RayTCurrent",
+ "PrimitiveIndex",
+ "InstanceID",
+ "InstanceIndex",
+ "GeometryIndex",
+ "HitKind",
+ "RayFlags",
+ "ObjectToWorld",
+ "WorldToObject",
+ "ObjectToWorld3x4",
+ "WorldToObject3x4",
+ "ObjectToWorld4x3",
+ "WorldToObject4x3",
+ "dot4add_u8packed",
+ "dot4add_i8packed",
+ "dot2add",
+ "unpack_s8s16",
+ "unpack_u8u16",
+ "unpack_s8s32",
+ "unpack_u8u32",
+ "pack_s8",
+ "pack_u8",
+ "pack_clamp_s8",
+ "pack_clamp_u8",
+ "SetMeshOutputCounts",
+ "DispatchMesh",
+ "IsHelperLane",
+ "AllocateRayQuery",
+ "CreateResourceFromHeap",
+ "and",
+ "or",
+ "select",
+ // DXC resource and other types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/HlslTypes.cpp#L441-#L572
+ "InputPatch",
+ "OutputPatch",
+ "PointStream",
+ "LineStream",
+ "TriangleStream",
+ "Texture1D",
+ "RWTexture1D",
+ "Texture2D",
+ "RWTexture2D",
+ "Texture2DMS",
+ "RWTexture2DMS",
+ "Texture3D",
+ "RWTexture3D",
+ "TextureCube",
+ "RWTextureCube",
+ "Texture1DArray",
+ "RWTexture1DArray",
+ "Texture2DArray",
+ "RWTexture2DArray",
+ "Texture2DMSArray",
+ "RWTexture2DMSArray",
+ "TextureCubeArray",
+ "RWTextureCubeArray",
+ "FeedbackTexture2D",
+ "FeedbackTexture2DArray",
+ "RasterizerOrderedTexture1D",
+ "RasterizerOrderedTexture2D",
+ "RasterizerOrderedTexture3D",
+ "RasterizerOrderedTexture1DArray",
+ "RasterizerOrderedTexture2DArray",
+ "RasterizerOrderedBuffer",
+ "RasterizerOrderedByteAddressBuffer",
+ "RasterizerOrderedStructuredBuffer",
+ "ByteAddressBuffer",
+ "RWByteAddressBuffer",
+ "StructuredBuffer",
+ "RWStructuredBuffer",
+ "AppendStructuredBuffer",
+ "ConsumeStructuredBuffer",
+ "Buffer",
+ "RWBuffer",
+ "SamplerState",
+ "SamplerComparisonState",
+ "ConstantBuffer",
+ "TextureBuffer",
+ "RaytracingAccelerationStructure",
+ // DXC templated types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp
+ // look for `BuiltinTypeDeclBuilder`
+ "matrix",
+ "vector",
+ "TextureBuffer",
+ "ConstantBuffer",
+ "RayQuery",
+ // Naga utilities
+ super::writer::MODF_FUNCTION,
+ super::writer::FREXP_FUNCTION,
+];
+
+// DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254
+// + vector and matrix shorthands
+pub const TYPES: &[&str] = &{
+ const L: usize = 23 * (1 + 4 + 4 * 4);
+ let mut res = [""; L];
+ let mut c = 0;
+
+ /// For each scalar type, it will additionally generate vector and matrix shorthands
+ macro_rules! generate {
+ ([$($roots:literal),*], $x:tt) => {
+ $(
+ generate!(@inner push $roots);
+ generate!(@inner $roots, $x);
+ )*
+ };
+
+ (@inner $root:literal, [$($x:literal),*]) => {
+ generate!(@inner vector $root, $($x)*);
+ generate!(@inner matrix $root, $($x)*);
+ };
+
+ (@inner vector $root:literal, $($x:literal)*) => {
+ $(
+ generate!(@inner push concat!($root, $x));
+ )*
+ };
+
+ (@inner matrix $root:literal, $($x:literal)*) => {
+ // Duplicate the list
+ generate!(@inner matrix $root, $($x)*; $($x)*);
+ };
+
+ // The head/tail recursion: pick the first element of the first list and recursively do it for the tail.
+ (@inner matrix $root:literal, $head:literal $($tail:literal)*; $($x:literal)*) => {
+ $(
+ generate!(@inner push concat!($root, $head, "x", $x));
+ )*
+ generate!(@inner matrix $root, $($tail)*; $($x)*);
+
+ };
+
+ // The end of iteration: we exhausted the list
+ (@inner matrix $root:literal, ; $($x:literal)*) => {};
+
+ (@inner push $v:expr) => {
+ res[c] = $v;
+ c += 1;
+ };
+ }
+
+ generate!(
+ [
+ "bool",
+ "int",
+ "uint",
+ "dword",
+ "half",
+ "float",
+ "double",
+ "min10float",
+ "min16float",
+ "min12int",
+ "min16int",
+ "min16uint",
+ "int16_t",
+ "int32_t",
+ "int64_t",
+ "uint16_t",
+ "uint32_t",
+ "uint64_t",
+ "float16_t",
+ "float32_t",
+ "float64_t",
+ "int8_t4_packed",
+ "uint8_t4_packed"
+ ],
+ ["1", "2", "3", "4"]
+ );
+
+ debug_assert!(c == L);
+
+ res
+};
diff --git a/third_party/rust/naga/src/back/hlsl/mod.rs b/third_party/rust/naga/src/back/hlsl/mod.rs
new file mode 100644
index 0000000000..37ddbd3d67
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/mod.rs
@@ -0,0 +1,302 @@
+/*!
+Backend for [HLSL][hlsl] (High-Level Shading Language).
+
+# Supported shader model versions:
+- 5.0
+- 5.1
+- 6.0
+
+# Layout of values in `uniform` buffers
+
+WGSL's ["Internal Layout of Values"][ilov] rules specify how each WGSL
+type should be stored in `uniform` and `storage` buffers. The HLSL we
+generate must access values in that form, even when it is not what
+HLSL would use normally.
+
+The rules described here only apply to WGSL `uniform` variables. WGSL
+`storage` buffers are translated as HLSL `ByteAddressBuffers`, for
+which we generate `Load` and `Store` method calls with explicit byte
+offsets. WGSL pipeline inputs must be scalars or vectors; they cannot
+be matrices, which is where the interesting problems arise.
+
+## Row- and column-major ordering for matrices
+
+WGSL specifies that matrices in uniform buffers are stored in
+column-major order. This matches HLSL's default, so one might expect
+things to be straightforward. Unfortunately, WGSL and HLSL disagree on
+what indexing a matrix means: in WGSL, `m[i]` retrieves the `i`'th
+*column* of `m`, whereas in HLSL it retrieves the `i`'th *row*. We
+want to avoid translating `m[i]` into some complicated reassembly of a
+vector from individually fetched components, so this is a problem.
+
+However, with a bit of trickery, it is possible to use HLSL's `m[i]`
+as the translation of WGSL's `m[i]`:
+
+- We declare all matrices in uniform buffers in HLSL with the
+ `row_major` qualifier, and transpose the row and column counts: a
+ WGSL `mat3x4<f32>`, say, becomes an HLSL `row_major float3x4`. (Note
+ that WGSL and HLSL type names put the row and column in reverse
+ order.) Since the HLSL type is the transpose of how WebGPU directs
+ the user to store the data, HLSL will load all matrices transposed.
+
+- Since matrices are transposed, an HLSL indexing expression retrieves
+ the "columns" of the intended WGSL value, as desired.
+
+- For vector-matrix multiplication, since `mul(transpose(m), v)` is
+ equivalent to `mul(v, m)` (note the reversal of the arguments), and
+ `mul(v, transpose(m))` is equivalent to `mul(m, v)`, we can
+ translate WGSL `m * v` and `v * m` to HLSL by simply reversing the
+ arguments to `mul`.
+
+## Padding in two-row matrices
+
+An HLSL `row_major floatKx2` matrix has padding between its rows that
+the WGSL `matKx2<f32>` matrix it represents does not. HLSL stores all
+matrix rows [aligned on 16-byte boundaries][16bb], whereas WGSL says
+that the columns of a `matKx2<f32>` need only be [aligned as required
+for `vec2<f32>`][ilov], which is [eight-byte alignment][8bb].
+
+To compensate for this, any time a `matKx2<f32>` appears in a WGSL
+`uniform` variable, whether directly as the variable's type or as part
+of a struct/array, we actually emit `K` separate `float2` members, and
+assemble/disassemble the matrix from its columns (in WGSL; rows in
+HLSL) upon load and store.
+
+For example, the following WGSL struct type:
+
+```ignore
+struct Baz {
+ m: mat3x2<f32>,
+}
+```
+
+is rendered as the HLSL struct type:
+
+```ignore
+struct Baz {
+ float2 m_0; float2 m_1; float2 m_2;
+};
+```
+
+The `wrapped_struct_matrix` functions in `help.rs` generate HLSL
+helper functions to access such members, converting between the stored
+form and the HLSL matrix types appropriately. For example, for reading
+the member `m` of the `Baz` struct above, we emit:
+
+```ignore
+float3x2 GetMatmOnBaz(Baz obj) {
+ return float3x2(obj.m_0, obj.m_1, obj.m_2);
+}
+```
+
+We also emit an analogous `Set` function, as well as functions for
+accessing individual columns by dynamic index.
+
+[hlsl]: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl
+[ilov]: https://gpuweb.github.io/gpuweb/wgsl/#internal-value-layout
+[16bb]: https://github.com/microsoft/DirectXShaderCompiler/wiki/Buffer-Packing#constant-buffer-packing
+[8bb]: https://gpuweb.github.io/gpuweb/wgsl/#alignment-and-size
+*/
+
+mod conv;
+mod help;
+mod keywords;
+mod storage;
+mod writer;
+
+use std::fmt::Error as FmtError;
+use thiserror::Error;
+
+use crate::{back, proc};
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct BindTarget {
+ pub space: u8,
+ pub register: u32,
+ /// If the binding is an unsized binding array, this overrides the size.
+ pub binding_array_size: Option<u32>,
+}
+
+// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
+
+/// A HLSL shader model version.
+#[allow(non_snake_case, non_camel_case_types)]
+#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum ShaderModel {
+ V5_0,
+ V5_1,
+ V6_0,
+}
+
+impl ShaderModel {
+ pub const fn to_str(self) -> &'static str {
+ match self {
+ Self::V5_0 => "5_0",
+ Self::V5_1 => "5_1",
+ Self::V6_0 => "6_0",
+ }
+ }
+}
+
+impl crate::ShaderStage {
+ pub const fn to_hlsl_str(self) -> &'static str {
+ match self {
+ Self::Vertex => "vs",
+ Self::Fragment => "ps",
+ Self::Compute => "cs",
+ }
+ }
+}
+
+impl crate::ImageDimension {
+ const fn to_hlsl_str(self) -> &'static str {
+ match self {
+ Self::D1 => "1D",
+ Self::D2 => "2D",
+ Self::D3 => "3D",
+ Self::Cube => "Cube",
+ }
+ }
+}
+
+/// Shorthand result used internally by the backend
+type BackendResult = Result<(), Error>;
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum EntryPointError {
+ #[error("mapping of {0:?} is missing")]
+ MissingBinding(crate::ResourceBinding),
+}
+
+/// Configuration used in the [`Writer`].
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Options {
+ /// The hlsl shader model to be used
+ pub shader_model: ShaderModel,
+ /// Map of resources association to binding locations.
+ pub binding_map: BindingMap,
+ /// Don't panic on missing bindings, instead generate any HLSL.
+ pub fake_missing_bindings: bool,
+ /// Add special constants to `SV_VertexIndex` and `SV_InstanceIndex`,
+ /// to make them work like in Vulkan/Metal, with help of the host.
+ pub special_constants_binding: Option<BindTarget>,
+ /// Bind target of the push constant buffer
+ pub push_constants_target: Option<BindTarget>,
+ /// Should workgroup variables be zero initialized (by polyfilling)?
+ pub zero_initialize_workgroup_memory: bool,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ shader_model: ShaderModel::V5_1,
+ binding_map: BindingMap::default(),
+ fake_missing_bindings: true,
+ special_constants_binding: None,
+ push_constants_target: None,
+ zero_initialize_workgroup_memory: true,
+ }
+ }
+}
+
+impl Options {
+ fn resolve_resource_binding(
+ &self,
+ res_binding: &crate::ResourceBinding,
+ ) -> Result<BindTarget, EntryPointError> {
+ match self.binding_map.get(res_binding) {
+ Some(target) => Ok(target.clone()),
+ None if self.fake_missing_bindings => Ok(BindTarget {
+ space: res_binding.group as u8,
+ register: res_binding.binding,
+ binding_array_size: None,
+ }),
+ None => Err(EntryPointError::MissingBinding(res_binding.clone())),
+ }
+ }
+}
+
+/// Reflection info for entry point names.
+#[derive(Default)]
+pub struct ReflectionInfo {
+ /// Mapping of the entry point names.
+ ///
+ /// Each item in the array corresponds to an entry point index. The real entry point name may be different if one of the
+ /// reserved words are used.
+ ///
+ /// Note: Some entry points may fail translation because of missing bindings.
+ pub entry_point_names: Vec<Result<String, EntryPointError>>,
+}
+
+#[derive(Error, Debug)]
+pub enum Error {
+ #[error(transparent)]
+ IoError(#[from] FmtError),
+ #[error("A scalar with an unsupported width was requested: {0:?}")]
+ UnsupportedScalar(crate::Scalar),
+ #[error("{0}")]
+ Unimplemented(String), // TODO: Error used only during development
+ #[error("{0}")]
+ Custom(String),
+}
+
+#[derive(Default)]
+struct Wrapped {
+ array_lengths: crate::FastHashSet<help::WrappedArrayLength>,
+ image_queries: crate::FastHashSet<help::WrappedImageQuery>,
+ constructors: crate::FastHashSet<help::WrappedConstructor>,
+ struct_matrix_access: crate::FastHashSet<help::WrappedStructMatrixAccess>,
+ mat_cx2s: crate::FastHashSet<help::WrappedMatCx2>,
+}
+
+impl Wrapped {
+ fn clear(&mut self) {
+ self.array_lengths.clear();
+ self.image_queries.clear();
+ self.constructors.clear();
+ self.struct_matrix_access.clear();
+ self.mat_cx2s.clear();
+ }
+}
+
+pub struct Writer<'a, W> {
+ out: W,
+ names: crate::FastHashMap<proc::NameKey, String>,
+ namer: proc::Namer,
+ /// HLSL backend options
+ options: &'a Options,
+ /// Information about entry point arguments and result types.
+ entry_point_io: Vec<writer::EntryPointInterface>,
+ /// Set of expressions that have associated temporary variables
+ named_expressions: crate::NamedExpressions,
+ wrapped: Wrapped,
+
+ /// A reference to some part of a global variable, lowered to a series of
+ /// byte offset calculations.
+ ///
+ /// See the [`storage`] module for background on why we need this.
+ ///
+ /// Each [`SubAccess`] in the vector is a lowering of some [`Access`] or
+ /// [`AccessIndex`] expression to the level of byte strides and offsets. See
+ /// [`SubAccess`] for details.
+ ///
+ /// This field is a member of [`Writer`] solely to allow re-use of
+ /// the `Vec`'s dynamic allocation. The value is no longer needed
+ /// once HLSL for the access has been generated.
+ ///
+ /// [`Storage`]: crate::AddressSpace::Storage
+ /// [`SubAccess`]: storage::SubAccess
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ temp_access_chain: Vec<storage::SubAccess>,
+ need_bake_expressions: back::NeedBakeExpressions,
+}
diff --git a/third_party/rust/naga/src/back/hlsl/storage.rs b/third_party/rust/naga/src/back/hlsl/storage.rs
new file mode 100644
index 0000000000..1b8a6ec12d
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/storage.rs
@@ -0,0 +1,494 @@
+/*!
+Generating accesses to [`ByteAddressBuffer`] contents.
+
+Naga IR globals in the [`Storage`] address space are rendered as
+[`ByteAddressBuffer`]s or [`RWByteAddressBuffer`]s in HLSL. These
+buffers don't have HLSL types (structs, arrays, etc.); instead, they
+are just raw blocks of bytes, with methods to load and store values of
+specific types at particular byte offsets. This means that Naga must
+translate chains of [`Access`] and [`AccessIndex`] expressions into
+HLSL expressions that compute byte offsets into the buffer.
+
+To generate code for a [`Storage`] access:
+
+- Call [`Writer::fill_access_chain`] on the expression referring to
+ the value. This populates [`Writer::temp_access_chain`] with the
+ appropriate byte offset calculations, as a vector of [`SubAccess`]
+ values.
+
+- Call [`Writer::write_storage_address`] to emit an HLSL expression
+ for a given slice of [`SubAccess`] values.
+
+Naga IR expressions can operate on composite values of any type, but
+[`ByteAddressBuffer`] and [`RWByteAddressBuffer`] have only a fixed
+set of `Load` and `Store` methods, to access one through four
+consecutive 32-bit values. To synthesize a Naga access, you can
+initialize [`temp_access_chain`] to refer to the composite, and then
+temporarily push and pop additional steps on
+[`Writer::temp_access_chain`] to generate accesses to the individual
+elements/members.
+
+The [`temp_access_chain`] field is a member of [`Writer`] solely to
+allow re-use of the `Vec`'s dynamic allocation. Its value is no longer
+needed once HLSL for the access has been generated.
+
+[`Storage`]: crate::AddressSpace::Storage
+[`ByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-byteaddressbuffer
+[`RWByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer
+[`Access`]: crate::Expression::Access
+[`AccessIndex`]: crate::Expression::AccessIndex
+[`Writer::fill_access_chain`]: super::Writer::fill_access_chain
+[`Writer::write_storage_address`]: super::Writer::write_storage_address
+[`Writer::temp_access_chain`]: super::Writer::temp_access_chain
+[`temp_access_chain`]: super::Writer::temp_access_chain
+[`Writer`]: super::Writer
+*/
+
+use super::{super::FunctionCtx, BackendResult, Error};
+use crate::{
+ proc::{Alignment, NameKey, TypeResolution},
+ Handle,
+};
+
+use std::{fmt, mem};
+
+const STORE_TEMP_NAME: &str = "_value";
+
+/// One step in accessing a [`Storage`] global's component or element.
+///
+/// [`Writer::temp_access_chain`] holds a series of these structures,
+/// describing how to compute the byte offset of a particular element
+/// or member of some global variable in the [`Storage`] address
+/// space.
+///
+/// [`Writer::temp_access_chain`]: super::Writer::temp_access_chain
+/// [`Storage`]: crate::AddressSpace::Storage
+#[derive(Debug)]
+pub(super) enum SubAccess {
+ /// Add the given byte offset. This is used for struct members, or
+ /// known components of a vector or matrix. In all those cases,
+ /// the byte offset is a compile-time constant.
+ Offset(u32),
+
+ /// Scale `value` by `stride`, and add that to the current byte
+ /// offset. This is used to compute the offset of an array element
+ /// whose index is computed at runtime.
+ Index {
+ value: Handle<crate::Expression>,
+ stride: u32,
+ },
+}
+
+pub(super) enum StoreValue {
+ Expression(Handle<crate::Expression>),
+ TempIndex {
+ depth: usize,
+ index: u32,
+ ty: TypeResolution,
+ },
+ TempAccess {
+ depth: usize,
+ base: Handle<crate::Type>,
+ member_index: u32,
+ },
+}
+
+impl<W: fmt::Write> super::Writer<'_, W> {
+ pub(super) fn write_storage_address(
+ &mut self,
+ module: &crate::Module,
+ chain: &[SubAccess],
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ if chain.is_empty() {
+ write!(self.out, "0")?;
+ }
+ for (i, access) in chain.iter().enumerate() {
+ if i != 0 {
+ write!(self.out, "+")?;
+ }
+ match *access {
+ SubAccess::Offset(offset) => {
+ write!(self.out, "{offset}")?;
+ }
+ SubAccess::Index { value, stride } => {
+ self.write_expr(module, value, func_ctx)?;
+ write!(self.out, "*{stride}")?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn write_storage_load_sequence<I: Iterator<Item = (TypeResolution, u32)>>(
+ &mut self,
+ module: &crate::Module,
+ var_handle: Handle<crate::GlobalVariable>,
+ sequence: I,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ for (i, (ty_resolution, offset)) in sequence.enumerate() {
+ // add the index temporarily
+ self.temp_access_chain.push(SubAccess::Offset(offset));
+ if i != 0 {
+ write!(self.out, ", ")?;
+ };
+ self.write_storage_load(module, var_handle, ty_resolution, func_ctx)?;
+ self.temp_access_chain.pop();
+ }
+ Ok(())
+ }
+
+ /// Emit code to access a [`Storage`] global's component.
+ ///
+ /// Emit HLSL to access the component of `var_handle`, a global
+ /// variable in the [`Storage`] address space, whose type is
+ /// `result_ty` and whose location within the global is given by
+ /// [`self.temp_access_chain`]. See the [`storage`] module's
+ /// documentation for background.
+ ///
+ /// [`Storage`]: crate::AddressSpace::Storage
+ /// [`self.temp_access_chain`]: super::Writer::temp_access_chain
+ pub(super) fn write_storage_load(
+ &mut self,
+ module: &crate::Module,
+ var_handle: Handle<crate::GlobalVariable>,
+ result_ty: TypeResolution,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ match *result_ty.inner_with(&module.types) {
+ crate::TypeInner::Scalar(scalar) => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ let cast = scalar.kind.to_hlsl_cast();
+ write!(self.out, "{cast}({var_name}.Load(")?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, "))")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Vector { size, scalar } => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ let cast = scalar.kind.to_hlsl_cast();
+ write!(self.out, "{}({}.Load{}(", cast, var_name, size as u8)?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, "))")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ write!(
+ self.out,
+ "{}{}x{}(",
+ scalar.to_hlsl_str()?,
+ columns as u8,
+ rows as u8,
+ )?;
+
+ // Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
+ let row_stride = Alignment::from(rows) * scalar.width as u32;
+ let iter = (0..columns as u32).map(|i| {
+ let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
+ (TypeResolution::Value(ty_inner), i * row_stride)
+ });
+ self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ stride,
+ } => {
+ let constructor = super::help::WrappedConstructor {
+ ty: result_ty.handle().unwrap(),
+ };
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ write!(self.out, "(")?;
+ let iter = (0..size.get()).map(|i| (TypeResolution::Handle(base), stride * i));
+ self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ let constructor = super::help::WrappedConstructor {
+ ty: result_ty.handle().unwrap(),
+ };
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ write!(self.out, "(")?;
+ let iter = members
+ .iter()
+ .map(|m| (TypeResolution::Handle(m.ty), m.offset));
+ self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ _ => unreachable!(),
+ }
+ Ok(())
+ }
+
+ fn write_store_value(
+ &mut self,
+ module: &crate::Module,
+ value: &StoreValue,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ match *value {
+ StoreValue::Expression(expr) => self.write_expr(module, expr, func_ctx)?,
+ StoreValue::TempIndex {
+ depth,
+ index,
+ ty: _,
+ } => write!(self.out, "{STORE_TEMP_NAME}{depth}[{index}]")?,
+ StoreValue::TempAccess {
+ depth,
+ base,
+ member_index,
+ } => {
+ let name = &self.names[&NameKey::StructMember(base, member_index)];
+ write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}")?
+ }
+ }
+ Ok(())
+ }
+
+ /// Helper function to write down the Store operation on a `ByteAddressBuffer`.
+ pub(super) fn write_storage_store(
+ &mut self,
+ module: &crate::Module,
+ var_handle: Handle<crate::GlobalVariable>,
+ value: StoreValue,
+ func_ctx: &FunctionCtx,
+ level: crate::back::Level,
+ ) -> BackendResult {
+ let temp_resolution;
+ let ty_resolution = match value {
+ StoreValue::Expression(expr) => &func_ctx.info[expr].ty,
+ StoreValue::TempIndex {
+ depth: _,
+ index: _,
+ ref ty,
+ } => ty,
+ StoreValue::TempAccess {
+ depth: _,
+ base,
+ member_index,
+ } => {
+ let ty_handle = match module.types[base].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ members[member_index as usize].ty
+ }
+ _ => unreachable!(),
+ };
+ temp_resolution = TypeResolution::Handle(ty_handle);
+ &temp_resolution
+ }
+ };
+ match *ty_resolution.inner_with(&module.types) {
+ crate::TypeInner::Scalar(_) => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "{level}{var_name}.Store(")?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, ", asuint(")?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, "));")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Vector { size, .. } => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, ", asuint(")?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, "));")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ // first, assign the value to a temporary
+ writeln!(self.out, "{level}{{")?;
+ let depth = level.0 + 1;
+ write!(
+ self.out,
+ "{}{}{}x{} {}{} = ",
+ level.next(),
+ scalar.to_hlsl_str()?,
+ columns as u8,
+ rows as u8,
+ STORE_TEMP_NAME,
+ depth,
+ )?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, ";")?;
+
+ // Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
+ let row_stride = Alignment::from(rows) * scalar.width as u32;
+
+ // then iterate the stores
+ for i in 0..columns as u32 {
+ self.temp_access_chain
+ .push(SubAccess::Offset(i * row_stride));
+ let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
+ let sv = StoreValue::TempIndex {
+ depth,
+ index: i,
+ ty: TypeResolution::Value(ty_inner),
+ };
+ self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
+ self.temp_access_chain.pop();
+ }
+ // done
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ stride,
+ } => {
+ // first, assign the value to a temporary
+ writeln!(self.out, "{level}{{")?;
+ write!(self.out, "{}", level.next())?;
+ self.write_value_type(module, &module.types[base].inner)?;
+ let depth = level.next().0;
+ write!(self.out, " {STORE_TEMP_NAME}{depth}")?;
+ self.write_array_size(module, base, crate::ArraySize::Constant(size))?;
+ write!(self.out, " = ")?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ // then iterate the stores
+ for i in 0..size.get() {
+ self.temp_access_chain.push(SubAccess::Offset(i * stride));
+ let sv = StoreValue::TempIndex {
+ depth,
+ index: i,
+ ty: TypeResolution::Handle(base),
+ };
+ self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
+ self.temp_access_chain.pop();
+ }
+ // done
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ // first, assign the value to a temporary
+ writeln!(self.out, "{level}{{")?;
+ let depth = level.next().0;
+ let struct_ty = ty_resolution.handle().unwrap();
+ let struct_name = &self.names[&NameKey::Type(struct_ty)];
+ write!(
+ self.out,
+ "{}{} {}{} = ",
+ level.next(),
+ struct_name,
+ STORE_TEMP_NAME,
+ depth
+ )?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ // then iterate the stores
+ for (i, member) in members.iter().enumerate() {
+ self.temp_access_chain
+ .push(SubAccess::Offset(member.offset));
+ let sv = StoreValue::TempAccess {
+ depth,
+ base: struct_ty,
+ member_index: i as u32,
+ };
+ self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
+ self.temp_access_chain.pop();
+ }
+ // done
+ writeln!(self.out, "{level}}}")?;
+ }
+ _ => unreachable!(),
+ }
+ Ok(())
+ }
+
+ /// Set [`temp_access_chain`] to compute the byte offset of `cur_expr`.
+ ///
+ /// The `cur_expr` expression must be a reference to a global
+ /// variable in the [`Storage`] address space, or a chain of
+ /// [`Access`] and [`AccessIndex`] expressions referring to some
+ /// component of such a global.
+ ///
+ /// [`temp_access_chain`]: super::Writer::temp_access_chain
+ /// [`Storage`]: crate::AddressSpace::Storage
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ pub(super) fn fill_access_chain(
+ &mut self,
+ module: &crate::Module,
+ mut cur_expr: Handle<crate::Expression>,
+ func_ctx: &FunctionCtx,
+ ) -> Result<Handle<crate::GlobalVariable>, Error> {
+ enum AccessIndex {
+ Expression(Handle<crate::Expression>),
+ Constant(u32),
+ }
+ enum Parent<'a> {
+ Array { stride: u32 },
+ Struct(&'a [crate::StructMember]),
+ }
+ self.temp_access_chain.clear();
+
+ loop {
+ let (next_expr, access_index) = match func_ctx.expressions[cur_expr] {
+ crate::Expression::GlobalVariable(handle) => return Ok(handle),
+ crate::Expression::Access { base, index } => (base, AccessIndex::Expression(index)),
+ crate::Expression::AccessIndex { base, index } => {
+ (base, AccessIndex::Constant(index))
+ }
+ ref other => {
+ return Err(Error::Unimplemented(format!("Pointer access of {other:?}")))
+ }
+ };
+
+ let parent = match *func_ctx.resolve_type(next_expr, &module.types) {
+ crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
+ crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members),
+ crate::TypeInner::Array { stride, .. } => Parent::Array { stride },
+ crate::TypeInner::Vector { scalar, .. } => Parent::Array {
+ stride: scalar.width as u32,
+ },
+ crate::TypeInner::Matrix { rows, scalar, .. } => Parent::Array {
+ // The stride between matrices is the count of rows as this is how
+ // long each column is.
+ stride: Alignment::from(rows) * scalar.width as u32,
+ },
+ _ => unreachable!(),
+ },
+ crate::TypeInner::ValuePointer { scalar, .. } => Parent::Array {
+ stride: scalar.width as u32,
+ },
+ _ => unreachable!(),
+ };
+
+ let sub = match (parent, access_index) {
+ (Parent::Array { stride }, AccessIndex::Expression(value)) => {
+ SubAccess::Index { value, stride }
+ }
+ (Parent::Array { stride }, AccessIndex::Constant(index)) => {
+ SubAccess::Offset(stride * index)
+ }
+ (Parent::Struct(members), AccessIndex::Constant(index)) => {
+ SubAccess::Offset(members[index as usize].offset)
+ }
+ (Parent::Struct(_), AccessIndex::Expression(_)) => unreachable!(),
+ };
+
+ self.temp_access_chain.push(sub);
+ cur_expr = next_expr;
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/writer.rs b/third_party/rust/naga/src/back/hlsl/writer.rs
new file mode 100644
index 0000000000..43f7212837
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/writer.rs
@@ -0,0 +1,3366 @@
+use super::{
+ help::{WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess},
+ storage::StoreValue,
+ BackendResult, Error, Options,
+};
+use crate::{
+ back,
+ proc::{self, NameKey},
+ valid, Handle, Module, ScalarKind, ShaderStage, TypeInner,
+};
+use std::{fmt, mem};
+
+const LOCATION_SEMANTIC: &str = "LOC";
+const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
+const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
+const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
+const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
+const SPECIAL_OTHER: &str = "other";
+
+pub(crate) const MODF_FUNCTION: &str = "naga_modf";
+pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
+
+struct EpStructMember {
+ name: String,
+ ty: Handle<crate::Type>,
+ // technically, this should always be `Some`
+ binding: Option<crate::Binding>,
+ index: u32,
+}
+
+/// Structure contains information required for generating
+/// wrapped structure of all entry points arguments
+struct EntryPointBinding {
+ /// Name of the fake EP argument that contains the struct
+ /// with all the flattened input data.
+ arg_name: String,
+ /// Generated structure name
+ ty_name: String,
+ /// Members of generated structure
+ members: Vec<EpStructMember>,
+}
+
+pub(super) struct EntryPointInterface {
+ /// If `Some`, the input of an entry point is gathered in a special
+ /// struct with members sorted by binding.
+ /// The `EntryPointBinding::members` array is sorted by index,
+ /// so that we can walk it in `write_ep_arguments_initialization`.
+ input: Option<EntryPointBinding>,
+ /// If `Some`, the output of an entry point is flattened.
+ /// The `EntryPointBinding::members` array is sorted by binding,
+ /// So that we can walk it in `Statement::Return` handler.
+ output: Option<EntryPointBinding>,
+}
+
+#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
+enum InterfaceKey {
+ Location(u32),
+ BuiltIn(crate::BuiltIn),
+ Other,
+}
+
+impl InterfaceKey {
+ const fn new(binding: Option<&crate::Binding>) -> Self {
+ match binding {
+ Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
+ Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
+ None => Self::Other,
+ }
+ }
+}
+
+#[derive(Copy, Clone, PartialEq)]
+enum Io {
+ Input,
+ Output,
+}
+
+impl<'a, W: fmt::Write> super::Writer<'a, W> {
+ pub fn new(out: W, options: &'a Options) -> Self {
+ Self {
+ out,
+ names: crate::FastHashMap::default(),
+ namer: proc::Namer::default(),
+ options,
+ entry_point_io: Vec::new(),
+ named_expressions: crate::NamedExpressions::default(),
+ wrapped: super::Wrapped::default(),
+ temp_access_chain: Vec::new(),
+ need_bake_expressions: Default::default(),
+ }
+ }
+
+ fn reset(&mut self, module: &Module) {
+ self.names.clear();
+ self.namer.reset(
+ module,
+ super::keywords::RESERVED,
+ super::keywords::TYPES,
+ super::keywords::RESERVED_CASE_INSENSITIVE,
+ &[],
+ &mut self.names,
+ );
+ self.entry_point_io.clear();
+ self.named_expressions.clear();
+ self.wrapped.clear();
+ self.need_bake_expressions.clear();
+ }
+
+ /// Helper method used to find which expressions of a given function require baking
+ ///
+ /// # Notes
+ /// Clears `need_bake_expressions` set before adding to it
+ fn update_expressions_to_bake(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ info: &valid::FunctionInfo,
+ ) {
+ use crate::Expression;
+ self.need_bake_expressions.clear();
+ for (fun_handle, expr) in func.expressions.iter() {
+ let expr_info = &info[fun_handle];
+ let min_ref_count = func.expressions[fun_handle].bake_ref_count();
+ if min_ref_count <= expr_info.ref_count {
+ self.need_bake_expressions.insert(fun_handle);
+ }
+
+ if let Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } = *expr
+ {
+ match fun {
+ crate::MathFunction::Asinh
+ | crate::MathFunction::Acosh
+ | crate::MathFunction::Atanh
+ | crate::MathFunction::Unpack2x16float
+ | crate::MathFunction::Unpack2x16snorm
+ | crate::MathFunction::Unpack2x16unorm
+ | crate::MathFunction::Unpack4x8snorm
+ | crate::MathFunction::Unpack4x8unorm
+ | crate::MathFunction::Pack2x16float
+ | crate::MathFunction::Pack2x16snorm
+ | crate::MathFunction::Pack2x16unorm
+ | crate::MathFunction::Pack4x8snorm
+ | crate::MathFunction::Pack4x8unorm => {
+ self.need_bake_expressions.insert(arg);
+ }
+ crate::MathFunction::ExtractBits => {
+ self.need_bake_expressions.insert(arg);
+ self.need_bake_expressions.insert(arg1.unwrap());
+ self.need_bake_expressions.insert(arg2.unwrap());
+ }
+ crate::MathFunction::InsertBits => {
+ self.need_bake_expressions.insert(arg);
+ self.need_bake_expressions.insert(arg1.unwrap());
+ self.need_bake_expressions.insert(arg2.unwrap());
+ self.need_bake_expressions.insert(arg3.unwrap());
+ }
+ crate::MathFunction::CountLeadingZeros => {
+ let inner = info[fun_handle].ty.inner_with(&module.types);
+ if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
+ self.need_bake_expressions.insert(arg);
+ }
+ }
+ _ => {}
+ }
+ }
+
+ if let Expression::Derivative { axis, ctrl, expr } = *expr {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
+ self.need_bake_expressions.insert(expr);
+ }
+ }
+ }
+ }
+
+ pub fn write(
+ &mut self,
+ module: &Module,
+ module_info: &valid::ModuleInfo,
+ ) -> Result<super::ReflectionInfo, Error> {
+ self.reset(module);
+
+ // Write special constants, if needed
+ if let Some(ref bt) = self.options.special_constants_binding {
+ writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
+ writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
+ writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
+ writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
+ writeln!(self.out, "}};")?;
+ write!(
+ self.out,
+ "ConstantBuffer<{}> {}: register(b{}",
+ SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
+ )?;
+ if bt.space != 0 {
+ write!(self.out, ", space{}", bt.space)?;
+ }
+ writeln!(self.out, ");")?;
+
+ // Extra newline for readability
+ writeln!(self.out)?;
+ }
+
+ // Save all entry point output types
+ let ep_results = module
+ .entry_points
+ .iter()
+ .map(|ep| (ep.stage, ep.function.result.clone()))
+ .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
+
+ self.write_all_mat_cx2_typedefs_and_functions(module)?;
+
+ // Write all structs
+ for (handle, ty) in module.types.iter() {
+ if let TypeInner::Struct { ref members, span } = ty.inner {
+ if module.types[members.last().unwrap().ty]
+ .inner
+ .is_dynamically_sized(&module.types)
+ {
+ // unsized arrays can only be in storage buffers,
+ // for which we use `ByteAddressBuffer` anyway.
+ continue;
+ }
+
+ let ep_result = ep_results.iter().find(|e| {
+ if let Some(ref result) = e.1 {
+ result.ty == handle
+ } else {
+ false
+ }
+ });
+
+ self.write_struct(
+ module,
+ handle,
+ members,
+ span,
+ ep_result.map(|r| (r.0, Io::Output)),
+ )?;
+ writeln!(self.out)?;
+ }
+ }
+
+ self.write_special_functions(module)?;
+
+ self.write_wrapped_compose_functions(module, &module.const_expressions)?;
+
+ // Write all named constants
+ let mut constants = module
+ .constants
+ .iter()
+ .filter(|&(_, c)| c.name.is_some())
+ .peekable();
+ while let Some((handle, _)) = constants.next() {
+ self.write_global_constant(module, handle)?;
+ // Add extra newline for readability on last iteration
+ if constants.peek().is_none() {
+ writeln!(self.out)?;
+ }
+ }
+
+ // Write all globals
+ for (ty, _) in module.global_variables.iter() {
+ self.write_global(module, ty)?;
+ }
+
+ if !module.global_variables.is_empty() {
+ // Add extra newline for readability
+ writeln!(self.out)?;
+ }
+
+ // Write all entry points wrapped structs
+ for (index, ep) in module.entry_points.iter().enumerate() {
+ let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
+ let ep_io = self.write_ep_interface(module, &ep.function, ep.stage, &ep_name)?;
+ self.entry_point_io.push(ep_io);
+ }
+
+ // Write all regular functions
+ for (handle, function) in module.functions.iter() {
+ let info = &module_info[handle];
+
+ // Check if all of the globals are accessible
+ if !self.options.fake_missing_bindings {
+ if let Some((var_handle, _)) =
+ module
+ .global_variables
+ .iter()
+ .find(|&(var_handle, var)| match var.binding {
+ Some(ref binding) if !info[var_handle].is_empty() => {
+ self.options.resolve_resource_binding(binding).is_err()
+ }
+ _ => false,
+ })
+ {
+ log::info!(
+ "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
+ handle,
+ function.name,
+ var_handle
+ );
+ continue;
+ }
+ }
+
+ let ctx = back::FunctionCtx {
+ ty: back::FunctionType::Function(handle),
+ info,
+ expressions: &function.expressions,
+ named_expressions: &function.named_expressions,
+ };
+ let name = self.names[&NameKey::Function(handle)].clone();
+
+ self.write_wrapped_functions(module, &ctx)?;
+
+ self.write_function(module, name.as_str(), function, &ctx, info)?;
+
+ writeln!(self.out)?;
+ }
+
+ let mut entry_point_names = Vec::with_capacity(module.entry_points.len());
+
+ // Write all entry points
+ for (index, ep) in module.entry_points.iter().enumerate() {
+ let info = module_info.get_entry_point(index);
+
+ if !self.options.fake_missing_bindings {
+ let mut ep_error = None;
+ for (var_handle, var) in module.global_variables.iter() {
+ match var.binding {
+ Some(ref binding) if !info[var_handle].is_empty() => {
+ if let Err(err) = self.options.resolve_resource_binding(binding) {
+ ep_error = Some(err);
+ break;
+ }
+ }
+ _ => {}
+ }
+ }
+ if let Some(err) = ep_error {
+ entry_point_names.push(Err(err));
+ continue;
+ }
+ }
+
+ let ctx = back::FunctionCtx {
+ ty: back::FunctionType::EntryPoint(index as u16),
+ info,
+ expressions: &ep.function.expressions,
+ named_expressions: &ep.function.named_expressions,
+ };
+
+ self.write_wrapped_functions(module, &ctx)?;
+
+ if ep.stage == ShaderStage::Compute {
+ // HLSL is calling workgroup size "num threads"
+ let num_threads = ep.workgroup_size;
+ writeln!(
+ self.out,
+ "[numthreads({}, {}, {})]",
+ num_threads[0], num_threads[1], num_threads[2]
+ )?;
+ }
+
+ let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
+ self.write_function(module, &name, &ep.function, &ctx, info)?;
+
+ if index < module.entry_points.len() - 1 {
+ writeln!(self.out)?;
+ }
+
+ entry_point_names.push(Ok(name));
+ }
+
+ Ok(super::ReflectionInfo { entry_point_names })
+ }
+
+ fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
+ match *binding {
+ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
+ write!(self.out, "precise ")?;
+ }
+ crate::Binding::Location {
+ interpolation,
+ sampling,
+ ..
+ } => {
+ if let Some(interpolation) = interpolation {
+ if let Some(string) = interpolation.to_hlsl_str() {
+ write!(self.out, "{string} ")?
+ }
+ }
+
+ if let Some(sampling) = sampling {
+ if let Some(string) = sampling.to_hlsl_str() {
+ write!(self.out, "{string} ")?
+ }
+ }
+ }
+ crate::Binding::BuiltIn(_) => {}
+ }
+
+ Ok(())
+ }
+
+ //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
+ // if they are struct, so that the `stage` argument here could be omitted.
+ fn write_semantic(
+ &mut self,
+ binding: &crate::Binding,
+ stage: Option<(ShaderStage, Io)>,
+ ) -> BackendResult {
+ match *binding {
+ crate::Binding::BuiltIn(builtin) => {
+ let builtin_str = builtin.to_hlsl_str()?;
+ write!(self.out, " : {builtin_str}")?;
+ }
+ crate::Binding::Location {
+ second_blend_source: true,
+ ..
+ } => {
+ write!(self.out, " : SV_Target1")?;
+ }
+ crate::Binding::Location {
+ location,
+ second_blend_source: false,
+ ..
+ } => {
+ if stage == Some((crate::ShaderStage::Fragment, Io::Output)) {
+ write!(self.out, " : SV_Target{location}")?;
+ } else {
+ write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ fn write_interface_struct(
+ &mut self,
+ module: &Module,
+ shader_stage: (ShaderStage, Io),
+ struct_name: String,
+ mut members: Vec<EpStructMember>,
+ ) -> Result<EntryPointBinding, Error> {
+ // Sort the members so that first come the user-defined varyings
+ // in ascending locations, and then built-ins. This allows VS and FS
+ // interfaces to match with regards to order.
+ members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
+
+ write!(self.out, "struct {struct_name}")?;
+ writeln!(self.out, " {{")?;
+ for m in members.iter() {
+ write!(self.out, "{}", back::INDENT)?;
+ if let Some(ref binding) = m.binding {
+ self.write_modifier(binding)?;
+ }
+ self.write_type(module, m.ty)?;
+ write!(self.out, " {}", &m.name)?;
+ if let Some(ref binding) = m.binding {
+ self.write_semantic(binding, Some(shader_stage))?;
+ }
+ writeln!(self.out, ";")?;
+ }
+ writeln!(self.out, "}};")?;
+ writeln!(self.out)?;
+
+ match shader_stage.1 {
+ Io::Input => {
+ // bring back the original order
+ members.sort_by_key(|m| m.index);
+ }
+ Io::Output => {
+ // keep it sorted by binding
+ }
+ }
+
+ Ok(EntryPointBinding {
+ arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
+ ty_name: struct_name,
+ members,
+ })
+ }
+
+ /// Flatten all entry point arguments into a single struct.
+ /// This is needed since we need to re-order them: first placing user locations,
+ /// then built-ins.
+ fn write_ep_input_struct(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ stage: ShaderStage,
+ entry_point_name: &str,
+ ) -> Result<EntryPointBinding, Error> {
+ let struct_name = format!("{stage:?}Input_{entry_point_name}");
+
+ let mut fake_members = Vec::new();
+ for arg in func.arguments.iter() {
+ match module.types[arg.ty].inner {
+ TypeInner::Struct { ref members, .. } => {
+ for member in members.iter() {
+ let name = self.namer.call_or(&member.name, "member");
+ let index = fake_members.len() as u32;
+ fake_members.push(EpStructMember {
+ name,
+ ty: member.ty,
+ binding: member.binding.clone(),
+ index,
+ });
+ }
+ }
+ _ => {
+ let member_name = self.namer.call_or(&arg.name, "member");
+ let index = fake_members.len() as u32;
+ fake_members.push(EpStructMember {
+ name: member_name,
+ ty: arg.ty,
+ binding: arg.binding.clone(),
+ index,
+ });
+ }
+ }
+ }
+
+ self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
+ }
+
+ /// Flatten all entry point results into a single struct.
+ /// This is needed since we need to re-order them: first placing user locations,
+ /// then built-ins.
+ fn write_ep_output_struct(
+ &mut self,
+ module: &Module,
+ result: &crate::FunctionResult,
+ stage: ShaderStage,
+ entry_point_name: &str,
+ ) -> Result<EntryPointBinding, Error> {
+ let struct_name = format!("{stage:?}Output_{entry_point_name}");
+
+ let mut fake_members = Vec::new();
+ let empty = [];
+ let members = match module.types[result.ty].inner {
+ TypeInner::Struct { ref members, .. } => members,
+ ref other => {
+ log::error!("Unexpected {:?} output type without a binding", other);
+ &empty[..]
+ }
+ };
+
+ for member in members.iter() {
+ let member_name = self.namer.call_or(&member.name, "member");
+ let index = fake_members.len() as u32;
+ fake_members.push(EpStructMember {
+ name: member_name,
+ ty: member.ty,
+ binding: member.binding.clone(),
+ index,
+ });
+ }
+
+ self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
+ }
+
+ /// Writes special interface structures for an entry point. The special structures have
+ /// all the fields flattened into them and sorted by binding. They are only needed for
+ /// VS outputs and FS inputs, so that these interfaces match.
+ fn write_ep_interface(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ stage: ShaderStage,
+ ep_name: &str,
+ ) -> Result<EntryPointInterface, Error> {
+ Ok(EntryPointInterface {
+ input: if !func.arguments.is_empty() && stage == ShaderStage::Fragment {
+ Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
+ } else {
+ None
+ },
+ output: match func.result {
+ Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
+ Some(self.write_ep_output_struct(module, fr, stage, ep_name)?)
+ }
+ _ => None,
+ },
+ })
+ }
+
+ /// Write an entry point preface that initializes the arguments as specified in IR.
+ fn write_ep_arguments_initialization(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ ep_index: u16,
+ ) -> BackendResult {
+ let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
+ Some(ep_input) => ep_input,
+ None => return Ok(()),
+ };
+ let mut fake_iter = ep_input.members.iter();
+ for (arg_index, arg) in func.arguments.iter().enumerate() {
+ write!(self.out, "{}", back::INDENT)?;
+ self.write_type(module, arg.ty)?;
+ let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
+ write!(self.out, " {arg_name}")?;
+ match module.types[arg.ty].inner {
+ TypeInner::Array { base, size, .. } => {
+ self.write_array_size(module, base, size)?;
+ let fake_member = fake_iter.next().unwrap();
+ writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
+ }
+ TypeInner::Struct { ref members, .. } => {
+ write!(self.out, " = {{ ")?;
+ for index in 0..members.len() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ let fake_member = fake_iter.next().unwrap();
+ write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
+ }
+ writeln!(self.out, " }};")?;
+ }
+ _ => {
+ let fake_member = fake_iter.next().unwrap();
+ writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
+ }
+ }
+ }
+ assert!(fake_iter.next().is_none());
+ Ok(())
+ }
+
+ /// Helper method used to write global variables
+ /// # Notes
+ /// Always adds a newline
+ fn write_global(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::GlobalVariable>,
+ ) -> BackendResult {
+ let global = &module.global_variables[handle];
+ let inner = &module.types[global.ty].inner;
+
+ if let Some(ref binding) = global.binding {
+ if let Err(err) = self.options.resolve_resource_binding(binding) {
+ log::info!(
+ "Skipping global {:?} (name {:?}) for being inaccessible: {}",
+ handle,
+ global.name,
+ err,
+ );
+ return Ok(());
+ }
+ }
+
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
+ let register_ty = match global.space {
+ crate::AddressSpace::Function => unreachable!("Function address space"),
+ crate::AddressSpace::Private => {
+ write!(self.out, "static ")?;
+ self.write_type(module, global.ty)?;
+ ""
+ }
+ crate::AddressSpace::WorkGroup => {
+ write!(self.out, "groupshared ")?;
+ self.write_type(module, global.ty)?;
+ ""
+ }
+ crate::AddressSpace::Uniform => {
+ // constant buffer declarations are expected to be inlined, e.g.
+ // `cbuffer foo: register(b0) { field1: type1; }`
+ write!(self.out, "cbuffer")?;
+ "b"
+ }
+ crate::AddressSpace::Storage { access } => {
+ let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
+ ("RW", "u")
+ } else {
+ ("", "t")
+ };
+ write!(self.out, "{prefix}ByteAddressBuffer")?;
+ register
+ }
+ crate::AddressSpace::Handle => {
+ let handle_ty = match *inner {
+ TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
+ _ => inner,
+ };
+
+ let register = match *handle_ty {
+ TypeInner::Sampler { .. } => "s",
+ // all storage textures are UAV, unconditionally
+ TypeInner::Image {
+ class: crate::ImageClass::Storage { .. },
+ ..
+ } => "u",
+ _ => "t",
+ };
+ self.write_type(module, global.ty)?;
+ register
+ }
+ crate::AddressSpace::PushConstant => {
+ // The type of the push constants will be wrapped in `ConstantBuffer`
+ write!(self.out, "ConstantBuffer<")?;
+ "b"
+ }
+ };
+
+ // If the global is a push constant write the type now because it will be a
+ // generic argument to `ConstantBuffer`
+ if global.space == crate::AddressSpace::PushConstant {
+ self.write_global_type(module, global.ty)?;
+
+ // need to write the array size if the type was emitted with `write_type`
+ if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ // Close the angled brackets for the generic argument
+ write!(self.out, ">")?;
+ }
+
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, " {name}")?;
+
+ // Push constants need to be assigned a binding explicitly by the consumer
+ // since naga has no way to know the binding from the shader alone
+ if global.space == crate::AddressSpace::PushConstant {
+ let target = self
+ .options
+ .push_constants_target
+ .as_ref()
+ .expect("No bind target was defined for the push constants block");
+ write!(self.out, ": register(b{}", target.register)?;
+ if target.space != 0 {
+ write!(self.out, ", space{}", target.space)?;
+ }
+ write!(self.out, ")")?;
+ }
+
+ if let Some(ref binding) = global.binding {
+ // this was already resolved earlier when we started evaluating an entry point.
+ let bt = self.options.resolve_resource_binding(binding).unwrap();
+
+ // need to write the binding array size if the type was emitted with `write_type`
+ if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
+ if let Some(overridden_size) = bt.binding_array_size {
+ write!(self.out, "[{overridden_size}]")?;
+ } else {
+ self.write_array_size(module, base, size)?;
+ }
+ }
+
+ write!(self.out, " : register({}{}", register_ty, bt.register)?;
+ if bt.space != 0 {
+ write!(self.out, ", space{}", bt.space)?;
+ }
+ write!(self.out, ")")?;
+ } else {
+ // need to write the array size if the type was emitted with `write_type`
+ if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ if global.space == crate::AddressSpace::Private {
+ write!(self.out, " = ")?;
+ if let Some(init) = global.init {
+ self.write_const_expression(module, init)?;
+ } else {
+ self.write_default_init(module, global.ty)?;
+ }
+ }
+ }
+
+ if global.space == crate::AddressSpace::Uniform {
+ write!(self.out, " {{ ")?;
+
+ self.write_global_type(module, global.ty)?;
+
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::GlobalVariable(handle)]
+ )?;
+
+ // need to write the array size if the type was emitted with `write_type`
+ if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ writeln!(self.out, "; }}")?;
+ } else {
+ writeln!(self.out, ";")?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write global constants
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_global_constant(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Constant>,
+ ) -> BackendResult {
+ write!(self.out, "static const ")?;
+ let constant = &module.constants[handle];
+ self.write_type(module, constant.ty)?;
+ let name = &self.names[&NameKey::Constant(handle)];
+ write!(self.out, " {}", name)?;
+ // Write size for array type
+ if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, " = ")?;
+ self.write_const_expression(module, constant.init)?;
+ writeln!(self.out, ";")?;
+ Ok(())
+ }
+
+ pub(super) fn write_array_size(
+ &mut self,
+ module: &Module,
+ base: Handle<crate::Type>,
+ size: crate::ArraySize,
+ ) -> BackendResult {
+ write!(self.out, "[")?;
+
+ match size {
+ crate::ArraySize::Constant(size) => {
+ write!(self.out, "{size}")?;
+ }
+ crate::ArraySize::Dynamic => unreachable!(),
+ }
+
+ write!(self.out, "]")?;
+
+ if let TypeInner::Array {
+ base: next_base,
+ size: next_size,
+ ..
+ } = module.types[base].inner
+ {
+ self.write_array_size(module, next_base, next_size)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write structs
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_struct(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Type>,
+ members: &[crate::StructMember],
+ span: u32,
+ shader_stage: Option<(ShaderStage, Io)>,
+ ) -> BackendResult {
+ // Write struct name
+ let struct_name = &self.names[&NameKey::Type(handle)];
+ writeln!(self.out, "struct {struct_name} {{")?;
+
+ let mut last_offset = 0;
+ for (index, member) in members.iter().enumerate() {
+ if member.binding.is_none() && member.offset > last_offset {
+ // using int as padding should work as long as the backend
+ // doesn't support a type that's less than 4 bytes in size
+ // (Error::UnsupportedScalar catches this)
+ let padding = (member.offset - last_offset) / 4;
+ for i in 0..padding {
+ writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
+ }
+ }
+ let ty_inner = &module.types[member.ty].inner;
+ last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx());
+
+ // The indentation is only for readability
+ write!(self.out, "{}", back::INDENT)?;
+
+ match module.types[member.ty].inner {
+ TypeInner::Array { base, size, .. } => {
+ // HLSL arrays are written as `type name[size]`
+
+ self.write_global_type(module, member.ty)?;
+
+ // Write `name`
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::StructMember(handle, index as u32)]
+ )?;
+ // Write [size]
+ self.write_array_size(module, base, size)?;
+ }
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ TypeInner::Matrix {
+ rows,
+ columns,
+ scalar,
+ } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
+ let vec_ty = crate::TypeInner::Vector { size: rows, scalar };
+ let field_name_key = NameKey::StructMember(handle, index as u32);
+
+ for i in 0..columns as u8 {
+ if i != 0 {
+ write!(self.out, "; ")?;
+ }
+ self.write_value_type(module, &vec_ty)?;
+ write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
+ }
+ }
+ _ => {
+ // Write modifier before type
+ if let Some(ref binding) = member.binding {
+ self.write_modifier(binding)?;
+ }
+
+ // Even though Naga IR matrices are column-major, we must describe
+ // matrices passed from the CPU as being in row-major order.
+ // See the module-level block comment in mod.rs for details.
+ if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
+ write!(self.out, "row_major ")?;
+ }
+
+ // Write the member type and name
+ self.write_type(module, member.ty)?;
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::StructMember(handle, index as u32)]
+ )?;
+ }
+ }
+
+ if let Some(ref binding) = member.binding {
+ self.write_semantic(binding, shader_stage)?;
+ };
+ writeln!(self.out, ";")?;
+ }
+
+ // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
+ if members.last().unwrap().binding.is_none() && span > last_offset {
+ let padding = (span - last_offset) / 4;
+ for i in 0..padding {
+ writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
+ }
+ }
+
+ writeln!(self.out, "}};")?;
+ Ok(())
+ }
+
+ /// Helper method used to write global/structs non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ pub(super) fn write_global_type(
+ &mut self,
+ module: &Module,
+ ty: Handle<crate::Type>,
+ ) -> BackendResult {
+ let matrix_data = get_inner_matrix_data(module, ty);
+
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = matrix_data
+ {
+ write!(self.out, "__mat{}x2", columns as u8)?;
+ } else {
+ // Even though Naga IR matrices are column-major, we must describe
+ // matrices passed from the CPU as being in row-major order.
+ // See the module-level block comment in mod.rs for details.
+ if matrix_data.is_some() {
+ write!(self.out, "row_major ")?;
+ }
+
+ self.write_type(module, ty)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ let inner = &module.types[ty].inner;
+ match *inner {
+ TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
+ // hlsl array has the size separated from the base type
+ TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
+ self.write_type(module, base)?
+ }
+ ref other => self.write_value_type(module, other)?,
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write value types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
+ match *inner {
+ TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
+ write!(self.out, "{}", scalar.to_hlsl_str()?)?;
+ }
+ TypeInner::Vector { size, scalar } => {
+ write!(
+ self.out,
+ "{}{}",
+ scalar.to_hlsl_str()?,
+ back::vector_size_str(size)
+ )?;
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ // The IR supports only float matrix
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
+
+ // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
+ write!(
+ self.out,
+ "{}{}x{}",
+ scalar.to_hlsl_str()?,
+ back::vector_size_str(columns),
+ back::vector_size_str(rows),
+ )?;
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ self.write_image_type(dim, arrayed, class)?;
+ }
+ TypeInner::Sampler { comparison } => {
+ let sampler = if comparison {
+ "SamplerComparisonState"
+ } else {
+ "SamplerState"
+ };
+ write!(self.out, "{sampler}")?;
+ }
+ // HLSL arrays are written as `type name[size]`
+ // Current code is written arrays only as `[size]`
+ // Base `type` and `name` should be written outside
+ TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
+ self.write_array_size(module, base, size)?;
+ }
+ _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write functions
+ /// # Notes
+ /// Ends in a newline
+ fn write_function(
+ &mut self,
+ module: &Module,
+ name: &str,
+ func: &crate::Function,
+ func_ctx: &back::FunctionCtx<'_>,
+ info: &valid::FunctionInfo,
+ ) -> BackendResult {
+ // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
+
+ self.update_expressions_to_bake(module, func, info);
+
+ // Write modifier
+ if let Some(crate::FunctionResult {
+ binding:
+ Some(
+ ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position {
+ invariant: true,
+ }),
+ ),
+ ..
+ }) = func.result
+ {
+ self.write_modifier(binding)?;
+ }
+
+ // Write return type
+ if let Some(ref result) = func.result {
+ match func_ctx.ty {
+ back::FunctionType::Function(_) => {
+ self.write_type(module, result.ty)?;
+ }
+ back::FunctionType::EntryPoint(index) => {
+ if let Some(ref ep_output) = self.entry_point_io[index as usize].output {
+ write!(self.out, "{}", ep_output.ty_name)?;
+ } else {
+ self.write_type(module, result.ty)?;
+ }
+ }
+ }
+ } else {
+ write!(self.out, "void")?;
+ }
+
+ // Write function name
+ write!(self.out, " {name}(")?;
+
+ let need_workgroup_variables_initialization =
+ self.need_workgroup_variables_initialization(func_ctx, module);
+
+ // Write function arguments for non entry point functions
+ match func_ctx.ty {
+ back::FunctionType::Function(handle) => {
+ for (index, arg) in func.arguments.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ // Write argument type
+ let arg_ty = match module.types[arg.ty].inner {
+ // pointers in function arguments are expected and resolve to `inout`
+ TypeInner::Pointer { base, .. } => {
+ //TODO: can we narrow this down to just `in` when possible?
+ write!(self.out, "inout ")?;
+ base
+ }
+ _ => arg.ty,
+ };
+ self.write_type(module, arg_ty)?;
+
+ let argument_name =
+ &self.names[&NameKey::FunctionArgument(handle, index as u32)];
+
+ // Write argument name. Space is important.
+ write!(self.out, " {argument_name}")?;
+ if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ }
+ }
+ back::FunctionType::EntryPoint(ep_index) => {
+ if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
+ write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?;
+ } else {
+ let stage = module.entry_points[ep_index as usize].stage;
+ for (index, arg) in func.arguments.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_type(module, arg.ty)?;
+
+ let argument_name =
+ &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
+
+ write!(self.out, " {argument_name}")?;
+ if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ if let Some(ref binding) = arg.binding {
+ self.write_semantic(binding, Some((stage, Io::Input)))?;
+ }
+ }
+
+ if need_workgroup_variables_initialization {
+ if !func.arguments.is_empty() {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
+ }
+ }
+ }
+ }
+ // Ends of arguments
+ write!(self.out, ")")?;
+
+ // Write semantic if it present
+ if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
+ let stage = module.entry_points[index as usize].stage;
+ if let Some(crate::FunctionResult {
+ binding: Some(ref binding),
+ ..
+ }) = func.result
+ {
+ self.write_semantic(binding, Some((stage, Io::Output)))?;
+ }
+ }
+
+ // Function body start
+ writeln!(self.out)?;
+ writeln!(self.out, "{{")?;
+
+ if need_workgroup_variables_initialization {
+ self.write_workgroup_variables_initialization(func_ctx, module)?;
+ }
+
+ if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
+ self.write_ep_arguments_initialization(module, func, index)?;
+ }
+
+ // Write function local variables
+ for (handle, local) in func.local_variables.iter() {
+ // Write indentation (only for readability)
+ write!(self.out, "{}", back::INDENT)?;
+
+ // Write the local name
+ // The leading space is important
+ self.write_type(module, local.ty)?;
+ write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
+ // Write size for array type
+ if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ write!(self.out, " = ")?;
+ // Write the local initializer if needed
+ if let Some(init) = local.init {
+ self.write_expr(module, init, func_ctx)?;
+ } else {
+ // Zero initialize local variables
+ self.write_default_init(module, local.ty)?;
+ }
+
+ // Finish the local with `;` and add a newline (only for readability)
+ writeln!(self.out, ";")?
+ }
+
+ if !func.local_variables.is_empty() {
+ writeln!(self.out)?;
+ }
+
+ // Write the function body (statement list)
+ for sta in func.body.iter() {
+ // The indentation should always be 1 when writing the function body
+ self.write_stmt(module, sta, func_ctx, back::Level(1))?;
+ }
+
+ writeln!(self.out, "}}")?;
+
+ self.named_expressions.clear();
+
+ Ok(())
+ }
+
+ fn need_workgroup_variables_initialization(
+ &mut self,
+ func_ctx: &back::FunctionCtx,
+ module: &Module,
+ ) -> bool {
+ self.options.zero_initialize_workgroup_memory
+ && func_ctx.ty.is_compute_entry_point(module)
+ && module.global_variables.iter().any(|(handle, var)| {
+ !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ })
+ }
+
+ fn write_workgroup_variables_initialization(
+ &mut self,
+ func_ctx: &back::FunctionCtx,
+ module: &Module,
+ ) -> BackendResult {
+ let level = back::Level(1);
+
+ writeln!(
+ self.out,
+ "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
+ )?;
+
+ let vars = module.global_variables.iter().filter(|&(handle, var)| {
+ !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ });
+
+ for (handle, var) in vars {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{}{} = ", level.next(), name)?;
+ self.write_default_init(module, var.ty)?;
+ writeln!(self.out, ";")?;
+ }
+
+ writeln!(self.out, "{level}}}")?;
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)
+ }
+
+ /// Helper method used to write statements
+ ///
+ /// # Notes
+ /// Always adds a newline
+ fn write_stmt(
+ &mut self,
+ module: &Module,
+ stmt: &crate::Statement,
+ func_ctx: &back::FunctionCtx<'_>,
+ level: back::Level,
+ ) -> BackendResult {
+ use crate::Statement;
+
+ match *stmt {
+ Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
+ let expr_name = if ptr_class.is_some() {
+ // HLSL can't save a pointer-valued expression in a variable,
+ // but we shouldn't ever need to: they should never be named expressions,
+ // and none of the expression types flagged by bake_ref_count can be pointer-valued.
+ None
+ } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
+ // Front end provides names for all variables at the start of writing.
+ // But we write them to step by step. We need to recache them
+ // Otherwise, we could accidentally write variable name instead of full expression.
+ // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
+ Some(self.namer.call(name))
+ } else if self.need_bake_expressions.contains(&handle) {
+ Some(format!("_expr{}", handle.index()))
+ } else {
+ None
+ };
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{level}")?;
+ self.write_named_expr(module, handle, name, handle, func_ctx)?;
+ }
+ }
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Block(ref block) => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "{{")?;
+ for sta in block.iter() {
+ // Increase the indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, level.next())?
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "if (")?;
+ self.write_expr(module, condition, func_ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ let l2 = level.next();
+ for sta in accept {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+
+ // If there are no statements in the reject block we skip writing it
+ // This is only for readability
+ if !reject.is_empty() {
+ writeln!(self.out, "{level}}} else {{")?;
+
+ for sta in reject {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Kill => writeln!(self.out, "{level}discard;")?,
+ Statement::Return { value: None } => {
+ writeln!(self.out, "{level}return;")?;
+ }
+ Statement::Return { value: Some(expr) } => {
+ let base_ty_res = &func_ctx.info[expr].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+ if let TypeInner::Pointer { base, space: _ } = *resolved {
+ resolved = &module.types[base].inner;
+ }
+
+ if let TypeInner::Struct { .. } = *resolved {
+ // We can safely unwrap here, since we now we working with struct
+ let ty = base_ty_res.handle().unwrap();
+ let struct_name = &self.names[&NameKey::Type(ty)];
+ let variable_name = self.namer.call(&struct_name.to_lowercase());
+ write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
+ self.write_expr(module, expr, func_ctx)?;
+ writeln!(self.out, ";")?;
+
+ // for entry point returns, we may need to reshuffle the outputs into a different struct
+ let ep_output = match func_ctx.ty {
+ back::FunctionType::Function(_) => None,
+ back::FunctionType::EntryPoint(index) => {
+ self.entry_point_io[index as usize].output.as_ref()
+ }
+ };
+ let final_name = match ep_output {
+ Some(ep_output) => {
+ let final_name = self.namer.call(&variable_name);
+ write!(
+ self.out,
+ "{}const {} {} = {{ ",
+ level, ep_output.ty_name, final_name,
+ )?;
+ for (index, m) in ep_output.members.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
+ write!(self.out, "{variable_name}.{member_name}")?;
+ }
+ writeln!(self.out, " }};")?;
+ final_name
+ }
+ None => variable_name,
+ };
+ writeln!(self.out, "{level}return {final_name};")?;
+ } else {
+ write!(self.out, "{level}return ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ writeln!(self.out, ";")?
+ }
+ }
+ Statement::Store { pointer, value } => {
+ let ty_inner = func_ctx.resolve_type(pointer, &module.types);
+ if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
+ let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
+ self.write_storage_store(
+ module,
+ var_handle,
+ StoreValue::Expression(value),
+ func_ctx,
+ level,
+ )?;
+ } else {
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ //
+ // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
+ // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
+ struct MatrixAccess {
+ base: Handle<crate::Expression>,
+ index: u32,
+ }
+ enum Index {
+ Expression(Handle<crate::Expression>),
+ Static(u32),
+ }
+
+ let get_members = |expr: Handle<crate::Expression>| {
+ let resolved = func_ctx.resolve_type(expr, &module.types);
+ match *resolved {
+ TypeInner::Pointer { base, .. } => match module.types[base].inner {
+ TypeInner::Struct { ref members, .. } => Some(members),
+ _ => None,
+ },
+ _ => None,
+ }
+ };
+
+ let mut matrix = None;
+ let mut vector = None;
+ let mut scalar = None;
+
+ let mut current_expr = pointer;
+ for _ in 0..3 {
+ let resolved = func_ctx.resolve_type(current_expr, &module.types);
+
+ match (resolved, &func_ctx.expressions[current_expr]) {
+ (
+ &TypeInner::Pointer { base: ty, .. },
+ &crate::Expression::AccessIndex { base, index },
+ ) if matches!(
+ module.types[ty].inner,
+ TypeInner::Matrix {
+ rows: crate::VectorSize::Bi,
+ ..
+ }
+ ) && get_members(base)
+ .map(|members| members[index as usize].binding.is_none())
+ == Some(true) =>
+ {
+ matrix = Some(MatrixAccess { base, index });
+ break;
+ }
+ (
+ &TypeInner::ValuePointer {
+ size: Some(crate::VectorSize::Bi),
+ ..
+ },
+ &crate::Expression::Access { base, index },
+ ) => {
+ vector = Some(Index::Expression(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer {
+ size: Some(crate::VectorSize::Bi),
+ ..
+ },
+ &crate::Expression::AccessIndex { base, index },
+ ) => {
+ vector = Some(Index::Static(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::Access { base, index },
+ ) => {
+ scalar = Some(Index::Expression(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::AccessIndex { base, index },
+ ) => {
+ scalar = Some(Index::Static(index));
+ current_expr = base;
+ }
+ _ => break,
+ }
+ }
+
+ write!(self.out, "{level}")?;
+
+ if let Some(MatrixAccess { index, base }) = matrix {
+ let base_ty_res = &func_ctx.info[base].ty;
+ let resolved = base_ty_res.inner_with(&module.types);
+ let ty = match *resolved {
+ TypeInner::Pointer { base, .. } => base,
+ _ => base_ty_res.handle().unwrap(),
+ };
+
+ if let Some(Index::Static(vec_index)) = vector {
+ self.write_expr(module, base, func_ctx)?;
+ write!(
+ self.out,
+ ".{}_{}",
+ &self.names[&NameKey::StructMember(ty, index)],
+ vec_index
+ )?;
+
+ if let Some(scalar_index) = scalar {
+ write!(self.out, "[")?;
+ match scalar_index {
+ Index::Static(index) => {
+ write!(self.out, "{index}")?;
+ }
+ Index::Expression(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ write!(self.out, "]")?;
+ }
+
+ write!(self.out, " = ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ } else {
+ let access = WrappedStructMatrixAccess { ty, index };
+ match (&vector, &scalar) {
+ (&Some(_), &Some(_)) => {
+ self.write_wrapped_struct_matrix_set_scalar_function_name(
+ access,
+ )?;
+ }
+ (&Some(_), &None) => {
+ self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
+ }
+ (&None, _) => {
+ self.write_wrapped_struct_matrix_set_function_name(access)?;
+ }
+ }
+
+ write!(self.out, "(")?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+
+ if let Some(Index::Expression(vec_index)) = vector {
+ write!(self.out, ", ")?;
+ self.write_expr(module, vec_index, func_ctx)?;
+
+ if let Some(scalar_index) = scalar {
+ write!(self.out, ", ")?;
+ match scalar_index {
+ Index::Static(index) => {
+ write!(self.out, "{index}")?;
+ }
+ Index::Expression(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ }
+ }
+ writeln!(self.out, ");")?;
+ }
+ } else {
+ // We handle `Store`s to __matCx2 column vectors and scalar elements via
+ // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
+ struct MatrixData {
+ columns: crate::VectorSize,
+ base: Handle<crate::Expression>,
+ }
+
+ enum Index {
+ Expression(Handle<crate::Expression>),
+ Static(u32),
+ }
+
+ let mut matrix = None;
+ let mut vector = None;
+ let mut scalar = None;
+
+ let mut current_expr = pointer;
+ for _ in 0..3 {
+ let resolved = func_ctx.resolve_type(current_expr, &module.types);
+ match (resolved, &func_ctx.expressions[current_expr]) {
+ (
+ &TypeInner::ValuePointer {
+ size: Some(crate::VectorSize::Bi),
+ ..
+ },
+ &crate::Expression::Access { base, index },
+ ) => {
+ vector = Some(index);
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::Access { base, index },
+ ) => {
+ scalar = Some(Index::Expression(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::AccessIndex { base, index },
+ ) => {
+ scalar = Some(Index::Static(index));
+ current_expr = base;
+ }
+ _ => {
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = get_inner_matrix_of_struct_array_member(
+ module,
+ current_expr,
+ func_ctx,
+ true,
+ ) {
+ matrix = Some(MatrixData {
+ columns,
+ base: current_expr,
+ });
+ }
+
+ break;
+ }
+ }
+ }
+
+ if let (Some(MatrixData { columns, base }), Some(vec_index)) =
+ (matrix, vector)
+ {
+ if scalar.is_some() {
+ write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
+ } else {
+ write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
+ }
+ write!(self.out, "(")?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, vec_index, func_ctx)?;
+
+ if let Some(scalar_index) = scalar {
+ write!(self.out, ", ")?;
+ match scalar_index {
+ Index::Static(index) => {
+ write!(self.out, "{index}")?;
+ }
+ Index::Expression(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ }
+
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+
+ writeln!(self.out, ");")?;
+ } else {
+ self.write_expr(module, pointer, func_ctx)?;
+ write!(self.out, " = ")?;
+
+ // We cast the RHS of this store in cases where the LHS
+ // is a struct member with type:
+ // - matCx2 or
+ // - a (possibly nested) array of matCx2's
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = get_inner_matrix_of_struct_array_member(
+ module, pointer, func_ctx, false,
+ ) {
+ let mut resolved = func_ctx.resolve_type(pointer, &module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ }
+
+ write!(self.out, "(__mat{}x2", columns as u8)?;
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, ")")?;
+ }
+
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ";")?
+ }
+ }
+ }
+ }
+ Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ let l2 = level.next();
+ if !continuing.is_empty() || break_if.is_some() {
+ let gate_name = self.namer.call("loop_init");
+ writeln!(self.out, "{level}bool {gate_name} = true;")?;
+ writeln!(self.out, "{level}while(true) {{")?;
+ writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
+ let l3 = l2.next();
+ for sta in continuing.iter() {
+ self.write_stmt(module, sta, func_ctx, l3)?;
+ }
+ if let Some(condition) = break_if {
+ write!(self.out, "{l3}if (")?;
+ self.write_expr(module, condition, func_ctx)?;
+ writeln!(self.out, ") {{")?;
+ writeln!(self.out, "{}break;", l3.next())?;
+ writeln!(self.out, "{l3}}}")?;
+ }
+ writeln!(self.out, "{l2}}}")?;
+ writeln!(self.out, "{l2}{gate_name} = false;")?;
+ } else {
+ writeln!(self.out, "{level}while(true) {{")?;
+ }
+
+ for sta in body.iter() {
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Break => writeln!(self.out, "{level}break;")?,
+ Statement::Continue => writeln!(self.out, "{level}continue;")?,
+ Statement::Barrier(barrier) => {
+ self.write_barrier(barrier, level)?;
+ }
+ Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ write!(self.out, "{level}")?;
+ self.write_expr(module, image, func_ctx)?;
+
+ write!(self.out, "[")?;
+ if let Some(index) = array_index {
+ // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
+ write!(self.out, "int3(")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(module, coordinate, func_ctx)?;
+ }
+ write!(self.out, "]")?;
+
+ write!(self.out, " = ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ }
+ Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ if let Some(expr) = result {
+ write!(self.out, "const ")?;
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ let expr_ty = &func_ctx.info[expr].ty;
+ match *expr_ty {
+ proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
+ proc::TypeResolution::Value(ref value) => {
+ self.write_value_type(module, value)?
+ }
+ };
+ write!(self.out, " {name} = ")?;
+ self.named_expressions.insert(expr, name);
+ }
+ let func_name = &self.names[&NameKey::Function(function)];
+ write!(self.out, "{func_name}(")?;
+ for (index, argument) in arguments.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_expr(module, *argument, func_ctx)?;
+ }
+ writeln!(self.out, ");")?
+ }
+ Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ match func_ctx.info[result].ty {
+ proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
+ proc::TypeResolution::Value(ref value) => {
+ self.write_value_type(module, value)?
+ }
+ };
+
+ // Validation ensures that `pointer` has a `Pointer` type.
+ let pointer_space = func_ctx
+ .resolve_type(pointer, &module.types)
+ .pointer_space()
+ .unwrap();
+
+ let fun_str = fun.to_hlsl_suffix();
+ write!(self.out, " {res_name}; ")?;
+ match pointer_space {
+ crate::AddressSpace::WorkGroup => {
+ write!(self.out, "Interlocked{fun_str}(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ }
+ crate::AddressSpace::Storage { .. } => {
+ let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
+ // The call to `self.write_storage_address` wants
+ // mutable access to all of `self`, so temporarily take
+ // ownership of our reusable access chain buffer.
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "{var_name}.Interlocked{fun_str}(")?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ self.temp_access_chain = chain;
+ }
+ ref other => {
+ return Err(Error::Custom(format!(
+ "invalid address space {other:?} for atomic statement"
+ )))
+ }
+ }
+ write!(self.out, ", ")?;
+ // handle the special cases
+ match *fun {
+ crate::AtomicFunction::Subtract => {
+ // we just wrote `InterlockedAdd`, so negate the argument
+ write!(self.out, "-")?;
+ }
+ crate::AtomicFunction::Exchange { compare: Some(_) } => {
+ return Err(Error::Unimplemented("atomic CompareExchange".to_string()));
+ }
+ _ => {}
+ }
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ", {res_name});")?;
+ self.named_expressions.insert(result, res_name);
+ }
+ Statement::WorkGroupUniformLoad { pointer, result } => {
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+ write!(self.out, "{level}")?;
+ let name = format!("_expr{}", result.index());
+ self.write_named_expr(module, pointer, name, result, func_ctx)?;
+
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+ }
+ Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ // Start the switch
+ write!(self.out, "{level}")?;
+ write!(self.out, "switch(")?;
+ self.write_expr(module, selector, func_ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ // Write all cases
+ let indent_level_1 = level.next();
+ let indent_level_2 = indent_level_1.next();
+
+ for (i, case) in cases.iter().enumerate() {
+ match case.value {
+ crate::SwitchValue::I32(value) => {
+ write!(self.out, "{indent_level_1}case {value}:")?
+ }
+ crate::SwitchValue::U32(value) => {
+ write!(self.out, "{indent_level_1}case {value}u:")?
+ }
+ crate::SwitchValue::Default => {
+ write!(self.out, "{indent_level_1}default:")?
+ }
+ }
+
+ // The new block is not only stylistic, it plays a role here:
+ // We might end up having to write the same case body
+ // multiple times due to FXC not supporting fallthrough.
+ // Therefore, some `Expression`s written by `Statement::Emit`
+ // will end up having the same name (`_expr<handle_index>`).
+ // So we need to put each case in its own scope.
+ let write_block_braces = !(case.fall_through && case.body.is_empty());
+ if write_block_braces {
+ writeln!(self.out, " {{")?;
+ } else {
+ writeln!(self.out)?;
+ }
+
+ // Although FXC does support a series of case clauses before
+ // a block[^yes], it does not support fallthrough from a
+ // non-empty case block to the next[^no]. If this case has a
+ // non-empty body with a fallthrough, emulate that by
+ // duplicating the bodies of all the cases it would fall
+ // into as extensions of this case's own body. This makes
+ // the HLSL output potentially quadratic in the size of the
+ // Naga IR.
+ //
+ // [^yes]: ```hlsl
+ // case 1:
+ // case 2: do_stuff()
+ // ```
+ // [^no]: ```hlsl
+ // case 1: do_this();
+ // case 2: do_that();
+ // ```
+ if case.fall_through && !case.body.is_empty() {
+ let curr_len = i + 1;
+ let end_case_idx = curr_len
+ + cases
+ .iter()
+ .skip(curr_len)
+ .position(|case| !case.fall_through)
+ .unwrap();
+ let indent_level_3 = indent_level_2.next();
+ for case in &cases[i..=end_case_idx] {
+ writeln!(self.out, "{indent_level_2}{{")?;
+ let prev_len = self.named_expressions.len();
+ for sta in case.body.iter() {
+ self.write_stmt(module, sta, func_ctx, indent_level_3)?;
+ }
+ // Clear all named expressions that were previously inserted by the statements in the block
+ self.named_expressions.truncate(prev_len);
+ writeln!(self.out, "{indent_level_2}}}")?;
+ }
+
+ let last_case = &cases[end_case_idx];
+ if last_case.body.last().map_or(true, |s| !s.is_terminator()) {
+ writeln!(self.out, "{indent_level_2}break;")?;
+ }
+ } else {
+ for sta in case.body.iter() {
+ self.write_stmt(module, sta, func_ctx, indent_level_2)?;
+ }
+ if !case.fall_through
+ && case.body.last().map_or(true, |s| !s.is_terminator())
+ {
+ writeln!(self.out, "{indent_level_2}break;")?;
+ }
+ }
+
+ if write_block_braces {
+ writeln!(self.out, "{indent_level_1}}}")?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::RayQuery { .. } => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ fn write_const_expression(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ ) -> BackendResult {
+ self.write_possibly_const_expression(
+ module,
+ expr,
+ &module.const_expressions,
+ |writer, expr| writer.write_const_expression(module, expr),
+ )
+ }
+
+ fn write_possibly_const_expression<E>(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ expressions: &crate::Arena<crate::Expression>,
+ write_expression: E,
+ ) -> BackendResult
+ where
+ E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
+ {
+ use crate::Expression;
+
+ match expressions[expr] {
+ Expression::Literal(literal) => match literal {
+ // Floats are written using `Debug` instead of `Display` because it always appends the
+ // decimal part even it's zero
+ crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
+ crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
+ crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
+ crate::Literal::I32(value) => write!(self.out, "{}", value)?,
+ crate::Literal::I64(value) => write!(self.out, "{}L", value)?,
+ crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
+ crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
+ return Err(Error::Custom(
+ "Abstract types should not appear in IR presented to backends".into(),
+ ));
+ }
+ },
+ Expression::Constant(handle) => {
+ let constant = &module.constants[handle];
+ if constant.name.is_some() {
+ write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
+ } else {
+ self.write_const_expression(module, constant.init)?;
+ }
+ }
+ Expression::ZeroValue(ty) => self.write_default_init(module, ty)?,
+ Expression::Compose { ty, ref components } => {
+ match module.types[ty].inner {
+ TypeInner::Struct { .. } | TypeInner::Array { .. } => {
+ self.write_wrapped_constructor_function_name(
+ module,
+ WrappedConstructor { ty },
+ )?;
+ }
+ _ => {
+ self.write_type(module, ty)?;
+ }
+ };
+ write!(self.out, "(")?;
+ for (index, component) in components.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ write_expression(self, *component)?;
+ }
+ write!(self.out, ")")?;
+ }
+ Expression::Splat { size, value } => {
+ // hlsl is not supported one value constructor
+ // if we write, for example, int4(0), dxc returns error:
+ // error: too few elements in vector initialization (expected 4 elements, have 1)
+ let number_of_components = match size {
+ crate::VectorSize::Bi => "xx",
+ crate::VectorSize::Tri => "xxx",
+ crate::VectorSize::Quad => "xxxx",
+ };
+ write!(self.out, "(")?;
+ write_expression(self, value)?;
+ write!(self.out, ").{number_of_components}")?
+ }
+ _ => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Helper method to write expressions
+ ///
+ /// # Notes
+ /// Doesn't add any newlines or leading/trailing spaces
+ pub(super) fn write_expr(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ use crate::Expression;
+
+ // Handle the special semantics of vertex_index/instance_index
+ let ff_input = if self.options.special_constants_binding.is_some() {
+ func_ctx.is_fixed_function_input(expr, module)
+ } else {
+ None
+ };
+ let closing_bracket = match ff_input {
+ Some(crate::BuiltIn::VertexIndex) => {
+ write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
+ ")"
+ }
+ Some(crate::BuiltIn::InstanceIndex) => {
+ write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
+ ")"
+ }
+ Some(crate::BuiltIn::NumWorkGroups) => {
+ // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`),
+ // in compute shaders the special constants contain the number
+ // of workgroups, which we are using here.
+ write!(
+ self.out,
+ "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
+ )?;
+ return Ok(());
+ }
+ _ => "",
+ };
+
+ if let Some(name) = self.named_expressions.get(&expr) {
+ write!(self.out, "{name}{closing_bracket}")?;
+ return Ok(());
+ }
+
+ let expression = &func_ctx.expressions[expr];
+
+ match *expression {
+ Expression::Literal(_)
+ | Expression::Constant(_)
+ | Expression::ZeroValue(_)
+ | Expression::Compose { .. }
+ | Expression::Splat { .. } => {
+ self.write_possibly_const_expression(
+ module,
+ expr,
+ func_ctx.expressions,
+ |writer, expr| writer.write_expr(module, expr, func_ctx),
+ )?;
+ }
+ // All of the multiplication can be expressed as `mul`,
+ // except vector * vector, which needs to use the "*" operator.
+ Expression::Binary {
+ op: crate::BinaryOperator::Multiply,
+ left,
+ right,
+ } if func_ctx.resolve_type(left, &module.types).is_matrix()
+ || func_ctx.resolve_type(right, &module.types).is_matrix() =>
+ {
+ // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
+ write!(self.out, "mul(")?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+
+ // TODO: handle undefined behavior of BinaryOperator::Modulo
+ //
+ // sint:
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ // if sign(left) != sign(right) return result as defined by WGSL
+ //
+ // uint:
+ // if right == 0 return 0
+ //
+ // float:
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+
+ // While HLSL supports float operands with the % operator it is only
+ // defined in cases where both sides are either positive or negative.
+ Expression::Binary {
+ op: crate::BinaryOperator::Modulo,
+ left,
+ right,
+ } if func_ctx.resolve_type(left, &module.types).scalar_kind()
+ == Some(crate::ScalarKind::Float) =>
+ {
+ write!(self.out, "fmod(")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Binary { op, left, right } => {
+ write!(self.out, "(")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, " {} ", crate::back::binary_operation_str(op))?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Access { base, index } => {
+ if let Some(crate::AddressSpace::Storage { .. }) =
+ func_ctx.resolve_type(expr, &module.types).pointer_space()
+ {
+ // do nothing, the chain is written on `Load`/`Store`
+ } else {
+ // We use the function __get_col_of_matCx2 here in cases
+ // where `base`s type resolves to a matCx2 and is part of a
+ // struct member with type of (possibly nested) array of matCx2's.
+ //
+ // Note that this only works for `Load`s and we handle
+ // `Store`s differently in `Statement::Store`.
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
+ {
+ write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ write!(self.out, ")")?;
+ return Ok(());
+ }
+
+ let resolved = func_ctx.resolve_type(base, &module.types);
+
+ let non_uniform_qualifier = match *resolved {
+ TypeInner::BindingArray { .. } => {
+ let uniformity = &func_ctx.info[index].uniformity;
+
+ uniformity.non_uniform_result.is_some()
+ }
+ _ => false,
+ };
+
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, "[")?;
+ if non_uniform_qualifier {
+ write!(self.out, "NonUniformResourceIndex(")?;
+ }
+ self.write_expr(module, index, func_ctx)?;
+ if non_uniform_qualifier {
+ write!(self.out, ")")?;
+ }
+ write!(self.out, "]")?;
+ }
+ }
+ Expression::AccessIndex { base, index } => {
+ if let Some(crate::AddressSpace::Storage { .. }) =
+ func_ctx.resolve_type(expr, &module.types).pointer_space()
+ {
+ // do nothing, the chain is written on `Load`/`Store`
+ } else {
+ fn write_access<W: fmt::Write>(
+ writer: &mut super::Writer<'_, W>,
+ resolved: &TypeInner,
+ base_ty_handle: Option<Handle<crate::Type>>,
+ index: u32,
+ ) -> BackendResult {
+ match *resolved {
+ // We specifically lift the ValuePointer to this case. While `[0]` is valid
+ // HLSL for any vector behind a value pointer, FXC completely miscompiles
+ // it and generates completely nonsensical DXBC.
+ //
+ // See https://github.com/gfx-rs/naga/issues/2095 for more details.
+ TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
+ // Write vector access as a swizzle
+ write!(writer.out, ".{}", back::COMPONENTS[index as usize])?
+ }
+ TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::BindingArray { .. } => write!(writer.out, "[{index}]")?,
+ TypeInner::Struct { .. } => {
+ // This will never panic in case the type is a `Struct`, this is not true
+ // for other types so we can only check while inside this match arm
+ let ty = base_ty_handle.unwrap();
+
+ write!(
+ writer.out,
+ ".{}",
+ &writer.names[&NameKey::StructMember(ty, index)]
+ )?
+ }
+ ref other => {
+ return Err(Error::Custom(format!("Cannot index {other:?}")))
+ }
+ }
+ Ok(())
+ }
+
+ // We write the matrix column access in a special way since
+ // the type of `base` is our special __matCx2 struct.
+ if let Some(MatrixType {
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ ..
+ }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
+ {
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, "._{index}")?;
+ return Ok(());
+ }
+
+ let base_ty_res = &func_ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+ let base_ty_handle = match *resolved {
+ TypeInner::Pointer { base, .. } => {
+ resolved = &module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ //
+ // We handle matrix reconstruction here for Loads.
+ // Stores are handled directly by `Statement::Store`.
+ if let TypeInner::Struct { ref members, .. } = *resolved {
+ let member = &members[index as usize];
+
+ match module.types[member.ty].inner {
+ TypeInner::Matrix {
+ rows: crate::VectorSize::Bi,
+ ..
+ } if member.binding.is_none() => {
+ let ty = base_ty_handle.unwrap();
+ self.write_wrapped_struct_matrix_get_function_name(
+ WrappedStructMatrixAccess { ty, index },
+ )?;
+ write!(self.out, "(")?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ")")?;
+ return Ok(());
+ }
+ _ => {}
+ }
+ }
+
+ self.write_expr(module, base, func_ctx)?;
+ write_access(self, resolved, base_ty_handle, index)?;
+ }
+ }
+ Expression::FunctionArgument(pos) => {
+ let key = func_ctx.argument_key(pos);
+ let name = &self.names[&key];
+ write!(self.out, "{name}")?;
+ }
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ use crate::SampleLevel as Sl;
+ const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
+
+ let (base_str, component_str) = match gather {
+ Some(component) => ("Gather", COMPONENTS[component as usize]),
+ None => ("Sample", ""),
+ };
+ let cmp_str = match depth_ref {
+ Some(_) => "Cmp",
+ None => "",
+ };
+ let level_str = match level {
+ Sl::Zero if gather.is_none() => "LevelZero",
+ Sl::Auto | Sl::Zero => "",
+ Sl::Exact(_) => "Level",
+ Sl::Bias(_) => "Bias",
+ Sl::Gradient { .. } => "Grad",
+ };
+
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
+ self.write_expr(module, sampler, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_texture_coordinates(
+ "float",
+ coordinate,
+ array_index,
+ None,
+ module,
+ func_ctx,
+ )?;
+
+ if let Some(depth_ref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(module, depth_ref, func_ctx)?;
+ }
+
+ match level {
+ Sl::Auto | Sl::Zero => {}
+ Sl::Exact(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Bias(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Gradient { x, y } => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, x, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, y, func_ctx)?;
+ }
+ }
+
+ if let Some(offset) = offset {
+ write!(self.out, ", ")?;
+ write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807
+ self.write_const_expression(module, offset)?;
+ write!(self.out, ")")?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ Expression::ImageQuery { image, query } => {
+ // use wrapped image query function
+ if let TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } = *func_ctx.resolve_type(image, &module.types)
+ {
+ let wrapped_image_query = WrappedImageQuery {
+ dim,
+ arrayed,
+ class,
+ query: query.into(),
+ };
+
+ self.write_wrapped_image_query_function_name(wrapped_image_query)?;
+ write!(self.out, "(")?;
+ // Image always first param
+ self.write_expr(module, image, func_ctx)?;
+ if let crate::ImageQuery::Size { level: Some(level) } = query {
+ write!(self.out, ", ")?;
+ self.write_expr(module, level, func_ctx)?;
+ }
+ write!(self.out, ")")?;
+ }
+ }
+ Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ".Load(")?;
+
+ self.write_texture_coordinates(
+ "int",
+ coordinate,
+ array_index,
+ level,
+ module,
+ func_ctx,
+ )?;
+
+ if let Some(sample) = sample {
+ write!(self.out, ", ")?;
+ self.write_expr(module, sample, func_ctx)?;
+ }
+
+ // close bracket for Load function
+ write!(self.out, ")")?;
+
+ // return x component if return type is scalar
+ if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
+ write!(self.out, ".x")?;
+ }
+ }
+ Expression::GlobalVariable(handle) => match module.global_variables[handle].space {
+ crate::AddressSpace::Storage { .. } => {}
+ _ => {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{name}")?;
+ }
+ },
+ Expression::LocalVariable(handle) => {
+ write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
+ }
+ Expression::Load { pointer } => {
+ match func_ctx
+ .resolve_type(pointer, &module.types)
+ .pointer_space()
+ {
+ Some(crate::AddressSpace::Storage { .. }) => {
+ let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
+ let result_ty = func_ctx.info[expr].ty.clone();
+ self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
+ }
+ _ => {
+ let mut close_paren = false;
+
+ // We cast the value loaded to a native HLSL floatCx2
+ // in cases where it is of type:
+ // - __matCx2 or
+ // - a (possibly nested) array of __matCx2's
+ if let Some(MatrixType {
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ ..
+ }) = get_inner_matrix_of_struct_array_member(
+ module, pointer, func_ctx, false,
+ )
+ .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
+ {
+ let mut resolved = func_ctx.resolve_type(pointer, &module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ }
+
+ write!(self.out, "((")?;
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_type(module, base)?;
+ self.write_array_size(module, base, size)?;
+ } else {
+ self.write_value_type(module, resolved)?;
+ }
+ write!(self.out, ")")?;
+ close_paren = true;
+ }
+
+ self.write_expr(module, pointer, func_ctx)?;
+
+ if close_paren {
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ }
+ Expression::Unary { op, expr } => {
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
+ let op_str = match op {
+ crate::UnaryOperator::Negate => "-",
+ crate::UnaryOperator::LogicalNot => "!",
+ crate::UnaryOperator::BitwiseNot => "~",
+ };
+ write!(self.out, "{op_str}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ let inner = func_ctx.resolve_type(expr, &module.types);
+ match convert {
+ Some(dst_width) => {
+ let scalar = crate::Scalar {
+ kind,
+ width: dst_width,
+ };
+ match *inner {
+ TypeInner::Vector { size, .. } => {
+ write!(
+ self.out,
+ "{}{}(",
+ scalar.to_hlsl_str()?,
+ back::vector_size_str(size)
+ )?;
+ }
+ TypeInner::Scalar(_) => {
+ write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
+ }
+ TypeInner::Matrix { columns, rows, .. } => {
+ write!(
+ self.out,
+ "{}{}x{}(",
+ scalar.to_hlsl_str()?,
+ back::vector_size_str(columns),
+ back::vector_size_str(rows)
+ )?;
+ }
+ _ => {
+ return Err(Error::Unimplemented(format!(
+ "write_expr expression::as {inner:?}"
+ )));
+ }
+ };
+ }
+ None => {
+ write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
+ }
+ }
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ enum Function {
+ Asincosh { is_sin: bool },
+ Atanh,
+ ExtractBits,
+ InsertBits,
+ Pack2x16float,
+ Pack2x16snorm,
+ Pack2x16unorm,
+ Pack4x8snorm,
+ Pack4x8unorm,
+ Unpack2x16float,
+ Unpack2x16snorm,
+ Unpack2x16unorm,
+ Unpack4x8snorm,
+ Unpack4x8unorm,
+ Regular(&'static str),
+ MissingIntOverload(&'static str),
+ MissingIntReturnType(&'static str),
+ CountTrailingZeros,
+ CountLeadingZeros,
+ }
+
+ let fun = match fun {
+ // comparison
+ Mf::Abs => Function::Regular("abs"),
+ Mf::Min => Function::Regular("min"),
+ Mf::Max => Function::Regular("max"),
+ Mf::Clamp => Function::Regular("clamp"),
+ Mf::Saturate => Function::Regular("saturate"),
+ // trigonometry
+ Mf::Cos => Function::Regular("cos"),
+ Mf::Cosh => Function::Regular("cosh"),
+ Mf::Sin => Function::Regular("sin"),
+ Mf::Sinh => Function::Regular("sinh"),
+ Mf::Tan => Function::Regular("tan"),
+ Mf::Tanh => Function::Regular("tanh"),
+ Mf::Acos => Function::Regular("acos"),
+ Mf::Asin => Function::Regular("asin"),
+ Mf::Atan => Function::Regular("atan"),
+ Mf::Atan2 => Function::Regular("atan2"),
+ Mf::Asinh => Function::Asincosh { is_sin: true },
+ Mf::Acosh => Function::Asincosh { is_sin: false },
+ Mf::Atanh => Function::Atanh,
+ Mf::Radians => Function::Regular("radians"),
+ Mf::Degrees => Function::Regular("degrees"),
+ // decomposition
+ Mf::Ceil => Function::Regular("ceil"),
+ Mf::Floor => Function::Regular("floor"),
+ Mf::Round => Function::Regular("round"),
+ Mf::Fract => Function::Regular("frac"),
+ Mf::Trunc => Function::Regular("trunc"),
+ Mf::Modf => Function::Regular(MODF_FUNCTION),
+ Mf::Frexp => Function::Regular(FREXP_FUNCTION),
+ Mf::Ldexp => Function::Regular("ldexp"),
+ // exponent
+ Mf::Exp => Function::Regular("exp"),
+ Mf::Exp2 => Function::Regular("exp2"),
+ Mf::Log => Function::Regular("log"),
+ Mf::Log2 => Function::Regular("log2"),
+ Mf::Pow => Function::Regular("pow"),
+ // geometry
+ Mf::Dot => Function::Regular("dot"),
+ //Mf::Outer => ,
+ Mf::Cross => Function::Regular("cross"),
+ Mf::Distance => Function::Regular("distance"),
+ Mf::Length => Function::Regular("length"),
+ Mf::Normalize => Function::Regular("normalize"),
+ Mf::FaceForward => Function::Regular("faceforward"),
+ Mf::Reflect => Function::Regular("reflect"),
+ Mf::Refract => Function::Regular("refract"),
+ // computational
+ Mf::Sign => Function::Regular("sign"),
+ Mf::Fma => Function::Regular("mad"),
+ Mf::Mix => Function::Regular("lerp"),
+ Mf::Step => Function::Regular("step"),
+ Mf::SmoothStep => Function::Regular("smoothstep"),
+ Mf::Sqrt => Function::Regular("sqrt"),
+ Mf::InverseSqrt => Function::Regular("rsqrt"),
+ //Mf::Inverse =>,
+ Mf::Transpose => Function::Regular("transpose"),
+ Mf::Determinant => Function::Regular("determinant"),
+ // bits
+ Mf::CountTrailingZeros => Function::CountTrailingZeros,
+ Mf::CountLeadingZeros => Function::CountLeadingZeros,
+ Mf::CountOneBits => Function::MissingIntOverload("countbits"),
+ Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
+ Mf::FindLsb => Function::MissingIntReturnType("firstbitlow"),
+ Mf::FindMsb => Function::MissingIntReturnType("firstbithigh"),
+ Mf::ExtractBits => Function::ExtractBits,
+ Mf::InsertBits => Function::InsertBits,
+ // Data Packing
+ Mf::Pack2x16float => Function::Pack2x16float,
+ Mf::Pack2x16snorm => Function::Pack2x16snorm,
+ Mf::Pack2x16unorm => Function::Pack2x16unorm,
+ Mf::Pack4x8snorm => Function::Pack4x8snorm,
+ Mf::Pack4x8unorm => Function::Pack4x8unorm,
+ // Data Unpacking
+ Mf::Unpack2x16float => Function::Unpack2x16float,
+ Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
+ Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
+ Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
+ Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
+ _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
+ };
+
+ match fun {
+ Function::Asincosh { is_sin } => {
+ write!(self.out, "log(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " + sqrt(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " * ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ match is_sin {
+ true => write!(self.out, " + 1.0))")?,
+ false => write!(self.out, " - 1.0))")?,
+ }
+ }
+ Function::Atanh => {
+ write!(self.out, "0.5 * log((1.0 + ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") / (1.0 - ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ }
+ Function::ExtractBits => {
+ // e: T,
+ // offset: u32,
+ // count: u32
+ // T is u32 or i32 or vecN<u32> or vecN<i32>
+ if let (Some(offset), Some(count)) = (arg1, arg2) {
+ let scalar_width: u8 = 32;
+ // Works for signed and unsigned
+ // (count == 0 ? 0 : (e << (32 - count - offset)) >> (32 - count))
+ write!(self.out, "(")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, " == 0 ? 0 : (")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " << ({scalar_width} - ")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, " - ")?;
+ self.write_expr(module, offset, func_ctx)?;
+ write!(self.out, ")) >> ({scalar_width} - ")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, "))")?;
+ }
+ }
+ Function::InsertBits => {
+ // e: T,
+ // newbits: T,
+ // offset: u32,
+ // count: u32
+ // returns T
+ // T is i32, u32, vecN<i32>, or vecN<u32>
+ if let (Some(newbits), Some(offset), Some(count)) = (arg1, arg2, arg3) {
+ let scalar_width: u8 = 32;
+ let scalar_max: u32 = 0xFFFFFFFF;
+ // mask = ((0xFFFFFFFFu >> (32 - count)) << offset)
+ // (count == 0 ? e : ((e & ~mask) | ((newbits << offset) & mask)))
+ write!(self.out, "(")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, " == 0 ? ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " : ")?;
+ write!(self.out, "(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " & ~")?;
+ // mask
+ write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, ")) << ")?;
+ self.write_expr(module, offset, func_ctx)?;
+ write!(self.out, ")")?;
+ // end mask
+ write!(self.out, ") | ((")?;
+ self.write_expr(module, newbits, func_ctx)?;
+ write!(self.out, " << ")?;
+ self.write_expr(module, offset, func_ctx)?;
+ write!(self.out, ") & ")?;
+ // // mask
+ write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, ")) << ")?;
+ self.write_expr(module, offset, func_ctx)?;
+ write!(self.out, ")")?;
+ // // end mask
+ write!(self.out, "))")?;
+ }
+ }
+ Function::Pack2x16float => {
+ write!(self.out, "(f32tof16(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[0]) | f32tof16(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[1]) << 16)")?;
+ }
+ Function::Pack2x16snorm => {
+ let scale = 32767;
+
+ write!(self.out, "uint((int(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
+ }
+ Function::Pack2x16unorm => {
+ let scale = 65535;
+
+ write!(self.out, "(uint(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
+ }
+ Function::Pack4x8snorm => {
+ let scale = 127;
+
+ write!(self.out, "uint((int(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
+ }
+ Function::Pack4x8unorm => {
+ let scale = 255;
+
+ write!(self.out, "(uint(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
+ }
+
+ Function::Unpack2x16float => {
+ write!(self.out, "float2(f16tof32(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "), f16tof32((")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") >> 16))")?;
+ }
+ Function::Unpack2x16snorm => {
+ let scale = 32767;
+
+ write!(self.out, "(float2(int2(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " << 16, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") >> 16) / {scale}.0)")?;
+ }
+ Function::Unpack2x16unorm => {
+ let scale = 65535;
+
+ write!(self.out, "(float2(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " & 0xFFFF, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " >> 16) / {scale}.0)")?;
+ }
+ Function::Unpack4x8snorm => {
+ let scale = 127;
+
+ write!(self.out, "(float4(int4(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " << 24, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " << 16, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " << 8, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") >> 24) / {scale}.0)")?;
+ }
+ Function::Unpack4x8unorm => {
+ let scale = 255;
+
+ write!(self.out, "(float4(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " & 0xFF, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " >> 8 & 0xFF, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " >> 16 & 0xFF, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " >> 24) / {scale}.0)")?;
+ }
+ Function::Regular(fun_name) => {
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ if let Some(arg) = arg1 {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ if let Some(arg) = arg2 {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ if let Some(arg) = arg3 {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ write!(self.out, ")")?
+ }
+ Function::MissingIntOverload(fun_name) => {
+ let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind();
+ if let Some(ScalarKind::Sint) = scalar_kind {
+ write!(self.out, "asint({fun_name}(asuint(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ } else {
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Function::MissingIntReturnType(fun_name) => {
+ let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind();
+ if let Some(ScalarKind::Sint) = scalar_kind {
+ write!(self.out, "asint({fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Function::CountTrailingZeros => {
+ match *func_ctx.resolve_type(arg, &module.types) {
+ TypeInner::Vector { size, scalar } => {
+ let s = match size {
+ crate::VectorSize::Bi => ".xx",
+ crate::VectorSize::Tri => ".xxx",
+ crate::VectorSize::Quad => ".xxxx",
+ };
+
+ if let ScalarKind::Uint = scalar.kind {
+ write!(self.out, "min((32u){s}, firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "asint(min((32u){s}, firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ }
+ }
+ TypeInner::Scalar(scalar) => {
+ if let ScalarKind::Uint = scalar.kind {
+ write!(self.out, "min(32u, firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "asint(min(32u, firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ return Ok(());
+ }
+ Function::CountLeadingZeros => {
+ match *func_ctx.resolve_type(arg, &module.types) {
+ TypeInner::Vector { size, scalar } => {
+ let s = match size {
+ crate::VectorSize::Bi => ".xx",
+ crate::VectorSize::Tri => ".xxx",
+ crate::VectorSize::Quad => ".xxxx",
+ };
+
+ if let ScalarKind::Uint = scalar.kind {
+ write!(self.out, "((31u){s} - firstbithigh(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ " < (0){s} ? (0){s} : (31){s} - asint(firstbithigh("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ }
+ }
+ TypeInner::Scalar(scalar) => {
+ if let ScalarKind::Uint = scalar.kind {
+ write!(self.out, "(31u - firstbithigh(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " < 0 ? 0 : 31 - asint(firstbithigh(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ return Ok(());
+ }
+ }
+ }
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.write_expr(module, vector, func_ctx)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ self.out.write_char(back::COMPONENTS[sc as usize])?;
+ }
+ }
+ Expression::ArrayLength(expr) => {
+ let var_handle = match func_ctx.expressions[expr] {
+ Expression::AccessIndex { base, index: _ } => {
+ match func_ctx.expressions[base] {
+ Expression::GlobalVariable(handle) => handle,
+ _ => unreachable!(),
+ }
+ }
+ Expression::GlobalVariable(handle) => handle,
+ _ => unreachable!(),
+ };
+
+ let var = &module.global_variables[var_handle];
+ let (offset, stride) = match module.types[var.ty].inner {
+ TypeInner::Array { stride, .. } => (0, stride),
+ TypeInner::Struct { ref members, .. } => {
+ let last = members.last().unwrap();
+ let stride = match module.types[last.ty].inner {
+ TypeInner::Array { stride, .. } => stride,
+ _ => unreachable!(),
+ };
+ (last.offset, stride)
+ }
+ _ => unreachable!(),
+ };
+
+ let storage_access = match var.space {
+ crate::AddressSpace::Storage { access } => access,
+ _ => crate::StorageAccess::default(),
+ };
+ let wrapped_array_length = WrappedArrayLength {
+ writable: storage_access.contains(crate::StorageAccess::STORE),
+ };
+
+ write!(self.out, "((")?;
+ self.write_wrapped_array_length_function_name(wrapped_array_length)?;
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "({var_name}) - {offset}) / {stride})")?
+ }
+ Expression::Derivative { axis, ctrl, expr } => {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
+ let tail = match ctrl {
+ Ctrl::Coarse => "coarse",
+ Ctrl::Fine => "fine",
+ Ctrl::None => unreachable!(),
+ };
+ write!(self.out, "abs(ddx_{tail}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")) + abs(ddy_{tail}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, "))")?
+ } else {
+ let fun_str = match (axis, ctrl) {
+ (Axis::X, Ctrl::Coarse) => "ddx_coarse",
+ (Axis::X, Ctrl::Fine) => "ddx_fine",
+ (Axis::X, Ctrl::None) => "ddx",
+ (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
+ (Axis::Y, Ctrl::Fine) => "ddy_fine",
+ (Axis::Y, Ctrl::None) => "ddy",
+ (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
+ (Axis::Width, Ctrl::None) => "fwidth",
+ };
+ write!(self.out, "{fun_str}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ }
+ Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+
+ let fun_str = match fun {
+ Rf::All => "all",
+ Rf::Any => "any",
+ Rf::IsNan => "isnan",
+ Rf::IsInf => "isinf",
+ };
+ write!(self.out, "{fun_str}(")?;
+ self.write_expr(module, argument, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ write!(self.out, "(")?;
+ self.write_expr(module, condition, func_ctx)?;
+ write!(self.out, " ? ")?;
+ self.write_expr(module, accept, func_ctx)?;
+ write!(self.out, " : ")?;
+ self.write_expr(module, reject, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ // Not supported yet
+ Expression::RayQueryGetIntersection { .. } => unreachable!(),
+ // Nothing to do here, since call expression already cached
+ Expression::CallResult(_)
+ | Expression::AtomicResult { .. }
+ | Expression::WorkGroupUniformLoadResult { .. }
+ | Expression::RayQueryProceedResult => {}
+ }
+
+ if !closing_bracket.is_empty() {
+ write!(self.out, "{closing_bracket}")?;
+ }
+ Ok(())
+ }
+
+ fn write_named_expr(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Expression>,
+ name: String,
+ // The expression which is being named.
+ // Generally, this is the same as handle, except in WorkGroupUniformLoad
+ named: Handle<crate::Expression>,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ match ctx.info[named].ty {
+ proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
+ TypeInner::Struct { .. } => {
+ let ty_name = &self.names[&NameKey::Type(ty_handle)];
+ write!(self.out, "{ty_name}")?;
+ }
+ _ => {
+ self.write_type(module, ty_handle)?;
+ }
+ },
+ proc::TypeResolution::Value(ref inner) => {
+ self.write_value_type(module, inner)?;
+ }
+ }
+
+ let resolved = ctx.resolve_type(named, &module.types);
+
+ write!(self.out, " {name}")?;
+ // If rhs is a array type, we should write array size
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, " = ")?;
+ self.write_expr(module, handle, ctx)?;
+ writeln!(self.out, ";")?;
+ self.named_expressions.insert(named, name);
+
+ Ok(())
+ }
+
+ /// Helper function that write default zero initialization
+ fn write_default_init(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ write!(self.out, "(")?;
+ self.write_type(module, ty)?;
+ if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, ")0")?;
+ Ok(())
+ }
+
+ fn write_barrier(&mut self, barrier: crate::Barrier, level: back::Level) -> BackendResult {
+ if barrier.contains(crate::Barrier::STORAGE) {
+ writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
+ }
+ if barrier.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
+ }
+ Ok(())
+ }
+}
+
+pub(super) struct MatrixType {
+ pub(super) columns: crate::VectorSize,
+ pub(super) rows: crate::VectorSize,
+ pub(super) width: crate::Bytes,
+}
+
+pub(super) fn get_inner_matrix_data(
+ module: &Module,
+ handle: Handle<crate::Type>,
+) -> Option<MatrixType> {
+ match module.types[handle].inner {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => Some(MatrixType {
+ columns,
+ rows,
+ width: scalar.width,
+ }),
+ TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
+ _ => None,
+ }
+}
+
+/// Returns the matrix data if the access chain starting at `base`:
+/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
+/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
+/// - ends at an expression with resolved type of [`TypeInner::Struct`]
+pub(super) fn get_inner_matrix_of_struct_array_member(
+ module: &Module,
+ base: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ direct: bool,
+) -> Option<MatrixType> {
+ let mut mat_data = None;
+ let mut array_base = None;
+
+ let mut current_base = base;
+ loop {
+ let mut resolved = func_ctx.resolve_type(current_base, &module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ };
+
+ match *resolved {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ mat_data = Some(MatrixType {
+ columns,
+ rows,
+ width: scalar.width,
+ })
+ }
+ TypeInner::Array { base, .. } => {
+ array_base = Some(base);
+ }
+ TypeInner::Struct { .. } => {
+ if let Some(array_base) = array_base {
+ if direct {
+ return mat_data;
+ } else {
+ return get_inner_matrix_data(module, array_base);
+ }
+ }
+
+ break;
+ }
+ _ => break,
+ }
+
+ current_base = match func_ctx.expressions[current_base] {
+ crate::Expression::Access { base, .. } => base,
+ crate::Expression::AccessIndex { base, .. } => base,
+ _ => break,
+ };
+ }
+ None
+}
+
+/// Returns the matrix data if the access chain starting at `base`:
+/// - starts with an expression with resolved type of [`TypeInner::Matrix`]
+/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
+/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
+fn get_inner_matrix_of_global_uniform(
+ module: &Module,
+ base: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+) -> Option<MatrixType> {
+ let mut mat_data = None;
+ let mut array_base = None;
+
+ let mut current_base = base;
+ loop {
+ let mut resolved = func_ctx.resolve_type(current_base, &module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ };
+
+ match *resolved {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ mat_data = Some(MatrixType {
+ columns,
+ rows,
+ width: scalar.width,
+ })
+ }
+ TypeInner::Array { base, .. } => {
+ array_base = Some(base);
+ }
+ _ => break,
+ }
+
+ current_base = match func_ctx.expressions[current_base] {
+ crate::Expression::Access { base, .. } => base,
+ crate::Expression::AccessIndex { base, .. } => base,
+ crate::Expression::GlobalVariable(handle)
+ if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
+ {
+ return mat_data.or_else(|| {
+ array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
+ })
+ }
+ _ => break,
+ };
+ }
+ None
+}
diff --git a/third_party/rust/naga/src/back/mod.rs b/third_party/rust/naga/src/back/mod.rs
new file mode 100644
index 0000000000..8100b930e9
--- /dev/null
+++ b/third_party/rust/naga/src/back/mod.rs
@@ -0,0 +1,273 @@
+/*!
+Backend functions that export shader [`Module`](super::Module)s into binary and text formats.
+*/
+#![allow(dead_code)] // can be dead if none of the enabled backends need it
+
+#[cfg(feature = "dot-out")]
+pub mod dot;
+#[cfg(feature = "glsl-out")]
+pub mod glsl;
+#[cfg(feature = "hlsl-out")]
+pub mod hlsl;
+#[cfg(feature = "msl-out")]
+pub mod msl;
+#[cfg(feature = "spv-out")]
+pub mod spv;
+#[cfg(feature = "wgsl-out")]
+pub mod wgsl;
+
+const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];
+const INDENT: &str = " ";
+const BAKE_PREFIX: &str = "_e";
+
+type NeedBakeExpressions = crate::FastHashSet<crate::Handle<crate::Expression>>;
+
+#[derive(Clone, Copy)]
+struct Level(usize);
+
+impl Level {
+ const fn next(&self) -> Self {
+ Level(self.0 + 1)
+ }
+}
+
+impl std::fmt::Display for Level {
+ fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
+ (0..self.0).try_for_each(|_| formatter.write_str(INDENT))
+ }
+}
+
+/// Whether we're generating an entry point or a regular function.
+///
+/// Backend languages often require different code for a [`Function`]
+/// depending on whether it represents an [`EntryPoint`] or not.
+/// Backends can pass common code one of these values to select the
+/// right behavior.
+///
+/// These values also carry enough information to find the `Function`
+/// in the [`Module`]: the `Handle` for a regular function, or the
+/// index into [`Module::entry_points`] for an entry point.
+///
+/// [`Function`]: crate::Function
+/// [`EntryPoint`]: crate::EntryPoint
+/// [`Module`]: crate::Module
+/// [`Module::entry_points`]: crate::Module::entry_points
+enum FunctionType {
+ /// A regular function.
+ Function(crate::Handle<crate::Function>),
+ /// An [`EntryPoint`], and its index in [`Module::entry_points`].
+ ///
+ /// [`EntryPoint`]: crate::EntryPoint
+ /// [`Module::entry_points`]: crate::Module::entry_points
+ EntryPoint(crate::proc::EntryPointIndex),
+}
+
+impl FunctionType {
+ fn is_compute_entry_point(&self, module: &crate::Module) -> bool {
+ match *self {
+ FunctionType::EntryPoint(index) => {
+ module.entry_points[index as usize].stage == crate::ShaderStage::Compute
+ }
+ FunctionType::Function(_) => false,
+ }
+ }
+}
+
+/// Helper structure that stores data needed when writing the function
+struct FunctionCtx<'a> {
+ /// The current function being written
+ ty: FunctionType,
+ /// Analysis about the function
+ info: &'a crate::valid::FunctionInfo,
+ /// The expression arena of the current function being written
+ expressions: &'a crate::Arena<crate::Expression>,
+ /// Map of expressions that have associated variable names
+ named_expressions: &'a crate::NamedExpressions,
+}
+
+impl FunctionCtx<'_> {
+ fn resolve_type<'a>(
+ &'a self,
+ handle: crate::Handle<crate::Expression>,
+ types: &'a crate::UniqueArena<crate::Type>,
+ ) -> &'a crate::TypeInner {
+ self.info[handle].ty.inner_with(types)
+ }
+
+ /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for a local in the current function
+ const fn name_key(&self, local: crate::Handle<crate::LocalVariable>) -> crate::proc::NameKey {
+ match self.ty {
+ FunctionType::Function(handle) => crate::proc::NameKey::FunctionLocal(handle, local),
+ FunctionType::EntryPoint(idx) => crate::proc::NameKey::EntryPointLocal(idx, local),
+ }
+ }
+
+ /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for a function argument.
+ ///
+ /// # Panics
+ /// - If the function arguments are less or equal to `arg`
+ const fn argument_key(&self, arg: u32) -> crate::proc::NameKey {
+ match self.ty {
+ FunctionType::Function(handle) => crate::proc::NameKey::FunctionArgument(handle, arg),
+ FunctionType::EntryPoint(ep_index) => {
+ crate::proc::NameKey::EntryPointArgument(ep_index, arg)
+ }
+ }
+ }
+
+ // Returns true if the given expression points to a fixed-function pipeline input.
+ fn is_fixed_function_input(
+ &self,
+ mut expression: crate::Handle<crate::Expression>,
+ module: &crate::Module,
+ ) -> Option<crate::BuiltIn> {
+ let ep_function = match self.ty {
+ FunctionType::Function(_) => return None,
+ FunctionType::EntryPoint(ep_index) => &module.entry_points[ep_index as usize].function,
+ };
+ let mut built_in = None;
+ loop {
+ match self.expressions[expression] {
+ crate::Expression::FunctionArgument(arg_index) => {
+ return match ep_function.arguments[arg_index as usize].binding {
+ Some(crate::Binding::BuiltIn(bi)) => Some(bi),
+ _ => built_in,
+ };
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ match *self.resolve_type(base, &module.types) {
+ crate::TypeInner::Struct { ref members, .. } => {
+ if let Some(crate::Binding::BuiltIn(bi)) =
+ members[index as usize].binding
+ {
+ built_in = Some(bi);
+ }
+ }
+ _ => return None,
+ }
+ expression = base;
+ }
+ _ => return None,
+ }
+ }
+ }
+}
+
+impl crate::Expression {
+ /// Returns the ref count, upon reaching which this expression
+ /// should be considered for baking.
+ ///
+ /// Note: we have to cache any expressions that depend on the control flow,
+ /// or otherwise they may be moved into a non-uniform control flow, accidentally.
+ /// See the [module-level documentation][emit] for details.
+ ///
+ /// [emit]: index.html#expression-evaluation-time
+ const fn bake_ref_count(&self) -> usize {
+ match *self {
+ // accesses are never cached, only loads are
+ crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => usize::MAX,
+ // sampling may use the control flow, and image ops look better by themselves
+ crate::Expression::ImageSample { .. } | crate::Expression::ImageLoad { .. } => 1,
+ // derivatives use the control flow
+ crate::Expression::Derivative { .. } => 1,
+ // TODO: We need a better fix for named `Load` expressions
+ // More info - https://github.com/gfx-rs/naga/pull/914
+ // And https://github.com/gfx-rs/naga/issues/910
+ crate::Expression::Load { .. } => 1,
+ // cache expressions that are referenced multiple times
+ _ => 2,
+ }
+ }
+}
+
+/// Helper function that returns the string corresponding to the [`BinaryOperator`](crate::BinaryOperator)
+/// # Notes
+/// Used by `glsl-out`, `msl-out`, `wgsl-out`, `hlsl-out`.
+const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str {
+ use crate::BinaryOperator as Bo;
+ match op {
+ Bo::Add => "+",
+ Bo::Subtract => "-",
+ Bo::Multiply => "*",
+ Bo::Divide => "/",
+ Bo::Modulo => "%",
+ Bo::Equal => "==",
+ Bo::NotEqual => "!=",
+ Bo::Less => "<",
+ Bo::LessEqual => "<=",
+ Bo::Greater => ">",
+ Bo::GreaterEqual => ">=",
+ Bo::And => "&",
+ Bo::ExclusiveOr => "^",
+ Bo::InclusiveOr => "|",
+ Bo::LogicalAnd => "&&",
+ Bo::LogicalOr => "||",
+ Bo::ShiftLeft => "<<",
+ Bo::ShiftRight => ">>",
+ }
+}
+
+/// Helper function that returns the string corresponding to the [`VectorSize`](crate::VectorSize)
+/// # Notes
+/// Used by `msl-out`, `wgsl-out`, `hlsl-out`.
+const fn vector_size_str(size: crate::VectorSize) -> &'static str {
+ match size {
+ crate::VectorSize::Bi => "2",
+ crate::VectorSize::Tri => "3",
+ crate::VectorSize::Quad => "4",
+ }
+}
+
+impl crate::TypeInner {
+ const fn is_handle(&self) -> bool {
+ match *self {
+ crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => true,
+ _ => false,
+ }
+ }
+}
+
+impl crate::Statement {
+ /// Returns true if the statement directly terminates the current block.
+ ///
+ /// Used to decide whether case blocks require a explicit `break`.
+ pub const fn is_terminator(&self) -> bool {
+ match *self {
+ crate::Statement::Break
+ | crate::Statement::Continue
+ | crate::Statement::Return { .. }
+ | crate::Statement::Kill => true,
+ _ => false,
+ }
+ }
+}
+
+bitflags::bitflags! {
+ /// Ray flags, for a [`RayDesc`]'s `flags` field.
+ ///
+ /// Note that these exactly correspond to the SPIR-V "Ray Flags" mask, and
+ /// the SPIR-V backend passes them directly through to the
+ /// `OpRayQueryInitializeKHR` instruction. (We have to choose something, so
+ /// we might as well make one back end's life easier.)
+ ///
+ /// [`RayDesc`]: crate::Module::generate_ray_desc_type
+ #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
+ pub struct RayFlag: u32 {
+ const OPAQUE = 0x01;
+ const NO_OPAQUE = 0x02;
+ const TERMINATE_ON_FIRST_HIT = 0x04;
+ const SKIP_CLOSEST_HIT_SHADER = 0x08;
+ const CULL_BACK_FACING = 0x10;
+ const CULL_FRONT_FACING = 0x20;
+ const CULL_OPAQUE = 0x40;
+ const CULL_NO_OPAQUE = 0x80;
+ const SKIP_TRIANGLES = 0x100;
+ const SKIP_AABBS = 0x200;
+ }
+}
+
+#[repr(u32)]
+enum RayIntersectionType {
+ Triangle = 1,
+ BoundingBox = 4,
+}
diff --git a/third_party/rust/naga/src/back/msl/keywords.rs b/third_party/rust/naga/src/back/msl/keywords.rs
new file mode 100644
index 0000000000..f0025bf239
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/keywords.rs
@@ -0,0 +1,342 @@
+// MSLS - Metal Shading Language Specification:
+// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+//
+// C++ - Standard for Programming Language C++ (N4431)
+// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4431.pdf
+pub const RESERVED: &[&str] = &[
+ // Standard for Programming Language C++ (N4431): 2.5 Alternative tokens
+ "and",
+ "bitor",
+ "or",
+ "xor",
+ "compl",
+ "bitand",
+ "and_eq",
+ "or_eq",
+ "xor_eq",
+ "not",
+ "not_eq",
+ // Standard for Programming Language C++ (N4431): 2.11 Keywords
+ "alignas",
+ "alignof",
+ "asm",
+ "auto",
+ "bool",
+ "break",
+ "case",
+ "catch",
+ "char",
+ "char16_t",
+ "char32_t",
+ "class",
+ "const",
+ "constexpr",
+ "const_cast",
+ "continue",
+ "decltype",
+ "default",
+ "delete",
+ "do",
+ "double",
+ "dynamic_cast",
+ "else",
+ "enum",
+ "explicit",
+ "export",
+ "extern",
+ "false",
+ "float",
+ "for",
+ "friend",
+ "goto",
+ "if",
+ "inline",
+ "int",
+ "long",
+ "mutable",
+ "namespace",
+ "new",
+ "noexcept",
+ "nullptr",
+ "operator",
+ "private",
+ "protected",
+ "public",
+ "register",
+ "reinterpret_cast",
+ "return",
+ "short",
+ "signed",
+ "sizeof",
+ "static",
+ "static_assert",
+ "static_cast",
+ "struct",
+ "switch",
+ "template",
+ "this",
+ "thread_local",
+ "throw",
+ "true",
+ "try",
+ "typedef",
+ "typeid",
+ "typename",
+ "union",
+ "unsigned",
+ "using",
+ "virtual",
+ "void",
+ "volatile",
+ "wchar_t",
+ "while",
+ // Metal Shading Language Specification: 1.4.4 Restrictions
+ "main",
+ // Metal Shading Language Specification: 2.1 Scalar Data Types
+ "int8_t",
+ "uchar",
+ "uint8_t",
+ "int16_t",
+ "ushort",
+ "uint16_t",
+ "int32_t",
+ "uint",
+ "uint32_t",
+ "int64_t",
+ "uint64_t",
+ "half",
+ "bfloat",
+ "size_t",
+ "ptrdiff_t",
+ // Metal Shading Language Specification: 2.2 Vector Data Types
+ "bool2",
+ "bool3",
+ "bool4",
+ "char2",
+ "char3",
+ "char4",
+ "short2",
+ "short3",
+ "short4",
+ "int2",
+ "int3",
+ "int4",
+ "long2",
+ "long3",
+ "long4",
+ "uchar2",
+ "uchar3",
+ "uchar4",
+ "ushort2",
+ "ushort3",
+ "ushort4",
+ "uint2",
+ "uint3",
+ "uint4",
+ "ulong2",
+ "ulong3",
+ "ulong4",
+ "half2",
+ "half3",
+ "half4",
+ "bfloat2",
+ "bfloat3",
+ "bfloat4",
+ "float2",
+ "float3",
+ "float4",
+ "vec",
+ // Metal Shading Language Specification: 2.2.3 Packed Vector Types
+ "packed_bool2",
+ "packed_bool3",
+ "packed_bool4",
+ "packed_char2",
+ "packed_char3",
+ "packed_char4",
+ "packed_short2",
+ "packed_short3",
+ "packed_short4",
+ "packed_int2",
+ "packed_int3",
+ "packed_int4",
+ "packed_uchar2",
+ "packed_uchar3",
+ "packed_uchar4",
+ "packed_ushort2",
+ "packed_ushort3",
+ "packed_ushort4",
+ "packed_uint2",
+ "packed_uint3",
+ "packed_uint4",
+ "packed_half2",
+ "packed_half3",
+ "packed_half4",
+ "packed_bfloat2",
+ "packed_bfloat3",
+ "packed_bfloat4",
+ "packed_float2",
+ "packed_float3",
+ "packed_float4",
+ "packed_long2",
+ "packed_long3",
+ "packed_long4",
+ "packed_vec",
+ // Metal Shading Language Specification: 2.3 Matrix Data Types
+ "half2x2",
+ "half2x3",
+ "half2x4",
+ "half3x2",
+ "half3x3",
+ "half3x4",
+ "half4x2",
+ "half4x3",
+ "half4x4",
+ "float2x2",
+ "float2x3",
+ "float2x4",
+ "float3x2",
+ "float3x3",
+ "float3x4",
+ "float4x2",
+ "float4x3",
+ "float4x4",
+ "matrix",
+ // Metal Shading Language Specification: 2.6 Atomic Data Types
+ "atomic",
+ "atomic_int",
+ "atomic_uint",
+ "atomic_bool",
+ "atomic_ulong",
+ "atomic_float",
+ // Metal Shading Language Specification: 2.20 Type Conversions and Re-interpreting Data
+ "as_type",
+ // Metal Shading Language Specification: 4 Address Spaces
+ "device",
+ "constant",
+ "thread",
+ "threadgroup",
+ "threadgroup_imageblock",
+ "ray_data",
+ "object_data",
+ // Metal Shading Language Specification: 5.1 Functions
+ "vertex",
+ "fragment",
+ "kernel",
+ // Metal Shading Language Specification: 6.1 Namespace and Header Files
+ "metal",
+ // C99 / C++ extension:
+ "restrict",
+ // Metal reserved types in <metal_types>:
+ "llong",
+ "ullong",
+ "quad",
+ "complex",
+ "imaginary",
+ // Constants in <metal_types>:
+ "CHAR_BIT",
+ "SCHAR_MAX",
+ "SCHAR_MIN",
+ "UCHAR_MAX",
+ "CHAR_MAX",
+ "CHAR_MIN",
+ "USHRT_MAX",
+ "SHRT_MAX",
+ "SHRT_MIN",
+ "UINT_MAX",
+ "INT_MAX",
+ "INT_MIN",
+ "ULONG_MAX",
+ "LONG_MAX",
+ "LONG_MIN",
+ "ULLONG_MAX",
+ "LLONG_MAX",
+ "LLONG_MIN",
+ "FLT_DIG",
+ "FLT_MANT_DIG",
+ "FLT_MAX_10_EXP",
+ "FLT_MAX_EXP",
+ "FLT_MIN_10_EXP",
+ "FLT_MIN_EXP",
+ "FLT_RADIX",
+ "FLT_MAX",
+ "FLT_MIN",
+ "FLT_EPSILON",
+ "FLT_DECIMAL_DIG",
+ "FP_ILOGB0",
+ "FP_ILOGB0",
+ "FP_ILOGBNAN",
+ "FP_ILOGBNAN",
+ "MAXFLOAT",
+ "HUGE_VALF",
+ "INFINITY",
+ "NAN",
+ "M_E_F",
+ "M_LOG2E_F",
+ "M_LOG10E_F",
+ "M_LN2_F",
+ "M_LN10_F",
+ "M_PI_F",
+ "M_PI_2_F",
+ "M_PI_4_F",
+ "M_1_PI_F",
+ "M_2_PI_F",
+ "M_2_SQRTPI_F",
+ "M_SQRT2_F",
+ "M_SQRT1_2_F",
+ "HALF_DIG",
+ "HALF_MANT_DIG",
+ "HALF_MAX_10_EXP",
+ "HALF_MAX_EXP",
+ "HALF_MIN_10_EXP",
+ "HALF_MIN_EXP",
+ "HALF_RADIX",
+ "HALF_MAX",
+ "HALF_MIN",
+ "HALF_EPSILON",
+ "HALF_DECIMAL_DIG",
+ "MAXHALF",
+ "HUGE_VALH",
+ "M_E_H",
+ "M_LOG2E_H",
+ "M_LOG10E_H",
+ "M_LN2_H",
+ "M_LN10_H",
+ "M_PI_H",
+ "M_PI_2_H",
+ "M_PI_4_H",
+ "M_1_PI_H",
+ "M_2_PI_H",
+ "M_2_SQRTPI_H",
+ "M_SQRT2_H",
+ "M_SQRT1_2_H",
+ "DBL_DIG",
+ "DBL_MANT_DIG",
+ "DBL_MAX_10_EXP",
+ "DBL_MAX_EXP",
+ "DBL_MIN_10_EXP",
+ "DBL_MIN_EXP",
+ "DBL_RADIX",
+ "DBL_MAX",
+ "DBL_MIN",
+ "DBL_EPSILON",
+ "DBL_DECIMAL_DIG",
+ "MAXDOUBLE",
+ "HUGE_VAL",
+ "M_E",
+ "M_LOG2E",
+ "M_LOG10E",
+ "M_LN2",
+ "M_LN10",
+ "M_PI",
+ "M_PI_2",
+ "M_PI_4",
+ "M_1_PI",
+ "M_2_PI",
+ "M_2_SQRTPI",
+ "M_SQRT2",
+ "M_SQRT1_2",
+ // Naga utilities
+ "DefaultConstructible",
+ super::writer::FREXP_FUNCTION,
+ super::writer::MODF_FUNCTION,
+];
diff --git a/third_party/rust/naga/src/back/msl/mod.rs b/third_party/rust/naga/src/back/msl/mod.rs
new file mode 100644
index 0000000000..5ef18730c9
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/mod.rs
@@ -0,0 +1,541 @@
+/*!
+Backend for [MSL][msl] (Metal Shading Language).
+
+## Binding model
+
+Metal's bindings are flat per resource. Since there isn't an obvious mapping
+from SPIR-V's descriptor sets, we require a separate mapping provided in the options.
+This mapping may have one or more resource end points for each descriptor set + index
+pair.
+
+## Entry points
+
+Even though MSL and our IR appear to be similar in that the entry points in both can
+accept arguments and return values, the restrictions are different.
+MSL allows the varyings to be either in separate arguments, or inside a single
+`[[stage_in]]` struct. We gather input varyings and form this artificial structure.
+We also add all the (non-Private) globals into the arguments.
+
+At the beginning of the entry point, we assign the local constants and re-compose
+the arguments as they are declared on IR side, so that the rest of the logic can
+pretend that MSL doesn't have all the restrictions it has.
+
+For the result type, if it's a structure, we re-compose it with a temporary value
+holding the result.
+
+[msl]: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+*/
+
+use crate::{arena::Handle, proc::index, valid::ModuleInfo};
+use std::fmt::{Error as FmtError, Write};
+
+mod keywords;
+pub mod sampler;
+mod writer;
+
+pub use writer::Writer;
+
+pub type Slot = u8;
+pub type InlineSamplerIndex = u8;
+
+#[derive(Clone, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum BindSamplerTarget {
+ Resource(Slot),
+ Inline(InlineSamplerIndex),
+}
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
+pub struct BindTarget {
+ pub buffer: Option<Slot>,
+ pub texture: Option<Slot>,
+ pub sampler: Option<BindSamplerTarget>,
+ /// If the binding is an unsized binding array, this overrides the size.
+ pub binding_array_size: Option<u32>,
+ pub mutable: bool,
+}
+
+// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
+
+#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
+pub struct EntryPointResources {
+ pub resources: BindingMap,
+
+ pub push_constant_buffer: Option<Slot>,
+
+ /// The slot of a buffer that contains an array of `u32`,
+ /// one for the size of each bound buffer that contains a runtime array,
+ /// in order of [`crate::GlobalVariable`] declarations.
+ pub sizes_buffer: Option<Slot>,
+}
+
+pub type EntryPointResourceMap = std::collections::BTreeMap<String, EntryPointResources>;
+
+enum ResolvedBinding {
+ BuiltIn(crate::BuiltIn),
+ Attribute(u32),
+ Color {
+ location: u32,
+ second_blend_source: bool,
+ },
+ User {
+ prefix: &'static str,
+ index: u32,
+ interpolation: Option<ResolvedInterpolation>,
+ },
+ Resource(BindTarget),
+}
+
+#[derive(Copy, Clone)]
+enum ResolvedInterpolation {
+ CenterPerspective,
+ CenterNoPerspective,
+ CentroidPerspective,
+ CentroidNoPerspective,
+ SamplePerspective,
+ SampleNoPerspective,
+ Flat,
+}
+
+// Note: some of these should be removed in favor of proper IR validation.
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error(transparent)]
+ Format(#[from] FmtError),
+ #[error("bind target {0:?} is empty")]
+ UnimplementedBindTarget(BindTarget),
+ #[error("composing of {0:?} is not implemented yet")]
+ UnsupportedCompose(Handle<crate::Type>),
+ #[error("operation {0:?} is not implemented yet")]
+ UnsupportedBinaryOp(crate::BinaryOperator),
+ #[error("standard function '{0}' is not implemented yet")]
+ UnsupportedCall(String),
+ #[error("feature '{0}' is not implemented yet")]
+ FeatureNotImplemented(String),
+ #[error("module is not valid")]
+ Validation,
+ #[error("BuiltIn {0:?} is not supported")]
+ UnsupportedBuiltIn(crate::BuiltIn),
+ #[error("capability {0:?} is not supported")]
+ CapabilityNotSupported(crate::valid::Capabilities),
+ #[error("attribute '{0}' is not supported for target MSL version")]
+ UnsupportedAttribute(String),
+ #[error("function '{0}' is not supported for target MSL version")]
+ UnsupportedFunction(String),
+ #[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")]
+ UnsupportedWriteableStorageBuffer,
+ #[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")]
+ UnsupportedWriteableStorageTexture(crate::ShaderStage),
+ #[error("can not use read-write storage textures prior to MSL 1.2")]
+ UnsupportedRWStorageTexture,
+ #[error("array of '{0}' is not supported for target MSL version")]
+ UnsupportedArrayOf(String),
+ #[error("array of type '{0:?}' is not supported")]
+ UnsupportedArrayOfType(Handle<crate::Type>),
+ #[error("ray tracing is not supported prior to MSL 2.3")]
+ UnsupportedRayTracing,
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum EntryPointError {
+ #[error("global '{0}' doesn't have a binding")]
+ MissingBinding(String),
+ #[error("mapping of {0:?} is missing")]
+ MissingBindTarget(crate::ResourceBinding),
+ #[error("mapping for push constants is missing")]
+ MissingPushConstants,
+ #[error("mapping for sizes buffer is missing")]
+ MissingSizesBuffer,
+}
+
+/// Points in the MSL code where we might emit a pipeline input or output.
+///
+/// Note that, even though vertex shaders' outputs are always fragment
+/// shaders' inputs, we still need to distinguish `VertexOutput` and
+/// `FragmentInput`, since there are certain differences in the way
+/// [`ResolvedBinding`s] are represented on either side.
+///
+/// [`ResolvedBinding`s]: ResolvedBinding
+#[derive(Clone, Copy, Debug)]
+enum LocationMode {
+ /// Input to the vertex shader.
+ VertexInput,
+
+ /// Output from the vertex shader.
+ VertexOutput,
+
+ /// Input to the fragment shader.
+ FragmentInput,
+
+ /// Output from the fragment shader.
+ FragmentOutput,
+
+ /// Compute shader input or output.
+ Uniform,
+}
+
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Options {
+ /// (Major, Minor) target version of the Metal Shading Language.
+ pub lang_version: (u8, u8),
+ /// Map of entry-point resources, indexed by entry point function name, to slots.
+ pub per_entry_point_map: EntryPointResourceMap,
+ /// Samplers to be inlined into the code.
+ pub inline_samplers: Vec<sampler::InlineSampler>,
+ /// Make it possible to link different stages via SPIRV-Cross.
+ pub spirv_cross_compatibility: bool,
+ /// Don't panic on missing bindings, instead generate invalid MSL.
+ pub fake_missing_bindings: bool,
+ /// Bounds checking policies.
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub bounds_check_policies: index::BoundsCheckPolicies,
+ /// Should workgroup variables be zero initialized (by polyfilling)?
+ pub zero_initialize_workgroup_memory: bool,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ lang_version: (1, 0),
+ per_entry_point_map: EntryPointResourceMap::default(),
+ inline_samplers: Vec::new(),
+ spirv_cross_compatibility: false,
+ fake_missing_bindings: true,
+ bounds_check_policies: index::BoundsCheckPolicies::default(),
+ zero_initialize_workgroup_memory: true,
+ }
+ }
+}
+
+/// A subset of options that are meant to be changed per pipeline.
+#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct PipelineOptions {
+ /// Allow `BuiltIn::PointSize` and inject it if doesn't exist.
+ ///
+ /// Metal doesn't like this for non-point primitive topologies and requires it for
+ /// point primitive topologies.
+ ///
+ /// Enable this for vertex shaders with point primitive topologies.
+ pub allow_and_force_point_size: bool,
+}
+
+impl Options {
+ fn resolve_local_binding(
+ &self,
+ binding: &crate::Binding,
+ mode: LocationMode,
+ ) -> Result<ResolvedBinding, Error> {
+ match *binding {
+ crate::Binding::BuiltIn(mut built_in) => {
+ match built_in {
+ crate::BuiltIn::Position { ref mut invariant } => {
+ if *invariant && self.lang_version < (2, 1) {
+ return Err(Error::UnsupportedAttribute("invariant".to_string()));
+ }
+
+ // The 'invariant' attribute may only appear on vertex
+ // shader outputs, not fragment shader inputs.
+ if !matches!(mode, LocationMode::VertexOutput) {
+ *invariant = false;
+ }
+ }
+ crate::BuiltIn::BaseInstance if self.lang_version < (1, 2) => {
+ return Err(Error::UnsupportedAttribute("base_instance".to_string()));
+ }
+ crate::BuiltIn::InstanceIndex if self.lang_version < (1, 2) => {
+ return Err(Error::UnsupportedAttribute("instance_id".to_string()));
+ }
+ // macOS: Since Metal 2.2
+ // iOS: Since Metal 2.3 (check depends on https://github.com/gfx-rs/naga/issues/2164)
+ crate::BuiltIn::PrimitiveIndex if self.lang_version < (2, 2) => {
+ return Err(Error::UnsupportedAttribute("primitive_id".to_string()));
+ }
+ _ => {}
+ }
+
+ Ok(ResolvedBinding::BuiltIn(built_in))
+ }
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ second_blend_source,
+ } => match mode {
+ LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(location)),
+ LocationMode::FragmentOutput => {
+ if second_blend_source && self.lang_version < (1, 2) {
+ return Err(Error::UnsupportedAttribute(
+ "second_blend_source".to_string(),
+ ));
+ }
+ Ok(ResolvedBinding::Color {
+ location,
+ second_blend_source,
+ })
+ }
+ LocationMode::VertexOutput | LocationMode::FragmentInput => {
+ Ok(ResolvedBinding::User {
+ prefix: if self.spirv_cross_compatibility {
+ "locn"
+ } else {
+ "loc"
+ },
+ index: location,
+ interpolation: {
+ // unwrap: The verifier ensures that vertex shader outputs and fragment
+ // shader inputs always have fully specified interpolation, and that
+ // sampling is `None` only for Flat interpolation.
+ let interpolation = interpolation.unwrap();
+ let sampling = sampling.unwrap_or(crate::Sampling::Center);
+ Some(ResolvedInterpolation::from_binding(interpolation, sampling))
+ },
+ })
+ }
+ LocationMode::Uniform => {
+ log::error!(
+ "Unexpected Binding::Location({}) for the Uniform mode",
+ location
+ );
+ Err(Error::Validation)
+ }
+ },
+ }
+ }
+
+ fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> {
+ self.per_entry_point_map.get(&ep.name)
+ }
+
+ fn get_resource_binding_target(
+ &self,
+ ep: &crate::EntryPoint,
+ res_binding: &crate::ResourceBinding,
+ ) -> Option<&BindTarget> {
+ self.get_entry_point_resources(ep)
+ .and_then(|res| res.resources.get(res_binding))
+ }
+
+ fn resolve_resource_binding(
+ &self,
+ ep: &crate::EntryPoint,
+ res_binding: &crate::ResourceBinding,
+ ) -> Result<ResolvedBinding, EntryPointError> {
+ let target = self.get_resource_binding_target(ep, res_binding);
+ match target {
+ Some(target) => Ok(ResolvedBinding::Resource(target.clone())),
+ None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
+ prefix: "fake",
+ index: 0,
+ interpolation: None,
+ }),
+ None => Err(EntryPointError::MissingBindTarget(res_binding.clone())),
+ }
+ }
+
+ fn resolve_push_constants(
+ &self,
+ ep: &crate::EntryPoint,
+ ) -> Result<ResolvedBinding, EntryPointError> {
+ let slot = self
+ .get_entry_point_resources(ep)
+ .and_then(|res| res.push_constant_buffer);
+ match slot {
+ Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
+ buffer: Some(slot),
+ ..Default::default()
+ })),
+ None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
+ prefix: "fake",
+ index: 0,
+ interpolation: None,
+ }),
+ None => Err(EntryPointError::MissingPushConstants),
+ }
+ }
+
+ fn resolve_sizes_buffer(
+ &self,
+ ep: &crate::EntryPoint,
+ ) -> Result<ResolvedBinding, EntryPointError> {
+ let slot = self
+ .get_entry_point_resources(ep)
+ .and_then(|res| res.sizes_buffer);
+ match slot {
+ Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
+ buffer: Some(slot),
+ ..Default::default()
+ })),
+ None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
+ prefix: "fake",
+ index: 0,
+ interpolation: None,
+ }),
+ None => Err(EntryPointError::MissingSizesBuffer),
+ }
+ }
+}
+
+impl ResolvedBinding {
+ fn as_inline_sampler<'a>(&self, options: &'a Options) -> Option<&'a sampler::InlineSampler> {
+ match *self {
+ Self::Resource(BindTarget {
+ sampler: Some(BindSamplerTarget::Inline(index)),
+ ..
+ }) => Some(&options.inline_samplers[index as usize]),
+ _ => None,
+ }
+ }
+
+ const fn as_bind_target(&self) -> Option<&BindTarget> {
+ match *self {
+ Self::Resource(ref target) => Some(target),
+ _ => None,
+ }
+ }
+
+ fn try_fmt<W: Write>(&self, out: &mut W) -> Result<(), Error> {
+ write!(out, " [[")?;
+ match *self {
+ Self::BuiltIn(built_in) => {
+ use crate::BuiltIn as Bi;
+ let name = match built_in {
+ Bi::Position { invariant: false } => "position",
+ Bi::Position { invariant: true } => "position, invariant",
+ // vertex
+ Bi::BaseInstance => "base_instance",
+ Bi::BaseVertex => "base_vertex",
+ Bi::ClipDistance => "clip_distance",
+ Bi::InstanceIndex => "instance_id",
+ Bi::PointSize => "point_size",
+ Bi::VertexIndex => "vertex_id",
+ // fragment
+ Bi::FragDepth => "depth(any)",
+ Bi::PointCoord => "point_coord",
+ Bi::FrontFacing => "front_facing",
+ Bi::PrimitiveIndex => "primitive_id",
+ Bi::SampleIndex => "sample_id",
+ Bi::SampleMask => "sample_mask",
+ // compute
+ Bi::GlobalInvocationId => "thread_position_in_grid",
+ Bi::LocalInvocationId => "thread_position_in_threadgroup",
+ Bi::LocalInvocationIndex => "thread_index_in_threadgroup",
+ Bi::WorkGroupId => "threadgroup_position_in_grid",
+ Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
+ Bi::NumWorkGroups => "threadgroups_per_grid",
+ Bi::CullDistance | Bi::ViewIndex => {
+ return Err(Error::UnsupportedBuiltIn(built_in))
+ }
+ };
+ write!(out, "{name}")?;
+ }
+ Self::Attribute(index) => write!(out, "attribute({index})")?,
+ Self::Color {
+ location,
+ second_blend_source,
+ } => {
+ if second_blend_source {
+ write!(out, "color({location}) index(1)")?
+ } else {
+ write!(out, "color({location})")?
+ }
+ }
+ Self::User {
+ prefix,
+ index,
+ interpolation,
+ } => {
+ write!(out, "user({prefix}{index})")?;
+ if let Some(interpolation) = interpolation {
+ write!(out, ", ")?;
+ interpolation.try_fmt(out)?;
+ }
+ }
+ Self::Resource(ref target) => {
+ if let Some(id) = target.buffer {
+ write!(out, "buffer({id})")?;
+ } else if let Some(id) = target.texture {
+ write!(out, "texture({id})")?;
+ } else if let Some(BindSamplerTarget::Resource(id)) = target.sampler {
+ write!(out, "sampler({id})")?;
+ } else {
+ return Err(Error::UnimplementedBindTarget(target.clone()));
+ }
+ }
+ }
+ write!(out, "]]")?;
+ Ok(())
+ }
+}
+
+impl ResolvedInterpolation {
+ const fn from_binding(interpolation: crate::Interpolation, sampling: crate::Sampling) -> Self {
+ use crate::Interpolation as I;
+ use crate::Sampling as S;
+
+ match (interpolation, sampling) {
+ (I::Perspective, S::Center) => Self::CenterPerspective,
+ (I::Perspective, S::Centroid) => Self::CentroidPerspective,
+ (I::Perspective, S::Sample) => Self::SamplePerspective,
+ (I::Linear, S::Center) => Self::CenterNoPerspective,
+ (I::Linear, S::Centroid) => Self::CentroidNoPerspective,
+ (I::Linear, S::Sample) => Self::SampleNoPerspective,
+ (I::Flat, _) => Self::Flat,
+ }
+ }
+
+ fn try_fmt<W: Write>(self, out: &mut W) -> Result<(), Error> {
+ let identifier = match self {
+ Self::CenterPerspective => "center_perspective",
+ Self::CenterNoPerspective => "center_no_perspective",
+ Self::CentroidPerspective => "centroid_perspective",
+ Self::CentroidNoPerspective => "centroid_no_perspective",
+ Self::SamplePerspective => "sample_perspective",
+ Self::SampleNoPerspective => "sample_no_perspective",
+ Self::Flat => "flat",
+ };
+ out.write_str(identifier)?;
+ Ok(())
+ }
+}
+
+/// Information about a translated module that is required
+/// for the use of the result.
+pub struct TranslationInfo {
+ /// Mapping of the entry point names. Each item in the array
+ /// corresponds to an entry point index.
+ ///
+ ///Note: Some entry points may fail translation because of missing bindings.
+ pub entry_point_names: Vec<Result<String, EntryPointError>>,
+}
+
+pub fn write_string(
+ module: &crate::Module,
+ info: &ModuleInfo,
+ options: &Options,
+ pipeline_options: &PipelineOptions,
+) -> Result<(String, TranslationInfo), Error> {
+ let mut w = writer::Writer::new(String::new());
+ let info = w.write(module, info, options, pipeline_options)?;
+ Ok((w.finish(), info))
+}
+
+#[test]
+fn test_error_size() {
+ use std::mem::size_of;
+ assert_eq!(size_of::<Error>(), 32);
+}
diff --git a/third_party/rust/naga/src/back/msl/sampler.rs b/third_party/rust/naga/src/back/msl/sampler.rs
new file mode 100644
index 0000000000..0bf987076d
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/sampler.rs
@@ -0,0 +1,176 @@
+#[cfg(feature = "deserialize")]
+use serde::Deserialize;
+#[cfg(feature = "serialize")]
+use serde::Serialize;
+use std::{num::NonZeroU32, ops::Range};
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Coord {
+ Normalized,
+ Pixel,
+}
+
+impl Default for Coord {
+ fn default() -> Self {
+ Self::Normalized
+ }
+}
+
+impl Coord {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Normalized => "normalized",
+ Self::Pixel => "pixel",
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Address {
+ Repeat,
+ MirroredRepeat,
+ ClampToEdge,
+ ClampToZero,
+ ClampToBorder,
+}
+
+impl Default for Address {
+ fn default() -> Self {
+ Self::ClampToEdge
+ }
+}
+
+impl Address {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Repeat => "repeat",
+ Self::MirroredRepeat => "mirrored_repeat",
+ Self::ClampToEdge => "clamp_to_edge",
+ Self::ClampToZero => "clamp_to_zero",
+ Self::ClampToBorder => "clamp_to_border",
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum BorderColor {
+ TransparentBlack,
+ OpaqueBlack,
+ OpaqueWhite,
+}
+
+impl Default for BorderColor {
+ fn default() -> Self {
+ Self::TransparentBlack
+ }
+}
+
+impl BorderColor {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::TransparentBlack => "transparent_black",
+ Self::OpaqueBlack => "opaque_black",
+ Self::OpaqueWhite => "opaque_white",
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Filter {
+ Nearest,
+ Linear,
+}
+
+impl Filter {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Nearest => "nearest",
+ Self::Linear => "linear",
+ }
+ }
+}
+
+impl Default for Filter {
+ fn default() -> Self {
+ Self::Nearest
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum CompareFunc {
+ Never,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+ Equal,
+ NotEqual,
+ Always,
+}
+
+impl Default for CompareFunc {
+ fn default() -> Self {
+ Self::Never
+ }
+}
+
+impl CompareFunc {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Never => "never",
+ Self::Less => "less",
+ Self::LessEqual => "less_equal",
+ Self::Greater => "greater",
+ Self::GreaterEqual => "greater_equal",
+ Self::Equal => "equal",
+ Self::NotEqual => "not_equal",
+ Self::Always => "always",
+ }
+ }
+}
+
+#[derive(Clone, Debug, Default, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct InlineSampler {
+ pub coord: Coord,
+ pub address: [Address; 3],
+ pub border_color: BorderColor,
+ pub mag_filter: Filter,
+ pub min_filter: Filter,
+ pub mip_filter: Option<Filter>,
+ pub lod_clamp: Option<Range<f32>>,
+ pub max_anisotropy: Option<NonZeroU32>,
+ pub compare_func: CompareFunc,
+}
+
+impl Eq for InlineSampler {}
+
+#[allow(renamed_and_removed_lints)]
+#[allow(clippy::derive_hash_xor_eq)]
+impl std::hash::Hash for InlineSampler {
+ fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
+ self.coord.hash(hasher);
+ self.address.hash(hasher);
+ self.border_color.hash(hasher);
+ self.mag_filter.hash(hasher);
+ self.min_filter.hash(hasher);
+ self.mip_filter.hash(hasher);
+ self.lod_clamp
+ .as_ref()
+ .map(|range| (range.start.to_bits(), range.end.to_bits()))
+ .hash(hasher);
+ self.max_anisotropy.hash(hasher);
+ self.compare_func.hash(hasher);
+ }
+}
diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs
new file mode 100644
index 0000000000..1e496b5f50
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/writer.rs
@@ -0,0 +1,4659 @@
+use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
+use crate::{
+ arena::Handle,
+ back,
+ proc::index,
+ proc::{self, NameKey, TypeResolution},
+ valid, FastHashMap, FastHashSet,
+};
+use bit_set::BitSet;
+use std::{
+ fmt::{Display, Error as FmtError, Formatter, Write},
+ iter,
+};
+
+/// Shorthand result used internally by the backend
+type BackendResult = Result<(), Error>;
+
+const NAMESPACE: &str = "metal";
+// The name of the array member of the Metal struct types we generate to
+// represent Naga `Array` types. See the comments in `Writer::write_type_defs`
+// for details.
+const WRAPPED_ARRAY_FIELD: &str = "inner";
+// This is a hack: we need to pass a pointer to an atomic,
+// but generally the backend isn't putting "&" in front of every pointer.
+// Some more general handling of pointers is needed to be implemented here.
+const ATOMIC_REFERENCE: &str = "&";
+
+const RT_NAMESPACE: &str = "metal::raytracing";
+const RAY_QUERY_TYPE: &str = "_RayQuery";
+const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
+const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
+const RAY_QUERY_FIELD_READY: &str = "ready";
+const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
+
+pub(crate) const MODF_FUNCTION: &str = "naga_modf";
+pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
+
+/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
+///
+/// The `sizes` slice determines whether this function writes a
+/// scalar, vector, or matrix type:
+///
+/// - An empty slice produces a scalar type.
+/// - A one-element slice produces a vector type.
+/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
+fn put_numeric_type(
+ out: &mut impl Write,
+ scalar: crate::Scalar,
+ sizes: &[crate::VectorSize],
+) -> Result<(), FmtError> {
+ match (scalar, sizes) {
+ (scalar, &[]) => {
+ write!(out, "{}", scalar.to_msl_name())
+ }
+ (scalar, &[rows]) => {
+ write!(
+ out,
+ "{}::{}{}",
+ NAMESPACE,
+ scalar.to_msl_name(),
+ back::vector_size_str(rows)
+ )
+ }
+ (scalar, &[rows, columns]) => {
+ write!(
+ out,
+ "{}::{}{}x{}",
+ NAMESPACE,
+ scalar.to_msl_name(),
+ back::vector_size_str(columns),
+ back::vector_size_str(rows)
+ )
+ }
+ (_, _) => Ok(()), // not meaningful
+ }
+}
+
+/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
+const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
+
+struct TypeContext<'a> {
+ handle: Handle<crate::Type>,
+ gctx: proc::GlobalCtx<'a>,
+ names: &'a FastHashMap<NameKey, String>,
+ access: crate::StorageAccess,
+ binding: Option<&'a super::ResolvedBinding>,
+ first_time: bool,
+}
+
+impl<'a> Display for TypeContext<'a> {
+ fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
+ let ty = &self.gctx.types[self.handle];
+ if ty.needs_alias() && !self.first_time {
+ let name = &self.names[&NameKey::Type(self.handle)];
+ return write!(out, "{name}");
+ }
+
+ match ty.inner {
+ crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
+ crate::TypeInner::Atomic(scalar) => {
+ write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
+ }
+ crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
+ crate::TypeInner::Matrix { columns, rows, .. } => {
+ put_numeric_type(out, crate::Scalar::F32, &[rows, columns])
+ }
+ crate::TypeInner::Pointer { base, space } => {
+ let sub = Self {
+ handle: base,
+ first_time: false,
+ ..*self
+ };
+ let space_name = match space.to_msl_name() {
+ Some(name) => name,
+ None => return Ok(()),
+ };
+ write!(out, "{space_name} {sub}&")
+ }
+ crate::TypeInner::ValuePointer {
+ size,
+ scalar,
+ space,
+ } => {
+ match space.to_msl_name() {
+ Some(name) => write!(out, "{name} ")?,
+ None => return Ok(()),
+ };
+ match size {
+ Some(rows) => put_numeric_type(out, scalar, &[rows])?,
+ None => put_numeric_type(out, scalar, &[])?,
+ };
+
+ write!(out, "&")
+ }
+ crate::TypeInner::Array { base, .. } => {
+ let sub = Self {
+ handle: base,
+ first_time: false,
+ ..*self
+ };
+ // Array lengths go at the end of the type definition,
+ // so just print the element type here.
+ write!(out, "{sub}")
+ }
+ crate::TypeInner::Struct { .. } => unreachable!(),
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ let dim_str = match dim {
+ crate::ImageDimension::D1 => "1d",
+ crate::ImageDimension::D2 => "2d",
+ crate::ImageDimension::D3 => "3d",
+ crate::ImageDimension::Cube => "cube",
+ };
+ let (texture_str, msaa_str, kind, access) = match class {
+ crate::ImageClass::Sampled { kind, multi } => {
+ let (msaa_str, access) = if multi {
+ ("_ms", "read")
+ } else {
+ ("", "sample")
+ };
+ ("texture", msaa_str, kind, access)
+ }
+ crate::ImageClass::Depth { multi } => {
+ let (msaa_str, access) = if multi {
+ ("_ms", "read")
+ } else {
+ ("", "sample")
+ };
+ ("depth", msaa_str, crate::ScalarKind::Float, access)
+ }
+ crate::ImageClass::Storage { format, .. } => {
+ let access = if self
+ .access
+ .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
+ {
+ "read_write"
+ } else if self.access.contains(crate::StorageAccess::STORE) {
+ "write"
+ } else if self.access.contains(crate::StorageAccess::LOAD) {
+ "read"
+ } else {
+ log::warn!(
+ "Storage access for {:?} (name '{}'): {:?}",
+ self.handle,
+ ty.name.as_deref().unwrap_or_default(),
+ self.access
+ );
+ unreachable!("module is not valid");
+ };
+ ("texture", "", format.into(), access)
+ }
+ };
+ let base_name = crate::Scalar { kind, width: 4 }.to_msl_name();
+ let array_str = if arrayed { "_array" } else { "" };
+ write!(
+ out,
+ "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
+ )
+ }
+ crate::TypeInner::Sampler { comparison: _ } => {
+ write!(out, "{NAMESPACE}::sampler")
+ }
+ crate::TypeInner::AccelerationStructure => {
+ write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
+ }
+ crate::TypeInner::RayQuery => {
+ write!(out, "{RAY_QUERY_TYPE}")
+ }
+ crate::TypeInner::BindingArray { base, size } => {
+ let base_tyname = Self {
+ handle: base,
+ first_time: false,
+ ..*self
+ };
+
+ if let Some(&super::ResolvedBinding::Resource(super::BindTarget {
+ binding_array_size: Some(override_size),
+ ..
+ })) = self.binding
+ {
+ write!(out, "{NAMESPACE}::array<{base_tyname}, {override_size}>")
+ } else if let crate::ArraySize::Constant(size) = size {
+ write!(out, "{NAMESPACE}::array<{base_tyname}, {size}>")
+ } else {
+ unreachable!("metal requires all arrays be constant sized");
+ }
+ }
+ }
+ }
+}
+
+struct TypedGlobalVariable<'a> {
+ module: &'a crate::Module,
+ names: &'a FastHashMap<NameKey, String>,
+ handle: Handle<crate::GlobalVariable>,
+ usage: valid::GlobalUse,
+ binding: Option<&'a super::ResolvedBinding>,
+ reference: bool,
+}
+
+impl<'a> TypedGlobalVariable<'a> {
+ fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
+ let var = &self.module.global_variables[self.handle];
+ let name = &self.names[&NameKey::GlobalVariable(self.handle)];
+
+ let storage_access = match var.space {
+ crate::AddressSpace::Storage { access } => access,
+ _ => match self.module.types[var.ty].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => access,
+ crate::TypeInner::BindingArray { base, .. } => {
+ match self.module.types[base].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => access,
+ _ => crate::StorageAccess::default(),
+ }
+ }
+ _ => crate::StorageAccess::default(),
+ },
+ };
+ let ty_name = TypeContext {
+ handle: var.ty,
+ gctx: self.module.to_ctx(),
+ names: self.names,
+ access: storage_access,
+ binding: self.binding,
+ first_time: false,
+ };
+
+ let (space, access, reference) = match var.space.to_msl_name() {
+ Some(space) if self.reference => {
+ let access = if var.space.needs_access_qualifier()
+ && !self.usage.contains(valid::GlobalUse::WRITE)
+ {
+ "const"
+ } else {
+ ""
+ };
+ (space, access, "&")
+ }
+ _ => ("", "", ""),
+ };
+
+ Ok(write!(
+ out,
+ "{}{}{}{}{}{} {}",
+ space,
+ if space.is_empty() { "" } else { " " },
+ ty_name,
+ if access.is_empty() { "" } else { " " },
+ access,
+ reference,
+ name,
+ )?)
+ }
+}
+
+pub struct Writer<W> {
+ out: W,
+ names: FastHashMap<NameKey, String>,
+ named_expressions: crate::NamedExpressions,
+ /// Set of expressions that need to be baked to avoid unnecessary repetition in output
+ need_bake_expressions: back::NeedBakeExpressions,
+ namer: proc::Namer,
+ #[cfg(test)]
+ put_expression_stack_pointers: FastHashSet<*const ()>,
+ #[cfg(test)]
+ put_block_stack_pointers: FastHashSet<*const ()>,
+ /// Set of (struct type, struct field index) denoting which fields require
+ /// padding inserted **before** them (i.e. between fields at index - 1 and index)
+ struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
+}
+
+impl crate::Scalar {
+ const fn to_msl_name(self) -> &'static str {
+ use crate::ScalarKind as Sk;
+ match self {
+ Self {
+ kind: Sk::Float,
+ width: _,
+ } => "float",
+ Self {
+ kind: Sk::Sint,
+ width: _,
+ } => "int",
+ Self {
+ kind: Sk::Uint,
+ width: _,
+ } => "uint",
+ Self {
+ kind: Sk::Bool,
+ width: _,
+ } => "bool",
+ Self {
+ kind: Sk::AbstractInt | Sk::AbstractFloat,
+ width: _,
+ } => unreachable!(),
+ }
+ }
+}
+
+const fn separate(need_separator: bool) -> &'static str {
+ if need_separator {
+ ","
+ } else {
+ ""
+ }
+}
+
+fn should_pack_struct_member(
+ members: &[crate::StructMember],
+ span: u32,
+ index: usize,
+ module: &crate::Module,
+) -> Option<crate::Scalar> {
+ let member = &members[index];
+
+ let ty_inner = &module.types[member.ty].inner;
+ let last_offset = member.offset + ty_inner.size(module.to_ctx());
+ let next_offset = match members.get(index + 1) {
+ Some(next) => next.offset,
+ None => span,
+ };
+ let is_tight = next_offset == last_offset;
+
+ match *ty_inner {
+ crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: scalar @ crate::Scalar { width: 4, .. },
+ } if is_tight => Some(scalar),
+ _ => None,
+ }
+}
+
+fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
+ match arena[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ if let Some(member) = members.last() {
+ if let crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } = arena[member.ty].inner
+ {
+ return true;
+ }
+ }
+ false
+ }
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } => true,
+ _ => false,
+ }
+}
+
+impl crate::AddressSpace {
+ /// Returns true if global variables in this address space are
+ /// passed in function arguments. These arguments need to be
+ /// passed through any functions called from the entry point.
+ const fn needs_pass_through(&self) -> bool {
+ match *self {
+ Self::Uniform
+ | Self::Storage { .. }
+ | Self::Private
+ | Self::WorkGroup
+ | Self::PushConstant
+ | Self::Handle => true,
+ Self::Function => false,
+ }
+ }
+
+ /// Returns true if the address space may need a "const" qualifier.
+ const fn needs_access_qualifier(&self) -> bool {
+ match *self {
+ //Note: we are ignoring the storage access here, and instead
+ // rely on the actual use of a global by functions. This means we
+ // may end up with "const" even if the binding is read-write,
+ // and that should be OK.
+ Self::Storage { .. } => true,
+ // These should always be read-write.
+ Self::Private | Self::WorkGroup => false,
+ // These translate to `constant` address space, no need for qualifiers.
+ Self::Uniform | Self::PushConstant => false,
+ // Not applicable.
+ Self::Handle | Self::Function => false,
+ }
+ }
+
+ const fn to_msl_name(self) -> Option<&'static str> {
+ match self {
+ Self::Handle => None,
+ Self::Uniform | Self::PushConstant => Some("constant"),
+ Self::Storage { .. } => Some("device"),
+ Self::Private | Self::Function => Some("thread"),
+ Self::WorkGroup => Some("threadgroup"),
+ }
+ }
+}
+
+impl crate::Type {
+ // Returns `true` if we need to emit an alias for this type.
+ const fn needs_alias(&self) -> bool {
+ use crate::TypeInner as Ti;
+
+ match self.inner {
+ // value types are concise enough, we only alias them if they are named
+ Ti::Scalar(_)
+ | Ti::Vector { .. }
+ | Ti::Matrix { .. }
+ | Ti::Atomic(_)
+ | Ti::Pointer { .. }
+ | Ti::ValuePointer { .. } => self.name.is_some(),
+ // composite types are better to be aliased, regardless of the name
+ Ti::Struct { .. } | Ti::Array { .. } => true,
+ // handle types may be different, depending on the global var access, so we always inline them
+ Ti::Image { .. }
+ | Ti::Sampler { .. }
+ | Ti::AccelerationStructure
+ | Ti::RayQuery
+ | Ti::BindingArray { .. } => false,
+ }
+ }
+}
+
+enum FunctionOrigin {
+ Handle(Handle<crate::Function>),
+ EntryPoint(proc::EntryPointIndex),
+}
+
+/// A level of detail argument.
+///
+/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
+/// save the clamped level of detail in a temporary variable whose name is based
+/// on the handle of the `ImageLoad` expression. But for other policies, we just
+/// use the expression directly.
+///
+/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+#[derive(Clone, Copy)]
+enum LevelOfDetail {
+ Direct(Handle<crate::Expression>),
+ Restricted(Handle<crate::Expression>),
+}
+
+/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
+///
+/// When this is used in code paths unconcerned with the `Restrict` bounds check
+/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
+/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
+/// be too awkward. If that changes, we can revisit.
+///
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+/// [`ImageStore`]: crate::Statement::ImageStore
+struct TexelAddress {
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ sample: Option<Handle<crate::Expression>>,
+ level: Option<LevelOfDetail>,
+}
+
+struct ExpressionContext<'a> {
+ function: &'a crate::Function,
+ origin: FunctionOrigin,
+ info: &'a valid::FunctionInfo,
+ module: &'a crate::Module,
+ mod_info: &'a valid::ModuleInfo,
+ pipeline_options: &'a PipelineOptions,
+ lang_version: (u8, u8),
+ policies: index::BoundsCheckPolicies,
+
+ /// A bitset containing the `Expression` handle indexes of expressions used
+ /// as indices in `ReadZeroSkipWrite`-policy accesses. These may need to be
+ /// cached in temporary variables. See `index::find_checked_indexes` for
+ /// details.
+ guarded_indices: BitSet,
+}
+
+impl<'a> ExpressionContext<'a> {
+ fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
+ self.info[handle].ty.inner_with(&self.module.types)
+ }
+
+ /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
+ ///
+ /// Only mipmapped images need to specify a level of detail. Since 1D
+ /// textures cannot have mipmaps, MSL requires that the level argument to
+ /// texture1d queries and accesses must be a constexpr 0. It's easiest
+ /// just to omit the level entirely for 1D textures.
+ fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
+ let image_ty = self.resolve_type(image);
+ if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
+ class.is_mipmapped() && dim != crate::ImageDimension::D1
+ } else {
+ false
+ }
+ }
+
+ fn choose_bounds_check_policy(
+ &self,
+ pointer: Handle<crate::Expression>,
+ ) -> index::BoundsCheckPolicy {
+ self.policies
+ .choose_policy(pointer, &self.module.types, self.info)
+ }
+
+ fn access_needs_check(
+ &self,
+ base: Handle<crate::Expression>,
+ index: index::GuardedIndex,
+ ) -> Option<index::IndexableLength> {
+ index::access_needs_check(base, index, self.module, self.function, self.info)
+ }
+
+ fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
+ match self.function.expressions[expr_handle] {
+ crate::Expression::AccessIndex { base, index } => {
+ let ty = match *self.resolve_type(base) {
+ crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
+ ref ty => ty,
+ };
+ match *ty {
+ crate::TypeInner::Struct {
+ ref members, span, ..
+ } => should_pack_struct_member(members, span, index as usize, self.module),
+ _ => None,
+ }
+ }
+ _ => None,
+ }
+ }
+}
+
+struct StatementContext<'a> {
+ expression: ExpressionContext<'a>,
+ result_struct: Option<&'a str>,
+}
+
+impl<W: Write> Writer<W> {
+ /// Creates a new `Writer` instance.
+ pub fn new(out: W) -> Self {
+ Writer {
+ out,
+ names: FastHashMap::default(),
+ named_expressions: Default::default(),
+ need_bake_expressions: Default::default(),
+ namer: proc::Namer::default(),
+ #[cfg(test)]
+ put_expression_stack_pointers: Default::default(),
+ #[cfg(test)]
+ put_block_stack_pointers: Default::default(),
+ struct_member_pads: FastHashSet::default(),
+ }
+ }
+
+ /// Finishes writing and returns the output.
+ // See https://github.com/rust-lang/rust-clippy/issues/4979.
+ #[allow(clippy::missing_const_for_fn)]
+ pub fn finish(self) -> W {
+ self.out
+ }
+
+ fn put_call_parameters(
+ &mut self,
+ parameters: impl Iterator<Item = Handle<crate::Expression>>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
+ writer.put_expression(expr, context, true)
+ })
+ }
+
+ fn put_call_parameters_impl<C, E>(
+ &mut self,
+ parameters: impl Iterator<Item = Handle<crate::Expression>>,
+ ctx: &C,
+ put_expression: E,
+ ) -> BackendResult
+ where
+ E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
+ {
+ write!(self.out, "(")?;
+ for (i, handle) in parameters.enumerate() {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ put_expression(self, ctx, handle)?;
+ }
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_level_of_detail(
+ &mut self,
+ level: LevelOfDetail,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match level {
+ LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
+ LevelOfDetail::Restricted(load) => {
+ write!(self.out, "{}{}", CLAMPED_LOD_LOAD_PREFIX, load.index())?
+ }
+ }
+ Ok(())
+ }
+
+ fn put_image_query(
+ &mut self,
+ image: Handle<crate::Expression>,
+ query: &str,
+ level: Option<LevelOfDetail>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_{query}(")?;
+ if let Some(level) = level {
+ self.put_level_of_detail(level, context)?;
+ }
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_image_size_query(
+ &mut self,
+ image: Handle<crate::Expression>,
+ level: Option<LevelOfDetail>,
+ kind: crate::ScalarKind,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ //Note: MSL only has separate width/height/depth queries,
+ // so compose the result of them.
+ let dim = match *context.resolve_type(image) {
+ crate::TypeInner::Image { dim, .. } => dim,
+ ref other => unreachable!("Unexpected type {:?}", other),
+ };
+ let scalar = crate::Scalar { kind, width: 4 };
+ let coordinate_type = scalar.to_msl_name();
+ match dim {
+ crate::ImageDimension::D1 => {
+ // Since 1D textures never have mipmaps, MSL requires that the
+ // `level` argument be a constexpr 0. It's simplest for us just
+ // to pass `None` and omit the level entirely.
+ if kind == crate::ScalarKind::Uint {
+ // No need to construct a vector. No cast needed.
+ self.put_image_query(image, "width", None, context)?;
+ } else {
+ // There's no definition for `int` in the `metal` namespace.
+ write!(self.out, "int(")?;
+ self.put_image_query(image, "width", None, context)?;
+ write!(self.out, ")")?;
+ }
+ }
+ crate::ImageDimension::D2 => {
+ write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
+ self.put_image_query(image, "width", level, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_query(image, "height", level, context)?;
+ write!(self.out, ")")?;
+ }
+ crate::ImageDimension::D3 => {
+ write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
+ self.put_image_query(image, "width", level, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_query(image, "height", level, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_query(image, "depth", level, context)?;
+ write!(self.out, ")")?;
+ }
+ crate::ImageDimension::Cube => {
+ write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
+ self.put_image_query(image, "width", level, context)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Ok(())
+ }
+
+ fn put_cast_to_uint_scalar_or_vector(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // coordinates in IR are int, but Metal expects uint
+ match *context.resolve_type(expr) {
+ crate::TypeInner::Scalar(_) => {
+ put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
+ }
+ crate::TypeInner::Vector { size, .. } => {
+ put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
+ }
+ _ => return Err(Error::Validation),
+ };
+
+ write!(self.out, "(")?;
+ self.put_expression(expr, context, true)?;
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_image_sample_level(
+ &mut self,
+ image: Handle<crate::Expression>,
+ level: crate::SampleLevel,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let has_levels = context.image_needs_lod(image);
+ match level {
+ crate::SampleLevel::Auto => {}
+ crate::SampleLevel::Zero => {
+ //TODO: do we support Zero on `Sampled` image classes?
+ }
+ _ if !has_levels => {
+ log::warn!("1D image can't be sampled with level {:?}", level);
+ }
+ crate::SampleLevel::Exact(h) => {
+ write!(self.out, ", {NAMESPACE}::level(")?;
+ self.put_expression(h, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::SampleLevel::Bias(h) => {
+ write!(self.out, ", {NAMESPACE}::bias(")?;
+ self.put_expression(h, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ write!(self.out, ", {NAMESPACE}::gradient2d(")?;
+ self.put_expression(x, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(y, context, true)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Ok(())
+ }
+
+ fn put_image_coordinate_limits(
+ &mut self,
+ image: Handle<crate::Expression>,
+ level: Option<LevelOfDetail>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
+ write!(self.out, " - 1")?;
+ Ok(())
+ }
+
+ /// General function for writing restricted image indexes.
+ ///
+ /// This is used to produce restricted mip levels, array indices, and sample
+ /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
+ /// [`Restrict`] bounds check policy.
+ ///
+ /// This function writes an expression of the form:
+ ///
+ /// ```ignore
+ ///
+ /// metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
+ ///
+ /// ```
+ ///
+ /// [`ImageLoad`]: crate::Expression::ImageLoad
+ /// [`ImageStore`]: crate::Statement::ImageStore
+ /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
+ fn put_restricted_scalar_image_index(
+ &mut self,
+ image: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ limit_method: &str,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ write!(self.out, "{NAMESPACE}::min(uint(")?;
+ self.put_expression(index, context, true)?;
+ write!(self.out, "), ")?;
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".{limit_method}() - 1)")?;
+ Ok(())
+ }
+
+ fn put_restricted_texel_address(
+ &mut self,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // Write the coordinate.
+ write!(self.out, "{NAMESPACE}::min(")?;
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_coordinate_limits(image, address.level, context)?;
+ write!(self.out, ")")?;
+
+ // Write the array index, if present.
+ if let Some(array_index) = address.array_index {
+ write!(self.out, ", ")?;
+ self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
+ }
+
+ // Write the sample index, if present.
+ if let Some(sample) = address.sample {
+ write!(self.out, ", ")?;
+ self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
+ }
+
+ // The level of detail should be clamped and cached by
+ // `put_cache_restricted_level`, so we don't need to clamp it here.
+ if let Some(level) = address.level {
+ write!(self.out, ", ")?;
+ self.put_level_of_detail(level, context)?;
+ }
+
+ Ok(())
+ }
+
+ /// Write an expression that is true if the given image access is in bounds.
+ fn put_image_access_bounds_check(
+ &mut self,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let mut conjunction = "";
+
+ // First, check the level of detail. Only if that is in bounds can we
+ // use it to find the appropriate bounds for the coordinates.
+ let level = if let Some(level) = address.level {
+ write!(self.out, "uint(")?;
+ self.put_level_of_detail(level, context)?;
+ write!(self.out, ") < ")?;
+ self.put_expression(image, context, true)?;
+ write!(self.out, ".get_num_mip_levels()")?;
+ conjunction = " && ";
+ Some(level)
+ } else {
+ None
+ };
+
+ // Check sample index, if present.
+ if let Some(sample) = address.sample {
+ write!(self.out, "uint(")?;
+ self.put_expression(sample, context, true)?;
+ write!(self.out, ") < ")?;
+ self.put_expression(image, context, true)?;
+ write!(self.out, ".get_num_samples()")?;
+ conjunction = " && ";
+ }
+
+ // Check array index, if present.
+ if let Some(array_index) = address.array_index {
+ write!(self.out, "{conjunction}uint(")?;
+ self.put_expression(array_index, context, true)?;
+ write!(self.out, ") < ")?;
+ self.put_expression(image, context, true)?;
+ write!(self.out, ".get_array_size()")?;
+ conjunction = " && ";
+ }
+
+ // Finally, check if the coordinates are within bounds.
+ let coord_is_vector = match *context.resolve_type(address.coordinate) {
+ crate::TypeInner::Vector { .. } => true,
+ _ => false,
+ };
+ write!(self.out, "{conjunction}")?;
+ if coord_is_vector {
+ write!(self.out, "{NAMESPACE}::all(")?;
+ }
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
+ write!(self.out, " < ")?;
+ self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
+ if coord_is_vector {
+ write!(self.out, ")")?;
+ }
+
+ Ok(())
+ }
+
+ fn put_image_load(
+ &mut self,
+ load: Handle<crate::Expression>,
+ image: Handle<crate::Expression>,
+ mut address: TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match context.policies.image_load {
+ proc::BoundsCheckPolicy::Restrict => {
+ // Use the cached restricted level of detail, if any. Omit the
+ // level altogether for 1D textures.
+ if address.level.is_some() {
+ address.level = if context.image_needs_lod(image) {
+ Some(LevelOfDetail::Restricted(load))
+ } else {
+ None
+ }
+ }
+
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".read(")?;
+ self.put_restricted_texel_address(image, &address, context)?;
+ write!(self.out, ")")?;
+ }
+ proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
+ write!(self.out, "(")?;
+ self.put_image_access_bounds_check(image, &address, context)?;
+ write!(self.out, " ? ")?;
+ self.put_unchecked_image_load(image, &address, context)?;
+ write!(self.out, ": DefaultConstructible())")?;
+ }
+ proc::BoundsCheckPolicy::Unchecked => {
+ self.put_unchecked_image_load(image, &address, context)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_image_load(
+ &mut self,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".read(")?;
+ // coordinates in IR are int, but Metal expects uint
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
+ if let Some(expr) = address.array_index {
+ write!(self.out, ", ")?;
+ self.put_expression(expr, context, true)?;
+ }
+ if let Some(sample) = address.sample {
+ write!(self.out, ", ")?;
+ self.put_expression(sample, context, true)?;
+ }
+ if let Some(level) = address.level {
+ if context.image_needs_lod(image) {
+ write!(self.out, ", ")?;
+ self.put_level_of_detail(level, context)?;
+ }
+ }
+ write!(self.out, ")")?;
+
+ Ok(())
+ }
+
+ fn put_image_store(
+ &mut self,
+ level: back::Level,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ value: Handle<crate::Expression>,
+ context: &StatementContext,
+ ) -> BackendResult {
+ match context.expression.policies.image_store {
+ proc::BoundsCheckPolicy::Restrict => {
+ // We don't have a restricted level value, because we don't
+ // support writes to mipmapped textures.
+ debug_assert!(address.level.is_none());
+
+ write!(self.out, "{level}")?;
+ self.put_expression(image, &context.expression, false)?;
+ write!(self.out, ".write(")?;
+ self.put_expression(value, &context.expression, true)?;
+ write!(self.out, ", ")?;
+ self.put_restricted_texel_address(image, address, &context.expression)?;
+ writeln!(self.out, ");")?;
+ }
+ proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
+ write!(self.out, "{level}if (")?;
+ self.put_image_access_bounds_check(image, address, &context.expression)?;
+ writeln!(self.out, ") {{")?;
+ self.put_unchecked_image_store(level.next(), image, address, value, context)?;
+ writeln!(self.out, "{level}}}")?;
+ }
+ proc::BoundsCheckPolicy::Unchecked => {
+ self.put_unchecked_image_store(level, image, address, value, context)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_image_store(
+ &mut self,
+ level: back::Level,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ value: Handle<crate::Expression>,
+ context: &StatementContext,
+ ) -> BackendResult {
+ write!(self.out, "{level}")?;
+ self.put_expression(image, &context.expression, false)?;
+ write!(self.out, ".write(")?;
+ self.put_expression(value, &context.expression, true)?;
+ write!(self.out, ", ")?;
+ // coordinates in IR are int, but Metal expects uint
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
+ if let Some(expr) = address.array_index {
+ write!(self.out, ", ")?;
+ self.put_expression(expr, &context.expression, true)?;
+ }
+ writeln!(self.out, ");")?;
+
+ Ok(())
+ }
+
+ /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
+ ///
+ /// The 'maximum valid index' is simply one less than the array's length.
+ ///
+ /// This emits an expression of the form `a / b`, so the caller must
+ /// parenthesize its output if it will be applying operators of higher
+ /// precedence.
+ ///
+ /// `handle` must be the handle of a global variable whose final member is a
+ /// dynamically sized array.
+ fn put_dynamic_array_max_index(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let global = &context.module.global_variables[handle];
+ let (offset, array_ty) = match context.module.types[global.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => match members.last() {
+ Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
+ None => return Err(Error::Validation),
+ },
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } => (0, global.ty),
+ _ => return Err(Error::Validation),
+ };
+
+ let (size, stride) = match context.module.types[array_ty].inner {
+ crate::TypeInner::Array { base, stride, .. } => (
+ context.module.types[base]
+ .inner
+ .size(context.module.to_ctx()),
+ stride,
+ ),
+ _ => return Err(Error::Validation),
+ };
+
+ // When the stride length is larger than the size, the final element's stride of
+ // bytes would have padding following the value. But the buffer size in
+ // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
+ // enough to hold the actual values' bytes.
+ //
+ // So subtract off the size to get a byte size that falls at the start or within
+ // the final element. Then divide by the stride size, to get one less than the
+ // length, and then add one. This works even if the buffer size does include the
+ // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
+ // if there are zero elements in the array, but the WebGPU `validating shader binding`
+ // rules, together with draw-time validation when `minBindingSize` is zero,
+ // prevent that.
+ write!(
+ self.out,
+ "(_buffer_sizes.size{idx} - {offset} - {size}) / {stride}",
+ idx = handle.index(),
+ offset = offset,
+ size = size,
+ stride = stride,
+ )?;
+ Ok(())
+ }
+
+ fn put_atomic_fetch(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ key: &str,
+ value: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_atomic_operation(pointer, "fetch_", key, value, context)
+ }
+
+ fn put_atomic_operation(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ key1: &str,
+ key2: &str,
+ value: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // If the pointer we're passing to the atomic operation needs to be conditional
+ // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
+ // the pointer operand should be unchecked.
+ let policy = context.choose_bounds_check_policy(pointer);
+ let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
+
+ // If requested and successfully put bounds checks, continue the ternary expression.
+ if checked {
+ write!(self.out, " ? ")?;
+ }
+
+ write!(
+ self.out,
+ "{NAMESPACE}::atomic_{key1}{key2}_explicit({ATOMIC_REFERENCE}"
+ )?;
+ self.put_access_chain(pointer, policy, context)?;
+ write!(self.out, ", ")?;
+ self.put_expression(value, context, true)?;
+ write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
+
+ // Finish the ternary expression.
+ if checked {
+ write!(self.out, " : DefaultConstructible()")?;
+ }
+
+ Ok(())
+ }
+
+ /// Emit code for the arithmetic expression of the dot product.
+ ///
+ fn put_dot_product(
+ &mut self,
+ arg: Handle<crate::Expression>,
+ arg1: Handle<crate::Expression>,
+ size: usize,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // Write parentheses around the dot product expression to prevent operators
+ // with different precedences from applying earlier.
+ write!(self.out, "(")?;
+
+ // Cycle trough all the components of the vector
+ for index in 0..size {
+ let component = back::COMPONENTS[index];
+ // Write the addition to the previous product
+ // This will print an extra '+' at the beginning but that is fine in msl
+ write!(self.out, " + ")?;
+ // Write the first vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.put_expression(arg, context, true)?;
+ // Access the current component on the first vector
+ write!(self.out, ".{component} * ")?;
+ // Write the second vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.put_expression(arg1, context, true)?;
+ // Access the current component on the second vector
+ write!(self.out, ".{component}")?;
+ }
+
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ /// Emit code for the sign(i32) expression.
+ ///
+ fn put_isign(
+ &mut self,
+ arg: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
+ match context.resolve_type(arg) {
+ &crate::TypeInner::Vector { size, .. } => {
+ let size = back::vector_size_str(size);
+ write!(self.out, "int{size}(-1), int{size}(1)")?;
+ }
+ _ => {
+ write!(self.out, "-1, 1")?;
+ }
+ }
+ write!(self.out, ", (")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " > 0)), 0, (")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " == 0))")?;
+ Ok(())
+ }
+
+ fn put_const_expression(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ module: &crate::Module,
+ mod_info: &valid::ModuleInfo,
+ ) -> BackendResult {
+ self.put_possibly_const_expression(
+ expr_handle,
+ &module.const_expressions,
+ module,
+ mod_info,
+ &(module, mod_info),
+ |&(_, mod_info), expr| &mod_info[expr],
+ |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info),
+ )
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn put_possibly_const_expression<C, I, E>(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ expressions: &crate::Arena<crate::Expression>,
+ module: &crate::Module,
+ mod_info: &valid::ModuleInfo,
+ ctx: &C,
+ get_expr_ty: I,
+ put_expression: E,
+ ) -> BackendResult
+ where
+ I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
+ E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
+ {
+ match expressions[expr_handle] {
+ crate::Expression::Literal(literal) => match literal {
+ crate::Literal::F64(_) => {
+ return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
+ }
+ crate::Literal::F32(value) => {
+ if value.is_infinite() {
+ let sign = if value.is_sign_negative() { "-" } else { "" };
+ write!(self.out, "{sign}INFINITY")?;
+ } else if value.is_nan() {
+ write!(self.out, "NAN")?;
+ } else {
+ let suffix = if value.fract() == 0.0 { ".0" } else { "" };
+ write!(self.out, "{value}{suffix}")?;
+ }
+ }
+ crate::Literal::U32(value) => {
+ write!(self.out, "{value}u")?;
+ }
+ crate::Literal::I32(value) => {
+ write!(self.out, "{value}")?;
+ }
+ crate::Literal::I64(value) => {
+ write!(self.out, "{value}L")?;
+ }
+ crate::Literal::Bool(value) => {
+ write!(self.out, "{value}")?;
+ }
+ crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
+ return Err(Error::Validation);
+ }
+ },
+ crate::Expression::Constant(handle) => {
+ let constant = &module.constants[handle];
+ if constant.name.is_some() {
+ write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
+ } else {
+ self.put_const_expression(constant.init, module, mod_info)?;
+ }
+ }
+ crate::Expression::ZeroValue(ty) => {
+ let ty_name = TypeContext {
+ handle: ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{ty_name} {{}}")?;
+ }
+ crate::Expression::Compose { ty, ref components } => {
+ let ty_name = TypeContext {
+ handle: ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{ty_name}")?;
+ match module.types[ty].inner {
+ crate::TypeInner::Scalar(_)
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. } => {
+ self.put_call_parameters_impl(
+ components.iter().copied(),
+ ctx,
+ put_expression,
+ )?;
+ }
+ crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
+ write!(self.out, " {{")?;
+ for (index, &component) in components.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ // insert padding initialization, if needed
+ if self.struct_member_pads.contains(&(ty, index as u32)) {
+ write!(self.out, "{{}}, ")?;
+ }
+ put_expression(self, ctx, component)?;
+ }
+ write!(self.out, "}}")?;
+ }
+ _ => return Err(Error::UnsupportedCompose(ty)),
+ }
+ }
+ crate::Expression::Splat { size, value } => {
+ let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
+ crate::TypeInner::Scalar(scalar) => scalar,
+ _ => return Err(Error::Validation),
+ };
+ put_numeric_type(&mut self.out, scalar, &[size])?;
+ write!(self.out, "(")?;
+ put_expression(self, ctx, value)?;
+ write!(self.out, ")")?;
+ }
+ _ => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Emit code for the expression `expr_handle`.
+ ///
+ /// The `is_scoped` argument is true if the surrounding operators have the
+ /// precedence of the comma operator, or lower. So, for example:
+ ///
+ /// - Pass `true` for `is_scoped` when writing function arguments, an
+ /// expression statement, an initializer expression, or anything already
+ /// wrapped in parenthesis.
+ ///
+ /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
+ /// almost anything else.
+ fn put_expression(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ // Add to the set in order to track the stack size.
+ #[cfg(test)]
+ #[allow(trivial_casts)]
+ self.put_expression_stack_pointers
+ .insert(&expr_handle as *const _ as *const ());
+
+ if let Some(name) = self.named_expressions.get(&expr_handle) {
+ write!(self.out, "{name}")?;
+ return Ok(());
+ }
+
+ let expression = &context.function.expressions[expr_handle];
+ log::trace!("expression {:?} = {:?}", expr_handle, expression);
+ match *expression {
+ crate::Expression::Literal(_)
+ | crate::Expression::Constant(_)
+ | crate::Expression::ZeroValue(_)
+ | crate::Expression::Compose { .. }
+ | crate::Expression::Splat { .. } => {
+ self.put_possibly_const_expression(
+ expr_handle,
+ &context.function.expressions,
+ context.module,
+ context.mod_info,
+ context,
+ |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
+ |writer, context, expr| writer.put_expression(expr, context, true),
+ )?;
+ }
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => {
+ // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
+ // Since `put_bounds_checks` and `put_access_chain` handle an entire
+ // access chain at a time, recursing back through `put_expression` only
+ // for index expressions and the base object, we will never see intermediate
+ // `Access` or `AccessIndex` expressions here.
+ let policy = context.choose_bounds_check_policy(base);
+ if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(
+ expr_handle,
+ context,
+ back::Level(0),
+ if is_scoped { "" } else { "(" },
+ )?
+ {
+ write!(self.out, " ? ")?;
+ self.put_access_chain(expr_handle, policy, context)?;
+ write!(self.out, " : DefaultConstructible()")?;
+
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ } else {
+ self.put_access_chain(expr_handle, policy, context)?;
+ }
+ }
+ crate::Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.put_wrapped_expression_for_packed_vec3_access(vector, context, false)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
+ }
+ }
+ crate::Expression::FunctionArgument(index) => {
+ let name_key = match context.origin {
+ FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
+ FunctionOrigin::EntryPoint(ep_index) => {
+ NameKey::EntryPointArgument(ep_index, index)
+ }
+ };
+ let name = &self.names[&name_key];
+ write!(self.out, "{name}")?;
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{name}")?;
+ }
+ crate::Expression::LocalVariable(handle) => {
+ let name_key = match context.origin {
+ FunctionOrigin::Handle(fun_handle) => {
+ NameKey::FunctionLocal(fun_handle, handle)
+ }
+ FunctionOrigin::EntryPoint(ep_index) => {
+ NameKey::EntryPointLocal(ep_index, handle)
+ }
+ };
+ let name = &self.names[&name_key];
+ write!(self.out, "{name}")?;
+ }
+ crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
+ crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ let main_op = match gather {
+ Some(_) => "gather",
+ None => "sample",
+ };
+ let comparison_op = match depth_ref {
+ Some(_) => "_compare",
+ None => "",
+ };
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".{main_op}{comparison_op}(")?;
+ self.put_expression(sampler, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(coordinate, context, true)?;
+ if let Some(expr) = array_index {
+ write!(self.out, ", ")?;
+ self.put_expression(expr, context, true)?;
+ }
+ if let Some(dref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.put_expression(dref, context, true)?;
+ }
+
+ self.put_image_sample_level(image, level, context)?;
+
+ if let Some(offset) = offset {
+ write!(self.out, ", ")?;
+ self.put_const_expression(offset, context.module, context.mod_info)?;
+ }
+
+ match gather {
+ None | Some(crate::SwizzleComponent::X) => {}
+ Some(component) => {
+ let is_cube_map = match *context.resolve_type(image) {
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::Cube,
+ ..
+ } => true,
+ _ => false,
+ };
+ // Offset always comes before the gather, except
+ // in cube maps where it's not applicable
+ if offset.is_none() && !is_cube_map {
+ write!(self.out, ", {NAMESPACE}::int2(0)")?;
+ }
+ let letter = back::COMPONENTS[component as usize];
+ write!(self.out, ", {NAMESPACE}::component::{letter}")?;
+ }
+ }
+ write!(self.out, ")")?;
+ }
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ let address = TexelAddress {
+ coordinate,
+ array_index,
+ sample,
+ level: level.map(LevelOfDetail::Direct),
+ };
+ self.put_image_load(expr_handle, image, address, context)?;
+ }
+ //Note: for all the queries, the signed integers are expected,
+ // so a conversion is needed.
+ crate::Expression::ImageQuery { image, query } => match query {
+ crate::ImageQuery::Size { level } => {
+ self.put_image_size_query(
+ image,
+ level.map(LevelOfDetail::Direct),
+ crate::ScalarKind::Uint,
+ context,
+ )?;
+ }
+ crate::ImageQuery::NumLevels => {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_num_mip_levels()")?;
+ }
+ crate::ImageQuery::NumLayers => {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_array_size()")?;
+ }
+ crate::ImageQuery::NumSamples => {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_num_samples()")?;
+ }
+ },
+ crate::Expression::Unary { op, expr } => {
+ let op_str = match op {
+ crate::UnaryOperator::Negate => "-",
+ crate::UnaryOperator::LogicalNot => "!",
+ crate::UnaryOperator::BitwiseNot => "~",
+ };
+ write!(self.out, "{op_str}(")?;
+ self.put_expression(expr, context, false)?;
+ write!(self.out, ")")?;
+ }
+ crate::Expression::Binary { op, left, right } => {
+ let op_str = crate::back::binary_operation_str(op);
+ let kind = context
+ .resolve_type(left)
+ .scalar_kind()
+ .ok_or(Error::UnsupportedBinaryOp(op))?;
+
+ // TODO: handle undefined behavior of BinaryOperator::Modulo
+ //
+ // sint:
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ // if sign(left) == -1 || sign(right) == -1 return result as defined by WGSL
+ //
+ // uint:
+ // if right == 0 return 0
+ //
+ // float:
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+
+ if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
+ write!(self.out, "{NAMESPACE}::fmod(")?;
+ self.put_expression(left, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(right, context, true)?;
+ write!(self.out, ")")?;
+ } else {
+ if !is_scoped {
+ write!(self.out, "(")?;
+ }
+
+ // Cast packed vector if necessary
+ // Packed vector - matrix multiplications are not supported in MSL
+ if op == crate::BinaryOperator::Multiply
+ && matches!(
+ context.resolve_type(right),
+ &crate::TypeInner::Matrix { .. }
+ )
+ {
+ self.put_wrapped_expression_for_packed_vec3_access(left, context, false)?;
+ } else {
+ self.put_expression(left, context, false)?;
+ }
+
+ write!(self.out, " {op_str} ")?;
+
+ // See comment above
+ if op == crate::BinaryOperator::Multiply
+ && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
+ {
+ self.put_wrapped_expression_for_packed_vec3_access(right, context, false)?;
+ } else {
+ self.put_expression(right, context, false)?;
+ }
+
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ crate::Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => match *context.resolve_type(condition) {
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Bool,
+ ..
+ }) => {
+ if !is_scoped {
+ write!(self.out, "(")?;
+ }
+ self.put_expression(condition, context, false)?;
+ write!(self.out, " ? ")?;
+ self.put_expression(accept, context, false)?;
+ write!(self.out, " : ")?;
+ self.put_expression(reject, context, false)?;
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ }
+ crate::TypeInner::Vector {
+ scalar:
+ crate::Scalar {
+ kind: crate::ScalarKind::Bool,
+ ..
+ },
+ ..
+ } => {
+ write!(self.out, "{NAMESPACE}::select(")?;
+ self.put_expression(reject, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(accept, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(condition, context, true)?;
+ write!(self.out, ")")?;
+ }
+ _ => return Err(Error::Validation),
+ },
+ crate::Expression::Derivative { axis, expr, .. } => {
+ use crate::DerivativeAxis as Axis;
+ let op = match axis {
+ Axis::X => "dfdx",
+ Axis::Y => "dfdy",
+ Axis::Width => "fwidth",
+ };
+ write!(self.out, "{NAMESPACE}::{op}")?;
+ self.put_call_parameters(iter::once(expr), context)?;
+ }
+ crate::Expression::Relational { fun, argument } => {
+ let op = match fun {
+ crate::RelationalFunction::Any => "any",
+ crate::RelationalFunction::All => "all",
+ crate::RelationalFunction::IsNan => "isnan",
+ crate::RelationalFunction::IsInf => "isinf",
+ };
+ write!(self.out, "{NAMESPACE}::{op}")?;
+ self.put_call_parameters(iter::once(argument), context)?;
+ }
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ let arg_type = context.resolve_type(arg);
+ let scalar_argument = match arg_type {
+ &crate::TypeInner::Scalar(_) => true,
+ _ => false,
+ };
+
+ let fun_name = match fun {
+ // comparison
+ Mf::Abs => "abs",
+ Mf::Min => "min",
+ Mf::Max => "max",
+ Mf::Clamp => "clamp",
+ Mf::Saturate => "saturate",
+ // trigonometry
+ Mf::Cos => "cos",
+ Mf::Cosh => "cosh",
+ Mf::Sin => "sin",
+ Mf::Sinh => "sinh",
+ Mf::Tan => "tan",
+ Mf::Tanh => "tanh",
+ Mf::Acos => "acos",
+ Mf::Asin => "asin",
+ Mf::Atan => "atan",
+ Mf::Atan2 => "atan2",
+ Mf::Asinh => "asinh",
+ Mf::Acosh => "acosh",
+ Mf::Atanh => "atanh",
+ Mf::Radians => "",
+ Mf::Degrees => "",
+ // decomposition
+ Mf::Ceil => "ceil",
+ Mf::Floor => "floor",
+ Mf::Round => "rint",
+ Mf::Fract => "fract",
+ Mf::Trunc => "trunc",
+ Mf::Modf => MODF_FUNCTION,
+ Mf::Frexp => FREXP_FUNCTION,
+ Mf::Ldexp => "ldexp",
+ // exponent
+ Mf::Exp => "exp",
+ Mf::Exp2 => "exp2",
+ Mf::Log => "log",
+ Mf::Log2 => "log2",
+ Mf::Pow => "pow",
+ // geometry
+ Mf::Dot => match *context.resolve_type(arg) {
+ crate::TypeInner::Vector {
+ scalar:
+ crate::Scalar {
+ kind: crate::ScalarKind::Float,
+ ..
+ },
+ ..
+ } => "dot",
+ crate::TypeInner::Vector { size, .. } => {
+ return self.put_dot_product(arg, arg1.unwrap(), size as usize, context)
+ }
+ _ => unreachable!(
+ "Correct TypeInner for dot product should be already validated"
+ ),
+ },
+ Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
+ Mf::Cross => "cross",
+ Mf::Distance => "distance",
+ Mf::Length if scalar_argument => "abs",
+ Mf::Length => "length",
+ Mf::Normalize => "normalize",
+ Mf::FaceForward => "faceforward",
+ Mf::Reflect => "reflect",
+ Mf::Refract => "refract",
+ // computational
+ Mf::Sign => match arg_type.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => {
+ return self.put_isign(arg, context);
+ }
+ _ => "sign",
+ },
+ Mf::Fma => "fma",
+ Mf::Mix => "mix",
+ Mf::Step => "step",
+ Mf::SmoothStep => "smoothstep",
+ Mf::Sqrt => "sqrt",
+ Mf::InverseSqrt => "rsqrt",
+ Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
+ Mf::Transpose => "transpose",
+ Mf::Determinant => "determinant",
+ // bits
+ Mf::CountTrailingZeros => "ctz",
+ Mf::CountLeadingZeros => "clz",
+ Mf::CountOneBits => "popcount",
+ Mf::ReverseBits => "reverse_bits",
+ Mf::ExtractBits => "extract_bits",
+ Mf::InsertBits => "insert_bits",
+ Mf::FindLsb => "",
+ Mf::FindMsb => "",
+ // data packing
+ Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
+ Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
+ Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
+ Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
+ Mf::Pack2x16float => "",
+ // data unpacking
+ Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
+ Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
+ Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
+ Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
+ Mf::Unpack2x16float => "",
+ };
+
+ match fun {
+ Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
+ // reverse_bits is listed as requiring MSL 2.1 but that
+ // is a copy/paste error. Looking at previous snapshots
+ // on web.archive.org it's present in MSL 1.2.
+ //
+ // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
+ // also talks about MSL 1.2 adding "New integer
+ // functions to extract, insert, and reverse bits, as
+ // described in Integer Functions."
+ if context.lang_version < (1, 2) {
+ return Err(Error::UnsupportedFunction(fun_name.to_string()));
+ }
+ }
+ _ => {}
+ }
+
+ if fun == Mf::Distance && scalar_argument {
+ write!(self.out, "{NAMESPACE}::abs(")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, " - ")?;
+ self.put_expression(arg1.unwrap(), context, false)?;
+ write!(self.out, ")")?;
+ } else if fun == Mf::FindLsb {
+ write!(self.out, "((({NAMESPACE}::ctz(")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ") + 1) % 33) - 1)")?;
+ } else if fun == Mf::FindMsb {
+ let inner = context.resolve_type(arg);
+
+ write!(self.out, "{NAMESPACE}::select(31 - {NAMESPACE}::clz(")?;
+
+ if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
+ write!(self.out, "{NAMESPACE}::select(")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ", ~")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " < 0)")?;
+ } else {
+ self.put_expression(arg, context, true)?;
+ }
+
+ write!(self.out, "), ")?;
+
+ // or metal will complain that select is ambiguous
+ match *inner {
+ crate::TypeInner::Vector { size, scalar } => {
+ let size = back::vector_size_str(size);
+ if let crate::ScalarKind::Sint = scalar.kind {
+ write!(self.out, "int{size}")?;
+ } else {
+ write!(self.out, "uint{size}")?;
+ }
+ }
+ crate::TypeInner::Scalar(scalar) => {
+ if let crate::ScalarKind::Sint = scalar.kind {
+ write!(self.out, "int")?;
+ } else {
+ write!(self.out, "uint")?;
+ }
+ }
+ _ => (),
+ }
+
+ write!(self.out, "(-1), ")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " == 0 || ")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " == -1)")?;
+ } else if fun == Mf::Unpack2x16float {
+ write!(self.out, "float2(as_type<half2>(")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, "))")?;
+ } else if fun == Mf::Pack2x16float {
+ write!(self.out, "as_type<uint>(half2(")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, "))")?;
+ } else if fun == Mf::Radians {
+ write!(self.out, "((")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, ") * 0.017453292519943295474)")?;
+ } else if fun == Mf::Degrees {
+ write!(self.out, "((")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, ") * 57.295779513082322865)")?;
+ } else if fun == Mf::Modf || fun == Mf::Frexp {
+ write!(self.out, "{fun_name}")?;
+ self.put_call_parameters(iter::once(arg), context)?;
+ } else {
+ write!(self.out, "{NAMESPACE}::{fun_name}")?;
+ self.put_call_parameters(
+ iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
+ context,
+ )?;
+ }
+ }
+ crate::Expression::As {
+ expr,
+ kind,
+ convert,
+ } => match *context.resolve_type(expr) {
+ crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
+ let target_scalar = crate::Scalar {
+ kind,
+ width: convert.unwrap_or(src.width),
+ };
+ let is_bool_cast =
+ kind == crate::ScalarKind::Bool || src.kind == crate::ScalarKind::Bool;
+ let op = match convert {
+ Some(w) if w == src.width || is_bool_cast => "static_cast",
+ Some(8) if kind == crate::ScalarKind::Float => {
+ return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
+ }
+ Some(_) => return Err(Error::Validation),
+ None => "as_type",
+ };
+ write!(self.out, "{op}<")?;
+ match *context.resolve_type(expr) {
+ crate::TypeInner::Vector { size, .. } => {
+ put_numeric_type(&mut self.out, target_scalar, &[size])?
+ }
+ _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
+ };
+ write!(self.out, ">(")?;
+ self.put_expression(expr, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ let target_scalar = crate::Scalar {
+ kind,
+ width: convert.unwrap_or(scalar.width),
+ };
+ put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
+ write!(self.out, "(")?;
+ self.put_expression(expr, context, true)?;
+ write!(self.out, ")")?;
+ }
+ _ => return Err(Error::Validation),
+ },
+ // has to be a named expression
+ crate::Expression::CallResult(_)
+ | crate::Expression::AtomicResult { .. }
+ | crate::Expression::WorkGroupUniformLoadResult { .. }
+ | crate::Expression::RayQueryProceedResult => {
+ unreachable!()
+ }
+ crate::Expression::ArrayLength(expr) => {
+ // Find the global to which the array belongs.
+ let global = match context.function.expressions[expr] {
+ crate::Expression::AccessIndex { base, .. } => {
+ match context.function.expressions[base] {
+ crate::Expression::GlobalVariable(handle) => handle,
+ _ => return Err(Error::Validation),
+ }
+ }
+ crate::Expression::GlobalVariable(handle) => handle,
+ _ => return Err(Error::Validation),
+ };
+
+ if !is_scoped {
+ write!(self.out, "(")?;
+ }
+ write!(self.out, "1 + ")?;
+ self.put_dynamic_array_max_index(global, context)?;
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ }
+ crate::Expression::RayQueryGetIntersection { query, committed } => {
+ if context.lang_version < (2, 4) {
+ return Err(Error::UnsupportedRayTracing);
+ }
+
+ if !committed {
+ unimplemented!()
+ }
+ let ty = context.module.special_types.ray_intersection.unwrap();
+ let type_name = &self.names[&NameKey::Type(ty)];
+ write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
+ self.put_expression(query, context, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
+ let fields = [
+ "distance",
+ "user_instance_id", // req Metal 2.4
+ "instance_id",
+ "", // SBT offset
+ "geometry_id",
+ "primitive_id",
+ "triangle_barycentric_coord",
+ "triangle_front_facing",
+ "", // padding
+ "object_to_world_transform", // req Metal 2.4
+ "world_to_object_transform", // req Metal 2.4
+ ];
+ for field in fields {
+ write!(self.out, ", ")?;
+ if field.is_empty() {
+ write!(self.out, "{{}}")?;
+ } else {
+ self.put_expression(query, context, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
+ }
+ }
+ write!(self.out, "}}")?;
+ }
+ }
+ Ok(())
+ }
+
+ /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
+ fn put_wrapped_expression_for_packed_vec3_access(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
+ write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
+ self.put_expression(expr_handle, context, is_scoped)?;
+ write!(self.out, ")")?;
+ } else {
+ self.put_expression(expr_handle, context, is_scoped)?;
+ }
+ Ok(())
+ }
+
+ /// Write a `GuardedIndex` as a Metal expression.
+ fn put_index(
+ &mut self,
+ index: index::GuardedIndex,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ match index {
+ index::GuardedIndex::Expression(expr) => {
+ self.put_expression(expr, context, is_scoped)?
+ }
+ index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
+ }
+ Ok(())
+ }
+
+ /// Emit an index bounds check condition for `chain`, if required.
+ ///
+ /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
+ /// operating either on a pointer to a value, or on a value directly. If we cannot
+ /// statically determine that all indexing operations in `chain` are within
+ /// bounds, then write a conditional expression to check them dynamically,
+ /// and return true. All accesses in the chain are checked by the generated
+ /// expression.
+ ///
+ /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
+ ///
+ /// The text written is of the form:
+ ///
+ /// ```ignore
+ /// {level}{prefix}uint(i) < 4 && uint(j) < 10
+ /// ```
+ ///
+ /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
+ /// statements, presumably these arguments start an indented `if` statement; for
+ /// [`Load`] expressions, the caller is probably building up a ternary `?:`
+ /// expression. In either case, what is written is not a complete syntactic structure
+ /// in its own right, and the caller will have to finish it off if we return `true`.
+ ///
+ /// If no expression is written, return false.
+ ///
+ /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
+ /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
+ /// [`Store`]: crate::Statement::Store
+ /// [`Load`]: crate::Expression::Load
+ #[allow(unused_variables)]
+ fn put_bounds_checks(
+ &mut self,
+ mut chain: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ level: back::Level,
+ prefix: &'static str,
+ ) -> Result<bool, Error> {
+ let mut check_written = false;
+
+ // Iterate over the access chain, handling each expression.
+ loop {
+ // Produce a `GuardedIndex`, so we can shared code between the
+ // `Access` and `AccessIndex` cases.
+ let (base, guarded_index) = match context.function.expressions[chain] {
+ crate::Expression::Access { base, index } => {
+ (base, Some(index::GuardedIndex::Expression(index)))
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ // Don't try to check indices into structs. Validation already took
+ // care of them, and index::needs_guard doesn't handle that case.
+ let mut base_inner = context.resolve_type(base);
+ if let crate::TypeInner::Pointer { base, .. } = *base_inner {
+ base_inner = &context.module.types[base].inner;
+ }
+ match *base_inner {
+ crate::TypeInner::Struct { .. } => (base, None),
+ _ => (base, Some(index::GuardedIndex::Known(index))),
+ }
+ }
+ _ => break,
+ };
+
+ if let Some(index) = guarded_index {
+ if let Some(length) = context.access_needs_check(base, index) {
+ if check_written {
+ write!(self.out, " && ")?;
+ } else {
+ write!(self.out, "{level}{prefix}")?;
+ check_written = true;
+ }
+
+ // Check that the index falls within bounds. Do this with a single
+ // comparison, by casting the index to `uint` first, so that negative
+ // indices become large positive values.
+ write!(self.out, "uint(")?;
+ self.put_index(index, context, true)?;
+ self.out.write_str(") < ")?;
+ match length {
+ index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
+ index::IndexableLength::Dynamic => {
+ let global = context
+ .function
+ .originating_global(base)
+ .ok_or(Error::Validation)?;
+ write!(self.out, "1 + ")?;
+ self.put_dynamic_array_max_index(global, context)?
+ }
+ }
+ }
+ }
+
+ chain = base
+ }
+
+ Ok(check_written)
+ }
+
+ /// Write the access chain `chain`.
+ ///
+ /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
+ /// operating either on a pointer to a value, or on a value directly.
+ ///
+ /// Generate bounds checks code only if `policy` is [`Restrict`]. The
+ /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
+ /// that must be handled in the caller.
+ ///
+ /// Handle the entire chain, recursing back into `put_expression` only for index
+ /// expressions and the base expression that originates the pointer or composite value
+ /// being accessed. This allows `put_expression` to assume that any `Access` or
+ /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
+ /// `ReadZeroSkipWrite` checks.
+ ///
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
+ /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
+ fn put_access_chain(
+ &mut self,
+ chain: Handle<crate::Expression>,
+ policy: index::BoundsCheckPolicy,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match context.function.expressions[chain] {
+ crate::Expression::Access { base, index } => {
+ let mut base_ty = context.resolve_type(base);
+
+ // Look through any pointers to see what we're really indexing.
+ if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
+ base_ty = &context.module.types[base].inner;
+ }
+
+ self.put_subscripted_access_chain(
+ base,
+ base_ty,
+ index::GuardedIndex::Expression(index),
+ policy,
+ context,
+ )?;
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ let base_resolution = &context.info[base].ty;
+ let mut base_ty = base_resolution.inner_with(&context.module.types);
+ let mut base_ty_handle = base_resolution.handle();
+
+ // Look through any pointers to see what we're really indexing.
+ if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
+ base_ty = &context.module.types[base].inner;
+ base_ty_handle = Some(base);
+ }
+
+ // Handle structs and anything else that can use `.x` syntax here, so
+ // `put_subscripted_access_chain` won't have to handle the absurd case of
+ // indexing a struct with an expression.
+ match *base_ty {
+ crate::TypeInner::Struct { .. } => {
+ let base_ty = base_ty_handle.unwrap();
+ self.put_access_chain(base, policy, context)?;
+ let name = &self.names[&NameKey::StructMember(base_ty, index)];
+ write!(self.out, ".{name}")?;
+ }
+ crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
+ self.put_access_chain(base, policy, context)?;
+ // Prior to Metal v2.1 component access for packed vectors wasn't available
+ // however array indexing is
+ if context.get_packed_vec_kind(base).is_some() {
+ write!(self.out, "[{index}]")?;
+ } else {
+ write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
+ }
+ }
+ _ => {
+ self.put_subscripted_access_chain(
+ base,
+ base_ty,
+ index::GuardedIndex::Known(index),
+ policy,
+ context,
+ )?;
+ }
+ }
+ }
+ _ => self.put_expression(chain, context, false)?,
+ }
+
+ Ok(())
+ }
+
+ /// Write a `[]`-style access of `base` by `index`.
+ ///
+ /// If `policy` is [`Restrict`], then generate code as needed to force all index
+ /// values within bounds.
+ ///
+ /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
+ /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
+ /// removed. Our callers often already have this handy.
+ ///
+ /// This only emits `[]` expressions; it doesn't handle struct member accesses or
+ /// referencing vector components by name.
+ ///
+ /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
+ /// [`Array`]: crate::TypeInner::Array
+ /// [`Vector`]: crate::TypeInner::Vector
+ /// [`Pointer`]: crate::TypeInner::Pointer
+ fn put_subscripted_access_chain(
+ &mut self,
+ base: Handle<crate::Expression>,
+ base_ty: &crate::TypeInner,
+ index: index::GuardedIndex,
+ policy: index::BoundsCheckPolicy,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let accessing_wrapped_array = match *base_ty {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(_),
+ ..
+ } => true,
+ _ => false,
+ };
+
+ self.put_access_chain(base, policy, context)?;
+ if accessing_wrapped_array {
+ write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
+ }
+ write!(self.out, "[")?;
+
+ // Decide whether this index needs to be clamped to fall within range.
+ let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
+ context.access_needs_check(base, index)
+ } else {
+ None
+ };
+ if let Some(limit) = restriction_needed {
+ write!(self.out, "{NAMESPACE}::min(unsigned(")?;
+ self.put_index(index, context, true)?;
+ write!(self.out, "), ")?;
+ match limit {
+ index::IndexableLength::Known(limit) => {
+ write!(self.out, "{}u", limit - 1)?;
+ }
+ index::IndexableLength::Dynamic => {
+ let global = context
+ .function
+ .originating_global(base)
+ .ok_or(Error::Validation)?;
+ self.put_dynamic_array_max_index(global, context)?;
+ }
+ }
+ write!(self.out, ")")?;
+ } else {
+ self.put_index(index, context, true)?;
+ }
+
+ write!(self.out, "]")?;
+
+ Ok(())
+ }
+
+ fn put_load(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ // Since access chains never cross between address spaces, we can just
+ // check the index bounds check policy once at the top.
+ let policy = context.choose_bounds_check_policy(pointer);
+ if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(
+ pointer,
+ context,
+ back::Level(0),
+ if is_scoped { "" } else { "(" },
+ )?
+ {
+ write!(self.out, " ? ")?;
+ self.put_unchecked_load(pointer, policy, context)?;
+ write!(self.out, " : DefaultConstructible()")?;
+
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ } else {
+ self.put_unchecked_load(pointer, policy, context)?;
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_load(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ policy: index::BoundsCheckPolicy,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let is_atomic_pointer = context
+ .resolve_type(pointer)
+ .is_atomic_pointer(&context.module.types);
+
+ if is_atomic_pointer {
+ write!(
+ self.out,
+ "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
+ )?;
+ self.put_access_chain(pointer, policy, context)?;
+ write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
+ } else {
+ // We don't do any dereferencing with `*` here as pointer arguments to functions
+ // are done by `&` references and not `*` pointers. These do not need to be
+ // dereferenced.
+ self.put_access_chain(pointer, policy, context)?;
+ }
+
+ Ok(())
+ }
+
+ fn put_return_value(
+ &mut self,
+ level: back::Level,
+ expr_handle: Handle<crate::Expression>,
+ result_struct: Option<&str>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match result_struct {
+ Some(struct_name) => {
+ let mut has_point_size = false;
+ let result_ty = context.function.result.as_ref().unwrap().ty;
+ match context.module.types[result_ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let tmp = "_tmp";
+ write!(self.out, "{level}const auto {tmp} = ")?;
+ self.put_expression(expr_handle, context, true)?;
+ writeln!(self.out, ";")?;
+ write!(self.out, "{level}return {struct_name} {{")?;
+
+ let mut is_first = true;
+
+ for (index, member) in members.iter().enumerate() {
+ if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
+ member.binding
+ {
+ has_point_size = true;
+ if !context.pipeline_options.allow_and_force_point_size {
+ continue;
+ }
+ }
+
+ let comma = if is_first { "" } else { "," };
+ is_first = false;
+ let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
+ // HACK: we are forcefully deduplicating the expression here
+ // to convert from a wrapped struct to a raw array, e.g.
+ // `float gl_ClipDistance1 [[clip_distance]] [1];`.
+ if let crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(size),
+ ..
+ } = context.module.types[member.ty].inner
+ {
+ write!(self.out, "{comma} {{")?;
+ for j in 0..size.get() {
+ if j != 0 {
+ write!(self.out, ",")?;
+ }
+ write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
+ }
+ write!(self.out, "}}")?;
+ } else {
+ write!(self.out, "{comma} {tmp}.{name}")?;
+ }
+ }
+ }
+ _ => {
+ write!(self.out, "{level}return {struct_name} {{ ")?;
+ self.put_expression(expr_handle, context, true)?;
+ }
+ }
+
+ if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
+ let stage = context.module.entry_points[ep_index as usize].stage;
+ if context.pipeline_options.allow_and_force_point_size
+ && stage == crate::ShaderStage::Vertex
+ && !has_point_size
+ {
+ // point size was injected and comes last
+ write!(self.out, ", 1.0")?;
+ }
+ }
+ write!(self.out, " }}")?;
+ }
+ None => {
+ write!(self.out, "{level}return ")?;
+ self.put_expression(expr_handle, context, true)?;
+ }
+ }
+ writeln!(self.out, ";")?;
+ Ok(())
+ }
+
+ /// Helper method used to find which expressions of a given function require baking
+ ///
+ /// # Notes
+ /// This function overwrites the contents of `self.need_bake_expressions`
+ fn update_expressions_to_bake(
+ &mut self,
+ func: &crate::Function,
+ info: &valid::FunctionInfo,
+ context: &ExpressionContext,
+ ) {
+ use crate::Expression;
+ self.need_bake_expressions.clear();
+
+ for (expr_handle, expr) in func.expressions.iter() {
+ // Expressions whose reference count is above the
+ // threshold should always be stored in temporaries.
+ let expr_info = &info[expr_handle];
+ let min_ref_count = func.expressions[expr_handle].bake_ref_count();
+ if min_ref_count <= expr_info.ref_count {
+ self.need_bake_expressions.insert(expr_handle);
+ } else {
+ match expr_info.ty {
+ // force ray desc to be baked: it's used multiple times internally
+ TypeResolution::Handle(h)
+ if Some(h) == context.module.special_types.ray_desc =>
+ {
+ self.need_bake_expressions.insert(expr_handle);
+ }
+ _ => {}
+ }
+ }
+
+ if let Expression::Math { fun, arg, arg1, .. } = *expr {
+ match fun {
+ crate::MathFunction::Dot => {
+ // WGSL's `dot` function works on any `vecN` type, but Metal's only
+ // works on floating-point vectors, so we emit inline code for
+ // integer vector `dot` calls. But that code uses each argument `N`
+ // times, once for each component (see `put_dot_product`), so to
+ // avoid duplicated evaluation, we must bake integer operands.
+
+ // check what kind of product this is depending
+ // on the resolve type of the Dot function itself
+ let inner = context.resolve_type(expr_handle);
+ if let crate::TypeInner::Scalar(scalar) = *inner {
+ match scalar.kind {
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
+ self.need_bake_expressions.insert(arg);
+ self.need_bake_expressions.insert(arg1.unwrap());
+ }
+ _ => {}
+ }
+ }
+ }
+ crate::MathFunction::FindMsb => {
+ self.need_bake_expressions.insert(arg);
+ }
+ crate::MathFunction::Sign => {
+ // WGSL's `sign` function works also on signed ints, but Metal's only
+ // works on floating points, so we emit inline code for integer `sign`
+ // calls. But that code uses each argument 2 times (see `put_isign`),
+ // so to avoid duplicated evaluation, we must bake the argument.
+ let inner = context.resolve_type(expr_handle);
+ if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
+ self.need_bake_expressions.insert(arg);
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+ }
+
+ fn start_baking_expression(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ name: &str,
+ ) -> BackendResult {
+ match context.info[handle].ty {
+ TypeResolution::Handle(ty_handle) => {
+ let ty_name = TypeContext {
+ handle: ty_handle,
+ gctx: context.module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{ty_name}")?;
+ }
+ TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
+ put_numeric_type(&mut self.out, scalar, &[])?;
+ }
+ TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
+ put_numeric_type(&mut self.out, scalar, &[size])?;
+ }
+ TypeResolution::Value(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ }) => {
+ put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
+ }
+ TypeResolution::Value(ref other) => {
+ log::warn!("Type {:?} isn't a known local", other); //TEMP!
+ return Err(Error::FeatureNotImplemented("weird local type".to_string()));
+ }
+ }
+
+ //TODO: figure out the naming scheme that wouldn't collide with user names.
+ write!(self.out, " {name} = ")?;
+
+ Ok(())
+ }
+
+ /// Cache a clamped level of detail value, if necessary.
+ ///
+ /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
+ /// properly clamped level of detail value both in the access itself, and
+ /// for fetching the size of the requested MIP level, needed to clamp the
+ /// coordinates. To avoid recomputing this clamped level of detail, we cache
+ /// it in a temporary variable, as part of the [`Emit`] statement covering
+ /// the [`ImageLoad`] expression.
+ ///
+ /// [`ImageLoad`]: crate::Expression::ImageLoad
+ /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
+ /// [`Emit`]: crate::Statement::Emit
+ fn put_cache_restricted_level(
+ &mut self,
+ load: Handle<crate::Expression>,
+ image: Handle<crate::Expression>,
+ mip_level: Option<Handle<crate::Expression>>,
+ indent: back::Level,
+ context: &StatementContext,
+ ) -> BackendResult {
+ // Does this image access actually require (or even permit) a
+ // level-of-detail, and does the policy require us to restrict it?
+ let level_of_detail = match mip_level {
+ Some(level) => level,
+ None => return Ok(()),
+ };
+
+ if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
+ || !context.expression.image_needs_lod(image)
+ {
+ return Ok(());
+ }
+
+ write!(
+ self.out,
+ "{}uint {}{} = ",
+ indent,
+ CLAMPED_LOD_LOAD_PREFIX,
+ load.index(),
+ )?;
+ self.put_restricted_scalar_image_index(
+ image,
+ level_of_detail,
+ "get_num_mip_levels",
+ &context.expression,
+ )?;
+ writeln!(self.out, ";")?;
+
+ Ok(())
+ }
+
+ fn put_block(
+ &mut self,
+ level: back::Level,
+ statements: &[crate::Statement],
+ context: &StatementContext,
+ ) -> BackendResult {
+ // Add to the set in order to track the stack size.
+ #[cfg(test)]
+ #[allow(trivial_casts)]
+ self.put_block_stack_pointers
+ .insert(&level as *const _ as *const ());
+
+ for statement in statements {
+ log::trace!("statement[{}] {:?}", level.0, statement);
+ match *statement {
+ crate::Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ // `ImageLoad` expressions covered by the `Restrict` bounds check policy
+ // may need to cache a clamped version of their level-of-detail argument.
+ if let crate::Expression::ImageLoad {
+ image,
+ level: mip_level,
+ ..
+ } = context.expression.function.expressions[handle]
+ {
+ self.put_cache_restricted_level(
+ handle, image, mip_level, level, context,
+ )?;
+ }
+
+ let ptr_class = context.expression.resolve_type(handle).pointer_space();
+ let expr_name = if ptr_class.is_some() {
+ None // don't bake pointer expressions (just yet)
+ } else if let Some(name) =
+ context.expression.function.named_expressions.get(&handle)
+ {
+ // The `crate::Function::named_expressions` table holds
+ // expressions that should be saved in temporaries once they
+ // are `Emit`ted. We only add them to `self.named_expressions`
+ // when we reach the `Emit` that covers them, so that we don't
+ // try to use their names before we've actually initialized
+ // the temporary that holds them.
+ //
+ // Don't assume the names in `named_expressions` are unique,
+ // or even valid. Use the `Namer`.
+ Some(self.namer.call(name))
+ } else {
+ // If this expression is an index that we're going to first compare
+ // against a limit, and then actually use as an index, then we may
+ // want to cache it in a temporary, to avoid evaluating it twice.
+ let bake =
+ if context.expression.guarded_indices.contains(handle.index()) {
+ true
+ } else {
+ self.need_bake_expressions.contains(&handle)
+ };
+
+ if bake {
+ Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
+ } else {
+ None
+ }
+ };
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{level}")?;
+ self.start_baking_expression(handle, &context.expression, &name)?;
+ self.put_expression(handle, &context.expression, true)?;
+ self.named_expressions.insert(handle, name);
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+ crate::Statement::Block(ref block) => {
+ if !block.is_empty() {
+ writeln!(self.out, "{level}{{")?;
+ self.put_block(level.next(), block, context)?;
+ writeln!(self.out, "{level}}}")?;
+ }
+ }
+ crate::Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{level}if (")?;
+ self.put_expression(condition, &context.expression, true)?;
+ writeln!(self.out, ") {{")?;
+ self.put_block(level.next(), accept, context)?;
+ if !reject.is_empty() {
+ writeln!(self.out, "{level}}} else {{")?;
+ self.put_block(level.next(), reject, context)?;
+ }
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ write!(self.out, "{level}switch(")?;
+ self.put_expression(selector, &context.expression, true)?;
+ writeln!(self.out, ") {{")?;
+ let lcase = level.next();
+ for case in cases.iter() {
+ match case.value {
+ crate::SwitchValue::I32(value) => {
+ write!(self.out, "{lcase}case {value}:")?;
+ }
+ crate::SwitchValue::U32(value) => {
+ write!(self.out, "{lcase}case {value}u:")?;
+ }
+ crate::SwitchValue::Default => {
+ write!(self.out, "{lcase}default:")?;
+ }
+ }
+
+ let write_block_braces = !(case.fall_through && case.body.is_empty());
+ if write_block_braces {
+ writeln!(self.out, " {{")?;
+ } else {
+ writeln!(self.out)?;
+ }
+
+ self.put_block(lcase.next(), &case.body, context)?;
+ if !case.fall_through
+ && case.body.last().map_or(true, |s| !s.is_terminator())
+ {
+ writeln!(self.out, "{}break;", lcase.next())?;
+ }
+
+ if write_block_braces {
+ writeln!(self.out, "{lcase}}}")?;
+ }
+ }
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ if !continuing.is_empty() || break_if.is_some() {
+ let gate_name = self.namer.call("loop_init");
+ writeln!(self.out, "{level}bool {gate_name} = true;")?;
+ writeln!(self.out, "{level}while(true) {{")?;
+ let lif = level.next();
+ let lcontinuing = lif.next();
+ writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
+ self.put_block(lcontinuing, continuing, context)?;
+ if let Some(condition) = break_if {
+ write!(self.out, "{lcontinuing}if (")?;
+ self.put_expression(condition, &context.expression, true)?;
+ writeln!(self.out, ") {{")?;
+ writeln!(self.out, "{}break;", lcontinuing.next())?;
+ writeln!(self.out, "{lcontinuing}}}")?;
+ }
+ writeln!(self.out, "{lif}}}")?;
+ writeln!(self.out, "{lif}{gate_name} = false;")?;
+ } else {
+ writeln!(self.out, "{level}while(true) {{")?;
+ }
+ self.put_block(level.next(), body, context)?;
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::Statement::Break => {
+ writeln!(self.out, "{level}break;")?;
+ }
+ crate::Statement::Continue => {
+ writeln!(self.out, "{level}continue;")?;
+ }
+ crate::Statement::Return {
+ value: Some(expr_handle),
+ } => {
+ self.put_return_value(
+ level,
+ expr_handle,
+ context.result_struct,
+ &context.expression,
+ )?;
+ }
+ crate::Statement::Return { value: None } => {
+ writeln!(self.out, "{level}return;")?;
+ }
+ crate::Statement::Kill => {
+ writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
+ }
+ crate::Statement::Barrier(flags) => {
+ self.write_barrier(flags, level)?;
+ }
+ crate::Statement::Store { pointer, value } => {
+ self.put_store(pointer, value, level, context)?
+ }
+ crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ let address = TexelAddress {
+ coordinate,
+ array_index,
+ sample: None,
+ level: None,
+ };
+ self.put_image_store(level, image, &address, value, context)?
+ }
+ crate::Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ if let Some(expr) = result {
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ self.start_baking_expression(expr, &context.expression, &name)?;
+ self.named_expressions.insert(expr, name);
+ }
+ let fun_name = &self.names[&NameKey::Function(function)];
+ write!(self.out, "{fun_name}(")?;
+ // first, write down the actual arguments
+ for (i, &handle) in arguments.iter().enumerate() {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.put_expression(handle, &context.expression, true)?;
+ }
+ // follow-up with any global resources used
+ let mut separate = !arguments.is_empty();
+ let fun_info = &context.expression.mod_info[function];
+ let mut supports_array_length = false;
+ for (handle, var) in context.expression.module.global_variables.iter() {
+ if fun_info[handle].is_empty() {
+ continue;
+ }
+ if var.space.needs_pass_through() {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ if separate {
+ write!(self.out, ", ")?;
+ } else {
+ separate = true;
+ }
+ write!(self.out, "{name}")?;
+ }
+ supports_array_length |=
+ needs_array_length(var.ty, &context.expression.module.types);
+ }
+ if supports_array_length {
+ if separate {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "_buffer_sizes")?;
+ }
+
+ // done
+ writeln!(self.out, ");")?;
+ }
+ crate::Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_baking_expression(result, &context.expression, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+ match *fun {
+ crate::AtomicFunction::Add => {
+ self.put_atomic_fetch(pointer, "add", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Subtract => {
+ self.put_atomic_fetch(pointer, "sub", value, &context.expression)?;
+ }
+ crate::AtomicFunction::And => {
+ self.put_atomic_fetch(pointer, "and", value, &context.expression)?;
+ }
+ crate::AtomicFunction::InclusiveOr => {
+ self.put_atomic_fetch(pointer, "or", value, &context.expression)?;
+ }
+ crate::AtomicFunction::ExclusiveOr => {
+ self.put_atomic_fetch(pointer, "xor", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Min => {
+ self.put_atomic_fetch(pointer, "min", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Max => {
+ self.put_atomic_fetch(pointer, "max", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Exchange { compare: None } => {
+ self.put_atomic_operation(
+ pointer,
+ "exchange",
+ "",
+ value,
+ &context.expression,
+ )?;
+ }
+ crate::AtomicFunction::Exchange { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "atomic CompareExchange".to_string(),
+ ));
+ }
+ }
+ // done
+ writeln!(self.out, ";")?;
+ }
+ crate::Statement::WorkGroupUniformLoad { pointer, result } => {
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+
+ write!(self.out, "{level}")?;
+ let name = self.namer.call("");
+ self.start_baking_expression(result, &context.expression, &name)?;
+ self.put_load(pointer, &context.expression, true)?;
+ self.named_expressions.insert(result, name);
+
+ writeln!(self.out, ";")?;
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+ }
+ crate::Statement::RayQuery { query, ref fun } => {
+ if context.expression.lang_version < (2, 4) {
+ return Err(Error::UnsupportedRayTracing);
+ }
+
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ //TODO: how to deal with winding?
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
+ {
+ let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
+ let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(
+ self.out,
+ ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
+ )?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
+ writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
+ }
+ {
+ let f_opaque = back::RayFlag::OPAQUE.bits();
+ let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
+ writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
+ }
+ {
+ let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(
+ self.out,
+ ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
+ )?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ writeln!(self.out, ".flags & {flag}) != 0);")?;
+ }
+
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(
+ self.out,
+ ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
+ )?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".origin, ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".dir, ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".tmin, ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".tmax), ")?;
+ self.put_expression(acceleration_structure, &context.expression, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".cull_mask);")?;
+
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ write!(self.out, "{level}")?;
+ let name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_baking_expression(result, &context.expression, &name)?;
+ self.named_expressions.insert(result, name);
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
+ //TODO: actually proceed?
+
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
+ }
+ crate::RayQueryFunction::Terminate => {
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.abort();")?;
+ }
+ }
+ }
+ }
+ }
+
+ // un-emit expressions
+ //TODO: take care of loop/continuing?
+ for statement in statements {
+ if let crate::Statement::Emit(ref range) = *statement {
+ for handle in range.clone() {
+ self.named_expressions.remove(&handle);
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn put_store(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ value: Handle<crate::Expression>,
+ level: back::Level,
+ context: &StatementContext,
+ ) -> BackendResult {
+ let policy = context.expression.choose_bounds_check_policy(pointer);
+ if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
+ {
+ writeln!(self.out, ") {{")?;
+ self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
+ writeln!(self.out, "{level}}}")?;
+ } else {
+ self.put_unchecked_store(pointer, value, policy, level, context)?;
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_store(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ value: Handle<crate::Expression>,
+ policy: index::BoundsCheckPolicy,
+ level: back::Level,
+ context: &StatementContext,
+ ) -> BackendResult {
+ let is_atomic_pointer = context
+ .expression
+ .resolve_type(pointer)
+ .is_atomic_pointer(&context.expression.module.types);
+
+ if is_atomic_pointer {
+ write!(
+ self.out,
+ "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
+ )?;
+ self.put_access_chain(pointer, policy, &context.expression)?;
+ write!(self.out, ", ")?;
+ self.put_expression(value, &context.expression, true)?;
+ writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
+ } else {
+ write!(self.out, "{level}")?;
+ self.put_access_chain(pointer, policy, &context.expression)?;
+ write!(self.out, " = ")?;
+ self.put_expression(value, &context.expression, true)?;
+ writeln!(self.out, ";")?;
+ }
+
+ Ok(())
+ }
+
+ pub fn write(
+ &mut self,
+ module: &crate::Module,
+ info: &valid::ModuleInfo,
+ options: &Options,
+ pipeline_options: &PipelineOptions,
+ ) -> Result<TranslationInfo, Error> {
+ self.names.clear();
+ self.namer.reset(
+ module,
+ super::keywords::RESERVED,
+ &[],
+ &[],
+ &[CLAMPED_LOD_LOAD_PREFIX],
+ &mut self.names,
+ );
+ self.struct_member_pads.clear();
+
+ writeln!(
+ self.out,
+ "// language: metal{}.{}",
+ options.lang_version.0, options.lang_version.1
+ )?;
+ writeln!(self.out, "#include <metal_stdlib>")?;
+ writeln!(self.out, "#include <simd/simd.h>")?;
+ writeln!(self.out)?;
+ // Work around Metal bug where `uint` is not available by default
+ writeln!(self.out, "using {NAMESPACE}::uint;")?;
+
+ let mut uses_ray_query = false;
+ for (_, ty) in module.types.iter() {
+ match ty.inner {
+ crate::TypeInner::AccelerationStructure => {
+ if options.lang_version < (2, 4) {
+ return Err(Error::UnsupportedRayTracing);
+ }
+ }
+ crate::TypeInner::RayQuery => {
+ if options.lang_version < (2, 4) {
+ return Err(Error::UnsupportedRayTracing);
+ }
+ uses_ray_query = true;
+ }
+ _ => (),
+ }
+ }
+
+ if module.special_types.ray_desc.is_some()
+ || module.special_types.ray_intersection.is_some()
+ {
+ if options.lang_version < (2, 4) {
+ return Err(Error::UnsupportedRayTracing);
+ }
+ }
+
+ if uses_ray_query {
+ self.put_ray_query_type()?;
+ }
+
+ if options
+ .bounds_check_policies
+ .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
+ {
+ self.put_default_constructible()?;
+ }
+ writeln!(self.out)?;
+
+ {
+ let mut indices = vec![];
+ for (handle, var) in module.global_variables.iter() {
+ if needs_array_length(var.ty, &module.types) {
+ let idx = handle.index();
+ indices.push(idx);
+ }
+ }
+
+ if !indices.is_empty() {
+ writeln!(self.out, "struct _mslBufferSizes {{")?;
+
+ for idx in indices {
+ writeln!(self.out, "{}uint size{};", back::INDENT, idx)?;
+ }
+
+ writeln!(self.out, "}};")?;
+ writeln!(self.out)?;
+ }
+ };
+
+ self.write_type_defs(module)?;
+ self.write_global_constants(module, info)?;
+ self.write_functions(module, info, options, pipeline_options)
+ }
+
+ /// Write the definition for the `DefaultConstructible` class.
+ ///
+ /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
+ /// produce 'zero' values for any type, including structs, arrays, and so
+ /// on. We could do this by emitting default constructor applications, but
+ /// that would entail printing the name of the type, which is more trouble
+ /// than you'd think. Instead, we just construct this magic C++14 class that
+ /// can be converted to any type that can be default constructed, using
+ /// template parameter inference to detect which type is needed, so we don't
+ /// have to figure out the name.
+ ///
+ /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
+ fn put_default_constructible(&mut self) -> BackendResult {
+ let tab = back::INDENT;
+ writeln!(self.out, "struct DefaultConstructible {{")?;
+ writeln!(self.out, "{tab}template<typename T>")?;
+ writeln!(self.out, "{tab}operator T() && {{")?;
+ writeln!(self.out, "{tab}{tab}return T {{}};")?;
+ writeln!(self.out, "{tab}}}")?;
+ writeln!(self.out, "}};")?;
+ Ok(())
+ }
+
+ fn put_ray_query_type(&mut self) -> BackendResult {
+ let tab = back::INDENT;
+ writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
+ let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
+ writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
+ writeln!(
+ self.out,
+ "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
+ )?;
+ writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
+ writeln!(self.out, "}};")?;
+ writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
+ let v_triangle = back::RayIntersectionType::Triangle as u32;
+ let v_bbox = back::RayIntersectionType::BoundingBox as u32;
+ writeln!(
+ self.out,
+ "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
+ )?;
+ writeln!(
+ self.out,
+ "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
+ )?;
+ writeln!(self.out, "}}")?;
+ Ok(())
+ }
+
+ fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
+ for (handle, ty) in module.types.iter() {
+ if !ty.needs_alias() {
+ continue;
+ }
+ let name = &self.names[&NameKey::Type(handle)];
+ match ty.inner {
+ // Naga IR can pass around arrays by value, but Metal, following
+ // C++, performs an array-to-pointer conversion (C++ [conv.array])
+ // on expressions of array type, so assigning the array by value
+ // isn't possible. However, Metal *does* assign structs by
+ // value. So in our Metal output, we wrap all array types in
+ // synthetic struct types:
+ //
+ // struct type1 {
+ // float inner[10]
+ // };
+ //
+ // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
+ // any expression that actually wants access to the array.
+ crate::TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ } => {
+ let base_name = TypeContext {
+ handle: base,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+
+ match size {
+ crate::ArraySize::Constant(size) => {
+ writeln!(self.out, "struct {name} {{")?;
+ writeln!(
+ self.out,
+ "{}{} {}[{}];",
+ back::INDENT,
+ base_name,
+ WRAPPED_ARRAY_FIELD,
+ size
+ )?;
+ writeln!(self.out, "}};")?;
+ }
+ crate::ArraySize::Dynamic => {
+ writeln!(self.out, "typedef {base_name} {name}[1];")?;
+ }
+ }
+ }
+ crate::TypeInner::Struct {
+ ref members, span, ..
+ } => {
+ writeln!(self.out, "struct {name} {{")?;
+ let mut last_offset = 0;
+ for (index, member) in members.iter().enumerate() {
+ if member.offset > last_offset {
+ self.struct_member_pads.insert((handle, index as u32));
+ let pad = member.offset - last_offset;
+ writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
+ }
+ let ty_inner = &module.types[member.ty].inner;
+ last_offset = member.offset + ty_inner.size(module.to_ctx());
+
+ let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
+
+ // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
+ match should_pack_struct_member(members, span, index, module) {
+ Some(scalar) => {
+ writeln!(
+ self.out,
+ "{}{}::packed_{}3 {};",
+ back::INDENT,
+ NAMESPACE,
+ scalar.to_msl_name(),
+ member_name
+ )?;
+ }
+ None => {
+ let base_name = TypeContext {
+ handle: member.ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ writeln!(
+ self.out,
+ "{}{} {};",
+ back::INDENT,
+ base_name,
+ member_name
+ )?;
+
+ // for 3-component vectors, add one component
+ if let crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ scalar,
+ } = *ty_inner
+ {
+ last_offset += scalar.width as u32;
+ }
+ }
+ }
+ }
+ writeln!(self.out, "}};")?;
+ }
+ _ => {
+ let ty_name = TypeContext {
+ handle,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: true,
+ };
+ writeln!(self.out, "typedef {ty_name} {name};")?;
+ }
+ }
+ }
+
+ // Write functions to create special types.
+ for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
+ match type_key {
+ &crate::PredeclaredType::ModfResult { size, width }
+ | &crate::PredeclaredType::FrexpResult { size, width } => {
+ let arg_type_name_owner;
+ let arg_type_name = if let Some(size) = size {
+ arg_type_name_owner = format!(
+ "{NAMESPACE}::{}{}",
+ if width == 8 { "double" } else { "float" },
+ size as u8
+ );
+ &arg_type_name_owner
+ } else if width == 8 {
+ "double"
+ } else {
+ "float"
+ };
+
+ let other_type_name_owner;
+ let (defined_func_name, called_func_name, other_type_name) =
+ if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
+ (MODF_FUNCTION, "modf", arg_type_name)
+ } else {
+ let other_type_name = if let Some(size) = size {
+ other_type_name_owner = format!("int{}", size as u8);
+ &other_type_name_owner
+ } else {
+ "int"
+ };
+ (FREXP_FUNCTION, "frexp", other_type_name)
+ };
+
+ let struct_name = &self.names[&NameKey::Type(*struct_ty)];
+
+ writeln!(self.out)?;
+ writeln!(
+ self.out,
+ "{} {defined_func_name}({arg_type_name} arg) {{
+ {other_type_name} other;
+ {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
+ return {}{{ fract, other }};
+}}",
+ struct_name, struct_name
+ )?;
+ }
+ &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Writes all named constants
+ fn write_global_constants(
+ &mut self,
+ module: &crate::Module,
+ mod_info: &valid::ModuleInfo,
+ ) -> BackendResult {
+ let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
+
+ for (handle, constant) in constants {
+ let ty_name = TypeContext {
+ handle: constant.ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let name = &self.names[&NameKey::Constant(handle)];
+ write!(self.out, "constant {ty_name} {name} = ")?;
+ self.put_const_expression(constant.init, module, mod_info)?;
+ writeln!(self.out, ";")?;
+ }
+
+ Ok(())
+ }
+
+ fn put_inline_sampler_properties(
+ &mut self,
+ level: back::Level,
+ sampler: &sm::InlineSampler,
+ ) -> BackendResult {
+ for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
+ writeln!(
+ self.out,
+ "{}{}::{}_address::{},",
+ level,
+ NAMESPACE,
+ letter,
+ address.as_str(),
+ )?;
+ }
+ writeln!(
+ self.out,
+ "{}{}::mag_filter::{},",
+ level,
+ NAMESPACE,
+ sampler.mag_filter.as_str(),
+ )?;
+ writeln!(
+ self.out,
+ "{}{}::min_filter::{},",
+ level,
+ NAMESPACE,
+ sampler.min_filter.as_str(),
+ )?;
+ if let Some(filter) = sampler.mip_filter {
+ writeln!(
+ self.out,
+ "{}{}::mip_filter::{},",
+ level,
+ NAMESPACE,
+ filter.as_str(),
+ )?;
+ }
+ // avoid setting it on platforms that don't support it
+ if sampler.border_color != sm::BorderColor::TransparentBlack {
+ writeln!(
+ self.out,
+ "{}{}::border_color::{},",
+ level,
+ NAMESPACE,
+ sampler.border_color.as_str(),
+ )?;
+ }
+ //TODO: I'm not able to feed this in a way that MSL likes:
+ //>error: use of undeclared identifier 'lod_clamp'
+ //>error: no member named 'max_anisotropy' in namespace 'metal'
+ if false {
+ if let Some(ref lod) = sampler.lod_clamp {
+ writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
+ }
+ if let Some(aniso) = sampler.max_anisotropy {
+ writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
+ }
+ }
+ if sampler.compare_func != sm::CompareFunc::Never {
+ writeln!(
+ self.out,
+ "{}{}::compare_func::{},",
+ level,
+ NAMESPACE,
+ sampler.compare_func.as_str(),
+ )?;
+ }
+ writeln!(
+ self.out,
+ "{}{}::coord::{}",
+ level,
+ NAMESPACE,
+ sampler.coord.as_str()
+ )?;
+ Ok(())
+ }
+
+ // Returns the array of mapped entry point names.
+ fn write_functions(
+ &mut self,
+ module: &crate::Module,
+ mod_info: &valid::ModuleInfo,
+ options: &Options,
+ pipeline_options: &PipelineOptions,
+ ) -> Result<TranslationInfo, Error> {
+ let mut pass_through_globals = Vec::new();
+ for (fun_handle, fun) in module.functions.iter() {
+ log::trace!(
+ "function {:?}, handle {:?}",
+ fun.name.as_deref().unwrap_or("(anonymous)"),
+ fun_handle
+ );
+
+ let fun_info = &mod_info[fun_handle];
+ pass_through_globals.clear();
+ let mut supports_array_length = false;
+ for (handle, var) in module.global_variables.iter() {
+ if !fun_info[handle].is_empty() {
+ if var.space.needs_pass_through() {
+ pass_through_globals.push(handle);
+ }
+ supports_array_length |= needs_array_length(var.ty, &module.types);
+ }
+ }
+
+ writeln!(self.out)?;
+ let fun_name = &self.names[&NameKey::Function(fun_handle)];
+ match fun.result {
+ Some(ref result) => {
+ let ty_name = TypeContext {
+ handle: result.ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{ty_name}")?;
+ }
+ None => {
+ write!(self.out, "void")?;
+ }
+ }
+ writeln!(self.out, " {fun_name}(")?;
+
+ for (index, arg) in fun.arguments.iter().enumerate() {
+ let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
+ let param_type_name = TypeContext {
+ handle: arg.ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let separator = separate(
+ !pass_through_globals.is_empty()
+ || index + 1 != fun.arguments.len()
+ || supports_array_length,
+ );
+ writeln!(
+ self.out,
+ "{}{} {}{}",
+ back::INDENT,
+ param_type_name,
+ name,
+ separator
+ )?;
+ }
+ for (index, &handle) in pass_through_globals.iter().enumerate() {
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage: fun_info[handle],
+ binding: None,
+ reference: true,
+ };
+ let separator =
+ separate(index + 1 != pass_through_globals.len() || supports_array_length);
+ write!(self.out, "{}", back::INDENT)?;
+ tyvar.try_fmt(&mut self.out)?;
+ writeln!(self.out, "{separator}")?;
+ }
+
+ if supports_array_length {
+ writeln!(
+ self.out,
+ "{}constant _mslBufferSizes& _buffer_sizes",
+ back::INDENT
+ )?;
+ }
+
+ writeln!(self.out, ") {{")?;
+
+ let guarded_indices =
+ index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
+
+ let context = StatementContext {
+ expression: ExpressionContext {
+ function: fun,
+ origin: FunctionOrigin::Handle(fun_handle),
+ info: fun_info,
+ lang_version: options.lang_version,
+ policies: options.bounds_check_policies,
+ guarded_indices,
+ module,
+ mod_info,
+ pipeline_options,
+ },
+ result_struct: None,
+ };
+
+ for (local_handle, local) in fun.local_variables.iter() {
+ let ty_name = TypeContext {
+ handle: local.ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let local_name = &self.names[&NameKey::FunctionLocal(fun_handle, local_handle)];
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, local_name)?;
+ match local.init {
+ Some(value) => {
+ write!(self.out, " = ")?;
+ self.put_expression(value, &context.expression, true)?;
+ }
+ None => {
+ write!(self.out, " = {{}}")?;
+ }
+ };
+ writeln!(self.out, ";")?;
+ }
+
+ self.update_expressions_to_bake(fun, fun_info, &context.expression);
+ self.put_block(back::Level(1), &fun.body, &context)?;
+ writeln!(self.out, "}}")?;
+ self.named_expressions.clear();
+ }
+
+ let mut info = TranslationInfo {
+ entry_point_names: Vec::with_capacity(module.entry_points.len()),
+ };
+ for (ep_index, ep) in module.entry_points.iter().enumerate() {
+ let fun = &ep.function;
+ let fun_info = mod_info.get_entry_point(ep_index);
+ let mut ep_error = None;
+
+ log::trace!(
+ "entry point {:?}, index {:?}",
+ fun.name.as_deref().unwrap_or("(anonymous)"),
+ ep_index
+ );
+
+ // Is any global variable used by this entry point dynamically sized?
+ let supports_array_length = module
+ .global_variables
+ .iter()
+ .filter(|&(handle, _)| !fun_info[handle].is_empty())
+ .any(|(_, var)| needs_array_length(var.ty, &module.types));
+
+ // skip this entry point if any global bindings are missing,
+ // or their types are incompatible.
+ if !options.fake_missing_bindings {
+ for (var_handle, var) in module.global_variables.iter() {
+ if fun_info[var_handle].is_empty() {
+ continue;
+ }
+ match var.space {
+ crate::AddressSpace::Uniform
+ | crate::AddressSpace::Storage { .. }
+ | crate::AddressSpace::Handle => {
+ let br = match var.binding {
+ Some(ref br) => br,
+ None => {
+ let var_name = var.name.clone().unwrap_or_default();
+ ep_error =
+ Some(super::EntryPointError::MissingBinding(var_name));
+ break;
+ }
+ };
+ let target = options.get_resource_binding_target(ep, br);
+ let good = match target {
+ Some(target) => {
+ let binding_ty = match module.types[var.ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => {
+ &module.types[base].inner
+ }
+ ref ty => ty,
+ };
+ match *binding_ty {
+ crate::TypeInner::Image { .. } => target.texture.is_some(),
+ crate::TypeInner::Sampler { .. } => {
+ target.sampler.is_some()
+ }
+ _ => target.buffer.is_some(),
+ }
+ }
+ None => false,
+ };
+ if !good {
+ ep_error =
+ Some(super::EntryPointError::MissingBindTarget(br.clone()));
+ break;
+ }
+ }
+ crate::AddressSpace::PushConstant => {
+ if let Err(e) = options.resolve_push_constants(ep) {
+ ep_error = Some(e);
+ break;
+ }
+ }
+ crate::AddressSpace::Function
+ | crate::AddressSpace::Private
+ | crate::AddressSpace::WorkGroup => {}
+ }
+ }
+ if supports_array_length {
+ if let Err(err) = options.resolve_sizes_buffer(ep) {
+ ep_error = Some(err);
+ }
+ }
+ }
+
+ if let Some(err) = ep_error {
+ info.entry_point_names.push(Err(err));
+ continue;
+ }
+ let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
+ info.entry_point_names.push(Ok(fun_name.clone()));
+
+ writeln!(self.out)?;
+
+ let (em_str, in_mode, out_mode) = match ep.stage {
+ crate::ShaderStage::Vertex => (
+ "vertex",
+ LocationMode::VertexInput,
+ LocationMode::VertexOutput,
+ ),
+ crate::ShaderStage::Fragment { .. } => (
+ "fragment",
+ LocationMode::FragmentInput,
+ LocationMode::FragmentOutput,
+ ),
+ crate::ShaderStage::Compute { .. } => {
+ ("kernel", LocationMode::Uniform, LocationMode::Uniform)
+ }
+ };
+
+ // Since `Namer.reset` wasn't expecting struct members to be
+ // suddenly injected into another namespace like this,
+ // `self.names` doesn't keep them distinct from other variables.
+ // Generate fresh names for these arguments, and remember the
+ // mapping.
+ let mut flattened_member_names = FastHashMap::default();
+ // Varyings' members get their own namespace
+ let mut varyings_namer = crate::proc::Namer::default();
+
+ // List all the Naga `EntryPoint`'s `Function`'s arguments,
+ // flattening structs into their members. In Metal, we will pass
+ // each of these values to the entry point as a separate argument—
+ // except for the varyings, handled next.
+ let mut flattened_arguments = Vec::new();
+ for (arg_index, arg) in fun.arguments.iter().enumerate() {
+ match module.types[arg.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (member_index, member) in members.iter().enumerate() {
+ let member_index = member_index as u32;
+ flattened_arguments.push((
+ NameKey::StructMember(arg.ty, member_index),
+ member.ty,
+ member.binding.as_ref(),
+ ));
+ let name_key = NameKey::StructMember(arg.ty, member_index);
+ let name = match member.binding {
+ Some(crate::Binding::Location { .. }) => {
+ varyings_namer.call(&self.names[&name_key])
+ }
+ _ => self.namer.call(&self.names[&name_key]),
+ };
+ flattened_member_names.insert(name_key, name);
+ }
+ }
+ _ => flattened_arguments.push((
+ NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
+ arg.ty,
+ arg.binding.as_ref(),
+ )),
+ }
+ }
+
+ // Identify the varyings among the argument values, and emit a
+ // struct type named `<fun>Input` to hold them.
+ let stage_in_name = format!("{fun_name}Input");
+ let varyings_member_name = self.namer.call("varyings");
+ let mut has_varyings = false;
+ if !flattened_arguments.is_empty() {
+ writeln!(self.out, "struct {stage_in_name} {{")?;
+ for &(ref name_key, ty, binding) in flattened_arguments.iter() {
+ let binding = match binding {
+ Some(ref binding @ &crate::Binding::Location { .. }) => binding,
+ _ => continue,
+ };
+ has_varyings = true;
+ let name = match *name_key {
+ NameKey::StructMember(..) => &flattened_member_names[name_key],
+ _ => &self.names[name_key],
+ };
+ let ty_name = TypeContext {
+ handle: ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let resolved = options.resolve_local_binding(binding, in_mode)?;
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out, ";")?;
+ }
+ writeln!(self.out, "}};")?;
+ }
+
+ // Define a struct type named for the return value, if any, named
+ // `<fun>Output`.
+ let stage_out_name = format!("{fun_name}Output");
+ let result_member_name = self.namer.call("member");
+ let result_type_name = match fun.result {
+ Some(ref result) => {
+ let mut result_members = Vec::new();
+ if let crate::TypeInner::Struct { ref members, .. } =
+ module.types[result.ty].inner
+ {
+ for (member_index, member) in members.iter().enumerate() {
+ result_members.push((
+ &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
+ member.ty,
+ member.binding.as_ref(),
+ ));
+ }
+ } else {
+ result_members.push((
+ &result_member_name,
+ result.ty,
+ result.binding.as_ref(),
+ ));
+ }
+
+ writeln!(self.out, "struct {stage_out_name} {{")?;
+ let mut has_point_size = false;
+ for (name, ty, binding) in result_members {
+ let ty_name = TypeContext {
+ handle: ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: true,
+ };
+ let binding = binding.ok_or(Error::Validation)?;
+
+ if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
+ has_point_size = true;
+ if !pipeline_options.allow_and_force_point_size {
+ continue;
+ }
+ }
+
+ let array_len = match module.types[ty].inner {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => Some(size),
+ _ => None,
+ };
+ let resolved = options.resolve_local_binding(binding, out_mode)?;
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
+ if let Some(array_len) = array_len {
+ write!(self.out, " [{array_len}]")?;
+ }
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out, ";")?;
+ }
+
+ if pipeline_options.allow_and_force_point_size
+ && ep.stage == crate::ShaderStage::Vertex
+ && !has_point_size
+ {
+ // inject the point size output last
+ writeln!(
+ self.out,
+ "{}float _point_size [[point_size]];",
+ back::INDENT
+ )?;
+ }
+ writeln!(self.out, "}};")?;
+ &stage_out_name
+ }
+ None => "void",
+ };
+
+ // Write the entry point function's name, and begin its argument list.
+ writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
+ let mut is_first_argument = true;
+
+ // If we have produced a struct holding the `EntryPoint`'s
+ // `Function`'s arguments' varyings, pass that struct first.
+ if has_varyings {
+ writeln!(
+ self.out,
+ " {stage_in_name} {varyings_member_name} [[stage_in]]"
+ )?;
+ is_first_argument = false;
+ }
+
+ let mut local_invocation_id = None;
+
+ // Then pass the remaining arguments not included in the varyings
+ // struct.
+ for &(ref name_key, ty, binding) in flattened_arguments.iter() {
+ let binding = match binding {
+ Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
+ _ => continue,
+ };
+ let name = match *name_key {
+ NameKey::StructMember(..) => &flattened_member_names[name_key],
+ _ => &self.names[name_key],
+ };
+
+ if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
+ local_invocation_id = Some(name_key);
+ }
+
+ let ty_name = TypeContext {
+ handle: ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let resolved = options.resolve_local_binding(binding, in_mode)?;
+ let separator = if is_first_argument {
+ is_first_argument = false;
+ ' '
+ } else {
+ ','
+ };
+ write!(self.out, "{separator} {ty_name} {name}")?;
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out)?;
+ }
+
+ let need_workgroup_variables_initialization =
+ self.need_workgroup_variables_initialization(options, ep, module, fun_info);
+
+ if need_workgroup_variables_initialization && local_invocation_id.is_none() {
+ let separator = if is_first_argument {
+ is_first_argument = false;
+ ' '
+ } else {
+ ','
+ };
+ writeln!(
+ self.out,
+ "{separator} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]"
+ )?;
+ }
+
+ // Those global variables used by this entry point and its callees
+ // get passed as arguments. `Private` globals are an exception, they
+ // don't outlive this invocation, so we declare them below as locals
+ // within the entry point.
+ for (handle, var) in module.global_variables.iter() {
+ let usage = fun_info[handle];
+ if usage.is_empty() || var.space == crate::AddressSpace::Private {
+ continue;
+ }
+
+ if options.lang_version < (1, 2) {
+ match var.space {
+ // This restriction is not documented in the MSL spec
+ // but validation will fail if it is not upheld.
+ //
+ // We infer the required version from the "Function
+ // Buffer Read-Writes" section of [what's new], where
+ // the feature sets listed correspond with the ones
+ // supporting MSL 1.2.
+ //
+ // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
+ crate::AddressSpace::Storage { access }
+ if access.contains(crate::StorageAccess::STORE)
+ && ep.stage == crate::ShaderStage::Fragment =>
+ {
+ return Err(Error::UnsupportedWriteableStorageBuffer)
+ }
+ crate::AddressSpace::Handle => {
+ match module.types[var.ty].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => {
+ // This restriction is not documented in the MSL spec
+ // but validation will fail if it is not upheld.
+ //
+ // We infer the required version from the "Function
+ // Texture Read-Writes" section of [what's new], where
+ // the feature sets listed correspond with the ones
+ // supporting MSL 1.2.
+ //
+ // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
+ if access.contains(crate::StorageAccess::STORE)
+ && (ep.stage == crate::ShaderStage::Vertex
+ || ep.stage == crate::ShaderStage::Fragment)
+ {
+ return Err(Error::UnsupportedWriteableStorageTexture(
+ ep.stage,
+ ));
+ }
+
+ if access.contains(
+ crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
+ ) {
+ return Err(Error::UnsupportedRWStorageTexture);
+ }
+ }
+ _ => {}
+ }
+ }
+ _ => {}
+ }
+ }
+
+ // Check min MSL version for binding arrays
+ match var.space {
+ crate::AddressSpace::Handle => match module.types[var.ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => {
+ match module.types[base].inner {
+ crate::TypeInner::Sampler { .. } => {
+ if options.lang_version < (2, 0) {
+ return Err(Error::UnsupportedArrayOf(
+ "samplers".to_string(),
+ ));
+ }
+ }
+ crate::TypeInner::Image { class, .. } => match class {
+ crate::ImageClass::Sampled { .. }
+ | crate::ImageClass::Depth { .. }
+ | crate::ImageClass::Storage {
+ access: crate::StorageAccess::LOAD,
+ ..
+ } => {
+ // Array of textures since:
+ // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
+ // - macOS: Metal 2
+
+ if options.lang_version < (2, 0) {
+ return Err(Error::UnsupportedArrayOf(
+ "textures".to_string(),
+ ));
+ }
+ }
+ crate::ImageClass::Storage {
+ access: crate::StorageAccess::STORE,
+ ..
+ } => {
+ // Array of write-only textures since:
+ // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
+ // - macOS: Metal 2
+
+ if options.lang_version < (2, 0) {
+ return Err(Error::UnsupportedArrayOf(
+ "write-only textures".to_string(),
+ ));
+ }
+ }
+ crate::ImageClass::Storage { .. } => {
+ return Err(Error::UnsupportedArrayOf(
+ "read-write textures".to_string(),
+ ));
+ }
+ },
+ _ => {
+ return Err(Error::UnsupportedArrayOfType(base));
+ }
+ }
+ }
+ _ => {}
+ },
+ _ => {}
+ }
+
+ // the resolves have already been checked for `!fake_missing_bindings` case
+ let resolved = match var.space {
+ crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
+ crate::AddressSpace::WorkGroup => None,
+ _ => options
+ .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
+ .ok(),
+ };
+ if let Some(ref resolved) = resolved {
+ // Inline samplers are be defined in the EP body
+ if resolved.as_inline_sampler(options).is_some() {
+ continue;
+ }
+ }
+
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage,
+ binding: resolved.as_ref(),
+ reference: true,
+ };
+ let separator = if is_first_argument {
+ is_first_argument = false;
+ ' '
+ } else {
+ ','
+ };
+ write!(self.out, "{separator} ")?;
+ tyvar.try_fmt(&mut self.out)?;
+ if let Some(resolved) = resolved {
+ resolved.try_fmt(&mut self.out)?;
+ }
+ if let Some(value) = var.init {
+ write!(self.out, " = ")?;
+ self.put_const_expression(value, module, mod_info)?;
+ }
+ writeln!(self.out)?;
+ }
+
+ // If this entry uses any variable-length arrays, their sizes are
+ // passed as a final struct-typed argument.
+ if supports_array_length {
+ // this is checked earlier
+ let resolved = options.resolve_sizes_buffer(ep).unwrap();
+ let separator = if module.global_variables.is_empty() {
+ ' '
+ } else {
+ ','
+ };
+ write!(
+ self.out,
+ "{separator} constant _mslBufferSizes& _buffer_sizes",
+ )?;
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out)?;
+ }
+
+ // end of the entry point argument list
+ writeln!(self.out, ") {{")?;
+
+ if need_workgroup_variables_initialization {
+ self.write_workgroup_variables_initialization(
+ module,
+ mod_info,
+ fun_info,
+ local_invocation_id,
+ )?;
+ }
+
+ // Metal doesn't support private mutable variables outside of functions,
+ // so we put them here, just like the locals.
+ for (handle, var) in module.global_variables.iter() {
+ let usage = fun_info[handle];
+ if usage.is_empty() {
+ continue;
+ }
+ if var.space == crate::AddressSpace::Private {
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage,
+ binding: None,
+ reference: false,
+ };
+ write!(self.out, "{}", back::INDENT)?;
+ tyvar.try_fmt(&mut self.out)?;
+ match var.init {
+ Some(value) => {
+ write!(self.out, " = ")?;
+ self.put_const_expression(value, module, mod_info)?;
+ writeln!(self.out, ";")?;
+ }
+ None => {
+ writeln!(self.out, " = {{}};")?;
+ }
+ };
+ } else if let Some(ref binding) = var.binding {
+ // write an inline sampler
+ let resolved = options.resolve_resource_binding(ep, binding).unwrap();
+ if let Some(sampler) = resolved.as_inline_sampler(options) {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ writeln!(
+ self.out,
+ "{}constexpr {}::sampler {}(",
+ back::INDENT,
+ NAMESPACE,
+ name
+ )?;
+ self.put_inline_sampler_properties(back::Level(2), sampler)?;
+ writeln!(self.out, "{});", back::INDENT)?;
+ }
+ }
+ }
+
+ // Now take the arguments that we gathered into structs, and the
+ // structs that we flattened into arguments, and emit local
+ // variables with initializers that put everything back the way the
+ // body code expects.
+ //
+ // If we had to generate fresh names for struct members passed as
+ // arguments, be sure to use those names when rebuilding the struct.
+ //
+ // "Each day, I change some zeros to ones, and some ones to zeros.
+ // The rest, I leave alone."
+ for (arg_index, arg) in fun.arguments.iter().enumerate() {
+ let arg_name =
+ &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
+ match module.types[arg.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let struct_name = &self.names[&NameKey::Type(arg.ty)];
+ write!(
+ self.out,
+ "{}const {} {} = {{ ",
+ back::INDENT,
+ struct_name,
+ arg_name
+ )?;
+ for (member_index, member) in members.iter().enumerate() {
+ let key = NameKey::StructMember(arg.ty, member_index as u32);
+ let name = &flattened_member_names[&key];
+ if member_index != 0 {
+ write!(self.out, ", ")?;
+ }
+ // insert padding initialization, if needed
+ if self
+ .struct_member_pads
+ .contains(&(arg.ty, member_index as u32))
+ {
+ write!(self.out, "{{}}, ")?;
+ }
+ if let Some(crate::Binding::Location { .. }) = member.binding {
+ write!(self.out, "{varyings_member_name}.")?;
+ }
+ write!(self.out, "{name}")?;
+ }
+ writeln!(self.out, " }};")?;
+ }
+ _ => {
+ if let Some(crate::Binding::Location { .. }) = arg.binding {
+ writeln!(
+ self.out,
+ "{}const auto {} = {}.{};",
+ back::INDENT,
+ arg_name,
+ varyings_member_name,
+ arg_name
+ )?;
+ }
+ }
+ }
+ }
+
+ let guarded_indices =
+ index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
+
+ let context = StatementContext {
+ expression: ExpressionContext {
+ function: fun,
+ origin: FunctionOrigin::EntryPoint(ep_index as _),
+ info: fun_info,
+ lang_version: options.lang_version,
+ policies: options.bounds_check_policies,
+ guarded_indices,
+ module,
+ mod_info,
+ pipeline_options,
+ },
+ result_struct: Some(&stage_out_name),
+ };
+
+ // Finally, declare all the local variables that we need
+ //TODO: we can postpone this till the relevant expressions are emitted
+ for (local_handle, local) in fun.local_variables.iter() {
+ let name = &self.names[&NameKey::EntryPointLocal(ep_index as _, local_handle)];
+ let ty_name = TypeContext {
+ handle: local.ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
+ match local.init {
+ Some(value) => {
+ write!(self.out, " = ")?;
+ self.put_expression(value, &context.expression, true)?;
+ }
+ None => {
+ write!(self.out, " = {{}}")?;
+ }
+ };
+ writeln!(self.out, ";")?;
+ }
+
+ self.update_expressions_to_bake(fun, fun_info, &context.expression);
+ self.put_block(back::Level(1), &fun.body, &context)?;
+ writeln!(self.out, "}}")?;
+ if ep_index + 1 != module.entry_points.len() {
+ writeln!(self.out)?;
+ }
+ self.named_expressions.clear();
+ }
+
+ Ok(info)
+ }
+
+ fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
+ // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
+ // so we try to avoid it here.
+ if flags.is_empty() {
+ writeln!(
+ self.out,
+ "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
+ )?;
+ }
+ if flags.contains(crate::Barrier::STORAGE) {
+ writeln!(
+ self.out,
+ "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
+ )?;
+ }
+ if flags.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(
+ self.out,
+ "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
+ )?;
+ }
+ Ok(())
+ }
+}
+
+/// Initializing workgroup variables is more tricky for Metal because we have to deal
+/// with atomics at the type-level (which don't have a copy constructor).
+mod workgroup_mem_init {
+ use crate::EntryPoint;
+
+ use super::*;
+
+ enum Access {
+ GlobalVariable(Handle<crate::GlobalVariable>),
+ StructMember(Handle<crate::Type>, u32),
+ Array(usize),
+ }
+
+ impl Access {
+ fn write<W: Write>(
+ &self,
+ writer: &mut W,
+ names: &FastHashMap<NameKey, String>,
+ ) -> Result<(), core::fmt::Error> {
+ match *self {
+ Access::GlobalVariable(handle) => {
+ write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
+ }
+ Access::StructMember(handle, index) => {
+ write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
+ }
+ Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
+ }
+ }
+ }
+
+ struct AccessStack {
+ stack: Vec<Access>,
+ array_depth: usize,
+ }
+
+ impl AccessStack {
+ const fn new() -> Self {
+ Self {
+ stack: Vec::new(),
+ array_depth: 0,
+ }
+ }
+
+ fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
+ let array_depth = self.array_depth;
+ self.stack.push(Access::Array(array_depth));
+ self.array_depth += 1;
+ let res = cb(self, array_depth);
+ self.stack.pop();
+ self.array_depth -= 1;
+ res
+ }
+
+ fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
+ self.stack.push(new);
+ let res = cb(self);
+ self.stack.pop();
+ res
+ }
+
+ fn write<W: Write>(
+ &self,
+ writer: &mut W,
+ names: &FastHashMap<NameKey, String>,
+ ) -> Result<(), core::fmt::Error> {
+ for next in self.stack.iter() {
+ next.write(writer, names)?;
+ }
+ Ok(())
+ }
+ }
+
+ impl<W: Write> Writer<W> {
+ pub(super) fn need_workgroup_variables_initialization(
+ &mut self,
+ options: &Options,
+ ep: &EntryPoint,
+ module: &crate::Module,
+ fun_info: &valid::FunctionInfo,
+ ) -> bool {
+ options.zero_initialize_workgroup_memory
+ && ep.stage == crate::ShaderStage::Compute
+ && module.global_variables.iter().any(|(handle, var)| {
+ !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ })
+ }
+
+ pub(super) fn write_workgroup_variables_initialization(
+ &mut self,
+ module: &crate::Module,
+ module_info: &valid::ModuleInfo,
+ fun_info: &valid::FunctionInfo,
+ local_invocation_id: Option<&NameKey>,
+ ) -> BackendResult {
+ let level = back::Level(1);
+
+ writeln!(
+ self.out,
+ "{}if ({}::all({} == {}::uint3(0u))) {{",
+ level,
+ NAMESPACE,
+ local_invocation_id
+ .map(|name_key| self.names[name_key].as_str())
+ .unwrap_or("__local_invocation_id"),
+ NAMESPACE,
+ )?;
+
+ let mut access_stack = AccessStack::new();
+
+ let vars = module.global_variables.iter().filter(|&(handle, var)| {
+ !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ });
+
+ for (handle, var) in vars {
+ access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
+ self.write_workgroup_variable_initialization(
+ module,
+ module_info,
+ var.ty,
+ access_stack,
+ level.next(),
+ )
+ })?;
+ }
+
+ writeln!(self.out, "{level}}}")?;
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)
+ }
+
+ fn write_workgroup_variable_initialization(
+ &mut self,
+ module: &crate::Module,
+ module_info: &valid::ModuleInfo,
+ ty: Handle<crate::Type>,
+ access_stack: &mut AccessStack,
+ level: back::Level,
+ ) -> BackendResult {
+ if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
+ write!(self.out, "{level}")?;
+ access_stack.write(&mut self.out, &self.names)?;
+ writeln!(self.out, " = {{}};")?;
+ } else {
+ match module.types[ty].inner {
+ crate::TypeInner::Atomic { .. } => {
+ write!(
+ self.out,
+ "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
+ )?;
+ access_stack.write(&mut self.out, &self.names)?;
+ writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
+ }
+ crate::TypeInner::Array { base, size, .. } => {
+ let count = match size.to_indexable_length(module).expect("Bad array size")
+ {
+ proc::IndexableLength::Known(count) => count,
+ proc::IndexableLength::Dynamic => unreachable!(),
+ };
+
+ access_stack.enter_array(|access_stack, array_depth| {
+ writeln!(
+ self.out,
+ "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
+ )?;
+ self.write_workgroup_variable_initialization(
+ module,
+ module_info,
+ base,
+ access_stack,
+ level.next(),
+ )?;
+ writeln!(self.out, "{level}}}")?;
+ BackendResult::Ok(())
+ })?;
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (index, member) in members.iter().enumerate() {
+ access_stack.enter(
+ Access::StructMember(ty, index as u32),
+ |access_stack| {
+ self.write_workgroup_variable_initialization(
+ module,
+ module_info,
+ member.ty,
+ access_stack,
+ level,
+ )
+ },
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ Ok(())
+ }
+ }
+}
+
+#[test]
+fn test_stack_size() {
+ use crate::valid::{Capabilities, ValidationFlags};
+ // create a module with at least one expression nested
+ let mut module = crate::Module::default();
+ let mut fun = crate::Function::default();
+ let const_expr = fun.expressions.append(
+ crate::Expression::Literal(crate::Literal::F32(1.0)),
+ Default::default(),
+ );
+ let nested_expr = fun.expressions.append(
+ crate::Expression::Unary {
+ op: crate::UnaryOperator::Negate,
+ expr: const_expr,
+ },
+ Default::default(),
+ );
+ fun.body.push(
+ crate::Statement::Emit(fun.expressions.range_from(1)),
+ Default::default(),
+ );
+ fun.body.push(
+ crate::Statement::If {
+ condition: nested_expr,
+ accept: crate::Block::new(),
+ reject: crate::Block::new(),
+ },
+ Default::default(),
+ );
+ let _ = module.functions.append(fun, Default::default());
+ // analyse the module
+ let info = crate::valid::Validator::new(ValidationFlags::empty(), Capabilities::empty())
+ .validate(&module)
+ .unwrap();
+ // process the module
+ let mut writer = Writer::new(String::new());
+ writer
+ .write(&module, &info, &Default::default(), &Default::default())
+ .unwrap();
+
+ {
+ // check expression stack
+ let mut addresses_start = usize::MAX;
+ let mut addresses_end = 0usize;
+ for pointer in writer.put_expression_stack_pointers {
+ addresses_start = addresses_start.min(pointer as usize);
+ addresses_end = addresses_end.max(pointer as usize);
+ }
+ let stack_size = addresses_end - addresses_start;
+ // check the size (in debug only)
+ // last observed macOS value: 20528 (CI)
+ if !(11000..=25000).contains(&stack_size) {
+ panic!("`put_expression` stack size {stack_size} has changed!");
+ }
+ }
+
+ {
+ // check block stack
+ let mut addresses_start = usize::MAX;
+ let mut addresses_end = 0usize;
+ for pointer in writer.put_block_stack_pointers {
+ addresses_start = addresses_start.min(pointer as usize);
+ addresses_end = addresses_end.max(pointer as usize);
+ }
+ let stack_size = addresses_end - addresses_start;
+ // check the size (in debug only)
+ // last observed macOS value: 19152 (CI)
+ if !(9000..=20000).contains(&stack_size) {
+ panic!("`put_block` stack size {stack_size} has changed!");
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/block.rs b/third_party/rust/naga/src/back/spv/block.rs
new file mode 100644
index 0000000000..6c96fa09e3
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/block.rs
@@ -0,0 +1,2368 @@
+/*!
+Implementations for `BlockContext` methods.
+*/
+
+use super::{
+ helpers, index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext,
+ Dimension, Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer,
+ WriterFlags,
+};
+use crate::{arena::Handle, proc::TypeResolution, Statement};
+use spirv::Word;
+
+fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
+ match *type_inner {
+ crate::TypeInner::Scalar(_) => Dimension::Scalar,
+ crate::TypeInner::Vector { .. } => Dimension::Vector,
+ crate::TypeInner::Matrix { .. } => Dimension::Matrix,
+ _ => unreachable!(),
+ }
+}
+
+/// The results of emitting code for a left-hand-side expression.
+///
+/// On success, `write_expression_pointer` returns one of these.
+enum ExpressionPointer {
+ /// The pointer to the expression's value is available, as the value of the
+ /// expression with the given id.
+ Ready { pointer_id: Word },
+
+ /// The access expression must be conditional on the value of `condition`, a boolean
+ /// expression that is true if all indices are in bounds. If `condition` is true, then
+ /// `access` is an `OpAccessChain` instruction that will compute a pointer to the
+ /// expression's value. If `condition` is false, then executing `access` would be
+ /// undefined behavior.
+ Conditional {
+ condition: Word,
+ access: Instruction,
+ },
+}
+
+/// The termination statement to be added to the end of the block
+pub enum BlockExit {
+ /// Generates an OpReturn (void return)
+ Return,
+ /// Generates an OpBranch to the specified block
+ Branch {
+ /// The branch target block
+ target: Word,
+ },
+ /// Translates a loop `break if` into an `OpBranchConditional` to the
+ /// merge block if true (the merge block is passed through [`LoopContext::break_id`]
+ /// or else to the loop header (passed through [`preamble_id`])
+ ///
+ /// [`preamble_id`]: Self::BreakIf::preamble_id
+ BreakIf {
+ /// The condition of the `break if`
+ condition: Handle<crate::Expression>,
+ /// The loop header block id
+ preamble_id: Word,
+ },
+}
+
+#[derive(Debug)]
+pub(crate) struct DebugInfoInner<'a> {
+ pub source_code: &'a str,
+ pub source_file_id: Word,
+}
+
+impl Writer {
+ // Flip Y coordinate to adjust for coordinate space difference
+ // between SPIR-V and our IR.
+ // The `position_id` argument is a pointer to a `vecN<f32>`,
+ // whose `y` component we will negate.
+ fn write_epilogue_position_y_flip(
+ &mut self,
+ position_id: Word,
+ body: &mut Vec<Instruction>,
+ ) -> Result<(), Error> {
+ let float_ptr_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::F32,
+ pointer_space: Some(spirv::StorageClass::Output),
+ }));
+ let index_y_id = self.get_index_constant(1);
+ let access_id = self.id_gen.next();
+ body.push(Instruction::access_chain(
+ float_ptr_type_id,
+ access_id,
+ position_id,
+ &[index_y_id],
+ ));
+
+ let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }));
+ let load_id = self.id_gen.next();
+ body.push(Instruction::load(float_type_id, load_id, access_id, None));
+
+ let neg_id = self.id_gen.next();
+ body.push(Instruction::unary(
+ spirv::Op::FNegate,
+ float_type_id,
+ neg_id,
+ load_id,
+ ));
+
+ body.push(Instruction::store(access_id, neg_id, None));
+ Ok(())
+ }
+
+ // Clamp fragment depth between 0 and 1.
+ fn write_epilogue_frag_depth_clamp(
+ &mut self,
+ frag_depth_id: Word,
+ body: &mut Vec<Instruction>,
+ ) -> Result<(), Error> {
+ let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }));
+ let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0));
+ let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0));
+
+ let original_id = self.id_gen.next();
+ body.push(Instruction::load(
+ float_type_id,
+ original_id,
+ frag_depth_id,
+ None,
+ ));
+
+ let clamp_id = self.id_gen.next();
+ body.push(Instruction::ext_inst(
+ self.gl450_ext_inst_id,
+ spirv::GLOp::FClamp,
+ float_type_id,
+ clamp_id,
+ &[original_id, zero_scalar_id, one_scalar_id],
+ ));
+
+ body.push(Instruction::store(frag_depth_id, clamp_id, None));
+ Ok(())
+ }
+
+ fn write_entry_point_return(
+ &mut self,
+ value_id: Word,
+ ir_result: &crate::FunctionResult,
+ result_members: &[ResultMember],
+ body: &mut Vec<Instruction>,
+ ) -> Result<(), Error> {
+ for (index, res_member) in result_members.iter().enumerate() {
+ let member_value_id = match ir_result.binding {
+ Some(_) => value_id,
+ None => {
+ let member_value_id = self.id_gen.next();
+ body.push(Instruction::composite_extract(
+ res_member.type_id,
+ member_value_id,
+ value_id,
+ &[index as u32],
+ ));
+ member_value_id
+ }
+ };
+
+ body.push(Instruction::store(res_member.id, member_value_id, None));
+
+ match res_member.built_in {
+ Some(crate::BuiltIn::Position { .. })
+ if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
+ {
+ self.write_epilogue_position_y_flip(res_member.id, body)?;
+ }
+ Some(crate::BuiltIn::FragDepth)
+ if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
+ {
+ self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
+ }
+ _ => {}
+ }
+ }
+ Ok(())
+ }
+}
+
+impl<'w> BlockContext<'w> {
+ /// Decide whether to put off emitting instructions for `expr_handle`.
+ ///
+ /// We would like to gather together chains of `Access` and `AccessIndex`
+ /// Naga expressions into a single `OpAccessChain` SPIR-V instruction. To do
+ /// this, we don't generate instructions for these exprs when we first
+ /// encounter them. Their ids in `self.writer.cached.ids` are left as zero. Then,
+ /// once we encounter a `Load` or `Store` expression that actually needs the
+ /// chain's value, we call `write_expression_pointer` to handle the whole
+ /// thing in one fell swoop.
+ fn is_intermediate(&self, expr_handle: Handle<crate::Expression>) -> bool {
+ match self.ir_function.expressions[expr_handle] {
+ crate::Expression::GlobalVariable(handle) => {
+ match self.ir_module.global_variables[handle].space {
+ crate::AddressSpace::Handle => false,
+ _ => true,
+ }
+ }
+ crate::Expression::LocalVariable(_) => true,
+ crate::Expression::FunctionArgument(index) => {
+ let arg = &self.ir_function.arguments[index as usize];
+ self.ir_module.types[arg.ty].inner.pointer_space().is_some()
+ }
+
+ // The chain rule: if this `Access...`'s `base` operand was
+ // previously omitted, then omit this one, too.
+ _ => self.cached.ids[expr_handle.index()] == 0,
+ }
+ }
+
+ /// Cache an expression for a value.
+ pub(super) fn cache_expression_value(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ let is_named_expression = self
+ .ir_function
+ .named_expressions
+ .contains_key(&expr_handle);
+
+ if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression {
+ return Ok(());
+ }
+
+ let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
+ let id = match self.ir_function.expressions[expr_handle] {
+ crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal),
+ crate::Expression::Constant(handle) => {
+ let init = self.ir_module.constants[handle].init;
+ self.writer.constant_ids[init.index()]
+ }
+ crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
+ crate::Expression::Compose { ty, ref components } => {
+ self.temp_list.clear();
+ if self.expression_constness.is_const(expr_handle) {
+ self.temp_list.extend(
+ crate::proc::flatten_compose(
+ ty,
+ components,
+ &self.ir_function.expressions,
+ &self.ir_module.types,
+ )
+ .map(|component| self.cached[component]),
+ );
+ self.writer
+ .get_constant_composite(LookupType::Handle(ty), &self.temp_list)
+ } else {
+ self.temp_list
+ .extend(components.iter().map(|&component| self.cached[component]));
+
+ let id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ id,
+ &self.temp_list,
+ ));
+ id
+ }
+ }
+ crate::Expression::Splat { size, value } => {
+ let value_id = self.cached[value];
+ let components = &[value_id; 4][..size as usize];
+
+ if self.expression_constness.is_const(expr_handle) {
+ let ty = self
+ .writer
+ .get_expression_lookup_type(&self.fun_info[expr_handle].ty);
+ self.writer.get_constant_composite(ty, components)
+ } else {
+ let id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ id,
+ components,
+ ));
+ id
+ }
+ }
+ crate::Expression::Access { base, index: _ } if self.is_intermediate(base) => {
+ // See `is_intermediate`; we'll handle this later in
+ // `write_expression_pointer`.
+ 0
+ }
+ crate::Expression::Access { base, index } => {
+ let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types);
+ match *base_ty_inner {
+ crate::TypeInner::Vector { .. } => {
+ self.write_vector_access(expr_handle, base, index, block)?
+ }
+ // Only binding arrays in the Handle address space will take this path (due to `is_intermediate`)
+ crate::TypeInner::BindingArray {
+ base: binding_type, ..
+ } => {
+ let space = match self.ir_function.expressions[base] {
+ crate::Expression::GlobalVariable(gvar) => {
+ self.ir_module.global_variables[gvar].space
+ }
+ _ => unreachable!(),
+ };
+ let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
+ base: binding_type,
+ class: helpers::map_storage_class(space),
+ });
+
+ let result_id = match self.write_expression_pointer(
+ expr_handle,
+ block,
+ Some(binding_array_false_pointer),
+ )? {
+ ExpressionPointer::Ready { pointer_id } => pointer_id,
+ ExpressionPointer::Conditional { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "Texture array out-of-bounds handling",
+ ));
+ }
+ };
+
+ let binding_type_id = self.get_type_id(LookupType::Handle(binding_type));
+
+ let load_id = self.gen_id();
+ block.body.push(Instruction::load(
+ binding_type_id,
+ load_id,
+ result_id,
+ None,
+ ));
+
+ // Subsequent image operations require the image/sampler to be decorated as NonUniform
+ // if the image/sampler binding array was accessed with a non-uniform index
+ // see VUID-RuntimeSpirv-NonUniform-06274
+ if self.fun_info[index].uniformity.non_uniform_result.is_some() {
+ self.writer
+ .decorate_non_uniform_binding_array_access(load_id)?;
+ }
+
+ load_id
+ }
+ ref other => {
+ log::error!(
+ "Unable to access base {:?} of type {:?}",
+ self.ir_function.expressions[base],
+ other
+ );
+ return Err(Error::Validation(
+ "only vectors may be dynamically indexed by value",
+ ));
+ }
+ }
+ }
+ crate::Expression::AccessIndex { base, index: _ } if self.is_intermediate(base) => {
+ // See `is_intermediate`; we'll handle this later in
+ // `write_expression_pointer`.
+ 0
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. }
+ | crate::TypeInner::Array { .. }
+ | crate::TypeInner::Struct { .. } => {
+ // We never need bounds checks here: dynamically sized arrays can
+ // only appear behind pointers, and are thus handled by the
+ // `is_intermediate` case above. Everything else's size is
+ // statically known and checked in validation.
+ let id = self.gen_id();
+ let base_id = self.cached[base];
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ id,
+ base_id,
+ &[index],
+ ));
+ id
+ }
+ // Only binding arrays in the Handle address space will take this path (due to `is_intermediate`)
+ crate::TypeInner::BindingArray {
+ base: binding_type, ..
+ } => {
+ let space = match self.ir_function.expressions[base] {
+ crate::Expression::GlobalVariable(gvar) => {
+ self.ir_module.global_variables[gvar].space
+ }
+ _ => unreachable!(),
+ };
+ let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
+ base: binding_type,
+ class: helpers::map_storage_class(space),
+ });
+
+ let result_id = match self.write_expression_pointer(
+ expr_handle,
+ block,
+ Some(binding_array_false_pointer),
+ )? {
+ ExpressionPointer::Ready { pointer_id } => pointer_id,
+ ExpressionPointer::Conditional { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "Texture array out-of-bounds handling",
+ ));
+ }
+ };
+
+ let binding_type_id = self.get_type_id(LookupType::Handle(binding_type));
+
+ let load_id = self.gen_id();
+ block.body.push(Instruction::load(
+ binding_type_id,
+ load_id,
+ result_id,
+ None,
+ ));
+
+ load_id
+ }
+ ref other => {
+ log::error!("Unable to access index of {:?}", other);
+ return Err(Error::FeatureNotImplemented("access index for type"));
+ }
+ }
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ self.writer.global_variables[handle.index()].access_id
+ }
+ crate::Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ let vector_id = self.cached[vector];
+ self.temp_list.clear();
+ for &sc in pattern[..size as usize].iter() {
+ self.temp_list.push(sc as Word);
+ }
+ let id = self.gen_id();
+ block.body.push(Instruction::vector_shuffle(
+ result_type_id,
+ id,
+ vector_id,
+ vector_id,
+ &self.temp_list,
+ ));
+ id
+ }
+ crate::Expression::Unary { op, expr } => {
+ let id = self.gen_id();
+ let expr_id = self.cached[expr];
+ let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
+
+ let spirv_op = match op {
+ crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
+ Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
+ _ => return Err(Error::Validation("Unexpected kind for negation")),
+ },
+ crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot,
+ crate::UnaryOperator::BitwiseNot => spirv::Op::Not,
+ };
+
+ block
+ .body
+ .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
+ id
+ }
+ crate::Expression::Binary { op, left, right } => {
+ let id = self.gen_id();
+ let left_id = self.cached[left];
+ let right_id = self.cached[right];
+
+ let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
+ let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
+
+ let left_dimension = get_dimension(left_ty_inner);
+ let right_dimension = get_dimension(right_ty_inner);
+
+ let mut reverse_operands = false;
+
+ let spirv_op = match op {
+ crate::BinaryOperator::Add => match *left_ty_inner {
+ crate::TypeInner::Scalar(scalar)
+ | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
+ crate::ScalarKind::Float => spirv::Op::FAdd,
+ _ => spirv::Op::IAdd,
+ },
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ self.write_matrix_matrix_column_op(
+ block,
+ id,
+ result_type_id,
+ left_id,
+ right_id,
+ columns,
+ rows,
+ scalar.width,
+ spirv::Op::FAdd,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Subtract => match *left_ty_inner {
+ crate::TypeInner::Scalar(scalar)
+ | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
+ crate::ScalarKind::Float => spirv::Op::FSub,
+ _ => spirv::Op::ISub,
+ },
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ self.write_matrix_matrix_column_op(
+ block,
+ id,
+ result_type_id,
+ left_id,
+ right_id,
+ columns,
+ rows,
+ scalar.width,
+ spirv::Op::FSub,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Multiply => match (left_dimension, right_dimension) {
+ (Dimension::Scalar, Dimension::Vector) => {
+ self.write_vector_scalar_mult(
+ block,
+ id,
+ result_type_id,
+ right_id,
+ left_id,
+ right_ty_inner,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ (Dimension::Vector, Dimension::Scalar) => {
+ self.write_vector_scalar_mult(
+ block,
+ id,
+ result_type_id,
+ left_id,
+ right_id,
+ left_ty_inner,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ (Dimension::Vector, Dimension::Matrix) => spirv::Op::VectorTimesMatrix,
+ (Dimension::Matrix, Dimension::Scalar) => spirv::Op::MatrixTimesScalar,
+ (Dimension::Scalar, Dimension::Matrix) => {
+ reverse_operands = true;
+ spirv::Op::MatrixTimesScalar
+ }
+ (Dimension::Matrix, Dimension::Vector) => spirv::Op::MatrixTimesVector,
+ (Dimension::Matrix, Dimension::Matrix) => spirv::Op::MatrixTimesMatrix,
+ (Dimension::Vector, Dimension::Vector)
+ | (Dimension::Scalar, Dimension::Scalar)
+ if left_ty_inner.scalar_kind() == Some(crate::ScalarKind::Float) =>
+ {
+ spirv::Op::FMul
+ }
+ (Dimension::Vector, Dimension::Vector)
+ | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
+ },
+ crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
+ Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
+ Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
+ // TODO: handle undefined behavior
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ Some(crate::ScalarKind::Sint) => spirv::Op::SRem,
+ // TODO: handle undefined behavior
+ // if right == 0 return 0
+ Some(crate::ScalarKind::Uint) => spirv::Op::UMod,
+ // TODO: handle undefined behavior
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+ Some(crate::ScalarKind::Float) => spirv::Op::FRem,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
+ spirv::Op::IEqual
+ }
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
+ spirv::Op::INotEqual
+ }
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
+ Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
+ Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
+ Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
+ Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
+ _ => spirv::Op::BitwiseAnd,
+ },
+ crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
+ crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
+ _ => spirv::Op::BitwiseOr,
+ },
+ crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
+ crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
+ crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
+ crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
+ Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
+ _ => unimplemented!(),
+ },
+ };
+
+ block.body.push(Instruction::binary(
+ spirv_op,
+ result_type_id,
+ id,
+ if reverse_operands { right_id } else { left_id },
+ if reverse_operands { left_id } else { right_id },
+ ));
+ id
+ }
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+ enum MathOp {
+ Ext(spirv::GLOp),
+ Custom(Instruction),
+ }
+
+ let arg0_id = self.cached[arg];
+ let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
+ let arg_scalar_kind = arg_ty.scalar_kind();
+ let arg1_id = match arg1 {
+ Some(handle) => self.cached[handle],
+ None => 0,
+ };
+ let arg2_id = match arg2 {
+ Some(handle) => self.cached[handle],
+ None => 0,
+ };
+ let arg3_id = match arg3 {
+ Some(handle) => self.cached[handle],
+ None => 0,
+ };
+
+ let id = self.gen_id();
+ let math_op = match fun {
+ // comparison
+ Mf::Abs => {
+ match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FAbs),
+ Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GLOp::SAbs),
+ Some(crate::ScalarKind::Uint) => {
+ MathOp::Custom(Instruction::unary(
+ spirv::Op::CopyObject, // do nothing
+ result_type_id,
+ id,
+ arg0_id,
+ ))
+ }
+ other => unimplemented!("Unexpected abs({:?})", other),
+ }
+ }
+ Mf::Min => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FMin,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SMin,
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::UMin,
+ other => unimplemented!("Unexpected min({:?})", other),
+ }),
+ Mf::Max => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FMax,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SMax,
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax,
+ other => unimplemented!("Unexpected max({:?})", other),
+ }),
+ Mf::Clamp => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FClamp,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SClamp,
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::UClamp,
+ other => unimplemented!("Unexpected max({:?})", other),
+ }),
+ Mf::Saturate => {
+ let (maybe_size, scalar) = match *arg_ty {
+ crate::TypeInner::Vector { size, scalar } => (Some(size), scalar),
+ crate::TypeInner::Scalar(scalar) => (None, scalar),
+ ref other => unimplemented!("Unexpected saturate({:?})", other),
+ };
+ let scalar = crate::Scalar::float(scalar.width);
+ let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?;
+ let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?;
+
+ if let Some(size) = maybe_size {
+ let ty = LocalType::Value {
+ vector_size: Some(size),
+ scalar,
+ pointer_space: None,
+ }
+ .into();
+
+ self.temp_list.clear();
+ self.temp_list.resize(size as _, arg1_id);
+
+ arg1_id = self.writer.get_constant_composite(ty, &self.temp_list);
+
+ self.temp_list.fill(arg2_id);
+
+ arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
+ }
+
+ MathOp::Custom(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::FClamp,
+ result_type_id,
+ id,
+ &[arg0_id, arg1_id, arg2_id],
+ ))
+ }
+ // trigonometry
+ Mf::Sin => MathOp::Ext(spirv::GLOp::Sin),
+ Mf::Sinh => MathOp::Ext(spirv::GLOp::Sinh),
+ Mf::Asin => MathOp::Ext(spirv::GLOp::Asin),
+ Mf::Cos => MathOp::Ext(spirv::GLOp::Cos),
+ Mf::Cosh => MathOp::Ext(spirv::GLOp::Cosh),
+ Mf::Acos => MathOp::Ext(spirv::GLOp::Acos),
+ Mf::Tan => MathOp::Ext(spirv::GLOp::Tan),
+ Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh),
+ Mf::Atan => MathOp::Ext(spirv::GLOp::Atan),
+ Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2),
+ Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh),
+ Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh),
+ Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh),
+ Mf::Radians => MathOp::Ext(spirv::GLOp::Radians),
+ Mf::Degrees => MathOp::Ext(spirv::GLOp::Degrees),
+ // decomposition
+ Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil),
+ Mf::Round => MathOp::Ext(spirv::GLOp::RoundEven),
+ Mf::Floor => MathOp::Ext(spirv::GLOp::Floor),
+ Mf::Fract => MathOp::Ext(spirv::GLOp::Fract),
+ Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc),
+ Mf::Modf => MathOp::Ext(spirv::GLOp::ModfStruct),
+ Mf::Frexp => MathOp::Ext(spirv::GLOp::FrexpStruct),
+ Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
+ // geometry
+ Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Vector {
+ scalar:
+ crate::Scalar {
+ kind: crate::ScalarKind::Float,
+ ..
+ },
+ ..
+ } => MathOp::Custom(Instruction::binary(
+ spirv::Op::Dot,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ )),
+ // TODO: consider using integer dot product if VK_KHR_shader_integer_dot_product is available
+ crate::TypeInner::Vector { size, .. } => {
+ self.write_dot_product(
+ id,
+ result_type_id,
+ arg0_id,
+ arg1_id,
+ size as u32,
+ block,
+ );
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ _ => unreachable!(
+ "Correct TypeInner for dot product should be already validated"
+ ),
+ },
+ Mf::Outer => MathOp::Custom(Instruction::binary(
+ spirv::Op::OuterProduct,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ )),
+ Mf::Cross => MathOp::Ext(spirv::GLOp::Cross),
+ Mf::Distance => MathOp::Ext(spirv::GLOp::Distance),
+ Mf::Length => MathOp::Ext(spirv::GLOp::Length),
+ Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize),
+ Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward),
+ Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect),
+ Mf::Refract => MathOp::Ext(spirv::GLOp::Refract),
+ // exponent
+ Mf::Exp => MathOp::Ext(spirv::GLOp::Exp),
+ Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2),
+ Mf::Log => MathOp::Ext(spirv::GLOp::Log),
+ Mf::Log2 => MathOp::Ext(spirv::GLOp::Log2),
+ Mf::Pow => MathOp::Ext(spirv::GLOp::Pow),
+ // computational
+ Mf::Sign => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FSign,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SSign,
+ other => unimplemented!("Unexpected sign({:?})", other),
+ }),
+ Mf::Fma => MathOp::Ext(spirv::GLOp::Fma),
+ Mf::Mix => {
+ let selector = arg2.unwrap();
+ let selector_ty =
+ self.fun_info[selector].ty.inner_with(&self.ir_module.types);
+ match (arg_ty, selector_ty) {
+ // if the selector is a scalar, we need to splat it
+ (
+ &crate::TypeInner::Vector { size, .. },
+ &crate::TypeInner::Scalar(scalar),
+ ) => {
+ let selector_type_id =
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(size),
+ scalar,
+ pointer_space: None,
+ }));
+ self.temp_list.clear();
+ self.temp_list.resize(size as usize, arg2_id);
+
+ let selector_id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ selector_type_id,
+ selector_id,
+ &self.temp_list,
+ ));
+
+ MathOp::Custom(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::FMix,
+ result_type_id,
+ id,
+ &[arg0_id, arg1_id, selector_id],
+ ))
+ }
+ _ => MathOp::Ext(spirv::GLOp::FMix),
+ }
+ }
+ Mf::Step => MathOp::Ext(spirv::GLOp::Step),
+ Mf::SmoothStep => MathOp::Ext(spirv::GLOp::SmoothStep),
+ Mf::Sqrt => MathOp::Ext(spirv::GLOp::Sqrt),
+ Mf::InverseSqrt => MathOp::Ext(spirv::GLOp::InverseSqrt),
+ Mf::Inverse => MathOp::Ext(spirv::GLOp::MatrixInverse),
+ Mf::Transpose => MathOp::Custom(Instruction::unary(
+ spirv::Op::Transpose,
+ result_type_id,
+ id,
+ arg0_id,
+ )),
+ Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant),
+ Mf::ReverseBits => MathOp::Custom(Instruction::unary(
+ spirv::Op::BitReverse,
+ result_type_id,
+ id,
+ arg0_id,
+ )),
+ Mf::CountTrailingZeros => {
+ let uint_id = match *arg_ty {
+ crate::TypeInner::Vector { size, mut scalar } => {
+ scalar.kind = crate::ScalarKind::Uint;
+ let ty = LocalType::Value {
+ vector_size: Some(size),
+ scalar,
+ pointer_space: None,
+ }
+ .into();
+
+ self.temp_list.clear();
+ self.temp_list.resize(
+ size as _,
+ self.writer.get_constant_scalar_with(32, scalar)?,
+ );
+
+ self.writer.get_constant_composite(ty, &self.temp_list)
+ }
+ crate::TypeInner::Scalar(mut scalar) => {
+ scalar.kind = crate::ScalarKind::Uint;
+ self.writer.get_constant_scalar_with(32, scalar)?
+ }
+ _ => unreachable!(),
+ };
+
+ let lsb_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::FindILsb,
+ result_type_id,
+ lsb_id,
+ &[arg0_id],
+ ));
+
+ MathOp::Custom(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::UMin,
+ result_type_id,
+ id,
+ &[uint_id, lsb_id],
+ ))
+ }
+ Mf::CountLeadingZeros => {
+ let (int_type_id, int_id) = match *arg_ty {
+ crate::TypeInner::Vector { size, mut scalar } => {
+ scalar.kind = crate::ScalarKind::Sint;
+ let ty = LocalType::Value {
+ vector_size: Some(size),
+ scalar,
+ pointer_space: None,
+ }
+ .into();
+
+ self.temp_list.clear();
+ self.temp_list.resize(
+ size as _,
+ self.writer.get_constant_scalar_with(31, scalar)?,
+ );
+
+ (
+ self.get_type_id(ty),
+ self.writer.get_constant_composite(ty, &self.temp_list),
+ )
+ }
+ crate::TypeInner::Scalar(mut scalar) => {
+ scalar.kind = crate::ScalarKind::Sint;
+ (
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar,
+ pointer_space: None,
+ })),
+ self.writer.get_constant_scalar_with(31, scalar)?,
+ )
+ }
+ _ => unreachable!(),
+ };
+
+ let msb_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::FindUMsb,
+ int_type_id,
+ msb_id,
+ &[arg0_id],
+ ));
+
+ MathOp::Custom(Instruction::binary(
+ spirv::Op::ISub,
+ result_type_id,
+ id,
+ int_id,
+ msb_id,
+ ))
+ }
+ Mf::CountOneBits => MathOp::Custom(Instruction::unary(
+ spirv::Op::BitCount,
+ result_type_id,
+ id,
+ arg0_id,
+ )),
+ Mf::ExtractBits => {
+ let op = match arg_scalar_kind {
+ Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
+ Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
+ other => unimplemented!("Unexpected sign({:?})", other),
+ };
+ MathOp::Custom(Instruction::ternary(
+ op,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ arg2_id,
+ ))
+ }
+ Mf::InsertBits => MathOp::Custom(Instruction::quaternary(
+ spirv::Op::BitFieldInsert,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ arg2_id,
+ arg3_id,
+ )),
+ Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb),
+ Mf::FindMsb => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
+ other => unimplemented!("Unexpected findMSB({:?})", other),
+ }),
+ Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8),
+ Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8),
+ Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16),
+ Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16),
+ Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16),
+ Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8),
+ Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8),
+ Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16),
+ Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16),
+ Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16),
+ };
+
+ block.body.push(match math_op {
+ MathOp::Ext(op) => Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ op,
+ result_type_id,
+ id,
+ &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
+ ),
+ MathOp::Custom(inst) => inst,
+ });
+ id
+ }
+ crate::Expression::LocalVariable(variable) => self.function.variables[&variable].id,
+ crate::Expression::Load { pointer } => {
+ match self.write_expression_pointer(pointer, block, None)? {
+ ExpressionPointer::Ready { pointer_id } => {
+ let id = self.gen_id();
+ let atomic_space =
+ match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Pointer { base, space } => {
+ match self.ir_module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => Some(space),
+ _ => None,
+ }
+ }
+ _ => None,
+ };
+ let instruction = if let Some(space) = atomic_space {
+ let (semantics, scope) = space.to_spirv_semantics_and_scope();
+ let scope_constant_id = self.get_scope_constant(scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ Instruction::atomic_load(
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ )
+ } else {
+ Instruction::load(result_type_id, id, pointer_id, None)
+ };
+ block.body.push(instruction);
+ id
+ }
+ ExpressionPointer::Conditional { condition, access } => {
+ //TODO: support atomics?
+ self.write_conditional_indexed_load(
+ result_type_id,
+ condition,
+ block,
+ move |id_gen, block| {
+ // The in-bounds path. Perform the access and the load.
+ let pointer_id = access.result_id.unwrap();
+ let value_id = id_gen.next();
+ block.body.push(access);
+ block.body.push(Instruction::load(
+ result_type_id,
+ value_id,
+ pointer_id,
+ None,
+ ));
+ value_id
+ },
+ )
+ }
+ }
+ }
+ crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
+ crate::Expression::CallResult(_)
+ | crate::Expression::AtomicResult { .. }
+ | crate::Expression::WorkGroupUniformLoadResult { .. }
+ | crate::Expression::RayQueryProceedResult => self.cached[expr_handle],
+ crate::Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ use crate::ScalarKind as Sk;
+
+ let expr_id = self.cached[expr];
+ let (src_scalar, src_size, is_matrix) =
+ match *self.fun_info[expr].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Scalar(scalar) => (scalar, None, false),
+ crate::TypeInner::Vector { scalar, size } => (scalar, Some(size), false),
+ crate::TypeInner::Matrix { scalar, .. } => (scalar, None, true),
+ ref other => {
+ log::error!("As source {:?}", other);
+ return Err(Error::Validation("Unexpected Expression::As source"));
+ }
+ };
+
+ enum Cast {
+ Identity,
+ Unary(spirv::Op),
+ Binary(spirv::Op, Word),
+ Ternary(spirv::Op, Word, Word),
+ }
+
+ let cast = if is_matrix {
+ // we only support identity casts for matrices
+ Cast::Unary(spirv::Op::CopyObject)
+ } else {
+ match (src_scalar.kind, kind, convert) {
+ // Filter out identity casts. Some Adreno drivers are
+ // confused by no-op OpBitCast instructions.
+ (src_kind, kind, convert)
+ if src_kind == kind
+ && convert.filter(|&width| width != src_scalar.width).is_none() =>
+ {
+ Cast::Identity
+ }
+ (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject),
+ (_, _, None) => Cast::Unary(spirv::Op::Bitcast),
+ // casting to a bool - generate `OpXxxNotEqual`
+ (_, Sk::Bool, Some(_)) => {
+ let op = match src_scalar.kind {
+ Sk::Sint | Sk::Uint => spirv::Op::INotEqual,
+ Sk::Float => spirv::Op::FUnordNotEqual,
+ Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(),
+ };
+ let zero_scalar_id =
+ self.writer.get_constant_scalar_with(0, src_scalar)?;
+ let zero_id = match src_size {
+ Some(size) => {
+ let ty = LocalType::Value {
+ vector_size: Some(size),
+ scalar: src_scalar,
+ pointer_space: None,
+ }
+ .into();
+
+ self.temp_list.clear();
+ self.temp_list.resize(size as _, zero_scalar_id);
+
+ self.writer.get_constant_composite(ty, &self.temp_list)
+ }
+ None => zero_scalar_id,
+ };
+
+ Cast::Binary(op, zero_id)
+ }
+ // casting from a bool - generate `OpSelect`
+ (Sk::Bool, _, Some(dst_width)) => {
+ let dst_scalar = crate::Scalar {
+ kind,
+ width: dst_width,
+ };
+ let zero_scalar_id =
+ self.writer.get_constant_scalar_with(0, dst_scalar)?;
+ let one_scalar_id =
+ self.writer.get_constant_scalar_with(1, dst_scalar)?;
+ let (accept_id, reject_id) = match src_size {
+ Some(size) => {
+ let ty = LocalType::Value {
+ vector_size: Some(size),
+ scalar: dst_scalar,
+ pointer_space: None,
+ }
+ .into();
+
+ self.temp_list.clear();
+ self.temp_list.resize(size as _, zero_scalar_id);
+
+ let vec0_id =
+ self.writer.get_constant_composite(ty, &self.temp_list);
+
+ self.temp_list.fill(one_scalar_id);
+
+ let vec1_id =
+ self.writer.get_constant_composite(ty, &self.temp_list);
+
+ (vec1_id, vec0_id)
+ }
+ None => (one_scalar_id, zero_scalar_id),
+ };
+
+ Cast::Ternary(spirv::Op::Select, accept_id, reject_id)
+ }
+ (Sk::Float, Sk::Uint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToU),
+ (Sk::Float, Sk::Sint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToS),
+ (Sk::Float, Sk::Float, Some(dst_width))
+ if src_scalar.width != dst_width =>
+ {
+ Cast::Unary(spirv::Op::FConvert)
+ }
+ (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF),
+ (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
+ Cast::Unary(spirv::Op::SConvert)
+ }
+ (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF),
+ (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
+ Cast::Unary(spirv::Op::UConvert)
+ }
+ // We assume it's either an identity cast, or int-uint.
+ _ => Cast::Unary(spirv::Op::Bitcast),
+ }
+ };
+
+ let id = self.gen_id();
+ let instruction = match cast {
+ Cast::Identity => None,
+ Cast::Unary(op) => Some(Instruction::unary(op, result_type_id, id, expr_id)),
+ Cast::Binary(op, operand) => Some(Instruction::binary(
+ op,
+ result_type_id,
+ id,
+ expr_id,
+ operand,
+ )),
+ Cast::Ternary(op, op1, op2) => Some(Instruction::ternary(
+ op,
+ result_type_id,
+ id,
+ expr_id,
+ op1,
+ op2,
+ )),
+ };
+ if let Some(instruction) = instruction {
+ block.body.push(instruction);
+ id
+ } else {
+ expr_id
+ }
+ }
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => self.write_image_load(
+ result_type_id,
+ image,
+ coordinate,
+ array_index,
+ level,
+ sample,
+ block,
+ )?,
+ crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => self.write_image_sample(
+ result_type_id,
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ block,
+ )?,
+ crate::Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ let id = self.gen_id();
+ let mut condition_id = self.cached[condition];
+ let accept_id = self.cached[accept];
+ let reject_id = self.cached[reject];
+
+ let condition_ty = self.fun_info[condition]
+ .ty
+ .inner_with(&self.ir_module.types);
+ let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
+
+ if let (
+ &crate::TypeInner::Scalar(
+ condition_scalar @ crate::Scalar {
+ kind: crate::ScalarKind::Bool,
+ ..
+ },
+ ),
+ &crate::TypeInner::Vector { size, .. },
+ ) = (condition_ty, object_ty)
+ {
+ self.temp_list.clear();
+ self.temp_list.resize(size as usize, condition_id);
+
+ let bool_vector_type_id =
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(size),
+ scalar: condition_scalar,
+ pointer_space: None,
+ }));
+
+ let id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ bool_vector_type_id,
+ id,
+ &self.temp_list,
+ ));
+ condition_id = id
+ }
+
+ let instruction =
+ Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
+ block.body.push(instruction);
+ id
+ }
+ crate::Expression::Derivative { axis, ctrl, expr } => {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ match ctrl {
+ Ctrl::Coarse | Ctrl::Fine => {
+ self.writer.require_any(
+ "DerivativeControl",
+ &[spirv::Capability::DerivativeControl],
+ )?;
+ }
+ Ctrl::None => {}
+ }
+ let id = self.gen_id();
+ let expr_id = self.cached[expr];
+ let op = match (axis, ctrl) {
+ (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse,
+ (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine,
+ (Axis::X, Ctrl::None) => spirv::Op::DPdx,
+ (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse,
+ (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine,
+ (Axis::Y, Ctrl::None) => spirv::Op::DPdy,
+ (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse,
+ (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine,
+ (Axis::Width, Ctrl::None) => spirv::Op::Fwidth,
+ };
+ block
+ .body
+ .push(Instruction::derivative(op, result_type_id, id, expr_id));
+ id
+ }
+ crate::Expression::ImageQuery { image, query } => {
+ self.write_image_query(result_type_id, image, query, block)?
+ }
+ crate::Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+ let arg_id = self.cached[argument];
+ let op = match fun {
+ Rf::All => spirv::Op::All,
+ Rf::Any => spirv::Op::Any,
+ Rf::IsNan => spirv::Op::IsNan,
+ Rf::IsInf => spirv::Op::IsInf,
+ };
+ let id = self.gen_id();
+ block
+ .body
+ .push(Instruction::relational(op, result_type_id, id, arg_id));
+ id
+ }
+ crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
+ crate::Expression::RayQueryGetIntersection { query, committed } => {
+ if !committed {
+ return Err(Error::FeatureNotImplemented("candidate intersection"));
+ }
+ self.write_ray_query_get_intersection(query, block)
+ }
+ };
+
+ self.cached[expr_handle] = id;
+ Ok(())
+ }
+
+ /// Build an `OpAccessChain` instruction.
+ ///
+ /// Emit any needed bounds-checking expressions to `block`.
+ ///
+ /// Some cases we need to generate a different return type than what the IR gives us.
+ /// This is because pointers to binding arrays of handles (such as images or samplers)
+ /// don't exist in the IR, but we need to create them to create an access chain in SPIRV.
+ ///
+ /// On success, the return value is an [`ExpressionPointer`] value; see the
+ /// documentation for that type.
+ fn write_expression_pointer(
+ &mut self,
+ mut expr_handle: Handle<crate::Expression>,
+ block: &mut Block,
+ return_type_override: Option<LookupType>,
+ ) -> Result<ExpressionPointer, Error> {
+ let result_lookup_ty = match self.fun_info[expr_handle].ty {
+ TypeResolution::Handle(ty_handle) => match return_type_override {
+ // We use the return type override as a special case for handle binding arrays as the OpAccessChain
+ // needs to return a pointer, but indexing into a handle binding array just gives you the type of
+ // the binding in the IR.
+ Some(ty) => ty,
+ None => LookupType::Handle(ty_handle),
+ },
+ TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()),
+ };
+ let result_type_id = self.get_type_id(result_lookup_ty);
+
+ // The id of the boolean `and` of all dynamic bounds checks up to this point. If
+ // `None`, then we haven't done any dynamic bounds checks yet.
+ //
+ // When we have a chain of bounds checks, we combine them with `OpLogicalAnd`, not
+ // a short-circuit branch. This means we might do comparisons we don't need to,
+ // but we expect these checks to almost always succeed, and keeping branches to a
+ // minimum is essential.
+ let mut accumulated_checks = None;
+ // Is true if we are accessing into a binding array with a non-uniform index.
+ let mut is_non_uniform_binding_array = false;
+
+ self.temp_list.clear();
+ let root_id = loop {
+ expr_handle = match self.ir_function.expressions[expr_handle] {
+ crate::Expression::Access { base, index } => {
+ if let crate::Expression::GlobalVariable(var_handle) =
+ self.ir_function.expressions[base]
+ {
+ // The access chain needs to be decorated as NonUniform
+ // see VUID-RuntimeSpirv-NonUniform-06274
+ let gvar = &self.ir_module.global_variables[var_handle];
+ if let crate::TypeInner::BindingArray { .. } =
+ self.ir_module.types[gvar.ty].inner
+ {
+ is_non_uniform_binding_array =
+ self.fun_info[index].uniformity.non_uniform_result.is_some();
+ }
+ }
+
+ let index_id = match self.write_bounds_check(base, index, block)? {
+ BoundsCheckResult::KnownInBounds(known_index) => {
+ // Even if the index is known, `OpAccessIndex`
+ // requires expression operands, not literals.
+ let scalar = crate::Literal::U32(known_index);
+ self.writer.get_constant_scalar(scalar)
+ }
+ BoundsCheckResult::Computed(computed_index_id) => computed_index_id,
+ BoundsCheckResult::Conditional(comparison_id) => {
+ match accumulated_checks {
+ Some(prior_checks) => {
+ let combined = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::LogicalAnd,
+ self.writer.get_bool_type_id(),
+ combined,
+ prior_checks,
+ comparison_id,
+ ));
+ accumulated_checks = Some(combined);
+ }
+ None => {
+ // Start a fresh chain of checks.
+ accumulated_checks = Some(comparison_id);
+ }
+ }
+
+ // Either way, the index to use is unchanged.
+ self.cached[index]
+ }
+ };
+ self.temp_list.push(index_id);
+ base
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ let const_id = self.get_index_constant(index);
+ self.temp_list.push(const_id);
+ base
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ let gv = &self.writer.global_variables[handle.index()];
+ break gv.access_id;
+ }
+ crate::Expression::LocalVariable(variable) => {
+ let local_var = &self.function.variables[&variable];
+ break local_var.id;
+ }
+ crate::Expression::FunctionArgument(index) => {
+ break self.function.parameter_id(index);
+ }
+ ref other => unimplemented!("Unexpected pointer expression {:?}", other),
+ }
+ };
+
+ let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
+ (
+ root_id,
+ ExpressionPointer::Ready {
+ pointer_id: root_id,
+ },
+ )
+ } else {
+ self.temp_list.reverse();
+ let pointer_id = self.gen_id();
+ let access =
+ Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
+
+ // If we generated some bounds checks, we need to leave it to our
+ // caller to generate the branch, the access, the load or store, and
+ // the zero value (for loads). Otherwise, we can emit the access
+ // ourselves, and just hand them the id of the pointer.
+ let expr_pointer = match accumulated_checks {
+ Some(condition) => ExpressionPointer::Conditional { condition, access },
+ None => {
+ block.body.push(access);
+ ExpressionPointer::Ready { pointer_id }
+ }
+ };
+ (pointer_id, expr_pointer)
+ };
+ // Subsequent load, store and atomic operations require the pointer to be decorated as NonUniform
+ // if the binding array was accessed with a non-uniform index
+ // see VUID-RuntimeSpirv-NonUniform-06274
+ if is_non_uniform_binding_array {
+ self.writer
+ .decorate_non_uniform_binding_array_access(pointer_id)?;
+ }
+
+ Ok(expr_pointer)
+ }
+
+ /// Build the instructions for matrix - matrix column operations
+ #[allow(clippy::too_many_arguments)]
+ fn write_matrix_matrix_column_op(
+ &mut self,
+ block: &mut Block,
+ result_id: Word,
+ result_type_id: Word,
+ left_id: Word,
+ right_id: Word,
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: u8,
+ op: spirv::Op,
+ ) {
+ self.temp_list.clear();
+
+ let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(rows),
+ scalar: crate::Scalar::float(width),
+ pointer_space: None,
+ }));
+
+ for index in 0..columns as u32 {
+ let column_id_left = self.gen_id();
+ let column_id_right = self.gen_id();
+ let column_id_res = self.gen_id();
+
+ block.body.push(Instruction::composite_extract(
+ vector_type_id,
+ column_id_left,
+ left_id,
+ &[index],
+ ));
+ block.body.push(Instruction::composite_extract(
+ vector_type_id,
+ column_id_right,
+ right_id,
+ &[index],
+ ));
+ block.body.push(Instruction::binary(
+ op,
+ vector_type_id,
+ column_id_res,
+ column_id_left,
+ column_id_right,
+ ));
+
+ self.temp_list.push(column_id_res);
+ }
+
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ result_id,
+ &self.temp_list,
+ ));
+ }
+
+ /// Build the instructions for vector - scalar multiplication
+ fn write_vector_scalar_mult(
+ &mut self,
+ block: &mut Block,
+ result_id: Word,
+ result_type_id: Word,
+ vector_id: Word,
+ scalar_id: Word,
+ vector: &crate::TypeInner,
+ ) {
+ let (size, kind) = match *vector {
+ crate::TypeInner::Vector {
+ size,
+ scalar: crate::Scalar { kind, .. },
+ } => (size, kind),
+ _ => unreachable!(),
+ };
+
+ let (op, operand_id) = match kind {
+ crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id),
+ _ => {
+ let operand_id = self.gen_id();
+ self.temp_list.clear();
+ self.temp_list.resize(size as usize, scalar_id);
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ operand_id,
+ &self.temp_list,
+ ));
+ (spirv::Op::IMul, operand_id)
+ }
+ };
+
+ block.body.push(Instruction::binary(
+ op,
+ result_type_id,
+ result_id,
+ vector_id,
+ operand_id,
+ ));
+ }
+
+ /// Build the instructions for the arithmetic expression of a dot product
+ fn write_dot_product(
+ &mut self,
+ result_id: Word,
+ result_type_id: Word,
+ arg0_id: Word,
+ arg1_id: Word,
+ size: u32,
+ block: &mut Block,
+ ) {
+ let mut partial_sum = self.writer.get_constant_null(result_type_id);
+ let last_component = size - 1;
+ for index in 0..=last_component {
+ // compute the product of the current components
+ let a_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ a_id,
+ arg0_id,
+ &[index],
+ ));
+ let b_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ b_id,
+ arg1_id,
+ &[index],
+ ));
+ let prod_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::IMul,
+ result_type_id,
+ prod_id,
+ a_id,
+ b_id,
+ ));
+
+ // choose the id for the next sum, depending on current index
+ let id = if index == last_component {
+ result_id
+ } else {
+ self.gen_id()
+ };
+
+ // sum the computed product with the partial sum
+ block.body.push(Instruction::binary(
+ spirv::Op::IAdd,
+ result_type_id,
+ id,
+ partial_sum,
+ prod_id,
+ ));
+ // set the id of the result as the previous partial sum
+ partial_sum = id;
+ }
+ }
+
+ pub(super) fn write_block(
+ &mut self,
+ label_id: Word,
+ naga_block: &crate::Block,
+ exit: BlockExit,
+ loop_context: LoopContext,
+ debug_info: Option<&DebugInfoInner>,
+ ) -> Result<(), Error> {
+ let mut block = Block::new(label_id);
+ for (statement, span) in naga_block.span_iter() {
+ if let (Some(debug_info), false) = (
+ debug_info,
+ matches!(
+ statement,
+ &(Statement::Block(..)
+ | Statement::Break
+ | Statement::Continue
+ | Statement::Kill
+ | Statement::Return { .. }
+ | Statement::Loop { .. })
+ ),
+ ) {
+ let loc: crate::SourceLocation = span.location(debug_info.source_code);
+ block.body.push(Instruction::line(
+ debug_info.source_file_id,
+ loc.line_number,
+ loc.line_position,
+ ));
+ };
+ match *statement {
+ crate::Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ // omit const expressions as we've already cached those
+ if !self.expression_constness.is_const(handle) {
+ self.cache_expression_value(handle, &mut block)?;
+ }
+ }
+ }
+ crate::Statement::Block(ref block_statements) => {
+ let scope_id = self.gen_id();
+ self.function.consume(block, Instruction::branch(scope_id));
+
+ let merge_id = self.gen_id();
+ self.write_block(
+ scope_id,
+ block_statements,
+ BlockExit::Branch { target: merge_id },
+ loop_context,
+ debug_info,
+ )?;
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ let condition_id = self.cached[condition];
+
+ let merge_id = self.gen_id();
+ block.body.push(Instruction::selection_merge(
+ merge_id,
+ spirv::SelectionControl::NONE,
+ ));
+
+ let accept_id = if accept.is_empty() {
+ None
+ } else {
+ Some(self.gen_id())
+ };
+ let reject_id = if reject.is_empty() {
+ None
+ } else {
+ Some(self.gen_id())
+ };
+
+ self.function.consume(
+ block,
+ Instruction::branch_conditional(
+ condition_id,
+ accept_id.unwrap_or(merge_id),
+ reject_id.unwrap_or(merge_id),
+ ),
+ );
+
+ if let Some(block_id) = accept_id {
+ self.write_block(
+ block_id,
+ accept,
+ BlockExit::Branch { target: merge_id },
+ loop_context,
+ debug_info,
+ )?;
+ }
+ if let Some(block_id) = reject_id {
+ self.write_block(
+ block_id,
+ reject,
+ BlockExit::Branch { target: merge_id },
+ loop_context,
+ debug_info,
+ )?;
+ }
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ let selector_id = self.cached[selector];
+
+ let merge_id = self.gen_id();
+ block.body.push(Instruction::selection_merge(
+ merge_id,
+ spirv::SelectionControl::NONE,
+ ));
+
+ let mut default_id = None;
+ // id of previous empty fall-through case
+ let mut last_id = None;
+
+ let mut raw_cases = Vec::with_capacity(cases.len());
+ let mut case_ids = Vec::with_capacity(cases.len());
+ for case in cases.iter() {
+ // take id of previous empty fall-through case or generate a new one
+ let label_id = last_id.take().unwrap_or_else(|| self.gen_id());
+
+ if case.fall_through && case.body.is_empty() {
+ last_id = Some(label_id);
+ }
+
+ case_ids.push(label_id);
+
+ match case.value {
+ crate::SwitchValue::I32(value) => {
+ raw_cases.push(super::instructions::Case {
+ value: value as Word,
+ label_id,
+ });
+ }
+ crate::SwitchValue::U32(value) => {
+ raw_cases.push(super::instructions::Case { value, label_id });
+ }
+ crate::SwitchValue::Default => {
+ default_id = Some(label_id);
+ }
+ }
+ }
+
+ let default_id = default_id.unwrap();
+
+ self.function.consume(
+ block,
+ Instruction::switch(selector_id, default_id, &raw_cases),
+ );
+
+ let inner_context = LoopContext {
+ break_id: Some(merge_id),
+ ..loop_context
+ };
+
+ for (i, (case, label_id)) in cases
+ .iter()
+ .zip(case_ids.iter())
+ .filter(|&(case, _)| !(case.fall_through && case.body.is_empty()))
+ .enumerate()
+ {
+ let case_finish_id = if case.fall_through {
+ case_ids[i + 1]
+ } else {
+ merge_id
+ };
+ self.write_block(
+ *label_id,
+ &case.body,
+ BlockExit::Branch {
+ target: case_finish_id,
+ },
+ inner_context,
+ debug_info,
+ )?;
+ }
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ let preamble_id = self.gen_id();
+ self.function
+ .consume(block, Instruction::branch(preamble_id));
+
+ let merge_id = self.gen_id();
+ let body_id = self.gen_id();
+ let continuing_id = self.gen_id();
+
+ // SPIR-V requires the continuing to the `OpLoopMerge`,
+ // so we have to start a new block with it.
+ block = Block::new(preamble_id);
+ // HACK the loop statement is begin with branch instruction,
+ // so we need to put `OpLine` debug info before merge instruction
+ if let Some(debug_info) = debug_info {
+ let loc: crate::SourceLocation = span.location(debug_info.source_code);
+ block.body.push(Instruction::line(
+ debug_info.source_file_id,
+ loc.line_number,
+ loc.line_position,
+ ))
+ }
+ block.body.push(Instruction::loop_merge(
+ merge_id,
+ continuing_id,
+ spirv::SelectionControl::NONE,
+ ));
+ self.function.consume(block, Instruction::branch(body_id));
+
+ self.write_block(
+ body_id,
+ body,
+ BlockExit::Branch {
+ target: continuing_id,
+ },
+ LoopContext {
+ continuing_id: Some(continuing_id),
+ break_id: Some(merge_id),
+ },
+ debug_info,
+ )?;
+
+ let exit = match break_if {
+ Some(condition) => BlockExit::BreakIf {
+ condition,
+ preamble_id,
+ },
+ None => BlockExit::Branch {
+ target: preamble_id,
+ },
+ };
+
+ self.write_block(
+ continuing_id,
+ continuing,
+ exit,
+ LoopContext {
+ continuing_id: None,
+ break_id: Some(merge_id),
+ },
+ debug_info,
+ )?;
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::Break => {
+ self.function
+ .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
+ return Ok(());
+ }
+ crate::Statement::Continue => {
+ self.function.consume(
+ block,
+ Instruction::branch(loop_context.continuing_id.unwrap()),
+ );
+ return Ok(());
+ }
+ crate::Statement::Return { value: Some(value) } => {
+ let value_id = self.cached[value];
+ let instruction = match self.function.entry_point_context {
+ // If this is an entry point, and we need to return anything,
+ // let's instead store the output variables and return `void`.
+ Some(ref context) => {
+ self.writer.write_entry_point_return(
+ value_id,
+ self.ir_function.result.as_ref().unwrap(),
+ &context.results,
+ &mut block.body,
+ )?;
+ Instruction::return_void()
+ }
+ None => Instruction::return_value(value_id),
+ };
+ self.function.consume(block, instruction);
+ return Ok(());
+ }
+ crate::Statement::Return { value: None } => {
+ self.function.consume(block, Instruction::return_void());
+ return Ok(());
+ }
+ crate::Statement::Kill => {
+ self.function.consume(block, Instruction::kill());
+ return Ok(());
+ }
+ crate::Statement::Barrier(flags) => {
+ self.writer.write_barrier(flags, &mut block);
+ }
+ crate::Statement::Store { pointer, value } => {
+ let value_id = self.cached[value];
+ match self.write_expression_pointer(pointer, &mut block, None)? {
+ ExpressionPointer::Ready { pointer_id } => {
+ let atomic_space = match *self.fun_info[pointer]
+ .ty
+ .inner_with(&self.ir_module.types)
+ {
+ crate::TypeInner::Pointer { base, space } => {
+ match self.ir_module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => Some(space),
+ _ => None,
+ }
+ }
+ _ => None,
+ };
+ let instruction = if let Some(space) = atomic_space {
+ let (semantics, scope) = space.to_spirv_semantics_and_scope();
+ let scope_constant_id = self.get_scope_constant(scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ Instruction::atomic_store(
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ } else {
+ Instruction::store(pointer_id, value_id, None)
+ };
+ block.body.push(instruction);
+ }
+ ExpressionPointer::Conditional { condition, access } => {
+ let mut selection = Selection::start(&mut block, ());
+ selection.if_true(self, condition, ());
+
+ // The in-bounds path. Perform the access and the store.
+ let pointer_id = access.result_id.unwrap();
+ selection.block().body.push(access);
+ selection
+ .block()
+ .body
+ .push(Instruction::store(pointer_id, value_id, None));
+
+ // Finish the in-bounds block and start the merge block. This
+ // is the block we'll leave current on return.
+ selection.finish(self, ());
+ }
+ };
+ }
+ crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
+ crate::Statement::Call {
+ function: local_function,
+ ref arguments,
+ result,
+ } => {
+ let id = self.gen_id();
+ self.temp_list.clear();
+ for &argument in arguments {
+ self.temp_list.push(self.cached[argument]);
+ }
+
+ let type_id = match result {
+ Some(expr) => {
+ self.cached[expr] = id;
+ self.get_expression_type_id(&self.fun_info[expr].ty)
+ }
+ None => self.writer.void_type,
+ };
+
+ block.body.push(Instruction::function_call(
+ type_id,
+ id,
+ self.writer.lookup_function[&local_function],
+ &self.temp_list,
+ ));
+ }
+ crate::Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ let id = self.gen_id();
+ let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
+
+ self.cached[result] = id;
+
+ let pointer_id =
+ match self.write_expression_pointer(pointer, &mut block, None)? {
+ ExpressionPointer::Ready { pointer_id } => pointer_id,
+ ExpressionPointer::Conditional { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "Atomics out-of-bounds handling",
+ ));
+ }
+ };
+
+ let space = self.fun_info[pointer]
+ .ty
+ .inner_with(&self.ir_module.types)
+ .pointer_space()
+ .unwrap();
+ let (semantics, scope) = space.to_spirv_semantics_and_scope();
+ let scope_constant_id = self.get_scope_constant(scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ let value_id = self.cached[value];
+ let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
+
+ let instruction = match *fun {
+ crate::AtomicFunction::Add => Instruction::atomic_binary(
+ spirv::Op::AtomicIAdd,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::Subtract => Instruction::atomic_binary(
+ spirv::Op::AtomicISub,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::And => Instruction::atomic_binary(
+ spirv::Op::AtomicAnd,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::InclusiveOr => Instruction::atomic_binary(
+ spirv::Op::AtomicOr,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::ExclusiveOr => Instruction::atomic_binary(
+ spirv::Op::AtomicXor,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::Min => {
+ let spirv_op = match *value_inner {
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: _,
+ }) => spirv::Op::AtomicSMin,
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: _,
+ }) => spirv::Op::AtomicUMin,
+ _ => unimplemented!(),
+ };
+ Instruction::atomic_binary(
+ spirv_op,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ }
+ crate::AtomicFunction::Max => {
+ let spirv_op = match *value_inner {
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: _,
+ }) => spirv::Op::AtomicSMax,
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: _,
+ }) => spirv::Op::AtomicUMax,
+ _ => unimplemented!(),
+ };
+ Instruction::atomic_binary(
+ spirv_op,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ }
+ crate::AtomicFunction::Exchange { compare: None } => {
+ Instruction::atomic_binary(
+ spirv::Op::AtomicExchange,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ }
+ crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
+ let scalar_type_id = match *value_inner {
+ crate::TypeInner::Scalar(scalar) => {
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar,
+ pointer_space: None,
+ }))
+ }
+ _ => unimplemented!(),
+ };
+ let bool_type_id =
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::BOOL,
+ pointer_space: None,
+ }));
+
+ let cas_result_id = self.gen_id();
+ let equality_result_id = self.gen_id();
+ let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
+ cas_instr.set_type(scalar_type_id);
+ cas_instr.set_result(cas_result_id);
+ cas_instr.add_operand(pointer_id);
+ cas_instr.add_operand(scope_constant_id);
+ cas_instr.add_operand(semantics_id); // semantics if equal
+ cas_instr.add_operand(semantics_id); // semantics if not equal
+ cas_instr.add_operand(value_id);
+ cas_instr.add_operand(self.cached[cmp]);
+ block.body.push(cas_instr);
+ block.body.push(Instruction::binary(
+ spirv::Op::IEqual,
+ bool_type_id,
+ equality_result_id,
+ cas_result_id,
+ self.cached[cmp],
+ ));
+ Instruction::composite_construct(
+ result_type_id,
+ id,
+ &[cas_result_id, equality_result_id],
+ )
+ }
+ };
+
+ block.body.push(instruction);
+ }
+ crate::Statement::WorkGroupUniformLoad { pointer, result } => {
+ self.writer
+ .write_barrier(crate::Barrier::WORK_GROUP, &mut block);
+ let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
+ // Embed the body of
+ match self.write_expression_pointer(pointer, &mut block, None)? {
+ ExpressionPointer::Ready { pointer_id } => {
+ let id = self.gen_id();
+ block.body.push(Instruction::load(
+ result_type_id,
+ id,
+ pointer_id,
+ None,
+ ));
+ self.cached[result] = id;
+ }
+ ExpressionPointer::Conditional { condition, access } => {
+ self.cached[result] = self.write_conditional_indexed_load(
+ result_type_id,
+ condition,
+ &mut block,
+ move |id_gen, block| {
+ // The in-bounds path. Perform the access and the load.
+ let pointer_id = access.result_id.unwrap();
+ let value_id = id_gen.next();
+ block.body.push(access);
+ block.body.push(Instruction::load(
+ result_type_id,
+ value_id,
+ pointer_id,
+ None,
+ ));
+ value_id
+ },
+ )
+ }
+ }
+ self.writer
+ .write_barrier(crate::Barrier::WORK_GROUP, &mut block);
+ }
+ crate::Statement::RayQuery { query, ref fun } => {
+ self.write_ray_query_function(query, fun, &mut block);
+ }
+ }
+ }
+
+ let termination = match exit {
+ // We're generating code for the top-level Block of the function, so we
+ // need to end it with some kind of return instruction.
+ BlockExit::Return => match self.ir_function.result {
+ Some(ref result) if self.function.entry_point_context.is_none() => {
+ let type_id = self.get_type_id(LookupType::Handle(result.ty));
+ let null_id = self.writer.get_constant_null(type_id);
+ Instruction::return_value(null_id)
+ }
+ _ => Instruction::return_void(),
+ },
+ BlockExit::Branch { target } => Instruction::branch(target),
+ BlockExit::BreakIf {
+ condition,
+ preamble_id,
+ } => {
+ let condition_id = self.cached[condition];
+
+ Instruction::branch_conditional(
+ condition_id,
+ loop_context.break_id.unwrap(),
+ preamble_id,
+ )
+ }
+ };
+
+ self.function.consume(block, termination);
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/helpers.rs b/third_party/rust/naga/src/back/spv/helpers.rs
new file mode 100644
index 0000000000..5b6226db85
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/helpers.rs
@@ -0,0 +1,109 @@
+use crate::{Handle, UniqueArena};
+use spirv::Word;
+
+pub(super) fn bytes_to_words(bytes: &[u8]) -> Vec<Word> {
+ bytes
+ .chunks(4)
+ .map(|chars| chars.iter().rev().fold(0u32, |u, c| (u << 8) | *c as u32))
+ .collect()
+}
+
+pub(super) fn string_to_words(input: &str) -> Vec<Word> {
+ let bytes = input.as_bytes();
+ let mut words = bytes_to_words(bytes);
+
+ if bytes.len() % 4 == 0 {
+ // nul-termination
+ words.push(0x0u32);
+ }
+
+ words
+}
+
+pub(super) const fn map_storage_class(space: crate::AddressSpace) -> spirv::StorageClass {
+ match space {
+ crate::AddressSpace::Handle => spirv::StorageClass::UniformConstant,
+ crate::AddressSpace::Function => spirv::StorageClass::Function,
+ crate::AddressSpace::Private => spirv::StorageClass::Private,
+ crate::AddressSpace::Storage { .. } => spirv::StorageClass::StorageBuffer,
+ crate::AddressSpace::Uniform => spirv::StorageClass::Uniform,
+ crate::AddressSpace::WorkGroup => spirv::StorageClass::Workgroup,
+ crate::AddressSpace::PushConstant => spirv::StorageClass::PushConstant,
+ }
+}
+
+pub(super) fn contains_builtin(
+ binding: Option<&crate::Binding>,
+ ty: Handle<crate::Type>,
+ arena: &UniqueArena<crate::Type>,
+ built_in: crate::BuiltIn,
+) -> bool {
+ if let Some(&crate::Binding::BuiltIn(bi)) = binding {
+ bi == built_in
+ } else if let crate::TypeInner::Struct { ref members, .. } = arena[ty].inner {
+ members
+ .iter()
+ .any(|member| contains_builtin(member.binding.as_ref(), member.ty, arena, built_in))
+ } else {
+ false // unreachable
+ }
+}
+
+impl crate::AddressSpace {
+ pub(super) const fn to_spirv_semantics_and_scope(
+ self,
+ ) -> (spirv::MemorySemantics, spirv::Scope) {
+ match self {
+ Self::Storage { .. } => (spirv::MemorySemantics::UNIFORM_MEMORY, spirv::Scope::Device),
+ Self::WorkGroup => (
+ spirv::MemorySemantics::WORKGROUP_MEMORY,
+ spirv::Scope::Workgroup,
+ ),
+ _ => (spirv::MemorySemantics::empty(), spirv::Scope::Invocation),
+ }
+ }
+}
+
+/// Return true if the global requires a type decorated with `Block`.
+///
+/// Vulkan spec v1.3 §15.6.2, "Descriptor Set Interface", says:
+///
+/// > Variables identified with the `Uniform` storage class are used to
+/// > access transparent buffer backed resources. Such variables must
+/// > be:
+/// >
+/// > - typed as `OpTypeStruct`, or an array of this type,
+/// >
+/// > - identified with a `Block` or `BufferBlock` decoration, and
+/// >
+/// > - laid out explicitly using the `Offset`, `ArrayStride`, and
+/// > `MatrixStride` decorations as specified in §15.6.4, "Offset
+/// > and Stride Assignment."
+// See `back::spv::GlobalVariable::access_id` for details.
+pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariable) -> bool {
+ match var.space {
+ crate::AddressSpace::Uniform
+ | crate::AddressSpace::Storage { .. }
+ | crate::AddressSpace::PushConstant => {}
+ _ => return false,
+ };
+ match ir_module.types[var.ty].inner {
+ crate::TypeInner::Struct {
+ ref members,
+ span: _,
+ } => match members.last() {
+ Some(member) => match ir_module.types[member.ty].inner {
+ // Structs with dynamically sized arrays can't be copied and can't be wrapped.
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } => false,
+ _ => true,
+ },
+ None => false,
+ },
+ crate::TypeInner::BindingArray { .. } => false,
+ // if it's not a structure or a binding array, let's wrap it to be able to put "Block"
+ _ => true,
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/image.rs b/third_party/rust/naga/src/back/spv/image.rs
new file mode 100644
index 0000000000..c0fc41cbb6
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/image.rs
@@ -0,0 +1,1210 @@
+/*!
+Generating SPIR-V for image operations.
+*/
+
+use super::{
+ selection::{MergeTuple, Selection},
+ Block, BlockContext, Error, IdGenerator, Instruction, LocalType, LookupType,
+};
+use crate::arena::Handle;
+use spirv::Word;
+
+/// Information about a vector of coordinates.
+///
+/// The coordinate vectors expected by SPIR-V `OpImageRead` and `OpImageFetch`
+/// supply the array index for arrayed images as an additional component at
+/// the end, whereas Naga's `ImageLoad`, `ImageStore`, and `ImageSample` carry
+/// the array index as a separate field.
+///
+/// In the process of generating code to compute the combined vector, we also
+/// produce SPIR-V types and vector lengths that are useful elsewhere. This
+/// struct gathers that information into one place, with standard names.
+struct ImageCoordinates {
+ /// The SPIR-V id of the combined coordinate/index vector value.
+ ///
+ /// Note: when indexing a non-arrayed 1D image, this will be a scalar.
+ value_id: Word,
+
+ /// The SPIR-V id of the type of `value`.
+ type_id: Word,
+
+ /// The number of components in `value`, if it is a vector, or `None` if it
+ /// is a scalar.
+ size: Option<crate::VectorSize>,
+}
+
+/// A trait for image access (load or store) code generators.
+///
+/// Types implementing this trait hold information about an `ImageStore` or
+/// `ImageLoad` operation that is not affected by the bounds check policy. The
+/// `generate` method emits code for the access, given the results of bounds
+/// checking.
+///
+/// The [`image`] bounds checks policy affects access coordinates, level of
+/// detail, and sample index, but never the image id, result type (if any), or
+/// the specific SPIR-V instruction used. Types that implement this trait gather
+/// together the latter category, so we don't have to plumb them through the
+/// bounds-checking code.
+///
+/// [`image`]: crate::proc::BoundsCheckPolicies::index
+trait Access {
+ /// The Rust type that represents SPIR-V values and types for this access.
+ ///
+ /// For operations like loads, this is `Word`. For operations like stores,
+ /// this is `()`.
+ ///
+ /// For `ReadZeroSkipWrite`, this will be the type of the selection
+ /// construct that performs the bounds checks, so it must implement
+ /// `MergeTuple`.
+ type Output: MergeTuple + Copy + Clone;
+
+ /// Write an image access to `block`.
+ ///
+ /// Access the texel at `coordinates_id`. The optional `level_id` indicates
+ /// the level of detail, and `sample_id` is the index of the sample to
+ /// access in a multisampled texel.
+ ///
+ /// This method assumes that `coordinates_id` has already had the image array
+ /// index, if any, folded in, as done by `write_image_coordinates`.
+ ///
+ /// Return the value id produced by the instruction, if any.
+ ///
+ /// Use `id_gen` to generate SPIR-V ids as necessary.
+ fn generate(
+ &self,
+ id_gen: &mut IdGenerator,
+ coordinates_id: Word,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ ) -> Self::Output;
+
+ /// Return the SPIR-V type of the value produced by the code written by
+ /// `generate`. If the access does not produce a value, `Self::Output`
+ /// should be `()`.
+ fn result_type(&self) -> Self::Output;
+
+ /// Construct the SPIR-V 'zero' value to be returned for an out-of-bounds
+ /// access under the `ReadZeroSkipWrite` policy. If the access does not
+ /// produce a value, `Self::Output` should be `()`.
+ fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Self::Output;
+}
+
+/// Texel access information for an [`ImageLoad`] expression.
+///
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+struct Load {
+ /// The specific opcode we'll use to perform the fetch. Storage images
+ /// require `OpImageRead`, while sampled images require `OpImageFetch`.
+ opcode: spirv::Op,
+
+ /// The type id produced by the actual image access instruction.
+ type_id: Word,
+
+ /// The id of the image being accessed.
+ image_id: Word,
+}
+
+impl Load {
+ fn from_image_expr(
+ ctx: &mut BlockContext<'_>,
+ image_id: Word,
+ image_class: crate::ImageClass,
+ result_type_id: Word,
+ ) -> Result<Load, Error> {
+ let opcode = match image_class {
+ crate::ImageClass::Storage { .. } => spirv::Op::ImageRead,
+ crate::ImageClass::Depth { .. } | crate::ImageClass::Sampled { .. } => {
+ spirv::Op::ImageFetch
+ }
+ };
+
+ // `OpImageRead` and `OpImageFetch` instructions produce vec4<f32>
+ // values. Most of the time, we can just use `result_type_id` for
+ // this. The exception is that `Expression::ImageLoad` from a depth
+ // image produces a scalar `f32`, so in that case we need to find
+ // the right SPIR-V type for the access instruction here.
+ let type_id = match image_class {
+ crate::ImageClass::Depth { .. } => {
+ ctx.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Quad),
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }))
+ }
+ _ => result_type_id,
+ };
+
+ Ok(Load {
+ opcode,
+ type_id,
+ image_id,
+ })
+ }
+}
+
+impl Access for Load {
+ type Output = Word;
+
+ /// Write an instruction to access a given texel of this image.
+ fn generate(
+ &self,
+ id_gen: &mut IdGenerator,
+ coordinates_id: Word,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ ) -> Word {
+ let texel_id = id_gen.next();
+ let mut instruction = Instruction::image_fetch_or_read(
+ self.opcode,
+ self.type_id,
+ texel_id,
+ self.image_id,
+ coordinates_id,
+ );
+
+ match (level_id, sample_id) {
+ (None, None) => {}
+ (Some(level_id), None) => {
+ instruction.add_operand(spirv::ImageOperands::LOD.bits());
+ instruction.add_operand(level_id);
+ }
+ (None, Some(sample_id)) => {
+ instruction.add_operand(spirv::ImageOperands::SAMPLE.bits());
+ instruction.add_operand(sample_id);
+ }
+ // There's no such thing as a multi-sampled mipmap.
+ (Some(_), Some(_)) => unreachable!(),
+ }
+
+ block.body.push(instruction);
+
+ texel_id
+ }
+
+ fn result_type(&self) -> Word {
+ self.type_id
+ }
+
+ fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Word {
+ ctx.writer.get_constant_null(self.type_id)
+ }
+}
+
+/// Texel access information for a [`Store`] statement.
+///
+/// [`Store`]: crate::Statement::Store
+struct Store {
+ /// The id of the image being written to.
+ image_id: Word,
+
+ /// The value we're going to write to the texel.
+ value_id: Word,
+}
+
+impl Access for Store {
+ /// Stores don't generate any value.
+ type Output = ();
+
+ fn generate(
+ &self,
+ _id_gen: &mut IdGenerator,
+ coordinates_id: Word,
+ _level_id: Option<Word>,
+ _sample_id: Option<Word>,
+ block: &mut Block,
+ ) {
+ block.body.push(Instruction::image_write(
+ self.image_id,
+ coordinates_id,
+ self.value_id,
+ ));
+ }
+
+ /// Stores don't generate any value, so this just returns `()`.
+ fn result_type(&self) {}
+
+ /// Stores don't generate any value, so this just returns `()`.
+ fn out_of_bounds_value(&self, _ctx: &mut BlockContext<'_>) {}
+}
+
+impl<'w> BlockContext<'w> {
+ /// Extend image coordinates with an array index, if necessary.
+ ///
+ /// Whereas [`Expression::ImageLoad`] and [`ImageSample`] treat the array
+ /// index as a separate operand from the coordinates, SPIR-V image access
+ /// instructions include the array index in the `coordinates` operand. This
+ /// function builds a SPIR-V coordinate vector from a Naga coordinate vector
+ /// and array index, if one is supplied, and returns a `ImageCoordinates`
+ /// struct describing what it built.
+ ///
+ /// If `array_index` is `Some(expr)`, then this function constructs a new
+ /// vector that is `coordinates` with `array_index` concatenated onto the
+ /// end: a `vec2` becomes a `vec3`, a scalar becomes a `vec2`, and so on.
+ ///
+ /// If `array_index` is `None`, then the return value uses `coordinates`
+ /// unchanged. Note that, when indexing a non-arrayed 1D image, this will be
+ /// a scalar value.
+ ///
+ /// If needed, this function generates code to convert the array index,
+ /// always an integer scalar, to match the component type of `coordinates`.
+ /// Naga's `ImageLoad` and SPIR-V's `OpImageRead`, `OpImageFetch`, and
+ /// `OpImageWrite` all use integer coordinates, while Naga's `ImageSample`
+ /// and SPIR-V's `OpImageSample...` instructions all take floating-point
+ /// coordinate vectors.
+ ///
+ /// [`Expression::ImageLoad`]: crate::Expression::ImageLoad
+ /// [`ImageSample`]: crate::Expression::ImageSample
+ fn write_image_coordinates(
+ &mut self,
+ coordinates: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ block: &mut Block,
+ ) -> Result<ImageCoordinates, Error> {
+ use crate::TypeInner as Ti;
+ use crate::VectorSize as Vs;
+
+ let coordinates_id = self.cached[coordinates];
+ let ty = &self.fun_info[coordinates].ty;
+ let inner_ty = ty.inner_with(&self.ir_module.types);
+
+ // If there's no array index, the image coordinates are exactly the
+ // `coordinate` field of the `Expression::ImageLoad`. No work is needed.
+ let array_index = match array_index {
+ None => {
+ let value_id = coordinates_id;
+ let type_id = self.get_expression_type_id(ty);
+ let size = match *inner_ty {
+ Ti::Scalar { .. } => None,
+ Ti::Vector { size, .. } => Some(size),
+ _ => return Err(Error::Validation("coordinate type")),
+ };
+ return Ok(ImageCoordinates {
+ value_id,
+ type_id,
+ size,
+ });
+ }
+ Some(ix) => ix,
+ };
+
+ // Find the component type of `coordinates`, and figure out the size the
+ // combined coordinate vector will have.
+ let (component_scalar, size) = match *inner_ty {
+ Ti::Scalar(scalar @ crate::Scalar { width: 4, .. }) => (scalar, Some(Vs::Bi)),
+ Ti::Vector {
+ scalar: scalar @ crate::Scalar { width: 4, .. },
+ size: Vs::Bi,
+ } => (scalar, Some(Vs::Tri)),
+ Ti::Vector {
+ scalar: scalar @ crate::Scalar { width: 4, .. },
+ size: Vs::Tri,
+ } => (scalar, Some(Vs::Quad)),
+ Ti::Vector { size: Vs::Quad, .. } => {
+ return Err(Error::Validation("extending vec4 coordinate"));
+ }
+ ref other => {
+ log::error!("wrong coordinate type {:?}", other);
+ return Err(Error::Validation("coordinate type"));
+ }
+ };
+
+ // Convert the index to the coordinate component type, if necessary.
+ let array_index_id = self.cached[array_index];
+ let ty = &self.fun_info[array_index].ty;
+ let inner_ty = ty.inner_with(&self.ir_module.types);
+ let array_index_scalar = match *inner_ty {
+ Ti::Scalar(
+ scalar @ crate::Scalar {
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ width: 4,
+ },
+ ) => scalar,
+ _ => unreachable!("we only allow i32 and u32"),
+ };
+ let cast = match (component_scalar.kind, array_index_scalar.kind) {
+ (crate::ScalarKind::Sint, crate::ScalarKind::Sint)
+ | (crate::ScalarKind::Uint, crate::ScalarKind::Uint) => None,
+ (crate::ScalarKind::Sint, crate::ScalarKind::Uint)
+ | (crate::ScalarKind::Uint, crate::ScalarKind::Sint) => Some(spirv::Op::Bitcast),
+ (crate::ScalarKind::Float, crate::ScalarKind::Sint) => Some(spirv::Op::ConvertSToF),
+ (crate::ScalarKind::Float, crate::ScalarKind::Uint) => Some(spirv::Op::ConvertUToF),
+ (crate::ScalarKind::Bool, _) => unreachable!("we don't allow bool for component"),
+ (_, crate::ScalarKind::Bool | crate::ScalarKind::Float) => {
+ unreachable!("we don't allow bool or float for array index")
+ }
+ (crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat, _)
+ | (_, crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat) => {
+ unreachable!("abstract types should never reach backends")
+ }
+ };
+ let reconciled_array_index_id = if let Some(cast) = cast {
+ let component_ty_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: component_scalar,
+ pointer_space: None,
+ }));
+ let reconciled_id = self.gen_id();
+ block.body.push(Instruction::unary(
+ cast,
+ component_ty_id,
+ reconciled_id,
+ array_index_id,
+ ));
+ reconciled_id
+ } else {
+ array_index_id
+ };
+
+ // Find the SPIR-V type for the combined coordinates/index vector.
+ let type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: size,
+ scalar: component_scalar,
+ pointer_space: None,
+ }));
+
+ // Schmear the coordinates and index together.
+ let value_id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ type_id,
+ value_id,
+ &[coordinates_id, reconciled_array_index_id],
+ ));
+ Ok(ImageCoordinates {
+ value_id,
+ type_id,
+ size,
+ })
+ }
+
+ pub(super) fn get_handle_id(&mut self, expr_handle: Handle<crate::Expression>) -> Word {
+ let id = match self.ir_function.expressions[expr_handle] {
+ crate::Expression::GlobalVariable(handle) => {
+ self.writer.global_variables[handle.index()].handle_id
+ }
+ crate::Expression::FunctionArgument(i) => {
+ self.function.parameters[i as usize].handle_id
+ }
+ crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => {
+ self.cached[expr_handle]
+ }
+ ref other => unreachable!("Unexpected image expression {:?}", other),
+ };
+
+ if id == 0 {
+ unreachable!(
+ "Image expression {:?} doesn't have a handle ID",
+ expr_handle
+ );
+ }
+
+ id
+ }
+
+ /// Generate a vector or scalar 'one' for arithmetic on `coordinates`.
+ ///
+ /// If `coordinates` is a scalar, return a scalar one. Otherwise, return
+ /// a vector of ones.
+ fn write_coordinate_one(&mut self, coordinates: &ImageCoordinates) -> Result<Word, Error> {
+ let one = self.get_scope_constant(1);
+ match coordinates.size {
+ None => Ok(one),
+ Some(vector_size) => {
+ let ones = [one; 4];
+ let id = self.gen_id();
+ Instruction::constant_composite(
+ coordinates.type_id,
+ id,
+ &ones[..vector_size as usize],
+ )
+ .to_words(&mut self.writer.logical_layout.declarations);
+ Ok(id)
+ }
+ }
+ }
+
+ /// Generate code to restrict `input` to fall between zero and one less than
+ /// `size_id`.
+ ///
+ /// Both must be 32-bit scalar integer values, whose type is given by
+ /// `type_id`. The computed value is also of type `type_id`.
+ fn restrict_scalar(
+ &mut self,
+ type_id: Word,
+ input_id: Word,
+ size_id: Word,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ let i32_one_id = self.get_scope_constant(1);
+
+ // Subtract one from `size` to get the largest valid value.
+ let limit_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ISub,
+ type_id,
+ limit_id,
+ size_id,
+ i32_one_id,
+ ));
+
+ // Use an unsigned minimum, to handle both positive out-of-range values
+ // and negative values in a single instruction: negative values of
+ // `input_id` get treated as very large positive values.
+ let restricted_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::UMin,
+ type_id,
+ restricted_id,
+ &[input_id, limit_id],
+ ));
+
+ Ok(restricted_id)
+ }
+
+ /// Write instructions to query the size of an image.
+ ///
+ /// This takes care of selecting the right instruction depending on whether
+ /// a level of detail parameter is present.
+ fn write_coordinate_bounds(
+ &mut self,
+ type_id: Word,
+ image_id: Word,
+ level_id: Option<Word>,
+ block: &mut Block,
+ ) -> Word {
+ let coordinate_bounds_id = self.gen_id();
+ match level_id {
+ Some(level_id) => {
+ // A level of detail was provided, so fetch the image size for
+ // that level.
+ let mut inst = Instruction::image_query(
+ spirv::Op::ImageQuerySizeLod,
+ type_id,
+ coordinate_bounds_id,
+ image_id,
+ );
+ inst.add_operand(level_id);
+ block.body.push(inst);
+ }
+ _ => {
+ // No level of detail was given.
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySize,
+ type_id,
+ coordinate_bounds_id,
+ image_id,
+ ));
+ }
+ }
+
+ coordinate_bounds_id
+ }
+
+ /// Write code to restrict coordinates for an image reference.
+ ///
+ /// First, clamp the level of detail or sample index to fall within bounds.
+ /// Then, obtain the image size, possibly using the clamped level of detail.
+ /// Finally, use an unsigned minimum instruction to force all coordinates
+ /// into range.
+ ///
+ /// Return a triple `(COORDS, LEVEL, SAMPLE)`, where `COORDS` is a coordinate
+ /// vector (including the array index, if any), `LEVEL` is an optional level
+ /// of detail, and `SAMPLE` is an optional sample index, all guaranteed to
+ /// be in-bounds for `image_id`.
+ ///
+ /// The result is usually a vector, but it is a scalar when indexing
+ /// non-arrayed 1D images.
+ fn write_restricted_coordinates(
+ &mut self,
+ image_id: Word,
+ coordinates: ImageCoordinates,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ ) -> Result<(Word, Option<Word>, Option<Word>), Error> {
+ self.writer.require_any(
+ "the `Restrict` image bounds check policy",
+ &[spirv::Capability::ImageQuery],
+ )?;
+
+ let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::I32,
+ pointer_space: None,
+ }));
+
+ // If `level` is `Some`, clamp it to fall within bounds. This must
+ // happen first, because we'll use it to query the image size for
+ // clamping the actual coordinates.
+ let level_id = level_id
+ .map(|level_id| {
+ // Find the number of mipmap levels in this image.
+ let num_levels_id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQueryLevels,
+ i32_type_id,
+ num_levels_id,
+ image_id,
+ ));
+
+ self.restrict_scalar(i32_type_id, level_id, num_levels_id, block)
+ })
+ .transpose()?;
+
+ // If `sample_id` is `Some`, clamp it to fall within bounds.
+ let sample_id = sample_id
+ .map(|sample_id| {
+ // Find the number of samples per texel.
+ let num_samples_id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySamples,
+ i32_type_id,
+ num_samples_id,
+ image_id,
+ ));
+
+ self.restrict_scalar(i32_type_id, sample_id, num_samples_id, block)
+ })
+ .transpose()?;
+
+ // Obtain the image bounds, including the array element count.
+ let coordinate_bounds_id =
+ self.write_coordinate_bounds(coordinates.type_id, image_id, level_id, block);
+
+ // Compute maximum valid values from the bounds.
+ let ones = self.write_coordinate_one(&coordinates)?;
+ let coordinate_limit_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ISub,
+ coordinates.type_id,
+ coordinate_limit_id,
+ coordinate_bounds_id,
+ ones,
+ ));
+
+ // Restrict the coordinates to fall within those bounds.
+ //
+ // Use an unsigned minimum, to handle both positive out-of-range values
+ // and negative values in a single instruction: negative values of
+ // `coordinates` get treated as very large positive values.
+ let restricted_coordinates_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::UMin,
+ coordinates.type_id,
+ restricted_coordinates_id,
+ &[coordinates.value_id, coordinate_limit_id],
+ ));
+
+ Ok((restricted_coordinates_id, level_id, sample_id))
+ }
+
+ fn write_conditional_image_access<A: Access>(
+ &mut self,
+ image_id: Word,
+ coordinates: ImageCoordinates,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ access: &A,
+ ) -> Result<A::Output, Error> {
+ self.writer.require_any(
+ "the `ReadZeroSkipWrite` image bounds check policy",
+ &[spirv::Capability::ImageQuery],
+ )?;
+
+ let bool_type_id = self.writer.get_bool_type_id();
+ let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::I32,
+ pointer_space: None,
+ }));
+
+ let null_id = access.out_of_bounds_value(self);
+
+ let mut selection = Selection::start(block, access.result_type());
+
+ // If `level_id` is `Some`, check whether it is within bounds. This must
+ // happen first, because we'll be supplying this as an argument when we
+ // query the image size.
+ if let Some(level_id) = level_id {
+ // Find the number of mipmap levels in this image.
+ let num_levels_id = self.gen_id();
+ selection.block().body.push(Instruction::image_query(
+ spirv::Op::ImageQueryLevels,
+ i32_type_id,
+ num_levels_id,
+ image_id,
+ ));
+
+ let lod_cond_id = self.gen_id();
+ selection.block().body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ bool_type_id,
+ lod_cond_id,
+ level_id,
+ num_levels_id,
+ ));
+
+ selection.if_true(self, lod_cond_id, null_id);
+ }
+
+ // If `sample_id` is `Some`, check whether it is in bounds.
+ if let Some(sample_id) = sample_id {
+ // Find the number of samples per texel.
+ let num_samples_id = self.gen_id();
+ selection.block().body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySamples,
+ i32_type_id,
+ num_samples_id,
+ image_id,
+ ));
+
+ let samples_cond_id = self.gen_id();
+ selection.block().body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ bool_type_id,
+ samples_cond_id,
+ sample_id,
+ num_samples_id,
+ ));
+
+ selection.if_true(self, samples_cond_id, null_id);
+ }
+
+ // Obtain the image bounds, including any array element count.
+ let coordinate_bounds_id = self.write_coordinate_bounds(
+ coordinates.type_id,
+ image_id,
+ level_id,
+ selection.block(),
+ );
+
+ // Compare the coordinates against the bounds.
+ let coords_bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: coordinates.size,
+ scalar: crate::Scalar::BOOL,
+ pointer_space: None,
+ }));
+ let coords_conds_id = self.gen_id();
+ selection.block().body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ coords_bool_type_id,
+ coords_conds_id,
+ coordinates.value_id,
+ coordinate_bounds_id,
+ ));
+
+ // If the comparison above was a vector comparison, then we need to
+ // check that all components of the comparison are true.
+ let coords_cond_id = if coords_bool_type_id != bool_type_id {
+ let id = self.gen_id();
+ selection.block().body.push(Instruction::relational(
+ spirv::Op::All,
+ bool_type_id,
+ id,
+ coords_conds_id,
+ ));
+ id
+ } else {
+ coords_conds_id
+ };
+
+ selection.if_true(self, coords_cond_id, null_id);
+
+ // All conditions are met. We can carry out the access.
+ let texel_id = access.generate(
+ &mut self.writer.id_gen,
+ coordinates.value_id,
+ level_id,
+ sample_id,
+ selection.block(),
+ );
+
+ // This, then, is the value of the 'true' branch.
+ Ok(selection.finish(self, texel_id))
+ }
+
+ /// Generate code for an `ImageLoad` expression.
+ ///
+ /// The arguments are the components of an `Expression::ImageLoad` variant.
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn write_image_load(
+ &mut self,
+ result_type_id: Word,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ level: Option<Handle<crate::Expression>>,
+ sample: Option<Handle<crate::Expression>>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ let image_id = self.get_handle_id(image);
+ let image_type = self.fun_info[image].ty.inner_with(&self.ir_module.types);
+ let image_class = match *image_type {
+ crate::TypeInner::Image { class, .. } => class,
+ _ => return Err(Error::Validation("image type")),
+ };
+
+ let access = Load::from_image_expr(self, image_id, image_class, result_type_id)?;
+ let coordinates = self.write_image_coordinates(coordinate, array_index, block)?;
+
+ let level_id = level.map(|expr| self.cached[expr]);
+ let sample_id = sample.map(|expr| self.cached[expr]);
+
+ // Perform the access, according to the bounds check policy.
+ let access_id = match self.writer.bounds_check_policies.image_load {
+ crate::proc::BoundsCheckPolicy::Restrict => {
+ let (coords, level_id, sample_id) = self.write_restricted_coordinates(
+ image_id,
+ coordinates,
+ level_id,
+ sample_id,
+ block,
+ )?;
+ access.generate(&mut self.writer.id_gen, coords, level_id, sample_id, block)
+ }
+ crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => self
+ .write_conditional_image_access(
+ image_id,
+ coordinates,
+ level_id,
+ sample_id,
+ block,
+ &access,
+ )?,
+ crate::proc::BoundsCheckPolicy::Unchecked => access.generate(
+ &mut self.writer.id_gen,
+ coordinates.value_id,
+ level_id,
+ sample_id,
+ block,
+ ),
+ };
+
+ // For depth images, `ImageLoad` expressions produce a single f32,
+ // whereas the SPIR-V instructions always produce a vec4. So we may have
+ // to pull out the component we need.
+ let result_id = if result_type_id == access.result_type() {
+ // The instruction produced the type we expected. We can use
+ // its result as-is.
+ access_id
+ } else {
+ // For `ImageClass::Depth` images, SPIR-V gave us four components,
+ // but we only want the first one.
+ let component_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ component_id,
+ access_id,
+ &[0],
+ ));
+ component_id
+ };
+
+ Ok(result_id)
+ }
+
+ /// Generate code for an `ImageSample` expression.
+ ///
+ /// The arguments are the components of an `Expression::ImageSample` variant.
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn write_image_sample(
+ &mut self,
+ result_type_id: Word,
+ image: Handle<crate::Expression>,
+ sampler: Handle<crate::Expression>,
+ gather: Option<crate::SwizzleComponent>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ offset: Option<Handle<crate::Expression>>,
+ level: crate::SampleLevel,
+ depth_ref: Option<Handle<crate::Expression>>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ use super::instructions::SampleLod;
+ // image
+ let image_id = self.get_handle_id(image);
+ let image_type = self.fun_info[image].ty.handle().unwrap();
+ // SPIR-V doesn't know about our `Depth` class, and it returns
+ // `vec4<f32>`, so we need to grab the first component out of it.
+ let needs_sub_access = match self.ir_module.types[image_type].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Depth { .. },
+ ..
+ } => depth_ref.is_none() && gather.is_none(),
+ _ => false,
+ };
+ let sample_result_type_id = if needs_sub_access {
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Quad),
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }))
+ } else {
+ result_type_id
+ };
+
+ // OpTypeSampledImage
+ let image_type_id = self.get_type_id(LookupType::Handle(image_type));
+ let sampled_image_type_id =
+ self.get_type_id(LookupType::Local(LocalType::SampledImage { image_type_id }));
+
+ let sampler_id = self.get_handle_id(sampler);
+ let coordinates_id = self
+ .write_image_coordinates(coordinate, array_index, block)?
+ .value_id;
+
+ let sampled_image_id = self.gen_id();
+ block.body.push(Instruction::sampled_image(
+ sampled_image_type_id,
+ sampled_image_id,
+ image_id,
+ sampler_id,
+ ));
+ let id = self.gen_id();
+
+ let depth_id = depth_ref.map(|handle| self.cached[handle]);
+ let mut mask = spirv::ImageOperands::empty();
+ mask.set(spirv::ImageOperands::CONST_OFFSET, offset.is_some());
+
+ let mut main_instruction = match (level, gather) {
+ (_, Some(component)) => {
+ let component_id = self.get_index_constant(component as u32);
+ let mut inst = Instruction::image_gather(
+ sample_result_type_id,
+ id,
+ sampled_image_id,
+ coordinates_id,
+ component_id,
+ depth_id,
+ );
+ if !mask.is_empty() {
+ inst.add_operand(mask.bits());
+ }
+ inst
+ }
+ (crate::SampleLevel::Zero, None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Explicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let zero_id = self.writer.get_constant_scalar(crate::Literal::F32(0.0));
+
+ mask |= spirv::ImageOperands::LOD;
+ inst.add_operand(mask.bits());
+ inst.add_operand(zero_id);
+
+ inst
+ }
+ (crate::SampleLevel::Auto, None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Implicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+ if !mask.is_empty() {
+ inst.add_operand(mask.bits());
+ }
+ inst
+ }
+ (crate::SampleLevel::Exact(lod_handle), None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Explicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let lod_id = self.cached[lod_handle];
+ mask |= spirv::ImageOperands::LOD;
+ inst.add_operand(mask.bits());
+ inst.add_operand(lod_id);
+
+ inst
+ }
+ (crate::SampleLevel::Bias(bias_handle), None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Implicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let bias_id = self.cached[bias_handle];
+ mask |= spirv::ImageOperands::BIAS;
+ inst.add_operand(mask.bits());
+ inst.add_operand(bias_id);
+
+ inst
+ }
+ (crate::SampleLevel::Gradient { x, y }, None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Explicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let x_id = self.cached[x];
+ let y_id = self.cached[y];
+ mask |= spirv::ImageOperands::GRAD;
+ inst.add_operand(mask.bits());
+ inst.add_operand(x_id);
+ inst.add_operand(y_id);
+
+ inst
+ }
+ };
+
+ if let Some(offset_const) = offset {
+ let offset_id = self.writer.constant_ids[offset_const.index()];
+ main_instruction.add_operand(offset_id);
+ }
+
+ block.body.push(main_instruction);
+
+ let id = if needs_sub_access {
+ let sub_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ sub_id,
+ id,
+ &[0],
+ ));
+ sub_id
+ } else {
+ id
+ };
+
+ Ok(id)
+ }
+
+ /// Generate code for an `ImageQuery` expression.
+ ///
+ /// The arguments are the components of an `Expression::ImageQuery` variant.
+ pub(super) fn write_image_query(
+ &mut self,
+ result_type_id: Word,
+ image: Handle<crate::Expression>,
+ query: crate::ImageQuery,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ use crate::{ImageClass as Ic, ImageDimension as Id, ImageQuery as Iq};
+
+ let image_id = self.get_handle_id(image);
+ let image_type = self.fun_info[image].ty.handle().unwrap();
+ let (dim, arrayed, class) = match self.ir_module.types[image_type].inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => (dim, arrayed, class),
+ _ => {
+ return Err(Error::Validation("image type"));
+ }
+ };
+
+ self.writer
+ .require_any("image queries", &[spirv::Capability::ImageQuery])?;
+
+ let id = match query {
+ Iq::Size { level } => {
+ let dim_coords = match dim {
+ Id::D1 => 1,
+ Id::D2 | Id::Cube => 2,
+ Id::D3 => 3,
+ };
+ let array_coords = usize::from(arrayed);
+ let vector_size = match dim_coords + array_coords {
+ 2 => Some(crate::VectorSize::Bi),
+ 3 => Some(crate::VectorSize::Tri),
+ 4 => Some(crate::VectorSize::Quad),
+ _ => None,
+ };
+ let extended_size_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size,
+ scalar: crate::Scalar::U32,
+ pointer_space: None,
+ }));
+
+ let (query_op, level_id) = match class {
+ Ic::Sampled { multi: true, .. }
+ | Ic::Depth { multi: true }
+ | Ic::Storage { .. } => (spirv::Op::ImageQuerySize, None),
+ _ => {
+ let level_id = match level {
+ Some(expr) => self.cached[expr],
+ None => self.get_index_constant(0),
+ };
+ (spirv::Op::ImageQuerySizeLod, Some(level_id))
+ }
+ };
+
+ // The ID of the vector returned by SPIR-V, which contains the dimensions
+ // as well as the layer count.
+ let id_extended = self.gen_id();
+ let mut inst = Instruction::image_query(
+ query_op,
+ extended_size_type_id,
+ id_extended,
+ image_id,
+ );
+ if let Some(expr_id) = level_id {
+ inst.add_operand(expr_id);
+ }
+ block.body.push(inst);
+
+ if result_type_id != extended_size_type_id {
+ let id = self.gen_id();
+ let components = match dim {
+ // always pick the first component, and duplicate it for all 3 dimensions
+ Id::Cube => &[0u32, 0][..],
+ _ => &[0u32, 1, 2, 3][..dim_coords],
+ };
+ block.body.push(Instruction::vector_shuffle(
+ result_type_id,
+ id,
+ id_extended,
+ id_extended,
+ components,
+ ));
+
+ id
+ } else {
+ id_extended
+ }
+ }
+ Iq::NumLevels => {
+ let query_id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQueryLevels,
+ result_type_id,
+ query_id,
+ image_id,
+ ));
+
+ query_id
+ }
+ Iq::NumLayers => {
+ let vec_size = match dim {
+ Id::D1 => crate::VectorSize::Bi,
+ Id::D2 | Id::Cube => crate::VectorSize::Tri,
+ Id::D3 => crate::VectorSize::Quad,
+ };
+ let extended_size_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(vec_size),
+ scalar: crate::Scalar::U32,
+ pointer_space: None,
+ }));
+ let id_extended = self.gen_id();
+ let mut inst = Instruction::image_query(
+ spirv::Op::ImageQuerySizeLod,
+ extended_size_type_id,
+ id_extended,
+ image_id,
+ );
+ inst.add_operand(self.get_index_constant(0));
+ block.body.push(inst);
+
+ let extract_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ extract_id,
+ id_extended,
+ &[vec_size as u32 - 1],
+ ));
+
+ extract_id
+ }
+ Iq::NumSamples => {
+ let query_id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySamples,
+ result_type_id,
+ query_id,
+ image_id,
+ ));
+
+ query_id
+ }
+ };
+
+ Ok(id)
+ }
+
+ pub(super) fn write_image_store(
+ &mut self,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ value: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ let image_id = self.get_handle_id(image);
+ let coordinates = self.write_image_coordinates(coordinate, array_index, block)?;
+ let value_id = self.cached[value];
+
+ let write = Store { image_id, value_id };
+
+ match *self.fun_info[image].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Image {
+ class:
+ crate::ImageClass::Storage {
+ format: crate::StorageFormat::Bgra8Unorm,
+ ..
+ },
+ ..
+ } => self.writer.require_any(
+ "Bgra8Unorm storage write",
+ &[spirv::Capability::StorageImageWriteWithoutFormat],
+ )?,
+ _ => {}
+ }
+
+ match self.writer.bounds_check_policies.image_store {
+ crate::proc::BoundsCheckPolicy::Restrict => {
+ let (coords, _, _) =
+ self.write_restricted_coordinates(image_id, coordinates, None, None, block)?;
+ write.generate(&mut self.writer.id_gen, coords, None, None, block);
+ }
+ crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
+ self.write_conditional_image_access(
+ image_id,
+ coordinates,
+ None,
+ None,
+ block,
+ &write,
+ )?;
+ }
+ crate::proc::BoundsCheckPolicy::Unchecked => {
+ write.generate(
+ &mut self.writer.id_gen,
+ coordinates.value_id,
+ None,
+ None,
+ block,
+ );
+ }
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/index.rs b/third_party/rust/naga/src/back/spv/index.rs
new file mode 100644
index 0000000000..92e0f88d9a
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/index.rs
@@ -0,0 +1,421 @@
+/*!
+Bounds-checking for SPIR-V output.
+*/
+
+use super::{
+ helpers::global_needs_wrapper, selection::Selection, Block, BlockContext, Error, IdGenerator,
+ Instruction, Word,
+};
+use crate::{arena::Handle, proc::BoundsCheckPolicy};
+
+/// The results of performing a bounds check.
+///
+/// On success, `write_bounds_check` returns a value of this type.
+pub(super) enum BoundsCheckResult {
+ /// The index is statically known and in bounds, with the given value.
+ KnownInBounds(u32),
+
+ /// The given instruction computes the index to be used.
+ Computed(Word),
+
+ /// The given instruction computes a boolean condition which is true
+ /// if the index is in bounds.
+ Conditional(Word),
+}
+
+/// A value that we either know at translation time, or need to compute at runtime.
+pub(super) enum MaybeKnown<T> {
+ /// The value is known at shader translation time.
+ Known(T),
+
+ /// The value is computed by the instruction with the given id.
+ Computed(Word),
+}
+
+impl<'w> BlockContext<'w> {
+ /// Emit code to compute the length of a run-time array.
+ ///
+ /// Given `array`, an expression referring a runtime-sized array, return the
+ /// instruction id for the array's length.
+ pub(super) fn write_runtime_array_length(
+ &mut self,
+ array: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ // Naga IR permits runtime-sized arrays as global variables or as the
+ // final member of a struct that is a global variable. SPIR-V permits
+ // only the latter, so this back end wraps bare runtime-sized arrays
+ // in a made-up struct; see `helpers::global_needs_wrapper` and its uses.
+ // This code must handle both cases.
+ let (structure_id, last_member_index) = match self.ir_function.expressions[array] {
+ crate::Expression::AccessIndex { base, index } => {
+ match self.ir_function.expressions[base] {
+ crate::Expression::GlobalVariable(handle) => (
+ self.writer.global_variables[handle.index()].access_id,
+ index,
+ ),
+ _ => return Err(Error::Validation("array length expression")),
+ }
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ let global = &self.ir_module.global_variables[handle];
+ if !global_needs_wrapper(self.ir_module, global) {
+ return Err(Error::Validation("array length expression"));
+ }
+
+ (self.writer.global_variables[handle.index()].var_id, 0)
+ }
+ _ => return Err(Error::Validation("array length expression")),
+ };
+
+ let length_id = self.gen_id();
+ block.body.push(Instruction::array_length(
+ self.writer.get_uint_type_id(),
+ length_id,
+ structure_id,
+ last_member_index,
+ ));
+
+ Ok(length_id)
+ }
+
+ /// Compute the length of a subscriptable value.
+ ///
+ /// Given `sequence`, an expression referring to some indexable type, return
+ /// its length. The result may either be computed by SPIR-V instructions, or
+ /// known at shader translation time.
+ ///
+ /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
+ /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
+ /// sized, or use a specializable constant as its length.
+ fn write_sequence_length(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<MaybeKnown<u32>, Error> {
+ let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types);
+ match sequence_ty.indexable_length(self.ir_module) {
+ Ok(crate::proc::IndexableLength::Known(known_length)) => {
+ Ok(MaybeKnown::Known(known_length))
+ }
+ Ok(crate::proc::IndexableLength::Dynamic) => {
+ let length_id = self.write_runtime_array_length(sequence, block)?;
+ Ok(MaybeKnown::Computed(length_id))
+ }
+ Err(err) => {
+ log::error!("Sequence length for {:?} failed: {}", sequence, err);
+ Err(Error::Validation("indexable length"))
+ }
+ }
+ }
+
+ /// Compute the maximum valid index of a subscriptable value.
+ ///
+ /// Given `sequence`, an expression referring to some indexable type, return
+ /// its maximum valid index - one less than its length. The result may
+ /// either be computed, or known at shader translation time.
+ ///
+ /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
+ /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
+ /// sized, or use a specializable constant as its length.
+ fn write_sequence_max_index(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<MaybeKnown<u32>, Error> {
+ match self.write_sequence_length(sequence, block)? {
+ MaybeKnown::Known(known_length) => {
+ // We should have thrown out all attempts to subscript zero-length
+ // sequences during validation, so the following subtraction should never
+ // underflow.
+ assert!(known_length > 0);
+ // Compute the max index from the length now.
+ Ok(MaybeKnown::Known(known_length - 1))
+ }
+ MaybeKnown::Computed(length_id) => {
+ // Emit code to compute the max index from the length.
+ let const_one_id = self.get_index_constant(1);
+ let max_index_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ISub,
+ self.writer.get_uint_type_id(),
+ max_index_id,
+ length_id,
+ const_one_id,
+ ));
+ Ok(MaybeKnown::Computed(max_index_id))
+ }
+ }
+ }
+
+ /// Restrict an index to be in range for a vector, matrix, or array.
+ ///
+ /// This is used to implement `BoundsCheckPolicy::Restrict`. An in-bounds
+ /// index is left unchanged. An out-of-bounds index is replaced with some
+ /// arbitrary in-bounds index. Note,this is not necessarily clamping; for
+ /// example, negative indices might be changed to refer to the last element
+ /// of the sequence, not the first, as clamping would do.
+ ///
+ /// Either return the restricted index value, if known, or add instructions
+ /// to `block` to compute it, and return the id of the result. See the
+ /// documentation for `BoundsCheckResult` for details.
+ ///
+ /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
+ /// `Pointer` to any of those, or a `ValuePointer`. An array may be
+ /// fixed-size, dynamically sized, or use a specializable constant as its
+ /// length.
+ pub(super) fn write_restricted_index(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<BoundsCheckResult, Error> {
+ let index_id = self.cached[index];
+
+ // Get the sequence's maximum valid index. Return early if we've already
+ // done the bounds check.
+ let max_index_id = match self.write_sequence_max_index(sequence, block)? {
+ MaybeKnown::Known(known_max_index) => {
+ if let Ok(known_index) = self
+ .ir_module
+ .to_ctx()
+ .eval_expr_to_u32_from(index, &self.ir_function.expressions)
+ {
+ // Both the index and length are known at compile time.
+ //
+ // In strict WGSL compliance mode, out-of-bounds indices cannot be
+ // reported at shader translation time, and must be replaced with
+ // in-bounds indices at run time. So we cannot assume that
+ // validation ensured the index was in bounds. Restrict now.
+ let restricted = std::cmp::min(known_index, known_max_index);
+ return Ok(BoundsCheckResult::KnownInBounds(restricted));
+ }
+
+ self.get_index_constant(known_max_index)
+ }
+ MaybeKnown::Computed(max_index_id) => max_index_id,
+ };
+
+ // One or the other of the index or length is dynamic, so emit code for
+ // BoundsCheckPolicy::Restrict.
+ let restricted_index_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::UMin,
+ self.writer.get_uint_type_id(),
+ restricted_index_id,
+ &[index_id, max_index_id],
+ ));
+ Ok(BoundsCheckResult::Computed(restricted_index_id))
+ }
+
+ /// Write an index bounds comparison to `block`, if needed.
+ ///
+ /// If we're able to determine statically that `index` is in bounds for
+ /// `sequence`, return `KnownInBounds(value)`, where `value` is the actual
+ /// value of the index. (In principle, one could know that the index is in
+ /// bounds without knowing its specific value, but in our simple-minded
+ /// situation, we always know it.)
+ ///
+ /// If instead we must generate code to perform the comparison at run time,
+ /// return `Conditional(comparison_id)`, where `comparison_id` is an
+ /// instruction producing a boolean value that is true if `index` is in
+ /// bounds for `sequence`.
+ ///
+ /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
+ /// `Pointer` to any of those, or a `ValuePointer`. An array may be
+ /// fixed-size, dynamically sized, or use a specializable constant as its
+ /// length.
+ fn write_index_comparison(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<BoundsCheckResult, Error> {
+ let index_id = self.cached[index];
+
+ // Get the sequence's length. Return early if we've already done the
+ // bounds check.
+ let length_id = match self.write_sequence_length(sequence, block)? {
+ MaybeKnown::Known(known_length) => {
+ if let Ok(known_index) = self
+ .ir_module
+ .to_ctx()
+ .eval_expr_to_u32_from(index, &self.ir_function.expressions)
+ {
+ // Both the index and length are known at compile time.
+ //
+ // It would be nice to assume that, since we are using the
+ // `ReadZeroSkipWrite` policy, we are not in strict WGSL
+ // compliance mode, and thus we can count on the validator to have
+ // rejected any programs with known out-of-bounds indices, and
+ // thus just return `KnownInBounds` here without actually
+ // checking.
+ //
+ // But it's also reasonable to expect that bounds check policies
+ // and error reporting policies should be able to vary
+ // independently without introducing security holes. So, we should
+ // support the case where bad indices do not cause validation
+ // errors, and are handled via `ReadZeroSkipWrite`.
+ //
+ // In theory, when `known_index` is bad, we could return a new
+ // `KnownOutOfBounds` variant here. But it's simpler just to fall
+ // through and let the bounds check take place. The shader is
+ // broken anyway, so it doesn't make sense to invest in emitting
+ // the ideal code for it.
+ if known_index < known_length {
+ return Ok(BoundsCheckResult::KnownInBounds(known_index));
+ }
+ }
+
+ self.get_index_constant(known_length)
+ }
+ MaybeKnown::Computed(length_id) => length_id,
+ };
+
+ // Compare the index against the length.
+ let condition_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ self.writer.get_bool_type_id(),
+ condition_id,
+ index_id,
+ length_id,
+ ));
+
+ // Indicate that we did generate the check.
+ Ok(BoundsCheckResult::Conditional(condition_id))
+ }
+
+ /// Emit a conditional load for `BoundsCheckPolicy::ReadZeroSkipWrite`.
+ ///
+ /// Generate code to load a value of `result_type` if `condition` is true,
+ /// and generate a null value of that type if it is false. Call `emit_load`
+ /// to emit the instructions to perform the load. Return the id of the
+ /// merged value of the two branches.
+ pub(super) fn write_conditional_indexed_load<F>(
+ &mut self,
+ result_type: Word,
+ condition: Word,
+ block: &mut Block,
+ emit_load: F,
+ ) -> Word
+ where
+ F: FnOnce(&mut IdGenerator, &mut Block) -> Word,
+ {
+ // For the out-of-bounds case, we produce a zero value.
+ let null_id = self.writer.get_constant_null(result_type);
+
+ let mut selection = Selection::start(block, result_type);
+
+ // As it turns out, we don't actually need a full 'if-then-else'
+ // structure for this: SPIR-V constants are declared up front, so the
+ // 'else' block would have no instructions. Instead we emit something
+ // like this:
+ //
+ // result = zero;
+ // if in_bounds {
+ // result = do the load;
+ // }
+ // use result;
+
+ // Continue only if the index was in bounds. Otherwise, branch to the
+ // merge block.
+ selection.if_true(self, condition, null_id);
+
+ // The in-bounds path. Perform the access and the load.
+ let loaded_value = emit_load(&mut self.writer.id_gen, selection.block());
+
+ selection.finish(self, loaded_value)
+ }
+
+ /// Emit code for bounds checks for an array, vector, or matrix access.
+ ///
+ /// This implements either `index_bounds_check_policy` or
+ /// `buffer_bounds_check_policy`, depending on the address space of the
+ /// pointer being accessed.
+ ///
+ /// Return a `BoundsCheckResult` indicating how the index should be
+ /// consumed. See that type's documentation for details.
+ pub(super) fn write_bounds_check(
+ &mut self,
+ base: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<BoundsCheckResult, Error> {
+ let policy = self.writer.bounds_check_policies.choose_policy(
+ base,
+ &self.ir_module.types,
+ self.fun_info,
+ );
+
+ Ok(match policy {
+ BoundsCheckPolicy::Restrict => self.write_restricted_index(base, index, block)?,
+ BoundsCheckPolicy::ReadZeroSkipWrite => {
+ self.write_index_comparison(base, index, block)?
+ }
+ BoundsCheckPolicy::Unchecked => BoundsCheckResult::Computed(self.cached[index]),
+ })
+ }
+
+ /// Emit code to subscript a vector by value with a computed index.
+ ///
+ /// Return the id of the element value.
+ pub(super) fn write_vector_access(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ base: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
+
+ let base_id = self.cached[base];
+ let index_id = self.cached[index];
+
+ let result_id = match self.write_bounds_check(base, index, block)? {
+ BoundsCheckResult::KnownInBounds(known_index) => {
+ let result_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ result_id,
+ base_id,
+ &[known_index],
+ ));
+ result_id
+ }
+ BoundsCheckResult::Computed(computed_index_id) => {
+ let result_id = self.gen_id();
+ block.body.push(Instruction::vector_extract_dynamic(
+ result_type_id,
+ result_id,
+ base_id,
+ computed_index_id,
+ ));
+ result_id
+ }
+ BoundsCheckResult::Conditional(comparison_id) => {
+ // Run-time bounds checks were required. Emit
+ // conditional load.
+ self.write_conditional_indexed_load(
+ result_type_id,
+ comparison_id,
+ block,
+ |id_gen, block| {
+ // The in-bounds path. Generate the access.
+ let element_id = id_gen.next();
+ block.body.push(Instruction::vector_extract_dynamic(
+ result_type_id,
+ element_id,
+ base_id,
+ index_id,
+ ));
+ element_id
+ },
+ )
+ }
+ };
+
+ Ok(result_id)
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/instructions.rs b/third_party/rust/naga/src/back/spv/instructions.rs
new file mode 100644
index 0000000000..b963793ad3
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/instructions.rs
@@ -0,0 +1,1100 @@
+use super::{block::DebugInfoInner, helpers};
+use spirv::{Op, Word};
+
+pub(super) enum Signedness {
+ Unsigned = 0,
+ Signed = 1,
+}
+
+pub(super) enum SampleLod {
+ Explicit,
+ Implicit,
+}
+
+pub(super) struct Case {
+ pub value: Word,
+ pub label_id: Word,
+}
+
+impl super::Instruction {
+ //
+ // Debug Instructions
+ //
+
+ pub(super) fn string(name: &str, id: Word) -> Self {
+ let mut instruction = Self::new(Op::String);
+ instruction.set_result(id);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn source(
+ source_language: spirv::SourceLanguage,
+ version: u32,
+ source: &Option<DebugInfoInner>,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Source);
+ instruction.add_operand(source_language as u32);
+ instruction.add_operands(helpers::bytes_to_words(&version.to_le_bytes()));
+ if let Some(source) = source.as_ref() {
+ instruction.add_operand(source.source_file_id);
+ instruction.add_operands(helpers::string_to_words(source.source_code));
+ }
+ instruction
+ }
+
+ pub(super) fn name(target_id: Word, name: &str) -> Self {
+ let mut instruction = Self::new(Op::Name);
+ instruction.add_operand(target_id);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn member_name(target_id: Word, member: Word, name: &str) -> Self {
+ let mut instruction = Self::new(Op::MemberName);
+ instruction.add_operand(target_id);
+ instruction.add_operand(member);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn line(file: Word, line: Word, column: Word) -> Self {
+ let mut instruction = Self::new(Op::Line);
+ instruction.add_operand(file);
+ instruction.add_operand(line);
+ instruction.add_operand(column);
+ instruction
+ }
+
+ pub(super) const fn no_line() -> Self {
+ Self::new(Op::NoLine)
+ }
+
+ //
+ // Annotation Instructions
+ //
+
+ pub(super) fn decorate(
+ target_id: Word,
+ decoration: spirv::Decoration,
+ operands: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::Decorate);
+ instruction.add_operand(target_id);
+ instruction.add_operand(decoration as u32);
+ for operand in operands {
+ instruction.add_operand(*operand)
+ }
+ instruction
+ }
+
+ pub(super) fn member_decorate(
+ target_id: Word,
+ member_index: Word,
+ decoration: spirv::Decoration,
+ operands: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::MemberDecorate);
+ instruction.add_operand(target_id);
+ instruction.add_operand(member_index);
+ instruction.add_operand(decoration as u32);
+ for operand in operands {
+ instruction.add_operand(*operand)
+ }
+ instruction
+ }
+
+ //
+ // Extension Instructions
+ //
+
+ pub(super) fn extension(name: &str) -> Self {
+ let mut instruction = Self::new(Op::Extension);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn ext_inst_import(id: Word, name: &str) -> Self {
+ let mut instruction = Self::new(Op::ExtInstImport);
+ instruction.set_result(id);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn ext_inst(
+ set_id: Word,
+ op: spirv::GLOp,
+ result_type_id: Word,
+ id: Word,
+ operands: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::ExtInst);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(set_id);
+ instruction.add_operand(op as u32);
+ for operand in operands {
+ instruction.add_operand(*operand)
+ }
+ instruction
+ }
+
+ //
+ // Mode-Setting Instructions
+ //
+
+ pub(super) fn memory_model(
+ addressing_model: spirv::AddressingModel,
+ memory_model: spirv::MemoryModel,
+ ) -> Self {
+ let mut instruction = Self::new(Op::MemoryModel);
+ instruction.add_operand(addressing_model as u32);
+ instruction.add_operand(memory_model as u32);
+ instruction
+ }
+
+ pub(super) fn entry_point(
+ execution_model: spirv::ExecutionModel,
+ entry_point_id: Word,
+ name: &str,
+ interface_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::EntryPoint);
+ instruction.add_operand(execution_model as u32);
+ instruction.add_operand(entry_point_id);
+ instruction.add_operands(helpers::string_to_words(name));
+
+ for interface_id in interface_ids {
+ instruction.add_operand(*interface_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn execution_mode(
+ entry_point_id: Word,
+ execution_mode: spirv::ExecutionMode,
+ args: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::ExecutionMode);
+ instruction.add_operand(entry_point_id);
+ instruction.add_operand(execution_mode as u32);
+ for arg in args {
+ instruction.add_operand(*arg);
+ }
+ instruction
+ }
+
+ pub(super) fn capability(capability: spirv::Capability) -> Self {
+ let mut instruction = Self::new(Op::Capability);
+ instruction.add_operand(capability as u32);
+ instruction
+ }
+
+ //
+ // Type-Declaration Instructions
+ //
+
+ pub(super) fn type_void(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeVoid);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_bool(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeBool);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_int(id: Word, width: Word, signedness: Signedness) -> Self {
+ let mut instruction = Self::new(Op::TypeInt);
+ instruction.set_result(id);
+ instruction.add_operand(width);
+ instruction.add_operand(signedness as u32);
+ instruction
+ }
+
+ pub(super) fn type_float(id: Word, width: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeFloat);
+ instruction.set_result(id);
+ instruction.add_operand(width);
+ instruction
+ }
+
+ pub(super) fn type_vector(
+ id: Word,
+ component_type_id: Word,
+ component_count: crate::VectorSize,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypeVector);
+ instruction.set_result(id);
+ instruction.add_operand(component_type_id);
+ instruction.add_operand(component_count as u32);
+ instruction
+ }
+
+ pub(super) fn type_matrix(
+ id: Word,
+ column_type_id: Word,
+ column_count: crate::VectorSize,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypeMatrix);
+ instruction.set_result(id);
+ instruction.add_operand(column_type_id);
+ instruction.add_operand(column_count as u32);
+ instruction
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn type_image(
+ id: Word,
+ sampled_type_id: Word,
+ dim: spirv::Dim,
+ flags: super::ImageTypeFlags,
+ image_format: spirv::ImageFormat,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypeImage);
+ instruction.set_result(id);
+ instruction.add_operand(sampled_type_id);
+ instruction.add_operand(dim as u32);
+ instruction.add_operand(flags.contains(super::ImageTypeFlags::DEPTH) as u32);
+ instruction.add_operand(flags.contains(super::ImageTypeFlags::ARRAYED) as u32);
+ instruction.add_operand(flags.contains(super::ImageTypeFlags::MULTISAMPLED) as u32);
+ instruction.add_operand(if flags.contains(super::ImageTypeFlags::SAMPLED) {
+ 1
+ } else {
+ 2
+ });
+ instruction.add_operand(image_format as u32);
+ instruction
+ }
+
+ pub(super) fn type_sampler(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeSampler);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_acceleration_structure(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeAccelerationStructureKHR);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_ray_query(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeRayQueryKHR);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_sampled_image(id: Word, image_type_id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeSampledImage);
+ instruction.set_result(id);
+ instruction.add_operand(image_type_id);
+ instruction
+ }
+
+ pub(super) fn type_array(id: Word, element_type_id: Word, length_id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeArray);
+ instruction.set_result(id);
+ instruction.add_operand(element_type_id);
+ instruction.add_operand(length_id);
+ instruction
+ }
+
+ pub(super) fn type_runtime_array(id: Word, element_type_id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeRuntimeArray);
+ instruction.set_result(id);
+ instruction.add_operand(element_type_id);
+ instruction
+ }
+
+ pub(super) fn type_struct(id: Word, member_ids: &[Word]) -> Self {
+ let mut instruction = Self::new(Op::TypeStruct);
+ instruction.set_result(id);
+
+ for member_id in member_ids {
+ instruction.add_operand(*member_id)
+ }
+
+ instruction
+ }
+
+ pub(super) fn type_pointer(
+ id: Word,
+ storage_class: spirv::StorageClass,
+ type_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypePointer);
+ instruction.set_result(id);
+ instruction.add_operand(storage_class as u32);
+ instruction.add_operand(type_id);
+ instruction
+ }
+
+ pub(super) fn type_function(id: Word, return_type_id: Word, parameter_ids: &[Word]) -> Self {
+ let mut instruction = Self::new(Op::TypeFunction);
+ instruction.set_result(id);
+ instruction.add_operand(return_type_id);
+
+ for parameter_id in parameter_ids {
+ instruction.add_operand(*parameter_id);
+ }
+
+ instruction
+ }
+
+ //
+ // Constant-Creation Instructions
+ //
+
+ pub(super) fn constant_null(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::ConstantNull);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn constant_true(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::ConstantTrue);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn constant_false(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::ConstantFalse);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn constant_32bit(result_type_id: Word, id: Word, value: Word) -> Self {
+ Self::constant(result_type_id, id, &[value])
+ }
+
+ pub(super) fn constant_64bit(result_type_id: Word, id: Word, low: Word, high: Word) -> Self {
+ Self::constant(result_type_id, id, &[low, high])
+ }
+
+ pub(super) fn constant(result_type_id: Word, id: Word, values: &[Word]) -> Self {
+ let mut instruction = Self::new(Op::Constant);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ for value in values {
+ instruction.add_operand(*value);
+ }
+
+ instruction
+ }
+
+ pub(super) fn constant_composite(
+ result_type_id: Word,
+ id: Word,
+ constituent_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::ConstantComposite);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ for constituent_id in constituent_ids {
+ instruction.add_operand(*constituent_id);
+ }
+
+ instruction
+ }
+
+ //
+ // Memory Instructions
+ //
+
+ pub(super) fn variable(
+ result_type_id: Word,
+ id: Word,
+ storage_class: spirv::StorageClass,
+ initializer_id: Option<Word>,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Variable);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(storage_class as u32);
+
+ if let Some(initializer_id) = initializer_id {
+ instruction.add_operand(initializer_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn load(
+ result_type_id: Word,
+ id: Word,
+ pointer_id: Word,
+ memory_access: Option<spirv::MemoryAccess>,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Load);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(pointer_id);
+
+ if let Some(memory_access) = memory_access {
+ instruction.add_operand(memory_access.bits());
+ }
+
+ instruction
+ }
+
+ pub(super) fn atomic_load(
+ result_type_id: Word,
+ id: Word,
+ pointer_id: Word,
+ scope_id: Word,
+ semantics_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::AtomicLoad);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(pointer_id);
+ instruction.add_operand(scope_id);
+ instruction.add_operand(semantics_id);
+ instruction
+ }
+
+ pub(super) fn store(
+ pointer_id: Word,
+ value_id: Word,
+ memory_access: Option<spirv::MemoryAccess>,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Store);
+ instruction.add_operand(pointer_id);
+ instruction.add_operand(value_id);
+
+ if let Some(memory_access) = memory_access {
+ instruction.add_operand(memory_access.bits());
+ }
+
+ instruction
+ }
+
+ pub(super) fn atomic_store(
+ pointer_id: Word,
+ scope_id: Word,
+ semantics_id: Word,
+ value_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::AtomicStore);
+ instruction.add_operand(pointer_id);
+ instruction.add_operand(scope_id);
+ instruction.add_operand(semantics_id);
+ instruction.add_operand(value_id);
+ instruction
+ }
+
+ pub(super) fn access_chain(
+ result_type_id: Word,
+ id: Word,
+ base_id: Word,
+ index_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::AccessChain);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(base_id);
+
+ for index_id in index_ids {
+ instruction.add_operand(*index_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn array_length(
+ result_type_id: Word,
+ id: Word,
+ structure_id: Word,
+ array_member: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::ArrayLength);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(structure_id);
+ instruction.add_operand(array_member);
+ instruction
+ }
+
+ //
+ // Function Instructions
+ //
+
+ pub(super) fn function(
+ return_type_id: Word,
+ id: Word,
+ function_control: spirv::FunctionControl,
+ function_type_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Function);
+ instruction.set_type(return_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(function_control.bits());
+ instruction.add_operand(function_type_id);
+ instruction
+ }
+
+ pub(super) fn function_parameter(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::FunctionParameter);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) const fn function_end() -> Self {
+ Self::new(Op::FunctionEnd)
+ }
+
+ pub(super) fn function_call(
+ result_type_id: Word,
+ id: Word,
+ function_id: Word,
+ argument_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::FunctionCall);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(function_id);
+
+ for argument_id in argument_ids {
+ instruction.add_operand(*argument_id);
+ }
+
+ instruction
+ }
+
+ //
+ // Image Instructions
+ //
+
+ pub(super) fn sampled_image(
+ result_type_id: Word,
+ id: Word,
+ image: Word,
+ sampler: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::SampledImage);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(image);
+ instruction.add_operand(sampler);
+ instruction
+ }
+
+ pub(super) fn image_sample(
+ result_type_id: Word,
+ id: Word,
+ lod: SampleLod,
+ sampled_image: Word,
+ coordinates: Word,
+ depth_ref: Option<Word>,
+ ) -> Self {
+ let op = match (lod, depth_ref) {
+ (SampleLod::Explicit, None) => Op::ImageSampleExplicitLod,
+ (SampleLod::Implicit, None) => Op::ImageSampleImplicitLod,
+ (SampleLod::Explicit, Some(_)) => Op::ImageSampleDrefExplicitLod,
+ (SampleLod::Implicit, Some(_)) => Op::ImageSampleDrefImplicitLod,
+ };
+
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(sampled_image);
+ instruction.add_operand(coordinates);
+ if let Some(dref) = depth_ref {
+ instruction.add_operand(dref);
+ }
+
+ instruction
+ }
+
+ pub(super) fn image_gather(
+ result_type_id: Word,
+ id: Word,
+ sampled_image: Word,
+ coordinates: Word,
+ component_id: Word,
+ depth_ref: Option<Word>,
+ ) -> Self {
+ let op = match depth_ref {
+ None => Op::ImageGather,
+ Some(_) => Op::ImageDrefGather,
+ };
+
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(sampled_image);
+ instruction.add_operand(coordinates);
+ if let Some(dref) = depth_ref {
+ instruction.add_operand(dref);
+ } else {
+ instruction.add_operand(component_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn image_fetch_or_read(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ image: Word,
+ coordinates: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(image);
+ instruction.add_operand(coordinates);
+ instruction
+ }
+
+ pub(super) fn image_write(image: Word, coordinates: Word, value: Word) -> Self {
+ let mut instruction = Self::new(Op::ImageWrite);
+ instruction.add_operand(image);
+ instruction.add_operand(coordinates);
+ instruction.add_operand(value);
+ instruction
+ }
+
+ pub(super) fn image_query(op: Op, result_type_id: Word, id: Word, image: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(image);
+ instruction
+ }
+
+ //
+ // Ray Query Instructions
+ //
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn ray_query_initialize(
+ query: Word,
+ acceleration_structure: Word,
+ ray_flags: Word,
+ cull_mask: Word,
+ ray_origin: Word,
+ ray_tmin: Word,
+ ray_dir: Word,
+ ray_tmax: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::RayQueryInitializeKHR);
+ instruction.add_operand(query);
+ instruction.add_operand(acceleration_structure);
+ instruction.add_operand(ray_flags);
+ instruction.add_operand(cull_mask);
+ instruction.add_operand(ray_origin);
+ instruction.add_operand(ray_tmin);
+ instruction.add_operand(ray_dir);
+ instruction.add_operand(ray_tmax);
+ instruction
+ }
+
+ pub(super) fn ray_query_proceed(result_type_id: Word, id: Word, query: Word) -> Self {
+ let mut instruction = Self::new(Op::RayQueryProceedKHR);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(query);
+ instruction
+ }
+
+ pub(super) fn ray_query_get_intersection(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ query: Word,
+ intersection: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(query);
+ instruction.add_operand(intersection);
+ instruction
+ }
+
+ //
+ // Conversion Instructions
+ //
+ pub(super) fn unary(op: Op, result_type_id: Word, id: Word, value: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(value);
+ instruction
+ }
+
+ //
+ // Composite Instructions
+ //
+
+ pub(super) fn composite_construct(
+ result_type_id: Word,
+ id: Word,
+ constituent_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::CompositeConstruct);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ for constituent_id in constituent_ids {
+ instruction.add_operand(*constituent_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn composite_extract(
+ result_type_id: Word,
+ id: Word,
+ composite_id: Word,
+ indices: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::CompositeExtract);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ instruction.add_operand(composite_id);
+ for index in indices {
+ instruction.add_operand(*index);
+ }
+
+ instruction
+ }
+
+ pub(super) fn vector_extract_dynamic(
+ result_type_id: Word,
+ id: Word,
+ vector_id: Word,
+ index_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::VectorExtractDynamic);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ instruction.add_operand(vector_id);
+ instruction.add_operand(index_id);
+
+ instruction
+ }
+
+ pub(super) fn vector_shuffle(
+ result_type_id: Word,
+ id: Word,
+ v1_id: Word,
+ v2_id: Word,
+ components: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::VectorShuffle);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(v1_id);
+ instruction.add_operand(v2_id);
+
+ for &component in components {
+ instruction.add_operand(component);
+ }
+
+ instruction
+ }
+
+ //
+ // Arithmetic Instructions
+ //
+ pub(super) fn binary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(operand_1);
+ instruction.add_operand(operand_2);
+ instruction
+ }
+
+ pub(super) fn ternary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+ operand_3: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(operand_1);
+ instruction.add_operand(operand_2);
+ instruction.add_operand(operand_3);
+ instruction
+ }
+
+ pub(super) fn quaternary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+ operand_3: Word,
+ operand_4: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(operand_1);
+ instruction.add_operand(operand_2);
+ instruction.add_operand(operand_3);
+ instruction.add_operand(operand_4);
+ instruction
+ }
+
+ pub(super) fn relational(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(expr_id);
+ instruction
+ }
+
+ pub(super) fn atomic_binary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ pointer: Word,
+ scope_id: Word,
+ semantics_id: Word,
+ value: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(pointer);
+ instruction.add_operand(scope_id);
+ instruction.add_operand(semantics_id);
+ instruction.add_operand(value);
+ instruction
+ }
+
+ //
+ // Bit Instructions
+ //
+
+ //
+ // Relational and Logical Instructions
+ //
+
+ //
+ // Derivative Instructions
+ //
+
+ pub(super) fn derivative(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(expr_id);
+ instruction
+ }
+
+ //
+ // Control-Flow Instructions
+ //
+
+ pub(super) fn phi(
+ result_type_id: Word,
+ result_id: Word,
+ var_parent_pairs: &[(Word, Word)],
+ ) -> Self {
+ let mut instruction = Self::new(Op::Phi);
+ instruction.add_operand(result_type_id);
+ instruction.add_operand(result_id);
+ for &(variable, parent) in var_parent_pairs {
+ instruction.add_operand(variable);
+ instruction.add_operand(parent);
+ }
+ instruction
+ }
+
+ pub(super) fn selection_merge(
+ merge_id: Word,
+ selection_control: spirv::SelectionControl,
+ ) -> Self {
+ let mut instruction = Self::new(Op::SelectionMerge);
+ instruction.add_operand(merge_id);
+ instruction.add_operand(selection_control.bits());
+ instruction
+ }
+
+ pub(super) fn loop_merge(
+ merge_id: Word,
+ continuing_id: Word,
+ selection_control: spirv::SelectionControl,
+ ) -> Self {
+ let mut instruction = Self::new(Op::LoopMerge);
+ instruction.add_operand(merge_id);
+ instruction.add_operand(continuing_id);
+ instruction.add_operand(selection_control.bits());
+ instruction
+ }
+
+ pub(super) fn label(id: Word) -> Self {
+ let mut instruction = Self::new(Op::Label);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn branch(id: Word) -> Self {
+ let mut instruction = Self::new(Op::Branch);
+ instruction.add_operand(id);
+ instruction
+ }
+
+ // TODO Branch Weights not implemented.
+ pub(super) fn branch_conditional(
+ condition_id: Word,
+ true_label: Word,
+ false_label: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::BranchConditional);
+ instruction.add_operand(condition_id);
+ instruction.add_operand(true_label);
+ instruction.add_operand(false_label);
+ instruction
+ }
+
+ pub(super) fn switch(selector_id: Word, default_id: Word, cases: &[Case]) -> Self {
+ let mut instruction = Self::new(Op::Switch);
+ instruction.add_operand(selector_id);
+ instruction.add_operand(default_id);
+ for case in cases {
+ instruction.add_operand(case.value);
+ instruction.add_operand(case.label_id);
+ }
+ instruction
+ }
+
+ pub(super) fn select(
+ result_type_id: Word,
+ id: Word,
+ condition_id: Word,
+ accept_id: Word,
+ reject_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Select);
+ instruction.add_operand(result_type_id);
+ instruction.add_operand(id);
+ instruction.add_operand(condition_id);
+ instruction.add_operand(accept_id);
+ instruction.add_operand(reject_id);
+ instruction
+ }
+
+ pub(super) const fn kill() -> Self {
+ Self::new(Op::Kill)
+ }
+
+ pub(super) const fn return_void() -> Self {
+ Self::new(Op::Return)
+ }
+
+ pub(super) fn return_value(value_id: Word) -> Self {
+ let mut instruction = Self::new(Op::ReturnValue);
+ instruction.add_operand(value_id);
+ instruction
+ }
+
+ //
+ // Atomic Instructions
+ //
+
+ //
+ // Primitive Instructions
+ //
+
+ // Barriers
+
+ pub(super) fn control_barrier(
+ exec_scope_id: Word,
+ mem_scope_id: Word,
+ semantics_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::ControlBarrier);
+ instruction.add_operand(exec_scope_id);
+ instruction.add_operand(mem_scope_id);
+ instruction.add_operand(semantics_id);
+ instruction
+ }
+}
+
+impl From<crate::StorageFormat> for spirv::ImageFormat {
+ fn from(format: crate::StorageFormat) -> Self {
+ use crate::StorageFormat as Sf;
+ match format {
+ Sf::R8Unorm => Self::R8,
+ Sf::R8Snorm => Self::R8Snorm,
+ Sf::R8Uint => Self::R8ui,
+ Sf::R8Sint => Self::R8i,
+ Sf::R16Uint => Self::R16ui,
+ Sf::R16Sint => Self::R16i,
+ Sf::R16Float => Self::R16f,
+ Sf::Rg8Unorm => Self::Rg8,
+ Sf::Rg8Snorm => Self::Rg8Snorm,
+ Sf::Rg8Uint => Self::Rg8ui,
+ Sf::Rg8Sint => Self::Rg8i,
+ Sf::R32Uint => Self::R32ui,
+ Sf::R32Sint => Self::R32i,
+ Sf::R32Float => Self::R32f,
+ Sf::Rg16Uint => Self::Rg16ui,
+ Sf::Rg16Sint => Self::Rg16i,
+ Sf::Rg16Float => Self::Rg16f,
+ Sf::Rgba8Unorm => Self::Rgba8,
+ Sf::Rgba8Snorm => Self::Rgba8Snorm,
+ Sf::Rgba8Uint => Self::Rgba8ui,
+ Sf::Rgba8Sint => Self::Rgba8i,
+ Sf::Bgra8Unorm => Self::Unknown,
+ Sf::Rgb10a2Uint => Self::Rgb10a2ui,
+ Sf::Rgb10a2Unorm => Self::Rgb10A2,
+ Sf::Rg11b10Float => Self::R11fG11fB10f,
+ Sf::Rg32Uint => Self::Rg32ui,
+ Sf::Rg32Sint => Self::Rg32i,
+ Sf::Rg32Float => Self::Rg32f,
+ Sf::Rgba16Uint => Self::Rgba16ui,
+ Sf::Rgba16Sint => Self::Rgba16i,
+ Sf::Rgba16Float => Self::Rgba16f,
+ Sf::Rgba32Uint => Self::Rgba32ui,
+ Sf::Rgba32Sint => Self::Rgba32i,
+ Sf::Rgba32Float => Self::Rgba32f,
+ Sf::R16Unorm => Self::R16,
+ Sf::R16Snorm => Self::R16Snorm,
+ Sf::Rg16Unorm => Self::Rg16,
+ Sf::Rg16Snorm => Self::Rg16Snorm,
+ Sf::Rgba16Unorm => Self::Rgba16,
+ Sf::Rgba16Snorm => Self::Rgba16Snorm,
+ }
+ }
+}
+
+impl From<crate::ImageDimension> for spirv::Dim {
+ fn from(dim: crate::ImageDimension) -> Self {
+ use crate::ImageDimension as Id;
+ match dim {
+ Id::D1 => Self::Dim1D,
+ Id::D2 => Self::Dim2D,
+ Id::D3 => Self::Dim3D,
+ Id::Cube => Self::DimCube,
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/layout.rs b/third_party/rust/naga/src/back/spv/layout.rs
new file mode 100644
index 0000000000..39117a3d2a
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/layout.rs
@@ -0,0 +1,210 @@
+use super::{Instruction, LogicalLayout, PhysicalLayout};
+use spirv::{Op, Word, MAGIC_NUMBER};
+use std::iter;
+
+// https://github.com/KhronosGroup/SPIRV-Headers/pull/195
+const GENERATOR: Word = 28;
+
+impl PhysicalLayout {
+ pub(super) const fn new(version: Word) -> Self {
+ PhysicalLayout {
+ magic_number: MAGIC_NUMBER,
+ version,
+ generator: GENERATOR,
+ bound: 0,
+ instruction_schema: 0x0u32,
+ }
+ }
+
+ pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) {
+ sink.extend(iter::once(self.magic_number));
+ sink.extend(iter::once(self.version));
+ sink.extend(iter::once(self.generator));
+ sink.extend(iter::once(self.bound));
+ sink.extend(iter::once(self.instruction_schema));
+ }
+}
+
+impl super::recyclable::Recyclable for PhysicalLayout {
+ fn recycle(self) -> Self {
+ PhysicalLayout {
+ magic_number: self.magic_number,
+ version: self.version,
+ generator: self.generator,
+ instruction_schema: self.instruction_schema,
+ bound: 0,
+ }
+ }
+}
+
+impl LogicalLayout {
+ pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) {
+ sink.extend(self.capabilities.iter().cloned());
+ sink.extend(self.extensions.iter().cloned());
+ sink.extend(self.ext_inst_imports.iter().cloned());
+ sink.extend(self.memory_model.iter().cloned());
+ sink.extend(self.entry_points.iter().cloned());
+ sink.extend(self.execution_modes.iter().cloned());
+ sink.extend(self.debugs.iter().cloned());
+ sink.extend(self.annotations.iter().cloned());
+ sink.extend(self.declarations.iter().cloned());
+ sink.extend(self.function_declarations.iter().cloned());
+ sink.extend(self.function_definitions.iter().cloned());
+ }
+}
+
+impl super::recyclable::Recyclable for LogicalLayout {
+ fn recycle(self) -> Self {
+ Self {
+ capabilities: self.capabilities.recycle(),
+ extensions: self.extensions.recycle(),
+ ext_inst_imports: self.ext_inst_imports.recycle(),
+ memory_model: self.memory_model.recycle(),
+ entry_points: self.entry_points.recycle(),
+ execution_modes: self.execution_modes.recycle(),
+ debugs: self.debugs.recycle(),
+ annotations: self.annotations.recycle(),
+ declarations: self.declarations.recycle(),
+ function_declarations: self.function_declarations.recycle(),
+ function_definitions: self.function_definitions.recycle(),
+ }
+ }
+}
+
+impl Instruction {
+ pub(super) const fn new(op: Op) -> Self {
+ Instruction {
+ op,
+ wc: 1, // Always start at 1 for the first word (OP + WC),
+ type_id: None,
+ result_id: None,
+ operands: vec![],
+ }
+ }
+
+ #[allow(clippy::panic)]
+ pub(super) fn set_type(&mut self, id: Word) {
+ assert!(self.type_id.is_none(), "Type can only be set once");
+ self.type_id = Some(id);
+ self.wc += 1;
+ }
+
+ #[allow(clippy::panic)]
+ pub(super) fn set_result(&mut self, id: Word) {
+ assert!(self.result_id.is_none(), "Result can only be set once");
+ self.result_id = Some(id);
+ self.wc += 1;
+ }
+
+ pub(super) fn add_operand(&mut self, operand: Word) {
+ self.operands.push(operand);
+ self.wc += 1;
+ }
+
+ pub(super) fn add_operands(&mut self, operands: Vec<Word>) {
+ for operand in operands.into_iter() {
+ self.add_operand(operand)
+ }
+ }
+
+ pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) {
+ sink.extend(Some(self.wc << 16 | self.op as u32));
+ sink.extend(self.type_id);
+ sink.extend(self.result_id);
+ sink.extend(self.operands.iter().cloned());
+ }
+}
+
+impl Instruction {
+ #[cfg(test)]
+ fn validate(&self, words: &[Word]) {
+ let mut inst_index = 0;
+ let (wc, op) = ((words[inst_index] >> 16) as u16, words[inst_index] as u16);
+ inst_index += 1;
+
+ assert_eq!(wc, words.len() as u16);
+ assert_eq!(op, self.op as u16);
+
+ if self.type_id.is_some() {
+ assert_eq!(words[inst_index], self.type_id.unwrap());
+ inst_index += 1;
+ }
+
+ if self.result_id.is_some() {
+ assert_eq!(words[inst_index], self.result_id.unwrap());
+ inst_index += 1;
+ }
+
+ for (op_index, i) in (inst_index..wc as usize).enumerate() {
+ assert_eq!(words[i], self.operands[op_index]);
+ }
+ }
+}
+
+#[test]
+fn test_physical_layout_in_words() {
+ let bound = 5;
+ let version = 0x10203;
+
+ let mut output = vec![];
+ let mut layout = PhysicalLayout::new(version);
+ layout.bound = bound;
+
+ layout.in_words(&mut output);
+
+ assert_eq!(&output, &[MAGIC_NUMBER, version, GENERATOR, bound, 0,]);
+}
+
+#[test]
+fn test_logical_layout_in_words() {
+ let mut output = vec![];
+ let mut layout = LogicalLayout::default();
+ let layout_vectors = 11;
+ let mut instructions = Vec::with_capacity(layout_vectors);
+
+ let vector_names = &[
+ "Capabilities",
+ "Extensions",
+ "External Instruction Imports",
+ "Memory Model",
+ "Entry Points",
+ "Execution Modes",
+ "Debugs",
+ "Annotations",
+ "Declarations",
+ "Function Declarations",
+ "Function Definitions",
+ ];
+
+ for (i, _) in vector_names.iter().enumerate().take(layout_vectors) {
+ let mut dummy_instruction = Instruction::new(Op::Constant);
+ dummy_instruction.set_type((i + 1) as u32);
+ dummy_instruction.set_result((i + 2) as u32);
+ dummy_instruction.add_operand((i + 3) as u32);
+ dummy_instruction.add_operands(super::helpers::string_to_words(
+ format!("This is the vector: {}", vector_names[i]).as_str(),
+ ));
+ instructions.push(dummy_instruction);
+ }
+
+ instructions[0].to_words(&mut layout.capabilities);
+ instructions[1].to_words(&mut layout.extensions);
+ instructions[2].to_words(&mut layout.ext_inst_imports);
+ instructions[3].to_words(&mut layout.memory_model);
+ instructions[4].to_words(&mut layout.entry_points);
+ instructions[5].to_words(&mut layout.execution_modes);
+ instructions[6].to_words(&mut layout.debugs);
+ instructions[7].to_words(&mut layout.annotations);
+ instructions[8].to_words(&mut layout.declarations);
+ instructions[9].to_words(&mut layout.function_declarations);
+ instructions[10].to_words(&mut layout.function_definitions);
+
+ layout.in_words(&mut output);
+
+ let mut index: usize = 0;
+ for instruction in instructions {
+ let wc = instruction.wc as usize;
+ instruction.validate(&output[index..index + wc]);
+ index += wc;
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/mod.rs b/third_party/rust/naga/src/back/spv/mod.rs
new file mode 100644
index 0000000000..b7d57be0d4
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/mod.rs
@@ -0,0 +1,748 @@
+/*!
+Backend for [SPIR-V][spv] (Standard Portable Intermediate Representation).
+
+[spv]: https://www.khronos.org/registry/SPIR-V/
+*/
+
+mod block;
+mod helpers;
+mod image;
+mod index;
+mod instructions;
+mod layout;
+mod ray;
+mod recyclable;
+mod selection;
+mod writer;
+
+pub use spirv::Capability;
+
+use crate::arena::Handle;
+use crate::proc::{BoundsCheckPolicies, TypeResolution};
+
+use spirv::Word;
+use std::ops;
+use thiserror::Error;
+
+#[derive(Clone)]
+struct PhysicalLayout {
+ magic_number: Word,
+ version: Word,
+ generator: Word,
+ bound: Word,
+ instruction_schema: Word,
+}
+
+#[derive(Default)]
+struct LogicalLayout {
+ capabilities: Vec<Word>,
+ extensions: Vec<Word>,
+ ext_inst_imports: Vec<Word>,
+ memory_model: Vec<Word>,
+ entry_points: Vec<Word>,
+ execution_modes: Vec<Word>,
+ debugs: Vec<Word>,
+ annotations: Vec<Word>,
+ declarations: Vec<Word>,
+ function_declarations: Vec<Word>,
+ function_definitions: Vec<Word>,
+}
+
+struct Instruction {
+ op: spirv::Op,
+ wc: u32,
+ type_id: Option<Word>,
+ result_id: Option<Word>,
+ operands: Vec<Word>,
+}
+
+const BITS_PER_BYTE: crate::Bytes = 8;
+
+#[derive(Clone, Debug, Error)]
+pub enum Error {
+ #[error("The requested entry point couldn't be found")]
+ EntryPointNotFound,
+ #[error("target SPIRV-{0}.{1} is not supported")]
+ UnsupportedVersion(u8, u8),
+ #[error("using {0} requires at least one of the capabilities {1:?}, but none are available")]
+ MissingCapabilities(&'static str, Vec<Capability>),
+ #[error("unimplemented {0}")]
+ FeatureNotImplemented(&'static str),
+ #[error("module is not validated properly: {0}")]
+ Validation(&'static str),
+}
+
+#[derive(Default)]
+struct IdGenerator(Word);
+
+impl IdGenerator {
+ fn next(&mut self) -> Word {
+ self.0 += 1;
+ self.0
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct DebugInfo<'a> {
+ pub source_code: &'a str,
+ pub file_name: &'a std::path::Path,
+}
+
+/// A SPIR-V block to which we are still adding instructions.
+///
+/// A `Block` represents a SPIR-V block that does not yet have a termination
+/// instruction like `OpBranch` or `OpReturn`.
+///
+/// The `OpLabel` that starts the block is implicit. It will be emitted based on
+/// `label_id` when we write the block to a `LogicalLayout`.
+///
+/// To terminate a `Block`, pass the block and the termination instruction to
+/// `Function::consume`. This takes ownership of the `Block` and transforms it
+/// into a `TerminatedBlock`.
+struct Block {
+ label_id: Word,
+ body: Vec<Instruction>,
+}
+
+/// A SPIR-V block that ends with a termination instruction.
+struct TerminatedBlock {
+ label_id: Word,
+ body: Vec<Instruction>,
+}
+
+impl Block {
+ const fn new(label_id: Word) -> Self {
+ Block {
+ label_id,
+ body: Vec::new(),
+ }
+ }
+}
+
+struct LocalVariable {
+ id: Word,
+ instruction: Instruction,
+}
+
+struct ResultMember {
+ id: Word,
+ type_id: Word,
+ built_in: Option<crate::BuiltIn>,
+}
+
+struct EntryPointContext {
+ argument_ids: Vec<Word>,
+ results: Vec<ResultMember>,
+}
+
+#[derive(Default)]
+struct Function {
+ signature: Option<Instruction>,
+ parameters: Vec<FunctionArgument>,
+ variables: crate::FastHashMap<Handle<crate::LocalVariable>, LocalVariable>,
+ blocks: Vec<TerminatedBlock>,
+ entry_point_context: Option<EntryPointContext>,
+}
+
+impl Function {
+ fn consume(&mut self, mut block: Block, termination: Instruction) {
+ block.body.push(termination);
+ self.blocks.push(TerminatedBlock {
+ label_id: block.label_id,
+ body: block.body,
+ })
+ }
+
+ fn parameter_id(&self, index: u32) -> Word {
+ match self.entry_point_context {
+ Some(ref context) => context.argument_ids[index as usize],
+ None => self.parameters[index as usize]
+ .instruction
+ .result_id
+ .unwrap(),
+ }
+ }
+}
+
+/// Characteristics of a SPIR-V `OpTypeImage` type.
+///
+/// SPIR-V requires non-composite types to be unique, including images. Since we
+/// use `LocalType` for this deduplication, it's essential that `LocalImageType`
+/// be equal whenever the corresponding `OpTypeImage`s would be. To reduce the
+/// likelihood of mistakes, we use fields that correspond exactly to the
+/// operands of an `OpTypeImage` instruction, using the actual SPIR-V types
+/// where practical.
+#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
+struct LocalImageType {
+ sampled_type: crate::ScalarKind,
+ dim: spirv::Dim,
+ flags: ImageTypeFlags,
+ image_format: spirv::ImageFormat,
+}
+
+bitflags::bitflags! {
+ /// Flags corresponding to the boolean(-ish) parameters to OpTypeImage.
+ #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
+ pub struct ImageTypeFlags: u8 {
+ const DEPTH = 0x1;
+ const ARRAYED = 0x2;
+ const MULTISAMPLED = 0x4;
+ const SAMPLED = 0x8;
+ }
+}
+
+impl LocalImageType {
+ /// Construct a `LocalImageType` from the fields of a `TypeInner::Image`.
+ fn from_inner(dim: crate::ImageDimension, arrayed: bool, class: crate::ImageClass) -> Self {
+ let make_flags = |multi: bool, other: ImageTypeFlags| -> ImageTypeFlags {
+ let mut flags = other;
+ flags.set(ImageTypeFlags::ARRAYED, arrayed);
+ flags.set(ImageTypeFlags::MULTISAMPLED, multi);
+ flags
+ };
+
+ let dim = spirv::Dim::from(dim);
+
+ match class {
+ crate::ImageClass::Sampled { kind, multi } => LocalImageType {
+ sampled_type: kind,
+ dim,
+ flags: make_flags(multi, ImageTypeFlags::SAMPLED),
+ image_format: spirv::ImageFormat::Unknown,
+ },
+ crate::ImageClass::Depth { multi } => LocalImageType {
+ sampled_type: crate::ScalarKind::Float,
+ dim,
+ flags: make_flags(multi, ImageTypeFlags::DEPTH | ImageTypeFlags::SAMPLED),
+ image_format: spirv::ImageFormat::Unknown,
+ },
+ crate::ImageClass::Storage { format, access: _ } => LocalImageType {
+ sampled_type: crate::ScalarKind::from(format),
+ dim,
+ flags: make_flags(false, ImageTypeFlags::empty()),
+ image_format: format.into(),
+ },
+ }
+ }
+}
+
+/// A SPIR-V type constructed during code generation.
+///
+/// This is the variant of [`LookupType`] used to represent types that might not
+/// be available in the arena. Variants are present here for one of two reasons:
+///
+/// - They represent types synthesized during code generation, as explained
+/// in the documentation for [`LookupType`].
+///
+/// - They represent types for which SPIR-V forbids duplicate `OpType...`
+/// instructions, requiring deduplication.
+///
+/// This is not a complete copy of [`TypeInner`]: for example, SPIR-V generation
+/// never synthesizes new struct types, so `LocalType` has nothing for that.
+///
+/// Each `LocalType` variant should be handled identically to its analogous
+/// `TypeInner` variant. You can use the [`make_local`] function to help with
+/// this, by converting everything possible to a `LocalType` before inspecting
+/// it.
+///
+/// ## `Localtype` equality and SPIR-V `OpType` uniqueness
+///
+/// The definition of `Eq` on `LocalType` is carefully chosen to help us follow
+/// certain SPIR-V rules. SPIR-V §2.8 requires some classes of `OpType...`
+/// instructions to be unique; for example, you can't have two `OpTypeInt 32 1`
+/// instructions in the same module. All 32-bit signed integers must use the
+/// same type id.
+///
+/// All SPIR-V types that must be unique can be represented as a `LocalType`,
+/// and two `LocalType`s are always `Eq` if SPIR-V would require them to use the
+/// same `OpType...` instruction. This lets us avoid duplicates by recording the
+/// ids of the type instructions we've already generated in a hash table,
+/// [`Writer::lookup_type`], keyed by `LocalType`.
+///
+/// As another example, [`LocalImageType`], stored in the `LocalType::Image`
+/// variant, is designed to help us deduplicate `OpTypeImage` instructions. See
+/// its documentation for details.
+///
+/// `LocalType` also includes variants like `Pointer` that do not need to be
+/// unique - but it is harmless to avoid the duplication.
+///
+/// As it always must, the `Hash` implementation respects the `Eq` relation.
+///
+/// [`TypeInner`]: crate::TypeInner
+#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
+enum LocalType {
+ /// A scalar, vector, or pointer to one of those.
+ Value {
+ /// If `None`, this represents a scalar type. If `Some`, this represents
+ /// a vector type of the given size.
+ vector_size: Option<crate::VectorSize>,
+ scalar: crate::Scalar,
+ pointer_space: Option<spirv::StorageClass>,
+ },
+ /// A matrix of floating-point values.
+ Matrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: crate::Bytes,
+ },
+ Pointer {
+ base: Handle<crate::Type>,
+ class: spirv::StorageClass,
+ },
+ Image(LocalImageType),
+ SampledImage {
+ image_type_id: Word,
+ },
+ Sampler,
+ /// Equivalent to a [`LocalType::Pointer`] whose `base` is a Naga IR [`BindingArray`]. SPIR-V
+ /// permits duplicated `OpTypePointer` ids, so it's fine to have two different [`LocalType`]
+ /// representations for pointer types.
+ ///
+ /// [`BindingArray`]: crate::TypeInner::BindingArray
+ PointerToBindingArray {
+ base: Handle<crate::Type>,
+ size: u32,
+ space: crate::AddressSpace,
+ },
+ BindingArray {
+ base: Handle<crate::Type>,
+ size: u32,
+ },
+ AccelerationStructure,
+ RayQuery,
+}
+
+/// A type encountered during SPIR-V generation.
+///
+/// In the process of writing SPIR-V, we need to synthesize various types for
+/// intermediate results and such: pointer types, vector/matrix component types,
+/// or even booleans, which usually appear in SPIR-V code even when they're not
+/// used by the module source.
+///
+/// However, we can't use `crate::Type` or `crate::TypeInner` for these, as the
+/// type arena may not contain what we need (it only contains types used
+/// directly by other parts of the IR), and the IR module is immutable, so we
+/// can't add anything to it.
+///
+/// So for local use in the SPIR-V writer, we use this type, which holds either
+/// a handle into the arena, or a [`LocalType`] containing something synthesized
+/// locally.
+///
+/// This is very similar to the [`proc::TypeResolution`] enum, with `LocalType`
+/// playing the role of `TypeInner`. However, `LocalType` also has other
+/// properties needed for SPIR-V generation; see the description of
+/// [`LocalType`] for details.
+///
+/// [`proc::TypeResolution`]: crate::proc::TypeResolution
+#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
+enum LookupType {
+ Handle(Handle<crate::Type>),
+ Local(LocalType),
+}
+
+impl From<LocalType> for LookupType {
+ fn from(local: LocalType) -> Self {
+ Self::Local(local)
+ }
+}
+
+#[derive(Debug, PartialEq, Clone, Hash, Eq)]
+struct LookupFunctionType {
+ parameter_type_ids: Vec<Word>,
+ return_type_id: Word,
+}
+
+fn make_local(inner: &crate::TypeInner) -> Option<LocalType> {
+ Some(match *inner {
+ crate::TypeInner::Scalar(scalar) | crate::TypeInner::Atomic(scalar) => LocalType::Value {
+ vector_size: None,
+ scalar,
+ pointer_space: None,
+ },
+ crate::TypeInner::Vector { size, scalar } => LocalType::Value {
+ vector_size: Some(size),
+ scalar,
+ pointer_space: None,
+ },
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => LocalType::Matrix {
+ columns,
+ rows,
+ width: scalar.width,
+ },
+ crate::TypeInner::Pointer { base, space } => LocalType::Pointer {
+ base,
+ class: helpers::map_storage_class(space),
+ },
+ crate::TypeInner::ValuePointer {
+ size,
+ scalar,
+ space,
+ } => LocalType::Value {
+ vector_size: size,
+ scalar,
+ pointer_space: Some(helpers::map_storage_class(space)),
+ },
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)),
+ crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler,
+ crate::TypeInner::AccelerationStructure => LocalType::AccelerationStructure,
+ crate::TypeInner::RayQuery => LocalType::RayQuery,
+ crate::TypeInner::Array { .. }
+ | crate::TypeInner::Struct { .. }
+ | crate::TypeInner::BindingArray { .. } => return None,
+ })
+}
+
+#[derive(Debug)]
+enum Dimension {
+ Scalar,
+ Vector,
+ Matrix,
+}
+
+/// A map from evaluated [`Expression`](crate::Expression)s to their SPIR-V ids.
+///
+/// When we emit code to evaluate a given `Expression`, we record the
+/// SPIR-V id of its value here, under its `Handle<Expression>` index.
+///
+/// A `CachedExpressions` value can be indexed by a `Handle<Expression>` value.
+///
+/// [emit]: index.html#expression-evaluation-time-and-scope
+#[derive(Default)]
+struct CachedExpressions {
+ ids: Vec<Word>,
+}
+impl CachedExpressions {
+ fn reset(&mut self, length: usize) {
+ self.ids.clear();
+ self.ids.resize(length, 0);
+ }
+}
+impl ops::Index<Handle<crate::Expression>> for CachedExpressions {
+ type Output = Word;
+ fn index(&self, h: Handle<crate::Expression>) -> &Word {
+ let id = &self.ids[h.index()];
+ if *id == 0 {
+ unreachable!("Expression {:?} is not cached!", h);
+ }
+ id
+ }
+}
+impl ops::IndexMut<Handle<crate::Expression>> for CachedExpressions {
+ fn index_mut(&mut self, h: Handle<crate::Expression>) -> &mut Word {
+ let id = &mut self.ids[h.index()];
+ if *id != 0 {
+ unreachable!("Expression {:?} is already cached!", h);
+ }
+ id
+ }
+}
+impl recyclable::Recyclable for CachedExpressions {
+ fn recycle(self) -> Self {
+ CachedExpressions {
+ ids: self.ids.recycle(),
+ }
+ }
+}
+
+#[derive(Eq, Hash, PartialEq)]
+enum CachedConstant {
+ Literal(crate::Literal),
+ Composite {
+ ty: LookupType,
+ constituent_ids: Vec<Word>,
+ },
+ ZeroValue(Word),
+}
+
+#[derive(Clone)]
+struct GlobalVariable {
+ /// ID of the OpVariable that declares the global.
+ ///
+ /// If you need the variable's value, use [`access_id`] instead of this
+ /// field. If we wrapped the Naga IR `GlobalVariable`'s type in a struct to
+ /// comply with Vulkan's requirements, then this points to the `OpVariable`
+ /// with the synthesized struct type, whereas `access_id` points to the
+ /// field of said struct that holds the variable's actual value.
+ ///
+ /// This is used to compute the `access_id` pointer in function prologues,
+ /// and used for `ArrayLength` expressions, which do need the struct.
+ ///
+ /// [`access_id`]: GlobalVariable::access_id
+ var_id: Word,
+
+ /// For `AddressSpace::Handle` variables, this ID is recorded in the function
+ /// prelude block (and reset before every function) as `OpLoad` of the variable.
+ /// It is then used for all the global ops, such as `OpImageSample`.
+ handle_id: Word,
+
+ /// Actual ID used to access this variable.
+ /// For wrapped buffer variables, this ID is `OpAccessChain` into the
+ /// wrapper. Otherwise, the same as `var_id`.
+ ///
+ /// Vulkan requires that globals in the `StorageBuffer` and `Uniform` storage
+ /// classes must be structs with the `Block` decoration, but WGSL and Naga IR
+ /// make no such requirement. So for such variables, we generate a wrapper struct
+ /// type with a single element of the type given by Naga, generate an
+ /// `OpAccessChain` for that member in the function prelude, and use that pointer
+ /// to refer to the global in the function body. This is the id of that access,
+ /// updated for each function in `write_function`.
+ access_id: Word,
+}
+
+impl GlobalVariable {
+ const fn dummy() -> Self {
+ Self {
+ var_id: 0,
+ handle_id: 0,
+ access_id: 0,
+ }
+ }
+
+ const fn new(id: Word) -> Self {
+ Self {
+ var_id: id,
+ handle_id: 0,
+ access_id: 0,
+ }
+ }
+
+ /// Prepare `self` for use within a single function.
+ fn reset_for_function(&mut self) {
+ self.handle_id = 0;
+ self.access_id = 0;
+ }
+}
+
+struct FunctionArgument {
+ /// Actual instruction of the argument.
+ instruction: Instruction,
+ handle_id: Word,
+}
+
+/// General information needed to emit SPIR-V for Naga statements.
+struct BlockContext<'w> {
+ /// The writer handling the module to which this code belongs.
+ writer: &'w mut Writer,
+
+ /// The [`Module`](crate::Module) for which we're generating code.
+ ir_module: &'w crate::Module,
+
+ /// The [`Function`](crate::Function) for which we're generating code.
+ ir_function: &'w crate::Function,
+
+ /// Information module validation produced about
+ /// [`ir_function`](BlockContext::ir_function).
+ fun_info: &'w crate::valid::FunctionInfo,
+
+ /// The [`spv::Function`](Function) to which we are contributing SPIR-V instructions.
+ function: &'w mut Function,
+
+ /// SPIR-V ids for expressions we've evaluated.
+ cached: CachedExpressions,
+
+ /// The `Writer`'s temporary vector, for convenience.
+ temp_list: Vec<Word>,
+
+ /// Tracks the constness of `Expression`s residing in `self.ir_function.expressions`
+ expression_constness: crate::proc::ExpressionConstnessTracker,
+}
+
+impl BlockContext<'_> {
+ fn gen_id(&mut self) -> Word {
+ self.writer.id_gen.next()
+ }
+
+ fn get_type_id(&mut self, lookup_type: LookupType) -> Word {
+ self.writer.get_type_id(lookup_type)
+ }
+
+ fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
+ self.writer.get_expression_type_id(tr)
+ }
+
+ fn get_index_constant(&mut self, index: Word) -> Word {
+ self.writer.get_constant_scalar(crate::Literal::U32(index))
+ }
+
+ fn get_scope_constant(&mut self, scope: Word) -> Word {
+ self.writer
+ .get_constant_scalar(crate::Literal::I32(scope as _))
+ }
+}
+
+#[derive(Clone, Copy, Default)]
+struct LoopContext {
+ continuing_id: Option<Word>,
+ break_id: Option<Word>,
+}
+
+pub struct Writer {
+ physical_layout: PhysicalLayout,
+ logical_layout: LogicalLayout,
+ id_gen: IdGenerator,
+
+ /// The set of capabilities modules are permitted to use.
+ ///
+ /// This is initialized from `Options::capabilities`.
+ capabilities_available: Option<crate::FastHashSet<Capability>>,
+
+ /// The set of capabilities used by this module.
+ ///
+ /// If `capabilities_available` is `Some`, then this is always a subset of
+ /// that.
+ capabilities_used: crate::FastIndexSet<Capability>,
+
+ /// The set of spirv extensions used.
+ extensions_used: crate::FastIndexSet<&'static str>,
+
+ debugs: Vec<Instruction>,
+ annotations: Vec<Instruction>,
+ flags: WriterFlags,
+ bounds_check_policies: BoundsCheckPolicies,
+ zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode,
+ void_type: Word,
+ //TODO: convert most of these into vectors, addressable by handle indices
+ lookup_type: crate::FastHashMap<LookupType, Word>,
+ lookup_function: crate::FastHashMap<Handle<crate::Function>, Word>,
+ lookup_function_type: crate::FastHashMap<LookupFunctionType, Word>,
+ /// Indexed by const-expression handle indexes
+ constant_ids: Vec<Word>,
+ cached_constants: crate::FastHashMap<CachedConstant, Word>,
+ global_variables: Vec<GlobalVariable>,
+ binding_map: BindingMap,
+
+ // Cached expressions are only meaningful within a BlockContext, but we
+ // retain the table here between functions to save heap allocations.
+ saved_cached: CachedExpressions,
+
+ gl450_ext_inst_id: Word,
+
+ // Just a temporary list of SPIR-V ids
+ temp_list: Vec<Word>,
+}
+
+bitflags::bitflags! {
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct WriterFlags: u32 {
+ /// Include debug labels for everything.
+ const DEBUG = 0x1;
+ /// Flip Y coordinate of `BuiltIn::Position` output.
+ const ADJUST_COORDINATE_SPACE = 0x2;
+ /// Emit `OpName` for input/output locations.
+ /// Contrary to spec, some drivers treat it as semantic, not allowing
+ /// any conflicts.
+ const LABEL_VARYINGS = 0x4;
+ /// Emit `PointSize` output builtin to vertex shaders, which is
+ /// required for drawing with `PointList` topology.
+ const FORCE_POINT_SIZE = 0x8;
+ /// Clamp `BuiltIn::FragDepth` output between 0 and 1.
+ const CLAMP_FRAG_DEPTH = 0x10;
+ }
+}
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct BindingInfo {
+ /// If the binding is an unsized binding array, this overrides the size.
+ pub binding_array_size: Option<u32>,
+}
+
+// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindingInfo>;
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub enum ZeroInitializeWorkgroupMemoryMode {
+ /// Via `VK_KHR_zero_initialize_workgroup_memory` or Vulkan 1.3
+ Native,
+ /// Via assignments + barrier
+ Polyfill,
+ None,
+}
+
+#[derive(Debug, Clone)]
+pub struct Options<'a> {
+ /// (Major, Minor) target version of the SPIR-V.
+ pub lang_version: (u8, u8),
+
+ /// Configuration flags for the writer.
+ pub flags: WriterFlags,
+
+ /// Map of resources to information about the binding.
+ pub binding_map: BindingMap,
+
+ /// If given, the set of capabilities modules are allowed to use. Code that
+ /// requires capabilities beyond these is rejected with an error.
+ ///
+ /// If this is `None`, all capabilities are permitted.
+ pub capabilities: Option<crate::FastHashSet<Capability>>,
+
+ /// How should generate code handle array, vector, matrix, or image texel
+ /// indices that are out of range?
+ pub bounds_check_policies: BoundsCheckPolicies,
+
+ /// Dictates the way workgroup variables should be zero initialized
+ pub zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode,
+
+ pub debug_info: Option<DebugInfo<'a>>,
+}
+
+impl<'a> Default for Options<'a> {
+ fn default() -> Self {
+ let mut flags = WriterFlags::ADJUST_COORDINATE_SPACE
+ | WriterFlags::LABEL_VARYINGS
+ | WriterFlags::CLAMP_FRAG_DEPTH;
+ if cfg!(debug_assertions) {
+ flags |= WriterFlags::DEBUG;
+ }
+ Options {
+ lang_version: (1, 0),
+ flags,
+ binding_map: BindingMap::default(),
+ capabilities: None,
+ bounds_check_policies: crate::proc::BoundsCheckPolicies::default(),
+ zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode::Polyfill,
+ debug_info: None,
+ }
+ }
+}
+
+// A subset of options meant to be changed per pipeline.
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct PipelineOptions {
+ /// The stage of the entry point.
+ pub shader_stage: crate::ShaderStage,
+ /// The name of the entry point.
+ ///
+ /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown.
+ pub entry_point: String,
+}
+
+pub fn write_vec(
+ module: &crate::Module,
+ info: &crate::valid::ModuleInfo,
+ options: &Options,
+ pipeline_options: Option<&PipelineOptions>,
+) -> Result<Vec<u32>, Error> {
+ let mut words: Vec<u32> = Vec::new();
+ let mut w = Writer::new(options)?;
+
+ w.write(
+ module,
+ info,
+ pipeline_options,
+ &options.debug_info,
+ &mut words,
+ )?;
+ Ok(words)
+}
diff --git a/third_party/rust/naga/src/back/spv/ray.rs b/third_party/rust/naga/src/back/spv/ray.rs
new file mode 100644
index 0000000000..bc2c4ce3c6
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/ray.rs
@@ -0,0 +1,261 @@
+/*!
+Generating SPIR-V for ray query operations.
+*/
+
+use super::{Block, BlockContext, Instruction, LocalType, LookupType};
+use crate::arena::Handle;
+
+impl<'w> BlockContext<'w> {
+ pub(super) fn write_ray_query_function(
+ &mut self,
+ query: Handle<crate::Expression>,
+ function: &crate::RayQueryFunction,
+ block: &mut Block,
+ ) {
+ let query_id = self.cached[query];
+ match *function {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ //Note: composite extract indices and types must match `generate_ray_desc_type`
+ let desc_id = self.cached[descriptor];
+ let acc_struct_id = self.get_handle_id(acceleration_structure);
+
+ let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::U32,
+ pointer_space: None,
+ }));
+ let ray_flags_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ flag_type_id,
+ ray_flags_id,
+ desc_id,
+ &[0],
+ ));
+ let cull_mask_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ flag_type_id,
+ cull_mask_id,
+ desc_id,
+ &[1],
+ ));
+
+ let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }));
+ let tmin_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ scalar_type_id,
+ tmin_id,
+ desc_id,
+ &[2],
+ ));
+ let tmax_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ scalar_type_id,
+ tmax_id,
+ desc_id,
+ &[3],
+ ));
+
+ let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Tri),
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }));
+ let ray_origin_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ vector_type_id,
+ ray_origin_id,
+ desc_id,
+ &[4],
+ ));
+ let ray_dir_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ vector_type_id,
+ ray_dir_id,
+ desc_id,
+ &[5],
+ ));
+
+ block.body.push(Instruction::ray_query_initialize(
+ query_id,
+ acc_struct_id,
+ ray_flags_id,
+ cull_mask_id,
+ ray_origin_id,
+ tmin_id,
+ ray_dir_id,
+ tmax_id,
+ ));
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ let id = self.gen_id();
+ self.cached[result] = id;
+ let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
+
+ block
+ .body
+ .push(Instruction::ray_query_proceed(result_type_id, id, query_id));
+ }
+ crate::RayQueryFunction::Terminate => {}
+ }
+ }
+
+ pub(super) fn write_ray_query_get_intersection(
+ &mut self,
+ query: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> spirv::Word {
+ let query_id = self.cached[query];
+ let intersection_id = self.writer.get_constant_scalar(crate::Literal::U32(
+ spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _,
+ ));
+
+ let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::U32,
+ pointer_space: None,
+ }));
+ let kind_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionTypeKHR,
+ flag_type_id,
+ kind_id,
+ query_id,
+ intersection_id,
+ ));
+ let instance_custom_index_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
+ flag_type_id,
+ instance_custom_index_id,
+ query_id,
+ intersection_id,
+ ));
+ let instance_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
+ flag_type_id,
+ instance_id,
+ query_id,
+ intersection_id,
+ ));
+ let sbt_record_offset_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
+ flag_type_id,
+ sbt_record_offset_id,
+ query_id,
+ intersection_id,
+ ));
+ let geometry_index_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
+ flag_type_id,
+ geometry_index_id,
+ query_id,
+ intersection_id,
+ ));
+ let primitive_index_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
+ flag_type_id,
+ primitive_index_id,
+ query_id,
+ intersection_id,
+ ));
+
+ let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }));
+ let t_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionTKHR,
+ scalar_type_id,
+ t_id,
+ query_id,
+ intersection_id,
+ ));
+
+ let barycentrics_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Bi),
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }));
+ let barycentrics_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
+ barycentrics_type_id,
+ barycentrics_id,
+ query_id,
+ intersection_id,
+ ));
+
+ let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::BOOL,
+ pointer_space: None,
+ }));
+ let front_face_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
+ bool_type_id,
+ front_face_id,
+ query_id,
+ intersection_id,
+ ));
+
+ let transform_type_id = self.get_type_id(LookupType::Local(LocalType::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ }));
+ let object_to_world_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
+ transform_type_id,
+ object_to_world_id,
+ query_id,
+ intersection_id,
+ ));
+ let world_to_object_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
+ transform_type_id,
+ world_to_object_id,
+ query_id,
+ intersection_id,
+ ));
+
+ let id = self.gen_id();
+ let intersection_type_id = self.get_type_id(LookupType::Handle(
+ self.ir_module.special_types.ray_intersection.unwrap(),
+ ));
+ //Note: the arguments must match `generate_ray_intersection_type` layout
+ block.body.push(Instruction::composite_construct(
+ intersection_type_id,
+ id,
+ &[
+ kind_id,
+ t_id,
+ instance_custom_index_id,
+ instance_id,
+ sbt_record_offset_id,
+ geometry_index_id,
+ primitive_index_id,
+ barycentrics_id,
+ front_face_id,
+ object_to_world_id,
+ world_to_object_id,
+ ],
+ ));
+ id
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/recyclable.rs b/third_party/rust/naga/src/back/spv/recyclable.rs
new file mode 100644
index 0000000000..cd1466e3c7
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/recyclable.rs
@@ -0,0 +1,67 @@
+/*!
+Reusing collections' previous allocations.
+*/
+
+/// A value that can be reset to its initial state, retaining its current allocations.
+///
+/// Naga attempts to lower the cost of SPIR-V generation by allowing clients to
+/// reuse the same `Writer` for multiple Module translations. Reusing a `Writer`
+/// means that the `Vec`s, `HashMap`s, and other heap-allocated structures the
+/// `Writer` uses internally begin the translation with heap-allocated buffers
+/// ready to use.
+///
+/// But this approach introduces the risk of `Writer` state leaking from one
+/// module to the next. When a developer adds fields to `Writer` or its internal
+/// types, they must remember to reset their contents between modules.
+///
+/// One trick to ensure that every field has been accounted for is to use Rust's
+/// struct literal syntax to construct a new, reset value. If a developer adds a
+/// field, but neglects to update the reset code, the compiler will complain
+/// that a field is missing from the literal. This trait's `recycle` method
+/// takes `self` by value, and returns `Self` by value, encouraging the use of
+/// struct literal expressions in its implementation.
+pub trait Recyclable {
+ /// Clear `self`, retaining its current memory allocations.
+ ///
+ /// Shrink the buffer if it's currently much larger than was actually used.
+ /// This prevents a module with exceptionally large allocations from causing
+ /// the `Writer` to retain more memory than it needs indefinitely.
+ fn recycle(self) -> Self;
+}
+
+// Stock values for various collections.
+
+impl<T> Recyclable for Vec<T> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
+
+impl<K, V, S: Clone> Recyclable for std::collections::HashMap<K, V, S> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
+
+impl<K, S: Clone> Recyclable for std::collections::HashSet<K, S> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
+
+impl<K, S: Clone> Recyclable for indexmap::IndexSet<K, S> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
+
+impl<K: Ord, V> Recyclable for std::collections::BTreeMap<K, V> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/selection.rs b/third_party/rust/naga/src/back/spv/selection.rs
new file mode 100644
index 0000000000..788b1f10ab
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/selection.rs
@@ -0,0 +1,257 @@
+/*!
+Generate SPIR-V conditional structures.
+
+Builders for `if` structures with `and`s.
+
+The types in this module track the information needed to emit SPIR-V code
+for complex conditional structures, like those whose conditions involve
+short-circuiting 'and' and 'or' structures. These track labels and can emit
+`OpPhi` instructions to merge values produced along different paths.
+
+This currently only supports exactly the forms Naga uses, so it doesn't
+support `or` or `else`, and only supports zero or one merged values.
+
+Naga needs to emit code roughly like this:
+
+```ignore
+
+ value = DEFAULT;
+ if COND1 && COND2 {
+ value = THEN_VALUE;
+ }
+ // use value
+
+```
+
+Assuming `ctx` and `block` are a mutable references to a [`BlockContext`]
+and the current [`Block`], and `merge_type` is the SPIR-V type for the
+merged value `value`, we can build SPIR-V for the code above like so:
+
+```ignore
+
+ let cond = Selection::start(block, merge_type);
+ // ... compute `cond1` ...
+ cond.if_true(ctx, cond1, DEFAULT);
+ // ... compute `cond2` ...
+ cond.if_true(ctx, cond2, DEFAULT);
+ // ... compute THEN_VALUE
+ let merged_value = cond.finish(ctx, THEN_VALUE);
+
+```
+
+After this, `merged_value` is either `DEFAULT` or `THEN_VALUE`, depending on
+the path by which the merged block was reached.
+
+This takes care of writing all branch instructions, including an
+`OpSelectionMerge` annotation in the header block; starting new blocks and
+assigning them labels; and emitting the `OpPhi` that gathers together the
+right sources for the merged values, for every path through the selection
+construct.
+
+When there is no merged value to produce, you can pass `()` for `merge_type`
+and the merge values. In this case no `OpPhi` instructions are produced, and
+the `finish` method returns `()`.
+
+To enforce proper nesting, a `Selection` takes ownership of the `&mut Block`
+pointer for the duration of its lifetime. To obtain the block for generating
+code in the selection's body, call the `Selection::block` method.
+*/
+
+use super::{Block, BlockContext, Instruction};
+use spirv::Word;
+
+/// A private struct recording what we know about the selection construct so far.
+pub(super) struct Selection<'b, M: MergeTuple> {
+ /// The block pointer we're emitting code into.
+ block: &'b mut Block,
+
+ /// The label of the selection construct's merge block, or `None` if we
+ /// haven't yet written the `OpSelectionMerge` merge instruction.
+ merge_label: Option<Word>,
+
+ /// A set of `(VALUES, PARENT)` pairs, used to build `OpPhi` instructions in
+ /// the merge block. Each `PARENT` is the label of a predecessor block of
+ /// the merge block. The corresponding `VALUES` holds the ids of the values
+ /// that `PARENT` contributes to the merged values.
+ ///
+ /// We emit all branches to the merge block, so we know all its
+ /// predecessors. And we refuse to emit a branch unless we're given the
+ /// values the branching block contributes to the merge, so we always have
+ /// everything we need to emit the correct phis, by construction.
+ values: Vec<(M, Word)>,
+
+ /// The types of the values in each element of `values`.
+ merge_types: M,
+}
+
+impl<'b, M: MergeTuple> Selection<'b, M> {
+ /// Start a new selection construct.
+ ///
+ /// The `block` argument indicates the selection's header block.
+ ///
+ /// The `merge_types` argument should be a `Word` or tuple of `Word`s, each
+ /// value being the SPIR-V result type id of an `OpPhi` instruction that
+ /// will be written to the selection's merge block when this selection's
+ /// [`finish`] method is called. This argument may also be `()`, for
+ /// selections that produce no values.
+ ///
+ /// (This function writes no code to `block` itself; it simply constructs a
+ /// fresh `Selection`.)
+ ///
+ /// [`finish`]: Selection::finish
+ pub(super) fn start(block: &'b mut Block, merge_types: M) -> Self {
+ Selection {
+ block,
+ merge_label: None,
+ values: vec![],
+ merge_types,
+ }
+ }
+
+ pub(super) fn block(&mut self) -> &mut Block {
+ self.block
+ }
+
+ /// Branch to a successor block if `cond` is true, otherwise merge.
+ ///
+ /// If `cond` is false, branch to the merge block, using `values` as the
+ /// merged values. Otherwise, proceed to a new block.
+ ///
+ /// The `values` argument must be the same shape as the `merge_types`
+ /// argument passed to `Selection::start`.
+ pub(super) fn if_true(&mut self, ctx: &mut BlockContext, cond: Word, values: M) {
+ self.values.push((values, self.block.label_id));
+
+ let merge_label = self.make_merge_label(ctx);
+ let next_label = ctx.gen_id();
+ ctx.function.consume(
+ std::mem::replace(self.block, Block::new(next_label)),
+ Instruction::branch_conditional(cond, next_label, merge_label),
+ );
+ }
+
+ /// Emit an unconditional branch to the merge block, and compute merged
+ /// values.
+ ///
+ /// Use `final_values` as the merged values contributed by the current
+ /// block, and transition to the merge block, emitting `OpPhi` instructions
+ /// to produce the merged values. This must be the same shape as the
+ /// `merge_types` argument passed to [`Selection::start`].
+ ///
+ /// Return the SPIR-V ids of the merged values. This value has the same
+ /// shape as the `merge_types` argument passed to `Selection::start`.
+ pub(super) fn finish(self, ctx: &mut BlockContext, final_values: M) -> M {
+ match self {
+ Selection {
+ merge_label: None, ..
+ } => {
+ // We didn't actually emit any branches, so `self.values` must
+ // be empty, and `final_values` are the only sources we have for
+ // the merged values. Easy peasy.
+ final_values
+ }
+
+ Selection {
+ block,
+ merge_label: Some(merge_label),
+ mut values,
+ merge_types,
+ } => {
+ // Emit the final branch and transition to the merge block.
+ values.push((final_values, block.label_id));
+ ctx.function.consume(
+ std::mem::replace(block, Block::new(merge_label)),
+ Instruction::branch(merge_label),
+ );
+
+ // Now that we're in the merge block, build the phi instructions.
+ merge_types.write_phis(ctx, block, &values)
+ }
+ }
+ }
+
+ /// Return the id of the merge block, writing a merge instruction if needed.
+ fn make_merge_label(&mut self, ctx: &mut BlockContext) -> Word {
+ match self.merge_label {
+ None => {
+ let merge_label = ctx.gen_id();
+ self.block.body.push(Instruction::selection_merge(
+ merge_label,
+ spirv::SelectionControl::NONE,
+ ));
+ self.merge_label = Some(merge_label);
+ merge_label
+ }
+ Some(merge_label) => merge_label,
+ }
+ }
+}
+
+/// A trait to help `Selection` manage any number of merged values.
+///
+/// Some selection constructs, like a `ReadZeroSkipWrite` bounds check on a
+/// [`Load`] expression, produce a single merged value. Others produce no merged
+/// value, like a bounds check on a [`Store`] statement.
+///
+/// To let `Selection` work nicely with both cases, we let the merge type
+/// argument passed to [`Selection::start`] be any type that implements this
+/// `MergeTuple` trait. `MergeTuple` is then implemented for `()`, `Word`,
+/// `(Word, Word)`, and so on.
+///
+/// A `MergeTuple` type can represent either a bunch of SPIR-V types or values;
+/// the `merge_types` argument to `Selection::start` are type ids, whereas the
+/// `values` arguments to the [`if_true`] and [`finish`] methods are value ids.
+/// The set of merged value returned by `finish` is a tuple of value ids.
+///
+/// In fact, since Naga only uses zero- and single-valued selection constructs
+/// at present, we only implement `MergeTuple` for `()` and `Word`. But if you
+/// add more cases, feel free to add more implementations. Once const generics
+/// are available, we could have a single implementation of `MergeTuple` for all
+/// lengths of arrays, and be done with it.
+///
+/// [`Load`]: crate::Expression::Load
+/// [`Store`]: crate::Statement::Store
+/// [`if_true`]: Selection::if_true
+/// [`finish`]: Selection::finish
+pub(super) trait MergeTuple: Sized {
+ /// Write OpPhi instructions for the given set of predecessors.
+ ///
+ /// The `predecessors` vector should be a vector of `(LABEL, VALUES)` pairs,
+ /// where each `VALUES` holds the values contributed by the branch from
+ /// `LABEL`, which should be one of the current block's predecessors.
+ fn write_phis(
+ self,
+ ctx: &mut BlockContext,
+ block: &mut Block,
+ predecessors: &[(Self, Word)],
+ ) -> Self;
+}
+
+/// Selections that produce a single merged value.
+///
+/// For example, `ImageLoad` with `BoundsCheckPolicy::ReadZeroSkipWrite` either
+/// returns a texel value or zeros.
+impl MergeTuple for Word {
+ fn write_phis(
+ self,
+ ctx: &mut BlockContext,
+ block: &mut Block,
+ predecessors: &[(Word, Word)],
+ ) -> Word {
+ let merged_value = ctx.gen_id();
+ block
+ .body
+ .push(Instruction::phi(self, merged_value, predecessors));
+ merged_value
+ }
+}
+
+/// Selections that produce no merged values.
+///
+/// For example, `ImageStore` under `BoundsCheckPolicy::ReadZeroSkipWrite`
+/// either does the store or skips it, but in neither case does it produce a
+/// value.
+impl MergeTuple for () {
+ /// No phis need to be generated.
+ fn write_phis(self, _: &mut BlockContext, _: &mut Block, _: &[((), Word)]) {}
+}
diff --git a/third_party/rust/naga/src/back/spv/writer.rs b/third_party/rust/naga/src/back/spv/writer.rs
new file mode 100644
index 0000000000..4db86c93a7
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/writer.rs
@@ -0,0 +1,2063 @@
+use super::{
+ block::DebugInfoInner,
+ helpers::{contains_builtin, global_needs_wrapper, map_storage_class},
+ make_local, Block, BlockContext, CachedConstant, CachedExpressions, DebugInfo,
+ EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction,
+ LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, LoopContext, Options,
+ PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE,
+};
+use crate::{
+ arena::{Handle, UniqueArena},
+ back::spv::BindingInfo,
+ proc::{Alignment, TypeResolution},
+ valid::{FunctionInfo, ModuleInfo},
+};
+use spirv::Word;
+use std::collections::hash_map::Entry;
+
+struct FunctionInterface<'a> {
+ varying_ids: &'a mut Vec<Word>,
+ stage: crate::ShaderStage,
+}
+
+impl Function {
+ fn to_words(&self, sink: &mut impl Extend<Word>) {
+ self.signature.as_ref().unwrap().to_words(sink);
+ for argument in self.parameters.iter() {
+ argument.instruction.to_words(sink);
+ }
+ for (index, block) in self.blocks.iter().enumerate() {
+ Instruction::label(block.label_id).to_words(sink);
+ if index == 0 {
+ for local_var in self.variables.values() {
+ local_var.instruction.to_words(sink);
+ }
+ }
+ for instruction in block.body.iter() {
+ instruction.to_words(sink);
+ }
+ }
+ }
+}
+
+impl Writer {
+ pub fn new(options: &Options) -> Result<Self, Error> {
+ let (major, minor) = options.lang_version;
+ if major != 1 {
+ return Err(Error::UnsupportedVersion(major, minor));
+ }
+ let raw_version = ((major as u32) << 16) | ((minor as u32) << 8);
+
+ let mut capabilities_used = crate::FastIndexSet::default();
+ capabilities_used.insert(spirv::Capability::Shader);
+
+ let mut id_gen = IdGenerator::default();
+ let gl450_ext_inst_id = id_gen.next();
+ let void_type = id_gen.next();
+
+ Ok(Writer {
+ physical_layout: PhysicalLayout::new(raw_version),
+ logical_layout: LogicalLayout::default(),
+ id_gen,
+ capabilities_available: options.capabilities.clone(),
+ capabilities_used,
+ extensions_used: crate::FastIndexSet::default(),
+ debugs: vec![],
+ annotations: vec![],
+ flags: options.flags,
+ bounds_check_policies: options.bounds_check_policies,
+ zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory,
+ void_type,
+ lookup_type: crate::FastHashMap::default(),
+ lookup_function: crate::FastHashMap::default(),
+ lookup_function_type: crate::FastHashMap::default(),
+ constant_ids: Vec::new(),
+ cached_constants: crate::FastHashMap::default(),
+ global_variables: Vec::new(),
+ binding_map: options.binding_map.clone(),
+ saved_cached: CachedExpressions::default(),
+ gl450_ext_inst_id,
+ temp_list: Vec::new(),
+ })
+ }
+
+ /// Reset `Writer` to its initial state, retaining any allocations.
+ ///
+ /// Why not just implement `Recyclable` for `Writer`? By design,
+ /// `Recyclable::recycle` requires ownership of the value, not just
+ /// `&mut`; see the trait documentation. But we need to use this method
+ /// from functions like `Writer::write`, which only have `&mut Writer`.
+ /// Workarounds include unsafe code (`std::ptr::read`, then `write`, ugh)
+ /// or something like a `Default` impl that returns an oddly-initialized
+ /// `Writer`, which is worse.
+ fn reset(&mut self) {
+ use super::recyclable::Recyclable;
+ use std::mem::take;
+
+ let mut id_gen = IdGenerator::default();
+ let gl450_ext_inst_id = id_gen.next();
+ let void_type = id_gen.next();
+
+ // Every field of the old writer that is not determined by the `Options`
+ // passed to `Writer::new` should be reset somehow.
+ let fresh = Writer {
+ // Copied from the old Writer:
+ flags: self.flags,
+ bounds_check_policies: self.bounds_check_policies,
+ zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory,
+ capabilities_available: take(&mut self.capabilities_available),
+ binding_map: take(&mut self.binding_map),
+
+ // Initialized afresh:
+ id_gen,
+ void_type,
+ gl450_ext_inst_id,
+
+ // Recycled:
+ capabilities_used: take(&mut self.capabilities_used).recycle(),
+ extensions_used: take(&mut self.extensions_used).recycle(),
+ physical_layout: self.physical_layout.clone().recycle(),
+ logical_layout: take(&mut self.logical_layout).recycle(),
+ debugs: take(&mut self.debugs).recycle(),
+ annotations: take(&mut self.annotations).recycle(),
+ lookup_type: take(&mut self.lookup_type).recycle(),
+ lookup_function: take(&mut self.lookup_function).recycle(),
+ lookup_function_type: take(&mut self.lookup_function_type).recycle(),
+ constant_ids: take(&mut self.constant_ids).recycle(),
+ cached_constants: take(&mut self.cached_constants).recycle(),
+ global_variables: take(&mut self.global_variables).recycle(),
+ saved_cached: take(&mut self.saved_cached).recycle(),
+ temp_list: take(&mut self.temp_list).recycle(),
+ };
+
+ *self = fresh;
+
+ self.capabilities_used.insert(spirv::Capability::Shader);
+ }
+
+ /// Indicate that the code requires any one of the listed capabilities.
+ ///
+ /// If nothing in `capabilities` appears in the available capabilities
+ /// specified in the [`Options`] from which this `Writer` was created,
+ /// return an error. The `what` string is used in the error message to
+ /// explain what provoked the requirement. (If no available capabilities were
+ /// given, assume everything is available.)
+ ///
+ /// The first acceptable capability will be added to this `Writer`'s
+ /// [`capabilities_used`] table, and an `OpCapability` emitted for it in the
+ /// result. For this reason, more specific capabilities should be listed
+ /// before more general.
+ ///
+ /// [`capabilities_used`]: Writer::capabilities_used
+ pub(super) fn require_any(
+ &mut self,
+ what: &'static str,
+ capabilities: &[spirv::Capability],
+ ) -> Result<(), Error> {
+ match *capabilities {
+ [] => Ok(()),
+ [first, ..] => {
+ // Find the first acceptable capability, or return an error if
+ // there is none.
+ let selected = match self.capabilities_available {
+ None => first,
+ Some(ref available) => {
+ match capabilities.iter().find(|cap| available.contains(cap)) {
+ Some(&cap) => cap,
+ None => {
+ return Err(Error::MissingCapabilities(what, capabilities.to_vec()))
+ }
+ }
+ }
+ };
+ self.capabilities_used.insert(selected);
+ Ok(())
+ }
+ }
+ }
+
+ /// Indicate that the code uses the given extension.
+ pub(super) fn use_extension(&mut self, extension: &'static str) {
+ self.extensions_used.insert(extension);
+ }
+
+ pub(super) fn get_type_id(&mut self, lookup_ty: LookupType) -> Word {
+ match self.lookup_type.entry(lookup_ty) {
+ Entry::Occupied(e) => *e.get(),
+ Entry::Vacant(e) => {
+ let local = match lookup_ty {
+ LookupType::Handle(_handle) => unreachable!("Handles are populated at start"),
+ LookupType::Local(local) => local,
+ };
+
+ let id = self.id_gen.next();
+ e.insert(id);
+ self.write_type_declaration_local(id, local);
+ id
+ }
+ }
+ }
+
+ pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType {
+ match *tr {
+ TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle),
+ TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()),
+ }
+ }
+
+ pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
+ let lookup_ty = self.get_expression_lookup_type(tr);
+ self.get_type_id(lookup_ty)
+ }
+
+ pub(super) fn get_pointer_id(
+ &mut self,
+ arena: &UniqueArena<crate::Type>,
+ handle: Handle<crate::Type>,
+ class: spirv::StorageClass,
+ ) -> Result<Word, Error> {
+ let ty_id = self.get_type_id(LookupType::Handle(handle));
+ if let crate::TypeInner::Pointer { .. } = arena[handle].inner {
+ return Ok(ty_id);
+ }
+ let lookup_type = LookupType::Local(LocalType::Pointer {
+ base: handle,
+ class,
+ });
+ Ok(if let Some(&id) = self.lookup_type.get(&lookup_type) {
+ id
+ } else {
+ let id = self.id_gen.next();
+ let instruction = Instruction::type_pointer(id, class, ty_id);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_type.insert(lookup_type, id);
+ id
+ })
+ }
+
+ pub(super) fn get_uint_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::U32,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn get_float_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn get_uint3_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: Some(crate::VectorSize::Tri),
+ scalar: crate::Scalar::U32,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn get_float_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
+ let lookup_type = LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::F32,
+ pointer_space: Some(class),
+ });
+ if let Some(&id) = self.lookup_type.get(&lookup_type) {
+ id
+ } else {
+ let id = self.id_gen.next();
+ let ty_id = self.get_float_type_id();
+ let instruction = Instruction::type_pointer(id, class, ty_id);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_type.insert(lookup_type, id);
+ id
+ }
+ }
+
+ pub(super) fn get_uint3_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
+ let lookup_type = LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Tri),
+ scalar: crate::Scalar::U32,
+ pointer_space: Some(class),
+ });
+ if let Some(&id) = self.lookup_type.get(&lookup_type) {
+ id
+ } else {
+ let id = self.id_gen.next();
+ let ty_id = self.get_uint3_type_id();
+ let instruction = Instruction::type_pointer(id, class, ty_id);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_type.insert(lookup_type, id);
+ id
+ }
+ }
+
+ pub(super) fn get_bool_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::BOOL,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn get_bool3_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: Some(crate::VectorSize::Tri),
+ scalar: crate::Scalar::BOOL,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) {
+ self.annotations
+ .push(Instruction::decorate(id, decoration, operands));
+ }
+
+ fn write_function(
+ &mut self,
+ ir_function: &crate::Function,
+ info: &FunctionInfo,
+ ir_module: &crate::Module,
+ mut interface: Option<FunctionInterface>,
+ debug_info: &Option<DebugInfoInner>,
+ ) -> Result<Word, Error> {
+ let mut function = Function::default();
+
+ let prelude_id = self.id_gen.next();
+ let mut prelude = Block::new(prelude_id);
+ let mut ep_context = EntryPointContext {
+ argument_ids: Vec::new(),
+ results: Vec::new(),
+ };
+
+ let mut local_invocation_id = None;
+
+ let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len());
+ for argument in ir_function.arguments.iter() {
+ let class = spirv::StorageClass::Input;
+ let handle_ty = ir_module.types[argument.ty].inner.is_handle();
+ let argument_type_id = match handle_ty {
+ true => self.get_pointer_id(
+ &ir_module.types,
+ argument.ty,
+ spirv::StorageClass::UniformConstant,
+ )?,
+ false => self.get_type_id(LookupType::Handle(argument.ty)),
+ };
+
+ if let Some(ref mut iface) = interface {
+ let id = if let Some(ref binding) = argument.binding {
+ let name = argument.name.as_deref();
+
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ name,
+ argument.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ let id = self.id_gen.next();
+ prelude
+ .body
+ .push(Instruction::load(argument_type_id, id, varying_id, None));
+
+ if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
+ local_invocation_id = Some(id);
+ }
+
+ id
+ } else if let crate::TypeInner::Struct { ref members, .. } =
+ ir_module.types[argument.ty].inner
+ {
+ let struct_id = self.id_gen.next();
+ let mut constituent_ids = Vec::with_capacity(members.len());
+ for member in members {
+ let type_id = self.get_type_id(LookupType::Handle(member.ty));
+ let name = member.name.as_deref();
+ let binding = member.binding.as_ref().unwrap();
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ name,
+ member.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ let id = self.id_gen.next();
+ prelude
+ .body
+ .push(Instruction::load(type_id, id, varying_id, None));
+ constituent_ids.push(id);
+
+ if binding == &crate::Binding::BuiltIn(crate::BuiltIn::GlobalInvocationId) {
+ local_invocation_id = Some(id);
+ }
+ }
+ prelude.body.push(Instruction::composite_construct(
+ argument_type_id,
+ struct_id,
+ &constituent_ids,
+ ));
+ struct_id
+ } else {
+ unreachable!("Missing argument binding on an entry point");
+ };
+ ep_context.argument_ids.push(id);
+ } else {
+ let argument_id = self.id_gen.next();
+ let instruction = Instruction::function_parameter(argument_type_id, argument_id);
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = argument.name {
+ self.debugs.push(Instruction::name(argument_id, name));
+ }
+ }
+ function.parameters.push(FunctionArgument {
+ instruction,
+ handle_id: if handle_ty {
+ let id = self.id_gen.next();
+ prelude.body.push(Instruction::load(
+ self.get_type_id(LookupType::Handle(argument.ty)),
+ id,
+ argument_id,
+ None,
+ ));
+ id
+ } else {
+ 0
+ },
+ });
+ parameter_type_ids.push(argument_type_id);
+ };
+ }
+
+ let return_type_id = match ir_function.result {
+ Some(ref result) => {
+ if let Some(ref mut iface) = interface {
+ let mut has_point_size = false;
+ let class = spirv::StorageClass::Output;
+ if let Some(ref binding) = result.binding {
+ has_point_size |=
+ *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
+ let type_id = self.get_type_id(LookupType::Handle(result.ty));
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ None,
+ result.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ ep_context.results.push(ResultMember {
+ id: varying_id,
+ type_id,
+ built_in: binding.to_built_in(),
+ });
+ } else if let crate::TypeInner::Struct { ref members, .. } =
+ ir_module.types[result.ty].inner
+ {
+ for member in members {
+ let type_id = self.get_type_id(LookupType::Handle(member.ty));
+ let name = member.name.as_deref();
+ let binding = member.binding.as_ref().unwrap();
+ has_point_size |=
+ *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ name,
+ member.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ ep_context.results.push(ResultMember {
+ id: varying_id,
+ type_id,
+ built_in: binding.to_built_in(),
+ });
+ }
+ } else {
+ unreachable!("Missing result binding on an entry point");
+ }
+
+ if self.flags.contains(WriterFlags::FORCE_POINT_SIZE)
+ && iface.stage == crate::ShaderStage::Vertex
+ && !has_point_size
+ {
+ // add point size artificially
+ let varying_id = self.id_gen.next();
+ let pointer_type_id = self.get_float_pointer_type_id(class);
+ Instruction::variable(pointer_type_id, varying_id, class, None)
+ .to_words(&mut self.logical_layout.declarations);
+ self.decorate(
+ varying_id,
+ spirv::Decoration::BuiltIn,
+ &[spirv::BuiltIn::PointSize as u32],
+ );
+ iface.varying_ids.push(varying_id);
+
+ let default_value_id = self.get_constant_scalar(crate::Literal::F32(1.0));
+ prelude
+ .body
+ .push(Instruction::store(varying_id, default_value_id, None));
+ }
+ self.void_type
+ } else {
+ self.get_type_id(LookupType::Handle(result.ty))
+ }
+ }
+ None => self.void_type,
+ };
+
+ let lookup_function_type = LookupFunctionType {
+ parameter_type_ids,
+ return_type_id,
+ };
+
+ let function_id = self.id_gen.next();
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = ir_function.name {
+ self.debugs.push(Instruction::name(function_id, name));
+ }
+ }
+
+ let function_type = self.get_function_type(lookup_function_type);
+ function.signature = Some(Instruction::function(
+ return_type_id,
+ function_id,
+ spirv::FunctionControl::empty(),
+ function_type,
+ ));
+
+ if interface.is_some() {
+ function.entry_point_context = Some(ep_context);
+ }
+
+ // fill up the `GlobalVariable::access_id`
+ for gv in self.global_variables.iter_mut() {
+ gv.reset_for_function();
+ }
+ for (handle, var) in ir_module.global_variables.iter() {
+ if info[handle].is_empty() {
+ continue;
+ }
+
+ let mut gv = self.global_variables[handle.index()].clone();
+ if let Some(ref mut iface) = interface {
+ // Have to include global variables in the interface
+ if self.physical_layout.version >= 0x10400 {
+ iface.varying_ids.push(gv.var_id);
+ }
+ }
+
+ // Handle globals are pre-emitted and should be loaded automatically.
+ //
+ // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
+ let is_binding_array = match ir_module.types[var.ty].inner {
+ crate::TypeInner::BindingArray { .. } => true,
+ _ => false,
+ };
+
+ if var.space == crate::AddressSpace::Handle && !is_binding_array {
+ let var_type_id = self.get_type_id(LookupType::Handle(var.ty));
+ let id = self.id_gen.next();
+ prelude
+ .body
+ .push(Instruction::load(var_type_id, id, gv.var_id, None));
+ gv.access_id = gv.var_id;
+ gv.handle_id = id;
+ } else if global_needs_wrapper(ir_module, var) {
+ let class = map_storage_class(var.space);
+ let pointer_type_id = self.get_pointer_id(&ir_module.types, var.ty, class)?;
+ let index_id = self.get_index_constant(0);
+
+ let id = self.id_gen.next();
+ prelude.body.push(Instruction::access_chain(
+ pointer_type_id,
+ id,
+ gv.var_id,
+ &[index_id],
+ ));
+ gv.access_id = id;
+ } else {
+ // by default, the variable ID is accessed as is
+ gv.access_id = gv.var_id;
+ };
+
+ // work around borrow checking in the presence of `self.xxx()` calls
+ self.global_variables[handle.index()] = gv;
+ }
+
+ // Create a `BlockContext` for generating SPIR-V for the function's
+ // body.
+ let mut context = BlockContext {
+ ir_module,
+ ir_function,
+ fun_info: info,
+ function: &mut function,
+ // Re-use the cached expression table from prior functions.
+ cached: std::mem::take(&mut self.saved_cached),
+
+ // Steal the Writer's temp list for a bit.
+ temp_list: std::mem::take(&mut self.temp_list),
+ writer: self,
+ expression_constness: crate::proc::ExpressionConstnessTracker::from_arena(
+ &ir_function.expressions,
+ ),
+ };
+
+ // fill up the pre-emitted and const expressions
+ context.cached.reset(ir_function.expressions.len());
+ for (handle, expr) in ir_function.expressions.iter() {
+ if (expr.needs_pre_emit() && !matches!(*expr, crate::Expression::LocalVariable(_)))
+ || context.expression_constness.is_const(handle)
+ {
+ context.cache_expression_value(handle, &mut prelude)?;
+ }
+ }
+
+ for (handle, variable) in ir_function.local_variables.iter() {
+ let id = context.gen_id();
+
+ if context.writer.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = variable.name {
+ context.writer.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ let init_word = variable.init.map(|constant| context.cached[constant]);
+ let pointer_type_id = context.writer.get_pointer_id(
+ &ir_module.types,
+ variable.ty,
+ spirv::StorageClass::Function,
+ )?;
+ let instruction = Instruction::variable(
+ pointer_type_id,
+ id,
+ spirv::StorageClass::Function,
+ init_word.or_else(|| match ir_module.types[variable.ty].inner {
+ crate::TypeInner::RayQuery => None,
+ _ => {
+ let type_id = context.get_type_id(LookupType::Handle(variable.ty));
+ Some(context.writer.write_constant_null(type_id))
+ }
+ }),
+ );
+ context
+ .function
+ .variables
+ .insert(handle, LocalVariable { id, instruction });
+ }
+
+ // cache local variable expressions
+ for (handle, expr) in ir_function.expressions.iter() {
+ if matches!(*expr, crate::Expression::LocalVariable(_)) {
+ context.cache_expression_value(handle, &mut prelude)?;
+ }
+ }
+
+ let next_id = context.gen_id();
+
+ context
+ .function
+ .consume(prelude, Instruction::branch(next_id));
+
+ let workgroup_vars_init_exit_block_id =
+ match (context.writer.zero_initialize_workgroup_memory, interface) {
+ (
+ super::ZeroInitializeWorkgroupMemoryMode::Polyfill,
+ Some(
+ ref mut interface @ FunctionInterface {
+ stage: crate::ShaderStage::Compute,
+ ..
+ },
+ ),
+ ) => context.writer.generate_workgroup_vars_init_block(
+ next_id,
+ ir_module,
+ info,
+ local_invocation_id,
+ interface,
+ context.function,
+ ),
+ _ => None,
+ };
+
+ let main_id = if let Some(exit_id) = workgroup_vars_init_exit_block_id {
+ exit_id
+ } else {
+ next_id
+ };
+
+ context.write_block(
+ main_id,
+ &ir_function.body,
+ super::block::BlockExit::Return,
+ LoopContext::default(),
+ debug_info.as_ref(),
+ )?;
+
+ // Consume the `BlockContext`, ending its borrows and letting the
+ // `Writer` steal back its cached expression table and temp_list.
+ let BlockContext {
+ cached, temp_list, ..
+ } = context;
+ self.saved_cached = cached;
+ self.temp_list = temp_list;
+
+ function.to_words(&mut self.logical_layout.function_definitions);
+ Instruction::function_end().to_words(&mut self.logical_layout.function_definitions);
+
+ Ok(function_id)
+ }
+
+ fn write_execution_mode(
+ &mut self,
+ function_id: Word,
+ mode: spirv::ExecutionMode,
+ ) -> Result<(), Error> {
+ //self.check(mode.required_capabilities())?;
+ Instruction::execution_mode(function_id, mode, &[])
+ .to_words(&mut self.logical_layout.execution_modes);
+ Ok(())
+ }
+
+ // TODO Move to instructions module
+ fn write_entry_point(
+ &mut self,
+ entry_point: &crate::EntryPoint,
+ info: &FunctionInfo,
+ ir_module: &crate::Module,
+ debug_info: &Option<DebugInfoInner>,
+ ) -> Result<Instruction, Error> {
+ let mut interface_ids = Vec::new();
+ let function_id = self.write_function(
+ &entry_point.function,
+ info,
+ ir_module,
+ Some(FunctionInterface {
+ varying_ids: &mut interface_ids,
+ stage: entry_point.stage,
+ }),
+ debug_info,
+ )?;
+
+ let exec_model = match entry_point.stage {
+ crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
+ crate::ShaderStage::Fragment => {
+ self.write_execution_mode(function_id, spirv::ExecutionMode::OriginUpperLeft)?;
+ if let Some(ref result) = entry_point.function.result {
+ if contains_builtin(
+ result.binding.as_ref(),
+ result.ty,
+ &ir_module.types,
+ crate::BuiltIn::FragDepth,
+ ) {
+ self.write_execution_mode(
+ function_id,
+ spirv::ExecutionMode::DepthReplacing,
+ )?;
+ }
+ }
+ spirv::ExecutionModel::Fragment
+ }
+ crate::ShaderStage::Compute => {
+ let execution_mode = spirv::ExecutionMode::LocalSize;
+ //self.check(execution_mode.required_capabilities())?;
+ Instruction::execution_mode(
+ function_id,
+ execution_mode,
+ &entry_point.workgroup_size,
+ )
+ .to_words(&mut self.logical_layout.execution_modes);
+ spirv::ExecutionModel::GLCompute
+ }
+ };
+ //self.check(exec_model.required_capabilities())?;
+
+ Ok(Instruction::entry_point(
+ exec_model,
+ function_id,
+ &entry_point.name,
+ interface_ids.as_slice(),
+ ))
+ }
+
+ fn make_scalar(&mut self, id: Word, scalar: crate::Scalar) -> Instruction {
+ use crate::ScalarKind as Sk;
+
+ let bits = (scalar.width * BITS_PER_BYTE) as u32;
+ match scalar.kind {
+ Sk::Sint | Sk::Uint => {
+ let signedness = if scalar.kind == Sk::Sint {
+ super::instructions::Signedness::Signed
+ } else {
+ super::instructions::Signedness::Unsigned
+ };
+ let cap = match bits {
+ 8 => Some(spirv::Capability::Int8),
+ 16 => Some(spirv::Capability::Int16),
+ 64 => Some(spirv::Capability::Int64),
+ _ => None,
+ };
+ if let Some(cap) = cap {
+ self.capabilities_used.insert(cap);
+ }
+ Instruction::type_int(id, bits, signedness)
+ }
+ Sk::Float => {
+ if bits == 64 {
+ self.capabilities_used.insert(spirv::Capability::Float64);
+ }
+ Instruction::type_float(id, bits)
+ }
+ Sk::Bool => Instruction::type_bool(id),
+ Sk::AbstractInt | Sk::AbstractFloat => {
+ unreachable!("abstract types should never reach the backend");
+ }
+ }
+ }
+
+ fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> {
+ match *inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ let sampled = match class {
+ crate::ImageClass::Sampled { .. } => true,
+ crate::ImageClass::Depth { .. } => true,
+ crate::ImageClass::Storage { format, .. } => {
+ self.request_image_format_capabilities(format.into())?;
+ false
+ }
+ };
+
+ match dim {
+ crate::ImageDimension::D1 => {
+ if sampled {
+ self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?;
+ } else {
+ self.require_any("1D storage images", &[spirv::Capability::Image1D])?;
+ }
+ }
+ crate::ImageDimension::Cube if arrayed => {
+ if sampled {
+ self.require_any(
+ "sampled cube array images",
+ &[spirv::Capability::SampledCubeArray],
+ )?;
+ } else {
+ self.require_any(
+ "cube array storage images",
+ &[spirv::Capability::ImageCubeArray],
+ )?;
+ }
+ }
+ _ => {}
+ }
+ }
+ crate::TypeInner::AccelerationStructure => {
+ self.require_any("Acceleration Structure", &[spirv::Capability::RayQueryKHR])?;
+ }
+ crate::TypeInner::RayQuery => {
+ self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?;
+ }
+ _ => {}
+ }
+ Ok(())
+ }
+
+ fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) {
+ let instruction = match local_ty {
+ LocalType::Value {
+ vector_size: None,
+ scalar,
+ pointer_space: None,
+ } => self.make_scalar(id, scalar),
+ LocalType::Value {
+ vector_size: Some(size),
+ scalar,
+ pointer_space: None,
+ } => {
+ let scalar_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar,
+ pointer_space: None,
+ }));
+ Instruction::type_vector(id, scalar_id, size)
+ }
+ LocalType::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let vector_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(rows),
+ scalar: crate::Scalar::float(width),
+ pointer_space: None,
+ }));
+ Instruction::type_matrix(id, vector_id, columns)
+ }
+ LocalType::Pointer { base, class } => {
+ let type_id = self.get_type_id(LookupType::Handle(base));
+ Instruction::type_pointer(id, class, type_id)
+ }
+ LocalType::Value {
+ vector_size,
+ scalar,
+ pointer_space: Some(class),
+ } => {
+ let type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size,
+ scalar,
+ pointer_space: None,
+ }));
+ Instruction::type_pointer(id, class, type_id)
+ }
+ LocalType::Image(image) => {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar {
+ kind: image.sampled_type,
+ width: 4,
+ },
+ pointer_space: None,
+ };
+ let type_id = self.get_type_id(LookupType::Local(local_type));
+ Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format)
+ }
+ LocalType::Sampler => Instruction::type_sampler(id),
+ LocalType::SampledImage { image_type_id } => {
+ Instruction::type_sampled_image(id, image_type_id)
+ }
+ LocalType::BindingArray { base, size } => {
+ let inner_ty = self.get_type_id(LookupType::Handle(base));
+ let scalar_id = self.get_constant_scalar(crate::Literal::U32(size));
+ Instruction::type_array(id, inner_ty, scalar_id)
+ }
+ LocalType::PointerToBindingArray { base, size, space } => {
+ let inner_ty =
+ self.get_type_id(LookupType::Local(LocalType::BindingArray { base, size }));
+ let class = map_storage_class(space);
+ Instruction::type_pointer(id, class, inner_ty)
+ }
+ LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id),
+ LocalType::RayQuery => Instruction::type_ray_query(id),
+ };
+
+ instruction.to_words(&mut self.logical_layout.declarations);
+ }
+
+ fn write_type_declaration_arena(
+ &mut self,
+ arena: &UniqueArena<crate::Type>,
+ handle: Handle<crate::Type>,
+ ) -> Result<Word, Error> {
+ let ty = &arena[handle];
+ let id = if let Some(local) = make_local(&ty.inner) {
+ // This type can be represented as a `LocalType`, so check if we've
+ // already written an instruction for it. If not, do so now, with
+ // `write_type_declaration_local`.
+ match self.lookup_type.entry(LookupType::Local(local)) {
+ // We already have an id for this `LocalType`.
+ Entry::Occupied(e) => *e.get(),
+
+ // It's a type we haven't seen before.
+ Entry::Vacant(e) => {
+ let id = self.id_gen.next();
+ e.insert(id);
+
+ self.write_type_declaration_local(id, local);
+
+ // If it's a type that needs SPIR-V capabilities, request them now,
+ // so write_type_declaration_local can stay infallible.
+ self.request_type_capabilities(&ty.inner)?;
+
+ id
+ }
+ }
+ } else {
+ use spirv::Decoration;
+
+ let id = self.id_gen.next();
+ let instruction = match ty.inner {
+ crate::TypeInner::Array { base, size, stride } => {
+ self.decorate(id, Decoration::ArrayStride, &[stride]);
+
+ let type_id = self.get_type_id(LookupType::Handle(base));
+ match size {
+ crate::ArraySize::Constant(length) => {
+ let length_id = self.get_index_constant(length.get());
+ Instruction::type_array(id, type_id, length_id)
+ }
+ crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id),
+ }
+ }
+ crate::TypeInner::BindingArray { base, size } => {
+ let type_id = self.get_type_id(LookupType::Handle(base));
+ match size {
+ crate::ArraySize::Constant(length) => {
+ let length_id = self.get_index_constant(length.get());
+ Instruction::type_array(id, type_id, length_id)
+ }
+ crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id),
+ }
+ }
+ crate::TypeInner::Struct {
+ ref members,
+ span: _,
+ } => {
+ let mut has_runtime_array = false;
+ let mut member_ids = Vec::with_capacity(members.len());
+ for (index, member) in members.iter().enumerate() {
+ let member_ty = &arena[member.ty];
+ match member_ty.inner {
+ crate::TypeInner::Array {
+ base: _,
+ size: crate::ArraySize::Dynamic,
+ stride: _,
+ } => {
+ has_runtime_array = true;
+ }
+ _ => (),
+ }
+ self.decorate_struct_member(id, index, member, arena)?;
+ let member_id = self.get_type_id(LookupType::Handle(member.ty));
+ member_ids.push(member_id);
+ }
+ if has_runtime_array {
+ self.decorate(id, Decoration::Block, &[]);
+ }
+ Instruction::type_struct(id, member_ids.as_slice())
+ }
+
+ // These all have TypeLocal representations, so they should have been
+ // handled by `write_type_declaration_local` above.
+ crate::TypeInner::Scalar(_)
+ | crate::TypeInner::Atomic(_)
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. }
+ | crate::TypeInner::Pointer { .. }
+ | crate::TypeInner::ValuePointer { .. }
+ | crate::TypeInner::Image { .. }
+ | crate::TypeInner::Sampler { .. }
+ | crate::TypeInner::AccelerationStructure
+ | crate::TypeInner::RayQuery => unreachable!(),
+ };
+
+ instruction.to_words(&mut self.logical_layout.declarations);
+ id
+ };
+
+ // Add this handle as a new alias for that type.
+ self.lookup_type.insert(LookupType::Handle(handle), id);
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = ty.name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ Ok(id)
+ }
+
+ fn request_image_format_capabilities(
+ &mut self,
+ format: spirv::ImageFormat,
+ ) -> Result<(), Error> {
+ use spirv::ImageFormat as If;
+ match format {
+ If::Rg32f
+ | If::Rg16f
+ | If::R11fG11fB10f
+ | If::R16f
+ | If::Rgba16
+ | If::Rgb10A2
+ | If::Rg16
+ | If::Rg8
+ | If::R16
+ | If::R8
+ | If::Rgba16Snorm
+ | If::Rg16Snorm
+ | If::Rg8Snorm
+ | If::R16Snorm
+ | If::R8Snorm
+ | If::Rg32i
+ | If::Rg16i
+ | If::Rg8i
+ | If::R16i
+ | If::R8i
+ | If::Rgb10a2ui
+ | If::Rg32ui
+ | If::Rg16ui
+ | If::Rg8ui
+ | If::R16ui
+ | If::R8ui => self.require_any(
+ "storage image format",
+ &[spirv::Capability::StorageImageExtendedFormats],
+ ),
+ If::R64ui | If::R64i => self.require_any(
+ "64-bit integer storage image format",
+ &[spirv::Capability::Int64ImageEXT],
+ ),
+ If::Unknown
+ | If::Rgba32f
+ | If::Rgba16f
+ | If::R32f
+ | If::Rgba8
+ | If::Rgba8Snorm
+ | If::Rgba32i
+ | If::Rgba16i
+ | If::Rgba8i
+ | If::R32i
+ | If::Rgba32ui
+ | If::Rgba16ui
+ | If::Rgba8ui
+ | If::R32ui => Ok(()),
+ }
+ }
+
+ pub(super) fn get_index_constant(&mut self, index: Word) -> Word {
+ self.get_constant_scalar(crate::Literal::U32(index))
+ }
+
+ pub(super) fn get_constant_scalar_with(
+ &mut self,
+ value: u8,
+ scalar: crate::Scalar,
+ ) -> Result<Word, Error> {
+ Ok(
+ self.get_constant_scalar(crate::Literal::new(value, scalar).ok_or(
+ Error::Validation("Unexpected kind and/or width for Literal"),
+ )?),
+ )
+ }
+
+ pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word {
+ let scalar = CachedConstant::Literal(value);
+ if let Some(&id) = self.cached_constants.get(&scalar) {
+ return id;
+ }
+ let id = self.id_gen.next();
+ self.write_constant_scalar(id, &value, None);
+ self.cached_constants.insert(scalar, id);
+ id
+ }
+
+ fn write_constant_scalar(
+ &mut self,
+ id: Word,
+ value: &crate::Literal,
+ debug_name: Option<&String>,
+ ) {
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(name) = debug_name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+ let type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: value.scalar(),
+ pointer_space: None,
+ }));
+ let instruction = match *value {
+ crate::Literal::F64(value) => {
+ let bits = value.to_bits();
+ Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32)
+ }
+ crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()),
+ crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value),
+ crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32),
+ crate::Literal::I64(value) => {
+ Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
+ }
+ crate::Literal::Bool(true) => Instruction::constant_true(type_id, id),
+ crate::Literal::Bool(false) => Instruction::constant_false(type_id, id),
+ crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
+ unreachable!("Abstract types should not appear in IR presented to backends");
+ }
+ };
+
+ instruction.to_words(&mut self.logical_layout.declarations);
+ }
+
+ pub(super) fn get_constant_composite(
+ &mut self,
+ ty: LookupType,
+ constituent_ids: &[Word],
+ ) -> Word {
+ let composite = CachedConstant::Composite {
+ ty,
+ constituent_ids: constituent_ids.to_vec(),
+ };
+ if let Some(&id) = self.cached_constants.get(&composite) {
+ return id;
+ }
+ let id = self.id_gen.next();
+ self.write_constant_composite(id, ty, constituent_ids, None);
+ self.cached_constants.insert(composite, id);
+ id
+ }
+
+ fn write_constant_composite(
+ &mut self,
+ id: Word,
+ ty: LookupType,
+ constituent_ids: &[Word],
+ debug_name: Option<&String>,
+ ) {
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(name) = debug_name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+ let type_id = self.get_type_id(ty);
+ Instruction::constant_composite(type_id, id, constituent_ids)
+ .to_words(&mut self.logical_layout.declarations);
+ }
+
+ pub(super) fn get_constant_null(&mut self, type_id: Word) -> Word {
+ let null = CachedConstant::ZeroValue(type_id);
+ if let Some(&id) = self.cached_constants.get(&null) {
+ return id;
+ }
+ let id = self.write_constant_null(type_id);
+ self.cached_constants.insert(null, id);
+ id
+ }
+
+ pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word {
+ let null_id = self.id_gen.next();
+ Instruction::constant_null(type_id, null_id)
+ .to_words(&mut self.logical_layout.declarations);
+ null_id
+ }
+
+ fn write_constant_expr(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ ir_module: &crate::Module,
+ mod_info: &ModuleInfo,
+ ) -> Result<Word, Error> {
+ let id = match ir_module.const_expressions[handle] {
+ crate::Expression::Literal(literal) => self.get_constant_scalar(literal),
+ crate::Expression::Constant(constant) => {
+ let constant = &ir_module.constants[constant];
+ self.constant_ids[constant.init.index()]
+ }
+ crate::Expression::ZeroValue(ty) => {
+ let type_id = self.get_type_id(LookupType::Handle(ty));
+ self.get_constant_null(type_id)
+ }
+ crate::Expression::Compose { ty, ref components } => {
+ let component_ids: Vec<_> = crate::proc::flatten_compose(
+ ty,
+ components,
+ &ir_module.const_expressions,
+ &ir_module.types,
+ )
+ .map(|component| self.constant_ids[component.index()])
+ .collect();
+ self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice())
+ }
+ crate::Expression::Splat { size, value } => {
+ let value_id = self.constant_ids[value.index()];
+ let component_ids = &[value_id; 4][..size as usize];
+
+ let ty = self.get_expression_lookup_type(&mod_info[handle]);
+
+ self.get_constant_composite(ty, component_ids)
+ }
+ _ => unreachable!(),
+ };
+
+ self.constant_ids[handle.index()] = id;
+
+ Ok(id)
+ }
+
+ pub(super) fn write_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
+ let memory_scope = if flags.contains(crate::Barrier::STORAGE) {
+ spirv::Scope::Device
+ } else {
+ spirv::Scope::Workgroup
+ };
+ let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
+ semantics.set(
+ spirv::MemorySemantics::UNIFORM_MEMORY,
+ flags.contains(crate::Barrier::STORAGE),
+ );
+ semantics.set(
+ spirv::MemorySemantics::WORKGROUP_MEMORY,
+ flags.contains(crate::Barrier::WORK_GROUP),
+ );
+ let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32);
+ let mem_scope_id = self.get_index_constant(memory_scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ block.body.push(Instruction::control_barrier(
+ exec_scope_id,
+ mem_scope_id,
+ semantics_id,
+ ));
+ }
+
+ fn generate_workgroup_vars_init_block(
+ &mut self,
+ entry_id: Word,
+ ir_module: &crate::Module,
+ info: &FunctionInfo,
+ local_invocation_id: Option<Word>,
+ interface: &mut FunctionInterface,
+ function: &mut Function,
+ ) -> Option<Word> {
+ let body = ir_module
+ .global_variables
+ .iter()
+ .filter(|&(handle, var)| {
+ !info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ })
+ .map(|(handle, var)| {
+ // It's safe to use `var_id` here, not `access_id`, because only
+ // variables in the `Uniform` and `StorageBuffer` address spaces
+ // get wrapped, and we're initializing `WorkGroup` variables.
+ let var_id = self.global_variables[handle.index()].var_id;
+ let var_type_id = self.get_type_id(LookupType::Handle(var.ty));
+ let init_word = self.get_constant_null(var_type_id);
+ Instruction::store(var_id, init_word, None)
+ })
+ .collect::<Vec<_>>();
+
+ if body.is_empty() {
+ return None;
+ }
+
+ let uint3_type_id = self.get_uint3_type_id();
+
+ let mut pre_if_block = Block::new(entry_id);
+
+ let local_invocation_id = if let Some(local_invocation_id) = local_invocation_id {
+ local_invocation_id
+ } else {
+ let varying_id = self.id_gen.next();
+ let class = spirv::StorageClass::Input;
+ let pointer_type_id = self.get_uint3_pointer_type_id(class);
+
+ Instruction::variable(pointer_type_id, varying_id, class, None)
+ .to_words(&mut self.logical_layout.declarations);
+
+ self.decorate(
+ varying_id,
+ spirv::Decoration::BuiltIn,
+ &[spirv::BuiltIn::LocalInvocationId as u32],
+ );
+
+ interface.varying_ids.push(varying_id);
+ let id = self.id_gen.next();
+ pre_if_block
+ .body
+ .push(Instruction::load(uint3_type_id, id, varying_id, None));
+
+ id
+ };
+
+ let zero_id = self.get_constant_null(uint3_type_id);
+ let bool3_type_id = self.get_bool3_type_id();
+
+ let eq_id = self.id_gen.next();
+ pre_if_block.body.push(Instruction::binary(
+ spirv::Op::IEqual,
+ bool3_type_id,
+ eq_id,
+ local_invocation_id,
+ zero_id,
+ ));
+
+ let condition_id = self.id_gen.next();
+ let bool_type_id = self.get_bool_type_id();
+ pre_if_block.body.push(Instruction::relational(
+ spirv::Op::All,
+ bool_type_id,
+ condition_id,
+ eq_id,
+ ));
+
+ let merge_id = self.id_gen.next();
+ pre_if_block.body.push(Instruction::selection_merge(
+ merge_id,
+ spirv::SelectionControl::NONE,
+ ));
+
+ let accept_id = self.id_gen.next();
+ function.consume(
+ pre_if_block,
+ Instruction::branch_conditional(condition_id, accept_id, merge_id),
+ );
+
+ let accept_block = Block {
+ label_id: accept_id,
+ body,
+ };
+ function.consume(accept_block, Instruction::branch(merge_id));
+
+ let mut post_if_block = Block::new(merge_id);
+
+ self.write_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block);
+
+ let next_id = self.id_gen.next();
+ function.consume(post_if_block, Instruction::branch(next_id));
+ Some(next_id)
+ }
+
+ /// Generate an `OpVariable` for one value in an [`EntryPoint`]'s IO interface.
+ ///
+ /// The [`Binding`]s of the arguments and result of an [`EntryPoint`]'s
+ /// [`Function`] describe a SPIR-V shader interface. In SPIR-V, the
+ /// interface is represented by global variables in the `Input` and `Output`
+ /// storage classes, with decorations indicating which builtin or location
+ /// each variable corresponds to.
+ ///
+ /// This function emits a single global `OpVariable` for a single value from
+ /// the interface, and adds appropriate decorations to indicate which
+ /// builtin or location it represents, how it should be interpolated, and so
+ /// on. The `class` argument gives the variable's SPIR-V storage class,
+ /// which should be either [`Input`] or [`Output`].
+ ///
+ /// [`Binding`]: crate::Binding
+ /// [`Function`]: crate::Function
+ /// [`EntryPoint`]: crate::EntryPoint
+ /// [`Input`]: spirv::StorageClass::Input
+ /// [`Output`]: spirv::StorageClass::Output
+ fn write_varying(
+ &mut self,
+ ir_module: &crate::Module,
+ stage: crate::ShaderStage,
+ class: spirv::StorageClass,
+ debug_name: Option<&str>,
+ ty: Handle<crate::Type>,
+ binding: &crate::Binding,
+ ) -> Result<Word, Error> {
+ let id = self.id_gen.next();
+ let pointer_type_id = self.get_pointer_id(&ir_module.types, ty, class)?;
+ Instruction::variable(pointer_type_id, id, class, None)
+ .to_words(&mut self.logical_layout.declarations);
+
+ if self
+ .flags
+ .contains(WriterFlags::DEBUG | WriterFlags::LABEL_VARYINGS)
+ {
+ if let Some(name) = debug_name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ use spirv::{BuiltIn, Decoration};
+
+ match *binding {
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ second_blend_source,
+ } => {
+ self.decorate(id, Decoration::Location, &[location]);
+
+ let no_decorations =
+ // VUID-StandaloneSpirv-Flat-06202
+ // > The Flat, NoPerspective, Sample, and Centroid decorations
+ // > must not be used on variables with the Input storage class in a vertex shader
+ (class == spirv::StorageClass::Input && stage == crate::ShaderStage::Vertex) ||
+ // VUID-StandaloneSpirv-Flat-06201
+ // > The Flat, NoPerspective, Sample, and Centroid decorations
+ // > must not be used on variables with the Output storage class in a fragment shader
+ (class == spirv::StorageClass::Output && stage == crate::ShaderStage::Fragment);
+
+ if !no_decorations {
+ match interpolation {
+ // Perspective-correct interpolation is the default in SPIR-V.
+ None | Some(crate::Interpolation::Perspective) => (),
+ Some(crate::Interpolation::Flat) => {
+ self.decorate(id, Decoration::Flat, &[]);
+ }
+ Some(crate::Interpolation::Linear) => {
+ self.decorate(id, Decoration::NoPerspective, &[]);
+ }
+ }
+ match sampling {
+ // Center sampling is the default in SPIR-V.
+ None | Some(crate::Sampling::Center) => (),
+ Some(crate::Sampling::Centroid) => {
+ self.decorate(id, Decoration::Centroid, &[]);
+ }
+ Some(crate::Sampling::Sample) => {
+ self.require_any(
+ "per-sample interpolation",
+ &[spirv::Capability::SampleRateShading],
+ )?;
+ self.decorate(id, Decoration::Sample, &[]);
+ }
+ }
+ }
+ if second_blend_source {
+ self.decorate(id, Decoration::Index, &[1]);
+ }
+ }
+ crate::Binding::BuiltIn(built_in) => {
+ use crate::BuiltIn as Bi;
+ let built_in = match built_in {
+ Bi::Position { invariant } => {
+ if invariant {
+ self.decorate(id, Decoration::Invariant, &[]);
+ }
+
+ if class == spirv::StorageClass::Output {
+ BuiltIn::Position
+ } else {
+ BuiltIn::FragCoord
+ }
+ }
+ Bi::ViewIndex => {
+ self.require_any("`view_index` built-in", &[spirv::Capability::MultiView])?;
+ BuiltIn::ViewIndex
+ }
+ // vertex
+ Bi::BaseInstance => BuiltIn::BaseInstance,
+ Bi::BaseVertex => BuiltIn::BaseVertex,
+ Bi::ClipDistance => {
+ self.require_any(
+ "`clip_distance` built-in",
+ &[spirv::Capability::ClipDistance],
+ )?;
+ BuiltIn::ClipDistance
+ }
+ Bi::CullDistance => {
+ self.require_any(
+ "`cull_distance` built-in",
+ &[spirv::Capability::CullDistance],
+ )?;
+ BuiltIn::CullDistance
+ }
+ Bi::InstanceIndex => BuiltIn::InstanceIndex,
+ Bi::PointSize => BuiltIn::PointSize,
+ Bi::VertexIndex => BuiltIn::VertexIndex,
+ // fragment
+ Bi::FragDepth => BuiltIn::FragDepth,
+ Bi::PointCoord => BuiltIn::PointCoord,
+ Bi::FrontFacing => BuiltIn::FrontFacing,
+ Bi::PrimitiveIndex => {
+ self.require_any(
+ "`primitive_index` built-in",
+ &[spirv::Capability::Geometry],
+ )?;
+ BuiltIn::PrimitiveId
+ }
+ Bi::SampleIndex => {
+ self.require_any(
+ "`sample_index` built-in",
+ &[spirv::Capability::SampleRateShading],
+ )?;
+
+ BuiltIn::SampleId
+ }
+ Bi::SampleMask => BuiltIn::SampleMask,
+ // compute
+ Bi::GlobalInvocationId => BuiltIn::GlobalInvocationId,
+ Bi::LocalInvocationId => BuiltIn::LocalInvocationId,
+ Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
+ Bi::WorkGroupId => BuiltIn::WorkgroupId,
+ Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
+ Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
+ };
+
+ self.decorate(id, Decoration::BuiltIn, &[built_in as u32]);
+
+ use crate::ScalarKind as Sk;
+
+ // Per the Vulkan spec, `VUID-StandaloneSpirv-Flat-04744`:
+ //
+ // > Any variable with integer or double-precision floating-
+ // > point type and with Input storage class in a fragment
+ // > shader, must be decorated Flat
+ if class == spirv::StorageClass::Input && stage == crate::ShaderStage::Fragment {
+ let is_flat = match ir_module.types[ty].inner {
+ crate::TypeInner::Scalar(scalar)
+ | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
+ Sk::Uint | Sk::Sint | Sk::Bool => true,
+ Sk::Float => false,
+ Sk::AbstractInt | Sk::AbstractFloat => {
+ return Err(Error::Validation(
+ "Abstract types should not appear in IR presented to backends",
+ ))
+ }
+ },
+ _ => false,
+ };
+
+ if is_flat {
+ self.decorate(id, Decoration::Flat, &[]);
+ }
+ }
+ }
+ }
+
+ Ok(id)
+ }
+
+ fn write_global_variable(
+ &mut self,
+ ir_module: &crate::Module,
+ global_variable: &crate::GlobalVariable,
+ ) -> Result<Word, Error> {
+ use spirv::Decoration;
+
+ let id = self.id_gen.next();
+ let class = map_storage_class(global_variable.space);
+
+ //self.check(class.required_capabilities())?;
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = global_variable.name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ let storage_access = match global_variable.space {
+ crate::AddressSpace::Storage { access } => Some(access),
+ _ => match ir_module.types[global_variable.ty].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => Some(access),
+ _ => None,
+ },
+ };
+ if let Some(storage_access) = storage_access {
+ if !storage_access.contains(crate::StorageAccess::LOAD) {
+ self.decorate(id, Decoration::NonReadable, &[]);
+ }
+ if !storage_access.contains(crate::StorageAccess::STORE) {
+ self.decorate(id, Decoration::NonWritable, &[]);
+ }
+ }
+
+ // Note: we should be able to substitute `binding_array<Foo, 0>`,
+ // but there is still code that tries to register the pre-substituted type,
+ // and it is failing on 0.
+ let mut substitute_inner_type_lookup = None;
+ if let Some(ref res_binding) = global_variable.binding {
+ self.decorate(id, Decoration::DescriptorSet, &[res_binding.group]);
+ self.decorate(id, Decoration::Binding, &[res_binding.binding]);
+
+ if let Some(&BindingInfo {
+ binding_array_size: Some(remapped_binding_array_size),
+ }) = self.binding_map.get(res_binding)
+ {
+ if let crate::TypeInner::BindingArray { base, .. } =
+ ir_module.types[global_variable.ty].inner
+ {
+ substitute_inner_type_lookup =
+ Some(LookupType::Local(LocalType::PointerToBindingArray {
+ base,
+ size: remapped_binding_array_size,
+ space: global_variable.space,
+ }))
+ }
+ }
+ };
+
+ let init_word = global_variable
+ .init
+ .map(|constant| self.constant_ids[constant.index()]);
+ let inner_type_id = self.get_type_id(
+ substitute_inner_type_lookup.unwrap_or(LookupType::Handle(global_variable.ty)),
+ );
+
+ // generate the wrapping structure if needed
+ let pointer_type_id = if global_needs_wrapper(ir_module, global_variable) {
+ let wrapper_type_id = self.id_gen.next();
+
+ self.decorate(wrapper_type_id, Decoration::Block, &[]);
+ let member = crate::StructMember {
+ name: None,
+ ty: global_variable.ty,
+ binding: None,
+ offset: 0,
+ };
+ self.decorate_struct_member(wrapper_type_id, 0, &member, &ir_module.types)?;
+
+ Instruction::type_struct(wrapper_type_id, &[inner_type_id])
+ .to_words(&mut self.logical_layout.declarations);
+
+ let pointer_type_id = self.id_gen.next();
+ Instruction::type_pointer(pointer_type_id, class, wrapper_type_id)
+ .to_words(&mut self.logical_layout.declarations);
+
+ pointer_type_id
+ } else {
+ // This is a global variable in the Storage address space. The only
+ // way it could have `global_needs_wrapper() == false` is if it has
+ // a runtime-sized or binding array.
+ // Runtime-sized arrays were decorated when iterating through struct content.
+ // Now binding arrays require Block decorating.
+ if let crate::AddressSpace::Storage { .. } = global_variable.space {
+ match ir_module.types[global_variable.ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => {
+ let decorated_id = self.get_type_id(LookupType::Handle(base));
+ self.decorate(decorated_id, Decoration::Block, &[]);
+ }
+ _ => (),
+ };
+ }
+ if substitute_inner_type_lookup.is_some() {
+ inner_type_id
+ } else {
+ self.get_pointer_id(&ir_module.types, global_variable.ty, class)?
+ }
+ };
+
+ let init_word = match (global_variable.space, self.zero_initialize_workgroup_memory) {
+ (crate::AddressSpace::Private, _)
+ | (crate::AddressSpace::WorkGroup, super::ZeroInitializeWorkgroupMemoryMode::Native) => {
+ init_word.or_else(|| Some(self.get_constant_null(inner_type_id)))
+ }
+ _ => init_word,
+ };
+
+ Instruction::variable(pointer_type_id, id, class, init_word)
+ .to_words(&mut self.logical_layout.declarations);
+ Ok(id)
+ }
+
+ /// Write the necessary decorations for a struct member.
+ ///
+ /// Emit decorations for the `index`'th member of the struct type
+ /// designated by `struct_id`, described by `member`.
+ fn decorate_struct_member(
+ &mut self,
+ struct_id: Word,
+ index: usize,
+ member: &crate::StructMember,
+ arena: &UniqueArena<crate::Type>,
+ ) -> Result<(), Error> {
+ use spirv::Decoration;
+
+ self.annotations.push(Instruction::member_decorate(
+ struct_id,
+ index as u32,
+ Decoration::Offset,
+ &[member.offset],
+ ));
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = member.name {
+ self.debugs
+ .push(Instruction::member_name(struct_id, index as u32, name));
+ }
+ }
+
+ // Matrices and arrays of matrices both require decorations,
+ // so "see through" an array to determine if they're needed.
+ let member_array_subty_inner = match arena[member.ty].inner {
+ crate::TypeInner::Array { base, .. } => &arena[base].inner,
+ ref other => other,
+ };
+ if let crate::TypeInner::Matrix {
+ columns: _,
+ rows,
+ scalar,
+ } = *member_array_subty_inner
+ {
+ let byte_stride = Alignment::from(rows) * scalar.width as u32;
+ self.annotations.push(Instruction::member_decorate(
+ struct_id,
+ index as u32,
+ Decoration::ColMajor,
+ &[],
+ ));
+ self.annotations.push(Instruction::member_decorate(
+ struct_id,
+ index as u32,
+ Decoration::MatrixStride,
+ &[byte_stride],
+ ));
+ }
+
+ Ok(())
+ }
+
+ fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word {
+ match self
+ .lookup_function_type
+ .entry(lookup_function_type.clone())
+ {
+ Entry::Occupied(e) => *e.get(),
+ Entry::Vacant(_) => {
+ let id = self.id_gen.next();
+ let instruction = Instruction::type_function(
+ id,
+ lookup_function_type.return_type_id,
+ &lookup_function_type.parameter_type_ids,
+ );
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_function_type.insert(lookup_function_type, id);
+ id
+ }
+ }
+ }
+
+ fn write_physical_layout(&mut self) {
+ self.physical_layout.bound = self.id_gen.0 + 1;
+ }
+
+ fn write_logical_layout(
+ &mut self,
+ ir_module: &crate::Module,
+ mod_info: &ModuleInfo,
+ ep_index: Option<usize>,
+ debug_info: &Option<DebugInfo>,
+ ) -> Result<(), Error> {
+ fn has_view_index_check(
+ ir_module: &crate::Module,
+ binding: Option<&crate::Binding>,
+ ty: Handle<crate::Type>,
+ ) -> bool {
+ match ir_module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => members.iter().any(|member| {
+ has_view_index_check(ir_module, member.binding.as_ref(), member.ty)
+ }),
+ _ => binding == Some(&crate::Binding::BuiltIn(crate::BuiltIn::ViewIndex)),
+ }
+ }
+
+ let has_storage_buffers =
+ ir_module
+ .global_variables
+ .iter()
+ .any(|(_, var)| match var.space {
+ crate::AddressSpace::Storage { .. } => true,
+ _ => false,
+ });
+ let has_view_index = ir_module
+ .entry_points
+ .iter()
+ .flat_map(|entry| entry.function.arguments.iter())
+ .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty));
+ let has_ray_query = ir_module.special_types.ray_desc.is_some()
+ | ir_module.special_types.ray_intersection.is_some();
+
+ if self.physical_layout.version < 0x10300 && has_storage_buffers {
+ // enable the storage buffer class on < SPV-1.3
+ Instruction::extension("SPV_KHR_storage_buffer_storage_class")
+ .to_words(&mut self.logical_layout.extensions);
+ }
+ if has_view_index {
+ Instruction::extension("SPV_KHR_multiview")
+ .to_words(&mut self.logical_layout.extensions)
+ }
+ if has_ray_query {
+ Instruction::extension("SPV_KHR_ray_query")
+ .to_words(&mut self.logical_layout.extensions)
+ }
+ Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations);
+ Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450")
+ .to_words(&mut self.logical_layout.ext_inst_imports);
+
+ let mut debug_info_inner = None;
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(debug_info) = debug_info.as_ref() {
+ let source_file_id = self.id_gen.next();
+ self.debugs.push(Instruction::string(
+ &debug_info.file_name.display().to_string(),
+ source_file_id,
+ ));
+
+ debug_info_inner = Some(DebugInfoInner {
+ source_code: debug_info.source_code,
+ source_file_id,
+ });
+ self.debugs.push(Instruction::source(
+ spirv::SourceLanguage::Unknown,
+ 0,
+ &debug_info_inner,
+ ));
+ }
+ }
+
+ // write all types
+ for (handle, _) in ir_module.types.iter() {
+ self.write_type_declaration_arena(&ir_module.types, handle)?;
+ }
+
+ // write all const-expressions as constants
+ self.constant_ids
+ .resize(ir_module.const_expressions.len(), 0);
+ for (handle, _) in ir_module.const_expressions.iter() {
+ self.write_constant_expr(handle, ir_module, mod_info)?;
+ }
+ debug_assert!(self.constant_ids.iter().all(|&id| id != 0));
+
+ // write the name of constants on their respective const-expression initializer
+ if self.flags.contains(WriterFlags::DEBUG) {
+ for (_, constant) in ir_module.constants.iter() {
+ if let Some(ref name) = constant.name {
+ let id = self.constant_ids[constant.init.index()];
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+ }
+
+ // write all global variables
+ for (handle, var) in ir_module.global_variables.iter() {
+ // If a single entry point was specified, only write `OpVariable` instructions
+ // for the globals it actually uses. Emit dummies for the others,
+ // to preserve the indices in `global_variables`.
+ let gvar = match ep_index {
+ Some(index) if mod_info.get_entry_point(index)[handle].is_empty() => {
+ GlobalVariable::dummy()
+ }
+ _ => {
+ let id = self.write_global_variable(ir_module, var)?;
+ GlobalVariable::new(id)
+ }
+ };
+ self.global_variables.push(gvar);
+ }
+
+ // write all functions
+ for (handle, ir_function) in ir_module.functions.iter() {
+ let info = &mod_info[handle];
+ if let Some(index) = ep_index {
+ let ep_info = mod_info.get_entry_point(index);
+ // If this function uses globals that we omitted from the SPIR-V
+ // because the entry point and its callees didn't use them,
+ // then we must skip it.
+ if !ep_info.dominates_global_use(info) {
+ log::info!("Skip function {:?}", ir_function.name);
+ continue;
+ }
+
+ // Skip functions that that are not compatible with this entry point's stage.
+ //
+ // When validation is enabled, it rejects modules whose entry points try to call
+ // incompatible functions, so if we got this far, then any functions incompatible
+ // with our selected entry point must not be used.
+ //
+ // When validation is disabled, `fun_info.available_stages` is always just
+ // `ShaderStages::all()`, so this will write all functions in the module, and
+ // the downstream GLSL compiler will catch any problems.
+ if !info.available_stages.contains(ep_info.available_stages) {
+ continue;
+ }
+ }
+ let id = self.write_function(ir_function, info, ir_module, None, &debug_info_inner)?;
+ self.lookup_function.insert(handle, id);
+ }
+
+ // write all or one entry points
+ for (index, ir_ep) in ir_module.entry_points.iter().enumerate() {
+ if ep_index.is_some() && ep_index != Some(index) {
+ continue;
+ }
+ let info = mod_info.get_entry_point(index);
+ let ep_instruction =
+ self.write_entry_point(ir_ep, info, ir_module, &debug_info_inner)?;
+ ep_instruction.to_words(&mut self.logical_layout.entry_points);
+ }
+
+ for capability in self.capabilities_used.iter() {
+ Instruction::capability(*capability).to_words(&mut self.logical_layout.capabilities);
+ }
+ for extension in self.extensions_used.iter() {
+ Instruction::extension(extension).to_words(&mut self.logical_layout.extensions);
+ }
+ if ir_module.entry_points.is_empty() {
+ // SPIR-V doesn't like modules without entry points
+ Instruction::capability(spirv::Capability::Linkage)
+ .to_words(&mut self.logical_layout.capabilities);
+ }
+
+ let addressing_model = spirv::AddressingModel::Logical;
+ let memory_model = spirv::MemoryModel::GLSL450;
+ //self.check(addressing_model.required_capabilities())?;
+ //self.check(memory_model.required_capabilities())?;
+
+ Instruction::memory_model(addressing_model, memory_model)
+ .to_words(&mut self.logical_layout.memory_model);
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ for debug in self.debugs.iter() {
+ debug.to_words(&mut self.logical_layout.debugs);
+ }
+ }
+
+ for annotation in self.annotations.iter() {
+ annotation.to_words(&mut self.logical_layout.annotations);
+ }
+
+ Ok(())
+ }
+
+ pub fn write(
+ &mut self,
+ ir_module: &crate::Module,
+ info: &ModuleInfo,
+ pipeline_options: Option<&PipelineOptions>,
+ debug_info: &Option<DebugInfo>,
+ words: &mut Vec<Word>,
+ ) -> Result<(), Error> {
+ self.reset();
+
+ // Try to find the entry point and corresponding index
+ let ep_index = match pipeline_options {
+ Some(po) => {
+ let index = ir_module
+ .entry_points
+ .iter()
+ .position(|ep| po.shader_stage == ep.stage && po.entry_point == ep.name)
+ .ok_or(Error::EntryPointNotFound)?;
+ Some(index)
+ }
+ None => None,
+ };
+
+ self.write_logical_layout(ir_module, info, ep_index, debug_info)?;
+ self.write_physical_layout();
+
+ self.physical_layout.in_words(words);
+ self.logical_layout.in_words(words);
+ Ok(())
+ }
+
+ /// Return the set of capabilities the last module written used.
+ pub const fn get_capabilities_used(&self) -> &crate::FastIndexSet<spirv::Capability> {
+ &self.capabilities_used
+ }
+
+ pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> {
+ self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?;
+ self.use_extension("SPV_EXT_descriptor_indexing");
+ self.decorate(id, spirv::Decoration::NonUniform, &[]);
+ Ok(())
+ }
+}
+
+#[test]
+fn test_write_physical_layout() {
+ let mut writer = Writer::new(&Options::default()).unwrap();
+ assert_eq!(writer.physical_layout.bound, 0);
+ writer.write_physical_layout();
+ assert_eq!(writer.physical_layout.bound, 3);
+}
diff --git a/third_party/rust/naga/src/back/wgsl/mod.rs b/third_party/rust/naga/src/back/wgsl/mod.rs
new file mode 100644
index 0000000000..d731b1ca0c
--- /dev/null
+++ b/third_party/rust/naga/src/back/wgsl/mod.rs
@@ -0,0 +1,52 @@
+/*!
+Backend for [WGSL][wgsl] (WebGPU Shading Language).
+
+[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html
+*/
+
+mod writer;
+
+use thiserror::Error;
+
+pub use writer::{Writer, WriterFlags};
+
+#[derive(Error, Debug)]
+pub enum Error {
+ #[error(transparent)]
+ FmtError(#[from] std::fmt::Error),
+ #[error("{0}")]
+ Custom(String),
+ #[error("{0}")]
+ Unimplemented(String), // TODO: Error used only during development
+ #[error("Unsupported math function: {0:?}")]
+ UnsupportedMathFunction(crate::MathFunction),
+ #[error("Unsupported relational function: {0:?}")]
+ UnsupportedRelationalFunction(crate::RelationalFunction),
+}
+
+pub fn write_string(
+ module: &crate::Module,
+ info: &crate::valid::ModuleInfo,
+ flags: WriterFlags,
+) -> Result<String, Error> {
+ let mut w = Writer::new(String::new(), flags);
+ w.write(module, info)?;
+ let output = w.finish();
+ Ok(output)
+}
+
+impl crate::AtomicFunction {
+ const fn to_wgsl(self) -> &'static str {
+ match self {
+ Self::Add => "Add",
+ Self::Subtract => "Sub",
+ Self::And => "And",
+ Self::InclusiveOr => "Or",
+ Self::ExclusiveOr => "Xor",
+ Self::Min => "Min",
+ Self::Max => "Max",
+ Self::Exchange { compare: None } => "Exchange",
+ Self::Exchange { .. } => "CompareExchangeWeak",
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/wgsl/writer.rs b/third_party/rust/naga/src/back/wgsl/writer.rs
new file mode 100644
index 0000000000..c737934f5e
--- /dev/null
+++ b/third_party/rust/naga/src/back/wgsl/writer.rs
@@ -0,0 +1,1961 @@
+use super::Error;
+use crate::{
+ back,
+ proc::{self, NameKey},
+ valid, Handle, Module, ShaderStage, TypeInner,
+};
+use std::fmt::Write;
+
+/// Shorthand result used internally by the backend
+type BackendResult = Result<(), Error>;
+
+/// WGSL [attribute](https://gpuweb.github.io/gpuweb/wgsl/#attributes)
+enum Attribute {
+ Binding(u32),
+ BuiltIn(crate::BuiltIn),
+ Group(u32),
+ Invariant,
+ Interpolate(Option<crate::Interpolation>, Option<crate::Sampling>),
+ Location(u32),
+ SecondBlendSource,
+ Stage(ShaderStage),
+ WorkGroupSize([u32; 3]),
+}
+
+/// The WGSL form that `write_expr_with_indirection` should use to render a Naga
+/// expression.
+///
+/// Sometimes a Naga `Expression` alone doesn't provide enough information to
+/// choose the right rendering for it in WGSL. For example, one natural WGSL
+/// rendering of a Naga `LocalVariable(x)` expression might be `&x`, since
+/// `LocalVariable` produces a pointer to the local variable's storage. But when
+/// rendering a `Store` statement, the `pointer` operand must be the left hand
+/// side of a WGSL assignment, so the proper rendering is `x`.
+///
+/// The caller of `write_expr_with_indirection` must provide an `Expected` value
+/// to indicate how ambiguous expressions should be rendered.
+#[derive(Clone, Copy, Debug)]
+enum Indirection {
+ /// Render pointer-construction expressions as WGSL `ptr`-typed expressions.
+ ///
+ /// This is the right choice for most cases. Whenever a Naga pointer
+ /// expression is not the `pointer` operand of a `Load` or `Store`, it
+ /// must be a WGSL pointer expression.
+ Ordinary,
+
+ /// Render pointer-construction expressions as WGSL reference-typed
+ /// expressions.
+ ///
+ /// For example, this is the right choice for the `pointer` operand when
+ /// rendering a `Store` statement as a WGSL assignment.
+ Reference,
+}
+
+bitflags::bitflags! {
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct WriterFlags: u32 {
+ /// Always annotate the type information instead of inferring.
+ const EXPLICIT_TYPES = 0x1;
+ }
+}
+
+pub struct Writer<W> {
+ out: W,
+ flags: WriterFlags,
+ names: crate::FastHashMap<NameKey, String>,
+ namer: proc::Namer,
+ named_expressions: crate::NamedExpressions,
+ ep_results: Vec<(ShaderStage, Handle<crate::Type>)>,
+}
+
+impl<W: Write> Writer<W> {
+ pub fn new(out: W, flags: WriterFlags) -> Self {
+ Writer {
+ out,
+ flags,
+ names: crate::FastHashMap::default(),
+ namer: proc::Namer::default(),
+ named_expressions: crate::NamedExpressions::default(),
+ ep_results: vec![],
+ }
+ }
+
+ fn reset(&mut self, module: &Module) {
+ self.names.clear();
+ self.namer.reset(
+ module,
+ crate::keywords::wgsl::RESERVED,
+ // an identifier must not start with two underscore
+ &[],
+ &[],
+ &["__"],
+ &mut self.names,
+ );
+ self.named_expressions.clear();
+ self.ep_results.clear();
+ }
+
+ fn is_builtin_wgsl_struct(&self, module: &Module, handle: Handle<crate::Type>) -> bool {
+ module
+ .special_types
+ .predeclared_types
+ .values()
+ .any(|t| *t == handle)
+ }
+
+ pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
+ self.reset(module);
+
+ // Save all ep result types
+ for (_, ep) in module.entry_points.iter().enumerate() {
+ if let Some(ref result) = ep.function.result {
+ self.ep_results.push((ep.stage, result.ty));
+ }
+ }
+
+ // Write all structs
+ for (handle, ty) in module.types.iter() {
+ if let TypeInner::Struct { ref members, .. } = ty.inner {
+ {
+ if !self.is_builtin_wgsl_struct(module, handle) {
+ self.write_struct(module, handle, members)?;
+ writeln!(self.out)?;
+ }
+ }
+ }
+ }
+
+ // Write all named constants
+ let mut constants = module
+ .constants
+ .iter()
+ .filter(|&(_, c)| c.name.is_some())
+ .peekable();
+ while let Some((handle, _)) = constants.next() {
+ self.write_global_constant(module, handle)?;
+ // Add extra newline for readability on last iteration
+ if constants.peek().is_none() {
+ writeln!(self.out)?;
+ }
+ }
+
+ // Write all globals
+ for (ty, global) in module.global_variables.iter() {
+ self.write_global(module, global, ty)?;
+ }
+
+ if !module.global_variables.is_empty() {
+ // Add extra newline for readability
+ writeln!(self.out)?;
+ }
+
+ // Write all regular functions
+ for (handle, function) in module.functions.iter() {
+ let fun_info = &info[handle];
+
+ let func_ctx = back::FunctionCtx {
+ ty: back::FunctionType::Function(handle),
+ info: fun_info,
+ expressions: &function.expressions,
+ named_expressions: &function.named_expressions,
+ };
+
+ // Write the function
+ self.write_function(module, function, &func_ctx)?;
+
+ writeln!(self.out)?;
+ }
+
+ // Write all entry points
+ for (index, ep) in module.entry_points.iter().enumerate() {
+ let attributes = match ep.stage {
+ ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)],
+ ShaderStage::Compute => vec![
+ Attribute::Stage(ShaderStage::Compute),
+ Attribute::WorkGroupSize(ep.workgroup_size),
+ ],
+ };
+
+ self.write_attributes(&attributes)?;
+ // Add a newline after attribute
+ writeln!(self.out)?;
+
+ let func_ctx = back::FunctionCtx {
+ ty: back::FunctionType::EntryPoint(index as u16),
+ info: info.get_entry_point(index),
+ expressions: &ep.function.expressions,
+ named_expressions: &ep.function.named_expressions,
+ };
+ self.write_function(module, &ep.function, &func_ctx)?;
+
+ if index < module.entry_points.len() - 1 {
+ writeln!(self.out)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write struct name
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_struct_name(&mut self, module: &Module, handle: Handle<crate::Type>) -> BackendResult {
+ if module.types[handle].name.is_none() {
+ if let Some(&(stage, _)) = self.ep_results.iter().find(|&&(_, ty)| ty == handle) {
+ let name = match stage {
+ ShaderStage::Compute => "ComputeOutput",
+ ShaderStage::Fragment => "FragmentOutput",
+ ShaderStage::Vertex => "VertexOutput",
+ };
+
+ write!(self.out, "{name}")?;
+ return Ok(());
+ }
+ }
+
+ write!(self.out, "{}", self.names[&NameKey::Type(handle)])?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write
+ /// [functions](https://gpuweb.github.io/gpuweb/wgsl/#functions)
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_function(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ let func_name = match func_ctx.ty {
+ back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)],
+ back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
+ };
+
+ // Write function name
+ write!(self.out, "fn {func_name}(")?;
+
+ // Write function arguments
+ for (index, arg) in func.arguments.iter().enumerate() {
+ // Write argument attribute if a binding is present
+ if let Some(ref binding) = arg.binding {
+ self.write_attributes(&map_binding_to_attribute(binding))?;
+ }
+ // Write argument name
+ let argument_name = &self.names[&func_ctx.argument_key(index as u32)];
+
+ write!(self.out, "{argument_name}: ")?;
+ // Write argument type
+ self.write_type(module, arg.ty)?;
+ if index < func.arguments.len() - 1 {
+ // Add a separator between args
+ write!(self.out, ", ")?;
+ }
+ }
+
+ write!(self.out, ")")?;
+
+ // Write function return type
+ if let Some(ref result) = func.result {
+ write!(self.out, " -> ")?;
+ if let Some(ref binding) = result.binding {
+ self.write_attributes(&map_binding_to_attribute(binding))?;
+ }
+ self.write_type(module, result.ty)?;
+ }
+
+ write!(self.out, " {{")?;
+ writeln!(self.out)?;
+
+ // Write function local variables
+ for (handle, local) in func.local_variables.iter() {
+ // Write indentation (only for readability)
+ write!(self.out, "{}", back::INDENT)?;
+
+ // Write the local name
+ // The leading space is important
+ write!(self.out, "var {}: ", self.names[&func_ctx.name_key(handle)])?;
+
+ // Write the local type
+ self.write_type(module, local.ty)?;
+
+ // Write the local initializer if needed
+ if let Some(init) = local.init {
+ // Put the equal signal only if there's a initializer
+ // The leading and trailing spaces aren't needed but help with readability
+ write!(self.out, " = ")?;
+
+ // Write the constant
+ // `write_constant` adds no trailing or leading space/newline
+ self.write_expr(module, init, func_ctx)?;
+ }
+
+ // Finish the local with `;` and add a newline (only for readability)
+ writeln!(self.out, ";")?
+ }
+
+ if !func.local_variables.is_empty() {
+ writeln!(self.out)?;
+ }
+
+ // Write the function body (statement list)
+ for sta in func.body.iter() {
+ // The indentation should always be 1 when writing the function body
+ self.write_stmt(module, sta, func_ctx, back::Level(1))?;
+ }
+
+ writeln!(self.out, "}}")?;
+
+ self.named_expressions.clear();
+
+ Ok(())
+ }
+
+ /// Helper method to write a attribute
+ fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult {
+ for attribute in attributes {
+ match *attribute {
+ Attribute::Location(id) => write!(self.out, "@location({id}) ")?,
+ Attribute::SecondBlendSource => write!(self.out, "@second_blend_source ")?,
+ Attribute::BuiltIn(builtin_attrib) => {
+ let builtin = builtin_str(builtin_attrib)?;
+ write!(self.out, "@builtin({builtin}) ")?;
+ }
+ Attribute::Stage(shader_stage) => {
+ let stage_str = match shader_stage {
+ ShaderStage::Vertex => "vertex",
+ ShaderStage::Fragment => "fragment",
+ ShaderStage::Compute => "compute",
+ };
+ write!(self.out, "@{stage_str} ")?;
+ }
+ Attribute::WorkGroupSize(size) => {
+ write!(
+ self.out,
+ "@workgroup_size({}, {}, {}) ",
+ size[0], size[1], size[2]
+ )?;
+ }
+ Attribute::Binding(id) => write!(self.out, "@binding({id}) ")?,
+ Attribute::Group(id) => write!(self.out, "@group({id}) ")?,
+ Attribute::Invariant => write!(self.out, "@invariant ")?,
+ Attribute::Interpolate(interpolation, sampling) => {
+ if sampling.is_some() && sampling != Some(crate::Sampling::Center) {
+ write!(
+ self.out,
+ "@interpolate({}, {}) ",
+ interpolation_str(
+ interpolation.unwrap_or(crate::Interpolation::Perspective)
+ ),
+ sampling_str(sampling.unwrap_or(crate::Sampling::Center))
+ )?;
+ } else if interpolation.is_some()
+ && interpolation != Some(crate::Interpolation::Perspective)
+ {
+ write!(
+ self.out,
+ "@interpolate({}) ",
+ interpolation_str(
+ interpolation.unwrap_or(crate::Interpolation::Perspective)
+ )
+ )?;
+ }
+ }
+ };
+ }
+ Ok(())
+ }
+
+ /// Helper method used to write structs
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_struct(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Type>,
+ members: &[crate::StructMember],
+ ) -> BackendResult {
+ write!(self.out, "struct ")?;
+ self.write_struct_name(module, handle)?;
+ write!(self.out, " {{")?;
+ writeln!(self.out)?;
+ for (index, member) in members.iter().enumerate() {
+ // The indentation is only for readability
+ write!(self.out, "{}", back::INDENT)?;
+ if let Some(ref binding) = member.binding {
+ self.write_attributes(&map_binding_to_attribute(binding))?;
+ }
+ // Write struct member name and type
+ let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
+ write!(self.out, "{member_name}: ")?;
+ self.write_type(module, member.ty)?;
+ write!(self.out, ",")?;
+ writeln!(self.out)?;
+ }
+
+ write!(self.out, "}}")?;
+
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ let inner = &module.types[ty].inner;
+ match *inner {
+ TypeInner::Struct { .. } => self.write_struct_name(module, ty)?,
+ ref other => self.write_value_type(module, other)?,
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write value types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
+ match *inner {
+ TypeInner::Vector { size, scalar } => write!(
+ self.out,
+ "vec{}<{}>",
+ back::vector_size_str(size),
+ scalar_kind_str(scalar),
+ )?,
+ TypeInner::Sampler { comparison: false } => {
+ write!(self.out, "sampler")?;
+ }
+ TypeInner::Sampler { comparison: true } => {
+ write!(self.out, "sampler_comparison")?;
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ // More about texture types: https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type
+ use crate::ImageClass as Ic;
+
+ let dim_str = image_dimension_str(dim);
+ let arrayed_str = if arrayed { "_array" } else { "" };
+ let (class_str, multisampled_str, format_str, storage_str) = match class {
+ Ic::Sampled { kind, multi } => (
+ "",
+ if multi { "multisampled_" } else { "" },
+ scalar_kind_str(crate::Scalar { kind, width: 4 }),
+ "",
+ ),
+ Ic::Depth { multi } => {
+ ("depth_", if multi { "multisampled_" } else { "" }, "", "")
+ }
+ Ic::Storage { format, access } => (
+ "storage_",
+ "",
+ storage_format_str(format),
+ if access.contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
+ {
+ ",read_write"
+ } else if access.contains(crate::StorageAccess::LOAD) {
+ ",read"
+ } else {
+ ",write"
+ },
+ ),
+ };
+ write!(
+ self.out,
+ "texture_{class_str}{multisampled_str}{dim_str}{arrayed_str}"
+ )?;
+
+ if !format_str.is_empty() {
+ write!(self.out, "<{format_str}{storage_str}>")?;
+ }
+ }
+ TypeInner::Scalar(scalar) => {
+ write!(self.out, "{}", scalar_kind_str(scalar))?;
+ }
+ TypeInner::Atomic(scalar) => {
+ write!(self.out, "atomic<{}>", scalar_kind_str(scalar))?;
+ }
+ TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ } => {
+ // More info https://gpuweb.github.io/gpuweb/wgsl/#array-types
+ // array<A, 3> -- Constant array
+ // array<A> -- Dynamic array
+ write!(self.out, "array<")?;
+ match size {
+ crate::ArraySize::Constant(len) => {
+ self.write_type(module, base)?;
+ write!(self.out, ", {len}")?;
+ }
+ crate::ArraySize::Dynamic => {
+ self.write_type(module, base)?;
+ }
+ }
+ write!(self.out, ">")?;
+ }
+ TypeInner::BindingArray { base, size } => {
+ // More info https://github.com/gpuweb/gpuweb/issues/2105
+ write!(self.out, "binding_array<")?;
+ match size {
+ crate::ArraySize::Constant(len) => {
+ self.write_type(module, base)?;
+ write!(self.out, ", {len}")?;
+ }
+ crate::ArraySize::Dynamic => {
+ self.write_type(module, base)?;
+ }
+ }
+ write!(self.out, ">")?;
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ write!(
+ self.out,
+ "mat{}x{}<{}>",
+ back::vector_size_str(columns),
+ back::vector_size_str(rows),
+ scalar_kind_str(scalar)
+ )?;
+ }
+ TypeInner::Pointer { base, space } => {
+ let (address, maybe_access) = address_space_str(space);
+ // Everything but `AddressSpace::Handle` gives us a `address` name, but
+ // Naga IR never produces pointers to handles, so it doesn't matter much
+ // how we write such a type. Just write it as the base type alone.
+ if let Some(space) = address {
+ write!(self.out, "ptr<{space}, ")?;
+ }
+ self.write_type(module, base)?;
+ if address.is_some() {
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {access}")?;
+ }
+ write!(self.out, ">")?;
+ }
+ }
+ TypeInner::ValuePointer {
+ size: None,
+ scalar,
+ space,
+ } => {
+ let (address, maybe_access) = address_space_str(space);
+ if let Some(space) = address {
+ write!(self.out, "ptr<{}, {}", space, scalar_kind_str(scalar))?;
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {access}")?;
+ }
+ write!(self.out, ">")?;
+ } else {
+ return Err(Error::Unimplemented(format!(
+ "ValuePointer to AddressSpace::Handle {inner:?}"
+ )));
+ }
+ }
+ TypeInner::ValuePointer {
+ size: Some(size),
+ scalar,
+ space,
+ } => {
+ let (address, maybe_access) = address_space_str(space);
+ if let Some(space) = address {
+ write!(
+ self.out,
+ "ptr<{}, vec{}<{}>",
+ space,
+ back::vector_size_str(size),
+ scalar_kind_str(scalar)
+ )?;
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {access}")?;
+ }
+ write!(self.out, ">")?;
+ } else {
+ return Err(Error::Unimplemented(format!(
+ "ValuePointer to AddressSpace::Handle {inner:?}"
+ )));
+ }
+ write!(self.out, ">")?;
+ }
+ _ => {
+ return Err(Error::Unimplemented(format!("write_value_type {inner:?}")));
+ }
+ }
+
+ Ok(())
+ }
+ /// Helper method used to write statements
+ ///
+ /// # Notes
+ /// Always adds a newline
+ fn write_stmt(
+ &mut self,
+ module: &Module,
+ stmt: &crate::Statement,
+ func_ctx: &back::FunctionCtx<'_>,
+ level: back::Level,
+ ) -> BackendResult {
+ use crate::{Expression, Statement};
+
+ match *stmt {
+ Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ let info = &func_ctx.info[handle];
+ let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) {
+ // Front end provides names for all variables at the start of writing.
+ // But we write them to step by step. We need to recache them
+ // Otherwise, we could accidentally write variable name instead of full expression.
+ // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
+ Some(self.namer.call(name))
+ } else {
+ let expr = &func_ctx.expressions[handle];
+ let min_ref_count = expr.bake_ref_count();
+ // Forcefully creating baking expressions in some cases to help with readability
+ let required_baking_expr = match *expr {
+ Expression::ImageLoad { .. }
+ | Expression::ImageQuery { .. }
+ | Expression::ImageSample { .. } => true,
+ _ => false,
+ };
+ if min_ref_count <= info.ref_count || required_baking_expr {
+ Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
+ } else {
+ None
+ }
+ };
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{level}")?;
+ self.start_named_expr(module, handle, func_ctx, &name)?;
+ self.write_expr(module, handle, func_ctx)?;
+ self.named_expressions.insert(handle, name);
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "if ")?;
+ self.write_expr(module, condition, func_ctx)?;
+ writeln!(self.out, " {{")?;
+
+ let l2 = level.next();
+ for sta in accept {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+
+ // If there are no statements in the reject block we skip writing it
+ // This is only for readability
+ if !reject.is_empty() {
+ writeln!(self.out, "{level}}} else {{")?;
+
+ for sta in reject {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Return { value } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "return")?;
+ if let Some(return_value) = value {
+ // The leading space is important
+ write!(self.out, " ")?;
+ self.write_expr(module, return_value, func_ctx)?;
+ }
+ writeln!(self.out, ";")?;
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Kill => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "discard;")?
+ }
+ Statement::Store { pointer, value } => {
+ write!(self.out, "{level}")?;
+
+ let is_atomic_pointer = func_ctx
+ .resolve_type(pointer, &module.types)
+ .is_atomic_pointer(&module.types);
+
+ if is_atomic_pointer {
+ write!(self.out, "atomicStore(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr_with_indirection(
+ module,
+ pointer,
+ func_ctx,
+ Indirection::Reference,
+ )?;
+ write!(self.out, " = ")?;
+ self.write_expr(module, value, func_ctx)?;
+ }
+ writeln!(self.out, ";")?
+ }
+ Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ if let Some(expr) = result {
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ self.start_named_expr(module, expr, func_ctx, &name)?;
+ self.named_expressions.insert(expr, name);
+ }
+ let func_name = &self.names[&NameKey::Function(function)];
+ write!(self.out, "{func_name}(")?;
+ for (index, &argument) in arguments.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_expr(module, argument, func_ctx)?;
+ }
+ writeln!(self.out, ");")?
+ }
+ Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_named_expr(module, result, func_ctx, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+
+ let fun_str = fun.to_wgsl();
+ write!(self.out, "atomic{fun_str}(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ write!(self.out, ", ")?;
+ self.write_expr(module, cmp, func_ctx)?;
+ }
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ");")?
+ }
+ Statement::WorkGroupUniformLoad { pointer, result } => {
+ write!(self.out, "{level}")?;
+ // TODO: Obey named expressions here.
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_named_expr(module, result, func_ctx, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+ write!(self.out, "workgroupUniformLoad(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "textureStore(")?;
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ if let Some(array_index_expr) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index_expr, func_ctx)?;
+ }
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Block(ref block) => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "{{")?;
+ for sta in block.iter() {
+ // Increase the indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, level.next())?
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ // Start the switch
+ write!(self.out, "{level}")?;
+ write!(self.out, "switch ")?;
+ self.write_expr(module, selector, func_ctx)?;
+ writeln!(self.out, " {{")?;
+
+ let l2 = level.next();
+ let mut new_case = true;
+ for case in cases {
+ if case.fall_through && !case.body.is_empty() {
+ // TODO: we could do the same workaround as we did for the HLSL backend
+ return Err(Error::Unimplemented(
+ "fall-through switch case block".into(),
+ ));
+ }
+
+ match case.value {
+ crate::SwitchValue::I32(value) => {
+ if new_case {
+ write!(self.out, "{l2}case ")?;
+ }
+ write!(self.out, "{value}")?;
+ }
+ crate::SwitchValue::U32(value) => {
+ if new_case {
+ write!(self.out, "{l2}case ")?;
+ }
+ write!(self.out, "{value}u")?;
+ }
+ crate::SwitchValue::Default => {
+ if new_case {
+ if case.fall_through {
+ write!(self.out, "{l2}case ")?;
+ } else {
+ write!(self.out, "{l2}")?;
+ }
+ }
+ write!(self.out, "default")?;
+ }
+ }
+
+ new_case = !case.fall_through;
+
+ if case.fall_through {
+ write!(self.out, ", ")?;
+ } else {
+ writeln!(self.out, ": {{")?;
+ }
+
+ for sta in case.body.iter() {
+ self.write_stmt(module, sta, func_ctx, l2.next())?;
+ }
+
+ if !case.fall_through {
+ writeln!(self.out, "{l2}}}")?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "loop {{")?;
+
+ let l2 = level.next();
+ for sta in body.iter() {
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+
+ // The continuing is optional so we don't need to write it if
+ // it is empty, but the `break if` counts as a continuing statement
+ // so even if `continuing` is empty we must generate it if a
+ // `break if` exists
+ if !continuing.is_empty() || break_if.is_some() {
+ writeln!(self.out, "{l2}continuing {{")?;
+ for sta in continuing.iter() {
+ self.write_stmt(module, sta, func_ctx, l2.next())?;
+ }
+
+ // The `break if` is always the last
+ // statement of the `continuing` block
+ if let Some(condition) = break_if {
+ // The trailing space is important
+ write!(self.out, "{}break if ", l2.next())?;
+ self.write_expr(module, condition, func_ctx)?;
+ // Close the `break if` statement
+ writeln!(self.out, ";")?;
+ }
+
+ writeln!(self.out, "{l2}}}")?;
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Break => {
+ writeln!(self.out, "{level}break;")?;
+ }
+ Statement::Continue => {
+ writeln!(self.out, "{level}continue;")?;
+ }
+ Statement::Barrier(barrier) => {
+ if barrier.contains(crate::Barrier::STORAGE) {
+ writeln!(self.out, "{level}storageBarrier();")?;
+ }
+
+ if barrier.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(self.out, "{level}workgroupBarrier();")?;
+ }
+ }
+ Statement::RayQuery { .. } => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Return the sort of indirection that `expr`'s plain form evaluates to.
+ ///
+ /// An expression's 'plain form' is the most general rendition of that
+ /// expression into WGSL, lacking `&` or `*` operators:
+ ///
+ /// - The plain form of `LocalVariable(x)` is simply `x`, which is a reference
+ /// to the local variable's storage.
+ ///
+ /// - The plain form of `GlobalVariable(g)` is simply `g`, which is usually a
+ /// reference to the global variable's storage. However, globals in the
+ /// `Handle` address space are immutable, and `GlobalVariable` expressions for
+ /// those produce the value directly, not a pointer to it. Such
+ /// `GlobalVariable` expressions are `Ordinary`.
+ ///
+ /// - `Access` and `AccessIndex` are `Reference` when their `base` operand is a
+ /// pointer. If they are applied directly to a composite value, they are
+ /// `Ordinary`.
+ ///
+ /// Note that `FunctionArgument` expressions are never `Reference`, even when
+ /// the argument's type is `Pointer`. `FunctionArgument` always evaluates to the
+ /// argument's value directly, so any pointer it produces is merely the value
+ /// passed by the caller.
+ fn plain_form_indirection(
+ &self,
+ expr: Handle<crate::Expression>,
+ module: &Module,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> Indirection {
+ use crate::Expression as Ex;
+
+ // Named expressions are `let` expressions, which apply the Load Rule,
+ // so if their type is a Naga pointer, then that must be a WGSL pointer
+ // as well.
+ if self.named_expressions.contains_key(&expr) {
+ return Indirection::Ordinary;
+ }
+
+ match func_ctx.expressions[expr] {
+ Ex::LocalVariable(_) => Indirection::Reference,
+ Ex::GlobalVariable(handle) => {
+ let global = &module.global_variables[handle];
+ match global.space {
+ crate::AddressSpace::Handle => Indirection::Ordinary,
+ _ => Indirection::Reference,
+ }
+ }
+ Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
+ let base_ty = func_ctx.resolve_type(base, &module.types);
+ match *base_ty {
+ crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
+ Indirection::Reference
+ }
+ _ => Indirection::Ordinary,
+ }
+ }
+ _ => Indirection::Ordinary,
+ }
+ }
+
+ fn start_named_expr(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx,
+ name: &str,
+ ) -> BackendResult {
+ // Write variable name
+ write!(self.out, "let {name}")?;
+ if self.flags.contains(WriterFlags::EXPLICIT_TYPES) {
+ write!(self.out, ": ")?;
+ let ty = &func_ctx.info[handle].ty;
+ // Write variable type
+ match *ty {
+ proc::TypeResolution::Handle(handle) => {
+ self.write_type(module, handle)?;
+ }
+ proc::TypeResolution::Value(ref inner) => {
+ self.write_value_type(module, inner)?;
+ }
+ }
+ }
+
+ write!(self.out, " = ")?;
+ Ok(())
+ }
+
+ /// Write the ordinary WGSL form of `expr`.
+ ///
+ /// See `write_expr_with_indirection` for details.
+ fn write_expr(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary)
+ }
+
+ /// Write `expr` as a WGSL expression with the requested indirection.
+ ///
+ /// In terms of the WGSL grammar, the resulting expression is a
+ /// `singular_expression`. It may be parenthesized. This makes it suitable
+ /// for use as the operand of a unary or binary operator without worrying
+ /// about precedence.
+ ///
+ /// This does not produce newlines or indentation.
+ ///
+ /// The `requested` argument indicates (roughly) whether Naga
+ /// `Pointer`-valued expressions represent WGSL references or pointers. See
+ /// `Indirection` for details.
+ fn write_expr_with_indirection(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ requested: Indirection,
+ ) -> BackendResult {
+ // If the plain form of the expression is not what we need, emit the
+ // operator necessary to correct that.
+ let plain = self.plain_form_indirection(expr, module, func_ctx);
+ match (requested, plain) {
+ (Indirection::Ordinary, Indirection::Reference) => {
+ write!(self.out, "(&")?;
+ self.write_expr_plain_form(module, expr, func_ctx, plain)?;
+ write!(self.out, ")")?;
+ }
+ (Indirection::Reference, Indirection::Ordinary) => {
+ write!(self.out, "(*")?;
+ self.write_expr_plain_form(module, expr, func_ctx, plain)?;
+ write!(self.out, ")")?;
+ }
+ (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?,
+ }
+
+ Ok(())
+ }
+
+ fn write_const_expression(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ ) -> BackendResult {
+ self.write_possibly_const_expression(
+ module,
+ expr,
+ &module.const_expressions,
+ |writer, expr| writer.write_const_expression(module, expr),
+ )
+ }
+
+ fn write_possibly_const_expression<E>(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ expressions: &crate::Arena<crate::Expression>,
+ write_expression: E,
+ ) -> BackendResult
+ where
+ E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
+ {
+ use crate::Expression;
+
+ match expressions[expr] {
+ Expression::Literal(literal) => match literal {
+ crate::Literal::F32(value) => write!(self.out, "{}f", value)?,
+ crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
+ crate::Literal::I32(value) => {
+ // `-2147483648i` is not valid WGSL. The most negative `i32`
+ // value can only be expressed in WGSL using AbstractInt and
+ // a unary negation operator.
+ if value == i32::MIN {
+ write!(self.out, "i32(-2147483648)")?;
+ } else {
+ write!(self.out, "{}i", value)?;
+ }
+ }
+ crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
+ crate::Literal::F64(value) => write!(self.out, "{:?}lf", value)?,
+ crate::Literal::I64(_) => {
+ return Err(Error::Custom("unsupported i64 literal".to_string()));
+ }
+ crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
+ return Err(Error::Custom(
+ "Abstract types should not appear in IR presented to backends".into(),
+ ));
+ }
+ },
+ Expression::Constant(handle) => {
+ let constant = &module.constants[handle];
+ if constant.name.is_some() {
+ write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
+ } else {
+ self.write_const_expression(module, constant.init)?;
+ }
+ }
+ Expression::ZeroValue(ty) => {
+ self.write_type(module, ty)?;
+ write!(self.out, "()")?;
+ }
+ Expression::Compose { ty, ref components } => {
+ self.write_type(module, ty)?;
+ write!(self.out, "(")?;
+ for (index, component) in components.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ write_expression(self, *component)?;
+ }
+ write!(self.out, ")")?
+ }
+ Expression::Splat { size, value } => {
+ let size = back::vector_size_str(size);
+ write!(self.out, "vec{size}(")?;
+ write_expression(self, value)?;
+ write!(self.out, ")")?;
+ }
+ _ => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Write the 'plain form' of `expr`.
+ ///
+ /// An expression's 'plain form' is the most general rendition of that
+ /// expression into WGSL, lacking `&` or `*` operators. The plain forms of
+ /// `LocalVariable(x)` and `GlobalVariable(g)` are simply `x` and `g`. Such
+ /// Naga expressions represent both WGSL pointers and references; it's the
+ /// caller's responsibility to distinguish those cases appropriately.
+ fn write_expr_plain_form(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ indirection: Indirection,
+ ) -> BackendResult {
+ use crate::Expression;
+
+ if let Some(name) = self.named_expressions.get(&expr) {
+ write!(self.out, "{name}")?;
+ return Ok(());
+ }
+
+ let expression = &func_ctx.expressions[expr];
+
+ // Write the plain WGSL form of a Naga expression.
+ //
+ // The plain form of `LocalVariable` and `GlobalVariable` expressions is
+ // simply the variable name; `*` and `&` operators are never emitted.
+ //
+ // The plain form of `Access` and `AccessIndex` expressions are WGSL
+ // `postfix_expression` forms for member/component access and
+ // subscripting.
+ match *expression {
+ Expression::Literal(_)
+ | Expression::Constant(_)
+ | Expression::ZeroValue(_)
+ | Expression::Compose { .. }
+ | Expression::Splat { .. } => {
+ self.write_possibly_const_expression(
+ module,
+ expr,
+ func_ctx.expressions,
+ |writer, expr| writer.write_expr(module, expr, func_ctx),
+ )?;
+ }
+ Expression::FunctionArgument(pos) => {
+ let name_key = func_ctx.argument_key(pos);
+ let name = &self.names[&name_key];
+ write!(self.out, "{name}")?;
+ }
+ Expression::Binary { op, left, right } => {
+ write!(self.out, "(")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, " {} ", back::binary_operation_str(op))?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Access { base, index } => {
+ self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
+ write!(self.out, "[")?;
+ self.write_expr(module, index, func_ctx)?;
+ write!(self.out, "]")?
+ }
+ Expression::AccessIndex { base, index } => {
+ let base_ty_res = &func_ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+
+ self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
+
+ let base_ty_handle = match *resolved {
+ TypeInner::Pointer { base, space: _ } => {
+ resolved = &module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+
+ match *resolved {
+ TypeInner::Vector { .. } => {
+ // Write vector access as a swizzle
+ write!(self.out, ".{}", back::COMPONENTS[index as usize])?
+ }
+ TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::BindingArray { .. }
+ | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?,
+ TypeInner::Struct { .. } => {
+ // This will never panic in case the type is a `Struct`, this is not true
+ // for other types so we can only check while inside this match arm
+ let ty = base_ty_handle.unwrap();
+
+ write!(
+ self.out,
+ ".{}",
+ &self.names[&NameKey::StructMember(ty, index)]
+ )?
+ }
+ ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
+ }
+ }
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather: None,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ use crate::SampleLevel as Sl;
+
+ let suffix_cmp = match depth_ref {
+ Some(_) => "Compare",
+ None => "",
+ };
+ let suffix_level = match level {
+ Sl::Auto => "",
+ Sl::Zero | Sl::Exact(_) => "Level",
+ Sl::Bias(_) => "Bias",
+ Sl::Gradient { .. } => "Grad",
+ };
+
+ write!(self.out, "textureSample{suffix_cmp}{suffix_level}(")?;
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, sampler, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+
+ if let Some(array_index) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index, func_ctx)?;
+ }
+
+ if let Some(depth_ref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(module, depth_ref, func_ctx)?;
+ }
+
+ match level {
+ Sl::Auto => {}
+ Sl::Zero => {
+ // Level 0 is implied for depth comparison
+ if depth_ref.is_none() {
+ write!(self.out, ", 0.0")?;
+ }
+ }
+ Sl::Exact(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Bias(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Gradient { x, y } => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, x, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, y, func_ctx)?;
+ }
+ }
+
+ if let Some(offset) = offset {
+ write!(self.out, ", ")?;
+ self.write_const_expression(module, offset)?;
+ }
+
+ write!(self.out, ")")?;
+ }
+
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather: Some(component),
+ coordinate,
+ array_index,
+ offset,
+ level: _,
+ depth_ref,
+ } => {
+ let suffix_cmp = match depth_ref {
+ Some(_) => "Compare",
+ None => "",
+ };
+
+ write!(self.out, "textureGather{suffix_cmp}(")?;
+ match *func_ctx.resolve_type(image, &module.types) {
+ TypeInner::Image {
+ class: crate::ImageClass::Depth { multi: _ },
+ ..
+ } => {}
+ _ => {
+ write!(self.out, "{}, ", component as u8)?;
+ }
+ }
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, sampler, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+
+ if let Some(array_index) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index, func_ctx)?;
+ }
+
+ if let Some(depth_ref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(module, depth_ref, func_ctx)?;
+ }
+
+ if let Some(offset) = offset {
+ write!(self.out, ", ")?;
+ self.write_const_expression(module, offset)?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ Expression::ImageQuery { image, query } => {
+ use crate::ImageQuery as Iq;
+
+ let texture_function = match query {
+ Iq::Size { .. } => "textureDimensions",
+ Iq::NumLevels => "textureNumLevels",
+ Iq::NumLayers => "textureNumLayers",
+ Iq::NumSamples => "textureNumSamples",
+ };
+
+ write!(self.out, "{texture_function}(")?;
+ self.write_expr(module, image, func_ctx)?;
+ if let Iq::Size { level: Some(level) } = query {
+ write!(self.out, ", ")?;
+ self.write_expr(module, level, func_ctx)?;
+ };
+ write!(self.out, ")")?;
+ }
+
+ Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ write!(self.out, "textureLoad(")?;
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ if let Some(array_index) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index, func_ctx)?;
+ }
+ if let Some(index) = sample.or(level) {
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ }
+ write!(self.out, ")")?;
+ }
+ Expression::GlobalVariable(handle) => {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{name}")?;
+ }
+
+ Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ let inner = func_ctx.resolve_type(expr, &module.types);
+ match *inner {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ let scalar = crate::Scalar {
+ kind,
+ width: convert.unwrap_or(scalar.width),
+ };
+ let scalar_kind_str = scalar_kind_str(scalar);
+ write!(
+ self.out,
+ "mat{}x{}<{}>",
+ back::vector_size_str(columns),
+ back::vector_size_str(rows),
+ scalar_kind_str
+ )?;
+ }
+ TypeInner::Vector {
+ size,
+ scalar: crate::Scalar { width, .. },
+ } => {
+ let scalar = crate::Scalar {
+ kind,
+ width: convert.unwrap_or(width),
+ };
+ let vector_size_str = back::vector_size_str(size);
+ let scalar_kind_str = scalar_kind_str(scalar);
+ if convert.is_some() {
+ write!(self.out, "vec{vector_size_str}<{scalar_kind_str}>")?;
+ } else {
+ write!(self.out, "bitcast<vec{vector_size_str}<{scalar_kind_str}>>")?;
+ }
+ }
+ TypeInner::Scalar(crate::Scalar { width, .. }) => {
+ let scalar = crate::Scalar {
+ kind,
+ width: convert.unwrap_or(width),
+ };
+ let scalar_kind_str = scalar_kind_str(scalar);
+ if convert.is_some() {
+ write!(self.out, "{scalar_kind_str}")?
+ } else {
+ write!(self.out, "bitcast<{scalar_kind_str}>")?
+ }
+ }
+ _ => {
+ return Err(Error::Unimplemented(format!(
+ "write_expr expression::as {inner:?}"
+ )));
+ }
+ };
+ write!(self.out, "(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Load { pointer } => {
+ let is_atomic_pointer = func_ctx
+ .resolve_type(pointer, &module.types)
+ .is_atomic_pointer(&module.types);
+
+ if is_atomic_pointer {
+ write!(self.out, "atomicLoad(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr_with_indirection(
+ module,
+ pointer,
+ func_ctx,
+ Indirection::Reference,
+ )?;
+ }
+ }
+ Expression::LocalVariable(handle) => {
+ write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
+ }
+ Expression::ArrayLength(expr) => {
+ write!(self.out, "arrayLength(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ enum Function {
+ Regular(&'static str),
+ }
+
+ let function = match fun {
+ Mf::Abs => Function::Regular("abs"),
+ Mf::Min => Function::Regular("min"),
+ Mf::Max => Function::Regular("max"),
+ Mf::Clamp => Function::Regular("clamp"),
+ Mf::Saturate => Function::Regular("saturate"),
+ // trigonometry
+ Mf::Cos => Function::Regular("cos"),
+ Mf::Cosh => Function::Regular("cosh"),
+ Mf::Sin => Function::Regular("sin"),
+ Mf::Sinh => Function::Regular("sinh"),
+ Mf::Tan => Function::Regular("tan"),
+ Mf::Tanh => Function::Regular("tanh"),
+ Mf::Acos => Function::Regular("acos"),
+ Mf::Asin => Function::Regular("asin"),
+ Mf::Atan => Function::Regular("atan"),
+ Mf::Atan2 => Function::Regular("atan2"),
+ Mf::Asinh => Function::Regular("asinh"),
+ Mf::Acosh => Function::Regular("acosh"),
+ Mf::Atanh => Function::Regular("atanh"),
+ Mf::Radians => Function::Regular("radians"),
+ Mf::Degrees => Function::Regular("degrees"),
+ // decomposition
+ Mf::Ceil => Function::Regular("ceil"),
+ Mf::Floor => Function::Regular("floor"),
+ Mf::Round => Function::Regular("round"),
+ Mf::Fract => Function::Regular("fract"),
+ Mf::Trunc => Function::Regular("trunc"),
+ Mf::Modf => Function::Regular("modf"),
+ Mf::Frexp => Function::Regular("frexp"),
+ Mf::Ldexp => Function::Regular("ldexp"),
+ // exponent
+ Mf::Exp => Function::Regular("exp"),
+ Mf::Exp2 => Function::Regular("exp2"),
+ Mf::Log => Function::Regular("log"),
+ Mf::Log2 => Function::Regular("log2"),
+ Mf::Pow => Function::Regular("pow"),
+ // geometry
+ Mf::Dot => Function::Regular("dot"),
+ Mf::Cross => Function::Regular("cross"),
+ Mf::Distance => Function::Regular("distance"),
+ Mf::Length => Function::Regular("length"),
+ Mf::Normalize => Function::Regular("normalize"),
+ Mf::FaceForward => Function::Regular("faceForward"),
+ Mf::Reflect => Function::Regular("reflect"),
+ Mf::Refract => Function::Regular("refract"),
+ // computational
+ Mf::Sign => Function::Regular("sign"),
+ Mf::Fma => Function::Regular("fma"),
+ Mf::Mix => Function::Regular("mix"),
+ Mf::Step => Function::Regular("step"),
+ Mf::SmoothStep => Function::Regular("smoothstep"),
+ Mf::Sqrt => Function::Regular("sqrt"),
+ Mf::InverseSqrt => Function::Regular("inverseSqrt"),
+ Mf::Transpose => Function::Regular("transpose"),
+ Mf::Determinant => Function::Regular("determinant"),
+ // bits
+ Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"),
+ Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"),
+ Mf::CountOneBits => Function::Regular("countOneBits"),
+ Mf::ReverseBits => Function::Regular("reverseBits"),
+ Mf::ExtractBits => Function::Regular("extractBits"),
+ Mf::InsertBits => Function::Regular("insertBits"),
+ Mf::FindLsb => Function::Regular("firstTrailingBit"),
+ Mf::FindMsb => Function::Regular("firstLeadingBit"),
+ // data packing
+ Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"),
+ Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"),
+ Mf::Pack2x16snorm => Function::Regular("pack2x16snorm"),
+ Mf::Pack2x16unorm => Function::Regular("pack2x16unorm"),
+ Mf::Pack2x16float => Function::Regular("pack2x16float"),
+ // data unpacking
+ Mf::Unpack4x8snorm => Function::Regular("unpack4x8snorm"),
+ Mf::Unpack4x8unorm => Function::Regular("unpack4x8unorm"),
+ Mf::Unpack2x16snorm => Function::Regular("unpack2x16snorm"),
+ Mf::Unpack2x16unorm => Function::Regular("unpack2x16unorm"),
+ Mf::Unpack2x16float => Function::Regular("unpack2x16float"),
+ Mf::Inverse | Mf::Outer => {
+ return Err(Error::UnsupportedMathFunction(fun));
+ }
+ };
+
+ match function {
+ Function::Regular(fun_name) => {
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ write!(self.out, ")")?
+ }
+ }
+ }
+
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.write_expr(module, vector, func_ctx)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ self.out.write_char(back::COMPONENTS[sc as usize])?;
+ }
+ }
+ Expression::Unary { op, expr } => {
+ let unary = match op {
+ crate::UnaryOperator::Negate => "-",
+ crate::UnaryOperator::LogicalNot => "!",
+ crate::UnaryOperator::BitwiseNot => "~",
+ };
+
+ write!(self.out, "{unary}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+
+ write!(self.out, ")")?
+ }
+
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ write!(self.out, "select(")?;
+ self.write_expr(module, reject, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, accept, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, condition, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Derivative { axis, ctrl, expr } => {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ let op = match (axis, ctrl) {
+ (Axis::X, Ctrl::Coarse) => "dpdxCoarse",
+ (Axis::X, Ctrl::Fine) => "dpdxFine",
+ (Axis::X, Ctrl::None) => "dpdx",
+ (Axis::Y, Ctrl::Coarse) => "dpdyCoarse",
+ (Axis::Y, Ctrl::Fine) => "dpdyFine",
+ (Axis::Y, Ctrl::None) => "dpdy",
+ (Axis::Width, Ctrl::Coarse) => "fwidthCoarse",
+ (Axis::Width, Ctrl::Fine) => "fwidthFine",
+ (Axis::Width, Ctrl::None) => "fwidth",
+ };
+ write!(self.out, "{op}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+
+ let fun_name = match fun {
+ Rf::All => "all",
+ Rf::Any => "any",
+ _ => return Err(Error::UnsupportedRelationalFunction(fun)),
+ };
+ write!(self.out, "{fun_name}(")?;
+
+ self.write_expr(module, argument, func_ctx)?;
+
+ write!(self.out, ")")?
+ }
+ // Not supported yet
+ Expression::RayQueryGetIntersection { .. } => unreachable!(),
+ // Nothing to do here, since call expression already cached
+ Expression::CallResult(_)
+ | Expression::AtomicResult { .. }
+ | Expression::RayQueryProceedResult
+ | Expression::WorkGroupUniformLoadResult { .. } => {}
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write global variables
+ /// # Notes
+ /// Always adds a newline
+ fn write_global(
+ &mut self,
+ module: &Module,
+ global: &crate::GlobalVariable,
+ handle: Handle<crate::GlobalVariable>,
+ ) -> BackendResult {
+ // Write group and binding attributes if present
+ if let Some(ref binding) = global.binding {
+ self.write_attributes(&[
+ Attribute::Group(binding.group),
+ Attribute::Binding(binding.binding),
+ ])?;
+ writeln!(self.out)?;
+ }
+
+ // First write global name and address space if supported
+ write!(self.out, "var")?;
+ let (address, maybe_access) = address_space_str(global.space);
+ if let Some(space) = address {
+ write!(self.out, "<{space}")?;
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {access}")?;
+ }
+ write!(self.out, ">")?;
+ }
+ write!(
+ self.out,
+ " {}: ",
+ &self.names[&NameKey::GlobalVariable(handle)]
+ )?;
+
+ // Write global type
+ self.write_type(module, global.ty)?;
+
+ // Write initializer
+ if let Some(init) = global.init {
+ write!(self.out, " = ")?;
+ self.write_const_expression(module, init)?;
+ }
+
+ // End with semicolon
+ writeln!(self.out, ";")?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write global constants
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_global_constant(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Constant>,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Constant(handle)];
+ // First write only constant name
+ write!(self.out, "const {name}: ")?;
+ self.write_type(module, module.constants[handle].ty)?;
+ write!(self.out, " = ")?;
+ let init = module.constants[handle].init;
+ self.write_const_expression(module, init)?;
+ writeln!(self.out, ";")?;
+
+ Ok(())
+ }
+
+ // See https://github.com/rust-lang/rust-clippy/issues/4979.
+ #[allow(clippy::missing_const_for_fn)]
+ pub fn finish(self) -> W {
+ self.out
+ }
+}
+
+fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> {
+ use crate::BuiltIn as Bi;
+
+ Ok(match built_in {
+ Bi::VertexIndex => "vertex_index",
+ Bi::InstanceIndex => "instance_index",
+ Bi::Position { .. } => "position",
+ Bi::FrontFacing => "front_facing",
+ Bi::FragDepth => "frag_depth",
+ Bi::LocalInvocationId => "local_invocation_id",
+ Bi::LocalInvocationIndex => "local_invocation_index",
+ Bi::GlobalInvocationId => "global_invocation_id",
+ Bi::WorkGroupId => "workgroup_id",
+ Bi::NumWorkGroups => "num_workgroups",
+ Bi::SampleIndex => "sample_index",
+ Bi::SampleMask => "sample_mask",
+ Bi::PrimitiveIndex => "primitive_index",
+ Bi::ViewIndex => "view_index",
+ Bi::BaseInstance
+ | Bi::BaseVertex
+ | Bi::ClipDistance
+ | Bi::CullDistance
+ | Bi::PointSize
+ | Bi::PointCoord
+ | Bi::WorkGroupSize => {
+ return Err(Error::Custom(format!("Unsupported builtin {built_in:?}")))
+ }
+ })
+}
+
+const fn image_dimension_str(dim: crate::ImageDimension) -> &'static str {
+ use crate::ImageDimension as IDim;
+
+ match dim {
+ IDim::D1 => "1d",
+ IDim::D2 => "2d",
+ IDim::D3 => "3d",
+ IDim::Cube => "cube",
+ }
+}
+
+const fn scalar_kind_str(scalar: crate::Scalar) -> &'static str {
+ use crate::Scalar;
+ use crate::ScalarKind as Sk;
+
+ match scalar {
+ Scalar {
+ kind: Sk::Float,
+ width: 8,
+ } => "f64",
+ Scalar {
+ kind: Sk::Float,
+ width: 4,
+ } => "f32",
+ Scalar {
+ kind: Sk::Sint,
+ width: 4,
+ } => "i32",
+ Scalar {
+ kind: Sk::Uint,
+ width: 4,
+ } => "u32",
+ Scalar {
+ kind: Sk::Bool,
+ width: 1,
+ } => "bool",
+ _ => unreachable!(),
+ }
+}
+
+const fn storage_format_str(format: crate::StorageFormat) -> &'static str {
+ use crate::StorageFormat as Sf;
+
+ match format {
+ Sf::R8Unorm => "r8unorm",
+ Sf::R8Snorm => "r8snorm",
+ Sf::R8Uint => "r8uint",
+ Sf::R8Sint => "r8sint",
+ Sf::R16Uint => "r16uint",
+ Sf::R16Sint => "r16sint",
+ Sf::R16Float => "r16float",
+ Sf::Rg8Unorm => "rg8unorm",
+ Sf::Rg8Snorm => "rg8snorm",
+ Sf::Rg8Uint => "rg8uint",
+ Sf::Rg8Sint => "rg8sint",
+ Sf::R32Uint => "r32uint",
+ Sf::R32Sint => "r32sint",
+ Sf::R32Float => "r32float",
+ Sf::Rg16Uint => "rg16uint",
+ Sf::Rg16Sint => "rg16sint",
+ Sf::Rg16Float => "rg16float",
+ Sf::Rgba8Unorm => "rgba8unorm",
+ Sf::Rgba8Snorm => "rgba8snorm",
+ Sf::Rgba8Uint => "rgba8uint",
+ Sf::Rgba8Sint => "rgba8sint",
+ Sf::Bgra8Unorm => "bgra8unorm",
+ Sf::Rgb10a2Uint => "rgb10a2uint",
+ Sf::Rgb10a2Unorm => "rgb10a2unorm",
+ Sf::Rg11b10Float => "rg11b10float",
+ Sf::Rg32Uint => "rg32uint",
+ Sf::Rg32Sint => "rg32sint",
+ Sf::Rg32Float => "rg32float",
+ Sf::Rgba16Uint => "rgba16uint",
+ Sf::Rgba16Sint => "rgba16sint",
+ Sf::Rgba16Float => "rgba16float",
+ Sf::Rgba32Uint => "rgba32uint",
+ Sf::Rgba32Sint => "rgba32sint",
+ Sf::Rgba32Float => "rgba32float",
+ Sf::R16Unorm => "r16unorm",
+ Sf::R16Snorm => "r16snorm",
+ Sf::Rg16Unorm => "rg16unorm",
+ Sf::Rg16Snorm => "rg16snorm",
+ Sf::Rgba16Unorm => "rgba16unorm",
+ Sf::Rgba16Snorm => "rgba16snorm",
+ }
+}
+
+/// Helper function that returns the string corresponding to the WGSL interpolation qualifier
+const fn interpolation_str(interpolation: crate::Interpolation) -> &'static str {
+ use crate::Interpolation as I;
+
+ match interpolation {
+ I::Perspective => "perspective",
+ I::Linear => "linear",
+ I::Flat => "flat",
+ }
+}
+
+/// Return the WGSL auxiliary qualifier for the given sampling value.
+const fn sampling_str(sampling: crate::Sampling) -> &'static str {
+ use crate::Sampling as S;
+
+ match sampling {
+ S::Center => "",
+ S::Centroid => "centroid",
+ S::Sample => "sample",
+ }
+}
+
+const fn address_space_str(
+ space: crate::AddressSpace,
+) -> (Option<&'static str>, Option<&'static str>) {
+ use crate::AddressSpace as As;
+
+ (
+ Some(match space {
+ As::Private => "private",
+ As::Uniform => "uniform",
+ As::Storage { access } => {
+ if access.contains(crate::StorageAccess::STORE) {
+ return (Some("storage"), Some("read_write"));
+ } else {
+ "storage"
+ }
+ }
+ As::PushConstant => "push_constant",
+ As::WorkGroup => "workgroup",
+ As::Handle => return (None, None),
+ As::Function => "function",
+ }),
+ None,
+ )
+}
+
+fn map_binding_to_attribute(binding: &crate::Binding) -> Vec<Attribute> {
+ match *binding {
+ crate::Binding::BuiltIn(built_in) => {
+ if let crate::BuiltIn::Position { invariant: true } = built_in {
+ vec![Attribute::BuiltIn(built_in), Attribute::Invariant]
+ } else {
+ vec![Attribute::BuiltIn(built_in)]
+ }
+ }
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ second_blend_source: false,
+ } => vec![
+ Attribute::Location(location),
+ Attribute::Interpolate(interpolation, sampling),
+ ],
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ second_blend_source: true,
+ } => vec![
+ Attribute::Location(location),
+ Attribute::SecondBlendSource,
+ Attribute::Interpolate(interpolation, sampling),
+ ],
+ }
+}
diff --git a/third_party/rust/naga/src/block.rs b/third_party/rust/naga/src/block.rs
new file mode 100644
index 0000000000..0abda9da7c
--- /dev/null
+++ b/third_party/rust/naga/src/block.rs
@@ -0,0 +1,123 @@
+use crate::{Span, Statement};
+use std::ops::{Deref, DerefMut, RangeBounds};
+
+/// A code block is a vector of statements, with maybe a vector of spans.
+#[derive(Debug, Clone, Default)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "serialize", serde(transparent))]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub struct Block {
+ body: Vec<Statement>,
+ #[cfg_attr(feature = "serialize", serde(skip))]
+ span_info: Vec<Span>,
+}
+
+impl Block {
+ pub const fn new() -> Self {
+ Self {
+ body: Vec::new(),
+ span_info: Vec::new(),
+ }
+ }
+
+ pub fn from_vec(body: Vec<Statement>) -> Self {
+ let span_info = std::iter::repeat(Span::default())
+ .take(body.len())
+ .collect();
+ Self { body, span_info }
+ }
+
+ pub fn with_capacity(capacity: usize) -> Self {
+ Self {
+ body: Vec::with_capacity(capacity),
+ span_info: Vec::with_capacity(capacity),
+ }
+ }
+
+ #[allow(unused_variables)]
+ pub fn push(&mut self, end: Statement, span: Span) {
+ self.body.push(end);
+ self.span_info.push(span);
+ }
+
+ pub fn extend(&mut self, item: Option<(Statement, Span)>) {
+ if let Some((end, span)) = item {
+ self.push(end, span)
+ }
+ }
+
+ pub fn extend_block(&mut self, other: Self) {
+ self.span_info.extend(other.span_info);
+ self.body.extend(other.body);
+ }
+
+ pub fn append(&mut self, other: &mut Self) {
+ self.span_info.append(&mut other.span_info);
+ self.body.append(&mut other.body);
+ }
+
+ pub fn cull<R: RangeBounds<usize> + Clone>(&mut self, range: R) {
+ self.span_info.drain(range.clone());
+ self.body.drain(range);
+ }
+
+ pub fn splice<R: RangeBounds<usize> + Clone>(&mut self, range: R, other: Self) {
+ self.span_info.splice(range.clone(), other.span_info);
+ self.body.splice(range, other.body);
+ }
+ pub fn span_iter(&self) -> impl Iterator<Item = (&Statement, &Span)> {
+ let span_iter = self.span_info.iter();
+ self.body.iter().zip(span_iter)
+ }
+
+ pub fn span_iter_mut(&mut self) -> impl Iterator<Item = (&mut Statement, Option<&mut Span>)> {
+ let span_iter = self.span_info.iter_mut().map(Some);
+ self.body.iter_mut().zip(span_iter)
+ }
+
+ pub fn is_empty(&self) -> bool {
+ self.body.is_empty()
+ }
+
+ pub fn len(&self) -> usize {
+ self.body.len()
+ }
+}
+
+impl Deref for Block {
+ type Target = [Statement];
+ fn deref(&self) -> &[Statement] {
+ &self.body
+ }
+}
+
+impl DerefMut for Block {
+ fn deref_mut(&mut self) -> &mut [Statement] {
+ &mut self.body
+ }
+}
+
+impl<'a> IntoIterator for &'a Block {
+ type Item = &'a Statement;
+ type IntoIter = std::slice::Iter<'a, Statement>;
+
+ fn into_iter(self) -> std::slice::Iter<'a, Statement> {
+ self.iter()
+ }
+}
+
+#[cfg(feature = "deserialize")]
+impl<'de> serde::Deserialize<'de> for Block {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ Ok(Self::from_vec(Vec::deserialize(deserializer)?))
+ }
+}
+
+impl From<Vec<Statement>> for Block {
+ fn from(body: Vec<Statement>) -> Self {
+ Self::from_vec(body)
+ }
+}
diff --git a/third_party/rust/naga/src/compact/expressions.rs b/third_party/rust/naga/src/compact/expressions.rs
new file mode 100644
index 0000000000..301bbe3240
--- /dev/null
+++ b/third_party/rust/naga/src/compact/expressions.rs
@@ -0,0 +1,389 @@
+use super::{HandleMap, HandleSet, ModuleMap};
+use crate::arena::{Arena, Handle};
+
+pub struct ExpressionTracer<'tracer> {
+ pub constants: &'tracer Arena<crate::Constant>,
+
+ /// The arena in which we are currently tracing expressions.
+ pub expressions: &'tracer Arena<crate::Expression>,
+
+ /// The used map for `types`.
+ pub types_used: &'tracer mut HandleSet<crate::Type>,
+
+ /// The used map for `constants`.
+ pub constants_used: &'tracer mut HandleSet<crate::Constant>,
+
+ /// The used set for `arena`.
+ ///
+ /// This points to whatever arena holds the expressions we are
+ /// currently tracing: either a function's expression arena, or
+ /// the module's constant expression arena.
+ pub expressions_used: &'tracer mut HandleSet<crate::Expression>,
+
+ /// The used set for the module's `const_expressions` arena.
+ ///
+ /// If `None`, we are already tracing the constant expressions,
+ /// and `expressions_used` already refers to their handle set.
+ pub const_expressions_used: Option<&'tracer mut HandleSet<crate::Expression>>,
+}
+
+impl<'tracer> ExpressionTracer<'tracer> {
+ /// Propagate usage through `self.expressions`, starting with `self.expressions_used`.
+ ///
+ /// Treat `self.expressions_used` as the initial set of "known
+ /// live" expressions, and follow through to identify all
+ /// transitively used expressions.
+ ///
+ /// Mark types, constants, and constant expressions used directly
+ /// by `self.expressions` as used. Items used indirectly are not
+ /// marked.
+ ///
+ /// [fe]: crate::Function::expressions
+ /// [ce]: crate::Module::const_expressions
+ pub fn trace_expressions(&mut self) {
+ log::trace!(
+ "entering trace_expression of {}",
+ if self.const_expressions_used.is_some() {
+ "function expressions"
+ } else {
+ "const expressions"
+ }
+ );
+
+ // We don't need recursion or a work list. Because an
+ // expression may only refer to other expressions that precede
+ // it in the arena, it suffices to make a single pass over the
+ // arena from back to front, marking the referents of used
+ // expressions as used themselves.
+ for (handle, expr) in self.expressions.iter().rev() {
+ // If this expression isn't used, it doesn't matter what it uses.
+ if !self.expressions_used.contains(handle) {
+ continue;
+ }
+
+ log::trace!("tracing new expression {:?}", expr);
+
+ use crate::Expression as Ex;
+ match *expr {
+ // Expressions that do not contain handles that need to be traced.
+ Ex::Literal(_)
+ | Ex::FunctionArgument(_)
+ | Ex::GlobalVariable(_)
+ | Ex::LocalVariable(_)
+ | Ex::CallResult(_)
+ | Ex::RayQueryProceedResult => {}
+
+ Ex::Constant(handle) => {
+ self.constants_used.insert(handle);
+ // Constants and expressions are mutually recursive, which
+ // complicates our nice one-pass algorithm. However, since
+ // constants don't refer to each other, we can get around
+ // this by looking *through* each constant and marking its
+ // initializer as used. Since `expr` refers to the constant,
+ // and the constant refers to the initializer, it must
+ // precede `expr` in the arena.
+ let init = self.constants[handle].init;
+ match self.const_expressions_used {
+ Some(ref mut used) => used.insert(init),
+ None => self.expressions_used.insert(init),
+ }
+ }
+ Ex::ZeroValue(ty) => self.types_used.insert(ty),
+ Ex::Compose { ty, ref components } => {
+ self.types_used.insert(ty);
+ self.expressions_used
+ .insert_iter(components.iter().cloned());
+ }
+ Ex::Access { base, index } => self.expressions_used.insert_iter([base, index]),
+ Ex::AccessIndex { base, index: _ } => self.expressions_used.insert(base),
+ Ex::Splat { size: _, value } => self.expressions_used.insert(value),
+ Ex::Swizzle {
+ size: _,
+ vector,
+ pattern: _,
+ } => self.expressions_used.insert(vector),
+ Ex::Load { pointer } => self.expressions_used.insert(pointer),
+ Ex::ImageSample {
+ image,
+ sampler,
+ gather: _,
+ coordinate,
+ array_index,
+ offset,
+ ref level,
+ depth_ref,
+ } => {
+ self.expressions_used
+ .insert_iter([image, sampler, coordinate]);
+ self.expressions_used.insert_iter(array_index);
+ match self.const_expressions_used {
+ Some(ref mut used) => used.insert_iter(offset),
+ None => self.expressions_used.insert_iter(offset),
+ }
+ use crate::SampleLevel as Sl;
+ match *level {
+ Sl::Auto | Sl::Zero => {}
+ Sl::Exact(expr) | Sl::Bias(expr) => self.expressions_used.insert(expr),
+ Sl::Gradient { x, y } => self.expressions_used.insert_iter([x, y]),
+ }
+ self.expressions_used.insert_iter(depth_ref);
+ }
+ Ex::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ self.expressions_used.insert(image);
+ self.expressions_used.insert(coordinate);
+ self.expressions_used.insert_iter(array_index);
+ self.expressions_used.insert_iter(sample);
+ self.expressions_used.insert_iter(level);
+ }
+ Ex::ImageQuery { image, ref query } => {
+ self.expressions_used.insert(image);
+ use crate::ImageQuery as Iq;
+ match *query {
+ Iq::Size { level } => self.expressions_used.insert_iter(level),
+ Iq::NumLevels | Iq::NumLayers | Iq::NumSamples => {}
+ }
+ }
+ Ex::Unary { op: _, expr } => self.expressions_used.insert(expr),
+ Ex::Binary { op: _, left, right } => {
+ self.expressions_used.insert_iter([left, right]);
+ }
+ Ex::Select {
+ condition,
+ accept,
+ reject,
+ } => self
+ .expressions_used
+ .insert_iter([condition, accept, reject]),
+ Ex::Derivative {
+ axis: _,
+ ctrl: _,
+ expr,
+ } => self.expressions_used.insert(expr),
+ Ex::Relational { fun: _, argument } => self.expressions_used.insert(argument),
+ Ex::Math {
+ fun: _,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ self.expressions_used.insert(arg);
+ self.expressions_used.insert_iter(arg1);
+ self.expressions_used.insert_iter(arg2);
+ self.expressions_used.insert_iter(arg3);
+ }
+ Ex::As {
+ expr,
+ kind: _,
+ convert: _,
+ } => self.expressions_used.insert(expr),
+ Ex::AtomicResult { ty, comparison: _ } => self.types_used.insert(ty),
+ Ex::WorkGroupUniformLoadResult { ty } => self.types_used.insert(ty),
+ Ex::ArrayLength(expr) => self.expressions_used.insert(expr),
+ Ex::RayQueryGetIntersection {
+ query,
+ committed: _,
+ } => self.expressions_used.insert(query),
+ }
+ }
+ }
+}
+
+impl ModuleMap {
+ /// Fix up all handles in `expr`.
+ ///
+ /// Use the expression handle remappings in `operand_map`, and all
+ /// other mappings from `self`.
+ pub fn adjust_expression(
+ &self,
+ expr: &mut crate::Expression,
+ operand_map: &HandleMap<crate::Expression>,
+ ) {
+ let adjust = |expr: &mut Handle<crate::Expression>| {
+ operand_map.adjust(expr);
+ };
+
+ use crate::Expression as Ex;
+ match *expr {
+ // Expressions that do not contain handles that need to be adjusted.
+ Ex::Literal(_)
+ | Ex::FunctionArgument(_)
+ | Ex::GlobalVariable(_)
+ | Ex::LocalVariable(_)
+ | Ex::CallResult(_)
+ | Ex::RayQueryProceedResult => {}
+
+ // Expressions that contain handles that need to be adjusted.
+ Ex::Constant(ref mut constant) => self.constants.adjust(constant),
+ Ex::ZeroValue(ref mut ty) => self.types.adjust(ty),
+ Ex::Compose {
+ ref mut ty,
+ ref mut components,
+ } => {
+ self.types.adjust(ty);
+ for component in components {
+ adjust(component);
+ }
+ }
+ Ex::Access {
+ ref mut base,
+ ref mut index,
+ } => {
+ adjust(base);
+ adjust(index);
+ }
+ Ex::AccessIndex {
+ ref mut base,
+ index: _,
+ } => adjust(base),
+ Ex::Splat {
+ size: _,
+ ref mut value,
+ } => adjust(value),
+ Ex::Swizzle {
+ size: _,
+ ref mut vector,
+ pattern: _,
+ } => adjust(vector),
+ Ex::Load { ref mut pointer } => adjust(pointer),
+ Ex::ImageSample {
+ ref mut image,
+ ref mut sampler,
+ gather: _,
+ ref mut coordinate,
+ ref mut array_index,
+ ref mut offset,
+ ref mut level,
+ ref mut depth_ref,
+ } => {
+ adjust(image);
+ adjust(sampler);
+ adjust(coordinate);
+ operand_map.adjust_option(array_index);
+ if let Some(ref mut offset) = *offset {
+ self.const_expressions.adjust(offset);
+ }
+ self.adjust_sample_level(level, operand_map);
+ operand_map.adjust_option(depth_ref);
+ }
+ Ex::ImageLoad {
+ ref mut image,
+ ref mut coordinate,
+ ref mut array_index,
+ ref mut sample,
+ ref mut level,
+ } => {
+ adjust(image);
+ adjust(coordinate);
+ operand_map.adjust_option(array_index);
+ operand_map.adjust_option(sample);
+ operand_map.adjust_option(level);
+ }
+ Ex::ImageQuery {
+ ref mut image,
+ ref mut query,
+ } => {
+ adjust(image);
+ self.adjust_image_query(query, operand_map);
+ }
+ Ex::Unary {
+ op: _,
+ ref mut expr,
+ } => adjust(expr),
+ Ex::Binary {
+ op: _,
+ ref mut left,
+ ref mut right,
+ } => {
+ adjust(left);
+ adjust(right);
+ }
+ Ex::Select {
+ ref mut condition,
+ ref mut accept,
+ ref mut reject,
+ } => {
+ adjust(condition);
+ adjust(accept);
+ adjust(reject);
+ }
+ Ex::Derivative {
+ axis: _,
+ ctrl: _,
+ ref mut expr,
+ } => adjust(expr),
+ Ex::Relational {
+ fun: _,
+ ref mut argument,
+ } => adjust(argument),
+ Ex::Math {
+ fun: _,
+ ref mut arg,
+ ref mut arg1,
+ ref mut arg2,
+ ref mut arg3,
+ } => {
+ adjust(arg);
+ operand_map.adjust_option(arg1);
+ operand_map.adjust_option(arg2);
+ operand_map.adjust_option(arg3);
+ }
+ Ex::As {
+ ref mut expr,
+ kind: _,
+ convert: _,
+ } => adjust(expr),
+ Ex::AtomicResult {
+ ref mut ty,
+ comparison: _,
+ } => self.types.adjust(ty),
+ Ex::WorkGroupUniformLoadResult { ref mut ty } => self.types.adjust(ty),
+ Ex::ArrayLength(ref mut expr) => adjust(expr),
+ Ex::RayQueryGetIntersection {
+ ref mut query,
+ committed: _,
+ } => adjust(query),
+ }
+ }
+
+ fn adjust_sample_level(
+ &self,
+ level: &mut crate::SampleLevel,
+ operand_map: &HandleMap<crate::Expression>,
+ ) {
+ let adjust = |expr: &mut Handle<crate::Expression>| operand_map.adjust(expr);
+
+ use crate::SampleLevel as Sl;
+ match *level {
+ Sl::Auto | Sl::Zero => {}
+ Sl::Exact(ref mut expr) => adjust(expr),
+ Sl::Bias(ref mut expr) => adjust(expr),
+ Sl::Gradient {
+ ref mut x,
+ ref mut y,
+ } => {
+ adjust(x);
+ adjust(y);
+ }
+ }
+ }
+
+ fn adjust_image_query(
+ &self,
+ query: &mut crate::ImageQuery,
+ operand_map: &HandleMap<crate::Expression>,
+ ) {
+ use crate::ImageQuery as Iq;
+
+ match *query {
+ Iq::Size { ref mut level } => operand_map.adjust_option(level),
+ Iq::NumLevels | Iq::NumLayers | Iq::NumSamples => {}
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/compact/functions.rs b/third_party/rust/naga/src/compact/functions.rs
new file mode 100644
index 0000000000..b0d08c7e96
--- /dev/null
+++ b/third_party/rust/naga/src/compact/functions.rs
@@ -0,0 +1,106 @@
+use super::handle_set_map::HandleSet;
+use super::{FunctionMap, ModuleMap};
+
+pub struct FunctionTracer<'a> {
+ pub function: &'a crate::Function,
+ pub constants: &'a crate::Arena<crate::Constant>,
+
+ pub types_used: &'a mut HandleSet<crate::Type>,
+ pub constants_used: &'a mut HandleSet<crate::Constant>,
+ pub const_expressions_used: &'a mut HandleSet<crate::Expression>,
+
+ /// Function-local expressions used.
+ pub expressions_used: HandleSet<crate::Expression>,
+}
+
+impl<'a> FunctionTracer<'a> {
+ pub fn trace(&mut self) {
+ for argument in self.function.arguments.iter() {
+ self.types_used.insert(argument.ty);
+ }
+
+ if let Some(ref result) = self.function.result {
+ self.types_used.insert(result.ty);
+ }
+
+ for (_, local) in self.function.local_variables.iter() {
+ self.types_used.insert(local.ty);
+ if let Some(init) = local.init {
+ self.expressions_used.insert(init);
+ }
+ }
+
+ // Treat named expressions as alive, for the sake of our test suite,
+ // which uses `let blah = expr;` to exercise lots of things.
+ for (&value, _name) in &self.function.named_expressions {
+ self.expressions_used.insert(value);
+ }
+
+ self.trace_block(&self.function.body);
+
+ // Given that `trace_block` has marked the expressions used
+ // directly by statements, walk the arena to find all
+ // expressions used, directly or indirectly.
+ self.as_expression().trace_expressions();
+ }
+
+ fn as_expression(&mut self) -> super::expressions::ExpressionTracer {
+ super::expressions::ExpressionTracer {
+ constants: self.constants,
+ expressions: &self.function.expressions,
+
+ types_used: self.types_used,
+ constants_used: self.constants_used,
+ expressions_used: &mut self.expressions_used,
+ const_expressions_used: Some(&mut self.const_expressions_used),
+ }
+ }
+}
+
+impl FunctionMap {
+ pub fn compact(
+ &self,
+ function: &mut crate::Function,
+ module_map: &ModuleMap,
+ reuse: &mut crate::NamedExpressions,
+ ) {
+ assert!(reuse.is_empty());
+
+ for argument in function.arguments.iter_mut() {
+ module_map.types.adjust(&mut argument.ty);
+ }
+
+ if let Some(ref mut result) = function.result {
+ module_map.types.adjust(&mut result.ty);
+ }
+
+ for (_, local) in function.local_variables.iter_mut() {
+ log::trace!("adjusting local variable {:?}", local.name);
+ module_map.types.adjust(&mut local.ty);
+ if let Some(ref mut init) = local.init {
+ self.expressions.adjust(init);
+ }
+ }
+
+ // Drop unused expressions, reusing existing storage.
+ function.expressions.retain_mut(|handle, expr| {
+ if self.expressions.used(handle) {
+ module_map.adjust_expression(expr, &self.expressions);
+ true
+ } else {
+ false
+ }
+ });
+
+ // Adjust named expressions.
+ for (mut handle, name) in function.named_expressions.drain(..) {
+ self.expressions.adjust(&mut handle);
+ reuse.insert(handle, name);
+ }
+ std::mem::swap(&mut function.named_expressions, reuse);
+ assert!(reuse.is_empty());
+
+ // Adjust statements.
+ self.adjust_body(function);
+ }
+}
diff --git a/third_party/rust/naga/src/compact/handle_set_map.rs b/third_party/rust/naga/src/compact/handle_set_map.rs
new file mode 100644
index 0000000000..c716ca8294
--- /dev/null
+++ b/third_party/rust/naga/src/compact/handle_set_map.rs
@@ -0,0 +1,170 @@
+use crate::arena::{Arena, Handle, Range, UniqueArena};
+
+type Index = std::num::NonZeroU32;
+
+/// A set of `Handle<T>` values.
+pub struct HandleSet<T> {
+ /// Bound on zero-based indexes of handles stored in this set.
+ len: usize,
+
+ /// `members[i]` is true if the handle with zero-based index `i`
+ /// is a member.
+ members: bit_set::BitSet,
+
+ /// This type is indexed by values of type `T`.
+ as_keys: std::marker::PhantomData<T>,
+}
+
+impl<T> HandleSet<T> {
+ pub fn for_arena(arena: &impl ArenaType<T>) -> Self {
+ let len = arena.len();
+ Self {
+ len,
+ members: bit_set::BitSet::with_capacity(len),
+ as_keys: std::marker::PhantomData,
+ }
+ }
+
+ /// Add `handle` to the set.
+ pub fn insert(&mut self, handle: Handle<T>) {
+ // Note that, oddly, `Handle::index` does not return a 1-based
+ // `Index`, but rather a zero-based `usize`.
+ self.members.insert(handle.index());
+ }
+
+ /// Add handles from `iter` to the set.
+ pub fn insert_iter(&mut self, iter: impl IntoIterator<Item = Handle<T>>) {
+ for handle in iter {
+ self.insert(handle);
+ }
+ }
+
+ pub fn contains(&self, handle: Handle<T>) -> bool {
+ // Note that, oddly, `Handle::index` does not return a 1-based
+ // `Index`, but rather a zero-based `usize`.
+ self.members.contains(handle.index())
+ }
+}
+
+pub trait ArenaType<T> {
+ fn len(&self) -> usize;
+}
+
+impl<T> ArenaType<T> for Arena<T> {
+ fn len(&self) -> usize {
+ self.len()
+ }
+}
+
+impl<T: std::hash::Hash + Eq> ArenaType<T> for UniqueArena<T> {
+ fn len(&self) -> usize {
+ self.len()
+ }
+}
+
+/// A map from old handle indices to new, compressed handle indices.
+pub struct HandleMap<T> {
+ /// The indices assigned to handles in the compacted module.
+ ///
+ /// If `new_index[i]` is `Some(n)`, then `n` is the 1-based
+ /// `Index` of the compacted `Handle` corresponding to the
+ /// pre-compacted `Handle` whose zero-based index is `i`. ("Clear
+ /// as mud.")
+ new_index: Vec<Option<Index>>,
+
+ /// This type is indexed by values of type `T`.
+ as_keys: std::marker::PhantomData<T>,
+}
+
+impl<T: 'static> HandleMap<T> {
+ pub fn from_set(set: HandleSet<T>) -> Self {
+ let mut next_index = Index::new(1).unwrap();
+ Self {
+ new_index: (0..set.len)
+ .map(|zero_based_index| {
+ if set.members.contains(zero_based_index) {
+ // This handle will be retained in the compacted version,
+ // so assign it a new index.
+ let this = next_index;
+ next_index = next_index.checked_add(1).unwrap();
+ Some(this)
+ } else {
+ // This handle will be omitted in the compacted version.
+ None
+ }
+ })
+ .collect(),
+ as_keys: std::marker::PhantomData,
+ }
+ }
+
+ /// Return true if `old` is used in the compacted module.
+ pub fn used(&self, old: Handle<T>) -> bool {
+ self.new_index[old.index()].is_some()
+ }
+
+ /// Return the counterpart to `old` in the compacted module.
+ ///
+ /// If we thought `old` wouldn't be used in the compacted module, return
+ /// `None`.
+ pub fn try_adjust(&self, old: Handle<T>) -> Option<Handle<T>> {
+ log::trace!(
+ "adjusting {} handle [{}] -> [{:?}]",
+ std::any::type_name::<T>(),
+ old.index() + 1,
+ self.new_index[old.index()]
+ );
+ // Note that `Handle::index` returns a zero-based index,
+ // but `Handle::new` accepts a 1-based `Index`.
+ self.new_index[old.index()].map(Handle::new)
+ }
+
+ /// Return the counterpart to `old` in the compacted module.
+ ///
+ /// If we thought `old` wouldn't be used in the compacted module, panic.
+ pub fn adjust(&self, handle: &mut Handle<T>) {
+ *handle = self.try_adjust(*handle).unwrap();
+ }
+
+ /// Like `adjust`, but for optional handles.
+ pub fn adjust_option(&self, handle: &mut Option<Handle<T>>) {
+ if let Some(ref mut handle) = *handle {
+ self.adjust(handle);
+ }
+ }
+
+ /// Shrink `range` to include only used handles.
+ ///
+ /// Fortunately, compaction doesn't arbitrarily scramble the expressions
+ /// in the arena, but instead preserves the order of the elements while
+ /// squeezing out unused ones. That means that a contiguous range in the
+ /// pre-compacted arena always maps to a contiguous range in the
+ /// post-compacted arena. So we just need to adjust the endpoints.
+ ///
+ /// Compaction may have eliminated the endpoints themselves.
+ ///
+ /// Use `compacted_arena` to bounds-check the result.
+ pub fn adjust_range(&self, range: &mut Range<T>, compacted_arena: &Arena<T>) {
+ let mut index_range = range.zero_based_index_range();
+ let compacted;
+ // Remember that the indices we retrieve from `new_index` are 1-based
+ // compacted indices, but the index range we're computing is zero-based
+ // compacted indices.
+ if let Some(first1) = index_range.find_map(|i| self.new_index[i as usize]) {
+ // The first call to `find_map` mutated `index_range` to hold the
+ // remainder of original range, which is exactly the range we need
+ // to search for the new last handle.
+ if let Some(last1) = index_range.rev().find_map(|i| self.new_index[i as usize]) {
+ // Build a zero-based end-exclusive range, given one-based handle indices.
+ compacted = first1.get() - 1..last1.get();
+ } else {
+ // The range contains only a single live handle, which
+ // we identified with the first `find_map` call.
+ compacted = first1.get() - 1..first1.get();
+ }
+ } else {
+ compacted = 0..0;
+ };
+ *range = Range::from_zero_based_index_range(compacted, compacted_arena);
+ }
+}
diff --git a/third_party/rust/naga/src/compact/mod.rs b/third_party/rust/naga/src/compact/mod.rs
new file mode 100644
index 0000000000..b4e57ed5c9
--- /dev/null
+++ b/third_party/rust/naga/src/compact/mod.rs
@@ -0,0 +1,307 @@
+mod expressions;
+mod functions;
+mod handle_set_map;
+mod statements;
+mod types;
+
+use crate::{arena, compact::functions::FunctionTracer};
+use handle_set_map::{HandleMap, HandleSet};
+
+/// Remove unused types, expressions, and constants from `module`.
+///
+/// Assuming that all globals, named constants, special types,
+/// functions and entry points in `module` are used, determine which
+/// types, constants, and expressions (both function-local and global
+/// constant expressions) are actually used, and remove the rest,
+/// adjusting all handles as necessary. The result should be a module
+/// functionally identical to the original.
+///
+/// This may be useful to apply to modules generated in the snapshot
+/// tests. Our backends often generate temporary names based on handle
+/// indices, which means that adding or removing unused arena entries
+/// can affect the output even though they have no semantic effect.
+/// Such meaningless changes add noise to snapshot diffs, making
+/// accurate patch review difficult. Compacting the modules before
+/// generating snapshots makes the output independent of unused arena
+/// entries.
+///
+/// # Panics
+///
+/// If `module` has not passed validation, this may panic.
+pub fn compact(module: &mut crate::Module) {
+ let mut module_tracer = ModuleTracer::new(module);
+
+ // We treat all globals as used by definition.
+ log::trace!("tracing global variables");
+ {
+ for (_, global) in module.global_variables.iter() {
+ log::trace!("tracing global {:?}", global.name);
+ module_tracer.types_used.insert(global.ty);
+ if let Some(init) = global.init {
+ module_tracer.const_expressions_used.insert(init);
+ }
+ }
+ }
+
+ // We treat all special types as used by definition.
+ module_tracer.trace_special_types(&module.special_types);
+
+ // We treat all named constants as used by definition.
+ for (handle, constant) in module.constants.iter() {
+ if constant.name.is_some() {
+ module_tracer.constants_used.insert(handle);
+ module_tracer.const_expressions_used.insert(constant.init);
+ }
+ }
+
+ // We assume that all functions are used.
+ //
+ // Observe which types, constant expressions, constants, and
+ // expressions each function uses, and produce maps for each
+ // function from pre-compaction to post-compaction expression
+ // handles.
+ log::trace!("tracing functions");
+ let function_maps: Vec<FunctionMap> = module
+ .functions
+ .iter()
+ .map(|(_, f)| {
+ log::trace!("tracing function {:?}", f.name);
+ let mut function_tracer = module_tracer.as_function(f);
+ function_tracer.trace();
+ FunctionMap::from(function_tracer)
+ })
+ .collect();
+
+ // Similarly, observe what each entry point actually uses.
+ log::trace!("tracing entry points");
+ let entry_point_maps: Vec<FunctionMap> = module
+ .entry_points
+ .iter()
+ .map(|e| {
+ log::trace!("tracing entry point {:?}", e.function.name);
+ let mut used = module_tracer.as_function(&e.function);
+ used.trace();
+ FunctionMap::from(used)
+ })
+ .collect();
+
+ // Given that the above steps have marked all the constant
+ // expressions used directly by globals, constants, functions, and
+ // entry points, walk the constant expression arena to find all
+ // constant expressions used, directly or indirectly.
+ module_tracer.as_const_expression().trace_expressions();
+
+ // Constants' initializers are taken care of already, because
+ // expression tracing sees through constants. But we still need to
+ // note type usage.
+ for (handle, constant) in module.constants.iter() {
+ if module_tracer.constants_used.contains(handle) {
+ module_tracer.types_used.insert(constant.ty);
+ }
+ }
+
+ // Treat all named types as used.
+ for (handle, ty) in module.types.iter() {
+ log::trace!("tracing type {:?}, name {:?}", handle, ty.name);
+ if ty.name.is_some() {
+ module_tracer.types_used.insert(handle);
+ }
+ }
+
+ // Propagate usage through types.
+ module_tracer.as_type().trace_types();
+
+ // Now that we know what is used and what is never touched,
+ // produce maps from the `Handle`s that appear in `module` now to
+ // the corresponding `Handle`s that will refer to the same items
+ // in the compacted module.
+ let module_map = ModuleMap::from(module_tracer);
+
+ // Drop unused types from the type arena.
+ //
+ // `FastIndexSet`s don't have an underlying Vec<T> that we can
+ // steal, compact in place, and then rebuild the `FastIndexSet`
+ // from. So we have to rebuild the type arena from scratch.
+ log::trace!("compacting types");
+ let mut new_types = arena::UniqueArena::new();
+ for (old_handle, mut ty, span) in module.types.drain_all() {
+ if let Some(expected_new_handle) = module_map.types.try_adjust(old_handle) {
+ module_map.adjust_type(&mut ty);
+ let actual_new_handle = new_types.insert(ty, span);
+ assert_eq!(actual_new_handle, expected_new_handle);
+ }
+ }
+ module.types = new_types;
+ log::trace!("adjusting special types");
+ module_map.adjust_special_types(&mut module.special_types);
+
+ // Drop unused constant expressions, reusing existing storage.
+ log::trace!("adjusting constant expressions");
+ module.const_expressions.retain_mut(|handle, expr| {
+ if module_map.const_expressions.used(handle) {
+ module_map.adjust_expression(expr, &module_map.const_expressions);
+ true
+ } else {
+ false
+ }
+ });
+
+ // Drop unused constants in place, reusing existing storage.
+ log::trace!("adjusting constants");
+ module.constants.retain_mut(|handle, constant| {
+ if module_map.constants.used(handle) {
+ module_map.types.adjust(&mut constant.ty);
+ module_map.const_expressions.adjust(&mut constant.init);
+ true
+ } else {
+ false
+ }
+ });
+
+ // Adjust global variables' types and initializers.
+ log::trace!("adjusting global variables");
+ for (_, global) in module.global_variables.iter_mut() {
+ log::trace!("adjusting global {:?}", global.name);
+ module_map.types.adjust(&mut global.ty);
+ if let Some(ref mut init) = global.init {
+ module_map.const_expressions.adjust(init);
+ }
+ }
+
+ // Temporary storage to help us reuse allocations of existing
+ // named expression tables.
+ let mut reused_named_expressions = crate::NamedExpressions::default();
+
+ // Compact each function.
+ for ((_, function), map) in module.functions.iter_mut().zip(function_maps.iter()) {
+ log::trace!("compacting function {:?}", function.name);
+ map.compact(function, &module_map, &mut reused_named_expressions);
+ }
+
+ // Compact each entry point.
+ for (entry, map) in module.entry_points.iter_mut().zip(entry_point_maps.iter()) {
+ log::trace!("compacting entry point {:?}", entry.function.name);
+ map.compact(
+ &mut entry.function,
+ &module_map,
+ &mut reused_named_expressions,
+ );
+ }
+}
+
+struct ModuleTracer<'module> {
+ module: &'module crate::Module,
+ types_used: HandleSet<crate::Type>,
+ constants_used: HandleSet<crate::Constant>,
+ const_expressions_used: HandleSet<crate::Expression>,
+}
+
+impl<'module> ModuleTracer<'module> {
+ fn new(module: &'module crate::Module) -> Self {
+ Self {
+ module,
+ types_used: HandleSet::for_arena(&module.types),
+ constants_used: HandleSet::for_arena(&module.constants),
+ const_expressions_used: HandleSet::for_arena(&module.const_expressions),
+ }
+ }
+
+ fn trace_special_types(&mut self, special_types: &crate::SpecialTypes) {
+ let crate::SpecialTypes {
+ ref ray_desc,
+ ref ray_intersection,
+ ref predeclared_types,
+ } = *special_types;
+
+ if let Some(ray_desc) = *ray_desc {
+ self.types_used.insert(ray_desc);
+ }
+ if let Some(ray_intersection) = *ray_intersection {
+ self.types_used.insert(ray_intersection);
+ }
+ for (_, &handle) in predeclared_types {
+ self.types_used.insert(handle);
+ }
+ }
+
+ fn as_type(&mut self) -> types::TypeTracer {
+ types::TypeTracer {
+ types: &self.module.types,
+ types_used: &mut self.types_used,
+ }
+ }
+
+ fn as_const_expression(&mut self) -> expressions::ExpressionTracer {
+ expressions::ExpressionTracer {
+ expressions: &self.module.const_expressions,
+ constants: &self.module.constants,
+ types_used: &mut self.types_used,
+ constants_used: &mut self.constants_used,
+ expressions_used: &mut self.const_expressions_used,
+ const_expressions_used: None,
+ }
+ }
+
+ pub fn as_function<'tracer>(
+ &'tracer mut self,
+ function: &'tracer crate::Function,
+ ) -> FunctionTracer<'tracer> {
+ FunctionTracer {
+ function,
+ constants: &self.module.constants,
+ types_used: &mut self.types_used,
+ constants_used: &mut self.constants_used,
+ const_expressions_used: &mut self.const_expressions_used,
+ expressions_used: HandleSet::for_arena(&function.expressions),
+ }
+ }
+}
+
+struct ModuleMap {
+ types: HandleMap<crate::Type>,
+ constants: HandleMap<crate::Constant>,
+ const_expressions: HandleMap<crate::Expression>,
+}
+
+impl From<ModuleTracer<'_>> for ModuleMap {
+ fn from(used: ModuleTracer) -> Self {
+ ModuleMap {
+ types: HandleMap::from_set(used.types_used),
+ constants: HandleMap::from_set(used.constants_used),
+ const_expressions: HandleMap::from_set(used.const_expressions_used),
+ }
+ }
+}
+
+impl ModuleMap {
+ fn adjust_special_types(&self, special: &mut crate::SpecialTypes) {
+ let crate::SpecialTypes {
+ ref mut ray_desc,
+ ref mut ray_intersection,
+ ref mut predeclared_types,
+ } = *special;
+
+ if let Some(ref mut ray_desc) = *ray_desc {
+ self.types.adjust(ray_desc);
+ }
+ if let Some(ref mut ray_intersection) = *ray_intersection {
+ self.types.adjust(ray_intersection);
+ }
+
+ for handle in predeclared_types.values_mut() {
+ self.types.adjust(handle);
+ }
+ }
+}
+
+struct FunctionMap {
+ expressions: HandleMap<crate::Expression>,
+}
+
+impl From<FunctionTracer<'_>> for FunctionMap {
+ fn from(used: FunctionTracer) -> Self {
+ FunctionMap {
+ expressions: HandleMap::from_set(used.expressions_used),
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/compact/statements.rs b/third_party/rust/naga/src/compact/statements.rs
new file mode 100644
index 0000000000..0698b57258
--- /dev/null
+++ b/third_party/rust/naga/src/compact/statements.rs
@@ -0,0 +1,300 @@
+use super::functions::FunctionTracer;
+use super::FunctionMap;
+use crate::arena::Handle;
+
+impl FunctionTracer<'_> {
+ pub fn trace_block(&mut self, block: &[crate::Statement]) {
+ let mut worklist: Vec<&[crate::Statement]> = vec![block];
+ while let Some(last) = worklist.pop() {
+ for stmt in last {
+ use crate::Statement as St;
+ match *stmt {
+ St::Emit(ref _range) => {
+ // If we come across a statement that actually uses an
+ // expression in this range, it'll get traced from
+ // there. But since evaluating expressions has no
+ // effect, we don't need to assume that everything
+ // emitted is live.
+ }
+ St::Block(ref block) => worklist.push(block),
+ St::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ self.expressions_used.insert(condition);
+ worklist.push(accept);
+ worklist.push(reject);
+ }
+ St::Switch {
+ selector,
+ ref cases,
+ } => {
+ self.expressions_used.insert(selector);
+ for case in cases {
+ worklist.push(&case.body);
+ }
+ }
+ St::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ if let Some(break_if) = break_if {
+ self.expressions_used.insert(break_if);
+ }
+ worklist.push(body);
+ worklist.push(continuing);
+ }
+ St::Return { value: Some(value) } => {
+ self.expressions_used.insert(value);
+ }
+ St::Store { pointer, value } => {
+ self.expressions_used.insert(pointer);
+ self.expressions_used.insert(value);
+ }
+ St::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ self.expressions_used.insert(image);
+ self.expressions_used.insert(coordinate);
+ if let Some(array_index) = array_index {
+ self.expressions_used.insert(array_index);
+ }
+ self.expressions_used.insert(value);
+ }
+ St::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ self.expressions_used.insert(pointer);
+ self.trace_atomic_function(fun);
+ self.expressions_used.insert(value);
+ self.expressions_used.insert(result);
+ }
+ St::WorkGroupUniformLoad { pointer, result } => {
+ self.expressions_used.insert(pointer);
+ self.expressions_used.insert(result);
+ }
+ St::Call {
+ function: _,
+ ref arguments,
+ result,
+ } => {
+ for expr in arguments {
+ self.expressions_used.insert(*expr);
+ }
+ if let Some(result) = result {
+ self.expressions_used.insert(result);
+ }
+ }
+ St::RayQuery { query, ref fun } => {
+ self.expressions_used.insert(query);
+ self.trace_ray_query_function(fun);
+ }
+
+ // Trivial statements.
+ St::Break
+ | St::Continue
+ | St::Kill
+ | St::Barrier(_)
+ | St::Return { value: None } => {}
+ }
+ }
+ }
+ }
+
+ fn trace_atomic_function(&mut self, fun: &crate::AtomicFunction) {
+ use crate::AtomicFunction as Af;
+ match *fun {
+ Af::Exchange {
+ compare: Some(expr),
+ } => {
+ self.expressions_used.insert(expr);
+ }
+ Af::Exchange { compare: None }
+ | Af::Add
+ | Af::Subtract
+ | Af::And
+ | Af::ExclusiveOr
+ | Af::InclusiveOr
+ | Af::Min
+ | Af::Max => {}
+ }
+ }
+
+ fn trace_ray_query_function(&mut self, fun: &crate::RayQueryFunction) {
+ use crate::RayQueryFunction as Qf;
+ match *fun {
+ Qf::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ self.expressions_used.insert(acceleration_structure);
+ self.expressions_used.insert(descriptor);
+ }
+ Qf::Proceed { result } => {
+ self.expressions_used.insert(result);
+ }
+ Qf::Terminate => {}
+ }
+ }
+}
+
+impl FunctionMap {
+ pub fn adjust_body(&self, function: &mut crate::Function) {
+ let block = &mut function.body;
+ let mut worklist: Vec<&mut [crate::Statement]> = vec![block];
+ let adjust = |handle: &mut Handle<crate::Expression>| {
+ self.expressions.adjust(handle);
+ };
+ while let Some(last) = worklist.pop() {
+ for stmt in last {
+ use crate::Statement as St;
+ match *stmt {
+ St::Emit(ref mut range) => {
+ self.expressions.adjust_range(range, &function.expressions);
+ }
+ St::Block(ref mut block) => worklist.push(block),
+ St::If {
+ ref mut condition,
+ ref mut accept,
+ ref mut reject,
+ } => {
+ adjust(condition);
+ worklist.push(accept);
+ worklist.push(reject);
+ }
+ St::Switch {
+ ref mut selector,
+ ref mut cases,
+ } => {
+ adjust(selector);
+ for case in cases {
+ worklist.push(&mut case.body);
+ }
+ }
+ St::Loop {
+ ref mut body,
+ ref mut continuing,
+ ref mut break_if,
+ } => {
+ if let Some(ref mut break_if) = *break_if {
+ adjust(break_if);
+ }
+ worklist.push(body);
+ worklist.push(continuing);
+ }
+ St::Return {
+ value: Some(ref mut value),
+ } => adjust(value),
+ St::Store {
+ ref mut pointer,
+ ref mut value,
+ } => {
+ adjust(pointer);
+ adjust(value);
+ }
+ St::ImageStore {
+ ref mut image,
+ ref mut coordinate,
+ ref mut array_index,
+ ref mut value,
+ } => {
+ adjust(image);
+ adjust(coordinate);
+ if let Some(ref mut array_index) = *array_index {
+ adjust(array_index);
+ }
+ adjust(value);
+ }
+ St::Atomic {
+ ref mut pointer,
+ ref mut fun,
+ ref mut value,
+ ref mut result,
+ } => {
+ adjust(pointer);
+ self.adjust_atomic_function(fun);
+ adjust(value);
+ adjust(result);
+ }
+ St::WorkGroupUniformLoad {
+ ref mut pointer,
+ ref mut result,
+ } => {
+ adjust(pointer);
+ adjust(result);
+ }
+ St::Call {
+ function: _,
+ ref mut arguments,
+ ref mut result,
+ } => {
+ for expr in arguments {
+ adjust(expr);
+ }
+ if let Some(ref mut result) = *result {
+ adjust(result);
+ }
+ }
+ St::RayQuery {
+ ref mut query,
+ ref mut fun,
+ } => {
+ adjust(query);
+ self.adjust_ray_query_function(fun);
+ }
+
+ // Trivial statements.
+ St::Break
+ | St::Continue
+ | St::Kill
+ | St::Barrier(_)
+ | St::Return { value: None } => {}
+ }
+ }
+ }
+ }
+
+ fn adjust_atomic_function(&self, fun: &mut crate::AtomicFunction) {
+ use crate::AtomicFunction as Af;
+ match *fun {
+ Af::Exchange {
+ compare: Some(ref mut expr),
+ } => {
+ self.expressions.adjust(expr);
+ }
+ Af::Exchange { compare: None }
+ | Af::Add
+ | Af::Subtract
+ | Af::And
+ | Af::ExclusiveOr
+ | Af::InclusiveOr
+ | Af::Min
+ | Af::Max => {}
+ }
+ }
+
+ fn adjust_ray_query_function(&self, fun: &mut crate::RayQueryFunction) {
+ use crate::RayQueryFunction as Qf;
+ match *fun {
+ Qf::Initialize {
+ ref mut acceleration_structure,
+ ref mut descriptor,
+ } => {
+ self.expressions.adjust(acceleration_structure);
+ self.expressions.adjust(descriptor);
+ }
+ Qf::Proceed { ref mut result } => {
+ self.expressions.adjust(result);
+ }
+ Qf::Terminate => {}
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/compact/types.rs b/third_party/rust/naga/src/compact/types.rs
new file mode 100644
index 0000000000..b78619d9a8
--- /dev/null
+++ b/third_party/rust/naga/src/compact/types.rs
@@ -0,0 +1,102 @@
+use super::{HandleSet, ModuleMap};
+use crate::{Handle, UniqueArena};
+
+pub struct TypeTracer<'a> {
+ pub types: &'a UniqueArena<crate::Type>,
+ pub types_used: &'a mut HandleSet<crate::Type>,
+}
+
+impl<'a> TypeTracer<'a> {
+ /// Propagate usage through `self.types`, starting with `self.types_used`.
+ ///
+ /// Treat `self.types_used` as the initial set of "known
+ /// live" types, and follow through to identify all
+ /// transitively used types.
+ pub fn trace_types(&mut self) {
+ // We don't need recursion or a work list. Because an
+ // expression may only refer to other expressions that precede
+ // it in the arena, it suffices to make a single pass over the
+ // arena from back to front, marking the referents of used
+ // expressions as used themselves.
+ for (handle, ty) in self.types.iter().rev() {
+ // If this type isn't used, it doesn't matter what it uses.
+ if !self.types_used.contains(handle) {
+ continue;
+ }
+
+ use crate::TypeInner as Ti;
+ match ty.inner {
+ // Types that do not contain handles.
+ Ti::Scalar { .. }
+ | Ti::Vector { .. }
+ | Ti::Matrix { .. }
+ | Ti::Atomic { .. }
+ | Ti::ValuePointer { .. }
+ | Ti::Image { .. }
+ | Ti::Sampler { .. }
+ | Ti::AccelerationStructure
+ | Ti::RayQuery => {}
+
+ // Types that do contain handles.
+ Ti::Pointer { base, space: _ }
+ | Ti::Array {
+ base,
+ size: _,
+ stride: _,
+ }
+ | Ti::BindingArray { base, size: _ } => self.types_used.insert(base),
+ Ti::Struct {
+ ref members,
+ span: _,
+ } => {
+ self.types_used.insert_iter(members.iter().map(|m| m.ty));
+ }
+ }
+ }
+ }
+}
+
+impl ModuleMap {
+ pub fn adjust_type(&self, ty: &mut crate::Type) {
+ let adjust = |ty: &mut Handle<crate::Type>| self.types.adjust(ty);
+
+ use crate::TypeInner as Ti;
+ match ty.inner {
+ // Types that do not contain handles.
+ Ti::Scalar(_)
+ | Ti::Vector { .. }
+ | Ti::Matrix { .. }
+ | Ti::Atomic(_)
+ | Ti::ValuePointer { .. }
+ | Ti::Image { .. }
+ | Ti::Sampler { .. }
+ | Ti::AccelerationStructure
+ | Ti::RayQuery => {}
+
+ // Types that do contain handles.
+ Ti::Pointer {
+ ref mut base,
+ space: _,
+ } => adjust(base),
+ Ti::Array {
+ ref mut base,
+ size: _,
+ stride: _,
+ } => adjust(base),
+ Ti::Struct {
+ ref mut members,
+ span: _,
+ } => {
+ for member in members {
+ self.types.adjust(&mut member.ty);
+ }
+ }
+ Ti::BindingArray {
+ ref mut base,
+ size: _,
+ } => {
+ adjust(base);
+ }
+ };
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/ast.rs b/third_party/rust/naga/src/front/glsl/ast.rs
new file mode 100644
index 0000000000..96b676dd6d
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/ast.rs
@@ -0,0 +1,394 @@
+use std::{borrow::Cow, fmt};
+
+use super::{builtins::MacroCall, context::ExprPos, Span};
+use crate::{
+ AddressSpace, BinaryOperator, Binding, Constant, Expression, Function, GlobalVariable, Handle,
+ Interpolation, Literal, Sampling, StorageAccess, Type, UnaryOperator,
+};
+
+#[derive(Debug, Clone, Copy)]
+pub enum GlobalLookupKind {
+ Variable(Handle<GlobalVariable>),
+ Constant(Handle<Constant>, Handle<Type>),
+ BlockSelect(Handle<GlobalVariable>, u32),
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct GlobalLookup {
+ pub kind: GlobalLookupKind,
+ pub entry_arg: Option<usize>,
+ pub mutable: bool,
+}
+
+#[derive(Debug, Clone)]
+pub struct ParameterInfo {
+ pub qualifier: ParameterQualifier,
+ /// Whether the parameter should be treated as a depth image instead of a
+ /// sampled image.
+ pub depth: bool,
+}
+
+/// How the function is implemented
+#[derive(Clone, Copy)]
+pub enum FunctionKind {
+ /// The function is user defined
+ Call(Handle<Function>),
+ /// The function is a builtin
+ Macro(MacroCall),
+}
+
+impl fmt::Debug for FunctionKind {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match *self {
+ Self::Call(_) => write!(f, "Call"),
+ Self::Macro(_) => write!(f, "Macro"),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Overload {
+ /// Normalized function parameters, modifiers are not applied
+ pub parameters: Vec<Handle<Type>>,
+ pub parameters_info: Vec<ParameterInfo>,
+ /// How the function is implemented
+ pub kind: FunctionKind,
+ /// Whether this function was already defined or is just a prototype
+ pub defined: bool,
+ /// Whether this overload is the one provided by the language or has
+ /// been redeclared by the user (builtins only)
+ pub internal: bool,
+ /// Whether or not this function returns void (nothing)
+ pub void: bool,
+}
+
+bitflags::bitflags! {
+ /// Tracks the variations of the builtin already generated, this is needed because some
+ /// builtins overloads can't be generated unless explicitly used, since they might cause
+ /// unneeded capabilities to be requested
+ #[derive(Default)]
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct BuiltinVariations: u32 {
+ /// Request the standard overloads
+ const STANDARD = 1 << 0;
+ /// Request overloads that use the double type
+ const DOUBLE = 1 << 1;
+ /// Request overloads that use samplerCubeArray(Shadow)
+ const CUBE_TEXTURES_ARRAY = 1 << 2;
+ /// Request overloads that use sampler2DMSArray
+ const D2_MULTI_TEXTURES_ARRAY = 1 << 3;
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct FunctionDeclaration {
+ pub overloads: Vec<Overload>,
+ /// Tracks the builtin overload variations that were already generated
+ pub variations: BuiltinVariations,
+}
+
+#[derive(Debug)]
+pub struct EntryArg {
+ pub name: Option<String>,
+ pub binding: Binding,
+ pub handle: Handle<GlobalVariable>,
+ pub storage: StorageQualifier,
+}
+
+#[derive(Debug, Clone)]
+pub struct VariableReference {
+ pub expr: Handle<Expression>,
+ /// Whether the variable is of a pointer type (and needs loading) or not
+ pub load: bool,
+ /// Whether the value of the variable can be changed or not
+ pub mutable: bool,
+ pub constant: Option<(Handle<Constant>, Handle<Type>)>,
+ pub entry_arg: Option<usize>,
+}
+
+#[derive(Debug, Clone)]
+pub struct HirExpr {
+ pub kind: HirExprKind,
+ pub meta: Span,
+}
+
+#[derive(Debug, Clone)]
+pub enum HirExprKind {
+ Access {
+ base: Handle<HirExpr>,
+ index: Handle<HirExpr>,
+ },
+ Select {
+ base: Handle<HirExpr>,
+ field: String,
+ },
+ Literal(Literal),
+ Binary {
+ left: Handle<HirExpr>,
+ op: BinaryOperator,
+ right: Handle<HirExpr>,
+ },
+ Unary {
+ op: UnaryOperator,
+ expr: Handle<HirExpr>,
+ },
+ Variable(VariableReference),
+ Call(FunctionCall),
+ /// Represents the ternary operator in glsl (`:?`)
+ Conditional {
+ /// The expression that will decide which branch to take, must evaluate to a boolean
+ condition: Handle<HirExpr>,
+ /// The expression that will be evaluated if [`condition`] returns `true`
+ ///
+ /// [`condition`]: Self::Conditional::condition
+ accept: Handle<HirExpr>,
+ /// The expression that will be evaluated if [`condition`] returns `false`
+ ///
+ /// [`condition`]: Self::Conditional::condition
+ reject: Handle<HirExpr>,
+ },
+ Assign {
+ tgt: Handle<HirExpr>,
+ value: Handle<HirExpr>,
+ },
+ /// A prefix/postfix operator like `++`
+ PrePostfix {
+ /// The operation to be performed
+ op: BinaryOperator,
+ /// Whether this is a postfix or a prefix
+ postfix: bool,
+ /// The target expression
+ expr: Handle<HirExpr>,
+ },
+ /// A method call like `what.something(a, b, c)`
+ Method {
+ /// expression the method call applies to (`what` in the example)
+ expr: Handle<HirExpr>,
+ /// the method name (`something` in the example)
+ name: String,
+ /// the arguments to the method (`a`, `b`, and `c` in the example)
+ args: Vec<Handle<HirExpr>>,
+ },
+}
+
+#[derive(Debug, Hash, PartialEq, Eq)]
+pub enum QualifierKey<'a> {
+ String(Cow<'a, str>),
+ /// Used for `std140` and `std430` layout qualifiers
+ Layout,
+ /// Used for image formats
+ Format,
+}
+
+#[derive(Debug)]
+pub enum QualifierValue {
+ None,
+ Uint(u32),
+ Layout(StructLayout),
+ Format(crate::StorageFormat),
+}
+
+#[derive(Debug, Default)]
+pub struct TypeQualifiers<'a> {
+ pub span: Span,
+ pub storage: (StorageQualifier, Span),
+ pub invariant: Option<Span>,
+ pub interpolation: Option<(Interpolation, Span)>,
+ pub precision: Option<(Precision, Span)>,
+ pub sampling: Option<(Sampling, Span)>,
+ /// Memory qualifiers used in the declaration to set the storage access to be used
+ /// in declarations that support it (storage images and buffers)
+ pub storage_access: Option<(StorageAccess, Span)>,
+ pub layout_qualifiers: crate::FastHashMap<QualifierKey<'a>, (QualifierValue, Span)>,
+}
+
+impl<'a> TypeQualifiers<'a> {
+ /// Appends `errors` with errors for all unused qualifiers
+ pub fn unused_errors(&self, errors: &mut Vec<super::Error>) {
+ if let Some(meta) = self.invariant {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError(
+ "Invariant qualifier can only be used in in/out variables".into(),
+ ),
+ meta,
+ });
+ }
+
+ if let Some((_, meta)) = self.interpolation {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError(
+ "Interpolation qualifiers can only be used in in/out variables".into(),
+ ),
+ meta,
+ });
+ }
+
+ if let Some((_, meta)) = self.sampling {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError(
+ "Sampling qualifiers can only be used in in/out variables".into(),
+ ),
+ meta,
+ });
+ }
+
+ if let Some((_, meta)) = self.storage_access {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError(
+ "Memory qualifiers can only be used in storage variables".into(),
+ ),
+ meta,
+ });
+ }
+
+ for &(_, meta) in self.layout_qualifiers.values() {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError("Unexpected qualifier".into()),
+ meta,
+ });
+ }
+ }
+
+ /// Removes the layout qualifier with `name`, if it exists and adds an error if it isn't
+ /// a [`QualifierValue::Uint`]
+ pub fn uint_layout_qualifier(
+ &mut self,
+ name: &'a str,
+ errors: &mut Vec<super::Error>,
+ ) -> Option<u32> {
+ match self
+ .layout_qualifiers
+ .remove(&QualifierKey::String(name.into()))
+ {
+ Some((QualifierValue::Uint(v), _)) => Some(v),
+ Some((_, meta)) => {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError("Qualifier expects a uint value".into()),
+ meta,
+ });
+ // Return a dummy value instead of `None` to differentiate from
+ // the qualifier not existing, since some parts might require the
+ // qualifier to exist and throwing another error that it doesn't
+ // exist would be unhelpful
+ Some(0)
+ }
+ _ => None,
+ }
+ }
+
+ /// Removes the layout qualifier with `name`, if it exists and adds an error if it isn't
+ /// a [`QualifierValue::None`]
+ pub fn none_layout_qualifier(&mut self, name: &'a str, errors: &mut Vec<super::Error>) -> bool {
+ match self
+ .layout_qualifiers
+ .remove(&QualifierKey::String(name.into()))
+ {
+ Some((QualifierValue::None, _)) => true,
+ Some((_, meta)) => {
+ errors.push(super::Error {
+ kind: super::ErrorKind::SemanticError(
+ "Qualifier doesn't expect a value".into(),
+ ),
+ meta,
+ });
+ // Return a `true` to since the qualifier is defined and adding
+ // another error for it not being defined would be unhelpful
+ true
+ }
+ _ => false,
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub enum FunctionCallKind {
+ TypeConstructor(Handle<Type>),
+ Function(String),
+}
+
+#[derive(Debug, Clone)]
+pub struct FunctionCall {
+ pub kind: FunctionCallKind,
+ pub args: Vec<Handle<HirExpr>>,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub enum StorageQualifier {
+ AddressSpace(AddressSpace),
+ Input,
+ Output,
+ Const,
+}
+
+impl Default for StorageQualifier {
+ fn default() -> Self {
+ StorageQualifier::AddressSpace(AddressSpace::Function)
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum StructLayout {
+ Std140,
+ Std430,
+}
+
+// TODO: Encode precision hints in the IR
+/// A precision hint used in GLSL declarations.
+///
+/// Precision hints can be used to either speed up shader execution or control
+/// the precision of arithmetic operations.
+///
+/// To use a precision hint simply add it before the type in the declaration.
+/// ```glsl
+/// mediump float a;
+/// ```
+///
+/// The default when no precision is declared is `highp` which means that all
+/// operations operate with the type defined width.
+///
+/// For `mediump` and `lowp` operations follow the spir-v
+/// [`RelaxedPrecision`][RelaxedPrecision] decoration semantics.
+///
+/// [RelaxedPrecision]: https://www.khronos.org/registry/SPIR-V/specs/unified1/SPIRV.html#_a_id_relaxedprecisionsection_a_relaxed_precision
+#[derive(Debug, Clone, PartialEq, Copy)]
+pub enum Precision {
+ /// `lowp` precision
+ Low,
+ /// `mediump` precision
+ Medium,
+ /// `highp` precision
+ High,
+}
+
+#[derive(Debug, Clone, PartialEq, Copy)]
+pub enum ParameterQualifier {
+ In,
+ Out,
+ InOut,
+ Const,
+}
+
+impl ParameterQualifier {
+ /// Returns true if the argument should be passed as a lhs expression
+ pub const fn is_lhs(&self) -> bool {
+ match *self {
+ ParameterQualifier::Out | ParameterQualifier::InOut => true,
+ _ => false,
+ }
+ }
+
+ /// Converts from a parameter qualifier into a [`ExprPos`]
+ pub const fn as_pos(&self) -> ExprPos {
+ match *self {
+ ParameterQualifier::Out | ParameterQualifier::InOut => ExprPos::Lhs,
+ _ => ExprPos::Rhs,
+ }
+ }
+}
+
+/// The GLSL profile used by a shader.
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub enum Profile {
+ /// The `core` profile, default when no profile is specified.
+ Core,
+}
diff --git a/third_party/rust/naga/src/front/glsl/builtins.rs b/third_party/rust/naga/src/front/glsl/builtins.rs
new file mode 100644
index 0000000000..9e3a578c6b
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/builtins.rs
@@ -0,0 +1,2314 @@
+use super::{
+ ast::{
+ BuiltinVariations, FunctionDeclaration, FunctionKind, Overload, ParameterInfo,
+ ParameterQualifier,
+ },
+ context::Context,
+ Error, ErrorKind, Frontend, Result,
+};
+use crate::{
+ BinaryOperator, DerivativeAxis as Axis, DerivativeControl as Ctrl, Expression, Handle,
+ ImageClass, ImageDimension as Dim, ImageQuery, MathFunction, Module, RelationalFunction,
+ SampleLevel, Scalar, ScalarKind as Sk, Span, Type, TypeInner, UnaryOperator, VectorSize,
+};
+
+impl crate::ScalarKind {
+ const fn dummy_storage_format(&self) -> crate::StorageFormat {
+ match *self {
+ Sk::Sint => crate::StorageFormat::R16Sint,
+ Sk::Uint => crate::StorageFormat::R16Uint,
+ _ => crate::StorageFormat::R16Float,
+ }
+ }
+}
+
+impl Module {
+ /// Helper function, to create a function prototype for a builtin
+ fn add_builtin(&mut self, args: Vec<TypeInner>, builtin: MacroCall) -> Overload {
+ let mut parameters = Vec::with_capacity(args.len());
+ let mut parameters_info = Vec::with_capacity(args.len());
+
+ for arg in args {
+ parameters.push(self.types.insert(
+ Type {
+ name: None,
+ inner: arg,
+ },
+ Span::default(),
+ ));
+ parameters_info.push(ParameterInfo {
+ qualifier: ParameterQualifier::In,
+ depth: false,
+ });
+ }
+
+ Overload {
+ parameters,
+ parameters_info,
+ kind: FunctionKind::Macro(builtin),
+ defined: false,
+ internal: true,
+ void: false,
+ }
+ }
+}
+
+const fn make_coords_arg(number_of_components: usize, kind: Sk) -> TypeInner {
+ let scalar = Scalar { kind, width: 4 };
+
+ match number_of_components {
+ 1 => TypeInner::Scalar(scalar),
+ _ => TypeInner::Vector {
+ size: match number_of_components {
+ 2 => VectorSize::Bi,
+ 3 => VectorSize::Tri,
+ _ => VectorSize::Quad,
+ },
+ scalar,
+ },
+ }
+}
+
+/// Inject builtins into the declaration
+///
+/// This is done to not add a large startup cost and not increase memory
+/// usage if it isn't needed.
+pub fn inject_builtin(
+ declaration: &mut FunctionDeclaration,
+ module: &mut Module,
+ name: &str,
+ mut variations: BuiltinVariations,
+) {
+ log::trace!(
+ "{} variations: {:?} {:?}",
+ name,
+ variations,
+ declaration.variations
+ );
+ // Don't regeneate variations
+ variations.remove(declaration.variations);
+ declaration.variations |= variations;
+
+ if variations.contains(BuiltinVariations::STANDARD) {
+ inject_standard_builtins(declaration, module, name)
+ }
+
+ if variations.contains(BuiltinVariations::DOUBLE) {
+ inject_double_builtin(declaration, module, name)
+ }
+
+ match name {
+ "texture"
+ | "textureGrad"
+ | "textureGradOffset"
+ | "textureLod"
+ | "textureLodOffset"
+ | "textureOffset"
+ | "textureProj"
+ | "textureProjGrad"
+ | "textureProjGradOffset"
+ | "textureProjLod"
+ | "textureProjLodOffset"
+ | "textureProjOffset" => {
+ let f = |kind, dim, arrayed, multi, shadow| {
+ for bits in 0..=0b11 {
+ let variant = bits & 0b1 != 0;
+ let bias = bits & 0b10 != 0;
+
+ let (proj, offset, level_type) = match name {
+ // texture(gsampler, gvec P, [float bias]);
+ "texture" => (false, false, TextureLevelType::None),
+ // textureGrad(gsampler, gvec P, gvec dPdx, gvec dPdy);
+ "textureGrad" => (false, false, TextureLevelType::Grad),
+ // textureGradOffset(gsampler, gvec P, gvec dPdx, gvec dPdy, ivec offset);
+ "textureGradOffset" => (false, true, TextureLevelType::Grad),
+ // textureLod(gsampler, gvec P, float lod);
+ "textureLod" => (false, false, TextureLevelType::Lod),
+ // textureLodOffset(gsampler, gvec P, float lod, ivec offset);
+ "textureLodOffset" => (false, true, TextureLevelType::Lod),
+ // textureOffset(gsampler, gvec+1 P, ivec offset, [float bias]);
+ "textureOffset" => (false, true, TextureLevelType::None),
+ // textureProj(gsampler, gvec+1 P, [float bias]);
+ "textureProj" => (true, false, TextureLevelType::None),
+ // textureProjGrad(gsampler, gvec+1 P, gvec dPdx, gvec dPdy);
+ "textureProjGrad" => (true, false, TextureLevelType::Grad),
+ // textureProjGradOffset(gsampler, gvec+1 P, gvec dPdx, gvec dPdy, ivec offset);
+ "textureProjGradOffset" => (true, true, TextureLevelType::Grad),
+ // textureProjLod(gsampler, gvec+1 P, float lod);
+ "textureProjLod" => (true, false, TextureLevelType::Lod),
+ // textureProjLodOffset(gsampler, gvec+1 P, gvec dPdx, gvec dPdy, ivec offset);
+ "textureProjLodOffset" => (true, true, TextureLevelType::Lod),
+ // textureProjOffset(gsampler, gvec+1 P, ivec offset, [float bias]);
+ "textureProjOffset" => (true, true, TextureLevelType::None),
+ _ => unreachable!(),
+ };
+
+ let builtin = MacroCall::Texture {
+ proj,
+ offset,
+ shadow,
+ level_type,
+ };
+
+ // Parse out the variant settings.
+ let grad = level_type == TextureLevelType::Grad;
+ let lod = level_type == TextureLevelType::Lod;
+
+ let supports_variant = proj && !shadow;
+ if variant && !supports_variant {
+ continue;
+ }
+
+ if bias && !matches!(level_type, TextureLevelType::None) {
+ continue;
+ }
+
+ // Proj doesn't work with arrayed or Cube
+ if proj && (arrayed || dim == Dim::Cube) {
+ continue;
+ }
+
+ // texture operations with offset are not supported for cube maps
+ if dim == Dim::Cube && offset {
+ continue;
+ }
+
+ // sampler2DArrayShadow can't be used in textureLod or in texture with bias
+ if (lod || bias) && arrayed && shadow && dim == Dim::D2 {
+ continue;
+ }
+
+ // TODO: glsl supports using bias with depth samplers but naga doesn't
+ if bias && shadow {
+ continue;
+ }
+
+ let class = match shadow {
+ true => ImageClass::Depth { multi },
+ false => ImageClass::Sampled { kind, multi },
+ };
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ };
+
+ let num_coords_from_dim = image_dims_to_coords_size(dim).min(3);
+ let mut num_coords = num_coords_from_dim;
+
+ if shadow && proj {
+ num_coords = 4;
+ } else if dim == Dim::D1 && shadow {
+ num_coords = 3;
+ } else if shadow {
+ num_coords += 1;
+ } else if proj {
+ if variant && num_coords == 4 {
+ // Normal form already has 4 components, no need to have a variant form.
+ continue;
+ } else if variant {
+ num_coords = 4;
+ } else {
+ num_coords += 1;
+ }
+ }
+
+ if !(dim == Dim::D1 && shadow) {
+ num_coords += arrayed as usize;
+ }
+
+ // Special case: texture(gsamplerCubeArrayShadow) kicks the shadow compare ref to a separate argument,
+ // since it would otherwise take five arguments. It also can't take a bias, nor can it be proj/grad/lod/offset
+ // (presumably because nobody asked for it, and implementation complexity?)
+ if num_coords >= 5 {
+ if lod || grad || offset || proj || bias {
+ continue;
+ }
+ debug_assert!(dim == Dim::Cube && shadow && arrayed);
+ }
+ debug_assert!(num_coords <= 5);
+
+ let vector = make_coords_arg(num_coords, Sk::Float);
+ let mut args = vec![image, vector];
+
+ if num_coords == 5 {
+ args.push(TypeInner::Scalar(Scalar::F32));
+ }
+
+ match level_type {
+ TextureLevelType::Lod => {
+ args.push(TypeInner::Scalar(Scalar::F32));
+ }
+ TextureLevelType::Grad => {
+ args.push(make_coords_arg(num_coords_from_dim, Sk::Float));
+ args.push(make_coords_arg(num_coords_from_dim, Sk::Float));
+ }
+ TextureLevelType::None => {}
+ };
+
+ if offset {
+ args.push(make_coords_arg(num_coords_from_dim, Sk::Sint));
+ }
+
+ if bias {
+ args.push(TypeInner::Scalar(Scalar::F32));
+ }
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, builtin));
+ }
+ };
+
+ texture_args_generator(TextureArgsOptions::SHADOW | variations.into(), f)
+ }
+ "textureSize" => {
+ let f = |kind, dim, arrayed, multi, shadow| {
+ let class = match shadow {
+ true => ImageClass::Depth { multi },
+ false => ImageClass::Sampled { kind, multi },
+ };
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ };
+
+ let mut args = vec![image];
+
+ if !multi {
+ args.push(TypeInner::Scalar(Scalar::I32))
+ }
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::TextureSize { arrayed }))
+ };
+
+ texture_args_generator(
+ TextureArgsOptions::SHADOW | TextureArgsOptions::MULTI | variations.into(),
+ f,
+ )
+ }
+ "texelFetch" | "texelFetchOffset" => {
+ let offset = "texelFetchOffset" == name;
+ let f = |kind, dim, arrayed, multi, _shadow| {
+ // Cube images aren't supported
+ if let Dim::Cube = dim {
+ return;
+ }
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class: ImageClass::Sampled { kind, multi },
+ };
+
+ let dim_value = image_dims_to_coords_size(dim);
+ let coordinates = make_coords_arg(dim_value + arrayed as usize, Sk::Sint);
+
+ let mut args = vec![image, coordinates, TypeInner::Scalar(Scalar::I32)];
+
+ if offset {
+ args.push(make_coords_arg(dim_value, Sk::Sint));
+ }
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::ImageLoad { multi }))
+ };
+
+ // Don't generate shadow images since they aren't supported
+ texture_args_generator(TextureArgsOptions::MULTI | variations.into(), f)
+ }
+ "imageSize" => {
+ let f = |kind: Sk, dim, arrayed, _, _| {
+ // Naga doesn't support cube images and it's usefulness
+ // is questionable, so they won't be supported for now
+ if dim == Dim::Cube {
+ return;
+ }
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class: ImageClass::Storage {
+ format: kind.dummy_storage_format(),
+ access: crate::StorageAccess::empty(),
+ },
+ };
+
+ declaration
+ .overloads
+ .push(module.add_builtin(vec![image], MacroCall::TextureSize { arrayed }))
+ };
+
+ texture_args_generator(variations.into(), f)
+ }
+ "imageLoad" => {
+ let f = |kind: Sk, dim, arrayed, _, _| {
+ // Naga doesn't support cube images and it's usefulness
+ // is questionable, so they won't be supported for now
+ if dim == Dim::Cube {
+ return;
+ }
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class: ImageClass::Storage {
+ format: kind.dummy_storage_format(),
+ access: crate::StorageAccess::LOAD,
+ },
+ };
+
+ let dim_value = image_dims_to_coords_size(dim);
+ let mut coord_size = dim_value + arrayed as usize;
+ // > Every OpenGL API call that operates on cubemap array
+ // > textures takes layer-faces, not array layers
+ //
+ // So this means that imageCubeArray only takes a three component
+ // vector coordinate and the third component is a layer index.
+ if Dim::Cube == dim && arrayed {
+ coord_size = 3
+ }
+ let coordinates = make_coords_arg(coord_size, Sk::Sint);
+
+ let args = vec![image, coordinates];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::ImageLoad { multi: false }))
+ };
+
+ // Don't generate shadow nor multisampled images since they aren't supported
+ texture_args_generator(variations.into(), f)
+ }
+ "imageStore" => {
+ let f = |kind: Sk, dim, arrayed, _, _| {
+ // Naga doesn't support cube images and it's usefulness
+ // is questionable, so they won't be supported for now
+ if dim == Dim::Cube {
+ return;
+ }
+
+ let image = TypeInner::Image {
+ dim,
+ arrayed,
+ class: ImageClass::Storage {
+ format: kind.dummy_storage_format(),
+ access: crate::StorageAccess::STORE,
+ },
+ };
+
+ let dim_value = image_dims_to_coords_size(dim);
+ let mut coord_size = dim_value + arrayed as usize;
+ // > Every OpenGL API call that operates on cubemap array
+ // > textures takes layer-faces, not array layers
+ //
+ // So this means that imageCubeArray only takes a three component
+ // vector coordinate and the third component is a layer index.
+ if Dim::Cube == dim && arrayed {
+ coord_size = 3
+ }
+ let coordinates = make_coords_arg(coord_size, Sk::Sint);
+
+ let args = vec![
+ image,
+ coordinates,
+ TypeInner::Vector {
+ size: VectorSize::Quad,
+ scalar: Scalar { kind, width: 4 },
+ },
+ ];
+
+ let mut overload = module.add_builtin(args, MacroCall::ImageStore);
+ overload.void = true;
+ declaration.overloads.push(overload)
+ };
+
+ // Don't generate shadow nor multisampled images since they aren't supported
+ texture_args_generator(variations.into(), f)
+ }
+ _ => {}
+ }
+}
+
+/// Injects the builtins into declaration that don't need any special variations
+fn inject_standard_builtins(
+ declaration: &mut FunctionDeclaration,
+ module: &mut Module,
+ name: &str,
+) {
+ match name {
+ "sampler1D" | "sampler1DArray" | "sampler2D" | "sampler2DArray" | "sampler2DMS"
+ | "sampler2DMSArray" | "sampler3D" | "samplerCube" | "samplerCubeArray" => {
+ declaration.overloads.push(module.add_builtin(
+ vec![
+ TypeInner::Image {
+ dim: match name {
+ "sampler1D" | "sampler1DArray" => Dim::D1,
+ "sampler2D" | "sampler2DArray" | "sampler2DMS" | "sampler2DMSArray" => {
+ Dim::D2
+ }
+ "sampler3D" => Dim::D3,
+ _ => Dim::Cube,
+ },
+ arrayed: matches!(
+ name,
+ "sampler1DArray"
+ | "sampler2DArray"
+ | "sampler2DMSArray"
+ | "samplerCubeArray"
+ ),
+ class: ImageClass::Sampled {
+ kind: Sk::Float,
+ multi: matches!(name, "sampler2DMS" | "sampler2DMSArray"),
+ },
+ },
+ TypeInner::Sampler { comparison: false },
+ ],
+ MacroCall::Sampler,
+ ))
+ }
+ "sampler1DShadow"
+ | "sampler1DArrayShadow"
+ | "sampler2DShadow"
+ | "sampler2DArrayShadow"
+ | "samplerCubeShadow"
+ | "samplerCubeArrayShadow" => {
+ let dim = match name {
+ "sampler1DShadow" | "sampler1DArrayShadow" => Dim::D1,
+ "sampler2DShadow" | "sampler2DArrayShadow" => Dim::D2,
+ _ => Dim::Cube,
+ };
+ let arrayed = matches!(
+ name,
+ "sampler1DArrayShadow" | "sampler2DArrayShadow" | "samplerCubeArrayShadow"
+ );
+
+ for i in 0..2 {
+ let ty = TypeInner::Image {
+ dim,
+ arrayed,
+ class: match i {
+ 0 => ImageClass::Sampled {
+ kind: Sk::Float,
+ multi: false,
+ },
+ _ => ImageClass::Depth { multi: false },
+ },
+ };
+
+ declaration.overloads.push(module.add_builtin(
+ vec![ty, TypeInner::Sampler { comparison: true }],
+ MacroCall::SamplerShadow,
+ ))
+ }
+ }
+ "sin" | "exp" | "exp2" | "sinh" | "cos" | "cosh" | "tan" | "tanh" | "acos" | "asin"
+ | "log" | "log2" | "radians" | "degrees" | "asinh" | "acosh" | "atanh"
+ | "floatBitsToInt" | "floatBitsToUint" | "dFdx" | "dFdxFine" | "dFdxCoarse" | "dFdy"
+ | "dFdyFine" | "dFdyCoarse" | "fwidth" | "fwidthFine" | "fwidthCoarse" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let scalar = Scalar::F32;
+
+ declaration.overloads.push(module.add_builtin(
+ vec![match size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ }],
+ match name {
+ "sin" => MacroCall::MathFunction(MathFunction::Sin),
+ "exp" => MacroCall::MathFunction(MathFunction::Exp),
+ "exp2" => MacroCall::MathFunction(MathFunction::Exp2),
+ "sinh" => MacroCall::MathFunction(MathFunction::Sinh),
+ "cos" => MacroCall::MathFunction(MathFunction::Cos),
+ "cosh" => MacroCall::MathFunction(MathFunction::Cosh),
+ "tan" => MacroCall::MathFunction(MathFunction::Tan),
+ "tanh" => MacroCall::MathFunction(MathFunction::Tanh),
+ "acos" => MacroCall::MathFunction(MathFunction::Acos),
+ "asin" => MacroCall::MathFunction(MathFunction::Asin),
+ "log" => MacroCall::MathFunction(MathFunction::Log),
+ "log2" => MacroCall::MathFunction(MathFunction::Log2),
+ "asinh" => MacroCall::MathFunction(MathFunction::Asinh),
+ "acosh" => MacroCall::MathFunction(MathFunction::Acosh),
+ "atanh" => MacroCall::MathFunction(MathFunction::Atanh),
+ "radians" => MacroCall::MathFunction(MathFunction::Radians),
+ "degrees" => MacroCall::MathFunction(MathFunction::Degrees),
+ "floatBitsToInt" => MacroCall::BitCast(Sk::Sint),
+ "floatBitsToUint" => MacroCall::BitCast(Sk::Uint),
+ "dFdxCoarse" => MacroCall::Derivate(Axis::X, Ctrl::Coarse),
+ "dFdyCoarse" => MacroCall::Derivate(Axis::Y, Ctrl::Coarse),
+ "fwidthCoarse" => MacroCall::Derivate(Axis::Width, Ctrl::Coarse),
+ "dFdxFine" => MacroCall::Derivate(Axis::X, Ctrl::Fine),
+ "dFdyFine" => MacroCall::Derivate(Axis::Y, Ctrl::Fine),
+ "fwidthFine" => MacroCall::Derivate(Axis::Width, Ctrl::Fine),
+ "dFdx" => MacroCall::Derivate(Axis::X, Ctrl::None),
+ "dFdy" => MacroCall::Derivate(Axis::Y, Ctrl::None),
+ "fwidth" => MacroCall::Derivate(Axis::Width, Ctrl::None),
+ _ => unreachable!(),
+ },
+ ))
+ }
+ }
+ "intBitsToFloat" | "uintBitsToFloat" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let scalar = match name {
+ "intBitsToFloat" => Scalar::I32,
+ _ => Scalar::U32,
+ };
+
+ declaration.overloads.push(module.add_builtin(
+ vec![match size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ }],
+ MacroCall::BitCast(Sk::Float),
+ ))
+ }
+ }
+ "pow" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let scalar = Scalar::F32;
+ let ty = || match size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ };
+
+ declaration.overloads.push(
+ module
+ .add_builtin(vec![ty(), ty()], MacroCall::MathFunction(MathFunction::Pow)),
+ )
+ }
+ }
+ "abs" | "sign" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ // bit 2 - float/sint
+ for bits in 0..0b1000 {
+ let size = match bits & 0b11 {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let scalar = match bits >> 2 {
+ 0b0 => Scalar::F32,
+ _ => Scalar::I32,
+ };
+
+ let args = vec![match size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ }];
+
+ declaration.overloads.push(module.add_builtin(
+ args,
+ MacroCall::MathFunction(match name {
+ "abs" => MathFunction::Abs,
+ "sign" => MathFunction::Sign,
+ _ => unreachable!(),
+ }),
+ ))
+ }
+ }
+ "bitCount" | "bitfieldReverse" | "bitfieldExtract" | "bitfieldInsert" | "findLSB"
+ | "findMSB" => {
+ let fun = match name {
+ "bitCount" => MathFunction::CountOneBits,
+ "bitfieldReverse" => MathFunction::ReverseBits,
+ "bitfieldExtract" => MathFunction::ExtractBits,
+ "bitfieldInsert" => MathFunction::InsertBits,
+ "findLSB" => MathFunction::FindLsb,
+ "findMSB" => MathFunction::FindMsb,
+ _ => unreachable!(),
+ };
+
+ let mc = match fun {
+ MathFunction::ExtractBits => MacroCall::BitfieldExtract,
+ MathFunction::InsertBits => MacroCall::BitfieldInsert,
+ _ => MacroCall::MathFunction(fun),
+ };
+
+ // bits layout
+ // bit 0 - int/uint
+ // bit 1 through 2 - dims
+ for bits in 0..0b1000 {
+ let scalar = match bits & 0b1 {
+ 0b0 => Scalar::I32,
+ _ => Scalar::U32,
+ };
+ let size = match bits >> 1 {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ let ty = || match size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ };
+
+ let mut args = vec![ty()];
+
+ match fun {
+ MathFunction::ExtractBits => {
+ args.push(TypeInner::Scalar(Scalar::I32));
+ args.push(TypeInner::Scalar(Scalar::I32));
+ }
+ MathFunction::InsertBits => {
+ args.push(ty());
+ args.push(TypeInner::Scalar(Scalar::I32));
+ args.push(TypeInner::Scalar(Scalar::I32));
+ }
+ _ => {}
+ }
+
+ // we need to cast the return type of findLsb / findMsb
+ let mc = if scalar.kind == Sk::Uint {
+ match mc {
+ MacroCall::MathFunction(MathFunction::FindLsb) => MacroCall::FindLsbUint,
+ MacroCall::MathFunction(MathFunction::FindMsb) => MacroCall::FindMsbUint,
+ mc => mc,
+ }
+ } else {
+ mc
+ };
+
+ declaration.overloads.push(module.add_builtin(args, mc))
+ }
+ }
+ "packSnorm4x8" | "packUnorm4x8" | "packSnorm2x16" | "packUnorm2x16" | "packHalf2x16" => {
+ let fun = match name {
+ "packSnorm4x8" => MathFunction::Pack4x8snorm,
+ "packUnorm4x8" => MathFunction::Pack4x8unorm,
+ "packSnorm2x16" => MathFunction::Pack2x16unorm,
+ "packUnorm2x16" => MathFunction::Pack2x16snorm,
+ "packHalf2x16" => MathFunction::Pack2x16float,
+ _ => unreachable!(),
+ };
+
+ let ty = match fun {
+ MathFunction::Pack4x8snorm | MathFunction::Pack4x8unorm => TypeInner::Vector {
+ size: crate::VectorSize::Quad,
+ scalar: Scalar::F32,
+ },
+ MathFunction::Pack2x16unorm
+ | MathFunction::Pack2x16snorm
+ | MathFunction::Pack2x16float => TypeInner::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: Scalar::F32,
+ },
+ _ => unreachable!(),
+ };
+
+ let args = vec![ty];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(fun)));
+ }
+ "unpackSnorm4x8" | "unpackUnorm4x8" | "unpackSnorm2x16" | "unpackUnorm2x16"
+ | "unpackHalf2x16" => {
+ let fun = match name {
+ "unpackSnorm4x8" => MathFunction::Unpack4x8snorm,
+ "unpackUnorm4x8" => MathFunction::Unpack4x8unorm,
+ "unpackSnorm2x16" => MathFunction::Unpack2x16snorm,
+ "unpackUnorm2x16" => MathFunction::Unpack2x16unorm,
+ "unpackHalf2x16" => MathFunction::Unpack2x16float,
+ _ => unreachable!(),
+ };
+
+ let args = vec![TypeInner::Scalar(Scalar::U32)];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(fun)));
+ }
+ "atan" => {
+ // bits layout
+ // bit 0 - atan/atan2
+ // bit 1 through 2 - dims
+ for bits in 0..0b1000 {
+ let fun = match bits & 0b1 {
+ 0b0 => MathFunction::Atan,
+ _ => MathFunction::Atan2,
+ };
+ let size = match bits >> 1 {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let scalar = Scalar::F32;
+ let ty = || match size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ };
+
+ let mut args = vec![ty()];
+
+ if fun == MathFunction::Atan2 {
+ args.push(ty())
+ }
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(fun)))
+ }
+ }
+ "all" | "any" | "not" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b11 {
+ let size = match bits {
+ 0b00 => VectorSize::Bi,
+ 0b01 => VectorSize::Tri,
+ _ => VectorSize::Quad,
+ };
+
+ let args = vec![TypeInner::Vector {
+ size,
+ scalar: Scalar::BOOL,
+ }];
+
+ let fun = match name {
+ "all" => MacroCall::Relational(RelationalFunction::All),
+ "any" => MacroCall::Relational(RelationalFunction::Any),
+ "not" => MacroCall::Unary(UnaryOperator::LogicalNot),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "lessThan" | "greaterThan" | "lessThanEqual" | "greaterThanEqual" => {
+ for bits in 0..0b1001 {
+ let (size, scalar) = match bits {
+ 0b0000 => (VectorSize::Bi, Scalar::F32),
+ 0b0001 => (VectorSize::Tri, Scalar::F32),
+ 0b0010 => (VectorSize::Quad, Scalar::F32),
+ 0b0011 => (VectorSize::Bi, Scalar::I32),
+ 0b0100 => (VectorSize::Tri, Scalar::I32),
+ 0b0101 => (VectorSize::Quad, Scalar::I32),
+ 0b0110 => (VectorSize::Bi, Scalar::U32),
+ 0b0111 => (VectorSize::Tri, Scalar::U32),
+ _ => (VectorSize::Quad, Scalar::U32),
+ };
+
+ let ty = || TypeInner::Vector { size, scalar };
+ let args = vec![ty(), ty()];
+
+ let fun = MacroCall::Binary(match name {
+ "lessThan" => BinaryOperator::Less,
+ "greaterThan" => BinaryOperator::Greater,
+ "lessThanEqual" => BinaryOperator::LessEqual,
+ "greaterThanEqual" => BinaryOperator::GreaterEqual,
+ _ => unreachable!(),
+ });
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "equal" | "notEqual" => {
+ for bits in 0..0b1100 {
+ let (size, scalar) = match bits {
+ 0b0000 => (VectorSize::Bi, Scalar::F32),
+ 0b0001 => (VectorSize::Tri, Scalar::F32),
+ 0b0010 => (VectorSize::Quad, Scalar::F32),
+ 0b0011 => (VectorSize::Bi, Scalar::I32),
+ 0b0100 => (VectorSize::Tri, Scalar::I32),
+ 0b0101 => (VectorSize::Quad, Scalar::I32),
+ 0b0110 => (VectorSize::Bi, Scalar::U32),
+ 0b0111 => (VectorSize::Tri, Scalar::U32),
+ 0b1000 => (VectorSize::Quad, Scalar::U32),
+ 0b1001 => (VectorSize::Bi, Scalar::BOOL),
+ 0b1010 => (VectorSize::Tri, Scalar::BOOL),
+ _ => (VectorSize::Quad, Scalar::BOOL),
+ };
+
+ let ty = || TypeInner::Vector { size, scalar };
+ let args = vec![ty(), ty()];
+
+ let fun = MacroCall::Binary(match name {
+ "equal" => BinaryOperator::Equal,
+ "notEqual" => BinaryOperator::NotEqual,
+ _ => unreachable!(),
+ });
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "min" | "max" => {
+ // bits layout
+ // bit 0 through 1 - scalar kind
+ // bit 2 through 4 - dims
+ for bits in 0..0b11100 {
+ let scalar = match bits & 0b11 {
+ 0b00 => Scalar::F32,
+ 0b01 => Scalar::I32,
+ 0b10 => Scalar::U32,
+ _ => continue,
+ };
+ let (size, second_size) = match bits >> 2 {
+ 0b000 => (None, None),
+ 0b001 => (Some(VectorSize::Bi), None),
+ 0b010 => (Some(VectorSize::Tri), None),
+ 0b011 => (Some(VectorSize::Quad), None),
+ 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)),
+ 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)),
+ _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)),
+ };
+
+ let args = vec![
+ match size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ },
+ match second_size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ },
+ ];
+
+ let fun = match name {
+ "max" => MacroCall::Splatted(MathFunction::Max, size, 1),
+ "min" => MacroCall::Splatted(MathFunction::Min, size, 1),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "mix" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ // bit 2 through 4 - types
+ //
+ // 0b10011 is the last element since splatted single elements
+ // were already added
+ for bits in 0..0b10011 {
+ let size = match bits & 0b11 {
+ 0b00 => Some(VectorSize::Bi),
+ 0b01 => Some(VectorSize::Tri),
+ 0b10 => Some(VectorSize::Quad),
+ _ => None,
+ };
+ let (scalar, splatted, boolean) = match bits >> 2 {
+ 0b000 => (Scalar::I32, false, true),
+ 0b001 => (Scalar::U32, false, true),
+ 0b010 => (Scalar::F32, false, true),
+ 0b011 => (Scalar::F32, false, false),
+ _ => (Scalar::F32, true, false),
+ };
+
+ let ty = |scalar| match size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ };
+ let args = vec![
+ ty(scalar),
+ ty(scalar),
+ match (boolean, splatted) {
+ (true, _) => ty(Scalar::BOOL),
+ (_, false) => TypeInner::Scalar(scalar),
+ _ => ty(scalar),
+ },
+ ];
+
+ declaration.overloads.push(module.add_builtin(
+ args,
+ match boolean {
+ true => MacroCall::MixBoolean,
+ false => MacroCall::Splatted(MathFunction::Mix, size, 2),
+ },
+ ))
+ }
+ }
+ "clamp" => {
+ // bits layout
+ // bit 0 through 1 - float/int/uint
+ // bit 2 through 3 - dims
+ // bit 4 - splatted
+ //
+ // 0b11010 is the last element since splatted single elements
+ // were already added
+ for bits in 0..0b11011 {
+ let scalar = match bits & 0b11 {
+ 0b00 => Scalar::F32,
+ 0b01 => Scalar::I32,
+ 0b10 => Scalar::U32,
+ _ => continue,
+ };
+ let size = match (bits >> 2) & 0b11 {
+ 0b00 => Some(VectorSize::Bi),
+ 0b01 => Some(VectorSize::Tri),
+ 0b10 => Some(VectorSize::Quad),
+ _ => None,
+ };
+ let splatted = bits & 0b10000 == 0b10000;
+
+ let base_ty = || match size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ };
+ let limit_ty = || match splatted {
+ true => TypeInner::Scalar(scalar),
+ false => base_ty(),
+ };
+
+ let args = vec![base_ty(), limit_ty(), limit_ty()];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::Clamp(size)))
+ }
+ }
+ "barrier" => declaration
+ .overloads
+ .push(module.add_builtin(Vec::new(), MacroCall::Barrier)),
+ // Add common builtins with floats
+ _ => inject_common_builtin(declaration, module, name, 4),
+ }
+}
+
+/// Injects the builtins into declaration that need doubles
+fn inject_double_builtin(declaration: &mut FunctionDeclaration, module: &mut Module, name: &str) {
+ match name {
+ "abs" | "sign" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let scalar = Scalar::F64;
+
+ let args = vec![match size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ }];
+
+ declaration.overloads.push(module.add_builtin(
+ args,
+ MacroCall::MathFunction(match name {
+ "abs" => MathFunction::Abs,
+ "sign" => MathFunction::Sign,
+ _ => unreachable!(),
+ }),
+ ))
+ }
+ }
+ "min" | "max" => {
+ // bits layout
+ // bit 0 through 2 - dims
+ for bits in 0..0b111 {
+ let (size, second_size) = match bits {
+ 0b000 => (None, None),
+ 0b001 => (Some(VectorSize::Bi), None),
+ 0b010 => (Some(VectorSize::Tri), None),
+ 0b011 => (Some(VectorSize::Quad), None),
+ 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)),
+ 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)),
+ _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)),
+ };
+ let scalar = Scalar::F64;
+
+ let args = vec![
+ match size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ },
+ match second_size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ },
+ ];
+
+ let fun = match name {
+ "max" => MacroCall::Splatted(MathFunction::Max, size, 1),
+ "min" => MacroCall::Splatted(MathFunction::Min, size, 1),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "mix" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ // bit 2 through 3 - splatted/boolean
+ //
+ // 0b1010 is the last element since splatted with single elements
+ // is equal to normal single elements
+ for bits in 0..0b1011 {
+ let size = match bits & 0b11 {
+ 0b00 => Some(VectorSize::Quad),
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => None,
+ };
+ let scalar = Scalar::F64;
+ let (splatted, boolean) = match bits >> 2 {
+ 0b00 => (false, false),
+ 0b01 => (false, true),
+ _ => (true, false),
+ };
+
+ let ty = |scalar| match size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ };
+ let args = vec![
+ ty(scalar),
+ ty(scalar),
+ match (boolean, splatted) {
+ (true, _) => ty(Scalar::BOOL),
+ (_, false) => TypeInner::Scalar(scalar),
+ _ => ty(scalar),
+ },
+ ];
+
+ declaration.overloads.push(module.add_builtin(
+ args,
+ match boolean {
+ true => MacroCall::MixBoolean,
+ false => MacroCall::Splatted(MathFunction::Mix, size, 2),
+ },
+ ))
+ }
+ }
+ "clamp" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ // bit 2 - splatted
+ //
+ // 0b110 is the last element since splatted with single elements
+ // is equal to normal single elements
+ for bits in 0..0b111 {
+ let scalar = Scalar::F64;
+ let size = match bits & 0b11 {
+ 0b00 => Some(VectorSize::Bi),
+ 0b01 => Some(VectorSize::Tri),
+ 0b10 => Some(VectorSize::Quad),
+ _ => None,
+ };
+ let splatted = bits & 0b100 == 0b100;
+
+ let base_ty = || match size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ };
+ let limit_ty = || match splatted {
+ true => TypeInner::Scalar(scalar),
+ false => base_ty(),
+ };
+
+ let args = vec![base_ty(), limit_ty(), limit_ty()];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::Clamp(size)))
+ }
+ }
+ "lessThan" | "greaterThan" | "lessThanEqual" | "greaterThanEqual" | "equal"
+ | "notEqual" => {
+ let scalar = Scalar::F64;
+ for bits in 0..0b11 {
+ let size = match bits {
+ 0b00 => VectorSize::Bi,
+ 0b01 => VectorSize::Tri,
+ _ => VectorSize::Quad,
+ };
+
+ let ty = || TypeInner::Vector { size, scalar };
+ let args = vec![ty(), ty()];
+
+ let fun = MacroCall::Binary(match name {
+ "lessThan" => BinaryOperator::Less,
+ "greaterThan" => BinaryOperator::Greater,
+ "lessThanEqual" => BinaryOperator::LessEqual,
+ "greaterThanEqual" => BinaryOperator::GreaterEqual,
+ "equal" => BinaryOperator::Equal,
+ "notEqual" => BinaryOperator::NotEqual,
+ _ => unreachable!(),
+ });
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ // Add common builtins with doubles
+ _ => inject_common_builtin(declaration, module, name, 8),
+ }
+}
+
+/// Injects the builtins into declaration that can used either float or doubles
+fn inject_common_builtin(
+ declaration: &mut FunctionDeclaration,
+ module: &mut Module,
+ name: &str,
+ float_width: crate::Bytes,
+) {
+ let float_scalar = Scalar {
+ kind: Sk::Float,
+ width: float_width,
+ };
+ match name {
+ "ceil" | "round" | "roundEven" | "floor" | "fract" | "trunc" | "sqrt" | "inversesqrt"
+ | "normalize" | "length" | "isinf" | "isnan" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ let args = vec![match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ scalar: float_scalar,
+ },
+ None => TypeInner::Scalar(float_scalar),
+ }];
+
+ let fun = match name {
+ "ceil" => MacroCall::MathFunction(MathFunction::Ceil),
+ "round" | "roundEven" => MacroCall::MathFunction(MathFunction::Round),
+ "floor" => MacroCall::MathFunction(MathFunction::Floor),
+ "fract" => MacroCall::MathFunction(MathFunction::Fract),
+ "trunc" => MacroCall::MathFunction(MathFunction::Trunc),
+ "sqrt" => MacroCall::MathFunction(MathFunction::Sqrt),
+ "inversesqrt" => MacroCall::MathFunction(MathFunction::InverseSqrt),
+ "normalize" => MacroCall::MathFunction(MathFunction::Normalize),
+ "length" => MacroCall::MathFunction(MathFunction::Length),
+ "isinf" => MacroCall::Relational(RelationalFunction::IsInf),
+ "isnan" => MacroCall::Relational(RelationalFunction::IsNan),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "dot" | "reflect" | "distance" | "ldexp" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+ let ty = |scalar| match size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ };
+
+ let fun = match name {
+ "dot" => MacroCall::MathFunction(MathFunction::Dot),
+ "reflect" => MacroCall::MathFunction(MathFunction::Reflect),
+ "distance" => MacroCall::MathFunction(MathFunction::Distance),
+ "ldexp" => MacroCall::MathFunction(MathFunction::Ldexp),
+ _ => unreachable!(),
+ };
+
+ let second_scalar = match fun {
+ MacroCall::MathFunction(MathFunction::Ldexp) => Scalar::I32,
+ _ => float_scalar,
+ };
+
+ declaration
+ .overloads
+ .push(module.add_builtin(vec![ty(float_scalar), ty(second_scalar)], fun))
+ }
+ }
+ "transpose" => {
+ // bits layout
+ // bit 0 through 3 - dims
+ for bits in 0..0b1001 {
+ let (rows, columns) = match bits {
+ 0b0000 => (VectorSize::Bi, VectorSize::Bi),
+ 0b0001 => (VectorSize::Bi, VectorSize::Tri),
+ 0b0010 => (VectorSize::Bi, VectorSize::Quad),
+ 0b0011 => (VectorSize::Tri, VectorSize::Bi),
+ 0b0100 => (VectorSize::Tri, VectorSize::Tri),
+ 0b0101 => (VectorSize::Tri, VectorSize::Quad),
+ 0b0110 => (VectorSize::Quad, VectorSize::Bi),
+ 0b0111 => (VectorSize::Quad, VectorSize::Tri),
+ _ => (VectorSize::Quad, VectorSize::Quad),
+ };
+
+ declaration.overloads.push(module.add_builtin(
+ vec![TypeInner::Matrix {
+ columns,
+ rows,
+ scalar: float_scalar,
+ }],
+ MacroCall::MathFunction(MathFunction::Transpose),
+ ))
+ }
+ }
+ "inverse" | "determinant" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b11 {
+ let (rows, columns) = match bits {
+ 0b00 => (VectorSize::Bi, VectorSize::Bi),
+ 0b01 => (VectorSize::Tri, VectorSize::Tri),
+ _ => (VectorSize::Quad, VectorSize::Quad),
+ };
+
+ let args = vec![TypeInner::Matrix {
+ columns,
+ rows,
+ scalar: float_scalar,
+ }];
+
+ declaration.overloads.push(module.add_builtin(
+ args,
+ MacroCall::MathFunction(match name {
+ "inverse" => MathFunction::Inverse,
+ "determinant" => MathFunction::Determinant,
+ _ => unreachable!(),
+ }),
+ ))
+ }
+ }
+ "mod" | "step" => {
+ // bits layout
+ // bit 0 through 2 - dims
+ for bits in 0..0b111 {
+ let (size, second_size) = match bits {
+ 0b000 => (None, None),
+ 0b001 => (Some(VectorSize::Bi), None),
+ 0b010 => (Some(VectorSize::Tri), None),
+ 0b011 => (Some(VectorSize::Quad), None),
+ 0b100 => (Some(VectorSize::Bi), Some(VectorSize::Bi)),
+ 0b101 => (Some(VectorSize::Tri), Some(VectorSize::Tri)),
+ _ => (Some(VectorSize::Quad), Some(VectorSize::Quad)),
+ };
+
+ let mut args = Vec::with_capacity(2);
+ let step = name == "step";
+
+ for i in 0..2 {
+ let maybe_size = match i == step as u32 {
+ true => size,
+ false => second_size,
+ };
+
+ args.push(match maybe_size {
+ Some(size) => TypeInner::Vector {
+ size,
+ scalar: float_scalar,
+ },
+ None => TypeInner::Scalar(float_scalar),
+ })
+ }
+
+ let fun = match name {
+ "mod" => MacroCall::Mod(size),
+ "step" => MacroCall::Splatted(MathFunction::Step, size, 0),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ // TODO: https://github.com/gfx-rs/naga/issues/2526
+ // "modf" | "frexp" => { ... }
+ "cross" => {
+ let args = vec![
+ TypeInner::Vector {
+ size: VectorSize::Tri,
+ scalar: float_scalar,
+ },
+ TypeInner::Vector {
+ size: VectorSize::Tri,
+ scalar: float_scalar,
+ },
+ ];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Cross)))
+ }
+ "outerProduct" => {
+ // bits layout
+ // bit 0 through 3 - dims
+ for bits in 0..0b1001 {
+ let (size1, size2) = match bits {
+ 0b0000 => (VectorSize::Bi, VectorSize::Bi),
+ 0b0001 => (VectorSize::Bi, VectorSize::Tri),
+ 0b0010 => (VectorSize::Bi, VectorSize::Quad),
+ 0b0011 => (VectorSize::Tri, VectorSize::Bi),
+ 0b0100 => (VectorSize::Tri, VectorSize::Tri),
+ 0b0101 => (VectorSize::Tri, VectorSize::Quad),
+ 0b0110 => (VectorSize::Quad, VectorSize::Bi),
+ 0b0111 => (VectorSize::Quad, VectorSize::Tri),
+ _ => (VectorSize::Quad, VectorSize::Quad),
+ };
+
+ let args = vec![
+ TypeInner::Vector {
+ size: size1,
+ scalar: float_scalar,
+ },
+ TypeInner::Vector {
+ size: size2,
+ scalar: float_scalar,
+ },
+ ];
+
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Outer)))
+ }
+ }
+ "faceforward" | "fma" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ let ty = || match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ scalar: float_scalar,
+ },
+ None => TypeInner::Scalar(float_scalar),
+ };
+ let args = vec![ty(), ty(), ty()];
+
+ let fun = match name {
+ "faceforward" => MacroCall::MathFunction(MathFunction::FaceForward),
+ "fma" => MacroCall::MathFunction(MathFunction::Fma),
+ _ => unreachable!(),
+ };
+
+ declaration.overloads.push(module.add_builtin(args, fun))
+ }
+ }
+ "refract" => {
+ // bits layout
+ // bit 0 through 1 - dims
+ for bits in 0..0b100 {
+ let size = match bits {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ let ty = || match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ scalar: float_scalar,
+ },
+ None => TypeInner::Scalar(float_scalar),
+ };
+ let args = vec![ty(), ty(), TypeInner::Scalar(Scalar::F32)];
+ declaration
+ .overloads
+ .push(module.add_builtin(args, MacroCall::MathFunction(MathFunction::Refract)))
+ }
+ }
+ "smoothstep" => {
+ // bit 0 - splatted
+ // bit 1 through 2 - dims
+ for bits in 0..0b1000 {
+ let splatted = bits & 0b1 == 0b1;
+ let size = match bits >> 1 {
+ 0b00 => None,
+ 0b01 => Some(VectorSize::Bi),
+ 0b10 => Some(VectorSize::Tri),
+ _ => Some(VectorSize::Quad),
+ };
+
+ if splatted && size.is_none() {
+ continue;
+ }
+
+ let base_ty = || match size {
+ Some(size) => TypeInner::Vector {
+ size,
+ scalar: float_scalar,
+ },
+ None => TypeInner::Scalar(float_scalar),
+ };
+ let ty = || match splatted {
+ true => TypeInner::Scalar(float_scalar),
+ false => base_ty(),
+ };
+ declaration.overloads.push(module.add_builtin(
+ vec![ty(), ty(), base_ty()],
+ MacroCall::SmoothStep { splatted: size },
+ ))
+ }
+ }
+ // The function isn't a builtin or we don't yet support it
+ _ => {}
+ }
+}
+
+#[derive(Clone, Copy, PartialEq, Debug)]
+pub enum TextureLevelType {
+ None,
+ Lod,
+ Grad,
+}
+
+/// A compiler defined builtin function
+#[derive(Clone, Copy, PartialEq, Debug)]
+pub enum MacroCall {
+ Sampler,
+ SamplerShadow,
+ Texture {
+ proj: bool,
+ offset: bool,
+ shadow: bool,
+ level_type: TextureLevelType,
+ },
+ TextureSize {
+ arrayed: bool,
+ },
+ ImageLoad {
+ multi: bool,
+ },
+ ImageStore,
+ MathFunction(MathFunction),
+ FindLsbUint,
+ FindMsbUint,
+ BitfieldExtract,
+ BitfieldInsert,
+ Relational(RelationalFunction),
+ Unary(UnaryOperator),
+ Binary(BinaryOperator),
+ Mod(Option<VectorSize>),
+ Splatted(MathFunction, Option<VectorSize>, usize),
+ MixBoolean,
+ Clamp(Option<VectorSize>),
+ BitCast(Sk),
+ Derivate(Axis, Ctrl),
+ Barrier,
+ /// SmoothStep needs a separate variant because it might need it's inputs
+ /// to be splatted depending on the overload
+ SmoothStep {
+ /// The size of the splat operation if some
+ splatted: Option<VectorSize>,
+ },
+}
+
+impl MacroCall {
+ /// Adds the necessary expressions and statements to the passed body and
+ /// finally returns the final expression with the correct result
+ pub fn call(
+ &self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ args: &mut [Handle<Expression>],
+ meta: Span,
+ ) -> Result<Option<Handle<Expression>>> {
+ Ok(Some(match *self {
+ MacroCall::Sampler => {
+ ctx.samplers.insert(args[0], args[1]);
+ args[0]
+ }
+ MacroCall::SamplerShadow => {
+ sampled_to_depth(ctx, args[0], meta, &mut frontend.errors);
+ ctx.invalidate_expression(args[0], meta)?;
+ ctx.samplers.insert(args[0], args[1]);
+ args[0]
+ }
+ MacroCall::Texture {
+ proj,
+ offset,
+ shadow,
+ level_type,
+ } => {
+ let mut coords = args[1];
+
+ if proj {
+ let size = match *ctx.resolve_type(coords, meta)? {
+ TypeInner::Vector { size, .. } => size,
+ _ => unreachable!(),
+ };
+ let mut right = ctx.add_expression(
+ Expression::AccessIndex {
+ base: coords,
+ index: size as u32 - 1,
+ },
+ Span::default(),
+ )?;
+ let left = if let VectorSize::Bi = size {
+ ctx.add_expression(
+ Expression::AccessIndex {
+ base: coords,
+ index: 0,
+ },
+ Span::default(),
+ )?
+ } else {
+ let size = match size {
+ VectorSize::Tri => VectorSize::Bi,
+ _ => VectorSize::Tri,
+ };
+ right = ctx.add_expression(
+ Expression::Splat { size, value: right },
+ Span::default(),
+ )?;
+ ctx.vector_resize(size, coords, Span::default())?
+ };
+ coords = ctx.add_expression(
+ Expression::Binary {
+ op: BinaryOperator::Divide,
+ left,
+ right,
+ },
+ Span::default(),
+ )?;
+ }
+
+ let extra = args.get(2).copied();
+ let comps = frontend.coordinate_components(ctx, args[0], coords, extra, meta)?;
+
+ let mut num_args = 2;
+
+ if comps.used_extra {
+ num_args += 1;
+ };
+
+ // Parse out explicit texture level.
+ let mut level = match level_type {
+ TextureLevelType::None => SampleLevel::Auto,
+
+ TextureLevelType::Lod => {
+ num_args += 1;
+
+ if shadow {
+ log::warn!("Assuming LOD {:?} is zero", args[2],);
+
+ SampleLevel::Zero
+ } else {
+ SampleLevel::Exact(args[2])
+ }
+ }
+
+ TextureLevelType::Grad => {
+ num_args += 2;
+
+ if shadow {
+ log::warn!(
+ "Assuming gradients {:?} and {:?} are not greater than 1",
+ args[2],
+ args[3],
+ );
+ SampleLevel::Zero
+ } else {
+ SampleLevel::Gradient {
+ x: args[2],
+ y: args[3],
+ }
+ }
+ }
+ };
+
+ let texture_offset = match offset {
+ true => {
+ let offset_arg = args[num_args];
+ num_args += 1;
+ match ctx.lift_up_const_expression(offset_arg) {
+ Ok(v) => Some(v),
+ Err(e) => {
+ frontend.errors.push(e);
+ None
+ }
+ }
+ }
+ false => None,
+ };
+
+ // Now go back and look for optional bias arg (if available)
+ if let TextureLevelType::None = level_type {
+ level = args
+ .get(num_args)
+ .copied()
+ .map_or(SampleLevel::Auto, SampleLevel::Bias);
+ }
+
+ texture_call(ctx, args[0], level, comps, texture_offset, meta)?
+ }
+
+ MacroCall::TextureSize { arrayed } => {
+ let mut expr = ctx.add_expression(
+ Expression::ImageQuery {
+ image: args[0],
+ query: ImageQuery::Size {
+ level: args.get(1).copied(),
+ },
+ },
+ Span::default(),
+ )?;
+
+ if arrayed {
+ let mut components = Vec::with_capacity(4);
+
+ let size = match *ctx.resolve_type(expr, meta)? {
+ TypeInner::Vector { size: ori_size, .. } => {
+ for index in 0..(ori_size as u32) {
+ components.push(ctx.add_expression(
+ Expression::AccessIndex { base: expr, index },
+ Span::default(),
+ )?)
+ }
+
+ match ori_size {
+ VectorSize::Bi => VectorSize::Tri,
+ _ => VectorSize::Quad,
+ }
+ }
+ _ => {
+ components.push(expr);
+ VectorSize::Bi
+ }
+ };
+
+ components.push(ctx.add_expression(
+ Expression::ImageQuery {
+ image: args[0],
+ query: ImageQuery::NumLayers,
+ },
+ Span::default(),
+ )?);
+
+ let ty = ctx.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size,
+ scalar: Scalar::U32,
+ },
+ },
+ Span::default(),
+ );
+
+ expr = ctx.add_expression(Expression::Compose { components, ty }, meta)?
+ }
+
+ ctx.add_expression(
+ Expression::As {
+ expr,
+ kind: Sk::Sint,
+ convert: Some(4),
+ },
+ Span::default(),
+ )?
+ }
+ MacroCall::ImageLoad { multi } => {
+ let comps = frontend.coordinate_components(ctx, args[0], args[1], None, meta)?;
+ let (sample, level) = match (multi, args.get(2)) {
+ (_, None) => (None, None),
+ (true, Some(&arg)) => (Some(arg), None),
+ (false, Some(&arg)) => (None, Some(arg)),
+ };
+ ctx.add_expression(
+ Expression::ImageLoad {
+ image: args[0],
+ coordinate: comps.coordinate,
+ array_index: comps.array_index,
+ sample,
+ level,
+ },
+ Span::default(),
+ )?
+ }
+ MacroCall::ImageStore => {
+ let comps = frontend.coordinate_components(ctx, args[0], args[1], None, meta)?;
+ ctx.emit_restart();
+ ctx.body.push(
+ crate::Statement::ImageStore {
+ image: args[0],
+ coordinate: comps.coordinate,
+ array_index: comps.array_index,
+ value: args[2],
+ },
+ meta,
+ );
+ return Ok(None);
+ }
+ MacroCall::MathFunction(fun) => ctx.add_expression(
+ Expression::Math {
+ fun,
+ arg: args[0],
+ arg1: args.get(1).copied(),
+ arg2: args.get(2).copied(),
+ arg3: args.get(3).copied(),
+ },
+ Span::default(),
+ )?,
+ mc @ (MacroCall::FindLsbUint | MacroCall::FindMsbUint) => {
+ let fun = match mc {
+ MacroCall::FindLsbUint => MathFunction::FindLsb,
+ MacroCall::FindMsbUint => MathFunction::FindMsb,
+ _ => unreachable!(),
+ };
+ let res = ctx.add_expression(
+ Expression::Math {
+ fun,
+ arg: args[0],
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ Span::default(),
+ )?;
+ ctx.add_expression(
+ Expression::As {
+ expr: res,
+ kind: Sk::Sint,
+ convert: Some(4),
+ },
+ Span::default(),
+ )?
+ }
+ MacroCall::BitfieldInsert => {
+ let conv_arg_2 = ctx.add_expression(
+ Expression::As {
+ expr: args[2],
+ kind: Sk::Uint,
+ convert: Some(4),
+ },
+ Span::default(),
+ )?;
+ let conv_arg_3 = ctx.add_expression(
+ Expression::As {
+ expr: args[3],
+ kind: Sk::Uint,
+ convert: Some(4),
+ },
+ Span::default(),
+ )?;
+ ctx.add_expression(
+ Expression::Math {
+ fun: MathFunction::InsertBits,
+ arg: args[0],
+ arg1: Some(args[1]),
+ arg2: Some(conv_arg_2),
+ arg3: Some(conv_arg_3),
+ },
+ Span::default(),
+ )?
+ }
+ MacroCall::BitfieldExtract => {
+ let conv_arg_1 = ctx.add_expression(
+ Expression::As {
+ expr: args[1],
+ kind: Sk::Uint,
+ convert: Some(4),
+ },
+ Span::default(),
+ )?;
+ let conv_arg_2 = ctx.add_expression(
+ Expression::As {
+ expr: args[2],
+ kind: Sk::Uint,
+ convert: Some(4),
+ },
+ Span::default(),
+ )?;
+ ctx.add_expression(
+ Expression::Math {
+ fun: MathFunction::ExtractBits,
+ arg: args[0],
+ arg1: Some(conv_arg_1),
+ arg2: Some(conv_arg_2),
+ arg3: None,
+ },
+ Span::default(),
+ )?
+ }
+ MacroCall::Relational(fun) => ctx.add_expression(
+ Expression::Relational {
+ fun,
+ argument: args[0],
+ },
+ Span::default(),
+ )?,
+ MacroCall::Unary(op) => {
+ ctx.add_expression(Expression::Unary { op, expr: args[0] }, Span::default())?
+ }
+ MacroCall::Binary(op) => ctx.add_expression(
+ Expression::Binary {
+ op,
+ left: args[0],
+ right: args[1],
+ },
+ Span::default(),
+ )?,
+ MacroCall::Mod(size) => {
+ ctx.implicit_splat(&mut args[1], meta, size)?;
+
+ // x - y * floor(x / y)
+
+ let div = ctx.add_expression(
+ Expression::Binary {
+ op: BinaryOperator::Divide,
+ left: args[0],
+ right: args[1],
+ },
+ Span::default(),
+ )?;
+ let floor = ctx.add_expression(
+ Expression::Math {
+ fun: MathFunction::Floor,
+ arg: div,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ Span::default(),
+ )?;
+ let mult = ctx.add_expression(
+ Expression::Binary {
+ op: BinaryOperator::Multiply,
+ left: floor,
+ right: args[1],
+ },
+ Span::default(),
+ )?;
+ ctx.add_expression(
+ Expression::Binary {
+ op: BinaryOperator::Subtract,
+ left: args[0],
+ right: mult,
+ },
+ Span::default(),
+ )?
+ }
+ MacroCall::Splatted(fun, size, i) => {
+ ctx.implicit_splat(&mut args[i], meta, size)?;
+
+ ctx.add_expression(
+ Expression::Math {
+ fun,
+ arg: args[0],
+ arg1: args.get(1).copied(),
+ arg2: args.get(2).copied(),
+ arg3: args.get(3).copied(),
+ },
+ Span::default(),
+ )?
+ }
+ MacroCall::MixBoolean => ctx.add_expression(
+ Expression::Select {
+ condition: args[2],
+ accept: args[1],
+ reject: args[0],
+ },
+ Span::default(),
+ )?,
+ MacroCall::Clamp(size) => {
+ ctx.implicit_splat(&mut args[1], meta, size)?;
+ ctx.implicit_splat(&mut args[2], meta, size)?;
+
+ ctx.add_expression(
+ Expression::Math {
+ fun: MathFunction::Clamp,
+ arg: args[0],
+ arg1: args.get(1).copied(),
+ arg2: args.get(2).copied(),
+ arg3: args.get(3).copied(),
+ },
+ Span::default(),
+ )?
+ }
+ MacroCall::BitCast(kind) => ctx.add_expression(
+ Expression::As {
+ expr: args[0],
+ kind,
+ convert: None,
+ },
+ Span::default(),
+ )?,
+ MacroCall::Derivate(axis, ctrl) => ctx.add_expression(
+ Expression::Derivative {
+ axis,
+ ctrl,
+ expr: args[0],
+ },
+ Span::default(),
+ )?,
+ MacroCall::Barrier => {
+ ctx.emit_restart();
+ ctx.body
+ .push(crate::Statement::Barrier(crate::Barrier::all()), meta);
+ return Ok(None);
+ }
+ MacroCall::SmoothStep { splatted } => {
+ ctx.implicit_splat(&mut args[0], meta, splatted)?;
+ ctx.implicit_splat(&mut args[1], meta, splatted)?;
+
+ ctx.add_expression(
+ Expression::Math {
+ fun: MathFunction::SmoothStep,
+ arg: args[0],
+ arg1: args.get(1).copied(),
+ arg2: args.get(2).copied(),
+ arg3: None,
+ },
+ Span::default(),
+ )?
+ }
+ }))
+ }
+}
+
+fn texture_call(
+ ctx: &mut Context,
+ image: Handle<Expression>,
+ level: SampleLevel,
+ comps: CoordComponents,
+ offset: Option<Handle<Expression>>,
+ meta: Span,
+) -> Result<Handle<Expression>> {
+ if let Some(sampler) = ctx.samplers.get(&image).copied() {
+ let mut array_index = comps.array_index;
+
+ if let Some(ref mut array_index_expr) = array_index {
+ ctx.conversion(array_index_expr, meta, Scalar::I32)?;
+ }
+
+ Ok(ctx.add_expression(
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather: None, //TODO
+ coordinate: comps.coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref: comps.depth_ref,
+ },
+ meta,
+ )?)
+ } else {
+ Err(Error {
+ kind: ErrorKind::SemanticError("Bad call".into()),
+ meta,
+ })
+ }
+}
+
+/// Helper struct for texture calls with the separate components from the vector argument
+///
+/// Obtained by calling [`coordinate_components`](Frontend::coordinate_components)
+#[derive(Debug)]
+struct CoordComponents {
+ coordinate: Handle<Expression>,
+ depth_ref: Option<Handle<Expression>>,
+ array_index: Option<Handle<Expression>>,
+ used_extra: bool,
+}
+
+impl Frontend {
+ /// Helper function for texture calls, splits the vector argument into it's components
+ fn coordinate_components(
+ &mut self,
+ ctx: &mut Context,
+ image: Handle<Expression>,
+ coord: Handle<Expression>,
+ extra: Option<Handle<Expression>>,
+ meta: Span,
+ ) -> Result<CoordComponents> {
+ if let TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } = *ctx.resolve_type(image, meta)?
+ {
+ let image_size = match dim {
+ Dim::D1 => None,
+ Dim::D2 => Some(VectorSize::Bi),
+ Dim::D3 => Some(VectorSize::Tri),
+ Dim::Cube => Some(VectorSize::Tri),
+ };
+ let coord_size = match *ctx.resolve_type(coord, meta)? {
+ TypeInner::Vector { size, .. } => Some(size),
+ _ => None,
+ };
+ let (shadow, storage) = match class {
+ ImageClass::Depth { .. } => (true, false),
+ ImageClass::Storage { .. } => (false, true),
+ ImageClass::Sampled { .. } => (false, false),
+ };
+
+ let coordinate = match (image_size, coord_size) {
+ (Some(size), Some(coord_s)) if size != coord_s => {
+ ctx.vector_resize(size, coord, Span::default())?
+ }
+ (None, Some(_)) => ctx.add_expression(
+ Expression::AccessIndex {
+ base: coord,
+ index: 0,
+ },
+ Span::default(),
+ )?,
+ _ => coord,
+ };
+
+ let mut coord_index = image_size.map_or(1, |s| s as u32);
+
+ let array_index = if arrayed && !(storage && dim == Dim::Cube) {
+ let index = coord_index;
+ coord_index += 1;
+
+ Some(ctx.add_expression(
+ Expression::AccessIndex { base: coord, index },
+ Span::default(),
+ )?)
+ } else {
+ None
+ };
+ let mut used_extra = false;
+ let depth_ref = match shadow {
+ true => {
+ let index = coord_index;
+
+ if index == 4 {
+ used_extra = true;
+ extra
+ } else {
+ Some(ctx.add_expression(
+ Expression::AccessIndex { base: coord, index },
+ Span::default(),
+ )?)
+ }
+ }
+ false => None,
+ };
+
+ Ok(CoordComponents {
+ coordinate,
+ depth_ref,
+ array_index,
+ used_extra,
+ })
+ } else {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError("Type is not an image".into()),
+ meta,
+ });
+
+ Ok(CoordComponents {
+ coordinate: coord,
+ depth_ref: None,
+ array_index: None,
+ used_extra: false,
+ })
+ }
+ }
+}
+
+/// Helper function to cast a expression holding a sampled image to a
+/// depth image.
+pub fn sampled_to_depth(
+ ctx: &mut Context,
+ image: Handle<Expression>,
+ meta: Span,
+ errors: &mut Vec<Error>,
+) {
+ // Get the a mutable type handle of the underlying image storage
+ let ty = match ctx[image] {
+ Expression::GlobalVariable(handle) => &mut ctx.module.global_variables.get_mut(handle).ty,
+ Expression::FunctionArgument(i) => {
+ // Mark the function argument as carrying a depth texture
+ ctx.parameters_info[i as usize].depth = true;
+ // NOTE: We need to later also change the parameter type
+ &mut ctx.arguments[i as usize].ty
+ }
+ _ => {
+ // Only globals and function arguments are allowed to carry an image
+ return errors.push(Error {
+ kind: ErrorKind::SemanticError("Not a valid texture expression".into()),
+ meta,
+ });
+ }
+ };
+
+ match ctx.module.types[*ty].inner {
+ // Update the image class to depth in case it already isn't
+ TypeInner::Image {
+ class,
+ dim,
+ arrayed,
+ } => match class {
+ ImageClass::Sampled { multi, .. } => {
+ *ty = ctx.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Image {
+ dim,
+ arrayed,
+ class: ImageClass::Depth { multi },
+ },
+ },
+ Span::default(),
+ )
+ }
+ ImageClass::Depth { .. } => {}
+ // Other image classes aren't allowed to be transformed to depth
+ ImageClass::Storage { .. } => errors.push(Error {
+ kind: ErrorKind::SemanticError("Not a texture".into()),
+ meta,
+ }),
+ },
+ _ => errors.push(Error {
+ kind: ErrorKind::SemanticError("Not a texture".into()),
+ meta,
+ }),
+ };
+
+ // Copy the handle to allow borrowing the `ctx` again
+ let ty = *ty;
+
+ // If the image was passed through a function argument we also need to change
+ // the corresponding parameter
+ if let Expression::FunctionArgument(i) = ctx[image] {
+ ctx.parameters[i as usize] = ty;
+ }
+}
+
+bitflags::bitflags! {
+ /// Influences the operation `texture_args_generator`
+ struct TextureArgsOptions: u32 {
+ /// Generates multisampled variants of images
+ const MULTI = 1 << 0;
+ /// Generates shadow variants of images
+ const SHADOW = 1 << 1;
+ /// Generates standard images
+ const STANDARD = 1 << 2;
+ /// Generates cube arrayed images
+ const CUBE_ARRAY = 1 << 3;
+ /// Generates cube arrayed images
+ const D2_MULTI_ARRAY = 1 << 4;
+ }
+}
+
+impl From<BuiltinVariations> for TextureArgsOptions {
+ fn from(variations: BuiltinVariations) -> Self {
+ let mut options = TextureArgsOptions::empty();
+ if variations.contains(BuiltinVariations::STANDARD) {
+ options |= TextureArgsOptions::STANDARD
+ }
+ if variations.contains(BuiltinVariations::CUBE_TEXTURES_ARRAY) {
+ options |= TextureArgsOptions::CUBE_ARRAY
+ }
+ if variations.contains(BuiltinVariations::D2_MULTI_TEXTURES_ARRAY) {
+ options |= TextureArgsOptions::D2_MULTI_ARRAY
+ }
+ options
+ }
+}
+
+/// Helper function to generate the image components for texture/image builtins
+///
+/// Calls the passed function `f` with:
+/// ```text
+/// f(ScalarKind, ImageDimension, arrayed, multi, shadow)
+/// ```
+///
+/// `options` controls extra image variants generation like multisampling and depth,
+/// see the struct documentation
+fn texture_args_generator(
+ options: TextureArgsOptions,
+ mut f: impl FnMut(crate::ScalarKind, Dim, bool, bool, bool),
+) {
+ for kind in [Sk::Float, Sk::Uint, Sk::Sint].iter().copied() {
+ for dim in [Dim::D1, Dim::D2, Dim::D3, Dim::Cube].iter().copied() {
+ for arrayed in [false, true].iter().copied() {
+ if dim == Dim::Cube && arrayed {
+ if !options.contains(TextureArgsOptions::CUBE_ARRAY) {
+ continue;
+ }
+ } else if Dim::D2 == dim
+ && options.contains(TextureArgsOptions::MULTI)
+ && arrayed
+ && options.contains(TextureArgsOptions::D2_MULTI_ARRAY)
+ {
+ // multisampling for sampler2DMSArray
+ f(kind, dim, arrayed, true, false);
+ } else if !options.contains(TextureArgsOptions::STANDARD) {
+ continue;
+ }
+
+ f(kind, dim, arrayed, false, false);
+
+ // 3D images can't be neither arrayed nor shadow
+ // so we break out early, this way arrayed will always
+ // be false and we won't hit the shadow branch
+ if let Dim::D3 = dim {
+ break;
+ }
+
+ if Dim::D2 == dim && options.contains(TextureArgsOptions::MULTI) && !arrayed {
+ // multisampling
+ f(kind, dim, arrayed, true, false);
+ }
+
+ if Sk::Float == kind && options.contains(TextureArgsOptions::SHADOW) {
+ // shadow
+ f(kind, dim, arrayed, false, true);
+ }
+ }
+ }
+ }
+}
+
+/// Helper functions used to convert from a image dimension into a integer representing the
+/// number of components needed for the coordinates vector (1 means scalar instead of vector)
+const fn image_dims_to_coords_size(dim: Dim) -> usize {
+ match dim {
+ Dim::D1 => 1,
+ Dim::D2 => 2,
+ _ => 3,
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/context.rs b/third_party/rust/naga/src/front/glsl/context.rs
new file mode 100644
index 0000000000..f26c57965d
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/context.rs
@@ -0,0 +1,1506 @@
+use super::{
+ ast::{
+ GlobalLookup, GlobalLookupKind, HirExpr, HirExprKind, ParameterInfo, ParameterQualifier,
+ VariableReference,
+ },
+ error::{Error, ErrorKind},
+ types::{scalar_components, type_power},
+ Frontend, Result,
+};
+use crate::{
+ front::Typifier, proc::Emitter, AddressSpace, Arena, BinaryOperator, Block, Expression,
+ FastHashMap, FunctionArgument, Handle, Literal, LocalVariable, RelationalFunction, Scalar,
+ Span, Statement, Type, TypeInner, VectorSize,
+};
+use std::ops::Index;
+
+/// The position at which an expression is, used while lowering
+#[derive(Clone, Copy, PartialEq, Eq, Debug)]
+pub enum ExprPos {
+ /// The expression is in the left hand side of an assignment
+ Lhs,
+ /// The expression is in the right hand side of an assignment
+ Rhs,
+ /// The expression is an array being indexed, needed to allow constant
+ /// arrays to be dynamically indexed
+ AccessBase {
+ /// The index is a constant
+ constant_index: bool,
+ },
+}
+
+impl ExprPos {
+ /// Returns an lhs position if the current position is lhs otherwise AccessBase
+ const fn maybe_access_base(&self, constant_index: bool) -> Self {
+ match *self {
+ ExprPos::Lhs
+ | ExprPos::AccessBase {
+ constant_index: false,
+ } => *self,
+ _ => ExprPos::AccessBase { constant_index },
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Context<'a> {
+ pub expressions: Arena<Expression>,
+ pub locals: Arena<LocalVariable>,
+
+ /// The [`FunctionArgument`]s for the final [`crate::Function`].
+ ///
+ /// Parameters with the `out` and `inout` qualifiers have [`Pointer`] types
+ /// here. For example, an `inout vec2 a` argument would be a [`Pointer`] to
+ /// a [`Vector`].
+ ///
+ /// [`Pointer`]: crate::TypeInner::Pointer
+ /// [`Vector`]: crate::TypeInner::Vector
+ pub arguments: Vec<FunctionArgument>,
+
+ /// The parameter types given in the source code.
+ ///
+ /// The `out` and `inout` qualifiers don't affect the types that appear
+ /// here. For example, an `inout vec2 a` argument would simply be a
+ /// [`Vector`], not a pointer to one.
+ ///
+ /// [`Vector`]: crate::TypeInner::Vector
+ pub parameters: Vec<Handle<Type>>,
+ pub parameters_info: Vec<ParameterInfo>,
+
+ pub symbol_table: crate::front::SymbolTable<String, VariableReference>,
+ pub samplers: FastHashMap<Handle<Expression>, Handle<Expression>>,
+
+ pub const_typifier: Typifier,
+ pub typifier: Typifier,
+ emitter: Emitter,
+ stmt_ctx: Option<StmtContext>,
+ pub body: Block,
+ pub module: &'a mut crate::Module,
+ pub is_const: bool,
+ /// Tracks the constness of `Expression`s residing in `self.expressions`
+ pub expression_constness: crate::proc::ExpressionConstnessTracker,
+}
+
+impl<'a> Context<'a> {
+ pub fn new(frontend: &Frontend, module: &'a mut crate::Module, is_const: bool) -> Result<Self> {
+ let mut this = Context {
+ expressions: Arena::new(),
+ locals: Arena::new(),
+ arguments: Vec::new(),
+
+ parameters: Vec::new(),
+ parameters_info: Vec::new(),
+
+ symbol_table: crate::front::SymbolTable::default(),
+ samplers: FastHashMap::default(),
+
+ const_typifier: Typifier::new(),
+ typifier: Typifier::new(),
+ emitter: Emitter::default(),
+ stmt_ctx: Some(StmtContext::new()),
+ body: Block::new(),
+ module,
+ is_const: false,
+ expression_constness: crate::proc::ExpressionConstnessTracker::new(),
+ };
+
+ this.emit_start();
+
+ for &(ref name, lookup) in frontend.global_variables.iter() {
+ this.add_global(name, lookup)?
+ }
+ this.is_const = is_const;
+
+ Ok(this)
+ }
+
+ pub fn new_body<F>(&mut self, cb: F) -> Result<Block>
+ where
+ F: FnOnce(&mut Self) -> Result<()>,
+ {
+ self.new_body_with_ret(cb).map(|(b, _)| b)
+ }
+
+ pub fn new_body_with_ret<F, R>(&mut self, cb: F) -> Result<(Block, R)>
+ where
+ F: FnOnce(&mut Self) -> Result<R>,
+ {
+ self.emit_restart();
+ let old_body = std::mem::replace(&mut self.body, Block::new());
+ let res = cb(self);
+ self.emit_restart();
+ let new_body = std::mem::replace(&mut self.body, old_body);
+ res.map(|r| (new_body, r))
+ }
+
+ pub fn with_body<F>(&mut self, body: Block, cb: F) -> Result<Block>
+ where
+ F: FnOnce(&mut Self) -> Result<()>,
+ {
+ self.emit_restart();
+ let old_body = std::mem::replace(&mut self.body, body);
+ let res = cb(self);
+ self.emit_restart();
+ let body = std::mem::replace(&mut self.body, old_body);
+ res.map(|_| body)
+ }
+
+ pub fn add_global(
+ &mut self,
+ name: &str,
+ GlobalLookup {
+ kind,
+ entry_arg,
+ mutable,
+ }: GlobalLookup,
+ ) -> Result<()> {
+ let (expr, load, constant) = match kind {
+ GlobalLookupKind::Variable(v) => {
+ let span = self.module.global_variables.get_span(v);
+ (
+ self.add_expression(Expression::GlobalVariable(v), span)?,
+ self.module.global_variables[v].space != AddressSpace::Handle,
+ None,
+ )
+ }
+ GlobalLookupKind::BlockSelect(handle, index) => {
+ let span = self.module.global_variables.get_span(handle);
+ let base = self.add_expression(Expression::GlobalVariable(handle), span)?;
+ let expr = self.add_expression(Expression::AccessIndex { base, index }, span)?;
+
+ (
+ expr,
+ {
+ let ty = self.module.global_variables[handle].ty;
+
+ match self.module.types[ty].inner {
+ TypeInner::Struct { ref members, .. } => {
+ if let TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } = self.module.types[members[index as usize].ty].inner
+ {
+ false
+ } else {
+ true
+ }
+ }
+ _ => true,
+ }
+ },
+ None,
+ )
+ }
+ GlobalLookupKind::Constant(v, ty) => {
+ let span = self.module.constants.get_span(v);
+ (
+ self.add_expression(Expression::Constant(v), span)?,
+ false,
+ Some((v, ty)),
+ )
+ }
+ };
+
+ let var = VariableReference {
+ expr,
+ load,
+ mutable,
+ constant,
+ entry_arg,
+ };
+
+ self.symbol_table.add(name.into(), var);
+
+ Ok(())
+ }
+
+ /// Starts the expression emitter
+ ///
+ /// # Panics
+ ///
+ /// - If called twice in a row without calling [`emit_end`][Self::emit_end].
+ #[inline]
+ pub fn emit_start(&mut self) {
+ self.emitter.start(&self.expressions)
+ }
+
+ /// Emits all the expressions captured by the emitter to the current body
+ ///
+ /// # Panics
+ ///
+ /// - If called before calling [`emit_start`].
+ /// - If called twice in a row without calling [`emit_start`].
+ ///
+ /// [`emit_start`]: Self::emit_start
+ pub fn emit_end(&mut self) {
+ self.body.extend(self.emitter.finish(&self.expressions))
+ }
+
+ /// Emits all the expressions captured by the emitter to the current body
+ /// and starts the emitter again
+ ///
+ /// # Panics
+ ///
+ /// - If called before calling [`emit_start`][Self::emit_start].
+ pub fn emit_restart(&mut self) {
+ self.emit_end();
+ self.emit_start()
+ }
+
+ pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> {
+ let mut eval = if self.is_const {
+ crate::proc::ConstantEvaluator::for_glsl_module(self.module)
+ } else {
+ crate::proc::ConstantEvaluator::for_glsl_function(
+ self.module,
+ &mut self.expressions,
+ &mut self.expression_constness,
+ &mut self.emitter,
+ &mut self.body,
+ )
+ };
+
+ let res = eval.try_eval_and_append(&expr, meta).map_err(|e| Error {
+ kind: e.into(),
+ meta,
+ });
+
+ match res {
+ Ok(expr) => Ok(expr),
+ Err(e) => {
+ if self.is_const {
+ Err(e)
+ } else {
+ let needs_pre_emit = expr.needs_pre_emit();
+ if needs_pre_emit {
+ self.body.extend(self.emitter.finish(&self.expressions));
+ }
+ let h = self.expressions.append(expr, meta);
+ if needs_pre_emit {
+ self.emitter.start(&self.expressions);
+ }
+ Ok(h)
+ }
+ }
+ }
+ }
+
+ /// Add variable to current scope
+ ///
+ /// Returns a variable if a variable with the same name was already defined,
+ /// otherwise returns `None`
+ pub fn add_local_var(
+ &mut self,
+ name: String,
+ expr: Handle<Expression>,
+ mutable: bool,
+ ) -> Option<VariableReference> {
+ let var = VariableReference {
+ expr,
+ load: true,
+ mutable,
+ constant: None,
+ entry_arg: None,
+ };
+
+ self.symbol_table.add(name, var)
+ }
+
+ /// Add function argument to current scope
+ pub fn add_function_arg(
+ &mut self,
+ name_meta: Option<(String, Span)>,
+ ty: Handle<Type>,
+ qualifier: ParameterQualifier,
+ ) -> Result<()> {
+ let index = self.arguments.len();
+ let mut arg = FunctionArgument {
+ name: name_meta.as_ref().map(|&(ref name, _)| name.clone()),
+ ty,
+ binding: None,
+ };
+ self.parameters.push(ty);
+
+ let opaque = match self.module.types[ty].inner {
+ TypeInner::Image { .. } | TypeInner::Sampler { .. } => true,
+ _ => false,
+ };
+
+ if qualifier.is_lhs() {
+ let span = self.module.types.get_span(arg.ty);
+ arg.ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Pointer {
+ base: arg.ty,
+ space: AddressSpace::Function,
+ },
+ },
+ span,
+ )
+ }
+
+ self.arguments.push(arg);
+
+ self.parameters_info.push(ParameterInfo {
+ qualifier,
+ depth: false,
+ });
+
+ if let Some((name, meta)) = name_meta {
+ let expr = self.add_expression(Expression::FunctionArgument(index as u32), meta)?;
+ let mutable = qualifier != ParameterQualifier::Const && !opaque;
+ let load = qualifier.is_lhs();
+
+ let var = if mutable && !load {
+ let handle = self.locals.append(
+ LocalVariable {
+ name: Some(name.clone()),
+ ty,
+ init: None,
+ },
+ meta,
+ );
+ let local_expr = self.add_expression(Expression::LocalVariable(handle), meta)?;
+
+ self.emit_restart();
+
+ self.body.push(
+ Statement::Store {
+ pointer: local_expr,
+ value: expr,
+ },
+ meta,
+ );
+
+ VariableReference {
+ expr: local_expr,
+ load: true,
+ mutable,
+ constant: None,
+ entry_arg: None,
+ }
+ } else {
+ VariableReference {
+ expr,
+ load,
+ mutable,
+ constant: None,
+ entry_arg: None,
+ }
+ };
+
+ self.symbol_table.add(name, var);
+ }
+
+ Ok(())
+ }
+
+ /// Returns a [`StmtContext`] to be used in parsing and lowering
+ ///
+ /// # Panics
+ ///
+ /// - If more than one [`StmtContext`] are active at the same time or if the
+ /// previous call didn't use it in lowering.
+ #[must_use]
+ pub fn stmt_ctx(&mut self) -> StmtContext {
+ self.stmt_ctx.take().unwrap()
+ }
+
+ /// Lowers a [`HirExpr`] which might produce a [`Expression`].
+ ///
+ /// consumes a [`StmtContext`] returning it to the context so that it can be
+ /// used again later.
+ pub fn lower(
+ &mut self,
+ mut stmt: StmtContext,
+ frontend: &mut Frontend,
+ expr: Handle<HirExpr>,
+ pos: ExprPos,
+ ) -> Result<(Option<Handle<Expression>>, Span)> {
+ let res = self.lower_inner(&stmt, frontend, expr, pos);
+
+ stmt.hir_exprs.clear();
+ self.stmt_ctx = Some(stmt);
+
+ res
+ }
+
+ /// Similar to [`lower`](Self::lower) but returns an error if the expression
+ /// returns void (ie. doesn't produce a [`Expression`]).
+ ///
+ /// consumes a [`StmtContext`] returning it to the context so that it can be
+ /// used again later.
+ pub fn lower_expect(
+ &mut self,
+ mut stmt: StmtContext,
+ frontend: &mut Frontend,
+ expr: Handle<HirExpr>,
+ pos: ExprPos,
+ ) -> Result<(Handle<Expression>, Span)> {
+ let res = self.lower_expect_inner(&stmt, frontend, expr, pos);
+
+ stmt.hir_exprs.clear();
+ self.stmt_ctx = Some(stmt);
+
+ res
+ }
+
+ /// internal implementation of [`lower_expect`](Self::lower_expect)
+ ///
+ /// this method is only public because it's used in
+ /// [`function_call`](Frontend::function_call), unless you know what
+ /// you're doing use [`lower_expect`](Self::lower_expect)
+ pub fn lower_expect_inner(
+ &mut self,
+ stmt: &StmtContext,
+ frontend: &mut Frontend,
+ expr: Handle<HirExpr>,
+ pos: ExprPos,
+ ) -> Result<(Handle<Expression>, Span)> {
+ let (maybe_expr, meta) = self.lower_inner(stmt, frontend, expr, pos)?;
+
+ let expr = match maybe_expr {
+ Some(e) => e,
+ None => {
+ return Err(Error {
+ kind: ErrorKind::SemanticError("Expression returns void".into()),
+ meta,
+ })
+ }
+ };
+
+ Ok((expr, meta))
+ }
+
+ fn lower_store(
+ &mut self,
+ pointer: Handle<Expression>,
+ value: Handle<Expression>,
+ meta: Span,
+ ) -> Result<()> {
+ if let Expression::Swizzle {
+ size,
+ mut vector,
+ pattern,
+ } = self.expressions[pointer]
+ {
+ // Stores to swizzled values are not directly supported,
+ // lower them as series of per-component stores.
+ let size = match size {
+ VectorSize::Bi => 2,
+ VectorSize::Tri => 3,
+ VectorSize::Quad => 4,
+ };
+
+ if let Expression::Load { pointer } = self.expressions[vector] {
+ vector = pointer;
+ }
+
+ #[allow(clippy::needless_range_loop)]
+ for index in 0..size {
+ let dst = self.add_expression(
+ Expression::AccessIndex {
+ base: vector,
+ index: pattern[index].index(),
+ },
+ meta,
+ )?;
+ let src = self.add_expression(
+ Expression::AccessIndex {
+ base: value,
+ index: index as u32,
+ },
+ meta,
+ )?;
+
+ self.emit_restart();
+
+ self.body.push(
+ Statement::Store {
+ pointer: dst,
+ value: src,
+ },
+ meta,
+ );
+ }
+ } else {
+ self.emit_restart();
+
+ self.body.push(Statement::Store { pointer, value }, meta);
+ }
+
+ Ok(())
+ }
+
+ /// Internal implementation of [`lower`](Self::lower)
+ fn lower_inner(
+ &mut self,
+ stmt: &StmtContext,
+ frontend: &mut Frontend,
+ expr: Handle<HirExpr>,
+ pos: ExprPos,
+ ) -> Result<(Option<Handle<Expression>>, Span)> {
+ let HirExpr { ref kind, meta } = stmt.hir_exprs[expr];
+
+ log::debug!("Lowering {:?} (kind {:?}, pos {:?})", expr, kind, pos);
+
+ let handle = match *kind {
+ HirExprKind::Access { base, index } => {
+ let (index, _) = self.lower_expect_inner(stmt, frontend, index, ExprPos::Rhs)?;
+ let maybe_constant_index = match pos {
+ // Don't try to generate `AccessIndex` if in a LHS position, since it
+ // wouldn't produce a pointer.
+ ExprPos::Lhs => None,
+ _ => self
+ .module
+ .to_ctx()
+ .eval_expr_to_u32_from(index, &self.expressions)
+ .ok(),
+ };
+
+ let base = self
+ .lower_expect_inner(
+ stmt,
+ frontend,
+ base,
+ pos.maybe_access_base(maybe_constant_index.is_some()),
+ )?
+ .0;
+
+ let pointer = maybe_constant_index
+ .map(|index| self.add_expression(Expression::AccessIndex { base, index }, meta))
+ .unwrap_or_else(|| {
+ self.add_expression(Expression::Access { base, index }, meta)
+ })?;
+
+ if ExprPos::Rhs == pos {
+ let resolved = self.resolve_type(pointer, meta)?;
+ if resolved.pointer_space().is_some() {
+ return Ok((
+ Some(self.add_expression(Expression::Load { pointer }, meta)?),
+ meta,
+ ));
+ }
+ }
+
+ pointer
+ }
+ HirExprKind::Select { base, ref field } => {
+ let base = self.lower_expect_inner(stmt, frontend, base, pos)?.0;
+
+ frontend.field_selection(self, pos, base, field, meta)?
+ }
+ HirExprKind::Literal(literal) if pos != ExprPos::Lhs => {
+ self.add_expression(Expression::Literal(literal), meta)?
+ }
+ HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => {
+ let (mut left, left_meta) =
+ self.lower_expect_inner(stmt, frontend, left, ExprPos::Rhs)?;
+ let (mut right, right_meta) =
+ self.lower_expect_inner(stmt, frontend, right, ExprPos::Rhs)?;
+
+ match op {
+ BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => {
+ self.implicit_conversion(&mut right, right_meta, Scalar::U32)?
+ }
+ _ => self
+ .binary_implicit_conversion(&mut left, left_meta, &mut right, right_meta)?,
+ }
+
+ self.typifier_grow(left, left_meta)?;
+ self.typifier_grow(right, right_meta)?;
+
+ let left_inner = self.get_type(left);
+ let right_inner = self.get_type(right);
+
+ match (left_inner, right_inner) {
+ (
+ &TypeInner::Matrix {
+ columns: left_columns,
+ rows: left_rows,
+ scalar: left_scalar,
+ },
+ &TypeInner::Matrix {
+ columns: right_columns,
+ rows: right_rows,
+ scalar: right_scalar,
+ },
+ ) => {
+ let dimensions_ok = if op == BinaryOperator::Multiply {
+ left_columns == right_rows
+ } else {
+ left_columns == right_columns && left_rows == right_rows
+ };
+
+ // Check that the two arguments have the same dimensions
+ if !dimensions_ok || left_scalar != right_scalar {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!(
+ "Cannot apply operation to {left_inner:?} and {right_inner:?}"
+ )
+ .into(),
+ ),
+ meta,
+ })
+ }
+
+ match op {
+ BinaryOperator::Divide => {
+ // Naga IR doesn't support matrix division so we need to
+ // divide the columns individually and reassemble the matrix
+ let mut components = Vec::with_capacity(left_columns as usize);
+
+ for index in 0..left_columns as u32 {
+ // Get the column vectors
+ let left_vector = self.add_expression(
+ Expression::AccessIndex { base: left, index },
+ meta,
+ )?;
+ let right_vector = self.add_expression(
+ Expression::AccessIndex { base: right, index },
+ meta,
+ )?;
+
+ // Divide the vectors
+ let column = self.add_expression(
+ Expression::Binary {
+ op,
+ left: left_vector,
+ right: right_vector,
+ },
+ meta,
+ )?;
+
+ components.push(column)
+ }
+
+ let ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns: left_columns,
+ rows: left_rows,
+ scalar: left_scalar,
+ },
+ },
+ Span::default(),
+ );
+
+ // Rebuild the matrix from the divided vectors
+ self.add_expression(Expression::Compose { ty, components }, meta)?
+ }
+ BinaryOperator::Equal | BinaryOperator::NotEqual => {
+ // Naga IR doesn't support matrix comparisons so we need to
+ // compare the columns individually and then fold them together
+ //
+ // The folding is done using a logical and for equality and
+ // a logical or for inequality
+ let equals = op == BinaryOperator::Equal;
+
+ let (op, combine, fun) = match equals {
+ true => (
+ BinaryOperator::Equal,
+ BinaryOperator::LogicalAnd,
+ RelationalFunction::All,
+ ),
+ false => (
+ BinaryOperator::NotEqual,
+ BinaryOperator::LogicalOr,
+ RelationalFunction::Any,
+ ),
+ };
+
+ let mut root = None;
+
+ for index in 0..left_columns as u32 {
+ // Get the column vectors
+ let left_vector = self.add_expression(
+ Expression::AccessIndex { base: left, index },
+ meta,
+ )?;
+ let right_vector = self.add_expression(
+ Expression::AccessIndex { base: right, index },
+ meta,
+ )?;
+
+ let argument = self.add_expression(
+ Expression::Binary {
+ op,
+ left: left_vector,
+ right: right_vector,
+ },
+ meta,
+ )?;
+
+ // The result of comparing two vectors is a boolean vector
+ // so use a relational function like all to get a single
+ // boolean value
+ let compare = self.add_expression(
+ Expression::Relational { fun, argument },
+ meta,
+ )?;
+
+ // Fold the result
+ root = Some(match root {
+ Some(right) => self.add_expression(
+ Expression::Binary {
+ op: combine,
+ left: compare,
+ right,
+ },
+ meta,
+ )?,
+ None => compare,
+ });
+ }
+
+ root.unwrap()
+ }
+ _ => {
+ self.add_expression(Expression::Binary { left, op, right }, meta)?
+ }
+ }
+ }
+ (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op {
+ BinaryOperator::Equal | BinaryOperator::NotEqual => {
+ let equals = op == BinaryOperator::Equal;
+
+ let (op, fun) = match equals {
+ true => (BinaryOperator::Equal, RelationalFunction::All),
+ false => (BinaryOperator::NotEqual, RelationalFunction::Any),
+ };
+
+ let argument =
+ self.add_expression(Expression::Binary { op, left, right }, meta)?;
+
+ self.add_expression(Expression::Relational { fun, argument }, meta)?
+ }
+ _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
+ },
+ (&TypeInner::Vector { size, .. }, &TypeInner::Scalar { .. }) => match op {
+ BinaryOperator::Add
+ | BinaryOperator::Subtract
+ | BinaryOperator::Divide
+ | BinaryOperator::And
+ | BinaryOperator::ExclusiveOr
+ | BinaryOperator::InclusiveOr
+ | BinaryOperator::ShiftLeft
+ | BinaryOperator::ShiftRight => {
+ let scalar_vector = self
+ .add_expression(Expression::Splat { size, value: right }, meta)?;
+
+ self.add_expression(
+ Expression::Binary {
+ op,
+ left,
+ right: scalar_vector,
+ },
+ meta,
+ )?
+ }
+ _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
+ },
+ (&TypeInner::Scalar { .. }, &TypeInner::Vector { size, .. }) => match op {
+ BinaryOperator::Add
+ | BinaryOperator::Subtract
+ | BinaryOperator::Divide
+ | BinaryOperator::And
+ | BinaryOperator::ExclusiveOr
+ | BinaryOperator::InclusiveOr => {
+ let scalar_vector =
+ self.add_expression(Expression::Splat { size, value: left }, meta)?;
+
+ self.add_expression(
+ Expression::Binary {
+ op,
+ left: scalar_vector,
+ right,
+ },
+ meta,
+ )?
+ }
+ _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
+ },
+ (
+ &TypeInner::Scalar(left_scalar),
+ &TypeInner::Matrix {
+ rows,
+ columns,
+ scalar: right_scalar,
+ },
+ ) => {
+ // Check that the two arguments have the same scalar type
+ if left_scalar != right_scalar {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!(
+ "Cannot apply operation to {left_inner:?} and {right_inner:?}"
+ )
+ .into(),
+ ),
+ meta,
+ })
+ }
+
+ match op {
+ BinaryOperator::Divide
+ | BinaryOperator::Add
+ | BinaryOperator::Subtract => {
+ // Naga IR doesn't support all matrix by scalar operations so
+ // we need for some to turn the scalar into a vector by
+ // splatting it and then for each column vector apply the
+ // operation and finally reconstruct the matrix
+ let scalar_vector = self.add_expression(
+ Expression::Splat {
+ size: rows,
+ value: left,
+ },
+ meta,
+ )?;
+
+ let mut components = Vec::with_capacity(columns as usize);
+
+ for index in 0..columns as u32 {
+ // Get the column vector
+ let matrix_column = self.add_expression(
+ Expression::AccessIndex { base: right, index },
+ meta,
+ )?;
+
+ // Apply the operation to the splatted vector and
+ // the column vector
+ let column = self.add_expression(
+ Expression::Binary {
+ op,
+ left: scalar_vector,
+ right: matrix_column,
+ },
+ meta,
+ )?;
+
+ components.push(column)
+ }
+
+ let ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns,
+ rows,
+ scalar: left_scalar,
+ },
+ },
+ Span::default(),
+ );
+
+ // Rebuild the matrix from the operation result vectors
+ self.add_expression(Expression::Compose { ty, components }, meta)?
+ }
+ _ => {
+ self.add_expression(Expression::Binary { left, op, right }, meta)?
+ }
+ }
+ }
+ (
+ &TypeInner::Matrix {
+ rows,
+ columns,
+ scalar: left_scalar,
+ },
+ &TypeInner::Scalar(right_scalar),
+ ) => {
+ // Check that the two arguments have the same scalar type
+ if left_scalar != right_scalar {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!(
+ "Cannot apply operation to {left_inner:?} and {right_inner:?}"
+ )
+ .into(),
+ ),
+ meta,
+ })
+ }
+
+ match op {
+ BinaryOperator::Divide
+ | BinaryOperator::Add
+ | BinaryOperator::Subtract => {
+ // Naga IR doesn't support all matrix by scalar operations so
+ // we need for some to turn the scalar into a vector by
+ // splatting it and then for each column vector apply the
+ // operation and finally reconstruct the matrix
+
+ let scalar_vector = self.add_expression(
+ Expression::Splat {
+ size: rows,
+ value: right,
+ },
+ meta,
+ )?;
+
+ let mut components = Vec::with_capacity(columns as usize);
+
+ for index in 0..columns as u32 {
+ // Get the column vector
+ let matrix_column = self.add_expression(
+ Expression::AccessIndex { base: left, index },
+ meta,
+ )?;
+
+ // Apply the operation to the splatted vector and
+ // the column vector
+ let column = self.add_expression(
+ Expression::Binary {
+ op,
+ left: matrix_column,
+ right: scalar_vector,
+ },
+ meta,
+ )?;
+
+ components.push(column)
+ }
+
+ let ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns,
+ rows,
+ scalar: left_scalar,
+ },
+ },
+ Span::default(),
+ );
+
+ // Rebuild the matrix from the operation result vectors
+ self.add_expression(Expression::Compose { ty, components }, meta)?
+ }
+ _ => {
+ self.add_expression(Expression::Binary { left, op, right }, meta)?
+ }
+ }
+ }
+ _ => self.add_expression(Expression::Binary { left, op, right }, meta)?,
+ }
+ }
+ HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => {
+ let expr = self
+ .lower_expect_inner(stmt, frontend, expr, ExprPos::Rhs)?
+ .0;
+
+ self.add_expression(Expression::Unary { op, expr }, meta)?
+ }
+ HirExprKind::Variable(ref var) => match pos {
+ ExprPos::Lhs => {
+ if !var.mutable {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Variable cannot be used in LHS position".into(),
+ ),
+ meta,
+ })
+ }
+
+ var.expr
+ }
+ ExprPos::AccessBase { constant_index } => {
+ // If the index isn't constant all accesses backed by a constant base need
+ // to be done through a proxy local variable, since constants have a non
+ // pointer type which is required for dynamic indexing
+ if !constant_index {
+ if let Some((constant, ty)) = var.constant {
+ let init = self
+ .add_expression(Expression::Constant(constant), Span::default())?;
+ let local = self.locals.append(
+ LocalVariable {
+ name: None,
+ ty,
+ init: Some(init),
+ },
+ Span::default(),
+ );
+
+ self.add_expression(Expression::LocalVariable(local), Span::default())?
+ } else {
+ var.expr
+ }
+ } else {
+ var.expr
+ }
+ }
+ _ if var.load => {
+ self.add_expression(Expression::Load { pointer: var.expr }, meta)?
+ }
+ ExprPos::Rhs => {
+ if let Some((constant, _)) = self.is_const.then_some(var.constant).flatten() {
+ self.add_expression(Expression::Constant(constant), meta)?
+ } else {
+ var.expr
+ }
+ }
+ },
+ HirExprKind::Call(ref call) if pos != ExprPos::Lhs => {
+ let maybe_expr = frontend.function_or_constructor_call(
+ self,
+ stmt,
+ call.kind.clone(),
+ &call.args,
+ meta,
+ )?;
+ return Ok((maybe_expr, meta));
+ }
+ // `HirExprKind::Conditional` represents the ternary operator in glsl (`:?`)
+ //
+ // The ternary operator is defined to only evaluate one of the two possible
+ // expressions which means that it's behavior is that of an `if` statement,
+ // and it's merely syntactic sugar for it.
+ HirExprKind::Conditional {
+ condition,
+ accept,
+ reject,
+ } if ExprPos::Lhs != pos => {
+ // Given an expression `a ? b : c`, we need to produce a Naga
+ // statement roughly like:
+ //
+ // var temp;
+ // if a {
+ // temp = convert(b);
+ // } else {
+ // temp = convert(c);
+ // }
+ //
+ // where `convert` stands for type conversions to bring `b` and `c` to
+ // the same type, and then use `temp` to represent the value of the whole
+ // conditional expression in subsequent code.
+
+ // Lower the condition first to the current bodyy
+ let condition = self
+ .lower_expect_inner(stmt, frontend, condition, ExprPos::Rhs)?
+ .0;
+
+ let (mut accept_body, (mut accept, accept_meta)) =
+ self.new_body_with_ret(|ctx| {
+ // Lower the `true` branch
+ ctx.lower_expect_inner(stmt, frontend, accept, pos)
+ })?;
+
+ let (mut reject_body, (mut reject, reject_meta)) =
+ self.new_body_with_ret(|ctx| {
+ // Lower the `false` branch
+ ctx.lower_expect_inner(stmt, frontend, reject, pos)
+ })?;
+
+ // We need to do some custom implicit conversions since the two target expressions
+ // are in different bodies
+ if let (Some((accept_power, accept_scalar)), Some((reject_power, reject_scalar))) = (
+ // Get the components of both branches and calculate the type power
+ self.expr_scalar_components(accept, accept_meta)?
+ .and_then(|scalar| Some((type_power(scalar)?, scalar))),
+ self.expr_scalar_components(reject, reject_meta)?
+ .and_then(|scalar| Some((type_power(scalar)?, scalar))),
+ ) {
+ match accept_power.cmp(&reject_power) {
+ std::cmp::Ordering::Less => {
+ accept_body = self.with_body(accept_body, |ctx| {
+ ctx.conversion(&mut accept, accept_meta, reject_scalar)?;
+ Ok(())
+ })?;
+ }
+ std::cmp::Ordering::Equal => {}
+ std::cmp::Ordering::Greater => {
+ reject_body = self.with_body(reject_body, |ctx| {
+ ctx.conversion(&mut reject, reject_meta, accept_scalar)?;
+ Ok(())
+ })?;
+ }
+ }
+ }
+
+ // We need to get the type of the resulting expression to create the local,
+ // this must be done after implicit conversions to ensure both branches have
+ // the same type.
+ let ty = self.resolve_type_handle(accept, accept_meta)?;
+
+ // Add the local that will hold the result of our conditional
+ let local = self.locals.append(
+ LocalVariable {
+ name: None,
+ ty,
+ init: None,
+ },
+ meta,
+ );
+
+ let local_expr = self.add_expression(Expression::LocalVariable(local), meta)?;
+
+ // Add to each the store to the result variable
+ accept_body.push(
+ Statement::Store {
+ pointer: local_expr,
+ value: accept,
+ },
+ accept_meta,
+ );
+ reject_body.push(
+ Statement::Store {
+ pointer: local_expr,
+ value: reject,
+ },
+ reject_meta,
+ );
+
+ // Finally add the `If` to the main body with the `condition` we lowered
+ // earlier and the branches we prepared.
+ self.body.push(
+ Statement::If {
+ condition,
+ accept: accept_body,
+ reject: reject_body,
+ },
+ meta,
+ );
+
+ // Note: `Expression::Load` must be emitted before it's used so make
+ // sure the emitter is active here.
+ self.add_expression(
+ Expression::Load {
+ pointer: local_expr,
+ },
+ meta,
+ )?
+ }
+ HirExprKind::Assign { tgt, value } if ExprPos::Lhs != pos => {
+ let (pointer, ptr_meta) =
+ self.lower_expect_inner(stmt, frontend, tgt, ExprPos::Lhs)?;
+ let (mut value, value_meta) =
+ self.lower_expect_inner(stmt, frontend, value, ExprPos::Rhs)?;
+
+ let ty = match *self.resolve_type(pointer, ptr_meta)? {
+ TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
+ ref ty => ty,
+ };
+
+ if let Some(scalar) = scalar_components(ty) {
+ self.implicit_conversion(&mut value, value_meta, scalar)?;
+ }
+
+ self.lower_store(pointer, value, meta)?;
+
+ value
+ }
+ HirExprKind::PrePostfix { op, postfix, expr } if ExprPos::Lhs != pos => {
+ let (pointer, _) = self.lower_expect_inner(stmt, frontend, expr, ExprPos::Lhs)?;
+ let left = if let Expression::Swizzle { .. } = self.expressions[pointer] {
+ pointer
+ } else {
+ self.add_expression(Expression::Load { pointer }, meta)?
+ };
+
+ let res = match *self.resolve_type(left, meta)? {
+ TypeInner::Scalar(scalar) => {
+ let ty = TypeInner::Scalar(scalar);
+ Literal::one(scalar).map(|i| (ty, i, None, None))
+ }
+ TypeInner::Vector { size, scalar } => {
+ let ty = TypeInner::Vector { size, scalar };
+ Literal::one(scalar).map(|i| (ty, i, Some(size), None))
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ let ty = TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ };
+ Literal::one(scalar).map(|i| (ty, i, Some(rows), Some(columns)))
+ }
+ _ => None,
+ };
+ let (ty_inner, literal, rows, columns) = match res {
+ Some(res) => res,
+ None => {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Increment/decrement only works on scalar/vector/matrix".into(),
+ ),
+ meta,
+ });
+ return Ok((Some(left), meta));
+ }
+ };
+
+ let mut right = self.add_expression(Expression::Literal(literal), meta)?;
+
+ // Glsl allows pre/postfixes operations on vectors and matrices, so if the
+ // target is either of them change the right side of the addition to be splatted
+ // to the same size as the target, furthermore if the target is a matrix
+ // use a composed matrix using the splatted value.
+ if let Some(size) = rows {
+ right = self.add_expression(Expression::Splat { size, value: right }, meta)?;
+
+ if let Some(cols) = columns {
+ let ty = self.module.types.insert(
+ Type {
+ name: None,
+ inner: ty_inner,
+ },
+ meta,
+ );
+
+ right = self.add_expression(
+ Expression::Compose {
+ ty,
+ components: std::iter::repeat(right).take(cols as usize).collect(),
+ },
+ meta,
+ )?;
+ }
+ }
+
+ let value = self.add_expression(Expression::Binary { op, left, right }, meta)?;
+
+ self.lower_store(pointer, value, meta)?;
+
+ if postfix {
+ left
+ } else {
+ value
+ }
+ }
+ HirExprKind::Method {
+ expr: object,
+ ref name,
+ ref args,
+ } if ExprPos::Lhs != pos => {
+ let args = args
+ .iter()
+ .map(|e| self.lower_expect_inner(stmt, frontend, *e, ExprPos::Rhs))
+ .collect::<Result<Vec<_>>>()?;
+ match name.as_ref() {
+ "length" => {
+ if !args.is_empty() {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ ".length() doesn't take any arguments".into(),
+ ),
+ meta,
+ });
+ }
+ let lowered_array = self.lower_expect_inner(stmt, frontend, object, pos)?.0;
+ let array_type = self.resolve_type(lowered_array, meta)?;
+
+ match *array_type {
+ TypeInner::Array {
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ let mut array_length = self.add_expression(
+ Expression::Literal(Literal::U32(size.get())),
+ meta,
+ )?;
+ self.forced_conversion(&mut array_length, meta, Scalar::I32)?;
+ array_length
+ }
+ // let the error be handled in type checking if it's not a dynamic array
+ _ => {
+ let mut array_length = self
+ .add_expression(Expression::ArrayLength(lowered_array), meta)?;
+ self.conversion(&mut array_length, meta, Scalar::I32)?;
+ array_length
+ }
+ }
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::SemanticError(
+ format!("unknown method '{name}'").into(),
+ ),
+ meta,
+ });
+ }
+ }
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::SemanticError(
+ format!("{:?} cannot be in the left hand side", stmt.hir_exprs[expr])
+ .into(),
+ ),
+ meta,
+ })
+ }
+ };
+
+ log::trace!(
+ "Lowered {:?}\n\tKind = {:?}\n\tPos = {:?}\n\tResult = {:?}",
+ expr,
+ kind,
+ pos,
+ handle
+ );
+
+ Ok((Some(handle), meta))
+ }
+
+ pub fn expr_scalar_components(
+ &mut self,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<Option<Scalar>> {
+ let ty = self.resolve_type(expr, meta)?;
+ Ok(scalar_components(ty))
+ }
+
+ pub fn expr_power(&mut self, expr: Handle<Expression>, meta: Span) -> Result<Option<u32>> {
+ Ok(self
+ .expr_scalar_components(expr, meta)?
+ .and_then(type_power))
+ }
+
+ pub fn conversion(
+ &mut self,
+ expr: &mut Handle<Expression>,
+ meta: Span,
+ scalar: Scalar,
+ ) -> Result<()> {
+ *expr = self.add_expression(
+ Expression::As {
+ expr: *expr,
+ kind: scalar.kind,
+ convert: Some(scalar.width),
+ },
+ meta,
+ )?;
+
+ Ok(())
+ }
+
+ pub fn implicit_conversion(
+ &mut self,
+ expr: &mut Handle<Expression>,
+ meta: Span,
+ scalar: Scalar,
+ ) -> Result<()> {
+ if let (Some(tgt_power), Some(expr_power)) =
+ (type_power(scalar), self.expr_power(*expr, meta)?)
+ {
+ if tgt_power > expr_power {
+ self.conversion(expr, meta, scalar)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ pub fn forced_conversion(
+ &mut self,
+ expr: &mut Handle<Expression>,
+ meta: Span,
+ scalar: Scalar,
+ ) -> Result<()> {
+ if let Some(expr_scalar) = self.expr_scalar_components(*expr, meta)? {
+ if expr_scalar != scalar {
+ self.conversion(expr, meta, scalar)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ pub fn binary_implicit_conversion(
+ &mut self,
+ left: &mut Handle<Expression>,
+ left_meta: Span,
+ right: &mut Handle<Expression>,
+ right_meta: Span,
+ ) -> Result<()> {
+ let left_components = self.expr_scalar_components(*left, left_meta)?;
+ let right_components = self.expr_scalar_components(*right, right_meta)?;
+
+ if let (Some((left_power, left_scalar)), Some((right_power, right_scalar))) = (
+ left_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
+ right_components.and_then(|scalar| Some((type_power(scalar)?, scalar))),
+ ) {
+ match left_power.cmp(&right_power) {
+ std::cmp::Ordering::Less => {
+ self.conversion(left, left_meta, right_scalar)?;
+ }
+ std::cmp::Ordering::Equal => {}
+ std::cmp::Ordering::Greater => {
+ self.conversion(right, right_meta, left_scalar)?;
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ pub fn implicit_splat(
+ &mut self,
+ expr: &mut Handle<Expression>,
+ meta: Span,
+ vector_size: Option<VectorSize>,
+ ) -> Result<()> {
+ let expr_type = self.resolve_type(*expr, meta)?;
+
+ if let (&TypeInner::Scalar { .. }, Some(size)) = (expr_type, vector_size) {
+ *expr = self.add_expression(Expression::Splat { size, value: *expr }, meta)?
+ }
+
+ Ok(())
+ }
+
+ pub fn vector_resize(
+ &mut self,
+ size: VectorSize,
+ vector: Handle<Expression>,
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ self.add_expression(
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern: crate::SwizzleComponent::XYZW,
+ },
+ meta,
+ )
+ }
+}
+
+impl Index<Handle<Expression>> for Context<'_> {
+ type Output = Expression;
+
+ fn index(&self, index: Handle<Expression>) -> &Self::Output {
+ if self.is_const {
+ &self.module.const_expressions[index]
+ } else {
+ &self.expressions[index]
+ }
+ }
+}
+
+/// Helper struct passed when parsing expressions
+///
+/// This struct should only be obtained through [`stmt_ctx`](Context::stmt_ctx)
+/// and only one of these may be active at any time per context.
+#[derive(Debug)]
+pub struct StmtContext {
+ /// A arena of high level expressions which can be lowered through a
+ /// [`Context`] to Naga's [`Expression`]s
+ pub hir_exprs: Arena<HirExpr>,
+}
+
+impl StmtContext {
+ const fn new() -> Self {
+ StmtContext {
+ hir_exprs: Arena::new(),
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/error.rs b/third_party/rust/naga/src/front/glsl/error.rs
new file mode 100644
index 0000000000..bd16ee30bc
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/error.rs
@@ -0,0 +1,191 @@
+use super::token::TokenValue;
+use crate::{proc::ConstantEvaluatorError, Span};
+use codespan_reporting::diagnostic::{Diagnostic, Label};
+use codespan_reporting::files::SimpleFile;
+use codespan_reporting::term;
+use pp_rs::token::PreprocessorError;
+use std::borrow::Cow;
+use termcolor::{NoColor, WriteColor};
+use thiserror::Error;
+
+fn join_with_comma(list: &[ExpectedToken]) -> String {
+ let mut string = "".to_string();
+ for (i, val) in list.iter().enumerate() {
+ string.push_str(&val.to_string());
+ match i {
+ i if i == list.len() - 1 => {}
+ i if i == list.len() - 2 => string.push_str(" or "),
+ _ => string.push_str(", "),
+ }
+ }
+ string
+}
+
+/// One of the expected tokens returned in [`InvalidToken`](ErrorKind::InvalidToken).
+#[derive(Clone, Debug, PartialEq)]
+pub enum ExpectedToken {
+ /// A specific token was expected.
+ Token(TokenValue),
+ /// A type was expected.
+ TypeName,
+ /// An identifier was expected.
+ Identifier,
+ /// An integer literal was expected.
+ IntLiteral,
+ /// A float literal was expected.
+ FloatLiteral,
+ /// A boolean literal was expected.
+ BoolLiteral,
+ /// The end of file was expected.
+ Eof,
+}
+impl From<TokenValue> for ExpectedToken {
+ fn from(token: TokenValue) -> Self {
+ ExpectedToken::Token(token)
+ }
+}
+impl std::fmt::Display for ExpectedToken {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match *self {
+ ExpectedToken::Token(ref token) => write!(f, "{token:?}"),
+ ExpectedToken::TypeName => write!(f, "a type"),
+ ExpectedToken::Identifier => write!(f, "identifier"),
+ ExpectedToken::IntLiteral => write!(f, "integer literal"),
+ ExpectedToken::FloatLiteral => write!(f, "float literal"),
+ ExpectedToken::BoolLiteral => write!(f, "bool literal"),
+ ExpectedToken::Eof => write!(f, "end of file"),
+ }
+ }
+}
+
+/// Information about the cause of an error.
+#[derive(Clone, Debug, Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum ErrorKind {
+ /// Whilst parsing as encountered an unexpected EOF.
+ #[error("Unexpected end of file")]
+ EndOfFile,
+ /// The shader specified an unsupported or invalid profile.
+ #[error("Invalid profile: {0}")]
+ InvalidProfile(String),
+ /// The shader requested an unsupported or invalid version.
+ #[error("Invalid version: {0}")]
+ InvalidVersion(u64),
+ /// Whilst parsing an unexpected token was encountered.
+ ///
+ /// A list of expected tokens is also returned.
+ #[error("Expected {}, found {0:?}", join_with_comma(.1))]
+ InvalidToken(TokenValue, Vec<ExpectedToken>),
+ /// A specific feature is not yet implemented.
+ ///
+ /// To help prioritize work please open an issue in the github issue tracker
+ /// if none exist already or react to the already existing one.
+ #[error("Not implemented: {0}")]
+ NotImplemented(&'static str),
+ /// A reference to a variable that wasn't declared was used.
+ #[error("Unknown variable: {0}")]
+ UnknownVariable(String),
+ /// A reference to a type that wasn't declared was used.
+ #[error("Unknown type: {0}")]
+ UnknownType(String),
+ /// A reference to a non existent member of a type was made.
+ #[error("Unknown field: {0}")]
+ UnknownField(String),
+ /// An unknown layout qualifier was used.
+ ///
+ /// If the qualifier does exist please open an issue in the github issue tracker
+ /// if none exist already or react to the already existing one to help
+ /// prioritize work.
+ #[error("Unknown layout qualifier: {0}")]
+ UnknownLayoutQualifier(String),
+ /// Unsupported matrix of the form matCx2
+ ///
+ /// Our IR expects matrices of the form matCx2 to have a stride of 8 however
+ /// matrices in the std140 layout have a stride of at least 16
+ #[error("unsupported matrix of the form matCx2 in std140 block layout")]
+ UnsupportedMatrixTypeInStd140,
+ /// A variable with the same name already exists in the current scope.
+ #[error("Variable already declared: {0}")]
+ VariableAlreadyDeclared(String),
+ /// A semantic error was detected in the shader.
+ #[error("{0}")]
+ SemanticError(Cow<'static, str>),
+ /// An error was returned by the preprocessor.
+ #[error("{0:?}")]
+ PreprocessorError(PreprocessorError),
+ /// The parser entered an illegal state and exited
+ ///
+ /// This obviously is a bug and as such should be reported in the github issue tracker
+ #[error("Internal error: {0}")]
+ InternalError(&'static str),
+}
+
+impl From<ConstantEvaluatorError> for ErrorKind {
+ fn from(err: ConstantEvaluatorError) -> Self {
+ ErrorKind::SemanticError(err.to_string().into())
+ }
+}
+
+/// Error returned during shader parsing.
+#[derive(Clone, Debug, Error)]
+#[error("{kind}")]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Error {
+ /// Holds the information about the error itself.
+ pub kind: ErrorKind,
+ /// Holds information about the range of the source code where the error happened.
+ pub meta: Span,
+}
+
+/// A collection of errors returned during shader parsing.
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct ParseError {
+ pub errors: Vec<Error>,
+}
+
+impl ParseError {
+ pub fn emit_to_writer(&self, writer: &mut impl WriteColor, source: &str) {
+ self.emit_to_writer_with_path(writer, source, "glsl");
+ }
+
+ pub fn emit_to_writer_with_path(&self, writer: &mut impl WriteColor, source: &str, path: &str) {
+ let path = path.to_string();
+ let files = SimpleFile::new(path, source);
+ let config = codespan_reporting::term::Config::default();
+
+ for err in &self.errors {
+ let mut diagnostic = Diagnostic::error().with_message(err.kind.to_string());
+
+ if let Some(range) = err.meta.to_range() {
+ diagnostic = diagnostic.with_labels(vec![Label::primary((), range)]);
+ }
+
+ term::emit(writer, &config, &files, &diagnostic).expect("cannot write error");
+ }
+ }
+
+ pub fn emit_to_string(&self, source: &str) -> String {
+ let mut writer = NoColor::new(Vec::new());
+ self.emit_to_writer(&mut writer, source);
+ String::from_utf8(writer.into_inner()).unwrap()
+ }
+}
+
+impl std::fmt::Display for ParseError {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ self.errors.iter().try_for_each(|e| write!(f, "{e:?}"))
+ }
+}
+
+impl std::error::Error for ParseError {
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ None
+ }
+}
+
+impl From<Vec<Error>> for ParseError {
+ fn from(errors: Vec<Error>) -> Self {
+ Self { errors }
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/functions.rs b/third_party/rust/naga/src/front/glsl/functions.rs
new file mode 100644
index 0000000000..df8cc8a30e
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/functions.rs
@@ -0,0 +1,1602 @@
+use super::{
+ ast::*,
+ builtins::{inject_builtin, sampled_to_depth},
+ context::{Context, ExprPos, StmtContext},
+ error::{Error, ErrorKind},
+ types::scalar_components,
+ Frontend, Result,
+};
+use crate::{
+ front::glsl::types::type_power, proc::ensure_block_returns, AddressSpace, Block, EntryPoint,
+ Expression, Function, FunctionArgument, FunctionResult, Handle, Literal, LocalVariable, Scalar,
+ ScalarKind, Span, Statement, StructMember, Type, TypeInner,
+};
+use std::iter;
+
+/// Struct detailing a store operation that must happen after a function call
+struct ProxyWrite {
+ /// The store target
+ target: Handle<Expression>,
+ /// A pointer to read the value of the store
+ value: Handle<Expression>,
+ /// An optional conversion to be applied
+ convert: Option<Scalar>,
+}
+
+impl Frontend {
+ pub(crate) fn function_or_constructor_call(
+ &mut self,
+ ctx: &mut Context,
+ stmt: &StmtContext,
+ fc: FunctionCallKind,
+ raw_args: &[Handle<HirExpr>],
+ meta: Span,
+ ) -> Result<Option<Handle<Expression>>> {
+ let args: Vec<_> = raw_args
+ .iter()
+ .map(|e| ctx.lower_expect_inner(stmt, self, *e, ExprPos::Rhs))
+ .collect::<Result<_>>()?;
+
+ match fc {
+ FunctionCallKind::TypeConstructor(ty) => {
+ if args.len() == 1 {
+ self.constructor_single(ctx, ty, args[0], meta).map(Some)
+ } else {
+ self.constructor_many(ctx, ty, args, meta).map(Some)
+ }
+ }
+ FunctionCallKind::Function(name) => {
+ self.function_call(ctx, stmt, name, args, raw_args, meta)
+ }
+ }
+ }
+
+ fn constructor_single(
+ &mut self,
+ ctx: &mut Context,
+ ty: Handle<Type>,
+ (mut value, expr_meta): (Handle<Expression>, Span),
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let expr_type = ctx.resolve_type(value, expr_meta)?;
+
+ let vector_size = match *expr_type {
+ TypeInner::Vector { size, .. } => Some(size),
+ _ => None,
+ };
+
+ let expr_is_bool = expr_type.scalar_kind() == Some(ScalarKind::Bool);
+
+ // Special case: if casting from a bool, we need to use Select and not As.
+ match ctx.module.types[ty].inner.scalar() {
+ Some(result_scalar) if expr_is_bool && result_scalar.kind != ScalarKind::Bool => {
+ let result_scalar = Scalar {
+ width: 4,
+ ..result_scalar
+ };
+ let l0 = Literal::zero(result_scalar).unwrap();
+ let l1 = Literal::one(result_scalar).unwrap();
+ let mut reject = ctx.add_expression(Expression::Literal(l0), expr_meta)?;
+ let mut accept = ctx.add_expression(Expression::Literal(l1), expr_meta)?;
+
+ ctx.implicit_splat(&mut reject, meta, vector_size)?;
+ ctx.implicit_splat(&mut accept, meta, vector_size)?;
+
+ let h = ctx.add_expression(
+ Expression::Select {
+ accept,
+ reject,
+ condition: value,
+ },
+ expr_meta,
+ )?;
+
+ return Ok(h);
+ }
+ _ => {}
+ }
+
+ Ok(match ctx.module.types[ty].inner {
+ TypeInner::Vector { size, scalar } if vector_size.is_none() => {
+ ctx.forced_conversion(&mut value, expr_meta, scalar)?;
+
+ if let TypeInner::Scalar { .. } = *ctx.resolve_type(value, expr_meta)? {
+ ctx.add_expression(Expression::Splat { size, value }, meta)?
+ } else {
+ self.vector_constructor(ctx, ty, size, scalar, &[(value, expr_meta)], meta)?
+ }
+ }
+ TypeInner::Scalar(scalar) => {
+ let mut expr = value;
+ if let TypeInner::Vector { .. } | TypeInner::Matrix { .. } =
+ *ctx.resolve_type(value, expr_meta)?
+ {
+ expr = ctx.add_expression(
+ Expression::AccessIndex {
+ base: expr,
+ index: 0,
+ },
+ meta,
+ )?;
+ }
+
+ if let TypeInner::Matrix { .. } = *ctx.resolve_type(value, expr_meta)? {
+ expr = ctx.add_expression(
+ Expression::AccessIndex {
+ base: expr,
+ index: 0,
+ },
+ meta,
+ )?;
+ }
+
+ ctx.add_expression(
+ Expression::As {
+ kind: scalar.kind,
+ expr,
+ convert: Some(scalar.width),
+ },
+ meta,
+ )?
+ }
+ TypeInner::Vector { size, scalar } => {
+ if vector_size.map_or(true, |s| s != size) {
+ value = ctx.vector_resize(size, value, expr_meta)?;
+ }
+
+ ctx.add_expression(
+ Expression::As {
+ kind: scalar.kind,
+ expr: value,
+ convert: Some(scalar.width),
+ },
+ meta,
+ )?
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => self.matrix_one_arg(ctx, ty, columns, rows, scalar, (value, expr_meta), meta)?,
+ TypeInner::Struct { ref members, .. } => {
+ let scalar_components = members
+ .get(0)
+ .and_then(|member| scalar_components(&ctx.module.types[member.ty].inner));
+ if let Some(scalar) = scalar_components {
+ ctx.implicit_conversion(&mut value, expr_meta, scalar)?;
+ }
+
+ ctx.add_expression(
+ Expression::Compose {
+ ty,
+ components: vec![value],
+ },
+ meta,
+ )?
+ }
+
+ TypeInner::Array { base, .. } => {
+ let scalar_components = scalar_components(&ctx.module.types[base].inner);
+ if let Some(scalar) = scalar_components {
+ ctx.implicit_conversion(&mut value, expr_meta, scalar)?;
+ }
+
+ ctx.add_expression(
+ Expression::Compose {
+ ty,
+ components: vec![value],
+ },
+ meta,
+ )?
+ }
+ _ => {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError("Bad type constructor".into()),
+ meta,
+ });
+
+ value
+ }
+ })
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn matrix_one_arg(
+ &mut self,
+ ctx: &mut Context,
+ ty: Handle<Type>,
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ element_scalar: Scalar,
+ (mut value, expr_meta): (Handle<Expression>, Span),
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let mut components = Vec::with_capacity(columns as usize);
+ // TODO: casts
+ // `Expression::As` doesn't support matrix width
+ // casts so we need to do some extra work for casts
+
+ ctx.forced_conversion(&mut value, expr_meta, element_scalar)?;
+ match *ctx.resolve_type(value, expr_meta)? {
+ TypeInner::Scalar(_) => {
+ // If a matrix is constructed with a single scalar value, then that
+ // value is used to initialize all the values along the diagonal of
+ // the matrix; the rest are given zeros.
+ let vector_ty = ctx.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: rows,
+ scalar: element_scalar,
+ },
+ },
+ meta,
+ );
+
+ let zero_literal = Literal::zero(element_scalar).unwrap();
+ let zero = ctx.add_expression(Expression::Literal(zero_literal), meta)?;
+
+ for i in 0..columns as u32 {
+ components.push(
+ ctx.add_expression(
+ Expression::Compose {
+ ty: vector_ty,
+ components: (0..rows as u32)
+ .map(|r| match r == i {
+ true => value,
+ false => zero,
+ })
+ .collect(),
+ },
+ meta,
+ )?,
+ )
+ }
+ }
+ TypeInner::Matrix {
+ rows: ori_rows,
+ columns: ori_cols,
+ ..
+ } => {
+ // If a matrix is constructed from a matrix, then each component
+ // (column i, row j) in the result that has a corresponding component
+ // (column i, row j) in the argument will be initialized from there. All
+ // other components will be initialized to the identity matrix.
+
+ let zero_literal = Literal::zero(element_scalar).unwrap();
+ let one_literal = Literal::one(element_scalar).unwrap();
+
+ let zero = ctx.add_expression(Expression::Literal(zero_literal), meta)?;
+ let one = ctx.add_expression(Expression::Literal(one_literal), meta)?;
+
+ let vector_ty = ctx.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: rows,
+ scalar: element_scalar,
+ },
+ },
+ meta,
+ );
+
+ for i in 0..columns as u32 {
+ if i < ori_cols as u32 {
+ use std::cmp::Ordering;
+
+ let vector = ctx.add_expression(
+ Expression::AccessIndex {
+ base: value,
+ index: i,
+ },
+ meta,
+ )?;
+
+ components.push(match ori_rows.cmp(&rows) {
+ Ordering::Less => {
+ let components = (0..rows as u32)
+ .map(|r| {
+ if r < ori_rows as u32 {
+ ctx.add_expression(
+ Expression::AccessIndex {
+ base: vector,
+ index: r,
+ },
+ meta,
+ )
+ } else if r == i {
+ Ok(one)
+ } else {
+ Ok(zero)
+ }
+ })
+ .collect::<Result<_>>()?;
+
+ ctx.add_expression(
+ Expression::Compose {
+ ty: vector_ty,
+ components,
+ },
+ meta,
+ )?
+ }
+ Ordering::Equal => vector,
+ Ordering::Greater => ctx.vector_resize(rows, vector, meta)?,
+ })
+ } else {
+ let compose_expr = Expression::Compose {
+ ty: vector_ty,
+ components: (0..rows as u32)
+ .map(|r| match r == i {
+ true => one,
+ false => zero,
+ })
+ .collect(),
+ };
+
+ let vec = ctx.add_expression(compose_expr, meta)?;
+
+ components.push(vec)
+ }
+ }
+ }
+ _ => {
+ components = iter::repeat(value).take(columns as usize).collect();
+ }
+ }
+
+ ctx.add_expression(Expression::Compose { ty, components }, meta)
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn vector_constructor(
+ &mut self,
+ ctx: &mut Context,
+ ty: Handle<Type>,
+ size: crate::VectorSize,
+ scalar: Scalar,
+ args: &[(Handle<Expression>, Span)],
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let mut components = Vec::with_capacity(size as usize);
+
+ for (mut arg, expr_meta) in args.iter().copied() {
+ ctx.forced_conversion(&mut arg, expr_meta, scalar)?;
+
+ if components.len() >= size as usize {
+ break;
+ }
+
+ match *ctx.resolve_type(arg, expr_meta)? {
+ TypeInner::Scalar { .. } => components.push(arg),
+ TypeInner::Matrix { rows, columns, .. } => {
+ components.reserve(rows as usize * columns as usize);
+ for c in 0..(columns as u32) {
+ let base = ctx.add_expression(
+ Expression::AccessIndex {
+ base: arg,
+ index: c,
+ },
+ expr_meta,
+ )?;
+ for r in 0..(rows as u32) {
+ components.push(ctx.add_expression(
+ Expression::AccessIndex { base, index: r },
+ expr_meta,
+ )?)
+ }
+ }
+ }
+ TypeInner::Vector { size: ori_size, .. } => {
+ components.reserve(ori_size as usize);
+ for index in 0..(ori_size as u32) {
+ components.push(ctx.add_expression(
+ Expression::AccessIndex { base: arg, index },
+ expr_meta,
+ )?)
+ }
+ }
+ _ => components.push(arg),
+ }
+ }
+
+ components.truncate(size as usize);
+
+ ctx.add_expression(Expression::Compose { ty, components }, meta)
+ }
+
+ fn constructor_many(
+ &mut self,
+ ctx: &mut Context,
+ ty: Handle<Type>,
+ args: Vec<(Handle<Expression>, Span)>,
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let mut components = Vec::with_capacity(args.len());
+
+ let struct_member_data = match ctx.module.types[ty].inner {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar: element_scalar,
+ } => {
+ let mut flattened = Vec::with_capacity(columns as usize * rows as usize);
+
+ for (mut arg, meta) in args.iter().copied() {
+ ctx.forced_conversion(&mut arg, meta, element_scalar)?;
+
+ match *ctx.resolve_type(arg, meta)? {
+ TypeInner::Vector { size, .. } => {
+ for i in 0..(size as u32) {
+ flattened.push(ctx.add_expression(
+ Expression::AccessIndex {
+ base: arg,
+ index: i,
+ },
+ meta,
+ )?)
+ }
+ }
+ _ => flattened.push(arg),
+ }
+ }
+
+ let ty = ctx.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: rows,
+ scalar: element_scalar,
+ },
+ },
+ meta,
+ );
+
+ for chunk in flattened.chunks(rows as usize) {
+ components.push(ctx.add_expression(
+ Expression::Compose {
+ ty,
+ components: Vec::from(chunk),
+ },
+ meta,
+ )?)
+ }
+ None
+ }
+ TypeInner::Vector { size, scalar } => {
+ return self.vector_constructor(ctx, ty, size, scalar, &args, meta)
+ }
+ TypeInner::Array { base, .. } => {
+ for (mut arg, meta) in args.iter().copied() {
+ let scalar_components = scalar_components(&ctx.module.types[base].inner);
+ if let Some(scalar) = scalar_components {
+ ctx.implicit_conversion(&mut arg, meta, scalar)?;
+ }
+
+ components.push(arg)
+ }
+ None
+ }
+ TypeInner::Struct { ref members, .. } => Some(
+ members
+ .iter()
+ .map(|member| scalar_components(&ctx.module.types[member.ty].inner))
+ .collect::<Vec<_>>(),
+ ),
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::SemanticError("Constructor: Too many arguments".into()),
+ meta,
+ })
+ }
+ };
+
+ if let Some(struct_member_data) = struct_member_data {
+ for ((mut arg, meta), scalar_components) in
+ args.iter().copied().zip(struct_member_data.iter().copied())
+ {
+ if let Some(scalar) = scalar_components {
+ ctx.implicit_conversion(&mut arg, meta, scalar)?;
+ }
+
+ components.push(arg)
+ }
+ }
+
+ ctx.add_expression(Expression::Compose { ty, components }, meta)
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn function_call(
+ &mut self,
+ ctx: &mut Context,
+ stmt: &StmtContext,
+ name: String,
+ args: Vec<(Handle<Expression>, Span)>,
+ raw_args: &[Handle<HirExpr>],
+ meta: Span,
+ ) -> Result<Option<Handle<Expression>>> {
+ // Grow the typifier to be able to index it later without needing
+ // to hold the context mutably
+ for &(expr, span) in args.iter() {
+ ctx.typifier_grow(expr, span)?;
+ }
+
+ // Check if the passed arguments require any special variations
+ let mut variations =
+ builtin_required_variations(args.iter().map(|&(expr, _)| ctx.get_type(expr)));
+
+ // Initiate the declaration if it wasn't previously initialized and inject builtins
+ let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| {
+ variations |= BuiltinVariations::STANDARD;
+ Default::default()
+ });
+ inject_builtin(declaration, ctx.module, &name, variations);
+
+ // Borrow again but without mutability, at this point a declaration is guaranteed
+ let declaration = self.lookup_function.get(&name).unwrap();
+
+ // Possibly contains the overload to be used in the call
+ let mut maybe_overload = None;
+ // The conversions needed for the best analyzed overload, this is initialized all to
+ // `NONE` to make sure that conversions always pass the first time without ambiguity
+ let mut old_conversions = vec![Conversion::None; args.len()];
+ // Tracks whether the comparison between overloads lead to an ambiguity
+ let mut ambiguous = false;
+
+ // Iterate over all the available overloads to select either an exact match or a
+ // overload which has suitable implicit conversions
+ 'outer: for (overload_idx, overload) in declaration.overloads.iter().enumerate() {
+ // If the overload and the function call don't have the same number of arguments
+ // continue to the next overload
+ if args.len() != overload.parameters.len() {
+ continue;
+ }
+
+ log::trace!("Testing overload {}", overload_idx);
+
+ // Stores whether the current overload matches exactly the function call
+ let mut exact = true;
+ // State of the selection
+ // If None we still don't know what is the best overload
+ // If Some(true) the new overload is better
+ // If Some(false) the old overload is better
+ let mut superior = None;
+ // Store the conversions for the current overload so that later they can replace the
+ // conversions used for querying the best overload
+ let mut new_conversions = vec![Conversion::None; args.len()];
+
+ // Loop through the overload parameters and check if the current overload is better
+ // compared to the previous best overload.
+ for (i, overload_parameter) in overload.parameters.iter().enumerate() {
+ let call_argument = &args[i];
+ let parameter_info = &overload.parameters_info[i];
+
+ // If the image is used in the overload as a depth texture convert it
+ // before comparing, otherwise exact matches wouldn't be reported
+ if parameter_info.depth {
+ sampled_to_depth(ctx, call_argument.0, call_argument.1, &mut self.errors);
+ ctx.invalidate_expression(call_argument.0, call_argument.1)?
+ }
+
+ ctx.typifier_grow(call_argument.0, call_argument.1)?;
+
+ let overload_param_ty = &ctx.module.types[*overload_parameter].inner;
+ let call_arg_ty = ctx.get_type(call_argument.0);
+
+ log::trace!(
+ "Testing parameter {}\n\tOverload = {:?}\n\tCall = {:?}",
+ i,
+ overload_param_ty,
+ call_arg_ty
+ );
+
+ // Storage images cannot be directly compared since while the access is part of the
+ // type in naga's IR, in glsl they are a qualifier and don't enter in the match as
+ // long as the access needed is satisfied.
+ if let (
+ &TypeInner::Image {
+ class:
+ crate::ImageClass::Storage {
+ format: overload_format,
+ access: overload_access,
+ },
+ dim: overload_dim,
+ arrayed: overload_arrayed,
+ },
+ &TypeInner::Image {
+ class:
+ crate::ImageClass::Storage {
+ format: call_format,
+ access: call_access,
+ },
+ dim: call_dim,
+ arrayed: call_arrayed,
+ },
+ ) = (overload_param_ty, call_arg_ty)
+ {
+ // Images size must match otherwise the overload isn't what we want
+ let good_size = call_dim == overload_dim && call_arrayed == overload_arrayed;
+ // Glsl requires the formats to strictly match unless you are builtin
+ // function overload and have not been replaced, in which case we only
+ // check that the format scalar kind matches
+ let good_format = overload_format == call_format
+ || (overload.internal
+ && ScalarKind::from(overload_format) == ScalarKind::from(call_format));
+ if !(good_size && good_format) {
+ continue 'outer;
+ }
+
+ // While storage access mismatch is an error it isn't one that causes
+ // the overload matching to fail so we defer the error and consider
+ // that the images match exactly
+ if !call_access.contains(overload_access) {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!(
+ "'{name}': image needs {overload_access:?} access but only {call_access:?} was provided"
+ )
+ .into(),
+ ),
+ meta,
+ });
+ }
+
+ // The images satisfy the conditions to be considered as an exact match
+ new_conversions[i] = Conversion::Exact;
+ continue;
+ } else if overload_param_ty == call_arg_ty {
+ // If the types match there's no need to check for conversions so continue
+ new_conversions[i] = Conversion::Exact;
+ continue;
+ }
+
+ // Glsl defines that inout follows both the conversions for input parameters and
+ // output parameters, this means that the type must have a conversion from both the
+ // call argument to the function parameter and the function parameter to the call
+ // argument, the only way this is possible is for the conversion to be an identity
+ // (i.e. call argument = function parameter)
+ if let ParameterQualifier::InOut = parameter_info.qualifier {
+ continue 'outer;
+ }
+
+ // The function call argument and the function definition
+ // parameter are not equal at this point, so we need to try
+ // implicit conversions.
+ //
+ // Now there are two cases, the argument is defined as a normal
+ // parameter (`in` or `const`), in this case an implicit
+ // conversion is made from the calling argument to the
+ // definition argument. If the parameter is `out` the
+ // opposite needs to be done, so the implicit conversion is made
+ // from the definition argument to the calling argument.
+ let maybe_conversion = if parameter_info.qualifier.is_lhs() {
+ conversion(call_arg_ty, overload_param_ty)
+ } else {
+ conversion(overload_param_ty, call_arg_ty)
+ };
+
+ let conversion = match maybe_conversion {
+ Some(info) => info,
+ None => continue 'outer,
+ };
+
+ // At this point a conversion will be needed so the overload no longer
+ // exactly matches the call arguments
+ exact = false;
+
+ // Compare the conversions needed for this overload parameter to that of the
+ // last overload analyzed respective parameter, the value is:
+ // - `true` when the new overload argument has a better conversion
+ // - `false` when the old overload argument has a better conversion
+ let best_arg = match (conversion, old_conversions[i]) {
+ // An exact match is always better, we don't need to check this for the
+ // current overload since it was checked earlier
+ (_, Conversion::Exact) => false,
+ // No overload was yet analyzed so this one is the best yet
+ (_, Conversion::None) => true,
+ // A conversion from a float to a double is the best possible conversion
+ (Conversion::FloatToDouble, _) => true,
+ (_, Conversion::FloatToDouble) => false,
+ // A conversion from a float to an integer is preferred than one
+ // from double to an integer
+ (Conversion::IntToFloat, Conversion::IntToDouble) => true,
+ (Conversion::IntToDouble, Conversion::IntToFloat) => false,
+ // This case handles things like no conversion and exact which were already
+ // treated and other cases which no conversion is better than the other
+ _ => continue,
+ };
+
+ // Check if the best parameter corresponds to the current selected overload
+ // to pass to the next comparison, if this isn't true mark it as ambiguous
+ match best_arg {
+ true => match superior {
+ Some(false) => ambiguous = true,
+ _ => {
+ superior = Some(true);
+ new_conversions[i] = conversion
+ }
+ },
+ false => match superior {
+ Some(true) => ambiguous = true,
+ _ => superior = Some(false),
+ },
+ }
+ }
+
+ // The overload matches exactly the function call so there's no ambiguity (since
+ // repeated overload aren't allowed) and the current overload is selected, no
+ // further querying is needed.
+ if exact {
+ maybe_overload = Some(overload);
+ ambiguous = false;
+ break;
+ }
+
+ match superior {
+ // New overload is better keep it
+ Some(true) => {
+ maybe_overload = Some(overload);
+ // Replace the conversions
+ old_conversions = new_conversions;
+ }
+ // Old overload is better do nothing
+ Some(false) => {}
+ // No overload was better than the other this can be caused
+ // when all conversions are ambiguous in which the overloads themselves are
+ // ambiguous.
+ None => {
+ ambiguous = true;
+ // Assign the new overload, this helps ensures that in this case of
+ // ambiguity the parsing won't end immediately and allow for further
+ // collection of errors.
+ maybe_overload = Some(overload);
+ }
+ }
+ }
+
+ if ambiguous {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!("Ambiguous best function for '{name}'").into(),
+ ),
+ meta,
+ })
+ }
+
+ let overload = maybe_overload.ok_or_else(|| Error {
+ kind: ErrorKind::SemanticError(format!("Unknown function '{name}'").into()),
+ meta,
+ })?;
+
+ let parameters_info = overload.parameters_info.clone();
+ let parameters = overload.parameters.clone();
+ let is_void = overload.void;
+ let kind = overload.kind;
+
+ let mut arguments = Vec::with_capacity(args.len());
+ let mut proxy_writes = Vec::new();
+
+ // Iterate through the function call arguments applying transformations as needed
+ for (((parameter_info, call_argument), expr), parameter) in parameters_info
+ .iter()
+ .zip(&args)
+ .zip(raw_args)
+ .zip(&parameters)
+ {
+ let (mut handle, meta) =
+ ctx.lower_expect_inner(stmt, self, *expr, parameter_info.qualifier.as_pos())?;
+
+ if parameter_info.qualifier.is_lhs() {
+ self.process_lhs_argument(
+ ctx,
+ meta,
+ *parameter,
+ parameter_info,
+ handle,
+ call_argument,
+ &mut proxy_writes,
+ &mut arguments,
+ )?;
+
+ continue;
+ }
+
+ let scalar_comps = scalar_components(&ctx.module.types[*parameter].inner);
+
+ // Apply implicit conversions as needed
+ if let Some(scalar) = scalar_comps {
+ ctx.implicit_conversion(&mut handle, meta, scalar)?;
+ }
+
+ arguments.push(handle)
+ }
+
+ match kind {
+ FunctionKind::Call(function) => {
+ ctx.emit_end();
+
+ let result = if !is_void {
+ Some(ctx.add_expression(Expression::CallResult(function), meta)?)
+ } else {
+ None
+ };
+
+ ctx.body.push(
+ crate::Statement::Call {
+ function,
+ arguments,
+ result,
+ },
+ meta,
+ );
+
+ ctx.emit_start();
+
+ // Write back all the variables that were scheduled to their original place
+ for proxy_write in proxy_writes {
+ let mut value = ctx.add_expression(
+ Expression::Load {
+ pointer: proxy_write.value,
+ },
+ meta,
+ )?;
+
+ if let Some(scalar) = proxy_write.convert {
+ ctx.conversion(&mut value, meta, scalar)?;
+ }
+
+ ctx.emit_restart();
+
+ ctx.body.push(
+ Statement::Store {
+ pointer: proxy_write.target,
+ value,
+ },
+ meta,
+ );
+ }
+
+ Ok(result)
+ }
+ FunctionKind::Macro(builtin) => builtin.call(self, ctx, arguments.as_mut_slice(), meta),
+ }
+ }
+
+ /// Processes a function call argument that appears in place of an output
+ /// parameter.
+ #[allow(clippy::too_many_arguments)]
+ fn process_lhs_argument(
+ &mut self,
+ ctx: &mut Context,
+ meta: Span,
+ parameter_ty: Handle<Type>,
+ parameter_info: &ParameterInfo,
+ original: Handle<Expression>,
+ call_argument: &(Handle<Expression>, Span),
+ proxy_writes: &mut Vec<ProxyWrite>,
+ arguments: &mut Vec<Handle<Expression>>,
+ ) -> Result<()> {
+ let original_ty = ctx.resolve_type(original, meta)?;
+ let original_pointer_space = original_ty.pointer_space();
+
+ // The type of a possible spill variable needed for a proxy write
+ let mut maybe_ty = match *original_ty {
+ // If the argument is to be passed as a pointer but the type of the
+ // expression returns a vector it must mean that it was for example
+ // swizzled and it must be spilled into a local before calling
+ TypeInner::Vector { size, scalar } => Some(ctx.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector { size, scalar },
+ },
+ Span::default(),
+ )),
+ // If the argument is a pointer whose address space isn't `Function`, an
+ // indirection through a local variable is needed to align the address
+ // spaces of the call argument and the overload parameter.
+ TypeInner::Pointer { base, space } if space != AddressSpace::Function => Some(base),
+ TypeInner::ValuePointer {
+ size,
+ scalar,
+ space,
+ } if space != AddressSpace::Function => {
+ let inner = match size {
+ Some(size) => TypeInner::Vector { size, scalar },
+ None => TypeInner::Scalar(scalar),
+ };
+
+ Some(
+ ctx.module
+ .types
+ .insert(Type { name: None, inner }, Span::default()),
+ )
+ }
+ _ => None,
+ };
+
+ // Since the original expression might be a pointer and we want a value
+ // for the proxy writes, we might need to load the pointer.
+ let value = if original_pointer_space.is_some() {
+ ctx.add_expression(Expression::Load { pointer: original }, Span::default())?
+ } else {
+ original
+ };
+
+ ctx.typifier_grow(call_argument.0, call_argument.1)?;
+
+ let overload_param_ty = &ctx.module.types[parameter_ty].inner;
+ let call_arg_ty = ctx.get_type(call_argument.0);
+ let needs_conversion = call_arg_ty != overload_param_ty;
+
+ let arg_scalar_comps = scalar_components(call_arg_ty);
+
+ // Since output parameters also allow implicit conversions from the
+ // parameter to the argument, we need to spill the conversion to a
+ // variable and create a proxy write for the original variable.
+ if needs_conversion {
+ maybe_ty = Some(parameter_ty);
+ }
+
+ if let Some(ty) = maybe_ty {
+ // Create the spill variable
+ let spill_var = ctx.locals.append(
+ LocalVariable {
+ name: None,
+ ty,
+ init: None,
+ },
+ Span::default(),
+ );
+ let spill_expr =
+ ctx.add_expression(Expression::LocalVariable(spill_var), Span::default())?;
+
+ // If the argument is also copied in we must store the value of the
+ // original variable to the spill variable.
+ if let ParameterQualifier::InOut = parameter_info.qualifier {
+ ctx.body.push(
+ Statement::Store {
+ pointer: spill_expr,
+ value,
+ },
+ Span::default(),
+ );
+ }
+
+ // Add the spill variable as an argument to the function call
+ arguments.push(spill_expr);
+
+ let convert = if needs_conversion {
+ arg_scalar_comps
+ } else {
+ None
+ };
+
+ // Register the temporary local to be written back to it's original
+ // place after the function call
+ if let Expression::Swizzle {
+ size,
+ mut vector,
+ pattern,
+ } = ctx.expressions[original]
+ {
+ if let Expression::Load { pointer } = ctx.expressions[vector] {
+ vector = pointer;
+ }
+
+ for (i, component) in pattern.iter().take(size as usize).enumerate() {
+ let original = ctx.add_expression(
+ Expression::AccessIndex {
+ base: vector,
+ index: *component as u32,
+ },
+ Span::default(),
+ )?;
+
+ let spill_component = ctx.add_expression(
+ Expression::AccessIndex {
+ base: spill_expr,
+ index: i as u32,
+ },
+ Span::default(),
+ )?;
+
+ proxy_writes.push(ProxyWrite {
+ target: original,
+ value: spill_component,
+ convert,
+ });
+ }
+ } else {
+ proxy_writes.push(ProxyWrite {
+ target: original,
+ value: spill_expr,
+ convert,
+ });
+ }
+ } else {
+ arguments.push(original);
+ }
+
+ Ok(())
+ }
+
+ pub(crate) fn add_function(
+ &mut self,
+ mut ctx: Context,
+ name: String,
+ result: Option<FunctionResult>,
+ meta: Span,
+ ) {
+ ensure_block_returns(&mut ctx.body);
+
+ let void = result.is_none();
+
+ // Check if the passed arguments require any special variations
+ let mut variations = builtin_required_variations(
+ ctx.parameters
+ .iter()
+ .map(|&arg| &ctx.module.types[arg].inner),
+ );
+
+ // Initiate the declaration if it wasn't previously initialized and inject builtins
+ let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| {
+ variations |= BuiltinVariations::STANDARD;
+ Default::default()
+ });
+ inject_builtin(declaration, ctx.module, &name, variations);
+
+ let Context {
+ expressions,
+ locals,
+ arguments,
+ parameters,
+ parameters_info,
+ body,
+ module,
+ ..
+ } = ctx;
+
+ let function = Function {
+ name: Some(name),
+ arguments,
+ result,
+ local_variables: locals,
+ expressions,
+ named_expressions: crate::NamedExpressions::default(),
+ body,
+ };
+
+ 'outer: for decl in declaration.overloads.iter_mut() {
+ if parameters.len() != decl.parameters.len() {
+ continue;
+ }
+
+ for (new_parameter, old_parameter) in parameters.iter().zip(decl.parameters.iter()) {
+ let new_inner = &module.types[*new_parameter].inner;
+ let old_inner = &module.types[*old_parameter].inner;
+
+ if new_inner != old_inner {
+ continue 'outer;
+ }
+ }
+
+ if decl.defined {
+ return self.errors.push(Error {
+ kind: ErrorKind::SemanticError("Function already defined".into()),
+ meta,
+ });
+ }
+
+ decl.defined = true;
+ decl.parameters_info = parameters_info;
+ match decl.kind {
+ FunctionKind::Call(handle) => *module.functions.get_mut(handle) = function,
+ FunctionKind::Macro(_) => {
+ let handle = module.functions.append(function, meta);
+ decl.kind = FunctionKind::Call(handle)
+ }
+ }
+ return;
+ }
+
+ let handle = module.functions.append(function, meta);
+ declaration.overloads.push(Overload {
+ parameters,
+ parameters_info,
+ kind: FunctionKind::Call(handle),
+ defined: true,
+ internal: false,
+ void,
+ });
+ }
+
+ pub(crate) fn add_prototype(
+ &mut self,
+ ctx: Context,
+ name: String,
+ result: Option<FunctionResult>,
+ meta: Span,
+ ) {
+ let void = result.is_none();
+
+ // Check if the passed arguments require any special variations
+ let mut variations = builtin_required_variations(
+ ctx.parameters
+ .iter()
+ .map(|&arg| &ctx.module.types[arg].inner),
+ );
+
+ // Initiate the declaration if it wasn't previously initialized and inject builtins
+ let declaration = self.lookup_function.entry(name.clone()).or_insert_with(|| {
+ variations |= BuiltinVariations::STANDARD;
+ Default::default()
+ });
+ inject_builtin(declaration, ctx.module, &name, variations);
+
+ let Context {
+ arguments,
+ parameters,
+ parameters_info,
+ module,
+ ..
+ } = ctx;
+
+ let function = Function {
+ name: Some(name),
+ arguments,
+ result,
+ ..Default::default()
+ };
+
+ 'outer: for decl in declaration.overloads.iter() {
+ if parameters.len() != decl.parameters.len() {
+ continue;
+ }
+
+ for (new_parameter, old_parameter) in parameters.iter().zip(decl.parameters.iter()) {
+ let new_inner = &module.types[*new_parameter].inner;
+ let old_inner = &module.types[*old_parameter].inner;
+
+ if new_inner != old_inner {
+ continue 'outer;
+ }
+ }
+
+ return self.errors.push(Error {
+ kind: ErrorKind::SemanticError("Prototype already defined".into()),
+ meta,
+ });
+ }
+
+ let handle = module.functions.append(function, meta);
+ declaration.overloads.push(Overload {
+ parameters,
+ parameters_info,
+ kind: FunctionKind::Call(handle),
+ defined: false,
+ internal: false,
+ void,
+ });
+ }
+
+ /// Create a Naga [`EntryPoint`] that calls the GLSL `main` function.
+ ///
+ /// We compile the GLSL `main` function as an ordinary Naga [`Function`].
+ /// This function synthesizes a Naga [`EntryPoint`] to call that.
+ ///
+ /// Each GLSL input and output variable (including builtins) becomes a Naga
+ /// [`GlobalVariable`]s in the [`Private`] address space, which `main` can
+ /// access in the usual way.
+ ///
+ /// The `EntryPoint` we synthesize here has an argument for each GLSL input
+ /// variable, and returns a struct with a member for each GLSL output
+ /// variable. The entry point contains code to:
+ ///
+ /// - copy its arguments into the Naga globals representing the GLSL input
+ /// variables,
+ ///
+ /// - call the Naga `Function` representing the GLSL `main` function, and then
+ ///
+ /// - build its return value from whatever values the GLSL `main` left in
+ /// the Naga globals representing GLSL `output` variables.
+ ///
+ /// Upon entry, [`ctx.body`] should contain code, accumulated by prior calls
+ /// to [`ParsingContext::parse_external_declaration`][pxd], to initialize
+ /// private global variables as needed. This code gets spliced into the
+ /// entry point before the call to `main`.
+ ///
+ /// [`GlobalVariable`]: crate::GlobalVariable
+ /// [`Private`]: crate::AddressSpace::Private
+ /// [`ctx.body`]: Context::body
+ /// [pxd]: super::ParsingContext::parse_external_declaration
+ pub(crate) fn add_entry_point(
+ &mut self,
+ function: Handle<Function>,
+ mut ctx: Context,
+ ) -> Result<()> {
+ let mut arguments = Vec::new();
+
+ let body = Block::with_capacity(
+ // global init body
+ ctx.body.len() +
+ // prologue and epilogue
+ self.entry_args.len() * 2
+ // Call, Emit for composing struct and return
+ + 3,
+ );
+
+ let global_init_body = std::mem::replace(&mut ctx.body, body);
+
+ for arg in self.entry_args.iter() {
+ if arg.storage != StorageQualifier::Input {
+ continue;
+ }
+
+ let pointer = ctx
+ .expressions
+ .append(Expression::GlobalVariable(arg.handle), Default::default());
+
+ let ty = ctx.module.global_variables[arg.handle].ty;
+
+ ctx.arg_type_walker(
+ arg.name.clone(),
+ arg.binding.clone(),
+ pointer,
+ ty,
+ &mut |ctx, name, pointer, ty, binding| {
+ let idx = arguments.len() as u32;
+
+ arguments.push(FunctionArgument {
+ name,
+ ty,
+ binding: Some(binding),
+ });
+
+ let value = ctx
+ .expressions
+ .append(Expression::FunctionArgument(idx), Default::default());
+ ctx.body
+ .push(Statement::Store { pointer, value }, Default::default());
+ },
+ )?
+ }
+
+ ctx.body.extend_block(global_init_body);
+
+ ctx.body.push(
+ Statement::Call {
+ function,
+ arguments: Vec::new(),
+ result: None,
+ },
+ Default::default(),
+ );
+
+ let mut span = 0;
+ let mut members = Vec::new();
+ let mut components = Vec::new();
+
+ for arg in self.entry_args.iter() {
+ if arg.storage != StorageQualifier::Output {
+ continue;
+ }
+
+ let pointer = ctx
+ .expressions
+ .append(Expression::GlobalVariable(arg.handle), Default::default());
+
+ let ty = ctx.module.global_variables[arg.handle].ty;
+
+ ctx.arg_type_walker(
+ arg.name.clone(),
+ arg.binding.clone(),
+ pointer,
+ ty,
+ &mut |ctx, name, pointer, ty, binding| {
+ members.push(StructMember {
+ name,
+ ty,
+ binding: Some(binding),
+ offset: span,
+ });
+
+ span += ctx.module.types[ty].inner.size(ctx.module.to_ctx());
+
+ let len = ctx.expressions.len();
+ let load = ctx
+ .expressions
+ .append(Expression::Load { pointer }, Default::default());
+ ctx.body.push(
+ Statement::Emit(ctx.expressions.range_from(len)),
+ Default::default(),
+ );
+ components.push(load)
+ },
+ )?
+ }
+
+ let (ty, value) = if !components.is_empty() {
+ let ty = ctx.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Struct { members, span },
+ },
+ Default::default(),
+ );
+
+ let len = ctx.expressions.len();
+ let res = ctx
+ .expressions
+ .append(Expression::Compose { ty, components }, Default::default());
+ ctx.body.push(
+ Statement::Emit(ctx.expressions.range_from(len)),
+ Default::default(),
+ );
+
+ (Some(ty), Some(res))
+ } else {
+ (None, None)
+ };
+
+ ctx.body
+ .push(Statement::Return { value }, Default::default());
+
+ let Context {
+ body, expressions, ..
+ } = ctx;
+
+ ctx.module.entry_points.push(EntryPoint {
+ name: "main".to_string(),
+ stage: self.meta.stage,
+ early_depth_test: Some(crate::EarlyDepthTest { conservative: None })
+ .filter(|_| self.meta.early_fragment_tests),
+ workgroup_size: self.meta.workgroup_size,
+ function: Function {
+ arguments,
+ expressions,
+ body,
+ result: ty.map(|ty| FunctionResult { ty, binding: None }),
+ ..Default::default()
+ },
+ });
+
+ Ok(())
+ }
+}
+
+impl Context<'_> {
+ /// Helper function for building the input/output interface of the entry point
+ ///
+ /// Calls `f` with the data of the entry point argument, flattening composite types
+ /// recursively
+ ///
+ /// The passed arguments to the callback are:
+ /// - The ctx
+ /// - The name
+ /// - The pointer expression to the global storage
+ /// - The handle to the type of the entry point argument
+ /// - The binding of the entry point argument
+ fn arg_type_walker(
+ &mut self,
+ name: Option<String>,
+ binding: crate::Binding,
+ pointer: Handle<Expression>,
+ ty: Handle<Type>,
+ f: &mut impl FnMut(
+ &mut Context,
+ Option<String>,
+ Handle<Expression>,
+ Handle<Type>,
+ crate::Binding,
+ ),
+ ) -> Result<()> {
+ match self.module.types[ty].inner {
+ // TODO: Better error reporting
+ // right now we just don't walk the array if the size isn't known at
+ // compile time and let validation catch it
+ TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ let mut location = match binding {
+ crate::Binding::Location { location, .. } => location,
+ crate::Binding::BuiltIn(_) => return Ok(()),
+ };
+
+ let interpolation =
+ self.module.types[base]
+ .inner
+ .scalar_kind()
+ .map(|kind| match kind {
+ ScalarKind::Float => crate::Interpolation::Perspective,
+ _ => crate::Interpolation::Flat,
+ });
+
+ for index in 0..size.get() {
+ let member_pointer = self.add_expression(
+ Expression::AccessIndex {
+ base: pointer,
+ index,
+ },
+ crate::Span::default(),
+ )?;
+
+ let binding = crate::Binding::Location {
+ location,
+ interpolation,
+ sampling: None,
+ second_blend_source: false,
+ };
+ location += 1;
+
+ self.arg_type_walker(name.clone(), binding, member_pointer, base, f)?
+ }
+ }
+ TypeInner::Struct { ref members, .. } => {
+ let mut location = match binding {
+ crate::Binding::Location { location, .. } => location,
+ crate::Binding::BuiltIn(_) => return Ok(()),
+ };
+
+ for (i, member) in members.clone().into_iter().enumerate() {
+ let member_pointer = self.add_expression(
+ Expression::AccessIndex {
+ base: pointer,
+ index: i as u32,
+ },
+ crate::Span::default(),
+ )?;
+
+ let binding = match member.binding {
+ Some(binding) => binding,
+ None => {
+ let interpolation = self.module.types[member.ty]
+ .inner
+ .scalar_kind()
+ .map(|kind| match kind {
+ ScalarKind::Float => crate::Interpolation::Perspective,
+ _ => crate::Interpolation::Flat,
+ });
+ let binding = crate::Binding::Location {
+ location,
+ interpolation,
+ sampling: None,
+ second_blend_source: false,
+ };
+ location += 1;
+ binding
+ }
+ };
+
+ self.arg_type_walker(member.name, binding, member_pointer, member.ty, f)?
+ }
+ }
+ _ => f(self, name, pointer, ty, binding),
+ }
+
+ Ok(())
+ }
+}
+
+/// Helper enum containing the type of conversion need for a call
+#[derive(PartialEq, Eq, Clone, Copy, Debug)]
+enum Conversion {
+ /// No conversion needed
+ Exact,
+ /// Float to double conversion needed
+ FloatToDouble,
+ /// Int or uint to float conversion needed
+ IntToFloat,
+ /// Int or uint to double conversion needed
+ IntToDouble,
+ /// Other type of conversion needed
+ Other,
+ /// No conversion was yet registered
+ None,
+}
+
+/// Helper function, returns the type of conversion from `source` to `target`, if a
+/// conversion is not possible returns None.
+fn conversion(target: &TypeInner, source: &TypeInner) -> Option<Conversion> {
+ use ScalarKind::*;
+
+ // Gather the `ScalarKind` and scalar width from both the target and the source
+ let (target_scalar, source_scalar) = match (target, source) {
+ // Conversions between scalars are allowed
+ (&TypeInner::Scalar(tgt_scalar), &TypeInner::Scalar(src_scalar)) => {
+ (tgt_scalar, src_scalar)
+ }
+ // Conversions between vectors of the same size are allowed
+ (
+ &TypeInner::Vector {
+ size: tgt_size,
+ scalar: tgt_scalar,
+ },
+ &TypeInner::Vector {
+ size: src_size,
+ scalar: src_scalar,
+ },
+ ) if tgt_size == src_size => (tgt_scalar, src_scalar),
+ // Conversions between matrices of the same size are allowed
+ (
+ &TypeInner::Matrix {
+ rows: tgt_rows,
+ columns: tgt_cols,
+ scalar: tgt_scalar,
+ },
+ &TypeInner::Matrix {
+ rows: src_rows,
+ columns: src_cols,
+ scalar: src_scalar,
+ },
+ ) if tgt_cols == src_cols && tgt_rows == src_rows => (tgt_scalar, src_scalar),
+ _ => return None,
+ };
+
+ // Check if source can be converted into target, if this is the case then the type
+ // power of target must be higher than that of source
+ let target_power = type_power(target_scalar);
+ let source_power = type_power(source_scalar);
+ if target_power < source_power {
+ return None;
+ }
+
+ Some(match (target_scalar, source_scalar) {
+ // A conversion from a float to a double is special
+ (Scalar::F64, Scalar::F32) => Conversion::FloatToDouble,
+ // A conversion from an integer to a float is special
+ (
+ Scalar::F32,
+ Scalar {
+ kind: Sint | Uint,
+ width: _,
+ },
+ ) => Conversion::IntToFloat,
+ // A conversion from an integer to a double is special
+ (
+ Scalar::F64,
+ Scalar {
+ kind: Sint | Uint,
+ width: _,
+ },
+ ) => Conversion::IntToDouble,
+ _ => Conversion::Other,
+ })
+}
+
+/// Helper method returning all the non standard builtin variations needed
+/// to process the function call with the passed arguments
+fn builtin_required_variations<'a>(args: impl Iterator<Item = &'a TypeInner>) -> BuiltinVariations {
+ let mut variations = BuiltinVariations::empty();
+
+ for ty in args {
+ match *ty {
+ TypeInner::ValuePointer { scalar, .. }
+ | TypeInner::Scalar(scalar)
+ | TypeInner::Vector { scalar, .. }
+ | TypeInner::Matrix { scalar, .. } => {
+ if scalar == Scalar::F64 {
+ variations |= BuiltinVariations::DOUBLE
+ }
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ if dim == crate::ImageDimension::Cube && arrayed {
+ variations |= BuiltinVariations::CUBE_TEXTURES_ARRAY
+ }
+
+ if dim == crate::ImageDimension::D2 && arrayed && class.is_multisampled() {
+ variations |= BuiltinVariations::D2_MULTI_TEXTURES_ARRAY
+ }
+ }
+ _ => {}
+ }
+ }
+
+ variations
+}
diff --git a/third_party/rust/naga/src/front/glsl/lex.rs b/third_party/rust/naga/src/front/glsl/lex.rs
new file mode 100644
index 0000000000..1b59a9bf3e
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/lex.rs
@@ -0,0 +1,301 @@
+use super::{
+ ast::Precision,
+ token::{Directive, DirectiveKind, Token, TokenValue},
+ types::parse_type,
+};
+use crate::{FastHashMap, Span, StorageAccess};
+use pp_rs::{
+ pp::Preprocessor,
+ token::{PreprocessorError, Punct, TokenValue as PPTokenValue},
+};
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct LexerResult {
+ pub kind: LexerResultKind,
+ pub meta: Span,
+}
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum LexerResultKind {
+ Token(Token),
+ Directive(Directive),
+ Error(PreprocessorError),
+}
+
+pub struct Lexer<'a> {
+ pp: Preprocessor<'a>,
+}
+
+impl<'a> Lexer<'a> {
+ pub fn new(input: &'a str, defines: &'a FastHashMap<String, String>) -> Self {
+ let mut pp = Preprocessor::new(input);
+ for (define, value) in defines {
+ pp.add_define(define, value).unwrap(); //TODO: handle error
+ }
+ Lexer { pp }
+ }
+}
+
+impl<'a> Iterator for Lexer<'a> {
+ type Item = LexerResult;
+ fn next(&mut self) -> Option<Self::Item> {
+ let pp_token = match self.pp.next()? {
+ Ok(t) => t,
+ Err((err, loc)) => {
+ return Some(LexerResult {
+ kind: LexerResultKind::Error(err),
+ meta: loc.into(),
+ });
+ }
+ };
+
+ let meta = pp_token.location.into();
+ let value = match pp_token.value {
+ PPTokenValue::Extension(extension) => {
+ return Some(LexerResult {
+ kind: LexerResultKind::Directive(Directive {
+ kind: DirectiveKind::Extension,
+ tokens: extension.tokens,
+ }),
+ meta,
+ })
+ }
+ PPTokenValue::Float(float) => TokenValue::FloatConstant(float),
+ PPTokenValue::Ident(ident) => {
+ match ident.as_str() {
+ // Qualifiers
+ "layout" => TokenValue::Layout,
+ "in" => TokenValue::In,
+ "out" => TokenValue::Out,
+ "uniform" => TokenValue::Uniform,
+ "buffer" => TokenValue::Buffer,
+ "shared" => TokenValue::Shared,
+ "invariant" => TokenValue::Invariant,
+ "flat" => TokenValue::Interpolation(crate::Interpolation::Flat),
+ "noperspective" => TokenValue::Interpolation(crate::Interpolation::Linear),
+ "smooth" => TokenValue::Interpolation(crate::Interpolation::Perspective),
+ "centroid" => TokenValue::Sampling(crate::Sampling::Centroid),
+ "sample" => TokenValue::Sampling(crate::Sampling::Sample),
+ "const" => TokenValue::Const,
+ "inout" => TokenValue::InOut,
+ "precision" => TokenValue::Precision,
+ "highp" => TokenValue::PrecisionQualifier(Precision::High),
+ "mediump" => TokenValue::PrecisionQualifier(Precision::Medium),
+ "lowp" => TokenValue::PrecisionQualifier(Precision::Low),
+ "restrict" => TokenValue::Restrict,
+ "readonly" => TokenValue::MemoryQualifier(StorageAccess::LOAD),
+ "writeonly" => TokenValue::MemoryQualifier(StorageAccess::STORE),
+ // values
+ "true" => TokenValue::BoolConstant(true),
+ "false" => TokenValue::BoolConstant(false),
+ // jump statements
+ "continue" => TokenValue::Continue,
+ "break" => TokenValue::Break,
+ "return" => TokenValue::Return,
+ "discard" => TokenValue::Discard,
+ // selection statements
+ "if" => TokenValue::If,
+ "else" => TokenValue::Else,
+ "switch" => TokenValue::Switch,
+ "case" => TokenValue::Case,
+ "default" => TokenValue::Default,
+ // iteration statements
+ "while" => TokenValue::While,
+ "do" => TokenValue::Do,
+ "for" => TokenValue::For,
+ // types
+ "void" => TokenValue::Void,
+ "struct" => TokenValue::Struct,
+ word => match parse_type(word) {
+ Some(t) => TokenValue::TypeName(t),
+ None => TokenValue::Identifier(String::from(word)),
+ },
+ }
+ }
+ PPTokenValue::Integer(integer) => TokenValue::IntConstant(integer),
+ PPTokenValue::Punct(punct) => match punct {
+ // Compound assignments
+ Punct::AddAssign => TokenValue::AddAssign,
+ Punct::SubAssign => TokenValue::SubAssign,
+ Punct::MulAssign => TokenValue::MulAssign,
+ Punct::DivAssign => TokenValue::DivAssign,
+ Punct::ModAssign => TokenValue::ModAssign,
+ Punct::LeftShiftAssign => TokenValue::LeftShiftAssign,
+ Punct::RightShiftAssign => TokenValue::RightShiftAssign,
+ Punct::AndAssign => TokenValue::AndAssign,
+ Punct::XorAssign => TokenValue::XorAssign,
+ Punct::OrAssign => TokenValue::OrAssign,
+
+ // Two character punctuation
+ Punct::Increment => TokenValue::Increment,
+ Punct::Decrement => TokenValue::Decrement,
+ Punct::LogicalAnd => TokenValue::LogicalAnd,
+ Punct::LogicalOr => TokenValue::LogicalOr,
+ Punct::LogicalXor => TokenValue::LogicalXor,
+ Punct::LessEqual => TokenValue::LessEqual,
+ Punct::GreaterEqual => TokenValue::GreaterEqual,
+ Punct::EqualEqual => TokenValue::Equal,
+ Punct::NotEqual => TokenValue::NotEqual,
+ Punct::LeftShift => TokenValue::LeftShift,
+ Punct::RightShift => TokenValue::RightShift,
+
+ // Parenthesis or similar
+ Punct::LeftBrace => TokenValue::LeftBrace,
+ Punct::RightBrace => TokenValue::RightBrace,
+ Punct::LeftParen => TokenValue::LeftParen,
+ Punct::RightParen => TokenValue::RightParen,
+ Punct::LeftBracket => TokenValue::LeftBracket,
+ Punct::RightBracket => TokenValue::RightBracket,
+
+ // Other one character punctuation
+ Punct::LeftAngle => TokenValue::LeftAngle,
+ Punct::RightAngle => TokenValue::RightAngle,
+ Punct::Semicolon => TokenValue::Semicolon,
+ Punct::Comma => TokenValue::Comma,
+ Punct::Colon => TokenValue::Colon,
+ Punct::Dot => TokenValue::Dot,
+ Punct::Equal => TokenValue::Assign,
+ Punct::Bang => TokenValue::Bang,
+ Punct::Minus => TokenValue::Dash,
+ Punct::Tilde => TokenValue::Tilde,
+ Punct::Plus => TokenValue::Plus,
+ Punct::Star => TokenValue::Star,
+ Punct::Slash => TokenValue::Slash,
+ Punct::Percent => TokenValue::Percent,
+ Punct::Pipe => TokenValue::VerticalBar,
+ Punct::Caret => TokenValue::Caret,
+ Punct::Ampersand => TokenValue::Ampersand,
+ Punct::Question => TokenValue::Question,
+ },
+ PPTokenValue::Pragma(pragma) => {
+ return Some(LexerResult {
+ kind: LexerResultKind::Directive(Directive {
+ kind: DirectiveKind::Pragma,
+ tokens: pragma.tokens,
+ }),
+ meta,
+ })
+ }
+ PPTokenValue::Version(version) => {
+ return Some(LexerResult {
+ kind: LexerResultKind::Directive(Directive {
+ kind: DirectiveKind::Version {
+ is_first_directive: version.is_first_directive,
+ },
+ tokens: version.tokens,
+ }),
+ meta,
+ })
+ }
+ };
+
+ Some(LexerResult {
+ kind: LexerResultKind::Token(Token { value, meta }),
+ meta,
+ })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use pp_rs::token::{Integer, Location, Token as PPToken, TokenValue as PPTokenValue};
+
+ use super::{
+ super::token::{Directive, DirectiveKind, Token, TokenValue},
+ Lexer, LexerResult, LexerResultKind,
+ };
+ use crate::Span;
+
+ #[test]
+ fn lex_tokens() {
+ let defines = crate::FastHashMap::default();
+
+ // line comments
+ let mut lex = Lexer::new("#version 450\nvoid main () {}", &defines);
+ let mut location = Location::default();
+ location.start = 9;
+ location.end = 12;
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Directive(Directive {
+ kind: DirectiveKind::Version {
+ is_first_directive: true
+ },
+ tokens: vec![PPToken {
+ value: PPTokenValue::Integer(Integer {
+ signed: true,
+ value: 450,
+ width: 32
+ }),
+ location
+ }]
+ }),
+ meta: Span::new(1, 8)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::Void,
+ meta: Span::new(13, 17)
+ }),
+ meta: Span::new(13, 17)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::Identifier("main".into()),
+ meta: Span::new(18, 22)
+ }),
+ meta: Span::new(18, 22)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::LeftParen,
+ meta: Span::new(23, 24)
+ }),
+ meta: Span::new(23, 24)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::RightParen,
+ meta: Span::new(24, 25)
+ }),
+ meta: Span::new(24, 25)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::LeftBrace,
+ meta: Span::new(26, 27)
+ }),
+ meta: Span::new(26, 27)
+ }
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LexerResult {
+ kind: LexerResultKind::Token(Token {
+ value: TokenValue::RightBrace,
+ meta: Span::new(27, 28)
+ }),
+ meta: Span::new(27, 28)
+ }
+ );
+ assert_eq!(lex.next(), None);
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/mod.rs b/third_party/rust/naga/src/front/glsl/mod.rs
new file mode 100644
index 0000000000..75f3929db4
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/mod.rs
@@ -0,0 +1,232 @@
+/*!
+Frontend for [GLSL][glsl] (OpenGL Shading Language).
+
+To begin, take a look at the documentation for the [`Frontend`].
+
+# Supported versions
+## Vulkan
+- 440 (partial)
+- 450
+- 460
+
+[glsl]: https://www.khronos.org/registry/OpenGL/index_gl.php
+*/
+
+pub use ast::{Precision, Profile};
+pub use error::{Error, ErrorKind, ExpectedToken, ParseError};
+pub use token::TokenValue;
+
+use crate::{proc::Layouter, FastHashMap, FastHashSet, Handle, Module, ShaderStage, Span, Type};
+use ast::{EntryArg, FunctionDeclaration, GlobalLookup};
+use parser::ParsingContext;
+
+mod ast;
+mod builtins;
+mod context;
+mod error;
+mod functions;
+mod lex;
+mod offset;
+mod parser;
+#[cfg(test)]
+mod parser_tests;
+mod token;
+mod types;
+mod variables;
+
+type Result<T> = std::result::Result<T, Error>;
+
+/// Per-shader options passed to [`parse`](Frontend::parse).
+///
+/// The [`From`] trait is implemented for [`ShaderStage`] to provide a quick way
+/// to create an `Options` instance.
+///
+/// ```rust
+/// # use naga::ShaderStage;
+/// # use naga::front::glsl::Options;
+/// Options::from(ShaderStage::Vertex);
+/// ```
+#[derive(Debug)]
+pub struct Options {
+ /// The shader stage in the pipeline.
+ pub stage: ShaderStage,
+ /// Preprocessor definitions to be used, akin to having
+ /// ```glsl
+ /// #define key value
+ /// ```
+ /// for each key value pair in the map.
+ pub defines: FastHashMap<String, String>,
+}
+
+impl From<ShaderStage> for Options {
+ fn from(stage: ShaderStage) -> Self {
+ Options {
+ stage,
+ defines: FastHashMap::default(),
+ }
+ }
+}
+
+/// Additional information about the GLSL shader.
+///
+/// Stores additional information about the GLSL shader which might not be
+/// stored in the shader [`Module`].
+#[derive(Debug)]
+pub struct ShaderMetadata {
+ /// The GLSL version specified in the shader through the use of the
+ /// `#version` preprocessor directive.
+ pub version: u16,
+ /// The GLSL profile specified in the shader through the use of the
+ /// `#version` preprocessor directive.
+ pub profile: Profile,
+ /// The shader stage in the pipeline, passed to the [`parse`](Frontend::parse)
+ /// method via the [`Options`] struct.
+ pub stage: ShaderStage,
+
+ /// The workgroup size for compute shaders, defaults to `[1; 3]` for
+ /// compute shaders and `[0; 3]` for non compute shaders.
+ pub workgroup_size: [u32; 3],
+ /// Whether or not early fragment tests where requested by the shader.
+ /// Defaults to `false`.
+ pub early_fragment_tests: bool,
+
+ /// The shader can request extensions via the
+ /// `#extension` preprocessor directive, in the directive a behavior
+ /// parameter is used to control whether the extension should be disabled,
+ /// warn on usage, enabled if possible or required.
+ ///
+ /// This field only stores extensions which were required or requested to
+ /// be enabled if possible and they are supported.
+ pub extensions: FastHashSet<String>,
+}
+
+impl ShaderMetadata {
+ fn reset(&mut self, stage: ShaderStage) {
+ self.version = 0;
+ self.profile = Profile::Core;
+ self.stage = stage;
+ self.workgroup_size = [u32::from(stage == ShaderStage::Compute); 3];
+ self.early_fragment_tests = false;
+ self.extensions.clear();
+ }
+}
+
+impl Default for ShaderMetadata {
+ fn default() -> Self {
+ ShaderMetadata {
+ version: 0,
+ profile: Profile::Core,
+ stage: ShaderStage::Vertex,
+ workgroup_size: [0; 3],
+ early_fragment_tests: false,
+ extensions: FastHashSet::default(),
+ }
+ }
+}
+
+/// The `Frontend` is the central structure of the GLSL frontend.
+///
+/// To instantiate a new `Frontend` the [`Default`] trait is used, so a
+/// call to the associated function [`Frontend::default`](Frontend::default) will
+/// return a new `Frontend` instance.
+///
+/// To parse a shader simply call the [`parse`](Frontend::parse) method with a
+/// [`Options`] struct and a [`&str`](str) holding the glsl code.
+///
+/// The `Frontend` also provides the [`metadata`](Frontend::metadata) to get some
+/// further information about the previously parsed shader, like version and
+/// extensions used (see the documentation for
+/// [`ShaderMetadata`] to see all the returned information)
+///
+/// # Example usage
+/// ```rust
+/// use naga::ShaderStage;
+/// use naga::front::glsl::{Frontend, Options};
+///
+/// let glsl = r#"
+/// #version 450 core
+///
+/// void main() {}
+/// "#;
+///
+/// let mut frontend = Frontend::default();
+/// let options = Options::from(ShaderStage::Vertex);
+/// frontend.parse(&options, glsl);
+/// ```
+///
+/// # Reusability
+///
+/// If there's a need to parse more than one shader reusing the same `Frontend`
+/// instance may be beneficial since internal allocations will be reused.
+///
+/// Calling the [`parse`](Frontend::parse) method multiple times will reset the
+/// `Frontend` so no extra care is needed when reusing.
+#[derive(Debug, Default)]
+pub struct Frontend {
+ meta: ShaderMetadata,
+
+ lookup_function: FastHashMap<String, FunctionDeclaration>,
+ lookup_type: FastHashMap<String, Handle<Type>>,
+
+ global_variables: Vec<(String, GlobalLookup)>,
+
+ entry_args: Vec<EntryArg>,
+
+ layouter: Layouter,
+
+ errors: Vec<Error>,
+}
+
+impl Frontend {
+ fn reset(&mut self, stage: ShaderStage) {
+ self.meta.reset(stage);
+
+ self.lookup_function.clear();
+ self.lookup_type.clear();
+ self.global_variables.clear();
+ self.entry_args.clear();
+ self.layouter.clear();
+ }
+
+ /// Parses a shader either outputting a shader [`Module`] or a list of
+ /// [`Error`]s.
+ ///
+ /// Multiple calls using the same `Frontend` and different shaders are supported.
+ pub fn parse(
+ &mut self,
+ options: &Options,
+ source: &str,
+ ) -> std::result::Result<Module, ParseError> {
+ self.reset(options.stage);
+
+ let lexer = lex::Lexer::new(source, &options.defines);
+ let mut ctx = ParsingContext::new(lexer);
+
+ match ctx.parse(self) {
+ Ok(module) => {
+ if self.errors.is_empty() {
+ Ok(module)
+ } else {
+ Err(std::mem::take(&mut self.errors).into())
+ }
+ }
+ Err(e) => {
+ self.errors.push(e);
+ Err(std::mem::take(&mut self.errors).into())
+ }
+ }
+ }
+
+ /// Returns additional information about the parsed shader which might not
+ /// be stored in the [`Module`], see the documentation for
+ /// [`ShaderMetadata`] for more information about the returned data.
+ ///
+ /// # Notes
+ ///
+ /// Following an unsuccessful parsing the state of the returned information
+ /// is undefined, it might contain only partial information about the
+ /// current shader, the previous shader or both.
+ pub const fn metadata(&self) -> &ShaderMetadata {
+ &self.meta
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/offset.rs b/third_party/rust/naga/src/front/glsl/offset.rs
new file mode 100644
index 0000000000..c88c46598d
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/offset.rs
@@ -0,0 +1,173 @@
+/*!
+Module responsible for calculating the offset and span for types.
+
+There exists two types of layouts std140 and std430 (there's technically
+two more layouts, shared and packed. Shared is not supported by spirv. Packed is
+implementation dependent and for now it's just implemented as an alias to
+std140).
+
+The OpenGl spec (the layout rules are defined by the OpenGl spec in section
+7.6.2.2 as opposed to the GLSL spec) uses the term basic machine units which are
+equivalent to bytes.
+*/
+
+use super::{
+ ast::StructLayout,
+ error::{Error, ErrorKind},
+ Span,
+};
+use crate::{proc::Alignment, Handle, Scalar, Type, TypeInner, UniqueArena};
+
+/// Struct with information needed for defining a struct member.
+///
+/// Returned by [`calculate_offset`].
+#[derive(Debug)]
+pub struct TypeAlignSpan {
+ /// The handle to the type, this might be the same handle passed to
+ /// [`calculate_offset`] or a new such a new array type with a different
+ /// stride set.
+ pub ty: Handle<Type>,
+ /// The alignment required by the type.
+ pub align: Alignment,
+ /// The size of the type.
+ pub span: u32,
+}
+
+/// Returns the type, alignment and span of a struct member according to a [`StructLayout`].
+///
+/// The functions returns a [`TypeAlignSpan`] which has a `ty` member this
+/// should be used as the struct member type because for example arrays may have
+/// to change the stride and as such need to have a different type.
+pub fn calculate_offset(
+ mut ty: Handle<Type>,
+ meta: Span,
+ layout: StructLayout,
+ types: &mut UniqueArena<Type>,
+ errors: &mut Vec<Error>,
+) -> TypeAlignSpan {
+ // When using the std430 storage layout, shader storage blocks will be laid out in buffer storage
+ // identically to uniform and shader storage blocks using the std140 layout, except
+ // that the base alignment and stride of arrays of scalars and vectors in rule 4 and of
+ // structures in rule 9 are not rounded up a multiple of the base alignment of a vec4.
+
+ let (align, span) = match types[ty].inner {
+ // 1. If the member is a scalar consuming N basic machine units,
+ // the base alignment is N.
+ TypeInner::Scalar(Scalar { width, .. }) => (Alignment::from_width(width), width as u32),
+ // 2. If the member is a two- or four-component vector with components
+ // consuming N basic machine units, the base alignment is 2N or 4N, respectively.
+ // 3. If the member is a three-component vector with components consuming N
+ // basic machine units, the base alignment is 4N.
+ TypeInner::Vector {
+ size,
+ scalar: Scalar { width, .. },
+ } => (
+ Alignment::from(size) * Alignment::from_width(width),
+ size as u32 * width as u32,
+ ),
+ // 4. If the member is an array of scalars or vectors, the base alignment and array
+ // stride are set to match the base alignment of a single array element, according
+ // to rules (1), (2), and (3), and rounded up to the base alignment of a vec4.
+ // TODO: Matrices array
+ TypeInner::Array { base, size, .. } => {
+ let info = calculate_offset(base, meta, layout, types, errors);
+
+ let name = types[ty].name.clone();
+
+ // See comment at the beginning of the function
+ let (align, stride) = if StructLayout::Std430 == layout {
+ (info.align, info.align.round_up(info.span))
+ } else {
+ let align = info.align.max(Alignment::MIN_UNIFORM);
+ (align, align.round_up(info.span))
+ };
+
+ let span = match size {
+ crate::ArraySize::Constant(size) => size.get() * stride,
+ crate::ArraySize::Dynamic => stride,
+ };
+
+ let ty_span = types.get_span(ty);
+ ty = types.insert(
+ Type {
+ name,
+ inner: TypeInner::Array {
+ base: info.ty,
+ size,
+ stride,
+ },
+ },
+ ty_span,
+ );
+
+ (align, span)
+ }
+ // 5. If the member is a column-major matrix with C columns and R rows, the
+ // matrix is stored identically to an array of C column vectors with R
+ // components each, according to rule (4)
+ // TODO: Row major matrices
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ let mut align = Alignment::from(rows) * Alignment::from_width(scalar.width);
+
+ // See comment at the beginning of the function
+ if StructLayout::Std430 != layout {
+ align = align.max(Alignment::MIN_UNIFORM);
+ }
+
+ // See comment on the error kind
+ if StructLayout::Std140 == layout && rows == crate::VectorSize::Bi {
+ errors.push(Error {
+ kind: ErrorKind::UnsupportedMatrixTypeInStd140,
+ meta,
+ });
+ }
+
+ (align, align * columns as u32)
+ }
+ TypeInner::Struct { ref members, .. } => {
+ let mut span = 0;
+ let mut align = Alignment::ONE;
+ let mut members = members.clone();
+ let name = types[ty].name.clone();
+
+ for member in members.iter_mut() {
+ let info = calculate_offset(member.ty, meta, layout, types, errors);
+
+ let member_alignment = info.align;
+ span = member_alignment.round_up(span);
+ align = member_alignment.max(align);
+
+ member.ty = info.ty;
+ member.offset = span;
+
+ span += info.span;
+ }
+
+ span = align.round_up(span);
+
+ let ty_span = types.get_span(ty);
+ ty = types.insert(
+ Type {
+ name,
+ inner: TypeInner::Struct { members, span },
+ },
+ ty_span,
+ );
+
+ (align, span)
+ }
+ _ => {
+ errors.push(Error {
+ kind: ErrorKind::SemanticError("Invalid struct member type".into()),
+ meta,
+ });
+ (Alignment::ONE, 0)
+ }
+ };
+
+ TypeAlignSpan { ty, align, span }
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser.rs b/third_party/rust/naga/src/front/glsl/parser.rs
new file mode 100644
index 0000000000..851d2e1d79
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser.rs
@@ -0,0 +1,431 @@
+use super::{
+ ast::{FunctionKind, Profile, TypeQualifiers},
+ context::{Context, ExprPos},
+ error::ExpectedToken,
+ error::{Error, ErrorKind},
+ lex::{Lexer, LexerResultKind},
+ token::{Directive, DirectiveKind},
+ token::{Token, TokenValue},
+ variables::{GlobalOrConstant, VarDeclaration},
+ Frontend, Result,
+};
+use crate::{arena::Handle, proc::U32EvalError, Expression, Module, Span, Type};
+use pp_rs::token::{PreprocessorError, Token as PPToken, TokenValue as PPTokenValue};
+use std::iter::Peekable;
+
+mod declarations;
+mod expressions;
+mod functions;
+mod types;
+
+pub struct ParsingContext<'source> {
+ lexer: Peekable<Lexer<'source>>,
+ /// Used to store tokens already consumed by the parser but that need to be backtracked
+ backtracked_token: Option<Token>,
+ last_meta: Span,
+}
+
+impl<'source> ParsingContext<'source> {
+ pub fn new(lexer: Lexer<'source>) -> Self {
+ ParsingContext {
+ lexer: lexer.peekable(),
+ backtracked_token: None,
+ last_meta: Span::default(),
+ }
+ }
+
+ /// Helper method for backtracking from a consumed token
+ ///
+ /// This method should always be used instead of assigning to `backtracked_token` since
+ /// it validates that backtracking hasn't occurred more than one time in a row
+ ///
+ /// # Panics
+ /// - If the parser already backtracked without bumping in between
+ pub fn backtrack(&mut self, token: Token) -> Result<()> {
+ // This should never happen
+ if let Some(ref prev_token) = self.backtracked_token {
+ return Err(Error {
+ kind: ErrorKind::InternalError("The parser tried to backtrack twice in a row"),
+ meta: prev_token.meta,
+ });
+ }
+
+ self.backtracked_token = Some(token);
+
+ Ok(())
+ }
+
+ pub fn expect_ident(&mut self, frontend: &mut Frontend) -> Result<(String, Span)> {
+ let token = self.bump(frontend)?;
+
+ match token.value {
+ TokenValue::Identifier(name) => Ok((name, token.meta)),
+ _ => Err(Error {
+ kind: ErrorKind::InvalidToken(token.value, vec![ExpectedToken::Identifier]),
+ meta: token.meta,
+ }),
+ }
+ }
+
+ pub fn expect(&mut self, frontend: &mut Frontend, value: TokenValue) -> Result<Token> {
+ let token = self.bump(frontend)?;
+
+ if token.value != value {
+ Err(Error {
+ kind: ErrorKind::InvalidToken(token.value, vec![value.into()]),
+ meta: token.meta,
+ })
+ } else {
+ Ok(token)
+ }
+ }
+
+ pub fn next(&mut self, frontend: &mut Frontend) -> Option<Token> {
+ loop {
+ if let Some(token) = self.backtracked_token.take() {
+ self.last_meta = token.meta;
+ break Some(token);
+ }
+
+ let res = self.lexer.next()?;
+
+ match res.kind {
+ LexerResultKind::Token(token) => {
+ self.last_meta = token.meta;
+ break Some(token);
+ }
+ LexerResultKind::Directive(directive) => {
+ frontend.handle_directive(directive, res.meta)
+ }
+ LexerResultKind::Error(error) => frontend.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(error),
+ meta: res.meta,
+ }),
+ }
+ }
+ }
+
+ pub fn bump(&mut self, frontend: &mut Frontend) -> Result<Token> {
+ self.next(frontend).ok_or(Error {
+ kind: ErrorKind::EndOfFile,
+ meta: self.last_meta,
+ })
+ }
+
+ /// Returns None on the end of the file rather than an error like other methods
+ pub fn bump_if(&mut self, frontend: &mut Frontend, value: TokenValue) -> Option<Token> {
+ if self.peek(frontend).filter(|t| t.value == value).is_some() {
+ self.bump(frontend).ok()
+ } else {
+ None
+ }
+ }
+
+ pub fn peek(&mut self, frontend: &mut Frontend) -> Option<&Token> {
+ loop {
+ if let Some(ref token) = self.backtracked_token {
+ break Some(token);
+ }
+
+ match self.lexer.peek()?.kind {
+ LexerResultKind::Token(_) => {
+ let res = self.lexer.peek()?;
+
+ match res.kind {
+ LexerResultKind::Token(ref token) => break Some(token),
+ _ => unreachable!(),
+ }
+ }
+ LexerResultKind::Error(_) | LexerResultKind::Directive(_) => {
+ let res = self.lexer.next()?;
+
+ match res.kind {
+ LexerResultKind::Directive(directive) => {
+ frontend.handle_directive(directive, res.meta)
+ }
+ LexerResultKind::Error(error) => frontend.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(error),
+ meta: res.meta,
+ }),
+ LexerResultKind::Token(_) => unreachable!(),
+ }
+ }
+ }
+ }
+ }
+
+ pub fn expect_peek(&mut self, frontend: &mut Frontend) -> Result<&Token> {
+ let meta = self.last_meta;
+ self.peek(frontend).ok_or(Error {
+ kind: ErrorKind::EndOfFile,
+ meta,
+ })
+ }
+
+ pub fn parse(&mut self, frontend: &mut Frontend) -> Result<Module> {
+ let mut module = Module::default();
+
+ // Body and expression arena for global initialization
+ let mut ctx = Context::new(frontend, &mut module, false)?;
+
+ while self.peek(frontend).is_some() {
+ self.parse_external_declaration(frontend, &mut ctx)?;
+ }
+
+ // Add an `EntryPoint` to `parser.module` for `main`, if a
+ // suitable overload exists. Error out if we can't find one.
+ if let Some(declaration) = frontend.lookup_function.get("main") {
+ for decl in declaration.overloads.iter() {
+ if let FunctionKind::Call(handle) = decl.kind {
+ if decl.defined && decl.parameters.is_empty() {
+ frontend.add_entry_point(handle, ctx)?;
+ return Ok(module);
+ }
+ }
+ }
+ }
+
+ Err(Error {
+ kind: ErrorKind::SemanticError("Missing entry point".into()),
+ meta: Span::default(),
+ })
+ }
+
+ fn parse_uint_constant(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ ) -> Result<(u32, Span)> {
+ let (const_expr, meta) = self.parse_constant_expression(frontend, ctx.module)?;
+
+ let res = ctx.module.to_ctx().eval_expr_to_u32(const_expr);
+
+ let int = match res {
+ Ok(value) => Ok(value),
+ Err(U32EvalError::Negative) => Err(Error {
+ kind: ErrorKind::SemanticError("int constant overflows".into()),
+ meta,
+ }),
+ Err(U32EvalError::NonConst) => Err(Error {
+ kind: ErrorKind::SemanticError("Expected a uint constant".into()),
+ meta,
+ }),
+ }?;
+
+ Ok((int, meta))
+ }
+
+ fn parse_constant_expression(
+ &mut self,
+ frontend: &mut Frontend,
+ module: &mut Module,
+ ) -> Result<(Handle<Expression>, Span)> {
+ let mut ctx = Context::new(frontend, module, true)?;
+
+ let mut stmt_ctx = ctx.stmt_ctx();
+ let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, None)?;
+ let (root, meta) = ctx.lower_expect(stmt_ctx, frontend, expr, ExprPos::Rhs)?;
+
+ Ok((root, meta))
+ }
+}
+
+impl Frontend {
+ fn handle_directive(&mut self, directive: Directive, meta: Span) {
+ let mut tokens = directive.tokens.into_iter();
+
+ match directive.kind {
+ DirectiveKind::Version { is_first_directive } => {
+ if !is_first_directive {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "#version must occur first in shader".into(),
+ ),
+ meta,
+ })
+ }
+
+ match tokens.next() {
+ Some(PPToken {
+ value: PPTokenValue::Integer(int),
+ location,
+ }) => match int.value {
+ 440 | 450 | 460 => self.meta.version = int.value as u16,
+ _ => self.errors.push(Error {
+ kind: ErrorKind::InvalidVersion(int.value),
+ meta: location.into(),
+ }),
+ },
+ Some(PPToken { value, location }) => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ }),
+ None => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine),
+ meta,
+ }),
+ };
+
+ match tokens.next() {
+ Some(PPToken {
+ value: PPTokenValue::Ident(name),
+ location,
+ }) => match name.as_str() {
+ "core" => self.meta.profile = Profile::Core,
+ _ => self.errors.push(Error {
+ kind: ErrorKind::InvalidProfile(name),
+ meta: location.into(),
+ }),
+ },
+ Some(PPToken { value, location }) => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ }),
+ None => {}
+ };
+
+ if let Some(PPToken { value, location }) = tokens.next() {
+ self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ })
+ }
+ }
+ DirectiveKind::Extension => {
+ // TODO: Proper extension handling
+ // - Checking for extension support in the compiler
+ // - Handle behaviors such as warn
+ // - Handle the all extension
+ let name = match tokens.next() {
+ Some(PPToken {
+ value: PPTokenValue::Ident(name),
+ ..
+ }) => Some(name),
+ Some(PPToken { value, location }) => {
+ self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ });
+
+ None
+ }
+ None => {
+ self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(
+ PreprocessorError::UnexpectedNewLine,
+ ),
+ meta,
+ });
+
+ None
+ }
+ };
+
+ match tokens.next() {
+ Some(PPToken {
+ value: PPTokenValue::Punct(pp_rs::token::Punct::Colon),
+ ..
+ }) => {}
+ Some(PPToken { value, location }) => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ }),
+ None => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine),
+ meta,
+ }),
+ };
+
+ match tokens.next() {
+ Some(PPToken {
+ value: PPTokenValue::Ident(behavior),
+ location,
+ }) => match behavior.as_str() {
+ "require" | "enable" | "warn" | "disable" => {
+ if let Some(name) = name {
+ self.meta.extensions.insert(name);
+ }
+ }
+ _ => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ PPTokenValue::Ident(behavior),
+ )),
+ meta: location.into(),
+ }),
+ },
+ Some(PPToken { value, location }) => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ }),
+ None => self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedNewLine),
+ meta,
+ }),
+ }
+
+ if let Some(PPToken { value, location }) = tokens.next() {
+ self.errors.push(Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedToken(
+ value,
+ )),
+ meta: location.into(),
+ })
+ }
+ }
+ DirectiveKind::Pragma => {
+ // TODO: handle some common pragmas?
+ }
+ }
+ }
+}
+
+pub struct DeclarationContext<'ctx, 'qualifiers, 'a> {
+ qualifiers: TypeQualifiers<'qualifiers>,
+ /// Indicates a global declaration
+ external: bool,
+ is_inside_loop: bool,
+ ctx: &'ctx mut Context<'a>,
+}
+
+impl<'ctx, 'qualifiers, 'a> DeclarationContext<'ctx, 'qualifiers, 'a> {
+ fn add_var(
+ &mut self,
+ frontend: &mut Frontend,
+ ty: Handle<Type>,
+ name: String,
+ init: Option<Handle<Expression>>,
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let decl = VarDeclaration {
+ qualifiers: &mut self.qualifiers,
+ ty,
+ name: Some(name),
+ init,
+ meta,
+ };
+
+ match self.external {
+ true => {
+ let global = frontend.add_global_var(self.ctx, decl)?;
+ let expr = match global {
+ GlobalOrConstant::Global(handle) => Expression::GlobalVariable(handle),
+ GlobalOrConstant::Constant(handle) => Expression::Constant(handle),
+ };
+ Ok(self.ctx.add_expression(expr, meta)?)
+ }
+ false => frontend.add_local_var(self.ctx, decl),
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser/declarations.rs b/third_party/rust/naga/src/front/glsl/parser/declarations.rs
new file mode 100644
index 0000000000..f5e38fb016
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser/declarations.rs
@@ -0,0 +1,677 @@
+use crate::{
+ front::glsl::{
+ ast::{
+ GlobalLookup, GlobalLookupKind, Precision, QualifierKey, QualifierValue,
+ StorageQualifier, StructLayout, TypeQualifiers,
+ },
+ context::{Context, ExprPos},
+ error::ExpectedToken,
+ offset,
+ token::{Token, TokenValue},
+ types::scalar_components,
+ variables::{GlobalOrConstant, VarDeclaration},
+ Error, ErrorKind, Frontend, Span,
+ },
+ proc::Alignment,
+ AddressSpace, Expression, FunctionResult, Handle, Scalar, ScalarKind, Statement, StructMember,
+ Type, TypeInner,
+};
+
+use super::{DeclarationContext, ParsingContext, Result};
+
+/// Helper method used to retrieve the child type of `ty` at
+/// index `i`.
+///
+/// # Note
+///
+/// Does not check if the index is valid and returns the same type
+/// when indexing out-of-bounds a struct or indexing a non indexable
+/// type.
+fn element_or_member_type(
+ ty: Handle<Type>,
+ i: usize,
+ types: &mut crate::UniqueArena<Type>,
+) -> Handle<Type> {
+ match types[ty].inner {
+ // The child type of a vector is a scalar of the same kind and width
+ TypeInner::Vector { scalar, .. } => types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Scalar(scalar),
+ },
+ Default::default(),
+ ),
+ // The child type of a matrix is a vector of floats with the same
+ // width and the size of the matrix rows.
+ TypeInner::Matrix { rows, scalar, .. } => types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector { size: rows, scalar },
+ },
+ Default::default(),
+ ),
+ // The child type of an array is the base type of the array
+ TypeInner::Array { base, .. } => base,
+ // The child type of a struct at index `i` is the type of it's
+ // member at that same index.
+ //
+ // In case the index is out of bounds the same type is returned
+ TypeInner::Struct { ref members, .. } => {
+ members.get(i).map(|member| member.ty).unwrap_or(ty)
+ }
+ // The type isn't indexable, the same type is returned
+ _ => ty,
+ }
+}
+
+impl<'source> ParsingContext<'source> {
+ pub fn parse_external_declaration(
+ &mut self,
+ frontend: &mut Frontend,
+ global_ctx: &mut Context,
+ ) -> Result<()> {
+ if self
+ .parse_declaration(frontend, global_ctx, true, false)?
+ .is_none()
+ {
+ let token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Semicolon if frontend.meta.version == 460 => Ok(()),
+ _ => {
+ let expected = match frontend.meta.version {
+ 460 => vec![TokenValue::Semicolon.into(), ExpectedToken::Eof],
+ _ => vec![ExpectedToken::Eof],
+ };
+ Err(Error {
+ kind: ErrorKind::InvalidToken(token.value, expected),
+ meta: token.meta,
+ })
+ }
+ }
+ } else {
+ Ok(())
+ }
+ }
+
+ pub fn parse_initializer(
+ &mut self,
+ frontend: &mut Frontend,
+ ty: Handle<Type>,
+ ctx: &mut Context,
+ ) -> Result<(Handle<Expression>, Span)> {
+ // initializer:
+ // assignment_expression
+ // LEFT_BRACE initializer_list RIGHT_BRACE
+ // LEFT_BRACE initializer_list COMMA RIGHT_BRACE
+ //
+ // initializer_list:
+ // initializer
+ // initializer_list COMMA initializer
+ if let Some(Token { mut meta, .. }) = self.bump_if(frontend, TokenValue::LeftBrace) {
+ // initializer_list
+ let mut components = Vec::new();
+ loop {
+ // The type expected to be parsed inside the initializer list
+ let new_ty = element_or_member_type(ty, components.len(), &mut ctx.module.types);
+
+ components.push(self.parse_initializer(frontend, new_ty, ctx)?.0);
+
+ let token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Comma => {
+ if let Some(Token { meta: end_meta, .. }) =
+ self.bump_if(frontend, TokenValue::RightBrace)
+ {
+ meta.subsume(end_meta);
+ break;
+ }
+ }
+ TokenValue::RightBrace => {
+ meta.subsume(token.meta);
+ break;
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![TokenValue::Comma.into(), TokenValue::RightBrace.into()],
+ ),
+ meta: token.meta,
+ })
+ }
+ }
+ }
+
+ Ok((
+ ctx.add_expression(Expression::Compose { ty, components }, meta)?,
+ meta,
+ ))
+ } else {
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_assignment(frontend, ctx, &mut stmt)?;
+ let (mut init, init_meta) = ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?;
+
+ let scalar_components = scalar_components(&ctx.module.types[ty].inner);
+ if let Some(scalar) = scalar_components {
+ ctx.implicit_conversion(&mut init, init_meta, scalar)?;
+ }
+
+ Ok((init, init_meta))
+ }
+ }
+
+ // Note: caller preparsed the type and qualifiers
+ // Note: caller skips this if the fallthrough token is not expected to be consumed here so this
+ // produced Error::InvalidToken if it isn't consumed
+ pub fn parse_init_declarator_list(
+ &mut self,
+ frontend: &mut Frontend,
+ mut ty: Handle<Type>,
+ ctx: &mut DeclarationContext,
+ ) -> Result<()> {
+ // init_declarator_list:
+ // single_declaration
+ // init_declarator_list COMMA IDENTIFIER
+ // init_declarator_list COMMA IDENTIFIER array_specifier
+ // init_declarator_list COMMA IDENTIFIER array_specifier EQUAL initializer
+ // init_declarator_list COMMA IDENTIFIER EQUAL initializer
+ //
+ // single_declaration:
+ // fully_specified_type
+ // fully_specified_type IDENTIFIER
+ // fully_specified_type IDENTIFIER array_specifier
+ // fully_specified_type IDENTIFIER array_specifier EQUAL initializer
+ // fully_specified_type IDENTIFIER EQUAL initializer
+
+ // Consume any leading comma, e.g. this is valid: `float, a=1;`
+ if self
+ .peek(frontend)
+ .map_or(false, |t| t.value == TokenValue::Comma)
+ {
+ self.next(frontend);
+ }
+
+ loop {
+ let token = self.bump(frontend)?;
+ let name = match token.value {
+ TokenValue::Semicolon => break,
+ TokenValue::Identifier(name) => name,
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()],
+ ),
+ meta: token.meta,
+ })
+ }
+ };
+ let mut meta = token.meta;
+
+ // array_specifier
+ // array_specifier EQUAL initializer
+ // EQUAL initializer
+
+ // parse an array specifier if it exists
+ // NOTE: unlike other parse methods this one doesn't expect an array specifier and
+ // returns Ok(None) rather than an error if there is not one
+ self.parse_array_specifier(frontend, ctx.ctx, &mut meta, &mut ty)?;
+
+ let is_global_const =
+ ctx.qualifiers.storage.0 == StorageQualifier::Const && ctx.external;
+
+ let init = self
+ .bump_if(frontend, TokenValue::Assign)
+ .map::<Result<_>, _>(|_| {
+ let prev_const = ctx.ctx.is_const;
+ ctx.ctx.is_const = is_global_const;
+
+ let (mut expr, init_meta) = self.parse_initializer(frontend, ty, ctx.ctx)?;
+
+ let scalar_components = scalar_components(&ctx.ctx.module.types[ty].inner);
+ if let Some(scalar) = scalar_components {
+ ctx.ctx.implicit_conversion(&mut expr, init_meta, scalar)?;
+ }
+
+ ctx.ctx.is_const = prev_const;
+
+ meta.subsume(init_meta);
+
+ Ok(expr)
+ })
+ .transpose()?;
+
+ let decl_initializer;
+ let late_initializer;
+ if is_global_const {
+ decl_initializer = init;
+ late_initializer = None;
+ } else if ctx.external {
+ decl_initializer =
+ init.and_then(|expr| ctx.ctx.lift_up_const_expression(expr).ok());
+ late_initializer = None;
+ } else if let Some(init) = init {
+ if ctx.is_inside_loop || !ctx.ctx.expression_constness.is_const(init) {
+ decl_initializer = None;
+ late_initializer = Some(init);
+ } else {
+ decl_initializer = Some(init);
+ late_initializer = None;
+ }
+ } else {
+ decl_initializer = None;
+ late_initializer = None;
+ };
+
+ let pointer = ctx.add_var(frontend, ty, name, decl_initializer, meta)?;
+
+ if let Some(value) = late_initializer {
+ ctx.ctx.emit_restart();
+ ctx.ctx.body.push(Statement::Store { pointer, value }, meta);
+ }
+
+ let token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Semicolon => break,
+ TokenValue::Comma => {}
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![TokenValue::Comma.into(), TokenValue::Semicolon.into()],
+ ),
+ meta: token.meta,
+ })
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ /// `external` whether or not we are in a global or local context
+ pub fn parse_declaration(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ external: bool,
+ is_inside_loop: bool,
+ ) -> Result<Option<Span>> {
+ //declaration:
+ // function_prototype SEMICOLON
+ //
+ // init_declarator_list SEMICOLON
+ // PRECISION precision_qualifier type_specifier SEMICOLON
+ //
+ // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE SEMICOLON
+ // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE IDENTIFIER SEMICOLON
+ // type_qualifier IDENTIFIER LEFT_BRACE struct_declaration_list RIGHT_BRACE IDENTIFIER array_specifier SEMICOLON
+ // type_qualifier SEMICOLON type_qualifier IDENTIFIER SEMICOLON
+ // type_qualifier IDENTIFIER identifier_list SEMICOLON
+
+ if self.peek_type_qualifier(frontend) || self.peek_type_name(frontend) {
+ let mut qualifiers = self.parse_type_qualifiers(frontend, ctx)?;
+
+ if self.peek_type_name(frontend) {
+ // This branch handles variables and function prototypes and if
+ // external is true also function definitions
+ let (ty, mut meta) = self.parse_type(frontend, ctx)?;
+
+ let token = self.bump(frontend)?;
+ let token_fallthrough = match token.value {
+ TokenValue::Identifier(name) => match self.expect_peek(frontend)?.value {
+ TokenValue::LeftParen => {
+ // This branch handles function definition and prototypes
+ self.bump(frontend)?;
+
+ let result = ty.map(|ty| FunctionResult { ty, binding: None });
+
+ let mut context = Context::new(frontend, ctx.module, false)?;
+
+ self.parse_function_args(frontend, &mut context)?;
+
+ let end_meta = self.expect(frontend, TokenValue::RightParen)?.meta;
+ meta.subsume(end_meta);
+
+ let token = self.bump(frontend)?;
+ return match token.value {
+ TokenValue::Semicolon => {
+ // This branch handles function prototypes
+ frontend.add_prototype(context, name, result, meta);
+
+ Ok(Some(meta))
+ }
+ TokenValue::LeftBrace if external => {
+ // This branch handles function definitions
+ // as you can see by the guard this branch
+ // only happens if external is also true
+
+ // parse the body
+ self.parse_compound_statement(
+ token.meta,
+ frontend,
+ &mut context,
+ &mut None,
+ false,
+ )?;
+
+ frontend.add_function(context, name, result, meta);
+
+ Ok(Some(meta))
+ }
+ _ if external => Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![
+ TokenValue::LeftBrace.into(),
+ TokenValue::Semicolon.into(),
+ ],
+ ),
+ meta: token.meta,
+ }),
+ _ => Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![TokenValue::Semicolon.into()],
+ ),
+ meta: token.meta,
+ }),
+ };
+ }
+ // Pass the token to the init_declarator_list parser
+ _ => Token {
+ value: TokenValue::Identifier(name),
+ meta: token.meta,
+ },
+ },
+ // Pass the token to the init_declarator_list parser
+ _ => token,
+ };
+
+ // If program execution has reached here then this will be a
+ // init_declarator_list
+ // token_fallthrough will have a token that was already bumped
+ if let Some(ty) = ty {
+ let mut ctx = DeclarationContext {
+ qualifiers,
+ external,
+ is_inside_loop,
+ ctx,
+ };
+
+ self.backtrack(token_fallthrough)?;
+ self.parse_init_declarator_list(frontend, ty, &mut ctx)?;
+ } else {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError("Declaration cannot have void type".into()),
+ meta,
+ })
+ }
+
+ Ok(Some(meta))
+ } else {
+ // This branch handles struct definitions and modifiers like
+ // ```glsl
+ // layout(early_fragment_tests);
+ // ```
+ let token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Identifier(ty_name) => {
+ if self.bump_if(frontend, TokenValue::LeftBrace).is_some() {
+ self.parse_block_declaration(
+ frontend,
+ ctx,
+ &mut qualifiers,
+ ty_name,
+ token.meta,
+ )
+ .map(Some)
+ } else {
+ if qualifiers.invariant.take().is_some() {
+ frontend.make_variable_invariant(ctx, &ty_name, token.meta)?;
+
+ qualifiers.unused_errors(&mut frontend.errors);
+ self.expect(frontend, TokenValue::Semicolon)?;
+ return Ok(Some(qualifiers.span));
+ }
+
+ //TODO: declaration
+ // type_qualifier IDENTIFIER SEMICOLON
+ // type_qualifier IDENTIFIER identifier_list SEMICOLON
+ Err(Error {
+ kind: ErrorKind::NotImplemented("variable qualifier"),
+ meta: token.meta,
+ })
+ }
+ }
+ TokenValue::Semicolon => {
+ if let Some(value) =
+ qualifiers.uint_layout_qualifier("local_size_x", &mut frontend.errors)
+ {
+ frontend.meta.workgroup_size[0] = value;
+ }
+ if let Some(value) =
+ qualifiers.uint_layout_qualifier("local_size_y", &mut frontend.errors)
+ {
+ frontend.meta.workgroup_size[1] = value;
+ }
+ if let Some(value) =
+ qualifiers.uint_layout_qualifier("local_size_z", &mut frontend.errors)
+ {
+ frontend.meta.workgroup_size[2] = value;
+ }
+
+ frontend.meta.early_fragment_tests |= qualifiers
+ .none_layout_qualifier("early_fragment_tests", &mut frontend.errors);
+
+ qualifiers.unused_errors(&mut frontend.errors);
+
+ Ok(Some(qualifiers.span))
+ }
+ _ => Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()],
+ ),
+ meta: token.meta,
+ }),
+ }
+ }
+ } else {
+ match self.peek(frontend).map(|t| &t.value) {
+ Some(&TokenValue::Precision) => {
+ // PRECISION precision_qualifier type_specifier SEMICOLON
+ self.bump(frontend)?;
+
+ let token = self.bump(frontend)?;
+ let _ = match token.value {
+ TokenValue::PrecisionQualifier(p) => p,
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![
+ TokenValue::PrecisionQualifier(Precision::High).into(),
+ TokenValue::PrecisionQualifier(Precision::Medium).into(),
+ TokenValue::PrecisionQualifier(Precision::Low).into(),
+ ],
+ ),
+ meta: token.meta,
+ })
+ }
+ };
+
+ let (ty, meta) = self.parse_type_non_void(frontend, ctx)?;
+
+ match ctx.module.types[ty].inner {
+ TypeInner::Scalar(Scalar {
+ kind: ScalarKind::Float | ScalarKind::Sint,
+ ..
+ }) => {}
+ _ => frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Precision statement can only work on floats and ints".into(),
+ ),
+ meta,
+ }),
+ }
+
+ self.expect(frontend, TokenValue::Semicolon)?;
+
+ Ok(Some(meta))
+ }
+ _ => Ok(None),
+ }
+ }
+ }
+
+ pub fn parse_block_declaration(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ qualifiers: &mut TypeQualifiers,
+ ty_name: String,
+ mut meta: Span,
+ ) -> Result<Span> {
+ let layout = match qualifiers.layout_qualifiers.remove(&QualifierKey::Layout) {
+ Some((QualifierValue::Layout(l), _)) => l,
+ None => {
+ if let StorageQualifier::AddressSpace(AddressSpace::Storage { .. }) =
+ qualifiers.storage.0
+ {
+ StructLayout::Std430
+ } else {
+ StructLayout::Std140
+ }
+ }
+ _ => unreachable!(),
+ };
+
+ let mut members = Vec::new();
+ let span = self.parse_struct_declaration_list(frontend, ctx, &mut members, layout)?;
+ self.expect(frontend, TokenValue::RightBrace)?;
+
+ let mut ty = ctx.module.types.insert(
+ Type {
+ name: Some(ty_name),
+ inner: TypeInner::Struct {
+ members: members.clone(),
+ span,
+ },
+ },
+ Default::default(),
+ );
+
+ let token = self.bump(frontend)?;
+ let name = match token.value {
+ TokenValue::Semicolon => None,
+ TokenValue::Identifier(name) => {
+ self.parse_array_specifier(frontend, ctx, &mut meta, &mut ty)?;
+
+ self.expect(frontend, TokenValue::Semicolon)?;
+
+ Some(name)
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![ExpectedToken::Identifier, TokenValue::Semicolon.into()],
+ ),
+ meta: token.meta,
+ })
+ }
+ };
+
+ let global = frontend.add_global_var(
+ ctx,
+ VarDeclaration {
+ qualifiers,
+ ty,
+ name,
+ init: None,
+ meta,
+ },
+ )?;
+
+ for (i, k, ty) in members.into_iter().enumerate().filter_map(|(i, m)| {
+ let ty = m.ty;
+ m.name.map(|s| (i as u32, s, ty))
+ }) {
+ let lookup = GlobalLookup {
+ kind: match global {
+ GlobalOrConstant::Global(handle) => GlobalLookupKind::BlockSelect(handle, i),
+ GlobalOrConstant::Constant(handle) => GlobalLookupKind::Constant(handle, ty),
+ },
+ entry_arg: None,
+ mutable: true,
+ };
+ ctx.add_global(&k, lookup)?;
+
+ frontend.global_variables.push((k, lookup));
+ }
+
+ Ok(meta)
+ }
+
+ // TODO: Accept layout arguments
+ pub fn parse_struct_declaration_list(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ members: &mut Vec<StructMember>,
+ layout: StructLayout,
+ ) -> Result<u32> {
+ let mut span = 0;
+ let mut align = Alignment::ONE;
+
+ loop {
+ // TODO: type_qualifier
+
+ let (base_ty, mut meta) = self.parse_type_non_void(frontend, ctx)?;
+
+ loop {
+ let (name, name_meta) = self.expect_ident(frontend)?;
+ let mut ty = base_ty;
+ self.parse_array_specifier(frontend, ctx, &mut meta, &mut ty)?;
+
+ meta.subsume(name_meta);
+
+ let info = offset::calculate_offset(
+ ty,
+ meta,
+ layout,
+ &mut ctx.module.types,
+ &mut frontend.errors,
+ );
+
+ let member_alignment = info.align;
+ span = member_alignment.round_up(span);
+ align = member_alignment.max(align);
+
+ members.push(StructMember {
+ name: Some(name),
+ ty: info.ty,
+ binding: None,
+ offset: span,
+ });
+
+ span += info.span;
+
+ if self.bump_if(frontend, TokenValue::Comma).is_none() {
+ break;
+ }
+ }
+
+ self.expect(frontend, TokenValue::Semicolon)?;
+
+ if let TokenValue::RightBrace = self.expect_peek(frontend)?.value {
+ break;
+ }
+ }
+
+ span = align.round_up(span);
+
+ Ok(span)
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser/expressions.rs b/third_party/rust/naga/src/front/glsl/parser/expressions.rs
new file mode 100644
index 0000000000..1b8febce90
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser/expressions.rs
@@ -0,0 +1,542 @@
+use std::num::NonZeroU32;
+
+use crate::{
+ front::glsl::{
+ ast::{FunctionCall, FunctionCallKind, HirExpr, HirExprKind},
+ context::{Context, StmtContext},
+ error::{ErrorKind, ExpectedToken},
+ parser::ParsingContext,
+ token::{Token, TokenValue},
+ Error, Frontend, Result, Span,
+ },
+ ArraySize, BinaryOperator, Handle, Literal, Type, TypeInner, UnaryOperator,
+};
+
+impl<'source> ParsingContext<'source> {
+ pub fn parse_primary(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ ) -> Result<Handle<HirExpr>> {
+ let mut token = self.bump(frontend)?;
+
+ let literal = match token.value {
+ TokenValue::IntConstant(int) => {
+ if int.width != 32 {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError("Unsupported non-32bit integer".into()),
+ meta: token.meta,
+ });
+ }
+ if int.signed {
+ Literal::I32(int.value as i32)
+ } else {
+ Literal::U32(int.value as u32)
+ }
+ }
+ TokenValue::FloatConstant(float) => {
+ if float.width != 32 {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError("Unsupported floating-point value (expected single-precision floating-point number)".into()),
+ meta: token.meta,
+ });
+ }
+ Literal::F32(float.value)
+ }
+ TokenValue::BoolConstant(value) => Literal::Bool(value),
+ TokenValue::LeftParen => {
+ let expr = self.parse_expression(frontend, ctx, stmt)?;
+ let meta = self.expect(frontend, TokenValue::RightParen)?.meta;
+
+ token.meta.subsume(meta);
+
+ return Ok(expr);
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![
+ TokenValue::LeftParen.into(),
+ ExpectedToken::IntLiteral,
+ ExpectedToken::FloatLiteral,
+ ExpectedToken::BoolLiteral,
+ ],
+ ),
+ meta: token.meta,
+ });
+ }
+ };
+
+ Ok(stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Literal(literal),
+ meta: token.meta,
+ },
+ Default::default(),
+ ))
+ }
+
+ pub fn parse_function_call_args(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ meta: &mut Span,
+ ) -> Result<Vec<Handle<HirExpr>>> {
+ let mut args = Vec::new();
+ if let Some(token) = self.bump_if(frontend, TokenValue::RightParen) {
+ meta.subsume(token.meta);
+ } else {
+ loop {
+ args.push(self.parse_assignment(frontend, ctx, stmt)?);
+
+ let token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Comma => {}
+ TokenValue::RightParen => {
+ meta.subsume(token.meta);
+ break;
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![TokenValue::Comma.into(), TokenValue::RightParen.into()],
+ ),
+ meta: token.meta,
+ });
+ }
+ }
+ }
+ }
+
+ Ok(args)
+ }
+
+ pub fn parse_postfix(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ ) -> Result<Handle<HirExpr>> {
+ let mut base = if self.peek_type_name(frontend) {
+ let (mut handle, mut meta) = self.parse_type_non_void(frontend, ctx)?;
+
+ self.expect(frontend, TokenValue::LeftParen)?;
+ let args = self.parse_function_call_args(frontend, ctx, stmt, &mut meta)?;
+
+ if let TypeInner::Array {
+ size: ArraySize::Dynamic,
+ stride,
+ base,
+ } = ctx.module.types[handle].inner
+ {
+ let span = ctx.module.types.get_span(handle);
+
+ let size = u32::try_from(args.len())
+ .ok()
+ .and_then(NonZeroU32::new)
+ .ok_or(Error {
+ kind: ErrorKind::SemanticError(
+ "There must be at least one argument".into(),
+ ),
+ meta,
+ })?;
+
+ handle = ctx.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Array {
+ stride,
+ base,
+ size: ArraySize::Constant(size),
+ },
+ },
+ span,
+ )
+ }
+
+ stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Call(FunctionCall {
+ kind: FunctionCallKind::TypeConstructor(handle),
+ args,
+ }),
+ meta,
+ },
+ Default::default(),
+ )
+ } else if let TokenValue::Identifier(_) = self.expect_peek(frontend)?.value {
+ let (name, mut meta) = self.expect_ident(frontend)?;
+
+ let expr = if self.bump_if(frontend, TokenValue::LeftParen).is_some() {
+ let args = self.parse_function_call_args(frontend, ctx, stmt, &mut meta)?;
+
+ let kind = match frontend.lookup_type.get(&name) {
+ Some(ty) => FunctionCallKind::TypeConstructor(*ty),
+ None => FunctionCallKind::Function(name),
+ };
+
+ HirExpr {
+ kind: HirExprKind::Call(FunctionCall { kind, args }),
+ meta,
+ }
+ } else {
+ let var = match frontend.lookup_variable(ctx, &name, meta)? {
+ Some(var) => var,
+ None => {
+ return Err(Error {
+ kind: ErrorKind::UnknownVariable(name),
+ meta,
+ })
+ }
+ };
+
+ HirExpr {
+ kind: HirExprKind::Variable(var),
+ meta,
+ }
+ };
+
+ stmt.hir_exprs.append(expr, Default::default())
+ } else {
+ self.parse_primary(frontend, ctx, stmt)?
+ };
+
+ while let TokenValue::LeftBracket
+ | TokenValue::Dot
+ | TokenValue::Increment
+ | TokenValue::Decrement = self.expect_peek(frontend)?.value
+ {
+ let Token { value, mut meta } = self.bump(frontend)?;
+
+ match value {
+ TokenValue::LeftBracket => {
+ let index = self.parse_expression(frontend, ctx, stmt)?;
+ let end_meta = self.expect(frontend, TokenValue::RightBracket)?.meta;
+
+ meta.subsume(end_meta);
+ base = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Access { base, index },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ TokenValue::Dot => {
+ let (field, end_meta) = self.expect_ident(frontend)?;
+
+ if self.bump_if(frontend, TokenValue::LeftParen).is_some() {
+ let args = self.parse_function_call_args(frontend, ctx, stmt, &mut meta)?;
+
+ base = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Method {
+ expr: base,
+ name: field,
+ args,
+ },
+ meta,
+ },
+ Default::default(),
+ );
+ continue;
+ }
+
+ meta.subsume(end_meta);
+ base = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Select { base, field },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ TokenValue::Increment | TokenValue::Decrement => {
+ base = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::PrePostfix {
+ op: match value {
+ TokenValue::Increment => crate::BinaryOperator::Add,
+ _ => crate::BinaryOperator::Subtract,
+ },
+ postfix: true,
+ expr: base,
+ },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ Ok(base)
+ }
+
+ pub fn parse_unary(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ ) -> Result<Handle<HirExpr>> {
+ Ok(match self.expect_peek(frontend)?.value {
+ TokenValue::Plus | TokenValue::Dash | TokenValue::Bang | TokenValue::Tilde => {
+ let Token { value, mut meta } = self.bump(frontend)?;
+
+ let expr = self.parse_unary(frontend, ctx, stmt)?;
+ let end_meta = stmt.hir_exprs[expr].meta;
+
+ let kind = match value {
+ TokenValue::Dash => HirExprKind::Unary {
+ op: UnaryOperator::Negate,
+ expr,
+ },
+ TokenValue::Bang => HirExprKind::Unary {
+ op: UnaryOperator::LogicalNot,
+ expr,
+ },
+ TokenValue::Tilde => HirExprKind::Unary {
+ op: UnaryOperator::BitwiseNot,
+ expr,
+ },
+ _ => return Ok(expr),
+ };
+
+ meta.subsume(end_meta);
+ stmt.hir_exprs
+ .append(HirExpr { kind, meta }, Default::default())
+ }
+ TokenValue::Increment | TokenValue::Decrement => {
+ let Token { value, meta } = self.bump(frontend)?;
+
+ let expr = self.parse_unary(frontend, ctx, stmt)?;
+
+ stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::PrePostfix {
+ op: match value {
+ TokenValue::Increment => crate::BinaryOperator::Add,
+ _ => crate::BinaryOperator::Subtract,
+ },
+ postfix: false,
+ expr,
+ },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ _ => self.parse_postfix(frontend, ctx, stmt)?,
+ })
+ }
+
+ pub fn parse_binary(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ passthrough: Option<Handle<HirExpr>>,
+ min_bp: u8,
+ ) -> Result<Handle<HirExpr>> {
+ let mut left = passthrough
+ .ok_or(ErrorKind::EndOfFile /* Dummy error */)
+ .or_else(|_| self.parse_unary(frontend, ctx, stmt))?;
+ let mut meta = stmt.hir_exprs[left].meta;
+
+ while let Some((l_bp, r_bp)) = binding_power(&self.expect_peek(frontend)?.value) {
+ if l_bp < min_bp {
+ break;
+ }
+
+ let Token { value, .. } = self.bump(frontend)?;
+
+ let right = self.parse_binary(frontend, ctx, stmt, None, r_bp)?;
+ let end_meta = stmt.hir_exprs[right].meta;
+
+ meta.subsume(end_meta);
+ left = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Binary {
+ left,
+ op: match value {
+ TokenValue::LogicalOr => BinaryOperator::LogicalOr,
+ TokenValue::LogicalXor => BinaryOperator::NotEqual,
+ TokenValue::LogicalAnd => BinaryOperator::LogicalAnd,
+ TokenValue::VerticalBar => BinaryOperator::InclusiveOr,
+ TokenValue::Caret => BinaryOperator::ExclusiveOr,
+ TokenValue::Ampersand => BinaryOperator::And,
+ TokenValue::Equal => BinaryOperator::Equal,
+ TokenValue::NotEqual => BinaryOperator::NotEqual,
+ TokenValue::GreaterEqual => BinaryOperator::GreaterEqual,
+ TokenValue::LessEqual => BinaryOperator::LessEqual,
+ TokenValue::LeftAngle => BinaryOperator::Less,
+ TokenValue::RightAngle => BinaryOperator::Greater,
+ TokenValue::LeftShift => BinaryOperator::ShiftLeft,
+ TokenValue::RightShift => BinaryOperator::ShiftRight,
+ TokenValue::Plus => BinaryOperator::Add,
+ TokenValue::Dash => BinaryOperator::Subtract,
+ TokenValue::Star => BinaryOperator::Multiply,
+ TokenValue::Slash => BinaryOperator::Divide,
+ TokenValue::Percent => BinaryOperator::Modulo,
+ _ => unreachable!(),
+ },
+ right,
+ },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+
+ Ok(left)
+ }
+
+ pub fn parse_conditional(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ passthrough: Option<Handle<HirExpr>>,
+ ) -> Result<Handle<HirExpr>> {
+ let mut condition = self.parse_binary(frontend, ctx, stmt, passthrough, 0)?;
+ let mut meta = stmt.hir_exprs[condition].meta;
+
+ if self.bump_if(frontend, TokenValue::Question).is_some() {
+ let accept = self.parse_expression(frontend, ctx, stmt)?;
+ self.expect(frontend, TokenValue::Colon)?;
+ let reject = self.parse_assignment(frontend, ctx, stmt)?;
+ let end_meta = stmt.hir_exprs[reject].meta;
+
+ meta.subsume(end_meta);
+ condition = stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Conditional {
+ condition,
+ accept,
+ reject,
+ },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+
+ Ok(condition)
+ }
+
+ pub fn parse_assignment(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ ) -> Result<Handle<HirExpr>> {
+ let tgt = self.parse_unary(frontend, ctx, stmt)?;
+ let mut meta = stmt.hir_exprs[tgt].meta;
+
+ Ok(match self.expect_peek(frontend)?.value {
+ TokenValue::Assign => {
+ self.bump(frontend)?;
+ let value = self.parse_assignment(frontend, ctx, stmt)?;
+ let end_meta = stmt.hir_exprs[value].meta;
+
+ meta.subsume(end_meta);
+ stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Assign { tgt, value },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ TokenValue::OrAssign
+ | TokenValue::AndAssign
+ | TokenValue::AddAssign
+ | TokenValue::DivAssign
+ | TokenValue::ModAssign
+ | TokenValue::SubAssign
+ | TokenValue::MulAssign
+ | TokenValue::LeftShiftAssign
+ | TokenValue::RightShiftAssign
+ | TokenValue::XorAssign => {
+ let token = self.bump(frontend)?;
+ let right = self.parse_assignment(frontend, ctx, stmt)?;
+ let end_meta = stmt.hir_exprs[right].meta;
+
+ meta.subsume(end_meta);
+ let value = stmt.hir_exprs.append(
+ HirExpr {
+ meta,
+ kind: HirExprKind::Binary {
+ left: tgt,
+ op: match token.value {
+ TokenValue::OrAssign => BinaryOperator::InclusiveOr,
+ TokenValue::AndAssign => BinaryOperator::And,
+ TokenValue::AddAssign => BinaryOperator::Add,
+ TokenValue::DivAssign => BinaryOperator::Divide,
+ TokenValue::ModAssign => BinaryOperator::Modulo,
+ TokenValue::SubAssign => BinaryOperator::Subtract,
+ TokenValue::MulAssign => BinaryOperator::Multiply,
+ TokenValue::LeftShiftAssign => BinaryOperator::ShiftLeft,
+ TokenValue::RightShiftAssign => BinaryOperator::ShiftRight,
+ TokenValue::XorAssign => BinaryOperator::ExclusiveOr,
+ _ => unreachable!(),
+ },
+ right,
+ },
+ },
+ Default::default(),
+ );
+
+ stmt.hir_exprs.append(
+ HirExpr {
+ kind: HirExprKind::Assign { tgt, value },
+ meta,
+ },
+ Default::default(),
+ )
+ }
+ _ => self.parse_conditional(frontend, ctx, stmt, Some(tgt))?,
+ })
+ }
+
+ pub fn parse_expression(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ stmt: &mut StmtContext,
+ ) -> Result<Handle<HirExpr>> {
+ let mut expr = self.parse_assignment(frontend, ctx, stmt)?;
+
+ while let TokenValue::Comma = self.expect_peek(frontend)?.value {
+ self.bump(frontend)?;
+ expr = self.parse_assignment(frontend, ctx, stmt)?;
+ }
+
+ Ok(expr)
+ }
+}
+
+const fn binding_power(value: &TokenValue) -> Option<(u8, u8)> {
+ Some(match *value {
+ TokenValue::LogicalOr => (1, 2),
+ TokenValue::LogicalXor => (3, 4),
+ TokenValue::LogicalAnd => (5, 6),
+ TokenValue::VerticalBar => (7, 8),
+ TokenValue::Caret => (9, 10),
+ TokenValue::Ampersand => (11, 12),
+ TokenValue::Equal | TokenValue::NotEqual => (13, 14),
+ TokenValue::GreaterEqual
+ | TokenValue::LessEqual
+ | TokenValue::LeftAngle
+ | TokenValue::RightAngle => (15, 16),
+ TokenValue::LeftShift | TokenValue::RightShift => (17, 18),
+ TokenValue::Plus | TokenValue::Dash => (19, 20),
+ TokenValue::Star | TokenValue::Slash | TokenValue::Percent => (21, 22),
+ _ => return None,
+ })
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser/functions.rs b/third_party/rust/naga/src/front/glsl/parser/functions.rs
new file mode 100644
index 0000000000..38184eedf7
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser/functions.rs
@@ -0,0 +1,656 @@
+use crate::front::glsl::context::ExprPos;
+use crate::front::glsl::Span;
+use crate::Literal;
+use crate::{
+ front::glsl::{
+ ast::ParameterQualifier,
+ context::Context,
+ parser::ParsingContext,
+ token::{Token, TokenValue},
+ variables::VarDeclaration,
+ Error, ErrorKind, Frontend, Result,
+ },
+ Block, Expression, Statement, SwitchCase, UnaryOperator,
+};
+
+impl<'source> ParsingContext<'source> {
+ pub fn peek_parameter_qualifier(&mut self, frontend: &mut Frontend) -> bool {
+ self.peek(frontend).map_or(false, |t| match t.value {
+ TokenValue::In | TokenValue::Out | TokenValue::InOut | TokenValue::Const => true,
+ _ => false,
+ })
+ }
+
+ /// Returns the parsed `ParameterQualifier` or `ParameterQualifier::In`
+ pub fn parse_parameter_qualifier(&mut self, frontend: &mut Frontend) -> ParameterQualifier {
+ if self.peek_parameter_qualifier(frontend) {
+ match self.bump(frontend).unwrap().value {
+ TokenValue::In => ParameterQualifier::In,
+ TokenValue::Out => ParameterQualifier::Out,
+ TokenValue::InOut => ParameterQualifier::InOut,
+ TokenValue::Const => ParameterQualifier::Const,
+ _ => unreachable!(),
+ }
+ } else {
+ ParameterQualifier::In
+ }
+ }
+
+ pub fn parse_statement(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ terminator: &mut Option<usize>,
+ is_inside_loop: bool,
+ ) -> Result<Option<Span>> {
+ // Type qualifiers always identify a declaration statement
+ if self.peek_type_qualifier(frontend) {
+ return self.parse_declaration(frontend, ctx, false, is_inside_loop);
+ }
+
+ // Type names can identify either declaration statements or type constructors
+ // depending on whether the token following the type name is a `(` (LeftParen)
+ if self.peek_type_name(frontend) {
+ // Start by consuming the type name so that we can peek the token after it
+ let token = self.bump(frontend)?;
+ // Peek the next token and check if it's a `(` (LeftParen) if so the statement
+ // is a constructor, otherwise it's a declaration. We need to do the check
+ // beforehand and not in the if since we will backtrack before the if
+ let declaration = TokenValue::LeftParen != self.expect_peek(frontend)?.value;
+
+ self.backtrack(token)?;
+
+ if declaration {
+ return self.parse_declaration(frontend, ctx, false, is_inside_loop);
+ }
+ }
+
+ let new_break = || {
+ let mut block = Block::new();
+ block.push(Statement::Break, crate::Span::default());
+ block
+ };
+
+ let &Token {
+ ref value,
+ mut meta,
+ } = self.expect_peek(frontend)?;
+
+ let meta_rest = match *value {
+ TokenValue::Continue => {
+ let meta = self.bump(frontend)?.meta;
+ ctx.body.push(Statement::Continue, meta);
+ terminator.get_or_insert(ctx.body.len());
+ self.expect(frontend, TokenValue::Semicolon)?.meta
+ }
+ TokenValue::Break => {
+ let meta = self.bump(frontend)?.meta;
+ ctx.body.push(Statement::Break, meta);
+ terminator.get_or_insert(ctx.body.len());
+ self.expect(frontend, TokenValue::Semicolon)?.meta
+ }
+ TokenValue::Return => {
+ self.bump(frontend)?;
+ let (value, meta) = match self.expect_peek(frontend)?.value {
+ TokenValue::Semicolon => (None, self.bump(frontend)?.meta),
+ _ => {
+ // TODO: Implicit conversions
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt)?;
+ self.expect(frontend, TokenValue::Semicolon)?;
+ let (handle, meta) =
+ ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?;
+ (Some(handle), meta)
+ }
+ };
+
+ ctx.emit_restart();
+
+ ctx.body.push(Statement::Return { value }, meta);
+ terminator.get_or_insert(ctx.body.len());
+
+ meta
+ }
+ TokenValue::Discard => {
+ let meta = self.bump(frontend)?.meta;
+ ctx.body.push(Statement::Kill, meta);
+ terminator.get_or_insert(ctx.body.len());
+
+ self.expect(frontend, TokenValue::Semicolon)?.meta
+ }
+ TokenValue::If => {
+ let mut meta = self.bump(frontend)?.meta;
+
+ self.expect(frontend, TokenValue::LeftParen)?;
+ let condition = {
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt)?;
+ let (handle, more_meta) =
+ ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?;
+ meta.subsume(more_meta);
+ handle
+ };
+ self.expect(frontend, TokenValue::RightParen)?;
+
+ let accept = ctx.new_body(|ctx| {
+ if let Some(more_meta) =
+ self.parse_statement(frontend, ctx, &mut None, is_inside_loop)?
+ {
+ meta.subsume(more_meta);
+ }
+ Ok(())
+ })?;
+
+ let reject = ctx.new_body(|ctx| {
+ if self.bump_if(frontend, TokenValue::Else).is_some() {
+ if let Some(more_meta) =
+ self.parse_statement(frontend, ctx, &mut None, is_inside_loop)?
+ {
+ meta.subsume(more_meta);
+ }
+ }
+ Ok(())
+ })?;
+
+ ctx.body.push(
+ Statement::If {
+ condition,
+ accept,
+ reject,
+ },
+ meta,
+ );
+
+ meta
+ }
+ TokenValue::Switch => {
+ let mut meta = self.bump(frontend)?.meta;
+ let end_meta;
+
+ self.expect(frontend, TokenValue::LeftParen)?;
+
+ let (selector, uint) = {
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt)?;
+ let (root, meta) = ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?;
+ let uint = ctx.resolve_type(root, meta)?.scalar_kind()
+ == Some(crate::ScalarKind::Uint);
+ (root, uint)
+ };
+
+ self.expect(frontend, TokenValue::RightParen)?;
+
+ ctx.emit_restart();
+
+ let mut cases = Vec::new();
+ // Track if any default case is present in the switch statement.
+ let mut default_present = false;
+
+ self.expect(frontend, TokenValue::LeftBrace)?;
+ loop {
+ let value = match self.expect_peek(frontend)?.value {
+ TokenValue::Case => {
+ self.bump(frontend)?;
+
+ let (const_expr, meta) =
+ self.parse_constant_expression(frontend, ctx.module)?;
+
+ match ctx.module.const_expressions[const_expr] {
+ Expression::Literal(Literal::I32(value)) => match uint {
+ // This unchecked cast isn't good, but since
+ // we only reach this code when the selector
+ // is unsigned but the case label is signed,
+ // verification will reject the module
+ // anyway (which also matches GLSL's rules).
+ true => crate::SwitchValue::U32(value as u32),
+ false => crate::SwitchValue::I32(value),
+ },
+ Expression::Literal(Literal::U32(value)) => {
+ crate::SwitchValue::U32(value)
+ }
+ _ => {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Case values can only be integers".into(),
+ ),
+ meta,
+ });
+
+ crate::SwitchValue::I32(0)
+ }
+ }
+ }
+ TokenValue::Default => {
+ self.bump(frontend)?;
+ default_present = true;
+ crate::SwitchValue::Default
+ }
+ TokenValue::RightBrace => {
+ end_meta = self.bump(frontend)?.meta;
+ break;
+ }
+ _ => {
+ let Token { value, meta } = self.bump(frontend)?;
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ value,
+ vec![
+ TokenValue::Case.into(),
+ TokenValue::Default.into(),
+ TokenValue::RightBrace.into(),
+ ],
+ ),
+ meta,
+ });
+ }
+ };
+
+ self.expect(frontend, TokenValue::Colon)?;
+
+ let mut fall_through = true;
+
+ let body = ctx.new_body(|ctx| {
+ let mut case_terminator = None;
+ loop {
+ match self.expect_peek(frontend)?.value {
+ TokenValue::Case | TokenValue::Default | TokenValue::RightBrace => {
+ break
+ }
+ _ => {
+ self.parse_statement(
+ frontend,
+ ctx,
+ &mut case_terminator,
+ is_inside_loop,
+ )?;
+ }
+ }
+ }
+
+ if let Some(mut idx) = case_terminator {
+ if let Statement::Break = ctx.body[idx - 1] {
+ fall_through = false;
+ idx -= 1;
+ }
+
+ ctx.body.cull(idx..)
+ }
+
+ Ok(())
+ })?;
+
+ cases.push(SwitchCase {
+ value,
+ body,
+ fall_through,
+ })
+ }
+
+ meta.subsume(end_meta);
+
+ // NOTE: do not unwrap here since a switch statement isn't required
+ // to have any cases.
+ if let Some(case) = cases.last_mut() {
+ // GLSL requires that the last case not be empty, so we check
+ // that here and produce an error otherwise (fall_through must
+ // also be checked because `break`s count as statements but
+ // they aren't added to the body)
+ if case.body.is_empty() && case.fall_through {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "last case/default label must be followed by statements".into(),
+ ),
+ meta,
+ })
+ }
+
+ // GLSL allows the last case to not have any `break` statement,
+ // this would mark it as fall through but naga's IR requires that
+ // the last case must not be fall through, so we mark need to mark
+ // the last case as not fall through always.
+ case.fall_through = false;
+ }
+
+ // Add an empty default case in case non was present, this is needed because
+ // naga's IR requires that all switch statements must have a default case but
+ // GLSL doesn't require that, so we might need to add an empty default case.
+ if !default_present {
+ cases.push(SwitchCase {
+ value: crate::SwitchValue::Default,
+ body: Block::new(),
+ fall_through: false,
+ })
+ }
+
+ ctx.body.push(Statement::Switch { selector, cases }, meta);
+
+ meta
+ }
+ TokenValue::While => {
+ let mut meta = self.bump(frontend)?.meta;
+
+ let loop_body = ctx.new_body(|ctx| {
+ let mut stmt = ctx.stmt_ctx();
+ self.expect(frontend, TokenValue::LeftParen)?;
+ let root = self.parse_expression(frontend, ctx, &mut stmt)?;
+ meta.subsume(self.expect(frontend, TokenValue::RightParen)?.meta);
+
+ let (expr, expr_meta) = ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs)?;
+ let condition = ctx.add_expression(
+ Expression::Unary {
+ op: UnaryOperator::LogicalNot,
+ expr,
+ },
+ expr_meta,
+ )?;
+
+ ctx.emit_restart();
+
+ ctx.body.push(
+ Statement::If {
+ condition,
+ accept: new_break(),
+ reject: Block::new(),
+ },
+ crate::Span::default(),
+ );
+
+ meta.subsume(expr_meta);
+
+ if let Some(body_meta) = self.parse_statement(frontend, ctx, &mut None, true)? {
+ meta.subsume(body_meta);
+ }
+ Ok(())
+ })?;
+
+ ctx.body.push(
+ Statement::Loop {
+ body: loop_body,
+ continuing: Block::new(),
+ break_if: None,
+ },
+ meta,
+ );
+
+ meta
+ }
+ TokenValue::Do => {
+ let mut meta = self.bump(frontend)?.meta;
+
+ let loop_body = ctx.new_body(|ctx| {
+ let mut terminator = None;
+ self.parse_statement(frontend, ctx, &mut terminator, true)?;
+
+ let mut stmt = ctx.stmt_ctx();
+
+ self.expect(frontend, TokenValue::While)?;
+ self.expect(frontend, TokenValue::LeftParen)?;
+ let root = self.parse_expression(frontend, ctx, &mut stmt)?;
+ let end_meta = self.expect(frontend, TokenValue::RightParen)?.meta;
+
+ meta.subsume(end_meta);
+
+ let (expr, expr_meta) = ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs)?;
+ let condition = ctx.add_expression(
+ Expression::Unary {
+ op: UnaryOperator::LogicalNot,
+ expr,
+ },
+ expr_meta,
+ )?;
+
+ ctx.emit_restart();
+
+ ctx.body.push(
+ Statement::If {
+ condition,
+ accept: new_break(),
+ reject: Block::new(),
+ },
+ crate::Span::default(),
+ );
+
+ if let Some(idx) = terminator {
+ ctx.body.cull(idx..)
+ }
+ Ok(())
+ })?;
+
+ ctx.body.push(
+ Statement::Loop {
+ body: loop_body,
+ continuing: Block::new(),
+ break_if: None,
+ },
+ meta,
+ );
+
+ meta
+ }
+ TokenValue::For => {
+ let mut meta = self.bump(frontend)?.meta;
+
+ ctx.symbol_table.push_scope();
+ self.expect(frontend, TokenValue::LeftParen)?;
+
+ if self.bump_if(frontend, TokenValue::Semicolon).is_none() {
+ if self.peek_type_name(frontend) || self.peek_type_qualifier(frontend) {
+ self.parse_declaration(frontend, ctx, false, false)?;
+ } else {
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt)?;
+ ctx.lower(stmt, frontend, expr, ExprPos::Rhs)?;
+ self.expect(frontend, TokenValue::Semicolon)?;
+ }
+ }
+
+ let loop_body = ctx.new_body(|ctx| {
+ if self.bump_if(frontend, TokenValue::Semicolon).is_none() {
+ let (expr, expr_meta) = if self.peek_type_name(frontend)
+ || self.peek_type_qualifier(frontend)
+ {
+ let mut qualifiers = self.parse_type_qualifiers(frontend, ctx)?;
+ let (ty, mut meta) = self.parse_type_non_void(frontend, ctx)?;
+ let name = self.expect_ident(frontend)?.0;
+
+ self.expect(frontend, TokenValue::Assign)?;
+
+ let (value, end_meta) = self.parse_initializer(frontend, ty, ctx)?;
+ meta.subsume(end_meta);
+
+ let decl = VarDeclaration {
+ qualifiers: &mut qualifiers,
+ ty,
+ name: Some(name),
+ init: None,
+ meta,
+ };
+
+ let pointer = frontend.add_local_var(ctx, decl)?;
+
+ ctx.emit_restart();
+
+ ctx.body.push(Statement::Store { pointer, value }, meta);
+
+ (value, end_meta)
+ } else {
+ let mut stmt = ctx.stmt_ctx();
+ let root = self.parse_expression(frontend, ctx, &mut stmt)?;
+ ctx.lower_expect(stmt, frontend, root, ExprPos::Rhs)?
+ };
+
+ let condition = ctx.add_expression(
+ Expression::Unary {
+ op: UnaryOperator::LogicalNot,
+ expr,
+ },
+ expr_meta,
+ )?;
+
+ ctx.emit_restart();
+
+ ctx.body.push(
+ Statement::If {
+ condition,
+ accept: new_break(),
+ reject: Block::new(),
+ },
+ crate::Span::default(),
+ );
+
+ self.expect(frontend, TokenValue::Semicolon)?;
+ }
+ Ok(())
+ })?;
+
+ let continuing = ctx.new_body(|ctx| {
+ match self.expect_peek(frontend)?.value {
+ TokenValue::RightParen => {}
+ _ => {
+ let mut stmt = ctx.stmt_ctx();
+ let rest = self.parse_expression(frontend, ctx, &mut stmt)?;
+ ctx.lower(stmt, frontend, rest, ExprPos::Rhs)?;
+ }
+ }
+ Ok(())
+ })?;
+
+ meta.subsume(self.expect(frontend, TokenValue::RightParen)?.meta);
+
+ let loop_body = ctx.with_body(loop_body, |ctx| {
+ if let Some(stmt_meta) = self.parse_statement(frontend, ctx, &mut None, true)? {
+ meta.subsume(stmt_meta);
+ }
+ Ok(())
+ })?;
+
+ ctx.body.push(
+ Statement::Loop {
+ body: loop_body,
+ continuing,
+ break_if: None,
+ },
+ meta,
+ );
+
+ ctx.symbol_table.pop_scope();
+
+ meta
+ }
+ TokenValue::LeftBrace => {
+ let mut meta = self.bump(frontend)?.meta;
+
+ let mut block_terminator = None;
+
+ let block = ctx.new_body(|ctx| {
+ let block_meta = self.parse_compound_statement(
+ meta,
+ frontend,
+ ctx,
+ &mut block_terminator,
+ is_inside_loop,
+ )?;
+ meta.subsume(block_meta);
+ Ok(())
+ })?;
+
+ ctx.body.push(Statement::Block(block), meta);
+ if block_terminator.is_some() {
+ terminator.get_or_insert(ctx.body.len());
+ }
+
+ meta
+ }
+ TokenValue::Semicolon => self.bump(frontend)?.meta,
+ _ => {
+ // Attempt to force expression parsing for remainder of the
+ // tokens. Unknown or invalid tokens will be caught there and
+ // turned into an error.
+ let mut stmt = ctx.stmt_ctx();
+ let expr = self.parse_expression(frontend, ctx, &mut stmt)?;
+ ctx.lower(stmt, frontend, expr, ExprPos::Rhs)?;
+ self.expect(frontend, TokenValue::Semicolon)?.meta
+ }
+ };
+
+ meta.subsume(meta_rest);
+ Ok(Some(meta))
+ }
+
+ pub fn parse_compound_statement(
+ &mut self,
+ mut meta: Span,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ terminator: &mut Option<usize>,
+ is_inside_loop: bool,
+ ) -> Result<Span> {
+ ctx.symbol_table.push_scope();
+
+ loop {
+ if let Some(Token {
+ meta: brace_meta, ..
+ }) = self.bump_if(frontend, TokenValue::RightBrace)
+ {
+ meta.subsume(brace_meta);
+ break;
+ }
+
+ let stmt = self.parse_statement(frontend, ctx, terminator, is_inside_loop)?;
+
+ if let Some(stmt_meta) = stmt {
+ meta.subsume(stmt_meta);
+ }
+ }
+
+ if let Some(idx) = *terminator {
+ ctx.body.cull(idx..)
+ }
+
+ ctx.symbol_table.pop_scope();
+
+ Ok(meta)
+ }
+
+ pub fn parse_function_args(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ ) -> Result<()> {
+ if self.bump_if(frontend, TokenValue::Void).is_some() {
+ return Ok(());
+ }
+
+ loop {
+ if self.peek_type_name(frontend) || self.peek_parameter_qualifier(frontend) {
+ let qualifier = self.parse_parameter_qualifier(frontend);
+ let mut ty = self.parse_type_non_void(frontend, ctx)?.0;
+
+ match self.expect_peek(frontend)?.value {
+ TokenValue::Comma => {
+ self.bump(frontend)?;
+ ctx.add_function_arg(None, ty, qualifier)?;
+ continue;
+ }
+ TokenValue::Identifier(_) => {
+ let mut name = self.expect_ident(frontend)?;
+ self.parse_array_specifier(frontend, ctx, &mut name.1, &mut ty)?;
+
+ ctx.add_function_arg(Some(name), ty, qualifier)?;
+
+ if self.bump_if(frontend, TokenValue::Comma).is_some() {
+ continue;
+ }
+
+ break;
+ }
+ _ => break,
+ }
+ }
+
+ break;
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser/types.rs b/third_party/rust/naga/src/front/glsl/parser/types.rs
new file mode 100644
index 0000000000..1b612b298d
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser/types.rs
@@ -0,0 +1,443 @@
+use std::num::NonZeroU32;
+
+use crate::{
+ front::glsl::{
+ ast::{QualifierKey, QualifierValue, StorageQualifier, StructLayout, TypeQualifiers},
+ context::Context,
+ error::ExpectedToken,
+ parser::ParsingContext,
+ token::{Token, TokenValue},
+ Error, ErrorKind, Frontend, Result,
+ },
+ AddressSpace, ArraySize, Handle, Span, Type, TypeInner,
+};
+
+impl<'source> ParsingContext<'source> {
+ /// Parses an optional array_specifier returning whether or not it's present
+ /// and modifying the type handle if it exists
+ pub fn parse_array_specifier(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ span: &mut Span,
+ ty: &mut Handle<Type>,
+ ) -> Result<()> {
+ while self.parse_array_specifier_single(frontend, ctx, span, ty)? {}
+ Ok(())
+ }
+
+ /// Implementation of [`Self::parse_array_specifier`] for a single array_specifier
+ fn parse_array_specifier_single(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ span: &mut Span,
+ ty: &mut Handle<Type>,
+ ) -> Result<bool> {
+ if self.bump_if(frontend, TokenValue::LeftBracket).is_some() {
+ let size = if let Some(Token { meta, .. }) =
+ self.bump_if(frontend, TokenValue::RightBracket)
+ {
+ span.subsume(meta);
+ ArraySize::Dynamic
+ } else {
+ let (value, constant_span) = self.parse_uint_constant(frontend, ctx)?;
+ let size = NonZeroU32::new(value).ok_or(Error {
+ kind: ErrorKind::SemanticError("Array size must be greater than zero".into()),
+ meta: constant_span,
+ })?;
+ let end_span = self.expect(frontend, TokenValue::RightBracket)?.meta;
+ span.subsume(end_span);
+ ArraySize::Constant(size)
+ };
+
+ frontend.layouter.update(ctx.module.to_ctx()).unwrap();
+ let stride = frontend.layouter[*ty].to_stride();
+ *ty = ctx.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Array {
+ base: *ty,
+ size,
+ stride,
+ },
+ },
+ *span,
+ );
+
+ Ok(true)
+ } else {
+ Ok(false)
+ }
+ }
+
+ pub fn parse_type(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ ) -> Result<(Option<Handle<Type>>, Span)> {
+ let token = self.bump(frontend)?;
+ let mut handle = match token.value {
+ TokenValue::Void => return Ok((None, token.meta)),
+ TokenValue::TypeName(ty) => ctx.module.types.insert(ty, token.meta),
+ TokenValue::Struct => {
+ let mut meta = token.meta;
+ let ty_name = self.expect_ident(frontend)?.0;
+ self.expect(frontend, TokenValue::LeftBrace)?;
+ let mut members = Vec::new();
+ let span = self.parse_struct_declaration_list(
+ frontend,
+ ctx,
+ &mut members,
+ StructLayout::Std140,
+ )?;
+ let end_meta = self.expect(frontend, TokenValue::RightBrace)?.meta;
+ meta.subsume(end_meta);
+ let ty = ctx.module.types.insert(
+ Type {
+ name: Some(ty_name.clone()),
+ inner: TypeInner::Struct { members, span },
+ },
+ meta,
+ );
+ frontend.lookup_type.insert(ty_name, ty);
+ ty
+ }
+ TokenValue::Identifier(ident) => match frontend.lookup_type.get(&ident) {
+ Some(ty) => *ty,
+ None => {
+ return Err(Error {
+ kind: ErrorKind::UnknownType(ident),
+ meta: token.meta,
+ })
+ }
+ },
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::InvalidToken(
+ token.value,
+ vec![
+ TokenValue::Void.into(),
+ TokenValue::Struct.into(),
+ ExpectedToken::TypeName,
+ ],
+ ),
+ meta: token.meta,
+ });
+ }
+ };
+
+ let mut span = token.meta;
+ self.parse_array_specifier(frontend, ctx, &mut span, &mut handle)?;
+ Ok((Some(handle), span))
+ }
+
+ pub fn parse_type_non_void(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ ) -> Result<(Handle<Type>, Span)> {
+ let (maybe_ty, meta) = self.parse_type(frontend, ctx)?;
+ let ty = maybe_ty.ok_or_else(|| Error {
+ kind: ErrorKind::SemanticError("Type can't be void".into()),
+ meta,
+ })?;
+
+ Ok((ty, meta))
+ }
+
+ pub fn peek_type_qualifier(&mut self, frontend: &mut Frontend) -> bool {
+ self.peek(frontend).map_or(false, |t| match t.value {
+ TokenValue::Invariant
+ | TokenValue::Interpolation(_)
+ | TokenValue::Sampling(_)
+ | TokenValue::PrecisionQualifier(_)
+ | TokenValue::Const
+ | TokenValue::In
+ | TokenValue::Out
+ | TokenValue::Uniform
+ | TokenValue::Shared
+ | TokenValue::Buffer
+ | TokenValue::Restrict
+ | TokenValue::MemoryQualifier(_)
+ | TokenValue::Layout => true,
+ _ => false,
+ })
+ }
+
+ pub fn parse_type_qualifiers<'a>(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ ) -> Result<TypeQualifiers<'a>> {
+ let mut qualifiers = TypeQualifiers::default();
+
+ while self.peek_type_qualifier(frontend) {
+ let token = self.bump(frontend)?;
+
+ // Handle layout qualifiers outside the match since this can push multiple values
+ if token.value == TokenValue::Layout {
+ self.parse_layout_qualifier_id_list(frontend, ctx, &mut qualifiers)?;
+ continue;
+ }
+
+ qualifiers.span.subsume(token.meta);
+
+ match token.value {
+ TokenValue::Invariant => {
+ if qualifiers.invariant.is_some() {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Cannot use more than one invariant qualifier per declaration"
+ .into(),
+ ),
+ meta: token.meta,
+ })
+ }
+
+ qualifiers.invariant = Some(token.meta);
+ }
+ TokenValue::Interpolation(i) => {
+ if qualifiers.interpolation.is_some() {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Cannot use more than one interpolation qualifier per declaration"
+ .into(),
+ ),
+ meta: token.meta,
+ })
+ }
+
+ qualifiers.interpolation = Some((i, token.meta));
+ }
+ TokenValue::Const
+ | TokenValue::In
+ | TokenValue::Out
+ | TokenValue::Uniform
+ | TokenValue::Shared
+ | TokenValue::Buffer => {
+ let storage = match token.value {
+ TokenValue::Const => StorageQualifier::Const,
+ TokenValue::In => StorageQualifier::Input,
+ TokenValue::Out => StorageQualifier::Output,
+ TokenValue::Uniform => {
+ StorageQualifier::AddressSpace(AddressSpace::Uniform)
+ }
+ TokenValue::Shared => {
+ StorageQualifier::AddressSpace(AddressSpace::WorkGroup)
+ }
+ TokenValue::Buffer => {
+ StorageQualifier::AddressSpace(AddressSpace::Storage {
+ access: crate::StorageAccess::all(),
+ })
+ }
+ _ => unreachable!(),
+ };
+
+ if StorageQualifier::AddressSpace(AddressSpace::Function)
+ != qualifiers.storage.0
+ {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Cannot use more than one storage qualifier per declaration".into(),
+ ),
+ meta: token.meta,
+ });
+ }
+
+ qualifiers.storage = (storage, token.meta);
+ }
+ TokenValue::Sampling(s) => {
+ if qualifiers.sampling.is_some() {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Cannot use more than one sampling qualifier per declaration"
+ .into(),
+ ),
+ meta: token.meta,
+ })
+ }
+
+ qualifiers.sampling = Some((s, token.meta));
+ }
+ TokenValue::PrecisionQualifier(p) => {
+ if qualifiers.precision.is_some() {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "Cannot use more than one precision qualifier per declaration"
+ .into(),
+ ),
+ meta: token.meta,
+ })
+ }
+
+ qualifiers.precision = Some((p, token.meta));
+ }
+ TokenValue::MemoryQualifier(access) => {
+ let storage_access = qualifiers
+ .storage_access
+ .get_or_insert((crate::StorageAccess::all(), Span::default()));
+ if !storage_access.0.contains(!access) {
+ frontend.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "The same memory qualifier can only be used once".into(),
+ ),
+ meta: token.meta,
+ })
+ }
+
+ storage_access.0 &= access;
+ storage_access.1.subsume(token.meta);
+ }
+ TokenValue::Restrict => continue,
+ _ => unreachable!(),
+ };
+ }
+
+ Ok(qualifiers)
+ }
+
+ pub fn parse_layout_qualifier_id_list(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ qualifiers: &mut TypeQualifiers,
+ ) -> Result<()> {
+ self.expect(frontend, TokenValue::LeftParen)?;
+ loop {
+ self.parse_layout_qualifier_id(frontend, ctx, &mut qualifiers.layout_qualifiers)?;
+
+ if self.bump_if(frontend, TokenValue::Comma).is_some() {
+ continue;
+ }
+
+ break;
+ }
+ let token = self.expect(frontend, TokenValue::RightParen)?;
+ qualifiers.span.subsume(token.meta);
+
+ Ok(())
+ }
+
+ pub fn parse_layout_qualifier_id(
+ &mut self,
+ frontend: &mut Frontend,
+ ctx: &mut Context,
+ qualifiers: &mut crate::FastHashMap<QualifierKey, (QualifierValue, Span)>,
+ ) -> Result<()> {
+ // layout_qualifier_id:
+ // IDENTIFIER
+ // IDENTIFIER EQUAL constant_expression
+ // SHARED
+ let mut token = self.bump(frontend)?;
+ match token.value {
+ TokenValue::Identifier(name) => {
+ let (key, value) = match name.as_str() {
+ "std140" => (
+ QualifierKey::Layout,
+ QualifierValue::Layout(StructLayout::Std140),
+ ),
+ "std430" => (
+ QualifierKey::Layout,
+ QualifierValue::Layout(StructLayout::Std430),
+ ),
+ word => {
+ if let Some(format) = map_image_format(word) {
+ (QualifierKey::Format, QualifierValue::Format(format))
+ } else {
+ let key = QualifierKey::String(name.into());
+ let value = if self.bump_if(frontend, TokenValue::Assign).is_some() {
+ let (value, end_meta) =
+ match self.parse_uint_constant(frontend, ctx) {
+ Ok(v) => v,
+ Err(e) => {
+ frontend.errors.push(e);
+ (0, Span::default())
+ }
+ };
+ token.meta.subsume(end_meta);
+
+ QualifierValue::Uint(value)
+ } else {
+ QualifierValue::None
+ };
+
+ (key, value)
+ }
+ }
+ };
+
+ qualifiers.insert(key, (value, token.meta));
+ }
+ _ => frontend.errors.push(Error {
+ kind: ErrorKind::InvalidToken(token.value, vec![ExpectedToken::Identifier]),
+ meta: token.meta,
+ }),
+ }
+
+ Ok(())
+ }
+
+ pub fn peek_type_name(&mut self, frontend: &mut Frontend) -> bool {
+ self.peek(frontend).map_or(false, |t| match t.value {
+ TokenValue::TypeName(_) | TokenValue::Void => true,
+ TokenValue::Struct => true,
+ TokenValue::Identifier(ref ident) => frontend.lookup_type.contains_key(ident),
+ _ => false,
+ })
+ }
+}
+
+fn map_image_format(word: &str) -> Option<crate::StorageFormat> {
+ use crate::StorageFormat as Sf;
+
+ let format = match word {
+ // float-image-format-qualifier:
+ "rgba32f" => Sf::Rgba32Float,
+ "rgba16f" => Sf::Rgba16Float,
+ "rg32f" => Sf::Rg32Float,
+ "rg16f" => Sf::Rg16Float,
+ "r11f_g11f_b10f" => Sf::Rg11b10Float,
+ "r32f" => Sf::R32Float,
+ "r16f" => Sf::R16Float,
+ "rgba16" => Sf::Rgba16Unorm,
+ "rgb10_a2ui" => Sf::Rgb10a2Uint,
+ "rgb10_a2" => Sf::Rgb10a2Unorm,
+ "rgba8" => Sf::Rgba8Unorm,
+ "rg16" => Sf::Rg16Unorm,
+ "rg8" => Sf::Rg8Unorm,
+ "r16" => Sf::R16Unorm,
+ "r8" => Sf::R8Unorm,
+ "rgba16_snorm" => Sf::Rgba16Snorm,
+ "rgba8_snorm" => Sf::Rgba8Snorm,
+ "rg16_snorm" => Sf::Rg16Snorm,
+ "rg8_snorm" => Sf::Rg8Snorm,
+ "r16_snorm" => Sf::R16Snorm,
+ "r8_snorm" => Sf::R8Snorm,
+ // int-image-format-qualifier:
+ "rgba32i" => Sf::Rgba32Sint,
+ "rgba16i" => Sf::Rgba16Sint,
+ "rgba8i" => Sf::Rgba8Sint,
+ "rg32i" => Sf::Rg32Sint,
+ "rg16i" => Sf::Rg16Sint,
+ "rg8i" => Sf::Rg8Sint,
+ "r32i" => Sf::R32Sint,
+ "r16i" => Sf::R16Sint,
+ "r8i" => Sf::R8Sint,
+ // uint-image-format-qualifier:
+ "rgba32ui" => Sf::Rgba32Uint,
+ "rgba16ui" => Sf::Rgba16Uint,
+ "rgba8ui" => Sf::Rgba8Uint,
+ "rg32ui" => Sf::Rg32Uint,
+ "rg16ui" => Sf::Rg16Uint,
+ "rg8ui" => Sf::Rg8Uint,
+ "r32ui" => Sf::R32Uint,
+ "r16ui" => Sf::R16Uint,
+ "r8ui" => Sf::R8Uint,
+ // TODO: These next ones seem incorrect to me
+ // "rgb10_a2ui" => Sf::Rgb10a2Unorm,
+ _ => return None,
+ };
+
+ Some(format)
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser_tests.rs b/third_party/rust/naga/src/front/glsl/parser_tests.rs
new file mode 100644
index 0000000000..259052cd27
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser_tests.rs
@@ -0,0 +1,858 @@
+use super::{
+ ast::Profile,
+ error::ExpectedToken,
+ error::{Error, ErrorKind, ParseError},
+ token::TokenValue,
+ Frontend, Options, Span,
+};
+use crate::ShaderStage;
+use pp_rs::token::PreprocessorError;
+
+#[test]
+fn version() {
+ let mut frontend = Frontend::default();
+
+ // invalid versions
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 99000\n void main(){}",
+ )
+ .err()
+ .unwrap(),
+ ParseError {
+ errors: vec![Error {
+ kind: ErrorKind::InvalidVersion(99000),
+ meta: Span::new(9, 14)
+ }],
+ },
+ );
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 449\n void main(){}",
+ )
+ .err()
+ .unwrap(),
+ ParseError {
+ errors: vec![Error {
+ kind: ErrorKind::InvalidVersion(449),
+ meta: Span::new(9, 12)
+ }]
+ },
+ );
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 450 smart\n void main(){}",
+ )
+ .err()
+ .unwrap(),
+ ParseError {
+ errors: vec![Error {
+ kind: ErrorKind::InvalidProfile("smart".into()),
+ meta: Span::new(13, 18),
+ }]
+ },
+ );
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 450\nvoid main(){} #version 450",
+ )
+ .err()
+ .unwrap(),
+ ParseError {
+ errors: vec![
+ Error {
+ kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedHash,),
+ meta: Span::new(27, 28),
+ },
+ Error {
+ kind: ErrorKind::InvalidToken(
+ TokenValue::Identifier("version".into()),
+ vec![ExpectedToken::Eof]
+ ),
+ meta: Span::new(28, 35)
+ }
+ ]
+ },
+ );
+
+ // valid versions
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ " # version 450\nvoid main() {}",
+ )
+ .unwrap();
+ assert_eq!(
+ (frontend.metadata().version, frontend.metadata().profile),
+ (450, Profile::Core)
+ );
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 450\nvoid main() {}",
+ )
+ .unwrap();
+ assert_eq!(
+ (frontend.metadata().version, frontend.metadata().profile),
+ (450, Profile::Core)
+ );
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ "#version 450 core\nvoid main(void) {}",
+ )
+ .unwrap();
+ assert_eq!(
+ (frontend.metadata().version, frontend.metadata().profile),
+ (450, Profile::Core)
+ );
+}
+
+#[test]
+fn control_flow() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ if (true) {
+ return 1;
+ } else {
+ return 2;
+ }
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ if (true) {
+ return 1;
+ }
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ int x;
+ int y = 3;
+ switch (5) {
+ case 2:
+ x = 2;
+ case 5:
+ x = 5;
+ y = 2;
+ break;
+ default:
+ x = 0;
+ }
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ int x = 0;
+ while(x < 5) {
+ x = x + 1;
+ }
+ do {
+ x = x - 1;
+ } while(x >= 4)
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ int x = 0;
+ for(int i = 0; i < 10;) {
+ x = x + 2;
+ }
+ for(;;);
+ return x;
+ }
+ "#,
+ )
+ .unwrap();
+}
+
+#[test]
+fn declarations() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(location = 0) in vec2 v_uv;
+ layout(location = 0) out vec4 o_color;
+ layout(set = 1, binding = 1) uniform texture2D tex;
+ layout(set = 1, binding = 2) uniform sampler tex_sampler;
+
+ layout(early_fragment_tests) in;
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(std140, set = 2, binding = 0)
+ uniform u_locals {
+ vec3 model_offs;
+ float load_time;
+ ivec4 atlas_offs;
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(push_constant)
+ uniform u_locals {
+ vec3 model_offs;
+ float load_time;
+ ivec4 atlas_offs;
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(std430, set = 2, binding = 0)
+ uniform u_locals {
+ vec3 model_offs;
+ float load_time;
+ ivec4 atlas_offs;
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(std140, set = 2, binding = 0)
+ uniform u_locals {
+ vec3 model_offs;
+ float load_time;
+ } block_var;
+
+ void main() {
+ load_time * model_offs;
+ block_var.load_time * block_var.model_offs;
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ float vector = vec4(1.0 / 17.0, 9.0 / 17.0, 3.0 / 17.0, 11.0 / 17.0);
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ precision highp float;
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+}
+
+#[test]
+fn textures() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ #version 450
+ layout(location = 0) in vec2 v_uv;
+ layout(location = 0) out vec4 o_color;
+ layout(set = 1, binding = 1) uniform texture2D tex;
+ layout(set = 1, binding = 2) uniform sampler tex_sampler;
+ void main() {
+ o_color = texture(sampler2D(tex, tex_sampler), v_uv);
+ o_color.a = texture(sampler2D(tex, tex_sampler), v_uv, 2.0).a;
+ }
+ "#,
+ )
+ .unwrap();
+}
+
+#[test]
+fn functions() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void test1(float);
+ void test1(float) {}
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void test2(float a) {}
+ void test3(float a, float b) {}
+ void test4(float, float) {}
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ float test(float a) { return a; }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ float test(vec4 p) {
+ return p.x;
+ }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ // Function overloading
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ float test(vec2 p);
+ float test(vec3 p);
+ float test(vec4 p);
+
+ float test(vec2 p) {
+ return p.x;
+ }
+
+ float test(vec3 p) {
+ return p.x;
+ }
+
+ float test(vec4 p) {
+ return p.x;
+ }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ int test(vec4 p) {
+ return p.x;
+ }
+
+ float test(vec4 p) {
+ return p.x;
+ }
+
+ void main() {}
+ "#,
+ )
+ .err()
+ .unwrap(),
+ ParseError {
+ errors: vec![Error {
+ kind: ErrorKind::SemanticError("Function already defined".into()),
+ meta: Span::new(134, 152),
+ }]
+ },
+ );
+
+ println!();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ float callee(uint q) {
+ return float(q);
+ }
+
+ float caller() {
+ callee(1u);
+ }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ // Nested function call
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ layout(set = 0, binding = 1) uniform texture2D t_noise;
+ layout(set = 0, binding = 2) uniform sampler s_noise;
+
+ void main() {
+ textureLod(sampler2D(t_noise, s_noise), vec2(1.0), 0);
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void fun(vec2 in_parameter, out float out_parameter) {
+ ivec2 _ = ivec2(in_parameter);
+ }
+
+ void main() {
+ float a;
+ fun(vec2(1.0), a);
+ }
+ "#,
+ )
+ .unwrap();
+}
+
+#[test]
+fn constants() {
+ use crate::{Constant, Expression, Type, TypeInner};
+
+ let mut frontend = Frontend::default();
+
+ let module = frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ const float a = 1.0;
+ float global = a;
+ const float b = a;
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ let mut types = module.types.iter();
+ let mut constants = module.constants.iter();
+ let mut const_expressions = module.const_expressions.iter();
+
+ let (ty_handle, ty) = types.next().unwrap();
+ assert_eq!(
+ ty,
+ &Type {
+ name: None,
+ inner: TypeInner::Scalar(crate::Scalar::F32)
+ }
+ );
+
+ let (init_handle, init) = const_expressions.next().unwrap();
+ assert_eq!(init, &Expression::Literal(crate::Literal::F32(1.0)));
+
+ assert_eq!(
+ constants.next().unwrap().1,
+ &Constant {
+ name: Some("a".to_owned()),
+ r#override: crate::Override::None,
+ ty: ty_handle,
+ init: init_handle
+ }
+ );
+
+ assert_eq!(
+ constants.next().unwrap().1,
+ &Constant {
+ name: Some("b".to_owned()),
+ r#override: crate::Override::None,
+ ty: ty_handle,
+ init: init_handle
+ }
+ );
+
+ assert!(constants.next().is_none());
+}
+
+#[test]
+fn function_overloading() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+
+ float saturate(float v) { return clamp(v, 0.0, 1.0); }
+ vec2 saturate(vec2 v) { return clamp(v, vec2(0.0), vec2(1.0)); }
+ vec3 saturate(vec3 v) { return clamp(v, vec3(0.0), vec3(1.0)); }
+ vec4 saturate(vec4 v) { return clamp(v, vec4(0.0), vec4(1.0)); }
+
+ void main() {
+ float v1 = saturate(1.5);
+ vec2 v2 = saturate(vec2(0.5, 1.5));
+ vec3 v3 = saturate(vec3(0.5, 1.5, 2.5));
+ vec3 v4 = saturate(vec4(0.5, 1.5, 2.5, 3.5));
+ }
+ "#,
+ )
+ .unwrap();
+}
+
+#[test]
+fn implicit_conversions() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ mat4 a = mat4(1);
+ float b = 1u;
+ float c = 1 + 2.0;
+ }
+ "#,
+ )
+ .unwrap();
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void test(int a) {}
+ void test(uint a) {}
+
+ void main() {
+ test(1.0);
+ }
+ "#,
+ )
+ .err()
+ .unwrap(),
+ ParseError {
+ errors: vec![Error {
+ kind: ErrorKind::SemanticError("Unknown function \'test\'".into()),
+ meta: Span::new(156, 165),
+ }]
+ },
+ );
+
+ assert_eq!(
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void test(float a) {}
+ void test(uint a) {}
+
+ void main() {
+ test(1);
+ }
+ "#,
+ )
+ .err()
+ .unwrap(),
+ ParseError {
+ errors: vec![Error {
+ kind: ErrorKind::SemanticError("Ambiguous best function for \'test\'".into()),
+ meta: Span::new(158, 165),
+ }]
+ }
+ );
+}
+
+#[test]
+fn structs() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ Test {
+ vec4 pos;
+ } xx;
+
+ void main() {}
+ "#,
+ )
+ .unwrap_err();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ struct Test {
+ vec4 pos;
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ const int NUM_VECS = 42;
+ struct Test {
+ vec4 vecs[NUM_VECS];
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ struct Hello {
+ vec4 test;
+ } test() {
+ return Hello( vec4(1.0) );
+ }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ struct Test {};
+
+ void main() {}
+ "#,
+ )
+ .unwrap_err();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ inout struct Test {
+ vec4 x;
+ };
+
+ void main() {}
+ "#,
+ )
+ .unwrap_err();
+}
+
+#[test]
+fn swizzles() {
+ let mut frontend = Frontend::default();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ vec4 v = vec4(1);
+ v.xyz = vec3(2);
+ v.x = 5.0;
+ v.xyz.zxy.yx.xy = vec2(5.0, 1.0);
+ }
+ "#,
+ )
+ .unwrap();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ vec4 v = vec4(1);
+ v.xx = vec2(5.0);
+ }
+ "#,
+ )
+ .unwrap_err();
+
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ vec3 v = vec3(1);
+ v.w = 2.0;
+ }
+ "#,
+ )
+ .unwrap_err();
+}
+
+#[test]
+fn expressions() {
+ let mut frontend = Frontend::default();
+
+ // Vector indexing
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ float test(int index) {
+ vec4 v = vec4(1.0, 2.0, 3.0, 4.0);
+ return v[index] + 1.0;
+ }
+
+ void main() {}
+ "#,
+ )
+ .unwrap();
+
+ // Prefix increment/decrement
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ uint index = 0;
+
+ --index;
+ ++index;
+ }
+ "#,
+ )
+ .unwrap();
+
+ // Dynamic indexing of array
+ frontend
+ .parse(
+ &Options::from(ShaderStage::Vertex),
+ r#"
+ # version 450
+ void main() {
+ const vec4 positions[1] = { vec4(0) };
+
+ gl_Position = positions[gl_VertexIndex];
+ }
+ "#,
+ )
+ .unwrap();
+}
diff --git a/third_party/rust/naga/src/front/glsl/token.rs b/third_party/rust/naga/src/front/glsl/token.rs
new file mode 100644
index 0000000000..303723a27b
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/token.rs
@@ -0,0 +1,137 @@
+pub use pp_rs::token::{Float, Integer, Location, Token as PPToken};
+
+use super::ast::Precision;
+use crate::{Interpolation, Sampling, Span, Type};
+
+impl From<Location> for Span {
+ fn from(loc: Location) -> Self {
+ Span::new(loc.start, loc.end)
+ }
+}
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Token {
+ pub value: TokenValue,
+ pub meta: Span,
+}
+
+/// A token passed from the lexing used in the parsing.
+///
+/// This type is exported since it's returned in the
+/// [`InvalidToken`](super::ErrorKind::InvalidToken) error.
+#[derive(Clone, Debug, PartialEq)]
+pub enum TokenValue {
+ Identifier(String),
+
+ FloatConstant(Float),
+ IntConstant(Integer),
+ BoolConstant(bool),
+
+ Layout,
+ In,
+ Out,
+ InOut,
+ Uniform,
+ Buffer,
+ Const,
+ Shared,
+
+ Restrict,
+ /// A `glsl` memory qualifier such as `writeonly`
+ ///
+ /// The associated [`crate::StorageAccess`] is the access being allowed
+ /// (for example `writeonly` has an associated value of [`crate::StorageAccess::STORE`])
+ MemoryQualifier(crate::StorageAccess),
+
+ Invariant,
+ Interpolation(Interpolation),
+ Sampling(Sampling),
+ Precision,
+ PrecisionQualifier(Precision),
+
+ Continue,
+ Break,
+ Return,
+ Discard,
+
+ If,
+ Else,
+ Switch,
+ Case,
+ Default,
+ While,
+ Do,
+ For,
+
+ Void,
+ Struct,
+ TypeName(Type),
+
+ Assign,
+ AddAssign,
+ SubAssign,
+ MulAssign,
+ DivAssign,
+ ModAssign,
+ LeftShiftAssign,
+ RightShiftAssign,
+ AndAssign,
+ XorAssign,
+ OrAssign,
+
+ Increment,
+ Decrement,
+
+ LogicalOr,
+ LogicalAnd,
+ LogicalXor,
+
+ LessEqual,
+ GreaterEqual,
+ Equal,
+ NotEqual,
+
+ LeftShift,
+ RightShift,
+
+ LeftBrace,
+ RightBrace,
+ LeftParen,
+ RightParen,
+ LeftBracket,
+ RightBracket,
+ LeftAngle,
+ RightAngle,
+
+ Comma,
+ Semicolon,
+ Colon,
+ Dot,
+ Bang,
+ Dash,
+ Tilde,
+ Plus,
+ Star,
+ Slash,
+ Percent,
+ VerticalBar,
+ Caret,
+ Ampersand,
+ Question,
+}
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Directive {
+ pub kind: DirectiveKind,
+ pub tokens: Vec<PPToken>,
+}
+
+#[derive(Debug)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum DirectiveKind {
+ Version { is_first_directive: bool },
+ Extension,
+ Pragma,
+}
diff --git a/third_party/rust/naga/src/front/glsl/types.rs b/third_party/rust/naga/src/front/glsl/types.rs
new file mode 100644
index 0000000000..e87d76fffc
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/types.rs
@@ -0,0 +1,360 @@
+use super::{context::Context, Error, ErrorKind, Result, Span};
+use crate::{
+ proc::ResolveContext, Expression, Handle, ImageClass, ImageDimension, Scalar, ScalarKind, Type,
+ TypeInner, VectorSize,
+};
+
+pub fn parse_type(type_name: &str) -> Option<Type> {
+ match type_name {
+ "bool" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar(Scalar::BOOL),
+ }),
+ "float" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar(Scalar::F32),
+ }),
+ "double" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar(Scalar::F64),
+ }),
+ "int" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar(Scalar::I32),
+ }),
+ "uint" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar(Scalar::U32),
+ }),
+ "sampler" | "samplerShadow" => Some(Type {
+ name: None,
+ inner: TypeInner::Sampler {
+ comparison: type_name == "samplerShadow",
+ },
+ }),
+ word => {
+ fn kind_width_parse(ty: &str) -> Option<Scalar> {
+ Some(match ty {
+ "" => Scalar::F32,
+ "b" => Scalar::BOOL,
+ "i" => Scalar::I32,
+ "u" => Scalar::U32,
+ "d" => Scalar::F64,
+ _ => return None,
+ })
+ }
+
+ fn size_parse(n: &str) -> Option<VectorSize> {
+ Some(match n {
+ "2" => VectorSize::Bi,
+ "3" => VectorSize::Tri,
+ "4" => VectorSize::Quad,
+ _ => return None,
+ })
+ }
+
+ let vec_parse = |word: &str| {
+ let mut iter = word.split("vec");
+
+ let kind = iter.next()?;
+ let size = iter.next()?;
+ let scalar = kind_width_parse(kind)?;
+ let size = size_parse(size)?;
+
+ Some(Type {
+ name: None,
+ inner: TypeInner::Vector { size, scalar },
+ })
+ };
+
+ let mat_parse = |word: &str| {
+ let mut iter = word.split("mat");
+
+ let kind = iter.next()?;
+ let size = iter.next()?;
+ let scalar = kind_width_parse(kind)?;
+
+ let (columns, rows) = if let Some(size) = size_parse(size) {
+ (size, size)
+ } else {
+ let mut iter = size.split('x');
+ match (iter.next()?, iter.next()?, iter.next()) {
+ (col, row, None) => (size_parse(col)?, size_parse(row)?),
+ _ => return None,
+ }
+ };
+
+ Some(Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ },
+ })
+ };
+
+ let texture_parse = |word: &str| {
+ let mut iter = word.split("texture");
+
+ let texture_kind = |ty| {
+ Some(match ty {
+ "" => ScalarKind::Float,
+ "i" => ScalarKind::Sint,
+ "u" => ScalarKind::Uint,
+ _ => return None,
+ })
+ };
+
+ let kind = iter.next()?;
+ let size = iter.next()?;
+ let kind = texture_kind(kind)?;
+
+ let sampled = |multi| ImageClass::Sampled { kind, multi };
+
+ let (dim, arrayed, class) = match size {
+ "1D" => (ImageDimension::D1, false, sampled(false)),
+ "1DArray" => (ImageDimension::D1, true, sampled(false)),
+ "2D" => (ImageDimension::D2, false, sampled(false)),
+ "2DArray" => (ImageDimension::D2, true, sampled(false)),
+ "2DMS" => (ImageDimension::D2, false, sampled(true)),
+ "2DMSArray" => (ImageDimension::D2, true, sampled(true)),
+ "3D" => (ImageDimension::D3, false, sampled(false)),
+ "Cube" => (ImageDimension::Cube, false, sampled(false)),
+ "CubeArray" => (ImageDimension::Cube, true, sampled(false)),
+ _ => return None,
+ };
+
+ Some(Type {
+ name: None,
+ inner: TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ },
+ })
+ };
+
+ let image_parse = |word: &str| {
+ let mut iter = word.split("image");
+
+ let texture_kind = |ty| {
+ Some(match ty {
+ "" => ScalarKind::Float,
+ "i" => ScalarKind::Sint,
+ "u" => ScalarKind::Uint,
+ _ => return None,
+ })
+ };
+
+ let kind = iter.next()?;
+ let size = iter.next()?;
+ // TODO: Check that the texture format and the kind match
+ let _ = texture_kind(kind)?;
+
+ let class = ImageClass::Storage {
+ format: crate::StorageFormat::R8Uint,
+ access: crate::StorageAccess::all(),
+ };
+
+ // TODO: glsl support multisampled storage images, naga doesn't
+ let (dim, arrayed) = match size {
+ "1D" => (ImageDimension::D1, false),
+ "1DArray" => (ImageDimension::D1, true),
+ "2D" => (ImageDimension::D2, false),
+ "2DArray" => (ImageDimension::D2, true),
+ "3D" => (ImageDimension::D3, false),
+ // Naga doesn't support cube images and it's usefulness
+ // is questionable, so they won't be supported for now
+ // "Cube" => (ImageDimension::Cube, false),
+ // "CubeArray" => (ImageDimension::Cube, true),
+ _ => return None,
+ };
+
+ Some(Type {
+ name: None,
+ inner: TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ },
+ })
+ };
+
+ vec_parse(word)
+ .or_else(|| mat_parse(word))
+ .or_else(|| texture_parse(word))
+ .or_else(|| image_parse(word))
+ }
+ }
+}
+
+pub const fn scalar_components(ty: &TypeInner) -> Option<Scalar> {
+ match *ty {
+ TypeInner::Scalar(scalar)
+ | TypeInner::Vector { scalar, .. }
+ | TypeInner::ValuePointer { scalar, .. }
+ | TypeInner::Matrix { scalar, .. } => Some(scalar),
+ _ => None,
+ }
+}
+
+pub const fn type_power(scalar: Scalar) -> Option<u32> {
+ Some(match scalar.kind {
+ ScalarKind::Sint => 0,
+ ScalarKind::Uint => 1,
+ ScalarKind::Float if scalar.width == 4 => 2,
+ ScalarKind::Float => 3,
+ ScalarKind::Bool | ScalarKind::AbstractInt | ScalarKind::AbstractFloat => return None,
+ })
+}
+
+impl Context<'_> {
+ /// Resolves the types of the expressions until `expr` (inclusive)
+ ///
+ /// This needs to be done before the [`typifier`] can be queried for
+ /// the types of the expressions in the range between the last grow and `expr`.
+ ///
+ /// # Note
+ ///
+ /// The `resolve_type*` methods (like [`resolve_type`]) automatically
+ /// grow the [`typifier`] so calling this method is not necessary when using
+ /// them.
+ ///
+ /// [`typifier`]: Context::typifier
+ /// [`resolve_type`]: Self::resolve_type
+ pub(crate) fn typifier_grow(&mut self, expr: Handle<Expression>, meta: Span) -> Result<()> {
+ let resolve_ctx = ResolveContext::with_locals(self.module, &self.locals, &self.arguments);
+
+ let typifier = if self.is_const {
+ &mut self.const_typifier
+ } else {
+ &mut self.typifier
+ };
+
+ let expressions = if self.is_const {
+ &self.module.const_expressions
+ } else {
+ &self.expressions
+ };
+
+ typifier
+ .grow(expr, expressions, &resolve_ctx)
+ .map_err(|error| Error {
+ kind: ErrorKind::SemanticError(format!("Can't resolve type: {error:?}").into()),
+ meta,
+ })
+ }
+
+ pub(crate) fn get_type(&self, expr: Handle<Expression>) -> &TypeInner {
+ let typifier = if self.is_const {
+ &self.const_typifier
+ } else {
+ &self.typifier
+ };
+
+ typifier.get(expr, &self.module.types)
+ }
+
+ /// Gets the type for the result of the `expr` expression
+ ///
+ /// Automatically grows the [`typifier`] to `expr` so calling
+ /// [`typifier_grow`] is not necessary
+ ///
+ /// [`typifier`]: Context::typifier
+ /// [`typifier_grow`]: Self::typifier_grow
+ pub(crate) fn resolve_type(
+ &mut self,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<&TypeInner> {
+ self.typifier_grow(expr, meta)?;
+ Ok(self.get_type(expr))
+ }
+
+ /// Gets the type handle for the result of the `expr` expression
+ ///
+ /// Automatically grows the [`typifier`] to `expr` so calling
+ /// [`typifier_grow`] is not necessary
+ ///
+ /// # Note
+ ///
+ /// Consider using [`resolve_type`] whenever possible
+ /// since it doesn't require adding each type to the [`types`] arena
+ /// and it doesn't need to mutably borrow the [`Parser`][Self]
+ ///
+ /// [`types`]: crate::Module::types
+ /// [`typifier`]: Context::typifier
+ /// [`typifier_grow`]: Self::typifier_grow
+ /// [`resolve_type`]: Self::resolve_type
+ pub(crate) fn resolve_type_handle(
+ &mut self,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<Handle<Type>> {
+ self.typifier_grow(expr, meta)?;
+
+ let typifier = if self.is_const {
+ &mut self.const_typifier
+ } else {
+ &mut self.typifier
+ };
+
+ Ok(typifier.register_type(expr, &mut self.module.types))
+ }
+
+ /// Invalidates the cached type resolution for `expr` forcing a recomputation
+ pub(crate) fn invalidate_expression(
+ &mut self,
+ expr: Handle<Expression>,
+ meta: Span,
+ ) -> Result<()> {
+ let resolve_ctx = ResolveContext::with_locals(self.module, &self.locals, &self.arguments);
+
+ let typifier = if self.is_const {
+ &mut self.const_typifier
+ } else {
+ &mut self.typifier
+ };
+
+ typifier
+ .invalidate(expr, &self.expressions, &resolve_ctx)
+ .map_err(|error| Error {
+ kind: ErrorKind::SemanticError(format!("Can't resolve type: {error:?}").into()),
+ meta,
+ })
+ }
+
+ pub(crate) fn lift_up_const_expression(
+ &mut self,
+ expr: Handle<Expression>,
+ ) -> Result<Handle<Expression>> {
+ let meta = self.expressions.get_span(expr);
+ Ok(match self.expressions[expr] {
+ ref expr @ (Expression::Literal(_)
+ | Expression::Constant(_)
+ | Expression::ZeroValue(_)) => self.module.const_expressions.append(expr.clone(), meta),
+ Expression::Compose { ty, ref components } => {
+ let mut components = components.clone();
+ for component in &mut components {
+ *component = self.lift_up_const_expression(*component)?;
+ }
+ self.module
+ .const_expressions
+ .append(Expression::Compose { ty, components }, meta)
+ }
+ Expression::Splat { size, value } => {
+ let value = self.lift_up_const_expression(value)?;
+ self.module
+ .const_expressions
+ .append(Expression::Splat { size, value }, meta)
+ }
+ _ => {
+ return Err(Error {
+ kind: ErrorKind::SemanticError("Expression is not const-expression".into()),
+ meta,
+ })
+ }
+ })
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/variables.rs b/third_party/rust/naga/src/front/glsl/variables.rs
new file mode 100644
index 0000000000..5af2b228f0
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/variables.rs
@@ -0,0 +1,646 @@
+use super::{
+ ast::*,
+ context::{Context, ExprPos},
+ error::{Error, ErrorKind},
+ Frontend, Result, Span,
+};
+use crate::{
+ AddressSpace, Binding, BuiltIn, Constant, Expression, GlobalVariable, Handle, Interpolation,
+ LocalVariable, ResourceBinding, Scalar, ScalarKind, ShaderStage, SwizzleComponent, Type,
+ TypeInner, VectorSize,
+};
+
+pub struct VarDeclaration<'a, 'key> {
+ pub qualifiers: &'a mut TypeQualifiers<'key>,
+ pub ty: Handle<Type>,
+ pub name: Option<String>,
+ pub init: Option<Handle<Expression>>,
+ pub meta: Span,
+}
+
+/// Information about a builtin used in [`add_builtin`](Frontend::add_builtin).
+struct BuiltInData {
+ /// The type of the builtin.
+ inner: TypeInner,
+ /// The associated builtin class.
+ builtin: BuiltIn,
+ /// Whether the builtin can be written to or not.
+ mutable: bool,
+ /// The storage used for the builtin.
+ storage: StorageQualifier,
+}
+
+pub enum GlobalOrConstant {
+ Global(Handle<GlobalVariable>),
+ Constant(Handle<Constant>),
+}
+
+impl Frontend {
+ /// Adds a builtin and returns a variable reference to it
+ fn add_builtin(
+ &mut self,
+ ctx: &mut Context,
+ name: &str,
+ data: BuiltInData,
+ meta: Span,
+ ) -> Result<Option<VariableReference>> {
+ let ty = ctx.module.types.insert(
+ Type {
+ name: None,
+ inner: data.inner,
+ },
+ meta,
+ );
+
+ let handle = ctx.module.global_variables.append(
+ GlobalVariable {
+ name: Some(name.into()),
+ space: AddressSpace::Private,
+ binding: None,
+ ty,
+ init: None,
+ },
+ meta,
+ );
+
+ let idx = self.entry_args.len();
+ self.entry_args.push(EntryArg {
+ name: None,
+ binding: Binding::BuiltIn(data.builtin),
+ handle,
+ storage: data.storage,
+ });
+
+ self.global_variables.push((
+ name.into(),
+ GlobalLookup {
+ kind: GlobalLookupKind::Variable(handle),
+ entry_arg: Some(idx),
+ mutable: data.mutable,
+ },
+ ));
+
+ let expr = ctx.add_expression(Expression::GlobalVariable(handle), meta)?;
+
+ let var = VariableReference {
+ expr,
+ load: true,
+ mutable: data.mutable,
+ constant: None,
+ entry_arg: Some(idx),
+ };
+
+ ctx.symbol_table.add_root(name.into(), var.clone());
+
+ Ok(Some(var))
+ }
+
+ pub(crate) fn lookup_variable(
+ &mut self,
+ ctx: &mut Context,
+ name: &str,
+ meta: Span,
+ ) -> Result<Option<VariableReference>> {
+ if let Some(var) = ctx.symbol_table.lookup(name).cloned() {
+ return Ok(Some(var));
+ }
+
+ let data = match name {
+ "gl_Position" => BuiltInData {
+ inner: TypeInner::Vector {
+ size: VectorSize::Quad,
+ scalar: Scalar::F32,
+ },
+ builtin: BuiltIn::Position { invariant: false },
+ mutable: true,
+ storage: StorageQualifier::Output,
+ },
+ "gl_FragCoord" => BuiltInData {
+ inner: TypeInner::Vector {
+ size: VectorSize::Quad,
+ scalar: Scalar::F32,
+ },
+ builtin: BuiltIn::Position { invariant: false },
+ mutable: false,
+ storage: StorageQualifier::Input,
+ },
+ "gl_PointCoord" => BuiltInData {
+ inner: TypeInner::Vector {
+ size: VectorSize::Bi,
+ scalar: Scalar::F32,
+ },
+ builtin: BuiltIn::PointCoord,
+ mutable: false,
+ storage: StorageQualifier::Input,
+ },
+ "gl_GlobalInvocationID"
+ | "gl_NumWorkGroups"
+ | "gl_WorkGroupSize"
+ | "gl_WorkGroupID"
+ | "gl_LocalInvocationID" => BuiltInData {
+ inner: TypeInner::Vector {
+ size: VectorSize::Tri,
+ scalar: Scalar::U32,
+ },
+ builtin: match name {
+ "gl_GlobalInvocationID" => BuiltIn::GlobalInvocationId,
+ "gl_NumWorkGroups" => BuiltIn::NumWorkGroups,
+ "gl_WorkGroupSize" => BuiltIn::WorkGroupSize,
+ "gl_WorkGroupID" => BuiltIn::WorkGroupId,
+ "gl_LocalInvocationID" => BuiltIn::LocalInvocationId,
+ _ => unreachable!(),
+ },
+ mutable: false,
+ storage: StorageQualifier::Input,
+ },
+ "gl_FrontFacing" => BuiltInData {
+ inner: TypeInner::Scalar(Scalar::BOOL),
+ builtin: BuiltIn::FrontFacing,
+ mutable: false,
+ storage: StorageQualifier::Input,
+ },
+ "gl_PointSize" | "gl_FragDepth" => BuiltInData {
+ inner: TypeInner::Scalar(Scalar::F32),
+ builtin: match name {
+ "gl_PointSize" => BuiltIn::PointSize,
+ "gl_FragDepth" => BuiltIn::FragDepth,
+ _ => unreachable!(),
+ },
+ mutable: true,
+ storage: StorageQualifier::Output,
+ },
+ "gl_ClipDistance" | "gl_CullDistance" => {
+ let base = ctx.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Scalar(Scalar::F32),
+ },
+ meta,
+ );
+
+ BuiltInData {
+ inner: TypeInner::Array {
+ base,
+ size: crate::ArraySize::Dynamic,
+ stride: 4,
+ },
+ builtin: match name {
+ "gl_ClipDistance" => BuiltIn::ClipDistance,
+ "gl_CullDistance" => BuiltIn::CullDistance,
+ _ => unreachable!(),
+ },
+ mutable: self.meta.stage == ShaderStage::Vertex,
+ storage: StorageQualifier::Output,
+ }
+ }
+ _ => {
+ let builtin = match name {
+ "gl_BaseVertex" => BuiltIn::BaseVertex,
+ "gl_BaseInstance" => BuiltIn::BaseInstance,
+ "gl_PrimitiveID" => BuiltIn::PrimitiveIndex,
+ "gl_InstanceIndex" => BuiltIn::InstanceIndex,
+ "gl_VertexIndex" => BuiltIn::VertexIndex,
+ "gl_SampleID" => BuiltIn::SampleIndex,
+ "gl_LocalInvocationIndex" => BuiltIn::LocalInvocationIndex,
+ _ => return Ok(None),
+ };
+
+ BuiltInData {
+ inner: TypeInner::Scalar(Scalar::U32),
+ builtin,
+ mutable: false,
+ storage: StorageQualifier::Input,
+ }
+ }
+ };
+
+ self.add_builtin(ctx, name, data, meta)
+ }
+
+ pub(crate) fn make_variable_invariant(
+ &mut self,
+ ctx: &mut Context,
+ name: &str,
+ meta: Span,
+ ) -> Result<()> {
+ if let Some(var) = self.lookup_variable(ctx, name, meta)? {
+ if let Some(index) = var.entry_arg {
+ if let Binding::BuiltIn(BuiltIn::Position { ref mut invariant }) =
+ self.entry_args[index].binding
+ {
+ *invariant = true;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ pub(crate) fn field_selection(
+ &mut self,
+ ctx: &mut Context,
+ pos: ExprPos,
+ expression: Handle<Expression>,
+ name: &str,
+ meta: Span,
+ ) -> Result<Handle<Expression>> {
+ let (ty, is_pointer) = match *ctx.resolve_type(expression, meta)? {
+ TypeInner::Pointer { base, .. } => (&ctx.module.types[base].inner, true),
+ ref ty => (ty, false),
+ };
+ match *ty {
+ TypeInner::Struct { ref members, .. } => {
+ let index = members
+ .iter()
+ .position(|m| m.name == Some(name.into()))
+ .ok_or_else(|| Error {
+ kind: ErrorKind::UnknownField(name.into()),
+ meta,
+ })?;
+ let pointer = ctx.add_expression(
+ Expression::AccessIndex {
+ base: expression,
+ index: index as u32,
+ },
+ meta,
+ )?;
+
+ Ok(match pos {
+ ExprPos::Rhs if is_pointer => {
+ ctx.add_expression(Expression::Load { pointer }, meta)?
+ }
+ _ => pointer,
+ })
+ }
+ // swizzles (xyzw, rgba, stpq)
+ TypeInner::Vector { size, .. } => {
+ let check_swizzle_components = |comps: &str| {
+ name.chars()
+ .map(|c| {
+ comps
+ .find(c)
+ .filter(|i| *i < size as usize)
+ .map(|i| SwizzleComponent::from_index(i as u32))
+ })
+ .collect::<Option<Vec<SwizzleComponent>>>()
+ };
+
+ let components = check_swizzle_components("xyzw")
+ .or_else(|| check_swizzle_components("rgba"))
+ .or_else(|| check_swizzle_components("stpq"));
+
+ if let Some(components) = components {
+ if let ExprPos::Lhs = pos {
+ let not_unique = (1..components.len())
+ .any(|i| components[i..].contains(&components[i - 1]));
+ if not_unique {
+ self.errors.push(Error {
+ kind:
+ ErrorKind::SemanticError(
+ format!(
+ "swizzle cannot have duplicate components in left-hand-side expression for \"{name:?}\""
+ )
+ .into(),
+ ),
+ meta ,
+ })
+ }
+ }
+
+ let mut pattern = [SwizzleComponent::X; 4];
+ for (pat, component) in pattern.iter_mut().zip(&components) {
+ *pat = *component;
+ }
+
+ // flatten nested swizzles (vec.zyx.xy.x => vec.z)
+ let mut expression = expression;
+ if let Expression::Swizzle {
+ size: _,
+ vector,
+ pattern: ref src_pattern,
+ } = ctx[expression]
+ {
+ expression = vector;
+ for pat in &mut pattern {
+ *pat = src_pattern[pat.index() as usize];
+ }
+ }
+
+ let size = match components.len() {
+ // Swizzles with just one component are accesses and not swizzles
+ 1 => {
+ match pos {
+ // If the position is in the right hand side and the base
+ // vector is a pointer, load it, otherwise the swizzle would
+ // produce a pointer
+ ExprPos::Rhs if is_pointer => {
+ expression = ctx.add_expression(
+ Expression::Load {
+ pointer: expression,
+ },
+ meta,
+ )?;
+ }
+ _ => {}
+ };
+ return ctx.add_expression(
+ Expression::AccessIndex {
+ base: expression,
+ index: pattern[0].index(),
+ },
+ meta,
+ );
+ }
+ 2 => VectorSize::Bi,
+ 3 => VectorSize::Tri,
+ 4 => VectorSize::Quad,
+ _ => {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ format!("Bad swizzle size for \"{name:?}\"").into(),
+ ),
+ meta,
+ });
+
+ VectorSize::Quad
+ }
+ };
+
+ if is_pointer {
+ // NOTE: for lhs expression, this extra load ends up as an unused expr, because the
+ // assignment will extract the pointer and use it directly anyway. Unfortunately we
+ // need it for validation to pass, as swizzles cannot operate on pointer values.
+ expression = ctx.add_expression(
+ Expression::Load {
+ pointer: expression,
+ },
+ meta,
+ )?;
+ }
+
+ Ok(ctx.add_expression(
+ Expression::Swizzle {
+ size,
+ vector: expression,
+ pattern,
+ },
+ meta,
+ )?)
+ } else {
+ Err(Error {
+ kind: ErrorKind::SemanticError(
+ format!("Invalid swizzle for vector \"{name}\"").into(),
+ ),
+ meta,
+ })
+ }
+ }
+ _ => Err(Error {
+ kind: ErrorKind::SemanticError(
+ format!("Can't lookup field on this type \"{name}\"").into(),
+ ),
+ meta,
+ }),
+ }
+ }
+
+ pub(crate) fn add_global_var(
+ &mut self,
+ ctx: &mut Context,
+ VarDeclaration {
+ qualifiers,
+ mut ty,
+ name,
+ init,
+ meta,
+ }: VarDeclaration,
+ ) -> Result<GlobalOrConstant> {
+ let storage = qualifiers.storage.0;
+ let (ret, lookup) = match storage {
+ StorageQualifier::Input | StorageQualifier::Output => {
+ let input = storage == StorageQualifier::Input;
+ // TODO: glslang seems to use a counter for variables without
+ // explicit location (even if that causes collisions)
+ let location = qualifiers
+ .uint_layout_qualifier("location", &mut self.errors)
+ .unwrap_or(0);
+ let interpolation = qualifiers.interpolation.take().map(|(i, _)| i).or_else(|| {
+ let kind = ctx.module.types[ty].inner.scalar_kind()?;
+ Some(match kind {
+ ScalarKind::Float => Interpolation::Perspective,
+ _ => Interpolation::Flat,
+ })
+ });
+ let sampling = qualifiers.sampling.take().map(|(s, _)| s);
+
+ let handle = ctx.module.global_variables.append(
+ GlobalVariable {
+ name: name.clone(),
+ space: AddressSpace::Private,
+ binding: None,
+ ty,
+ init,
+ },
+ meta,
+ );
+
+ let idx = self.entry_args.len();
+ self.entry_args.push(EntryArg {
+ name: name.clone(),
+ binding: Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ second_blend_source: false,
+ },
+ handle,
+ storage,
+ });
+
+ let lookup = GlobalLookup {
+ kind: GlobalLookupKind::Variable(handle),
+ entry_arg: Some(idx),
+ mutable: !input,
+ };
+
+ (GlobalOrConstant::Global(handle), lookup)
+ }
+ StorageQualifier::Const => {
+ let init = init.ok_or_else(|| Error {
+ kind: ErrorKind::SemanticError("const values must have an initializer".into()),
+ meta,
+ })?;
+
+ let constant = Constant {
+ name: name.clone(),
+ r#override: crate::Override::None,
+ ty,
+ init,
+ };
+ let handle = ctx.module.constants.fetch_or_append(constant, meta);
+
+ let lookup = GlobalLookup {
+ kind: GlobalLookupKind::Constant(handle, ty),
+ entry_arg: None,
+ mutable: false,
+ };
+
+ (GlobalOrConstant::Constant(handle), lookup)
+ }
+ StorageQualifier::AddressSpace(mut space) => {
+ match space {
+ AddressSpace::Storage { ref mut access } => {
+ if let Some((allowed_access, _)) = qualifiers.storage_access.take() {
+ *access = allowed_access;
+ }
+ }
+ AddressSpace::Uniform => match ctx.module.types[ty].inner {
+ TypeInner::Image {
+ class,
+ dim,
+ arrayed,
+ } => {
+ if let crate::ImageClass::Storage {
+ mut access,
+ mut format,
+ } = class
+ {
+ if let Some((allowed_access, _)) = qualifiers.storage_access.take()
+ {
+ access = allowed_access;
+ }
+
+ match qualifiers.layout_qualifiers.remove(&QualifierKey::Format) {
+ Some((QualifierValue::Format(f), _)) => format = f,
+ // TODO: glsl supports images without format qualifier
+ // if they are `writeonly`
+ None => self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "image types require a format layout qualifier".into(),
+ ),
+ meta,
+ }),
+ _ => unreachable!(),
+ }
+
+ ty = ctx.module.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Image {
+ dim,
+ arrayed,
+ class: crate::ImageClass::Storage { format, access },
+ },
+ },
+ meta,
+ );
+ }
+
+ space = AddressSpace::Handle
+ }
+ TypeInner::Sampler { .. } => space = AddressSpace::Handle,
+ _ => {
+ if qualifiers.none_layout_qualifier("push_constant", &mut self.errors) {
+ space = AddressSpace::PushConstant
+ }
+ }
+ },
+ AddressSpace::Function => space = AddressSpace::Private,
+ _ => {}
+ };
+
+ let binding = match space {
+ AddressSpace::Uniform | AddressSpace::Storage { .. } | AddressSpace::Handle => {
+ let binding = qualifiers.uint_layout_qualifier("binding", &mut self.errors);
+ if binding.is_none() {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError(
+ "uniform/buffer blocks require layout(binding=X)".into(),
+ ),
+ meta,
+ });
+ }
+ let set = qualifiers.uint_layout_qualifier("set", &mut self.errors);
+ binding.map(|binding| ResourceBinding {
+ group: set.unwrap_or(0),
+ binding,
+ })
+ }
+ _ => None,
+ };
+
+ let handle = ctx.module.global_variables.append(
+ GlobalVariable {
+ name: name.clone(),
+ space,
+ binding,
+ ty,
+ init,
+ },
+ meta,
+ );
+
+ let lookup = GlobalLookup {
+ kind: GlobalLookupKind::Variable(handle),
+ entry_arg: None,
+ mutable: true,
+ };
+
+ (GlobalOrConstant::Global(handle), lookup)
+ }
+ };
+
+ if let Some(name) = name {
+ ctx.add_global(&name, lookup)?;
+
+ self.global_variables.push((name, lookup));
+ }
+
+ qualifiers.unused_errors(&mut self.errors);
+
+ Ok(ret)
+ }
+
+ pub(crate) fn add_local_var(
+ &mut self,
+ ctx: &mut Context,
+ decl: VarDeclaration,
+ ) -> Result<Handle<Expression>> {
+ let storage = decl.qualifiers.storage;
+ let mutable = match storage.0 {
+ StorageQualifier::AddressSpace(AddressSpace::Function) => true,
+ StorageQualifier::Const => false,
+ _ => {
+ self.errors.push(Error {
+ kind: ErrorKind::SemanticError("Locals cannot have a storage qualifier".into()),
+ meta: storage.1,
+ });
+ true
+ }
+ };
+
+ let handle = ctx.locals.append(
+ LocalVariable {
+ name: decl.name.clone(),
+ ty: decl.ty,
+ init: decl.init,
+ },
+ decl.meta,
+ );
+ let expr = ctx.add_expression(Expression::LocalVariable(handle), decl.meta)?;
+
+ if let Some(name) = decl.name {
+ let maybe_var = ctx.add_local_var(name.clone(), expr, mutable);
+
+ if maybe_var.is_some() {
+ self.errors.push(Error {
+ kind: ErrorKind::VariableAlreadyDeclared(name),
+ meta: decl.meta,
+ })
+ }
+ }
+
+ decl.qualifiers.unused_errors(&mut self.errors);
+
+ Ok(expr)
+ }
+}
diff --git a/third_party/rust/naga/src/front/interpolator.rs b/third_party/rust/naga/src/front/interpolator.rs
new file mode 100644
index 0000000000..0196a2254d
--- /dev/null
+++ b/third_party/rust/naga/src/front/interpolator.rs
@@ -0,0 +1,62 @@
+/*!
+Interpolation defaults.
+*/
+
+impl crate::Binding {
+ /// Apply the usual default interpolation for `ty` to `binding`.
+ ///
+ /// This function is a utility front ends may use to satisfy the Naga IR's
+ /// requirement, meant to ensure that input languages' policies have been
+ /// applied appropriately, that all I/O `Binding`s from the vertex shader to the
+ /// fragment shader must have non-`None` `interpolation` values.
+ ///
+ /// All the shader languages Naga supports have similar rules:
+ /// perspective-correct, center-sampled interpolation is the default for any
+ /// binding that can vary, and everything else either defaults to flat, or
+ /// requires an explicit flat qualifier/attribute/what-have-you.
+ ///
+ /// If `binding` is not a [`Location`] binding, or if its [`interpolation`] is
+ /// already set, then make no changes. Otherwise, set `binding`'s interpolation
+ /// and sampling to reasonable defaults depending on `ty`, the type of the value
+ /// being interpolated:
+ ///
+ /// - If `ty` is a floating-point scalar, vector, or matrix type, then
+ /// default to [`Perspective`] interpolation and [`Center`] sampling.
+ ///
+ /// - If `ty` is an integral scalar or vector, then default to [`Flat`]
+ /// interpolation, which has no associated sampling.
+ ///
+ /// - For any other types, make no change. Such types are not permitted as
+ /// user-defined IO values, and will probably be flagged by the verifier
+ ///
+ /// When structs appear in input or output types, each member ought to have its
+ /// own [`Binding`], so structs are simply covered by the third case.
+ ///
+ /// [`Binding`]: crate::Binding
+ /// [`Location`]: crate::Binding::Location
+ /// [`interpolation`]: crate::Binding::Location::interpolation
+ /// [`Perspective`]: crate::Interpolation::Perspective
+ /// [`Flat`]: crate::Interpolation::Flat
+ /// [`Center`]: crate::Sampling::Center
+ pub fn apply_default_interpolation(&mut self, ty: &crate::TypeInner) {
+ if let crate::Binding::Location {
+ location: _,
+ interpolation: ref mut interpolation @ None,
+ ref mut sampling,
+ second_blend_source: _,
+ } = *self
+ {
+ match ty.scalar_kind() {
+ Some(crate::ScalarKind::Float) => {
+ *interpolation = Some(crate::Interpolation::Perspective);
+ *sampling = Some(crate::Sampling::Center);
+ }
+ Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
+ *interpolation = Some(crate::Interpolation::Flat);
+ *sampling = None;
+ }
+ Some(_) | None => {}
+ }
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/mod.rs b/third_party/rust/naga/src/front/mod.rs
new file mode 100644
index 0000000000..e1f99452e1
--- /dev/null
+++ b/third_party/rust/naga/src/front/mod.rs
@@ -0,0 +1,328 @@
+/*!
+Frontend parsers that consume binary and text shaders and load them into [`Module`](super::Module)s.
+*/
+
+mod interpolator;
+mod type_gen;
+
+#[cfg(feature = "glsl-in")]
+pub mod glsl;
+#[cfg(feature = "spv-in")]
+pub mod spv;
+#[cfg(feature = "wgsl-in")]
+pub mod wgsl;
+
+use crate::{
+ arena::{Arena, Handle, UniqueArena},
+ proc::{ResolveContext, ResolveError, TypeResolution},
+ FastHashMap,
+};
+use std::ops;
+
+/// A table of types for an `Arena<Expression>`.
+///
+/// A front end can use a `Typifier` to get types for an arena's expressions
+/// while it is still contributing expressions to it. At any point, you can call
+/// [`typifier.grow(expr, arena, ctx)`], where `expr` is a `Handle<Expression>`
+/// referring to something in `arena`, and the `Typifier` will resolve the types
+/// of all the expressions up to and including `expr`. Then you can write
+/// `typifier[handle]` to get the type of any handle at or before `expr`.
+///
+/// Note that `Typifier` does *not* build an `Arena<Type>` as a part of its
+/// usual operation. Ideally, a module's type arena should only contain types
+/// actually needed by `Handle<Type>`s elsewhere in the module — functions,
+/// variables, [`Compose`] expressions, other types, and so on — so we don't
+/// want every little thing that occurs as the type of some intermediate
+/// expression to show up there.
+///
+/// Instead, `Typifier` accumulates a [`TypeResolution`] for each expression,
+/// which refers to the `Arena<Type>` in the [`ResolveContext`] passed to `grow`
+/// as needed. [`TypeResolution`] is a lightweight representation for
+/// intermediate types like this; see its documentation for details.
+///
+/// If you do need to register a `Typifier`'s conclusion in an `Arena<Type>`
+/// (say, for a [`LocalVariable`] whose type you've inferred), you can use
+/// [`register_type`] to do so.
+///
+/// [`typifier.grow(expr, arena)`]: Typifier::grow
+/// [`register_type`]: Typifier::register_type
+/// [`Compose`]: crate::Expression::Compose
+/// [`LocalVariable`]: crate::LocalVariable
+#[derive(Debug, Default)]
+pub struct Typifier {
+ resolutions: Vec<TypeResolution>,
+}
+
+impl Typifier {
+ pub const fn new() -> Self {
+ Typifier {
+ resolutions: Vec::new(),
+ }
+ }
+
+ pub fn reset(&mut self) {
+ self.resolutions.clear()
+ }
+
+ pub fn get<'a>(
+ &'a self,
+ expr_handle: Handle<crate::Expression>,
+ types: &'a UniqueArena<crate::Type>,
+ ) -> &'a crate::TypeInner {
+ self.resolutions[expr_handle.index()].inner_with(types)
+ }
+
+ /// Add an expression's type to an `Arena<Type>`.
+ ///
+ /// Add the type of `expr_handle` to `types`, and return a `Handle<Type>`
+ /// referring to it.
+ ///
+ /// # Note
+ ///
+ /// If you just need a [`TypeInner`] for `expr_handle`'s type, consider
+ /// using `typifier[expression].inner_with(types)` instead. Calling
+ /// [`TypeResolution::inner_with`] often lets us avoid adding anything to
+ /// the arena, which can significantly reduce the number of types that end
+ /// up in the final module.
+ ///
+ /// [`TypeInner`]: crate::TypeInner
+ pub fn register_type(
+ &self,
+ expr_handle: Handle<crate::Expression>,
+ types: &mut UniqueArena<crate::Type>,
+ ) -> Handle<crate::Type> {
+ match self[expr_handle].clone() {
+ TypeResolution::Handle(handle) => handle,
+ TypeResolution::Value(inner) => {
+ types.insert(crate::Type { name: None, inner }, crate::Span::UNDEFINED)
+ }
+ }
+ }
+
+ /// Grow this typifier until it contains a type for `expr_handle`.
+ pub fn grow(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ expressions: &Arena<crate::Expression>,
+ ctx: &ResolveContext,
+ ) -> Result<(), ResolveError> {
+ if self.resolutions.len() <= expr_handle.index() {
+ for (eh, expr) in expressions.iter().skip(self.resolutions.len()) {
+ //Note: the closure can't `Err` by construction
+ let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h.index()]))?;
+ log::debug!("Resolving {:?} = {:?} : {:?}", eh, expr, resolution);
+ self.resolutions.push(resolution);
+ }
+ }
+ Ok(())
+ }
+
+ /// Recompute the type resolution for `expr_handle`.
+ ///
+ /// If the type of `expr_handle` hasn't yet been calculated, call
+ /// [`grow`](Self::grow) to ensure it is covered.
+ ///
+ /// In either case, when this returns, `self[expr_handle]` should be an
+ /// updated type resolution for `expr_handle`.
+ pub fn invalidate(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ expressions: &Arena<crate::Expression>,
+ ctx: &ResolveContext,
+ ) -> Result<(), ResolveError> {
+ if self.resolutions.len() <= expr_handle.index() {
+ self.grow(expr_handle, expressions, ctx)
+ } else {
+ let expr = &expressions[expr_handle];
+ //Note: the closure can't `Err` by construction
+ let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h.index()]))?;
+ self.resolutions[expr_handle.index()] = resolution;
+ Ok(())
+ }
+ }
+}
+
+impl ops::Index<Handle<crate::Expression>> for Typifier {
+ type Output = TypeResolution;
+ fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
+ &self.resolutions[handle.index()]
+ }
+}
+
+/// Type representing a lexical scope, associating a name to a single variable
+///
+/// The scope is generic over the variable representation and name representation
+/// in order to allow larger flexibility on the frontends on how they might
+/// represent them.
+type Scope<Name, Var> = FastHashMap<Name, Var>;
+
+/// Structure responsible for managing variable lookups and keeping track of
+/// lexical scopes
+///
+/// The symbol table is generic over the variable representation and its name
+/// to allow larger flexibility on the frontends on how they might represent them.
+///
+/// ```
+/// use naga::front::SymbolTable;
+///
+/// // Create a new symbol table with `u32`s representing the variable
+/// let mut symbol_table: SymbolTable<&str, u32> = SymbolTable::default();
+///
+/// // Add two variables named `var1` and `var2` with 0 and 2 respectively
+/// symbol_table.add("var1", 0);
+/// symbol_table.add("var2", 2);
+///
+/// // Check that `var1` exists and is `0`
+/// assert_eq!(symbol_table.lookup("var1"), Some(&0));
+///
+/// // Push a new scope and add a variable to it named `var1` shadowing the
+/// // variable of our previous scope
+/// symbol_table.push_scope();
+/// symbol_table.add("var1", 1);
+///
+/// // Check that `var1` now points to the new value of `1` and `var2` still
+/// // exists with its value of `2`
+/// assert_eq!(symbol_table.lookup("var1"), Some(&1));
+/// assert_eq!(symbol_table.lookup("var2"), Some(&2));
+///
+/// // Pop the scope
+/// symbol_table.pop_scope();
+///
+/// // Check that `var1` now refers to our initial variable with value `0`
+/// assert_eq!(symbol_table.lookup("var1"), Some(&0));
+/// ```
+///
+/// Scopes are ordered as a LIFO stack so a variable defined in a later scope
+/// with the same name as another variable defined in a earlier scope will take
+/// precedence in the lookup. Scopes can be added with [`push_scope`] and
+/// removed with [`pop_scope`].
+///
+/// A root scope is added when the symbol table is created and must always be
+/// present. Trying to pop it will result in a panic.
+///
+/// Variables can be added with [`add`] and looked up with [`lookup`]. Adding a
+/// variable will do so in the currently active scope and as mentioned
+/// previously a lookup will search from the current scope to the root scope.
+///
+/// [`push_scope`]: Self::push_scope
+/// [`pop_scope`]: Self::push_scope
+/// [`add`]: Self::add
+/// [`lookup`]: Self::lookup
+pub struct SymbolTable<Name, Var> {
+ /// Stack of lexical scopes. Not all scopes are active; see [`cursor`].
+ ///
+ /// [`cursor`]: Self::cursor
+ scopes: Vec<Scope<Name, Var>>,
+ /// Limit of the [`scopes`] stack (exclusive). By using a separate value for
+ /// the stack length instead of `Vec`'s own internal length, the scopes can
+ /// be reused to cache memory allocations.
+ ///
+ /// [`scopes`]: Self::scopes
+ cursor: usize,
+}
+
+impl<Name, Var> SymbolTable<Name, Var> {
+ /// Adds a new lexical scope.
+ ///
+ /// All variables declared after this point will be added to this scope
+ /// until another scope is pushed or [`pop_scope`] is called, causing this
+ /// scope to be removed along with all variables added to it.
+ ///
+ /// [`pop_scope`]: Self::pop_scope
+ pub fn push_scope(&mut self) {
+ // If the cursor is equal to the scope's stack length then we need to
+ // push another empty scope. Otherwise we can reuse the already existing
+ // scope.
+ if self.scopes.len() == self.cursor {
+ self.scopes.push(FastHashMap::default())
+ } else {
+ self.scopes[self.cursor].clear();
+ }
+
+ self.cursor += 1;
+ }
+
+ /// Removes the current lexical scope and all its variables
+ ///
+ /// # PANICS
+ /// - If the current lexical scope is the root scope
+ pub fn pop_scope(&mut self) {
+ // Despite the method title, the variables are only deleted when the
+ // scope is reused. This is because while a clear is inevitable if the
+ // scope needs to be reused, there are cases where the scope might be
+ // popped and not reused, i.e. if another scope with the same nesting
+ // level is never pushed again.
+ assert!(self.cursor != 1, "Tried to pop the root scope");
+
+ self.cursor -= 1;
+ }
+}
+
+impl<Name, Var> SymbolTable<Name, Var>
+where
+ Name: std::hash::Hash + Eq,
+{
+ /// Perform a lookup for a variable named `name`.
+ ///
+ /// As stated in the struct level documentation the lookup will proceed from
+ /// the current scope to the root scope, returning `Some` when a variable is
+ /// found or `None` if there doesn't exist a variable with `name` in any
+ /// scope.
+ pub fn lookup<Q: ?Sized>(&self, name: &Q) -> Option<&Var>
+ where
+ Name: std::borrow::Borrow<Q>,
+ Q: std::hash::Hash + Eq,
+ {
+ // Iterate backwards trough the scopes and try to find the variable
+ for scope in self.scopes[..self.cursor].iter().rev() {
+ if let Some(var) = scope.get(name) {
+ return Some(var);
+ }
+ }
+
+ None
+ }
+
+ /// Adds a new variable to the current scope.
+ ///
+ /// Returns the previous variable with the same name in this scope if it
+ /// exists, so that the frontend might handle it in case variable shadowing
+ /// is disallowed.
+ pub fn add(&mut self, name: Name, var: Var) -> Option<Var> {
+ self.scopes[self.cursor - 1].insert(name, var)
+ }
+
+ /// Adds a new variable to the root scope.
+ ///
+ /// This is used in GLSL for builtins which aren't known in advance and only
+ /// when used for the first time, so there must be a way to add those
+ /// declarations to the root unconditionally from the current scope.
+ ///
+ /// Returns the previous variable with the same name in the root scope if it
+ /// exists, so that the frontend might handle it in case variable shadowing
+ /// is disallowed.
+ pub fn add_root(&mut self, name: Name, var: Var) -> Option<Var> {
+ self.scopes[0].insert(name, var)
+ }
+}
+
+impl<Name, Var> Default for SymbolTable<Name, Var> {
+ /// Constructs a new symbol table with a root scope
+ fn default() -> Self {
+ Self {
+ scopes: vec![FastHashMap::default()],
+ cursor: 1,
+ }
+ }
+}
+
+use std::fmt;
+
+impl<Name: fmt::Debug, Var: fmt::Debug> fmt::Debug for SymbolTable<Name, Var> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.write_str("SymbolTable ")?;
+ f.debug_list()
+ .entries(self.scopes[..self.cursor].iter())
+ .finish()
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/convert.rs b/third_party/rust/naga/src/front/spv/convert.rs
new file mode 100644
index 0000000000..f0a714fbeb
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/convert.rs
@@ -0,0 +1,179 @@
+use super::error::Error;
+use std::convert::TryInto;
+
+pub(super) const fn map_binary_operator(word: spirv::Op) -> Result<crate::BinaryOperator, Error> {
+ use crate::BinaryOperator;
+ use spirv::Op;
+
+ match word {
+ // Arithmetic Instructions +, -, *, /, %
+ Op::IAdd | Op::FAdd => Ok(BinaryOperator::Add),
+ Op::ISub | Op::FSub => Ok(BinaryOperator::Subtract),
+ Op::IMul | Op::FMul => Ok(BinaryOperator::Multiply),
+ Op::UDiv | Op::SDiv | Op::FDiv => Ok(BinaryOperator::Divide),
+ Op::SRem => Ok(BinaryOperator::Modulo),
+ // Relational and Logical Instructions
+ Op::IEqual | Op::FOrdEqual | Op::FUnordEqual | Op::LogicalEqual => {
+ Ok(BinaryOperator::Equal)
+ }
+ Op::INotEqual | Op::FOrdNotEqual | Op::FUnordNotEqual | Op::LogicalNotEqual => {
+ Ok(BinaryOperator::NotEqual)
+ }
+ Op::ULessThan | Op::SLessThan | Op::FOrdLessThan | Op::FUnordLessThan => {
+ Ok(BinaryOperator::Less)
+ }
+ Op::ULessThanEqual
+ | Op::SLessThanEqual
+ | Op::FOrdLessThanEqual
+ | Op::FUnordLessThanEqual => Ok(BinaryOperator::LessEqual),
+ Op::UGreaterThan | Op::SGreaterThan | Op::FOrdGreaterThan | Op::FUnordGreaterThan => {
+ Ok(BinaryOperator::Greater)
+ }
+ Op::UGreaterThanEqual
+ | Op::SGreaterThanEqual
+ | Op::FOrdGreaterThanEqual
+ | Op::FUnordGreaterThanEqual => Ok(BinaryOperator::GreaterEqual),
+ Op::BitwiseOr => Ok(BinaryOperator::InclusiveOr),
+ Op::BitwiseXor => Ok(BinaryOperator::ExclusiveOr),
+ Op::BitwiseAnd => Ok(BinaryOperator::And),
+ _ => Err(Error::UnknownBinaryOperator(word)),
+ }
+}
+
+pub(super) const fn map_relational_fun(
+ word: spirv::Op,
+) -> Result<crate::RelationalFunction, Error> {
+ use crate::RelationalFunction as Rf;
+ use spirv::Op;
+
+ match word {
+ Op::All => Ok(Rf::All),
+ Op::Any => Ok(Rf::Any),
+ Op::IsNan => Ok(Rf::IsNan),
+ Op::IsInf => Ok(Rf::IsInf),
+ _ => Err(Error::UnknownRelationalFunction(word)),
+ }
+}
+
+pub(super) const fn map_vector_size(word: spirv::Word) -> Result<crate::VectorSize, Error> {
+ match word {
+ 2 => Ok(crate::VectorSize::Bi),
+ 3 => Ok(crate::VectorSize::Tri),
+ 4 => Ok(crate::VectorSize::Quad),
+ _ => Err(Error::InvalidVectorSize(word)),
+ }
+}
+
+pub(super) fn map_image_dim(word: spirv::Word) -> Result<crate::ImageDimension, Error> {
+ use spirv::Dim as D;
+ match D::from_u32(word) {
+ Some(D::Dim1D) => Ok(crate::ImageDimension::D1),
+ Some(D::Dim2D) => Ok(crate::ImageDimension::D2),
+ Some(D::Dim3D) => Ok(crate::ImageDimension::D3),
+ Some(D::DimCube) => Ok(crate::ImageDimension::Cube),
+ _ => Err(Error::UnsupportedImageDim(word)),
+ }
+}
+
+pub(super) fn map_image_format(word: spirv::Word) -> Result<crate::StorageFormat, Error> {
+ match spirv::ImageFormat::from_u32(word) {
+ Some(spirv::ImageFormat::R8) => Ok(crate::StorageFormat::R8Unorm),
+ Some(spirv::ImageFormat::R8Snorm) => Ok(crate::StorageFormat::R8Snorm),
+ Some(spirv::ImageFormat::R8ui) => Ok(crate::StorageFormat::R8Uint),
+ Some(spirv::ImageFormat::R8i) => Ok(crate::StorageFormat::R8Sint),
+ Some(spirv::ImageFormat::R16) => Ok(crate::StorageFormat::R16Unorm),
+ Some(spirv::ImageFormat::R16Snorm) => Ok(crate::StorageFormat::R16Snorm),
+ Some(spirv::ImageFormat::R16ui) => Ok(crate::StorageFormat::R16Uint),
+ Some(spirv::ImageFormat::R16i) => Ok(crate::StorageFormat::R16Sint),
+ Some(spirv::ImageFormat::R16f) => Ok(crate::StorageFormat::R16Float),
+ Some(spirv::ImageFormat::Rg8) => Ok(crate::StorageFormat::Rg8Unorm),
+ Some(spirv::ImageFormat::Rg8Snorm) => Ok(crate::StorageFormat::Rg8Snorm),
+ Some(spirv::ImageFormat::Rg8ui) => Ok(crate::StorageFormat::Rg8Uint),
+ Some(spirv::ImageFormat::Rg8i) => Ok(crate::StorageFormat::Rg8Sint),
+ Some(spirv::ImageFormat::R32ui) => Ok(crate::StorageFormat::R32Uint),
+ Some(spirv::ImageFormat::R32i) => Ok(crate::StorageFormat::R32Sint),
+ Some(spirv::ImageFormat::R32f) => Ok(crate::StorageFormat::R32Float),
+ Some(spirv::ImageFormat::Rg16) => Ok(crate::StorageFormat::Rg16Unorm),
+ Some(spirv::ImageFormat::Rg16Snorm) => Ok(crate::StorageFormat::Rg16Snorm),
+ Some(spirv::ImageFormat::Rg16ui) => Ok(crate::StorageFormat::Rg16Uint),
+ Some(spirv::ImageFormat::Rg16i) => Ok(crate::StorageFormat::Rg16Sint),
+ Some(spirv::ImageFormat::Rg16f) => Ok(crate::StorageFormat::Rg16Float),
+ Some(spirv::ImageFormat::Rgba8) => Ok(crate::StorageFormat::Rgba8Unorm),
+ Some(spirv::ImageFormat::Rgba8Snorm) => Ok(crate::StorageFormat::Rgba8Snorm),
+ Some(spirv::ImageFormat::Rgba8ui) => Ok(crate::StorageFormat::Rgba8Uint),
+ Some(spirv::ImageFormat::Rgba8i) => Ok(crate::StorageFormat::Rgba8Sint),
+ Some(spirv::ImageFormat::Rgb10a2ui) => Ok(crate::StorageFormat::Rgb10a2Uint),
+ Some(spirv::ImageFormat::Rgb10A2) => Ok(crate::StorageFormat::Rgb10a2Unorm),
+ Some(spirv::ImageFormat::R11fG11fB10f) => Ok(crate::StorageFormat::Rg11b10Float),
+ Some(spirv::ImageFormat::Rg32ui) => Ok(crate::StorageFormat::Rg32Uint),
+ Some(spirv::ImageFormat::Rg32i) => Ok(crate::StorageFormat::Rg32Sint),
+ Some(spirv::ImageFormat::Rg32f) => Ok(crate::StorageFormat::Rg32Float),
+ Some(spirv::ImageFormat::Rgba16) => Ok(crate::StorageFormat::Rgba16Unorm),
+ Some(spirv::ImageFormat::Rgba16Snorm) => Ok(crate::StorageFormat::Rgba16Snorm),
+ Some(spirv::ImageFormat::Rgba16ui) => Ok(crate::StorageFormat::Rgba16Uint),
+ Some(spirv::ImageFormat::Rgba16i) => Ok(crate::StorageFormat::Rgba16Sint),
+ Some(spirv::ImageFormat::Rgba16f) => Ok(crate::StorageFormat::Rgba16Float),
+ Some(spirv::ImageFormat::Rgba32ui) => Ok(crate::StorageFormat::Rgba32Uint),
+ Some(spirv::ImageFormat::Rgba32i) => Ok(crate::StorageFormat::Rgba32Sint),
+ Some(spirv::ImageFormat::Rgba32f) => Ok(crate::StorageFormat::Rgba32Float),
+ _ => Err(Error::UnsupportedImageFormat(word)),
+ }
+}
+
+pub(super) fn map_width(word: spirv::Word) -> Result<crate::Bytes, Error> {
+ (word >> 3) // bits to bytes
+ .try_into()
+ .map_err(|_| Error::InvalidTypeWidth(word))
+}
+
+pub(super) fn map_builtin(word: spirv::Word, invariant: bool) -> Result<crate::BuiltIn, Error> {
+ use spirv::BuiltIn as Bi;
+ Ok(match spirv::BuiltIn::from_u32(word) {
+ Some(Bi::Position | Bi::FragCoord) => crate::BuiltIn::Position { invariant },
+ Some(Bi::ViewIndex) => crate::BuiltIn::ViewIndex,
+ // vertex
+ Some(Bi::BaseInstance) => crate::BuiltIn::BaseInstance,
+ Some(Bi::BaseVertex) => crate::BuiltIn::BaseVertex,
+ Some(Bi::ClipDistance) => crate::BuiltIn::ClipDistance,
+ Some(Bi::CullDistance) => crate::BuiltIn::CullDistance,
+ Some(Bi::InstanceIndex) => crate::BuiltIn::InstanceIndex,
+ Some(Bi::PointSize) => crate::BuiltIn::PointSize,
+ Some(Bi::VertexIndex) => crate::BuiltIn::VertexIndex,
+ // fragment
+ Some(Bi::FragDepth) => crate::BuiltIn::FragDepth,
+ Some(Bi::PointCoord) => crate::BuiltIn::PointCoord,
+ Some(Bi::FrontFacing) => crate::BuiltIn::FrontFacing,
+ Some(Bi::PrimitiveId) => crate::BuiltIn::PrimitiveIndex,
+ Some(Bi::SampleId) => crate::BuiltIn::SampleIndex,
+ Some(Bi::SampleMask) => crate::BuiltIn::SampleMask,
+ // compute
+ Some(Bi::GlobalInvocationId) => crate::BuiltIn::GlobalInvocationId,
+ Some(Bi::LocalInvocationId) => crate::BuiltIn::LocalInvocationId,
+ Some(Bi::LocalInvocationIndex) => crate::BuiltIn::LocalInvocationIndex,
+ Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId,
+ Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize,
+ Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups,
+ _ => return Err(Error::UnsupportedBuiltIn(word)),
+ })
+}
+
+pub(super) fn map_storage_class(word: spirv::Word) -> Result<super::ExtendedClass, Error> {
+ use super::ExtendedClass as Ec;
+ use spirv::StorageClass as Sc;
+ Ok(match Sc::from_u32(word) {
+ Some(Sc::Function) => Ec::Global(crate::AddressSpace::Function),
+ Some(Sc::Input) => Ec::Input,
+ Some(Sc::Output) => Ec::Output,
+ Some(Sc::Private) => Ec::Global(crate::AddressSpace::Private),
+ Some(Sc::UniformConstant) => Ec::Global(crate::AddressSpace::Handle),
+ Some(Sc::StorageBuffer) => Ec::Global(crate::AddressSpace::Storage {
+ //Note: this is restricted by decorations later
+ access: crate::StorageAccess::all(),
+ }),
+ // we expect the `Storage` case to be filtered out before calling this function.
+ Some(Sc::Uniform) => Ec::Global(crate::AddressSpace::Uniform),
+ Some(Sc::Workgroup) => Ec::Global(crate::AddressSpace::WorkGroup),
+ Some(Sc::PushConstant) => Ec::Global(crate::AddressSpace::PushConstant),
+ _ => return Err(Error::UnsupportedStorageClass(word)),
+ })
+}
diff --git a/third_party/rust/naga/src/front/spv/error.rs b/third_party/rust/naga/src/front/spv/error.rs
new file mode 100644
index 0000000000..af025636c0
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/error.rs
@@ -0,0 +1,154 @@
+use super::ModuleState;
+use crate::arena::Handle;
+use codespan_reporting::diagnostic::Diagnostic;
+use codespan_reporting::files::SimpleFile;
+use codespan_reporting::term;
+use termcolor::{NoColor, WriteColor};
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error("invalid header")]
+ InvalidHeader,
+ #[error("invalid word count")]
+ InvalidWordCount,
+ #[error("unknown instruction {0}")]
+ UnknownInstruction(u16),
+ #[error("unknown capability %{0}")]
+ UnknownCapability(spirv::Word),
+ #[error("unsupported instruction {1:?} at {0:?}")]
+ UnsupportedInstruction(ModuleState, spirv::Op),
+ #[error("unsupported capability {0:?}")]
+ UnsupportedCapability(spirv::Capability),
+ #[error("unsupported extension {0}")]
+ UnsupportedExtension(String),
+ #[error("unsupported extension set {0}")]
+ UnsupportedExtSet(String),
+ #[error("unsupported extension instantiation set %{0}")]
+ UnsupportedExtInstSet(spirv::Word),
+ #[error("unsupported extension instantiation %{0}")]
+ UnsupportedExtInst(spirv::Word),
+ #[error("unsupported type {0:?}")]
+ UnsupportedType(Handle<crate::Type>),
+ #[error("unsupported execution model %{0}")]
+ UnsupportedExecutionModel(spirv::Word),
+ #[error("unsupported execution mode %{0}")]
+ UnsupportedExecutionMode(spirv::Word),
+ #[error("unsupported storage class %{0}")]
+ UnsupportedStorageClass(spirv::Word),
+ #[error("unsupported image dimension %{0}")]
+ UnsupportedImageDim(spirv::Word),
+ #[error("unsupported image format %{0}")]
+ UnsupportedImageFormat(spirv::Word),
+ #[error("unsupported builtin %{0}")]
+ UnsupportedBuiltIn(spirv::Word),
+ #[error("unsupported control flow %{0}")]
+ UnsupportedControlFlow(spirv::Word),
+ #[error("unsupported binary operator %{0}")]
+ UnsupportedBinaryOperator(spirv::Word),
+ #[error("Naga supports OpTypeRuntimeArray in the StorageBuffer storage class only")]
+ UnsupportedRuntimeArrayStorageClass,
+ #[error("unsupported matrix stride {stride} for a {columns}x{rows} matrix with scalar width={width}")]
+ UnsupportedMatrixStride {
+ stride: u32,
+ columns: u8,
+ rows: u8,
+ width: u8,
+ },
+ #[error("unknown binary operator {0:?}")]
+ UnknownBinaryOperator(spirv::Op),
+ #[error("unknown relational function {0:?}")]
+ UnknownRelationalFunction(spirv::Op),
+ #[error("invalid parameter {0:?}")]
+ InvalidParameter(spirv::Op),
+ #[error("invalid operand count {1} for {0:?}")]
+ InvalidOperandCount(spirv::Op, u16),
+ #[error("invalid operand")]
+ InvalidOperand,
+ #[error("invalid id %{0}")]
+ InvalidId(spirv::Word),
+ #[error("invalid decoration %{0}")]
+ InvalidDecoration(spirv::Word),
+ #[error("invalid type width %{0}")]
+ InvalidTypeWidth(spirv::Word),
+ #[error("invalid sign %{0}")]
+ InvalidSign(spirv::Word),
+ #[error("invalid inner type %{0}")]
+ InvalidInnerType(spirv::Word),
+ #[error("invalid vector size %{0}")]
+ InvalidVectorSize(spirv::Word),
+ #[error("invalid access type %{0}")]
+ InvalidAccessType(spirv::Word),
+ #[error("invalid access {0:?}")]
+ InvalidAccess(crate::Expression),
+ #[error("invalid access index %{0}")]
+ InvalidAccessIndex(spirv::Word),
+ #[error("invalid index type %{0}")]
+ InvalidIndexType(spirv::Word),
+ #[error("invalid binding %{0}")]
+ InvalidBinding(spirv::Word),
+ #[error("invalid global var {0:?}")]
+ InvalidGlobalVar(crate::Expression),
+ #[error("invalid image/sampler expression {0:?}")]
+ InvalidImageExpression(crate::Expression),
+ #[error("invalid image base type {0:?}")]
+ InvalidImageBaseType(Handle<crate::Type>),
+ #[error("invalid image {0:?}")]
+ InvalidImage(Handle<crate::Type>),
+ #[error("invalid as type {0:?}")]
+ InvalidAsType(Handle<crate::Type>),
+ #[error("invalid vector type {0:?}")]
+ InvalidVectorType(Handle<crate::Type>),
+ #[error("inconsistent comparison sampling {0:?}")]
+ InconsistentComparisonSampling(Handle<crate::GlobalVariable>),
+ #[error("wrong function result type %{0}")]
+ WrongFunctionResultType(spirv::Word),
+ #[error("wrong function argument type %{0}")]
+ WrongFunctionArgumentType(spirv::Word),
+ #[error("missing decoration {0:?}")]
+ MissingDecoration(spirv::Decoration),
+ #[error("bad string")]
+ BadString,
+ #[error("incomplete data")]
+ IncompleteData,
+ #[error("invalid terminator")]
+ InvalidTerminator,
+ #[error("invalid edge classification")]
+ InvalidEdgeClassification,
+ #[error("cycle detected in the CFG during traversal at {0}")]
+ ControlFlowGraphCycle(crate::front::spv::BlockId),
+ #[error("recursive function call %{0}")]
+ FunctionCallCycle(spirv::Word),
+ #[error("invalid array size {0:?}")]
+ InvalidArraySize(Handle<crate::Constant>),
+ #[error("invalid barrier scope %{0}")]
+ InvalidBarrierScope(spirv::Word),
+ #[error("invalid barrier memory semantics %{0}")]
+ InvalidBarrierMemorySemantics(spirv::Word),
+ #[error(
+ "arrays of images / samplers are supported only through bindings for \
+ now (i.e. you can't create an array of images or samplers that doesn't \
+ come from a binding)"
+ )]
+ NonBindingArrayOfImageOrSamplers,
+}
+
+impl Error {
+ pub fn emit_to_writer(&self, writer: &mut impl WriteColor, source: &str) {
+ self.emit_to_writer_with_path(writer, source, "glsl");
+ }
+
+ pub fn emit_to_writer_with_path(&self, writer: &mut impl WriteColor, source: &str, path: &str) {
+ let path = path.to_string();
+ let files = SimpleFile::new(path, source);
+ let config = codespan_reporting::term::Config::default();
+ let diagnostic = Diagnostic::error().with_message(format!("{self:?}"));
+
+ term::emit(writer, &config, &files, &diagnostic).expect("cannot write error");
+ }
+
+ pub fn emit_to_string(&self, source: &str) -> String {
+ let mut writer = NoColor::new(Vec::new());
+ self.emit_to_writer(&mut writer, source);
+ String::from_utf8(writer.into_inner()).unwrap()
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/function.rs b/third_party/rust/naga/src/front/spv/function.rs
new file mode 100644
index 0000000000..198d9c52dd
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/function.rs
@@ -0,0 +1,674 @@
+use crate::{
+ arena::{Arena, Handle},
+ front::spv::{BlockContext, BodyIndex},
+};
+
+use super::{Error, Instruction, LookupExpression, LookupHelper as _};
+use crate::proc::Emitter;
+
+pub type BlockId = u32;
+
+#[derive(Copy, Clone, Debug)]
+pub struct MergeInstruction {
+ pub merge_block_id: BlockId,
+ pub continue_block_id: Option<BlockId>,
+}
+
+impl<I: Iterator<Item = u32>> super::Frontend<I> {
+ // Registers a function call. It will generate a dummy handle to call, which
+ // gets resolved after all the functions are processed.
+ pub(super) fn add_call(
+ &mut self,
+ from: spirv::Word,
+ to: spirv::Word,
+ ) -> Handle<crate::Function> {
+ let dummy_handle = self
+ .dummy_functions
+ .append(crate::Function::default(), Default::default());
+ self.deferred_function_calls.push(to);
+ self.function_call_graph.add_edge(from, to, ());
+ dummy_handle
+ }
+
+ pub(super) fn parse_function(&mut self, module: &mut crate::Module) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.lookup_expression.clear();
+ self.lookup_load_override.clear();
+ self.lookup_sampled_image.clear();
+
+ let result_type_id = self.next()?;
+ let fun_id = self.next()?;
+ let _fun_control = self.next()?;
+ let fun_type_id = self.next()?;
+
+ let mut fun = {
+ let ft = self.lookup_function_type.lookup(fun_type_id)?;
+ if ft.return_type_id != result_type_id {
+ return Err(Error::WrongFunctionResultType(result_type_id));
+ }
+ crate::Function {
+ name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name),
+ arguments: Vec::with_capacity(ft.parameter_type_ids.len()),
+ result: if self.lookup_void_type == Some(result_type_id) {
+ None
+ } else {
+ let lookup_result_ty = self.lookup_type.lookup(result_type_id)?;
+ Some(crate::FunctionResult {
+ ty: lookup_result_ty.handle,
+ binding: None,
+ })
+ },
+ local_variables: Arena::new(),
+ expressions: self
+ .make_expression_storage(&module.global_variables, &module.constants),
+ named_expressions: crate::NamedExpressions::default(),
+ body: crate::Block::new(),
+ }
+ };
+
+ // read parameters
+ for i in 0..fun.arguments.capacity() {
+ let start = self.data_offset;
+ match self.next_inst()? {
+ Instruction {
+ op: spirv::Op::FunctionParameter,
+ wc: 3,
+ } => {
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let handle = fun.expressions.append(
+ crate::Expression::FunctionArgument(i as u32),
+ self.span_from(start),
+ );
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle,
+ type_id,
+ // Setting this to an invalid id will cause get_expr_handle
+ // to default to the main body making sure no load/stores
+ // are added.
+ block_id: 0,
+ },
+ );
+ //Note: we redo the lookup in order to work around `self` borrowing
+
+ if type_id
+ != self
+ .lookup_function_type
+ .lookup(fun_type_id)?
+ .parameter_type_ids[i]
+ {
+ return Err(Error::WrongFunctionArgumentType(type_id));
+ }
+ let ty = self.lookup_type.lookup(type_id)?.handle;
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+ fun.arguments.push(crate::FunctionArgument {
+ name: decor.name,
+ ty,
+ binding: None,
+ });
+ }
+ Instruction { op, .. } => return Err(Error::InvalidParameter(op)),
+ }
+ }
+
+ // Read body
+ self.function_call_graph.add_node(fun_id);
+ let mut parameters_sampling =
+ vec![super::image::SamplingFlags::empty(); fun.arguments.len()];
+
+ let mut block_ctx = BlockContext {
+ phis: Default::default(),
+ blocks: Default::default(),
+ body_for_label: Default::default(),
+ mergers: Default::default(),
+ bodies: Default::default(),
+ function_id: fun_id,
+ expressions: &mut fun.expressions,
+ local_arena: &mut fun.local_variables,
+ const_arena: &mut module.constants,
+ const_expressions: &mut module.const_expressions,
+ type_arena: &module.types,
+ global_arena: &module.global_variables,
+ arguments: &fun.arguments,
+ parameter_sampling: &mut parameters_sampling,
+ };
+ // Insert the main body whose parent is also himself
+ block_ctx.bodies.push(super::Body::with_parent(0));
+
+ // Scan the blocks and add them as nodes
+ loop {
+ let fun_inst = self.next_inst()?;
+ log::debug!("{:?}", fun_inst.op);
+ match fun_inst.op {
+ spirv::Op::Line => {
+ fun_inst.expect(4)?;
+ let _file_id = self.next()?;
+ let _row_id = self.next()?;
+ let _col_id = self.next()?;
+ }
+ spirv::Op::Label => {
+ // Read the label ID
+ fun_inst.expect(2)?;
+ let block_id = self.next()?;
+
+ self.next_block(block_id, &mut block_ctx)?;
+ }
+ spirv::Op::FunctionEnd => {
+ fun_inst.expect(1)?;
+ break;
+ }
+ _ => {
+ return Err(Error::UnsupportedInstruction(self.state, fun_inst.op));
+ }
+ }
+ }
+
+ if let Some(ref prefix) = self.options.block_ctx_dump_prefix {
+ let dump_suffix = match self.lookup_entry_point.get(&fun_id) {
+ Some(ep) => format!("block_ctx.{:?}-{}.txt", ep.stage, ep.name),
+ None => format!("block_ctx.Fun-{}.txt", module.functions.len()),
+ };
+ let dest = prefix.join(dump_suffix);
+ let dump = format!("{block_ctx:#?}");
+ if let Err(e) = std::fs::write(&dest, dump) {
+ log::error!("Unable to dump the block context into {:?}: {}", dest, e);
+ }
+ }
+
+ // Emit `Store` statements to properly initialize all the local variables we
+ // created for `phi` expressions.
+ //
+ // Note that get_expr_handle also contributes slightly odd entries to this table,
+ // to get the spill.
+ for phi in block_ctx.phis.iter() {
+ // Get a pointer to the local variable for the phi's value.
+ let phi_pointer = block_ctx.expressions.append(
+ crate::Expression::LocalVariable(phi.local),
+ crate::Span::default(),
+ );
+
+ // At the end of each of `phi`'s predecessor blocks, store the corresponding
+ // source value in the phi's local variable.
+ for &(source, predecessor) in phi.expressions.iter() {
+ let source_lexp = &self.lookup_expression[&source];
+ let predecessor_body_idx = block_ctx.body_for_label[&predecessor];
+ // If the expression is a global/argument it will have a 0 block
+ // id so we must use a default value instead of panicking
+ let source_body_idx = block_ctx
+ .body_for_label
+ .get(&source_lexp.block_id)
+ .copied()
+ .unwrap_or(0);
+
+ // If the Naga `Expression` generated for `source` is in scope, then we
+ // can simply store that in the phi's local variable.
+ //
+ // Otherwise, spill the source value to a local variable in the block that
+ // defines it. (We know this store dominates the predecessor; otherwise,
+ // the phi wouldn't have been able to refer to that source expression in
+ // the first place.) Then, the predecessor block can count on finding the
+ // source's value in that local variable.
+ let value = if super::is_parent(predecessor_body_idx, source_body_idx, &block_ctx) {
+ source_lexp.handle
+ } else {
+ // The source SPIR-V expression is not defined in the phi's
+ // predecessor block, nor is it a globally available expression. So it
+ // must be defined off in some other block that merely dominates the
+ // predecessor. This means that the corresponding Naga `Expression`
+ // may not be in scope in the predecessor block.
+ //
+ // In the block that defines `source`, spill it to a fresh local
+ // variable, to ensure we can still use it at the end of the
+ // predecessor.
+ let ty = self.lookup_type[&source_lexp.type_id].handle;
+ let local = block_ctx.local_arena.append(
+ crate::LocalVariable {
+ name: None,
+ ty,
+ init: None,
+ },
+ crate::Span::default(),
+ );
+
+ let pointer = block_ctx.expressions.append(
+ crate::Expression::LocalVariable(local),
+ crate::Span::default(),
+ );
+
+ // Get the spilled value of the source expression.
+ let start = block_ctx.expressions.len();
+ let expr = block_ctx
+ .expressions
+ .append(crate::Expression::Load { pointer }, crate::Span::default());
+ let range = block_ctx.expressions.range_from(start);
+
+ block_ctx
+ .blocks
+ .get_mut(&predecessor)
+ .unwrap()
+ .push(crate::Statement::Emit(range), crate::Span::default());
+
+ // At the end of the block that defines it, spill the source
+ // expression's value.
+ block_ctx
+ .blocks
+ .get_mut(&source_lexp.block_id)
+ .unwrap()
+ .push(
+ crate::Statement::Store {
+ pointer,
+ value: source_lexp.handle,
+ },
+ crate::Span::default(),
+ );
+
+ expr
+ };
+
+ // At the end of the phi predecessor block, store the source
+ // value in the phi's value.
+ block_ctx.blocks.get_mut(&predecessor).unwrap().push(
+ crate::Statement::Store {
+ pointer: phi_pointer,
+ value,
+ },
+ crate::Span::default(),
+ )
+ }
+ }
+
+ fun.body = block_ctx.lower();
+
+ // done
+ let fun_handle = module.functions.append(fun, self.span_from_with_op(start));
+ self.lookup_function.insert(
+ fun_id,
+ super::LookupFunction {
+ handle: fun_handle,
+ parameters_sampling,
+ },
+ );
+
+ if let Some(ep) = self.lookup_entry_point.remove(&fun_id) {
+ // create a wrapping function
+ let mut function = crate::Function {
+ name: Some(format!("{}_wrap", ep.name)),
+ arguments: Vec::new(),
+ result: None,
+ local_variables: Arena::new(),
+ expressions: Arena::new(),
+ named_expressions: crate::NamedExpressions::default(),
+ body: crate::Block::new(),
+ };
+
+ // 1. copy the inputs from arguments to privates
+ for &v_id in ep.variable_ids.iter() {
+ let lvar = self.lookup_variable.lookup(v_id)?;
+ if let super::Variable::Input(ref arg) = lvar.inner {
+ let span = module.global_variables.get_span(lvar.handle);
+ let arg_expr = function.expressions.append(
+ crate::Expression::FunctionArgument(function.arguments.len() as u32),
+ span,
+ );
+ let load_expr = if arg.ty == module.global_variables[lvar.handle].ty {
+ arg_expr
+ } else {
+ // The only case where the type is different is if we need to treat
+ // unsigned integer as signed.
+ let mut emitter = Emitter::default();
+ emitter.start(&function.expressions);
+ let handle = function.expressions.append(
+ crate::Expression::As {
+ expr: arg_expr,
+ kind: crate::ScalarKind::Sint,
+ convert: Some(4),
+ },
+ span,
+ );
+ function.body.extend(emitter.finish(&function.expressions));
+ handle
+ };
+ function.body.push(
+ crate::Statement::Store {
+ pointer: function
+ .expressions
+ .append(crate::Expression::GlobalVariable(lvar.handle), span),
+ value: load_expr,
+ },
+ span,
+ );
+
+ let mut arg = arg.clone();
+ if ep.stage == crate::ShaderStage::Fragment {
+ if let Some(ref mut binding) = arg.binding {
+ binding.apply_default_interpolation(&module.types[arg.ty].inner);
+ }
+ }
+ function.arguments.push(arg);
+ }
+ }
+ // 2. call the wrapped function
+ let fake_id = !(module.entry_points.len() as u32); // doesn't matter, as long as it's not a collision
+ let dummy_handle = self.add_call(fake_id, fun_id);
+ function.body.push(
+ crate::Statement::Call {
+ function: dummy_handle,
+ arguments: Vec::new(),
+ result: None,
+ },
+ crate::Span::default(),
+ );
+
+ // 3. copy the outputs from privates to the result
+ let mut members = Vec::new();
+ let mut components = Vec::new();
+ for &v_id in ep.variable_ids.iter() {
+ let lvar = self.lookup_variable.lookup(v_id)?;
+ if let super::Variable::Output(ref result) = lvar.inner {
+ let span = module.global_variables.get_span(lvar.handle);
+ let expr_handle = function
+ .expressions
+ .append(crate::Expression::GlobalVariable(lvar.handle), span);
+
+ // Cull problematic builtins of gl_PerVertex.
+ // See the docs for `Frontend::gl_per_vertex_builtin_access`.
+ {
+ let ty = &module.types[result.ty];
+ match ty.inner {
+ crate::TypeInner::Struct {
+ members: ref original_members,
+ span,
+ } if ty.name.as_deref() == Some("gl_PerVertex") => {
+ let mut new_members = original_members.clone();
+ for member in &mut new_members {
+ if let Some(crate::Binding::BuiltIn(built_in)) = member.binding
+ {
+ if !self.gl_per_vertex_builtin_access.contains(&built_in) {
+ member.binding = None
+ }
+ }
+ }
+ if &new_members != original_members {
+ module.types.replace(
+ result.ty,
+ crate::Type {
+ name: ty.name.clone(),
+ inner: crate::TypeInner::Struct {
+ members: new_members,
+ span,
+ },
+ },
+ );
+ }
+ }
+ _ => {}
+ }
+ }
+
+ match module.types[result.ty].inner {
+ crate::TypeInner::Struct {
+ members: ref sub_members,
+ ..
+ } => {
+ for (index, sm) in sub_members.iter().enumerate() {
+ if sm.binding.is_none() {
+ continue;
+ }
+ let mut sm = sm.clone();
+
+ if let Some(ref mut binding) = sm.binding {
+ if ep.stage == crate::ShaderStage::Vertex {
+ binding.apply_default_interpolation(
+ &module.types[sm.ty].inner,
+ );
+ }
+ }
+
+ members.push(sm);
+
+ components.push(function.expressions.append(
+ crate::Expression::AccessIndex {
+ base: expr_handle,
+ index: index as u32,
+ },
+ span,
+ ));
+ }
+ }
+ ref inner => {
+ let mut binding = result.binding.clone();
+ if let Some(ref mut binding) = binding {
+ if ep.stage == crate::ShaderStage::Vertex {
+ binding.apply_default_interpolation(inner);
+ }
+ }
+
+ members.push(crate::StructMember {
+ name: None,
+ ty: result.ty,
+ binding,
+ offset: 0,
+ });
+ // populate just the globals first, then do `Load` in a
+ // separate step, so that we can get a range.
+ components.push(expr_handle);
+ }
+ }
+ }
+ }
+
+ for (member_index, member) in members.iter().enumerate() {
+ match member.binding {
+ Some(crate::Binding::BuiltIn(crate::BuiltIn::Position { .. }))
+ if self.options.adjust_coordinate_space =>
+ {
+ let mut emitter = Emitter::default();
+ emitter.start(&function.expressions);
+ let global_expr = components[member_index];
+ let span = function.expressions.get_span(global_expr);
+ let access_expr = function.expressions.append(
+ crate::Expression::AccessIndex {
+ base: global_expr,
+ index: 1,
+ },
+ span,
+ );
+ let load_expr = function.expressions.append(
+ crate::Expression::Load {
+ pointer: access_expr,
+ },
+ span,
+ );
+ let neg_expr = function.expressions.append(
+ crate::Expression::Unary {
+ op: crate::UnaryOperator::Negate,
+ expr: load_expr,
+ },
+ span,
+ );
+ function.body.extend(emitter.finish(&function.expressions));
+ function.body.push(
+ crate::Statement::Store {
+ pointer: access_expr,
+ value: neg_expr,
+ },
+ span,
+ );
+ }
+ _ => {}
+ }
+ }
+
+ let mut emitter = Emitter::default();
+ emitter.start(&function.expressions);
+ for component in components.iter_mut() {
+ let load_expr = crate::Expression::Load {
+ pointer: *component,
+ };
+ let span = function.expressions.get_span(*component);
+ *component = function.expressions.append(load_expr, span);
+ }
+
+ match members[..] {
+ [] => {}
+ [ref member] => {
+ function.body.extend(emitter.finish(&function.expressions));
+ let span = function.expressions.get_span(components[0]);
+ function.body.push(
+ crate::Statement::Return {
+ value: components.first().cloned(),
+ },
+ span,
+ );
+ function.result = Some(crate::FunctionResult {
+ ty: member.ty,
+ binding: member.binding.clone(),
+ });
+ }
+ _ => {
+ let span = crate::Span::total_span(
+ components.iter().map(|h| function.expressions.get_span(*h)),
+ );
+ let ty = module.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Struct {
+ members,
+ span: 0xFFFF, // shouldn't matter
+ },
+ },
+ span,
+ );
+ let result_expr = function
+ .expressions
+ .append(crate::Expression::Compose { ty, components }, span);
+ function.body.extend(emitter.finish(&function.expressions));
+ function.body.push(
+ crate::Statement::Return {
+ value: Some(result_expr),
+ },
+ span,
+ );
+ function.result = Some(crate::FunctionResult { ty, binding: None });
+ }
+ }
+
+ module.entry_points.push(crate::EntryPoint {
+ name: ep.name,
+ stage: ep.stage,
+ early_depth_test: ep.early_depth_test,
+ workgroup_size: ep.workgroup_size,
+ function,
+ });
+ }
+
+ Ok(())
+ }
+}
+
+impl<'function> BlockContext<'function> {
+ pub(super) fn gctx(&self) -> crate::proc::GlobalCtx {
+ crate::proc::GlobalCtx {
+ types: self.type_arena,
+ constants: self.const_arena,
+ const_expressions: self.const_expressions,
+ }
+ }
+
+ /// Consumes the `BlockContext` producing a Ir [`Block`](crate::Block)
+ fn lower(mut self) -> crate::Block {
+ fn lower_impl(
+ blocks: &mut crate::FastHashMap<spirv::Word, crate::Block>,
+ bodies: &[super::Body],
+ body_idx: BodyIndex,
+ ) -> crate::Block {
+ let mut block = crate::Block::new();
+
+ for item in bodies[body_idx].data.iter() {
+ match *item {
+ super::BodyFragment::BlockId(id) => block.append(blocks.get_mut(&id).unwrap()),
+ super::BodyFragment::If {
+ condition,
+ accept,
+ reject,
+ } => {
+ let accept = lower_impl(blocks, bodies, accept);
+ let reject = lower_impl(blocks, bodies, reject);
+
+ block.push(
+ crate::Statement::If {
+ condition,
+ accept,
+ reject,
+ },
+ crate::Span::default(),
+ )
+ }
+ super::BodyFragment::Loop {
+ body,
+ continuing,
+ break_if,
+ } => {
+ let body = lower_impl(blocks, bodies, body);
+ let continuing = lower_impl(blocks, bodies, continuing);
+
+ block.push(
+ crate::Statement::Loop {
+ body,
+ continuing,
+ break_if,
+ },
+ crate::Span::default(),
+ )
+ }
+ super::BodyFragment::Switch {
+ selector,
+ ref cases,
+ default,
+ } => {
+ let mut ir_cases: Vec<_> = cases
+ .iter()
+ .map(|&(value, body_idx)| {
+ let body = lower_impl(blocks, bodies, body_idx);
+
+ // Handle simple cases that would make a fallthrough statement unreachable code
+ let fall_through = body.last().map_or(true, |s| !s.is_terminator());
+
+ crate::SwitchCase {
+ value: crate::SwitchValue::I32(value),
+ body,
+ fall_through,
+ }
+ })
+ .collect();
+ ir_cases.push(crate::SwitchCase {
+ value: crate::SwitchValue::Default,
+ body: lower_impl(blocks, bodies, default),
+ fall_through: false,
+ });
+
+ block.push(
+ crate::Statement::Switch {
+ selector,
+ cases: ir_cases,
+ },
+ crate::Span::default(),
+ )
+ }
+ super::BodyFragment::Break => {
+ block.push(crate::Statement::Break, crate::Span::default())
+ }
+ super::BodyFragment::Continue => {
+ block.push(crate::Statement::Continue, crate::Span::default())
+ }
+ }
+ }
+
+ block
+ }
+
+ lower_impl(&mut self.blocks, &self.bodies, 0)
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/image.rs b/third_party/rust/naga/src/front/spv/image.rs
new file mode 100644
index 0000000000..0f25dd626b
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/image.rs
@@ -0,0 +1,767 @@
+use crate::{
+ arena::{Handle, UniqueArena},
+ Scalar,
+};
+
+use super::{Error, LookupExpression, LookupHelper as _};
+
+#[derive(Clone, Debug)]
+pub(super) struct LookupSampledImage {
+ image: Handle<crate::Expression>,
+ sampler: Handle<crate::Expression>,
+}
+
+bitflags::bitflags! {
+ /// Flags describing sampling method.
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct SamplingFlags: u32 {
+ /// Regular sampling.
+ const REGULAR = 0x1;
+ /// Comparison sampling.
+ const COMPARISON = 0x2;
+ }
+}
+
+impl<'function> super::BlockContext<'function> {
+ fn get_image_expr_ty(
+ &self,
+ handle: Handle<crate::Expression>,
+ ) -> Result<Handle<crate::Type>, Error> {
+ match self.expressions[handle] {
+ crate::Expression::GlobalVariable(handle) => Ok(self.global_arena[handle].ty),
+ crate::Expression::FunctionArgument(i) => Ok(self.arguments[i as usize].ty),
+ ref other => Err(Error::InvalidImageExpression(other.clone())),
+ }
+ }
+}
+
+/// Options of a sampling operation.
+#[derive(Debug)]
+pub struct SamplingOptions {
+ /// Projection sampling: the division by W is expected to happen
+ /// in the texture unit.
+ pub project: bool,
+ /// Depth comparison sampling with a reference value.
+ pub compare: bool,
+}
+
+enum ExtraCoordinate {
+ ArrayLayer,
+ Projection,
+ Garbage,
+}
+
+/// Return the texture coordinates separated from the array layer,
+/// and/or divided by the projection term.
+///
+/// The Proj sampling ops expect an extra coordinate for the W.
+/// The arrayed (can't be Proj!) images expect an extra coordinate for the layer.
+fn extract_image_coordinates(
+ image_dim: crate::ImageDimension,
+ extra_coordinate: ExtraCoordinate,
+ base: Handle<crate::Expression>,
+ coordinate_ty: Handle<crate::Type>,
+ ctx: &mut super::BlockContext,
+) -> (Handle<crate::Expression>, Option<Handle<crate::Expression>>) {
+ let (given_size, kind) = match ctx.type_arena[coordinate_ty].inner {
+ crate::TypeInner::Scalar(Scalar { kind, .. }) => (None, kind),
+ crate::TypeInner::Vector {
+ size,
+ scalar: Scalar { kind, .. },
+ } => (Some(size), kind),
+ ref other => unreachable!("Unexpected texture coordinate {:?}", other),
+ };
+
+ let required_size = image_dim.required_coordinate_size();
+ let required_ty = required_size.map(|size| {
+ ctx.type_arena
+ .get(&crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ size,
+ scalar: Scalar { kind, width: 4 },
+ },
+ })
+ .expect("Required coordinate type should have been set up by `parse_type_image`!")
+ });
+ let extra_expr = crate::Expression::AccessIndex {
+ base,
+ index: required_size.map_or(1, |size| size as u32),
+ };
+
+ let base_span = ctx.expressions.get_span(base);
+
+ match extra_coordinate {
+ ExtraCoordinate::ArrayLayer => {
+ let extracted = match required_size {
+ None => ctx
+ .expressions
+ .append(crate::Expression::AccessIndex { base, index: 0 }, base_span),
+ Some(size) => {
+ let mut components = Vec::with_capacity(size as usize);
+ for index in 0..size as u32 {
+ let comp = ctx
+ .expressions
+ .append(crate::Expression::AccessIndex { base, index }, base_span);
+ components.push(comp);
+ }
+ ctx.expressions.append(
+ crate::Expression::Compose {
+ ty: required_ty.unwrap(),
+ components,
+ },
+ base_span,
+ )
+ }
+ };
+ let array_index_f32 = ctx.expressions.append(extra_expr, base_span);
+ let array_index = ctx.expressions.append(
+ crate::Expression::As {
+ kind: crate::ScalarKind::Sint,
+ expr: array_index_f32,
+ convert: Some(4),
+ },
+ base_span,
+ );
+ (extracted, Some(array_index))
+ }
+ ExtraCoordinate::Projection => {
+ let projection = ctx.expressions.append(extra_expr, base_span);
+ let divided = match required_size {
+ None => {
+ let temp = ctx
+ .expressions
+ .append(crate::Expression::AccessIndex { base, index: 0 }, base_span);
+ ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Divide,
+ left: temp,
+ right: projection,
+ },
+ base_span,
+ )
+ }
+ Some(size) => {
+ let mut components = Vec::with_capacity(size as usize);
+ for index in 0..size as u32 {
+ let temp = ctx
+ .expressions
+ .append(crate::Expression::AccessIndex { base, index }, base_span);
+ let comp = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Divide,
+ left: temp,
+ right: projection,
+ },
+ base_span,
+ );
+ components.push(comp);
+ }
+ ctx.expressions.append(
+ crate::Expression::Compose {
+ ty: required_ty.unwrap(),
+ components,
+ },
+ base_span,
+ )
+ }
+ };
+ (divided, None)
+ }
+ ExtraCoordinate::Garbage if given_size == required_size => (base, None),
+ ExtraCoordinate::Garbage => {
+ use crate::SwizzleComponent as Sc;
+ let cut_expr = match required_size {
+ None => crate::Expression::AccessIndex { base, index: 0 },
+ Some(size) => crate::Expression::Swizzle {
+ size,
+ vector: base,
+ pattern: [Sc::X, Sc::Y, Sc::Z, Sc::W],
+ },
+ };
+ (ctx.expressions.append(cut_expr, base_span), None)
+ }
+ }
+}
+
+pub(super) fn patch_comparison_type(
+ flags: SamplingFlags,
+ var: &mut crate::GlobalVariable,
+ arena: &mut UniqueArena<crate::Type>,
+) -> bool {
+ if !flags.contains(SamplingFlags::COMPARISON) {
+ return true;
+ }
+ if flags == SamplingFlags::all() {
+ return false;
+ }
+
+ log::debug!("Flipping comparison for {:?}", var);
+ let original_ty = &arena[var.ty];
+ let original_ty_span = arena.get_span(var.ty);
+ let ty_inner = match original_ty.inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Sampled { multi, .. },
+ dim,
+ arrayed,
+ } => crate::TypeInner::Image {
+ class: crate::ImageClass::Depth { multi },
+ dim,
+ arrayed,
+ },
+ crate::TypeInner::Sampler { .. } => crate::TypeInner::Sampler { comparison: true },
+ ref other => unreachable!("Unexpected type for comparison mutation: {:?}", other),
+ };
+
+ let name = original_ty.name.clone();
+ var.ty = arena.insert(
+ crate::Type {
+ name,
+ inner: ty_inner,
+ },
+ original_ty_span,
+ );
+ true
+}
+
+impl<I: Iterator<Item = u32>> super::Frontend<I> {
+ pub(super) fn parse_image_couple(&mut self) -> Result<(), Error> {
+ let _result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let image_id = self.next()?;
+ let sampler_id = self.next()?;
+ let image_lexp = self.lookup_expression.lookup(image_id)?;
+ let sampler_lexp = self.lookup_expression.lookup(sampler_id)?;
+ self.lookup_sampled_image.insert(
+ result_id,
+ LookupSampledImage {
+ image: image_lexp.handle,
+ sampler: sampler_lexp.handle,
+ },
+ );
+ Ok(())
+ }
+
+ pub(super) fn parse_image_uncouple(&mut self, block_id: spirv::Word) -> Result<(), Error> {
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let sampled_image_id = self.next()?;
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: self.lookup_sampled_image.lookup(sampled_image_id)?.image,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ pub(super) fn parse_image_write(
+ &mut self,
+ words_left: u16,
+ ctx: &mut super::BlockContext,
+ emitter: &mut crate::proc::Emitter,
+ block: &mut crate::Block,
+ body_idx: usize,
+ ) -> Result<crate::Statement, Error> {
+ let image_id = self.next()?;
+ let coordinate_id = self.next()?;
+ let value_id = self.next()?;
+
+ let image_ops = if words_left != 0 { self.next()? } else { 0 };
+
+ if image_ops != 0 {
+ let other = spirv::ImageOperands::from_bits_truncate(image_ops);
+ log::warn!("Unknown image write ops {:?}", other);
+ for _ in 1..words_left {
+ self.next()?;
+ }
+ }
+
+ let image_lexp = self.lookup_expression.lookup(image_id)?;
+ let image_ty = ctx.get_image_expr_ty(image_lexp.handle)?;
+
+ let coord_lexp = self.lookup_expression.lookup(coordinate_id)?;
+ let coord_handle =
+ self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx);
+ let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle;
+ let (coordinate, array_index) = match ctx.type_arena[image_ty].inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: _,
+ } => extract_image_coordinates(
+ dim,
+ if arrayed {
+ ExtraCoordinate::ArrayLayer
+ } else {
+ ExtraCoordinate::Garbage
+ },
+ coord_handle,
+ coord_type_handle,
+ ctx,
+ ),
+ _ => return Err(Error::InvalidImage(image_ty)),
+ };
+
+ let value_lexp = self.lookup_expression.lookup(value_id)?;
+ let value = self.get_expr_handle(value_id, value_lexp, ctx, emitter, block, body_idx);
+
+ Ok(crate::Statement::ImageStore {
+ image: image_lexp.handle,
+ coordinate,
+ array_index,
+ value,
+ })
+ }
+
+ pub(super) fn parse_image_load(
+ &mut self,
+ mut words_left: u16,
+ ctx: &mut super::BlockContext,
+ emitter: &mut crate::proc::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let image_id = self.next()?;
+ let coordinate_id = self.next()?;
+
+ let mut image_ops = if words_left != 0 {
+ words_left -= 1;
+ self.next()?
+ } else {
+ 0
+ };
+
+ let mut sample = None;
+ let mut level = None;
+ while image_ops != 0 {
+ let bit = 1 << image_ops.trailing_zeros();
+ match spirv::ImageOperands::from_bits_truncate(bit) {
+ spirv::ImageOperands::LOD => {
+ let lod_expr = self.next()?;
+ let lod_lexp = self.lookup_expression.lookup(lod_expr)?;
+ let lod_handle =
+ self.get_expr_handle(lod_expr, lod_lexp, ctx, emitter, block, body_idx);
+ level = Some(lod_handle);
+ words_left -= 1;
+ }
+ spirv::ImageOperands::SAMPLE => {
+ let sample_expr = self.next()?;
+ let sample_handle = self.lookup_expression.lookup(sample_expr)?.handle;
+ sample = Some(sample_handle);
+ words_left -= 1;
+ }
+ other => {
+ log::warn!("Unknown image load op {:?}", other);
+ for _ in 0..words_left {
+ self.next()?;
+ }
+ break;
+ }
+ }
+ image_ops ^= bit;
+ }
+
+ // No need to call get_expr_handle here since only globals/arguments are
+ // allowed as images and they are always in the root scope
+ let image_lexp = self.lookup_expression.lookup(image_id)?;
+ let image_ty = ctx.get_image_expr_ty(image_lexp.handle)?;
+
+ let coord_lexp = self.lookup_expression.lookup(coordinate_id)?;
+ let coord_handle =
+ self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx);
+ let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle;
+ let (coordinate, array_index) = match ctx.type_arena[image_ty].inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: _,
+ } => extract_image_coordinates(
+ dim,
+ if arrayed {
+ ExtraCoordinate::ArrayLayer
+ } else {
+ ExtraCoordinate::Garbage
+ },
+ coord_handle,
+ coord_type_handle,
+ ctx,
+ ),
+ _ => return Err(Error::InvalidImage(image_ty)),
+ };
+
+ let expr = crate::Expression::ImageLoad {
+ image: image_lexp.handle,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn parse_image_sample(
+ &mut self,
+ mut words_left: u16,
+ options: SamplingOptions,
+ ctx: &mut super::BlockContext,
+ emitter: &mut crate::proc::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let sampled_image_id = self.next()?;
+ let coordinate_id = self.next()?;
+ let dref_id = if options.compare {
+ Some(self.next()?)
+ } else {
+ None
+ };
+
+ let mut image_ops = if words_left != 0 {
+ words_left -= 1;
+ self.next()?
+ } else {
+ 0
+ };
+
+ let mut level = crate::SampleLevel::Auto;
+ let mut offset = None;
+ while image_ops != 0 {
+ let bit = 1 << image_ops.trailing_zeros();
+ match spirv::ImageOperands::from_bits_truncate(bit) {
+ spirv::ImageOperands::BIAS => {
+ let bias_expr = self.next()?;
+ let bias_lexp = self.lookup_expression.lookup(bias_expr)?;
+ let bias_handle =
+ self.get_expr_handle(bias_expr, bias_lexp, ctx, emitter, block, body_idx);
+ level = crate::SampleLevel::Bias(bias_handle);
+ words_left -= 1;
+ }
+ spirv::ImageOperands::LOD => {
+ let lod_expr = self.next()?;
+ let lod_lexp = self.lookup_expression.lookup(lod_expr)?;
+ let lod_handle =
+ self.get_expr_handle(lod_expr, lod_lexp, ctx, emitter, block, body_idx);
+ level = if options.compare {
+ log::debug!("Assuming {:?} is zero", lod_handle);
+ crate::SampleLevel::Zero
+ } else {
+ crate::SampleLevel::Exact(lod_handle)
+ };
+ words_left -= 1;
+ }
+ spirv::ImageOperands::GRAD => {
+ let grad_x_expr = self.next()?;
+ let grad_x_lexp = self.lookup_expression.lookup(grad_x_expr)?;
+ let grad_x_handle = self.get_expr_handle(
+ grad_x_expr,
+ grad_x_lexp,
+ ctx,
+ emitter,
+ block,
+ body_idx,
+ );
+ let grad_y_expr = self.next()?;
+ let grad_y_lexp = self.lookup_expression.lookup(grad_y_expr)?;
+ let grad_y_handle = self.get_expr_handle(
+ grad_y_expr,
+ grad_y_lexp,
+ ctx,
+ emitter,
+ block,
+ body_idx,
+ );
+ level = if options.compare {
+ log::debug!(
+ "Assuming gradients {:?} and {:?} are not greater than 1",
+ grad_x_handle,
+ grad_y_handle
+ );
+ crate::SampleLevel::Zero
+ } else {
+ crate::SampleLevel::Gradient {
+ x: grad_x_handle,
+ y: grad_y_handle,
+ }
+ };
+ words_left -= 2;
+ }
+ spirv::ImageOperands::CONST_OFFSET => {
+ let offset_constant = self.next()?;
+ let offset_handle = self.lookup_constant.lookup(offset_constant)?.handle;
+ let offset_handle = ctx.const_expressions.append(
+ crate::Expression::Constant(offset_handle),
+ Default::default(),
+ );
+ offset = Some(offset_handle);
+ words_left -= 1;
+ }
+ other => {
+ log::warn!("Unknown image sample operand {:?}", other);
+ for _ in 0..words_left {
+ self.next()?;
+ }
+ break;
+ }
+ }
+ image_ops ^= bit;
+ }
+
+ let si_lexp = self.lookup_sampled_image.lookup(sampled_image_id)?;
+ let coord_lexp = self.lookup_expression.lookup(coordinate_id)?;
+ let coord_handle =
+ self.get_expr_handle(coordinate_id, coord_lexp, ctx, emitter, block, body_idx);
+ let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle;
+
+ let sampling_bit = if options.compare {
+ SamplingFlags::COMPARISON
+ } else {
+ SamplingFlags::REGULAR
+ };
+
+ let image_ty = match ctx.expressions[si_lexp.image] {
+ crate::Expression::GlobalVariable(handle) => {
+ if let Some(flags) = self.handle_sampling.get_mut(&handle) {
+ *flags |= sampling_bit;
+ }
+
+ ctx.global_arena[handle].ty
+ }
+
+ crate::Expression::FunctionArgument(i) => {
+ ctx.parameter_sampling[i as usize] |= sampling_bit;
+ ctx.arguments[i as usize].ty
+ }
+
+ crate::Expression::Access { base, .. } => match ctx.expressions[base] {
+ crate::Expression::GlobalVariable(handle) => {
+ if let Some(flags) = self.handle_sampling.get_mut(&handle) {
+ *flags |= sampling_bit;
+ }
+
+ match ctx.type_arena[ctx.global_arena[handle].ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => base,
+ _ => return Err(Error::InvalidGlobalVar(ctx.expressions[base].clone())),
+ }
+ }
+
+ ref other => return Err(Error::InvalidGlobalVar(other.clone())),
+ },
+
+ ref other => return Err(Error::InvalidGlobalVar(other.clone())),
+ };
+
+ match ctx.expressions[si_lexp.sampler] {
+ crate::Expression::GlobalVariable(handle) => {
+ *self.handle_sampling.get_mut(&handle).unwrap() |= sampling_bit;
+ }
+
+ crate::Expression::FunctionArgument(i) => {
+ ctx.parameter_sampling[i as usize] |= sampling_bit;
+ }
+
+ crate::Expression::Access { base, .. } => match ctx.expressions[base] {
+ crate::Expression::GlobalVariable(handle) => {
+ *self.handle_sampling.get_mut(&handle).unwrap() |= sampling_bit;
+ }
+
+ ref other => return Err(Error::InvalidGlobalVar(other.clone())),
+ },
+
+ ref other => return Err(Error::InvalidGlobalVar(other.clone())),
+ }
+
+ let ((coordinate, array_index), depth_ref) = match ctx.type_arena[image_ty].inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: _,
+ } => (
+ extract_image_coordinates(
+ dim,
+ if options.project {
+ ExtraCoordinate::Projection
+ } else if arrayed {
+ ExtraCoordinate::ArrayLayer
+ } else {
+ ExtraCoordinate::Garbage
+ },
+ coord_handle,
+ coord_type_handle,
+ ctx,
+ ),
+ {
+ match dref_id {
+ Some(id) => {
+ let expr_lexp = self.lookup_expression.lookup(id)?;
+ let mut expr =
+ self.get_expr_handle(id, expr_lexp, ctx, emitter, block, body_idx);
+
+ if options.project {
+ let required_size = dim.required_coordinate_size();
+ let right = ctx.expressions.append(
+ crate::Expression::AccessIndex {
+ base: coord_handle,
+ index: required_size.map_or(1, |size| size as u32),
+ },
+ crate::Span::default(),
+ );
+ expr = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Divide,
+ left: expr,
+ right,
+ },
+ crate::Span::default(),
+ )
+ };
+ Some(expr)
+ }
+ None => None,
+ }
+ },
+ ),
+ _ => return Err(Error::InvalidImage(image_ty)),
+ };
+
+ let expr = crate::Expression::ImageSample {
+ image: si_lexp.image,
+ sampler: si_lexp.sampler,
+ gather: None, //TODO
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ pub(super) fn parse_image_query_size(
+ &mut self,
+ at_level: bool,
+ ctx: &mut super::BlockContext,
+ emitter: &mut crate::proc::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let image_id = self.next()?;
+ let level = if at_level {
+ let level_id = self.next()?;
+ let level_lexp = self.lookup_expression.lookup(level_id)?;
+ Some(self.get_expr_handle(level_id, level_lexp, ctx, emitter, block, body_idx))
+ } else {
+ None
+ };
+
+ // No need to call get_expr_handle here since only globals/arguments are
+ // allowed as images and they are always in the root scope
+ //TODO: handle arrays and cubes
+ let image_lexp = self.lookup_expression.lookup(image_id)?;
+
+ let expr = crate::Expression::ImageQuery {
+ image: image_lexp.handle,
+ query: crate::ImageQuery::Size { level },
+ };
+
+ let result_type_handle = self.lookup_type.lookup(result_type_id)?.handle;
+ let maybe_scalar_kind = ctx.type_arena[result_type_handle].inner.scalar_kind();
+
+ let expr = if maybe_scalar_kind == Some(crate::ScalarKind::Sint) {
+ crate::Expression::As {
+ expr: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ kind: crate::ScalarKind::Sint,
+ convert: Some(4),
+ }
+ } else {
+ expr
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+
+ Ok(())
+ }
+
+ pub(super) fn parse_image_query_other(
+ &mut self,
+ query: crate::ImageQuery,
+ ctx: &mut super::BlockContext,
+ block_id: spirv::Word,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let image_id = self.next()?;
+
+ // No need to call get_expr_handle here since only globals/arguments are
+ // allowed as images and they are always in the root scope
+ let image_lexp = self.lookup_expression.lookup(image_id)?.clone();
+
+ let expr = crate::Expression::ImageQuery {
+ image: image_lexp.handle,
+ query,
+ };
+
+ let result_type_handle = self.lookup_type.lookup(result_type_id)?.handle;
+ let maybe_scalar_kind = ctx.type_arena[result_type_handle].inner.scalar_kind();
+
+ let expr = if maybe_scalar_kind == Some(crate::ScalarKind::Sint) {
+ crate::Expression::As {
+ expr: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ kind: crate::ScalarKind::Sint,
+ convert: Some(4),
+ }
+ } else {
+ expr
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/mod.rs b/third_party/rust/naga/src/front/spv/mod.rs
new file mode 100644
index 0000000000..8b1c854358
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/mod.rs
@@ -0,0 +1,5356 @@
+/*!
+Frontend for [SPIR-V][spv] (Standard Portable Intermediate Representation).
+
+## ID lookups
+
+Our IR links to everything with `Handle`, while SPIR-V uses IDs.
+In order to keep track of the associations, the parser has many lookup tables.
+There map `spv::Word` into a specific IR handle, plus potentially a bit of
+extra info, such as the related SPIR-V type ID.
+TODO: would be nice to find ways that avoid looking up as much
+
+## Inputs/Outputs
+
+We create a private variable for each input/output. The relevant inputs are
+populated at the start of an entry point. The outputs are saved at the end.
+
+The function associated with an entry point is wrapped in another function,
+such that we can handle any `Return` statements without problems.
+
+## Row-major matrices
+
+We don't handle them natively, since the IR only expects column majority.
+Instead, we detect when such matrix is accessed in the `OpAccessChain`,
+and we generate a parallel expression that loads the value, but transposed.
+This value then gets used instead of `OpLoad` result later on.
+
+[spv]: https://www.khronos.org/registry/SPIR-V/
+*/
+
+mod convert;
+mod error;
+mod function;
+mod image;
+mod null;
+
+use convert::*;
+pub use error::Error;
+use function::*;
+
+use crate::{
+ arena::{Arena, Handle, UniqueArena},
+ proc::{Alignment, Layouter},
+ FastHashMap, FastHashSet, FastIndexMap,
+};
+
+use petgraph::graphmap::GraphMap;
+use std::{convert::TryInto, mem, num::NonZeroU32, path::PathBuf};
+
+pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[
+ spirv::Capability::Shader,
+ spirv::Capability::VulkanMemoryModel,
+ spirv::Capability::ClipDistance,
+ spirv::Capability::CullDistance,
+ spirv::Capability::SampleRateShading,
+ spirv::Capability::DerivativeControl,
+ spirv::Capability::Matrix,
+ spirv::Capability::ImageQuery,
+ spirv::Capability::Sampled1D,
+ spirv::Capability::Image1D,
+ spirv::Capability::SampledCubeArray,
+ spirv::Capability::ImageCubeArray,
+ spirv::Capability::StorageImageExtendedFormats,
+ spirv::Capability::Int8,
+ spirv::Capability::Int16,
+ spirv::Capability::Int64,
+ spirv::Capability::Float16,
+ spirv::Capability::Float64,
+ spirv::Capability::Geometry,
+ spirv::Capability::MultiView,
+ // tricky ones
+ spirv::Capability::UniformBufferArrayDynamicIndexing,
+ spirv::Capability::StorageBufferArrayDynamicIndexing,
+];
+pub const SUPPORTED_EXTENSIONS: &[&str] = &[
+ "SPV_KHR_storage_buffer_storage_class",
+ "SPV_KHR_vulkan_memory_model",
+ "SPV_KHR_multiview",
+];
+pub const SUPPORTED_EXT_SETS: &[&str] = &["GLSL.std.450"];
+
+#[derive(Copy, Clone)]
+pub struct Instruction {
+ op: spirv::Op,
+ wc: u16,
+}
+
+impl Instruction {
+ const fn expect(self, count: u16) -> Result<(), Error> {
+ if self.wc == count {
+ Ok(())
+ } else {
+ Err(Error::InvalidOperandCount(self.op, self.wc))
+ }
+ }
+
+ fn expect_at_least(self, count: u16) -> Result<u16, Error> {
+ self.wc
+ .checked_sub(count)
+ .ok_or(Error::InvalidOperandCount(self.op, self.wc))
+ }
+}
+
+impl crate::TypeInner {
+ fn can_comparison_sample(&self, module: &crate::Module) -> bool {
+ match *self {
+ crate::TypeInner::Image {
+ class:
+ crate::ImageClass::Sampled {
+ kind: crate::ScalarKind::Float,
+ multi: false,
+ },
+ ..
+ } => true,
+ crate::TypeInner::Sampler { .. } => true,
+ crate::TypeInner::BindingArray { base, .. } => {
+ module.types[base].inner.can_comparison_sample(module)
+ }
+ _ => false,
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
+pub enum ModuleState {
+ Empty,
+ Capability,
+ Extension,
+ ExtInstImport,
+ MemoryModel,
+ EntryPoint,
+ ExecutionMode,
+ Source,
+ Name,
+ ModuleProcessed,
+ Annotation,
+ Type,
+ Function,
+}
+
+trait LookupHelper {
+ type Target;
+ fn lookup(&self, key: spirv::Word) -> Result<&Self::Target, Error>;
+}
+
+impl<T> LookupHelper for FastHashMap<spirv::Word, T> {
+ type Target = T;
+ fn lookup(&self, key: spirv::Word) -> Result<&T, Error> {
+ self.get(&key).ok_or(Error::InvalidId(key))
+ }
+}
+
+impl crate::ImageDimension {
+ const fn required_coordinate_size(&self) -> Option<crate::VectorSize> {
+ match *self {
+ crate::ImageDimension::D1 => None,
+ crate::ImageDimension::D2 => Some(crate::VectorSize::Bi),
+ crate::ImageDimension::D3 => Some(crate::VectorSize::Tri),
+ crate::ImageDimension::Cube => Some(crate::VectorSize::Tri),
+ }
+ }
+}
+
+type MemberIndex = u32;
+
+bitflags::bitflags! {
+ #[derive(Clone, Copy, Debug, Default)]
+ struct DecorationFlags: u32 {
+ const NON_READABLE = 0x1;
+ const NON_WRITABLE = 0x2;
+ }
+}
+
+impl DecorationFlags {
+ fn to_storage_access(self) -> crate::StorageAccess {
+ let mut access = crate::StorageAccess::all();
+ if self.contains(DecorationFlags::NON_READABLE) {
+ access &= !crate::StorageAccess::LOAD;
+ }
+ if self.contains(DecorationFlags::NON_WRITABLE) {
+ access &= !crate::StorageAccess::STORE;
+ }
+ access
+ }
+}
+
+#[derive(Debug, PartialEq)]
+enum Majority {
+ Column,
+ Row,
+}
+
+#[derive(Debug, Default)]
+struct Decoration {
+ name: Option<String>,
+ built_in: Option<spirv::Word>,
+ location: Option<spirv::Word>,
+ desc_set: Option<spirv::Word>,
+ desc_index: Option<spirv::Word>,
+ specialization: Option<spirv::Word>,
+ storage_buffer: bool,
+ offset: Option<spirv::Word>,
+ array_stride: Option<NonZeroU32>,
+ matrix_stride: Option<NonZeroU32>,
+ matrix_major: Option<Majority>,
+ invariant: bool,
+ interpolation: Option<crate::Interpolation>,
+ sampling: Option<crate::Sampling>,
+ flags: DecorationFlags,
+}
+
+impl Decoration {
+ fn debug_name(&self) -> &str {
+ match self.name {
+ Some(ref name) => name.as_str(),
+ None => "?",
+ }
+ }
+
+ fn specialization(&self) -> crate::Override {
+ self.specialization
+ .map_or(crate::Override::None, crate::Override::ByNameOrId)
+ }
+
+ const fn resource_binding(&self) -> Option<crate::ResourceBinding> {
+ match *self {
+ Decoration {
+ desc_set: Some(group),
+ desc_index: Some(binding),
+ ..
+ } => Some(crate::ResourceBinding { group, binding }),
+ _ => None,
+ }
+ }
+
+ fn io_binding(&self) -> Result<crate::Binding, Error> {
+ match *self {
+ Decoration {
+ built_in: Some(built_in),
+ location: None,
+ invariant,
+ ..
+ } => Ok(crate::Binding::BuiltIn(map_builtin(built_in, invariant)?)),
+ Decoration {
+ built_in: None,
+ location: Some(location),
+ interpolation,
+ sampling,
+ ..
+ } => Ok(crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ second_blend_source: false,
+ }),
+ _ => Err(Error::MissingDecoration(spirv::Decoration::Location)),
+ }
+ }
+}
+
+#[derive(Debug)]
+struct LookupFunctionType {
+ parameter_type_ids: Vec<spirv::Word>,
+ return_type_id: spirv::Word,
+}
+
+struct LookupFunction {
+ handle: Handle<crate::Function>,
+ parameters_sampling: Vec<image::SamplingFlags>,
+}
+
+#[derive(Debug)]
+struct EntryPoint {
+ stage: crate::ShaderStage,
+ name: String,
+ early_depth_test: Option<crate::EarlyDepthTest>,
+ workgroup_size: [u32; 3],
+ variable_ids: Vec<spirv::Word>,
+}
+
+#[derive(Clone, Debug)]
+struct LookupType {
+ handle: Handle<crate::Type>,
+ base_id: Option<spirv::Word>,
+}
+
+#[derive(Debug)]
+struct LookupConstant {
+ handle: Handle<crate::Constant>,
+ type_id: spirv::Word,
+}
+
+#[derive(Debug)]
+enum Variable {
+ Global,
+ Input(crate::FunctionArgument),
+ Output(crate::FunctionResult),
+}
+
+#[derive(Debug)]
+struct LookupVariable {
+ inner: Variable,
+ handle: Handle<crate::GlobalVariable>,
+ type_id: spirv::Word,
+}
+
+/// Information about SPIR-V result ids, stored in `Parser::lookup_expression`.
+#[derive(Clone, Debug)]
+struct LookupExpression {
+ /// The `Expression` constructed for this result.
+ ///
+ /// Note that, while a SPIR-V result id can be used in any block dominated
+ /// by its definition, a Naga `Expression` is only in scope for the rest of
+ /// its subtree. `Parser::get_expr_handle` takes care of spilling the result
+ /// to a `LocalVariable` which can then be used anywhere.
+ handle: Handle<crate::Expression>,
+
+ /// The SPIR-V type of this result.
+ type_id: spirv::Word,
+
+ /// The label id of the block that defines this expression.
+ ///
+ /// This is zero for globals, constants, and function parameters, since they
+ /// originate outside any function's block.
+ block_id: spirv::Word,
+}
+
+#[derive(Debug)]
+struct LookupMember {
+ type_id: spirv::Word,
+ // This is true for either matrices, or arrays of matrices (yikes).
+ row_major: bool,
+}
+
+#[derive(Clone, Debug)]
+enum LookupLoadOverride {
+ /// For arrays of matrices, we track them but not loading yet.
+ Pending,
+ /// For matrices, vectors, and scalars, we pre-load the data.
+ Loaded(Handle<crate::Expression>),
+}
+
+#[derive(PartialEq)]
+enum ExtendedClass {
+ Global(crate::AddressSpace),
+ Input,
+ Output,
+}
+
+#[derive(Clone, Debug)]
+pub struct Options {
+ /// The IR coordinate space matches all the APIs except SPIR-V,
+ /// so by default we flip the Y coordinate of the `BuiltIn::Position`.
+ /// This flag can be used to avoid this.
+ pub adjust_coordinate_space: bool,
+ /// Only allow shaders with the known set of capabilities.
+ pub strict_capabilities: bool,
+ pub block_ctx_dump_prefix: Option<PathBuf>,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ adjust_coordinate_space: true,
+ strict_capabilities: false,
+ block_ctx_dump_prefix: None,
+ }
+ }
+}
+
+/// An index into the `BlockContext::bodies` table.
+type BodyIndex = usize;
+
+/// An intermediate representation of a Naga [`Statement`].
+///
+/// `Body` and `BodyFragment` values form a tree: the `BodyIndex` fields of the
+/// variants are indices of the child `Body` values in [`BlockContext::bodies`].
+/// The `lower` function assembles the final `Statement` tree from this `Body`
+/// tree. See [`BlockContext`] for details.
+///
+/// [`Statement`]: crate::Statement
+#[derive(Debug)]
+enum BodyFragment {
+ BlockId(spirv::Word),
+ If {
+ condition: Handle<crate::Expression>,
+ accept: BodyIndex,
+ reject: BodyIndex,
+ },
+ Loop {
+ /// The body of the loop. Its [`Body::parent`] is the block containing
+ /// this `Loop` fragment.
+ body: BodyIndex,
+
+ /// The loop's continuing block. This is a grandchild: its
+ /// [`Body::parent`] is the loop body block, whose index is above.
+ continuing: BodyIndex,
+
+ /// If the SPIR-V loop's back-edge branch is conditional, this is the
+ /// expression that must be `false` for the back-edge to be taken, with
+ /// `true` being for the "loop merge" (which breaks out of the loop).
+ break_if: Option<Handle<crate::Expression>>,
+ },
+ Switch {
+ selector: Handle<crate::Expression>,
+ cases: Vec<(i32, BodyIndex)>,
+ default: BodyIndex,
+ },
+ Break,
+ Continue,
+}
+
+/// An intermediate representation of a Naga [`Block`].
+///
+/// This will be assembled into a `Block` once we've added spills for phi nodes
+/// and out-of-scope expressions. See [`BlockContext`] for details.
+///
+/// [`Block`]: crate::Block
+#[derive(Debug)]
+struct Body {
+ /// The index of the direct parent of this body
+ parent: usize,
+ data: Vec<BodyFragment>,
+}
+
+impl Body {
+ /// Creates a new empty `Body` with the specified `parent`
+ pub const fn with_parent(parent: usize) -> Self {
+ Body {
+ parent,
+ data: Vec::new(),
+ }
+ }
+}
+
+#[derive(Debug)]
+struct PhiExpression {
+ /// The local variable used for the phi node
+ local: Handle<crate::LocalVariable>,
+ /// List of (expression, block)
+ expressions: Vec<(spirv::Word, spirv::Word)>,
+}
+
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+enum MergeBlockInformation {
+ LoopMerge,
+ LoopContinue,
+ SelectionMerge,
+ SwitchMerge,
+}
+
+/// Fragments of Naga IR, to be assembled into `Statements` once data flow is
+/// resolved.
+///
+/// We can't build a Naga `Statement` tree directly from SPIR-V blocks for three
+/// main reasons:
+///
+/// - We parse a function's SPIR-V blocks in the order they appear in the file.
+/// Within a function, SPIR-V requires that a block must precede any blocks it
+/// structurally dominates, but doesn't say much else about the order in which
+/// they must appear. So while we know we'll see control flow header blocks
+/// before their child constructs and merge blocks, those children and the
+/// merge blocks may appear in any order - perhaps even intermingled with
+/// children of other constructs.
+///
+/// - A SPIR-V expression can be used in any SPIR-V block dominated by its
+/// definition, whereas Naga expressions are scoped to the rest of their
+/// subtree. This means that discovering an expression use later in the
+/// function retroactively requires us to have spilled that expression into a
+/// local variable back before we left its scope.
+///
+/// - We translate SPIR-V OpPhi expressions as Naga local variables in which we
+/// store the appropriate value before jumping to the OpPhi's block.
+///
+/// All these cases require us to go back and amend previously generated Naga IR
+/// based on things we discover later. But modifying old blocks in arbitrary
+/// spots in a `Statement` tree is awkward.
+///
+/// Instead, as we iterate through the function's body, we accumulate
+/// control-flow-free fragments of Naga IR in the [`blocks`] table, while
+/// building a skeleton of the Naga `Statement` tree in [`bodies`]. We note any
+/// spills and temporaries we must introduce in [`phis`].
+///
+/// Finally, once we've processed the entire function, we add temporaries and
+/// spills to the fragmentary `Blocks` as directed by `phis`, and assemble them
+/// into the final Naga `Statement` tree as directed by `bodies`.
+///
+/// [`blocks`]: BlockContext::blocks
+/// [`bodies`]: BlockContext::bodies
+/// [`phis`]: BlockContext::phis
+/// [`lower`]: function::lower
+#[derive(Debug)]
+struct BlockContext<'function> {
+ /// Phi nodes encountered when parsing the function, used to generate spills
+ /// to local variables.
+ phis: Vec<PhiExpression>,
+
+ /// Fragments of control-flow-free Naga IR.
+ ///
+ /// These will be stitched together into a proper [`Statement`] tree according
+ /// to `bodies`, once parsing is complete.
+ ///
+ /// [`Statement`]: crate::Statement
+ blocks: FastHashMap<spirv::Word, crate::Block>,
+
+ /// Map from each SPIR-V block's label id to the index of the [`Body`] in
+ /// [`bodies`] the block should append its contents to.
+ ///
+ /// Since each statement in a Naga [`Block`] dominates the next, we are sure
+ /// to encounter their SPIR-V blocks in order. Thus, by having this table
+ /// map a SPIR-V structured control flow construct's merge block to the same
+ /// body index as its header block, when we encounter the merge block, we
+ /// will simply pick up building the [`Body`] where the header left off.
+ ///
+ /// A function's first block is special: it is the only block we encounter
+ /// without having seen its label mentioned in advance. (It's simply the
+ /// first `OpLabel` after the `OpFunction`.) We thus assume that any block
+ /// missing an entry here must be the first block, which always has body
+ /// index zero.
+ ///
+ /// [`bodies`]: BlockContext::bodies
+ /// [`Block`]: crate::Block
+ body_for_label: FastHashMap<spirv::Word, BodyIndex>,
+
+ /// SPIR-V metadata about merge/continue blocks.
+ mergers: FastHashMap<spirv::Word, MergeBlockInformation>,
+
+ /// A table of `Body` values, each representing a block in the final IR.
+ ///
+ /// The first element is always the function's top-level block.
+ bodies: Vec<Body>,
+
+ /// Id of the function currently being processed
+ function_id: spirv::Word,
+ /// Expression arena of the function currently being processed
+ expressions: &'function mut Arena<crate::Expression>,
+ /// Local variables arena of the function currently being processed
+ local_arena: &'function mut Arena<crate::LocalVariable>,
+ /// Constants arena of the module being processed
+ const_arena: &'function mut Arena<crate::Constant>,
+ const_expressions: &'function mut Arena<crate::Expression>,
+ /// Type arena of the module being processed
+ type_arena: &'function UniqueArena<crate::Type>,
+ /// Global arena of the module being processed
+ global_arena: &'function Arena<crate::GlobalVariable>,
+ /// Arguments of the function currently being processed
+ arguments: &'function [crate::FunctionArgument],
+ /// Metadata about the usage of function parameters as sampling objects
+ parameter_sampling: &'function mut [image::SamplingFlags],
+}
+
+enum SignAnchor {
+ Result,
+ Operand,
+}
+
+pub struct Frontend<I> {
+ data: I,
+ data_offset: usize,
+ state: ModuleState,
+ layouter: Layouter,
+ temp_bytes: Vec<u8>,
+ ext_glsl_id: Option<spirv::Word>,
+ future_decor: FastHashMap<spirv::Word, Decoration>,
+ future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>,
+ lookup_member: FastHashMap<(Handle<crate::Type>, MemberIndex), LookupMember>,
+ handle_sampling: FastHashMap<Handle<crate::GlobalVariable>, image::SamplingFlags>,
+ lookup_type: FastHashMap<spirv::Word, LookupType>,
+ lookup_void_type: Option<spirv::Word>,
+ lookup_storage_buffer_types: FastHashMap<Handle<crate::Type>, crate::StorageAccess>,
+ // Lookup for samplers and sampled images, storing flags on how they are used.
+ lookup_constant: FastHashMap<spirv::Word, LookupConstant>,
+ lookup_variable: FastHashMap<spirv::Word, LookupVariable>,
+ lookup_expression: FastHashMap<spirv::Word, LookupExpression>,
+ // Load overrides are used to work around row-major matrices
+ lookup_load_override: FastHashMap<spirv::Word, LookupLoadOverride>,
+ lookup_sampled_image: FastHashMap<spirv::Word, image::LookupSampledImage>,
+ lookup_function_type: FastHashMap<spirv::Word, LookupFunctionType>,
+ lookup_function: FastHashMap<spirv::Word, LookupFunction>,
+ lookup_entry_point: FastHashMap<spirv::Word, EntryPoint>,
+ //Note: each `OpFunctionCall` gets a single entry here, indexed by the
+ // dummy `Handle<crate::Function>` of the call site.
+ deferred_function_calls: Vec<spirv::Word>,
+ dummy_functions: Arena<crate::Function>,
+ // Graph of all function calls through the module.
+ // It's used to sort the functions (as nodes) topologically,
+ // so that in the IR any called function is already known.
+ function_call_graph: GraphMap<spirv::Word, (), petgraph::Directed>,
+ options: Options,
+
+ /// Maps for a switch from a case target to the respective body and associated literals that
+ /// use that target block id.
+ ///
+ /// Used to preserve allocations between instruction parsing.
+ switch_cases: FastIndexMap<spirv::Word, (BodyIndex, Vec<i32>)>,
+
+ /// Tracks access to gl_PerVertex's builtins, it is used to cull unused builtins since initializing those can
+ /// affect performance and the mere presence of some of these builtins might cause backends to error since they
+ /// might be unsupported.
+ ///
+ /// The problematic builtins are: PointSize, ClipDistance and CullDistance.
+ ///
+ /// glslang declares those by default even though they are never written to
+ /// (see <https://github.com/KhronosGroup/glslang/issues/1868>)
+ gl_per_vertex_builtin_access: FastHashSet<crate::BuiltIn>,
+}
+
+impl<I: Iterator<Item = u32>> Frontend<I> {
+ pub fn new(data: I, options: &Options) -> Self {
+ Frontend {
+ data,
+ data_offset: 0,
+ state: ModuleState::Empty,
+ layouter: Layouter::default(),
+ temp_bytes: Vec::new(),
+ ext_glsl_id: None,
+ future_decor: FastHashMap::default(),
+ future_member_decor: FastHashMap::default(),
+ handle_sampling: FastHashMap::default(),
+ lookup_member: FastHashMap::default(),
+ lookup_type: FastHashMap::default(),
+ lookup_void_type: None,
+ lookup_storage_buffer_types: FastHashMap::default(),
+ lookup_constant: FastHashMap::default(),
+ lookup_variable: FastHashMap::default(),
+ lookup_expression: FastHashMap::default(),
+ lookup_load_override: FastHashMap::default(),
+ lookup_sampled_image: FastHashMap::default(),
+ lookup_function_type: FastHashMap::default(),
+ lookup_function: FastHashMap::default(),
+ lookup_entry_point: FastHashMap::default(),
+ deferred_function_calls: Vec::default(),
+ dummy_functions: Arena::new(),
+ function_call_graph: GraphMap::new(),
+ options: options.clone(),
+ switch_cases: FastIndexMap::default(),
+ gl_per_vertex_builtin_access: FastHashSet::default(),
+ }
+ }
+
+ fn span_from(&self, from: usize) -> crate::Span {
+ crate::Span::from(from..self.data_offset)
+ }
+
+ fn span_from_with_op(&self, from: usize) -> crate::Span {
+ crate::Span::from((from - 4)..self.data_offset)
+ }
+
+ fn next(&mut self) -> Result<u32, Error> {
+ if let Some(res) = self.data.next() {
+ self.data_offset += 4;
+ Ok(res)
+ } else {
+ Err(Error::IncompleteData)
+ }
+ }
+
+ fn next_inst(&mut self) -> Result<Instruction, Error> {
+ let word = self.next()?;
+ let (wc, opcode) = ((word >> 16) as u16, (word & 0xffff) as u16);
+ if wc == 0 {
+ return Err(Error::InvalidWordCount);
+ }
+ let op = spirv::Op::from_u32(opcode as u32).ok_or(Error::UnknownInstruction(opcode))?;
+
+ Ok(Instruction { op, wc })
+ }
+
+ fn next_string(&mut self, mut count: u16) -> Result<(String, u16), Error> {
+ self.temp_bytes.clear();
+ loop {
+ if count == 0 {
+ return Err(Error::BadString);
+ }
+ count -= 1;
+ let chars = self.next()?.to_le_bytes();
+ let pos = chars.iter().position(|&c| c == 0).unwrap_or(4);
+ self.temp_bytes.extend_from_slice(&chars[..pos]);
+ if pos < 4 {
+ break;
+ }
+ }
+ std::str::from_utf8(&self.temp_bytes)
+ .map(|s| (s.to_owned(), count))
+ .map_err(|_| Error::BadString)
+ }
+
+ fn next_decoration(
+ &mut self,
+ inst: Instruction,
+ base_words: u16,
+ dec: &mut Decoration,
+ ) -> Result<(), Error> {
+ let raw = self.next()?;
+ let dec_typed = spirv::Decoration::from_u32(raw).ok_or(Error::InvalidDecoration(raw))?;
+ log::trace!("\t\t{}: {:?}", dec.debug_name(), dec_typed);
+ match dec_typed {
+ spirv::Decoration::BuiltIn => {
+ inst.expect(base_words + 2)?;
+ dec.built_in = Some(self.next()?);
+ }
+ spirv::Decoration::Location => {
+ inst.expect(base_words + 2)?;
+ dec.location = Some(self.next()?);
+ }
+ spirv::Decoration::DescriptorSet => {
+ inst.expect(base_words + 2)?;
+ dec.desc_set = Some(self.next()?);
+ }
+ spirv::Decoration::Binding => {
+ inst.expect(base_words + 2)?;
+ dec.desc_index = Some(self.next()?);
+ }
+ spirv::Decoration::BufferBlock => {
+ dec.storage_buffer = true;
+ }
+ spirv::Decoration::Offset => {
+ inst.expect(base_words + 2)?;
+ dec.offset = Some(self.next()?);
+ }
+ spirv::Decoration::ArrayStride => {
+ inst.expect(base_words + 2)?;
+ dec.array_stride = NonZeroU32::new(self.next()?);
+ }
+ spirv::Decoration::MatrixStride => {
+ inst.expect(base_words + 2)?;
+ dec.matrix_stride = NonZeroU32::new(self.next()?);
+ }
+ spirv::Decoration::Invariant => {
+ dec.invariant = true;
+ }
+ spirv::Decoration::NoPerspective => {
+ dec.interpolation = Some(crate::Interpolation::Linear);
+ }
+ spirv::Decoration::Flat => {
+ dec.interpolation = Some(crate::Interpolation::Flat);
+ }
+ spirv::Decoration::Centroid => {
+ dec.sampling = Some(crate::Sampling::Centroid);
+ }
+ spirv::Decoration::Sample => {
+ dec.sampling = Some(crate::Sampling::Sample);
+ }
+ spirv::Decoration::NonReadable => {
+ dec.flags |= DecorationFlags::NON_READABLE;
+ }
+ spirv::Decoration::NonWritable => {
+ dec.flags |= DecorationFlags::NON_WRITABLE;
+ }
+ spirv::Decoration::ColMajor => {
+ dec.matrix_major = Some(Majority::Column);
+ }
+ spirv::Decoration::RowMajor => {
+ dec.matrix_major = Some(Majority::Row);
+ }
+ spirv::Decoration::SpecId => {
+ dec.specialization = Some(self.next()?);
+ }
+ other => {
+ log::warn!("Unknown decoration {:?}", other);
+ for _ in base_words + 1..inst.wc {
+ let _var = self.next()?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ /// Return the Naga `Expression` for a given SPIR-V result `id`.
+ ///
+ /// `lookup` must be the `LookupExpression` for `id`.
+ ///
+ /// SPIR-V result ids can be used by any block dominated by the id's
+ /// definition, but Naga `Expressions` are only in scope for the remainder
+ /// of their `Statement` subtree. This means that the `Expression` generated
+ /// for `id` may no longer be in scope. In such cases, this function takes
+ /// care of spilling the value of `id` to a `LocalVariable` which can then
+ /// be used anywhere. The SPIR-V domination rule ensures that the
+ /// `LocalVariable` has been initialized before it is used.
+ ///
+ /// The `body_idx` argument should be the index of the `Body` that hopes to
+ /// use `id`'s `Expression`.
+ fn get_expr_handle(
+ &self,
+ id: spirv::Word,
+ lookup: &LookupExpression,
+ ctx: &mut BlockContext,
+ emitter: &mut crate::proc::Emitter,
+ block: &mut crate::Block,
+ body_idx: BodyIndex,
+ ) -> Handle<crate::Expression> {
+ // What `Body` was `id` defined in?
+ let expr_body_idx = ctx
+ .body_for_label
+ .get(&lookup.block_id)
+ .copied()
+ .unwrap_or(0);
+
+ // Don't need to do a load/store if the expression is in the main body
+ // or if the expression is in the same body as where the query was
+ // requested. The body_idx might actually not be the final one if a loop
+ // or conditional occurs but in those cases we know that the new body
+ // will be a subscope of the body that was passed so we can still reuse
+ // the handle and not issue a load/store.
+ if is_parent(body_idx, expr_body_idx, ctx) {
+ lookup.handle
+ } else {
+ // Add a temporary variable of the same type which will be used to
+ // store the original expression and used in the current block
+ let ty = self.lookup_type[&lookup.type_id].handle;
+ let local = ctx.local_arena.append(
+ crate::LocalVariable {
+ name: None,
+ ty,
+ init: None,
+ },
+ crate::Span::default(),
+ );
+
+ block.extend(emitter.finish(ctx.expressions));
+ let pointer = ctx.expressions.append(
+ crate::Expression::LocalVariable(local),
+ crate::Span::default(),
+ );
+ emitter.start(ctx.expressions);
+ let expr = ctx
+ .expressions
+ .append(crate::Expression::Load { pointer }, crate::Span::default());
+
+ // Add a slightly odd entry to the phi table, so that while `id`'s
+ // `Expression` is still in scope, the usual phi processing will
+ // spill its value to `local`, where we can find it later.
+ //
+ // This pretends that the block in which `id` is defined is the
+ // predecessor of some other block with a phi in it that cites id as
+ // one of its sources, and uses `local` as its variable. There is no
+ // such phi, but nobody needs to know that.
+ ctx.phis.push(PhiExpression {
+ local,
+ expressions: vec![(id, lookup.block_id)],
+ });
+
+ expr
+ }
+ }
+
+ fn parse_expr_unary_op(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut crate::proc::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::UnaryOperator,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p_id = self.next()?;
+
+ let p_lexp = self.lookup_expression.lookup(p_id)?;
+ let handle = self.get_expr_handle(p_id, p_lexp, ctx, emitter, block, body_idx);
+
+ let expr = crate::Expression::Unary { op, expr: handle };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_expr_binary_op(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut crate::proc::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::BinaryOperator,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx);
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx);
+
+ let expr = crate::Expression::Binary { op, left, right };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ /// A more complicated version of the unary op,
+ /// where we force the operand to have the same type as the result.
+ fn parse_expr_unary_op_sign_adjusted(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut crate::proc::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::UnaryOperator,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let span = self.span_from_with_op(start);
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx);
+
+ let result_lookup_ty = self.lookup_type.lookup(result_type_id)?;
+ let kind = ctx.type_arena[result_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+
+ let expr = crate::Expression::Unary {
+ op,
+ expr: if p1_lexp.type_id == result_type_id {
+ left
+ } else {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: left,
+ kind,
+ convert: None,
+ },
+ span,
+ )
+ },
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ /// A more complicated version of the binary op,
+ /// where we force the operand to have the same type as the result.
+ /// This is mostly needed for "i++" and "i--" coming from GLSL.
+ #[allow(clippy::too_many_arguments)]
+ fn parse_expr_binary_op_sign_adjusted(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut crate::proc::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::BinaryOperator,
+ // For arithmetic operations, we need the sign of operands to match the result.
+ // For boolean operations, however, the operands need to match the signs, but
+ // result is always different - a boolean.
+ anchor: SignAnchor,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+ let span = self.span_from_with_op(start);
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx);
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx);
+
+ let expected_type_id = match anchor {
+ SignAnchor::Result => result_type_id,
+ SignAnchor::Operand => p1_lexp.type_id,
+ };
+ let expected_lookup_ty = self.lookup_type.lookup(expected_type_id)?;
+ let kind = ctx.type_arena[expected_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+
+ let expr = crate::Expression::Binary {
+ op,
+ left: if p1_lexp.type_id == expected_type_id {
+ left
+ } else {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: left,
+ kind,
+ convert: None,
+ },
+ span,
+ )
+ },
+ right: if p2_lexp.type_id == expected_type_id {
+ right
+ } else {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: right,
+ kind,
+ convert: None,
+ },
+ span,
+ )
+ },
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ /// A version of the binary op where one or both of the arguments might need to be casted to a
+ /// specific integer kind (unsigned or signed), used for operations like OpINotEqual or
+ /// OpUGreaterThan.
+ #[allow(clippy::too_many_arguments)]
+ fn parse_expr_int_comparison(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut crate::proc::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::BinaryOperator,
+ kind: crate::ScalarKind,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+ let span = self.span_from_with_op(start);
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx);
+ let p1_lookup_ty = self.lookup_type.lookup(p1_lexp.type_id)?;
+ let p1_kind = ctx.type_arena[p1_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let right = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx);
+ let p2_lookup_ty = self.lookup_type.lookup(p2_lexp.type_id)?;
+ let p2_kind = ctx.type_arena[p2_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+
+ let expr = crate::Expression::Binary {
+ op,
+ left: if p1_kind == kind {
+ left
+ } else {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: left,
+ kind,
+ convert: None,
+ },
+ span,
+ )
+ },
+ right: if p2_kind == kind {
+ right
+ } else {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: right,
+ kind,
+ convert: None,
+ },
+ span,
+ )
+ },
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_expr_shift_op(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut crate::proc::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ op: crate::BinaryOperator,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+
+ let span = self.span_from_with_op(start);
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(p1_id, p1_lexp, ctx, emitter, block, body_idx);
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let p2_handle = self.get_expr_handle(p2_id, p2_lexp, ctx, emitter, block, body_idx);
+ // convert the shift to Uint
+ let right = ctx.expressions.append(
+ crate::Expression::As {
+ expr: p2_handle,
+ kind: crate::ScalarKind::Uint,
+ convert: None,
+ },
+ span,
+ );
+
+ let expr = crate::Expression::Binary { op, left, right };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_expr_derivative(
+ &mut self,
+ ctx: &mut BlockContext,
+ emitter: &mut crate::proc::Emitter,
+ block: &mut crate::Block,
+ block_id: spirv::Word,
+ body_idx: usize,
+ (axis, ctrl): (crate::DerivativeAxis, crate::DerivativeControl),
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let arg_id = self.next()?;
+
+ let arg_lexp = self.lookup_expression.lookup(arg_id)?;
+ let arg_handle = self.get_expr_handle(arg_id, arg_lexp, ctx, emitter, block, body_idx);
+
+ let expr = crate::Expression::Derivative {
+ axis,
+ ctrl,
+ expr: arg_handle,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, self.span_from_with_op(start)),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Ok(())
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn insert_composite(
+ &self,
+ root_expr: Handle<crate::Expression>,
+ root_type_id: spirv::Word,
+ object_expr: Handle<crate::Expression>,
+ selections: &[spirv::Word],
+ type_arena: &UniqueArena<crate::Type>,
+ expressions: &mut Arena<crate::Expression>,
+ span: crate::Span,
+ ) -> Result<Handle<crate::Expression>, Error> {
+ let selection = match selections.first() {
+ Some(&index) => index,
+ None => return Ok(object_expr),
+ };
+ let root_span = expressions.get_span(root_expr);
+ let root_lookup = self.lookup_type.lookup(root_type_id)?;
+
+ let (count, child_type_id) = match type_arena[root_lookup.handle].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let child_member = self
+ .lookup_member
+ .get(&(root_lookup.handle, selection))
+ .ok_or(Error::InvalidAccessType(root_type_id))?;
+ (members.len(), child_member.type_id)
+ }
+ crate::TypeInner::Array { size, .. } => {
+ let size = match size {
+ crate::ArraySize::Constant(size) => size.get(),
+ // A runtime sized array is not a composite type
+ crate::ArraySize::Dynamic => {
+ return Err(Error::InvalidAccessType(root_type_id))
+ }
+ };
+
+ let child_type_id = root_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(root_type_id))?;
+
+ (size as usize, child_type_id)
+ }
+ crate::TypeInner::Vector { size, .. }
+ | crate::TypeInner::Matrix { columns: size, .. } => {
+ let child_type_id = root_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(root_type_id))?;
+ (size as usize, child_type_id)
+ }
+ _ => return Err(Error::InvalidAccessType(root_type_id)),
+ };
+
+ let mut components = Vec::with_capacity(count);
+ for index in 0..count as u32 {
+ let expr = expressions.append(
+ crate::Expression::AccessIndex {
+ base: root_expr,
+ index,
+ },
+ if index == selection { span } else { root_span },
+ );
+ components.push(expr);
+ }
+ components[selection as usize] = self.insert_composite(
+ components[selection as usize],
+ child_type_id,
+ object_expr,
+ &selections[1..],
+ type_arena,
+ expressions,
+ span,
+ )?;
+
+ Ok(expressions.append(
+ crate::Expression::Compose {
+ ty: root_lookup.handle,
+ components,
+ },
+ span,
+ ))
+ }
+
+ /// Add the next SPIR-V block's contents to `block_ctx`.
+ ///
+ /// Except for the function's entry block, `block_id` should be the label of
+ /// a block we've seen mentioned before, with an entry in
+ /// `block_ctx.body_for_label` to tell us which `Body` it contributes to.
+ fn next_block(&mut self, block_id: spirv::Word, ctx: &mut BlockContext) -> Result<(), Error> {
+ // Extend `body` with the correct form for a branch to `target`.
+ fn merger(body: &mut Body, target: &MergeBlockInformation) {
+ body.data.push(match *target {
+ MergeBlockInformation::LoopContinue => BodyFragment::Continue,
+ MergeBlockInformation::LoopMerge | MergeBlockInformation::SwitchMerge => {
+ BodyFragment::Break
+ }
+
+ // Finishing a selection merge means just falling off the end of
+ // the `accept` or `reject` block of the `If` statement.
+ MergeBlockInformation::SelectionMerge => return,
+ })
+ }
+
+ let mut emitter = crate::proc::Emitter::default();
+ emitter.start(ctx.expressions);
+
+ // Find the `Body` to which this block contributes.
+ //
+ // If this is some SPIR-V structured control flow construct's merge
+ // block, then `body_idx` will refer to the same `Body` as the header,
+ // so that we simply pick up accumulating the `Body` where the header
+ // left off. Each of the statements in a block dominates the next, so
+ // we're sure to encounter their SPIR-V blocks in order, ensuring that
+ // the `Body` will be assembled in the proper order.
+ //
+ // Note that, unlike every other kind of SPIR-V block, we don't know the
+ // function's first block's label in advance. Thus, we assume that if
+ // this block has no entry in `ctx.body_for_label`, it must be the
+ // function's first block. This always has body index zero.
+ let mut body_idx = *ctx.body_for_label.entry(block_id).or_default();
+
+ // The Naga IR block this call builds. This will end up as
+ // `ctx.blocks[&block_id]`, and `ctx.bodies[body_idx]` will refer to it
+ // via a `BodyFragment::BlockId`.
+ let mut block = crate::Block::new();
+
+ // Stores the merge block as defined by a `OpSelectionMerge` otherwise is `None`
+ //
+ // This is used in `OpSwitch` to promote the `MergeBlockInformation` from
+ // `SelectionMerge` to `SwitchMerge` to allow `Break`s this isn't desirable for
+ // `LoopMerge`s because otherwise `Continue`s wouldn't be allowed
+ let mut selection_merge_block = None;
+
+ macro_rules! get_expr_handle {
+ ($id:expr, $lexp:expr) => {
+ self.get_expr_handle($id, $lexp, ctx, &mut emitter, &mut block, body_idx)
+ };
+ }
+ macro_rules! parse_expr_op {
+ ($op:expr, BINARY) => {
+ self.parse_expr_binary_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op)
+ };
+
+ ($op:expr, SHIFT) => {
+ self.parse_expr_shift_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op)
+ };
+ ($op:expr, UNARY) => {
+ self.parse_expr_unary_op(ctx, &mut emitter, &mut block, block_id, body_idx, $op)
+ };
+ ($axis:expr, $ctrl:expr, DERIVATIVE) => {
+ self.parse_expr_derivative(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ ($axis, $ctrl),
+ )
+ };
+ }
+
+ let terminator = loop {
+ use spirv::Op;
+ let start = self.data_offset;
+ let inst = self.next_inst()?;
+ let span = crate::Span::from(start..(start + 4 * (inst.wc as usize)));
+ log::debug!("\t\t{:?} [{}]", inst.op, inst.wc);
+
+ match inst.op {
+ Op::Line => {
+ inst.expect(4)?;
+ let _file_id = self.next()?;
+ let _row_id = self.next()?;
+ let _col_id = self.next()?;
+ }
+ Op::NoLine => inst.expect(1)?,
+ Op::Undef => {
+ inst.expect(3)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let ty = type_lookup.handle;
+
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle: ctx
+ .expressions
+ .append(crate::Expression::ZeroValue(ty), span),
+ type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Variable => {
+ inst.expect_at_least(4)?;
+ block.extend(emitter.finish(ctx.expressions));
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let _storage_class = self.next()?;
+ let init = if inst.wc > 4 {
+ inst.expect(5)?;
+ let init_id = self.next()?;
+ let lconst = self.lookup_constant.lookup(init_id)?;
+ Some(
+ ctx.expressions
+ .append(crate::Expression::Constant(lconst.handle), span),
+ )
+ } else {
+ None
+ };
+
+ let name = self
+ .future_decor
+ .remove(&result_id)
+ .and_then(|decor| decor.name);
+ if let Some(ref name) = name {
+ log::debug!("\t\t\tid={} name={}", result_id, name);
+ }
+ let lookup_ty = self.lookup_type.lookup(result_type_id)?;
+ let var_handle = ctx.local_arena.append(
+ crate::LocalVariable {
+ name,
+ ty: match ctx.type_arena[lookup_ty.handle].inner {
+ crate::TypeInner::Pointer { base, .. } => base,
+ _ => lookup_ty.handle,
+ },
+ init,
+ },
+ span,
+ );
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx
+ .expressions
+ .append(crate::Expression::LocalVariable(var_handle), span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ emitter.start(ctx.expressions);
+ }
+ Op::Phi => {
+ inst.expect_at_least(3)?;
+ block.extend(emitter.finish(ctx.expressions));
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+
+ let name = format!("phi_{result_id}");
+ let local = ctx.local_arena.append(
+ crate::LocalVariable {
+ name: Some(name),
+ ty: self.lookup_type.lookup(result_type_id)?.handle,
+ init: None,
+ },
+ self.span_from(start),
+ );
+ let pointer = ctx
+ .expressions
+ .append(crate::Expression::LocalVariable(local), span);
+
+ let in_count = (inst.wc - 3) / 2;
+ let mut phi = PhiExpression {
+ local,
+ expressions: Vec::with_capacity(in_count as usize),
+ };
+ for _ in 0..in_count {
+ let expr = self.next()?;
+ let block = self.next()?;
+ phi.expressions.push((expr, block));
+ }
+
+ ctx.phis.push(phi);
+ emitter.start(ctx.expressions);
+
+ // Associate the lookup with an actual value, which is emitted
+ // into the current block.
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx
+ .expressions
+ .append(crate::Expression::Load { pointer }, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::AccessChain | Op::InBoundsAccessChain => {
+ struct AccessExpression {
+ base_handle: Handle<crate::Expression>,
+ type_id: spirv::Word,
+ load_override: Option<LookupLoadOverride>,
+ }
+
+ inst.expect_at_least(4)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ log::trace!("\t\t\tlooking up expr {:?}", base_id);
+
+ let mut acex = {
+ let lexp = self.lookup_expression.lookup(base_id)?;
+ let lty = self.lookup_type.lookup(lexp.type_id)?;
+
+ // HACK `OpAccessChain` and `OpInBoundsAccessChain`
+ // require for the result type to be a pointer, but if
+ // we're given a pointer to an image / sampler, it will
+ // be *already* dereferenced, since we do that early
+ // during `parse_type_pointer()`.
+ //
+ // This can happen only through `BindingArray`, since
+ // that's the only case where one can obtain a pointer
+ // to an image / sampler, and so let's match on that:
+ let dereference = match ctx.type_arena[lty.handle].inner {
+ crate::TypeInner::BindingArray { .. } => false,
+ _ => true,
+ };
+
+ let type_id = if dereference {
+ lty.base_id.ok_or(Error::InvalidAccessType(lexp.type_id))?
+ } else {
+ lexp.type_id
+ };
+
+ AccessExpression {
+ base_handle: get_expr_handle!(base_id, lexp),
+ type_id,
+ load_override: self.lookup_load_override.get(&base_id).cloned(),
+ }
+ };
+
+ for _ in 4..inst.wc {
+ let access_id = self.next()?;
+ log::trace!("\t\t\tlooking up index expr {:?}", access_id);
+ let index_expr = self.lookup_expression.lookup(access_id)?.clone();
+ let index_expr_handle = get_expr_handle!(access_id, &index_expr);
+ let index_expr_data = &ctx.expressions[index_expr.handle];
+ let index_maybe = match *index_expr_data {
+ crate::Expression::Constant(const_handle) => Some(
+ ctx.gctx()
+ .eval_expr_to_u32(ctx.const_arena[const_handle].init)
+ .map_err(|_| {
+ Error::InvalidAccess(crate::Expression::Constant(
+ const_handle,
+ ))
+ })?,
+ ),
+ _ => None,
+ };
+
+ log::trace!("\t\t\tlooking up type {:?}", acex.type_id);
+ let type_lookup = self.lookup_type.lookup(acex.type_id)?;
+ let ty = &ctx.type_arena[type_lookup.handle];
+ acex = match ty.inner {
+ // can only index a struct with a constant
+ crate::TypeInner::Struct { ref members, .. } => {
+ let index = index_maybe
+ .ok_or_else(|| Error::InvalidAccess(index_expr_data.clone()))?;
+
+ let lookup_member = self
+ .lookup_member
+ .get(&(type_lookup.handle, index))
+ .ok_or(Error::InvalidAccessType(acex.type_id))?;
+ let base_handle = ctx.expressions.append(
+ crate::Expression::AccessIndex {
+ base: acex.base_handle,
+ index,
+ },
+ span,
+ );
+
+ if ty.name.as_deref() == Some("gl_PerVertex") {
+ if let Some(crate::Binding::BuiltIn(built_in)) =
+ members[index as usize].binding
+ {
+ self.gl_per_vertex_builtin_access.insert(built_in);
+ }
+ }
+
+ AccessExpression {
+ base_handle,
+ type_id: lookup_member.type_id,
+ load_override: if lookup_member.row_major {
+ debug_assert!(acex.load_override.is_none());
+ let sub_type_lookup =
+ self.lookup_type.lookup(lookup_member.type_id)?;
+ Some(match ctx.type_arena[sub_type_lookup.handle].inner {
+ // load it transposed, to match column major expectations
+ crate::TypeInner::Matrix { .. } => {
+ let loaded = ctx.expressions.append(
+ crate::Expression::Load {
+ pointer: base_handle,
+ },
+ span,
+ );
+ let transposed = ctx.expressions.append(
+ crate::Expression::Math {
+ fun: crate::MathFunction::Transpose,
+ arg: loaded,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ span,
+ );
+ LookupLoadOverride::Loaded(transposed)
+ }
+ _ => LookupLoadOverride::Pending,
+ })
+ } else {
+ None
+ },
+ }
+ }
+ crate::TypeInner::Matrix { .. } => {
+ let load_override = match acex.load_override {
+ // We are indexing inside a row-major matrix
+ Some(LookupLoadOverride::Loaded(load_expr)) => {
+ let index = index_maybe.ok_or_else(|| {
+ Error::InvalidAccess(index_expr_data.clone())
+ })?;
+ let sub_handle = ctx.expressions.append(
+ crate::Expression::AccessIndex {
+ base: load_expr,
+ index,
+ },
+ span,
+ );
+ Some(LookupLoadOverride::Loaded(sub_handle))
+ }
+ _ => None,
+ };
+ let sub_expr = match index_maybe {
+ Some(index) => crate::Expression::AccessIndex {
+ base: acex.base_handle,
+ index,
+ },
+ None => crate::Expression::Access {
+ base: acex.base_handle,
+ index: index_expr_handle,
+ },
+ };
+ AccessExpression {
+ base_handle: ctx.expressions.append(sub_expr, span),
+ type_id: type_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(acex.type_id))?,
+ load_override,
+ }
+ }
+ // This must be a vector or an array.
+ _ => {
+ let base_handle = ctx.expressions.append(
+ crate::Expression::Access {
+ base: acex.base_handle,
+ index: index_expr_handle,
+ },
+ span,
+ );
+ let load_override = match acex.load_override {
+ // If there is a load override in place, then we always end up
+ // with a side-loaded value here.
+ Some(lookup_load_override) => {
+ let sub_expr = match lookup_load_override {
+ // We must be indexing into the array of row-major matrices.
+ // Let's load the result of indexing and transpose it.
+ LookupLoadOverride::Pending => {
+ let loaded = ctx.expressions.append(
+ crate::Expression::Load {
+ pointer: base_handle,
+ },
+ span,
+ );
+ ctx.expressions.append(
+ crate::Expression::Math {
+ fun: crate::MathFunction::Transpose,
+ arg: loaded,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ span,
+ )
+ }
+ // We are indexing inside a row-major matrix.
+ LookupLoadOverride::Loaded(load_expr) => {
+ ctx.expressions.append(
+ crate::Expression::Access {
+ base: load_expr,
+ index: index_expr_handle,
+ },
+ span,
+ )
+ }
+ };
+ Some(LookupLoadOverride::Loaded(sub_expr))
+ }
+ None => None,
+ };
+ AccessExpression {
+ base_handle,
+ type_id: type_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(acex.type_id))?,
+ load_override,
+ }
+ }
+ };
+ }
+
+ if let Some(load_expr) = acex.load_override {
+ self.lookup_load_override.insert(result_id, load_expr);
+ }
+ let lookup_expression = LookupExpression {
+ handle: acex.base_handle,
+ type_id: result_type_id,
+ block_id,
+ };
+ self.lookup_expression.insert(result_id, lookup_expression);
+ }
+ Op::VectorExtractDynamic => {
+ inst.expect(5)?;
+
+ let result_type_id = self.next()?;
+ let id = self.next()?;
+ let composite_id = self.next()?;
+ let index_id = self.next()?;
+
+ let root_lexp = self.lookup_expression.lookup(composite_id)?;
+ let root_handle = get_expr_handle!(composite_id, root_lexp);
+ let root_type_lookup = self.lookup_type.lookup(root_lexp.type_id)?;
+ let index_lexp = self.lookup_expression.lookup(index_id)?;
+ let index_handle = get_expr_handle!(index_id, index_lexp);
+ let index_type = self.lookup_type.lookup(index_lexp.type_id)?.handle;
+
+ let num_components = match ctx.type_arena[root_type_lookup.handle].inner {
+ crate::TypeInner::Vector { size, .. } => size as u32,
+ _ => return Err(Error::InvalidVectorType(root_type_lookup.handle)),
+ };
+
+ let mut make_index = |ctx: &mut BlockContext, index: u32| {
+ make_index_literal(
+ ctx,
+ index,
+ &mut block,
+ &mut emitter,
+ index_type,
+ index_lexp.type_id,
+ span,
+ )
+ };
+
+ let index_expr = make_index(ctx, 0)?;
+ let mut handle = ctx.expressions.append(
+ crate::Expression::Access {
+ base: root_handle,
+ index: index_expr,
+ },
+ span,
+ );
+ for index in 1..num_components {
+ let index_expr = make_index(ctx, index)?;
+ let access_expr = ctx.expressions.append(
+ crate::Expression::Access {
+ base: root_handle,
+ index: index_expr,
+ },
+ span,
+ );
+ let cond = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Equal,
+ left: index_expr,
+ right: index_handle,
+ },
+ span,
+ );
+ handle = ctx.expressions.append(
+ crate::Expression::Select {
+ condition: cond,
+ accept: access_expr,
+ reject: handle,
+ },
+ span,
+ );
+ }
+
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::VectorInsertDynamic => {
+ inst.expect(6)?;
+
+ let result_type_id = self.next()?;
+ let id = self.next()?;
+ let composite_id = self.next()?;
+ let object_id = self.next()?;
+ let index_id = self.next()?;
+
+ let object_lexp = self.lookup_expression.lookup(object_id)?;
+ let object_handle = get_expr_handle!(object_id, object_lexp);
+ let root_lexp = self.lookup_expression.lookup(composite_id)?;
+ let root_handle = get_expr_handle!(composite_id, root_lexp);
+ let root_type_lookup = self.lookup_type.lookup(root_lexp.type_id)?;
+ let index_lexp = self.lookup_expression.lookup(index_id)?;
+ let index_handle = get_expr_handle!(index_id, index_lexp);
+ let index_type = self.lookup_type.lookup(index_lexp.type_id)?.handle;
+
+ let num_components = match ctx.type_arena[root_type_lookup.handle].inner {
+ crate::TypeInner::Vector { size, .. } => size as u32,
+ _ => return Err(Error::InvalidVectorType(root_type_lookup.handle)),
+ };
+
+ let mut components = Vec::with_capacity(num_components as usize);
+ for index in 0..num_components {
+ let index_expr = make_index_literal(
+ ctx,
+ index,
+ &mut block,
+ &mut emitter,
+ index_type,
+ index_lexp.type_id,
+ span,
+ )?;
+ let access_expr = ctx.expressions.append(
+ crate::Expression::Access {
+ base: root_handle,
+ index: index_expr,
+ },
+ span,
+ );
+ let cond = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Equal,
+ left: index_expr,
+ right: index_handle,
+ },
+ span,
+ );
+ let handle = ctx.expressions.append(
+ crate::Expression::Select {
+ condition: cond,
+ accept: object_handle,
+ reject: access_expr,
+ },
+ span,
+ );
+ components.push(handle);
+ }
+ let handle = ctx.expressions.append(
+ crate::Expression::Compose {
+ ty: root_type_lookup.handle,
+ components,
+ },
+ span,
+ );
+
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::CompositeExtract => {
+ inst.expect_at_least(4)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ log::trace!("\t\t\tlooking up expr {:?}", base_id);
+ let mut lexp = self.lookup_expression.lookup(base_id)?.clone();
+ lexp.handle = get_expr_handle!(base_id, &lexp);
+ for _ in 4..inst.wc {
+ let index = self.next()?;
+ log::trace!("\t\t\tlooking up type {:?}", lexp.type_id);
+ let type_lookup = self.lookup_type.lookup(lexp.type_id)?;
+ let type_id = match ctx.type_arena[type_lookup.handle].inner {
+ crate::TypeInner::Struct { .. } => {
+ self.lookup_member
+ .get(&(type_lookup.handle, index))
+ .ok_or(Error::InvalidAccessType(lexp.type_id))?
+ .type_id
+ }
+ crate::TypeInner::Array { .. }
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. } => type_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(lexp.type_id))?,
+ ref other => {
+ log::warn!("composite type {:?}", other);
+ return Err(Error::UnsupportedType(type_lookup.handle));
+ }
+ };
+ lexp = LookupExpression {
+ handle: ctx.expressions.append(
+ crate::Expression::AccessIndex {
+ base: lexp.handle,
+ index,
+ },
+ span,
+ ),
+ type_id,
+ block_id,
+ };
+ }
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: lexp.handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::CompositeInsert => {
+ inst.expect_at_least(5)?;
+
+ let result_type_id = self.next()?;
+ let id = self.next()?;
+ let object_id = self.next()?;
+ let composite_id = self.next()?;
+ let mut selections = Vec::with_capacity(inst.wc as usize - 5);
+ for _ in 5..inst.wc {
+ selections.push(self.next()?);
+ }
+
+ let object_lexp = self.lookup_expression.lookup(object_id)?.clone();
+ let object_handle = get_expr_handle!(object_id, &object_lexp);
+ let root_lexp = self.lookup_expression.lookup(composite_id)?.clone();
+ let root_handle = get_expr_handle!(composite_id, &root_lexp);
+ let handle = self.insert_composite(
+ root_handle,
+ result_type_id,
+ object_handle,
+ &selections,
+ ctx.type_arena,
+ ctx.expressions,
+ span,
+ )?;
+
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::CompositeConstruct => {
+ inst.expect_at_least(3)?;
+
+ let result_type_id = self.next()?;
+ let id = self.next()?;
+ let mut components = Vec::with_capacity(inst.wc as usize - 2);
+ for _ in 3..inst.wc {
+ let comp_id = self.next()?;
+ log::trace!("\t\t\tlooking up expr {:?}", comp_id);
+ let lexp = self.lookup_expression.lookup(comp_id)?;
+ let handle = get_expr_handle!(comp_id, lexp);
+ components.push(handle);
+ }
+ let ty = self.lookup_type.lookup(result_type_id)?.handle;
+ let first = components[0];
+ let expr = match ctx.type_arena[ty].inner {
+ // this is an optimization to detect the splat
+ crate::TypeInner::Vector { size, .. }
+ if components.len() == size as usize
+ && components[1..].iter().all(|&c| c == first) =>
+ {
+ crate::Expression::Splat { size, value: first }
+ }
+ _ => crate::Expression::Compose { ty, components },
+ };
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Load => {
+ inst.expect_at_least(4)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let pointer_id = self.next()?;
+ if inst.wc != 4 {
+ inst.expect(5)?;
+ let _memory_access = self.next()?;
+ }
+
+ let base_lexp = self.lookup_expression.lookup(pointer_id)?;
+ let base_handle = get_expr_handle!(pointer_id, base_lexp);
+ let type_lookup = self.lookup_type.lookup(base_lexp.type_id)?;
+ let handle = match ctx.type_arena[type_lookup.handle].inner {
+ crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => {
+ base_handle
+ }
+ _ => match self.lookup_load_override.get(&pointer_id) {
+ Some(&LookupLoadOverride::Loaded(handle)) => handle,
+ //Note: we aren't handling `LookupLoadOverride::Pending` properly here
+ _ => ctx.expressions.append(
+ crate::Expression::Load {
+ pointer: base_handle,
+ },
+ span,
+ ),
+ },
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Store => {
+ inst.expect_at_least(3)?;
+
+ let pointer_id = self.next()?;
+ let value_id = self.next()?;
+ if inst.wc != 3 {
+ inst.expect(4)?;
+ let _memory_access = self.next()?;
+ }
+ let base_expr = self.lookup_expression.lookup(pointer_id)?;
+ let base_handle = get_expr_handle!(pointer_id, base_expr);
+ let value_expr = self.lookup_expression.lookup(value_id)?;
+ let value_handle = get_expr_handle!(value_id, value_expr);
+
+ block.extend(emitter.finish(ctx.expressions));
+ block.push(
+ crate::Statement::Store {
+ pointer: base_handle,
+ value: value_handle,
+ },
+ span,
+ );
+ emitter.start(ctx.expressions);
+ }
+ // Arithmetic Instructions +, -, *, /, %
+ Op::SNegate | Op::FNegate => {
+ inst.expect(4)?;
+ self.parse_expr_unary_op_sign_adjusted(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ crate::UnaryOperator::Negate,
+ )?;
+ }
+ Op::IAdd
+ | Op::ISub
+ | Op::IMul
+ | Op::BitwiseOr
+ | Op::BitwiseXor
+ | Op::BitwiseAnd
+ | Op::SDiv
+ | Op::SRem => {
+ inst.expect(5)?;
+ let operator = map_binary_operator(inst.op)?;
+ self.parse_expr_binary_op_sign_adjusted(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ operator,
+ SignAnchor::Result,
+ )?;
+ }
+ Op::IEqual | Op::INotEqual => {
+ inst.expect(5)?;
+ let operator = map_binary_operator(inst.op)?;
+ self.parse_expr_binary_op_sign_adjusted(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ operator,
+ SignAnchor::Operand,
+ )?;
+ }
+ Op::FAdd => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Add, BINARY)?;
+ }
+ Op::FSub => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Subtract, BINARY)?;
+ }
+ Op::FMul => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Multiply, BINARY)?;
+ }
+ Op::UDiv | Op::FDiv => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Divide, BINARY)?;
+ }
+ Op::UMod | Op::FRem => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Modulo, BINARY)?;
+ }
+ Op::SMod => {
+ inst.expect(5)?;
+
+ // x - y * int(floor(float(x) / float(y)))
+
+ let start = self.data_offset;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+ let span = self.span_from_with_op(start);
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(
+ p1_id,
+ p1_lexp,
+ ctx,
+ &mut emitter,
+ &mut block,
+ body_idx,
+ );
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let right = self.get_expr_handle(
+ p2_id,
+ p2_lexp,
+ ctx,
+ &mut emitter,
+ &mut block,
+ body_idx,
+ );
+
+ let result_ty = self.lookup_type.lookup(result_type_id)?;
+ let inner = &ctx.type_arena[result_ty.handle].inner;
+ let kind = inner.scalar_kind().unwrap();
+ let size = inner.size(ctx.gctx()) as u8;
+
+ let left_cast = ctx.expressions.append(
+ crate::Expression::As {
+ expr: left,
+ kind: crate::ScalarKind::Float,
+ convert: Some(size),
+ },
+ span,
+ );
+ let right_cast = ctx.expressions.append(
+ crate::Expression::As {
+ expr: right,
+ kind: crate::ScalarKind::Float,
+ convert: Some(size),
+ },
+ span,
+ );
+ let div = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Divide,
+ left: left_cast,
+ right: right_cast,
+ },
+ span,
+ );
+ let floor = ctx.expressions.append(
+ crate::Expression::Math {
+ fun: crate::MathFunction::Floor,
+ arg: div,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ span,
+ );
+ let cast = ctx.expressions.append(
+ crate::Expression::As {
+ expr: floor,
+ kind,
+ convert: Some(size),
+ },
+ span,
+ );
+ let mult = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Multiply,
+ left: cast,
+ right,
+ },
+ span,
+ );
+ let sub = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Subtract,
+ left,
+ right: mult,
+ },
+ span,
+ );
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: sub,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::FMod => {
+ inst.expect(5)?;
+
+ // x - y * floor(x / y)
+
+ let start = self.data_offset;
+ let span = self.span_from_with_op(start);
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let left = self.get_expr_handle(
+ p1_id,
+ p1_lexp,
+ ctx,
+ &mut emitter,
+ &mut block,
+ body_idx,
+ );
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+ let right = self.get_expr_handle(
+ p2_id,
+ p2_lexp,
+ ctx,
+ &mut emitter,
+ &mut block,
+ body_idx,
+ );
+
+ let div = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Divide,
+ left,
+ right,
+ },
+ span,
+ );
+ let floor = ctx.expressions.append(
+ crate::Expression::Math {
+ fun: crate::MathFunction::Floor,
+ arg: div,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ },
+ span,
+ );
+ let mult = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Multiply,
+ left: floor,
+ right,
+ },
+ span,
+ );
+ let sub = ctx.expressions.append(
+ crate::Expression::Binary {
+ op: crate::BinaryOperator::Subtract,
+ left,
+ right: mult,
+ },
+ span,
+ );
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: sub,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::VectorTimesScalar
+ | Op::VectorTimesMatrix
+ | Op::MatrixTimesScalar
+ | Op::MatrixTimesVector
+ | Op::MatrixTimesMatrix => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::Multiply, BINARY)?;
+ }
+ Op::Transpose => {
+ inst.expect(4)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let matrix_id = self.next()?;
+ let matrix_lexp = self.lookup_expression.lookup(matrix_id)?;
+ let matrix_handle = get_expr_handle!(matrix_id, matrix_lexp);
+ let expr = crate::Expression::Math {
+ fun: crate::MathFunction::Transpose,
+ arg: matrix_handle,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Dot => {
+ inst.expect(5)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let left_id = self.next()?;
+ let right_id = self.next()?;
+ let left_lexp = self.lookup_expression.lookup(left_id)?;
+ let left_handle = get_expr_handle!(left_id, left_lexp);
+ let right_lexp = self.lookup_expression.lookup(right_id)?;
+ let right_handle = get_expr_handle!(right_id, right_lexp);
+ let expr = crate::Expression::Math {
+ fun: crate::MathFunction::Dot,
+ arg: left_handle,
+ arg1: Some(right_handle),
+ arg2: None,
+ arg3: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::BitFieldInsert => {
+ inst.expect(7)?;
+
+ let start = self.data_offset;
+ let span = self.span_from_with_op(start);
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ let insert_id = self.next()?;
+ let offset_id = self.next()?;
+ let count_id = self.next()?;
+ let base_lexp = self.lookup_expression.lookup(base_id)?;
+ let base_handle = get_expr_handle!(base_id, base_lexp);
+ let insert_lexp = self.lookup_expression.lookup(insert_id)?;
+ let insert_handle = get_expr_handle!(insert_id, insert_lexp);
+ let offset_lexp = self.lookup_expression.lookup(offset_id)?;
+ let offset_handle = get_expr_handle!(offset_id, offset_lexp);
+ let offset_lookup_ty = self.lookup_type.lookup(offset_lexp.type_id)?;
+ let count_lexp = self.lookup_expression.lookup(count_id)?;
+ let count_handle = get_expr_handle!(count_id, count_lexp);
+ let count_lookup_ty = self.lookup_type.lookup(count_lexp.type_id)?;
+
+ let offset_kind = ctx.type_arena[offset_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+ let count_kind = ctx.type_arena[count_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+
+ let offset_cast_handle = if offset_kind != crate::ScalarKind::Uint {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: offset_handle,
+ kind: crate::ScalarKind::Uint,
+ convert: None,
+ },
+ span,
+ )
+ } else {
+ offset_handle
+ };
+
+ let count_cast_handle = if count_kind != crate::ScalarKind::Uint {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: count_handle,
+ kind: crate::ScalarKind::Uint,
+ convert: None,
+ },
+ span,
+ )
+ } else {
+ count_handle
+ };
+
+ let expr = crate::Expression::Math {
+ fun: crate::MathFunction::InsertBits,
+ arg: base_handle,
+ arg1: Some(insert_handle),
+ arg2: Some(offset_cast_handle),
+ arg3: Some(count_cast_handle),
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::BitFieldSExtract | Op::BitFieldUExtract => {
+ inst.expect(6)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ let offset_id = self.next()?;
+ let count_id = self.next()?;
+ let base_lexp = self.lookup_expression.lookup(base_id)?;
+ let base_handle = get_expr_handle!(base_id, base_lexp);
+ let offset_lexp = self.lookup_expression.lookup(offset_id)?;
+ let offset_handle = get_expr_handle!(offset_id, offset_lexp);
+ let offset_lookup_ty = self.lookup_type.lookup(offset_lexp.type_id)?;
+ let count_lexp = self.lookup_expression.lookup(count_id)?;
+ let count_handle = get_expr_handle!(count_id, count_lexp);
+ let count_lookup_ty = self.lookup_type.lookup(count_lexp.type_id)?;
+
+ let offset_kind = ctx.type_arena[offset_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+ let count_kind = ctx.type_arena[count_lookup_ty.handle]
+ .inner
+ .scalar_kind()
+ .unwrap();
+
+ let offset_cast_handle = if offset_kind != crate::ScalarKind::Uint {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: offset_handle,
+ kind: crate::ScalarKind::Uint,
+ convert: None,
+ },
+ span,
+ )
+ } else {
+ offset_handle
+ };
+
+ let count_cast_handle = if count_kind != crate::ScalarKind::Uint {
+ ctx.expressions.append(
+ crate::Expression::As {
+ expr: count_handle,
+ kind: crate::ScalarKind::Uint,
+ convert: None,
+ },
+ span,
+ )
+ } else {
+ count_handle
+ };
+
+ let expr = crate::Expression::Math {
+ fun: crate::MathFunction::ExtractBits,
+ arg: base_handle,
+ arg1: Some(offset_cast_handle),
+ arg2: Some(count_cast_handle),
+ arg3: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::BitReverse | Op::BitCount => {
+ inst.expect(4)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ let base_lexp = self.lookup_expression.lookup(base_id)?;
+ let base_handle = get_expr_handle!(base_id, base_lexp);
+ let expr = crate::Expression::Math {
+ fun: match inst.op {
+ Op::BitReverse => crate::MathFunction::ReverseBits,
+ Op::BitCount => crate::MathFunction::CountOneBits,
+ _ => unreachable!(),
+ },
+ arg: base_handle,
+ arg1: None,
+ arg2: None,
+ arg3: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::OuterProduct => {
+ inst.expect(5)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let left_id = self.next()?;
+ let right_id = self.next()?;
+ let left_lexp = self.lookup_expression.lookup(left_id)?;
+ let left_handle = get_expr_handle!(left_id, left_lexp);
+ let right_lexp = self.lookup_expression.lookup(right_id)?;
+ let right_handle = get_expr_handle!(right_id, right_lexp);
+ let expr = crate::Expression::Math {
+ fun: crate::MathFunction::Outer,
+ arg: left_handle,
+ arg1: Some(right_handle),
+ arg2: None,
+ arg3: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ // Bitwise instructions
+ Op::Not => {
+ inst.expect(4)?;
+ self.parse_expr_unary_op_sign_adjusted(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ crate::UnaryOperator::BitwiseNot,
+ )?;
+ }
+ Op::ShiftRightLogical => {
+ inst.expect(5)?;
+ //TODO: convert input and result to unsigned
+ parse_expr_op!(crate::BinaryOperator::ShiftRight, SHIFT)?;
+ }
+ Op::ShiftRightArithmetic => {
+ inst.expect(5)?;
+ //TODO: convert input and result to signed
+ parse_expr_op!(crate::BinaryOperator::ShiftRight, SHIFT)?;
+ }
+ Op::ShiftLeftLogical => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::ShiftLeft, SHIFT)?;
+ }
+ // Sampling
+ Op::Image => {
+ inst.expect(4)?;
+ self.parse_image_uncouple(block_id)?;
+ }
+ Op::SampledImage => {
+ inst.expect(5)?;
+ self.parse_image_couple()?;
+ }
+ Op::ImageWrite => {
+ let extra = inst.expect_at_least(4)?;
+ let stmt =
+ self.parse_image_write(extra, ctx, &mut emitter, &mut block, body_idx)?;
+ block.extend(emitter.finish(ctx.expressions));
+ block.push(stmt, span);
+ emitter.start(ctx.expressions);
+ }
+ Op::ImageFetch | Op::ImageRead => {
+ let extra = inst.expect_at_least(5)?;
+ self.parse_image_load(
+ extra,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageSampleImplicitLod | Op::ImageSampleExplicitLod => {
+ let extra = inst.expect_at_least(5)?;
+ let options = image::SamplingOptions {
+ compare: false,
+ project: false,
+ };
+ self.parse_image_sample(
+ extra,
+ options,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageSampleProjImplicitLod | Op::ImageSampleProjExplicitLod => {
+ let extra = inst.expect_at_least(5)?;
+ let options = image::SamplingOptions {
+ compare: false,
+ project: true,
+ };
+ self.parse_image_sample(
+ extra,
+ options,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageSampleDrefImplicitLod | Op::ImageSampleDrefExplicitLod => {
+ let extra = inst.expect_at_least(6)?;
+ let options = image::SamplingOptions {
+ compare: true,
+ project: false,
+ };
+ self.parse_image_sample(
+ extra,
+ options,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageSampleProjDrefImplicitLod | Op::ImageSampleProjDrefExplicitLod => {
+ let extra = inst.expect_at_least(6)?;
+ let options = image::SamplingOptions {
+ compare: true,
+ project: true,
+ };
+ self.parse_image_sample(
+ extra,
+ options,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageQuerySize => {
+ inst.expect(4)?;
+ self.parse_image_query_size(
+ false,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageQuerySizeLod => {
+ inst.expect(5)?;
+ self.parse_image_query_size(
+ true,
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ )?;
+ }
+ Op::ImageQueryLevels => {
+ inst.expect(4)?;
+ self.parse_image_query_other(crate::ImageQuery::NumLevels, ctx, block_id)?;
+ }
+ Op::ImageQuerySamples => {
+ inst.expect(4)?;
+ self.parse_image_query_other(crate::ImageQuery::NumSamples, ctx, block_id)?;
+ }
+ // other ops
+ Op::Select => {
+ inst.expect(6)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let condition = self.next()?;
+ let o1_id = self.next()?;
+ let o2_id = self.next()?;
+
+ let cond_lexp = self.lookup_expression.lookup(condition)?;
+ let cond_handle = get_expr_handle!(condition, cond_lexp);
+ let o1_lexp = self.lookup_expression.lookup(o1_id)?;
+ let o1_handle = get_expr_handle!(o1_id, o1_lexp);
+ let o2_lexp = self.lookup_expression.lookup(o2_id)?;
+ let o2_handle = get_expr_handle!(o2_id, o2_lexp);
+
+ let expr = crate::Expression::Select {
+ condition: cond_handle,
+ accept: o1_handle,
+ reject: o2_handle,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::VectorShuffle => {
+ inst.expect_at_least(5)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let v1_id = self.next()?;
+ let v2_id = self.next()?;
+
+ let v1_lexp = self.lookup_expression.lookup(v1_id)?;
+ let v1_lty = self.lookup_type.lookup(v1_lexp.type_id)?;
+ let v1_handle = get_expr_handle!(v1_id, v1_lexp);
+ let n1 = match ctx.type_arena[v1_lty.handle].inner {
+ crate::TypeInner::Vector { size, .. } => size as u32,
+ _ => return Err(Error::InvalidInnerType(v1_lexp.type_id)),
+ };
+ let v2_lexp = self.lookup_expression.lookup(v2_id)?;
+ let v2_lty = self.lookup_type.lookup(v2_lexp.type_id)?;
+ let v2_handle = get_expr_handle!(v2_id, v2_lexp);
+ let n2 = match ctx.type_arena[v2_lty.handle].inner {
+ crate::TypeInner::Vector { size, .. } => size as u32,
+ _ => return Err(Error::InvalidInnerType(v2_lexp.type_id)),
+ };
+
+ self.temp_bytes.clear();
+ let mut max_component = 0;
+ for _ in 5..inst.wc as usize {
+ let mut index = self.next()?;
+ if index == u32::MAX {
+ // treat Undefined as X
+ index = 0;
+ }
+ max_component = max_component.max(index);
+ self.temp_bytes.push(index as u8);
+ }
+
+ // Check for swizzle first.
+ let expr = if max_component < n1 {
+ use crate::SwizzleComponent as Sc;
+ let size = match self.temp_bytes.len() {
+ 2 => crate::VectorSize::Bi,
+ 3 => crate::VectorSize::Tri,
+ _ => crate::VectorSize::Quad,
+ };
+ let mut pattern = [Sc::X; 4];
+ for (pat, index) in pattern.iter_mut().zip(self.temp_bytes.drain(..)) {
+ *pat = match index {
+ 0 => Sc::X,
+ 1 => Sc::Y,
+ 2 => Sc::Z,
+ _ => Sc::W,
+ };
+ }
+ crate::Expression::Swizzle {
+ size,
+ vector: v1_handle,
+ pattern,
+ }
+ } else {
+ // Fall back to access + compose
+ let mut components = Vec::with_capacity(self.temp_bytes.len());
+ for index in self.temp_bytes.drain(..).map(|i| i as u32) {
+ let expr = if index < n1 {
+ crate::Expression::AccessIndex {
+ base: v1_handle,
+ index,
+ }
+ } else if index < n1 + n2 {
+ crate::Expression::AccessIndex {
+ base: v2_handle,
+ index: index - n1,
+ }
+ } else {
+ return Err(Error::InvalidAccessIndex(index));
+ };
+ components.push(ctx.expressions.append(expr, span));
+ }
+ crate::Expression::Compose {
+ ty: self.lookup_type.lookup(result_type_id)?.handle,
+ components,
+ }
+ };
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Bitcast
+ | Op::ConvertSToF
+ | Op::ConvertUToF
+ | Op::ConvertFToU
+ | Op::ConvertFToS
+ | Op::FConvert
+ | Op::UConvert
+ | Op::SConvert => {
+ inst.expect(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let value_id = self.next()?;
+
+ let value_lexp = self.lookup_expression.lookup(value_id)?;
+ let ty_lookup = self.lookup_type.lookup(result_type_id)?;
+ let scalar = match ctx.type_arena[ty_lookup.handle].inner {
+ crate::TypeInner::Scalar(scalar)
+ | crate::TypeInner::Vector { scalar, .. }
+ | crate::TypeInner::Matrix { scalar, .. } => scalar,
+ _ => return Err(Error::InvalidAsType(ty_lookup.handle)),
+ };
+
+ let expr = crate::Expression::As {
+ expr: get_expr_handle!(value_id, value_lexp),
+ kind: scalar.kind,
+ convert: if scalar.kind == crate::ScalarKind::Bool {
+ Some(crate::BOOL_WIDTH)
+ } else if inst.op == Op::Bitcast {
+ None
+ } else {
+ Some(scalar.width)
+ },
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::FunctionCall => {
+ inst.expect_at_least(4)?;
+ block.extend(emitter.finish(ctx.expressions));
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let func_id = self.next()?;
+
+ let mut arguments = Vec::with_capacity(inst.wc as usize - 4);
+ for _ in 0..arguments.capacity() {
+ let arg_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(arg_id)?;
+ arguments.push(get_expr_handle!(arg_id, lexp));
+ }
+
+ // We just need an unique handle here, nothing more.
+ let function = self.add_call(ctx.function_id, func_id);
+
+ let result = if self.lookup_void_type == Some(result_type_id) {
+ None
+ } else {
+ let expr_handle = ctx
+ .expressions
+ .append(crate::Expression::CallResult(function), span);
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expr_handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ Some(expr_handle)
+ };
+ block.push(
+ crate::Statement::Call {
+ function,
+ arguments,
+ result,
+ },
+ span,
+ );
+ emitter.start(ctx.expressions);
+ }
+ Op::ExtInst => {
+ use crate::MathFunction as Mf;
+ use spirv::GLOp as Glo;
+
+ let base_wc = 5;
+ inst.expect_at_least(base_wc)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let set_id = self.next()?;
+ if Some(set_id) != self.ext_glsl_id {
+ return Err(Error::UnsupportedExtInstSet(set_id));
+ }
+ let inst_id = self.next()?;
+ let gl_op = Glo::from_u32(inst_id).ok_or(Error::UnsupportedExtInst(inst_id))?;
+
+ let fun = match gl_op {
+ Glo::Round => Mf::Round,
+ Glo::RoundEven => Mf::Round,
+ Glo::Trunc => Mf::Trunc,
+ Glo::FAbs | Glo::SAbs => Mf::Abs,
+ Glo::FSign | Glo::SSign => Mf::Sign,
+ Glo::Floor => Mf::Floor,
+ Glo::Ceil => Mf::Ceil,
+ Glo::Fract => Mf::Fract,
+ Glo::Sin => Mf::Sin,
+ Glo::Cos => Mf::Cos,
+ Glo::Tan => Mf::Tan,
+ Glo::Asin => Mf::Asin,
+ Glo::Acos => Mf::Acos,
+ Glo::Atan => Mf::Atan,
+ Glo::Sinh => Mf::Sinh,
+ Glo::Cosh => Mf::Cosh,
+ Glo::Tanh => Mf::Tanh,
+ Glo::Atan2 => Mf::Atan2,
+ Glo::Asinh => Mf::Asinh,
+ Glo::Acosh => Mf::Acosh,
+ Glo::Atanh => Mf::Atanh,
+ Glo::Radians => Mf::Radians,
+ Glo::Degrees => Mf::Degrees,
+ Glo::Pow => Mf::Pow,
+ Glo::Exp => Mf::Exp,
+ Glo::Log => Mf::Log,
+ Glo::Exp2 => Mf::Exp2,
+ Glo::Log2 => Mf::Log2,
+ Glo::Sqrt => Mf::Sqrt,
+ Glo::InverseSqrt => Mf::InverseSqrt,
+ Glo::MatrixInverse => Mf::Inverse,
+ Glo::Determinant => Mf::Determinant,
+ Glo::ModfStruct => Mf::Modf,
+ Glo::FMin | Glo::UMin | Glo::SMin | Glo::NMin => Mf::Min,
+ Glo::FMax | Glo::UMax | Glo::SMax | Glo::NMax => Mf::Max,
+ Glo::FClamp | Glo::UClamp | Glo::SClamp | Glo::NClamp => Mf::Clamp,
+ Glo::FMix => Mf::Mix,
+ Glo::Step => Mf::Step,
+ Glo::SmoothStep => Mf::SmoothStep,
+ Glo::Fma => Mf::Fma,
+ Glo::FrexpStruct => Mf::Frexp,
+ Glo::Ldexp => Mf::Ldexp,
+ Glo::Length => Mf::Length,
+ Glo::Distance => Mf::Distance,
+ Glo::Cross => Mf::Cross,
+ Glo::Normalize => Mf::Normalize,
+ Glo::FaceForward => Mf::FaceForward,
+ Glo::Reflect => Mf::Reflect,
+ Glo::Refract => Mf::Refract,
+ Glo::PackUnorm4x8 => Mf::Pack4x8unorm,
+ Glo::PackSnorm4x8 => Mf::Pack4x8snorm,
+ Glo::PackHalf2x16 => Mf::Pack2x16float,
+ Glo::PackUnorm2x16 => Mf::Pack2x16unorm,
+ Glo::PackSnorm2x16 => Mf::Pack2x16snorm,
+ Glo::UnpackUnorm4x8 => Mf::Unpack4x8unorm,
+ Glo::UnpackSnorm4x8 => Mf::Unpack4x8snorm,
+ Glo::UnpackHalf2x16 => Mf::Unpack2x16float,
+ Glo::UnpackUnorm2x16 => Mf::Unpack2x16unorm,
+ Glo::UnpackSnorm2x16 => Mf::Unpack2x16snorm,
+ Glo::FindILsb => Mf::FindLsb,
+ Glo::FindUMsb | Glo::FindSMsb => Mf::FindMsb,
+ // TODO: https://github.com/gfx-rs/naga/issues/2526
+ Glo::Modf | Glo::Frexp => return Err(Error::UnsupportedExtInst(inst_id)),
+ Glo::IMix
+ | Glo::PackDouble2x32
+ | Glo::UnpackDouble2x32
+ | Glo::InterpolateAtCentroid
+ | Glo::InterpolateAtSample
+ | Glo::InterpolateAtOffset => {
+ return Err(Error::UnsupportedExtInst(inst_id))
+ }
+ };
+
+ let arg_count = fun.argument_count();
+ inst.expect(base_wc + arg_count as u16)?;
+ let arg = {
+ let arg_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(arg_id)?;
+ get_expr_handle!(arg_id, lexp)
+ };
+ let arg1 = if arg_count > 1 {
+ let arg_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(arg_id)?;
+ Some(get_expr_handle!(arg_id, lexp))
+ } else {
+ None
+ };
+ let arg2 = if arg_count > 2 {
+ let arg_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(arg_id)?;
+ Some(get_expr_handle!(arg_id, lexp))
+ } else {
+ None
+ };
+ let arg3 = if arg_count > 3 {
+ let arg_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(arg_id)?;
+ Some(get_expr_handle!(arg_id, lexp))
+ } else {
+ None
+ };
+
+ let expr = crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ // Relational and Logical Instructions
+ Op::LogicalNot => {
+ inst.expect(4)?;
+ parse_expr_op!(crate::UnaryOperator::LogicalNot, UNARY)?;
+ }
+ Op::LogicalOr => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::LogicalOr, BINARY)?;
+ }
+ Op::LogicalAnd => {
+ inst.expect(5)?;
+ parse_expr_op!(crate::BinaryOperator::LogicalAnd, BINARY)?;
+ }
+ Op::SGreaterThan | Op::SGreaterThanEqual | Op::SLessThan | Op::SLessThanEqual => {
+ inst.expect(5)?;
+ self.parse_expr_int_comparison(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ map_binary_operator(inst.op)?,
+ crate::ScalarKind::Sint,
+ )?;
+ }
+ Op::UGreaterThan | Op::UGreaterThanEqual | Op::ULessThan | Op::ULessThanEqual => {
+ inst.expect(5)?;
+ self.parse_expr_int_comparison(
+ ctx,
+ &mut emitter,
+ &mut block,
+ block_id,
+ body_idx,
+ map_binary_operator(inst.op)?,
+ crate::ScalarKind::Uint,
+ )?;
+ }
+ Op::FOrdEqual
+ | Op::FUnordEqual
+ | Op::FOrdNotEqual
+ | Op::FUnordNotEqual
+ | Op::FOrdLessThan
+ | Op::FUnordLessThan
+ | Op::FOrdGreaterThan
+ | Op::FUnordGreaterThan
+ | Op::FOrdLessThanEqual
+ | Op::FUnordLessThanEqual
+ | Op::FOrdGreaterThanEqual
+ | Op::FUnordGreaterThanEqual
+ | Op::LogicalEqual
+ | Op::LogicalNotEqual => {
+ inst.expect(5)?;
+ let operator = map_binary_operator(inst.op)?;
+ parse_expr_op!(operator, BINARY)?;
+ }
+ Op::Any | Op::All | Op::IsNan | Op::IsInf | Op::IsFinite | Op::IsNormal => {
+ inst.expect(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let arg_id = self.next()?;
+
+ let arg_lexp = self.lookup_expression.lookup(arg_id)?;
+ let arg_handle = get_expr_handle!(arg_id, arg_lexp);
+
+ let expr = crate::Expression::Relational {
+ fun: map_relational_fun(inst.op)?,
+ argument: arg_handle,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: ctx.expressions.append(expr, span),
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::Kill => {
+ inst.expect(1)?;
+ break Some(crate::Statement::Kill);
+ }
+ Op::Unreachable => {
+ inst.expect(1)?;
+ break None;
+ }
+ Op::Return => {
+ inst.expect(1)?;
+ break Some(crate::Statement::Return { value: None });
+ }
+ Op::ReturnValue => {
+ inst.expect(2)?;
+ let value_id = self.next()?;
+ let value_lexp = self.lookup_expression.lookup(value_id)?;
+ let value_handle = get_expr_handle!(value_id, value_lexp);
+ break Some(crate::Statement::Return {
+ value: Some(value_handle),
+ });
+ }
+ Op::Branch => {
+ inst.expect(2)?;
+ let target_id = self.next()?;
+
+ // If this is a branch to a merge or continue block, then
+ // that ends the current body.
+ //
+ // Why can we count on finding an entry here when it's
+ // needed? SPIR-V requires dominators to appear before
+ // blocks they dominate, so we will have visited a
+ // structured control construct's header block before
+ // anything that could exit it.
+ if let Some(info) = ctx.mergers.get(&target_id) {
+ block.extend(emitter.finish(ctx.expressions));
+ ctx.blocks.insert(block_id, block);
+ let body = &mut ctx.bodies[body_idx];
+ body.data.push(BodyFragment::BlockId(block_id));
+
+ merger(body, info);
+
+ return Ok(());
+ }
+
+ // If `target_id` has no entry in `ctx.body_for_label`, then
+ // this must be the only branch to it:
+ //
+ // - We've already established that it's not anybody's merge
+ // block.
+ //
+ // - It can't be a switch case. Only switch header blocks
+ // and other switch cases can branch to a switch case.
+ // Switch header blocks must dominate all their cases, so
+ // they must appear in the file before them, and when we
+ // see `Op::Switch` we populate `ctx.body_for_label` for
+ // every switch case.
+ //
+ // Thus, `target_id` must be a simple extension of the
+ // current block, which we dominate, so we know we'll
+ // encounter it later in the file.
+ ctx.body_for_label.entry(target_id).or_insert(body_idx);
+
+ break None;
+ }
+ Op::BranchConditional => {
+ inst.expect_at_least(4)?;
+
+ let condition = {
+ let condition_id = self.next()?;
+ let lexp = self.lookup_expression.lookup(condition_id)?;
+ get_expr_handle!(condition_id, lexp)
+ };
+
+ // HACK(eddyb) Naga doesn't seem to have this helper,
+ // so it's declared on the fly here for convenience.
+ #[derive(Copy, Clone)]
+ struct BranchTarget {
+ label_id: spirv::Word,
+ merge_info: Option<MergeBlockInformation>,
+ }
+ let branch_target = |label_id| BranchTarget {
+ label_id,
+ merge_info: ctx.mergers.get(&label_id).copied(),
+ };
+
+ let true_target = branch_target(self.next()?);
+ let false_target = branch_target(self.next()?);
+
+ // Consume branch weights
+ for _ in 4..inst.wc {
+ let _ = self.next()?;
+ }
+
+ // Handle `OpBranchConditional`s used at the end of a loop
+ // body's "continuing" section as a "conditional backedge",
+ // i.e. a `do`-`while` condition, or `break if` in WGSL.
+
+ // HACK(eddyb) this has to go to the parent *twice*, because
+ // `OpLoopMerge` left the "continuing" section nested in the
+ // loop body in terms of `parent`, but not `BodyFragment`.
+ let parent_body_idx = ctx.bodies[body_idx].parent;
+ let parent_parent_body_idx = ctx.bodies[parent_body_idx].parent;
+ match ctx.bodies[parent_parent_body_idx].data[..] {
+ // The `OpLoopMerge`'s `continuing` block and the loop's
+ // backedge block may not be the same, but they'll both
+ // belong to the same body.
+ [.., BodyFragment::Loop {
+ body: loop_body_idx,
+ continuing: loop_continuing_idx,
+ break_if: ref mut break_if_slot @ None,
+ }] if body_idx == loop_continuing_idx => {
+ // Try both orderings of break-vs-backedge, because
+ // SPIR-V is symmetrical here, unlike WGSL `break if`.
+ let break_if_cond = [true, false].into_iter().find_map(|true_breaks| {
+ let (break_candidate, backedge_candidate) = if true_breaks {
+ (true_target, false_target)
+ } else {
+ (false_target, true_target)
+ };
+
+ if break_candidate.merge_info
+ != Some(MergeBlockInformation::LoopMerge)
+ {
+ return None;
+ }
+
+ // HACK(eddyb) since Naga doesn't explicitly track
+ // backedges, this is checking for the outcome of
+ // `OpLoopMerge` below (even if it looks weird).
+ let backedge_candidate_is_backedge =
+ backedge_candidate.merge_info.is_none()
+ && ctx.body_for_label.get(&backedge_candidate.label_id)
+ == Some(&loop_body_idx);
+ if !backedge_candidate_is_backedge {
+ return None;
+ }
+
+ Some(if true_breaks {
+ condition
+ } else {
+ ctx.expressions.append(
+ crate::Expression::Unary {
+ op: crate::UnaryOperator::LogicalNot,
+ expr: condition,
+ },
+ span,
+ )
+ })
+ });
+
+ if let Some(break_if_cond) = break_if_cond {
+ *break_if_slot = Some(break_if_cond);
+
+ // This `OpBranchConditional` ends the "continuing"
+ // section of the loop body as normal, with the
+ // `break if` condition having been stashed above.
+ break None;
+ }
+ }
+ _ => {}
+ }
+
+ block.extend(emitter.finish(ctx.expressions));
+ ctx.blocks.insert(block_id, block);
+ let body = &mut ctx.bodies[body_idx];
+ body.data.push(BodyFragment::BlockId(block_id));
+
+ let same_target = true_target.label_id == false_target.label_id;
+
+ // Start a body block for the `accept` branch.
+ let accept = ctx.bodies.len();
+ let mut accept_block = Body::with_parent(body_idx);
+
+ // If the `OpBranchConditional` target is somebody else's
+ // merge or continue block, then put a `Break` or `Continue`
+ // statement in this new body block.
+ if let Some(info) = true_target.merge_info {
+ merger(
+ match same_target {
+ true => &mut ctx.bodies[body_idx],
+ false => &mut accept_block,
+ },
+ &info,
+ )
+ } else {
+ // Note the body index for the block we're branching to.
+ let prev = ctx.body_for_label.insert(
+ true_target.label_id,
+ match same_target {
+ true => body_idx,
+ false => accept,
+ },
+ );
+ debug_assert!(prev.is_none());
+ }
+
+ if same_target {
+ return Ok(());
+ }
+
+ ctx.bodies.push(accept_block);
+
+ // Handle the `reject` branch just like the `accept` block.
+ let reject = ctx.bodies.len();
+ let mut reject_block = Body::with_parent(body_idx);
+
+ if let Some(info) = false_target.merge_info {
+ merger(&mut reject_block, &info)
+ } else {
+ let prev = ctx.body_for_label.insert(false_target.label_id, reject);
+ debug_assert!(prev.is_none());
+ }
+
+ ctx.bodies.push(reject_block);
+
+ let body = &mut ctx.bodies[body_idx];
+ body.data.push(BodyFragment::If {
+ condition,
+ accept,
+ reject,
+ });
+
+ return Ok(());
+ }
+ Op::Switch => {
+ inst.expect_at_least(3)?;
+ let selector = self.next()?;
+ let default_id = self.next()?;
+
+ // If the previous instruction was a `OpSelectionMerge` then we must
+ // promote the `MergeBlockInformation` to a `SwitchMerge`
+ if let Some(merge) = selection_merge_block {
+ ctx.mergers
+ .insert(merge, MergeBlockInformation::SwitchMerge);
+ }
+
+ let default = ctx.bodies.len();
+ ctx.bodies.push(Body::with_parent(body_idx));
+ ctx.body_for_label.entry(default_id).or_insert(default);
+
+ let selector_lexp = &self.lookup_expression[&selector];
+ let selector_lty = self.lookup_type.lookup(selector_lexp.type_id)?;
+ let selector_handle = get_expr_handle!(selector, selector_lexp);
+ let selector = match ctx.type_arena[selector_lty.handle].inner {
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: _,
+ }) => {
+ // IR expects a signed integer, so do a bitcast
+ ctx.expressions.append(
+ crate::Expression::As {
+ kind: crate::ScalarKind::Sint,
+ expr: selector_handle,
+ convert: None,
+ },
+ span,
+ )
+ }
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: _,
+ }) => selector_handle,
+ ref other => unimplemented!("Unexpected selector {:?}", other),
+ };
+
+ // Clear past switch cases to prevent them from entering this one
+ self.switch_cases.clear();
+
+ for _ in 0..(inst.wc - 3) / 2 {
+ let literal = self.next()?;
+ let target = self.next()?;
+
+ let case_body_idx = ctx.bodies.len();
+
+ // Check if any previous case already used this target block id, if so
+ // group them together to reorder them later so that no weird
+ // fallthrough cases happen.
+ if let Some(&mut (_, ref mut literals)) = self.switch_cases.get_mut(&target)
+ {
+ literals.push(literal as i32);
+ continue;
+ }
+
+ let mut body = Body::with_parent(body_idx);
+
+ if let Some(info) = ctx.mergers.get(&target) {
+ merger(&mut body, info);
+ }
+
+ ctx.bodies.push(body);
+ ctx.body_for_label.entry(target).or_insert(case_body_idx);
+
+ // Register this target block id as already having been processed and
+ // the respective body index assigned and the first case value
+ self.switch_cases
+ .insert(target, (case_body_idx, vec![literal as i32]));
+ }
+
+ // Loop trough the collected target blocks creating a new case for each
+ // literal pointing to it, only one case will have the true body and all the
+ // others will be empty fallthrough so that they all execute the same body
+ // without duplicating code.
+ //
+ // Since `switch_cases` is an indexmap the order of insertion is preserved
+ // this is needed because spir-v defines fallthrough order in the switch
+ // instruction.
+ let mut cases = Vec::with_capacity((inst.wc as usize - 3) / 2);
+ for &(case_body_idx, ref literals) in self.switch_cases.values() {
+ let value = literals[0];
+
+ for &literal in literals.iter().skip(1) {
+ let empty_body_idx = ctx.bodies.len();
+ let body = Body::with_parent(body_idx);
+
+ ctx.bodies.push(body);
+
+ cases.push((literal, empty_body_idx));
+ }
+
+ cases.push((value, case_body_idx));
+ }
+
+ block.extend(emitter.finish(ctx.expressions));
+
+ let body = &mut ctx.bodies[body_idx];
+ ctx.blocks.insert(block_id, block);
+ // Make sure the vector has space for at least two more allocations
+ body.data.reserve(2);
+ body.data.push(BodyFragment::BlockId(block_id));
+ body.data.push(BodyFragment::Switch {
+ selector,
+ cases,
+ default,
+ });
+
+ return Ok(());
+ }
+ Op::SelectionMerge => {
+ inst.expect(3)?;
+ let merge_block_id = self.next()?;
+ // TODO: Selection Control Mask
+ let _selection_control = self.next()?;
+
+ // Indicate that the merge block is a continuation of the
+ // current `Body`.
+ ctx.body_for_label.entry(merge_block_id).or_insert(body_idx);
+
+ // Let subsequent branches to the merge block know that
+ // they've reached the end of the selection construct.
+ ctx.mergers
+ .insert(merge_block_id, MergeBlockInformation::SelectionMerge);
+
+ selection_merge_block = Some(merge_block_id);
+ }
+ Op::LoopMerge => {
+ inst.expect_at_least(4)?;
+ let merge_block_id = self.next()?;
+ let continuing = self.next()?;
+
+ // TODO: Loop Control Parameters
+ for _ in 0..inst.wc - 3 {
+ self.next()?;
+ }
+
+ // Indicate that the merge block is a continuation of the
+ // current `Body`.
+ ctx.body_for_label.entry(merge_block_id).or_insert(body_idx);
+ // Let subsequent branches to the merge block know that
+ // they're `Break` statements.
+ ctx.mergers
+ .insert(merge_block_id, MergeBlockInformation::LoopMerge);
+
+ let loop_body_idx = ctx.bodies.len();
+ ctx.bodies.push(Body::with_parent(body_idx));
+
+ let continue_idx = ctx.bodies.len();
+ // The continue block inherits the scope of the loop body
+ ctx.bodies.push(Body::with_parent(loop_body_idx));
+ ctx.body_for_label.entry(continuing).or_insert(continue_idx);
+ // Let subsequent branches to the continue block know that
+ // they're `Continue` statements.
+ ctx.mergers
+ .insert(continuing, MergeBlockInformation::LoopContinue);
+
+ // The loop header always belongs to the loop body
+ ctx.body_for_label.insert(block_id, loop_body_idx);
+
+ let parent_body = &mut ctx.bodies[body_idx];
+ parent_body.data.push(BodyFragment::Loop {
+ body: loop_body_idx,
+ continuing: continue_idx,
+ break_if: None,
+ });
+ body_idx = loop_body_idx;
+ }
+ Op::DPdxCoarse => {
+ parse_expr_op!(
+ crate::DerivativeAxis::X,
+ crate::DerivativeControl::Coarse,
+ DERIVATIVE
+ )?;
+ }
+ Op::DPdyCoarse => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Y,
+ crate::DerivativeControl::Coarse,
+ DERIVATIVE
+ )?;
+ }
+ Op::FwidthCoarse => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Width,
+ crate::DerivativeControl::Coarse,
+ DERIVATIVE
+ )?;
+ }
+ Op::DPdxFine => {
+ parse_expr_op!(
+ crate::DerivativeAxis::X,
+ crate::DerivativeControl::Fine,
+ DERIVATIVE
+ )?;
+ }
+ Op::DPdyFine => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Y,
+ crate::DerivativeControl::Fine,
+ DERIVATIVE
+ )?;
+ }
+ Op::FwidthFine => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Width,
+ crate::DerivativeControl::Fine,
+ DERIVATIVE
+ )?;
+ }
+ Op::DPdx => {
+ parse_expr_op!(
+ crate::DerivativeAxis::X,
+ crate::DerivativeControl::None,
+ DERIVATIVE
+ )?;
+ }
+ Op::DPdy => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Y,
+ crate::DerivativeControl::None,
+ DERIVATIVE
+ )?;
+ }
+ Op::Fwidth => {
+ parse_expr_op!(
+ crate::DerivativeAxis::Width,
+ crate::DerivativeControl::None,
+ DERIVATIVE
+ )?;
+ }
+ Op::ArrayLength => {
+ inst.expect(5)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let structure_id = self.next()?;
+ let member_index = self.next()?;
+
+ // We're assuming that the validation pass, if it's run, will catch if the
+ // wrong types or parameters are supplied here.
+
+ let structure_ptr = self.lookup_expression.lookup(structure_id)?;
+ let structure_handle = get_expr_handle!(structure_id, structure_ptr);
+
+ let member_ptr = ctx.expressions.append(
+ crate::Expression::AccessIndex {
+ base: structure_handle,
+ index: member_index,
+ },
+ span,
+ );
+
+ let length = ctx
+ .expressions
+ .append(crate::Expression::ArrayLength(member_ptr), span);
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: length,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ Op::CopyMemory => {
+ inst.expect_at_least(3)?;
+ let target_id = self.next()?;
+ let source_id = self.next()?;
+ let _memory_access = if inst.wc != 3 {
+ inst.expect(4)?;
+ spirv::MemoryAccess::from_bits(self.next()?)
+ .ok_or(Error::InvalidParameter(Op::CopyMemory))?
+ } else {
+ spirv::MemoryAccess::NONE
+ };
+
+ // TODO: check if the source and target types are the same?
+ let target = self.lookup_expression.lookup(target_id)?;
+ let target_handle = get_expr_handle!(target_id, target);
+ let source = self.lookup_expression.lookup(source_id)?;
+ let source_handle = get_expr_handle!(source_id, source);
+
+ // This operation is practically the same as loading and then storing, I think.
+ let value_expr = ctx.expressions.append(
+ crate::Expression::Load {
+ pointer: source_handle,
+ },
+ span,
+ );
+
+ block.extend(emitter.finish(ctx.expressions));
+ block.push(
+ crate::Statement::Store {
+ pointer: target_handle,
+ value: value_expr,
+ },
+ span,
+ );
+
+ emitter.start(ctx.expressions);
+ }
+ Op::ControlBarrier => {
+ inst.expect(4)?;
+ let exec_scope_id = self.next()?;
+ let _mem_scope_raw = self.next()?;
+ let semantics_id = self.next()?;
+ let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
+ let semantics_const = self.lookup_constant.lookup(semantics_id)?;
+
+ let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle)
+ .ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
+ let semantics = resolve_constant(ctx.gctx(), semantics_const.handle)
+ .ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?;
+
+ if exec_scope == spirv::Scope::Workgroup as u32 {
+ let mut flags = crate::Barrier::empty();
+ flags.set(
+ crate::Barrier::STORAGE,
+ semantics & spirv::MemorySemantics::UNIFORM_MEMORY.bits() != 0,
+ );
+ flags.set(
+ crate::Barrier::WORK_GROUP,
+ semantics
+ & (spirv::MemorySemantics::SUBGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY)
+ .bits()
+ != 0,
+ );
+ block.push(crate::Statement::Barrier(flags), span);
+ } else {
+ log::warn!("Unsupported barrier execution scope: {}", exec_scope);
+ }
+ }
+ Op::CopyObject => {
+ inst.expect(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let operand_id = self.next()?;
+
+ let lookup = self.lookup_expression.lookup(operand_id)?;
+ let handle = get_expr_handle!(operand_id, lookup);
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+ }
+ _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)),
+ }
+ };
+
+ block.extend(emitter.finish(ctx.expressions));
+ if let Some(stmt) = terminator {
+ block.push(stmt, crate::Span::default());
+ }
+
+ // Save this block fragment in `block_ctx.blocks`, and mark it to be
+ // incorporated into the current body at `Statement` assembly time.
+ ctx.blocks.insert(block_id, block);
+ let body = &mut ctx.bodies[body_idx];
+ body.data.push(BodyFragment::BlockId(block_id));
+ Ok(())
+ }
+
+ fn make_expression_storage(
+ &mut self,
+ globals: &Arena<crate::GlobalVariable>,
+ constants: &Arena<crate::Constant>,
+ ) -> Arena<crate::Expression> {
+ let mut expressions = Arena::new();
+ #[allow(clippy::panic)]
+ {
+ assert!(self.lookup_expression.is_empty());
+ }
+ // register global variables
+ for (&id, var) in self.lookup_variable.iter() {
+ let span = globals.get_span(var.handle);
+ let handle = expressions.append(crate::Expression::GlobalVariable(var.handle), span);
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ type_id: var.type_id,
+ handle,
+ // Setting this to an invalid id will cause get_expr_handle
+ // to default to the main body making sure no load/stores
+ // are added.
+ block_id: 0,
+ },
+ );
+ }
+ // register constants
+ for (&id, con) in self.lookup_constant.iter() {
+ let span = constants.get_span(con.handle);
+ let handle = expressions.append(crate::Expression::Constant(con.handle), span);
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ type_id: con.type_id,
+ handle,
+ // Setting this to an invalid id will cause get_expr_handle
+ // to default to the main body making sure no load/stores
+ // are added.
+ block_id: 0,
+ },
+ );
+ }
+ // done
+ expressions
+ }
+
+ fn switch(&mut self, state: ModuleState, op: spirv::Op) -> Result<(), Error> {
+ if state < self.state {
+ Err(Error::UnsupportedInstruction(self.state, op))
+ } else {
+ self.state = state;
+ Ok(())
+ }
+ }
+
+ /// Walk the statement tree and patch it in the following cases:
+ /// 1. Function call targets are replaced by `deferred_function_calls` map
+ fn patch_statements(
+ &mut self,
+ statements: &mut crate::Block,
+ expressions: &mut Arena<crate::Expression>,
+ fun_parameter_sampling: &mut [image::SamplingFlags],
+ ) -> Result<(), Error> {
+ use crate::Statement as S;
+ let mut i = 0usize;
+ while i < statements.len() {
+ match statements[i] {
+ S::Emit(_) => {}
+ S::Block(ref mut block) => {
+ self.patch_statements(block, expressions, fun_parameter_sampling)?;
+ }
+ S::If {
+ condition: _,
+ ref mut accept,
+ ref mut reject,
+ } => {
+ self.patch_statements(reject, expressions, fun_parameter_sampling)?;
+ self.patch_statements(accept, expressions, fun_parameter_sampling)?;
+ }
+ S::Switch {
+ selector: _,
+ ref mut cases,
+ } => {
+ for case in cases.iter_mut() {
+ self.patch_statements(&mut case.body, expressions, fun_parameter_sampling)?;
+ }
+ }
+ S::Loop {
+ ref mut body,
+ ref mut continuing,
+ break_if: _,
+ } => {
+ self.patch_statements(body, expressions, fun_parameter_sampling)?;
+ self.patch_statements(continuing, expressions, fun_parameter_sampling)?;
+ }
+ S::Break
+ | S::Continue
+ | S::Return { .. }
+ | S::Kill
+ | S::Barrier(_)
+ | S::Store { .. }
+ | S::ImageStore { .. }
+ | S::Atomic { .. }
+ | S::RayQuery { .. } => {}
+ S::Call {
+ function: ref mut callee,
+ ref arguments,
+ ..
+ } => {
+ let fun_id = self.deferred_function_calls[callee.index()];
+ let fun_lookup = self.lookup_function.lookup(fun_id)?;
+ *callee = fun_lookup.handle;
+
+ // Patch sampling flags
+ for (arg_index, arg) in arguments.iter().enumerate() {
+ let flags = match fun_lookup.parameters_sampling.get(arg_index) {
+ Some(&flags) if !flags.is_empty() => flags,
+ _ => continue,
+ };
+
+ match expressions[*arg] {
+ crate::Expression::GlobalVariable(handle) => {
+ if let Some(sampling) = self.handle_sampling.get_mut(&handle) {
+ *sampling |= flags
+ }
+ }
+ crate::Expression::FunctionArgument(i) => {
+ fun_parameter_sampling[i as usize] |= flags;
+ }
+ ref other => return Err(Error::InvalidGlobalVar(other.clone())),
+ }
+ }
+ }
+ S::WorkGroupUniformLoad { .. } => unreachable!(),
+ }
+ i += 1;
+ }
+ Ok(())
+ }
+
+ fn patch_function(
+ &mut self,
+ handle: Option<Handle<crate::Function>>,
+ fun: &mut crate::Function,
+ ) -> Result<(), Error> {
+ // Note: this search is a bit unfortunate
+ let (fun_id, mut parameters_sampling) = match handle {
+ Some(h) => {
+ let (&fun_id, lookup) = self
+ .lookup_function
+ .iter_mut()
+ .find(|&(_, ref lookup)| lookup.handle == h)
+ .unwrap();
+ (fun_id, mem::take(&mut lookup.parameters_sampling))
+ }
+ None => (0, Vec::new()),
+ };
+
+ for (_, expr) in fun.expressions.iter_mut() {
+ if let crate::Expression::CallResult(ref mut function) = *expr {
+ let fun_id = self.deferred_function_calls[function.index()];
+ *function = self.lookup_function.lookup(fun_id)?.handle;
+ }
+ }
+
+ self.patch_statements(
+ &mut fun.body,
+ &mut fun.expressions,
+ &mut parameters_sampling,
+ )?;
+
+ if let Some(lookup) = self.lookup_function.get_mut(&fun_id) {
+ lookup.parameters_sampling = parameters_sampling;
+ }
+ Ok(())
+ }
+
+ pub fn parse(mut self) -> Result<crate::Module, Error> {
+ let mut module = {
+ if self.next()? != spirv::MAGIC_NUMBER {
+ return Err(Error::InvalidHeader);
+ }
+ let version_raw = self.next()?;
+ let generator = self.next()?;
+ let _bound = self.next()?;
+ let _schema = self.next()?;
+ log::info!("Generated by {} version {:x}", generator, version_raw);
+ crate::Module::default()
+ };
+
+ self.layouter.clear();
+ self.dummy_functions = Arena::new();
+ self.lookup_function.clear();
+ self.function_call_graph.clear();
+
+ loop {
+ use spirv::Op;
+
+ let inst = match self.next_inst() {
+ Ok(inst) => inst,
+ Err(Error::IncompleteData) => break,
+ Err(other) => return Err(other),
+ };
+ log::debug!("\t{:?} [{}]", inst.op, inst.wc);
+
+ match inst.op {
+ Op::Capability => self.parse_capability(inst),
+ Op::Extension => self.parse_extension(inst),
+ Op::ExtInstImport => self.parse_ext_inst_import(inst),
+ Op::MemoryModel => self.parse_memory_model(inst),
+ Op::EntryPoint => self.parse_entry_point(inst),
+ Op::ExecutionMode => self.parse_execution_mode(inst),
+ Op::String => self.parse_string(inst),
+ Op::Source => self.parse_source(inst),
+ Op::SourceExtension => self.parse_source_extension(inst),
+ Op::Name => self.parse_name(inst),
+ Op::MemberName => self.parse_member_name(inst),
+ Op::ModuleProcessed => self.parse_module_processed(inst),
+ Op::Decorate => self.parse_decorate(inst),
+ Op::MemberDecorate => self.parse_member_decorate(inst),
+ Op::TypeVoid => self.parse_type_void(inst),
+ Op::TypeBool => self.parse_type_bool(inst, &mut module),
+ Op::TypeInt => self.parse_type_int(inst, &mut module),
+ Op::TypeFloat => self.parse_type_float(inst, &mut module),
+ Op::TypeVector => self.parse_type_vector(inst, &mut module),
+ Op::TypeMatrix => self.parse_type_matrix(inst, &mut module),
+ Op::TypeFunction => self.parse_type_function(inst),
+ Op::TypePointer => self.parse_type_pointer(inst, &mut module),
+ Op::TypeArray => self.parse_type_array(inst, &mut module),
+ Op::TypeRuntimeArray => self.parse_type_runtime_array(inst, &mut module),
+ Op::TypeStruct => self.parse_type_struct(inst, &mut module),
+ Op::TypeImage => self.parse_type_image(inst, &mut module),
+ Op::TypeSampledImage => self.parse_type_sampled_image(inst),
+ Op::TypeSampler => self.parse_type_sampler(inst, &mut module),
+ Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module),
+ Op::ConstantComposite => self.parse_composite_constant(inst, &mut module),
+ Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module),
+ Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module),
+ Op::ConstantFalse => self.parse_bool_constant(inst, false, &mut module),
+ Op::Variable => self.parse_global_variable(inst, &mut module),
+ Op::Function => {
+ self.switch(ModuleState::Function, inst.op)?;
+ inst.expect(5)?;
+ self.parse_function(&mut module)
+ }
+ _ => Err(Error::UnsupportedInstruction(self.state, inst.op)), //TODO
+ }?;
+ }
+
+ log::info!("Patching...");
+ {
+ let mut nodes = petgraph::algo::toposort(&self.function_call_graph, None)
+ .map_err(|cycle| Error::FunctionCallCycle(cycle.node_id()))?;
+ nodes.reverse(); // we need dominated first
+ let mut functions = mem::take(&mut module.functions);
+ for fun_id in nodes {
+ if fun_id > !(functions.len() as u32) {
+ // skip all the fake IDs registered for the entry points
+ continue;
+ }
+ let lookup = self.lookup_function.get_mut(&fun_id).unwrap();
+ // take out the function from the old array
+ let fun = mem::take(&mut functions[lookup.handle]);
+ // add it to the newly formed arena, and adjust the lookup
+ lookup.handle = module
+ .functions
+ .append(fun, functions.get_span(lookup.handle));
+ }
+ }
+ // patch all the functions
+ for (handle, fun) in module.functions.iter_mut() {
+ self.patch_function(Some(handle), fun)?;
+ }
+ for ep in module.entry_points.iter_mut() {
+ self.patch_function(None, &mut ep.function)?;
+ }
+
+ // Check all the images and samplers to have consistent comparison property.
+ for (handle, flags) in self.handle_sampling.drain() {
+ if !image::patch_comparison_type(
+ flags,
+ module.global_variables.get_mut(handle),
+ &mut module.types,
+ ) {
+ return Err(Error::InconsistentComparisonSampling(handle));
+ }
+ }
+
+ if !self.future_decor.is_empty() {
+ log::warn!("Unused item decorations: {:?}", self.future_decor);
+ self.future_decor.clear();
+ }
+ if !self.future_member_decor.is_empty() {
+ log::warn!("Unused member decorations: {:?}", self.future_member_decor);
+ self.future_member_decor.clear();
+ }
+
+ Ok(module)
+ }
+
+ fn parse_capability(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Capability, inst.op)?;
+ inst.expect(2)?;
+ let capability = self.next()?;
+ let cap =
+ spirv::Capability::from_u32(capability).ok_or(Error::UnknownCapability(capability))?;
+ if !SUPPORTED_CAPABILITIES.contains(&cap) {
+ if self.options.strict_capabilities {
+ return Err(Error::UnsupportedCapability(cap));
+ } else {
+ log::warn!("Unknown capability {:?}", cap);
+ }
+ }
+ Ok(())
+ }
+
+ fn parse_extension(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Extension, inst.op)?;
+ inst.expect_at_least(2)?;
+ let (name, left) = self.next_string(inst.wc - 1)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+ if !SUPPORTED_EXTENSIONS.contains(&name.as_str()) {
+ return Err(Error::UnsupportedExtension(name));
+ }
+ Ok(())
+ }
+
+ fn parse_ext_inst_import(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Extension, inst.op)?;
+ inst.expect_at_least(3)?;
+ let result_id = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 2)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+ if !SUPPORTED_EXT_SETS.contains(&name.as_str()) {
+ return Err(Error::UnsupportedExtSet(name));
+ }
+ self.ext_glsl_id = Some(result_id);
+ Ok(())
+ }
+
+ fn parse_memory_model(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::MemoryModel, inst.op)?;
+ inst.expect(3)?;
+ let _addressing_model = self.next()?;
+ let _memory_model = self.next()?;
+ Ok(())
+ }
+
+ fn parse_entry_point(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::EntryPoint, inst.op)?;
+ inst.expect_at_least(4)?;
+ let exec_model = self.next()?;
+ let exec_model = spirv::ExecutionModel::from_u32(exec_model)
+ .ok_or(Error::UnsupportedExecutionModel(exec_model))?;
+ let function_id = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 3)?;
+ let ep = EntryPoint {
+ stage: match exec_model {
+ spirv::ExecutionModel::Vertex => crate::ShaderStage::Vertex,
+ spirv::ExecutionModel::Fragment => crate::ShaderStage::Fragment,
+ spirv::ExecutionModel::GLCompute => crate::ShaderStage::Compute,
+ _ => return Err(Error::UnsupportedExecutionModel(exec_model as u32)),
+ },
+ name,
+ early_depth_test: None,
+ workgroup_size: [0; 3],
+ variable_ids: self.data.by_ref().take(left as usize).collect(),
+ };
+ self.lookup_entry_point.insert(function_id, ep);
+ Ok(())
+ }
+
+ fn parse_execution_mode(&mut self, inst: Instruction) -> Result<(), Error> {
+ use spirv::ExecutionMode;
+
+ self.switch(ModuleState::ExecutionMode, inst.op)?;
+ inst.expect_at_least(3)?;
+
+ let ep_id = self.next()?;
+ let mode_id = self.next()?;
+ let args: Vec<spirv::Word> = self.data.by_ref().take(inst.wc as usize - 3).collect();
+
+ let ep = self
+ .lookup_entry_point
+ .get_mut(&ep_id)
+ .ok_or(Error::InvalidId(ep_id))?;
+ let mode = spirv::ExecutionMode::from_u32(mode_id)
+ .ok_or(Error::UnsupportedExecutionMode(mode_id))?;
+
+ match mode {
+ ExecutionMode::EarlyFragmentTests => {
+ if ep.early_depth_test.is_none() {
+ ep.early_depth_test = Some(crate::EarlyDepthTest { conservative: None });
+ }
+ }
+ ExecutionMode::DepthUnchanged => {
+ ep.early_depth_test = Some(crate::EarlyDepthTest {
+ conservative: Some(crate::ConservativeDepth::Unchanged),
+ });
+ }
+ ExecutionMode::DepthGreater => {
+ ep.early_depth_test = Some(crate::EarlyDepthTest {
+ conservative: Some(crate::ConservativeDepth::GreaterEqual),
+ });
+ }
+ ExecutionMode::DepthLess => {
+ ep.early_depth_test = Some(crate::EarlyDepthTest {
+ conservative: Some(crate::ConservativeDepth::LessEqual),
+ });
+ }
+ ExecutionMode::DepthReplacing => {
+ // Ignored because it can be deduced from the IR.
+ }
+ ExecutionMode::OriginUpperLeft => {
+ // Ignored because the other option (OriginLowerLeft) is not valid in Vulkan mode.
+ }
+ ExecutionMode::LocalSize => {
+ ep.workgroup_size = [args[0], args[1], args[2]];
+ }
+ _ => {
+ return Err(Error::UnsupportedExecutionMode(mode_id));
+ }
+ }
+
+ Ok(())
+ }
+
+ fn parse_string(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Source, inst.op)?;
+ inst.expect_at_least(3)?;
+ let _id = self.next()?;
+ let (_name, _) = self.next_string(inst.wc - 2)?;
+ Ok(())
+ }
+
+ fn parse_source(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Source, inst.op)?;
+ for _ in 1..inst.wc {
+ let _ = self.next()?;
+ }
+ Ok(())
+ }
+
+ fn parse_source_extension(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Source, inst.op)?;
+ inst.expect_at_least(2)?;
+ let (_name, _) = self.next_string(inst.wc - 1)?;
+ Ok(())
+ }
+
+ fn parse_name(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Name, inst.op)?;
+ inst.expect_at_least(3)?;
+ let id = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 2)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+ self.future_decor.entry(id).or_default().name = Some(name);
+ Ok(())
+ }
+
+ fn parse_member_name(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Name, inst.op)?;
+ inst.expect_at_least(4)?;
+ let id = self.next()?;
+ let member = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 3)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+
+ self.future_member_decor
+ .entry((id, member))
+ .or_default()
+ .name = Some(name);
+ Ok(())
+ }
+
+ fn parse_module_processed(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Name, inst.op)?;
+ inst.expect_at_least(2)?;
+ let (_info, left) = self.next_string(inst.wc - 1)?;
+ //Note: string is ignored
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+ Ok(())
+ }
+
+ fn parse_decorate(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Annotation, inst.op)?;
+ inst.expect_at_least(3)?;
+ let id = self.next()?;
+ let mut dec = self.future_decor.remove(&id).unwrap_or_default();
+ self.next_decoration(inst, 2, &mut dec)?;
+ self.future_decor.insert(id, dec);
+ Ok(())
+ }
+
+ fn parse_member_decorate(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Annotation, inst.op)?;
+ inst.expect_at_least(4)?;
+ let id = self.next()?;
+ let member = self.next()?;
+
+ let mut dec = self
+ .future_member_decor
+ .remove(&(id, member))
+ .unwrap_or_default();
+ self.next_decoration(inst, 3, &mut dec)?;
+ self.future_member_decor.insert((id, member), dec);
+ Ok(())
+ }
+
+ fn parse_type_void(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(2)?;
+ let id = self.next()?;
+ self.lookup_void_type = Some(id);
+ Ok(())
+ }
+
+ fn parse_type_bool(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(2)?;
+ let id = self.next()?;
+ let inner = crate::TypeInner::Scalar(crate::Scalar::BOOL);
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_int(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let width = self.next()?;
+ let sign = self.next()?;
+ let inner = crate::TypeInner::Scalar(crate::Scalar {
+ kind: match sign {
+ 0 => crate::ScalarKind::Uint,
+ 1 => crate::ScalarKind::Sint,
+ _ => return Err(Error::InvalidSign(sign)),
+ },
+ width: map_width(width)?,
+ });
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_float(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let id = self.next()?;
+ let width = self.next()?;
+ let inner = crate::TypeInner::Scalar(crate::Scalar::float(map_width(width)?));
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_vector(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let type_id = self.next()?;
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let scalar = match module.types[type_lookup.handle].inner {
+ crate::TypeInner::Scalar(scalar) => scalar,
+ _ => return Err(Error::InvalidInnerType(type_id)),
+ };
+ let component_count = self.next()?;
+ let inner = crate::TypeInner::Vector {
+ size: map_vector_size(component_count)?,
+ scalar,
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: Some(type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_matrix(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let vector_type_id = self.next()?;
+ let num_columns = self.next()?;
+ let decor = self.future_decor.remove(&id);
+
+ let vector_type_lookup = self.lookup_type.lookup(vector_type_id)?;
+ let inner = match module.types[vector_type_lookup.handle].inner {
+ crate::TypeInner::Vector { size, scalar } => crate::TypeInner::Matrix {
+ columns: map_vector_size(num_columns)?,
+ rows: size,
+ scalar,
+ },
+ _ => return Err(Error::InvalidInnerType(vector_type_id)),
+ };
+
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: decor.and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: Some(vector_type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_function(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(3)?;
+ let id = self.next()?;
+ let return_type_id = self.next()?;
+ let parameter_type_ids = self.data.by_ref().take(inst.wc as usize - 3).collect();
+ self.lookup_function_type.insert(
+ id,
+ LookupFunctionType {
+ parameter_type_ids,
+ return_type_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_pointer(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let storage_class = self.next()?;
+ let type_id = self.next()?;
+
+ let decor = self.future_decor.remove(&id);
+ let base_lookup_ty = self.lookup_type.lookup(type_id)?;
+ let base_inner = &module.types[base_lookup_ty.handle].inner;
+
+ let space = if let Some(space) = base_inner.pointer_space() {
+ space
+ } else if self
+ .lookup_storage_buffer_types
+ .contains_key(&base_lookup_ty.handle)
+ {
+ crate::AddressSpace::Storage {
+ access: crate::StorageAccess::default(),
+ }
+ } else {
+ match map_storage_class(storage_class)? {
+ ExtendedClass::Global(space) => space,
+ ExtendedClass::Input | ExtendedClass::Output => crate::AddressSpace::Private,
+ }
+ };
+
+ // We don't support pointers to runtime-sized arrays in the `Uniform`
+ // storage class with the `BufferBlock` decoration. Runtime-sized arrays
+ // should be in the StorageBuffer class.
+ if let crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } = *base_inner
+ {
+ match space {
+ crate::AddressSpace::Storage { .. } => {}
+ _ => {
+ return Err(Error::UnsupportedRuntimeArrayStorageClass);
+ }
+ }
+ }
+
+ // Don't bother with pointer stuff for `Handle` types.
+ let lookup_ty = if space == crate::AddressSpace::Handle {
+ base_lookup_ty.clone()
+ } else {
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: decor.and_then(|dec| dec.name),
+ inner: crate::TypeInner::Pointer {
+ base: base_lookup_ty.handle,
+ space,
+ },
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: Some(type_id),
+ }
+ };
+ self.lookup_type.insert(id, lookup_ty);
+ Ok(())
+ }
+
+ fn parse_type_array(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let type_id = self.next()?;
+ let length_id = self.next()?;
+ let length_const = self.lookup_constant.lookup(length_id)?;
+
+ let size = resolve_constant(module.to_ctx(), length_const.handle)
+ .and_then(NonZeroU32::new)
+ .ok_or(Error::InvalidArraySize(length_const.handle))?;
+
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+ let base = self.lookup_type.lookup(type_id)?.handle;
+
+ self.layouter.update(module.to_ctx()).unwrap();
+
+ // HACK if the underlying type is an image or a sampler, let's assume
+ // that we're dealing with a binding-array
+ //
+ // Note that it's not a strictly correct assumption, but rather a trade
+ // off caused by an impedance mismatch between SPIR-V's and Naga's type
+ // systems - Naga distinguishes between arrays and binding-arrays via
+ // types (i.e. both kinds of arrays are just different types), while
+ // SPIR-V distinguishes between them through usage - e.g. given:
+ //
+ // ```
+ // %image = OpTypeImage %float 2D 2 0 0 2 Rgba16f
+ // %uint_256 = OpConstant %uint 256
+ // %image_array = OpTypeArray %image %uint_256
+ // ```
+ //
+ // ```
+ // %image = OpTypeImage %float 2D 2 0 0 2 Rgba16f
+ // %uint_256 = OpConstant %uint 256
+ // %image_array = OpTypeArray %image %uint_256
+ // %image_array_ptr = OpTypePointer UniformConstant %image_array
+ // ```
+ //
+ // ... in the first case, `%image_array` should technically correspond
+ // to `TypeInner::Array`, while in the second case it should say
+ // `TypeInner::BindingArray` (kinda, depending on whether `%image_array`
+ // is ever used as a freestanding type or rather always through the
+ // pointer-indirection).
+ //
+ // Anyway, at the moment we don't support other kinds of image / sampler
+ // arrays than those binding-based, so this assumption is pretty safe
+ // for now.
+ let inner = if let crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } =
+ module.types[base].inner
+ {
+ crate::TypeInner::BindingArray {
+ base,
+ size: crate::ArraySize::Constant(size),
+ }
+ } else {
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ stride: match decor.array_stride {
+ Some(stride) => stride.get(),
+ None => self.layouter[base].to_stride(),
+ },
+ }
+ };
+
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: decor.name,
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: Some(type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_runtime_array(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let id = self.next()?;
+ let type_id = self.next()?;
+
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+ let base = self.lookup_type.lookup(type_id)?.handle;
+
+ self.layouter.update(module.to_ctx()).unwrap();
+
+ // HACK same case as in `parse_type_array()`
+ let inner = if let crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } =
+ module.types[base].inner
+ {
+ crate::TypeInner::BindingArray {
+ base: self.lookup_type.lookup(type_id)?.handle,
+ size: crate::ArraySize::Dynamic,
+ }
+ } else {
+ crate::TypeInner::Array {
+ base: self.lookup_type.lookup(type_id)?.handle,
+ size: crate::ArraySize::Dynamic,
+ stride: match decor.array_stride {
+ Some(stride) => stride.get(),
+ None => self.layouter[base].to_stride(),
+ },
+ }
+ };
+
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.insert(
+ crate::Type {
+ name: decor.name,
+ inner,
+ },
+ self.span_from_with_op(start),
+ ),
+ base_id: Some(type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_struct(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(2)?;
+ let id = self.next()?;
+ let parent_decor = self.future_decor.remove(&id);
+ let is_storage_buffer = parent_decor
+ .as_ref()
+ .map_or(false, |decor| decor.storage_buffer);
+
+ self.layouter.update(module.to_ctx()).unwrap();
+
+ let mut members = Vec::<crate::StructMember>::with_capacity(inst.wc as usize - 2);
+ let mut member_lookups = Vec::with_capacity(members.capacity());
+ let mut storage_access = crate::StorageAccess::empty();
+ let mut span = 0;
+ let mut alignment = Alignment::ONE;
+ for i in 0..u32::from(inst.wc) - 2 {
+ let type_id = self.next()?;
+ let ty = self.lookup_type.lookup(type_id)?.handle;
+ let decor = self
+ .future_member_decor
+ .remove(&(id, i))
+ .unwrap_or_default();
+
+ storage_access |= decor.flags.to_storage_access();
+
+ member_lookups.push(LookupMember {
+ type_id,
+ row_major: decor.matrix_major == Some(Majority::Row),
+ });
+
+ let member_alignment = self.layouter[ty].alignment;
+ span = member_alignment.round_up(span);
+ alignment = member_alignment.max(alignment);
+
+ let binding = decor.io_binding().ok();
+ if let Some(offset) = decor.offset {
+ span = offset;
+ }
+ let offset = span;
+
+ span += self.layouter[ty].size;
+
+ let inner = &module.types[ty].inner;
+ if let crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } = *inner
+ {
+ if let Some(stride) = decor.matrix_stride {
+ let expected_stride = Alignment::from(rows) * scalar.width as u32;
+ if stride.get() != expected_stride {
+ return Err(Error::UnsupportedMatrixStride {
+ stride: stride.get(),
+ columns: columns as u8,
+ rows: rows as u8,
+ width: scalar.width,
+ });
+ }
+ }
+ }
+
+ members.push(crate::StructMember {
+ name: decor.name,
+ ty,
+ binding,
+ offset,
+ });
+ }
+
+ span = alignment.round_up(span);
+
+ let inner = crate::TypeInner::Struct { span, members };
+
+ let ty_handle = module.types.insert(
+ crate::Type {
+ name: parent_decor.and_then(|dec| dec.name),
+ inner,
+ },
+ self.span_from_with_op(start),
+ );
+
+ if is_storage_buffer {
+ self.lookup_storage_buffer_types
+ .insert(ty_handle, storage_access);
+ }
+ for (i, member_lookup) in member_lookups.into_iter().enumerate() {
+ self.lookup_member
+ .insert((ty_handle, i as u32), member_lookup);
+ }
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: ty_handle,
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_image(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(9)?;
+
+ let id = self.next()?;
+ let sample_type_id = self.next()?;
+ let dim = self.next()?;
+ let is_depth = self.next()?;
+ let is_array = self.next()? != 0;
+ let is_msaa = self.next()? != 0;
+ let _is_sampled = self.next()?;
+ let format = self.next()?;
+
+ let dim = map_image_dim(dim)?;
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+
+ // ensure there is a type for texture coordinate without extra components
+ module.types.insert(
+ crate::Type {
+ name: None,
+ inner: {
+ let scalar = crate::Scalar::F32;
+ match dim.required_coordinate_size() {
+ None => crate::TypeInner::Scalar(scalar),
+ Some(size) => crate::TypeInner::Vector { size, scalar },
+ }
+ },
+ },
+ Default::default(),
+ );
+
+ let base_handle = self.lookup_type.lookup(sample_type_id)?.handle;
+ let kind = module.types[base_handle]
+ .inner
+ .scalar_kind()
+ .ok_or(Error::InvalidImageBaseType(base_handle))?;
+
+ let inner = crate::TypeInner::Image {
+ class: if is_depth == 1 {
+ crate::ImageClass::Depth { multi: is_msaa }
+ } else if format != 0 {
+ crate::ImageClass::Storage {
+ format: map_image_format(format)?,
+ access: crate::StorageAccess::default(),
+ }
+ } else {
+ crate::ImageClass::Sampled {
+ kind,
+ multi: is_msaa,
+ }
+ },
+ dim,
+ arrayed: is_array,
+ };
+
+ let handle = module.types.insert(
+ crate::Type {
+ name: decor.name,
+ inner,
+ },
+ self.span_from_with_op(start),
+ );
+
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle,
+ base_id: Some(sample_type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_sampled_image(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let id = self.next()?;
+ let image_id = self.next()?;
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: self.lookup_type.lookup(image_id)?.handle,
+ base_id: Some(image_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_sampler(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(2)?;
+ let id = self.next()?;
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+ let handle = module.types.insert(
+ crate::Type {
+ name: decor.name,
+ inner: crate::TypeInner::Sampler { comparison: false },
+ },
+ self.span_from_with_op(start),
+ );
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle,
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_constant(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(4)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let ty = type_lookup.handle;
+
+ let literal = match module.types[ty].inner {
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width,
+ }) => {
+ let low = self.next()?;
+ match width {
+ 4 => crate::Literal::U32(low),
+ _ => return Err(Error::InvalidTypeWidth(width as u32)),
+ }
+ }
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width,
+ }) => {
+ let low = self.next()?;
+ match width {
+ 4 => crate::Literal::I32(low as i32),
+ 8 => {
+ inst.expect(5)?;
+ let high = self.next()?;
+ crate::Literal::I64((u64::from(high) << 32 | u64::from(low)) as i64)
+ }
+ _ => return Err(Error::InvalidTypeWidth(width as u32)),
+ }
+ }
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Float,
+ width,
+ }) => {
+ let low = self.next()?;
+ match width {
+ 4 => crate::Literal::F32(f32::from_bits(low)),
+ 8 => {
+ inst.expect(5)?;
+ let high = self.next()?;
+ crate::Literal::F64(f64::from_bits(
+ (u64::from(high) << 32) | u64::from(low),
+ ))
+ }
+ _ => return Err(Error::InvalidTypeWidth(width as u32)),
+ }
+ }
+ _ => return Err(Error::UnsupportedType(type_lookup.handle)),
+ };
+
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+
+ let span = self.span_from_with_op(start);
+
+ let init = module
+ .const_expressions
+ .append(crate::Expression::Literal(literal), span);
+ self.lookup_constant.insert(
+ id,
+ LookupConstant {
+ handle: module.constants.append(
+ crate::Constant {
+ r#override: decor.specialization(),
+ name: decor.name,
+ ty,
+ init,
+ },
+ span,
+ ),
+ type_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_composite_constant(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(3)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let ty = type_lookup.handle;
+
+ let mut components = Vec::with_capacity(inst.wc as usize - 3);
+ for _ in 0..components.capacity() {
+ let start = self.data_offset;
+ let component_id = self.next()?;
+ let span = self.span_from_with_op(start);
+ let constant = self.lookup_constant.lookup(component_id)?;
+ let expr = module
+ .const_expressions
+ .append(crate::Expression::Constant(constant.handle), span);
+ components.push(expr);
+ }
+
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+
+ let span = self.span_from_with_op(start);
+
+ let init = module
+ .const_expressions
+ .append(crate::Expression::Compose { ty, components }, span);
+ self.lookup_constant.insert(
+ id,
+ LookupConstant {
+ handle: module.constants.append(
+ crate::Constant {
+ r#override: decor.specialization(),
+ name: decor.name,
+ ty,
+ init,
+ },
+ span,
+ ),
+ type_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_null_constant(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let span = self.span_from_with_op(start);
+
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let ty = type_lookup.handle;
+
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+
+ let init = module
+ .const_expressions
+ .append(crate::Expression::ZeroValue(ty), span);
+ let handle = module.constants.append(
+ crate::Constant {
+ r#override: decor.specialization(),
+ name: decor.name,
+ ty,
+ init,
+ },
+ span,
+ );
+ self.lookup_constant
+ .insert(id, LookupConstant { handle, type_id });
+ Ok(())
+ }
+
+ fn parse_bool_constant(
+ &mut self,
+ inst: Instruction,
+ value: bool,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let span = self.span_from_with_op(start);
+
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let ty = type_lookup.handle;
+
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+
+ let init = module.const_expressions.append(
+ crate::Expression::Literal(crate::Literal::Bool(value)),
+ span,
+ );
+ self.lookup_constant.insert(
+ id,
+ LookupConstant {
+ handle: module.constants.append(
+ crate::Constant {
+ r#override: decor.specialization(),
+ name: decor.name,
+ ty,
+ init,
+ },
+ span,
+ ),
+ type_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_global_variable(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ let start = self.data_offset;
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(4)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let storage_class = self.next()?;
+ let init = if inst.wc > 4 {
+ inst.expect(5)?;
+ let start = self.data_offset;
+ let init_id = self.next()?;
+ let span = self.span_from_with_op(start);
+ let lconst = self.lookup_constant.lookup(init_id)?;
+ let expr = module
+ .const_expressions
+ .append(crate::Expression::Constant(lconst.handle), span);
+ Some(expr)
+ } else {
+ None
+ };
+ let span = self.span_from_with_op(start);
+ let mut dec = self.future_decor.remove(&id).unwrap_or_default();
+
+ let original_ty = self.lookup_type.lookup(type_id)?.handle;
+ let mut ty = original_ty;
+
+ if let crate::TypeInner::Pointer { base, space: _ } = module.types[original_ty].inner {
+ ty = base;
+ }
+
+ if let crate::TypeInner::BindingArray { .. } = module.types[original_ty].inner {
+ // Inside `parse_type_array()` we guess that an array of images or
+ // samplers must be a binding array, and here we validate that guess
+ if dec.desc_set.is_none() || dec.desc_index.is_none() {
+ return Err(Error::NonBindingArrayOfImageOrSamplers);
+ }
+ }
+
+ if let crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: crate::ImageClass::Storage { format, access: _ },
+ } = module.types[ty].inner
+ {
+ // Storage image types in IR have to contain the access, but not in the SPIR-V.
+ // The same image type in SPIR-V can be used (and has to be used) for multiple images.
+ // So we copy the type out and apply the variable access decorations.
+ let access = dec.flags.to_storage_access();
+
+ ty = module.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: crate::ImageClass::Storage { format, access },
+ },
+ },
+ Default::default(),
+ );
+ }
+
+ let ext_class = match self.lookup_storage_buffer_types.get(&ty) {
+ Some(&access) => ExtendedClass::Global(crate::AddressSpace::Storage { access }),
+ None => map_storage_class(storage_class)?,
+ };
+
+ // Fix empty name for gl_PerVertex struct generated by glslang
+ if let crate::TypeInner::Pointer { .. } = module.types[original_ty].inner {
+ if ext_class == ExtendedClass::Input || ext_class == ExtendedClass::Output {
+ if let Some(ref dec_name) = dec.name {
+ if dec_name.is_empty() {
+ dec.name = Some("perVertexStruct".to_string())
+ }
+ }
+ }
+ }
+
+ let (inner, var) = match ext_class {
+ ExtendedClass::Global(mut space) => {
+ if let crate::AddressSpace::Storage { ref mut access } = space {
+ *access &= dec.flags.to_storage_access();
+ }
+ let var = crate::GlobalVariable {
+ binding: dec.resource_binding(),
+ name: dec.name,
+ space,
+ ty,
+ init,
+ };
+ (Variable::Global, var)
+ }
+ ExtendedClass::Input => {
+ let binding = dec.io_binding()?;
+ let mut unsigned_ty = ty;
+ if let crate::Binding::BuiltIn(built_in) = binding {
+ let needs_inner_uint = match built_in {
+ crate::BuiltIn::BaseInstance
+ | crate::BuiltIn::BaseVertex
+ | crate::BuiltIn::InstanceIndex
+ | crate::BuiltIn::SampleIndex
+ | crate::BuiltIn::VertexIndex
+ | crate::BuiltIn::PrimitiveIndex
+ | crate::BuiltIn::LocalInvocationIndex => {
+ Some(crate::TypeInner::Scalar(crate::Scalar::U32))
+ }
+ crate::BuiltIn::GlobalInvocationId
+ | crate::BuiltIn::LocalInvocationId
+ | crate::BuiltIn::WorkGroupId
+ | crate::BuiltIn::WorkGroupSize => Some(crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: crate::Scalar::U32,
+ }),
+ _ => None,
+ };
+ if let (Some(inner), Some(crate::ScalarKind::Sint)) =
+ (needs_inner_uint, module.types[ty].inner.scalar_kind())
+ {
+ unsigned_ty = module
+ .types
+ .insert(crate::Type { name: None, inner }, Default::default());
+ }
+ }
+
+ let var = crate::GlobalVariable {
+ name: dec.name.clone(),
+ space: crate::AddressSpace::Private,
+ binding: None,
+ ty,
+ init: None,
+ };
+
+ let inner = Variable::Input(crate::FunctionArgument {
+ name: dec.name,
+ ty: unsigned_ty,
+ binding: Some(binding),
+ });
+ (inner, var)
+ }
+ ExtendedClass::Output => {
+ // For output interface blocks, this would be a structure.
+ let binding = dec.io_binding().ok();
+ let init = match binding {
+ Some(crate::Binding::BuiltIn(built_in)) => {
+ match null::generate_default_built_in(
+ Some(built_in),
+ ty,
+ &mut module.const_expressions,
+ span,
+ ) {
+ Ok(handle) => Some(handle),
+ Err(e) => {
+ log::warn!("Failed to initialize output built-in: {}", e);
+ None
+ }
+ }
+ }
+ Some(crate::Binding::Location { .. }) => None,
+ None => match module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let mut components = Vec::with_capacity(members.len());
+ for member in members.iter() {
+ let built_in = match member.binding {
+ Some(crate::Binding::BuiltIn(built_in)) => Some(built_in),
+ _ => None,
+ };
+ let handle = null::generate_default_built_in(
+ built_in,
+ member.ty,
+ &mut module.const_expressions,
+ span,
+ )?;
+ components.push(handle);
+ }
+ Some(
+ module
+ .const_expressions
+ .append(crate::Expression::Compose { ty, components }, span),
+ )
+ }
+ _ => None,
+ },
+ };
+
+ let var = crate::GlobalVariable {
+ name: dec.name,
+ space: crate::AddressSpace::Private,
+ binding: None,
+ ty,
+ init,
+ };
+ let inner = Variable::Output(crate::FunctionResult { ty, binding });
+ (inner, var)
+ }
+ };
+
+ let handle = module.global_variables.append(var, span);
+
+ if module.types[ty].inner.can_comparison_sample(module) {
+ log::debug!("\t\ttracking {:?} for sampling properties", handle);
+
+ self.handle_sampling
+ .insert(handle, image::SamplingFlags::empty());
+ }
+
+ self.lookup_variable.insert(
+ id,
+ LookupVariable {
+ inner,
+ handle,
+ type_id,
+ },
+ );
+ Ok(())
+ }
+}
+
+fn make_index_literal(
+ ctx: &mut BlockContext,
+ index: u32,
+ block: &mut crate::Block,
+ emitter: &mut crate::proc::Emitter,
+ index_type: Handle<crate::Type>,
+ index_type_id: spirv::Word,
+ span: crate::Span,
+) -> Result<Handle<crate::Expression>, Error> {
+ block.extend(emitter.finish(ctx.expressions));
+
+ let literal = match ctx.type_arena[index_type].inner.scalar_kind() {
+ Some(crate::ScalarKind::Uint) => crate::Literal::U32(index),
+ Some(crate::ScalarKind::Sint) => crate::Literal::I32(index as i32),
+ _ => return Err(Error::InvalidIndexType(index_type_id)),
+ };
+ let expr = ctx
+ .expressions
+ .append(crate::Expression::Literal(literal), span);
+
+ emitter.start(ctx.expressions);
+ Ok(expr)
+}
+
+fn resolve_constant(
+ gctx: crate::proc::GlobalCtx,
+ constant: Handle<crate::Constant>,
+) -> Option<u32> {
+ match gctx.const_expressions[gctx.constants[constant].init] {
+ crate::Expression::Literal(crate::Literal::U32(id)) => Some(id),
+ crate::Expression::Literal(crate::Literal::I32(id)) => Some(id as u32),
+ _ => None,
+ }
+}
+
+pub fn parse_u8_slice(data: &[u8], options: &Options) -> Result<crate::Module, Error> {
+ if data.len() % 4 != 0 {
+ return Err(Error::IncompleteData);
+ }
+
+ let words = data
+ .chunks(4)
+ .map(|c| u32::from_le_bytes(c.try_into().unwrap()));
+ Frontend::new(words, options).parse()
+}
+
+#[cfg(test)]
+mod test {
+ #[test]
+ fn parse() {
+ let bin = vec![
+ // Magic number. Version number: 1.0.
+ 0x03, 0x02, 0x23, 0x07, 0x00, 0x00, 0x01, 0x00,
+ // Generator number: 0. Bound: 0.
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Reserved word: 0.
+ 0x00, 0x00, 0x00, 0x00, // OpMemoryModel. Logical.
+ 0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // GLSL450.
+ 0x01, 0x00, 0x00, 0x00,
+ ];
+ let _ = super::parse_u8_slice(&bin, &Default::default()).unwrap();
+ }
+}
+
+/// Helper function to check if `child` is in the scope of `parent`
+fn is_parent(mut child: usize, parent: usize, block_ctx: &BlockContext) -> bool {
+ loop {
+ if child == parent {
+ // The child is in the scope parent
+ break true;
+ } else if child == 0 {
+ // Searched finished at the root the child isn't in the parent's body
+ break false;
+ }
+
+ child = block_ctx.bodies[child].parent;
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/null.rs b/third_party/rust/naga/src/front/spv/null.rs
new file mode 100644
index 0000000000..42cccca80a
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/null.rs
@@ -0,0 +1,31 @@
+use super::Error;
+use crate::arena::{Arena, Handle};
+
+/// Create a default value for an output built-in.
+pub fn generate_default_built_in(
+ built_in: Option<crate::BuiltIn>,
+ ty: Handle<crate::Type>,
+ const_expressions: &mut Arena<crate::Expression>,
+ span: crate::Span,
+) -> Result<Handle<crate::Expression>, Error> {
+ let expr = match built_in {
+ Some(crate::BuiltIn::Position { .. }) => {
+ let zero = const_expressions
+ .append(crate::Expression::Literal(crate::Literal::F32(0.0)), span);
+ let one = const_expressions
+ .append(crate::Expression::Literal(crate::Literal::F32(1.0)), span);
+ crate::Expression::Compose {
+ ty,
+ components: vec![zero, zero, zero, one],
+ }
+ }
+ Some(crate::BuiltIn::PointSize) => crate::Expression::Literal(crate::Literal::F32(1.0)),
+ Some(crate::BuiltIn::FragDepth) => crate::Expression::Literal(crate::Literal::F32(0.0)),
+ Some(crate::BuiltIn::SampleMask) => {
+ crate::Expression::Literal(crate::Literal::U32(u32::MAX))
+ }
+ // Note: `crate::BuiltIn::ClipDistance` is intentionally left for the default path
+ _ => crate::Expression::ZeroValue(ty),
+ };
+ Ok(const_expressions.append(expr, span))
+}
diff --git a/third_party/rust/naga/src/front/type_gen.rs b/third_party/rust/naga/src/front/type_gen.rs
new file mode 100644
index 0000000000..34730c1db5
--- /dev/null
+++ b/third_party/rust/naga/src/front/type_gen.rs
@@ -0,0 +1,437 @@
+/*!
+Type generators.
+*/
+
+use crate::{arena::Handle, span::Span};
+
+impl crate::Module {
+ /// Populate this module's [`SpecialTypes::ray_desc`] type.
+ ///
+ /// [`SpecialTypes::ray_desc`] is the type of the [`descriptor`] operand of
+ /// an [`Initialize`] [`RayQuery`] statement. In WGSL, it is a struct type
+ /// referred to as `RayDesc`.
+ ///
+ /// Backends consume values of this type to drive platform APIs, so if you
+ /// change any its fields, you must update the backends to match. Look for
+ /// backend code dealing with [`RayQueryFunction::Initialize`].
+ ///
+ /// [`SpecialTypes::ray_desc`]: crate::SpecialTypes::ray_desc
+ /// [`descriptor`]: crate::RayQueryFunction::Initialize::descriptor
+ /// [`Initialize`]: crate::RayQueryFunction::Initialize
+ /// [`RayQuery`]: crate::Statement::RayQuery
+ /// [`RayQueryFunction::Initialize`]: crate::RayQueryFunction::Initialize
+ pub fn generate_ray_desc_type(&mut self) -> Handle<crate::Type> {
+ if let Some(handle) = self.special_types.ray_desc {
+ return handle;
+ }
+
+ let ty_flag = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar(crate::Scalar::U32),
+ },
+ Span::UNDEFINED,
+ );
+ let ty_scalar = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar(crate::Scalar::F32),
+ },
+ Span::UNDEFINED,
+ );
+ let ty_vector = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: crate::Scalar::F32,
+ },
+ },
+ Span::UNDEFINED,
+ );
+
+ let handle = self.types.insert(
+ crate::Type {
+ name: Some("RayDesc".to_string()),
+ inner: crate::TypeInner::Struct {
+ members: vec![
+ crate::StructMember {
+ name: Some("flags".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 0,
+ },
+ crate::StructMember {
+ name: Some("cull_mask".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 4,
+ },
+ crate::StructMember {
+ name: Some("tmin".to_string()),
+ ty: ty_scalar,
+ binding: None,
+ offset: 8,
+ },
+ crate::StructMember {
+ name: Some("tmax".to_string()),
+ ty: ty_scalar,
+ binding: None,
+ offset: 12,
+ },
+ crate::StructMember {
+ name: Some("origin".to_string()),
+ ty: ty_vector,
+ binding: None,
+ offset: 16,
+ },
+ crate::StructMember {
+ name: Some("dir".to_string()),
+ ty: ty_vector,
+ binding: None,
+ offset: 32,
+ },
+ ],
+ span: 48,
+ },
+ },
+ Span::UNDEFINED,
+ );
+
+ self.special_types.ray_desc = Some(handle);
+ handle
+ }
+
+ /// Populate this module's [`SpecialTypes::ray_intersection`] type.
+ ///
+ /// [`SpecialTypes::ray_intersection`] is the type of a
+ /// `RayQueryGetIntersection` expression. In WGSL, it is a struct type
+ /// referred to as `RayIntersection`.
+ ///
+ /// Backends construct values of this type based on platform APIs, so if you
+ /// change any its fields, you must update the backends to match. Look for
+ /// the backend's handling for [`Expression::RayQueryGetIntersection`].
+ ///
+ /// [`SpecialTypes::ray_intersection`]: crate::SpecialTypes::ray_intersection
+ /// [`Expression::RayQueryGetIntersection`]: crate::Expression::RayQueryGetIntersection
+ pub fn generate_ray_intersection_type(&mut self) -> Handle<crate::Type> {
+ if let Some(handle) = self.special_types.ray_intersection {
+ return handle;
+ }
+
+ let ty_flag = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar(crate::Scalar::U32),
+ },
+ Span::UNDEFINED,
+ );
+ let ty_scalar = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar(crate::Scalar::F32),
+ },
+ Span::UNDEFINED,
+ );
+ let ty_barycentrics = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: crate::Scalar::F32,
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let ty_bool = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar(crate::Scalar::BOOL),
+ },
+ Span::UNDEFINED,
+ );
+ let ty_transform = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ scalar: crate::Scalar::F32,
+ },
+ },
+ Span::UNDEFINED,
+ );
+
+ let handle = self.types.insert(
+ crate::Type {
+ name: Some("RayIntersection".to_string()),
+ inner: crate::TypeInner::Struct {
+ members: vec![
+ crate::StructMember {
+ name: Some("kind".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 0,
+ },
+ crate::StructMember {
+ name: Some("t".to_string()),
+ ty: ty_scalar,
+ binding: None,
+ offset: 4,
+ },
+ crate::StructMember {
+ name: Some("instance_custom_index".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 8,
+ },
+ crate::StructMember {
+ name: Some("instance_id".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 12,
+ },
+ crate::StructMember {
+ name: Some("sbt_record_offset".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 16,
+ },
+ crate::StructMember {
+ name: Some("geometry_index".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 20,
+ },
+ crate::StructMember {
+ name: Some("primitive_index".to_string()),
+ ty: ty_flag,
+ binding: None,
+ offset: 24,
+ },
+ crate::StructMember {
+ name: Some("barycentrics".to_string()),
+ ty: ty_barycentrics,
+ binding: None,
+ offset: 28,
+ },
+ crate::StructMember {
+ name: Some("front_face".to_string()),
+ ty: ty_bool,
+ binding: None,
+ offset: 36,
+ },
+ crate::StructMember {
+ name: Some("object_to_world".to_string()),
+ ty: ty_transform,
+ binding: None,
+ offset: 48,
+ },
+ crate::StructMember {
+ name: Some("world_to_object".to_string()),
+ ty: ty_transform,
+ binding: None,
+ offset: 112,
+ },
+ ],
+ span: 176,
+ },
+ },
+ Span::UNDEFINED,
+ );
+
+ self.special_types.ray_intersection = Some(handle);
+ handle
+ }
+
+ /// Populate this module's [`SpecialTypes::predeclared_types`] type and return the handle.
+ ///
+ /// [`SpecialTypes::predeclared_types`]: crate::SpecialTypes::predeclared_types
+ pub fn generate_predeclared_type(
+ &mut self,
+ special_type: crate::PredeclaredType,
+ ) -> Handle<crate::Type> {
+ use std::fmt::Write;
+
+ if let Some(value) = self.special_types.predeclared_types.get(&special_type) {
+ return *value;
+ }
+
+ let ty = match special_type {
+ crate::PredeclaredType::AtomicCompareExchangeWeakResult(scalar) => {
+ let bool_ty = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar(crate::Scalar::BOOL),
+ },
+ Span::UNDEFINED,
+ );
+ let scalar_ty = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar(scalar),
+ },
+ Span::UNDEFINED,
+ );
+
+ crate::Type {
+ name: Some(format!(
+ "__atomic_compare_exchange_result<{:?},{}>",
+ scalar.kind, scalar.width,
+ )),
+ inner: crate::TypeInner::Struct {
+ members: vec![
+ crate::StructMember {
+ name: Some("old_value".to_string()),
+ ty: scalar_ty,
+ binding: None,
+ offset: 0,
+ },
+ crate::StructMember {
+ name: Some("exchanged".to_string()),
+ ty: bool_ty,
+ binding: None,
+ offset: 4,
+ },
+ ],
+ span: 8,
+ },
+ }
+ }
+ crate::PredeclaredType::ModfResult { size, width } => {
+ let float_ty = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar(crate::Scalar::float(width)),
+ },
+ Span::UNDEFINED,
+ );
+
+ let (member_ty, second_offset) = if let Some(size) = size {
+ let vec_ty = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ size,
+ scalar: crate::Scalar::float(width),
+ },
+ },
+ Span::UNDEFINED,
+ );
+ (vec_ty, size as u32 * width as u32)
+ } else {
+ (float_ty, width as u32)
+ };
+
+ let mut type_name = "__modf_result_".to_string();
+ if let Some(size) = size {
+ let _ = write!(type_name, "vec{}_", size as u8);
+ }
+ let _ = write!(type_name, "f{}", width * 8);
+
+ crate::Type {
+ name: Some(type_name),
+ inner: crate::TypeInner::Struct {
+ members: vec![
+ crate::StructMember {
+ name: Some("fract".to_string()),
+ ty: member_ty,
+ binding: None,
+ offset: 0,
+ },
+ crate::StructMember {
+ name: Some("whole".to_string()),
+ ty: member_ty,
+ binding: None,
+ offset: second_offset,
+ },
+ ],
+ span: second_offset * 2,
+ },
+ }
+ }
+ crate::PredeclaredType::FrexpResult { size, width } => {
+ let float_ty = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar(crate::Scalar::float(width)),
+ },
+ Span::UNDEFINED,
+ );
+
+ let int_ty = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width,
+ }),
+ },
+ Span::UNDEFINED,
+ );
+
+ let (fract_member_ty, exp_member_ty, second_offset) = if let Some(size) = size {
+ let vec_float_ty = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ size,
+ scalar: crate::Scalar::float(width),
+ },
+ },
+ Span::UNDEFINED,
+ );
+ let vec_int_ty = self.types.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ size,
+ scalar: crate::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width,
+ },
+ },
+ },
+ Span::UNDEFINED,
+ );
+ (vec_float_ty, vec_int_ty, size as u32 * width as u32)
+ } else {
+ (float_ty, int_ty, width as u32)
+ };
+
+ let mut type_name = "__frexp_result_".to_string();
+ if let Some(size) = size {
+ let _ = write!(type_name, "vec{}_", size as u8);
+ }
+ let _ = write!(type_name, "f{}", width * 8);
+
+ crate::Type {
+ name: Some(type_name),
+ inner: crate::TypeInner::Struct {
+ members: vec![
+ crate::StructMember {
+ name: Some("fract".to_string()),
+ ty: fract_member_ty,
+ binding: None,
+ offset: 0,
+ },
+ crate::StructMember {
+ name: Some("exp".to_string()),
+ ty: exp_member_ty,
+ binding: None,
+ offset: second_offset,
+ },
+ ],
+ span: second_offset * 2,
+ },
+ }
+ }
+ };
+
+ let handle = self.types.insert(ty, Span::UNDEFINED);
+ self.special_types
+ .predeclared_types
+ .insert(special_type, handle);
+ handle
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/error.rs b/third_party/rust/naga/src/front/wgsl/error.rs
new file mode 100644
index 0000000000..07e68f8dd9
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/error.rs
@@ -0,0 +1,775 @@
+use crate::front::wgsl::parse::lexer::Token;
+use crate::front::wgsl::Scalar;
+use crate::proc::{Alignment, ConstantEvaluatorError, ResolveError};
+use crate::{SourceLocation, Span};
+use codespan_reporting::diagnostic::{Diagnostic, Label};
+use codespan_reporting::files::SimpleFile;
+use codespan_reporting::term;
+use std::borrow::Cow;
+use std::ops::Range;
+use termcolor::{ColorChoice, NoColor, StandardStream};
+use thiserror::Error;
+
+#[derive(Clone, Debug)]
+pub struct ParseError {
+ message: String,
+ labels: Vec<(Span, Cow<'static, str>)>,
+ notes: Vec<String>,
+}
+
+impl ParseError {
+ pub fn labels(&self) -> impl ExactSizeIterator<Item = (Span, &str)> + '_ {
+ self.labels
+ .iter()
+ .map(|&(span, ref msg)| (span, msg.as_ref()))
+ }
+
+ pub fn message(&self) -> &str {
+ &self.message
+ }
+
+ fn diagnostic(&self) -> Diagnostic<()> {
+ let diagnostic = Diagnostic::error()
+ .with_message(self.message.to_string())
+ .with_labels(
+ self.labels
+ .iter()
+ .filter_map(|label| label.0.to_range().map(|range| (label, range)))
+ .map(|(label, range)| {
+ Label::primary((), range).with_message(label.1.to_string())
+ })
+ .collect(),
+ )
+ .with_notes(
+ self.notes
+ .iter()
+ .map(|note| format!("note: {note}"))
+ .collect(),
+ );
+ diagnostic
+ }
+
+ /// Emits a summary of the error to standard error stream.
+ pub fn emit_to_stderr(&self, source: &str) {
+ self.emit_to_stderr_with_path(source, "wgsl")
+ }
+
+ /// Emits a summary of the error to standard error stream.
+ pub fn emit_to_stderr_with_path<P>(&self, source: &str, path: P)
+ where
+ P: AsRef<std::path::Path>,
+ {
+ let path = path.as_ref().display().to_string();
+ let files = SimpleFile::new(path, source);
+ let config = codespan_reporting::term::Config::default();
+ let writer = StandardStream::stderr(ColorChoice::Auto);
+ term::emit(&mut writer.lock(), &config, &files, &self.diagnostic())
+ .expect("cannot write error");
+ }
+
+ /// Emits a summary of the error to a string.
+ pub fn emit_to_string(&self, source: &str) -> String {
+ self.emit_to_string_with_path(source, "wgsl")
+ }
+
+ /// Emits a summary of the error to a string.
+ pub fn emit_to_string_with_path<P>(&self, source: &str, path: P) -> String
+ where
+ P: AsRef<std::path::Path>,
+ {
+ let path = path.as_ref().display().to_string();
+ let files = SimpleFile::new(path, source);
+ let config = codespan_reporting::term::Config::default();
+ let mut writer = NoColor::new(Vec::new());
+ term::emit(&mut writer, &config, &files, &self.diagnostic()).expect("cannot write error");
+ String::from_utf8(writer.into_inner()).unwrap()
+ }
+
+ /// Returns a [`SourceLocation`] for the first label in the error message.
+ pub fn location(&self, source: &str) -> Option<SourceLocation> {
+ self.labels.get(0).map(|label| label.0.location(source))
+ }
+}
+
+impl std::fmt::Display for ParseError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.message)
+ }
+}
+
+impl std::error::Error for ParseError {
+ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+ None
+ }
+}
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum ExpectedToken<'a> {
+ Token(Token<'a>),
+ Identifier,
+ /// Expected: constant, parenthesized expression, identifier
+ PrimaryExpression,
+ /// Expected: assignment, increment/decrement expression
+ Assignment,
+ /// Expected: 'case', 'default', '}'
+ SwitchItem,
+ /// Expected: ',', ')'
+ WorkgroupSizeSeparator,
+ /// Expected: 'struct', 'let', 'var', 'type', ';', 'fn', eof
+ GlobalItem,
+ /// Expected a type.
+ Type,
+ /// Access of `var`, `let`, `const`.
+ Variable,
+ /// Access of a function
+ Function,
+}
+
+#[derive(Clone, Copy, Debug, Error, PartialEq)]
+pub enum NumberError {
+ #[error("invalid numeric literal format")]
+ Invalid,
+ #[error("numeric literal not representable by target type")]
+ NotRepresentable,
+ #[error("unimplemented f16 type")]
+ UnimplementedF16,
+}
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum InvalidAssignmentType {
+ Other,
+ Swizzle,
+ ImmutableBinding(Span),
+}
+
+#[derive(Clone, Debug)]
+pub enum Error<'a> {
+ Unexpected(Span, ExpectedToken<'a>),
+ UnexpectedComponents(Span),
+ UnexpectedOperationInConstContext(Span),
+ BadNumber(Span, NumberError),
+ BadMatrixScalarKind(Span, Scalar),
+ BadAccessor(Span),
+ BadTexture(Span),
+ BadTypeCast {
+ span: Span,
+ from_type: String,
+ to_type: String,
+ },
+ BadTextureSampleType {
+ span: Span,
+ scalar: Scalar,
+ },
+ BadIncrDecrReferenceType(Span),
+ InvalidResolve(ResolveError),
+ InvalidForInitializer(Span),
+ /// A break if appeared outside of a continuing block
+ InvalidBreakIf(Span),
+ InvalidGatherComponent(Span),
+ InvalidConstructorComponentType(Span, i32),
+ InvalidIdentifierUnderscore(Span),
+ ReservedIdentifierPrefix(Span),
+ UnknownAddressSpace(Span),
+ RepeatedAttribute(Span),
+ UnknownAttribute(Span),
+ UnknownBuiltin(Span),
+ UnknownAccess(Span),
+ UnknownIdent(Span, &'a str),
+ UnknownScalarType(Span),
+ UnknownType(Span),
+ UnknownStorageFormat(Span),
+ UnknownConservativeDepth(Span),
+ SizeAttributeTooLow(Span, u32),
+ AlignAttributeTooLow(Span, Alignment),
+ NonPowerOfTwoAlignAttribute(Span),
+ InconsistentBinding(Span),
+ TypeNotConstructible(Span),
+ TypeNotInferable(Span),
+ InitializationTypeMismatch {
+ name: Span,
+ expected: String,
+ got: String,
+ },
+ MissingType(Span),
+ MissingAttribute(&'static str, Span),
+ InvalidAtomicPointer(Span),
+ InvalidAtomicOperandType(Span),
+ InvalidRayQueryPointer(Span),
+ Pointer(&'static str, Span),
+ NotPointer(Span),
+ NotReference(&'static str, Span),
+ InvalidAssignment {
+ span: Span,
+ ty: InvalidAssignmentType,
+ },
+ ReservedKeyword(Span),
+ /// Redefinition of an identifier (used for both module-scope and local redefinitions).
+ Redefinition {
+ /// Span of the identifier in the previous definition.
+ previous: Span,
+
+ /// Span of the identifier in the new definition.
+ current: Span,
+ },
+ /// A declaration refers to itself directly.
+ RecursiveDeclaration {
+ /// The location of the name of the declaration.
+ ident: Span,
+
+ /// The point at which it is used.
+ usage: Span,
+ },
+ /// A declaration refers to itself indirectly, through one or more other
+ /// definitions.
+ CyclicDeclaration {
+ /// The location of the name of some declaration in the cycle.
+ ident: Span,
+
+ /// The edges of the cycle of references.
+ ///
+ /// Each `(decl, reference)` pair indicates that the declaration whose
+ /// name is `decl` has an identifier at `reference` whose definition is
+ /// the next declaration in the cycle. The last pair's `reference` is
+ /// the same identifier as `ident`, above.
+ path: Vec<(Span, Span)>,
+ },
+ InvalidSwitchValue {
+ uint: bool,
+ span: Span,
+ },
+ CalledEntryPoint(Span),
+ WrongArgumentCount {
+ span: Span,
+ expected: Range<u32>,
+ found: u32,
+ },
+ FunctionReturnsVoid(Span),
+ InvalidWorkGroupUniformLoad(Span),
+ Internal(&'static str),
+ ExpectedConstExprConcreteIntegerScalar(Span),
+ ExpectedNonNegative(Span),
+ ExpectedPositiveArrayLength(Span),
+ MissingWorkgroupSize(Span),
+ ConstantEvaluatorError(ConstantEvaluatorError, Span),
+ AutoConversion {
+ dest_span: Span,
+ dest_type: String,
+ source_span: Span,
+ source_type: String,
+ },
+ AutoConversionLeafScalar {
+ dest_span: Span,
+ dest_scalar: String,
+ source_span: Span,
+ source_type: String,
+ },
+ ConcretizationFailed {
+ expr_span: Span,
+ expr_type: String,
+ scalar: String,
+ inner: ConstantEvaluatorError,
+ },
+}
+
+impl<'a> Error<'a> {
+ pub(crate) fn as_parse_error(&self, source: &'a str) -> ParseError {
+ match *self {
+ Error::Unexpected(unexpected_span, expected) => {
+ let expected_str = match expected {
+ ExpectedToken::Token(token) => {
+ match token {
+ Token::Separator(c) => format!("'{c}'"),
+ Token::Paren(c) => format!("'{c}'"),
+ Token::Attribute => "@".to_string(),
+ Token::Number(_) => "number".to_string(),
+ Token::Word(s) => s.to_string(),
+ Token::Operation(c) => format!("operation ('{c}')"),
+ Token::LogicalOperation(c) => format!("logical operation ('{c}')"),
+ Token::ShiftOperation(c) => format!("bitshift ('{c}{c}')"),
+ Token::AssignmentOperation(c) if c=='<' || c=='>' => format!("bitshift ('{c}{c}=')"),
+ Token::AssignmentOperation(c) => format!("operation ('{c}=')"),
+ Token::IncrementOperation => "increment operation".to_string(),
+ Token::DecrementOperation => "decrement operation".to_string(),
+ Token::Arrow => "->".to_string(),
+ Token::Unknown(c) => format!("unknown ('{c}')"),
+ Token::Trivia => "trivia".to_string(),
+ Token::End => "end".to_string(),
+ }
+ }
+ ExpectedToken::Identifier => "identifier".to_string(),
+ ExpectedToken::PrimaryExpression => "expression".to_string(),
+ ExpectedToken::Assignment => "assignment or increment/decrement".to_string(),
+ ExpectedToken::SwitchItem => "switch item ('case' or 'default') or a closing curly bracket to signify the end of the switch statement ('}')".to_string(),
+ ExpectedToken::WorkgroupSizeSeparator => "workgroup size separator (',') or a closing parenthesis".to_string(),
+ ExpectedToken::GlobalItem => "global item ('struct', 'const', 'var', 'alias', ';', 'fn') or the end of the file".to_string(),
+ ExpectedToken::Type => "type".to_string(),
+ ExpectedToken::Variable => "variable access".to_string(),
+ ExpectedToken::Function => "function name".to_string(),
+ };
+ ParseError {
+ message: format!(
+ "expected {}, found '{}'",
+ expected_str, &source[unexpected_span],
+ ),
+ labels: vec![(unexpected_span, format!("expected {expected_str}").into())],
+ notes: vec![],
+ }
+ }
+ Error::UnexpectedComponents(bad_span) => ParseError {
+ message: "unexpected components".to_string(),
+ labels: vec![(bad_span, "unexpected components".into())],
+ notes: vec![],
+ },
+ Error::UnexpectedOperationInConstContext(span) => ParseError {
+ message: "this operation is not supported in a const context".to_string(),
+ labels: vec![(span, "operation not supported here".into())],
+ notes: vec![],
+ },
+ Error::BadNumber(bad_span, ref err) => ParseError {
+ message: format!("{}: `{}`", err, &source[bad_span],),
+ labels: vec![(bad_span, err.to_string().into())],
+ notes: vec![],
+ },
+ Error::BadMatrixScalarKind(span, scalar) => ParseError {
+ message: format!(
+ "matrix scalar type must be floating-point, but found `{}`",
+ scalar.to_wgsl()
+ ),
+ labels: vec![(span, "must be floating-point (e.g. `f32`)".into())],
+ notes: vec![],
+ },
+ Error::BadAccessor(accessor_span) => ParseError {
+ message: format!("invalid field accessor `{}`", &source[accessor_span],),
+ labels: vec![(accessor_span, "invalid accessor".into())],
+ notes: vec![],
+ },
+ Error::UnknownIdent(ident_span, ident) => ParseError {
+ message: format!("no definition in scope for identifier: '{ident}'"),
+ labels: vec![(ident_span, "unknown identifier".into())],
+ notes: vec![],
+ },
+ Error::UnknownScalarType(bad_span) => ParseError {
+ message: format!("unknown scalar type: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown scalar type".into())],
+ notes: vec!["Valid scalar types are f32, f64, i32, u32, bool".into()],
+ },
+ Error::BadTextureSampleType { span, scalar } => ParseError {
+ message: format!(
+ "texture sample type must be one of f32, i32 or u32, but found {}",
+ scalar.to_wgsl()
+ ),
+ labels: vec![(span, "must be one of f32, i32 or u32".into())],
+ notes: vec![],
+ },
+ Error::BadIncrDecrReferenceType(span) => ParseError {
+ message:
+ "increment/decrement operation requires reference type to be one of i32 or u32"
+ .to_string(),
+ labels: vec![(span, "must be a reference type of i32 or u32".into())],
+ notes: vec![],
+ },
+ Error::BadTexture(bad_span) => ParseError {
+ message: format!(
+ "expected an image, but found '{}' which is not an image",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "not an image".into())],
+ notes: vec![],
+ },
+ Error::BadTypeCast {
+ span,
+ ref from_type,
+ ref to_type,
+ } => {
+ let msg = format!("cannot cast a {from_type} to a {to_type}");
+ ParseError {
+ message: msg.clone(),
+ labels: vec![(span, msg.into())],
+ notes: vec![],
+ }
+ }
+ Error::InvalidResolve(ref resolve_error) => ParseError {
+ message: resolve_error.to_string(),
+ labels: vec![],
+ notes: vec![],
+ },
+ Error::InvalidForInitializer(bad_span) => ParseError {
+ message: format!(
+ "for(;;) initializer is not an assignment or a function call: '{}'",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "not an assignment or function call".into())],
+ notes: vec![],
+ },
+ Error::InvalidBreakIf(bad_span) => ParseError {
+ message: "A break if is only allowed in a continuing block".to_string(),
+ labels: vec![(bad_span, "not in a continuing block".into())],
+ notes: vec![],
+ },
+ Error::InvalidGatherComponent(bad_span) => ParseError {
+ message: format!(
+ "textureGather component '{}' doesn't exist, must be 0, 1, 2, or 3",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "invalid component".into())],
+ notes: vec![],
+ },
+ Error::InvalidConstructorComponentType(bad_span, component) => ParseError {
+ message: format!("invalid type for constructor component at index [{component}]"),
+ labels: vec![(bad_span, "invalid component type".into())],
+ notes: vec![],
+ },
+ Error::InvalidIdentifierUnderscore(bad_span) => ParseError {
+ message: "Identifier can't be '_'".to_string(),
+ labels: vec![(bad_span, "invalid identifier".into())],
+ notes: vec![
+ "Use phony assignment instead ('_ =' notice the absence of 'let' or 'var')"
+ .to_string(),
+ ],
+ },
+ Error::ReservedIdentifierPrefix(bad_span) => ParseError {
+ message: format!(
+ "Identifier starts with a reserved prefix: '{}'",
+ &source[bad_span]
+ ),
+ labels: vec![(bad_span, "invalid identifier".into())],
+ notes: vec![],
+ },
+ Error::UnknownAddressSpace(bad_span) => ParseError {
+ message: format!("unknown address space: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown address space".into())],
+ notes: vec![],
+ },
+ Error::RepeatedAttribute(bad_span) => ParseError {
+ message: format!("repeated attribute: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "repeated attribute".into())],
+ notes: vec![],
+ },
+ Error::UnknownAttribute(bad_span) => ParseError {
+ message: format!("unknown attribute: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown attribute".into())],
+ notes: vec![],
+ },
+ Error::UnknownBuiltin(bad_span) => ParseError {
+ message: format!("unknown builtin: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown builtin".into())],
+ notes: vec![],
+ },
+ Error::UnknownAccess(bad_span) => ParseError {
+ message: format!("unknown access: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown access".into())],
+ notes: vec![],
+ },
+ Error::UnknownStorageFormat(bad_span) => ParseError {
+ message: format!("unknown storage format: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown storage format".into())],
+ notes: vec![],
+ },
+ Error::UnknownConservativeDepth(bad_span) => ParseError {
+ message: format!("unknown conservative depth: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown conservative depth".into())],
+ notes: vec![],
+ },
+ Error::UnknownType(bad_span) => ParseError {
+ message: format!("unknown type: '{}'", &source[bad_span]),
+ labels: vec![(bad_span, "unknown type".into())],
+ notes: vec![],
+ },
+ Error::SizeAttributeTooLow(bad_span, min_size) => ParseError {
+ message: format!("struct member size must be at least {min_size}"),
+ labels: vec![(bad_span, format!("must be at least {min_size}").into())],
+ notes: vec![],
+ },
+ Error::AlignAttributeTooLow(bad_span, min_align) => ParseError {
+ message: format!("struct member alignment must be at least {min_align}"),
+ labels: vec![(bad_span, format!("must be at least {min_align}").into())],
+ notes: vec![],
+ },
+ Error::NonPowerOfTwoAlignAttribute(bad_span) => ParseError {
+ message: "struct member alignment must be a power of 2".to_string(),
+ labels: vec![(bad_span, "must be a power of 2".into())],
+ notes: vec![],
+ },
+ Error::InconsistentBinding(span) => ParseError {
+ message: "input/output binding is not consistent".to_string(),
+ labels: vec![(span, "input/output binding is not consistent".into())],
+ notes: vec![],
+ },
+ Error::TypeNotConstructible(span) => ParseError {
+ message: format!("type `{}` is not constructible", &source[span]),
+ labels: vec![(span, "type is not constructible".into())],
+ notes: vec![],
+ },
+ Error::TypeNotInferable(span) => ParseError {
+ message: "type can't be inferred".to_string(),
+ labels: vec![(span, "type can't be inferred".into())],
+ notes: vec![],
+ },
+ Error::InitializationTypeMismatch { name, ref expected, ref got } => {
+ ParseError {
+ message: format!(
+ "the type of `{}` is expected to be `{}`, but got `{}`",
+ &source[name], expected, got,
+ ),
+ labels: vec![(
+ name,
+ format!("definition of `{}`", &source[name]).into(),
+ )],
+ notes: vec![],
+ }
+ }
+ Error::MissingType(name_span) => ParseError {
+ message: format!("variable `{}` needs a type", &source[name_span]),
+ labels: vec![(
+ name_span,
+ format!("definition of `{}`", &source[name_span]).into(),
+ )],
+ notes: vec![],
+ },
+ Error::MissingAttribute(name, name_span) => ParseError {
+ message: format!(
+ "variable `{}` needs a '{}' attribute",
+ &source[name_span], name
+ ),
+ labels: vec![(
+ name_span,
+ format!("definition of `{}`", &source[name_span]).into(),
+ )],
+ notes: vec![],
+ },
+ Error::InvalidAtomicPointer(span) => ParseError {
+ message: "atomic operation is done on a pointer to a non-atomic".to_string(),
+ labels: vec![(span, "atomic pointer is invalid".into())],
+ notes: vec![],
+ },
+ Error::InvalidAtomicOperandType(span) => ParseError {
+ message: "atomic operand type is inconsistent with the operation".to_string(),
+ labels: vec![(span, "atomic operand type is invalid".into())],
+ notes: vec![],
+ },
+ Error::InvalidRayQueryPointer(span) => ParseError {
+ message: "ray query operation is done on a pointer to a non-ray-query".to_string(),
+ labels: vec![(span, "ray query pointer is invalid".into())],
+ notes: vec![],
+ },
+ Error::NotPointer(span) => ParseError {
+ message: "the operand of the `*` operator must be a pointer".to_string(),
+ labels: vec![(span, "expression is not a pointer".into())],
+ notes: vec![],
+ },
+ Error::NotReference(what, span) => ParseError {
+ message: format!("{what} must be a reference"),
+ labels: vec![(span, "expression is not a reference".into())],
+ notes: vec![],
+ },
+ Error::InvalidAssignment { span, ty } => {
+ let (extra_label, notes) = match ty {
+ InvalidAssignmentType::Swizzle => (
+ None,
+ vec![
+ "WGSL does not support assignments to swizzles".into(),
+ "consider assigning each component individually".into(),
+ ],
+ ),
+ InvalidAssignmentType::ImmutableBinding(binding_span) => (
+ Some((binding_span, "this is an immutable binding".into())),
+ vec![format!(
+ "consider declaring '{}' with `var` instead of `let`",
+ &source[binding_span]
+ )],
+ ),
+ InvalidAssignmentType::Other => (None, vec![]),
+ };
+
+ ParseError {
+ message: "invalid left-hand side of assignment".into(),
+ labels: std::iter::once((span, "cannot assign to this expression".into()))
+ .chain(extra_label)
+ .collect(),
+ notes,
+ }
+ }
+ Error::Pointer(what, span) => ParseError {
+ message: format!("{what} must not be a pointer"),
+ labels: vec![(span, "expression is a pointer".into())],
+ notes: vec![],
+ },
+ Error::ReservedKeyword(name_span) => ParseError {
+ message: format!("name `{}` is a reserved keyword", &source[name_span]),
+ labels: vec![(
+ name_span,
+ format!("definition of `{}`", &source[name_span]).into(),
+ )],
+ notes: vec![],
+ },
+ Error::Redefinition { previous, current } => ParseError {
+ message: format!("redefinition of `{}`", &source[current]),
+ labels: vec![
+ (
+ current,
+ format!("redefinition of `{}`", &source[current]).into(),
+ ),
+ (
+ previous,
+ format!("previous definition of `{}`", &source[previous]).into(),
+ ),
+ ],
+ notes: vec![],
+ },
+ Error::RecursiveDeclaration { ident, usage } => ParseError {
+ message: format!("declaration of `{}` is recursive", &source[ident]),
+ labels: vec![(ident, "".into()), (usage, "uses itself here".into())],
+ notes: vec![],
+ },
+ Error::CyclicDeclaration { ident, ref path } => ParseError {
+ message: format!("declaration of `{}` is cyclic", &source[ident]),
+ labels: path
+ .iter()
+ .enumerate()
+ .flat_map(|(i, &(ident, usage))| {
+ [
+ (ident, "".into()),
+ (
+ usage,
+ if i == path.len() - 1 {
+ "ending the cycle".into()
+ } else {
+ format!("uses `{}`", &source[ident]).into()
+ },
+ ),
+ ]
+ })
+ .collect(),
+ notes: vec![],
+ },
+ Error::InvalidSwitchValue { uint, span } => ParseError {
+ message: "invalid switch value".to_string(),
+ labels: vec![(
+ span,
+ if uint {
+ "expected unsigned integer"
+ } else {
+ "expected signed integer"
+ }
+ .into(),
+ )],
+ notes: vec![if uint {
+ format!("suffix the integer with a `u`: '{}u'", &source[span])
+ } else {
+ let span = span.to_range().unwrap();
+ format!(
+ "remove the `u` suffix: '{}'",
+ &source[span.start..span.end - 1]
+ )
+ }],
+ },
+ Error::CalledEntryPoint(span) => ParseError {
+ message: "entry point cannot be called".to_string(),
+ labels: vec![(span, "entry point cannot be called".into())],
+ notes: vec![],
+ },
+ Error::WrongArgumentCount {
+ span,
+ ref expected,
+ found,
+ } => ParseError {
+ message: format!(
+ "wrong number of arguments: expected {}, found {}",
+ if expected.len() < 2 {
+ format!("{}", expected.start)
+ } else {
+ format!("{}..{}", expected.start, expected.end)
+ },
+ found
+ ),
+ labels: vec![(span, "wrong number of arguments".into())],
+ notes: vec![],
+ },
+ Error::FunctionReturnsVoid(span) => ParseError {
+ message: "function does not return any value".to_string(),
+ labels: vec![(span, "".into())],
+ notes: vec![
+ "perhaps you meant to call the function in a separate statement?".into(),
+ ],
+ },
+ Error::InvalidWorkGroupUniformLoad(span) => ParseError {
+ message: "incorrect type passed to workgroupUniformLoad".into(),
+ labels: vec![(span, "".into())],
+ notes: vec!["passed type must be a workgroup pointer".into()],
+ },
+ Error::Internal(message) => ParseError {
+ message: "internal WGSL front end error".to_string(),
+ labels: vec![],
+ notes: vec![message.into()],
+ },
+ Error::ExpectedConstExprConcreteIntegerScalar(span) => ParseError {
+ message: "must be a const-expression that resolves to a concrete integer scalar (u32 or i32)".to_string(),
+ labels: vec![(span, "must resolve to u32 or i32".into())],
+ notes: vec![],
+ },
+ Error::ExpectedNonNegative(span) => ParseError {
+ message: "must be non-negative (>= 0)".to_string(),
+ labels: vec![(span, "must be non-negative".into())],
+ notes: vec![],
+ },
+ Error::ExpectedPositiveArrayLength(span) => ParseError {
+ message: "array element count must be positive (> 0)".to_string(),
+ labels: vec![(span, "must be positive".into())],
+ notes: vec![],
+ },
+ Error::ConstantEvaluatorError(ref e, span) => ParseError {
+ message: e.to_string(),
+ labels: vec![(span, "see msg".into())],
+ notes: vec![],
+ },
+ Error::MissingWorkgroupSize(span) => ParseError {
+ message: "workgroup size is missing on compute shader entry point".to_string(),
+ labels: vec![(
+ span,
+ "must be paired with a @workgroup_size attribute".into(),
+ )],
+ notes: vec![],
+ },
+ Error::AutoConversion { dest_span, ref dest_type, source_span, ref source_type } => ParseError {
+ message: format!("automatic conversions cannot convert `{source_type}` to `{dest_type}`"),
+ labels: vec![
+ (
+ dest_span,
+ format!("a value of type {dest_type} is required here").into(),
+ ),
+ (
+ source_span,
+ format!("this expression has type {source_type}").into(),
+ )
+ ],
+ notes: vec![],
+ },
+ Error::AutoConversionLeafScalar { dest_span, ref dest_scalar, source_span, ref source_type } => ParseError {
+ message: format!("automatic conversions cannot convert elements of `{source_type}` to `{dest_scalar}`"),
+ labels: vec![
+ (
+ dest_span,
+ format!("a value with elements of type {dest_scalar} is required here").into(),
+ ),
+ (
+ source_span,
+ format!("this expression has type {source_type}").into(),
+ )
+ ],
+ notes: vec![],
+ },
+ Error::ConcretizationFailed { expr_span, ref expr_type, ref scalar, ref inner } => ParseError {
+ message: format!("failed to convert expression to a concrete type: {}", inner),
+ labels: vec![
+ (
+ expr_span,
+ format!("this expression has type {}", expr_type).into(),
+ )
+ ],
+ notes: vec![
+ format!("the expression should have been converted to have {} scalar type", scalar),
+ ]
+ },
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/index.rs b/third_party/rust/naga/src/front/wgsl/index.rs
new file mode 100644
index 0000000000..a5524fe8f1
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/index.rs
@@ -0,0 +1,193 @@
+use super::Error;
+use crate::front::wgsl::parse::ast;
+use crate::{FastHashMap, Handle, Span};
+
+/// A `GlobalDecl` list in which each definition occurs before all its uses.
+pub struct Index<'a> {
+ dependency_order: Vec<Handle<ast::GlobalDecl<'a>>>,
+}
+
+impl<'a> Index<'a> {
+ /// Generate an `Index` for the given translation unit.
+ ///
+ /// Perform a topological sort on `tu`'s global declarations, placing
+ /// referents before the definitions that refer to them.
+ ///
+ /// Return an error if the graph of references between declarations contains
+ /// any cycles.
+ pub fn generate(tu: &ast::TranslationUnit<'a>) -> Result<Self, Error<'a>> {
+ // Produce a map from global definitions' names to their `Handle<GlobalDecl>`s.
+ // While doing so, reject conflicting definitions.
+ let mut globals = FastHashMap::with_capacity_and_hasher(tu.decls.len(), Default::default());
+ for (handle, decl) in tu.decls.iter() {
+ let ident = decl_ident(decl);
+ let name = ident.name;
+ if let Some(old) = globals.insert(name, handle) {
+ return Err(Error::Redefinition {
+ previous: decl_ident(&tu.decls[old]).span,
+ current: ident.span,
+ });
+ }
+ }
+
+ let len = tu.decls.len();
+ let solver = DependencySolver {
+ globals: &globals,
+ module: tu,
+ visited: vec![false; len],
+ temp_visited: vec![false; len],
+ path: Vec::new(),
+ out: Vec::with_capacity(len),
+ };
+ let dependency_order = solver.solve()?;
+
+ Ok(Self { dependency_order })
+ }
+
+ /// Iterate over `GlobalDecl`s, visiting each definition before all its uses.
+ ///
+ /// Produce handles for all of the `GlobalDecl`s of the `TranslationUnit`
+ /// passed to `Index::generate`, ordered so that a given declaration is
+ /// produced before any other declaration that uses it.
+ pub fn visit_ordered(&self) -> impl Iterator<Item = Handle<ast::GlobalDecl<'a>>> + '_ {
+ self.dependency_order.iter().copied()
+ }
+}
+
+/// An edge from a reference to its referent in the current depth-first
+/// traversal.
+///
+/// This is like `ast::Dependency`, except that we've determined which
+/// `GlobalDecl` it refers to.
+struct ResolvedDependency<'a> {
+ /// The referent of some identifier used in the current declaration.
+ decl: Handle<ast::GlobalDecl<'a>>,
+
+ /// Where that use occurs within the current declaration.
+ usage: Span,
+}
+
+/// Local state for ordering a `TranslationUnit`'s module-scope declarations.
+///
+/// Values of this type are used temporarily by `Index::generate`
+/// to perform a depth-first sort on the declarations.
+/// Technically, what we want is a topological sort, but a depth-first sort
+/// has one key benefit - it's much more efficient in storing
+/// the path of each node for error generation.
+struct DependencySolver<'source, 'temp> {
+ /// A map from module-scope definitions' names to their handles.
+ globals: &'temp FastHashMap<&'source str, Handle<ast::GlobalDecl<'source>>>,
+
+ /// The translation unit whose declarations we're ordering.
+ module: &'temp ast::TranslationUnit<'source>,
+
+ /// For each handle, whether we have pushed it onto `out` yet.
+ visited: Vec<bool>,
+
+ /// For each handle, whether it is an predecessor in the current depth-first
+ /// traversal. This is used to detect cycles in the reference graph.
+ temp_visited: Vec<bool>,
+
+ /// The current path in our depth-first traversal. Used for generating
+ /// error messages for non-trivial reference cycles.
+ path: Vec<ResolvedDependency<'source>>,
+
+ /// The list of declaration handles, with declarations before uses.
+ out: Vec<Handle<ast::GlobalDecl<'source>>>,
+}
+
+impl<'a> DependencySolver<'a, '_> {
+ /// Produce the sorted list of declaration handles, and check for cycles.
+ fn solve(mut self) -> Result<Vec<Handle<ast::GlobalDecl<'a>>>, Error<'a>> {
+ for (id, _) in self.module.decls.iter() {
+ if self.visited[id.index()] {
+ continue;
+ }
+
+ self.dfs(id)?;
+ }
+
+ Ok(self.out)
+ }
+
+ /// Ensure that all declarations used by `id` have been added to the
+ /// ordering, and then append `id` itself.
+ fn dfs(&mut self, id: Handle<ast::GlobalDecl<'a>>) -> Result<(), Error<'a>> {
+ let decl = &self.module.decls[id];
+ let id_usize = id.index();
+
+ self.temp_visited[id_usize] = true;
+ for dep in decl.dependencies.iter() {
+ if let Some(&dep_id) = self.globals.get(dep.ident) {
+ self.path.push(ResolvedDependency {
+ decl: dep_id,
+ usage: dep.usage,
+ });
+ let dep_id_usize = dep_id.index();
+
+ if self.temp_visited[dep_id_usize] {
+ // Found a cycle.
+ return if dep_id == id {
+ // A declaration refers to itself directly.
+ Err(Error::RecursiveDeclaration {
+ ident: decl_ident(decl).span,
+ usage: dep.usage,
+ })
+ } else {
+ // A declaration refers to itself indirectly, through
+ // one or more other definitions. Report the entire path
+ // of references.
+ let start_at = self
+ .path
+ .iter()
+ .rev()
+ .enumerate()
+ .find_map(|(i, dep)| (dep.decl == dep_id).then_some(i))
+ .unwrap_or(0);
+
+ Err(Error::CyclicDeclaration {
+ ident: decl_ident(&self.module.decls[dep_id]).span,
+ path: self.path[start_at..]
+ .iter()
+ .map(|curr_dep| {
+ let curr_id = curr_dep.decl;
+ let curr_decl = &self.module.decls[curr_id];
+
+ (decl_ident(curr_decl).span, curr_dep.usage)
+ })
+ .collect(),
+ })
+ };
+ } else if !self.visited[dep_id_usize] {
+ self.dfs(dep_id)?;
+ }
+
+ // Remove this edge from the current path.
+ self.path.pop();
+ }
+
+ // Ignore unresolved identifiers; they may be predeclared objects.
+ }
+
+ // Remove this node from the current path.
+ self.temp_visited[id_usize] = false;
+
+ // Now everything this declaration uses has been visited, and is already
+ // present in `out`. That means we we can append this one to the
+ // ordering, and mark it as visited.
+ self.out.push(id);
+ self.visited[id_usize] = true;
+
+ Ok(())
+ }
+}
+
+const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
+ match decl.kind {
+ ast::GlobalDeclKind::Fn(ref f) => f.name,
+ ast::GlobalDeclKind::Var(ref v) => v.name,
+ ast::GlobalDeclKind::Const(ref c) => c.name,
+ ast::GlobalDeclKind::Struct(ref s) => s.name,
+ ast::GlobalDeclKind::Type(ref t) => t.name,
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/lower/construction.rs b/third_party/rust/naga/src/front/wgsl/lower/construction.rs
new file mode 100644
index 0000000000..de0d11d227
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/lower/construction.rs
@@ -0,0 +1,616 @@
+use std::num::NonZeroU32;
+
+use crate::front::wgsl::parse::ast;
+use crate::{Handle, Span};
+
+use crate::front::wgsl::error::Error;
+use crate::front::wgsl::lower::{ExpressionContext, Lowerer};
+
+/// A cooked form of `ast::ConstructorType` that uses Naga types whenever
+/// possible.
+enum Constructor<T> {
+ /// A vector construction whose component type is inferred from the
+ /// argument: `vec3(1.0)`.
+ PartialVector { size: crate::VectorSize },
+
+ /// A matrix construction whose component type is inferred from the
+ /// argument: `mat2x2(1,2,3,4)`.
+ PartialMatrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ },
+
+ /// An array whose component type and size are inferred from the arguments:
+ /// `array(3,4,5)`.
+ PartialArray,
+
+ /// A known Naga type.
+ ///
+ /// When we match on this type, we need to see the `TypeInner` here, but at
+ /// the point that we build this value we'll still need mutable access to
+ /// the module later. To avoid borrowing from the module, the type parameter
+ /// `T` is `Handle<Type>` initially. Then we use `borrow_inner` to produce a
+ /// version holding a tuple `(Handle<Type>, &TypeInner)`.
+ Type(T),
+}
+
+impl Constructor<Handle<crate::Type>> {
+ /// Return an equivalent `Constructor` value that includes borrowed
+ /// `TypeInner` values alongside any type handles.
+ ///
+ /// The returned form is more convenient to match on, since the patterns
+ /// can actually see what the handle refers to.
+ fn borrow_inner(
+ self,
+ module: &crate::Module,
+ ) -> Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
+ match self {
+ Constructor::PartialVector { size } => Constructor::PartialVector { size },
+ Constructor::PartialMatrix { columns, rows } => {
+ Constructor::PartialMatrix { columns, rows }
+ }
+ Constructor::PartialArray => Constructor::PartialArray,
+ Constructor::Type(handle) => Constructor::Type((handle, &module.types[handle].inner)),
+ }
+ }
+}
+
+impl Constructor<(Handle<crate::Type>, &crate::TypeInner)> {
+ fn to_error_string(&self, ctx: &ExpressionContext) -> String {
+ match *self {
+ Self::PartialVector { size } => {
+ format!("vec{}<?>", size as u32,)
+ }
+ Self::PartialMatrix { columns, rows } => {
+ format!("mat{}x{}<?>", columns as u32, rows as u32,)
+ }
+ Self::PartialArray => "array<?, ?>".to_string(),
+ Self::Type((handle, _inner)) => handle.to_wgsl(&ctx.module.to_ctx()),
+ }
+ }
+}
+
+enum Components<'a> {
+ None,
+ One {
+ component: Handle<crate::Expression>,
+ span: Span,
+ ty_inner: &'a crate::TypeInner,
+ },
+ Many {
+ components: Vec<Handle<crate::Expression>>,
+ spans: Vec<Span>,
+ },
+}
+
+impl Components<'_> {
+ fn into_components_vec(self) -> Vec<Handle<crate::Expression>> {
+ match self {
+ Self::None => vec![],
+ Self::One { component, .. } => vec![component],
+ Self::Many { components, .. } => components,
+ }
+ }
+}
+
+impl<'source, 'temp> Lowerer<'source, 'temp> {
+ /// Generate Naga IR for a type constructor expression.
+ ///
+ /// The `constructor` value represents the head of the constructor
+ /// expression, which is at least a hint of which type is being built; if
+ /// it's one of the `Partial` variants, we need to consider the argument
+ /// types as well.
+ ///
+ /// This is used for [`Construct`] expressions, but also for [`Call`]
+ /// expressions, once we've determined that the "callable" (in WGSL spec
+ /// terms) is actually a type.
+ ///
+ /// [`Construct`]: ast::Expression::Construct
+ /// [`Call`]: ast::Expression::Call
+ pub fn construct(
+ &mut self,
+ span: Span,
+ constructor: &ast::ConstructorType<'source>,
+ ty_span: Span,
+ components: &[Handle<ast::Expression<'source>>],
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ use crate::proc::TypeResolution as Tr;
+
+ let constructor_h = self.constructor(constructor, ctx)?;
+
+ let components = match *components {
+ [] => Components::None,
+ [component] => {
+ let span = ctx.ast_expressions.get_span(component);
+ let component = self.expression_for_abstract(component, ctx)?;
+ let ty_inner = super::resolve_inner!(ctx, component);
+
+ Components::One {
+ component,
+ span,
+ ty_inner,
+ }
+ }
+ ref ast_components @ [_, _, ..] => {
+ let components = ast_components
+ .iter()
+ .map(|&expr| self.expression_for_abstract(expr, ctx))
+ .collect::<Result<_, _>>()?;
+ let spans = ast_components
+ .iter()
+ .map(|&expr| ctx.ast_expressions.get_span(expr))
+ .collect();
+
+ for &component in &components {
+ ctx.grow_types(component)?;
+ }
+
+ Components::Many { components, spans }
+ }
+ };
+
+ // Even though we computed `constructor` above, wait until now to borrow
+ // a reference to the `TypeInner`, so that the component-handling code
+ // above can have mutable access to the type arena.
+ let constructor = constructor_h.borrow_inner(ctx.module);
+
+ let expr;
+ match (components, constructor) {
+ // Empty constructor
+ (Components::None, dst_ty) => match dst_ty {
+ Constructor::Type((result_ty, _)) => {
+ return ctx.append_expression(crate::Expression::ZeroValue(result_ty), span)
+ }
+ Constructor::PartialVector { .. }
+ | Constructor::PartialMatrix { .. }
+ | Constructor::PartialArray => {
+ // We have no arguments from which to infer the result type, so
+ // partial constructors aren't acceptable here.
+ return Err(Error::TypeNotInferable(ty_span));
+ }
+ },
+
+ // Scalar constructor & conversion (scalar -> scalar)
+ (
+ Components::One {
+ component,
+ ty_inner: &crate::TypeInner::Scalar { .. },
+ ..
+ },
+ Constructor::Type((_, &crate::TypeInner::Scalar(scalar))),
+ ) => {
+ expr = crate::Expression::As {
+ expr: component,
+ kind: scalar.kind,
+ convert: Some(scalar.width),
+ };
+ }
+
+ // Vector conversion (vector -> vector)
+ (
+ Components::One {
+ component,
+ ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
+ ..
+ },
+ Constructor::Type((
+ _,
+ &crate::TypeInner::Vector {
+ size: dst_size,
+ scalar: dst_scalar,
+ },
+ )),
+ ) if dst_size == src_size => {
+ expr = crate::Expression::As {
+ expr: component,
+ kind: dst_scalar.kind,
+ convert: Some(dst_scalar.width),
+ };
+ }
+
+ // Vector conversion (vector -> vector) - partial
+ (
+ Components::One {
+ component,
+ ty_inner: &crate::TypeInner::Vector { size: src_size, .. },
+ ..
+ },
+ Constructor::PartialVector { size: dst_size },
+ ) if dst_size == src_size => {
+ // This is a trivial conversion: the sizes match, and a Partial
+ // constructor doesn't specify a scalar type, so nothing can
+ // possibly happen.
+ return Ok(component);
+ }
+
+ // Matrix conversion (matrix -> matrix)
+ (
+ Components::One {
+ component,
+ ty_inner:
+ &crate::TypeInner::Matrix {
+ columns: src_columns,
+ rows: src_rows,
+ ..
+ },
+ ..
+ },
+ Constructor::Type((
+ _,
+ &crate::TypeInner::Matrix {
+ columns: dst_columns,
+ rows: dst_rows,
+ scalar: dst_scalar,
+ },
+ )),
+ ) if dst_columns == src_columns && dst_rows == src_rows => {
+ expr = crate::Expression::As {
+ expr: component,
+ kind: dst_scalar.kind,
+ convert: Some(dst_scalar.width),
+ };
+ }
+
+ // Matrix conversion (matrix -> matrix) - partial
+ (
+ Components::One {
+ component,
+ ty_inner:
+ &crate::TypeInner::Matrix {
+ columns: src_columns,
+ rows: src_rows,
+ ..
+ },
+ ..
+ },
+ Constructor::PartialMatrix {
+ columns: dst_columns,
+ rows: dst_rows,
+ },
+ ) if dst_columns == src_columns && dst_rows == src_rows => {
+ // This is a trivial conversion: the sizes match, and a Partial
+ // constructor doesn't specify a scalar type, so nothing can
+ // possibly happen.
+ return Ok(component);
+ }
+
+ // Vector constructor (splat) - infer type
+ (
+ Components::One {
+ component,
+ ty_inner: &crate::TypeInner::Scalar { .. },
+ ..
+ },
+ Constructor::PartialVector { size },
+ ) => {
+ expr = crate::Expression::Splat {
+ size,
+ value: component,
+ };
+ }
+
+ // Vector constructor (splat)
+ (
+ Components::One {
+ mut component,
+ ty_inner: &crate::TypeInner::Scalar(_),
+ ..
+ },
+ Constructor::Type((_, &crate::TypeInner::Vector { size, scalar })),
+ ) => {
+ ctx.convert_slice_to_common_leaf_scalar(
+ std::slice::from_mut(&mut component),
+ scalar,
+ )?;
+ expr = crate::Expression::Splat {
+ size,
+ value: component,
+ };
+ }
+
+ // Vector constructor (by elements), partial
+ (
+ Components::Many {
+ mut components,
+ spans,
+ },
+ Constructor::PartialVector { size },
+ ) => {
+ let consensus_scalar =
+ ctx.automatic_conversion_consensus(&components)
+ .map_err(|index| {
+ Error::InvalidConstructorComponentType(spans[index], index as i32)
+ })?;
+ ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
+ let inner = consensus_scalar.to_inner_vector(size);
+ let ty = ctx.ensure_type_exists(inner);
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Vector constructor (by elements), full type given
+ (
+ Components::Many { mut components, .. },
+ Constructor::Type((ty, &crate::TypeInner::Vector { scalar, .. })),
+ ) => {
+ ctx.try_automatic_conversions_for_vector(&mut components, scalar, ty_span)?;
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Matrix constructor (by elements), partial
+ (
+ Components::Many {
+ mut components,
+ spans,
+ },
+ Constructor::PartialMatrix { columns, rows },
+ ) if components.len() == columns as usize * rows as usize => {
+ let consensus_scalar =
+ ctx.automatic_conversion_consensus(&components)
+ .map_err(|index| {
+ Error::InvalidConstructorComponentType(spans[index], index as i32)
+ })?;
+ // We actually only accept floating-point elements.
+ let consensus_scalar = consensus_scalar
+ .automatic_conversion_combine(crate::Scalar::ABSTRACT_FLOAT)
+ .unwrap_or(consensus_scalar);
+ ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
+ let vec_ty = ctx.ensure_type_exists(consensus_scalar.to_inner_vector(rows));
+
+ let components = components
+ .chunks(rows as usize)
+ .map(|vec_components| {
+ ctx.append_expression(
+ crate::Expression::Compose {
+ ty: vec_ty,
+ components: Vec::from(vec_components),
+ },
+ Default::default(),
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar: consensus_scalar,
+ });
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Matrix constructor (by elements), type given
+ (
+ Components::Many { mut components, .. },
+ Constructor::Type((
+ _,
+ &crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ },
+ )),
+ ) if components.len() == columns as usize * rows as usize => {
+ let element = Tr::Value(crate::TypeInner::Scalar(scalar));
+ ctx.try_automatic_conversions_slice(&mut components, &element, ty_span)?;
+ let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows));
+
+ let components = components
+ .chunks(rows as usize)
+ .map(|vec_components| {
+ ctx.append_expression(
+ crate::Expression::Compose {
+ ty: vec_ty,
+ components: Vec::from(vec_components),
+ },
+ Default::default(),
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ });
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Matrix constructor (by columns), partial
+ (
+ Components::Many {
+ mut components,
+ spans,
+ },
+ Constructor::PartialMatrix { columns, rows },
+ ) => {
+ let consensus_scalar =
+ ctx.automatic_conversion_consensus(&components)
+ .map_err(|index| {
+ Error::InvalidConstructorComponentType(spans[index], index as i32)
+ })?;
+ ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar: consensus_scalar,
+ });
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Matrix constructor (by columns), type given
+ (
+ Components::Many { mut components, .. },
+ Constructor::Type((
+ ty,
+ &crate::TypeInner::Matrix {
+ columns: _,
+ rows,
+ scalar,
+ },
+ )),
+ ) => {
+ let component_ty = crate::TypeInner::Vector { size: rows, scalar };
+ ctx.try_automatic_conversions_slice(
+ &mut components,
+ &Tr::Value(component_ty),
+ ty_span,
+ )?;
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Array constructor - infer type
+ (components, Constructor::PartialArray) => {
+ let mut components = components.into_components_vec();
+ if let Ok(consensus_scalar) = ctx.automatic_conversion_consensus(&components) {
+ // Note that this will *not* necessarily convert all the
+ // components to the same type! The `automatic_conversion_consensus`
+ // method only considers the parameters' leaf scalar
+ // types; the parameters themselves could be any mix of
+ // vectors, matrices, and scalars.
+ //
+ // But *if* it is possible for this array construction
+ // expression to be well-typed at all, then all the
+ // parameters must have the same type constructors (vec,
+ // matrix, scalar) applied to their leaf scalars, so
+ // reconciling their scalars is always the right thing to
+ // do. And if this array construction is not well-typed,
+ // these conversions will not make it so, and we can let
+ // validation catch the error.
+ ctx.convert_slice_to_common_leaf_scalar(&mut components, consensus_scalar)?;
+ } else {
+ // There's no consensus scalar. Emit the `Compose`
+ // expression anyway, and let validation catch the problem.
+ }
+
+ let base = ctx.register_type(components[0])?;
+
+ let inner = crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(
+ NonZeroU32::new(u32::try_from(components.len()).unwrap()).unwrap(),
+ ),
+ stride: {
+ self.layouter.update(ctx.module.to_ctx()).unwrap();
+ self.layouter[base].to_stride()
+ },
+ };
+ let ty = ctx.ensure_type_exists(inner);
+
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Array constructor, explicit type
+ (components, Constructor::Type((ty, &crate::TypeInner::Array { base, .. }))) => {
+ let mut components = components.into_components_vec();
+ ctx.try_automatic_conversions_slice(&mut components, &Tr::Handle(base), ty_span)?;
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // Struct constructor
+ (
+ components,
+ Constructor::Type((ty, &crate::TypeInner::Struct { ref members, .. })),
+ ) => {
+ let mut components = components.into_components_vec();
+ let struct_ty_span = ctx.module.types.get_span(ty);
+
+ // Make a vector of the members' type handles in advance, to
+ // avoid borrowing `members` from `ctx` while we generate
+ // new code.
+ let members: Vec<Handle<crate::Type>> = members.iter().map(|m| m.ty).collect();
+
+ for (component, &ty) in components.iter_mut().zip(&members) {
+ *component =
+ ctx.try_automatic_conversions(*component, &Tr::Handle(ty), struct_ty_span)?;
+ }
+ expr = crate::Expression::Compose { ty, components };
+ }
+
+ // ERRORS
+
+ // Bad conversion (type cast)
+ (Components::One { span, ty_inner, .. }, constructor) => {
+ let from_type = ty_inner.to_wgsl(&ctx.module.to_ctx());
+ return Err(Error::BadTypeCast {
+ span,
+ from_type,
+ to_type: constructor.to_error_string(ctx),
+ });
+ }
+
+ // Too many parameters for scalar constructor
+ (
+ Components::Many { spans, .. },
+ Constructor::Type((_, &crate::TypeInner::Scalar { .. })),
+ ) => {
+ let span = spans[1].until(spans.last().unwrap());
+ return Err(Error::UnexpectedComponents(span));
+ }
+
+ // Other types can't be constructed
+ _ => return Err(Error::TypeNotConstructible(ty_span)),
+ }
+
+ let expr = ctx.append_expression(expr, span)?;
+ Ok(expr)
+ }
+
+ /// Build a [`Constructor`] for a WGSL construction expression.
+ ///
+ /// If `constructor` conveys enough information to determine which Naga [`Type`]
+ /// we're actually building (i.e., it's not a partial constructor), then
+ /// ensure the `Type` exists in [`ctx.module`], and return
+ /// [`Constructor::Type`].
+ ///
+ /// Otherwise, return the [`Constructor`] partial variant corresponding to
+ /// `constructor`.
+ ///
+ /// [`Type`]: crate::Type
+ /// [`ctx.module`]: ExpressionContext::module
+ fn constructor<'out>(
+ &mut self,
+ constructor: &ast::ConstructorType<'source>,
+ ctx: &mut ExpressionContext<'source, '_, 'out>,
+ ) -> Result<Constructor<Handle<crate::Type>>, Error<'source>> {
+ let handle = match *constructor {
+ ast::ConstructorType::Scalar(scalar) => {
+ let ty = ctx.ensure_type_exists(scalar.to_inner_scalar());
+ Constructor::Type(ty)
+ }
+ ast::ConstructorType::PartialVector { size } => Constructor::PartialVector { size },
+ ast::ConstructorType::Vector { size, scalar } => {
+ let ty = ctx.ensure_type_exists(scalar.to_inner_vector(size));
+ Constructor::Type(ty)
+ }
+ ast::ConstructorType::PartialMatrix { columns, rows } => {
+ Constructor::PartialMatrix { columns, rows }
+ }
+ ast::ConstructorType::Matrix {
+ rows,
+ columns,
+ width,
+ } => {
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar: crate::Scalar::float(width),
+ });
+ Constructor::Type(ty)
+ }
+ ast::ConstructorType::PartialArray => Constructor::PartialArray,
+ ast::ConstructorType::Array { base, size } => {
+ let base = self.resolve_ast_type(base, &mut ctx.as_global())?;
+ let size = self.array_size(size, &mut ctx.as_global())?;
+
+ self.layouter.update(ctx.module.to_ctx()).unwrap();
+ let stride = self.layouter[base].to_stride();
+
+ let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, stride });
+ Constructor::Type(ty)
+ }
+ ast::ConstructorType::Type(ty) => Constructor::Type(ty),
+ };
+
+ Ok(handle)
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/lower/conversion.rs b/third_party/rust/naga/src/front/wgsl/lower/conversion.rs
new file mode 100644
index 0000000000..2a2690f096
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/lower/conversion.rs
@@ -0,0 +1,503 @@
+//! WGSL's automatic conversions for abstract types.
+
+use crate::{Handle, Span};
+
+impl<'source, 'temp, 'out> super::ExpressionContext<'source, 'temp, 'out> {
+ /// Try to use WGSL's automatic conversions to convert `expr` to `goal_ty`.
+ ///
+ /// If no conversions are necessary, return `expr` unchanged.
+ ///
+ /// If automatic conversions cannot convert `expr` to `goal_ty`, return an
+ /// [`AutoConversion`] error.
+ ///
+ /// Although the Load Rule is one of the automatic conversions, this
+ /// function assumes it has already been applied if appropriate, as
+ /// indicated by the fact that the Rust type of `expr` is not `Typed<_>`.
+ ///
+ /// [`AutoConversion`]: super::Error::AutoConversion
+ pub fn try_automatic_conversions(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ goal_ty: &crate::proc::TypeResolution,
+ goal_span: Span,
+ ) -> Result<Handle<crate::Expression>, super::Error<'source>> {
+ let expr_span = self.get_expression_span(expr);
+ // Keep the TypeResolution so we can get type names for
+ // structs in error messages.
+ let expr_resolution = super::resolve!(self, expr);
+ let types = &self.module.types;
+ let expr_inner = expr_resolution.inner_with(types);
+ let goal_inner = goal_ty.inner_with(types);
+
+ // If `expr` already has the requested type, we're done.
+ if expr_inner.equivalent(goal_inner, types) {
+ return Ok(expr);
+ }
+
+ let (_expr_scalar, goal_scalar) =
+ match expr_inner.automatically_converts_to(goal_inner, types) {
+ Some(scalars) => scalars,
+ None => {
+ let gctx = &self.module.to_ctx();
+ let source_type = expr_resolution.to_wgsl(gctx);
+ let dest_type = goal_ty.to_wgsl(gctx);
+
+ return Err(super::Error::AutoConversion {
+ dest_span: goal_span,
+ dest_type,
+ source_span: expr_span,
+ source_type,
+ });
+ }
+ };
+
+ self.convert_leaf_scalar(expr, expr_span, goal_scalar)
+ }
+
+ /// Try to convert `expr`'s leaf scalar to `goal` using automatic conversions.
+ ///
+ /// If no conversions are necessary, return `expr` unchanged.
+ ///
+ /// If automatic conversions cannot convert `expr` to `goal_scalar`, return
+ /// an [`AutoConversionLeafScalar`] error.
+ ///
+ /// Although the Load Rule is one of the automatic conversions, this
+ /// function assumes it has already been applied if appropriate, as
+ /// indicated by the fact that the Rust type of `expr` is not `Typed<_>`.
+ ///
+ /// [`AutoConversionLeafScalar`]: super::Error::AutoConversionLeafScalar
+ pub fn try_automatic_conversion_for_leaf_scalar(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ goal_scalar: crate::Scalar,
+ goal_span: Span,
+ ) -> Result<Handle<crate::Expression>, super::Error<'source>> {
+ let expr_span = self.get_expression_span(expr);
+ let expr_resolution = super::resolve!(self, expr);
+ let types = &self.module.types;
+ let expr_inner = expr_resolution.inner_with(types);
+
+ let make_error = || {
+ let gctx = &self.module.to_ctx();
+ let source_type = expr_resolution.to_wgsl(gctx);
+ super::Error::AutoConversionLeafScalar {
+ dest_span: goal_span,
+ dest_scalar: goal_scalar.to_wgsl(),
+ source_span: expr_span,
+ source_type,
+ }
+ };
+
+ let expr_scalar = match expr_inner.scalar() {
+ Some(scalar) => scalar,
+ None => return Err(make_error()),
+ };
+
+ if expr_scalar == goal_scalar {
+ return Ok(expr);
+ }
+
+ if !expr_scalar.automatically_converts_to(goal_scalar) {
+ return Err(make_error());
+ }
+
+ assert!(expr_scalar.is_abstract());
+
+ self.convert_leaf_scalar(expr, expr_span, goal_scalar)
+ }
+
+ fn convert_leaf_scalar(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ expr_span: Span,
+ goal_scalar: crate::Scalar,
+ ) -> Result<Handle<crate::Expression>, super::Error<'source>> {
+ let expr_inner = super::resolve_inner!(self, expr);
+ if let crate::TypeInner::Array { .. } = *expr_inner {
+ self.as_const_evaluator()
+ .cast_array(expr, goal_scalar, expr_span)
+ .map_err(|err| super::Error::ConstantEvaluatorError(err, expr_span))
+ } else {
+ let cast = crate::Expression::As {
+ expr,
+ kind: goal_scalar.kind,
+ convert: Some(goal_scalar.width),
+ };
+ self.append_expression(cast, expr_span)
+ }
+ }
+
+ /// Try to convert `exprs` to `goal_ty` using WGSL's automatic conversions.
+ pub fn try_automatic_conversions_slice(
+ &mut self,
+ exprs: &mut [Handle<crate::Expression>],
+ goal_ty: &crate::proc::TypeResolution,
+ goal_span: Span,
+ ) -> Result<(), super::Error<'source>> {
+ for expr in exprs.iter_mut() {
+ *expr = self.try_automatic_conversions(*expr, goal_ty, goal_span)?;
+ }
+
+ Ok(())
+ }
+
+ /// Apply WGSL's automatic conversions to a vector constructor's arguments.
+ ///
+ /// When calling a vector constructor like `vec3<f32>(...)`, the parameters
+ /// can be a mix of scalars and vectors, with the latter being spread out to
+ /// contribute each of their components as a component of the new value.
+ /// When the element type is explicit, as with `<f32>` in the example above,
+ /// WGSL's automatic conversions should convert abstract scalar and vector
+ /// parameters to the constructor's required scalar type.
+ pub fn try_automatic_conversions_for_vector(
+ &mut self,
+ exprs: &mut [Handle<crate::Expression>],
+ goal_scalar: crate::Scalar,
+ goal_span: Span,
+ ) -> Result<(), super::Error<'source>> {
+ use crate::proc::TypeResolution as Tr;
+ use crate::TypeInner as Ti;
+ let goal_scalar_res = Tr::Value(Ti::Scalar(goal_scalar));
+
+ for (i, expr) in exprs.iter_mut().enumerate() {
+ // Keep the TypeResolution so we can get full type names
+ // in error messages.
+ let expr_resolution = super::resolve!(self, *expr);
+ let types = &self.module.types;
+ let expr_inner = expr_resolution.inner_with(types);
+
+ match *expr_inner {
+ Ti::Scalar(_) => {
+ *expr = self.try_automatic_conversions(*expr, &goal_scalar_res, goal_span)?;
+ }
+ Ti::Vector { size, scalar: _ } => {
+ let goal_vector_res = Tr::Value(Ti::Vector {
+ size,
+ scalar: goal_scalar,
+ });
+ *expr = self.try_automatic_conversions(*expr, &goal_vector_res, goal_span)?;
+ }
+ _ => {
+ let span = self.get_expression_span(*expr);
+ return Err(super::Error::InvalidConstructorComponentType(
+ span, i as i32,
+ ));
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Convert `expr` to the leaf scalar type `scalar`.
+ pub fn convert_to_leaf_scalar(
+ &mut self,
+ expr: &mut Handle<crate::Expression>,
+ goal: crate::Scalar,
+ ) -> Result<(), super::Error<'source>> {
+ let inner = super::resolve_inner!(self, *expr);
+ // Do nothing if `inner` doesn't even have leaf scalars;
+ // it's a type error that validation will catch.
+ if inner.scalar() != Some(goal) {
+ let cast = crate::Expression::As {
+ expr: *expr,
+ kind: goal.kind,
+ convert: Some(goal.width),
+ };
+ let expr_span = self.get_expression_span(*expr);
+ *expr = self.append_expression(cast, expr_span)?;
+ }
+
+ Ok(())
+ }
+
+ /// Convert all expressions in `exprs` to a common scalar type.
+ ///
+ /// Note that the caller is responsible for making sure these
+ /// conversions are actually justified. This function simply
+ /// generates `As` expressions, regardless of whether they are
+ /// permitted WGSL automatic conversions. Callers intending to
+ /// implement automatic conversions need to determine for
+ /// themselves whether the casts we we generate are justified,
+ /// perhaps by calling `TypeInner::automatically_converts_to` or
+ /// `Scalar::automatic_conversion_combine`.
+ pub fn convert_slice_to_common_leaf_scalar(
+ &mut self,
+ exprs: &mut [Handle<crate::Expression>],
+ goal: crate::Scalar,
+ ) -> Result<(), super::Error<'source>> {
+ for expr in exprs.iter_mut() {
+ self.convert_to_leaf_scalar(expr, goal)?;
+ }
+
+ Ok(())
+ }
+
+ /// Return an expression for the concretized value of `expr`.
+ ///
+ /// If `expr` is already concrete, return it unchanged.
+ pub fn concretize(
+ &mut self,
+ mut expr: Handle<crate::Expression>,
+ ) -> Result<Handle<crate::Expression>, super::Error<'source>> {
+ let inner = super::resolve_inner!(self, expr);
+ if let Some(scalar) = inner.automatically_convertible_scalar(&self.module.types) {
+ let concretized = scalar.concretize();
+ if concretized != scalar {
+ assert!(scalar.is_abstract());
+ let expr_span = self.get_expression_span(expr);
+ expr = self
+ .as_const_evaluator()
+ .cast_array(expr, concretized, expr_span)
+ .map_err(|err| {
+ // A `TypeResolution` includes the type's full name, if
+ // it has one. Also, avoid holding the borrow of `inner`
+ // across the call to `cast_array`.
+ let expr_type = &self.typifier()[expr];
+ super::Error::ConcretizationFailed {
+ expr_span,
+ expr_type: expr_type.to_wgsl(&self.module.to_ctx()),
+ scalar: concretized.to_wgsl(),
+ inner: err,
+ }
+ })?;
+ }
+ }
+
+ Ok(expr)
+ }
+
+ /// Find the consensus scalar of `components` under WGSL's automatic
+ /// conversions.
+ ///
+ /// If `components` can all be converted to any common scalar via
+ /// WGSL's automatic conversions, return the best such scalar.
+ ///
+ /// The `components` slice must not be empty. All elements' types must
+ /// have been resolved.
+ ///
+ /// If `components` are definitely not acceptable as arguments to such
+ /// constructors, return `Err(i)`, where `i` is the index in
+ /// `components` of some problematic argument.
+ ///
+ /// This function doesn't fully type-check the arguments - it only
+ /// considers their leaf scalar types. This means it may return `Ok`
+ /// even when the Naga validator will reject the resulting
+ /// construction expression later.
+ pub fn automatic_conversion_consensus<'handle, I>(
+ &self,
+ components: I,
+ ) -> Result<crate::Scalar, usize>
+ where
+ I: IntoIterator<Item = &'handle Handle<crate::Expression>>,
+ I::IntoIter: Clone, // for debugging
+ {
+ let types = &self.module.types;
+ let mut inners = components
+ .into_iter()
+ .map(|&c| self.typifier()[c].inner_with(types));
+ log::debug!(
+ "wgsl automatic_conversion_consensus: {:?}",
+ inners
+ .clone()
+ .map(|inner| inner.to_wgsl(&self.module.to_ctx()))
+ .collect::<Vec<String>>()
+ );
+ let mut best = inners.next().unwrap().scalar().ok_or(0_usize)?;
+ for (inner, i) in inners.zip(1..) {
+ let scalar = inner.scalar().ok_or(i)?;
+ match best.automatic_conversion_combine(scalar) {
+ Some(new_best) => {
+ best = new_best;
+ }
+ None => return Err(i),
+ }
+ }
+
+ log::debug!(" consensus: {:?}", best.to_wgsl());
+ Ok(best)
+ }
+}
+
+impl crate::TypeInner {
+ /// Determine whether `self` automatically converts to `goal`.
+ ///
+ /// If WGSL's automatic conversions (excluding the Load Rule) will
+ /// convert `self` to `goal`, then return a pair `(from, to)`,
+ /// where `from` and `to` are the scalar types of the leaf values
+ /// of `self` and `goal`.
+ ///
+ /// This function assumes that `self` and `goal` are different
+ /// types. Callers should first check whether any conversion is
+ /// needed at all.
+ ///
+ /// If the automatic conversions cannot convert `self` to `goal`,
+ /// return `None`.
+ fn automatically_converts_to(
+ &self,
+ goal: &Self,
+ types: &crate::UniqueArena<crate::Type>,
+ ) -> Option<(crate::Scalar, crate::Scalar)> {
+ use crate::ScalarKind as Sk;
+ use crate::TypeInner as Ti;
+
+ // Automatic conversions only change the scalar type of a value's leaves
+ // (e.g., `vec4<AbstractFloat>` to `vec4<f32>`), never the type
+ // constructors applied to those scalar types (e.g., never scalar to
+ // `vec4`, or `vec2` to `vec3`). So first we check that the type
+ // constructors match, extracting the leaf scalar types in the process.
+ let expr_scalar;
+ let goal_scalar;
+ match (self, goal) {
+ (&Ti::Scalar(expr), &Ti::Scalar(goal)) => {
+ expr_scalar = expr;
+ goal_scalar = goal;
+ }
+ (
+ &Ti::Vector {
+ size: expr_size,
+ scalar: expr,
+ },
+ &Ti::Vector {
+ size: goal_size,
+ scalar: goal,
+ },
+ ) if expr_size == goal_size => {
+ expr_scalar = expr;
+ goal_scalar = goal;
+ }
+ (
+ &Ti::Matrix {
+ rows: expr_rows,
+ columns: expr_columns,
+ scalar: expr,
+ },
+ &Ti::Matrix {
+ rows: goal_rows,
+ columns: goal_columns,
+ scalar: goal,
+ },
+ ) if expr_rows == goal_rows && expr_columns == goal_columns => {
+ expr_scalar = expr;
+ goal_scalar = goal;
+ }
+ (
+ &Ti::Array {
+ base: expr_base,
+ size: expr_size,
+ stride: _,
+ },
+ &Ti::Array {
+ base: goal_base,
+ size: goal_size,
+ stride: _,
+ },
+ ) if expr_size == goal_size => {
+ return types[expr_base]
+ .inner
+ .automatically_converts_to(&types[goal_base].inner, types);
+ }
+ _ => return None,
+ }
+
+ match (expr_scalar.kind, goal_scalar.kind) {
+ (Sk::AbstractFloat, Sk::Float) => {}
+ (Sk::AbstractInt, Sk::Sint | Sk::Uint | Sk::AbstractFloat | Sk::Float) => {}
+ _ => return None,
+ }
+
+ log::trace!(" okay: expr {expr_scalar:?}, goal {goal_scalar:?}");
+ Some((expr_scalar, goal_scalar))
+ }
+
+ fn automatically_convertible_scalar(
+ &self,
+ types: &crate::UniqueArena<crate::Type>,
+ ) -> Option<crate::Scalar> {
+ use crate::TypeInner as Ti;
+ match *self {
+ Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => {
+ Some(scalar)
+ }
+ Ti::Array { base, .. } => types[base].inner.automatically_convertible_scalar(types),
+ Ti::Atomic(_)
+ | Ti::Pointer { .. }
+ | Ti::ValuePointer { .. }
+ | Ti::Struct { .. }
+ | Ti::Image { .. }
+ | Ti::Sampler { .. }
+ | Ti::AccelerationStructure
+ | Ti::RayQuery
+ | Ti::BindingArray { .. } => None,
+ }
+ }
+}
+
+impl crate::Scalar {
+ /// Find the common type of `self` and `other` under WGSL's
+ /// automatic conversions.
+ ///
+ /// If there are any scalars to which WGSL's automatic conversions
+ /// will convert both `self` and `other`, return the best such
+ /// scalar. Otherwise, return `None`.
+ pub const fn automatic_conversion_combine(self, other: Self) -> Option<crate::Scalar> {
+ use crate::ScalarKind as Sk;
+
+ match (self.kind, other.kind) {
+ // When the kinds match...
+ (Sk::AbstractFloat, Sk::AbstractFloat)
+ | (Sk::AbstractInt, Sk::AbstractInt)
+ | (Sk::Sint, Sk::Sint)
+ | (Sk::Uint, Sk::Uint)
+ | (Sk::Float, Sk::Float)
+ | (Sk::Bool, Sk::Bool) => {
+ if self.width == other.width {
+ // ... either no conversion is necessary ...
+ Some(self)
+ } else {
+ // ... or no conversion is possible.
+ // We never convert concrete to concrete, and
+ // abstract types should have only one size.
+ None
+ }
+ }
+
+ // AbstractInt converts to AbstractFloat.
+ (Sk::AbstractFloat, Sk::AbstractInt) => Some(self),
+ (Sk::AbstractInt, Sk::AbstractFloat) => Some(other),
+
+ // AbstractFloat converts to Float.
+ (Sk::AbstractFloat, Sk::Float) => Some(other),
+ (Sk::Float, Sk::AbstractFloat) => Some(self),
+
+ // AbstractInt converts to concrete integer or float.
+ (Sk::AbstractInt, Sk::Uint | Sk::Sint | Sk::Float) => Some(other),
+ (Sk::Uint | Sk::Sint | Sk::Float, Sk::AbstractInt) => Some(self),
+
+ // AbstractFloat can't be reconciled with concrete integer types.
+ (Sk::AbstractFloat, Sk::Uint | Sk::Sint) | (Sk::Uint | Sk::Sint, Sk::AbstractFloat) => {
+ None
+ }
+
+ // Nothing can be reconciled with `bool`.
+ (Sk::Bool, _) | (_, Sk::Bool) => None,
+
+ // Different concrete types cannot be reconciled.
+ (Sk::Sint | Sk::Uint | Sk::Float, Sk::Sint | Sk::Uint | Sk::Float) => None,
+ }
+ }
+
+ /// Return `true` if automatic conversions will covert `self` to `goal`.
+ pub fn automatically_converts_to(self, goal: Self) -> bool {
+ self.automatic_conversion_combine(goal) == Some(goal)
+ }
+
+ const fn concretize(self) -> Self {
+ use crate::ScalarKind as Sk;
+ match self.kind {
+ Sk::Sint | Sk::Uint | Sk::Float | Sk::Bool => self,
+ Sk::AbstractInt => Self::I32,
+ Sk::AbstractFloat => Self::F32,
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/lower/mod.rs b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
new file mode 100644
index 0000000000..ba9b49e135
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
@@ -0,0 +1,2760 @@
+use std::num::NonZeroU32;
+
+use crate::front::wgsl::error::{Error, ExpectedToken, InvalidAssignmentType};
+use crate::front::wgsl::index::Index;
+use crate::front::wgsl::parse::number::Number;
+use crate::front::wgsl::parse::{ast, conv};
+use crate::front::Typifier;
+use crate::proc::{
+ ensure_block_returns, Alignment, ConstantEvaluator, Emitter, Layouter, ResolveContext,
+};
+use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span};
+
+mod construction;
+mod conversion;
+
+/// Resolves the inner type of a given expression.
+///
+/// Expects a &mut [`ExpressionContext`] and a [`Handle<Expression>`].
+///
+/// Returns a &[`crate::TypeInner`].
+///
+/// Ideally, we would simply have a function that takes a `&mut ExpressionContext`
+/// and returns a `&TypeResolution`. Unfortunately, this leads the borrow checker
+/// to conclude that the mutable borrow lasts for as long as we are using the
+/// `&TypeResolution`, so we can't use the `ExpressionContext` for anything else -
+/// like, say, resolving another operand's type. Using a macro that expands to
+/// two separate calls, only the first of which needs a `&mut`,
+/// lets the borrow checker see that the mutable borrow is over.
+macro_rules! resolve_inner {
+ ($ctx:ident, $expr:expr) => {{
+ $ctx.grow_types($expr)?;
+ $ctx.typifier()[$expr].inner_with(&$ctx.module.types)
+ }};
+}
+pub(super) use resolve_inner;
+
+/// Resolves the inner types of two given expressions.
+///
+/// Expects a &mut [`ExpressionContext`] and two [`Handle<Expression>`]s.
+///
+/// Returns a tuple containing two &[`crate::TypeInner`].
+///
+/// See the documentation of [`resolve_inner!`] for why this macro is necessary.
+macro_rules! resolve_inner_binary {
+ ($ctx:ident, $left:expr, $right:expr) => {{
+ $ctx.grow_types($left)?;
+ $ctx.grow_types($right)?;
+ (
+ $ctx.typifier()[$left].inner_with(&$ctx.module.types),
+ $ctx.typifier()[$right].inner_with(&$ctx.module.types),
+ )
+ }};
+}
+
+/// Resolves the type of a given expression.
+///
+/// Expects a &mut [`ExpressionContext`] and a [`Handle<Expression>`].
+///
+/// Returns a &[`TypeResolution`].
+///
+/// See the documentation of [`resolve_inner!`] for why this macro is necessary.
+///
+/// [`TypeResolution`]: crate::proc::TypeResolution
+macro_rules! resolve {
+ ($ctx:ident, $expr:expr) => {{
+ $ctx.grow_types($expr)?;
+ &$ctx.typifier()[$expr]
+ }};
+}
+pub(super) use resolve;
+
+/// State for constructing a `crate::Module`.
+pub struct GlobalContext<'source, 'temp, 'out> {
+ /// The `TranslationUnit`'s expressions arena.
+ ast_expressions: &'temp Arena<ast::Expression<'source>>,
+
+ /// The `TranslationUnit`'s types arena.
+ types: &'temp Arena<ast::Type<'source>>,
+
+ // Naga IR values.
+ /// The map from the names of module-scope declarations to the Naga IR
+ /// `Handle`s we have built for them, owned by `Lowerer::lower`.
+ globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>,
+
+ /// The module we're constructing.
+ module: &'out mut crate::Module,
+
+ const_typifier: &'temp mut Typifier,
+}
+
+impl<'source> GlobalContext<'source, '_, '_> {
+ fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> {
+ ExpressionContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ const_typifier: self.const_typifier,
+ expr_type: ExpressionContextType::Constant,
+ }
+ }
+
+ fn ensure_type_exists(
+ &mut self,
+ name: Option<String>,
+ inner: crate::TypeInner,
+ ) -> Handle<crate::Type> {
+ self.module
+ .types
+ .insert(crate::Type { inner, name }, Span::UNDEFINED)
+ }
+}
+
+/// State for lowering a statement within a function.
+pub struct StatementContext<'source, 'temp, 'out> {
+ // WGSL AST values.
+ /// A reference to [`TranslationUnit::expressions`] for the translation unit
+ /// we're lowering.
+ ///
+ /// [`TranslationUnit::expressions`]: ast::TranslationUnit::expressions
+ ast_expressions: &'temp Arena<ast::Expression<'source>>,
+
+ /// A reference to [`TranslationUnit::types`] for the translation unit
+ /// we're lowering.
+ ///
+ /// [`TranslationUnit::types`]: ast::TranslationUnit::types
+ types: &'temp Arena<ast::Type<'source>>,
+
+ // Naga IR values.
+ /// The map from the names of module-scope declarations to the Naga IR
+ /// `Handle`s we have built for them, owned by `Lowerer::lower`.
+ globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>,
+
+ /// A map from each `ast::Local` handle to the Naga expression
+ /// we've built for it:
+ ///
+ /// - WGSL function arguments become Naga [`FunctionArgument`] expressions.
+ ///
+ /// - WGSL `var` declarations become Naga [`LocalVariable`] expressions.
+ ///
+ /// - WGSL `let` declararations become arbitrary Naga expressions.
+ ///
+ /// This always borrows the `local_table` local variable in
+ /// [`Lowerer::function`].
+ ///
+ /// [`LocalVariable`]: crate::Expression::LocalVariable
+ /// [`FunctionArgument`]: crate::Expression::FunctionArgument
+ local_table: &'temp mut FastHashMap<Handle<ast::Local>, Typed<Handle<crate::Expression>>>,
+
+ const_typifier: &'temp mut Typifier,
+ typifier: &'temp mut Typifier,
+ function: &'out mut crate::Function,
+ /// Stores the names of expressions that are assigned in `let` statement
+ /// Also stores the spans of the names, for use in errors.
+ named_expressions: &'out mut FastIndexMap<Handle<crate::Expression>, (String, Span)>,
+ module: &'out mut crate::Module,
+
+ /// Which `Expression`s in `self.naga_expressions` are const expressions, in
+ /// the WGSL sense.
+ ///
+ /// According to the WGSL spec, a const expression must not refer to any
+ /// `let` declarations, even if those declarations' initializers are
+ /// themselves const expressions. So this tracker is not simply concerned
+ /// with the form of the expressions; it is also tracking whether WGSL says
+ /// we should consider them to be const. See the use of `force_non_const` in
+ /// the code for lowering `let` bindings.
+ expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker,
+}
+
+impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
+ fn as_expression<'t>(
+ &'t mut self,
+ block: &'t mut crate::Block,
+ emitter: &'t mut Emitter,
+ ) -> ExpressionContext<'a, 't, '_>
+ where
+ 'temp: 't,
+ {
+ ExpressionContext {
+ globals: self.globals,
+ types: self.types,
+ ast_expressions: self.ast_expressions,
+ const_typifier: self.const_typifier,
+ module: self.module,
+ expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext {
+ local_table: self.local_table,
+ function: self.function,
+ block,
+ emitter,
+ typifier: self.typifier,
+ expression_constness: self.expression_constness,
+ }),
+ }
+ }
+
+ fn as_global(&mut self) -> GlobalContext<'a, '_, '_> {
+ GlobalContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ const_typifier: self.const_typifier,
+ }
+ }
+
+ fn invalid_assignment_type(&self, expr: Handle<crate::Expression>) -> InvalidAssignmentType {
+ if let Some(&(_, span)) = self.named_expressions.get(&expr) {
+ InvalidAssignmentType::ImmutableBinding(span)
+ } else {
+ match self.function.expressions[expr] {
+ crate::Expression::Swizzle { .. } => InvalidAssignmentType::Swizzle,
+ crate::Expression::Access { base, .. } => self.invalid_assignment_type(base),
+ crate::Expression::AccessIndex { base, .. } => self.invalid_assignment_type(base),
+ _ => InvalidAssignmentType::Other,
+ }
+ }
+ }
+}
+
+pub struct RuntimeExpressionContext<'temp, 'out> {
+ /// A map from [`ast::Local`] handles to the Naga expressions we've built for them.
+ ///
+ /// This is always [`StatementContext::local_table`] for the
+ /// enclosing statement; see that documentation for details.
+ local_table: &'temp FastHashMap<Handle<ast::Local>, Typed<Handle<crate::Expression>>>,
+
+ function: &'out mut crate::Function,
+ block: &'temp mut crate::Block,
+ emitter: &'temp mut Emitter,
+ typifier: &'temp mut Typifier,
+
+ /// Which `Expression`s in `self.naga_expressions` are const expressions, in
+ /// the WGSL sense.
+ ///
+ /// See [`StatementContext::expression_constness`] for details.
+ expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker,
+}
+
+/// The type of Naga IR expression we are lowering an [`ast::Expression`] to.
+pub enum ExpressionContextType<'temp, 'out> {
+ /// We are lowering to an arbitrary runtime expression, to be
+ /// included in a function's body.
+ ///
+ /// The given [`RuntimeExpressionContext`] holds information about local
+ /// variables, arguments, and other definitions available only to runtime
+ /// expressions, not constant or override expressions.
+ Runtime(RuntimeExpressionContext<'temp, 'out>),
+
+ /// We are lowering to a constant expression, to be included in the module's
+ /// constant expression arena.
+ ///
+ /// Everything constant expressions are allowed to refer to is
+ /// available in the [`ExpressionContext`], so this variant
+ /// carries no further information.
+ Constant,
+}
+
+/// State for lowering an [`ast::Expression`] to Naga IR.
+///
+/// [`ExpressionContext`]s come in two kinds, distinguished by
+/// the value of the [`expr_type`] field:
+///
+/// - A [`Runtime`] context contributes [`naga::Expression`]s to a [`naga::Function`]'s
+/// runtime expression arena.
+///
+/// - A [`Constant`] context contributes [`naga::Expression`]s to a [`naga::Module`]'s
+/// constant expression arena.
+///
+/// [`ExpressionContext`]s are constructed in restricted ways:
+///
+/// - To get a [`Runtime`] [`ExpressionContext`], call
+/// [`StatementContext::as_expression`].
+///
+/// - To get a [`Constant`] [`ExpressionContext`], call
+/// [`GlobalContext::as_const`].
+///
+/// - You can demote a [`Runtime`] context to a [`Constant`] context
+/// by calling [`as_const`], but there's no way to go in the other
+/// direction, producing a runtime context from a constant one. This
+/// is because runtime expressions can refer to constant
+/// expressions, via [`Expression::Constant`], but constant
+/// expressions can't refer to a function's expressions.
+///
+/// Not to be confused with `wgsl::parse::ExpressionContext`, which is
+/// for parsing the `ast::Expression` in the first place.
+///
+/// [`expr_type`]: ExpressionContext::expr_type
+/// [`Runtime`]: ExpressionContextType::Runtime
+/// [`naga::Expression`]: crate::Expression
+/// [`naga::Function`]: crate::Function
+/// [`Constant`]: ExpressionContextType::Constant
+/// [`naga::Module`]: crate::Module
+/// [`as_const`]: ExpressionContext::as_const
+/// [`Expression::Constant`]: crate::Expression::Constant
+pub struct ExpressionContext<'source, 'temp, 'out> {
+ // WGSL AST values.
+ ast_expressions: &'temp Arena<ast::Expression<'source>>,
+ types: &'temp Arena<ast::Type<'source>>,
+
+ // Naga IR values.
+ /// The map from the names of module-scope declarations to the Naga IR
+ /// `Handle`s we have built for them, owned by `Lowerer::lower`.
+ globals: &'temp mut FastHashMap<&'source str, LoweredGlobalDecl>,
+
+ /// The IR [`Module`] we're constructing.
+ ///
+ /// [`Module`]: crate::Module
+ module: &'out mut crate::Module,
+
+ /// Type judgments for [`module::const_expressions`].
+ ///
+ /// [`module::const_expressions`]: crate::Module::const_expressions
+ const_typifier: &'temp mut Typifier,
+
+ /// Whether we are lowering a constant expression or a general
+ /// runtime expression, and the data needed in each case.
+ expr_type: ExpressionContextType<'temp, 'out>,
+}
+
+impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
+ fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> {
+ ExpressionContext {
+ globals: self.globals,
+ types: self.types,
+ ast_expressions: self.ast_expressions,
+ const_typifier: self.const_typifier,
+ module: self.module,
+ expr_type: ExpressionContextType::Constant,
+ }
+ }
+
+ fn as_global(&mut self) -> GlobalContext<'source, '_, '_> {
+ GlobalContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ const_typifier: self.const_typifier,
+ }
+ }
+
+ fn as_const_evaluator(&mut self) -> ConstantEvaluator {
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref mut rctx) => ConstantEvaluator::for_wgsl_function(
+ self.module,
+ &mut rctx.function.expressions,
+ rctx.expression_constness,
+ rctx.emitter,
+ rctx.block,
+ ),
+ ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(self.module),
+ }
+ }
+
+ fn append_expression(
+ &mut self,
+ expr: crate::Expression,
+ span: Span,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let mut eval = self.as_const_evaluator();
+ match eval.try_eval_and_append(&expr, span) {
+ Ok(expr) => Ok(expr),
+
+ // `expr` is not a constant expression. This is fine as
+ // long as we're not building `Module::const_expressions`.
+ Err(err) => match self.expr_type {
+ ExpressionContextType::Runtime(ref mut rctx) => {
+ Ok(rctx.function.expressions.append(expr, span))
+ }
+ ExpressionContextType::Constant => Err(Error::ConstantEvaluatorError(err, span)),
+ },
+ }
+ }
+
+ fn const_access(&self, handle: Handle<crate::Expression>) -> Option<u32> {
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref ctx) => {
+ if !ctx.expression_constness.is_const(handle) {
+ return None;
+ }
+
+ self.module
+ .to_ctx()
+ .eval_expr_to_u32_from(handle, &ctx.function.expressions)
+ .ok()
+ }
+ ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(),
+ }
+ }
+
+ fn get_expression_span(&self, handle: Handle<crate::Expression>) -> Span {
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle),
+ ExpressionContextType::Constant => self.module.const_expressions.get_span(handle),
+ }
+ }
+
+ fn typifier(&self) -> &Typifier {
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
+ ExpressionContextType::Constant => self.const_typifier,
+ }
+ }
+
+ fn runtime_expression_ctx(
+ &mut self,
+ span: Span,
+ ) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> {
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx),
+ ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)),
+ }
+ }
+
+ fn gather_component(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ component_span: Span,
+ gather_span: Span,
+ ) -> Result<crate::SwizzleComponent, Error<'source>> {
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref rctx) => {
+ if !rctx.expression_constness.is_const(expr) {
+ return Err(Error::ExpectedConstExprConcreteIntegerScalar(
+ component_span,
+ ));
+ }
+
+ let index = self
+ .module
+ .to_ctx()
+ .eval_expr_to_u32_from(expr, &rctx.function.expressions)
+ .map_err(|err| match err {
+ crate::proc::U32EvalError::NonConst => {
+ Error::ExpectedConstExprConcreteIntegerScalar(component_span)
+ }
+ crate::proc::U32EvalError::Negative => {
+ Error::ExpectedNonNegative(component_span)
+ }
+ })?;
+ crate::SwizzleComponent::XYZW
+ .get(index as usize)
+ .copied()
+ .ok_or(Error::InvalidGatherComponent(component_span))
+ }
+ // This means a `gather` operation appeared in a constant expression.
+ // This error refers to the `gather` itself, not its "component" argument.
+ ExpressionContextType::Constant => {
+ Err(Error::UnexpectedOperationInConstContext(gather_span))
+ }
+ }
+ }
+
+ /// Determine the type of `handle`, and add it to the module's arena.
+ ///
+ /// If you just need a `TypeInner` for `handle`'s type, use the
+ /// [`resolve_inner!`] macro instead. This function
+ /// should only be used when the type of `handle` needs to appear
+ /// in the module's final `Arena<Type>`, for example, if you're
+ /// creating a [`LocalVariable`] whose type is inferred from its
+ /// initializer.
+ ///
+ /// [`LocalVariable`]: crate::LocalVariable
+ fn register_type(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ ) -> Result<Handle<crate::Type>, Error<'source>> {
+ self.grow_types(handle)?;
+ // This is equivalent to calling ExpressionContext::typifier(),
+ // except that this lets the borrow checker see that it's okay
+ // to also borrow self.module.types mutably below.
+ let typifier = match self.expr_type {
+ ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
+ ExpressionContextType::Constant => &*self.const_typifier,
+ };
+ Ok(typifier.register_type(handle, &mut self.module.types))
+ }
+
+ /// Resolve the types of all expressions up through `handle`.
+ ///
+ /// Ensure that [`self.typifier`] has a [`TypeResolution`] for
+ /// every expression in [`self.function.expressions`].
+ ///
+ /// This does not add types to any arena. The [`Typifier`]
+ /// documentation explains the steps we take to avoid filling
+ /// arenas with intermediate types.
+ ///
+ /// This function takes `&mut self`, so it can't conveniently
+ /// return a shared reference to the resulting `TypeResolution`:
+ /// the shared reference would extend the mutable borrow, and you
+ /// wouldn't be able to use `self` for anything else. Instead, you
+ /// should use [`register_type`] or one of [`resolve!`],
+ /// [`resolve_inner!`] or [`resolve_inner_binary!`].
+ ///
+ /// [`self.typifier`]: ExpressionContext::typifier
+ /// [`TypeResolution`]: crate::proc::TypeResolution
+ /// [`register_type`]: Self::register_type
+ /// [`Typifier`]: Typifier
+ fn grow_types(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ ) -> Result<&mut Self, Error<'source>> {
+ let empty_arena = Arena::new();
+ let resolve_ctx;
+ let typifier;
+ let expressions;
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref mut ctx) => {
+ resolve_ctx = ResolveContext::with_locals(
+ self.module,
+ &ctx.function.local_variables,
+ &ctx.function.arguments,
+ );
+ typifier = &mut *ctx.typifier;
+ expressions = &ctx.function.expressions;
+ }
+ ExpressionContextType::Constant => {
+ resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]);
+ typifier = self.const_typifier;
+ expressions = &self.module.const_expressions;
+ }
+ };
+ typifier
+ .grow(handle, expressions, &resolve_ctx)
+ .map_err(Error::InvalidResolve)?;
+
+ Ok(self)
+ }
+
+ fn image_data(
+ &mut self,
+ image: Handle<crate::Expression>,
+ span: Span,
+ ) -> Result<(crate::ImageClass, bool), Error<'source>> {
+ match *resolve_inner!(self, image) {
+ crate::TypeInner::Image { class, arrayed, .. } => Ok((class, arrayed)),
+ _ => Err(Error::BadTexture(span)),
+ }
+ }
+
+ fn prepare_args<'b>(
+ &mut self,
+ args: &'b [Handle<ast::Expression<'source>>],
+ min_args: u32,
+ span: Span,
+ ) -> ArgumentContext<'b, 'source> {
+ ArgumentContext {
+ args: args.iter(),
+ min_args,
+ args_used: 0,
+ total_args: args.len() as u32,
+ span,
+ }
+ }
+
+ /// Insert splats, if needed by the non-'*' operations.
+ ///
+ /// See the "Binary arithmetic expressions with mixed scalar and vector operands"
+ /// table in the WebGPU Shading Language specification for relevant operators.
+ ///
+ /// Multiply is not handled here as backends are expected to handle vec*scalar
+ /// operations, so inserting splats into the IR increases size needlessly.
+ fn binary_op_splat(
+ &mut self,
+ op: crate::BinaryOperator,
+ left: &mut Handle<crate::Expression>,
+ right: &mut Handle<crate::Expression>,
+ ) -> Result<(), Error<'source>> {
+ if matches!(
+ op,
+ crate::BinaryOperator::Add
+ | crate::BinaryOperator::Subtract
+ | crate::BinaryOperator::Divide
+ | crate::BinaryOperator::Modulo
+ ) {
+ match resolve_inner_binary!(self, *left, *right) {
+ (&crate::TypeInner::Vector { size, .. }, &crate::TypeInner::Scalar { .. }) => {
+ *right = self.append_expression(
+ crate::Expression::Splat {
+ size,
+ value: *right,
+ },
+ self.get_expression_span(*right),
+ )?;
+ }
+ (&crate::TypeInner::Scalar { .. }, &crate::TypeInner::Vector { size, .. }) => {
+ *left = self.append_expression(
+ crate::Expression::Splat { size, value: *left },
+ self.get_expression_span(*left),
+ )?;
+ }
+ _ => {}
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Add a single expression to the expression table that is not covered by `self.emitter`.
+ ///
+ /// This is useful for `CallResult` and `AtomicResult` expressions, which should not be covered by
+ /// `Emit` statements.
+ fn interrupt_emitter(
+ &mut self,
+ expression: crate::Expression,
+ span: Span,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref mut rctx) => {
+ rctx.block
+ .extend(rctx.emitter.finish(&rctx.function.expressions));
+ }
+ ExpressionContextType::Constant => {}
+ }
+ let result = self.append_expression(expression, span);
+ match self.expr_type {
+ ExpressionContextType::Runtime(ref mut rctx) => {
+ rctx.emitter.start(&rctx.function.expressions);
+ }
+ ExpressionContextType::Constant => {}
+ }
+ result
+ }
+
+ /// Apply the WGSL Load Rule to `expr`.
+ ///
+ /// If `expr` is has type `ref<SC, T, A>`, perform a load to produce a value of type
+ /// `T`. Otherwise, return `expr` unchanged.
+ fn apply_load_rule(
+ &mut self,
+ expr: Typed<Handle<crate::Expression>>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ match expr {
+ Typed::Reference(pointer) => {
+ let load = crate::Expression::Load { pointer };
+ let span = self.get_expression_span(pointer);
+ self.append_expression(load, span)
+ }
+ Typed::Plain(handle) => Ok(handle),
+ }
+ }
+
+ fn ensure_type_exists(&mut self, inner: crate::TypeInner) -> Handle<crate::Type> {
+ self.as_global().ensure_type_exists(None, inner)
+ }
+}
+
+struct ArgumentContext<'ctx, 'source> {
+ args: std::slice::Iter<'ctx, Handle<ast::Expression<'source>>>,
+ min_args: u32,
+ args_used: u32,
+ total_args: u32,
+ span: Span,
+}
+
+impl<'source> ArgumentContext<'_, 'source> {
+ pub fn finish(self) -> Result<(), Error<'source>> {
+ if self.args.len() == 0 {
+ Ok(())
+ } else {
+ Err(Error::WrongArgumentCount {
+ found: self.total_args,
+ expected: self.min_args..self.args_used + 1,
+ span: self.span,
+ })
+ }
+ }
+
+ pub fn next(&mut self) -> Result<Handle<ast::Expression<'source>>, Error<'source>> {
+ match self.args.next().copied() {
+ Some(arg) => {
+ self.args_used += 1;
+ Ok(arg)
+ }
+ None => Err(Error::WrongArgumentCount {
+ found: self.total_args,
+ expected: self.min_args..self.args_used + 1,
+ span: self.span,
+ }),
+ }
+ }
+}
+
+/// WGSL type annotations on expressions, types, values, etc.
+///
+/// Naga and WGSL types are very close, but Naga lacks WGSL's `ref` types, which
+/// we need to know to apply the Load Rule. This enum carries some WGSL or Naga
+/// datum along with enough information to determine its corresponding WGSL
+/// type.
+///
+/// The `T` type parameter can be any expression-like thing:
+///
+/// - `Typed<Handle<crate::Type>>` can represent a full WGSL type. For example,
+/// given some Naga `Pointer` type `ptr`, a WGSL reference type is a
+/// `Typed::Reference(ptr)` whereas a WGSL pointer type is a
+/// `Typed::Plain(ptr)`.
+///
+/// - `Typed<crate::Expression>` or `Typed<Handle<crate::Expression>>` can
+/// represent references similarly.
+///
+/// Use the `map` and `try_map` methods to convert from one expression
+/// representation to another.
+///
+/// [`Expression`]: crate::Expression
+#[derive(Debug, Copy, Clone)]
+enum Typed<T> {
+ /// A WGSL reference.
+ Reference(T),
+
+ /// A WGSL plain type.
+ Plain(T),
+}
+
+impl<T> Typed<T> {
+ fn map<U>(self, mut f: impl FnMut(T) -> U) -> Typed<U> {
+ match self {
+ Self::Reference(v) => Typed::Reference(f(v)),
+ Self::Plain(v) => Typed::Plain(f(v)),
+ }
+ }
+
+ fn try_map<U, E>(self, mut f: impl FnMut(T) -> Result<U, E>) -> Result<Typed<U>, E> {
+ Ok(match self {
+ Self::Reference(expr) => Typed::Reference(f(expr)?),
+ Self::Plain(expr) => Typed::Plain(f(expr)?),
+ })
+ }
+}
+
+/// A single vector component or swizzle.
+///
+/// This represents the things that can appear after the `.` in a vector access
+/// expression: either a single component name, or a series of them,
+/// representing a swizzle.
+enum Components {
+ Single(u32),
+ Swizzle {
+ size: crate::VectorSize,
+ pattern: [crate::SwizzleComponent; 4],
+ },
+}
+
+impl Components {
+ const fn letter_component(letter: char) -> Option<crate::SwizzleComponent> {
+ use crate::SwizzleComponent as Sc;
+ match letter {
+ 'x' | 'r' => Some(Sc::X),
+ 'y' | 'g' => Some(Sc::Y),
+ 'z' | 'b' => Some(Sc::Z),
+ 'w' | 'a' => Some(Sc::W),
+ _ => None,
+ }
+ }
+
+ fn single_component(name: &str, name_span: Span) -> Result<u32, Error> {
+ let ch = name.chars().next().ok_or(Error::BadAccessor(name_span))?;
+ match Self::letter_component(ch) {
+ Some(sc) => Ok(sc as u32),
+ None => Err(Error::BadAccessor(name_span)),
+ }
+ }
+
+ /// Construct a `Components` value from a 'member' name, like `"wzy"` or `"x"`.
+ ///
+ /// Use `name_span` for reporting errors in parsing the component string.
+ fn new(name: &str, name_span: Span) -> Result<Self, Error> {
+ let size = match name.len() {
+ 1 => return Ok(Components::Single(Self::single_component(name, name_span)?)),
+ 2 => crate::VectorSize::Bi,
+ 3 => crate::VectorSize::Tri,
+ 4 => crate::VectorSize::Quad,
+ _ => return Err(Error::BadAccessor(name_span)),
+ };
+
+ let mut pattern = [crate::SwizzleComponent::X; 4];
+ for (comp, ch) in pattern.iter_mut().zip(name.chars()) {
+ *comp = Self::letter_component(ch).ok_or(Error::BadAccessor(name_span))?;
+ }
+
+ Ok(Components::Swizzle { size, pattern })
+ }
+}
+
+/// An `ast::GlobalDecl` for which we have built the Naga IR equivalent.
+enum LoweredGlobalDecl {
+ Function(Handle<crate::Function>),
+ Var(Handle<crate::GlobalVariable>),
+ Const(Handle<crate::Constant>),
+ Type(Handle<crate::Type>),
+ EntryPoint,
+}
+
+enum Texture {
+ Gather,
+ GatherCompare,
+
+ Sample,
+ SampleBias,
+ SampleCompare,
+ SampleCompareLevel,
+ SampleGrad,
+ SampleLevel,
+ // SampleBaseClampToEdge,
+}
+
+impl Texture {
+ pub fn map(word: &str) -> Option<Self> {
+ Some(match word {
+ "textureGather" => Self::Gather,
+ "textureGatherCompare" => Self::GatherCompare,
+
+ "textureSample" => Self::Sample,
+ "textureSampleBias" => Self::SampleBias,
+ "textureSampleCompare" => Self::SampleCompare,
+ "textureSampleCompareLevel" => Self::SampleCompareLevel,
+ "textureSampleGrad" => Self::SampleGrad,
+ "textureSampleLevel" => Self::SampleLevel,
+ // "textureSampleBaseClampToEdge" => Some(Self::SampleBaseClampToEdge),
+ _ => return None,
+ })
+ }
+
+ pub const fn min_argument_count(&self) -> u32 {
+ match *self {
+ Self::Gather => 3,
+ Self::GatherCompare => 4,
+
+ Self::Sample => 3,
+ Self::SampleBias => 5,
+ Self::SampleCompare => 5,
+ Self::SampleCompareLevel => 5,
+ Self::SampleGrad => 6,
+ Self::SampleLevel => 5,
+ // Self::SampleBaseClampToEdge => 3,
+ }
+ }
+}
+
+pub struct Lowerer<'source, 'temp> {
+ index: &'temp Index<'source>,
+ layouter: Layouter,
+}
+
+impl<'source, 'temp> Lowerer<'source, 'temp> {
+ pub fn new(index: &'temp Index<'source>) -> Self {
+ Self {
+ index,
+ layouter: Layouter::default(),
+ }
+ }
+
+ pub fn lower(
+ &mut self,
+ tu: &'temp ast::TranslationUnit<'source>,
+ ) -> Result<crate::Module, Error<'source>> {
+ let mut module = crate::Module::default();
+
+ let mut ctx = GlobalContext {
+ ast_expressions: &tu.expressions,
+ globals: &mut FastHashMap::default(),
+ types: &tu.types,
+ module: &mut module,
+ const_typifier: &mut Typifier::new(),
+ };
+
+ for decl_handle in self.index.visit_ordered() {
+ let span = tu.decls.get_span(decl_handle);
+ let decl = &tu.decls[decl_handle];
+
+ match decl.kind {
+ ast::GlobalDeclKind::Fn(ref f) => {
+ let lowered_decl = self.function(f, span, &mut ctx)?;
+ ctx.globals.insert(f.name.name, lowered_decl);
+ }
+ ast::GlobalDeclKind::Var(ref v) => {
+ let ty = self.resolve_ast_type(v.ty, &mut ctx)?;
+
+ let init;
+ if let Some(init_ast) = v.init {
+ let mut ectx = ctx.as_const();
+ let lowered = self.expression_for_abstract(init_ast, &mut ectx)?;
+ let ty_res = crate::proc::TypeResolution::Handle(ty);
+ let converted = ectx
+ .try_automatic_conversions(lowered, &ty_res, v.name.span)
+ .map_err(|error| match error {
+ Error::AutoConversion {
+ dest_span: _,
+ dest_type,
+ source_span: _,
+ source_type,
+ } => Error::InitializationTypeMismatch {
+ name: v.name.span,
+ expected: dest_type,
+ got: source_type,
+ },
+ other => other,
+ })?;
+ init = Some(converted);
+ } else {
+ init = None;
+ }
+
+ let binding = if let Some(ref binding) = v.binding {
+ Some(crate::ResourceBinding {
+ group: self.const_u32(binding.group, &mut ctx.as_const())?.0,
+ binding: self.const_u32(binding.binding, &mut ctx.as_const())?.0,
+ })
+ } else {
+ None
+ };
+
+ let handle = ctx.module.global_variables.append(
+ crate::GlobalVariable {
+ name: Some(v.name.name.to_string()),
+ space: v.space,
+ binding,
+ ty,
+ init,
+ },
+ span,
+ );
+
+ ctx.globals
+ .insert(v.name.name, LoweredGlobalDecl::Var(handle));
+ }
+ ast::GlobalDeclKind::Const(ref c) => {
+ let mut ectx = ctx.as_const();
+ let mut init = self.expression_for_abstract(c.init, &mut ectx)?;
+
+ let ty;
+ if let Some(explicit_ty) = c.ty {
+ let explicit_ty =
+ self.resolve_ast_type(explicit_ty, &mut ectx.as_global())?;
+ let explicit_ty_res = crate::proc::TypeResolution::Handle(explicit_ty);
+ init = ectx
+ .try_automatic_conversions(init, &explicit_ty_res, c.name.span)
+ .map_err(|error| match error {
+ Error::AutoConversion {
+ dest_span: _,
+ dest_type,
+ source_span: _,
+ source_type,
+ } => Error::InitializationTypeMismatch {
+ name: c.name.span,
+ expected: dest_type,
+ got: source_type,
+ },
+ other => other,
+ })?;
+ ty = explicit_ty;
+ } else {
+ init = ectx.concretize(init)?;
+ ty = ectx.register_type(init)?;
+ }
+
+ let handle = ctx.module.constants.append(
+ crate::Constant {
+ name: Some(c.name.name.to_string()),
+ r#override: crate::Override::None,
+ ty,
+ init,
+ },
+ span,
+ );
+
+ ctx.globals
+ .insert(c.name.name, LoweredGlobalDecl::Const(handle));
+ }
+ ast::GlobalDeclKind::Struct(ref s) => {
+ let handle = self.r#struct(s, span, &mut ctx)?;
+ ctx.globals
+ .insert(s.name.name, LoweredGlobalDecl::Type(handle));
+ }
+ ast::GlobalDeclKind::Type(ref alias) => {
+ let ty = self.resolve_named_ast_type(
+ alias.ty,
+ Some(alias.name.name.to_string()),
+ &mut ctx,
+ )?;
+ ctx.globals
+ .insert(alias.name.name, LoweredGlobalDecl::Type(ty));
+ }
+ }
+ }
+
+ // Constant evaluation may leave abstract-typed literals and
+ // compositions in expression arenas, so we need to compact the module
+ // to remove unused expressions and types.
+ crate::compact::compact(&mut module);
+
+ Ok(module)
+ }
+
+ fn function(
+ &mut self,
+ f: &ast::Function<'source>,
+ span: Span,
+ ctx: &mut GlobalContext<'source, '_, '_>,
+ ) -> Result<LoweredGlobalDecl, Error<'source>> {
+ let mut local_table = FastHashMap::default();
+ let mut expressions = Arena::new();
+ let mut named_expressions = FastIndexMap::default();
+
+ let arguments = f
+ .arguments
+ .iter()
+ .enumerate()
+ .map(|(i, arg)| {
+ let ty = self.resolve_ast_type(arg.ty, ctx)?;
+ let expr = expressions
+ .append(crate::Expression::FunctionArgument(i as u32), arg.name.span);
+ local_table.insert(arg.handle, Typed::Plain(expr));
+ named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span));
+
+ Ok(crate::FunctionArgument {
+ name: Some(arg.name.name.to_string()),
+ ty,
+ binding: self.binding(&arg.binding, ty, ctx)?,
+ })
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let result = f
+ .result
+ .as_ref()
+ .map(|res| {
+ let ty = self.resolve_ast_type(res.ty, ctx)?;
+ Ok(crate::FunctionResult {
+ ty,
+ binding: self.binding(&res.binding, ty, ctx)?,
+ })
+ })
+ .transpose()?;
+
+ let mut function = crate::Function {
+ name: Some(f.name.name.to_string()),
+ arguments,
+ result,
+ local_variables: Arena::new(),
+ expressions,
+ named_expressions: crate::NamedExpressions::default(),
+ body: crate::Block::default(),
+ };
+
+ let mut typifier = Typifier::default();
+ let mut stmt_ctx = StatementContext {
+ local_table: &mut local_table,
+ globals: ctx.globals,
+ ast_expressions: ctx.ast_expressions,
+ const_typifier: ctx.const_typifier,
+ typifier: &mut typifier,
+ function: &mut function,
+ named_expressions: &mut named_expressions,
+ types: ctx.types,
+ module: ctx.module,
+ expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(),
+ };
+ let mut body = self.block(&f.body, false, &mut stmt_ctx)?;
+ ensure_block_returns(&mut body);
+
+ function.body = body;
+ function.named_expressions = named_expressions
+ .into_iter()
+ .map(|(key, (name, _))| (key, name))
+ .collect();
+
+ if let Some(ref entry) = f.entry_point {
+ let workgroup_size = if let Some(workgroup_size) = entry.workgroup_size {
+ // TODO: replace with try_map once stabilized
+ let mut workgroup_size_out = [1; 3];
+ for (i, size) in workgroup_size.into_iter().enumerate() {
+ if let Some(size_expr) = size {
+ workgroup_size_out[i] = self.const_u32(size_expr, &mut ctx.as_const())?.0;
+ }
+ }
+ workgroup_size_out
+ } else {
+ [0; 3]
+ };
+
+ ctx.module.entry_points.push(crate::EntryPoint {
+ name: f.name.name.to_string(),
+ stage: entry.stage,
+ early_depth_test: entry.early_depth_test,
+ workgroup_size,
+ function,
+ });
+ Ok(LoweredGlobalDecl::EntryPoint)
+ } else {
+ let handle = ctx.module.functions.append(function, span);
+ Ok(LoweredGlobalDecl::Function(handle))
+ }
+ }
+
+ fn block(
+ &mut self,
+ b: &ast::Block<'source>,
+ is_inside_loop: bool,
+ ctx: &mut StatementContext<'source, '_, '_>,
+ ) -> Result<crate::Block, Error<'source>> {
+ let mut block = crate::Block::default();
+
+ for stmt in b.stmts.iter() {
+ self.statement(stmt, &mut block, is_inside_loop, ctx)?;
+ }
+
+ Ok(block)
+ }
+
+ fn statement(
+ &mut self,
+ stmt: &ast::Statement<'source>,
+ block: &mut crate::Block,
+ is_inside_loop: bool,
+ ctx: &mut StatementContext<'source, '_, '_>,
+ ) -> Result<(), Error<'source>> {
+ let out = match stmt.kind {
+ ast::StatementKind::Block(ref block) => {
+ let block = self.block(block, is_inside_loop, ctx)?;
+ crate::Statement::Block(block)
+ }
+ ast::StatementKind::LocalDecl(ref decl) => match *decl {
+ ast::LocalDecl::Let(ref l) => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let value =
+ self.expression(l.init, &mut ctx.as_expression(block, &mut emitter))?;
+
+ // The WGSL spec says that any expression that refers to a
+ // `let`-bound variable is not a const expression. This
+ // affects when errors must be reported, so we can't even
+ // treat suitable `let` bindings as constant as an
+ // optimization.
+ ctx.expression_constness.force_non_const(value);
+
+ let explicit_ty =
+ l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global()))
+ .transpose()?;
+
+ if let Some(ty) = explicit_ty {
+ let mut ctx = ctx.as_expression(block, &mut emitter);
+ let init_ty = ctx.register_type(value)?;
+ if !ctx.module.types[ty]
+ .inner
+ .equivalent(&ctx.module.types[init_ty].inner, &ctx.module.types)
+ {
+ let gctx = &ctx.module.to_ctx();
+ return Err(Error::InitializationTypeMismatch {
+ name: l.name.span,
+ expected: ty.to_wgsl(gctx),
+ got: init_ty.to_wgsl(gctx),
+ });
+ }
+ }
+
+ block.extend(emitter.finish(&ctx.function.expressions));
+ ctx.local_table.insert(l.handle, Typed::Plain(value));
+ ctx.named_expressions
+ .insert(value, (l.name.name.to_string(), l.name.span));
+
+ return Ok(());
+ }
+ ast::LocalDecl::Var(ref v) => {
+ let explicit_ty =
+ v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx.as_global()))
+ .transpose()?;
+
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+ let mut ectx = ctx.as_expression(block, &mut emitter);
+
+ let ty;
+ let initializer;
+ match (v.init, explicit_ty) {
+ (Some(init), Some(explicit_ty)) => {
+ let init = self.expression_for_abstract(init, &mut ectx)?;
+ let ty_res = crate::proc::TypeResolution::Handle(explicit_ty);
+ let init = ectx
+ .try_automatic_conversions(init, &ty_res, v.name.span)
+ .map_err(|error| match error {
+ Error::AutoConversion {
+ dest_span: _,
+ dest_type,
+ source_span: _,
+ source_type,
+ } => Error::InitializationTypeMismatch {
+ name: v.name.span,
+ expected: dest_type,
+ got: source_type,
+ },
+ other => other,
+ })?;
+ ty = explicit_ty;
+ initializer = Some(init);
+ }
+ (Some(init), None) => {
+ let concretized = self.expression(init, &mut ectx)?;
+ ty = ectx.register_type(concretized)?;
+ initializer = Some(concretized);
+ }
+ (None, Some(explicit_ty)) => {
+ ty = explicit_ty;
+ initializer = None;
+ }
+ (None, None) => return Err(Error::MissingType(v.name.span)),
+ }
+
+ let (const_initializer, initializer) = {
+ match initializer {
+ Some(init) => {
+ // It's not correct to hoist the initializer up
+ // to the top of the function if:
+ // - the initialization is inside a loop, and should
+ // take place on every iteration, or
+ // - the initialization is not a constant
+ // expression, so its value depends on the
+ // state at the point of initialization.
+ if is_inside_loop || !ctx.expression_constness.is_const(init) {
+ (None, Some(init))
+ } else {
+ (Some(init), None)
+ }
+ }
+ None => (None, None),
+ }
+ };
+
+ let var = ctx.function.local_variables.append(
+ crate::LocalVariable {
+ name: Some(v.name.name.to_string()),
+ ty,
+ init: const_initializer,
+ },
+ stmt.span,
+ );
+
+ let handle = ctx.as_expression(block, &mut emitter).interrupt_emitter(
+ crate::Expression::LocalVariable(var),
+ Span::UNDEFINED,
+ )?;
+ block.extend(emitter.finish(&ctx.function.expressions));
+ ctx.local_table.insert(v.handle, Typed::Reference(handle));
+
+ match initializer {
+ Some(initializer) => crate::Statement::Store {
+ pointer: handle,
+ value: initializer,
+ },
+ None => return Ok(()),
+ }
+ }
+ },
+ ast::StatementKind::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let condition =
+ self.expression(condition, &mut ctx.as_expression(block, &mut emitter))?;
+ block.extend(emitter.finish(&ctx.function.expressions));
+
+ let accept = self.block(accept, is_inside_loop, ctx)?;
+ let reject = self.block(reject, is_inside_loop, ctx)?;
+
+ crate::Statement::If {
+ condition,
+ accept,
+ reject,
+ }
+ }
+ ast::StatementKind::Switch {
+ selector,
+ ref cases,
+ } => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let mut ectx = ctx.as_expression(block, &mut emitter);
+ let selector = self.expression(selector, &mut ectx)?;
+
+ let uint =
+ resolve_inner!(ectx, selector).scalar_kind() == Some(crate::ScalarKind::Uint);
+ block.extend(emitter.finish(&ctx.function.expressions));
+
+ let cases = cases
+ .iter()
+ .map(|case| {
+ Ok(crate::SwitchCase {
+ value: match case.value {
+ ast::SwitchValue::Expr(expr) => {
+ let span = ctx.ast_expressions.get_span(expr);
+ let expr =
+ self.expression(expr, &mut ctx.as_global().as_const())?;
+ match ctx.module.to_ctx().eval_expr_to_literal(expr) {
+ Some(crate::Literal::I32(value)) if !uint => {
+ crate::SwitchValue::I32(value)
+ }
+ Some(crate::Literal::U32(value)) if uint => {
+ crate::SwitchValue::U32(value)
+ }
+ _ => {
+ return Err(Error::InvalidSwitchValue { uint, span });
+ }
+ }
+ }
+ ast::SwitchValue::Default => crate::SwitchValue::Default,
+ },
+ body: self.block(&case.body, is_inside_loop, ctx)?,
+ fall_through: case.fall_through,
+ })
+ })
+ .collect::<Result<_, _>>()?;
+
+ crate::Statement::Switch { selector, cases }
+ }
+ ast::StatementKind::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ let body = self.block(body, true, ctx)?;
+ let mut continuing = self.block(continuing, true, ctx)?;
+
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+ let break_if = break_if
+ .map(|expr| {
+ self.expression(expr, &mut ctx.as_expression(&mut continuing, &mut emitter))
+ })
+ .transpose()?;
+ continuing.extend(emitter.finish(&ctx.function.expressions));
+
+ crate::Statement::Loop {
+ body,
+ continuing,
+ break_if,
+ }
+ }
+ ast::StatementKind::Break => crate::Statement::Break,
+ ast::StatementKind::Continue => crate::Statement::Continue,
+ ast::StatementKind::Return { value } => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let value = value
+ .map(|expr| self.expression(expr, &mut ctx.as_expression(block, &mut emitter)))
+ .transpose()?;
+ block.extend(emitter.finish(&ctx.function.expressions));
+
+ crate::Statement::Return { value }
+ }
+ ast::StatementKind::Kill => crate::Statement::Kill,
+ ast::StatementKind::Call {
+ ref function,
+ ref arguments,
+ } => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let _ = self.call(
+ stmt.span,
+ function,
+ arguments,
+ &mut ctx.as_expression(block, &mut emitter),
+ )?;
+ block.extend(emitter.finish(&ctx.function.expressions));
+ return Ok(());
+ }
+ ast::StatementKind::Assign {
+ target: ast_target,
+ op,
+ value,
+ } => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let target = self.expression_for_reference(
+ ast_target,
+ &mut ctx.as_expression(block, &mut emitter),
+ )?;
+ let mut value =
+ self.expression(value, &mut ctx.as_expression(block, &mut emitter))?;
+
+ let target_handle = match target {
+ Typed::Reference(handle) => handle,
+ Typed::Plain(handle) => {
+ let ty = ctx.invalid_assignment_type(handle);
+ return Err(Error::InvalidAssignment {
+ span: ctx.ast_expressions.get_span(ast_target),
+ ty,
+ });
+ }
+ };
+
+ let value = match op {
+ Some(op) => {
+ let mut ctx = ctx.as_expression(block, &mut emitter);
+ let mut left = ctx.apply_load_rule(target)?;
+ ctx.binary_op_splat(op, &mut left, &mut value)?;
+ ctx.append_expression(
+ crate::Expression::Binary {
+ op,
+ left,
+ right: value,
+ },
+ stmt.span,
+ )?
+ }
+ None => value,
+ };
+ block.extend(emitter.finish(&ctx.function.expressions));
+
+ crate::Statement::Store {
+ pointer: target_handle,
+ value,
+ }
+ }
+ ast::StatementKind::Increment(value) | ast::StatementKind::Decrement(value) => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let op = match stmt.kind {
+ ast::StatementKind::Increment(_) => crate::BinaryOperator::Add,
+ ast::StatementKind::Decrement(_) => crate::BinaryOperator::Subtract,
+ _ => unreachable!(),
+ };
+
+ let value_span = ctx.ast_expressions.get_span(value);
+ let target = self
+ .expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?;
+ let target_handle = match target {
+ Typed::Reference(handle) => handle,
+ Typed::Plain(_) => return Err(Error::BadIncrDecrReferenceType(value_span)),
+ };
+
+ let mut ectx = ctx.as_expression(block, &mut emitter);
+ let scalar = match *resolve_inner!(ectx, target_handle) {
+ crate::TypeInner::ValuePointer {
+ size: None, scalar, ..
+ } => scalar,
+ crate::TypeInner::Pointer { base, .. } => match ectx.module.types[base].inner {
+ crate::TypeInner::Scalar(scalar) => scalar,
+ _ => return Err(Error::BadIncrDecrReferenceType(value_span)),
+ },
+ _ => return Err(Error::BadIncrDecrReferenceType(value_span)),
+ };
+ let literal = match scalar.kind {
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
+ crate::Literal::one(scalar)
+ .ok_or(Error::BadIncrDecrReferenceType(value_span))?
+ }
+ _ => return Err(Error::BadIncrDecrReferenceType(value_span)),
+ };
+
+ let right =
+ ectx.interrupt_emitter(crate::Expression::Literal(literal), Span::UNDEFINED)?;
+ let rctx = ectx.runtime_expression_ctx(stmt.span)?;
+ let left = rctx.function.expressions.append(
+ crate::Expression::Load {
+ pointer: target_handle,
+ },
+ value_span,
+ );
+ let value = rctx
+ .function
+ .expressions
+ .append(crate::Expression::Binary { op, left, right }, stmt.span);
+
+ block.extend(emitter.finish(&ctx.function.expressions));
+ crate::Statement::Store {
+ pointer: target_handle,
+ value,
+ }
+ }
+ ast::StatementKind::Ignore(expr) => {
+ let mut emitter = Emitter::default();
+ emitter.start(&ctx.function.expressions);
+
+ let _ = self.expression(expr, &mut ctx.as_expression(block, &mut emitter))?;
+ block.extend(emitter.finish(&ctx.function.expressions));
+ return Ok(());
+ }
+ };
+
+ block.push(out, stmt.span);
+
+ Ok(())
+ }
+
+ /// Lower `expr` and apply the Load Rule if possible.
+ ///
+ /// For the time being, this concretizes abstract values, to support
+ /// consumers that haven't been adapted to consume them yet. Consumers
+ /// prepared for abstract values can call [`expression_for_abstract`].
+ ///
+ /// [`expression_for_abstract`]: Lowerer::expression_for_abstract
+ fn expression(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let expr = self.expression_for_abstract(expr, ctx)?;
+ ctx.concretize(expr)
+ }
+
+ fn expression_for_abstract(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let expr = self.expression_for_reference(expr, ctx)?;
+ ctx.apply_load_rule(expr)
+ }
+
+ fn expression_for_reference(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Typed<Handle<crate::Expression>>, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let expr = &ctx.ast_expressions[expr];
+
+ let expr: Typed<crate::Expression> = match *expr {
+ ast::Expression::Literal(literal) => {
+ let literal = match literal {
+ ast::Literal::Number(Number::F32(f)) => crate::Literal::F32(f),
+ ast::Literal::Number(Number::I32(i)) => crate::Literal::I32(i),
+ ast::Literal::Number(Number::U32(u)) => crate::Literal::U32(u),
+ ast::Literal::Number(Number::F64(f)) => crate::Literal::F64(f),
+ ast::Literal::Number(Number::AbstractInt(i)) => crate::Literal::AbstractInt(i),
+ ast::Literal::Number(Number::AbstractFloat(f)) => {
+ crate::Literal::AbstractFloat(f)
+ }
+ ast::Literal::Bool(b) => crate::Literal::Bool(b),
+ };
+ let handle = ctx.interrupt_emitter(crate::Expression::Literal(literal), span)?;
+ return Ok(Typed::Plain(handle));
+ }
+ ast::Expression::Ident(ast::IdentExpr::Local(local)) => {
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ return Ok(rctx.local_table[&local]);
+ }
+ ast::Expression::Ident(ast::IdentExpr::Unresolved(name)) => {
+ let global = ctx
+ .globals
+ .get(name)
+ .ok_or(Error::UnknownIdent(span, name))?;
+ let expr = match *global {
+ LoweredGlobalDecl::Var(handle) => {
+ let expr = crate::Expression::GlobalVariable(handle);
+ match ctx.module.global_variables[handle].space {
+ crate::AddressSpace::Handle => Typed::Plain(expr),
+ _ => Typed::Reference(expr),
+ }
+ }
+ LoweredGlobalDecl::Const(handle) => {
+ Typed::Plain(crate::Expression::Constant(handle))
+ }
+ _ => {
+ return Err(Error::Unexpected(span, ExpectedToken::Variable));
+ }
+ };
+
+ return expr.try_map(|handle| ctx.interrupt_emitter(handle, span));
+ }
+ ast::Expression::Construct {
+ ref ty,
+ ty_span,
+ ref components,
+ } => {
+ let handle = self.construct(span, ty, ty_span, components, ctx)?;
+ return Ok(Typed::Plain(handle));
+ }
+ ast::Expression::Unary { op, expr } => {
+ let expr = self.expression_for_abstract(expr, ctx)?;
+ Typed::Plain(crate::Expression::Unary { op, expr })
+ }
+ ast::Expression::AddrOf(expr) => {
+ // The `&` operator simply converts a reference to a pointer. And since a
+ // reference is required, the Load Rule is not applied.
+ match self.expression_for_reference(expr, ctx)? {
+ Typed::Reference(handle) => {
+ // No code is generated. We just declare the reference a pointer now.
+ return Ok(Typed::Plain(handle));
+ }
+ Typed::Plain(_) => {
+ return Err(Error::NotReference("the operand of the `&` operator", span));
+ }
+ }
+ }
+ ast::Expression::Deref(expr) => {
+ // The pointer we dereference must be loaded.
+ let pointer = self.expression(expr, ctx)?;
+
+ if resolve_inner!(ctx, pointer).pointer_space().is_none() {
+ return Err(Error::NotPointer(span));
+ }
+
+ // No code is generated. We just declare the pointer a reference now.
+ return Ok(Typed::Reference(pointer));
+ }
+ ast::Expression::Binary { op, left, right } => {
+ self.binary(op, left, right, span, ctx)?
+ }
+ ast::Expression::Call {
+ ref function,
+ ref arguments,
+ } => {
+ let handle = self
+ .call(span, function, arguments, ctx)?
+ .ok_or(Error::FunctionReturnsVoid(function.span))?;
+ return Ok(Typed::Plain(handle));
+ }
+ ast::Expression::Index { base, index } => {
+ let lowered_base = self.expression_for_reference(base, ctx)?;
+ let index = self.expression(index, ctx)?;
+
+ if let Typed::Plain(handle) = lowered_base {
+ if resolve_inner!(ctx, handle).pointer_space().is_some() {
+ return Err(Error::Pointer(
+ "the value indexed by a `[]` subscripting expression",
+ ctx.ast_expressions.get_span(base),
+ ));
+ }
+ }
+
+ lowered_base.map(|base| match ctx.const_access(index) {
+ Some(index) => crate::Expression::AccessIndex { base, index },
+ None => crate::Expression::Access { base, index },
+ })
+ }
+ ast::Expression::Member { base, ref field } => {
+ let lowered_base = self.expression_for_reference(base, ctx)?;
+
+ let temp_inner;
+ let composite_type: &crate::TypeInner = match lowered_base {
+ Typed::Reference(handle) => {
+ let inner = resolve_inner!(ctx, handle);
+ match *inner {
+ crate::TypeInner::Pointer { base, .. } => &ctx.module.types[base].inner,
+ crate::TypeInner::ValuePointer {
+ size: None, scalar, ..
+ } => {
+ temp_inner = crate::TypeInner::Scalar(scalar);
+ &temp_inner
+ }
+ crate::TypeInner::ValuePointer {
+ size: Some(size),
+ scalar,
+ ..
+ } => {
+ temp_inner = crate::TypeInner::Vector { size, scalar };
+ &temp_inner
+ }
+ _ => unreachable!(
+ "In Typed::Reference(handle), handle must be a Naga pointer"
+ ),
+ }
+ }
+
+ Typed::Plain(handle) => {
+ let inner = resolve_inner!(ctx, handle);
+ if let crate::TypeInner::Pointer { .. }
+ | crate::TypeInner::ValuePointer { .. } = *inner
+ {
+ return Err(Error::Pointer(
+ "the value accessed by a `.member` expression",
+ ctx.ast_expressions.get_span(base),
+ ));
+ }
+ inner
+ }
+ };
+
+ let access = match *composite_type {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let index = members
+ .iter()
+ .position(|m| m.name.as_deref() == Some(field.name))
+ .ok_or(Error::BadAccessor(field.span))?
+ as u32;
+
+ lowered_base.map(|base| crate::Expression::AccessIndex { base, index })
+ }
+ crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => {
+ match Components::new(field.name, field.span)? {
+ Components::Swizzle { size, pattern } => {
+ // Swizzles aren't allowed on matrices, but
+ // validation will catch that.
+ Typed::Plain(crate::Expression::Swizzle {
+ size,
+ vector: ctx.apply_load_rule(lowered_base)?,
+ pattern,
+ })
+ }
+ Components::Single(index) => lowered_base
+ .map(|base| crate::Expression::AccessIndex { base, index }),
+ }
+ }
+ _ => return Err(Error::BadAccessor(field.span)),
+ };
+
+ access
+ }
+ ast::Expression::Bitcast { expr, to, ty_span } => {
+ let expr = self.expression(expr, ctx)?;
+ let to_resolved = self.resolve_ast_type(to, &mut ctx.as_global())?;
+
+ let element_scalar = match ctx.module.types[to_resolved].inner {
+ crate::TypeInner::Scalar(scalar) => scalar,
+ crate::TypeInner::Vector { scalar, .. } => scalar,
+ _ => {
+ let ty = resolve!(ctx, expr);
+ let gctx = &ctx.module.to_ctx();
+ return Err(Error::BadTypeCast {
+ from_type: ty.to_wgsl(gctx),
+ span: ty_span,
+ to_type: to_resolved.to_wgsl(gctx),
+ });
+ }
+ };
+
+ Typed::Plain(crate::Expression::As {
+ expr,
+ kind: element_scalar.kind,
+ convert: None,
+ })
+ }
+ };
+
+ expr.try_map(|handle| ctx.append_expression(handle, span))
+ }
+
+ fn binary(
+ &mut self,
+ op: crate::BinaryOperator,
+ left: Handle<ast::Expression<'source>>,
+ right: Handle<ast::Expression<'source>>,
+ span: Span,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Typed<crate::Expression>, Error<'source>> {
+ // Load both operands.
+ let mut left = self.expression_for_abstract(left, ctx)?;
+ let mut right = self.expression_for_abstract(right, ctx)?;
+
+ // Convert `scalar op vector` to `vector op vector` by introducing
+ // `Splat` expressions.
+ ctx.binary_op_splat(op, &mut left, &mut right)?;
+
+ // Apply automatic conversions.
+ match op {
+ // Shift operators require the right operand to be `u32` or
+ // `vecN<u32>`. We can let the validator sort out vector length
+ // issues, but the right operand must be, or convert to, a u32 leaf
+ // scalar.
+ crate::BinaryOperator::ShiftLeft | crate::BinaryOperator::ShiftRight => {
+ right =
+ ctx.try_automatic_conversion_for_leaf_scalar(right, crate::Scalar::U32, span)?;
+ }
+
+ // All other operators follow the same pattern: reconcile the
+ // scalar leaf types. If there's no reconciliation possible,
+ // leave the expressions as they are: validation will report the
+ // problem.
+ _ => {
+ ctx.grow_types(left)?;
+ ctx.grow_types(right)?;
+ if let Ok(consensus_scalar) =
+ ctx.automatic_conversion_consensus([left, right].iter())
+ {
+ ctx.convert_to_leaf_scalar(&mut left, consensus_scalar)?;
+ ctx.convert_to_leaf_scalar(&mut right, consensus_scalar)?;
+ }
+ }
+ }
+
+ Ok(Typed::Plain(crate::Expression::Binary { op, left, right }))
+ }
+
+ /// Generate Naga IR for call expressions and statements, and type
+ /// constructor expressions.
+ ///
+ /// The "function" being called is simply an `Ident` that we know refers to
+ /// some module-scope definition.
+ ///
+ /// - If it is the name of a type, then the expression is a type constructor
+ /// expression: either constructing a value from components, a conversion
+ /// expression, or a zero value expression.
+ ///
+ /// - If it is the name of a function, then we're generating a [`Call`]
+ /// statement. We may be in the midst of generating code for an
+ /// expression, in which case we must generate an `Emit` statement to
+ /// force evaluation of the IR expressions we've generated so far, add the
+ /// `Call` statement to the current block, and then resume generating
+ /// expressions.
+ ///
+ /// [`Call`]: crate::Statement::Call
+ fn call(
+ &mut self,
+ span: Span,
+ function: &ast::Ident<'source>,
+ arguments: &[Handle<ast::Expression<'source>>],
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Option<Handle<crate::Expression>>, Error<'source>> {
+ match ctx.globals.get(function.name) {
+ Some(&LoweredGlobalDecl::Type(ty)) => {
+ let handle = self.construct(
+ span,
+ &ast::ConstructorType::Type(ty),
+ function.span,
+ arguments,
+ ctx,
+ )?;
+ Ok(Some(handle))
+ }
+ Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => {
+ Err(Error::Unexpected(function.span, ExpectedToken::Function))
+ }
+ Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)),
+ Some(&LoweredGlobalDecl::Function(function)) => {
+ let arguments = arguments
+ .iter()
+ .map(|&arg| self.expression(arg, ctx))
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let has_result = ctx.module.functions[function].result.is_some();
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ // we need to always do this before a fn call since all arguments need to be emitted before the fn call
+ rctx.block
+ .extend(rctx.emitter.finish(&rctx.function.expressions));
+ let result = has_result.then(|| {
+ rctx.function
+ .expressions
+ .append(crate::Expression::CallResult(function), span)
+ });
+ rctx.emitter.start(&rctx.function.expressions);
+ rctx.block.push(
+ crate::Statement::Call {
+ function,
+ arguments,
+ result,
+ },
+ span,
+ );
+
+ Ok(result)
+ }
+ None => {
+ let span = function.span;
+ let expr = if let Some(fun) = conv::map_relational_fun(function.name) {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let argument = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ // Check for no-op all(bool) and any(bool):
+ let argument_unmodified = matches!(
+ fun,
+ crate::RelationalFunction::All | crate::RelationalFunction::Any
+ ) && {
+ matches!(
+ resolve_inner!(ctx, argument),
+ &crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Bool,
+ ..
+ })
+ )
+ };
+
+ if argument_unmodified {
+ return Ok(Some(argument));
+ } else {
+ crate::Expression::Relational { fun, argument }
+ }
+ } else if let Some((axis, ctrl)) = conv::map_derivative(function.name) {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let expr = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ crate::Expression::Derivative { axis, ctrl, expr }
+ } else if let Some(fun) = conv::map_standard_fun(function.name) {
+ let expected = fun.argument_count() as _;
+ let mut args = ctx.prepare_args(arguments, expected, span);
+
+ let arg = self.expression(args.next()?, ctx)?;
+ let arg1 = args
+ .next()
+ .map(|x| self.expression(x, ctx))
+ .ok()
+ .transpose()?;
+ let arg2 = args
+ .next()
+ .map(|x| self.expression(x, ctx))
+ .ok()
+ .transpose()?;
+ let arg3 = args
+ .next()
+ .map(|x| self.expression(x, ctx))
+ .ok()
+ .transpose()?;
+
+ args.finish()?;
+
+ if fun == crate::MathFunction::Modf || fun == crate::MathFunction::Frexp {
+ if let Some((size, width)) = match *resolve_inner!(ctx, arg) {
+ crate::TypeInner::Scalar(crate::Scalar { width, .. }) => {
+ Some((None, width))
+ }
+ crate::TypeInner::Vector {
+ size,
+ scalar: crate::Scalar { width, .. },
+ ..
+ } => Some((Some(size), width)),
+ _ => None,
+ } {
+ ctx.module.generate_predeclared_type(
+ if fun == crate::MathFunction::Modf {
+ crate::PredeclaredType::ModfResult { size, width }
+ } else {
+ crate::PredeclaredType::FrexpResult { size, width }
+ },
+ );
+ }
+ }
+
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ }
+ } else if let Some(fun) = Texture::map(function.name) {
+ self.texture_sample_helper(fun, arguments, span, ctx)?
+ } else {
+ match function.name {
+ "select" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+
+ let reject = self.expression(args.next()?, ctx)?;
+ let accept = self.expression(args.next()?, ctx)?;
+ let condition = self.expression(args.next()?, ctx)?;
+
+ args.finish()?;
+
+ crate::Expression::Select {
+ reject,
+ accept,
+ condition,
+ }
+ }
+ "arrayLength" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let expr = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ crate::Expression::ArrayLength(expr)
+ }
+ "atomicLoad" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let pointer = self.atomic_pointer(args.next()?, ctx)?;
+ args.finish()?;
+
+ crate::Expression::Load { pointer }
+ }
+ "atomicStore" => {
+ let mut args = ctx.prepare_args(arguments, 2, span);
+ let pointer = self.atomic_pointer(args.next()?, ctx)?;
+ let value = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .extend(rctx.emitter.finish(&rctx.function.expressions));
+ rctx.emitter.start(&rctx.function.expressions);
+ rctx.block
+ .push(crate::Statement::Store { pointer, value }, span);
+ return Ok(None);
+ }
+ "atomicAdd" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Add,
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicSub" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Subtract,
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicAnd" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::And,
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicOr" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::InclusiveOr,
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicXor" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::ExclusiveOr,
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicMin" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Min,
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicMax" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Max,
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicExchange" => {
+ return Ok(Some(self.atomic_helper(
+ span,
+ crate::AtomicFunction::Exchange { compare: None },
+ arguments,
+ ctx,
+ )?))
+ }
+ "atomicCompareExchangeWeak" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+
+ let pointer = self.atomic_pointer(args.next()?, ctx)?;
+
+ let compare = self.expression(args.next()?, ctx)?;
+
+ let value = args.next()?;
+ let value_span = ctx.ast_expressions.get_span(value);
+ let value = self.expression(value, ctx)?;
+
+ args.finish()?;
+
+ let expression = match *resolve_inner!(ctx, value) {
+ crate::TypeInner::Scalar(scalar) => {
+ crate::Expression::AtomicResult {
+ ty: ctx.module.generate_predeclared_type(
+ crate::PredeclaredType::AtomicCompareExchangeWeakResult(
+ scalar,
+ ),
+ ),
+ comparison: true,
+ }
+ }
+ _ => return Err(Error::InvalidAtomicOperandType(value_span)),
+ };
+
+ let result = ctx.interrupt_emitter(expression, span)?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block.push(
+ crate::Statement::Atomic {
+ pointer,
+ fun: crate::AtomicFunction::Exchange {
+ compare: Some(compare),
+ },
+ value,
+ result,
+ },
+ span,
+ );
+ return Ok(Some(result));
+ }
+ "storageBarrier" => {
+ ctx.prepare_args(arguments, 0, span).finish()?;
+
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .push(crate::Statement::Barrier(crate::Barrier::STORAGE), span);
+ return Ok(None);
+ }
+ "workgroupBarrier" => {
+ ctx.prepare_args(arguments, 0, span).finish()?;
+
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span);
+ return Ok(None);
+ }
+ "workgroupUniformLoad" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let expr = args.next()?;
+ args.finish()?;
+
+ let pointer = self.expression(expr, ctx)?;
+ let result_ty = match *resolve_inner!(ctx, pointer) {
+ crate::TypeInner::Pointer {
+ base,
+ space: crate::AddressSpace::WorkGroup,
+ } => base,
+ ref other => {
+ log::error!("Type {other:?} passed to workgroupUniformLoad");
+ let span = ctx.ast_expressions.get_span(expr);
+ return Err(Error::InvalidWorkGroupUniformLoad(span));
+ }
+ };
+ let result = ctx.interrupt_emitter(
+ crate::Expression::WorkGroupUniformLoadResult { ty: result_ty },
+ span,
+ )?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block.push(
+ crate::Statement::WorkGroupUniformLoad { pointer, result },
+ span,
+ );
+
+ return Ok(Some(result));
+ }
+ "textureStore" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+
+ let image = args.next()?;
+ let image_span = ctx.ast_expressions.get_span(image);
+ let image = self.expression(image, ctx)?;
+
+ let coordinate = self.expression(args.next()?, ctx)?;
+
+ let (_, arrayed) = ctx.image_data(image, image_span)?;
+ let array_index = arrayed
+ .then(|| {
+ args.min_args += 1;
+ self.expression(args.next()?, ctx)
+ })
+ .transpose()?;
+
+ let value = self.expression(args.next()?, ctx)?;
+
+ args.finish()?;
+
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .extend(rctx.emitter.finish(&rctx.function.expressions));
+ rctx.emitter.start(&rctx.function.expressions);
+ let stmt = crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ };
+ rctx.block.push(stmt, span);
+ return Ok(None);
+ }
+ "textureLoad" => {
+ let mut args = ctx.prepare_args(arguments, 2, span);
+
+ let image = args.next()?;
+ let image_span = ctx.ast_expressions.get_span(image);
+ let image = self.expression(image, ctx)?;
+
+ let coordinate = self.expression(args.next()?, ctx)?;
+
+ let (class, arrayed) = ctx.image_data(image, image_span)?;
+ let array_index = arrayed
+ .then(|| {
+ args.min_args += 1;
+ self.expression(args.next()?, ctx)
+ })
+ .transpose()?;
+
+ let level = class
+ .is_mipmapped()
+ .then(|| {
+ args.min_args += 1;
+ self.expression(args.next()?, ctx)
+ })
+ .transpose()?;
+
+ let sample = class
+ .is_multisampled()
+ .then(|| self.expression(args.next()?, ctx))
+ .transpose()?;
+
+ args.finish()?;
+
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ level,
+ sample,
+ }
+ }
+ "textureDimensions" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx)?;
+ let level = args
+ .next()
+ .map(|arg| self.expression(arg, ctx))
+ .ok()
+ .transpose()?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::Size { level },
+ }
+ }
+ "textureNumLevels" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::NumLevels,
+ }
+ }
+ "textureNumLayers" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::NumLayers,
+ }
+ }
+ "textureNumSamples" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let image = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ crate::Expression::ImageQuery {
+ image,
+ query: crate::ImageQuery::NumSamples,
+ }
+ }
+ "rayQueryInitialize" => {
+ let mut args = ctx.prepare_args(arguments, 3, span);
+ let query = self.ray_query_pointer(args.next()?, ctx)?;
+ let acceleration_structure = self.expression(args.next()?, ctx)?;
+ let descriptor = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ let _ = ctx.module.generate_ray_desc_type();
+ let fun = crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ };
+
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .extend(rctx.emitter.finish(&rctx.function.expressions));
+ rctx.emitter.start(&rctx.function.expressions);
+ rctx.block
+ .push(crate::Statement::RayQuery { query, fun }, span);
+ return Ok(None);
+ }
+ "rayQueryProceed" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let query = self.ray_query_pointer(args.next()?, ctx)?;
+ args.finish()?;
+
+ let result = ctx.interrupt_emitter(
+ crate::Expression::RayQueryProceedResult,
+ span,
+ )?;
+ let fun = crate::RayQueryFunction::Proceed { result };
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .push(crate::Statement::RayQuery { query, fun }, span);
+ return Ok(Some(result));
+ }
+ "rayQueryGetCommittedIntersection" => {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+ let query = self.ray_query_pointer(args.next()?, ctx)?;
+ args.finish()?;
+
+ let _ = ctx.module.generate_ray_intersection_type();
+
+ crate::Expression::RayQueryGetIntersection {
+ query,
+ committed: true,
+ }
+ }
+ "RayDesc" => {
+ let ty = ctx.module.generate_ray_desc_type();
+ let handle = self.construct(
+ span,
+ &ast::ConstructorType::Type(ty),
+ function.span,
+ arguments,
+ ctx,
+ )?;
+ return Ok(Some(handle));
+ }
+ _ => return Err(Error::UnknownIdent(function.span, function.name)),
+ }
+ };
+
+ let expr = ctx.append_expression(expr, span)?;
+ Ok(Some(expr))
+ }
+ }
+ }
+
+ fn atomic_pointer(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let pointer = self.expression(expr, ctx)?;
+
+ match *resolve_inner!(ctx, pointer) {
+ crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => Ok(pointer),
+ ref other => {
+ log::error!("Pointer type to {:?} passed to atomic op", other);
+ Err(Error::InvalidAtomicPointer(span))
+ }
+ },
+ ref other => {
+ log::error!("Type {:?} passed to atomic op", other);
+ Err(Error::InvalidAtomicPointer(span))
+ }
+ }
+ }
+
+ fn atomic_helper(
+ &mut self,
+ span: Span,
+ fun: crate::AtomicFunction,
+ args: &[Handle<ast::Expression<'source>>],
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let mut args = ctx.prepare_args(args, 2, span);
+
+ let pointer = self.atomic_pointer(args.next()?, ctx)?;
+
+ let value = args.next()?;
+ let value = self.expression(value, ctx)?;
+ let ty = ctx.register_type(value)?;
+
+ args.finish()?;
+
+ let result = ctx.interrupt_emitter(
+ crate::Expression::AtomicResult {
+ ty,
+ comparison: false,
+ },
+ span,
+ )?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block.push(
+ crate::Statement::Atomic {
+ pointer,
+ fun,
+ value,
+ result,
+ },
+ span,
+ );
+ Ok(result)
+ }
+
+ fn texture_sample_helper(
+ &mut self,
+ fun: Texture,
+ args: &[Handle<ast::Expression<'source>>],
+ span: Span,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<crate::Expression, Error<'source>> {
+ let mut args = ctx.prepare_args(args, fun.min_argument_count(), span);
+
+ fn get_image_and_span<'source>(
+ lowerer: &mut Lowerer<'source, '_>,
+ args: &mut ArgumentContext<'_, 'source>,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<(Handle<crate::Expression>, Span), Error<'source>> {
+ let image = args.next()?;
+ let image_span = ctx.ast_expressions.get_span(image);
+ let image = lowerer.expression(image, ctx)?;
+ Ok((image, image_span))
+ }
+
+ let (image, image_span, gather) = match fun {
+ Texture::Gather => {
+ let image_or_component = args.next()?;
+ let image_or_component_span = ctx.ast_expressions.get_span(image_or_component);
+ // Gathers from depth textures don't take an initial `component` argument.
+ let lowered_image_or_component = self.expression(image_or_component, ctx)?;
+
+ match *resolve_inner!(ctx, lowered_image_or_component) {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Depth { .. },
+ ..
+ } => (
+ lowered_image_or_component,
+ image_or_component_span,
+ Some(crate::SwizzleComponent::X),
+ ),
+ _ => {
+ let (image, image_span) = get_image_and_span(self, &mut args, ctx)?;
+ (
+ image,
+ image_span,
+ Some(ctx.gather_component(
+ lowered_image_or_component,
+ image_or_component_span,
+ span,
+ )?),
+ )
+ }
+ }
+ }
+ Texture::GatherCompare => {
+ let (image, image_span) = get_image_and_span(self, &mut args, ctx)?;
+ (image, image_span, Some(crate::SwizzleComponent::X))
+ }
+
+ _ => {
+ let (image, image_span) = get_image_and_span(self, &mut args, ctx)?;
+ (image, image_span, None)
+ }
+ };
+
+ let sampler = self.expression(args.next()?, ctx)?;
+
+ let coordinate = self.expression(args.next()?, ctx)?;
+
+ let (_, arrayed) = ctx.image_data(image, image_span)?;
+ let array_index = arrayed
+ .then(|| self.expression(args.next()?, ctx))
+ .transpose()?;
+
+ let (level, depth_ref) = match fun {
+ Texture::Gather => (crate::SampleLevel::Zero, None),
+ Texture::GatherCompare => {
+ let reference = self.expression(args.next()?, ctx)?;
+ (crate::SampleLevel::Zero, Some(reference))
+ }
+
+ Texture::Sample => (crate::SampleLevel::Auto, None),
+ Texture::SampleBias => {
+ let bias = self.expression(args.next()?, ctx)?;
+ (crate::SampleLevel::Bias(bias), None)
+ }
+ Texture::SampleCompare => {
+ let reference = self.expression(args.next()?, ctx)?;
+ (crate::SampleLevel::Auto, Some(reference))
+ }
+ Texture::SampleCompareLevel => {
+ let reference = self.expression(args.next()?, ctx)?;
+ (crate::SampleLevel::Zero, Some(reference))
+ }
+ Texture::SampleGrad => {
+ let x = self.expression(args.next()?, ctx)?;
+ let y = self.expression(args.next()?, ctx)?;
+ (crate::SampleLevel::Gradient { x, y }, None)
+ }
+ Texture::SampleLevel => {
+ let level = self.expression(args.next()?, ctx)?;
+ (crate::SampleLevel::Exact(level), None)
+ }
+ };
+
+ let offset = args
+ .next()
+ .map(|arg| self.expression(arg, &mut ctx.as_const()))
+ .ok()
+ .transpose()?;
+
+ args.finish()?;
+
+ Ok(crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ })
+ }
+
+ fn r#struct(
+ &mut self,
+ s: &ast::Struct<'source>,
+ span: Span,
+ ctx: &mut GlobalContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Type>, Error<'source>> {
+ let mut offset = 0;
+ let mut struct_alignment = Alignment::ONE;
+ let mut members = Vec::with_capacity(s.members.len());
+
+ for member in s.members.iter() {
+ let ty = self.resolve_ast_type(member.ty, ctx)?;
+
+ self.layouter.update(ctx.module.to_ctx()).unwrap();
+
+ let member_min_size = self.layouter[ty].size;
+ let member_min_alignment = self.layouter[ty].alignment;
+
+ let member_size = if let Some(size_expr) = member.size {
+ let (size, span) = self.const_u32(size_expr, &mut ctx.as_const())?;
+ if size < member_min_size {
+ return Err(Error::SizeAttributeTooLow(span, member_min_size));
+ } else {
+ size
+ }
+ } else {
+ member_min_size
+ };
+
+ let member_alignment = if let Some(align_expr) = member.align {
+ let (align, span) = self.const_u32(align_expr, &mut ctx.as_const())?;
+ if let Some(alignment) = Alignment::new(align) {
+ if alignment < member_min_alignment {
+ return Err(Error::AlignAttributeTooLow(span, member_min_alignment));
+ } else {
+ alignment
+ }
+ } else {
+ return Err(Error::NonPowerOfTwoAlignAttribute(span));
+ }
+ } else {
+ member_min_alignment
+ };
+
+ let binding = self.binding(&member.binding, ty, ctx)?;
+
+ offset = member_alignment.round_up(offset);
+ struct_alignment = struct_alignment.max(member_alignment);
+
+ members.push(crate::StructMember {
+ name: Some(member.name.name.to_owned()),
+ ty,
+ binding,
+ offset,
+ });
+
+ offset += member_size;
+ }
+
+ let size = struct_alignment.round_up(offset);
+ let inner = crate::TypeInner::Struct {
+ members,
+ span: size,
+ };
+
+ let handle = ctx.module.types.insert(
+ crate::Type {
+ name: Some(s.name.name.to_string()),
+ inner,
+ },
+ span,
+ );
+ Ok(handle)
+ }
+
+ fn const_u32(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<(u32, Span), Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let expr = self.expression(expr, ctx)?;
+ let value = ctx
+ .module
+ .to_ctx()
+ .eval_expr_to_u32(expr)
+ .map_err(|err| match err {
+ crate::proc::U32EvalError::NonConst => {
+ Error::ExpectedConstExprConcreteIntegerScalar(span)
+ }
+ crate::proc::U32EvalError::Negative => Error::ExpectedNonNegative(span),
+ })?;
+ Ok((value, span))
+ }
+
+ fn array_size(
+ &mut self,
+ size: ast::ArraySize<'source>,
+ ctx: &mut GlobalContext<'source, '_, '_>,
+ ) -> Result<crate::ArraySize, Error<'source>> {
+ Ok(match size {
+ ast::ArraySize::Constant(expr) => {
+ let span = ctx.ast_expressions.get_span(expr);
+ let const_expr = self.expression(expr, &mut ctx.as_const())?;
+ let len =
+ ctx.module
+ .to_ctx()
+ .eval_expr_to_u32(const_expr)
+ .map_err(|err| match err {
+ crate::proc::U32EvalError::NonConst => {
+ Error::ExpectedConstExprConcreteIntegerScalar(span)
+ }
+ crate::proc::U32EvalError::Negative => {
+ Error::ExpectedPositiveArrayLength(span)
+ }
+ })?;
+ let size = NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?;
+ crate::ArraySize::Constant(size)
+ }
+ ast::ArraySize::Dynamic => crate::ArraySize::Dynamic,
+ })
+ }
+
+ /// Build the Naga equivalent of a named AST type.
+ ///
+ /// Return a Naga `Handle<Type>` representing the front-end type
+ /// `handle`, which should be named `name`, if given.
+ ///
+ /// If `handle` refers to a type cached in [`SpecialTypes`],
+ /// `name` may be ignored.
+ ///
+ /// [`SpecialTypes`]: crate::SpecialTypes
+ fn resolve_named_ast_type(
+ &mut self,
+ handle: Handle<ast::Type<'source>>,
+ name: Option<String>,
+ ctx: &mut GlobalContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Type>, Error<'source>> {
+ let inner = match ctx.types[handle] {
+ ast::Type::Scalar(scalar) => scalar.to_inner_scalar(),
+ ast::Type::Vector { size, scalar } => scalar.to_inner_vector(size),
+ ast::Type::Matrix {
+ rows,
+ columns,
+ width,
+ } => crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar: crate::Scalar::float(width),
+ },
+ ast::Type::Atomic(scalar) => scalar.to_inner_atomic(),
+ ast::Type::Pointer { base, space } => {
+ let base = self.resolve_ast_type(base, ctx)?;
+ crate::TypeInner::Pointer { base, space }
+ }
+ ast::Type::Array { base, size } => {
+ let base = self.resolve_ast_type(base, ctx)?;
+ let size = self.array_size(size, ctx)?;
+
+ self.layouter.update(ctx.module.to_ctx()).unwrap();
+ let stride = self.layouter[base].to_stride();
+
+ crate::TypeInner::Array { base, size, stride }
+ }
+ ast::Type::Image {
+ dim,
+ arrayed,
+ class,
+ } => crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ },
+ ast::Type::Sampler { comparison } => crate::TypeInner::Sampler { comparison },
+ ast::Type::AccelerationStructure => crate::TypeInner::AccelerationStructure,
+ ast::Type::RayQuery => crate::TypeInner::RayQuery,
+ ast::Type::BindingArray { base, size } => {
+ let base = self.resolve_ast_type(base, ctx)?;
+ let size = self.array_size(size, ctx)?;
+ crate::TypeInner::BindingArray { base, size }
+ }
+ ast::Type::RayDesc => {
+ return Ok(ctx.module.generate_ray_desc_type());
+ }
+ ast::Type::RayIntersection => {
+ return Ok(ctx.module.generate_ray_intersection_type());
+ }
+ ast::Type::User(ref ident) => {
+ return match ctx.globals.get(ident.name) {
+ Some(&LoweredGlobalDecl::Type(handle)) => Ok(handle),
+ Some(_) => Err(Error::Unexpected(ident.span, ExpectedToken::Type)),
+ None => Err(Error::UnknownType(ident.span)),
+ }
+ }
+ };
+
+ Ok(ctx.ensure_type_exists(name, inner))
+ }
+
+ /// Return a Naga `Handle<Type>` representing the front-end type `handle`.
+ fn resolve_ast_type(
+ &mut self,
+ handle: Handle<ast::Type<'source>>,
+ ctx: &mut GlobalContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Type>, Error<'source>> {
+ self.resolve_named_ast_type(handle, None, ctx)
+ }
+
+ fn binding(
+ &mut self,
+ binding: &Option<ast::Binding<'source>>,
+ ty: Handle<crate::Type>,
+ ctx: &mut GlobalContext<'source, '_, '_>,
+ ) -> Result<Option<crate::Binding>, Error<'source>> {
+ Ok(match *binding {
+ Some(ast::Binding::BuiltIn(b)) => Some(crate::Binding::BuiltIn(b)),
+ Some(ast::Binding::Location {
+ location,
+ second_blend_source,
+ interpolation,
+ sampling,
+ }) => {
+ let mut binding = crate::Binding::Location {
+ location: self.const_u32(location, &mut ctx.as_const())?.0,
+ second_blend_source,
+ interpolation,
+ sampling,
+ };
+ binding.apply_default_interpolation(&ctx.module.types[ty].inner);
+ Some(binding)
+ }
+ None => None,
+ })
+ }
+
+ fn ray_query_pointer(
+ &mut self,
+ expr: Handle<ast::Expression<'source>>,
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let span = ctx.ast_expressions.get_span(expr);
+ let pointer = self.expression(expr, ctx)?;
+
+ match *resolve_inner!(ctx, pointer) {
+ crate::TypeInner::Pointer { base, .. } => match ctx.module.types[base].inner {
+ crate::TypeInner::RayQuery => Ok(pointer),
+ ref other => {
+ log::error!("Pointer type to {:?} passed to ray query op", other);
+ Err(Error::InvalidRayQueryPointer(span))
+ }
+ },
+ ref other => {
+ log::error!("Type {:?} passed to ray query op", other);
+ Err(Error::InvalidRayQueryPointer(span))
+ }
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/mod.rs b/third_party/rust/naga/src/front/wgsl/mod.rs
new file mode 100644
index 0000000000..b6151fe1c0
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/mod.rs
@@ -0,0 +1,49 @@
+/*!
+Frontend for [WGSL][wgsl] (WebGPU Shading Language).
+
+[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html
+*/
+
+mod error;
+mod index;
+mod lower;
+mod parse;
+#[cfg(test)]
+mod tests;
+mod to_wgsl;
+
+use crate::front::wgsl::error::Error;
+use crate::front::wgsl::parse::Parser;
+use thiserror::Error;
+
+pub use crate::front::wgsl::error::ParseError;
+use crate::front::wgsl::lower::Lowerer;
+use crate::Scalar;
+
+pub struct Frontend {
+ parser: Parser,
+}
+
+impl Frontend {
+ pub const fn new() -> Self {
+ Self {
+ parser: Parser::new(),
+ }
+ }
+
+ pub fn parse(&mut self, source: &str) -> Result<crate::Module, ParseError> {
+ self.inner(source).map_err(|x| x.as_parse_error(source))
+ }
+
+ fn inner<'a>(&mut self, source: &'a str) -> Result<crate::Module, Error<'a>> {
+ let tu = self.parser.parse(source)?;
+ let index = index::Index::generate(&tu)?;
+ let module = Lowerer::new(&index).lower(&tu)?;
+
+ Ok(module)
+ }
+}
+
+pub fn parse_str(source: &str) -> Result<crate::Module, ParseError> {
+ Frontend::new().parse(source)
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/ast.rs b/third_party/rust/naga/src/front/wgsl/parse/ast.rs
new file mode 100644
index 0000000000..dbaac523cb
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/ast.rs
@@ -0,0 +1,491 @@
+use crate::front::wgsl::parse::number::Number;
+use crate::front::wgsl::Scalar;
+use crate::{Arena, FastIndexSet, Handle, Span};
+use std::hash::Hash;
+
+#[derive(Debug, Default)]
+pub struct TranslationUnit<'a> {
+ pub decls: Arena<GlobalDecl<'a>>,
+ /// The common expressions arena for the entire translation unit.
+ ///
+ /// All functions, global initializers, array lengths, etc. store their
+ /// expressions here. We apportion these out to individual Naga
+ /// [`Function`]s' expression arenas at lowering time. Keeping them all in a
+ /// single arena simplifies handling of things like array lengths (which are
+ /// effectively global and thus don't clearly belong to any function) and
+ /// initializers (which can appear in both function-local and module-scope
+ /// contexts).
+ ///
+ /// [`Function`]: crate::Function
+ pub expressions: Arena<Expression<'a>>,
+
+ /// Non-user-defined types, like `vec4<f32>` or `array<i32, 10>`.
+ ///
+ /// These are referred to by `Handle<ast::Type<'a>>` values.
+ /// User-defined types are referred to by name until lowering.
+ pub types: Arena<Type<'a>>,
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct Ident<'a> {
+ pub name: &'a str,
+ pub span: Span,
+}
+
+#[derive(Debug)]
+pub enum IdentExpr<'a> {
+ Unresolved(&'a str),
+ Local(Handle<Local>),
+}
+
+/// A reference to a module-scope definition or predeclared object.
+///
+/// Each [`GlobalDecl`] holds a set of these values, to be resolved to
+/// specific definitions later. To support de-duplication, `Eq` and
+/// `Hash` on a `Dependency` value consider only the name, not the
+/// source location at which the reference occurs.
+#[derive(Debug)]
+pub struct Dependency<'a> {
+ /// The name referred to.
+ pub ident: &'a str,
+
+ /// The location at which the reference to that name occurs.
+ pub usage: Span,
+}
+
+impl Hash for Dependency<'_> {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ self.ident.hash(state);
+ }
+}
+
+impl PartialEq for Dependency<'_> {
+ fn eq(&self, other: &Self) -> bool {
+ self.ident == other.ident
+ }
+}
+
+impl Eq for Dependency<'_> {}
+
+/// A module-scope declaration.
+#[derive(Debug)]
+pub struct GlobalDecl<'a> {
+ pub kind: GlobalDeclKind<'a>,
+
+ /// Names of all module-scope or predeclared objects this
+ /// declaration uses.
+ pub dependencies: FastIndexSet<Dependency<'a>>,
+}
+
+#[derive(Debug)]
+pub enum GlobalDeclKind<'a> {
+ Fn(Function<'a>),
+ Var(GlobalVariable<'a>),
+ Const(Const<'a>),
+ Struct(Struct<'a>),
+ Type(TypeAlias<'a>),
+}
+
+#[derive(Debug)]
+pub struct FunctionArgument<'a> {
+ pub name: Ident<'a>,
+ pub ty: Handle<Type<'a>>,
+ pub binding: Option<Binding<'a>>,
+ pub handle: Handle<Local>,
+}
+
+#[derive(Debug)]
+pub struct FunctionResult<'a> {
+ pub ty: Handle<Type<'a>>,
+ pub binding: Option<Binding<'a>>,
+}
+
+#[derive(Debug)]
+pub struct EntryPoint<'a> {
+ pub stage: crate::ShaderStage,
+ pub early_depth_test: Option<crate::EarlyDepthTest>,
+ pub workgroup_size: Option<[Option<Handle<Expression<'a>>>; 3]>,
+}
+
+#[cfg(doc)]
+use crate::front::wgsl::lower::{RuntimeExpressionContext, StatementContext};
+
+#[derive(Debug)]
+pub struct Function<'a> {
+ pub entry_point: Option<EntryPoint<'a>>,
+ pub name: Ident<'a>,
+ pub arguments: Vec<FunctionArgument<'a>>,
+ pub result: Option<FunctionResult<'a>>,
+
+ /// Local variable and function argument arena.
+ ///
+ /// Note that the `Local` here is actually a zero-sized type. The AST keeps
+ /// all the detailed information about locals - names, types, etc. - in
+ /// [`LocalDecl`] statements. For arguments, that information is kept in
+ /// [`arguments`]. This `Arena`'s only role is to assign a unique `Handle`
+ /// to each of them, and track their definitions' spans for use in
+ /// diagnostics.
+ ///
+ /// In the AST, when an [`Ident`] expression refers to a local variable or
+ /// argument, its [`IdentExpr`] holds the referent's `Handle<Local>` in this
+ /// arena.
+ ///
+ /// During lowering, [`LocalDecl`] statements add entries to a per-function
+ /// table that maps `Handle<Local>` values to their Naga representations,
+ /// accessed via [`StatementContext::local_table`] and
+ /// [`RuntimeExpressionContext::local_table`]. This table is then consulted when
+ /// lowering subsequent [`Ident`] expressions.
+ ///
+ /// [`LocalDecl`]: StatementKind::LocalDecl
+ /// [`arguments`]: Function::arguments
+ /// [`Ident`]: Expression::Ident
+ /// [`StatementContext::local_table`]: StatementContext::local_table
+ /// [`RuntimeExpressionContext::local_table`]: RuntimeExpressionContext::local_table
+ pub locals: Arena<Local>,
+
+ pub body: Block<'a>,
+}
+
+#[derive(Debug)]
+pub enum Binding<'a> {
+ BuiltIn(crate::BuiltIn),
+ Location {
+ location: Handle<Expression<'a>>,
+ second_blend_source: bool,
+ interpolation: Option<crate::Interpolation>,
+ sampling: Option<crate::Sampling>,
+ },
+}
+
+#[derive(Debug)]
+pub struct ResourceBinding<'a> {
+ pub group: Handle<Expression<'a>>,
+ pub binding: Handle<Expression<'a>>,
+}
+
+#[derive(Debug)]
+pub struct GlobalVariable<'a> {
+ pub name: Ident<'a>,
+ pub space: crate::AddressSpace,
+ pub binding: Option<ResourceBinding<'a>>,
+ pub ty: Handle<Type<'a>>,
+ pub init: Option<Handle<Expression<'a>>>,
+}
+
+#[derive(Debug)]
+pub struct StructMember<'a> {
+ pub name: Ident<'a>,
+ pub ty: Handle<Type<'a>>,
+ pub binding: Option<Binding<'a>>,
+ pub align: Option<Handle<Expression<'a>>>,
+ pub size: Option<Handle<Expression<'a>>>,
+}
+
+#[derive(Debug)]
+pub struct Struct<'a> {
+ pub name: Ident<'a>,
+ pub members: Vec<StructMember<'a>>,
+}
+
+#[derive(Debug)]
+pub struct TypeAlias<'a> {
+ pub name: Ident<'a>,
+ pub ty: Handle<Type<'a>>,
+}
+
+#[derive(Debug)]
+pub struct Const<'a> {
+ pub name: Ident<'a>,
+ pub ty: Option<Handle<Type<'a>>>,
+ pub init: Handle<Expression<'a>>,
+}
+
+/// The size of an [`Array`] or [`BindingArray`].
+///
+/// [`Array`]: Type::Array
+/// [`BindingArray`]: Type::BindingArray
+#[derive(Debug, Copy, Clone)]
+pub enum ArraySize<'a> {
+ /// The length as a constant expression.
+ Constant(Handle<Expression<'a>>),
+ Dynamic,
+}
+
+#[derive(Debug)]
+pub enum Type<'a> {
+ Scalar(Scalar),
+ Vector {
+ size: crate::VectorSize,
+ scalar: Scalar,
+ },
+ Matrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: crate::Bytes,
+ },
+ Atomic(Scalar),
+ Pointer {
+ base: Handle<Type<'a>>,
+ space: crate::AddressSpace,
+ },
+ Array {
+ base: Handle<Type<'a>>,
+ size: ArraySize<'a>,
+ },
+ Image {
+ dim: crate::ImageDimension,
+ arrayed: bool,
+ class: crate::ImageClass,
+ },
+ Sampler {
+ comparison: bool,
+ },
+ AccelerationStructure,
+ RayQuery,
+ RayDesc,
+ RayIntersection,
+ BindingArray {
+ base: Handle<Type<'a>>,
+ size: ArraySize<'a>,
+ },
+
+ /// A user-defined type, like a struct or a type alias.
+ User(Ident<'a>),
+}
+
+#[derive(Debug, Default)]
+pub struct Block<'a> {
+ pub stmts: Vec<Statement<'a>>,
+}
+
+#[derive(Debug)]
+pub struct Statement<'a> {
+ pub kind: StatementKind<'a>,
+ pub span: Span,
+}
+
+#[derive(Debug)]
+pub enum StatementKind<'a> {
+ LocalDecl(LocalDecl<'a>),
+ Block(Block<'a>),
+ If {
+ condition: Handle<Expression<'a>>,
+ accept: Block<'a>,
+ reject: Block<'a>,
+ },
+ Switch {
+ selector: Handle<Expression<'a>>,
+ cases: Vec<SwitchCase<'a>>,
+ },
+ Loop {
+ body: Block<'a>,
+ continuing: Block<'a>,
+ break_if: Option<Handle<Expression<'a>>>,
+ },
+ Break,
+ Continue,
+ Return {
+ value: Option<Handle<Expression<'a>>>,
+ },
+ Kill,
+ Call {
+ function: Ident<'a>,
+ arguments: Vec<Handle<Expression<'a>>>,
+ },
+ Assign {
+ target: Handle<Expression<'a>>,
+ op: Option<crate::BinaryOperator>,
+ value: Handle<Expression<'a>>,
+ },
+ Increment(Handle<Expression<'a>>),
+ Decrement(Handle<Expression<'a>>),
+ Ignore(Handle<Expression<'a>>),
+}
+
+#[derive(Debug)]
+pub enum SwitchValue<'a> {
+ Expr(Handle<Expression<'a>>),
+ Default,
+}
+
+#[derive(Debug)]
+pub struct SwitchCase<'a> {
+ pub value: SwitchValue<'a>,
+ pub body: Block<'a>,
+ pub fall_through: bool,
+}
+
+/// A type at the head of a [`Construct`] expression.
+///
+/// WGSL has two types of [`type constructor expressions`]:
+///
+/// - Those that fully specify the type being constructed, like
+/// `vec3<f32>(x,y,z)`, which obviously constructs a `vec3<f32>`.
+///
+/// - Those that leave the component type of the composite being constructed
+/// implicit, to be inferred from the argument types, like `vec3(x,y,z)`,
+/// which constructs a `vec3<T>` where `T` is the type of `x`, `y`, and `z`.
+///
+/// This enum represents the head type of both cases. The `PartialFoo` variants
+/// represent the second case, where the component type is implicit.
+///
+/// This does not cover structs or types referred to by type aliases. See the
+/// documentation for [`Construct`] and [`Call`] expressions for details.
+///
+/// [`Construct`]: Expression::Construct
+/// [`type constructor expressions`]: https://gpuweb.github.io/gpuweb/wgsl/#type-constructor-expr
+/// [`Call`]: Expression::Call
+#[derive(Debug)]
+pub enum ConstructorType<'a> {
+ /// A scalar type or conversion: `f32(1)`.
+ Scalar(Scalar),
+
+ /// A vector construction whose component type is inferred from the
+ /// argument: `vec3(1.0)`.
+ PartialVector { size: crate::VectorSize },
+
+ /// A vector construction whose component type is written out:
+ /// `vec3<f32>(1.0)`.
+ Vector {
+ size: crate::VectorSize,
+ scalar: Scalar,
+ },
+
+ /// A matrix construction whose component type is inferred from the
+ /// argument: `mat2x2(1,2,3,4)`.
+ PartialMatrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ },
+
+ /// A matrix construction whose component type is written out:
+ /// `mat2x2<f32>(1,2,3,4)`.
+ Matrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: crate::Bytes,
+ },
+
+ /// An array whose component type and size are inferred from the arguments:
+ /// `array(3,4,5)`.
+ PartialArray,
+
+ /// An array whose component type and size are written out:
+ /// `array<u32, 4>(3,4,5)`.
+ Array {
+ base: Handle<Type<'a>>,
+ size: ArraySize<'a>,
+ },
+
+ /// Constructing a value of a known Naga IR type.
+ ///
+ /// This variant is produced only during lowering, when we have Naga types
+ /// available, never during parsing.
+ Type(Handle<crate::Type>),
+}
+
+#[derive(Debug, Copy, Clone)]
+pub enum Literal {
+ Bool(bool),
+ Number(Number),
+}
+
+#[cfg(doc)]
+use crate::front::wgsl::lower::Lowerer;
+
+#[derive(Debug)]
+pub enum Expression<'a> {
+ Literal(Literal),
+ Ident(IdentExpr<'a>),
+
+ /// A type constructor expression.
+ ///
+ /// This is only used for expressions like `KEYWORD(EXPR...)` and
+ /// `KEYWORD<PARAM>(EXPR...)`, where `KEYWORD` is a [type-defining keyword] like
+ /// `vec3`. These keywords cannot be shadowed by user definitions, so we can
+ /// tell that such an expression is a construction immediately.
+ ///
+ /// For ordinary identifiers, we can't tell whether an expression like
+ /// `IDENTIFIER(EXPR, ...)` is a construction expression or a function call
+ /// until we know `IDENTIFIER`'s definition, so we represent those as
+ /// [`Call`] expressions.
+ ///
+ /// [type-defining keyword]: https://gpuweb.github.io/gpuweb/wgsl/#type-defining-keywords
+ /// [`Call`]: Expression::Call
+ Construct {
+ ty: ConstructorType<'a>,
+ ty_span: Span,
+ components: Vec<Handle<Expression<'a>>>,
+ },
+ Unary {
+ op: crate::UnaryOperator,
+ expr: Handle<Expression<'a>>,
+ },
+ AddrOf(Handle<Expression<'a>>),
+ Deref(Handle<Expression<'a>>),
+ Binary {
+ op: crate::BinaryOperator,
+ left: Handle<Expression<'a>>,
+ right: Handle<Expression<'a>>,
+ },
+
+ /// A function call or type constructor expression.
+ ///
+ /// We can't tell whether an expression like `IDENTIFIER(EXPR, ...)` is a
+ /// construction expression or a function call until we know `IDENTIFIER`'s
+ /// definition, so we represent everything of that form as one of these
+ /// expressions until lowering. At that point, [`Lowerer::call`] has
+ /// everything's definition in hand, and can decide whether to emit a Naga
+ /// [`Constant`], [`As`], [`Splat`], or [`Compose`] expression.
+ ///
+ /// [`Lowerer::call`]: Lowerer::call
+ /// [`Constant`]: crate::Expression::Constant
+ /// [`As`]: crate::Expression::As
+ /// [`Splat`]: crate::Expression::Splat
+ /// [`Compose`]: crate::Expression::Compose
+ Call {
+ function: Ident<'a>,
+ arguments: Vec<Handle<Expression<'a>>>,
+ },
+ Index {
+ base: Handle<Expression<'a>>,
+ index: Handle<Expression<'a>>,
+ },
+ Member {
+ base: Handle<Expression<'a>>,
+ field: Ident<'a>,
+ },
+ Bitcast {
+ expr: Handle<Expression<'a>>,
+ to: Handle<Type<'a>>,
+ ty_span: Span,
+ },
+}
+
+#[derive(Debug)]
+pub struct LocalVariable<'a> {
+ pub name: Ident<'a>,
+ pub ty: Option<Handle<Type<'a>>>,
+ pub init: Option<Handle<Expression<'a>>>,
+ pub handle: Handle<Local>,
+}
+
+#[derive(Debug)]
+pub struct Let<'a> {
+ pub name: Ident<'a>,
+ pub ty: Option<Handle<Type<'a>>>,
+ pub init: Handle<Expression<'a>>,
+ pub handle: Handle<Local>,
+}
+
+#[derive(Debug)]
+pub enum LocalDecl<'a> {
+ Var(LocalVariable<'a>),
+ Let(Let<'a>),
+}
+
+#[derive(Debug)]
+/// A placeholder for a local variable declaration.
+///
+/// See [`Function::locals`] for more information.
+pub struct Local;
diff --git a/third_party/rust/naga/src/front/wgsl/parse/conv.rs b/third_party/rust/naga/src/front/wgsl/parse/conv.rs
new file mode 100644
index 0000000000..08f1e39285
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/conv.rs
@@ -0,0 +1,254 @@
+use super::Error;
+use crate::front::wgsl::Scalar;
+use crate::Span;
+
+pub fn map_address_space(word: &str, span: Span) -> Result<crate::AddressSpace, Error<'_>> {
+ match word {
+ "private" => Ok(crate::AddressSpace::Private),
+ "workgroup" => Ok(crate::AddressSpace::WorkGroup),
+ "uniform" => Ok(crate::AddressSpace::Uniform),
+ "storage" => Ok(crate::AddressSpace::Storage {
+ access: crate::StorageAccess::default(),
+ }),
+ "push_constant" => Ok(crate::AddressSpace::PushConstant),
+ "function" => Ok(crate::AddressSpace::Function),
+ _ => Err(Error::UnknownAddressSpace(span)),
+ }
+}
+
+pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>> {
+ Ok(match word {
+ "position" => crate::BuiltIn::Position { invariant: false },
+ // vertex
+ "vertex_index" => crate::BuiltIn::VertexIndex,
+ "instance_index" => crate::BuiltIn::InstanceIndex,
+ "view_index" => crate::BuiltIn::ViewIndex,
+ // fragment
+ "front_facing" => crate::BuiltIn::FrontFacing,
+ "frag_depth" => crate::BuiltIn::FragDepth,
+ "primitive_index" => crate::BuiltIn::PrimitiveIndex,
+ "sample_index" => crate::BuiltIn::SampleIndex,
+ "sample_mask" => crate::BuiltIn::SampleMask,
+ // compute
+ "global_invocation_id" => crate::BuiltIn::GlobalInvocationId,
+ "local_invocation_id" => crate::BuiltIn::LocalInvocationId,
+ "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
+ "workgroup_id" => crate::BuiltIn::WorkGroupId,
+ "num_workgroups" => crate::BuiltIn::NumWorkGroups,
+ _ => return Err(Error::UnknownBuiltin(span)),
+ })
+}
+
+pub fn map_interpolation(word: &str, span: Span) -> Result<crate::Interpolation, Error<'_>> {
+ match word {
+ "linear" => Ok(crate::Interpolation::Linear),
+ "flat" => Ok(crate::Interpolation::Flat),
+ "perspective" => Ok(crate::Interpolation::Perspective),
+ _ => Err(Error::UnknownAttribute(span)),
+ }
+}
+
+pub fn map_sampling(word: &str, span: Span) -> Result<crate::Sampling, Error<'_>> {
+ match word {
+ "center" => Ok(crate::Sampling::Center),
+ "centroid" => Ok(crate::Sampling::Centroid),
+ "sample" => Ok(crate::Sampling::Sample),
+ _ => Err(Error::UnknownAttribute(span)),
+ }
+}
+
+pub fn map_storage_format(word: &str, span: Span) -> Result<crate::StorageFormat, Error<'_>> {
+ use crate::StorageFormat as Sf;
+ Ok(match word {
+ "r8unorm" => Sf::R8Unorm,
+ "r8snorm" => Sf::R8Snorm,
+ "r8uint" => Sf::R8Uint,
+ "r8sint" => Sf::R8Sint,
+ "r16unorm" => Sf::R16Unorm,
+ "r16snorm" => Sf::R16Snorm,
+ "r16uint" => Sf::R16Uint,
+ "r16sint" => Sf::R16Sint,
+ "r16float" => Sf::R16Float,
+ "rg8unorm" => Sf::Rg8Unorm,
+ "rg8snorm" => Sf::Rg8Snorm,
+ "rg8uint" => Sf::Rg8Uint,
+ "rg8sint" => Sf::Rg8Sint,
+ "r32uint" => Sf::R32Uint,
+ "r32sint" => Sf::R32Sint,
+ "r32float" => Sf::R32Float,
+ "rg16unorm" => Sf::Rg16Unorm,
+ "rg16snorm" => Sf::Rg16Snorm,
+ "rg16uint" => Sf::Rg16Uint,
+ "rg16sint" => Sf::Rg16Sint,
+ "rg16float" => Sf::Rg16Float,
+ "rgba8unorm" => Sf::Rgba8Unorm,
+ "rgba8snorm" => Sf::Rgba8Snorm,
+ "rgba8uint" => Sf::Rgba8Uint,
+ "rgba8sint" => Sf::Rgba8Sint,
+ "rgb10a2uint" => Sf::Rgb10a2Uint,
+ "rgb10a2unorm" => Sf::Rgb10a2Unorm,
+ "rg11b10float" => Sf::Rg11b10Float,
+ "rg32uint" => Sf::Rg32Uint,
+ "rg32sint" => Sf::Rg32Sint,
+ "rg32float" => Sf::Rg32Float,
+ "rgba16unorm" => Sf::Rgba16Unorm,
+ "rgba16snorm" => Sf::Rgba16Snorm,
+ "rgba16uint" => Sf::Rgba16Uint,
+ "rgba16sint" => Sf::Rgba16Sint,
+ "rgba16float" => Sf::Rgba16Float,
+ "rgba32uint" => Sf::Rgba32Uint,
+ "rgba32sint" => Sf::Rgba32Sint,
+ "rgba32float" => Sf::Rgba32Float,
+ "bgra8unorm" => Sf::Bgra8Unorm,
+ _ => return Err(Error::UnknownStorageFormat(span)),
+ })
+}
+
+pub fn get_scalar_type(word: &str) -> Option<Scalar> {
+ use crate::ScalarKind as Sk;
+ match word {
+ // "f16" => Some(Scalar { kind: Sk::Float, width: 2 }),
+ "f32" => Some(Scalar {
+ kind: Sk::Float,
+ width: 4,
+ }),
+ "f64" => Some(Scalar {
+ kind: Sk::Float,
+ width: 8,
+ }),
+ "i32" => Some(Scalar {
+ kind: Sk::Sint,
+ width: 4,
+ }),
+ "u32" => Some(Scalar {
+ kind: Sk::Uint,
+ width: 4,
+ }),
+ "bool" => Some(Scalar {
+ kind: Sk::Bool,
+ width: crate::BOOL_WIDTH,
+ }),
+ _ => None,
+ }
+}
+
+pub fn map_derivative(word: &str) -> Option<(crate::DerivativeAxis, crate::DerivativeControl)> {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ match word {
+ "dpdxCoarse" => Some((Axis::X, Ctrl::Coarse)),
+ "dpdyCoarse" => Some((Axis::Y, Ctrl::Coarse)),
+ "fwidthCoarse" => Some((Axis::Width, Ctrl::Coarse)),
+ "dpdxFine" => Some((Axis::X, Ctrl::Fine)),
+ "dpdyFine" => Some((Axis::Y, Ctrl::Fine)),
+ "fwidthFine" => Some((Axis::Width, Ctrl::Fine)),
+ "dpdx" => Some((Axis::X, Ctrl::None)),
+ "dpdy" => Some((Axis::Y, Ctrl::None)),
+ "fwidth" => Some((Axis::Width, Ctrl::None)),
+ _ => None,
+ }
+}
+
+pub fn map_relational_fun(word: &str) -> Option<crate::RelationalFunction> {
+ match word {
+ "any" => Some(crate::RelationalFunction::Any),
+ "all" => Some(crate::RelationalFunction::All),
+ _ => None,
+ }
+}
+
+pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
+ use crate::MathFunction as Mf;
+ Some(match word {
+ // comparison
+ "abs" => Mf::Abs,
+ "min" => Mf::Min,
+ "max" => Mf::Max,
+ "clamp" => Mf::Clamp,
+ "saturate" => Mf::Saturate,
+ // trigonometry
+ "cos" => Mf::Cos,
+ "cosh" => Mf::Cosh,
+ "sin" => Mf::Sin,
+ "sinh" => Mf::Sinh,
+ "tan" => Mf::Tan,
+ "tanh" => Mf::Tanh,
+ "acos" => Mf::Acos,
+ "acosh" => Mf::Acosh,
+ "asin" => Mf::Asin,
+ "asinh" => Mf::Asinh,
+ "atan" => Mf::Atan,
+ "atanh" => Mf::Atanh,
+ "atan2" => Mf::Atan2,
+ "radians" => Mf::Radians,
+ "degrees" => Mf::Degrees,
+ // decomposition
+ "ceil" => Mf::Ceil,
+ "floor" => Mf::Floor,
+ "round" => Mf::Round,
+ "fract" => Mf::Fract,
+ "trunc" => Mf::Trunc,
+ "modf" => Mf::Modf,
+ "frexp" => Mf::Frexp,
+ "ldexp" => Mf::Ldexp,
+ // exponent
+ "exp" => Mf::Exp,
+ "exp2" => Mf::Exp2,
+ "log" => Mf::Log,
+ "log2" => Mf::Log2,
+ "pow" => Mf::Pow,
+ // geometry
+ "dot" => Mf::Dot,
+ "cross" => Mf::Cross,
+ "distance" => Mf::Distance,
+ "length" => Mf::Length,
+ "normalize" => Mf::Normalize,
+ "faceForward" => Mf::FaceForward,
+ "reflect" => Mf::Reflect,
+ "refract" => Mf::Refract,
+ // computational
+ "sign" => Mf::Sign,
+ "fma" => Mf::Fma,
+ "mix" => Mf::Mix,
+ "step" => Mf::Step,
+ "smoothstep" => Mf::SmoothStep,
+ "sqrt" => Mf::Sqrt,
+ "inverseSqrt" => Mf::InverseSqrt,
+ "transpose" => Mf::Transpose,
+ "determinant" => Mf::Determinant,
+ // bits
+ "countTrailingZeros" => Mf::CountTrailingZeros,
+ "countLeadingZeros" => Mf::CountLeadingZeros,
+ "countOneBits" => Mf::CountOneBits,
+ "reverseBits" => Mf::ReverseBits,
+ "extractBits" => Mf::ExtractBits,
+ "insertBits" => Mf::InsertBits,
+ "firstTrailingBit" => Mf::FindLsb,
+ "firstLeadingBit" => Mf::FindMsb,
+ // data packing
+ "pack4x8snorm" => Mf::Pack4x8snorm,
+ "pack4x8unorm" => Mf::Pack4x8unorm,
+ "pack2x16snorm" => Mf::Pack2x16snorm,
+ "pack2x16unorm" => Mf::Pack2x16unorm,
+ "pack2x16float" => Mf::Pack2x16float,
+ // data unpacking
+ "unpack4x8snorm" => Mf::Unpack4x8snorm,
+ "unpack4x8unorm" => Mf::Unpack4x8unorm,
+ "unpack2x16snorm" => Mf::Unpack2x16snorm,
+ "unpack2x16unorm" => Mf::Unpack2x16unorm,
+ "unpack2x16float" => Mf::Unpack2x16float,
+ _ => return None,
+ })
+}
+
+pub fn map_conservative_depth(
+ word: &str,
+ span: Span,
+) -> Result<crate::ConservativeDepth, Error<'_>> {
+ use crate::ConservativeDepth as Cd;
+ match word {
+ "greater_equal" => Ok(Cd::GreaterEqual),
+ "less_equal" => Ok(Cd::LessEqual),
+ "unchanged" => Ok(Cd::Unchanged),
+ _ => Err(Error::UnknownConservativeDepth(span)),
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/lexer.rs b/third_party/rust/naga/src/front/wgsl/parse/lexer.rs
new file mode 100644
index 0000000000..d03a448561
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/lexer.rs
@@ -0,0 +1,739 @@
+use super::{number::consume_number, Error, ExpectedToken};
+use crate::front::wgsl::error::NumberError;
+use crate::front::wgsl::parse::{conv, Number};
+use crate::front::wgsl::Scalar;
+use crate::Span;
+
+type TokenSpan<'a> = (Token<'a>, Span);
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum Token<'a> {
+ Separator(char),
+ Paren(char),
+ Attribute,
+ Number(Result<Number, NumberError>),
+ Word(&'a str),
+ Operation(char),
+ LogicalOperation(char),
+ ShiftOperation(char),
+ AssignmentOperation(char),
+ IncrementOperation,
+ DecrementOperation,
+ Arrow,
+ Unknown(char),
+ Trivia,
+ End,
+}
+
+fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str) {
+ let pos = input.find(|c| !what(c)).unwrap_or(input.len());
+ input.split_at(pos)
+}
+
+/// Return the token at the start of `input`.
+///
+/// If `generic` is `false`, then the bit shift operators `>>` or `<<`
+/// are valid lookahead tokens for the current parser state (see [§3.1
+/// Parsing] in the WGSL specification). In other words:
+///
+/// - If `generic` is `true`, then we are expecting an angle bracket
+/// around a generic type parameter, like the `<` and `>` in
+/// `vec3<f32>`, so interpret `<` and `>` as `Token::Paren` tokens,
+/// even if they're part of `<<` or `>>` sequences.
+///
+/// - Otherwise, interpret `<<` and `>>` as shift operators:
+/// `Token::LogicalOperation` tokens.
+///
+/// [§3.1 Parsing]: https://gpuweb.github.io/gpuweb/wgsl/#parsing
+fn consume_token(input: &str, generic: bool) -> (Token<'_>, &str) {
+ let mut chars = input.chars();
+ let cur = match chars.next() {
+ Some(c) => c,
+ None => return (Token::End, ""),
+ };
+ match cur {
+ ':' | ';' | ',' => (Token::Separator(cur), chars.as_str()),
+ '.' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('0'..='9') => consume_number(input),
+ _ => (Token::Separator(cur), og_chars),
+ }
+ }
+ '@' => (Token::Attribute, chars.as_str()),
+ '(' | ')' | '{' | '}' | '[' | ']' => (Token::Paren(cur), chars.as_str()),
+ '<' | '>' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') if !generic => (Token::LogicalOperation(cur), chars.as_str()),
+ Some(c) if c == cur && !generic => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::ShiftOperation(cur), og_chars),
+ }
+ }
+ _ => (Token::Paren(cur), og_chars),
+ }
+ }
+ '0'..='9' => consume_number(input),
+ '/' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('/') => {
+ let _ = chars.position(is_comment_end);
+ (Token::Trivia, chars.as_str())
+ }
+ Some('*') => {
+ let mut depth = 1;
+ let mut prev = None;
+
+ for c in &mut chars {
+ match (prev, c) {
+ (Some('*'), '/') => {
+ prev = None;
+ depth -= 1;
+ if depth == 0 {
+ return (Token::Trivia, chars.as_str());
+ }
+ }
+ (Some('/'), '*') => {
+ prev = None;
+ depth += 1;
+ }
+ _ => {
+ prev = Some(c);
+ }
+ }
+ }
+
+ (Token::End, "")
+ }
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '-' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('>') => (Token::Arrow, chars.as_str()),
+ Some('-') => (Token::DecrementOperation, chars.as_str()),
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '+' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('+') => (Token::IncrementOperation, chars.as_str()),
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '*' | '%' | '^' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '~' => (Token::Operation(cur), chars.as_str()),
+ '=' | '!' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('=') => (Token::LogicalOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '&' | '|' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some(c) if c == cur => (Token::LogicalOperation(cur), chars.as_str()),
+ Some('=') => (Token::AssignmentOperation(cur), chars.as_str()),
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ _ if is_blankspace(cur) => {
+ let (_, rest) = consume_any(input, is_blankspace);
+ (Token::Trivia, rest)
+ }
+ _ if is_word_start(cur) => {
+ let (word, rest) = consume_any(input, is_word_part);
+ (Token::Word(word), rest)
+ }
+ _ => (Token::Unknown(cur), chars.as_str()),
+ }
+}
+
+/// Returns whether or not a char is a comment end
+/// (Unicode Pattern_White_Space excluding U+0020, U+0009, U+200E and U+200F)
+const fn is_comment_end(c: char) -> bool {
+ match c {
+ '\u{000a}'..='\u{000d}' | '\u{0085}' | '\u{2028}' | '\u{2029}' => true,
+ _ => false,
+ }
+}
+
+/// Returns whether or not a char is a blankspace (Unicode Pattern_White_Space)
+const fn is_blankspace(c: char) -> bool {
+ match c {
+ '\u{0020}'
+ | '\u{0009}'..='\u{000d}'
+ | '\u{0085}'
+ | '\u{200e}'
+ | '\u{200f}'
+ | '\u{2028}'
+ | '\u{2029}' => true,
+ _ => false,
+ }
+}
+
+/// Returns whether or not a char is a word start (Unicode XID_Start + '_')
+fn is_word_start(c: char) -> bool {
+ c == '_' || unicode_xid::UnicodeXID::is_xid_start(c)
+}
+
+/// Returns whether or not a char is a word part (Unicode XID_Continue)
+fn is_word_part(c: char) -> bool {
+ unicode_xid::UnicodeXID::is_xid_continue(c)
+}
+
+#[derive(Clone)]
+pub(in crate::front::wgsl) struct Lexer<'a> {
+ input: &'a str,
+ pub(in crate::front::wgsl) source: &'a str,
+ // The byte offset of the end of the last non-trivia token.
+ last_end_offset: usize,
+}
+
+impl<'a> Lexer<'a> {
+ pub(in crate::front::wgsl) const fn new(input: &'a str) -> Self {
+ Lexer {
+ input,
+ source: input,
+ last_end_offset: 0,
+ }
+ }
+
+ /// Calls the function with a lexer and returns the result of the function as well as the span for everything the function parsed
+ ///
+ /// # Examples
+ /// ```ignore
+ /// let lexer = Lexer::new("5");
+ /// let (value, span) = lexer.capture_span(Lexer::next_uint_literal);
+ /// assert_eq!(value, 5);
+ /// ```
+ #[inline]
+ pub fn capture_span<T, E>(
+ &mut self,
+ inner: impl FnOnce(&mut Self) -> Result<T, E>,
+ ) -> Result<(T, Span), E> {
+ let start = self.current_byte_offset();
+ let res = inner(self)?;
+ let end = self.current_byte_offset();
+ Ok((res, Span::from(start..end)))
+ }
+
+ pub(in crate::front::wgsl) fn start_byte_offset(&mut self) -> usize {
+ loop {
+ // Eat all trivia because `next` doesn't eat trailing trivia.
+ let (token, rest) = consume_token(self.input, false);
+ if let Token::Trivia = token {
+ self.input = rest;
+ } else {
+ return self.current_byte_offset();
+ }
+ }
+ }
+
+ fn peek_token_and_rest(&mut self) -> (TokenSpan<'a>, &'a str) {
+ let mut cloned = self.clone();
+ let token = cloned.next();
+ let rest = cloned.input;
+ (token, rest)
+ }
+
+ const fn current_byte_offset(&self) -> usize {
+ self.source.len() - self.input.len()
+ }
+
+ pub(in crate::front::wgsl) fn span_from(&self, offset: usize) -> Span {
+ Span::from(offset..self.last_end_offset)
+ }
+
+ /// Return the next non-whitespace token from `self`.
+ ///
+ /// Assume we are a parse state where bit shift operators may
+ /// occur, but not angle brackets.
+ #[must_use]
+ pub(in crate::front::wgsl) fn next(&mut self) -> TokenSpan<'a> {
+ self.next_impl(false)
+ }
+
+ /// Return the next non-whitespace token from `self`.
+ ///
+ /// Assume we are in a parse state where angle brackets may occur,
+ /// but not bit shift operators.
+ #[must_use]
+ pub(in crate::front::wgsl) fn next_generic(&mut self) -> TokenSpan<'a> {
+ self.next_impl(true)
+ }
+
+ /// Return the next non-whitespace token from `self`, with a span.
+ ///
+ /// See [`consume_token`] for the meaning of `generic`.
+ fn next_impl(&mut self, generic: bool) -> TokenSpan<'a> {
+ let mut start_byte_offset = self.current_byte_offset();
+ loop {
+ let (token, rest) = consume_token(self.input, generic);
+ self.input = rest;
+ match token {
+ Token::Trivia => start_byte_offset = self.current_byte_offset(),
+ _ => {
+ self.last_end_offset = self.current_byte_offset();
+ return (token, self.span_from(start_byte_offset));
+ }
+ }
+ }
+ }
+
+ #[must_use]
+ pub(in crate::front::wgsl) fn peek(&mut self) -> TokenSpan<'a> {
+ let (token, _) = self.peek_token_and_rest();
+ token
+ }
+
+ pub(in crate::front::wgsl) fn expect_span(
+ &mut self,
+ expected: Token<'a>,
+ ) -> Result<Span, Error<'a>> {
+ let next = self.next();
+ if next.0 == expected {
+ Ok(next.1)
+ } else {
+ Err(Error::Unexpected(next.1, ExpectedToken::Token(expected)))
+ }
+ }
+
+ pub(in crate::front::wgsl) fn expect(&mut self, expected: Token<'a>) -> Result<(), Error<'a>> {
+ self.expect_span(expected)?;
+ Ok(())
+ }
+
+ pub(in crate::front::wgsl) fn expect_generic_paren(
+ &mut self,
+ expected: char,
+ ) -> Result<(), Error<'a>> {
+ let next = self.next_generic();
+ if next.0 == Token::Paren(expected) {
+ Ok(())
+ } else {
+ Err(Error::Unexpected(
+ next.1,
+ ExpectedToken::Token(Token::Paren(expected)),
+ ))
+ }
+ }
+
+ /// If the next token matches it is skipped and true is returned
+ pub(in crate::front::wgsl) fn skip(&mut self, what: Token<'_>) -> bool {
+ let (peeked_token, rest) = self.peek_token_and_rest();
+ if peeked_token.0 == what {
+ self.input = rest;
+ true
+ } else {
+ false
+ }
+ }
+
+ pub(in crate::front::wgsl) fn next_ident_with_span(
+ &mut self,
+ ) -> Result<(&'a str, Span), Error<'a>> {
+ match self.next() {
+ (Token::Word("_"), span) => Err(Error::InvalidIdentifierUnderscore(span)),
+ (Token::Word(word), span) if word.starts_with("__") => {
+ Err(Error::ReservedIdentifierPrefix(span))
+ }
+ (Token::Word(word), span) => Ok((word, span)),
+ other => Err(Error::Unexpected(other.1, ExpectedToken::Identifier)),
+ }
+ }
+
+ pub(in crate::front::wgsl) fn next_ident(
+ &mut self,
+ ) -> Result<super::ast::Ident<'a>, Error<'a>> {
+ let ident = self
+ .next_ident_with_span()
+ .map(|(name, span)| super::ast::Ident { name, span })?;
+
+ if crate::keywords::wgsl::RESERVED.contains(&ident.name) {
+ return Err(Error::ReservedKeyword(ident.span));
+ }
+
+ Ok(ident)
+ }
+
+ /// Parses a generic scalar type, for example `<f32>`.
+ pub(in crate::front::wgsl) fn next_scalar_generic(&mut self) -> Result<Scalar, Error<'a>> {
+ self.expect_generic_paren('<')?;
+ let pair = match self.next() {
+ (Token::Word(word), span) => {
+ conv::get_scalar_type(word).ok_or(Error::UnknownScalarType(span))
+ }
+ (_, span) => Err(Error::UnknownScalarType(span)),
+ }?;
+ self.expect_generic_paren('>')?;
+ Ok(pair)
+ }
+
+ /// Parses a generic scalar type, for example `<f32>`.
+ ///
+ /// Returns the span covering the inner type, excluding the brackets.
+ pub(in crate::front::wgsl) fn next_scalar_generic_with_span(
+ &mut self,
+ ) -> Result<(Scalar, Span), Error<'a>> {
+ self.expect_generic_paren('<')?;
+ let pair = match self.next() {
+ (Token::Word(word), span) => conv::get_scalar_type(word)
+ .map(|scalar| (scalar, span))
+ .ok_or(Error::UnknownScalarType(span)),
+ (_, span) => Err(Error::UnknownScalarType(span)),
+ }?;
+ self.expect_generic_paren('>')?;
+ Ok(pair)
+ }
+
+ pub(in crate::front::wgsl) fn next_storage_access(
+ &mut self,
+ ) -> Result<crate::StorageAccess, Error<'a>> {
+ let (ident, span) = self.next_ident_with_span()?;
+ match ident {
+ "read" => Ok(crate::StorageAccess::LOAD),
+ "write" => Ok(crate::StorageAccess::STORE),
+ "read_write" => Ok(crate::StorageAccess::LOAD | crate::StorageAccess::STORE),
+ _ => Err(Error::UnknownAccess(span)),
+ }
+ }
+
+ pub(in crate::front::wgsl) fn next_format_generic(
+ &mut self,
+ ) -> Result<(crate::StorageFormat, crate::StorageAccess), Error<'a>> {
+ self.expect(Token::Paren('<'))?;
+ let (ident, ident_span) = self.next_ident_with_span()?;
+ let format = conv::map_storage_format(ident, ident_span)?;
+ self.expect(Token::Separator(','))?;
+ let access = self.next_storage_access()?;
+ self.expect(Token::Paren('>'))?;
+ Ok((format, access))
+ }
+
+ pub(in crate::front::wgsl) fn open_arguments(&mut self) -> Result<(), Error<'a>> {
+ self.expect(Token::Paren('('))
+ }
+
+ pub(in crate::front::wgsl) fn close_arguments(&mut self) -> Result<(), Error<'a>> {
+ let _ = self.skip(Token::Separator(','));
+ self.expect(Token::Paren(')'))
+ }
+
+ pub(in crate::front::wgsl) fn next_argument(&mut self) -> Result<bool, Error<'a>> {
+ let paren = Token::Paren(')');
+ if self.skip(Token::Separator(',')) {
+ Ok(!self.skip(paren))
+ } else {
+ self.expect(paren).map(|()| false)
+ }
+ }
+}
+
+#[cfg(test)]
+#[track_caller]
+fn sub_test(source: &str, expected_tokens: &[Token]) {
+ let mut lex = Lexer::new(source);
+ for &token in expected_tokens {
+ assert_eq!(lex.next().0, token);
+ }
+ assert_eq!(lex.next().0, Token::End);
+}
+
+#[test]
+fn test_numbers() {
+ // WGSL spec examples //
+
+ // decimal integer
+ sub_test(
+ "0x123 0X123u 1u 123 0 0i 0x3f",
+ &[
+ Token::Number(Ok(Number::AbstractInt(291))),
+ Token::Number(Ok(Number::U32(291))),
+ Token::Number(Ok(Number::U32(1))),
+ Token::Number(Ok(Number::AbstractInt(123))),
+ Token::Number(Ok(Number::AbstractInt(0))),
+ Token::Number(Ok(Number::I32(0))),
+ Token::Number(Ok(Number::AbstractInt(63))),
+ ],
+ );
+ // decimal floating point
+ sub_test(
+ "0.e+4f 01. .01 12.34 .0f 0h 1e-3 0xa.fp+2 0x1P+4f 0X.3 0x3p+2h 0X1.fp-4 0x3.2p+2h",
+ &[
+ Token::Number(Ok(Number::F32(0.))),
+ Token::Number(Ok(Number::AbstractFloat(1.))),
+ Token::Number(Ok(Number::AbstractFloat(0.01))),
+ Token::Number(Ok(Number::AbstractFloat(12.34))),
+ Token::Number(Ok(Number::F32(0.))),
+ Token::Number(Err(NumberError::UnimplementedF16)),
+ Token::Number(Ok(Number::AbstractFloat(0.001))),
+ Token::Number(Ok(Number::AbstractFloat(43.75))),
+ Token::Number(Ok(Number::F32(16.))),
+ Token::Number(Ok(Number::AbstractFloat(0.1875))),
+ Token::Number(Err(NumberError::UnimplementedF16)),
+ Token::Number(Ok(Number::AbstractFloat(0.12109375))),
+ Token::Number(Err(NumberError::UnimplementedF16)),
+ ],
+ );
+
+ // MIN / MAX //
+
+ // min / max decimal integer
+ sub_test(
+ "0i 2147483647i 2147483648i",
+ &[
+ Token::Number(Ok(Number::I32(0))),
+ Token::Number(Ok(Number::I32(i32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+ // min / max decimal unsigned integer
+ sub_test(
+ "0u 4294967295u 4294967296u",
+ &[
+ Token::Number(Ok(Number::U32(u32::MIN))),
+ Token::Number(Ok(Number::U32(u32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+
+ // min / max hexadecimal signed integer
+ sub_test(
+ "0x0i 0x7FFFFFFFi 0x80000000i",
+ &[
+ Token::Number(Ok(Number::I32(0))),
+ Token::Number(Ok(Number::I32(i32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+ // min / max hexadecimal unsigned integer
+ sub_test(
+ "0x0u 0xFFFFFFFFu 0x100000000u",
+ &[
+ Token::Number(Ok(Number::U32(u32::MIN))),
+ Token::Number(Ok(Number::U32(u32::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+
+ // min/max decimal abstract int
+ sub_test(
+ "0 9223372036854775807 9223372036854775808",
+ &[
+ Token::Number(Ok(Number::AbstractInt(0))),
+ Token::Number(Ok(Number::AbstractInt(i64::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+
+ // min/max hexadecimal abstract int
+ sub_test(
+ "0 0x7fffffffffffffff 0x8000000000000000",
+ &[
+ Token::Number(Ok(Number::AbstractInt(0))),
+ Token::Number(Ok(Number::AbstractInt(i64::MAX))),
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+
+ /// ≈ 2^-126 * 2^−23 (= 2^−149)
+ const SMALLEST_POSITIVE_SUBNORMAL_F32: f32 = 1e-45;
+ /// ≈ 2^-126 * (1 − 2^−23)
+ const LARGEST_SUBNORMAL_F32: f32 = 1.1754942e-38;
+ /// ≈ 2^-126
+ const SMALLEST_POSITIVE_NORMAL_F32: f32 = f32::MIN_POSITIVE;
+ /// ≈ 1 − 2^−24
+ const LARGEST_F32_LESS_THAN_ONE: f32 = 0.99999994;
+ /// ≈ 1 + 2^−23
+ const SMALLEST_F32_LARGER_THAN_ONE: f32 = 1.0000001;
+ /// ≈ 2^127 * (2 − 2^−23)
+ const LARGEST_NORMAL_F32: f32 = f32::MAX;
+
+ // decimal floating point
+ sub_test(
+ "1e-45f 1.1754942e-38f 1.17549435e-38f 0.99999994f 1.0000001f 3.40282347e+38f",
+ &[
+ Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_SUBNORMAL_F32))),
+ Token::Number(Ok(Number::F32(LARGEST_SUBNORMAL_F32))),
+ Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_NORMAL_F32))),
+ Token::Number(Ok(Number::F32(LARGEST_F32_LESS_THAN_ONE))),
+ Token::Number(Ok(Number::F32(SMALLEST_F32_LARGER_THAN_ONE))),
+ Token::Number(Ok(Number::F32(LARGEST_NORMAL_F32))),
+ ],
+ );
+ sub_test(
+ "3.40282367e+38f",
+ &[
+ Token::Number(Err(NumberError::NotRepresentable)), // ≈ 2^128
+ ],
+ );
+
+ // hexadecimal floating point
+ sub_test(
+ "0x1p-149f 0x7FFFFFp-149f 0x1p-126f 0xFFFFFFp-24f 0x800001p-23f 0xFFFFFFp+104f",
+ &[
+ Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_SUBNORMAL_F32))),
+ Token::Number(Ok(Number::F32(LARGEST_SUBNORMAL_F32))),
+ Token::Number(Ok(Number::F32(SMALLEST_POSITIVE_NORMAL_F32))),
+ Token::Number(Ok(Number::F32(LARGEST_F32_LESS_THAN_ONE))),
+ Token::Number(Ok(Number::F32(SMALLEST_F32_LARGER_THAN_ONE))),
+ Token::Number(Ok(Number::F32(LARGEST_NORMAL_F32))),
+ ],
+ );
+ sub_test(
+ "0x1p128f 0x1.000001p0f",
+ &[
+ Token::Number(Err(NumberError::NotRepresentable)), // = 2^128
+ Token::Number(Err(NumberError::NotRepresentable)),
+ ],
+ );
+}
+
+#[test]
+fn double_floats() {
+ sub_test(
+ "0x1.2p4lf 0x1p8lf 0.0625lf 625e-4lf 10lf 10l",
+ &[
+ Token::Number(Ok(Number::F64(18.0))),
+ Token::Number(Ok(Number::F64(256.0))),
+ Token::Number(Ok(Number::F64(0.0625))),
+ Token::Number(Ok(Number::F64(0.0625))),
+ Token::Number(Ok(Number::F64(10.0))),
+ Token::Number(Ok(Number::AbstractInt(10))),
+ Token::Word("l"),
+ ],
+ )
+}
+
+#[test]
+fn test_tokens() {
+ sub_test("id123_OK", &[Token::Word("id123_OK")]);
+ sub_test(
+ "92No",
+ &[
+ Token::Number(Ok(Number::AbstractInt(92))),
+ Token::Word("No"),
+ ],
+ );
+ sub_test(
+ "2u3o",
+ &[
+ Token::Number(Ok(Number::U32(2))),
+ Token::Number(Ok(Number::AbstractInt(3))),
+ Token::Word("o"),
+ ],
+ );
+ sub_test(
+ "2.4f44po",
+ &[
+ Token::Number(Ok(Number::F32(2.4))),
+ Token::Number(Ok(Number::AbstractInt(44))),
+ Token::Word("po"),
+ ],
+ );
+ sub_test(
+ "Δέλτα réflexion Кызыл 𐰓𐰏𐰇 朝焼け سلام 검정 שָׁלוֹם गुलाबी փիրուզ",
+ &[
+ Token::Word("Δέλτα"),
+ Token::Word("réflexion"),
+ Token::Word("Кызыл"),
+ Token::Word("𐰓𐰏𐰇"),
+ Token::Word("朝焼け"),
+ Token::Word("سلام"),
+ Token::Word("검정"),
+ Token::Word("שָׁלוֹם"),
+ Token::Word("गुलाबी"),
+ Token::Word("փիրուզ"),
+ ],
+ );
+ sub_test("æNoø", &[Token::Word("æNoø")]);
+ sub_test("No¾", &[Token::Word("No"), Token::Unknown('¾')]);
+ sub_test("No好", &[Token::Word("No好")]);
+ sub_test("_No", &[Token::Word("_No")]);
+ sub_test(
+ "*/*/***/*//=/*****//",
+ &[
+ Token::Operation('*'),
+ Token::AssignmentOperation('/'),
+ Token::Operation('/'),
+ ],
+ );
+
+ // Type suffixes are only allowed on hex float literals
+ // if you provided an exponent.
+ sub_test(
+ "0x1.2f 0x1.2f 0x1.2h 0x1.2H 0x1.2lf",
+ &[
+ // The 'f' suffixes are taken as a hex digit:
+ // the fractional part is 0x2f / 256.
+ Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))),
+ Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))),
+ Token::Number(Ok(Number::AbstractFloat(1.125))),
+ Token::Word("h"),
+ Token::Number(Ok(Number::AbstractFloat(1.125))),
+ Token::Word("H"),
+ Token::Number(Ok(Number::AbstractFloat(1.125))),
+ Token::Word("lf"),
+ ],
+ )
+}
+
+#[test]
+fn test_variable_decl() {
+ sub_test(
+ "@group(0 ) var< uniform> texture: texture_multisampled_2d <f32 >;",
+ &[
+ Token::Attribute,
+ Token::Word("group"),
+ Token::Paren('('),
+ Token::Number(Ok(Number::AbstractInt(0))),
+ Token::Paren(')'),
+ Token::Word("var"),
+ Token::Paren('<'),
+ Token::Word("uniform"),
+ Token::Paren('>'),
+ Token::Word("texture"),
+ Token::Separator(':'),
+ Token::Word("texture_multisampled_2d"),
+ Token::Paren('<'),
+ Token::Word("f32"),
+ Token::Paren('>'),
+ Token::Separator(';'),
+ ],
+ );
+ sub_test(
+ "var<storage,read_write> buffer: array<u32>;",
+ &[
+ Token::Word("var"),
+ Token::Paren('<'),
+ Token::Word("storage"),
+ Token::Separator(','),
+ Token::Word("read_write"),
+ Token::Paren('>'),
+ Token::Word("buffer"),
+ Token::Separator(':'),
+ Token::Word("array"),
+ Token::Paren('<'),
+ Token::Word("u32"),
+ Token::Paren('>'),
+ Token::Separator(';'),
+ ],
+ );
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/mod.rs b/third_party/rust/naga/src/front/wgsl/parse/mod.rs
new file mode 100644
index 0000000000..51fc2f013b
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/mod.rs
@@ -0,0 +1,2350 @@
+use crate::front::wgsl::error::{Error, ExpectedToken};
+use crate::front::wgsl::parse::lexer::{Lexer, Token};
+use crate::front::wgsl::parse::number::Number;
+use crate::front::wgsl::Scalar;
+use crate::front::SymbolTable;
+use crate::{Arena, FastIndexSet, Handle, ShaderStage, Span};
+
+pub mod ast;
+pub mod conv;
+pub mod lexer;
+pub mod number;
+
+/// State for constructing an AST expression.
+///
+/// Not to be confused with [`lower::ExpressionContext`], which is for producing
+/// Naga IR from the AST we produce here.
+///
+/// [`lower::ExpressionContext`]: super::lower::ExpressionContext
+struct ExpressionContext<'input, 'temp, 'out> {
+ /// The [`TranslationUnit::expressions`] arena to which we should contribute
+ /// expressions.
+ ///
+ /// [`TranslationUnit::expressions`]: ast::TranslationUnit::expressions
+ expressions: &'out mut Arena<ast::Expression<'input>>,
+
+ /// The [`TranslationUnit::types`] arena to which we should contribute new
+ /// types.
+ ///
+ /// [`TranslationUnit::types`]: ast::TranslationUnit::types
+ types: &'out mut Arena<ast::Type<'input>>,
+
+ /// A map from identifiers in scope to the locals/arguments they represent.
+ ///
+ /// The handles refer to the [`Function::locals`] area; see that field's
+ /// documentation for details.
+ ///
+ /// [`Function::locals`]: ast::Function::locals
+ local_table: &'temp mut SymbolTable<&'input str, Handle<ast::Local>>,
+
+ /// The [`Function::locals`] arena for the function we're building.
+ ///
+ /// [`Function::locals`]: ast::Function::locals
+ locals: &'out mut Arena<ast::Local>,
+
+ /// Identifiers used by the current global declaration that have no local definition.
+ ///
+ /// This becomes the [`GlobalDecl`]'s [`dependencies`] set.
+ ///
+ /// Note that we don't know at parse time what kind of [`GlobalDecl`] the
+ /// name refers to. We can't look up names until we've seen the entire
+ /// translation unit.
+ ///
+ /// [`GlobalDecl`]: ast::GlobalDecl
+ /// [`dependencies`]: ast::GlobalDecl::dependencies
+ unresolved: &'out mut FastIndexSet<ast::Dependency<'input>>,
+}
+
+impl<'a> ExpressionContext<'a, '_, '_> {
+ fn parse_binary_op(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ classifier: impl Fn(Token<'a>) -> Option<crate::BinaryOperator>,
+ mut parser: impl FnMut(
+ &mut Lexer<'a>,
+ &mut Self,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ let start = lexer.start_byte_offset();
+ let mut accumulator = parser(lexer, self)?;
+ while let Some(op) = classifier(lexer.peek().0) {
+ let _ = lexer.next();
+ let left = accumulator;
+ let right = parser(lexer, self)?;
+ accumulator = self.expressions.append(
+ ast::Expression::Binary { op, left, right },
+ lexer.span_from(start),
+ );
+ }
+ Ok(accumulator)
+ }
+
+ fn declare_local(&mut self, name: ast::Ident<'a>) -> Result<Handle<ast::Local>, Error<'a>> {
+ let handle = self.locals.append(ast::Local, name.span);
+ if let Some(old) = self.local_table.add(name.name, handle) {
+ Err(Error::Redefinition {
+ previous: self.locals.get_span(old),
+ current: name.span,
+ })
+ } else {
+ Ok(handle)
+ }
+ }
+}
+
+/// Which grammar rule we are in the midst of parsing.
+///
+/// This is used for error checking. `Parser` maintains a stack of
+/// these and (occasionally) checks that it is being pushed and popped
+/// as expected.
+#[derive(Clone, Debug, PartialEq)]
+enum Rule {
+ Attribute,
+ VariableDecl,
+ TypeDecl,
+ FunctionDecl,
+ Block,
+ Statement,
+ PrimaryExpr,
+ SingularExpr,
+ UnaryExpr,
+ GeneralExpr,
+}
+
+struct ParsedAttribute<T> {
+ value: Option<T>,
+}
+
+impl<T> Default for ParsedAttribute<T> {
+ fn default() -> Self {
+ Self { value: None }
+ }
+}
+
+impl<T> ParsedAttribute<T> {
+ fn set(&mut self, value: T, name_span: Span) -> Result<(), Error<'static>> {
+ if self.value.is_some() {
+ return Err(Error::RepeatedAttribute(name_span));
+ }
+ self.value = Some(value);
+ Ok(())
+ }
+}
+
+#[derive(Default)]
+struct BindingParser<'a> {
+ location: ParsedAttribute<Handle<ast::Expression<'a>>>,
+ second_blend_source: ParsedAttribute<bool>,
+ built_in: ParsedAttribute<crate::BuiltIn>,
+ interpolation: ParsedAttribute<crate::Interpolation>,
+ sampling: ParsedAttribute<crate::Sampling>,
+ invariant: ParsedAttribute<bool>,
+}
+
+impl<'a> BindingParser<'a> {
+ fn parse(
+ &mut self,
+ parser: &mut Parser,
+ lexer: &mut Lexer<'a>,
+ name: &'a str,
+ name_span: Span,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<(), Error<'a>> {
+ match name {
+ "location" => {
+ lexer.expect(Token::Paren('('))?;
+ self.location
+ .set(parser.general_expression(lexer, ctx)?, name_span)?;
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "builtin" => {
+ lexer.expect(Token::Paren('('))?;
+ let (raw, span) = lexer.next_ident_with_span()?;
+ self.built_in
+ .set(conv::map_built_in(raw, span)?, name_span)?;
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "interpolate" => {
+ lexer.expect(Token::Paren('('))?;
+ let (raw, span) = lexer.next_ident_with_span()?;
+ self.interpolation
+ .set(conv::map_interpolation(raw, span)?, name_span)?;
+ if lexer.skip(Token::Separator(',')) {
+ let (raw, span) = lexer.next_ident_with_span()?;
+ self.sampling
+ .set(conv::map_sampling(raw, span)?, name_span)?;
+ }
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "second_blend_source" => {
+ self.second_blend_source.set(true, name_span)?;
+ }
+ "invariant" => {
+ self.invariant.set(true, name_span)?;
+ }
+ _ => return Err(Error::UnknownAttribute(name_span)),
+ }
+ Ok(())
+ }
+
+ fn finish(self, span: Span) -> Result<Option<ast::Binding<'a>>, Error<'a>> {
+ match (
+ self.location.value,
+ self.built_in.value,
+ self.interpolation.value,
+ self.sampling.value,
+ self.invariant.value.unwrap_or_default(),
+ ) {
+ (None, None, None, None, false) => Ok(None),
+ (Some(location), None, interpolation, sampling, false) => {
+ // Before handing over the completed `Module`, we call
+ // `apply_default_interpolation` to ensure that the interpolation and
+ // sampling have been explicitly specified on all vertex shader output and fragment
+ // shader input user bindings, so leaving them potentially `None` here is fine.
+ Ok(Some(ast::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ second_blend_source: self.second_blend_source.value.unwrap_or(false),
+ }))
+ }
+ (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant) => {
+ Ok(Some(ast::Binding::BuiltIn(crate::BuiltIn::Position {
+ invariant,
+ })))
+ }
+ (None, Some(built_in), None, None, false) => Ok(Some(ast::Binding::BuiltIn(built_in))),
+ (_, _, _, _, _) => Err(Error::InconsistentBinding(span)),
+ }
+ }
+}
+
+pub struct Parser {
+ rules: Vec<(Rule, usize)>,
+}
+
+impl Parser {
+ pub const fn new() -> Self {
+ Parser { rules: Vec::new() }
+ }
+
+ fn reset(&mut self) {
+ self.rules.clear();
+ }
+
+ fn push_rule_span(&mut self, rule: Rule, lexer: &mut Lexer<'_>) {
+ self.rules.push((rule, lexer.start_byte_offset()));
+ }
+
+ fn pop_rule_span(&mut self, lexer: &Lexer<'_>) -> Span {
+ let (_, initial) = self.rules.pop().unwrap();
+ lexer.span_from(initial)
+ }
+
+ fn peek_rule_span(&mut self, lexer: &Lexer<'_>) -> Span {
+ let &(_, initial) = self.rules.last().unwrap();
+ lexer.span_from(initial)
+ }
+
+ fn switch_value<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<ast::SwitchValue<'a>, Error<'a>> {
+ if let Token::Word("default") = lexer.peek().0 {
+ let _ = lexer.next();
+ return Ok(ast::SwitchValue::Default);
+ }
+
+ let expr = self.general_expression(lexer, ctx)?;
+ Ok(ast::SwitchValue::Expr(expr))
+ }
+
+ /// Decide if we're looking at a construction expression, and return its
+ /// type if so.
+ ///
+ /// If the identifier `word` is a [type-defining keyword], then return a
+ /// [`ConstructorType`] value describing the type to build. Return an error
+ /// if the type is not constructible (like `sampler`).
+ ///
+ /// If `word` isn't a type name, then return `None`.
+ ///
+ /// [type-defining keyword]: https://gpuweb.github.io/gpuweb/wgsl/#type-defining-keywords
+ /// [`ConstructorType`]: ast::ConstructorType
+ fn constructor_type<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ word: &'a str,
+ span: Span,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Option<ast::ConstructorType<'a>>, Error<'a>> {
+ if let Some(scalar) = conv::get_scalar_type(word) {
+ return Ok(Some(ast::ConstructorType::Scalar(scalar)));
+ }
+
+ let partial = match word {
+ "vec2" => ast::ConstructorType::PartialVector {
+ size: crate::VectorSize::Bi,
+ },
+ "vec2i" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ },
+ }))
+ }
+ "vec2u" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ }))
+ }
+ "vec2f" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: Scalar::F32,
+ }))
+ }
+ "vec3" => ast::ConstructorType::PartialVector {
+ size: crate::VectorSize::Tri,
+ },
+ "vec3i" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: Scalar::I32,
+ }))
+ }
+ "vec3u" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: Scalar::U32,
+ }))
+ }
+ "vec3f" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: Scalar::F32,
+ }))
+ }
+ "vec4" => ast::ConstructorType::PartialVector {
+ size: crate::VectorSize::Quad,
+ },
+ "vec4i" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Quad,
+ scalar: Scalar::I32,
+ }))
+ }
+ "vec4u" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Quad,
+ scalar: Scalar::U32,
+ }))
+ }
+ "vec4f" => {
+ return Ok(Some(ast::ConstructorType::Vector {
+ size: crate::VectorSize::Quad,
+ scalar: Scalar::F32,
+ }))
+ }
+ "mat2x2" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Bi,
+ },
+ "mat2x2f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }))
+ }
+ "mat2x3" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Tri,
+ },
+ "mat2x3f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ }))
+ }
+ "mat2x4" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Quad,
+ },
+ "mat2x4f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ }))
+ }
+ "mat3x2" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Bi,
+ },
+ "mat3x2f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }))
+ }
+ "mat3x3" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ },
+ "mat3x3f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ }))
+ }
+ "mat3x4" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Quad,
+ },
+ "mat3x4f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ }))
+ }
+ "mat4x2" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Bi,
+ },
+ "mat4x2f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }))
+ }
+ "mat4x3" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ },
+ "mat4x3f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ }))
+ }
+ "mat4x4" => ast::ConstructorType::PartialMatrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Quad,
+ },
+ "mat4x4f" => {
+ return Ok(Some(ast::ConstructorType::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ }))
+ }
+ "array" => ast::ConstructorType::PartialArray,
+ "atomic"
+ | "binding_array"
+ | "sampler"
+ | "sampler_comparison"
+ | "texture_1d"
+ | "texture_1d_array"
+ | "texture_2d"
+ | "texture_2d_array"
+ | "texture_3d"
+ | "texture_cube"
+ | "texture_cube_array"
+ | "texture_multisampled_2d"
+ | "texture_multisampled_2d_array"
+ | "texture_depth_2d"
+ | "texture_depth_2d_array"
+ | "texture_depth_cube"
+ | "texture_depth_cube_array"
+ | "texture_depth_multisampled_2d"
+ | "texture_storage_1d"
+ | "texture_storage_1d_array"
+ | "texture_storage_2d"
+ | "texture_storage_2d_array"
+ | "texture_storage_3d" => return Err(Error::TypeNotConstructible(span)),
+ _ => return Ok(None),
+ };
+
+ // parse component type if present
+ match (lexer.peek().0, partial) {
+ (Token::Paren('<'), ast::ConstructorType::PartialVector { size }) => {
+ let scalar = lexer.next_scalar_generic()?;
+ Ok(Some(ast::ConstructorType::Vector { size, scalar }))
+ }
+ (Token::Paren('<'), ast::ConstructorType::PartialMatrix { columns, rows }) => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ match scalar.kind {
+ crate::ScalarKind::Float => Ok(Some(ast::ConstructorType::Matrix {
+ columns,
+ rows,
+ width: scalar.width,
+ })),
+ _ => Err(Error::BadMatrixScalarKind(span, scalar)),
+ }
+ }
+ (Token::Paren('<'), ast::ConstructorType::PartialArray) => {
+ lexer.expect_generic_paren('<')?;
+ let base = self.type_decl(lexer, ctx)?;
+ let size = if lexer.skip(Token::Separator(',')) {
+ let expr = self.unary_expression(lexer, ctx)?;
+ ast::ArraySize::Constant(expr)
+ } else {
+ ast::ArraySize::Dynamic
+ };
+ lexer.expect_generic_paren('>')?;
+
+ Ok(Some(ast::ConstructorType::Array { base, size }))
+ }
+ (_, partial) => Ok(Some(partial)),
+ }
+ }
+
+ /// Expects `name` to be consumed (not in lexer).
+ fn arguments<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Vec<Handle<ast::Expression<'a>>>, Error<'a>> {
+ lexer.open_arguments()?;
+ let mut arguments = Vec::new();
+ loop {
+ if !arguments.is_empty() {
+ if !lexer.next_argument()? {
+ break;
+ }
+ } else if lexer.skip(Token::Paren(')')) {
+ break;
+ }
+ let arg = self.general_expression(lexer, ctx)?;
+ arguments.push(arg);
+ }
+
+ Ok(arguments)
+ }
+
+ /// Expects [`Rule::PrimaryExpr`] or [`Rule::SingularExpr`] on top; does not pop it.
+ /// Expects `name` to be consumed (not in lexer).
+ fn function_call<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ name: &'a str,
+ name_span: Span,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ assert!(self.rules.last().is_some());
+
+ let expr = match name {
+ // bitcast looks like a function call, but it's an operator and must be handled differently.
+ "bitcast" => {
+ lexer.expect_generic_paren('<')?;
+ let start = lexer.start_byte_offset();
+ let to = self.type_decl(lexer, ctx)?;
+ let span = lexer.span_from(start);
+ lexer.expect_generic_paren('>')?;
+
+ lexer.open_arguments()?;
+ let expr = self.general_expression(lexer, ctx)?;
+ lexer.close_arguments()?;
+
+ ast::Expression::Bitcast {
+ expr,
+ to,
+ ty_span: span,
+ }
+ }
+ // everything else must be handled later, since they can be hidden by user-defined functions.
+ _ => {
+ let arguments = self.arguments(lexer, ctx)?;
+ ctx.unresolved.insert(ast::Dependency {
+ ident: name,
+ usage: name_span,
+ });
+ ast::Expression::Call {
+ function: ast::Ident {
+ name,
+ span: name_span,
+ },
+ arguments,
+ }
+ }
+ };
+
+ let span = self.peek_rule_span(lexer);
+ let expr = ctx.expressions.append(expr, span);
+ Ok(expr)
+ }
+
+ fn ident_expr<'a>(
+ &mut self,
+ name: &'a str,
+ name_span: Span,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> ast::IdentExpr<'a> {
+ match ctx.local_table.lookup(name) {
+ Some(&local) => ast::IdentExpr::Local(local),
+ None => {
+ ctx.unresolved.insert(ast::Dependency {
+ ident: name,
+ usage: name_span,
+ });
+ ast::IdentExpr::Unresolved(name)
+ }
+ }
+ }
+
+ fn primary_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ self.push_rule_span(Rule::PrimaryExpr, lexer);
+
+ let expr = match lexer.peek() {
+ (Token::Paren('('), _) => {
+ let _ = lexer.next();
+ let expr = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Paren(')'))?;
+ self.pop_rule_span(lexer);
+ return Ok(expr);
+ }
+ (Token::Word("true"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Bool(true))
+ }
+ (Token::Word("false"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Bool(false))
+ }
+ (Token::Number(res), span) => {
+ let _ = lexer.next();
+ let num = res.map_err(|err| Error::BadNumber(span, err))?;
+ ast::Expression::Literal(ast::Literal::Number(num))
+ }
+ (Token::Word("RAY_FLAG_NONE"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Number(Number::U32(0)))
+ }
+ (Token::Word("RAY_FLAG_TERMINATE_ON_FIRST_HIT"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Number(Number::U32(4)))
+ }
+ (Token::Word("RAY_QUERY_INTERSECTION_NONE"), _) => {
+ let _ = lexer.next();
+ ast::Expression::Literal(ast::Literal::Number(Number::U32(0)))
+ }
+ (Token::Word(word), span) => {
+ let start = lexer.start_byte_offset();
+ let _ = lexer.next();
+
+ if let Some(ty) = self.constructor_type(lexer, word, span, ctx)? {
+ let ty_span = lexer.span_from(start);
+ let components = self.arguments(lexer, ctx)?;
+ ast::Expression::Construct {
+ ty,
+ ty_span,
+ components,
+ }
+ } else if let Token::Paren('(') = lexer.peek().0 {
+ self.pop_rule_span(lexer);
+ return self.function_call(lexer, word, span, ctx);
+ } else if word == "bitcast" {
+ self.pop_rule_span(lexer);
+ return self.function_call(lexer, word, span, ctx);
+ } else {
+ let ident = self.ident_expr(word, span, ctx);
+ ast::Expression::Ident(ident)
+ }
+ }
+ other => return Err(Error::Unexpected(other.1, ExpectedToken::PrimaryExpression)),
+ };
+
+ let span = self.pop_rule_span(lexer);
+ let expr = ctx.expressions.append(expr, span);
+ Ok(expr)
+ }
+
+ fn postfix<'a>(
+ &mut self,
+ span_start: usize,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ expr: Handle<ast::Expression<'a>>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ let mut expr = expr;
+
+ loop {
+ let expression = match lexer.peek().0 {
+ Token::Separator('.') => {
+ let _ = lexer.next();
+ let field = lexer.next_ident()?;
+
+ ast::Expression::Member { base: expr, field }
+ }
+ Token::Paren('[') => {
+ let _ = lexer.next();
+ let index = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Paren(']'))?;
+
+ ast::Expression::Index { base: expr, index }
+ }
+ _ => break,
+ };
+
+ let span = lexer.span_from(span_start);
+ expr = ctx.expressions.append(expression, span);
+ }
+
+ Ok(expr)
+ }
+
+ /// Parse a `unary_expression`.
+ fn unary_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ self.push_rule_span(Rule::UnaryExpr, lexer);
+ //TODO: refactor this to avoid backing up
+ let expr = match lexer.peek().0 {
+ Token::Operation('-') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx)?;
+ let expr = ast::Expression::Unary {
+ op: crate::UnaryOperator::Negate,
+ expr,
+ };
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ Token::Operation('!') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx)?;
+ let expr = ast::Expression::Unary {
+ op: crate::UnaryOperator::LogicalNot,
+ expr,
+ };
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ Token::Operation('~') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx)?;
+ let expr = ast::Expression::Unary {
+ op: crate::UnaryOperator::BitwiseNot,
+ expr,
+ };
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ Token::Operation('*') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx)?;
+ let expr = ast::Expression::Deref(expr);
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ Token::Operation('&') => {
+ let _ = lexer.next();
+ let expr = self.unary_expression(lexer, ctx)?;
+ let expr = ast::Expression::AddrOf(expr);
+ let span = self.peek_rule_span(lexer);
+ ctx.expressions.append(expr, span)
+ }
+ _ => self.singular_expression(lexer, ctx)?,
+ };
+
+ self.pop_rule_span(lexer);
+ Ok(expr)
+ }
+
+ /// Parse a `singular_expression`.
+ fn singular_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ let start = lexer.start_byte_offset();
+ self.push_rule_span(Rule::SingularExpr, lexer);
+ let primary_expr = self.primary_expression(lexer, ctx)?;
+ let singular_expr = self.postfix(start, lexer, ctx, primary_expr)?;
+ self.pop_rule_span(lexer);
+
+ Ok(singular_expr)
+ }
+
+ fn equality_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ context: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ // equality_expression
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::LogicalOperation('=') => Some(crate::BinaryOperator::Equal),
+ Token::LogicalOperation('!') => Some(crate::BinaryOperator::NotEqual),
+ _ => None,
+ },
+ // relational_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Paren('<') => Some(crate::BinaryOperator::Less),
+ Token::Paren('>') => Some(crate::BinaryOperator::Greater),
+ Token::LogicalOperation('<') => Some(crate::BinaryOperator::LessEqual),
+ Token::LogicalOperation('>') => Some(crate::BinaryOperator::GreaterEqual),
+ _ => None,
+ },
+ // shift_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::ShiftOperation('<') => {
+ Some(crate::BinaryOperator::ShiftLeft)
+ }
+ Token::ShiftOperation('>') => {
+ Some(crate::BinaryOperator::ShiftRight)
+ }
+ _ => None,
+ },
+ // additive_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('+') => Some(crate::BinaryOperator::Add),
+ Token::Operation('-') => {
+ Some(crate::BinaryOperator::Subtract)
+ }
+ _ => None,
+ },
+ // multiplicative_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('*') => {
+ Some(crate::BinaryOperator::Multiply)
+ }
+ Token::Operation('/') => {
+ Some(crate::BinaryOperator::Divide)
+ }
+ Token::Operation('%') => {
+ Some(crate::BinaryOperator::Modulo)
+ }
+ _ => None,
+ },
+ |lexer, context| self.unary_expression(lexer, context),
+ )
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )
+ }
+
+ fn general_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Expression<'a>>, Error<'a>> {
+ self.general_expression_with_span(lexer, ctx)
+ .map(|(expr, _)| expr)
+ }
+
+ fn general_expression_with_span<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ context: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<(Handle<ast::Expression<'a>>, Span), Error<'a>> {
+ self.push_rule_span(Rule::GeneralExpr, lexer);
+ // logical_or_expression
+ let handle = context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::LogicalOperation('|') => Some(crate::BinaryOperator::LogicalOr),
+ _ => None,
+ },
+ // logical_and_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::LogicalOperation('&') => Some(crate::BinaryOperator::LogicalAnd),
+ _ => None,
+ },
+ // inclusive_or_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('|') => Some(crate::BinaryOperator::InclusiveOr),
+ _ => None,
+ },
+ // exclusive_or_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('^') => {
+ Some(crate::BinaryOperator::ExclusiveOr)
+ }
+ _ => None,
+ },
+ // and_expression
+ |lexer, context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('&') => {
+ Some(crate::BinaryOperator::And)
+ }
+ _ => None,
+ },
+ |lexer, context| {
+ self.equality_expression(lexer, context)
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )?;
+ Ok((handle, self.pop_rule_span(lexer)))
+ }
+
+ fn variable_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<ast::GlobalVariable<'a>, Error<'a>> {
+ self.push_rule_span(Rule::VariableDecl, lexer);
+ let mut space = crate::AddressSpace::Handle;
+
+ if lexer.skip(Token::Paren('<')) {
+ let (class_str, span) = lexer.next_ident_with_span()?;
+ space = match class_str {
+ "storage" => {
+ let access = if lexer.skip(Token::Separator(',')) {
+ lexer.next_storage_access()?
+ } else {
+ // defaulting to `read`
+ crate::StorageAccess::LOAD
+ };
+ crate::AddressSpace::Storage { access }
+ }
+ _ => conv::map_address_space(class_str, span)?,
+ };
+ lexer.expect(Token::Paren('>'))?;
+ }
+ let name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(':'))?;
+ let ty = self.type_decl(lexer, ctx)?;
+
+ let init = if lexer.skip(Token::Operation('=')) {
+ let handle = self.general_expression(lexer, ctx)?;
+ Some(handle)
+ } else {
+ None
+ };
+ lexer.expect(Token::Separator(';'))?;
+ self.pop_rule_span(lexer);
+
+ Ok(ast::GlobalVariable {
+ name,
+ space,
+ binding: None,
+ ty,
+ init,
+ })
+ }
+
+ fn struct_body<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Vec<ast::StructMember<'a>>, Error<'a>> {
+ let mut members = Vec::new();
+
+ lexer.expect(Token::Paren('{'))?;
+ let mut ready = true;
+ while !lexer.skip(Token::Paren('}')) {
+ if !ready {
+ return Err(Error::Unexpected(
+ lexer.next().1,
+ ExpectedToken::Token(Token::Separator(',')),
+ ));
+ }
+ let (mut size, mut align) = (ParsedAttribute::default(), ParsedAttribute::default());
+ self.push_rule_span(Rule::Attribute, lexer);
+ let mut bind_parser = BindingParser::default();
+ while lexer.skip(Token::Attribute) {
+ match lexer.next_ident_with_span()? {
+ ("size", name_span) => {
+ lexer.expect(Token::Paren('('))?;
+ let expr = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Paren(')'))?;
+ size.set(expr, name_span)?;
+ }
+ ("align", name_span) => {
+ lexer.expect(Token::Paren('('))?;
+ let expr = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Paren(')'))?;
+ align.set(expr, name_span)?;
+ }
+ (word, word_span) => bind_parser.parse(self, lexer, word, word_span, ctx)?,
+ }
+ }
+
+ let bind_span = self.pop_rule_span(lexer);
+ let binding = bind_parser.finish(bind_span)?;
+
+ let name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(':'))?;
+ let ty = self.type_decl(lexer, ctx)?;
+ ready = lexer.skip(Token::Separator(','));
+
+ members.push(ast::StructMember {
+ name,
+ ty,
+ binding,
+ size: size.value,
+ align: align.value,
+ });
+ }
+
+ Ok(members)
+ }
+
+ fn matrix_scalar_type<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ ) -> Result<ast::Type<'a>, Error<'a>> {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ match scalar.kind {
+ crate::ScalarKind::Float => Ok(ast::Type::Matrix {
+ columns,
+ rows,
+ width: scalar.width,
+ }),
+ _ => Err(Error::BadMatrixScalarKind(span, scalar)),
+ }
+ }
+
+ fn type_decl_impl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ word: &'a str,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Option<ast::Type<'a>>, Error<'a>> {
+ if let Some(scalar) = conv::get_scalar_type(word) {
+ return Ok(Some(ast::Type::Scalar(scalar)));
+ }
+
+ Ok(Some(match word {
+ "vec2" => {
+ let scalar = lexer.next_scalar_generic()?;
+ ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ scalar,
+ }
+ }
+ "vec2i" => ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ },
+ },
+ "vec2u" => ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ },
+ "vec2f" => ast::Type::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: Scalar::F32,
+ },
+ "vec3" => {
+ let scalar = lexer.next_scalar_generic()?;
+ ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ scalar,
+ }
+ }
+ "vec3i" => ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ },
+ },
+ "vec3u" => ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ },
+ "vec3f" => ast::Type::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: Scalar::F32,
+ },
+ "vec4" => {
+ let scalar = lexer.next_scalar_generic()?;
+ ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ scalar,
+ }
+ }
+ "vec4i" => ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ },
+ },
+ "vec4u" => ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ scalar: Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ },
+ "vec4f" => ast::Type::Vector {
+ size: crate::VectorSize::Quad,
+ scalar: Scalar::F32,
+ },
+ "mat2x2" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Bi)?
+ }
+ "mat2x2f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ },
+ "mat2x3" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Tri)?
+ }
+ "mat2x3f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ },
+ "mat2x4" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Bi, crate::VectorSize::Quad)?
+ }
+ "mat2x4f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ },
+ "mat3x2" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Bi)?
+ }
+ "mat3x2f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ },
+ "mat3x3" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Tri)?
+ }
+ "mat3x3f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ },
+ "mat3x4" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Tri, crate::VectorSize::Quad)?
+ }
+ "mat3x4f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ },
+ "mat4x2" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Bi)?
+ }
+ "mat4x2f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ },
+ "mat4x3" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Tri)?
+ }
+ "mat4x3f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ },
+ "mat4x4" => {
+ self.matrix_scalar_type(lexer, crate::VectorSize::Quad, crate::VectorSize::Quad)?
+ }
+ "mat4x4f" => ast::Type::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Quad,
+ width: 4,
+ },
+ "atomic" => {
+ let scalar = lexer.next_scalar_generic()?;
+ ast::Type::Atomic(scalar)
+ }
+ "ptr" => {
+ lexer.expect_generic_paren('<')?;
+ let (ident, span) = lexer.next_ident_with_span()?;
+ let mut space = conv::map_address_space(ident, span)?;
+ lexer.expect(Token::Separator(','))?;
+ let base = self.type_decl(lexer, ctx)?;
+ if let crate::AddressSpace::Storage { ref mut access } = space {
+ *access = if lexer.skip(Token::Separator(',')) {
+ lexer.next_storage_access()?
+ } else {
+ crate::StorageAccess::LOAD
+ };
+ }
+ lexer.expect_generic_paren('>')?;
+ ast::Type::Pointer { base, space }
+ }
+ "array" => {
+ lexer.expect_generic_paren('<')?;
+ let base = self.type_decl(lexer, ctx)?;
+ let size = if lexer.skip(Token::Separator(',')) {
+ let size = self.unary_expression(lexer, ctx)?;
+ ast::ArraySize::Constant(size)
+ } else {
+ ast::ArraySize::Dynamic
+ };
+ lexer.expect_generic_paren('>')?;
+
+ ast::Type::Array { base, size }
+ }
+ "binding_array" => {
+ lexer.expect_generic_paren('<')?;
+ let base = self.type_decl(lexer, ctx)?;
+ let size = if lexer.skip(Token::Separator(',')) {
+ let size = self.unary_expression(lexer, ctx)?;
+ ast::ArraySize::Constant(size)
+ } else {
+ ast::ArraySize::Dynamic
+ };
+ lexer.expect_generic_paren('>')?;
+
+ ast::Type::BindingArray { base, size }
+ }
+ "sampler" => ast::Type::Sampler { comparison: false },
+ "sampler_comparison" => ast::Type::Sampler { comparison: true },
+ "texture_1d" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: false,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: false,
+ },
+ }
+ }
+ "texture_1d_array" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: true,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: false,
+ },
+ }
+ }
+ "texture_2d" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: false,
+ },
+ }
+ }
+ "texture_2d_array" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: false,
+ },
+ }
+ }
+ "texture_3d" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D3,
+ arrayed: false,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: false,
+ },
+ }
+ }
+ "texture_cube" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: false,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: false,
+ },
+ }
+ }
+ "texture_cube_array" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: true,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: false,
+ },
+ }
+ }
+ "texture_multisampled_2d" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: true,
+ },
+ }
+ }
+ "texture_multisampled_2d_array" => {
+ let (scalar, span) = lexer.next_scalar_generic_with_span()?;
+ Self::check_texture_sample_type(scalar, span)?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Sampled {
+ kind: scalar.kind,
+ multi: true,
+ },
+ }
+ }
+ "texture_depth_2d" => ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_2d_array" => ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_cube" => ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_cube_array" => ast::Type::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: true,
+ class: crate::ImageClass::Depth { multi: false },
+ },
+ "texture_depth_multisampled_2d" => ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: true },
+ },
+ "texture_storage_1d" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: false,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_1d_array" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: true,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_2d" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_2d_array" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "texture_storage_3d" => {
+ let (format, access) = lexer.next_format_generic()?;
+ ast::Type::Image {
+ dim: crate::ImageDimension::D3,
+ arrayed: false,
+ class: crate::ImageClass::Storage { format, access },
+ }
+ }
+ "acceleration_structure" => ast::Type::AccelerationStructure,
+ "ray_query" => ast::Type::RayQuery,
+ "RayDesc" => ast::Type::RayDesc,
+ "RayIntersection" => ast::Type::RayIntersection,
+ _ => return Ok(None),
+ }))
+ }
+
+ const fn check_texture_sample_type(scalar: Scalar, span: Span) -> Result<(), Error<'static>> {
+ use crate::ScalarKind::*;
+ // Validate according to https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type
+ match scalar {
+ Scalar {
+ kind: Float | Sint | Uint,
+ width: 4,
+ } => Ok(()),
+ _ => Err(Error::BadTextureSampleType { span, scalar }),
+ }
+ }
+
+ /// Parse type declaration of a given name.
+ fn type_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<ast::Type<'a>>, Error<'a>> {
+ self.push_rule_span(Rule::TypeDecl, lexer);
+
+ let (name, span) = lexer.next_ident_with_span()?;
+
+ let ty = match self.type_decl_impl(lexer, name, ctx)? {
+ Some(ty) => ty,
+ None => {
+ ctx.unresolved.insert(ast::Dependency {
+ ident: name,
+ usage: span,
+ });
+ ast::Type::User(ast::Ident { name, span })
+ }
+ };
+
+ self.pop_rule_span(lexer);
+
+ let handle = ctx.types.append(ty, Span::UNDEFINED);
+ Ok(handle)
+ }
+
+ fn assignment_op_and_rhs<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ block: &mut ast::Block<'a>,
+ target: Handle<ast::Expression<'a>>,
+ span_start: usize,
+ ) -> Result<(), Error<'a>> {
+ use crate::BinaryOperator as Bo;
+
+ let op = lexer.next();
+ let (op, value) = match op {
+ (Token::Operation('='), _) => {
+ let value = self.general_expression(lexer, ctx)?;
+ (None, value)
+ }
+ (Token::AssignmentOperation(c), _) => {
+ let op = match c {
+ '<' => Bo::ShiftLeft,
+ '>' => Bo::ShiftRight,
+ '+' => Bo::Add,
+ '-' => Bo::Subtract,
+ '*' => Bo::Multiply,
+ '/' => Bo::Divide,
+ '%' => Bo::Modulo,
+ '&' => Bo::And,
+ '|' => Bo::InclusiveOr,
+ '^' => Bo::ExclusiveOr,
+ // Note: `consume_token` shouldn't produce any other assignment ops
+ _ => unreachable!(),
+ };
+
+ let value = self.general_expression(lexer, ctx)?;
+ (Some(op), value)
+ }
+ token @ (Token::IncrementOperation | Token::DecrementOperation, _) => {
+ let op = match token.0 {
+ Token::IncrementOperation => ast::StatementKind::Increment,
+ Token::DecrementOperation => ast::StatementKind::Decrement,
+ _ => unreachable!(),
+ };
+
+ let span = lexer.span_from(span_start);
+ block.stmts.push(ast::Statement {
+ kind: op(target),
+ span,
+ });
+ return Ok(());
+ }
+ _ => return Err(Error::Unexpected(op.1, ExpectedToken::Assignment)),
+ };
+
+ let span = lexer.span_from(span_start);
+ block.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Assign { target, op, value },
+ span,
+ });
+ Ok(())
+ }
+
+ /// Parse an assignment statement (will also parse increment and decrement statements)
+ fn assignment_statement<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ block: &mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ let span_start = lexer.start_byte_offset();
+ let target = self.general_expression(lexer, ctx)?;
+ self.assignment_op_and_rhs(lexer, ctx, block, target, span_start)
+ }
+
+ /// Parse a function call statement.
+ /// Expects `ident` to be consumed (not in the lexer).
+ fn function_statement<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ident: &'a str,
+ ident_span: Span,
+ span_start: usize,
+ context: &mut ExpressionContext<'a, '_, '_>,
+ block: &mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ self.push_rule_span(Rule::SingularExpr, lexer);
+
+ context.unresolved.insert(ast::Dependency {
+ ident,
+ usage: ident_span,
+ });
+ let arguments = self.arguments(lexer, context)?;
+ let span = lexer.span_from(span_start);
+
+ block.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Call {
+ function: ast::Ident {
+ name: ident,
+ span: ident_span,
+ },
+ arguments,
+ },
+ span,
+ });
+
+ self.pop_rule_span(lexer);
+
+ Ok(())
+ }
+
+ fn function_call_or_assignment_statement<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ context: &mut ExpressionContext<'a, '_, '_>,
+ block: &mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ let span_start = lexer.start_byte_offset();
+ match lexer.peek() {
+ (Token::Word(name), span) => {
+ // A little hack for 2 token lookahead.
+ let cloned = lexer.clone();
+ let _ = lexer.next();
+ match lexer.peek() {
+ (Token::Paren('('), _) => {
+ self.function_statement(lexer, name, span, span_start, context, block)
+ }
+ _ => {
+ *lexer = cloned;
+ self.assignment_statement(lexer, context, block)
+ }
+ }
+ }
+ _ => self.assignment_statement(lexer, context, block),
+ }
+ }
+
+ fn statement<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ block: &mut ast::Block<'a>,
+ ) -> Result<(), Error<'a>> {
+ self.push_rule_span(Rule::Statement, lexer);
+ match lexer.peek() {
+ (Token::Separator(';'), _) => {
+ let _ = lexer.next();
+ self.pop_rule_span(lexer);
+ return Ok(());
+ }
+ (Token::Paren('{'), _) => {
+ let (inner, span) = self.block(lexer, ctx)?;
+ block.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Block(inner),
+ span,
+ });
+ self.pop_rule_span(lexer);
+ return Ok(());
+ }
+ (Token::Word(word), _) => {
+ let kind = match word {
+ "_" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Operation('='))?;
+ let expr = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Separator(';'))?;
+
+ ast::StatementKind::Ignore(expr)
+ }
+ "let" => {
+ let _ = lexer.next();
+ let name = lexer.next_ident()?;
+
+ let given_ty = if lexer.skip(Token::Separator(':')) {
+ let ty = self.type_decl(lexer, ctx)?;
+ Some(ty)
+ } else {
+ None
+ };
+ lexer.expect(Token::Operation('='))?;
+ let expr_id = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Separator(';'))?;
+
+ let handle = ctx.declare_local(name)?;
+ ast::StatementKind::LocalDecl(ast::LocalDecl::Let(ast::Let {
+ name,
+ ty: given_ty,
+ init: expr_id,
+ handle,
+ }))
+ }
+ "var" => {
+ let _ = lexer.next();
+
+ let name = lexer.next_ident()?;
+ let ty = if lexer.skip(Token::Separator(':')) {
+ let ty = self.type_decl(lexer, ctx)?;
+ Some(ty)
+ } else {
+ None
+ };
+
+ let init = if lexer.skip(Token::Operation('=')) {
+ let init = self.general_expression(lexer, ctx)?;
+ Some(init)
+ } else {
+ None
+ };
+
+ lexer.expect(Token::Separator(';'))?;
+
+ let handle = ctx.declare_local(name)?;
+ ast::StatementKind::LocalDecl(ast::LocalDecl::Var(ast::LocalVariable {
+ name,
+ ty,
+ init,
+ handle,
+ }))
+ }
+ "return" => {
+ let _ = lexer.next();
+ let value = if lexer.peek().0 != Token::Separator(';') {
+ let handle = self.general_expression(lexer, ctx)?;
+ Some(handle)
+ } else {
+ None
+ };
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Return { value }
+ }
+ "if" => {
+ let _ = lexer.next();
+ let condition = self.general_expression(lexer, ctx)?;
+
+ let accept = self.block(lexer, ctx)?.0;
+
+ let mut elsif_stack = Vec::new();
+ let mut elseif_span_start = lexer.start_byte_offset();
+ let mut reject = loop {
+ if !lexer.skip(Token::Word("else")) {
+ break ast::Block::default();
+ }
+
+ if !lexer.skip(Token::Word("if")) {
+ // ... else { ... }
+ break self.block(lexer, ctx)?.0;
+ }
+
+ // ... else if (...) { ... }
+ let other_condition = self.general_expression(lexer, ctx)?;
+ let other_block = self.block(lexer, ctx)?;
+ elsif_stack.push((elseif_span_start, other_condition, other_block));
+ elseif_span_start = lexer.start_byte_offset();
+ };
+
+ // reverse-fold the else-if blocks
+ //Note: we may consider uplifting this to the IR
+ for (other_span_start, other_cond, other_block) in
+ elsif_stack.into_iter().rev()
+ {
+ let sub_stmt = ast::StatementKind::If {
+ condition: other_cond,
+ accept: other_block.0,
+ reject,
+ };
+ reject = ast::Block::default();
+ let span = lexer.span_from(other_span_start);
+ reject.stmts.push(ast::Statement {
+ kind: sub_stmt,
+ span,
+ })
+ }
+
+ ast::StatementKind::If {
+ condition,
+ accept,
+ reject,
+ }
+ }
+ "switch" => {
+ let _ = lexer.next();
+ let selector = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Paren('{'))?;
+ let mut cases = Vec::new();
+
+ loop {
+ // cases + default
+ match lexer.next() {
+ (Token::Word("case"), _) => {
+ // parse a list of values
+ let value = loop {
+ let value = self.switch_value(lexer, ctx)?;
+ if lexer.skip(Token::Separator(',')) {
+ if lexer.skip(Token::Separator(':')) {
+ break value;
+ }
+ } else {
+ lexer.skip(Token::Separator(':'));
+ break value;
+ }
+ cases.push(ast::SwitchCase {
+ value,
+ body: ast::Block::default(),
+ fall_through: true,
+ });
+ };
+
+ let body = self.block(lexer, ctx)?.0;
+
+ cases.push(ast::SwitchCase {
+ value,
+ body,
+ fall_through: false,
+ });
+ }
+ (Token::Word("default"), _) => {
+ lexer.skip(Token::Separator(':'));
+ let body = self.block(lexer, ctx)?.0;
+ cases.push(ast::SwitchCase {
+ value: ast::SwitchValue::Default,
+ body,
+ fall_through: false,
+ });
+ }
+ (Token::Paren('}'), _) => break,
+ (_, span) => {
+ return Err(Error::Unexpected(span, ExpectedToken::SwitchItem))
+ }
+ }
+ }
+
+ ast::StatementKind::Switch { selector, cases }
+ }
+ "loop" => self.r#loop(lexer, ctx)?,
+ "while" => {
+ let _ = lexer.next();
+ let mut body = ast::Block::default();
+
+ let (condition, span) = lexer.capture_span(|lexer| {
+ let condition = self.general_expression(lexer, ctx)?;
+ Ok(condition)
+ })?;
+ let mut reject = ast::Block::default();
+ reject.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Break,
+ span,
+ });
+
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::If {
+ condition,
+ accept: ast::Block::default(),
+ reject,
+ },
+ span,
+ });
+
+ let (block, span) = self.block(lexer, ctx)?;
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Block(block),
+ span,
+ });
+
+ ast::StatementKind::Loop {
+ body,
+ continuing: ast::Block::default(),
+ break_if: None,
+ }
+ }
+ "for" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Paren('('))?;
+
+ ctx.local_table.push_scope();
+
+ if !lexer.skip(Token::Separator(';')) {
+ let num_statements = block.stmts.len();
+ let (_, span) = {
+ let ctx = &mut *ctx;
+ let block = &mut *block;
+ lexer.capture_span(|lexer| self.statement(lexer, ctx, block))?
+ };
+
+ if block.stmts.len() != num_statements {
+ match block.stmts.last().unwrap().kind {
+ ast::StatementKind::Call { .. }
+ | ast::StatementKind::Assign { .. }
+ | ast::StatementKind::LocalDecl(_) => {}
+ _ => return Err(Error::InvalidForInitializer(span)),
+ }
+ }
+ };
+
+ let mut body = ast::Block::default();
+ if !lexer.skip(Token::Separator(';')) {
+ let (condition, span) = lexer.capture_span(|lexer| {
+ let condition = self.general_expression(lexer, ctx)?;
+ lexer.expect(Token::Separator(';'))?;
+ Ok(condition)
+ })?;
+ let mut reject = ast::Block::default();
+ reject.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Break,
+ span,
+ });
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::If {
+ condition,
+ accept: ast::Block::default(),
+ reject,
+ },
+ span,
+ });
+ };
+
+ let mut continuing = ast::Block::default();
+ if !lexer.skip(Token::Paren(')')) {
+ self.function_call_or_assignment_statement(
+ lexer,
+ ctx,
+ &mut continuing,
+ )?;
+ lexer.expect(Token::Paren(')'))?;
+ }
+
+ let (block, span) = self.block(lexer, ctx)?;
+ body.stmts.push(ast::Statement {
+ kind: ast::StatementKind::Block(block),
+ span,
+ });
+
+ ctx.local_table.pop_scope();
+
+ ast::StatementKind::Loop {
+ body,
+ continuing,
+ break_if: None,
+ }
+ }
+ "break" => {
+ let (_, span) = lexer.next();
+ // Check if the next token is an `if`, this indicates
+ // that the user tried to type out a `break if` which
+ // is illegal in this position.
+ let (peeked_token, peeked_span) = lexer.peek();
+ if let Token::Word("if") = peeked_token {
+ let span = span.until(&peeked_span);
+ return Err(Error::InvalidBreakIf(span));
+ }
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Break
+ }
+ "continue" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Continue
+ }
+ "discard" => {
+ let _ = lexer.next();
+ lexer.expect(Token::Separator(';'))?;
+ ast::StatementKind::Kill
+ }
+ // assignment or a function call
+ _ => {
+ self.function_call_or_assignment_statement(lexer, ctx, block)?;
+ lexer.expect(Token::Separator(';'))?;
+ self.pop_rule_span(lexer);
+ return Ok(());
+ }
+ };
+
+ let span = self.pop_rule_span(lexer);
+ block.stmts.push(ast::Statement { kind, span });
+ }
+ _ => {
+ self.assignment_statement(lexer, ctx, block)?;
+ lexer.expect(Token::Separator(';'))?;
+ self.pop_rule_span(lexer);
+ }
+ }
+ Ok(())
+ }
+
+ fn r#loop<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<ast::StatementKind<'a>, Error<'a>> {
+ let _ = lexer.next();
+ let mut body = ast::Block::default();
+ let mut continuing = ast::Block::default();
+ let mut break_if = None;
+
+ lexer.expect(Token::Paren('{'))?;
+
+ ctx.local_table.push_scope();
+
+ loop {
+ if lexer.skip(Token::Word("continuing")) {
+ // Branch for the `continuing` block, this must be
+ // the last thing in the loop body
+
+ // Expect a opening brace to start the continuing block
+ lexer.expect(Token::Paren('{'))?;
+ loop {
+ if lexer.skip(Token::Word("break")) {
+ // Branch for the `break if` statement, this statement
+ // has the form `break if <expr>;` and must be the last
+ // statement in a continuing block
+
+ // The break must be followed by an `if` to form
+ // the break if
+ lexer.expect(Token::Word("if"))?;
+
+ let condition = self.general_expression(lexer, ctx)?;
+ // Set the condition of the break if to the newly parsed
+ // expression
+ break_if = Some(condition);
+
+ // Expect a semicolon to close the statement
+ lexer.expect(Token::Separator(';'))?;
+ // Expect a closing brace to close the continuing block,
+ // since the break if must be the last statement
+ lexer.expect(Token::Paren('}'))?;
+ // Stop parsing the continuing block
+ break;
+ } else if lexer.skip(Token::Paren('}')) {
+ // If we encounter a closing brace it means we have reached
+ // the end of the continuing block and should stop processing
+ break;
+ } else {
+ // Otherwise try to parse a statement
+ self.statement(lexer, ctx, &mut continuing)?;
+ }
+ }
+ // Since the continuing block must be the last part of the loop body,
+ // we expect to see a closing brace to end the loop body
+ lexer.expect(Token::Paren('}'))?;
+ break;
+ }
+ if lexer.skip(Token::Paren('}')) {
+ // If we encounter a closing brace it means we have reached
+ // the end of the loop body and should stop processing
+ break;
+ }
+ // Otherwise try to parse a statement
+ self.statement(lexer, ctx, &mut body)?;
+ }
+
+ ctx.local_table.pop_scope();
+
+ Ok(ast::StatementKind::Loop {
+ body,
+ continuing,
+ break_if,
+ })
+ }
+
+ /// compound_statement
+ fn block<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<(ast::Block<'a>, Span), Error<'a>> {
+ self.push_rule_span(Rule::Block, lexer);
+
+ ctx.local_table.push_scope();
+
+ lexer.expect(Token::Paren('{'))?;
+ let mut block = ast::Block::default();
+ while !lexer.skip(Token::Paren('}')) {
+ self.statement(lexer, ctx, &mut block)?;
+ }
+
+ ctx.local_table.pop_scope();
+
+ let span = self.pop_rule_span(lexer);
+ Ok((block, span))
+ }
+
+ fn varying_binding<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ ctx: &mut ExpressionContext<'a, '_, '_>,
+ ) -> Result<Option<ast::Binding<'a>>, Error<'a>> {
+ let mut bind_parser = BindingParser::default();
+ self.push_rule_span(Rule::Attribute, lexer);
+
+ while lexer.skip(Token::Attribute) {
+ let (word, span) = lexer.next_ident_with_span()?;
+ bind_parser.parse(self, lexer, word, span, ctx)?;
+ }
+
+ let span = self.pop_rule_span(lexer);
+ bind_parser.finish(span)
+ }
+
+ fn function_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ out: &mut ast::TranslationUnit<'a>,
+ dependencies: &mut FastIndexSet<ast::Dependency<'a>>,
+ ) -> Result<ast::Function<'a>, Error<'a>> {
+ self.push_rule_span(Rule::FunctionDecl, lexer);
+ // read function name
+ let fun_name = lexer.next_ident()?;
+
+ let mut locals = Arena::new();
+
+ let mut ctx = ExpressionContext {
+ expressions: &mut out.expressions,
+ local_table: &mut SymbolTable::default(),
+ locals: &mut locals,
+ types: &mut out.types,
+ unresolved: dependencies,
+ };
+
+ // start a scope that contains arguments as well as the function body
+ ctx.local_table.push_scope();
+
+ // read parameter list
+ let mut arguments = Vec::new();
+ lexer.expect(Token::Paren('('))?;
+ let mut ready = true;
+ while !lexer.skip(Token::Paren(')')) {
+ if !ready {
+ return Err(Error::Unexpected(
+ lexer.next().1,
+ ExpectedToken::Token(Token::Separator(',')),
+ ));
+ }
+ let binding = self.varying_binding(lexer, &mut ctx)?;
+
+ let param_name = lexer.next_ident()?;
+
+ lexer.expect(Token::Separator(':'))?;
+ let param_type = self.type_decl(lexer, &mut ctx)?;
+
+ let handle = ctx.declare_local(param_name)?;
+ arguments.push(ast::FunctionArgument {
+ name: param_name,
+ ty: param_type,
+ binding,
+ handle,
+ });
+ ready = lexer.skip(Token::Separator(','));
+ }
+ // read return type
+ let result = if lexer.skip(Token::Arrow) && !lexer.skip(Token::Word("void")) {
+ let binding = self.varying_binding(lexer, &mut ctx)?;
+ let ty = self.type_decl(lexer, &mut ctx)?;
+ Some(ast::FunctionResult { ty, binding })
+ } else {
+ None
+ };
+
+ // do not use `self.block` here, since we must not push a new scope
+ lexer.expect(Token::Paren('{'))?;
+ let mut body = ast::Block::default();
+ while !lexer.skip(Token::Paren('}')) {
+ self.statement(lexer, &mut ctx, &mut body)?;
+ }
+
+ ctx.local_table.pop_scope();
+
+ let fun = ast::Function {
+ entry_point: None,
+ name: fun_name,
+ arguments,
+ result,
+ body,
+ locals,
+ };
+
+ // done
+ self.pop_rule_span(lexer);
+
+ Ok(fun)
+ }
+
+ fn global_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ out: &mut ast::TranslationUnit<'a>,
+ ) -> Result<(), Error<'a>> {
+ // read attributes
+ let mut binding = None;
+ let mut stage = ParsedAttribute::default();
+ let mut compute_span = Span::new(0, 0);
+ let mut workgroup_size = ParsedAttribute::default();
+ let mut early_depth_test = ParsedAttribute::default();
+ let (mut bind_index, mut bind_group) =
+ (ParsedAttribute::default(), ParsedAttribute::default());
+
+ let mut dependencies = FastIndexSet::default();
+ let mut ctx = ExpressionContext {
+ expressions: &mut out.expressions,
+ local_table: &mut SymbolTable::default(),
+ locals: &mut Arena::new(),
+ types: &mut out.types,
+ unresolved: &mut dependencies,
+ };
+
+ self.push_rule_span(Rule::Attribute, lexer);
+ while lexer.skip(Token::Attribute) {
+ match lexer.next_ident_with_span()? {
+ ("binding", name_span) => {
+ lexer.expect(Token::Paren('('))?;
+ bind_index.set(self.general_expression(lexer, &mut ctx)?, name_span)?;
+ lexer.expect(Token::Paren(')'))?;
+ }
+ ("group", name_span) => {
+ lexer.expect(Token::Paren('('))?;
+ bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?;
+ lexer.expect(Token::Paren(')'))?;
+ }
+ ("vertex", name_span) => {
+ stage.set(crate::ShaderStage::Vertex, name_span)?;
+ }
+ ("fragment", name_span) => {
+ stage.set(crate::ShaderStage::Fragment, name_span)?;
+ }
+ ("compute", name_span) => {
+ stage.set(crate::ShaderStage::Compute, name_span)?;
+ compute_span = name_span;
+ }
+ ("workgroup_size", name_span) => {
+ lexer.expect(Token::Paren('('))?;
+ let mut new_workgroup_size = [None; 3];
+ for (i, size) in new_workgroup_size.iter_mut().enumerate() {
+ *size = Some(self.general_expression(lexer, &mut ctx)?);
+ match lexer.next() {
+ (Token::Paren(')'), _) => break,
+ (Token::Separator(','), _) if i != 2 => (),
+ other => {
+ return Err(Error::Unexpected(
+ other.1,
+ ExpectedToken::WorkgroupSizeSeparator,
+ ))
+ }
+ }
+ }
+ workgroup_size.set(new_workgroup_size, name_span)?;
+ }
+ ("early_depth_test", name_span) => {
+ let conservative = if lexer.skip(Token::Paren('(')) {
+ let (ident, ident_span) = lexer.next_ident_with_span()?;
+ let value = conv::map_conservative_depth(ident, ident_span)?;
+ lexer.expect(Token::Paren(')'))?;
+ Some(value)
+ } else {
+ None
+ };
+ early_depth_test.set(crate::EarlyDepthTest { conservative }, name_span)?;
+ }
+ (_, word_span) => return Err(Error::UnknownAttribute(word_span)),
+ }
+ }
+
+ let attrib_span = self.pop_rule_span(lexer);
+ match (bind_group.value, bind_index.value) {
+ (Some(group), Some(index)) => {
+ binding = Some(ast::ResourceBinding {
+ group,
+ binding: index,
+ });
+ }
+ (Some(_), None) => return Err(Error::MissingAttribute("binding", attrib_span)),
+ (None, Some(_)) => return Err(Error::MissingAttribute("group", attrib_span)),
+ (None, None) => {}
+ }
+
+ // read item
+ let start = lexer.start_byte_offset();
+ let kind = match lexer.next() {
+ (Token::Separator(';'), _) => None,
+ (Token::Word("struct"), _) => {
+ let name = lexer.next_ident()?;
+
+ let members = self.struct_body(lexer, &mut ctx)?;
+ Some(ast::GlobalDeclKind::Struct(ast::Struct { name, members }))
+ }
+ (Token::Word("alias"), _) => {
+ let name = lexer.next_ident()?;
+
+ lexer.expect(Token::Operation('='))?;
+ let ty = self.type_decl(lexer, &mut ctx)?;
+ lexer.expect(Token::Separator(';'))?;
+ Some(ast::GlobalDeclKind::Type(ast::TypeAlias { name, ty }))
+ }
+ (Token::Word("const"), _) => {
+ let name = lexer.next_ident()?;
+
+ let ty = if lexer.skip(Token::Separator(':')) {
+ let ty = self.type_decl(lexer, &mut ctx)?;
+ Some(ty)
+ } else {
+ None
+ };
+
+ lexer.expect(Token::Operation('='))?;
+ let init = self.general_expression(lexer, &mut ctx)?;
+ lexer.expect(Token::Separator(';'))?;
+
+ Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init }))
+ }
+ (Token::Word("var"), _) => {
+ let mut var = self.variable_decl(lexer, &mut ctx)?;
+ var.binding = binding.take();
+ Some(ast::GlobalDeclKind::Var(var))
+ }
+ (Token::Word("fn"), _) => {
+ let function = self.function_decl(lexer, out, &mut dependencies)?;
+ Some(ast::GlobalDeclKind::Fn(ast::Function {
+ entry_point: if let Some(stage) = stage.value {
+ if stage == ShaderStage::Compute && workgroup_size.value.is_none() {
+ return Err(Error::MissingWorkgroupSize(compute_span));
+ }
+ Some(ast::EntryPoint {
+ stage,
+ early_depth_test: early_depth_test.value,
+ workgroup_size: workgroup_size.value,
+ })
+ } else {
+ None
+ },
+ ..function
+ }))
+ }
+ (Token::End, _) => return Ok(()),
+ other => return Err(Error::Unexpected(other.1, ExpectedToken::GlobalItem)),
+ };
+
+ if let Some(kind) = kind {
+ out.decls.append(
+ ast::GlobalDecl { kind, dependencies },
+ lexer.span_from(start),
+ );
+ }
+
+ if !self.rules.is_empty() {
+ log::error!("Reached the end of global decl, but rule stack is not empty");
+ log::error!("Rules: {:?}", self.rules);
+ return Err(Error::Internal("rule stack is not empty"));
+ };
+
+ match binding {
+ None => Ok(()),
+ Some(_) => Err(Error::Internal("we had the attribute but no var?")),
+ }
+ }
+
+ pub fn parse<'a>(&mut self, source: &'a str) -> Result<ast::TranslationUnit<'a>, Error<'a>> {
+ self.reset();
+
+ let mut lexer = Lexer::new(source);
+ let mut tu = ast::TranslationUnit::default();
+ loop {
+ match self.global_decl(&mut lexer, &mut tu) {
+ Err(error) => return Err(error),
+ Ok(()) => {
+ if lexer.peek().0 == Token::End {
+ break;
+ }
+ }
+ }
+ }
+
+ Ok(tu)
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/number.rs b/third_party/rust/naga/src/front/wgsl/parse/number.rs
new file mode 100644
index 0000000000..7b09ac59bb
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/parse/number.rs
@@ -0,0 +1,420 @@
+use crate::front::wgsl::error::NumberError;
+use crate::front::wgsl::parse::lexer::Token;
+
+/// When using this type assume no Abstract Int/Float for now
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum Number {
+ /// Abstract Int (-2^63 ≤ i < 2^63)
+ AbstractInt(i64),
+ /// Abstract Float (IEEE-754 binary64)
+ AbstractFloat(f64),
+ /// Concrete i32
+ I32(i32),
+ /// Concrete u32
+ U32(u32),
+ /// Concrete f32
+ F32(f32),
+ /// Concrete f64
+ F64(f64),
+}
+
+pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) {
+ let (result, rest) = parse(input);
+ (Token::Number(result), rest)
+}
+
+enum Kind {
+ Int(IntKind),
+ Float(FloatKind),
+}
+
+enum IntKind {
+ I32,
+ U32,
+}
+
+#[derive(Debug)]
+enum FloatKind {
+ F16,
+ F32,
+ F64,
+}
+
+// The following regexes (from the WGSL spec) will be matched:
+
+// int_literal:
+// | / 0 [iu]? /
+// | / [1-9][0-9]* [iu]? /
+// | / 0[xX][0-9a-fA-F]+ [iu]? /
+
+// decimal_float_literal:
+// | / 0 [fh] /
+// | / [1-9][0-9]* [fh] /
+// | / [0-9]* \.[0-9]+ ([eE][+-]?[0-9]+)? [fh]? /
+// | / [0-9]+ \.[0-9]* ([eE][+-]?[0-9]+)? [fh]? /
+// | / [0-9]+ [eE][+-]?[0-9]+ [fh]? /
+
+// hex_float_literal:
+// | / 0[xX][0-9a-fA-F]* \.[0-9a-fA-F]+ ([pP][+-]?[0-9]+ [fh]?)? /
+// | / 0[xX][0-9a-fA-F]+ \.[0-9a-fA-F]* ([pP][+-]?[0-9]+ [fh]?)? /
+// | / 0[xX][0-9a-fA-F]+ [pP][+-]?[0-9]+ [fh]? /
+
+// You could visualize the regex below via https://debuggex.com to get a rough idea what `parse` is doing
+// (?:0[xX](?:([0-9a-fA-F]+\.[0-9a-fA-F]*|[0-9a-fA-F]*\.[0-9a-fA-F]+)(?:([pP][+-]?[0-9]+)([fh]?))?|([0-9a-fA-F]+)([pP][+-]?[0-9]+)([fh]?)|([0-9a-fA-F]+)([iu]?))|((?:[0-9]+[eE][+-]?[0-9]+|(?:[0-9]+\.[0-9]*|[0-9]*\.[0-9]+)(?:[eE][+-]?[0-9]+)?))([fh]?)|((?:[0-9]|[1-9][0-9]+))([iufh]?))
+
+// Leading signs are handled as unary operators.
+
+fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
+ /// returns `true` and consumes `X` bytes from the given byte buffer
+ /// if the given `X` nr of patterns are found at the start of the buffer
+ macro_rules! consume {
+ ($bytes:ident, $($pattern:pat),*) => {
+ match $bytes {
+ &[$($pattern),*, ref rest @ ..] => { $bytes = rest; true },
+ _ => false,
+ }
+ };
+ }
+
+ /// consumes one byte from the given byte buffer
+ /// if one of the given patterns are found at the start of the buffer
+ /// returning the corresponding expr for the matched pattern
+ macro_rules! consume_map {
+ ($bytes:ident, [$( $($pattern:pat_param),* => $to:expr),* $(,)?]) => {
+ match $bytes {
+ $( &[ $($pattern),*, ref rest @ ..] => { $bytes = rest; Some($to) }, )*
+ _ => None,
+ }
+ };
+ }
+
+ /// consumes all consecutive bytes matched by the `0-9` pattern from the given byte buffer
+ /// returning the number of consumed bytes
+ macro_rules! consume_dec_digits {
+ ($bytes:ident) => {{
+ let start_len = $bytes.len();
+ while let &[b'0'..=b'9', ref rest @ ..] = $bytes {
+ $bytes = rest;
+ }
+ start_len - $bytes.len()
+ }};
+ }
+
+ /// consumes all consecutive bytes matched by the `0-9 | a-f | A-F` pattern from the given byte buffer
+ /// returning the number of consumed bytes
+ macro_rules! consume_hex_digits {
+ ($bytes:ident) => {{
+ let start_len = $bytes.len();
+ while let &[b'0'..=b'9' | b'a'..=b'f' | b'A'..=b'F', ref rest @ ..] = $bytes {
+ $bytes = rest;
+ }
+ start_len - $bytes.len()
+ }};
+ }
+
+ macro_rules! consume_float_suffix {
+ ($bytes:ident) => {
+ consume_map!($bytes, [
+ b'h' => FloatKind::F16,
+ b'f' => FloatKind::F32,
+ b'l', b'f' => FloatKind::F64,
+ ])
+ };
+ }
+
+ /// maps the given `&[u8]` (tail of the initial `input: &str`) to a `&str`
+ macro_rules! rest_to_str {
+ ($bytes:ident) => {
+ &input[input.len() - $bytes.len()..]
+ };
+ }
+
+ struct ExtractSubStr<'a>(&'a str);
+
+ impl<'a> ExtractSubStr<'a> {
+ /// given an `input` and a `start` (tail of the `input`)
+ /// creates a new [`ExtractSubStr`](`Self`)
+ fn start(input: &'a str, start: &'a [u8]) -> Self {
+ let start = input.len() - start.len();
+ Self(&input[start..])
+ }
+ /// given an `end` (tail of the initial `input`)
+ /// returns a substring of `input`
+ fn end(&self, end: &'a [u8]) -> &'a str {
+ let end = self.0.len() - end.len();
+ &self.0[..end]
+ }
+ }
+
+ let mut bytes = input.as_bytes();
+
+ let general_extract = ExtractSubStr::start(input, bytes);
+
+ if consume!(bytes, b'0', b'x' | b'X') {
+ let digits_extract = ExtractSubStr::start(input, bytes);
+
+ let consumed = consume_hex_digits!(bytes);
+
+ if consume!(bytes, b'.') {
+ let consumed_after_period = consume_hex_digits!(bytes);
+
+ if consumed + consumed_after_period == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let significand = general_extract.end(bytes);
+
+ if consume!(bytes, b'p' | b'P') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let number = general_extract.end(bytes);
+
+ let kind = consume_float_suffix!(bytes);
+
+ (parse_hex_float(number, kind), rest_to_str!(bytes))
+ } else {
+ (
+ parse_hex_float_missing_exponent(significand, None),
+ rest_to_str!(bytes),
+ )
+ }
+ } else {
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let significand = general_extract.end(bytes);
+ let digits = digits_extract.end(bytes);
+
+ let exp_extract = ExtractSubStr::start(input, bytes);
+
+ if consume!(bytes, b'p' | b'P') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let exponent = exp_extract.end(bytes);
+
+ let kind = consume_float_suffix!(bytes);
+
+ (
+ parse_hex_float_missing_period(significand, exponent, kind),
+ rest_to_str!(bytes),
+ )
+ } else {
+ let kind = consume_map!(bytes, [b'i' => IntKind::I32, b'u' => IntKind::U32]);
+
+ (parse_hex_int(digits, kind), rest_to_str!(bytes))
+ }
+ }
+ } else {
+ let is_first_zero = bytes.first() == Some(&b'0');
+
+ let consumed = consume_dec_digits!(bytes);
+
+ if consume!(bytes, b'.') {
+ let consumed_after_period = consume_dec_digits!(bytes);
+
+ if consumed + consumed_after_period == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ if consume!(bytes, b'e' | b'E') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+ }
+
+ let number = general_extract.end(bytes);
+
+ let kind = consume_float_suffix!(bytes);
+
+ (parse_dec_float(number, kind), rest_to_str!(bytes))
+ } else {
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ if consume!(bytes, b'e' | b'E') {
+ consume!(bytes, b'+' | b'-');
+ let consumed = consume_dec_digits!(bytes);
+
+ if consumed == 0 {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let number = general_extract.end(bytes);
+
+ let kind = consume_float_suffix!(bytes);
+
+ (parse_dec_float(number, kind), rest_to_str!(bytes))
+ } else {
+ // make sure the multi-digit numbers don't start with zero
+ if consumed > 1 && is_first_zero {
+ return (Err(NumberError::Invalid), rest_to_str!(bytes));
+ }
+
+ let digits = general_extract.end(bytes);
+
+ let kind = consume_map!(bytes, [
+ b'i' => Kind::Int(IntKind::I32),
+ b'u' => Kind::Int(IntKind::U32),
+ b'h' => Kind::Float(FloatKind::F16),
+ b'f' => Kind::Float(FloatKind::F32),
+ b'l', b'f' => Kind::Float(FloatKind::F64),
+ ]);
+
+ (parse_dec(digits, kind), rest_to_str!(bytes))
+ }
+ }
+ }
+}
+
+fn parse_hex_float_missing_exponent(
+ // format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ )
+ significand: &str,
+ kind: Option<FloatKind>,
+) -> Result<Number, NumberError> {
+ let hexf_input = format!("{}{}", significand, "p0");
+ parse_hex_float(&hexf_input, kind)
+}
+
+fn parse_hex_float_missing_period(
+ // format: 0[xX] [0-9a-fA-F]+
+ significand: &str,
+ // format: [pP][+-]?[0-9]+
+ exponent: &str,
+ kind: Option<FloatKind>,
+) -> Result<Number, NumberError> {
+ let hexf_input = format!("{significand}.{exponent}");
+ parse_hex_float(&hexf_input, kind)
+}
+
+fn parse_hex_int(
+ // format: [0-9a-fA-F]+
+ digits: &str,
+ kind: Option<IntKind>,
+) -> Result<Number, NumberError> {
+ parse_int(digits, kind, 16)
+}
+
+fn parse_dec(
+ // format: ( [0-9] | [1-9][0-9]+ )
+ digits: &str,
+ kind: Option<Kind>,
+) -> Result<Number, NumberError> {
+ match kind {
+ None => parse_int(digits, None, 10),
+ Some(Kind::Int(kind)) => parse_int(digits, Some(kind), 10),
+ Some(Kind::Float(kind)) => parse_dec_float(digits, Some(kind)),
+ }
+}
+
+// Float parsing notes
+
+// The following chapters of IEEE 754-2019 are relevant:
+//
+// 7.4 Overflow (largest finite number is exceeded by what would have been
+// the rounded floating-point result were the exponent range unbounded)
+//
+// 7.5 Underflow (tiny non-zero result is detected;
+// for decimal formats tininess is detected before rounding when a non-zero result
+// computed as though both the exponent range and the precision were unbounded
+// would lie strictly between 2^−126)
+//
+// 7.6 Inexact (rounded result differs from what would have been computed
+// were both exponent range and precision unbounded)
+
+// The WGSL spec requires us to error:
+// on overflow for decimal floating point literals
+// on overflow and inexact for hexadecimal floating point literals
+// (underflow is not mentioned)
+
+// hexf_parse errors on overflow, underflow, inexact
+// rust std lib float from str handles overflow, underflow, inexact transparently (rounds and will not error)
+
+// Therefore we only check for overflow manually for decimal floating point literals
+
+// input format: 0[xX] ( [0-9a-fA-F]+\.[0-9a-fA-F]* | [0-9a-fA-F]*\.[0-9a-fA-F]+ ) [pP][+-]?[0-9]+
+fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
+ match kind {
+ None => match hexf_parse::parse_hexf64(input, false) {
+ Ok(num) => Ok(Number::AbstractFloat(num)),
+ // can only be ParseHexfErrorKind::Inexact but we can't check since it's private
+ _ => Err(NumberError::NotRepresentable),
+ },
+ Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
+ Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) {
+ Ok(num) => Ok(Number::F32(num)),
+ // can only be ParseHexfErrorKind::Inexact but we can't check since it's private
+ _ => Err(NumberError::NotRepresentable),
+ },
+ Some(FloatKind::F64) => match hexf_parse::parse_hexf64(input, false) {
+ Ok(num) => Ok(Number::F64(num)),
+ // can only be ParseHexfErrorKind::Inexact but we can't check since it's private
+ _ => Err(NumberError::NotRepresentable),
+ },
+ }
+}
+
+// input format: ( [0-9]+\.[0-9]* | [0-9]*\.[0-9]+ ) ([eE][+-]?[0-9]+)?
+// | [0-9]+ [eE][+-]?[0-9]+
+fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, NumberError> {
+ match kind {
+ None => {
+ let num = input.parse::<f64>().unwrap(); // will never fail
+ num.is_finite()
+ .then_some(Number::AbstractFloat(num))
+ .ok_or(NumberError::NotRepresentable)
+ }
+ Some(FloatKind::F32) => {
+ let num = input.parse::<f32>().unwrap(); // will never fail
+ num.is_finite()
+ .then_some(Number::F32(num))
+ .ok_or(NumberError::NotRepresentable)
+ }
+ Some(FloatKind::F64) => {
+ let num = input.parse::<f64>().unwrap(); // will never fail
+ num.is_finite()
+ .then_some(Number::F64(num))
+ .ok_or(NumberError::NotRepresentable)
+ }
+ Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
+ }
+}
+
+fn parse_int(input: &str, kind: Option<IntKind>, radix: u32) -> Result<Number, NumberError> {
+ fn map_err(e: core::num::ParseIntError) -> NumberError {
+ match *e.kind() {
+ core::num::IntErrorKind::PosOverflow | core::num::IntErrorKind::NegOverflow => {
+ NumberError::NotRepresentable
+ }
+ _ => unreachable!(),
+ }
+ }
+ match kind {
+ None => match i64::from_str_radix(input, radix) {
+ Ok(num) => Ok(Number::AbstractInt(num)),
+ Err(e) => Err(map_err(e)),
+ },
+ Some(IntKind::I32) => match i32::from_str_radix(input, radix) {
+ Ok(num) => Ok(Number::I32(num)),
+ Err(e) => Err(map_err(e)),
+ },
+ Some(IntKind::U32) => match u32::from_str_radix(input, radix) {
+ Ok(num) => Ok(Number::U32(num)),
+ Err(e) => Err(map_err(e)),
+ },
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/tests.rs b/third_party/rust/naga/src/front/wgsl/tests.rs
new file mode 100644
index 0000000000..eb2f8a2eb3
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/tests.rs
@@ -0,0 +1,637 @@
+use super::parse_str;
+
+#[test]
+fn parse_comment() {
+ parse_str(
+ "//
+ ////
+ ///////////////////////////////////////////////////////// asda
+ //////////////////// dad ////////// /
+ /////////////////////////////////////////////////////////////////////////////////////////////////////
+ //
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_types() {
+ parse_str("const a : i32 = 2;").unwrap();
+ assert!(parse_str("const a : x32 = 2;").is_err());
+ parse_str("var t: texture_2d<f32>;").unwrap();
+ parse_str("var t: texture_cube_array<i32>;").unwrap();
+ parse_str("var t: texture_multisampled_2d<u32>;").unwrap();
+ parse_str("var t: texture_storage_1d<rgba8uint,write>;").unwrap();
+ parse_str("var t: texture_storage_3d<r32float,read>;").unwrap();
+}
+
+#[test]
+fn parse_type_inference() {
+ parse_str(
+ "
+ fn foo() {
+ let a = 2u;
+ let b: u32 = a;
+ var x = 3.;
+ var y = vec2<f32>(1, 2);
+ }",
+ )
+ .unwrap();
+ assert!(parse_str(
+ "
+ fn foo() { let c : i32 = 2.0; }",
+ )
+ .is_err());
+}
+
+#[test]
+fn parse_type_cast() {
+ parse_str(
+ "
+ const a : i32 = 2;
+ fn main() {
+ var x: f32 = f32(a);
+ x = f32(i32(a + 1) / 2);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ let x: vec2<f32> = vec2<f32>(1.0, 2.0);
+ let y: vec2<u32> = vec2<u32>(x);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ let x: vec2<f32> = vec2<f32>(0.0);
+ }
+ ",
+ )
+ .unwrap();
+ assert!(parse_str(
+ "
+ fn main() {
+ let x: vec2<f32> = vec2<f32>(0i, 0i);
+ }
+ ",
+ )
+ .is_err());
+}
+
+#[test]
+fn parse_struct() {
+ parse_str(
+ "
+ struct Foo { x: i32 }
+ struct Bar {
+ @size(16) x: vec2<i32>,
+ @align(16) y: f32,
+ @size(32) @align(128) z: vec3<f32>,
+ };
+ struct Empty {}
+ var<storage,read_write> s: Foo;
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_standard_fun() {
+ parse_str(
+ "
+ fn main() {
+ var x: i32 = min(max(1, 2), 3);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_statement() {
+ parse_str(
+ "
+ fn main() {
+ ;
+ {}
+ {;}
+ }
+ ",
+ )
+ .unwrap();
+
+ parse_str(
+ "
+ fn foo() {}
+ fn bar() { foo(); }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_if() {
+ parse_str(
+ "
+ fn main() {
+ if true {
+ discard;
+ } else {}
+ if 0 != 1 {}
+ if false {
+ return;
+ } else if true {
+ return;
+ } else {}
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_parentheses_if() {
+ parse_str(
+ "
+ fn main() {
+ if (true) {
+ discard;
+ } else {}
+ if (0 != 1) {}
+ if (false) {
+ return;
+ } else if (true) {
+ return;
+ } else {}
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_loop() {
+ parse_str(
+ "
+ fn main() {
+ var i: i32 = 0;
+ loop {
+ if i == 1 { break; }
+ continuing { i = 1; }
+ }
+ loop {
+ if i == 0 { continue; }
+ break;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ var found: bool = false;
+ var i: i32 = 0;
+ while !found {
+ if i == 10 {
+ found = true;
+ }
+
+ i = i + 1;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ while true {
+ break;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ var a: i32 = 0;
+ for(var i: i32 = 0; i < 4; i = i + 1) {
+ a = a + 2;
+ }
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ fn main() {
+ for(;;) {
+ break;
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_switch() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch (3) {
+ case 0, 1: { pos = 0.0; }
+ case 2: { pos = 1.0; }
+ default: { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_switch_optional_colon_in_case() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch (3) {
+ case 0, 1 { pos = 0.0; }
+ case 2 { pos = 1.0; }
+ default { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_switch_default_in_case() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch (3) {
+ case 0, 1: { pos = 0.0; }
+ case 2: {}
+ case default, 3: { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_parentheses_switch() {
+ parse_str(
+ "
+ fn main() {
+ var pos: f32;
+ switch pos > 1.0 {
+ default: { pos = 3.0; }
+ }
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_load() {
+ parse_str(
+ "
+ var t: texture_3d<u32>;
+ fn foo() {
+ let r: vec4<u32> = textureLoad(t, vec3<u32>(0u, 1u, 2u), 1);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ var t: texture_multisampled_2d_array<i32>;
+ fn foo() {
+ let r: vec4<i32> = textureLoad(t, vec2<i32>(10, 20), 2, 3);
+ }
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ var t: texture_storage_1d_array<r32float,read>;
+ fn foo() {
+ let r: vec4<f32> = textureLoad(t, 10, 2);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_store() {
+ parse_str(
+ "
+ var t: texture_storage_2d<rgba8unorm,write>;
+ fn foo() {
+ textureStore(t, vec2<i32>(10, 20), vec4<f32>(0.0, 1.0, 2.0, 3.0));
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_query() {
+ parse_str(
+ "
+ var t: texture_multisampled_2d_array<f32>;
+ fn foo() {
+ var dim: vec2<u32> = textureDimensions(t);
+ dim = textureDimensions(t, 0);
+ let layers: u32 = textureNumLayers(t);
+ let samples: u32 = textureNumSamples(t);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_postfix() {
+ parse_str(
+ "fn foo() {
+ let x: f32 = vec4<f32>(1.0, 2.0, 3.0, 4.0).xyz.rgbr.aaaa.wz.g;
+ let y: f32 = fract(vec2<f32>(0.5, x)).x;
+ }",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_expressions() {
+ parse_str("fn foo() {
+ let x: f32 = select(0.0, 1.0, true);
+ let y: vec2<f32> = select(vec2<f32>(1.0, 1.0), vec2<f32>(x, x), vec2<bool>(x < 0.5, x > 0.5));
+ let z: bool = !(0.0 == 1.0);
+ }").unwrap();
+}
+
+#[test]
+fn binary_expression_mixed_scalar_and_vector_operands() {
+ for (operand, expect_splat) in [
+ ('<', false),
+ ('>', false),
+ ('&', false),
+ ('|', false),
+ ('+', true),
+ ('-', true),
+ ('*', false),
+ ('/', true),
+ ('%', true),
+ ] {
+ let module = parse_str(&format!(
+ "
+ @fragment
+ fn main(@location(0) some_vec: vec3<f32>) -> @location(0) vec4<f32> {{
+ if (all(1.0 {operand} some_vec)) {{
+ return vec4(0.0);
+ }}
+ return vec4(1.0);
+ }}
+ "
+ ))
+ .unwrap();
+
+ let expressions = &&module.entry_points[0].function.expressions;
+
+ let found_expressions = expressions
+ .iter()
+ .filter(|&(_, e)| {
+ if let crate::Expression::Binary { left, .. } = *e {
+ matches!(
+ (expect_splat, &expressions[left]),
+ (false, &crate::Expression::Literal(crate::Literal::F32(..)))
+ | (true, &crate::Expression::Splat { .. })
+ )
+ } else {
+ false
+ }
+ })
+ .count();
+
+ assert_eq!(
+ found_expressions,
+ 1,
+ "expected `{operand}` expression {} splat",
+ if expect_splat { "with" } else { "without" }
+ );
+ }
+
+ let module = parse_str(
+ "@fragment
+ fn main(mat: mat3x3<f32>) {
+ let vec = vec3<f32>(1.0, 1.0, 1.0);
+ let result = mat / vec;
+ }",
+ )
+ .unwrap();
+ let expressions = &&module.entry_points[0].function.expressions;
+ let found_splat = expressions.iter().any(|(_, e)| {
+ if let crate::Expression::Binary { left, .. } = *e {
+ matches!(&expressions[left], &crate::Expression::Splat { .. })
+ } else {
+ false
+ }
+ });
+ assert!(!found_splat, "'mat / vec' should not be splatted");
+}
+
+#[test]
+fn parse_pointers() {
+ parse_str(
+ "fn foo(a: ptr<private, f32>) -> f32 { return *a; }
+ fn bar() {
+ var x: f32 = 1.0;
+ let px = &x;
+ let py = foo(px);
+ }",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_struct_instantiation() {
+ parse_str(
+ "
+ struct Foo {
+ a: f32,
+ b: vec3<f32>,
+ }
+
+ @fragment
+ fn fs_main() {
+ var foo: Foo = Foo(0.0, vec3<f32>(0.0, 1.0, 42.0));
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_array_length() {
+ parse_str(
+ "
+ struct Foo {
+ data: array<u32>
+ } // this is used as both input and output for convenience
+
+ @group(0) @binding(0)
+ var<storage> foo: Foo;
+
+ @group(0) @binding(1)
+ var<storage> bar: array<u32>;
+
+ fn baz() {
+ var x: u32 = arrayLength(foo.data);
+ var y: u32 = arrayLength(bar);
+ }
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_storage_buffers() {
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage,read> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage,write> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+ parse_str(
+ "
+ @group(0) @binding(0)
+ var<storage,read_write> foo: array<u32>;
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_alias() {
+ parse_str(
+ "
+ alias Vec4 = vec4<f32>;
+ ",
+ )
+ .unwrap();
+}
+
+#[test]
+fn parse_texture_load_store_expecting_four_args() {
+ for (func, texture) in [
+ (
+ "textureStore",
+ "texture_storage_2d_array<rg11b10float, write>",
+ ),
+ ("textureLoad", "texture_2d_array<i32>"),
+ ] {
+ let error = parse_str(&format!(
+ "
+ @group(0) @binding(0) var tex_los_res: {texture};
+ @compute
+ @workgroup_size(1)
+ fn main(@builtin(global_invocation_id) id: vec3<u32>) {{
+ var color = vec4(1, 1, 1, 1);
+ {func}(tex_los_res, id, color);
+ }}
+ "
+ ))
+ .unwrap_err();
+ assert_eq!(
+ error.message(),
+ "wrong number of arguments: expected 4, found 3"
+ );
+ }
+}
+
+#[test]
+fn parse_repeated_attributes() {
+ use crate::{
+ front::wgsl::{error::Error, Frontend},
+ Span,
+ };
+
+ let template_vs = "@vertex fn vs() -> __REPLACE__ vec4<f32> { return vec4<f32>(0.0); }";
+ let template_struct = "struct A { __REPLACE__ data: vec3<f32> }";
+ let template_resource = "__REPLACE__ var tex_los_res: texture_2d_array<i32>;";
+ let template_stage = "__REPLACE__ fn vs() -> vec4<f32> { return vec4<f32>(0.0); }";
+ for (attribute, template) in [
+ ("align(16)", template_struct),
+ ("binding(0)", template_resource),
+ ("builtin(position)", template_vs),
+ ("compute", template_stage),
+ ("fragment", template_stage),
+ ("group(0)", template_resource),
+ ("interpolate(flat)", template_vs),
+ ("invariant", template_vs),
+ ("location(0)", template_vs),
+ ("size(16)", template_struct),
+ ("vertex", template_stage),
+ ("early_depth_test(less_equal)", template_resource),
+ ("workgroup_size(1)", template_stage),
+ ] {
+ let shader = template.replace("__REPLACE__", &format!("@{attribute} @{attribute}"));
+ let name_length = attribute.rfind('(').unwrap_or(attribute.len()) as u32;
+ let span_start = shader.rfind(attribute).unwrap() as u32;
+ let span_end = span_start + name_length;
+ let expected_span = Span::new(span_start, span_end);
+
+ let result = Frontend::new().inner(&shader);
+ assert!(matches!(
+ result.unwrap_err(),
+ Error::RepeatedAttribute(span) if span == expected_span
+ ));
+ }
+}
+
+#[test]
+fn parse_missing_workgroup_size() {
+ use crate::{
+ front::wgsl::{error::Error, Frontend},
+ Span,
+ };
+
+ let shader = "@compute fn vs() -> vec4<f32> { return vec4<f32>(0.0); }";
+ let result = Frontend::new().inner(shader);
+ assert!(matches!(
+ result.unwrap_err(),
+ Error::MissingWorkgroupSize(span) if span == Span::new(1, 8)
+ ));
+}
diff --git a/third_party/rust/naga/src/front/wgsl/to_wgsl.rs b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs
new file mode 100644
index 0000000000..c8331ace09
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs
@@ -0,0 +1,283 @@
+//! Producing the WGSL forms of types, for use in error messages.
+
+use crate::proc::GlobalCtx;
+use crate::Handle;
+
+impl crate::proc::TypeResolution {
+ pub fn to_wgsl(&self, gctx: &GlobalCtx) -> String {
+ match *self {
+ crate::proc::TypeResolution::Handle(handle) => handle.to_wgsl(gctx),
+ crate::proc::TypeResolution::Value(ref inner) => inner.to_wgsl(gctx),
+ }
+ }
+}
+
+impl Handle<crate::Type> {
+ /// Formats the type as it is written in wgsl.
+ ///
+ /// For example `vec3<f32>`.
+ pub fn to_wgsl(self, gctx: &GlobalCtx) -> String {
+ let ty = &gctx.types[self];
+ match ty.name {
+ Some(ref name) => name.clone(),
+ None => ty.inner.to_wgsl(gctx),
+ }
+ }
+}
+
+impl crate::TypeInner {
+ /// Formats the type as it is written in wgsl.
+ ///
+ /// For example `vec3<f32>`.
+ ///
+ /// Note: `TypeInner::Struct` doesn't include the name of the
+ /// struct type. Therefore this method will simply return "struct"
+ /// for them.
+ pub fn to_wgsl(&self, gctx: &GlobalCtx) -> String {
+ use crate::TypeInner as Ti;
+
+ match *self {
+ Ti::Scalar(scalar) => scalar.to_wgsl(),
+ Ti::Vector { size, scalar } => {
+ format!("vec{}<{}>", size as u32, scalar.to_wgsl())
+ }
+ Ti::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ format!(
+ "mat{}x{}<{}>",
+ columns as u32,
+ rows as u32,
+ scalar.to_wgsl(),
+ )
+ }
+ Ti::Atomic(scalar) => {
+ format!("atomic<{}>", scalar.to_wgsl())
+ }
+ Ti::Pointer { base, .. } => {
+ let name = base.to_wgsl(gctx);
+ format!("ptr<{name}>")
+ }
+ Ti::ValuePointer { scalar, .. } => {
+ format!("ptr<{}>", scalar.to_wgsl())
+ }
+ Ti::Array { base, size, .. } => {
+ let base = base.to_wgsl(gctx);
+ match size {
+ crate::ArraySize::Constant(size) => format!("array<{base}, {size}>"),
+ crate::ArraySize::Dynamic => format!("array<{base}>"),
+ }
+ }
+ Ti::Struct { .. } => {
+ // TODO: Actually output the struct?
+ "struct".to_string()
+ }
+ Ti::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ let dim_suffix = match dim {
+ crate::ImageDimension::D1 => "_1d",
+ crate::ImageDimension::D2 => "_2d",
+ crate::ImageDimension::D3 => "_3d",
+ crate::ImageDimension::Cube => "_cube",
+ };
+ let array_suffix = if arrayed { "_array" } else { "" };
+
+ let class_suffix = match class {
+ crate::ImageClass::Sampled { multi: true, .. } => "_multisampled",
+ crate::ImageClass::Depth { multi: false } => "_depth",
+ crate::ImageClass::Depth { multi: true } => "_depth_multisampled",
+ crate::ImageClass::Sampled { multi: false, .. }
+ | crate::ImageClass::Storage { .. } => "",
+ };
+
+ let type_in_brackets = match class {
+ crate::ImageClass::Sampled { kind, .. } => {
+ // Note: The only valid widths are 4 bytes wide.
+ // The lexer has already verified this, so we can safely assume it here.
+ // https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type
+ let element_type = crate::Scalar { kind, width: 4 }.to_wgsl();
+ format!("<{element_type}>")
+ }
+ crate::ImageClass::Depth { multi: _ } => String::new(),
+ crate::ImageClass::Storage { format, access } => {
+ if access.contains(crate::StorageAccess::STORE) {
+ format!("<{},write>", format.to_wgsl())
+ } else {
+ format!("<{}>", format.to_wgsl())
+ }
+ }
+ };
+
+ format!("texture{class_suffix}{dim_suffix}{array_suffix}{type_in_brackets}")
+ }
+ Ti::Sampler { .. } => "sampler".to_string(),
+ Ti::AccelerationStructure => "acceleration_structure".to_string(),
+ Ti::RayQuery => "ray_query".to_string(),
+ Ti::BindingArray { base, size, .. } => {
+ let member_type = &gctx.types[base];
+ let base = member_type.name.as_deref().unwrap_or("unknown");
+ match size {
+ crate::ArraySize::Constant(size) => format!("binding_array<{base}, {size}>"),
+ crate::ArraySize::Dynamic => format!("binding_array<{base}>"),
+ }
+ }
+ }
+ }
+}
+
+impl crate::Scalar {
+ /// Format a scalar kind+width as a type is written in wgsl.
+ ///
+ /// Examples: `f32`, `u64`, `bool`.
+ pub fn to_wgsl(self) -> String {
+ let prefix = match self.kind {
+ crate::ScalarKind::Sint => "i",
+ crate::ScalarKind::Uint => "u",
+ crate::ScalarKind::Float => "f",
+ crate::ScalarKind::Bool => return "bool".to_string(),
+ crate::ScalarKind::AbstractInt => return "{AbstractInt}".to_string(),
+ crate::ScalarKind::AbstractFloat => return "{AbstractFloat}".to_string(),
+ };
+ format!("{}{}", prefix, self.width * 8)
+ }
+}
+
+impl crate::StorageFormat {
+ pub const fn to_wgsl(self) -> &'static str {
+ use crate::StorageFormat as Sf;
+ match self {
+ Sf::R8Unorm => "r8unorm",
+ Sf::R8Snorm => "r8snorm",
+ Sf::R8Uint => "r8uint",
+ Sf::R8Sint => "r8sint",
+ Sf::R16Uint => "r16uint",
+ Sf::R16Sint => "r16sint",
+ Sf::R16Float => "r16float",
+ Sf::Rg8Unorm => "rg8unorm",
+ Sf::Rg8Snorm => "rg8snorm",
+ Sf::Rg8Uint => "rg8uint",
+ Sf::Rg8Sint => "rg8sint",
+ Sf::R32Uint => "r32uint",
+ Sf::R32Sint => "r32sint",
+ Sf::R32Float => "r32float",
+ Sf::Rg16Uint => "rg16uint",
+ Sf::Rg16Sint => "rg16sint",
+ Sf::Rg16Float => "rg16float",
+ Sf::Rgba8Unorm => "rgba8unorm",
+ Sf::Rgba8Snorm => "rgba8snorm",
+ Sf::Rgba8Uint => "rgba8uint",
+ Sf::Rgba8Sint => "rgba8sint",
+ Sf::Bgra8Unorm => "bgra8unorm",
+ Sf::Rgb10a2Uint => "rgb10a2uint",
+ Sf::Rgb10a2Unorm => "rgb10a2unorm",
+ Sf::Rg11b10Float => "rg11b10float",
+ Sf::Rg32Uint => "rg32uint",
+ Sf::Rg32Sint => "rg32sint",
+ Sf::Rg32Float => "rg32float",
+ Sf::Rgba16Uint => "rgba16uint",
+ Sf::Rgba16Sint => "rgba16sint",
+ Sf::Rgba16Float => "rgba16float",
+ Sf::Rgba32Uint => "rgba32uint",
+ Sf::Rgba32Sint => "rgba32sint",
+ Sf::Rgba32Float => "rgba32float",
+ Sf::R16Unorm => "r16unorm",
+ Sf::R16Snorm => "r16snorm",
+ Sf::Rg16Unorm => "rg16unorm",
+ Sf::Rg16Snorm => "rg16snorm",
+ Sf::Rgba16Unorm => "rgba16unorm",
+ Sf::Rgba16Snorm => "rgba16snorm",
+ }
+ }
+}
+
+mod tests {
+ #[test]
+ fn to_wgsl() {
+ use std::num::NonZeroU32;
+
+ let mut types = crate::UniqueArena::new();
+
+ let mytype1 = types.insert(
+ crate::Type {
+ name: Some("MyType1".to_string()),
+ inner: crate::TypeInner::Struct {
+ members: vec![],
+ span: 0,
+ },
+ },
+ Default::default(),
+ );
+ let mytype2 = types.insert(
+ crate::Type {
+ name: Some("MyType2".to_string()),
+ inner: crate::TypeInner::Struct {
+ members: vec![],
+ span: 0,
+ },
+ },
+ Default::default(),
+ );
+
+ let gctx = crate::proc::GlobalCtx {
+ types: &types,
+ constants: &crate::Arena::new(),
+ const_expressions: &crate::Arena::new(),
+ };
+ let array = crate::TypeInner::Array {
+ base: mytype1,
+ stride: 4,
+ size: crate::ArraySize::Constant(unsafe { NonZeroU32::new_unchecked(32) }),
+ };
+ assert_eq!(array.to_wgsl(&gctx), "array<MyType1, 32>");
+
+ let mat = crate::TypeInner::Matrix {
+ rows: crate::VectorSize::Quad,
+ columns: crate::VectorSize::Bi,
+ scalar: crate::Scalar::F64,
+ };
+ assert_eq!(mat.to_wgsl(&gctx), "mat2x4<f64>");
+
+ let ptr = crate::TypeInner::Pointer {
+ base: mytype2,
+ space: crate::AddressSpace::Storage {
+ access: crate::StorageAccess::default(),
+ },
+ };
+ assert_eq!(ptr.to_wgsl(&gctx), "ptr<MyType2>");
+
+ let img1 = crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Sampled {
+ kind: crate::ScalarKind::Float,
+ multi: true,
+ },
+ };
+ assert_eq!(img1.to_wgsl(&gctx), "texture_multisampled_2d<f32>");
+
+ let img2 = crate::TypeInner::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: true,
+ class: crate::ImageClass::Depth { multi: false },
+ };
+ assert_eq!(img2.to_wgsl(&gctx), "texture_depth_cube_array");
+
+ let img3 = crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Depth { multi: true },
+ };
+ assert_eq!(img3.to_wgsl(&gctx), "texture_depth_multisampled_2d");
+
+ let array = crate::TypeInner::BindingArray {
+ base: mytype1,
+ size: crate::ArraySize::Constant(unsafe { NonZeroU32::new_unchecked(32) }),
+ };
+ assert_eq!(array.to_wgsl(&gctx), "binding_array<MyType1, 32>");
+ }
+}
diff --git a/third_party/rust/naga/src/keywords/mod.rs b/third_party/rust/naga/src/keywords/mod.rs
new file mode 100644
index 0000000000..d54a1704f7
--- /dev/null
+++ b/third_party/rust/naga/src/keywords/mod.rs
@@ -0,0 +1,6 @@
+/*!
+Lists of reserved keywords for each shading language with a [frontend][crate::front] or [backend][crate::back].
+*/
+
+#[cfg(any(feature = "wgsl-in", feature = "wgsl-out"))]
+pub mod wgsl;
diff --git a/third_party/rust/naga/src/keywords/wgsl.rs b/third_party/rust/naga/src/keywords/wgsl.rs
new file mode 100644
index 0000000000..7b47a13128
--- /dev/null
+++ b/third_party/rust/naga/src/keywords/wgsl.rs
@@ -0,0 +1,229 @@
+/*!
+Keywords for [WGSL][wgsl] (WebGPU Shading Language).
+
+[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html
+*/
+
+// https://gpuweb.github.io/gpuweb/wgsl/#keyword-summary
+// last sync: https://github.com/gpuweb/gpuweb/blob/39f2321f547c8f0b7f473cf1d47fba30b1691303/wgsl/index.bs
+pub const RESERVED: &[&str] = &[
+ // Type-defining Keywords
+ "array",
+ "atomic",
+ "bool",
+ "f32",
+ "f16",
+ "i32",
+ "mat2x2",
+ "mat2x3",
+ "mat2x4",
+ "mat3x2",
+ "mat3x3",
+ "mat3x4",
+ "mat4x2",
+ "mat4x3",
+ "mat4x4",
+ "ptr",
+ "sampler",
+ "sampler_comparison",
+ "texture_1d",
+ "texture_2d",
+ "texture_2d_array",
+ "texture_3d",
+ "texture_cube",
+ "texture_cube_array",
+ "texture_multisampled_2d",
+ "texture_storage_1d",
+ "texture_storage_2d",
+ "texture_storage_2d_array",
+ "texture_storage_3d",
+ "texture_depth_2d",
+ "texture_depth_2d_array",
+ "texture_depth_cube",
+ "texture_depth_cube_array",
+ "texture_depth_multisampled_2d",
+ "u32",
+ "vec2",
+ "vec3",
+ "vec4",
+ // Other Keywords
+ "alias",
+ "bitcast",
+ "break",
+ "case",
+ "const",
+ "continue",
+ "continuing",
+ "default",
+ "discard",
+ "else",
+ "enable",
+ "false",
+ "fn",
+ "for",
+ "if",
+ "let",
+ "loop",
+ "override",
+ "return",
+ "static_assert",
+ "struct",
+ "switch",
+ "true",
+ "type",
+ "var",
+ "while",
+ // Reserved Words
+ "CompileShader",
+ "ComputeShader",
+ "DomainShader",
+ "GeometryShader",
+ "Hullshader",
+ "NULL",
+ "Self",
+ "abstract",
+ "active",
+ "alignas",
+ "alignof",
+ "as",
+ "asm",
+ "asm_fragment",
+ "async",
+ "attribute",
+ "auto",
+ "await",
+ "become",
+ "binding_array",
+ "cast",
+ "catch",
+ "class",
+ "co_await",
+ "co_return",
+ "co_yield",
+ "coherent",
+ "column_major",
+ "common",
+ "compile",
+ "compile_fragment",
+ "concept",
+ "const_cast",
+ "consteval",
+ "constexpr",
+ "constinit",
+ "crate",
+ "debugger",
+ "decltype",
+ "delete",
+ "demote",
+ "demote_to_helper",
+ "do",
+ "dynamic_cast",
+ "enum",
+ "explicit",
+ "export",
+ "extends",
+ "extern",
+ "external",
+ "fallthrough",
+ "filter",
+ "final",
+ "finally",
+ "friend",
+ "from",
+ "fxgroup",
+ "get",
+ "goto",
+ "groupshared",
+ "handle",
+ "highp",
+ "impl",
+ "implements",
+ "import",
+ "inline",
+ "inout",
+ "instanceof",
+ "interface",
+ "layout",
+ "lowp",
+ "macro",
+ "macro_rules",
+ "match",
+ "mediump",
+ "meta",
+ "mod",
+ "module",
+ "move",
+ "mut",
+ "mutable",
+ "namespace",
+ "new",
+ "nil",
+ "noexcept",
+ "noinline",
+ "nointerpolation",
+ "noperspective",
+ "null",
+ "nullptr",
+ "of",
+ "operator",
+ "package",
+ "packoffset",
+ "partition",
+ "pass",
+ "patch",
+ "pixelfragment",
+ "precise",
+ "precision",
+ "premerge",
+ "priv",
+ "protected",
+ "pub",
+ "public",
+ "readonly",
+ "ref",
+ "regardless",
+ "register",
+ "reinterpret_cast",
+ "requires",
+ "resource",
+ "restrict",
+ "self",
+ "set",
+ "shared",
+ "signed",
+ "sizeof",
+ "smooth",
+ "snorm",
+ "static",
+ "static_assert",
+ "static_cast",
+ "std",
+ "subroutine",
+ "super",
+ "target",
+ "template",
+ "this",
+ "thread_local",
+ "throw",
+ "trait",
+ "try",
+ "typedef",
+ "typeid",
+ "typename",
+ "typeof",
+ "union",
+ "unless",
+ "unorm",
+ "unsafe",
+ "unsized",
+ "use",
+ "using",
+ "varying",
+ "virtual",
+ "volatile",
+ "wgsl",
+ "where",
+ "with",
+ "writeonly",
+ "yield",
+];
diff --git a/third_party/rust/naga/src/lib.rs b/third_party/rust/naga/src/lib.rs
new file mode 100644
index 0000000000..d6b9c6a7f4
--- /dev/null
+++ b/third_party/rust/naga/src/lib.rs
@@ -0,0 +1,2072 @@
+/*! Universal shader translator.
+
+The central structure of the crate is [`Module`]. A `Module` contains:
+
+- [`Function`]s, which have arguments, a return type, local variables, and a body,
+
+- [`EntryPoint`]s, which are specialized functions that can serve as the entry
+ point for pipeline stages like vertex shading or fragment shading,
+
+- [`Constant`]s and [`GlobalVariable`]s used by `EntryPoint`s and `Function`s, and
+
+- [`Type`]s used by the above.
+
+The body of an `EntryPoint` or `Function` is represented using two types:
+
+- An [`Expression`] produces a value, but has no side effects or control flow.
+ `Expressions` include variable references, unary and binary operators, and so
+ on.
+
+- A [`Statement`] can have side effects and structured control flow.
+ `Statement`s do not produce a value, other than by storing one in some
+ designated place. `Statements` include blocks, conditionals, and loops, but also
+ operations that have side effects, like stores and function calls.
+
+`Statement`s form a tree, with pointers into the DAG of `Expression`s.
+
+Restricting side effects to statements simplifies analysis and code generation.
+A Naga backend can generate code to evaluate an `Expression` however and
+whenever it pleases, as long as it is certain to observe the side effects of all
+previously executed `Statement`s.
+
+Many `Statement` variants use the [`Block`] type, which is `Vec<Statement>`,
+with optional span info, representing a series of statements executed in order. The body of an
+`EntryPoint`s or `Function` is a `Block`, and `Statement` has a
+[`Block`][Statement::Block] variant.
+
+If the `clone` feature is enabled, [`Arena`], [`UniqueArena`], [`Type`], [`TypeInner`],
+[`Constant`], [`Function`], [`EntryPoint`] and [`Module`] can be cloned.
+
+## Arenas
+
+To improve translator performance and reduce memory usage, most structures are
+stored in an [`Arena`]. An `Arena<T>` stores a series of `T` values, indexed by
+[`Handle<T>`](Handle) values, which are just wrappers around integer indexes.
+For example, a `Function`'s expressions are stored in an `Arena<Expression>`,
+and compound expressions refer to their sub-expressions via `Handle<Expression>`
+values. (When examining the serialized form of a `Module`, note that the first
+element of an `Arena` has an index of 1, not 0.)
+
+A [`UniqueArena`] is just like an `Arena`, except that it stores only a single
+instance of each value. The value type must implement `Eq` and `Hash`. Like an
+`Arena`, inserting a value into a `UniqueArena` returns a `Handle` which can be
+used to efficiently access the value, without a hash lookup. Inserting a value
+multiple times returns the same `Handle`.
+
+If the `span` feature is enabled, both `Arena` and `UniqueArena` can associate a
+source code span with each element.
+
+## Function Calls
+
+Naga's representation of function calls is unusual. Most languages treat
+function calls as expressions, but because calls may have side effects, Naga
+represents them as a kind of statement, [`Statement::Call`]. If the function
+returns a value, a call statement designates a particular [`Expression::CallResult`]
+expression to represent its return value, for use by subsequent statements and
+expressions.
+
+## `Expression` evaluation time
+
+It is essential to know when an [`Expression`] should be evaluated, because its
+value may depend on previous [`Statement`]s' effects. But whereas the order of
+execution for a tree of `Statement`s is apparent from its structure, it is not
+so clear for `Expressions`, since an expression may be referred to by any number
+of `Statement`s and other `Expression`s.
+
+Naga's rules for when `Expression`s are evaluated are as follows:
+
+- [`Literal`], [`Constant`], and [`ZeroValue`] expressions are
+ considered to be implicitly evaluated before execution begins.
+
+- [`FunctionArgument`] and [`LocalVariable`] expressions are considered
+ implicitly evaluated upon entry to the function to which they belong.
+ Function arguments cannot be assigned to, and `LocalVariable` expressions
+ produce a *pointer to* the variable's value (for use with [`Load`] and
+ [`Store`]). Neither varies while the function executes, so it suffices to
+ consider these expressions evaluated once on entry.
+
+- Similarly, [`GlobalVariable`] expressions are considered implicitly
+ evaluated before execution begins, since their value does not change while
+ code executes, for one of two reasons:
+
+ - Most `GlobalVariable` expressions produce a pointer to the variable's
+ value, for use with [`Load`] and [`Store`], as `LocalVariable`
+ expressions do. Although the variable's value may change, its address
+ does not.
+
+ - A `GlobalVariable` expression referring to a global in the
+ [`AddressSpace::Handle`] address space produces the value directly, not
+ a pointer. Such global variables hold opaque types like shaders or
+ images, and cannot be assigned to.
+
+- A [`CallResult`] expression that is the `result` of a [`Statement::Call`],
+ representing the call's return value, is evaluated when the `Call` statement
+ is executed.
+
+- Similarly, an [`AtomicResult`] expression that is the `result` of an
+ [`Atomic`] statement, representing the result of the atomic operation, is
+ evaluated when the `Atomic` statement is executed.
+
+- A [`RayQueryProceedResult`] expression, which is a boolean
+ indicating if the ray query is finished, is evaluated when the
+ [`RayQuery`] statement whose [`Proceed::result`] points to it is
+ executed.
+
+- All other expressions are evaluated when the (unique) [`Statement::Emit`]
+ statement that covers them is executed.
+
+Now, strictly speaking, not all `Expression` variants actually care when they're
+evaluated. For example, you can evaluate a [`BinaryOperator::Add`] expression
+any time you like, as long as you give it the right operands. It's really only a
+very small set of expressions that are affected by timing:
+
+- [`Load`], [`ImageSample`], and [`ImageLoad`] expressions are influenced by
+ stores to the variables or images they access, and must execute at the
+ proper time relative to them.
+
+- [`Derivative`] expressions are sensitive to control flow uniformity: they
+ must not be moved out of an area of uniform control flow into a non-uniform
+ area.
+
+- More generally, any expression that's used by more than one other expression
+ or statement should probably be evaluated only once, and then stored in a
+ variable to be cited at each point of use.
+
+Naga tries to help back ends handle all these cases correctly in a somewhat
+circuitous way. The [`ModuleInfo`] structure returned by [`Validator::validate`]
+provides a reference count for each expression in each function in the module.
+Naturally, any expression with a reference count of two or more deserves to be
+evaluated and stored in a temporary variable at the point that the `Emit`
+statement covering it is executed. But if we selectively lower the reference
+count threshold to _one_ for the sensitive expression types listed above, so
+that we _always_ generate a temporary variable and save their value, then the
+same code that manages multiply referenced expressions will take care of
+introducing temporaries for time-sensitive expressions as well. The
+`Expression::bake_ref_count` method (private to the back ends) is meant to help
+with this.
+
+## `Expression` scope
+
+Each `Expression` has a *scope*, which is the region of the function within
+which it can be used by `Statement`s and other `Expression`s. It is a validation
+error to use an `Expression` outside its scope.
+
+An expression's scope is defined as follows:
+
+- The scope of a [`Constant`], [`GlobalVariable`], [`FunctionArgument`] or
+ [`LocalVariable`] expression covers the entire `Function` in which it
+ occurs.
+
+- The scope of an expression evaluated by an [`Emit`] statement covers the
+ subsequent expressions in that `Emit`, the subsequent statements in the `Block`
+ to which that `Emit` belongs (if any) and their sub-statements (if any).
+
+- The `result` expression of a [`Call`] or [`Atomic`] statement has a scope
+ covering the subsequent statements in the `Block` in which the statement
+ occurs (if any) and their sub-statements (if any).
+
+For example, this implies that an expression evaluated by some statement in a
+nested `Block` is not available in the `Block`'s parents. Such a value would
+need to be stored in a local variable to be carried upwards in the statement
+tree.
+
+## Constant expressions
+
+A Naga *constant expression* is one of the following [`Expression`]
+variants, whose operands (if any) are also constant expressions:
+- [`Literal`]
+- [`Constant`], for [`Constant`s][const_type] whose [`override`] is [`None`]
+- [`ZeroValue`], for fixed-size types
+- [`Compose`]
+- [`Access`]
+- [`AccessIndex`]
+- [`Splat`]
+- [`Swizzle`]
+- [`Unary`]
+- [`Binary`]
+- [`Select`]
+- [`Relational`]
+- [`Math`]
+- [`As`]
+
+A constant expression can be evaluated at module translation time.
+
+## Override expressions
+
+A Naga *override expression* is the same as a [constant expression],
+except that it is also allowed to refer to [`Constant`s][const_type]
+whose [`override`] is something other than [`None`].
+
+An override expression can be evaluated at pipeline creation time.
+
+[`AtomicResult`]: Expression::AtomicResult
+[`RayQueryProceedResult`]: Expression::RayQueryProceedResult
+[`CallResult`]: Expression::CallResult
+[`Constant`]: Expression::Constant
+[`ZeroValue`]: Expression::ZeroValue
+[`Literal`]: Expression::Literal
+[`Derivative`]: Expression::Derivative
+[`FunctionArgument`]: Expression::FunctionArgument
+[`GlobalVariable`]: Expression::GlobalVariable
+[`ImageLoad`]: Expression::ImageLoad
+[`ImageSample`]: Expression::ImageSample
+[`Load`]: Expression::Load
+[`LocalVariable`]: Expression::LocalVariable
+
+[`Atomic`]: Statement::Atomic
+[`Call`]: Statement::Call
+[`Emit`]: Statement::Emit
+[`Store`]: Statement::Store
+[`RayQuery`]: Statement::RayQuery
+
+[`Proceed::result`]: RayQueryFunction::Proceed::result
+
+[`Validator::validate`]: valid::Validator::validate
+[`ModuleInfo`]: valid::ModuleInfo
+
+[`Literal`]: Expression::Literal
+[`ZeroValue`]: Expression::ZeroValue
+[`Compose`]: Expression::Compose
+[`Access`]: Expression::Access
+[`AccessIndex`]: Expression::AccessIndex
+[`Splat`]: Expression::Splat
+[`Swizzle`]: Expression::Swizzle
+[`Unary`]: Expression::Unary
+[`Binary`]: Expression::Binary
+[`Select`]: Expression::Select
+[`Relational`]: Expression::Relational
+[`Math`]: Expression::Math
+[`As`]: Expression::As
+
+[const_type]: Constant
+[`override`]: Constant::override
+[`None`]: Override::None
+
+[constant expression]: index.html#constant-expressions
+*/
+
+#![allow(
+ clippy::new_without_default,
+ clippy::unneeded_field_pattern,
+ clippy::match_like_matches_macro,
+ clippy::collapsible_if,
+ clippy::derive_partial_eq_without_eq,
+ clippy::needless_borrowed_reference,
+ clippy::single_match
+)]
+#![warn(
+ trivial_casts,
+ trivial_numeric_casts,
+ unused_extern_crates,
+ unused_qualifications,
+ clippy::pattern_type_mismatch,
+ clippy::missing_const_for_fn,
+ clippy::rest_pat_in_fully_bound_structs,
+ clippy::match_wildcard_for_single_variants
+)]
+#![deny(clippy::exit)]
+#![cfg_attr(
+ not(test),
+ warn(
+ clippy::dbg_macro,
+ clippy::panic,
+ clippy::print_stderr,
+ clippy::print_stdout,
+ clippy::todo
+ )
+)]
+
+mod arena;
+pub mod back;
+mod block;
+#[cfg(feature = "compact")]
+pub mod compact;
+pub mod front;
+pub mod keywords;
+pub mod proc;
+mod span;
+pub mod valid;
+
+pub use crate::arena::{Arena, Handle, Range, UniqueArena};
+
+pub use crate::span::{SourceLocation, Span, SpanContext, WithSpan};
+#[cfg(feature = "arbitrary")]
+use arbitrary::Arbitrary;
+#[cfg(feature = "deserialize")]
+use serde::Deserialize;
+#[cfg(feature = "serialize")]
+use serde::Serialize;
+
+/// Width of a boolean type, in bytes.
+pub const BOOL_WIDTH: Bytes = 1;
+
+/// Width of abstract types, in bytes.
+pub const ABSTRACT_WIDTH: Bytes = 8;
+
+/// Hash map that is faster but not resilient to DoS attacks.
+pub type FastHashMap<K, T> = rustc_hash::FxHashMap<K, T>;
+/// Hash set that is faster but not resilient to DoS attacks.
+pub type FastHashSet<K> = rustc_hash::FxHashSet<K>;
+
+/// Insertion-order-preserving hash set (`IndexSet<K>`), but with the same
+/// hasher as `FastHashSet<K>` (faster but not resilient to DoS attacks).
+pub type FastIndexSet<K> =
+ indexmap::IndexSet<K, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>;
+
+/// Insertion-order-preserving hash map (`IndexMap<K, V>`), but with the same
+/// hasher as `FastHashMap<K, V>` (faster but not resilient to DoS attacks).
+pub type FastIndexMap<K, V> =
+ indexmap::IndexMap<K, V, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>;
+
+/// Map of expressions that have associated variable names
+pub(crate) type NamedExpressions = FastIndexMap<Handle<Expression>, String>;
+
+/// Early fragment tests.
+///
+/// In a standard situation, if a driver determines that it is possible to switch on early depth test, it will.
+///
+/// Typical situations when early depth test is switched off:
+/// - Calling `discard` in a shader.
+/// - Writing to the depth buffer, unless ConservativeDepth is enabled.
+///
+/// To use in a shader:
+/// - GLSL: `layout(early_fragment_tests) in;`
+/// - HLSL: `Attribute earlydepthstencil`
+/// - SPIR-V: `ExecutionMode EarlyFragmentTests`
+/// - WGSL: `@early_depth_test`
+///
+/// For more, see:
+/// - <https://www.khronos.org/opengl/wiki/Early_Fragment_Test#Explicit_specification>
+/// - <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-attributes-earlydepthstencil>
+/// - <https://www.khronos.org/registry/SPIR-V/specs/unified1/SPIRV.html#Execution_Mode>
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct EarlyDepthTest {
+ pub conservative: Option<ConservativeDepth>,
+}
+/// Enables adjusting depth without disabling early Z.
+///
+/// To use in a shader:
+/// - GLSL: `layout (depth_<greater/less/unchanged/any>) out float gl_FragDepth;`
+/// - `depth_any` option behaves as if the layout qualifier was not present.
+/// - HLSL: `SV_DepthGreaterEqual`/`SV_DepthLessEqual`/`SV_Depth`
+/// - SPIR-V: `ExecutionMode Depth<Greater/Less/Unchanged>`
+/// - WGSL: `@early_depth_test(greater_equal/less_equal/unchanged)`
+///
+/// For more, see:
+/// - <https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt>
+/// - <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-semantics#system-value-semantics>
+/// - <https://www.khronos.org/registry/SPIR-V/specs/unified1/SPIRV.html#Execution_Mode>
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum ConservativeDepth {
+ /// Shader may rewrite depth only with a value greater than calculated.
+ GreaterEqual,
+
+ /// Shader may rewrite depth smaller than one that would have been written without the modification.
+ LessEqual,
+
+ /// Shader may not rewrite depth value.
+ Unchanged,
+}
+
+/// Stage of the programmable pipeline.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+#[allow(missing_docs)] // The names are self evident
+pub enum ShaderStage {
+ Vertex,
+ Fragment,
+ Compute,
+}
+
+/// Addressing space of variables.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum AddressSpace {
+ /// Function locals.
+ Function,
+ /// Private data, per invocation, mutable.
+ Private,
+ /// Workgroup shared data, mutable.
+ WorkGroup,
+ /// Uniform buffer data.
+ Uniform,
+ /// Storage buffer data, potentially mutable.
+ Storage { access: StorageAccess },
+ /// Opaque handles, such as samplers and images.
+ Handle,
+ /// Push constants.
+ PushConstant,
+}
+
+/// Built-in inputs and outputs.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum BuiltIn {
+ Position { invariant: bool },
+ ViewIndex,
+ // vertex
+ BaseInstance,
+ BaseVertex,
+ ClipDistance,
+ CullDistance,
+ InstanceIndex,
+ PointSize,
+ VertexIndex,
+ // fragment
+ FragDepth,
+ PointCoord,
+ FrontFacing,
+ PrimitiveIndex,
+ SampleIndex,
+ SampleMask,
+ // compute
+ GlobalInvocationId,
+ LocalInvocationId,
+ LocalInvocationIndex,
+ WorkGroupId,
+ WorkGroupSize,
+ NumWorkGroups,
+}
+
+/// Number of bytes per scalar.
+pub type Bytes = u8;
+
+/// Number of components in a vector.
+#[repr(u8)]
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum VectorSize {
+ /// 2D vector
+ Bi = 2,
+ /// 3D vector
+ Tri = 3,
+ /// 4D vector
+ Quad = 4,
+}
+
+impl VectorSize {
+ const MAX: usize = Self::Quad as u8 as usize;
+}
+
+/// Primitive type for a scalar.
+#[repr(u8)]
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum ScalarKind {
+ /// Signed integer type.
+ Sint,
+ /// Unsigned integer type.
+ Uint,
+ /// Floating point type.
+ Float,
+ /// Boolean type.
+ Bool,
+
+ /// WGSL abstract integer type.
+ ///
+ /// These are forbidden by validation, and should never reach backends.
+ AbstractInt,
+
+ /// Abstract floating-point type.
+ ///
+ /// These are forbidden by validation, and should never reach backends.
+ AbstractFloat,
+}
+
+/// Characteristics of a scalar type.
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct Scalar {
+ /// How the value's bits are to be interpreted.
+ pub kind: ScalarKind,
+
+ /// This size of the value in bytes.
+ pub width: Bytes,
+}
+
+/// Size of an array.
+#[repr(u8)]
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum ArraySize {
+ /// The array size is constant.
+ Constant(std::num::NonZeroU32),
+ /// The array size can change at runtime.
+ Dynamic,
+}
+
+/// The interpolation qualifier of a binding or struct field.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum Interpolation {
+ /// The value will be interpolated in a perspective-correct fashion.
+ /// Also known as "smooth" in glsl.
+ Perspective,
+ /// Indicates that linear, non-perspective, correct
+ /// interpolation must be used.
+ /// Also known as "no_perspective" in glsl.
+ Linear,
+ /// Indicates that no interpolation will be performed.
+ Flat,
+}
+
+/// The sampling qualifiers of a binding or struct field.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum Sampling {
+ /// Interpolate the value at the center of the pixel.
+ Center,
+
+ /// Interpolate the value at a point that lies within all samples covered by
+ /// the fragment within the current primitive. In multisampling, use a
+ /// single value for all samples in the primitive.
+ Centroid,
+
+ /// Interpolate the value at each sample location. In multisampling, invoke
+ /// the fragment shader once per sample.
+ Sample,
+}
+
+/// Member of a user-defined structure.
+// Clone is used only for error reporting and is not intended for end users
+#[derive(Clone, Debug, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct StructMember {
+ pub name: Option<String>,
+ /// Type of the field.
+ pub ty: Handle<Type>,
+ /// For I/O structs, defines the binding.
+ pub binding: Option<Binding>,
+ /// Offset from the beginning from the struct.
+ pub offset: u32,
+}
+
+/// The number of dimensions an image has.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum ImageDimension {
+ /// 1D image
+ D1,
+ /// 2D image
+ D2,
+ /// 3D image
+ D3,
+ /// Cube map
+ Cube,
+}
+
+bitflags::bitflags! {
+ /// Flags describing an image.
+ #[cfg_attr(feature = "serialize", derive(Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(Deserialize))]
+ #[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+ #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
+ pub struct StorageAccess: u32 {
+ /// Storage can be used as a source for load ops.
+ const LOAD = 0x1;
+ /// Storage can be used as a target for store ops.
+ const STORE = 0x2;
+ }
+}
+
+/// Image storage format.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum StorageFormat {
+ // 8-bit formats
+ R8Unorm,
+ R8Snorm,
+ R8Uint,
+ R8Sint,
+
+ // 16-bit formats
+ R16Uint,
+ R16Sint,
+ R16Float,
+ Rg8Unorm,
+ Rg8Snorm,
+ Rg8Uint,
+ Rg8Sint,
+
+ // 32-bit formats
+ R32Uint,
+ R32Sint,
+ R32Float,
+ Rg16Uint,
+ Rg16Sint,
+ Rg16Float,
+ Rgba8Unorm,
+ Rgba8Snorm,
+ Rgba8Uint,
+ Rgba8Sint,
+ Bgra8Unorm,
+
+ // Packed 32-bit formats
+ Rgb10a2Uint,
+ Rgb10a2Unorm,
+ Rg11b10Float,
+
+ // 64-bit formats
+ Rg32Uint,
+ Rg32Sint,
+ Rg32Float,
+ Rgba16Uint,
+ Rgba16Sint,
+ Rgba16Float,
+
+ // 128-bit formats
+ Rgba32Uint,
+ Rgba32Sint,
+ Rgba32Float,
+
+ // Normalized 16-bit per channel formats
+ R16Unorm,
+ R16Snorm,
+ Rg16Unorm,
+ Rg16Snorm,
+ Rgba16Unorm,
+ Rgba16Snorm,
+}
+
+/// Sub-class of the image type.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum ImageClass {
+ /// Regular sampled image.
+ Sampled {
+ /// Kind of values to sample.
+ kind: ScalarKind,
+ /// Multi-sampled image.
+ ///
+ /// A multi-sampled image holds several samples per texel. Multi-sampled
+ /// images cannot have mipmaps.
+ multi: bool,
+ },
+ /// Depth comparison image.
+ Depth {
+ /// Multi-sampled depth image.
+ multi: bool,
+ },
+ /// Storage image.
+ Storage {
+ format: StorageFormat,
+ access: StorageAccess,
+ },
+}
+
+/// A data type declared in the module.
+#[derive(Clone, Debug, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct Type {
+ /// The name of the type, if any.
+ pub name: Option<String>,
+ /// Inner structure that depends on the kind of the type.
+ pub inner: TypeInner,
+}
+
+/// Enum with additional information, depending on the kind of type.
+#[derive(Clone, Debug, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum TypeInner {
+ /// Number of integral or floating-point kind.
+ Scalar(Scalar),
+ /// Vector of numbers.
+ Vector { size: VectorSize, scalar: Scalar },
+ /// Matrix of numbers.
+ Matrix {
+ columns: VectorSize,
+ rows: VectorSize,
+ scalar: Scalar,
+ },
+ /// Atomic scalar.
+ Atomic(Scalar),
+ /// Pointer to another type.
+ ///
+ /// Pointers to scalars and vectors should be treated as equivalent to
+ /// [`ValuePointer`] types. Use the [`TypeInner::equivalent`] method to
+ /// compare types in a way that treats pointers correctly.
+ ///
+ /// ## Pointers to non-`SIZED` types
+ ///
+ /// The `base` type of a pointer may be a non-[`SIZED`] type like a
+ /// dynamically-sized [`Array`], or a [`Struct`] whose last member is a
+ /// dynamically sized array. Such pointers occur as the types of
+ /// [`GlobalVariable`] or [`AccessIndex`] expressions referring to
+ /// dynamically-sized arrays.
+ ///
+ /// However, among pointers to non-`SIZED` types, only pointers to `Struct`s
+ /// are [`DATA`]. Pointers to dynamically sized `Array`s cannot be passed as
+ /// arguments, stored in variables, or held in arrays or structures. Their
+ /// only use is as the types of `AccessIndex` expressions.
+ ///
+ /// [`SIZED`]: valid::TypeFlags::SIZED
+ /// [`DATA`]: valid::TypeFlags::DATA
+ /// [`Array`]: TypeInner::Array
+ /// [`Struct`]: TypeInner::Struct
+ /// [`ValuePointer`]: TypeInner::ValuePointer
+ /// [`GlobalVariable`]: Expression::GlobalVariable
+ /// [`AccessIndex`]: Expression::AccessIndex
+ Pointer {
+ base: Handle<Type>,
+ space: AddressSpace,
+ },
+
+ /// Pointer to a scalar or vector.
+ ///
+ /// A `ValuePointer` type is equivalent to a `Pointer` whose `base` is a
+ /// `Scalar` or `Vector` type. This is for use in [`TypeResolution::Value`]
+ /// variants; see the documentation for [`TypeResolution`] for details.
+ ///
+ /// Use the [`TypeInner::equivalent`] method to compare types that could be
+ /// pointers, to ensure that `Pointer` and `ValuePointer` types are
+ /// recognized as equivalent.
+ ///
+ /// [`TypeResolution`]: proc::TypeResolution
+ /// [`TypeResolution::Value`]: proc::TypeResolution::Value
+ ValuePointer {
+ size: Option<VectorSize>,
+ scalar: Scalar,
+ space: AddressSpace,
+ },
+
+ /// Homogeneous list of elements.
+ ///
+ /// The `base` type must be a [`SIZED`], [`DATA`] type.
+ ///
+ /// ## Dynamically sized arrays
+ ///
+ /// An `Array` is [`SIZED`] unless its `size` is [`Dynamic`].
+ /// Dynamically-sized arrays may only appear in a few situations:
+ ///
+ /// - They may appear as the type of a [`GlobalVariable`], or as the last
+ /// member of a [`Struct`].
+ ///
+ /// - They may appear as the base type of a [`Pointer`]. An
+ /// [`AccessIndex`] expression referring to a struct's final
+ /// unsized array member would have such a pointer type. However, such
+ /// pointer types may only appear as the types of such intermediate
+ /// expressions. They are not [`DATA`], and cannot be stored in
+ /// variables, held in arrays or structs, or passed as parameters.
+ ///
+ /// [`SIZED`]: crate::valid::TypeFlags::SIZED
+ /// [`DATA`]: crate::valid::TypeFlags::DATA
+ /// [`Dynamic`]: ArraySize::Dynamic
+ /// [`Struct`]: TypeInner::Struct
+ /// [`Pointer`]: TypeInner::Pointer
+ /// [`AccessIndex`]: Expression::AccessIndex
+ Array {
+ base: Handle<Type>,
+ size: ArraySize,
+ stride: u32,
+ },
+
+ /// User-defined structure.
+ ///
+ /// There must always be at least one member.
+ ///
+ /// A `Struct` type is [`DATA`], and the types of its members must be
+ /// `DATA` as well.
+ ///
+ /// Member types must be [`SIZED`], except for the final member of a
+ /// struct, which may be a dynamically sized [`Array`]. The
+ /// `Struct` type itself is `SIZED` when all its members are `SIZED`.
+ ///
+ /// [`DATA`]: crate::valid::TypeFlags::DATA
+ /// [`SIZED`]: crate::valid::TypeFlags::SIZED
+ /// [`Array`]: TypeInner::Array
+ Struct {
+ members: Vec<StructMember>,
+ //TODO: should this be unaligned?
+ span: u32,
+ },
+ /// Possibly multidimensional array of texels.
+ Image {
+ dim: ImageDimension,
+ arrayed: bool,
+ //TODO: consider moving `multisampled: bool` out
+ class: ImageClass,
+ },
+ /// Can be used to sample values from images.
+ Sampler { comparison: bool },
+
+ /// Opaque object representing an acceleration structure of geometry.
+ AccelerationStructure,
+
+ /// Locally used handle for ray queries.
+ RayQuery,
+
+ /// Array of bindings.
+ ///
+ /// A `BindingArray` represents an array where each element draws its value
+ /// from a separate bound resource. The array's element type `base` may be
+ /// [`Image`], [`Sampler`], or any type that would be permitted for a global
+ /// in the [`Uniform`] or [`Storage`] address spaces. Only global variables
+ /// may be binding arrays; on the host side, their values are provided by
+ /// [`TextureViewArray`], [`SamplerArray`], or [`BufferArray`]
+ /// bindings.
+ ///
+ /// Since each element comes from a distinct resource, a binding array of
+ /// images could have images of varying sizes (but not varying dimensions;
+ /// they must all have the same `Image` type). Or, a binding array of
+ /// buffers could have elements that are dynamically sized arrays, each with
+ /// a different length.
+ ///
+ /// Binding arrays are in the same address spaces as their underlying type.
+ /// As such, referring to an array of images produces an [`Image`] value
+ /// directly (as opposed to a pointer). The only operation permitted on
+ /// `BindingArray` values is indexing, which works transparently: indexing
+ /// a binding array of samplers yields a [`Sampler`], indexing a pointer to the
+ /// binding array of storage buffers produces a pointer to the storage struct.
+ ///
+ /// Unlike textures and samplers, binding arrays are not [`ARGUMENT`], so
+ /// they cannot be passed as arguments to functions.
+ ///
+ /// Naga's WGSL front end supports binding arrays with the type syntax
+ /// `binding_array<T, N>`.
+ ///
+ /// [`Image`]: TypeInner::Image
+ /// [`Sampler`]: TypeInner::Sampler
+ /// [`Uniform`]: AddressSpace::Uniform
+ /// [`Storage`]: AddressSpace::Storage
+ /// [`TextureViewArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.TextureViewArray
+ /// [`SamplerArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.SamplerArray
+ /// [`BufferArray`]: https://docs.rs/wgpu/latest/wgpu/enum.BindingResource.html#variant.BufferArray
+ /// [`DATA`]: crate::valid::TypeFlags::DATA
+ /// [`ARGUMENT`]: crate::valid::TypeFlags::ARGUMENT
+ /// [naga#1864]: https://github.com/gfx-rs/naga/issues/1864
+ BindingArray { base: Handle<Type>, size: ArraySize },
+}
+
+#[derive(Debug, Clone, Copy, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum Literal {
+ /// May not be NaN or infinity.
+ F64(f64),
+ /// May not be NaN or infinity.
+ F32(f32),
+ U32(u32),
+ I32(i32),
+ I64(i64),
+ Bool(bool),
+ AbstractInt(i64),
+ AbstractFloat(f64),
+}
+
+#[derive(Debug, PartialEq)]
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum Override {
+ None,
+ ByName,
+ ByNameOrId(u32),
+}
+
+/// Constant value.
+#[derive(Debug, PartialEq)]
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct Constant {
+ pub name: Option<String>,
+ pub r#override: Override,
+ pub ty: Handle<Type>,
+
+ /// The value of the constant.
+ ///
+ /// This [`Handle`] refers to [`Module::const_expressions`], not
+ /// any [`Function::expressions`] arena.
+ ///
+ /// If [`override`] is [`None`], then this must be a Naga
+ /// [constant expression]. Otherwise, this may be a Naga
+ /// [override expression] or [constant expression].
+ ///
+ /// [`override`]: Constant::override
+ /// [`None`]: Override::None
+ /// [constant expression]: index.html#constant-expressions
+ /// [override expression]: index.html#override-expressions
+ pub init: Handle<Expression>,
+}
+
+/// Describes how an input/output variable is to be bound.
+#[derive(Clone, Debug, Eq, PartialEq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum Binding {
+ /// Built-in shader variable.
+ BuiltIn(BuiltIn),
+
+ /// Indexed location.
+ ///
+ /// Values passed from the [`Vertex`] stage to the [`Fragment`] stage must
+ /// have their `interpolation` defaulted (i.e. not `None`) by the front end
+ /// as appropriate for that language.
+ ///
+ /// For other stages, we permit interpolations even though they're ignored.
+ /// When a front end is parsing a struct type, it usually doesn't know what
+ /// stages will be using it for IO, so it's easiest if it can apply the
+ /// defaults to anything with a `Location` binding, just in case.
+ ///
+ /// For anything other than floating-point scalars and vectors, the
+ /// interpolation must be `Flat`.
+ ///
+ /// [`Vertex`]: crate::ShaderStage::Vertex
+ /// [`Fragment`]: crate::ShaderStage::Fragment
+ Location {
+ location: u32,
+ /// Indicates the 2nd input to the blender when dual-source blending.
+ second_blend_source: bool,
+ interpolation: Option<Interpolation>,
+ sampling: Option<Sampling>,
+ },
+}
+
+/// Pipeline binding information for global resources.
+#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct ResourceBinding {
+ /// The bind group index.
+ pub group: u32,
+ /// Binding number within the group.
+ pub binding: u32,
+}
+
+/// Variable defined at module level.
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct GlobalVariable {
+ /// Name of the variable, if any.
+ pub name: Option<String>,
+ /// How this variable is to be stored.
+ pub space: AddressSpace,
+ /// For resources, defines the binding point.
+ pub binding: Option<ResourceBinding>,
+ /// The type of this variable.
+ pub ty: Handle<Type>,
+ /// Initial value for this variable.
+ ///
+ /// Expression handle lives in const_expressions
+ pub init: Option<Handle<Expression>>,
+}
+
+/// Variable defined at function level.
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct LocalVariable {
+ /// Name of the variable, if any.
+ pub name: Option<String>,
+ /// The type of this variable.
+ pub ty: Handle<Type>,
+ /// Initial value for this variable.
+ ///
+ /// This handle refers to this `LocalVariable`'s function's
+ /// [`expressions`] arena, but it is required to be an evaluated
+ /// constant expression.
+ ///
+ /// [`expressions`]: Function::expressions
+ pub init: Option<Handle<Expression>>,
+}
+
+/// Operation that can be applied on a single value.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum UnaryOperator {
+ Negate,
+ LogicalNot,
+ BitwiseNot,
+}
+
+/// Operation that can be applied on two values.
+///
+/// ## Arithmetic type rules
+///
+/// The arithmetic operations `Add`, `Subtract`, `Multiply`, `Divide`, and
+/// `Modulo` can all be applied to [`Scalar`] types other than [`Bool`], or
+/// [`Vector`]s thereof. Both operands must have the same type.
+///
+/// `Add` and `Subtract` can also be applied to [`Matrix`] values. Both operands
+/// must have the same type.
+///
+/// `Multiply` supports additional cases:
+///
+/// - A [`Matrix`] or [`Vector`] can be multiplied by a scalar [`Float`],
+/// either on the left or the right.
+///
+/// - A [`Matrix`] on the left can be multiplied by a [`Vector`] on the right
+/// if the matrix has as many columns as the vector has components (`matCxR
+/// * VecC`).
+///
+/// - A [`Vector`] on the left can be multiplied by a [`Matrix`] on the right
+/// if the matrix has as many rows as the vector has components (`VecR *
+/// matCxR`).
+///
+/// - Two matrices can be multiplied if the left operand has as many columns
+/// as the right operand has rows (`matNxR * matCxN`).
+///
+/// In all the above `Multiply` cases, the byte widths of the underlying scalar
+/// types of both operands must be the same.
+///
+/// Note that `Multiply` supports mixed vector and scalar operations directly,
+/// whereas the other arithmetic operations require an explicit [`Splat`] for
+/// mixed-type use.
+///
+/// [`Scalar`]: TypeInner::Scalar
+/// [`Vector`]: TypeInner::Vector
+/// [`Matrix`]: TypeInner::Matrix
+/// [`Float`]: ScalarKind::Float
+/// [`Bool`]: ScalarKind::Bool
+/// [`Splat`]: Expression::Splat
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum BinaryOperator {
+ Add,
+ Subtract,
+ Multiply,
+ Divide,
+ /// Equivalent of the WGSL's `%` operator or SPIR-V's `OpFRem`
+ Modulo,
+ Equal,
+ NotEqual,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+ And,
+ ExclusiveOr,
+ InclusiveOr,
+ LogicalAnd,
+ LogicalOr,
+ ShiftLeft,
+ /// Right shift carries the sign of signed integers only.
+ ShiftRight,
+}
+
+/// Function on an atomic value.
+///
+/// Note: these do not include load/store, which use the existing
+/// [`Expression::Load`] and [`Statement::Store`].
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum AtomicFunction {
+ Add,
+ Subtract,
+ And,
+ ExclusiveOr,
+ InclusiveOr,
+ Min,
+ Max,
+ Exchange { compare: Option<Handle<Expression>> },
+}
+
+/// Hint at which precision to compute a derivative.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum DerivativeControl {
+ Coarse,
+ Fine,
+ None,
+}
+
+/// Axis on which to compute a derivative.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum DerivativeAxis {
+ X,
+ Y,
+ Width,
+}
+
+/// Built-in shader function for testing relation between values.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum RelationalFunction {
+ All,
+ Any,
+ IsNan,
+ IsInf,
+}
+
+/// Built-in shader function for math.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum MathFunction {
+ // comparison
+ Abs,
+ Min,
+ Max,
+ Clamp,
+ Saturate,
+ // trigonometry
+ Cos,
+ Cosh,
+ Sin,
+ Sinh,
+ Tan,
+ Tanh,
+ Acos,
+ Asin,
+ Atan,
+ Atan2,
+ Asinh,
+ Acosh,
+ Atanh,
+ Radians,
+ Degrees,
+ // decomposition
+ Ceil,
+ Floor,
+ Round,
+ Fract,
+ Trunc,
+ Modf,
+ Frexp,
+ Ldexp,
+ // exponent
+ Exp,
+ Exp2,
+ Log,
+ Log2,
+ Pow,
+ // geometry
+ Dot,
+ Outer,
+ Cross,
+ Distance,
+ Length,
+ Normalize,
+ FaceForward,
+ Reflect,
+ Refract,
+ // computational
+ Sign,
+ Fma,
+ Mix,
+ Step,
+ SmoothStep,
+ Sqrt,
+ InverseSqrt,
+ Inverse,
+ Transpose,
+ Determinant,
+ // bits
+ CountTrailingZeros,
+ CountLeadingZeros,
+ CountOneBits,
+ ReverseBits,
+ ExtractBits,
+ InsertBits,
+ FindLsb,
+ FindMsb,
+ // data packing
+ Pack4x8snorm,
+ Pack4x8unorm,
+ Pack2x16snorm,
+ Pack2x16unorm,
+ Pack2x16float,
+ // data unpacking
+ Unpack4x8snorm,
+ Unpack4x8unorm,
+ Unpack2x16snorm,
+ Unpack2x16unorm,
+ Unpack2x16float,
+}
+
+/// Sampling modifier to control the level of detail.
+#[derive(Clone, Copy, Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum SampleLevel {
+ Auto,
+ Zero,
+ Exact(Handle<Expression>),
+ Bias(Handle<Expression>),
+ Gradient {
+ x: Handle<Expression>,
+ y: Handle<Expression>,
+ },
+}
+
+/// Type of an image query.
+#[derive(Clone, Copy, Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum ImageQuery {
+ /// Get the size at the specified level.
+ Size {
+ /// If `None`, the base level is considered.
+ level: Option<Handle<Expression>>,
+ },
+ /// Get the number of mipmap levels.
+ NumLevels,
+ /// Get the number of array layers.
+ NumLayers,
+ /// Get the number of samples.
+ NumSamples,
+}
+
+/// Component selection for a vector swizzle.
+#[repr(u8)]
+#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum SwizzleComponent {
+ ///
+ X = 0,
+ ///
+ Y = 1,
+ ///
+ Z = 2,
+ ///
+ W = 3,
+}
+
+bitflags::bitflags! {
+ /// Memory barrier flags.
+ #[cfg_attr(feature = "serialize", derive(Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(Deserialize))]
+ #[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+ #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
+ pub struct Barrier: u32 {
+ /// Barrier affects all `AddressSpace::Storage` accesses.
+ const STORAGE = 0x1;
+ /// Barrier affects all `AddressSpace::WorkGroup` accesses.
+ const WORK_GROUP = 0x2;
+ }
+}
+
+/// An expression that can be evaluated to obtain a value.
+///
+/// This is a Single Static Assignment (SSA) scheme similar to SPIR-V.
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum Expression {
+ /// Literal.
+ Literal(Literal),
+ /// Constant value.
+ Constant(Handle<Constant>),
+ /// Zero value of a type.
+ ZeroValue(Handle<Type>),
+ /// Composite expression.
+ Compose {
+ ty: Handle<Type>,
+ components: Vec<Handle<Expression>>,
+ },
+
+ /// Array access with a computed index.
+ ///
+ /// ## Typing rules
+ ///
+ /// The `base` operand must be some composite type: [`Vector`], [`Matrix`],
+ /// [`Array`], a [`Pointer`] to one of those, or a [`ValuePointer`] with a
+ /// `size`.
+ ///
+ /// The `index` operand must be an integer, signed or unsigned.
+ ///
+ /// Indexing a [`Vector`] or [`Array`] produces a value of its element type.
+ /// Indexing a [`Matrix`] produces a [`Vector`].
+ ///
+ /// Indexing a [`Pointer`] to any of the above produces a pointer to the
+ /// element/component type, in the same [`space`]. In the case of [`Array`],
+ /// the result is an actual [`Pointer`], but for vectors and matrices, there
+ /// may not be any type in the arena representing the component's type, so
+ /// those produce [`ValuePointer`] types equivalent to the appropriate
+ /// [`Pointer`].
+ ///
+ /// ## Dynamic indexing restrictions
+ ///
+ /// To accommodate restrictions in some of the shader languages that Naga
+ /// targets, it is not permitted to subscript a matrix or array with a
+ /// dynamically computed index unless that matrix or array appears behind a
+ /// pointer. In other words, if the inner type of `base` is [`Array`] or
+ /// [`Matrix`], then `index` must be a constant. But if the type of `base`
+ /// is a [`Pointer`] to an array or matrix or a [`ValuePointer`] with a
+ /// `size`, then the index may be any expression of integer type.
+ ///
+ /// You can use the [`Expression::is_dynamic_index`] method to determine
+ /// whether a given index expression requires matrix or array base operands
+ /// to be behind a pointer.
+ ///
+ /// (It would be simpler to always require the use of `AccessIndex` when
+ /// subscripting arrays and matrices that are not behind pointers, but to
+ /// accommodate existing front ends, Naga also permits `Access`, with a
+ /// restricted `index`.)
+ ///
+ /// [`Vector`]: TypeInner::Vector
+ /// [`Matrix`]: TypeInner::Matrix
+ /// [`Array`]: TypeInner::Array
+ /// [`Pointer`]: TypeInner::Pointer
+ /// [`space`]: TypeInner::Pointer::space
+ /// [`ValuePointer`]: TypeInner::ValuePointer
+ /// [`Float`]: ScalarKind::Float
+ Access {
+ base: Handle<Expression>,
+ index: Handle<Expression>,
+ },
+ /// Access the same types as [`Access`], plus [`Struct`] with a known index.
+ ///
+ /// [`Access`]: Expression::Access
+ /// [`Struct`]: TypeInner::Struct
+ AccessIndex {
+ base: Handle<Expression>,
+ index: u32,
+ },
+ /// Splat scalar into a vector.
+ Splat {
+ size: VectorSize,
+ value: Handle<Expression>,
+ },
+ /// Vector swizzle.
+ Swizzle {
+ size: VectorSize,
+ vector: Handle<Expression>,
+ pattern: [SwizzleComponent; 4],
+ },
+
+ /// Reference a function parameter, by its index.
+ ///
+ /// A `FunctionArgument` expression evaluates to a pointer to the argument's
+ /// value. You must use a [`Load`] expression to retrieve its value, or a
+ /// [`Store`] statement to assign it a new value.
+ ///
+ /// [`Load`]: Expression::Load
+ /// [`Store`]: Statement::Store
+ FunctionArgument(u32),
+
+ /// Reference a global variable.
+ ///
+ /// If the given `GlobalVariable`'s [`space`] is [`AddressSpace::Handle`],
+ /// then the variable stores some opaque type like a sampler or an image,
+ /// and a `GlobalVariable` expression referring to it produces the
+ /// variable's value directly.
+ ///
+ /// For any other address space, a `GlobalVariable` expression produces a
+ /// pointer to the variable's value. You must use a [`Load`] expression to
+ /// retrieve its value, or a [`Store`] statement to assign it a new value.
+ ///
+ /// [`space`]: GlobalVariable::space
+ /// [`Load`]: Expression::Load
+ /// [`Store`]: Statement::Store
+ GlobalVariable(Handle<GlobalVariable>),
+
+ /// Reference a local variable.
+ ///
+ /// A `LocalVariable` expression evaluates to a pointer to the variable's value.
+ /// You must use a [`Load`](Expression::Load) expression to retrieve its value,
+ /// or a [`Store`](Statement::Store) statement to assign it a new value.
+ LocalVariable(Handle<LocalVariable>),
+
+ /// Load a value indirectly.
+ ///
+ /// For [`TypeInner::Atomic`] the result is a corresponding scalar.
+ /// For other types behind the `pointer<T>`, the result is `T`.
+ Load { pointer: Handle<Expression> },
+ /// Sample a point from a sampled or a depth image.
+ ImageSample {
+ image: Handle<Expression>,
+ sampler: Handle<Expression>,
+ /// If Some(), this operation is a gather operation
+ /// on the selected component.
+ gather: Option<SwizzleComponent>,
+ coordinate: Handle<Expression>,
+ array_index: Option<Handle<Expression>>,
+ /// Expression handle lives in const_expressions
+ offset: Option<Handle<Expression>>,
+ level: SampleLevel,
+ depth_ref: Option<Handle<Expression>>,
+ },
+
+ /// Load a texel from an image.
+ ///
+ /// For most images, this returns a four-element vector of the same
+ /// [`ScalarKind`] as the image. If the format of the image does not have
+ /// four components, default values are provided: the first three components
+ /// (typically R, G, and B) default to zero, and the final component
+ /// (typically alpha) defaults to one.
+ ///
+ /// However, if the image's [`class`] is [`Depth`], then this returns a
+ /// [`Float`] scalar value.
+ ///
+ /// [`ScalarKind`]: ScalarKind
+ /// [`class`]: TypeInner::Image::class
+ /// [`Depth`]: ImageClass::Depth
+ /// [`Float`]: ScalarKind::Float
+ ImageLoad {
+ /// The image to load a texel from. This must have type [`Image`]. (This
+ /// will necessarily be a [`GlobalVariable`] or [`FunctionArgument`]
+ /// expression, since no other expressions are allowed to have that
+ /// type.)
+ ///
+ /// [`Image`]: TypeInner::Image
+ /// [`GlobalVariable`]: Expression::GlobalVariable
+ /// [`FunctionArgument`]: Expression::FunctionArgument
+ image: Handle<Expression>,
+
+ /// The coordinate of the texel we wish to load. This must be a scalar
+ /// for [`D1`] images, a [`Bi`] vector for [`D2`] images, and a [`Tri`]
+ /// vector for [`D3`] images. (Array indices, sample indices, and
+ /// explicit level-of-detail values are supplied separately.) Its
+ /// component type must be [`Sint`].
+ ///
+ /// [`D1`]: ImageDimension::D1
+ /// [`D2`]: ImageDimension::D2
+ /// [`D3`]: ImageDimension::D3
+ /// [`Bi`]: VectorSize::Bi
+ /// [`Tri`]: VectorSize::Tri
+ /// [`Sint`]: ScalarKind::Sint
+ coordinate: Handle<Expression>,
+
+ /// The index into an arrayed image. If the [`arrayed`] flag in
+ /// `image`'s type is `true`, then this must be `Some(expr)`, where
+ /// `expr` is a [`Sint`] scalar. Otherwise, it must be `None`.
+ ///
+ /// [`arrayed`]: TypeInner::Image::arrayed
+ /// [`Sint`]: ScalarKind::Sint
+ array_index: Option<Handle<Expression>>,
+
+ /// A sample index, for multisampled [`Sampled`] and [`Depth`] images.
+ ///
+ /// [`Sampled`]: ImageClass::Sampled
+ /// [`Depth`]: ImageClass::Depth
+ sample: Option<Handle<Expression>>,
+
+ /// A level of detail, for mipmapped images.
+ ///
+ /// This must be present when accessing non-multisampled
+ /// [`Sampled`] and [`Depth`] images, even if only the
+ /// full-resolution level is present (in which case the only
+ /// valid level is zero).
+ ///
+ /// [`Sampled`]: ImageClass::Sampled
+ /// [`Depth`]: ImageClass::Depth
+ level: Option<Handle<Expression>>,
+ },
+
+ /// Query information from an image.
+ ImageQuery {
+ image: Handle<Expression>,
+ query: ImageQuery,
+ },
+ /// Apply an unary operator.
+ Unary {
+ op: UnaryOperator,
+ expr: Handle<Expression>,
+ },
+ /// Apply a binary operator.
+ Binary {
+ op: BinaryOperator,
+ left: Handle<Expression>,
+ right: Handle<Expression>,
+ },
+ /// Select between two values based on a condition.
+ ///
+ /// Note that, because expressions have no side effects, it is unobservable
+ /// whether the non-selected branch is evaluated.
+ Select {
+ /// Boolean expression
+ condition: Handle<Expression>,
+ accept: Handle<Expression>,
+ reject: Handle<Expression>,
+ },
+ /// Compute the derivative on an axis.
+ Derivative {
+ axis: DerivativeAxis,
+ ctrl: DerivativeControl,
+ expr: Handle<Expression>,
+ },
+ /// Call a relational function.
+ Relational {
+ fun: RelationalFunction,
+ argument: Handle<Expression>,
+ },
+ /// Call a math function
+ Math {
+ fun: MathFunction,
+ arg: Handle<Expression>,
+ arg1: Option<Handle<Expression>>,
+ arg2: Option<Handle<Expression>>,
+ arg3: Option<Handle<Expression>>,
+ },
+ /// Cast a simple type to another kind.
+ As {
+ /// Source expression, which can only be a scalar or a vector.
+ expr: Handle<Expression>,
+ /// Target scalar kind.
+ kind: ScalarKind,
+ /// If provided, converts to the specified byte width.
+ /// Otherwise, bitcast.
+ convert: Option<Bytes>,
+ },
+ /// Result of calling another function.
+ CallResult(Handle<Function>),
+ /// Result of an atomic operation.
+ AtomicResult { ty: Handle<Type>, comparison: bool },
+ /// Result of a [`WorkGroupUniformLoad`] statement.
+ ///
+ /// [`WorkGroupUniformLoad`]: Statement::WorkGroupUniformLoad
+ WorkGroupUniformLoadResult {
+ /// The type of the result
+ ty: Handle<Type>,
+ },
+ /// Get the length of an array.
+ /// The expression must resolve to a pointer to an array with a dynamic size.
+ ///
+ /// This doesn't match the semantics of spirv's `OpArrayLength`, which must be passed
+ /// a pointer to a structure containing a runtime array in its' last field.
+ ArrayLength(Handle<Expression>),
+
+ /// Result of a [`Proceed`] [`RayQuery`] statement.
+ ///
+ /// [`Proceed`]: RayQueryFunction::Proceed
+ /// [`RayQuery`]: Statement::RayQuery
+ RayQueryProceedResult,
+
+ /// Return an intersection found by `query`.
+ ///
+ /// If `committed` is true, return the committed result available when
+ RayQueryGetIntersection {
+ query: Handle<Expression>,
+ committed: bool,
+ },
+}
+
+pub use block::Block;
+
+/// The value of the switch case.
+#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum SwitchValue {
+ I32(i32),
+ U32(u32),
+ Default,
+}
+
+/// A case for a switch statement.
+// Clone is used only for error reporting and is not intended for end users
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct SwitchCase {
+ /// Value, upon which the case is considered true.
+ pub value: SwitchValue,
+ /// Body of the case.
+ pub body: Block,
+ /// If true, the control flow continues to the next case in the list,
+ /// or default.
+ pub fall_through: bool,
+}
+
+/// An operation that a [`RayQuery` statement] applies to its [`query`] operand.
+///
+/// [`RayQuery` statement]: Statement::RayQuery
+/// [`query`]: Statement::RayQuery::query
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum RayQueryFunction {
+ /// Initialize the `RayQuery` object.
+ Initialize {
+ /// The acceleration structure within which this query should search for hits.
+ ///
+ /// The expression must be an [`AccelerationStructure`].
+ ///
+ /// [`AccelerationStructure`]: TypeInner::AccelerationStructure
+ acceleration_structure: Handle<Expression>,
+
+ #[allow(rustdoc::private_intra_doc_links)]
+ /// A struct of detailed parameters for the ray query.
+ ///
+ /// This expression should have the struct type given in
+ /// [`SpecialTypes::ray_desc`]. This is available in the WGSL
+ /// front end as the `RayDesc` type.
+ descriptor: Handle<Expression>,
+ },
+
+ /// Start or continue the query given by the statement's [`query`] operand.
+ ///
+ /// After executing this statement, the `result` expression is a
+ /// [`Bool`] scalar indicating whether there are more intersection
+ /// candidates to consider.
+ ///
+ /// [`query`]: Statement::RayQuery::query
+ /// [`Bool`]: ScalarKind::Bool
+ Proceed {
+ result: Handle<Expression>,
+ },
+
+ Terminate,
+}
+
+//TODO: consider removing `Clone`. It's not valid to clone `Statement::Emit` anyway.
+/// Instructions which make up an executable block.
+// Clone is used only for error reporting and is not intended for end users
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum Statement {
+ /// Emit a range of expressions, visible to all statements that follow in this block.
+ ///
+ /// See the [module-level documentation][emit] for details.
+ ///
+ /// [emit]: index.html#expression-evaluation-time
+ Emit(Range<Expression>),
+ /// A block containing more statements, to be executed sequentially.
+ Block(Block),
+ /// Conditionally executes one of two blocks, based on the value of the condition.
+ If {
+ condition: Handle<Expression>, //bool
+ accept: Block,
+ reject: Block,
+ },
+ /// Conditionally executes one of multiple blocks, based on the value of the selector.
+ ///
+ /// Each case must have a distinct [`value`], exactly one of which must be
+ /// [`Default`]. The `Default` may appear at any position, and covers all
+ /// values not explicitly appearing in other cases. A `Default` appearing in
+ /// the midst of the list of cases does not shadow the cases that follow.
+ ///
+ /// Some backend languages don't support fallthrough (HLSL due to FXC,
+ /// WGSL), and may translate fallthrough cases in the IR by duplicating
+ /// code. However, all backend languages do support cases selected by
+ /// multiple values, like `case 1: case 2: case 3: { ... }`. This is
+ /// represented in the IR as a series of fallthrough cases with empty
+ /// bodies, except for the last.
+ ///
+ /// [`value`]: SwitchCase::value
+ /// [`body`]: SwitchCase::body
+ /// [`Default`]: SwitchValue::Default
+ Switch {
+ selector: Handle<Expression>,
+ cases: Vec<SwitchCase>,
+ },
+
+ /// Executes a block repeatedly.
+ ///
+ /// Each iteration of the loop executes the `body` block, followed by the
+ /// `continuing` block.
+ ///
+ /// Executing a [`Break`], [`Return`] or [`Kill`] statement exits the loop.
+ ///
+ /// A [`Continue`] statement in `body` jumps to the `continuing` block. The
+ /// `continuing` block is meant to be used to represent structures like the
+ /// third expression of a C-style `for` loop head, to which `continue`
+ /// statements in the loop's body jump.
+ ///
+ /// The `continuing` block and its substatements must not contain `Return`
+ /// or `Kill` statements, or any `Break` or `Continue` statements targeting
+ /// this loop. (It may have `Break` and `Continue` statements targeting
+ /// loops or switches nested within the `continuing` block.) Expressions
+ /// emitted in `body` are in scope in `continuing`.
+ ///
+ /// If present, `break_if` is an expression which is evaluated after the
+ /// continuing block. Expressions emitted in `body` or `continuing` are
+ /// considered to be in scope. If the expression's value is true, control
+ /// continues after the `Loop` statement, rather than branching back to the
+ /// top of body as usual. The `break_if` expression corresponds to a "break
+ /// if" statement in WGSL, or a loop whose back edge is an
+ /// `OpBranchConditional` instruction in SPIR-V.
+ ///
+ /// [`Break`]: Statement::Break
+ /// [`Continue`]: Statement::Continue
+ /// [`Kill`]: Statement::Kill
+ /// [`Return`]: Statement::Return
+ /// [`break if`]: Self::Loop::break_if
+ Loop {
+ body: Block,
+ continuing: Block,
+ break_if: Option<Handle<Expression>>,
+ },
+
+ /// Exits the innermost enclosing [`Loop`] or [`Switch`].
+ ///
+ /// A `Break` statement may only appear within a [`Loop`] or [`Switch`]
+ /// statement. It may not break out of a [`Loop`] from within the loop's
+ /// `continuing` block.
+ ///
+ /// [`Loop`]: Statement::Loop
+ /// [`Switch`]: Statement::Switch
+ Break,
+
+ /// Skips to the `continuing` block of the innermost enclosing [`Loop`].
+ ///
+ /// A `Continue` statement may only appear within the `body` block of the
+ /// innermost enclosing [`Loop`] statement. It must not appear within that
+ /// loop's `continuing` block.
+ ///
+ /// [`Loop`]: Statement::Loop
+ Continue,
+
+ /// Returns from the function (possibly with a value).
+ ///
+ /// `Return` statements are forbidden within the `continuing` block of a
+ /// [`Loop`] statement.
+ ///
+ /// [`Loop`]: Statement::Loop
+ Return { value: Option<Handle<Expression>> },
+
+ /// Aborts the current shader execution.
+ ///
+ /// `Kill` statements are forbidden within the `continuing` block of a
+ /// [`Loop`] statement.
+ ///
+ /// [`Loop`]: Statement::Loop
+ Kill,
+
+ /// Synchronize invocations within the work group.
+ /// The `Barrier` flags control which memory accesses should be synchronized.
+ /// If empty, this becomes purely an execution barrier.
+ Barrier(Barrier),
+ /// Stores a value at an address.
+ ///
+ /// For [`TypeInner::Atomic`] type behind the pointer, the value
+ /// has to be a corresponding scalar.
+ /// For other types behind the `pointer<T>`, the value is `T`.
+ ///
+ /// This statement is a barrier for any operations on the
+ /// `Expression::LocalVariable` or `Expression::GlobalVariable`
+ /// that is the destination of an access chain, started
+ /// from the `pointer`.
+ Store {
+ pointer: Handle<Expression>,
+ value: Handle<Expression>,
+ },
+ /// Stores a texel value to an image.
+ ///
+ /// The `image`, `coordinate`, and `array_index` fields have the same
+ /// meanings as the corresponding operands of an [`ImageLoad`] expression;
+ /// see that documentation for details. Storing into multisampled images or
+ /// images with mipmaps is not supported, so there are no `level` or
+ /// `sample` operands.
+ ///
+ /// This statement is a barrier for any operations on the corresponding
+ /// [`Expression::GlobalVariable`] for this image.
+ ///
+ /// [`ImageLoad`]: Expression::ImageLoad
+ ImageStore {
+ image: Handle<Expression>,
+ coordinate: Handle<Expression>,
+ array_index: Option<Handle<Expression>>,
+ value: Handle<Expression>,
+ },
+ /// Atomic function.
+ Atomic {
+ /// Pointer to an atomic value.
+ pointer: Handle<Expression>,
+ /// Function to run on the atomic.
+ fun: AtomicFunction,
+ /// Value to use in the function.
+ value: Handle<Expression>,
+ /// [`AtomicResult`] expression representing this function's result.
+ ///
+ /// [`AtomicResult`]: crate::Expression::AtomicResult
+ result: Handle<Expression>,
+ },
+ /// Load uniformly from a uniform pointer in the workgroup address space.
+ ///
+ /// Corresponds to the [`workgroupUniformLoad`](https://www.w3.org/TR/WGSL/#workgroupUniformLoad-builtin)
+ /// built-in function of wgsl, and has the same barrier semantics
+ WorkGroupUniformLoad {
+ /// This must be of type [`Pointer`] in the [`WorkGroup`] address space
+ ///
+ /// [`Pointer`]: TypeInner::Pointer
+ /// [`WorkGroup`]: AddressSpace::WorkGroup
+ pointer: Handle<Expression>,
+ /// The [`WorkGroupUniformLoadResult`] expression representing this load's result.
+ ///
+ /// [`WorkGroupUniformLoadResult`]: Expression::WorkGroupUniformLoadResult
+ result: Handle<Expression>,
+ },
+ /// Calls a function.
+ ///
+ /// If the `result` is `Some`, the corresponding expression has to be
+ /// `Expression::CallResult`, and this statement serves as a barrier for any
+ /// operations on that expression.
+ Call {
+ function: Handle<Function>,
+ arguments: Vec<Handle<Expression>>,
+ result: Option<Handle<Expression>>,
+ },
+ RayQuery {
+ /// The [`RayQuery`] object this statement operates on.
+ ///
+ /// [`RayQuery`]: TypeInner::RayQuery
+ query: Handle<Expression>,
+
+ /// The specific operation we're performing on `query`.
+ fun: RayQueryFunction,
+ },
+}
+
+/// A function argument.
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct FunctionArgument {
+ /// Name of the argument, if any.
+ pub name: Option<String>,
+ /// Type of the argument.
+ pub ty: Handle<Type>,
+ /// For entry points, an argument has to have a binding
+ /// unless it's a structure.
+ pub binding: Option<Binding>,
+}
+
+/// A function result.
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct FunctionResult {
+ /// Type of the result.
+ pub ty: Handle<Type>,
+ /// For entry points, the result has to have a binding
+ /// unless it's a structure.
+ pub binding: Option<Binding>,
+}
+
+/// A function defined in the module.
+#[derive(Debug, Default)]
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct Function {
+ /// Name of the function, if any.
+ pub name: Option<String>,
+ /// Information about function argument.
+ pub arguments: Vec<FunctionArgument>,
+ /// The result of this function, if any.
+ pub result: Option<FunctionResult>,
+ /// Local variables defined and used in the function.
+ pub local_variables: Arena<LocalVariable>,
+ /// Expressions used inside this function.
+ ///
+ /// An `Expression` must occur before all other `Expression`s that use its
+ /// value.
+ pub expressions: Arena<Expression>,
+ /// Map of expressions that have associated variable names
+ pub named_expressions: NamedExpressions,
+ /// Block of instructions comprising the body of the function.
+ pub body: Block,
+}
+
+/// The main function for a pipeline stage.
+///
+/// An [`EntryPoint`] is a [`Function`] that serves as the main function for a
+/// graphics or compute pipeline stage. For example, an `EntryPoint` whose
+/// [`stage`] is [`ShaderStage::Vertex`] can serve as a graphics pipeline's
+/// vertex shader.
+///
+/// Since an entry point is called directly by the graphics or compute pipeline,
+/// not by other WGSL functions, you must specify what the pipeline should pass
+/// as the entry point's arguments, and what values it will return. For example,
+/// a vertex shader needs a vertex's attributes as its arguments, but if it's
+/// used for instanced draw calls, it will also want to know the instance id.
+/// The vertex shader's return value will usually include an output vertex
+/// position, and possibly other attributes to be interpolated and passed along
+/// to a fragment shader.
+///
+/// To specify this, the arguments and result of an `EntryPoint`'s [`function`]
+/// must each have a [`Binding`], or be structs whose members all have
+/// `Binding`s. This associates every value passed to or returned from the entry
+/// point with either a [`BuiltIn`] or a [`Location`]:
+///
+/// - A [`BuiltIn`] has special semantics, usually specific to its pipeline
+/// stage. For example, the result of a vertex shader can include a
+/// [`BuiltIn::Position`] value, which determines the position of a vertex
+/// of a rendered primitive. Or, a compute shader might take an argument
+/// whose binding is [`BuiltIn::WorkGroupSize`], through which the compute
+/// pipeline would pass the number of invocations in your workgroup.
+///
+/// - A [`Location`] indicates user-defined IO to be passed from one pipeline
+/// stage to the next. For example, a vertex shader might also produce a
+/// `uv` texture location as a user-defined IO value.
+///
+/// In other words, the pipeline stage's input and output interface are
+/// determined by the bindings of the arguments and result of the `EntryPoint`'s
+/// [`function`].
+///
+/// [`Function`]: crate::Function
+/// [`Location`]: Binding::Location
+/// [`function`]: EntryPoint::function
+/// [`stage`]: EntryPoint::stage
+#[derive(Debug)]
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct EntryPoint {
+ /// Name of this entry point, visible externally.
+ ///
+ /// Entry point names for a given `stage` must be distinct within a module.
+ pub name: String,
+ /// Shader stage.
+ pub stage: ShaderStage,
+ /// Early depth test for fragment stages.
+ pub early_depth_test: Option<EarlyDepthTest>,
+ /// Workgroup size for compute stages
+ pub workgroup_size: [u32; 3],
+ /// The entrance function.
+ pub function: Function,
+}
+
+/// Return types predeclared for the frexp, modf, and atomicCompareExchangeWeak built-in functions.
+///
+/// These cannot be spelled in WGSL source.
+///
+/// Stored in [`SpecialTypes::predeclared_types`] and created by [`Module::generate_predeclared_type`].
+#[derive(Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum PredeclaredType {
+ AtomicCompareExchangeWeakResult(Scalar),
+ ModfResult {
+ size: Option<VectorSize>,
+ width: Bytes,
+ },
+ FrexpResult {
+ size: Option<VectorSize>,
+ width: Bytes,
+ },
+}
+
+/// Set of special types that can be optionally generated by the frontends.
+#[derive(Debug, Default)]
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct SpecialTypes {
+ /// Type for `RayDesc`.
+ ///
+ /// Call [`Module::generate_ray_desc_type`] to populate this if
+ /// needed and return the handle.
+ pub ray_desc: Option<Handle<Type>>,
+
+ /// Type for `RayIntersection`.
+ ///
+ /// Call [`Module::generate_ray_intersection_type`] to populate
+ /// this if needed and return the handle.
+ pub ray_intersection: Option<Handle<Type>>,
+
+ /// Types for predeclared wgsl types instantiated on demand.
+ ///
+ /// Call [`Module::generate_predeclared_type`] to populate this if
+ /// needed and return the handle.
+ pub predeclared_types: FastIndexMap<PredeclaredType, Handle<Type>>,
+}
+
+/// Shader module.
+///
+/// A module is a set of constants, global variables and functions, as well as
+/// the types required to define them.
+///
+/// Some functions are marked as entry points, to be used in a certain shader stage.
+///
+/// To create a new module, use the `Default` implementation.
+/// Alternatively, you can load an existing shader using one of the [available front ends][front].
+///
+/// When finished, you can export modules using one of the [available backends][back].
+#[derive(Debug, Default)]
+#[cfg_attr(feature = "clone", derive(Clone))]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub struct Module {
+ /// Arena for the types defined in this module.
+ pub types: UniqueArena<Type>,
+ /// Dictionary of special type handles.
+ pub special_types: SpecialTypes,
+ /// Arena for the constants defined in this module.
+ pub constants: Arena<Constant>,
+ /// Arena for the global variables defined in this module.
+ pub global_variables: Arena<GlobalVariable>,
+ /// [Constant expressions] and [override expressions] used by this module.
+ ///
+ /// Each `Expression` must occur in the arena before any
+ /// `Expression` that uses its value.
+ ///
+ /// [Constant expressions]: index.html#constant-expressions
+ /// [override expressions]: index.html#override-expressions
+ pub const_expressions: Arena<Expression>,
+ /// Arena for the functions defined in this module.
+ ///
+ /// Each function must appear in this arena strictly before all its callers.
+ /// Recursion is not supported.
+ pub functions: Arena<Function>,
+ /// Entry points.
+ pub entry_points: Vec<EntryPoint>,
+}
diff --git a/third_party/rust/naga/src/proc/constant_evaluator.rs b/third_party/rust/naga/src/proc/constant_evaluator.rs
new file mode 100644
index 0000000000..b3884b04b1
--- /dev/null
+++ b/third_party/rust/naga/src/proc/constant_evaluator.rs
@@ -0,0 +1,2475 @@
+use std::iter;
+
+use arrayvec::ArrayVec;
+
+use crate::{
+ arena::{Arena, Handle, UniqueArena},
+ ArraySize, BinaryOperator, Constant, Expression, Literal, ScalarKind, Span, Type, TypeInner,
+ UnaryOperator,
+};
+
+/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
+/// `macro_rules!` items that, in turn, emit their own `macro_rules!` items.
+///
+/// Technique stolen directly from
+/// <https://github.com/rust-lang/rust/issues/35853#issuecomment-415993963>.
+macro_rules! with_dollar_sign {
+ ($($body:tt)*) => {
+ macro_rules! __with_dollar_sign { $($body)* }
+ __with_dollar_sign!($);
+ }
+}
+
+macro_rules! gen_component_wise_extractor {
+ (
+ $ident:ident -> $target:ident,
+ literals: [$( $literal:ident => $mapping:ident: $ty:ident ),+ $(,)?],
+ scalar_kinds: [$( $scalar_kind:ident ),* $(,)?],
+ ) => {
+ /// A subset of [`Literal`]s intended to be used for implementing numeric built-ins.
+ enum $target<const N: usize> {
+ $(
+ #[doc = concat!(
+ "Maps to [`Literal::",
+ stringify!($mapping),
+ "`]",
+ )]
+ $mapping([$ty; N]),
+ )+
+ }
+
+ impl From<$target<1>> for Expression {
+ fn from(value: $target<1>) -> Self {
+ match value {
+ $(
+ $target::$mapping([value]) => {
+ Expression::Literal(Literal::$literal(value))
+ }
+ )+
+ }
+ }
+ }
+
+ #[doc = concat!(
+ "Attempts to evaluate multiple `exprs` as a combined [`",
+ stringify!($target),
+ "`] to pass to `handler`. ",
+ )]
+ /// If `exprs` are vectors of the same length, `handler` is called for each corresponding
+ /// component of each vector.
+ ///
+ /// `handler`'s output is registered as a new expression. If `exprs` are vectors of the
+ /// same length, a new vector expression is registered, composed of each component emitted
+ /// by `handler`.
+ fn $ident<const N: usize, const M: usize, F>(
+ eval: &mut ConstantEvaluator<'_>,
+ span: Span,
+ exprs: [Handle<Expression>; N],
+ mut handler: F,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError>
+ where
+ $target<M>: Into<Expression>,
+ F: FnMut($target<N>) -> Result<$target<M>, ConstantEvaluatorError> + Clone,
+ {
+ assert!(N > 0);
+ let err = ConstantEvaluatorError::InvalidMathArg;
+ let mut exprs = exprs.into_iter();
+
+ macro_rules! sanitize {
+ ($expr:expr) => {
+ eval.eval_zero_value_and_splat($expr, span)
+ .map(|expr| &eval.expressions[expr])
+ };
+ }
+
+ let new_expr = match sanitize!(exprs.next().unwrap())? {
+ $(
+ &Expression::Literal(Literal::$literal(x)) => iter::once(Ok(x))
+ .chain(exprs.map(|expr| {
+ sanitize!(expr).and_then(|expr| match expr {
+ &Expression::Literal(Literal::$literal(x)) => Ok(x),
+ _ => Err(err.clone()),
+ })
+ }))
+ .collect::<Result<ArrayVec<_, N>, _>>()
+ .map(|a| a.into_inner().unwrap())
+ .map($target::$mapping)
+ .and_then(|comps| Ok(handler(comps)?.into())),
+ )+
+ &Expression::Compose { ty, ref components } => match &eval.types[ty].inner {
+ &TypeInner::Vector { size, scalar } => match scalar.kind {
+ $(ScalarKind::$scalar_kind)|* => {
+ let first_ty = ty;
+ let mut component_groups =
+ ArrayVec::<ArrayVec<_, { crate::VectorSize::MAX }>, N>::new();
+ component_groups.push(crate::proc::flatten_compose(
+ first_ty,
+ components,
+ eval.expressions,
+ eval.types,
+ ).collect());
+ component_groups.extend(
+ exprs
+ .map(|expr| {
+ sanitize!(expr).and_then(|expr| match expr {
+ &Expression::Compose { ty, ref components }
+ if &eval.types[ty].inner
+ == &eval.types[first_ty].inner =>
+ {
+ Ok(crate::proc::flatten_compose(
+ ty,
+ components,
+ eval.expressions,
+ eval.types,
+ ).collect())
+ }
+ _ => Err(err.clone()),
+ })
+ })
+ .collect::<Result<ArrayVec<_, { crate::VectorSize::MAX }>, _>>(
+ )?,
+ );
+ let component_groups = component_groups.into_inner().unwrap();
+ let mut new_components =
+ ArrayVec::<_, { crate::VectorSize::MAX }>::new();
+ for idx in 0..(size as u8).into() {
+ let group = component_groups
+ .iter()
+ .map(|cs| cs[idx])
+ .collect::<ArrayVec<_, N>>()
+ .into_inner()
+ .unwrap();
+ new_components.push($ident(
+ eval,
+ span,
+ group,
+ handler.clone(),
+ )?);
+ }
+ Ok(Expression::Compose {
+ ty: first_ty,
+ components: new_components.into_iter().collect(),
+ })
+ }
+ _ => return Err(err),
+ },
+ _ => return Err(err),
+ },
+ _ => return Err(err),
+ }?;
+ eval.register_evaluated_expr(new_expr, span)
+ }
+
+ with_dollar_sign! {
+ ($d:tt) => {
+ #[allow(unused)]
+ #[doc = concat!(
+ "A convenience macro for using the same RHS for each [`",
+ stringify!($target),
+ "`] variant in a call to [`",
+ stringify!($ident),
+ "`].",
+ )]
+ macro_rules! $ident {
+ (
+ $eval:expr,
+ $span:expr,
+ [$d ($d expr:expr),+ $d (,)?],
+ |$d ($d arg:ident),+| $d tt:tt
+ ) => {
+ $ident($eval, $span, [$d ($d expr),+], |args| match args {
+ $(
+ $target::$mapping([$d ($d arg),+]) => {
+ let res = $d tt;
+ Result::map(res, $target::$mapping)
+ },
+ )+
+ })
+ };
+ }
+ };
+ }
+ };
+}
+
+gen_component_wise_extractor! {
+ component_wise_scalar -> Scalar,
+ literals: [
+ AbstractFloat => AbstractFloat: f64,
+ F32 => F32: f32,
+ AbstractInt => AbstractInt: i64,
+ U32 => U32: u32,
+ I32 => I32: i32,
+ ],
+ scalar_kinds: [
+ Float,
+ AbstractFloat,
+ Sint,
+ Uint,
+ AbstractInt,
+ ],
+}
+
+gen_component_wise_extractor! {
+ component_wise_float -> Float,
+ literals: [
+ AbstractFloat => Abstract: f64,
+ F32 => F32: f32,
+ ],
+ scalar_kinds: [
+ Float,
+ AbstractFloat,
+ ],
+}
+
+gen_component_wise_extractor! {
+ component_wise_concrete_int -> ConcreteInt,
+ literals: [
+ U32 => U32: u32,
+ I32 => I32: i32,
+ ],
+ scalar_kinds: [
+ Sint,
+ Uint,
+ ],
+}
+
+gen_component_wise_extractor! {
+ component_wise_signed -> Signed,
+ literals: [
+ AbstractFloat => AbstractFloat: f64,
+ AbstractInt => AbstractInt: i64,
+ F32 => F32: f32,
+ I32 => I32: i32,
+ ],
+ scalar_kinds: [
+ Sint,
+ AbstractInt,
+ Float,
+ AbstractFloat,
+ ],
+}
+
+#[derive(Debug)]
+enum Behavior {
+ Wgsl,
+ Glsl,
+}
+
+/// A context for evaluating constant expressions.
+///
+/// A `ConstantEvaluator` points at an expression arena to which it can append
+/// newly evaluated expressions: you pass [`try_eval_and_append`] whatever kind
+/// of Naga [`Expression`] you like, and if its value can be computed at compile
+/// time, `try_eval_and_append` appends an expression representing the computed
+/// value - a tree of [`Literal`], [`Compose`], [`ZeroValue`], and [`Swizzle`]
+/// expressions - to the arena. See the [`try_eval_and_append`] method for details.
+///
+/// A `ConstantEvaluator` also holds whatever information we need to carry out
+/// that evaluation: types, other constants, and so on.
+///
+/// [`try_eval_and_append`]: ConstantEvaluator::try_eval_and_append
+/// [`Compose`]: Expression::Compose
+/// [`ZeroValue`]: Expression::ZeroValue
+/// [`Literal`]: Expression::Literal
+/// [`Swizzle`]: Expression::Swizzle
+#[derive(Debug)]
+pub struct ConstantEvaluator<'a> {
+ /// Which language's evaluation rules we should follow.
+ behavior: Behavior,
+
+ /// The module's type arena.
+ ///
+ /// Because expressions like [`Splat`] contain type handles, we need to be
+ /// able to add new types to produce those expressions.
+ ///
+ /// [`Splat`]: Expression::Splat
+ types: &'a mut UniqueArena<Type>,
+
+ /// The module's constant arena.
+ constants: &'a Arena<Constant>,
+
+ /// The arena to which we are contributing expressions.
+ expressions: &'a mut Arena<Expression>,
+
+ /// When `self.expressions` refers to a function's local expression
+ /// arena, this needs to be populated
+ function_local_data: Option<FunctionLocalData<'a>>,
+}
+
+#[derive(Debug)]
+struct FunctionLocalData<'a> {
+ /// Global constant expressions
+ const_expressions: &'a Arena<Expression>,
+ /// Tracks the constness of expressions residing in `ConstantEvaluator.expressions`
+ expression_constness: &'a mut ExpressionConstnessTracker,
+ emitter: &'a mut super::Emitter,
+ block: &'a mut crate::Block,
+}
+
+#[derive(Debug)]
+pub struct ExpressionConstnessTracker {
+ inner: bit_set::BitSet,
+}
+
+impl ExpressionConstnessTracker {
+ pub fn new() -> Self {
+ Self {
+ inner: bit_set::BitSet::new(),
+ }
+ }
+
+ /// Forces the the expression to not be const
+ pub fn force_non_const(&mut self, value: Handle<Expression>) {
+ self.inner.remove(value.index());
+ }
+
+ fn insert(&mut self, value: Handle<Expression>) {
+ self.inner.insert(value.index());
+ }
+
+ pub fn is_const(&self, value: Handle<Expression>) -> bool {
+ self.inner.contains(value.index())
+ }
+
+ pub fn from_arena(arena: &Arena<Expression>) -> Self {
+ let mut tracker = Self::new();
+ for (handle, expr) in arena.iter() {
+ let insert = match *expr {
+ crate::Expression::Literal(_)
+ | crate::Expression::ZeroValue(_)
+ | crate::Expression::Constant(_) => true,
+ crate::Expression::Compose { ref components, .. } => {
+ components.iter().all(|h| tracker.is_const(*h))
+ }
+ crate::Expression::Splat { value, .. } => tracker.is_const(value),
+ _ => false,
+ };
+ if insert {
+ tracker.insert(handle);
+ }
+ }
+ tracker
+ }
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum ConstantEvaluatorError {
+ #[error("Constants cannot access function arguments")]
+ FunctionArg,
+ #[error("Constants cannot access global variables")]
+ GlobalVariable,
+ #[error("Constants cannot access local variables")]
+ LocalVariable,
+ #[error("Cannot get the array length of a non array type")]
+ InvalidArrayLengthArg,
+ #[error("Constants cannot get the array length of a dynamically sized array")]
+ ArrayLengthDynamic,
+ #[error("Constants cannot call functions")]
+ Call,
+ #[error("Constants don't support workGroupUniformLoad")]
+ WorkGroupUniformLoadResult,
+ #[error("Constants don't support atomic functions")]
+ Atomic,
+ #[error("Constants don't support derivative functions")]
+ Derivative,
+ #[error("Constants don't support load expressions")]
+ Load,
+ #[error("Constants don't support image expressions")]
+ ImageExpression,
+ #[error("Constants don't support ray query expressions")]
+ RayQueryExpression,
+ #[error("Cannot access the type")]
+ InvalidAccessBase,
+ #[error("Cannot access at the index")]
+ InvalidAccessIndex,
+ #[error("Cannot access with index of type")]
+ InvalidAccessIndexTy,
+ #[error("Constants don't support array length expressions")]
+ ArrayLength,
+ #[error("Cannot cast scalar components of expression `{from}` to type `{to}`")]
+ InvalidCastArg { from: String, to: String },
+ #[error("Cannot apply the unary op to the argument")]
+ InvalidUnaryOpArg,
+ #[error("Cannot apply the binary op to the arguments")]
+ InvalidBinaryOpArgs,
+ #[error("Cannot apply math function to type")]
+ InvalidMathArg,
+ #[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
+ InvalidMathArgCount(crate::MathFunction, usize, usize),
+ #[error("value of `low` is greater than `high` for clamp built-in function")]
+ InvalidClamp,
+ #[error("Splat is defined only on scalar values")]
+ SplatScalarOnly,
+ #[error("Can only swizzle vector constants")]
+ SwizzleVectorOnly,
+ #[error("swizzle component not present in source expression")]
+ SwizzleOutOfBounds,
+ #[error("Type is not constructible")]
+ TypeNotConstructible,
+ #[error("Subexpression(s) are not constant")]
+ SubexpressionsAreNotConstant,
+ #[error("Not implemented as constant expression: {0}")]
+ NotImplemented(String),
+ #[error("{0} operation overflowed")]
+ Overflow(String),
+ #[error(
+ "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately"
+ )]
+ AutomaticConversionLossy {
+ value: String,
+ to_type: &'static str,
+ },
+ #[error("abstract floating-point values cannot be automatically converted to integers")]
+ AutomaticConversionFloatToInt { to_type: &'static str },
+ #[error("Division by zero")]
+ DivisionByZero,
+ #[error("Remainder by zero")]
+ RemainderByZero,
+ #[error("RHS of shift operation is greater than or equal to 32")]
+ ShiftedMoreThan32Bits,
+ #[error(transparent)]
+ Literal(#[from] crate::valid::LiteralError),
+}
+
+impl<'a> ConstantEvaluator<'a> {
+ /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
+ /// constant expression arena.
+ ///
+ /// Report errors according to WGSL's rules for constant evaluation.
+ pub fn for_wgsl_module(module: &'a mut crate::Module) -> Self {
+ Self::for_module(Behavior::Wgsl, module)
+ }
+
+ /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
+ /// constant expression arena.
+ ///
+ /// Report errors according to GLSL's rules for constant evaluation.
+ pub fn for_glsl_module(module: &'a mut crate::Module) -> Self {
+ Self::for_module(Behavior::Glsl, module)
+ }
+
+ fn for_module(behavior: Behavior, module: &'a mut crate::Module) -> Self {
+ Self {
+ behavior,
+ types: &mut module.types,
+ constants: &module.constants,
+ expressions: &mut module.const_expressions,
+ function_local_data: None,
+ }
+ }
+
+ /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
+ /// expression arena.
+ ///
+ /// Report errors according to WGSL's rules for constant evaluation.
+ pub fn for_wgsl_function(
+ module: &'a mut crate::Module,
+ expressions: &'a mut Arena<Expression>,
+ expression_constness: &'a mut ExpressionConstnessTracker,
+ emitter: &'a mut super::Emitter,
+ block: &'a mut crate::Block,
+ ) -> Self {
+ Self::for_function(
+ Behavior::Wgsl,
+ module,
+ expressions,
+ expression_constness,
+ emitter,
+ block,
+ )
+ }
+
+ /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
+ /// expression arena.
+ ///
+ /// Report errors according to GLSL's rules for constant evaluation.
+ pub fn for_glsl_function(
+ module: &'a mut crate::Module,
+ expressions: &'a mut Arena<Expression>,
+ expression_constness: &'a mut ExpressionConstnessTracker,
+ emitter: &'a mut super::Emitter,
+ block: &'a mut crate::Block,
+ ) -> Self {
+ Self::for_function(
+ Behavior::Glsl,
+ module,
+ expressions,
+ expression_constness,
+ emitter,
+ block,
+ )
+ }
+
+ fn for_function(
+ behavior: Behavior,
+ module: &'a mut crate::Module,
+ expressions: &'a mut Arena<Expression>,
+ expression_constness: &'a mut ExpressionConstnessTracker,
+ emitter: &'a mut super::Emitter,
+ block: &'a mut crate::Block,
+ ) -> Self {
+ Self {
+ behavior,
+ types: &mut module.types,
+ constants: &module.constants,
+ expressions,
+ function_local_data: Some(FunctionLocalData {
+ const_expressions: &module.const_expressions,
+ expression_constness,
+ emitter,
+ block,
+ }),
+ }
+ }
+
+ pub fn to_ctx(&self) -> crate::proc::GlobalCtx {
+ crate::proc::GlobalCtx {
+ types: self.types,
+ constants: self.constants,
+ const_expressions: match self.function_local_data {
+ Some(ref data) => data.const_expressions,
+ None => self.expressions,
+ },
+ }
+ }
+
+ fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
+ if let Some(ref function_local_data) = self.function_local_data {
+ if !function_local_data.expression_constness.is_const(expr) {
+ log::debug!("check: SubexpressionsAreNotConstant");
+ return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
+ }
+ }
+ Ok(())
+ }
+
+ fn check_and_get(
+ &mut self,
+ expr: Handle<Expression>,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ match self.expressions[expr] {
+ Expression::Constant(c) => {
+ // Are we working in a function's expression arena, or the
+ // module's constant expression arena?
+ if let Some(ref function_local_data) = self.function_local_data {
+ // Deep-copy the constant's value into our arena.
+ self.copy_from(
+ self.constants[c].init,
+ function_local_data.const_expressions,
+ )
+ } else {
+ // "See through" the constant and use its initializer.
+ Ok(self.constants[c].init)
+ }
+ }
+ _ => {
+ self.check(expr)?;
+ Ok(expr)
+ }
+ }
+ }
+
+ /// Try to evaluate `expr` at compile time.
+ ///
+ /// The `expr` argument can be any sort of Naga [`Expression`] you like. If
+ /// we can determine its value at compile time, we append an expression
+ /// representing its value - a tree of [`Literal`], [`Compose`],
+ /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena
+ /// `self` contributes to.
+ ///
+ /// If `expr`'s value cannot be determined at compile time, return a an
+ /// error. If it's acceptable to evaluate `expr` at runtime, this error can
+ /// be ignored, and the caller can append `expr` to the arena itself.
+ ///
+ /// We only consider `expr` itself, without recursing into its operands. Its
+ /// operands must all have been produced by prior calls to
+ /// `try_eval_and_append`, to ensure that they have already been reduced to
+ /// an evaluated form if possible.
+ ///
+ /// [`Literal`]: Expression::Literal
+ /// [`Compose`]: Expression::Compose
+ /// [`ZeroValue`]: Expression::ZeroValue
+ /// [`Swizzle`]: Expression::Swizzle
+ pub fn try_eval_and_append(
+ &mut self,
+ expr: &Expression,
+ span: Span,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ log::trace!("try_eval_and_append: {:?}", expr);
+ match *expr {
+ Expression::Constant(c) if self.function_local_data.is_none() => {
+ // "See through" the constant and use its initializer.
+ // This is mainly done to avoid having constants pointing to other constants.
+ Ok(self.constants[c].init)
+ }
+ Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
+ self.register_evaluated_expr(expr.clone(), span)
+ }
+ Expression::Compose { ty, ref components } => {
+ let components = components
+ .iter()
+ .map(|component| self.check_and_get(*component))
+ .collect::<Result<Vec<_>, _>>()?;
+ self.register_evaluated_expr(Expression::Compose { ty, components }, span)
+ }
+ Expression::Splat { size, value } => {
+ let value = self.check_and_get(value)?;
+ self.register_evaluated_expr(Expression::Splat { size, value }, span)
+ }
+ Expression::AccessIndex { base, index } => {
+ let base = self.check_and_get(base)?;
+
+ self.access(base, index as usize, span)
+ }
+ Expression::Access { base, index } => {
+ let base = self.check_and_get(base)?;
+ let index = self.check_and_get(index)?;
+
+ self.access(base, self.constant_index(index)?, span)
+ }
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ let vector = self.check_and_get(vector)?;
+
+ self.swizzle(size, span, vector, pattern)
+ }
+ Expression::Unary { expr, op } => {
+ let expr = self.check_and_get(expr)?;
+
+ self.unary_op(op, expr, span)
+ }
+ Expression::Binary { left, right, op } => {
+ let left = self.check_and_get(left)?;
+ let right = self.check_and_get(right)?;
+
+ self.binary_op(op, left, right, span)
+ }
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ let arg = self.check_and_get(arg)?;
+ let arg1 = arg1.map(|arg| self.check_and_get(arg)).transpose()?;
+ let arg2 = arg2.map(|arg| self.check_and_get(arg)).transpose()?;
+ let arg3 = arg3.map(|arg| self.check_and_get(arg)).transpose()?;
+
+ self.math(arg, arg1, arg2, arg3, fun, span)
+ }
+ Expression::As {
+ convert,
+ expr,
+ kind,
+ } => {
+ let expr = self.check_and_get(expr)?;
+
+ match convert {
+ Some(width) => self.cast(expr, crate::Scalar { kind, width }, span),
+ None => Err(ConstantEvaluatorError::NotImplemented(
+ "bitcast built-in function".into(),
+ )),
+ }
+ }
+ Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
+ "select built-in function".into(),
+ )),
+ Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented(
+ format!("{fun:?} built-in function"),
+ )),
+ Expression::ArrayLength(expr) => match self.behavior {
+ Behavior::Wgsl => Err(ConstantEvaluatorError::ArrayLength),
+ Behavior::Glsl => {
+ let expr = self.check_and_get(expr)?;
+ self.array_length(expr, span)
+ }
+ },
+ Expression::Load { .. } => Err(ConstantEvaluatorError::Load),
+ Expression::LocalVariable(_) => Err(ConstantEvaluatorError::LocalVariable),
+ Expression::Derivative { .. } => Err(ConstantEvaluatorError::Derivative),
+ Expression::CallResult { .. } => Err(ConstantEvaluatorError::Call),
+ Expression::WorkGroupUniformLoadResult { .. } => {
+ Err(ConstantEvaluatorError::WorkGroupUniformLoadResult)
+ }
+ Expression::AtomicResult { .. } => Err(ConstantEvaluatorError::Atomic),
+ Expression::FunctionArgument(_) => Err(ConstantEvaluatorError::FunctionArg),
+ Expression::GlobalVariable(_) => Err(ConstantEvaluatorError::GlobalVariable),
+ Expression::ImageSample { .. }
+ | Expression::ImageLoad { .. }
+ | Expression::ImageQuery { .. } => Err(ConstantEvaluatorError::ImageExpression),
+ Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
+ Err(ConstantEvaluatorError::RayQueryExpression)
+ }
+ }
+ }
+
+ /// Splat `value` to `size`, without using [`Splat`] expressions.
+ ///
+ /// This constructs [`Compose`] or [`ZeroValue`] expressions to
+ /// build a vector with the given `size` whose components are all
+ /// `value`.
+ ///
+ /// Use `span` as the span of the inserted expressions and
+ /// resulting types.
+ ///
+ /// [`Splat`]: Expression::Splat
+ /// [`Compose`]: Expression::Compose
+ /// [`ZeroValue`]: Expression::ZeroValue
+ fn splat(
+ &mut self,
+ value: Handle<Expression>,
+ size: crate::VectorSize,
+ span: Span,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ match self.expressions[value] {
+ Expression::Literal(literal) => {
+ let scalar = literal.scalar();
+ let ty = self.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector { size, scalar },
+ },
+ span,
+ );
+ let expr = Expression::Compose {
+ ty,
+ components: vec![value; size as usize],
+ };
+ self.register_evaluated_expr(expr, span)
+ }
+ Expression::ZeroValue(ty) => {
+ let inner = match self.types[ty].inner {
+ TypeInner::Scalar(scalar) => TypeInner::Vector { size, scalar },
+ _ => return Err(ConstantEvaluatorError::SplatScalarOnly),
+ };
+ let res_ty = self.types.insert(Type { name: None, inner }, span);
+ let expr = Expression::ZeroValue(res_ty);
+ self.register_evaluated_expr(expr, span)
+ }
+ _ => Err(ConstantEvaluatorError::SplatScalarOnly),
+ }
+ }
+
+ fn swizzle(
+ &mut self,
+ size: crate::VectorSize,
+ span: Span,
+ src_constant: Handle<Expression>,
+ pattern: [crate::SwizzleComponent; 4],
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ let mut get_dst_ty = |ty| match self.types[ty].inner {
+ crate::TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
+ Type {
+ name: None,
+ inner: crate::TypeInner::Vector { size, scalar },
+ },
+ span,
+ )),
+ _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
+ };
+
+ match self.expressions[src_constant] {
+ Expression::ZeroValue(ty) => {
+ let dst_ty = get_dst_ty(ty)?;
+ let expr = Expression::ZeroValue(dst_ty);
+ self.register_evaluated_expr(expr, span)
+ }
+ Expression::Splat { value, .. } => {
+ let expr = Expression::Splat { size, value };
+ self.register_evaluated_expr(expr, span)
+ }
+ Expression::Compose { ty, ref components } => {
+ let dst_ty = get_dst_ty(ty)?;
+
+ let mut flattened = [src_constant; 4]; // dummy value
+ let len =
+ crate::proc::flatten_compose(ty, components, self.expressions, self.types)
+ .zip(flattened.iter_mut())
+ .map(|(component, elt)| *elt = component)
+ .count();
+ let flattened = &flattened[..len];
+
+ let swizzled_components = pattern[..size as usize]
+ .iter()
+ .map(|&sc| {
+ let sc = sc as usize;
+ if let Some(elt) = flattened.get(sc) {
+ Ok(*elt)
+ } else {
+ Err(ConstantEvaluatorError::SwizzleOutOfBounds)
+ }
+ })
+ .collect::<Result<Vec<Handle<Expression>>, _>>()?;
+ let expr = Expression::Compose {
+ ty: dst_ty,
+ components: swizzled_components,
+ };
+ self.register_evaluated_expr(expr, span)
+ }
+ _ => Err(ConstantEvaluatorError::SwizzleVectorOnly),
+ }
+ }
+
+ fn math(
+ &mut self,
+ arg: Handle<Expression>,
+ arg1: Option<Handle<Expression>>,
+ arg2: Option<Handle<Expression>>,
+ arg3: Option<Handle<Expression>>,
+ fun: crate::MathFunction,
+ span: Span,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ let expected = fun.argument_count();
+ let given = Some(arg)
+ .into_iter()
+ .chain(arg1)
+ .chain(arg2)
+ .chain(arg3)
+ .count();
+ if expected != given {
+ return Err(ConstantEvaluatorError::InvalidMathArgCount(
+ fun, expected, given,
+ ));
+ }
+
+ // NOTE: We try to match the declaration order of `MathFunction` here.
+ match fun {
+ // comparison
+ crate::MathFunction::Abs => {
+ component_wise_scalar(self, span, [arg], |args| match args {
+ Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])),
+ Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])),
+ Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.abs()])),
+ Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])),
+ Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz
+ })
+ }
+ crate::MathFunction::Min => {
+ component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
+ Ok([e1.min(e2)])
+ })
+ }
+ crate::MathFunction::Max => {
+ component_wise_scalar!(self, span, [arg, arg1.unwrap()], |e1, e2| {
+ Ok([e1.max(e2)])
+ })
+ }
+ crate::MathFunction::Clamp => {
+ component_wise_scalar!(
+ self,
+ span,
+ [arg, arg1.unwrap(), arg2.unwrap()],
+ |e, low, high| {
+ if low > high {
+ Err(ConstantEvaluatorError::InvalidClamp)
+ } else {
+ Ok([e.clamp(low, high)])
+ }
+ }
+ )
+ }
+ crate::MathFunction::Saturate => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.clamp(0., 1.)]) })
+ }
+
+ // trigonometry
+ crate::MathFunction::Cos => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.cos()]) })
+ }
+ crate::MathFunction::Cosh => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.cosh()]) })
+ }
+ crate::MathFunction::Sin => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.sin()]) })
+ }
+ crate::MathFunction::Sinh => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.sinh()]) })
+ }
+ crate::MathFunction::Tan => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.tan()]) })
+ }
+ crate::MathFunction::Tanh => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.tanh()]) })
+ }
+ crate::MathFunction::Acos => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.acos()]) })
+ }
+ crate::MathFunction::Asin => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.asin()]) })
+ }
+ crate::MathFunction::Atan => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.atan()]) })
+ }
+ crate::MathFunction::Asinh => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.asinh()]) })
+ }
+ crate::MathFunction::Acosh => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.acosh()]) })
+ }
+ crate::MathFunction::Atanh => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.atanh()]) })
+ }
+ crate::MathFunction::Radians => {
+ component_wise_float!(self, span, [arg], |e1| { Ok([e1.to_radians()]) })
+ }
+ crate::MathFunction::Degrees => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.to_degrees()]) })
+ }
+
+ // decomposition
+ crate::MathFunction::Ceil => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.ceil()]) })
+ }
+ crate::MathFunction::Floor => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.floor()]) })
+ }
+ crate::MathFunction::Round => {
+ // TODO: Use `f{32,64}.round_ties_even()` when available on stable. This polyfill
+ // is shamelessly [~~stolen from~~ inspired by `ndarray-image`][polyfill source],
+ // which has licensing compatible with ours. See also
+ // <https://github.com/rust-lang/rust/issues/96710>.
+ //
+ // [polyfill source]: https://github.com/imeka/ndarray-ndimage/blob/8b14b4d6ecfbc96a8a052f802e342a7049c68d8f/src/lib.rs#L98
+ fn round_ties_even(x: f64) -> f64 {
+ let i = x as i64;
+ let f = (x - i as f64).abs();
+ if f == 0.5 {
+ if i & 1 == 1 {
+ // -1.5, 1.5, 3.5, ...
+ (x.abs() + 0.5).copysign(x)
+ } else {
+ (x.abs() - 0.5).copysign(x)
+ }
+ } else {
+ x.round()
+ }
+ }
+ component_wise_float(self, span, [arg], |e| match e {
+ Float::Abstract([e]) => Ok(Float::Abstract([round_ties_even(e)])),
+ Float::F32([e]) => Ok(Float::F32([(round_ties_even(e as f64) as f32)])),
+ })
+ }
+ crate::MathFunction::Fract => {
+ component_wise_float!(self, span, [arg], |e| {
+ // N.B., Rust's definition of `fract` is `e - e.trunc()`, so we can't use that
+ // here.
+ Ok([e - e.floor()])
+ })
+ }
+ crate::MathFunction::Trunc => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.trunc()]) })
+ }
+
+ // exponent
+ crate::MathFunction::Exp => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.exp()]) })
+ }
+ crate::MathFunction::Exp2 => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.exp2()]) })
+ }
+ crate::MathFunction::Log => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.ln()]) })
+ }
+ crate::MathFunction::Log2 => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.log2()]) })
+ }
+ crate::MathFunction::Pow => {
+ component_wise_float!(self, span, [arg, arg1.unwrap()], |e1, e2| {
+ Ok([e1.powf(e2)])
+ })
+ }
+
+ // computational
+ crate::MathFunction::Sign => {
+ component_wise_signed!(self, span, [arg], |e| { Ok([e.signum()]) })
+ }
+ crate::MathFunction::Fma => {
+ component_wise_float!(
+ self,
+ span,
+ [arg, arg1.unwrap(), arg2.unwrap()],
+ |e1, e2, e3| { Ok([e1.mul_add(e2, e3)]) }
+ )
+ }
+ crate::MathFunction::Step => {
+ component_wise_float!(self, span, [arg, arg1.unwrap()], |edge, x| {
+ Ok([if edge <= x { 1.0 } else { 0.0 }])
+ })
+ }
+ crate::MathFunction::Sqrt => {
+ component_wise_float!(self, span, [arg], |e| { Ok([e.sqrt()]) })
+ }
+ crate::MathFunction::InverseSqrt => {
+ component_wise_float!(self, span, [arg], |e| { Ok([1. / e.sqrt()]) })
+ }
+
+ // bits
+ crate::MathFunction::CountTrailingZeros => {
+ component_wise_concrete_int!(self, span, [arg], |e| {
+ #[allow(clippy::useless_conversion)]
+ Ok([e
+ .trailing_zeros()
+ .try_into()
+ .expect("bit count overflowed 32 bits, somehow!?")])
+ })
+ }
+ crate::MathFunction::CountLeadingZeros => {
+ component_wise_concrete_int!(self, span, [arg], |e| {
+ #[allow(clippy::useless_conversion)]
+ Ok([e
+ .leading_zeros()
+ .try_into()
+ .expect("bit count overflowed 32 bits, somehow!?")])
+ })
+ }
+ crate::MathFunction::CountOneBits => {
+ component_wise_concrete_int!(self, span, [arg], |e| {
+ #[allow(clippy::useless_conversion)]
+ Ok([e
+ .count_ones()
+ .try_into()
+ .expect("bit count overflowed 32 bits, somehow!?")])
+ })
+ }
+ crate::MathFunction::ReverseBits => {
+ component_wise_concrete_int!(self, span, [arg], |e| { Ok([e.reverse_bits()]) })
+ }
+
+ fun => Err(ConstantEvaluatorError::NotImplemented(format!(
+ "{fun:?} built-in function"
+ ))),
+ }
+ }
+
+ fn array_length(
+ &mut self,
+ array: Handle<Expression>,
+ span: Span,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ match self.expressions[array] {
+ Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
+ match self.types[ty].inner {
+ TypeInner::Array { size, .. } => match size {
+ crate::ArraySize::Constant(len) => {
+ let expr = Expression::Literal(Literal::U32(len.get()));
+ self.register_evaluated_expr(expr, span)
+ }
+ crate::ArraySize::Dynamic => {
+ Err(ConstantEvaluatorError::ArrayLengthDynamic)
+ }
+ },
+ _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
+ }
+ }
+ _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
+ }
+ }
+
+ fn access(
+ &mut self,
+ base: Handle<Expression>,
+ index: usize,
+ span: Span,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ match self.expressions[base] {
+ Expression::ZeroValue(ty) => {
+ let ty_inner = &self.types[ty].inner;
+ let components = ty_inner
+ .components()
+ .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
+
+ if index >= components as usize {
+ Err(ConstantEvaluatorError::InvalidAccessBase)
+ } else {
+ let ty_res = ty_inner
+ .component_type(index)
+ .ok_or(ConstantEvaluatorError::InvalidAccessIndex)?;
+ let ty = match ty_res {
+ crate::proc::TypeResolution::Handle(ty) => ty,
+ crate::proc::TypeResolution::Value(inner) => {
+ self.types.insert(Type { name: None, inner }, span)
+ }
+ };
+ self.register_evaluated_expr(Expression::ZeroValue(ty), span)
+ }
+ }
+ Expression::Splat { size, value } => {
+ if index >= size as usize {
+ Err(ConstantEvaluatorError::InvalidAccessBase)
+ } else {
+ Ok(value)
+ }
+ }
+ Expression::Compose { ty, ref components } => {
+ let _ = self.types[ty]
+ .inner
+ .components()
+ .ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
+
+ crate::proc::flatten_compose(ty, components, self.expressions, self.types)
+ .nth(index)
+ .ok_or(ConstantEvaluatorError::InvalidAccessIndex)
+ }
+ _ => Err(ConstantEvaluatorError::InvalidAccessBase),
+ }
+ }
+
+ fn constant_index(&self, expr: Handle<Expression>) -> Result<usize, ConstantEvaluatorError> {
+ match self.expressions[expr] {
+ Expression::ZeroValue(ty)
+ if matches!(
+ self.types[ty].inner,
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: ScalarKind::Uint,
+ ..
+ })
+ ) =>
+ {
+ Ok(0)
+ }
+ Expression::Literal(Literal::U32(index)) => Ok(index as usize),
+ _ => Err(ConstantEvaluatorError::InvalidAccessIndexTy),
+ }
+ }
+
+ /// Lower [`ZeroValue`] and [`Splat`] expressions to [`Literal`] and [`Compose`] expressions.
+ ///
+ /// [`ZeroValue`]: Expression::ZeroValue
+ /// [`Splat`]: Expression::Splat
+ /// [`Literal`]: Expression::Literal
+ /// [`Compose`]: Expression::Compose
+ fn eval_zero_value_and_splat(
+ &mut self,
+ expr: Handle<Expression>,
+ span: Span,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ match self.expressions[expr] {
+ Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
+ Expression::Splat { size, value } => self.splat(value, size, span),
+ _ => Ok(expr),
+ }
+ }
+
+ /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
+ ///
+ /// [`ZeroValue`]: Expression::ZeroValue
+ /// [`Literal`]: Expression::Literal
+ /// [`Compose`]: Expression::Compose
+ fn eval_zero_value(
+ &mut self,
+ expr: Handle<Expression>,
+ span: Span,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ match self.expressions[expr] {
+ Expression::ZeroValue(ty) => self.eval_zero_value_impl(ty, span),
+ _ => Ok(expr),
+ }
+ }
+
+ /// Lower [`ZeroValue`] expressions to [`Literal`] and [`Compose`] expressions.
+ ///
+ /// [`ZeroValue`]: Expression::ZeroValue
+ /// [`Literal`]: Expression::Literal
+ /// [`Compose`]: Expression::Compose
+ fn eval_zero_value_impl(
+ &mut self,
+ ty: Handle<Type>,
+ span: Span,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ match self.types[ty].inner {
+ TypeInner::Scalar(scalar) => {
+ let expr = Expression::Literal(
+ Literal::zero(scalar).ok_or(ConstantEvaluatorError::TypeNotConstructible)?,
+ );
+ self.register_evaluated_expr(expr, span)
+ }
+ TypeInner::Vector { size, scalar } => {
+ let scalar_ty = self.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Scalar(scalar),
+ },
+ span,
+ );
+ let el = self.eval_zero_value_impl(scalar_ty, span)?;
+ let expr = Expression::Compose {
+ ty,
+ components: vec![el; size as usize],
+ };
+ self.register_evaluated_expr(expr, span)
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ let vec_ty = self.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector { size: rows, scalar },
+ },
+ span,
+ );
+ let el = self.eval_zero_value_impl(vec_ty, span)?;
+ let expr = Expression::Compose {
+ ty,
+ components: vec![el; columns as usize],
+ };
+ self.register_evaluated_expr(expr, span)
+ }
+ TypeInner::Array {
+ base,
+ size: ArraySize::Constant(size),
+ ..
+ } => {
+ let el = self.eval_zero_value_impl(base, span)?;
+ let expr = Expression::Compose {
+ ty,
+ components: vec![el; size.get() as usize],
+ };
+ self.register_evaluated_expr(expr, span)
+ }
+ TypeInner::Struct { ref members, .. } => {
+ let types: Vec<_> = members.iter().map(|m| m.ty).collect();
+ let mut components = Vec::with_capacity(members.len());
+ for ty in types {
+ components.push(self.eval_zero_value_impl(ty, span)?);
+ }
+ let expr = Expression::Compose { ty, components };
+ self.register_evaluated_expr(expr, span)
+ }
+ _ => Err(ConstantEvaluatorError::TypeNotConstructible),
+ }
+ }
+
+ /// Convert the scalar components of `expr` to `target`.
+ ///
+ /// Treat `span` as the location of the resulting expression.
+ pub fn cast(
+ &mut self,
+ expr: Handle<Expression>,
+ target: crate::Scalar,
+ span: Span,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ use crate::Scalar as Sc;
+
+ let expr = self.eval_zero_value(expr, span)?;
+
+ let make_error = || -> Result<_, ConstantEvaluatorError> {
+ let from = format!("{:?} {:?}", expr, self.expressions[expr]);
+
+ #[cfg(feature = "wgsl-in")]
+ let to = target.to_wgsl();
+
+ #[cfg(not(feature = "wgsl-in"))]
+ let to = format!("{target:?}");
+
+ Err(ConstantEvaluatorError::InvalidCastArg { from, to })
+ };
+
+ let expr = match self.expressions[expr] {
+ Expression::Literal(literal) => {
+ let literal = match target {
+ Sc::I32 => Literal::I32(match literal {
+ Literal::I32(v) => v,
+ Literal::U32(v) => v as i32,
+ Literal::F32(v) => v as i32,
+ Literal::Bool(v) => v as i32,
+ Literal::F64(_) | Literal::I64(_) => {
+ return make_error();
+ }
+ Literal::AbstractInt(v) => i32::try_from_abstract(v)?,
+ Literal::AbstractFloat(v) => i32::try_from_abstract(v)?,
+ }),
+ Sc::U32 => Literal::U32(match literal {
+ Literal::I32(v) => v as u32,
+ Literal::U32(v) => v,
+ Literal::F32(v) => v as u32,
+ Literal::Bool(v) => v as u32,
+ Literal::F64(_) | Literal::I64(_) => {
+ return make_error();
+ }
+ Literal::AbstractInt(v) => u32::try_from_abstract(v)?,
+ Literal::AbstractFloat(v) => u32::try_from_abstract(v)?,
+ }),
+ Sc::F32 => Literal::F32(match literal {
+ Literal::I32(v) => v as f32,
+ Literal::U32(v) => v as f32,
+ Literal::F32(v) => v,
+ Literal::Bool(v) => v as u32 as f32,
+ Literal::F64(_) | Literal::I64(_) => {
+ return make_error();
+ }
+ Literal::AbstractInt(v) => f32::try_from_abstract(v)?,
+ Literal::AbstractFloat(v) => f32::try_from_abstract(v)?,
+ }),
+ Sc::F64 => Literal::F64(match literal {
+ Literal::I32(v) => v as f64,
+ Literal::U32(v) => v as f64,
+ Literal::F32(v) => v as f64,
+ Literal::F64(v) => v,
+ Literal::Bool(v) => v as u32 as f64,
+ Literal::I64(_) => return make_error(),
+ Literal::AbstractInt(v) => f64::try_from_abstract(v)?,
+ Literal::AbstractFloat(v) => f64::try_from_abstract(v)?,
+ }),
+ Sc::BOOL => Literal::Bool(match literal {
+ Literal::I32(v) => v != 0,
+ Literal::U32(v) => v != 0,
+ Literal::F32(v) => v != 0.0,
+ Literal::Bool(v) => v,
+ Literal::F64(_)
+ | Literal::I64(_)
+ | Literal::AbstractInt(_)
+ | Literal::AbstractFloat(_) => {
+ return make_error();
+ }
+ }),
+ Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal {
+ Literal::AbstractInt(v) => {
+ // Overflow is forbidden, but inexact conversions
+ // are fine. The range of f64 is far larger than
+ // that of i64, so we don't have to check anything
+ // here.
+ v as f64
+ }
+ Literal::AbstractFloat(v) => v,
+ _ => return make_error(),
+ }),
+ _ => {
+ log::debug!("Constant evaluator refused to convert value to {target:?}");
+ return make_error();
+ }
+ };
+ Expression::Literal(literal)
+ }
+ Expression::Compose {
+ ty,
+ components: ref src_components,
+ } => {
+ let ty_inner = match self.types[ty].inner {
+ TypeInner::Vector { size, .. } => TypeInner::Vector {
+ size,
+ scalar: target,
+ },
+ TypeInner::Matrix { columns, rows, .. } => TypeInner::Matrix {
+ columns,
+ rows,
+ scalar: target,
+ },
+ _ => return make_error(),
+ };
+
+ let mut components = src_components.clone();
+ for component in &mut components {
+ *component = self.cast(*component, target, span)?;
+ }
+
+ let ty = self.types.insert(
+ Type {
+ name: None,
+ inner: ty_inner,
+ },
+ span,
+ );
+
+ Expression::Compose { ty, components }
+ }
+ Expression::Splat { size, value } => {
+ let value_span = self.expressions.get_span(value);
+ let cast_value = self.cast(value, target, value_span)?;
+ Expression::Splat {
+ size,
+ value: cast_value,
+ }
+ }
+ _ => return make_error(),
+ };
+
+ self.register_evaluated_expr(expr, span)
+ }
+
+ /// Convert the scalar leaves of `expr` to `target`, handling arrays.
+ ///
+ /// `expr` must be a `Compose` expression whose type is a scalar, vector,
+ /// matrix, or nested arrays of such.
+ ///
+ /// This is basically the same as the [`cast`] method, except that that
+ /// should only handle Naga [`As`] expressions, which cannot convert arrays.
+ ///
+ /// Treat `span` as the location of the resulting expression.
+ ///
+ /// [`cast`]: ConstantEvaluator::cast
+ /// [`As`]: crate::Expression::As
+ pub fn cast_array(
+ &mut self,
+ expr: Handle<Expression>,
+ target: crate::Scalar,
+ span: Span,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ let Expression::Compose { ty, ref components } = self.expressions[expr] else {
+ return self.cast(expr, target, span);
+ };
+
+ let crate::TypeInner::Array {
+ base: _,
+ size,
+ stride: _,
+ } = self.types[ty].inner
+ else {
+ return self.cast(expr, target, span);
+ };
+
+ let mut components = components.clone();
+ for component in &mut components {
+ *component = self.cast_array(*component, target, span)?;
+ }
+
+ let first = components.first().unwrap();
+ let new_base = match self.resolve_type(*first)? {
+ crate::proc::TypeResolution::Handle(ty) => ty,
+ crate::proc::TypeResolution::Value(inner) => {
+ self.types.insert(Type { name: None, inner }, span)
+ }
+ };
+ let new_base_stride = self.types[new_base].inner.size(self.to_ctx());
+ let new_array_ty = self.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Array {
+ base: new_base,
+ size,
+ stride: new_base_stride,
+ },
+ },
+ span,
+ );
+
+ let compose = Expression::Compose {
+ ty: new_array_ty,
+ components,
+ };
+ self.register_evaluated_expr(compose, span)
+ }
+
+ fn unary_op(
+ &mut self,
+ op: UnaryOperator,
+ expr: Handle<Expression>,
+ span: Span,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ let expr = self.eval_zero_value_and_splat(expr, span)?;
+
+ let expr = match self.expressions[expr] {
+ Expression::Literal(value) => Expression::Literal(match op {
+ UnaryOperator::Negate => match value {
+ Literal::I32(v) => Literal::I32(v.wrapping_neg()),
+ Literal::F32(v) => Literal::F32(-v),
+ Literal::AbstractInt(v) => Literal::AbstractInt(v.wrapping_neg()),
+ Literal::AbstractFloat(v) => Literal::AbstractFloat(-v),
+ _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
+ },
+ UnaryOperator::LogicalNot => match value {
+ Literal::Bool(v) => Literal::Bool(!v),
+ _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
+ },
+ UnaryOperator::BitwiseNot => match value {
+ Literal::I32(v) => Literal::I32(!v),
+ Literal::U32(v) => Literal::U32(!v),
+ Literal::AbstractInt(v) => Literal::AbstractInt(!v),
+ _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
+ },
+ }),
+ Expression::Compose {
+ ty,
+ components: ref src_components,
+ } => {
+ match self.types[ty].inner {
+ TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
+ _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
+ }
+
+ let mut components = src_components.clone();
+ for component in &mut components {
+ *component = self.unary_op(op, *component, span)?;
+ }
+
+ Expression::Compose { ty, components }
+ }
+ _ => return Err(ConstantEvaluatorError::InvalidUnaryOpArg),
+ };
+
+ self.register_evaluated_expr(expr, span)
+ }
+
+ fn binary_op(
+ &mut self,
+ op: BinaryOperator,
+ left: Handle<Expression>,
+ right: Handle<Expression>,
+ span: Span,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ let left = self.eval_zero_value_and_splat(left, span)?;
+ let right = self.eval_zero_value_and_splat(right, span)?;
+
+ let expr = match (&self.expressions[left], &self.expressions[right]) {
+ (&Expression::Literal(left_value), &Expression::Literal(right_value)) => {
+ let literal = match op {
+ BinaryOperator::Equal => Literal::Bool(left_value == right_value),
+ BinaryOperator::NotEqual => Literal::Bool(left_value != right_value),
+ BinaryOperator::Less => Literal::Bool(left_value < right_value),
+ BinaryOperator::LessEqual => Literal::Bool(left_value <= right_value),
+ BinaryOperator::Greater => Literal::Bool(left_value > right_value),
+ BinaryOperator::GreaterEqual => Literal::Bool(left_value >= right_value),
+
+ _ => match (left_value, right_value) {
+ (Literal::I32(a), Literal::I32(b)) => Literal::I32(match op {
+ BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
+ ConstantEvaluatorError::Overflow("addition".into())
+ })?,
+ BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
+ ConstantEvaluatorError::Overflow("subtraction".into())
+ })?,
+ BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
+ ConstantEvaluatorError::Overflow("multiplication".into())
+ })?,
+ BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
+ if b == 0 {
+ ConstantEvaluatorError::DivisionByZero
+ } else {
+ ConstantEvaluatorError::Overflow("division".into())
+ }
+ })?,
+ BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
+ if b == 0 {
+ ConstantEvaluatorError::RemainderByZero
+ } else {
+ ConstantEvaluatorError::Overflow("remainder".into())
+ }
+ })?,
+ BinaryOperator::And => a & b,
+ BinaryOperator::ExclusiveOr => a ^ b,
+ BinaryOperator::InclusiveOr => a | b,
+ _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
+ }),
+ (Literal::I32(a), Literal::U32(b)) => Literal::I32(match op {
+ BinaryOperator::ShiftLeft => a
+ .checked_shl(b)
+ .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
+ BinaryOperator::ShiftRight => a
+ .checked_shr(b)
+ .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
+ _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
+ }),
+ (Literal::U32(a), Literal::U32(b)) => Literal::U32(match op {
+ BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
+ ConstantEvaluatorError::Overflow("addition".into())
+ })?,
+ BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
+ ConstantEvaluatorError::Overflow("subtraction".into())
+ })?,
+ BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
+ ConstantEvaluatorError::Overflow("multiplication".into())
+ })?,
+ BinaryOperator::Divide => a
+ .checked_div(b)
+ .ok_or(ConstantEvaluatorError::DivisionByZero)?,
+ BinaryOperator::Modulo => a
+ .checked_rem(b)
+ .ok_or(ConstantEvaluatorError::RemainderByZero)?,
+ BinaryOperator::And => a & b,
+ BinaryOperator::ExclusiveOr => a ^ b,
+ BinaryOperator::InclusiveOr => a | b,
+ BinaryOperator::ShiftLeft => a
+ .checked_shl(b)
+ .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
+ BinaryOperator::ShiftRight => a
+ .checked_shr(b)
+ .ok_or(ConstantEvaluatorError::ShiftedMoreThan32Bits)?,
+ _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
+ }),
+ (Literal::F32(a), Literal::F32(b)) => Literal::F32(match op {
+ BinaryOperator::Add => a + b,
+ BinaryOperator::Subtract => a - b,
+ BinaryOperator::Multiply => a * b,
+ BinaryOperator::Divide => a / b,
+ BinaryOperator::Modulo => a % b,
+ _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
+ }),
+ (Literal::AbstractInt(a), Literal::AbstractInt(b)) => {
+ Literal::AbstractInt(match op {
+ BinaryOperator::Add => a.checked_add(b).ok_or_else(|| {
+ ConstantEvaluatorError::Overflow("addition".into())
+ })?,
+ BinaryOperator::Subtract => a.checked_sub(b).ok_or_else(|| {
+ ConstantEvaluatorError::Overflow("subtraction".into())
+ })?,
+ BinaryOperator::Multiply => a.checked_mul(b).ok_or_else(|| {
+ ConstantEvaluatorError::Overflow("multiplication".into())
+ })?,
+ BinaryOperator::Divide => a.checked_div(b).ok_or_else(|| {
+ if b == 0 {
+ ConstantEvaluatorError::DivisionByZero
+ } else {
+ ConstantEvaluatorError::Overflow("division".into())
+ }
+ })?,
+ BinaryOperator::Modulo => a.checked_rem(b).ok_or_else(|| {
+ if b == 0 {
+ ConstantEvaluatorError::RemainderByZero
+ } else {
+ ConstantEvaluatorError::Overflow("remainder".into())
+ }
+ })?,
+ BinaryOperator::And => a & b,
+ BinaryOperator::ExclusiveOr => a ^ b,
+ BinaryOperator::InclusiveOr => a | b,
+ _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
+ })
+ }
+ (Literal::AbstractFloat(a), Literal::AbstractFloat(b)) => {
+ Literal::AbstractFloat(match op {
+ BinaryOperator::Add => a + b,
+ BinaryOperator::Subtract => a - b,
+ BinaryOperator::Multiply => a * b,
+ BinaryOperator::Divide => a / b,
+ BinaryOperator::Modulo => a % b,
+ _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
+ })
+ }
+ (Literal::Bool(a), Literal::Bool(b)) => Literal::Bool(match op {
+ BinaryOperator::LogicalAnd => a && b,
+ BinaryOperator::LogicalOr => a || b,
+ _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
+ }),
+ _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
+ },
+ };
+ Expression::Literal(literal)
+ }
+ (
+ &Expression::Compose {
+ components: ref src_components,
+ ty,
+ },
+ &Expression::Literal(_),
+ ) => {
+ let mut components = src_components.clone();
+ for component in &mut components {
+ *component = self.binary_op(op, *component, right, span)?;
+ }
+ Expression::Compose { ty, components }
+ }
+ (
+ &Expression::Literal(_),
+ &Expression::Compose {
+ components: ref src_components,
+ ty,
+ },
+ ) => {
+ let mut components = src_components.clone();
+ for component in &mut components {
+ *component = self.binary_op(op, left, *component, span)?;
+ }
+ Expression::Compose { ty, components }
+ }
+ (
+ &Expression::Compose {
+ components: ref left_components,
+ ty: left_ty,
+ },
+ &Expression::Compose {
+ components: ref right_components,
+ ty: right_ty,
+ },
+ ) => {
+ // We have to make a copy of the component lists, because the
+ // call to `binary_op_vector` needs `&mut self`, but `self` owns
+ // the component lists.
+ let left_flattened = crate::proc::flatten_compose(
+ left_ty,
+ left_components,
+ self.expressions,
+ self.types,
+ );
+ let right_flattened = crate::proc::flatten_compose(
+ right_ty,
+ right_components,
+ self.expressions,
+ self.types,
+ );
+
+ // `flatten_compose` doesn't return an `ExactSizeIterator`, so
+ // make a reasonable guess of the capacity we'll need.
+ let mut flattened = Vec::with_capacity(left_components.len());
+ flattened.extend(left_flattened.zip(right_flattened));
+
+ match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
+ (
+ &TypeInner::Vector {
+ size: left_size, ..
+ },
+ &TypeInner::Vector {
+ size: right_size, ..
+ },
+ ) if left_size == right_size => {
+ self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
+ }
+ _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
+ }
+ }
+ _ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
+ };
+
+ self.register_evaluated_expr(expr, span)
+ }
+
+ fn binary_op_vector(
+ &mut self,
+ op: BinaryOperator,
+ size: crate::VectorSize,
+ components: &[(Handle<Expression>, Handle<Expression>)],
+ left_ty: Handle<Type>,
+ span: Span,
+ ) -> Result<Expression, ConstantEvaluatorError> {
+ let ty = match op {
+ // Relational operators produce vectors of booleans.
+ BinaryOperator::Equal
+ | BinaryOperator::NotEqual
+ | BinaryOperator::Less
+ | BinaryOperator::LessEqual
+ | BinaryOperator::Greater
+ | BinaryOperator::GreaterEqual => self.types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size,
+ scalar: crate::Scalar::BOOL,
+ },
+ },
+ span,
+ ),
+
+ // Other operators produce the same type as their left
+ // operand.
+ BinaryOperator::Add
+ | BinaryOperator::Subtract
+ | BinaryOperator::Multiply
+ | BinaryOperator::Divide
+ | BinaryOperator::Modulo
+ | BinaryOperator::And
+ | BinaryOperator::ExclusiveOr
+ | BinaryOperator::InclusiveOr
+ | BinaryOperator::LogicalAnd
+ | BinaryOperator::LogicalOr
+ | BinaryOperator::ShiftLeft
+ | BinaryOperator::ShiftRight => left_ty,
+ };
+
+ let components = components
+ .iter()
+ .map(|&(left, right)| self.binary_op(op, left, right, span))
+ .collect::<Result<Vec<_>, _>>()?;
+
+ Ok(Expression::Compose { ty, components })
+ }
+
+ /// Deep copy `expr` from `expressions` into `self.expressions`.
+ ///
+ /// Return the root of the new copy.
+ ///
+ /// This is used when we're evaluating expressions in a function's
+ /// expression arena that refer to a constant: we need to copy the
+ /// constant's value into the function's arena so we can operate on it.
+ fn copy_from(
+ &mut self,
+ expr: Handle<Expression>,
+ expressions: &Arena<Expression>,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ let span = expressions.get_span(expr);
+ match expressions[expr] {
+ ref expr @ (Expression::Literal(_)
+ | Expression::Constant(_)
+ | Expression::ZeroValue(_)) => self.register_evaluated_expr(expr.clone(), span),
+ Expression::Compose { ty, ref components } => {
+ let mut components = components.clone();
+ for component in &mut components {
+ *component = self.copy_from(*component, expressions)?;
+ }
+ self.register_evaluated_expr(Expression::Compose { ty, components }, span)
+ }
+ Expression::Splat { size, value } => {
+ let value = self.copy_from(value, expressions)?;
+ self.register_evaluated_expr(Expression::Splat { size, value }, span)
+ }
+ _ => {
+ log::debug!("copy_from: SubexpressionsAreNotConstant");
+ Err(ConstantEvaluatorError::SubexpressionsAreNotConstant)
+ }
+ }
+ }
+
+ fn register_evaluated_expr(
+ &mut self,
+ expr: Expression,
+ span: Span,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ // It suffices to only check literals, since we only register one
+ // expression at a time, `Compose` expressions can only refer to other
+ // expressions, and `ZeroValue` expressions are always okay.
+ if let Expression::Literal(literal) = expr {
+ crate::valid::check_literal_value(literal)?;
+ }
+
+ if let Some(FunctionLocalData {
+ ref mut emitter,
+ ref mut block,
+ ref mut expression_constness,
+ ..
+ }) = self.function_local_data
+ {
+ let is_running = emitter.is_running();
+ let needs_pre_emit = expr.needs_pre_emit();
+ if is_running && needs_pre_emit {
+ block.extend(emitter.finish(self.expressions));
+ let h = self.expressions.append(expr, span);
+ emitter.start(self.expressions);
+ expression_constness.insert(h);
+ Ok(h)
+ } else {
+ let h = self.expressions.append(expr, span);
+ expression_constness.insert(h);
+ Ok(h)
+ }
+ } else {
+ Ok(self.expressions.append(expr, span))
+ }
+ }
+
+ fn resolve_type(
+ &self,
+ expr: Handle<Expression>,
+ ) -> Result<crate::proc::TypeResolution, ConstantEvaluatorError> {
+ use crate::proc::TypeResolution as Tr;
+ use crate::Expression as Ex;
+ let resolution = match self.expressions[expr] {
+ Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()),
+ Ex::Constant(c) => Tr::Handle(self.constants[c].ty),
+ Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty),
+ Ex::Splat { size, value } => {
+ let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else {
+ return Err(ConstantEvaluatorError::SplatScalarOnly);
+ };
+ Tr::Value(TypeInner::Vector { scalar, size })
+ }
+ _ => {
+ log::debug!("resolve_type: SubexpressionsAreNotConstant");
+ return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
+ }
+ };
+
+ Ok(resolution)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::vec;
+
+ use crate::{
+ Arena, Constant, Expression, Literal, ScalarKind, Type, TypeInner, UnaryOperator,
+ UniqueArena, VectorSize,
+ };
+
+ use super::{Behavior, ConstantEvaluator};
+
+ #[test]
+ fn unary_op() {
+ let mut types = UniqueArena::new();
+ let mut constants = Arena::new();
+ let mut const_expressions = Arena::new();
+
+ let scalar_ty = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Scalar(crate::Scalar::I32),
+ },
+ Default::default(),
+ );
+
+ let vec_ty = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: VectorSize::Bi,
+ scalar: crate::Scalar::I32,
+ },
+ },
+ Default::default(),
+ );
+
+ let h = constants.append(
+ Constant {
+ name: None,
+ r#override: crate::Override::None,
+ ty: scalar_ty,
+ init: const_expressions
+ .append(Expression::Literal(Literal::I32(4)), Default::default()),
+ },
+ Default::default(),
+ );
+
+ let h1 = constants.append(
+ Constant {
+ name: None,
+ r#override: crate::Override::None,
+ ty: scalar_ty,
+ init: const_expressions
+ .append(Expression::Literal(Literal::I32(8)), Default::default()),
+ },
+ Default::default(),
+ );
+
+ let vec_h = constants.append(
+ Constant {
+ name: None,
+ r#override: crate::Override::None,
+ ty: vec_ty,
+ init: const_expressions.append(
+ Expression::Compose {
+ ty: vec_ty,
+ components: vec![constants[h].init, constants[h1].init],
+ },
+ Default::default(),
+ ),
+ },
+ Default::default(),
+ );
+
+ let expr = const_expressions.append(Expression::Constant(h), Default::default());
+ let expr1 = const_expressions.append(Expression::Constant(vec_h), Default::default());
+
+ let expr2 = Expression::Unary {
+ op: UnaryOperator::Negate,
+ expr,
+ };
+
+ let expr3 = Expression::Unary {
+ op: UnaryOperator::BitwiseNot,
+ expr,
+ };
+
+ let expr4 = Expression::Unary {
+ op: UnaryOperator::BitwiseNot,
+ expr: expr1,
+ };
+
+ let mut solver = ConstantEvaluator {
+ behavior: Behavior::Wgsl,
+ types: &mut types,
+ constants: &constants,
+ expressions: &mut const_expressions,
+ function_local_data: None,
+ };
+
+ let res1 = solver
+ .try_eval_and_append(&expr2, Default::default())
+ .unwrap();
+ let res2 = solver
+ .try_eval_and_append(&expr3, Default::default())
+ .unwrap();
+ let res3 = solver
+ .try_eval_and_append(&expr4, Default::default())
+ .unwrap();
+
+ assert_eq!(
+ const_expressions[res1],
+ Expression::Literal(Literal::I32(-4))
+ );
+
+ assert_eq!(
+ const_expressions[res2],
+ Expression::Literal(Literal::I32(!4))
+ );
+
+ let res3_inner = &const_expressions[res3];
+
+ match *res3_inner {
+ Expression::Compose {
+ ref ty,
+ ref components,
+ } => {
+ assert_eq!(*ty, vec_ty);
+ let mut components_iter = components.iter().copied();
+ assert_eq!(
+ const_expressions[components_iter.next().unwrap()],
+ Expression::Literal(Literal::I32(!4))
+ );
+ assert_eq!(
+ const_expressions[components_iter.next().unwrap()],
+ Expression::Literal(Literal::I32(!8))
+ );
+ assert!(components_iter.next().is_none());
+ }
+ _ => panic!("Expected vector"),
+ }
+ }
+
+ #[test]
+ fn cast() {
+ let mut types = UniqueArena::new();
+ let mut constants = Arena::new();
+ let mut const_expressions = Arena::new();
+
+ let scalar_ty = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Scalar(crate::Scalar::I32),
+ },
+ Default::default(),
+ );
+
+ let h = constants.append(
+ Constant {
+ name: None,
+ r#override: crate::Override::None,
+ ty: scalar_ty,
+ init: const_expressions
+ .append(Expression::Literal(Literal::I32(4)), Default::default()),
+ },
+ Default::default(),
+ );
+
+ let expr = const_expressions.append(Expression::Constant(h), Default::default());
+
+ let root = Expression::As {
+ expr,
+ kind: ScalarKind::Bool,
+ convert: Some(crate::BOOL_WIDTH),
+ };
+
+ let mut solver = ConstantEvaluator {
+ behavior: Behavior::Wgsl,
+ types: &mut types,
+ constants: &constants,
+ expressions: &mut const_expressions,
+ function_local_data: None,
+ };
+
+ let res = solver
+ .try_eval_and_append(&root, Default::default())
+ .unwrap();
+
+ assert_eq!(
+ const_expressions[res],
+ Expression::Literal(Literal::Bool(true))
+ );
+ }
+
+ #[test]
+ fn access() {
+ let mut types = UniqueArena::new();
+ let mut constants = Arena::new();
+ let mut const_expressions = Arena::new();
+
+ let matrix_ty = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns: VectorSize::Bi,
+ rows: VectorSize::Tri,
+ scalar: crate::Scalar::F32,
+ },
+ },
+ Default::default(),
+ );
+
+ let vec_ty = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: VectorSize::Tri,
+ scalar: crate::Scalar::F32,
+ },
+ },
+ Default::default(),
+ );
+
+ let mut vec1_components = Vec::with_capacity(3);
+ let mut vec2_components = Vec::with_capacity(3);
+
+ for i in 0..3 {
+ let h = const_expressions.append(
+ Expression::Literal(Literal::F32(i as f32)),
+ Default::default(),
+ );
+
+ vec1_components.push(h)
+ }
+
+ for i in 3..6 {
+ let h = const_expressions.append(
+ Expression::Literal(Literal::F32(i as f32)),
+ Default::default(),
+ );
+
+ vec2_components.push(h)
+ }
+
+ let vec1 = constants.append(
+ Constant {
+ name: None,
+ r#override: crate::Override::None,
+ ty: vec_ty,
+ init: const_expressions.append(
+ Expression::Compose {
+ ty: vec_ty,
+ components: vec1_components,
+ },
+ Default::default(),
+ ),
+ },
+ Default::default(),
+ );
+
+ let vec2 = constants.append(
+ Constant {
+ name: None,
+ r#override: crate::Override::None,
+ ty: vec_ty,
+ init: const_expressions.append(
+ Expression::Compose {
+ ty: vec_ty,
+ components: vec2_components,
+ },
+ Default::default(),
+ ),
+ },
+ Default::default(),
+ );
+
+ let h = constants.append(
+ Constant {
+ name: None,
+ r#override: crate::Override::None,
+ ty: matrix_ty,
+ init: const_expressions.append(
+ Expression::Compose {
+ ty: matrix_ty,
+ components: vec![constants[vec1].init, constants[vec2].init],
+ },
+ Default::default(),
+ ),
+ },
+ Default::default(),
+ );
+
+ let base = const_expressions.append(Expression::Constant(h), Default::default());
+
+ let mut solver = ConstantEvaluator {
+ behavior: Behavior::Wgsl,
+ types: &mut types,
+ constants: &constants,
+ expressions: &mut const_expressions,
+ function_local_data: None,
+ };
+
+ let root1 = Expression::AccessIndex { base, index: 1 };
+
+ let res1 = solver
+ .try_eval_and_append(&root1, Default::default())
+ .unwrap();
+
+ let root2 = Expression::AccessIndex {
+ base: res1,
+ index: 2,
+ };
+
+ let res2 = solver
+ .try_eval_and_append(&root2, Default::default())
+ .unwrap();
+
+ match const_expressions[res1] {
+ Expression::Compose {
+ ref ty,
+ ref components,
+ } => {
+ assert_eq!(*ty, vec_ty);
+ let mut components_iter = components.iter().copied();
+ assert_eq!(
+ const_expressions[components_iter.next().unwrap()],
+ Expression::Literal(Literal::F32(3.))
+ );
+ assert_eq!(
+ const_expressions[components_iter.next().unwrap()],
+ Expression::Literal(Literal::F32(4.))
+ );
+ assert_eq!(
+ const_expressions[components_iter.next().unwrap()],
+ Expression::Literal(Literal::F32(5.))
+ );
+ assert!(components_iter.next().is_none());
+ }
+ _ => panic!("Expected vector"),
+ }
+
+ assert_eq!(
+ const_expressions[res2],
+ Expression::Literal(Literal::F32(5.))
+ );
+ }
+
+ #[test]
+ fn compose_of_constants() {
+ let mut types = UniqueArena::new();
+ let mut constants = Arena::new();
+ let mut const_expressions = Arena::new();
+
+ let i32_ty = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Scalar(crate::Scalar::I32),
+ },
+ Default::default(),
+ );
+
+ let vec2_i32_ty = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: VectorSize::Bi,
+ scalar: crate::Scalar::I32,
+ },
+ },
+ Default::default(),
+ );
+
+ let h = constants.append(
+ Constant {
+ name: None,
+ r#override: crate::Override::None,
+ ty: i32_ty,
+ init: const_expressions
+ .append(Expression::Literal(Literal::I32(4)), Default::default()),
+ },
+ Default::default(),
+ );
+
+ let h_expr = const_expressions.append(Expression::Constant(h), Default::default());
+
+ let mut solver = ConstantEvaluator {
+ behavior: Behavior::Wgsl,
+ types: &mut types,
+ constants: &constants,
+ expressions: &mut const_expressions,
+ function_local_data: None,
+ };
+
+ let solved_compose = solver
+ .try_eval_and_append(
+ &Expression::Compose {
+ ty: vec2_i32_ty,
+ components: vec![h_expr, h_expr],
+ },
+ Default::default(),
+ )
+ .unwrap();
+ let solved_negate = solver
+ .try_eval_and_append(
+ &Expression::Unary {
+ op: UnaryOperator::Negate,
+ expr: solved_compose,
+ },
+ Default::default(),
+ )
+ .unwrap();
+
+ let pass = match const_expressions[solved_negate] {
+ Expression::Compose { ty, ref components } => {
+ ty == vec2_i32_ty
+ && components.iter().all(|&component| {
+ let component = &const_expressions[component];
+ matches!(*component, Expression::Literal(Literal::I32(-4)))
+ })
+ }
+ _ => false,
+ };
+ if !pass {
+ panic!("unexpected evaluation result")
+ }
+ }
+
+ #[test]
+ fn splat_of_constant() {
+ let mut types = UniqueArena::new();
+ let mut constants = Arena::new();
+ let mut const_expressions = Arena::new();
+
+ let i32_ty = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Scalar(crate::Scalar::I32),
+ },
+ Default::default(),
+ );
+
+ let vec2_i32_ty = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: VectorSize::Bi,
+ scalar: crate::Scalar::I32,
+ },
+ },
+ Default::default(),
+ );
+
+ let h = constants.append(
+ Constant {
+ name: None,
+ r#override: crate::Override::None,
+ ty: i32_ty,
+ init: const_expressions
+ .append(Expression::Literal(Literal::I32(4)), Default::default()),
+ },
+ Default::default(),
+ );
+
+ let h_expr = const_expressions.append(Expression::Constant(h), Default::default());
+
+ let mut solver = ConstantEvaluator {
+ behavior: Behavior::Wgsl,
+ types: &mut types,
+ constants: &constants,
+ expressions: &mut const_expressions,
+ function_local_data: None,
+ };
+
+ let solved_compose = solver
+ .try_eval_and_append(
+ &Expression::Splat {
+ size: VectorSize::Bi,
+ value: h_expr,
+ },
+ Default::default(),
+ )
+ .unwrap();
+ let solved_negate = solver
+ .try_eval_and_append(
+ &Expression::Unary {
+ op: UnaryOperator::Negate,
+ expr: solved_compose,
+ },
+ Default::default(),
+ )
+ .unwrap();
+
+ let pass = match const_expressions[solved_negate] {
+ Expression::Compose { ty, ref components } => {
+ ty == vec2_i32_ty
+ && components.iter().all(|&component| {
+ let component = &const_expressions[component];
+ matches!(*component, Expression::Literal(Literal::I32(-4)))
+ })
+ }
+ _ => false,
+ };
+ if !pass {
+ panic!("unexpected evaluation result")
+ }
+ }
+}
+
+/// Trait for conversions of abstract values to concrete types.
+trait TryFromAbstract<T>: Sized {
+ /// Convert an abstract literal `value` to `Self`.
+ ///
+ /// Since Naga's `AbstractInt` and `AbstractFloat` exist to support
+ /// WGSL, we follow WGSL's conversion rules here:
+ ///
+ /// - WGSL §6.1.2. Conversion Rank says that automatic conversions
+ /// to integers are either lossless or an error.
+ ///
+ /// - WGSL §14.6.4 Floating Point Conversion says that conversions
+ /// to floating point in constant expressions and override
+ /// expressions are errors if the value is out of range for the
+ /// destination type, but rounding is okay.
+ ///
+ /// [`AbstractInt`]: crate::Literal::AbstractInt
+ /// [`Float`]: crate::Literal::Float
+ fn try_from_abstract(value: T) -> Result<Self, ConstantEvaluatorError>;
+}
+
+impl TryFromAbstract<i64> for i32 {
+ fn try_from_abstract(value: i64) -> Result<i32, ConstantEvaluatorError> {
+ i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
+ value: format!("{value:?}"),
+ to_type: "i32",
+ })
+ }
+}
+
+impl TryFromAbstract<i64> for u32 {
+ fn try_from_abstract(value: i64) -> Result<u32, ConstantEvaluatorError> {
+ u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy {
+ value: format!("{value:?}"),
+ to_type: "u32",
+ })
+ }
+}
+
+impl TryFromAbstract<i64> for f32 {
+ fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
+ let f = value as f32;
+ // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
+ // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for
+ // overflow here.
+ Ok(f)
+ }
+}
+
+impl TryFromAbstract<f64> for f32 {
+ fn try_from_abstract(value: f64) -> Result<f32, ConstantEvaluatorError> {
+ let f = value as f32;
+ if f.is_infinite() {
+ return Err(ConstantEvaluatorError::AutomaticConversionLossy {
+ value: format!("{value:?}"),
+ to_type: "f32",
+ });
+ }
+ Ok(f)
+ }
+}
+
+impl TryFromAbstract<i64> for f64 {
+ fn try_from_abstract(value: i64) -> Result<Self, ConstantEvaluatorError> {
+ let f = value as f64;
+ // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of
+ // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for
+ // overflow here.
+ Ok(f)
+ }
+}
+
+impl TryFromAbstract<f64> for f64 {
+ fn try_from_abstract(value: f64) -> Result<f64, ConstantEvaluatorError> {
+ Ok(value)
+ }
+}
+
+impl TryFromAbstract<f64> for i32 {
+ fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
+ Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i32" })
+ }
+}
+
+impl TryFromAbstract<f64> for u32 {
+ fn try_from_abstract(_: f64) -> Result<Self, ConstantEvaluatorError> {
+ Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u32" })
+ }
+}
diff --git a/third_party/rust/naga/src/proc/emitter.rs b/third_party/rust/naga/src/proc/emitter.rs
new file mode 100644
index 0000000000..0df804fff2
--- /dev/null
+++ b/third_party/rust/naga/src/proc/emitter.rs
@@ -0,0 +1,39 @@
+use crate::arena::Arena;
+
+/// Helper class to emit expressions
+#[allow(dead_code)]
+#[derive(Default, Debug)]
+pub struct Emitter {
+ start_len: Option<usize>,
+}
+
+#[allow(dead_code)]
+impl Emitter {
+ pub fn start(&mut self, arena: &Arena<crate::Expression>) {
+ if self.start_len.is_some() {
+ unreachable!("Emitting has already started!");
+ }
+ self.start_len = Some(arena.len());
+ }
+ pub const fn is_running(&self) -> bool {
+ self.start_len.is_some()
+ }
+ #[must_use]
+ pub fn finish(
+ &mut self,
+ arena: &Arena<crate::Expression>,
+ ) -> Option<(crate::Statement, crate::span::Span)> {
+ let start_len = self.start_len.take().unwrap();
+ if start_len != arena.len() {
+ #[allow(unused_mut)]
+ let mut span = crate::span::Span::default();
+ let range = arena.range_from(start_len);
+ for handle in range.clone() {
+ span.subsume(arena.get_span(handle))
+ }
+ Some((crate::Statement::Emit(range), span))
+ } else {
+ None
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/proc/index.rs b/third_party/rust/naga/src/proc/index.rs
new file mode 100644
index 0000000000..af3221c0fe
--- /dev/null
+++ b/third_party/rust/naga/src/proc/index.rs
@@ -0,0 +1,435 @@
+/*!
+Definitions for index bounds checking.
+*/
+
+use crate::{valid, Handle, UniqueArena};
+use bit_set::BitSet;
+
+/// How should code generated by Naga do bounds checks?
+///
+/// When a vector, matrix, or array index is out of bounds—either negative, or
+/// greater than or equal to the number of elements in the type—WGSL requires
+/// that some other index of the implementation's choice that is in bounds is
+/// used instead. (There are no types with zero elements.)
+///
+/// Similarly, when out-of-bounds coordinates, array indices, or sample indices
+/// are presented to the WGSL `textureLoad` and `textureStore` operations, the
+/// operation is redirected to do something safe.
+///
+/// Different users of Naga will prefer different defaults:
+///
+/// - When used as part of a WebGPU implementation, the WGSL specification
+/// requires the `Restrict` behavior for array, vector, and matrix accesses,
+/// and either the `Restrict` or `ReadZeroSkipWrite` behaviors for texture
+/// accesses.
+///
+/// - When used by the `wgpu` crate for native development, `wgpu` selects
+/// `ReadZeroSkipWrite` as its default.
+///
+/// - Naga's own default is `Unchecked`, so that shader translations
+/// are as faithful to the original as possible.
+///
+/// Sometimes the underlying hardware and drivers can perform bounds checks
+/// themselves, in a way that performs better than the checks Naga would inject.
+/// If you're using native checks like this, then having Naga inject its own
+/// checks as well would be redundant, and the `Unchecked` policy is
+/// appropriate.
+#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum BoundsCheckPolicy {
+ /// Replace out-of-bounds indexes with some arbitrary in-bounds index.
+ ///
+ /// (This does not necessarily mean clamping. For example, interpreting the
+ /// index as unsigned and taking the minimum with the largest valid index
+ /// would also be a valid implementation. That would map negative indices to
+ /// the last element, not the first.)
+ Restrict,
+
+ /// Out-of-bounds reads return zero, and writes have no effect.
+ ///
+ /// When applied to a chain of accesses, like `a[i][j].b[k]`, all index
+ /// expressions are evaluated, regardless of whether prior or later index
+ /// expressions were in bounds. But all the accesses per se are skipped
+ /// if any index is out of bounds.
+ ReadZeroSkipWrite,
+
+ /// Naga adds no checks to indexing operations. Generate the fastest code
+ /// possible. This is the default for Naga, as a translator, but consumers
+ /// should consider defaulting to a safer behavior.
+ Unchecked,
+}
+
+/// Policies for injecting bounds checks during code generation.
+#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct BoundsCheckPolicies {
+ /// How should the generated code handle array, vector, or matrix indices
+ /// that are out of range?
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub index: BoundsCheckPolicy,
+
+ /// How should the generated code handle array, vector, or matrix indices
+ /// that are out of range, when those values live in a [`GlobalVariable`] in
+ /// the [`Storage`] or [`Uniform`] address spaces?
+ ///
+ /// Some graphics hardware provides "robust buffer access", a feature that
+ /// ensures that using a pointer cannot access memory outside the 'buffer'
+ /// that it was derived from. In Naga terms, this means that the hardware
+ /// ensures that pointers computed by applying [`Access`] and
+ /// [`AccessIndex`] expressions to a [`GlobalVariable`] whose [`space`] is
+ /// [`Storage`] or [`Uniform`] will never read or write memory outside that
+ /// global variable.
+ ///
+ /// When hardware offers such a feature, it is probably undesirable to have
+ /// Naga inject bounds checking code for such accesses, since the hardware
+ /// can probably provide the same protection more efficiently. However,
+ /// bounds checks are still needed on accesses to indexable values that do
+ /// not live in buffers, like local variables.
+ ///
+ /// So, this option provides a separate policy that applies only to accesses
+ /// to storage and uniform globals. When depending on hardware bounds
+ /// checking, this policy can be `Unchecked` to avoid unnecessary overhead.
+ ///
+ /// When special hardware support is not available, this should probably be
+ /// the same as `index_bounds_check_policy`.
+ ///
+ /// [`GlobalVariable`]: crate::GlobalVariable
+ /// [`space`]: crate::GlobalVariable::space
+ /// [`Restrict`]: crate::back::BoundsCheckPolicy::Restrict
+ /// [`ReadZeroSkipWrite`]: crate::back::BoundsCheckPolicy::ReadZeroSkipWrite
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ /// [`Storage`]: crate::AddressSpace::Storage
+ /// [`Uniform`]: crate::AddressSpace::Uniform
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub buffer: BoundsCheckPolicy,
+
+ /// How should the generated code handle image texel loads that are out
+ /// of range?
+ ///
+ /// This controls the behavior of [`ImageLoad`] expressions when a coordinate,
+ /// texture array index, level of detail, or multisampled sample number is out of range.
+ ///
+ /// [`ImageLoad`]: crate::Expression::ImageLoad
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub image_load: BoundsCheckPolicy,
+
+ /// How should the generated code handle image texel stores that are out
+ /// of range?
+ ///
+ /// This controls the behavior of [`ImageStore`] statements when a coordinate,
+ /// texture array index, level of detail, or multisampled sample number is out of range.
+ ///
+ /// This policy should't be needed since all backends should ignore OOB writes.
+ ///
+ /// [`ImageStore`]: crate::Statement::ImageStore
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub image_store: BoundsCheckPolicy,
+
+ /// How should the generated code handle binding array indexes that are out of bounds.
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub binding_array: BoundsCheckPolicy,
+}
+
+/// The default `BoundsCheckPolicy` is `Unchecked`.
+impl Default for BoundsCheckPolicy {
+ fn default() -> Self {
+ BoundsCheckPolicy::Unchecked
+ }
+}
+
+impl BoundsCheckPolicies {
+ /// Determine which policy applies to `base`.
+ ///
+ /// `base` is the "base" expression (the expression being indexed) of a `Access`
+ /// and `AccessIndex` expression. This is either a pointer, a value, being directly
+ /// indexed, or a binding array.
+ ///
+ /// See the documentation for [`BoundsCheckPolicy`] for details about
+ /// when each policy applies.
+ pub fn choose_policy(
+ &self,
+ base: Handle<crate::Expression>,
+ types: &UniqueArena<crate::Type>,
+ info: &valid::FunctionInfo,
+ ) -> BoundsCheckPolicy {
+ let ty = info[base].ty.inner_with(types);
+
+ if let crate::TypeInner::BindingArray { .. } = *ty {
+ return self.binding_array;
+ }
+
+ match ty.pointer_space() {
+ Some(crate::AddressSpace::Storage { access: _ } | crate::AddressSpace::Uniform) => {
+ self.buffer
+ }
+ // This covers other address spaces, but also accessing vectors and
+ // matrices by value, where no pointer is involved.
+ _ => self.index,
+ }
+ }
+
+ /// Return `true` if any of `self`'s policies are `policy`.
+ pub fn contains(&self, policy: BoundsCheckPolicy) -> bool {
+ self.index == policy
+ || self.buffer == policy
+ || self.image_load == policy
+ || self.image_store == policy
+ }
+}
+
+/// An index that may be statically known, or may need to be computed at runtime.
+///
+/// This enum lets us handle both [`Access`] and [`AccessIndex`] expressions
+/// with the same code.
+///
+/// [`Access`]: crate::Expression::Access
+/// [`AccessIndex`]: crate::Expression::AccessIndex
+#[derive(Clone, Copy, Debug)]
+pub enum GuardedIndex {
+ Known(u32),
+ Expression(Handle<crate::Expression>),
+}
+
+/// Build a set of expressions used as indices, to cache in temporary variables when
+/// emitted.
+///
+/// Given the bounds-check policies `policies`, construct a `BitSet` containing the handle
+/// indices of all the expressions in `function` that are ever used as guarded indices
+/// under the [`ReadZeroSkipWrite`] policy. The `module` argument must be the module to
+/// which `function` belongs, and `info` should be that function's analysis results.
+///
+/// Such index expressions will be used twice in the generated code: first for the
+/// comparison to see if the index is in bounds, and then for the access itself, should
+/// the comparison succeed. To avoid computing the expressions twice, the generated code
+/// should cache them in temporary variables.
+///
+/// Why do we need to build such a set in advance, instead of just processing access
+/// expressions as we encounter them? Whether an expression needs to be cached depends on
+/// whether it appears as something like the [`index`] operand of an [`Access`] expression
+/// or the [`level`] operand of an [`ImageLoad`] expression, and on the index bounds check
+/// policies that apply to those accesses. But [`Emit`] statements just identify a range
+/// of expressions by index; there's no good way to tell what an expression is used
+/// for. The only way to do it is to just iterate over all the expressions looking for
+/// relevant `Access` expressions --- which is what this function does.
+///
+/// Simple expressions like variable loads and constants don't make sense to cache: it's
+/// no better than just re-evaluating them. But constants are not covered by `Emit`
+/// statements, and `Load`s are always cached to ensure they occur at the right time, so
+/// we don't bother filtering them out from this set.
+///
+/// Fortunately, we don't need to deal with [`ImageStore`] statements here. When we emit
+/// code for a statement, the writer isn't in the middle of an expression, so we can just
+/// emit declarations for temporaries, initialized appropriately.
+///
+/// None of these concerns apply for SPIR-V output, since it's easy to just reuse an
+/// instruction ID in two places; that has the same semantics as a temporary variable, and
+/// it's inherent in the design of SPIR-V. This function is more useful for text-based
+/// back ends.
+///
+/// [`ReadZeroSkipWrite`]: BoundsCheckPolicy::ReadZeroSkipWrite
+/// [`index`]: crate::Expression::Access::index
+/// [`Access`]: crate::Expression::Access
+/// [`level`]: crate::Expression::ImageLoad::level
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+/// [`Emit`]: crate::Statement::Emit
+/// [`ImageStore`]: crate::Statement::ImageStore
+pub fn find_checked_indexes(
+ module: &crate::Module,
+ function: &crate::Function,
+ info: &crate::valid::FunctionInfo,
+ policies: BoundsCheckPolicies,
+) -> BitSet {
+ use crate::Expression as Ex;
+
+ let mut guarded_indices = BitSet::new();
+
+ // Don't bother scanning if we never need `ReadZeroSkipWrite`.
+ if policies.contains(BoundsCheckPolicy::ReadZeroSkipWrite) {
+ for (_handle, expr) in function.expressions.iter() {
+ // There's no need to handle `AccessIndex` expressions, as their
+ // indices never need to be cached.
+ match *expr {
+ Ex::Access { base, index } => {
+ if policies.choose_policy(base, &module.types, info)
+ == BoundsCheckPolicy::ReadZeroSkipWrite
+ && access_needs_check(
+ base,
+ GuardedIndex::Expression(index),
+ module,
+ function,
+ info,
+ )
+ .is_some()
+ {
+ guarded_indices.insert(index.index());
+ }
+ }
+ Ex::ImageLoad {
+ coordinate,
+ array_index,
+ sample,
+ level,
+ ..
+ } => {
+ if policies.image_load == BoundsCheckPolicy::ReadZeroSkipWrite {
+ guarded_indices.insert(coordinate.index());
+ if let Some(array_index) = array_index {
+ guarded_indices.insert(array_index.index());
+ }
+ if let Some(sample) = sample {
+ guarded_indices.insert(sample.index());
+ }
+ if let Some(level) = level {
+ guarded_indices.insert(level.index());
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+
+ guarded_indices
+}
+
+/// Determine whether `index` is statically known to be in bounds for `base`.
+///
+/// If we can't be sure that the index is in bounds, return the limit within
+/// which valid indices must fall.
+///
+/// The return value is one of the following:
+///
+/// - `Some(Known(n))` indicates that `n` is the largest valid index.
+///
+/// - `Some(Computed(global))` indicates that the largest valid index is one
+/// less than the length of the array that is the last member of the
+/// struct held in `global`.
+///
+/// - `None` indicates that the index need not be checked, either because it
+/// is statically known to be in bounds, or because the applicable policy
+/// is `Unchecked`.
+///
+/// This function only handles subscriptable types: arrays, vectors, and
+/// matrices. It does not handle struct member indices; those never require
+/// run-time checks, so it's best to deal with them further up the call
+/// chain.
+pub fn access_needs_check(
+ base: Handle<crate::Expression>,
+ mut index: GuardedIndex,
+ module: &crate::Module,
+ function: &crate::Function,
+ info: &crate::valid::FunctionInfo,
+) -> Option<IndexableLength> {
+ let base_inner = info[base].ty.inner_with(&module.types);
+ // Unwrap safety: `Err` here indicates unindexable base types and invalid
+ // length constants, but `access_needs_check` is only used by back ends, so
+ // validation should have caught those problems.
+ let length = base_inner.indexable_length(module).unwrap();
+ index.try_resolve_to_constant(function, module);
+ if let (&GuardedIndex::Known(index), &IndexableLength::Known(length)) = (&index, &length) {
+ if index < length {
+ // Index is statically known to be in bounds, no check needed.
+ return None;
+ }
+ };
+
+ Some(length)
+}
+
+impl GuardedIndex {
+ /// Make a `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible.
+ ///
+ /// Return values that are already `Known` unchanged.
+ fn try_resolve_to_constant(&mut self, function: &crate::Function, module: &crate::Module) {
+ if let GuardedIndex::Expression(expr) = *self {
+ if let Ok(value) = module
+ .to_ctx()
+ .eval_expr_to_u32_from(expr, &function.expressions)
+ {
+ *self = GuardedIndex::Known(value);
+ }
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)]
+pub enum IndexableLengthError {
+ #[error("Type is not indexable, and has no length (validation error)")]
+ TypeNotIndexable,
+ #[error("Array length constant {0:?} is invalid")]
+ InvalidArrayLength(Handle<crate::Expression>),
+}
+
+impl crate::TypeInner {
+ /// Return the length of a subscriptable type.
+ ///
+ /// The `self` parameter should be a handle to a vector, matrix, or array
+ /// type, a pointer to one of those, or a value pointer. Arrays may be
+ /// fixed-size, dynamically sized, or sized by a specializable constant.
+ /// This function does not handle struct member references, as with
+ /// `AccessIndex`.
+ ///
+ /// The value returned is appropriate for bounds checks on subscripting.
+ ///
+ /// Return an error if `self` does not describe a subscriptable type at all.
+ pub fn indexable_length(
+ &self,
+ module: &crate::Module,
+ ) -> Result<IndexableLength, IndexableLengthError> {
+ use crate::TypeInner as Ti;
+ let known_length = match *self {
+ Ti::Vector { size, .. } => size as _,
+ Ti::Matrix { columns, .. } => columns as _,
+ Ti::Array { size, .. } | Ti::BindingArray { size, .. } => {
+ return size.to_indexable_length(module);
+ }
+ Ti::ValuePointer {
+ size: Some(size), ..
+ } => size as _,
+ Ti::Pointer { base, .. } => {
+ // When assigning types to expressions, ResolveContext::Resolve
+ // does a separate sub-match here instead of a full recursion,
+ // so we'll do the same.
+ let base_inner = &module.types[base].inner;
+ match *base_inner {
+ Ti::Vector { size, .. } => size as _,
+ Ti::Matrix { columns, .. } => columns as _,
+ Ti::Array { size, .. } | Ti::BindingArray { size, .. } => {
+ return size.to_indexable_length(module)
+ }
+ _ => return Err(IndexableLengthError::TypeNotIndexable),
+ }
+ }
+ _ => return Err(IndexableLengthError::TypeNotIndexable),
+ };
+ Ok(IndexableLength::Known(known_length))
+ }
+}
+
+/// The number of elements in an indexable type.
+///
+/// This summarizes the length of vectors, matrices, and arrays in a way that is
+/// convenient for indexing and bounds-checking code.
+#[derive(Debug)]
+pub enum IndexableLength {
+ /// Values of this type always have the given number of elements.
+ Known(u32),
+
+ /// The number of elements is determined at runtime.
+ Dynamic,
+}
+
+impl crate::ArraySize {
+ pub const fn to_indexable_length(
+ self,
+ _module: &crate::Module,
+ ) -> Result<IndexableLength, IndexableLengthError> {
+ Ok(match self {
+ Self::Constant(length) => IndexableLength::Known(length.get()),
+ Self::Dynamic => IndexableLength::Dynamic,
+ })
+ }
+}
diff --git a/third_party/rust/naga/src/proc/layouter.rs b/third_party/rust/naga/src/proc/layouter.rs
new file mode 100644
index 0000000000..1c78a594d1
--- /dev/null
+++ b/third_party/rust/naga/src/proc/layouter.rs
@@ -0,0 +1,251 @@
+use crate::arena::Handle;
+use std::{fmt::Display, num::NonZeroU32, ops};
+
+/// A newtype struct where its only valid values are powers of 2
+#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Alignment(NonZeroU32);
+
+impl Alignment {
+ pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) });
+ pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) });
+ pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) });
+ pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) });
+ pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) });
+
+ pub const MIN_UNIFORM: Self = Self::SIXTEEN;
+
+ pub const fn new(n: u32) -> Option<Self> {
+ if n.is_power_of_two() {
+ // SAFETY: value can't be 0 since we just checked if it's a power of 2
+ Some(Self(unsafe { NonZeroU32::new_unchecked(n) }))
+ } else {
+ None
+ }
+ }
+
+ /// # Panics
+ /// If `width` is not a power of 2
+ pub fn from_width(width: u8) -> Self {
+ Self::new(width as u32).unwrap()
+ }
+
+ /// Returns whether or not `n` is a multiple of this alignment.
+ pub const fn is_aligned(&self, n: u32) -> bool {
+ // equivalent to: `n % self.0.get() == 0` but much faster
+ n & (self.0.get() - 1) == 0
+ }
+
+ /// Round `n` up to the nearest alignment boundary.
+ pub const fn round_up(&self, n: u32) -> u32 {
+ // equivalent to:
+ // match n % self.0.get() {
+ // 0 => n,
+ // rem => n + (self.0.get() - rem),
+ // }
+ let mask = self.0.get() - 1;
+ (n + mask) & !mask
+ }
+}
+
+impl Display for Alignment {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.0.get().fmt(f)
+ }
+}
+
+impl ops::Mul<u32> for Alignment {
+ type Output = u32;
+
+ fn mul(self, rhs: u32) -> Self::Output {
+ self.0.get() * rhs
+ }
+}
+
+impl ops::Mul for Alignment {
+ type Output = Alignment;
+
+ fn mul(self, rhs: Alignment) -> Self::Output {
+ // SAFETY: both lhs and rhs are powers of 2, the result will be a power of 2
+ Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) })
+ }
+}
+
+impl From<crate::VectorSize> for Alignment {
+ fn from(size: crate::VectorSize) -> Self {
+ match size {
+ crate::VectorSize::Bi => Alignment::TWO,
+ crate::VectorSize::Tri => Alignment::FOUR,
+ crate::VectorSize::Quad => Alignment::FOUR,
+ }
+ }
+}
+
+/// Size and alignment information for a type.
+#[derive(Clone, Copy, Debug, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct TypeLayout {
+ pub size: u32,
+ pub alignment: Alignment,
+}
+
+impl TypeLayout {
+ /// Produce the stride as if this type is a base of an array.
+ pub const fn to_stride(&self) -> u32 {
+ self.alignment.round_up(self.size)
+ }
+}
+
+/// Helper processor that derives the sizes of all types.
+///
+/// `Layouter` uses the default layout algorithm/table, described in
+/// [WGSL §4.3.7, "Memory Layout"]
+///
+/// A `Layouter` may be indexed by `Handle<Type>` values: `layouter[handle]` is the
+/// layout of the type whose handle is `handle`.
+///
+/// [WGSL §4.3.7, "Memory Layout"](https://gpuweb.github.io/gpuweb/wgsl/#memory-layouts)
+#[derive(Debug, Default)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Layouter {
+ /// Layouts for types in an arena, indexed by `Handle` index.
+ layouts: Vec<TypeLayout>,
+}
+
+impl ops::Index<Handle<crate::Type>> for Layouter {
+ type Output = TypeLayout;
+ fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
+ &self.layouts[handle.index()]
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
+pub enum LayoutErrorInner {
+ #[error("Array element type {0:?} doesn't exist")]
+ InvalidArrayElementType(Handle<crate::Type>),
+ #[error("Struct member[{0}] type {1:?} doesn't exist")]
+ InvalidStructMemberType(u32, Handle<crate::Type>),
+ #[error("Type width must be a power of two")]
+ NonPowerOfTwoWidth,
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
+#[error("Error laying out type {ty:?}: {inner}")]
+pub struct LayoutError {
+ pub ty: Handle<crate::Type>,
+ pub inner: LayoutErrorInner,
+}
+
+impl LayoutErrorInner {
+ const fn with(self, ty: Handle<crate::Type>) -> LayoutError {
+ LayoutError { ty, inner: self }
+ }
+}
+
+impl Layouter {
+ /// Remove all entries from this `Layouter`, retaining storage.
+ pub fn clear(&mut self) {
+ self.layouts.clear();
+ }
+
+ /// Extend this `Layouter` with layouts for any new entries in `gctx.types`.
+ ///
+ /// Ensure that every type in `gctx.types` has a corresponding [TypeLayout]
+ /// in [`self.layouts`].
+ ///
+ /// Some front ends need to be able to compute layouts for existing types
+ /// while module construction is still in progress and new types are still
+ /// being added. This function assumes that the `TypeLayout` values already
+ /// present in `self.layouts` cover their corresponding entries in `types`,
+ /// and extends `self.layouts` as needed to cover the rest. Thus, a front
+ /// end can call this function at any time, passing its current type and
+ /// constant arenas, and then assume that layouts are available for all
+ /// types.
+ #[allow(clippy::or_fun_call)]
+ pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> {
+ use crate::TypeInner as Ti;
+
+ for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) {
+ let size = ty.inner.size(gctx);
+ let layout = match ty.inner {
+ Ti::Scalar(scalar) | Ti::Atomic(scalar) => {
+ let alignment = Alignment::new(scalar.width as u32)
+ .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
+ TypeLayout { size, alignment }
+ }
+ Ti::Vector {
+ size: vec_size,
+ scalar,
+ } => {
+ let alignment = Alignment::new(scalar.width as u32)
+ .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
+ TypeLayout {
+ size,
+ alignment: Alignment::from(vec_size) * alignment,
+ }
+ }
+ Ti::Matrix {
+ columns: _,
+ rows,
+ scalar,
+ } => {
+ let alignment = Alignment::new(scalar.width as u32)
+ .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
+ TypeLayout {
+ size,
+ alignment: Alignment::from(rows) * alignment,
+ }
+ }
+ Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
+ size,
+ alignment: Alignment::ONE,
+ },
+ Ti::Array {
+ base,
+ stride: _,
+ size: _,
+ } => TypeLayout {
+ size,
+ alignment: if base < ty_handle {
+ self[base].alignment
+ } else {
+ return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle));
+ },
+ },
+ Ti::Struct { span, ref members } => {
+ let mut alignment = Alignment::ONE;
+ for (index, member) in members.iter().enumerate() {
+ alignment = if member.ty < ty_handle {
+ alignment.max(self[member.ty].alignment)
+ } else {
+ return Err(LayoutErrorInner::InvalidStructMemberType(
+ index as u32,
+ member.ty,
+ )
+ .with(ty_handle));
+ };
+ }
+ TypeLayout {
+ size: span,
+ alignment,
+ }
+ }
+ Ti::Image { .. }
+ | Ti::Sampler { .. }
+ | Ti::AccelerationStructure
+ | Ti::RayQuery
+ | Ti::BindingArray { .. } => TypeLayout {
+ size,
+ alignment: Alignment::ONE,
+ },
+ };
+ debug_assert!(size <= layout.size);
+ self.layouts.push(layout);
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/proc/mod.rs b/third_party/rust/naga/src/proc/mod.rs
new file mode 100644
index 0000000000..b9ce80b5ea
--- /dev/null
+++ b/third_party/rust/naga/src/proc/mod.rs
@@ -0,0 +1,809 @@
+/*!
+[`Module`](super::Module) processing functionality.
+*/
+
+mod constant_evaluator;
+mod emitter;
+pub mod index;
+mod layouter;
+mod namer;
+mod terminator;
+mod typifier;
+
+pub use constant_evaluator::{
+ ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker,
+};
+pub use emitter::Emitter;
+pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
+pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
+pub use namer::{EntryPointIndex, NameKey, Namer};
+pub use terminator::ensure_block_returns;
+pub use typifier::{ResolveContext, ResolveError, TypeResolution};
+
+impl From<super::StorageFormat> for super::ScalarKind {
+ fn from(format: super::StorageFormat) -> Self {
+ use super::{ScalarKind as Sk, StorageFormat as Sf};
+ match format {
+ Sf::R8Unorm => Sk::Float,
+ Sf::R8Snorm => Sk::Float,
+ Sf::R8Uint => Sk::Uint,
+ Sf::R8Sint => Sk::Sint,
+ Sf::R16Uint => Sk::Uint,
+ Sf::R16Sint => Sk::Sint,
+ Sf::R16Float => Sk::Float,
+ Sf::Rg8Unorm => Sk::Float,
+ Sf::Rg8Snorm => Sk::Float,
+ Sf::Rg8Uint => Sk::Uint,
+ Sf::Rg8Sint => Sk::Sint,
+ Sf::R32Uint => Sk::Uint,
+ Sf::R32Sint => Sk::Sint,
+ Sf::R32Float => Sk::Float,
+ Sf::Rg16Uint => Sk::Uint,
+ Sf::Rg16Sint => Sk::Sint,
+ Sf::Rg16Float => Sk::Float,
+ Sf::Rgba8Unorm => Sk::Float,
+ Sf::Rgba8Snorm => Sk::Float,
+ Sf::Rgba8Uint => Sk::Uint,
+ Sf::Rgba8Sint => Sk::Sint,
+ Sf::Bgra8Unorm => Sk::Float,
+ Sf::Rgb10a2Uint => Sk::Uint,
+ Sf::Rgb10a2Unorm => Sk::Float,
+ Sf::Rg11b10Float => Sk::Float,
+ Sf::Rg32Uint => Sk::Uint,
+ Sf::Rg32Sint => Sk::Sint,
+ Sf::Rg32Float => Sk::Float,
+ Sf::Rgba16Uint => Sk::Uint,
+ Sf::Rgba16Sint => Sk::Sint,
+ Sf::Rgba16Float => Sk::Float,
+ Sf::Rgba32Uint => Sk::Uint,
+ Sf::Rgba32Sint => Sk::Sint,
+ Sf::Rgba32Float => Sk::Float,
+ Sf::R16Unorm => Sk::Float,
+ Sf::R16Snorm => Sk::Float,
+ Sf::Rg16Unorm => Sk::Float,
+ Sf::Rg16Snorm => Sk::Float,
+ Sf::Rgba16Unorm => Sk::Float,
+ Sf::Rgba16Snorm => Sk::Float,
+ }
+ }
+}
+
+impl super::ScalarKind {
+ pub const fn is_numeric(self) -> bool {
+ match self {
+ crate::ScalarKind::Sint
+ | crate::ScalarKind::Uint
+ | crate::ScalarKind::Float
+ | crate::ScalarKind::AbstractInt
+ | crate::ScalarKind::AbstractFloat => true,
+ crate::ScalarKind::Bool => false,
+ }
+ }
+}
+
+impl super::Scalar {
+ pub const I32: Self = Self {
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ };
+ pub const U32: Self = Self {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ };
+ pub const F32: Self = Self {
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ };
+ pub const F64: Self = Self {
+ kind: crate::ScalarKind::Float,
+ width: 8,
+ };
+ pub const I64: Self = Self {
+ kind: crate::ScalarKind::Sint,
+ width: 8,
+ };
+ pub const BOOL: Self = Self {
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ };
+ pub const ABSTRACT_INT: Self = Self {
+ kind: crate::ScalarKind::AbstractInt,
+ width: crate::ABSTRACT_WIDTH,
+ };
+ pub const ABSTRACT_FLOAT: Self = Self {
+ kind: crate::ScalarKind::AbstractFloat,
+ width: crate::ABSTRACT_WIDTH,
+ };
+
+ pub const fn is_abstract(self) -> bool {
+ match self.kind {
+ crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => true,
+ crate::ScalarKind::Sint
+ | crate::ScalarKind::Uint
+ | crate::ScalarKind::Float
+ | crate::ScalarKind::Bool => false,
+ }
+ }
+
+ /// Construct a float `Scalar` with the given width.
+ ///
+ /// This is especially common when dealing with
+ /// `TypeInner::Matrix`, where the scalar kind is implicit.
+ pub const fn float(width: crate::Bytes) -> Self {
+ Self {
+ kind: crate::ScalarKind::Float,
+ width,
+ }
+ }
+
+ pub const fn to_inner_scalar(self) -> crate::TypeInner {
+ crate::TypeInner::Scalar(self)
+ }
+
+ pub const fn to_inner_vector(self, size: crate::VectorSize) -> crate::TypeInner {
+ crate::TypeInner::Vector { size, scalar: self }
+ }
+
+ pub const fn to_inner_atomic(self) -> crate::TypeInner {
+ crate::TypeInner::Atomic(self)
+ }
+}
+
+impl PartialEq for crate::Literal {
+ fn eq(&self, other: &Self) -> bool {
+ match (*self, *other) {
+ (Self::F64(a), Self::F64(b)) => a.to_bits() == b.to_bits(),
+ (Self::F32(a), Self::F32(b)) => a.to_bits() == b.to_bits(),
+ (Self::U32(a), Self::U32(b)) => a == b,
+ (Self::I32(a), Self::I32(b)) => a == b,
+ (Self::I64(a), Self::I64(b)) => a == b,
+ (Self::Bool(a), Self::Bool(b)) => a == b,
+ _ => false,
+ }
+ }
+}
+impl Eq for crate::Literal {}
+impl std::hash::Hash for crate::Literal {
+ fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
+ match *self {
+ Self::F64(v) | Self::AbstractFloat(v) => {
+ hasher.write_u8(0);
+ v.to_bits().hash(hasher);
+ }
+ Self::F32(v) => {
+ hasher.write_u8(1);
+ v.to_bits().hash(hasher);
+ }
+ Self::U32(v) => {
+ hasher.write_u8(2);
+ v.hash(hasher);
+ }
+ Self::I32(v) => {
+ hasher.write_u8(3);
+ v.hash(hasher);
+ }
+ Self::Bool(v) => {
+ hasher.write_u8(4);
+ v.hash(hasher);
+ }
+ Self::I64(v) | Self::AbstractInt(v) => {
+ hasher.write_u8(5);
+ v.hash(hasher);
+ }
+ }
+ }
+}
+
+impl crate::Literal {
+ pub const fn new(value: u8, scalar: crate::Scalar) -> Option<Self> {
+ match (value, scalar.kind, scalar.width) {
+ (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)),
+ (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
+ (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
+ (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
+ (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
+ (1, crate::ScalarKind::Bool, 4) => Some(Self::Bool(true)),
+ (0, crate::ScalarKind::Bool, 4) => Some(Self::Bool(false)),
+ _ => None,
+ }
+ }
+
+ pub const fn zero(scalar: crate::Scalar) -> Option<Self> {
+ Self::new(0, scalar)
+ }
+
+ pub const fn one(scalar: crate::Scalar) -> Option<Self> {
+ Self::new(1, scalar)
+ }
+
+ pub const fn width(&self) -> crate::Bytes {
+ match *self {
+ Self::F64(_) | Self::I64(_) => 8,
+ Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
+ Self::Bool(_) => crate::BOOL_WIDTH,
+ Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH,
+ }
+ }
+ pub const fn scalar(&self) -> crate::Scalar {
+ match *self {
+ Self::F64(_) => crate::Scalar::F64,
+ Self::F32(_) => crate::Scalar::F32,
+ Self::U32(_) => crate::Scalar::U32,
+ Self::I32(_) => crate::Scalar::I32,
+ Self::I64(_) => crate::Scalar::I64,
+ Self::Bool(_) => crate::Scalar::BOOL,
+ Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT,
+ Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT,
+ }
+ }
+ pub const fn scalar_kind(&self) -> crate::ScalarKind {
+ self.scalar().kind
+ }
+ pub const fn ty_inner(&self) -> crate::TypeInner {
+ crate::TypeInner::Scalar(self.scalar())
+ }
+}
+
+pub const POINTER_SPAN: u32 = 4;
+
+impl super::TypeInner {
+ /// Return the scalar type of `self`.
+ ///
+ /// If `inner` is a scalar, vector, or matrix type, return
+ /// its scalar type. Otherwise, return `None`.
+ pub const fn scalar(&self) -> Option<super::Scalar> {
+ use crate::TypeInner as Ti;
+ match *self {
+ Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => Some(scalar),
+ Ti::Matrix { scalar, .. } => Some(scalar),
+ _ => None,
+ }
+ }
+
+ pub fn scalar_kind(&self) -> Option<super::ScalarKind> {
+ self.scalar().map(|scalar| scalar.kind)
+ }
+
+ pub fn scalar_width(&self) -> Option<u8> {
+ self.scalar().map(|scalar| scalar.width * 8)
+ }
+
+ pub const fn pointer_space(&self) -> Option<crate::AddressSpace> {
+ match *self {
+ Self::Pointer { space, .. } => Some(space),
+ Self::ValuePointer { space, .. } => Some(space),
+ _ => None,
+ }
+ }
+
+ pub fn is_atomic_pointer(&self, types: &crate::UniqueArena<crate::Type>) -> bool {
+ match *self {
+ crate::TypeInner::Pointer { base, .. } => match types[base].inner {
+ crate::TypeInner::Atomic { .. } => true,
+ _ => false,
+ },
+ _ => false,
+ }
+ }
+
+ /// Get the size of this type.
+ pub fn size(&self, _gctx: GlobalCtx) -> u32 {
+ match *self {
+ Self::Scalar(scalar) | Self::Atomic(scalar) => scalar.width as u32,
+ Self::Vector { size, scalar } => size as u32 * scalar.width as u32,
+ // matrices are treated as arrays of aligned columns
+ Self::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => Alignment::from(rows) * scalar.width as u32 * columns as u32,
+ Self::Pointer { .. } | Self::ValuePointer { .. } => POINTER_SPAN,
+ Self::Array {
+ base: _,
+ size,
+ stride,
+ } => {
+ let count = match size {
+ super::ArraySize::Constant(count) => count.get(),
+ // A dynamically-sized array has to have at least one element
+ super::ArraySize::Dynamic => 1,
+ };
+ count * stride
+ }
+ Self::Struct { span, .. } => span,
+ Self::Image { .. }
+ | Self::Sampler { .. }
+ | Self::AccelerationStructure
+ | Self::RayQuery
+ | Self::BindingArray { .. } => 0,
+ }
+ }
+
+ /// Return the canonical form of `self`, or `None` if it's already in
+ /// canonical form.
+ ///
+ /// Certain types have multiple representations in `TypeInner`. This
+ /// function converts all forms of equivalent types to a single
+ /// representative of their class, so that simply applying `Eq` to the
+ /// result indicates whether the types are equivalent, as far as Naga IR is
+ /// concerned.
+ pub fn canonical_form(
+ &self,
+ types: &crate::UniqueArena<crate::Type>,
+ ) -> Option<crate::TypeInner> {
+ use crate::TypeInner as Ti;
+ match *self {
+ Ti::Pointer { base, space } => match types[base].inner {
+ Ti::Scalar(scalar) => Some(Ti::ValuePointer {
+ size: None,
+ scalar,
+ space,
+ }),
+ Ti::Vector { size, scalar } => Some(Ti::ValuePointer {
+ size: Some(size),
+ scalar,
+ space,
+ }),
+ _ => None,
+ },
+ _ => None,
+ }
+ }
+
+ /// Compare `self` and `rhs` as types.
+ ///
+ /// This is mostly the same as `<TypeInner as Eq>::eq`, but it treats
+ /// `ValuePointer` and `Pointer` types as equivalent.
+ ///
+ /// When you know that one side of the comparison is never a pointer, it's
+ /// fine to not bother with canonicalization, and just compare `TypeInner`
+ /// values with `==`.
+ pub fn equivalent(
+ &self,
+ rhs: &crate::TypeInner,
+ types: &crate::UniqueArena<crate::Type>,
+ ) -> bool {
+ let left = self.canonical_form(types);
+ let right = rhs.canonical_form(types);
+ left.as_ref().unwrap_or(self) == right.as_ref().unwrap_or(rhs)
+ }
+
+ pub fn is_dynamically_sized(&self, types: &crate::UniqueArena<crate::Type>) -> bool {
+ use crate::TypeInner as Ti;
+ match *self {
+ Ti::Array { size, .. } => size == crate::ArraySize::Dynamic,
+ Ti::Struct { ref members, .. } => members
+ .last()
+ .map(|last| types[last.ty].inner.is_dynamically_sized(types))
+ .unwrap_or(false),
+ _ => false,
+ }
+ }
+
+ pub fn components(&self) -> Option<u32> {
+ Some(match *self {
+ Self::Vector { size, .. } => size as u32,
+ Self::Matrix { columns, .. } => columns as u32,
+ Self::Array {
+ size: crate::ArraySize::Constant(len),
+ ..
+ } => len.get(),
+ Self::Struct { ref members, .. } => members.len() as u32,
+ _ => return None,
+ })
+ }
+
+ pub fn component_type(&self, index: usize) -> Option<TypeResolution> {
+ Some(match *self {
+ Self::Vector { scalar, .. } => TypeResolution::Value(crate::TypeInner::Scalar(scalar)),
+ Self::Matrix { rows, scalar, .. } => {
+ TypeResolution::Value(crate::TypeInner::Vector { size: rows, scalar })
+ }
+ Self::Array {
+ base,
+ size: crate::ArraySize::Constant(_),
+ ..
+ } => TypeResolution::Handle(base),
+ Self::Struct { ref members, .. } => TypeResolution::Handle(members[index].ty),
+ _ => return None,
+ })
+ }
+}
+
+impl super::AddressSpace {
+ pub fn access(self) -> crate::StorageAccess {
+ use crate::StorageAccess as Sa;
+ match self {
+ crate::AddressSpace::Function
+ | crate::AddressSpace::Private
+ | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
+ crate::AddressSpace::Uniform => Sa::LOAD,
+ crate::AddressSpace::Storage { access } => access,
+ crate::AddressSpace::Handle => Sa::LOAD,
+ crate::AddressSpace::PushConstant => Sa::LOAD,
+ }
+ }
+}
+
+impl super::MathFunction {
+ pub const fn argument_count(&self) -> usize {
+ match *self {
+ // comparison
+ Self::Abs => 1,
+ Self::Min => 2,
+ Self::Max => 2,
+ Self::Clamp => 3,
+ Self::Saturate => 1,
+ // trigonometry
+ Self::Cos => 1,
+ Self::Cosh => 1,
+ Self::Sin => 1,
+ Self::Sinh => 1,
+ Self::Tan => 1,
+ Self::Tanh => 1,
+ Self::Acos => 1,
+ Self::Asin => 1,
+ Self::Atan => 1,
+ Self::Atan2 => 2,
+ Self::Asinh => 1,
+ Self::Acosh => 1,
+ Self::Atanh => 1,
+ Self::Radians => 1,
+ Self::Degrees => 1,
+ // decomposition
+ Self::Ceil => 1,
+ Self::Floor => 1,
+ Self::Round => 1,
+ Self::Fract => 1,
+ Self::Trunc => 1,
+ Self::Modf => 1,
+ Self::Frexp => 1,
+ Self::Ldexp => 2,
+ // exponent
+ Self::Exp => 1,
+ Self::Exp2 => 1,
+ Self::Log => 1,
+ Self::Log2 => 1,
+ Self::Pow => 2,
+ // geometry
+ Self::Dot => 2,
+ Self::Outer => 2,
+ Self::Cross => 2,
+ Self::Distance => 2,
+ Self::Length => 1,
+ Self::Normalize => 1,
+ Self::FaceForward => 3,
+ Self::Reflect => 2,
+ Self::Refract => 3,
+ // computational
+ Self::Sign => 1,
+ Self::Fma => 3,
+ Self::Mix => 3,
+ Self::Step => 2,
+ Self::SmoothStep => 3,
+ Self::Sqrt => 1,
+ Self::InverseSqrt => 1,
+ Self::Inverse => 1,
+ Self::Transpose => 1,
+ Self::Determinant => 1,
+ // bits
+ Self::CountTrailingZeros => 1,
+ Self::CountLeadingZeros => 1,
+ Self::CountOneBits => 1,
+ Self::ReverseBits => 1,
+ Self::ExtractBits => 3,
+ Self::InsertBits => 4,
+ Self::FindLsb => 1,
+ Self::FindMsb => 1,
+ // data packing
+ Self::Pack4x8snorm => 1,
+ Self::Pack4x8unorm => 1,
+ Self::Pack2x16snorm => 1,
+ Self::Pack2x16unorm => 1,
+ Self::Pack2x16float => 1,
+ // data unpacking
+ Self::Unpack4x8snorm => 1,
+ Self::Unpack4x8unorm => 1,
+ Self::Unpack2x16snorm => 1,
+ Self::Unpack2x16unorm => 1,
+ Self::Unpack2x16float => 1,
+ }
+ }
+}
+
+impl crate::Expression {
+ /// Returns true if the expression is considered emitted at the start of a function.
+ pub const fn needs_pre_emit(&self) -> bool {
+ match *self {
+ Self::Literal(_)
+ | Self::Constant(_)
+ | Self::ZeroValue(_)
+ | Self::FunctionArgument(_)
+ | Self::GlobalVariable(_)
+ | Self::LocalVariable(_) => true,
+ _ => false,
+ }
+ }
+
+ /// Return true if this expression is a dynamic array index, for [`Access`].
+ ///
+ /// This method returns true if this expression is a dynamically computed
+ /// index, and as such can only be used to index matrices and arrays when
+ /// they appear behind a pointer. See the documentation for [`Access`] for
+ /// details.
+ ///
+ /// Note, this does not check the _type_ of the given expression. It's up to
+ /// the caller to establish that the `Access` expression is well-typed
+ /// through other means, like [`ResolveContext`].
+ ///
+ /// [`Access`]: crate::Expression::Access
+ /// [`ResolveContext`]: crate::proc::ResolveContext
+ pub fn is_dynamic_index(&self, module: &crate::Module) -> bool {
+ match *self {
+ Self::Literal(_) | Self::ZeroValue(_) => false,
+ Self::Constant(handle) => {
+ let constant = &module.constants[handle];
+ !matches!(constant.r#override, crate::Override::None)
+ }
+ _ => true,
+ }
+ }
+}
+
+impl crate::Function {
+ /// Return the global variable being accessed by the expression `pointer`.
+ ///
+ /// Assuming that `pointer` is a series of `Access` and `AccessIndex`
+ /// expressions that ultimately access some part of a `GlobalVariable`,
+ /// return a handle for that global.
+ ///
+ /// If the expression does not ultimately access a global variable, return
+ /// `None`.
+ pub fn originating_global(
+ &self,
+ mut pointer: crate::Handle<crate::Expression>,
+ ) -> Option<crate::Handle<crate::GlobalVariable>> {
+ loop {
+ pointer = match self.expressions[pointer] {
+ crate::Expression::Access { base, .. } => base,
+ crate::Expression::AccessIndex { base, .. } => base,
+ crate::Expression::GlobalVariable(handle) => return Some(handle),
+ crate::Expression::LocalVariable(_) => return None,
+ crate::Expression::FunctionArgument(_) => return None,
+ // There are no other expressions that produce pointer values.
+ _ => unreachable!(),
+ }
+ }
+ }
+}
+
+impl crate::SampleLevel {
+ pub const fn implicit_derivatives(&self) -> bool {
+ match *self {
+ Self::Auto | Self::Bias(_) => true,
+ Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
+ }
+ }
+}
+
+impl crate::Binding {
+ pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
+ match *self {
+ crate::Binding::BuiltIn(built_in) => Some(built_in),
+ Self::Location { .. } => None,
+ }
+ }
+}
+
+impl super::SwizzleComponent {
+ pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
+
+ pub const fn index(&self) -> u32 {
+ match *self {
+ Self::X => 0,
+ Self::Y => 1,
+ Self::Z => 2,
+ Self::W => 3,
+ }
+ }
+ pub const fn from_index(idx: u32) -> Self {
+ match idx {
+ 0 => Self::X,
+ 1 => Self::Y,
+ 2 => Self::Z,
+ _ => Self::W,
+ }
+ }
+}
+
+impl super::ImageClass {
+ pub const fn is_multisampled(self) -> bool {
+ match self {
+ crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
+ crate::ImageClass::Storage { .. } => false,
+ }
+ }
+
+ pub const fn is_mipmapped(self) -> bool {
+ match self {
+ crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
+ crate::ImageClass::Storage { .. } => false,
+ }
+ }
+}
+
+impl crate::Module {
+ pub const fn to_ctx(&self) -> GlobalCtx<'_> {
+ GlobalCtx {
+ types: &self.types,
+ constants: &self.constants,
+ const_expressions: &self.const_expressions,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(super) enum U32EvalError {
+ NonConst,
+ Negative,
+}
+
+#[derive(Clone, Copy)]
+pub struct GlobalCtx<'a> {
+ pub types: &'a crate::UniqueArena<crate::Type>,
+ pub constants: &'a crate::Arena<crate::Constant>,
+ pub const_expressions: &'a crate::Arena<crate::Expression>,
+}
+
+impl GlobalCtx<'_> {
+ /// Try to evaluate the expression in `self.const_expressions` using its `handle` and return it as a `u32`.
+ #[allow(dead_code)]
+ pub(super) fn eval_expr_to_u32(
+ &self,
+ handle: crate::Handle<crate::Expression>,
+ ) -> Result<u32, U32EvalError> {
+ self.eval_expr_to_u32_from(handle, self.const_expressions)
+ }
+
+ /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `u32`.
+ pub(super) fn eval_expr_to_u32_from(
+ &self,
+ handle: crate::Handle<crate::Expression>,
+ arena: &crate::Arena<crate::Expression>,
+ ) -> Result<u32, U32EvalError> {
+ match self.eval_expr_to_literal_from(handle, arena) {
+ Some(crate::Literal::U32(value)) => Ok(value),
+ Some(crate::Literal::I32(value)) => {
+ value.try_into().map_err(|_| U32EvalError::Negative)
+ }
+ _ => Err(U32EvalError::NonConst),
+ }
+ }
+
+ #[allow(dead_code)]
+ pub(crate) fn eval_expr_to_literal(
+ &self,
+ handle: crate::Handle<crate::Expression>,
+ ) -> Option<crate::Literal> {
+ self.eval_expr_to_literal_from(handle, self.const_expressions)
+ }
+
+ fn eval_expr_to_literal_from(
+ &self,
+ handle: crate::Handle<crate::Expression>,
+ arena: &crate::Arena<crate::Expression>,
+ ) -> Option<crate::Literal> {
+ fn get(
+ gctx: GlobalCtx,
+ handle: crate::Handle<crate::Expression>,
+ arena: &crate::Arena<crate::Expression>,
+ ) -> Option<crate::Literal> {
+ match arena[handle] {
+ crate::Expression::Literal(literal) => Some(literal),
+ crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
+ crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar),
+ _ => None,
+ },
+ _ => None,
+ }
+ }
+ match arena[handle] {
+ crate::Expression::Constant(c) => {
+ get(*self, self.constants[c].init, self.const_expressions)
+ }
+ _ => get(*self, handle, arena),
+ }
+ }
+}
+
+/// Return an iterator over the individual components assembled by a
+/// `Compose` expression.
+///
+/// Given `ty` and `components` from an `Expression::Compose`, return an
+/// iterator over the components of the resulting value.
+///
+/// Normally, this would just be an iterator over `components`. However,
+/// `Compose` expressions can concatenate vectors, in which case the i'th
+/// value being composed is not generally the i'th element of `components`.
+/// This function consults `ty` to decide if this concatenation is occurring,
+/// and returns an iterator that produces the components of the result of
+/// the `Compose` expression in either case.
+pub fn flatten_compose<'arenas>(
+ ty: crate::Handle<crate::Type>,
+ components: &'arenas [crate::Handle<crate::Expression>],
+ expressions: &'arenas crate::Arena<crate::Expression>,
+ types: &'arenas crate::UniqueArena<crate::Type>,
+) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
+ // Returning `impl Iterator` is a bit tricky. We may or may not
+ // want to flatten the components, but we have to settle on a
+ // single concrete type to return. This function returns a single
+ // iterator chain that handles both the flattening and
+ // non-flattening cases.
+ let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
+ (size as usize, true)
+ } else {
+ (components.len(), false)
+ };
+
+ /// Flatten `Compose` expressions if `is_vector` is true.
+ fn flatten_compose<'c>(
+ component: &'c crate::Handle<crate::Expression>,
+ is_vector: bool,
+ expressions: &'c crate::Arena<crate::Expression>,
+ ) -> &'c [crate::Handle<crate::Expression>] {
+ if is_vector {
+ if let crate::Expression::Compose {
+ ty: _,
+ components: ref subcomponents,
+ } = expressions[*component]
+ {
+ return subcomponents;
+ }
+ }
+ std::slice::from_ref(component)
+ }
+
+ /// Flatten `Splat` expressions if `is_vector` is true.
+ fn flatten_splat<'c>(
+ component: &'c crate::Handle<crate::Expression>,
+ is_vector: bool,
+ expressions: &'c crate::Arena<crate::Expression>,
+ ) -> impl Iterator<Item = crate::Handle<crate::Expression>> {
+ let mut expr = *component;
+ let mut count = 1;
+ if is_vector {
+ if let crate::Expression::Splat { size, value } = expressions[expr] {
+ expr = value;
+ count = size as usize;
+ }
+ }
+ std::iter::repeat(expr).take(count)
+ }
+
+ // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to
+ // flatten up to two levels of `Compose` expressions.
+ //
+ // Expressions like `vec4(vec3(1.0), 1.0)` require us to flatten
+ // `Splat` expressions. Fortunately, the operand of a `Splat` must
+ // be a scalar, so we can stop there.
+ components
+ .iter()
+ .flat_map(move |component| flatten_compose(component, is_vector, expressions))
+ .flat_map(move |component| flatten_compose(component, is_vector, expressions))
+ .flat_map(move |component| flatten_splat(component, is_vector, expressions))
+ .take(size)
+}
+
+#[test]
+fn test_matrix_size() {
+ let module = crate::Module::default();
+ assert_eq!(
+ crate::TypeInner::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ scalar: crate::Scalar::F32,
+ }
+ .size(module.to_ctx()),
+ 48,
+ );
+}
diff --git a/third_party/rust/naga/src/proc/namer.rs b/third_party/rust/naga/src/proc/namer.rs
new file mode 100644
index 0000000000..8afacb593d
--- /dev/null
+++ b/third_party/rust/naga/src/proc/namer.rs
@@ -0,0 +1,281 @@
+use crate::{arena::Handle, FastHashMap, FastHashSet};
+use std::borrow::Cow;
+use std::hash::{Hash, Hasher};
+
+pub type EntryPointIndex = u16;
+const SEPARATOR: char = '_';
+
+#[derive(Debug, Eq, Hash, PartialEq)]
+pub enum NameKey {
+ Constant(Handle<crate::Constant>),
+ GlobalVariable(Handle<crate::GlobalVariable>),
+ Type(Handle<crate::Type>),
+ StructMember(Handle<crate::Type>, u32),
+ Function(Handle<crate::Function>),
+ FunctionArgument(Handle<crate::Function>, u32),
+ FunctionLocal(Handle<crate::Function>, Handle<crate::LocalVariable>),
+ EntryPoint(EntryPointIndex),
+ EntryPointLocal(EntryPointIndex, Handle<crate::LocalVariable>),
+ EntryPointArgument(EntryPointIndex, u32),
+}
+
+/// This processor assigns names to all the things in a module
+/// that may need identifiers in a textual backend.
+#[derive(Default)]
+pub struct Namer {
+ /// The last numeric suffix used for each base name. Zero means "no suffix".
+ unique: FastHashMap<String, u32>,
+ keywords: FastHashSet<&'static str>,
+ keywords_case_insensitive: FastHashSet<AsciiUniCase<&'static str>>,
+ reserved_prefixes: Vec<&'static str>,
+}
+
+impl Namer {
+ /// Return a form of `string` suitable for use as the base of an identifier.
+ ///
+ /// - Drop leading digits.
+ /// - Retain only alphanumeric and `_` characters.
+ /// - Avoid prefixes in [`Namer::reserved_prefixes`].
+ /// - Replace consecutive `_` characters with a single `_` character.
+ ///
+ /// The return value is a valid identifier prefix in all of Naga's output languages,
+ /// and it never ends with a `SEPARATOR` character.
+ /// It is used as a key into the unique table.
+ fn sanitize<'s>(&self, string: &'s str) -> Cow<'s, str> {
+ let string = string
+ .trim_start_matches(|c: char| c.is_numeric())
+ .trim_end_matches(SEPARATOR);
+
+ let base = if !string.is_empty()
+ && !string.contains("__")
+ && string
+ .chars()
+ .all(|c: char| c.is_ascii_alphanumeric() || c == '_')
+ {
+ Cow::Borrowed(string)
+ } else {
+ let mut filtered = string
+ .chars()
+ .filter(|&c| c.is_ascii_alphanumeric() || c == '_')
+ .fold(String::new(), |mut s, c| {
+ if s.ends_with('_') && c == '_' {
+ return s;
+ }
+ s.push(c);
+ s
+ });
+ let stripped_len = filtered.trim_end_matches(SEPARATOR).len();
+ filtered.truncate(stripped_len);
+ if filtered.is_empty() {
+ filtered.push_str("unnamed");
+ }
+ Cow::Owned(filtered)
+ };
+
+ for prefix in &self.reserved_prefixes {
+ if base.starts_with(prefix) {
+ return format!("gen_{base}").into();
+ }
+ }
+
+ base
+ }
+
+ /// Return a new identifier based on `label_raw`.
+ ///
+ /// The result:
+ /// - is a valid identifier even if `label_raw` is not
+ /// - conflicts with no keywords listed in `Namer::keywords`, and
+ /// - is different from any identifier previously constructed by this
+ /// `Namer`.
+ ///
+ /// Guarantee uniqueness by applying a numeric suffix when necessary. If `label_raw`
+ /// itself ends with digits, separate them from the suffix with an underscore.
+ pub fn call(&mut self, label_raw: &str) -> String {
+ use std::fmt::Write as _; // for write!-ing to Strings
+
+ let base = self.sanitize(label_raw);
+ debug_assert!(!base.is_empty() && !base.ends_with(SEPARATOR));
+
+ // This would seem to be a natural place to use `HashMap::entry`. However, `entry`
+ // requires an owned key, and we'd like to avoid heap-allocating strings we're
+ // just going to throw away. The approach below double-hashes only when we create
+ // a new entry, in which case the heap allocation of the owned key was more
+ // expensive anyway.
+ match self.unique.get_mut(base.as_ref()) {
+ Some(count) => {
+ *count += 1;
+ // Add the suffix. This may fit in base's existing allocation.
+ let mut suffixed = base.into_owned();
+ write!(suffixed, "{}{}", SEPARATOR, *count).unwrap();
+ suffixed
+ }
+ None => {
+ let mut suffixed = base.to_string();
+ if base.ends_with(char::is_numeric)
+ || self.keywords.contains(base.as_ref())
+ || self
+ .keywords_case_insensitive
+ .contains(&AsciiUniCase(base.as_ref()))
+ {
+ suffixed.push(SEPARATOR);
+ }
+ debug_assert!(!self.keywords.contains::<str>(&suffixed));
+ // `self.unique` wants to own its keys. This allocates only if we haven't
+ // already done so earlier.
+ self.unique.insert(base.into_owned(), 0);
+ suffixed
+ }
+ }
+ }
+
+ pub fn call_or(&mut self, label: &Option<String>, fallback: &str) -> String {
+ self.call(match *label {
+ Some(ref name) => name,
+ None => fallback,
+ })
+ }
+
+ /// Enter a local namespace for things like structs.
+ ///
+ /// Struct member names only need to be unique amongst themselves, not
+ /// globally. This function temporarily establishes a fresh, empty naming
+ /// context for the duration of the call to `body`.
+ fn namespace(&mut self, capacity: usize, body: impl FnOnce(&mut Self)) {
+ let fresh = FastHashMap::with_capacity_and_hasher(capacity, Default::default());
+ let outer = std::mem::replace(&mut self.unique, fresh);
+ body(self);
+ self.unique = outer;
+ }
+
+ pub fn reset(
+ &mut self,
+ module: &crate::Module,
+ reserved_keywords: &[&'static str],
+ extra_reserved_keywords: &[&'static str],
+ reserved_keywords_case_insensitive: &[&'static str],
+ reserved_prefixes: &[&'static str],
+ output: &mut FastHashMap<NameKey, String>,
+ ) {
+ self.reserved_prefixes.clear();
+ self.reserved_prefixes.extend(reserved_prefixes.iter());
+
+ self.unique.clear();
+ self.keywords.clear();
+ self.keywords.extend(reserved_keywords.iter());
+ self.keywords.extend(extra_reserved_keywords.iter());
+
+ debug_assert!(reserved_keywords_case_insensitive
+ .iter()
+ .all(|s| s.is_ascii()));
+ self.keywords_case_insensitive.clear();
+ self.keywords_case_insensitive.extend(
+ reserved_keywords_case_insensitive
+ .iter()
+ .map(|string| (AsciiUniCase(*string))),
+ );
+
+ let mut temp = String::new();
+
+ for (ty_handle, ty) in module.types.iter() {
+ let ty_name = self.call_or(&ty.name, "type");
+ output.insert(NameKey::Type(ty_handle), ty_name);
+
+ if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
+ // struct members have their own namespace, because access is always prefixed
+ self.namespace(members.len(), |namer| {
+ for (index, member) in members.iter().enumerate() {
+ let name = namer.call_or(&member.name, "member");
+ output.insert(NameKey::StructMember(ty_handle, index as u32), name);
+ }
+ })
+ }
+ }
+
+ for (ep_index, ep) in module.entry_points.iter().enumerate() {
+ let ep_name = self.call(&ep.name);
+ output.insert(NameKey::EntryPoint(ep_index as _), ep_name);
+ for (index, arg) in ep.function.arguments.iter().enumerate() {
+ let name = self.call_or(&arg.name, "param");
+ output.insert(
+ NameKey::EntryPointArgument(ep_index as _, index as u32),
+ name,
+ );
+ }
+ for (handle, var) in ep.function.local_variables.iter() {
+ let name = self.call_or(&var.name, "local");
+ output.insert(NameKey::EntryPointLocal(ep_index as _, handle), name);
+ }
+ }
+
+ for (fun_handle, fun) in module.functions.iter() {
+ let fun_name = self.call_or(&fun.name, "function");
+ output.insert(NameKey::Function(fun_handle), fun_name);
+ for (index, arg) in fun.arguments.iter().enumerate() {
+ let name = self.call_or(&arg.name, "param");
+ output.insert(NameKey::FunctionArgument(fun_handle, index as u32), name);
+ }
+ for (handle, var) in fun.local_variables.iter() {
+ let name = self.call_or(&var.name, "local");
+ output.insert(NameKey::FunctionLocal(fun_handle, handle), name);
+ }
+ }
+
+ for (handle, var) in module.global_variables.iter() {
+ let name = self.call_or(&var.name, "global");
+ output.insert(NameKey::GlobalVariable(handle), name);
+ }
+
+ for (handle, constant) in module.constants.iter() {
+ let label = match constant.name {
+ Some(ref name) => name,
+ None => {
+ use std::fmt::Write;
+ // Try to be more descriptive about the constant values
+ temp.clear();
+ write!(temp, "const_{}", output[&NameKey::Type(constant.ty)]).unwrap();
+ &temp
+ }
+ };
+ let name = self.call(label);
+ output.insert(NameKey::Constant(handle), name);
+ }
+ }
+}
+
+/// A string wrapper type with an ascii case insensitive Eq and Hash impl
+struct AsciiUniCase<S: AsRef<str> + ?Sized>(S);
+
+impl<S: AsRef<str>> PartialEq<Self> for AsciiUniCase<S> {
+ #[inline]
+ fn eq(&self, other: &Self) -> bool {
+ self.0.as_ref().eq_ignore_ascii_case(other.0.as_ref())
+ }
+}
+
+impl<S: AsRef<str>> Eq for AsciiUniCase<S> {}
+
+impl<S: AsRef<str>> Hash for AsciiUniCase<S> {
+ #[inline]
+ fn hash<H: Hasher>(&self, hasher: &mut H) {
+ for byte in self
+ .0
+ .as_ref()
+ .as_bytes()
+ .iter()
+ .map(|b| b.to_ascii_lowercase())
+ {
+ hasher.write_u8(byte);
+ }
+ }
+}
+
+#[test]
+fn test() {
+ let mut namer = Namer::default();
+ assert_eq!(namer.call("x"), "x");
+ assert_eq!(namer.call("x"), "x_1");
+ assert_eq!(namer.call("x1"), "x1_");
+ assert_eq!(namer.call("__x"), "_x");
+ assert_eq!(namer.call("1___x"), "_x_1");
+}
diff --git a/third_party/rust/naga/src/proc/terminator.rs b/third_party/rust/naga/src/proc/terminator.rs
new file mode 100644
index 0000000000..a5239d4eca
--- /dev/null
+++ b/third_party/rust/naga/src/proc/terminator.rs
@@ -0,0 +1,44 @@
+/// Ensure that the given block has return statements
+/// at the end of its control flow.
+///
+/// Note: we don't want to blindly append a return statement
+/// to the end, because it may be either redundant or invalid,
+/// e.g. when the user already has returns in if/else branches.
+pub fn ensure_block_returns(block: &mut crate::Block) {
+ use crate::Statement as S;
+ match block.last_mut() {
+ Some(&mut S::Block(ref mut b)) => {
+ ensure_block_returns(b);
+ }
+ Some(&mut S::If {
+ condition: _,
+ ref mut accept,
+ ref mut reject,
+ }) => {
+ ensure_block_returns(accept);
+ ensure_block_returns(reject);
+ }
+ Some(&mut S::Switch {
+ selector: _,
+ ref mut cases,
+ }) => {
+ for case in cases.iter_mut() {
+ if !case.fall_through {
+ ensure_block_returns(&mut case.body);
+ }
+ }
+ }
+ Some(&mut (S::Emit(_) | S::Break | S::Continue | S::Return { .. } | S::Kill)) => (),
+ Some(
+ &mut (S::Loop { .. }
+ | S::Store { .. }
+ | S::ImageStore { .. }
+ | S::Call { .. }
+ | S::RayQuery { .. }
+ | S::Atomic { .. }
+ | S::WorkGroupUniformLoad { .. }
+ | S::Barrier(_)),
+ )
+ | None => block.push(S::Return { value: None }, Default::default()),
+ }
+}
diff --git a/third_party/rust/naga/src/proc/typifier.rs b/third_party/rust/naga/src/proc/typifier.rs
new file mode 100644
index 0000000000..9c4403445c
--- /dev/null
+++ b/third_party/rust/naga/src/proc/typifier.rs
@@ -0,0 +1,893 @@
+use crate::arena::{Arena, Handle, UniqueArena};
+
+use thiserror::Error;
+
+/// The result of computing an expression's type.
+///
+/// This is the (Rust) type returned by [`ResolveContext::resolve`] to represent
+/// the (Naga) type it ascribes to some expression.
+///
+/// You might expect such a function to simply return a `Handle<Type>`. However,
+/// we want type resolution to be a read-only process, and that would limit the
+/// possible results to types already present in the expression's associated
+/// `UniqueArena<Type>`. Naga IR does have certain expressions whose types are
+/// not certain to be present.
+///
+/// So instead, type resolution returns a `TypeResolution` enum: either a
+/// [`Handle`], referencing some type in the arena, or a [`Value`], holding a
+/// free-floating [`TypeInner`]. This extends the range to cover anything that
+/// can be represented with a `TypeInner` referring to the existing arena.
+///
+/// What sorts of expressions can have types not available in the arena?
+///
+/// - An [`Access`] or [`AccessIndex`] expression applied to a [`Vector`] or
+/// [`Matrix`] must have a [`Scalar`] or [`Vector`] type. But since `Vector`
+/// and `Matrix` represent their element and column types implicitly, not
+/// via a handle, there may not be a suitable type in the expression's
+/// associated arena. Instead, resolving such an expression returns a
+/// `TypeResolution::Value(TypeInner::X { ... })`, where `X` is `Scalar` or
+/// `Vector`.
+///
+/// - Similarly, the type of an [`Access`] or [`AccessIndex`] expression
+/// applied to a *pointer to* a vector or matrix must produce a *pointer to*
+/// a scalar or vector type. These cannot be represented with a
+/// [`TypeInner::Pointer`], since the `Pointer`'s `base` must point into the
+/// arena, and as before, we cannot assume that a suitable scalar or vector
+/// type is there. So we take things one step further and provide
+/// [`TypeInner::ValuePointer`], specifically for the case of pointers to
+/// scalars or vectors. This type fits in a `TypeInner` and is exactly
+/// equivalent to a `Pointer` to a `Vector` or `Scalar`.
+///
+/// So, for example, the type of an `Access` expression applied to a value of type:
+///
+/// ```ignore
+/// TypeInner::Matrix { columns, rows, width }
+/// ```
+///
+/// might be:
+///
+/// ```ignore
+/// TypeResolution::Value(TypeInner::Vector {
+/// size: rows,
+/// kind: ScalarKind::Float,
+/// width,
+/// })
+/// ```
+///
+/// and the type of an access to a pointer of address space `space` to such a
+/// matrix might be:
+///
+/// ```ignore
+/// TypeResolution::Value(TypeInner::ValuePointer {
+/// size: Some(rows),
+/// kind: ScalarKind::Float,
+/// width,
+/// space,
+/// })
+/// ```
+///
+/// [`Handle`]: TypeResolution::Handle
+/// [`Value`]: TypeResolution::Value
+///
+/// [`Access`]: crate::Expression::Access
+/// [`AccessIndex`]: crate::Expression::AccessIndex
+///
+/// [`TypeInner`]: crate::TypeInner
+/// [`Matrix`]: crate::TypeInner::Matrix
+/// [`Pointer`]: crate::TypeInner::Pointer
+/// [`Scalar`]: crate::TypeInner::Scalar
+/// [`ValuePointer`]: crate::TypeInner::ValuePointer
+/// [`Vector`]: crate::TypeInner::Vector
+///
+/// [`TypeInner::Pointer`]: crate::TypeInner::Pointer
+/// [`TypeInner::ValuePointer`]: crate::TypeInner::ValuePointer
+#[derive(Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum TypeResolution {
+ /// A type stored in the associated arena.
+ Handle(Handle<crate::Type>),
+
+ /// A free-floating [`TypeInner`], representing a type that may not be
+ /// available in the associated arena. However, the `TypeInner` itself may
+ /// contain `Handle<Type>` values referring to types from the arena.
+ ///
+ /// [`TypeInner`]: crate::TypeInner
+ Value(crate::TypeInner),
+}
+
+impl TypeResolution {
+ pub const fn handle(&self) -> Option<Handle<crate::Type>> {
+ match *self {
+ Self::Handle(handle) => Some(handle),
+ Self::Value(_) => None,
+ }
+ }
+
+ pub fn inner_with<'a>(&'a self, arena: &'a UniqueArena<crate::Type>) -> &'a crate::TypeInner {
+ match *self {
+ Self::Handle(handle) => &arena[handle].inner,
+ Self::Value(ref inner) => inner,
+ }
+ }
+}
+
+// Clone is only implemented for numeric variants of `TypeInner`.
+impl Clone for TypeResolution {
+ fn clone(&self) -> Self {
+ use crate::TypeInner as Ti;
+ match *self {
+ Self::Handle(handle) => Self::Handle(handle),
+ Self::Value(ref v) => Self::Value(match *v {
+ Ti::Scalar(scalar) => Ti::Scalar(scalar),
+ Ti::Vector { size, scalar } => Ti::Vector { size, scalar },
+ Ti::Matrix {
+ rows,
+ columns,
+ scalar,
+ } => Ti::Matrix {
+ rows,
+ columns,
+ scalar,
+ },
+ Ti::Pointer { base, space } => Ti::Pointer { base, space },
+ Ti::ValuePointer {
+ size,
+ scalar,
+ space,
+ } => Ti::ValuePointer {
+ size,
+ scalar,
+ space,
+ },
+ _ => unreachable!("Unexpected clone type: {:?}", v),
+ }),
+ }
+ }
+}
+
+#[derive(Clone, Debug, Error, PartialEq)]
+pub enum ResolveError {
+ #[error("Index {index} is out of bounds for expression {expr:?}")]
+ OutOfBoundsIndex {
+ expr: Handle<crate::Expression>,
+ index: u32,
+ },
+ #[error("Invalid access into expression {expr:?}, indexed: {indexed}")]
+ InvalidAccess {
+ expr: Handle<crate::Expression>,
+ indexed: bool,
+ },
+ #[error("Invalid sub-access into type {ty:?}, indexed: {indexed}")]
+ InvalidSubAccess {
+ ty: Handle<crate::Type>,
+ indexed: bool,
+ },
+ #[error("Invalid scalar {0:?}")]
+ InvalidScalar(Handle<crate::Expression>),
+ #[error("Invalid vector {0:?}")]
+ InvalidVector(Handle<crate::Expression>),
+ #[error("Invalid pointer {0:?}")]
+ InvalidPointer(Handle<crate::Expression>),
+ #[error("Invalid image {0:?}")]
+ InvalidImage(Handle<crate::Expression>),
+ #[error("Function {name} not defined")]
+ FunctionNotDefined { name: String },
+ #[error("Function without return type")]
+ FunctionReturnsVoid,
+ #[error("Incompatible operands: {0}")]
+ IncompatibleOperands(String),
+ #[error("Function argument {0} doesn't exist")]
+ FunctionArgumentNotFound(u32),
+ #[error("Special type is not registered within the module")]
+ MissingSpecialType,
+}
+
+pub struct ResolveContext<'a> {
+ pub constants: &'a Arena<crate::Constant>,
+ pub types: &'a UniqueArena<crate::Type>,
+ pub special_types: &'a crate::SpecialTypes,
+ pub global_vars: &'a Arena<crate::GlobalVariable>,
+ pub local_vars: &'a Arena<crate::LocalVariable>,
+ pub functions: &'a Arena<crate::Function>,
+ pub arguments: &'a [crate::FunctionArgument],
+}
+
+impl<'a> ResolveContext<'a> {
+ /// Initialize a resolve context from the module.
+ pub const fn with_locals(
+ module: &'a crate::Module,
+ local_vars: &'a Arena<crate::LocalVariable>,
+ arguments: &'a [crate::FunctionArgument],
+ ) -> Self {
+ Self {
+ constants: &module.constants,
+ types: &module.types,
+ special_types: &module.special_types,
+ global_vars: &module.global_variables,
+ local_vars,
+ functions: &module.functions,
+ arguments,
+ }
+ }
+
+ /// Determine the type of `expr`.
+ ///
+ /// The `past` argument must be a closure that can resolve the types of any
+ /// expressions that `expr` refers to. These can be gathered by caching the
+ /// results of prior calls to `resolve`, perhaps as done by the
+ /// [`front::Typifier`] utility type.
+ ///
+ /// Type resolution is a read-only process: this method takes `self` by
+ /// shared reference. However, this means that we cannot add anything to
+ /// `self.types` that we might need to describe `expr`. To work around this,
+ /// this method returns a [`TypeResolution`], rather than simply returning a
+ /// `Handle<Type>`; see the documentation for [`TypeResolution`] for
+ /// details.
+ ///
+ /// [`front::Typifier`]: crate::front::Typifier
+ pub fn resolve(
+ &self,
+ expr: &crate::Expression,
+ past: impl Fn(Handle<crate::Expression>) -> Result<&'a TypeResolution, ResolveError>,
+ ) -> Result<TypeResolution, ResolveError> {
+ use crate::TypeInner as Ti;
+ let types = self.types;
+ Ok(match *expr {
+ crate::Expression::Access { base, .. } => match *past(base)?.inner_with(types) {
+ // Arrays and matrices can only be indexed dynamically behind a
+ // pointer, but that's a validation error, not a type error, so
+ // go ahead provide a type here.
+ Ti::Array { base, .. } => TypeResolution::Handle(base),
+ Ti::Matrix { rows, scalar, .. } => {
+ TypeResolution::Value(Ti::Vector { size: rows, scalar })
+ }
+ Ti::Vector { size: _, scalar } => TypeResolution::Value(Ti::Scalar(scalar)),
+ Ti::ValuePointer {
+ size: Some(_),
+ scalar,
+ space,
+ } => TypeResolution::Value(Ti::ValuePointer {
+ size: None,
+ scalar,
+ space,
+ }),
+ Ti::Pointer { base, space } => {
+ TypeResolution::Value(match types[base].inner {
+ Ti::Array { base, .. } => Ti::Pointer { base, space },
+ Ti::Vector { size: _, scalar } => Ti::ValuePointer {
+ size: None,
+ scalar,
+ space,
+ },
+ // Matrices are only dynamically indexed behind a pointer
+ Ti::Matrix {
+ columns: _,
+ rows,
+ scalar,
+ } => Ti::ValuePointer {
+ size: Some(rows),
+ scalar,
+ space,
+ },
+ Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
+ ref other => {
+ log::error!("Access sub-type {:?}", other);
+ return Err(ResolveError::InvalidSubAccess {
+ ty: base,
+ indexed: false,
+ });
+ }
+ })
+ }
+ Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
+ ref other => {
+ log::error!("Access type {:?}", other);
+ return Err(ResolveError::InvalidAccess {
+ expr: base,
+ indexed: false,
+ });
+ }
+ },
+ crate::Expression::AccessIndex { base, index } => {
+ match *past(base)?.inner_with(types) {
+ Ti::Vector { size, scalar } => {
+ if index >= size as u32 {
+ return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
+ }
+ TypeResolution::Value(Ti::Scalar(scalar))
+ }
+ Ti::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ if index >= columns as u32 {
+ return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
+ }
+ TypeResolution::Value(crate::TypeInner::Vector { size: rows, scalar })
+ }
+ Ti::Array { base, .. } => TypeResolution::Handle(base),
+ Ti::Struct { ref members, .. } => {
+ let member = members
+ .get(index as usize)
+ .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
+ TypeResolution::Handle(member.ty)
+ }
+ Ti::ValuePointer {
+ size: Some(size),
+ scalar,
+ space,
+ } => {
+ if index >= size as u32 {
+ return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
+ }
+ TypeResolution::Value(Ti::ValuePointer {
+ size: None,
+ scalar,
+ space,
+ })
+ }
+ Ti::Pointer {
+ base: ty_base,
+ space,
+ } => TypeResolution::Value(match types[ty_base].inner {
+ Ti::Array { base, .. } => Ti::Pointer { base, space },
+ Ti::Vector { size, scalar } => {
+ if index >= size as u32 {
+ return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
+ }
+ Ti::ValuePointer {
+ size: None,
+ scalar,
+ space,
+ }
+ }
+ Ti::Matrix {
+ rows,
+ columns,
+ scalar,
+ } => {
+ if index >= columns as u32 {
+ return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
+ }
+ Ti::ValuePointer {
+ size: Some(rows),
+ scalar,
+ space,
+ }
+ }
+ Ti::Struct { ref members, .. } => {
+ let member = members
+ .get(index as usize)
+ .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
+ Ti::Pointer {
+ base: member.ty,
+ space,
+ }
+ }
+ Ti::BindingArray { base, .. } => Ti::Pointer { base, space },
+ ref other => {
+ log::error!("Access index sub-type {:?}", other);
+ return Err(ResolveError::InvalidSubAccess {
+ ty: ty_base,
+ indexed: true,
+ });
+ }
+ }),
+ Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
+ ref other => {
+ log::error!("Access index type {:?}", other);
+ return Err(ResolveError::InvalidAccess {
+ expr: base,
+ indexed: true,
+ });
+ }
+ }
+ }
+ crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) {
+ Ti::Scalar(scalar) => TypeResolution::Value(Ti::Vector { size, scalar }),
+ ref other => {
+ log::error!("Scalar type {:?}", other);
+ return Err(ResolveError::InvalidScalar(value));
+ }
+ },
+ crate::Expression::Swizzle {
+ size,
+ vector,
+ pattern: _,
+ } => match *past(vector)?.inner_with(types) {
+ Ti::Vector { size: _, scalar } => {
+ TypeResolution::Value(Ti::Vector { size, scalar })
+ }
+ ref other => {
+ log::error!("Vector type {:?}", other);
+ return Err(ResolveError::InvalidVector(vector));
+ }
+ },
+ crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()),
+ crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty),
+ crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty),
+ crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty),
+ crate::Expression::FunctionArgument(index) => {
+ let arg = self
+ .arguments
+ .get(index as usize)
+ .ok_or(ResolveError::FunctionArgumentNotFound(index))?;
+ TypeResolution::Handle(arg.ty)
+ }
+ crate::Expression::GlobalVariable(h) => {
+ let var = &self.global_vars[h];
+ if var.space == crate::AddressSpace::Handle {
+ TypeResolution::Handle(var.ty)
+ } else {
+ TypeResolution::Value(Ti::Pointer {
+ base: var.ty,
+ space: var.space,
+ })
+ }
+ }
+ crate::Expression::LocalVariable(h) => {
+ let var = &self.local_vars[h];
+ TypeResolution::Value(Ti::Pointer {
+ base: var.ty,
+ space: crate::AddressSpace::Function,
+ })
+ }
+ crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) {
+ Ti::Pointer { base, space: _ } => {
+ if let Ti::Atomic(scalar) = types[base].inner {
+ TypeResolution::Value(Ti::Scalar(scalar))
+ } else {
+ TypeResolution::Handle(base)
+ }
+ }
+ Ti::ValuePointer {
+ size,
+ scalar,
+ space: _,
+ } => TypeResolution::Value(match size {
+ Some(size) => Ti::Vector { size, scalar },
+ None => Ti::Scalar(scalar),
+ }),
+ ref other => {
+ log::error!("Pointer type {:?}", other);
+ return Err(ResolveError::InvalidPointer(pointer));
+ }
+ },
+ crate::Expression::ImageSample {
+ image,
+ gather: Some(_),
+ ..
+ } => match *past(image)?.inner_with(types) {
+ Ti::Image { class, .. } => TypeResolution::Value(Ti::Vector {
+ scalar: crate::Scalar {
+ kind: match class {
+ crate::ImageClass::Sampled { kind, multi: _ } => kind,
+ _ => crate::ScalarKind::Float,
+ },
+ width: 4,
+ },
+ size: crate::VectorSize::Quad,
+ }),
+ ref other => {
+ log::error!("Image type {:?}", other);
+ return Err(ResolveError::InvalidImage(image));
+ }
+ },
+ crate::Expression::ImageSample { image, .. }
+ | crate::Expression::ImageLoad { image, .. } => match *past(image)?.inner_with(types) {
+ Ti::Image { class, .. } => TypeResolution::Value(match class {
+ crate::ImageClass::Depth { multi: _ } => Ti::Scalar(crate::Scalar::F32),
+ crate::ImageClass::Sampled { kind, multi: _ } => Ti::Vector {
+ scalar: crate::Scalar { kind, width: 4 },
+ size: crate::VectorSize::Quad,
+ },
+ crate::ImageClass::Storage { format, .. } => Ti::Vector {
+ scalar: crate::Scalar {
+ kind: format.into(),
+ width: 4,
+ },
+ size: crate::VectorSize::Quad,
+ },
+ }),
+ ref other => {
+ log::error!("Image type {:?}", other);
+ return Err(ResolveError::InvalidImage(image));
+ }
+ },
+ crate::Expression::ImageQuery { image, query } => TypeResolution::Value(match query {
+ crate::ImageQuery::Size { level: _ } => match *past(image)?.inner_with(types) {
+ Ti::Image { dim, .. } => match dim {
+ crate::ImageDimension::D1 => Ti::Scalar(crate::Scalar::U32),
+ crate::ImageDimension::D2 | crate::ImageDimension::Cube => Ti::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: crate::Scalar::U32,
+ },
+ crate::ImageDimension::D3 => Ti::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: crate::Scalar::U32,
+ },
+ },
+ ref other => {
+ log::error!("Image type {:?}", other);
+ return Err(ResolveError::InvalidImage(image));
+ }
+ },
+ crate::ImageQuery::NumLevels
+ | crate::ImageQuery::NumLayers
+ | crate::ImageQuery::NumSamples => Ti::Scalar(crate::Scalar::U32),
+ }),
+ crate::Expression::Unary { expr, .. } => past(expr)?.clone(),
+ crate::Expression::Binary { op, left, right } => match op {
+ crate::BinaryOperator::Add
+ | crate::BinaryOperator::Subtract
+ | crate::BinaryOperator::Divide
+ | crate::BinaryOperator::Modulo => past(left)?.clone(),
+ crate::BinaryOperator::Multiply => {
+ let (res_left, res_right) = (past(left)?, past(right)?);
+ match (res_left.inner_with(types), res_right.inner_with(types)) {
+ (
+ &Ti::Matrix {
+ columns: _,
+ rows,
+ scalar,
+ },
+ &Ti::Matrix { columns, .. },
+ ) => TypeResolution::Value(Ti::Matrix {
+ columns,
+ rows,
+ scalar,
+ }),
+ (
+ &Ti::Matrix {
+ columns: _,
+ rows,
+ scalar,
+ },
+ &Ti::Vector { .. },
+ ) => TypeResolution::Value(Ti::Vector { size: rows, scalar }),
+ (
+ &Ti::Vector { .. },
+ &Ti::Matrix {
+ columns,
+ rows: _,
+ scalar,
+ },
+ ) => TypeResolution::Value(Ti::Vector {
+ size: columns,
+ scalar,
+ }),
+ (&Ti::Scalar { .. }, _) => res_right.clone(),
+ (_, &Ti::Scalar { .. }) => res_left.clone(),
+ (&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(),
+ (tl, tr) => {
+ return Err(ResolveError::IncompatibleOperands(format!(
+ "{tl:?} * {tr:?}"
+ )))
+ }
+ }
+ }
+ crate::BinaryOperator::Equal
+ | crate::BinaryOperator::NotEqual
+ | crate::BinaryOperator::Less
+ | crate::BinaryOperator::LessEqual
+ | crate::BinaryOperator::Greater
+ | crate::BinaryOperator::GreaterEqual
+ | crate::BinaryOperator::LogicalAnd
+ | crate::BinaryOperator::LogicalOr => {
+ let scalar = crate::Scalar::BOOL;
+ let inner = match *past(left)?.inner_with(types) {
+ Ti::Scalar { .. } => Ti::Scalar(scalar),
+ Ti::Vector { size, .. } => Ti::Vector { size, scalar },
+ ref other => {
+ return Err(ResolveError::IncompatibleOperands(format!(
+ "{op:?}({other:?}, _)"
+ )))
+ }
+ };
+ TypeResolution::Value(inner)
+ }
+ crate::BinaryOperator::And
+ | crate::BinaryOperator::ExclusiveOr
+ | crate::BinaryOperator::InclusiveOr
+ | crate::BinaryOperator::ShiftLeft
+ | crate::BinaryOperator::ShiftRight => past(left)?.clone(),
+ },
+ crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
+ crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty),
+ crate::Expression::Select { accept, .. } => past(accept)?.clone(),
+ crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
+ crate::Expression::Relational { fun, argument } => match fun {
+ crate::RelationalFunction::All | crate::RelationalFunction::Any => {
+ TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL))
+ }
+ crate::RelationalFunction::IsNan | crate::RelationalFunction::IsInf => {
+ match *past(argument)?.inner_with(types) {
+ Ti::Scalar { .. } => TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL)),
+ Ti::Vector { size, .. } => TypeResolution::Value(Ti::Vector {
+ scalar: crate::Scalar::BOOL,
+ size,
+ }),
+ ref other => {
+ return Err(ResolveError::IncompatibleOperands(format!(
+ "{fun:?}({other:?})"
+ )))
+ }
+ }
+ }
+ },
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2: _,
+ arg3: _,
+ } => {
+ use crate::MathFunction as Mf;
+ let res_arg = past(arg)?;
+ match fun {
+ // comparison
+ Mf::Abs |
+ Mf::Min |
+ Mf::Max |
+ Mf::Clamp |
+ Mf::Saturate |
+ // trigonometry
+ Mf::Cos |
+ Mf::Cosh |
+ Mf::Sin |
+ Mf::Sinh |
+ Mf::Tan |
+ Mf::Tanh |
+ Mf::Acos |
+ Mf::Asin |
+ Mf::Atan |
+ Mf::Atan2 |
+ Mf::Asinh |
+ Mf::Acosh |
+ Mf::Atanh |
+ Mf::Radians |
+ Mf::Degrees |
+ // decomposition
+ Mf::Ceil |
+ Mf::Floor |
+ Mf::Round |
+ Mf::Fract |
+ Mf::Trunc |
+ Mf::Ldexp |
+ // exponent
+ Mf::Exp |
+ Mf::Exp2 |
+ Mf::Log |
+ Mf::Log2 |
+ Mf::Pow => res_arg.clone(),
+ Mf::Modf | Mf::Frexp => {
+ let (size, width) = match res_arg.inner_with(types) {
+ &Ti::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Float,
+ width,
+ }) => (None, width),
+ &Ti::Vector {
+ scalar: crate::Scalar {
+ kind: crate::ScalarKind::Float,
+ width,
+ },
+ size,
+ } => (Some(size), width),
+ ref other =>
+ return Err(ResolveError::IncompatibleOperands(format!("{fun:?}({other:?}, _)")))
+ };
+ let result = self
+ .special_types
+ .predeclared_types
+ .get(&if fun == Mf::Modf {
+ crate::PredeclaredType::ModfResult { size, width }
+ } else {
+ crate::PredeclaredType::FrexpResult { size, width }
+ })
+ .ok_or(ResolveError::MissingSpecialType)?;
+ TypeResolution::Handle(*result)
+ },
+ // geometry
+ Mf::Dot => match *res_arg.inner_with(types) {
+ Ti::Vector {
+ size: _,
+ scalar,
+ } => TypeResolution::Value(Ti::Scalar(scalar)),
+ ref other =>
+ return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?}, _)")
+ )),
+ },
+ Mf::Outer => {
+ let arg1 = arg1.ok_or_else(|| ResolveError::IncompatibleOperands(
+ format!("{fun:?}(_, None)")
+ ))?;
+ match (res_arg.inner_with(types), past(arg1)?.inner_with(types)) {
+ (
+ &Ti::Vector { size: columns, scalar },
+ &Ti::Vector{ size: rows, .. }
+ ) => TypeResolution::Value(Ti::Matrix {
+ columns,
+ rows,
+ scalar,
+ }),
+ (left, right) =>
+ return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({left:?}, {right:?})")
+ )),
+ }
+ },
+ Mf::Cross => res_arg.clone(),
+ Mf::Distance |
+ Mf::Length => match *res_arg.inner_with(types) {
+ Ti::Scalar(scalar) |
+ Ti::Vector {scalar,size:_} => TypeResolution::Value(Ti::Scalar(scalar)),
+ ref other => return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?})")
+ )),
+ },
+ Mf::Normalize |
+ Mf::FaceForward |
+ Mf::Reflect |
+ Mf::Refract => res_arg.clone(),
+ // computational
+ Mf::Sign |
+ Mf::Fma |
+ Mf::Mix |
+ Mf::Step |
+ Mf::SmoothStep |
+ Mf::Sqrt |
+ Mf::InverseSqrt => res_arg.clone(),
+ Mf::Transpose => match *res_arg.inner_with(types) {
+ Ti::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => TypeResolution::Value(Ti::Matrix {
+ columns: rows,
+ rows: columns,
+ scalar,
+ }),
+ ref other => return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?})")
+ )),
+ },
+ Mf::Inverse => match *res_arg.inner_with(types) {
+ Ti::Matrix {
+ columns,
+ rows,
+ scalar,
+ } if columns == rows => TypeResolution::Value(Ti::Matrix {
+ columns,
+ rows,
+ scalar,
+ }),
+ ref other => return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?})")
+ )),
+ },
+ Mf::Determinant => match *res_arg.inner_with(types) {
+ Ti::Matrix {
+ scalar,
+ ..
+ } => TypeResolution::Value(Ti::Scalar(scalar)),
+ ref other => return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?})")
+ )),
+ },
+ // bits
+ Mf::CountTrailingZeros |
+ Mf::CountLeadingZeros |
+ Mf::CountOneBits |
+ Mf::ReverseBits |
+ Mf::ExtractBits |
+ Mf::InsertBits |
+ Mf::FindLsb |
+ Mf::FindMsb => match *res_arg.inner_with(types) {
+ Ti::Scalar(scalar @ crate::Scalar {
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ ..
+ }) => TypeResolution::Value(Ti::Scalar(scalar)),
+ Ti::Vector {
+ size,
+ scalar: scalar @ crate::Scalar {
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ ..
+ }
+ } => TypeResolution::Value(Ti::Vector { size, scalar }),
+ ref other => return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?})")
+ )),
+ },
+ // data packing
+ Mf::Pack4x8snorm |
+ Mf::Pack4x8unorm |
+ Mf::Pack2x16snorm |
+ Mf::Pack2x16unorm |
+ Mf::Pack2x16float => TypeResolution::Value(Ti::Scalar(crate::Scalar::U32)),
+ // data unpacking
+ Mf::Unpack4x8snorm |
+ Mf::Unpack4x8unorm => TypeResolution::Value(Ti::Vector {
+ size: crate::VectorSize::Quad,
+ scalar: crate::Scalar::F32
+ }),
+ Mf::Unpack2x16snorm |
+ Mf::Unpack2x16unorm |
+ Mf::Unpack2x16float => TypeResolution::Value(Ti::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: crate::Scalar::F32
+ }),
+ }
+ }
+ crate::Expression::As {
+ expr,
+ kind,
+ convert,
+ } => match *past(expr)?.inner_with(types) {
+ Ti::Scalar(crate::Scalar { width, .. }) => {
+ TypeResolution::Value(Ti::Scalar(crate::Scalar {
+ kind,
+ width: convert.unwrap_or(width),
+ }))
+ }
+ Ti::Vector {
+ size,
+ scalar: crate::Scalar { kind: _, width },
+ } => TypeResolution::Value(Ti::Vector {
+ size,
+ scalar: crate::Scalar {
+ kind,
+ width: convert.unwrap_or(width),
+ },
+ }),
+ Ti::Matrix {
+ columns,
+ rows,
+ mut scalar,
+ } => {
+ if let Some(width) = convert {
+ scalar.width = width;
+ }
+ TypeResolution::Value(Ti::Matrix {
+ columns,
+ rows,
+ scalar,
+ })
+ }
+ ref other => {
+ return Err(ResolveError::IncompatibleOperands(format!(
+ "{other:?} as {kind:?}"
+ )))
+ }
+ },
+ crate::Expression::CallResult(function) => {
+ let result = self.functions[function]
+ .result
+ .as_ref()
+ .ok_or(ResolveError::FunctionReturnsVoid)?;
+ TypeResolution::Handle(result.ty)
+ }
+ crate::Expression::ArrayLength(_) => {
+ TypeResolution::Value(Ti::Scalar(crate::Scalar::U32))
+ }
+ crate::Expression::RayQueryProceedResult => {
+ TypeResolution::Value(Ti::Scalar(crate::Scalar::BOOL))
+ }
+ crate::Expression::RayQueryGetIntersection { .. } => {
+ let result = self
+ .special_types
+ .ray_intersection
+ .ok_or(ResolveError::MissingSpecialType)?;
+ TypeResolution::Handle(result)
+ }
+ })
+ }
+}
+
+#[test]
+fn test_error_size() {
+ use std::mem::size_of;
+ assert_eq!(size_of::<ResolveError>(), 32);
+}
diff --git a/third_party/rust/naga/src/span.rs b/third_party/rust/naga/src/span.rs
new file mode 100644
index 0000000000..53246b25d6
--- /dev/null
+++ b/third_party/rust/naga/src/span.rs
@@ -0,0 +1,501 @@
+use crate::{Arena, Handle, UniqueArena};
+use std::{error::Error, fmt, ops::Range};
+
+/// A source code span, used for error reporting.
+#[derive(Clone, Copy, Debug, PartialEq, Default)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub struct Span {
+ start: u32,
+ end: u32,
+}
+
+impl Span {
+ pub const UNDEFINED: Self = Self { start: 0, end: 0 };
+ /// Creates a new `Span` from a range of byte indices
+ ///
+ /// Note: end is exclusive, it doesn't belong to the `Span`
+ pub const fn new(start: u32, end: u32) -> Self {
+ Span { start, end }
+ }
+
+ /// Returns a new `Span` starting at `self` and ending at `other`
+ pub const fn until(&self, other: &Self) -> Self {
+ Span {
+ start: self.start,
+ end: other.end,
+ }
+ }
+
+ /// Modifies `self` to contain the smallest `Span` possible that
+ /// contains both `self` and `other`
+ pub fn subsume(&mut self, other: Self) {
+ *self = if !self.is_defined() {
+ // self isn't defined so use other
+ other
+ } else if !other.is_defined() {
+ // other isn't defined so don't try to subsume
+ *self
+ } else {
+ // Both self and other are defined so calculate the span that contains them both
+ Span {
+ start: self.start.min(other.start),
+ end: self.end.max(other.end),
+ }
+ }
+ }
+
+ /// Returns the smallest `Span` possible that contains all the `Span`s
+ /// defined in the `from` iterator
+ pub fn total_span<T: Iterator<Item = Self>>(from: T) -> Self {
+ let mut span: Self = Default::default();
+ for other in from {
+ span.subsume(other);
+ }
+ span
+ }
+
+ /// Converts `self` to a range if the span is not unknown
+ pub fn to_range(self) -> Option<Range<usize>> {
+ if self.is_defined() {
+ Some(self.start as usize..self.end as usize)
+ } else {
+ None
+ }
+ }
+
+ /// Check whether `self` was defined or is a default/unknown span
+ pub fn is_defined(&self) -> bool {
+ *self != Self::default()
+ }
+
+ /// Return a [`SourceLocation`] for this span in the provided source.
+ pub fn location(&self, source: &str) -> SourceLocation {
+ let prefix = &source[..self.start as usize];
+ let line_number = prefix.matches('\n').count() as u32 + 1;
+ let line_start = prefix.rfind('\n').map(|pos| pos + 1).unwrap_or(0);
+ let line_position = source[line_start..self.start as usize].chars().count() as u32 + 1;
+
+ SourceLocation {
+ line_number,
+ line_position,
+ offset: self.start,
+ length: self.end - self.start,
+ }
+ }
+}
+
+impl From<Range<usize>> for Span {
+ fn from(range: Range<usize>) -> Self {
+ Span {
+ start: range.start as u32,
+ end: range.end as u32,
+ }
+ }
+}
+
+impl std::ops::Index<Span> for str {
+ type Output = str;
+
+ #[inline]
+ fn index(&self, span: Span) -> &str {
+ &self[span.start as usize..span.end as usize]
+ }
+}
+
+/// A human-readable representation for a span, tailored for text source.
+///
+/// Corresponds to the positional members of [`GPUCompilationMessage`][gcm] from
+/// the WebGPU specification, except that `offset` and `length` are in bytes
+/// (UTF-8 code units), instead of UTF-16 code units.
+///
+/// [gcm]: https://www.w3.org/TR/webgpu/#gpucompilationmessage
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+pub struct SourceLocation {
+ /// 1-based line number.
+ pub line_number: u32,
+ /// 1-based column of the start of this span
+ pub line_position: u32,
+ /// 0-based Offset in code units (in bytes) of the start of the span.
+ pub offset: u32,
+ /// Length in code units (in bytes) of the span.
+ pub length: u32,
+}
+
+/// A source code span together with "context", a user-readable description of what part of the error it refers to.
+pub type SpanContext = (Span, String);
+
+/// Wrapper class for [`Error`], augmenting it with a list of [`SpanContext`]s.
+#[derive(Debug, Clone)]
+pub struct WithSpan<E> {
+ inner: E,
+ spans: Vec<SpanContext>,
+}
+
+impl<E> fmt::Display for WithSpan<E>
+where
+ E: fmt::Display,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
+ self.inner.fmt(f)
+ }
+}
+
+#[cfg(test)]
+impl<E> PartialEq for WithSpan<E>
+where
+ E: PartialEq,
+{
+ fn eq(&self, other: &Self) -> bool {
+ self.inner.eq(&other.inner)
+ }
+}
+
+impl<E> Error for WithSpan<E>
+where
+ E: Error,
+{
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ self.inner.source()
+ }
+}
+
+impl<E> WithSpan<E> {
+ /// Create a new [`WithSpan`] from an [`Error`], containing no spans.
+ pub const fn new(inner: E) -> Self {
+ Self {
+ inner,
+ spans: Vec::new(),
+ }
+ }
+
+ /// Reverse of [`Self::new`], discards span information and returns an inner error.
+ #[allow(clippy::missing_const_for_fn)] // ignore due to requirement of #![feature(const_precise_live_drops)]
+ pub fn into_inner(self) -> E {
+ self.inner
+ }
+
+ pub const fn as_inner(&self) -> &E {
+ &self.inner
+ }
+
+ /// Iterator over stored [`SpanContext`]s.
+ pub fn spans(&self) -> impl ExactSizeIterator<Item = &SpanContext> {
+ self.spans.iter()
+ }
+
+ /// Add a new span with description.
+ pub fn with_span<S>(mut self, span: Span, description: S) -> Self
+ where
+ S: ToString,
+ {
+ if span.is_defined() {
+ self.spans.push((span, description.to_string()));
+ }
+ self
+ }
+
+ /// Add a [`SpanContext`].
+ pub fn with_context(self, span_context: SpanContext) -> Self {
+ let (span, description) = span_context;
+ self.with_span(span, description)
+ }
+
+ /// Add a [`Handle`] from either [`Arena`] or [`UniqueArena`], borrowing its span information from there
+ /// and annotating with a type and the handle representation.
+ pub(crate) fn with_handle<T, A: SpanProvider<T>>(self, handle: Handle<T>, arena: &A) -> Self {
+ self.with_context(arena.get_span_context(handle))
+ }
+
+ /// Convert inner error using [`From`].
+ pub fn into_other<E2>(self) -> WithSpan<E2>
+ where
+ E2: From<E>,
+ {
+ WithSpan {
+ inner: self.inner.into(),
+ spans: self.spans,
+ }
+ }
+
+ /// Convert inner error into another type. Joins span information contained in `self`
+ /// with what is returned from `func`.
+ pub fn and_then<F, E2>(self, func: F) -> WithSpan<E2>
+ where
+ F: FnOnce(E) -> WithSpan<E2>,
+ {
+ let mut res = func(self.inner);
+ res.spans.extend(self.spans);
+ res
+ }
+
+ /// Return a [`SourceLocation`] for our first span, if we have one.
+ pub fn location(&self, source: &str) -> Option<SourceLocation> {
+ if self.spans.is_empty() {
+ return None;
+ }
+
+ Some(self.spans[0].0.location(source))
+ }
+
+ fn diagnostic(&self) -> codespan_reporting::diagnostic::Diagnostic<()>
+ where
+ E: Error,
+ {
+ use codespan_reporting::diagnostic::{Diagnostic, Label};
+ let diagnostic = Diagnostic::error()
+ .with_message(self.inner.to_string())
+ .with_labels(
+ self.spans()
+ .map(|&(span, ref desc)| {
+ Label::primary((), span.to_range().unwrap()).with_message(desc.to_owned())
+ })
+ .collect(),
+ )
+ .with_notes({
+ let mut notes = Vec::new();
+ let mut source: &dyn Error = &self.inner;
+ while let Some(next) = Error::source(source) {
+ notes.push(next.to_string());
+ source = next;
+ }
+ notes
+ });
+ diagnostic
+ }
+
+ /// Emits a summary of the error to standard error stream.
+ pub fn emit_to_stderr(&self, source: &str)
+ where
+ E: Error,
+ {
+ self.emit_to_stderr_with_path(source, "wgsl")
+ }
+
+ /// Emits a summary of the error to standard error stream.
+ pub fn emit_to_stderr_with_path(&self, source: &str, path: &str)
+ where
+ E: Error,
+ {
+ use codespan_reporting::{files, term};
+ use term::termcolor::{ColorChoice, StandardStream};
+
+ let files = files::SimpleFile::new(path, source);
+ let config = term::Config::default();
+ let writer = StandardStream::stderr(ColorChoice::Auto);
+ term::emit(&mut writer.lock(), &config, &files, &self.diagnostic())
+ .expect("cannot write error");
+ }
+
+ /// Emits a summary of the error to a string.
+ pub fn emit_to_string(&self, source: &str) -> String
+ where
+ E: Error,
+ {
+ self.emit_to_string_with_path(source, "wgsl")
+ }
+
+ /// Emits a summary of the error to a string.
+ pub fn emit_to_string_with_path(&self, source: &str, path: &str) -> String
+ where
+ E: Error,
+ {
+ use codespan_reporting::{files, term};
+ use term::termcolor::NoColor;
+
+ let files = files::SimpleFile::new(path, source);
+ let config = codespan_reporting::term::Config::default();
+ let mut writer = NoColor::new(Vec::new());
+ term::emit(&mut writer, &config, &files, &self.diagnostic()).expect("cannot write error");
+ String::from_utf8(writer.into_inner()).unwrap()
+ }
+}
+
+/// Convenience trait for [`Error`] to be able to apply spans to anything.
+pub(crate) trait AddSpan: Sized {
+ type Output;
+ /// See [`WithSpan::new`].
+ fn with_span(self) -> Self::Output;
+ /// See [`WithSpan::with_span`].
+ fn with_span_static(self, span: Span, description: &'static str) -> Self::Output;
+ /// See [`WithSpan::with_context`].
+ fn with_span_context(self, span_context: SpanContext) -> Self::Output;
+ /// See [`WithSpan::with_handle`].
+ fn with_span_handle<T, A: SpanProvider<T>>(self, handle: Handle<T>, arena: &A) -> Self::Output;
+}
+
+/// Trait abstracting over getting a span from an [`Arena`] or a [`UniqueArena`].
+pub(crate) trait SpanProvider<T> {
+ fn get_span(&self, handle: Handle<T>) -> Span;
+ fn get_span_context(&self, handle: Handle<T>) -> SpanContext {
+ match self.get_span(handle) {
+ x if !x.is_defined() => (Default::default(), "".to_string()),
+ known => (
+ known,
+ format!("{} {:?}", std::any::type_name::<T>(), handle),
+ ),
+ }
+ }
+}
+
+impl<T> SpanProvider<T> for Arena<T> {
+ fn get_span(&self, handle: Handle<T>) -> Span {
+ self.get_span(handle)
+ }
+}
+
+impl<T> SpanProvider<T> for UniqueArena<T> {
+ fn get_span(&self, handle: Handle<T>) -> Span {
+ self.get_span(handle)
+ }
+}
+
+impl<E> AddSpan for E
+where
+ E: Error,
+{
+ type Output = WithSpan<Self>;
+ fn with_span(self) -> WithSpan<Self> {
+ WithSpan::new(self)
+ }
+
+ fn with_span_static(self, span: Span, description: &'static str) -> WithSpan<Self> {
+ WithSpan::new(self).with_span(span, description)
+ }
+
+ fn with_span_context(self, span_context: SpanContext) -> WithSpan<Self> {
+ WithSpan::new(self).with_context(span_context)
+ }
+
+ fn with_span_handle<T, A: SpanProvider<T>>(
+ self,
+ handle: Handle<T>,
+ arena: &A,
+ ) -> WithSpan<Self> {
+ WithSpan::new(self).with_handle(handle, arena)
+ }
+}
+
+/// Convenience trait for [`Result`], adding a [`MapErrWithSpan::map_err_inner`]
+/// mapping to [`WithSpan::and_then`].
+pub trait MapErrWithSpan<E, E2>: Sized {
+ type Output: Sized;
+ fn map_err_inner<F, E3>(self, func: F) -> Self::Output
+ where
+ F: FnOnce(E) -> WithSpan<E3>,
+ E2: From<E3>;
+}
+
+impl<T, E, E2> MapErrWithSpan<E, E2> for Result<T, WithSpan<E>> {
+ type Output = Result<T, WithSpan<E2>>;
+ fn map_err_inner<F, E3>(self, func: F) -> Result<T, WithSpan<E2>>
+ where
+ F: FnOnce(E) -> WithSpan<E3>,
+ E2: From<E3>,
+ {
+ self.map_err(|e| e.and_then(func).into_other::<E2>())
+ }
+}
+
+#[test]
+fn span_location() {
+ let source = "12\n45\n\n89\n";
+ assert_eq!(
+ Span { start: 0, end: 1 }.location(source),
+ SourceLocation {
+ line_number: 1,
+ line_position: 1,
+ offset: 0,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 1, end: 2 }.location(source),
+ SourceLocation {
+ line_number: 1,
+ line_position: 2,
+ offset: 1,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 2, end: 3 }.location(source),
+ SourceLocation {
+ line_number: 1,
+ line_position: 3,
+ offset: 2,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 3, end: 5 }.location(source),
+ SourceLocation {
+ line_number: 2,
+ line_position: 1,
+ offset: 3,
+ length: 2
+ }
+ );
+ assert_eq!(
+ Span { start: 4, end: 6 }.location(source),
+ SourceLocation {
+ line_number: 2,
+ line_position: 2,
+ offset: 4,
+ length: 2
+ }
+ );
+ assert_eq!(
+ Span { start: 5, end: 6 }.location(source),
+ SourceLocation {
+ line_number: 2,
+ line_position: 3,
+ offset: 5,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 6, end: 7 }.location(source),
+ SourceLocation {
+ line_number: 3,
+ line_position: 1,
+ offset: 6,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 7, end: 8 }.location(source),
+ SourceLocation {
+ line_number: 4,
+ line_position: 1,
+ offset: 7,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 8, end: 9 }.location(source),
+ SourceLocation {
+ line_number: 4,
+ line_position: 2,
+ offset: 8,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 9, end: 10 }.location(source),
+ SourceLocation {
+ line_number: 4,
+ line_position: 3,
+ offset: 9,
+ length: 1
+ }
+ );
+ assert_eq!(
+ Span { start: 10, end: 11 }.location(source),
+ SourceLocation {
+ line_number: 5,
+ line_position: 1,
+ offset: 10,
+ length: 1
+ }
+ );
+}
diff --git a/third_party/rust/naga/src/valid/analyzer.rs b/third_party/rust/naga/src/valid/analyzer.rs
new file mode 100644
index 0000000000..df6fc5e9b0
--- /dev/null
+++ b/third_party/rust/naga/src/valid/analyzer.rs
@@ -0,0 +1,1281 @@
+/*! Module analyzer.
+
+Figures out the following properties:
+ - control flow uniformity
+ - texture/sampler pairs
+ - expression reference counts
+!*/
+
+use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
+use crate::span::{AddSpan as _, WithSpan};
+use crate::{
+ arena::{Arena, Handle},
+ proc::{ResolveContext, TypeResolution},
+};
+use std::ops;
+
+pub type NonUniformResult = Option<Handle<crate::Expression>>;
+
+// Remove this once we update our uniformity analysis and
+// add support for the `derivative_uniformity` diagnostic
+const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true;
+
+bitflags::bitflags! {
+ /// Kinds of expressions that require uniform control flow.
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct UniformityRequirements: u8 {
+ const WORK_GROUP_BARRIER = 0x1;
+ const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 };
+ const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 };
+ }
+}
+
+/// Uniform control flow characteristics.
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Uniformity {
+ /// A child expression with non-uniform result.
+ ///
+ /// This means, when the relevant invocations are scheduled on a compute unit,
+ /// they have to use vector registers to store an individual value
+ /// per invocation.
+ ///
+ /// Whenever the control flow is conditioned on such value,
+ /// the hardware needs to keep track of the mask of invocations,
+ /// and process all branches of the control flow.
+ ///
+ /// Any operations that depend on non-uniform results also produce non-uniform.
+ pub non_uniform_result: NonUniformResult,
+ /// If this expression requires uniform control flow, store the reason here.
+ pub requirements: UniformityRequirements,
+}
+
+impl Uniformity {
+ const fn new() -> Self {
+ Uniformity {
+ non_uniform_result: None,
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+}
+
+bitflags::bitflags! {
+ #[derive(Clone, Copy, Debug, PartialEq)]
+ struct ExitFlags: u8 {
+ /// Control flow may return from the function, which makes all the
+ /// subsequent statements within the current function (only!)
+ /// to be executed in a non-uniform control flow.
+ const MAY_RETURN = 0x1;
+ /// Control flow may be killed. Anything after `Statement::Kill` is
+ /// considered inside non-uniform context.
+ const MAY_KILL = 0x2;
+ }
+}
+
+/// Uniformity characteristics of a function.
+#[cfg_attr(test, derive(Debug, PartialEq))]
+struct FunctionUniformity {
+ result: Uniformity,
+ exit: ExitFlags,
+}
+
+impl ops::BitOr for FunctionUniformity {
+ type Output = Self;
+ fn bitor(self, other: Self) -> Self {
+ FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: self
+ .result
+ .non_uniform_result
+ .or(other.result.non_uniform_result),
+ requirements: self.result.requirements | other.result.requirements,
+ },
+ exit: self.exit | other.exit,
+ }
+ }
+}
+
+impl FunctionUniformity {
+ const fn new() -> Self {
+ FunctionUniformity {
+ result: Uniformity::new(),
+ exit: ExitFlags::empty(),
+ }
+ }
+
+ /// Returns a disruptor based on the stored exit flags, if any.
+ const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
+ if self.exit.contains(ExitFlags::MAY_RETURN) {
+ Some(UniformityDisruptor::Return)
+ } else if self.exit.contains(ExitFlags::MAY_KILL) {
+ Some(UniformityDisruptor::Discard)
+ } else {
+ None
+ }
+ }
+}
+
+bitflags::bitflags! {
+ /// Indicates how a global variable is used.
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct GlobalUse: u8 {
+ /// Data will be read from the variable.
+ const READ = 0x1;
+ /// Data will be written to the variable.
+ const WRITE = 0x2;
+ /// The information about the data is queried.
+ const QUERY = 0x4;
+ }
+}
+
+#[derive(Clone, Debug, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct SamplingKey {
+ pub image: Handle<crate::GlobalVariable>,
+ pub sampler: Handle<crate::GlobalVariable>,
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct ExpressionInfo {
+ pub uniformity: Uniformity,
+ pub ref_count: usize,
+ assignable_global: Option<Handle<crate::GlobalVariable>>,
+ pub ty: TypeResolution,
+}
+
+impl ExpressionInfo {
+ const fn new() -> Self {
+ ExpressionInfo {
+ uniformity: Uniformity::new(),
+ ref_count: 0,
+ assignable_global: None,
+ // this doesn't matter at this point, will be overwritten
+ ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: 0,
+ })),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+enum GlobalOrArgument {
+ Global(Handle<crate::GlobalVariable>),
+ Argument(u32),
+}
+
+impl GlobalOrArgument {
+ fn from_expression(
+ expression_arena: &Arena<crate::Expression>,
+ expression: Handle<crate::Expression>,
+ ) -> Result<GlobalOrArgument, ExpressionError> {
+ Ok(match expression_arena[expression] {
+ crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
+ crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
+ crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
+ _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
+ },
+ _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
+ })
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+struct Sampling {
+ image: GlobalOrArgument,
+ sampler: GlobalOrArgument,
+}
+
+#[derive(Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct FunctionInfo {
+ /// Validation flags.
+ #[allow(dead_code)]
+ flags: ValidationFlags,
+ /// Set of shader stages where calling this function is valid.
+ pub available_stages: ShaderStages,
+ /// Uniformity characteristics.
+ pub uniformity: Uniformity,
+ /// Function may kill the invocation.
+ pub may_kill: bool,
+
+ /// All pairs of (texture, sampler) globals that may be used together in
+ /// sampling operations by this function and its callees. This includes
+ /// pairings that arise when this function passes textures and samplers as
+ /// arguments to its callees.
+ ///
+ /// This table does not include uses of textures and samplers passed as
+ /// arguments to this function itself, since we do not know which globals
+ /// those will be. However, this table *is* exhaustive when computed for an
+ /// entry point function: entry points never receive textures or samplers as
+ /// arguments, so all an entry point's sampling can be reported in terms of
+ /// globals.
+ ///
+ /// The GLSL back end uses this table to construct reflection info that
+ /// clients need to construct texture-combined sampler values.
+ pub sampling_set: crate::FastHashSet<SamplingKey>,
+
+ /// How this function and its callees use this module's globals.
+ ///
+ /// This is indexed by `Handle<GlobalVariable>` indices. However,
+ /// `FunctionInfo` implements `std::ops::Index<Handle<GlobalVariable>>`,
+ /// so you can simply index this struct with a global handle to retrieve
+ /// its usage information.
+ global_uses: Box<[GlobalUse]>,
+
+ /// Information about each expression in this function's body.
+ ///
+ /// This is indexed by `Handle<Expression>` indices. However, `FunctionInfo`
+ /// implements `std::ops::Index<Handle<Expression>>`, so you can simply
+ /// index this struct with an expression handle to retrieve its
+ /// `ExpressionInfo`.
+ expressions: Box<[ExpressionInfo]>,
+
+ /// All (texture, sampler) pairs that may be used together in sampling
+ /// operations by this function and its callees, whether they are accessed
+ /// as globals or passed as arguments.
+ ///
+ /// Participants are represented by [`GlobalVariable`] handles whenever
+ /// possible, and otherwise by indices of this function's arguments.
+ ///
+ /// When analyzing a function call, we combine this data about the callee
+ /// with the actual arguments being passed to produce the callers' own
+ /// `sampling_set` and `sampling` tables.
+ ///
+ /// [`GlobalVariable`]: crate::GlobalVariable
+ sampling: crate::FastHashSet<Sampling>,
+
+ /// Indicates that the function is using dual source blending.
+ pub dual_source_blending: bool,
+}
+
+impl FunctionInfo {
+ pub const fn global_variable_count(&self) -> usize {
+ self.global_uses.len()
+ }
+ pub const fn expression_count(&self) -> usize {
+ self.expressions.len()
+ }
+ pub fn dominates_global_use(&self, other: &Self) -> bool {
+ for (self_global_uses, other_global_uses) in
+ self.global_uses.iter().zip(other.global_uses.iter())
+ {
+ if !self_global_uses.contains(*other_global_uses) {
+ return false;
+ }
+ }
+ true
+ }
+}
+
+impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
+ type Output = GlobalUse;
+ fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
+ &self.global_uses[handle.index()]
+ }
+}
+
+impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
+ type Output = ExpressionInfo;
+ fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
+ &self.expressions[handle.index()]
+ }
+}
+
+/// Disruptor of the uniform control flow.
+#[derive(Clone, Copy, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum UniformityDisruptor {
+ #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
+ Expression(Handle<crate::Expression>),
+ #[error("There is a Return earlier in the control flow of the function")]
+ Return,
+ #[error("There is a Discard earlier in the entry point across all called functions")]
+ Discard,
+}
+
+impl FunctionInfo {
+ /// Adds a value-type reference to an expression.
+ #[must_use]
+ fn add_ref_impl(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ global_use: GlobalUse,
+ ) -> NonUniformResult {
+ let info = &mut self.expressions[handle.index()];
+ info.ref_count += 1;
+ // mark the used global as read
+ if let Some(global) = info.assignable_global {
+ self.global_uses[global.index()] |= global_use;
+ }
+ info.uniformity.non_uniform_result
+ }
+
+ /// Adds a value-type reference to an expression.
+ #[must_use]
+ fn add_ref(&mut self, handle: Handle<crate::Expression>) -> NonUniformResult {
+ self.add_ref_impl(handle, GlobalUse::READ)
+ }
+
+ /// Adds a potentially assignable reference to an expression.
+ /// These are destinations for `Store` and `ImageStore` statements,
+ /// which can transit through `Access` and `AccessIndex`.
+ #[must_use]
+ fn add_assignable_ref(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
+ ) -> NonUniformResult {
+ let info = &mut self.expressions[handle.index()];
+ info.ref_count += 1;
+ // propagate the assignable global up the chain, till it either hits
+ // a value-type expression, or the assignment statement.
+ if let Some(global) = info.assignable_global {
+ if let Some(_old) = assignable_global.replace(global) {
+ unreachable!()
+ }
+ }
+ info.uniformity.non_uniform_result
+ }
+
+ /// Inherit information from a called function.
+ fn process_call(
+ &mut self,
+ callee: &Self,
+ arguments: &[Handle<crate::Expression>],
+ expression_arena: &Arena<crate::Expression>,
+ ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
+ self.sampling_set
+ .extend(callee.sampling_set.iter().cloned());
+ for sampling in callee.sampling.iter() {
+ // If the callee was passed the texture or sampler as an argument,
+ // we may now be able to determine which globals those referred to.
+ let image_storage = match sampling.image {
+ GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
+ GlobalOrArgument::Argument(i) => {
+ let handle = arguments[i as usize];
+ GlobalOrArgument::from_expression(expression_arena, handle).map_err(
+ |source| {
+ FunctionError::Expression { handle, source }
+ .with_span_handle(handle, expression_arena)
+ },
+ )?
+ }
+ };
+
+ let sampler_storage = match sampling.sampler {
+ GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
+ GlobalOrArgument::Argument(i) => {
+ let handle = arguments[i as usize];
+ GlobalOrArgument::from_expression(expression_arena, handle).map_err(
+ |source| {
+ FunctionError::Expression { handle, source }
+ .with_span_handle(handle, expression_arena)
+ },
+ )?
+ }
+ };
+
+ // If we've managed to pin both the image and sampler down to
+ // specific globals, record that in our `sampling_set`. Otherwise,
+ // record as much as we do know in our own `sampling` table, for our
+ // callers to sort out.
+ match (image_storage, sampler_storage) {
+ (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
+ self.sampling_set.insert(SamplingKey { image, sampler });
+ }
+ (image, sampler) => {
+ self.sampling.insert(Sampling { image, sampler });
+ }
+ }
+ }
+
+ // Inherit global use from our callees.
+ for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
+ *mine |= *other;
+ }
+
+ Ok(FunctionUniformity {
+ result: callee.uniformity.clone(),
+ exit: if callee.may_kill {
+ ExitFlags::MAY_KILL
+ } else {
+ ExitFlags::empty()
+ },
+ })
+ }
+
+ /// Compute the [`ExpressionInfo`] for `handle`.
+ ///
+ /// Replace the dummy entry in [`self.expressions`] for `handle`
+ /// with a real `ExpressionInfo` value describing that expression.
+ ///
+ /// This function is called as part of a forward sweep through the
+ /// arena, so we can assume that all earlier expressions in the
+ /// arena already have valid info. Since expressions only depend
+ /// on earlier expressions, this includes all our subexpressions.
+ ///
+ /// Adjust the reference counts on all expressions we use.
+ ///
+ /// Also populate the [`sampling_set`], [`sampling`] and
+ /// [`global_uses`] fields of `self`.
+ ///
+ /// [`self.expressions`]: FunctionInfo::expressions
+ /// [`sampling_set`]: FunctionInfo::sampling_set
+ /// [`sampling`]: FunctionInfo::sampling
+ /// [`global_uses`]: FunctionInfo::global_uses
+ #[allow(clippy::or_fun_call)]
+ fn process_expression(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ expression_arena: &Arena<crate::Expression>,
+ other_functions: &[FunctionInfo],
+ resolve_context: &ResolveContext,
+ capabilities: super::Capabilities,
+ ) -> Result<(), ExpressionError> {
+ use crate::{Expression as E, SampleLevel as Sl};
+
+ let expression = &expression_arena[handle];
+ let mut assignable_global = None;
+ let uniformity = match *expression {
+ E::Access { base, index } => {
+ let base_ty = self[base].ty.inner_with(resolve_context.types);
+
+ // build up the caps needed if this is indexed non-uniformly
+ let mut needed_caps = super::Capabilities::empty();
+ let is_binding_array = match *base_ty {
+ crate::TypeInner::BindingArray {
+ base: array_element_ty_handle,
+ ..
+ } => {
+ // these are nasty aliases, but these idents are too long and break rustfmt
+ let ub_st = super::Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
+ let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
+ let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
+
+ // We're a binding array, so lets use the type of _what_ we are array of to determine if we can non-uniformly index it.
+ let array_element_ty =
+ &resolve_context.types[array_element_ty_handle].inner;
+
+ needed_caps |= match *array_element_ty {
+ // If we're an image, use the appropriate limit.
+ crate::TypeInner::Image { class, .. } => match class {
+ crate::ImageClass::Storage { .. } => ub_st,
+ _ => st_sb,
+ },
+ crate::TypeInner::Sampler { .. } => sampler,
+ // If we're anything but an image, assume we're a buffer and use the address space.
+ _ => {
+ if let E::GlobalVariable(global_handle) = expression_arena[base] {
+ let global = &resolve_context.global_vars[global_handle];
+ match global.space {
+ crate::AddressSpace::Uniform => ub_st,
+ crate::AddressSpace::Storage { .. } => st_sb,
+ _ => unreachable!(),
+ }
+ } else {
+ unreachable!()
+ }
+ }
+ };
+
+ true
+ }
+ _ => false,
+ };
+
+ if self[index].uniformity.non_uniform_result.is_some()
+ && !capabilities.contains(needed_caps)
+ && is_binding_array
+ {
+ return Err(ExpressionError::MissingCapabilities(needed_caps));
+ }
+
+ Uniformity {
+ non_uniform_result: self
+ .add_assignable_ref(base, &mut assignable_global)
+ .or(self.add_ref(index)),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::AccessIndex { base, .. } => Uniformity {
+ non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
+ requirements: UniformityRequirements::empty(),
+ },
+ // always uniform
+ E::Splat { size: _, value } => Uniformity {
+ non_uniform_result: self.add_ref(value),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Swizzle { vector, .. } => Uniformity {
+ non_uniform_result: self.add_ref(vector),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => Uniformity::new(),
+ E::Compose { ref components, .. } => {
+ let non_uniform_result = components
+ .iter()
+ .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
+ Uniformity {
+ non_uniform_result,
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ // depends on the builtin or interpolation
+ E::FunctionArgument(index) => {
+ let arg = &resolve_context.arguments[index as usize];
+ let uniform = match arg.binding {
+ Some(crate::Binding::BuiltIn(built_in)) => match built_in {
+ // per-polygon built-ins are uniform
+ crate::BuiltIn::FrontFacing
+ // per-work-group built-ins are uniform
+ | crate::BuiltIn::WorkGroupId
+ | crate::BuiltIn::WorkGroupSize
+ | crate::BuiltIn::NumWorkGroups => true,
+ _ => false,
+ },
+ // only flat inputs are uniform
+ Some(crate::Binding::Location {
+ interpolation: Some(crate::Interpolation::Flat),
+ ..
+ }) => true,
+ _ => false,
+ };
+ Uniformity {
+ non_uniform_result: if uniform { None } else { Some(handle) },
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ // depends on the address space
+ E::GlobalVariable(gh) => {
+ use crate::AddressSpace as As;
+ assignable_global = Some(gh);
+ let var = &resolve_context.global_vars[gh];
+ let uniform = match var.space {
+ // local data is non-uniform
+ As::Function | As::Private => false,
+ // workgroup memory is exclusively accessed by the group
+ As::WorkGroup => true,
+ // uniform data
+ As::Uniform | As::PushConstant => true,
+ // storage data is only uniform when read-only
+ As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
+ As::Handle => false,
+ };
+ Uniformity {
+ non_uniform_result: if uniform { None } else { Some(handle) },
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::LocalVariable(_) => Uniformity {
+ non_uniform_result: Some(handle),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Load { pointer } => Uniformity {
+ non_uniform_result: self.add_ref(pointer),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::ImageSample {
+ image,
+ sampler,
+ gather: _,
+ coordinate,
+ array_index,
+ offset: _,
+ level,
+ depth_ref,
+ } => {
+ let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
+ let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
+
+ match (image_storage, sampler_storage) {
+ (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
+ self.sampling_set.insert(SamplingKey { image, sampler });
+ }
+ _ => {
+ self.sampling.insert(Sampling {
+ image: image_storage,
+ sampler: sampler_storage,
+ });
+ }
+ }
+
+ // "nur" == "Non-Uniform Result"
+ let array_nur = array_index.and_then(|h| self.add_ref(h));
+ let level_nur = match level {
+ Sl::Auto | Sl::Zero => None,
+ Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
+ Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
+ };
+ let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
+ Uniformity {
+ non_uniform_result: self
+ .add_ref(image)
+ .or(self.add_ref(sampler))
+ .or(self.add_ref(coordinate))
+ .or(array_nur)
+ .or(level_nur)
+ .or(dref_nur),
+ requirements: if level.implicit_derivatives() {
+ UniformityRequirements::IMPLICIT_LEVEL
+ } else {
+ UniformityRequirements::empty()
+ },
+ }
+ }
+ E::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ let array_nur = array_index.and_then(|h| self.add_ref(h));
+ let sample_nur = sample.and_then(|h| self.add_ref(h));
+ let level_nur = level.and_then(|h| self.add_ref(h));
+ Uniformity {
+ non_uniform_result: self
+ .add_ref(image)
+ .or(self.add_ref(coordinate))
+ .or(array_nur)
+ .or(sample_nur)
+ .or(level_nur),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::ImageQuery { image, query } => {
+ let query_nur = match query {
+ crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
+ _ => None,
+ };
+ Uniformity {
+ non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::Unary { expr, .. } => Uniformity {
+ non_uniform_result: self.add_ref(expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Binary { left, right, .. } => Uniformity {
+ non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Select {
+ condition,
+ accept,
+ reject,
+ } => Uniformity {
+ non_uniform_result: self
+ .add_ref(condition)
+ .or(self.add_ref(accept))
+ .or(self.add_ref(reject)),
+ requirements: UniformityRequirements::empty(),
+ },
+ // explicit derivatives require uniform
+ E::Derivative { expr, .. } => Uniformity {
+ //Note: taking a derivative of a uniform doesn't make it non-uniform
+ non_uniform_result: self.add_ref(expr),
+ requirements: UniformityRequirements::DERIVATIVE,
+ },
+ E::Relational { argument, .. } => Uniformity {
+ non_uniform_result: self.add_ref(argument),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Math {
+ fun: _,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ let arg1_nur = arg1.and_then(|h| self.add_ref(h));
+ let arg2_nur = arg2.and_then(|h| self.add_ref(h));
+ let arg3_nur = arg3.and_then(|h| self.add_ref(h));
+ Uniformity {
+ non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::As { expr, .. } => Uniformity {
+ non_uniform_result: self.add_ref(expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
+ E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
+ non_uniform_result: Some(handle),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::WorkGroupUniformLoadResult { .. } => Uniformity {
+ // The result of WorkGroupUniformLoad is always uniform by definition
+ non_uniform_result: None,
+ // The call is what cares about uniformity, not the expression
+ // This expression is never emitted, so this requirement should never be used anyway?
+ requirements: UniformityRequirements::empty(),
+ },
+ E::ArrayLength(expr) => Uniformity {
+ non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::RayQueryGetIntersection {
+ query,
+ committed: _,
+ } => Uniformity {
+ non_uniform_result: self.add_ref(query),
+ requirements: UniformityRequirements::empty(),
+ },
+ };
+
+ let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
+ self.expressions[handle.index()] = ExpressionInfo {
+ uniformity,
+ ref_count: 0,
+ assignable_global,
+ ty,
+ };
+ Ok(())
+ }
+
+ /// Analyzes the uniformity requirements of a block (as a sequence of statements).
+ /// Returns the uniformity characteristics at the *function* level, i.e.
+ /// whether or not the function requires to be called in uniform control flow,
+ /// and whether the produced result is not disrupting the control flow.
+ ///
+ /// The parent control flow is uniform if `disruptor.is_none()`.
+ ///
+ /// Returns a `NonUniformControlFlow` error if any of the expressions in the block
+ /// require uniformity, but the current flow is non-uniform.
+ #[allow(clippy::or_fun_call)]
+ fn process_block(
+ &mut self,
+ statements: &crate::Block,
+ other_functions: &[FunctionInfo],
+ mut disruptor: Option<UniformityDisruptor>,
+ expression_arena: &Arena<crate::Expression>,
+ ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
+ use crate::Statement as S;
+
+ let mut combined_uniformity = FunctionUniformity::new();
+ for statement in statements {
+ let uniformity = match *statement {
+ S::Emit(ref range) => {
+ let mut requirements = UniformityRequirements::empty();
+ for expr in range.clone() {
+ let req = self.expressions[expr.index()].uniformity.requirements;
+ if self
+ .flags
+ .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
+ && !req.is_empty()
+ {
+ if let Some(cause) = disruptor {
+ return Err(FunctionError::NonUniformControlFlow(req, expr, cause)
+ .with_span_handle(expr, expression_arena));
+ }
+ }
+ requirements |= req;
+ }
+ FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: None,
+ requirements,
+ },
+ exit: ExitFlags::empty(),
+ }
+ }
+ S::Break | S::Continue => FunctionUniformity::new(),
+ S::Kill => FunctionUniformity {
+ result: Uniformity::new(),
+ exit: if disruptor.is_some() {
+ ExitFlags::MAY_KILL
+ } else {
+ ExitFlags::empty()
+ },
+ },
+ S::Barrier(_) => FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: None,
+ requirements: UniformityRequirements::WORK_GROUP_BARRIER,
+ },
+ exit: ExitFlags::empty(),
+ },
+ S::WorkGroupUniformLoad { pointer, .. } => {
+ let _condition_nur = self.add_ref(pointer);
+
+ // Don't check that this call occurs in uniform control flow until Naga implements WGSL's standard
+ // uniformity analysis (https://github.com/gfx-rs/naga/issues/1744).
+ // The uniformity analysis Naga uses now is less accurate than the one in the WGSL standard,
+ // causing Naga to reject correct uses of `workgroupUniformLoad` in some interesting programs.
+
+ /*
+ if self
+ .flags
+ .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
+ {
+ let condition_nur = self.add_ref(pointer);
+ let this_disruptor =
+ disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
+ if let Some(cause) = this_disruptor {
+ return Err(FunctionError::NonUniformWorkgroupUniformLoad(cause)
+ .with_span_static(*span, "WorkGroupUniformLoad"));
+ }
+ } */
+ FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: None,
+ requirements: UniformityRequirements::WORK_GROUP_BARRIER,
+ },
+ exit: ExitFlags::empty(),
+ }
+ }
+ S::Block(ref b) => {
+ self.process_block(b, other_functions, disruptor, expression_arena)?
+ }
+ S::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ let condition_nur = self.add_ref(condition);
+ let branch_disruptor =
+ disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
+ let accept_uniformity = self.process_block(
+ accept,
+ other_functions,
+ branch_disruptor,
+ expression_arena,
+ )?;
+ let reject_uniformity = self.process_block(
+ reject,
+ other_functions,
+ branch_disruptor,
+ expression_arena,
+ )?;
+ accept_uniformity | reject_uniformity
+ }
+ S::Switch {
+ selector,
+ ref cases,
+ } => {
+ let selector_nur = self.add_ref(selector);
+ let branch_disruptor =
+ disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
+ let mut uniformity = FunctionUniformity::new();
+ let mut case_disruptor = branch_disruptor;
+ for case in cases.iter() {
+ let case_uniformity = self.process_block(
+ &case.body,
+ other_functions,
+ case_disruptor,
+ expression_arena,
+ )?;
+ case_disruptor = if case.fall_through {
+ case_disruptor.or(case_uniformity.exit_disruptor())
+ } else {
+ branch_disruptor
+ };
+ uniformity = uniformity | case_uniformity;
+ }
+ uniformity
+ }
+ S::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ let body_uniformity =
+ self.process_block(body, other_functions, disruptor, expression_arena)?;
+ let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
+ let continuing_uniformity = self.process_block(
+ continuing,
+ other_functions,
+ continuing_disruptor,
+ expression_arena,
+ )?;
+ if let Some(expr) = break_if {
+ let _ = self.add_ref(expr);
+ }
+ body_uniformity | continuing_uniformity
+ }
+ S::Return { value } => FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
+ requirements: UniformityRequirements::empty(),
+ },
+ exit: if disruptor.is_some() {
+ ExitFlags::MAY_RETURN
+ } else {
+ ExitFlags::empty()
+ },
+ },
+ // Here and below, the used expressions are already emitted,
+ // and their results do not affect the function return value,
+ // so we can ignore their non-uniformity.
+ S::Store { pointer, value } => {
+ let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
+ let _ = self.add_ref(value);
+ FunctionUniformity::new()
+ }
+ S::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ let _ = self.add_ref_impl(image, GlobalUse::WRITE);
+ if let Some(expr) = array_index {
+ let _ = self.add_ref(expr);
+ }
+ let _ = self.add_ref(coordinate);
+ let _ = self.add_ref(value);
+ FunctionUniformity::new()
+ }
+ S::Call {
+ function,
+ ref arguments,
+ result: _,
+ } => {
+ for &argument in arguments {
+ let _ = self.add_ref(argument);
+ }
+ let info = &other_functions[function.index()];
+ //Note: the result is validated by the Validator, not here
+ self.process_call(info, arguments, expression_arena)?
+ }
+ S::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result: _,
+ } => {
+ let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
+ let _ = self.add_ref(value);
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ let _ = self.add_ref(cmp);
+ }
+ FunctionUniformity::new()
+ }
+ S::RayQuery { query, ref fun } => {
+ let _ = self.add_ref(query);
+ if let crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } = *fun
+ {
+ let _ = self.add_ref(acceleration_structure);
+ let _ = self.add_ref(descriptor);
+ }
+ FunctionUniformity::new()
+ }
+ };
+
+ disruptor = disruptor.or(uniformity.exit_disruptor());
+ combined_uniformity = combined_uniformity | uniformity;
+ }
+ Ok(combined_uniformity)
+ }
+}
+
+impl ModuleInfo {
+ /// Populates `self.const_expression_types`
+ pub(super) fn process_const_expression(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ resolve_context: &ResolveContext,
+ gctx: crate::proc::GlobalCtx,
+ ) -> Result<(), super::ConstExpressionError> {
+ self.const_expression_types[handle.index()] =
+ resolve_context.resolve(&gctx.const_expressions[handle], |h| Ok(&self[h]))?;
+ Ok(())
+ }
+
+ /// Builds the `FunctionInfo` based on the function, and validates the
+ /// uniform control flow if required by the expressions of this function.
+ pub(super) fn process_function(
+ &self,
+ fun: &crate::Function,
+ module: &crate::Module,
+ flags: ValidationFlags,
+ capabilities: super::Capabilities,
+ ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
+ let mut info = FunctionInfo {
+ flags,
+ available_stages: ShaderStages::all(),
+ uniformity: Uniformity::new(),
+ may_kill: false,
+ sampling_set: crate::FastHashSet::default(),
+ global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
+ expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
+ sampling: crate::FastHashSet::default(),
+ dual_source_blending: false,
+ };
+ let resolve_context =
+ ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
+
+ for (handle, _) in fun.expressions.iter() {
+ if let Err(source) = info.process_expression(
+ handle,
+ &fun.expressions,
+ &self.functions,
+ &resolve_context,
+ capabilities,
+ ) {
+ return Err(FunctionError::Expression { handle, source }
+ .with_span_handle(handle, &fun.expressions));
+ }
+ }
+
+ for (_, expr) in fun.local_variables.iter() {
+ if let Some(init) = expr.init {
+ let _ = info.add_ref(init);
+ }
+ }
+
+ let uniformity = info.process_block(&fun.body, &self.functions, None, &fun.expressions)?;
+ info.uniformity = uniformity.result;
+ info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
+
+ Ok(info)
+ }
+
+ pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
+ &self.entry_points[index]
+ }
+}
+
+#[test]
+fn uniform_control_flow() {
+ use crate::{Expression as E, Statement as S};
+
+ let mut type_arena = crate::UniqueArena::new();
+ let ty = type_arena.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ size: crate::VectorSize::Bi,
+ scalar: crate::Scalar::F32,
+ },
+ },
+ Default::default(),
+ );
+ let mut global_var_arena = Arena::new();
+ let non_uniform_global = global_var_arena.append(
+ crate::GlobalVariable {
+ name: None,
+ init: None,
+ ty,
+ space: crate::AddressSpace::Handle,
+ binding: None,
+ },
+ Default::default(),
+ );
+ let uniform_global = global_var_arena.append(
+ crate::GlobalVariable {
+ name: None,
+ init: None,
+ ty,
+ binding: None,
+ space: crate::AddressSpace::Uniform,
+ },
+ Default::default(),
+ );
+
+ let mut expressions = Arena::new();
+ // checks the uniform control flow
+ let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default());
+ // checks the non-uniform control flow
+ let derivative_expr = expressions.append(
+ E::Derivative {
+ axis: crate::DerivativeAxis::X,
+ ctrl: crate::DerivativeControl::None,
+ expr: constant_expr,
+ },
+ Default::default(),
+ );
+ let emit_range_constant_derivative = expressions.range_from(0);
+ let non_uniform_global_expr =
+ expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
+ let uniform_global_expr =
+ expressions.append(E::GlobalVariable(uniform_global), Default::default());
+ let emit_range_globals = expressions.range_from(2);
+
+ // checks the QUERY flag
+ let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
+ // checks the transitive WRITE flag
+ let access_expr = expressions.append(
+ E::AccessIndex {
+ base: non_uniform_global_expr,
+ index: 1,
+ },
+ Default::default(),
+ );
+ let emit_range_query_access_globals = expressions.range_from(2);
+
+ let mut info = FunctionInfo {
+ flags: ValidationFlags::all(),
+ available_stages: ShaderStages::all(),
+ uniformity: Uniformity::new(),
+ may_kill: false,
+ sampling_set: crate::FastHashSet::default(),
+ global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
+ expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
+ sampling: crate::FastHashSet::default(),
+ dual_source_blending: false,
+ };
+ let resolve_context = ResolveContext {
+ constants: &Arena::new(),
+ types: &type_arena,
+ special_types: &crate::SpecialTypes::default(),
+ global_vars: &global_var_arena,
+ local_vars: &Arena::new(),
+ functions: &Arena::new(),
+ arguments: &[],
+ };
+ for (handle, _) in expressions.iter() {
+ info.process_expression(
+ handle,
+ &expressions,
+ &[],
+ &resolve_context,
+ super::Capabilities::empty(),
+ )
+ .unwrap();
+ }
+ assert_eq!(info[non_uniform_global_expr].ref_count, 1);
+ assert_eq!(info[uniform_global_expr].ref_count, 1);
+ assert_eq!(info[query_expr].ref_count, 0);
+ assert_eq!(info[access_expr].ref_count, 0);
+ assert_eq!(info[non_uniform_global], GlobalUse::empty());
+ assert_eq!(info[uniform_global], GlobalUse::QUERY);
+
+ let stmt_emit1 = S::Emit(emit_range_globals.clone());
+ let stmt_if_uniform = S::If {
+ condition: uniform_global_expr,
+ accept: crate::Block::new(),
+ reject: vec![
+ S::Emit(emit_range_constant_derivative.clone()),
+ S::Store {
+ pointer: constant_expr,
+ value: derivative_expr,
+ },
+ ]
+ .into(),
+ };
+ assert_eq!(
+ info.process_block(
+ &vec![stmt_emit1, stmt_if_uniform].into(),
+ &[],
+ None,
+ &expressions
+ ),
+ Ok(FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: None,
+ requirements: UniformityRequirements::DERIVATIVE,
+ },
+ exit: ExitFlags::empty(),
+ }),
+ );
+ assert_eq!(info[constant_expr].ref_count, 2);
+ assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
+
+ let stmt_emit2 = S::Emit(emit_range_globals.clone());
+ let stmt_if_non_uniform = S::If {
+ condition: non_uniform_global_expr,
+ accept: vec![
+ S::Emit(emit_range_constant_derivative),
+ S::Store {
+ pointer: constant_expr,
+ value: derivative_expr,
+ },
+ ]
+ .into(),
+ reject: crate::Block::new(),
+ };
+ {
+ let block_info = info.process_block(
+ &vec![stmt_emit2, stmt_if_non_uniform].into(),
+ &[],
+ None,
+ &expressions,
+ );
+ if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE {
+ assert_eq!(info[derivative_expr].ref_count, 2);
+ } else {
+ assert_eq!(
+ block_info,
+ Err(FunctionError::NonUniformControlFlow(
+ UniformityRequirements::DERIVATIVE,
+ derivative_expr,
+ UniformityDisruptor::Expression(non_uniform_global_expr)
+ )
+ .with_span()),
+ );
+ assert_eq!(info[derivative_expr].ref_count, 1);
+ }
+ }
+ assert_eq!(info[non_uniform_global], GlobalUse::READ);
+
+ let stmt_emit3 = S::Emit(emit_range_globals);
+ let stmt_return_non_uniform = S::Return {
+ value: Some(non_uniform_global_expr),
+ };
+ assert_eq!(
+ info.process_block(
+ &vec![stmt_emit3, stmt_return_non_uniform].into(),
+ &[],
+ Some(UniformityDisruptor::Return),
+ &expressions
+ ),
+ Ok(FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: Some(non_uniform_global_expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ exit: ExitFlags::MAY_RETURN,
+ }),
+ );
+ assert_eq!(info[non_uniform_global_expr].ref_count, 3);
+
+ // Check that uniformity requirements reach through a pointer
+ let stmt_emit4 = S::Emit(emit_range_query_access_globals);
+ let stmt_assign = S::Store {
+ pointer: access_expr,
+ value: query_expr,
+ };
+ let stmt_return_pointer = S::Return {
+ value: Some(access_expr),
+ };
+ let stmt_kill = S::Kill;
+ assert_eq!(
+ info.process_block(
+ &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
+ &[],
+ Some(UniformityDisruptor::Discard),
+ &expressions
+ ),
+ Ok(FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: Some(non_uniform_global_expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ exit: ExitFlags::all(),
+ }),
+ );
+ assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
+}
diff --git a/third_party/rust/naga/src/valid/compose.rs b/third_party/rust/naga/src/valid/compose.rs
new file mode 100644
index 0000000000..c21e98c6f2
--- /dev/null
+++ b/third_party/rust/naga/src/valid/compose.rs
@@ -0,0 +1,128 @@
+use crate::proc::TypeResolution;
+
+use crate::arena::Handle;
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum ComposeError {
+ #[error("Composing of type {0:?} can't be done")]
+ Type(Handle<crate::Type>),
+ #[error("Composing expects {expected} components but {given} were given")]
+ ComponentCount { given: u32, expected: u32 },
+ #[error("Composing {index}'s component type is not expected")]
+ ComponentType { index: u32 },
+}
+
+pub fn validate_compose(
+ self_ty_handle: Handle<crate::Type>,
+ gctx: crate::proc::GlobalCtx,
+ component_resolutions: impl ExactSizeIterator<Item = TypeResolution>,
+) -> Result<(), ComposeError> {
+ use crate::TypeInner as Ti;
+
+ match gctx.types[self_ty_handle].inner {
+ // vectors are composed from scalars or other vectors
+ Ti::Vector { size, scalar } => {
+ let mut total = 0;
+ for (index, comp_res) in component_resolutions.enumerate() {
+ total += match *comp_res.inner_with(gctx.types) {
+ Ti::Scalar(comp_scalar) if comp_scalar == scalar => 1,
+ Ti::Vector {
+ size: comp_size,
+ scalar: comp_scalar,
+ } if comp_scalar == scalar => comp_size as u32,
+ ref other => {
+ log::error!(
+ "Vector component[{}] type {:?}, building {:?}",
+ index,
+ other,
+ scalar
+ );
+ return Err(ComposeError::ComponentType {
+ index: index as u32,
+ });
+ }
+ };
+ }
+ if size as u32 != total {
+ return Err(ComposeError::ComponentCount {
+ expected: size as u32,
+ given: total,
+ });
+ }
+ }
+ // matrix are composed from column vectors
+ Ti::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ let inner = Ti::Vector { size: rows, scalar };
+ if columns as usize != component_resolutions.len() {
+ return Err(ComposeError::ComponentCount {
+ expected: columns as u32,
+ given: component_resolutions.len() as u32,
+ });
+ }
+ for (index, comp_res) in component_resolutions.enumerate() {
+ if comp_res.inner_with(gctx.types) != &inner {
+ log::error!("Matrix component[{}] type {:?}", index, comp_res);
+ return Err(ComposeError::ComponentType {
+ index: index as u32,
+ });
+ }
+ }
+ }
+ Ti::Array {
+ base,
+ size: crate::ArraySize::Constant(count),
+ stride: _,
+ } => {
+ if count.get() as usize != component_resolutions.len() {
+ return Err(ComposeError::ComponentCount {
+ expected: count.get(),
+ given: component_resolutions.len() as u32,
+ });
+ }
+ for (index, comp_res) in component_resolutions.enumerate() {
+ let base_inner = &gctx.types[base].inner;
+ let comp_res_inner = comp_res.inner_with(gctx.types);
+ // We don't support arrays of pointers, but it seems best not to
+ // embed that assumption here, so use `TypeInner::equivalent`.
+ if !base_inner.equivalent(comp_res_inner, gctx.types) {
+ log::error!("Array component[{}] type {:?}", index, comp_res);
+ return Err(ComposeError::ComponentType {
+ index: index as u32,
+ });
+ }
+ }
+ }
+ Ti::Struct { ref members, .. } => {
+ if members.len() != component_resolutions.len() {
+ return Err(ComposeError::ComponentCount {
+ given: component_resolutions.len() as u32,
+ expected: members.len() as u32,
+ });
+ }
+ for (index, (member, comp_res)) in members.iter().zip(component_resolutions).enumerate()
+ {
+ let member_inner = &gctx.types[member.ty].inner;
+ let comp_res_inner = comp_res.inner_with(gctx.types);
+ // We don't support pointers in structs, but it seems best not to embed
+ // that assumption here, so use `TypeInner::equivalent`.
+ if !comp_res_inner.equivalent(member_inner, gctx.types) {
+ log::error!("Struct component[{}] type {:?}", index, comp_res);
+ return Err(ComposeError::ComponentType {
+ index: index as u32,
+ });
+ }
+ }
+ }
+ ref other => {
+ log::error!("Composing of {:?}", other);
+ return Err(ComposeError::Type(self_ty_handle));
+ }
+ }
+
+ Ok(())
+}
diff --git a/third_party/rust/naga/src/valid/expression.rs b/third_party/rust/naga/src/valid/expression.rs
new file mode 100644
index 0000000000..c82d60f062
--- /dev/null
+++ b/third_party/rust/naga/src/valid/expression.rs
@@ -0,0 +1,1797 @@
+use super::{
+ compose::validate_compose, validate_atomic_compare_exchange_struct, FunctionInfo, ModuleInfo,
+ ShaderStages, TypeFlags,
+};
+use crate::arena::UniqueArena;
+
+use crate::{
+ arena::Handle,
+ proc::{IndexableLengthError, ResolveError},
+};
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum ExpressionError {
+ #[error("Doesn't exist")]
+ DoesntExist,
+ #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
+ NotInScope,
+ #[error("Base type {0:?} is not compatible with this expression")]
+ InvalidBaseType(Handle<crate::Expression>),
+ #[error("Accessing with index {0:?} can't be done")]
+ InvalidIndexType(Handle<crate::Expression>),
+ #[error("Accessing {0:?} via a negative index is invalid")]
+ NegativeIndex(Handle<crate::Expression>),
+ #[error("Accessing index {1} is out of {0:?} bounds")]
+ IndexOutOfBounds(Handle<crate::Expression>, u32),
+ #[error("The expression {0:?} may only be indexed by a constant")]
+ IndexMustBeConstant(Handle<crate::Expression>),
+ #[error("Function argument {0:?} doesn't exist")]
+ FunctionArgumentDoesntExist(u32),
+ #[error("Loading of {0:?} can't be done")]
+ InvalidPointerType(Handle<crate::Expression>),
+ #[error("Array length of {0:?} can't be done")]
+ InvalidArrayType(Handle<crate::Expression>),
+ #[error("Get intersection of {0:?} can't be done")]
+ InvalidRayQueryType(Handle<crate::Expression>),
+ #[error("Splatting {0:?} can't be done")]
+ InvalidSplatType(Handle<crate::Expression>),
+ #[error("Swizzling {0:?} can't be done")]
+ InvalidVectorType(Handle<crate::Expression>),
+ #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
+ InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
+ #[error(transparent)]
+ Compose(#[from] super::ComposeError),
+ #[error(transparent)]
+ IndexableLength(#[from] IndexableLengthError),
+ #[error("Operation {0:?} can't work with {1:?}")]
+ InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
+ #[error("Operation {0:?} can't work with {1:?} and {2:?}")]
+ InvalidBinaryOperandTypes(
+ crate::BinaryOperator,
+ Handle<crate::Expression>,
+ Handle<crate::Expression>,
+ ),
+ #[error("Selecting is not possible")]
+ InvalidSelectTypes,
+ #[error("Relational argument {0:?} is not a boolean vector")]
+ InvalidBooleanVector(Handle<crate::Expression>),
+ #[error("Relational argument {0:?} is not a float")]
+ InvalidFloatArgument(Handle<crate::Expression>),
+ #[error("Type resolution failed")]
+ Type(#[from] ResolveError),
+ #[error("Not a global variable")]
+ ExpectedGlobalVariable,
+ #[error("Not a global variable or a function argument")]
+ ExpectedGlobalOrArgument,
+ #[error("Needs to be an binding array instead of {0:?}")]
+ ExpectedBindingArrayType(Handle<crate::Type>),
+ #[error("Needs to be an image instead of {0:?}")]
+ ExpectedImageType(Handle<crate::Type>),
+ #[error("Needs to be an image instead of {0:?}")]
+ ExpectedSamplerType(Handle<crate::Type>),
+ #[error("Unable to operate on image class {0:?}")]
+ InvalidImageClass(crate::ImageClass),
+ #[error("Derivatives can only be taken from scalar and vector floats")]
+ InvalidDerivative,
+ #[error("Image array index parameter is misplaced")]
+ InvalidImageArrayIndex,
+ #[error("Inappropriate sample or level-of-detail index for texel access")]
+ InvalidImageOtherIndex,
+ #[error("Image array index type of {0:?} is not an integer scalar")]
+ InvalidImageArrayIndexType(Handle<crate::Expression>),
+ #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
+ InvalidImageOtherIndexType(Handle<crate::Expression>),
+ #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
+ InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
+ #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
+ ComparisonSamplingMismatch {
+ image: crate::ImageClass,
+ sampler: bool,
+ has_ref: bool,
+ },
+ #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
+ InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
+ #[error("Depth reference {0:?} is not a scalar float")]
+ InvalidDepthReference(Handle<crate::Expression>),
+ #[error("Depth sample level can only be Auto or Zero")]
+ InvalidDepthSampleLevel,
+ #[error("Gather level can only be Zero")]
+ InvalidGatherLevel,
+ #[error("Gather component {0:?} doesn't exist in the image")]
+ InvalidGatherComponent(crate::SwizzleComponent),
+ #[error("Gather can't be done for image dimension {0:?}")]
+ InvalidGatherDimension(crate::ImageDimension),
+ #[error("Sample level (exact) type {0:?} is not a scalar float")]
+ InvalidSampleLevelExactType(Handle<crate::Expression>),
+ #[error("Sample level (bias) type {0:?} is not a scalar float")]
+ InvalidSampleLevelBiasType(Handle<crate::Expression>),
+ #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
+ InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
+ #[error("Unable to cast")]
+ InvalidCastArgument,
+ #[error("Invalid argument count for {0:?}")]
+ WrongArgumentCount(crate::MathFunction),
+ #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
+ InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
+ #[error("Atomic result type can't be {0:?}")]
+ InvalidAtomicResultType(Handle<crate::Type>),
+ #[error(
+ "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type."
+ )]
+ InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>),
+ #[error("Shader requires capability {0:?}")]
+ MissingCapabilities(super::Capabilities),
+ #[error(transparent)]
+ Literal(#[from] LiteralError),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum ConstExpressionError {
+ #[error("The expression is not a constant expression")]
+ NonConst,
+ #[error(transparent)]
+ Compose(#[from] super::ComposeError),
+ #[error("Splatting {0:?} can't be done")]
+ InvalidSplatType(Handle<crate::Expression>),
+ #[error("Type resolution failed")]
+ Type(#[from] ResolveError),
+ #[error(transparent)]
+ Literal(#[from] LiteralError),
+ #[error(transparent)]
+ Width(#[from] super::r#type::WidthError),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum LiteralError {
+ #[error("Float literal is NaN")]
+ NaN,
+ #[error("Float literal is infinite")]
+ Infinity,
+ #[error(transparent)]
+ Width(#[from] super::r#type::WidthError),
+}
+
+struct ExpressionTypeResolver<'a> {
+ root: Handle<crate::Expression>,
+ types: &'a UniqueArena<crate::Type>,
+ info: &'a FunctionInfo,
+}
+
+impl<'a> std::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'a> {
+ type Output = crate::TypeInner;
+
+ #[allow(clippy::panic)]
+ fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
+ if handle < self.root {
+ self.info[handle].ty.inner_with(self.types)
+ } else {
+ // `Validator::validate_module_handles` should have caught this.
+ panic!(
+ "Depends on {:?}, which has not been processed yet",
+ self.root
+ )
+ }
+ }
+}
+
+impl super::Validator {
+ pub(super) fn validate_const_expression(
+ &self,
+ handle: Handle<crate::Expression>,
+ gctx: crate::proc::GlobalCtx,
+ mod_info: &ModuleInfo,
+ ) -> Result<(), ConstExpressionError> {
+ use crate::Expression as E;
+
+ match gctx.const_expressions[handle] {
+ E::Literal(literal) => {
+ self.validate_literal(literal)?;
+ }
+ E::Constant(_) | E::ZeroValue(_) => {}
+ E::Compose { ref components, ty } => {
+ validate_compose(
+ ty,
+ gctx,
+ components.iter().map(|&handle| mod_info[handle].clone()),
+ )?;
+ }
+ E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
+ crate::TypeInner::Scalar { .. } => {}
+ _ => return Err(super::ConstExpressionError::InvalidSplatType(value)),
+ },
+ _ => return Err(super::ConstExpressionError::NonConst),
+ }
+
+ Ok(())
+ }
+
+ pub(super) fn validate_expression(
+ &self,
+ root: Handle<crate::Expression>,
+ expression: &crate::Expression,
+ function: &crate::Function,
+ module: &crate::Module,
+ info: &FunctionInfo,
+ mod_info: &ModuleInfo,
+ ) -> Result<ShaderStages, ExpressionError> {
+ use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
+
+ let resolver = ExpressionTypeResolver {
+ root,
+ types: &module.types,
+ info,
+ };
+
+ let stages = match *expression {
+ E::Access { base, index } => {
+ let base_type = &resolver[base];
+ // See the documentation for `Expression::Access`.
+ let dynamic_indexing_restricted = match *base_type {
+ Ti::Vector { .. } => false,
+ Ti::Matrix { .. } | Ti::Array { .. } => true,
+ Ti::Pointer { .. }
+ | Ti::ValuePointer { size: Some(_), .. }
+ | Ti::BindingArray { .. } => false,
+ ref other => {
+ log::error!("Indexing of {:?}", other);
+ return Err(ExpressionError::InvalidBaseType(base));
+ }
+ };
+ match resolver[index] {
+ //TODO: only allow one of these
+ Ti::Scalar(Sc {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ }) => {}
+ ref other => {
+ log::error!("Indexing by {:?}", other);
+ return Err(ExpressionError::InvalidIndexType(index));
+ }
+ }
+ if dynamic_indexing_restricted
+ && function.expressions[index].is_dynamic_index(module)
+ {
+ return Err(ExpressionError::IndexMustBeConstant(base));
+ }
+
+ // If we know both the length and the index, we can do the
+ // bounds check now.
+ if let crate::proc::IndexableLength::Known(known_length) =
+ base_type.indexable_length(module)?
+ {
+ match module
+ .to_ctx()
+ .eval_expr_to_u32_from(index, &function.expressions)
+ {
+ Ok(value) => {
+ if value >= known_length {
+ return Err(ExpressionError::IndexOutOfBounds(base, value));
+ }
+ }
+ Err(crate::proc::U32EvalError::Negative) => {
+ return Err(ExpressionError::NegativeIndex(base))
+ }
+ Err(crate::proc::U32EvalError::NonConst) => {}
+ }
+ }
+
+ ShaderStages::all()
+ }
+ E::AccessIndex { base, index } => {
+ fn resolve_index_limit(
+ module: &crate::Module,
+ top: Handle<crate::Expression>,
+ ty: &crate::TypeInner,
+ top_level: bool,
+ ) -> Result<u32, ExpressionError> {
+ let limit = match *ty {
+ Ti::Vector { size, .. }
+ | Ti::ValuePointer {
+ size: Some(size), ..
+ } => size as u32,
+ Ti::Matrix { columns, .. } => columns as u32,
+ Ti::Array {
+ size: crate::ArraySize::Constant(len),
+ ..
+ } => len.get(),
+ Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, // can't statically know, but need run-time checks
+ Ti::Pointer { base, .. } if top_level => {
+ resolve_index_limit(module, top, &module.types[base].inner, false)?
+ }
+ Ti::Struct { ref members, .. } => members.len() as u32,
+ ref other => {
+ log::error!("Indexing of {:?}", other);
+ return Err(ExpressionError::InvalidBaseType(top));
+ }
+ };
+ Ok(limit)
+ }
+
+ let limit = resolve_index_limit(module, base, &resolver[base], true)?;
+ if index >= limit {
+ return Err(ExpressionError::IndexOutOfBounds(base, limit));
+ }
+ ShaderStages::all()
+ }
+ E::Splat { size: _, value } => match resolver[value] {
+ Ti::Scalar { .. } => ShaderStages::all(),
+ ref other => {
+ log::error!("Splat scalar type {:?}", other);
+ return Err(ExpressionError::InvalidSplatType(value));
+ }
+ },
+ E::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ let vec_size = match resolver[vector] {
+ Ti::Vector { size: vec_size, .. } => vec_size,
+ ref other => {
+ log::error!("Swizzle vector type {:?}", other);
+ return Err(ExpressionError::InvalidVectorType(vector));
+ }
+ };
+ for &sc in pattern[..size as usize].iter() {
+ if sc as u8 >= vec_size as u8 {
+ return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
+ }
+ }
+ ShaderStages::all()
+ }
+ E::Literal(literal) => {
+ self.validate_literal(literal)?;
+ ShaderStages::all()
+ }
+ E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(),
+ E::Compose { ref components, ty } => {
+ validate_compose(
+ ty,
+ module.to_ctx(),
+ components.iter().map(|&handle| info[handle].ty.clone()),
+ )?;
+ ShaderStages::all()
+ }
+ E::FunctionArgument(index) => {
+ if index >= function.arguments.len() as u32 {
+ return Err(ExpressionError::FunctionArgumentDoesntExist(index));
+ }
+ ShaderStages::all()
+ }
+ E::GlobalVariable(_handle) => ShaderStages::all(),
+ E::LocalVariable(_handle) => ShaderStages::all(),
+ E::Load { pointer } => {
+ match resolver[pointer] {
+ Ti::Pointer { base, .. }
+ if self.types[base.index()]
+ .flags
+ .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
+ Ti::ValuePointer { .. } => {}
+ ref other => {
+ log::error!("Loading {:?}", other);
+ return Err(ExpressionError::InvalidPointerType(pointer));
+ }
+ }
+ ShaderStages::all()
+ }
+ E::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ // check the validity of expressions
+ let image_ty = Self::global_var_ty(module, function, image)?;
+ let sampler_ty = Self::global_var_ty(module, function, sampler)?;
+
+ let comparison = match module.types[sampler_ty].inner {
+ Ti::Sampler { comparison } => comparison,
+ _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
+ };
+
+ let (class, dim) = match module.types[image_ty].inner {
+ Ti::Image {
+ class,
+ arrayed,
+ dim,
+ } => {
+ // check the array property
+ if arrayed != array_index.is_some() {
+ return Err(ExpressionError::InvalidImageArrayIndex);
+ }
+ if let Some(expr) = array_index {
+ match resolver[expr] {
+ Ti::Scalar(Sc {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ }) => {}
+ _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
+ }
+ }
+ (class, dim)
+ }
+ _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
+ };
+
+ // check sampling and comparison properties
+ let image_depth = match class {
+ crate::ImageClass::Sampled {
+ kind: crate::ScalarKind::Float,
+ multi: false,
+ } => false,
+ crate::ImageClass::Sampled {
+ kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
+ multi: false,
+ } if gather.is_some() => false,
+ crate::ImageClass::Depth { multi: false } => true,
+ _ => return Err(ExpressionError::InvalidImageClass(class)),
+ };
+ if comparison != depth_ref.is_some() || (comparison && !image_depth) {
+ return Err(ExpressionError::ComparisonSamplingMismatch {
+ image: class,
+ sampler: comparison,
+ has_ref: depth_ref.is_some(),
+ });
+ }
+
+ // check texture coordinates type
+ let num_components = match dim {
+ crate::ImageDimension::D1 => 1,
+ crate::ImageDimension::D2 => 2,
+ crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
+ };
+ match resolver[coordinate] {
+ Ti::Scalar(Sc {
+ kind: Sk::Float, ..
+ }) if num_components == 1 => {}
+ Ti::Vector {
+ size,
+ scalar:
+ Sc {
+ kind: Sk::Float, ..
+ },
+ } if size as u32 == num_components => {}
+ _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
+ }
+
+ // check constant offset
+ if let Some(const_expr) = offset {
+ match *mod_info[const_expr].inner_with(&module.types) {
+ Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
+ Ti::Vector {
+ size,
+ scalar: Sc { kind: Sk::Sint, .. },
+ } if size as u32 == num_components => {}
+ _ => {
+ return Err(ExpressionError::InvalidSampleOffset(dim, const_expr));
+ }
+ }
+ }
+
+ // check depth reference type
+ if let Some(expr) = depth_ref {
+ match resolver[expr] {
+ Ti::Scalar(Sc {
+ kind: Sk::Float, ..
+ }) => {}
+ _ => return Err(ExpressionError::InvalidDepthReference(expr)),
+ }
+ match level {
+ crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
+ _ => return Err(ExpressionError::InvalidDepthSampleLevel),
+ }
+ }
+
+ if let Some(component) = gather {
+ match dim {
+ crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
+ crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
+ return Err(ExpressionError::InvalidGatherDimension(dim))
+ }
+ };
+ let max_component = match class {
+ crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
+ _ => crate::SwizzleComponent::W,
+ };
+ if component > max_component {
+ return Err(ExpressionError::InvalidGatherComponent(component));
+ }
+ match level {
+ crate::SampleLevel::Zero => {}
+ _ => return Err(ExpressionError::InvalidGatherLevel),
+ }
+ }
+
+ // check level properties
+ match level {
+ crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
+ crate::SampleLevel::Zero => ShaderStages::all(),
+ crate::SampleLevel::Exact(expr) => {
+ match resolver[expr] {
+ Ti::Scalar(Sc {
+ kind: Sk::Float, ..
+ }) => {}
+ _ => return Err(ExpressionError::InvalidSampleLevelExactType(expr)),
+ }
+ ShaderStages::all()
+ }
+ crate::SampleLevel::Bias(expr) => {
+ match resolver[expr] {
+ Ti::Scalar(Sc {
+ kind: Sk::Float, ..
+ }) => {}
+ _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
+ }
+ ShaderStages::FRAGMENT
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ match resolver[x] {
+ Ti::Scalar(Sc {
+ kind: Sk::Float, ..
+ }) if num_components == 1 => {}
+ Ti::Vector {
+ size,
+ scalar:
+ Sc {
+ kind: Sk::Float, ..
+ },
+ } if size as u32 == num_components => {}
+ _ => {
+ return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
+ }
+ }
+ match resolver[y] {
+ Ti::Scalar(Sc {
+ kind: Sk::Float, ..
+ }) if num_components == 1 => {}
+ Ti::Vector {
+ size,
+ scalar:
+ Sc {
+ kind: Sk::Float, ..
+ },
+ } if size as u32 == num_components => {}
+ _ => {
+ return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
+ }
+ }
+ ShaderStages::all()
+ }
+ }
+ }
+ E::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ let ty = Self::global_var_ty(module, function, image)?;
+ match module.types[ty].inner {
+ Ti::Image {
+ class,
+ arrayed,
+ dim,
+ } => {
+ match resolver[coordinate].image_storage_coordinates() {
+ Some(coord_dim) if coord_dim == dim => {}
+ _ => {
+ return Err(ExpressionError::InvalidImageCoordinateType(
+ dim, coordinate,
+ ))
+ }
+ };
+ if arrayed != array_index.is_some() {
+ return Err(ExpressionError::InvalidImageArrayIndex);
+ }
+ if let Some(expr) = array_index {
+ match resolver[expr] {
+ Ti::Scalar(Sc {
+ kind: Sk::Sint | Sk::Uint,
+ width: _,
+ }) => {}
+ _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
+ }
+ }
+
+ match (sample, class.is_multisampled()) {
+ (None, false) => {}
+ (Some(sample), true) => {
+ if resolver[sample].scalar_kind() != Some(Sk::Sint) {
+ return Err(ExpressionError::InvalidImageOtherIndexType(
+ sample,
+ ));
+ }
+ }
+ _ => {
+ return Err(ExpressionError::InvalidImageOtherIndex);
+ }
+ }
+
+ match (level, class.is_mipmapped()) {
+ (None, false) => {}
+ (Some(level), true) => {
+ if resolver[level].scalar_kind() != Some(Sk::Sint) {
+ return Err(ExpressionError::InvalidImageOtherIndexType(level));
+ }
+ }
+ _ => {
+ return Err(ExpressionError::InvalidImageOtherIndex);
+ }
+ }
+ }
+ _ => return Err(ExpressionError::ExpectedImageType(ty)),
+ }
+ ShaderStages::all()
+ }
+ E::ImageQuery { image, query } => {
+ let ty = Self::global_var_ty(module, function, image)?;
+ match module.types[ty].inner {
+ Ti::Image { class, arrayed, .. } => {
+ let good = match query {
+ crate::ImageQuery::NumLayers => arrayed,
+ crate::ImageQuery::Size { level: None } => true,
+ crate::ImageQuery::Size { level: Some(_) }
+ | crate::ImageQuery::NumLevels => class.is_mipmapped(),
+ crate::ImageQuery::NumSamples => class.is_multisampled(),
+ };
+ if !good {
+ return Err(ExpressionError::InvalidImageClass(class));
+ }
+ }
+ _ => return Err(ExpressionError::ExpectedImageType(ty)),
+ }
+ ShaderStages::all()
+ }
+ E::Unary { op, expr } => {
+ use crate::UnaryOperator as Uo;
+ let inner = &resolver[expr];
+ match (op, inner.scalar_kind()) {
+ (Uo::Negate, Some(Sk::Float | Sk::Sint))
+ | (Uo::LogicalNot, Some(Sk::Bool))
+ | (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {}
+ other => {
+ log::error!("Op {:?} kind {:?}", op, other);
+ return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
+ }
+ }
+ ShaderStages::all()
+ }
+ E::Binary { op, left, right } => {
+ use crate::BinaryOperator as Bo;
+ let left_inner = &resolver[left];
+ let right_inner = &resolver[right];
+ let good = match op {
+ Bo::Add | Bo::Subtract => match *left_inner {
+ Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
+ Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
+ Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
+ },
+ Ti::Matrix { .. } => left_inner == right_inner,
+ _ => false,
+ },
+ Bo::Divide | Bo::Modulo => match *left_inner {
+ Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
+ Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
+ Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
+ },
+ _ => false,
+ },
+ Bo::Multiply => {
+ let kind_allowed = match left_inner.scalar_kind() {
+ Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
+ Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false,
+ };
+ let types_match = match (left_inner, right_inner) {
+ // Straight scalar and mixed scalar/vector.
+ (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2))
+ | (
+ &Ti::Vector {
+ scalar: scalar1, ..
+ },
+ &Ti::Scalar(scalar2),
+ )
+ | (
+ &Ti::Scalar(scalar1),
+ &Ti::Vector {
+ scalar: scalar2, ..
+ },
+ ) => scalar1 == scalar2,
+ // Scalar/matrix.
+ (
+ &Ti::Scalar(Sc {
+ kind: Sk::Float, ..
+ }),
+ &Ti::Matrix { .. },
+ )
+ | (
+ &Ti::Matrix { .. },
+ &Ti::Scalar(Sc {
+ kind: Sk::Float, ..
+ }),
+ ) => true,
+ // Vector/vector.
+ (
+ &Ti::Vector {
+ size: size1,
+ scalar: scalar1,
+ },
+ &Ti::Vector {
+ size: size2,
+ scalar: scalar2,
+ },
+ ) => scalar1 == scalar2 && size1 == size2,
+ // Matrix * vector.
+ (
+ &Ti::Matrix { columns, .. },
+ &Ti::Vector {
+ size,
+ scalar:
+ Sc {
+ kind: Sk::Float, ..
+ },
+ },
+ ) => columns == size,
+ // Vector * matrix.
+ (
+ &Ti::Vector {
+ size,
+ scalar:
+ Sc {
+ kind: Sk::Float, ..
+ },
+ },
+ &Ti::Matrix { rows, .. },
+ ) => size == rows,
+ (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
+ columns == rows
+ }
+ _ => false,
+ };
+ let left_width = left_inner.scalar_width().unwrap_or(0);
+ let right_width = right_inner.scalar_width().unwrap_or(0);
+ kind_allowed && types_match && left_width == right_width
+ }
+ Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
+ Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
+ match *left_inner {
+ Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
+ Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
+ Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false,
+ },
+ ref other => {
+ log::error!("Op {:?} left type {:?}", op, other);
+ false
+ }
+ }
+ }
+ Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
+ Ti::Scalar(Sc { kind: Sk::Bool, .. })
+ | Ti::Vector {
+ scalar: Sc { kind: Sk::Bool, .. },
+ ..
+ } => left_inner == right_inner,
+ ref other => {
+ log::error!("Op {:?} left type {:?}", op, other);
+ false
+ }
+ },
+ Bo::And | Bo::InclusiveOr => match *left_inner {
+ Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
+ Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
+ Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
+ },
+ ref other => {
+ log::error!("Op {:?} left type {:?}", op, other);
+ false
+ }
+ },
+ Bo::ExclusiveOr => match *left_inner {
+ Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind {
+ Sk::Sint | Sk::Uint => left_inner == right_inner,
+ Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false,
+ },
+ ref other => {
+ log::error!("Op {:?} left type {:?}", op, other);
+ false
+ }
+ },
+ Bo::ShiftLeft | Bo::ShiftRight => {
+ let (base_size, base_scalar) = match *left_inner {
+ Ti::Scalar(scalar) => (Ok(None), scalar),
+ Ti::Vector { size, scalar } => (Ok(Some(size)), scalar),
+ ref other => {
+ log::error!("Op {:?} base type {:?}", op, other);
+ (Err(()), Sc::BOOL)
+ }
+ };
+ let shift_size = match *right_inner {
+ Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None),
+ Ti::Vector {
+ size,
+ scalar: Sc { kind: Sk::Uint, .. },
+ } => Ok(Some(size)),
+ ref other => {
+ log::error!("Op {:?} shift type {:?}", op, other);
+ Err(())
+ }
+ };
+ match base_scalar.kind {
+ Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
+ Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false,
+ }
+ }
+ };
+ if !good {
+ log::error!(
+ "Left: {:?} of type {:?}",
+ function.expressions[left],
+ left_inner
+ );
+ log::error!(
+ "Right: {:?} of type {:?}",
+ function.expressions[right],
+ right_inner
+ );
+ return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right));
+ }
+ ShaderStages::all()
+ }
+ E::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ let accept_inner = &resolver[accept];
+ let reject_inner = &resolver[reject];
+ let condition_good = match resolver[condition] {
+ Ti::Scalar(Sc {
+ kind: Sk::Bool,
+ width: _,
+ }) => {
+ // When `condition` is a single boolean, `accept` and
+ // `reject` can be vectors or scalars.
+ match *accept_inner {
+ Ti::Scalar { .. } | Ti::Vector { .. } => true,
+ _ => false,
+ }
+ }
+ Ti::Vector {
+ size,
+ scalar:
+ Sc {
+ kind: Sk::Bool,
+ width: _,
+ },
+ } => match *accept_inner {
+ Ti::Vector {
+ size: other_size, ..
+ } => size == other_size,
+ _ => false,
+ },
+ _ => false,
+ };
+ if !condition_good || accept_inner != reject_inner {
+ return Err(ExpressionError::InvalidSelectTypes);
+ }
+ ShaderStages::all()
+ }
+ E::Derivative { expr, .. } => {
+ match resolver[expr] {
+ Ti::Scalar(Sc {
+ kind: Sk::Float, ..
+ })
+ | Ti::Vector {
+ scalar:
+ Sc {
+ kind: Sk::Float, ..
+ },
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidDerivative),
+ }
+ ShaderStages::FRAGMENT
+ }
+ E::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+ let argument_inner = &resolver[argument];
+ match fun {
+ Rf::All | Rf::Any => match *argument_inner {
+ Ti::Vector {
+ scalar: Sc { kind: Sk::Bool, .. },
+ ..
+ } => {}
+ ref other => {
+ log::error!("All/Any of type {:?}", other);
+ return Err(ExpressionError::InvalidBooleanVector(argument));
+ }
+ },
+ Rf::IsNan | Rf::IsInf => match *argument_inner {
+ Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
+ if scalar.kind == Sk::Float => {}
+ ref other => {
+ log::error!("Float test of type {:?}", other);
+ return Err(ExpressionError::InvalidFloatArgument(argument));
+ }
+ },
+ }
+ ShaderStages::all()
+ }
+ E::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ let resolve = |arg| &resolver[arg];
+ let arg_ty = resolve(arg);
+ let arg1_ty = arg1.map(resolve);
+ let arg2_ty = arg2.map(resolve);
+ let arg3_ty = arg3.map(resolve);
+ match fun {
+ Mf::Abs => {
+ if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ let good = match *arg_ty {
+ Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
+ scalar.kind != Sk::Bool
+ }
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
+ }
+ }
+ Mf::Min | Mf::Max => {
+ let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), None, None) => ty1,
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ let good = match *arg_ty {
+ Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
+ scalar.kind != Sk::Bool
+ }
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ }
+ Mf::Clamp => {
+ let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), None) => (ty1, ty2),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ let good = match *arg_ty {
+ Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => {
+ scalar.kind != Sk::Bool
+ }
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ if arg2_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ));
+ }
+ }
+ Mf::Saturate
+ | Mf::Cos
+ | Mf::Cosh
+ | Mf::Sin
+ | Mf::Sinh
+ | Mf::Tan
+ | Mf::Tanh
+ | Mf::Acos
+ | Mf::Asin
+ | Mf::Atan
+ | Mf::Asinh
+ | Mf::Acosh
+ | Mf::Atanh
+ | Mf::Radians
+ | Mf::Degrees
+ | Mf::Ceil
+ | Mf::Floor
+ | Mf::Round
+ | Mf::Fract
+ | Mf::Trunc
+ | Mf::Exp
+ | Mf::Exp2
+ | Mf::Log
+ | Mf::Log2
+ | Mf::Length
+ | Mf::Sqrt
+ | Mf::InverseSqrt => {
+ if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
+ if scalar.kind == Sk::Float => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::Sign => {
+ if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Scalar(Sc {
+ kind: Sk::Float | Sk::Sint,
+ ..
+ })
+ | Ti::Vector {
+ scalar:
+ Sc {
+ kind: Sk::Float | Sk::Sint,
+ ..
+ },
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => {
+ let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), None, None) => ty1,
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
+ if scalar.kind == Sk::Float => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ }
+ Mf::Modf | Mf::Frexp => {
+ if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ if !matches!(*arg_ty,
+ Ti::Scalar(scalar) | Ti::Vector { scalar, .. }
+ if scalar.kind == Sk::Float)
+ {
+ return Err(ExpressionError::InvalidArgumentType(fun, 1, arg));
+ }
+ }
+ Mf::Ldexp => {
+ let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), None, None) => ty1,
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ let size0 = match *arg_ty {
+ Ti::Scalar(Sc {
+ kind: Sk::Float, ..
+ }) => None,
+ Ti::Vector {
+ scalar:
+ Sc {
+ kind: Sk::Float, ..
+ },
+ size,
+ } => Some(size),
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
+ }
+ };
+ let good = match *arg1_ty {
+ Ti::Scalar(Sc { kind: Sk::Sint, .. }) if size0.is_none() => true,
+ Ti::Vector {
+ size,
+ scalar: Sc { kind: Sk::Sint, .. },
+ } if Some(size) == size0 => true,
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ }
+ Mf::Dot => {
+ let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), None, None) => ty1,
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Vector {
+ scalar:
+ Sc {
+ kind: Sk::Float | Sk::Sint | Sk::Uint,
+ ..
+ },
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ }
+ Mf::Outer | Mf::Cross | Mf::Reflect => {
+ let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), None, None) => ty1,
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Vector {
+ scalar:
+ Sc {
+ kind: Sk::Float, ..
+ },
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ }
+ Mf::Refract => {
+ let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), None) => (ty1, ty2),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+
+ match *arg_ty {
+ Ti::Vector {
+ scalar:
+ Sc {
+ kind: Sk::Float, ..
+ },
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+
+ match (arg_ty, arg2_ty) {
+ (
+ &Ti::Vector {
+ scalar:
+ Sc {
+ width: vector_width,
+ ..
+ },
+ ..
+ },
+ &Ti::Scalar(Sc {
+ width: scalar_width,
+ kind: Sk::Float,
+ }),
+ ) if vector_width == scalar_width => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ))
+ }
+ }
+ }
+ Mf::Normalize => {
+ if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Vector {
+ scalar:
+ Sc {
+ kind: Sk::Float, ..
+ },
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::FaceForward | Mf::Fma | Mf::SmoothStep => {
+ let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), None) => (ty1, ty2),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Scalar(Sc {
+ kind: Sk::Float, ..
+ })
+ | Ti::Vector {
+ scalar:
+ Sc {
+ kind: Sk::Float, ..
+ },
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ if arg2_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ));
+ }
+ }
+ Mf::Mix => {
+ let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), None) => (ty1, ty2),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ let arg_width = match *arg_ty {
+ Ti::Scalar(Sc {
+ kind: Sk::Float,
+ width,
+ })
+ | Ti::Vector {
+ scalar:
+ Sc {
+ kind: Sk::Float,
+ width,
+ },
+ ..
+ } => width,
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ };
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ // the last argument can always be a scalar
+ match *arg2_ty {
+ Ti::Scalar(Sc {
+ kind: Sk::Float,
+ width,
+ }) if width == arg_width => {}
+ _ if arg2_ty == arg_ty => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ));
+ }
+ }
+ }
+ Mf::Inverse | Mf::Determinant => {
+ if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ let good = match *arg_ty {
+ Ti::Matrix { columns, rows, .. } => columns == rows,
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
+ }
+ }
+ Mf::Transpose => {
+ if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Matrix { .. } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::CountTrailingZeros
+ | Mf::CountLeadingZeros
+ | Mf::CountOneBits
+ | Mf::ReverseBits
+ | Mf::FindLsb
+ | Mf::FindMsb => {
+ if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Scalar(Sc {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ })
+ | Ti::Vector {
+ scalar:
+ Sc {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ },
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::InsertBits => {
+ let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Scalar(Sc {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ })
+ | Ti::Vector {
+ scalar:
+ Sc {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ },
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ match *arg2_ty {
+ Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ))
+ }
+ }
+ match *arg3_ty {
+ Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg3.unwrap(),
+ ))
+ }
+ }
+ }
+ Mf::ExtractBits => {
+ let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), None) => (ty1, ty2),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Scalar(Sc {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ })
+ | Ti::Vector {
+ scalar:
+ Sc {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ },
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ match *arg1_ty {
+ Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg1.unwrap(),
+ ))
+ }
+ }
+ match *arg2_ty {
+ Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ))
+ }
+ }
+ }
+ Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => {
+ if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Vector {
+ size: crate::VectorSize::Bi,
+ scalar:
+ Sc {
+ kind: Sk::Float, ..
+ },
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::Pack4x8snorm | Mf::Pack4x8unorm => {
+ if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Vector {
+ size: crate::VectorSize::Quad,
+ scalar:
+ Sc {
+ kind: Sk::Float, ..
+ },
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::Unpack2x16float
+ | Mf::Unpack2x16snorm
+ | Mf::Unpack2x16unorm
+ | Mf::Unpack4x8snorm
+ | Mf::Unpack4x8unorm => {
+ if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ }
+ ShaderStages::all()
+ }
+ E::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ let mut base_scalar = match resolver[expr] {
+ crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => {
+ scalar
+ }
+ crate::TypeInner::Matrix { scalar, .. } => scalar,
+ _ => return Err(ExpressionError::InvalidCastArgument),
+ };
+ base_scalar.kind = kind;
+ if let Some(width) = convert {
+ base_scalar.width = width;
+ }
+ if self.check_width(base_scalar).is_err() {
+ return Err(ExpressionError::InvalidCastArgument);
+ }
+ ShaderStages::all()
+ }
+ E::CallResult(function) => mod_info.functions[function.index()].available_stages,
+ E::AtomicResult { ty, comparison } => {
+ let scalar_predicate = |ty: &crate::TypeInner| match ty {
+ &crate::TypeInner::Scalar(
+ scalar @ Sc {
+ kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
+ ..
+ },
+ ) => self.check_width(scalar).is_ok(),
+ _ => false,
+ };
+ let good = match &module.types[ty].inner {
+ ty if !comparison => scalar_predicate(ty),
+ &crate::TypeInner::Struct { ref members, .. } if comparison => {
+ validate_atomic_compare_exchange_struct(
+ &module.types,
+ members,
+ scalar_predicate,
+ )
+ }
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidAtomicResultType(ty));
+ }
+ ShaderStages::all()
+ }
+ E::WorkGroupUniformLoadResult { ty } => {
+ if self.types[ty.index()]
+ .flags
+ // Sized | Constructible is exactly the types currently supported by
+ // WorkGroupUniformLoad
+ .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE)
+ {
+ ShaderStages::COMPUTE
+ } else {
+ return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty));
+ }
+ }
+ E::ArrayLength(expr) => match resolver[expr] {
+ Ti::Pointer { base, .. } => {
+ let base_ty = &resolver.types[base];
+ if let Ti::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } = base_ty.inner
+ {
+ ShaderStages::all()
+ } else {
+ return Err(ExpressionError::InvalidArrayType(expr));
+ }
+ }
+ ref other => {
+ log::error!("Array length of {:?}", other);
+ return Err(ExpressionError::InvalidArrayType(expr));
+ }
+ },
+ E::RayQueryProceedResult => ShaderStages::all(),
+ E::RayQueryGetIntersection {
+ query,
+ committed: _,
+ } => match resolver[query] {
+ Ti::Pointer {
+ base,
+ space: crate::AddressSpace::Function,
+ } => match resolver.types[base].inner {
+ Ti::RayQuery => ShaderStages::all(),
+ ref other => {
+ log::error!("Intersection result of a pointer to {:?}", other);
+ return Err(ExpressionError::InvalidRayQueryType(query));
+ }
+ },
+ ref other => {
+ log::error!("Intersection result of {:?}", other);
+ return Err(ExpressionError::InvalidRayQueryType(query));
+ }
+ },
+ };
+ Ok(stages)
+ }
+
+ fn global_var_ty(
+ module: &crate::Module,
+ function: &crate::Function,
+ expr: Handle<crate::Expression>,
+ ) -> Result<Handle<crate::Type>, ExpressionError> {
+ use crate::Expression as Ex;
+
+ match function.expressions[expr] {
+ Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
+ Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
+ Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
+ match function.expressions[base] {
+ Ex::GlobalVariable(var_handle) => {
+ let array_ty = module.global_variables[var_handle].ty;
+
+ match module.types[array_ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => Ok(base),
+ _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
+ }
+ }
+ _ => Err(ExpressionError::ExpectedGlobalVariable),
+ }
+ }
+ _ => Err(ExpressionError::ExpectedGlobalVariable),
+ }
+ }
+
+ pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> {
+ self.check_width(literal.scalar())?;
+ check_literal_value(literal)?;
+
+ Ok(())
+ }
+}
+
+pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> {
+ let is_nan = match literal {
+ crate::Literal::F64(v) => v.is_nan(),
+ crate::Literal::F32(v) => v.is_nan(),
+ _ => false,
+ };
+ if is_nan {
+ return Err(LiteralError::NaN);
+ }
+
+ let is_infinite = match literal {
+ crate::Literal::F64(v) => v.is_infinite(),
+ crate::Literal::F32(v) => v.is_infinite(),
+ _ => false,
+ };
+ if is_infinite {
+ return Err(LiteralError::Infinity);
+ }
+
+ Ok(())
+}
+
+#[cfg(all(test, feature = "validate"))]
+/// Validate a module containing the given expression, expecting an error.
+fn validate_with_expression(
+ expr: crate::Expression,
+ caps: super::Capabilities,
+) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
+ use crate::span::Span;
+
+ let mut function = crate::Function::default();
+ function.expressions.append(expr, Span::default());
+ function.body.push(
+ crate::Statement::Emit(function.expressions.range_from(0)),
+ Span::default(),
+ );
+
+ let mut module = crate::Module::default();
+ module.functions.append(function, Span::default());
+
+ let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps);
+
+ validator.validate(&module)
+}
+
+#[cfg(all(test, feature = "validate"))]
+/// Validate a module containing the given constant expression, expecting an error.
+fn validate_with_const_expression(
+ expr: crate::Expression,
+ caps: super::Capabilities,
+) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
+ use crate::span::Span;
+
+ let mut module = crate::Module::default();
+ module.const_expressions.append(expr, Span::default());
+
+ let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
+
+ validator.validate(&module)
+}
+
+/// Using F64 in a function's expression arena is forbidden.
+#[cfg(feature = "validate")]
+#[test]
+fn f64_runtime_literals() {
+ let result = validate_with_expression(
+ crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
+ super::Capabilities::default(),
+ );
+ let error = result.unwrap_err().into_inner();
+ assert!(matches!(
+ error,
+ crate::valid::ValidationError::Function {
+ source: super::FunctionError::Expression {
+ source: super::ExpressionError::Literal(super::LiteralError::Width(
+ super::r#type::WidthError::MissingCapability {
+ name: "f64",
+ flag: "FLOAT64",
+ }
+ ),),
+ ..
+ },
+ ..
+ }
+ ));
+
+ let result = validate_with_expression(
+ crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
+ super::Capabilities::default() | super::Capabilities::FLOAT64,
+ );
+ assert!(result.is_ok());
+}
+
+/// Using F64 in a module's constant expression arena is forbidden.
+#[cfg(feature = "validate")]
+#[test]
+fn f64_const_literals() {
+ let result = validate_with_const_expression(
+ crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
+ super::Capabilities::default(),
+ );
+ let error = result.unwrap_err().into_inner();
+ assert!(matches!(
+ error,
+ crate::valid::ValidationError::ConstExpression {
+ source: super::ConstExpressionError::Literal(super::LiteralError::Width(
+ super::r#type::WidthError::MissingCapability {
+ name: "f64",
+ flag: "FLOAT64",
+ }
+ )),
+ ..
+ }
+ ));
+
+ let result = validate_with_const_expression(
+ crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
+ super::Capabilities::default() | super::Capabilities::FLOAT64,
+ );
+ assert!(result.is_ok());
+}
+
+/// Using I64 in a function's expression arena is forbidden.
+#[cfg(feature = "validate")]
+#[test]
+fn i64_runtime_literals() {
+ let result = validate_with_expression(
+ crate::Expression::Literal(crate::Literal::I64(1729)),
+ // There is no capability that enables this.
+ super::Capabilities::all(),
+ );
+ let error = result.unwrap_err().into_inner();
+ assert!(matches!(
+ error,
+ crate::valid::ValidationError::Function {
+ source: super::FunctionError::Expression {
+ source: super::ExpressionError::Literal(super::LiteralError::Width(
+ super::r#type::WidthError::Unsupported64Bit
+ ),),
+ ..
+ },
+ ..
+ }
+ ));
+}
+
+/// Using I64 in a module's constant expression arena is forbidden.
+#[cfg(feature = "validate")]
+#[test]
+fn i64_const_literals() {
+ let result = validate_with_const_expression(
+ crate::Expression::Literal(crate::Literal::I64(1729)),
+ // There is no capability that enables this.
+ super::Capabilities::all(),
+ );
+ let error = result.unwrap_err().into_inner();
+ assert!(matches!(
+ error,
+ crate::valid::ValidationError::ConstExpression {
+ source: super::ConstExpressionError::Literal(super::LiteralError::Width(
+ super::r#type::WidthError::Unsupported64Bit,
+ ),),
+ ..
+ }
+ ));
+}
diff --git a/third_party/rust/naga/src/valid/function.rs b/third_party/rust/naga/src/valid/function.rs
new file mode 100644
index 0000000000..f0ca22cbda
--- /dev/null
+++ b/third_party/rust/naga/src/valid/function.rs
@@ -0,0 +1,1056 @@
+use crate::arena::Handle;
+use crate::arena::{Arena, UniqueArena};
+
+use super::validate_atomic_compare_exchange_struct;
+
+use super::{
+ analyzer::{UniformityDisruptor, UniformityRequirements},
+ ExpressionError, FunctionInfo, ModuleInfo,
+};
+use crate::span::WithSpan;
+use crate::span::{AddSpan as _, MapErrWithSpan as _};
+
+use bit_set::BitSet;
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum CallError {
+ #[error("Argument {index} expression is invalid")]
+ Argument {
+ index: usize,
+ source: ExpressionError,
+ },
+ #[error("Result expression {0:?} has already been introduced earlier")]
+ ResultAlreadyInScope(Handle<crate::Expression>),
+ #[error("Result value is invalid")]
+ ResultValue(#[source] ExpressionError),
+ #[error("Requires {required} arguments, but {seen} are provided")]
+ ArgumentCount { required: usize, seen: usize },
+ #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")]
+ ArgumentType {
+ index: usize,
+ required: Handle<crate::Type>,
+ seen_expression: Handle<crate::Expression>,
+ },
+ #[error("The emitted expression doesn't match the call")]
+ ExpressionMismatch(Option<Handle<crate::Expression>>),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum AtomicError {
+ #[error("Pointer {0:?} to atomic is invalid.")]
+ InvalidPointer(Handle<crate::Expression>),
+ #[error("Operand {0:?} has invalid type.")]
+ InvalidOperand(Handle<crate::Expression>),
+ #[error("Result type for {0:?} doesn't match the statement")]
+ ResultTypeMismatch(Handle<crate::Expression>),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum LocalVariableError {
+ #[error("Local variable has a type {0:?} that can't be stored in a local variable.")]
+ InvalidType(Handle<crate::Type>),
+ #[error("Initializer doesn't match the variable type")]
+ InitializerType,
+ #[error("Initializer is not const")]
+ NonConstInitializer,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum FunctionError {
+ #[error("Expression {handle:?} is invalid")]
+ Expression {
+ handle: Handle<crate::Expression>,
+ source: ExpressionError,
+ },
+ #[error("Expression {0:?} can't be introduced - it's already in scope")]
+ ExpressionAlreadyInScope(Handle<crate::Expression>),
+ #[error("Local variable {handle:?} '{name}' is invalid")]
+ LocalVariable {
+ handle: Handle<crate::LocalVariable>,
+ name: String,
+ source: LocalVariableError,
+ },
+ #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")]
+ InvalidArgumentType { index: usize, name: String },
+ #[error("The function's given return type cannot be returned from functions")]
+ NonConstructibleReturnType,
+ #[error("Argument '{name}' at index {index} is a pointer of space {space:?}, which can't be passed into functions.")]
+ InvalidArgumentPointerSpace {
+ index: usize,
+ name: String,
+ space: crate::AddressSpace,
+ },
+ #[error("There are instructions after `return`/`break`/`continue`")]
+ InstructionsAfterReturn,
+ #[error("The `break` is used outside of a `loop` or `switch` context")]
+ BreakOutsideOfLoopOrSwitch,
+ #[error("The `continue` is used outside of a `loop` context")]
+ ContinueOutsideOfLoop,
+ #[error("The `return` is called within a `continuing` block")]
+ InvalidReturnSpot,
+ #[error("The `return` value {0:?} does not match the function return value")]
+ InvalidReturnType(Option<Handle<crate::Expression>>),
+ #[error("The `if` condition {0:?} is not a boolean scalar")]
+ InvalidIfType(Handle<crate::Expression>),
+ #[error("The `switch` value {0:?} is not an integer scalar")]
+ InvalidSwitchType(Handle<crate::Expression>),
+ #[error("Multiple `switch` cases for {0:?} are present")]
+ ConflictingSwitchCase(crate::SwitchValue),
+ #[error("The `switch` contains cases with conflicting types")]
+ ConflictingCaseType,
+ #[error("The `switch` is missing a `default` case")]
+ MissingDefaultCase,
+ #[error("Multiple `default` cases are present")]
+ MultipleDefaultCases,
+ #[error("The last `switch` case contains a `fallthrough`")]
+ LastCaseFallTrough,
+ #[error("The pointer {0:?} doesn't relate to a valid destination for a store")]
+ InvalidStorePointer(Handle<crate::Expression>),
+ #[error("The value {0:?} can not be stored")]
+ InvalidStoreValue(Handle<crate::Expression>),
+ #[error("The type of {value:?} doesn't match the type stored in {pointer:?}")]
+ InvalidStoreTypes {
+ pointer: Handle<crate::Expression>,
+ value: Handle<crate::Expression>,
+ },
+ #[error("Image store parameters are invalid")]
+ InvalidImageStore(#[source] ExpressionError),
+ #[error("Call to {function:?} is invalid")]
+ InvalidCall {
+ function: Handle<crate::Function>,
+ #[source]
+ error: CallError,
+ },
+ #[error("Atomic operation is invalid")]
+ InvalidAtomic(#[from] AtomicError),
+ #[error("Ray Query {0:?} is not a local variable")]
+ InvalidRayQueryExpression(Handle<crate::Expression>),
+ #[error("Acceleration structure {0:?} is not a matching expression")]
+ InvalidAccelerationStructure(Handle<crate::Expression>),
+ #[error("Ray descriptor {0:?} is not a matching expression")]
+ InvalidRayDescriptor(Handle<crate::Expression>),
+ #[error("Ray Query {0:?} does not have a matching type")]
+ InvalidRayQueryType(Handle<crate::Type>),
+ #[error(
+ "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
+ )]
+ NonUniformControlFlow(
+ UniformityRequirements,
+ Handle<crate::Expression>,
+ UniformityDisruptor,
+ ),
+ #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their arguments: \"{name}\" has attributes")]
+ PipelineInputRegularFunction { name: String },
+ #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their return value types")]
+ PipelineOutputRegularFunction,
+ #[error("Required uniformity for WorkGroupUniformLoad is not fulfilled because of {0:?}")]
+ // The actual load statement will be "pointed to" by the span
+ NonUniformWorkgroupUniformLoad(UniformityDisruptor),
+ // This is only possible with a misbehaving frontend
+ #[error("The expression {0:?} for a WorkGroupUniformLoad isn't a WorkgroupUniformLoadResult")]
+ WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
+ #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
+ WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
+}
+
+bitflags::bitflags! {
+ #[repr(transparent)]
+ #[derive(Clone, Copy)]
+ struct ControlFlowAbility: u8 {
+ /// The control can return out of this block.
+ const RETURN = 0x1;
+ /// The control can break.
+ const BREAK = 0x2;
+ /// The control can continue.
+ const CONTINUE = 0x4;
+ }
+}
+
+struct BlockInfo {
+ stages: super::ShaderStages,
+ finished: bool,
+}
+
+struct BlockContext<'a> {
+ abilities: ControlFlowAbility,
+ info: &'a FunctionInfo,
+ expressions: &'a Arena<crate::Expression>,
+ types: &'a UniqueArena<crate::Type>,
+ local_vars: &'a Arena<crate::LocalVariable>,
+ global_vars: &'a Arena<crate::GlobalVariable>,
+ functions: &'a Arena<crate::Function>,
+ special_types: &'a crate::SpecialTypes,
+ prev_infos: &'a [FunctionInfo],
+ return_type: Option<Handle<crate::Type>>,
+}
+
+impl<'a> BlockContext<'a> {
+ fn new(
+ fun: &'a crate::Function,
+ module: &'a crate::Module,
+ info: &'a FunctionInfo,
+ prev_infos: &'a [FunctionInfo],
+ ) -> Self {
+ Self {
+ abilities: ControlFlowAbility::RETURN,
+ info,
+ expressions: &fun.expressions,
+ types: &module.types,
+ local_vars: &fun.local_variables,
+ global_vars: &module.global_variables,
+ functions: &module.functions,
+ special_types: &module.special_types,
+ prev_infos,
+ return_type: fun.result.as_ref().map(|fr| fr.ty),
+ }
+ }
+
+ const fn with_abilities(&self, abilities: ControlFlowAbility) -> Self {
+ BlockContext { abilities, ..*self }
+ }
+
+ fn get_expression(&self, handle: Handle<crate::Expression>) -> &'a crate::Expression {
+ &self.expressions[handle]
+ }
+
+ fn resolve_type_impl(
+ &self,
+ handle: Handle<crate::Expression>,
+ valid_expressions: &BitSet,
+ ) -> Result<&crate::TypeInner, WithSpan<ExpressionError>> {
+ if handle.index() >= self.expressions.len() {
+ Err(ExpressionError::DoesntExist.with_span())
+ } else if !valid_expressions.contains(handle.index()) {
+ Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions))
+ } else {
+ Ok(self.info[handle].ty.inner_with(self.types))
+ }
+ }
+
+ fn resolve_type(
+ &self,
+ handle: Handle<crate::Expression>,
+ valid_expressions: &BitSet,
+ ) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
+ self.resolve_type_impl(handle, valid_expressions)
+ .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span())
+ }
+
+ fn resolve_pointer_type(
+ &self,
+ handle: Handle<crate::Expression>,
+ ) -> Result<&crate::TypeInner, FunctionError> {
+ if handle.index() >= self.expressions.len() {
+ Err(FunctionError::Expression {
+ handle,
+ source: ExpressionError::DoesntExist,
+ })
+ } else {
+ Ok(self.info[handle].ty.inner_with(self.types))
+ }
+ }
+}
+
+impl super::Validator {
+ fn validate_call(
+ &mut self,
+ function: Handle<crate::Function>,
+ arguments: &[Handle<crate::Expression>],
+ result: Option<Handle<crate::Expression>>,
+ context: &BlockContext,
+ ) -> Result<super::ShaderStages, WithSpan<CallError>> {
+ let fun = &context.functions[function];
+ if fun.arguments.len() != arguments.len() {
+ return Err(CallError::ArgumentCount {
+ required: fun.arguments.len(),
+ seen: arguments.len(),
+ }
+ .with_span());
+ }
+ for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() {
+ let ty = context
+ .resolve_type_impl(expr, &self.valid_expression_set)
+ .map_err_inner(|source| {
+ CallError::Argument { index, source }
+ .with_span_handle(expr, context.expressions)
+ })?;
+ let arg_inner = &context.types[arg.ty].inner;
+ if !ty.equivalent(arg_inner, context.types) {
+ return Err(CallError::ArgumentType {
+ index,
+ required: arg.ty,
+ seen_expression: expr,
+ }
+ .with_span_handle(expr, context.expressions));
+ }
+ }
+
+ if let Some(expr) = result {
+ if self.valid_expression_set.insert(expr.index()) {
+ self.valid_expression_list.push(expr);
+ } else {
+ return Err(CallError::ResultAlreadyInScope(expr)
+ .with_span_handle(expr, context.expressions));
+ }
+ match context.expressions[expr] {
+ crate::Expression::CallResult(callee)
+ if fun.result.is_some() && callee == function => {}
+ _ => {
+ return Err(CallError::ExpressionMismatch(result)
+ .with_span_handle(expr, context.expressions))
+ }
+ }
+ } else if fun.result.is_some() {
+ return Err(CallError::ExpressionMismatch(result).with_span());
+ }
+
+ let callee_info = &context.prev_infos[function.index()];
+ Ok(callee_info.available_stages)
+ }
+
+ fn emit_expression(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ context: &BlockContext,
+ ) -> Result<(), WithSpan<FunctionError>> {
+ if self.valid_expression_set.insert(handle.index()) {
+ self.valid_expression_list.push(handle);
+ Ok(())
+ } else {
+ Err(FunctionError::ExpressionAlreadyInScope(handle)
+ .with_span_handle(handle, context.expressions))
+ }
+ }
+
+ fn validate_atomic(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ fun: &crate::AtomicFunction,
+ value: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ context: &BlockContext,
+ ) -> Result<(), WithSpan<FunctionError>> {
+ let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?;
+ let ptr_scalar = match *pointer_inner {
+ crate::TypeInner::Pointer { base, .. } => match context.types[base].inner {
+ crate::TypeInner::Atomic(scalar) => scalar,
+ ref other => {
+ log::error!("Atomic pointer to type {:?}", other);
+ return Err(AtomicError::InvalidPointer(pointer)
+ .with_span_handle(pointer, context.expressions)
+ .into_other());
+ }
+ },
+ ref other => {
+ log::error!("Atomic on type {:?}", other);
+ return Err(AtomicError::InvalidPointer(pointer)
+ .with_span_handle(pointer, context.expressions)
+ .into_other());
+ }
+ };
+
+ let value_inner = context.resolve_type(value, &self.valid_expression_set)?;
+ match *value_inner {
+ crate::TypeInner::Scalar(scalar) if scalar == ptr_scalar => {}
+ ref other => {
+ log::error!("Atomic operand type {:?}", other);
+ return Err(AtomicError::InvalidOperand(value)
+ .with_span_handle(value, context.expressions)
+ .into_other());
+ }
+ }
+
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ if context.resolve_type(cmp, &self.valid_expression_set)? != value_inner {
+ log::error!("Atomic exchange comparison has a different type from the value");
+ return Err(AtomicError::InvalidOperand(cmp)
+ .with_span_handle(cmp, context.expressions)
+ .into_other());
+ }
+ }
+
+ self.emit_expression(result, context)?;
+ match context.expressions[result] {
+ crate::Expression::AtomicResult { ty, comparison }
+ if {
+ let scalar_predicate =
+ |ty: &crate::TypeInner| *ty == crate::TypeInner::Scalar(ptr_scalar);
+ match &context.types[ty].inner {
+ ty if !comparison => scalar_predicate(ty),
+ &crate::TypeInner::Struct { ref members, .. } if comparison => {
+ validate_atomic_compare_exchange_struct(
+ context.types,
+ members,
+ scalar_predicate,
+ )
+ }
+ _ => false,
+ }
+ } => {}
+ _ => {
+ return Err(AtomicError::ResultTypeMismatch(result)
+ .with_span_handle(result, context.expressions)
+ .into_other())
+ }
+ }
+ Ok(())
+ }
+
+ fn validate_block_impl(
+ &mut self,
+ statements: &crate::Block,
+ context: &BlockContext,
+ ) -> Result<BlockInfo, WithSpan<FunctionError>> {
+ use crate::{AddressSpace, Statement as S, TypeInner as Ti};
+ let mut finished = false;
+ let mut stages = super::ShaderStages::all();
+ for (statement, &span) in statements.span_iter() {
+ if finished {
+ return Err(FunctionError::InstructionsAfterReturn
+ .with_span_static(span, "instructions after return"));
+ }
+ match *statement {
+ S::Emit(ref range) => {
+ for handle in range.clone() {
+ self.emit_expression(handle, context)?;
+ }
+ }
+ S::Block(ref block) => {
+ let info = self.validate_block(block, context)?;
+ stages &= info.stages;
+ finished = info.finished;
+ }
+ S::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ match *context.resolve_type(condition, &self.valid_expression_set)? {
+ Ti::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: _,
+ }) => {}
+ _ => {
+ return Err(FunctionError::InvalidIfType(condition)
+ .with_span_handle(condition, context.expressions))
+ }
+ }
+ stages &= self.validate_block(accept, context)?.stages;
+ stages &= self.validate_block(reject, context)?.stages;
+ }
+ S::Switch {
+ selector,
+ ref cases,
+ } => {
+ let uint = match context
+ .resolve_type(selector, &self.valid_expression_set)?
+ .scalar_kind()
+ {
+ Some(crate::ScalarKind::Uint) => true,
+ Some(crate::ScalarKind::Sint) => false,
+ _ => {
+ return Err(FunctionError::InvalidSwitchType(selector)
+ .with_span_handle(selector, context.expressions))
+ }
+ };
+ self.switch_values.clear();
+ for case in cases {
+ match case.value {
+ crate::SwitchValue::I32(_) if !uint => {}
+ crate::SwitchValue::U32(_) if uint => {}
+ crate::SwitchValue::Default => {}
+ _ => {
+ return Err(FunctionError::ConflictingCaseType.with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "conflicting switch arm here",
+ ));
+ }
+ };
+ if !self.switch_values.insert(case.value) {
+ return Err(match case.value {
+ crate::SwitchValue::Default => FunctionError::MultipleDefaultCases
+ .with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "duplicated switch arm here",
+ ),
+ _ => FunctionError::ConflictingSwitchCase(case.value)
+ .with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "conflicting switch arm here",
+ ),
+ });
+ }
+ }
+ if !self.switch_values.contains(&crate::SwitchValue::Default) {
+ return Err(FunctionError::MissingDefaultCase
+ .with_span_static(span, "missing default case"));
+ }
+ if let Some(case) = cases.last() {
+ if case.fall_through {
+ return Err(FunctionError::LastCaseFallTrough.with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "bad switch arm here",
+ ));
+ }
+ }
+ let pass_through_abilities = context.abilities
+ & (ControlFlowAbility::RETURN | ControlFlowAbility::CONTINUE);
+ let sub_context =
+ context.with_abilities(pass_through_abilities | ControlFlowAbility::BREAK);
+ for case in cases {
+ stages &= self.validate_block(&case.body, &sub_context)?.stages;
+ }
+ }
+ S::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ // special handling for block scoping is needed here,
+ // because the continuing{} block inherits the scope
+ let base_expression_count = self.valid_expression_list.len();
+ let pass_through_abilities = context.abilities & ControlFlowAbility::RETURN;
+ stages &= self
+ .validate_block_impl(
+ body,
+ &context.with_abilities(
+ pass_through_abilities
+ | ControlFlowAbility::BREAK
+ | ControlFlowAbility::CONTINUE,
+ ),
+ )?
+ .stages;
+ stages &= self
+ .validate_block_impl(
+ continuing,
+ &context.with_abilities(ControlFlowAbility::empty()),
+ )?
+ .stages;
+
+ if let Some(condition) = break_if {
+ match *context.resolve_type(condition, &self.valid_expression_set)? {
+ Ti::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: _,
+ }) => {}
+ _ => {
+ return Err(FunctionError::InvalidIfType(condition)
+ .with_span_handle(condition, context.expressions))
+ }
+ }
+ }
+
+ for handle in self.valid_expression_list.drain(base_expression_count..) {
+ self.valid_expression_set.remove(handle.index());
+ }
+ }
+ S::Break => {
+ if !context.abilities.contains(ControlFlowAbility::BREAK) {
+ return Err(FunctionError::BreakOutsideOfLoopOrSwitch
+ .with_span_static(span, "invalid break"));
+ }
+ finished = true;
+ }
+ S::Continue => {
+ if !context.abilities.contains(ControlFlowAbility::CONTINUE) {
+ return Err(FunctionError::ContinueOutsideOfLoop
+ .with_span_static(span, "invalid continue"));
+ }
+ finished = true;
+ }
+ S::Return { value } => {
+ if !context.abilities.contains(ControlFlowAbility::RETURN) {
+ return Err(FunctionError::InvalidReturnSpot
+ .with_span_static(span, "invalid return"));
+ }
+ let value_ty = value
+ .map(|expr| context.resolve_type(expr, &self.valid_expression_set))
+ .transpose()?;
+ let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
+ // We can't return pointers, but it seems best not to embed that
+ // assumption here, so use `TypeInner::equivalent` for comparison.
+ let okay = match (value_ty, expected_ty) {
+ (None, None) => true,
+ (Some(value_inner), Some(expected_inner)) => {
+ value_inner.equivalent(expected_inner, context.types)
+ }
+ (_, _) => false,
+ };
+
+ if !okay {
+ log::error!(
+ "Returning {:?} where {:?} is expected",
+ value_ty,
+ expected_ty
+ );
+ if let Some(handle) = value {
+ return Err(FunctionError::InvalidReturnType(value)
+ .with_span_handle(handle, context.expressions));
+ } else {
+ return Err(FunctionError::InvalidReturnType(value)
+ .with_span_static(span, "invalid return"));
+ }
+ }
+ finished = true;
+ }
+ S::Kill => {
+ stages &= super::ShaderStages::FRAGMENT;
+ finished = true;
+ }
+ S::Barrier(_) => {
+ stages &= super::ShaderStages::COMPUTE;
+ }
+ S::Store { pointer, value } => {
+ let mut current = pointer;
+ loop {
+ let _ = context
+ .resolve_pointer_type(current)
+ .map_err(|e| e.with_span())?;
+ match context.expressions[current] {
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => current = base,
+ crate::Expression::LocalVariable(_)
+ | crate::Expression::GlobalVariable(_)
+ | crate::Expression::FunctionArgument(_) => break,
+ _ => {
+ return Err(FunctionError::InvalidStorePointer(current)
+ .with_span_handle(pointer, context.expressions))
+ }
+ }
+ }
+
+ let value_ty = context.resolve_type(value, &self.valid_expression_set)?;
+ match *value_ty {
+ Ti::Image { .. } | Ti::Sampler { .. } => {
+ return Err(FunctionError::InvalidStoreValue(value)
+ .with_span_handle(value, context.expressions));
+ }
+ _ => {}
+ }
+
+ let pointer_ty = context
+ .resolve_pointer_type(pointer)
+ .map_err(|e| e.with_span())?;
+
+ let good = match *pointer_ty {
+ Ti::Pointer { base, space: _ } => match context.types[base].inner {
+ Ti::Atomic(scalar) => *value_ty == Ti::Scalar(scalar),
+ ref other => value_ty == other,
+ },
+ Ti::ValuePointer {
+ size: Some(size),
+ scalar,
+ space: _,
+ } => *value_ty == Ti::Vector { size, scalar },
+ Ti::ValuePointer {
+ size: None,
+ scalar,
+ space: _,
+ } => *value_ty == Ti::Scalar(scalar),
+ _ => false,
+ };
+ if !good {
+ return Err(FunctionError::InvalidStoreTypes { pointer, value }
+ .with_span()
+ .with_handle(pointer, context.expressions)
+ .with_handle(value, context.expressions));
+ }
+
+ if let Some(space) = pointer_ty.pointer_space() {
+ if !space.access().contains(crate::StorageAccess::STORE) {
+ return Err(FunctionError::InvalidStorePointer(pointer)
+ .with_span_static(
+ context.expressions.get_span(pointer),
+ "writing to this location is not permitted",
+ ));
+ }
+ }
+ }
+ S::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ //Note: this code uses a lot of `FunctionError::InvalidImageStore`,
+ // and could probably be refactored.
+ let var = match *context.get_expression(image) {
+ crate::Expression::GlobalVariable(var_handle) => {
+ &context.global_vars[var_handle]
+ }
+ // We're looking at a binding index situation, so punch through the index and look at the global behind it.
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => {
+ match *context.get_expression(base) {
+ crate::Expression::GlobalVariable(var_handle) => {
+ &context.global_vars[var_handle]
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::ExpectedGlobalVariable,
+ )
+ .with_span_handle(image, context.expressions))
+ }
+ }
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::ExpectedGlobalVariable,
+ )
+ .with_span_handle(image, context.expressions))
+ }
+ };
+
+ // Punch through a binding array to get the underlying type
+ let global_ty = match context.types[var.ty].inner {
+ Ti::BindingArray { base, .. } => &context.types[base].inner,
+ ref inner => inner,
+ };
+
+ let value_ty = match *global_ty {
+ Ti::Image {
+ class,
+ arrayed,
+ dim,
+ } => {
+ match context
+ .resolve_type(coordinate, &self.valid_expression_set)?
+ .image_storage_coordinates()
+ {
+ Some(coord_dim) if coord_dim == dim => {}
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageCoordinateType(
+ dim, coordinate,
+ ),
+ )
+ .with_span_handle(coordinate, context.expressions));
+ }
+ };
+ if arrayed != array_index.is_some() {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageArrayIndex,
+ )
+ .with_span_handle(coordinate, context.expressions));
+ }
+ if let Some(expr) = array_index {
+ match *context.resolve_type(expr, &self.valid_expression_set)? {
+ Ti::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ width: _,
+ }) => {}
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageArrayIndexType(expr),
+ )
+ .with_span_handle(expr, context.expressions));
+ }
+ }
+ }
+ match class {
+ crate::ImageClass::Storage { format, .. } => {
+ crate::TypeInner::Vector {
+ size: crate::VectorSize::Quad,
+ scalar: crate::Scalar {
+ kind: format.into(),
+ width: 4,
+ },
+ }
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageClass(class),
+ )
+ .with_span_handle(image, context.expressions));
+ }
+ }
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::ExpectedImageType(var.ty),
+ )
+ .with_span()
+ .with_handle(var.ty, context.types)
+ .with_handle(image, context.expressions))
+ }
+ };
+
+ if *context.resolve_type(value, &self.valid_expression_set)? != value_ty {
+ return Err(FunctionError::InvalidStoreValue(value)
+ .with_span_handle(value, context.expressions));
+ }
+ }
+ S::Call {
+ function,
+ ref arguments,
+ result,
+ } => match self.validate_call(function, arguments, result, context) {
+ Ok(callee_stages) => stages &= callee_stages,
+ Err(error) => {
+ return Err(error.and_then(|error| {
+ FunctionError::InvalidCall { function, error }
+ .with_span_static(span, "invalid function call")
+ }))
+ }
+ },
+ S::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ self.validate_atomic(pointer, fun, value, result, context)?;
+ }
+ S::WorkGroupUniformLoad { pointer, result } => {
+ stages &= super::ShaderStages::COMPUTE;
+ let pointer_inner =
+ context.resolve_type(pointer, &self.valid_expression_set)?;
+ match *pointer_inner {
+ Ti::Pointer {
+ space: AddressSpace::WorkGroup,
+ ..
+ } => {}
+ Ti::ValuePointer {
+ space: AddressSpace::WorkGroup,
+ ..
+ } => {}
+ _ => {
+ return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
+ .with_span_static(span, "WorkGroupUniformLoad"))
+ }
+ }
+ self.emit_expression(result, context)?;
+ let ty = match &context.expressions[result] {
+ &crate::Expression::WorkGroupUniformLoadResult { ty } => ty,
+ _ => {
+ return Err(FunctionError::WorkgroupUniformLoadExpressionMismatch(
+ result,
+ )
+ .with_span_static(span, "WorkGroupUniformLoad"));
+ }
+ };
+ let expected_pointer_inner = Ti::Pointer {
+ base: ty,
+ space: AddressSpace::WorkGroup,
+ };
+ if !expected_pointer_inner.equivalent(pointer_inner, context.types) {
+ return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer)
+ .with_span_static(span, "WorkGroupUniformLoad"));
+ }
+ }
+ S::RayQuery { query, ref fun } => {
+ let query_var = match *context.get_expression(query) {
+ crate::Expression::LocalVariable(var) => &context.local_vars[var],
+ ref other => {
+ log::error!("Unexpected ray query expression {other:?}");
+ return Err(FunctionError::InvalidRayQueryExpression(query)
+ .with_span_static(span, "invalid query expression"));
+ }
+ };
+ match context.types[query_var.ty].inner {
+ Ti::RayQuery => {}
+ ref other => {
+ log::error!("Unexpected ray query type {other:?}");
+ return Err(FunctionError::InvalidRayQueryType(query_var.ty)
+ .with_span_static(span, "invalid query type"));
+ }
+ }
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ match *context
+ .resolve_type(acceleration_structure, &self.valid_expression_set)?
+ {
+ Ti::AccelerationStructure => {}
+ _ => {
+ return Err(FunctionError::InvalidAccelerationStructure(
+ acceleration_structure,
+ )
+ .with_span_static(span, "invalid acceleration structure"))
+ }
+ }
+ let desc_ty_given =
+ context.resolve_type(descriptor, &self.valid_expression_set)?;
+ let desc_ty_expected = context
+ .special_types
+ .ray_desc
+ .map(|handle| &context.types[handle].inner);
+ if Some(desc_ty_given) != desc_ty_expected {
+ return Err(FunctionError::InvalidRayDescriptor(descriptor)
+ .with_span_static(span, "invalid ray descriptor"));
+ }
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ self.emit_expression(result, context)?;
+ }
+ crate::RayQueryFunction::Terminate => {}
+ }
+ }
+ }
+ }
+ Ok(BlockInfo { stages, finished })
+ }
+
+ fn validate_block(
+ &mut self,
+ statements: &crate::Block,
+ context: &BlockContext,
+ ) -> Result<BlockInfo, WithSpan<FunctionError>> {
+ let base_expression_count = self.valid_expression_list.len();
+ let info = self.validate_block_impl(statements, context)?;
+ for handle in self.valid_expression_list.drain(base_expression_count..) {
+ self.valid_expression_set.remove(handle.index());
+ }
+ Ok(info)
+ }
+
+ fn validate_local_var(
+ &self,
+ var: &crate::LocalVariable,
+ gctx: crate::proc::GlobalCtx,
+ fun_info: &FunctionInfo,
+ expression_constness: &crate::proc::ExpressionConstnessTracker,
+ ) -> Result<(), LocalVariableError> {
+ log::debug!("var {:?}", var);
+ let type_info = self
+ .types
+ .get(var.ty.index())
+ .ok_or(LocalVariableError::InvalidType(var.ty))?;
+ if !type_info.flags.contains(super::TypeFlags::CONSTRUCTIBLE) {
+ return Err(LocalVariableError::InvalidType(var.ty));
+ }
+
+ if let Some(init) = var.init {
+ let decl_ty = &gctx.types[var.ty].inner;
+ let init_ty = fun_info[init].ty.inner_with(gctx.types);
+ if !decl_ty.equivalent(init_ty, gctx.types) {
+ return Err(LocalVariableError::InitializerType);
+ }
+
+ if !expression_constness.is_const(init) {
+ return Err(LocalVariableError::NonConstInitializer);
+ }
+ }
+
+ Ok(())
+ }
+
+ pub(super) fn validate_function(
+ &mut self,
+ fun: &crate::Function,
+ module: &crate::Module,
+ mod_info: &ModuleInfo,
+ entry_point: bool,
+ ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
+ let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
+
+ let expression_constness =
+ crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions);
+
+ for (var_handle, var) in fun.local_variables.iter() {
+ self.validate_local_var(var, module.to_ctx(), &info, &expression_constness)
+ .map_err(|source| {
+ FunctionError::LocalVariable {
+ handle: var_handle,
+ name: var.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(var.ty, &module.types)
+ .with_handle(var_handle, &fun.local_variables)
+ })?;
+ }
+
+ for (index, argument) in fun.arguments.iter().enumerate() {
+ match module.types[argument.ty].inner.pointer_space() {
+ Some(crate::AddressSpace::Private | crate::AddressSpace::Function) | None => {}
+ Some(other) => {
+ return Err(FunctionError::InvalidArgumentPointerSpace {
+ index,
+ name: argument.name.clone().unwrap_or_default(),
+ space: other,
+ }
+ .with_span_handle(argument.ty, &module.types))
+ }
+ }
+ // Check for the least informative error last.
+ if !self.types[argument.ty.index()]
+ .flags
+ .contains(super::TypeFlags::ARGUMENT)
+ {
+ return Err(FunctionError::InvalidArgumentType {
+ index,
+ name: argument.name.clone().unwrap_or_default(),
+ }
+ .with_span_handle(argument.ty, &module.types));
+ }
+
+ if !entry_point && argument.binding.is_some() {
+ return Err(FunctionError::PipelineInputRegularFunction {
+ name: argument.name.clone().unwrap_or_default(),
+ }
+ .with_span_handle(argument.ty, &module.types));
+ }
+ }
+
+ if let Some(ref result) = fun.result {
+ if !self.types[result.ty.index()]
+ .flags
+ .contains(super::TypeFlags::CONSTRUCTIBLE)
+ {
+ return Err(FunctionError::NonConstructibleReturnType
+ .with_span_handle(result.ty, &module.types));
+ }
+
+ if !entry_point && result.binding.is_some() {
+ return Err(FunctionError::PipelineOutputRegularFunction
+ .with_span_handle(result.ty, &module.types));
+ }
+ }
+
+ self.valid_expression_set.clear();
+ self.valid_expression_list.clear();
+ for (handle, expr) in fun.expressions.iter() {
+ if expr.needs_pre_emit() {
+ self.valid_expression_set.insert(handle.index());
+ }
+ if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
+ match self.validate_expression(handle, expr, fun, module, &info, mod_info) {
+ Ok(stages) => info.available_stages &= stages,
+ Err(source) => {
+ return Err(FunctionError::Expression { handle, source }
+ .with_span_handle(handle, &fun.expressions))
+ }
+ }
+ }
+ }
+
+ if self.flags.contains(super::ValidationFlags::BLOCKS) {
+ let stages = self
+ .validate_block(
+ &fun.body,
+ &BlockContext::new(fun, module, &info, &mod_info.functions),
+ )?
+ .stages;
+ info.available_stages &= stages;
+ }
+ Ok(info)
+ }
+}
diff --git a/third_party/rust/naga/src/valid/handles.rs b/third_party/rust/naga/src/valid/handles.rs
new file mode 100644
index 0000000000..e482f293bb
--- /dev/null
+++ b/third_party/rust/naga/src/valid/handles.rs
@@ -0,0 +1,699 @@
+//! Implementation of `Validator::validate_module_handles`.
+
+use crate::{
+ arena::{BadHandle, BadRangeError},
+ Handle,
+};
+
+use crate::{Arena, UniqueArena};
+
+use super::ValidationError;
+
+use std::{convert::TryInto, hash::Hash, num::NonZeroU32};
+
+impl super::Validator {
+ /// Validates that all handles within `module` are:
+ ///
+ /// * Valid, in the sense that they contain indices within each arena structure inside the
+ /// [`crate::Module`] type.
+ /// * No arena contents contain any items that have forward dependencies; that is, the value
+ /// associated with a handle only may contain references to handles in the same arena that
+ /// were constructed before it.
+ ///
+ /// By validating the above conditions, we free up subsequent logic to assume that handle
+ /// accesses are infallible.
+ ///
+ /// # Errors
+ ///
+ /// Errors returned by this method are intentionally sparse, for simplicity of implementation.
+ /// It is expected that only buggy frontends or fuzzers should ever emit IR that fails this
+ /// validation pass.
+ pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> {
+ let &crate::Module {
+ ref constants,
+ ref entry_points,
+ ref functions,
+ ref global_variables,
+ ref types,
+ ref special_types,
+ ref const_expressions,
+ } = module;
+
+ // NOTE: Types being first is important. All other forms of validation depend on this.
+ for (this_handle, ty) in types.iter() {
+ match ty.inner {
+ crate::TypeInner::Scalar { .. }
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. }
+ | crate::TypeInner::ValuePointer { .. }
+ | crate::TypeInner::Atomic { .. }
+ | crate::TypeInner::Image { .. }
+ | crate::TypeInner::Sampler { .. }
+ | crate::TypeInner::AccelerationStructure
+ | crate::TypeInner::RayQuery => (),
+ crate::TypeInner::Pointer { base, space: _ } => {
+ this_handle.check_dep(base)?;
+ }
+ crate::TypeInner::Array { base, .. }
+ | crate::TypeInner::BindingArray { base, .. } => {
+ this_handle.check_dep(base)?;
+ }
+ crate::TypeInner::Struct {
+ ref members,
+ span: _,
+ } => {
+ this_handle.check_dep_iter(members.iter().map(|m| m.ty))?;
+ }
+ }
+ }
+
+ for handle_and_expr in const_expressions.iter() {
+ Self::validate_const_expression_handles(handle_and_expr, constants, types)?;
+ }
+
+ let validate_type = |handle| Self::validate_type_handle(handle, types);
+ let validate_const_expr =
+ |handle| Self::validate_expression_handle(handle, const_expressions);
+
+ for (_handle, constant) in constants.iter() {
+ let &crate::Constant {
+ name: _,
+ r#override: _,
+ ty,
+ init,
+ } = constant;
+ validate_type(ty)?;
+ validate_const_expr(init)?;
+ }
+
+ for (_handle, global_variable) in global_variables.iter() {
+ let &crate::GlobalVariable {
+ name: _,
+ space: _,
+ binding: _,
+ ty,
+ init,
+ } = global_variable;
+ validate_type(ty)?;
+ if let Some(init_expr) = init {
+ validate_const_expr(init_expr)?;
+ }
+ }
+
+ let validate_function = |function_handle, function: &_| -> Result<_, InvalidHandleError> {
+ let &crate::Function {
+ name: _,
+ ref arguments,
+ ref result,
+ ref local_variables,
+ ref expressions,
+ ref named_expressions,
+ ref body,
+ } = function;
+
+ for arg in arguments.iter() {
+ let &crate::FunctionArgument {
+ name: _,
+ ty,
+ binding: _,
+ } = arg;
+ validate_type(ty)?;
+ }
+
+ if let &Some(crate::FunctionResult { ty, binding: _ }) = result {
+ validate_type(ty)?;
+ }
+
+ for (_handle, local_variable) in local_variables.iter() {
+ let &crate::LocalVariable { name: _, ty, init } = local_variable;
+ validate_type(ty)?;
+ if let Some(init) = init {
+ Self::validate_expression_handle(init, expressions)?;
+ }
+ }
+
+ for handle in named_expressions.keys().copied() {
+ Self::validate_expression_handle(handle, expressions)?;
+ }
+
+ for handle_and_expr in expressions.iter() {
+ Self::validate_expression_handles(
+ handle_and_expr,
+ constants,
+ const_expressions,
+ types,
+ local_variables,
+ global_variables,
+ functions,
+ function_handle,
+ )?;
+ }
+
+ Self::validate_block_handles(body, expressions, functions)?;
+
+ Ok(())
+ };
+
+ for entry_point in entry_points.iter() {
+ validate_function(None, &entry_point.function)?;
+ }
+
+ for (function_handle, function) in functions.iter() {
+ validate_function(Some(function_handle), function)?;
+ }
+
+ if let Some(ty) = special_types.ray_desc {
+ validate_type(ty)?;
+ }
+ if let Some(ty) = special_types.ray_intersection {
+ validate_type(ty)?;
+ }
+
+ Ok(())
+ }
+
+ fn validate_type_handle(
+ handle: Handle<crate::Type>,
+ types: &UniqueArena<crate::Type>,
+ ) -> Result<(), InvalidHandleError> {
+ handle.check_valid_for_uniq(types).map(|_| ())
+ }
+
+ fn validate_constant_handle(
+ handle: Handle<crate::Constant>,
+ constants: &Arena<crate::Constant>,
+ ) -> Result<(), InvalidHandleError> {
+ handle.check_valid_for(constants).map(|_| ())
+ }
+
+ fn validate_expression_handle(
+ handle: Handle<crate::Expression>,
+ expressions: &Arena<crate::Expression>,
+ ) -> Result<(), InvalidHandleError> {
+ handle.check_valid_for(expressions).map(|_| ())
+ }
+
+ fn validate_function_handle(
+ handle: Handle<crate::Function>,
+ functions: &Arena<crate::Function>,
+ ) -> Result<(), InvalidHandleError> {
+ handle.check_valid_for(functions).map(|_| ())
+ }
+
+ fn validate_const_expression_handles(
+ (handle, expression): (Handle<crate::Expression>, &crate::Expression),
+ constants: &Arena<crate::Constant>,
+ types: &UniqueArena<crate::Type>,
+ ) -> Result<(), InvalidHandleError> {
+ let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
+ let validate_type = |handle| Self::validate_type_handle(handle, types);
+
+ match *expression {
+ crate::Expression::Literal(_) => {}
+ crate::Expression::Constant(constant) => {
+ validate_constant(constant)?;
+ handle.check_dep(constants[constant].init)?;
+ }
+ crate::Expression::ZeroValue(ty) => {
+ validate_type(ty)?;
+ }
+ crate::Expression::Compose { ty, ref components } => {
+ validate_type(ty)?;
+ handle.check_dep_iter(components.iter().copied())?;
+ }
+ _ => {}
+ }
+ Ok(())
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn validate_expression_handles(
+ (handle, expression): (Handle<crate::Expression>, &crate::Expression),
+ constants: &Arena<crate::Constant>,
+ const_expressions: &Arena<crate::Expression>,
+ types: &UniqueArena<crate::Type>,
+ local_variables: &Arena<crate::LocalVariable>,
+ global_variables: &Arena<crate::GlobalVariable>,
+ functions: &Arena<crate::Function>,
+ // The handle of the current function or `None` if it's an entry point
+ current_function: Option<Handle<crate::Function>>,
+ ) -> Result<(), InvalidHandleError> {
+ let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
+ let validate_const_expr =
+ |handle| Self::validate_expression_handle(handle, const_expressions);
+ let validate_type = |handle| Self::validate_type_handle(handle, types);
+
+ match *expression {
+ crate::Expression::Access { base, index } => {
+ handle.check_dep(base)?.check_dep(index)?;
+ }
+ crate::Expression::AccessIndex { base, .. } => {
+ handle.check_dep(base)?;
+ }
+ crate::Expression::Splat { value, .. } => {
+ handle.check_dep(value)?;
+ }
+ crate::Expression::Swizzle { vector, .. } => {
+ handle.check_dep(vector)?;
+ }
+ crate::Expression::Literal(_) => {}
+ crate::Expression::Constant(constant) => {
+ validate_constant(constant)?;
+ }
+ crate::Expression::ZeroValue(ty) => {
+ validate_type(ty)?;
+ }
+ crate::Expression::Compose { ty, ref components } => {
+ validate_type(ty)?;
+ handle.check_dep_iter(components.iter().copied())?;
+ }
+ crate::Expression::FunctionArgument(_arg_idx) => (),
+ crate::Expression::GlobalVariable(global_variable) => {
+ global_variable.check_valid_for(global_variables)?;
+ }
+ crate::Expression::LocalVariable(local_variable) => {
+ local_variable.check_valid_for(local_variables)?;
+ }
+ crate::Expression::Load { pointer } => {
+ handle.check_dep(pointer)?;
+ }
+ crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather: _,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ if let Some(offset) = offset {
+ validate_const_expr(offset)?;
+ }
+
+ handle
+ .check_dep(image)?
+ .check_dep(sampler)?
+ .check_dep(coordinate)?
+ .check_dep_opt(array_index)?;
+
+ match level {
+ crate::SampleLevel::Auto | crate::SampleLevel::Zero => (),
+ crate::SampleLevel::Exact(expr) => {
+ handle.check_dep(expr)?;
+ }
+ crate::SampleLevel::Bias(expr) => {
+ handle.check_dep(expr)?;
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ handle.check_dep(x)?.check_dep(y)?;
+ }
+ };
+
+ handle.check_dep_opt(depth_ref)?;
+ }
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ handle
+ .check_dep(image)?
+ .check_dep(coordinate)?
+ .check_dep_opt(array_index)?
+ .check_dep_opt(sample)?
+ .check_dep_opt(level)?;
+ }
+ crate::Expression::ImageQuery { image, query } => {
+ handle.check_dep(image)?;
+ match query {
+ crate::ImageQuery::Size { level } => {
+ handle.check_dep_opt(level)?;
+ }
+ crate::ImageQuery::NumLevels
+ | crate::ImageQuery::NumLayers
+ | crate::ImageQuery::NumSamples => (),
+ };
+ }
+ crate::Expression::Unary {
+ op: _,
+ expr: operand,
+ } => {
+ handle.check_dep(operand)?;
+ }
+ crate::Expression::Binary { op: _, left, right } => {
+ handle.check_dep(left)?.check_dep(right)?;
+ }
+ crate::Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ handle
+ .check_dep(condition)?
+ .check_dep(accept)?
+ .check_dep(reject)?;
+ }
+ crate::Expression::Derivative { expr: argument, .. } => {
+ handle.check_dep(argument)?;
+ }
+ crate::Expression::Relational { fun: _, argument } => {
+ handle.check_dep(argument)?;
+ }
+ crate::Expression::Math {
+ fun: _,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ handle
+ .check_dep(arg)?
+ .check_dep_opt(arg1)?
+ .check_dep_opt(arg2)?
+ .check_dep_opt(arg3)?;
+ }
+ crate::Expression::As {
+ expr: input,
+ kind: _,
+ convert: _,
+ } => {
+ handle.check_dep(input)?;
+ }
+ crate::Expression::CallResult(function) => {
+ Self::validate_function_handle(function, functions)?;
+ if let Some(handle) = current_function {
+ handle.check_dep(function)?;
+ }
+ }
+ crate::Expression::AtomicResult { .. }
+ | crate::Expression::RayQueryProceedResult
+ | crate::Expression::WorkGroupUniformLoadResult { .. } => (),
+ crate::Expression::ArrayLength(array) => {
+ handle.check_dep(array)?;
+ }
+ crate::Expression::RayQueryGetIntersection {
+ query,
+ committed: _,
+ } => {
+ handle.check_dep(query)?;
+ }
+ }
+ Ok(())
+ }
+
+ fn validate_block_handles(
+ block: &crate::Block,
+ expressions: &Arena<crate::Expression>,
+ functions: &Arena<crate::Function>,
+ ) -> Result<(), InvalidHandleError> {
+ let validate_block = |block| Self::validate_block_handles(block, expressions, functions);
+ let validate_expr = |handle| Self::validate_expression_handle(handle, expressions);
+ let validate_expr_opt = |handle_opt| {
+ if let Some(handle) = handle_opt {
+ validate_expr(handle)?;
+ }
+ Ok(())
+ };
+
+ block.iter().try_for_each(|stmt| match *stmt {
+ crate::Statement::Emit(ref expr_range) => {
+ expr_range.check_valid_for(expressions)?;
+ Ok(())
+ }
+ crate::Statement::Block(ref block) => {
+ validate_block(block)?;
+ Ok(())
+ }
+ crate::Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ validate_expr(condition)?;
+ validate_block(accept)?;
+ validate_block(reject)?;
+ Ok(())
+ }
+ crate::Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ validate_expr(selector)?;
+ for &crate::SwitchCase {
+ value: _,
+ ref body,
+ fall_through: _,
+ } in cases
+ {
+ validate_block(body)?;
+ }
+ Ok(())
+ }
+ crate::Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ validate_block(body)?;
+ validate_block(continuing)?;
+ validate_expr_opt(break_if)?;
+ Ok(())
+ }
+ crate::Statement::Return { value } => validate_expr_opt(value),
+ crate::Statement::Store { pointer, value } => {
+ validate_expr(pointer)?;
+ validate_expr(value)?;
+ Ok(())
+ }
+ crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ validate_expr(image)?;
+ validate_expr(coordinate)?;
+ validate_expr_opt(array_index)?;
+ validate_expr(value)?;
+ Ok(())
+ }
+ crate::Statement::Atomic {
+ pointer,
+ fun,
+ value,
+ result,
+ } => {
+ validate_expr(pointer)?;
+ match fun {
+ crate::AtomicFunction::Add
+ | crate::AtomicFunction::Subtract
+ | crate::AtomicFunction::And
+ | crate::AtomicFunction::ExclusiveOr
+ | crate::AtomicFunction::InclusiveOr
+ | crate::AtomicFunction::Min
+ | crate::AtomicFunction::Max => (),
+ crate::AtomicFunction::Exchange { compare } => validate_expr_opt(compare)?,
+ };
+ validate_expr(value)?;
+ validate_expr(result)?;
+ Ok(())
+ }
+ crate::Statement::WorkGroupUniformLoad { pointer, result } => {
+ validate_expr(pointer)?;
+ validate_expr(result)?;
+ Ok(())
+ }
+ crate::Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ Self::validate_function_handle(function, functions)?;
+ for arg in arguments.iter().copied() {
+ validate_expr(arg)?;
+ }
+ validate_expr_opt(result)?;
+ Ok(())
+ }
+ crate::Statement::RayQuery { query, ref fun } => {
+ validate_expr(query)?;
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ validate_expr(acceleration_structure)?;
+ validate_expr(descriptor)?;
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ validate_expr(result)?;
+ }
+ crate::RayQueryFunction::Terminate => {}
+ }
+ Ok(())
+ }
+ crate::Statement::Break
+ | crate::Statement::Continue
+ | crate::Statement::Kill
+ | crate::Statement::Barrier(_) => Ok(()),
+ })
+ }
+}
+
+impl From<BadHandle> for ValidationError {
+ fn from(source: BadHandle) -> Self {
+ Self::InvalidHandle(source.into())
+ }
+}
+
+impl From<FwdDepError> for ValidationError {
+ fn from(source: FwdDepError) -> Self {
+ Self::InvalidHandle(source.into())
+ }
+}
+
+impl From<BadRangeError> for ValidationError {
+ fn from(source: BadRangeError) -> Self {
+ Self::InvalidHandle(source.into())
+ }
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum InvalidHandleError {
+ #[error(transparent)]
+ BadHandle(#[from] BadHandle),
+ #[error(transparent)]
+ ForwardDependency(#[from] FwdDepError),
+ #[error(transparent)]
+ BadRange(#[from] BadRangeError),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[error(
+ "{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \
+ which has not been processed yet"
+)]
+pub struct FwdDepError {
+ // This error is used for many `Handle` types, but there's no point in making this generic, so
+ // we just flatten them all to `Handle<()>` here.
+ subject: Handle<()>,
+ subject_kind: &'static str,
+ depends_on: Handle<()>,
+ depends_on_kind: &'static str,
+}
+
+impl<T> Handle<T> {
+ /// Check that `self` is valid within `arena` using [`Arena::check_contains_handle`].
+ pub(self) fn check_valid_for(self, arena: &Arena<T>) -> Result<(), InvalidHandleError> {
+ arena.check_contains_handle(self)?;
+ Ok(())
+ }
+
+ /// Check that `self` is valid within `arena` using [`UniqueArena::check_contains_handle`].
+ pub(self) fn check_valid_for_uniq(
+ self,
+ arena: &UniqueArena<T>,
+ ) -> Result<(), InvalidHandleError>
+ where
+ T: Eq + Hash,
+ {
+ arena.check_contains_handle(self)?;
+ Ok(())
+ }
+
+ /// Check that `depends_on` was constructed before `self` by comparing handle indices.
+ ///
+ /// If `self` is a valid handle (i.e., it has been validated using [`Self::check_valid_for`])
+ /// and this function returns [`Ok`], then it may be assumed that `depends_on` is also valid.
+ /// In [`naga`](crate)'s current arena-based implementation, this is useful for validating
+ /// recursive definitions of arena-based values in linear time.
+ ///
+ /// # Errors
+ ///
+ /// If `depends_on`'s handle is from the same [`Arena`] as `self'`s, but not constructed earlier
+ /// than `self`'s, this function returns an error.
+ pub(self) fn check_dep(self, depends_on: Self) -> Result<Self, FwdDepError> {
+ if depends_on < self {
+ Ok(self)
+ } else {
+ let erase_handle_type = |handle: Handle<_>| {
+ Handle::new(NonZeroU32::new((handle.index() + 1).try_into().unwrap()).unwrap())
+ };
+ Err(FwdDepError {
+ subject: erase_handle_type(self),
+ subject_kind: std::any::type_name::<T>(),
+ depends_on: erase_handle_type(depends_on),
+ depends_on_kind: std::any::type_name::<T>(),
+ })
+ }
+ }
+
+ /// Like [`Self::check_dep`], except for [`Option`]al handle values.
+ pub(self) fn check_dep_opt(self, depends_on: Option<Self>) -> Result<Self, FwdDepError> {
+ self.check_dep_iter(depends_on.into_iter())
+ }
+
+ /// Like [`Self::check_dep`], except for [`Iterator`]s over handle values.
+ pub(self) fn check_dep_iter(
+ self,
+ depends_on: impl Iterator<Item = Self>,
+ ) -> Result<Self, FwdDepError> {
+ for handle in depends_on {
+ self.check_dep(handle)?;
+ }
+ Ok(self)
+ }
+}
+
+impl<T> crate::arena::Range<T> {
+ pub(self) fn check_valid_for(&self, arena: &Arena<T>) -> Result<(), BadRangeError> {
+ arena.check_contains_range(self)
+ }
+}
+
+#[test]
+fn constant_deps() {
+ use crate::{Constant, Expression, Literal, Span, Type, TypeInner};
+
+ let nowhere = Span::default();
+
+ let mut types = UniqueArena::new();
+ let mut const_exprs = Arena::new();
+ let mut fun_exprs = Arena::new();
+ let mut constants = Arena::new();
+
+ let i32_handle = types.insert(
+ Type {
+ name: None,
+ inner: TypeInner::Scalar(crate::Scalar::I32),
+ },
+ nowhere,
+ );
+
+ // Construct a self-referential constant by misusing a handle to
+ // fun_exprs as a constant initializer.
+ let fun_expr = fun_exprs.append(Expression::Literal(Literal::I32(42)), nowhere);
+ let self_referential_const = constants.append(
+ Constant {
+ name: None,
+ r#override: crate::Override::None,
+ ty: i32_handle,
+ init: fun_expr,
+ },
+ nowhere,
+ );
+ let _self_referential_expr =
+ const_exprs.append(Expression::Constant(self_referential_const), nowhere);
+
+ for handle_and_expr in const_exprs.iter() {
+ assert!(super::Validator::validate_const_expression_handles(
+ handle_and_expr,
+ &constants,
+ &types,
+ )
+ .is_err());
+ }
+}
diff --git a/third_party/rust/naga/src/valid/interface.rs b/third_party/rust/naga/src/valid/interface.rs
new file mode 100644
index 0000000000..84c8b09ddb
--- /dev/null
+++ b/third_party/rust/naga/src/valid/interface.rs
@@ -0,0 +1,709 @@
+use super::{
+ analyzer::{FunctionInfo, GlobalUse},
+ Capabilities, Disalignment, FunctionError, ModuleInfo,
+};
+use crate::arena::{Handle, UniqueArena};
+
+use crate::span::{AddSpan as _, MapErrWithSpan as _, SpanProvider as _, WithSpan};
+use bit_set::BitSet;
+
+const MAX_WORKGROUP_SIZE: u32 = 0x4000;
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum GlobalVariableError {
+ #[error("Usage isn't compatible with address space {0:?}")]
+ InvalidUsage(crate::AddressSpace),
+ #[error("Type isn't compatible with address space {0:?}")]
+ InvalidType(crate::AddressSpace),
+ #[error("Type flags {seen:?} do not meet the required {required:?}")]
+ MissingTypeFlags {
+ required: super::TypeFlags,
+ seen: super::TypeFlags,
+ },
+ #[error("Capability {0:?} is not supported")]
+ UnsupportedCapability(Capabilities),
+ #[error("Binding decoration is missing or not applicable")]
+ InvalidBinding,
+ #[error("Alignment requirements for address space {0:?} are not met by {1:?}")]
+ Alignment(
+ crate::AddressSpace,
+ Handle<crate::Type>,
+ #[source] Disalignment,
+ ),
+ #[error("Initializer doesn't match the variable type")]
+ InitializerType,
+ #[error("Initializer can't be used with address space {0:?}")]
+ InitializerNotAllowed(crate::AddressSpace),
+ #[error("Storage address space doesn't support write-only access")]
+ StorageAddressSpaceWriteOnlyNotSupported,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum VaryingError {
+ #[error("The type {0:?} does not match the varying")]
+ InvalidType(Handle<crate::Type>),
+ #[error("The type {0:?} cannot be used for user-defined entry point inputs or outputs")]
+ NotIOShareableType(Handle<crate::Type>),
+ #[error("Interpolation is not valid")]
+ InvalidInterpolation,
+ #[error("Interpolation must be specified on vertex shader outputs and fragment shader inputs")]
+ MissingInterpolation,
+ #[error("Built-in {0:?} is not available at this stage")]
+ InvalidBuiltInStage(crate::BuiltIn),
+ #[error("Built-in type for {0:?} is invalid")]
+ InvalidBuiltInType(crate::BuiltIn),
+ #[error("Entry point arguments and return values must all have bindings")]
+ MissingBinding,
+ #[error("Struct member {0} is missing a binding")]
+ MemberMissingBinding(u32),
+ #[error("Multiple bindings at location {location} are present")]
+ BindingCollision { location: u32 },
+ #[error("Built-in {0:?} is present more than once")]
+ DuplicateBuiltIn(crate::BuiltIn),
+ #[error("Capability {0:?} is not supported")]
+ UnsupportedCapability(Capabilities),
+ #[error("The attribute {0:?} is only valid as an output for stage {1:?}")]
+ InvalidInputAttributeInStage(&'static str, crate::ShaderStage),
+ #[error("The attribute {0:?} is not valid for stage {1:?}")]
+ InvalidAttributeInStage(&'static str, crate::ShaderStage),
+ #[error(
+ "The location index {location} cannot be used together with the attribute {attribute:?}"
+ )]
+ InvalidLocationAttributeCombination {
+ location: u32,
+ attribute: &'static str,
+ },
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum EntryPointError {
+ #[error("Multiple conflicting entry points")]
+ Conflict,
+ #[error("Vertex shaders must return a `@builtin(position)` output value")]
+ MissingVertexOutputPosition,
+ #[error("Early depth test is not applicable")]
+ UnexpectedEarlyDepthTest,
+ #[error("Workgroup size is not applicable")]
+ UnexpectedWorkgroupSize,
+ #[error("Workgroup size is out of range")]
+ OutOfRangeWorkgroupSize,
+ #[error("Uses operations forbidden at this stage")]
+ ForbiddenStageOperations,
+ #[error("Global variable {0:?} is used incorrectly as {1:?}")]
+ InvalidGlobalUsage(Handle<crate::GlobalVariable>, GlobalUse),
+ #[error("More than 1 push constant variable is used")]
+ MoreThanOnePushConstantUsed,
+ #[error("Bindings for {0:?} conflict with other resource")]
+ BindingCollision(Handle<crate::GlobalVariable>),
+ #[error("Argument {0} varying error")]
+ Argument(u32, #[source] VaryingError),
+ #[error(transparent)]
+ Result(#[from] VaryingError),
+ #[error("Location {location} interpolation of an integer has to be flat")]
+ InvalidIntegerInterpolation { location: u32 },
+ #[error(transparent)]
+ Function(#[from] FunctionError),
+ #[error(
+ "Invalid locations {location_mask:?} are set while dual source blending. Only location 0 may be set."
+ )]
+ InvalidLocationsWhileDualSourceBlending { location_mask: BitSet },
+}
+
+fn storage_usage(access: crate::StorageAccess) -> GlobalUse {
+ let mut storage_usage = GlobalUse::QUERY;
+ if access.contains(crate::StorageAccess::LOAD) {
+ storage_usage |= GlobalUse::READ;
+ }
+ if access.contains(crate::StorageAccess::STORE) {
+ storage_usage |= GlobalUse::WRITE;
+ }
+ storage_usage
+}
+
+struct VaryingContext<'a> {
+ stage: crate::ShaderStage,
+ output: bool,
+ second_blend_source: bool,
+ types: &'a UniqueArena<crate::Type>,
+ type_info: &'a Vec<super::r#type::TypeInfo>,
+ location_mask: &'a mut BitSet,
+ built_ins: &'a mut crate::FastHashSet<crate::BuiltIn>,
+ capabilities: Capabilities,
+ flags: super::ValidationFlags,
+}
+
+impl VaryingContext<'_> {
+ fn validate_impl(
+ &mut self,
+ ty: Handle<crate::Type>,
+ binding: &crate::Binding,
+ ) -> Result<(), VaryingError> {
+ use crate::{BuiltIn as Bi, ShaderStage as St, TypeInner as Ti, VectorSize as Vs};
+
+ let ty_inner = &self.types[ty].inner;
+ match *binding {
+ crate::Binding::BuiltIn(built_in) => {
+ // Ignore the `invariant` field for the sake of duplicate checks,
+ // but use the original in error messages.
+ let canonical = if let crate::BuiltIn::Position { .. } = built_in {
+ crate::BuiltIn::Position { invariant: false }
+ } else {
+ built_in
+ };
+
+ if self.built_ins.contains(&canonical) {
+ return Err(VaryingError::DuplicateBuiltIn(built_in));
+ }
+ self.built_ins.insert(canonical);
+
+ let required = match built_in {
+ Bi::ClipDistance => Capabilities::CLIP_DISTANCE,
+ Bi::CullDistance => Capabilities::CULL_DISTANCE,
+ Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX,
+ Bi::ViewIndex => Capabilities::MULTIVIEW,
+ Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING,
+ _ => Capabilities::empty(),
+ };
+ if !self.capabilities.contains(required) {
+ return Err(VaryingError::UnsupportedCapability(required));
+ }
+
+ let (visible, type_good) = match built_in {
+ Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => (
+ self.stage == St::Vertex && !self.output,
+ *ty_inner == Ti::Scalar(crate::Scalar::U32),
+ ),
+ Bi::ClipDistance | Bi::CullDistance => (
+ self.stage == St::Vertex && self.output,
+ match *ty_inner {
+ Ti::Array { base, .. } => {
+ self.types[base].inner == Ti::Scalar(crate::Scalar::F32)
+ }
+ _ => false,
+ },
+ ),
+ Bi::PointSize => (
+ self.stage == St::Vertex && self.output,
+ *ty_inner == Ti::Scalar(crate::Scalar::F32),
+ ),
+ Bi::PointCoord => (
+ self.stage == St::Fragment && !self.output,
+ *ty_inner
+ == Ti::Vector {
+ size: Vs::Bi,
+ scalar: crate::Scalar::F32,
+ },
+ ),
+ Bi::Position { .. } => (
+ match self.stage {
+ St::Vertex => self.output,
+ St::Fragment => !self.output,
+ St::Compute => false,
+ },
+ *ty_inner
+ == Ti::Vector {
+ size: Vs::Quad,
+ scalar: crate::Scalar::F32,
+ },
+ ),
+ Bi::ViewIndex => (
+ match self.stage {
+ St::Vertex | St::Fragment => !self.output,
+ St::Compute => false,
+ },
+ *ty_inner == Ti::Scalar(crate::Scalar::I32),
+ ),
+ Bi::FragDepth => (
+ self.stage == St::Fragment && self.output,
+ *ty_inner == Ti::Scalar(crate::Scalar::F32),
+ ),
+ Bi::FrontFacing => (
+ self.stage == St::Fragment && !self.output,
+ *ty_inner == Ti::Scalar(crate::Scalar::BOOL),
+ ),
+ Bi::PrimitiveIndex => (
+ self.stage == St::Fragment && !self.output,
+ *ty_inner == Ti::Scalar(crate::Scalar::U32),
+ ),
+ Bi::SampleIndex => (
+ self.stage == St::Fragment && !self.output,
+ *ty_inner == Ti::Scalar(crate::Scalar::U32),
+ ),
+ Bi::SampleMask => (
+ self.stage == St::Fragment,
+ *ty_inner == Ti::Scalar(crate::Scalar::U32),
+ ),
+ Bi::LocalInvocationIndex => (
+ self.stage == St::Compute && !self.output,
+ *ty_inner == Ti::Scalar(crate::Scalar::U32),
+ ),
+ Bi::GlobalInvocationId
+ | Bi::LocalInvocationId
+ | Bi::WorkGroupId
+ | Bi::WorkGroupSize
+ | Bi::NumWorkGroups => (
+ self.stage == St::Compute && !self.output,
+ *ty_inner
+ == Ti::Vector {
+ size: Vs::Tri,
+ scalar: crate::Scalar::U32,
+ },
+ ),
+ };
+
+ if !visible {
+ return Err(VaryingError::InvalidBuiltInStage(built_in));
+ }
+ if !type_good {
+ log::warn!("Wrong builtin type: {:?}", ty_inner);
+ return Err(VaryingError::InvalidBuiltInType(built_in));
+ }
+ }
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ second_blend_source,
+ } => {
+ // Only IO-shareable types may be stored in locations.
+ if !self.type_info[ty.index()]
+ .flags
+ .contains(super::TypeFlags::IO_SHAREABLE)
+ {
+ return Err(VaryingError::NotIOShareableType(ty));
+ }
+
+ if second_blend_source {
+ if !self
+ .capabilities
+ .contains(Capabilities::DUAL_SOURCE_BLENDING)
+ {
+ return Err(VaryingError::UnsupportedCapability(
+ Capabilities::DUAL_SOURCE_BLENDING,
+ ));
+ }
+ if self.stage != crate::ShaderStage::Fragment {
+ return Err(VaryingError::InvalidAttributeInStage(
+ "second_blend_source",
+ self.stage,
+ ));
+ }
+ if !self.output {
+ return Err(VaryingError::InvalidInputAttributeInStage(
+ "second_blend_source",
+ self.stage,
+ ));
+ }
+ if location != 0 {
+ return Err(VaryingError::InvalidLocationAttributeCombination {
+ location,
+ attribute: "second_blend_source",
+ });
+ }
+
+ self.second_blend_source = true;
+ } else if !self.location_mask.insert(location as usize) {
+ if self.flags.contains(super::ValidationFlags::BINDINGS) {
+ return Err(VaryingError::BindingCollision { location });
+ }
+ }
+
+ let needs_interpolation = match self.stage {
+ crate::ShaderStage::Vertex => self.output,
+ crate::ShaderStage::Fragment => !self.output,
+ crate::ShaderStage::Compute => false,
+ };
+
+ // It doesn't make sense to specify a sampling when `interpolation` is `Flat`, but
+ // SPIR-V and GLSL both explicitly tolerate such combinations of decorators /
+ // qualifiers, so we won't complain about that here.
+ let _ = sampling;
+
+ let required = match sampling {
+ Some(crate::Sampling::Sample) => Capabilities::MULTISAMPLED_SHADING,
+ _ => Capabilities::empty(),
+ };
+ if !self.capabilities.contains(required) {
+ return Err(VaryingError::UnsupportedCapability(required));
+ }
+
+ match ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Float) => {
+ if needs_interpolation && interpolation.is_none() {
+ return Err(VaryingError::MissingInterpolation);
+ }
+ }
+ Some(_) => {
+ if needs_interpolation && interpolation != Some(crate::Interpolation::Flat)
+ {
+ return Err(VaryingError::InvalidInterpolation);
+ }
+ }
+ None => return Err(VaryingError::InvalidType(ty)),
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ fn validate(
+ &mut self,
+ ty: Handle<crate::Type>,
+ binding: Option<&crate::Binding>,
+ ) -> Result<(), WithSpan<VaryingError>> {
+ let span_context = self.types.get_span_context(ty);
+ match binding {
+ Some(binding) => self
+ .validate_impl(ty, binding)
+ .map_err(|e| e.with_span_context(span_context)),
+ None => {
+ match self.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (index, member) in members.iter().enumerate() {
+ let span_context = self.types.get_span_context(ty);
+ match member.binding {
+ None => {
+ if self.flags.contains(super::ValidationFlags::BINDINGS) {
+ return Err(VaryingError::MemberMissingBinding(
+ index as u32,
+ )
+ .with_span_context(span_context));
+ }
+ }
+ Some(ref binding) => self
+ .validate_impl(member.ty, binding)
+ .map_err(|e| e.with_span_context(span_context))?,
+ }
+ }
+ }
+ _ => {
+ if self.flags.contains(super::ValidationFlags::BINDINGS) {
+ return Err(VaryingError::MissingBinding.with_span());
+ }
+ }
+ }
+ Ok(())
+ }
+ }
+ }
+}
+
+impl super::Validator {
+ pub(super) fn validate_global_var(
+ &self,
+ var: &crate::GlobalVariable,
+ gctx: crate::proc::GlobalCtx,
+ mod_info: &ModuleInfo,
+ ) -> Result<(), GlobalVariableError> {
+ use super::TypeFlags;
+
+ log::debug!("var {:?}", var);
+ let inner_ty = match gctx.types[var.ty].inner {
+ // A binding array is (mostly) supposed to behave the same as a
+ // series of individually bound resources, so we can (mostly)
+ // validate a `binding_array<T>` as if it were just a plain `T`.
+ crate::TypeInner::BindingArray { base, .. } => match var.space {
+ crate::AddressSpace::Storage { .. }
+ | crate::AddressSpace::Uniform
+ | crate::AddressSpace::Handle => base,
+ _ => return Err(GlobalVariableError::InvalidUsage(var.space)),
+ },
+ _ => var.ty,
+ };
+ let type_info = &self.types[inner_ty.index()];
+
+ let (required_type_flags, is_resource) = match var.space {
+ crate::AddressSpace::Function => {
+ return Err(GlobalVariableError::InvalidUsage(var.space))
+ }
+ crate::AddressSpace::Storage { access } => {
+ if let Err((ty_handle, disalignment)) = type_info.storage_layout {
+ if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) {
+ return Err(GlobalVariableError::Alignment(
+ var.space,
+ ty_handle,
+ disalignment,
+ ));
+ }
+ }
+ if access == crate::StorageAccess::STORE {
+ return Err(GlobalVariableError::StorageAddressSpaceWriteOnlyNotSupported);
+ }
+ (TypeFlags::DATA | TypeFlags::HOST_SHAREABLE, true)
+ }
+ crate::AddressSpace::Uniform => {
+ if let Err((ty_handle, disalignment)) = type_info.uniform_layout {
+ if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) {
+ return Err(GlobalVariableError::Alignment(
+ var.space,
+ ty_handle,
+ disalignment,
+ ));
+ }
+ }
+ (
+ TypeFlags::DATA
+ | TypeFlags::COPY
+ | TypeFlags::SIZED
+ | TypeFlags::HOST_SHAREABLE,
+ true,
+ )
+ }
+ crate::AddressSpace::Handle => {
+ match gctx.types[inner_ty].inner {
+ crate::TypeInner::Image { class, .. } => match class {
+ crate::ImageClass::Storage {
+ format:
+ crate::StorageFormat::R16Unorm
+ | crate::StorageFormat::R16Snorm
+ | crate::StorageFormat::Rg16Unorm
+ | crate::StorageFormat::Rg16Snorm
+ | crate::StorageFormat::Rgba16Unorm
+ | crate::StorageFormat::Rgba16Snorm,
+ ..
+ } => {
+ if !self
+ .capabilities
+ .contains(Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS)
+ {
+ return Err(GlobalVariableError::UnsupportedCapability(
+ Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
+ ));
+ }
+ }
+ _ => {}
+ },
+ crate::TypeInner::Sampler { .. }
+ | crate::TypeInner::AccelerationStructure
+ | crate::TypeInner::RayQuery => {}
+ _ => {
+ return Err(GlobalVariableError::InvalidType(var.space));
+ }
+ }
+
+ (TypeFlags::empty(), true)
+ }
+ crate::AddressSpace::Private => (TypeFlags::CONSTRUCTIBLE, false),
+ crate::AddressSpace::WorkGroup => (TypeFlags::DATA | TypeFlags::SIZED, false),
+ crate::AddressSpace::PushConstant => {
+ if !self.capabilities.contains(Capabilities::PUSH_CONSTANT) {
+ return Err(GlobalVariableError::UnsupportedCapability(
+ Capabilities::PUSH_CONSTANT,
+ ));
+ }
+ (
+ TypeFlags::DATA
+ | TypeFlags::COPY
+ | TypeFlags::HOST_SHAREABLE
+ | TypeFlags::SIZED,
+ false,
+ )
+ }
+ };
+
+ if !type_info.flags.contains(required_type_flags) {
+ return Err(GlobalVariableError::MissingTypeFlags {
+ seen: type_info.flags,
+ required: required_type_flags,
+ });
+ }
+
+ if is_resource != var.binding.is_some() {
+ if self.flags.contains(super::ValidationFlags::BINDINGS) {
+ return Err(GlobalVariableError::InvalidBinding);
+ }
+ }
+
+ if let Some(init) = var.init {
+ match var.space {
+ crate::AddressSpace::Private | crate::AddressSpace::Function => {}
+ _ => {
+ return Err(GlobalVariableError::InitializerNotAllowed(var.space));
+ }
+ }
+
+ let decl_ty = &gctx.types[var.ty].inner;
+ let init_ty = mod_info[init].inner_with(gctx.types);
+ if !decl_ty.equivalent(init_ty, gctx.types) {
+ return Err(GlobalVariableError::InitializerType);
+ }
+ }
+
+ Ok(())
+ }
+
+ pub(super) fn validate_entry_point(
+ &mut self,
+ ep: &crate::EntryPoint,
+ module: &crate::Module,
+ mod_info: &ModuleInfo,
+ ) -> Result<FunctionInfo, WithSpan<EntryPointError>> {
+ if ep.early_depth_test.is_some() {
+ let required = Capabilities::EARLY_DEPTH_TEST;
+ if !self.capabilities.contains(required) {
+ return Err(
+ EntryPointError::Result(VaryingError::UnsupportedCapability(required))
+ .with_span(),
+ );
+ }
+
+ if ep.stage != crate::ShaderStage::Fragment {
+ return Err(EntryPointError::UnexpectedEarlyDepthTest.with_span());
+ }
+ }
+
+ if ep.stage == crate::ShaderStage::Compute {
+ if ep
+ .workgroup_size
+ .iter()
+ .any(|&s| s == 0 || s > MAX_WORKGROUP_SIZE)
+ {
+ return Err(EntryPointError::OutOfRangeWorkgroupSize.with_span());
+ }
+ } else if ep.workgroup_size != [0; 3] {
+ return Err(EntryPointError::UnexpectedWorkgroupSize.with_span());
+ }
+
+ let mut info = self
+ .validate_function(&ep.function, module, mod_info, true)
+ .map_err(WithSpan::into_other)?;
+
+ {
+ use super::ShaderStages;
+
+ let stage_bit = match ep.stage {
+ crate::ShaderStage::Vertex => ShaderStages::VERTEX,
+ crate::ShaderStage::Fragment => ShaderStages::FRAGMENT,
+ crate::ShaderStage::Compute => ShaderStages::COMPUTE,
+ };
+
+ if !info.available_stages.contains(stage_bit) {
+ return Err(EntryPointError::ForbiddenStageOperations.with_span());
+ }
+ }
+
+ self.location_mask.clear();
+ let mut argument_built_ins = crate::FastHashSet::default();
+ // TODO: add span info to function arguments
+ for (index, fa) in ep.function.arguments.iter().enumerate() {
+ let mut ctx = VaryingContext {
+ stage: ep.stage,
+ output: false,
+ second_blend_source: false,
+ types: &module.types,
+ type_info: &self.types,
+ location_mask: &mut self.location_mask,
+ built_ins: &mut argument_built_ins,
+ capabilities: self.capabilities,
+ flags: self.flags,
+ };
+ ctx.validate(fa.ty, fa.binding.as_ref())
+ .map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?;
+ }
+
+ self.location_mask.clear();
+ if let Some(ref fr) = ep.function.result {
+ let mut result_built_ins = crate::FastHashSet::default();
+ let mut ctx = VaryingContext {
+ stage: ep.stage,
+ output: true,
+ second_blend_source: false,
+ types: &module.types,
+ type_info: &self.types,
+ location_mask: &mut self.location_mask,
+ built_ins: &mut result_built_ins,
+ capabilities: self.capabilities,
+ flags: self.flags,
+ };
+ ctx.validate(fr.ty, fr.binding.as_ref())
+ .map_err_inner(|e| EntryPointError::Result(e).with_span())?;
+ if ctx.second_blend_source {
+ // Only the first location may be used when dual source blending
+ if ctx.location_mask.len() == 1 && ctx.location_mask.contains(0) {
+ info.dual_source_blending = true;
+ } else {
+ return Err(EntryPointError::InvalidLocationsWhileDualSourceBlending {
+ location_mask: self.location_mask.clone(),
+ }
+ .with_span());
+ }
+ }
+
+ if ep.stage == crate::ShaderStage::Vertex
+ && !result_built_ins.contains(&crate::BuiltIn::Position { invariant: false })
+ {
+ return Err(EntryPointError::MissingVertexOutputPosition.with_span());
+ }
+ } else if ep.stage == crate::ShaderStage::Vertex {
+ return Err(EntryPointError::MissingVertexOutputPosition.with_span());
+ }
+
+ {
+ let used_push_constants = module
+ .global_variables
+ .iter()
+ .filter(|&(_, var)| var.space == crate::AddressSpace::PushConstant)
+ .map(|(handle, _)| handle)
+ .filter(|&handle| !info[handle].is_empty());
+ // Check if there is more than one push constant, and error if so.
+ // Use a loop for when returning multiple errors is supported.
+ #[allow(clippy::never_loop)]
+ for handle in used_push_constants.skip(1) {
+ return Err(EntryPointError::MoreThanOnePushConstantUsed
+ .with_span_handle(handle, &module.global_variables));
+ }
+ }
+
+ self.ep_resource_bindings.clear();
+ for (var_handle, var) in module.global_variables.iter() {
+ let usage = info[var_handle];
+ if usage.is_empty() {
+ continue;
+ }
+
+ let allowed_usage = match var.space {
+ crate::AddressSpace::Function => unreachable!(),
+ crate::AddressSpace::Uniform => GlobalUse::READ | GlobalUse::QUERY,
+ crate::AddressSpace::Storage { access } => storage_usage(access),
+ crate::AddressSpace::Handle => match module.types[var.ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => match module.types[base].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => storage_usage(access),
+ _ => GlobalUse::READ | GlobalUse::QUERY,
+ },
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => storage_usage(access),
+ _ => GlobalUse::READ | GlobalUse::QUERY,
+ },
+ crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => GlobalUse::all(),
+ crate::AddressSpace::PushConstant => GlobalUse::READ,
+ };
+ if !allowed_usage.contains(usage) {
+ log::warn!("\tUsage error for: {:?}", var);
+ log::warn!(
+ "\tAllowed usage: {:?}, requested: {:?}",
+ allowed_usage,
+ usage
+ );
+ return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage)
+ .with_span_handle(var_handle, &module.global_variables));
+ }
+
+ if let Some(ref bind) = var.binding {
+ if !self.ep_resource_bindings.insert(bind.clone()) {
+ if self.flags.contains(super::ValidationFlags::BINDINGS) {
+ return Err(EntryPointError::BindingCollision(var_handle)
+ .with_span_handle(var_handle, &module.global_variables));
+ }
+ }
+ }
+ }
+
+ Ok(info)
+ }
+}
diff --git a/third_party/rust/naga/src/valid/mod.rs b/third_party/rust/naga/src/valid/mod.rs
new file mode 100644
index 0000000000..388495a3ac
--- /dev/null
+++ b/third_party/rust/naga/src/valid/mod.rs
@@ -0,0 +1,477 @@
+/*!
+Shader validator.
+*/
+
+mod analyzer;
+mod compose;
+mod expression;
+mod function;
+mod handles;
+mod interface;
+mod r#type;
+
+use crate::{
+ arena::Handle,
+ proc::{LayoutError, Layouter, TypeResolution},
+ FastHashSet,
+};
+use bit_set::BitSet;
+use std::ops;
+
+//TODO: analyze the model at the same time as we validate it,
+// merge the corresponding matches over expressions and statements.
+
+use crate::span::{AddSpan as _, WithSpan};
+pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
+pub use compose::ComposeError;
+pub use expression::{check_literal_value, LiteralError};
+pub use expression::{ConstExpressionError, ExpressionError};
+pub use function::{CallError, FunctionError, LocalVariableError};
+pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
+pub use r#type::{Disalignment, TypeError, TypeFlags};
+
+use self::handles::InvalidHandleError;
+
+bitflags::bitflags! {
+ /// Validation flags.
+ ///
+ /// If you are working with trusted shaders, then you may be able
+ /// to save some time by skipping validation.
+ ///
+ /// If you do not perform full validation, invalid shaders may
+ /// cause Naga to panic. If you do perform full validation and
+ /// [`Validator::validate`] returns `Ok`, then Naga promises that
+ /// code generation will either succeed or return an error; it
+ /// should never panic.
+ ///
+ /// The default value for `ValidationFlags` is
+ /// `ValidationFlags::all()`.
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct ValidationFlags: u8 {
+ /// Expressions.
+ const EXPRESSIONS = 0x1;
+ /// Statements and blocks of them.
+ const BLOCKS = 0x2;
+ /// Uniformity of control flow for operations that require it.
+ const CONTROL_FLOW_UNIFORMITY = 0x4;
+ /// Host-shareable structure layouts.
+ const STRUCT_LAYOUTS = 0x8;
+ /// Constants.
+ const CONSTANTS = 0x10;
+ /// Group, binding, and location attributes.
+ const BINDINGS = 0x20;
+ }
+}
+
+impl Default for ValidationFlags {
+ fn default() -> Self {
+ Self::all()
+ }
+}
+
+bitflags::bitflags! {
+ /// Allowed IR capabilities.
+ #[must_use]
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct Capabilities: u16 {
+ /// Support for [`AddressSpace:PushConstant`].
+ const PUSH_CONSTANT = 0x1;
+ /// Float values with width = 8.
+ const FLOAT64 = 0x2;
+ /// Support for [`Builtin:PrimitiveIndex`].
+ const PRIMITIVE_INDEX = 0x4;
+ /// Support for non-uniform indexing of sampled textures and storage buffer arrays.
+ const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 0x8;
+ /// Support for non-uniform indexing of uniform buffers and storage texture arrays.
+ const UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 0x10;
+ /// Support for non-uniform indexing of samplers.
+ const SAMPLER_NON_UNIFORM_INDEXING = 0x20;
+ /// Support for [`Builtin::ClipDistance`].
+ const CLIP_DISTANCE = 0x40;
+ /// Support for [`Builtin::CullDistance`].
+ const CULL_DISTANCE = 0x80;
+ /// Support for 16-bit normalized storage texture formats.
+ const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 0x100;
+ /// Support for [`BuiltIn::ViewIndex`].
+ const MULTIVIEW = 0x200;
+ /// Support for `early_depth_test`.
+ const EARLY_DEPTH_TEST = 0x400;
+ /// Support for [`Builtin::SampleIndex`] and [`Sampling::Sample`].
+ const MULTISAMPLED_SHADING = 0x800;
+ /// Support for ray queries and acceleration structures.
+ const RAY_QUERY = 0x1000;
+ /// Support for generating two sources for blending from fragment shaders.
+ const DUAL_SOURCE_BLENDING = 0x2000;
+ /// Support for arrayed cube textures.
+ const CUBE_ARRAY_TEXTURES = 0x4000;
+ }
+}
+
+impl Default for Capabilities {
+ fn default() -> Self {
+ Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES
+ }
+}
+
+bitflags::bitflags! {
+ /// Validation flags.
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct ShaderStages: u8 {
+ const VERTEX = 0x1;
+ const FRAGMENT = 0x2;
+ const COMPUTE = 0x4;
+ }
+}
+
+#[derive(Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct ModuleInfo {
+ type_flags: Vec<TypeFlags>,
+ functions: Vec<FunctionInfo>,
+ entry_points: Vec<FunctionInfo>,
+ const_expression_types: Box<[TypeResolution]>,
+}
+
+impl ops::Index<Handle<crate::Type>> for ModuleInfo {
+ type Output = TypeFlags;
+ fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
+ &self.type_flags[handle.index()]
+ }
+}
+
+impl ops::Index<Handle<crate::Function>> for ModuleInfo {
+ type Output = FunctionInfo;
+ fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
+ &self.functions[handle.index()]
+ }
+}
+
+impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
+ type Output = TypeResolution;
+ fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
+ &self.const_expression_types[handle.index()]
+ }
+}
+
+#[derive(Debug)]
+pub struct Validator {
+ flags: ValidationFlags,
+ capabilities: Capabilities,
+ types: Vec<r#type::TypeInfo>,
+ layouter: Layouter,
+ location_mask: BitSet,
+ ep_resource_bindings: FastHashSet<crate::ResourceBinding>,
+ #[allow(dead_code)]
+ switch_values: FastHashSet<crate::SwitchValue>,
+ valid_expression_list: Vec<Handle<crate::Expression>>,
+ valid_expression_set: BitSet,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum ConstantError {
+ #[error("The type doesn't match the constant")]
+ InvalidType,
+ #[error("The type is not constructible")]
+ NonConstructibleType,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum ValidationError {
+ #[error(transparent)]
+ InvalidHandle(#[from] InvalidHandleError),
+ #[error(transparent)]
+ Layouter(#[from] LayoutError),
+ #[error("Type {handle:?} '{name}' is invalid")]
+ Type {
+ handle: Handle<crate::Type>,
+ name: String,
+ source: TypeError,
+ },
+ #[error("Constant expression {handle:?} is invalid")]
+ ConstExpression {
+ handle: Handle<crate::Expression>,
+ source: ConstExpressionError,
+ },
+ #[error("Constant {handle:?} '{name}' is invalid")]
+ Constant {
+ handle: Handle<crate::Constant>,
+ name: String,
+ source: ConstantError,
+ },
+ #[error("Global variable {handle:?} '{name}' is invalid")]
+ GlobalVariable {
+ handle: Handle<crate::GlobalVariable>,
+ name: String,
+ source: GlobalVariableError,
+ },
+ #[error("Function {handle:?} '{name}' is invalid")]
+ Function {
+ handle: Handle<crate::Function>,
+ name: String,
+ source: FunctionError,
+ },
+ #[error("Entry point {name} at {stage:?} is invalid")]
+ EntryPoint {
+ stage: crate::ShaderStage,
+ name: String,
+ source: EntryPointError,
+ },
+ #[error("Module is corrupted")]
+ Corrupted,
+}
+
+impl crate::TypeInner {
+ const fn is_sized(&self) -> bool {
+ match *self {
+ Self::Scalar { .. }
+ | Self::Vector { .. }
+ | Self::Matrix { .. }
+ | Self::Array {
+ size: crate::ArraySize::Constant(_),
+ ..
+ }
+ | Self::Atomic { .. }
+ | Self::Pointer { .. }
+ | Self::ValuePointer { .. }
+ | Self::Struct { .. } => true,
+ Self::Array { .. }
+ | Self::Image { .. }
+ | Self::Sampler { .. }
+ | Self::AccelerationStructure
+ | Self::RayQuery
+ | Self::BindingArray { .. } => false,
+ }
+ }
+
+ /// Return the `ImageDimension` for which `self` is an appropriate coordinate.
+ const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
+ match *self {
+ Self::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ ..
+ }) => Some(crate::ImageDimension::D1),
+ Self::Vector {
+ size: crate::VectorSize::Bi,
+ scalar:
+ crate::Scalar {
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ ..
+ },
+ } => Some(crate::ImageDimension::D2),
+ Self::Vector {
+ size: crate::VectorSize::Tri,
+ scalar:
+ crate::Scalar {
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ ..
+ },
+ } => Some(crate::ImageDimension::D3),
+ _ => None,
+ }
+ }
+}
+
+impl Validator {
+ /// Construct a new validator instance.
+ pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
+ Validator {
+ flags,
+ capabilities,
+ types: Vec::new(),
+ layouter: Layouter::default(),
+ location_mask: BitSet::new(),
+ ep_resource_bindings: FastHashSet::default(),
+ switch_values: FastHashSet::default(),
+ valid_expression_list: Vec::new(),
+ valid_expression_set: BitSet::new(),
+ }
+ }
+
+ /// Reset the validator internals
+ pub fn reset(&mut self) {
+ self.types.clear();
+ self.layouter.clear();
+ self.location_mask.clear();
+ self.ep_resource_bindings.clear();
+ self.switch_values.clear();
+ self.valid_expression_list.clear();
+ self.valid_expression_set.clear();
+ }
+
+ fn validate_constant(
+ &self,
+ handle: Handle<crate::Constant>,
+ gctx: crate::proc::GlobalCtx,
+ mod_info: &ModuleInfo,
+ ) -> Result<(), ConstantError> {
+ let con = &gctx.constants[handle];
+
+ let type_info = &self.types[con.ty.index()];
+ if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
+ return Err(ConstantError::NonConstructibleType);
+ }
+
+ let decl_ty = &gctx.types[con.ty].inner;
+ let init_ty = mod_info[con.init].inner_with(gctx.types);
+ if !decl_ty.equivalent(init_ty, gctx.types) {
+ return Err(ConstantError::InvalidType);
+ }
+
+ Ok(())
+ }
+
+ /// Check the given module to be valid.
+ pub fn validate(
+ &mut self,
+ module: &crate::Module,
+ ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
+ self.reset();
+ self.reset_types(module.types.len());
+
+ Self::validate_module_handles(module).map_err(|e| e.with_span())?;
+
+ self.layouter.update(module.to_ctx()).map_err(|e| {
+ let handle = e.ty;
+ ValidationError::from(e).with_span_handle(handle, &module.types)
+ })?;
+
+ // These should all get overwritten.
+ let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: 0,
+ }));
+
+ let mut mod_info = ModuleInfo {
+ type_flags: Vec::with_capacity(module.types.len()),
+ functions: Vec::with_capacity(module.functions.len()),
+ entry_points: Vec::with_capacity(module.entry_points.len()),
+ const_expression_types: vec![placeholder; module.const_expressions.len()]
+ .into_boxed_slice(),
+ };
+
+ for (handle, ty) in module.types.iter() {
+ let ty_info = self
+ .validate_type(handle, module.to_ctx())
+ .map_err(|source| {
+ ValidationError::Type {
+ handle,
+ name: ty.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(handle, &module.types)
+ })?;
+ mod_info.type_flags.push(ty_info.flags);
+ self.types[handle.index()] = ty_info;
+ }
+
+ {
+ let t = crate::Arena::new();
+ let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
+ for (handle, _) in module.const_expressions.iter() {
+ mod_info
+ .process_const_expression(handle, &resolve_context, module.to_ctx())
+ .map_err(|source| {
+ ValidationError::ConstExpression { handle, source }
+ .with_span_handle(handle, &module.const_expressions)
+ })?
+ }
+ }
+
+ if self.flags.contains(ValidationFlags::CONSTANTS) {
+ for (handle, _) in module.const_expressions.iter() {
+ self.validate_const_expression(handle, module.to_ctx(), &mod_info)
+ .map_err(|source| {
+ ValidationError::ConstExpression { handle, source }
+ .with_span_handle(handle, &module.const_expressions)
+ })?
+ }
+
+ for (handle, constant) in module.constants.iter() {
+ self.validate_constant(handle, module.to_ctx(), &mod_info)
+ .map_err(|source| {
+ ValidationError::Constant {
+ handle,
+ name: constant.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(handle, &module.constants)
+ })?
+ }
+ }
+
+ for (var_handle, var) in module.global_variables.iter() {
+ self.validate_global_var(var, module.to_ctx(), &mod_info)
+ .map_err(|source| {
+ ValidationError::GlobalVariable {
+ handle: var_handle,
+ name: var.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(var_handle, &module.global_variables)
+ })?;
+ }
+
+ for (handle, fun) in module.functions.iter() {
+ match self.validate_function(fun, module, &mod_info, false) {
+ Ok(info) => mod_info.functions.push(info),
+ Err(error) => {
+ return Err(error.and_then(|source| {
+ ValidationError::Function {
+ handle,
+ name: fun.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(handle, &module.functions)
+ }))
+ }
+ }
+ }
+
+ let mut ep_map = FastHashSet::default();
+ for ep in module.entry_points.iter() {
+ if !ep_map.insert((ep.stage, &ep.name)) {
+ return Err(ValidationError::EntryPoint {
+ stage: ep.stage,
+ name: ep.name.clone(),
+ source: EntryPointError::Conflict,
+ }
+ .with_span()); // TODO: keep some EP span information?
+ }
+
+ match self.validate_entry_point(ep, module, &mod_info) {
+ Ok(info) => mod_info.entry_points.push(info),
+ Err(error) => {
+ return Err(error.and_then(|source| {
+ ValidationError::EntryPoint {
+ stage: ep.stage,
+ name: ep.name.clone(),
+ source,
+ }
+ .with_span()
+ }));
+ }
+ }
+ }
+
+ Ok(mod_info)
+ }
+}
+
+fn validate_atomic_compare_exchange_struct(
+ types: &crate::UniqueArena<crate::Type>,
+ members: &[crate::StructMember],
+ scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
+) -> bool {
+ members.len() == 2
+ && members[0].name.as_deref() == Some("old_value")
+ && scalar_predicate(&types[members[0].ty].inner)
+ && members[1].name.as_deref() == Some("exchanged")
+ && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL)
+}
diff --git a/third_party/rust/naga/src/valid/type.rs b/third_party/rust/naga/src/valid/type.rs
new file mode 100644
index 0000000000..1e3e03fe19
--- /dev/null
+++ b/third_party/rust/naga/src/valid/type.rs
@@ -0,0 +1,643 @@
+use super::Capabilities;
+use crate::{arena::Handle, proc::Alignment};
+
+bitflags::bitflags! {
+ /// Flags associated with [`Type`]s by [`Validator`].
+ ///
+ /// [`Type`]: crate::Type
+ /// [`Validator`]: crate::valid::Validator
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ #[repr(transparent)]
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct TypeFlags: u8 {
+ /// Can be used for data variables.
+ ///
+ /// This flag is required on types of local variables, function
+ /// arguments, array elements, and struct members.
+ ///
+ /// This includes all types except `Image`, `Sampler`,
+ /// and some `Pointer` types.
+ const DATA = 0x1;
+
+ /// The data type has a size known by pipeline creation time.
+ ///
+ /// Unsized types are quite restricted. The only unsized types permitted
+ /// by Naga, other than the non-[`DATA`] types like [`Image`] and
+ /// [`Sampler`], are dynamically-sized [`Array`s], and [`Struct`s] whose
+ /// last members are such arrays. See the documentation for those types
+ /// for details.
+ ///
+ /// [`DATA`]: TypeFlags::DATA
+ /// [`Image`]: crate::Type::Image
+ /// [`Sampler`]: crate::Type::Sampler
+ /// [`Array`]: crate::Type::Array
+ /// [`Struct`]: crate::Type::struct
+ const SIZED = 0x2;
+
+ /// The data can be copied around.
+ const COPY = 0x4;
+
+ /// Can be be used for user-defined IO between pipeline stages.
+ ///
+ /// This covers anything that can be in [`Location`] binding:
+ /// non-bool scalars and vectors, matrices, and structs and
+ /// arrays containing only interface types.
+ const IO_SHAREABLE = 0x8;
+
+ /// Can be used for host-shareable structures.
+ const HOST_SHAREABLE = 0x10;
+
+ /// This type can be passed as a function argument.
+ const ARGUMENT = 0x40;
+
+ /// A WGSL [constructible] type.
+ ///
+ /// The constructible types are scalars, vectors, matrices, fixed-size
+ /// arrays of constructible types, and structs whose members are all
+ /// constructible.
+ ///
+ /// [constructible]: https://gpuweb.github.io/gpuweb/wgsl/#constructible
+ const CONSTRUCTIBLE = 0x80;
+ }
+}
+
+#[derive(Clone, Copy, Debug, thiserror::Error)]
+pub enum Disalignment {
+ #[error("The array stride {stride} is not a multiple of the required alignment {alignment}")]
+ ArrayStride { stride: u32, alignment: Alignment },
+ #[error("The struct span {span}, is not a multiple of the required alignment {alignment}")]
+ StructSpan { span: u32, alignment: Alignment },
+ #[error("The struct member[{index}] offset {offset} is not a multiple of the required alignment {alignment}")]
+ MemberOffset {
+ index: u32,
+ offset: u32,
+ alignment: Alignment,
+ },
+ #[error("The struct member[{index}] offset {offset} must be at least {expected}")]
+ MemberOffsetAfterStruct {
+ index: u32,
+ offset: u32,
+ expected: u32,
+ },
+ #[error("The struct member[{index}] is not statically sized")]
+ UnsizedMember { index: u32 },
+ #[error("The type is not host-shareable")]
+ NonHostShareable,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum TypeError {
+ #[error("Capability {0:?} is required")]
+ MissingCapability(Capabilities),
+ #[error("The {0:?} scalar width {1} is not supported for an atomic")]
+ InvalidAtomicWidth(crate::ScalarKind, crate::Bytes),
+ #[error("Invalid type for pointer target {0:?}")]
+ InvalidPointerBase(Handle<crate::Type>),
+ #[error("Unsized types like {base:?} must be in the `Storage` address space, not `{space:?}`")]
+ InvalidPointerToUnsized {
+ base: Handle<crate::Type>,
+ space: crate::AddressSpace,
+ },
+ #[error("Expected data type, found {0:?}")]
+ InvalidData(Handle<crate::Type>),
+ #[error("Base type {0:?} for the array is invalid")]
+ InvalidArrayBaseType(Handle<crate::Type>),
+ #[error("Matrix elements must always be floating-point types")]
+ MatrixElementNotFloat,
+ #[error("The constant {0:?} is specialized, and cannot be used as an array size")]
+ UnsupportedSpecializedArrayLength(Handle<crate::Constant>),
+ #[error("Array stride {stride} does not match the expected {expected}")]
+ InvalidArrayStride { stride: u32, expected: u32 },
+ #[error("Field '{0}' can't be dynamically-sized, has type {1:?}")]
+ InvalidDynamicArray(String, Handle<crate::Type>),
+ #[error("The base handle {0:?} has to be a struct")]
+ BindingArrayBaseTypeNotStruct(Handle<crate::Type>),
+ #[error("Structure member[{index}] at {offset} overlaps the previous member")]
+ MemberOverlap { index: u32, offset: u32 },
+ #[error(
+ "Structure member[{index}] at {offset} and size {size} crosses the structure boundary of size {span}"
+ )]
+ MemberOutOfBounds {
+ index: u32,
+ offset: u32,
+ size: u32,
+ span: u32,
+ },
+ #[error("Structure types must have at least one member")]
+ EmptyStruct,
+ #[error(transparent)]
+ WidthError(#[from] WidthError),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum WidthError {
+ #[error("The {0:?} scalar width {1} is not supported")]
+ Invalid(crate::ScalarKind, crate::Bytes),
+ #[error("Using `{name}` values requires the `naga::valid::Capabilities::{flag}` flag")]
+ MissingCapability {
+ name: &'static str,
+ flag: &'static str,
+ },
+
+ #[error("64-bit integers are not yet supported")]
+ Unsupported64Bit,
+
+ #[error("Abstract types may only appear in constant expressions")]
+ Abstract,
+}
+
+// Only makes sense if `flags.contains(HOST_SHAREABLE)`
+type LayoutCompatibility = Result<Alignment, (Handle<crate::Type>, Disalignment)>;
+
+fn check_member_layout(
+ accum: &mut LayoutCompatibility,
+ member: &crate::StructMember,
+ member_index: u32,
+ member_layout: LayoutCompatibility,
+ parent_handle: Handle<crate::Type>,
+) {
+ *accum = match (*accum, member_layout) {
+ (Ok(cur_alignment), Ok(alignment)) => {
+ if alignment.is_aligned(member.offset) {
+ Ok(cur_alignment.max(alignment))
+ } else {
+ Err((
+ parent_handle,
+ Disalignment::MemberOffset {
+ index: member_index,
+ offset: member.offset,
+ alignment,
+ },
+ ))
+ }
+ }
+ (Err(e), _) | (_, Err(e)) => Err(e),
+ };
+}
+
+/// Determine whether a pointer in `space` can be passed as an argument.
+///
+/// If a pointer in `space` is permitted to be passed as an argument to a
+/// user-defined function, return `TypeFlags::ARGUMENT`. Otherwise, return
+/// `TypeFlags::empty()`.
+///
+/// Pointers passed as arguments to user-defined functions must be in the
+/// `Function` or `Private` address space.
+const fn ptr_space_argument_flag(space: crate::AddressSpace) -> TypeFlags {
+ use crate::AddressSpace as As;
+ match space {
+ As::Function | As::Private => TypeFlags::ARGUMENT,
+ As::Uniform | As::Storage { .. } | As::Handle | As::PushConstant | As::WorkGroup => {
+ TypeFlags::empty()
+ }
+ }
+}
+
+#[derive(Clone, Debug)]
+pub(super) struct TypeInfo {
+ pub flags: TypeFlags,
+ pub uniform_layout: LayoutCompatibility,
+ pub storage_layout: LayoutCompatibility,
+}
+
+impl TypeInfo {
+ const fn dummy() -> Self {
+ TypeInfo {
+ flags: TypeFlags::empty(),
+ uniform_layout: Ok(Alignment::ONE),
+ storage_layout: Ok(Alignment::ONE),
+ }
+ }
+
+ const fn new(flags: TypeFlags, alignment: Alignment) -> Self {
+ TypeInfo {
+ flags,
+ uniform_layout: Ok(alignment),
+ storage_layout: Ok(alignment),
+ }
+ }
+}
+
+impl super::Validator {
+ const fn require_type_capability(&self, capability: Capabilities) -> Result<(), TypeError> {
+ if self.capabilities.contains(capability) {
+ Ok(())
+ } else {
+ Err(TypeError::MissingCapability(capability))
+ }
+ }
+
+ pub(super) const fn check_width(&self, scalar: crate::Scalar) -> Result<(), WidthError> {
+ let good = match scalar.kind {
+ crate::ScalarKind::Bool => scalar.width == crate::BOOL_WIDTH,
+ crate::ScalarKind::Float => {
+ if scalar.width == 8 {
+ if !self.capabilities.contains(Capabilities::FLOAT64) {
+ return Err(WidthError::MissingCapability {
+ name: "f64",
+ flag: "FLOAT64",
+ });
+ }
+ true
+ } else {
+ scalar.width == 4
+ }
+ }
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
+ if scalar.width == 8 {
+ return Err(WidthError::Unsupported64Bit);
+ }
+ scalar.width == 4
+ }
+ crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => {
+ return Err(WidthError::Abstract);
+ }
+ };
+ if good {
+ Ok(())
+ } else {
+ Err(WidthError::Invalid(scalar.kind, scalar.width))
+ }
+ }
+
+ pub(super) fn reset_types(&mut self, size: usize) {
+ self.types.clear();
+ self.types.resize(size, TypeInfo::dummy());
+ self.layouter.clear();
+ }
+
+ pub(super) fn validate_type(
+ &self,
+ handle: Handle<crate::Type>,
+ gctx: crate::proc::GlobalCtx,
+ ) -> Result<TypeInfo, TypeError> {
+ use crate::TypeInner as Ti;
+ Ok(match gctx.types[handle].inner {
+ Ti::Scalar(scalar) => {
+ self.check_width(scalar)?;
+ let shareable = if scalar.kind.is_numeric() {
+ TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE
+ } else {
+ TypeFlags::empty()
+ };
+ TypeInfo::new(
+ TypeFlags::DATA
+ | TypeFlags::SIZED
+ | TypeFlags::COPY
+ | TypeFlags::ARGUMENT
+ | TypeFlags::CONSTRUCTIBLE
+ | shareable,
+ Alignment::from_width(scalar.width),
+ )
+ }
+ Ti::Vector { size, scalar } => {
+ self.check_width(scalar)?;
+ let shareable = if scalar.kind.is_numeric() {
+ TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE
+ } else {
+ TypeFlags::empty()
+ };
+ TypeInfo::new(
+ TypeFlags::DATA
+ | TypeFlags::SIZED
+ | TypeFlags::COPY
+ | TypeFlags::HOST_SHAREABLE
+ | TypeFlags::ARGUMENT
+ | TypeFlags::CONSTRUCTIBLE
+ | shareable,
+ Alignment::from(size) * Alignment::from_width(scalar.width),
+ )
+ }
+ Ti::Matrix {
+ columns: _,
+ rows,
+ scalar,
+ } => {
+ if scalar.kind != crate::ScalarKind::Float {
+ return Err(TypeError::MatrixElementNotFloat);
+ }
+ self.check_width(scalar)?;
+ TypeInfo::new(
+ TypeFlags::DATA
+ | TypeFlags::SIZED
+ | TypeFlags::COPY
+ | TypeFlags::HOST_SHAREABLE
+ | TypeFlags::ARGUMENT
+ | TypeFlags::CONSTRUCTIBLE,
+ Alignment::from(rows) * Alignment::from_width(scalar.width),
+ )
+ }
+ Ti::Atomic(crate::Scalar { kind, width }) => {
+ let good = match kind {
+ crate::ScalarKind::Bool
+ | crate::ScalarKind::Float
+ | crate::ScalarKind::AbstractInt
+ | crate::ScalarKind::AbstractFloat => false,
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4,
+ };
+ if !good {
+ return Err(TypeError::InvalidAtomicWidth(kind, width));
+ }
+ TypeInfo::new(
+ TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE,
+ Alignment::from_width(width),
+ )
+ }
+ Ti::Pointer { base, space } => {
+ use crate::AddressSpace as As;
+
+ let base_info = &self.types[base.index()];
+ if !base_info.flags.contains(TypeFlags::DATA) {
+ return Err(TypeError::InvalidPointerBase(base));
+ }
+
+ // Runtime-sized values can only live in the `Storage` address
+ // space, so it's useless to have a pointer to such a type in
+ // any other space.
+ //
+ // Detecting this problem here prevents the definition of
+ // functions like:
+ //
+ // fn f(p: ptr<workgroup, UnsizedType>) -> ... { ... }
+ //
+ // which would otherwise be permitted, but uncallable. (They
+ // may also present difficulties in code generation).
+ if !base_info.flags.contains(TypeFlags::SIZED) {
+ match space {
+ As::Storage { .. } => {}
+ _ => {
+ return Err(TypeError::InvalidPointerToUnsized { base, space });
+ }
+ }
+ }
+
+ // `Validator::validate_function` actually checks the address
+ // space of pointer arguments explicitly before checking the
+ // `ARGUMENT` flag, to give better error messages. But it seems
+ // best to set `ARGUMENT` accurately anyway.
+ let argument_flag = ptr_space_argument_flag(space);
+
+ // Pointers cannot be stored in variables, structure members, or
+ // array elements, so we do not mark them as `DATA`.
+ TypeInfo::new(
+ argument_flag | TypeFlags::SIZED | TypeFlags::COPY,
+ Alignment::ONE,
+ )
+ }
+ Ti::ValuePointer {
+ size: _,
+ scalar,
+ space,
+ } => {
+ // ValuePointer should be treated the same way as the equivalent
+ // Pointer / Scalar / Vector combination, so each step in those
+ // variants' match arms should have a counterpart here.
+ //
+ // However, some cases are trivial: All our implicit base types
+ // are DATA and SIZED, so we can never return
+ // `InvalidPointerBase` or `InvalidPointerToUnsized`.
+ self.check_width(scalar)?;
+
+ // `Validator::validate_function` actually checks the address
+ // space of pointer arguments explicitly before checking the
+ // `ARGUMENT` flag, to give better error messages. But it seems
+ // best to set `ARGUMENT` accurately anyway.
+ let argument_flag = ptr_space_argument_flag(space);
+
+ // Pointers cannot be stored in variables, structure members, or
+ // array elements, so we do not mark them as `DATA`.
+ TypeInfo::new(
+ argument_flag | TypeFlags::SIZED | TypeFlags::COPY,
+ Alignment::ONE,
+ )
+ }
+ Ti::Array { base, size, stride } => {
+ let base_info = &self.types[base.index()];
+ if !base_info.flags.contains(TypeFlags::DATA | TypeFlags::SIZED) {
+ return Err(TypeError::InvalidArrayBaseType(base));
+ }
+
+ let base_layout = self.layouter[base];
+ let general_alignment = base_layout.alignment;
+ let uniform_layout = match base_info.uniform_layout {
+ Ok(base_alignment) => {
+ let alignment = base_alignment
+ .max(general_alignment)
+ .max(Alignment::MIN_UNIFORM);
+ if alignment.is_aligned(stride) {
+ Ok(alignment)
+ } else {
+ Err((handle, Disalignment::ArrayStride { stride, alignment }))
+ }
+ }
+ Err(e) => Err(e),
+ };
+ let storage_layout = match base_info.storage_layout {
+ Ok(base_alignment) => {
+ let alignment = base_alignment.max(general_alignment);
+ if alignment.is_aligned(stride) {
+ Ok(alignment)
+ } else {
+ Err((handle, Disalignment::ArrayStride { stride, alignment }))
+ }
+ }
+ Err(e) => Err(e),
+ };
+
+ let type_info_mask = match size {
+ crate::ArraySize::Constant(_) => {
+ TypeFlags::DATA
+ | TypeFlags::SIZED
+ | TypeFlags::COPY
+ | TypeFlags::HOST_SHAREABLE
+ | TypeFlags::ARGUMENT
+ | TypeFlags::CONSTRUCTIBLE
+ }
+ crate::ArraySize::Dynamic => {
+ // Non-SIZED types may only appear as the last element of a structure.
+ // This is enforced by checks for SIZED-ness for all compound types,
+ // and a special case for structs.
+ TypeFlags::DATA | TypeFlags::COPY | TypeFlags::HOST_SHAREABLE
+ }
+ };
+
+ TypeInfo {
+ flags: base_info.flags & type_info_mask,
+ uniform_layout,
+ storage_layout,
+ }
+ }
+ Ti::Struct { ref members, span } => {
+ if members.is_empty() {
+ return Err(TypeError::EmptyStruct);
+ }
+
+ let mut ti = TypeInfo::new(
+ TypeFlags::DATA
+ | TypeFlags::SIZED
+ | TypeFlags::COPY
+ | TypeFlags::HOST_SHAREABLE
+ | TypeFlags::IO_SHAREABLE
+ | TypeFlags::ARGUMENT
+ | TypeFlags::CONSTRUCTIBLE,
+ Alignment::ONE,
+ );
+ ti.uniform_layout = Ok(Alignment::MIN_UNIFORM);
+
+ let mut min_offset = 0;
+
+ let mut prev_struct_data: Option<(u32, u32)> = None;
+
+ for (i, member) in members.iter().enumerate() {
+ let base_info = &self.types[member.ty.index()];
+ if !base_info.flags.contains(TypeFlags::DATA) {
+ return Err(TypeError::InvalidData(member.ty));
+ }
+ if !base_info.flags.contains(TypeFlags::HOST_SHAREABLE) {
+ if ti.uniform_layout.is_ok() {
+ ti.uniform_layout = Err((member.ty, Disalignment::NonHostShareable));
+ }
+ if ti.storage_layout.is_ok() {
+ ti.storage_layout = Err((member.ty, Disalignment::NonHostShareable));
+ }
+ }
+ ti.flags &= base_info.flags;
+
+ if member.offset < min_offset {
+ // HACK: this could be nicer. We want to allow some structures
+ // to not bother with offsets/alignments if they are never
+ // used for host sharing.
+ if member.offset == 0 {
+ ti.flags.set(TypeFlags::HOST_SHAREABLE, false);
+ } else {
+ return Err(TypeError::MemberOverlap {
+ index: i as u32,
+ offset: member.offset,
+ });
+ }
+ }
+
+ let base_size = gctx.types[member.ty].inner.size(gctx);
+ min_offset = member.offset + base_size;
+ if min_offset > span {
+ return Err(TypeError::MemberOutOfBounds {
+ index: i as u32,
+ offset: member.offset,
+ size: base_size,
+ span,
+ });
+ }
+
+ check_member_layout(
+ &mut ti.uniform_layout,
+ member,
+ i as u32,
+ base_info.uniform_layout,
+ handle,
+ );
+ check_member_layout(
+ &mut ti.storage_layout,
+ member,
+ i as u32,
+ base_info.storage_layout,
+ handle,
+ );
+
+ // Validate rule: If a structure member itself has a structure type S,
+ // then the number of bytes between the start of that member and
+ // the start of any following member must be at least roundUp(16, SizeOf(S)).
+ if let Some((span, offset)) = prev_struct_data {
+ let diff = member.offset - offset;
+ let min = Alignment::MIN_UNIFORM.round_up(span);
+ if diff < min {
+ ti.uniform_layout = Err((
+ handle,
+ Disalignment::MemberOffsetAfterStruct {
+ index: i as u32,
+ offset: member.offset,
+ expected: offset + min,
+ },
+ ));
+ }
+ };
+
+ prev_struct_data = match gctx.types[member.ty].inner {
+ crate::TypeInner::Struct { span, .. } => Some((span, member.offset)),
+ _ => None,
+ };
+
+ // The last field may be an unsized array.
+ if !base_info.flags.contains(TypeFlags::SIZED) {
+ let is_array = match gctx.types[member.ty].inner {
+ crate::TypeInner::Array { .. } => true,
+ _ => false,
+ };
+ if !is_array || i + 1 != members.len() {
+ let name = member.name.clone().unwrap_or_default();
+ return Err(TypeError::InvalidDynamicArray(name, member.ty));
+ }
+ if ti.uniform_layout.is_ok() {
+ ti.uniform_layout =
+ Err((handle, Disalignment::UnsizedMember { index: i as u32 }));
+ }
+ }
+ }
+
+ let alignment = self.layouter[handle].alignment;
+ if !alignment.is_aligned(span) {
+ ti.uniform_layout = Err((handle, Disalignment::StructSpan { span, alignment }));
+ ti.storage_layout = Err((handle, Disalignment::StructSpan { span, alignment }));
+ }
+
+ ti
+ }
+ Ti::Image {
+ dim,
+ arrayed,
+ class: _,
+ } => {
+ if arrayed && matches!(dim, crate::ImageDimension::Cube) {
+ self.require_type_capability(Capabilities::CUBE_ARRAY_TEXTURES)?;
+ }
+ TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE)
+ }
+ Ti::Sampler { .. } => TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE),
+ Ti::AccelerationStructure => {
+ self.require_type_capability(Capabilities::RAY_QUERY)?;
+ TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE)
+ }
+ Ti::RayQuery => {
+ self.require_type_capability(Capabilities::RAY_QUERY)?;
+ TypeInfo::new(
+ TypeFlags::DATA | TypeFlags::CONSTRUCTIBLE | TypeFlags::SIZED,
+ Alignment::ONE,
+ )
+ }
+ Ti::BindingArray { base, size } => {
+ if base >= handle {
+ return Err(TypeError::InvalidArrayBaseType(base));
+ }
+ let type_info_mask = match size {
+ crate::ArraySize::Constant(_) => TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE,
+ crate::ArraySize::Dynamic => {
+ // Final type is non-sized
+ TypeFlags::HOST_SHAREABLE
+ }
+ };
+ let base_info = &self.types[base.index()];
+
+ if base_info.flags.contains(TypeFlags::DATA) {
+ // Currently Naga only supports binding arrays of structs for non-handle types.
+ match gctx.types[base].inner {
+ crate::TypeInner::Struct { .. } => {}
+ _ => return Err(TypeError::BindingArrayBaseTypeNotStruct(base)),
+ };
+ }
+
+ TypeInfo::new(base_info.flags & type_info_mask, Alignment::ONE)
+ }
+ })
+ }
+}