diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 14:29:10 +0000 |
commit | 2aa4a82499d4becd2284cdb482213d541b8804dd (patch) | |
tree | b80bf8bf13c3766139fbacc530efd0dd9d54394c /third_party/rust/naga/src | |
parent | Initial commit. (diff) | |
download | firefox-2aa4a82499d4becd2284cdb482213d541b8804dd.tar.xz firefox-2aa4a82499d4becd2284cdb482213d541b8804dd.zip |
Adding upstream version 86.0.1.upstream/86.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/naga/src')
43 files changed, 17114 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/arena.rs b/third_party/rust/naga/src/arena.rs new file mode 100644 index 0000000000..460d909c2d --- /dev/null +++ b/third_party/rust/naga/src/arena.rs @@ -0,0 +1,225 @@ +use std::{cmp::Ordering, fmt, hash, marker::PhantomData, num::NonZeroU32}; + +/// An unique index in the arena array that a handle points to. +/// The "non-zero" part ensures that an `Option<Handle<T>>` has +/// the same size and representation as `Handle<T>`. +type Index = NonZeroU32; + +/// A strongly typed reference to a SPIR-V element. +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +#[cfg_attr( + any(feature = "serialize", feature = "deserialize"), + serde(transparent) +)] +pub struct Handle<T> { + index: Index, + #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))] + marker: PhantomData<T>, +} + +impl<T> Clone for Handle<T> { + fn clone(&self) -> Self { + Handle { + index: self.index, + marker: self.marker, + } + } +} +impl<T> Copy for Handle<T> {} +impl<T> PartialEq for Handle<T> { + fn eq(&self, other: &Self) -> bool { + self.index == other.index + } +} +impl<T> Eq for Handle<T> {} +impl<T> PartialOrd for Handle<T> { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + self.index.partial_cmp(&other.index) + } +} +impl<T> Ord for Handle<T> { + fn cmp(&self, other: &Self) -> Ordering { + self.index.cmp(&other.index) + } +} +impl<T> fmt::Debug for Handle<T> { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "Handle({})", self.index) + } +} +impl<T> hash::Hash for Handle<T> { + fn hash<H: hash::Hasher>(&self, hasher: &mut H) { + self.index.hash(hasher) + } +} + +impl<T> Handle<T> { + #[cfg(test)] + pub const DUMMY: Self = Handle { + index: unsafe { NonZeroU32::new_unchecked(!0) }, + marker: PhantomData, + }; + + pub(crate) fn new(index: Index) -> Self { + Handle { + index, + marker: PhantomData, + } + } + + /// Returns the zero-based index of this handle. + pub fn index(self) -> usize { + let index = self.index.get() - 1; + index as usize + } +} + +/// An arena holding some kind of component (e.g., type, constant, +/// instruction, etc.) that can be referenced. +/// +/// Adding new items to the arena produces a strongly-typed [`Handle`]. +/// The arena can be indexed using the given handle to obtain +/// a reference to the stored item. +#[derive(Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +#[cfg_attr( + any(feature = "serialize", feature = "deserialize"), + serde(transparent) +)] +#[cfg_attr(test, derive(PartialEq))] +pub struct Arena<T> { + /// Values of this arena. + data: Vec<T>, +} + +impl<T> Default for Arena<T> { + fn default() -> Self { + Self::new() + } +} + +impl<T> Arena<T> { + /// Create a new arena with no initial capacity allocated. + pub fn new() -> Self { + Arena { data: Vec::new() } + } + + /// Returns the current number of items stored in this arena. + pub fn len(&self) -> usize { + self.data.len() + } + + /// Returns `true` if the arena contains no elements. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + /// Returns an iterator over the items stored in this arena, returning both + /// the item's handle and a reference to it. + pub fn iter(&self) -> impl Iterator<Item = (Handle<T>, &T)> { + self.data.iter().enumerate().map(|(i, v)| { + let position = i + 1; + let index = unsafe { Index::new_unchecked(position as u32) }; + (Handle::new(index), v) + }) + } + + /// Adds a new value to the arena, returning a typed handle. + /// + /// The value is not linked to any SPIR-V module. + pub fn append(&mut self, value: T) -> Handle<T> { + let position = self.data.len() + 1; + let index = unsafe { Index::new_unchecked(position as u32) }; + self.data.push(value); + Handle::new(index) + } + + /// Fetch a handle to an existing type. + pub fn fetch_if<F: Fn(&T) -> bool>(&self, fun: F) -> Option<Handle<T>> { + self.data + .iter() + .position(fun) + .map(|index| Handle::new(unsafe { Index::new_unchecked((index + 1) as u32) })) + } + + /// Adds a value with a custom check for uniqueness: + /// returns a handle pointing to + /// an existing element if the check succeeds, or adds a new + /// element otherwise. + pub fn fetch_if_or_append<F: Fn(&T, &T) -> bool>(&mut self, value: T, fun: F) -> Handle<T> { + if let Some(index) = self.data.iter().position(|d| fun(d, &value)) { + let index = unsafe { Index::new_unchecked((index + 1) as u32) }; + Handle::new(index) + } else { + self.append(value) + } + } + + /// Adds a value with a check for uniqueness, where the check is plain comparison. + pub fn fetch_or_append(&mut self, value: T) -> Handle<T> + where + T: PartialEq, + { + self.fetch_if_or_append(value, T::eq) + } + + pub fn try_get(&self, handle: Handle<T>) -> Option<&T> { + self.data.get(handle.index.get() as usize - 1) + } + + /// Get a mutable reference to an element in the arena. + pub fn get_mut(&mut self, handle: Handle<T>) -> &mut T { + self.data.get_mut(handle.index.get() as usize - 1).unwrap() + } +} + +impl<T> std::ops::Index<Handle<T>> for Arena<T> { + type Output = T; + fn index(&self, handle: Handle<T>) -> &T { + let index = handle.index.get() - 1; + &self.data[index as usize] + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn append_non_unique() { + let mut arena: Arena<u8> = Arena::new(); + let t1 = arena.append(0); + let t2 = arena.append(0); + assert!(t1 != t2); + assert!(arena[t1] == arena[t2]); + } + + #[test] + fn append_unique() { + let mut arena: Arena<u8> = Arena::new(); + let t1 = arena.append(0); + let t2 = arena.append(1); + assert!(t1 != t2); + assert!(arena[t1] != arena[t2]); + } + + #[test] + fn fetch_or_append_non_unique() { + let mut arena: Arena<u8> = Arena::new(); + let t1 = arena.fetch_or_append(0); + let t2 = arena.fetch_or_append(0); + assert!(t1 == t2); + assert!(arena[t1] == arena[t2]) + } + + #[test] + fn fetch_or_append_unique() { + let mut arena: Arena<u8> = Arena::new(); + let t1 = arena.fetch_or_append(0); + let t2 = arena.fetch_or_append(1); + assert!(t1 != t2); + assert!(arena[t1] != arena[t2]); + } +} diff --git a/third_party/rust/naga/src/back/glsl.rs b/third_party/rust/naga/src/back/glsl.rs new file mode 100644 index 0000000000..6a497a675b --- /dev/null +++ b/third_party/rust/naga/src/back/glsl.rs @@ -0,0 +1,1579 @@ +//! OpenGL shading language backend +//! +//! The main structure is [`Writer`](struct.Writer.html), it maintains internal state that is used +//! to output a `Module` into glsl +//! +//! # Supported versions +//! ### Core +//! - 330 +//! - 400 +//! - 410 +//! - 420 +//! - 430 +//! - 450 +//! - 460 +//! +//! ### ES +//! - 300 +//! - 310 +//! + +use crate::{ + proc::{ + CallGraph, CallGraphBuilder, Interface, NameKey, Namer, ResolveContext, ResolveError, + Typifier, Visitor, + }, + Arena, ArraySize, BinaryOperator, BuiltIn, ConservativeDepth, Constant, ConstantInner, + DerivativeAxis, Expression, FastHashMap, Function, FunctionOrigin, GlobalVariable, Handle, + ImageClass, Interpolation, IntrinsicFunction, LocalVariable, Module, ScalarKind, ShaderStage, + Statement, StorageAccess, StorageClass, StorageFormat, StructMember, Type, TypeInner, + UnaryOperator, +}; +use std::{ + cmp::Ordering, + fmt::{self, Error as FmtError}, + io::{Error as IoError, Write}, +}; + +/// Contains a constant with a slice of all the reserved keywords RESERVED_KEYWORDS +mod keywords; + +const SUPPORTED_CORE_VERSIONS: &[u16] = &[330, 400, 410, 420, 430, 440, 450]; +const SUPPORTED_ES_VERSIONS: &[u16] = &[300, 310]; + +#[derive(Debug)] +pub enum Error { + FormatError(FmtError), + IoError(IoError), + Type(ResolveError), + Custom(String), +} + +impl From<FmtError> for Error { + fn from(err: FmtError) -> Self { + Error::FormatError(err) + } +} + +impl From<IoError> for Error { + fn from(err: IoError) -> Self { + Error::IoError(err) + } +} + +impl From<ResolveError> for Error { + fn from(err: ResolveError) -> Self { + Error::Type(err) + } +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Error::FormatError(err) => write!(f, "Formatting error {}", err), + Error::IoError(err) => write!(f, "Io error: {}", err), + Error::Type(err) => write!(f, "Type error: {:?}", err), + Error::Custom(err) => write!(f, "{}", err), + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum Version { + Desktop(u16), + Embedded(u16), +} + +impl Version { + fn is_es(&self) -> bool { + match self { + Version::Desktop(_) => false, + Version::Embedded(_) => true, + } + } + + fn is_supported(&self) -> bool { + match self { + Version::Desktop(v) => SUPPORTED_CORE_VERSIONS.contains(v), + Version::Embedded(v) => SUPPORTED_ES_VERSIONS.contains(v), + } + } +} + +impl PartialOrd for Version { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + match (*self, *other) { + (Version::Desktop(x), Version::Desktop(y)) => Some(x.cmp(&y)), + (Version::Embedded(x), Version::Embedded(y)) => Some(x.cmp(&y)), + _ => None, + } + } +} + +impl fmt::Display for Version { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Version::Desktop(v) => write!(f, "{} core", v), + Version::Embedded(v) => write!(f, "{} es", v), + } + } +} + +#[derive(Debug, Clone)] +pub struct Options { + pub version: Version, + pub entry_point: (ShaderStage, String), +} + +#[derive(Debug, Clone)] +pub struct TextureMapping { + pub texture: Handle<GlobalVariable>, + pub sampler: Option<Handle<GlobalVariable>>, +} + +bitflags::bitflags! { + struct Features: u32 { + const BUFFER_STORAGE = 1; + const ARRAY_OF_ARRAYS = 1 << 1; + const DOUBLE_TYPE = 1 << 2; + const FULL_IMAGE_FORMATS = 1 << 3; + const MULTISAMPLED_TEXTURES = 1 << 4; + const MULTISAMPLED_TEXTURE_ARRAYS = 1 << 5; + const CUBE_TEXTURES_ARRAY = 1 << 6; + const COMPUTE_SHADER = 1 << 7; + const IMAGE_LOAD_STORE = 1 << 8; + const CONSERVATIVE_DEPTH = 1 << 9; + const TEXTURE_1D = 1 << 10; + const PUSH_CONSTANT = 1 << 11; + } +} + +struct FeaturesManager(Features); + +impl FeaturesManager { + pub fn new() -> Self { + Self(Features::empty()) + } + + pub fn request(&mut self, features: Features) { + self.0 |= features + } + + #[allow(clippy::collapsible_if)] + pub fn write(&self, version: Version, mut out: impl Write) -> Result<(), Error> { + if self.0.contains(Features::COMPUTE_SHADER) { + if version < Version::Embedded(310) || version < Version::Desktop(420) { + return Err(Error::Custom(format!( + "Version {} doesn't support compute shaders", + version + ))); + } + + if !version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_compute_shader.txt + writeln!(out, "#extension GL_ARB_compute_shader : require")?; + } + } + + if self.0.contains(Features::BUFFER_STORAGE) { + if version < Version::Embedded(310) || version < Version::Desktop(400) { + return Err(Error::Custom(format!( + "Version {} doesn't support buffer storage class", + version + ))); + } + + if let Version::Desktop(_) = version { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_storage_buffer_object.txt + writeln!( + out, + "#extension GL_ARB_shader_storage_buffer_object : require" + )?; + } + } + + if self.0.contains(Features::DOUBLE_TYPE) { + if version.is_es() || version < Version::Desktop(150) { + return Err(Error::Custom(format!( + "Version {} doesn't support doubles", + version + ))); + } + + if version < Version::Desktop(400) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_gpu_shader_fp64.txt + writeln!(out, "#extension GL_ARB_gpu_shader_fp64 : require")?; + } + } + + if self.0.contains(Features::CUBE_TEXTURES_ARRAY) { + if version < Version::Embedded(310) || version < Version::Desktop(130) { + return Err(Error::Custom(format!( + "Version {} doesn't support cube map array textures", + version + ))); + } + + if version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_cube_map_array.txt + writeln!(out, "#extension GL_EXT_texture_cube_map_array : require")?; + } else if version < Version::Desktop(400) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_cube_map_array.txt + writeln!(out, "#extension GL_ARB_texture_cube_map_array : require")?; + } + } + + if self.0.contains(Features::MULTISAMPLED_TEXTURES) { + if version < Version::Embedded(300) { + return Err(Error::Custom(format!( + "Version {} doesn't support multi sampled textures", + version + ))); + } + } + + if self.0.contains(Features::MULTISAMPLED_TEXTURE_ARRAYS) { + if version < Version::Embedded(310) { + return Err(Error::Custom(format!( + "Version {} doesn't support multi sampled texture arrays", + version + ))); + } + + if version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_texture_storage_multisample_2d_array.txt + writeln!( + out, + "#extension GL_OES_texture_storage_multisample_2d_array : require" + )?; + } + } + + if self.0.contains(Features::ARRAY_OF_ARRAYS) { + if version < Version::Embedded(310) || version < Version::Desktop(120) { + return Err(Error::Custom(format!( + "Version {} doesn't arrays of arrays", + version + ))); + } + + if version < Version::Desktop(430) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_arrays_of_arrays.txt + writeln!(out, "#extension ARB_arrays_of_arrays : require")?; + } + } + + if self.0.contains(Features::IMAGE_LOAD_STORE) { + if version < Version::Embedded(310) || version < Version::Desktop(130) { + return Err(Error::Custom(format!( + "Version {} doesn't support images load/stores", + version + ))); + } + + if self.0.contains(Features::FULL_IMAGE_FORMATS) && version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/NV/NV_image_formats.txt + writeln!(out, "#extension GL_NV_image_formats : require")?; + } + + if version < Version::Desktop(420) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_image_load_store.txt + writeln!(out, "#extension GL_ARB_shader_image_load_store : require")?; + } + } + + if self.0.contains(Features::CONSERVATIVE_DEPTH) { + if version < Version::Embedded(300) || version < Version::Desktop(130) { + return Err(Error::Custom(format!( + "Version {} doesn't support conservative depth", + version + ))); + } + + if version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_conservative_depth.txt + writeln!(out, "#extension GL_EXT_conservative_depth : require")?; + } + + if version < Version::Desktop(420) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt + writeln!(out, "#extension GL_ARB_conservative_depth : require")?; + } + } + + if self.0.contains(Features::TEXTURE_1D) { + if version.is_es() { + return Err(Error::Custom(format!( + "Version {} doesn't support 1d textures", + version + ))); + } + } + + Ok(()) + } +} + +enum FunctionType { + Function(Handle<Function>), + EntryPoint(crate::proc::EntryPointIndex), +} + +struct FunctionCtx<'a, 'b> { + func: FunctionType, + expressions: &'a Arena<Expression>, + typifier: &'b Typifier, +} + +impl<'a, 'b> FunctionCtx<'a, 'b> { + fn name_key(&self, local: Handle<LocalVariable>) -> NameKey { + match self.func { + FunctionType::Function(handle) => NameKey::FunctionLocal(handle, local), + FunctionType::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local), + } + } + + fn get_arg<'c>(&self, arg: u32, names: &'c FastHashMap<NameKey, String>) -> &'c str { + match self.func { + FunctionType::Function(handle) => &names[&NameKey::FunctionArgument(handle, arg)], + FunctionType::EntryPoint(_) => unreachable!(), + } + } +} + +/// Helper structure that generates a number +#[derive(Default)] +struct IdGenerator(u32); + +impl IdGenerator { + fn generate(&mut self) -> u32 { + let ret = self.0; + self.0 += 1; + ret + } +} + +/// Main structure of the glsl backend responsible for all code generation +pub struct Writer<'a, W> { + // Inputs + module: &'a Module, + out: W, + options: &'a Options, + + // Internal State + features: FeaturesManager, + names: FastHashMap<NameKey, String>, + entry_point: &'a crate::EntryPoint, + entry_point_idx: crate::proc::EntryPointIndex, + call_graph: CallGraph, + + /// Used to generate a unique number for blocks + block_id: IdGenerator, +} + +impl<'a, W: Write> Writer<'a, W> { + pub fn new(out: W, module: &'a Module, options: &'a Options) -> Result<Self, Error> { + if !options.version.is_supported() { + return Err(Error::Custom(format!( + "Version not supported {}", + options.version + ))); + } + + let (ep_idx, ep) = module + .entry_points + .iter() + .enumerate() + .find_map(|(i, (key, entry_point))| { + Some((i as u16, entry_point)).filter(|_| &options.entry_point == key) + }) + .ok_or_else(|| Error::Custom(String::from("Entry point not found")))?; + + let mut names = FastHashMap::default(); + Namer::process(module, keywords::RESERVED_KEYWORDS, &mut names); + + let call_graph = CallGraphBuilder { + functions: &module.functions, + } + .process(&ep.function); + + let mut this = Self { + module, + out, + options, + + features: FeaturesManager::new(), + names, + entry_point: ep, + entry_point_idx: ep_idx, + call_graph, + + block_id: IdGenerator::default(), + }; + + this.collect_required_features()?; + + Ok(this) + } + + fn collect_required_features(&mut self) -> Result<(), Error> { + let stage = self.options.entry_point.0; + + if let Some(depth_test) = self.entry_point.early_depth_test { + self.features.request(Features::IMAGE_LOAD_STORE); + + if depth_test.conservative.is_some() { + self.features.request(Features::CONSERVATIVE_DEPTH); + } + } + + if let ShaderStage::Compute = stage { + self.features.request(Features::COMPUTE_SHADER) + } + + for (_, ty) in self.module.types.iter() { + match ty.inner { + TypeInner::Scalar { kind, width } => self.scalar_required_features(kind, width), + TypeInner::Vector { kind, width, .. } => self.scalar_required_features(kind, width), + TypeInner::Matrix { .. } => self.scalar_required_features(ScalarKind::Float, 8), + TypeInner::Array { base, .. } => { + if let TypeInner::Array { .. } = self.module.types[base].inner { + self.features.request(Features::ARRAY_OF_ARRAYS) + } + } + TypeInner::Image { + dim, + arrayed, + class, + } => { + if arrayed && dim == crate::ImageDimension::Cube { + self.features.request(Features::CUBE_TEXTURES_ARRAY) + } else if dim == crate::ImageDimension::D1 { + self.features.request(Features::TEXTURE_1D) + } + + match class { + ImageClass::Sampled { multi: true, .. } => { + self.features.request(Features::MULTISAMPLED_TEXTURES); + if arrayed { + self.features.request(Features::MULTISAMPLED_TEXTURE_ARRAYS); + } + } + ImageClass::Storage(format) => match format { + StorageFormat::R8Unorm + | StorageFormat::R8Snorm + | StorageFormat::R8Uint + | StorageFormat::R8Sint + | StorageFormat::R16Uint + | StorageFormat::R16Sint + | StorageFormat::R16Float + | StorageFormat::Rg8Unorm + | StorageFormat::Rg8Snorm + | StorageFormat::Rg8Uint + | StorageFormat::Rg8Sint + | StorageFormat::Rg16Uint + | StorageFormat::Rg16Sint + | StorageFormat::Rg16Float + | StorageFormat::Rgb10a2Unorm + | StorageFormat::Rg11b10Float + | StorageFormat::Rg32Uint + | StorageFormat::Rg32Sint + | StorageFormat::Rg32Float => { + self.features.request(Features::FULL_IMAGE_FORMATS) + } + _ => {} + }, + _ => {} + } + } + _ => {} + } + } + + for (_, global) in self.module.global_variables.iter() { + match global.class { + StorageClass::WorkGroup => self.features.request(Features::COMPUTE_SHADER), + StorageClass::Storage => self.features.request(Features::BUFFER_STORAGE), + StorageClass::PushConstant => self.features.request(Features::PUSH_CONSTANT), + _ => {} + } + } + + Ok(()) + } + + fn scalar_required_features(&mut self, kind: ScalarKind, width: crate::Bytes) { + if kind == ScalarKind::Float && width == 8 { + self.features.request(Features::DOUBLE_TYPE); + } + } + + pub fn write(&mut self) -> Result<FastHashMap<String, TextureMapping>, Error> { + let es = self.options.version.is_es(); + + writeln!(self.out, "#version {}", self.options.version)?; + self.features.write(self.options.version, &mut self.out)?; + writeln!(self.out)?; + + if es { + writeln!(self.out, "precision highp float;\n")?; + } + + if let Some(depth_test) = self.entry_point.early_depth_test { + writeln!(self.out, "layout(early_fragment_tests) in;\n")?; + + if let Some(conservative) = depth_test.conservative { + writeln!( + self.out, + "layout (depth_{}) out float gl_FragDepth;\n", + match conservative { + ConservativeDepth::GreaterEqual => "greater", + ConservativeDepth::LessEqual => "less", + ConservativeDepth::Unchanged => "unchanged", + } + )?; + } + } + + for (handle, ty) in self.module.types.iter() { + if let TypeInner::Struct { ref members } = ty.inner { + self.write_struct(handle, members)? + } + } + + writeln!(self.out)?; + + let texture_mappings = self.collect_texture_mapping( + self.call_graph + .raw_nodes() + .iter() + .map(|node| &self.module.functions[node.weight]) + .chain(std::iter::once(&self.entry_point.function)), + )?; + + for (handle, global) in self + .module + .global_variables + .iter() + .zip(&self.entry_point.function.global_usage) + .filter_map(|(global, usage)| Some(global).filter(|_| !usage.is_empty())) + { + if let Some(crate::Binding::BuiltIn(_)) = global.binding { + continue; + } + + match self.module.types[global.ty].inner { + TypeInner::Image { + dim, + arrayed, + class, + } => { + if let TypeInner::Image { + class: ImageClass::Storage(format), + .. + } = self.module.types[global.ty].inner + { + write!(self.out, "layout({}) ", glsl_storage_format(format))?; + } + + if global.storage_access == StorageAccess::LOAD { + write!(self.out, "readonly ")?; + } else if global.storage_access == StorageAccess::STORE { + write!(self.out, "writeonly ")?; + } + + write!(self.out, "uniform ")?; + + self.write_image_type(dim, arrayed, class)?; + + writeln!( + self.out, + " {};", + self.names[&NameKey::GlobalVariable(handle)] + )? + } + TypeInner::Sampler { .. } => continue, + _ => self.write_global(handle, global)?, + } + } + + writeln!(self.out)?; + + // Sort the graph topologically so that functions calls are valid + // It's impossible for this to panic because the IR forbids cycles + let functions = petgraph::algo::toposort(&self.call_graph, None).unwrap(); + + for node in functions { + let handle = self.call_graph[node]; + let name = self.names[&NameKey::Function(handle)].clone(); + self.write_function( + FunctionType::Function(handle), + &self.module.functions[handle], + name, + )?; + } + + self.write_function( + FunctionType::EntryPoint(self.entry_point_idx), + &self.entry_point.function, + "main", + )?; + + Ok(texture_mappings) + } + + fn write_global( + &mut self, + handle: Handle<GlobalVariable>, + global: &GlobalVariable, + ) -> Result<(), Error> { + if global.storage_access == StorageAccess::LOAD { + write!(self.out, "readonly ")?; + } else if global.storage_access == StorageAccess::STORE { + write!(self.out, "writeonly ")?; + } + + if let Some(interpolation) = global.interpolation { + match (self.options.entry_point.0, global.class) { + (ShaderStage::Fragment, StorageClass::Input) + | (ShaderStage::Vertex, StorageClass::Output) => { + write!(self.out, "{} ", glsl_interpolation(interpolation)?)?; + } + _ => (), + }; + } + + let block = match global.class { + StorageClass::Storage | StorageClass::Uniform => { + let block_name = self.names[&NameKey::Type(global.ty)].clone(); + + Some(block_name) + } + _ => None, + }; + + write!(self.out, "{} ", glsl_storage_class(global.class))?; + + self.write_type(global.ty, block)?; + + let name = &self.names[&NameKey::GlobalVariable(handle)]; + writeln!(self.out, " {};", name)?; + + Ok(()) + } + + fn write_function<N: AsRef<str>>( + &mut self, + ty: FunctionType, + func: &Function, + name: N, + ) -> Result<(), Error> { + let mut typifier = Typifier::new(); + + typifier.resolve_all( + &func.expressions, + &self.module.types, + &ResolveContext { + constants: &self.module.constants, + global_vars: &self.module.global_variables, + local_vars: &func.local_variables, + functions: &self.module.functions, + arguments: &func.arguments, + }, + )?; + + let ctx = FunctionCtx { + func: ty, + expressions: &func.expressions, + typifier: &typifier, + }; + + self.write_fn_header(name.as_ref(), func, &ctx)?; + writeln!(self.out, " {{",)?; + + for (handle, local) in func.local_variables.iter() { + write!(self.out, "\t")?; + self.write_type(local.ty, None)?; + + write!(self.out, " {}", self.names[&ctx.name_key(handle)])?; + + if let Some(init) = local.init { + write!(self.out, " = ",)?; + + self.write_constant(&self.module.constants[init])?; + } + + writeln!(self.out, ";")? + } + + writeln!(self.out)?; + + for sta in func.body.iter() { + self.write_stmt(sta, &ctx, 1)?; + } + + Ok(writeln!(self.out, "}}")?) + } + + fn write_slice<T, F: FnMut(&mut Self, u32, &T) -> Result<(), Error>>( + &mut self, + data: &[T], + mut f: F, + ) -> Result<(), Error> { + for (i, item) in data.iter().enumerate() { + f(self, i as u32, item)?; + + if i != data.len().saturating_sub(1) { + write!(self.out, ",")?; + } + } + + Ok(()) + } + + fn write_fn_header( + &mut self, + name: &str, + func: &Function, + ctx: &FunctionCtx<'_, '_>, + ) -> Result<(), Error> { + if let Some(ty) = func.return_type { + self.write_type(ty, None)?; + } else { + write!(self.out, "void")?; + } + + write!(self.out, " {}(", name)?; + + self.write_slice(&func.arguments, |this, i, arg| { + this.write_type(arg.ty, None)?; + + let name = ctx.get_arg(i, &this.names); + + Ok(write!(this.out, " {}", name)?) + })?; + + write!(self.out, ")")?; + + Ok(()) + } + + fn write_type(&mut self, ty: Handle<Type>, block: Option<String>) -> Result<(), Error> { + match self.module.types[ty].inner { + TypeInner::Scalar { kind, width } => { + write!(self.out, "{}", glsl_scalar(kind, width)?.full)? + } + TypeInner::Vector { size, kind, width } => write!( + self.out, + "{}vec{}", + glsl_scalar(kind, width)?.prefix, + size as u8 + )?, + TypeInner::Matrix { + columns, + rows, + width, + } => write!( + self.out, + "{}mat{}x{}", + glsl_scalar(ScalarKind::Float, width)?.prefix, + columns as u8, + rows as u8 + )?, + TypeInner::Pointer { base, .. } => self.write_type(base, None)?, + TypeInner::Array { base, size, .. } => { + self.write_type(base, None)?; + + write!(self.out, "[")?; + self.write_array_size(size)?; + write!(self.out, "]")? + } + TypeInner::Struct { ref members } => { + if let Some(name) = block { + writeln!(self.out, "{}_block_{} {{", name, self.block_id.generate())?; + + for (idx, member) in members.iter().enumerate() { + self.write_type(member.ty, None)?; + + writeln!( + self.out, + " {};", + &self.names[&NameKey::StructMember(ty, idx as u32)] + )?; + } + + write!(self.out, "}}")? + } else { + write!(self.out, "{}", &self.names[&NameKey::Type(ty)])? + } + } + _ => unreachable!(), + } + + Ok(()) + } + + fn write_image_type( + &mut self, + dim: crate::ImageDimension, + arrayed: bool, + class: ImageClass, + ) -> Result<(), Error> { + let (base, kind, ms, comparison) = match class { + ImageClass::Sampled { kind, multi: true } => ("sampler", kind, "MS", ""), + ImageClass::Sampled { kind, multi: false } => ("sampler", kind, "", ""), + ImageClass::Depth => ("sampler", crate::ScalarKind::Float, "", "Shadow"), + ImageClass::Storage(format) => ("image", format.into(), "", ""), + }; + + Ok(write!( + self.out, + "{}{}{}{}{}{}", + glsl_scalar(kind, 4)?.prefix, + base, + ImageDimension(dim), + ms, + if arrayed { "Array" } else { "" }, + comparison + )?) + } + + fn write_array_size(&mut self, size: ArraySize) -> Result<(), Error> { + match size { + ArraySize::Constant(const_handle) => match self.module.constants[const_handle].inner { + ConstantInner::Uint(size) => write!(self.out, "{}", size)?, + _ => unreachable!(), + }, + ArraySize::Dynamic => (), + } + + Ok(()) + } + + fn collect_texture_mapping( + &self, + functions: impl Iterator<Item = &'a Function>, + ) -> Result<FastHashMap<String, TextureMapping>, Error> { + let mut mappings = FastHashMap::default(); + + for func in functions { + let mut interface = Interface { + expressions: &func.expressions, + local_variables: &func.local_variables, + visitor: TextureMappingVisitor { + names: &self.names, + expressions: &func.expressions, + map: &mut mappings, + error: None, + }, + }; + interface.traverse(&func.body); + + if let Some(error) = interface.visitor.error { + return Err(error); + } + } + + Ok(mappings) + } + + fn write_struct( + &mut self, + handle: Handle<Type>, + members: &[StructMember], + ) -> Result<(), Error> { + writeln!(self.out, "struct {} {{", self.names[&NameKey::Type(handle)])?; + + for (idx, member) in members.iter().enumerate() { + write!(self.out, "\t")?; + self.write_type(member.ty, None)?; + writeln!( + self.out, + " {};", + self.names[&NameKey::StructMember(handle, idx as u32)] + )?; + } + + writeln!(self.out, "}};")?; + Ok(()) + } + + fn write_stmt( + &mut self, + sta: &Statement, + ctx: &FunctionCtx<'_, '_>, + indent: usize, + ) -> Result<(), Error> { + write!(self.out, "{}", "\t".repeat(indent))?; + + match sta { + Statement::Block(block) => { + writeln!(self.out, "{{")?; + for sta in block.iter() { + self.write_stmt(sta, ctx, indent + 1)? + } + writeln!(self.out, "{}}}", "\t".repeat(indent))? + } + Statement::If { + condition, + accept, + reject, + } => { + write!(self.out, "if(")?; + self.write_expr(*condition, ctx)?; + writeln!(self.out, ") {{")?; + + for sta in accept { + self.write_stmt(sta, ctx, indent + 1)?; + } + + if !reject.is_empty() { + writeln!(self.out, "{}}} else {{", "\t".repeat(indent))?; + + for sta in reject { + self.write_stmt(sta, ctx, indent + 1)?; + } + } + + writeln!(self.out, "{}}}", "\t".repeat(indent))? + } + Statement::Switch { + selector, + cases, + default, + } => { + write!(self.out, "switch(")?; + self.write_expr(*selector, ctx)?; + writeln!(self.out, ") {{")?; + + for (label, (block, fallthrough)) in cases { + writeln!(self.out, "{}case {}:", "\t".repeat(indent + 1), label)?; + + for sta in block { + self.write_stmt(sta, ctx, indent + 2)?; + } + + if fallthrough.is_none() { + writeln!(self.out, "{}break;", "\t".repeat(indent + 2))?; + } + } + + if !default.is_empty() { + writeln!(self.out, "{}default:", "\t".repeat(indent + 1))?; + + for sta in default { + self.write_stmt(sta, ctx, indent + 2)?; + } + } + + writeln!(self.out, "{}}}", "\t".repeat(indent))? + } + Statement::Loop { body, continuing } => { + writeln!(self.out, "while(true) {{")?; + + for sta in body.iter().chain(continuing.iter()) { + self.write_stmt(sta, ctx, indent + 1)?; + } + + writeln!(self.out, "{}}}", "\t".repeat(indent))? + } + Statement::Break => writeln!(self.out, "break;")?, + Statement::Continue => writeln!(self.out, "continue;")?, + Statement::Return { value } => { + write!(self.out, "return")?; + if let Some(expr) = value { + write!(self.out, " ")?; + self.write_expr(*expr, ctx)?; + } + writeln!(self.out, ";")?; + } + Statement::Kill => writeln!(self.out, "discard;")?, + Statement::Store { pointer, value } => { + self.write_expr(*pointer, ctx)?; + write!(self.out, " = ")?; + self.write_expr(*value, ctx)?; + writeln!(self.out, ";")? + } + } + + Ok(()) + } + + fn write_expr( + &mut self, + expr: Handle<Expression>, + ctx: &FunctionCtx<'_, '_>, + ) -> Result<(), Error> { + match ctx.expressions[expr] { + Expression::Access { base, index } => { + self.write_expr(base, ctx)?; + write!(self.out, "[")?; + self.write_expr(index, ctx)?; + write!(self.out, "]")? + } + Expression::AccessIndex { base, index } => { + self.write_expr(base, ctx)?; + + match ctx.typifier.get(base, &self.module.types) { + TypeInner::Vector { .. } + | TypeInner::Matrix { .. } + | TypeInner::Array { .. } => write!(self.out, "[{}]", index)?, + TypeInner::Struct { .. } => { + let ty = ctx.typifier.get_handle(base).unwrap(); + + write!( + self.out, + ".{}", + &self.names[&NameKey::StructMember(ty, index)] + )? + } + ref other => return Err(Error::Custom(format!("Cannot index {:?}", other))), + } + } + Expression::Constant(constant) => { + self.write_constant(&self.module.constants[constant])? + } + Expression::Compose { ty, ref components } => { + match self.module.types[ty].inner { + TypeInner::Vector { .. } + | TypeInner::Matrix { .. } + | TypeInner::Array { .. } + | TypeInner::Struct { .. } => self.write_type(ty, None)?, + _ => unreachable!(), + } + + write!(self.out, "(")?; + self.write_slice(components, |this, _, arg| this.write_expr(*arg, ctx))?; + write!(self.out, ")")? + } + Expression::FunctionArgument(pos) => { + write!(self.out, "{}", ctx.get_arg(pos, &self.names))? + } + Expression::GlobalVariable(handle) => { + if let Some(crate::Binding::BuiltIn(built_in)) = + self.module.global_variables[handle].binding + { + write!(self.out, "{}", glsl_built_in(built_in))? + } else { + write!( + self.out, + "{}", + &self.names[&NameKey::GlobalVariable(handle)] + )? + } + } + Expression::LocalVariable(handle) => { + write!(self.out, "{}", self.names[&ctx.name_key(handle)])? + } + Expression::Load { pointer } => self.write_expr(pointer, ctx)?, + Expression::ImageSample { + image, + coordinate, + level, + depth_ref, + .. + } => { + //TODO: handle MS + write!( + self.out, + "{}(", + match level { + crate::SampleLevel::Auto | crate::SampleLevel::Bias(_) => "texture", + crate::SampleLevel::Zero | crate::SampleLevel::Exact(_) => "textureLod", + } + )?; + self.write_expr(image, ctx)?; + write!(self.out, ", ")?; + + let size = match *ctx.typifier.get(coordinate, &self.module.types) { + TypeInner::Vector { size, .. } => size, + ref other => { + return Err(Error::Custom(format!( + "Cannot sample with coordinates of type {:?}", + other + ))) + } + }; + + if let Some(depth_ref) = depth_ref { + write!(self.out, "vec{}(", size as u8 + 1)?; + self.write_expr(coordinate, ctx)?; + write!(self.out, ", ")?; + self.write_expr(depth_ref, ctx)?; + write!(self.out, ")")? + } else { + self.write_expr(coordinate, ctx)? + } + + match level { + crate::SampleLevel::Auto => (), + crate::SampleLevel::Zero => write!(self.out, ", 0")?, + crate::SampleLevel::Exact(expr) | crate::SampleLevel::Bias(expr) => { + write!(self.out, ", ")?; + self.write_expr(expr, ctx)?; + } + } + + write!(self.out, ")")? + } + Expression::ImageLoad { + image, + coordinate, + index, + } => { + let class = match ctx.typifier.get(image, &self.module.types) { + TypeInner::Image { class, .. } => class, + _ => unreachable!(), + }; + + match class { + ImageClass::Sampled { .. } => write!(self.out, "texelFetch(")?, + ImageClass::Storage(_) => write!(self.out, "imageLoad(")?, + ImageClass::Depth => todo!(), + } + + self.write_expr(image, ctx)?; + write!(self.out, ", ")?; + self.write_expr(coordinate, ctx)?; + + match class { + ImageClass::Sampled { .. } => { + write!(self.out, ", ")?; + self.write_expr(index.unwrap(), ctx)?; + write!(self.out, ")")? + } + ImageClass::Storage(_) => write!(self.out, ")")?, + ImageClass::Depth => todo!(), + } + } + Expression::Unary { op, expr } => { + write!( + self.out, + "({} ", + match op { + UnaryOperator::Negate => "-", + UnaryOperator::Not => match *ctx.typifier.get(expr, &self.module.types) { + TypeInner::Scalar { + kind: ScalarKind::Sint, + .. + } => "~", + TypeInner::Scalar { + kind: ScalarKind::Uint, + .. + } => "~", + TypeInner::Scalar { + kind: ScalarKind::Bool, + .. + } => "!", + ref other => + return Err(Error::Custom(format!( + "Cannot apply not to type {:?}", + other + ))), + }, + } + )?; + + self.write_expr(expr, ctx)?; + + write!(self.out, ")")? + } + Expression::Binary { op, left, right } => { + write!(self.out, "(")?; + self.write_expr(left, ctx)?; + + write!( + self.out, + " {} ", + match op { + BinaryOperator::Add => "+", + BinaryOperator::Subtract => "-", + BinaryOperator::Multiply => "*", + BinaryOperator::Divide => "/", + BinaryOperator::Modulo => "%", + BinaryOperator::Equal => "==", + BinaryOperator::NotEqual => "!=", + BinaryOperator::Less => "<", + BinaryOperator::LessEqual => "<=", + BinaryOperator::Greater => ">", + BinaryOperator::GreaterEqual => ">=", + BinaryOperator::And => "&", + BinaryOperator::ExclusiveOr => "^", + BinaryOperator::InclusiveOr => "|", + BinaryOperator::LogicalAnd => "&&", + BinaryOperator::LogicalOr => "||", + BinaryOperator::ShiftLeft => "<<", + BinaryOperator::ShiftRight => ">>", + } + )?; + + self.write_expr(right, ctx)?; + + write!(self.out, ")")? + } + Expression::Select { + condition, + accept, + reject, + } => { + write!(self.out, "(")?; + self.write_expr(condition, ctx)?; + write!(self.out, " ? ")?; + self.write_expr(accept, ctx)?; + write!(self.out, " : ")?; + self.write_expr(reject, ctx)?; + write!(self.out, ")")? + } + Expression::Intrinsic { fun, argument } => { + write!( + self.out, + "{}(", + match fun { + IntrinsicFunction::IsFinite => "!isinf", + IntrinsicFunction::IsInf => "isinf", + IntrinsicFunction::IsNan => "isnan", + IntrinsicFunction::IsNormal => "!isnan", + IntrinsicFunction::All => "all", + IntrinsicFunction::Any => "any", + } + )?; + + self.write_expr(argument, ctx)?; + + write!(self.out, ")")? + } + Expression::Transpose(matrix) => { + write!(self.out, "transpose(")?; + self.write_expr(matrix, ctx)?; + write!(self.out, ")")? + } + Expression::DotProduct(left, right) => { + write!(self.out, "dot(")?; + self.write_expr(left, ctx)?; + write!(self.out, ", ")?; + self.write_expr(right, ctx)?; + write!(self.out, ")")? + } + Expression::CrossProduct(left, right) => { + write!(self.out, "cross(")?; + self.write_expr(left, ctx)?; + write!(self.out, ", ")?; + self.write_expr(right, ctx)?; + write!(self.out, ")")? + } + Expression::As { + expr, + kind, + convert, + } => { + if convert { + self.write_type(ctx.typifier.get_handle(expr).unwrap(), None)?; + } else { + let source_kind = match *ctx.typifier.get(expr, &self.module.types) { + TypeInner::Scalar { + kind: source_kind, .. + } => source_kind, + TypeInner::Vector { + kind: source_kind, .. + } => source_kind, + _ => unreachable!(), + }; + + write!( + self.out, + "{}", + match (source_kind, kind) { + (ScalarKind::Float, ScalarKind::Sint) => "floatBitsToInt", + (ScalarKind::Float, ScalarKind::Uint) => "floatBitsToUInt", + (ScalarKind::Sint, ScalarKind::Float) => "intBitsToFloat", + (ScalarKind::Uint, ScalarKind::Float) => "uintBitsToFloat", + _ => { + return Err(Error::Custom(format!( + "Cannot bitcast {:?} to {:?}", + source_kind, kind + ))); + } + } + )?; + } + + write!(self.out, "(")?; + self.write_expr(expr, ctx)?; + write!(self.out, ")")? + } + Expression::Derivative { axis, expr } => { + write!( + self.out, + "{}(", + match axis { + DerivativeAxis::X => "dFdx", + DerivativeAxis::Y => "dFdy", + DerivativeAxis::Width => "fwidth", + } + )?; + self.write_expr(expr, ctx)?; + write!(self.out, ")")? + } + Expression::Call { + origin: FunctionOrigin::Local(ref function), + ref arguments, + } => { + write!(self.out, "{}(", &self.names[&NameKey::Function(*function)])?; + self.write_slice(arguments, |this, _, arg| this.write_expr(*arg, ctx))?; + write!(self.out, ")")? + } + Expression::Call { + origin: crate::FunctionOrigin::External(ref name), + ref arguments, + } => match name.as_str() { + "cos" | "normalize" | "sin" | "length" | "abs" | "floor" | "inverse" + | "distance" | "dot" | "min" | "max" | "reflect" | "pow" | "step" | "cross" + | "fclamp" | "clamp" | "mix" | "smoothstep" => { + let name = match name.as_str() { + "fclamp" => "clamp", + name => name, + }; + + write!(self.out, "{}(", name)?; + self.write_slice(arguments, |this, _, arg| this.write_expr(*arg, ctx))?; + write!(self.out, ")")? + } + "atan2" => { + write!(self.out, "atan(")?; + self.write_expr(arguments[1], ctx)?; + write!(self.out, ", ")?; + self.write_expr(arguments[0], ctx)?; + write!(self.out, ")")? + } + other => { + return Err(Error::Custom(format!( + "Unsupported function call {}", + other + ))) + } + }, + Expression::ArrayLength(expr) => { + write!(self.out, "uint(")?; + self.write_expr(expr, ctx)?; + write!(self.out, ".length())")? + } + } + + Ok(()) + } + + fn write_constant(&mut self, constant: &Constant) -> Result<(), Error> { + match constant.inner { + ConstantInner::Sint(int) => write!(self.out, "{}", int)?, + ConstantInner::Uint(int) => write!(self.out, "{}u", int)?, + ConstantInner::Float(float) => write!(self.out, "{:?}", float)?, + ConstantInner::Bool(boolean) => write!(self.out, "{}", boolean)?, + ConstantInner::Composite(ref components) => { + self.write_type(constant.ty, None)?; + write!(self.out, "(")?; + self.write_slice(components, |this, _, arg| { + this.write_constant(&this.module.constants[*arg]) + })?; + write!(self.out, ")")? + } + } + + Ok(()) + } +} + +struct ScalarString<'a> { + prefix: &'a str, + full: &'a str, +} + +fn glsl_scalar(kind: ScalarKind, width: crate::Bytes) -> Result<ScalarString<'static>, Error> { + Ok(match kind { + ScalarKind::Sint => ScalarString { + prefix: "i", + full: "int", + }, + ScalarKind::Uint => ScalarString { + prefix: "u", + full: "uint", + }, + ScalarKind::Float => match width { + 4 => ScalarString { + prefix: "", + full: "float", + }, + 8 => ScalarString { + prefix: "d", + full: "double", + }, + _ => { + return Err(Error::Custom(format!( + "Cannot build float of width {}", + width + ))) + } + }, + ScalarKind::Bool => ScalarString { + prefix: "b", + full: "bool", + }, + }) +} + +fn glsl_built_in(built_in: BuiltIn) -> &'static str { + match built_in { + BuiltIn::Position => "gl_Position", + BuiltIn::GlobalInvocationId => "gl_GlobalInvocationID", + BuiltIn::BaseInstance => "gl_BaseInstance", + BuiltIn::BaseVertex => "gl_BaseVertex", + BuiltIn::ClipDistance => "gl_ClipDistance", + BuiltIn::InstanceIndex => "gl_InstanceIndex", + BuiltIn::VertexIndex => "gl_VertexIndex", + BuiltIn::PointSize => "gl_PointSize", + BuiltIn::FragCoord => "gl_FragCoord", + BuiltIn::FrontFacing => "gl_FrontFacing", + BuiltIn::SampleIndex => "gl_SampleID", + BuiltIn::FragDepth => "gl_FragDepth", + BuiltIn::LocalInvocationId => "gl_LocalInvocationID", + BuiltIn::LocalInvocationIndex => "gl_LocalInvocationIndex", + BuiltIn::WorkGroupId => "gl_WorkGroupID", + } +} + +fn glsl_storage_class(class: StorageClass) -> &'static str { + match class { + StorageClass::Function => "", + StorageClass::Input => "in", + StorageClass::Output => "out", + StorageClass::Private => "", + StorageClass::Storage => "buffer", + StorageClass::Uniform => "uniform", + StorageClass::Handle => "uniform", + StorageClass::WorkGroup => "shared", + StorageClass::PushConstant => "", + } +} + +fn glsl_interpolation(interpolation: Interpolation) -> Result<&'static str, Error> { + Ok(match interpolation { + Interpolation::Perspective => "smooth", + Interpolation::Linear => "noperspective", + Interpolation::Flat => "flat", + Interpolation::Centroid => "centroid", + Interpolation::Sample => "sample", + Interpolation::Patch => { + return Err(Error::Custom( + "patch interpolation qualifier not supported".to_string(), + )) + } + }) +} + +struct ImageDimension(crate::ImageDimension); +impl fmt::Display for ImageDimension { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self.0 { + crate::ImageDimension::D1 => "1D", + crate::ImageDimension::D2 => "2D", + crate::ImageDimension::D3 => "3D", + crate::ImageDimension::Cube => "Cube", + } + ) + } +} + +fn glsl_storage_format(format: StorageFormat) -> &'static str { + match format { + StorageFormat::R8Unorm => "r8", + StorageFormat::R8Snorm => "r8_snorm", + StorageFormat::R8Uint => "r8ui", + StorageFormat::R8Sint => "r8i", + StorageFormat::R16Uint => "r16ui", + StorageFormat::R16Sint => "r16i", + StorageFormat::R16Float => "r16f", + StorageFormat::Rg8Unorm => "rg8", + StorageFormat::Rg8Snorm => "rg8_snorm", + StorageFormat::Rg8Uint => "rg8ui", + StorageFormat::Rg8Sint => "rg8i", + StorageFormat::R32Uint => "r32ui", + StorageFormat::R32Sint => "r32i", + StorageFormat::R32Float => "r32f", + StorageFormat::Rg16Uint => "rg16ui", + StorageFormat::Rg16Sint => "rg16i", + StorageFormat::Rg16Float => "rg16f", + StorageFormat::Rgba8Unorm => "rgba8ui", + StorageFormat::Rgba8Snorm => "rgba8_snorm", + StorageFormat::Rgba8Uint => "rgba8ui", + StorageFormat::Rgba8Sint => "rgba8i", + StorageFormat::Rgb10a2Unorm => "rgb10_a2ui", + StorageFormat::Rg11b10Float => "r11f_g11f_b10f", + StorageFormat::Rg32Uint => "rg32ui", + StorageFormat::Rg32Sint => "rg32i", + StorageFormat::Rg32Float => "rg32f", + StorageFormat::Rgba16Uint => "rgba16ui", + StorageFormat::Rgba16Sint => "rgba16i", + StorageFormat::Rgba16Float => "rgba16f", + StorageFormat::Rgba32Uint => "rgba32ui", + StorageFormat::Rgba32Sint => "rgba32i", + StorageFormat::Rgba32Float => "rgba32f", + } +} + +struct TextureMappingVisitor<'a> { + names: &'a FastHashMap<NameKey, String>, + expressions: &'a Arena<Expression>, + map: &'a mut FastHashMap<String, TextureMapping>, + error: Option<Error>, +} + +impl<'a> Visitor for TextureMappingVisitor<'a> { + fn visit_expr(&mut self, expr: &crate::Expression) { + match expr { + Expression::ImageSample { image, sampler, .. } => { + let tex_handle = match self.expressions[*image] { + Expression::GlobalVariable(global) => global, + _ => unreachable!(), + }; + let tex_name = self.names[&NameKey::GlobalVariable(tex_handle)].clone(); + + let sampler_handle = match self.expressions[*sampler] { + Expression::GlobalVariable(global) => global, + _ => unreachable!(), + }; + + let mapping = self.map.entry(tex_name).or_insert(TextureMapping { + texture: tex_handle, + sampler: Some(sampler_handle), + }); + + if mapping.sampler != Some(sampler_handle) { + self.error = Some(Error::Custom(String::from( + "Cannot use texture with two different samplers", + ))); + } + } + Expression::ImageLoad { image, .. } => { + let tex_handle = match self.expressions[*image] { + Expression::GlobalVariable(global) => global, + _ => unreachable!(), + }; + let tex_name = self.names[&NameKey::GlobalVariable(tex_handle)].clone(); + + let mapping = self.map.entry(tex_name).or_insert(TextureMapping { + texture: tex_handle, + sampler: None, + }); + + if mapping.sampler != None { + self.error = Some(Error::Custom(String::from( + "Cannot use texture with two different samplers", + ))); + } + } + _ => {} + } + } +} diff --git a/third_party/rust/naga/src/back/glsl/keywords.rs b/third_party/rust/naga/src/back/glsl/keywords.rs new file mode 100644 index 0000000000..5a2836c189 --- /dev/null +++ b/third_party/rust/naga/src/back/glsl/keywords.rs @@ -0,0 +1,204 @@ +pub const RESERVED_KEYWORDS: &[&str] = &[ + "attribute", + "const", + "uniform", + "varying", + "buffer", + "shared", + "coherent", + "volatile", + "restrict", + "readonly", + "writeonly", + "atomic_uint", + "layout", + "centroid", + "flat", + "smooth", + "noperspective", + "patch", + "sample", + "break", + "continue", + "do", + "for", + "while", + "switch", + "case", + "default", + "if", + "else", + "subroutine", + "in", + "out", + "inout", + "float", + "double", + "int", + "void", + "bool", + "true", + "false", + "invariant", + "precise", + "discard", + "return", + "mat2", + "mat3", + "mat4", + "dmat2", + "dmat3", + "dmat4", + "mat2x2", + "mat2x3", + "mat2x4", + "dmat2x2", + "dmat2x3", + "dmat2x4", + "mat3x2", + "mat3x3", + "mat3x4", + "dmat3x2", + "dmat3x3", + "dmat3x4", + "mat4x2", + "mat4x3", + "mat4x4", + "dmat4x2", + "dmat4x3", + "dmat4x4", + "vec2", + "vec3", + "vec4", + "ivec2", + "ivec3", + "ivec4", + "bvec2", + "bvec3", + "bvec4", + "dvec2", + "dvec3", + "dvec4", + "uint", + "uvec2", + "uvec3", + "uvec4", + "lowp", + "mediump", + "highp", + "precision", + "sampler1D", + "sampler2D", + "sampler3D", + "samplerCube", + "sampler1DShadow", + "sampler2DShadow", + "samplerCubeShadow", + "sampler1DArray", + "sampler2DArray", + "sampler1DArrayShadow", + "sampler2DArrayShadow", + "isampler1D", + "isampler2D", + "isampler3D", + "isamplerCube", + "isampler1DArray", + "isampler2DArray", + "usampler1D", + "usampler2D", + "usampler3D", + "usamplerCube", + "usampler1DArray", + "usampler2DArray", + "sampler2DRect", + "sampler2DRectShadow", + "isampler2D", + "Rect", + "usampler2DRect", + "samplerBuffer", + "isamplerBuffer", + "usamplerBuffer", + "sampler2DMS", + "isampler2DMS", + "usampler2DMS", + "sampler2DMSArray", + "isampler2DMSArray", + "usampler2DMSArray", + "samplerCubeArray", + "samplerCubeArrayShadow", + "isamplerCubeArray", + "usamplerCubeArray", + "image1D", + "iimage1D", + "uimage1D", + "image2D", + "iimage2D", + "uimage2D", + "image3D", + "iimage3D", + "uimage3D", + "image2DRect", + "iimage2DRect", + "uimage2DRect", + "imageCube", + "iimageCube", + "uimageCube", + "imageBuffer", + "iimageBuffer", + "uimageBuffer", + "image1DArray", + "iimage1DArray", + "uimage1DArray", + "image2DArray", + "iimage2DArray", + "uimage2DArray", + "imageCubeArray", + "iimageCubeArray", + "uimageCubeArray", + "image2DMS", + "iimage2DMS", + "uimage2DMS", + "image2DMSArray", + "iimage2DMSArray", + "uimage2DMSArraystruct", + "common", + "partition", + "active", + "asm", + "class", + "union", + "enum", + "typedef", + "template", + "this", + "resource", + "goto", + "inline", + "noinline", + "public", + "static", + "extern", + "external", + "interface", + "long", + "short", + "half", + "fixed", + "unsigned", + "superp", + "input", + "output", + "hvec2", + "hvec3", + "hvec4", + "fvec2", + "fvec3", + "fvec4", + "sampler3DRect", + "filter", + "sizeof", + "cast", + "namespace", + "using", + "main", +]; diff --git a/third_party/rust/naga/src/back/mod.rs b/third_party/rust/naga/src/back/mod.rs new file mode 100644 index 0000000000..bc96dd3496 --- /dev/null +++ b/third_party/rust/naga/src/back/mod.rs @@ -0,0 +1,8 @@ +//! Functions which export shader modules into binary and text formats. + +#[cfg(feature = "glsl-out")] +pub mod glsl; +#[cfg(feature = "msl-out")] +pub mod msl; +#[cfg(feature = "spv-out")] +pub mod spv; diff --git a/third_party/rust/naga/src/back/msl/keywords.rs b/third_party/rust/naga/src/back/msl/keywords.rs new file mode 100644 index 0000000000..cd074ab43f --- /dev/null +++ b/third_party/rust/naga/src/back/msl/keywords.rs @@ -0,0 +1,102 @@ +//TODO: find a complete list +pub const RESERVED: &[&str] = &[ + // control flow + "break", + "if", + "else", + "continue", + "goto", + "do", + "while", + "for", + "switch", + "case", + // types and values + "void", + "unsigned", + "signed", + "bool", + "char", + "int", + "long", + "float", + "double", + "char8_t", + "wchar_t", + "true", + "false", + "nullptr", + "union", + "class", + "struct", + "enum", + // other + "main", + "using", + "decltype", + "sizeof", + "typeof", + "typedef", + "explicit", + "export", + "friend", + "namespace", + "operator", + "public", + "template", + "typename", + "typeid", + "co_await", + "co_return", + "co_yield", + "module", + "import", + "ray_data", + "vec_step", + "visible", + "as_type", + // qualifiers + "mutable", + "static", + "volatile", + "restrict", + "const", + "non-temporal", + "dereferenceable", + "invariant", + // exceptions + "throw", + "try", + "catch", + // operators + "const_cast", + "dynamic_cast", + "reinterpret_cast", + "static_cast", + "new", + "delete", + "and", + "and_eq", + "bitand", + "bitor", + "compl", + "not", + "not_eq", + "or", + "or_eq", + "xor", + "xor_eq", + "compl", + // Metal-specific + "constant", + "device", + "threadgroup", + "threadgroup_imageblock", + "kernel", + "compute", + "vertex", + "fragment", + "read_only", + "write_only", + "read_write", +]; diff --git a/third_party/rust/naga/src/back/msl/mod.rs b/third_party/rust/naga/src/back/msl/mod.rs new file mode 100644 index 0000000000..493e7d0c85 --- /dev/null +++ b/third_party/rust/naga/src/back/msl/mod.rs @@ -0,0 +1,211 @@ +/*! Metal Shading Language (MSL) backend + +## Binding model + +Metal's bindings are flat per resource. Since there isn't an obvious mapping +from SPIR-V's descriptor sets, we require a separate mapping provided in the options. +This mapping may have one or more resource end points for each descriptor set + index +pair. + +## Outputs + +In Metal, built-in shader outputs can not be nested into structures within +the output struct. If there is a structure in the outputs, and it contains any built-ins, +we move them up to the root output structure that we define ourselves. +!*/ + +use crate::{arena::Handle, proc::ResolveError, FastHashMap}; +use std::{ + io::{Error as IoError, Write}, + string::FromUtf8Error, +}; + +mod keywords; +mod writer; + +pub use writer::Writer; + +#[derive(Clone, Debug, Default, PartialEq)] +pub struct BindTarget { + pub buffer: Option<u8>, + pub texture: Option<u8>, + pub sampler: Option<u8>, + pub mutable: bool, +} + +#[derive(Clone, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub struct BindSource { + pub stage: crate::ShaderStage, + pub group: u32, + pub binding: u32, +} + +pub type BindingMap = FastHashMap<BindSource, BindTarget>; + +enum ResolvedBinding { + BuiltIn(crate::BuiltIn), + Attribute(u32), + Color(u32), + User { prefix: &'static str, index: u32 }, + Resource(BindTarget), +} + +// Note: some of these should be removed in favor of proper IR validation. + +#[derive(Debug)] +pub enum Error { + IO(IoError), + Utf8(FromUtf8Error), + Type(ResolveError), + UnexpectedLocation, + MissingBinding(Handle<crate::GlobalVariable>), + MissingBindTarget(BindSource), + InvalidImageAccess(crate::StorageAccess), + MutabilityViolation(Handle<crate::GlobalVariable>), + BadName(String), + UnexpectedGlobalType(Handle<crate::Type>), + UnimplementedBindTarget(BindTarget), + UnsupportedCompose(Handle<crate::Type>), + UnsupportedBinaryOp(crate::BinaryOperator), + UnexpectedSampleLevel(crate::SampleLevel), + UnsupportedCall(String), + UnsupportedDynamicArrayLength, + UnableToReturnValue(Handle<crate::Expression>), + /// The source IR is not valid. + Validation, +} + +impl From<IoError> for Error { + fn from(e: IoError) -> Self { + Error::IO(e) + } +} + +impl From<FromUtf8Error> for Error { + fn from(e: FromUtf8Error) -> Self { + Error::Utf8(e) + } +} + +impl From<ResolveError> for Error { + fn from(e: ResolveError) -> Self { + Error::Type(e) + } +} + +#[derive(Clone, Copy, Debug)] +enum LocationMode { + VertexInput, + FragmentOutput, + Intermediate, + Uniform, +} + +#[derive(Debug, Default, Clone)] +pub struct Options { + /// (Major, Minor) target version of the Metal Shading Language. + pub lang_version: (u8, u8), + /// Make it possible to link different stages via SPIRV-Cross. + pub spirv_cross_compatibility: bool, + /// Binding model mapping to Metal. + pub binding_map: BindingMap, +} + +impl Options { + fn resolve_binding( + &self, + stage: crate::ShaderStage, + binding: &crate::Binding, + mode: LocationMode, + ) -> Result<ResolvedBinding, Error> { + match *binding { + crate::Binding::BuiltIn(built_in) => Ok(ResolvedBinding::BuiltIn(built_in)), + crate::Binding::Location(index) => match mode { + LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(index)), + LocationMode::FragmentOutput => Ok(ResolvedBinding::Color(index)), + LocationMode::Intermediate => Ok(ResolvedBinding::User { + prefix: if self.spirv_cross_compatibility { + "locn" + } else { + "loc" + }, + index, + }), + LocationMode::Uniform => Err(Error::UnexpectedLocation), + }, + crate::Binding::Resource { group, binding } => { + let source = BindSource { + stage, + group, + binding, + }; + self.binding_map + .get(&source) + .cloned() + .map(ResolvedBinding::Resource) + .ok_or(Error::MissingBindTarget(source)) + } + } + } +} + +impl ResolvedBinding { + fn try_fmt<W: Write>(&self, out: &mut W) -> Result<(), Error> { + match *self { + ResolvedBinding::BuiltIn(built_in) => { + use crate::BuiltIn as Bi; + let name = match built_in { + // vertex + Bi::BaseInstance => "base_instance", + Bi::BaseVertex => "base_vertex", + Bi::ClipDistance => "clip_distance", + Bi::InstanceIndex => "instance_id", + Bi::PointSize => "point_size", + Bi::Position => "position", + Bi::VertexIndex => "vertex_id", + // fragment + Bi::FragCoord => "position", + Bi::FragDepth => "depth(any)", + Bi::FrontFacing => "front_facing", + Bi::SampleIndex => "sample_id", + // compute + Bi::GlobalInvocationId => "thread_position_in_grid", + Bi::LocalInvocationId => "thread_position_in_threadgroup", + Bi::LocalInvocationIndex => "thread_index_in_threadgroup", + Bi::WorkGroupId => "threadgroup_position_in_grid", + }; + Ok(write!(out, "{}", name)?) + } + ResolvedBinding::Attribute(index) => Ok(write!(out, "attribute({})", index)?), + ResolvedBinding::Color(index) => Ok(write!(out, "color({})", index)?), + ResolvedBinding::User { prefix, index } => { + Ok(write!(out, "user({}{})", prefix, index)?) + } + ResolvedBinding::Resource(ref target) => { + if let Some(id) = target.buffer { + Ok(write!(out, "buffer({})", id)?) + } else if let Some(id) = target.texture { + Ok(write!(out, "texture({})", id)?) + } else if let Some(id) = target.sampler { + Ok(write!(out, "sampler({})", id)?) + } else { + Err(Error::UnimplementedBindTarget(target.clone())) + } + } + } + } + + fn try_fmt_decorated<W: Write>(&self, out: &mut W, terminator: &str) -> Result<(), Error> { + write!(out, " [[")?; + self.try_fmt(out)?; + write!(out, "]]")?; + write!(out, "{}", terminator)?; + Ok(()) + } +} + +pub fn write_string(module: &crate::Module, options: &Options) -> Result<String, Error> { + let mut w = writer::Writer::new(Vec::new()); + w.write(module, options)?; + Ok(String::from_utf8(w.finish())?) +} diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs new file mode 100644 index 0000000000..e2adbea6b7 --- /dev/null +++ b/third_party/rust/naga/src/back/msl/writer.rs @@ -0,0 +1,990 @@ +use super::{keywords::RESERVED, Error, LocationMode, Options, ResolvedBinding}; +use crate::{ + arena::Handle, + proc::{EntryPointIndex, NameKey, Namer, ResolveContext, Typifier}, + FastHashMap, +}; +use std::{ + fmt::{Display, Error as FmtError, Formatter}, + io::Write, +}; + +struct Level(usize); +impl Level { + fn next(&self) -> Self { + Level(self.0 + 1) + } +} +impl Display for Level { + fn fmt(&self, formatter: &mut Formatter<'_>) -> Result<(), FmtError> { + (0..self.0).map(|_| formatter.write_str("\t")).collect() + } +} + +struct TypedGlobalVariable<'a> { + module: &'a crate::Module, + names: &'a FastHashMap<NameKey, String>, + handle: Handle<crate::GlobalVariable>, + usage: crate::GlobalUse, +} + +impl<'a> TypedGlobalVariable<'a> { + fn try_fmt<W: Write>(&self, out: &mut W) -> Result<(), Error> { + let var = &self.module.global_variables[self.handle]; + let name = &self.names[&NameKey::GlobalVariable(self.handle)]; + let ty = &self.module.types[var.ty]; + let ty_name = &self.names[&NameKey::Type(var.ty)]; + + let (space_qualifier, reference) = match ty.inner { + crate::TypeInner::Struct { .. } => match var.class { + crate::StorageClass::Uniform | crate::StorageClass::Storage => { + let space = if self.usage.contains(crate::GlobalUse::STORE) { + "device " + } else { + "constant " + }; + (space, "&") + } + _ => ("", ""), + }, + _ => ("", ""), + }; + Ok(write!( + out, + "{}{}{} {}", + space_qualifier, ty_name, reference, name + )?) + } +} + +pub struct Writer<W> { + out: W, + names: FastHashMap<NameKey, String>, + typifier: Typifier, +} + +fn scalar_kind_string(kind: crate::ScalarKind) -> &'static str { + match kind { + crate::ScalarKind::Float => "float", + crate::ScalarKind::Sint => "int", + crate::ScalarKind::Uint => "uint", + crate::ScalarKind::Bool => "bool", + } +} + +fn vector_size_string(size: crate::VectorSize) -> &'static str { + match size { + crate::VectorSize::Bi => "2", + crate::VectorSize::Tri => "3", + crate::VectorSize::Quad => "4", + } +} + +const OUTPUT_STRUCT_NAME: &str = "output"; +const LOCATION_INPUT_STRUCT_NAME: &str = "input"; +const COMPONENTS: &[char] = &['x', 'y', 'z', 'w']; + +fn separate(is_last: bool) -> &'static str { + if is_last { + "" + } else { + "," + } +} + +enum FunctionOrigin { + Handle(Handle<crate::Function>), + EntryPoint(EntryPointIndex), +} + +struct ExpressionContext<'a> { + function: &'a crate::Function, + origin: FunctionOrigin, + module: &'a crate::Module, +} + +impl<W: Write> Writer<W> { + /// Creates a new `Writer` instance. + pub fn new(out: W) -> Self { + Writer { + out, + names: FastHashMap::default(), + typifier: Typifier::new(), + } + } + + /// Finishes writing and returns the output. + pub fn finish(self) -> W { + self.out + } + + fn put_call( + &mut self, + name: &str, + parameters: &[Handle<crate::Expression>], + context: &ExpressionContext, + ) -> Result<(), Error> { + if !name.is_empty() { + write!(self.out, "metal::{}", name)?; + } + write!(self.out, "(")?; + for (i, &handle) in parameters.iter().enumerate() { + if i != 0 { + write!(self.out, ", ")?; + } + self.put_expression(handle, context)?; + } + write!(self.out, ")")?; + Ok(()) + } + + fn put_expression( + &mut self, + expr_handle: Handle<crate::Expression>, + context: &ExpressionContext, + ) -> Result<(), Error> { + let expression = &context.function.expressions[expr_handle]; + log::trace!("expression {:?} = {:?}", expr_handle, expression); + match *expression { + crate::Expression::Access { base, index } => { + self.put_expression(base, context)?; + write!(self.out, "[")?; + self.put_expression(index, context)?; + write!(self.out, "]")?; + } + crate::Expression::AccessIndex { base, index } => { + self.put_expression(base, context)?; + let resolved = self.typifier.get(base, &context.module.types); + match *resolved { + crate::TypeInner::Struct { .. } => { + let base_ty = self.typifier.get_handle(base).unwrap(); + let name = &self.names[&NameKey::StructMember(base_ty, index)]; + write!(self.out, ".{}", name)?; + } + crate::TypeInner::Matrix { .. } | crate::TypeInner::Vector { .. } => { + write!(self.out, ".{}", COMPONENTS[index as usize])?; + } + crate::TypeInner::Array { .. } => { + write!(self.out, "[{}]", index)?; + } + _ => { + // unexpected indexing, should fail validation + } + } + } + crate::Expression::Constant(handle) => self.put_constant(handle, context.module)?, + crate::Expression::Compose { ty, ref components } => { + let inner = &context.module.types[ty].inner; + match *inner { + crate::TypeInner::Vector { size, kind, .. } => { + write!( + self.out, + "{}{}", + scalar_kind_string(kind), + vector_size_string(size) + )?; + self.put_call("", components, context)?; + } + crate::TypeInner::Scalar { width: 4, kind } if components.len() == 1 => { + write!(self.out, "{}", scalar_kind_string(kind),)?; + self.put_call("", components, context)?; + } + _ => return Err(Error::UnsupportedCompose(ty)), + } + } + crate::Expression::FunctionArgument(index) => { + let fun_handle = match context.origin { + FunctionOrigin::Handle(handle) => handle, + FunctionOrigin::EntryPoint(_) => unreachable!(), + }; + let name = &self.names[&NameKey::FunctionArgument(fun_handle, index)]; + write!(self.out, "{}", name)?; + } + crate::Expression::GlobalVariable(handle) => { + let var = &context.module.global_variables[handle]; + match var.class { + crate::StorageClass::Output => { + if let crate::TypeInner::Struct { .. } = context.module.types[var.ty].inner + { + return Ok(()); + } + write!(self.out, "{}.", OUTPUT_STRUCT_NAME)?; + } + crate::StorageClass::Input => { + if let Some(crate::Binding::Location(_)) = var.binding { + write!(self.out, "{}.", LOCATION_INPUT_STRUCT_NAME)?; + } + } + _ => {} + } + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, "{}", name)?; + } + crate::Expression::LocalVariable(handle) => { + let name_key = match context.origin { + FunctionOrigin::Handle(fun_handle) => { + NameKey::FunctionLocal(fun_handle, handle) + } + FunctionOrigin::EntryPoint(ep_index) => { + NameKey::EntryPointLocal(ep_index, handle) + } + }; + let name = &self.names[&name_key]; + write!(self.out, "{}", name)?; + } + crate::Expression::Load { pointer } => { + //write!(self.out, "*")?; + self.put_expression(pointer, context)?; + } + crate::Expression::ImageSample { + image, + sampler, + coordinate, + level, + depth_ref, + } => { + let op = match depth_ref { + Some(_) => "sample_compare", + None => "sample", + }; + //TODO: handle arrayed images + self.put_expression(image, context)?; + write!(self.out, ".{}(", op)?; + self.put_expression(sampler, context)?; + write!(self.out, ", ")?; + self.put_expression(coordinate, context)?; + if let Some(dref) = depth_ref { + write!(self.out, ", ")?; + self.put_expression(dref, context)?; + } + match level { + crate::SampleLevel::Auto => {} + crate::SampleLevel::Zero => { + write!(self.out, ", level(0)")?; + } + crate::SampleLevel::Exact(h) => { + write!(self.out, ", level(")?; + self.put_expression(h, context)?; + write!(self.out, ")")?; + } + crate::SampleLevel::Bias(h) => { + write!(self.out, ", bias(")?; + self.put_expression(h, context)?; + write!(self.out, ")")?; + } + } + write!(self.out, ")")?; + } + crate::Expression::ImageLoad { + image, + coordinate, + index, + } => { + //TODO: handle arrayed images + self.put_expression(image, context)?; + write!(self.out, ".read(")?; + self.put_expression(coordinate, context)?; + if let Some(index) = index { + write!(self.out, ", ")?; + self.put_expression(index, context)?; + } + write!(self.out, ")")?; + } + crate::Expression::Unary { op, expr } => { + let op_str = match op { + crate::UnaryOperator::Negate => "-", + crate::UnaryOperator::Not => "!", + }; + write!(self.out, "{}", op_str)?; + self.put_expression(expr, context)?; + } + crate::Expression::Binary { op, left, right } => { + let op_str = match op { + crate::BinaryOperator::Add => "+", + crate::BinaryOperator::Subtract => "-", + crate::BinaryOperator::Multiply => "*", + crate::BinaryOperator::Divide => "/", + crate::BinaryOperator::Modulo => "%", + crate::BinaryOperator::Equal => "==", + crate::BinaryOperator::NotEqual => "!=", + crate::BinaryOperator::Less => "<", + crate::BinaryOperator::LessEqual => "<=", + crate::BinaryOperator::Greater => "==", + crate::BinaryOperator::GreaterEqual => ">=", + crate::BinaryOperator::And => "&", + _ => return Err(Error::UnsupportedBinaryOp(op)), + }; + let kind = self + .typifier + .get(left, &context.module.types) + .scalar_kind() + .ok_or(Error::UnsupportedBinaryOp(op))?; + if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float { + write!(self.out, "fmod(")?; + self.put_expression(left, context)?; + write!(self.out, ", ")?; + self.put_expression(right, context)?; + write!(self.out, ")")?; + } else { + //write!(self.out, "(")?; + self.put_expression(left, context)?; + write!(self.out, " {} ", op_str)?; + self.put_expression(right, context)?; + //write!(self.out, ")")?; + } + } + crate::Expression::Select { + condition, + accept, + reject, + } => { + write!(self.out, "(")?; + self.put_expression(condition, context)?; + write!(self.out, " ? ")?; + self.put_expression(accept, context)?; + write!(self.out, " : ")?; + self.put_expression(reject, context)?; + write!(self.out, ")")?; + } + crate::Expression::Intrinsic { fun, argument } => { + let op = match fun { + crate::IntrinsicFunction::Any => "any", + crate::IntrinsicFunction::All => "all", + crate::IntrinsicFunction::IsNan => "", + crate::IntrinsicFunction::IsInf => "", + crate::IntrinsicFunction::IsFinite => "", + crate::IntrinsicFunction::IsNormal => "", + }; + self.put_call(op, &[argument], context)?; + } + crate::Expression::Transpose(expr) => { + self.put_call("transpose", &[expr], context)?; + } + crate::Expression::DotProduct(a, b) => { + self.put_call("dot", &[a, b], context)?; + } + crate::Expression::CrossProduct(a, b) => { + self.put_call("cross", &[a, b], context)?; + } + crate::Expression::As { + expr, + kind, + convert, + } => { + let scalar = scalar_kind_string(kind); + let size = match *self.typifier.get(expr, &context.module.types) { + crate::TypeInner::Scalar { .. } => "", + crate::TypeInner::Vector { size, .. } => vector_size_string(size), + _ => return Err(Error::Validation), + }; + let op = if convert { "static_cast" } else { "as_type" }; + write!(self.out, "{}<{}{}>(", op, scalar, size)?; + self.put_expression(expr, context)?; + write!(self.out, ")")?; + } + crate::Expression::Derivative { axis, expr } => { + let op = match axis { + crate::DerivativeAxis::X => "dfdx", + crate::DerivativeAxis::Y => "dfdy", + crate::DerivativeAxis::Width => "fwidth", + }; + self.put_call(op, &[expr], context)?; + } + crate::Expression::Call { + origin: crate::FunctionOrigin::Local(handle), + ref arguments, + } => { + let name = &self.names[&NameKey::Function(handle)]; + write!(self.out, "{}", name)?; + self.put_call("", arguments, context)?; + } + crate::Expression::Call { + origin: crate::FunctionOrigin::External(ref name), + ref arguments, + } => match name.as_str() { + "atan2" | "cos" | "distance" | "length" | "mix" | "normalize" | "sin" => { + self.put_call(name, arguments, context)?; + } + "fclamp" => { + self.put_call("clamp", arguments, context)?; + } + other => return Err(Error::UnsupportedCall(other.to_owned())), + }, + crate::Expression::ArrayLength(expr) => match *self + .typifier + .get(expr, &context.module.types) + { + crate::TypeInner::Array { + size: crate::ArraySize::Constant(const_handle), + .. + } => { + self.put_constant(const_handle, context.module)?; + } + crate::TypeInner::Array { .. } => return Err(Error::UnsupportedDynamicArrayLength), + _ => return Err(Error::Validation), + }, + } + Ok(()) + } + + fn put_constant( + &mut self, + handle: Handle<crate::Constant>, + module: &crate::Module, + ) -> Result<(), Error> { + let constant = &module.constants[handle]; + match constant.inner { + crate::ConstantInner::Sint(value) => { + write!(self.out, "{}", value)?; + } + crate::ConstantInner::Uint(value) => { + write!(self.out, "{}", value)?; + } + crate::ConstantInner::Float(value) => { + write!(self.out, "{}", value)?; + if value.fract() == 0.0 { + write!(self.out, ".0")?; + } + } + crate::ConstantInner::Bool(value) => { + write!(self.out, "{}", value)?; + } + crate::ConstantInner::Composite(ref constituents) => { + let ty_name = &self.names[&NameKey::Type(constant.ty)]; + write!(self.out, "{}(", ty_name)?; + for (i, &handle) in constituents.iter().enumerate() { + if i != 0 { + write!(self.out, ", ")?; + } + self.put_constant(handle, module)?; + } + write!(self.out, ")")?; + } + } + Ok(()) + } + + fn put_block( + &mut self, + level: Level, + statements: &[crate::Statement], + context: &ExpressionContext, + return_value: Option<&str>, + ) -> Result<(), Error> { + for statement in statements { + log::trace!("statement[{}] {:?}", level.0, statement); + match *statement { + crate::Statement::Block(ref block) => { + if !block.is_empty() { + writeln!(self.out, "{}{{", level)?; + self.put_block(level.next(), block, context, return_value)?; + writeln!(self.out, "{}}}", level)?; + } + } + crate::Statement::If { + condition, + ref accept, + ref reject, + } => { + write!(self.out, "{}if (", level)?; + self.put_expression(condition, context)?; + writeln!(self.out, ") {{")?; + self.put_block(level.next(), accept, context, return_value)?; + if !reject.is_empty() { + writeln!(self.out, "{}}} else {{", level)?; + self.put_block(level.next(), reject, context, return_value)?; + } + writeln!(self.out, "{}}}", level)?; + } + crate::Statement::Switch { + selector, + ref cases, + ref default, + } => { + write!(self.out, "{}switch(", level)?; + self.put_expression(selector, context)?; + writeln!(self.out, ") {{")?; + let lcase = level.next(); + for (&value, &(ref block, ref fall_through)) in cases.iter() { + writeln!(self.out, "{}case {}: {{", lcase, value)?; + self.put_block(lcase.next(), block, context, return_value)?; + if fall_through.is_none() { + writeln!(self.out, "{}break;", lcase.next())?; + } + writeln!(self.out, "{}}}", lcase)?; + } + writeln!(self.out, "{}default: {{", lcase)?; + self.put_block(lcase.next(), default, context, return_value)?; + writeln!(self.out, "{}}}", lcase)?; + writeln!(self.out, "{}}}", level)?; + } + crate::Statement::Loop { + ref body, + ref continuing, + } => { + writeln!(self.out, "{}while(true) {{", level)?; + self.put_block(level.next(), body, context, return_value)?; + if !continuing.is_empty() { + //TODO + } + writeln!(self.out, "{}}}", level)?; + } + crate::Statement::Break => { + writeln!(self.out, "{}break;", level)?; + } + crate::Statement::Continue => { + writeln!(self.out, "{}continue;", level)?; + } + crate::Statement::Return { + value: Some(expr_handle), + } => { + write!(self.out, "{}return ", level)?; + self.put_expression(expr_handle, context)?; + writeln!(self.out, ";")?; + } + crate::Statement::Return { value: None } => { + if let Some(string) = return_value { + writeln!(self.out, "{}return {};", level, string)?; + } + } + crate::Statement::Kill => { + writeln!(self.out, "{}discard_fragment();", level)?; + } + crate::Statement::Store { pointer, value } => { + //write!(self.out, "\t*")?; + write!(self.out, "{}", level)?; + self.put_expression(pointer, context)?; + write!(self.out, " = ")?; + self.put_expression(value, context)?; + writeln!(self.out, ";")?; + } + } + } + Ok(()) + } + + pub fn write(&mut self, module: &crate::Module, options: &Options) -> Result<(), Error> { + self.names.clear(); + Namer::process(module, RESERVED, &mut self.names); + + writeln!(self.out, "#include <metal_stdlib>")?; + writeln!(self.out, "#include <simd/simd.h>")?; + + writeln!(self.out)?; + self.write_type_defs(module)?; + + writeln!(self.out)?; + self.write_functions(module, options)?; + + Ok(()) + } + + fn write_type_defs(&mut self, module: &crate::Module) -> Result<(), Error> { + for (handle, ty) in module.types.iter() { + let name = &self.names[&NameKey::Type(handle)]; + match ty.inner { + crate::TypeInner::Scalar { kind, .. } => { + write!(self.out, "typedef {} {}", scalar_kind_string(kind), name)?; + } + crate::TypeInner::Vector { size, kind, .. } => { + write!( + self.out, + "typedef {}{} {}", + scalar_kind_string(kind), + vector_size_string(size), + name + )?; + } + crate::TypeInner::Matrix { columns, rows, .. } => { + write!( + self.out, + "typedef {}{}x{} {}", + scalar_kind_string(crate::ScalarKind::Float), + vector_size_string(columns), + vector_size_string(rows), + name + )?; + } + crate::TypeInner::Pointer { base, class } => { + use crate::StorageClass as Sc; + let base_name = &self.names[&NameKey::Type(base)]; + let class_name = match class { + Sc::Input | Sc::Output => continue, + Sc::Uniform => "constant", + Sc::Storage => "device", + Sc::Handle + | Sc::Private + | Sc::Function + | Sc::WorkGroup + | Sc::PushConstant => "", + }; + write!(self.out, "typedef {} {} *{}", class_name, base_name, name)?; + } + crate::TypeInner::Array { + base, + size, + stride: _, + } => { + let base_name = &self.names[&NameKey::Type(base)]; + write!(self.out, "typedef {} {}[", base_name, name)?; + match size { + crate::ArraySize::Constant(const_handle) => { + self.put_constant(const_handle, module)?; + write!(self.out, "]")?; + } + crate::ArraySize::Dynamic => write!(self.out, "1]")?, + } + } + crate::TypeInner::Struct { ref members } => { + writeln!(self.out, "struct {} {{", name)?; + for (index, member) in members.iter().enumerate() { + let member_name = &self.names[&NameKey::StructMember(handle, index as u32)]; + let base_name = &self.names[&NameKey::Type(member.ty)]; + write!(self.out, "\t{} {}", base_name, member_name)?; + match member.origin { + crate::MemberOrigin::Empty => {} + crate::MemberOrigin::BuiltIn(built_in) => { + ResolvedBinding::BuiltIn(built_in) + .try_fmt_decorated(&mut self.out, "")?; + } + crate::MemberOrigin::Offset(_) => { + //TODO + } + } + writeln!(self.out, ";")?; + } + write!(self.out, "}}")?; + } + crate::TypeInner::Image { + dim, + arrayed, + class, + } => { + let dim_str = match dim { + crate::ImageDimension::D1 => "1d", + crate::ImageDimension::D2 => "2d", + crate::ImageDimension::D3 => "3d", + crate::ImageDimension::Cube => "Cube", + }; + let (texture_str, msaa_str, kind, access) = match class { + crate::ImageClass::Sampled { kind, multi } => { + ("texture", if multi { "_ms" } else { "" }, kind, "sample") + } + crate::ImageClass::Depth => { + ("depth", "", crate::ScalarKind::Float, "sample") + } + crate::ImageClass::Storage(format) => { + let (_, global) = module + .global_variables + .iter() + .find(|(_, var)| var.ty == handle) + .expect("Unable to find a global variable using the image type"); + let access = if global + .storage_access + .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE) + { + "read_write" + } else if global.storage_access.contains(crate::StorageAccess::STORE) { + "write" + } else if global.storage_access.contains(crate::StorageAccess::LOAD) { + "read" + } else { + return Err(Error::InvalidImageAccess(global.storage_access)); + }; + ("texture", "", format.into(), access) + } + }; + let base_name = scalar_kind_string(kind); + let array_str = if arrayed { "_array" } else { "" }; + write!( + self.out, + "typedef {}{}{}{}<{}, access::{}> {}", + texture_str, dim_str, msaa_str, array_str, base_name, access, name + )?; + } + crate::TypeInner::Sampler { comparison: _ } => { + write!(self.out, "typedef sampler {}", name)?; + } + } + writeln!(self.out, ";")?; + } + Ok(()) + } + + fn write_functions(&mut self, module: &crate::Module, options: &Options) -> Result<(), Error> { + for (fun_handle, fun) in module.functions.iter() { + self.typifier.resolve_all( + &fun.expressions, + &module.types, + &ResolveContext { + constants: &module.constants, + global_vars: &module.global_variables, + local_vars: &fun.local_variables, + functions: &module.functions, + arguments: &fun.arguments, + }, + )?; + + let fun_name = &self.names[&NameKey::Function(fun_handle)]; + let result_type_name = match fun.return_type { + Some(ret_ty) => &self.names[&NameKey::Type(ret_ty)], + None => "void", + }; + writeln!(self.out, "{} {}(", result_type_name, fun_name)?; + + for (index, arg) in fun.arguments.iter().enumerate() { + let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)]; + let param_type_name = &self.names[&NameKey::Type(arg.ty)]; + let separator = separate(index + 1 == fun.arguments.len()); + writeln!(self.out, "\t{} {}{}", param_type_name, name, separator)?; + } + writeln!(self.out, ") {{")?; + + for (local_handle, local) in fun.local_variables.iter() { + let ty_name = &self.names[&NameKey::Type(local.ty)]; + let local_name = &self.names[&NameKey::FunctionLocal(fun_handle, local_handle)]; + write!(self.out, "\t{} {}", ty_name, local_name)?; + if let Some(value) = local.init { + write!(self.out, " = ")?; + self.put_constant(value, module)?; + } + writeln!(self.out, ";")?; + } + + let context = ExpressionContext { + function: fun, + origin: FunctionOrigin::Handle(fun_handle), + module, + }; + self.put_block(Level(1), &fun.body, &context, None)?; + writeln!(self.out, "}}")?; + } + + for (ep_index, (&(stage, _), ep)) in module.entry_points.iter().enumerate() { + let fun = &ep.function; + self.typifier.resolve_all( + &fun.expressions, + &module.types, + &ResolveContext { + constants: &module.constants, + global_vars: &module.global_variables, + local_vars: &fun.local_variables, + functions: &module.functions, + arguments: &fun.arguments, + }, + )?; + + // find the entry point(s) and inputs/outputs + let mut last_used_global = None; + for ((handle, var), &usage) in module.global_variables.iter().zip(&fun.global_usage) { + match var.class { + crate::StorageClass::Input => { + if let Some(crate::Binding::Location(_)) = var.binding { + continue; + } + } + crate::StorageClass::Output => continue, + _ => {} + } + if !usage.is_empty() { + last_used_global = Some(handle); + } + } + + let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)]; + let output_name = format!("{}Output", fun_name); + let location_input_name = format!("{}Input", fun_name); + + let (em_str, in_mode, out_mode) = match stage { + crate::ShaderStage::Vertex => ( + "vertex", + LocationMode::VertexInput, + LocationMode::Intermediate, + ), + crate::ShaderStage::Fragment { .. } => ( + "fragment", + LocationMode::Intermediate, + LocationMode::FragmentOutput, + ), + crate::ShaderStage::Compute { .. } => { + ("kernel", LocationMode::Uniform, LocationMode::Uniform) + } + }; + + let return_value = match stage { + crate::ShaderStage::Vertex | crate::ShaderStage::Fragment => { + // make dedicated input/output structs + writeln!(self.out, "struct {} {{", location_input_name)?; + + for ((handle, var), &usage) in + module.global_variables.iter().zip(&fun.global_usage) + { + if var.class != crate::StorageClass::Input + || !usage.contains(crate::GlobalUse::LOAD) + { + continue; + } + // if it's a struct, lift all the built-in contents up to the root + if let crate::TypeInner::Struct { ref members } = module.types[var.ty].inner + { + for (index, member) in members.iter().enumerate() { + if let crate::MemberOrigin::BuiltIn(built_in) = member.origin { + let name = + &self.names[&NameKey::StructMember(var.ty, index as u32)]; + let ty_name = &self.names[&NameKey::Type(member.ty)]; + write!(self.out, "\t{} {}", ty_name, name)?; + ResolvedBinding::BuiltIn(built_in) + .try_fmt_decorated(&mut self.out, ";\n")?; + } + } + } else if let Some(ref binding @ crate::Binding::Location(_)) = var.binding + { + let tyvar = TypedGlobalVariable { + module, + names: &self.names, + handle, + usage: crate::GlobalUse::empty(), + }; + let resolved = options.resolve_binding(stage, binding, in_mode)?; + + write!(self.out, "\t")?; + tyvar.try_fmt(&mut self.out)?; + resolved.try_fmt_decorated(&mut self.out, ";\n")?; + } + } + writeln!(self.out, "}};")?; + + writeln!(self.out, "struct {} {{", output_name)?; + for ((handle, var), &usage) in + module.global_variables.iter().zip(&fun.global_usage) + { + if var.class != crate::StorageClass::Output + || !usage.contains(crate::GlobalUse::STORE) + { + continue; + } + // if it's a struct, lift all the built-in contents up to the root + if let crate::TypeInner::Struct { ref members } = module.types[var.ty].inner + { + for (index, member) in members.iter().enumerate() { + let name = + &self.names[&NameKey::StructMember(var.ty, index as u32)]; + let ty_name = &self.names[&NameKey::Type(member.ty)]; + match member.origin { + crate::MemberOrigin::Empty => {} + crate::MemberOrigin::BuiltIn(built_in) => { + write!(self.out, "\t{} {}", ty_name, name)?; + ResolvedBinding::BuiltIn(built_in) + .try_fmt_decorated(&mut self.out, ";\n")?; + } + crate::MemberOrigin::Offset(_) => { + //TODO + } + } + } + } else { + let tyvar = TypedGlobalVariable { + module, + names: &self.names, + handle, + usage: crate::GlobalUse::empty(), + }; + write!(self.out, "\t")?; + tyvar.try_fmt(&mut self.out)?; + if let Some(ref binding) = var.binding { + let resolved = options.resolve_binding(stage, binding, out_mode)?; + resolved.try_fmt_decorated(&mut self.out, "")?; + } + writeln!(self.out, ";")?; + } + } + writeln!(self.out, "}};")?; + + writeln!(self.out, "{} {} {}(", em_str, output_name, fun_name)?; + let separator = separate(last_used_global.is_none()); + writeln!( + self.out, + "\t{} {} [[stage_in]]{}", + location_input_name, LOCATION_INPUT_STRUCT_NAME, separator + )?; + + Some(OUTPUT_STRUCT_NAME) + } + crate::ShaderStage::Compute => { + writeln!(self.out, "{} void {}(", em_str, fun_name)?; + None + } + }; + + for ((handle, var), &usage) in module.global_variables.iter().zip(&fun.global_usage) { + if usage.is_empty() || var.class == crate::StorageClass::Output { + continue; + } + if var.class == crate::StorageClass::Input { + if let Some(crate::Binding::Location(_)) = var.binding { + // location inputs are put into a separate struct + continue; + } + } + let loc_mode = match (stage, var.class) { + (crate::ShaderStage::Vertex, crate::StorageClass::Input) => { + LocationMode::VertexInput + } + (crate::ShaderStage::Vertex, crate::StorageClass::Output) + | (crate::ShaderStage::Fragment { .. }, crate::StorageClass::Input) => { + LocationMode::Intermediate + } + (crate::ShaderStage::Fragment { .. }, crate::StorageClass::Output) => { + LocationMode::FragmentOutput + } + _ => LocationMode::Uniform, + }; + let resolved = + options.resolve_binding(stage, var.binding.as_ref().unwrap(), loc_mode)?; + let tyvar = TypedGlobalVariable { + module, + names: &self.names, + handle, + usage, + }; + let separator = separate(last_used_global == Some(handle)); + write!(self.out, "\t")?; + tyvar.try_fmt(&mut self.out)?; + resolved.try_fmt_decorated(&mut self.out, separator)?; + if let Some(value) = var.init { + write!(self.out, " = ")?; + self.put_constant(value, module)?; + } + writeln!(self.out)?; + } + writeln!(self.out, ") {{")?; + + match stage { + crate::ShaderStage::Vertex | crate::ShaderStage::Fragment => { + writeln!(self.out, "\t{} {};", output_name, OUTPUT_STRUCT_NAME)?; + } + crate::ShaderStage::Compute => {} + } + for (local_handle, local) in fun.local_variables.iter() { + let name = &self.names[&NameKey::EntryPointLocal(ep_index as _, local_handle)]; + let ty_name = &self.names[&NameKey::Type(local.ty)]; + write!(self.out, "\t{} {}", ty_name, name)?; + if let Some(value) = local.init { + write!(self.out, " = ")?; + self.put_constant(value, module)?; + } + writeln!(self.out, ";")?; + } + + let context = ExpressionContext { + function: fun, + origin: FunctionOrigin::EntryPoint(ep_index as _), + module, + }; + self.put_block(Level(1), &fun.body, &context, return_value)?; + writeln!(self.out, "}}")?; + } + + Ok(()) + } +} diff --git a/third_party/rust/naga/src/back/spv/helpers.rs b/third_party/rust/naga/src/back/spv/helpers.rs new file mode 100644 index 0000000000..5facbe8b69 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/helpers.rs @@ -0,0 +1,20 @@ +use spirv::Word; + +pub(crate) fn bytes_to_words(bytes: &[u8]) -> Vec<Word> { + bytes + .chunks(4) + .map(|chars| chars.iter().rev().fold(0u32, |u, c| (u << 8) | *c as u32)) + .collect() +} + +pub(crate) fn string_to_words(input: &str) -> Vec<Word> { + let bytes = input.as_bytes(); + let mut words = bytes_to_words(bytes); + + if bytes.len() % 4 == 0 { + // nul-termination + words.push(0x0u32); + } + + words +} diff --git a/third_party/rust/naga/src/back/spv/instructions.rs b/third_party/rust/naga/src/back/spv/instructions.rs new file mode 100644 index 0000000000..ab8e56844a --- /dev/null +++ b/third_party/rust/naga/src/back/spv/instructions.rs @@ -0,0 +1,708 @@ +use crate::back::spv::{helpers, Instruction}; +use spirv::{Op, Word}; + +pub(super) enum Signedness { + Unsigned = 0, + Signed = 1, +} + +// +// Debug Instructions +// + +pub(super) fn instruction_source( + source_language: spirv::SourceLanguage, + version: u32, +) -> Instruction { + let mut instruction = Instruction::new(Op::Source); + instruction.add_operand(source_language as u32); + instruction.add_operands(helpers::bytes_to_words(&version.to_le_bytes())); + instruction +} + +pub(super) fn instruction_name(target_id: Word, name: &str) -> Instruction { + let mut instruction = Instruction::new(Op::Name); + instruction.add_operand(target_id); + instruction.add_operands(helpers::string_to_words(name)); + instruction +} + +// +// Annotation Instructions +// + +pub(super) fn instruction_decorate( + target_id: Word, + decoration: spirv::Decoration, + operands: &[Word], +) -> Instruction { + let mut instruction = Instruction::new(Op::Decorate); + instruction.add_operand(target_id); + instruction.add_operand(decoration as u32); + + for operand in operands { + instruction.add_operand(*operand) + } + + instruction +} + +// +// Extension Instructions +// + +pub(super) fn instruction_ext_inst_import(id: Word, name: &str) -> Instruction { + let mut instruction = Instruction::new(Op::ExtInstImport); + instruction.set_result(id); + instruction.add_operands(helpers::string_to_words(name)); + instruction +} + +// +// Mode-Setting Instructions +// + +pub(super) fn instruction_memory_model( + addressing_model: spirv::AddressingModel, + memory_model: spirv::MemoryModel, +) -> Instruction { + let mut instruction = Instruction::new(Op::MemoryModel); + instruction.add_operand(addressing_model as u32); + instruction.add_operand(memory_model as u32); + instruction +} + +pub(super) fn instruction_entry_point( + execution_model: spirv::ExecutionModel, + entry_point_id: Word, + name: &str, + interface_ids: &[Word], +) -> Instruction { + let mut instruction = Instruction::new(Op::EntryPoint); + instruction.add_operand(execution_model as u32); + instruction.add_operand(entry_point_id); + instruction.add_operands(helpers::string_to_words(name)); + + for interface_id in interface_ids { + instruction.add_operand(*interface_id); + } + + instruction +} + +pub(super) fn instruction_execution_mode( + entry_point_id: Word, + execution_mode: spirv::ExecutionMode, +) -> Instruction { + let mut instruction = Instruction::new(Op::ExecutionMode); + instruction.add_operand(entry_point_id); + instruction.add_operand(execution_mode as u32); + instruction +} + +pub(super) fn instruction_capability(capability: spirv::Capability) -> Instruction { + let mut instruction = Instruction::new(Op::Capability); + instruction.add_operand(capability as u32); + instruction +} + +// +// Type-Declaration Instructions +// + +pub(super) fn instruction_type_void(id: Word) -> Instruction { + let mut instruction = Instruction::new(Op::TypeVoid); + instruction.set_result(id); + instruction +} + +pub(super) fn instruction_type_bool(id: Word) -> Instruction { + let mut instruction = Instruction::new(Op::TypeBool); + instruction.set_result(id); + instruction +} + +pub(super) fn instruction_type_int(id: Word, width: Word, signedness: Signedness) -> Instruction { + let mut instruction = Instruction::new(Op::TypeInt); + instruction.set_result(id); + instruction.add_operand(width); + instruction.add_operand(signedness as u32); + instruction +} + +pub(super) fn instruction_type_float(id: Word, width: Word) -> Instruction { + let mut instruction = Instruction::new(Op::TypeFloat); + instruction.set_result(id); + instruction.add_operand(width); + instruction +} + +pub(super) fn instruction_type_vector( + id: Word, + component_type_id: Word, + component_count: crate::VectorSize, +) -> Instruction { + let mut instruction = Instruction::new(Op::TypeVector); + instruction.set_result(id); + instruction.add_operand(component_type_id); + instruction.add_operand(component_count as u32); + instruction +} + +pub(super) fn instruction_type_matrix( + id: Word, + column_type_id: Word, + column_count: crate::VectorSize, +) -> Instruction { + let mut instruction = Instruction::new(Op::TypeMatrix); + instruction.set_result(id); + instruction.add_operand(column_type_id); + instruction.add_operand(column_count as u32); + instruction +} + +pub(super) fn instruction_type_image( + id: Word, + sampled_type_id: Word, + dim: spirv::Dim, + arrayed: bool, + image_class: crate::ImageClass, +) -> Instruction { + let mut instruction = Instruction::new(Op::TypeImage); + instruction.set_result(id); + instruction.add_operand(sampled_type_id); + instruction.add_operand(dim as u32); + + instruction.add_operand(match image_class { + crate::ImageClass::Depth => 1, + _ => 0, + }); + instruction.add_operand(arrayed as u32); + instruction.add_operand(match image_class { + crate::ImageClass::Sampled { multi: true, .. } => 1, + _ => 0, + }); + instruction.add_operand(match image_class { + crate::ImageClass::Sampled { .. } => 1, + _ => 0, + }); + + let format = match image_class { + crate::ImageClass::Storage(format) => match format { + crate::StorageFormat::R8Unorm => spirv::ImageFormat::R8, + crate::StorageFormat::R8Snorm => spirv::ImageFormat::R8Snorm, + crate::StorageFormat::R8Uint => spirv::ImageFormat::R8ui, + crate::StorageFormat::R8Sint => spirv::ImageFormat::R8i, + crate::StorageFormat::R16Uint => spirv::ImageFormat::R16ui, + crate::StorageFormat::R16Sint => spirv::ImageFormat::R16i, + crate::StorageFormat::R16Float => spirv::ImageFormat::R16f, + crate::StorageFormat::Rg8Unorm => spirv::ImageFormat::Rg8, + crate::StorageFormat::Rg8Snorm => spirv::ImageFormat::Rg8Snorm, + crate::StorageFormat::Rg8Uint => spirv::ImageFormat::Rg8ui, + crate::StorageFormat::Rg8Sint => spirv::ImageFormat::Rg8i, + crate::StorageFormat::R32Uint => spirv::ImageFormat::R32ui, + crate::StorageFormat::R32Sint => spirv::ImageFormat::R32i, + crate::StorageFormat::R32Float => spirv::ImageFormat::R32f, + crate::StorageFormat::Rg16Uint => spirv::ImageFormat::Rg16ui, + crate::StorageFormat::Rg16Sint => spirv::ImageFormat::Rg16i, + crate::StorageFormat::Rg16Float => spirv::ImageFormat::Rg16f, + crate::StorageFormat::Rgba8Unorm => spirv::ImageFormat::Rgba8, + crate::StorageFormat::Rgba8Snorm => spirv::ImageFormat::Rgba8Snorm, + crate::StorageFormat::Rgba8Uint => spirv::ImageFormat::Rgba8ui, + crate::StorageFormat::Rgba8Sint => spirv::ImageFormat::Rgba8i, + crate::StorageFormat::Rgb10a2Unorm => spirv::ImageFormat::Rgb10a2ui, + crate::StorageFormat::Rg11b10Float => spirv::ImageFormat::R11fG11fB10f, + crate::StorageFormat::Rg32Uint => spirv::ImageFormat::Rg32ui, + crate::StorageFormat::Rg32Sint => spirv::ImageFormat::Rg32i, + crate::StorageFormat::Rg32Float => spirv::ImageFormat::Rg32f, + crate::StorageFormat::Rgba16Uint => spirv::ImageFormat::Rgba16ui, + crate::StorageFormat::Rgba16Sint => spirv::ImageFormat::Rgba16i, + crate::StorageFormat::Rgba16Float => spirv::ImageFormat::Rgba16f, + crate::StorageFormat::Rgba32Uint => spirv::ImageFormat::Rgba32ui, + crate::StorageFormat::Rgba32Sint => spirv::ImageFormat::Rgba32i, + crate::StorageFormat::Rgba32Float => spirv::ImageFormat::Rgba32f, + }, + _ => spirv::ImageFormat::Unknown, + }; + + instruction.add_operand(format as u32); + instruction +} + +pub(super) fn instruction_type_sampler(id: Word) -> Instruction { + let mut instruction = Instruction::new(Op::TypeSampler); + instruction.set_result(id); + instruction +} + +pub(super) fn instruction_type_sampled_image(id: Word, image_type_id: Word) -> Instruction { + let mut instruction = Instruction::new(Op::TypeSampledImage); + instruction.set_result(id); + instruction.add_operand(image_type_id); + instruction +} + +pub(super) fn instruction_type_array( + id: Word, + element_type_id: Word, + length_id: Word, +) -> Instruction { + let mut instruction = Instruction::new(Op::TypeArray); + instruction.set_result(id); + instruction.add_operand(element_type_id); + instruction.add_operand(length_id); + instruction +} + +pub(super) fn instruction_type_runtime_array(id: Word, element_type_id: Word) -> Instruction { + let mut instruction = Instruction::new(Op::TypeRuntimeArray); + instruction.set_result(id); + instruction.add_operand(element_type_id); + instruction +} + +pub(super) fn instruction_type_struct(id: Word, member_ids: &[Word]) -> Instruction { + let mut instruction = Instruction::new(Op::TypeStruct); + instruction.set_result(id); + + for member_id in member_ids { + instruction.add_operand(*member_id) + } + + instruction +} + +pub(super) fn instruction_type_pointer( + id: Word, + storage_class: spirv::StorageClass, + type_id: Word, +) -> Instruction { + let mut instruction = Instruction::new(Op::TypePointer); + instruction.set_result(id); + instruction.add_operand(storage_class as u32); + instruction.add_operand(type_id); + instruction +} + +pub(super) fn instruction_type_function( + id: Word, + return_type_id: Word, + parameter_ids: &[Word], +) -> Instruction { + let mut instruction = Instruction::new(Op::TypeFunction); + instruction.set_result(id); + instruction.add_operand(return_type_id); + + for parameter_id in parameter_ids { + instruction.add_operand(*parameter_id); + } + + instruction +} + +// +// Constant-Creation Instructions +// + +pub(super) fn instruction_constant_true(result_type_id: Word, id: Word) -> Instruction { + let mut instruction = Instruction::new(Op::ConstantTrue); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction +} + +pub(super) fn instruction_constant_false(result_type_id: Word, id: Word) -> Instruction { + let mut instruction = Instruction::new(Op::ConstantFalse); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction +} + +pub(super) fn instruction_constant(result_type_id: Word, id: Word, values: &[Word]) -> Instruction { + let mut instruction = Instruction::new(Op::Constant); + instruction.set_type(result_type_id); + instruction.set_result(id); + + for value in values { + instruction.add_operand(*value); + } + + instruction +} + +pub(super) fn instruction_constant_composite( + result_type_id: Word, + id: Word, + constituent_ids: &[Word], +) -> Instruction { + let mut instruction = Instruction::new(Op::ConstantComposite); + instruction.set_type(result_type_id); + instruction.set_result(id); + + for constituent_id in constituent_ids { + instruction.add_operand(*constituent_id); + } + + instruction +} + +// +// Memory Instructions +// + +pub(super) fn instruction_variable( + result_type_id: Word, + id: Word, + storage_class: spirv::StorageClass, + initializer_id: Option<Word>, +) -> Instruction { + let mut instruction = Instruction::new(Op::Variable); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(storage_class as u32); + + if let Some(initializer_id) = initializer_id { + instruction.add_operand(initializer_id); + } + + instruction +} + +pub(super) fn instruction_load( + result_type_id: Word, + id: Word, + pointer_type_id: Word, + memory_access: Option<spirv::MemoryAccess>, +) -> Instruction { + let mut instruction = Instruction::new(Op::Load); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(pointer_type_id); + + if let Some(memory_access) = memory_access { + instruction.add_operand(memory_access.bits()); + } + + instruction +} + +pub(super) fn instruction_store( + pointer_type_id: Word, + object_id: Word, + memory_access: Option<spirv::MemoryAccess>, +) -> Instruction { + let mut instruction = Instruction::new(Op::Store); + instruction.add_operand(pointer_type_id); + instruction.add_operand(object_id); + + if let Some(memory_access) = memory_access { + instruction.add_operand(memory_access.bits()); + } + + instruction +} + +pub(super) fn instruction_access_chain( + result_type_id: Word, + id: Word, + base_id: Word, + index_ids: &[Word], +) -> Instruction { + let mut instruction = Instruction::new(Op::AccessChain); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(base_id); + + for index_id in index_ids { + instruction.add_operand(*index_id); + } + + instruction +} + +// +// Function Instructions +// + +pub(super) fn instruction_function( + return_type_id: Word, + id: Word, + function_control: spirv::FunctionControl, + function_type_id: Word, +) -> Instruction { + let mut instruction = Instruction::new(Op::Function); + instruction.set_type(return_type_id); + instruction.set_result(id); + instruction.add_operand(function_control.bits()); + instruction.add_operand(function_type_id); + instruction +} + +pub(super) fn instruction_function_parameter(result_type_id: Word, id: Word) -> Instruction { + let mut instruction = Instruction::new(Op::FunctionParameter); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction +} + +pub(super) fn instruction_function_end() -> Instruction { + Instruction::new(Op::FunctionEnd) +} + +pub(super) fn instruction_function_call( + result_type_id: Word, + id: Word, + function_id: Word, + argument_ids: &[Word], +) -> Instruction { + let mut instruction = Instruction::new(Op::FunctionCall); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(function_id); + + for argument_id in argument_ids { + instruction.add_operand(*argument_id); + } + + instruction +} + +// +// Image Instructions +// +pub(super) fn instruction_sampled_image( + result_type_id: Word, + id: Word, + image: Word, + sampler: Word, +) -> Instruction { + let mut instruction = Instruction::new(Op::SampledImage); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(image); + instruction.add_operand(sampler); + instruction +} + +pub(super) fn instruction_image_sample_implicit_lod( + result_type_id: Word, + id: Word, + sampled_image: Word, + coordinates: Word, +) -> Instruction { + let mut instruction = Instruction::new(Op::ImageSampleImplicitLod); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(sampled_image); + instruction.add_operand(coordinates); + instruction +} + +// +// Conversion Instructions +// +pub(super) fn instruction_unary( + op: Op, + result_type_id: Word, + id: Word, + value: Word, +) -> Instruction { + let mut instruction = Instruction::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(value); + instruction +} + +// +// Composite Instructions +// + +pub(super) fn instruction_composite_construct( + result_type_id: Word, + id: Word, + constituent_ids: &[Word], +) -> Instruction { + let mut instruction = Instruction::new(Op::CompositeConstruct); + instruction.set_type(result_type_id); + instruction.set_result(id); + + for constituent_id in constituent_ids { + instruction.add_operand(*constituent_id); + } + + instruction +} + +// +// Arithmetic Instructions +// +fn instruction_binary( + op: Op, + result_type_id: Word, + id: Word, + operand_1: Word, + operand_2: Word, +) -> Instruction { + let mut instruction = Instruction::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(operand_1); + instruction.add_operand(operand_2); + instruction +} + +pub(super) fn instruction_i_sub( + result_type_id: Word, + id: Word, + operand_1: Word, + operand_2: Word, +) -> Instruction { + instruction_binary(Op::ISub, result_type_id, id, operand_1, operand_2) +} + +pub(super) fn instruction_f_sub( + result_type_id: Word, + id: Word, + operand_1: Word, + operand_2: Word, +) -> Instruction { + instruction_binary(Op::FSub, result_type_id, id, operand_1, operand_2) +} + +pub(super) fn instruction_i_mul( + result_type_id: Word, + id: Word, + operand_1: Word, + operand_2: Word, +) -> Instruction { + instruction_binary(Op::IMul, result_type_id, id, operand_1, operand_2) +} + +pub(super) fn instruction_f_mul( + result_type_id: Word, + id: Word, + operand_1: Word, + operand_2: Word, +) -> Instruction { + instruction_binary(Op::FMul, result_type_id, id, operand_1, operand_2) +} + +pub(super) fn instruction_vector_times_scalar( + result_type_id: Word, + id: Word, + vector_type_id: Word, + scalar_type_id: Word, +) -> Instruction { + let mut instruction = Instruction::new(Op::VectorTimesScalar); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(vector_type_id); + instruction.add_operand(scalar_type_id); + instruction +} + +pub(super) fn instruction_matrix_times_scalar( + result_type_id: Word, + id: Word, + matrix_id: Word, + scalar_id: Word, +) -> Instruction { + let mut instruction = Instruction::new(Op::MatrixTimesScalar); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(matrix_id); + instruction.add_operand(scalar_id); + instruction +} + +pub(super) fn instruction_vector_times_matrix( + result_type_id: Word, + id: Word, + vector_id: Word, + matrix_id: Word, +) -> Instruction { + let mut instruction = Instruction::new(Op::VectorTimesMatrix); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(vector_id); + instruction.add_operand(matrix_id); + instruction +} + +pub(super) fn instruction_matrix_times_vector( + result_type_id: Word, + id: Word, + matrix_id: Word, + vector_id: Word, +) -> Instruction { + let mut instruction = Instruction::new(Op::MatrixTimesVector); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(matrix_id); + instruction.add_operand(vector_id); + instruction +} + +pub(super) fn instruction_matrix_times_matrix( + result_type_id: Word, + id: Word, + left_matrix: Word, + right_matrix: Word, +) -> Instruction { + let mut instruction = Instruction::new(Op::MatrixTimesMatrix); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(left_matrix); + instruction.add_operand(right_matrix); + instruction +} + +// +// Bit Instructions +// + +pub(super) fn instruction_bitwise_and( + result_type_id: Word, + id: Word, + operand_1: Word, + operand_2: Word, +) -> Instruction { + instruction_binary(Op::BitwiseAnd, result_type_id, id, operand_1, operand_2) +} + +// +// Relational and Logical Instructions +// + +// +// Derivative Instructions +// + +// +// Control-Flow Instructions +// + +pub(super) fn instruction_label(id: Word) -> Instruction { + let mut instruction = Instruction::new(Op::Label); + instruction.set_result(id); + instruction +} + +pub(super) fn instruction_return() -> Instruction { + Instruction::new(Op::Return) +} + +pub(super) fn instruction_return_value(value_id: Word) -> Instruction { + let mut instruction = Instruction::new(Op::ReturnValue); + instruction.add_operand(value_id); + instruction +} + +// +// Atomic Instructions +// + +// +// Primitive Instructions +// diff --git a/third_party/rust/naga/src/back/spv/layout.rs b/third_party/rust/naga/src/back/spv/layout.rs new file mode 100644 index 0000000000..006e785317 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/layout.rs @@ -0,0 +1,91 @@ +use crate::back::spv::{Instruction, LogicalLayout, PhysicalLayout}; +use spirv::*; +use std::iter; + +impl PhysicalLayout { + pub(super) fn new(header: &crate::Header) -> Self { + let version: Word = ((header.version.0 as u32) << 16) + | ((header.version.1 as u32) << 8) + | header.version.2 as u32; + + PhysicalLayout { + magic_number: MAGIC_NUMBER, + version, + generator: header.generator, + bound: 0, + instruction_schema: 0x0u32, + } + } + + pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) { + sink.extend(iter::once(self.magic_number)); + sink.extend(iter::once(self.version)); + sink.extend(iter::once(self.generator)); + sink.extend(iter::once(self.bound)); + sink.extend(iter::once(self.instruction_schema)); + } + + pub(super) fn supports_storage_buffers(&self) -> bool { + self.version >= 0x10300 + } +} + +impl LogicalLayout { + pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) { + sink.extend(self.capabilities.iter().cloned()); + sink.extend(self.extensions.iter().cloned()); + sink.extend(self.ext_inst_imports.iter().cloned()); + sink.extend(self.memory_model.iter().cloned()); + sink.extend(self.entry_points.iter().cloned()); + sink.extend(self.execution_modes.iter().cloned()); + sink.extend(self.debugs.iter().cloned()); + sink.extend(self.annotations.iter().cloned()); + sink.extend(self.declarations.iter().cloned()); + sink.extend(self.function_declarations.iter().cloned()); + sink.extend(self.function_definitions.iter().cloned()); + } +} + +impl Instruction { + pub(super) fn new(op: Op) -> Self { + Instruction { + op, + wc: 1, // Always start at 1 for the first word (OP + WC), + type_id: None, + result_id: None, + operands: vec![], + } + } + + #[allow(clippy::panic)] + pub(super) fn set_type(&mut self, id: Word) { + assert!(self.type_id.is_none(), "Type can only be set once"); + self.type_id = Some(id); + self.wc += 1; + } + + #[allow(clippy::panic)] + pub(super) fn set_result(&mut self, id: Word) { + assert!(self.result_id.is_none(), "Result can only be set once"); + self.result_id = Some(id); + self.wc += 1; + } + + pub(super) fn add_operand(&mut self, operand: Word) { + self.operands.push(operand); + self.wc += 1; + } + + pub(super) fn add_operands(&mut self, operands: Vec<Word>) { + for operand in operands.into_iter() { + self.add_operand(operand) + } + } + + pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) { + sink.extend(Some((self.wc << 16 | self.op as u32) as u32)); + sink.extend(self.type_id); + sink.extend(self.result_id); + sink.extend(self.operands.iter().cloned()); + } +} diff --git a/third_party/rust/naga/src/back/spv/layout_tests.rs b/third_party/rust/naga/src/back/spv/layout_tests.rs new file mode 100644 index 0000000000..37024b238f --- /dev/null +++ b/third_party/rust/naga/src/back/spv/layout_tests.rs @@ -0,0 +1,166 @@ +use crate::back::spv::test_framework::*; +use crate::back::spv::{helpers, Instruction, LogicalLayout, PhysicalLayout}; +use crate::Header; +use spirv::*; + +#[test] +fn test_physical_layout_in_words() { + let header = Header { + generator: 0, + version: (1, 2, 3), + }; + let bound = 5; + + let mut output = vec![]; + let mut layout = PhysicalLayout::new(&header); + layout.bound = bound; + + layout.in_words(&mut output); + + assert_eq!(output[0], spirv::MAGIC_NUMBER); + assert_eq!( + output[1], + to_word(&[header.version.0, header.version.1, header.version.2, 1]) + ); + assert_eq!(output[2], 0); + assert_eq!(output[3], bound); + assert_eq!(output[4], 0); +} + +#[test] +fn test_logical_layout_in_words() { + let mut output = vec![]; + let mut layout = LogicalLayout::default(); + let layout_vectors = 11; + let mut instructions = Vec::with_capacity(layout_vectors); + + let vector_names = &[ + "Capabilities", + "Extensions", + "External Instruction Imports", + "Memory Model", + "Entry Points", + "Execution Modes", + "Debugs", + "Annotations", + "Declarations", + "Function Declarations", + "Function Definitions", + ]; + + for i in 0..layout_vectors { + let mut dummy_instruction = Instruction::new(Op::Constant); + dummy_instruction.set_type((i + 1) as u32); + dummy_instruction.set_result((i + 2) as u32); + dummy_instruction.add_operand((i + 3) as u32); + dummy_instruction.add_operands(helpers::string_to_words( + format!("This is the vector: {}", vector_names[i]).as_str(), + )); + instructions.push(dummy_instruction); + } + + instructions[0].to_words(&mut layout.capabilities); + instructions[1].to_words(&mut layout.extensions); + instructions[2].to_words(&mut layout.ext_inst_imports); + instructions[3].to_words(&mut layout.memory_model); + instructions[4].to_words(&mut layout.entry_points); + instructions[5].to_words(&mut layout.execution_modes); + instructions[6].to_words(&mut layout.debugs); + instructions[7].to_words(&mut layout.annotations); + instructions[8].to_words(&mut layout.declarations); + instructions[9].to_words(&mut layout.function_declarations); + instructions[10].to_words(&mut layout.function_definitions); + + layout.in_words(&mut output); + + let mut index: usize = 0; + for instruction in instructions { + let wc = instruction.wc as usize; + let instruction_output = &output[index..index + wc]; + validate_instruction(instruction_output, &instruction); + index += wc; + } +} + +#[test] +fn test_instruction_set_type() { + let ty = 1; + let mut instruction = Instruction::new(Op::Constant); + assert_eq!(instruction.wc, 1); + + instruction.set_type(ty); + assert_eq!(instruction.type_id.unwrap(), ty); + assert_eq!(instruction.wc, 2); +} + +#[test] +#[should_panic] +fn test_instruction_set_type_twice() { + let ty = 1; + let mut instruction = Instruction::new(Op::Constant); + instruction.set_type(ty); + instruction.set_type(ty); +} + +#[test] +fn test_instruction_set_result() { + let result = 1; + let mut instruction = Instruction::new(Op::Constant); + assert_eq!(instruction.wc, 1); + + instruction.set_result(result); + assert_eq!(instruction.result_id.unwrap(), result); + assert_eq!(instruction.wc, 2); +} + +#[test] +#[should_panic] +fn test_instruction_set_result_twice() { + let result = 1; + let mut instruction = Instruction::new(Op::Constant); + instruction.set_result(result); + instruction.set_result(result); +} + +#[test] +fn test_instruction_add_operand() { + let operand = 1; + let mut instruction = Instruction::new(Op::Constant); + assert_eq!(instruction.operands.len(), 0); + assert_eq!(instruction.wc, 1); + + instruction.add_operand(operand); + assert_eq!(instruction.operands.len(), 1); + assert_eq!(instruction.wc, 2); +} + +#[test] +fn test_instruction_add_operands() { + let operands = vec![1, 2, 3]; + let mut instruction = Instruction::new(Op::Constant); + assert_eq!(instruction.operands.len(), 0); + assert_eq!(instruction.wc, 1); + + instruction.add_operands(operands); + assert_eq!(instruction.operands.len(), 3); + assert_eq!(instruction.wc, 4); +} + +#[test] +fn test_instruction_to_words() { + let ty = 1; + let result = 2; + let operand = 3; + let mut instruction = Instruction::new(Op::Constant); + instruction.set_type(ty); + instruction.set_result(result); + instruction.add_operand(operand); + + let mut output = vec![]; + instruction.to_words(&mut output); + validate_instruction(output.as_slice(), &instruction); +} + +fn to_word(bytes: &[u8]) -> Word { + ((bytes[0] as u32) << 16) | ((bytes[1] as u32) << 8) | bytes[2] as u32 +} diff --git a/third_party/rust/naga/src/back/spv/mod.rs b/third_party/rust/naga/src/back/spv/mod.rs new file mode 100644 index 0000000000..15f1598357 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/mod.rs @@ -0,0 +1,52 @@ +mod helpers; +mod instructions; +mod layout; +mod writer; + +#[cfg(test)] +mod test_framework; + +#[cfg(test)] +mod layout_tests; + +pub use writer::Writer; + +use spirv::*; + +bitflags::bitflags! { + pub struct WriterFlags: u32 { + const NONE = 0x0; + const DEBUG = 0x1; + } +} + +struct PhysicalLayout { + magic_number: Word, + version: Word, + generator: Word, + bound: Word, + instruction_schema: Word, +} + +#[derive(Default)] +struct LogicalLayout { + capabilities: Vec<Word>, + extensions: Vec<Word>, + ext_inst_imports: Vec<Word>, + memory_model: Vec<Word>, + entry_points: Vec<Word>, + execution_modes: Vec<Word>, + debugs: Vec<Word>, + annotations: Vec<Word>, + declarations: Vec<Word>, + function_declarations: Vec<Word>, + function_definitions: Vec<Word>, +} + +pub(self) struct Instruction { + op: Op, + wc: u32, + type_id: Option<Word>, + result_id: Option<Word>, + operands: Vec<Word>, +} diff --git a/third_party/rust/naga/src/back/spv/test_framework.rs b/third_party/rust/naga/src/back/spv/test_framework.rs new file mode 100644 index 0000000000..be2fa74fe1 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/test_framework.rs @@ -0,0 +1,27 @@ +pub(super) fn validate_instruction( + words: &[spirv::Word], + instruction: &crate::back::spv::Instruction, +) { + let mut inst_index = 0; + let (wc, op) = ((words[inst_index] >> 16) as u16, words[inst_index] as u16); + inst_index += 1; + + assert_eq!(wc, words.len() as u16); + assert_eq!(op, instruction.op as u16); + + if instruction.type_id.is_some() { + assert_eq!(words[inst_index], instruction.type_id.unwrap()); + inst_index += 1; + } + + if instruction.result_id.is_some() { + assert_eq!(words[inst_index], instruction.result_id.unwrap()); + inst_index += 1; + } + + let mut op_index = 0; + for i in inst_index..wc as usize { + assert_eq!(words[i as usize], instruction.operands[op_index]); + op_index += 1; + } +} diff --git a/third_party/rust/naga/src/back/spv/writer.rs b/third_party/rust/naga/src/back/spv/writer.rs new file mode 100644 index 0000000000..f1b43f5289 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/writer.rs @@ -0,0 +1,1776 @@ +/*! Standard Portable Intermediate Representation (SPIR-V) backend !*/ +use super::{Instruction, LogicalLayout, PhysicalLayout, WriterFlags}; +use spirv::Word; +use std::{collections::hash_map::Entry, ops}; +use thiserror::Error; + +const BITS_PER_BYTE: crate::Bytes = 8; + +#[derive(Clone, Debug, Error)] +pub enum Error { + #[error("can't find local variable: {0:?}")] + UnknownLocalVariable(crate::LocalVariable), + #[error("bad image class for op: {0:?}")] + BadImageClass(crate::ImageClass), + #[error("not an image")] + NotImage, + #[error("empty value")] + FeatureNotImplemented(), +} + +struct Block { + label: Option<Instruction>, + body: Vec<Instruction>, + termination: Option<Instruction>, +} + +impl Block { + pub fn new() -> Self { + Block { + label: None, + body: vec![], + termination: None, + } + } +} + +struct LocalVariable { + id: Word, + name: Option<String>, + instruction: Instruction, +} + +struct Function { + signature: Option<Instruction>, + parameters: Vec<Instruction>, + variables: Vec<LocalVariable>, + blocks: Vec<Block>, +} + +impl Function { + pub fn new() -> Self { + Function { + signature: None, + parameters: vec![], + variables: vec![], + blocks: vec![], + } + } + + fn to_words(&self, sink: &mut impl Extend<Word>) { + self.signature.as_ref().unwrap().to_words(sink); + for instruction in self.parameters.iter() { + instruction.to_words(sink); + } + for (index, block) in self.blocks.iter().enumerate() { + block.label.as_ref().unwrap().to_words(sink); + if index == 0 { + for local_var in self.variables.iter() { + local_var.instruction.to_words(sink); + } + } + for instruction in block.body.iter() { + instruction.to_words(sink); + } + block.termination.as_ref().unwrap().to_words(sink); + } + } +} + +#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] +enum LocalType { + Void, + Scalar { + kind: crate::ScalarKind, + width: crate::Bytes, + }, + Vector { + size: crate::VectorSize, + kind: crate::ScalarKind, + width: crate::Bytes, + }, + Pointer { + base: crate::Handle<crate::Type>, + class: crate::StorageClass, + }, + SampledImage { + image_type: crate::Handle<crate::Type>, + }, +} + +#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] +enum LookupType { + Handle(crate::Handle<crate::Type>), + Local(LocalType), +} + +fn map_dim(dim: crate::ImageDimension) -> spirv::Dim { + match dim { + crate::ImageDimension::D1 => spirv::Dim::Dim1D, + crate::ImageDimension::D2 => spirv::Dim::Dim2D, + crate::ImageDimension::D3 => spirv::Dim::Dim2D, + crate::ImageDimension::Cube => spirv::Dim::DimCube, + } +} + +#[derive(Debug, PartialEq, Clone, Hash, Eq)] +struct LookupFunctionType { + parameter_type_ids: Vec<Word>, + return_type_id: Word, +} + +enum MaybeOwned<'a, T> { + Owned(T), + Borrowed(&'a T), +} + +impl<'a, T> ops::Deref for MaybeOwned<'a, T> { + type Target = T; + fn deref(&self) -> &T { + match *self { + MaybeOwned::Owned(ref value) => value, + MaybeOwned::Borrowed(reference) => reference, + } + } +} + +enum Dimension { + Scalar, + Vector, + Matrix, +} + +fn get_dimension(ty_inner: &crate::TypeInner) -> Dimension { + match *ty_inner { + crate::TypeInner::Scalar { .. } => Dimension::Scalar, + crate::TypeInner::Vector { .. } => Dimension::Vector, + crate::TypeInner::Matrix { .. } => Dimension::Matrix, + _ => unreachable!(), + } +} + +pub struct Writer { + physical_layout: PhysicalLayout, + logical_layout: LogicalLayout, + id_count: u32, + capabilities: crate::FastHashSet<spirv::Capability>, + debugs: Vec<Instruction>, + annotations: Vec<Instruction>, + writer_flags: WriterFlags, + void_type: Option<u32>, + lookup_type: crate::FastHashMap<LookupType, Word>, + lookup_function: crate::FastHashMap<crate::Handle<crate::Function>, Word>, + lookup_function_type: crate::FastHashMap<LookupFunctionType, Word>, + lookup_constant: crate::FastHashMap<crate::Handle<crate::Constant>, Word>, + lookup_global_variable: crate::FastHashMap<crate::Handle<crate::GlobalVariable>, Word>, +} + +// type alias, for success return of write_expression +type WriteExpressionOutput = (Word, LookupType); + +impl Writer { + pub fn new(header: &crate::Header, writer_flags: WriterFlags) -> Self { + Writer { + physical_layout: PhysicalLayout::new(header), + logical_layout: LogicalLayout::default(), + id_count: 0, + capabilities: crate::FastHashSet::default(), + debugs: vec![], + annotations: vec![], + writer_flags, + void_type: None, + lookup_type: crate::FastHashMap::default(), + lookup_function: crate::FastHashMap::default(), + lookup_function_type: crate::FastHashMap::default(), + lookup_constant: crate::FastHashMap::default(), + lookup_global_variable: crate::FastHashMap::default(), + } + } + + fn generate_id(&mut self) -> Word { + self.id_count += 1; + self.id_count + } + + fn try_add_capabilities(&mut self, capabilities: &[spirv::Capability]) { + for capability in capabilities.iter() { + self.capabilities.insert(*capability); + } + } + + fn get_type_id(&mut self, arena: &crate::Arena<crate::Type>, lookup_ty: LookupType) -> Word { + if let Entry::Occupied(e) = self.lookup_type.entry(lookup_ty) { + *e.get() + } else { + match lookup_ty { + LookupType::Handle(handle) => match arena[handle].inner { + crate::TypeInner::Scalar { kind, width } => self + .get_type_id(arena, LookupType::Local(LocalType::Scalar { kind, width })), + _ => self.write_type_declaration_arena(arena, handle), + }, + LookupType::Local(local_ty) => self.write_type_declaration_local(arena, local_ty), + } + } + } + + fn get_constant_id( + &mut self, + handle: crate::Handle<crate::Constant>, + ir_module: &crate::Module, + ) -> Word { + match self.lookup_constant.entry(handle) { + Entry::Occupied(e) => *e.get(), + _ => { + let (instruction, id) = self.write_constant_type(handle, ir_module); + instruction.to_words(&mut self.logical_layout.declarations); + id + } + } + } + + fn get_global_variable_id( + &mut self, + ir_module: &crate::Module, + handle: crate::Handle<crate::GlobalVariable>, + ) -> Word { + match self.lookup_global_variable.entry(handle) { + Entry::Occupied(e) => *e.get(), + _ => { + let (instruction, id) = self.write_global_variable(ir_module, handle); + instruction.to_words(&mut self.logical_layout.declarations); + id + } + } + } + + fn get_function_return_type( + &mut self, + ty: Option<crate::Handle<crate::Type>>, + arena: &crate::Arena<crate::Type>, + ) -> Word { + match ty { + Some(handle) => self.get_type_id(arena, LookupType::Handle(handle)), + None => match self.void_type { + Some(id) => id, + None => { + let id = self.generate_id(); + self.void_type = Some(id); + super::instructions::instruction_type_void(id) + .to_words(&mut self.logical_layout.declarations); + id + } + }, + } + } + + fn get_pointer_id( + &mut self, + arena: &crate::Arena<crate::Type>, + handle: crate::Handle<crate::Type>, + class: crate::StorageClass, + ) -> Word { + let ty = &arena[handle]; + let ty_id = self.get_type_id(arena, LookupType::Handle(handle)); + match ty.inner { + crate::TypeInner::Pointer { .. } => ty_id, + _ => { + match self + .lookup_type + .entry(LookupType::Local(LocalType::Pointer { + base: handle, + class, + })) { + Entry::Occupied(e) => *e.get(), + _ => { + let id = + self.create_pointer(ty_id, self.parse_to_spirv_storage_class(class)); + self.lookup_type.insert( + LookupType::Local(LocalType::Pointer { + base: handle, + class, + }), + id, + ); + id + } + } + } + } + } + + fn create_pointer(&mut self, ty_id: Word, class: spirv::StorageClass) -> Word { + let id = self.generate_id(); + let instruction = super::instructions::instruction_type_pointer(id, class, ty_id); + instruction.to_words(&mut self.logical_layout.declarations); + id + } + + fn create_constant(&mut self, type_id: Word, value: &[Word]) -> Word { + let id = self.generate_id(); + let instruction = super::instructions::instruction_constant(type_id, id, value); + instruction.to_words(&mut self.logical_layout.declarations); + id + } + + fn write_function( + &mut self, + ir_function: &crate::Function, + ir_module: &crate::Module, + ) -> spirv::Word { + let mut function = Function::new(); + + for (_, variable) in ir_function.local_variables.iter() { + let id = self.generate_id(); + + let init_word = variable + .init + .map(|constant| self.get_constant_id(constant, ir_module)); + + let pointer_id = + self.get_pointer_id(&ir_module.types, variable.ty, crate::StorageClass::Function); + function.variables.push(LocalVariable { + id, + name: variable.name.clone(), + instruction: super::instructions::instruction_variable( + pointer_id, + id, + spirv::StorageClass::Function, + init_word, + ), + }); + } + + let return_type_id = + self.get_function_return_type(ir_function.return_type, &ir_module.types); + let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len()); + + let mut function_parameter_pointer_ids = vec![]; + + for argument in ir_function.arguments.iter() { + let id = self.generate_id(); + let pointer_id = + self.get_pointer_id(&ir_module.types, argument.ty, crate::StorageClass::Function); + + function_parameter_pointer_ids.push(pointer_id); + parameter_type_ids + .push(self.get_type_id(&ir_module.types, LookupType::Handle(argument.ty))); + function + .parameters + .push(super::instructions::instruction_function_parameter( + pointer_id, id, + )); + } + + let lookup_function_type = LookupFunctionType { + return_type_id, + parameter_type_ids, + }; + + let function_id = self.generate_id(); + let function_type = + self.get_function_type(lookup_function_type, function_parameter_pointer_ids); + function.signature = Some(super::instructions::instruction_function( + return_type_id, + function_id, + spirv::FunctionControl::empty(), + function_type, + )); + + self.write_block(&ir_function.body, ir_module, ir_function, &mut function); + + function.to_words(&mut self.logical_layout.function_definitions); + super::instructions::instruction_function_end() + .to_words(&mut self.logical_layout.function_definitions); + + function_id + } + + // TODO Move to instructions module + fn write_entry_point( + &mut self, + entry_point: &crate::EntryPoint, + stage: crate::ShaderStage, + name: &str, + ir_module: &crate::Module, + ) -> Instruction { + let function_id = self.write_function(&entry_point.function, ir_module); + + let exec_model = match stage { + crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex, + crate::ShaderStage::Fragment { .. } => spirv::ExecutionModel::Fragment, + crate::ShaderStage::Compute { .. } => spirv::ExecutionModel::GLCompute, + }; + + let mut interface_ids = vec![]; + for ((handle, _), &usage) in ir_module + .global_variables + .iter() + .filter(|&(_, var)| { + var.class == crate::StorageClass::Input || var.class == crate::StorageClass::Output + }) + .zip(&entry_point.function.global_usage) + { + if usage.contains(crate::GlobalUse::STORE) || usage.contains(crate::GlobalUse::LOAD) { + let id = self.get_global_variable_id(ir_module, handle); + interface_ids.push(id); + } + } + + self.try_add_capabilities(exec_model.required_capabilities()); + match stage { + crate::ShaderStage::Vertex => {} + crate::ShaderStage::Fragment => { + let execution_mode = spirv::ExecutionMode::OriginUpperLeft; + self.try_add_capabilities(execution_mode.required_capabilities()); + super::instructions::instruction_execution_mode(function_id, execution_mode) + .to_words(&mut self.logical_layout.execution_modes); + } + crate::ShaderStage::Compute => {} + } + + if self.writer_flags.contains(WriterFlags::DEBUG) { + self.debugs + .push(super::instructions::instruction_name(function_id, name)); + } + + super::instructions::instruction_entry_point( + exec_model, + function_id, + name, + interface_ids.as_slice(), + ) + } + + fn write_scalar(&self, id: Word, kind: crate::ScalarKind, width: crate::Bytes) -> Instruction { + let bits = (width * BITS_PER_BYTE) as u32; + match kind { + crate::ScalarKind::Sint => super::instructions::instruction_type_int( + id, + bits, + super::instructions::Signedness::Signed, + ), + crate::ScalarKind::Uint => super::instructions::instruction_type_int( + id, + bits, + super::instructions::Signedness::Unsigned, + ), + crate::ScalarKind::Float => super::instructions::instruction_type_float(id, bits), + crate::ScalarKind::Bool => super::instructions::instruction_type_bool(id), + } + } + + fn parse_to_spirv_storage_class(&self, class: crate::StorageClass) -> spirv::StorageClass { + match class { + crate::StorageClass::Handle => spirv::StorageClass::UniformConstant, + crate::StorageClass::Function => spirv::StorageClass::Function, + crate::StorageClass::Input => spirv::StorageClass::Input, + crate::StorageClass::Output => spirv::StorageClass::Output, + crate::StorageClass::Private => spirv::StorageClass::Private, + crate::StorageClass::Storage if self.physical_layout.supports_storage_buffers() => { + spirv::StorageClass::StorageBuffer + } + crate::StorageClass::Storage | crate::StorageClass::Uniform => { + spirv::StorageClass::Uniform + } + crate::StorageClass::WorkGroup => spirv::StorageClass::Workgroup, + crate::StorageClass::PushConstant => spirv::StorageClass::PushConstant, + } + } + + fn write_type_declaration_local( + &mut self, + arena: &crate::Arena<crate::Type>, + local_ty: LocalType, + ) -> Word { + let id = self.generate_id(); + let instruction = match local_ty { + LocalType::Void => unreachable!(), + LocalType::Scalar { kind, width } => self.write_scalar(id, kind, width), + LocalType::Vector { size, kind, width } => { + let scalar_id = + self.get_type_id(arena, LookupType::Local(LocalType::Scalar { kind, width })); + super::instructions::instruction_type_vector(id, scalar_id, size) + } + LocalType::Pointer { .. } => unimplemented!(), + LocalType::SampledImage { image_type } => { + let image_type_id = self.get_type_id(arena, LookupType::Handle(image_type)); + super::instructions::instruction_type_sampled_image(id, image_type_id) + } + }; + + self.lookup_type.insert(LookupType::Local(local_ty), id); + instruction.to_words(&mut self.logical_layout.declarations); + id + } + + fn write_type_declaration_arena( + &mut self, + arena: &crate::Arena<crate::Type>, + handle: crate::Handle<crate::Type>, + ) -> Word { + let ty = &arena[handle]; + let id = self.generate_id(); + + let instruction = match ty.inner { + crate::TypeInner::Scalar { kind, width } => { + self.lookup_type + .insert(LookupType::Local(LocalType::Scalar { kind, width }), id); + self.write_scalar(id, kind, width) + } + crate::TypeInner::Vector { size, kind, width } => { + let scalar_id = + self.get_type_id(arena, LookupType::Local(LocalType::Scalar { kind, width })); + self.lookup_type.insert( + LookupType::Local(LocalType::Vector { size, kind, width }), + id, + ); + super::instructions::instruction_type_vector(id, scalar_id, size) + } + crate::TypeInner::Matrix { + columns, + rows: _, + width, + } => { + let vector_id = self.get_type_id( + arena, + LookupType::Local(LocalType::Vector { + size: columns, + kind: crate::ScalarKind::Float, + width, + }), + ); + super::instructions::instruction_type_matrix(id, vector_id, columns) + } + crate::TypeInner::Image { + dim, + arrayed, + class, + } => { + let width = 4; + let local_type = match class { + crate::ImageClass::Sampled { kind, multi: _ } => { + LocalType::Scalar { kind, width } + } + crate::ImageClass::Depth => LocalType::Scalar { + kind: crate::ScalarKind::Float, + width, + }, + crate::ImageClass::Storage(format) => LocalType::Scalar { + kind: format.into(), + width, + }, + }; + let type_id = self.get_type_id(arena, LookupType::Local(local_type)); + let dim = map_dim(dim); + self.try_add_capabilities(dim.required_capabilities()); + super::instructions::instruction_type_image(id, type_id, dim, arrayed, class) + } + crate::TypeInner::Sampler { comparison: _ } => { + super::instructions::instruction_type_sampler(id) + } + crate::TypeInner::Array { base, size, stride } => { + if let Some(array_stride) = stride { + self.annotations + .push(super::instructions::instruction_decorate( + id, + spirv::Decoration::ArrayStride, + &[array_stride.get()], + )); + } + + let type_id = self.get_type_id(arena, LookupType::Handle(base)); + match size { + crate::ArraySize::Constant(const_handle) => { + let length_id = self.lookup_constant[&const_handle]; + super::instructions::instruction_type_array(id, type_id, length_id) + } + crate::ArraySize::Dynamic => { + super::instructions::instruction_type_runtime_array(id, type_id) + } + } + } + crate::TypeInner::Struct { ref members } => { + let mut member_ids = Vec::with_capacity(members.len()); + for member in members { + let member_id = self.get_type_id(arena, LookupType::Handle(member.ty)); + member_ids.push(member_id); + } + super::instructions::instruction_type_struct(id, member_ids.as_slice()) + } + crate::TypeInner::Pointer { base, class } => { + let type_id = self.get_type_id(arena, LookupType::Handle(base)); + self.lookup_type + .insert(LookupType::Local(LocalType::Pointer { base, class }), id); + super::instructions::instruction_type_pointer( + id, + self.parse_to_spirv_storage_class(class), + type_id, + ) + } + }; + + self.lookup_type.insert(LookupType::Handle(handle), id); + instruction.to_words(&mut self.logical_layout.declarations); + id + } + + fn write_constant_type( + &mut self, + handle: crate::Handle<crate::Constant>, + ir_module: &crate::Module, + ) -> (Instruction, Word) { + let id = self.generate_id(); + self.lookup_constant.insert(handle, id); + let constant = &ir_module.constants[handle]; + let arena = &ir_module.types; + + match constant.inner { + crate::ConstantInner::Sint(val) => { + let ty = &ir_module.types[constant.ty]; + let type_id = self.get_type_id(arena, LookupType::Handle(constant.ty)); + + let instruction = match ty.inner { + crate::TypeInner::Scalar { kind: _, width } => match width { + 4 => super::instructions::instruction_constant(type_id, id, &[val as u32]), + 8 => { + let (low, high) = ((val >> 32) as u32, val as u32); + super::instructions::instruction_constant(type_id, id, &[low, high]) + } + _ => unreachable!(), + }, + _ => unreachable!(), + }; + (instruction, id) + } + crate::ConstantInner::Uint(val) => { + let ty = &ir_module.types[constant.ty]; + let type_id = self.get_type_id(arena, LookupType::Handle(constant.ty)); + + let instruction = match ty.inner { + crate::TypeInner::Scalar { kind: _, width } => match width { + 4 => super::instructions::instruction_constant(type_id, id, &[val as u32]), + 8 => { + let (low, high) = ((val >> 32) as u32, val as u32); + super::instructions::instruction_constant(type_id, id, &[low, high]) + } + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + (instruction, id) + } + crate::ConstantInner::Float(val) => { + let ty = &ir_module.types[constant.ty]; + let type_id = self.get_type_id(arena, LookupType::Handle(constant.ty)); + + let instruction = match ty.inner { + crate::TypeInner::Scalar { kind: _, width } => match width { + 4 => super::instructions::instruction_constant( + type_id, + id, + &[(val as f32).to_bits()], + ), + 8 => { + let bits = f64::to_bits(val); + let (low, high) = ((bits >> 32) as u32, bits as u32); + super::instructions::instruction_constant(type_id, id, &[low, high]) + } + _ => unreachable!(), + }, + _ => unreachable!(), + }; + (instruction, id) + } + crate::ConstantInner::Bool(val) => { + let type_id = self.get_type_id(arena, LookupType::Handle(constant.ty)); + + let instruction = if val { + super::instructions::instruction_constant_true(type_id, id) + } else { + super::instructions::instruction_constant_false(type_id, id) + }; + + (instruction, id) + } + crate::ConstantInner::Composite(ref constituents) => { + let mut constituent_ids = Vec::with_capacity(constituents.len()); + for constituent in constituents.iter() { + let constituent_id = self.get_constant_id(*constituent, &ir_module); + constituent_ids.push(constituent_id); + } + + let type_id = self.get_type_id(arena, LookupType::Handle(constant.ty)); + let instruction = super::instructions::instruction_constant_composite( + type_id, + id, + constituent_ids.as_slice(), + ); + (instruction, id) + } + } + } + + fn write_global_variable( + &mut self, + ir_module: &crate::Module, + handle: crate::Handle<crate::GlobalVariable>, + ) -> (Instruction, Word) { + let global_variable = &ir_module.global_variables[handle]; + let id = self.generate_id(); + + let class = self.parse_to_spirv_storage_class(global_variable.class); + self.try_add_capabilities(class.required_capabilities()); + + let init_word = global_variable + .init + .map(|constant| self.get_constant_id(constant, ir_module)); + let pointer_id = + self.get_pointer_id(&ir_module.types, global_variable.ty, global_variable.class); + let instruction = + super::instructions::instruction_variable(pointer_id, id, class, init_word); + + if self.writer_flags.contains(WriterFlags::DEBUG) { + if let Some(ref name) = global_variable.name { + self.debugs + .push(super::instructions::instruction_name(id, name.as_str())); + } + } + + if let Some(interpolation) = global_variable.interpolation { + let decoration = match interpolation { + crate::Interpolation::Linear => Some(spirv::Decoration::NoPerspective), + crate::Interpolation::Flat => Some(spirv::Decoration::Flat), + crate::Interpolation::Patch => Some(spirv::Decoration::Patch), + crate::Interpolation::Centroid => Some(spirv::Decoration::Centroid), + crate::Interpolation::Sample => Some(spirv::Decoration::Sample), + crate::Interpolation::Perspective => None, + }; + if let Some(decoration) = decoration { + self.annotations + .push(super::instructions::instruction_decorate( + id, + decoration, + &[], + )); + } + } + + match *global_variable.binding.as_ref().unwrap() { + crate::Binding::Location(location) => { + self.annotations + .push(super::instructions::instruction_decorate( + id, + spirv::Decoration::Location, + &[location], + )); + } + crate::Binding::Resource { group, binding } => { + self.annotations + .push(super::instructions::instruction_decorate( + id, + spirv::Decoration::DescriptorSet, + &[group], + )); + self.annotations + .push(super::instructions::instruction_decorate( + id, + spirv::Decoration::Binding, + &[binding], + )); + } + crate::Binding::BuiltIn(built_in) => { + let built_in = match built_in { + crate::BuiltIn::BaseInstance => spirv::BuiltIn::BaseInstance, + crate::BuiltIn::BaseVertex => spirv::BuiltIn::BaseVertex, + crate::BuiltIn::ClipDistance => spirv::BuiltIn::ClipDistance, + crate::BuiltIn::InstanceIndex => spirv::BuiltIn::InstanceIndex, + crate::BuiltIn::Position => spirv::BuiltIn::Position, + crate::BuiltIn::VertexIndex => spirv::BuiltIn::VertexIndex, + crate::BuiltIn::PointSize => spirv::BuiltIn::PointSize, + crate::BuiltIn::FragCoord => spirv::BuiltIn::FragCoord, + crate::BuiltIn::FrontFacing => spirv::BuiltIn::FrontFacing, + crate::BuiltIn::SampleIndex => spirv::BuiltIn::SampleId, + crate::BuiltIn::FragDepth => spirv::BuiltIn::FragDepth, + crate::BuiltIn::GlobalInvocationId => spirv::BuiltIn::GlobalInvocationId, + crate::BuiltIn::LocalInvocationId => spirv::BuiltIn::LocalInvocationId, + crate::BuiltIn::LocalInvocationIndex => spirv::BuiltIn::LocalInvocationIndex, + crate::BuiltIn::WorkGroupId => spirv::BuiltIn::WorkgroupId, + }; + + self.annotations + .push(super::instructions::instruction_decorate( + id, + spirv::Decoration::BuiltIn, + &[built_in as u32], + )); + } + } + + // TODO Initializer is optional and not (yet) included in the IR + + self.lookup_global_variable.insert(handle, id); + (instruction, id) + } + + fn get_function_type( + &mut self, + lookup_function_type: LookupFunctionType, + parameter_pointer_ids: Vec<Word>, + ) -> Word { + match self + .lookup_function_type + .entry(lookup_function_type.clone()) + { + Entry::Occupied(e) => *e.get(), + _ => { + let id = self.generate_id(); + let instruction = super::instructions::instruction_type_function( + id, + lookup_function_type.return_type_id, + parameter_pointer_ids.as_slice(), + ); + instruction.to_words(&mut self.logical_layout.declarations); + self.lookup_function_type.insert(lookup_function_type, id); + id + } + } + } + + fn write_composite_construct( + &mut self, + base_type_id: Word, + constituent_ids: &[Word], + block: &mut Block, + ) -> Word { + let id = self.generate_id(); + block + .body + .push(super::instructions::instruction_composite_construct( + base_type_id, + id, + constituent_ids, + )); + id + } + + fn get_type_inner<'a>( + &self, + ty_arena: &'a crate::Arena<crate::Type>, + lookup_ty: LookupType, + ) -> MaybeOwned<'a, crate::TypeInner> { + match lookup_ty { + LookupType::Handle(handle) => MaybeOwned::Borrowed(&ty_arena[handle].inner), + LookupType::Local(local_ty) => match local_ty { + LocalType::Scalar { kind, width } => { + MaybeOwned::Owned(crate::TypeInner::Scalar { kind, width }) + } + LocalType::Vector { size, kind, width } => { + MaybeOwned::Owned(crate::TypeInner::Vector { size, kind, width }) + } + LocalType::Pointer { base, class } => { + MaybeOwned::Owned(crate::TypeInner::Pointer { base, class }) + } + _ => unreachable!(), + }, + } + } + + fn write_expression<'a>( + &mut self, + ir_module: &'a crate::Module, + ir_function: &crate::Function, + expression: &crate::Expression, + block: &mut Block, + function: &mut Function, + ) -> Result<WriteExpressionOutput, Error> { + match *expression { + crate::Expression::Access { base, index } => { + let id = self.generate_id(); + + let (base_id, base_lookup_ty) = self.write_expression( + ir_module, + ir_function, + &ir_function.expressions[base], + block, + function, + )?; + let (index_id, _) = self.write_expression( + ir_module, + ir_function, + &ir_function.expressions[index], + block, + function, + )?; + + let base_ty_inner = self.get_type_inner(&ir_module.types, base_lookup_ty); + + let (pointer_id, type_id, lookup_ty) = match *base_ty_inner { + crate::TypeInner::Vector { kind, width, .. } => { + let scalar_id = self.get_type_id( + &ir_module.types, + LookupType::Local(LocalType::Scalar { kind, width }), + ); + ( + self.create_pointer(scalar_id, spirv::StorageClass::Function), + scalar_id, + LookupType::Local(LocalType::Scalar { kind, width }), + ) + } + _ => unimplemented!(), + }; + + block + .body + .push(super::instructions::instruction_access_chain( + pointer_id, + id, + base_id, + &[index_id], + )); + + let load_id = self.generate_id(); + block.body.push(super::instructions::instruction_load( + type_id, load_id, id, None, + )); + + Ok((load_id, lookup_ty)) + } + crate::Expression::AccessIndex { base, index } => { + let id = self.generate_id(); + let (base_id, base_lookup_ty) = self + .write_expression( + ir_module, + ir_function, + &ir_function.expressions[base], + block, + function, + ) + .unwrap(); + + let base_ty_inner = self.get_type_inner(&ir_module.types, base_lookup_ty); + + let (pointer_id, type_id, lookup_ty) = match *base_ty_inner { + crate::TypeInner::Vector { kind, width, .. } => { + let scalar_id = self.get_type_id( + &ir_module.types, + LookupType::Local(LocalType::Scalar { kind, width }), + ); + ( + self.create_pointer(scalar_id, spirv::StorageClass::Function), + scalar_id, + LookupType::Local(LocalType::Scalar { kind, width }), + ) + } + crate::TypeInner::Struct { ref members } => { + let member = &members[index as usize]; + let type_id = + self.get_type_id(&ir_module.types, LookupType::Handle(member.ty)); + ( + self.create_pointer(type_id, spirv::StorageClass::Uniform), + type_id, + LookupType::Handle(member.ty), + ) + } + _ => unimplemented!(), + }; + + let const_ty_id = self.get_type_id( + &ir_module.types, + LookupType::Local(LocalType::Scalar { + kind: crate::ScalarKind::Sint, + width: 4, + }), + ); + let const_id = self.create_constant(const_ty_id, &[index]); + + block + .body + .push(super::instructions::instruction_access_chain( + pointer_id, + id, + base_id, + &[const_id], + )); + + let load_id = self.generate_id(); + block.body.push(super::instructions::instruction_load( + type_id, load_id, id, None, + )); + + Ok((load_id, lookup_ty)) + } + crate::Expression::GlobalVariable(handle) => { + let var = &ir_module.global_variables[handle]; + let id = self.get_global_variable_id(&ir_module, handle); + + Ok((id, LookupType::Handle(var.ty))) + } + crate::Expression::Constant(handle) => { + let var = &ir_module.constants[handle]; + let id = self.get_constant_id(handle, ir_module); + Ok((id, LookupType::Handle(var.ty))) + } + crate::Expression::Compose { ty, ref components } => { + let base_type_id = self.get_type_id(&ir_module.types, LookupType::Handle(ty)); + + let mut constituent_ids = Vec::with_capacity(components.len()); + for component in components { + let expression = &ir_function.expressions[*component]; + let (component_id, component_local_ty) = self.write_expression( + ir_module, + &ir_function, + expression, + block, + function, + )?; + + let component_id = match expression { + crate::Expression::LocalVariable(_) + | crate::Expression::GlobalVariable(_) => { + let load_id = self.generate_id(); + block.body.push(super::instructions::instruction_load( + self.get_type_id(&ir_module.types, component_local_ty), + load_id, + component_id, + None, + )); + load_id + } + _ => component_id, + }; + + constituent_ids.push(component_id); + } + let constituent_ids_slice = constituent_ids.as_slice(); + + let id = match ir_module.types[ty].inner { + crate::TypeInner::Vector { .. } => { + self.write_composite_construct(base_type_id, constituent_ids_slice, block) + } + crate::TypeInner::Matrix { + rows, + columns, + width, + } => { + let vector_type_id = self.get_type_id( + &ir_module.types, + LookupType::Local(LocalType::Vector { + width, + kind: crate::ScalarKind::Float, + size: columns, + }), + ); + + let capacity = match rows { + crate::VectorSize::Bi => 2, + crate::VectorSize::Tri => 3, + crate::VectorSize::Quad => 4, + }; + + let mut vector_ids = Vec::with_capacity(capacity); + + for _ in 0..capacity { + let vector_id = self.write_composite_construct( + vector_type_id, + constituent_ids_slice, + block, + ); + vector_ids.push(vector_id); + } + + self.write_composite_construct(base_type_id, vector_ids.as_slice(), block) + } + _ => unreachable!(), + }; + + Ok((id, LookupType::Handle(ty))) + } + crate::Expression::Binary { op, left, right } => { + let id = self.generate_id(); + let left_expression = &ir_function.expressions[left]; + let right_expression = &ir_function.expressions[right]; + let (left_id, left_lookup_ty) = self.write_expression( + ir_module, + ir_function, + left_expression, + block, + function, + )?; + let (right_id, right_lookup_ty) = self.write_expression( + ir_module, + ir_function, + right_expression, + block, + function, + )?; + + let left_lookup_ty = left_lookup_ty; + let right_lookup_ty = right_lookup_ty; + + let left_ty_inner = self.get_type_inner(&ir_module.types, left_lookup_ty); + let right_ty_inner = self.get_type_inner(&ir_module.types, right_lookup_ty); + + let left_result_type_id = self.get_type_id(&ir_module.types, left_lookup_ty); + + let right_result_type_id = self.get_type_id(&ir_module.types, right_lookup_ty); + + let left_id = match *left_expression { + crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => { + let load_id = self.generate_id(); + block.body.push(super::instructions::instruction_load( + left_result_type_id, + load_id, + left_id, + None, + )); + load_id + } + _ => left_id, + }; + + let right_id = match *right_expression { + crate::Expression::LocalVariable(..) + | crate::Expression::GlobalVariable(..) => { + let load_id = self.generate_id(); + block.body.push(super::instructions::instruction_load( + right_result_type_id, + load_id, + right_id, + None, + )); + load_id + } + _ => right_id, + }; + + let left_dimension = get_dimension(&left_ty_inner); + let right_dimension = get_dimension(&right_ty_inner); + + let (instruction, lookup_ty) = match op { + crate::BinaryOperator::Multiply => match (left_dimension, right_dimension) { + (Dimension::Vector, Dimension::Scalar { .. }) => ( + super::instructions::instruction_vector_times_scalar( + left_result_type_id, + id, + left_id, + right_id, + ), + left_lookup_ty, + ), + (Dimension::Vector, Dimension::Matrix) => ( + super::instructions::instruction_vector_times_matrix( + left_result_type_id, + id, + left_id, + right_id, + ), + left_lookup_ty, + ), + (Dimension::Matrix, Dimension::Scalar { .. }) => ( + super::instructions::instruction_matrix_times_scalar( + left_result_type_id, + id, + left_id, + right_id, + ), + left_lookup_ty, + ), + (Dimension::Matrix, Dimension::Vector) => ( + super::instructions::instruction_matrix_times_vector( + right_result_type_id, + id, + left_id, + right_id, + ), + right_lookup_ty, + ), + (Dimension::Matrix, Dimension::Matrix) => ( + super::instructions::instruction_matrix_times_matrix( + left_result_type_id, + id, + left_id, + right_id, + ), + left_lookup_ty, + ), + (Dimension::Vector, Dimension::Vector) + | (Dimension::Scalar, Dimension::Scalar) + if left_ty_inner.scalar_kind() == Some(crate::ScalarKind::Float) => + { + ( + super::instructions::instruction_f_mul( + left_result_type_id, + id, + left_id, + right_id, + ), + left_lookup_ty, + ) + } + (Dimension::Vector, Dimension::Vector) + | (Dimension::Scalar, Dimension::Scalar) => ( + super::instructions::instruction_i_mul( + left_result_type_id, + id, + left_id, + right_id, + ), + left_lookup_ty, + ), + _ => unreachable!(), + }, + crate::BinaryOperator::Subtract => match *left_ty_inner { + crate::TypeInner::Scalar { kind, .. } => match kind { + crate::ScalarKind::Sint | crate::ScalarKind::Uint => ( + super::instructions::instruction_i_sub( + left_result_type_id, + id, + left_id, + right_id, + ), + left_lookup_ty, + ), + crate::ScalarKind::Float => ( + super::instructions::instruction_f_sub( + left_result_type_id, + id, + left_id, + right_id, + ), + left_lookup_ty, + ), + _ => unreachable!(), + }, + _ => unreachable!(), + }, + crate::BinaryOperator::And => ( + super::instructions::instruction_bitwise_and( + left_result_type_id, + id, + left_id, + right_id, + ), + left_lookup_ty, + ), + _ => unimplemented!("{:?}", op), + }; + + block.body.push(instruction); + Ok((id, lookup_ty)) + } + crate::Expression::LocalVariable(variable) => { + let var = &ir_function.local_variables[variable]; + function + .variables + .iter() + .find(|&v| v.name.as_ref().unwrap() == var.name.as_ref().unwrap()) + .map(|local_var| (local_var.id, LookupType::Handle(var.ty))) + .ok_or_else(|| Error::UnknownLocalVariable(var.clone())) + } + crate::Expression::FunctionArgument(index) => { + let handle = ir_function.arguments[index as usize].ty; + let type_id = self.get_type_id(&ir_module.types, LookupType::Handle(handle)); + let load_id = self.generate_id(); + + block.body.push(super::instructions::instruction_load( + type_id, + load_id, + function.parameters[index as usize].result_id.unwrap(), + None, + )); + Ok((load_id, LookupType::Handle(handle))) + } + crate::Expression::Call { + ref origin, + ref arguments, + } => match *origin { + crate::FunctionOrigin::Local(local_function) => { + let origin_function = &ir_module.functions[local_function]; + let id = self.generate_id(); + let mut argument_ids = vec![]; + + for argument in arguments { + let expression = &ir_function.expressions[*argument]; + let (id, lookup_ty) = self.write_expression( + ir_module, + ir_function, + expression, + block, + function, + )?; + + // Create variable - OpVariable + // Store value to variable - OpStore + // Use id of variable + + let handle = match lookup_ty { + LookupType::Handle(handle) => handle, + LookupType::Local(_) => unreachable!(), + }; + + let pointer_id = self.get_pointer_id( + &ir_module.types, + handle, + crate::StorageClass::Function, + ); + + let variable_id = self.generate_id(); + function.variables.push(LocalVariable { + id: variable_id, + name: None, + instruction: super::instructions::instruction_variable( + pointer_id, + variable_id, + spirv::StorageClass::Function, + None, + ), + }); + block.body.push(super::instructions::instruction_store( + variable_id, + id, + None, + )); + argument_ids.push(variable_id); + } + + let return_type_id = self + .get_function_return_type(origin_function.return_type, &ir_module.types); + + block + .body + .push(super::instructions::instruction_function_call( + return_type_id, + id, + *self.lookup_function.get(&local_function).unwrap(), + argument_ids.as_slice(), + )); + + let result_type = match origin_function.return_type { + Some(ty_handle) => LookupType::Handle(ty_handle), + None => LookupType::Local(LocalType::Void), + }; + Ok((id, result_type)) + } + _ => unimplemented!("{:?}", origin), + }, + crate::Expression::As { + expr, + kind, + convert, + } => { + if !convert { + return Err(Error::FeatureNotImplemented()); + } + + let (expr_id, expr_type) = self.write_expression( + ir_module, + ir_function, + &ir_function.expressions[expr], + block, + function, + )?; + + let expr_type_inner = self.get_type_inner(&ir_module.types, expr_type); + + let (expr_kind, local_type) = match *expr_type_inner { + crate::TypeInner::Scalar { + kind: expr_kind, + width, + } => (expr_kind, LocalType::Scalar { kind, width }), + crate::TypeInner::Vector { + size, + kind: expr_kind, + width, + } => (expr_kind, LocalType::Vector { size, kind, width }), + _ => unreachable!(), + }; + + let lookup_type = LookupType::Local(local_type); + let op = match (expr_kind, kind) { + _ if !convert => spirv::Op::Bitcast, + (crate::ScalarKind::Float, crate::ScalarKind::Uint) => spirv::Op::ConvertFToU, + (crate::ScalarKind::Float, crate::ScalarKind::Sint) => spirv::Op::ConvertFToS, + (crate::ScalarKind::Sint, crate::ScalarKind::Float) => spirv::Op::ConvertSToF, + (crate::ScalarKind::Uint, crate::ScalarKind::Float) => spirv::Op::ConvertUToF, + // We assume it's either an identity cast, or int-uint. + // In both cases no SPIR-V instructions need to be generated. + _ => { + let id = match ir_function.expressions[expr] { + crate::Expression::LocalVariable(_) + | crate::Expression::GlobalVariable(_) => { + let load_id = self.generate_id(); + let kind_type_id = self.get_type_id(&ir_module.types, expr_type); + block.body.push(super::instructions::instruction_load( + kind_type_id, + load_id, + expr_id, + None, + )); + load_id + } + _ => expr_id, + }; + return Ok((id, lookup_type)); + } + }; + + let id = self.generate_id(); + let kind_type_id = self.get_type_id(&ir_module.types, lookup_type); + let instruction = + super::instructions::instruction_unary(op, kind_type_id, id, expr_id); + block.body.push(instruction); + + Ok((id, lookup_type)) + } + crate::Expression::ImageSample { + image, + sampler, + coordinate, + level: _, + depth_ref: _, + } => { + // image + let image_expression = &ir_function.expressions[image]; + let (image_id, image_lookup_ty) = self.write_expression( + ir_module, + ir_function, + image_expression, + block, + function, + )?; + + let image_result_type_id = self.get_type_id(&ir_module.types, image_lookup_ty); + let image_id = match *image_expression { + crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => { + let load_id = self.generate_id(); + block.body.push(super::instructions::instruction_load( + image_result_type_id, + load_id, + image_id, + None, + )); + load_id + } + _ => image_id, + }; + + let image_ty = match image_lookup_ty { + LookupType::Handle(handle) => handle, + LookupType::Local(_) => unreachable!(), + }; + + // OpTypeSampledImage + let sampled_image_type_id = self.get_type_id( + &ir_module.types, + LookupType::Local(LocalType::SampledImage { + image_type: image_ty, + }), + ); + + // sampler + let sampler_expression = &ir_function.expressions[sampler]; + let (sampler_id, sampler_lookup_ty) = self.write_expression( + ir_module, + ir_function, + sampler_expression, + block, + function, + )?; + + let sampler_result_type_id = self.get_type_id(&ir_module.types, sampler_lookup_ty); + let sampler_id = match *sampler_expression { + crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => { + let load_id = self.generate_id(); + block.body.push(super::instructions::instruction_load( + sampler_result_type_id, + load_id, + sampler_id, + None, + )); + load_id + } + _ => sampler_id, + }; + + // coordinate + let coordinate_expression = &ir_function.expressions[coordinate]; + let (coordinate_id, coordinate_lookup_ty) = self.write_expression( + ir_module, + ir_function, + coordinate_expression, + block, + function, + )?; + + let coordinate_result_type_id = + self.get_type_id(&ir_module.types, coordinate_lookup_ty); + let coordinate_id = match *coordinate_expression { + crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => { + let load_id = self.generate_id(); + block.body.push(super::instructions::instruction_load( + coordinate_result_type_id, + load_id, + coordinate_id, + None, + )); + load_id + } + _ => coordinate_id, + }; + + // component kind + let image_type = &ir_module.types[image_ty]; + let image_sample_result_type = + if let crate::TypeInner::Image { class, .. } = image_type.inner { + let width = 4; + LookupType::Local(match class { + crate::ImageClass::Sampled { kind, multi: _ } => LocalType::Vector { + kind, + width, + size: crate::VectorSize::Quad, + }, + crate::ImageClass::Depth => LocalType::Scalar { + kind: crate::ScalarKind::Float, + width, + }, + _ => return Err(Error::BadImageClass(class)), + }) + } else { + return Err(Error::NotImage); + }; + + let sampled_image_id = self.generate_id(); + block + .body + .push(super::instructions::instruction_sampled_image( + sampled_image_type_id, + sampled_image_id, + image_id, + sampler_id, + )); + let id = self.generate_id(); + let image_sample_result_type_id = + self.get_type_id(&ir_module.types, image_sample_result_type); + block + .body + .push(super::instructions::instruction_image_sample_implicit_lod( + image_sample_result_type_id, + id, + sampled_image_id, + coordinate_id, + )); + Ok((id, image_sample_result_type)) + } + _ => unimplemented!("{:?}", expression), + } + } + + fn write_block( + &mut self, + statements: &[crate::Statement], + ir_module: &crate::Module, + ir_function: &crate::Function, + function: &mut Function, + ) -> spirv::Word { + let mut block = Block::new(); + let id = self.generate_id(); + block.label = Some(super::instructions::instruction_label(id)); + + for statement in statements { + match *statement { + crate::Statement::Block(ref ir_block) => { + if !ir_block.is_empty() { + //TODO: link the block with `OpBranch` + self.write_block(ir_block, ir_module, ir_function, function); + } + } + crate::Statement::Return { value } => { + block.termination = Some(match ir_function.return_type { + Some(_) => { + let expression = &ir_function.expressions[value.unwrap()]; + let (id, lookup_ty) = self + .write_expression( + ir_module, + ir_function, + expression, + &mut block, + function, + ) + .unwrap(); + + let id = match *expression { + crate::Expression::LocalVariable(_) + | crate::Expression::GlobalVariable(_) => { + let load_id = self.generate_id(); + let value_ty_id = self.get_type_id(&ir_module.types, lookup_ty); + block.body.push(super::instructions::instruction_load( + value_ty_id, + load_id, + id, + None, + )); + load_id + } + + _ => id, + }; + super::instructions::instruction_return_value(id) + } + None => super::instructions::instruction_return(), + }); + } + crate::Statement::Store { pointer, value } => { + let pointer_expression = &ir_function.expressions[pointer]; + let value_expression = &ir_function.expressions[value]; + let (pointer_id, _) = self + .write_expression( + ir_module, + ir_function, + pointer_expression, + &mut block, + function, + ) + .unwrap(); + let (value_id, value_lookup_ty) = self + .write_expression( + ir_module, + ir_function, + value_expression, + &mut block, + function, + ) + .unwrap(); + + let value_id = match value_expression { + crate::Expression::LocalVariable(_) + | crate::Expression::GlobalVariable(_) => { + let load_id = self.generate_id(); + let value_ty_id = self.get_type_id(&ir_module.types, value_lookup_ty); + block.body.push(super::instructions::instruction_load( + value_ty_id, + load_id, + value_id, + None, + )); + load_id + } + _ => value_id, + }; + + block.body.push(super::instructions::instruction_store( + pointer_id, value_id, None, + )); + } + _ => unimplemented!("{:?}", statement), + } + } + + function.blocks.push(block); + id + } + + fn write_physical_layout(&mut self) { + self.physical_layout.bound = self.id_count + 1; + } + + fn write_logical_layout(&mut self, ir_module: &crate::Module) { + let id = self.generate_id(); + super::instructions::instruction_ext_inst_import(id, "GLSL.std.450") + .to_words(&mut self.logical_layout.ext_inst_imports); + + if self.writer_flags.contains(WriterFlags::DEBUG) { + self.debugs.push(super::instructions::instruction_source( + spirv::SourceLanguage::GLSL, + 450, + )); + } + + for (handle, ir_function) in ir_module.functions.iter() { + let id = self.write_function(ir_function, ir_module); + self.lookup_function.insert(handle, id); + } + + for (&(stage, ref name), ir_ep) in ir_module.entry_points.iter() { + let entry_point_instruction = self.write_entry_point(ir_ep, stage, name, ir_module); + entry_point_instruction.to_words(&mut self.logical_layout.entry_points); + } + + for capability in self.capabilities.iter() { + super::instructions::instruction_capability(*capability) + .to_words(&mut self.logical_layout.capabilities); + } + + let addressing_model = spirv::AddressingModel::Logical; + let memory_model = spirv::MemoryModel::GLSL450; + self.try_add_capabilities(addressing_model.required_capabilities()); + self.try_add_capabilities(memory_model.required_capabilities()); + + super::instructions::instruction_memory_model(addressing_model, memory_model) + .to_words(&mut self.logical_layout.memory_model); + + if self.writer_flags.contains(WriterFlags::DEBUG) { + for debug in self.debugs.iter() { + debug.to_words(&mut self.logical_layout.debugs); + } + } + + for annotation in self.annotations.iter() { + annotation.to_words(&mut self.logical_layout.annotations); + } + } + + pub fn write(&mut self, ir_module: &crate::Module) -> Vec<Word> { + let mut words: Vec<Word> = vec![]; + + self.write_logical_layout(ir_module); + self.write_physical_layout(); + + self.physical_layout.in_words(&mut words); + self.logical_layout.in_words(&mut words); + words + } +} + +#[cfg(test)] +mod tests { + use crate::back::spv::{Writer, WriterFlags}; + use crate::Header; + + #[test] + fn test_writer_generate_id() { + let mut writer = create_writer(); + + assert_eq!(writer.id_count, 0); + writer.generate_id(); + assert_eq!(writer.id_count, 1); + } + + #[test] + fn test_try_add_capabilities() { + let mut writer = create_writer(); + + assert_eq!(writer.capabilities.len(), 0); + writer.try_add_capabilities(&[spirv::Capability::Shader]); + assert_eq!(writer.capabilities.len(), 1); + + writer.try_add_capabilities(&[spirv::Capability::Shader]); + assert_eq!(writer.capabilities.len(), 1); + } + + #[test] + fn test_write_physical_layout() { + let mut writer = create_writer(); + assert_eq!(writer.physical_layout.bound, 0); + writer.write_physical_layout(); + assert_eq!(writer.physical_layout.bound, 1); + } + + fn create_writer() -> Writer { + let header = Header { + generator: 0, + version: (1, 0, 0), + }; + Writer::new(&header, WriterFlags::NONE) + } +} diff --git a/third_party/rust/naga/src/front/glsl/ast.rs b/third_party/rust/naga/src/front/glsl/ast.rs new file mode 100644 index 0000000000..58161ad7be --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/ast.rs @@ -0,0 +1,178 @@ +use super::error::ErrorKind; +use crate::{ + proc::{ResolveContext, Typifier}, + Arena, BinaryOperator, Binding, Expression, FastHashMap, Function, GlobalVariable, Handle, + Interpolation, LocalVariable, Module, ShaderStage, Statement, StorageClass, Type, +}; + +#[derive(Debug)] +pub struct Program { + pub version: u16, + pub profile: Profile, + pub shader_stage: ShaderStage, + pub entry: Option<String>, + pub lookup_function: FastHashMap<String, Handle<Function>>, + pub lookup_type: FastHashMap<String, Handle<Type>>, + pub lookup_global_variables: FastHashMap<String, Handle<GlobalVariable>>, + pub context: Context, + pub module: Module, +} + +impl Program { + pub fn new(shader_stage: ShaderStage, entry: &str) -> Program { + Program { + version: 0, + profile: Profile::Core, + shader_stage, + entry: Some(entry.to_string()), + lookup_function: FastHashMap::default(), + lookup_type: FastHashMap::default(), + lookup_global_variables: FastHashMap::default(), + context: Context { + expressions: Arena::<Expression>::new(), + local_variables: Arena::<LocalVariable>::new(), + scopes: vec![FastHashMap::default()], + lookup_global_var_exps: FastHashMap::default(), + typifier: Typifier::new(), + }, + module: Module::generate_empty(), + } + } + + pub fn binary_expr( + &mut self, + op: BinaryOperator, + left: &ExpressionRule, + right: &ExpressionRule, + ) -> ExpressionRule { + ExpressionRule::from_expression(self.context.expressions.append(Expression::Binary { + op, + left: left.expression, + right: right.expression, + })) + } + + pub fn resolve_type( + &mut self, + handle: Handle<crate::Expression>, + ) -> Result<&crate::TypeInner, ErrorKind> { + let functions = Arena::new(); //TODO + let arguments = Vec::new(); //TODO + let resolve_ctx = ResolveContext { + constants: &self.module.constants, + global_vars: &self.module.global_variables, + local_vars: &self.context.local_variables, + functions: &functions, + arguments: &arguments, + }; + match self.context.typifier.grow( + handle, + &self.context.expressions, + &mut self.module.types, + &resolve_ctx, + ) { + //TODO: better error report + Err(_) => Err(ErrorKind::SemanticError("Can't resolve type")), + Ok(()) => Ok(self.context.typifier.get(handle, &self.module.types)), + } + } +} + +#[derive(Debug)] +pub enum Profile { + Core, +} + +#[derive(Debug)] +pub struct Context { + pub expressions: Arena<Expression>, + pub local_variables: Arena<LocalVariable>, + //TODO: Find less allocation heavy representation + pub scopes: Vec<FastHashMap<String, Handle<Expression>>>, + pub lookup_global_var_exps: FastHashMap<String, Handle<Expression>>, + pub typifier: Typifier, +} + +impl Context { + pub fn lookup_local_var(&self, name: &str) -> Option<Handle<Expression>> { + for scope in self.scopes.iter().rev() { + if let Some(var) = scope.get(name) { + return Some(*var); + } + } + None + } + + #[cfg(feature = "glsl-validate")] + pub fn lookup_local_var_current_scope(&self, name: &str) -> Option<Handle<Expression>> { + if let Some(current) = self.scopes.last() { + current.get(name).cloned() + } else { + None + } + } + + pub fn clear_scopes(&mut self) { + self.scopes.clear(); + self.scopes.push(FastHashMap::default()); + } + + /// Add variable to current scope + pub fn add_local_var(&mut self, name: String, handle: Handle<Expression>) { + if let Some(current) = self.scopes.last_mut() { + (*current).insert(name, handle); + } + } + + /// Add new empty scope + pub fn push_scope(&mut self) { + self.scopes.push(FastHashMap::default()); + } + + pub fn remove_current_scope(&mut self) { + self.scopes.pop(); + } +} + +#[derive(Debug)] +pub struct ExpressionRule { + pub expression: Handle<Expression>, + pub statements: Vec<Statement>, + pub sampler: Option<Handle<Expression>>, +} + +impl ExpressionRule { + pub fn from_expression(expression: Handle<Expression>) -> ExpressionRule { + ExpressionRule { + expression, + statements: vec![], + sampler: None, + } + } +} + +#[derive(Debug)] +pub enum TypeQualifier { + StorageClass(StorageClass), + Binding(Binding), + Interpolation(Interpolation), +} + +#[derive(Debug)] +pub struct VarDeclaration { + pub type_qualifiers: Vec<TypeQualifier>, + pub ids_initializers: Vec<(Option<String>, Option<ExpressionRule>)>, + pub ty: Handle<Type>, +} + +#[derive(Debug)] +pub enum FunctionCallKind { + TypeConstructor(Handle<Type>), + Function(String), +} + +#[derive(Debug)] +pub struct FunctionCall { + pub kind: FunctionCallKind, + pub args: Vec<ExpressionRule>, +} diff --git a/third_party/rust/naga/src/front/glsl/error.rs b/third_party/rust/naga/src/front/glsl/error.rs new file mode 100644 index 0000000000..db5ffbe9f2 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/error.rs @@ -0,0 +1,87 @@ +use super::parser::Token; +use super::token::TokenMetadata; +use std::{fmt, io}; + +#[derive(Debug)] +pub enum ErrorKind { + EndOfFile, + InvalidInput, + InvalidProfile(TokenMetadata, String), + InvalidToken(Token), + InvalidVersion(TokenMetadata, i64), + IoError(io::Error), + ParserFail, + ParserStackOverflow, + NotImplemented(&'static str), + UnknownVariable(TokenMetadata, String), + UnknownField(TokenMetadata, String), + #[cfg(feature = "glsl-validate")] + VariableAlreadyDeclared(String), + #[cfg(feature = "glsl-validate")] + VariableNotAvailable(String), + ExpectedConstant, + SemanticError(&'static str), + PreprocessorError(String), +} + +impl fmt::Display for ErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ErrorKind::EndOfFile => write!(f, "Unexpected end of file"), + ErrorKind::InvalidInput => write!(f, "InvalidInput"), + ErrorKind::InvalidProfile(meta, val) => { + write!(f, "Invalid profile {} at {:?}", val, meta) + } + ErrorKind::InvalidToken(token) => write!(f, "Invalid Token {:?}", token), + ErrorKind::InvalidVersion(meta, val) => { + write!(f, "Invalid version {} at {:?}", val, meta) + } + ErrorKind::IoError(error) => write!(f, "IO Error {}", error), + ErrorKind::ParserFail => write!(f, "Parser failed"), + ErrorKind::ParserStackOverflow => write!(f, "Parser stack overflow"), + ErrorKind::NotImplemented(msg) => write!(f, "Not implemented: {}", msg), + ErrorKind::UnknownVariable(meta, val) => { + write!(f, "Unknown variable {} at {:?}", val, meta) + } + ErrorKind::UnknownField(meta, val) => write!(f, "Unknown field {} at {:?}", val, meta), + #[cfg(feature = "glsl-validate")] + ErrorKind::VariableAlreadyDeclared(val) => { + write!(f, "Variable {} already decalred in current scope", val) + } + #[cfg(feature = "glsl-validate")] + ErrorKind::VariableNotAvailable(val) => { + write!(f, "Variable {} not available in this stage", val) + } + ErrorKind::ExpectedConstant => write!(f, "Expected constant"), + ErrorKind::SemanticError(msg) => write!(f, "Semantic error: {}", msg), + ErrorKind::PreprocessorError(val) => write!(f, "Preprocessor error: {}", val), + } + } +} + +#[derive(Debug)] +pub struct ParseError { + pub kind: ErrorKind, +} + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +impl From<io::Error> for ParseError { + fn from(error: io::Error) -> Self { + ParseError { + kind: ErrorKind::IoError(error), + } + } +} + +impl From<ErrorKind> for ParseError { + fn from(kind: ErrorKind) -> Self { + ParseError { kind } + } +} + +impl std::error::Error for ParseError {} diff --git a/third_party/rust/naga/src/front/glsl/lex.rs b/third_party/rust/naga/src/front/glsl/lex.rs new file mode 100644 index 0000000000..57e86a4c2c --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/lex.rs @@ -0,0 +1,380 @@ +use super::parser::Token; +use super::{preprocess::LinePreProcessor, token::TokenMetadata, types::parse_type}; +use std::{iter::Enumerate, str::Lines}; + +fn _consume_str<'a>(input: &'a str, what: &str) -> Option<&'a str> { + if input.starts_with(what) { + Some(&input[what.len()..]) + } else { + None + } +} + +fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str, usize) { + let pos = input.find(|c| !what(c)).unwrap_or_else(|| input.len()); + let (o, i) = input.split_at(pos); + (o, i, pos) +} + +#[derive(Clone, Debug)] +pub struct Lexer<'a> { + lines: Enumerate<Lines<'a>>, + input: String, + line: usize, + offset: usize, + inside_comment: bool, + pub pp: LinePreProcessor, +} + +impl<'a> Lexer<'a> { + pub fn consume_token(&mut self) -> Option<Token> { + let start = self + .input + .find(|c: char| !c.is_whitespace()) + .unwrap_or_else(|| self.input.chars().count()); + let input = &self.input[start..]; + + let mut chars = input.chars(); + let cur = match chars.next() { + Some(c) => c, + None => { + self.input = self.input[start..].into(); + return None; + } + }; + let mut meta = TokenMetadata { + line: 0, + chars: start..start + 1, + }; + let mut consume_all = false; + let token = match cur { + ':' => Some(Token::Colon(meta)), + ';' => Some(Token::Semicolon(meta)), + ',' => Some(Token::Comma(meta)), + '.' => Some(Token::Dot(meta)), + + '(' => Some(Token::LeftParen(meta)), + ')' => Some(Token::RightParen(meta)), + '{' => Some(Token::LeftBrace(meta)), + '}' => Some(Token::RightBrace(meta)), + '[' => Some(Token::LeftBracket(meta)), + ']' => Some(Token::RightBracket(meta)), + '<' | '>' => { + let n1 = chars.next(); + let n2 = chars.next(); + match (cur, n1, n2) { + ('<', Some('<'), Some('=')) => { + meta.chars.end = start + 3; + Some(Token::LeftAssign(meta)) + } + ('>', Some('>'), Some('=')) => { + meta.chars.end = start + 3; + Some(Token::RightAssign(meta)) + } + ('<', Some('<'), _) => { + meta.chars.end = start + 2; + Some(Token::LeftOp(meta)) + } + ('>', Some('>'), _) => { + meta.chars.end = start + 2; + Some(Token::RightOp(meta)) + } + ('<', Some('='), _) => { + meta.chars.end = start + 2; + Some(Token::LeOp(meta)) + } + ('>', Some('='), _) => { + meta.chars.end = start + 2; + Some(Token::GeOp(meta)) + } + ('<', _, _) => Some(Token::LeftAngle(meta)), + ('>', _, _) => Some(Token::RightAngle(meta)), + _ => None, + } + } + '0'..='9' => { + let (number, _, pos) = consume_any(input, |c| (c >= '0' && c <= '9' || c == '.')); + if number.find('.').is_some() { + if ( + chars.next().map(|c| c.to_lowercase().next().unwrap()), + chars.next().map(|c| c.to_lowercase().next().unwrap()), + ) == (Some('l'), Some('f')) + { + meta.chars.end = start + pos + 2; + Some(Token::DoubleConstant((meta, number.parse().unwrap()))) + } else { + meta.chars.end = start + pos; + + Some(Token::FloatConstant((meta, number.parse().unwrap()))) + } + } else { + meta.chars.end = start + pos; + Some(Token::IntConstant((meta, number.parse().unwrap()))) + } + } + 'a'..='z' | 'A'..='Z' | '_' => { + let (word, _, pos) = consume_any(input, |c| c.is_ascii_alphanumeric() || c == '_'); + meta.chars.end = start + pos; + match word { + "layout" => Some(Token::Layout(meta)), + "in" => Some(Token::In(meta)), + "out" => Some(Token::Out(meta)), + "uniform" => Some(Token::Uniform(meta)), + "flat" => Some(Token::Interpolation((meta, crate::Interpolation::Flat))), + "noperspective" => { + Some(Token::Interpolation((meta, crate::Interpolation::Linear))) + } + "smooth" => Some(Token::Interpolation(( + meta, + crate::Interpolation::Perspective, + ))), + "centroid" => { + Some(Token::Interpolation((meta, crate::Interpolation::Centroid))) + } + "sample" => Some(Token::Interpolation((meta, crate::Interpolation::Sample))), + // values + "true" => Some(Token::BoolConstant((meta, true))), + "false" => Some(Token::BoolConstant((meta, false))), + // jump statements + "continue" => Some(Token::Continue(meta)), + "break" => Some(Token::Break(meta)), + "return" => Some(Token::Return(meta)), + "discard" => Some(Token::Discard(meta)), + // selection statements + "if" => Some(Token::If(meta)), + "else" => Some(Token::Else(meta)), + "switch" => Some(Token::Switch(meta)), + "case" => Some(Token::Case(meta)), + "default" => Some(Token::Default(meta)), + // iteration statements + "while" => Some(Token::While(meta)), + "do" => Some(Token::Do(meta)), + "for" => Some(Token::For(meta)), + // types + "void" => Some(Token::Void(meta)), + word => { + let token = match parse_type(word) { + Some(t) => Token::TypeName((meta, t)), + None => Token::Identifier((meta, String::from(word))), + }; + Some(token) + } + } + } + '+' | '-' | '&' | '|' | '^' => { + let next = chars.next(); + if next == Some(cur) { + meta.chars.end = start + 2; + match cur { + '+' => Some(Token::IncOp(meta)), + '-' => Some(Token::DecOp(meta)), + '&' => Some(Token::AndOp(meta)), + '|' => Some(Token::OrOp(meta)), + '^' => Some(Token::XorOp(meta)), + _ => None, + } + } else { + match next { + Some('=') => { + meta.chars.end = start + 2; + match cur { + '+' => Some(Token::AddAssign(meta)), + '-' => Some(Token::SubAssign(meta)), + '&' => Some(Token::AndAssign(meta)), + '|' => Some(Token::OrAssign(meta)), + '^' => Some(Token::XorAssign(meta)), + _ => None, + } + } + _ => match cur { + '+' => Some(Token::Plus(meta)), + '-' => Some(Token::Dash(meta)), + '&' => Some(Token::Ampersand(meta)), + '|' => Some(Token::VerticalBar(meta)), + '^' => Some(Token::Caret(meta)), + _ => None, + }, + } + } + } + + '%' | '!' | '=' => match chars.next() { + Some('=') => { + meta.chars.end = start + 2; + match cur { + '%' => Some(Token::ModAssign(meta)), + '!' => Some(Token::NeOp(meta)), + '=' => Some(Token::EqOp(meta)), + _ => None, + } + } + _ => match cur { + '%' => Some(Token::Percent(meta)), + '!' => Some(Token::Bang(meta)), + '=' => Some(Token::Equal(meta)), + _ => None, + }, + }, + + '*' => match chars.next() { + Some('=') => { + meta.chars.end = start + 2; + Some(Token::MulAssign(meta)) + } + Some('/') => { + meta.chars.end = start + 2; + Some(Token::CommentEnd((meta, ()))) + } + _ => Some(Token::Star(meta)), + }, + '/' => { + match chars.next() { + Some('=') => { + meta.chars.end = start + 2; + Some(Token::DivAssign(meta)) + } + Some('/') => { + // consume rest of line + consume_all = true; + None + } + Some('*') => { + meta.chars.end = start + 2; + Some(Token::CommentStart((meta, ()))) + } + _ => Some(Token::Slash(meta)), + } + } + '#' => { + if self.offset == 0 { + let mut input = chars.as_str(); + + // skip whitespace + let word_start = input + .find(|c: char| !c.is_whitespace()) + .unwrap_or_else(|| input.chars().count()); + input = &input[word_start..]; + + let (word, _, pos) = consume_any(input, |c| c.is_alphanumeric() || c == '_'); + meta.chars.end = start + word_start + 1 + pos; + match word { + "version" => Some(Token::Version(meta)), + w => Some(Token::Unknown((meta, w.into()))), + } + + //TODO: preprocessor + // if chars.next() == Some(cur) { + // (Token::TokenPasting, chars.as_str(), start, start + 2) + // } else { + // (Token::Preprocessor, input, start, start + 1) + // } + } else { + Some(Token::Unknown((meta, '#'.to_string()))) + } + } + '~' => Some(Token::Tilde(meta)), + '?' => Some(Token::Question(meta)), + ch => Some(Token::Unknown((meta, ch.to_string()))), + }; + if let Some(token) = token { + let skip_bytes = input + .chars() + .take(token.extra().chars.end - start) + .fold(0, |acc, c| acc + c.len_utf8()); + self.input = input[skip_bytes..].into(); + Some(token) + } else { + if consume_all { + self.input = "".into(); + } else { + self.input = self.input[start..].into(); + } + None + } + } + + pub fn new(input: &'a str) -> Self { + let mut lexer = Lexer { + lines: input.lines().enumerate(), + input: "".to_string(), + line: 0, + offset: 0, + inside_comment: false, + pp: LinePreProcessor::new(), + }; + lexer.next_line(); + lexer + } + + fn next_line(&mut self) -> bool { + if let Some((line, input)) = self.lines.next() { + let mut input = String::from(input); + + while input.ends_with('\\') { + if let Some((_, next)) = self.lines.next() { + input.pop(); + input.push_str(next); + } else { + break; + } + } + + if let Ok(processed) = self.pp.process_line(&input) { + self.input = processed.unwrap_or_default(); + self.line = line; + self.offset = 0; + true + } else { + //TODO: handle preprocessor error + false + } + } else { + false + } + } + + #[must_use] + pub fn next(&mut self) -> Option<Token> { + let token = self.consume_token(); + + if let Some(mut token) = token { + let meta = token.extra_mut(); + let end = meta.chars.end; + meta.line = self.line; + meta.chars.start += self.offset; + meta.chars.end += self.offset; + self.offset += end; + if !self.inside_comment { + match token { + Token::CommentStart(_) => { + self.inside_comment = true; + self.next() + } + _ => Some(token), + } + } else { + if let Token::CommentEnd(_) = token { + self.inside_comment = false; + } + self.next() + } + } else { + if !self.next_line() { + return None; + } + self.next() + } + } + + // #[must_use] + // pub fn peek(&mut self) -> Option<Token> { + // self.clone().next() + // } +} + +impl<'a> Iterator for Lexer<'a> { + type Item = Token; + fn next(&mut self) -> Option<Self::Item> { + self.next() + } +} diff --git a/third_party/rust/naga/src/front/glsl/lex_tests.rs b/third_party/rust/naga/src/front/glsl/lex_tests.rs new file mode 100644 index 0000000000..fde7fc9790 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/lex_tests.rs @@ -0,0 +1,346 @@ +use super::{lex::Lexer, parser::Token::*, token::TokenMetadata}; + +#[test] +fn tokens() { + // line comments + let mut lex = Lexer::new("void main // myfunction\n//()\n{}"); + assert_eq!( + lex.next().unwrap(), + Void(TokenMetadata { + line: 0, + chars: 0..4 + }) + ); + assert_eq!( + lex.next().unwrap(), + Identifier(( + TokenMetadata { + line: 0, + chars: 5..9 + }, + "main".into() + )) + ); + assert_eq!( + lex.next().unwrap(), + LeftBrace(TokenMetadata { + line: 2, + chars: 0..1 + }) + ); + assert_eq!( + lex.next().unwrap(), + RightBrace(TokenMetadata { + line: 2, + chars: 1..2 + }) + ); + assert_eq!(lex.next(), None); + + // multi line comment + let mut lex = Lexer::new("void main /* comment [] {}\n/**\n{}*/{}"); + assert_eq!( + lex.next().unwrap(), + Void(TokenMetadata { + line: 0, + chars: 0..4 + }) + ); + assert_eq!( + lex.next().unwrap(), + Identifier(( + TokenMetadata { + line: 0, + chars: 5..9 + }, + "main".into() + )) + ); + assert_eq!( + lex.next().unwrap(), + LeftBrace(TokenMetadata { + line: 2, + chars: 4..5 + }) + ); + assert_eq!( + lex.next().unwrap(), + RightBrace(TokenMetadata { + line: 2, + chars: 5..6 + }) + ); + assert_eq!(lex.next(), None); + + // identifiers + let mut lex = Lexer::new("id123_OK 92No æNoø No¾ No好"); + assert_eq!( + lex.next().unwrap(), + Identifier(( + TokenMetadata { + line: 0, + chars: 0..8 + }, + "id123_OK".into() + )) + ); + assert_eq!( + lex.next().unwrap(), + IntConstant(( + TokenMetadata { + line: 0, + chars: 9..11 + }, + 92 + )) + ); + assert_eq!( + lex.next().unwrap(), + Identifier(( + TokenMetadata { + line: 0, + chars: 11..13 + }, + "No".into() + )) + ); + assert_eq!( + lex.next().unwrap(), + Unknown(( + TokenMetadata { + line: 0, + chars: 14..15 + }, + 'æ'.to_string() + )) + ); + assert_eq!( + lex.next().unwrap(), + Identifier(( + TokenMetadata { + line: 0, + chars: 15..17 + }, + "No".into() + )) + ); + assert_eq!( + lex.next().unwrap(), + Unknown(( + TokenMetadata { + line: 0, + chars: 17..18 + }, + 'ø'.to_string() + )) + ); + assert_eq!( + lex.next().unwrap(), + Identifier(( + TokenMetadata { + line: 0, + chars: 19..21 + }, + "No".into() + )) + ); + assert_eq!( + lex.next().unwrap(), + Unknown(( + TokenMetadata { + line: 0, + chars: 21..22 + }, + '¾'.to_string() + )) + ); + assert_eq!( + lex.next().unwrap(), + Identifier(( + TokenMetadata { + line: 0, + chars: 23..25 + }, + "No".into() + )) + ); + assert_eq!( + lex.next().unwrap(), + Unknown(( + TokenMetadata { + line: 0, + chars: 25..26 + }, + '好'.to_string() + )) + ); + assert_eq!(lex.next(), None); + + // version + let mut lex = Lexer::new("#version 890 core"); + assert_eq!( + lex.next().unwrap(), + Version(TokenMetadata { + line: 0, + chars: 0..8 + }) + ); + assert_eq!( + lex.next().unwrap(), + IntConstant(( + TokenMetadata { + line: 0, + chars: 9..12 + }, + 890 + )) + ); + assert_eq!( + lex.next().unwrap(), + Identifier(( + TokenMetadata { + line: 0, + chars: 13..17 + }, + "core".into() + )) + ); + assert_eq!(lex.next(), None); + + // operators + let mut lex = Lexer::new("+ - * | & % / += -= *= |= &= %= /= ++ -- || && ^^"); + assert_eq!( + lex.next().unwrap(), + Plus(TokenMetadata { + line: 0, + chars: 0..1 + }) + ); + assert_eq!( + lex.next().unwrap(), + Dash(TokenMetadata { + line: 0, + chars: 2..3 + }) + ); + assert_eq!( + lex.next().unwrap(), + Star(TokenMetadata { + line: 0, + chars: 4..5 + }) + ); + assert_eq!( + lex.next().unwrap(), + VerticalBar(TokenMetadata { + line: 0, + chars: 6..7 + }) + ); + assert_eq!( + lex.next().unwrap(), + Ampersand(TokenMetadata { + line: 0, + chars: 8..9 + }) + ); + assert_eq!( + lex.next().unwrap(), + Percent(TokenMetadata { + line: 0, + chars: 10..11 + }) + ); + assert_eq!( + lex.next().unwrap(), + Slash(TokenMetadata { + line: 0, + chars: 12..13 + }) + ); + assert_eq!( + lex.next().unwrap(), + AddAssign(TokenMetadata { + line: 0, + chars: 14..16 + }) + ); + assert_eq!( + lex.next().unwrap(), + SubAssign(TokenMetadata { + line: 0, + chars: 17..19 + }) + ); + assert_eq!( + lex.next().unwrap(), + MulAssign(TokenMetadata { + line: 0, + chars: 20..22 + }) + ); + assert_eq!( + lex.next().unwrap(), + OrAssign(TokenMetadata { + line: 0, + chars: 23..25 + }) + ); + assert_eq!( + lex.next().unwrap(), + AndAssign(TokenMetadata { + line: 0, + chars: 26..28 + }) + ); + assert_eq!( + lex.next().unwrap(), + ModAssign(TokenMetadata { + line: 0, + chars: 29..31 + }) + ); + assert_eq!( + lex.next().unwrap(), + DivAssign(TokenMetadata { + line: 0, + chars: 32..34 + }) + ); + assert_eq!( + lex.next().unwrap(), + IncOp(TokenMetadata { + line: 0, + chars: 35..37 + }) + ); + assert_eq!( + lex.next().unwrap(), + DecOp(TokenMetadata { + line: 0, + chars: 38..40 + }) + ); + assert_eq!( + lex.next().unwrap(), + OrOp(TokenMetadata { + line: 0, + chars: 41..43 + }) + ); + assert_eq!( + lex.next().unwrap(), + AndOp(TokenMetadata { + line: 0, + chars: 44..46 + }) + ); + assert_eq!( + lex.next().unwrap(), + XorOp(TokenMetadata { + line: 0, + chars: 47..49 + }) + ); + assert_eq!(lex.next(), None); +} diff --git a/third_party/rust/naga/src/front/glsl/mod.rs b/third_party/rust/naga/src/front/glsl/mod.rs new file mode 100644 index 0000000000..4661011828 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/mod.rs @@ -0,0 +1,43 @@ +use crate::{FastHashMap, Module, ShaderStage}; + +mod lex; +#[cfg(test)] +mod lex_tests; + +mod preprocess; +#[cfg(test)] +mod preprocess_tests; + +mod ast; +use ast::Program; + +use lex::Lexer; +mod error; +use error::ParseError; +mod parser; +#[cfg(test)] +mod parser_tests; +mod token; +mod types; +mod variables; + +pub fn parse_str( + source: &str, + entry: &str, + stage: ShaderStage, + defines: FastHashMap<String, String>, +) -> Result<Module, ParseError> { + let mut program = Program::new(stage, entry); + + let mut lex = Lexer::new(source); + lex.pp.defines = defines; + + let mut parser = parser::Parser::new(&mut program); + + for token in lex { + parser.parse(token)?; + } + parser.end_of_input()?; + + Ok(program.module) +} diff --git a/third_party/rust/naga/src/front/glsl/parser.rs b/third_party/rust/naga/src/front/glsl/parser.rs new file mode 100644 index 0000000000..db4935b0b4 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/parser.rs @@ -0,0 +1,1131 @@ +#![allow(clippy::panic)] +use pomelo::pomelo; + +pomelo! { + //%verbose; + %include { + use super::super::{error::ErrorKind, token::*, ast::*}; + use crate::{proc::Typifier, Arena, BinaryOperator, Binding, Block, Constant, + ConstantInner, EntryPoint, Expression, FallThrough, FastHashMap, Function, GlobalVariable, Handle, Interpolation, + LocalVariable, MemberOrigin, SampleLevel, ScalarKind, Statement, StorageAccess, + StorageClass, StructMember, Type, TypeInner, UnaryOperator}; + } + %token #[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] pub enum Token {}; + %parser pub struct Parser<'a> {}; + %extra_argument &'a mut Program; + %extra_token TokenMetadata; + %error ErrorKind; + %syntax_error { + match token { + Some(token) => Err(ErrorKind::InvalidToken(token)), + None => Err(ErrorKind::EndOfFile), + } + } + %parse_fail { + ErrorKind::ParserFail + } + %stack_overflow { + ErrorKind::ParserStackOverflow + } + + %type Unknown String; + %type CommentStart (); + %type CommentEnd (); + + %type Identifier String; + // constants + %type IntConstant i64; + %type UintConstant u64; + %type FloatConstant f32; + %type BoolConstant bool; + %type DoubleConstant f64; + %type String String; + // function + %type function_prototype Function; + %type function_declarator Function; + %type function_header Function; + %type function_definition Function; + + // statements + %type compound_statement Block; + %type compound_statement_no_new_scope Block; + %type statement_list Block; + %type statement Statement; + %type simple_statement Statement; + %type expression_statement Statement; + %type declaration_statement Statement; + %type jump_statement Statement; + %type iteration_statement Statement; + %type selection_statement Statement; + %type switch_statement_list Vec<(Option<i32>, Block, Option<FallThrough>)>; + %type switch_statement (Option<i32>, Block, Option<FallThrough>); + %type for_init_statement Statement; + %type for_rest_statement (Option<ExpressionRule>, Option<ExpressionRule>); + %type condition_opt Option<ExpressionRule>; + + // expressions + %type unary_expression ExpressionRule; + %type postfix_expression ExpressionRule; + %type primary_expression ExpressionRule; + %type variable_identifier ExpressionRule; + + %type function_call ExpressionRule; + %type function_call_or_method FunctionCall; + %type function_call_generic FunctionCall; + %type function_call_header_no_parameters FunctionCall; + %type function_call_header_with_parameters FunctionCall; + %type function_call_header FunctionCall; + %type function_identifier FunctionCallKind; + + %type multiplicative_expression ExpressionRule; + %type additive_expression ExpressionRule; + %type shift_expression ExpressionRule; + %type relational_expression ExpressionRule; + %type equality_expression ExpressionRule; + %type and_expression ExpressionRule; + %type exclusive_or_expression ExpressionRule; + %type inclusive_or_expression ExpressionRule; + %type logical_and_expression ExpressionRule; + %type logical_xor_expression ExpressionRule; + %type logical_or_expression ExpressionRule; + %type conditional_expression ExpressionRule; + + %type assignment_expression ExpressionRule; + %type assignment_operator BinaryOperator; + %type expression ExpressionRule; + %type constant_expression Handle<Constant>; + + %type initializer ExpressionRule; + + // declarations + %type declaration Option<VarDeclaration>; + %type init_declarator_list VarDeclaration; + %type single_declaration VarDeclaration; + %type layout_qualifier Binding; + %type layout_qualifier_id_list Vec<(String, u32)>; + %type layout_qualifier_id (String, u32); + %type type_qualifier Vec<TypeQualifier>; + %type single_type_qualifier TypeQualifier; + %type storage_qualifier StorageClass; + %type interpolation_qualifier Interpolation; + %type Interpolation Interpolation; + + // types + %type fully_specified_type (Vec<TypeQualifier>, Option<Handle<Type>>); + %type type_specifier Option<Handle<Type>>; + %type type_specifier_nonarray Option<Type>; + %type struct_specifier Type; + %type struct_declaration_list Vec<StructMember>; + %type struct_declaration Vec<StructMember>; + %type struct_declarator_list Vec<String>; + %type struct_declarator String; + + %type TypeName Type; + + // precedence + %right Else; + + root ::= version_pragma translation_unit; + version_pragma ::= Version IntConstant(V) Identifier?(P) { + match V.1 { + 440 => (), + 450 => (), + 460 => (), + _ => return Err(ErrorKind::InvalidVersion(V.0, V.1)) + } + extra.version = V.1 as u16; + extra.profile = match P { + Some((meta, profile)) => { + match profile.as_str() { + "core" => Profile::Core, + _ => return Err(ErrorKind::InvalidProfile(meta, profile)) + } + }, + None => Profile::Core, + } + }; + + // expression + variable_identifier ::= Identifier(v) { + let var = extra.lookup_variable(&v.1)?; + match var { + Some(expression) => { + ExpressionRule::from_expression(expression) + }, + None => { + return Err(ErrorKind::UnknownVariable(v.0, v.1)); + } + } + } + + primary_expression ::= variable_identifier; + primary_expression ::= IntConstant(i) { + let ty = extra.module.types.fetch_or_append(Type { + name: None, + inner: TypeInner::Scalar { + kind: ScalarKind::Sint, + width: 4, + } + }); + let ch = extra.module.constants.fetch_or_append(Constant { + name: None, + specialization: None, + ty, + inner: ConstantInner::Sint(i.1) + }); + ExpressionRule::from_expression( + extra.context.expressions.append(Expression::Constant(ch)) + ) + } + // primary_expression ::= UintConstant; + primary_expression ::= FloatConstant(f) { + let ty = extra.module.types.fetch_or_append(Type { + name: None, + inner: TypeInner::Scalar { + kind: ScalarKind::Float, + width: 4, + } + }); + let ch = extra.module.constants.fetch_or_append(Constant { + name: None, + specialization: None, + ty, + inner: ConstantInner::Float(f.1 as f64) + }); + ExpressionRule::from_expression( + extra.context.expressions.append(Expression::Constant(ch)) + ) + } + primary_expression ::= BoolConstant(b) { + let ty = extra.module.types.fetch_or_append(Type { + name: None, + inner: TypeInner::Scalar { + kind: ScalarKind::Bool, + width: 4, + } + }); + let ch = extra.module.constants.fetch_or_append(Constant { + name: None, + specialization: None, + ty, + inner: ConstantInner::Bool(b.1) + }); + ExpressionRule::from_expression( + extra.context.expressions.append(Expression::Constant(ch)) + ) + } + // primary_expression ::= DoubleConstant; + primary_expression ::= LeftParen expression(e) RightParen { + e + } + + postfix_expression ::= primary_expression; + postfix_expression ::= postfix_expression LeftBracket integer_expression RightBracket { + //TODO + return Err(ErrorKind::NotImplemented("[]")) + } + postfix_expression ::= function_call; + postfix_expression ::= postfix_expression(e) Dot Identifier(i) /* FieldSelection in spec */ { + //TODO: how will this work as l-value? + let expression = extra.field_selection(e.expression, &*i.1, i.0)?; + ExpressionRule { expression, statements: e.statements, sampler: None } + } + postfix_expression ::= postfix_expression(pe) IncOp { + //TODO + return Err(ErrorKind::NotImplemented("post++")) + } + postfix_expression ::= postfix_expression(pe) DecOp { + //TODO + return Err(ErrorKind::NotImplemented("post--")) + } + + integer_expression ::= expression; + + function_call ::= function_call_or_method(fc) { + match fc.kind { + FunctionCallKind::TypeConstructor(ty) => { + let h = if fc.args.len() == 1 { + let kind = extra.module.types[ty].inner + .scalar_kind() + .ok_or(ErrorKind::SemanticError("Can only cast to scalar or vector"))?; + extra.context.expressions.append(Expression::As { + kind, + expr: fc.args[0].expression, + convert: true, + }) + } else { + extra.context.expressions.append(Expression::Compose { + ty, + components: fc.args.iter().map(|a| a.expression).collect(), + }) + }; + ExpressionRule { + expression: h, + statements: fc.args.into_iter().map(|a| a.statements).flatten().collect(), + sampler: None + } + } + FunctionCallKind::Function(name) => { + match name.as_str() { + "sampler2D" => { + //TODO: check args len + ExpressionRule{ + expression: fc.args[0].expression, + sampler: Some(fc.args[1].expression), + statements: fc.args.into_iter().map(|a| a.statements).flatten().collect(), + } + } + "texture" => { + //TODO: check args len + if let Some(sampler) = fc.args[0].sampler { + ExpressionRule{ + expression: extra.context.expressions.append(Expression::ImageSample { + image: fc.args[0].expression, + sampler, + coordinate: fc.args[1].expression, + level: SampleLevel::Auto, + depth_ref: None, + }), + sampler: None, + statements: fc.args.into_iter().map(|a| a.statements).flatten().collect(), + } + } else { + return Err(ErrorKind::SemanticError("Bad call to texture")); + } + } + _ => { return Err(ErrorKind::NotImplemented("Function call")); } + } + } + } + } + function_call_or_method ::= function_call_generic; + function_call_generic ::= function_call_header_with_parameters(h) RightParen { + h + } + function_call_generic ::= function_call_header_no_parameters(h) RightParen { + h + } + function_call_header_no_parameters ::= function_call_header(h) Void { + h + } + function_call_header_no_parameters ::= function_call_header; + function_call_header_with_parameters ::= function_call_header(mut h) assignment_expression(ae) { + h.args.push(ae); + h + } + function_call_header_with_parameters ::= function_call_header_with_parameters(mut h) Comma assignment_expression(ae) { + h.args.push(ae); + h + } + function_call_header ::= function_identifier(i) LeftParen { + FunctionCall { + kind: i, + args: vec![], + } + } + + // Grammar Note: Constructors look like functions, but lexical analysis recognized most of them as + // keywords. They are now recognized through “type_specifier”. + function_identifier ::= type_specifier(t) { + if let Some(ty) = t { + FunctionCallKind::TypeConstructor(ty) + } else { + return Err(ErrorKind::NotImplemented("bad type ctor")) + } + } + + //TODO + // Methods (.length), subroutine array calls, and identifiers are recognized through postfix_expression. + // function_identifier ::= postfix_expression(e) { + // FunctionCallKind::Function(e.expression) + // } + + // Simplification of above + function_identifier ::= Identifier(i) { + FunctionCallKind::Function(i.1) + } + + + unary_expression ::= postfix_expression; + + unary_expression ::= IncOp unary_expression { + //TODO + return Err(ErrorKind::NotImplemented("++pre")) + } + unary_expression ::= DecOp unary_expression { + //TODO + return Err(ErrorKind::NotImplemented("--pre")) + } + unary_expression ::= unary_operator unary_expression { + //TODO + return Err(ErrorKind::NotImplemented("unary_op")) + } + + unary_operator ::= Plus; + unary_operator ::= Dash; + unary_operator ::= Bang; + unary_operator ::= Tilde; + multiplicative_expression ::= unary_expression; + multiplicative_expression ::= multiplicative_expression(left) Star unary_expression(right) { + extra.binary_expr(BinaryOperator::Multiply, &left, &right) + } + multiplicative_expression ::= multiplicative_expression(left) Slash unary_expression(right) { + extra.binary_expr(BinaryOperator::Divide, &left, &right) + } + multiplicative_expression ::= multiplicative_expression(left) Percent unary_expression(right) { + extra.binary_expr(BinaryOperator::Modulo, &left, &right) + } + additive_expression ::= multiplicative_expression; + additive_expression ::= additive_expression(left) Plus multiplicative_expression(right) { + extra.binary_expr(BinaryOperator::Add, &left, &right) + } + additive_expression ::= additive_expression(left) Dash multiplicative_expression(right) { + extra.binary_expr(BinaryOperator::Subtract, &left, &right) + } + shift_expression ::= additive_expression; + shift_expression ::= shift_expression(left) LeftOp additive_expression(right) { + extra.binary_expr(BinaryOperator::ShiftLeft, &left, &right) + } + shift_expression ::= shift_expression(left) RightOp additive_expression(right) { + extra.binary_expr(BinaryOperator::ShiftRight, &left, &right) + } + relational_expression ::= shift_expression; + relational_expression ::= relational_expression(left) LeftAngle shift_expression(right) { + extra.binary_expr(BinaryOperator::Less, &left, &right) + } + relational_expression ::= relational_expression(left) RightAngle shift_expression(right) { + extra.binary_expr(BinaryOperator::Greater, &left, &right) + } + relational_expression ::= relational_expression(left) LeOp shift_expression(right) { + extra.binary_expr(BinaryOperator::LessEqual, &left, &right) + } + relational_expression ::= relational_expression(left) GeOp shift_expression(right) { + extra.binary_expr(BinaryOperator::GreaterEqual, &left, &right) + } + equality_expression ::= relational_expression; + equality_expression ::= equality_expression(left) EqOp relational_expression(right) { + extra.binary_expr(BinaryOperator::Equal, &left, &right) + } + equality_expression ::= equality_expression(left) NeOp relational_expression(right) { + extra.binary_expr(BinaryOperator::NotEqual, &left, &right) + } + and_expression ::= equality_expression; + and_expression ::= and_expression(left) Ampersand equality_expression(right) { + extra.binary_expr(BinaryOperator::And, &left, &right) + } + exclusive_or_expression ::= and_expression; + exclusive_or_expression ::= exclusive_or_expression(left) Caret and_expression(right) { + extra.binary_expr(BinaryOperator::ExclusiveOr, &left, &right) + } + inclusive_or_expression ::= exclusive_or_expression; + inclusive_or_expression ::= inclusive_or_expression(left) VerticalBar exclusive_or_expression(right) { + extra.binary_expr(BinaryOperator::InclusiveOr, &left, &right) + } + logical_and_expression ::= inclusive_or_expression; + logical_and_expression ::= logical_and_expression(left) AndOp inclusive_or_expression(right) { + extra.binary_expr(BinaryOperator::LogicalAnd, &left, &right) + } + logical_xor_expression ::= logical_and_expression; + logical_xor_expression ::= logical_xor_expression(left) XorOp logical_and_expression(right) { + let exp1 = extra.binary_expr(BinaryOperator::LogicalOr, &left, &right); + let exp2 = { + let tmp = extra.binary_expr(BinaryOperator::LogicalAnd, &left, &right).expression; + ExpressionRule::from_expression(extra.context.expressions.append(Expression::Unary { op: UnaryOperator::Not, expr: tmp })) + }; + extra.binary_expr(BinaryOperator::LogicalAnd, &exp1, &exp2) + } + logical_or_expression ::= logical_xor_expression; + logical_or_expression ::= logical_or_expression(left) OrOp logical_xor_expression(right) { + extra.binary_expr(BinaryOperator::LogicalOr, &left, &right) + } + + conditional_expression ::= logical_or_expression; + conditional_expression ::= logical_or_expression Question expression Colon assignment_expression(ae) { + //TODO: how to do ternary here in naga? + return Err(ErrorKind::NotImplemented("ternary exp")) + } + + assignment_expression ::= conditional_expression; + assignment_expression ::= unary_expression(mut pointer) assignment_operator(op) assignment_expression(value) { + pointer.statements.extend(value.statements); + match op { + BinaryOperator::Equal => { + pointer.statements.push(Statement::Store{ + pointer: pointer.expression, + value: value.expression + }); + pointer + }, + _ => { + let h = extra.context.expressions.append( + Expression::Binary{ + op, + left: pointer.expression, + right: value.expression, + } + ); + pointer.statements.push(Statement::Store{ + pointer: pointer.expression, + value: h, + }); + pointer + } + } + } + + assignment_operator ::= Equal { + BinaryOperator::Equal + } + assignment_operator ::= MulAssign { + BinaryOperator::Multiply + } + assignment_operator ::= DivAssign { + BinaryOperator::Divide + } + assignment_operator ::= ModAssign { + BinaryOperator::Modulo + } + assignment_operator ::= AddAssign { + BinaryOperator::Add + } + assignment_operator ::= SubAssign { + BinaryOperator::Subtract + } + assignment_operator ::= LeftAssign { + BinaryOperator::ShiftLeft + } + assignment_operator ::= RightAssign { + BinaryOperator::ShiftRight + } + assignment_operator ::= AndAssign { + BinaryOperator::And + } + assignment_operator ::= XorAssign { + BinaryOperator::ExclusiveOr + } + assignment_operator ::= OrAssign { + BinaryOperator::InclusiveOr + } + + expression ::= assignment_expression; + expression ::= expression(e) Comma assignment_expression(mut ae) { + ae.statements.extend(e.statements); + ExpressionRule { + expression: e.expression, + statements: ae.statements, + sampler: None, + } + } + + //TODO: properly handle constant expressions + // constant_expression ::= conditional_expression(e) { + // if let Expression::Constant(h) = extra.context.expressions[e] { + // h + // } else { + // return Err(ErrorKind::ExpectedConstant) + // } + // } + + // declaration + declaration ::= init_declarator_list(idl) Semicolon { + Some(idl) + } + + declaration ::= type_qualifier(t) Identifier(i) LeftBrace + struct_declaration_list(sdl) RightBrace Semicolon { + if i.1 == "gl_PerVertex" { + None + } else { + Some(VarDeclaration { + type_qualifiers: t, + ids_initializers: vec![(None, None)], + ty: extra.module.types.fetch_or_append(Type{ + name: Some(i.1), + inner: TypeInner::Struct { + members: sdl + } + }), + }) + } + } + + declaration ::= type_qualifier(t) Identifier(i1) LeftBrace + struct_declaration_list(sdl) RightBrace Identifier(i2) Semicolon { + Some(VarDeclaration { + type_qualifiers: t, + ids_initializers: vec![(Some(i2.1), None)], + ty: extra.module.types.fetch_or_append(Type{ + name: Some(i1.1), + inner: TypeInner::Struct { + members: sdl + } + }), + }) + } + + // declaration ::= type_qualifier(t) Identifier(i1) LeftBrace + // struct_declaration_list RightBrace Identifier(i2) array_specifier Semicolon; + + init_declarator_list ::= single_declaration; + init_declarator_list ::= init_declarator_list(mut idl) Comma Identifier(i) { + idl.ids_initializers.push((Some(i.1), None)); + idl + } + // init_declarator_list ::= init_declarator_list Comma Identifier array_specifier; + // init_declarator_list ::= init_declarator_list Comma Identifier array_specifier Equal initializer; + init_declarator_list ::= init_declarator_list(mut idl) Comma Identifier(i) Equal initializer(init) { + idl.ids_initializers.push((Some(i.1), Some(init))); + idl + } + + single_declaration ::= fully_specified_type(t) { + let ty = t.1.ok_or(ErrorKind::SemanticError("Empty type for declaration"))?; + + VarDeclaration { + type_qualifiers: t.0, + ids_initializers: vec![], + ty, + } + } + single_declaration ::= fully_specified_type(t) Identifier(i) { + let ty = t.1.ok_or(ErrorKind::SemanticError("Empty type for declaration"))?; + + VarDeclaration { + type_qualifiers: t.0, + ids_initializers: vec![(Some(i.1), None)], + ty, + } + } + // single_declaration ::= fully_specified_type Identifier array_specifier; + // single_declaration ::= fully_specified_type Identifier array_specifier Equal initializer; + single_declaration ::= fully_specified_type(t) Identifier(i) Equal initializer(init) { + let ty = t.1.ok_or(ErrorKind::SemanticError("Empty type for declaration"))?; + + VarDeclaration { + type_qualifiers: t.0, + ids_initializers: vec![(Some(i.1), Some(init))], + ty, + } + } + + fully_specified_type ::= type_specifier(t) { + (vec![], t) + } + fully_specified_type ::= type_qualifier(q) type_specifier(t) { + (q,t) + } + + interpolation_qualifier ::= Interpolation((_, i)) { + i + } + + layout_qualifier ::= Layout LeftParen layout_qualifier_id_list(l) RightParen { + if let Some(&(_, loc)) = l.iter().find(|&q| q.0.as_str() == "location") { + Binding::Location(loc) + } else if let Some(&(_, binding)) = l.iter().find(|&q| q.0.as_str() == "binding") { + let group = if let Some(&(_, set)) = l.iter().find(|&q| q.0.as_str() == "set") { + set + } else { + 0 + }; + Binding::Resource{ group, binding } + } else { + return Err(ErrorKind::NotImplemented("unsupported layout qualifier(s)")); + } + } + layout_qualifier_id_list ::= layout_qualifier_id(lqi) { + vec![lqi] + } + layout_qualifier_id_list ::= layout_qualifier_id_list(mut l) Comma layout_qualifier_id(lqi) { + l.push(lqi); + l + } + layout_qualifier_id ::= Identifier(i) { + (i.1, 0) + } + //TODO: handle full constant_expression instead of IntConstant + layout_qualifier_id ::= Identifier(i) Equal IntConstant(ic) { + (i.1, ic.1 as u32) + } + // layout_qualifier_id ::= Shared; + + // precise_qualifier ::= Precise; + + type_qualifier ::= single_type_qualifier(t) { + vec![t] + } + type_qualifier ::= type_qualifier(mut l) single_type_qualifier(t) { + l.push(t); + l + } + + single_type_qualifier ::= storage_qualifier(s) { + TypeQualifier::StorageClass(s) + } + single_type_qualifier ::= layout_qualifier(l) { + TypeQualifier::Binding(l) + } + // single_type_qualifier ::= precision_qualifier; + single_type_qualifier ::= interpolation_qualifier(i) { + TypeQualifier::Interpolation(i) + } + // single_type_qualifier ::= invariant_qualifier; + // single_type_qualifier ::= precise_qualifier; + + // storage_qualifier ::= Const + // storage_qualifier ::= InOut; + storage_qualifier ::= In { + StorageClass::Input + } + storage_qualifier ::= Out { + StorageClass::Output + } + // storage_qualifier ::= Centroid; + // storage_qualifier ::= Patch; + // storage_qualifier ::= Sample; + storage_qualifier ::= Uniform { + StorageClass::Uniform + } + //TODO: other storage qualifiers + + type_specifier ::= type_specifier_nonarray(t) { + t.map(|t| { + let name = t.name.clone(); + let handle = extra.module.types.fetch_or_append(t); + if let Some(name) = name { + extra.lookup_type.insert(name, handle); + } + handle + }) + } + //TODO: array + + type_specifier_nonarray ::= Void { + None + } + type_specifier_nonarray ::= TypeName(t) { + Some(t.1) + }; + type_specifier_nonarray ::= struct_specifier(s) { + Some(s) + } + + // struct + struct_specifier ::= Struct Identifier(i) LeftBrace struct_declaration_list RightBrace { + Type{ + name: Some(i.1), + inner: TypeInner::Struct { + members: vec![] + } + } + } + //struct_specifier ::= Struct LeftBrace struct_declaration_list RightBrace; + + struct_declaration_list ::= struct_declaration(sd) { + sd + } + struct_declaration_list ::= struct_declaration_list(mut sdl) struct_declaration(sd) { + sdl.extend(sd); + sdl + } + + struct_declaration ::= type_specifier(t) struct_declarator_list(sdl) Semicolon { + if let Some(ty) = t { + sdl.iter().map(|name| StructMember { + name: Some(name.clone()), + origin: MemberOrigin::Empty, + ty, + }).collect() + } else { + return Err(ErrorKind::SemanticError("Struct member can't be void")) + } + } + //struct_declaration ::= type_qualifier type_specifier struct_declarator_list Semicolon; + + struct_declarator_list ::= struct_declarator(sd) { + vec![sd] + } + struct_declarator_list ::= struct_declarator_list(mut sdl) Comma struct_declarator(sd) { + sdl.push(sd); + sdl + } + + struct_declarator ::= Identifier(i) { + i.1 + } + //struct_declarator ::= Identifier array_specifier; + + + initializer ::= assignment_expression; + // initializer ::= LeftBrace initializer_list RightBrace; + // initializer ::= LeftBrace initializer_list Comma RightBrace; + + // initializer_list ::= initializer; + // initializer_list ::= initializer_list Comma initializer; + + declaration_statement ::= declaration(d) { + let mut statements = Vec::<Statement>::new(); + // local variables + if let Some(d) = d { + for (id, initializer) in d.ids_initializers { + let id = id.ok_or(ErrorKind::SemanticError("local var must be named"))?; + // check if already declared in current scope + #[cfg(feature = "glsl-validate")] + { + if extra.context.lookup_local_var_current_scope(&id).is_some() { + return Err(ErrorKind::VariableAlreadyDeclared(id)) + } + } + let mut init_exp: Option<Handle<Expression>> = None; + let localVar = extra.context.local_variables.append( + LocalVariable { + name: Some(id.clone()), + ty: d.ty, + init: initializer.map(|i| { + statements.extend(i.statements); + if let Expression::Constant(constant) = extra.context.expressions[i.expression] { + Some(constant) + } else { + init_exp = Some(i.expression); + None + } + }).flatten(), + } + ); + let exp = extra.context.expressions.append(Expression::LocalVariable(localVar)); + extra.context.add_local_var(id, exp); + + if let Some(value) = init_exp { + statements.push( + Statement::Store { + pointer: exp, + value, + } + ); + } + } + }; + match statements.len() { + 1 => statements.remove(0), + _ => Statement::Block(statements), + } + } + + // statement + statement ::= compound_statement(cs) { + Statement::Block(cs) + } + statement ::= simple_statement; + + simple_statement ::= declaration_statement; + simple_statement ::= expression_statement; + simple_statement ::= selection_statement; + simple_statement ::= jump_statement; + simple_statement ::= iteration_statement; + + + selection_statement ::= If LeftParen expression(e) RightParen statement(s1) Else statement(s2) { + Statement::If { + condition: e.expression, + accept: vec![s1], + reject: vec![s2], + } + } + + selection_statement ::= If LeftParen expression(e) RightParen statement(s) [Else] { + Statement::If { + condition: e.expression, + accept: vec![s], + reject: vec![], + } + } + + selection_statement ::= Switch LeftParen expression(e) RightParen LeftBrace switch_statement_list(ls) RightBrace { + let mut default = Vec::new(); + let mut cases = FastHashMap::default(); + for (v, s, ft) in ls { + if let Some(v) = v { + cases.insert(v, (s, ft)); + } else { + default.extend_from_slice(&s); + } + } + Statement::Switch { + selector: e.expression, + cases, + default, + } + } + + switch_statement_list ::= { + vec![] + } + switch_statement_list ::= switch_statement_list(mut ssl) switch_statement((v, sl, ft)) { + ssl.push((v, sl, ft)); + ssl + } + switch_statement ::= Case IntConstant(v) Colon statement_list(sl) { + let fallthrough = match sl.last() { + Some(Statement::Break) => None, + _ => Some(FallThrough), + }; + (Some(v.1 as i32), sl, fallthrough) + } + switch_statement ::= Default Colon statement_list(sl) { + let fallthrough = match sl.last() { + Some(Statement::Break) => Some(FallThrough), + _ => None, + }; + (None, sl, fallthrough) + } + + iteration_statement ::= While LeftParen expression(e) RightParen compound_statement_no_new_scope(sl) { + let mut body = Vec::with_capacity(sl.len() + 1); + body.push( + Statement::If { + condition: e.expression, + accept: vec![Statement::Break], + reject: vec![], + } + ); + body.extend_from_slice(&sl); + Statement::Loop { + body, + continuing: vec![], + } + } + + iteration_statement ::= Do compound_statement(sl) While LeftParen expression(e) RightParen { + let mut body = sl; + body.push( + Statement::If { + condition: e.expression, + accept: vec![Statement::Break], + reject: vec![], + } + ); + Statement::Loop { + body, + continuing: vec![], + } + } + + iteration_statement ::= For LeftParen for_init_statement(s_init) for_rest_statement((cond_e, loop_e)) RightParen compound_statement_no_new_scope(sl) { + let mut body = Vec::with_capacity(sl.len() + 2); + if let Some(cond_e) = cond_e { + body.push( + Statement::If { + condition: cond_e.expression, + accept: vec![Statement::Break], + reject: vec![], + } + ); + } + body.extend_from_slice(&sl); + if let Some(loop_e) = loop_e { + body.extend_from_slice(&loop_e.statements); + } + Statement::Block(vec![ + s_init, + Statement::Loop { + body, + continuing: vec![], + } + ]) + } + + for_init_statement ::= expression_statement; + for_init_statement ::= declaration_statement; + for_rest_statement ::= condition_opt(c) Semicolon { + (c, None) + } + for_rest_statement ::= condition_opt(c) Semicolon expression(e) { + (c, Some(e)) + } + + condition_opt ::= { + None + } + condition_opt ::= conditional_expression(c) { + Some(c) + } + + compound_statement ::= LeftBrace RightBrace { + vec![] + } + compound_statement ::= left_brace_scope statement_list(sl) RightBrace { + extra.context.remove_current_scope(); + sl + } + + // extra rule to add scope before statement_list + left_brace_scope ::= LeftBrace { + extra.context.push_scope(); + } + + + compound_statement_no_new_scope ::= LeftBrace RightBrace { + vec![] + } + compound_statement_no_new_scope ::= LeftBrace statement_list(sl) RightBrace { + sl + } + + statement_list ::= statement(s) { + vec![s] + } + statement_list ::= statement_list(mut ss) statement(s) { ss.push(s); ss } + + expression_statement ::= Semicolon { + Statement::Block(Vec::new()) + } + expression_statement ::= expression(mut e) Semicolon { + match e.statements.len() { + 1 => e.statements.remove(0), + _ => Statement::Block(e.statements), + } + } + + + + // function + function_prototype ::= function_declarator(f) RightParen { + // prelude, add global var expressions + for (var_handle, var) in extra.module.global_variables.iter() { + if let Some(name) = var.name.as_ref() { + let exp = extra.context.expressions.append( + Expression::GlobalVariable(var_handle) + ); + extra.context.lookup_global_var_exps.insert(name.clone(), exp); + } else { + let ty = &extra.module.types[var.ty]; + // anonymous structs + if let TypeInner::Struct { members } = &ty.inner { + let base = extra.context.expressions.append( + Expression::GlobalVariable(var_handle) + ); + for (idx, member) in members.iter().enumerate() { + if let Some(name) = member.name.as_ref() { + let exp = extra.context.expressions.append( + Expression::AccessIndex{ + base, + index: idx as u32, + } + ); + extra.context.lookup_global_var_exps.insert(name.clone(), exp); + } + } + } + } + } + f + } + function_declarator ::= function_header; + function_header ::= fully_specified_type(t) Identifier(n) LeftParen { + Function { + name: Some(n.1), + arguments: vec![], + return_type: t.1, + global_usage: vec![], + local_variables: Arena::<LocalVariable>::new(), + expressions: Arena::<Expression>::new(), + body: vec![], + } + } + + jump_statement ::= Continue Semicolon { + Statement::Continue + } + jump_statement ::= Break Semicolon { + Statement::Break + } + jump_statement ::= Return Semicolon { + Statement::Return { value: None } + } + jump_statement ::= Return expression(mut e) Semicolon { + let ret = Statement::Return{ value: Some(e.expression) }; + if !e.statements.is_empty() { + e.statements.push(ret); + Statement::Block(e.statements) + } else { + ret + } + } + jump_statement ::= Discard Semicolon { + Statement::Kill + } // Fragment shader only + + // Grammar Note: No 'goto'. Gotos are not supported. + + // misc + translation_unit ::= external_declaration; + translation_unit ::= translation_unit external_declaration; + + external_declaration ::= function_definition(f) { + if f.name == extra.entry { + let name = extra.entry.take().unwrap(); + extra.module.entry_points.insert( + (extra.shader_stage, name), + EntryPoint { + early_depth_test: None, + workgroup_size: [0; 3], //TODO + function: f, + }, + ); + } else { + let name = f.name.clone().unwrap(); + let handle = extra.module.functions.append(f); + extra.lookup_function.insert(name, handle); + } + } + external_declaration ::= declaration(d) { + if let Some(d) = d { + let class = d.type_qualifiers.iter().find_map(|tq| { + if let TypeQualifier::StorageClass(sc) = tq { Some(*sc) } else { None } + }).ok_or(ErrorKind::SemanticError("Missing storage class for global var"))?; + + let binding = d.type_qualifiers.iter().find_map(|tq| { + if let TypeQualifier::Binding(b) = tq { Some(b.clone()) } else { None } + }); + + let interpolation = d.type_qualifiers.iter().find_map(|tq| { + if let TypeQualifier::Interpolation(i) = tq { Some(*i) } else { None } + }); + + for (id, initializer) in d.ids_initializers { + let h = extra.module.global_variables.fetch_or_append( + GlobalVariable { + name: id.clone(), + class, + binding: binding.clone(), + ty: d.ty, + init: None, + interpolation, + storage_access: StorageAccess::empty(), //TODO + }, + ); + if let Some(id) = id { + extra.lookup_global_variables.insert(id, h); + } + } + } + } + + function_definition ::= function_prototype(mut f) compound_statement_no_new_scope(mut cs) { + std::mem::swap(&mut f.expressions, &mut extra.context.expressions); + std::mem::swap(&mut f.local_variables, &mut extra.context.local_variables); + extra.context.clear_scopes(); + extra.context.lookup_global_var_exps.clear(); + extra.context.typifier = Typifier::new(); + // make sure function ends with return + match cs.last() { + Some(Statement::Return {..}) => {} + _ => {cs.push(Statement::Return { value:None });} + } + f.body = cs; + f.fill_global_use(&extra.module.global_variables); + f + }; +} + +pub use parser::*; diff --git a/third_party/rust/naga/src/front/glsl/parser_tests.rs b/third_party/rust/naga/src/front/glsl/parser_tests.rs new file mode 100644 index 0000000000..ccdc657a78 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/parser_tests.rs @@ -0,0 +1,182 @@ +use super::ast::Program; +use super::error::ErrorKind; +use super::lex::Lexer; +use super::parser; +use crate::ShaderStage; + +fn parse_program(source: &str, stage: ShaderStage) -> Result<Program, ErrorKind> { + let mut program = Program::new(stage, ""); + let lex = Lexer::new(source); + let mut parser = parser::Parser::new(&mut program); + + for token in lex { + parser.parse(token)?; + } + parser.end_of_input()?; + Ok(program) +} + +#[test] +fn version() { + // invalid versions + assert_eq!( + format!( + "{:?}", + parse_program("#version 99000", ShaderStage::Vertex) + .err() + .unwrap() + ), + "InvalidVersion(TokenMetadata { line: 0, chars: 9..14 }, 99000)" + ); + + assert_eq!( + format!( + "{:?}", + parse_program("#version 449", ShaderStage::Vertex) + .err() + .unwrap() + ), + "InvalidVersion(TokenMetadata { line: 0, chars: 9..12 }, 449)" + ); + + assert_eq!( + format!( + "{:?}", + parse_program("#version 450 smart", ShaderStage::Vertex) + .err() + .unwrap() + ), + "InvalidProfile(TokenMetadata { line: 0, chars: 13..18 }, \"smart\")" + ); + + assert_eq!( + format!( + "{:?}", + parse_program("#version 450\nvoid f(){} #version 450", ShaderStage::Vertex) + .err() + .unwrap() + ), + "InvalidToken(Unknown((TokenMetadata { line: 1, chars: 11..12 }, \"#\")))" + ); + + // valid versions + let program = parse_program(" # version 450\nvoid main() {}", ShaderStage::Vertex).unwrap(); + assert_eq!( + format!("{:?}", (program.version, program.profile)), + "(450, Core)" + ); + + let program = parse_program("#version 450\nvoid main() {}", ShaderStage::Vertex).unwrap(); + assert_eq!( + format!("{:?}", (program.version, program.profile)), + "(450, Core)" + ); + + let program = parse_program("#version 450 core\nvoid main() {}", ShaderStage::Vertex).unwrap(); + assert_eq!( + format!("{:?}", (program.version, program.profile)), + "(450, Core)" + ); +} + +#[test] +fn control_flow() { + let _program = parse_program( + r#" + # version 450 + void main() { + if (true) { + return 1; + } else { + return 2; + } + } + "#, + ShaderStage::Vertex, + ) + .unwrap(); + + let _program = parse_program( + r#" + # version 450 + void main() { + if (true) { + return 1; + } + } + "#, + ShaderStage::Vertex, + ) + .unwrap(); + + let _program = parse_program( + r#" + # version 450 + void main() { + int x; + int y = 3; + switch (5) { + case 2: + x = 2; + case 5: + x = 5; + y = 2; + break; + default: + x = 0; + } + } + "#, + ShaderStage::Vertex, + ) + .unwrap(); + let _program = parse_program( + r#" + # version 450 + void main() { + int x = 0; + while(x < 5) { + x = x + 1; + } + do { + x = x - 1; + } while(x >= 4) + } + "#, + ShaderStage::Vertex, + ) + .unwrap(); + + let _program = parse_program( + r#" + # version 450 + void main() { + int x = 0; + for(int i = 0; i < 10;) { + x = x + 2; + } + return x; + } + "#, + ShaderStage::Vertex, + ) + .unwrap(); +} + +#[test] +fn textures() { + let _program = parse_program( + r#" + #version 450 + layout(location = 0) in vec2 v_uv; + layout(location = 0) out vec4 o_color; + layout(set = 1, binding = 1) uniform texture2D tex; + layout(set = 1, binding = 2) uniform sampler tex_sampler; + void main() { + o_color = texture(sampler2D(tex, tex_sampler), v_uv); + } + "#, + ShaderStage::Fragment, + ) + .unwrap(); +} diff --git a/third_party/rust/naga/src/front/glsl/preprocess.rs b/third_party/rust/naga/src/front/glsl/preprocess.rs new file mode 100644 index 0000000000..6050594719 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/preprocess.rs @@ -0,0 +1,152 @@ +use crate::FastHashMap; +use thiserror::Error; + +#[derive(Clone, Debug, Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum Error { + #[error("unmatched else")] + UnmatchedElse, + #[error("unmatched endif")] + UnmatchedEndif, + #[error("missing macro name")] + MissingMacro, +} + +#[derive(Clone, Debug)] +pub struct IfState { + true_branch: bool, + else_seen: bool, +} + +#[derive(Clone, Debug)] +pub struct LinePreProcessor { + pub defines: FastHashMap<String, String>, + if_stack: Vec<IfState>, + inside_comment: bool, + in_preprocess: bool, +} + +impl LinePreProcessor { + pub fn new() -> Self { + LinePreProcessor { + defines: FastHashMap::default(), + if_stack: vec![], + inside_comment: false, + in_preprocess: false, + } + } + + fn subst_defines(&self, input: &str) -> String { + //TODO: don't subst in commments, strings literals? + self.defines + .iter() + .fold(input.to_string(), |acc, (k, v)| acc.replace(k, v)) + } + + pub fn process_line(&mut self, line: &str) -> Result<Option<String>, Error> { + let mut skip = !self.if_stack.last().map(|i| i.true_branch).unwrap_or(true); + let mut inside_comment = self.inside_comment; + let mut in_preprocess = inside_comment && self.in_preprocess; + // single-line comment + let mut processed = line; + if let Some(pos) = line.find("//") { + processed = line.split_at(pos).0; + } + // multi-line comment + let mut processed_string: String; + loop { + if inside_comment { + if let Some(pos) = processed.find("*/") { + processed = processed.split_at(pos + 2).1; + inside_comment = false; + self.inside_comment = false; + continue; + } + } else if let Some(pos) = processed.find("/*") { + if let Some(end_pos) = processed[pos + 2..].find("*/") { + // comment ends during this line + processed_string = processed.to_string(); + processed_string.replace_range(pos..pos + end_pos + 4, ""); + processed = &processed_string; + } else { + processed = processed.split_at(pos).0; + inside_comment = true; + } + continue; + } + break; + } + // strip leading whitespace + processed = processed.trim_start(); + if processed.starts_with('#') && !self.inside_comment { + let mut iter = processed[1..] + .trim_start() + .splitn(2, |c: char| c.is_whitespace()); + if let Some(directive) = iter.next() { + skip = true; + in_preprocess = true; + match directive { + "version" => { + skip = false; + } + "define" => { + let rest = iter.next().ok_or(Error::MissingMacro)?; + let pos = rest + .find(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '(') + .unwrap_or_else(|| rest.len()); + let (key, mut value) = rest.split_at(pos); + value = value.trim(); + self.defines.insert(key.into(), self.subst_defines(value)); + } + "undef" => { + let rest = iter.next().ok_or(Error::MissingMacro)?; + let key = rest.trim(); + self.defines.remove(key); + } + "ifdef" => { + let rest = iter.next().ok_or(Error::MissingMacro)?; + let key = rest.trim(); + self.if_stack.push(IfState { + true_branch: self.defines.contains_key(key), + else_seen: false, + }); + } + "ifndef" => { + let rest = iter.next().ok_or(Error::MissingMacro)?; + let key = rest.trim(); + self.if_stack.push(IfState { + true_branch: !self.defines.contains_key(key), + else_seen: false, + }); + } + "else" => { + let if_state = self.if_stack.last_mut().ok_or(Error::UnmatchedElse)?; + if !if_state.else_seen { + // this is first else + if_state.true_branch = !if_state.true_branch; + if_state.else_seen = true; + } else { + return Err(Error::UnmatchedElse); + } + } + "endif" => { + self.if_stack.pop().ok_or(Error::UnmatchedEndif)?; + } + _ => {} + } + } + } + let res = if !skip && !self.inside_comment { + Ok(Some(self.subst_defines(&line))) + } else { + Ok(if in_preprocess && !self.in_preprocess { + Some("".to_string()) + } else { + None + }) + }; + self.in_preprocess = in_preprocess || skip; + self.inside_comment = inside_comment; + res + } +} diff --git a/third_party/rust/naga/src/front/glsl/preprocess_tests.rs b/third_party/rust/naga/src/front/glsl/preprocess_tests.rs new file mode 100644 index 0000000000..253a99935e --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/preprocess_tests.rs @@ -0,0 +1,218 @@ +use super::preprocess::{Error, LinePreProcessor}; +use std::{iter::Enumerate, str::Lines}; + +#[derive(Clone, Debug)] +pub struct PreProcessor<'a> { + lines: Enumerate<Lines<'a>>, + input: String, + line: usize, + offset: usize, + line_pp: LinePreProcessor, +} + +impl<'a> PreProcessor<'a> { + pub fn new(input: &'a str) -> Self { + let mut lexer = PreProcessor { + lines: input.lines().enumerate(), + input: "".to_string(), + line: 0, + offset: 0, + line_pp: LinePreProcessor::new(), + }; + lexer.next_line(); + lexer + } + + fn next_line(&mut self) -> bool { + if let Some((line, input)) = self.lines.next() { + let mut input = String::from(input); + + while input.ends_with('\\') { + if let Some((_, next)) = self.lines.next() { + input.pop(); + input.push_str(next); + } else { + break; + } + } + + self.input = input; + self.line = line; + self.offset = 0; + true + } else { + false + } + } + + pub fn process(&mut self) -> Result<String, Error> { + let mut res = String::new(); + loop { + let line = &self.line_pp.process_line(&self.input)?; + if let Some(line) = line { + res.push_str(line); + } + if !self.next_line() { + break; + } + if line.is_some() { + res.push_str("\n"); + } + } + Ok(res) + } +} + +#[test] +fn preprocess() { + // line continuation + let mut pp = PreProcessor::new( + "void main my_\ + func", + ); + assert_eq!(pp.process().unwrap(), "void main my_func"); + + // preserve #version + let mut pp = PreProcessor::new( + "#version 450 core\n\ + void main()", + ); + assert_eq!(pp.process().unwrap(), "#version 450 core\nvoid main()"); + + // simple define + let mut pp = PreProcessor::new( + "#define FOO 42 \n\ + fun=FOO", + ); + assert_eq!(pp.process().unwrap(), "\nfun=42"); + + // ifdef with else + let mut pp = PreProcessor::new( + "#define FOO\n\ + #ifdef FOO\n\ + foo=42\n\ + #endif\n\ + some=17\n\ + #ifdef BAR\n\ + bar=88\n\ + #else\n\ + mm=49\n\ + #endif\n\ + done=1", + ); + assert_eq!( + pp.process().unwrap(), + "\n\ + foo=42\n\ + \n\ + some=17\n\ + \n\ + mm=49\n\ + \n\ + done=1" + ); + + // nested ifdef/ifndef + let mut pp = PreProcessor::new( + "#define FOO\n\ + #define BOO\n\ + #ifdef FOO\n\ + foo=42\n\ + #ifdef BOO\n\ + boo=44\n\ + #endif\n\ + ifd=0\n\ + #ifndef XYZ\n\ + nxyz=8\n\ + #endif\n\ + #endif\n\ + some=17\n\ + #ifdef BAR\n\ + bar=88\n\ + #else\n\ + mm=49\n\ + #endif\n\ + done=1", + ); + assert_eq!( + pp.process().unwrap(), + "\n\ + foo=42\n\ + \n\ + boo=44\n\ + \n\ + ifd=0\n\ + \n\ + nxyz=8\n\ + \n\ + some=17\n\ + \n\ + mm=49\n\ + \n\ + done=1" + ); + + // undef + let mut pp = PreProcessor::new( + "#define FOO\n\ + #ifdef FOO\n\ + foo=42\n\ + #endif\n\ + some=17\n\ + #undef FOO\n\ + #ifdef FOO\n\ + foo=88\n\ + #else\n\ + nofoo=66\n\ + #endif\n\ + done=1", + ); + assert_eq!( + pp.process().unwrap(), + "\n\ + foo=42\n\ + \n\ + some=17\n\ + \n\ + nofoo=66\n\ + \n\ + done=1" + ); + + // single-line comment + let mut pp = PreProcessor::new( + "#define FOO 42//1234\n\ + fun=FOO", + ); + assert_eq!(pp.process().unwrap(), "\nfun=42"); + + // multi-line comments + let mut pp = PreProcessor::new( + "#define FOO 52/*/1234\n\ + #define FOO 88\n\ + end of comment*/ /* one more comment */ #define FOO 56\n\ + fun=FOO", + ); + assert_eq!(pp.process().unwrap(), "\nfun=56"); + + // unmatched endif + let mut pp = PreProcessor::new( + "#ifdef FOO\n\ + foo=42\n\ + #endif\n\ + #endif", + ); + assert_eq!(pp.process(), Err(Error::UnmatchedEndif)); + + // unmatched else + let mut pp = PreProcessor::new( + "#ifdef FOO\n\ + foo=42\n\ + #else\n\ + bar=88\n\ + #else\n\ + bad=true\n\ + #endif", + ); + assert_eq!(pp.process(), Err(Error::UnmatchedElse)); +} diff --git a/third_party/rust/naga/src/front/glsl/token.rs b/third_party/rust/naga/src/front/glsl/token.rs new file mode 100644 index 0000000000..7b849d361f --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/token.rs @@ -0,0 +1,8 @@ +use std::ops::Range; + +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +pub struct TokenMetadata { + pub line: usize, + pub chars: Range<usize>, +} diff --git a/third_party/rust/naga/src/front/glsl/types.rs b/third_party/rust/naga/src/front/glsl/types.rs new file mode 100644 index 0000000000..715c0c7430 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/types.rs @@ -0,0 +1,120 @@ +use crate::{ScalarKind, Type, TypeInner, VectorSize}; + +pub fn parse_type(type_name: &str) -> Option<Type> { + match type_name { + "bool" => Some(Type { + name: None, + inner: TypeInner::Scalar { + kind: ScalarKind::Bool, + width: 4, // https://stackoverflow.com/questions/9419781/what-is-the-size-of-glsl-boolean + }, + }), + "float" => Some(Type { + name: None, + inner: TypeInner::Scalar { + kind: ScalarKind::Float, + width: 4, + }, + }), + "double" => Some(Type { + name: None, + inner: TypeInner::Scalar { + kind: ScalarKind::Float, + width: 8, + }, + }), + "int" => Some(Type { + name: None, + inner: TypeInner::Scalar { + kind: ScalarKind::Sint, + width: 4, + }, + }), + "uint" => Some(Type { + name: None, + inner: TypeInner::Scalar { + kind: ScalarKind::Uint, + width: 4, + }, + }), + "texture2D" => Some(Type { + name: None, + inner: TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Sampled { + kind: ScalarKind::Float, + multi: false, + }, + }, + }), + "sampler" => Some(Type { + name: None, + inner: TypeInner::Sampler { comparison: false }, + }), + word => { + fn kind_width_parse(ty: &str) -> Option<(ScalarKind, u8)> { + Some(match ty { + "" => (ScalarKind::Float, 4), + "b" => (ScalarKind::Bool, 4), + "i" => (ScalarKind::Sint, 4), + "u" => (ScalarKind::Uint, 4), + "d" => (ScalarKind::Float, 8), + _ => return None, + }) + } + + fn size_parse(n: &str) -> Option<VectorSize> { + Some(match n { + "2" => VectorSize::Bi, + "3" => VectorSize::Tri, + "4" => VectorSize::Quad, + _ => return None, + }) + } + + let vec_parse = |word: &str| { + let mut iter = word.split("vec"); + + let kind = iter.next()?; + let size = iter.next()?; + let (kind, width) = kind_width_parse(kind)?; + let size = size_parse(size)?; + + Some(Type { + name: None, + inner: TypeInner::Vector { size, kind, width }, + }) + }; + + let mat_parse = |word: &str| { + let mut iter = word.split("mat"); + + let kind = iter.next()?; + let size = iter.next()?; + let (_, width) = kind_width_parse(kind)?; + + let (columns, rows) = if let Some(size) = size_parse(size) { + (size, size) + } else { + let mut iter = size.split('x'); + match (iter.next()?, iter.next()?, iter.next()) { + (col, row, None) => (size_parse(col)?, size_parse(row)?), + _ => return None, + } + }; + + Some(Type { + name: None, + inner: TypeInner::Matrix { + columns, + rows, + width, + }, + }) + }; + + vec_parse(word).or_else(|| mat_parse(word)) + } + } +} diff --git a/third_party/rust/naga/src/front/glsl/variables.rs b/third_party/rust/naga/src/front/glsl/variables.rs new file mode 100644 index 0000000000..9f76335685 --- /dev/null +++ b/third_party/rust/naga/src/front/glsl/variables.rs @@ -0,0 +1,185 @@ +use crate::{ + Binding, BuiltIn, Expression, GlobalVariable, Handle, ScalarKind, ShaderStage, StorageAccess, + StorageClass, Type, TypeInner, VectorSize, +}; + +use super::ast::*; +use super::error::ErrorKind; +use super::token::TokenMetadata; + +impl Program { + pub fn lookup_variable(&mut self, name: &str) -> Result<Option<Handle<Expression>>, ErrorKind> { + let mut expression: Option<Handle<Expression>> = None; + match name { + "gl_Position" => { + #[cfg(feature = "glsl-validate")] + match self.shader_stage { + ShaderStage::Vertex | ShaderStage::Fragment { .. } => {} + _ => { + return Err(ErrorKind::VariableNotAvailable(name.into())); + } + }; + let h = self + .module + .global_variables + .fetch_or_append(GlobalVariable { + name: Some(name.into()), + class: if self.shader_stage == ShaderStage::Vertex { + StorageClass::Output + } else { + StorageClass::Input + }, + binding: Some(Binding::BuiltIn(BuiltIn::Position)), + ty: self.module.types.fetch_or_append(Type { + name: None, + inner: TypeInner::Vector { + size: VectorSize::Quad, + kind: ScalarKind::Float, + width: 4, + }, + }), + init: None, + interpolation: None, + storage_access: StorageAccess::empty(), + }); + self.lookup_global_variables.insert(name.into(), h); + let exp = self + .context + .expressions + .append(Expression::GlobalVariable(h)); + self.context.lookup_global_var_exps.insert(name.into(), exp); + + expression = Some(exp); + } + "gl_VertexIndex" => { + #[cfg(feature = "glsl-validate")] + match self.shader_stage { + ShaderStage::Vertex => {} + _ => { + return Err(ErrorKind::VariableNotAvailable(name.into())); + } + }; + let h = self + .module + .global_variables + .fetch_or_append(GlobalVariable { + name: Some(name.into()), + class: StorageClass::Input, + binding: Some(Binding::BuiltIn(BuiltIn::VertexIndex)), + ty: self.module.types.fetch_or_append(Type { + name: None, + inner: TypeInner::Scalar { + kind: ScalarKind::Uint, + width: 4, + }, + }), + init: None, + interpolation: None, + storage_access: StorageAccess::empty(), + }); + self.lookup_global_variables.insert(name.into(), h); + let exp = self + .context + .expressions + .append(Expression::GlobalVariable(h)); + self.context.lookup_global_var_exps.insert(name.into(), exp); + + expression = Some(exp); + } + _ => {} + } + + if let Some(expression) = expression { + Ok(Some(expression)) + } else if let Some(local_var) = self.context.lookup_local_var(name) { + Ok(Some(local_var)) + } else if let Some(global_var) = self.context.lookup_global_var_exps.get(name) { + Ok(Some(*global_var)) + } else { + Ok(None) + } + } + + pub fn field_selection( + &mut self, + expression: Handle<Expression>, + name: &str, + meta: TokenMetadata, + ) -> Result<Handle<Expression>, ErrorKind> { + match *self.resolve_type(expression)? { + TypeInner::Struct { ref members } => { + let index = members + .iter() + .position(|m| m.name == Some(name.into())) + .ok_or_else(|| ErrorKind::UnknownField(meta, name.into()))?; + Ok(self.context.expressions.append(Expression::AccessIndex { + base: expression, + index: index as u32, + })) + } + // swizzles (xyzw, rgba, stpq) + TypeInner::Vector { size, kind, width } => { + let check_swizzle_components = |comps: &str| { + name.chars() + .map(|c| { + comps + .find(c) + .and_then(|i| if i < size as usize { Some(i) } else { None }) + }) + .fold(Some(Vec::<usize>::new()), |acc, cur| { + cur.and_then(|i| { + acc.map(|mut v| { + v.push(i); + v + }) + }) + }) + }; + + let indices = check_swizzle_components("xyzw") + .or_else(|| check_swizzle_components("rgba")) + .or_else(|| check_swizzle_components("stpq")); + + if let Some(v) = indices { + let components: Vec<Handle<Expression>> = v + .iter() + .map(|idx| { + self.context.expressions.append(Expression::AccessIndex { + base: expression, + index: *idx as u32, + }) + }) + .collect(); + if components.len() == 1 { + // only single element swizzle, like pos.y, just return that component + Ok(components[0]) + } else { + Ok(self.context.expressions.append(Expression::Compose { + ty: self.module.types.fetch_or_append(Type { + name: None, + inner: TypeInner::Vector { + kind, + width, + size: match components.len() { + 2 => VectorSize::Bi, + 3 => VectorSize::Tri, + 4 => VectorSize::Quad, + _ => { + return Err(ErrorKind::SemanticError( + "Bad swizzle size", + )); + } + }, + }, + }), + components, + })) + } + } else { + Err(ErrorKind::SemanticError("Invalid swizzle for vector")) + } + } + _ => Err(ErrorKind::SemanticError("Can't lookup field on this type")), + } + } +} diff --git a/third_party/rust/naga/src/front/mod.rs b/third_party/rust/naga/src/front/mod.rs new file mode 100644 index 0000000000..cb543c72a9 --- /dev/null +++ b/third_party/rust/naga/src/front/mod.rs @@ -0,0 +1,32 @@ +//! Parsers which load shaders into memory. + +#[cfg(feature = "glsl-in")] +pub mod glsl; +#[cfg(feature = "spv-in")] +pub mod spv; +#[cfg(feature = "wgsl-in")] +pub mod wgsl; + +use crate::arena::Arena; + +pub const GENERATOR: u32 = 0; + +impl crate::Module { + pub fn from_header(header: crate::Header) -> Self { + crate::Module { + header, + types: Arena::new(), + constants: Arena::new(), + global_variables: Arena::new(), + functions: Arena::new(), + entry_points: crate::FastHashMap::default(), + } + } + + pub fn generate_empty() -> Self { + Self::from_header(crate::Header { + version: (1, 0, 0), + generator: GENERATOR, + }) + } +} diff --git a/third_party/rust/naga/src/front/spv/convert.rs b/third_party/rust/naga/src/front/spv/convert.rs new file mode 100644 index 0000000000..6036d08946 --- /dev/null +++ b/third_party/rust/naga/src/front/spv/convert.rs @@ -0,0 +1,123 @@ +use super::error::Error; +use num_traits::cast::FromPrimitive; +use std::convert::TryInto; + +pub fn map_binary_operator(word: spirv::Op) -> Result<crate::BinaryOperator, Error> { + use crate::BinaryOperator; + use spirv::Op; + + match word { + // Arithmetic Instructions +, -, *, /, % + Op::IAdd | Op::FAdd => Ok(BinaryOperator::Add), + Op::ISub | Op::FSub => Ok(BinaryOperator::Subtract), + Op::IMul | Op::FMul => Ok(BinaryOperator::Multiply), + Op::UDiv | Op::SDiv | Op::FDiv => Ok(BinaryOperator::Divide), + Op::UMod | Op::SMod | Op::FMod => Ok(BinaryOperator::Modulo), + // Relational and Logical Instructions + Op::IEqual | Op::FOrdEqual | Op::FUnordEqual => Ok(BinaryOperator::Equal), + Op::INotEqual | Op::FOrdNotEqual | Op::FUnordNotEqual => Ok(BinaryOperator::NotEqual), + Op::ULessThan | Op::SLessThan | Op::FOrdLessThan | Op::FUnordLessThan => { + Ok(BinaryOperator::Less) + } + Op::ULessThanEqual + | Op::SLessThanEqual + | Op::FOrdLessThanEqual + | Op::FUnordLessThanEqual => Ok(BinaryOperator::LessEqual), + Op::UGreaterThan | Op::SGreaterThan | Op::FOrdGreaterThan | Op::FUnordGreaterThan => { + Ok(BinaryOperator::Greater) + } + Op::UGreaterThanEqual + | Op::SGreaterThanEqual + | Op::FOrdGreaterThanEqual + | Op::FUnordGreaterThanEqual => Ok(BinaryOperator::GreaterEqual), + _ => Err(Error::UnknownInstruction(word as u16)), + } +} + +pub fn map_vector_size(word: spirv::Word) -> Result<crate::VectorSize, Error> { + match word { + 2 => Ok(crate::VectorSize::Bi), + 3 => Ok(crate::VectorSize::Tri), + 4 => Ok(crate::VectorSize::Quad), + _ => Err(Error::InvalidVectorSize(word)), + } +} + +pub fn map_image_dim(word: spirv::Word) -> Result<crate::ImageDimension, Error> { + use spirv::Dim as D; + match D::from_u32(word) { + Some(D::Dim1D) => Ok(crate::ImageDimension::D1), + Some(D::Dim2D) => Ok(crate::ImageDimension::D2), + Some(D::Dim3D) => Ok(crate::ImageDimension::D3), + Some(D::DimCube) => Ok(crate::ImageDimension::Cube), + _ => Err(Error::UnsupportedImageDim(word)), + } +} + +pub fn map_image_format(word: spirv::Word) -> Result<crate::StorageFormat, Error> { + match spirv::ImageFormat::from_u32(word) { + Some(spirv::ImageFormat::R8) => Ok(crate::StorageFormat::R8Unorm), + Some(spirv::ImageFormat::R8Snorm) => Ok(crate::StorageFormat::R8Snorm), + Some(spirv::ImageFormat::R8ui) => Ok(crate::StorageFormat::R8Uint), + Some(spirv::ImageFormat::R8i) => Ok(crate::StorageFormat::R8Sint), + Some(spirv::ImageFormat::R16ui) => Ok(crate::StorageFormat::R16Uint), + Some(spirv::ImageFormat::R16i) => Ok(crate::StorageFormat::R16Sint), + Some(spirv::ImageFormat::R16f) => Ok(crate::StorageFormat::R16Float), + Some(spirv::ImageFormat::Rg8) => Ok(crate::StorageFormat::Rg8Unorm), + Some(spirv::ImageFormat::Rg8Snorm) => Ok(crate::StorageFormat::Rg8Snorm), + Some(spirv::ImageFormat::Rg8ui) => Ok(crate::StorageFormat::Rg8Uint), + Some(spirv::ImageFormat::Rg8i) => Ok(crate::StorageFormat::Rg8Sint), + Some(spirv::ImageFormat::R32ui) => Ok(crate::StorageFormat::R32Uint), + Some(spirv::ImageFormat::R32i) => Ok(crate::StorageFormat::R32Sint), + Some(spirv::ImageFormat::R32f) => Ok(crate::StorageFormat::R32Float), + Some(spirv::ImageFormat::Rg16ui) => Ok(crate::StorageFormat::Rg16Uint), + Some(spirv::ImageFormat::Rg16i) => Ok(crate::StorageFormat::Rg16Sint), + Some(spirv::ImageFormat::Rg16f) => Ok(crate::StorageFormat::Rg16Float), + Some(spirv::ImageFormat::Rgba8) => Ok(crate::StorageFormat::Rgba8Unorm), + Some(spirv::ImageFormat::Rgba8Snorm) => Ok(crate::StorageFormat::Rgba8Snorm), + Some(spirv::ImageFormat::Rgba8ui) => Ok(crate::StorageFormat::Rgba8Uint), + Some(spirv::ImageFormat::Rgba8i) => Ok(crate::StorageFormat::Rgba8Sint), + Some(spirv::ImageFormat::Rgb10a2ui) => Ok(crate::StorageFormat::Rgb10a2Unorm), + Some(spirv::ImageFormat::R11fG11fB10f) => Ok(crate::StorageFormat::Rg11b10Float), + Some(spirv::ImageFormat::Rg32ui) => Ok(crate::StorageFormat::Rg32Uint), + Some(spirv::ImageFormat::Rg32i) => Ok(crate::StorageFormat::Rg32Sint), + Some(spirv::ImageFormat::Rg32f) => Ok(crate::StorageFormat::Rg32Float), + Some(spirv::ImageFormat::Rgba16ui) => Ok(crate::StorageFormat::Rgba16Uint), + Some(spirv::ImageFormat::Rgba16i) => Ok(crate::StorageFormat::Rgba16Sint), + Some(spirv::ImageFormat::Rgba16f) => Ok(crate::StorageFormat::Rgba16Float), + Some(spirv::ImageFormat::Rgba32ui) => Ok(crate::StorageFormat::Rgba32Uint), + Some(spirv::ImageFormat::Rgba32i) => Ok(crate::StorageFormat::Rgba32Sint), + Some(spirv::ImageFormat::Rgba32f) => Ok(crate::StorageFormat::Rgba32Float), + _ => Err(Error::UnsupportedImageFormat(word)), + } +} + +pub fn map_width(word: spirv::Word) -> Result<crate::Bytes, Error> { + (word >> 3) // bits to bytes + .try_into() + .map_err(|_| Error::InvalidTypeWidth(word)) +} + +pub fn map_builtin(word: spirv::Word) -> Result<crate::BuiltIn, Error> { + use spirv::BuiltIn as Bi; + Ok(match spirv::BuiltIn::from_u32(word) { + Some(Bi::BaseInstance) => crate::BuiltIn::BaseInstance, + Some(Bi::BaseVertex) => crate::BuiltIn::BaseVertex, + Some(Bi::ClipDistance) => crate::BuiltIn::ClipDistance, + Some(Bi::InstanceIndex) => crate::BuiltIn::InstanceIndex, + Some(Bi::Position) => crate::BuiltIn::Position, + Some(Bi::VertexIndex) => crate::BuiltIn::VertexIndex, + // fragment + Some(Bi::PointSize) => crate::BuiltIn::PointSize, + Some(Bi::FragCoord) => crate::BuiltIn::FragCoord, + Some(Bi::FrontFacing) => crate::BuiltIn::FrontFacing, + Some(Bi::SampleId) => crate::BuiltIn::SampleIndex, + Some(Bi::FragDepth) => crate::BuiltIn::FragDepth, + // compute + Some(Bi::GlobalInvocationId) => crate::BuiltIn::GlobalInvocationId, + Some(Bi::LocalInvocationId) => crate::BuiltIn::LocalInvocationId, + Some(Bi::LocalInvocationIndex) => crate::BuiltIn::LocalInvocationIndex, + Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId, + _ => return Err(Error::UnsupportedBuiltIn(word)), + }) +} diff --git a/third_party/rust/naga/src/front/spv/error.rs b/third_party/rust/naga/src/front/spv/error.rs new file mode 100644 index 0000000000..0ebf603912 --- /dev/null +++ b/third_party/rust/naga/src/front/spv/error.rs @@ -0,0 +1,56 @@ +use super::ModuleState; +use crate::arena::Handle; + +#[derive(Debug)] +pub enum Error { + InvalidHeader, + InvalidWordCount, + UnknownInstruction(u16), + UnknownCapability(spirv::Word), + UnsupportedInstruction(ModuleState, spirv::Op), + UnsupportedCapability(spirv::Capability), + UnsupportedExtension(String), + UnsupportedExtSet(String), + UnsupportedExtInstSet(spirv::Word), + UnsupportedExtInst(spirv::Word), + UnsupportedType(Handle<crate::Type>), + UnsupportedExecutionModel(spirv::Word), + UnsupportedExecutionMode(spirv::Word), + UnsupportedStorageClass(spirv::Word), + UnsupportedImageDim(spirv::Word), + UnsupportedImageFormat(spirv::Word), + UnsupportedBuiltIn(spirv::Word), + UnsupportedControlFlow(spirv::Word), + UnsupportedBinaryOperator(spirv::Word), + InvalidParameter(spirv::Op), + InvalidOperandCount(spirv::Op, u16), + InvalidOperand, + InvalidId(spirv::Word), + InvalidDecoration(spirv::Word), + InvalidTypeWidth(spirv::Word), + InvalidSign(spirv::Word), + InvalidInnerType(spirv::Word), + InvalidVectorSize(spirv::Word), + InvalidVariableClass(spirv::StorageClass), + InvalidAccessType(spirv::Word), + InvalidAccess(Handle<crate::Expression>), + InvalidAccessIndex(spirv::Word), + InvalidBinding(spirv::Word), + InvalidImageExpression(Handle<crate::Expression>), + InvalidImageBaseType(Handle<crate::Type>), + InvalidSamplerExpression(Handle<crate::Expression>), + InvalidSampleImage(Handle<crate::Type>), + InvalidSampleSampler(Handle<crate::Type>), + InvalidSampleCoordinates(Handle<crate::Type>), + InvalidDepthReference(Handle<crate::Type>), + InvalidAsType(Handle<crate::Type>), + InconsistentComparisonSampling(Handle<crate::Type>), + WrongFunctionResultType(spirv::Word), + WrongFunctionArgumentType(spirv::Word), + MissingDecoration(spirv::Decoration), + BadString, + IncompleteData, + InvalidTerminator, + InvalidEdgeClassification, + UnexpectedComparisonType(Handle<crate::Type>), +} diff --git a/third_party/rust/naga/src/front/spv/flow.rs b/third_party/rust/naga/src/front/spv/flow.rs new file mode 100644 index 0000000000..32eac66941 --- /dev/null +++ b/third_party/rust/naga/src/front/spv/flow.rs @@ -0,0 +1,569 @@ +#![allow(dead_code)] + +use super::error::Error; +///! see https://en.wikipedia.org/wiki/Control-flow_graph +///! see https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_structuredcontrolflow_a_structured_control_flow +use super::{ + function::{BlockId, MergeInstruction, Terminator}, + LookupExpression, PhiInstruction, +}; + +use crate::FastHashMap; + +use petgraph::{ + algo::has_path_connecting, + graph::{node_index, NodeIndex}, + visit::EdgeRef, + Directed, Direction, +}; + +use std::fmt::Write; + +/// Index of a block node in the `ControlFlowGraph`. +type BlockNodeIndex = NodeIndex<u32>; + +/// Internal representation of a CFG constisting of function's basic blocks. +type ControlFlowGraph = petgraph::Graph<ControlFlowNode, ControlFlowEdgeType, Directed, u32>; + +/// Control flow graph (CFG) containing relationships between blocks. +pub(super) struct FlowGraph { + /// + flow: ControlFlowGraph, + + /// Block ID to Node index mapping. Internal helper to speed up the classification. + block_to_node: FastHashMap<BlockId, BlockNodeIndex>, +} + +impl FlowGraph { + /// Creates empty flow graph. + pub(super) fn new() -> Self { + Self { + flow: ControlFlowGraph::default(), + block_to_node: FastHashMap::default(), + } + } + + /// Add a control flow node. + pub(super) fn add_node(&mut self, node: ControlFlowNode) { + let block_id = node.id; + let node_index = self.flow.add_node(node); + self.block_to_node.insert(block_id, node_index); + } + + /// + /// 1. Creates edges in the CFG. + /// 2. Classifies types of blocks and edges in the CFG. + pub(super) fn classify(&mut self) { + let block_to_node = &mut self.block_to_node; + + // 1. + // Add all edges + // Classify Nodes as one of [Header, Loop, Kill, Return] + for source_node_index in self.flow.node_indices() { + // Merge edges + if let Some(merge) = self.flow[source_node_index].merge { + let merge_block_index = block_to_node[&merge.merge_block_id]; + + self.flow[source_node_index].ty = Some(ControlFlowNodeType::Header); + self.flow[merge_block_index].ty = Some(ControlFlowNodeType::Merge); + self.flow.add_edge( + source_node_index, + merge_block_index, + ControlFlowEdgeType::ForwardMerge, + ); + + if let Some(continue_block_id) = merge.continue_block_id { + let continue_block_index = block_to_node[&continue_block_id]; + + self.flow[source_node_index].ty = Some(ControlFlowNodeType::Loop); + self.flow.add_edge( + source_node_index, + continue_block_index, + ControlFlowEdgeType::ForwardContinue, + ); + } + } + + // Branch Edges + let terminator = self.flow[source_node_index].terminator.clone(); + match terminator { + Terminator::Branch { target_id } => { + let target_node_index = block_to_node[&target_id]; + + self.flow.add_edge( + source_node_index, + target_node_index, + ControlFlowEdgeType::Forward, + ); + } + Terminator::BranchConditional { + true_id, false_id, .. + } => { + let true_node_index = block_to_node[&true_id]; + let false_node_index = block_to_node[&false_id]; + + self.flow.add_edge( + source_node_index, + true_node_index, + ControlFlowEdgeType::IfTrue, + ); + self.flow.add_edge( + source_node_index, + false_node_index, + ControlFlowEdgeType::IfFalse, + ); + } + Terminator::Switch { + selector: _, + default, + ref targets, + } => { + let default_node_index = block_to_node[&default]; + + self.flow.add_edge( + source_node_index, + default_node_index, + ControlFlowEdgeType::Forward, + ); + + for (_, target_block_id) in targets.iter() { + let target_node_index = block_to_node[&target_block_id]; + + self.flow.add_edge( + source_node_index, + target_node_index, + ControlFlowEdgeType::Forward, + ); + } + } + Terminator::Return { .. } => { + self.flow[source_node_index].ty = Some(ControlFlowNodeType::Return) + } + Terminator::Kill => { + self.flow[source_node_index].ty = Some(ControlFlowNodeType::Kill) + } + _ => {} + }; + } + + // 2. + // Classify Nodes/Edges as one of [Break, Continue, Back] + for edge_index in self.flow.edge_indices() { + let (node_source_index, node_target_index) = + self.flow.edge_endpoints(edge_index).unwrap(); + + if self.flow[node_source_index].ty == Some(ControlFlowNodeType::Header) + || self.flow[node_source_index].ty == Some(ControlFlowNodeType::Loop) + { + continue; + } + + // Back + if self.flow[node_target_index].ty == Some(ControlFlowNodeType::Loop) + && self.flow[node_source_index].id > self.flow[node_target_index].id + { + self.flow[node_source_index].ty = Some(ControlFlowNodeType::Back); + self.flow[edge_index] = ControlFlowEdgeType::Back; + } + + let mut target_incoming_edges = self + .flow + .neighbors_directed(node_target_index, Direction::Incoming) + .detach(); + while let Some((incoming_edge, incoming_source)) = + target_incoming_edges.next(&self.flow) + { + // Loop continue + if self.flow[incoming_edge] == ControlFlowEdgeType::ForwardContinue { + self.flow[node_source_index].ty = Some(ControlFlowNodeType::Continue); + self.flow[edge_index] = ControlFlowEdgeType::LoopContinue; + } + // Loop break + if self.flow[incoming_source].ty == Some(ControlFlowNodeType::Loop) + && self.flow[incoming_edge] == ControlFlowEdgeType::ForwardMerge + { + self.flow[node_source_index].ty = Some(ControlFlowNodeType::Break); + self.flow[edge_index] = ControlFlowEdgeType::LoopBreak; + } + } + } + } + + /// Removes OpPhi instructions from the control flow graph and turns them into ordinary variables. + /// + /// Phi instructions are not supported inside Naga nor do they exist as instructions on CPUs. It is neccessary + /// to remove them and turn into ordinary variables before converting to Naga's IR and shader code. + pub(super) fn remove_phi_instructions( + &mut self, + lookup_expression: &FastHashMap<spirv::Word, LookupExpression>, + ) { + for node_index in self.flow.node_indices() { + let phis = std::mem::replace(&mut self.flow[node_index].phis, Vec::new()); + for phi in phis.iter() { + let phi_var = &lookup_expression[&phi.id]; + for (variable_id, parent_id) in phi.variables.iter() { + let variable = &lookup_expression[&variable_id]; + let parent_node = &mut self.flow[self.block_to_node[&parent_id]]; + + parent_node.block.push(crate::Statement::Store { + pointer: phi_var.handle, + value: variable.handle, + }); + } + } + self.flow[node_index].phis = phis; + } + } + + /// Traverses the flow graph and returns a list of Naga's statements. + pub(super) fn to_naga(&self) -> Result<crate::Block, Error> { + self.naga_traverse(node_index(0), None) + } + + fn naga_traverse( + &self, + node_index: BlockNodeIndex, + stop_node_index: Option<BlockNodeIndex>, + ) -> Result<crate::Block, Error> { + if stop_node_index == Some(node_index) { + return Ok(vec![]); + } + + let node = &self.flow[node_index]; + + match node.ty { + Some(ControlFlowNodeType::Header) => match node.terminator { + Terminator::BranchConditional { + condition, + true_id, + false_id, + } => { + let true_node_index = self.block_to_node[&true_id]; + let false_node_index = self.block_to_node[&false_id]; + let merge_node_index = self.block_to_node[&node.merge.unwrap().merge_block_id]; + + let mut result = node.block.clone(); + + if false_node_index != merge_node_index { + result.push(crate::Statement::If { + condition, + accept: self.naga_traverse(true_node_index, Some(merge_node_index))?, + reject: self.naga_traverse(false_node_index, Some(merge_node_index))?, + }); + result.extend(self.naga_traverse(merge_node_index, stop_node_index)?); + } else { + result.push(crate::Statement::If { + condition, + accept: self.naga_traverse( + self.block_to_node[&true_id], + Some(merge_node_index), + )?, + reject: self.naga_traverse(merge_node_index, stop_node_index)?, + }); + } + + Ok(result) + } + Terminator::Switch { + selector, + default, + ref targets, + } => { + let merge_node_index = self.block_to_node[&node.merge.unwrap().merge_block_id]; + let mut result = node.block.clone(); + + let mut cases = FastHashMap::default(); + + for i in 0..targets.len() { + let left_target_node_index = self.block_to_node[&targets[i].1]; + + let fallthrough: Option<crate::FallThrough> = if i < targets.len() - 1 { + let right_target_node_index = self.block_to_node[&targets[i + 1].1]; + if has_path_connecting( + &self.flow, + left_target_node_index, + right_target_node_index, + None, + ) { + Some(crate::FallThrough {}) + } else { + None + } + } else { + None + }; + + cases.insert( + targets[i].0, + ( + self.naga_traverse(left_target_node_index, Some(merge_node_index))?, + fallthrough, + ), + ); + } + + result.push(crate::Statement::Switch { + selector, + cases, + default: self + .naga_traverse(self.block_to_node[&default], Some(merge_node_index))?, + }); + + result.extend(self.naga_traverse(merge_node_index, stop_node_index)?); + + Ok(result) + } + _ => Err(Error::InvalidTerminator), + }, + Some(ControlFlowNodeType::Loop) => { + let merge_node_index = self.block_to_node[&node.merge.unwrap().merge_block_id]; + let continuing: crate::Block = { + let continue_edge = self + .flow + .edges_directed(node_index, Direction::Outgoing) + .find(|&ty| *ty.weight() == ControlFlowEdgeType::ForwardContinue) + .unwrap(); + + self.flow[continue_edge.target()].block.clone() + }; + + let mut body = node.block.clone(); + match node.terminator { + Terminator::BranchConditional { + condition, + true_id, + false_id, + } => body.push(crate::Statement::If { + condition, + accept: self + .naga_traverse(self.block_to_node[&true_id], Some(merge_node_index))?, + reject: self + .naga_traverse(self.block_to_node[&false_id], Some(merge_node_index))?, + }), + Terminator::Branch { target_id } => body.extend( + self.naga_traverse(self.block_to_node[&target_id], Some(merge_node_index))?, + ), + _ => return Err(Error::InvalidTerminator), + }; + + let mut result = vec![crate::Statement::Loop { body, continuing }]; + result.extend(self.naga_traverse(merge_node_index, stop_node_index)?); + + Ok(result) + } + Some(ControlFlowNodeType::Break) => { + let mut result = node.block.clone(); + match node.terminator { + Terminator::BranchConditional { + condition, + true_id, + false_id, + } => { + let true_node_id = self.block_to_node[&true_id]; + let false_node_id = self.block_to_node[&false_id]; + + let true_edge = + self.flow[self.flow.find_edge(node_index, true_node_id).unwrap()]; + let false_edge = + self.flow[self.flow.find_edge(node_index, false_node_id).unwrap()]; + + if true_edge == ControlFlowEdgeType::LoopBreak { + result.push(crate::Statement::If { + condition, + accept: vec![crate::Statement::Break], + reject: self.naga_traverse(false_node_id, stop_node_index)?, + }); + } else if false_edge == ControlFlowEdgeType::LoopBreak { + result.push(crate::Statement::If { + condition, + accept: self.naga_traverse(true_node_id, stop_node_index)?, + reject: vec![crate::Statement::Break], + }); + } else { + return Err(Error::InvalidEdgeClassification); + } + } + Terminator::Branch { .. } => { + result.push(crate::Statement::Break); + } + _ => return Err(Error::InvalidTerminator), + }; + Ok(result) + } + Some(ControlFlowNodeType::Continue) => { + let back_block = match node.terminator { + Terminator::Branch { target_id } => { + self.naga_traverse(self.block_to_node[&target_id], None)? + } + _ => return Err(Error::InvalidTerminator), + }; + + let mut result = node.block.clone(); + result.extend(back_block); + result.push(crate::Statement::Continue); + Ok(result) + } + Some(ControlFlowNodeType::Back) => Ok(node.block.clone()), + Some(ControlFlowNodeType::Kill) => { + let mut result = node.block.clone(); + result.push(crate::Statement::Kill); + Ok(result) + } + Some(ControlFlowNodeType::Return) => { + let value = match node.terminator { + Terminator::Return { value } => value, + _ => return Err(Error::InvalidTerminator), + }; + let mut result = node.block.clone(); + result.push(crate::Statement::Return { value }); + Ok(result) + } + Some(ControlFlowNodeType::Merge) | None => match node.terminator { + Terminator::Branch { target_id } => { + let mut result = node.block.clone(); + result.extend( + self.naga_traverse(self.block_to_node[&target_id], stop_node_index)?, + ); + Ok(result) + } + _ => Ok(node.block.clone()), + }, + } + } + + /// Get the entire graph in a graphviz dot format for visualization. Useful for debugging purposes. + pub(super) fn to_graphviz(&self) -> Result<String, std::fmt::Error> { + let mut output = String::new(); + + output += "digraph ControlFlowGraph {\n"; + + for node_index in self.flow.node_indices() { + let node = &self.flow[node_index]; + writeln!( + output, + "{} [ label = \"%{} {:?}\" ]", + node_index.index(), + node.id, + node.ty + )?; + } + + for edge in self.flow.raw_edges() { + let source = edge.source(); + let target = edge.target(); + + let style = match edge.weight { + ControlFlowEdgeType::Forward => "", + ControlFlowEdgeType::ForwardMerge => "style=dotted", + ControlFlowEdgeType::ForwardContinue => "color=green", + ControlFlowEdgeType::Back => "style=dashed", + ControlFlowEdgeType::LoopBreak => "color=yellow", + ControlFlowEdgeType::LoopContinue => "color=green", + ControlFlowEdgeType::IfTrue => "color=blue", + ControlFlowEdgeType::IfFalse => "color=red", + ControlFlowEdgeType::SwitchBreak => "color=yellow", + ControlFlowEdgeType::CaseFallThrough => "style=dotted", + }; + + writeln!( + &mut output, + "{} -> {} [ {} ]", + source.index(), + target.index(), + style + )?; + } + + output += "}\n"; + + Ok(output) + } +} + +/// Type of an edge(flow) in the `ControlFlowGraph`. +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +pub(super) enum ControlFlowEdgeType { + /// Default + Forward, + + /// Forward edge to a merge block. + ForwardMerge, + + /// Forward edge to a OpLoopMerge continue's instruction. + ForwardContinue, + + /// A back-edge: An edge from a node to one of its ancestors in a depth-first + /// search from the entry block. + /// Can only be to a ControlFlowNodeType::Loop. + Back, + + /// An edge from a node to the merge block of the nearest enclosing loop, where + /// there is no intervening switch. + /// The source block is a "break block" as defined by SPIR-V. + LoopBreak, + + /// An edge from a node in a loop body to the associated continue target, where + /// there are no other intervening loops or switches. + /// The source block is a "continue block" as defined by SPIR-V. + LoopContinue, + + /// An edge from a node with OpBranchConditional to the block of true operand. + IfTrue, + + /// An edge from a node with OpBranchConditional to the block of false operand. + IfFalse, + + /// An edge from a node to the merge block of the nearest enclosing switch, + /// where there is no intervening loop. + SwitchBreak, + + /// An edge from one switch case to the next sibling switch case. + CaseFallThrough, +} +/// Type of a node(block) in the `ControlFlowGraph`. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub(super) enum ControlFlowNodeType { + /// A block whose merge instruction is an OpSelectionMerge. + Header, + + /// A header block whose merge instruction is an OpLoopMerge. + Loop, + + /// A block declared by the Merge Block operand of a merge instruction. + Merge, + + /// A block containing a branch to the Merge Block of a loop header’s merge instruction. + Break, + + /// A block containing a branch to an OpLoopMerge instruction’s Continue Target. + Continue, + + /// A block containing an OpBranch to a Loop block. + Back, + + /// A block containing an OpKill instruction. + Kill, + + /// A block containing an OpReturn or OpReturnValue branch. + Return, +} +/// ControlFlowGraph's node representing a block in the control flow. +pub(super) struct ControlFlowNode { + /// SPIR-V ID. + pub id: BlockId, + + /// Type of the node. See *ControlFlowNodeType*. + pub ty: Option<ControlFlowNodeType>, + + /// Phi instructions. + pub phis: Vec<PhiInstruction>, + + /// Naga's statements inside this block. + pub block: crate::Block, + + /// Termination instruction of the block. + pub terminator: Terminator, + + /// Merge Instruction + pub merge: Option<MergeInstruction>, +} diff --git a/third_party/rust/naga/src/front/spv/function.rs b/third_party/rust/naga/src/front/spv/function.rs new file mode 100644 index 0000000000..d2cb0551a1 --- /dev/null +++ b/third_party/rust/naga/src/front/spv/function.rs @@ -0,0 +1,202 @@ +use crate::arena::Handle; + +use super::flow::*; +use super::*; + +pub type BlockId = u32; + +#[derive(Copy, Clone, Debug)] +pub struct MergeInstruction { + pub merge_block_id: BlockId, + pub continue_block_id: Option<BlockId>, +} +/// Terminator instruction of a SPIR-V's block. +#[derive(Clone, Debug)] +#[allow(dead_code)] +pub enum Terminator { + /// + Return { + value: Option<Handle<crate::Expression>>, + }, + /// + Branch { target_id: BlockId }, + /// + BranchConditional { + condition: Handle<crate::Expression>, + true_id: BlockId, + false_id: BlockId, + }, + /// + /// switch(SELECTOR) { + /// case TARGET_LITERAL#: { + /// TARGET_BLOCK# + /// } + /// default: { + /// DEFAULT + /// } + /// } + Switch { + /// + selector: Handle<crate::Expression>, + /// Default block of the switch case. + default: BlockId, + /// Tuples of (literal, target block) + targets: Vec<(i32, BlockId)>, + }, + /// Fragment shader discard + Kill, + /// + Unreachable, +} + +impl<I: Iterator<Item = u32>> super::Parser<I> { + pub fn parse_function( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Function, inst.op)?; + inst.expect(5)?; + let result_type = self.next()?; + let fun_id = self.next()?; + let _fun_control = self.next()?; + let fun_type = self.next()?; + + let mut fun = { + let ft = self.lookup_function_type.lookup(fun_type)?; + if ft.return_type_id != result_type { + return Err(Error::WrongFunctionResultType(result_type)); + } + crate::Function { + name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name), + arguments: Vec::with_capacity(ft.parameter_type_ids.len()), + return_type: if self.lookup_void_type.contains(&result_type) { + None + } else { + Some(self.lookup_type.lookup(result_type)?.handle) + }, + global_usage: Vec::new(), + local_variables: Arena::new(), + expressions: self.make_expression_storage(), + body: Vec::new(), + } + }; + + // read parameters + for i in 0..fun.arguments.capacity() { + match self.next_inst()? { + Instruction { + op: spirv::Op::FunctionParameter, + wc: 3, + } => { + let type_id = self.next()?; + let id = self.next()?; + let handle = fun + .expressions + .append(crate::Expression::FunctionArgument(i as u32)); + self.lookup_expression + .insert(id, LookupExpression { type_id, handle }); + //Note: we redo the lookup in order to work around `self` borrowing + + if type_id + != self + .lookup_function_type + .lookup(fun_type)? + .parameter_type_ids[i] + { + return Err(Error::WrongFunctionArgumentType(type_id)); + } + let ty = self.lookup_type.lookup(type_id)?.handle; + fun.arguments + .push(crate::FunctionArgument { name: None, ty }); + } + Instruction { op, .. } => return Err(Error::InvalidParameter(op)), + } + } + + // Read body + let mut local_function_calls = FastHashMap::default(); + let mut flow_graph = FlowGraph::new(); + + // Scan the blocks and add them as nodes + loop { + let fun_inst = self.next_inst()?; + log::debug!("{:?}", fun_inst.op); + match fun_inst.op { + spirv::Op::Label => { + // Read the label ID + fun_inst.expect(2)?; + let block_id = self.next()?; + + let node = self.next_block( + block_id, + &mut fun.expressions, + &mut fun.local_variables, + &module.types, + &module.constants, + &module.global_variables, + &mut local_function_calls, + )?; + + flow_graph.add_node(node); + } + spirv::Op::FunctionEnd => { + fun_inst.expect(1)?; + break; + } + _ => { + return Err(Error::UnsupportedInstruction(self.state, fun_inst.op)); + } + } + } + + flow_graph.classify(); + flow_graph.remove_phi_instructions(&self.lookup_expression); + fun.body = flow_graph.to_naga()?; + + // done + fun.fill_global_use(&module.global_variables); + + let source = match self.lookup_entry_point.remove(&fun_id) { + Some(ep) => { + module.entry_points.insert( + (ep.stage, ep.name.clone()), + crate::EntryPoint { + early_depth_test: ep.early_depth_test, + workgroup_size: ep.workgroup_size, + function: fun, + }, + ); + DeferredSource::EntryPoint(ep.stage, ep.name) + } + None => { + let handle = module.functions.append(fun); + self.lookup_function.insert(fun_id, handle); + DeferredSource::Function(handle) + } + }; + + if let Some(ref prefix) = self.options.flow_graph_dump_prefix { + let dump = flow_graph.to_graphviz().unwrap_or_default(); + let suffix = match source { + DeferredSource::EntryPoint(stage, ref name) => { + format!("flow.{:?}-{}.dot", stage, name) + } + DeferredSource::Function(handle) => format!("flow.Fun-{}.dot", handle.index()), + }; + let _ = std::fs::write(prefix.join(suffix), dump); + } + + for (expr_handle, dst_id) in local_function_calls { + self.deferred_function_calls.push(DeferredFunctionCall { + source: source.clone(), + expr_handle, + dst_id, + }); + } + + self.lookup_expression.clear(); + self.lookup_sampled_image.clear(); + Ok(()) + } +} diff --git a/third_party/rust/naga/src/front/spv/mod.rs b/third_party/rust/naga/src/front/spv/mod.rs new file mode 100644 index 0000000000..f13c3c7577 --- /dev/null +++ b/third_party/rust/naga/src/front/spv/mod.rs @@ -0,0 +1,2416 @@ +/*! SPIR-V frontend + +## ID lookups + +Our IR links to everything with `Handle`, while SPIR-V uses IDs. +In order to keep track of the associations, the parser has many lookup tables. +There map `spv::Word` into a specific IR handle, plus potentially a bit of +extra info, such as the related SPIR-V type ID. +TODO: would be nice to find ways that avoid looking up as much + +!*/ +#![allow(dead_code)] + +mod convert; +mod error; +mod flow; +mod function; +#[cfg(all(test, feature = "serialize"))] +mod rosetta; + +use convert::*; +use error::Error; +use flow::*; +use function::*; + +use crate::{ + arena::{Arena, Handle}, + FastHashMap, FastHashSet, +}; + +use num_traits::cast::FromPrimitive; +use std::{convert::TryInto, num::NonZeroU32, path::PathBuf}; + +pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[ + spirv::Capability::Shader, + spirv::Capability::CullDistance, + spirv::Capability::StorageImageExtendedFormats, +]; +pub const SUPPORTED_EXTENSIONS: &[&str] = &[]; +pub const SUPPORTED_EXT_SETS: &[&str] = &["GLSL.std.450"]; + +#[derive(Copy, Clone)] +pub struct Instruction { + op: spirv::Op, + wc: u16, +} + +impl Instruction { + fn expect(self, count: u16) -> Result<(), Error> { + if self.wc == count { + Ok(()) + } else { + Err(Error::InvalidOperandCount(self.op, self.wc)) + } + } + + fn expect_at_least(self, count: u16) -> Result<(), Error> { + if self.wc >= count { + Ok(()) + } else { + Err(Error::InvalidOperandCount(self.op, self.wc)) + } + } +} +/// OpPhi instruction. +#[derive(Clone, Default, Debug)] +struct PhiInstruction { + /// SPIR-V's ID. + id: u32, + + /// Tuples of (variable, parent). + variables: Vec<(u32, u32)>, +} +#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] +pub enum ModuleState { + Empty, + Capability, + Extension, + ExtInstImport, + MemoryModel, + EntryPoint, + ExecutionMode, + Source, + Name, + ModuleProcessed, + Annotation, + Type, + Function, +} + +trait LookupHelper { + type Target; + fn lookup(&self, key: spirv::Word) -> Result<&Self::Target, Error>; +} + +impl<T> LookupHelper for FastHashMap<spirv::Word, T> { + type Target = T; + fn lookup(&self, key: spirv::Word) -> Result<&T, Error> { + self.get(&key).ok_or(Error::InvalidId(key)) + } +} + +//TODO: this method may need to be gone, depending on whether +// WGSL allows treating images and samplers as expressions and pass them around. +fn reach_global_type( + mut expr_handle: Handle<crate::Expression>, + expressions: &Arena<crate::Expression>, + globals: &Arena<crate::GlobalVariable>, +) -> Option<Handle<crate::Type>> { + loop { + expr_handle = match expressions[expr_handle] { + crate::Expression::Load { pointer } => pointer, + crate::Expression::GlobalVariable(var) => return Some(globals[var].ty), + _ => return None, + }; + } +} + +fn check_sample_coordinates( + ty: &crate::Type, + expect_kind: crate::ScalarKind, + dim: crate::ImageDimension, + is_array: bool, +) -> bool { + let base_count = match dim { + crate::ImageDimension::D1 => 1, + crate::ImageDimension::D2 => 2, + crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3, + }; + let extra_count = if is_array { 1 } else { 0 }; + let count = base_count + extra_count; + match ty.inner { + crate::TypeInner::Scalar { kind, width: _ } => count == 1 && kind == expect_kind, + crate::TypeInner::Vector { + size, + kind, + width: _, + } => size as u8 == count && kind == expect_kind, + _ => false, + } +} + +type MemberIndex = u32; + +#[derive(Debug, Default)] +struct Block { + buffer: bool, +} + +bitflags::bitflags! { + #[derive(Default)] + struct DecorationFlags: u32 { + const NON_READABLE = 0x1; + const NON_WRITABLE = 0x2; + } +} + +#[derive(Debug, Default)] +struct Decoration { + name: Option<String>, + built_in: Option<crate::BuiltIn>, + location: Option<spirv::Word>, + desc_set: Option<spirv::Word>, + desc_index: Option<spirv::Word>, + block: Option<Block>, + offset: Option<spirv::Word>, + array_stride: Option<NonZeroU32>, + interpolation: Option<crate::Interpolation>, + flags: DecorationFlags, +} + +impl Decoration { + fn debug_name(&self) -> &str { + match self.name { + Some(ref name) => name.as_str(), + None => "?", + } + } + + fn get_binding(&self) -> Option<crate::Binding> { + //TODO: validate this better + match *self { + Decoration { + built_in: Some(built_in), + location: None, + desc_set: None, + desc_index: None, + .. + } => Some(crate::Binding::BuiltIn(built_in)), + Decoration { + built_in: None, + location: Some(loc), + desc_set: None, + desc_index: None, + .. + } => Some(crate::Binding::Location(loc)), + Decoration { + built_in: None, + location: None, + desc_set: Some(group), + desc_index: Some(binding), + .. + } => Some(crate::Binding::Resource { group, binding }), + _ => None, + } + } + + fn get_origin(&self) -> Result<crate::MemberOrigin, Error> { + match *self { + Decoration { + location: Some(_), .. + } + | Decoration { + desc_set: Some(_), .. + } + | Decoration { + desc_index: Some(_), + .. + } => Err(Error::MissingDecoration(spirv::Decoration::Offset)), + Decoration { + built_in: Some(built_in), + offset: None, + .. + } => Ok(crate::MemberOrigin::BuiltIn(built_in)), + Decoration { + built_in: None, + offset: Some(offset), + .. + } => Ok(crate::MemberOrigin::Offset(offset)), + _ => Ok(crate::MemberOrigin::Empty), + } + } +} + +bitflags::bitflags! { + /// Flags describing sampling method. + pub struct SamplingFlags: u32 { + /// Regular sampling. + const REGULAR = 0x1; + /// Comparison sampling. + const COMPARISON = 0x2; + } +} + +#[derive(Debug)] +struct LookupFunctionType { + parameter_type_ids: Vec<spirv::Word>, + return_type_id: spirv::Word, +} + +#[derive(Debug)] +struct EntryPoint { + stage: crate::ShaderStage, + name: String, + early_depth_test: Option<crate::EarlyDepthTest>, + workgroup_size: [u32; 3], + function_id: spirv::Word, + variable_ids: Vec<spirv::Word>, +} + +#[derive(Clone, Debug)] +struct LookupType { + handle: Handle<crate::Type>, + base_id: Option<spirv::Word>, +} + +#[derive(Debug)] +struct LookupConstant { + handle: Handle<crate::Constant>, + type_id: spirv::Word, +} + +#[derive(Debug)] +struct LookupVariable { + handle: Handle<crate::GlobalVariable>, + type_id: spirv::Word, +} + +#[derive(Clone, Debug)] +struct LookupExpression { + handle: Handle<crate::Expression>, + type_id: spirv::Word, +} + +#[derive(Clone, Debug)] +struct LookupSampledImage { + image: Handle<crate::Expression>, + sampler: Handle<crate::Expression>, +} +#[derive(Clone, Debug)] +enum DeferredSource { + EntryPoint(crate::ShaderStage, String), + Function(Handle<crate::Function>), +} +struct DeferredFunctionCall { + source: DeferredSource, + expr_handle: Handle<crate::Expression>, + dst_id: spirv::Word, +} + +#[derive(Clone, Debug)] +pub struct Assignment { + to: Handle<crate::Expression>, + value: Handle<crate::Expression>, +} + +#[derive(Clone, Debug, Default)] +pub struct Options { + pub flow_graph_dump_prefix: Option<PathBuf>, +} + +pub struct Parser<I> { + data: I, + state: ModuleState, + temp_bytes: Vec<u8>, + ext_glsl_id: Option<spirv::Word>, + future_decor: FastHashMap<spirv::Word, Decoration>, + future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>, + lookup_member_type_id: FastHashMap<(Handle<crate::Type>, MemberIndex), spirv::Word>, + handle_sampling: FastHashMap<Handle<crate::Type>, SamplingFlags>, + lookup_type: FastHashMap<spirv::Word, LookupType>, + lookup_void_type: FastHashSet<spirv::Word>, + lookup_storage_buffer_types: FastHashSet<Handle<crate::Type>>, + // Lookup for samplers and sampled images, storing flags on how they are used. + lookup_constant: FastHashMap<spirv::Word, LookupConstant>, + lookup_variable: FastHashMap<spirv::Word, LookupVariable>, + lookup_expression: FastHashMap<spirv::Word, LookupExpression>, + lookup_sampled_image: FastHashMap<spirv::Word, LookupSampledImage>, + lookup_function_type: FastHashMap<spirv::Word, LookupFunctionType>, + lookup_function: FastHashMap<spirv::Word, Handle<crate::Function>>, + lookup_entry_point: FastHashMap<spirv::Word, EntryPoint>, + deferred_function_calls: Vec<DeferredFunctionCall>, + options: Options, +} + +impl<I: Iterator<Item = u32>> Parser<I> { + pub fn new(data: I, options: &Options) -> Self { + Parser { + data, + state: ModuleState::Empty, + temp_bytes: Vec::new(), + ext_glsl_id: None, + future_decor: FastHashMap::default(), + future_member_decor: FastHashMap::default(), + handle_sampling: FastHashMap::default(), + lookup_member_type_id: FastHashMap::default(), + lookup_type: FastHashMap::default(), + lookup_void_type: FastHashSet::default(), + lookup_storage_buffer_types: FastHashSet::default(), + lookup_constant: FastHashMap::default(), + lookup_variable: FastHashMap::default(), + lookup_expression: FastHashMap::default(), + lookup_sampled_image: FastHashMap::default(), + lookup_function_type: FastHashMap::default(), + lookup_function: FastHashMap::default(), + lookup_entry_point: FastHashMap::default(), + deferred_function_calls: Vec::new(), + options: options.clone(), + } + } + + fn next(&mut self) -> Result<u32, Error> { + self.data.next().ok_or(Error::IncompleteData) + } + + fn next_inst(&mut self) -> Result<Instruction, Error> { + let word = self.next()?; + let (wc, opcode) = ((word >> 16) as u16, (word & 0xffff) as u16); + if wc == 0 { + return Err(Error::InvalidWordCount); + } + let op = spirv::Op::from_u16(opcode).ok_or(Error::UnknownInstruction(opcode))?; + + Ok(Instruction { op, wc }) + } + + fn next_string(&mut self, mut count: u16) -> Result<(String, u16), Error> { + self.temp_bytes.clear(); + loop { + if count == 0 { + return Err(Error::BadString); + } + count -= 1; + let chars = self.next()?.to_le_bytes(); + let pos = chars.iter().position(|&c| c == 0).unwrap_or(4); + self.temp_bytes.extend_from_slice(&chars[..pos]); + if pos < 4 { + break; + } + } + std::str::from_utf8(&self.temp_bytes) + .map(|s| (s.to_owned(), count)) + .map_err(|_| Error::BadString) + } + + fn next_decoration( + &mut self, + inst: Instruction, + base_words: u16, + dec: &mut Decoration, + ) -> Result<(), Error> { + let raw = self.next()?; + let dec_typed = spirv::Decoration::from_u32(raw).ok_or(Error::InvalidDecoration(raw))?; + log::trace!("\t\t{}: {:?}", dec.debug_name(), dec_typed); + match dec_typed { + spirv::Decoration::BuiltIn => { + inst.expect(base_words + 2)?; + let raw = self.next()?; + match map_builtin(raw) { + Ok(built_in) => dec.built_in = Some(built_in), + Err(_e) => log::warn!("Unsupported builtin {}", raw), + }; + } + spirv::Decoration::Location => { + inst.expect(base_words + 2)?; + dec.location = Some(self.next()?); + } + spirv::Decoration::DescriptorSet => { + inst.expect(base_words + 2)?; + dec.desc_set = Some(self.next()?); + } + spirv::Decoration::Binding => { + inst.expect(base_words + 2)?; + dec.desc_index = Some(self.next()?); + } + spirv::Decoration::Block => { + dec.block = Some(Block { buffer: false }); + } + spirv::Decoration::BufferBlock => { + dec.block = Some(Block { buffer: true }); + } + spirv::Decoration::Offset => { + inst.expect(base_words + 2)?; + dec.offset = Some(self.next()?); + } + spirv::Decoration::ArrayStride => { + inst.expect(base_words + 2)?; + dec.array_stride = NonZeroU32::new(self.next()?); + } + spirv::Decoration::NoPerspective => { + dec.interpolation = Some(crate::Interpolation::Linear); + } + spirv::Decoration::Flat => { + dec.interpolation = Some(crate::Interpolation::Flat); + } + spirv::Decoration::Patch => { + dec.interpolation = Some(crate::Interpolation::Patch); + } + spirv::Decoration::Centroid => { + dec.interpolation = Some(crate::Interpolation::Centroid); + } + spirv::Decoration::Sample => { + dec.interpolation = Some(crate::Interpolation::Sample); + } + spirv::Decoration::NonReadable => { + dec.flags |= DecorationFlags::NON_READABLE; + } + spirv::Decoration::NonWritable => { + dec.flags |= DecorationFlags::NON_WRITABLE; + } + other => { + log::warn!("Unknown decoration {:?}", other); + for _ in base_words + 1..inst.wc { + let _var = self.next()?; + } + } + } + Ok(()) + } + + fn parse_expr_unary_op( + &mut self, + expressions: &mut Arena<crate::Expression>, + op: crate::UnaryOperator, + ) -> Result<(), Error> { + let result_type_id = self.next()?; + let result_id = self.next()?; + let p_id = self.next()?; + + let p_lexp = self.lookup_expression.lookup(p_id)?; + + let expr = crate::Expression::Unary { + op, + expr: p_lexp.handle, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expressions.append(expr), + type_id: result_type_id, + }, + ); + Ok(()) + } + + fn parse_expr_binary_op( + &mut self, + expressions: &mut Arena<crate::Expression>, + op: crate::BinaryOperator, + ) -> Result<(), Error> { + let result_type_id = self.next()?; + let result_id = self.next()?; + let p1_id = self.next()?; + let p2_id = self.next()?; + + let p1_lexp = self.lookup_expression.lookup(p1_id)?; + let p2_lexp = self.lookup_expression.lookup(p2_id)?; + + let expr = crate::Expression::Binary { + op, + left: p1_lexp.handle, + right: p2_lexp.handle, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expressions.append(expr), + type_id: result_type_id, + }, + ); + Ok(()) + } + + #[allow(clippy::too_many_arguments)] + fn next_block( + &mut self, + block_id: spirv::Word, + expressions: &mut Arena<crate::Expression>, + local_arena: &mut Arena<crate::LocalVariable>, + type_arena: &Arena<crate::Type>, + const_arena: &Arena<crate::Constant>, + global_arena: &Arena<crate::GlobalVariable>, + local_function_calls: &mut FastHashMap<Handle<crate::Expression>, spirv::Word>, + ) -> Result<ControlFlowNode, Error> { + let mut assignments = Vec::new(); + let mut phis = Vec::new(); + let mut merge = None; + let terminator = loop { + use spirv::Op; + let inst = self.next_inst()?; + log::debug!("\t\t{:?} [{}]", inst.op, inst.wc); + + match inst.op { + Op::Variable => { + inst.expect_at_least(4)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let storage = self.next()?; + match spirv::StorageClass::from_u32(storage) { + Some(spirv::StorageClass::Function) => (), + Some(class) => return Err(Error::InvalidVariableClass(class)), + None => return Err(Error::UnsupportedStorageClass(storage)), + } + let init = if inst.wc > 4 { + inst.expect(5)?; + let init_id = self.next()?; + let lconst = self.lookup_constant.lookup(init_id)?; + Some(lconst.handle) + } else { + None + }; + let name = self + .future_decor + .remove(&result_id) + .and_then(|decor| decor.name); + if let Some(ref name) = name { + log::debug!("\t\t\tid={} name={}", result_id, name); + } + let var_handle = local_arena.append(crate::LocalVariable { + name, + ty: self.lookup_type.lookup(result_type_id)?.handle, + init, + }); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expressions + .append(crate::Expression::LocalVariable(var_handle)), + type_id: result_type_id, + }, + ); + } + Op::Phi => { + inst.expect_at_least(3)?; + + let result_type_id = self.next()?; + let result_id = self.next()?; + + let name = format!("phi_{}", result_id); + let var_handle = local_arena.append(crate::LocalVariable { + name: Some(name), + ty: self.lookup_type.lookup(result_type_id)?.handle, + init: None, + }); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expressions + .append(crate::Expression::LocalVariable(var_handle)), + type_id: result_type_id, + }, + ); + + let mut phi = PhiInstruction::default(); + phi.id = result_id; + for _ in 0..(inst.wc - 3) / 2 { + phi.variables.push((self.next()?, self.next()?)); + } + + phis.push(phi); + } + Op::AccessChain => { + struct AccessExpression { + base_handle: Handle<crate::Expression>, + type_id: spirv::Word, + } + inst.expect_at_least(4)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let base_id = self.next()?; + log::trace!("\t\t\tlooking up expr {:?}", base_id); + let mut acex = { + let expr = self.lookup_expression.lookup(base_id)?; + AccessExpression { + base_handle: expr.handle, + type_id: expr.type_id, + } + }; + for _ in 4..inst.wc { + let access_id = self.next()?; + log::trace!("\t\t\tlooking up index expr {:?}", access_id); + let index_expr = self.lookup_expression.lookup(access_id)?.clone(); + let index_type_handle = self.lookup_type.lookup(index_expr.type_id)?.handle; + match type_arena[index_type_handle].inner { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + .. + } + | crate::TypeInner::Scalar { + kind: crate::ScalarKind::Sint, + .. + } => (), + _ => return Err(Error::UnsupportedType(index_type_handle)), + } + log::trace!("\t\t\tlooking up type {:?}", acex.type_id); + let type_lookup = self.lookup_type.lookup(acex.type_id)?; + acex = match type_arena[type_lookup.handle].inner { + crate::TypeInner::Struct { .. } => { + let index = match expressions[index_expr.handle] { + crate::Expression::Constant(const_handle) => { + match const_arena[const_handle].inner { + crate::ConstantInner::Uint(v) => v as u32, + crate::ConstantInner::Sint(v) => v as u32, + _ => { + return Err(Error::InvalidAccess(index_expr.handle)) + } + } + } + _ => return Err(Error::InvalidAccess(index_expr.handle)), + }; + AccessExpression { + base_handle: expressions.append( + crate::Expression::AccessIndex { + base: acex.base_handle, + index, + }, + ), + type_id: *self + .lookup_member_type_id + .get(&(type_lookup.handle, index)) + .ok_or(Error::InvalidAccessType(acex.type_id))?, + } + } + crate::TypeInner::Array { .. } + | crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } => AccessExpression { + base_handle: expressions.append(crate::Expression::Access { + base: acex.base_handle, + index: index_expr.handle, + }), + type_id: type_lookup + .base_id + .ok_or(Error::InvalidAccessType(acex.type_id))?, + }, + _ => return Err(Error::UnsupportedType(type_lookup.handle)), + }; + } + + let lookup_expression = LookupExpression { + handle: acex.base_handle, + type_id: result_type_id, + }; + self.lookup_expression.insert(result_id, lookup_expression); + } + Op::CompositeExtract => { + inst.expect_at_least(4)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let base_id = self.next()?; + log::trace!("\t\t\tlooking up expr {:?}", base_id); + let mut lexp = { + let expr = self.lookup_expression.lookup(base_id)?; + LookupExpression { + handle: expr.handle, + type_id: expr.type_id, + } + }; + for _ in 4..inst.wc { + let index = self.next()?; + log::trace!("\t\t\tlooking up type {:?}", lexp.type_id); + let type_lookup = self.lookup_type.lookup(lexp.type_id)?; + let type_id = match type_arena[type_lookup.handle].inner { + crate::TypeInner::Struct { .. } => *self + .lookup_member_type_id + .get(&(type_lookup.handle, index)) + .ok_or(Error::InvalidAccessType(lexp.type_id))?, + crate::TypeInner::Array { .. } + | crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } => type_lookup + .base_id + .ok_or(Error::InvalidAccessType(lexp.type_id))?, + _ => return Err(Error::UnsupportedType(type_lookup.handle)), + }; + lexp = LookupExpression { + handle: expressions.append(crate::Expression::AccessIndex { + base: lexp.handle, + index, + }), + type_id, + }; + } + + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: lexp.handle, + type_id: result_type_id, + }, + ); + } + Op::CompositeConstruct => { + inst.expect_at_least(3)?; + let result_type_id = self.next()?; + let id = self.next()?; + let mut components = Vec::with_capacity(inst.wc as usize - 2); + for _ in 3..inst.wc { + let comp_id = self.next()?; + log::trace!("\t\t\tlooking up expr {:?}", comp_id); + let lexp = self.lookup_expression.lookup(comp_id)?; + components.push(lexp.handle); + } + let expr = crate::Expression::Compose { + ty: self.lookup_type.lookup(result_type_id)?.handle, + components, + }; + self.lookup_expression.insert( + id, + LookupExpression { + handle: expressions.append(expr), + type_id: result_type_id, + }, + ); + } + Op::Load => { + inst.expect_at_least(4)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let pointer_id = self.next()?; + if inst.wc != 4 { + inst.expect(5)?; + let _memory_access = self.next()?; + } + let base_expr = self.lookup_expression.lookup(pointer_id)?.clone(); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: base_expr.handle, // pass-through pointers + type_id: result_type_id, + }, + ); + } + Op::Store => { + inst.expect_at_least(3)?; + let pointer_id = self.next()?; + let value_id = self.next()?; + if inst.wc != 3 { + inst.expect(4)?; + let _memory_access = self.next()?; + } + let base_expr = self.lookup_expression.lookup(pointer_id)?; + let value_expr = self.lookup_expression.lookup(value_id)?; + assignments.push(Assignment { + to: base_expr.handle, + value: value_expr.handle, + }); + } + // Arithmetic Instructions +, -, *, /, % + Op::SNegate | Op::FNegate => { + inst.expect(4)?; + self.parse_expr_unary_op(expressions, crate::UnaryOperator::Negate)?; + } + Op::IAdd | Op::FAdd => { + inst.expect(5)?; + self.parse_expr_binary_op(expressions, crate::BinaryOperator::Add)?; + } + Op::ISub | Op::FSub => { + inst.expect(5)?; + self.parse_expr_binary_op(expressions, crate::BinaryOperator::Subtract)?; + } + Op::IMul | Op::FMul => { + inst.expect(5)?; + self.parse_expr_binary_op(expressions, crate::BinaryOperator::Multiply)?; + } + Op::SDiv | Op::UDiv | Op::FDiv => { + inst.expect(5)?; + self.parse_expr_binary_op(expressions, crate::BinaryOperator::Divide)?; + } + Op::UMod | Op::FMod | Op::SRem | Op::FRem => { + inst.expect(5)?; + self.parse_expr_binary_op(expressions, crate::BinaryOperator::Modulo)?; + } + Op::VectorTimesScalar + | Op::VectorTimesMatrix + | Op::MatrixTimesScalar + | Op::MatrixTimesVector + | Op::MatrixTimesMatrix => { + inst.expect(5)?; + self.parse_expr_binary_op(expressions, crate::BinaryOperator::Multiply)?; + } + Op::Transpose => { + inst.expect(4)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let matrix_id = self.next()?; + let matrix_lexp = self.lookup_expression.lookup(matrix_id)?; + let expr = crate::Expression::Transpose(matrix_lexp.handle); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expressions.append(expr), + type_id: result_type_id, + }, + ); + } + Op::Dot => { + inst.expect(5)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let left_id = self.next()?; + let right_id = self.next()?; + let left_lexp = self.lookup_expression.lookup(left_id)?; + let right_lexp = self.lookup_expression.lookup(right_id)?; + let expr = crate::Expression::DotProduct(left_lexp.handle, right_lexp.handle); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expressions.append(expr), + type_id: result_type_id, + }, + ); + } + // Bitwise instructions + Op::Not => { + inst.expect(4)?; + self.parse_expr_unary_op(expressions, crate::UnaryOperator::Not)?; + } + Op::BitwiseOr => { + inst.expect(5)?; + self.parse_expr_binary_op(expressions, crate::BinaryOperator::InclusiveOr)?; + } + Op::BitwiseXor => { + inst.expect(5)?; + self.parse_expr_binary_op(expressions, crate::BinaryOperator::ExclusiveOr)?; + } + Op::BitwiseAnd => { + inst.expect(5)?; + self.parse_expr_binary_op(expressions, crate::BinaryOperator::And)?; + } + Op::ShiftRightLogical => { + inst.expect(5)?; + //TODO: convert input and result to usigned + self.parse_expr_binary_op(expressions, crate::BinaryOperator::ShiftRight)?; + } + Op::ShiftRightArithmetic => { + inst.expect(5)?; + //TODO: convert input and result to signed + self.parse_expr_binary_op(expressions, crate::BinaryOperator::ShiftRight)?; + } + Op::ShiftLeftLogical => { + inst.expect(5)?; + self.parse_expr_binary_op(expressions, crate::BinaryOperator::ShiftLeft)?; + } + // Sampling + Op::SampledImage => { + inst.expect(5)?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let image_id = self.next()?; + let sampler_id = self.next()?; + let image_lexp = self.lookup_expression.lookup(image_id)?; + let sampler_lexp = self.lookup_expression.lookup(sampler_id)?; + //TODO: compare the result type + self.lookup_sampled_image.insert( + result_id, + LookupSampledImage { + image: image_lexp.handle, + sampler: sampler_lexp.handle, + }, + ); + } + Op::ImageSampleImplicitLod | Op::ImageSampleExplicitLod => { + inst.expect_at_least(5)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let sampled_image_id = self.next()?; + let coordinate_id = self.next()?; + let si_lexp = self.lookup_sampled_image.lookup(sampled_image_id)?.clone(); + let coord_lexp = self.lookup_expression.lookup(coordinate_id)?.clone(); + let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle; + + let sampler_type_handle = + reach_global_type(si_lexp.sampler, &expressions, global_arena) + .ok_or(Error::InvalidSamplerExpression(si_lexp.sampler))?; + let image_type_handle = + reach_global_type(si_lexp.image, &expressions, global_arena) + .ok_or(Error::InvalidImageExpression(si_lexp.image))?; + log::debug!( + "\t\t\tImage {:?} with sampler {:?}", + image_type_handle, + sampler_type_handle + ); + *self.handle_sampling.get_mut(&sampler_type_handle).unwrap() |= + SamplingFlags::REGULAR; + *self.handle_sampling.get_mut(&image_type_handle).unwrap() |= + SamplingFlags::REGULAR; + match type_arena[sampler_type_handle].inner { + crate::TypeInner::Sampler { comparison: false } => (), + _ => return Err(Error::InvalidSampleSampler(sampler_type_handle)), + }; + match type_arena[image_type_handle].inner { + //TODO: compare the result type + crate::TypeInner::Image { + dim, + arrayed, + class: + crate::ImageClass::Sampled { + kind: _, + multi: false, + }, + } => { + if !check_sample_coordinates( + &type_arena[coord_type_handle], + crate::ScalarKind::Float, + dim, + arrayed, + ) { + return Err(Error::InvalidSampleCoordinates(coord_type_handle)); + } + } + _ => return Err(Error::InvalidSampleImage(image_type_handle)), + }; + + let mut level = crate::SampleLevel::Auto; + let mut base_wc = 5; + if base_wc < inst.wc { + let image_ops = self.next()?; + base_wc += 1; + let mask = spirv::ImageOperands::from_bits_truncate(image_ops); + if mask.contains(spirv::ImageOperands::BIAS) { + let bias_expr = self.next()?; + let bias_handle = self.lookup_expression.lookup(bias_expr)?.handle; + level = crate::SampleLevel::Bias(bias_handle); + base_wc += 1; + } + if mask.contains(spirv::ImageOperands::LOD) { + let lod_expr = self.next()?; + let lod_handle = self.lookup_expression.lookup(lod_expr)?.handle; + level = crate::SampleLevel::Exact(lod_handle); + base_wc += 1; + } + for _ in base_wc..inst.wc { + self.next()?; + } + } + + let expr = crate::Expression::ImageSample { + image: si_lexp.image, + sampler: si_lexp.sampler, + coordinate: coord_lexp.handle, + level, + depth_ref: None, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expressions.append(expr), + type_id: result_type_id, + }, + ); + } + Op::ImageSampleDrefImplicitLod => { + inst.expect_at_least(6)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let sampled_image_id = self.next()?; + let coordinate_id = self.next()?; + let dref_id = self.next()?; + + let si_lexp = self.lookup_sampled_image.lookup(sampled_image_id)?; + let coord_lexp = self.lookup_expression.lookup(coordinate_id)?; + let coord_type_handle = self.lookup_type.lookup(coord_lexp.type_id)?.handle; + let sampler_type_handle = + reach_global_type(si_lexp.sampler, &expressions, global_arena) + .ok_or(Error::InvalidSamplerExpression(si_lexp.sampler))?; + let image_type_handle = + reach_global_type(si_lexp.image, &expressions, global_arena) + .ok_or(Error::InvalidImageExpression(si_lexp.image))?; + *self.handle_sampling.get_mut(&sampler_type_handle).unwrap() |= + SamplingFlags::COMPARISON; + *self.handle_sampling.get_mut(&image_type_handle).unwrap() |= + SamplingFlags::COMPARISON; + match type_arena[sampler_type_handle].inner { + crate::TypeInner::Sampler { comparison: true } => (), + _ => return Err(Error::InvalidSampleSampler(sampler_type_handle)), + }; + match type_arena[image_type_handle].inner { + //TODO: compare the result type + crate::TypeInner::Image { + dim, + arrayed, + class: crate::ImageClass::Depth, + } => { + if !check_sample_coordinates( + &type_arena[coord_type_handle], + crate::ScalarKind::Float, + dim, + arrayed, + ) { + return Err(Error::InvalidSampleCoordinates(coord_type_handle)); + } + } + _ => return Err(Error::InvalidSampleImage(image_type_handle)), + }; + + let dref_lexp = self.lookup_expression.lookup(dref_id)?; + let dref_type_handle = self.lookup_type.lookup(dref_lexp.type_id)?.handle; + match type_arena[dref_type_handle].inner { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Float, + width: _, + } => (), + _ => return Err(Error::InvalidDepthReference(dref_type_handle)), + } + + let expr = crate::Expression::ImageSample { + image: si_lexp.image, + sampler: si_lexp.sampler, + coordinate: coord_lexp.handle, + level: crate::SampleLevel::Auto, + depth_ref: Some(dref_lexp.handle), + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expressions.append(expr), + type_id: result_type_id, + }, + ); + } + Op::Select => { + inst.expect(6)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let condition = self.next()?; + let o1_id = self.next()?; + let o2_id = self.next()?; + + let cond_lexp = self.lookup_expression.lookup(condition)?; + let o1_lexp = self.lookup_expression.lookup(o1_id)?; + let o2_lexp = self.lookup_expression.lookup(o2_id)?; + + let expr = crate::Expression::Select { + condition: cond_lexp.handle, + accept: o1_lexp.handle, + reject: o2_lexp.handle, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expressions.append(expr), + type_id: result_type_id, + }, + ); + } + Op::VectorShuffle => { + inst.expect_at_least(5)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let v1_id = self.next()?; + let v2_id = self.next()?; + + let v1_lexp = self.lookup_expression.lookup(v1_id)?; + let v1_lty = self.lookup_type.lookup(v1_lexp.type_id)?; + let n1 = match type_arena[v1_lty.handle].inner { + crate::TypeInner::Vector { size, .. } => size as u8, + _ => return Err(Error::InvalidInnerType(v1_lexp.type_id)), + }; + let v1_handle = v1_lexp.handle; + let v2_lexp = self.lookup_expression.lookup(v2_id)?; + let v2_lty = self.lookup_type.lookup(v2_lexp.type_id)?; + let n2 = match type_arena[v2_lty.handle].inner { + crate::TypeInner::Vector { size, .. } => size as u8, + _ => return Err(Error::InvalidInnerType(v2_lexp.type_id)), + }; + let v2_handle = v2_lexp.handle; + + let mut components = Vec::with_capacity(inst.wc as usize - 5); + for _ in 0..components.capacity() { + let index = self.next()?; + let expr = if index < n1 as u32 { + crate::Expression::AccessIndex { + base: v1_handle, + index, + } + } else if index < n1 as u32 + n2 as u32 { + crate::Expression::AccessIndex { + base: v2_handle, + index: index - n1 as u32, + } + } else { + return Err(Error::InvalidAccessIndex(index)); + }; + components.push(expressions.append(expr)); + } + let expr = crate::Expression::Compose { + ty: self.lookup_type.lookup(result_type_id)?.handle, + components, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expressions.append(expr), + type_id: result_type_id, + }, + ); + } + Op::Bitcast + | Op::ConvertSToF + | Op::ConvertUToF + | Op::ConvertFToU + | Op::ConvertFToS => { + inst.expect_at_least(4)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let value_id = self.next()?; + + let value_lexp = self.lookup_expression.lookup(value_id)?; + let ty_lookup = self.lookup_type.lookup(result_type_id)?; + let kind = type_arena[ty_lookup.handle] + .inner + .scalar_kind() + .ok_or(Error::InvalidAsType(ty_lookup.handle))?; + + let expr = crate::Expression::As { + expr: value_lexp.handle, + kind, + convert: inst.op != Op::Bitcast, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expressions.append(expr), + type_id: result_type_id, + }, + ); + } + Op::FunctionCall => { + inst.expect_at_least(4)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let func_id = self.next()?; + + let mut arguments = Vec::with_capacity(inst.wc as usize - 4); + for _ in 0..arguments.capacity() { + let arg_id = self.next()?; + arguments.push(self.lookup_expression.lookup(arg_id)?.handle); + } + let expr = crate::Expression::Call { + // will be replaced by `Local()` after all the functions are parsed + origin: crate::FunctionOrigin::External(String::new()), + arguments, + }; + let expr_handle = expressions.append(expr); + local_function_calls.insert(expr_handle, func_id); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expr_handle, + type_id: result_type_id, + }, + ); + } + Op::ExtInst => { + let base_wc = 5; + inst.expect_at_least(base_wc)?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let set_id = self.next()?; + if Some(set_id) != self.ext_glsl_id { + return Err(Error::UnsupportedExtInstSet(set_id)); + } + let inst_id = self.next()?; + let name = match spirv::GLOp::from_u32(inst_id) { + Some(spirv::GLOp::FAbs) | Some(spirv::GLOp::SAbs) => { + inst.expect(base_wc + 1)?; + "abs" + } + Some(spirv::GLOp::FSign) | Some(spirv::GLOp::SSign) => { + inst.expect(base_wc + 1)?; + "sign" + } + Some(spirv::GLOp::Floor) => { + inst.expect(base_wc + 1)?; + "floor" + } + Some(spirv::GLOp::Ceil) => { + inst.expect(base_wc + 1)?; + "ceil" + } + Some(spirv::GLOp::Fract) => { + inst.expect(base_wc + 1)?; + "fract" + } + Some(spirv::GLOp::Sin) => { + inst.expect(base_wc + 1)?; + "sin" + } + Some(spirv::GLOp::Cos) => { + inst.expect(base_wc + 1)?; + "cos" + } + Some(spirv::GLOp::Tan) => { + inst.expect(base_wc + 1)?; + "tan" + } + Some(spirv::GLOp::Atan2) => { + inst.expect(base_wc + 2)?; + "atan2" + } + Some(spirv::GLOp::Pow) => { + inst.expect(base_wc + 2)?; + "pow" + } + Some(spirv::GLOp::MatrixInverse) => { + inst.expect(base_wc + 1)?; + "inverse" + } + Some(spirv::GLOp::FMix) => { + inst.expect(base_wc + 3)?; + "mix" + } + Some(spirv::GLOp::Step) => { + inst.expect(base_wc + 2)?; + "step" + } + Some(spirv::GLOp::SmoothStep) => { + inst.expect(base_wc + 3)?; + "smoothstep" + } + Some(spirv::GLOp::FMin) => { + inst.expect(base_wc + 2)?; + "min" + } + Some(spirv::GLOp::FMax) => { + inst.expect(base_wc + 2)?; + "max" + } + Some(spirv::GLOp::FClamp) => { + inst.expect(base_wc + 3)?; + "clamp" + } + Some(spirv::GLOp::Length) => { + inst.expect(base_wc + 1)?; + "length" + } + Some(spirv::GLOp::Distance) => { + inst.expect(base_wc + 2)?; + "distance" + } + Some(spirv::GLOp::Cross) => { + inst.expect(base_wc + 2)?; + "cross" + } + Some(spirv::GLOp::Normalize) => { + inst.expect(base_wc + 1)?; + "normalize" + } + Some(spirv::GLOp::Reflect) => { + inst.expect(base_wc + 2)?; + "reflect" + } + _ => return Err(Error::UnsupportedExtInst(inst_id)), + }; + + let mut arguments = Vec::with_capacity((inst.wc - base_wc) as usize); + for _ in 0..arguments.capacity() { + let arg_id = self.next()?; + arguments.push(self.lookup_expression.lookup(arg_id)?.handle); + } + let expr = crate::Expression::Call { + origin: crate::FunctionOrigin::External(name.to_string()), + arguments, + }; + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expressions.append(expr), + type_id: result_type_id, + }, + ); + } + // Relational and Logical Instructions + Op::LogicalNot => { + inst.expect(4)?; + self.parse_expr_unary_op(expressions, crate::UnaryOperator::Not)?; + } + op if inst.op >= Op::IEqual && inst.op <= Op::FUnordGreaterThanEqual => { + inst.expect(5)?; + self.parse_expr_binary_op(expressions, map_binary_operator(op)?)?; + } + Op::Kill => { + inst.expect(1)?; + break Terminator::Kill; + } + Op::Unreachable => { + inst.expect(1)?; + break Terminator::Unreachable; + } + Op::Return => { + inst.expect(1)?; + break Terminator::Return { value: None }; + } + Op::ReturnValue => { + inst.expect(2)?; + let value_id = self.next()?; + let value_lexp = self.lookup_expression.lookup(value_id)?; + break Terminator::Return { + value: Some(value_lexp.handle), + }; + } + Op::Branch => { + inst.expect(2)?; + let target_id = self.next()?; + break Terminator::Branch { target_id }; + } + Op::BranchConditional => { + inst.expect_at_least(4)?; + + let condition_id = self.next()?; + let condition = self.lookup_expression.lookup(condition_id)?.handle; + + let true_id = self.next()?; + let false_id = self.next()?; + + break Terminator::BranchConditional { + condition, + true_id, + false_id, + }; + } + Op::Switch => { + inst.expect_at_least(3)?; + + let selector = self.next()?; + let selector = self.lookup_expression[&selector].handle; + let default = self.next()?; + + let mut targets = Vec::new(); + for _ in 0..(inst.wc - 3) / 2 { + let literal = self.next()?; + let target = self.next()?; + targets.push((literal as i32, target)); + } + + break Terminator::Switch { + selector, + default, + targets, + }; + } + Op::SelectionMerge => { + inst.expect(3)?; + let merge_block_id = self.next()?; + // TODO: Selection Control Mask + let _selection_control = self.next()?; + let continue_block_id = None; + merge = Some(MergeInstruction { + merge_block_id, + continue_block_id, + }); + } + Op::LoopMerge => { + inst.expect_at_least(4)?; + let merge_block_id = self.next()?; + let continue_block_id = Some(self.next()?); + + // TODO: Loop Control Parameters + for _ in 0..inst.wc - 3 { + self.next()?; + } + + merge = Some(MergeInstruction { + merge_block_id, + continue_block_id, + }); + } + _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), + } + }; + + let mut block = Vec::new(); + for assignment in assignments.iter() { + block.push(crate::Statement::Store { + pointer: assignment.to, + value: assignment.value, + }); + } + + Ok(ControlFlowNode { + id: block_id, + ty: None, + phis, + block, + terminator, + merge, + }) + } + + fn make_expression_storage(&mut self) -> Arena<crate::Expression> { + let mut expressions = Arena::new(); + #[allow(clippy::panic)] + { + assert!(self.lookup_expression.is_empty()); + } + // register global variables + for (&id, var) in self.lookup_variable.iter() { + let handle = expressions.append(crate::Expression::GlobalVariable(var.handle)); + self.lookup_expression.insert( + id, + LookupExpression { + type_id: var.type_id, + handle, + }, + ); + } + // register constants + for (&id, con) in self.lookup_constant.iter() { + let handle = expressions.append(crate::Expression::Constant(con.handle)); + self.lookup_expression.insert( + id, + LookupExpression { + type_id: con.type_id, + handle, + }, + ); + } + // done + expressions + } + + fn switch(&mut self, state: ModuleState, op: spirv::Op) -> Result<(), Error> { + if state < self.state { + Err(Error::UnsupportedInstruction(self.state, op)) + } else { + self.state = state; + Ok(()) + } + } + + pub fn parse(mut self) -> Result<crate::Module, Error> { + let mut module = { + if self.next()? != spirv::MAGIC_NUMBER { + return Err(Error::InvalidHeader); + } + let _version_raw = self.next()?.to_le_bytes(); + let _generator = self.next()?; + let _bound = self.next()?; + let _schema = self.next()?; + crate::Module::generate_empty() + }; + + while let Ok(inst) = self.next_inst() { + use spirv::Op; + log::debug!("\t{:?} [{}]", inst.op, inst.wc); + match inst.op { + Op::Capability => self.parse_capability(inst), + Op::Extension => self.parse_extension(inst), + Op::ExtInstImport => self.parse_ext_inst_import(inst), + Op::MemoryModel => self.parse_memory_model(inst), + Op::EntryPoint => self.parse_entry_point(inst), + Op::ExecutionMode => self.parse_execution_mode(inst), + Op::Source => self.parse_source(inst), + Op::SourceExtension => self.parse_source_extension(inst), + Op::Name => self.parse_name(inst), + Op::MemberName => self.parse_member_name(inst), + Op::Decorate => self.parse_decorate(inst), + Op::MemberDecorate => self.parse_member_decorate(inst), + Op::TypeVoid => self.parse_type_void(inst), + Op::TypeBool => self.parse_type_bool(inst, &mut module), + Op::TypeInt => self.parse_type_int(inst, &mut module), + Op::TypeFloat => self.parse_type_float(inst, &mut module), + Op::TypeVector => self.parse_type_vector(inst, &mut module), + Op::TypeMatrix => self.parse_type_matrix(inst, &mut module), + Op::TypeFunction => self.parse_type_function(inst), + Op::TypePointer => self.parse_type_pointer(inst, &mut module), + Op::TypeArray => self.parse_type_array(inst, &mut module), + Op::TypeRuntimeArray => self.parse_type_runtime_array(inst, &mut module), + Op::TypeStruct => self.parse_type_struct(inst, &mut module), + Op::TypeImage => self.parse_type_image(inst, &mut module), + Op::TypeSampledImage => self.parse_type_sampled_image(inst), + Op::TypeSampler => self.parse_type_sampler(inst, &mut module), + Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module), + Op::ConstantComposite => self.parse_composite_constant(inst, &mut module), + Op::Variable => self.parse_global_variable(inst, &mut module), + Op::Function => self.parse_function(inst, &mut module), + _ => Err(Error::UnsupportedInstruction(self.state, inst.op)), //TODO + }?; + } + + // Check all the images and samplers to have consistent comparison property. + for (handle, flags) in self.handle_sampling.drain() { + if !flags.contains(SamplingFlags::COMPARISON) { + continue; + } + if flags == SamplingFlags::all() { + return Err(Error::InconsistentComparisonSampling(handle)); + } + let ty = module.types.get_mut(handle); + match ty.inner { + crate::TypeInner::Sampler { ref mut comparison } => { + #[allow(clippy::panic)] + { + assert!(!*comparison) + }; + *comparison = true; + } + _ => { + return Err(Error::UnexpectedComparisonType(handle)); + } + } + } + + for dfc in self.deferred_function_calls.drain(..) { + let dst_handle = *self.lookup_function.lookup(dfc.dst_id)?; + let fun = match dfc.source { + DeferredSource::Function(fun_handle) => module.functions.get_mut(fun_handle), + DeferredSource::EntryPoint(stage, name) => { + &mut module + .entry_points + .get_mut(&(stage, name)) + .unwrap() + .function + } + }; + match *fun.expressions.get_mut(dfc.expr_handle) { + crate::Expression::Call { + ref mut origin, + arguments: _, + } => *origin = crate::FunctionOrigin::Local(dst_handle), + _ => unreachable!(), + } + } + + if !self.future_decor.is_empty() { + log::warn!("Unused item decorations: {:?}", self.future_decor); + self.future_decor.clear(); + } + if !self.future_member_decor.is_empty() { + log::warn!("Unused member decorations: {:?}", self.future_member_decor); + self.future_member_decor.clear(); + } + + Ok(module) + } + + fn parse_capability(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Capability, inst.op)?; + inst.expect(2)?; + let capability = self.next()?; + let cap = + spirv::Capability::from_u32(capability).ok_or(Error::UnknownCapability(capability))?; + if !SUPPORTED_CAPABILITIES.contains(&cap) { + return Err(Error::UnsupportedCapability(cap)); + } + Ok(()) + } + + fn parse_extension(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Extension, inst.op)?; + inst.expect_at_least(2)?; + let (name, left) = self.next_string(inst.wc - 1)?; + if left != 0 { + return Err(Error::InvalidOperand); + } + if !SUPPORTED_EXTENSIONS.contains(&name.as_str()) { + return Err(Error::UnsupportedExtension(name)); + } + Ok(()) + } + + fn parse_ext_inst_import(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Extension, inst.op)?; + inst.expect_at_least(3)?; + let result_id = self.next()?; + let (name, left) = self.next_string(inst.wc - 2)?; + if left != 0 { + return Err(Error::InvalidOperand); + } + if !SUPPORTED_EXT_SETS.contains(&name.as_str()) { + return Err(Error::UnsupportedExtSet(name)); + } + self.ext_glsl_id = Some(result_id); + Ok(()) + } + + fn parse_memory_model(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::MemoryModel, inst.op)?; + inst.expect(3)?; + let _addressing_model = self.next()?; + let _memory_model = self.next()?; + Ok(()) + } + + fn parse_entry_point(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::EntryPoint, inst.op)?; + inst.expect_at_least(4)?; + let exec_model = self.next()?; + let exec_model = spirv::ExecutionModel::from_u32(exec_model) + .ok_or(Error::UnsupportedExecutionModel(exec_model))?; + let function_id = self.next()?; + let (name, left) = self.next_string(inst.wc - 3)?; + let ep = EntryPoint { + stage: match exec_model { + spirv::ExecutionModel::Vertex => crate::ShaderStage::Vertex, + spirv::ExecutionModel::Fragment => crate::ShaderStage::Fragment, + spirv::ExecutionModel::GLCompute => crate::ShaderStage::Compute, + _ => return Err(Error::UnsupportedExecutionModel(exec_model as u32)), + }, + name, + early_depth_test: None, + workgroup_size: [0; 3], + function_id, + variable_ids: self.data.by_ref().take(left as usize).collect(), + }; + self.lookup_entry_point.insert(function_id, ep); + Ok(()) + } + + fn parse_execution_mode(&mut self, inst: Instruction) -> Result<(), Error> { + use spirv::ExecutionMode; + + self.switch(ModuleState::ExecutionMode, inst.op)?; + inst.expect_at_least(3)?; + + let ep_id = self.next()?; + let mode_id = self.next()?; + let args: Vec<spirv::Word> = self.data.by_ref().take(inst.wc as usize - 3).collect(); + + let ep = self + .lookup_entry_point + .get_mut(&ep_id) + .ok_or(Error::InvalidId(ep_id))?; + let mode = spirv::ExecutionMode::from_u32(mode_id) + .ok_or(Error::UnsupportedExecutionMode(mode_id))?; + + match mode { + ExecutionMode::EarlyFragmentTests => { + if ep.early_depth_test.is_none() { + ep.early_depth_test = Some(crate::EarlyDepthTest { conservative: None }); + } + } + ExecutionMode::DepthUnchanged => { + ep.early_depth_test = Some(crate::EarlyDepthTest { + conservative: Some(crate::ConservativeDepth::Unchanged), + }); + } + ExecutionMode::DepthGreater => { + ep.early_depth_test = Some(crate::EarlyDepthTest { + conservative: Some(crate::ConservativeDepth::GreaterEqual), + }); + } + ExecutionMode::DepthLess => { + ep.early_depth_test = Some(crate::EarlyDepthTest { + conservative: Some(crate::ConservativeDepth::LessEqual), + }); + } + ExecutionMode::DepthReplacing => { + // Ignored because it can be deduced from the IR. + } + ExecutionMode::OriginUpperLeft => { + // Ignored because the other option (OriginLowerLeft) is not valid in Vulkan mode. + } + ExecutionMode::LocalSize => { + ep.workgroup_size = [args[0], args[1], args[2]]; + } + _ => { + return Err(Error::UnsupportedExecutionMode(mode_id)); + } + } + + Ok(()) + } + + fn parse_source(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Source, inst.op)?; + for _ in 1..inst.wc { + let _ = self.next()?; + } + Ok(()) + } + + fn parse_source_extension(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Source, inst.op)?; + inst.expect_at_least(2)?; + let (_name, _) = self.next_string(inst.wc - 1)?; + Ok(()) + } + + fn parse_name(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Name, inst.op)?; + inst.expect_at_least(3)?; + let id = self.next()?; + let (name, left) = self.next_string(inst.wc - 2)?; + if left != 0 { + return Err(Error::InvalidOperand); + } + self.future_decor.entry(id).or_default().name = Some(name); + Ok(()) + } + + fn parse_member_name(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Name, inst.op)?; + inst.expect_at_least(4)?; + let id = self.next()?; + let member = self.next()?; + let (name, left) = self.next_string(inst.wc - 3)?; + if left != 0 { + return Err(Error::InvalidOperand); + } + + self.future_member_decor + .entry((id, member)) + .or_default() + .name = Some(name); + Ok(()) + } + + fn parse_decorate(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Annotation, inst.op)?; + inst.expect_at_least(3)?; + let id = self.next()?; + let mut dec = self.future_decor.remove(&id).unwrap_or_default(); + self.next_decoration(inst, 2, &mut dec)?; + self.future_decor.insert(id, dec); + Ok(()) + } + + fn parse_member_decorate(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Annotation, inst.op)?; + inst.expect_at_least(4)?; + let id = self.next()?; + let member = self.next()?; + + let mut dec = self + .future_member_decor + .remove(&(id, member)) + .unwrap_or_default(); + self.next_decoration(inst, 3, &mut dec)?; + self.future_member_decor.insert((id, member), dec); + Ok(()) + } + + fn parse_type_void(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(2)?; + let id = self.next()?; + self.lookup_void_type.insert(id); + Ok(()) + } + + fn parse_type_bool( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(2)?; + let id = self.next()?; + let inner = crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width: 1, + }; + self.lookup_type.insert( + id, + LookupType { + handle: module.types.append(crate::Type { + name: self.future_decor.remove(&id).and_then(|dec| dec.name), + inner, + }), + base_id: None, + }, + ); + Ok(()) + } + + fn parse_type_int( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(4)?; + let id = self.next()?; + let width = self.next()?; + let sign = self.next()?; + let inner = crate::TypeInner::Scalar { + kind: match sign { + 0 => crate::ScalarKind::Uint, + 1 => crate::ScalarKind::Sint, + _ => return Err(Error::InvalidSign(sign)), + }, + width: map_width(width)?, + }; + self.lookup_type.insert( + id, + LookupType { + handle: module.types.append(crate::Type { + name: self.future_decor.remove(&id).and_then(|dec| dec.name), + inner, + }), + base_id: None, + }, + ); + Ok(()) + } + + fn parse_type_float( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(3)?; + let id = self.next()?; + let width = self.next()?; + let inner = crate::TypeInner::Scalar { + kind: crate::ScalarKind::Float, + width: map_width(width)?, + }; + self.lookup_type.insert( + id, + LookupType { + handle: module.types.append(crate::Type { + name: self.future_decor.remove(&id).and_then(|dec| dec.name), + inner, + }), + base_id: None, + }, + ); + Ok(()) + } + + fn parse_type_vector( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(4)?; + let id = self.next()?; + let type_id = self.next()?; + let type_lookup = self.lookup_type.lookup(type_id)?; + let (kind, width) = match module.types[type_lookup.handle].inner { + crate::TypeInner::Scalar { kind, width } => (kind, width), + _ => return Err(Error::InvalidInnerType(type_id)), + }; + let component_count = self.next()?; + let inner = crate::TypeInner::Vector { + size: map_vector_size(component_count)?, + kind, + width, + }; + self.lookup_type.insert( + id, + LookupType { + handle: module.types.append(crate::Type { + name: self.future_decor.remove(&id).and_then(|dec| dec.name), + inner, + }), + base_id: Some(type_id), + }, + ); + Ok(()) + } + + fn parse_type_matrix( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(4)?; + let id = self.next()?; + let vector_type_id = self.next()?; + let num_columns = self.next()?; + let vector_type_lookup = self.lookup_type.lookup(vector_type_id)?; + let inner = match module.types[vector_type_lookup.handle].inner { + crate::TypeInner::Vector { size, width, .. } => crate::TypeInner::Matrix { + columns: map_vector_size(num_columns)?, + rows: size, + width, + }, + _ => return Err(Error::InvalidInnerType(vector_type_id)), + }; + self.lookup_type.insert( + id, + LookupType { + handle: module.types.append(crate::Type { + name: self.future_decor.remove(&id).and_then(|dec| dec.name), + inner, + }), + base_id: Some(vector_type_id), + }, + ); + Ok(()) + } + + fn parse_type_function(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect_at_least(3)?; + let id = self.next()?; + let return_type_id = self.next()?; + let parameter_type_ids = self.data.by_ref().take(inst.wc as usize - 3).collect(); + self.lookup_function_type.insert( + id, + LookupFunctionType { + parameter_type_ids, + return_type_id, + }, + ); + Ok(()) + } + + fn parse_type_pointer( + &mut self, + inst: Instruction, + _module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(4)?; + let id = self.next()?; + let _storage = self.next()?; + let type_id = self.next()?; + let type_lookup = self.lookup_type.lookup(type_id)?.clone(); + self.lookup_type.insert(id, type_lookup); // don't register pointers in the IR + Ok(()) + } + + fn parse_type_array( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(4)?; + let id = self.next()?; + let type_id = self.next()?; + let length_id = self.next()?; + let length_const = self.lookup_constant.lookup(length_id)?; + + let decor = self.future_decor.remove(&id); + let inner = crate::TypeInner::Array { + base: self.lookup_type.lookup(type_id)?.handle, + size: crate::ArraySize::Constant(length_const.handle), + stride: decor.as_ref().and_then(|dec| dec.array_stride), + }; + self.lookup_type.insert( + id, + LookupType { + handle: module.types.append(crate::Type { + name: decor.and_then(|dec| dec.name), + inner, + }), + base_id: Some(type_id), + }, + ); + Ok(()) + } + + fn parse_type_runtime_array( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(3)?; + let id = self.next()?; + let type_id = self.next()?; + + let decor = self.future_decor.remove(&id); + let inner = crate::TypeInner::Array { + base: self.lookup_type.lookup(type_id)?.handle, + size: crate::ArraySize::Dynamic, + stride: decor.as_ref().and_then(|dec| dec.array_stride), + }; + self.lookup_type.insert( + id, + LookupType { + handle: module.types.append(crate::Type { + name: decor.and_then(|dec| dec.name), + inner, + }), + base_id: Some(type_id), + }, + ); + Ok(()) + } + + fn parse_type_struct( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect_at_least(2)?; + let id = self.next()?; + let parent_decor = self.future_decor.remove(&id); + let is_buffer_block = parent_decor + .as_ref() + .map_or(false, |decor| match decor.block { + Some(Block { buffer }) => buffer, + _ => false, + }); + + let mut members = Vec::with_capacity(inst.wc as usize - 2); + let mut member_type_ids = Vec::with_capacity(members.capacity()); + for i in 0..u32::from(inst.wc) - 2 { + let type_id = self.next()?; + member_type_ids.push(type_id); + let ty = self.lookup_type.lookup(type_id)?.handle; + let decor = self + .future_member_decor + .remove(&(id, i)) + .unwrap_or_default(); + let origin = decor.get_origin()?; + members.push(crate::StructMember { + name: decor.name, + origin, + ty, + }); + } + let inner = crate::TypeInner::Struct { members }; + let ty_handle = module.types.append(crate::Type { + name: parent_decor.and_then(|dec| dec.name), + inner, + }); + + if is_buffer_block { + self.lookup_storage_buffer_types.insert(ty_handle); + } + for (i, type_id) in member_type_ids.into_iter().enumerate() { + self.lookup_member_type_id + .insert((ty_handle, i as u32), type_id); + } + self.lookup_type.insert( + id, + LookupType { + handle: ty_handle, + base_id: None, + }, + ); + Ok(()) + } + + fn parse_type_image( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(9)?; + + let id = self.next()?; + let sample_type_id = self.next()?; + let dim = self.next()?; + let _is_depth = self.next()?; + let is_array = self.next()? != 0; + let is_msaa = self.next()? != 0; + let _is_sampled = self.next()?; + let format = self.next()?; + + let base_handle = self.lookup_type.lookup(sample_type_id)?.handle; + let kind = module.types[base_handle] + .inner + .scalar_kind() + .ok_or(Error::InvalidImageBaseType(base_handle))?; + + let class = if format != 0 { + crate::ImageClass::Storage(map_image_format(format)?) + } else { + crate::ImageClass::Sampled { + kind, + multi: is_msaa, + } + }; + + let decor = self.future_decor.remove(&id).unwrap_or_default(); + + let inner = crate::TypeInner::Image { + class, + dim: map_image_dim(dim)?, + arrayed: is_array, + }; + let handle = module.types.append(crate::Type { + name: decor.name, + inner, + }); + log::debug!("\t\ttracking {:?} for sampling properties", handle); + self.handle_sampling.insert(handle, SamplingFlags::empty()); + self.lookup_type.insert( + id, + LookupType { + handle, + base_id: Some(sample_type_id), + }, + ); + Ok(()) + } + + fn parse_type_sampled_image(&mut self, inst: Instruction) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(3)?; + let id = self.next()?; + let image_id = self.next()?; + self.lookup_type.insert( + id, + LookupType { + handle: self.lookup_type.lookup(image_id)?.handle, + base_id: Some(image_id), + }, + ); + Ok(()) + } + + fn parse_type_sampler( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect(2)?; + let id = self.next()?; + let decor = self.future_decor.remove(&id).unwrap_or_default(); + // The comparison bit is temporary, will be overwritten based on the + // accumulated sampling flags at the end. + let inner = crate::TypeInner::Sampler { comparison: false }; + let handle = module.types.append(crate::Type { + name: decor.name, + inner, + }); + log::debug!("\t\ttracking {:?} for sampling properties", handle); + self.handle_sampling.insert(handle, SamplingFlags::empty()); + self.lookup_type.insert( + id, + LookupType { + handle, + base_id: None, + }, + ); + Ok(()) + } + + fn parse_constant( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect_at_least(3)?; + let type_id = self.next()?; + let id = self.next()?; + let type_lookup = self.lookup_type.lookup(type_id)?; + let ty = type_lookup.handle; + let inner = match module.types[ty].inner { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + width, + } => { + let low = self.next()?; + let high = if width > 4 { + inst.expect(4)?; + self.next()? + } else { + 0 + }; + crate::ConstantInner::Uint((u64::from(high) << 32) | u64::from(low)) + } + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Sint, + width, + } => { + use std::cmp::Ordering; + let low = self.next()?; + let high = match width.cmp(&4) { + Ordering::Less => return Err(Error::InvalidTypeWidth(u32::from(width))), + Ordering::Greater => { + inst.expect(4)?; + self.next()? + } + Ordering::Equal => 0, + }; + crate::ConstantInner::Sint(((u64::from(high) << 32) | u64::from(low)) as i64) + } + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Float, + width, + } => { + let low = self.next()?; + let extended = match width { + 4 => f64::from(f32::from_bits(low)), + 8 => { + inst.expect(4)?; + let high = self.next()?; + f64::from_bits((u64::from(high) << 32) | u64::from(low)) + } + _ => return Err(Error::InvalidTypeWidth(u32::from(width))), + }; + crate::ConstantInner::Float(extended) + } + _ => return Err(Error::UnsupportedType(type_lookup.handle)), + }; + self.lookup_constant.insert( + id, + LookupConstant { + handle: module.constants.append(crate::Constant { + name: self.future_decor.remove(&id).and_then(|dec| dec.name), + specialization: None, //TODO + inner, + ty, + }), + type_id, + }, + ); + Ok(()) + } + + fn parse_composite_constant( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect_at_least(3)?; + let type_id = self.next()?; + let type_lookup = self.lookup_type.lookup(type_id)?; + let ty = type_lookup.handle; + + let id = self.next()?; + + let constituents_count = inst.wc - 3; + let mut constituents = Vec::with_capacity(constituents_count as usize); + for _ in 0..constituents_count { + let constituent_id = self.next()?; + let constant = self.lookup_constant.lookup(constituent_id)?; + constituents.push(constant.handle); + } + + self.lookup_constant.insert( + id, + LookupConstant { + handle: module.constants.append(crate::Constant { + name: self.future_decor.remove(&id).and_then(|dec| dec.name), + specialization: None, + inner: crate::ConstantInner::Composite(constituents), + ty, + }), + type_id, + }, + ); + + Ok(()) + } + + fn parse_global_variable( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Type, inst.op)?; + inst.expect_at_least(4)?; + let type_id = self.next()?; + let id = self.next()?; + let storage_class = self.next()?; + let init = if inst.wc > 4 { + inst.expect(5)?; + let init_id = self.next()?; + let lconst = self.lookup_constant.lookup(init_id)?; + Some(lconst.handle) + } else { + None + }; + let lookup_type = self.lookup_type.lookup(type_id)?; + let dec = self + .future_decor + .remove(&id) + .ok_or(Error::InvalidBinding(id))?; + + let class = { + use spirv::StorageClass as Sc; + match Sc::from_u32(storage_class) { + Some(Sc::Function) => crate::StorageClass::Function, + Some(Sc::Input) => crate::StorageClass::Input, + Some(Sc::Output) => crate::StorageClass::Output, + Some(Sc::Private) => crate::StorageClass::Private, + Some(Sc::UniformConstant) => crate::StorageClass::Handle, + Some(Sc::StorageBuffer) => crate::StorageClass::Storage, + Some(Sc::Uniform) => { + if self + .lookup_storage_buffer_types + .contains(&lookup_type.handle) + { + crate::StorageClass::Storage + } else { + crate::StorageClass::Uniform + } + } + Some(Sc::Workgroup) => crate::StorageClass::WorkGroup, + Some(Sc::PushConstant) => crate::StorageClass::PushConstant, + _ => return Err(Error::UnsupportedStorageClass(storage_class)), + } + }; + + let binding = match (class, &module.types[lookup_type.handle].inner) { + (crate::StorageClass::Input, &crate::TypeInner::Struct { .. }) + | (crate::StorageClass::Output, &crate::TypeInner::Struct { .. }) => None, + _ => Some(dec.get_binding().ok_or(Error::InvalidBinding(id))?), + }; + let is_storage = match module.types[lookup_type.handle].inner { + crate::TypeInner::Struct { .. } => class == crate::StorageClass::Storage, + crate::TypeInner::Image { + class: crate::ImageClass::Storage(_), + .. + } => true, + _ => false, + }; + + let storage_access = if is_storage { + let mut access = crate::StorageAccess::all(); + if dec.flags.contains(DecorationFlags::NON_READABLE) { + access ^= crate::StorageAccess::LOAD; + } + if dec.flags.contains(DecorationFlags::NON_WRITABLE) { + access ^= crate::StorageAccess::STORE; + } + access + } else { + crate::StorageAccess::empty() + }; + + let var = crate::GlobalVariable { + name: dec.name, + class, + binding, + ty: lookup_type.handle, + init, + interpolation: dec.interpolation, + storage_access, + }; + self.lookup_variable.insert( + id, + LookupVariable { + handle: module.global_variables.append(var), + type_id, + }, + ); + Ok(()) + } +} + +pub fn parse_u8_slice(data: &[u8], options: &Options) -> Result<crate::Module, Error> { + if data.len() % 4 != 0 { + return Err(Error::IncompleteData); + } + + let words = data + .chunks(4) + .map(|c| u32::from_le_bytes(c.try_into().unwrap())); + Parser::new(words, options).parse() +} + +#[cfg(test)] +mod test { + #[test] + fn parse() { + let bin = vec![ + // Magic number. Version number: 1.0. + 0x03, 0x02, 0x23, 0x07, 0x00, 0x00, 0x01, 0x00, + // Generator number: 0. Bound: 0. + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Reserved word: 0. + 0x00, 0x00, 0x00, 0x00, // OpMemoryModel. Logical. + 0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, // GLSL450. + 0x01, 0x00, 0x00, 0x00, + ]; + let _ = super::parse_u8_slice(&bin, &Default::default()).unwrap(); + } +} diff --git a/third_party/rust/naga/src/front/spv/rosetta.rs b/third_party/rust/naga/src/front/spv/rosetta.rs new file mode 100644 index 0000000000..027c0d0adc --- /dev/null +++ b/third_party/rust/naga/src/front/spv/rosetta.rs @@ -0,0 +1,23 @@ +use std::{fs, path::Path}; + +const TEST_PATH: &str = "test-data"; + +fn rosetta_test(file_name: &str) { + if true { + return; //TODO: fix this test + } + let file_path = Path::new(TEST_PATH).join(file_name); + let input = fs::read(&file_path).unwrap(); + + let module = super::parse_u8_slice(&input, &Default::default()).unwrap(); + let output = ron::ser::to_string_pretty(&module, Default::default()).unwrap(); + + let expected = fs::read_to_string(file_path.with_extension("expected.ron")).unwrap(); + + difference::assert_diff!(output.as_str(), expected.as_str(), "", 0); +} + +#[test] +fn simple() { + rosetta_test("simple/simple.spv") +} diff --git a/third_party/rust/naga/src/front/wgsl/conv.rs b/third_party/rust/naga/src/front/wgsl/conv.rs new file mode 100644 index 0000000000..6fdea89e7a --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/conv.rs @@ -0,0 +1,117 @@ +use super::Error; + +pub fn map_storage_class(word: &str) -> Result<crate::StorageClass, Error<'_>> { + match word { + "in" => Ok(crate::StorageClass::Input), + "out" => Ok(crate::StorageClass::Output), + "private" => Ok(crate::StorageClass::Private), + "uniform" => Ok(crate::StorageClass::Uniform), + "storage" => Ok(crate::StorageClass::Storage), + _ => Err(Error::UnknownStorageClass(word)), + } +} + +pub fn map_built_in(word: &str) -> Result<crate::BuiltIn, Error<'_>> { + Ok(match word { + // vertex + "position" => crate::BuiltIn::Position, + "vertex_idx" => crate::BuiltIn::VertexIndex, + "instance_idx" => crate::BuiltIn::InstanceIndex, + // fragment + "front_facing" => crate::BuiltIn::FrontFacing, + "frag_coord" => crate::BuiltIn::FragCoord, + "frag_depth" => crate::BuiltIn::FragDepth, + // compute + "global_invocation_id" => crate::BuiltIn::GlobalInvocationId, + "local_invocation_id" => crate::BuiltIn::LocalInvocationId, + "local_invocation_idx" => crate::BuiltIn::LocalInvocationIndex, + _ => return Err(Error::UnknownBuiltin(word)), + }) +} + +pub fn map_shader_stage(word: &str) -> Result<crate::ShaderStage, Error<'_>> { + match word { + "vertex" => Ok(crate::ShaderStage::Vertex), + "fragment" => Ok(crate::ShaderStage::Fragment), + "compute" => Ok(crate::ShaderStage::Compute), + _ => Err(Error::UnknownShaderStage(word)), + } +} + +pub fn map_interpolation(word: &str) -> Result<crate::Interpolation, Error<'_>> { + match word { + "linear" => Ok(crate::Interpolation::Linear), + "flat" => Ok(crate::Interpolation::Flat), + "centroid" => Ok(crate::Interpolation::Centroid), + "sample" => Ok(crate::Interpolation::Sample), + "perspective" => Ok(crate::Interpolation::Perspective), + _ => Err(Error::UnknownDecoration(word)), + } +} + +pub fn map_storage_format(word: &str) -> Result<crate::StorageFormat, Error<'_>> { + use crate::StorageFormat as Sf; + Ok(match word { + "r8unorm" => Sf::R8Unorm, + "r8snorm" => Sf::R8Snorm, + "r8uint" => Sf::R8Uint, + "r8sint" => Sf::R8Sint, + "r16uint" => Sf::R16Uint, + "r16sint" => Sf::R16Sint, + "r16float" => Sf::R16Float, + "rg8unorm" => Sf::Rg8Unorm, + "rg8snorm" => Sf::Rg8Snorm, + "rg8uint" => Sf::Rg8Uint, + "rg8sint" => Sf::Rg8Sint, + "r32uint" => Sf::R32Uint, + "r32sint" => Sf::R32Sint, + "r32float" => Sf::R32Float, + "rg16uint" => Sf::Rg16Uint, + "rg16sint" => Sf::Rg16Sint, + "rg16float" => Sf::Rg16Float, + "rgba8unorm" => Sf::Rgba8Unorm, + "rgba8snorm" => Sf::Rgba8Snorm, + "rgba8uint" => Sf::Rgba8Uint, + "rgba8sint" => Sf::Rgba8Sint, + "rgb10a2unorm" => Sf::Rgb10a2Unorm, + "rg11b10float" => Sf::Rg11b10Float, + "rg32uint" => Sf::Rg32Uint, + "rg32sint" => Sf::Rg32Sint, + "rg32float" => Sf::Rg32Float, + "rgba16uint" => Sf::Rgba16Uint, + "rgba16sint" => Sf::Rgba16Sint, + "rgba16float" => Sf::Rgba16Float, + "rgba32uint" => Sf::Rgba32Uint, + "rgba32sint" => Sf::Rgba32Sint, + "rgba32float" => Sf::Rgba32Float, + _ => return Err(Error::UnknownStorageFormat(word)), + }) +} + +pub fn get_scalar_type(word: &str) -> Option<(crate::ScalarKind, crate::Bytes)> { + match word { + "f32" => Some((crate::ScalarKind::Float, 4)), + "i32" => Some((crate::ScalarKind::Sint, 4)), + "u32" => Some((crate::ScalarKind::Uint, 4)), + _ => None, + } +} + +pub fn get_intrinsic(word: &str) -> Option<crate::IntrinsicFunction> { + match word { + "any" => Some(crate::IntrinsicFunction::Any), + "all" => Some(crate::IntrinsicFunction::All), + "is_nan" => Some(crate::IntrinsicFunction::IsNan), + "is_inf" => Some(crate::IntrinsicFunction::IsInf), + "is_normal" => Some(crate::IntrinsicFunction::IsNormal), + _ => None, + } +} +pub fn get_derivative(word: &str) -> Option<crate::DerivativeAxis> { + match word { + "dpdx" => Some(crate::DerivativeAxis::X), + "dpdy" => Some(crate::DerivativeAxis::Y), + "dwidth" => Some(crate::DerivativeAxis::Width), + _ => None, + } +} diff --git a/third_party/rust/naga/src/front/wgsl/lexer.rs b/third_party/rust/naga/src/front/wgsl/lexer.rs new file mode 100644 index 0000000000..b991ecf619 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/lexer.rs @@ -0,0 +1,292 @@ +use super::{conv, Error, Token}; + +fn _consume_str<'a>(input: &'a str, what: &str) -> Option<&'a str> { + if input.starts_with(what) { + Some(&input[what.len()..]) + } else { + None + } +} + +fn consume_any(input: &str, what: impl Fn(char) -> bool) -> (&str, &str) { + let pos = input.find(|c| !what(c)).unwrap_or_else(|| input.len()); + input.split_at(pos) +} + +fn consume_number(input: &str) -> (&str, &str) { + let mut is_first_char = true; + let mut right_after_exponent = false; + + let mut what = |c| { + if is_first_char { + is_first_char = false; + c == '-' || c >= '0' && c <= '9' || c == '.' + } else if c == 'e' || c == 'E' { + right_after_exponent = true; + true + } else if right_after_exponent { + right_after_exponent = false; + c >= '0' && c <= '9' || c == '-' + } else { + c >= '0' && c <= '9' || c == '.' + } + }; + let pos = input.find(|c| !what(c)).unwrap_or_else(|| input.len()); + input.split_at(pos) +} + +fn consume_token(mut input: &str) -> (Token<'_>, &str) { + input = input.trim_start(); + let mut chars = input.chars(); + let cur = match chars.next() { + Some(c) => c, + None => return (Token::End, input), + }; + match cur { + ':' => { + input = chars.as_str(); + if chars.next() == Some(':') { + (Token::DoubleColon, chars.as_str()) + } else { + (Token::Separator(cur), input) + } + } + ';' | ',' => (Token::Separator(cur), chars.as_str()), + '.' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('0'..='9') => { + let (number, rest) = consume_number(input); + (Token::Number(number), rest) + } + _ => (Token::Separator(cur), og_chars), + } + } + '(' | ')' | '{' | '}' => (Token::Paren(cur), chars.as_str()), + '<' | '>' => { + input = chars.as_str(); + let next = chars.next(); + if next == Some('=') { + (Token::LogicalOperation(cur), chars.as_str()) + } else if next == Some(cur) { + (Token::ShiftOperation(cur), chars.as_str()) + } else { + (Token::Paren(cur), input) + } + } + '[' | ']' => { + input = chars.as_str(); + if chars.next() == Some(cur) { + (Token::DoubleParen(cur), chars.as_str()) + } else { + (Token::Paren(cur), input) + } + } + '0'..='9' => { + let (number, rest) = consume_number(input); + (Token::Number(number), rest) + } + 'a'..='z' | 'A'..='Z' | '_' => { + let (word, rest) = consume_any(input, |c| c.is_ascii_alphanumeric() || c == '_'); + (Token::Word(word), rest) + } + '"' => { + let mut iter = chars.as_str().splitn(2, '"'); + + // splitn returns an iterator with at least one element, so unwrapping is fine + let quote_content = iter.next().unwrap(); + if let Some(rest) = iter.next() { + (Token::String(quote_content), rest) + } else { + (Token::UnterminatedString, quote_content) + } + } + '-' => { + let og_chars = chars.as_str(); + match chars.next() { + Some('>') => (Token::Arrow, chars.as_str()), + Some('0'..='9') | Some('.') => { + let (number, rest) = consume_number(input); + (Token::Number(number), rest) + } + _ => (Token::Operation(cur), og_chars), + } + } + '+' | '*' | '/' | '%' | '^' => (Token::Operation(cur), chars.as_str()), + '!' => { + if chars.next() == Some('=') { + (Token::LogicalOperation(cur), chars.as_str()) + } else { + (Token::Operation(cur), input) + } + } + '=' | '&' | '|' => { + input = chars.as_str(); + if chars.next() == Some(cur) { + (Token::LogicalOperation(cur), chars.as_str()) + } else { + (Token::Operation(cur), input) + } + } + '#' => match chars.position(|c| c == '\n' || c == '\r') { + Some(_) => consume_token(chars.as_str()), + None => (Token::End, chars.as_str()), + }, + _ => (Token::Unknown(cur), chars.as_str()), + } +} + +#[derive(Clone)] +pub(super) struct Lexer<'a> { + input: &'a str, +} + +impl<'a> Lexer<'a> { + pub(super) fn new(input: &'a str) -> Self { + Lexer { input } + } + + #[must_use] + pub(super) fn next(&mut self) -> Token<'a> { + let (token, rest) = consume_token(self.input); + self.input = rest; + token + } + + #[must_use] + pub(super) fn peek(&mut self) -> Token<'a> { + self.clone().next() + } + + pub(super) fn expect(&mut self, expected: Token<'_>) -> Result<(), Error<'a>> { + let token = self.next(); + if token == expected { + Ok(()) + } else { + Err(Error::Unexpected(token)) + } + } + + pub(super) fn skip(&mut self, what: Token<'_>) -> bool { + let (token, rest) = consume_token(self.input); + if token == what { + self.input = rest; + true + } else { + false + } + } + + pub(super) fn next_ident(&mut self) -> Result<&'a str, Error<'a>> { + match self.next() { + Token::Word(word) => Ok(word), + other => Err(Error::Unexpected(other)), + } + } + + fn _next_float_literal(&mut self) -> Result<f32, Error<'a>> { + match self.next() { + Token::Number(word) => word.parse().map_err(|err| Error::BadFloat(word, err)), + other => Err(Error::Unexpected(other)), + } + } + + pub(super) fn next_uint_literal(&mut self) -> Result<u32, Error<'a>> { + match self.next() { + Token::Number(word) => word.parse().map_err(|err| Error::BadInteger(word, err)), + other => Err(Error::Unexpected(other)), + } + } + + fn _next_sint_literal(&mut self) -> Result<i32, Error<'a>> { + match self.next() { + Token::Number(word) => word.parse().map_err(|err| Error::BadInteger(word, err)), + other => Err(Error::Unexpected(other)), + } + } + + pub(super) fn next_scalar_generic( + &mut self, + ) -> Result<(crate::ScalarKind, crate::Bytes), Error<'a>> { + self.expect(Token::Paren('<'))?; + let word = self.next_ident()?; + let pair = conv::get_scalar_type(word).ok_or(Error::UnknownScalarType(word))?; + self.expect(Token::Paren('>'))?; + Ok(pair) + } + + pub(super) fn next_format_generic(&mut self) -> Result<crate::StorageFormat, Error<'a>> { + self.expect(Token::Paren('<'))?; + let format = conv::map_storage_format(self.next_ident()?)?; + self.expect(Token::Paren('>'))?; + Ok(format) + } + + pub(super) fn take_until(&mut self, what: Token<'_>) -> Result<Lexer<'a>, Error<'a>> { + let original_input = self.input; + let initial_len = self.input.len(); + let mut used_len = 0; + loop { + if self.next() == what { + break; + } + used_len = initial_len - self.input.len(); + } + + Ok(Lexer { + input: &original_input[..used_len], + }) + } + + pub(super) fn offset_from(&self, source: &'a str) -> usize { + source.len() - self.input.len() + } +} + +#[cfg(test)] +fn sub_test(source: &str, expected_tokens: &[Token]) { + let mut lex = Lexer::new(source); + for &token in expected_tokens { + assert_eq!(lex.next(), token); + } + assert_eq!(lex.next(), Token::End); +} + +#[test] +fn test_tokens() { + sub_test("id123_OK", &[Token::Word("id123_OK")]); + sub_test("92No", &[Token::Number("92"), Token::Word("No")]); + sub_test( + "æNoø", + &[Token::Unknown('æ'), Token::Word("No"), Token::Unknown('ø')], + ); + sub_test("No¾", &[Token::Word("No"), Token::Unknown('¾')]); + sub_test("No好", &[Token::Word("No"), Token::Unknown('好')]); + sub_test("\"\u{2}ПЀ\u{0}\"", &[Token::String("\u{2}ПЀ\u{0}")]); // https://github.com/gfx-rs/naga/issues/90 +} + +#[test] +fn test_variable_decl() { + sub_test( + "[[ group(0 )]] var< uniform> texture: texture_multisampled_2d <f32 >;", + &[ + Token::DoubleParen('['), + Token::Word("group"), + Token::Paren('('), + Token::Number("0"), + Token::Paren(')'), + Token::DoubleParen(']'), + Token::Word("var"), + Token::Paren('<'), + Token::Word("uniform"), + Token::Paren('>'), + Token::Word("texture"), + Token::Separator(':'), + Token::Word("texture_multisampled_2d"), + Token::Paren('<'), + Token::Word("f32"), + Token::Paren('>'), + Token::Separator(';'), + ], + ) +} diff --git a/third_party/rust/naga/src/front/wgsl/mod.rs b/third_party/rust/naga/src/front/wgsl/mod.rs new file mode 100644 index 0000000000..3b17719006 --- /dev/null +++ b/third_party/rust/naga/src/front/wgsl/mod.rs @@ -0,0 +1,1850 @@ +//! Front end for consuming [WebGPU Shading Language][wgsl]. +//! +//! [wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html + +mod conv; +mod lexer; + +use crate::{ + arena::{Arena, Handle}, + proc::{ResolveContext, ResolveError, Typifier}, + FastHashMap, +}; + +use self::lexer::Lexer; +use thiserror::Error; + +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum Token<'a> { + Separator(char), + DoubleColon, + Paren(char), + DoubleParen(char), + Number(&'a str), + String(&'a str), + Word(&'a str), + Operation(char), + LogicalOperation(char), + ShiftOperation(char), + Arrow, + Unknown(char), + UnterminatedString, + End, +} + +#[derive(Clone, Debug, Error)] +pub enum Error<'a> { + #[error("unexpected token: {0:?}")] + Unexpected(Token<'a>), + #[error("unable to parse `{0}` as integer: {1}")] + BadInteger(&'a str, std::num::ParseIntError), + #[error("unable to parse `{1}` as float: {1}")] + BadFloat(&'a str, std::num::ParseFloatError), + #[error("bad field accessor `{0}`")] + BadAccessor(&'a str), + #[error(transparent)] + InvalidResolve(ResolveError), + #[error("unknown import: `{0}`")] + UnknownImport(&'a str), + #[error("unknown storage class: `{0}`")] + UnknownStorageClass(&'a str), + #[error("unknown decoration: `{0}`")] + UnknownDecoration(&'a str), + #[error("unknown scalar kind: `{0}`")] + UnknownScalarKind(&'a str), + #[error("unknown builtin: `{0}`")] + UnknownBuiltin(&'a str), + #[error("unknown shader stage: `{0}`")] + UnknownShaderStage(&'a str), + #[error("unknown identifier: `{0}`")] + UnknownIdent(&'a str), + #[error("unknown scalar type: `{0}`")] + UnknownScalarType(&'a str), + #[error("unknown type: `{0}`")] + UnknownType(&'a str), + #[error("unknown function: `{0}`")] + UnknownFunction(&'a str), + #[error("unknown storage format: `{0}`")] + UnknownStorageFormat(&'a str), + #[error("missing offset for structure member `{0}`")] + MissingMemberOffset(&'a str), + #[error("array stride must not be 0")] + ZeroStride, + #[error("not a composite type: {0:?}")] + NotCompositeType(Handle<crate::Type>), + #[error("function redefinition: `{0}`")] + FunctionRedefinition(&'a str), + #[error("other error")] + Other, +} + +trait StringValueLookup<'a> { + type Value; + fn lookup(&self, key: &'a str) -> Result<Self::Value, Error<'a>>; +} +impl<'a> StringValueLookup<'a> for FastHashMap<&'a str, Handle<crate::Expression>> { + type Value = Handle<crate::Expression>; + fn lookup(&self, key: &'a str) -> Result<Self::Value, Error<'a>> { + self.get(key).cloned().ok_or(Error::UnknownIdent(key)) + } +} + +struct StatementContext<'input, 'temp, 'out> { + lookup_ident: &'temp mut FastHashMap<&'input str, Handle<crate::Expression>>, + typifier: &'temp mut Typifier, + variables: &'out mut Arena<crate::LocalVariable>, + expressions: &'out mut Arena<crate::Expression>, + types: &'out mut Arena<crate::Type>, + constants: &'out mut Arena<crate::Constant>, + global_vars: &'out Arena<crate::GlobalVariable>, + arguments: &'out [crate::FunctionArgument], +} + +impl<'a> StatementContext<'a, '_, '_> { + fn reborrow(&mut self) -> StatementContext<'a, '_, '_> { + StatementContext { + lookup_ident: self.lookup_ident, + typifier: self.typifier, + variables: self.variables, + expressions: self.expressions, + types: self.types, + constants: self.constants, + global_vars: self.global_vars, + arguments: self.arguments, + } + } + + fn as_expression(&mut self) -> ExpressionContext<'a, '_, '_> { + ExpressionContext { + lookup_ident: self.lookup_ident, + typifier: self.typifier, + expressions: self.expressions, + types: self.types, + constants: self.constants, + global_vars: self.global_vars, + local_vars: self.variables, + arguments: self.arguments, + } + } +} + +struct ExpressionContext<'input, 'temp, 'out> { + lookup_ident: &'temp FastHashMap<&'input str, Handle<crate::Expression>>, + typifier: &'temp mut Typifier, + expressions: &'out mut Arena<crate::Expression>, + types: &'out mut Arena<crate::Type>, + constants: &'out mut Arena<crate::Constant>, + global_vars: &'out Arena<crate::GlobalVariable>, + local_vars: &'out Arena<crate::LocalVariable>, + arguments: &'out [crate::FunctionArgument], +} + +impl<'a> ExpressionContext<'a, '_, '_> { + fn reborrow(&mut self) -> ExpressionContext<'a, '_, '_> { + ExpressionContext { + lookup_ident: self.lookup_ident, + typifier: self.typifier, + expressions: self.expressions, + types: self.types, + constants: self.constants, + global_vars: self.global_vars, + local_vars: self.local_vars, + arguments: self.arguments, + } + } + + fn resolve_type( + &mut self, + handle: Handle<crate::Expression>, + ) -> Result<&crate::TypeInner, Error<'a>> { + let functions = Arena::new(); //TODO + let resolve_ctx = ResolveContext { + constants: self.constants, + global_vars: self.global_vars, + local_vars: self.local_vars, + functions: &functions, + arguments: self.arguments, + }; + match self + .typifier + .grow(handle, self.expressions, self.types, &resolve_ctx) + { + Err(e) => Err(Error::InvalidResolve(e)), + Ok(()) => Ok(self.typifier.get(handle, self.types)), + } + } + + fn parse_binary_op( + &mut self, + lexer: &mut Lexer<'a>, + classifier: impl Fn(Token<'a>) -> Option<crate::BinaryOperator>, + mut parser: impl FnMut( + &mut Lexer<'a>, + ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'a>>, + ) -> Result<Handle<crate::Expression>, Error<'a>> { + let mut left = parser(lexer, self.reborrow())?; + while let Some(op) = classifier(lexer.peek()) { + let _ = lexer.next(); + let expression = crate::Expression::Binary { + op, + left, + right: parser(lexer, self.reborrow())?, + }; + left = self.expressions.append(expression); + } + Ok(left) + } +} + +enum Composition { + Single(crate::Expression), + Multi(crate::VectorSize, Vec<Handle<crate::Expression>>), +} + +impl Composition { + fn make<'a>( + base: Handle<crate::Expression>, + base_size: crate::VectorSize, + name: &'a str, + expressions: &mut Arena<crate::Expression>, + ) -> Result<Self, Error<'a>> { + const MEMBERS: [char; 4] = ['x', 'y', 'z', 'w']; + + Ok(if name.len() > 1 { + let mut components = Vec::with_capacity(name.len()); + for ch in name.chars() { + let expr = crate::Expression::AccessIndex { + base, + index: MEMBERS[..base_size as usize] + .iter() + .position(|&m| m == ch) + .ok_or(Error::BadAccessor(name))? as u32, + }; + components.push(expressions.append(expr)); + } + + let size = match name.len() { + 2 => crate::VectorSize::Bi, + 3 => crate::VectorSize::Tri, + 4 => crate::VectorSize::Quad, + _ => return Err(Error::BadAccessor(name)), + }; + Composition::Multi(size, components) + } else { + let ch = name.chars().next().ok_or(Error::BadAccessor(name))?; + let index = MEMBERS[..base_size as usize] + .iter() + .position(|&m| m == ch) + .ok_or(Error::BadAccessor(name))? as u32; + Composition::Single(crate::Expression::AccessIndex { base, index }) + }) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Scope { + Decoration, + ImportDecl, + VariableDecl, + TypeDecl, + FunctionDecl, + Block, + Statement, + ConstantExpr, + PrimaryExpr, + SingularExpr, + GeneralExpr, +} + +struct ParsedVariable<'a> { + name: &'a str, + class: Option<crate::StorageClass>, + ty: Handle<crate::Type>, + access: crate::StorageAccess, + init: Option<Handle<crate::Constant>>, +} + +#[derive(Clone, Debug, Error)] +#[error("error while parsing WGSL in scopes {scopes:?} at position {pos:?}: {error}")] +pub struct ParseError<'a> { + pub error: Error<'a>, + pub scopes: Vec<Scope>, + pub pos: (usize, usize), +} + +pub struct Parser { + scopes: Vec<Scope>, + lookup_type: FastHashMap<String, Handle<crate::Type>>, + function_lookup: FastHashMap<String, Handle<crate::Function>>, + std_namespace: Option<Vec<String>>, +} + +impl Parser { + pub fn new() -> Self { + Parser { + scopes: Vec::new(), + lookup_type: FastHashMap::default(), + function_lookup: FastHashMap::default(), + std_namespace: None, + } + } + + fn deconstruct_composite_type( + type_arena: &mut Arena<crate::Type>, + ty: Handle<crate::Type>, + index: usize, + ) -> Result<Handle<crate::Type>, Error<'static>> { + match type_arena[ty].inner { + crate::TypeInner::Vector { kind, width, .. } => { + let inner = crate::TypeInner::Scalar { kind, width }; + Ok(type_arena.fetch_or_append(crate::Type { name: None, inner })) + } + crate::TypeInner::Matrix { width, .. } => { + let inner = crate::TypeInner::Scalar { + kind: crate::ScalarKind::Float, + width, + }; + Ok(type_arena.fetch_or_append(crate::Type { name: None, inner })) + } + crate::TypeInner::Array { base, .. } => Ok(base), + crate::TypeInner::Struct { ref members } => Ok(members[index].ty), + _ => Err(Error::NotCompositeType(ty)), + } + } + + fn get_constant_inner( + word: &str, + ) -> Result<(crate::ConstantInner, crate::ScalarKind), Error<'_>> { + if word.contains('.') { + word.parse() + .map(|f| (crate::ConstantInner::Float(f), crate::ScalarKind::Float)) + .map_err(|err| Error::BadFloat(word, err)) + } else { + word.parse() + .map(|i| (crate::ConstantInner::Sint(i), crate::ScalarKind::Sint)) + .map_err(|err| Error::BadInteger(word, err)) + } + } + + fn parse_function_call<'a>( + &mut self, + lexer: &Lexer<'a>, + mut ctx: ExpressionContext<'a, '_, '_>, + ) -> Result<Option<(crate::Expression, Lexer<'a>)>, Error<'a>> { + let mut lexer = lexer.clone(); + + let external_function = if let Some(std_namespaces) = self.std_namespace.as_deref() { + std_namespaces.iter().all(|namespace| { + lexer.skip(Token::Word(namespace)) && lexer.skip(Token::DoubleColon) + }) + } else { + false + }; + + let origin = if external_function { + let function = lexer.next_ident()?; + crate::FunctionOrigin::External(function.to_string()) + } else if let Ok(function) = lexer.next_ident() { + if let Some(&function) = self.function_lookup.get(function) { + crate::FunctionOrigin::Local(function) + } else { + return Ok(None); + } + } else { + return Ok(None); + }; + + if !lexer.skip(Token::Paren('(')) { + return Ok(None); + } + + let mut arguments = Vec::new(); + while !lexer.skip(Token::Paren(')')) { + if !arguments.is_empty() { + lexer.expect(Token::Separator(','))?; + } + let arg = self.parse_general_expression(&mut lexer, ctx.reborrow())?; + arguments.push(arg); + } + Ok(Some((crate::Expression::Call { origin, arguments }, lexer))) + } + + fn parse_const_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + self_ty: Handle<crate::Type>, + type_arena: &mut Arena<crate::Type>, + const_arena: &mut Arena<crate::Constant>, + ) -> Result<Handle<crate::Constant>, Error<'a>> { + self.scopes.push(Scope::ConstantExpr); + let inner = match lexer.peek() { + Token::Word("true") => { + let _ = lexer.next(); + crate::ConstantInner::Bool(true) + } + Token::Word("false") => { + let _ = lexer.next(); + crate::ConstantInner::Bool(false) + } + Token::Number(word) => { + let _ = lexer.next(); + let (inner, _) = Self::get_constant_inner(word)?; + inner + } + _ => { + let composite_ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?; + lexer.expect(Token::Paren('('))?; + let mut components = Vec::new(); + while !lexer.skip(Token::Paren(')')) { + if !components.is_empty() { + lexer.expect(Token::Separator(','))?; + } + let ty = Self::deconstruct_composite_type( + type_arena, + composite_ty, + components.len(), + )?; + let component = + self.parse_const_expression(lexer, ty, type_arena, const_arena)?; + components.push(component); + } + crate::ConstantInner::Composite(components) + } + }; + let handle = const_arena.fetch_or_append(crate::Constant { + name: None, + specialization: None, + inner, + ty: self_ty, + }); + self.scopes.pop(); + Ok(handle) + } + + fn parse_primary_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut ctx: ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'a>> { + self.scopes.push(Scope::PrimaryExpr); + let backup = lexer.clone(); + let expression = match lexer.next() { + Token::Paren('(') => { + let expr = self.parse_general_expression(lexer, ctx)?; + lexer.expect(Token::Paren(')'))?; + self.scopes.pop(); + return Ok(expr); + } + Token::Word("true") => { + let handle = ctx.constants.fetch_or_append(crate::Constant { + name: None, + specialization: None, + inner: crate::ConstantInner::Bool(true), + ty: ctx.types.fetch_or_append(crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width: 1, + }, + }), + }); + crate::Expression::Constant(handle) + } + Token::Word("false") => { + let handle = ctx.constants.fetch_or_append(crate::Constant { + name: None, + specialization: None, + inner: crate::ConstantInner::Bool(false), + ty: ctx.types.fetch_or_append(crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width: 1, + }, + }), + }); + crate::Expression::Constant(handle) + } + Token::Number(word) => { + let (inner, kind) = Self::get_constant_inner(word)?; + let handle = ctx.constants.fetch_or_append(crate::Constant { + name: None, + specialization: None, + inner, + ty: ctx.types.fetch_or_append(crate::Type { + name: None, + inner: crate::TypeInner::Scalar { kind, width: 4 }, + }), + }); + crate::Expression::Constant(handle) + } + Token::Word(word) => { + if let Some(handle) = ctx.lookup_ident.get(word) { + self.scopes.pop(); + return Ok(*handle); + } + if let Some((expr, new_lexer)) = + self.parse_function_call(&backup, ctx.reborrow())? + { + *lexer = new_lexer; + expr + } else { + *lexer = backup; + let ty = self.parse_type_decl(lexer, None, ctx.types, ctx.constants)?; + lexer.expect(Token::Paren('('))?; + let mut components = Vec::new(); + while !lexer.skip(Token::Paren(')')) { + if !components.is_empty() { + lexer.expect(Token::Separator(','))?; + } + let sub_expr = self.parse_general_expression(lexer, ctx.reborrow())?; + components.push(sub_expr); + } + crate::Expression::Compose { ty, components } + } + } + other => return Err(Error::Unexpected(other)), + }; + self.scopes.pop(); + Ok(ctx.expressions.append(expression)) + } + + fn parse_postfix<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut ctx: ExpressionContext<'a, '_, '_>, + mut handle: Handle<crate::Expression>, + ) -> Result<Handle<crate::Expression>, Error<'a>> { + loop { + match lexer.peek() { + Token::Separator('.') => { + let _ = lexer.next(); + let name = lexer.next_ident()?; + let expression = match *ctx.resolve_type(handle)? { + crate::TypeInner::Struct { ref members } => { + let index = members + .iter() + .position(|m| m.name.as_deref() == Some(name)) + .ok_or(Error::BadAccessor(name))? + as u32; + crate::Expression::AccessIndex { + base: handle, + index, + } + } + crate::TypeInner::Vector { size, kind, width } => { + match Composition::make(handle, size, name, ctx.expressions)? { + Composition::Multi(size, components) => { + let inner = crate::TypeInner::Vector { size, kind, width }; + crate::Expression::Compose { + ty: ctx + .types + .fetch_or_append(crate::Type { name: None, inner }), + components, + } + } + Composition::Single(expr) => expr, + } + } + crate::TypeInner::Matrix { + rows, + columns, + width, + } => match Composition::make(handle, columns, name, ctx.expressions)? { + Composition::Multi(columns, components) => { + let inner = crate::TypeInner::Matrix { + columns, + rows, + width, + }; + crate::Expression::Compose { + ty: ctx + .types + .fetch_or_append(crate::Type { name: None, inner }), + components, + } + } + Composition::Single(expr) => expr, + }, + _ => return Err(Error::BadAccessor(name)), + }; + handle = ctx.expressions.append(expression); + } + Token::Paren('[') => { + let _ = lexer.next(); + let index = self.parse_general_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Paren(']'))?; + let expr = crate::Expression::Access { + base: handle, + index, + }; + handle = ctx.expressions.append(expr); + } + _ => return Ok(handle), + } + } + } + + fn parse_singular_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut ctx: ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'a>> { + self.scopes.push(Scope::SingularExpr); + let backup = lexer.clone(); + let expression = match lexer.next() { + Token::Operation('-') => Some(crate::Expression::Unary { + op: crate::UnaryOperator::Negate, + expr: self.parse_singular_expression(lexer, ctx.reborrow())?, + }), + Token::Operation('!') => Some(crate::Expression::Unary { + op: crate::UnaryOperator::Not, + expr: self.parse_singular_expression(lexer, ctx.reborrow())?, + }), + Token::Word(word) => { + if let Some(fun) = conv::get_intrinsic(word) { + lexer.expect(Token::Paren('('))?; + let argument = self.parse_primary_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Paren(')'))?; + Some(crate::Expression::Intrinsic { fun, argument }) + } else if let Some(axis) = conv::get_derivative(word) { + lexer.expect(Token::Paren('('))?; + let expr = self.parse_primary_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Paren(')'))?; + Some(crate::Expression::Derivative { axis, expr }) + } else if let Some((kind, _width)) = conv::get_scalar_type(word) { + lexer.expect(Token::Paren('('))?; + let expr = self.parse_primary_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Paren(')'))?; + Some(crate::Expression::As { + expr, + kind, + convert: true, + }) + } else { + match word { + "dot" => { + lexer.expect(Token::Paren('('))?; + let a = self.parse_primary_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let b = self.parse_primary_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Paren(')'))?; + Some(crate::Expression::DotProduct(a, b)) + } + "cross" => { + lexer.expect(Token::Paren('('))?; + let a = self.parse_primary_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let b = self.parse_primary_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Paren(')'))?; + Some(crate::Expression::CrossProduct(a, b)) + } + "textureSample" => { + lexer.expect(Token::Paren('('))?; + let image_name = lexer.next_ident()?; + lexer.expect(Token::Separator(','))?; + let sampler_name = lexer.next_ident()?; + lexer.expect(Token::Separator(','))?; + let coordinate = + self.parse_primary_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Paren(')'))?; + Some(crate::Expression::ImageSample { + image: ctx.lookup_ident.lookup(image_name)?, + sampler: ctx.lookup_ident.lookup(sampler_name)?, + coordinate, + level: crate::SampleLevel::Auto, + depth_ref: None, + }) + } + "textureSampleLevel" => { + lexer.expect(Token::Paren('('))?; + let image_name = lexer.next_ident()?; + lexer.expect(Token::Separator(','))?; + let sampler_name = lexer.next_ident()?; + lexer.expect(Token::Separator(','))?; + let coordinate = + self.parse_primary_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let level = self.parse_primary_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Paren(')'))?; + Some(crate::Expression::ImageSample { + image: ctx.lookup_ident.lookup(image_name)?, + sampler: ctx.lookup_ident.lookup(sampler_name)?, + coordinate, + level: crate::SampleLevel::Exact(level), + depth_ref: None, + }) + } + "textureSampleBias" => { + lexer.expect(Token::Paren('('))?; + let image_name = lexer.next_ident()?; + lexer.expect(Token::Separator(','))?; + let sampler_name = lexer.next_ident()?; + lexer.expect(Token::Separator(','))?; + let coordinate = + self.parse_primary_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let bias = self.parse_primary_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Paren(')'))?; + Some(crate::Expression::ImageSample { + image: ctx.lookup_ident.lookup(image_name)?, + sampler: ctx.lookup_ident.lookup(sampler_name)?, + coordinate, + level: crate::SampleLevel::Bias(bias), + depth_ref: None, + }) + } + "textureSampleCompare" => { + lexer.expect(Token::Paren('('))?; + let image_name = lexer.next_ident()?; + lexer.expect(Token::Separator(','))?; + let sampler_name = lexer.next_ident()?; + lexer.expect(Token::Separator(','))?; + let coordinate = + self.parse_primary_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let reference = self.parse_primary_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Paren(')'))?; + Some(crate::Expression::ImageSample { + image: ctx.lookup_ident.lookup(image_name)?, + sampler: ctx.lookup_ident.lookup(sampler_name)?, + coordinate, + level: crate::SampleLevel::Zero, + depth_ref: Some(reference), + }) + } + "textureLoad" => { + lexer.expect(Token::Paren('('))?; + let image_name = lexer.next_ident()?; + let image = ctx.lookup_ident.lookup(image_name)?; + lexer.expect(Token::Separator(','))?; + let coordinate = + self.parse_primary_expression(lexer, ctx.reborrow())?; + let is_storage = match *ctx.resolve_type(image)? { + crate::TypeInner::Image { + class: crate::ImageClass::Storage(_), + .. + } => true, + _ => false, + }; + let index = if is_storage { + None + } else { + lexer.expect(Token::Separator(','))?; + let index_name = lexer.next_ident()?; + Some(ctx.lookup_ident.lookup(index_name)?) + }; + lexer.expect(Token::Paren(')'))?; + Some(crate::Expression::ImageLoad { + image, + coordinate, + index, + }) + } + _ => None, + } + } + } + _ => None, + }; + + let handle = match expression { + Some(expr) => ctx.expressions.append(expr), + None => { + *lexer = backup; + let handle = self.parse_primary_expression(lexer, ctx.reborrow())?; + self.parse_postfix(lexer, ctx, handle)? + } + }; + self.scopes.pop(); + Ok(handle) + } + + fn parse_equality_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut context: ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'a>> { + // equality_expression + context.parse_binary_op( + lexer, + |token| match token { + Token::LogicalOperation('=') => Some(crate::BinaryOperator::Equal), + Token::LogicalOperation('!') => Some(crate::BinaryOperator::NotEqual), + _ => None, + }, + // relational_expression + |lexer, mut context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Paren('<') => Some(crate::BinaryOperator::Less), + Token::Paren('>') => Some(crate::BinaryOperator::Greater), + Token::LogicalOperation('<') => Some(crate::BinaryOperator::LessEqual), + Token::LogicalOperation('>') => Some(crate::BinaryOperator::GreaterEqual), + _ => None, + }, + // shift_expression + |lexer, mut context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::ShiftOperation('<') => { + Some(crate::BinaryOperator::ShiftLeft) + } + Token::ShiftOperation('>') => { + Some(crate::BinaryOperator::ShiftRight) + } + _ => None, + }, + // additive_expression + |lexer, mut context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('+') => Some(crate::BinaryOperator::Add), + Token::Operation('-') => { + Some(crate::BinaryOperator::Subtract) + } + _ => None, + }, + // multiplicative_expression + |lexer, mut context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('*') => { + Some(crate::BinaryOperator::Multiply) + } + Token::Operation('/') => { + Some(crate::BinaryOperator::Divide) + } + Token::Operation('%') => { + Some(crate::BinaryOperator::Modulo) + } + _ => None, + }, + |lexer, context| { + self.parse_singular_expression(lexer, context) + }, + ) + }, + ) + }, + ) + }, + ) + }, + ) + } + + fn parse_general_expression<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut context: ExpressionContext<'a, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'a>> { + self.scopes.push(Scope::GeneralExpr); + // logical_or_expression + let handle = context.parse_binary_op( + lexer, + |token| match token { + Token::LogicalOperation('|') => Some(crate::BinaryOperator::LogicalOr), + _ => None, + }, + // logical_and_expression + |lexer, mut context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::LogicalOperation('&') => Some(crate::BinaryOperator::LogicalAnd), + _ => None, + }, + // inclusive_or_expression + |lexer, mut context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('|') => Some(crate::BinaryOperator::InclusiveOr), + _ => None, + }, + // exclusive_or_expression + |lexer, mut context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('^') => { + Some(crate::BinaryOperator::ExclusiveOr) + } + _ => None, + }, + // and_expression + |lexer, mut context| { + context.parse_binary_op( + lexer, + |token| match token { + Token::Operation('&') => { + Some(crate::BinaryOperator::And) + } + _ => None, + }, + |lexer, context| { + self.parse_equality_expression(lexer, context) + }, + ) + }, + ) + }, + ) + }, + ) + }, + )?; + self.scopes.pop(); + Ok(handle) + } + + fn parse_variable_ident_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + type_arena: &mut Arena<crate::Type>, + const_arena: &mut Arena<crate::Constant>, + ) -> Result<(&'a str, Handle<crate::Type>), Error<'a>> { + let name = lexer.next_ident()?; + lexer.expect(Token::Separator(':'))?; + let ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?; + Ok((name, ty)) + } + + fn parse_variable_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + type_arena: &mut Arena<crate::Type>, + const_arena: &mut Arena<crate::Constant>, + ) -> Result<ParsedVariable<'a>, Error<'a>> { + self.scopes.push(Scope::VariableDecl); + let mut class = None; + if lexer.skip(Token::Paren('<')) { + let class_str = lexer.next_ident()?; + class = Some(conv::map_storage_class(class_str)?); + lexer.expect(Token::Paren('>'))?; + } + let name = lexer.next_ident()?; + lexer.expect(Token::Separator(':'))?; + let ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?; + let access = match class { + Some(crate::StorageClass::Storage) => crate::StorageAccess::all(), + Some(crate::StorageClass::Handle) => { + match type_arena[ty].inner { + //TODO: RW textures + crate::TypeInner::Image { + class: crate::ImageClass::Storage(_), + .. + } => crate::StorageAccess::LOAD, + _ => crate::StorageAccess::empty(), + } + } + _ => crate::StorageAccess::empty(), + }; + let init = if lexer.skip(Token::Operation('=')) { + let handle = self.parse_const_expression(lexer, ty, type_arena, const_arena)?; + Some(handle) + } else { + None + }; + lexer.expect(Token::Separator(';'))?; + self.scopes.pop(); + Ok(ParsedVariable { + name, + class, + ty, + access, + init, + }) + } + + fn parse_struct_body<'a>( + &mut self, + lexer: &mut Lexer<'a>, + type_arena: &mut Arena<crate::Type>, + const_arena: &mut Arena<crate::Constant>, + ) -> Result<Vec<crate::StructMember>, Error<'a>> { + let mut members = Vec::new(); + lexer.expect(Token::Paren('{'))?; + loop { + let mut offset = !0; + if lexer.skip(Token::DoubleParen('[')) { + self.scopes.push(Scope::Decoration); + let mut ready = true; + loop { + match lexer.next() { + Token::DoubleParen(']') => { + break; + } + Token::Separator(',') if !ready => { + ready = true; + } + Token::Word("offset") if ready => { + lexer.expect(Token::Paren('('))?; + offset = lexer.next_uint_literal()?; + lexer.expect(Token::Paren(')'))?; + ready = false; + } + other => return Err(Error::Unexpected(other)), + } + } + self.scopes.pop(); + } + let name = match lexer.next() { + Token::Word(word) => word, + Token::Paren('}') => return Ok(members), + other => return Err(Error::Unexpected(other)), + }; + if offset == !0 { + return Err(Error::MissingMemberOffset(name)); + } + lexer.expect(Token::Separator(':'))?; + let ty = self.parse_type_decl(lexer, None, type_arena, const_arena)?; + lexer.expect(Token::Separator(';'))?; + members.push(crate::StructMember { + name: Some(name.to_owned()), + origin: crate::MemberOrigin::Offset(offset), + ty, + }); + } + } + + fn parse_type_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + self_name: Option<&'a str>, + type_arena: &mut Arena<crate::Type>, + const_arena: &mut Arena<crate::Constant>, + ) -> Result<Handle<crate::Type>, Error<'a>> { + self.scopes.push(Scope::TypeDecl); + let decoration_lexer = if lexer.skip(Token::DoubleParen('[')) { + Some(lexer.take_until(Token::DoubleParen(']'))?) + } else { + None + }; + + let inner = match lexer.next() { + Token::Word("f32") => crate::TypeInner::Scalar { + kind: crate::ScalarKind::Float, + width: 4, + }, + Token::Word("i32") => crate::TypeInner::Scalar { + kind: crate::ScalarKind::Sint, + width: 4, + }, + Token::Word("u32") => crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }, + Token::Word("vec2") => { + let (kind, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Vector { + size: crate::VectorSize::Bi, + kind, + width, + } + } + Token::Word("vec3") => { + let (kind, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Vector { + size: crate::VectorSize::Tri, + kind, + width, + } + } + Token::Word("vec4") => { + let (kind, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Vector { + size: crate::VectorSize::Quad, + kind, + width, + } + } + Token::Word("mat2x2") => { + let (_, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Matrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Bi, + width, + } + } + Token::Word("mat2x3") => { + let (_, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Matrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Tri, + width, + } + } + Token::Word("mat2x4") => { + let (_, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Matrix { + columns: crate::VectorSize::Bi, + rows: crate::VectorSize::Quad, + width, + } + } + Token::Word("mat3x2") => { + let (_, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Bi, + width, + } + } + Token::Word("mat3x3") => { + let (_, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Tri, + width, + } + } + Token::Word("mat3x4") => { + let (_, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Quad, + width, + } + } + Token::Word("mat4x2") => { + let (_, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Bi, + width, + } + } + Token::Word("mat4x3") => { + let (_, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + width, + } + } + Token::Word("mat4x4") => { + let (_, width) = lexer.next_scalar_generic()?; + crate::TypeInner::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Quad, + width, + } + } + Token::Word("ptr") => { + lexer.expect(Token::Paren('<'))?; + let class = conv::map_storage_class(lexer.next_ident()?)?; + lexer.expect(Token::Separator(','))?; + let base = self.parse_type_decl(lexer, None, type_arena, const_arena)?; + lexer.expect(Token::Paren('>'))?; + crate::TypeInner::Pointer { base, class } + } + Token::Word("array") => { + lexer.expect(Token::Paren('<'))?; + let base = self.parse_type_decl(lexer, None, type_arena, const_arena)?; + let size = match lexer.next() { + Token::Separator(',') => { + let value = lexer.next_uint_literal()?; + lexer.expect(Token::Paren('>'))?; + let const_handle = const_arena.fetch_or_append(crate::Constant { + name: None, + specialization: None, + inner: crate::ConstantInner::Uint(value as u64), + ty: type_arena.fetch_or_append(crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }, + }), + }); + crate::ArraySize::Constant(const_handle) + } + Token::Paren('>') => crate::ArraySize::Dynamic, + other => return Err(Error::Unexpected(other)), + }; + + let mut stride = None; + if let Some(mut lexer) = decoration_lexer { + self.scopes.push(Scope::Decoration); + loop { + match lexer.next() { + Token::Word("stride") => { + use std::num::NonZeroU32; + stride = Some( + NonZeroU32::new(lexer.next_uint_literal()?) + .ok_or(Error::ZeroStride)?, + ); + } + Token::End => break, + other => return Err(Error::Unexpected(other)), + } + } + self.scopes.pop(); + } + + crate::TypeInner::Array { base, size, stride } + } + Token::Word("struct") => { + let members = self.parse_struct_body(lexer, type_arena, const_arena)?; + crate::TypeInner::Struct { members } + } + Token::Word("sampler") => crate::TypeInner::Sampler { comparison: false }, + Token::Word("sampler_comparison") => crate::TypeInner::Sampler { comparison: true }, + Token::Word("texture_sampled_1d") => { + let (kind, _) = lexer.next_scalar_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D1, + arrayed: false, + class: crate::ImageClass::Sampled { kind, multi: false }, + } + } + Token::Word("texture_sampled_1d_array") => { + let (kind, _) = lexer.next_scalar_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D1, + arrayed: true, + class: crate::ImageClass::Sampled { kind, multi: false }, + } + } + Token::Word("texture_sampled_2d") => { + let (kind, _) = lexer.next_scalar_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Sampled { kind, multi: false }, + } + } + Token::Word("texture_sampled_2d_array") => { + let (kind, _) = lexer.next_scalar_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Sampled { kind, multi: false }, + } + } + Token::Word("texture_sampled_3d") => { + let (kind, _) = lexer.next_scalar_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D3, + arrayed: false, + class: crate::ImageClass::Sampled { kind, multi: false }, + } + } + Token::Word("texture_sampled_cube") => { + let (kind, _) = lexer.next_scalar_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::Cube, + arrayed: false, + class: crate::ImageClass::Sampled { kind, multi: false }, + } + } + Token::Word("texture_sampled_cube_array") => { + let (kind, _) = lexer.next_scalar_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::Cube, + arrayed: true, + class: crate::ImageClass::Sampled { kind, multi: false }, + } + } + Token::Word("texture_multisampled_2d") => { + let (kind, _) = lexer.next_scalar_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Sampled { kind, multi: true }, + } + } + Token::Word("texture_depth_2d") => crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Depth, + }, + Token::Word("texture_depth_2d_array") => crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Depth, + }, + Token::Word("texture_depth_cube") => crate::TypeInner::Image { + dim: crate::ImageDimension::Cube, + arrayed: false, + class: crate::ImageClass::Depth, + }, + Token::Word("texture_depth_cube_array") => crate::TypeInner::Image { + dim: crate::ImageDimension::Cube, + arrayed: true, + class: crate::ImageClass::Depth, + }, + Token::Word("texture_ro_1d") => { + let format = lexer.next_format_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D1, + arrayed: false, + class: crate::ImageClass::Storage(format), + } + } + Token::Word("texture_ro_1d_array") => { + let format = lexer.next_format_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D1, + arrayed: true, + class: crate::ImageClass::Storage(format), + } + } + Token::Word("texture_ro_2d") => { + let format = lexer.next_format_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Storage(format), + } + } + Token::Word("texture_ro_2d_array") => { + let format = lexer.next_format_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Storage(format), + } + } + Token::Word("texture_ro_3d") => { + let format = lexer.next_format_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D3, + arrayed: false, + class: crate::ImageClass::Storage(format), + } + } + Token::Word("texture_wo_1d") => { + let format = lexer.next_format_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D1, + arrayed: false, + class: crate::ImageClass::Storage(format), + } + } + Token::Word("texture_wo_1d_array") => { + let format = lexer.next_format_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D1, + arrayed: true, + class: crate::ImageClass::Storage(format), + } + } + Token::Word("texture_wo_2d") => { + let format = lexer.next_format_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: false, + class: crate::ImageClass::Storage(format), + } + } + Token::Word("texture_wo_2d_array") => { + let format = lexer.next_format_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D2, + arrayed: true, + class: crate::ImageClass::Storage(format), + } + } + Token::Word("texture_wo_3d") => { + let format = lexer.next_format_generic()?; + crate::TypeInner::Image { + dim: crate::ImageDimension::D3, + arrayed: false, + class: crate::ImageClass::Storage(format), + } + } + Token::Word(name) => { + self.scopes.pop(); + return match self.lookup_type.get(name) { + Some(&handle) => Ok(handle), + None => Err(Error::UnknownType(name)), + }; + } + other => return Err(Error::Unexpected(other)), + }; + self.scopes.pop(); + + let handle = type_arena.fetch_or_append(crate::Type { + name: self_name.map(|s| s.to_string()), + inner, + }); + Ok(handle) + } + + fn parse_statement<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut context: StatementContext<'a, '_, '_>, + ) -> Result<crate::Statement, Error<'a>> { + let backup = lexer.clone(); + match lexer.next() { + Token::Separator(';') => Ok(crate::Statement::Block(Vec::new())), + Token::Word(word) => { + self.scopes.push(Scope::Statement); + let statement = match word { + "var" => { + enum Init { + Empty, + Constant(Handle<crate::Constant>), + Variable(Handle<crate::Expression>), + } + let (name, ty) = self.parse_variable_ident_decl( + lexer, + context.types, + context.constants, + )?; + let init = if lexer.skip(Token::Operation('=')) { + let value = + self.parse_general_expression(lexer, context.as_expression())?; + if let crate::Expression::Constant(handle) = context.expressions[value] + { + Init::Constant(handle) + } else { + Init::Variable(value) + } + } else { + Init::Empty + }; + lexer.expect(Token::Separator(';'))?; + let var_id = context.variables.append(crate::LocalVariable { + name: Some(name.to_owned()), + ty, + init: match init { + Init::Constant(value) => Some(value), + _ => None, + }, + }); + let expr_id = context + .expressions + .append(crate::Expression::LocalVariable(var_id)); + context.lookup_ident.insert(name, expr_id); + match init { + Init::Variable(value) => crate::Statement::Store { + pointer: expr_id, + value, + }, + _ => crate::Statement::Block(Vec::new()), + } + } + "return" => { + let value = if lexer.peek() != Token::Separator(';') { + Some(self.parse_general_expression(lexer, context.as_expression())?) + } else { + None + }; + lexer.expect(Token::Separator(';'))?; + crate::Statement::Return { value } + } + "if" => { + lexer.expect(Token::Paren('('))?; + let condition = + self.parse_general_expression(lexer, context.as_expression())?; + lexer.expect(Token::Paren(')'))?; + let accept = self.parse_block(lexer, context.reborrow())?; + let reject = if lexer.skip(Token::Word("else")) { + self.parse_block(lexer, context.reborrow())? + } else { + Vec::new() + }; + crate::Statement::If { + condition, + accept, + reject, + } + } + "loop" => { + let mut body = Vec::new(); + let mut continuing = Vec::new(); + lexer.expect(Token::Paren('{'))?; + loop { + if lexer.skip(Token::Word("continuing")) { + continuing = self.parse_block(lexer, context.reborrow())?; + lexer.expect(Token::Paren('}'))?; + break; + } + if lexer.skip(Token::Paren('}')) { + break; + } + let s = self.parse_statement(lexer, context.reborrow())?; + body.push(s); + } + crate::Statement::Loop { body, continuing } + } + "break" => crate::Statement::Break, + "continue" => crate::Statement::Continue, + ident => { + // assignment + if let Some(&var_expr) = context.lookup_ident.get(ident) { + let left = + self.parse_postfix(lexer, context.as_expression(), var_expr)?; + lexer.expect(Token::Operation('='))?; + let value = + self.parse_general_expression(lexer, context.as_expression())?; + lexer.expect(Token::Separator(';'))?; + crate::Statement::Store { + pointer: left, + value, + } + } else if let Some((expr, new_lexer)) = + self.parse_function_call(&backup, context.as_expression())? + { + *lexer = new_lexer; + context.expressions.append(expr); + lexer.expect(Token::Separator(';'))?; + crate::Statement::Block(Vec::new()) + } else { + return Err(Error::UnknownIdent(ident)); + } + } + }; + self.scopes.pop(); + Ok(statement) + } + other => Err(Error::Unexpected(other)), + } + } + + fn parse_block<'a>( + &mut self, + lexer: &mut Lexer<'a>, + mut context: StatementContext<'a, '_, '_>, + ) -> Result<Vec<crate::Statement>, Error<'a>> { + self.scopes.push(Scope::Block); + lexer.expect(Token::Paren('{'))?; + let mut statements = Vec::new(); + while !lexer.skip(Token::Paren('}')) { + let s = self.parse_statement(lexer, context.reborrow())?; + statements.push(s); + } + self.scopes.pop(); + Ok(statements) + } + + fn parse_function_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + module: &mut crate::Module, + lookup_global_expression: &FastHashMap<&'a str, crate::Expression>, + ) -> Result<(crate::Function, &'a str), Error<'a>> { + self.scopes.push(Scope::FunctionDecl); + // read function name + let mut lookup_ident = FastHashMap::default(); + let fun_name = lexer.next_ident()?; + // populare initial expressions + let mut expressions = Arena::new(); + for (&name, expression) in lookup_global_expression.iter() { + let expr_handle = expressions.append(expression.clone()); + lookup_ident.insert(name, expr_handle); + } + // read parameter list + let mut arguments = Vec::new(); + lexer.expect(Token::Paren('('))?; + while !lexer.skip(Token::Paren(')')) { + if !arguments.is_empty() { + lexer.expect(Token::Separator(','))?; + } + let (param_name, param_type) = + self.parse_variable_ident_decl(lexer, &mut module.types, &mut module.constants)?; + let param_index = arguments.len() as u32; + let expression_token = + expressions.append(crate::Expression::FunctionArgument(param_index)); + lookup_ident.insert(param_name, expression_token); + arguments.push(crate::FunctionArgument { + name: Some(param_name.to_string()), + ty: param_type, + }); + } + // read return type + lexer.expect(Token::Arrow)?; + let return_type = if lexer.skip(Token::Word("void")) { + None + } else { + Some(self.parse_type_decl(lexer, None, &mut module.types, &mut module.constants)?) + }; + + let mut fun = crate::Function { + name: Some(fun_name.to_string()), + arguments, + return_type, + global_usage: Vec::new(), + local_variables: Arena::new(), + expressions, + body: Vec::new(), + }; + + // read body + let mut typifier = Typifier::new(); + fun.body = self.parse_block( + lexer, + StatementContext { + lookup_ident: &mut lookup_ident, + typifier: &mut typifier, + variables: &mut fun.local_variables, + expressions: &mut fun.expressions, + types: &mut module.types, + constants: &mut module.constants, + global_vars: &module.global_variables, + arguments: &fun.arguments, + }, + )?; + // done + fun.fill_global_use(&module.global_variables); + self.scopes.pop(); + + Ok((fun, fun_name)) + } + + fn parse_global_decl<'a>( + &mut self, + lexer: &mut Lexer<'a>, + module: &mut crate::Module, + lookup_global_expression: &mut FastHashMap<&'a str, crate::Expression>, + ) -> Result<bool, Error<'a>> { + // read decorations + let mut binding = None; + // Perspective is the default qualifier. + let mut interpolation = None; + let mut stage = None; + let mut workgroup_size = [0u32; 3]; + + if lexer.skip(Token::DoubleParen('[')) { + let (mut bind_index, mut bind_group) = (None, None); + self.scopes.push(Scope::Decoration); + loop { + match lexer.next_ident()? { + "location" => { + lexer.expect(Token::Paren('('))?; + let loc = lexer.next_uint_literal()?; + lexer.expect(Token::Paren(')'))?; + binding = Some(crate::Binding::Location(loc)); + } + "builtin" => { + lexer.expect(Token::Paren('('))?; + let builtin = conv::map_built_in(lexer.next_ident()?)?; + lexer.expect(Token::Paren(')'))?; + binding = Some(crate::Binding::BuiltIn(builtin)); + } + "binding" => { + lexer.expect(Token::Paren('('))?; + bind_index = Some(lexer.next_uint_literal()?); + lexer.expect(Token::Paren(')'))?; + } + "group" => { + lexer.expect(Token::Paren('('))?; + bind_group = Some(lexer.next_uint_literal()?); + lexer.expect(Token::Paren(')'))?; + } + "interpolate" => { + lexer.expect(Token::Paren('('))?; + interpolation = Some(conv::map_interpolation(lexer.next_ident()?)?); + lexer.expect(Token::Paren(')'))?; + } + "stage" => { + lexer.expect(Token::Paren('('))?; + stage = Some(conv::map_shader_stage(lexer.next_ident()?)?); + lexer.expect(Token::Paren(')'))?; + } + "workgroup_size" => { + lexer.expect(Token::Paren('('))?; + for (i, size) in workgroup_size.iter_mut().enumerate() { + *size = lexer.next_uint_literal()?; + match lexer.next() { + Token::Paren(')') => break, + Token::Separator(',') if i != 2 => (), + other => return Err(Error::Unexpected(other)), + } + } + for size in workgroup_size.iter_mut() { + if *size == 0 { + *size = 1; + } + } + } + word => return Err(Error::UnknownDecoration(word)), + } + match lexer.next() { + Token::DoubleParen(']') => { + break; + } + Token::Separator(',') => {} + other => return Err(Error::Unexpected(other)), + } + } + if let (Some(group), Some(index)) = (bind_group, bind_index) { + binding = Some(crate::Binding::Resource { + group, + binding: index, + }); + } + self.scopes.pop(); + } + // read items + match lexer.next() { + Token::Separator(';') => {} + Token::Word("import") => { + self.scopes.push(Scope::ImportDecl); + let path = match lexer.next() { + Token::String(path) => path, + other => return Err(Error::Unexpected(other)), + }; + lexer.expect(Token::Word("as"))?; + let mut namespaces = Vec::new(); + loop { + namespaces.push(lexer.next_ident()?.to_owned()); + if lexer.skip(Token::Separator(';')) { + break; + } + lexer.expect(Token::DoubleColon)?; + } + match path { + "GLSL.std.450" => self.std_namespace = Some(namespaces), + _ => return Err(Error::UnknownImport(path)), + } + self.scopes.pop(); + } + Token::Word("type") => { + let name = lexer.next_ident()?; + lexer.expect(Token::Operation('='))?; + let ty = self.parse_type_decl( + lexer, + Some(name), + &mut module.types, + &mut module.constants, + )?; + self.lookup_type.insert(name.to_owned(), ty); + lexer.expect(Token::Separator(';'))?; + } + Token::Word("const") => { + let (name, ty) = self.parse_variable_ident_decl( + lexer, + &mut module.types, + &mut module.constants, + )?; + lexer.expect(Token::Operation('='))?; + let const_handle = self.parse_const_expression( + lexer, + ty, + &mut module.types, + &mut module.constants, + )?; + lexer.expect(Token::Separator(';'))?; + lookup_global_expression.insert(name, crate::Expression::Constant(const_handle)); + } + Token::Word("var") => { + let pvar = + self.parse_variable_decl(lexer, &mut module.types, &mut module.constants)?; + let class = match pvar.class { + Some(c) => c, + None => match binding { + Some(crate::Binding::BuiltIn(builtin)) => match builtin { + crate::BuiltIn::GlobalInvocationId => crate::StorageClass::Input, + crate::BuiltIn::Position => crate::StorageClass::Output, + _ => unimplemented!(), + }, + _ => crate::StorageClass::Handle, + }, + }; + let var_handle = module.global_variables.append(crate::GlobalVariable { + name: Some(pvar.name.to_owned()), + class, + binding: binding.take(), + ty: pvar.ty, + init: pvar.init, + interpolation, + storage_access: pvar.access, + }); + lookup_global_expression + .insert(pvar.name, crate::Expression::GlobalVariable(var_handle)); + } + Token::Word("fn") => { + let (function, name) = + self.parse_function_decl(lexer, module, &lookup_global_expression)?; + let already_declared = match stage { + Some(stage) => module + .entry_points + .insert( + (stage, name.to_string()), + crate::EntryPoint { + early_depth_test: None, + workgroup_size, + function, + }, + ) + .is_some(), + None => { + let fun_handle = module.functions.append(function); + self.function_lookup + .insert(name.to_string(), fun_handle) + .is_some() + } + }; + if already_declared { + return Err(Error::FunctionRedefinition(name)); + } + } + Token::End => return Ok(false), + token => return Err(Error::Unexpected(token)), + } + match binding { + None => Ok(true), + // we had the decoration but no var? + Some(_) => Err(Error::Other), + } + } + + pub fn parse<'a>(&mut self, source: &'a str) -> Result<crate::Module, ParseError<'a>> { + self.scopes.clear(); + self.lookup_type.clear(); + self.std_namespace = None; + + let mut module = crate::Module::generate_empty(); + let mut lexer = Lexer::new(source); + let mut lookup_global_expression = FastHashMap::default(); + loop { + match self.parse_global_decl(&mut lexer, &mut module, &mut lookup_global_expression) { + Err(error) => { + let pos = lexer.offset_from(source); + let (mut rows, mut cols) = (0, 1); + for line in source[..pos].lines() { + rows += 1; + cols = line.len(); + } + return Err(ParseError { + error, + scopes: std::mem::replace(&mut self.scopes, Vec::new()), + pos: (rows, cols), + }); + } + Ok(true) => {} + Ok(false) => { + if !self.scopes.is_empty() { + return Err(ParseError { + error: Error::Other, + scopes: std::mem::replace(&mut self.scopes, Vec::new()), + pos: (0, 0), + }); + }; + return Ok(module); + } + } + } + } +} + +pub fn parse_str(source: &str) -> Result<crate::Module, ParseError> { + Parser::new().parse(source) +} + +#[test] +fn parse_types() { + assert!(parse_str("const a : i32 = 2;").is_ok()); + assert!(parse_str("const a : x32 = 2;").is_err()); +} diff --git a/third_party/rust/naga/src/lib.rs b/third_party/rust/naga/src/lib.rs new file mode 100644 index 0000000000..936c3b69f5 --- /dev/null +++ b/third_party/rust/naga/src/lib.rs @@ -0,0 +1,788 @@ +//! Universal shader translator. +//! +//! The central structure of the crate is [`Module`]. +//! +//! To improve performance and reduce memory usage, most structures are stored +//! in an [`Arena`], and can be retrieved using the corresponding [`Handle`]. +#![allow( + clippy::new_without_default, + clippy::unneeded_field_pattern, + clippy::match_like_matches_macro +)] +// TODO: use `strip_prefix` instead when Rust 1.45 <= MSRV +#![allow(clippy::manual_strip, clippy::unknown_clippy_lints)] +#![deny(clippy::panic)] + +mod arena; +pub mod back; +pub mod front; +pub mod proc; + +pub use crate::arena::{Arena, Handle}; + +use std::{ + collections::{HashMap, HashSet}, + hash::BuildHasherDefault, + num::NonZeroU32, +}; + +#[cfg(feature = "deserialize")] +use serde::Deserialize; +#[cfg(feature = "serialize")] +use serde::Serialize; + +/// Hash map that is faster but not resilient to DoS attacks. +pub type FastHashMap<K, T> = HashMap<K, T, BuildHasherDefault<fxhash::FxHasher>>; +/// Hash set that is faster but not resilient to DoS attacks. +pub type FastHashSet<K> = HashSet<K, BuildHasherDefault<fxhash::FxHasher>>; + +/// Metadata for a given module. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub struct Header { + /// Major, minor and patch version. + /// + /// Currently used only for the SPIR-V back end. + pub version: (u8, u8, u8), + /// Magic number identifying the tool that generated the shader code. + /// + /// Can safely be set to 0. + pub generator: u32, +} + +/// Early fragment tests. In a standard situation if a driver determines that it is possible to +/// switch on early depth test it will. Typical situations when early depth test is switched off: +/// - Calling ```discard``` in a shader. +/// - Writing to the depth buffer, unless ConservativeDepth is enabled. +/// +/// SPIR-V: ExecutionMode EarlyFragmentTests +/// In GLSL: layout(early_fragment_tests) in; +/// HLSL: Attribute earlydepthstencil +/// +/// For more, see: +/// - https://www.khronos.org/opengl/wiki/Early_Fragment_Test#Explicit_specification +/// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-attributes-earlydepthstencil +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub struct EarlyDepthTest { + conservative: Option<ConservativeDepth>, +} +/// Enables adjusting depth without disabling early Z. +/// +/// SPIR-V: ExecutionMode DepthGreater/DepthLess/DepthUnchanged +/// GLSL: layout (depth_<greater/less/unchanged/any>) out float gl_FragDepth; +/// - ```depth_any``` option behaves as if the layout qualifier was not present. +/// HLSL: SV_Depth/SV_DepthGreaterEqual/SV_DepthLessEqual +/// +/// For more, see: +/// - https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt +/// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-semantics#system-value-semantics +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum ConservativeDepth { + /// Shader may rewrite depth only with a value greater than calculated; + GreaterEqual, + + /// Shader may rewrite depth smaller than one that would have been written without the modification. + LessEqual, + + /// Shader may not rewrite depth value. + Unchanged, +} + +/// Stage of the programmable pipeline. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[allow(missing_docs)] // The names are self evident +pub enum ShaderStage { + Vertex, + Fragment, + Compute, +} + +/// Class of storage for variables. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[allow(missing_docs)] // The names are self evident +pub enum StorageClass { + /// Function locals. + Function, + /// Pipeline input, per invocation. + Input, + /// Pipeline output, per invocation, mutable. + Output, + /// Private data, per invocation, mutable. + Private, + /// Workgroup shared data, mutable. + WorkGroup, + /// Uniform buffer data. + Uniform, + /// Storage buffer data, potentially mutable. + Storage, + /// Opaque handles, such as samplers and images. + Handle, + /// Push constants. + PushConstant, +} + +/// Built-in inputs and outputs. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum BuiltIn { + // vertex + BaseInstance, + BaseVertex, + ClipDistance, + InstanceIndex, + Position, + VertexIndex, + // fragment + PointSize, + FragCoord, + FrontFacing, + SampleIndex, + FragDepth, + // compute + GlobalInvocationId, + LocalInvocationId, + LocalInvocationIndex, + WorkGroupId, +} + +/// Number of bytes. +pub type Bytes = u8; + +/// Number of components in a vector. +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum VectorSize { + /// 2D vector + Bi = 2, + /// 3D vector + Tri = 3, + /// 4D vector + Quad = 4, +} + +/// Primitive type for a scalar. +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum ScalarKind { + /// Signed integer type. + Sint, + /// Unsigned integer type. + Uint, + /// Floating point type. + Float, + /// Boolean type. + Bool, +} + +/// Size of an array. +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum ArraySize { + /// The array size is constant. + Constant(Handle<Constant>), + /// The array size can change at runtime. + Dynamic, +} + +/// Describes where a struct member is placed. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum MemberOrigin { + /// Member is local to the shader. + Empty, + /// Built-in shader variable. + BuiltIn(BuiltIn), + /// Offset within the struct. + Offset(u32), +} + +/// The interpolation qualifier of a binding or struct field. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum Interpolation { + /// The value will be interpolated in a perspective-correct fashion. + /// Also known as "smooth" in glsl. + Perspective, + /// Indicates that linear, non-perspective, correct + /// interpolation must be used. + /// Also known as "no_perspective" in glsl. + Linear, + /// Indicates that no interpolation will be performed. + Flat, + /// Indicates a tessellation patch. + Patch, + /// When used with multi-sampling rasterization, allow + /// a single interpolation location for an entire pixel. + Centroid, + /// When used with multi-sampling rasterization, require + /// per-sample interpolation. + Sample, +} + +/// Member of a user-defined structure. +// Clone is used only for error reporting and is not intended for end users +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub struct StructMember { + pub name: Option<String>, + pub origin: MemberOrigin, + pub ty: Handle<Type>, +} + +/// The number of dimensions an image has. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum ImageDimension { + /// 1D image + D1, + /// 2D image + D2, + /// 3D image + D3, + /// Cube map + Cube, +} + +bitflags::bitflags! { + /// Flags describing an image. + #[cfg_attr(feature = "serialize", derive(Serialize))] + #[cfg_attr(feature = "deserialize", derive(Deserialize))] + pub struct StorageAccess: u32 { + /// Storage can be used as a source for load ops. + const LOAD = 0x1; + /// Storage can be used as a target for store ops. + const STORE = 0x2; + } +} + +// Storage image format. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum StorageFormat { + // 8-bit formats + R8Unorm, + R8Snorm, + R8Uint, + R8Sint, + + // 16-bit formats + R16Uint, + R16Sint, + R16Float, + Rg8Unorm, + Rg8Snorm, + Rg8Uint, + Rg8Sint, + + // 32-bit formats + R32Uint, + R32Sint, + R32Float, + Rg16Uint, + Rg16Sint, + Rg16Float, + Rgba8Unorm, + Rgba8Snorm, + Rgba8Uint, + Rgba8Sint, + + // Packed 32-bit formats + Rgb10a2Unorm, + Rg11b10Float, + + // 64-bit formats + Rg32Uint, + Rg32Sint, + Rg32Float, + Rgba16Uint, + Rgba16Sint, + Rgba16Float, + + // 128-bit formats + Rgba32Uint, + Rgba32Sint, + Rgba32Float, +} + +/// Sub-class of the image type. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum ImageClass { + /// Regular sampled image. + Sampled { + /// Kind of values to sample. + kind: ScalarKind, + // Multi-sampled. + multi: bool, + }, + /// Depth comparison image. + Depth, + /// Storage image. + Storage(StorageFormat), +} + +/// A data type declared in the module. +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub struct Type { + /// The name of the type, if any. + pub name: Option<String>, + /// Inner structure that depends on the kind of the type. + pub inner: TypeInner, +} + +/// Enum with additional information, depending on the kind of type. +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum TypeInner { + /// Number of integral or floating-point kind. + Scalar { kind: ScalarKind, width: Bytes }, + /// Vector of numbers. + Vector { + size: VectorSize, + kind: ScalarKind, + width: Bytes, + }, + /// Matrix of floats. + Matrix { + columns: VectorSize, + rows: VectorSize, + width: Bytes, + }, + /// Pointer to a value. + Pointer { + base: Handle<Type>, + class: StorageClass, + }, + /// Homogenous list of elements. + Array { + base: Handle<Type>, + size: ArraySize, + stride: Option<NonZeroU32>, + }, + /// User-defined structure. + Struct { members: Vec<StructMember> }, + /// Possibly multidimensional array of texels. + Image { + dim: ImageDimension, + arrayed: bool, + class: ImageClass, + }, + /// Can be used to sample values from images. + Sampler { comparison: bool }, +} + +/// Constant value. +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub struct Constant { + pub name: Option<String>, + pub specialization: Option<u32>, + pub inner: ConstantInner, + pub ty: Handle<Type>, +} + +/// Additional information, dependendent on the kind of constant. +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum ConstantInner { + Sint(i64), + Uint(u64), + Float(f64), + Bool(bool), + Composite(Vec<Handle<Constant>>), +} + +/// Describes how an input/output variable is to be bound. +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum Binding { + /// Built-in shader variable. + BuiltIn(BuiltIn), + /// Indexed location. + Location(u32), + /// Binding within a resource group. + Resource { group: u32, binding: u32 }, +} + +bitflags::bitflags! { + /// Indicates how a global variable is used. + #[cfg_attr(feature = "serialize", derive(Serialize))] + #[cfg_attr(feature = "deserialize", derive(Deserialize))] + pub struct GlobalUse: u8 { + /// Data will be read from the variable. + const LOAD = 0x1; + /// Data will be written to the variable. + const STORE = 0x2; + } +} + +/// Variable defined at module level. +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub struct GlobalVariable { + /// Name of the variable, if any. + pub name: Option<String>, + /// How this variable is to be stored. + pub class: StorageClass, + /// How this variable is to be bound. + pub binding: Option<Binding>, + /// The type of this variable. + pub ty: Handle<Type>, + /// Initial value for this variable. + pub init: Option<Handle<Constant>>, + /// The interpolation qualifier, if any. + /// If the this `GlobalVariable` is a vertex output + /// or fragment input, `None` corresponds to the + /// `smooth`/`perspective` interpolation qualifier. + pub interpolation: Option<Interpolation>, + /// Access bit for storage types of images and buffers. + pub storage_access: StorageAccess, +} + +/// Variable defined at function level. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub struct LocalVariable { + /// Name of the variable, if any. + pub name: Option<String>, + /// The type of this variable. + pub ty: Handle<Type>, + /// Initial value for this variable. + pub init: Option<Handle<Constant>>, +} + +/// Operation that can be applied on a single value. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum UnaryOperator { + Negate, + Not, +} + +/// Operation that can be applied on two values. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum BinaryOperator { + Add, + Subtract, + Multiply, + Divide, + Modulo, + Equal, + NotEqual, + Less, + LessEqual, + Greater, + GreaterEqual, + And, + ExclusiveOr, + InclusiveOr, + LogicalAnd, + LogicalOr, + ShiftLeft, + /// Right shift carries the sign of signed integers only. + ShiftRight, +} + +/// Built-in shader function. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum IntrinsicFunction { + Any, + All, + IsNan, + IsInf, + IsFinite, + IsNormal, +} + +/// Axis on which to compute a derivative. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum DerivativeAxis { + X, + Y, + Width, +} + +/// Origin of a function to call. +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum FunctionOrigin { + Local(Handle<Function>), + // External { + // namespace: String, // Maybe this should be a handle to a namespace Arena? + // function: String, + // }, + External(String), +} + +/// Sampling modifier to control the level of detail. +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum SampleLevel { + Auto, + Zero, + Exact(Handle<Expression>), + Bias(Handle<Expression>), +} + +/// An expression that can be evaluated to obtain a value. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum Expression { + /// Array access with a computed index. + Access { + base: Handle<Expression>, + index: Handle<Expression>, //int + }, + /// Array access with a known index. + AccessIndex { + base: Handle<Expression>, + index: u32, + }, + /// Constant value. + Constant(Handle<Constant>), + /// Composite expression. + Compose { + ty: Handle<Type>, + components: Vec<Handle<Expression>>, + }, + /// Reference a function parameter, by its index. + FunctionArgument(u32), + /// Reference a global variable. + GlobalVariable(Handle<GlobalVariable>), + /// Reference a local variable. + LocalVariable(Handle<LocalVariable>), + /// Load a value indirectly. + Load { pointer: Handle<Expression> }, + /// Sample a point from a sampled or a depth image. + ImageSample { + image: Handle<Expression>, + sampler: Handle<Expression>, + coordinate: Handle<Expression>, + level: SampleLevel, + depth_ref: Option<Handle<Expression>>, + }, + /// Load a texel from an image. + ImageLoad { + image: Handle<Expression>, + coordinate: Handle<Expression>, + /// For storage images, this is None. + /// For sampled images, this is the Some(Level). + /// For multisampled images, this is Some(Sample). + index: Option<Handle<Expression>>, + }, + /// Apply an unary operator. + Unary { + op: UnaryOperator, + expr: Handle<Expression>, + }, + /// Apply a binary operator. + Binary { + op: BinaryOperator, + left: Handle<Expression>, + right: Handle<Expression>, + }, + /// Select between two values based on a condition. + Select { + /// Boolean expression + condition: Handle<Expression>, + accept: Handle<Expression>, + reject: Handle<Expression>, + }, + /// Call an intrinsic function. + Intrinsic { + fun: IntrinsicFunction, + argument: Handle<Expression>, + }, + /// Transpose of a matrix. + Transpose(Handle<Expression>), + /// Dot product between two vectors. + DotProduct(Handle<Expression>, Handle<Expression>), + /// Cross product between two vectors. + CrossProduct(Handle<Expression>, Handle<Expression>), + /// Cast a simply type to another kind. + As { + /// Source expression, which can only be a scalar or a vector. + expr: Handle<Expression>, + /// Target scalar kind. + kind: ScalarKind, + /// True = conversion needs to take place; False = bitcast. + convert: bool, + }, + /// Compute the derivative on an axis. + Derivative { + axis: DerivativeAxis, + //modifier, + expr: Handle<Expression>, + }, + /// Call another function. + Call { + origin: FunctionOrigin, + arguments: Vec<Handle<Expression>>, + }, + /// Get the length of an array. + ArrayLength(Handle<Expression>), +} + +/// A code block is just a vector of statements. +pub type Block = Vec<Statement>; + +/// Marker type, used for falling through in a switch statement. +// Clone is used only for error reporting and is not intended for end users +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub struct FallThrough; + +/// Instructions which make up an executable block. +// Clone is used only for error reporting and is not intended for end users +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum Statement { + /// A block containing more statements, to be executed sequentially. + Block(Block), + /// Conditionally executes one of two blocks, based on the value of the condition. + If { + condition: Handle<Expression>, //bool + accept: Block, + reject: Block, + }, + /// Conditionally executes one of multiple blocks, based on the value of the selector. + Switch { + selector: Handle<Expression>, //int + cases: FastHashMap<i32, (Block, Option<FallThrough>)>, + default: Block, + }, + /// Executes a block repeatedly. + Loop { body: Block, continuing: Block }, + //TODO: move terminator variations into a separate enum? + /// Exits the loop. + Break, + /// Skips execution to the next iteration of the loop. + Continue, + /// Returns from the function (possibly with a value). + Return { value: Option<Handle<Expression>> }, + /// Aborts the current shader execution. + Kill, + /// Stores a value at an address. + Store { + pointer: Handle<Expression>, + value: Handle<Expression>, + }, +} + +/// A function argument. +#[derive(Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub struct FunctionArgument { + /// Name of the argument, if any. + pub name: Option<String>, + /// Type of the argument. + pub ty: Handle<Type>, +} + +/// A function defined in the module. +#[derive(Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub struct Function { + /// Name of the function, if any. + pub name: Option<String>, + /// Information about function argument. + pub arguments: Vec<FunctionArgument>, + /// The return type of this function, if any. + pub return_type: Option<Handle<Type>>, + /// Vector of global variable usages. + /// + /// Each item corresponds to a global variable in the module. + pub global_usage: Vec<GlobalUse>, + /// Local variables defined and used in the function. + pub local_variables: Arena<LocalVariable>, + /// Expressions used inside this function. + pub expressions: Arena<Expression>, + /// Block of instructions comprising the body of the function. + pub body: Block, +} + +/// Exported function, to be run at a certain stage in the pipeline. +#[derive(Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub struct EntryPoint { + /// Early depth test for fragment stages. + pub early_depth_test: Option<EarlyDepthTest>, + /// Workgroup size for compute stages + pub workgroup_size: [u32; 3], + /// The entrance function. + pub function: Function, +} + +/// Shader module. +/// +/// A module is a set of constants, global variables and functions, as well as +/// the types required to define them. +/// +/// Some functions are marked as entry points, to be used in a certain shader stage. +/// +/// To create a new module, use [`Module::from_header`] or [`Module::generate_empty`]. +/// Alternatively, you can load an existing shader using one of the [available front ends][front]. +/// +/// When finished, you can export modules using one of the [available back ends][back]. +#[derive(Debug)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub struct Module { + /// Header containing module metadata. + pub header: Header, + /// Storage for the types defined in this module. + pub types: Arena<Type>, + /// Storage for the constants defined in this module. + pub constants: Arena<Constant>, + /// Storage for the global variables defined in this module. + pub global_variables: Arena<GlobalVariable>, + /// Storage for the functions defined in this module. + pub functions: Arena<Function>, + /// Exported entry points. + pub entry_points: FastHashMap<(ShaderStage, String), EntryPoint>, +} diff --git a/third_party/rust/naga/src/proc/call_graph.rs b/third_party/rust/naga/src/proc/call_graph.rs new file mode 100644 index 0000000000..1c580d5c15 --- /dev/null +++ b/third_party/rust/naga/src/proc/call_graph.rs @@ -0,0 +1,74 @@ +use crate::{ + arena::{Arena, Handle}, + proc::{Interface, Visitor}, + Function, +}; +use petgraph::{ + graph::{DefaultIx, NodeIndex}, + Graph, +}; + +pub type CallGraph = Graph<Handle<Function>, ()>; + +pub struct CallGraphBuilder<'a> { + pub functions: &'a Arena<Function>, +} + +impl<'a> CallGraphBuilder<'a> { + pub fn process(&self, func: &Function) -> CallGraph { + let mut graph = Graph::new(); + let mut children = Vec::new(); + + let visitor = CallGraphVisitor { + children: &mut children, + }; + + let mut interface = Interface { + expressions: &func.expressions, + local_variables: &func.local_variables, + visitor, + }; + + interface.traverse(&func.body); + + for handle in children { + let id = graph.add_node(handle); + self.collect(handle, id, &mut graph); + } + + graph + } + + fn collect(&self, handle: Handle<Function>, id: NodeIndex<DefaultIx>, graph: &mut CallGraph) { + let mut children = Vec::new(); + let visitor = CallGraphVisitor { + children: &mut children, + }; + let func = &self.functions[handle]; + + let mut interface = Interface { + expressions: &func.expressions, + local_variables: &func.local_variables, + visitor, + }; + + interface.traverse(&func.body); + + for handle in children { + let child_id = graph.add_node(handle); + graph.add_edge(id, child_id, ()); + + self.collect(handle, child_id, graph); + } + } +} + +struct CallGraphVisitor<'a> { + children: &'a mut Vec<Handle<Function>>, +} + +impl<'a> Visitor for CallGraphVisitor<'a> { + fn visit_fun(&mut self, func: Handle<Function>) { + self.children.push(func) + } +} diff --git a/third_party/rust/naga/src/proc/interface.rs b/third_party/rust/naga/src/proc/interface.rs new file mode 100644 index 0000000000..b512452fe2 --- /dev/null +++ b/third_party/rust/naga/src/proc/interface.rs @@ -0,0 +1,290 @@ +use crate::arena::{Arena, Handle}; + +pub struct Interface<'a, T> { + pub expressions: &'a Arena<crate::Expression>, + pub local_variables: &'a Arena<crate::LocalVariable>, + pub visitor: T, +} + +pub trait Visitor { + fn visit_expr(&mut self, _: &crate::Expression) {} + fn visit_lhs_expr(&mut self, _: &crate::Expression) {} + fn visit_fun(&mut self, _: Handle<crate::Function>) {} +} + +impl<'a, T> Interface<'a, T> +where + T: Visitor, +{ + fn traverse_expr(&mut self, handle: Handle<crate::Expression>) { + use crate::Expression as E; + + let expr = &self.expressions[handle]; + + self.visitor.visit_expr(expr); + + match *expr { + E::Access { base, index } => { + self.traverse_expr(base); + self.traverse_expr(index); + } + E::AccessIndex { base, .. } => { + self.traverse_expr(base); + } + E::Constant(_) => {} + E::Compose { ref components, .. } => { + for &comp in components { + self.traverse_expr(comp); + } + } + E::FunctionArgument(_) | E::GlobalVariable(_) | E::LocalVariable(_) => {} + E::Load { pointer } => { + self.traverse_expr(pointer); + } + E::ImageSample { + image, + sampler, + coordinate, + level, + depth_ref, + } => { + self.traverse_expr(image); + self.traverse_expr(sampler); + self.traverse_expr(coordinate); + match level { + crate::SampleLevel::Auto | crate::SampleLevel::Zero => (), + crate::SampleLevel::Exact(h) | crate::SampleLevel::Bias(h) => { + self.traverse_expr(h) + } + } + if let Some(dref) = depth_ref { + self.traverse_expr(dref); + } + } + E::ImageLoad { + image, + coordinate, + index, + } => { + self.traverse_expr(image); + self.traverse_expr(coordinate); + if let Some(index) = index { + self.traverse_expr(index); + } + } + E::Unary { expr, .. } => { + self.traverse_expr(expr); + } + E::Binary { left, right, .. } => { + self.traverse_expr(left); + self.traverse_expr(right); + } + E::Select { + condition, + accept, + reject, + } => { + self.traverse_expr(condition); + self.traverse_expr(accept); + self.traverse_expr(reject); + } + E::Intrinsic { argument, .. } => { + self.traverse_expr(argument); + } + E::Transpose(matrix) => { + self.traverse_expr(matrix); + } + E::DotProduct(left, right) => { + self.traverse_expr(left); + self.traverse_expr(right); + } + E::CrossProduct(left, right) => { + self.traverse_expr(left); + self.traverse_expr(right); + } + E::As { expr, .. } => { + self.traverse_expr(expr); + } + E::Derivative { expr, .. } => { + self.traverse_expr(expr); + } + E::Call { + ref origin, + ref arguments, + } => { + for &argument in arguments { + self.traverse_expr(argument); + } + if let crate::FunctionOrigin::Local(fun) = *origin { + self.visitor.visit_fun(fun); + } + } + E::ArrayLength(expr) => { + self.traverse_expr(expr); + } + } + } + + pub fn traverse(&mut self, block: &[crate::Statement]) { + for statement in block { + use crate::Statement as S; + match *statement { + S::Break | S::Continue | S::Kill => (), + S::Block(ref b) => { + self.traverse(b); + } + S::If { + condition, + ref accept, + ref reject, + } => { + self.traverse_expr(condition); + self.traverse(accept); + self.traverse(reject); + } + S::Switch { + selector, + ref cases, + ref default, + } => { + self.traverse_expr(selector); + for &(ref case, _) in cases.values() { + self.traverse(case); + } + self.traverse(default); + } + S::Loop { + ref body, + ref continuing, + } => { + self.traverse(body); + self.traverse(continuing); + } + S::Return { value } => { + if let Some(expr) = value { + self.traverse_expr(expr); + } + } + S::Store { pointer, value } => { + let mut left = pointer; + loop { + match self.expressions[left] { + crate::Expression::Access { base, index } => { + self.traverse_expr(index); + left = base; + } + crate::Expression::AccessIndex { base, .. } => { + left = base; + } + _ => break, + } + } + self.visitor.visit_lhs_expr(&self.expressions[left]); + self.traverse_expr(value); + } + } + } + } +} + +struct GlobalUseVisitor<'a>(&'a mut [crate::GlobalUse]); + +impl Visitor for GlobalUseVisitor<'_> { + fn visit_expr(&mut self, expr: &crate::Expression) { + if let crate::Expression::GlobalVariable(handle) = expr { + self.0[handle.index()] |= crate::GlobalUse::LOAD; + } + } + + fn visit_lhs_expr(&mut self, expr: &crate::Expression) { + if let crate::Expression::GlobalVariable(handle) = expr { + self.0[handle.index()] |= crate::GlobalUse::STORE; + } + } +} + +impl crate::Function { + pub fn fill_global_use(&mut self, globals: &Arena<crate::GlobalVariable>) { + self.global_usage.clear(); + self.global_usage + .resize(globals.len(), crate::GlobalUse::empty()); + + let mut io = Interface { + expressions: &self.expressions, + local_variables: &self.local_variables, + visitor: GlobalUseVisitor(&mut self.global_usage), + }; + io.traverse(&self.body); + } +} + +#[cfg(test)] +mod tests { + use crate::{ + Arena, Expression, GlobalUse, GlobalVariable, Handle, Statement, StorageAccess, + StorageClass, + }; + + #[test] + fn global_use_scan() { + let test_global = GlobalVariable { + name: None, + class: StorageClass::Uniform, + binding: None, + ty: Handle::new(std::num::NonZeroU32::new(1).unwrap()), + init: None, + interpolation: None, + storage_access: StorageAccess::empty(), + }; + let mut test_globals = Arena::new(); + + let global_1 = test_globals.append(test_global.clone()); + let global_2 = test_globals.append(test_global.clone()); + let global_3 = test_globals.append(test_global.clone()); + let global_4 = test_globals.append(test_global); + + let mut expressions = Arena::new(); + let global_1_expr = expressions.append(Expression::GlobalVariable(global_1)); + let global_2_expr = expressions.append(Expression::GlobalVariable(global_2)); + let global_3_expr = expressions.append(Expression::GlobalVariable(global_3)); + let global_4_expr = expressions.append(Expression::GlobalVariable(global_4)); + + let test_body = vec![ + Statement::Return { + value: Some(global_1_expr), + }, + Statement::Store { + pointer: global_2_expr, + value: global_1_expr, + }, + Statement::Store { + pointer: expressions.append(Expression::Access { + base: global_3_expr, + index: global_4_expr, + }), + value: global_1_expr, + }, + ]; + + let mut function = crate::Function { + name: None, + arguments: Vec::new(), + return_type: None, + local_variables: Arena::new(), + expressions, + global_usage: Vec::new(), + body: test_body, + }; + function.fill_global_use(&test_globals); + + assert_eq!( + &function.global_usage, + &[ + GlobalUse::LOAD, + GlobalUse::STORE, + GlobalUse::STORE, + GlobalUse::LOAD, + ], + ) + } +} diff --git a/third_party/rust/naga/src/proc/mod.rs b/third_party/rust/naga/src/proc/mod.rs new file mode 100644 index 0000000000..961e55da7c --- /dev/null +++ b/third_party/rust/naga/src/proc/mod.rs @@ -0,0 +1,67 @@ +//! Module processing functionality. + +#[cfg(feature = "petgraph")] +mod call_graph; +mod interface; +mod namer; +mod typifier; +mod validator; + +#[cfg(feature = "petgraph")] +pub use call_graph::{CallGraph, CallGraphBuilder}; +pub use interface::{Interface, Visitor}; +pub use namer::{EntryPointIndex, NameKey, Namer}; +pub use typifier::{check_constant_type, ResolveContext, ResolveError, Typifier}; +pub use validator::{ValidationError, Validator}; + +impl From<super::StorageFormat> for super::ScalarKind { + fn from(format: super::StorageFormat) -> Self { + use super::{ScalarKind as Sk, StorageFormat as Sf}; + match format { + Sf::R8Unorm => Sk::Float, + Sf::R8Snorm => Sk::Float, + Sf::R8Uint => Sk::Uint, + Sf::R8Sint => Sk::Sint, + Sf::R16Uint => Sk::Uint, + Sf::R16Sint => Sk::Sint, + Sf::R16Float => Sk::Float, + Sf::Rg8Unorm => Sk::Float, + Sf::Rg8Snorm => Sk::Float, + Sf::Rg8Uint => Sk::Uint, + Sf::Rg8Sint => Sk::Sint, + Sf::R32Uint => Sk::Uint, + Sf::R32Sint => Sk::Sint, + Sf::R32Float => Sk::Float, + Sf::Rg16Uint => Sk::Uint, + Sf::Rg16Sint => Sk::Sint, + Sf::Rg16Float => Sk::Float, + Sf::Rgba8Unorm => Sk::Float, + Sf::Rgba8Snorm => Sk::Float, + Sf::Rgba8Uint => Sk::Uint, + Sf::Rgba8Sint => Sk::Sint, + Sf::Rgb10a2Unorm => Sk::Float, + Sf::Rg11b10Float => Sk::Float, + Sf::Rg32Uint => Sk::Uint, + Sf::Rg32Sint => Sk::Sint, + Sf::Rg32Float => Sk::Float, + Sf::Rgba16Uint => Sk::Uint, + Sf::Rgba16Sint => Sk::Sint, + Sf::Rgba16Float => Sk::Float, + Sf::Rgba32Uint => Sk::Uint, + Sf::Rgba32Sint => Sk::Sint, + Sf::Rgba32Float => Sk::Float, + } + } +} + +impl crate::TypeInner { + pub fn scalar_kind(&self) -> Option<super::ScalarKind> { + match *self { + super::TypeInner::Scalar { kind, .. } | super::TypeInner::Vector { kind, .. } => { + Some(kind) + } + super::TypeInner::Matrix { .. } => Some(super::ScalarKind::Float), + _ => None, + } + } +} diff --git a/third_party/rust/naga/src/proc/namer.rs b/third_party/rust/naga/src/proc/namer.rs new file mode 100644 index 0000000000..03b508904b --- /dev/null +++ b/third_party/rust/naga/src/proc/namer.rs @@ -0,0 +1,113 @@ +use crate::{arena::Handle, FastHashMap}; +use std::collections::hash_map::Entry; + +pub type EntryPointIndex = u16; + +#[derive(Debug, Eq, Hash, PartialEq)] +pub enum NameKey { + GlobalVariable(Handle<crate::GlobalVariable>), + Type(Handle<crate::Type>), + StructMember(Handle<crate::Type>, u32), + Function(Handle<crate::Function>), + FunctionArgument(Handle<crate::Function>, u32), + FunctionLocal(Handle<crate::Function>, Handle<crate::LocalVariable>), + EntryPoint(EntryPointIndex), + EntryPointLocal(EntryPointIndex, Handle<crate::LocalVariable>), +} + +/// This processor assigns names to all the things in a module +/// that may need identifiers in a textual backend. +pub struct Namer { + unique: FastHashMap<String, u32>, +} + +impl Namer { + fn sanitize(string: &str) -> String { + let mut base = string + .chars() + .skip_while(|c| c.is_numeric()) + .filter(|&c| c.is_ascii_alphanumeric() || c == '_') + .collect::<String>(); + // close the name by '_' if the re is a number, so that + // we can have our own number! + match base.chars().next_back() { + Some(c) if !c.is_numeric() => {} + _ => base.push('_'), + }; + base + } + + fn call(&mut self, label_raw: &str) -> String { + let base = Self::sanitize(label_raw); + match self.unique.entry(base) { + Entry::Occupied(mut e) => { + *e.get_mut() += 1; + format!("{}{}", e.key(), e.get()) + } + Entry::Vacant(e) => { + let name = e.key().to_string(); + e.insert(0); + name + } + } + } + + fn call_or(&mut self, label: &Option<String>, fallback: &str) -> String { + self.call(match *label { + Some(ref name) => name, + None => fallback, + }) + } + + pub fn process( + module: &crate::Module, + reserved: &[&str], + output: &mut FastHashMap<NameKey, String>, + ) { + let mut this = Namer { + unique: reserved + .iter() + .map(|string| (string.to_string(), 0)) + .collect(), + }; + + for (handle, var) in module.global_variables.iter() { + let name = this.call_or(&var.name, "global"); + output.insert(NameKey::GlobalVariable(handle), name); + } + + for (ty_handle, ty) in module.types.iter() { + let ty_name = this.call_or(&ty.name, "type"); + output.insert(NameKey::Type(ty_handle), ty_name); + + if let crate::TypeInner::Struct { ref members } = ty.inner { + for (index, member) in members.iter().enumerate() { + let name = this.call_or(&member.name, "member"); + output.insert(NameKey::StructMember(ty_handle, index as u32), name); + } + } + } + + for (fun_handle, fun) in module.functions.iter() { + let fun_name = this.call_or(&fun.name, "function"); + output.insert(NameKey::Function(fun_handle), fun_name); + for (index, arg) in fun.arguments.iter().enumerate() { + let name = this.call_or(&arg.name, "param"); + output.insert(NameKey::FunctionArgument(fun_handle, index as u32), name); + } + for (handle, var) in fun.local_variables.iter() { + let name = this.call_or(&var.name, "local"); + output.insert(NameKey::FunctionLocal(fun_handle, handle), name); + } + } + + for (ep_index, (&(_, ref base_name), ep)) in module.entry_points.iter().enumerate() { + let ep_name = this.call(base_name); + output.insert(NameKey::EntryPoint(ep_index as _), ep_name); + for (handle, var) in ep.function.local_variables.iter() { + let name = this.call_or(&var.name, "local"); + output.insert(NameKey::EntryPointLocal(ep_index as _, handle), name); + } + } + } +} diff --git a/third_party/rust/naga/src/proc/typifier.rs b/third_party/rust/naga/src/proc/typifier.rs new file mode 100644 index 0000000000..b09893c179 --- /dev/null +++ b/third_party/rust/naga/src/proc/typifier.rs @@ -0,0 +1,424 @@ +use crate::arena::{Arena, Handle}; + +use thiserror::Error; + +#[derive(Debug, PartialEq)] +enum Resolution { + Handle(Handle<crate::Type>), + Value(crate::TypeInner), +} + +// Clone is only implemented for numeric variants of `TypeInner`. +impl Clone for Resolution { + fn clone(&self) -> Self { + match *self { + Resolution::Handle(handle) => Resolution::Handle(handle), + Resolution::Value(ref v) => Resolution::Value(match *v { + crate::TypeInner::Scalar { kind, width } => { + crate::TypeInner::Scalar { kind, width } + } + crate::TypeInner::Vector { size, kind, width } => { + crate::TypeInner::Vector { size, kind, width } + } + crate::TypeInner::Matrix { + rows, + columns, + width, + } => crate::TypeInner::Matrix { + rows, + columns, + width, + }, + #[allow(clippy::panic)] + _ => panic!("Unepxected clone type: {:?}", v), + }), + } + } +} + +#[derive(Debug)] +pub struct Typifier { + resolutions: Vec<Resolution>, +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ResolveError { + #[error("Invalid index into array")] + InvalidAccessIndex, + #[error("Function {name} not defined")] + FunctionNotDefined { name: String }, + #[error("Function without return type")] + FunctionReturnsVoid, + #[error("Type is not found in the given immutable arena")] + TypeNotFound, + #[error("Incompatible operand: {op} {operand}")] + IncompatibleOperand { op: String, operand: String }, + #[error("Incompatible operands: {left} {op} {right}")] + IncompatibleOperands { + op: String, + left: String, + right: String, + }, +} + +pub struct ResolveContext<'a> { + pub constants: &'a Arena<crate::Constant>, + pub global_vars: &'a Arena<crate::GlobalVariable>, + pub local_vars: &'a Arena<crate::LocalVariable>, + pub functions: &'a Arena<crate::Function>, + pub arguments: &'a [crate::FunctionArgument], +} + +impl Typifier { + pub fn new() -> Self { + Typifier { + resolutions: Vec::new(), + } + } + + pub fn clear(&mut self) { + self.resolutions.clear() + } + + pub fn get<'a>( + &'a self, + expr_handle: Handle<crate::Expression>, + types: &'a Arena<crate::Type>, + ) -> &'a crate::TypeInner { + match self.resolutions[expr_handle.index()] { + Resolution::Handle(ty_handle) => &types[ty_handle].inner, + Resolution::Value(ref inner) => inner, + } + } + + pub fn get_handle( + &self, + expr_handle: Handle<crate::Expression>, + ) -> Option<Handle<crate::Type>> { + match self.resolutions[expr_handle.index()] { + Resolution::Handle(ty_handle) => Some(ty_handle), + Resolution::Value(_) => None, + } + } + + fn resolve_impl( + &self, + expr: &crate::Expression, + types: &Arena<crate::Type>, + ctx: &ResolveContext, + ) -> Result<Resolution, ResolveError> { + Ok(match *expr { + crate::Expression::Access { base, .. } => match *self.get(base, types) { + crate::TypeInner::Array { base, .. } => Resolution::Handle(base), + crate::TypeInner::Vector { + size: _, + kind, + width, + } => Resolution::Value(crate::TypeInner::Scalar { kind, width }), + crate::TypeInner::Matrix { + rows: size, + columns: _, + width, + } => Resolution::Value(crate::TypeInner::Vector { + size, + kind: crate::ScalarKind::Float, + width, + }), + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: "access".to_string(), + operand: format!("{:?}", other), + }) + } + }, + crate::Expression::AccessIndex { base, index } => match *self.get(base, types) { + crate::TypeInner::Vector { size, kind, width } => { + if index >= size as u32 { + return Err(ResolveError::InvalidAccessIndex); + } + Resolution::Value(crate::TypeInner::Scalar { kind, width }) + } + crate::TypeInner::Matrix { + columns, + rows, + width, + } => { + if index >= columns as u32 { + return Err(ResolveError::InvalidAccessIndex); + } + Resolution::Value(crate::TypeInner::Vector { + size: rows, + kind: crate::ScalarKind::Float, + width, + }) + } + crate::TypeInner::Array { base, .. } => Resolution::Handle(base), + crate::TypeInner::Struct { ref members } => { + let member = members + .get(index as usize) + .ok_or(ResolveError::InvalidAccessIndex)?; + Resolution::Handle(member.ty) + } + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: "access index".to_string(), + operand: format!("{:?}", other), + }) + } + }, + crate::Expression::Constant(h) => Resolution::Handle(ctx.constants[h].ty), + crate::Expression::Compose { ty, .. } => Resolution::Handle(ty), + crate::Expression::FunctionArgument(index) => { + Resolution::Handle(ctx.arguments[index as usize].ty) + } + crate::Expression::GlobalVariable(h) => Resolution::Handle(ctx.global_vars[h].ty), + crate::Expression::LocalVariable(h) => Resolution::Handle(ctx.local_vars[h].ty), + crate::Expression::Load { .. } => unimplemented!(), + crate::Expression::ImageSample { image, .. } + | crate::Expression::ImageLoad { image, .. } => match *self.get(image, types) { + crate::TypeInner::Image { class, .. } => Resolution::Value(match class { + crate::ImageClass::Depth => crate::TypeInner::Scalar { + kind: crate::ScalarKind::Float, + width: 4, + }, + crate::ImageClass::Sampled { kind, multi: _ } => crate::TypeInner::Vector { + kind, + width: 4, + size: crate::VectorSize::Quad, + }, + crate::ImageClass::Storage(format) => crate::TypeInner::Vector { + kind: format.into(), + width: 4, + size: crate::VectorSize::Quad, + }, + }), + _ => unreachable!(), + }, + crate::Expression::Unary { expr, .. } => self.resolutions[expr.index()].clone(), + crate::Expression::Binary { op, left, right } => match op { + crate::BinaryOperator::Add + | crate::BinaryOperator::Subtract + | crate::BinaryOperator::Divide + | crate::BinaryOperator::Modulo => self.resolutions[left.index()].clone(), + crate::BinaryOperator::Multiply => { + let ty_left = self.get(left, types); + let ty_right = self.get(right, types); + if ty_left == ty_right { + self.resolutions[left.index()].clone() + } else if let crate::TypeInner::Scalar { .. } = *ty_right { + self.resolutions[left.index()].clone() + } else { + match *ty_left { + crate::TypeInner::Scalar { .. } => { + self.resolutions[right.index()].clone() + } + crate::TypeInner::Matrix { + columns, + rows: _, + width, + } => Resolution::Value(crate::TypeInner::Vector { + size: columns, + kind: crate::ScalarKind::Float, + width, + }), + _ => { + return Err(ResolveError::IncompatibleOperands { + op: "x".to_string(), + left: format!("{:?}", ty_left), + right: format!("{:?}", ty_right), + }) + } + } + } + } + crate::BinaryOperator::Equal + | crate::BinaryOperator::NotEqual + | crate::BinaryOperator::Less + | crate::BinaryOperator::LessEqual + | crate::BinaryOperator::Greater + | crate::BinaryOperator::GreaterEqual + | crate::BinaryOperator::LogicalAnd + | crate::BinaryOperator::LogicalOr => self.resolutions[left.index()].clone(), + crate::BinaryOperator::And + | crate::BinaryOperator::ExclusiveOr + | crate::BinaryOperator::InclusiveOr + | crate::BinaryOperator::ShiftLeft + | crate::BinaryOperator::ShiftRight => self.resolutions[left.index()].clone(), + }, + crate::Expression::Select { accept, .. } => self.resolutions[accept.index()].clone(), + crate::Expression::Intrinsic { .. } => unimplemented!(), + crate::Expression::Transpose(expr) => match *self.get(expr, types) { + crate::TypeInner::Matrix { + columns, + rows, + width, + } => Resolution::Value(crate::TypeInner::Matrix { + columns: rows, + rows: columns, + width, + }), + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: "transpose".to_string(), + operand: format!("{:?}", other), + }) + } + }, + crate::Expression::DotProduct(left_expr, _) => match *self.get(left_expr, types) { + crate::TypeInner::Vector { + kind, + size: _, + width, + } => Resolution::Value(crate::TypeInner::Scalar { kind, width }), + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: "dot product".to_string(), + operand: format!("{:?}", other), + }) + } + }, + crate::Expression::CrossProduct(_, _) => unimplemented!(), + crate::Expression::As { + expr, + kind, + convert: _, + } => match *self.get(expr, types) { + crate::TypeInner::Scalar { kind: _, width } => { + Resolution::Value(crate::TypeInner::Scalar { kind, width }) + } + crate::TypeInner::Vector { + kind: _, + size, + width, + } => Resolution::Value(crate::TypeInner::Vector { kind, size, width }), + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: "as".to_string(), + operand: format!("{:?}", other), + }) + } + }, + crate::Expression::Derivative { .. } => unimplemented!(), + crate::Expression::Call { + origin: crate::FunctionOrigin::External(ref name), + ref arguments, + } => match name.as_str() { + "distance" | "length" => match *self.get(arguments[0], types) { + crate::TypeInner::Vector { kind, width, .. } + | crate::TypeInner::Scalar { kind, width } => { + Resolution::Value(crate::TypeInner::Scalar { kind, width }) + } + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: name.clone(), + operand: format!("{:?}", other), + }) + } + }, + "dot" => match *self.get(arguments[0], types) { + crate::TypeInner::Vector { kind, width, .. } => { + Resolution::Value(crate::TypeInner::Scalar { kind, width }) + } + ref other => { + return Err(ResolveError::IncompatibleOperand { + op: name.clone(), + operand: format!("{:?}", other), + }) + } + }, + //Note: `cross` is here too, we still need to figure out what to do with it + "abs" | "atan2" | "cos" | "sin" | "floor" | "inverse" | "normalize" | "min" + | "max" | "reflect" | "pow" | "clamp" | "fclamp" | "mix" | "step" + | "smoothstep" | "cross" => self.resolutions[arguments[0].index()].clone(), + _ => return Err(ResolveError::FunctionNotDefined { name: name.clone() }), + }, + crate::Expression::Call { + origin: crate::FunctionOrigin::Local(handle), + arguments: _, + } => { + let ty = ctx.functions[handle] + .return_type + .ok_or(ResolveError::FunctionReturnsVoid)?; + Resolution::Handle(ty) + } + crate::Expression::ArrayLength(_) => Resolution::Value(crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }), + }) + } + + pub fn grow( + &mut self, + expr_handle: Handle<crate::Expression>, + expressions: &Arena<crate::Expression>, + types: &mut Arena<crate::Type>, + ctx: &ResolveContext, + ) -> Result<(), ResolveError> { + if self.resolutions.len() <= expr_handle.index() { + for (eh, expr) in expressions.iter().skip(self.resolutions.len()) { + let resolution = self.resolve_impl(expr, types, ctx)?; + log::debug!("Resolving {:?} = {:?} : {:?}", eh, expr, resolution); + + let ty_handle = match resolution { + Resolution::Handle(h) => h, + Resolution::Value(inner) => types + .fetch_if_or_append(crate::Type { name: None, inner }, |a, b| { + a.inner == b.inner + }), + }; + self.resolutions.push(Resolution::Handle(ty_handle)); + } + } + Ok(()) + } + + pub fn resolve_all( + &mut self, + expressions: &Arena<crate::Expression>, + types: &Arena<crate::Type>, + ctx: &ResolveContext, + ) -> Result<(), ResolveError> { + self.clear(); + for (_, expr) in expressions.iter() { + let resolution = self.resolve_impl(expr, types, ctx)?; + self.resolutions.push(resolution); + } + Ok(()) + } +} + +pub fn check_constant_type(inner: &crate::ConstantInner, type_inner: &crate::TypeInner) -> bool { + match (inner, type_inner) { + ( + crate::ConstantInner::Sint(_), + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + }, + ) => true, + ( + crate::ConstantInner::Uint(_), + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + width: _, + }, + ) => true, + ( + crate::ConstantInner::Float(_), + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Float, + width: _, + }, + ) => true, + ( + crate::ConstantInner::Bool(_), + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width: _, + }, + ) => true, + (crate::ConstantInner::Composite(_inner), _) => true, // TODO recursively check composite types + (_, _) => false, + } +} diff --git a/third_party/rust/naga/src/proc/validator.rs b/third_party/rust/naga/src/proc/validator.rs new file mode 100644 index 0000000000..d9d3eac659 --- /dev/null +++ b/third_party/rust/naga/src/proc/validator.rs @@ -0,0 +1,489 @@ +use super::typifier::{ResolveContext, ResolveError, Typifier}; +use crate::arena::{Arena, Handle}; + +const MAX_BIND_GROUPS: u32 = 8; +const MAX_LOCATIONS: u32 = 64; // using u64 mask +const MAX_BIND_INDICES: u32 = 64; // using u64 mask +const MAX_WORKGROUP_SIZE: u32 = 0x4000; + +#[derive(Debug)] +pub struct Validator { + //Note: this is a bit tricky: some of the front-ends as well as backends + // already have to use the typifier, so the work here is redundant in a way. + typifier: Typifier, +} + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +pub enum GlobalVariableError { + #[error("Usage isn't compatible with the storage class")] + InvalidUsage, + #[error("Type isn't compatible with the storage class")] + InvalidType, + #[error("Interpolation is not valid")] + InvalidInterpolation, + #[error("Storage access {seen:?} exceed the allowed {allowed:?}")] + InvalidStorageAccess { + allowed: crate::StorageAccess, + seen: crate::StorageAccess, + }, + #[error("Binding decoration is missing or not applicable")] + InvalidBinding, + #[error("Binding is out of range")] + OutOfRangeBinding, +} + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +pub enum LocalVariableError { + #[error("Initializer is not a constant expression")] + InitializerConst, + #[error("Initializer doesn't match the variable type")] + InitializerType, +} + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +pub enum FunctionError { + #[error(transparent)] + Resolve(#[from] ResolveError), + #[error("There are instructions after `return`/`break`/`continue`")] + InvalidControlFlowExitTail, + #[error("Local variable {handle:?} '{name}' is invalid: {error:?}")] + LocalVariable { + handle: Handle<crate::LocalVariable>, + name: String, + error: LocalVariableError, + }, +} + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +pub enum EntryPointError { + #[error("Early depth test is not applicable")] + UnexpectedEarlyDepthTest, + #[error("Workgroup size is not applicable")] + UnexpectedWorkgroupSize, + #[error("Workgroup size is out of range")] + OutOfRangeWorkgroupSize, + #[error("Global variable {0:?} is used incorrectly as {1:?}")] + InvalidGlobalUsage(Handle<crate::GlobalVariable>, crate::GlobalUse), + #[error("Bindings for {0:?} conflict with other global variables")] + BindingCollision(Handle<crate::GlobalVariable>), + #[error("Built-in {0:?} is not applicable to this entry point")] + InvalidBuiltIn(crate::BuiltIn), + #[error("Interpolation of an integer has to be flat")] + InvalidIntegerInterpolation, + #[error(transparent)] + Function(#[from] FunctionError), +} + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +pub enum ValidationError { + #[error("The type {0:?} width {1} is not supported")] + InvalidTypeWidth(crate::ScalarKind, crate::Bytes), + #[error("The type handle {0:?} can not be resolved")] + UnresolvedType(Handle<crate::Type>), + #[error("The constant {0:?} can not be used for an array size")] + InvalidArraySizeConstant(Handle<crate::Constant>), + #[error("Global variable {handle:?} '{name}' is invalid: {error:?}")] + GlobalVariable { + handle: Handle<crate::GlobalVariable>, + name: String, + error: GlobalVariableError, + }, + #[error("Function {0:?} is invalid: {1:?}")] + Function(Handle<crate::Function>, FunctionError), + #[error("Entry point {name} at {stage:?} is invalid: {error:?}")] + EntryPoint { + stage: crate::ShaderStage, + name: String, + error: EntryPointError, + }, + #[error("Module is corrupted")] + Corrupted, +} + +impl crate::GlobalVariable { + fn forbid_interpolation(&self) -> Result<(), GlobalVariableError> { + match self.interpolation { + Some(_) => Err(GlobalVariableError::InvalidInterpolation), + None => Ok(()), + } + } + + fn check_resource(&self) -> Result<(), GlobalVariableError> { + match self.binding { + Some(crate::Binding::BuiltIn(_)) => {} // validated per entry point + Some(crate::Binding::Resource { group, binding }) => { + if group > MAX_BIND_GROUPS || binding > MAX_BIND_INDICES { + return Err(GlobalVariableError::OutOfRangeBinding); + } + } + Some(crate::Binding::Location(_)) | None => { + return Err(GlobalVariableError::InvalidBinding) + } + } + self.forbid_interpolation() + } +} + +fn storage_usage(access: crate::StorageAccess) -> crate::GlobalUse { + let mut storage_usage = crate::GlobalUse::empty(); + if access.contains(crate::StorageAccess::LOAD) { + storage_usage |= crate::GlobalUse::LOAD; + } + if access.contains(crate::StorageAccess::STORE) { + storage_usage |= crate::GlobalUse::STORE; + } + storage_usage +} + +impl Validator { + /// Construct a new validator instance. + pub fn new() -> Self { + Validator { + typifier: Typifier::new(), + } + } + + fn validate_global_var( + &self, + var: &crate::GlobalVariable, + types: &Arena<crate::Type>, + ) -> Result<(), GlobalVariableError> { + log::debug!("var {:?}", var); + let allowed_storage_access = match var.class { + crate::StorageClass::Function => return Err(GlobalVariableError::InvalidUsage), + crate::StorageClass::Input | crate::StorageClass::Output => { + match var.binding { + Some(crate::Binding::BuiltIn(_)) => { + // validated per entry point + var.forbid_interpolation()? + } + Some(crate::Binding::Location(loc)) => { + if loc > MAX_LOCATIONS { + return Err(GlobalVariableError::OutOfRangeBinding); + } + match types[var.ty].inner { + crate::TypeInner::Scalar { .. } + | crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } => {} + _ => return Err(GlobalVariableError::InvalidType), + } + } + Some(crate::Binding::Resource { .. }) => { + return Err(GlobalVariableError::InvalidBinding) + } + None => { + match types[var.ty].inner { + //TODO: check the member types + crate::TypeInner::Struct { members: _ } => { + var.forbid_interpolation()? + } + _ => return Err(GlobalVariableError::InvalidType), + } + } + } + crate::StorageAccess::empty() + } + crate::StorageClass::Storage => { + var.check_resource()?; + crate::StorageAccess::all() + } + crate::StorageClass::Uniform => { + var.check_resource()?; + crate::StorageAccess::empty() + } + crate::StorageClass::Handle => { + var.check_resource()?; + match types[var.ty].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage(_), + .. + } => crate::StorageAccess::all(), + _ => crate::StorageAccess::empty(), + } + } + crate::StorageClass::Private | crate::StorageClass::WorkGroup => { + if var.binding.is_some() { + return Err(GlobalVariableError::InvalidBinding); + } + var.forbid_interpolation()?; + crate::StorageAccess::empty() + } + crate::StorageClass::PushConstant => { + //TODO + return Err(GlobalVariableError::InvalidStorageAccess { + allowed: crate::StorageAccess::empty(), + seen: crate::StorageAccess::empty(), + }); + } + }; + + if !allowed_storage_access.contains(var.storage_access) { + return Err(GlobalVariableError::InvalidStorageAccess { + allowed: allowed_storage_access, + seen: var.storage_access, + }); + } + + Ok(()) + } + + fn validate_local_var( + &self, + var: &crate::LocalVariable, + _fun: &crate::Function, + _types: &Arena<crate::Type>, + ) -> Result<(), LocalVariableError> { + log::debug!("var {:?}", var); + if let Some(_expr_handle) = var.init { + if false { + return Err(LocalVariableError::InitializerConst); + } + } + Ok(()) + } + + fn validate_function( + &mut self, + fun: &crate::Function, + module: &crate::Module, + ) -> Result<(), FunctionError> { + let resolve_ctx = ResolveContext { + constants: &module.constants, + global_vars: &module.global_variables, + local_vars: &fun.local_variables, + functions: &module.functions, + arguments: &fun.arguments, + }; + self.typifier + .resolve_all(&fun.expressions, &module.types, &resolve_ctx)?; + + for (var_handle, var) in fun.local_variables.iter() { + self.validate_local_var(var, fun, &module.types) + .map_err(|error| FunctionError::LocalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + error, + })?; + } + Ok(()) + } + + fn validate_entry_point( + &mut self, + ep: &crate::EntryPoint, + stage: crate::ShaderStage, + module: &crate::Module, + ) -> Result<(), EntryPointError> { + if ep.early_depth_test.is_some() && stage != crate::ShaderStage::Fragment { + return Err(EntryPointError::UnexpectedEarlyDepthTest); + } + if stage == crate::ShaderStage::Compute { + if ep + .workgroup_size + .iter() + .any(|&s| s == 0 || s > MAX_WORKGROUP_SIZE) + { + return Err(EntryPointError::OutOfRangeWorkgroupSize); + } + } else if ep.workgroup_size != [0; 3] { + return Err(EntryPointError::UnexpectedWorkgroupSize); + } + + let mut bind_group_masks = [0u64; MAX_BIND_GROUPS as usize]; + let mut location_in_mask = 0u64; + let mut location_out_mask = 0u64; + for ((var_handle, var), &usage) in module + .global_variables + .iter() + .zip(&ep.function.global_usage) + { + if usage.is_empty() { + continue; + } + + if let Some(crate::Binding::Location(_)) = var.binding { + match (stage, var.class) { + (crate::ShaderStage::Vertex, crate::StorageClass::Output) + | (crate::ShaderStage::Fragment, crate::StorageClass::Input) => { + match module.types[var.ty].inner.scalar_kind() { + Some(crate::ScalarKind::Float) => {} + Some(_) if var.interpolation != Some(crate::Interpolation::Flat) => { + return Err(EntryPointError::InvalidIntegerInterpolation); + } + _ => {} + } + } + _ => {} + } + } + + let allowed_usage = match var.class { + crate::StorageClass::Function => unreachable!(), + crate::StorageClass::Input => { + let mask = match var.binding { + Some(crate::Binding::BuiltIn(built_in)) => match (stage, built_in) { + (crate::ShaderStage::Vertex, crate::BuiltIn::BaseInstance) + | (crate::ShaderStage::Vertex, crate::BuiltIn::BaseVertex) + | (crate::ShaderStage::Vertex, crate::BuiltIn::InstanceIndex) + | (crate::ShaderStage::Vertex, crate::BuiltIn::VertexIndex) + | (crate::ShaderStage::Fragment, crate::BuiltIn::PointSize) + | (crate::ShaderStage::Fragment, crate::BuiltIn::FragCoord) + | (crate::ShaderStage::Fragment, crate::BuiltIn::FrontFacing) + | (crate::ShaderStage::Fragment, crate::BuiltIn::SampleIndex) + | (crate::ShaderStage::Compute, crate::BuiltIn::GlobalInvocationId) + | (crate::ShaderStage::Compute, crate::BuiltIn::LocalInvocationId) + | (crate::ShaderStage::Compute, crate::BuiltIn::LocalInvocationIndex) + | (crate::ShaderStage::Compute, crate::BuiltIn::WorkGroupId) => 0, + _ => return Err(EntryPointError::InvalidBuiltIn(built_in)), + }, + Some(crate::Binding::Location(loc)) => 1 << loc, + Some(crate::Binding::Resource { .. }) => unreachable!(), + None => 0, + }; + if location_in_mask & mask != 0 { + return Err(EntryPointError::BindingCollision(var_handle)); + } + location_in_mask |= mask; + crate::GlobalUse::LOAD + } + crate::StorageClass::Output => { + let mask = match var.binding { + Some(crate::Binding::BuiltIn(built_in)) => match (stage, built_in) { + (crate::ShaderStage::Vertex, crate::BuiltIn::Position) + | (crate::ShaderStage::Vertex, crate::BuiltIn::PointSize) + | (crate::ShaderStage::Vertex, crate::BuiltIn::ClipDistance) + | (crate::ShaderStage::Fragment, crate::BuiltIn::FragDepth) => 0, + _ => return Err(EntryPointError::InvalidBuiltIn(built_in)), + }, + Some(crate::Binding::Location(loc)) => 1 << loc, + Some(crate::Binding::Resource { .. }) => unreachable!(), + None => 0, + }; + if location_out_mask & mask != 0 { + return Err(EntryPointError::BindingCollision(var_handle)); + } + location_out_mask |= mask; + crate::GlobalUse::LOAD | crate::GlobalUse::STORE + } + crate::StorageClass::Uniform => crate::GlobalUse::LOAD, + crate::StorageClass::Storage => storage_usage(var.storage_access), + crate::StorageClass::Handle => match module.types[var.ty].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage(_), + .. + } => storage_usage(var.storage_access), + _ => crate::GlobalUse::LOAD, + }, + crate::StorageClass::Private | crate::StorageClass::WorkGroup => { + crate::GlobalUse::all() + } + crate::StorageClass::PushConstant => crate::GlobalUse::LOAD, + }; + if !allowed_usage.contains(usage) { + log::warn!("\tUsage error for: {:?}", var); + log::warn!( + "\tAllowed usage: {:?}, requested: {:?}", + allowed_usage, + usage + ); + return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage)); + } + + if let Some(crate::Binding::Resource { group, binding }) = var.binding { + let mask = 1 << binding; + let group_mask = &mut bind_group_masks[group as usize]; + if *group_mask & mask != 0 { + return Err(EntryPointError::BindingCollision(var_handle)); + } + *group_mask |= mask; + } + } + + self.validate_function(&ep.function, module)?; + Ok(()) + } + + /// Check the given module to be valid. + pub fn validate(&mut self, module: &crate::Module) -> Result<(), ValidationError> { + // check the types + for (handle, ty) in module.types.iter() { + use crate::TypeInner as Ti; + match ty.inner { + Ti::Scalar { kind, width } | Ti::Vector { kind, width, .. } => { + let expected = match kind { + crate::ScalarKind::Bool => 1, + _ => 4, + }; + if width != expected { + return Err(ValidationError::InvalidTypeWidth(kind, width)); + } + } + Ti::Matrix { width, .. } => { + if width != 4 { + return Err(ValidationError::InvalidTypeWidth( + crate::ScalarKind::Float, + width, + )); + } + } + Ti::Pointer { base, class: _ } => { + if base >= handle { + return Err(ValidationError::UnresolvedType(base)); + } + } + Ti::Array { base, size, .. } => { + if base >= handle { + return Err(ValidationError::UnresolvedType(base)); + } + if let crate::ArraySize::Constant(const_handle) = size { + let constant = module + .constants + .try_get(const_handle) + .ok_or(ValidationError::Corrupted)?; + match constant.inner { + crate::ConstantInner::Uint(_) => {} + _ => { + return Err(ValidationError::InvalidArraySizeConstant(const_handle)) + } + } + } + } + Ti::Struct { ref members } => { + //TODO: check that offsets are not intersecting? + for member in members { + if member.ty >= handle { + return Err(ValidationError::UnresolvedType(member.ty)); + } + } + } + Ti::Image { .. } => {} + Ti::Sampler { comparison: _ } => {} + } + } + + for (var_handle, var) in module.global_variables.iter() { + self.validate_global_var(var, &module.types) + .map_err(|error| ValidationError::GlobalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + error, + })?; + } + + for (fun_handle, fun) in module.functions.iter() { + self.validate_function(fun, module) + .map_err(|e| ValidationError::Function(fun_handle, e))?; + } + + for (&(stage, ref name), entry_point) in module.entry_points.iter() { + self.validate_entry_point(entry_point, stage, module) + .map_err(|error| ValidationError::EntryPoint { + stage, + name: name.to_string(), + error, + })?; + } + + Ok(()) + } +} |