summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-28 14:29:10 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-28 14:29:10 +0000
commit2aa4a82499d4becd2284cdb482213d541b8804dd (patch)
treeb80bf8bf13c3766139fbacc530efd0dd9d54394c /third_party/rust/naga/src
parentInitial commit. (diff)
downloadfirefox-upstream.tar.xz
firefox-upstream.zip
Adding upstream version 86.0.1.upstream/86.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/naga/src')
-rw-r--r--third_party/rust/naga/src/arena.rs225
-rw-r--r--third_party/rust/naga/src/back/glsl.rs1579
-rw-r--r--third_party/rust/naga/src/back/glsl/keywords.rs204
-rw-r--r--third_party/rust/naga/src/back/mod.rs8
-rw-r--r--third_party/rust/naga/src/back/msl/keywords.rs102
-rw-r--r--third_party/rust/naga/src/back/msl/mod.rs211
-rw-r--r--third_party/rust/naga/src/back/msl/writer.rs990
-rw-r--r--third_party/rust/naga/src/back/spv/helpers.rs20
-rw-r--r--third_party/rust/naga/src/back/spv/instructions.rs708
-rw-r--r--third_party/rust/naga/src/back/spv/layout.rs91
-rw-r--r--third_party/rust/naga/src/back/spv/layout_tests.rs166
-rw-r--r--third_party/rust/naga/src/back/spv/mod.rs52
-rw-r--r--third_party/rust/naga/src/back/spv/test_framework.rs27
-rw-r--r--third_party/rust/naga/src/back/spv/writer.rs1776
-rw-r--r--third_party/rust/naga/src/front/glsl/ast.rs178
-rw-r--r--third_party/rust/naga/src/front/glsl/error.rs87
-rw-r--r--third_party/rust/naga/src/front/glsl/lex.rs380
-rw-r--r--third_party/rust/naga/src/front/glsl/lex_tests.rs346
-rw-r--r--third_party/rust/naga/src/front/glsl/mod.rs43
-rw-r--r--third_party/rust/naga/src/front/glsl/parser.rs1131
-rw-r--r--third_party/rust/naga/src/front/glsl/parser_tests.rs182
-rw-r--r--third_party/rust/naga/src/front/glsl/preprocess.rs152
-rw-r--r--third_party/rust/naga/src/front/glsl/preprocess_tests.rs218
-rw-r--r--third_party/rust/naga/src/front/glsl/token.rs8
-rw-r--r--third_party/rust/naga/src/front/glsl/types.rs120
-rw-r--r--third_party/rust/naga/src/front/glsl/variables.rs185
-rw-r--r--third_party/rust/naga/src/front/mod.rs32
-rw-r--r--third_party/rust/naga/src/front/spv/convert.rs123
-rw-r--r--third_party/rust/naga/src/front/spv/error.rs56
-rw-r--r--third_party/rust/naga/src/front/spv/flow.rs569
-rw-r--r--third_party/rust/naga/src/front/spv/function.rs202
-rw-r--r--third_party/rust/naga/src/front/spv/mod.rs2416
-rw-r--r--third_party/rust/naga/src/front/spv/rosetta.rs23
-rw-r--r--third_party/rust/naga/src/front/wgsl/conv.rs117
-rw-r--r--third_party/rust/naga/src/front/wgsl/lexer.rs292
-rw-r--r--third_party/rust/naga/src/front/wgsl/mod.rs1850
-rw-r--r--third_party/rust/naga/src/lib.rs788
-rw-r--r--third_party/rust/naga/src/proc/call_graph.rs74
-rw-r--r--third_party/rust/naga/src/proc/interface.rs290
-rw-r--r--third_party/rust/naga/src/proc/mod.rs67
-rw-r--r--third_party/rust/naga/src/proc/namer.rs113
-rw-r--r--third_party/rust/naga/src/proc/typifier.rs424
-rw-r--r--third_party/rust/naga/src/proc/validator.rs489
43 files changed, 17114 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/arena.rs b/third_party/rust/naga/src/arena.rs
new file mode 100644
index 0000000000..460d909c2d
--- /dev/null
+++ b/third_party/rust/naga/src/arena.rs
@@ -0,0 +1,225 @@
+use std::{cmp::Ordering, fmt, hash, marker::PhantomData, num::NonZeroU32};
+
+/// An unique index in the arena array that a handle points to.
+/// The "non-zero" part ensures that an `Option<Handle<T>>` has
+/// the same size and representation as `Handle<T>`.
+type Index = NonZeroU32;
+
+/// A strongly typed reference to a SPIR-V element.
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(
+ any(feature = "serialize", feature = "deserialize"),
+ serde(transparent)
+)]
+pub struct Handle<T> {
+ index: Index,
+ #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))]
+ marker: PhantomData<T>,
+}
+
+impl<T> Clone for Handle<T> {
+ fn clone(&self) -> Self {
+ Handle {
+ index: self.index,
+ marker: self.marker,
+ }
+ }
+}
+impl<T> Copy for Handle<T> {}
+impl<T> PartialEq for Handle<T> {
+ fn eq(&self, other: &Self) -> bool {
+ self.index == other.index
+ }
+}
+impl<T> Eq for Handle<T> {}
+impl<T> PartialOrd for Handle<T> {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ self.index.partial_cmp(&other.index)
+ }
+}
+impl<T> Ord for Handle<T> {
+ fn cmp(&self, other: &Self) -> Ordering {
+ self.index.cmp(&other.index)
+ }
+}
+impl<T> fmt::Debug for Handle<T> {
+ fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ write!(formatter, "Handle({})", self.index)
+ }
+}
+impl<T> hash::Hash for Handle<T> {
+ fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
+ self.index.hash(hasher)
+ }
+}
+
+impl<T> Handle<T> {
+ #[cfg(test)]
+ pub const DUMMY: Self = Handle {
+ index: unsafe { NonZeroU32::new_unchecked(!0) },
+ marker: PhantomData,
+ };
+
+ pub(crate) fn new(index: Index) -> Self {
+ Handle {
+ index,
+ marker: PhantomData,
+ }
+ }
+
+ /// Returns the zero-based index of this handle.
+ pub fn index(self) -> usize {
+ let index = self.index.get() - 1;
+ index as usize
+ }
+}
+
+/// An arena holding some kind of component (e.g., type, constant,
+/// instruction, etc.) that can be referenced.
+///
+/// Adding new items to the arena produces a strongly-typed [`Handle`].
+/// The arena can be indexed using the given handle to obtain
+/// a reference to the stored item.
+#[derive(Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(
+ any(feature = "serialize", feature = "deserialize"),
+ serde(transparent)
+)]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Arena<T> {
+ /// Values of this arena.
+ data: Vec<T>,
+}
+
+impl<T> Default for Arena<T> {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl<T> Arena<T> {
+ /// Create a new arena with no initial capacity allocated.
+ pub fn new() -> Self {
+ Arena { data: Vec::new() }
+ }
+
+ /// Returns the current number of items stored in this arena.
+ pub fn len(&self) -> usize {
+ self.data.len()
+ }
+
+ /// Returns `true` if the arena contains no elements.
+ pub fn is_empty(&self) -> bool {
+ self.data.is_empty()
+ }
+
+ /// Returns an iterator over the items stored in this arena, returning both
+ /// the item's handle and a reference to it.
+ pub fn iter(&self) -> impl Iterator<Item = (Handle<T>, &T)> {
+ self.data.iter().enumerate().map(|(i, v)| {
+ let position = i + 1;
+ let index = unsafe { Index::new_unchecked(position as u32) };
+ (Handle::new(index), v)
+ })
+ }
+
+ /// Adds a new value to the arena, returning a typed handle.
+ ///
+ /// The value is not linked to any SPIR-V module.
+ pub fn append(&mut self, value: T) -> Handle<T> {
+ let position = self.data.len() + 1;
+ let index = unsafe { Index::new_unchecked(position as u32) };
+ self.data.push(value);
+ Handle::new(index)
+ }
+
+ /// Fetch a handle to an existing type.
+ pub fn fetch_if<F: Fn(&T) -> bool>(&self, fun: F) -> Option<Handle<T>> {
+ self.data
+ .iter()
+ .position(fun)
+ .map(|index| Handle::new(unsafe { Index::new_unchecked((index + 1) as u32) }))
+ }
+
+ /// Adds a value with a custom check for uniqueness:
+ /// returns a handle pointing to
+ /// an existing element if the check succeeds, or adds a new
+ /// element otherwise.
+ pub fn fetch_if_or_append<F: Fn(&T, &T) -> bool>(&mut self, value: T, fun: F) -> Handle<T> {
+ if let Some(index) = self.data.iter().position(|d| fun(d, &value)) {
+ let index = unsafe { Index::new_unchecked((index + 1) as u32) };
+ Handle::new(index)
+ } else {
+ self.append(value)
+ }
+ }
+
+ /// Adds a value with a check for uniqueness, where the check is plain comparison.
+ pub fn fetch_or_append(&mut self, value: T) -> Handle<T>
+ where
+ T: PartialEq,
+ {
+ self.fetch_if_or_append(value, T::eq)
+ }
+
+ pub fn try_get(&self, handle: Handle<T>) -> Option<&T> {
+ self.data.get(handle.index.get() as usize - 1)
+ }
+
+ /// Get a mutable reference to an element in the arena.
+ pub fn get_mut(&mut self, handle: Handle<T>) -> &mut T {
+ self.data.get_mut(handle.index.get() as usize - 1).unwrap()
+ }
+}
+
+impl<T> std::ops::Index<Handle<T>> for Arena<T> {
+ type Output = T;
+ fn index(&self, handle: Handle<T>) -> &T {
+ let index = handle.index.get() - 1;
+ &self.data[index as usize]
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn append_non_unique() {
+ let mut arena: Arena<u8> = Arena::new();
+ let t1 = arena.append(0);
+ let t2 = arena.append(0);
+ assert!(t1 != t2);
+ assert!(arena[t1] == arena[t2]);
+ }
+
+ #[test]
+ fn append_unique() {
+ let mut arena: Arena<u8> = Arena::new();
+ let t1 = arena.append(0);
+ let t2 = arena.append(1);
+ assert!(t1 != t2);
+ assert!(arena[t1] != arena[t2]);
+ }
+
+ #[test]
+ fn fetch_or_append_non_unique() {
+ let mut arena: Arena<u8> = Arena::new();
+ let t1 = arena.fetch_or_append(0);
+ let t2 = arena.fetch_or_append(0);
+ assert!(t1 == t2);
+ assert!(arena[t1] == arena[t2])
+ }
+
+ #[test]
+ fn fetch_or_append_unique() {
+ let mut arena: Arena<u8> = Arena::new();
+ let t1 = arena.fetch_or_append(0);
+ let t2 = arena.fetch_or_append(1);
+ assert!(t1 != t2);
+ assert!(arena[t1] != arena[t2]);
+ }
+}
diff --git a/third_party/rust/naga/src/back/glsl.rs b/third_party/rust/naga/src/back/glsl.rs
new file mode 100644
index 0000000000..6a497a675b
--- /dev/null
+++ b/third_party/rust/naga/src/back/glsl.rs
@@ -0,0 +1,1579 @@
+//! OpenGL shading language backend
+//!
+//! The main structure is [`Writer`](struct.Writer.html), it maintains internal state that is used
+//! to output a `Module` into glsl
+//!
+//! # Supported versions
+//! ### Core
+//! - 330
+//! - 400
+//! - 410
+//! - 420
+//! - 430
+//! - 450
+//! - 460
+//!
+//! ### ES
+//! - 300
+//! - 310
+//!
+
+use crate::{
+ proc::{
+ CallGraph, CallGraphBuilder, Interface, NameKey, Namer, ResolveContext, ResolveError,
+ Typifier, Visitor,
+ },
+ Arena, ArraySize, BinaryOperator, BuiltIn, ConservativeDepth, Constant, ConstantInner,
+ DerivativeAxis, Expression, FastHashMap, Function, FunctionOrigin, GlobalVariable, Handle,
+ ImageClass, Interpolation, IntrinsicFunction, LocalVariable, Module, ScalarKind, ShaderStage,
+ Statement, StorageAccess, StorageClass, StorageFormat, StructMember, Type, TypeInner,
+ UnaryOperator,
+};
+use std::{
+ cmp::Ordering,
+ fmt::{self, Error as FmtError},
+ io::{Error as IoError, Write},
+};
+
+/// Contains a constant with a slice of all the reserved keywords RESERVED_KEYWORDS
+mod keywords;
+
+const SUPPORTED_CORE_VERSIONS: &[u16] = &[330, 400, 410, 420, 430, 440, 450];
+const SUPPORTED_ES_VERSIONS: &[u16] = &[300, 310];
+
+#[derive(Debug)]
+pub enum Error {
+ FormatError(FmtError),
+ IoError(IoError),
+ Type(ResolveError),
+ Custom(String),
+}
+
+impl From<FmtError> for Error {
+ fn from(err: FmtError) -> Self {
+ Error::FormatError(err)
+ }
+}
+
+impl From<IoError> for Error {
+ fn from(err: IoError) -> Self {
+ Error::IoError(err)
+ }
+}
+
+impl From<ResolveError> for Error {
+ fn from(err: ResolveError) -> Self {
+ Error::Type(err)
+ }
+}
+
+impl fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Error::FormatError(err) => write!(f, "Formatting error {}", err),
+ Error::IoError(err) => write!(f, "Io error: {}", err),
+ Error::Type(err) => write!(f, "Type error: {:?}", err),
+ Error::Custom(err) => write!(f, "{}", err),
+ }
+ }
+}
+
+#[derive(Debug, Copy, Clone, PartialEq)]
+pub enum Version {
+ Desktop(u16),
+ Embedded(u16),
+}
+
+impl Version {
+ fn is_es(&self) -> bool {
+ match self {
+ Version::Desktop(_) => false,
+ Version::Embedded(_) => true,
+ }
+ }
+
+ fn is_supported(&self) -> bool {
+ match self {
+ Version::Desktop(v) => SUPPORTED_CORE_VERSIONS.contains(v),
+ Version::Embedded(v) => SUPPORTED_ES_VERSIONS.contains(v),
+ }
+ }
+}
+
+impl PartialOrd for Version {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ match (*self, *other) {
+ (Version::Desktop(x), Version::Desktop(y)) => Some(x.cmp(&y)),
+ (Version::Embedded(x), Version::Embedded(y)) => Some(x.cmp(&y)),
+ _ => None,
+ }
+ }
+}
+
+impl fmt::Display for Version {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Version::Desktop(v) => write!(f, "{} core", v),
+ Version::Embedded(v) => write!(f, "{} es", v),
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct Options {
+ pub version: Version,
+ pub entry_point: (ShaderStage, String),
+}
+
+#[derive(Debug, Clone)]
+pub struct TextureMapping {
+ pub texture: Handle<GlobalVariable>,
+ pub sampler: Option<Handle<GlobalVariable>>,
+}
+
+bitflags::bitflags! {
+ struct Features: u32 {
+ const BUFFER_STORAGE = 1;
+ const ARRAY_OF_ARRAYS = 1 << 1;
+ const DOUBLE_TYPE = 1 << 2;
+ const FULL_IMAGE_FORMATS = 1 << 3;
+ const MULTISAMPLED_TEXTURES = 1 << 4;
+ const MULTISAMPLED_TEXTURE_ARRAYS = 1 << 5;
+ const CUBE_TEXTURES_ARRAY = 1 << 6;
+ const COMPUTE_SHADER = 1 << 7;
+ const IMAGE_LOAD_STORE = 1 << 8;
+ const CONSERVATIVE_DEPTH = 1 << 9;
+ const TEXTURE_1D = 1 << 10;
+ const PUSH_CONSTANT = 1 << 11;
+ }
+}
+
+struct FeaturesManager(Features);
+
+impl FeaturesManager {
+ pub fn new() -> Self {
+ Self(Features::empty())
+ }
+
+ pub fn request(&mut self, features: Features) {
+ self.0 |= features
+ }
+
+ #[allow(clippy::collapsible_if)]
+ pub fn write(&self, version: Version, mut out: impl Write) -> Result<(), Error> {
+ if self.0.contains(Features::COMPUTE_SHADER) {
+ if version < Version::Embedded(310) || version < Version::Desktop(420) {
+ return Err(Error::Custom(format!(
+ "Version {} doesn't support compute shaders",
+ version
+ )));
+ }
+
+ if !version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_compute_shader.txt
+ writeln!(out, "#extension GL_ARB_compute_shader : require")?;
+ }
+ }
+
+ if self.0.contains(Features::BUFFER_STORAGE) {
+ if version < Version::Embedded(310) || version < Version::Desktop(400) {
+ return Err(Error::Custom(format!(
+ "Version {} doesn't support buffer storage class",
+ version
+ )));
+ }
+
+ if let Version::Desktop(_) = version {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_storage_buffer_object.txt
+ writeln!(
+ out,
+ "#extension GL_ARB_shader_storage_buffer_object : require"
+ )?;
+ }
+ }
+
+ if self.0.contains(Features::DOUBLE_TYPE) {
+ if version.is_es() || version < Version::Desktop(150) {
+ return Err(Error::Custom(format!(
+ "Version {} doesn't support doubles",
+ version
+ )));
+ }
+
+ if version < Version::Desktop(400) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_gpu_shader_fp64.txt
+ writeln!(out, "#extension GL_ARB_gpu_shader_fp64 : require")?;
+ }
+ }
+
+ if self.0.contains(Features::CUBE_TEXTURES_ARRAY) {
+ if version < Version::Embedded(310) || version < Version::Desktop(130) {
+ return Err(Error::Custom(format!(
+ "Version {} doesn't support cube map array textures",
+ version
+ )));
+ }
+
+ if version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_cube_map_array.txt
+ writeln!(out, "#extension GL_EXT_texture_cube_map_array : require")?;
+ } else if version < Version::Desktop(400) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_cube_map_array.txt
+ writeln!(out, "#extension GL_ARB_texture_cube_map_array : require")?;
+ }
+ }
+
+ if self.0.contains(Features::MULTISAMPLED_TEXTURES) {
+ if version < Version::Embedded(300) {
+ return Err(Error::Custom(format!(
+ "Version {} doesn't support multi sampled textures",
+ version
+ )));
+ }
+ }
+
+ if self.0.contains(Features::MULTISAMPLED_TEXTURE_ARRAYS) {
+ if version < Version::Embedded(310) {
+ return Err(Error::Custom(format!(
+ "Version {} doesn't support multi sampled texture arrays",
+ version
+ )));
+ }
+
+ if version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_texture_storage_multisample_2d_array.txt
+ writeln!(
+ out,
+ "#extension GL_OES_texture_storage_multisample_2d_array : require"
+ )?;
+ }
+ }
+
+ if self.0.contains(Features::ARRAY_OF_ARRAYS) {
+ if version < Version::Embedded(310) || version < Version::Desktop(120) {
+ return Err(Error::Custom(format!(
+ "Version {} doesn't arrays of arrays",
+ version
+ )));
+ }
+
+ if version < Version::Desktop(430) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_arrays_of_arrays.txt
+ writeln!(out, "#extension ARB_arrays_of_arrays : require")?;
+ }
+ }
+
+ if self.0.contains(Features::IMAGE_LOAD_STORE) {
+ if version < Version::Embedded(310) || version < Version::Desktop(130) {
+ return Err(Error::Custom(format!(
+ "Version {} doesn't support images load/stores",
+ version
+ )));
+ }
+
+ if self.0.contains(Features::FULL_IMAGE_FORMATS) && version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/NV/NV_image_formats.txt
+ writeln!(out, "#extension GL_NV_image_formats : require")?;
+ }
+
+ if version < Version::Desktop(420) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_image_load_store.txt
+ writeln!(out, "#extension GL_ARB_shader_image_load_store : require")?;
+ }
+ }
+
+ if self.0.contains(Features::CONSERVATIVE_DEPTH) {
+ if version < Version::Embedded(300) || version < Version::Desktop(130) {
+ return Err(Error::Custom(format!(
+ "Version {} doesn't support conservative depth",
+ version
+ )));
+ }
+
+ if version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_conservative_depth.txt
+ writeln!(out, "#extension GL_EXT_conservative_depth : require")?;
+ }
+
+ if version < Version::Desktop(420) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt
+ writeln!(out, "#extension GL_ARB_conservative_depth : require")?;
+ }
+ }
+
+ if self.0.contains(Features::TEXTURE_1D) {
+ if version.is_es() {
+ return Err(Error::Custom(format!(
+ "Version {} doesn't support 1d textures",
+ version
+ )));
+ }
+ }
+
+ Ok(())
+ }
+}
+
+enum FunctionType {
+ Function(Handle<Function>),
+ EntryPoint(crate::proc::EntryPointIndex),
+}
+
+struct FunctionCtx<'a, 'b> {
+ func: FunctionType,
+ expressions: &'a Arena<Expression>,
+ typifier: &'b Typifier,
+}
+
+impl<'a, 'b> FunctionCtx<'a, 'b> {
+ fn name_key(&self, local: Handle<LocalVariable>) -> NameKey {
+ match self.func {
+ FunctionType::Function(handle) => NameKey::FunctionLocal(handle, local),
+ FunctionType::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local),
+ }
+ }
+
+ fn get_arg<'c>(&self, arg: u32, names: &'c FastHashMap<NameKey, String>) -> &'c str {
+ match self.func {
+ FunctionType::Function(handle) => &names[&NameKey::FunctionArgument(handle, arg)],
+ FunctionType::EntryPoint(_) => unreachable!(),
+ }
+ }
+}
+
+/// Helper structure that generates a number
+#[derive(Default)]
+struct IdGenerator(u32);
+
+impl IdGenerator {
+ fn generate(&mut self) -> u32 {
+ let ret = self.0;
+ self.0 += 1;
+ ret
+ }
+}
+
+/// Main structure of the glsl backend responsible for all code generation
+pub struct Writer<'a, W> {
+ // Inputs
+ module: &'a Module,
+ out: W,
+ options: &'a Options,
+
+ // Internal State
+ features: FeaturesManager,
+ names: FastHashMap<NameKey, String>,
+ entry_point: &'a crate::EntryPoint,
+ entry_point_idx: crate::proc::EntryPointIndex,
+ call_graph: CallGraph,
+
+ /// Used to generate a unique number for blocks
+ block_id: IdGenerator,
+}
+
+impl<'a, W: Write> Writer<'a, W> {
+ pub fn new(out: W, module: &'a Module, options: &'a Options) -> Result<Self, Error> {
+ if !options.version.is_supported() {
+ return Err(Error::Custom(format!(
+ "Version not supported {}",
+ options.version
+ )));
+ }
+
+ let (ep_idx, ep) = module
+ .entry_points
+ .iter()
+ .enumerate()
+ .find_map(|(i, (key, entry_point))| {
+ Some((i as u16, entry_point)).filter(|_| &options.entry_point == key)
+ })
+ .ok_or_else(|| Error::Custom(String::from("Entry point not found")))?;
+
+ let mut names = FastHashMap::default();
+ Namer::process(module, keywords::RESERVED_KEYWORDS, &mut names);
+
+ let call_graph = CallGraphBuilder {
+ functions: &module.functions,
+ }
+ .process(&ep.function);
+
+ let mut this = Self {
+ module,
+ out,
+ options,
+
+ features: FeaturesManager::new(),
+ names,
+ entry_point: ep,
+ entry_point_idx: ep_idx,
+ call_graph,
+
+ block_id: IdGenerator::default(),
+ };
+
+ this.collect_required_features()?;
+
+ Ok(this)
+ }
+
+ fn collect_required_features(&mut self) -> Result<(), Error> {
+ let stage = self.options.entry_point.0;
+
+ if let Some(depth_test) = self.entry_point.early_depth_test {
+ self.features.request(Features::IMAGE_LOAD_STORE);
+
+ if depth_test.conservative.is_some() {
+ self.features.request(Features::CONSERVATIVE_DEPTH);
+ }
+ }
+
+ if let ShaderStage::Compute = stage {
+ self.features.request(Features::COMPUTE_SHADER)
+ }
+
+ for (_, ty) in self.module.types.iter() {
+ match ty.inner {
+ TypeInner::Scalar { kind, width } => self.scalar_required_features(kind, width),
+ TypeInner::Vector { kind, width, .. } => self.scalar_required_features(kind, width),
+ TypeInner::Matrix { .. } => self.scalar_required_features(ScalarKind::Float, 8),
+ TypeInner::Array { base, .. } => {
+ if let TypeInner::Array { .. } = self.module.types[base].inner {
+ self.features.request(Features::ARRAY_OF_ARRAYS)
+ }
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ if arrayed && dim == crate::ImageDimension::Cube {
+ self.features.request(Features::CUBE_TEXTURES_ARRAY)
+ } else if dim == crate::ImageDimension::D1 {
+ self.features.request(Features::TEXTURE_1D)
+ }
+
+ match class {
+ ImageClass::Sampled { multi: true, .. } => {
+ self.features.request(Features::MULTISAMPLED_TEXTURES);
+ if arrayed {
+ self.features.request(Features::MULTISAMPLED_TEXTURE_ARRAYS);
+ }
+ }
+ ImageClass::Storage(format) => match format {
+ StorageFormat::R8Unorm
+ | StorageFormat::R8Snorm
+ | StorageFormat::R8Uint
+ | StorageFormat::R8Sint
+ | StorageFormat::R16Uint
+ | StorageFormat::R16Sint
+ | StorageFormat::R16Float
+ | StorageFormat::Rg8Unorm
+ | StorageFormat::Rg8Snorm
+ | StorageFormat::Rg8Uint
+ | StorageFormat::Rg8Sint
+ | StorageFormat::Rg16Uint
+ | StorageFormat::Rg16Sint
+ | StorageFormat::Rg16Float
+ | StorageFormat::Rgb10a2Unorm
+ | StorageFormat::Rg11b10Float
+ | StorageFormat::Rg32Uint
+ | StorageFormat::Rg32Sint
+ | StorageFormat::Rg32Float => {
+ self.features.request(Features::FULL_IMAGE_FORMATS)
+ }
+ _ => {}
+ },
+ _ => {}
+ }
+ }
+ _ => {}
+ }
+ }
+
+ for (_, global) in self.module.global_variables.iter() {
+ match global.class {
+ StorageClass::WorkGroup => self.features.request(Features::COMPUTE_SHADER),
+ StorageClass::Storage => self.features.request(Features::BUFFER_STORAGE),
+ StorageClass::PushConstant => self.features.request(Features::PUSH_CONSTANT),
+ _ => {}
+ }
+ }
+
+ Ok(())
+ }
+
+ fn scalar_required_features(&mut self, kind: ScalarKind, width: crate::Bytes) {
+ if kind == ScalarKind::Float && width == 8 {
+ self.features.request(Features::DOUBLE_TYPE);
+ }
+ }
+
+ pub fn write(&mut self) -> Result<FastHashMap<String, TextureMapping>, Error> {
+ let es = self.options.version.is_es();
+
+ writeln!(self.out, "#version {}", self.options.version)?;
+ self.features.write(self.options.version, &mut self.out)?;
+ writeln!(self.out)?;
+
+ if es {
+ writeln!(self.out, "precision highp float;\n")?;
+ }
+
+ if let Some(depth_test) = self.entry_point.early_depth_test {
+ writeln!(self.out, "layout(early_fragment_tests) in;\n")?;
+
+ if let Some(conservative) = depth_test.conservative {
+ writeln!(
+ self.out,
+ "layout (depth_{}) out float gl_FragDepth;\n",
+ match conservative {
+ ConservativeDepth::GreaterEqual => "greater",
+ ConservativeDepth::LessEqual => "less",
+ ConservativeDepth::Unchanged => "unchanged",
+ }
+ )?;
+ }
+ }
+
+ for (handle, ty) in self.module.types.iter() {
+ if let TypeInner::Struct { ref members } = ty.inner {
+ self.write_struct(handle, members)?
+ }
+ }
+
+ writeln!(self.out)?;
+
+ let texture_mappings = self.collect_texture_mapping(
+ self.call_graph
+ .raw_nodes()
+ .iter()
+ .map(|node| &self.module.functions[node.weight])
+ .chain(std::iter::once(&self.entry_point.function)),
+ )?;
+
+ for (handle, global) in self
+ .module
+ .global_variables
+ .iter()
+ .zip(&self.entry_point.function.global_usage)
+ .filter_map(|(global, usage)| Some(global).filter(|_| !usage.is_empty()))
+ {
+ if let Some(crate::Binding::BuiltIn(_)) = global.binding {
+ continue;
+ }
+
+ match self.module.types[global.ty].inner {
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ if let TypeInner::Image {
+ class: ImageClass::Storage(format),
+ ..
+ } = self.module.types[global.ty].inner
+ {
+ write!(self.out, "layout({}) ", glsl_storage_format(format))?;
+ }
+
+ if global.storage_access == StorageAccess::LOAD {
+ write!(self.out, "readonly ")?;
+ } else if global.storage_access == StorageAccess::STORE {
+ write!(self.out, "writeonly ")?;
+ }
+
+ write!(self.out, "uniform ")?;
+
+ self.write_image_type(dim, arrayed, class)?;
+
+ writeln!(
+ self.out,
+ " {};",
+ self.names[&NameKey::GlobalVariable(handle)]
+ )?
+ }
+ TypeInner::Sampler { .. } => continue,
+ _ => self.write_global(handle, global)?,
+ }
+ }
+
+ writeln!(self.out)?;
+
+ // Sort the graph topologically so that functions calls are valid
+ // It's impossible for this to panic because the IR forbids cycles
+ let functions = petgraph::algo::toposort(&self.call_graph, None).unwrap();
+
+ for node in functions {
+ let handle = self.call_graph[node];
+ let name = self.names[&NameKey::Function(handle)].clone();
+ self.write_function(
+ FunctionType::Function(handle),
+ &self.module.functions[handle],
+ name,
+ )?;
+ }
+
+ self.write_function(
+ FunctionType::EntryPoint(self.entry_point_idx),
+ &self.entry_point.function,
+ "main",
+ )?;
+
+ Ok(texture_mappings)
+ }
+
+ fn write_global(
+ &mut self,
+ handle: Handle<GlobalVariable>,
+ global: &GlobalVariable,
+ ) -> Result<(), Error> {
+ if global.storage_access == StorageAccess::LOAD {
+ write!(self.out, "readonly ")?;
+ } else if global.storage_access == StorageAccess::STORE {
+ write!(self.out, "writeonly ")?;
+ }
+
+ if let Some(interpolation) = global.interpolation {
+ match (self.options.entry_point.0, global.class) {
+ (ShaderStage::Fragment, StorageClass::Input)
+ | (ShaderStage::Vertex, StorageClass::Output) => {
+ write!(self.out, "{} ", glsl_interpolation(interpolation)?)?;
+ }
+ _ => (),
+ };
+ }
+
+ let block = match global.class {
+ StorageClass::Storage | StorageClass::Uniform => {
+ let block_name = self.names[&NameKey::Type(global.ty)].clone();
+
+ Some(block_name)
+ }
+ _ => None,
+ };
+
+ write!(self.out, "{} ", glsl_storage_class(global.class))?;
+
+ self.write_type(global.ty, block)?;
+
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ writeln!(self.out, " {};", name)?;
+
+ Ok(())
+ }
+
+ fn write_function<N: AsRef<str>>(
+ &mut self,
+ ty: FunctionType,
+ func: &Function,
+ name: N,
+ ) -> Result<(), Error> {
+ let mut typifier = Typifier::new();
+
+ typifier.resolve_all(
+ &func.expressions,
+ &self.module.types,
+ &ResolveContext {
+ constants: &self.module.constants,
+ global_vars: &self.module.global_variables,
+ local_vars: &func.local_variables,
+ functions: &self.module.functions,
+ arguments: &func.arguments,
+ },
+ )?;
+
+ let ctx = FunctionCtx {
+ func: ty,
+ expressions: &func.expressions,
+ typifier: &typifier,
+ };
+
+ self.write_fn_header(name.as_ref(), func, &ctx)?;
+ writeln!(self.out, " {{",)?;
+
+ for (handle, local) in func.local_variables.iter() {
+ write!(self.out, "\t")?;
+ self.write_type(local.ty, None)?;
+
+ write!(self.out, " {}", self.names[&ctx.name_key(handle)])?;
+
+ if let Some(init) = local.init {
+ write!(self.out, " = ",)?;
+
+ self.write_constant(&self.module.constants[init])?;
+ }
+
+ writeln!(self.out, ";")?
+ }
+
+ writeln!(self.out)?;
+
+ for sta in func.body.iter() {
+ self.write_stmt(sta, &ctx, 1)?;
+ }
+
+ Ok(writeln!(self.out, "}}")?)
+ }
+
+ fn write_slice<T, F: FnMut(&mut Self, u32, &T) -> Result<(), Error>>(
+ &mut self,
+ data: &[T],
+ mut f: F,
+ ) -> Result<(), Error> {
+ for (i, item) in data.iter().enumerate() {
+ f(self, i as u32, item)?;
+
+ if i != data.len().saturating_sub(1) {
+ write!(self.out, ",")?;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn write_fn_header(
+ &mut self,
+ name: &str,
+ func: &Function,
+ ctx: &FunctionCtx<'_, '_>,
+ ) -> Result<(), Error> {
+ if let Some(ty) = func.return_type {
+ self.write_type(ty, None)?;
+ } else {
+ write!(self.out, "void")?;
+ }
+
+ write!(self.out, " {}(", name)?;
+
+ self.write_slice(&func.arguments, |this, i, arg| {
+ this.write_type(arg.ty, None)?;
+
+ let name = ctx.get_arg(i, &this.names);
+
+ Ok(write!(this.out, " {}", name)?)
+ })?;
+
+ write!(self.out, ")")?;
+
+ Ok(())
+ }
+
+ fn write_type(&mut self, ty: Handle<Type>, block: Option<String>) -> Result<(), Error> {
+ match self.module.types[ty].inner {
+ TypeInner::Scalar { kind, width } => {
+ write!(self.out, "{}", glsl_scalar(kind, width)?.full)?
+ }
+ TypeInner::Vector { size, kind, width } => write!(
+ self.out,
+ "{}vec{}",
+ glsl_scalar(kind, width)?.prefix,
+ size as u8
+ )?,
+ TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => write!(
+ self.out,
+ "{}mat{}x{}",
+ glsl_scalar(ScalarKind::Float, width)?.prefix,
+ columns as u8,
+ rows as u8
+ )?,
+ TypeInner::Pointer { base, .. } => self.write_type(base, None)?,
+ TypeInner::Array { base, size, .. } => {
+ self.write_type(base, None)?;
+
+ write!(self.out, "[")?;
+ self.write_array_size(size)?;
+ write!(self.out, "]")?
+ }
+ TypeInner::Struct { ref members } => {
+ if let Some(name) = block {
+ writeln!(self.out, "{}_block_{} {{", name, self.block_id.generate())?;
+
+ for (idx, member) in members.iter().enumerate() {
+ self.write_type(member.ty, None)?;
+
+ writeln!(
+ self.out,
+ " {};",
+ &self.names[&NameKey::StructMember(ty, idx as u32)]
+ )?;
+ }
+
+ write!(self.out, "}}")?
+ } else {
+ write!(self.out, "{}", &self.names[&NameKey::Type(ty)])?
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ fn write_image_type(
+ &mut self,
+ dim: crate::ImageDimension,
+ arrayed: bool,
+ class: ImageClass,
+ ) -> Result<(), Error> {
+ let (base, kind, ms, comparison) = match class {
+ ImageClass::Sampled { kind, multi: true } => ("sampler", kind, "MS", ""),
+ ImageClass::Sampled { kind, multi: false } => ("sampler", kind, "", ""),
+ ImageClass::Depth => ("sampler", crate::ScalarKind::Float, "", "Shadow"),
+ ImageClass::Storage(format) => ("image", format.into(), "", ""),
+ };
+
+ Ok(write!(
+ self.out,
+ "{}{}{}{}{}{}",
+ glsl_scalar(kind, 4)?.prefix,
+ base,
+ ImageDimension(dim),
+ ms,
+ if arrayed { "Array" } else { "" },
+ comparison
+ )?)
+ }
+
+ fn write_array_size(&mut self, size: ArraySize) -> Result<(), Error> {
+ match size {
+ ArraySize::Constant(const_handle) => match self.module.constants[const_handle].inner {
+ ConstantInner::Uint(size) => write!(self.out, "{}", size)?,
+ _ => unreachable!(),
+ },
+ ArraySize::Dynamic => (),
+ }
+
+ Ok(())
+ }
+
+ fn collect_texture_mapping(
+ &self,
+ functions: impl Iterator<Item = &'a Function>,
+ ) -> Result<FastHashMap<String, TextureMapping>, Error> {
+ let mut mappings = FastHashMap::default();
+
+ for func in functions {
+ let mut interface = Interface {
+ expressions: &func.expressions,
+ local_variables: &func.local_variables,
+ visitor: TextureMappingVisitor {
+ names: &self.names,
+ expressions: &func.expressions,
+ map: &mut mappings,
+ error: None,
+ },
+ };
+ interface.traverse(&func.body);
+
+ if let Some(error) = interface.visitor.error {
+ return Err(error);
+ }
+ }
+
+ Ok(mappings)
+ }
+
+ fn write_struct(
+ &mut self,
+ handle: Handle<Type>,
+ members: &[StructMember],
+ ) -> Result<(), Error> {
+ writeln!(self.out, "struct {} {{", self.names[&NameKey::Type(handle)])?;
+
+ for (idx, member) in members.iter().enumerate() {
+ write!(self.out, "\t")?;
+ self.write_type(member.ty, None)?;
+ writeln!(
+ self.out,
+ " {};",
+ self.names[&NameKey::StructMember(handle, idx as u32)]
+ )?;
+ }
+
+ writeln!(self.out, "}};")?;
+ Ok(())
+ }
+
+ fn write_stmt(
+ &mut self,
+ sta: &Statement,
+ ctx: &FunctionCtx<'_, '_>,
+ indent: usize,
+ ) -> Result<(), Error> {
+ write!(self.out, "{}", "\t".repeat(indent))?;
+
+ match sta {
+ Statement::Block(block) => {
+ writeln!(self.out, "{{")?;
+ for sta in block.iter() {
+ self.write_stmt(sta, ctx, indent + 1)?
+ }
+ writeln!(self.out, "{}}}", "\t".repeat(indent))?
+ }
+ Statement::If {
+ condition,
+ accept,
+ reject,
+ } => {
+ write!(self.out, "if(")?;
+ self.write_expr(*condition, ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ for sta in accept {
+ self.write_stmt(sta, ctx, indent + 1)?;
+ }
+
+ if !reject.is_empty() {
+ writeln!(self.out, "{}}} else {{", "\t".repeat(indent))?;
+
+ for sta in reject {
+ self.write_stmt(sta, ctx, indent + 1)?;
+ }
+ }
+
+ writeln!(self.out, "{}}}", "\t".repeat(indent))?
+ }
+ Statement::Switch {
+ selector,
+ cases,
+ default,
+ } => {
+ write!(self.out, "switch(")?;
+ self.write_expr(*selector, ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ for (label, (block, fallthrough)) in cases {
+ writeln!(self.out, "{}case {}:", "\t".repeat(indent + 1), label)?;
+
+ for sta in block {
+ self.write_stmt(sta, ctx, indent + 2)?;
+ }
+
+ if fallthrough.is_none() {
+ writeln!(self.out, "{}break;", "\t".repeat(indent + 2))?;
+ }
+ }
+
+ if !default.is_empty() {
+ writeln!(self.out, "{}default:", "\t".repeat(indent + 1))?;
+
+ for sta in default {
+ self.write_stmt(sta, ctx, indent + 2)?;
+ }
+ }
+
+ writeln!(self.out, "{}}}", "\t".repeat(indent))?
+ }
+ Statement::Loop { body, continuing } => {
+ writeln!(self.out, "while(true) {{")?;
+
+ for sta in body.iter().chain(continuing.iter()) {
+ self.write_stmt(sta, ctx, indent + 1)?;
+ }
+
+ writeln!(self.out, "{}}}", "\t".repeat(indent))?
+ }
+ Statement::Break => writeln!(self.out, "break;")?,
+ Statement::Continue => writeln!(self.out, "continue;")?,
+ Statement::Return { value } => {
+ write!(self.out, "return")?;
+ if let Some(expr) = value {
+ write!(self.out, " ")?;
+ self.write_expr(*expr, ctx)?;
+ }
+ writeln!(self.out, ";")?;
+ }
+ Statement::Kill => writeln!(self.out, "discard;")?,
+ Statement::Store { pointer, value } => {
+ self.write_expr(*pointer, ctx)?;
+ write!(self.out, " = ")?;
+ self.write_expr(*value, ctx)?;
+ writeln!(self.out, ";")?
+ }
+ }
+
+ Ok(())
+ }
+
+ fn write_expr(
+ &mut self,
+ expr: Handle<Expression>,
+ ctx: &FunctionCtx<'_, '_>,
+ ) -> Result<(), Error> {
+ match ctx.expressions[expr] {
+ Expression::Access { base, index } => {
+ self.write_expr(base, ctx)?;
+ write!(self.out, "[")?;
+ self.write_expr(index, ctx)?;
+ write!(self.out, "]")?
+ }
+ Expression::AccessIndex { base, index } => {
+ self.write_expr(base, ctx)?;
+
+ match ctx.typifier.get(base, &self.module.types) {
+ TypeInner::Vector { .. }
+ | TypeInner::Matrix { .. }
+ | TypeInner::Array { .. } => write!(self.out, "[{}]", index)?,
+ TypeInner::Struct { .. } => {
+ let ty = ctx.typifier.get_handle(base).unwrap();
+
+ write!(
+ self.out,
+ ".{}",
+ &self.names[&NameKey::StructMember(ty, index)]
+ )?
+ }
+ ref other => return Err(Error::Custom(format!("Cannot index {:?}", other))),
+ }
+ }
+ Expression::Constant(constant) => {
+ self.write_constant(&self.module.constants[constant])?
+ }
+ Expression::Compose { ty, ref components } => {
+ match self.module.types[ty].inner {
+ TypeInner::Vector { .. }
+ | TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::Struct { .. } => self.write_type(ty, None)?,
+ _ => unreachable!(),
+ }
+
+ write!(self.out, "(")?;
+ self.write_slice(components, |this, _, arg| this.write_expr(*arg, ctx))?;
+ write!(self.out, ")")?
+ }
+ Expression::FunctionArgument(pos) => {
+ write!(self.out, "{}", ctx.get_arg(pos, &self.names))?
+ }
+ Expression::GlobalVariable(handle) => {
+ if let Some(crate::Binding::BuiltIn(built_in)) =
+ self.module.global_variables[handle].binding
+ {
+ write!(self.out, "{}", glsl_built_in(built_in))?
+ } else {
+ write!(
+ self.out,
+ "{}",
+ &self.names[&NameKey::GlobalVariable(handle)]
+ )?
+ }
+ }
+ Expression::LocalVariable(handle) => {
+ write!(self.out, "{}", self.names[&ctx.name_key(handle)])?
+ }
+ Expression::Load { pointer } => self.write_expr(pointer, ctx)?,
+ Expression::ImageSample {
+ image,
+ coordinate,
+ level,
+ depth_ref,
+ ..
+ } => {
+ //TODO: handle MS
+ write!(
+ self.out,
+ "{}(",
+ match level {
+ crate::SampleLevel::Auto | crate::SampleLevel::Bias(_) => "texture",
+ crate::SampleLevel::Zero | crate::SampleLevel::Exact(_) => "textureLod",
+ }
+ )?;
+ self.write_expr(image, ctx)?;
+ write!(self.out, ", ")?;
+
+ let size = match *ctx.typifier.get(coordinate, &self.module.types) {
+ TypeInner::Vector { size, .. } => size,
+ ref other => {
+ return Err(Error::Custom(format!(
+ "Cannot sample with coordinates of type {:?}",
+ other
+ )))
+ }
+ };
+
+ if let Some(depth_ref) = depth_ref {
+ write!(self.out, "vec{}(", size as u8 + 1)?;
+ self.write_expr(coordinate, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(depth_ref, ctx)?;
+ write!(self.out, ")")?
+ } else {
+ self.write_expr(coordinate, ctx)?
+ }
+
+ match level {
+ crate::SampleLevel::Auto => (),
+ crate::SampleLevel::Zero => write!(self.out, ", 0")?,
+ crate::SampleLevel::Exact(expr) | crate::SampleLevel::Bias(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+ }
+
+ write!(self.out, ")")?
+ }
+ Expression::ImageLoad {
+ image,
+ coordinate,
+ index,
+ } => {
+ let class = match ctx.typifier.get(image, &self.module.types) {
+ TypeInner::Image { class, .. } => class,
+ _ => unreachable!(),
+ };
+
+ match class {
+ ImageClass::Sampled { .. } => write!(self.out, "texelFetch(")?,
+ ImageClass::Storage(_) => write!(self.out, "imageLoad(")?,
+ ImageClass::Depth => todo!(),
+ }
+
+ self.write_expr(image, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(coordinate, ctx)?;
+
+ match class {
+ ImageClass::Sampled { .. } => {
+ write!(self.out, ", ")?;
+ self.write_expr(index.unwrap(), ctx)?;
+ write!(self.out, ")")?
+ }
+ ImageClass::Storage(_) => write!(self.out, ")")?,
+ ImageClass::Depth => todo!(),
+ }
+ }
+ Expression::Unary { op, expr } => {
+ write!(
+ self.out,
+ "({} ",
+ match op {
+ UnaryOperator::Negate => "-",
+ UnaryOperator::Not => match *ctx.typifier.get(expr, &self.module.types) {
+ TypeInner::Scalar {
+ kind: ScalarKind::Sint,
+ ..
+ } => "~",
+ TypeInner::Scalar {
+ kind: ScalarKind::Uint,
+ ..
+ } => "~",
+ TypeInner::Scalar {
+ kind: ScalarKind::Bool,
+ ..
+ } => "!",
+ ref other =>
+ return Err(Error::Custom(format!(
+ "Cannot apply not to type {:?}",
+ other
+ ))),
+ },
+ }
+ )?;
+
+ self.write_expr(expr, ctx)?;
+
+ write!(self.out, ")")?
+ }
+ Expression::Binary { op, left, right } => {
+ write!(self.out, "(")?;
+ self.write_expr(left, ctx)?;
+
+ write!(
+ self.out,
+ " {} ",
+ match op {
+ BinaryOperator::Add => "+",
+ BinaryOperator::Subtract => "-",
+ BinaryOperator::Multiply => "*",
+ BinaryOperator::Divide => "/",
+ BinaryOperator::Modulo => "%",
+ BinaryOperator::Equal => "==",
+ BinaryOperator::NotEqual => "!=",
+ BinaryOperator::Less => "<",
+ BinaryOperator::LessEqual => "<=",
+ BinaryOperator::Greater => ">",
+ BinaryOperator::GreaterEqual => ">=",
+ BinaryOperator::And => "&",
+ BinaryOperator::ExclusiveOr => "^",
+ BinaryOperator::InclusiveOr => "|",
+ BinaryOperator::LogicalAnd => "&&",
+ BinaryOperator::LogicalOr => "||",
+ BinaryOperator::ShiftLeft => "<<",
+ BinaryOperator::ShiftRight => ">>",
+ }
+ )?;
+
+ self.write_expr(right, ctx)?;
+
+ write!(self.out, ")")?
+ }
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ write!(self.out, "(")?;
+ self.write_expr(condition, ctx)?;
+ write!(self.out, " ? ")?;
+ self.write_expr(accept, ctx)?;
+ write!(self.out, " : ")?;
+ self.write_expr(reject, ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Intrinsic { fun, argument } => {
+ write!(
+ self.out,
+ "{}(",
+ match fun {
+ IntrinsicFunction::IsFinite => "!isinf",
+ IntrinsicFunction::IsInf => "isinf",
+ IntrinsicFunction::IsNan => "isnan",
+ IntrinsicFunction::IsNormal => "!isnan",
+ IntrinsicFunction::All => "all",
+ IntrinsicFunction::Any => "any",
+ }
+ )?;
+
+ self.write_expr(argument, ctx)?;
+
+ write!(self.out, ")")?
+ }
+ Expression::Transpose(matrix) => {
+ write!(self.out, "transpose(")?;
+ self.write_expr(matrix, ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::DotProduct(left, right) => {
+ write!(self.out, "dot(")?;
+ self.write_expr(left, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(right, ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::CrossProduct(left, right) => {
+ write!(self.out, "cross(")?;
+ self.write_expr(left, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(right, ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ if convert {
+ self.write_type(ctx.typifier.get_handle(expr).unwrap(), None)?;
+ } else {
+ let source_kind = match *ctx.typifier.get(expr, &self.module.types) {
+ TypeInner::Scalar {
+ kind: source_kind, ..
+ } => source_kind,
+ TypeInner::Vector {
+ kind: source_kind, ..
+ } => source_kind,
+ _ => unreachable!(),
+ };
+
+ write!(
+ self.out,
+ "{}",
+ match (source_kind, kind) {
+ (ScalarKind::Float, ScalarKind::Sint) => "floatBitsToInt",
+ (ScalarKind::Float, ScalarKind::Uint) => "floatBitsToUInt",
+ (ScalarKind::Sint, ScalarKind::Float) => "intBitsToFloat",
+ (ScalarKind::Uint, ScalarKind::Float) => "uintBitsToFloat",
+ _ => {
+ return Err(Error::Custom(format!(
+ "Cannot bitcast {:?} to {:?}",
+ source_kind, kind
+ )));
+ }
+ }
+ )?;
+ }
+
+ write!(self.out, "(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Derivative { axis, expr } => {
+ write!(
+ self.out,
+ "{}(",
+ match axis {
+ DerivativeAxis::X => "dFdx",
+ DerivativeAxis::Y => "dFdy",
+ DerivativeAxis::Width => "fwidth",
+ }
+ )?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Call {
+ origin: FunctionOrigin::Local(ref function),
+ ref arguments,
+ } => {
+ write!(self.out, "{}(", &self.names[&NameKey::Function(*function)])?;
+ self.write_slice(arguments, |this, _, arg| this.write_expr(*arg, ctx))?;
+ write!(self.out, ")")?
+ }
+ Expression::Call {
+ origin: crate::FunctionOrigin::External(ref name),
+ ref arguments,
+ } => match name.as_str() {
+ "cos" | "normalize" | "sin" | "length" | "abs" | "floor" | "inverse"
+ | "distance" | "dot" | "min" | "max" | "reflect" | "pow" | "step" | "cross"
+ | "fclamp" | "clamp" | "mix" | "smoothstep" => {
+ let name = match name.as_str() {
+ "fclamp" => "clamp",
+ name => name,
+ };
+
+ write!(self.out, "{}(", name)?;
+ self.write_slice(arguments, |this, _, arg| this.write_expr(*arg, ctx))?;
+ write!(self.out, ")")?
+ }
+ "atan2" => {
+ write!(self.out, "atan(")?;
+ self.write_expr(arguments[1], ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(arguments[0], ctx)?;
+ write!(self.out, ")")?
+ }
+ other => {
+ return Err(Error::Custom(format!(
+ "Unsupported function call {}",
+ other
+ )))
+ }
+ },
+ Expression::ArrayLength(expr) => {
+ write!(self.out, "uint(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ".length())")?
+ }
+ }
+
+ Ok(())
+ }
+
+ fn write_constant(&mut self, constant: &Constant) -> Result<(), Error> {
+ match constant.inner {
+ ConstantInner::Sint(int) => write!(self.out, "{}", int)?,
+ ConstantInner::Uint(int) => write!(self.out, "{}u", int)?,
+ ConstantInner::Float(float) => write!(self.out, "{:?}", float)?,
+ ConstantInner::Bool(boolean) => write!(self.out, "{}", boolean)?,
+ ConstantInner::Composite(ref components) => {
+ self.write_type(constant.ty, None)?;
+ write!(self.out, "(")?;
+ self.write_slice(components, |this, _, arg| {
+ this.write_constant(&this.module.constants[*arg])
+ })?;
+ write!(self.out, ")")?
+ }
+ }
+
+ Ok(())
+ }
+}
+
+struct ScalarString<'a> {
+ prefix: &'a str,
+ full: &'a str,
+}
+
+fn glsl_scalar(kind: ScalarKind, width: crate::Bytes) -> Result<ScalarString<'static>, Error> {
+ Ok(match kind {
+ ScalarKind::Sint => ScalarString {
+ prefix: "i",
+ full: "int",
+ },
+ ScalarKind::Uint => ScalarString {
+ prefix: "u",
+ full: "uint",
+ },
+ ScalarKind::Float => match width {
+ 4 => ScalarString {
+ prefix: "",
+ full: "float",
+ },
+ 8 => ScalarString {
+ prefix: "d",
+ full: "double",
+ },
+ _ => {
+ return Err(Error::Custom(format!(
+ "Cannot build float of width {}",
+ width
+ )))
+ }
+ },
+ ScalarKind::Bool => ScalarString {
+ prefix: "b",
+ full: "bool",
+ },
+ })
+}
+
+fn glsl_built_in(built_in: BuiltIn) -> &'static str {
+ match built_in {
+ BuiltIn::Position => "gl_Position",
+ BuiltIn::GlobalInvocationId => "gl_GlobalInvocationID",
+ BuiltIn::BaseInstance => "gl_BaseInstance",
+ BuiltIn::BaseVertex => "gl_BaseVertex",
+ BuiltIn::ClipDistance => "gl_ClipDistance",
+ BuiltIn::InstanceIndex => "gl_InstanceIndex",
+ BuiltIn::VertexIndex => "gl_VertexIndex",
+ BuiltIn::PointSize => "gl_PointSize",
+ BuiltIn::FragCoord => "gl_FragCoord",
+ BuiltIn::FrontFacing => "gl_FrontFacing",
+ BuiltIn::SampleIndex => "gl_SampleID",
+ BuiltIn::FragDepth => "gl_FragDepth",
+ BuiltIn::LocalInvocationId => "gl_LocalInvocationID",
+ BuiltIn::LocalInvocationIndex => "gl_LocalInvocationIndex",
+ BuiltIn::WorkGroupId => "gl_WorkGroupID",
+ }
+}
+
+fn glsl_storage_class(class: StorageClass) -> &'static str {
+ match class {
+ StorageClass::Function => "",
+ StorageClass::Input => "in",
+ StorageClass::Output => "out",
+ StorageClass::Private => "",
+ StorageClass::Storage => "buffer",
+ StorageClass::Uniform => "uniform",
+ StorageClass::Handle => "uniform",
+ StorageClass::WorkGroup => "shared",
+ StorageClass::PushConstant => "",
+ }
+}
+
+fn glsl_interpolation(interpolation: Interpolation) -> Result<&'static str, Error> {
+ Ok(match interpolation {
+ Interpolation::Perspective => "smooth",
+ Interpolation::Linear => "noperspective",
+ Interpolation::Flat => "flat",
+ Interpolation::Centroid => "centroid",
+ Interpolation::Sample => "sample",
+ Interpolation::Patch => {
+ return Err(Error::Custom(
+ "patch interpolation qualifier not supported".to_string(),
+ ))
+ }
+ })
+}
+
+struct ImageDimension(crate::ImageDimension);
+impl fmt::Display for ImageDimension {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(
+ f,
+ "{}",
+ match self.0 {
+ crate::ImageDimension::D1 => "1D",
+ crate::ImageDimension::D2 => "2D",
+ crate::ImageDimension::D3 => "3D",
+ crate::ImageDimension::Cube => "Cube",
+ }
+ )
+ }
+}
+
+fn glsl_storage_format(format: StorageFormat) -> &'static str {
+ match format {
+ StorageFormat::R8Unorm => "r8",
+ StorageFormat::R8Snorm => "r8_snorm",
+ StorageFormat::R8Uint => "r8ui",
+ StorageFormat::R8Sint => "r8i",
+ StorageFormat::R16Uint => "r16ui",
+ StorageFormat::R16Sint => "r16i",
+ StorageFormat::R16Float => "r16f",
+ StorageFormat::Rg8Unorm => "rg8",
+ StorageFormat::Rg8Snorm => "rg8_snorm",
+ StorageFormat::Rg8Uint => "rg8ui",
+ StorageFormat::Rg8Sint => "rg8i",
+ StorageFormat::R32Uint => "r32ui",
+ StorageFormat::R32Sint => "r32i",
+ StorageFormat::R32Float => "r32f",
+ StorageFormat::Rg16Uint => "rg16ui",
+ StorageFormat::Rg16Sint => "rg16i",
+ StorageFormat::Rg16Float => "rg16f",
+ StorageFormat::Rgba8Unorm => "rgba8ui",
+ StorageFormat::Rgba8Snorm => "rgba8_snorm",
+ StorageFormat::Rgba8Uint => "rgba8ui",
+ StorageFormat::Rgba8Sint => "rgba8i",
+ StorageFormat::Rgb10a2Unorm => "rgb10_a2ui",
+ StorageFormat::Rg11b10Float => "r11f_g11f_b10f",
+ StorageFormat::Rg32Uint => "rg32ui",
+ StorageFormat::Rg32Sint => "rg32i",
+ StorageFormat::Rg32Float => "rg32f",
+ StorageFormat::Rgba16Uint => "rgba16ui",
+ StorageFormat::Rgba16Sint => "rgba16i",
+ StorageFormat::Rgba16Float => "rgba16f",
+ StorageFormat::Rgba32Uint => "rgba32ui",
+ StorageFormat::Rgba32Sint => "rgba32i",
+ StorageFormat::Rgba32Float => "rgba32f",
+ }
+}
+
+struct TextureMappingVisitor<'a> {
+ names: &'a FastHashMap<NameKey, String>,
+ expressions: &'a Arena<Expression>,
+ map: &'a mut FastHashMap<String, TextureMapping>,
+ error: Option<Error>,
+}
+
+impl<'a> Visitor for TextureMappingVisitor<'a> {
+ fn visit_expr(&mut self, expr: &crate::Expression) {
+ match expr {
+ Expression::ImageSample { image, sampler, .. } => {
+ let tex_handle = match self.expressions[*image] {
+ Expression::GlobalVariable(global) => global,
+ _ => unreachable!(),
+ };
+ let tex_name = self.names[&NameKey::GlobalVariable(tex_handle)].clone();
+
+ let sampler_handle = match self.expressions[*sampler] {
+ Expression::GlobalVariable(global) => global,
+ _ => unreachable!(),
+ };
+
+ let mapping = self.map.entry(tex_name).or_insert(TextureMapping {
+ texture: tex_handle,
+ sampler: Some(sampler_handle),
+ });
+
+ if mapping.sampler != Some(sampler_handle) {
+ self.error = Some(Error::Custom(String::from(
+ "Cannot use texture with two different samplers",
+ )));
+ }
+ }
+ Expression::ImageLoad { image, .. } => {
+ let tex_handle = match self.expressions[*image] {
+ Expression::GlobalVariable(global) => global,
+ _ => unreachable!(),
+ };
+ let tex_name = self.names[&NameKey::GlobalVariable(tex_handle)].clone();
+
+ let mapping = self.map.entry(tex_name).or_insert(TextureMapping {
+ texture: tex_handle,
+ sampler: None,
+ });
+
+ if mapping.sampler != None {
+ self.error = Some(Error::Custom(String::from(
+ "Cannot use texture with two different samplers",
+ )));
+ }
+ }
+ _ => {}
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/glsl/keywords.rs b/third_party/rust/naga/src/back/glsl/keywords.rs
new file mode 100644
index 0000000000..5a2836c189
--- /dev/null
+++ b/third_party/rust/naga/src/back/glsl/keywords.rs
@@ -0,0 +1,204 @@
+pub const RESERVED_KEYWORDS: &[&str] = &[
+ "attribute",
+ "const",
+ "uniform",
+ "varying",
+ "buffer",
+ "shared",
+ "coherent",
+ "volatile",
+ "restrict",
+ "readonly",
+ "writeonly",
+ "atomic_uint",
+ "layout",
+ "centroid",
+ "flat",
+ "smooth",
+ "noperspective",
+ "patch",
+ "sample",
+ "break",
+ "continue",
+ "do",
+ "for",
+ "while",
+ "switch",
+ "case",
+ "default",
+ "if",
+ "else",
+ "subroutine",
+ "in",
+ "out",
+ "inout",
+ "float",
+ "double",
+ "int",
+ "void",
+ "bool",
+ "true",
+ "false",
+ "invariant",
+ "precise",
+ "discard",
+ "return",
+ "mat2",
+ "mat3",
+ "mat4",
+ "dmat2",
+ "dmat3",
+ "dmat4",
+ "mat2x2",
+ "mat2x3",
+ "mat2x4",
+ "dmat2x2",
+ "dmat2x3",
+ "dmat2x4",
+ "mat3x2",
+ "mat3x3",
+ "mat3x4",
+ "dmat3x2",
+ "dmat3x3",
+ "dmat3x4",
+ "mat4x2",
+ "mat4x3",
+ "mat4x4",
+ "dmat4x2",
+ "dmat4x3",
+ "dmat4x4",
+ "vec2",
+ "vec3",
+ "vec4",
+ "ivec2",
+ "ivec3",
+ "ivec4",
+ "bvec2",
+ "bvec3",
+ "bvec4",
+ "dvec2",
+ "dvec3",
+ "dvec4",
+ "uint",
+ "uvec2",
+ "uvec3",
+ "uvec4",
+ "lowp",
+ "mediump",
+ "highp",
+ "precision",
+ "sampler1D",
+ "sampler2D",
+ "sampler3D",
+ "samplerCube",
+ "sampler1DShadow",
+ "sampler2DShadow",
+ "samplerCubeShadow",
+ "sampler1DArray",
+ "sampler2DArray",
+ "sampler1DArrayShadow",
+ "sampler2DArrayShadow",
+ "isampler1D",
+ "isampler2D",
+ "isampler3D",
+ "isamplerCube",
+ "isampler1DArray",
+ "isampler2DArray",
+ "usampler1D",
+ "usampler2D",
+ "usampler3D",
+ "usamplerCube",
+ "usampler1DArray",
+ "usampler2DArray",
+ "sampler2DRect",
+ "sampler2DRectShadow",
+ "isampler2D",
+ "Rect",
+ "usampler2DRect",
+ "samplerBuffer",
+ "isamplerBuffer",
+ "usamplerBuffer",
+ "sampler2DMS",
+ "isampler2DMS",
+ "usampler2DMS",
+ "sampler2DMSArray",
+ "isampler2DMSArray",
+ "usampler2DMSArray",
+ "samplerCubeArray",
+ "samplerCubeArrayShadow",
+ "isamplerCubeArray",
+ "usamplerCubeArray",
+ "image1D",
+ "iimage1D",
+ "uimage1D",
+ "image2D",
+ "iimage2D",
+ "uimage2D",
+ "image3D",
+ "iimage3D",
+ "uimage3D",
+ "image2DRect",
+ "iimage2DRect",
+ "uimage2DRect",
+ "imageCube",
+ "iimageCube",
+ "uimageCube",
+ "imageBuffer",
+ "iimageBuffer",
+ "uimageBuffer",
+ "image1DArray",
+ "iimage1DArray",
+ "uimage1DArray",
+ "image2DArray",
+ "iimage2DArray",
+ "uimage2DArray",
+ "imageCubeArray",
+ "iimageCubeArray",
+ "uimageCubeArray",
+ "image2DMS",
+ "iimage2DMS",
+ "uimage2DMS",
+ "image2DMSArray",
+ "iimage2DMSArray",
+ "uimage2DMSArraystruct",
+ "common",
+ "partition",
+ "active",
+ "asm",
+ "class",
+ "union",
+ "enum",
+ "typedef",
+ "template",
+ "this",
+ "resource",
+ "goto",
+ "inline",
+ "noinline",
+ "public",
+ "static",
+ "extern",
+ "external",
+ "interface",
+ "long",
+ "short",
+ "half",
+ "fixed",
+ "unsigned",
+ "superp",
+ "input",
+ "output",
+ "hvec2",
+ "hvec3",
+ "hvec4",
+ "fvec2",
+ "fvec3",
+ "fvec4",
+ "sampler3DRect",
+ "filter",
+ "sizeof",
+ "cast",
+ "namespace",
+ "using",
+ "main",
+];
diff --git a/third_party/rust/naga/src/back/mod.rs b/third_party/rust/naga/src/back/mod.rs
new file mode 100644
index 0000000000..bc96dd3496
--- /dev/null
+++ b/third_party/rust/naga/src/back/mod.rs
@@ -0,0 +1,8 @@
+//! Functions which export shader modules into binary and text formats.
+
+#[cfg(feature = "glsl-out")]
+pub mod glsl;
+#[cfg(feature = "msl-out")]
+pub mod msl;
+#[cfg(feature = "spv-out")]
+pub mod spv;
diff --git a/third_party/rust/naga/src/back/msl/keywords.rs b/third_party/rust/naga/src/back/msl/keywords.rs
new file mode 100644
index 0000000000..cd074ab43f
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/keywords.rs
@@ -0,0 +1,102 @@
+//TODO: find a complete list
+pub const RESERVED: &[&str] = &[
+ // control flow
+ "break",
+ "if",
+ "else",
+ "continue",
+ "goto",
+ "do",
+ "while",
+ "for",
+ "switch",
+ "case",
+ // types and values
+ "void",
+ "unsigned",
+ "signed",
+ "bool",
+ "char",
+ "int",
+ "long",
+ "float",
+ "double",
+ "char8_t",
+ "wchar_t",
+ "true",
+ "false",
+ "nullptr",
+ "union",
+ "class",
+ "struct",
+ "enum",
+ // other
+ "main",
+ "using",
+ "decltype",
+ "sizeof",
+ "typeof",
+ "typedef",
+ "explicit",
+ "export",
+ "friend",
+ "namespace",
+ "operator",
+ "public",
+ "template",
+ "typename",
+ "typeid",
+ "co_await",
+ "co_return",
+ "co_yield",
+ "module",
+ "import",
+ "ray_data",
+ "vec_step",
+ "visible",
+ "as_type",
+ // qualifiers
+ "mutable",
+ "static",
+ "volatile",
+ "restrict",
+ "const",
+ "non-temporal",
+ "dereferenceable",
+ "invariant",
+ // exceptions
+ "throw",
+ "try",
+ "catch",
+ // operators
+ "const_cast",
+ "dynamic_cast",
+ "reinterpret_cast",
+ "static_cast",
+ "new",
+ "delete",
+ "and",
+ "and_eq",
+ "bitand",
+ "bitor",
+ "compl",
+ "not",
+ "not_eq",
+ "or",
+ "or_eq",
+ "xor",
+ "xor_eq",
+ "compl",
+ // Metal-specific
+ "constant",
+ "device",
+ "threadgroup",
+ "threadgroup_imageblock",
+ "kernel",
+ "compute",
+ "vertex",
+ "fragment",
+ "read_only",
+ "write_only",
+ "read_write",
+];
diff --git a/third_party/rust/naga/src/back/msl/mod.rs b/third_party/rust/naga/src/back/msl/mod.rs
new file mode 100644
index 0000000000..493e7d0c85
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/mod.rs
@@ -0,0 +1,211 @@
+/*! Metal Shading Language (MSL) backend
+
+## Binding model
+
+Metal's bindings are flat per resource. Since there isn't an obvious mapping
+from SPIR-V's descriptor sets, we require a separate mapping provided in the options.
+This mapping may have one or more resource end points for each descriptor set + index
+pair.
+
+## Outputs
+
+In Metal, built-in shader outputs can not be nested into structures within
+the output struct. If there is a structure in the outputs, and it contains any built-ins,
+we move them up to the root output structure that we define ourselves.
+!*/
+
+use crate::{arena::Handle, proc::ResolveError, FastHashMap};
+use std::{
+ io::{Error as IoError, Write},
+ string::FromUtf8Error,
+};
+
+mod keywords;
+mod writer;
+
+pub use writer::Writer;
+
+#[derive(Clone, Debug, Default, PartialEq)]
+pub struct BindTarget {
+ pub buffer: Option<u8>,
+ pub texture: Option<u8>,
+ pub sampler: Option<u8>,
+ pub mutable: bool,
+}
+
+#[derive(Clone, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub struct BindSource {
+ pub stage: crate::ShaderStage,
+ pub group: u32,
+ pub binding: u32,
+}
+
+pub type BindingMap = FastHashMap<BindSource, BindTarget>;
+
+enum ResolvedBinding {
+ BuiltIn(crate::BuiltIn),
+ Attribute(u32),
+ Color(u32),
+ User { prefix: &'static str, index: u32 },
+ Resource(BindTarget),
+}
+
+// Note: some of these should be removed in favor of proper IR validation.
+
+#[derive(Debug)]
+pub enum Error {
+ IO(IoError),
+ Utf8(FromUtf8Error),
+ Type(ResolveError),
+ UnexpectedLocation,
+ MissingBinding(Handle<crate::GlobalVariable>),
+ MissingBindTarget(BindSource),
+ InvalidImageAccess(crate::StorageAccess),
+ MutabilityViolation(Handle<crate::GlobalVariable>),
+ BadName(String),
+ UnexpectedGlobalType(Handle<crate::Type>),
+ UnimplementedBindTarget(BindTarget),
+ UnsupportedCompose(Handle<crate::Type>),
+ UnsupportedBinaryOp(crate::BinaryOperator),
+ UnexpectedSampleLevel(crate::SampleLevel),
+ UnsupportedCall(String),
+ UnsupportedDynamicArrayLength,
+ UnableToReturnValue(Handle<crate::Expression>),
+ /// The source IR is not valid.
+ Validation,
+}
+
+impl From<IoError> for Error {
+ fn from(e: IoError) -> Self {
+ Error::IO(e)
+ }
+}
+
+impl From<FromUtf8Error> for Error {
+ fn from(e: FromUtf8Error) -> Self {
+ Error::Utf8(e)
+ }
+}
+
+impl From<ResolveError> for Error {
+ fn from(e: ResolveError) -> Self {
+ Error::Type(e)
+ }
+}
+
+#[derive(Clone, Copy, Debug)]
+enum LocationMode {
+ VertexInput,
+ FragmentOutput,
+ Intermediate,
+ Uniform,
+}
+
+#[derive(Debug, Default, Clone)]
+pub struct Options {
+ /// (Major, Minor) target version of the Metal Shading Language.
+ pub lang_version: (u8, u8),
+ /// Make it possible to link different stages via SPIRV-Cross.
+ pub spirv_cross_compatibility: bool,
+ /// Binding model mapping to Metal.
+ pub binding_map: BindingMap,
+}
+
+impl Options {
+ fn resolve_binding(
+ &self,
+ stage: crate::ShaderStage,
+ binding: &crate::Binding,
+ mode: LocationMode,
+ ) -> Result<ResolvedBinding, Error> {
+ match *binding {
+ crate::Binding::BuiltIn(built_in) => Ok(ResolvedBinding::BuiltIn(built_in)),
+ crate::Binding::Location(index) => match mode {
+ LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(index)),
+ LocationMode::FragmentOutput => Ok(ResolvedBinding::Color(index)),
+ LocationMode::Intermediate => Ok(ResolvedBinding::User {
+ prefix: if self.spirv_cross_compatibility {
+ "locn"
+ } else {
+ "loc"
+ },
+ index,
+ }),
+ LocationMode::Uniform => Err(Error::UnexpectedLocation),
+ },
+ crate::Binding::Resource { group, binding } => {
+ let source = BindSource {
+ stage,
+ group,
+ binding,
+ };
+ self.binding_map
+ .get(&source)
+ .cloned()
+ .map(ResolvedBinding::Resource)
+ .ok_or(Error::MissingBindTarget(source))
+ }
+ }
+ }
+}
+
+impl ResolvedBinding {
+ fn try_fmt<W: Write>(&self, out: &mut W) -> Result<(), Error> {
+ match *self {
+ ResolvedBinding::BuiltIn(built_in) => {
+ use crate::BuiltIn as Bi;
+ let name = match built_in {
+ // vertex
+ Bi::BaseInstance => "base_instance",
+ Bi::BaseVertex => "base_vertex",
+ Bi::ClipDistance => "clip_distance",
+ Bi::InstanceIndex => "instance_id",
+ Bi::PointSize => "point_size",
+ Bi::Position => "position",
+ Bi::VertexIndex => "vertex_id",
+ // fragment
+ Bi::FragCoord => "position",
+ Bi::FragDepth => "depth(any)",
+ Bi::FrontFacing => "front_facing",
+ Bi::SampleIndex => "sample_id",
+ // compute
+ Bi::GlobalInvocationId => "thread_position_in_grid",
+ Bi::LocalInvocationId => "thread_position_in_threadgroup",
+ Bi::LocalInvocationIndex => "thread_index_in_threadgroup",
+ Bi::WorkGroupId => "threadgroup_position_in_grid",
+ };
+ Ok(write!(out, "{}", name)?)
+ }
+ ResolvedBinding::Attribute(index) => Ok(write!(out, "attribute({})", index)?),
+ ResolvedBinding::Color(index) => Ok(write!(out, "color({})", index)?),
+ ResolvedBinding::User { prefix, index } => {
+ Ok(write!(out, "user({}{})", prefix, index)?)
+ }
+ ResolvedBinding::Resource(ref target) => {
+ if let Some(id) = target.buffer {
+ Ok(write!(out, "buffer({})", id)?)
+ } else if let Some(id) = target.texture {
+ Ok(write!(out, "texture({})", id)?)
+ } else if let Some(id) = target.sampler {
+ Ok(write!(out, "sampler({})", id)?)
+ } else {
+ Err(Error::UnimplementedBindTarget(target.clone()))
+ }
+ }
+ }
+ }
+
+ fn try_fmt_decorated<W: Write>(&self, out: &mut W, terminator: &str) -> Result<(), Error> {
+ write!(out, " [[")?;
+ self.try_fmt(out)?;
+ write!(out, "]]")?;
+ write!(out, "{}", terminator)?;
+ Ok(())
+ }
+}
+
+pub fn write_string(module: &crate::Module, options: &Options) -> Result<String, Error> {
+ let mut w = writer::Writer::new(Vec::new());
+ w.write(module, options)?;
+ Ok(String::from_utf8(w.finish())?)
+}
diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs
new file mode 100644
index 0000000000..e2adbea6b7
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/writer.rs
@@ -0,0 +1,990 @@
+use super::{keywords::RESERVED, Error, LocationMode, Options, ResolvedBinding};
+use crate::{
+ arena::Handle,
+ proc::{EntryPointIndex, NameKey, Namer, ResolveContext, Typifier},
+ FastHashMap,
+};
+use std::{
+ fmt::{Display, Error as FmtError, Formatter},
+ io::Write,
+};
+
+struct Level(usize);
+impl Level {
+ fn next(&self) -> Self {
+ Level(self.0 + 1)
+ }
+}
+impl Display for Level {
+ fn fmt(&self, formatter: &mut Formatter<'_>) -> Result<(), FmtError> {
+ (0..self.0).map(|_| formatter.write_str("\t")).collect()
+ }
+}
+
+struct TypedGlobalVariable<'a> {
+ module: &'a crate::Module,
+ names: &'a FastHashMap<NameKey, String>,
+ handle: Handle<crate::GlobalVariable>,
+ usage: crate::GlobalUse,
+}
+
+impl<'a> TypedGlobalVariable<'a> {
+ fn try_fmt<W: Write>(&self, out: &mut W) -> Result<(), Error> {
+ let var = &self.module.global_variables[self.handle];
+ let name = &self.names[&NameKey::GlobalVariable(self.handle)];
+ let ty = &self.module.types[var.ty];
+ let ty_name = &self.names[&NameKey::Type(var.ty)];
+
+ let (space_qualifier, reference) = match ty.inner {
+ crate::TypeInner::Struct { .. } => match var.class {
+ crate::StorageClass::Uniform | crate::StorageClass::Storage => {
+ let space = if self.usage.contains(crate::GlobalUse::STORE) {
+ "device "
+ } else {
+ "constant "
+ };
+ (space, "&")
+ }
+ _ => ("", ""),
+ },
+ _ => ("", ""),
+ };
+ Ok(write!(
+ out,
+ "{}{}{} {}",
+ space_qualifier, ty_name, reference, name
+ )?)
+ }
+}
+
+pub struct Writer<W> {
+ out: W,
+ names: FastHashMap<NameKey, String>,
+ typifier: Typifier,
+}
+
+fn scalar_kind_string(kind: crate::ScalarKind) -> &'static str {
+ match kind {
+ crate::ScalarKind::Float => "float",
+ crate::ScalarKind::Sint => "int",
+ crate::ScalarKind::Uint => "uint",
+ crate::ScalarKind::Bool => "bool",
+ }
+}
+
+fn vector_size_string(size: crate::VectorSize) -> &'static str {
+ match size {
+ crate::VectorSize::Bi => "2",
+ crate::VectorSize::Tri => "3",
+ crate::VectorSize::Quad => "4",
+ }
+}
+
+const OUTPUT_STRUCT_NAME: &str = "output";
+const LOCATION_INPUT_STRUCT_NAME: &str = "input";
+const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];
+
+fn separate(is_last: bool) -> &'static str {
+ if is_last {
+ ""
+ } else {
+ ","
+ }
+}
+
+enum FunctionOrigin {
+ Handle(Handle<crate::Function>),
+ EntryPoint(EntryPointIndex),
+}
+
+struct ExpressionContext<'a> {
+ function: &'a crate::Function,
+ origin: FunctionOrigin,
+ module: &'a crate::Module,
+}
+
+impl<W: Write> Writer<W> {
+ /// Creates a new `Writer` instance.
+ pub fn new(out: W) -> Self {
+ Writer {
+ out,
+ names: FastHashMap::default(),
+ typifier: Typifier::new(),
+ }
+ }
+
+ /// Finishes writing and returns the output.
+ pub fn finish(self) -> W {
+ self.out
+ }
+
+ fn put_call(
+ &mut self,
+ name: &str,
+ parameters: &[Handle<crate::Expression>],
+ context: &ExpressionContext,
+ ) -> Result<(), Error> {
+ if !name.is_empty() {
+ write!(self.out, "metal::{}", name)?;
+ }
+ write!(self.out, "(")?;
+ for (i, &handle) in parameters.iter().enumerate() {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.put_expression(handle, context)?;
+ }
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_expression(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> Result<(), Error> {
+ let expression = &context.function.expressions[expr_handle];
+ log::trace!("expression {:?} = {:?}", expr_handle, expression);
+ match *expression {
+ crate::Expression::Access { base, index } => {
+ self.put_expression(base, context)?;
+ write!(self.out, "[")?;
+ self.put_expression(index, context)?;
+ write!(self.out, "]")?;
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ self.put_expression(base, context)?;
+ let resolved = self.typifier.get(base, &context.module.types);
+ match *resolved {
+ crate::TypeInner::Struct { .. } => {
+ let base_ty = self.typifier.get_handle(base).unwrap();
+ let name = &self.names[&NameKey::StructMember(base_ty, index)];
+ write!(self.out, ".{}", name)?;
+ }
+ crate::TypeInner::Matrix { .. } | crate::TypeInner::Vector { .. } => {
+ write!(self.out, ".{}", COMPONENTS[index as usize])?;
+ }
+ crate::TypeInner::Array { .. } => {
+ write!(self.out, "[{}]", index)?;
+ }
+ _ => {
+ // unexpected indexing, should fail validation
+ }
+ }
+ }
+ crate::Expression::Constant(handle) => self.put_constant(handle, context.module)?,
+ crate::Expression::Compose { ty, ref components } => {
+ let inner = &context.module.types[ty].inner;
+ match *inner {
+ crate::TypeInner::Vector { size, kind, .. } => {
+ write!(
+ self.out,
+ "{}{}",
+ scalar_kind_string(kind),
+ vector_size_string(size)
+ )?;
+ self.put_call("", components, context)?;
+ }
+ crate::TypeInner::Scalar { width: 4, kind } if components.len() == 1 => {
+ write!(self.out, "{}", scalar_kind_string(kind),)?;
+ self.put_call("", components, context)?;
+ }
+ _ => return Err(Error::UnsupportedCompose(ty)),
+ }
+ }
+ crate::Expression::FunctionArgument(index) => {
+ let fun_handle = match context.origin {
+ FunctionOrigin::Handle(handle) => handle,
+ FunctionOrigin::EntryPoint(_) => unreachable!(),
+ };
+ let name = &self.names[&NameKey::FunctionArgument(fun_handle, index)];
+ write!(self.out, "{}", name)?;
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ let var = &context.module.global_variables[handle];
+ match var.class {
+ crate::StorageClass::Output => {
+ if let crate::TypeInner::Struct { .. } = context.module.types[var.ty].inner
+ {
+ return Ok(());
+ }
+ write!(self.out, "{}.", OUTPUT_STRUCT_NAME)?;
+ }
+ crate::StorageClass::Input => {
+ if let Some(crate::Binding::Location(_)) = var.binding {
+ write!(self.out, "{}.", LOCATION_INPUT_STRUCT_NAME)?;
+ }
+ }
+ _ => {}
+ }
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{}", name)?;
+ }
+ crate::Expression::LocalVariable(handle) => {
+ let name_key = match context.origin {
+ FunctionOrigin::Handle(fun_handle) => {
+ NameKey::FunctionLocal(fun_handle, handle)
+ }
+ FunctionOrigin::EntryPoint(ep_index) => {
+ NameKey::EntryPointLocal(ep_index, handle)
+ }
+ };
+ let name = &self.names[&name_key];
+ write!(self.out, "{}", name)?;
+ }
+ crate::Expression::Load { pointer } => {
+ //write!(self.out, "*")?;
+ self.put_expression(pointer, context)?;
+ }
+ crate::Expression::ImageSample {
+ image,
+ sampler,
+ coordinate,
+ level,
+ depth_ref,
+ } => {
+ let op = match depth_ref {
+ Some(_) => "sample_compare",
+ None => "sample",
+ };
+ //TODO: handle arrayed images
+ self.put_expression(image, context)?;
+ write!(self.out, ".{}(", op)?;
+ self.put_expression(sampler, context)?;
+ write!(self.out, ", ")?;
+ self.put_expression(coordinate, context)?;
+ if let Some(dref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.put_expression(dref, context)?;
+ }
+ match level {
+ crate::SampleLevel::Auto => {}
+ crate::SampleLevel::Zero => {
+ write!(self.out, ", level(0)")?;
+ }
+ crate::SampleLevel::Exact(h) => {
+ write!(self.out, ", level(")?;
+ self.put_expression(h, context)?;
+ write!(self.out, ")")?;
+ }
+ crate::SampleLevel::Bias(h) => {
+ write!(self.out, ", bias(")?;
+ self.put_expression(h, context)?;
+ write!(self.out, ")")?;
+ }
+ }
+ write!(self.out, ")")?;
+ }
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ index,
+ } => {
+ //TODO: handle arrayed images
+ self.put_expression(image, context)?;
+ write!(self.out, ".read(")?;
+ self.put_expression(coordinate, context)?;
+ if let Some(index) = index {
+ write!(self.out, ", ")?;
+ self.put_expression(index, context)?;
+ }
+ write!(self.out, ")")?;
+ }
+ crate::Expression::Unary { op, expr } => {
+ let op_str = match op {
+ crate::UnaryOperator::Negate => "-",
+ crate::UnaryOperator::Not => "!",
+ };
+ write!(self.out, "{}", op_str)?;
+ self.put_expression(expr, context)?;
+ }
+ crate::Expression::Binary { op, left, right } => {
+ let op_str = match op {
+ crate::BinaryOperator::Add => "+",
+ crate::BinaryOperator::Subtract => "-",
+ crate::BinaryOperator::Multiply => "*",
+ crate::BinaryOperator::Divide => "/",
+ crate::BinaryOperator::Modulo => "%",
+ crate::BinaryOperator::Equal => "==",
+ crate::BinaryOperator::NotEqual => "!=",
+ crate::BinaryOperator::Less => "<",
+ crate::BinaryOperator::LessEqual => "<=",
+ crate::BinaryOperator::Greater => "==",
+ crate::BinaryOperator::GreaterEqual => ">=",
+ crate::BinaryOperator::And => "&",
+ _ => return Err(Error::UnsupportedBinaryOp(op)),
+ };
+ let kind = self
+ .typifier
+ .get(left, &context.module.types)
+ .scalar_kind()
+ .ok_or(Error::UnsupportedBinaryOp(op))?;
+ if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
+ write!(self.out, "fmod(")?;
+ self.put_expression(left, context)?;
+ write!(self.out, ", ")?;
+ self.put_expression(right, context)?;
+ write!(self.out, ")")?;
+ } else {
+ //write!(self.out, "(")?;
+ self.put_expression(left, context)?;
+ write!(self.out, " {} ", op_str)?;
+ self.put_expression(right, context)?;
+ //write!(self.out, ")")?;
+ }
+ }
+ crate::Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ write!(self.out, "(")?;
+ self.put_expression(condition, context)?;
+ write!(self.out, " ? ")?;
+ self.put_expression(accept, context)?;
+ write!(self.out, " : ")?;
+ self.put_expression(reject, context)?;
+ write!(self.out, ")")?;
+ }
+ crate::Expression::Intrinsic { fun, argument } => {
+ let op = match fun {
+ crate::IntrinsicFunction::Any => "any",
+ crate::IntrinsicFunction::All => "all",
+ crate::IntrinsicFunction::IsNan => "",
+ crate::IntrinsicFunction::IsInf => "",
+ crate::IntrinsicFunction::IsFinite => "",
+ crate::IntrinsicFunction::IsNormal => "",
+ };
+ self.put_call(op, &[argument], context)?;
+ }
+ crate::Expression::Transpose(expr) => {
+ self.put_call("transpose", &[expr], context)?;
+ }
+ crate::Expression::DotProduct(a, b) => {
+ self.put_call("dot", &[a, b], context)?;
+ }
+ crate::Expression::CrossProduct(a, b) => {
+ self.put_call("cross", &[a, b], context)?;
+ }
+ crate::Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ let scalar = scalar_kind_string(kind);
+ let size = match *self.typifier.get(expr, &context.module.types) {
+ crate::TypeInner::Scalar { .. } => "",
+ crate::TypeInner::Vector { size, .. } => vector_size_string(size),
+ _ => return Err(Error::Validation),
+ };
+ let op = if convert { "static_cast" } else { "as_type" };
+ write!(self.out, "{}<{}{}>(", op, scalar, size)?;
+ self.put_expression(expr, context)?;
+ write!(self.out, ")")?;
+ }
+ crate::Expression::Derivative { axis, expr } => {
+ let op = match axis {
+ crate::DerivativeAxis::X => "dfdx",
+ crate::DerivativeAxis::Y => "dfdy",
+ crate::DerivativeAxis::Width => "fwidth",
+ };
+ self.put_call(op, &[expr], context)?;
+ }
+ crate::Expression::Call {
+ origin: crate::FunctionOrigin::Local(handle),
+ ref arguments,
+ } => {
+ let name = &self.names[&NameKey::Function(handle)];
+ write!(self.out, "{}", name)?;
+ self.put_call("", arguments, context)?;
+ }
+ crate::Expression::Call {
+ origin: crate::FunctionOrigin::External(ref name),
+ ref arguments,
+ } => match name.as_str() {
+ "atan2" | "cos" | "distance" | "length" | "mix" | "normalize" | "sin" => {
+ self.put_call(name, arguments, context)?;
+ }
+ "fclamp" => {
+ self.put_call("clamp", arguments, context)?;
+ }
+ other => return Err(Error::UnsupportedCall(other.to_owned())),
+ },
+ crate::Expression::ArrayLength(expr) => match *self
+ .typifier
+ .get(expr, &context.module.types)
+ {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(const_handle),
+ ..
+ } => {
+ self.put_constant(const_handle, context.module)?;
+ }
+ crate::TypeInner::Array { .. } => return Err(Error::UnsupportedDynamicArrayLength),
+ _ => return Err(Error::Validation),
+ },
+ }
+ Ok(())
+ }
+
+ fn put_constant(
+ &mut self,
+ handle: Handle<crate::Constant>,
+ module: &crate::Module,
+ ) -> Result<(), Error> {
+ let constant = &module.constants[handle];
+ match constant.inner {
+ crate::ConstantInner::Sint(value) => {
+ write!(self.out, "{}", value)?;
+ }
+ crate::ConstantInner::Uint(value) => {
+ write!(self.out, "{}", value)?;
+ }
+ crate::ConstantInner::Float(value) => {
+ write!(self.out, "{}", value)?;
+ if value.fract() == 0.0 {
+ write!(self.out, ".0")?;
+ }
+ }
+ crate::ConstantInner::Bool(value) => {
+ write!(self.out, "{}", value)?;
+ }
+ crate::ConstantInner::Composite(ref constituents) => {
+ let ty_name = &self.names[&NameKey::Type(constant.ty)];
+ write!(self.out, "{}(", ty_name)?;
+ for (i, &handle) in constituents.iter().enumerate() {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.put_constant(handle, module)?;
+ }
+ write!(self.out, ")")?;
+ }
+ }
+ Ok(())
+ }
+
+ fn put_block(
+ &mut self,
+ level: Level,
+ statements: &[crate::Statement],
+ context: &ExpressionContext,
+ return_value: Option<&str>,
+ ) -> Result<(), Error> {
+ for statement in statements {
+ log::trace!("statement[{}] {:?}", level.0, statement);
+ match *statement {
+ crate::Statement::Block(ref block) => {
+ if !block.is_empty() {
+ writeln!(self.out, "{}{{", level)?;
+ self.put_block(level.next(), block, context, return_value)?;
+ writeln!(self.out, "{}}}", level)?;
+ }
+ }
+ crate::Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{}if (", level)?;
+ self.put_expression(condition, context)?;
+ writeln!(self.out, ") {{")?;
+ self.put_block(level.next(), accept, context, return_value)?;
+ if !reject.is_empty() {
+ writeln!(self.out, "{}}} else {{", level)?;
+ self.put_block(level.next(), reject, context, return_value)?;
+ }
+ writeln!(self.out, "{}}}", level)?;
+ }
+ crate::Statement::Switch {
+ selector,
+ ref cases,
+ ref default,
+ } => {
+ write!(self.out, "{}switch(", level)?;
+ self.put_expression(selector, context)?;
+ writeln!(self.out, ") {{")?;
+ let lcase = level.next();
+ for (&value, &(ref block, ref fall_through)) in cases.iter() {
+ writeln!(self.out, "{}case {}: {{", lcase, value)?;
+ self.put_block(lcase.next(), block, context, return_value)?;
+ if fall_through.is_none() {
+ writeln!(self.out, "{}break;", lcase.next())?;
+ }
+ writeln!(self.out, "{}}}", lcase)?;
+ }
+ writeln!(self.out, "{}default: {{", lcase)?;
+ self.put_block(lcase.next(), default, context, return_value)?;
+ writeln!(self.out, "{}}}", lcase)?;
+ writeln!(self.out, "{}}}", level)?;
+ }
+ crate::Statement::Loop {
+ ref body,
+ ref continuing,
+ } => {
+ writeln!(self.out, "{}while(true) {{", level)?;
+ self.put_block(level.next(), body, context, return_value)?;
+ if !continuing.is_empty() {
+ //TODO
+ }
+ writeln!(self.out, "{}}}", level)?;
+ }
+ crate::Statement::Break => {
+ writeln!(self.out, "{}break;", level)?;
+ }
+ crate::Statement::Continue => {
+ writeln!(self.out, "{}continue;", level)?;
+ }
+ crate::Statement::Return {
+ value: Some(expr_handle),
+ } => {
+ write!(self.out, "{}return ", level)?;
+ self.put_expression(expr_handle, context)?;
+ writeln!(self.out, ";")?;
+ }
+ crate::Statement::Return { value: None } => {
+ if let Some(string) = return_value {
+ writeln!(self.out, "{}return {};", level, string)?;
+ }
+ }
+ crate::Statement::Kill => {
+ writeln!(self.out, "{}discard_fragment();", level)?;
+ }
+ crate::Statement::Store { pointer, value } => {
+ //write!(self.out, "\t*")?;
+ write!(self.out, "{}", level)?;
+ self.put_expression(pointer, context)?;
+ write!(self.out, " = ")?;
+ self.put_expression(value, context)?;
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ pub fn write(&mut self, module: &crate::Module, options: &Options) -> Result<(), Error> {
+ self.names.clear();
+ Namer::process(module, RESERVED, &mut self.names);
+
+ writeln!(self.out, "#include <metal_stdlib>")?;
+ writeln!(self.out, "#include <simd/simd.h>")?;
+
+ writeln!(self.out)?;
+ self.write_type_defs(module)?;
+
+ writeln!(self.out)?;
+ self.write_functions(module, options)?;
+
+ Ok(())
+ }
+
+ fn write_type_defs(&mut self, module: &crate::Module) -> Result<(), Error> {
+ for (handle, ty) in module.types.iter() {
+ let name = &self.names[&NameKey::Type(handle)];
+ match ty.inner {
+ crate::TypeInner::Scalar { kind, .. } => {
+ write!(self.out, "typedef {} {}", scalar_kind_string(kind), name)?;
+ }
+ crate::TypeInner::Vector { size, kind, .. } => {
+ write!(
+ self.out,
+ "typedef {}{} {}",
+ scalar_kind_string(kind),
+ vector_size_string(size),
+ name
+ )?;
+ }
+ crate::TypeInner::Matrix { columns, rows, .. } => {
+ write!(
+ self.out,
+ "typedef {}{}x{} {}",
+ scalar_kind_string(crate::ScalarKind::Float),
+ vector_size_string(columns),
+ vector_size_string(rows),
+ name
+ )?;
+ }
+ crate::TypeInner::Pointer { base, class } => {
+ use crate::StorageClass as Sc;
+ let base_name = &self.names[&NameKey::Type(base)];
+ let class_name = match class {
+ Sc::Input | Sc::Output => continue,
+ Sc::Uniform => "constant",
+ Sc::Storage => "device",
+ Sc::Handle
+ | Sc::Private
+ | Sc::Function
+ | Sc::WorkGroup
+ | Sc::PushConstant => "",
+ };
+ write!(self.out, "typedef {} {} *{}", class_name, base_name, name)?;
+ }
+ crate::TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ } => {
+ let base_name = &self.names[&NameKey::Type(base)];
+ write!(self.out, "typedef {} {}[", base_name, name)?;
+ match size {
+ crate::ArraySize::Constant(const_handle) => {
+ self.put_constant(const_handle, module)?;
+ write!(self.out, "]")?;
+ }
+ crate::ArraySize::Dynamic => write!(self.out, "1]")?,
+ }
+ }
+ crate::TypeInner::Struct { ref members } => {
+ writeln!(self.out, "struct {} {{", name)?;
+ for (index, member) in members.iter().enumerate() {
+ let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
+ let base_name = &self.names[&NameKey::Type(member.ty)];
+ write!(self.out, "\t{} {}", base_name, member_name)?;
+ match member.origin {
+ crate::MemberOrigin::Empty => {}
+ crate::MemberOrigin::BuiltIn(built_in) => {
+ ResolvedBinding::BuiltIn(built_in)
+ .try_fmt_decorated(&mut self.out, "")?;
+ }
+ crate::MemberOrigin::Offset(_) => {
+ //TODO
+ }
+ }
+ writeln!(self.out, ";")?;
+ }
+ write!(self.out, "}}")?;
+ }
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ let dim_str = match dim {
+ crate::ImageDimension::D1 => "1d",
+ crate::ImageDimension::D2 => "2d",
+ crate::ImageDimension::D3 => "3d",
+ crate::ImageDimension::Cube => "Cube",
+ };
+ let (texture_str, msaa_str, kind, access) = match class {
+ crate::ImageClass::Sampled { kind, multi } => {
+ ("texture", if multi { "_ms" } else { "" }, kind, "sample")
+ }
+ crate::ImageClass::Depth => {
+ ("depth", "", crate::ScalarKind::Float, "sample")
+ }
+ crate::ImageClass::Storage(format) => {
+ let (_, global) = module
+ .global_variables
+ .iter()
+ .find(|(_, var)| var.ty == handle)
+ .expect("Unable to find a global variable using the image type");
+ let access = if global
+ .storage_access
+ .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
+ {
+ "read_write"
+ } else if global.storage_access.contains(crate::StorageAccess::STORE) {
+ "write"
+ } else if global.storage_access.contains(crate::StorageAccess::LOAD) {
+ "read"
+ } else {
+ return Err(Error::InvalidImageAccess(global.storage_access));
+ };
+ ("texture", "", format.into(), access)
+ }
+ };
+ let base_name = scalar_kind_string(kind);
+ let array_str = if arrayed { "_array" } else { "" };
+ write!(
+ self.out,
+ "typedef {}{}{}{}<{}, access::{}> {}",
+ texture_str, dim_str, msaa_str, array_str, base_name, access, name
+ )?;
+ }
+ crate::TypeInner::Sampler { comparison: _ } => {
+ write!(self.out, "typedef sampler {}", name)?;
+ }
+ }
+ writeln!(self.out, ";")?;
+ }
+ Ok(())
+ }
+
+ fn write_functions(&mut self, module: &crate::Module, options: &Options) -> Result<(), Error> {
+ for (fun_handle, fun) in module.functions.iter() {
+ self.typifier.resolve_all(
+ &fun.expressions,
+ &module.types,
+ &ResolveContext {
+ constants: &module.constants,
+ global_vars: &module.global_variables,
+ local_vars: &fun.local_variables,
+ functions: &module.functions,
+ arguments: &fun.arguments,
+ },
+ )?;
+
+ let fun_name = &self.names[&NameKey::Function(fun_handle)];
+ let result_type_name = match fun.return_type {
+ Some(ret_ty) => &self.names[&NameKey::Type(ret_ty)],
+ None => "void",
+ };
+ writeln!(self.out, "{} {}(", result_type_name, fun_name)?;
+
+ for (index, arg) in fun.arguments.iter().enumerate() {
+ let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
+ let param_type_name = &self.names[&NameKey::Type(arg.ty)];
+ let separator = separate(index + 1 == fun.arguments.len());
+ writeln!(self.out, "\t{} {}{}", param_type_name, name, separator)?;
+ }
+ writeln!(self.out, ") {{")?;
+
+ for (local_handle, local) in fun.local_variables.iter() {
+ let ty_name = &self.names[&NameKey::Type(local.ty)];
+ let local_name = &self.names[&NameKey::FunctionLocal(fun_handle, local_handle)];
+ write!(self.out, "\t{} {}", ty_name, local_name)?;
+ if let Some(value) = local.init {
+ write!(self.out, " = ")?;
+ self.put_constant(value, module)?;
+ }
+ writeln!(self.out, ";")?;
+ }
+
+ let context = ExpressionContext {
+ function: fun,
+ origin: FunctionOrigin::Handle(fun_handle),
+ module,
+ };
+ self.put_block(Level(1), &fun.body, &context, None)?;
+ writeln!(self.out, "}}")?;
+ }
+
+ for (ep_index, (&(stage, _), ep)) in module.entry_points.iter().enumerate() {
+ let fun = &ep.function;
+ self.typifier.resolve_all(
+ &fun.expressions,
+ &module.types,
+ &ResolveContext {
+ constants: &module.constants,
+ global_vars: &module.global_variables,
+ local_vars: &fun.local_variables,
+ functions: &module.functions,
+ arguments: &fun.arguments,
+ },
+ )?;
+
+ // find the entry point(s) and inputs/outputs
+ let mut last_used_global = None;
+ for ((handle, var), &usage) in module.global_variables.iter().zip(&fun.global_usage) {
+ match var.class {
+ crate::StorageClass::Input => {
+ if let Some(crate::Binding::Location(_)) = var.binding {
+ continue;
+ }
+ }
+ crate::StorageClass::Output => continue,
+ _ => {}
+ }
+ if !usage.is_empty() {
+ last_used_global = Some(handle);
+ }
+ }
+
+ let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
+ let output_name = format!("{}Output", fun_name);
+ let location_input_name = format!("{}Input", fun_name);
+
+ let (em_str, in_mode, out_mode) = match stage {
+ crate::ShaderStage::Vertex => (
+ "vertex",
+ LocationMode::VertexInput,
+ LocationMode::Intermediate,
+ ),
+ crate::ShaderStage::Fragment { .. } => (
+ "fragment",
+ LocationMode::Intermediate,
+ LocationMode::FragmentOutput,
+ ),
+ crate::ShaderStage::Compute { .. } => {
+ ("kernel", LocationMode::Uniform, LocationMode::Uniform)
+ }
+ };
+
+ let return_value = match stage {
+ crate::ShaderStage::Vertex | crate::ShaderStage::Fragment => {
+ // make dedicated input/output structs
+ writeln!(self.out, "struct {} {{", location_input_name)?;
+
+ for ((handle, var), &usage) in
+ module.global_variables.iter().zip(&fun.global_usage)
+ {
+ if var.class != crate::StorageClass::Input
+ || !usage.contains(crate::GlobalUse::LOAD)
+ {
+ continue;
+ }
+ // if it's a struct, lift all the built-in contents up to the root
+ if let crate::TypeInner::Struct { ref members } = module.types[var.ty].inner
+ {
+ for (index, member) in members.iter().enumerate() {
+ if let crate::MemberOrigin::BuiltIn(built_in) = member.origin {
+ let name =
+ &self.names[&NameKey::StructMember(var.ty, index as u32)];
+ let ty_name = &self.names[&NameKey::Type(member.ty)];
+ write!(self.out, "\t{} {}", ty_name, name)?;
+ ResolvedBinding::BuiltIn(built_in)
+ .try_fmt_decorated(&mut self.out, ";\n")?;
+ }
+ }
+ } else if let Some(ref binding @ crate::Binding::Location(_)) = var.binding
+ {
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage: crate::GlobalUse::empty(),
+ };
+ let resolved = options.resolve_binding(stage, binding, in_mode)?;
+
+ write!(self.out, "\t")?;
+ tyvar.try_fmt(&mut self.out)?;
+ resolved.try_fmt_decorated(&mut self.out, ";\n")?;
+ }
+ }
+ writeln!(self.out, "}};")?;
+
+ writeln!(self.out, "struct {} {{", output_name)?;
+ for ((handle, var), &usage) in
+ module.global_variables.iter().zip(&fun.global_usage)
+ {
+ if var.class != crate::StorageClass::Output
+ || !usage.contains(crate::GlobalUse::STORE)
+ {
+ continue;
+ }
+ // if it's a struct, lift all the built-in contents up to the root
+ if let crate::TypeInner::Struct { ref members } = module.types[var.ty].inner
+ {
+ for (index, member) in members.iter().enumerate() {
+ let name =
+ &self.names[&NameKey::StructMember(var.ty, index as u32)];
+ let ty_name = &self.names[&NameKey::Type(member.ty)];
+ match member.origin {
+ crate::MemberOrigin::Empty => {}
+ crate::MemberOrigin::BuiltIn(built_in) => {
+ write!(self.out, "\t{} {}", ty_name, name)?;
+ ResolvedBinding::BuiltIn(built_in)
+ .try_fmt_decorated(&mut self.out, ";\n")?;
+ }
+ crate::MemberOrigin::Offset(_) => {
+ //TODO
+ }
+ }
+ }
+ } else {
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage: crate::GlobalUse::empty(),
+ };
+ write!(self.out, "\t")?;
+ tyvar.try_fmt(&mut self.out)?;
+ if let Some(ref binding) = var.binding {
+ let resolved = options.resolve_binding(stage, binding, out_mode)?;
+ resolved.try_fmt_decorated(&mut self.out, "")?;
+ }
+ writeln!(self.out, ";")?;
+ }
+ }
+ writeln!(self.out, "}};")?;
+
+ writeln!(self.out, "{} {} {}(", em_str, output_name, fun_name)?;
+ let separator = separate(last_used_global.is_none());
+ writeln!(
+ self.out,
+ "\t{} {} [[stage_in]]{}",
+ location_input_name, LOCATION_INPUT_STRUCT_NAME, separator
+ )?;
+
+ Some(OUTPUT_STRUCT_NAME)
+ }
+ crate::ShaderStage::Compute => {
+ writeln!(self.out, "{} void {}(", em_str, fun_name)?;
+ None
+ }
+ };
+
+ for ((handle, var), &usage) in module.global_variables.iter().zip(&fun.global_usage) {
+ if usage.is_empty() || var.class == crate::StorageClass::Output {
+ continue;
+ }
+ if var.class == crate::StorageClass::Input {
+ if let Some(crate::Binding::Location(_)) = var.binding {
+ // location inputs are put into a separate struct
+ continue;
+ }
+ }
+ let loc_mode = match (stage, var.class) {
+ (crate::ShaderStage::Vertex, crate::StorageClass::Input) => {
+ LocationMode::VertexInput
+ }
+ (crate::ShaderStage::Vertex, crate::StorageClass::Output)
+ | (crate::ShaderStage::Fragment { .. }, crate::StorageClass::Input) => {
+ LocationMode::Intermediate
+ }
+ (crate::ShaderStage::Fragment { .. }, crate::StorageClass::Output) => {
+ LocationMode::FragmentOutput
+ }
+ _ => LocationMode::Uniform,
+ };
+ let resolved =
+ options.resolve_binding(stage, var.binding.as_ref().unwrap(), loc_mode)?;
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage,
+ };
+ let separator = separate(last_used_global == Some(handle));
+ write!(self.out, "\t")?;
+ tyvar.try_fmt(&mut self.out)?;
+ resolved.try_fmt_decorated(&mut self.out, separator)?;
+ if let Some(value) = var.init {
+ write!(self.out, " = ")?;
+ self.put_constant(value, module)?;
+ }
+ writeln!(self.out)?;
+ }
+ writeln!(self.out, ") {{")?;
+
+ match stage {
+ crate::ShaderStage::Vertex | crate::ShaderStage::Fragment => {
+ writeln!(self.out, "\t{} {};", output_name, OUTPUT_STRUCT_NAME)?;
+ }
+ crate::ShaderStage::Compute => {}
+ }
+ for (local_handle, local) in fun.local_variables.iter() {
+ let name = &self.names[&NameKey::EntryPointLocal(ep_index as _, local_handle)];
+ let ty_name = &self.names[&NameKey::Type(local.ty)];
+ write!(self.out, "\t{} {}", ty_name, name)?;
+ if let Some(value) = local.init {
+ write!(self.out, " = ")?;
+ self.put_constant(value, module)?;
+ }
+ writeln!(self.out, ";")?;
+ }
+
+ let context = ExpressionContext {
+ function: fun,
+ origin: FunctionOrigin::EntryPoint(ep_index as _),
+ module,
+ };
+ self.put_block(Level(1), &fun.body, &context, return_value)?;
+ writeln!(self.out, "}}")?;
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/helpers.rs b/third_party/rust/naga/src/back/spv/helpers.rs
new file mode 100644
index 0000000000..5facbe8b69
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/helpers.rs
@@ -0,0 +1,20 @@
+use spirv::Word;
+
+pub(crate) fn bytes_to_words(bytes: &[u8]) -> Vec<Word> {
+ bytes
+ .chunks(4)
+ .map(|chars| chars.iter().rev().fold(0u32, |u, c| (u << 8) | *c as u32))
+ .collect()
+}
+
+pub(crate) fn string_to_words(input: &str) -> Vec<Word> {
+ let bytes = input.as_bytes();
+ let mut words = bytes_to_words(bytes);
+
+ if bytes.len() % 4 == 0 {
+ // nul-termination
+ words.push(0x0u32);
+ }
+
+ words
+}
diff --git a/third_party/rust/naga/src/back/spv/instructions.rs b/third_party/rust/naga/src/back/spv/instructions.rs
new file mode 100644
index 0000000000..ab8e56844a
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/instructions.rs
@@ -0,0 +1,708 @@
+use crate::back::spv::{helpers, Instruction};
+use spirv::{Op, Word};
+
+pub(super) enum Signedness {
+ Unsigned = 0,
+ Signed = 1,
+}
+
+//
+// Debug Instructions
+//
+
+pub(super) fn instruction_source(
+ source_language: spirv::SourceLanguage,
+ version: u32,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::Source);
+ instruction.add_operand(source_language as u32);
+ instruction.add_operands(helpers::bytes_to_words(&version.to_le_bytes()));
+ instruction
+}
+
+pub(super) fn instruction_name(target_id: Word, name: &str) -> Instruction {
+ let mut instruction = Instruction::new(Op::Name);
+ instruction.add_operand(target_id);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+}
+
+//
+// Annotation Instructions
+//
+
+pub(super) fn instruction_decorate(
+ target_id: Word,
+ decoration: spirv::Decoration,
+ operands: &[Word],
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::Decorate);
+ instruction.add_operand(target_id);
+ instruction.add_operand(decoration as u32);
+
+ for operand in operands {
+ instruction.add_operand(*operand)
+ }
+
+ instruction
+}
+
+//
+// Extension Instructions
+//
+
+pub(super) fn instruction_ext_inst_import(id: Word, name: &str) -> Instruction {
+ let mut instruction = Instruction::new(Op::ExtInstImport);
+ instruction.set_result(id);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+}
+
+//
+// Mode-Setting Instructions
+//
+
+pub(super) fn instruction_memory_model(
+ addressing_model: spirv::AddressingModel,
+ memory_model: spirv::MemoryModel,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::MemoryModel);
+ instruction.add_operand(addressing_model as u32);
+ instruction.add_operand(memory_model as u32);
+ instruction
+}
+
+pub(super) fn instruction_entry_point(
+ execution_model: spirv::ExecutionModel,
+ entry_point_id: Word,
+ name: &str,
+ interface_ids: &[Word],
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::EntryPoint);
+ instruction.add_operand(execution_model as u32);
+ instruction.add_operand(entry_point_id);
+ instruction.add_operands(helpers::string_to_words(name));
+
+ for interface_id in interface_ids {
+ instruction.add_operand(*interface_id);
+ }
+
+ instruction
+}
+
+pub(super) fn instruction_execution_mode(
+ entry_point_id: Word,
+ execution_mode: spirv::ExecutionMode,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::ExecutionMode);
+ instruction.add_operand(entry_point_id);
+ instruction.add_operand(execution_mode as u32);
+ instruction
+}
+
+pub(super) fn instruction_capability(capability: spirv::Capability) -> Instruction {
+ let mut instruction = Instruction::new(Op::Capability);
+ instruction.add_operand(capability as u32);
+ instruction
+}
+
+//
+// Type-Declaration Instructions
+//
+
+pub(super) fn instruction_type_void(id: Word) -> Instruction {
+ let mut instruction = Instruction::new(Op::TypeVoid);
+ instruction.set_result(id);
+ instruction
+}
+
+pub(super) fn instruction_type_bool(id: Word) -> Instruction {
+ let mut instruction = Instruction::new(Op::TypeBool);
+ instruction.set_result(id);
+ instruction
+}
+
+pub(super) fn instruction_type_int(id: Word, width: Word, signedness: Signedness) -> Instruction {
+ let mut instruction = Instruction::new(Op::TypeInt);
+ instruction.set_result(id);
+ instruction.add_operand(width);
+ instruction.add_operand(signedness as u32);
+ instruction
+}
+
+pub(super) fn instruction_type_float(id: Word, width: Word) -> Instruction {
+ let mut instruction = Instruction::new(Op::TypeFloat);
+ instruction.set_result(id);
+ instruction.add_operand(width);
+ instruction
+}
+
+pub(super) fn instruction_type_vector(
+ id: Word,
+ component_type_id: Word,
+ component_count: crate::VectorSize,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::TypeVector);
+ instruction.set_result(id);
+ instruction.add_operand(component_type_id);
+ instruction.add_operand(component_count as u32);
+ instruction
+}
+
+pub(super) fn instruction_type_matrix(
+ id: Word,
+ column_type_id: Word,
+ column_count: crate::VectorSize,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::TypeMatrix);
+ instruction.set_result(id);
+ instruction.add_operand(column_type_id);
+ instruction.add_operand(column_count as u32);
+ instruction
+}
+
+pub(super) fn instruction_type_image(
+ id: Word,
+ sampled_type_id: Word,
+ dim: spirv::Dim,
+ arrayed: bool,
+ image_class: crate::ImageClass,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::TypeImage);
+ instruction.set_result(id);
+ instruction.add_operand(sampled_type_id);
+ instruction.add_operand(dim as u32);
+
+ instruction.add_operand(match image_class {
+ crate::ImageClass::Depth => 1,
+ _ => 0,
+ });
+ instruction.add_operand(arrayed as u32);
+ instruction.add_operand(match image_class {
+ crate::ImageClass::Sampled { multi: true, .. } => 1,
+ _ => 0,
+ });
+ instruction.add_operand(match image_class {
+ crate::ImageClass::Sampled { .. } => 1,
+ _ => 0,
+ });
+
+ let format = match image_class {
+ crate::ImageClass::Storage(format) => match format {
+ crate::StorageFormat::R8Unorm => spirv::ImageFormat::R8,
+ crate::StorageFormat::R8Snorm => spirv::ImageFormat::R8Snorm,
+ crate::StorageFormat::R8Uint => spirv::ImageFormat::R8ui,
+ crate::StorageFormat::R8Sint => spirv::ImageFormat::R8i,
+ crate::StorageFormat::R16Uint => spirv::ImageFormat::R16ui,
+ crate::StorageFormat::R16Sint => spirv::ImageFormat::R16i,
+ crate::StorageFormat::R16Float => spirv::ImageFormat::R16f,
+ crate::StorageFormat::Rg8Unorm => spirv::ImageFormat::Rg8,
+ crate::StorageFormat::Rg8Snorm => spirv::ImageFormat::Rg8Snorm,
+ crate::StorageFormat::Rg8Uint => spirv::ImageFormat::Rg8ui,
+ crate::StorageFormat::Rg8Sint => spirv::ImageFormat::Rg8i,
+ crate::StorageFormat::R32Uint => spirv::ImageFormat::R32ui,
+ crate::StorageFormat::R32Sint => spirv::ImageFormat::R32i,
+ crate::StorageFormat::R32Float => spirv::ImageFormat::R32f,
+ crate::StorageFormat::Rg16Uint => spirv::ImageFormat::Rg16ui,
+ crate::StorageFormat::Rg16Sint => spirv::ImageFormat::Rg16i,
+ crate::StorageFormat::Rg16Float => spirv::ImageFormat::Rg16f,
+ crate::StorageFormat::Rgba8Unorm => spirv::ImageFormat::Rgba8,
+ crate::StorageFormat::Rgba8Snorm => spirv::ImageFormat::Rgba8Snorm,
+ crate::StorageFormat::Rgba8Uint => spirv::ImageFormat::Rgba8ui,
+ crate::StorageFormat::Rgba8Sint => spirv::ImageFormat::Rgba8i,
+ crate::StorageFormat::Rgb10a2Unorm => spirv::ImageFormat::Rgb10a2ui,
+ crate::StorageFormat::Rg11b10Float => spirv::ImageFormat::R11fG11fB10f,
+ crate::StorageFormat::Rg32Uint => spirv::ImageFormat::Rg32ui,
+ crate::StorageFormat::Rg32Sint => spirv::ImageFormat::Rg32i,
+ crate::StorageFormat::Rg32Float => spirv::ImageFormat::Rg32f,
+ crate::StorageFormat::Rgba16Uint => spirv::ImageFormat::Rgba16ui,
+ crate::StorageFormat::Rgba16Sint => spirv::ImageFormat::Rgba16i,
+ crate::StorageFormat::Rgba16Float => spirv::ImageFormat::Rgba16f,
+ crate::StorageFormat::Rgba32Uint => spirv::ImageFormat::Rgba32ui,
+ crate::StorageFormat::Rgba32Sint => spirv::ImageFormat::Rgba32i,
+ crate::StorageFormat::Rgba32Float => spirv::ImageFormat::Rgba32f,
+ },
+ _ => spirv::ImageFormat::Unknown,
+ };
+
+ instruction.add_operand(format as u32);
+ instruction
+}
+
+pub(super) fn instruction_type_sampler(id: Word) -> Instruction {
+ let mut instruction = Instruction::new(Op::TypeSampler);
+ instruction.set_result(id);
+ instruction
+}
+
+pub(super) fn instruction_type_sampled_image(id: Word, image_type_id: Word) -> Instruction {
+ let mut instruction = Instruction::new(Op::TypeSampledImage);
+ instruction.set_result(id);
+ instruction.add_operand(image_type_id);
+ instruction
+}
+
+pub(super) fn instruction_type_array(
+ id: Word,
+ element_type_id: Word,
+ length_id: Word,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::TypeArray);
+ instruction.set_result(id);
+ instruction.add_operand(element_type_id);
+ instruction.add_operand(length_id);
+ instruction
+}
+
+pub(super) fn instruction_type_runtime_array(id: Word, element_type_id: Word) -> Instruction {
+ let mut instruction = Instruction::new(Op::TypeRuntimeArray);
+ instruction.set_result(id);
+ instruction.add_operand(element_type_id);
+ instruction
+}
+
+pub(super) fn instruction_type_struct(id: Word, member_ids: &[Word]) -> Instruction {
+ let mut instruction = Instruction::new(Op::TypeStruct);
+ instruction.set_result(id);
+
+ for member_id in member_ids {
+ instruction.add_operand(*member_id)
+ }
+
+ instruction
+}
+
+pub(super) fn instruction_type_pointer(
+ id: Word,
+ storage_class: spirv::StorageClass,
+ type_id: Word,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::TypePointer);
+ instruction.set_result(id);
+ instruction.add_operand(storage_class as u32);
+ instruction.add_operand(type_id);
+ instruction
+}
+
+pub(super) fn instruction_type_function(
+ id: Word,
+ return_type_id: Word,
+ parameter_ids: &[Word],
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::TypeFunction);
+ instruction.set_result(id);
+ instruction.add_operand(return_type_id);
+
+ for parameter_id in parameter_ids {
+ instruction.add_operand(*parameter_id);
+ }
+
+ instruction
+}
+
+//
+// Constant-Creation Instructions
+//
+
+pub(super) fn instruction_constant_true(result_type_id: Word, id: Word) -> Instruction {
+ let mut instruction = Instruction::new(Op::ConstantTrue);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+}
+
+pub(super) fn instruction_constant_false(result_type_id: Word, id: Word) -> Instruction {
+ let mut instruction = Instruction::new(Op::ConstantFalse);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+}
+
+pub(super) fn instruction_constant(result_type_id: Word, id: Word, values: &[Word]) -> Instruction {
+ let mut instruction = Instruction::new(Op::Constant);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ for value in values {
+ instruction.add_operand(*value);
+ }
+
+ instruction
+}
+
+pub(super) fn instruction_constant_composite(
+ result_type_id: Word,
+ id: Word,
+ constituent_ids: &[Word],
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::ConstantComposite);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ for constituent_id in constituent_ids {
+ instruction.add_operand(*constituent_id);
+ }
+
+ instruction
+}
+
+//
+// Memory Instructions
+//
+
+pub(super) fn instruction_variable(
+ result_type_id: Word,
+ id: Word,
+ storage_class: spirv::StorageClass,
+ initializer_id: Option<Word>,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::Variable);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(storage_class as u32);
+
+ if let Some(initializer_id) = initializer_id {
+ instruction.add_operand(initializer_id);
+ }
+
+ instruction
+}
+
+pub(super) fn instruction_load(
+ result_type_id: Word,
+ id: Word,
+ pointer_type_id: Word,
+ memory_access: Option<spirv::MemoryAccess>,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::Load);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(pointer_type_id);
+
+ if let Some(memory_access) = memory_access {
+ instruction.add_operand(memory_access.bits());
+ }
+
+ instruction
+}
+
+pub(super) fn instruction_store(
+ pointer_type_id: Word,
+ object_id: Word,
+ memory_access: Option<spirv::MemoryAccess>,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::Store);
+ instruction.add_operand(pointer_type_id);
+ instruction.add_operand(object_id);
+
+ if let Some(memory_access) = memory_access {
+ instruction.add_operand(memory_access.bits());
+ }
+
+ instruction
+}
+
+pub(super) fn instruction_access_chain(
+ result_type_id: Word,
+ id: Word,
+ base_id: Word,
+ index_ids: &[Word],
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::AccessChain);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(base_id);
+
+ for index_id in index_ids {
+ instruction.add_operand(*index_id);
+ }
+
+ instruction
+}
+
+//
+// Function Instructions
+//
+
+pub(super) fn instruction_function(
+ return_type_id: Word,
+ id: Word,
+ function_control: spirv::FunctionControl,
+ function_type_id: Word,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::Function);
+ instruction.set_type(return_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(function_control.bits());
+ instruction.add_operand(function_type_id);
+ instruction
+}
+
+pub(super) fn instruction_function_parameter(result_type_id: Word, id: Word) -> Instruction {
+ let mut instruction = Instruction::new(Op::FunctionParameter);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+}
+
+pub(super) fn instruction_function_end() -> Instruction {
+ Instruction::new(Op::FunctionEnd)
+}
+
+pub(super) fn instruction_function_call(
+ result_type_id: Word,
+ id: Word,
+ function_id: Word,
+ argument_ids: &[Word],
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::FunctionCall);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(function_id);
+
+ for argument_id in argument_ids {
+ instruction.add_operand(*argument_id);
+ }
+
+ instruction
+}
+
+//
+// Image Instructions
+//
+pub(super) fn instruction_sampled_image(
+ result_type_id: Word,
+ id: Word,
+ image: Word,
+ sampler: Word,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::SampledImage);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(image);
+ instruction.add_operand(sampler);
+ instruction
+}
+
+pub(super) fn instruction_image_sample_implicit_lod(
+ result_type_id: Word,
+ id: Word,
+ sampled_image: Word,
+ coordinates: Word,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::ImageSampleImplicitLod);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(sampled_image);
+ instruction.add_operand(coordinates);
+ instruction
+}
+
+//
+// Conversion Instructions
+//
+pub(super) fn instruction_unary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ value: Word,
+) -> Instruction {
+ let mut instruction = Instruction::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(value);
+ instruction
+}
+
+//
+// Composite Instructions
+//
+
+pub(super) fn instruction_composite_construct(
+ result_type_id: Word,
+ id: Word,
+ constituent_ids: &[Word],
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::CompositeConstruct);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ for constituent_id in constituent_ids {
+ instruction.add_operand(*constituent_id);
+ }
+
+ instruction
+}
+
+//
+// Arithmetic Instructions
+//
+fn instruction_binary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+) -> Instruction {
+ let mut instruction = Instruction::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(operand_1);
+ instruction.add_operand(operand_2);
+ instruction
+}
+
+pub(super) fn instruction_i_sub(
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+) -> Instruction {
+ instruction_binary(Op::ISub, result_type_id, id, operand_1, operand_2)
+}
+
+pub(super) fn instruction_f_sub(
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+) -> Instruction {
+ instruction_binary(Op::FSub, result_type_id, id, operand_1, operand_2)
+}
+
+pub(super) fn instruction_i_mul(
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+) -> Instruction {
+ instruction_binary(Op::IMul, result_type_id, id, operand_1, operand_2)
+}
+
+pub(super) fn instruction_f_mul(
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+) -> Instruction {
+ instruction_binary(Op::FMul, result_type_id, id, operand_1, operand_2)
+}
+
+pub(super) fn instruction_vector_times_scalar(
+ result_type_id: Word,
+ id: Word,
+ vector_type_id: Word,
+ scalar_type_id: Word,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::VectorTimesScalar);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(vector_type_id);
+ instruction.add_operand(scalar_type_id);
+ instruction
+}
+
+pub(super) fn instruction_matrix_times_scalar(
+ result_type_id: Word,
+ id: Word,
+ matrix_id: Word,
+ scalar_id: Word,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::MatrixTimesScalar);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(matrix_id);
+ instruction.add_operand(scalar_id);
+ instruction
+}
+
+pub(super) fn instruction_vector_times_matrix(
+ result_type_id: Word,
+ id: Word,
+ vector_id: Word,
+ matrix_id: Word,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::VectorTimesMatrix);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(vector_id);
+ instruction.add_operand(matrix_id);
+ instruction
+}
+
+pub(super) fn instruction_matrix_times_vector(
+ result_type_id: Word,
+ id: Word,
+ matrix_id: Word,
+ vector_id: Word,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::MatrixTimesVector);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(matrix_id);
+ instruction.add_operand(vector_id);
+ instruction
+}
+
+pub(super) fn instruction_matrix_times_matrix(
+ result_type_id: Word,
+ id: Word,
+ left_matrix: Word,
+ right_matrix: Word,
+) -> Instruction {
+ let mut instruction = Instruction::new(Op::MatrixTimesMatrix);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(left_matrix);
+ instruction.add_operand(right_matrix);
+ instruction
+}
+
+//
+// Bit Instructions
+//
+
+pub(super) fn instruction_bitwise_and(
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+) -> Instruction {
+ instruction_binary(Op::BitwiseAnd, result_type_id, id, operand_1, operand_2)
+}
+
+//
+// Relational and Logical Instructions
+//
+
+//
+// Derivative Instructions
+//
+
+//
+// Control-Flow Instructions
+//
+
+pub(super) fn instruction_label(id: Word) -> Instruction {
+ let mut instruction = Instruction::new(Op::Label);
+ instruction.set_result(id);
+ instruction
+}
+
+pub(super) fn instruction_return() -> Instruction {
+ Instruction::new(Op::Return)
+}
+
+pub(super) fn instruction_return_value(value_id: Word) -> Instruction {
+ let mut instruction = Instruction::new(Op::ReturnValue);
+ instruction.add_operand(value_id);
+ instruction
+}
+
+//
+// Atomic Instructions
+//
+
+//
+// Primitive Instructions
+//
diff --git a/third_party/rust/naga/src/back/spv/layout.rs b/third_party/rust/naga/src/back/spv/layout.rs
new file mode 100644
index 0000000000..006e785317
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/layout.rs
@@ -0,0 +1,91 @@
+use crate::back::spv::{Instruction, LogicalLayout, PhysicalLayout};
+use spirv::*;
+use std::iter;
+
+impl PhysicalLayout {
+ pub(super) fn new(header: &crate::Header) -> Self {
+ let version: Word = ((header.version.0 as u32) << 16)
+ | ((header.version.1 as u32) << 8)
+ | header.version.2 as u32;
+
+ PhysicalLayout {
+ magic_number: MAGIC_NUMBER,
+ version,
+ generator: header.generator,
+ bound: 0,
+ instruction_schema: 0x0u32,
+ }
+ }
+
+ pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) {
+ sink.extend(iter::once(self.magic_number));
+ sink.extend(iter::once(self.version));
+ sink.extend(iter::once(self.generator));
+ sink.extend(iter::once(self.bound));
+ sink.extend(iter::once(self.instruction_schema));
+ }
+
+ pub(super) fn supports_storage_buffers(&self) -> bool {
+ self.version >= 0x10300
+ }
+}
+
+impl LogicalLayout {
+ pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) {
+ sink.extend(self.capabilities.iter().cloned());
+ sink.extend(self.extensions.iter().cloned());
+ sink.extend(self.ext_inst_imports.iter().cloned());
+ sink.extend(self.memory_model.iter().cloned());
+ sink.extend(self.entry_points.iter().cloned());
+ sink.extend(self.execution_modes.iter().cloned());
+ sink.extend(self.debugs.iter().cloned());
+ sink.extend(self.annotations.iter().cloned());
+ sink.extend(self.declarations.iter().cloned());
+ sink.extend(self.function_declarations.iter().cloned());
+ sink.extend(self.function_definitions.iter().cloned());
+ }
+}
+
+impl Instruction {
+ pub(super) fn new(op: Op) -> Self {
+ Instruction {
+ op,
+ wc: 1, // Always start at 1 for the first word (OP + WC),
+ type_id: None,
+ result_id: None,
+ operands: vec![],
+ }
+ }
+
+ #[allow(clippy::panic)]
+ pub(super) fn set_type(&mut self, id: Word) {
+ assert!(self.type_id.is_none(), "Type can only be set once");
+ self.type_id = Some(id);
+ self.wc += 1;
+ }
+
+ #[allow(clippy::panic)]
+ pub(super) fn set_result(&mut self, id: Word) {
+ assert!(self.result_id.is_none(), "Result can only be set once");
+ self.result_id = Some(id);
+ self.wc += 1;
+ }
+
+ pub(super) fn add_operand(&mut self, operand: Word) {
+ self.operands.push(operand);
+ self.wc += 1;
+ }
+
+ pub(super) fn add_operands(&mut self, operands: Vec<Word>) {
+ for operand in operands.into_iter() {
+ self.add_operand(operand)
+ }
+ }
+
+ pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) {
+ sink.extend(Some((self.wc << 16 | self.op as u32) as u32));
+ sink.extend(self.type_id);
+ sink.extend(self.result_id);
+ sink.extend(self.operands.iter().cloned());
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/layout_tests.rs b/third_party/rust/naga/src/back/spv/layout_tests.rs
new file mode 100644
index 0000000000..37024b238f
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/layout_tests.rs
@@ -0,0 +1,166 @@
+use crate::back::spv::test_framework::*;
+use crate::back::spv::{helpers, Instruction, LogicalLayout, PhysicalLayout};
+use crate::Header;
+use spirv::*;
+
+#[test]
+fn test_physical_layout_in_words() {
+ let header = Header {
+ generator: 0,
+ version: (1, 2, 3),
+ };
+ let bound = 5;
+
+ let mut output = vec![];
+ let mut layout = PhysicalLayout::new(&header);
+ layout.bound = bound;
+
+ layout.in_words(&mut output);
+
+ assert_eq!(output[0], spirv::MAGIC_NUMBER);
+ assert_eq!(
+ output[1],
+ to_word(&[header.version.0, header.version.1, header.version.2, 1])
+ );
+ assert_eq!(output[2], 0);
+ assert_eq!(output[3], bound);
+ assert_eq!(output[4], 0);
+}
+
+#[test]
+fn test_logical_layout_in_words() {
+ let mut output = vec![];
+ let mut layout = LogicalLayout::default();
+ let layout_vectors = 11;
+ let mut instructions = Vec::with_capacity(layout_vectors);
+
+ let vector_names = &[
+ "Capabilities",
+ "Extensions",
+ "External Instruction Imports",
+ "Memory Model",
+ "Entry Points",
+ "Execution Modes",
+ "Debugs",
+ "Annotations",
+ "Declarations",
+ "Function Declarations",
+ "Function Definitions",
+ ];
+
+ for i in 0..layout_vectors {
+ let mut dummy_instruction = Instruction::new(Op::Constant);
+ dummy_instruction.set_type((i + 1) as u32);
+ dummy_instruction.set_result((i + 2) as u32);
+ dummy_instruction.add_operand((i + 3) as u32);
+ dummy_instruction.add_operands(helpers::string_to_words(
+ format!("This is the vector: {}", vector_names[i]).as_str(),
+ ));
+ instructions.push(dummy_instruction);
+ }
+
+ instructions[0].to_words(&mut layout.capabilities);
+ instructions[1].to_words(&mut layout.extensions);
+ instructions[2].to_words(&mut layout.ext_inst_imports);
+ instructions[3].to_words(&mut layout.memory_model);
+ instructions[4].to_words(&mut layout.entry_points);
+ instructions[5].to_words(&mut layout.execution_modes);
+ instructions[6].to_words(&mut layout.debugs);
+ instructions[7].to_words(&mut layout.annotations);
+ instructions[8].to_words(&mut layout.declarations);
+ instructions[9].to_words(&mut layout.function_declarations);
+ instructions[10].to_words(&mut layout.function_definitions);
+
+ layout.in_words(&mut output);
+
+ let mut index: usize = 0;
+ for instruction in instructions {
+ let wc = instruction.wc as usize;
+ let instruction_output = &output[index..index + wc];
+ validate_instruction(instruction_output, &instruction);
+ index += wc;
+ }
+}
+
+#[test]
+fn test_instruction_set_type() {
+ let ty = 1;
+ let mut instruction = Instruction::new(Op::Constant);
+ assert_eq!(instruction.wc, 1);
+
+ instruction.set_type(ty);
+ assert_eq!(instruction.type_id.unwrap(), ty);
+ assert_eq!(instruction.wc, 2);
+}
+
+#[test]
+#[should_panic]
+fn test_instruction_set_type_twice() {
+ let ty = 1;
+ let mut instruction = Instruction::new(Op::Constant);
+ instruction.set_type(ty);
+ instruction.set_type(ty);
+}
+
+#[test]
+fn test_instruction_set_result() {
+ let result = 1;
+ let mut instruction = Instruction::new(Op::Constant);
+ assert_eq!(instruction.wc, 1);
+
+ instruction.set_result(result);
+ assert_eq!(instruction.result_id.unwrap(), result);
+ assert_eq!(instruction.wc, 2);
+}
+
+#[test]
+#[should_panic]
+fn test_instruction_set_result_twice() {
+ let result = 1;
+ let mut instruction = Instruction::new(Op::Constant);
+ instruction.set_result(result);
+ instruction.set_result(result);
+}
+
+#[test]
+fn test_instruction_add_operand() {
+ let operand = 1;
+ let mut instruction = Instruction::new(Op::Constant);
+ assert_eq!(instruction.operands.len(), 0);
+ assert_eq!(instruction.wc, 1);
+
+ instruction.add_operand(operand);
+ assert_eq!(instruction.operands.len(), 1);
+ assert_eq!(instruction.wc, 2);
+}
+
+#[test]
+fn test_instruction_add_operands() {
+ let operands = vec![1, 2, 3];
+ let mut instruction = Instruction::new(Op::Constant);
+ assert_eq!(instruction.operands.len(), 0);
+ assert_eq!(instruction.wc, 1);
+
+ instruction.add_operands(operands);
+ assert_eq!(instruction.operands.len(), 3);
+ assert_eq!(instruction.wc, 4);
+}
+
+#[test]
+fn test_instruction_to_words() {
+ let ty = 1;
+ let result = 2;
+ let operand = 3;
+ let mut instruction = Instruction::new(Op::Constant);
+ instruction.set_type(ty);
+ instruction.set_result(result);
+ instruction.add_operand(operand);
+
+ let mut output = vec![];
+ instruction.to_words(&mut output);
+ validate_instruction(output.as_slice(), &instruction);
+}
+
+fn to_word(bytes: &[u8]) -> Word {
+ ((bytes[0] as u32) << 16) | ((bytes[1] as u32) << 8) | bytes[2] as u32
+}
diff --git a/third_party/rust/naga/src/back/spv/mod.rs b/third_party/rust/naga/src/back/spv/mod.rs
new file mode 100644
index 0000000000..15f1598357
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/mod.rs
@@ -0,0 +1,52 @@
+mod helpers;
+mod instructions;
+mod layout;
+mod writer;
+
+#[cfg(test)]
+mod test_framework;
+
+#[cfg(test)]
+mod layout_tests;
+
+pub use writer::Writer;
+
+use spirv::*;
+
+bitflags::bitflags! {
+ pub struct WriterFlags: u32 {
+ const NONE = 0x0;
+ const DEBUG = 0x1;
+ }
+}
+
+struct PhysicalLayout {
+ magic_number: Word,
+ version: Word,
+ generator: Word,
+ bound: Word,
+ instruction_schema: Word,
+}
+
+#[derive(Default)]
+struct LogicalLayout {
+ capabilities: Vec<Word>,
+ extensions: Vec<Word>,
+ ext_inst_imports: Vec<Word>,
+ memory_model: Vec<Word>,
+ entry_points: Vec<Word>,
+ execution_modes: Vec<Word>,
+ debugs: Vec<Word>,
+ annotations: Vec<Word>,
+ declarations: Vec<Word>,
+ function_declarations: Vec<Word>,
+ function_definitions: Vec<Word>,
+}
+
+pub(self) struct Instruction {
+ op: Op,
+ wc: u32,
+ type_id: Option<Word>,
+ result_id: Option<Word>,
+ operands: Vec<Word>,
+}
diff --git a/third_party/rust/naga/src/back/spv/test_framework.rs b/third_party/rust/naga/src/back/spv/test_framework.rs
new file mode 100644
index 0000000000..be2fa74fe1
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/test_framework.rs
@@ -0,0 +1,27 @@
+pub(super) fn validate_instruction(
+ words: &[spirv::Word],
+ instruction: &crate::back::spv::Instruction,
+) {
+ let mut inst_index = 0;
+ let (wc, op) = ((words[inst_index] >> 16) as u16, words[inst_index] as u16);
+ inst_index += 1;
+
+ assert_eq!(wc, words.len() as u16);
+ assert_eq!(op, instruction.op as u16);
+
+ if instruction.type_id.is_some() {
+ assert_eq!(words[inst_index], instruction.type_id.unwrap());
+ inst_index += 1;
+ }
+
+ if instruction.result_id.is_some() {
+ assert_eq!(words[inst_index], instruction.result_id.unwrap());
+ inst_index += 1;
+ }
+
+ let mut op_index = 0;
+ for i in inst_index..wc as usize {
+ assert_eq!(words[i as usize], instruction.operands[op_index]);
+ op_index += 1;
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/writer.rs b/third_party/rust/naga/src/back/spv/writer.rs
new file mode 100644
index 0000000000..f1b43f5289
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/writer.rs
@@ -0,0 +1,1776 @@
+/*! Standard Portable Intermediate Representation (SPIR-V) backend !*/
+use super::{Instruction, LogicalLayout, PhysicalLayout, WriterFlags};
+use spirv::Word;
+use std::{collections::hash_map::Entry, ops};
+use thiserror::Error;
+
+const BITS_PER_BYTE: crate::Bytes = 8;
+
+#[derive(Clone, Debug, Error)]
+pub enum Error {
+ #[error("can't find local variable: {0:?}")]
+ UnknownLocalVariable(crate::LocalVariable),
+ #[error("bad image class for op: {0:?}")]
+ BadImageClass(crate::ImageClass),
+ #[error("not an image")]
+ NotImage,
+ #[error("empty value")]
+ FeatureNotImplemented(),
+}
+
+struct Block {
+ label: Option<Instruction>,
+ body: Vec<Instruction>,
+ termination: Option<Instruction>,
+}
+
+impl Block {
+ pub fn new() -> Self {
+ Block {
+ label: None,
+ body: vec![],
+ termination: None,
+ }
+ }
+}
+
+struct LocalVariable {
+ id: Word,
+ name: Option<String>,
+ instruction: Instruction,
+}
+
+struct Function {
+ signature: Option<Instruction>,
+ parameters: Vec<Instruction>,
+ variables: Vec<LocalVariable>,
+ blocks: Vec<Block>,
+}
+
+impl Function {
+ pub fn new() -> Self {
+ Function {
+ signature: None,
+ parameters: vec![],
+ variables: vec![],
+ blocks: vec![],
+ }
+ }
+
+ fn to_words(&self, sink: &mut impl Extend<Word>) {
+ self.signature.as_ref().unwrap().to_words(sink);
+ for instruction in self.parameters.iter() {
+ instruction.to_words(sink);
+ }
+ for (index, block) in self.blocks.iter().enumerate() {
+ block.label.as_ref().unwrap().to_words(sink);
+ if index == 0 {
+ for local_var in self.variables.iter() {
+ local_var.instruction.to_words(sink);
+ }
+ }
+ for instruction in block.body.iter() {
+ instruction.to_words(sink);
+ }
+ block.termination.as_ref().unwrap().to_words(sink);
+ }
+ }
+}
+
+#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
+enum LocalType {
+ Void,
+ Scalar {
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ },
+ Vector {
+ size: crate::VectorSize,
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ },
+ Pointer {
+ base: crate::Handle<crate::Type>,
+ class: crate::StorageClass,
+ },
+ SampledImage {
+ image_type: crate::Handle<crate::Type>,
+ },
+}
+
+#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
+enum LookupType {
+ Handle(crate::Handle<crate::Type>),
+ Local(LocalType),
+}
+
+fn map_dim(dim: crate::ImageDimension) -> spirv::Dim {
+ match dim {
+ crate::ImageDimension::D1 => spirv::Dim::Dim1D,
+ crate::ImageDimension::D2 => spirv::Dim::Dim2D,
+ crate::ImageDimension::D3 => spirv::Dim::Dim2D,
+ crate::ImageDimension::Cube => spirv::Dim::DimCube,
+ }
+}
+
+#[derive(Debug, PartialEq, Clone, Hash, Eq)]
+struct LookupFunctionType {
+ parameter_type_ids: Vec<Word>,
+ return_type_id: Word,
+}
+
+enum MaybeOwned<'a, T> {
+ Owned(T),
+ Borrowed(&'a T),
+}
+
+impl<'a, T> ops::Deref for MaybeOwned<'a, T> {
+ type Target = T;
+ fn deref(&self) -> &T {
+ match *self {
+ MaybeOwned::Owned(ref value) => value,
+ MaybeOwned::Borrowed(reference) => reference,
+ }
+ }
+}
+
+enum Dimension {
+ Scalar,
+ Vector,
+ Matrix,
+}
+
+fn get_dimension(ty_inner: &crate::TypeInner) -> Dimension {
+ match *ty_inner {
+ crate::TypeInner::Scalar { .. } => Dimension::Scalar,
+ crate::TypeInner::Vector { .. } => Dimension::Vector,
+ crate::TypeInner::Matrix { .. } => Dimension::Matrix,
+ _ => unreachable!(),
+ }
+}
+
+pub struct Writer {
+ physical_layout: PhysicalLayout,
+ logical_layout: LogicalLayout,
+ id_count: u32,
+ capabilities: crate::FastHashSet<spirv::Capability>,
+ debugs: Vec<Instruction>,
+ annotations: Vec<Instruction>,
+ writer_flags: WriterFlags,
+ void_type: Option<u32>,
+ lookup_type: crate::FastHashMap<LookupType, Word>,
+ lookup_function: crate::FastHashMap<crate::Handle<crate::Function>, Word>,
+ lookup_function_type: crate::FastHashMap<LookupFunctionType, Word>,
+ lookup_constant: crate::FastHashMap<crate::Handle<crate::Constant>, Word>,
+ lookup_global_variable: crate::FastHashMap<crate::Handle<crate::GlobalVariable>, Word>,
+}
+
+// type alias, for success return of write_expression
+type WriteExpressionOutput = (Word, LookupType);
+
+impl Writer {
+ pub fn new(header: &crate::Header, writer_flags: WriterFlags) -> Self {
+ Writer {
+ physical_layout: PhysicalLayout::new(header),
+ logical_layout: LogicalLayout::default(),
+ id_count: 0,
+ capabilities: crate::FastHashSet::default(),
+ debugs: vec![],
+ annotations: vec![],
+ writer_flags,
+ void_type: None,
+ lookup_type: crate::FastHashMap::default(),
+ lookup_function: crate::FastHashMap::default(),
+ lookup_function_type: crate::FastHashMap::default(),
+ lookup_constant: crate::FastHashMap::default(),
+ lookup_global_variable: crate::FastHashMap::default(),
+ }
+ }
+
+ fn generate_id(&mut self) -> Word {
+ self.id_count += 1;
+ self.id_count
+ }
+
+ fn try_add_capabilities(&mut self, capabilities: &[spirv::Capability]) {
+ for capability in capabilities.iter() {
+ self.capabilities.insert(*capability);
+ }
+ }
+
+ fn get_type_id(&mut self, arena: &crate::Arena<crate::Type>, lookup_ty: LookupType) -> Word {
+ if let Entry::Occupied(e) = self.lookup_type.entry(lookup_ty) {
+ *e.get()
+ } else {
+ match lookup_ty {
+ LookupType::Handle(handle) => match arena[handle].inner {
+ crate::TypeInner::Scalar { kind, width } => self
+ .get_type_id(arena, LookupType::Local(LocalType::Scalar { kind, width })),
+ _ => self.write_type_declaration_arena(arena, handle),
+ },
+ LookupType::Local(local_ty) => self.write_type_declaration_local(arena, local_ty),
+ }
+ }
+ }
+
+ fn get_constant_id(
+ &mut self,
+ handle: crate::Handle<crate::Constant>,
+ ir_module: &crate::Module,
+ ) -> Word {
+ match self.lookup_constant.entry(handle) {
+ Entry::Occupied(e) => *e.get(),
+ _ => {
+ let (instruction, id) = self.write_constant_type(handle, ir_module);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ id
+ }
+ }
+ }
+
+ fn get_global_variable_id(
+ &mut self,
+ ir_module: &crate::Module,
+ handle: crate::Handle<crate::GlobalVariable>,
+ ) -> Word {
+ match self.lookup_global_variable.entry(handle) {
+ Entry::Occupied(e) => *e.get(),
+ _ => {
+ let (instruction, id) = self.write_global_variable(ir_module, handle);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ id
+ }
+ }
+ }
+
+ fn get_function_return_type(
+ &mut self,
+ ty: Option<crate::Handle<crate::Type>>,
+ arena: &crate::Arena<crate::Type>,
+ ) -> Word {
+ match ty {
+ Some(handle) => self.get_type_id(arena, LookupType::Handle(handle)),
+ None => match self.void_type {
+ Some(id) => id,
+ None => {
+ let id = self.generate_id();
+ self.void_type = Some(id);
+ super::instructions::instruction_type_void(id)
+ .to_words(&mut self.logical_layout.declarations);
+ id
+ }
+ },
+ }
+ }
+
+ fn get_pointer_id(
+ &mut self,
+ arena: &crate::Arena<crate::Type>,
+ handle: crate::Handle<crate::Type>,
+ class: crate::StorageClass,
+ ) -> Word {
+ let ty = &arena[handle];
+ let ty_id = self.get_type_id(arena, LookupType::Handle(handle));
+ match ty.inner {
+ crate::TypeInner::Pointer { .. } => ty_id,
+ _ => {
+ match self
+ .lookup_type
+ .entry(LookupType::Local(LocalType::Pointer {
+ base: handle,
+ class,
+ })) {
+ Entry::Occupied(e) => *e.get(),
+ _ => {
+ let id =
+ self.create_pointer(ty_id, self.parse_to_spirv_storage_class(class));
+ self.lookup_type.insert(
+ LookupType::Local(LocalType::Pointer {
+ base: handle,
+ class,
+ }),
+ id,
+ );
+ id
+ }
+ }
+ }
+ }
+ }
+
+ fn create_pointer(&mut self, ty_id: Word, class: spirv::StorageClass) -> Word {
+ let id = self.generate_id();
+ let instruction = super::instructions::instruction_type_pointer(id, class, ty_id);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ id
+ }
+
+ fn create_constant(&mut self, type_id: Word, value: &[Word]) -> Word {
+ let id = self.generate_id();
+ let instruction = super::instructions::instruction_constant(type_id, id, value);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ id
+ }
+
+ fn write_function(
+ &mut self,
+ ir_function: &crate::Function,
+ ir_module: &crate::Module,
+ ) -> spirv::Word {
+ let mut function = Function::new();
+
+ for (_, variable) in ir_function.local_variables.iter() {
+ let id = self.generate_id();
+
+ let init_word = variable
+ .init
+ .map(|constant| self.get_constant_id(constant, ir_module));
+
+ let pointer_id =
+ self.get_pointer_id(&ir_module.types, variable.ty, crate::StorageClass::Function);
+ function.variables.push(LocalVariable {
+ id,
+ name: variable.name.clone(),
+ instruction: super::instructions::instruction_variable(
+ pointer_id,
+ id,
+ spirv::StorageClass::Function,
+ init_word,
+ ),
+ });
+ }
+
+ let return_type_id =
+ self.get_function_return_type(ir_function.return_type, &ir_module.types);
+ let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len());
+
+ let mut function_parameter_pointer_ids = vec![];
+
+ for argument in ir_function.arguments.iter() {
+ let id = self.generate_id();
+ let pointer_id =
+ self.get_pointer_id(&ir_module.types, argument.ty, crate::StorageClass::Function);
+
+ function_parameter_pointer_ids.push(pointer_id);
+ parameter_type_ids
+ .push(self.get_type_id(&ir_module.types, LookupType::Handle(argument.ty)));
+ function
+ .parameters
+ .push(super::instructions::instruction_function_parameter(
+ pointer_id, id,
+ ));
+ }
+
+ let lookup_function_type = LookupFunctionType {
+ return_type_id,
+ parameter_type_ids,
+ };
+
+ let function_id = self.generate_id();
+ let function_type =
+ self.get_function_type(lookup_function_type, function_parameter_pointer_ids);
+ function.signature = Some(super::instructions::instruction_function(
+ return_type_id,
+ function_id,
+ spirv::FunctionControl::empty(),
+ function_type,
+ ));
+
+ self.write_block(&ir_function.body, ir_module, ir_function, &mut function);
+
+ function.to_words(&mut self.logical_layout.function_definitions);
+ super::instructions::instruction_function_end()
+ .to_words(&mut self.logical_layout.function_definitions);
+
+ function_id
+ }
+
+ // TODO Move to instructions module
+ fn write_entry_point(
+ &mut self,
+ entry_point: &crate::EntryPoint,
+ stage: crate::ShaderStage,
+ name: &str,
+ ir_module: &crate::Module,
+ ) -> Instruction {
+ let function_id = self.write_function(&entry_point.function, ir_module);
+
+ let exec_model = match stage {
+ crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
+ crate::ShaderStage::Fragment { .. } => spirv::ExecutionModel::Fragment,
+ crate::ShaderStage::Compute { .. } => spirv::ExecutionModel::GLCompute,
+ };
+
+ let mut interface_ids = vec![];
+ for ((handle, _), &usage) in ir_module
+ .global_variables
+ .iter()
+ .filter(|&(_, var)| {
+ var.class == crate::StorageClass::Input || var.class == crate::StorageClass::Output
+ })
+ .zip(&entry_point.function.global_usage)
+ {
+ if usage.contains(crate::GlobalUse::STORE) || usage.contains(crate::GlobalUse::LOAD) {
+ let id = self.get_global_variable_id(ir_module, handle);
+ interface_ids.push(id);
+ }
+ }
+
+ self.try_add_capabilities(exec_model.required_capabilities());
+ match stage {
+ crate::ShaderStage::Vertex => {}
+ crate::ShaderStage::Fragment => {
+ let execution_mode = spirv::ExecutionMode::OriginUpperLeft;
+ self.try_add_capabilities(execution_mode.required_capabilities());
+ super::instructions::instruction_execution_mode(function_id, execution_mode)
+ .to_words(&mut self.logical_layout.execution_modes);
+ }
+ crate::ShaderStage::Compute => {}
+ }
+
+ if self.writer_flags.contains(WriterFlags::DEBUG) {
+ self.debugs
+ .push(super::instructions::instruction_name(function_id, name));
+ }
+
+ super::instructions::instruction_entry_point(
+ exec_model,
+ function_id,
+ name,
+ interface_ids.as_slice(),
+ )
+ }
+
+ fn write_scalar(&self, id: Word, kind: crate::ScalarKind, width: crate::Bytes) -> Instruction {
+ let bits = (width * BITS_PER_BYTE) as u32;
+ match kind {
+ crate::ScalarKind::Sint => super::instructions::instruction_type_int(
+ id,
+ bits,
+ super::instructions::Signedness::Signed,
+ ),
+ crate::ScalarKind::Uint => super::instructions::instruction_type_int(
+ id,
+ bits,
+ super::instructions::Signedness::Unsigned,
+ ),
+ crate::ScalarKind::Float => super::instructions::instruction_type_float(id, bits),
+ crate::ScalarKind::Bool => super::instructions::instruction_type_bool(id),
+ }
+ }
+
+ fn parse_to_spirv_storage_class(&self, class: crate::StorageClass) -> spirv::StorageClass {
+ match class {
+ crate::StorageClass::Handle => spirv::StorageClass::UniformConstant,
+ crate::StorageClass::Function => spirv::StorageClass::Function,
+ crate::StorageClass::Input => spirv::StorageClass::Input,
+ crate::StorageClass::Output => spirv::StorageClass::Output,
+ crate::StorageClass::Private => spirv::StorageClass::Private,
+ crate::StorageClass::Storage if self.physical_layout.supports_storage_buffers() => {
+ spirv::StorageClass::StorageBuffer
+ }
+ crate::StorageClass::Storage | crate::StorageClass::Uniform => {
+ spirv::StorageClass::Uniform
+ }
+ crate::StorageClass::WorkGroup => spirv::StorageClass::Workgroup,
+ crate::StorageClass::PushConstant => spirv::StorageClass::PushConstant,
+ }
+ }
+
+ fn write_type_declaration_local(
+ &mut self,
+ arena: &crate::Arena<crate::Type>,
+ local_ty: LocalType,
+ ) -> Word {
+ let id = self.generate_id();
+ let instruction = match local_ty {
+ LocalType::Void => unreachable!(),
+ LocalType::Scalar { kind, width } => self.write_scalar(id, kind, width),
+ LocalType::Vector { size, kind, width } => {
+ let scalar_id =
+ self.get_type_id(arena, LookupType::Local(LocalType::Scalar { kind, width }));
+ super::instructions::instruction_type_vector(id, scalar_id, size)
+ }
+ LocalType::Pointer { .. } => unimplemented!(),
+ LocalType::SampledImage { image_type } => {
+ let image_type_id = self.get_type_id(arena, LookupType::Handle(image_type));
+ super::instructions::instruction_type_sampled_image(id, image_type_id)
+ }
+ };
+
+ self.lookup_type.insert(LookupType::Local(local_ty), id);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ id
+ }
+
+ fn write_type_declaration_arena(
+ &mut self,
+ arena: &crate::Arena<crate::Type>,
+ handle: crate::Handle<crate::Type>,
+ ) -> Word {
+ let ty = &arena[handle];
+ let id = self.generate_id();
+
+ let instruction = match ty.inner {
+ crate::TypeInner::Scalar { kind, width } => {
+ self.lookup_type
+ .insert(LookupType::Local(LocalType::Scalar { kind, width }), id);
+ self.write_scalar(id, kind, width)
+ }
+ crate::TypeInner::Vector { size, kind, width } => {
+ let scalar_id =
+ self.get_type_id(arena, LookupType::Local(LocalType::Scalar { kind, width }));
+ self.lookup_type.insert(
+ LookupType::Local(LocalType::Vector { size, kind, width }),
+ id,
+ );
+ super::instructions::instruction_type_vector(id, scalar_id, size)
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows: _,
+ width,
+ } => {
+ let vector_id = self.get_type_id(
+ arena,
+ LookupType::Local(LocalType::Vector {
+ size: columns,
+ kind: crate::ScalarKind::Float,
+ width,
+ }),
+ );
+ super::instructions::instruction_type_matrix(id, vector_id, columns)
+ }
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ let width = 4;
+ let local_type = match class {
+ crate::ImageClass::Sampled { kind, multi: _ } => {
+ LocalType::Scalar { kind, width }
+ }
+ crate::ImageClass::Depth => LocalType::Scalar {
+ kind: crate::ScalarKind::Float,
+ width,
+ },
+ crate::ImageClass::Storage(format) => LocalType::Scalar {
+ kind: format.into(),
+ width,
+ },
+ };
+ let type_id = self.get_type_id(arena, LookupType::Local(local_type));
+ let dim = map_dim(dim);
+ self.try_add_capabilities(dim.required_capabilities());
+ super::instructions::instruction_type_image(id, type_id, dim, arrayed, class)
+ }
+ crate::TypeInner::Sampler { comparison: _ } => {
+ super::instructions::instruction_type_sampler(id)
+ }
+ crate::TypeInner::Array { base, size, stride } => {
+ if let Some(array_stride) = stride {
+ self.annotations
+ .push(super::instructions::instruction_decorate(
+ id,
+ spirv::Decoration::ArrayStride,
+ &[array_stride.get()],
+ ));
+ }
+
+ let type_id = self.get_type_id(arena, LookupType::Handle(base));
+ match size {
+ crate::ArraySize::Constant(const_handle) => {
+ let length_id = self.lookup_constant[&const_handle];
+ super::instructions::instruction_type_array(id, type_id, length_id)
+ }
+ crate::ArraySize::Dynamic => {
+ super::instructions::instruction_type_runtime_array(id, type_id)
+ }
+ }
+ }
+ crate::TypeInner::Struct { ref members } => {
+ let mut member_ids = Vec::with_capacity(members.len());
+ for member in members {
+ let member_id = self.get_type_id(arena, LookupType::Handle(member.ty));
+ member_ids.push(member_id);
+ }
+ super::instructions::instruction_type_struct(id, member_ids.as_slice())
+ }
+ crate::TypeInner::Pointer { base, class } => {
+ let type_id = self.get_type_id(arena, LookupType::Handle(base));
+ self.lookup_type
+ .insert(LookupType::Local(LocalType::Pointer { base, class }), id);
+ super::instructions::instruction_type_pointer(
+ id,
+ self.parse_to_spirv_storage_class(class),
+ type_id,
+ )
+ }
+ };
+
+ self.lookup_type.insert(LookupType::Handle(handle), id);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ id
+ }
+
+ fn write_constant_type(
+ &mut self,
+ handle: crate::Handle<crate::Constant>,
+ ir_module: &crate::Module,
+ ) -> (Instruction, Word) {
+ let id = self.generate_id();
+ self.lookup_constant.insert(handle, id);
+ let constant = &ir_module.constants[handle];
+ let arena = &ir_module.types;
+
+ match constant.inner {
+ crate::ConstantInner::Sint(val) => {
+ let ty = &ir_module.types[constant.ty];
+ let type_id = self.get_type_id(arena, LookupType::Handle(constant.ty));
+
+ let instruction = match ty.inner {
+ crate::TypeInner::Scalar { kind: _, width } => match width {
+ 4 => super::instructions::instruction_constant(type_id, id, &[val as u32]),
+ 8 => {
+ let (low, high) = ((val >> 32) as u32, val as u32);
+ super::instructions::instruction_constant(type_id, id, &[low, high])
+ }
+ _ => unreachable!(),
+ },
+ _ => unreachable!(),
+ };
+ (instruction, id)
+ }
+ crate::ConstantInner::Uint(val) => {
+ let ty = &ir_module.types[constant.ty];
+ let type_id = self.get_type_id(arena, LookupType::Handle(constant.ty));
+
+ let instruction = match ty.inner {
+ crate::TypeInner::Scalar { kind: _, width } => match width {
+ 4 => super::instructions::instruction_constant(type_id, id, &[val as u32]),
+ 8 => {
+ let (low, high) = ((val >> 32) as u32, val as u32);
+ super::instructions::instruction_constant(type_id, id, &[low, high])
+ }
+ _ => unreachable!(),
+ },
+ _ => unreachable!(),
+ };
+
+ (instruction, id)
+ }
+ crate::ConstantInner::Float(val) => {
+ let ty = &ir_module.types[constant.ty];
+ let type_id = self.get_type_id(arena, LookupType::Handle(constant.ty));
+
+ let instruction = match ty.inner {
+ crate::TypeInner::Scalar { kind: _, width } => match width {
+ 4 => super::instructions::instruction_constant(
+ type_id,
+ id,
+ &[(val as f32).to_bits()],
+ ),
+ 8 => {
+ let bits = f64::to_bits(val);
+ let (low, high) = ((bits >> 32) as u32, bits as u32);
+ super::instructions::instruction_constant(type_id, id, &[low, high])
+ }
+ _ => unreachable!(),
+ },
+ _ => unreachable!(),
+ };
+ (instruction, id)
+ }
+ crate::ConstantInner::Bool(val) => {
+ let type_id = self.get_type_id(arena, LookupType::Handle(constant.ty));
+
+ let instruction = if val {
+ super::instructions::instruction_constant_true(type_id, id)
+ } else {
+ super::instructions::instruction_constant_false(type_id, id)
+ };
+
+ (instruction, id)
+ }
+ crate::ConstantInner::Composite(ref constituents) => {
+ let mut constituent_ids = Vec::with_capacity(constituents.len());
+ for constituent in constituents.iter() {
+ let constituent_id = self.get_constant_id(*constituent, &ir_module);
+ constituent_ids.push(constituent_id);
+ }
+
+ let type_id = self.get_type_id(arena, LookupType::Handle(constant.ty));
+ let instruction = super::instructions::instruction_constant_composite(
+ type_id,
+ id,
+ constituent_ids.as_slice(),
+ );
+ (instruction, id)
+ }
+ }
+ }
+
+ fn write_global_variable(
+ &mut self,
+ ir_module: &crate::Module,
+ handle: crate::Handle<crate::GlobalVariable>,
+ ) -> (Instruction, Word) {
+ let global_variable = &ir_module.global_variables[handle];
+ let id = self.generate_id();
+
+ let class = self.parse_to_spirv_storage_class(global_variable.class);
+ self.try_add_capabilities(class.required_capabilities());
+
+ let init_word = global_variable
+ .init
+ .map(|constant| self.get_constant_id(constant, ir_module));
+ let pointer_id =
+ self.get_pointer_id(&ir_module.types, global_variable.ty, global_variable.class);
+ let instruction =
+ super::instructions::instruction_variable(pointer_id, id, class, init_word);
+
+ if self.writer_flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = global_variable.name {
+ self.debugs
+ .push(super::instructions::instruction_name(id, name.as_str()));
+ }
+ }
+
+ if let Some(interpolation) = global_variable.interpolation {
+ let decoration = match interpolation {
+ crate::Interpolation::Linear => Some(spirv::Decoration::NoPerspective),
+ crate::Interpolation::Flat => Some(spirv::Decoration::Flat),
+ crate::Interpolation::Patch => Some(spirv::Decoration::Patch),
+ crate::Interpolation::Centroid => Some(spirv::Decoration::Centroid),
+ crate::Interpolation::Sample => Some(spirv::Decoration::Sample),
+ crate::Interpolation::Perspective => None,
+ };
+ if let Some(decoration) = decoration {
+ self.annotations
+ .push(super::instructions::instruction_decorate(
+ id,
+ decoration,
+ &[],
+ ));
+ }
+ }
+
+ match *global_variable.binding.as_ref().unwrap() {
+ crate::Binding::Location(location) => {
+ self.annotations
+ .push(super::instructions::instruction_decorate(
+ id,
+ spirv::Decoration::Location,
+ &[location],
+ ));
+ }
+ crate::Binding::Resource { group, binding } => {
+ self.annotations
+ .push(super::instructions::instruction_decorate(
+ id,
+ spirv::Decoration::DescriptorSet,
+ &[group],
+ ));
+ self.annotations
+ .push(super::instructions::instruction_decorate(
+ id,
+ spirv::Decoration::Binding,
+ &[binding],
+ ));
+ }
+ crate::Binding::BuiltIn(built_in) => {
+ let built_in = match built_in {
+ crate::BuiltIn::BaseInstance => spirv::BuiltIn::BaseInstance,
+ crate::BuiltIn::BaseVertex => spirv::BuiltIn::BaseVertex,
+ crate::BuiltIn::ClipDistance => spirv::BuiltIn::ClipDistance,
+ crate::BuiltIn::InstanceIndex => spirv::BuiltIn::InstanceIndex,
+ crate::BuiltIn::Position => spirv::BuiltIn::Position,
+ crate::BuiltIn::VertexIndex => spirv::BuiltIn::VertexIndex,
+ crate::BuiltIn::PointSize => spirv::BuiltIn::PointSize,
+ crate::BuiltIn::FragCoord => spirv::BuiltIn::FragCoord,
+ crate::BuiltIn::FrontFacing => spirv::BuiltIn::FrontFacing,
+ crate::BuiltIn::SampleIndex => spirv::BuiltIn::SampleId,
+ crate::BuiltIn::FragDepth => spirv::BuiltIn::FragDepth,
+ crate::BuiltIn::GlobalInvocationId => spirv::BuiltIn::GlobalInvocationId,
+ crate::BuiltIn::LocalInvocationId => spirv::BuiltIn::LocalInvocationId,
+ crate::BuiltIn::LocalInvocationIndex => spirv::BuiltIn::LocalInvocationIndex,
+ crate::BuiltIn::WorkGroupId => spirv::BuiltIn::WorkgroupId,
+ };
+
+ self.annotations
+ .push(super::instructions::instruction_decorate(
+ id,
+ spirv::Decoration::BuiltIn,
+ &[built_in as u32],
+ ));
+ }
+ }
+
+ // TODO Initializer is optional and not (yet) included in the IR
+
+ self.lookup_global_variable.insert(handle, id);
+ (instruction, id)
+ }
+
+ fn get_function_type(
+ &mut self,
+ lookup_function_type: LookupFunctionType,
+ parameter_pointer_ids: Vec<Word>,
+ ) -> Word {
+ match self
+ .lookup_function_type
+ .entry(lookup_function_type.clone())
+ {
+ Entry::Occupied(e) => *e.get(),
+ _ => {
+ let id = self.generate_id();
+ let instruction = super::instructions::instruction_type_function(
+ id,
+ lookup_function_type.return_type_id,
+ parameter_pointer_ids.as_slice(),
+ );
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_function_type.insert(lookup_function_type, id);
+ id
+ }
+ }
+ }
+
+ fn write_composite_construct(
+ &mut self,
+ base_type_id: Word,
+ constituent_ids: &[Word],
+ block: &mut Block,
+ ) -> Word {
+ let id = self.generate_id();
+ block
+ .body
+ .push(super::instructions::instruction_composite_construct(
+ base_type_id,
+ id,
+ constituent_ids,
+ ));
+ id
+ }
+
+ fn get_type_inner<'a>(
+ &self,
+ ty_arena: &'a crate::Arena<crate::Type>,
+ lookup_ty: LookupType,
+ ) -> MaybeOwned<'a, crate::TypeInner> {
+ match lookup_ty {
+ LookupType::Handle(handle) => MaybeOwned::Borrowed(&ty_arena[handle].inner),
+ LookupType::Local(local_ty) => match local_ty {
+ LocalType::Scalar { kind, width } => {
+ MaybeOwned::Owned(crate::TypeInner::Scalar { kind, width })
+ }
+ LocalType::Vector { size, kind, width } => {
+ MaybeOwned::Owned(crate::TypeInner::Vector { size, kind, width })
+ }
+ LocalType::Pointer { base, class } => {
+ MaybeOwned::Owned(crate::TypeInner::Pointer { base, class })
+ }
+ _ => unreachable!(),
+ },
+ }
+ }
+
+ fn write_expression<'a>(
+ &mut self,
+ ir_module: &'a crate::Module,
+ ir_function: &crate::Function,
+ expression: &crate::Expression,
+ block: &mut Block,
+ function: &mut Function,
+ ) -> Result<WriteExpressionOutput, Error> {
+ match *expression {
+ crate::Expression::Access { base, index } => {
+ let id = self.generate_id();
+
+ let (base_id, base_lookup_ty) = self.write_expression(
+ ir_module,
+ ir_function,
+ &ir_function.expressions[base],
+ block,
+ function,
+ )?;
+ let (index_id, _) = self.write_expression(
+ ir_module,
+ ir_function,
+ &ir_function.expressions[index],
+ block,
+ function,
+ )?;
+
+ let base_ty_inner = self.get_type_inner(&ir_module.types, base_lookup_ty);
+
+ let (pointer_id, type_id, lookup_ty) = match *base_ty_inner {
+ crate::TypeInner::Vector { kind, width, .. } => {
+ let scalar_id = self.get_type_id(
+ &ir_module.types,
+ LookupType::Local(LocalType::Scalar { kind, width }),
+ );
+ (
+ self.create_pointer(scalar_id, spirv::StorageClass::Function),
+ scalar_id,
+ LookupType::Local(LocalType::Scalar { kind, width }),
+ )
+ }
+ _ => unimplemented!(),
+ };
+
+ block
+ .body
+ .push(super::instructions::instruction_access_chain(
+ pointer_id,
+ id,
+ base_id,
+ &[index_id],
+ ));
+
+ let load_id = self.generate_id();
+ block.body.push(super::instructions::instruction_load(
+ type_id, load_id, id, None,
+ ));
+
+ Ok((load_id, lookup_ty))
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ let id = self.generate_id();
+ let (base_id, base_lookup_ty) = self
+ .write_expression(
+ ir_module,
+ ir_function,
+ &ir_function.expressions[base],
+ block,
+ function,
+ )
+ .unwrap();
+
+ let base_ty_inner = self.get_type_inner(&ir_module.types, base_lookup_ty);
+
+ let (pointer_id, type_id, lookup_ty) = match *base_ty_inner {
+ crate::TypeInner::Vector { kind, width, .. } => {
+ let scalar_id = self.get_type_id(
+ &ir_module.types,
+ LookupType::Local(LocalType::Scalar { kind, width }),
+ );
+ (
+ self.create_pointer(scalar_id, spirv::StorageClass::Function),
+ scalar_id,
+ LookupType::Local(LocalType::Scalar { kind, width }),
+ )
+ }
+ crate::TypeInner::Struct { ref members } => {
+ let member = &members[index as usize];
+ let type_id =
+ self.get_type_id(&ir_module.types, LookupType::Handle(member.ty));
+ (
+ self.create_pointer(type_id, spirv::StorageClass::Uniform),
+ type_id,
+ LookupType::Handle(member.ty),
+ )
+ }
+ _ => unimplemented!(),
+ };
+
+ let const_ty_id = self.get_type_id(
+ &ir_module.types,
+ LookupType::Local(LocalType::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ }),
+ );
+ let const_id = self.create_constant(const_ty_id, &[index]);
+
+ block
+ .body
+ .push(super::instructions::instruction_access_chain(
+ pointer_id,
+ id,
+ base_id,
+ &[const_id],
+ ));
+
+ let load_id = self.generate_id();
+ block.body.push(super::instructions::instruction_load(
+ type_id, load_id, id, None,
+ ));
+
+ Ok((load_id, lookup_ty))
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ let var = &ir_module.global_variables[handle];
+ let id = self.get_global_variable_id(&ir_module, handle);
+
+ Ok((id, LookupType::Handle(var.ty)))
+ }
+ crate::Expression::Constant(handle) => {
+ let var = &ir_module.constants[handle];
+ let id = self.get_constant_id(handle, ir_module);
+ Ok((id, LookupType::Handle(var.ty)))
+ }
+ crate::Expression::Compose { ty, ref components } => {
+ let base_type_id = self.get_type_id(&ir_module.types, LookupType::Handle(ty));
+
+ let mut constituent_ids = Vec::with_capacity(components.len());
+ for component in components {
+ let expression = &ir_function.expressions[*component];
+ let (component_id, component_local_ty) = self.write_expression(
+ ir_module,
+ &ir_function,
+ expression,
+ block,
+ function,
+ )?;
+
+ let component_id = match expression {
+ crate::Expression::LocalVariable(_)
+ | crate::Expression::GlobalVariable(_) => {
+ let load_id = self.generate_id();
+ block.body.push(super::instructions::instruction_load(
+ self.get_type_id(&ir_module.types, component_local_ty),
+ load_id,
+ component_id,
+ None,
+ ));
+ load_id
+ }
+ _ => component_id,
+ };
+
+ constituent_ids.push(component_id);
+ }
+ let constituent_ids_slice = constituent_ids.as_slice();
+
+ let id = match ir_module.types[ty].inner {
+ crate::TypeInner::Vector { .. } => {
+ self.write_composite_construct(base_type_id, constituent_ids_slice, block)
+ }
+ crate::TypeInner::Matrix {
+ rows,
+ columns,
+ width,
+ } => {
+ let vector_type_id = self.get_type_id(
+ &ir_module.types,
+ LookupType::Local(LocalType::Vector {
+ width,
+ kind: crate::ScalarKind::Float,
+ size: columns,
+ }),
+ );
+
+ let capacity = match rows {
+ crate::VectorSize::Bi => 2,
+ crate::VectorSize::Tri => 3,
+ crate::VectorSize::Quad => 4,
+ };
+
+ let mut vector_ids = Vec::with_capacity(capacity);
+
+ for _ in 0..capacity {
+ let vector_id = self.write_composite_construct(
+ vector_type_id,
+ constituent_ids_slice,
+ block,
+ );
+ vector_ids.push(vector_id);
+ }
+
+ self.write_composite_construct(base_type_id, vector_ids.as_slice(), block)
+ }
+ _ => unreachable!(),
+ };
+
+ Ok((id, LookupType::Handle(ty)))
+ }
+ crate::Expression::Binary { op, left, right } => {
+ let id = self.generate_id();
+ let left_expression = &ir_function.expressions[left];
+ let right_expression = &ir_function.expressions[right];
+ let (left_id, left_lookup_ty) = self.write_expression(
+ ir_module,
+ ir_function,
+ left_expression,
+ block,
+ function,
+ )?;
+ let (right_id, right_lookup_ty) = self.write_expression(
+ ir_module,
+ ir_function,
+ right_expression,
+ block,
+ function,
+ )?;
+
+ let left_lookup_ty = left_lookup_ty;
+ let right_lookup_ty = right_lookup_ty;
+
+ let left_ty_inner = self.get_type_inner(&ir_module.types, left_lookup_ty);
+ let right_ty_inner = self.get_type_inner(&ir_module.types, right_lookup_ty);
+
+ let left_result_type_id = self.get_type_id(&ir_module.types, left_lookup_ty);
+
+ let right_result_type_id = self.get_type_id(&ir_module.types, right_lookup_ty);
+
+ let left_id = match *left_expression {
+ crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => {
+ let load_id = self.generate_id();
+ block.body.push(super::instructions::instruction_load(
+ left_result_type_id,
+ load_id,
+ left_id,
+ None,
+ ));
+ load_id
+ }
+ _ => left_id,
+ };
+
+ let right_id = match *right_expression {
+ crate::Expression::LocalVariable(..)
+ | crate::Expression::GlobalVariable(..) => {
+ let load_id = self.generate_id();
+ block.body.push(super::instructions::instruction_load(
+ right_result_type_id,
+ load_id,
+ right_id,
+ None,
+ ));
+ load_id
+ }
+ _ => right_id,
+ };
+
+ let left_dimension = get_dimension(&left_ty_inner);
+ let right_dimension = get_dimension(&right_ty_inner);
+
+ let (instruction, lookup_ty) = match op {
+ crate::BinaryOperator::Multiply => match (left_dimension, right_dimension) {
+ (Dimension::Vector, Dimension::Scalar { .. }) => (
+ super::instructions::instruction_vector_times_scalar(
+ left_result_type_id,
+ id,
+ left_id,
+ right_id,
+ ),
+ left_lookup_ty,
+ ),
+ (Dimension::Vector, Dimension::Matrix) => (
+ super::instructions::instruction_vector_times_matrix(
+ left_result_type_id,
+ id,
+ left_id,
+ right_id,
+ ),
+ left_lookup_ty,
+ ),
+ (Dimension::Matrix, Dimension::Scalar { .. }) => (
+ super::instructions::instruction_matrix_times_scalar(
+ left_result_type_id,
+ id,
+ left_id,
+ right_id,
+ ),
+ left_lookup_ty,
+ ),
+ (Dimension::Matrix, Dimension::Vector) => (
+ super::instructions::instruction_matrix_times_vector(
+ right_result_type_id,
+ id,
+ left_id,
+ right_id,
+ ),
+ right_lookup_ty,
+ ),
+ (Dimension::Matrix, Dimension::Matrix) => (
+ super::instructions::instruction_matrix_times_matrix(
+ left_result_type_id,
+ id,
+ left_id,
+ right_id,
+ ),
+ left_lookup_ty,
+ ),
+ (Dimension::Vector, Dimension::Vector)
+ | (Dimension::Scalar, Dimension::Scalar)
+ if left_ty_inner.scalar_kind() == Some(crate::ScalarKind::Float) =>
+ {
+ (
+ super::instructions::instruction_f_mul(
+ left_result_type_id,
+ id,
+ left_id,
+ right_id,
+ ),
+ left_lookup_ty,
+ )
+ }
+ (Dimension::Vector, Dimension::Vector)
+ | (Dimension::Scalar, Dimension::Scalar) => (
+ super::instructions::instruction_i_mul(
+ left_result_type_id,
+ id,
+ left_id,
+ right_id,
+ ),
+ left_lookup_ty,
+ ),
+ _ => unreachable!(),
+ },
+ crate::BinaryOperator::Subtract => match *left_ty_inner {
+ crate::TypeInner::Scalar { kind, .. } => match kind {
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => (
+ super::instructions::instruction_i_sub(
+ left_result_type_id,
+ id,
+ left_id,
+ right_id,
+ ),
+ left_lookup_ty,
+ ),
+ crate::ScalarKind::Float => (
+ super::instructions::instruction_f_sub(
+ left_result_type_id,
+ id,
+ left_id,
+ right_id,
+ ),
+ left_lookup_ty,
+ ),
+ _ => unreachable!(),
+ },
+ _ => unreachable!(),
+ },
+ crate::BinaryOperator::And => (
+ super::instructions::instruction_bitwise_and(
+ left_result_type_id,
+ id,
+ left_id,
+ right_id,
+ ),
+ left_lookup_ty,
+ ),
+ _ => unimplemented!("{:?}", op),
+ };
+
+ block.body.push(instruction);
+ Ok((id, lookup_ty))
+ }
+ crate::Expression::LocalVariable(variable) => {
+ let var = &ir_function.local_variables[variable];
+ function
+ .variables
+ .iter()
+ .find(|&v| v.name.as_ref().unwrap() == var.name.as_ref().unwrap())
+ .map(|local_var| (local_var.id, LookupType::Handle(var.ty)))
+ .ok_or_else(|| Error::UnknownLocalVariable(var.clone()))
+ }
+ crate::Expression::FunctionArgument(index) => {
+ let handle = ir_function.arguments[index as usize].ty;
+ let type_id = self.get_type_id(&ir_module.types, LookupType::Handle(handle));
+ let load_id = self.generate_id();
+
+ block.body.push(super::instructions::instruction_load(
+ type_id,
+ load_id,
+ function.parameters[index as usize].result_id.unwrap(),
+ None,
+ ));
+ Ok((load_id, LookupType::Handle(handle)))
+ }
+ crate::Expression::Call {
+ ref origin,
+ ref arguments,
+ } => match *origin {
+ crate::FunctionOrigin::Local(local_function) => {
+ let origin_function = &ir_module.functions[local_function];
+ let id = self.generate_id();
+ let mut argument_ids = vec![];
+
+ for argument in arguments {
+ let expression = &ir_function.expressions[*argument];
+ let (id, lookup_ty) = self.write_expression(
+ ir_module,
+ ir_function,
+ expression,
+ block,
+ function,
+ )?;
+
+ // Create variable - OpVariable
+ // Store value to variable - OpStore
+ // Use id of variable
+
+ let handle = match lookup_ty {
+ LookupType::Handle(handle) => handle,
+ LookupType::Local(_) => unreachable!(),
+ };
+
+ let pointer_id = self.get_pointer_id(
+ &ir_module.types,
+ handle,
+ crate::StorageClass::Function,
+ );
+
+ let variable_id = self.generate_id();
+ function.variables.push(LocalVariable {
+ id: variable_id,
+ name: None,
+ instruction: super::instructions::instruction_variable(
+ pointer_id,
+ variable_id,
+ spirv::StorageClass::Function,
+ None,
+ ),
+ });
+ block.body.push(super::instructions::instruction_store(
+ variable_id,
+ id,
+ None,
+ ));
+ argument_ids.push(variable_id);
+ }
+
+ let return_type_id = self
+ .get_function_return_type(origin_function.return_type, &ir_module.types);
+
+ block
+ .body
+ .push(super::instructions::instruction_function_call(
+ return_type_id,
+ id,
+ *self.lookup_function.get(&local_function).unwrap(),
+ argument_ids.as_slice(),
+ ));
+
+ let result_type = match origin_function.return_type {
+ Some(ty_handle) => LookupType::Handle(ty_handle),
+ None => LookupType::Local(LocalType::Void),
+ };
+ Ok((id, result_type))
+ }
+ _ => unimplemented!("{:?}", origin),
+ },
+ crate::Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ if !convert {
+ return Err(Error::FeatureNotImplemented());
+ }
+
+ let (expr_id, expr_type) = self.write_expression(
+ ir_module,
+ ir_function,
+ &ir_function.expressions[expr],
+ block,
+ function,
+ )?;
+
+ let expr_type_inner = self.get_type_inner(&ir_module.types, expr_type);
+
+ let (expr_kind, local_type) = match *expr_type_inner {
+ crate::TypeInner::Scalar {
+ kind: expr_kind,
+ width,
+ } => (expr_kind, LocalType::Scalar { kind, width }),
+ crate::TypeInner::Vector {
+ size,
+ kind: expr_kind,
+ width,
+ } => (expr_kind, LocalType::Vector { size, kind, width }),
+ _ => unreachable!(),
+ };
+
+ let lookup_type = LookupType::Local(local_type);
+ let op = match (expr_kind, kind) {
+ _ if !convert => spirv::Op::Bitcast,
+ (crate::ScalarKind::Float, crate::ScalarKind::Uint) => spirv::Op::ConvertFToU,
+ (crate::ScalarKind::Float, crate::ScalarKind::Sint) => spirv::Op::ConvertFToS,
+ (crate::ScalarKind::Sint, crate::ScalarKind::Float) => spirv::Op::ConvertSToF,
+ (crate::ScalarKind::Uint, crate::ScalarKind::Float) => spirv::Op::ConvertUToF,
+ // We assume it's either an identity cast, or int-uint.
+ // In both cases no SPIR-V instructions need to be generated.
+ _ => {
+ let id = match ir_function.expressions[expr] {
+ crate::Expression::LocalVariable(_)
+ | crate::Expression::GlobalVariable(_) => {
+ let load_id = self.generate_id();
+ let kind_type_id = self.get_type_id(&ir_module.types, expr_type);
+ block.body.push(super::instructions::instruction_load(
+ kind_type_id,
+ load_id,
+ expr_id,
+ None,
+ ));
+ load_id
+ }
+ _ => expr_id,
+ };
+ return Ok((id, lookup_type));
+ }
+ };
+
+ let id = self.generate_id();
+ let kind_type_id = self.get_type_id(&ir_module.types, lookup_type);
+ let instruction =
+ super::instructions::instruction_unary(op, kind_type_id, id, expr_id);
+ block.body.push(instruction);
+
+ Ok((id, lookup_type))
+ }
+ crate::Expression::ImageSample {
+ image,
+ sampler,
+ coordinate,
+ level: _,
+ depth_ref: _,
+ } => {
+ // image
+ let image_expression = &ir_function.expressions[image];
+ let (image_id, image_lookup_ty) = self.write_expression(
+ ir_module,
+ ir_function,
+ image_expression,
+ block,
+ function,
+ )?;
+
+ let image_result_type_id = self.get_type_id(&ir_module.types, image_lookup_ty);
+ let image_id = match *image_expression {
+ crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => {
+ let load_id = self.generate_id();
+ block.body.push(super::instructions::instruction_load(
+ image_result_type_id,
+ load_id,
+ image_id,
+ None,
+ ));
+ load_id
+ }
+ _ => image_id,
+ };
+
+ let image_ty = match image_lookup_ty {
+ LookupType::Handle(handle) => handle,
+ LookupType::Local(_) => unreachable!(),
+ };
+
+ // OpTypeSampledImage
+ let sampled_image_type_id = self.get_type_id(
+ &ir_module.types,
+ LookupType::Local(LocalType::SampledImage {
+ image_type: image_ty,
+ }),
+ );
+
+ // sampler
+ let sampler_expression = &ir_function.expressions[sampler];
+ let (sampler_id, sampler_lookup_ty) = self.write_expression(
+ ir_module,
+ ir_function,
+ sampler_expression,
+ block,
+ function,
+ )?;
+
+ let sampler_result_type_id = self.get_type_id(&ir_module.types, sampler_lookup_ty);
+ let sampler_id = match *sampler_expression {
+ crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => {
+ let load_id = self.generate_id();
+ block.body.push(super::instructions::instruction_load(
+ sampler_result_type_id,
+ load_id,
+ sampler_id,
+ None,
+ ));
+ load_id
+ }
+ _ => sampler_id,
+ };
+
+ // coordinate
+ let coordinate_expression = &ir_function.expressions[coordinate];
+ let (coordinate_id, coordinate_lookup_ty) = self.write_expression(
+ ir_module,
+ ir_function,
+ coordinate_expression,
+ block,
+ function,
+ )?;
+
+ let coordinate_result_type_id =
+ self.get_type_id(&ir_module.types, coordinate_lookup_ty);
+ let coordinate_id = match *coordinate_expression {
+ crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => {
+ let load_id = self.generate_id();
+ block.body.push(super::instructions::instruction_load(
+ coordinate_result_type_id,
+ load_id,
+ coordinate_id,
+ None,
+ ));
+ load_id
+ }
+ _ => coordinate_id,
+ };
+
+ // component kind
+ let image_type = &ir_module.types[image_ty];
+ let image_sample_result_type =
+ if let crate::TypeInner::Image { class, .. } = image_type.inner {
+ let width = 4;
+ LookupType::Local(match class {
+ crate::ImageClass::Sampled { kind, multi: _ } => LocalType::Vector {
+ kind,
+ width,
+ size: crate::VectorSize::Quad,
+ },
+ crate::ImageClass::Depth => LocalType::Scalar {
+ kind: crate::ScalarKind::Float,
+ width,
+ },
+ _ => return Err(Error::BadImageClass(class)),
+ })
+ } else {
+ return Err(Error::NotImage);
+ };
+
+ let sampled_image_id = self.generate_id();
+ block
+ .body
+ .push(super::instructions::instruction_sampled_image(
+ sampled_image_type_id,
+ sampled_image_id,
+ image_id,
+ sampler_id,
+ ));
+ let id = self.generate_id();
+ let image_sample_result_type_id =
+ self.get_type_id(&ir_module.types, image_sample_result_type);
+ block
+ .body
+ .push(super::instructions::instruction_image_sample_implicit_lod(
+ image_sample_result_type_id,
+ id,
+ sampled_image_id,
+ coordinate_id,
+ ));
+ Ok((id, image_sample_result_type))
+ }
+ _ => unimplemented!("{:?}", expression),
+ }
+ }
+
+ fn write_block(
+ &mut self,
+ statements: &[crate::Statement],
+ ir_module: &crate::Module,
+ ir_function: &crate::Function,
+ function: &mut Function,
+ ) -> spirv::Word {
+ let mut block = Block::new();
+ let id = self.generate_id();
+ block.label = Some(super::instructions::instruction_label(id));
+
+ for statement in statements {
+ match *statement {
+ crate::Statement::Block(ref ir_block) => {
+ if !ir_block.is_empty() {
+ //TODO: link the block with `OpBranch`
+ self.write_block(ir_block, ir_module, ir_function, function);
+ }
+ }
+ crate::Statement::Return { value } => {
+ block.termination = Some(match ir_function.return_type {
+ Some(_) => {
+ let expression = &ir_function.expressions[value.unwrap()];
+ let (id, lookup_ty) = self
+ .write_expression(
+ ir_module,
+ ir_function,
+ expression,
+ &mut block,
+ function,
+ )
+ .unwrap();
+
+ let id = match *expression {
+ crate::Expression::LocalVariable(_)
+ | crate::Expression::GlobalVariable(_) => {
+ let load_id = self.generate_id();
+ let value_ty_id = self.get_type_id(&ir_module.types, lookup_ty);
+ block.body.push(super::instructions::instruction_load(
+ value_ty_id,
+ load_id,
+ id,
+ None,
+ ));
+ load_id
+ }
+
+ _ => id,
+ };
+ super::instructions::instruction_return_value(id)
+ }
+ None => super::instructions::instruction_return(),
+ });
+ }
+ crate::Statement::Store { pointer, value } => {
+ let pointer_expression = &ir_function.expressions[pointer];
+ let value_expression = &ir_function.expressions[value];
+ let (pointer_id, _) = self
+ .write_expression(
+ ir_module,
+ ir_function,
+ pointer_expression,
+ &mut block,
+ function,
+ )
+ .unwrap();
+ let (value_id, value_lookup_ty) = self
+ .write_expression(
+ ir_module,
+ ir_function,
+ value_expression,
+ &mut block,
+ function,
+ )
+ .unwrap();
+
+ let value_id = match value_expression {
+ crate::Expression::LocalVariable(_)
+ | crate::Expression::GlobalVariable(_) => {
+ let load_id = self.generate_id();
+ let value_ty_id = self.get_type_id(&ir_module.types, value_lookup_ty);
+ block.body.push(super::instructions::instruction_load(
+ value_ty_id,
+ load_id,
+ value_id,
+ None,
+ ));
+ load_id
+ }
+ _ => value_id,
+ };
+
+ block.body.push(super::instructions::instruction_store(
+ pointer_id, value_id, None,
+ ));
+ }
+ _ => unimplemented!("{:?}", statement),
+ }
+ }
+
+ function.blocks.push(block);
+ id
+ }
+
+ fn write_physical_layout(&mut self) {
+ self.physical_layout.bound = self.id_count + 1;
+ }
+
+ fn write_logical_layout(&mut self, ir_module: &crate::Module) {
+ let id = self.generate_id();
+ super::instructions::instruction_ext_inst_import(id, "GLSL.std.450")
+ .to_words(&mut self.logical_layout.ext_inst_imports);
+
+ if self.writer_flags.contains(WriterFlags::DEBUG) {
+ self.debugs.push(super::instructions::instruction_source(
+ spirv::SourceLanguage::GLSL,
+ 450,
+ ));
+ }
+
+ for (handle, ir_function) in ir_module.functions.iter() {
+ let id = self.write_function(ir_function, ir_module);
+ self.lookup_function.insert(handle, id);
+ }
+
+ for (&(stage, ref name), ir_ep) in ir_module.entry_points.iter() {
+ let entry_point_instruction = self.write_entry_point(ir_ep, stage, name, ir_module);
+ entry_point_instruction.to_words(&mut self.logical_layout.entry_points);
+ }
+
+ for capability in self.capabilities.iter() {
+ super::instructions::instruction_capability(*capability)
+ .to_words(&mut self.logical_layout.capabilities);
+ }
+
+ let addressing_model = spirv::AddressingModel::Logical;
+ let memory_model = spirv::MemoryModel::GLSL450;
+ self.try_add_capabilities(addressing_model.required_capabilities());
+ self.try_add_capabilities(memory_model.required_capabilities());
+
+ super::instructions::instruction_memory_model(addressing_model, memory_model)
+ .to_words(&mut self.logical_layout.memory_model);
+
+ if self.writer_flags.contains(WriterFlags::DEBUG) {
+ for debug in self.debugs.iter() {
+ debug.to_words(&mut self.logical_layout.debugs);
+ }
+ }
+
+ for annotation in self.annotations.iter() {
+ annotation.to_words(&mut self.logical_layout.annotations);
+ }
+ }
+
+ pub fn write(&mut self, ir_module: &crate::Module) -> Vec<Word> {
+ let mut words: Vec<Word> = vec![];
+
+ self.write_logical_layout(ir_module);
+ self.write_physical_layout();
+
+ self.physical_layout.in_words(&mut words);
+ self.logical_layout.in_words(&mut words);
+ words
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::back::spv::{Writer, WriterFlags};
+ use crate::Header;
+
+ #[test]
+ fn test_writer_generate_id() {
+ let mut writer = create_writer();
+
+ assert_eq!(writer.id_count, 0);
+ writer.generate_id();
+ assert_eq!(writer.id_count, 1);
+ }
+
+ #[test]
+ fn test_try_add_capabilities() {
+ let mut writer = create_writer();
+
+ assert_eq!(writer.capabilities.len(), 0);
+ writer.try_add_capabilities(&[spirv::Capability::Shader]);
+ assert_eq!(writer.capabilities.len(), 1);
+
+ writer.try_add_capabilities(&[spirv::Capability::Shader]);
+ assert_eq!(writer.capabilities.len(), 1);
+ }
+
+ #[test]
+ fn test_write_physical_layout() {
+ let mut writer = create_writer();
+ assert_eq!(writer.physical_layout.bound, 0);
+ writer.write_physical_layout();
+ assert_eq!(writer.physical_layout.bound, 1);
+ }
+
+ fn create_writer() -> Writer {
+ let header = Header {
+ generator: 0,
+ version: (1, 0, 0),
+ };
+ Writer::new(&header, WriterFlags::NONE)
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/ast.rs b/third_party/rust/naga/src/front/glsl/ast.rs
new file mode 100644
index 0000000000..58161ad7be
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/ast.rs
@@ -0,0 +1,178 @@
+use super::error::ErrorKind;
+use crate::{
+ proc::{ResolveContext, Typifier},
+ Arena, BinaryOperator, Binding, Expression, FastHashMap, Function, GlobalVariable, Handle,
+ Interpolation, LocalVariable, Module, ShaderStage, Statement, StorageClass, Type,
+};
+
+#[derive(Debug)]
+pub struct Program {
+ pub version: u16,
+ pub profile: Profile,
+ pub shader_stage: ShaderStage,
+ pub entry: Option<String>,
+ pub lookup_function: FastHashMap<String, Handle<Function>>,
+ pub lookup_type: FastHashMap<String, Handle<Type>>,
+ pub lookup_global_variables: FastHashMap<String, Handle<GlobalVariable>>,
+ pub context: Context,
+ pub module: Module,
+}
+
+impl Program {
+ pub fn new(shader_stage: ShaderStage, entry: &str) -> Program {
+ Program {
+ version: 0,
+ profile: Profile::Core,
+ shader_stage,
+ entry: Some(entry.to_string()),
+ lookup_function: FastHashMap::default(),
+ lookup_type: FastHashMap::default(),
+ lookup_global_variables: FastHashMap::default(),
+ context: Context {
+ expressions: Arena::<Expression>::new(),
+ local_variables: Arena::<LocalVariable>::new(),
+ scopes: vec![FastHashMap::default()],
+ lookup_global_var_exps: FastHashMap::default(),
+ typifier: Typifier::new(),
+ },
+ module: Module::generate_empty(),
+ }
+ }
+
+ pub fn binary_expr(
+ &mut self,
+ op: BinaryOperator,
+ left: &ExpressionRule,
+ right: &ExpressionRule,
+ ) -> ExpressionRule {
+ ExpressionRule::from_expression(self.context.expressions.append(Expression::Binary {
+ op,
+ left: left.expression,
+ right: right.expression,
+ }))
+ }
+
+ pub fn resolve_type(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ ) -> Result<&crate::TypeInner, ErrorKind> {
+ let functions = Arena::new(); //TODO
+ let arguments = Vec::new(); //TODO
+ let resolve_ctx = ResolveContext {
+ constants: &self.module.constants,
+ global_vars: &self.module.global_variables,
+ local_vars: &self.context.local_variables,
+ functions: &functions,
+ arguments: &arguments,
+ };
+ match self.context.typifier.grow(
+ handle,
+ &self.context.expressions,
+ &mut self.module.types,
+ &resolve_ctx,
+ ) {
+ //TODO: better error report
+ Err(_) => Err(ErrorKind::SemanticError("Can't resolve type")),
+ Ok(()) => Ok(self.context.typifier.get(handle, &self.module.types)),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub enum Profile {
+ Core,
+}
+
+#[derive(Debug)]
+pub struct Context {
+ pub expressions: Arena<Expression>,
+ pub local_variables: Arena<LocalVariable>,
+ //TODO: Find less allocation heavy representation
+ pub scopes: Vec<FastHashMap<String, Handle<Expression>>>,
+ pub lookup_global_var_exps: FastHashMap<String, Handle<Expression>>,
+ pub typifier: Typifier,
+}
+
+impl Context {
+ pub fn lookup_local_var(&self, name: &str) -> Option<Handle<Expression>> {
+ for scope in self.scopes.iter().rev() {
+ if let Some(var) = scope.get(name) {
+ return Some(*var);
+ }
+ }
+ None
+ }
+
+ #[cfg(feature = "glsl-validate")]
+ pub fn lookup_local_var_current_scope(&self, name: &str) -> Option<Handle<Expression>> {
+ if let Some(current) = self.scopes.last() {
+ current.get(name).cloned()
+ } else {
+ None
+ }
+ }
+
+ pub fn clear_scopes(&mut self) {
+ self.scopes.clear();
+ self.scopes.push(FastHashMap::default());
+ }
+
+ /// Add variable to current scope
+ pub fn add_local_var(&mut self, name: String, handle: Handle<Expression>) {
+ if let Some(current) = self.scopes.last_mut() {
+ (*current).insert(name, handle);
+ }
+ }
+
+ /// Add new empty scope
+ pub fn push_scope(&mut self) {
+ self.scopes.push(FastHashMap::default());
+ }
+
+ pub fn remove_current_scope(&mut self) {
+ self.scopes.pop();
+ }
+}
+
+#[derive(Debug)]
+pub struct ExpressionRule {
+ pub expression: Handle<Expression>,
+ pub statements: Vec<Statement>,
+ pub sampler: Option<Handle<Expression>>,
+}
+
+impl ExpressionRule {
+ pub fn from_expression(expression: Handle<Expression>) -> ExpressionRule {
+ ExpressionRule {
+ expression,
+ statements: vec![],
+ sampler: None,
+ }
+ }
+}
+
+#[derive(Debug)]
+pub enum TypeQualifier {
+ StorageClass(StorageClass),
+ Binding(Binding),
+ Interpolation(Interpolation),
+}
+
+#[derive(Debug)]
+pub struct VarDeclaration {
+ pub type_qualifiers: Vec<TypeQualifier>,
+ pub ids_initializers: Vec<(Option<String>, Option<ExpressionRule>)>,
+ pub ty: Handle<Type>,
+}
+
+#[derive(Debug)]
+pub enum FunctionCallKind {
+ TypeConstructor(Handle<Type>),
+ Function(String),
+}
+
+#[derive(Debug)]
+pub struct FunctionCall {
+ pub kind: FunctionCallKind,
+ pub args: Vec<ExpressionRule>,
+}
diff --git a/third_party/rust/naga/src/front/glsl/error.rs b/third_party/rust/naga/src/front/glsl/error.rs
new file mode 100644
index 0000000000..db5ffbe9f2
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/error.rs
@@ -0,0 +1,87 @@
+use super::parser::Token;
+use super::token::TokenMetadata;
+use std::{fmt, io};
+
+#[derive(Debug)]
+pub enum ErrorKind {
+ EndOfFile,
+ InvalidInput,
+ InvalidProfile(TokenMetadata, String),
+ InvalidToken(Token),
+ InvalidVersion(TokenMetadata, i64),
+ IoError(io::Error),
+ ParserFail,
+ ParserStackOverflow,
+ NotImplemented(&'static str),
+ UnknownVariable(TokenMetadata, String),
+ UnknownField(TokenMetadata, String),
+ #[cfg(feature = "glsl-validate")]
+ VariableAlreadyDeclared(String),
+ #[cfg(feature = "glsl-validate")]
+ VariableNotAvailable(String),
+ ExpectedConstant,
+ SemanticError(&'static str),
+ PreprocessorError(String),
+}
+
+impl fmt::Display for ErrorKind {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ ErrorKind::EndOfFile => write!(f, "Unexpected end of file"),
+ ErrorKind::InvalidInput => write!(f, "InvalidInput"),
+ ErrorKind::InvalidProfile(meta, val) => {
+ write!(f, "Invalid profile {} at {:?}", val, meta)
+ }
+ ErrorKind::InvalidToken(token) => write!(f, "Invalid Token {:?}", token),
+ ErrorKind::InvalidVersion(meta, val) => {
+ write!(f, "Invalid version {} at {:?}", val, meta)
+ }
+ ErrorKind::IoError(error) => write!(f, "IO Error {}", error),
+ ErrorKind::ParserFail => write!(f, "Parser failed"),
+ ErrorKind::ParserStackOverflow => write!(f, "Parser stack overflow"),
+ ErrorKind::NotImplemented(msg) => write!(f, "Not implemented: {}", msg),
+ ErrorKind::UnknownVariable(meta, val) => {
+ write!(f, "Unknown variable {} at {:?}", val, meta)
+ }
+ ErrorKind::UnknownField(meta, val) => write!(f, "Unknown field {} at {:?}", val, meta),
+ #[cfg(feature = "glsl-validate")]
+ ErrorKind::VariableAlreadyDeclared(val) => {
+ write!(f, "Variable {} already decalred in current scope", val)
+ }
+ #[cfg(feature = "glsl-validate")]
+ ErrorKind::VariableNotAvailable(val) => {
+ write!(f, "Variable {} not available in this stage", val)
+ }
+ ErrorKind::ExpectedConstant => write!(f, "Expected constant"),
+ ErrorKind::SemanticError(msg) => write!(f, "Semantic error: {}", msg),
+ ErrorKind::PreprocessorError(val) => write!(f, "Preprocessor error: {}", val),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct ParseError {
+ pub kind: ErrorKind,
+}
+
+impl fmt::Display for ParseError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "{:?}", self)
+ }
+}
+
+impl From<io::Error> for ParseError {
+ fn from(error: io::Error) -> Self {
+ ParseError {
+ kind: ErrorKind::IoError(error),
+ }
+ }
+}
+
+impl From<ErrorKind> for ParseError {
+ fn from(kind: ErrorKind) -> Self {
+ ParseError { kind }
+ }
+}
+
+impl std::error::Error for ParseError {}
diff --git a/third_party/rust/naga/src/front/glsl/lex.rs b/third_party/rust/naga/src/front/glsl/lex.rs
new file mode 100644
index 0000000000..57e86a4c2c
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/lex.rs
@@ -0,0 +1,380 @@
+use super::parser::Token;
+use super::{preprocess::LinePreProcessor, token::TokenMetadata, types::parse_type};
+use std::{iter::Enumerate, str::Lines};
+
+fn _consume_str<'a>(input: &'a str, what: &str) -> Option<&'a str> {
+ if input.starts_with(what) {
+ Some(&input[what.len()..])
+ } else {
+ None
+ }
+}
+
+fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str, usize) {
+ let pos = input.find(|c| !what(c)).unwrap_or_else(|| input.len());
+ let (o, i) = input.split_at(pos);
+ (o, i, pos)
+}
+
+#[derive(Clone, Debug)]
+pub struct Lexer<'a> {
+ lines: Enumerate<Lines<'a>>,
+ input: String,
+ line: usize,
+ offset: usize,
+ inside_comment: bool,
+ pub pp: LinePreProcessor,
+}
+
+impl<'a> Lexer<'a> {
+ pub fn consume_token(&mut self) -> Option<Token> {
+ let start = self
+ .input
+ .find(|c: char| !c.is_whitespace())
+ .unwrap_or_else(|| self.input.chars().count());
+ let input = &self.input[start..];
+
+ let mut chars = input.chars();
+ let cur = match chars.next() {
+ Some(c) => c,
+ None => {
+ self.input = self.input[start..].into();
+ return None;
+ }
+ };
+ let mut meta = TokenMetadata {
+ line: 0,
+ chars: start..start + 1,
+ };
+ let mut consume_all = false;
+ let token = match cur {
+ ':' => Some(Token::Colon(meta)),
+ ';' => Some(Token::Semicolon(meta)),
+ ',' => Some(Token::Comma(meta)),
+ '.' => Some(Token::Dot(meta)),
+
+ '(' => Some(Token::LeftParen(meta)),
+ ')' => Some(Token::RightParen(meta)),
+ '{' => Some(Token::LeftBrace(meta)),
+ '}' => Some(Token::RightBrace(meta)),
+ '[' => Some(Token::LeftBracket(meta)),
+ ']' => Some(Token::RightBracket(meta)),
+ '<' | '>' => {
+ let n1 = chars.next();
+ let n2 = chars.next();
+ match (cur, n1, n2) {
+ ('<', Some('<'), Some('=')) => {
+ meta.chars.end = start + 3;
+ Some(Token::LeftAssign(meta))
+ }
+ ('>', Some('>'), Some('=')) => {
+ meta.chars.end = start + 3;
+ Some(Token::RightAssign(meta))
+ }
+ ('<', Some('<'), _) => {
+ meta.chars.end = start + 2;
+ Some(Token::LeftOp(meta))
+ }
+ ('>', Some('>'), _) => {
+ meta.chars.end = start + 2;
+ Some(Token::RightOp(meta))
+ }
+ ('<', Some('='), _) => {
+ meta.chars.end = start + 2;
+ Some(Token::LeOp(meta))
+ }
+ ('>', Some('='), _) => {
+ meta.chars.end = start + 2;
+ Some(Token::GeOp(meta))
+ }
+ ('<', _, _) => Some(Token::LeftAngle(meta)),
+ ('>', _, _) => Some(Token::RightAngle(meta)),
+ _ => None,
+ }
+ }
+ '0'..='9' => {
+ let (number, _, pos) = consume_any(input, |c| (c >= '0' && c <= '9' || c == '.'));
+ if number.find('.').is_some() {
+ if (
+ chars.next().map(|c| c.to_lowercase().next().unwrap()),
+ chars.next().map(|c| c.to_lowercase().next().unwrap()),
+ ) == (Some('l'), Some('f'))
+ {
+ meta.chars.end = start + pos + 2;
+ Some(Token::DoubleConstant((meta, number.parse().unwrap())))
+ } else {
+ meta.chars.end = start + pos;
+
+ Some(Token::FloatConstant((meta, number.parse().unwrap())))
+ }
+ } else {
+ meta.chars.end = start + pos;
+ Some(Token::IntConstant((meta, number.parse().unwrap())))
+ }
+ }
+ 'a'..='z' | 'A'..='Z' | '_' => {
+ let (word, _, pos) = consume_any(input, |c| c.is_ascii_alphanumeric() || c == '_');
+ meta.chars.end = start + pos;
+ match word {
+ "layout" => Some(Token::Layout(meta)),
+ "in" => Some(Token::In(meta)),
+ "out" => Some(Token::Out(meta)),
+ "uniform" => Some(Token::Uniform(meta)),
+ "flat" => Some(Token::Interpolation((meta, crate::Interpolation::Flat))),
+ "noperspective" => {
+ Some(Token::Interpolation((meta, crate::Interpolation::Linear)))
+ }
+ "smooth" => Some(Token::Interpolation((
+ meta,
+ crate::Interpolation::Perspective,
+ ))),
+ "centroid" => {
+ Some(Token::Interpolation((meta, crate::Interpolation::Centroid)))
+ }
+ "sample" => Some(Token::Interpolation((meta, crate::Interpolation::Sample))),
+ // values
+ "true" => Some(Token::BoolConstant((meta, true))),
+ "false" => Some(Token::BoolConstant((meta, false))),
+ // jump statements
+ "continue" => Some(Token::Continue(meta)),
+ "break" => Some(Token::Break(meta)),
+ "return" => Some(Token::Return(meta)),
+ "discard" => Some(Token::Discard(meta)),
+ // selection statements
+ "if" => Some(Token::If(meta)),
+ "else" => Some(Token::Else(meta)),
+ "switch" => Some(Token::Switch(meta)),
+ "case" => Some(Token::Case(meta)),
+ "default" => Some(Token::Default(meta)),
+ // iteration statements
+ "while" => Some(Token::While(meta)),
+ "do" => Some(Token::Do(meta)),
+ "for" => Some(Token::For(meta)),
+ // types
+ "void" => Some(Token::Void(meta)),
+ word => {
+ let token = match parse_type(word) {
+ Some(t) => Token::TypeName((meta, t)),
+ None => Token::Identifier((meta, String::from(word))),
+ };
+ Some(token)
+ }
+ }
+ }
+ '+' | '-' | '&' | '|' | '^' => {
+ let next = chars.next();
+ if next == Some(cur) {
+ meta.chars.end = start + 2;
+ match cur {
+ '+' => Some(Token::IncOp(meta)),
+ '-' => Some(Token::DecOp(meta)),
+ '&' => Some(Token::AndOp(meta)),
+ '|' => Some(Token::OrOp(meta)),
+ '^' => Some(Token::XorOp(meta)),
+ _ => None,
+ }
+ } else {
+ match next {
+ Some('=') => {
+ meta.chars.end = start + 2;
+ match cur {
+ '+' => Some(Token::AddAssign(meta)),
+ '-' => Some(Token::SubAssign(meta)),
+ '&' => Some(Token::AndAssign(meta)),
+ '|' => Some(Token::OrAssign(meta)),
+ '^' => Some(Token::XorAssign(meta)),
+ _ => None,
+ }
+ }
+ _ => match cur {
+ '+' => Some(Token::Plus(meta)),
+ '-' => Some(Token::Dash(meta)),
+ '&' => Some(Token::Ampersand(meta)),
+ '|' => Some(Token::VerticalBar(meta)),
+ '^' => Some(Token::Caret(meta)),
+ _ => None,
+ },
+ }
+ }
+ }
+
+ '%' | '!' | '=' => match chars.next() {
+ Some('=') => {
+ meta.chars.end = start + 2;
+ match cur {
+ '%' => Some(Token::ModAssign(meta)),
+ '!' => Some(Token::NeOp(meta)),
+ '=' => Some(Token::EqOp(meta)),
+ _ => None,
+ }
+ }
+ _ => match cur {
+ '%' => Some(Token::Percent(meta)),
+ '!' => Some(Token::Bang(meta)),
+ '=' => Some(Token::Equal(meta)),
+ _ => None,
+ },
+ },
+
+ '*' => match chars.next() {
+ Some('=') => {
+ meta.chars.end = start + 2;
+ Some(Token::MulAssign(meta))
+ }
+ Some('/') => {
+ meta.chars.end = start + 2;
+ Some(Token::CommentEnd((meta, ())))
+ }
+ _ => Some(Token::Star(meta)),
+ },
+ '/' => {
+ match chars.next() {
+ Some('=') => {
+ meta.chars.end = start + 2;
+ Some(Token::DivAssign(meta))
+ }
+ Some('/') => {
+ // consume rest of line
+ consume_all = true;
+ None
+ }
+ Some('*') => {
+ meta.chars.end = start + 2;
+ Some(Token::CommentStart((meta, ())))
+ }
+ _ => Some(Token::Slash(meta)),
+ }
+ }
+ '#' => {
+ if self.offset == 0 {
+ let mut input = chars.as_str();
+
+ // skip whitespace
+ let word_start = input
+ .find(|c: char| !c.is_whitespace())
+ .unwrap_or_else(|| input.chars().count());
+ input = &input[word_start..];
+
+ let (word, _, pos) = consume_any(input, |c| c.is_alphanumeric() || c == '_');
+ meta.chars.end = start + word_start + 1 + pos;
+ match word {
+ "version" => Some(Token::Version(meta)),
+ w => Some(Token::Unknown((meta, w.into()))),
+ }
+
+ //TODO: preprocessor
+ // if chars.next() == Some(cur) {
+ // (Token::TokenPasting, chars.as_str(), start, start + 2)
+ // } else {
+ // (Token::Preprocessor, input, start, start + 1)
+ // }
+ } else {
+ Some(Token::Unknown((meta, '#'.to_string())))
+ }
+ }
+ '~' => Some(Token::Tilde(meta)),
+ '?' => Some(Token::Question(meta)),
+ ch => Some(Token::Unknown((meta, ch.to_string()))),
+ };
+ if let Some(token) = token {
+ let skip_bytes = input
+ .chars()
+ .take(token.extra().chars.end - start)
+ .fold(0, |acc, c| acc + c.len_utf8());
+ self.input = input[skip_bytes..].into();
+ Some(token)
+ } else {
+ if consume_all {
+ self.input = "".into();
+ } else {
+ self.input = self.input[start..].into();
+ }
+ None
+ }
+ }
+
+ pub fn new(input: &'a str) -> Self {
+ let mut lexer = Lexer {
+ lines: input.lines().enumerate(),
+ input: "".to_string(),
+ line: 0,
+ offset: 0,
+ inside_comment: false,
+ pp: LinePreProcessor::new(),
+ };
+ lexer.next_line();
+ lexer
+ }
+
+ fn next_line(&mut self) -> bool {
+ if let Some((line, input)) = self.lines.next() {
+ let mut input = String::from(input);
+
+ while input.ends_with('\\') {
+ if let Some((_, next)) = self.lines.next() {
+ input.pop();
+ input.push_str(next);
+ } else {
+ break;
+ }
+ }
+
+ if let Ok(processed) = self.pp.process_line(&input) {
+ self.input = processed.unwrap_or_default();
+ self.line = line;
+ self.offset = 0;
+ true
+ } else {
+ //TODO: handle preprocessor error
+ false
+ }
+ } else {
+ false
+ }
+ }
+
+ #[must_use]
+ pub fn next(&mut self) -> Option<Token> {
+ let token = self.consume_token();
+
+ if let Some(mut token) = token {
+ let meta = token.extra_mut();
+ let end = meta.chars.end;
+ meta.line = self.line;
+ meta.chars.start += self.offset;
+ meta.chars.end += self.offset;
+ self.offset += end;
+ if !self.inside_comment {
+ match token {
+ Token::CommentStart(_) => {
+ self.inside_comment = true;
+ self.next()
+ }
+ _ => Some(token),
+ }
+ } else {
+ if let Token::CommentEnd(_) = token {
+ self.inside_comment = false;
+ }
+ self.next()
+ }
+ } else {
+ if !self.next_line() {
+ return None;
+ }
+ self.next()
+ }
+ }
+
+ // #[must_use]
+ // pub fn peek(&mut self) -> Option<Token> {
+ // self.clone().next()
+ // }
+}
+
+impl<'a> Iterator for Lexer<'a> {
+ type Item = Token;
+ fn next(&mut self) -> Option<Self::Item> {
+ self.next()
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/lex_tests.rs b/third_party/rust/naga/src/front/glsl/lex_tests.rs
new file mode 100644
index 0000000000..fde7fc9790
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/lex_tests.rs
@@ -0,0 +1,346 @@
+use super::{lex::Lexer, parser::Token::*, token::TokenMetadata};
+
+#[test]
+fn tokens() {
+ // line comments
+ let mut lex = Lexer::new("void main // myfunction\n//()\n{}");
+ assert_eq!(
+ lex.next().unwrap(),
+ Void(TokenMetadata {
+ line: 0,
+ chars: 0..4
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Identifier((
+ TokenMetadata {
+ line: 0,
+ chars: 5..9
+ },
+ "main".into()
+ ))
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LeftBrace(TokenMetadata {
+ line: 2,
+ chars: 0..1
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ RightBrace(TokenMetadata {
+ line: 2,
+ chars: 1..2
+ })
+ );
+ assert_eq!(lex.next(), None);
+
+ // multi line comment
+ let mut lex = Lexer::new("void main /* comment [] {}\n/**\n{}*/{}");
+ assert_eq!(
+ lex.next().unwrap(),
+ Void(TokenMetadata {
+ line: 0,
+ chars: 0..4
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Identifier((
+ TokenMetadata {
+ line: 0,
+ chars: 5..9
+ },
+ "main".into()
+ ))
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ LeftBrace(TokenMetadata {
+ line: 2,
+ chars: 4..5
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ RightBrace(TokenMetadata {
+ line: 2,
+ chars: 5..6
+ })
+ );
+ assert_eq!(lex.next(), None);
+
+ // identifiers
+ let mut lex = Lexer::new("id123_OK 92No æNoø No¾ No好");
+ assert_eq!(
+ lex.next().unwrap(),
+ Identifier((
+ TokenMetadata {
+ line: 0,
+ chars: 0..8
+ },
+ "id123_OK".into()
+ ))
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ IntConstant((
+ TokenMetadata {
+ line: 0,
+ chars: 9..11
+ },
+ 92
+ ))
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Identifier((
+ TokenMetadata {
+ line: 0,
+ chars: 11..13
+ },
+ "No".into()
+ ))
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Unknown((
+ TokenMetadata {
+ line: 0,
+ chars: 14..15
+ },
+ 'æ'.to_string()
+ ))
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Identifier((
+ TokenMetadata {
+ line: 0,
+ chars: 15..17
+ },
+ "No".into()
+ ))
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Unknown((
+ TokenMetadata {
+ line: 0,
+ chars: 17..18
+ },
+ 'ø'.to_string()
+ ))
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Identifier((
+ TokenMetadata {
+ line: 0,
+ chars: 19..21
+ },
+ "No".into()
+ ))
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Unknown((
+ TokenMetadata {
+ line: 0,
+ chars: 21..22
+ },
+ '¾'.to_string()
+ ))
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Identifier((
+ TokenMetadata {
+ line: 0,
+ chars: 23..25
+ },
+ "No".into()
+ ))
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Unknown((
+ TokenMetadata {
+ line: 0,
+ chars: 25..26
+ },
+ '好'.to_string()
+ ))
+ );
+ assert_eq!(lex.next(), None);
+
+ // version
+ let mut lex = Lexer::new("#version 890 core");
+ assert_eq!(
+ lex.next().unwrap(),
+ Version(TokenMetadata {
+ line: 0,
+ chars: 0..8
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ IntConstant((
+ TokenMetadata {
+ line: 0,
+ chars: 9..12
+ },
+ 890
+ ))
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Identifier((
+ TokenMetadata {
+ line: 0,
+ chars: 13..17
+ },
+ "core".into()
+ ))
+ );
+ assert_eq!(lex.next(), None);
+
+ // operators
+ let mut lex = Lexer::new("+ - * | & % / += -= *= |= &= %= /= ++ -- || && ^^");
+ assert_eq!(
+ lex.next().unwrap(),
+ Plus(TokenMetadata {
+ line: 0,
+ chars: 0..1
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Dash(TokenMetadata {
+ line: 0,
+ chars: 2..3
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Star(TokenMetadata {
+ line: 0,
+ chars: 4..5
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ VerticalBar(TokenMetadata {
+ line: 0,
+ chars: 6..7
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Ampersand(TokenMetadata {
+ line: 0,
+ chars: 8..9
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Percent(TokenMetadata {
+ line: 0,
+ chars: 10..11
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ Slash(TokenMetadata {
+ line: 0,
+ chars: 12..13
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ AddAssign(TokenMetadata {
+ line: 0,
+ chars: 14..16
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ SubAssign(TokenMetadata {
+ line: 0,
+ chars: 17..19
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ MulAssign(TokenMetadata {
+ line: 0,
+ chars: 20..22
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ OrAssign(TokenMetadata {
+ line: 0,
+ chars: 23..25
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ AndAssign(TokenMetadata {
+ line: 0,
+ chars: 26..28
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ ModAssign(TokenMetadata {
+ line: 0,
+ chars: 29..31
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ DivAssign(TokenMetadata {
+ line: 0,
+ chars: 32..34
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ IncOp(TokenMetadata {
+ line: 0,
+ chars: 35..37
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ DecOp(TokenMetadata {
+ line: 0,
+ chars: 38..40
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ OrOp(TokenMetadata {
+ line: 0,
+ chars: 41..43
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ AndOp(TokenMetadata {
+ line: 0,
+ chars: 44..46
+ })
+ );
+ assert_eq!(
+ lex.next().unwrap(),
+ XorOp(TokenMetadata {
+ line: 0,
+ chars: 47..49
+ })
+ );
+ assert_eq!(lex.next(), None);
+}
diff --git a/third_party/rust/naga/src/front/glsl/mod.rs b/third_party/rust/naga/src/front/glsl/mod.rs
new file mode 100644
index 0000000000..4661011828
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/mod.rs
@@ -0,0 +1,43 @@
+use crate::{FastHashMap, Module, ShaderStage};
+
+mod lex;
+#[cfg(test)]
+mod lex_tests;
+
+mod preprocess;
+#[cfg(test)]
+mod preprocess_tests;
+
+mod ast;
+use ast::Program;
+
+use lex::Lexer;
+mod error;
+use error::ParseError;
+mod parser;
+#[cfg(test)]
+mod parser_tests;
+mod token;
+mod types;
+mod variables;
+
+pub fn parse_str(
+ source: &str,
+ entry: &str,
+ stage: ShaderStage,
+ defines: FastHashMap<String, String>,
+) -> Result<Module, ParseError> {
+ let mut program = Program::new(stage, entry);
+
+ let mut lex = Lexer::new(source);
+ lex.pp.defines = defines;
+
+ let mut parser = parser::Parser::new(&mut program);
+
+ for token in lex {
+ parser.parse(token)?;
+ }
+ parser.end_of_input()?;
+
+ Ok(program.module)
+}
diff --git a/third_party/rust/naga/src/front/glsl/parser.rs b/third_party/rust/naga/src/front/glsl/parser.rs
new file mode 100644
index 0000000000..db4935b0b4
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser.rs
@@ -0,0 +1,1131 @@
+#![allow(clippy::panic)]
+use pomelo::pomelo;
+
+pomelo! {
+ //%verbose;
+ %include {
+ use super::super::{error::ErrorKind, token::*, ast::*};
+ use crate::{proc::Typifier, Arena, BinaryOperator, Binding, Block, Constant,
+ ConstantInner, EntryPoint, Expression, FallThrough, FastHashMap, Function, GlobalVariable, Handle, Interpolation,
+ LocalVariable, MemberOrigin, SampleLevel, ScalarKind, Statement, StorageAccess,
+ StorageClass, StructMember, Type, TypeInner, UnaryOperator};
+ }
+ %token #[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] pub enum Token {};
+ %parser pub struct Parser<'a> {};
+ %extra_argument &'a mut Program;
+ %extra_token TokenMetadata;
+ %error ErrorKind;
+ %syntax_error {
+ match token {
+ Some(token) => Err(ErrorKind::InvalidToken(token)),
+ None => Err(ErrorKind::EndOfFile),
+ }
+ }
+ %parse_fail {
+ ErrorKind::ParserFail
+ }
+ %stack_overflow {
+ ErrorKind::ParserStackOverflow
+ }
+
+ %type Unknown String;
+ %type CommentStart ();
+ %type CommentEnd ();
+
+ %type Identifier String;
+ // constants
+ %type IntConstant i64;
+ %type UintConstant u64;
+ %type FloatConstant f32;
+ %type BoolConstant bool;
+ %type DoubleConstant f64;
+ %type String String;
+ // function
+ %type function_prototype Function;
+ %type function_declarator Function;
+ %type function_header Function;
+ %type function_definition Function;
+
+ // statements
+ %type compound_statement Block;
+ %type compound_statement_no_new_scope Block;
+ %type statement_list Block;
+ %type statement Statement;
+ %type simple_statement Statement;
+ %type expression_statement Statement;
+ %type declaration_statement Statement;
+ %type jump_statement Statement;
+ %type iteration_statement Statement;
+ %type selection_statement Statement;
+ %type switch_statement_list Vec<(Option<i32>, Block, Option<FallThrough>)>;
+ %type switch_statement (Option<i32>, Block, Option<FallThrough>);
+ %type for_init_statement Statement;
+ %type for_rest_statement (Option<ExpressionRule>, Option<ExpressionRule>);
+ %type condition_opt Option<ExpressionRule>;
+
+ // expressions
+ %type unary_expression ExpressionRule;
+ %type postfix_expression ExpressionRule;
+ %type primary_expression ExpressionRule;
+ %type variable_identifier ExpressionRule;
+
+ %type function_call ExpressionRule;
+ %type function_call_or_method FunctionCall;
+ %type function_call_generic FunctionCall;
+ %type function_call_header_no_parameters FunctionCall;
+ %type function_call_header_with_parameters FunctionCall;
+ %type function_call_header FunctionCall;
+ %type function_identifier FunctionCallKind;
+
+ %type multiplicative_expression ExpressionRule;
+ %type additive_expression ExpressionRule;
+ %type shift_expression ExpressionRule;
+ %type relational_expression ExpressionRule;
+ %type equality_expression ExpressionRule;
+ %type and_expression ExpressionRule;
+ %type exclusive_or_expression ExpressionRule;
+ %type inclusive_or_expression ExpressionRule;
+ %type logical_and_expression ExpressionRule;
+ %type logical_xor_expression ExpressionRule;
+ %type logical_or_expression ExpressionRule;
+ %type conditional_expression ExpressionRule;
+
+ %type assignment_expression ExpressionRule;
+ %type assignment_operator BinaryOperator;
+ %type expression ExpressionRule;
+ %type constant_expression Handle<Constant>;
+
+ %type initializer ExpressionRule;
+
+ // declarations
+ %type declaration Option<VarDeclaration>;
+ %type init_declarator_list VarDeclaration;
+ %type single_declaration VarDeclaration;
+ %type layout_qualifier Binding;
+ %type layout_qualifier_id_list Vec<(String, u32)>;
+ %type layout_qualifier_id (String, u32);
+ %type type_qualifier Vec<TypeQualifier>;
+ %type single_type_qualifier TypeQualifier;
+ %type storage_qualifier StorageClass;
+ %type interpolation_qualifier Interpolation;
+ %type Interpolation Interpolation;
+
+ // types
+ %type fully_specified_type (Vec<TypeQualifier>, Option<Handle<Type>>);
+ %type type_specifier Option<Handle<Type>>;
+ %type type_specifier_nonarray Option<Type>;
+ %type struct_specifier Type;
+ %type struct_declaration_list Vec<StructMember>;
+ %type struct_declaration Vec<StructMember>;
+ %type struct_declarator_list Vec<String>;
+ %type struct_declarator String;
+
+ %type TypeName Type;
+
+ // precedence
+ %right Else;
+
+ root ::= version_pragma translation_unit;
+ version_pragma ::= Version IntConstant(V) Identifier?(P) {
+ match V.1 {
+ 440 => (),
+ 450 => (),
+ 460 => (),
+ _ => return Err(ErrorKind::InvalidVersion(V.0, V.1))
+ }
+ extra.version = V.1 as u16;
+ extra.profile = match P {
+ Some((meta, profile)) => {
+ match profile.as_str() {
+ "core" => Profile::Core,
+ _ => return Err(ErrorKind::InvalidProfile(meta, profile))
+ }
+ },
+ None => Profile::Core,
+ }
+ };
+
+ // expression
+ variable_identifier ::= Identifier(v) {
+ let var = extra.lookup_variable(&v.1)?;
+ match var {
+ Some(expression) => {
+ ExpressionRule::from_expression(expression)
+ },
+ None => {
+ return Err(ErrorKind::UnknownVariable(v.0, v.1));
+ }
+ }
+ }
+
+ primary_expression ::= variable_identifier;
+ primary_expression ::= IntConstant(i) {
+ let ty = extra.module.types.fetch_or_append(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Sint,
+ width: 4,
+ }
+ });
+ let ch = extra.module.constants.fetch_or_append(Constant {
+ name: None,
+ specialization: None,
+ ty,
+ inner: ConstantInner::Sint(i.1)
+ });
+ ExpressionRule::from_expression(
+ extra.context.expressions.append(Expression::Constant(ch))
+ )
+ }
+ // primary_expression ::= UintConstant;
+ primary_expression ::= FloatConstant(f) {
+ let ty = extra.module.types.fetch_or_append(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Float,
+ width: 4,
+ }
+ });
+ let ch = extra.module.constants.fetch_or_append(Constant {
+ name: None,
+ specialization: None,
+ ty,
+ inner: ConstantInner::Float(f.1 as f64)
+ });
+ ExpressionRule::from_expression(
+ extra.context.expressions.append(Expression::Constant(ch))
+ )
+ }
+ primary_expression ::= BoolConstant(b) {
+ let ty = extra.module.types.fetch_or_append(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Bool,
+ width: 4,
+ }
+ });
+ let ch = extra.module.constants.fetch_or_append(Constant {
+ name: None,
+ specialization: None,
+ ty,
+ inner: ConstantInner::Bool(b.1)
+ });
+ ExpressionRule::from_expression(
+ extra.context.expressions.append(Expression::Constant(ch))
+ )
+ }
+ // primary_expression ::= DoubleConstant;
+ primary_expression ::= LeftParen expression(e) RightParen {
+ e
+ }
+
+ postfix_expression ::= primary_expression;
+ postfix_expression ::= postfix_expression LeftBracket integer_expression RightBracket {
+ //TODO
+ return Err(ErrorKind::NotImplemented("[]"))
+ }
+ postfix_expression ::= function_call;
+ postfix_expression ::= postfix_expression(e) Dot Identifier(i) /* FieldSelection in spec */ {
+ //TODO: how will this work as l-value?
+ let expression = extra.field_selection(e.expression, &*i.1, i.0)?;
+ ExpressionRule { expression, statements: e.statements, sampler: None }
+ }
+ postfix_expression ::= postfix_expression(pe) IncOp {
+ //TODO
+ return Err(ErrorKind::NotImplemented("post++"))
+ }
+ postfix_expression ::= postfix_expression(pe) DecOp {
+ //TODO
+ return Err(ErrorKind::NotImplemented("post--"))
+ }
+
+ integer_expression ::= expression;
+
+ function_call ::= function_call_or_method(fc) {
+ match fc.kind {
+ FunctionCallKind::TypeConstructor(ty) => {
+ let h = if fc.args.len() == 1 {
+ let kind = extra.module.types[ty].inner
+ .scalar_kind()
+ .ok_or(ErrorKind::SemanticError("Can only cast to scalar or vector"))?;
+ extra.context.expressions.append(Expression::As {
+ kind,
+ expr: fc.args[0].expression,
+ convert: true,
+ })
+ } else {
+ extra.context.expressions.append(Expression::Compose {
+ ty,
+ components: fc.args.iter().map(|a| a.expression).collect(),
+ })
+ };
+ ExpressionRule {
+ expression: h,
+ statements: fc.args.into_iter().map(|a| a.statements).flatten().collect(),
+ sampler: None
+ }
+ }
+ FunctionCallKind::Function(name) => {
+ match name.as_str() {
+ "sampler2D" => {
+ //TODO: check args len
+ ExpressionRule{
+ expression: fc.args[0].expression,
+ sampler: Some(fc.args[1].expression),
+ statements: fc.args.into_iter().map(|a| a.statements).flatten().collect(),
+ }
+ }
+ "texture" => {
+ //TODO: check args len
+ if let Some(sampler) = fc.args[0].sampler {
+ ExpressionRule{
+ expression: extra.context.expressions.append(Expression::ImageSample {
+ image: fc.args[0].expression,
+ sampler,
+ coordinate: fc.args[1].expression,
+ level: SampleLevel::Auto,
+ depth_ref: None,
+ }),
+ sampler: None,
+ statements: fc.args.into_iter().map(|a| a.statements).flatten().collect(),
+ }
+ } else {
+ return Err(ErrorKind::SemanticError("Bad call to texture"));
+ }
+ }
+ _ => { return Err(ErrorKind::NotImplemented("Function call")); }
+ }
+ }
+ }
+ }
+ function_call_or_method ::= function_call_generic;
+ function_call_generic ::= function_call_header_with_parameters(h) RightParen {
+ h
+ }
+ function_call_generic ::= function_call_header_no_parameters(h) RightParen {
+ h
+ }
+ function_call_header_no_parameters ::= function_call_header(h) Void {
+ h
+ }
+ function_call_header_no_parameters ::= function_call_header;
+ function_call_header_with_parameters ::= function_call_header(mut h) assignment_expression(ae) {
+ h.args.push(ae);
+ h
+ }
+ function_call_header_with_parameters ::= function_call_header_with_parameters(mut h) Comma assignment_expression(ae) {
+ h.args.push(ae);
+ h
+ }
+ function_call_header ::= function_identifier(i) LeftParen {
+ FunctionCall {
+ kind: i,
+ args: vec![],
+ }
+ }
+
+ // Grammar Note: Constructors look like functions, but lexical analysis recognized most of them as
+ // keywords. They are now recognized through “type_specifier”.
+ function_identifier ::= type_specifier(t) {
+ if let Some(ty) = t {
+ FunctionCallKind::TypeConstructor(ty)
+ } else {
+ return Err(ErrorKind::NotImplemented("bad type ctor"))
+ }
+ }
+
+ //TODO
+ // Methods (.length), subroutine array calls, and identifiers are recognized through postfix_expression.
+ // function_identifier ::= postfix_expression(e) {
+ // FunctionCallKind::Function(e.expression)
+ // }
+
+ // Simplification of above
+ function_identifier ::= Identifier(i) {
+ FunctionCallKind::Function(i.1)
+ }
+
+
+ unary_expression ::= postfix_expression;
+
+ unary_expression ::= IncOp unary_expression {
+ //TODO
+ return Err(ErrorKind::NotImplemented("++pre"))
+ }
+ unary_expression ::= DecOp unary_expression {
+ //TODO
+ return Err(ErrorKind::NotImplemented("--pre"))
+ }
+ unary_expression ::= unary_operator unary_expression {
+ //TODO
+ return Err(ErrorKind::NotImplemented("unary_op"))
+ }
+
+ unary_operator ::= Plus;
+ unary_operator ::= Dash;
+ unary_operator ::= Bang;
+ unary_operator ::= Tilde;
+ multiplicative_expression ::= unary_expression;
+ multiplicative_expression ::= multiplicative_expression(left) Star unary_expression(right) {
+ extra.binary_expr(BinaryOperator::Multiply, &left, &right)
+ }
+ multiplicative_expression ::= multiplicative_expression(left) Slash unary_expression(right) {
+ extra.binary_expr(BinaryOperator::Divide, &left, &right)
+ }
+ multiplicative_expression ::= multiplicative_expression(left) Percent unary_expression(right) {
+ extra.binary_expr(BinaryOperator::Modulo, &left, &right)
+ }
+ additive_expression ::= multiplicative_expression;
+ additive_expression ::= additive_expression(left) Plus multiplicative_expression(right) {
+ extra.binary_expr(BinaryOperator::Add, &left, &right)
+ }
+ additive_expression ::= additive_expression(left) Dash multiplicative_expression(right) {
+ extra.binary_expr(BinaryOperator::Subtract, &left, &right)
+ }
+ shift_expression ::= additive_expression;
+ shift_expression ::= shift_expression(left) LeftOp additive_expression(right) {
+ extra.binary_expr(BinaryOperator::ShiftLeft, &left, &right)
+ }
+ shift_expression ::= shift_expression(left) RightOp additive_expression(right) {
+ extra.binary_expr(BinaryOperator::ShiftRight, &left, &right)
+ }
+ relational_expression ::= shift_expression;
+ relational_expression ::= relational_expression(left) LeftAngle shift_expression(right) {
+ extra.binary_expr(BinaryOperator::Less, &left, &right)
+ }
+ relational_expression ::= relational_expression(left) RightAngle shift_expression(right) {
+ extra.binary_expr(BinaryOperator::Greater, &left, &right)
+ }
+ relational_expression ::= relational_expression(left) LeOp shift_expression(right) {
+ extra.binary_expr(BinaryOperator::LessEqual, &left, &right)
+ }
+ relational_expression ::= relational_expression(left) GeOp shift_expression(right) {
+ extra.binary_expr(BinaryOperator::GreaterEqual, &left, &right)
+ }
+ equality_expression ::= relational_expression;
+ equality_expression ::= equality_expression(left) EqOp relational_expression(right) {
+ extra.binary_expr(BinaryOperator::Equal, &left, &right)
+ }
+ equality_expression ::= equality_expression(left) NeOp relational_expression(right) {
+ extra.binary_expr(BinaryOperator::NotEqual, &left, &right)
+ }
+ and_expression ::= equality_expression;
+ and_expression ::= and_expression(left) Ampersand equality_expression(right) {
+ extra.binary_expr(BinaryOperator::And, &left, &right)
+ }
+ exclusive_or_expression ::= and_expression;
+ exclusive_or_expression ::= exclusive_or_expression(left) Caret and_expression(right) {
+ extra.binary_expr(BinaryOperator::ExclusiveOr, &left, &right)
+ }
+ inclusive_or_expression ::= exclusive_or_expression;
+ inclusive_or_expression ::= inclusive_or_expression(left) VerticalBar exclusive_or_expression(right) {
+ extra.binary_expr(BinaryOperator::InclusiveOr, &left, &right)
+ }
+ logical_and_expression ::= inclusive_or_expression;
+ logical_and_expression ::= logical_and_expression(left) AndOp inclusive_or_expression(right) {
+ extra.binary_expr(BinaryOperator::LogicalAnd, &left, &right)
+ }
+ logical_xor_expression ::= logical_and_expression;
+ logical_xor_expression ::= logical_xor_expression(left) XorOp logical_and_expression(right) {
+ let exp1 = extra.binary_expr(BinaryOperator::LogicalOr, &left, &right);
+ let exp2 = {
+ let tmp = extra.binary_expr(BinaryOperator::LogicalAnd, &left, &right).expression;
+ ExpressionRule::from_expression(extra.context.expressions.append(Expression::Unary { op: UnaryOperator::Not, expr: tmp }))
+ };
+ extra.binary_expr(BinaryOperator::LogicalAnd, &exp1, &exp2)
+ }
+ logical_or_expression ::= logical_xor_expression;
+ logical_or_expression ::= logical_or_expression(left) OrOp logical_xor_expression(right) {
+ extra.binary_expr(BinaryOperator::LogicalOr, &left, &right)
+ }
+
+ conditional_expression ::= logical_or_expression;
+ conditional_expression ::= logical_or_expression Question expression Colon assignment_expression(ae) {
+ //TODO: how to do ternary here in naga?
+ return Err(ErrorKind::NotImplemented("ternary exp"))
+ }
+
+ assignment_expression ::= conditional_expression;
+ assignment_expression ::= unary_expression(mut pointer) assignment_operator(op) assignment_expression(value) {
+ pointer.statements.extend(value.statements);
+ match op {
+ BinaryOperator::Equal => {
+ pointer.statements.push(Statement::Store{
+ pointer: pointer.expression,
+ value: value.expression
+ });
+ pointer
+ },
+ _ => {
+ let h = extra.context.expressions.append(
+ Expression::Binary{
+ op,
+ left: pointer.expression,
+ right: value.expression,
+ }
+ );
+ pointer.statements.push(Statement::Store{
+ pointer: pointer.expression,
+ value: h,
+ });
+ pointer
+ }
+ }
+ }
+
+ assignment_operator ::= Equal {
+ BinaryOperator::Equal
+ }
+ assignment_operator ::= MulAssign {
+ BinaryOperator::Multiply
+ }
+ assignment_operator ::= DivAssign {
+ BinaryOperator::Divide
+ }
+ assignment_operator ::= ModAssign {
+ BinaryOperator::Modulo
+ }
+ assignment_operator ::= AddAssign {
+ BinaryOperator::Add
+ }
+ assignment_operator ::= SubAssign {
+ BinaryOperator::Subtract
+ }
+ assignment_operator ::= LeftAssign {
+ BinaryOperator::ShiftLeft
+ }
+ assignment_operator ::= RightAssign {
+ BinaryOperator::ShiftRight
+ }
+ assignment_operator ::= AndAssign {
+ BinaryOperator::And
+ }
+ assignment_operator ::= XorAssign {
+ BinaryOperator::ExclusiveOr
+ }
+ assignment_operator ::= OrAssign {
+ BinaryOperator::InclusiveOr
+ }
+
+ expression ::= assignment_expression;
+ expression ::= expression(e) Comma assignment_expression(mut ae) {
+ ae.statements.extend(e.statements);
+ ExpressionRule {
+ expression: e.expression,
+ statements: ae.statements,
+ sampler: None,
+ }
+ }
+
+ //TODO: properly handle constant expressions
+ // constant_expression ::= conditional_expression(e) {
+ // if let Expression::Constant(h) = extra.context.expressions[e] {
+ // h
+ // } else {
+ // return Err(ErrorKind::ExpectedConstant)
+ // }
+ // }
+
+ // declaration
+ declaration ::= init_declarator_list(idl) Semicolon {
+ Some(idl)
+ }
+
+ declaration ::= type_qualifier(t) Identifier(i) LeftBrace
+ struct_declaration_list(sdl) RightBrace Semicolon {
+ if i.1 == "gl_PerVertex" {
+ None
+ } else {
+ Some(VarDeclaration {
+ type_qualifiers: t,
+ ids_initializers: vec![(None, None)],
+ ty: extra.module.types.fetch_or_append(Type{
+ name: Some(i.1),
+ inner: TypeInner::Struct {
+ members: sdl
+ }
+ }),
+ })
+ }
+ }
+
+ declaration ::= type_qualifier(t) Identifier(i1) LeftBrace
+ struct_declaration_list(sdl) RightBrace Identifier(i2) Semicolon {
+ Some(VarDeclaration {
+ type_qualifiers: t,
+ ids_initializers: vec![(Some(i2.1), None)],
+ ty: extra.module.types.fetch_or_append(Type{
+ name: Some(i1.1),
+ inner: TypeInner::Struct {
+ members: sdl
+ }
+ }),
+ })
+ }
+
+ // declaration ::= type_qualifier(t) Identifier(i1) LeftBrace
+ // struct_declaration_list RightBrace Identifier(i2) array_specifier Semicolon;
+
+ init_declarator_list ::= single_declaration;
+ init_declarator_list ::= init_declarator_list(mut idl) Comma Identifier(i) {
+ idl.ids_initializers.push((Some(i.1), None));
+ idl
+ }
+ // init_declarator_list ::= init_declarator_list Comma Identifier array_specifier;
+ // init_declarator_list ::= init_declarator_list Comma Identifier array_specifier Equal initializer;
+ init_declarator_list ::= init_declarator_list(mut idl) Comma Identifier(i) Equal initializer(init) {
+ idl.ids_initializers.push((Some(i.1), Some(init)));
+ idl
+ }
+
+ single_declaration ::= fully_specified_type(t) {
+ let ty = t.1.ok_or(ErrorKind::SemanticError("Empty type for declaration"))?;
+
+ VarDeclaration {
+ type_qualifiers: t.0,
+ ids_initializers: vec![],
+ ty,
+ }
+ }
+ single_declaration ::= fully_specified_type(t) Identifier(i) {
+ let ty = t.1.ok_or(ErrorKind::SemanticError("Empty type for declaration"))?;
+
+ VarDeclaration {
+ type_qualifiers: t.0,
+ ids_initializers: vec![(Some(i.1), None)],
+ ty,
+ }
+ }
+ // single_declaration ::= fully_specified_type Identifier array_specifier;
+ // single_declaration ::= fully_specified_type Identifier array_specifier Equal initializer;
+ single_declaration ::= fully_specified_type(t) Identifier(i) Equal initializer(init) {
+ let ty = t.1.ok_or(ErrorKind::SemanticError("Empty type for declaration"))?;
+
+ VarDeclaration {
+ type_qualifiers: t.0,
+ ids_initializers: vec![(Some(i.1), Some(init))],
+ ty,
+ }
+ }
+
+ fully_specified_type ::= type_specifier(t) {
+ (vec![], t)
+ }
+ fully_specified_type ::= type_qualifier(q) type_specifier(t) {
+ (q,t)
+ }
+
+ interpolation_qualifier ::= Interpolation((_, i)) {
+ i
+ }
+
+ layout_qualifier ::= Layout LeftParen layout_qualifier_id_list(l) RightParen {
+ if let Some(&(_, loc)) = l.iter().find(|&q| q.0.as_str() == "location") {
+ Binding::Location(loc)
+ } else if let Some(&(_, binding)) = l.iter().find(|&q| q.0.as_str() == "binding") {
+ let group = if let Some(&(_, set)) = l.iter().find(|&q| q.0.as_str() == "set") {
+ set
+ } else {
+ 0
+ };
+ Binding::Resource{ group, binding }
+ } else {
+ return Err(ErrorKind::NotImplemented("unsupported layout qualifier(s)"));
+ }
+ }
+ layout_qualifier_id_list ::= layout_qualifier_id(lqi) {
+ vec![lqi]
+ }
+ layout_qualifier_id_list ::= layout_qualifier_id_list(mut l) Comma layout_qualifier_id(lqi) {
+ l.push(lqi);
+ l
+ }
+ layout_qualifier_id ::= Identifier(i) {
+ (i.1, 0)
+ }
+ //TODO: handle full constant_expression instead of IntConstant
+ layout_qualifier_id ::= Identifier(i) Equal IntConstant(ic) {
+ (i.1, ic.1 as u32)
+ }
+ // layout_qualifier_id ::= Shared;
+
+ // precise_qualifier ::= Precise;
+
+ type_qualifier ::= single_type_qualifier(t) {
+ vec![t]
+ }
+ type_qualifier ::= type_qualifier(mut l) single_type_qualifier(t) {
+ l.push(t);
+ l
+ }
+
+ single_type_qualifier ::= storage_qualifier(s) {
+ TypeQualifier::StorageClass(s)
+ }
+ single_type_qualifier ::= layout_qualifier(l) {
+ TypeQualifier::Binding(l)
+ }
+ // single_type_qualifier ::= precision_qualifier;
+ single_type_qualifier ::= interpolation_qualifier(i) {
+ TypeQualifier::Interpolation(i)
+ }
+ // single_type_qualifier ::= invariant_qualifier;
+ // single_type_qualifier ::= precise_qualifier;
+
+ // storage_qualifier ::= Const
+ // storage_qualifier ::= InOut;
+ storage_qualifier ::= In {
+ StorageClass::Input
+ }
+ storage_qualifier ::= Out {
+ StorageClass::Output
+ }
+ // storage_qualifier ::= Centroid;
+ // storage_qualifier ::= Patch;
+ // storage_qualifier ::= Sample;
+ storage_qualifier ::= Uniform {
+ StorageClass::Uniform
+ }
+ //TODO: other storage qualifiers
+
+ type_specifier ::= type_specifier_nonarray(t) {
+ t.map(|t| {
+ let name = t.name.clone();
+ let handle = extra.module.types.fetch_or_append(t);
+ if let Some(name) = name {
+ extra.lookup_type.insert(name, handle);
+ }
+ handle
+ })
+ }
+ //TODO: array
+
+ type_specifier_nonarray ::= Void {
+ None
+ }
+ type_specifier_nonarray ::= TypeName(t) {
+ Some(t.1)
+ };
+ type_specifier_nonarray ::= struct_specifier(s) {
+ Some(s)
+ }
+
+ // struct
+ struct_specifier ::= Struct Identifier(i) LeftBrace struct_declaration_list RightBrace {
+ Type{
+ name: Some(i.1),
+ inner: TypeInner::Struct {
+ members: vec![]
+ }
+ }
+ }
+ //struct_specifier ::= Struct LeftBrace struct_declaration_list RightBrace;
+
+ struct_declaration_list ::= struct_declaration(sd) {
+ sd
+ }
+ struct_declaration_list ::= struct_declaration_list(mut sdl) struct_declaration(sd) {
+ sdl.extend(sd);
+ sdl
+ }
+
+ struct_declaration ::= type_specifier(t) struct_declarator_list(sdl) Semicolon {
+ if let Some(ty) = t {
+ sdl.iter().map(|name| StructMember {
+ name: Some(name.clone()),
+ origin: MemberOrigin::Empty,
+ ty,
+ }).collect()
+ } else {
+ return Err(ErrorKind::SemanticError("Struct member can't be void"))
+ }
+ }
+ //struct_declaration ::= type_qualifier type_specifier struct_declarator_list Semicolon;
+
+ struct_declarator_list ::= struct_declarator(sd) {
+ vec![sd]
+ }
+ struct_declarator_list ::= struct_declarator_list(mut sdl) Comma struct_declarator(sd) {
+ sdl.push(sd);
+ sdl
+ }
+
+ struct_declarator ::= Identifier(i) {
+ i.1
+ }
+ //struct_declarator ::= Identifier array_specifier;
+
+
+ initializer ::= assignment_expression;
+ // initializer ::= LeftBrace initializer_list RightBrace;
+ // initializer ::= LeftBrace initializer_list Comma RightBrace;
+
+ // initializer_list ::= initializer;
+ // initializer_list ::= initializer_list Comma initializer;
+
+ declaration_statement ::= declaration(d) {
+ let mut statements = Vec::<Statement>::new();
+ // local variables
+ if let Some(d) = d {
+ for (id, initializer) in d.ids_initializers {
+ let id = id.ok_or(ErrorKind::SemanticError("local var must be named"))?;
+ // check if already declared in current scope
+ #[cfg(feature = "glsl-validate")]
+ {
+ if extra.context.lookup_local_var_current_scope(&id).is_some() {
+ return Err(ErrorKind::VariableAlreadyDeclared(id))
+ }
+ }
+ let mut init_exp: Option<Handle<Expression>> = None;
+ let localVar = extra.context.local_variables.append(
+ LocalVariable {
+ name: Some(id.clone()),
+ ty: d.ty,
+ init: initializer.map(|i| {
+ statements.extend(i.statements);
+ if let Expression::Constant(constant) = extra.context.expressions[i.expression] {
+ Some(constant)
+ } else {
+ init_exp = Some(i.expression);
+ None
+ }
+ }).flatten(),
+ }
+ );
+ let exp = extra.context.expressions.append(Expression::LocalVariable(localVar));
+ extra.context.add_local_var(id, exp);
+
+ if let Some(value) = init_exp {
+ statements.push(
+ Statement::Store {
+ pointer: exp,
+ value,
+ }
+ );
+ }
+ }
+ };
+ match statements.len() {
+ 1 => statements.remove(0),
+ _ => Statement::Block(statements),
+ }
+ }
+
+ // statement
+ statement ::= compound_statement(cs) {
+ Statement::Block(cs)
+ }
+ statement ::= simple_statement;
+
+ simple_statement ::= declaration_statement;
+ simple_statement ::= expression_statement;
+ simple_statement ::= selection_statement;
+ simple_statement ::= jump_statement;
+ simple_statement ::= iteration_statement;
+
+
+ selection_statement ::= If LeftParen expression(e) RightParen statement(s1) Else statement(s2) {
+ Statement::If {
+ condition: e.expression,
+ accept: vec![s1],
+ reject: vec![s2],
+ }
+ }
+
+ selection_statement ::= If LeftParen expression(e) RightParen statement(s) [Else] {
+ Statement::If {
+ condition: e.expression,
+ accept: vec![s],
+ reject: vec![],
+ }
+ }
+
+ selection_statement ::= Switch LeftParen expression(e) RightParen LeftBrace switch_statement_list(ls) RightBrace {
+ let mut default = Vec::new();
+ let mut cases = FastHashMap::default();
+ for (v, s, ft) in ls {
+ if let Some(v) = v {
+ cases.insert(v, (s, ft));
+ } else {
+ default.extend_from_slice(&s);
+ }
+ }
+ Statement::Switch {
+ selector: e.expression,
+ cases,
+ default,
+ }
+ }
+
+ switch_statement_list ::= {
+ vec![]
+ }
+ switch_statement_list ::= switch_statement_list(mut ssl) switch_statement((v, sl, ft)) {
+ ssl.push((v, sl, ft));
+ ssl
+ }
+ switch_statement ::= Case IntConstant(v) Colon statement_list(sl) {
+ let fallthrough = match sl.last() {
+ Some(Statement::Break) => None,
+ _ => Some(FallThrough),
+ };
+ (Some(v.1 as i32), sl, fallthrough)
+ }
+ switch_statement ::= Default Colon statement_list(sl) {
+ let fallthrough = match sl.last() {
+ Some(Statement::Break) => Some(FallThrough),
+ _ => None,
+ };
+ (None, sl, fallthrough)
+ }
+
+ iteration_statement ::= While LeftParen expression(e) RightParen compound_statement_no_new_scope(sl) {
+ let mut body = Vec::with_capacity(sl.len() + 1);
+ body.push(
+ Statement::If {
+ condition: e.expression,
+ accept: vec![Statement::Break],
+ reject: vec![],
+ }
+ );
+ body.extend_from_slice(&sl);
+ Statement::Loop {
+ body,
+ continuing: vec![],
+ }
+ }
+
+ iteration_statement ::= Do compound_statement(sl) While LeftParen expression(e) RightParen {
+ let mut body = sl;
+ body.push(
+ Statement::If {
+ condition: e.expression,
+ accept: vec![Statement::Break],
+ reject: vec![],
+ }
+ );
+ Statement::Loop {
+ body,
+ continuing: vec![],
+ }
+ }
+
+ iteration_statement ::= For LeftParen for_init_statement(s_init) for_rest_statement((cond_e, loop_e)) RightParen compound_statement_no_new_scope(sl) {
+ let mut body = Vec::with_capacity(sl.len() + 2);
+ if let Some(cond_e) = cond_e {
+ body.push(
+ Statement::If {
+ condition: cond_e.expression,
+ accept: vec![Statement::Break],
+ reject: vec![],
+ }
+ );
+ }
+ body.extend_from_slice(&sl);
+ if let Some(loop_e) = loop_e {
+ body.extend_from_slice(&loop_e.statements);
+ }
+ Statement::Block(vec![
+ s_init,
+ Statement::Loop {
+ body,
+ continuing: vec![],
+ }
+ ])
+ }
+
+ for_init_statement ::= expression_statement;
+ for_init_statement ::= declaration_statement;
+ for_rest_statement ::= condition_opt(c) Semicolon {
+ (c, None)
+ }
+ for_rest_statement ::= condition_opt(c) Semicolon expression(e) {
+ (c, Some(e))
+ }
+
+ condition_opt ::= {
+ None
+ }
+ condition_opt ::= conditional_expression(c) {
+ Some(c)
+ }
+
+ compound_statement ::= LeftBrace RightBrace {
+ vec![]
+ }
+ compound_statement ::= left_brace_scope statement_list(sl) RightBrace {
+ extra.context.remove_current_scope();
+ sl
+ }
+
+ // extra rule to add scope before statement_list
+ left_brace_scope ::= LeftBrace {
+ extra.context.push_scope();
+ }
+
+
+ compound_statement_no_new_scope ::= LeftBrace RightBrace {
+ vec![]
+ }
+ compound_statement_no_new_scope ::= LeftBrace statement_list(sl) RightBrace {
+ sl
+ }
+
+ statement_list ::= statement(s) {
+ vec![s]
+ }
+ statement_list ::= statement_list(mut ss) statement(s) { ss.push(s); ss }
+
+ expression_statement ::= Semicolon {
+ Statement::Block(Vec::new())
+ }
+ expression_statement ::= expression(mut e) Semicolon {
+ match e.statements.len() {
+ 1 => e.statements.remove(0),
+ _ => Statement::Block(e.statements),
+ }
+ }
+
+
+
+ // function
+ function_prototype ::= function_declarator(f) RightParen {
+ // prelude, add global var expressions
+ for (var_handle, var) in extra.module.global_variables.iter() {
+ if let Some(name) = var.name.as_ref() {
+ let exp = extra.context.expressions.append(
+ Expression::GlobalVariable(var_handle)
+ );
+ extra.context.lookup_global_var_exps.insert(name.clone(), exp);
+ } else {
+ let ty = &extra.module.types[var.ty];
+ // anonymous structs
+ if let TypeInner::Struct { members } = &ty.inner {
+ let base = extra.context.expressions.append(
+ Expression::GlobalVariable(var_handle)
+ );
+ for (idx, member) in members.iter().enumerate() {
+ if let Some(name) = member.name.as_ref() {
+ let exp = extra.context.expressions.append(
+ Expression::AccessIndex{
+ base,
+ index: idx as u32,
+ }
+ );
+ extra.context.lookup_global_var_exps.insert(name.clone(), exp);
+ }
+ }
+ }
+ }
+ }
+ f
+ }
+ function_declarator ::= function_header;
+ function_header ::= fully_specified_type(t) Identifier(n) LeftParen {
+ Function {
+ name: Some(n.1),
+ arguments: vec![],
+ return_type: t.1,
+ global_usage: vec![],
+ local_variables: Arena::<LocalVariable>::new(),
+ expressions: Arena::<Expression>::new(),
+ body: vec![],
+ }
+ }
+
+ jump_statement ::= Continue Semicolon {
+ Statement::Continue
+ }
+ jump_statement ::= Break Semicolon {
+ Statement::Break
+ }
+ jump_statement ::= Return Semicolon {
+ Statement::Return { value: None }
+ }
+ jump_statement ::= Return expression(mut e) Semicolon {
+ let ret = Statement::Return{ value: Some(e.expression) };
+ if !e.statements.is_empty() {
+ e.statements.push(ret);
+ Statement::Block(e.statements)
+ } else {
+ ret
+ }
+ }
+ jump_statement ::= Discard Semicolon {
+ Statement::Kill
+ } // Fragment shader only
+
+ // Grammar Note: No 'goto'. Gotos are not supported.
+
+ // misc
+ translation_unit ::= external_declaration;
+ translation_unit ::= translation_unit external_declaration;
+
+ external_declaration ::= function_definition(f) {
+ if f.name == extra.entry {
+ let name = extra.entry.take().unwrap();
+ extra.module.entry_points.insert(
+ (extra.shader_stage, name),
+ EntryPoint {
+ early_depth_test: None,
+ workgroup_size: [0; 3], //TODO
+ function: f,
+ },
+ );
+ } else {
+ let name = f.name.clone().unwrap();
+ let handle = extra.module.functions.append(f);
+ extra.lookup_function.insert(name, handle);
+ }
+ }
+ external_declaration ::= declaration(d) {
+ if let Some(d) = d {
+ let class = d.type_qualifiers.iter().find_map(|tq| {
+ if let TypeQualifier::StorageClass(sc) = tq { Some(*sc) } else { None }
+ }).ok_or(ErrorKind::SemanticError("Missing storage class for global var"))?;
+
+ let binding = d.type_qualifiers.iter().find_map(|tq| {
+ if let TypeQualifier::Binding(b) = tq { Some(b.clone()) } else { None }
+ });
+
+ let interpolation = d.type_qualifiers.iter().find_map(|tq| {
+ if let TypeQualifier::Interpolation(i) = tq { Some(*i) } else { None }
+ });
+
+ for (id, initializer) in d.ids_initializers {
+ let h = extra.module.global_variables.fetch_or_append(
+ GlobalVariable {
+ name: id.clone(),
+ class,
+ binding: binding.clone(),
+ ty: d.ty,
+ init: None,
+ interpolation,
+ storage_access: StorageAccess::empty(), //TODO
+ },
+ );
+ if let Some(id) = id {
+ extra.lookup_global_variables.insert(id, h);
+ }
+ }
+ }
+ }
+
+ function_definition ::= function_prototype(mut f) compound_statement_no_new_scope(mut cs) {
+ std::mem::swap(&mut f.expressions, &mut extra.context.expressions);
+ std::mem::swap(&mut f.local_variables, &mut extra.context.local_variables);
+ extra.context.clear_scopes();
+ extra.context.lookup_global_var_exps.clear();
+ extra.context.typifier = Typifier::new();
+ // make sure function ends with return
+ match cs.last() {
+ Some(Statement::Return {..}) => {}
+ _ => {cs.push(Statement::Return { value:None });}
+ }
+ f.body = cs;
+ f.fill_global_use(&extra.module.global_variables);
+ f
+ };
+}
+
+pub use parser::*;
diff --git a/third_party/rust/naga/src/front/glsl/parser_tests.rs b/third_party/rust/naga/src/front/glsl/parser_tests.rs
new file mode 100644
index 0000000000..ccdc657a78
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/parser_tests.rs
@@ -0,0 +1,182 @@
+use super::ast::Program;
+use super::error::ErrorKind;
+use super::lex::Lexer;
+use super::parser;
+use crate::ShaderStage;
+
+fn parse_program(source: &str, stage: ShaderStage) -> Result<Program, ErrorKind> {
+ let mut program = Program::new(stage, "");
+ let lex = Lexer::new(source);
+ let mut parser = parser::Parser::new(&mut program);
+
+ for token in lex {
+ parser.parse(token)?;
+ }
+ parser.end_of_input()?;
+ Ok(program)
+}
+
+#[test]
+fn version() {
+ // invalid versions
+ assert_eq!(
+ format!(
+ "{:?}",
+ parse_program("#version 99000", ShaderStage::Vertex)
+ .err()
+ .unwrap()
+ ),
+ "InvalidVersion(TokenMetadata { line: 0, chars: 9..14 }, 99000)"
+ );
+
+ assert_eq!(
+ format!(
+ "{:?}",
+ parse_program("#version 449", ShaderStage::Vertex)
+ .err()
+ .unwrap()
+ ),
+ "InvalidVersion(TokenMetadata { line: 0, chars: 9..12 }, 449)"
+ );
+
+ assert_eq!(
+ format!(
+ "{:?}",
+ parse_program("#version 450 smart", ShaderStage::Vertex)
+ .err()
+ .unwrap()
+ ),
+ "InvalidProfile(TokenMetadata { line: 0, chars: 13..18 }, \"smart\")"
+ );
+
+ assert_eq!(
+ format!(
+ "{:?}",
+ parse_program("#version 450\nvoid f(){} #version 450", ShaderStage::Vertex)
+ .err()
+ .unwrap()
+ ),
+ "InvalidToken(Unknown((TokenMetadata { line: 1, chars: 11..12 }, \"#\")))"
+ );
+
+ // valid versions
+ let program = parse_program(" # version 450\nvoid main() {}", ShaderStage::Vertex).unwrap();
+ assert_eq!(
+ format!("{:?}", (program.version, program.profile)),
+ "(450, Core)"
+ );
+
+ let program = parse_program("#version 450\nvoid main() {}", ShaderStage::Vertex).unwrap();
+ assert_eq!(
+ format!("{:?}", (program.version, program.profile)),
+ "(450, Core)"
+ );
+
+ let program = parse_program("#version 450 core\nvoid main() {}", ShaderStage::Vertex).unwrap();
+ assert_eq!(
+ format!("{:?}", (program.version, program.profile)),
+ "(450, Core)"
+ );
+}
+
+#[test]
+fn control_flow() {
+ let _program = parse_program(
+ r#"
+ # version 450
+ void main() {
+ if (true) {
+ return 1;
+ } else {
+ return 2;
+ }
+ }
+ "#,
+ ShaderStage::Vertex,
+ )
+ .unwrap();
+
+ let _program = parse_program(
+ r#"
+ # version 450
+ void main() {
+ if (true) {
+ return 1;
+ }
+ }
+ "#,
+ ShaderStage::Vertex,
+ )
+ .unwrap();
+
+ let _program = parse_program(
+ r#"
+ # version 450
+ void main() {
+ int x;
+ int y = 3;
+ switch (5) {
+ case 2:
+ x = 2;
+ case 5:
+ x = 5;
+ y = 2;
+ break;
+ default:
+ x = 0;
+ }
+ }
+ "#,
+ ShaderStage::Vertex,
+ )
+ .unwrap();
+ let _program = parse_program(
+ r#"
+ # version 450
+ void main() {
+ int x = 0;
+ while(x < 5) {
+ x = x + 1;
+ }
+ do {
+ x = x - 1;
+ } while(x >= 4)
+ }
+ "#,
+ ShaderStage::Vertex,
+ )
+ .unwrap();
+
+ let _program = parse_program(
+ r#"
+ # version 450
+ void main() {
+ int x = 0;
+ for(int i = 0; i < 10;) {
+ x = x + 2;
+ }
+ return x;
+ }
+ "#,
+ ShaderStage::Vertex,
+ )
+ .unwrap();
+}
+
+#[test]
+fn textures() {
+ let _program = parse_program(
+ r#"
+ #version 450
+ layout(location = 0) in vec2 v_uv;
+ layout(location = 0) out vec4 o_color;
+ layout(set = 1, binding = 1) uniform texture2D tex;
+ layout(set = 1, binding = 2) uniform sampler tex_sampler;
+ void main() {
+ o_color = texture(sampler2D(tex, tex_sampler), v_uv);
+ }
+ "#,
+ ShaderStage::Fragment,
+ )
+ .unwrap();
+}
diff --git a/third_party/rust/naga/src/front/glsl/preprocess.rs b/third_party/rust/naga/src/front/glsl/preprocess.rs
new file mode 100644
index 0000000000..6050594719
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/preprocess.rs
@@ -0,0 +1,152 @@
+use crate::FastHashMap;
+use thiserror::Error;
+
+#[derive(Clone, Debug, Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum Error {
+ #[error("unmatched else")]
+ UnmatchedElse,
+ #[error("unmatched endif")]
+ UnmatchedEndif,
+ #[error("missing macro name")]
+ MissingMacro,
+}
+
+#[derive(Clone, Debug)]
+pub struct IfState {
+ true_branch: bool,
+ else_seen: bool,
+}
+
+#[derive(Clone, Debug)]
+pub struct LinePreProcessor {
+ pub defines: FastHashMap<String, String>,
+ if_stack: Vec<IfState>,
+ inside_comment: bool,
+ in_preprocess: bool,
+}
+
+impl LinePreProcessor {
+ pub fn new() -> Self {
+ LinePreProcessor {
+ defines: FastHashMap::default(),
+ if_stack: vec![],
+ inside_comment: false,
+ in_preprocess: false,
+ }
+ }
+
+ fn subst_defines(&self, input: &str) -> String {
+ //TODO: don't subst in commments, strings literals?
+ self.defines
+ .iter()
+ .fold(input.to_string(), |acc, (k, v)| acc.replace(k, v))
+ }
+
+ pub fn process_line(&mut self, line: &str) -> Result<Option<String>, Error> {
+ let mut skip = !self.if_stack.last().map(|i| i.true_branch).unwrap_or(true);
+ let mut inside_comment = self.inside_comment;
+ let mut in_preprocess = inside_comment && self.in_preprocess;
+ // single-line comment
+ let mut processed = line;
+ if let Some(pos) = line.find("//") {
+ processed = line.split_at(pos).0;
+ }
+ // multi-line comment
+ let mut processed_string: String;
+ loop {
+ if inside_comment {
+ if let Some(pos) = processed.find("*/") {
+ processed = processed.split_at(pos + 2).1;
+ inside_comment = false;
+ self.inside_comment = false;
+ continue;
+ }
+ } else if let Some(pos) = processed.find("/*") {
+ if let Some(end_pos) = processed[pos + 2..].find("*/") {
+ // comment ends during this line
+ processed_string = processed.to_string();
+ processed_string.replace_range(pos..pos + end_pos + 4, "");
+ processed = &processed_string;
+ } else {
+ processed = processed.split_at(pos).0;
+ inside_comment = true;
+ }
+ continue;
+ }
+ break;
+ }
+ // strip leading whitespace
+ processed = processed.trim_start();
+ if processed.starts_with('#') && !self.inside_comment {
+ let mut iter = processed[1..]
+ .trim_start()
+ .splitn(2, |c: char| c.is_whitespace());
+ if let Some(directive) = iter.next() {
+ skip = true;
+ in_preprocess = true;
+ match directive {
+ "version" => {
+ skip = false;
+ }
+ "define" => {
+ let rest = iter.next().ok_or(Error::MissingMacro)?;
+ let pos = rest
+ .find(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '(')
+ .unwrap_or_else(|| rest.len());
+ let (key, mut value) = rest.split_at(pos);
+ value = value.trim();
+ self.defines.insert(key.into(), self.subst_defines(value));
+ }
+ "undef" => {
+ let rest = iter.next().ok_or(Error::MissingMacro)?;
+ let key = rest.trim();
+ self.defines.remove(key);
+ }
+ "ifdef" => {
+ let rest = iter.next().ok_or(Error::MissingMacro)?;
+ let key = rest.trim();
+ self.if_stack.push(IfState {
+ true_branch: self.defines.contains_key(key),
+ else_seen: false,
+ });
+ }
+ "ifndef" => {
+ let rest = iter.next().ok_or(Error::MissingMacro)?;
+ let key = rest.trim();
+ self.if_stack.push(IfState {
+ true_branch: !self.defines.contains_key(key),
+ else_seen: false,
+ });
+ }
+ "else" => {
+ let if_state = self.if_stack.last_mut().ok_or(Error::UnmatchedElse)?;
+ if !if_state.else_seen {
+ // this is first else
+ if_state.true_branch = !if_state.true_branch;
+ if_state.else_seen = true;
+ } else {
+ return Err(Error::UnmatchedElse);
+ }
+ }
+ "endif" => {
+ self.if_stack.pop().ok_or(Error::UnmatchedEndif)?;
+ }
+ _ => {}
+ }
+ }
+ }
+ let res = if !skip && !self.inside_comment {
+ Ok(Some(self.subst_defines(&line)))
+ } else {
+ Ok(if in_preprocess && !self.in_preprocess {
+ Some("".to_string())
+ } else {
+ None
+ })
+ };
+ self.in_preprocess = in_preprocess || skip;
+ self.inside_comment = inside_comment;
+ res
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/preprocess_tests.rs b/third_party/rust/naga/src/front/glsl/preprocess_tests.rs
new file mode 100644
index 0000000000..253a99935e
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/preprocess_tests.rs
@@ -0,0 +1,218 @@
+use super::preprocess::{Error, LinePreProcessor};
+use std::{iter::Enumerate, str::Lines};
+
+#[derive(Clone, Debug)]
+pub struct PreProcessor<'a> {
+ lines: Enumerate<Lines<'a>>,
+ input: String,
+ line: usize,
+ offset: usize,
+ line_pp: LinePreProcessor,
+}
+
+impl<'a> PreProcessor<'a> {
+ pub fn new(input: &'a str) -> Self {
+ let mut lexer = PreProcessor {
+ lines: input.lines().enumerate(),
+ input: "".to_string(),
+ line: 0,
+ offset: 0,
+ line_pp: LinePreProcessor::new(),
+ };
+ lexer.next_line();
+ lexer
+ }
+
+ fn next_line(&mut self) -> bool {
+ if let Some((line, input)) = self.lines.next() {
+ let mut input = String::from(input);
+
+ while input.ends_with('\\') {
+ if let Some((_, next)) = self.lines.next() {
+ input.pop();
+ input.push_str(next);
+ } else {
+ break;
+ }
+ }
+
+ self.input = input;
+ self.line = line;
+ self.offset = 0;
+ true
+ } else {
+ false
+ }
+ }
+
+ pub fn process(&mut self) -> Result<String, Error> {
+ let mut res = String::new();
+ loop {
+ let line = &self.line_pp.process_line(&self.input)?;
+ if let Some(line) = line {
+ res.push_str(line);
+ }
+ if !self.next_line() {
+ break;
+ }
+ if line.is_some() {
+ res.push_str("\n");
+ }
+ }
+ Ok(res)
+ }
+}
+
+#[test]
+fn preprocess() {
+ // line continuation
+ let mut pp = PreProcessor::new(
+ "void main my_\
+ func",
+ );
+ assert_eq!(pp.process().unwrap(), "void main my_func");
+
+ // preserve #version
+ let mut pp = PreProcessor::new(
+ "#version 450 core\n\
+ void main()",
+ );
+ assert_eq!(pp.process().unwrap(), "#version 450 core\nvoid main()");
+
+ // simple define
+ let mut pp = PreProcessor::new(
+ "#define FOO 42 \n\
+ fun=FOO",
+ );
+ assert_eq!(pp.process().unwrap(), "\nfun=42");
+
+ // ifdef with else
+ let mut pp = PreProcessor::new(
+ "#define FOO\n\
+ #ifdef FOO\n\
+ foo=42\n\
+ #endif\n\
+ some=17\n\
+ #ifdef BAR\n\
+ bar=88\n\
+ #else\n\
+ mm=49\n\
+ #endif\n\
+ done=1",
+ );
+ assert_eq!(
+ pp.process().unwrap(),
+ "\n\
+ foo=42\n\
+ \n\
+ some=17\n\
+ \n\
+ mm=49\n\
+ \n\
+ done=1"
+ );
+
+ // nested ifdef/ifndef
+ let mut pp = PreProcessor::new(
+ "#define FOO\n\
+ #define BOO\n\
+ #ifdef FOO\n\
+ foo=42\n\
+ #ifdef BOO\n\
+ boo=44\n\
+ #endif\n\
+ ifd=0\n\
+ #ifndef XYZ\n\
+ nxyz=8\n\
+ #endif\n\
+ #endif\n\
+ some=17\n\
+ #ifdef BAR\n\
+ bar=88\n\
+ #else\n\
+ mm=49\n\
+ #endif\n\
+ done=1",
+ );
+ assert_eq!(
+ pp.process().unwrap(),
+ "\n\
+ foo=42\n\
+ \n\
+ boo=44\n\
+ \n\
+ ifd=0\n\
+ \n\
+ nxyz=8\n\
+ \n\
+ some=17\n\
+ \n\
+ mm=49\n\
+ \n\
+ done=1"
+ );
+
+ // undef
+ let mut pp = PreProcessor::new(
+ "#define FOO\n\
+ #ifdef FOO\n\
+ foo=42\n\
+ #endif\n\
+ some=17\n\
+ #undef FOO\n\
+ #ifdef FOO\n\
+ foo=88\n\
+ #else\n\
+ nofoo=66\n\
+ #endif\n\
+ done=1",
+ );
+ assert_eq!(
+ pp.process().unwrap(),
+ "\n\
+ foo=42\n\
+ \n\
+ some=17\n\
+ \n\
+ nofoo=66\n\
+ \n\
+ done=1"
+ );
+
+ // single-line comment
+ let mut pp = PreProcessor::new(
+ "#define FOO 42//1234\n\
+ fun=FOO",
+ );
+ assert_eq!(pp.process().unwrap(), "\nfun=42");
+
+ // multi-line comments
+ let mut pp = PreProcessor::new(
+ "#define FOO 52/*/1234\n\
+ #define FOO 88\n\
+ end of comment*/ /* one more comment */ #define FOO 56\n\
+ fun=FOO",
+ );
+ assert_eq!(pp.process().unwrap(), "\nfun=56");
+
+ // unmatched endif
+ let mut pp = PreProcessor::new(
+ "#ifdef FOO\n\
+ foo=42\n\
+ #endif\n\
+ #endif",
+ );
+ assert_eq!(pp.process(), Err(Error::UnmatchedEndif));
+
+ // unmatched else
+ let mut pp = PreProcessor::new(
+ "#ifdef FOO\n\
+ foo=42\n\
+ #else\n\
+ bar=88\n\
+ #else\n\
+ bad=true\n\
+ #endif",
+ );
+ assert_eq!(pp.process(), Err(Error::UnmatchedElse));
+}
diff --git a/third_party/rust/naga/src/front/glsl/token.rs b/third_party/rust/naga/src/front/glsl/token.rs
new file mode 100644
index 0000000000..7b849d361f
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/token.rs
@@ -0,0 +1,8 @@
+use std::ops::Range;
+
+#[derive(Debug, Clone)]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct TokenMetadata {
+ pub line: usize,
+ pub chars: Range<usize>,
+}
diff --git a/third_party/rust/naga/src/front/glsl/types.rs b/third_party/rust/naga/src/front/glsl/types.rs
new file mode 100644
index 0000000000..715c0c7430
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/types.rs
@@ -0,0 +1,120 @@
+use crate::{ScalarKind, Type, TypeInner, VectorSize};
+
+pub fn parse_type(type_name: &str) -> Option<Type> {
+ match type_name {
+ "bool" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Bool,
+ width: 4, // https://stackoverflow.com/questions/9419781/what-is-the-size-of-glsl-boolean
+ },
+ }),
+ "float" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ }),
+ "double" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Float,
+ width: 8,
+ },
+ }),
+ "int" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Sint,
+ width: 4,
+ },
+ }),
+ "uint" => Some(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Uint,
+ width: 4,
+ },
+ }),
+ "texture2D" => Some(Type {
+ name: None,
+ inner: TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Sampled {
+ kind: ScalarKind::Float,
+ multi: false,
+ },
+ },
+ }),
+ "sampler" => Some(Type {
+ name: None,
+ inner: TypeInner::Sampler { comparison: false },
+ }),
+ word => {
+ fn kind_width_parse(ty: &str) -> Option<(ScalarKind, u8)> {
+ Some(match ty {
+ "" => (ScalarKind::Float, 4),
+ "b" => (ScalarKind::Bool, 4),
+ "i" => (ScalarKind::Sint, 4),
+ "u" => (ScalarKind::Uint, 4),
+ "d" => (ScalarKind::Float, 8),
+ _ => return None,
+ })
+ }
+
+ fn size_parse(n: &str) -> Option<VectorSize> {
+ Some(match n {
+ "2" => VectorSize::Bi,
+ "3" => VectorSize::Tri,
+ "4" => VectorSize::Quad,
+ _ => return None,
+ })
+ }
+
+ let vec_parse = |word: &str| {
+ let mut iter = word.split("vec");
+
+ let kind = iter.next()?;
+ let size = iter.next()?;
+ let (kind, width) = kind_width_parse(kind)?;
+ let size = size_parse(size)?;
+
+ Some(Type {
+ name: None,
+ inner: TypeInner::Vector { size, kind, width },
+ })
+ };
+
+ let mat_parse = |word: &str| {
+ let mut iter = word.split("mat");
+
+ let kind = iter.next()?;
+ let size = iter.next()?;
+ let (_, width) = kind_width_parse(kind)?;
+
+ let (columns, rows) = if let Some(size) = size_parse(size) {
+ (size, size)
+ } else {
+ let mut iter = size.split('x');
+ match (iter.next()?, iter.next()?, iter.next()) {
+ (col, row, None) => (size_parse(col)?, size_parse(row)?),
+ _ => return None,
+ }
+ };
+
+ Some(Type {
+ name: None,
+ inner: TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ },
+ })
+ };
+
+ vec_parse(word).or_else(|| mat_parse(word))
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/variables.rs b/third_party/rust/naga/src/front/glsl/variables.rs
new file mode 100644
index 0000000000..9f76335685
--- /dev/null
+++ b/third_party/rust/naga/src/front/glsl/variables.rs
@@ -0,0 +1,185 @@
+use crate::{
+ Binding, BuiltIn, Expression, GlobalVariable, Handle, ScalarKind, ShaderStage, StorageAccess,
+ StorageClass, Type, TypeInner, VectorSize,
+};
+
+use super::ast::*;
+use super::error::ErrorKind;
+use super::token::TokenMetadata;
+
+impl Program {
+ pub fn lookup_variable(&mut self, name: &str) -> Result<Option<Handle<Expression>>, ErrorKind> {
+ let mut expression: Option<Handle<Expression>> = None;
+ match name {
+ "gl_Position" => {
+ #[cfg(feature = "glsl-validate")]
+ match self.shader_stage {
+ ShaderStage::Vertex | ShaderStage::Fragment { .. } => {}
+ _ => {
+ return Err(ErrorKind::VariableNotAvailable(name.into()));
+ }
+ };
+ let h = self
+ .module
+ .global_variables
+ .fetch_or_append(GlobalVariable {
+ name: Some(name.into()),
+ class: if self.shader_stage == ShaderStage::Vertex {
+ StorageClass::Output
+ } else {
+ StorageClass::Input
+ },
+ binding: Some(Binding::BuiltIn(BuiltIn::Position)),
+ ty: self.module.types.fetch_or_append(Type {
+ name: None,
+ inner: TypeInner::Vector {
+ size: VectorSize::Quad,
+ kind: ScalarKind::Float,
+ width: 4,
+ },
+ }),
+ init: None,
+ interpolation: None,
+ storage_access: StorageAccess::empty(),
+ });
+ self.lookup_global_variables.insert(name.into(), h);
+ let exp = self
+ .context
+ .expressions
+ .append(Expression::GlobalVariable(h));
+ self.context.lookup_global_var_exps.insert(name.into(), exp);
+
+ expression = Some(exp);
+ }
+ "gl_VertexIndex" => {
+ #[cfg(feature = "glsl-validate")]
+ match self.shader_stage {
+ ShaderStage::Vertex => {}
+ _ => {
+ return Err(ErrorKind::VariableNotAvailable(name.into()));
+ }
+ };
+ let h = self
+ .module
+ .global_variables
+ .fetch_or_append(GlobalVariable {
+ name: Some(name.into()),
+ class: StorageClass::Input,
+ binding: Some(Binding::BuiltIn(BuiltIn::VertexIndex)),
+ ty: self.module.types.fetch_or_append(Type {
+ name: None,
+ inner: TypeInner::Scalar {
+ kind: ScalarKind::Uint,
+ width: 4,
+ },
+ }),
+ init: None,
+ interpolation: None,
+ storage_access: StorageAccess::empty(),
+ });
+ self.lookup_global_variables.insert(name.into(), h);
+ let exp = self
+ .context
+ .expressions
+ .append(Expression::GlobalVariable(h));
+ self.context.lookup_global_var_exps.insert(name.into(), exp);
+
+ expression = Some(exp);
+ }
+ _ => {}
+ }
+
+ if let Some(expression) = expression {
+ Ok(Some(expression))
+ } else if let Some(local_var) = self.context.lookup_local_var(name) {
+ Ok(Some(local_var))
+ } else if let Some(global_var) = self.context.lookup_global_var_exps.get(name) {
+ Ok(Some(*global_var))
+ } else {
+ Ok(None)
+ }
+ }
+
+ pub fn field_selection(
+ &mut self,
+ expression: Handle<Expression>,
+ name: &str,
+ meta: TokenMetadata,
+ ) -> Result<Handle<Expression>, ErrorKind> {
+ match *self.resolve_type(expression)? {
+ TypeInner::Struct { ref members } => {
+ let index = members
+ .iter()
+ .position(|m| m.name == Some(name.into()))
+ .ok_or_else(|| ErrorKind::UnknownField(meta, name.into()))?;
+ Ok(self.context.expressions.append(Expression::AccessIndex {
+ base: expression,
+ index: index as u32,
+ }))
+ }
+ // swizzles (xyzw, rgba, stpq)
+ TypeInner::Vector { size, kind, width } => {
+ let check_swizzle_components = |comps: &str| {
+ name.chars()
+ .map(|c| {
+ comps
+ .find(c)
+ .and_then(|i| if i < size as usize { Some(i) } else { None })
+ })
+ .fold(Some(Vec::<usize>::new()), |acc, cur| {
+ cur.and_then(|i| {
+ acc.map(|mut v| {
+ v.push(i);
+ v
+ })
+ })
+ })
+ };
+
+ let indices = check_swizzle_components("xyzw")
+ .or_else(|| check_swizzle_components("rgba"))
+ .or_else(|| check_swizzle_components("stpq"));
+
+ if let Some(v) = indices {
+ let components: Vec<Handle<Expression>> = v
+ .iter()
+ .map(|idx| {
+ self.context.expressions.append(Expression::AccessIndex {
+ base: expression,
+ index: *idx as u32,
+ })
+ })
+ .collect();
+ if components.len() == 1 {
+ // only single element swizzle, like pos.y, just return that component
+ Ok(components[0])
+ } else {
+ Ok(self.context.expressions.append(Expression::Compose {
+ ty: self.module.types.fetch_or_append(Type {
+ name: None,
+ inner: TypeInner::Vector {
+ kind,
+ width,
+ size: match components.len() {
+ 2 => VectorSize::Bi,
+ 3 => VectorSize::Tri,
+ 4 => VectorSize::Quad,
+ _ => {
+ return Err(ErrorKind::SemanticError(
+ "Bad swizzle size",
+ ));
+ }
+ },
+ },
+ }),
+ components,
+ }))
+ }
+ } else {
+ Err(ErrorKind::SemanticError("Invalid swizzle for vector"))
+ }
+ }
+ _ => Err(ErrorKind::SemanticError("Can't lookup field on this type")),
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/front/mod.rs b/third_party/rust/naga/src/front/mod.rs
new file mode 100644
index 0000000000..cb543c72a9
--- /dev/null
+++ b/third_party/rust/naga/src/front/mod.rs
@@ -0,0 +1,32 @@
+//! Parsers which load shaders into memory.
+
+#[cfg(feature = "glsl-in")]
+pub mod glsl;
+#[cfg(feature = "spv-in")]
+pub mod spv;
+#[cfg(feature = "wgsl-in")]
+pub mod wgsl;
+
+use crate::arena::Arena;
+
+pub const GENERATOR: u32 = 0;
+
+impl crate::Module {
+ pub fn from_header(header: crate::Header) -> Self {
+ crate::Module {
+ header,
+ types: Arena::new(),
+ constants: Arena::new(),
+ global_variables: Arena::new(),
+ functions: Arena::new(),
+ entry_points: crate::FastHashMap::default(),
+ }
+ }
+
+ pub fn generate_empty() -> Self {
+ Self::from_header(crate::Header {
+ version: (1, 0, 0),
+ generator: GENERATOR,
+ })
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/convert.rs b/third_party/rust/naga/src/front/spv/convert.rs
new file mode 100644
index 0000000000..6036d08946
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/convert.rs
@@ -0,0 +1,123 @@
+use super::error::Error;
+use num_traits::cast::FromPrimitive;
+use std::convert::TryInto;
+
+pub fn map_binary_operator(word: spirv::Op) -> Result<crate::BinaryOperator, Error> {
+ use crate::BinaryOperator;
+ use spirv::Op;
+
+ match word {
+ // Arithmetic Instructions +, -, *, /, %
+ Op::IAdd | Op::FAdd => Ok(BinaryOperator::Add),
+ Op::ISub | Op::FSub => Ok(BinaryOperator::Subtract),
+ Op::IMul | Op::FMul => Ok(BinaryOperator::Multiply),
+ Op::UDiv | Op::SDiv | Op::FDiv => Ok(BinaryOperator::Divide),
+ Op::UMod | Op::SMod | Op::FMod => Ok(BinaryOperator::Modulo),
+ // Relational and Logical Instructions
+ Op::IEqual | Op::FOrdEqual | Op::FUnordEqual => Ok(BinaryOperator::Equal),
+ Op::INotEqual | Op::FOrdNotEqual | Op::FUnordNotEqual => Ok(BinaryOperator::NotEqual),
+ Op::ULessThan | Op::SLessThan | Op::FOrdLessThan | Op::FUnordLessThan => {
+ Ok(BinaryOperator::Less)
+ }
+ Op::ULessThanEqual
+ | Op::SLessThanEqual
+ | Op::FOrdLessThanEqual
+ | Op::FUnordLessThanEqual => Ok(BinaryOperator::LessEqual),
+ Op::UGreaterThan | Op::SGreaterThan | Op::FOrdGreaterThan | Op::FUnordGreaterThan => {
+ Ok(BinaryOperator::Greater)
+ }
+ Op::UGreaterThanEqual
+ | Op::SGreaterThanEqual
+ | Op::FOrdGreaterThanEqual
+ | Op::FUnordGreaterThanEqual => Ok(BinaryOperator::GreaterEqual),
+ _ => Err(Error::UnknownInstruction(word as u16)),
+ }
+}
+
+pub fn map_vector_size(word: spirv::Word) -> Result<crate::VectorSize, Error> {
+ match word {
+ 2 => Ok(crate::VectorSize::Bi),
+ 3 => Ok(crate::VectorSize::Tri),
+ 4 => Ok(crate::VectorSize::Quad),
+ _ => Err(Error::InvalidVectorSize(word)),
+ }
+}
+
+pub fn map_image_dim(word: spirv::Word) -> Result<crate::ImageDimension, Error> {
+ use spirv::Dim as D;
+ match D::from_u32(word) {
+ Some(D::Dim1D) => Ok(crate::ImageDimension::D1),
+ Some(D::Dim2D) => Ok(crate::ImageDimension::D2),
+ Some(D::Dim3D) => Ok(crate::ImageDimension::D3),
+ Some(D::DimCube) => Ok(crate::ImageDimension::Cube),
+ _ => Err(Error::UnsupportedImageDim(word)),
+ }
+}
+
+pub fn map_image_format(word: spirv::Word) -> Result<crate::StorageFormat, Error> {
+ match spirv::ImageFormat::from_u32(word) {
+ Some(spirv::ImageFormat::R8) => Ok(crate::StorageFormat::R8Unorm),
+ Some(spirv::ImageFormat::R8Snorm) => Ok(crate::StorageFormat::R8Snorm),
+ Some(spirv::ImageFormat::R8ui) => Ok(crate::StorageFormat::R8Uint),
+ Some(spirv::ImageFormat::R8i) => Ok(crate::StorageFormat::R8Sint),
+ Some(spirv::ImageFormat::R16ui) => Ok(crate::StorageFormat::R16Uint),
+ Some(spirv::ImageFormat::R16i) => Ok(crate::StorageFormat::R16Sint),
+ Some(spirv::ImageFormat::R16f) => Ok(crate::StorageFormat::R16Float),
+ Some(spirv::ImageFormat::Rg8) => Ok(crate::StorageFormat::Rg8Unorm),
+ Some(spirv::ImageFormat::Rg8Snorm) => Ok(crate::StorageFormat::Rg8Snorm),
+ Some(spirv::ImageFormat::Rg8ui) => Ok(crate::StorageFormat::Rg8Uint),
+ Some(spirv::ImageFormat::Rg8i) => Ok(crate::StorageFormat::Rg8Sint),
+ Some(spirv::ImageFormat::R32ui) => Ok(crate::StorageFormat::R32Uint),
+ Some(spirv::ImageFormat::R32i) => Ok(crate::StorageFormat::R32Sint),
+ Some(spirv::ImageFormat::R32f) => Ok(crate::StorageFormat::R32Float),
+ Some(spirv::ImageFormat::Rg16ui) => Ok(crate::StorageFormat::Rg16Uint),
+ Some(spirv::ImageFormat::Rg16i) => Ok(crate::StorageFormat::Rg16Sint),
+ Some(spirv::ImageFormat::Rg16f) => Ok(crate::StorageFormat::Rg16Float),
+ Some(spirv::ImageFormat::Rgba8) => Ok(crate::StorageFormat::Rgba8Unorm),
+ Some(spirv::ImageFormat::Rgba8Snorm) => Ok(crate::StorageFormat::Rgba8Snorm),
+ Some(spirv::ImageFormat::Rgba8ui) => Ok(crate::StorageFormat::Rgba8Uint),
+ Some(spirv::ImageFormat::Rgba8i) => Ok(crate::StorageFormat::Rgba8Sint),
+ Some(spirv::ImageFormat::Rgb10a2ui) => Ok(crate::StorageFormat::Rgb10a2Unorm),
+ Some(spirv::ImageFormat::R11fG11fB10f) => Ok(crate::StorageFormat::Rg11b10Float),
+ Some(spirv::ImageFormat::Rg32ui) => Ok(crate::StorageFormat::Rg32Uint),
+ Some(spirv::ImageFormat::Rg32i) => Ok(crate::StorageFormat::Rg32Sint),
+ Some(spirv::ImageFormat::Rg32f) => Ok(crate::StorageFormat::Rg32Float),
+ Some(spirv::ImageFormat::Rgba16ui) => Ok(crate::StorageFormat::Rgba16Uint),
+ Some(spirv::ImageFormat::Rgba16i) => Ok(crate::StorageFormat::Rgba16Sint),
+ Some(spirv::ImageFormat::Rgba16f) => Ok(crate::StorageFormat::Rgba16Float),
+ Some(spirv::ImageFormat::Rgba32ui) => Ok(crate::StorageFormat::Rgba32Uint),
+ Some(spirv::ImageFormat::Rgba32i) => Ok(crate::StorageFormat::Rgba32Sint),
+ Some(spirv::ImageFormat::Rgba32f) => Ok(crate::StorageFormat::Rgba32Float),
+ _ => Err(Error::UnsupportedImageFormat(word)),
+ }
+}
+
+pub fn map_width(word: spirv::Word) -> Result<crate::Bytes, Error> {
+ (word >> 3) // bits to bytes
+ .try_into()
+ .map_err(|_| Error::InvalidTypeWidth(word))
+}
+
+pub fn map_builtin(word: spirv::Word) -> Result<crate::BuiltIn, Error> {
+ use spirv::BuiltIn as Bi;
+ Ok(match spirv::BuiltIn::from_u32(word) {
+ Some(Bi::BaseInstance) => crate::BuiltIn::BaseInstance,
+ Some(Bi::BaseVertex) => crate::BuiltIn::BaseVertex,
+ Some(Bi::ClipDistance) => crate::BuiltIn::ClipDistance,
+ Some(Bi::InstanceIndex) => crate::BuiltIn::InstanceIndex,
+ Some(Bi::Position) => crate::BuiltIn::Position,
+ Some(Bi::VertexIndex) => crate::BuiltIn::VertexIndex,
+ // fragment
+ Some(Bi::PointSize) => crate::BuiltIn::PointSize,
+ Some(Bi::FragCoord) => crate::BuiltIn::FragCoord,
+ Some(Bi::FrontFacing) => crate::BuiltIn::FrontFacing,
+ Some(Bi::SampleId) => crate::BuiltIn::SampleIndex,
+ Some(Bi::FragDepth) => crate::BuiltIn::FragDepth,
+ // compute
+ Some(Bi::GlobalInvocationId) => crate::BuiltIn::GlobalInvocationId,
+ Some(Bi::LocalInvocationId) => crate::BuiltIn::LocalInvocationId,
+ Some(Bi::LocalInvocationIndex) => crate::BuiltIn::LocalInvocationIndex,
+ Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId,
+ _ => return Err(Error::UnsupportedBuiltIn(word)),
+ })
+}
diff --git a/third_party/rust/naga/src/front/spv/error.rs b/third_party/rust/naga/src/front/spv/error.rs
new file mode 100644
index 0000000000..0ebf603912
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/error.rs
@@ -0,0 +1,56 @@
+use super::ModuleState;
+use crate::arena::Handle;
+
+#[derive(Debug)]
+pub enum Error {
+ InvalidHeader,
+ InvalidWordCount,
+ UnknownInstruction(u16),
+ UnknownCapability(spirv::Word),
+ UnsupportedInstruction(ModuleState, spirv::Op),
+ UnsupportedCapability(spirv::Capability),
+ UnsupportedExtension(String),
+ UnsupportedExtSet(String),
+ UnsupportedExtInstSet(spirv::Word),
+ UnsupportedExtInst(spirv::Word),
+ UnsupportedType(Handle<crate::Type>),
+ UnsupportedExecutionModel(spirv::Word),
+ UnsupportedExecutionMode(spirv::Word),
+ UnsupportedStorageClass(spirv::Word),
+ UnsupportedImageDim(spirv::Word),
+ UnsupportedImageFormat(spirv::Word),
+ UnsupportedBuiltIn(spirv::Word),
+ UnsupportedControlFlow(spirv::Word),
+ UnsupportedBinaryOperator(spirv::Word),
+ InvalidParameter(spirv::Op),
+ InvalidOperandCount(spirv::Op, u16),
+ InvalidOperand,
+ InvalidId(spirv::Word),
+ InvalidDecoration(spirv::Word),
+ InvalidTypeWidth(spirv::Word),
+ InvalidSign(spirv::Word),
+ InvalidInnerType(spirv::Word),
+ InvalidVectorSize(spirv::Word),
+ InvalidVariableClass(spirv::StorageClass),
+ InvalidAccessType(spirv::Word),
+ InvalidAccess(Handle<crate::Expression>),
+ InvalidAccessIndex(spirv::Word),
+ InvalidBinding(spirv::Word),
+ InvalidImageExpression(Handle<crate::Expression>),
+ InvalidImageBaseType(Handle<crate::Type>),
+ InvalidSamplerExpression(Handle<crate::Expression>),
+ InvalidSampleImage(Handle<crate::Type>),
+ InvalidSampleSampler(Handle<crate::Type>),
+ InvalidSampleCoordinates(Handle<crate::Type>),
+ InvalidDepthReference(Handle<crate::Type>),
+ InvalidAsType(Handle<crate::Type>),
+ InconsistentComparisonSampling(Handle<crate::Type>),
+ WrongFunctionResultType(spirv::Word),
+ WrongFunctionArgumentType(spirv::Word),
+ MissingDecoration(spirv::Decoration),
+ BadString,
+ IncompleteData,
+ InvalidTerminator,
+ InvalidEdgeClassification,
+ UnexpectedComparisonType(Handle<crate::Type>),
+}
diff --git a/third_party/rust/naga/src/front/spv/flow.rs b/third_party/rust/naga/src/front/spv/flow.rs
new file mode 100644
index 0000000000..32eac66941
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/flow.rs
@@ -0,0 +1,569 @@
+#![allow(dead_code)]
+
+use super::error::Error;
+///! see https://en.wikipedia.org/wiki/Control-flow_graph
+///! see https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_structuredcontrolflow_a_structured_control_flow
+use super::{
+ function::{BlockId, MergeInstruction, Terminator},
+ LookupExpression, PhiInstruction,
+};
+
+use crate::FastHashMap;
+
+use petgraph::{
+ algo::has_path_connecting,
+ graph::{node_index, NodeIndex},
+ visit::EdgeRef,
+ Directed, Direction,
+};
+
+use std::fmt::Write;
+
+/// Index of a block node in the `ControlFlowGraph`.
+type BlockNodeIndex = NodeIndex<u32>;
+
+/// Internal representation of a CFG constisting of function's basic blocks.
+type ControlFlowGraph = petgraph::Graph<ControlFlowNode, ControlFlowEdgeType, Directed, u32>;
+
+/// Control flow graph (CFG) containing relationships between blocks.
+pub(super) struct FlowGraph {
+ ///
+ flow: ControlFlowGraph,
+
+ /// Block ID to Node index mapping. Internal helper to speed up the classification.
+ block_to_node: FastHashMap<BlockId, BlockNodeIndex>,
+}
+
+impl FlowGraph {
+ /// Creates empty flow graph.
+ pub(super) fn new() -> Self {
+ Self {
+ flow: ControlFlowGraph::default(),
+ block_to_node: FastHashMap::default(),
+ }
+ }
+
+ /// Add a control flow node.
+ pub(super) fn add_node(&mut self, node: ControlFlowNode) {
+ let block_id = node.id;
+ let node_index = self.flow.add_node(node);
+ self.block_to_node.insert(block_id, node_index);
+ }
+
+ ///
+ /// 1. Creates edges in the CFG.
+ /// 2. Classifies types of blocks and edges in the CFG.
+ pub(super) fn classify(&mut self) {
+ let block_to_node = &mut self.block_to_node;
+
+ // 1.
+ // Add all edges
+ // Classify Nodes as one of [Header, Loop, Kill, Return]
+ for source_node_index in self.flow.node_indices() {
+ // Merge edges
+ if let Some(merge) = self.flow[source_node_index].merge {
+ let merge_block_index = block_to_node[&merge.merge_block_id];
+
+ self.flow[source_node_index].ty = Some(ControlFlowNodeType::Header);
+ self.flow[merge_block_index].ty = Some(ControlFlowNodeType::Merge);
+ self.flow.add_edge(
+ source_node_index,
+ merge_block_index,
+ ControlFlowEdgeType::ForwardMerge,
+ );
+
+ if let Some(continue_block_id) = merge.continue_block_id {
+ let continue_block_index = block_to_node[&continue_block_id];
+
+ self.flow[source_node_index].ty = Some(ControlFlowNodeType::Loop);
+ self.flow.add_edge(
+ source_node_index,
+ continue_block_index,
+ ControlFlowEdgeType::ForwardContinue,
+ );
+ }
+ }
+
+ // Branch Edges
+ let terminator = self.flow[source_node_index].terminator.clone();
+ match terminator {
+ Terminator::Branch { target_id } => {
+ let target_node_index = block_to_node[&target_id];
+
+ self.flow.add_edge(
+ source_node_index,
+ target_node_index,
+ ControlFlowEdgeType::Forward,
+ );
+ }
+ Terminator::BranchConditional {
+ true_id, false_id, ..
+ } => {
+ let true_node_index = block_to_node[&true_id];
+ let false_node_index = block_to_node[&false_id];
+
+ self.flow.add_edge(
+ source_node_index,
+ true_node_index,
+ ControlFlowEdgeType::IfTrue,
+ );
+ self.flow.add_edge(
+ source_node_index,
+ false_node_index,
+ ControlFlowEdgeType::IfFalse,
+ );
+ }
+ Terminator::Switch {
+ selector: _,
+ default,
+ ref targets,
+ } => {
+ let default_node_index = block_to_node[&default];
+
+ self.flow.add_edge(
+ source_node_index,
+ default_node_index,
+ ControlFlowEdgeType::Forward,
+ );
+
+ for (_, target_block_id) in targets.iter() {
+ let target_node_index = block_to_node[&target_block_id];
+
+ self.flow.add_edge(
+ source_node_index,
+ target_node_index,
+ ControlFlowEdgeType::Forward,
+ );
+ }
+ }
+ Terminator::Return { .. } => {
+ self.flow[source_node_index].ty = Some(ControlFlowNodeType::Return)
+ }
+ Terminator::Kill => {
+ self.flow[source_node_index].ty = Some(ControlFlowNodeType::Kill)
+ }
+ _ => {}
+ };
+ }
+
+ // 2.
+ // Classify Nodes/Edges as one of [Break, Continue, Back]
+ for edge_index in self.flow.edge_indices() {
+ let (node_source_index, node_target_index) =
+ self.flow.edge_endpoints(edge_index).unwrap();
+
+ if self.flow[node_source_index].ty == Some(ControlFlowNodeType::Header)
+ || self.flow[node_source_index].ty == Some(ControlFlowNodeType::Loop)
+ {
+ continue;
+ }
+
+ // Back
+ if self.flow[node_target_index].ty == Some(ControlFlowNodeType::Loop)
+ && self.flow[node_source_index].id > self.flow[node_target_index].id
+ {
+ self.flow[node_source_index].ty = Some(ControlFlowNodeType::Back);
+ self.flow[edge_index] = ControlFlowEdgeType::Back;
+ }
+
+ let mut target_incoming_edges = self
+ .flow
+ .neighbors_directed(node_target_index, Direction::Incoming)
+ .detach();
+ while let Some((incoming_edge, incoming_source)) =
+ target_incoming_edges.next(&self.flow)
+ {
+ // Loop continue
+ if self.flow[incoming_edge] == ControlFlowEdgeType::ForwardContinue {
+ self.flow[node_source_index].ty = Some(ControlFlowNodeType::Continue);
+ self.flow[edge_index] = ControlFlowEdgeType::LoopContinue;
+ }
+ // Loop break
+ if self.flow[incoming_source].ty == Some(ControlFlowNodeType::Loop)
+ && self.flow[incoming_edge] == ControlFlowEdgeType::ForwardMerge
+ {
+ self.flow[node_source_index].ty = Some(ControlFlowNodeType::Break);
+ self.flow[edge_index] = ControlFlowEdgeType::LoopBreak;
+ }
+ }
+ }
+ }
+
+ /// Removes OpPhi instructions from the control flow graph and turns them into ordinary variables.
+ ///
+ /// Phi instructions are not supported inside Naga nor do they exist as instructions on CPUs. It is neccessary
+ /// to remove them and turn into ordinary variables before converting to Naga's IR and shader code.
+ pub(super) fn remove_phi_instructions(
+ &mut self,
+ lookup_expression: &FastHashMap<spirv::Word, LookupExpression>,
+ ) {
+ for node_index in self.flow.node_indices() {
+ let phis = std::mem::replace(&mut self.flow[node_index].phis, Vec::new());
+ for phi in phis.iter() {
+ let phi_var = &lookup_expression[&phi.id];
+ for (variable_id, parent_id) in phi.variables.iter() {
+ let variable = &lookup_expression[&variable_id];
+ let parent_node = &mut self.flow[self.block_to_node[&parent_id]];
+
+ parent_node.block.push(crate::Statement::Store {
+ pointer: phi_var.handle,
+ value: variable.handle,
+ });
+ }
+ }
+ self.flow[node_index].phis = phis;
+ }
+ }
+
+ /// Traverses the flow graph and returns a list of Naga's statements.
+ pub(super) fn to_naga(&self) -> Result<crate::Block, Error> {
+ self.naga_traverse(node_index(0), None)
+ }
+
+ fn naga_traverse(
+ &self,
+ node_index: BlockNodeIndex,
+ stop_node_index: Option<BlockNodeIndex>,
+ ) -> Result<crate::Block, Error> {
+ if stop_node_index == Some(node_index) {
+ return Ok(vec![]);
+ }
+
+ let node = &self.flow[node_index];
+
+ match node.ty {
+ Some(ControlFlowNodeType::Header) => match node.terminator {
+ Terminator::BranchConditional {
+ condition,
+ true_id,
+ false_id,
+ } => {
+ let true_node_index = self.block_to_node[&true_id];
+ let false_node_index = self.block_to_node[&false_id];
+ let merge_node_index = self.block_to_node[&node.merge.unwrap().merge_block_id];
+
+ let mut result = node.block.clone();
+
+ if false_node_index != merge_node_index {
+ result.push(crate::Statement::If {
+ condition,
+ accept: self.naga_traverse(true_node_index, Some(merge_node_index))?,
+ reject: self.naga_traverse(false_node_index, Some(merge_node_index))?,
+ });
+ result.extend(self.naga_traverse(merge_node_index, stop_node_index)?);
+ } else {
+ result.push(crate::Statement::If {
+ condition,
+ accept: self.naga_traverse(
+ self.block_to_node[&true_id],
+ Some(merge_node_index),
+ )?,
+ reject: self.naga_traverse(merge_node_index, stop_node_index)?,
+ });
+ }
+
+ Ok(result)
+ }
+ Terminator::Switch {
+ selector,
+ default,
+ ref targets,
+ } => {
+ let merge_node_index = self.block_to_node[&node.merge.unwrap().merge_block_id];
+ let mut result = node.block.clone();
+
+ let mut cases = FastHashMap::default();
+
+ for i in 0..targets.len() {
+ let left_target_node_index = self.block_to_node[&targets[i].1];
+
+ let fallthrough: Option<crate::FallThrough> = if i < targets.len() - 1 {
+ let right_target_node_index = self.block_to_node[&targets[i + 1].1];
+ if has_path_connecting(
+ &self.flow,
+ left_target_node_index,
+ right_target_node_index,
+ None,
+ ) {
+ Some(crate::FallThrough {})
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
+ cases.insert(
+ targets[i].0,
+ (
+ self.naga_traverse(left_target_node_index, Some(merge_node_index))?,
+ fallthrough,
+ ),
+ );
+ }
+
+ result.push(crate::Statement::Switch {
+ selector,
+ cases,
+ default: self
+ .naga_traverse(self.block_to_node[&default], Some(merge_node_index))?,
+ });
+
+ result.extend(self.naga_traverse(merge_node_index, stop_node_index)?);
+
+ Ok(result)
+ }
+ _ => Err(Error::InvalidTerminator),
+ },
+ Some(ControlFlowNodeType::Loop) => {
+ let merge_node_index = self.block_to_node[&node.merge.unwrap().merge_block_id];
+ let continuing: crate::Block = {
+ let continue_edge = self
+ .flow
+ .edges_directed(node_index, Direction::Outgoing)
+ .find(|&ty| *ty.weight() == ControlFlowEdgeType::ForwardContinue)
+ .unwrap();
+
+ self.flow[continue_edge.target()].block.clone()
+ };
+
+ let mut body = node.block.clone();
+ match node.terminator {
+ Terminator::BranchConditional {
+ condition,
+ true_id,
+ false_id,
+ } => body.push(crate::Statement::If {
+ condition,
+ accept: self
+ .naga_traverse(self.block_to_node[&true_id], Some(merge_node_index))?,
+ reject: self
+ .naga_traverse(self.block_to_node[&false_id], Some(merge_node_index))?,
+ }),
+ Terminator::Branch { target_id } => body.extend(
+ self.naga_traverse(self.block_to_node[&target_id], Some(merge_node_index))?,
+ ),
+ _ => return Err(Error::InvalidTerminator),
+ };
+
+ let mut result = vec![crate::Statement::Loop { body, continuing }];
+ result.extend(self.naga_traverse(merge_node_index, stop_node_index)?);
+
+ Ok(result)
+ }
+ Some(ControlFlowNodeType::Break) => {
+ let mut result = node.block.clone();
+ match node.terminator {
+ Terminator::BranchConditional {
+ condition,
+ true_id,
+ false_id,
+ } => {
+ let true_node_id = self.block_to_node[&true_id];
+ let false_node_id = self.block_to_node[&false_id];
+
+ let true_edge =
+ self.flow[self.flow.find_edge(node_index, true_node_id).unwrap()];
+ let false_edge =
+ self.flow[self.flow.find_edge(node_index, false_node_id).unwrap()];
+
+ if true_edge == ControlFlowEdgeType::LoopBreak {
+ result.push(crate::Statement::If {
+ condition,
+ accept: vec![crate::Statement::Break],
+ reject: self.naga_traverse(false_node_id, stop_node_index)?,
+ });
+ } else if false_edge == ControlFlowEdgeType::LoopBreak {
+ result.push(crate::Statement::If {
+ condition,
+ accept: self.naga_traverse(true_node_id, stop_node_index)?,
+ reject: vec![crate::Statement::Break],
+ });
+ } else {
+ return Err(Error::InvalidEdgeClassification);
+ }
+ }
+ Terminator::Branch { .. } => {
+ result.push(crate::Statement::Break);
+ }
+ _ => return Err(Error::InvalidTerminator),
+ };
+ Ok(result)
+ }
+ Some(ControlFlowNodeType::Continue) => {
+ let back_block = match node.terminator {
+ Terminator::Branch { target_id } => {
+ self.naga_traverse(self.block_to_node[&target_id], None)?
+ }
+ _ => return Err(Error::InvalidTerminator),
+ };
+
+ let mut result = node.block.clone();
+ result.extend(back_block);
+ result.push(crate::Statement::Continue);
+ Ok(result)
+ }
+ Some(ControlFlowNodeType::Back) => Ok(node.block.clone()),
+ Some(ControlFlowNodeType::Kill) => {
+ let mut result = node.block.clone();
+ result.push(crate::Statement::Kill);
+ Ok(result)
+ }
+ Some(ControlFlowNodeType::Return) => {
+ let value = match node.terminator {
+ Terminator::Return { value } => value,
+ _ => return Err(Error::InvalidTerminator),
+ };
+ let mut result = node.block.clone();
+ result.push(crate::Statement::Return { value });
+ Ok(result)
+ }
+ Some(ControlFlowNodeType::Merge) | None => match node.terminator {
+ Terminator::Branch { target_id } => {
+ let mut result = node.block.clone();
+ result.extend(
+ self.naga_traverse(self.block_to_node[&target_id], stop_node_index)?,
+ );
+ Ok(result)
+ }
+ _ => Ok(node.block.clone()),
+ },
+ }
+ }
+
+ /// Get the entire graph in a graphviz dot format for visualization. Useful for debugging purposes.
+ pub(super) fn to_graphviz(&self) -> Result<String, std::fmt::Error> {
+ let mut output = String::new();
+
+ output += "digraph ControlFlowGraph {\n";
+
+ for node_index in self.flow.node_indices() {
+ let node = &self.flow[node_index];
+ writeln!(
+ output,
+ "{} [ label = \"%{} {:?}\" ]",
+ node_index.index(),
+ node.id,
+ node.ty
+ )?;
+ }
+
+ for edge in self.flow.raw_edges() {
+ let source = edge.source();
+ let target = edge.target();
+
+ let style = match edge.weight {
+ ControlFlowEdgeType::Forward => "",
+ ControlFlowEdgeType::ForwardMerge => "style=dotted",
+ ControlFlowEdgeType::ForwardContinue => "color=green",
+ ControlFlowEdgeType::Back => "style=dashed",
+ ControlFlowEdgeType::LoopBreak => "color=yellow",
+ ControlFlowEdgeType::LoopContinue => "color=green",
+ ControlFlowEdgeType::IfTrue => "color=blue",
+ ControlFlowEdgeType::IfFalse => "color=red",
+ ControlFlowEdgeType::SwitchBreak => "color=yellow",
+ ControlFlowEdgeType::CaseFallThrough => "style=dotted",
+ };
+
+ writeln!(
+ &mut output,
+ "{} -> {} [ {} ]",
+ source.index(),
+ target.index(),
+ style
+ )?;
+ }
+
+ output += "}\n";
+
+ Ok(output)
+ }
+}
+
+/// Type of an edge(flow) in the `ControlFlowGraph`.
+#[derive(Copy, Clone, Eq, PartialEq, Debug)]
+pub(super) enum ControlFlowEdgeType {
+ /// Default
+ Forward,
+
+ /// Forward edge to a merge block.
+ ForwardMerge,
+
+ /// Forward edge to a OpLoopMerge continue's instruction.
+ ForwardContinue,
+
+ /// A back-edge: An edge from a node to one of its ancestors in a depth-first
+ /// search from the entry block.
+ /// Can only be to a ControlFlowNodeType::Loop.
+ Back,
+
+ /// An edge from a node to the merge block of the nearest enclosing loop, where
+ /// there is no intervening switch.
+ /// The source block is a "break block" as defined by SPIR-V.
+ LoopBreak,
+
+ /// An edge from a node in a loop body to the associated continue target, where
+ /// there are no other intervening loops or switches.
+ /// The source block is a "continue block" as defined by SPIR-V.
+ LoopContinue,
+
+ /// An edge from a node with OpBranchConditional to the block of true operand.
+ IfTrue,
+
+ /// An edge from a node with OpBranchConditional to the block of false operand.
+ IfFalse,
+
+ /// An edge from a node to the merge block of the nearest enclosing switch,
+ /// where there is no intervening loop.
+ SwitchBreak,
+
+ /// An edge from one switch case to the next sibling switch case.
+ CaseFallThrough,
+}
+/// Type of a node(block) in the `ControlFlowGraph`.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub(super) enum ControlFlowNodeType {
+ /// A block whose merge instruction is an OpSelectionMerge.
+ Header,
+
+ /// A header block whose merge instruction is an OpLoopMerge.
+ Loop,
+
+ /// A block declared by the Merge Block operand of a merge instruction.
+ Merge,
+
+ /// A block containing a branch to the Merge Block of a loop header’s merge instruction.
+ Break,
+
+ /// A block containing a branch to an OpLoopMerge instruction’s Continue Target.
+ Continue,
+
+ /// A block containing an OpBranch to a Loop block.
+ Back,
+
+ /// A block containing an OpKill instruction.
+ Kill,
+
+ /// A block containing an OpReturn or OpReturnValue branch.
+ Return,
+}
+/// ControlFlowGraph's node representing a block in the control flow.
+pub(super) struct ControlFlowNode {
+ /// SPIR-V ID.
+ pub id: BlockId,
+
+ /// Type of the node. See *ControlFlowNodeType*.
+ pub ty: Option<ControlFlowNodeType>,
+
+ /// Phi instructions.
+ pub phis: Vec<PhiInstruction>,
+
+ /// Naga's statements inside this block.
+ pub block: crate::Block,
+
+ /// Termination instruction of the block.
+ pub terminator: Terminator,
+
+ /// Merge Instruction
+ pub merge: Option<MergeInstruction>,
+}
diff --git a/third_party/rust/naga/src/front/spv/function.rs b/third_party/rust/naga/src/front/spv/function.rs
new file mode 100644
index 0000000000..d2cb0551a1
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/function.rs
@@ -0,0 +1,202 @@
+use crate::arena::Handle;
+
+use super::flow::*;
+use super::*;
+
+pub type BlockId = u32;
+
+#[derive(Copy, Clone, Debug)]
+pub struct MergeInstruction {
+ pub merge_block_id: BlockId,
+ pub continue_block_id: Option<BlockId>,
+}
+/// Terminator instruction of a SPIR-V's block.
+#[derive(Clone, Debug)]
+#[allow(dead_code)]
+pub enum Terminator {
+ ///
+ Return {
+ value: Option<Handle<crate::Expression>>,
+ },
+ ///
+ Branch { target_id: BlockId },
+ ///
+ BranchConditional {
+ condition: Handle<crate::Expression>,
+ true_id: BlockId,
+ false_id: BlockId,
+ },
+ ///
+ /// switch(SELECTOR) {
+ /// case TARGET_LITERAL#: {
+ /// TARGET_BLOCK#
+ /// }
+ /// default: {
+ /// DEFAULT
+ /// }
+ /// }
+ Switch {
+ ///
+ selector: Handle<crate::Expression>,
+ /// Default block of the switch case.
+ default: BlockId,
+ /// Tuples of (literal, target block)
+ targets: Vec<(i32, BlockId)>,
+ },
+ /// Fragment shader discard
+ Kill,
+ ///
+ Unreachable,
+}
+
+impl<I: Iterator<Item = u32>> super::Parser<I> {
+ pub fn parse_function(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ self.switch(ModuleState::Function, inst.op)?;
+ inst.expect(5)?;
+ let result_type = self.next()?;
+ let fun_id = self.next()?;
+ let _fun_control = self.next()?;
+ let fun_type = self.next()?;
+
+ let mut fun = {
+ let ft = self.lookup_function_type.lookup(fun_type)?;
+ if ft.return_type_id != result_type {
+ return Err(Error::WrongFunctionResultType(result_type));
+ }
+ crate::Function {
+ name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name),
+ arguments: Vec::with_capacity(ft.parameter_type_ids.len()),
+ return_type: if self.lookup_void_type.contains(&result_type) {
+ None
+ } else {
+ Some(self.lookup_type.lookup(result_type)?.handle)
+ },
+ global_usage: Vec::new(),
+ local_variables: Arena::new(),
+ expressions: self.make_expression_storage(),
+ body: Vec::new(),
+ }
+ };
+
+ // read parameters
+ for i in 0..fun.arguments.capacity() {
+ match self.next_inst()? {
+ Instruction {
+ op: spirv::Op::FunctionParameter,
+ wc: 3,
+ } => {
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let handle = fun
+ .expressions
+ .append(crate::Expression::FunctionArgument(i as u32));
+ self.lookup_expression
+ .insert(id, LookupExpression { type_id, handle });
+ //Note: we redo the lookup in order to work around `self` borrowing
+
+ if type_id
+ != self
+ .lookup_function_type
+ .lookup(fun_type)?
+ .parameter_type_ids[i]
+ {
+ return Err(Error::WrongFunctionArgumentType(type_id));
+ }
+ let ty = self.lookup_type.lookup(type_id)?.handle;
+ fun.arguments
+ .push(crate::FunctionArgument { name: None, ty });
+ }
+ Instruction { op, .. } => return Err(Error::InvalidParameter(op)),
+ }
+ }
+
+ // Read body
+ let mut local_function_calls = FastHashMap::default();
+ let mut flow_graph = FlowGraph::new();
+
+ // Scan the blocks and add them as nodes
+ loop {
+ let fun_inst = self.next_inst()?;
+ log::debug!("{:?}", fun_inst.op);
+ match fun_inst.op {
+ spirv::Op::Label => {
+ // Read the label ID
+ fun_inst.expect(2)?;
+ let block_id = self.next()?;
+
+ let node = self.next_block(
+ block_id,
+ &mut fun.expressions,
+ &mut fun.local_variables,
+ &module.types,
+ &module.constants,
+ &module.global_variables,
+ &mut local_function_calls,
+ )?;
+
+ flow_graph.add_node(node);
+ }
+ spirv::Op::FunctionEnd => {
+ fun_inst.expect(1)?;
+ break;
+ }
+ _ => {
+ return Err(Error::UnsupportedInstruction(self.state, fun_inst.op));
+ }
+ }
+ }
+
+ flow_graph.classify();
+ flow_graph.remove_phi_instructions(&self.lookup_expression);
+ fun.body = flow_graph.to_naga()?;
+
+ // done
+ fun.fill_global_use(&module.global_variables);
+
+ let source = match self.lookup_entry_point.remove(&fun_id) {
+ Some(ep) => {
+ module.entry_points.insert(
+ (ep.stage, ep.name.clone()),
+ crate::EntryPoint {
+ early_depth_test: ep.early_depth_test,
+ workgroup_size: ep.workgroup_size,
+ function: fun,
+ },
+ );
+ DeferredSource::EntryPoint(ep.stage, ep.name)
+ }
+ None => {
+ let handle = module.functions.append(fun);
+ self.lookup_function.insert(fun_id, handle);
+ DeferredSource::Function(handle)
+ }
+ };
+
+ if let Some(ref prefix) = self.options.flow_graph_dump_prefix {
+ let dump = flow_graph.to_graphviz().unwrap_or_default();
+ let suffix = match source {
+ DeferredSource::EntryPoint(stage, ref name) => {
+ format!("flow.{:?}-{}.dot", stage, name)
+ }
+ DeferredSource::Function(handle) => format!("flow.Fun-{}.dot", handle.index()),
+ };
+ let _ = std::fs::write(prefix.join(suffix), dump);
+ }
+
+ for (expr_handle, dst_id) in local_function_calls {
+ self.deferred_function_calls.push(DeferredFunctionCall {
+ source: source.clone(),
+ expr_handle,
+ dst_id,
+ });
+ }
+
+ self.lookup_expression.clear();
+ self.lookup_sampled_image.clear();
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/mod.rs b/third_party/rust/naga/src/front/spv/mod.rs
new file mode 100644
index 0000000000..f13c3c7577
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/mod.rs
@@ -0,0 +1,2416 @@
+/*! SPIR-V frontend
+
+## ID lookups
+
+Our IR links to everything with `Handle`, while SPIR-V uses IDs.
+In order to keep track of the associations, the parser has many lookup tables.
+There map `spv::Word` into a specific IR handle, plus potentially a bit of
+extra info, such as the related SPIR-V type ID.
+TODO: would be nice to find ways that avoid looking up as much
+
+!*/
+#![allow(dead_code)]
+
+mod convert;
+mod error;
+mod flow;
+mod function;
+#[cfg(all(test, feature = "serialize"))]
+mod rosetta;
+
+use convert::*;
+use error::Error;
+use flow::*;
+use function::*;
+
+use crate::{
+ arena::{Arena, Handle},
+ FastHashMap, FastHashSet,
+};
+
+use num_traits::cast::FromPrimitive;
+use std::{convert::TryInto, num::NonZeroU32, path::PathBuf};
+
+pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[
+ spirv::Capability::Shader,
+ spirv::Capability::CullDistance,
+ spirv::Capability::StorageImageExtendedFormats,
+];
+pub const SUPPORTED_EXTENSIONS: &[&str] = &[];
+pub const SUPPORTED_EXT_SETS: &[&str] = &["GLSL.std.450"];
+
+#[derive(Copy, Clone)]
+pub struct Instruction {
+ op: spirv::Op,
+ wc: u16,
+}
+
+impl Instruction {
+ fn expect(self, count: u16) -> Result<(), Error> {
+ if self.wc == count {
+ Ok(())
+ } else {
+ Err(Error::InvalidOperandCount(self.op, self.wc))
+ }
+ }
+
+ fn expect_at_least(self, count: u16) -> Result<(), Error> {
+ if self.wc >= count {
+ Ok(())
+ } else {
+ Err(Error::InvalidOperandCount(self.op, self.wc))
+ }
+ }
+}
+/// OpPhi instruction.
+#[derive(Clone, Default, Debug)]
+struct PhiInstruction {
+ /// SPIR-V's ID.
+ id: u32,
+
+ /// Tuples of (variable, parent).
+ variables: Vec<(u32, u32)>,
+}
+#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
+pub enum ModuleState {
+ Empty,
+ Capability,
+ Extension,
+ ExtInstImport,
+ MemoryModel,
+ EntryPoint,
+ ExecutionMode,
+ Source,
+ Name,
+ ModuleProcessed,
+ Annotation,
+ Type,
+ Function,
+}
+
+trait LookupHelper {
+ type Target;
+ fn lookup(&self, key: spirv::Word) -> Result<&Self::Target, Error>;
+}
+
+impl<T> LookupHelper for FastHashMap<spirv::Word, T> {
+ type Target = T;
+ fn lookup(&self, key: spirv::Word) -> Result<&T, Error> {
+ self.get(&key).ok_or(Error::InvalidId(key))
+ }
+}
+
+//TODO: this method may need to be gone, depending on whether
+// WGSL allows treating images and samplers as expressions and pass them around.
+fn reach_global_type(
+ mut expr_handle: Handle<crate::Expression>,
+ expressions: &Arena<crate::Expression>,
+ globals: &Arena<crate::GlobalVariable>,
+) -> Option<Handle<crate::Type>> {
+ loop {
+ expr_handle = match expressions[expr_handle] {
+ crate::Expression::Load { pointer } => pointer,
+ crate::Expression::GlobalVariable(var) => return Some(globals[var].ty),
+ _ => return None,
+ };
+ }
+}
+
+fn check_sample_coordinates(
+ ty: &crate::Type,
+ expect_kind: crate::ScalarKind,
+ dim: crate::ImageDimension,
+ is_array: bool,
+) -> bool {
+ let base_count = match dim {
+ crate::ImageDimension::D1 => 1,
+ crate::ImageDimension::D2 => 2,
+ crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
+ };
+ let extra_count = if is_array { 1 } else { 0 };
+ let count = base_count + extra_count;
+ match ty.inner {
+ crate::TypeInner::Scalar { kind, width: _ } => count == 1 && kind == expect_kind,
+ crate::TypeInner::Vector {
+ size,
+ kind,
+ width: _,
+ } => size as u8 == count && kind == expect_kind,
+ _ => false,
+ }
+}
+
+type MemberIndex = u32;
+
+#[derive(Debug, Default)]
+struct Block {
+ buffer: bool,
+}
+
+bitflags::bitflags! {
+ #[derive(Default)]
+ struct DecorationFlags: u32 {
+ const NON_READABLE = 0x1;
+ const NON_WRITABLE = 0x2;
+ }
+}
+
+#[derive(Debug, Default)]
+struct Decoration {
+ name: Option<String>,
+ built_in: Option<crate::BuiltIn>,
+ location: Option<spirv::Word>,
+ desc_set: Option<spirv::Word>,
+ desc_index: Option<spirv::Word>,
+ block: Option<Block>,
+ offset: Option<spirv::Word>,
+ array_stride: Option<NonZeroU32>,
+ interpolation: Option<crate::Interpolation>,
+ flags: DecorationFlags,
+}
+
+impl Decoration {
+ fn debug_name(&self) -> &str {
+ match self.name {
+ Some(ref name) => name.as_str(),
+ None => "?",
+ }
+ }
+
+ fn get_binding(&self) -> Option<crate::Binding> {
+ //TODO: validate this better
+ match *self {
+ Decoration {
+ built_in: Some(built_in),
+ location: None,
+ desc_set: None,
+ desc_index: None,
+ ..
+ } => Some(crate::Binding::BuiltIn(built_in)),
+ Decoration {
+ built_in: None,
+ location: Some(loc),
+ desc_set: None,
+ desc_index: None,
+ ..
+ } => Some(crate::Binding::Location(loc)),
+ Decoration {
+ built_in: None,
+ location: None,
+ desc_set: Some(group),
+ desc_index: Some(binding),
+ ..
+ } => Some(crate::Binding::Resource { group, binding }),
+ _ => None,
+ }
+ }
+
+ fn get_origin(&self) -> Result<crate::MemberOrigin, Error> {
+ match *self {
+ Decoration {
+ location: Some(_), ..
+ }
+ | Decoration {
+ desc_set: Some(_), ..
+ }
+ | Decoration {
+ desc_index: Some(_),
+ ..
+ } => Err(Error::MissingDecoration(spirv::Decoration::Offset)),
+ Decoration {
+ built_in: Some(built_in),
+ offset: None,
+ ..
+ } => Ok(crate::MemberOrigin::BuiltIn(built_in)),
+ Decoration {
+ built_in: None,
+ offset: Some(offset),
+ ..
+ } => Ok(crate::MemberOrigin::Offset(offset)),
+ _ => Ok(crate::MemberOrigin::Empty),
+ }
+ }
+}
+
+bitflags::bitflags! {
+ /// Flags describing sampling method.
+ pub struct SamplingFlags: u32 {
+ /// Regular sampling.
+ const REGULAR = 0x1;
+ /// Comparison sampling.
+ const COMPARISON = 0x2;
+ }
+}
+
+#[derive(Debug)]
+struct LookupFunctionType {
+ parameter_type_ids: Vec<spirv::Word>,
+ return_type_id: spirv::Word,
+}
+
+#[derive(Debug)]
+struct EntryPoint {
+ stage: crate::ShaderStage,
+ name: String,
+ early_depth_test: Option<crate::EarlyDepthTest>,
+ workgroup_size: [u32; 3],
+ function_id: spirv::Word,
+ variable_ids: Vec<spirv::Word>,
+}
+
+#[derive(Clone, Debug)]
+struct LookupType {
+ handle: Handle<crate::Type>,
+ base_id: Option<spirv::Word>,
+}
+
+#[derive(Debug)]
+struct LookupConstant {
+ handle: Handle<crate::Constant>,
+ type_id: spirv::Word,
+}
+
+#[derive(Debug)]
+struct LookupVariable {
+ handle: Handle<crate::GlobalVariable>,
+ type_id: spirv::Word,
+}
+
+#[derive(Clone, Debug)]
+struct LookupExpression {
+ handle: Handle<crate::Expression>,
+ type_id: spirv::Word,
+}
+
+#[derive(Clone, Debug)]
+struct LookupSampledImage {
+ image: Handle<crate::Expression>,
+ sampler: Handle<crate::Expression>,
+}
+#[derive(Clone, Debug)]
+enum DeferredSource {
+ EntryPoint(crate::ShaderStage, String),
+ Function(Handle<crate::Function>),
+}
+struct DeferredFunctionCall {
+ source: DeferredSource,
+ expr_handle: Handle<crate::Expression>,
+ dst_id: spirv::Word,
+}
+
+#[derive(Clone, Debug)]
+pub struct Assignment {
+ to: Handle<crate::Expression>,
+ value: Handle<crate::Expression>,
+}
+
+#[derive(Clone, Debug, Default)]
+pub struct Options {
+ pub flow_graph_dump_prefix: Option<PathBuf>,
+}
+
+pub struct Parser<I> {
+ data: I,
+ state: ModuleState,
+ temp_bytes: Vec<u8>,
+ ext_glsl_id: Option<spirv::Word>,
+ future_decor: FastHashMap<spirv::Word, Decoration>,
+ future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>,
+ lookup_member_type_id: FastHashMap<(Handle<crate::Type>, MemberIndex), spirv::Word>,
+ handle_sampling: FastHashMap<Handle<crate::Type>, SamplingFlags>,
+ lookup_type: FastHashMap<spirv::Word, LookupType>,
+ lookup_void_type: FastHashSet<spirv::Word>,
+ lookup_storage_buffer_types: FastHashSet<Handle<crate::Type>>,
+ // Lookup for samplers and sampled images, storing flags on how they are used.
+ lookup_constant: FastHashMap<spirv::Word, LookupConstant>,
+ lookup_variable: FastHashMap<spirv::Word, LookupVariable>,
+ lookup_expression: FastHashMap<spirv::Word, LookupExpression>,
+ lookup_sampled_image: FastHashMap<spirv::Word, LookupSampledImage>,
+ lookup_function_type: FastHashMap<spirv::Word, LookupFunctionType>,
+ lookup_function: FastHashMap<spirv::Word, Handle<crate::Function>>,
+ lookup_entry_point: FastHashMap<spirv::Word, EntryPoint>,
+ deferred_function_calls: Vec<DeferredFunctionCall>,
+ options: Options,
+}
+
+impl<I: Iterator<Item = u32>> Parser<I> {
+ pub fn new(data: I, options: &Options) -> Self {
+ Parser {
+ data,
+ state: ModuleState::Empty,
+ temp_bytes: Vec::new(),
+ ext_glsl_id: None,
+ future_decor: FastHashMap::default(),
+ future_member_decor: FastHashMap::default(),
+ handle_sampling: FastHashMap::default(),
+ lookup_member_type_id: FastHashMap::default(),
+ lookup_type: FastHashMap::default(),
+ lookup_void_type: FastHashSet::default(),
+ lookup_storage_buffer_types: FastHashSet::default(),
+ lookup_constant: FastHashMap::default(),
+ lookup_variable: FastHashMap::default(),
+ lookup_expression: FastHashMap::default(),
+ lookup_sampled_image: FastHashMap::default(),
+ lookup_function_type: FastHashMap::default(),
+ lookup_function: FastHashMap::default(),
+ lookup_entry_point: FastHashMap::default(),
+ deferred_function_calls: Vec::new(),
+ options: options.clone(),
+ }
+ }
+
+ fn next(&mut self) -> Result<u32, Error> {
+ self.data.next().ok_or(Error::IncompleteData)
+ }
+
+ fn next_inst(&mut self) -> Result<Instruction, Error> {
+ let word = self.next()?;
+ let (wc, opcode) = ((word >> 16) as u16, (word & 0xffff) as u16);
+ if wc == 0 {
+ return Err(Error::InvalidWordCount);
+ }
+ let op = spirv::Op::from_u16(opcode).ok_or(Error::UnknownInstruction(opcode))?;
+
+ Ok(Instruction { op, wc })
+ }
+
+ fn next_string(&mut self, mut count: u16) -> Result<(String, u16), Error> {
+ self.temp_bytes.clear();
+ loop {
+ if count == 0 {
+ return Err(Error::BadString);
+ }
+ count -= 1;
+ let chars = self.next()?.to_le_bytes();
+ let pos = chars.iter().position(|&c| c == 0).unwrap_or(4);
+ self.temp_bytes.extend_from_slice(&chars[..pos]);
+ if pos < 4 {
+ break;
+ }
+ }
+ std::str::from_utf8(&self.temp_bytes)
+ .map(|s| (s.to_owned(), count))
+ .map_err(|_| Error::BadString)
+ }
+
+ fn next_decoration(
+ &mut self,
+ inst: Instruction,
+ base_words: u16,
+ dec: &mut Decoration,
+ ) -> Result<(), Error> {
+ let raw = self.next()?;
+ let dec_typed = spirv::Decoration::from_u32(raw).ok_or(Error::InvalidDecoration(raw))?;
+ log::trace!("\t\t{}: {:?}", dec.debug_name(), dec_typed);
+ match dec_typed {
+ spirv::Decoration::BuiltIn => {
+ inst.expect(base_words + 2)?;
+ let raw = self.next()?;
+ match map_builtin(raw) {
+ Ok(built_in) => dec.built_in = Some(built_in),
+ Err(_e) => log::warn!("Unsupported builtin {}", raw),
+ };
+ }
+ spirv::Decoration::Location => {
+ inst.expect(base_words + 2)?;
+ dec.location = Some(self.next()?);
+ }
+ spirv::Decoration::DescriptorSet => {
+ inst.expect(base_words + 2)?;
+ dec.desc_set = Some(self.next()?);
+ }
+ spirv::Decoration::Binding => {
+ inst.expect(base_words + 2)?;
+ dec.desc_index = Some(self.next()?);
+ }
+ spirv::Decoration::Block => {
+ dec.block = Some(Block { buffer: false });
+ }
+ spirv::Decoration::BufferBlock => {
+ dec.block = Some(Block { buffer: true });
+ }
+ spirv::Decoration::Offset => {
+ inst.expect(base_words + 2)?;
+ dec.offset = Some(self.next()?);
+ }
+ spirv::Decoration::ArrayStride => {
+ inst.expect(base_words + 2)?;
+ dec.array_stride = NonZeroU32::new(self.next()?);
+ }
+ spirv::Decoration::NoPerspective => {
+ dec.interpolation = Some(crate::Interpolation::Linear);
+ }
+ spirv::Decoration::Flat => {
+ dec.interpolation = Some(crate::Interpolation::Flat);
+ }
+ spirv::Decoration::Patch => {
+ dec.interpolation = Some(crate::Interpolation::Patch);
+ }
+ spirv::Decoration::Centroid => {
+ dec.interpolation = Some(crate::Interpolation::Centroid);
+ }
+ spirv::Decoration::Sample => {
+ dec.interpolation = Some(crate::Interpolation::Sample);
+ }
+ spirv::Decoration::NonReadable => {
+ dec.flags |= DecorationFlags::NON_READABLE;
+ }
+ spirv::Decoration::NonWritable => {
+ dec.flags |= DecorationFlags::NON_WRITABLE;
+ }
+ other => {
+ log::warn!("Unknown decoration {:?}", other);
+ for _ in base_words + 1..inst.wc {
+ let _var = self.next()?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn parse_expr_unary_op(
+ &mut self,
+ expressions: &mut Arena<crate::Expression>,
+ op: crate::UnaryOperator,
+ ) -> Result<(), Error> {
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p_id = self.next()?;
+
+ let p_lexp = self.lookup_expression.lookup(p_id)?;
+
+ let expr = crate::Expression::Unary {
+ op,
+ expr: p_lexp.handle,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expressions.append(expr),
+ type_id: result_type_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_expr_binary_op(
+ &mut self,
+ expressions: &mut Arena<crate::Expression>,
+ op: crate::BinaryOperator,
+ ) -> Result<(), Error> {
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let p1_id = self.next()?;
+ let p2_id = self.next()?;
+
+ let p1_lexp = self.lookup_expression.lookup(p1_id)?;
+ let p2_lexp = self.lookup_expression.lookup(p2_id)?;
+
+ let expr = crate::Expression::Binary {
+ op,
+ left: p1_lexp.handle,
+ right: p2_lexp.handle,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expressions.append(expr),
+ type_id: result_type_id,
+ },
+ );
+ Ok(())
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn next_block(
+ &mut self,
+ block_id: spirv::Word,
+ expressions: &mut Arena<crate::Expression>,
+ local_arena: &mut Arena<crate::LocalVariable>,
+ type_arena: &Arena<crate::Type>,
+ const_arena: &Arena<crate::Constant>,
+ global_arena: &Arena<crate::GlobalVariable>,
+ local_function_calls: &mut FastHashMap<Handle<crate::Expression>, spirv::Word>,
+ ) -> Result<ControlFlowNode, Error> {
+ let mut assignments = Vec::new();
+ let mut phis = Vec::new();
+ let mut merge = None;
+ let terminator = loop {
+ use spirv::Op;
+ let inst = self.next_inst()?;
+ log::debug!("\t\t{:?} [{}]", inst.op, inst.wc);
+
+ match inst.op {
+ Op::Variable => {
+ inst.expect_at_least(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let storage = self.next()?;
+ match spirv::StorageClass::from_u32(storage) {
+ Some(spirv::StorageClass::Function) => (),
+ Some(class) => return Err(Error::InvalidVariableClass(class)),
+ None => return Err(Error::UnsupportedStorageClass(storage)),
+ }
+ let init = if inst.wc > 4 {
+ inst.expect(5)?;
+ let init_id = self.next()?;
+ let lconst = self.lookup_constant.lookup(init_id)?;
+ Some(lconst.handle)
+ } else {
+ None
+ };
+ let name = self
+ .future_decor
+ .remove(&result_id)
+ .and_then(|decor| decor.name);
+ if let Some(ref name) = name {
+ log::debug!("\t\t\tid={} name={}", result_id, name);
+ }
+ let var_handle = local_arena.append(crate::LocalVariable {
+ name,
+ ty: self.lookup_type.lookup(result_type_id)?.handle,
+ init,
+ });
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expressions
+ .append(crate::Expression::LocalVariable(var_handle)),
+ type_id: result_type_id,
+ },
+ );
+ }
+ Op::Phi => {
+ inst.expect_at_least(3)?;
+
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+
+ let name = format!("phi_{}", result_id);
+ let var_handle = local_arena.append(crate::LocalVariable {
+ name: Some(name),
+ ty: self.lookup_type.lookup(result_type_id)?.handle,
+ init: None,
+ });
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expressions
+ .append(crate::Expression::LocalVariable(var_handle)),
+ type_id: result_type_id,
+ },
+ );
+
+ let mut phi = PhiInstruction::default();
+ phi.id = result_id;
+ for _ in 0..(inst.wc - 3) / 2 {
+ phi.variables.push((self.next()?, self.next()?));
+ }
+
+ phis.push(phi);
+ }
+ Op::AccessChain => {
+ struct AccessExpression {
+ base_handle: Handle<crate::Expression>,
+ type_id: spirv::Word,
+ }
+ inst.expect_at_least(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ log::trace!("\t\t\tlooking up expr {:?}", base_id);
+ let mut acex = {
+ let expr = self.lookup_expression.lookup(base_id)?;
+ AccessExpression {
+ base_handle: expr.handle,
+ type_id: expr.type_id,
+ }
+ };
+ for _ in 4..inst.wc {
+ let access_id = self.next()?;
+ log::trace!("\t\t\tlooking up index expr {:?}", access_id);
+ let index_expr = self.lookup_expression.lookup(access_id)?.clone();
+ let index_type_handle = self.lookup_type.lookup(index_expr.type_id)?.handle;
+ match type_arena[index_type_handle].inner {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ }
+ | crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Sint,
+ ..
+ } => (),
+ _ => return Err(Error::UnsupportedType(index_type_handle)),
+ }
+ log::trace!("\t\t\tlooking up type {:?}", acex.type_id);
+ let type_lookup = self.lookup_type.lookup(acex.type_id)?;
+ acex = match type_arena[type_lookup.handle].inner {
+ crate::TypeInner::Struct { .. } => {
+ let index = match expressions[index_expr.handle] {
+ crate::Expression::Constant(const_handle) => {
+ match const_arena[const_handle].inner {
+ crate::ConstantInner::Uint(v) => v as u32,
+ crate::ConstantInner::Sint(v) => v as u32,
+ _ => {
+ return Err(Error::InvalidAccess(index_expr.handle))
+ }
+ }
+ }
+ _ => return Err(Error::InvalidAccess(index_expr.handle)),
+ };
+ AccessExpression {
+ base_handle: expressions.append(
+ crate::Expression::AccessIndex {
+ base: acex.base_handle,
+ index,
+ },
+ ),
+ type_id: *self
+ .lookup_member_type_id
+ .get(&(type_lookup.handle, index))
+ .ok_or(Error::InvalidAccessType(acex.type_id))?,
+ }
+ }
+ crate::TypeInner::Array { .. }
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. } => AccessExpression {
+ base_handle: expressions.append(crate::Expression::Access {
+ base: acex.base_handle,
+ index: index_expr.handle,
+ }),
+ type_id: type_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(acex.type_id))?,
+ },
+ _ => return Err(Error::UnsupportedType(type_lookup.handle)),
+ };
+ }
+
+ let lookup_expression = LookupExpression {
+ handle: acex.base_handle,
+ type_id: result_type_id,
+ };
+ self.lookup_expression.insert(result_id, lookup_expression);
+ }
+ Op::CompositeExtract => {
+ inst.expect_at_least(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let base_id = self.next()?;
+ log::trace!("\t\t\tlooking up expr {:?}", base_id);
+ let mut lexp = {
+ let expr = self.lookup_expression.lookup(base_id)?;
+ LookupExpression {
+ handle: expr.handle,
+ type_id: expr.type_id,
+ }
+ };
+ for _ in 4..inst.wc {
+ let index = self.next()?;
+ log::trace!("\t\t\tlooking up type {:?}", lexp.type_id);
+ let type_lookup = self.lookup_type.lookup(lexp.type_id)?;
+ let type_id = match type_arena[type_lookup.handle].inner {
+ crate::TypeInner::Struct { .. } => *self
+ .lookup_member_type_id
+ .get(&(type_lookup.handle, index))
+ .ok_or(Error::InvalidAccessType(lexp.type_id))?,
+ crate::TypeInner::Array { .. }
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. } => type_lookup
+ .base_id
+ .ok_or(Error::InvalidAccessType(lexp.type_id))?,
+ _ => return Err(Error::UnsupportedType(type_lookup.handle)),
+ };
+ lexp = LookupExpression {
+ handle: expressions.append(crate::Expression::AccessIndex {
+ base: lexp.handle,
+ index,
+ }),
+ type_id,
+ };
+ }
+
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: lexp.handle,
+ type_id: result_type_id,
+ },
+ );
+ }
+ Op::CompositeConstruct => {
+ inst.expect_at_least(3)?;
+ let result_type_id = self.next()?;
+ let id = self.next()?;
+ let mut components = Vec::with_capacity(inst.wc as usize - 2);
+ for _ in 3..inst.wc {
+ let comp_id = self.next()?;
+ log::trace!("\t\t\tlooking up expr {:?}", comp_id);
+ let lexp = self.lookup_expression.lookup(comp_id)?;
+ components.push(lexp.handle);
+ }
+ let expr = crate::Expression::Compose {
+ ty: self.lookup_type.lookup(result_type_id)?.handle,
+ components,
+ };
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ handle: expressions.append(expr),
+ type_id: result_type_id,
+ },
+ );
+ }
+ Op::Load => {
+ inst.expect_at_least(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let pointer_id = self.next()?;
+ if inst.wc != 4 {
+ inst.expect(5)?;
+ let _memory_access = self.next()?;
+ }
+ let base_expr = self.lookup_expression.lookup(pointer_id)?.clone();
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: base_expr.handle, // pass-through pointers
+ type_id: result_type_id,
+ },
+ );
+ }
+ Op::Store => {
+ inst.expect_at_least(3)?;
+ let pointer_id = self.next()?;
+ let value_id = self.next()?;
+ if inst.wc != 3 {
+ inst.expect(4)?;
+ let _memory_access = self.next()?;
+ }
+ let base_expr = self.lookup_expression.lookup(pointer_id)?;
+ let value_expr = self.lookup_expression.lookup(value_id)?;
+ assignments.push(Assignment {
+ to: base_expr.handle,
+ value: value_expr.handle,
+ });
+ }
+ // Arithmetic Instructions +, -, *, /, %
+ Op::SNegate | Op::FNegate => {
+ inst.expect(4)?;
+ self.parse_expr_unary_op(expressions, crate::UnaryOperator::Negate)?;
+ }
+ Op::IAdd | Op::FAdd => {
+ inst.expect(5)?;
+ self.parse_expr_binary_op(expressions, crate::BinaryOperator::Add)?;
+ }
+ Op::ISub | Op::FSub => {
+ inst.expect(5)?;
+ self.parse_expr_binary_op(expressions, crate::BinaryOperator::Subtract)?;
+ }
+ Op::IMul | Op::FMul => {
+ inst.expect(5)?;
+ self.parse_expr_binary_op(expressions, crate::BinaryOperator::Multiply)?;
+ }
+ Op::SDiv | Op::UDiv | Op::FDiv => {
+ inst.expect(5)?;
+ self.parse_expr_binary_op(expressions, crate::BinaryOperator::Divide)?;
+ }
+ Op::UMod | Op::FMod | Op::SRem | Op::FRem => {
+ inst.expect(5)?;
+ self.parse_expr_binary_op(expressions, crate::BinaryOperator::Modulo)?;
+ }
+ Op::VectorTimesScalar
+ | Op::VectorTimesMatrix
+ | Op::MatrixTimesScalar
+ | Op::MatrixTimesVector
+ | Op::MatrixTimesMatrix => {
+ inst.expect(5)?;
+ self.parse_expr_binary_op(expressions, crate::BinaryOperator::Multiply)?;
+ }
+ Op::Transpose => {
+ inst.expect(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let matrix_id = self.next()?;
+ let matrix_lexp = self.lookup_expression.lookup(matrix_id)?;
+ let expr = crate::Expression::Transpose(matrix_lexp.handle);
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expressions.append(expr),
+ type_id: result_type_id,
+ },
+ );
+ }
+ Op::Dot => {
+ inst.expect(5)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let left_id = self.next()?;
+ let right_id = self.next()?;
+ let left_lexp = self.lookup_expression.lookup(left_id)?;
+ let right_lexp = self.lookup_expression.lookup(right_id)?;
+ let expr = crate::Expression::DotProduct(left_lexp.handle, right_lexp.handle);
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expressions.append(expr),
+ type_id: result_type_id,
+ },
+ );
+ }
+ // Bitwise instructions
+ Op::Not => {
+ inst.expect(4)?;
+ self.parse_expr_unary_op(expressions, crate::UnaryOperator::Not)?;
+ }
+ Op::BitwiseOr => {
+ inst.expect(5)?;
+ self.parse_expr_binary_op(expressions, crate::BinaryOperator::InclusiveOr)?;
+ }
+ Op::BitwiseXor => {
+ inst.expect(5)?;
+ self.parse_expr_binary_op(expressions, crate::BinaryOperator::ExclusiveOr)?;
+ }
+ Op::BitwiseAnd => {
+ inst.expect(5)?;
+ self.parse_expr_binary_op(expressions, crate::BinaryOperator::And)?;
+ }
+ Op::ShiftRightLogical => {
+ inst.expect(5)?;
+ //TODO: convert input and result to usigned
+ self.parse_expr_binary_op(expressions, crate::BinaryOperator::ShiftRight)?;
+ }
+ Op::ShiftRightArithmetic => {
+ inst.expect(5)?;
+ //TODO: convert input and result to signed
+ self.parse_expr_binary_op(expressions, crate::BinaryOperator::ShiftRight)?;
+ }
+ Op::ShiftLeftLogical => {
+ inst.expect(5)?;
+ self.parse_expr_binary_op(expressions, crate::BinaryOperator::ShiftLeft)?;
+ }
+ // Sampling
+ Op::SampledImage => {
+ inst.expect(5)?;
+ let _result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let image_id = self.next()?;
+ let sampler_id = self.next()?;
+ let image_lexp = self.lookup_expression.lookup(image_id)?;
+ let sampler_lexp = self.lookup_expression.lookup(sampler_id)?;
+ //TODO: compare the result type
+ self.lookup_sampled_image.insert(
+ result_id,
+ LookupSampledImage {
+ image: image_lexp.handle,
+ sampler: sampler_lexp.handle,
+ },
+ );
+ }
+ Op::ImageSampleImplicitLod | Op::ImageSampleExplicitLod => {
+ inst.expect_at_least(5)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let sampled_image_id = self.next()?;
+ let coordinate_id = self.next()?;
+ let si_lexp = self.lookup_sampled_image.lookup(sampled_image_id)?.clone();
+ let coord_lexp = self.lookup_expression.lookup(coordinate_id)?.clone();
+ let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle;
+
+ let sampler_type_handle =
+ reach_global_type(si_lexp.sampler, &expressions, global_arena)
+ .ok_or(Error::InvalidSamplerExpression(si_lexp.sampler))?;
+ let image_type_handle =
+ reach_global_type(si_lexp.image, &expressions, global_arena)
+ .ok_or(Error::InvalidImageExpression(si_lexp.image))?;
+ log::debug!(
+ "\t\t\tImage {:?} with sampler {:?}",
+ image_type_handle,
+ sampler_type_handle
+ );
+ *self.handle_sampling.get_mut(&sampler_type_handle).unwrap() |=
+ SamplingFlags::REGULAR;
+ *self.handle_sampling.get_mut(&image_type_handle).unwrap() |=
+ SamplingFlags::REGULAR;
+ match type_arena[sampler_type_handle].inner {
+ crate::TypeInner::Sampler { comparison: false } => (),
+ _ => return Err(Error::InvalidSampleSampler(sampler_type_handle)),
+ };
+ match type_arena[image_type_handle].inner {
+ //TODO: compare the result type
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class:
+ crate::ImageClass::Sampled {
+ kind: _,
+ multi: false,
+ },
+ } => {
+ if !check_sample_coordinates(
+ &type_arena[coord_type_handle],
+ crate::ScalarKind::Float,
+ dim,
+ arrayed,
+ ) {
+ return Err(Error::InvalidSampleCoordinates(coord_type_handle));
+ }
+ }
+ _ => return Err(Error::InvalidSampleImage(image_type_handle)),
+ };
+
+ let mut level = crate::SampleLevel::Auto;
+ let mut base_wc = 5;
+ if base_wc < inst.wc {
+ let image_ops = self.next()?;
+ base_wc += 1;
+ let mask = spirv::ImageOperands::from_bits_truncate(image_ops);
+ if mask.contains(spirv::ImageOperands::BIAS) {
+ let bias_expr = self.next()?;
+ let bias_handle = self.lookup_expression.lookup(bias_expr)?.handle;
+ level = crate::SampleLevel::Bias(bias_handle);
+ base_wc += 1;
+ }
+ if mask.contains(spirv::ImageOperands::LOD) {
+ let lod_expr = self.next()?;
+ let lod_handle = self.lookup_expression.lookup(lod_expr)?.handle;
+ level = crate::SampleLevel::Exact(lod_handle);
+ base_wc += 1;
+ }
+ for _ in base_wc..inst.wc {
+ self.next()?;
+ }
+ }
+
+ let expr = crate::Expression::ImageSample {
+ image: si_lexp.image,
+ sampler: si_lexp.sampler,
+ coordinate: coord_lexp.handle,
+ level,
+ depth_ref: None,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expressions.append(expr),
+ type_id: result_type_id,
+ },
+ );
+ }
+ Op::ImageSampleDrefImplicitLod => {
+ inst.expect_at_least(6)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let sampled_image_id = self.next()?;
+ let coordinate_id = self.next()?;
+ let dref_id = self.next()?;
+
+ let si_lexp = self.lookup_sampled_image.lookup(sampled_image_id)?;
+ let coord_lexp = self.lookup_expression.lookup(coordinate_id)?;
+ let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle;
+ let sampler_type_handle =
+ reach_global_type(si_lexp.sampler, &expressions, global_arena)
+ .ok_or(Error::InvalidSamplerExpression(si_lexp.sampler))?;
+ let image_type_handle =
+ reach_global_type(si_lexp.image, &expressions, global_arena)
+ .ok_or(Error::InvalidImageExpression(si_lexp.image))?;
+ *self.handle_sampling.get_mut(&sampler_type_handle).unwrap() |=
+ SamplingFlags::COMPARISON;
+ *self.handle_sampling.get_mut(&image_type_handle).unwrap() |=
+ SamplingFlags::COMPARISON;
+ match type_arena[sampler_type_handle].inner {
+ crate::TypeInner::Sampler { comparison: true } => (),
+ _ => return Err(Error::InvalidSampleSampler(sampler_type_handle)),
+ };
+ match type_arena[image_type_handle].inner {
+ //TODO: compare the result type
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class: crate::ImageClass::Depth,
+ } => {
+ if !check_sample_coordinates(
+ &type_arena[coord_type_handle],
+ crate::ScalarKind::Float,
+ dim,
+ arrayed,
+ ) {
+ return Err(Error::InvalidSampleCoordinates(coord_type_handle));
+ }
+ }
+ _ => return Err(Error::InvalidSampleImage(image_type_handle)),
+ };
+
+ let dref_lexp = self.lookup_expression.lookup(dref_id)?;
+ let dref_type_handle = self.lookup_type.lookup(dref_lexp.type_id)?.handle;
+ match type_arena[dref_type_handle].inner {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Float,
+ width: _,
+ } => (),
+ _ => return Err(Error::InvalidDepthReference(dref_type_handle)),
+ }
+
+ let expr = crate::Expression::ImageSample {
+ image: si_lexp.image,
+ sampler: si_lexp.sampler,
+ coordinate: coord_lexp.handle,
+ level: crate::SampleLevel::Auto,
+ depth_ref: Some(dref_lexp.handle),
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expressions.append(expr),
+ type_id: result_type_id,
+ },
+ );
+ }
+ Op::Select => {
+ inst.expect(6)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let condition = self.next()?;
+ let o1_id = self.next()?;
+ let o2_id = self.next()?;
+
+ let cond_lexp = self.lookup_expression.lookup(condition)?;
+ let o1_lexp = self.lookup_expression.lookup(o1_id)?;
+ let o2_lexp = self.lookup_expression.lookup(o2_id)?;
+
+ let expr = crate::Expression::Select {
+ condition: cond_lexp.handle,
+ accept: o1_lexp.handle,
+ reject: o2_lexp.handle,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expressions.append(expr),
+ type_id: result_type_id,
+ },
+ );
+ }
+ Op::VectorShuffle => {
+ inst.expect_at_least(5)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let v1_id = self.next()?;
+ let v2_id = self.next()?;
+
+ let v1_lexp = self.lookup_expression.lookup(v1_id)?;
+ let v1_lty = self.lookup_type.lookup(v1_lexp.type_id)?;
+ let n1 = match type_arena[v1_lty.handle].inner {
+ crate::TypeInner::Vector { size, .. } => size as u8,
+ _ => return Err(Error::InvalidInnerType(v1_lexp.type_id)),
+ };
+ let v1_handle = v1_lexp.handle;
+ let v2_lexp = self.lookup_expression.lookup(v2_id)?;
+ let v2_lty = self.lookup_type.lookup(v2_lexp.type_id)?;
+ let n2 = match type_arena[v2_lty.handle].inner {
+ crate::TypeInner::Vector { size, .. } => size as u8,
+ _ => return Err(Error::InvalidInnerType(v2_lexp.type_id)),
+ };
+ let v2_handle = v2_lexp.handle;
+
+ let mut components = Vec::with_capacity(inst.wc as usize - 5);
+ for _ in 0..components.capacity() {
+ let index = self.next()?;
+ let expr = if index < n1 as u32 {
+ crate::Expression::AccessIndex {
+ base: v1_handle,
+ index,
+ }
+ } else if index < n1 as u32 + n2 as u32 {
+ crate::Expression::AccessIndex {
+ base: v2_handle,
+ index: index - n1 as u32,
+ }
+ } else {
+ return Err(Error::InvalidAccessIndex(index));
+ };
+ components.push(expressions.append(expr));
+ }
+ let expr = crate::Expression::Compose {
+ ty: self.lookup_type.lookup(result_type_id)?.handle,
+ components,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expressions.append(expr),
+ type_id: result_type_id,
+ },
+ );
+ }
+ Op::Bitcast
+ | Op::ConvertSToF
+ | Op::ConvertUToF
+ | Op::ConvertFToU
+ | Op::ConvertFToS => {
+ inst.expect_at_least(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let value_id = self.next()?;
+
+ let value_lexp = self.lookup_expression.lookup(value_id)?;
+ let ty_lookup = self.lookup_type.lookup(result_type_id)?;
+ let kind = type_arena[ty_lookup.handle]
+ .inner
+ .scalar_kind()
+ .ok_or(Error::InvalidAsType(ty_lookup.handle))?;
+
+ let expr = crate::Expression::As {
+ expr: value_lexp.handle,
+ kind,
+ convert: inst.op != Op::Bitcast,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expressions.append(expr),
+ type_id: result_type_id,
+ },
+ );
+ }
+ Op::FunctionCall => {
+ inst.expect_at_least(4)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let func_id = self.next()?;
+
+ let mut arguments = Vec::with_capacity(inst.wc as usize - 4);
+ for _ in 0..arguments.capacity() {
+ let arg_id = self.next()?;
+ arguments.push(self.lookup_expression.lookup(arg_id)?.handle);
+ }
+ let expr = crate::Expression::Call {
+ // will be replaced by `Local()` after all the functions are parsed
+ origin: crate::FunctionOrigin::External(String::new()),
+ arguments,
+ };
+ let expr_handle = expressions.append(expr);
+ local_function_calls.insert(expr_handle, func_id);
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expr_handle,
+ type_id: result_type_id,
+ },
+ );
+ }
+ Op::ExtInst => {
+ let base_wc = 5;
+ inst.expect_at_least(base_wc)?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let set_id = self.next()?;
+ if Some(set_id) != self.ext_glsl_id {
+ return Err(Error::UnsupportedExtInstSet(set_id));
+ }
+ let inst_id = self.next()?;
+ let name = match spirv::GLOp::from_u32(inst_id) {
+ Some(spirv::GLOp::FAbs) | Some(spirv::GLOp::SAbs) => {
+ inst.expect(base_wc + 1)?;
+ "abs"
+ }
+ Some(spirv::GLOp::FSign) | Some(spirv::GLOp::SSign) => {
+ inst.expect(base_wc + 1)?;
+ "sign"
+ }
+ Some(spirv::GLOp::Floor) => {
+ inst.expect(base_wc + 1)?;
+ "floor"
+ }
+ Some(spirv::GLOp::Ceil) => {
+ inst.expect(base_wc + 1)?;
+ "ceil"
+ }
+ Some(spirv::GLOp::Fract) => {
+ inst.expect(base_wc + 1)?;
+ "fract"
+ }
+ Some(spirv::GLOp::Sin) => {
+ inst.expect(base_wc + 1)?;
+ "sin"
+ }
+ Some(spirv::GLOp::Cos) => {
+ inst.expect(base_wc + 1)?;
+ "cos"
+ }
+ Some(spirv::GLOp::Tan) => {
+ inst.expect(base_wc + 1)?;
+ "tan"
+ }
+ Some(spirv::GLOp::Atan2) => {
+ inst.expect(base_wc + 2)?;
+ "atan2"
+ }
+ Some(spirv::GLOp::Pow) => {
+ inst.expect(base_wc + 2)?;
+ "pow"
+ }
+ Some(spirv::GLOp::MatrixInverse) => {
+ inst.expect(base_wc + 1)?;
+ "inverse"
+ }
+ Some(spirv::GLOp::FMix) => {
+ inst.expect(base_wc + 3)?;
+ "mix"
+ }
+ Some(spirv::GLOp::Step) => {
+ inst.expect(base_wc + 2)?;
+ "step"
+ }
+ Some(spirv::GLOp::SmoothStep) => {
+ inst.expect(base_wc + 3)?;
+ "smoothstep"
+ }
+ Some(spirv::GLOp::FMin) => {
+ inst.expect(base_wc + 2)?;
+ "min"
+ }
+ Some(spirv::GLOp::FMax) => {
+ inst.expect(base_wc + 2)?;
+ "max"
+ }
+ Some(spirv::GLOp::FClamp) => {
+ inst.expect(base_wc + 3)?;
+ "clamp"
+ }
+ Some(spirv::GLOp::Length) => {
+ inst.expect(base_wc + 1)?;
+ "length"
+ }
+ Some(spirv::GLOp::Distance) => {
+ inst.expect(base_wc + 2)?;
+ "distance"
+ }
+ Some(spirv::GLOp::Cross) => {
+ inst.expect(base_wc + 2)?;
+ "cross"
+ }
+ Some(spirv::GLOp::Normalize) => {
+ inst.expect(base_wc + 1)?;
+ "normalize"
+ }
+ Some(spirv::GLOp::Reflect) => {
+ inst.expect(base_wc + 2)?;
+ "reflect"
+ }
+ _ => return Err(Error::UnsupportedExtInst(inst_id)),
+ };
+
+ let mut arguments = Vec::with_capacity((inst.wc - base_wc) as usize);
+ for _ in 0..arguments.capacity() {
+ let arg_id = self.next()?;
+ arguments.push(self.lookup_expression.lookup(arg_id)?.handle);
+ }
+ let expr = crate::Expression::Call {
+ origin: crate::FunctionOrigin::External(name.to_string()),
+ arguments,
+ };
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: expressions.append(expr),
+ type_id: result_type_id,
+ },
+ );
+ }
+ // Relational and Logical Instructions
+ Op::LogicalNot => {
+ inst.expect(4)?;
+ self.parse_expr_unary_op(expressions, crate::UnaryOperator::Not)?;
+ }
+ op if inst.op >= Op::IEqual && inst.op <= Op::FUnordGreaterThanEqual => {
+ inst.expect(5)?;
+ self.parse_expr_binary_op(expressions, map_binary_operator(op)?)?;
+ }
+ Op::Kill => {
+ inst.expect(1)?;
+ break Terminator::Kill;
+ }
+ Op::Unreachable => {
+ inst.expect(1)?;
+ break Terminator::Unreachable;
+ }
+ Op::Return => {
+ inst.expect(1)?;
+ break Terminator::Return { value: None };
+ }
+ Op::ReturnValue => {
+ inst.expect(2)?;
+ let value_id = self.next()?;
+ let value_lexp = self.lookup_expression.lookup(value_id)?;
+ break Terminator::Return {
+ value: Some(value_lexp.handle),
+ };
+ }
+ Op::Branch => {
+ inst.expect(2)?;
+ let target_id = self.next()?;
+ break Terminator::Branch { target_id };
+ }
+ Op::BranchConditional => {
+ inst.expect_at_least(4)?;
+
+ let condition_id = self.next()?;
+ let condition = self.lookup_expression.lookup(condition_id)?.handle;
+
+ let true_id = self.next()?;
+ let false_id = self.next()?;
+
+ break Terminator::BranchConditional {
+ condition,
+ true_id,
+ false_id,
+ };
+ }
+ Op::Switch => {
+ inst.expect_at_least(3)?;
+
+ let selector = self.next()?;
+ let selector = self.lookup_expression[&selector].handle;
+ let default = self.next()?;
+
+ let mut targets = Vec::new();
+ for _ in 0..(inst.wc - 3) / 2 {
+ let literal = self.next()?;
+ let target = self.next()?;
+ targets.push((literal as i32, target));
+ }
+
+ break Terminator::Switch {
+ selector,
+ default,
+ targets,
+ };
+ }
+ Op::SelectionMerge => {
+ inst.expect(3)?;
+ let merge_block_id = self.next()?;
+ // TODO: Selection Control Mask
+ let _selection_control = self.next()?;
+ let continue_block_id = None;
+ merge = Some(MergeInstruction {
+ merge_block_id,
+ continue_block_id,
+ });
+ }
+ Op::LoopMerge => {
+ inst.expect_at_least(4)?;
+ let merge_block_id = self.next()?;
+ let continue_block_id = Some(self.next()?);
+
+ // TODO: Loop Control Parameters
+ for _ in 0..inst.wc - 3 {
+ self.next()?;
+ }
+
+ merge = Some(MergeInstruction {
+ merge_block_id,
+ continue_block_id,
+ });
+ }
+ _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)),
+ }
+ };
+
+ let mut block = Vec::new();
+ for assignment in assignments.iter() {
+ block.push(crate::Statement::Store {
+ pointer: assignment.to,
+ value: assignment.value,
+ });
+ }
+
+ Ok(ControlFlowNode {
+ id: block_id,
+ ty: None,
+ phis,
+ block,
+ terminator,
+ merge,
+ })
+ }
+
+ fn make_expression_storage(&mut self) -> Arena<crate::Expression> {
+ let mut expressions = Arena::new();
+ #[allow(clippy::panic)]
+ {
+ assert!(self.lookup_expression.is_empty());
+ }
+ // register global variables
+ for (&id, var) in self.lookup_variable.iter() {
+ let handle = expressions.append(crate::Expression::GlobalVariable(var.handle));
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ type_id: var.type_id,
+ handle,
+ },
+ );
+ }
+ // register constants
+ for (&id, con) in self.lookup_constant.iter() {
+ let handle = expressions.append(crate::Expression::Constant(con.handle));
+ self.lookup_expression.insert(
+ id,
+ LookupExpression {
+ type_id: con.type_id,
+ handle,
+ },
+ );
+ }
+ // done
+ expressions
+ }
+
+ fn switch(&mut self, state: ModuleState, op: spirv::Op) -> Result<(), Error> {
+ if state < self.state {
+ Err(Error::UnsupportedInstruction(self.state, op))
+ } else {
+ self.state = state;
+ Ok(())
+ }
+ }
+
+ pub fn parse(mut self) -> Result<crate::Module, Error> {
+ let mut module = {
+ if self.next()? != spirv::MAGIC_NUMBER {
+ return Err(Error::InvalidHeader);
+ }
+ let _version_raw = self.next()?.to_le_bytes();
+ let _generator = self.next()?;
+ let _bound = self.next()?;
+ let _schema = self.next()?;
+ crate::Module::generate_empty()
+ };
+
+ while let Ok(inst) = self.next_inst() {
+ use spirv::Op;
+ log::debug!("\t{:?} [{}]", inst.op, inst.wc);
+ match inst.op {
+ Op::Capability => self.parse_capability(inst),
+ Op::Extension => self.parse_extension(inst),
+ Op::ExtInstImport => self.parse_ext_inst_import(inst),
+ Op::MemoryModel => self.parse_memory_model(inst),
+ Op::EntryPoint => self.parse_entry_point(inst),
+ Op::ExecutionMode => self.parse_execution_mode(inst),
+ Op::Source => self.parse_source(inst),
+ Op::SourceExtension => self.parse_source_extension(inst),
+ Op::Name => self.parse_name(inst),
+ Op::MemberName => self.parse_member_name(inst),
+ Op::Decorate => self.parse_decorate(inst),
+ Op::MemberDecorate => self.parse_member_decorate(inst),
+ Op::TypeVoid => self.parse_type_void(inst),
+ Op::TypeBool => self.parse_type_bool(inst, &mut module),
+ Op::TypeInt => self.parse_type_int(inst, &mut module),
+ Op::TypeFloat => self.parse_type_float(inst, &mut module),
+ Op::TypeVector => self.parse_type_vector(inst, &mut module),
+ Op::TypeMatrix => self.parse_type_matrix(inst, &mut module),
+ Op::TypeFunction => self.parse_type_function(inst),
+ Op::TypePointer => self.parse_type_pointer(inst, &mut module),
+ Op::TypeArray => self.parse_type_array(inst, &mut module),
+ Op::TypeRuntimeArray => self.parse_type_runtime_array(inst, &mut module),
+ Op::TypeStruct => self.parse_type_struct(inst, &mut module),
+ Op::TypeImage => self.parse_type_image(inst, &mut module),
+ Op::TypeSampledImage => self.parse_type_sampled_image(inst),
+ Op::TypeSampler => self.parse_type_sampler(inst, &mut module),
+ Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module),
+ Op::ConstantComposite => self.parse_composite_constant(inst, &mut module),
+ Op::Variable => self.parse_global_variable(inst, &mut module),
+ Op::Function => self.parse_function(inst, &mut module),
+ _ => Err(Error::UnsupportedInstruction(self.state, inst.op)), //TODO
+ }?;
+ }
+
+ // Check all the images and samplers to have consistent comparison property.
+ for (handle, flags) in self.handle_sampling.drain() {
+ if !flags.contains(SamplingFlags::COMPARISON) {
+ continue;
+ }
+ if flags == SamplingFlags::all() {
+ return Err(Error::InconsistentComparisonSampling(handle));
+ }
+ let ty = module.types.get_mut(handle);
+ match ty.inner {
+ crate::TypeInner::Sampler { ref mut comparison } => {
+ #[allow(clippy::panic)]
+ {
+ assert!(!*comparison)
+ };
+ *comparison = true;
+ }
+ _ => {
+ return Err(Error::UnexpectedComparisonType(handle));
+ }
+ }
+ }
+
+ for dfc in self.deferred_function_calls.drain(..) {
+ let dst_handle = *self.lookup_function.lookup(dfc.dst_id)?;
+ let fun = match dfc.source {
+ DeferredSource::Function(fun_handle) => module.functions.get_mut(fun_handle),
+ DeferredSource::EntryPoint(stage, name) => {
+ &mut module
+ .entry_points
+ .get_mut(&(stage, name))
+ .unwrap()
+ .function
+ }
+ };
+ match *fun.expressions.get_mut(dfc.expr_handle) {
+ crate::Expression::Call {
+ ref mut origin,
+ arguments: _,
+ } => *origin = crate::FunctionOrigin::Local(dst_handle),
+ _ => unreachable!(),
+ }
+ }
+
+ if !self.future_decor.is_empty() {
+ log::warn!("Unused item decorations: {:?}", self.future_decor);
+ self.future_decor.clear();
+ }
+ if !self.future_member_decor.is_empty() {
+ log::warn!("Unused member decorations: {:?}", self.future_member_decor);
+ self.future_member_decor.clear();
+ }
+
+ Ok(module)
+ }
+
+ fn parse_capability(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Capability, inst.op)?;
+ inst.expect(2)?;
+ let capability = self.next()?;
+ let cap =
+ spirv::Capability::from_u32(capability).ok_or(Error::UnknownCapability(capability))?;
+ if !SUPPORTED_CAPABILITIES.contains(&cap) {
+ return Err(Error::UnsupportedCapability(cap));
+ }
+ Ok(())
+ }
+
+ fn parse_extension(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Extension, inst.op)?;
+ inst.expect_at_least(2)?;
+ let (name, left) = self.next_string(inst.wc - 1)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+ if !SUPPORTED_EXTENSIONS.contains(&name.as_str()) {
+ return Err(Error::UnsupportedExtension(name));
+ }
+ Ok(())
+ }
+
+ fn parse_ext_inst_import(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Extension, inst.op)?;
+ inst.expect_at_least(3)?;
+ let result_id = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 2)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+ if !SUPPORTED_EXT_SETS.contains(&name.as_str()) {
+ return Err(Error::UnsupportedExtSet(name));
+ }
+ self.ext_glsl_id = Some(result_id);
+ Ok(())
+ }
+
+ fn parse_memory_model(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::MemoryModel, inst.op)?;
+ inst.expect(3)?;
+ let _addressing_model = self.next()?;
+ let _memory_model = self.next()?;
+ Ok(())
+ }
+
+ fn parse_entry_point(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::EntryPoint, inst.op)?;
+ inst.expect_at_least(4)?;
+ let exec_model = self.next()?;
+ let exec_model = spirv::ExecutionModel::from_u32(exec_model)
+ .ok_or(Error::UnsupportedExecutionModel(exec_model))?;
+ let function_id = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 3)?;
+ let ep = EntryPoint {
+ stage: match exec_model {
+ spirv::ExecutionModel::Vertex => crate::ShaderStage::Vertex,
+ spirv::ExecutionModel::Fragment => crate::ShaderStage::Fragment,
+ spirv::ExecutionModel::GLCompute => crate::ShaderStage::Compute,
+ _ => return Err(Error::UnsupportedExecutionModel(exec_model as u32)),
+ },
+ name,
+ early_depth_test: None,
+ workgroup_size: [0; 3],
+ function_id,
+ variable_ids: self.data.by_ref().take(left as usize).collect(),
+ };
+ self.lookup_entry_point.insert(function_id, ep);
+ Ok(())
+ }
+
+ fn parse_execution_mode(&mut self, inst: Instruction) -> Result<(), Error> {
+ use spirv::ExecutionMode;
+
+ self.switch(ModuleState::ExecutionMode, inst.op)?;
+ inst.expect_at_least(3)?;
+
+ let ep_id = self.next()?;
+ let mode_id = self.next()?;
+ let args: Vec<spirv::Word> = self.data.by_ref().take(inst.wc as usize - 3).collect();
+
+ let ep = self
+ .lookup_entry_point
+ .get_mut(&ep_id)
+ .ok_or(Error::InvalidId(ep_id))?;
+ let mode = spirv::ExecutionMode::from_u32(mode_id)
+ .ok_or(Error::UnsupportedExecutionMode(mode_id))?;
+
+ match mode {
+ ExecutionMode::EarlyFragmentTests => {
+ if ep.early_depth_test.is_none() {
+ ep.early_depth_test = Some(crate::EarlyDepthTest { conservative: None });
+ }
+ }
+ ExecutionMode::DepthUnchanged => {
+ ep.early_depth_test = Some(crate::EarlyDepthTest {
+ conservative: Some(crate::ConservativeDepth::Unchanged),
+ });
+ }
+ ExecutionMode::DepthGreater => {
+ ep.early_depth_test = Some(crate::EarlyDepthTest {
+ conservative: Some(crate::ConservativeDepth::GreaterEqual),
+ });
+ }
+ ExecutionMode::DepthLess => {
+ ep.early_depth_test = Some(crate::EarlyDepthTest {
+ conservative: Some(crate::ConservativeDepth::LessEqual),
+ });
+ }
+ ExecutionMode::DepthReplacing => {
+ // Ignored because it can be deduced from the IR.
+ }
+ ExecutionMode::OriginUpperLeft => {
+ // Ignored because the other option (OriginLowerLeft) is not valid in Vulkan mode.
+ }
+ ExecutionMode::LocalSize => {
+ ep.workgroup_size = [args[0], args[1], args[2]];
+ }
+ _ => {
+ return Err(Error::UnsupportedExecutionMode(mode_id));
+ }
+ }
+
+ Ok(())
+ }
+
+ fn parse_source(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Source, inst.op)?;
+ for _ in 1..inst.wc {
+ let _ = self.next()?;
+ }
+ Ok(())
+ }
+
+ fn parse_source_extension(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Source, inst.op)?;
+ inst.expect_at_least(2)?;
+ let (_name, _) = self.next_string(inst.wc - 1)?;
+ Ok(())
+ }
+
+ fn parse_name(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Name, inst.op)?;
+ inst.expect_at_least(3)?;
+ let id = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 2)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+ self.future_decor.entry(id).or_default().name = Some(name);
+ Ok(())
+ }
+
+ fn parse_member_name(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Name, inst.op)?;
+ inst.expect_at_least(4)?;
+ let id = self.next()?;
+ let member = self.next()?;
+ let (name, left) = self.next_string(inst.wc - 3)?;
+ if left != 0 {
+ return Err(Error::InvalidOperand);
+ }
+
+ self.future_member_decor
+ .entry((id, member))
+ .or_default()
+ .name = Some(name);
+ Ok(())
+ }
+
+ fn parse_decorate(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Annotation, inst.op)?;
+ inst.expect_at_least(3)?;
+ let id = self.next()?;
+ let mut dec = self.future_decor.remove(&id).unwrap_or_default();
+ self.next_decoration(inst, 2, &mut dec)?;
+ self.future_decor.insert(id, dec);
+ Ok(())
+ }
+
+ fn parse_member_decorate(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Annotation, inst.op)?;
+ inst.expect_at_least(4)?;
+ let id = self.next()?;
+ let member = self.next()?;
+
+ let mut dec = self
+ .future_member_decor
+ .remove(&(id, member))
+ .unwrap_or_default();
+ self.next_decoration(inst, 3, &mut dec)?;
+ self.future_member_decor.insert((id, member), dec);
+ Ok(())
+ }
+
+ fn parse_type_void(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(2)?;
+ let id = self.next()?;
+ self.lookup_void_type.insert(id);
+ Ok(())
+ }
+
+ fn parse_type_bool(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(2)?;
+ let id = self.next()?;
+ let inner = crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: 1,
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.append(crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ }),
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_int(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let width = self.next()?;
+ let sign = self.next()?;
+ let inner = crate::TypeInner::Scalar {
+ kind: match sign {
+ 0 => crate::ScalarKind::Uint,
+ 1 => crate::ScalarKind::Sint,
+ _ => return Err(Error::InvalidSign(sign)),
+ },
+ width: map_width(width)?,
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.append(crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ }),
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_float(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let id = self.next()?;
+ let width = self.next()?;
+ let inner = crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Float,
+ width: map_width(width)?,
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.append(crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ }),
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_vector(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let type_id = self.next()?;
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let (kind, width) = match module.types[type_lookup.handle].inner {
+ crate::TypeInner::Scalar { kind, width } => (kind, width),
+ _ => return Err(Error::InvalidInnerType(type_id)),
+ };
+ let component_count = self.next()?;
+ let inner = crate::TypeInner::Vector {
+ size: map_vector_size(component_count)?,
+ kind,
+ width,
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.append(crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ }),
+ base_id: Some(type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_matrix(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let vector_type_id = self.next()?;
+ let num_columns = self.next()?;
+ let vector_type_lookup = self.lookup_type.lookup(vector_type_id)?;
+ let inner = match module.types[vector_type_lookup.handle].inner {
+ crate::TypeInner::Vector { size, width, .. } => crate::TypeInner::Matrix {
+ columns: map_vector_size(num_columns)?,
+ rows: size,
+ width,
+ },
+ _ => return Err(Error::InvalidInnerType(vector_type_id)),
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.append(crate::Type {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ inner,
+ }),
+ base_id: Some(vector_type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_function(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(3)?;
+ let id = self.next()?;
+ let return_type_id = self.next()?;
+ let parameter_type_ids = self.data.by_ref().take(inst.wc as usize - 3).collect();
+ self.lookup_function_type.insert(
+ id,
+ LookupFunctionType {
+ parameter_type_ids,
+ return_type_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_pointer(
+ &mut self,
+ inst: Instruction,
+ _module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let _storage = self.next()?;
+ let type_id = self.next()?;
+ let type_lookup = self.lookup_type.lookup(type_id)?.clone();
+ self.lookup_type.insert(id, type_lookup); // don't register pointers in the IR
+ Ok(())
+ }
+
+ fn parse_type_array(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(4)?;
+ let id = self.next()?;
+ let type_id = self.next()?;
+ let length_id = self.next()?;
+ let length_const = self.lookup_constant.lookup(length_id)?;
+
+ let decor = self.future_decor.remove(&id);
+ let inner = crate::TypeInner::Array {
+ base: self.lookup_type.lookup(type_id)?.handle,
+ size: crate::ArraySize::Constant(length_const.handle),
+ stride: decor.as_ref().and_then(|dec| dec.array_stride),
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.append(crate::Type {
+ name: decor.and_then(|dec| dec.name),
+ inner,
+ }),
+ base_id: Some(type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_runtime_array(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let id = self.next()?;
+ let type_id = self.next()?;
+
+ let decor = self.future_decor.remove(&id);
+ let inner = crate::TypeInner::Array {
+ base: self.lookup_type.lookup(type_id)?.handle,
+ size: crate::ArraySize::Dynamic,
+ stride: decor.as_ref().and_then(|dec| dec.array_stride),
+ };
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: module.types.append(crate::Type {
+ name: decor.and_then(|dec| dec.name),
+ inner,
+ }),
+ base_id: Some(type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_struct(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(2)?;
+ let id = self.next()?;
+ let parent_decor = self.future_decor.remove(&id);
+ let is_buffer_block = parent_decor
+ .as_ref()
+ .map_or(false, |decor| match decor.block {
+ Some(Block { buffer }) => buffer,
+ _ => false,
+ });
+
+ let mut members = Vec::with_capacity(inst.wc as usize - 2);
+ let mut member_type_ids = Vec::with_capacity(members.capacity());
+ for i in 0..u32::from(inst.wc) - 2 {
+ let type_id = self.next()?;
+ member_type_ids.push(type_id);
+ let ty = self.lookup_type.lookup(type_id)?.handle;
+ let decor = self
+ .future_member_decor
+ .remove(&(id, i))
+ .unwrap_or_default();
+ let origin = decor.get_origin()?;
+ members.push(crate::StructMember {
+ name: decor.name,
+ origin,
+ ty,
+ });
+ }
+ let inner = crate::TypeInner::Struct { members };
+ let ty_handle = module.types.append(crate::Type {
+ name: parent_decor.and_then(|dec| dec.name),
+ inner,
+ });
+
+ if is_buffer_block {
+ self.lookup_storage_buffer_types.insert(ty_handle);
+ }
+ for (i, type_id) in member_type_ids.into_iter().enumerate() {
+ self.lookup_member_type_id
+ .insert((ty_handle, i as u32), type_id);
+ }
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: ty_handle,
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_image(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(9)?;
+
+ let id = self.next()?;
+ let sample_type_id = self.next()?;
+ let dim = self.next()?;
+ let _is_depth = self.next()?;
+ let is_array = self.next()? != 0;
+ let is_msaa = self.next()? != 0;
+ let _is_sampled = self.next()?;
+ let format = self.next()?;
+
+ let base_handle = self.lookup_type.lookup(sample_type_id)?.handle;
+ let kind = module.types[base_handle]
+ .inner
+ .scalar_kind()
+ .ok_or(Error::InvalidImageBaseType(base_handle))?;
+
+ let class = if format != 0 {
+ crate::ImageClass::Storage(map_image_format(format)?)
+ } else {
+ crate::ImageClass::Sampled {
+ kind,
+ multi: is_msaa,
+ }
+ };
+
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+
+ let inner = crate::TypeInner::Image {
+ class,
+ dim: map_image_dim(dim)?,
+ arrayed: is_array,
+ };
+ let handle = module.types.append(crate::Type {
+ name: decor.name,
+ inner,
+ });
+ log::debug!("\t\ttracking {:?} for sampling properties", handle);
+ self.handle_sampling.insert(handle, SamplingFlags::empty());
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle,
+ base_id: Some(sample_type_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_sampled_image(&mut self, inst: Instruction) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(3)?;
+ let id = self.next()?;
+ let image_id = self.next()?;
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle: self.lookup_type.lookup(image_id)?.handle,
+ base_id: Some(image_id),
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_type_sampler(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect(2)?;
+ let id = self.next()?;
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+ // The comparison bit is temporary, will be overwritten based on the
+ // accumulated sampling flags at the end.
+ let inner = crate::TypeInner::Sampler { comparison: false };
+ let handle = module.types.append(crate::Type {
+ name: decor.name,
+ inner,
+ });
+ log::debug!("\t\ttracking {:?} for sampling properties", handle);
+ self.handle_sampling.insert(handle, SamplingFlags::empty());
+ self.lookup_type.insert(
+ id,
+ LookupType {
+ handle,
+ base_id: None,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_constant(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(3)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let ty = type_lookup.handle;
+ let inner = match module.types[ty].inner {
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width,
+ } => {
+ let low = self.next()?;
+ let high = if width > 4 {
+ inst.expect(4)?;
+ self.next()?
+ } else {
+ 0
+ };
+ crate::ConstantInner::Uint((u64::from(high) << 32) | u64::from(low))
+ }
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width,
+ } => {
+ use std::cmp::Ordering;
+ let low = self.next()?;
+ let high = match width.cmp(&4) {
+ Ordering::Less => return Err(Error::InvalidTypeWidth(u32::from(width))),
+ Ordering::Greater => {
+ inst.expect(4)?;
+ self.next()?
+ }
+ Ordering::Equal => 0,
+ };
+ crate::ConstantInner::Sint(((u64::from(high) << 32) | u64::from(low)) as i64)
+ }
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Float,
+ width,
+ } => {
+ let low = self.next()?;
+ let extended = match width {
+ 4 => f64::from(f32::from_bits(low)),
+ 8 => {
+ inst.expect(4)?;
+ let high = self.next()?;
+ f64::from_bits((u64::from(high) << 32) | u64::from(low))
+ }
+ _ => return Err(Error::InvalidTypeWidth(u32::from(width))),
+ };
+ crate::ConstantInner::Float(extended)
+ }
+ _ => return Err(Error::UnsupportedType(type_lookup.handle)),
+ };
+ self.lookup_constant.insert(
+ id,
+ LookupConstant {
+ handle: module.constants.append(crate::Constant {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ specialization: None, //TODO
+ inner,
+ ty,
+ }),
+ type_id,
+ },
+ );
+ Ok(())
+ }
+
+ fn parse_composite_constant(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(3)?;
+ let type_id = self.next()?;
+ let type_lookup = self.lookup_type.lookup(type_id)?;
+ let ty = type_lookup.handle;
+
+ let id = self.next()?;
+
+ let constituents_count = inst.wc - 3;
+ let mut constituents = Vec::with_capacity(constituents_count as usize);
+ for _ in 0..constituents_count {
+ let constituent_id = self.next()?;
+ let constant = self.lookup_constant.lookup(constituent_id)?;
+ constituents.push(constant.handle);
+ }
+
+ self.lookup_constant.insert(
+ id,
+ LookupConstant {
+ handle: module.constants.append(crate::Constant {
+ name: self.future_decor.remove(&id).and_then(|dec| dec.name),
+ specialization: None,
+ inner: crate::ConstantInner::Composite(constituents),
+ ty,
+ }),
+ type_id,
+ },
+ );
+
+ Ok(())
+ }
+
+ fn parse_global_variable(
+ &mut self,
+ inst: Instruction,
+ module: &mut crate::Module,
+ ) -> Result<(), Error> {
+ self.switch(ModuleState::Type, inst.op)?;
+ inst.expect_at_least(4)?;
+ let type_id = self.next()?;
+ let id = self.next()?;
+ let storage_class = self.next()?;
+ let init = if inst.wc > 4 {
+ inst.expect(5)?;
+ let init_id = self.next()?;
+ let lconst = self.lookup_constant.lookup(init_id)?;
+ Some(lconst.handle)
+ } else {
+ None
+ };
+ let lookup_type = self.lookup_type.lookup(type_id)?;
+ let dec = self
+ .future_decor
+ .remove(&id)
+ .ok_or(Error::InvalidBinding(id))?;
+
+ let class = {
+ use spirv::StorageClass as Sc;
+ match Sc::from_u32(storage_class) {
+ Some(Sc::Function) => crate::StorageClass::Function,
+ Some(Sc::Input) => crate::StorageClass::Input,
+ Some(Sc::Output) => crate::StorageClass::Output,
+ Some(Sc::Private) => crate::StorageClass::Private,
+ Some(Sc::UniformConstant) => crate::StorageClass::Handle,
+ Some(Sc::StorageBuffer) => crate::StorageClass::Storage,
+ Some(Sc::Uniform) => {
+ if self
+ .lookup_storage_buffer_types
+ .contains(&lookup_type.handle)
+ {
+ crate::StorageClass::Storage
+ } else {
+ crate::StorageClass::Uniform
+ }
+ }
+ Some(Sc::Workgroup) => crate::StorageClass::WorkGroup,
+ Some(Sc::PushConstant) => crate::StorageClass::PushConstant,
+ _ => return Err(Error::UnsupportedStorageClass(storage_class)),
+ }
+ };
+
+ let binding = match (class, &module.types[lookup_type.handle].inner) {
+ (crate::StorageClass::Input, &crate::TypeInner::Struct { .. })
+ | (crate::StorageClass::Output, &crate::TypeInner::Struct { .. }) => None,
+ _ => Some(dec.get_binding().ok_or(Error::InvalidBinding(id))?),
+ };
+ let is_storage = match module.types[lookup_type.handle].inner {
+ crate::TypeInner::Struct { .. } => class == crate::StorageClass::Storage,
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage(_),
+ ..
+ } => true,
+ _ => false,
+ };
+
+ let storage_access = if is_storage {
+ let mut access = crate::StorageAccess::all();
+ if dec.flags.contains(DecorationFlags::NON_READABLE) {
+ access ^= crate::StorageAccess::LOAD;
+ }
+ if dec.flags.contains(DecorationFlags::NON_WRITABLE) {
+ access ^= crate::StorageAccess::STORE;
+ }
+ access
+ } else {
+ crate::StorageAccess::empty()
+ };
+
+ let var = crate::GlobalVariable {
+ name: dec.name,
+ class,
+ binding,
+ ty: lookup_type.handle,
+ init,
+ interpolation: dec.interpolation,
+ storage_access,
+ };
+ self.lookup_variable.insert(
+ id,
+ LookupVariable {
+ handle: module.global_variables.append(var),
+ type_id,
+ },
+ );
+ Ok(())
+ }
+}
+
+pub fn parse_u8_slice(data: &[u8], options: &Options) -> Result<crate::Module, Error> {
+ if data.len() % 4 != 0 {
+ return Err(Error::IncompleteData);
+ }
+
+ let words = data
+ .chunks(4)
+ .map(|c| u32::from_le_bytes(c.try_into().unwrap()));
+ Parser::new(words, options).parse()
+}
+
+#[cfg(test)]
+mod test {
+ #[test]
+ fn parse() {
+ let bin = vec![
+ // Magic number. Version number: 1.0.
+ 0x03, 0x02, 0x23, 0x07, 0x00, 0x00, 0x01, 0x00,
+ // Generator number: 0. Bound: 0.
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Reserved word: 0.
+ 0x00, 0x00, 0x00, 0x00, // OpMemoryModel. Logical.
+ 0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // GLSL450.
+ 0x01, 0x00, 0x00, 0x00,
+ ];
+ let _ = super::parse_u8_slice(&bin, &Default::default()).unwrap();
+ }
+}
diff --git a/third_party/rust/naga/src/front/spv/rosetta.rs b/third_party/rust/naga/src/front/spv/rosetta.rs
new file mode 100644
index 0000000000..027c0d0adc
--- /dev/null
+++ b/third_party/rust/naga/src/front/spv/rosetta.rs
@@ -0,0 +1,23 @@
+use std::{fs, path::Path};
+
+const TEST_PATH: &str = "test-data";
+
+fn rosetta_test(file_name: &str) {
+ if true {
+ return; //TODO: fix this test
+ }
+ let file_path = Path::new(TEST_PATH).join(file_name);
+ let input = fs::read(&file_path).unwrap();
+
+ let module = super::parse_u8_slice(&input, &Default::default()).unwrap();
+ let output = ron::ser::to_string_pretty(&module, Default::default()).unwrap();
+
+ let expected = fs::read_to_string(file_path.with_extension("expected.ron")).unwrap();
+
+ difference::assert_diff!(output.as_str(), expected.as_str(), "", 0);
+}
+
+#[test]
+fn simple() {
+ rosetta_test("simple/simple.spv")
+}
diff --git a/third_party/rust/naga/src/front/wgsl/conv.rs b/third_party/rust/naga/src/front/wgsl/conv.rs
new file mode 100644
index 0000000000..6fdea89e7a
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/conv.rs
@@ -0,0 +1,117 @@
+use super::Error;
+
+pub fn map_storage_class(word: &str) -> Result<crate::StorageClass, Error<'_>> {
+ match word {
+ "in" => Ok(crate::StorageClass::Input),
+ "out" => Ok(crate::StorageClass::Output),
+ "private" => Ok(crate::StorageClass::Private),
+ "uniform" => Ok(crate::StorageClass::Uniform),
+ "storage" => Ok(crate::StorageClass::Storage),
+ _ => Err(Error::UnknownStorageClass(word)),
+ }
+}
+
+pub fn map_built_in(word: &str) -> Result<crate::BuiltIn, Error<'_>> {
+ Ok(match word {
+ // vertex
+ "position" => crate::BuiltIn::Position,
+ "vertex_idx" => crate::BuiltIn::VertexIndex,
+ "instance_idx" => crate::BuiltIn::InstanceIndex,
+ // fragment
+ "front_facing" => crate::BuiltIn::FrontFacing,
+ "frag_coord" => crate::BuiltIn::FragCoord,
+ "frag_depth" => crate::BuiltIn::FragDepth,
+ // compute
+ "global_invocation_id" => crate::BuiltIn::GlobalInvocationId,
+ "local_invocation_id" => crate::BuiltIn::LocalInvocationId,
+ "local_invocation_idx" => crate::BuiltIn::LocalInvocationIndex,
+ _ => return Err(Error::UnknownBuiltin(word)),
+ })
+}
+
+pub fn map_shader_stage(word: &str) -> Result<crate::ShaderStage, Error<'_>> {
+ match word {
+ "vertex" => Ok(crate::ShaderStage::Vertex),
+ "fragment" => Ok(crate::ShaderStage::Fragment),
+ "compute" => Ok(crate::ShaderStage::Compute),
+ _ => Err(Error::UnknownShaderStage(word)),
+ }
+}
+
+pub fn map_interpolation(word: &str) -> Result<crate::Interpolation, Error<'_>> {
+ match word {
+ "linear" => Ok(crate::Interpolation::Linear),
+ "flat" => Ok(crate::Interpolation::Flat),
+ "centroid" => Ok(crate::Interpolation::Centroid),
+ "sample" => Ok(crate::Interpolation::Sample),
+ "perspective" => Ok(crate::Interpolation::Perspective),
+ _ => Err(Error::UnknownDecoration(word)),
+ }
+}
+
+pub fn map_storage_format(word: &str) -> Result<crate::StorageFormat, Error<'_>> {
+ use crate::StorageFormat as Sf;
+ Ok(match word {
+ "r8unorm" => Sf::R8Unorm,
+ "r8snorm" => Sf::R8Snorm,
+ "r8uint" => Sf::R8Uint,
+ "r8sint" => Sf::R8Sint,
+ "r16uint" => Sf::R16Uint,
+ "r16sint" => Sf::R16Sint,
+ "r16float" => Sf::R16Float,
+ "rg8unorm" => Sf::Rg8Unorm,
+ "rg8snorm" => Sf::Rg8Snorm,
+ "rg8uint" => Sf::Rg8Uint,
+ "rg8sint" => Sf::Rg8Sint,
+ "r32uint" => Sf::R32Uint,
+ "r32sint" => Sf::R32Sint,
+ "r32float" => Sf::R32Float,
+ "rg16uint" => Sf::Rg16Uint,
+ "rg16sint" => Sf::Rg16Sint,
+ "rg16float" => Sf::Rg16Float,
+ "rgba8unorm" => Sf::Rgba8Unorm,
+ "rgba8snorm" => Sf::Rgba8Snorm,
+ "rgba8uint" => Sf::Rgba8Uint,
+ "rgba8sint" => Sf::Rgba8Sint,
+ "rgb10a2unorm" => Sf::Rgb10a2Unorm,
+ "rg11b10float" => Sf::Rg11b10Float,
+ "rg32uint" => Sf::Rg32Uint,
+ "rg32sint" => Sf::Rg32Sint,
+ "rg32float" => Sf::Rg32Float,
+ "rgba16uint" => Sf::Rgba16Uint,
+ "rgba16sint" => Sf::Rgba16Sint,
+ "rgba16float" => Sf::Rgba16Float,
+ "rgba32uint" => Sf::Rgba32Uint,
+ "rgba32sint" => Sf::Rgba32Sint,
+ "rgba32float" => Sf::Rgba32Float,
+ _ => return Err(Error::UnknownStorageFormat(word)),
+ })
+}
+
+pub fn get_scalar_type(word: &str) -> Option<(crate::ScalarKind, crate::Bytes)> {
+ match word {
+ "f32" => Some((crate::ScalarKind::Float, 4)),
+ "i32" => Some((crate::ScalarKind::Sint, 4)),
+ "u32" => Some((crate::ScalarKind::Uint, 4)),
+ _ => None,
+ }
+}
+
+pub fn get_intrinsic(word: &str) -> Option<crate::IntrinsicFunction> {
+ match word {
+ "any" => Some(crate::IntrinsicFunction::Any),
+ "all" => Some(crate::IntrinsicFunction::All),
+ "is_nan" => Some(crate::IntrinsicFunction::IsNan),
+ "is_inf" => Some(crate::IntrinsicFunction::IsInf),
+ "is_normal" => Some(crate::IntrinsicFunction::IsNormal),
+ _ => None,
+ }
+}
+pub fn get_derivative(word: &str) -> Option<crate::DerivativeAxis> {
+ match word {
+ "dpdx" => Some(crate::DerivativeAxis::X),
+ "dpdy" => Some(crate::DerivativeAxis::Y),
+ "dwidth" => Some(crate::DerivativeAxis::Width),
+ _ => None,
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/lexer.rs b/third_party/rust/naga/src/front/wgsl/lexer.rs
new file mode 100644
index 0000000000..b991ecf619
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/lexer.rs
@@ -0,0 +1,292 @@
+use super::{conv, Error, Token};
+
+fn _consume_str<'a>(input: &'a str, what: &str) -> Option<&'a str> {
+ if input.starts_with(what) {
+ Some(&input[what.len()..])
+ } else {
+ None
+ }
+}
+
+fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str) {
+ let pos = input.find(|c| !what(c)).unwrap_or_else(|| input.len());
+ input.split_at(pos)
+}
+
+fn consume_number(input: &str) -> (&str, &str) {
+ let mut is_first_char = true;
+ let mut right_after_exponent = false;
+
+ let mut what = |c| {
+ if is_first_char {
+ is_first_char = false;
+ c == '-' || c >= '0' && c <= '9' || c == '.'
+ } else if c == 'e' || c == 'E' {
+ right_after_exponent = true;
+ true
+ } else if right_after_exponent {
+ right_after_exponent = false;
+ c >= '0' && c <= '9' || c == '-'
+ } else {
+ c >= '0' && c <= '9' || c == '.'
+ }
+ };
+ let pos = input.find(|c| !what(c)).unwrap_or_else(|| input.len());
+ input.split_at(pos)
+}
+
+fn consume_token(mut input: &str) -> (Token<'_>, &str) {
+ input = input.trim_start();
+ let mut chars = input.chars();
+ let cur = match chars.next() {
+ Some(c) => c,
+ None => return (Token::End, input),
+ };
+ match cur {
+ ':' => {
+ input = chars.as_str();
+ if chars.next() == Some(':') {
+ (Token::DoubleColon, chars.as_str())
+ } else {
+ (Token::Separator(cur), input)
+ }
+ }
+ ';' | ',' => (Token::Separator(cur), chars.as_str()),
+ '.' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('0'..='9') => {
+ let (number, rest) = consume_number(input);
+ (Token::Number(number), rest)
+ }
+ _ => (Token::Separator(cur), og_chars),
+ }
+ }
+ '(' | ')' | '{' | '}' => (Token::Paren(cur), chars.as_str()),
+ '<' | '>' => {
+ input = chars.as_str();
+ let next = chars.next();
+ if next == Some('=') {
+ (Token::LogicalOperation(cur), chars.as_str())
+ } else if next == Some(cur) {
+ (Token::ShiftOperation(cur), chars.as_str())
+ } else {
+ (Token::Paren(cur), input)
+ }
+ }
+ '[' | ']' => {
+ input = chars.as_str();
+ if chars.next() == Some(cur) {
+ (Token::DoubleParen(cur), chars.as_str())
+ } else {
+ (Token::Paren(cur), input)
+ }
+ }
+ '0'..='9' => {
+ let (number, rest) = consume_number(input);
+ (Token::Number(number), rest)
+ }
+ 'a'..='z' | 'A'..='Z' | '_' => {
+ let (word, rest) = consume_any(input, |c| c.is_ascii_alphanumeric() || c == '_');
+ (Token::Word(word), rest)
+ }
+ '"' => {
+ let mut iter = chars.as_str().splitn(2, '"');
+
+ // splitn returns an iterator with at least one element, so unwrapping is fine
+ let quote_content = iter.next().unwrap();
+ if let Some(rest) = iter.next() {
+ (Token::String(quote_content), rest)
+ } else {
+ (Token::UnterminatedString, quote_content)
+ }
+ }
+ '-' => {
+ let og_chars = chars.as_str();
+ match chars.next() {
+ Some('>') => (Token::Arrow, chars.as_str()),
+ Some('0'..='9') | Some('.') => {
+ let (number, rest) = consume_number(input);
+ (Token::Number(number), rest)
+ }
+ _ => (Token::Operation(cur), og_chars),
+ }
+ }
+ '+' | '*' | '/' | '%' | '^' => (Token::Operation(cur), chars.as_str()),
+ '!' => {
+ if chars.next() == Some('=') {
+ (Token::LogicalOperation(cur), chars.as_str())
+ } else {
+ (Token::Operation(cur), input)
+ }
+ }
+ '=' | '&' | '|' => {
+ input = chars.as_str();
+ if chars.next() == Some(cur) {
+ (Token::LogicalOperation(cur), chars.as_str())
+ } else {
+ (Token::Operation(cur), input)
+ }
+ }
+ '#' => match chars.position(|c| c == '\n' || c == '\r') {
+ Some(_) => consume_token(chars.as_str()),
+ None => (Token::End, chars.as_str()),
+ },
+ _ => (Token::Unknown(cur), chars.as_str()),
+ }
+}
+
+#[derive(Clone)]
+pub(super) struct Lexer<'a> {
+ input: &'a str,
+}
+
+impl<'a> Lexer<'a> {
+ pub(super) fn new(input: &'a str) -> Self {
+ Lexer { input }
+ }
+
+ #[must_use]
+ pub(super) fn next(&mut self) -> Token<'a> {
+ let (token, rest) = consume_token(self.input);
+ self.input = rest;
+ token
+ }
+
+ #[must_use]
+ pub(super) fn peek(&mut self) -> Token<'a> {
+ self.clone().next()
+ }
+
+ pub(super) fn expect(&mut self, expected: Token<'_>) -> Result<(), Error<'a>> {
+ let token = self.next();
+ if token == expected {
+ Ok(())
+ } else {
+ Err(Error::Unexpected(token))
+ }
+ }
+
+ pub(super) fn skip(&mut self, what: Token<'_>) -> bool {
+ let (token, rest) = consume_token(self.input);
+ if token == what {
+ self.input = rest;
+ true
+ } else {
+ false
+ }
+ }
+
+ pub(super) fn next_ident(&mut self) -> Result<&'a str, Error<'a>> {
+ match self.next() {
+ Token::Word(word) => Ok(word),
+ other => Err(Error::Unexpected(other)),
+ }
+ }
+
+ fn _next_float_literal(&mut self) -> Result<f32, Error<'a>> {
+ match self.next() {
+ Token::Number(word) => word.parse().map_err(|err| Error::BadFloat(word, err)),
+ other => Err(Error::Unexpected(other)),
+ }
+ }
+
+ pub(super) fn next_uint_literal(&mut self) -> Result<u32, Error<'a>> {
+ match self.next() {
+ Token::Number(word) => word.parse().map_err(|err| Error::BadInteger(word, err)),
+ other => Err(Error::Unexpected(other)),
+ }
+ }
+
+ fn _next_sint_literal(&mut self) -> Result<i32, Error<'a>> {
+ match self.next() {
+ Token::Number(word) => word.parse().map_err(|err| Error::BadInteger(word, err)),
+ other => Err(Error::Unexpected(other)),
+ }
+ }
+
+ pub(super) fn next_scalar_generic(
+ &mut self,
+ ) -> Result<(crate::ScalarKind, crate::Bytes), Error<'a>> {
+ self.expect(Token::Paren('<'))?;
+ let word = self.next_ident()?;
+ let pair = conv::get_scalar_type(word).ok_or(Error::UnknownScalarType(word))?;
+ self.expect(Token::Paren('>'))?;
+ Ok(pair)
+ }
+
+ pub(super) fn next_format_generic(&mut self) -> Result<crate::StorageFormat, Error<'a>> {
+ self.expect(Token::Paren('<'))?;
+ let format = conv::map_storage_format(self.next_ident()?)?;
+ self.expect(Token::Paren('>'))?;
+ Ok(format)
+ }
+
+ pub(super) fn take_until(&mut self, what: Token<'_>) -> Result<Lexer<'a>, Error<'a>> {
+ let original_input = self.input;
+ let initial_len = self.input.len();
+ let mut used_len = 0;
+ loop {
+ if self.next() == what {
+ break;
+ }
+ used_len = initial_len - self.input.len();
+ }
+
+ Ok(Lexer {
+ input: &original_input[..used_len],
+ })
+ }
+
+ pub(super) fn offset_from(&self, source: &'a str) -> usize {
+ source.len() - self.input.len()
+ }
+}
+
+#[cfg(test)]
+fn sub_test(source: &str, expected_tokens: &[Token]) {
+ let mut lex = Lexer::new(source);
+ for &token in expected_tokens {
+ assert_eq!(lex.next(), token);
+ }
+ assert_eq!(lex.next(), Token::End);
+}
+
+#[test]
+fn test_tokens() {
+ sub_test("id123_OK", &[Token::Word("id123_OK")]);
+ sub_test("92No", &[Token::Number("92"), Token::Word("No")]);
+ sub_test(
+ "æNoø",
+ &[Token::Unknown('æ'), Token::Word("No"), Token::Unknown('ø')],
+ );
+ sub_test("No¾", &[Token::Word("No"), Token::Unknown('¾')]);
+ sub_test("No好", &[Token::Word("No"), Token::Unknown('好')]);
+ sub_test("\"\u{2}ПЀ\u{0}\"", &[Token::String("\u{2}ПЀ\u{0}")]); // https://github.com/gfx-rs/naga/issues/90
+}
+
+#[test]
+fn test_variable_decl() {
+ sub_test(
+ "[[ group(0 )]] var< uniform> texture: texture_multisampled_2d <f32 >;",
+ &[
+ Token::DoubleParen('['),
+ Token::Word("group"),
+ Token::Paren('('),
+ Token::Number("0"),
+ Token::Paren(')'),
+ Token::DoubleParen(']'),
+ Token::Word("var"),
+ Token::Paren('<'),
+ Token::Word("uniform"),
+ Token::Paren('>'),
+ Token::Word("texture"),
+ Token::Separator(':'),
+ Token::Word("texture_multisampled_2d"),
+ Token::Paren('<'),
+ Token::Word("f32"),
+ Token::Paren('>'),
+ Token::Separator(';'),
+ ],
+ )
+}
diff --git a/third_party/rust/naga/src/front/wgsl/mod.rs b/third_party/rust/naga/src/front/wgsl/mod.rs
new file mode 100644
index 0000000000..3b17719006
--- /dev/null
+++ b/third_party/rust/naga/src/front/wgsl/mod.rs
@@ -0,0 +1,1850 @@
+//! Front end for consuming [WebGPU Shading Language][wgsl].
+//!
+//! [wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html
+
+mod conv;
+mod lexer;
+
+use crate::{
+ arena::{Arena, Handle},
+ proc::{ResolveContext, ResolveError, Typifier},
+ FastHashMap,
+};
+
+use self::lexer::Lexer;
+use thiserror::Error;
+
+#[derive(Copy, Clone, Debug, PartialEq)]
+pub enum Token<'a> {
+ Separator(char),
+ DoubleColon,
+ Paren(char),
+ DoubleParen(char),
+ Number(&'a str),
+ String(&'a str),
+ Word(&'a str),
+ Operation(char),
+ LogicalOperation(char),
+ ShiftOperation(char),
+ Arrow,
+ Unknown(char),
+ UnterminatedString,
+ End,
+}
+
+#[derive(Clone, Debug, Error)]
+pub enum Error<'a> {
+ #[error("unexpected token: {0:?}")]
+ Unexpected(Token<'a>),
+ #[error("unable to parse `{0}` as integer: {1}")]
+ BadInteger(&'a str, std::num::ParseIntError),
+ #[error("unable to parse `{1}` as float: {1}")]
+ BadFloat(&'a str, std::num::ParseFloatError),
+ #[error("bad field accessor `{0}`")]
+ BadAccessor(&'a str),
+ #[error(transparent)]
+ InvalidResolve(ResolveError),
+ #[error("unknown import: `{0}`")]
+ UnknownImport(&'a str),
+ #[error("unknown storage class: `{0}`")]
+ UnknownStorageClass(&'a str),
+ #[error("unknown decoration: `{0}`")]
+ UnknownDecoration(&'a str),
+ #[error("unknown scalar kind: `{0}`")]
+ UnknownScalarKind(&'a str),
+ #[error("unknown builtin: `{0}`")]
+ UnknownBuiltin(&'a str),
+ #[error("unknown shader stage: `{0}`")]
+ UnknownShaderStage(&'a str),
+ #[error("unknown identifier: `{0}`")]
+ UnknownIdent(&'a str),
+ #[error("unknown scalar type: `{0}`")]
+ UnknownScalarType(&'a str),
+ #[error("unknown type: `{0}`")]
+ UnknownType(&'a str),
+ #[error("unknown function: `{0}`")]
+ UnknownFunction(&'a str),
+ #[error("unknown storage format: `{0}`")]
+ UnknownStorageFormat(&'a str),
+ #[error("missing offset for structure member `{0}`")]
+ MissingMemberOffset(&'a str),
+ #[error("array stride must not be 0")]
+ ZeroStride,
+ #[error("not a composite type: {0:?}")]
+ NotCompositeType(Handle<crate::Type>),
+ #[error("function redefinition: `{0}`")]
+ FunctionRedefinition(&'a str),
+ #[error("other error")]
+ Other,
+}
+
+trait StringValueLookup<'a> {
+ type Value;
+ fn lookup(&self, key: &'a str) -> Result<Self::Value, Error<'a>>;
+}
+impl<'a> StringValueLookup<'a> for FastHashMap<&'a str, Handle<crate::Expression>> {
+ type Value = Handle<crate::Expression>;
+ fn lookup(&self, key: &'a str) -> Result<Self::Value, Error<'a>> {
+ self.get(key).cloned().ok_or(Error::UnknownIdent(key))
+ }
+}
+
+struct StatementContext<'input, 'temp, 'out> {
+ lookup_ident: &'temp mut FastHashMap<&'input str, Handle<crate::Expression>>,
+ typifier: &'temp mut Typifier,
+ variables: &'out mut Arena<crate::LocalVariable>,
+ expressions: &'out mut Arena<crate::Expression>,
+ types: &'out mut Arena<crate::Type>,
+ constants: &'out mut Arena<crate::Constant>,
+ global_vars: &'out Arena<crate::GlobalVariable>,
+ arguments: &'out [crate::FunctionArgument],
+}
+
+impl<'a> StatementContext<'a, '_, '_> {
+ fn reborrow(&mut self) -> StatementContext<'a, '_, '_> {
+ StatementContext {
+ lookup_ident: self.lookup_ident,
+ typifier: self.typifier,
+ variables: self.variables,
+ expressions: self.expressions,
+ types: self.types,
+ constants: self.constants,
+ global_vars: self.global_vars,
+ arguments: self.arguments,
+ }
+ }
+
+ fn as_expression(&mut self) -> ExpressionContext<'a, '_, '_> {
+ ExpressionContext {
+ lookup_ident: self.lookup_ident,
+ typifier: self.typifier,
+ expressions: self.expressions,
+ types: self.types,
+ constants: self.constants,
+ global_vars: self.global_vars,
+ local_vars: self.variables,
+ arguments: self.arguments,
+ }
+ }
+}
+
+struct ExpressionContext<'input, 'temp, 'out> {
+ lookup_ident: &'temp FastHashMap<&'input str, Handle<crate::Expression>>,
+ typifier: &'temp mut Typifier,
+ expressions: &'out mut Arena<crate::Expression>,
+ types: &'out mut Arena<crate::Type>,
+ constants: &'out mut Arena<crate::Constant>,
+ global_vars: &'out Arena<crate::GlobalVariable>,
+ local_vars: &'out Arena<crate::LocalVariable>,
+ arguments: &'out [crate::FunctionArgument],
+}
+
+impl<'a> ExpressionContext<'a, '_, '_> {
+ fn reborrow(&mut self) -> ExpressionContext<'a, '_, '_> {
+ ExpressionContext {
+ lookup_ident: self.lookup_ident,
+ typifier: self.typifier,
+ expressions: self.expressions,
+ types: self.types,
+ constants: self.constants,
+ global_vars: self.global_vars,
+ local_vars: self.local_vars,
+ arguments: self.arguments,
+ }
+ }
+
+ fn resolve_type(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ ) -> Result<&crate::TypeInner, Error<'a>> {
+ let functions = Arena::new(); //TODO
+ let resolve_ctx = ResolveContext {
+ constants: self.constants,
+ global_vars: self.global_vars,
+ local_vars: self.local_vars,
+ functions: &functions,
+ arguments: self.arguments,
+ };
+ match self
+ .typifier
+ .grow(handle, self.expressions, self.types, &resolve_ctx)
+ {
+ Err(e) => Err(Error::InvalidResolve(e)),
+ Ok(()) => Ok(self.typifier.get(handle, self.types)),
+ }
+ }
+
+ fn parse_binary_op(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ classifier: impl Fn(Token<'a>) -> Option<crate::BinaryOperator>,
+ mut parser: impl FnMut(
+ &mut Lexer<'a>,
+ ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'a>>,
+ ) -> Result<Handle<crate::Expression>, Error<'a>> {
+ let mut left = parser(lexer, self.reborrow())?;
+ while let Some(op) = classifier(lexer.peek()) {
+ let _ = lexer.next();
+ let expression = crate::Expression::Binary {
+ op,
+ left,
+ right: parser(lexer, self.reborrow())?,
+ };
+ left = self.expressions.append(expression);
+ }
+ Ok(left)
+ }
+}
+
+enum Composition {
+ Single(crate::Expression),
+ Multi(crate::VectorSize, Vec<Handle<crate::Expression>>),
+}
+
+impl Composition {
+ fn make<'a>(
+ base: Handle<crate::Expression>,
+ base_size: crate::VectorSize,
+ name: &'a str,
+ expressions: &mut Arena<crate::Expression>,
+ ) -> Result<Self, Error<'a>> {
+ const MEMBERS: [char; 4] = ['x', 'y', 'z', 'w'];
+
+ Ok(if name.len() > 1 {
+ let mut components = Vec::with_capacity(name.len());
+ for ch in name.chars() {
+ let expr = crate::Expression::AccessIndex {
+ base,
+ index: MEMBERS[..base_size as usize]
+ .iter()
+ .position(|&m| m == ch)
+ .ok_or(Error::BadAccessor(name))? as u32,
+ };
+ components.push(expressions.append(expr));
+ }
+
+ let size = match name.len() {
+ 2 => crate::VectorSize::Bi,
+ 3 => crate::VectorSize::Tri,
+ 4 => crate::VectorSize::Quad,
+ _ => return Err(Error::BadAccessor(name)),
+ };
+ Composition::Multi(size, components)
+ } else {
+ let ch = name.chars().next().ok_or(Error::BadAccessor(name))?;
+ let index = MEMBERS[..base_size as usize]
+ .iter()
+ .position(|&m| m == ch)
+ .ok_or(Error::BadAccessor(name))? as u32;
+ Composition::Single(crate::Expression::AccessIndex { base, index })
+ })
+ }
+}
+
+#[derive(Clone, Debug, PartialEq)]
+pub enum Scope {
+ Decoration,
+ ImportDecl,
+ VariableDecl,
+ TypeDecl,
+ FunctionDecl,
+ Block,
+ Statement,
+ ConstantExpr,
+ PrimaryExpr,
+ SingularExpr,
+ GeneralExpr,
+}
+
+struct ParsedVariable<'a> {
+ name: &'a str,
+ class: Option<crate::StorageClass>,
+ ty: Handle<crate::Type>,
+ access: crate::StorageAccess,
+ init: Option<Handle<crate::Constant>>,
+}
+
+#[derive(Clone, Debug, Error)]
+#[error("error while parsing WGSL in scopes {scopes:?} at position {pos:?}: {error}")]
+pub struct ParseError<'a> {
+ pub error: Error<'a>,
+ pub scopes: Vec<Scope>,
+ pub pos: (usize, usize),
+}
+
+pub struct Parser {
+ scopes: Vec<Scope>,
+ lookup_type: FastHashMap<String, Handle<crate::Type>>,
+ function_lookup: FastHashMap<String, Handle<crate::Function>>,
+ std_namespace: Option<Vec<String>>,
+}
+
+impl Parser {
+ pub fn new() -> Self {
+ Parser {
+ scopes: Vec::new(),
+ lookup_type: FastHashMap::default(),
+ function_lookup: FastHashMap::default(),
+ std_namespace: None,
+ }
+ }
+
+ fn deconstruct_composite_type(
+ type_arena: &mut Arena<crate::Type>,
+ ty: Handle<crate::Type>,
+ index: usize,
+ ) -> Result<Handle<crate::Type>, Error<'static>> {
+ match type_arena[ty].inner {
+ crate::TypeInner::Vector { kind, width, .. } => {
+ let inner = crate::TypeInner::Scalar { kind, width };
+ Ok(type_arena.fetch_or_append(crate::Type { name: None, inner }))
+ }
+ crate::TypeInner::Matrix { width, .. } => {
+ let inner = crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Float,
+ width,
+ };
+ Ok(type_arena.fetch_or_append(crate::Type { name: None, inner }))
+ }
+ crate::TypeInner::Array { base, .. } => Ok(base),
+ crate::TypeInner::Struct { ref members } => Ok(members[index].ty),
+ _ => Err(Error::NotCompositeType(ty)),
+ }
+ }
+
+ fn get_constant_inner(
+ word: &str,
+ ) -> Result<(crate::ConstantInner, crate::ScalarKind), Error<'_>> {
+ if word.contains('.') {
+ word.parse()
+ .map(|f| (crate::ConstantInner::Float(f), crate::ScalarKind::Float))
+ .map_err(|err| Error::BadFloat(word, err))
+ } else {
+ word.parse()
+ .map(|i| (crate::ConstantInner::Sint(i), crate::ScalarKind::Sint))
+ .map_err(|err| Error::BadInteger(word, err))
+ }
+ }
+
+ fn parse_function_call<'a>(
+ &mut self,
+ lexer: &Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Option<(crate::Expression, Lexer<'a>)>, Error<'a>> {
+ let mut lexer = lexer.clone();
+
+ let external_function = if let Some(std_namespaces) = self.std_namespace.as_deref() {
+ std_namespaces.iter().all(|namespace| {
+ lexer.skip(Token::Word(namespace)) && lexer.skip(Token::DoubleColon)
+ })
+ } else {
+ false
+ };
+
+ let origin = if external_function {
+ let function = lexer.next_ident()?;
+ crate::FunctionOrigin::External(function.to_string())
+ } else if let Ok(function) = lexer.next_ident() {
+ if let Some(&function) = self.function_lookup.get(function) {
+ crate::FunctionOrigin::Local(function)
+ } else {
+ return Ok(None);
+ }
+ } else {
+ return Ok(None);
+ };
+
+ if !lexer.skip(Token::Paren('(')) {
+ return Ok(None);
+ }
+
+ let mut arguments = Vec::new();
+ while !lexer.skip(Token::Paren(')')) {
+ if !arguments.is_empty() {
+ lexer.expect(Token::Separator(','))?;
+ }
+ let arg = self.parse_general_expression(&mut lexer, ctx.reborrow())?;
+ arguments.push(arg);
+ }
+ Ok(Some((crate::Expression::Call { origin, arguments }, lexer)))
+ }
+
+ fn parse_const_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ self_ty: Handle<crate::Type>,
+ type_arena: &mut Arena<crate::Type>,
+ const_arena: &mut Arena<crate::Constant>,
+ ) -> Result<Handle<crate::Constant>, Error<'a>> {
+ self.scopes.push(Scope::ConstantExpr);
+ let inner = match lexer.peek() {
+ Token::Word("true") => {
+ let _ = lexer.next();
+ crate::ConstantInner::Bool(true)
+ }
+ Token::Word("false") => {
+ let _ = lexer.next();
+ crate::ConstantInner::Bool(false)
+ }
+ Token::Number(word) => {
+ let _ = lexer.next();
+ let (inner, _) = Self::get_constant_inner(word)?;
+ inner
+ }
+ _ => {
+ let composite_ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
+ lexer.expect(Token::Paren('('))?;
+ let mut components = Vec::new();
+ while !lexer.skip(Token::Paren(')')) {
+ if !components.is_empty() {
+ lexer.expect(Token::Separator(','))?;
+ }
+ let ty = Self::deconstruct_composite_type(
+ type_arena,
+ composite_ty,
+ components.len(),
+ )?;
+ let component =
+ self.parse_const_expression(lexer, ty, type_arena, const_arena)?;
+ components.push(component);
+ }
+ crate::ConstantInner::Composite(components)
+ }
+ };
+ let handle = const_arena.fetch_or_append(crate::Constant {
+ name: None,
+ specialization: None,
+ inner,
+ ty: self_ty,
+ });
+ self.scopes.pop();
+ Ok(handle)
+ }
+
+ fn parse_primary_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'a>> {
+ self.scopes.push(Scope::PrimaryExpr);
+ let backup = lexer.clone();
+ let expression = match lexer.next() {
+ Token::Paren('(') => {
+ let expr = self.parse_general_expression(lexer, ctx)?;
+ lexer.expect(Token::Paren(')'))?;
+ self.scopes.pop();
+ return Ok(expr);
+ }
+ Token::Word("true") => {
+ let handle = ctx.constants.fetch_or_append(crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Bool(true),
+ ty: ctx.types.fetch_or_append(crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: 1,
+ },
+ }),
+ });
+ crate::Expression::Constant(handle)
+ }
+ Token::Word("false") => {
+ let handle = ctx.constants.fetch_or_append(crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Bool(false),
+ ty: ctx.types.fetch_or_append(crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: 1,
+ },
+ }),
+ });
+ crate::Expression::Constant(handle)
+ }
+ Token::Number(word) => {
+ let (inner, kind) = Self::get_constant_inner(word)?;
+ let handle = ctx.constants.fetch_or_append(crate::Constant {
+ name: None,
+ specialization: None,
+ inner,
+ ty: ctx.types.fetch_or_append(crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar { kind, width: 4 },
+ }),
+ });
+ crate::Expression::Constant(handle)
+ }
+ Token::Word(word) => {
+ if let Some(handle) = ctx.lookup_ident.get(word) {
+ self.scopes.pop();
+ return Ok(*handle);
+ }
+ if let Some((expr, new_lexer)) =
+ self.parse_function_call(&backup, ctx.reborrow())?
+ {
+ *lexer = new_lexer;
+ expr
+ } else {
+ *lexer = backup;
+ let ty = self.parse_type_decl(lexer, None, ctx.types, ctx.constants)?;
+ lexer.expect(Token::Paren('('))?;
+ let mut components = Vec::new();
+ while !lexer.skip(Token::Paren(')')) {
+ if !components.is_empty() {
+ lexer.expect(Token::Separator(','))?;
+ }
+ let sub_expr = self.parse_general_expression(lexer, ctx.reborrow())?;
+ components.push(sub_expr);
+ }
+ crate::Expression::Compose { ty, components }
+ }
+ }
+ other => return Err(Error::Unexpected(other)),
+ };
+ self.scopes.pop();
+ Ok(ctx.expressions.append(expression))
+ }
+
+ fn parse_postfix<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ mut handle: Handle<crate::Expression>,
+ ) -> Result<Handle<crate::Expression>, Error<'a>> {
+ loop {
+ match lexer.peek() {
+ Token::Separator('.') => {
+ let _ = lexer.next();
+ let name = lexer.next_ident()?;
+ let expression = match *ctx.resolve_type(handle)? {
+ crate::TypeInner::Struct { ref members } => {
+ let index = members
+ .iter()
+ .position(|m| m.name.as_deref() == Some(name))
+ .ok_or(Error::BadAccessor(name))?
+ as u32;
+ crate::Expression::AccessIndex {
+ base: handle,
+ index,
+ }
+ }
+ crate::TypeInner::Vector { size, kind, width } => {
+ match Composition::make(handle, size, name, ctx.expressions)? {
+ Composition::Multi(size, components) => {
+ let inner = crate::TypeInner::Vector { size, kind, width };
+ crate::Expression::Compose {
+ ty: ctx
+ .types
+ .fetch_or_append(crate::Type { name: None, inner }),
+ components,
+ }
+ }
+ Composition::Single(expr) => expr,
+ }
+ }
+ crate::TypeInner::Matrix {
+ rows,
+ columns,
+ width,
+ } => match Composition::make(handle, columns, name, ctx.expressions)? {
+ Composition::Multi(columns, components) => {
+ let inner = crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ };
+ crate::Expression::Compose {
+ ty: ctx
+ .types
+ .fetch_or_append(crate::Type { name: None, inner }),
+ components,
+ }
+ }
+ Composition::Single(expr) => expr,
+ },
+ _ => return Err(Error::BadAccessor(name)),
+ };
+ handle = ctx.expressions.append(expression);
+ }
+ Token::Paren('[') => {
+ let _ = lexer.next();
+ let index = self.parse_general_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren(']'))?;
+ let expr = crate::Expression::Access {
+ base: handle,
+ index,
+ };
+ handle = ctx.expressions.append(expr);
+ }
+ _ => return Ok(handle),
+ }
+ }
+ }
+
+ fn parse_singular_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut ctx: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'a>> {
+ self.scopes.push(Scope::SingularExpr);
+ let backup = lexer.clone();
+ let expression = match lexer.next() {
+ Token::Operation('-') => Some(crate::Expression::Unary {
+ op: crate::UnaryOperator::Negate,
+ expr: self.parse_singular_expression(lexer, ctx.reborrow())?,
+ }),
+ Token::Operation('!') => Some(crate::Expression::Unary {
+ op: crate::UnaryOperator::Not,
+ expr: self.parse_singular_expression(lexer, ctx.reborrow())?,
+ }),
+ Token::Word(word) => {
+ if let Some(fun) = conv::get_intrinsic(word) {
+ lexer.expect(Token::Paren('('))?;
+ let argument = self.parse_primary_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren(')'))?;
+ Some(crate::Expression::Intrinsic { fun, argument })
+ } else if let Some(axis) = conv::get_derivative(word) {
+ lexer.expect(Token::Paren('('))?;
+ let expr = self.parse_primary_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren(')'))?;
+ Some(crate::Expression::Derivative { axis, expr })
+ } else if let Some((kind, _width)) = conv::get_scalar_type(word) {
+ lexer.expect(Token::Paren('('))?;
+ let expr = self.parse_primary_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren(')'))?;
+ Some(crate::Expression::As {
+ expr,
+ kind,
+ convert: true,
+ })
+ } else {
+ match word {
+ "dot" => {
+ lexer.expect(Token::Paren('('))?;
+ let a = self.parse_primary_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Separator(','))?;
+ let b = self.parse_primary_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren(')'))?;
+ Some(crate::Expression::DotProduct(a, b))
+ }
+ "cross" => {
+ lexer.expect(Token::Paren('('))?;
+ let a = self.parse_primary_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Separator(','))?;
+ let b = self.parse_primary_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren(')'))?;
+ Some(crate::Expression::CrossProduct(a, b))
+ }
+ "textureSample" => {
+ lexer.expect(Token::Paren('('))?;
+ let image_name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(','))?;
+ let sampler_name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(','))?;
+ let coordinate =
+ self.parse_primary_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren(')'))?;
+ Some(crate::Expression::ImageSample {
+ image: ctx.lookup_ident.lookup(image_name)?,
+ sampler: ctx.lookup_ident.lookup(sampler_name)?,
+ coordinate,
+ level: crate::SampleLevel::Auto,
+ depth_ref: None,
+ })
+ }
+ "textureSampleLevel" => {
+ lexer.expect(Token::Paren('('))?;
+ let image_name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(','))?;
+ let sampler_name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(','))?;
+ let coordinate =
+ self.parse_primary_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Separator(','))?;
+ let level = self.parse_primary_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren(')'))?;
+ Some(crate::Expression::ImageSample {
+ image: ctx.lookup_ident.lookup(image_name)?,
+ sampler: ctx.lookup_ident.lookup(sampler_name)?,
+ coordinate,
+ level: crate::SampleLevel::Exact(level),
+ depth_ref: None,
+ })
+ }
+ "textureSampleBias" => {
+ lexer.expect(Token::Paren('('))?;
+ let image_name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(','))?;
+ let sampler_name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(','))?;
+ let coordinate =
+ self.parse_primary_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Separator(','))?;
+ let bias = self.parse_primary_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren(')'))?;
+ Some(crate::Expression::ImageSample {
+ image: ctx.lookup_ident.lookup(image_name)?,
+ sampler: ctx.lookup_ident.lookup(sampler_name)?,
+ coordinate,
+ level: crate::SampleLevel::Bias(bias),
+ depth_ref: None,
+ })
+ }
+ "textureSampleCompare" => {
+ lexer.expect(Token::Paren('('))?;
+ let image_name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(','))?;
+ let sampler_name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(','))?;
+ let coordinate =
+ self.parse_primary_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Separator(','))?;
+ let reference = self.parse_primary_expression(lexer, ctx.reborrow())?;
+ lexer.expect(Token::Paren(')'))?;
+ Some(crate::Expression::ImageSample {
+ image: ctx.lookup_ident.lookup(image_name)?,
+ sampler: ctx.lookup_ident.lookup(sampler_name)?,
+ coordinate,
+ level: crate::SampleLevel::Zero,
+ depth_ref: Some(reference),
+ })
+ }
+ "textureLoad" => {
+ lexer.expect(Token::Paren('('))?;
+ let image_name = lexer.next_ident()?;
+ let image = ctx.lookup_ident.lookup(image_name)?;
+ lexer.expect(Token::Separator(','))?;
+ let coordinate =
+ self.parse_primary_expression(lexer, ctx.reborrow())?;
+ let is_storage = match *ctx.resolve_type(image)? {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage(_),
+ ..
+ } => true,
+ _ => false,
+ };
+ let index = if is_storage {
+ None
+ } else {
+ lexer.expect(Token::Separator(','))?;
+ let index_name = lexer.next_ident()?;
+ Some(ctx.lookup_ident.lookup(index_name)?)
+ };
+ lexer.expect(Token::Paren(')'))?;
+ Some(crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ index,
+ })
+ }
+ _ => None,
+ }
+ }
+ }
+ _ => None,
+ };
+
+ let handle = match expression {
+ Some(expr) => ctx.expressions.append(expr),
+ None => {
+ *lexer = backup;
+ let handle = self.parse_primary_expression(lexer, ctx.reborrow())?;
+ self.parse_postfix(lexer, ctx, handle)?
+ }
+ };
+ self.scopes.pop();
+ Ok(handle)
+ }
+
+ fn parse_equality_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut context: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'a>> {
+ // equality_expression
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::LogicalOperation('=') => Some(crate::BinaryOperator::Equal),
+ Token::LogicalOperation('!') => Some(crate::BinaryOperator::NotEqual),
+ _ => None,
+ },
+ // relational_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Paren('<') => Some(crate::BinaryOperator::Less),
+ Token::Paren('>') => Some(crate::BinaryOperator::Greater),
+ Token::LogicalOperation('<') => Some(crate::BinaryOperator::LessEqual),
+ Token::LogicalOperation('>') => Some(crate::BinaryOperator::GreaterEqual),
+ _ => None,
+ },
+ // shift_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::ShiftOperation('<') => {
+ Some(crate::BinaryOperator::ShiftLeft)
+ }
+ Token::ShiftOperation('>') => {
+ Some(crate::BinaryOperator::ShiftRight)
+ }
+ _ => None,
+ },
+ // additive_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('+') => Some(crate::BinaryOperator::Add),
+ Token::Operation('-') => {
+ Some(crate::BinaryOperator::Subtract)
+ }
+ _ => None,
+ },
+ // multiplicative_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('*') => {
+ Some(crate::BinaryOperator::Multiply)
+ }
+ Token::Operation('/') => {
+ Some(crate::BinaryOperator::Divide)
+ }
+ Token::Operation('%') => {
+ Some(crate::BinaryOperator::Modulo)
+ }
+ _ => None,
+ },
+ |lexer, context| {
+ self.parse_singular_expression(lexer, context)
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )
+ }
+
+ fn parse_general_expression<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut context: ExpressionContext<'a, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'a>> {
+ self.scopes.push(Scope::GeneralExpr);
+ // logical_or_expression
+ let handle = context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::LogicalOperation('|') => Some(crate::BinaryOperator::LogicalOr),
+ _ => None,
+ },
+ // logical_and_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::LogicalOperation('&') => Some(crate::BinaryOperator::LogicalAnd),
+ _ => None,
+ },
+ // inclusive_or_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('|') => Some(crate::BinaryOperator::InclusiveOr),
+ _ => None,
+ },
+ // exclusive_or_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('^') => {
+ Some(crate::BinaryOperator::ExclusiveOr)
+ }
+ _ => None,
+ },
+ // and_expression
+ |lexer, mut context| {
+ context.parse_binary_op(
+ lexer,
+ |token| match token {
+ Token::Operation('&') => {
+ Some(crate::BinaryOperator::And)
+ }
+ _ => None,
+ },
+ |lexer, context| {
+ self.parse_equality_expression(lexer, context)
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )
+ },
+ )?;
+ self.scopes.pop();
+ Ok(handle)
+ }
+
+ fn parse_variable_ident_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ type_arena: &mut Arena<crate::Type>,
+ const_arena: &mut Arena<crate::Constant>,
+ ) -> Result<(&'a str, Handle<crate::Type>), Error<'a>> {
+ let name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(':'))?;
+ let ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
+ Ok((name, ty))
+ }
+
+ fn parse_variable_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ type_arena: &mut Arena<crate::Type>,
+ const_arena: &mut Arena<crate::Constant>,
+ ) -> Result<ParsedVariable<'a>, Error<'a>> {
+ self.scopes.push(Scope::VariableDecl);
+ let mut class = None;
+ if lexer.skip(Token::Paren('<')) {
+ let class_str = lexer.next_ident()?;
+ class = Some(conv::map_storage_class(class_str)?);
+ lexer.expect(Token::Paren('>'))?;
+ }
+ let name = lexer.next_ident()?;
+ lexer.expect(Token::Separator(':'))?;
+ let ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
+ let access = match class {
+ Some(crate::StorageClass::Storage) => crate::StorageAccess::all(),
+ Some(crate::StorageClass::Handle) => {
+ match type_arena[ty].inner {
+ //TODO: RW textures
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage(_),
+ ..
+ } => crate::StorageAccess::LOAD,
+ _ => crate::StorageAccess::empty(),
+ }
+ }
+ _ => crate::StorageAccess::empty(),
+ };
+ let init = if lexer.skip(Token::Operation('=')) {
+ let handle = self.parse_const_expression(lexer, ty, type_arena, const_arena)?;
+ Some(handle)
+ } else {
+ None
+ };
+ lexer.expect(Token::Separator(';'))?;
+ self.scopes.pop();
+ Ok(ParsedVariable {
+ name,
+ class,
+ ty,
+ access,
+ init,
+ })
+ }
+
+ fn parse_struct_body<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ type_arena: &mut Arena<crate::Type>,
+ const_arena: &mut Arena<crate::Constant>,
+ ) -> Result<Vec<crate::StructMember>, Error<'a>> {
+ let mut members = Vec::new();
+ lexer.expect(Token::Paren('{'))?;
+ loop {
+ let mut offset = !0;
+ if lexer.skip(Token::DoubleParen('[')) {
+ self.scopes.push(Scope::Decoration);
+ let mut ready = true;
+ loop {
+ match lexer.next() {
+ Token::DoubleParen(']') => {
+ break;
+ }
+ Token::Separator(',') if !ready => {
+ ready = true;
+ }
+ Token::Word("offset") if ready => {
+ lexer.expect(Token::Paren('('))?;
+ offset = lexer.next_uint_literal()?;
+ lexer.expect(Token::Paren(')'))?;
+ ready = false;
+ }
+ other => return Err(Error::Unexpected(other)),
+ }
+ }
+ self.scopes.pop();
+ }
+ let name = match lexer.next() {
+ Token::Word(word) => word,
+ Token::Paren('}') => return Ok(members),
+ other => return Err(Error::Unexpected(other)),
+ };
+ if offset == !0 {
+ return Err(Error::MissingMemberOffset(name));
+ }
+ lexer.expect(Token::Separator(':'))?;
+ let ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
+ lexer.expect(Token::Separator(';'))?;
+ members.push(crate::StructMember {
+ name: Some(name.to_owned()),
+ origin: crate::MemberOrigin::Offset(offset),
+ ty,
+ });
+ }
+ }
+
+ fn parse_type_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ self_name: Option<&'a str>,
+ type_arena: &mut Arena<crate::Type>,
+ const_arena: &mut Arena<crate::Constant>,
+ ) -> Result<Handle<crate::Type>, Error<'a>> {
+ self.scopes.push(Scope::TypeDecl);
+ let decoration_lexer = if lexer.skip(Token::DoubleParen('[')) {
+ Some(lexer.take_until(Token::DoubleParen(']'))?)
+ } else {
+ None
+ };
+
+ let inner = match lexer.next() {
+ Token::Word("f32") => crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ },
+ Token::Word("i32") => crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: 4,
+ },
+ Token::Word("u32") => crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ Token::Word("vec2") => {
+ let (kind, width) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Vector {
+ size: crate::VectorSize::Bi,
+ kind,
+ width,
+ }
+ }
+ Token::Word("vec3") => {
+ let (kind, width) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ kind,
+ width,
+ }
+ }
+ Token::Word("vec4") => {
+ let (kind, width) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Vector {
+ size: crate::VectorSize::Quad,
+ kind,
+ width,
+ }
+ }
+ Token::Word("mat2x2") => {
+ let (_, width) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Bi,
+ width,
+ }
+ }
+ Token::Word("mat2x3") => {
+ let (_, width) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Tri,
+ width,
+ }
+ }
+ Token::Word("mat2x4") => {
+ let (_, width) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Matrix {
+ columns: crate::VectorSize::Bi,
+ rows: crate::VectorSize::Quad,
+ width,
+ }
+ }
+ Token::Word("mat3x2") => {
+ let (_, width) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Bi,
+ width,
+ }
+ }
+ Token::Word("mat3x3") => {
+ let (_, width) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ width,
+ }
+ }
+ Token::Word("mat3x4") => {
+ let (_, width) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Quad,
+ width,
+ }
+ }
+ Token::Word("mat4x2") => {
+ let (_, width) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Bi,
+ width,
+ }
+ }
+ Token::Word("mat4x3") => {
+ let (_, width) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ width,
+ }
+ }
+ Token::Word("mat4x4") => {
+ let (_, width) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Quad,
+ width,
+ }
+ }
+ Token::Word("ptr") => {
+ lexer.expect(Token::Paren('<'))?;
+ let class = conv::map_storage_class(lexer.next_ident()?)?;
+ lexer.expect(Token::Separator(','))?;
+ let base = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
+ lexer.expect(Token::Paren('>'))?;
+ crate::TypeInner::Pointer { base, class }
+ }
+ Token::Word("array") => {
+ lexer.expect(Token::Paren('<'))?;
+ let base = self.parse_type_decl(lexer, None, type_arena, const_arena)?;
+ let size = match lexer.next() {
+ Token::Separator(',') => {
+ let value = lexer.next_uint_literal()?;
+ lexer.expect(Token::Paren('>'))?;
+ let const_handle = const_arena.fetch_or_append(crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Uint(value as u64),
+ ty: type_arena.fetch_or_append(crate::Type {
+ name: None,
+ inner: crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ }),
+ });
+ crate::ArraySize::Constant(const_handle)
+ }
+ Token::Paren('>') => crate::ArraySize::Dynamic,
+ other => return Err(Error::Unexpected(other)),
+ };
+
+ let mut stride = None;
+ if let Some(mut lexer) = decoration_lexer {
+ self.scopes.push(Scope::Decoration);
+ loop {
+ match lexer.next() {
+ Token::Word("stride") => {
+ use std::num::NonZeroU32;
+ stride = Some(
+ NonZeroU32::new(lexer.next_uint_literal()?)
+ .ok_or(Error::ZeroStride)?,
+ );
+ }
+ Token::End => break,
+ other => return Err(Error::Unexpected(other)),
+ }
+ }
+ self.scopes.pop();
+ }
+
+ crate::TypeInner::Array { base, size, stride }
+ }
+ Token::Word("struct") => {
+ let members = self.parse_struct_body(lexer, type_arena, const_arena)?;
+ crate::TypeInner::Struct { members }
+ }
+ Token::Word("sampler") => crate::TypeInner::Sampler { comparison: false },
+ Token::Word("sampler_comparison") => crate::TypeInner::Sampler { comparison: true },
+ Token::Word("texture_sampled_1d") => {
+ let (kind, _) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: false,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ Token::Word("texture_sampled_1d_array") => {
+ let (kind, _) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: true,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ Token::Word("texture_sampled_2d") => {
+ let (kind, _) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ Token::Word("texture_sampled_2d_array") => {
+ let (kind, _) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ Token::Word("texture_sampled_3d") => {
+ let (kind, _) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D3,
+ arrayed: false,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ Token::Word("texture_sampled_cube") => {
+ let (kind, _) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: false,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ Token::Word("texture_sampled_cube_array") => {
+ let (kind, _) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: true,
+ class: crate::ImageClass::Sampled { kind, multi: false },
+ }
+ }
+ Token::Word("texture_multisampled_2d") => {
+ let (kind, _) = lexer.next_scalar_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Sampled { kind, multi: true },
+ }
+ }
+ Token::Word("texture_depth_2d") => crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Depth,
+ },
+ Token::Word("texture_depth_2d_array") => crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Depth,
+ },
+ Token::Word("texture_depth_cube") => crate::TypeInner::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: false,
+ class: crate::ImageClass::Depth,
+ },
+ Token::Word("texture_depth_cube_array") => crate::TypeInner::Image {
+ dim: crate::ImageDimension::Cube,
+ arrayed: true,
+ class: crate::ImageClass::Depth,
+ },
+ Token::Word("texture_ro_1d") => {
+ let format = lexer.next_format_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: false,
+ class: crate::ImageClass::Storage(format),
+ }
+ }
+ Token::Word("texture_ro_1d_array") => {
+ let format = lexer.next_format_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: true,
+ class: crate::ImageClass::Storage(format),
+ }
+ }
+ Token::Word("texture_ro_2d") => {
+ let format = lexer.next_format_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Storage(format),
+ }
+ }
+ Token::Word("texture_ro_2d_array") => {
+ let format = lexer.next_format_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Storage(format),
+ }
+ }
+ Token::Word("texture_ro_3d") => {
+ let format = lexer.next_format_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D3,
+ arrayed: false,
+ class: crate::ImageClass::Storage(format),
+ }
+ }
+ Token::Word("texture_wo_1d") => {
+ let format = lexer.next_format_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: false,
+ class: crate::ImageClass::Storage(format),
+ }
+ }
+ Token::Word("texture_wo_1d_array") => {
+ let format = lexer.next_format_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D1,
+ arrayed: true,
+ class: crate::ImageClass::Storage(format),
+ }
+ }
+ Token::Word("texture_wo_2d") => {
+ let format = lexer.next_format_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: false,
+ class: crate::ImageClass::Storage(format),
+ }
+ }
+ Token::Word("texture_wo_2d_array") => {
+ let format = lexer.next_format_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D2,
+ arrayed: true,
+ class: crate::ImageClass::Storage(format),
+ }
+ }
+ Token::Word("texture_wo_3d") => {
+ let format = lexer.next_format_generic()?;
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::D3,
+ arrayed: false,
+ class: crate::ImageClass::Storage(format),
+ }
+ }
+ Token::Word(name) => {
+ self.scopes.pop();
+ return match self.lookup_type.get(name) {
+ Some(&handle) => Ok(handle),
+ None => Err(Error::UnknownType(name)),
+ };
+ }
+ other => return Err(Error::Unexpected(other)),
+ };
+ self.scopes.pop();
+
+ let handle = type_arena.fetch_or_append(crate::Type {
+ name: self_name.map(|s| s.to_string()),
+ inner,
+ });
+ Ok(handle)
+ }
+
+ fn parse_statement<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut context: StatementContext<'a, '_, '_>,
+ ) -> Result<crate::Statement, Error<'a>> {
+ let backup = lexer.clone();
+ match lexer.next() {
+ Token::Separator(';') => Ok(crate::Statement::Block(Vec::new())),
+ Token::Word(word) => {
+ self.scopes.push(Scope::Statement);
+ let statement = match word {
+ "var" => {
+ enum Init {
+ Empty,
+ Constant(Handle<crate::Constant>),
+ Variable(Handle<crate::Expression>),
+ }
+ let (name, ty) = self.parse_variable_ident_decl(
+ lexer,
+ context.types,
+ context.constants,
+ )?;
+ let init = if lexer.skip(Token::Operation('=')) {
+ let value =
+ self.parse_general_expression(lexer, context.as_expression())?;
+ if let crate::Expression::Constant(handle) = context.expressions[value]
+ {
+ Init::Constant(handle)
+ } else {
+ Init::Variable(value)
+ }
+ } else {
+ Init::Empty
+ };
+ lexer.expect(Token::Separator(';'))?;
+ let var_id = context.variables.append(crate::LocalVariable {
+ name: Some(name.to_owned()),
+ ty,
+ init: match init {
+ Init::Constant(value) => Some(value),
+ _ => None,
+ },
+ });
+ let expr_id = context
+ .expressions
+ .append(crate::Expression::LocalVariable(var_id));
+ context.lookup_ident.insert(name, expr_id);
+ match init {
+ Init::Variable(value) => crate::Statement::Store {
+ pointer: expr_id,
+ value,
+ },
+ _ => crate::Statement::Block(Vec::new()),
+ }
+ }
+ "return" => {
+ let value = if lexer.peek() != Token::Separator(';') {
+ Some(self.parse_general_expression(lexer, context.as_expression())?)
+ } else {
+ None
+ };
+ lexer.expect(Token::Separator(';'))?;
+ crate::Statement::Return { value }
+ }
+ "if" => {
+ lexer.expect(Token::Paren('('))?;
+ let condition =
+ self.parse_general_expression(lexer, context.as_expression())?;
+ lexer.expect(Token::Paren(')'))?;
+ let accept = self.parse_block(lexer, context.reborrow())?;
+ let reject = if lexer.skip(Token::Word("else")) {
+ self.parse_block(lexer, context.reborrow())?
+ } else {
+ Vec::new()
+ };
+ crate::Statement::If {
+ condition,
+ accept,
+ reject,
+ }
+ }
+ "loop" => {
+ let mut body = Vec::new();
+ let mut continuing = Vec::new();
+ lexer.expect(Token::Paren('{'))?;
+ loop {
+ if lexer.skip(Token::Word("continuing")) {
+ continuing = self.parse_block(lexer, context.reborrow())?;
+ lexer.expect(Token::Paren('}'))?;
+ break;
+ }
+ if lexer.skip(Token::Paren('}')) {
+ break;
+ }
+ let s = self.parse_statement(lexer, context.reborrow())?;
+ body.push(s);
+ }
+ crate::Statement::Loop { body, continuing }
+ }
+ "break" => crate::Statement::Break,
+ "continue" => crate::Statement::Continue,
+ ident => {
+ // assignment
+ if let Some(&var_expr) = context.lookup_ident.get(ident) {
+ let left =
+ self.parse_postfix(lexer, context.as_expression(), var_expr)?;
+ lexer.expect(Token::Operation('='))?;
+ let value =
+ self.parse_general_expression(lexer, context.as_expression())?;
+ lexer.expect(Token::Separator(';'))?;
+ crate::Statement::Store {
+ pointer: left,
+ value,
+ }
+ } else if let Some((expr, new_lexer)) =
+ self.parse_function_call(&backup, context.as_expression())?
+ {
+ *lexer = new_lexer;
+ context.expressions.append(expr);
+ lexer.expect(Token::Separator(';'))?;
+ crate::Statement::Block(Vec::new())
+ } else {
+ return Err(Error::UnknownIdent(ident));
+ }
+ }
+ };
+ self.scopes.pop();
+ Ok(statement)
+ }
+ other => Err(Error::Unexpected(other)),
+ }
+ }
+
+ fn parse_block<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ mut context: StatementContext<'a, '_, '_>,
+ ) -> Result<Vec<crate::Statement>, Error<'a>> {
+ self.scopes.push(Scope::Block);
+ lexer.expect(Token::Paren('{'))?;
+ let mut statements = Vec::new();
+ while !lexer.skip(Token::Paren('}')) {
+ let s = self.parse_statement(lexer, context.reborrow())?;
+ statements.push(s);
+ }
+ self.scopes.pop();
+ Ok(statements)
+ }
+
+ fn parse_function_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ module: &mut crate::Module,
+ lookup_global_expression: &FastHashMap<&'a str, crate::Expression>,
+ ) -> Result<(crate::Function, &'a str), Error<'a>> {
+ self.scopes.push(Scope::FunctionDecl);
+ // read function name
+ let mut lookup_ident = FastHashMap::default();
+ let fun_name = lexer.next_ident()?;
+ // populare initial expressions
+ let mut expressions = Arena::new();
+ for (&name, expression) in lookup_global_expression.iter() {
+ let expr_handle = expressions.append(expression.clone());
+ lookup_ident.insert(name, expr_handle);
+ }
+ // read parameter list
+ let mut arguments = Vec::new();
+ lexer.expect(Token::Paren('('))?;
+ while !lexer.skip(Token::Paren(')')) {
+ if !arguments.is_empty() {
+ lexer.expect(Token::Separator(','))?;
+ }
+ let (param_name, param_type) =
+ self.parse_variable_ident_decl(lexer, &mut module.types, &mut module.constants)?;
+ let param_index = arguments.len() as u32;
+ let expression_token =
+ expressions.append(crate::Expression::FunctionArgument(param_index));
+ lookup_ident.insert(param_name, expression_token);
+ arguments.push(crate::FunctionArgument {
+ name: Some(param_name.to_string()),
+ ty: param_type,
+ });
+ }
+ // read return type
+ lexer.expect(Token::Arrow)?;
+ let return_type = if lexer.skip(Token::Word("void")) {
+ None
+ } else {
+ Some(self.parse_type_decl(lexer, None, &mut module.types, &mut module.constants)?)
+ };
+
+ let mut fun = crate::Function {
+ name: Some(fun_name.to_string()),
+ arguments,
+ return_type,
+ global_usage: Vec::new(),
+ local_variables: Arena::new(),
+ expressions,
+ body: Vec::new(),
+ };
+
+ // read body
+ let mut typifier = Typifier::new();
+ fun.body = self.parse_block(
+ lexer,
+ StatementContext {
+ lookup_ident: &mut lookup_ident,
+ typifier: &mut typifier,
+ variables: &mut fun.local_variables,
+ expressions: &mut fun.expressions,
+ types: &mut module.types,
+ constants: &mut module.constants,
+ global_vars: &module.global_variables,
+ arguments: &fun.arguments,
+ },
+ )?;
+ // done
+ fun.fill_global_use(&module.global_variables);
+ self.scopes.pop();
+
+ Ok((fun, fun_name))
+ }
+
+ fn parse_global_decl<'a>(
+ &mut self,
+ lexer: &mut Lexer<'a>,
+ module: &mut crate::Module,
+ lookup_global_expression: &mut FastHashMap<&'a str, crate::Expression>,
+ ) -> Result<bool, Error<'a>> {
+ // read decorations
+ let mut binding = None;
+ // Perspective is the default qualifier.
+ let mut interpolation = None;
+ let mut stage = None;
+ let mut workgroup_size = [0u32; 3];
+
+ if lexer.skip(Token::DoubleParen('[')) {
+ let (mut bind_index, mut bind_group) = (None, None);
+ self.scopes.push(Scope::Decoration);
+ loop {
+ match lexer.next_ident()? {
+ "location" => {
+ lexer.expect(Token::Paren('('))?;
+ let loc = lexer.next_uint_literal()?;
+ lexer.expect(Token::Paren(')'))?;
+ binding = Some(crate::Binding::Location(loc));
+ }
+ "builtin" => {
+ lexer.expect(Token::Paren('('))?;
+ let builtin = conv::map_built_in(lexer.next_ident()?)?;
+ lexer.expect(Token::Paren(')'))?;
+ binding = Some(crate::Binding::BuiltIn(builtin));
+ }
+ "binding" => {
+ lexer.expect(Token::Paren('('))?;
+ bind_index = Some(lexer.next_uint_literal()?);
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "group" => {
+ lexer.expect(Token::Paren('('))?;
+ bind_group = Some(lexer.next_uint_literal()?);
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "interpolate" => {
+ lexer.expect(Token::Paren('('))?;
+ interpolation = Some(conv::map_interpolation(lexer.next_ident()?)?);
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "stage" => {
+ lexer.expect(Token::Paren('('))?;
+ stage = Some(conv::map_shader_stage(lexer.next_ident()?)?);
+ lexer.expect(Token::Paren(')'))?;
+ }
+ "workgroup_size" => {
+ lexer.expect(Token::Paren('('))?;
+ for (i, size) in workgroup_size.iter_mut().enumerate() {
+ *size = lexer.next_uint_literal()?;
+ match lexer.next() {
+ Token::Paren(')') => break,
+ Token::Separator(',') if i != 2 => (),
+ other => return Err(Error::Unexpected(other)),
+ }
+ }
+ for size in workgroup_size.iter_mut() {
+ if *size == 0 {
+ *size = 1;
+ }
+ }
+ }
+ word => return Err(Error::UnknownDecoration(word)),
+ }
+ match lexer.next() {
+ Token::DoubleParen(']') => {
+ break;
+ }
+ Token::Separator(',') => {}
+ other => return Err(Error::Unexpected(other)),
+ }
+ }
+ if let (Some(group), Some(index)) = (bind_group, bind_index) {
+ binding = Some(crate::Binding::Resource {
+ group,
+ binding: index,
+ });
+ }
+ self.scopes.pop();
+ }
+ // read items
+ match lexer.next() {
+ Token::Separator(';') => {}
+ Token::Word("import") => {
+ self.scopes.push(Scope::ImportDecl);
+ let path = match lexer.next() {
+ Token::String(path) => path,
+ other => return Err(Error::Unexpected(other)),
+ };
+ lexer.expect(Token::Word("as"))?;
+ let mut namespaces = Vec::new();
+ loop {
+ namespaces.push(lexer.next_ident()?.to_owned());
+ if lexer.skip(Token::Separator(';')) {
+ break;
+ }
+ lexer.expect(Token::DoubleColon)?;
+ }
+ match path {
+ "GLSL.std.450" => self.std_namespace = Some(namespaces),
+ _ => return Err(Error::UnknownImport(path)),
+ }
+ self.scopes.pop();
+ }
+ Token::Word("type") => {
+ let name = lexer.next_ident()?;
+ lexer.expect(Token::Operation('='))?;
+ let ty = self.parse_type_decl(
+ lexer,
+ Some(name),
+ &mut module.types,
+ &mut module.constants,
+ )?;
+ self.lookup_type.insert(name.to_owned(), ty);
+ lexer.expect(Token::Separator(';'))?;
+ }
+ Token::Word("const") => {
+ let (name, ty) = self.parse_variable_ident_decl(
+ lexer,
+ &mut module.types,
+ &mut module.constants,
+ )?;
+ lexer.expect(Token::Operation('='))?;
+ let const_handle = self.parse_const_expression(
+ lexer,
+ ty,
+ &mut module.types,
+ &mut module.constants,
+ )?;
+ lexer.expect(Token::Separator(';'))?;
+ lookup_global_expression.insert(name, crate::Expression::Constant(const_handle));
+ }
+ Token::Word("var") => {
+ let pvar =
+ self.parse_variable_decl(lexer, &mut module.types, &mut module.constants)?;
+ let class = match pvar.class {
+ Some(c) => c,
+ None => match binding {
+ Some(crate::Binding::BuiltIn(builtin)) => match builtin {
+ crate::BuiltIn::GlobalInvocationId => crate::StorageClass::Input,
+ crate::BuiltIn::Position => crate::StorageClass::Output,
+ _ => unimplemented!(),
+ },
+ _ => crate::StorageClass::Handle,
+ },
+ };
+ let var_handle = module.global_variables.append(crate::GlobalVariable {
+ name: Some(pvar.name.to_owned()),
+ class,
+ binding: binding.take(),
+ ty: pvar.ty,
+ init: pvar.init,
+ interpolation,
+ storage_access: pvar.access,
+ });
+ lookup_global_expression
+ .insert(pvar.name, crate::Expression::GlobalVariable(var_handle));
+ }
+ Token::Word("fn") => {
+ let (function, name) =
+ self.parse_function_decl(lexer, module, &lookup_global_expression)?;
+ let already_declared = match stage {
+ Some(stage) => module
+ .entry_points
+ .insert(
+ (stage, name.to_string()),
+ crate::EntryPoint {
+ early_depth_test: None,
+ workgroup_size,
+ function,
+ },
+ )
+ .is_some(),
+ None => {
+ let fun_handle = module.functions.append(function);
+ self.function_lookup
+ .insert(name.to_string(), fun_handle)
+ .is_some()
+ }
+ };
+ if already_declared {
+ return Err(Error::FunctionRedefinition(name));
+ }
+ }
+ Token::End => return Ok(false),
+ token => return Err(Error::Unexpected(token)),
+ }
+ match binding {
+ None => Ok(true),
+ // we had the decoration but no var?
+ Some(_) => Err(Error::Other),
+ }
+ }
+
+ pub fn parse<'a>(&mut self, source: &'a str) -> Result<crate::Module, ParseError<'a>> {
+ self.scopes.clear();
+ self.lookup_type.clear();
+ self.std_namespace = None;
+
+ let mut module = crate::Module::generate_empty();
+ let mut lexer = Lexer::new(source);
+ let mut lookup_global_expression = FastHashMap::default();
+ loop {
+ match self.parse_global_decl(&mut lexer, &mut module, &mut lookup_global_expression) {
+ Err(error) => {
+ let pos = lexer.offset_from(source);
+ let (mut rows, mut cols) = (0, 1);
+ for line in source[..pos].lines() {
+ rows += 1;
+ cols = line.len();
+ }
+ return Err(ParseError {
+ error,
+ scopes: std::mem::replace(&mut self.scopes, Vec::new()),
+ pos: (rows, cols),
+ });
+ }
+ Ok(true) => {}
+ Ok(false) => {
+ if !self.scopes.is_empty() {
+ return Err(ParseError {
+ error: Error::Other,
+ scopes: std::mem::replace(&mut self.scopes, Vec::new()),
+ pos: (0, 0),
+ });
+ };
+ return Ok(module);
+ }
+ }
+ }
+ }
+}
+
+pub fn parse_str(source: &str) -> Result<crate::Module, ParseError> {
+ Parser::new().parse(source)
+}
+
+#[test]
+fn parse_types() {
+ assert!(parse_str("const a : i32 = 2;").is_ok());
+ assert!(parse_str("const a : x32 = 2;").is_err());
+}
diff --git a/third_party/rust/naga/src/lib.rs b/third_party/rust/naga/src/lib.rs
new file mode 100644
index 0000000000..936c3b69f5
--- /dev/null
+++ b/third_party/rust/naga/src/lib.rs
@@ -0,0 +1,788 @@
+//! Universal shader translator.
+//!
+//! The central structure of the crate is [`Module`].
+//!
+//! To improve performance and reduce memory usage, most structures are stored
+//! in an [`Arena`], and can be retrieved using the corresponding [`Handle`].
+#![allow(
+ clippy::new_without_default,
+ clippy::unneeded_field_pattern,
+ clippy::match_like_matches_macro
+)]
+// TODO: use `strip_prefix` instead when Rust 1.45 <= MSRV
+#![allow(clippy::manual_strip, clippy::unknown_clippy_lints)]
+#![deny(clippy::panic)]
+
+mod arena;
+pub mod back;
+pub mod front;
+pub mod proc;
+
+pub use crate::arena::{Arena, Handle};
+
+use std::{
+ collections::{HashMap, HashSet},
+ hash::BuildHasherDefault,
+ num::NonZeroU32,
+};
+
+#[cfg(feature = "deserialize")]
+use serde::Deserialize;
+#[cfg(feature = "serialize")]
+use serde::Serialize;
+
+/// Hash map that is faster but not resilient to DoS attacks.
+pub type FastHashMap<K, T> = HashMap<K, T, BuildHasherDefault<fxhash::FxHasher>>;
+/// Hash set that is faster but not resilient to DoS attacks.
+pub type FastHashSet<K> = HashSet<K, BuildHasherDefault<fxhash::FxHasher>>;
+
+/// Metadata for a given module.
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct Header {
+ /// Major, minor and patch version.
+ ///
+ /// Currently used only for the SPIR-V back end.
+ pub version: (u8, u8, u8),
+ /// Magic number identifying the tool that generated the shader code.
+ ///
+ /// Can safely be set to 0.
+ pub generator: u32,
+}
+
+/// Early fragment tests. In a standard situation if a driver determines that it is possible to
+/// switch on early depth test it will. Typical situations when early depth test is switched off:
+/// - Calling ```discard``` in a shader.
+/// - Writing to the depth buffer, unless ConservativeDepth is enabled.
+///
+/// SPIR-V: ExecutionMode EarlyFragmentTests
+/// In GLSL: layout(early_fragment_tests) in;
+/// HLSL: Attribute earlydepthstencil
+///
+/// For more, see:
+/// - https://www.khronos.org/opengl/wiki/Early_Fragment_Test#Explicit_specification
+/// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-attributes-earlydepthstencil
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct EarlyDepthTest {
+ conservative: Option<ConservativeDepth>,
+}
+/// Enables adjusting depth without disabling early Z.
+///
+/// SPIR-V: ExecutionMode DepthGreater/DepthLess/DepthUnchanged
+/// GLSL: layout (depth_<greater/less/unchanged/any>) out float gl_FragDepth;
+/// - ```depth_any``` option behaves as if the layout qualifier was not present.
+/// HLSL: SV_Depth/SV_DepthGreaterEqual/SV_DepthLessEqual
+///
+/// For more, see:
+/// - https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt
+/// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-semantics#system-value-semantics
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum ConservativeDepth {
+ /// Shader may rewrite depth only with a value greater than calculated;
+ GreaterEqual,
+
+ /// Shader may rewrite depth smaller than one that would have been written without the modification.
+ LessEqual,
+
+ /// Shader may not rewrite depth value.
+ Unchanged,
+}
+
+/// Stage of the programmable pipeline.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[allow(missing_docs)] // The names are self evident
+pub enum ShaderStage {
+ Vertex,
+ Fragment,
+ Compute,
+}
+
+/// Class of storage for variables.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[allow(missing_docs)] // The names are self evident
+pub enum StorageClass {
+ /// Function locals.
+ Function,
+ /// Pipeline input, per invocation.
+ Input,
+ /// Pipeline output, per invocation, mutable.
+ Output,
+ /// Private data, per invocation, mutable.
+ Private,
+ /// Workgroup shared data, mutable.
+ WorkGroup,
+ /// Uniform buffer data.
+ Uniform,
+ /// Storage buffer data, potentially mutable.
+ Storage,
+ /// Opaque handles, such as samplers and images.
+ Handle,
+ /// Push constants.
+ PushConstant,
+}
+
+/// Built-in inputs and outputs.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum BuiltIn {
+ // vertex
+ BaseInstance,
+ BaseVertex,
+ ClipDistance,
+ InstanceIndex,
+ Position,
+ VertexIndex,
+ // fragment
+ PointSize,
+ FragCoord,
+ FrontFacing,
+ SampleIndex,
+ FragDepth,
+ // compute
+ GlobalInvocationId,
+ LocalInvocationId,
+ LocalInvocationIndex,
+ WorkGroupId,
+}
+
+/// Number of bytes.
+pub type Bytes = u8;
+
+/// Number of components in a vector.
+#[repr(u8)]
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum VectorSize {
+ /// 2D vector
+ Bi = 2,
+ /// 3D vector
+ Tri = 3,
+ /// 4D vector
+ Quad = 4,
+}
+
+/// Primitive type for a scalar.
+#[repr(u8)]
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum ScalarKind {
+ /// Signed integer type.
+ Sint,
+ /// Unsigned integer type.
+ Uint,
+ /// Floating point type.
+ Float,
+ /// Boolean type.
+ Bool,
+}
+
+/// Size of an array.
+#[repr(u8)]
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum ArraySize {
+ /// The array size is constant.
+ Constant(Handle<Constant>),
+ /// The array size can change at runtime.
+ Dynamic,
+}
+
+/// Describes where a struct member is placed.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum MemberOrigin {
+ /// Member is local to the shader.
+ Empty,
+ /// Built-in shader variable.
+ BuiltIn(BuiltIn),
+ /// Offset within the struct.
+ Offset(u32),
+}
+
+/// The interpolation qualifier of a binding or struct field.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Interpolation {
+ /// The value will be interpolated in a perspective-correct fashion.
+ /// Also known as "smooth" in glsl.
+ Perspective,
+ /// Indicates that linear, non-perspective, correct
+ /// interpolation must be used.
+ /// Also known as "no_perspective" in glsl.
+ Linear,
+ /// Indicates that no interpolation will be performed.
+ Flat,
+ /// Indicates a tessellation patch.
+ Patch,
+ /// When used with multi-sampling rasterization, allow
+ /// a single interpolation location for an entire pixel.
+ Centroid,
+ /// When used with multi-sampling rasterization, require
+ /// per-sample interpolation.
+ Sample,
+}
+
+/// Member of a user-defined structure.
+// Clone is used only for error reporting and is not intended for end users
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct StructMember {
+ pub name: Option<String>,
+ pub origin: MemberOrigin,
+ pub ty: Handle<Type>,
+}
+
+/// The number of dimensions an image has.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum ImageDimension {
+ /// 1D image
+ D1,
+ /// 2D image
+ D2,
+ /// 3D image
+ D3,
+ /// Cube map
+ Cube,
+}
+
+bitflags::bitflags! {
+ /// Flags describing an image.
+ #[cfg_attr(feature = "serialize", derive(Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(Deserialize))]
+ pub struct StorageAccess: u32 {
+ /// Storage can be used as a source for load ops.
+ const LOAD = 0x1;
+ /// Storage can be used as a target for store ops.
+ const STORE = 0x2;
+ }
+}
+
+// Storage image format.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum StorageFormat {
+ // 8-bit formats
+ R8Unorm,
+ R8Snorm,
+ R8Uint,
+ R8Sint,
+
+ // 16-bit formats
+ R16Uint,
+ R16Sint,
+ R16Float,
+ Rg8Unorm,
+ Rg8Snorm,
+ Rg8Uint,
+ Rg8Sint,
+
+ // 32-bit formats
+ R32Uint,
+ R32Sint,
+ R32Float,
+ Rg16Uint,
+ Rg16Sint,
+ Rg16Float,
+ Rgba8Unorm,
+ Rgba8Snorm,
+ Rgba8Uint,
+ Rgba8Sint,
+
+ // Packed 32-bit formats
+ Rgb10a2Unorm,
+ Rg11b10Float,
+
+ // 64-bit formats
+ Rg32Uint,
+ Rg32Sint,
+ Rg32Float,
+ Rgba16Uint,
+ Rgba16Sint,
+ Rgba16Float,
+
+ // 128-bit formats
+ Rgba32Uint,
+ Rgba32Sint,
+ Rgba32Float,
+}
+
+/// Sub-class of the image type.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum ImageClass {
+ /// Regular sampled image.
+ Sampled {
+ /// Kind of values to sample.
+ kind: ScalarKind,
+ // Multi-sampled.
+ multi: bool,
+ },
+ /// Depth comparison image.
+ Depth,
+ /// Storage image.
+ Storage(StorageFormat),
+}
+
+/// A data type declared in the module.
+#[derive(Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct Type {
+ /// The name of the type, if any.
+ pub name: Option<String>,
+ /// Inner structure that depends on the kind of the type.
+ pub inner: TypeInner,
+}
+
+/// Enum with additional information, depending on the kind of type.
+#[derive(Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum TypeInner {
+ /// Number of integral or floating-point kind.
+ Scalar { kind: ScalarKind, width: Bytes },
+ /// Vector of numbers.
+ Vector {
+ size: VectorSize,
+ kind: ScalarKind,
+ width: Bytes,
+ },
+ /// Matrix of floats.
+ Matrix {
+ columns: VectorSize,
+ rows: VectorSize,
+ width: Bytes,
+ },
+ /// Pointer to a value.
+ Pointer {
+ base: Handle<Type>,
+ class: StorageClass,
+ },
+ /// Homogenous list of elements.
+ Array {
+ base: Handle<Type>,
+ size: ArraySize,
+ stride: Option<NonZeroU32>,
+ },
+ /// User-defined structure.
+ Struct { members: Vec<StructMember> },
+ /// Possibly multidimensional array of texels.
+ Image {
+ dim: ImageDimension,
+ arrayed: bool,
+ class: ImageClass,
+ },
+ /// Can be used to sample values from images.
+ Sampler { comparison: bool },
+}
+
+/// Constant value.
+#[derive(Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct Constant {
+ pub name: Option<String>,
+ pub specialization: Option<u32>,
+ pub inner: ConstantInner,
+ pub ty: Handle<Type>,
+}
+
+/// Additional information, dependendent on the kind of constant.
+#[derive(Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum ConstantInner {
+ Sint(i64),
+ Uint(u64),
+ Float(f64),
+ Bool(bool),
+ Composite(Vec<Handle<Constant>>),
+}
+
+/// Describes how an input/output variable is to be bound.
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Binding {
+ /// Built-in shader variable.
+ BuiltIn(BuiltIn),
+ /// Indexed location.
+ Location(u32),
+ /// Binding within a resource group.
+ Resource { group: u32, binding: u32 },
+}
+
+bitflags::bitflags! {
+ /// Indicates how a global variable is used.
+ #[cfg_attr(feature = "serialize", derive(Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(Deserialize))]
+ pub struct GlobalUse: u8 {
+ /// Data will be read from the variable.
+ const LOAD = 0x1;
+ /// Data will be written to the variable.
+ const STORE = 0x2;
+ }
+}
+
+/// Variable defined at module level.
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct GlobalVariable {
+ /// Name of the variable, if any.
+ pub name: Option<String>,
+ /// How this variable is to be stored.
+ pub class: StorageClass,
+ /// How this variable is to be bound.
+ pub binding: Option<Binding>,
+ /// The type of this variable.
+ pub ty: Handle<Type>,
+ /// Initial value for this variable.
+ pub init: Option<Handle<Constant>>,
+ /// The interpolation qualifier, if any.
+ /// If the this `GlobalVariable` is a vertex output
+ /// or fragment input, `None` corresponds to the
+ /// `smooth`/`perspective` interpolation qualifier.
+ pub interpolation: Option<Interpolation>,
+ /// Access bit for storage types of images and buffers.
+ pub storage_access: StorageAccess,
+}
+
+/// Variable defined at function level.
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct LocalVariable {
+ /// Name of the variable, if any.
+ pub name: Option<String>,
+ /// The type of this variable.
+ pub ty: Handle<Type>,
+ /// Initial value for this variable.
+ pub init: Option<Handle<Constant>>,
+}
+
+/// Operation that can be applied on a single value.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum UnaryOperator {
+ Negate,
+ Not,
+}
+
+/// Operation that can be applied on two values.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum BinaryOperator {
+ Add,
+ Subtract,
+ Multiply,
+ Divide,
+ Modulo,
+ Equal,
+ NotEqual,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+ And,
+ ExclusiveOr,
+ InclusiveOr,
+ LogicalAnd,
+ LogicalOr,
+ ShiftLeft,
+ /// Right shift carries the sign of signed integers only.
+ ShiftRight,
+}
+
+/// Built-in shader function.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum IntrinsicFunction {
+ Any,
+ All,
+ IsNan,
+ IsInf,
+ IsFinite,
+ IsNormal,
+}
+
+/// Axis on which to compute a derivative.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum DerivativeAxis {
+ X,
+ Y,
+ Width,
+}
+
+/// Origin of a function to call.
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum FunctionOrigin {
+ Local(Handle<Function>),
+ // External {
+ // namespace: String, // Maybe this should be a handle to a namespace Arena?
+ // function: String,
+ // },
+ External(String),
+}
+
+/// Sampling modifier to control the level of detail.
+#[derive(Clone, Copy, Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum SampleLevel {
+ Auto,
+ Zero,
+ Exact(Handle<Expression>),
+ Bias(Handle<Expression>),
+}
+
+/// An expression that can be evaluated to obtain a value.
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Expression {
+ /// Array access with a computed index.
+ Access {
+ base: Handle<Expression>,
+ index: Handle<Expression>, //int
+ },
+ /// Array access with a known index.
+ AccessIndex {
+ base: Handle<Expression>,
+ index: u32,
+ },
+ /// Constant value.
+ Constant(Handle<Constant>),
+ /// Composite expression.
+ Compose {
+ ty: Handle<Type>,
+ components: Vec<Handle<Expression>>,
+ },
+ /// Reference a function parameter, by its index.
+ FunctionArgument(u32),
+ /// Reference a global variable.
+ GlobalVariable(Handle<GlobalVariable>),
+ /// Reference a local variable.
+ LocalVariable(Handle<LocalVariable>),
+ /// Load a value indirectly.
+ Load { pointer: Handle<Expression> },
+ /// Sample a point from a sampled or a depth image.
+ ImageSample {
+ image: Handle<Expression>,
+ sampler: Handle<Expression>,
+ coordinate: Handle<Expression>,
+ level: SampleLevel,
+ depth_ref: Option<Handle<Expression>>,
+ },
+ /// Load a texel from an image.
+ ImageLoad {
+ image: Handle<Expression>,
+ coordinate: Handle<Expression>,
+ /// For storage images, this is None.
+ /// For sampled images, this is the Some(Level).
+ /// For multisampled images, this is Some(Sample).
+ index: Option<Handle<Expression>>,
+ },
+ /// Apply an unary operator.
+ Unary {
+ op: UnaryOperator,
+ expr: Handle<Expression>,
+ },
+ /// Apply a binary operator.
+ Binary {
+ op: BinaryOperator,
+ left: Handle<Expression>,
+ right: Handle<Expression>,
+ },
+ /// Select between two values based on a condition.
+ Select {
+ /// Boolean expression
+ condition: Handle<Expression>,
+ accept: Handle<Expression>,
+ reject: Handle<Expression>,
+ },
+ /// Call an intrinsic function.
+ Intrinsic {
+ fun: IntrinsicFunction,
+ argument: Handle<Expression>,
+ },
+ /// Transpose of a matrix.
+ Transpose(Handle<Expression>),
+ /// Dot product between two vectors.
+ DotProduct(Handle<Expression>, Handle<Expression>),
+ /// Cross product between two vectors.
+ CrossProduct(Handle<Expression>, Handle<Expression>),
+ /// Cast a simply type to another kind.
+ As {
+ /// Source expression, which can only be a scalar or a vector.
+ expr: Handle<Expression>,
+ /// Target scalar kind.
+ kind: ScalarKind,
+ /// True = conversion needs to take place; False = bitcast.
+ convert: bool,
+ },
+ /// Compute the derivative on an axis.
+ Derivative {
+ axis: DerivativeAxis,
+ //modifier,
+ expr: Handle<Expression>,
+ },
+ /// Call another function.
+ Call {
+ origin: FunctionOrigin,
+ arguments: Vec<Handle<Expression>>,
+ },
+ /// Get the length of an array.
+ ArrayLength(Handle<Expression>),
+}
+
+/// A code block is just a vector of statements.
+pub type Block = Vec<Statement>;
+
+/// Marker type, used for falling through in a switch statement.
+// Clone is used only for error reporting and is not intended for end users
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct FallThrough;
+
+/// Instructions which make up an executable block.
+// Clone is used only for error reporting and is not intended for end users
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Statement {
+ /// A block containing more statements, to be executed sequentially.
+ Block(Block),
+ /// Conditionally executes one of two blocks, based on the value of the condition.
+ If {
+ condition: Handle<Expression>, //bool
+ accept: Block,
+ reject: Block,
+ },
+ /// Conditionally executes one of multiple blocks, based on the value of the selector.
+ Switch {
+ selector: Handle<Expression>, //int
+ cases: FastHashMap<i32, (Block, Option<FallThrough>)>,
+ default: Block,
+ },
+ /// Executes a block repeatedly.
+ Loop { body: Block, continuing: Block },
+ //TODO: move terminator variations into a separate enum?
+ /// Exits the loop.
+ Break,
+ /// Skips execution to the next iteration of the loop.
+ Continue,
+ /// Returns from the function (possibly with a value).
+ Return { value: Option<Handle<Expression>> },
+ /// Aborts the current shader execution.
+ Kill,
+ /// Stores a value at an address.
+ Store {
+ pointer: Handle<Expression>,
+ value: Handle<Expression>,
+ },
+}
+
+/// A function argument.
+#[derive(Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct FunctionArgument {
+ /// Name of the argument, if any.
+ pub name: Option<String>,
+ /// Type of the argument.
+ pub ty: Handle<Type>,
+}
+
+/// A function defined in the module.
+#[derive(Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct Function {
+ /// Name of the function, if any.
+ pub name: Option<String>,
+ /// Information about function argument.
+ pub arguments: Vec<FunctionArgument>,
+ /// The return type of this function, if any.
+ pub return_type: Option<Handle<Type>>,
+ /// Vector of global variable usages.
+ ///
+ /// Each item corresponds to a global variable in the module.
+ pub global_usage: Vec<GlobalUse>,
+ /// Local variables defined and used in the function.
+ pub local_variables: Arena<LocalVariable>,
+ /// Expressions used inside this function.
+ pub expressions: Arena<Expression>,
+ /// Block of instructions comprising the body of the function.
+ pub body: Block,
+}
+
+/// Exported function, to be run at a certain stage in the pipeline.
+#[derive(Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct EntryPoint {
+ /// Early depth test for fragment stages.
+ pub early_depth_test: Option<EarlyDepthTest>,
+ /// Workgroup size for compute stages
+ pub workgroup_size: [u32; 3],
+ /// The entrance function.
+ pub function: Function,
+}
+
+/// Shader module.
+///
+/// A module is a set of constants, global variables and functions, as well as
+/// the types required to define them.
+///
+/// Some functions are marked as entry points, to be used in a certain shader stage.
+///
+/// To create a new module, use [`Module::from_header`] or [`Module::generate_empty`].
+/// Alternatively, you can load an existing shader using one of the [available front ends][front].
+///
+/// When finished, you can export modules using one of the [available back ends][back].
+#[derive(Debug)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct Module {
+ /// Header containing module metadata.
+ pub header: Header,
+ /// Storage for the types defined in this module.
+ pub types: Arena<Type>,
+ /// Storage for the constants defined in this module.
+ pub constants: Arena<Constant>,
+ /// Storage for the global variables defined in this module.
+ pub global_variables: Arena<GlobalVariable>,
+ /// Storage for the functions defined in this module.
+ pub functions: Arena<Function>,
+ /// Exported entry points.
+ pub entry_points: FastHashMap<(ShaderStage, String), EntryPoint>,
+}
diff --git a/third_party/rust/naga/src/proc/call_graph.rs b/third_party/rust/naga/src/proc/call_graph.rs
new file mode 100644
index 0000000000..1c580d5c15
--- /dev/null
+++ b/third_party/rust/naga/src/proc/call_graph.rs
@@ -0,0 +1,74 @@
+use crate::{
+ arena::{Arena, Handle},
+ proc::{Interface, Visitor},
+ Function,
+};
+use petgraph::{
+ graph::{DefaultIx, NodeIndex},
+ Graph,
+};
+
+pub type CallGraph = Graph<Handle<Function>, ()>;
+
+pub struct CallGraphBuilder<'a> {
+ pub functions: &'a Arena<Function>,
+}
+
+impl<'a> CallGraphBuilder<'a> {
+ pub fn process(&self, func: &Function) -> CallGraph {
+ let mut graph = Graph::new();
+ let mut children = Vec::new();
+
+ let visitor = CallGraphVisitor {
+ children: &mut children,
+ };
+
+ let mut interface = Interface {
+ expressions: &func.expressions,
+ local_variables: &func.local_variables,
+ visitor,
+ };
+
+ interface.traverse(&func.body);
+
+ for handle in children {
+ let id = graph.add_node(handle);
+ self.collect(handle, id, &mut graph);
+ }
+
+ graph
+ }
+
+ fn collect(&self, handle: Handle<Function>, id: NodeIndex<DefaultIx>, graph: &mut CallGraph) {
+ let mut children = Vec::new();
+ let visitor = CallGraphVisitor {
+ children: &mut children,
+ };
+ let func = &self.functions[handle];
+
+ let mut interface = Interface {
+ expressions: &func.expressions,
+ local_variables: &func.local_variables,
+ visitor,
+ };
+
+ interface.traverse(&func.body);
+
+ for handle in children {
+ let child_id = graph.add_node(handle);
+ graph.add_edge(id, child_id, ());
+
+ self.collect(handle, child_id, graph);
+ }
+ }
+}
+
+struct CallGraphVisitor<'a> {
+ children: &'a mut Vec<Handle<Function>>,
+}
+
+impl<'a> Visitor for CallGraphVisitor<'a> {
+ fn visit_fun(&mut self, func: Handle<Function>) {
+ self.children.push(func)
+ }
+}
diff --git a/third_party/rust/naga/src/proc/interface.rs b/third_party/rust/naga/src/proc/interface.rs
new file mode 100644
index 0000000000..b512452fe2
--- /dev/null
+++ b/third_party/rust/naga/src/proc/interface.rs
@@ -0,0 +1,290 @@
+use crate::arena::{Arena, Handle};
+
+pub struct Interface<'a, T> {
+ pub expressions: &'a Arena<crate::Expression>,
+ pub local_variables: &'a Arena<crate::LocalVariable>,
+ pub visitor: T,
+}
+
+pub trait Visitor {
+ fn visit_expr(&mut self, _: &crate::Expression) {}
+ fn visit_lhs_expr(&mut self, _: &crate::Expression) {}
+ fn visit_fun(&mut self, _: Handle<crate::Function>) {}
+}
+
+impl<'a, T> Interface<'a, T>
+where
+ T: Visitor,
+{
+ fn traverse_expr(&mut self, handle: Handle<crate::Expression>) {
+ use crate::Expression as E;
+
+ let expr = &self.expressions[handle];
+
+ self.visitor.visit_expr(expr);
+
+ match *expr {
+ E::Access { base, index } => {
+ self.traverse_expr(base);
+ self.traverse_expr(index);
+ }
+ E::AccessIndex { base, .. } => {
+ self.traverse_expr(base);
+ }
+ E::Constant(_) => {}
+ E::Compose { ref components, .. } => {
+ for &comp in components {
+ self.traverse_expr(comp);
+ }
+ }
+ E::FunctionArgument(_) | E::GlobalVariable(_) | E::LocalVariable(_) => {}
+ E::Load { pointer } => {
+ self.traverse_expr(pointer);
+ }
+ E::ImageSample {
+ image,
+ sampler,
+ coordinate,
+ level,
+ depth_ref,
+ } => {
+ self.traverse_expr(image);
+ self.traverse_expr(sampler);
+ self.traverse_expr(coordinate);
+ match level {
+ crate::SampleLevel::Auto | crate::SampleLevel::Zero => (),
+ crate::SampleLevel::Exact(h) | crate::SampleLevel::Bias(h) => {
+ self.traverse_expr(h)
+ }
+ }
+ if let Some(dref) = depth_ref {
+ self.traverse_expr(dref);
+ }
+ }
+ E::ImageLoad {
+ image,
+ coordinate,
+ index,
+ } => {
+ self.traverse_expr(image);
+ self.traverse_expr(coordinate);
+ if let Some(index) = index {
+ self.traverse_expr(index);
+ }
+ }
+ E::Unary { expr, .. } => {
+ self.traverse_expr(expr);
+ }
+ E::Binary { left, right, .. } => {
+ self.traverse_expr(left);
+ self.traverse_expr(right);
+ }
+ E::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ self.traverse_expr(condition);
+ self.traverse_expr(accept);
+ self.traverse_expr(reject);
+ }
+ E::Intrinsic { argument, .. } => {
+ self.traverse_expr(argument);
+ }
+ E::Transpose(matrix) => {
+ self.traverse_expr(matrix);
+ }
+ E::DotProduct(left, right) => {
+ self.traverse_expr(left);
+ self.traverse_expr(right);
+ }
+ E::CrossProduct(left, right) => {
+ self.traverse_expr(left);
+ self.traverse_expr(right);
+ }
+ E::As { expr, .. } => {
+ self.traverse_expr(expr);
+ }
+ E::Derivative { expr, .. } => {
+ self.traverse_expr(expr);
+ }
+ E::Call {
+ ref origin,
+ ref arguments,
+ } => {
+ for &argument in arguments {
+ self.traverse_expr(argument);
+ }
+ if let crate::FunctionOrigin::Local(fun) = *origin {
+ self.visitor.visit_fun(fun);
+ }
+ }
+ E::ArrayLength(expr) => {
+ self.traverse_expr(expr);
+ }
+ }
+ }
+
+ pub fn traverse(&mut self, block: &[crate::Statement]) {
+ for statement in block {
+ use crate::Statement as S;
+ match *statement {
+ S::Break | S::Continue | S::Kill => (),
+ S::Block(ref b) => {
+ self.traverse(b);
+ }
+ S::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ self.traverse_expr(condition);
+ self.traverse(accept);
+ self.traverse(reject);
+ }
+ S::Switch {
+ selector,
+ ref cases,
+ ref default,
+ } => {
+ self.traverse_expr(selector);
+ for &(ref case, _) in cases.values() {
+ self.traverse(case);
+ }
+ self.traverse(default);
+ }
+ S::Loop {
+ ref body,
+ ref continuing,
+ } => {
+ self.traverse(body);
+ self.traverse(continuing);
+ }
+ S::Return { value } => {
+ if let Some(expr) = value {
+ self.traverse_expr(expr);
+ }
+ }
+ S::Store { pointer, value } => {
+ let mut left = pointer;
+ loop {
+ match self.expressions[left] {
+ crate::Expression::Access { base, index } => {
+ self.traverse_expr(index);
+ left = base;
+ }
+ crate::Expression::AccessIndex { base, .. } => {
+ left = base;
+ }
+ _ => break,
+ }
+ }
+ self.visitor.visit_lhs_expr(&self.expressions[left]);
+ self.traverse_expr(value);
+ }
+ }
+ }
+ }
+}
+
+struct GlobalUseVisitor<'a>(&'a mut [crate::GlobalUse]);
+
+impl Visitor for GlobalUseVisitor<'_> {
+ fn visit_expr(&mut self, expr: &crate::Expression) {
+ if let crate::Expression::GlobalVariable(handle) = expr {
+ self.0[handle.index()] |= crate::GlobalUse::LOAD;
+ }
+ }
+
+ fn visit_lhs_expr(&mut self, expr: &crate::Expression) {
+ if let crate::Expression::GlobalVariable(handle) = expr {
+ self.0[handle.index()] |= crate::GlobalUse::STORE;
+ }
+ }
+}
+
+impl crate::Function {
+ pub fn fill_global_use(&mut self, globals: &Arena<crate::GlobalVariable>) {
+ self.global_usage.clear();
+ self.global_usage
+ .resize(globals.len(), crate::GlobalUse::empty());
+
+ let mut io = Interface {
+ expressions: &self.expressions,
+ local_variables: &self.local_variables,
+ visitor: GlobalUseVisitor(&mut self.global_usage),
+ };
+ io.traverse(&self.body);
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ Arena, Expression, GlobalUse, GlobalVariable, Handle, Statement, StorageAccess,
+ StorageClass,
+ };
+
+ #[test]
+ fn global_use_scan() {
+ let test_global = GlobalVariable {
+ name: None,
+ class: StorageClass::Uniform,
+ binding: None,
+ ty: Handle::new(std::num::NonZeroU32::new(1).unwrap()),
+ init: None,
+ interpolation: None,
+ storage_access: StorageAccess::empty(),
+ };
+ let mut test_globals = Arena::new();
+
+ let global_1 = test_globals.append(test_global.clone());
+ let global_2 = test_globals.append(test_global.clone());
+ let global_3 = test_globals.append(test_global.clone());
+ let global_4 = test_globals.append(test_global);
+
+ let mut expressions = Arena::new();
+ let global_1_expr = expressions.append(Expression::GlobalVariable(global_1));
+ let global_2_expr = expressions.append(Expression::GlobalVariable(global_2));
+ let global_3_expr = expressions.append(Expression::GlobalVariable(global_3));
+ let global_4_expr = expressions.append(Expression::GlobalVariable(global_4));
+
+ let test_body = vec![
+ Statement::Return {
+ value: Some(global_1_expr),
+ },
+ Statement::Store {
+ pointer: global_2_expr,
+ value: global_1_expr,
+ },
+ Statement::Store {
+ pointer: expressions.append(Expression::Access {
+ base: global_3_expr,
+ index: global_4_expr,
+ }),
+ value: global_1_expr,
+ },
+ ];
+
+ let mut function = crate::Function {
+ name: None,
+ arguments: Vec::new(),
+ return_type: None,
+ local_variables: Arena::new(),
+ expressions,
+ global_usage: Vec::new(),
+ body: test_body,
+ };
+ function.fill_global_use(&test_globals);
+
+ assert_eq!(
+ &function.global_usage,
+ &[
+ GlobalUse::LOAD,
+ GlobalUse::STORE,
+ GlobalUse::STORE,
+ GlobalUse::LOAD,
+ ],
+ )
+ }
+}
diff --git a/third_party/rust/naga/src/proc/mod.rs b/third_party/rust/naga/src/proc/mod.rs
new file mode 100644
index 0000000000..961e55da7c
--- /dev/null
+++ b/third_party/rust/naga/src/proc/mod.rs
@@ -0,0 +1,67 @@
+//! Module processing functionality.
+
+#[cfg(feature = "petgraph")]
+mod call_graph;
+mod interface;
+mod namer;
+mod typifier;
+mod validator;
+
+#[cfg(feature = "petgraph")]
+pub use call_graph::{CallGraph, CallGraphBuilder};
+pub use interface::{Interface, Visitor};
+pub use namer::{EntryPointIndex, NameKey, Namer};
+pub use typifier::{check_constant_type, ResolveContext, ResolveError, Typifier};
+pub use validator::{ValidationError, Validator};
+
+impl From<super::StorageFormat> for super::ScalarKind {
+ fn from(format: super::StorageFormat) -> Self {
+ use super::{ScalarKind as Sk, StorageFormat as Sf};
+ match format {
+ Sf::R8Unorm => Sk::Float,
+ Sf::R8Snorm => Sk::Float,
+ Sf::R8Uint => Sk::Uint,
+ Sf::R8Sint => Sk::Sint,
+ Sf::R16Uint => Sk::Uint,
+ Sf::R16Sint => Sk::Sint,
+ Sf::R16Float => Sk::Float,
+ Sf::Rg8Unorm => Sk::Float,
+ Sf::Rg8Snorm => Sk::Float,
+ Sf::Rg8Uint => Sk::Uint,
+ Sf::Rg8Sint => Sk::Sint,
+ Sf::R32Uint => Sk::Uint,
+ Sf::R32Sint => Sk::Sint,
+ Sf::R32Float => Sk::Float,
+ Sf::Rg16Uint => Sk::Uint,
+ Sf::Rg16Sint => Sk::Sint,
+ Sf::Rg16Float => Sk::Float,
+ Sf::Rgba8Unorm => Sk::Float,
+ Sf::Rgba8Snorm => Sk::Float,
+ Sf::Rgba8Uint => Sk::Uint,
+ Sf::Rgba8Sint => Sk::Sint,
+ Sf::Rgb10a2Unorm => Sk::Float,
+ Sf::Rg11b10Float => Sk::Float,
+ Sf::Rg32Uint => Sk::Uint,
+ Sf::Rg32Sint => Sk::Sint,
+ Sf::Rg32Float => Sk::Float,
+ Sf::Rgba16Uint => Sk::Uint,
+ Sf::Rgba16Sint => Sk::Sint,
+ Sf::Rgba16Float => Sk::Float,
+ Sf::Rgba32Uint => Sk::Uint,
+ Sf::Rgba32Sint => Sk::Sint,
+ Sf::Rgba32Float => Sk::Float,
+ }
+ }
+}
+
+impl crate::TypeInner {
+ pub fn scalar_kind(&self) -> Option<super::ScalarKind> {
+ match *self {
+ super::TypeInner::Scalar { kind, .. } | super::TypeInner::Vector { kind, .. } => {
+ Some(kind)
+ }
+ super::TypeInner::Matrix { .. } => Some(super::ScalarKind::Float),
+ _ => None,
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/proc/namer.rs b/third_party/rust/naga/src/proc/namer.rs
new file mode 100644
index 0000000000..03b508904b
--- /dev/null
+++ b/third_party/rust/naga/src/proc/namer.rs
@@ -0,0 +1,113 @@
+use crate::{arena::Handle, FastHashMap};
+use std::collections::hash_map::Entry;
+
+pub type EntryPointIndex = u16;
+
+#[derive(Debug, Eq, Hash, PartialEq)]
+pub enum NameKey {
+ GlobalVariable(Handle<crate::GlobalVariable>),
+ Type(Handle<crate::Type>),
+ StructMember(Handle<crate::Type>, u32),
+ Function(Handle<crate::Function>),
+ FunctionArgument(Handle<crate::Function>, u32),
+ FunctionLocal(Handle<crate::Function>, Handle<crate::LocalVariable>),
+ EntryPoint(EntryPointIndex),
+ EntryPointLocal(EntryPointIndex, Handle<crate::LocalVariable>),
+}
+
+/// This processor assigns names to all the things in a module
+/// that may need identifiers in a textual backend.
+pub struct Namer {
+ unique: FastHashMap<String, u32>,
+}
+
+impl Namer {
+ fn sanitize(string: &str) -> String {
+ let mut base = string
+ .chars()
+ .skip_while(|c| c.is_numeric())
+ .filter(|&c| c.is_ascii_alphanumeric() || c == '_')
+ .collect::<String>();
+ // close the name by '_' if the re is a number, so that
+ // we can have our own number!
+ match base.chars().next_back() {
+ Some(c) if !c.is_numeric() => {}
+ _ => base.push('_'),
+ };
+ base
+ }
+
+ fn call(&mut self, label_raw: &str) -> String {
+ let base = Self::sanitize(label_raw);
+ match self.unique.entry(base) {
+ Entry::Occupied(mut e) => {
+ *e.get_mut() += 1;
+ format!("{}{}", e.key(), e.get())
+ }
+ Entry::Vacant(e) => {
+ let name = e.key().to_string();
+ e.insert(0);
+ name
+ }
+ }
+ }
+
+ fn call_or(&mut self, label: &Option<String>, fallback: &str) -> String {
+ self.call(match *label {
+ Some(ref name) => name,
+ None => fallback,
+ })
+ }
+
+ pub fn process(
+ module: &crate::Module,
+ reserved: &[&str],
+ output: &mut FastHashMap<NameKey, String>,
+ ) {
+ let mut this = Namer {
+ unique: reserved
+ .iter()
+ .map(|string| (string.to_string(), 0))
+ .collect(),
+ };
+
+ for (handle, var) in module.global_variables.iter() {
+ let name = this.call_or(&var.name, "global");
+ output.insert(NameKey::GlobalVariable(handle), name);
+ }
+
+ for (ty_handle, ty) in module.types.iter() {
+ let ty_name = this.call_or(&ty.name, "type");
+ output.insert(NameKey::Type(ty_handle), ty_name);
+
+ if let crate::TypeInner::Struct { ref members } = ty.inner {
+ for (index, member) in members.iter().enumerate() {
+ let name = this.call_or(&member.name, "member");
+ output.insert(NameKey::StructMember(ty_handle, index as u32), name);
+ }
+ }
+ }
+
+ for (fun_handle, fun) in module.functions.iter() {
+ let fun_name = this.call_or(&fun.name, "function");
+ output.insert(NameKey::Function(fun_handle), fun_name);
+ for (index, arg) in fun.arguments.iter().enumerate() {
+ let name = this.call_or(&arg.name, "param");
+ output.insert(NameKey::FunctionArgument(fun_handle, index as u32), name);
+ }
+ for (handle, var) in fun.local_variables.iter() {
+ let name = this.call_or(&var.name, "local");
+ output.insert(NameKey::FunctionLocal(fun_handle, handle), name);
+ }
+ }
+
+ for (ep_index, (&(_, ref base_name), ep)) in module.entry_points.iter().enumerate() {
+ let ep_name = this.call(base_name);
+ output.insert(NameKey::EntryPoint(ep_index as _), ep_name);
+ for (handle, var) in ep.function.local_variables.iter() {
+ let name = this.call_or(&var.name, "local");
+ output.insert(NameKey::EntryPointLocal(ep_index as _, handle), name);
+ }
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/proc/typifier.rs b/third_party/rust/naga/src/proc/typifier.rs
new file mode 100644
index 0000000000..b09893c179
--- /dev/null
+++ b/third_party/rust/naga/src/proc/typifier.rs
@@ -0,0 +1,424 @@
+use crate::arena::{Arena, Handle};
+
+use thiserror::Error;
+
+#[derive(Debug, PartialEq)]
+enum Resolution {
+ Handle(Handle<crate::Type>),
+ Value(crate::TypeInner),
+}
+
+// Clone is only implemented for numeric variants of `TypeInner`.
+impl Clone for Resolution {
+ fn clone(&self) -> Self {
+ match *self {
+ Resolution::Handle(handle) => Resolution::Handle(handle),
+ Resolution::Value(ref v) => Resolution::Value(match *v {
+ crate::TypeInner::Scalar { kind, width } => {
+ crate::TypeInner::Scalar { kind, width }
+ }
+ crate::TypeInner::Vector { size, kind, width } => {
+ crate::TypeInner::Vector { size, kind, width }
+ }
+ crate::TypeInner::Matrix {
+ rows,
+ columns,
+ width,
+ } => crate::TypeInner::Matrix {
+ rows,
+ columns,
+ width,
+ },
+ #[allow(clippy::panic)]
+ _ => panic!("Unepxected clone type: {:?}", v),
+ }),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub struct Typifier {
+ resolutions: Vec<Resolution>,
+}
+
+#[derive(Clone, Debug, Error, PartialEq)]
+pub enum ResolveError {
+ #[error("Invalid index into array")]
+ InvalidAccessIndex,
+ #[error("Function {name} not defined")]
+ FunctionNotDefined { name: String },
+ #[error("Function without return type")]
+ FunctionReturnsVoid,
+ #[error("Type is not found in the given immutable arena")]
+ TypeNotFound,
+ #[error("Incompatible operand: {op} {operand}")]
+ IncompatibleOperand { op: String, operand: String },
+ #[error("Incompatible operands: {left} {op} {right}")]
+ IncompatibleOperands {
+ op: String,
+ left: String,
+ right: String,
+ },
+}
+
+pub struct ResolveContext<'a> {
+ pub constants: &'a Arena<crate::Constant>,
+ pub global_vars: &'a Arena<crate::GlobalVariable>,
+ pub local_vars: &'a Arena<crate::LocalVariable>,
+ pub functions: &'a Arena<crate::Function>,
+ pub arguments: &'a [crate::FunctionArgument],
+}
+
+impl Typifier {
+ pub fn new() -> Self {
+ Typifier {
+ resolutions: Vec::new(),
+ }
+ }
+
+ pub fn clear(&mut self) {
+ self.resolutions.clear()
+ }
+
+ pub fn get<'a>(
+ &'a self,
+ expr_handle: Handle<crate::Expression>,
+ types: &'a Arena<crate::Type>,
+ ) -> &'a crate::TypeInner {
+ match self.resolutions[expr_handle.index()] {
+ Resolution::Handle(ty_handle) => &types[ty_handle].inner,
+ Resolution::Value(ref inner) => inner,
+ }
+ }
+
+ pub fn get_handle(
+ &self,
+ expr_handle: Handle<crate::Expression>,
+ ) -> Option<Handle<crate::Type>> {
+ match self.resolutions[expr_handle.index()] {
+ Resolution::Handle(ty_handle) => Some(ty_handle),
+ Resolution::Value(_) => None,
+ }
+ }
+
+ fn resolve_impl(
+ &self,
+ expr: &crate::Expression,
+ types: &Arena<crate::Type>,
+ ctx: &ResolveContext,
+ ) -> Result<Resolution, ResolveError> {
+ Ok(match *expr {
+ crate::Expression::Access { base, .. } => match *self.get(base, types) {
+ crate::TypeInner::Array { base, .. } => Resolution::Handle(base),
+ crate::TypeInner::Vector {
+ size: _,
+ kind,
+ width,
+ } => Resolution::Value(crate::TypeInner::Scalar { kind, width }),
+ crate::TypeInner::Matrix {
+ rows: size,
+ columns: _,
+ width,
+ } => Resolution::Value(crate::TypeInner::Vector {
+ size,
+ kind: crate::ScalarKind::Float,
+ width,
+ }),
+ ref other => {
+ return Err(ResolveError::IncompatibleOperand {
+ op: "access".to_string(),
+ operand: format!("{:?}", other),
+ })
+ }
+ },
+ crate::Expression::AccessIndex { base, index } => match *self.get(base, types) {
+ crate::TypeInner::Vector { size, kind, width } => {
+ if index >= size as u32 {
+ return Err(ResolveError::InvalidAccessIndex);
+ }
+ Resolution::Value(crate::TypeInner::Scalar { kind, width })
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ if index >= columns as u32 {
+ return Err(ResolveError::InvalidAccessIndex);
+ }
+ Resolution::Value(crate::TypeInner::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ })
+ }
+ crate::TypeInner::Array { base, .. } => Resolution::Handle(base),
+ crate::TypeInner::Struct { ref members } => {
+ let member = members
+ .get(index as usize)
+ .ok_or(ResolveError::InvalidAccessIndex)?;
+ Resolution::Handle(member.ty)
+ }
+ ref other => {
+ return Err(ResolveError::IncompatibleOperand {
+ op: "access index".to_string(),
+ operand: format!("{:?}", other),
+ })
+ }
+ },
+ crate::Expression::Constant(h) => Resolution::Handle(ctx.constants[h].ty),
+ crate::Expression::Compose { ty, .. } => Resolution::Handle(ty),
+ crate::Expression::FunctionArgument(index) => {
+ Resolution::Handle(ctx.arguments[index as usize].ty)
+ }
+ crate::Expression::GlobalVariable(h) => Resolution::Handle(ctx.global_vars[h].ty),
+ crate::Expression::LocalVariable(h) => Resolution::Handle(ctx.local_vars[h].ty),
+ crate::Expression::Load { .. } => unimplemented!(),
+ crate::Expression::ImageSample { image, .. }
+ | crate::Expression::ImageLoad { image, .. } => match *self.get(image, types) {
+ crate::TypeInner::Image { class, .. } => Resolution::Value(match class {
+ crate::ImageClass::Depth => crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ },
+ crate::ImageClass::Sampled { kind, multi: _ } => crate::TypeInner::Vector {
+ kind,
+ width: 4,
+ size: crate::VectorSize::Quad,
+ },
+ crate::ImageClass::Storage(format) => crate::TypeInner::Vector {
+ kind: format.into(),
+ width: 4,
+ size: crate::VectorSize::Quad,
+ },
+ }),
+ _ => unreachable!(),
+ },
+ crate::Expression::Unary { expr, .. } => self.resolutions[expr.index()].clone(),
+ crate::Expression::Binary { op, left, right } => match op {
+ crate::BinaryOperator::Add
+ | crate::BinaryOperator::Subtract
+ | crate::BinaryOperator::Divide
+ | crate::BinaryOperator::Modulo => self.resolutions[left.index()].clone(),
+ crate::BinaryOperator::Multiply => {
+ let ty_left = self.get(left, types);
+ let ty_right = self.get(right, types);
+ if ty_left == ty_right {
+ self.resolutions[left.index()].clone()
+ } else if let crate::TypeInner::Scalar { .. } = *ty_right {
+ self.resolutions[left.index()].clone()
+ } else {
+ match *ty_left {
+ crate::TypeInner::Scalar { .. } => {
+ self.resolutions[right.index()].clone()
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows: _,
+ width,
+ } => Resolution::Value(crate::TypeInner::Vector {
+ size: columns,
+ kind: crate::ScalarKind::Float,
+ width,
+ }),
+ _ => {
+ return Err(ResolveError::IncompatibleOperands {
+ op: "x".to_string(),
+ left: format!("{:?}", ty_left),
+ right: format!("{:?}", ty_right),
+ })
+ }
+ }
+ }
+ }
+ crate::BinaryOperator::Equal
+ | crate::BinaryOperator::NotEqual
+ | crate::BinaryOperator::Less
+ | crate::BinaryOperator::LessEqual
+ | crate::BinaryOperator::Greater
+ | crate::BinaryOperator::GreaterEqual
+ | crate::BinaryOperator::LogicalAnd
+ | crate::BinaryOperator::LogicalOr => self.resolutions[left.index()].clone(),
+ crate::BinaryOperator::And
+ | crate::BinaryOperator::ExclusiveOr
+ | crate::BinaryOperator::InclusiveOr
+ | crate::BinaryOperator::ShiftLeft
+ | crate::BinaryOperator::ShiftRight => self.resolutions[left.index()].clone(),
+ },
+ crate::Expression::Select { accept, .. } => self.resolutions[accept.index()].clone(),
+ crate::Expression::Intrinsic { .. } => unimplemented!(),
+ crate::Expression::Transpose(expr) => match *self.get(expr, types) {
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ width,
+ } => Resolution::Value(crate::TypeInner::Matrix {
+ columns: rows,
+ rows: columns,
+ width,
+ }),
+ ref other => {
+ return Err(ResolveError::IncompatibleOperand {
+ op: "transpose".to_string(),
+ operand: format!("{:?}", other),
+ })
+ }
+ },
+ crate::Expression::DotProduct(left_expr, _) => match *self.get(left_expr, types) {
+ crate::TypeInner::Vector {
+ kind,
+ size: _,
+ width,
+ } => Resolution::Value(crate::TypeInner::Scalar { kind, width }),
+ ref other => {
+ return Err(ResolveError::IncompatibleOperand {
+ op: "dot product".to_string(),
+ operand: format!("{:?}", other),
+ })
+ }
+ },
+ crate::Expression::CrossProduct(_, _) => unimplemented!(),
+ crate::Expression::As {
+ expr,
+ kind,
+ convert: _,
+ } => match *self.get(expr, types) {
+ crate::TypeInner::Scalar { kind: _, width } => {
+ Resolution::Value(crate::TypeInner::Scalar { kind, width })
+ }
+ crate::TypeInner::Vector {
+ kind: _,
+ size,
+ width,
+ } => Resolution::Value(crate::TypeInner::Vector { kind, size, width }),
+ ref other => {
+ return Err(ResolveError::IncompatibleOperand {
+ op: "as".to_string(),
+ operand: format!("{:?}", other),
+ })
+ }
+ },
+ crate::Expression::Derivative { .. } => unimplemented!(),
+ crate::Expression::Call {
+ origin: crate::FunctionOrigin::External(ref name),
+ ref arguments,
+ } => match name.as_str() {
+ "distance" | "length" => match *self.get(arguments[0], types) {
+ crate::TypeInner::Vector { kind, width, .. }
+ | crate::TypeInner::Scalar { kind, width } => {
+ Resolution::Value(crate::TypeInner::Scalar { kind, width })
+ }
+ ref other => {
+ return Err(ResolveError::IncompatibleOperand {
+ op: name.clone(),
+ operand: format!("{:?}", other),
+ })
+ }
+ },
+ "dot" => match *self.get(arguments[0], types) {
+ crate::TypeInner::Vector { kind, width, .. } => {
+ Resolution::Value(crate::TypeInner::Scalar { kind, width })
+ }
+ ref other => {
+ return Err(ResolveError::IncompatibleOperand {
+ op: name.clone(),
+ operand: format!("{:?}", other),
+ })
+ }
+ },
+ //Note: `cross` is here too, we still need to figure out what to do with it
+ "abs" | "atan2" | "cos" | "sin" | "floor" | "inverse" | "normalize" | "min"
+ | "max" | "reflect" | "pow" | "clamp" | "fclamp" | "mix" | "step"
+ | "smoothstep" | "cross" => self.resolutions[arguments[0].index()].clone(),
+ _ => return Err(ResolveError::FunctionNotDefined { name: name.clone() }),
+ },
+ crate::Expression::Call {
+ origin: crate::FunctionOrigin::Local(handle),
+ arguments: _,
+ } => {
+ let ty = ctx.functions[handle]
+ .return_type
+ .ok_or(ResolveError::FunctionReturnsVoid)?;
+ Resolution::Handle(ty)
+ }
+ crate::Expression::ArrayLength(_) => Resolution::Value(crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ }),
+ })
+ }
+
+ pub fn grow(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ expressions: &Arena<crate::Expression>,
+ types: &mut Arena<crate::Type>,
+ ctx: &ResolveContext,
+ ) -> Result<(), ResolveError> {
+ if self.resolutions.len() <= expr_handle.index() {
+ for (eh, expr) in expressions.iter().skip(self.resolutions.len()) {
+ let resolution = self.resolve_impl(expr, types, ctx)?;
+ log::debug!("Resolving {:?} = {:?} : {:?}", eh, expr, resolution);
+
+ let ty_handle = match resolution {
+ Resolution::Handle(h) => h,
+ Resolution::Value(inner) => types
+ .fetch_if_or_append(crate::Type { name: None, inner }, |a, b| {
+ a.inner == b.inner
+ }),
+ };
+ self.resolutions.push(Resolution::Handle(ty_handle));
+ }
+ }
+ Ok(())
+ }
+
+ pub fn resolve_all(
+ &mut self,
+ expressions: &Arena<crate::Expression>,
+ types: &Arena<crate::Type>,
+ ctx: &ResolveContext,
+ ) -> Result<(), ResolveError> {
+ self.clear();
+ for (_, expr) in expressions.iter() {
+ let resolution = self.resolve_impl(expr, types, ctx)?;
+ self.resolutions.push(resolution);
+ }
+ Ok(())
+ }
+}
+
+pub fn check_constant_type(inner: &crate::ConstantInner, type_inner: &crate::TypeInner) -> bool {
+ match (inner, type_inner) {
+ (
+ crate::ConstantInner::Sint(_),
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: _,
+ },
+ ) => true,
+ (
+ crate::ConstantInner::Uint(_),
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: _,
+ },
+ ) => true,
+ (
+ crate::ConstantInner::Float(_),
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Float,
+ width: _,
+ },
+ ) => true,
+ (
+ crate::ConstantInner::Bool(_),
+ crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: _,
+ },
+ ) => true,
+ (crate::ConstantInner::Composite(_inner), _) => true, // TODO recursively check composite types
+ (_, _) => false,
+ }
+}
diff --git a/third_party/rust/naga/src/proc/validator.rs b/third_party/rust/naga/src/proc/validator.rs
new file mode 100644
index 0000000000..d9d3eac659
--- /dev/null
+++ b/third_party/rust/naga/src/proc/validator.rs
@@ -0,0 +1,489 @@
+use super::typifier::{ResolveContext, ResolveError, Typifier};
+use crate::arena::{Arena, Handle};
+
+const MAX_BIND_GROUPS: u32 = 8;
+const MAX_LOCATIONS: u32 = 64; // using u64 mask
+const MAX_BIND_INDICES: u32 = 64; // using u64 mask
+const MAX_WORKGROUP_SIZE: u32 = 0x4000;
+
+#[derive(Debug)]
+pub struct Validator {
+ //Note: this is a bit tricky: some of the front-ends as well as backends
+ // already have to use the typifier, so the work here is redundant in a way.
+ typifier: Typifier,
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+pub enum GlobalVariableError {
+ #[error("Usage isn't compatible with the storage class")]
+ InvalidUsage,
+ #[error("Type isn't compatible with the storage class")]
+ InvalidType,
+ #[error("Interpolation is not valid")]
+ InvalidInterpolation,
+ #[error("Storage access {seen:?} exceed the allowed {allowed:?}")]
+ InvalidStorageAccess {
+ allowed: crate::StorageAccess,
+ seen: crate::StorageAccess,
+ },
+ #[error("Binding decoration is missing or not applicable")]
+ InvalidBinding,
+ #[error("Binding is out of range")]
+ OutOfRangeBinding,
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+pub enum LocalVariableError {
+ #[error("Initializer is not a constant expression")]
+ InitializerConst,
+ #[error("Initializer doesn't match the variable type")]
+ InitializerType,
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+pub enum FunctionError {
+ #[error(transparent)]
+ Resolve(#[from] ResolveError),
+ #[error("There are instructions after `return`/`break`/`continue`")]
+ InvalidControlFlowExitTail,
+ #[error("Local variable {handle:?} '{name}' is invalid: {error:?}")]
+ LocalVariable {
+ handle: Handle<crate::LocalVariable>,
+ name: String,
+ error: LocalVariableError,
+ },
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+pub enum EntryPointError {
+ #[error("Early depth test is not applicable")]
+ UnexpectedEarlyDepthTest,
+ #[error("Workgroup size is not applicable")]
+ UnexpectedWorkgroupSize,
+ #[error("Workgroup size is out of range")]
+ OutOfRangeWorkgroupSize,
+ #[error("Global variable {0:?} is used incorrectly as {1:?}")]
+ InvalidGlobalUsage(Handle<crate::GlobalVariable>, crate::GlobalUse),
+ #[error("Bindings for {0:?} conflict with other global variables")]
+ BindingCollision(Handle<crate::GlobalVariable>),
+ #[error("Built-in {0:?} is not applicable to this entry point")]
+ InvalidBuiltIn(crate::BuiltIn),
+ #[error("Interpolation of an integer has to be flat")]
+ InvalidIntegerInterpolation,
+ #[error(transparent)]
+ Function(#[from] FunctionError),
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+pub enum ValidationError {
+ #[error("The type {0:?} width {1} is not supported")]
+ InvalidTypeWidth(crate::ScalarKind, crate::Bytes),
+ #[error("The type handle {0:?} can not be resolved")]
+ UnresolvedType(Handle<crate::Type>),
+ #[error("The constant {0:?} can not be used for an array size")]
+ InvalidArraySizeConstant(Handle<crate::Constant>),
+ #[error("Global variable {handle:?} '{name}' is invalid: {error:?}")]
+ GlobalVariable {
+ handle: Handle<crate::GlobalVariable>,
+ name: String,
+ error: GlobalVariableError,
+ },
+ #[error("Function {0:?} is invalid: {1:?}")]
+ Function(Handle<crate::Function>, FunctionError),
+ #[error("Entry point {name} at {stage:?} is invalid: {error:?}")]
+ EntryPoint {
+ stage: crate::ShaderStage,
+ name: String,
+ error: EntryPointError,
+ },
+ #[error("Module is corrupted")]
+ Corrupted,
+}
+
+impl crate::GlobalVariable {
+ fn forbid_interpolation(&self) -> Result<(), GlobalVariableError> {
+ match self.interpolation {
+ Some(_) => Err(GlobalVariableError::InvalidInterpolation),
+ None => Ok(()),
+ }
+ }
+
+ fn check_resource(&self) -> Result<(), GlobalVariableError> {
+ match self.binding {
+ Some(crate::Binding::BuiltIn(_)) => {} // validated per entry point
+ Some(crate::Binding::Resource { group, binding }) => {
+ if group > MAX_BIND_GROUPS || binding > MAX_BIND_INDICES {
+ return Err(GlobalVariableError::OutOfRangeBinding);
+ }
+ }
+ Some(crate::Binding::Location(_)) | None => {
+ return Err(GlobalVariableError::InvalidBinding)
+ }
+ }
+ self.forbid_interpolation()
+ }
+}
+
+fn storage_usage(access: crate::StorageAccess) -> crate::GlobalUse {
+ let mut storage_usage = crate::GlobalUse::empty();
+ if access.contains(crate::StorageAccess::LOAD) {
+ storage_usage |= crate::GlobalUse::LOAD;
+ }
+ if access.contains(crate::StorageAccess::STORE) {
+ storage_usage |= crate::GlobalUse::STORE;
+ }
+ storage_usage
+}
+
+impl Validator {
+ /// Construct a new validator instance.
+ pub fn new() -> Self {
+ Validator {
+ typifier: Typifier::new(),
+ }
+ }
+
+ fn validate_global_var(
+ &self,
+ var: &crate::GlobalVariable,
+ types: &Arena<crate::Type>,
+ ) -> Result<(), GlobalVariableError> {
+ log::debug!("var {:?}", var);
+ let allowed_storage_access = match var.class {
+ crate::StorageClass::Function => return Err(GlobalVariableError::InvalidUsage),
+ crate::StorageClass::Input | crate::StorageClass::Output => {
+ match var.binding {
+ Some(crate::Binding::BuiltIn(_)) => {
+ // validated per entry point
+ var.forbid_interpolation()?
+ }
+ Some(crate::Binding::Location(loc)) => {
+ if loc > MAX_LOCATIONS {
+ return Err(GlobalVariableError::OutOfRangeBinding);
+ }
+ match types[var.ty].inner {
+ crate::TypeInner::Scalar { .. }
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. } => {}
+ _ => return Err(GlobalVariableError::InvalidType),
+ }
+ }
+ Some(crate::Binding::Resource { .. }) => {
+ return Err(GlobalVariableError::InvalidBinding)
+ }
+ None => {
+ match types[var.ty].inner {
+ //TODO: check the member types
+ crate::TypeInner::Struct { members: _ } => {
+ var.forbid_interpolation()?
+ }
+ _ => return Err(GlobalVariableError::InvalidType),
+ }
+ }
+ }
+ crate::StorageAccess::empty()
+ }
+ crate::StorageClass::Storage => {
+ var.check_resource()?;
+ crate::StorageAccess::all()
+ }
+ crate::StorageClass::Uniform => {
+ var.check_resource()?;
+ crate::StorageAccess::empty()
+ }
+ crate::StorageClass::Handle => {
+ var.check_resource()?;
+ match types[var.ty].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage(_),
+ ..
+ } => crate::StorageAccess::all(),
+ _ => crate::StorageAccess::empty(),
+ }
+ }
+ crate::StorageClass::Private | crate::StorageClass::WorkGroup => {
+ if var.binding.is_some() {
+ return Err(GlobalVariableError::InvalidBinding);
+ }
+ var.forbid_interpolation()?;
+ crate::StorageAccess::empty()
+ }
+ crate::StorageClass::PushConstant => {
+ //TODO
+ return Err(GlobalVariableError::InvalidStorageAccess {
+ allowed: crate::StorageAccess::empty(),
+ seen: crate::StorageAccess::empty(),
+ });
+ }
+ };
+
+ if !allowed_storage_access.contains(var.storage_access) {
+ return Err(GlobalVariableError::InvalidStorageAccess {
+ allowed: allowed_storage_access,
+ seen: var.storage_access,
+ });
+ }
+
+ Ok(())
+ }
+
+ fn validate_local_var(
+ &self,
+ var: &crate::LocalVariable,
+ _fun: &crate::Function,
+ _types: &Arena<crate::Type>,
+ ) -> Result<(), LocalVariableError> {
+ log::debug!("var {:?}", var);
+ if let Some(_expr_handle) = var.init {
+ if false {
+ return Err(LocalVariableError::InitializerConst);
+ }
+ }
+ Ok(())
+ }
+
+ fn validate_function(
+ &mut self,
+ fun: &crate::Function,
+ module: &crate::Module,
+ ) -> Result<(), FunctionError> {
+ let resolve_ctx = ResolveContext {
+ constants: &module.constants,
+ global_vars: &module.global_variables,
+ local_vars: &fun.local_variables,
+ functions: &module.functions,
+ arguments: &fun.arguments,
+ };
+ self.typifier
+ .resolve_all(&fun.expressions, &module.types, &resolve_ctx)?;
+
+ for (var_handle, var) in fun.local_variables.iter() {
+ self.validate_local_var(var, fun, &module.types)
+ .map_err(|error| FunctionError::LocalVariable {
+ handle: var_handle,
+ name: var.name.clone().unwrap_or_default(),
+ error,
+ })?;
+ }
+ Ok(())
+ }
+
+ fn validate_entry_point(
+ &mut self,
+ ep: &crate::EntryPoint,
+ stage: crate::ShaderStage,
+ module: &crate::Module,
+ ) -> Result<(), EntryPointError> {
+ if ep.early_depth_test.is_some() && stage != crate::ShaderStage::Fragment {
+ return Err(EntryPointError::UnexpectedEarlyDepthTest);
+ }
+ if stage == crate::ShaderStage::Compute {
+ if ep
+ .workgroup_size
+ .iter()
+ .any(|&s| s == 0 || s > MAX_WORKGROUP_SIZE)
+ {
+ return Err(EntryPointError::OutOfRangeWorkgroupSize);
+ }
+ } else if ep.workgroup_size != [0; 3] {
+ return Err(EntryPointError::UnexpectedWorkgroupSize);
+ }
+
+ let mut bind_group_masks = [0u64; MAX_BIND_GROUPS as usize];
+ let mut location_in_mask = 0u64;
+ let mut location_out_mask = 0u64;
+ for ((var_handle, var), &usage) in module
+ .global_variables
+ .iter()
+ .zip(&ep.function.global_usage)
+ {
+ if usage.is_empty() {
+ continue;
+ }
+
+ if let Some(crate::Binding::Location(_)) = var.binding {
+ match (stage, var.class) {
+ (crate::ShaderStage::Vertex, crate::StorageClass::Output)
+ | (crate::ShaderStage::Fragment, crate::StorageClass::Input) => {
+ match module.types[var.ty].inner.scalar_kind() {
+ Some(crate::ScalarKind::Float) => {}
+ Some(_) if var.interpolation != Some(crate::Interpolation::Flat) => {
+ return Err(EntryPointError::InvalidIntegerInterpolation);
+ }
+ _ => {}
+ }
+ }
+ _ => {}
+ }
+ }
+
+ let allowed_usage = match var.class {
+ crate::StorageClass::Function => unreachable!(),
+ crate::StorageClass::Input => {
+ let mask = match var.binding {
+ Some(crate::Binding::BuiltIn(built_in)) => match (stage, built_in) {
+ (crate::ShaderStage::Vertex, crate::BuiltIn::BaseInstance)
+ | (crate::ShaderStage::Vertex, crate::BuiltIn::BaseVertex)
+ | (crate::ShaderStage::Vertex, crate::BuiltIn::InstanceIndex)
+ | (crate::ShaderStage::Vertex, crate::BuiltIn::VertexIndex)
+ | (crate::ShaderStage::Fragment, crate::BuiltIn::PointSize)
+ | (crate::ShaderStage::Fragment, crate::BuiltIn::FragCoord)
+ | (crate::ShaderStage::Fragment, crate::BuiltIn::FrontFacing)
+ | (crate::ShaderStage::Fragment, crate::BuiltIn::SampleIndex)
+ | (crate::ShaderStage::Compute, crate::BuiltIn::GlobalInvocationId)
+ | (crate::ShaderStage::Compute, crate::BuiltIn::LocalInvocationId)
+ | (crate::ShaderStage::Compute, crate::BuiltIn::LocalInvocationIndex)
+ | (crate::ShaderStage::Compute, crate::BuiltIn::WorkGroupId) => 0,
+ _ => return Err(EntryPointError::InvalidBuiltIn(built_in)),
+ },
+ Some(crate::Binding::Location(loc)) => 1 << loc,
+ Some(crate::Binding::Resource { .. }) => unreachable!(),
+ None => 0,
+ };
+ if location_in_mask & mask != 0 {
+ return Err(EntryPointError::BindingCollision(var_handle));
+ }
+ location_in_mask |= mask;
+ crate::GlobalUse::LOAD
+ }
+ crate::StorageClass::Output => {
+ let mask = match var.binding {
+ Some(crate::Binding::BuiltIn(built_in)) => match (stage, built_in) {
+ (crate::ShaderStage::Vertex, crate::BuiltIn::Position)
+ | (crate::ShaderStage::Vertex, crate::BuiltIn::PointSize)
+ | (crate::ShaderStage::Vertex, crate::BuiltIn::ClipDistance)
+ | (crate::ShaderStage::Fragment, crate::BuiltIn::FragDepth) => 0,
+ _ => return Err(EntryPointError::InvalidBuiltIn(built_in)),
+ },
+ Some(crate::Binding::Location(loc)) => 1 << loc,
+ Some(crate::Binding::Resource { .. }) => unreachable!(),
+ None => 0,
+ };
+ if location_out_mask & mask != 0 {
+ return Err(EntryPointError::BindingCollision(var_handle));
+ }
+ location_out_mask |= mask;
+ crate::GlobalUse::LOAD | crate::GlobalUse::STORE
+ }
+ crate::StorageClass::Uniform => crate::GlobalUse::LOAD,
+ crate::StorageClass::Storage => storage_usage(var.storage_access),
+ crate::StorageClass::Handle => match module.types[var.ty].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage(_),
+ ..
+ } => storage_usage(var.storage_access),
+ _ => crate::GlobalUse::LOAD,
+ },
+ crate::StorageClass::Private | crate::StorageClass::WorkGroup => {
+ crate::GlobalUse::all()
+ }
+ crate::StorageClass::PushConstant => crate::GlobalUse::LOAD,
+ };
+ if !allowed_usage.contains(usage) {
+ log::warn!("\tUsage error for: {:?}", var);
+ log::warn!(
+ "\tAllowed usage: {:?}, requested: {:?}",
+ allowed_usage,
+ usage
+ );
+ return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage));
+ }
+
+ if let Some(crate::Binding::Resource { group, binding }) = var.binding {
+ let mask = 1 << binding;
+ let group_mask = &mut bind_group_masks[group as usize];
+ if *group_mask & mask != 0 {
+ return Err(EntryPointError::BindingCollision(var_handle));
+ }
+ *group_mask |= mask;
+ }
+ }
+
+ self.validate_function(&ep.function, module)?;
+ Ok(())
+ }
+
+ /// Check the given module to be valid.
+ pub fn validate(&mut self, module: &crate::Module) -> Result<(), ValidationError> {
+ // check the types
+ for (handle, ty) in module.types.iter() {
+ use crate::TypeInner as Ti;
+ match ty.inner {
+ Ti::Scalar { kind, width } | Ti::Vector { kind, width, .. } => {
+ let expected = match kind {
+ crate::ScalarKind::Bool => 1,
+ _ => 4,
+ };
+ if width != expected {
+ return Err(ValidationError::InvalidTypeWidth(kind, width));
+ }
+ }
+ Ti::Matrix { width, .. } => {
+ if width != 4 {
+ return Err(ValidationError::InvalidTypeWidth(
+ crate::ScalarKind::Float,
+ width,
+ ));
+ }
+ }
+ Ti::Pointer { base, class: _ } => {
+ if base >= handle {
+ return Err(ValidationError::UnresolvedType(base));
+ }
+ }
+ Ti::Array { base, size, .. } => {
+ if base >= handle {
+ return Err(ValidationError::UnresolvedType(base));
+ }
+ if let crate::ArraySize::Constant(const_handle) = size {
+ let constant = module
+ .constants
+ .try_get(const_handle)
+ .ok_or(ValidationError::Corrupted)?;
+ match constant.inner {
+ crate::ConstantInner::Uint(_) => {}
+ _ => {
+ return Err(ValidationError::InvalidArraySizeConstant(const_handle))
+ }
+ }
+ }
+ }
+ Ti::Struct { ref members } => {
+ //TODO: check that offsets are not intersecting?
+ for member in members {
+ if member.ty >= handle {
+ return Err(ValidationError::UnresolvedType(member.ty));
+ }
+ }
+ }
+ Ti::Image { .. } => {}
+ Ti::Sampler { comparison: _ } => {}
+ }
+ }
+
+ for (var_handle, var) in module.global_variables.iter() {
+ self.validate_global_var(var, &module.types)
+ .map_err(|error| ValidationError::GlobalVariable {
+ handle: var_handle,
+ name: var.name.clone().unwrap_or_default(),
+ error,
+ })?;
+ }
+
+ for (fun_handle, fun) in module.functions.iter() {
+ self.validate_function(fun, module)
+ .map_err(|e| ValidationError::Function(fun_handle, e))?;
+ }
+
+ for (&(stage, ref name), entry_point) in module.entry_points.iter() {
+ self.validate_entry_point(entry_point, stage, module)
+ .map_err(|error| ValidationError::EntryPoint {
+ stage,
+ name: name.to_string(),
+ error,
+ })?;
+ }
+
+ Ok(())
+ }
+}