diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
commit | 26a029d407be480d791972afb5975cf62c9360a6 (patch) | |
tree | f435a8308119effd964b339f76abb83a57c29483 /third_party/rust/naga/src/back | |
parent | Initial commit. (diff) | |
download | firefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz firefox-26a029d407be480d791972afb5975cf62c9360a6.zip |
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/naga/src/back')
28 files changed, 29499 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/back/dot/mod.rs b/third_party/rust/naga/src/back/dot/mod.rs new file mode 100644 index 0000000000..1556371df1 --- /dev/null +++ b/third_party/rust/naga/src/back/dot/mod.rs @@ -0,0 +1,703 @@ +/*! +Backend for [DOT][dot] (Graphviz). + +This backend writes a graph in the DOT language, for the ease +of IR inspection and debugging. + +[dot]: https://graphviz.org/doc/info/lang.html +*/ + +use crate::{ + arena::Handle, + valid::{FunctionInfo, ModuleInfo}, +}; + +use std::{ + borrow::Cow, + fmt::{Error as FmtError, Write as _}, +}; + +/// Configuration options for the dot backend +#[derive(Clone, Default)] +pub struct Options { + /// Only emit function bodies + pub cfg_only: bool, +} + +/// Identifier used to address a graph node +type NodeId = usize; + +/// Stores the target nodes for control flow statements +#[derive(Default, Clone, Copy)] +struct Targets { + /// The node, if some, where continue operations will land + continue_target: Option<usize>, + /// The node, if some, where break operations will land + break_target: Option<usize>, +} + +/// Stores information about the graph of statements +#[derive(Default)] +struct StatementGraph { + /// List of node names + nodes: Vec<&'static str>, + /// List of edges of the control flow, the items are defined as + /// (from, to, label) + flow: Vec<(NodeId, NodeId, &'static str)>, + /// List of implicit edges of the control flow, used for jump + /// operations such as continue or break, the items are defined as + /// (from, to, label, color_id) + jumps: Vec<(NodeId, NodeId, &'static str, usize)>, + /// List of dependency relationships between a statement node and + /// expressions + dependencies: Vec<(NodeId, Handle<crate::Expression>, &'static str)>, + /// List of expression emitted by statement node + emits: Vec<(NodeId, Handle<crate::Expression>)>, + /// List of function call by statement node + calls: Vec<(NodeId, Handle<crate::Function>)>, +} + +impl StatementGraph { + /// Adds a new block to the statement graph, returning the first and last node, respectively + fn add(&mut self, block: &[crate::Statement], targets: Targets) -> (NodeId, NodeId) { + use crate::Statement as S; + + // The first node of the block isn't a statement but a virtual node + let root = self.nodes.len(); + self.nodes.push(if root == 0 { "Root" } else { "Node" }); + // Track the last placed node, this will be returned to the caller and + // will also be used to generate the control flow edges + let mut last_node = root; + for statement in block { + // Reserve a new node for the current statement and link it to the + // node of the previous statement + let id = self.nodes.len(); + self.flow.push((last_node, id, "")); + self.nodes.push(""); // reserve space + + // Track the node identifier for the merge node, the merge node is + // the last node of a statement, normally this is the node itself, + // but for control flow statements such as `if`s and `switch`s this + // is a virtual node where all branches merge back. + let mut merge_id = id; + + self.nodes[id] = match *statement { + S::Emit(ref range) => { + for handle in range.clone() { + self.emits.push((id, handle)); + } + "Emit" + } + S::Kill => "Kill", //TODO: link to the beginning + S::Break => { + // Try to link to the break target, otherwise produce + // a broken connection + if let Some(target) = targets.break_target { + self.jumps.push((id, target, "Break", 5)) + } else { + self.jumps.push((id, root, "Broken", 7)) + } + "Break" + } + S::Continue => { + // Try to link to the continue target, otherwise produce + // a broken connection + if let Some(target) = targets.continue_target { + self.jumps.push((id, target, "Continue", 5)) + } else { + self.jumps.push((id, root, "Broken", 7)) + } + "Continue" + } + S::Barrier(_flags) => "Barrier", + S::Block(ref b) => { + let (other, last) = self.add(b, targets); + self.flow.push((id, other, "")); + // All following nodes should connect to the end of the block + // statement so change the merge id to it. + merge_id = last; + "Block" + } + S::If { + condition, + ref accept, + ref reject, + } => { + self.dependencies.push((id, condition, "condition")); + let (accept_id, accept_last) = self.add(accept, targets); + self.flow.push((id, accept_id, "accept")); + let (reject_id, reject_last) = self.add(reject, targets); + self.flow.push((id, reject_id, "reject")); + + // Create a merge node, link the branches to it and set it + // as the merge node to make the next statement node link to it + merge_id = self.nodes.len(); + self.nodes.push("Merge"); + self.flow.push((accept_last, merge_id, "")); + self.flow.push((reject_last, merge_id, "")); + + "If" + } + S::Switch { + selector, + ref cases, + } => { + self.dependencies.push((id, selector, "selector")); + + // Create a merge node and set it as the merge node to make + // the next statement node link to it + merge_id = self.nodes.len(); + self.nodes.push("Merge"); + + // Create a new targets structure and set the break target + // to the merge node + let mut targets = targets; + targets.break_target = Some(merge_id); + + for case in cases { + let (case_id, case_last) = self.add(&case.body, targets); + let label = match case.value { + crate::SwitchValue::Default => "default", + _ => "case", + }; + self.flow.push((id, case_id, label)); + // Link the last node of the branch to the merge node + self.flow.push((case_last, merge_id, "")); + } + "Switch" + } + S::Loop { + ref body, + ref continuing, + break_if, + } => { + // Create a new targets structure and set the break target + // to the merge node, this must happen before generating the + // continuing block since it can break. + let mut targets = targets; + targets.break_target = Some(id); + + let (continuing_id, continuing_last) = self.add(continuing, targets); + + // Set the the continue target to the beginning + // of the newly generated continuing block + targets.continue_target = Some(continuing_id); + + let (body_id, body_last) = self.add(body, targets); + + self.flow.push((id, body_id, "body")); + + // Link the last node of the body to the continuing block + self.flow.push((body_last, continuing_id, "continuing")); + // Link the last node of the continuing block back to the + // beginning of the loop body + self.flow.push((continuing_last, body_id, "continuing")); + + if let Some(expr) = break_if { + self.dependencies.push((continuing_id, expr, "break if")); + } + + "Loop" + } + S::Return { value } => { + if let Some(expr) = value { + self.dependencies.push((id, expr, "value")); + } + "Return" + } + S::Store { pointer, value } => { + self.dependencies.push((id, value, "value")); + self.emits.push((id, pointer)); + "Store" + } + S::ImageStore { + image, + coordinate, + array_index, + value, + } => { + self.dependencies.push((id, image, "image")); + self.dependencies.push((id, coordinate, "coordinate")); + if let Some(expr) = array_index { + self.dependencies.push((id, expr, "array_index")); + } + self.dependencies.push((id, value, "value")); + "ImageStore" + } + S::Call { + function, + ref arguments, + result, + } => { + for &arg in arguments { + self.dependencies.push((id, arg, "arg")); + } + if let Some(expr) = result { + self.emits.push((id, expr)); + } + self.calls.push((id, function)); + "Call" + } + S::Atomic { + pointer, + ref fun, + value, + result, + } => { + self.emits.push((id, result)); + self.dependencies.push((id, pointer, "pointer")); + self.dependencies.push((id, value, "value")); + if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { + self.dependencies.push((id, cmp, "cmp")); + } + "Atomic" + } + S::WorkGroupUniformLoad { pointer, result } => { + self.emits.push((id, result)); + self.dependencies.push((id, pointer, "pointer")); + "WorkGroupUniformLoad" + } + S::RayQuery { query, ref fun } => { + self.dependencies.push((id, query, "query")); + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + self.dependencies.push(( + id, + acceleration_structure, + "acceleration_structure", + )); + self.dependencies.push((id, descriptor, "descriptor")); + "RayQueryInitialize" + } + crate::RayQueryFunction::Proceed { result } => { + self.emits.push((id, result)); + "RayQueryProceed" + } + crate::RayQueryFunction::Terminate => "RayQueryTerminate", + } + } + }; + // Set the last node to the merge node + last_node = merge_id; + } + (root, last_node) + } +} + +#[allow(clippy::manual_unwrap_or)] +fn name(option: &Option<String>) -> &str { + match *option { + Some(ref name) => name, + None => "", + } +} + +/// set39 color scheme from <https://graphviz.org/doc/info/colors.html> +const COLORS: &[&str] = &[ + "white", // pattern starts at 1 + "#8dd3c7", "#ffffb3", "#bebada", "#fb8072", "#80b1d3", "#fdb462", "#b3de69", "#fccde5", + "#d9d9d9", +]; + +fn write_fun( + output: &mut String, + prefix: String, + fun: &crate::Function, + info: Option<&FunctionInfo>, + options: &Options, +) -> Result<(), FmtError> { + writeln!(output, "\t\tnode [ style=filled ]")?; + + if !options.cfg_only { + for (handle, var) in fun.local_variables.iter() { + writeln!( + output, + "\t\t{}_l{} [ shape=hexagon label=\"{:?} '{}'\" ]", + prefix, + handle.index(), + handle, + name(&var.name), + )?; + } + + write_function_expressions(output, &prefix, fun, info)?; + } + + let mut sg = StatementGraph::default(); + sg.add(&fun.body, Targets::default()); + for (index, label) in sg.nodes.into_iter().enumerate() { + writeln!( + output, + "\t\t{prefix}_s{index} [ shape=square label=\"{label}\" ]", + )?; + } + for (from, to, label) in sg.flow { + writeln!( + output, + "\t\t{prefix}_s{from} -> {prefix}_s{to} [ arrowhead=tee label=\"{label}\" ]", + )?; + } + for (from, to, label, color_id) in sg.jumps { + writeln!( + output, + "\t\t{}_s{} -> {}_s{} [ arrowhead=tee style=dashed color=\"{}\" label=\"{}\" ]", + prefix, from, prefix, to, COLORS[color_id], label, + )?; + } + + if !options.cfg_only { + for (to, expr, label) in sg.dependencies { + writeln!( + output, + "\t\t{}_e{} -> {}_s{} [ label=\"{}\" ]", + prefix, + expr.index(), + prefix, + to, + label, + )?; + } + for (from, to) in sg.emits { + writeln!( + output, + "\t\t{}_s{} -> {}_e{} [ style=dotted ]", + prefix, + from, + prefix, + to.index(), + )?; + } + } + + for (from, function) in sg.calls { + writeln!( + output, + "\t\t{}_s{} -> f{}_s0", + prefix, + from, + function.index(), + )?; + } + + Ok(()) +} + +fn write_function_expressions( + output: &mut String, + prefix: &str, + fun: &crate::Function, + info: Option<&FunctionInfo>, +) -> Result<(), FmtError> { + enum Payload<'a> { + Arguments(&'a [Handle<crate::Expression>]), + Local(Handle<crate::LocalVariable>), + Global(Handle<crate::GlobalVariable>), + } + + let mut edges = crate::FastHashMap::<&str, _>::default(); + let mut payload = None; + for (handle, expression) in fun.expressions.iter() { + use crate::Expression as E; + let (label, color_id) = match *expression { + E::Literal(_) => ("Literal".into(), 2), + E::Constant(_) => ("Constant".into(), 2), + E::ZeroValue(_) => ("ZeroValue".into(), 2), + E::Compose { ref components, .. } => { + payload = Some(Payload::Arguments(components)); + ("Compose".into(), 3) + } + E::Access { base, index } => { + edges.insert("base", base); + edges.insert("index", index); + ("Access".into(), 1) + } + E::AccessIndex { base, index } => { + edges.insert("base", base); + (format!("AccessIndex[{index}]").into(), 1) + } + E::Splat { size, value } => { + edges.insert("value", value); + (format!("Splat{size:?}").into(), 3) + } + E::Swizzle { + size, + vector, + pattern, + } => { + edges.insert("vector", vector); + (format!("Swizzle{:?}", &pattern[..size as usize]).into(), 3) + } + E::FunctionArgument(index) => (format!("Argument[{index}]").into(), 1), + E::GlobalVariable(h) => { + payload = Some(Payload::Global(h)); + ("Global".into(), 2) + } + E::LocalVariable(h) => { + payload = Some(Payload::Local(h)); + ("Local".into(), 1) + } + E::Load { pointer } => { + edges.insert("pointer", pointer); + ("Load".into(), 4) + } + E::ImageSample { + image, + sampler, + gather, + coordinate, + array_index, + offset: _, + level, + depth_ref, + } => { + edges.insert("image", image); + edges.insert("sampler", sampler); + edges.insert("coordinate", coordinate); + if let Some(expr) = array_index { + edges.insert("array_index", expr); + } + match level { + crate::SampleLevel::Auto => {} + crate::SampleLevel::Zero => {} + crate::SampleLevel::Exact(expr) => { + edges.insert("level", expr); + } + crate::SampleLevel::Bias(expr) => { + edges.insert("bias", expr); + } + crate::SampleLevel::Gradient { x, y } => { + edges.insert("grad_x", x); + edges.insert("grad_y", y); + } + } + if let Some(expr) = depth_ref { + edges.insert("depth_ref", expr); + } + let string = match gather { + Some(component) => Cow::Owned(format!("ImageGather{component:?}")), + _ => Cow::Borrowed("ImageSample"), + }; + (string, 5) + } + E::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + edges.insert("image", image); + edges.insert("coordinate", coordinate); + if let Some(expr) = array_index { + edges.insert("array_index", expr); + } + if let Some(sample) = sample { + edges.insert("sample", sample); + } + if let Some(level) = level { + edges.insert("level", level); + } + ("ImageLoad".into(), 5) + } + E::ImageQuery { image, query } => { + edges.insert("image", image); + let args = match query { + crate::ImageQuery::Size { level } => { + if let Some(expr) = level { + edges.insert("level", expr); + } + Cow::from("ImageSize") + } + _ => Cow::Owned(format!("{query:?}")), + }; + (args, 7) + } + E::Unary { op, expr } => { + edges.insert("expr", expr); + (format!("{op:?}").into(), 6) + } + E::Binary { op, left, right } => { + edges.insert("left", left); + edges.insert("right", right); + (format!("{op:?}").into(), 6) + } + E::Select { + condition, + accept, + reject, + } => { + edges.insert("condition", condition); + edges.insert("accept", accept); + edges.insert("reject", reject); + ("Select".into(), 3) + } + E::Derivative { axis, ctrl, expr } => { + edges.insert("", expr); + (format!("d{axis:?}{ctrl:?}").into(), 8) + } + E::Relational { fun, argument } => { + edges.insert("arg", argument); + (format!("{fun:?}").into(), 6) + } + E::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + edges.insert("arg", arg); + if let Some(expr) = arg1 { + edges.insert("arg1", expr); + } + if let Some(expr) = arg2 { + edges.insert("arg2", expr); + } + if let Some(expr) = arg3 { + edges.insert("arg3", expr); + } + (format!("{fun:?}").into(), 7) + } + E::As { + kind, + expr, + convert, + } => { + edges.insert("", expr); + let string = match convert { + Some(width) => format!("Convert<{kind:?},{width}>"), + None => format!("Bitcast<{kind:?}>"), + }; + (string.into(), 3) + } + E::CallResult(_function) => ("CallResult".into(), 4), + E::AtomicResult { .. } => ("AtomicResult".into(), 4), + E::WorkGroupUniformLoadResult { .. } => ("WorkGroupUniformLoadResult".into(), 4), + E::ArrayLength(expr) => { + edges.insert("", expr); + ("ArrayLength".into(), 7) + } + E::RayQueryProceedResult => ("rayQueryProceedResult".into(), 4), + E::RayQueryGetIntersection { query, committed } => { + edges.insert("", query); + let ty = if committed { "Committed" } else { "Candidate" }; + (format!("rayQueryGet{}Intersection", ty).into(), 4) + } + }; + + // give uniform expressions an outline + let color_attr = match info { + Some(info) if info[handle].uniformity.non_uniform_result.is_none() => "fillcolor", + _ => "color", + }; + writeln!( + output, + "\t\t{}_e{} [ {}=\"{}\" label=\"{:?} {}\" ]", + prefix, + handle.index(), + color_attr, + COLORS[color_id], + handle, + label, + )?; + + for (key, edge) in edges.drain() { + writeln!( + output, + "\t\t{}_e{} -> {}_e{} [ label=\"{}\" ]", + prefix, + edge.index(), + prefix, + handle.index(), + key, + )?; + } + match payload.take() { + Some(Payload::Arguments(list)) => { + write!(output, "\t\t{{")?; + for &comp in list { + write!(output, " {}_e{}", prefix, comp.index())?; + } + writeln!(output, " }} -> {}_e{}", prefix, handle.index())?; + } + Some(Payload::Local(h)) => { + writeln!( + output, + "\t\t{}_l{} -> {}_e{}", + prefix, + h.index(), + prefix, + handle.index(), + )?; + } + Some(Payload::Global(h)) => { + writeln!( + output, + "\t\tg{} -> {}_e{} [fillcolor=gray]", + h.index(), + prefix, + handle.index(), + )?; + } + None => {} + } + } + + Ok(()) +} + +/// Write shader module to a [`String`]. +pub fn write( + module: &crate::Module, + mod_info: Option<&ModuleInfo>, + options: Options, +) -> Result<String, FmtError> { + use std::fmt::Write as _; + + let mut output = String::new(); + output += "digraph Module {\n"; + + if !options.cfg_only { + writeln!(output, "\tsubgraph cluster_globals {{")?; + writeln!(output, "\t\tlabel=\"Globals\"")?; + for (handle, var) in module.global_variables.iter() { + writeln!( + output, + "\t\tg{} [ shape=hexagon label=\"{:?} {:?}/'{}'\" ]", + handle.index(), + handle, + var.space, + name(&var.name), + )?; + } + writeln!(output, "\t}}")?; + } + + for (handle, fun) in module.functions.iter() { + let prefix = format!("f{}", handle.index()); + writeln!(output, "\tsubgraph cluster_{prefix} {{")?; + writeln!( + output, + "\t\tlabel=\"Function{:?}/'{}'\"", + handle, + name(&fun.name) + )?; + let info = mod_info.map(|a| &a[handle]); + write_fun(&mut output, prefix, fun, info, &options)?; + writeln!(output, "\t}}")?; + } + for (ep_index, ep) in module.entry_points.iter().enumerate() { + let prefix = format!("ep{ep_index}"); + writeln!(output, "\tsubgraph cluster_{prefix} {{")?; + writeln!(output, "\t\tlabel=\"{:?}/'{}'\"", ep.stage, ep.name)?; + let info = mod_info.map(|a| a.get_entry_point(ep_index)); + write_fun(&mut output, prefix, &ep.function, info, &options)?; + writeln!(output, "\t}}")?; + } + + output += "}\n"; + Ok(output) +} diff --git a/third_party/rust/naga/src/back/glsl/features.rs b/third_party/rust/naga/src/back/glsl/features.rs new file mode 100644 index 0000000000..e7de05f695 --- /dev/null +++ b/third_party/rust/naga/src/back/glsl/features.rs @@ -0,0 +1,536 @@ +use super::{BackendResult, Error, Version, Writer}; +use crate::{ + back::glsl::{Options, WriterFlags}, + AddressSpace, Binding, Expression, Handle, ImageClass, ImageDimension, Interpolation, Sampling, + Scalar, ScalarKind, ShaderStage, StorageFormat, Type, TypeInner, +}; +use std::fmt::Write; + +bitflags::bitflags! { + /// Structure used to encode additions to GLSL that aren't supported by all versions. + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct Features: u32 { + /// Buffer address space support. + const BUFFER_STORAGE = 1; + const ARRAY_OF_ARRAYS = 1 << 1; + /// 8 byte floats. + const DOUBLE_TYPE = 1 << 2; + /// More image formats. + const FULL_IMAGE_FORMATS = 1 << 3; + const MULTISAMPLED_TEXTURES = 1 << 4; + const MULTISAMPLED_TEXTURE_ARRAYS = 1 << 5; + const CUBE_TEXTURES_ARRAY = 1 << 6; + const COMPUTE_SHADER = 1 << 7; + /// Image load and early depth tests. + const IMAGE_LOAD_STORE = 1 << 8; + const CONSERVATIVE_DEPTH = 1 << 9; + /// Interpolation and auxiliary qualifiers. + /// + /// Perspective, Flat, and Centroid are available in all GLSL versions we support. + const NOPERSPECTIVE_QUALIFIER = 1 << 11; + const SAMPLE_QUALIFIER = 1 << 12; + const CLIP_DISTANCE = 1 << 13; + const CULL_DISTANCE = 1 << 14; + /// Sample ID. + const SAMPLE_VARIABLES = 1 << 15; + /// Arrays with a dynamic length. + const DYNAMIC_ARRAY_SIZE = 1 << 16; + const MULTI_VIEW = 1 << 17; + /// Texture samples query + const TEXTURE_SAMPLES = 1 << 18; + /// Texture levels query + const TEXTURE_LEVELS = 1 << 19; + /// Image size query + const IMAGE_SIZE = 1 << 20; + /// Dual source blending + const DUAL_SOURCE_BLENDING = 1 << 21; + /// Instance index + /// + /// We can always support this, either through the language or a polyfill + const INSTANCE_INDEX = 1 << 22; + } +} + +/// Helper structure used to store the required [`Features`] needed to output a +/// [`Module`](crate::Module) +/// +/// Provides helper methods to check for availability and writing required extensions +pub struct FeaturesManager(Features); + +impl FeaturesManager { + /// Creates a new [`FeaturesManager`] instance + pub const fn new() -> Self { + Self(Features::empty()) + } + + /// Adds to the list of required [`Features`] + pub fn request(&mut self, features: Features) { + self.0 |= features + } + + /// Checks if the list of features [`Features`] contains the specified [`Features`] + pub fn contains(&mut self, features: Features) -> bool { + self.0.contains(features) + } + + /// Checks that all required [`Features`] are available for the specified + /// [`Version`] otherwise returns an [`Error::MissingFeatures`]. + pub fn check_availability(&self, version: Version) -> BackendResult { + // Will store all the features that are unavailable + let mut missing = Features::empty(); + + // Helper macro to check for feature availability + macro_rules! check_feature { + // Used when only core glsl supports the feature + ($feature:ident, $core:literal) => { + if self.0.contains(Features::$feature) + && (version < Version::Desktop($core) || version.is_es()) + { + missing |= Features::$feature; + } + }; + // Used when both core and es support the feature + ($feature:ident, $core:literal, $es:literal) => { + if self.0.contains(Features::$feature) + && (version < Version::Desktop($core) || version < Version::new_gles($es)) + { + missing |= Features::$feature; + } + }; + } + + check_feature!(COMPUTE_SHADER, 420, 310); + check_feature!(BUFFER_STORAGE, 400, 310); + check_feature!(DOUBLE_TYPE, 150); + check_feature!(CUBE_TEXTURES_ARRAY, 130, 310); + check_feature!(MULTISAMPLED_TEXTURES, 150, 300); + check_feature!(MULTISAMPLED_TEXTURE_ARRAYS, 150, 310); + check_feature!(ARRAY_OF_ARRAYS, 120, 310); + check_feature!(IMAGE_LOAD_STORE, 130, 310); + check_feature!(CONSERVATIVE_DEPTH, 130, 300); + check_feature!(NOPERSPECTIVE_QUALIFIER, 130); + check_feature!(SAMPLE_QUALIFIER, 400, 320); + check_feature!(CLIP_DISTANCE, 130, 300 /* with extension */); + check_feature!(CULL_DISTANCE, 450, 300 /* with extension */); + check_feature!(SAMPLE_VARIABLES, 400, 300); + check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310); + check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */); + match version { + Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300), + _ => check_feature!(MULTI_VIEW, 140, 310), + }; + // Only available on glsl core, this means that opengl es can't query the number + // of samples nor levels in a image and neither do bound checks on the sample nor + // the level argument of texelFecth + check_feature!(TEXTURE_SAMPLES, 150); + check_feature!(TEXTURE_LEVELS, 130); + check_feature!(IMAGE_SIZE, 430, 310); + + // Return an error if there are missing features + if missing.is_empty() { + Ok(()) + } else { + Err(Error::MissingFeatures(missing)) + } + } + + /// Helper method used to write all needed extensions + /// + /// # Notes + /// This won't check for feature availability so it might output extensions that aren't even + /// supported.[`check_availability`](Self::check_availability) will check feature availability + pub fn write(&self, options: &Options, mut out: impl Write) -> BackendResult { + if self.0.contains(Features::COMPUTE_SHADER) && !options.version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_compute_shader.txt + writeln!(out, "#extension GL_ARB_compute_shader : require")?; + } + + if self.0.contains(Features::BUFFER_STORAGE) && !options.version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_storage_buffer_object.txt + writeln!( + out, + "#extension GL_ARB_shader_storage_buffer_object : require" + )?; + } + + if self.0.contains(Features::DOUBLE_TYPE) && options.version < Version::Desktop(400) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_gpu_shader_fp64.txt + writeln!(out, "#extension GL_ARB_gpu_shader_fp64 : require")?; + } + + if self.0.contains(Features::CUBE_TEXTURES_ARRAY) { + if options.version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_cube_map_array.txt + writeln!(out, "#extension GL_EXT_texture_cube_map_array : require")?; + } else if options.version < Version::Desktop(400) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_cube_map_array.txt + writeln!(out, "#extension GL_ARB_texture_cube_map_array : require")?; + } + } + + if self.0.contains(Features::MULTISAMPLED_TEXTURE_ARRAYS) && options.version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_texture_storage_multisample_2d_array.txt + writeln!( + out, + "#extension GL_OES_texture_storage_multisample_2d_array : require" + )?; + } + + if self.0.contains(Features::ARRAY_OF_ARRAYS) && options.version < Version::Desktop(430) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_arrays_of_arrays.txt + writeln!(out, "#extension ARB_arrays_of_arrays : require")?; + } + + if self.0.contains(Features::IMAGE_LOAD_STORE) { + if self.0.contains(Features::FULL_IMAGE_FORMATS) && options.version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/NV/NV_image_formats.txt + writeln!(out, "#extension GL_NV_image_formats : require")?; + } + + if options.version < Version::Desktop(420) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_image_load_store.txt + writeln!(out, "#extension GL_ARB_shader_image_load_store : require")?; + } + } + + if self.0.contains(Features::CONSERVATIVE_DEPTH) { + if options.version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_conservative_depth.txt + writeln!(out, "#extension GL_EXT_conservative_depth : require")?; + } + + if options.version < Version::Desktop(420) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt + writeln!(out, "#extension GL_ARB_conservative_depth : require")?; + } + } + + if (self.0.contains(Features::CLIP_DISTANCE) || self.0.contains(Features::CULL_DISTANCE)) + && options.version.is_es() + { + // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_clip_cull_distance.txt + writeln!(out, "#extension GL_EXT_clip_cull_distance : require")?; + } + + if self.0.contains(Features::SAMPLE_VARIABLES) && options.version.is_es() { + // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_sample_variables.txt + writeln!(out, "#extension GL_OES_sample_variables : require")?; + } + + if self.0.contains(Features::MULTI_VIEW) { + if let Version::Embedded { is_webgl: true, .. } = options.version { + // https://www.khronos.org/registry/OpenGL/extensions/OVR/OVR_multiview2.txt + writeln!(out, "#extension GL_OVR_multiview2 : require")?; + } else { + // https://github.com/KhronosGroup/GLSL/blob/master/extensions/ext/GL_EXT_multiview.txt + writeln!(out, "#extension GL_EXT_multiview : require")?; + } + } + + if self.0.contains(Features::TEXTURE_SAMPLES) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_texture_image_samples.txt + writeln!( + out, + "#extension GL_ARB_shader_texture_image_samples : require" + )?; + } + + if self.0.contains(Features::TEXTURE_LEVELS) && options.version < Version::Desktop(430) { + // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_query_levels.txt + writeln!(out, "#extension GL_ARB_texture_query_levels : require")?; + } + if self.0.contains(Features::DUAL_SOURCE_BLENDING) && options.version.is_es() { + // https://registry.khronos.org/OpenGL/extensions/EXT/EXT_blend_func_extended.txt + writeln!(out, "#extension GL_EXT_blend_func_extended : require")?; + } + + if self.0.contains(Features::INSTANCE_INDEX) { + if options.writer_flags.contains(WriterFlags::DRAW_PARAMETERS) { + // https://registry.khronos.org/OpenGL/extensions/ARB/ARB_shader_draw_parameters.txt + writeln!(out, "#extension GL_ARB_shader_draw_parameters : require")?; + } + } + + Ok(()) + } +} + +impl<'a, W> Writer<'a, W> { + /// Helper method that searches the module for all the needed [`Features`] + /// + /// # Errors + /// If the version doesn't support any of the needed [`Features`] a + /// [`Error::MissingFeatures`] will be returned + pub(super) fn collect_required_features(&mut self) -> BackendResult { + let ep_info = self.info.get_entry_point(self.entry_point_idx as usize); + + if let Some(depth_test) = self.entry_point.early_depth_test { + // If IMAGE_LOAD_STORE is supported for this version of GLSL + if self.options.version.supports_early_depth_test() { + self.features.request(Features::IMAGE_LOAD_STORE); + } + + if depth_test.conservative.is_some() { + self.features.request(Features::CONSERVATIVE_DEPTH); + } + } + + for arg in self.entry_point.function.arguments.iter() { + self.varying_required_features(arg.binding.as_ref(), arg.ty); + } + if let Some(ref result) = self.entry_point.function.result { + self.varying_required_features(result.binding.as_ref(), result.ty); + } + + if let ShaderStage::Compute = self.entry_point.stage { + self.features.request(Features::COMPUTE_SHADER) + } + + if self.multiview.is_some() { + self.features.request(Features::MULTI_VIEW); + } + + for (ty_handle, ty) in self.module.types.iter() { + match ty.inner { + TypeInner::Scalar(scalar) + | TypeInner::Vector { scalar, .. } + | TypeInner::Matrix { scalar, .. } => self.scalar_required_features(scalar), + TypeInner::Array { base, size, .. } => { + if let TypeInner::Array { .. } = self.module.types[base].inner { + self.features.request(Features::ARRAY_OF_ARRAYS) + } + + // If the array is dynamically sized + if size == crate::ArraySize::Dynamic { + let mut is_used = false; + + // Check if this type is used in a global that is needed by the current entrypoint + for (global_handle, global) in self.module.global_variables.iter() { + // Skip unused globals + if ep_info[global_handle].is_empty() { + continue; + } + + // If this array is the type of a global, then this array is used + if global.ty == ty_handle { + is_used = true; + break; + } + + // If the type of this global is a struct + if let crate::TypeInner::Struct { ref members, .. } = + self.module.types[global.ty].inner + { + // Check the last element of the struct to see if it's type uses + // this array + if let Some(last) = members.last() { + if last.ty == ty_handle { + is_used = true; + break; + } + } + } + } + + // If this dynamically size array is used, we need dynamic array size support + if is_used { + self.features.request(Features::DYNAMIC_ARRAY_SIZE); + } + } + } + TypeInner::Image { + dim, + arrayed, + class, + } => { + if arrayed && dim == ImageDimension::Cube { + self.features.request(Features::CUBE_TEXTURES_ARRAY) + } + + match class { + ImageClass::Sampled { multi: true, .. } + | ImageClass::Depth { multi: true } => { + self.features.request(Features::MULTISAMPLED_TEXTURES); + if arrayed { + self.features.request(Features::MULTISAMPLED_TEXTURE_ARRAYS); + } + } + ImageClass::Storage { format, .. } => match format { + StorageFormat::R8Unorm + | StorageFormat::R8Snorm + | StorageFormat::R8Uint + | StorageFormat::R8Sint + | StorageFormat::R16Uint + | StorageFormat::R16Sint + | StorageFormat::R16Float + | StorageFormat::Rg8Unorm + | StorageFormat::Rg8Snorm + | StorageFormat::Rg8Uint + | StorageFormat::Rg8Sint + | StorageFormat::Rg16Uint + | StorageFormat::Rg16Sint + | StorageFormat::Rg16Float + | StorageFormat::Rgb10a2Uint + | StorageFormat::Rgb10a2Unorm + | StorageFormat::Rg11b10Float + | StorageFormat::Rg32Uint + | StorageFormat::Rg32Sint + | StorageFormat::Rg32Float => { + self.features.request(Features::FULL_IMAGE_FORMATS) + } + _ => {} + }, + ImageClass::Sampled { multi: false, .. } + | ImageClass::Depth { multi: false } => {} + } + } + _ => {} + } + } + + let mut push_constant_used = false; + + for (handle, global) in self.module.global_variables.iter() { + if ep_info[handle].is_empty() { + continue; + } + match global.space { + AddressSpace::WorkGroup => self.features.request(Features::COMPUTE_SHADER), + AddressSpace::Storage { .. } => self.features.request(Features::BUFFER_STORAGE), + AddressSpace::PushConstant => { + if push_constant_used { + return Err(Error::MultiplePushConstants); + } + push_constant_used = true; + } + _ => {} + } + } + + // We will need to pass some of the members to a closure, so we need + // to separate them otherwise the borrow checker will complain, this + // shouldn't be needed in rust 2021 + let &mut Self { + module, + info, + ref mut features, + entry_point, + entry_point_idx, + ref policies, + .. + } = self; + + // Loop trough all expressions in both functions and the entry point + // to check for needed features + for (expressions, info) in module + .functions + .iter() + .map(|(h, f)| (&f.expressions, &info[h])) + .chain(std::iter::once(( + &entry_point.function.expressions, + info.get_entry_point(entry_point_idx as usize), + ))) + { + for (_, expr) in expressions.iter() { + match *expr { + // Check for queries that need aditonal features + Expression::ImageQuery { + image, + query, + .. + } => match query { + // Storage images use `imageSize` which is only available + // in glsl > 420 + // + // layers queries are also implemented as size queries + crate::ImageQuery::Size { .. } | crate::ImageQuery::NumLayers => { + if let TypeInner::Image { + class: crate::ImageClass::Storage { .. }, .. + } = *info[image].ty.inner_with(&module.types) { + features.request(Features::IMAGE_SIZE) + } + }, + crate::ImageQuery::NumLevels => features.request(Features::TEXTURE_LEVELS), + crate::ImageQuery::NumSamples => features.request(Features::TEXTURE_SAMPLES), + } + , + // Check for image loads that needs bound checking on the sample + // or level argument since this requires a feature + Expression::ImageLoad { + sample, level, .. + } => { + if policies.image_load != crate::proc::BoundsCheckPolicy::Unchecked { + if sample.is_some() { + features.request(Features::TEXTURE_SAMPLES) + } + + if level.is_some() { + features.request(Features::TEXTURE_LEVELS) + } + } + } + _ => {} + } + } + } + + self.features.check_availability(self.options.version) + } + + /// Helper method that checks the [`Features`] needed by a scalar + fn scalar_required_features(&mut self, scalar: Scalar) { + if scalar.kind == ScalarKind::Float && scalar.width == 8 { + self.features.request(Features::DOUBLE_TYPE); + } + } + + fn varying_required_features(&mut self, binding: Option<&Binding>, ty: Handle<Type>) { + match self.module.types[ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + for member in members { + self.varying_required_features(member.binding.as_ref(), member.ty); + } + } + _ => { + if let Some(binding) = binding { + match *binding { + Binding::BuiltIn(built_in) => match built_in { + crate::BuiltIn::ClipDistance => { + self.features.request(Features::CLIP_DISTANCE) + } + crate::BuiltIn::CullDistance => { + self.features.request(Features::CULL_DISTANCE) + } + crate::BuiltIn::SampleIndex => { + self.features.request(Features::SAMPLE_VARIABLES) + } + crate::BuiltIn::ViewIndex => { + self.features.request(Features::MULTI_VIEW) + } + crate::BuiltIn::InstanceIndex => { + self.features.request(Features::INSTANCE_INDEX) + } + _ => {} + }, + Binding::Location { + location: _, + interpolation, + sampling, + second_blend_source, + } => { + if interpolation == Some(Interpolation::Linear) { + self.features.request(Features::NOPERSPECTIVE_QUALIFIER); + } + if sampling == Some(Sampling::Sample) { + self.features.request(Features::SAMPLE_QUALIFIER); + } + if second_blend_source { + self.features.request(Features::DUAL_SOURCE_BLENDING); + } + } + } + } + } + } + } +} diff --git a/third_party/rust/naga/src/back/glsl/keywords.rs b/third_party/rust/naga/src/back/glsl/keywords.rs new file mode 100644 index 0000000000..857c935e68 --- /dev/null +++ b/third_party/rust/naga/src/back/glsl/keywords.rs @@ -0,0 +1,484 @@ +pub const RESERVED_KEYWORDS: &[&str] = &[ + // + // GLSL 4.6 keywords, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L2004-L2322 + // GLSL ES 3.2 keywords, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/es/3.2/GLSL_ES_Specification_3.20.html#L2166-L2478 + // + // Note: The GLSL ES 3.2 keywords are the same as GLSL 4.6 keywords with some residing in the reserved section. + // The only exception are the missing Vulkan keywords which I think is an oversight (see https://github.com/KhronosGroup/OpenGL-Registry/issues/585). + // + "const", + "uniform", + "buffer", + "shared", + "attribute", + "varying", + "coherent", + "volatile", + "restrict", + "readonly", + "writeonly", + "atomic_uint", + "layout", + "centroid", + "flat", + "smooth", + "noperspective", + "patch", + "sample", + "invariant", + "precise", + "break", + "continue", + "do", + "for", + "while", + "switch", + "case", + "default", + "if", + "else", + "subroutine", + "in", + "out", + "inout", + "int", + "void", + "bool", + "true", + "false", + "float", + "double", + "discard", + "return", + "vec2", + "vec3", + "vec4", + "ivec2", + "ivec3", + "ivec4", + "bvec2", + "bvec3", + "bvec4", + "uint", + "uvec2", + "uvec3", + "uvec4", + "dvec2", + "dvec3", + "dvec4", + "mat2", + "mat3", + "mat4", + "mat2x2", + "mat2x3", + "mat2x4", + "mat3x2", + "mat3x3", + "mat3x4", + "mat4x2", + "mat4x3", + "mat4x4", + "dmat2", + "dmat3", + "dmat4", + "dmat2x2", + "dmat2x3", + "dmat2x4", + "dmat3x2", + "dmat3x3", + "dmat3x4", + "dmat4x2", + "dmat4x3", + "dmat4x4", + "lowp", + "mediump", + "highp", + "precision", + "sampler1D", + "sampler1DShadow", + "sampler1DArray", + "sampler1DArrayShadow", + "isampler1D", + "isampler1DArray", + "usampler1D", + "usampler1DArray", + "sampler2D", + "sampler2DShadow", + "sampler2DArray", + "sampler2DArrayShadow", + "isampler2D", + "isampler2DArray", + "usampler2D", + "usampler2DArray", + "sampler2DRect", + "sampler2DRectShadow", + "isampler2DRect", + "usampler2DRect", + "sampler2DMS", + "isampler2DMS", + "usampler2DMS", + "sampler2DMSArray", + "isampler2DMSArray", + "usampler2DMSArray", + "sampler3D", + "isampler3D", + "usampler3D", + "samplerCube", + "samplerCubeShadow", + "isamplerCube", + "usamplerCube", + "samplerCubeArray", + "samplerCubeArrayShadow", + "isamplerCubeArray", + "usamplerCubeArray", + "samplerBuffer", + "isamplerBuffer", + "usamplerBuffer", + "image1D", + "iimage1D", + "uimage1D", + "image1DArray", + "iimage1DArray", + "uimage1DArray", + "image2D", + "iimage2D", + "uimage2D", + "image2DArray", + "iimage2DArray", + "uimage2DArray", + "image2DRect", + "iimage2DRect", + "uimage2DRect", + "image2DMS", + "iimage2DMS", + "uimage2DMS", + "image2DMSArray", + "iimage2DMSArray", + "uimage2DMSArray", + "image3D", + "iimage3D", + "uimage3D", + "imageCube", + "iimageCube", + "uimageCube", + "imageCubeArray", + "iimageCubeArray", + "uimageCubeArray", + "imageBuffer", + "iimageBuffer", + "uimageBuffer", + "struct", + // Vulkan keywords + "texture1D", + "texture1DArray", + "itexture1D", + "itexture1DArray", + "utexture1D", + "utexture1DArray", + "texture2D", + "texture2DArray", + "itexture2D", + "itexture2DArray", + "utexture2D", + "utexture2DArray", + "texture2DRect", + "itexture2DRect", + "utexture2DRect", + "texture2DMS", + "itexture2DMS", + "utexture2DMS", + "texture2DMSArray", + "itexture2DMSArray", + "utexture2DMSArray", + "texture3D", + "itexture3D", + "utexture3D", + "textureCube", + "itextureCube", + "utextureCube", + "textureCubeArray", + "itextureCubeArray", + "utextureCubeArray", + "textureBuffer", + "itextureBuffer", + "utextureBuffer", + "sampler", + "samplerShadow", + "subpassInput", + "isubpassInput", + "usubpassInput", + "subpassInputMS", + "isubpassInputMS", + "usubpassInputMS", + // Reserved keywords + "common", + "partition", + "active", + "asm", + "class", + "union", + "enum", + "typedef", + "template", + "this", + "resource", + "goto", + "inline", + "noinline", + "public", + "static", + "extern", + "external", + "interface", + "long", + "short", + "half", + "fixed", + "unsigned", + "superp", + "input", + "output", + "hvec2", + "hvec3", + "hvec4", + "fvec2", + "fvec3", + "fvec4", + "filter", + "sizeof", + "cast", + "namespace", + "using", + "sampler3DRect", + // + // GLSL 4.6 Built-In Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13314 + // + // Angle and Trigonometry Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13469-L13561C5 + "radians", + "degrees", + "sin", + "cos", + "tan", + "asin", + "acos", + "atan", + "sinh", + "cosh", + "tanh", + "asinh", + "acosh", + "atanh", + // Exponential Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13569-L13620 + "pow", + "exp", + "log", + "exp2", + "log2", + "sqrt", + "inversesqrt", + // Common Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13628-L13908 + "abs", + "sign", + "floor", + "trunc", + "round", + "roundEven", + "ceil", + "fract", + "mod", + "modf", + "min", + "max", + "clamp", + "mix", + "step", + "smoothstep", + "isnan", + "isinf", + "floatBitsToInt", + "floatBitsToUint", + "intBitsToFloat", + "uintBitsToFloat", + "fma", + "frexp", + "ldexp", + // Floating-Point Pack and Unpack Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13916-L14007 + "packUnorm2x16", + "packSnorm2x16", + "packUnorm4x8", + "packSnorm4x8", + "unpackUnorm2x16", + "unpackSnorm2x16", + "unpackUnorm4x8", + "unpackSnorm4x8", + "packHalf2x16", + "unpackHalf2x16", + "packDouble2x32", + "unpackDouble2x32", + // Geometric Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14014-L14121 + "length", + "distance", + "dot", + "cross", + "normalize", + "ftransform", + "faceforward", + "reflect", + "refract", + // Matrix Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14151-L14215 + "matrixCompMult", + "outerProduct", + "transpose", + "determinant", + "inverse", + // Vector Relational Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14259-L14322 + "lessThan", + "lessThanEqual", + "greaterThan", + "greaterThanEqual", + "equal", + "notEqual", + "any", + "all", + "not", + // Integer Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14335-L14432 + "uaddCarry", + "usubBorrow", + "umulExtended", + "imulExtended", + "bitfieldExtract", + "bitfieldInsert", + "bitfieldReverse", + "bitCount", + "findLSB", + "findMSB", + // Texture Query Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14645-L14732 + "textureSize", + "textureQueryLod", + "textureQueryLevels", + "textureSamples", + // Texel Lookup Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14736-L14997 + "texture", + "textureProj", + "textureLod", + "textureOffset", + "texelFetch", + "texelFetchOffset", + "textureProjOffset", + "textureLodOffset", + "textureProjLod", + "textureProjLodOffset", + "textureGrad", + "textureGradOffset", + "textureProjGrad", + "textureProjGradOffset", + // Texture Gather Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15077-L15154 + "textureGather", + "textureGatherOffset", + "textureGatherOffsets", + // Compatibility Profile Texture Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15161-L15220 + "texture1D", + "texture1DProj", + "texture1DLod", + "texture1DProjLod", + "texture2D", + "texture2DProj", + "texture2DLod", + "texture2DProjLod", + "texture3D", + "texture3DProj", + "texture3DLod", + "texture3DProjLod", + "textureCube", + "textureCubeLod", + "shadow1D", + "shadow2D", + "shadow1DProj", + "shadow2DProj", + "shadow1DLod", + "shadow2DLod", + "shadow1DProjLod", + "shadow2DProjLod", + // Atomic Counter Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15241-L15531 + "atomicCounterIncrement", + "atomicCounterDecrement", + "atomicCounter", + "atomicCounterAdd", + "atomicCounterSubtract", + "atomicCounterMin", + "atomicCounterMax", + "atomicCounterAnd", + "atomicCounterOr", + "atomicCounterXor", + "atomicCounterExchange", + "atomicCounterCompSwap", + // Atomic Memory Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15563-L15624 + "atomicAdd", + "atomicMin", + "atomicMax", + "atomicAnd", + "atomicOr", + "atomicXor", + "atomicExchange", + "atomicCompSwap", + // Image Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15763-L15878 + "imageSize", + "imageSamples", + "imageLoad", + "imageStore", + "imageAtomicAdd", + "imageAtomicMin", + "imageAtomicMax", + "imageAtomicAnd", + "imageAtomicOr", + "imageAtomicXor", + "imageAtomicExchange", + "imageAtomicCompSwap", + // Geometry Shader Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15886-L15932 + "EmitStreamVertex", + "EndStreamPrimitive", + "EmitVertex", + "EndPrimitive", + // Fragment Processing Functions, Derivative Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16041-L16114 + "dFdx", + "dFdy", + "dFdxFine", + "dFdyFine", + "dFdxCoarse", + "dFdyCoarse", + "fwidth", + "fwidthFine", + "fwidthCoarse", + // Fragment Processing Functions, Interpolation Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16150-L16198 + "interpolateAtCentroid", + "interpolateAtSample", + "interpolateAtOffset", + // Noise Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16214-L16243 + "noise1", + "noise2", + "noise3", + "noise4", + // Shader Invocation Control Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16255-L16276 + "barrier", + // Shader Memory Control Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16336-L16382 + "memoryBarrier", + "memoryBarrierAtomicCounter", + "memoryBarrierBuffer", + "memoryBarrierShared", + "memoryBarrierImage", + "groupMemoryBarrier", + // Subpass-Input Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16451-L16470 + "subpassLoad", + // Shader Invocation Group Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16483-L16511 + "anyInvocation", + "allInvocations", + "allInvocationsEqual", + // + // entry point name (should not be shadowed) + // + "main", + // Naga utilities: + super::MODF_FUNCTION, + super::FREXP_FUNCTION, + super::FIRST_INSTANCE_BINDING, +]; diff --git a/third_party/rust/naga/src/back/glsl/mod.rs b/third_party/rust/naga/src/back/glsl/mod.rs new file mode 100644 index 0000000000..e346d43257 --- /dev/null +++ b/third_party/rust/naga/src/back/glsl/mod.rs @@ -0,0 +1,4532 @@ +/*! +Backend for [GLSL][glsl] (OpenGL Shading Language). + +The main structure is [`Writer`], it maintains internal state that is used +to output a [`Module`](crate::Module) into glsl + +# Supported versions +### Core +- 330 +- 400 +- 410 +- 420 +- 430 +- 450 + +### ES +- 300 +- 310 + +[glsl]: https://www.khronos.org/registry/OpenGL/index_gl.php +*/ + +// GLSL is mostly a superset of C but it also removes some parts of it this is a list of relevant +// aspects for this backend. +// +// The most notable change is the introduction of the version preprocessor directive that must +// always be the first line of a glsl file and is written as +// `#version number profile` +// `number` is the version itself (i.e. 300) and `profile` is the +// shader profile we only support "core" and "es", the former is used in desktop applications and +// the later is used in embedded contexts, mobile devices and browsers. Each one as it's own +// versions (at the time of writing this the latest version for "core" is 460 and for "es" is 320) +// +// Other important preprocessor addition is the extension directive which is written as +// `#extension name: behaviour` +// Extensions provide increased features in a plugin fashion but they aren't required to be +// supported hence why they are called extensions, that's why `behaviour` is used it specifies +// whether the extension is strictly required or if it should only be enabled if needed. In our case +// when we use extensions we set behaviour to `require` always. +// +// The only thing that glsl removes that makes a difference are pointers. +// +// Additions that are relevant for the backend are the discard keyword, the introduction of +// vector, matrices, samplers, image types and functions that provide common shader operations + +pub use features::Features; + +use crate::{ + back, + proc::{self, NameKey}, + valid, Handle, ShaderStage, TypeInner, +}; +use features::FeaturesManager; +use std::{ + cmp::Ordering, + fmt, + fmt::{Error as FmtError, Write}, + mem, +}; +use thiserror::Error; + +/// Contains the features related code and the features querying method +mod features; +/// Contains a constant with a slice of all the reserved keywords RESERVED_KEYWORDS +mod keywords; + +/// List of supported `core` GLSL versions. +pub const SUPPORTED_CORE_VERSIONS: &[u16] = &[140, 150, 330, 400, 410, 420, 430, 440, 450, 460]; +/// List of supported `es` GLSL versions. +pub const SUPPORTED_ES_VERSIONS: &[u16] = &[300, 310, 320]; + +/// The suffix of the variable that will hold the calculated clamped level +/// of detail for bounds checking in `ImageLoad` +const CLAMPED_LOD_SUFFIX: &str = "_clamped_lod"; + +pub(crate) const MODF_FUNCTION: &str = "naga_modf"; +pub(crate) const FREXP_FUNCTION: &str = "naga_frexp"; + +// Must match code in glsl_built_in +pub const FIRST_INSTANCE_BINDING: &str = "naga_vs_first_instance"; + +/// Mapping between resources and bindings. +pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, u8>; + +impl crate::AtomicFunction { + const fn to_glsl(self) -> &'static str { + match self { + Self::Add | Self::Subtract => "Add", + Self::And => "And", + Self::InclusiveOr => "Or", + Self::ExclusiveOr => "Xor", + Self::Min => "Min", + Self::Max => "Max", + Self::Exchange { compare: None } => "Exchange", + Self::Exchange { compare: Some(_) } => "", //TODO + } + } +} + +impl crate::AddressSpace { + const fn is_buffer(&self) -> bool { + match *self { + crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } => true, + _ => false, + } + } + + /// Whether a variable with this address space can be initialized + const fn initializable(&self) -> bool { + match *self { + crate::AddressSpace::Function | crate::AddressSpace::Private => true, + crate::AddressSpace::WorkGroup + | crate::AddressSpace::Uniform + | crate::AddressSpace::Storage { .. } + | crate::AddressSpace::Handle + | crate::AddressSpace::PushConstant => false, + } + } +} + +/// A GLSL version. +#[derive(Debug, Copy, Clone, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum Version { + /// `core` GLSL. + Desktop(u16), + /// `es` GLSL. + Embedded { version: u16, is_webgl: bool }, +} + +impl Version { + /// Create a new gles version + pub const fn new_gles(version: u16) -> Self { + Self::Embedded { + version, + is_webgl: false, + } + } + + /// Returns true if self is `Version::Embedded` (i.e. is a es version) + const fn is_es(&self) -> bool { + match *self { + Version::Desktop(_) => false, + Version::Embedded { .. } => true, + } + } + + /// Returns true if targeting WebGL + const fn is_webgl(&self) -> bool { + match *self { + Version::Desktop(_) => false, + Version::Embedded { is_webgl, .. } => is_webgl, + } + } + + /// Checks the list of currently supported versions and returns true if it contains the + /// specified version + /// + /// # Notes + /// As an invalid version number will never be added to the supported version list + /// so this also checks for version validity + fn is_supported(&self) -> bool { + match *self { + Version::Desktop(v) => SUPPORTED_CORE_VERSIONS.contains(&v), + Version::Embedded { version: v, .. } => SUPPORTED_ES_VERSIONS.contains(&v), + } + } + + fn supports_io_locations(&self) -> bool { + *self >= Version::Desktop(330) || *self >= Version::new_gles(300) + } + + /// Checks if the version supports all of the explicit layouts: + /// - `location=` qualifiers for bindings + /// - `binding=` qualifiers for resources + /// + /// Note: `location=` for vertex inputs and fragment outputs is supported + /// unconditionally for GLES 300. + fn supports_explicit_locations(&self) -> bool { + *self >= Version::Desktop(410) || *self >= Version::new_gles(310) + } + + fn supports_early_depth_test(&self) -> bool { + *self >= Version::Desktop(130) || *self >= Version::new_gles(310) + } + + fn supports_std430_layout(&self) -> bool { + *self >= Version::Desktop(430) || *self >= Version::new_gles(310) + } + + fn supports_fma_function(&self) -> bool { + *self >= Version::Desktop(400) || *self >= Version::new_gles(320) + } + + fn supports_integer_functions(&self) -> bool { + *self >= Version::Desktop(400) || *self >= Version::new_gles(310) + } + + fn supports_frexp_function(&self) -> bool { + *self >= Version::Desktop(400) || *self >= Version::new_gles(310) + } + + fn supports_derivative_control(&self) -> bool { + *self >= Version::Desktop(450) + } +} + +impl PartialOrd for Version { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + match (*self, *other) { + (Version::Desktop(x), Version::Desktop(y)) => Some(x.cmp(&y)), + (Version::Embedded { version: x, .. }, Version::Embedded { version: y, .. }) => { + Some(x.cmp(&y)) + } + _ => None, + } + } +} + +impl fmt::Display for Version { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Version::Desktop(v) => write!(f, "{v} core"), + Version::Embedded { version: v, .. } => write!(f, "{v} es"), + } + } +} + +bitflags::bitflags! { + /// Configuration flags for the [`Writer`]. + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct WriterFlags: u32 { + /// Flip output Y and extend Z from (0, 1) to (-1, 1). + const ADJUST_COORDINATE_SPACE = 0x1; + /// Supports GL_EXT_texture_shadow_lod on the host, which provides + /// additional functions on shadows and arrays of shadows. + const TEXTURE_SHADOW_LOD = 0x2; + /// Supports ARB_shader_draw_parameters on the host, which provides + /// support for `gl_BaseInstanceARB`, `gl_BaseVertexARB`, and `gl_DrawIDARB`. + const DRAW_PARAMETERS = 0x4; + /// Include unused global variables, constants and functions. By default the output will exclude + /// global variables that are not used in the specified entrypoint (including indirect use), + /// all constant declarations, and functions that use excluded global variables. + const INCLUDE_UNUSED_ITEMS = 0x10; + /// Emit `PointSize` output builtin to vertex shaders, which is + /// required for drawing with `PointList` topology. + /// + /// https://registry.khronos.org/OpenGL/specs/es/3.2/GLSL_ES_Specification_3.20.html#built-in-language-variables + /// The variable gl_PointSize is intended for a shader to write the size of the point to be rasterized. It is measured in pixels. + /// If gl_PointSize is not written to, its value is undefined in subsequent pipe stages. + const FORCE_POINT_SIZE = 0x20; + } +} + +/// Configuration used in the [`Writer`]. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct Options { + /// The GLSL version to be used. + pub version: Version, + /// Configuration flags for the [`Writer`]. + pub writer_flags: WriterFlags, + /// Map of resources association to binding locations. + pub binding_map: BindingMap, + /// Should workgroup variables be zero initialized (by polyfilling)? + pub zero_initialize_workgroup_memory: bool, +} + +impl Default for Options { + fn default() -> Self { + Options { + version: Version::new_gles(310), + writer_flags: WriterFlags::ADJUST_COORDINATE_SPACE, + binding_map: BindingMap::default(), + zero_initialize_workgroup_memory: true, + } + } +} + +/// A subset of options meant to be changed per pipeline. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct PipelineOptions { + /// The stage of the entry point. + pub shader_stage: ShaderStage, + /// The name of the entry point. + /// + /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown. + pub entry_point: String, + /// How many views to render to, if doing multiview rendering. + pub multiview: Option<std::num::NonZeroU32>, +} + +#[derive(Debug)] +pub struct VaryingLocation { + /// The location of the global. + /// This corresponds to `layout(location = ..)` in GLSL. + pub location: u32, + /// The index which can be used for dual source blending. + /// This corresponds to `layout(index = ..)` in GLSL. + pub index: u32, +} + +/// Reflection info for texture mappings and uniforms. +#[derive(Debug)] +pub struct ReflectionInfo { + /// Mapping between texture names and variables/samplers. + pub texture_mapping: crate::FastHashMap<String, TextureMapping>, + /// Mapping between uniform variables and names. + pub uniforms: crate::FastHashMap<Handle<crate::GlobalVariable>, String>, + /// Mapping between names and attribute locations. + pub varying: crate::FastHashMap<String, VaryingLocation>, + /// List of push constant items in the shader. + pub push_constant_items: Vec<PushConstantItem>, +} + +/// Mapping between a texture and its sampler, if it exists. +/// +/// GLSL pre-Vulkan has no concept of separate textures and samplers. Instead, everything is a +/// `gsamplerN` where `g` is the scalar type and `N` is the dimension. But naga uses separate textures +/// and samplers in the IR, so the backend produces a [`FastHashMap`](crate::FastHashMap) with the texture name +/// as a key and a [`TextureMapping`] as a value. This way, the user knows where to bind. +/// +/// [`Storage`](crate::ImageClass::Storage) images produce `gimageN` and don't have an associated sampler, +/// so the [`sampler`](Self::sampler) field will be [`None`]. +#[derive(Debug, Clone)] +pub struct TextureMapping { + /// Handle to the image global variable. + pub texture: Handle<crate::GlobalVariable>, + /// Handle to the associated sampler global variable, if it exists. + pub sampler: Option<Handle<crate::GlobalVariable>>, +} + +/// All information to bind a single uniform value to the shader. +/// +/// Push constants are emulated using traditional uniforms in OpenGL. +/// +/// These are composed of a set of primitives (scalar, vector, matrix) that +/// are given names. Because they are not backed by the concept of a buffer, +/// we must do the work of calculating the offset of each primitive in the +/// push constant block. +#[derive(Debug, Clone)] +pub struct PushConstantItem { + /// GL uniform name for the item. This name is the same as if you were + /// to access it directly from a GLSL shader. + /// + /// The with the following example, the following names will be generated, + /// one name per GLSL uniform. + /// + /// ```glsl + /// struct InnerStruct { + /// value: f32, + /// } + /// + /// struct PushConstant { + /// InnerStruct inner; + /// vec4 array[2]; + /// } + /// + /// uniform PushConstants _push_constant_binding_cs; + /// ``` + /// + /// ```text + /// - _push_constant_binding_cs.inner.value + /// - _push_constant_binding_cs.array[0] + /// - _push_constant_binding_cs.array[1] + /// ``` + /// + pub access_path: String, + /// Type of the uniform. This will only ever be a scalar, vector, or matrix. + pub ty: Handle<crate::Type>, + /// The offset in the push constant memory block this uniform maps to. + /// + /// The size of the uniform can be derived from the type. + pub offset: u32, +} + +/// Helper structure that generates a number +#[derive(Default)] +struct IdGenerator(u32); + +impl IdGenerator { + /// Generates a number that's guaranteed to be unique for this `IdGenerator` + fn generate(&mut self) -> u32 { + // It's just an increasing number but it does the job + let ret = self.0; + self.0 += 1; + ret + } +} + +/// Assorted options needed for generating varyings. +#[derive(Clone, Copy)] +struct VaryingOptions { + output: bool, + targeting_webgl: bool, + draw_parameters: bool, +} + +impl VaryingOptions { + const fn from_writer_options(options: &Options, output: bool) -> Self { + Self { + output, + targeting_webgl: options.version.is_webgl(), + draw_parameters: options.writer_flags.contains(WriterFlags::DRAW_PARAMETERS), + } + } +} + +/// Helper wrapper used to get a name for a varying +/// +/// Varying have different naming schemes depending on their binding: +/// - Varyings with builtin bindings get the from [`glsl_built_in`]. +/// - Varyings with location bindings are named `_S_location_X` where `S` is a +/// prefix identifying which pipeline stage the varying connects, and `X` is +/// the location. +struct VaryingName<'a> { + binding: &'a crate::Binding, + stage: ShaderStage, + options: VaryingOptions, +} +impl fmt::Display for VaryingName<'_> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self.binding { + crate::Binding::Location { + second_blend_source: true, + .. + } => { + write!(f, "_fs2p_location1",) + } + crate::Binding::Location { location, .. } => { + let prefix = match (self.stage, self.options.output) { + (ShaderStage::Compute, _) => unreachable!(), + // pipeline to vertex + (ShaderStage::Vertex, false) => "p2vs", + // vertex to fragment + (ShaderStage::Vertex, true) | (ShaderStage::Fragment, false) => "vs2fs", + // fragment to pipeline + (ShaderStage::Fragment, true) => "fs2p", + }; + write!(f, "_{prefix}_location{location}",) + } + crate::Binding::BuiltIn(built_in) => { + write!(f, "{}", glsl_built_in(built_in, self.options)) + } + } + } +} + +impl ShaderStage { + const fn to_str(self) -> &'static str { + match self { + ShaderStage::Compute => "cs", + ShaderStage::Fragment => "fs", + ShaderStage::Vertex => "vs", + } + } +} + +/// Shorthand result used internally by the backend +type BackendResult<T = ()> = Result<T, Error>; + +/// A GLSL compilation error. +#[derive(Debug, Error)] +pub enum Error { + /// A error occurred while writing to the output. + #[error("Format error")] + FmtError(#[from] FmtError), + /// The specified [`Version`] doesn't have all required [`Features`]. + /// + /// Contains the missing [`Features`]. + #[error("The selected version doesn't support {0:?}")] + MissingFeatures(Features), + /// [`AddressSpace::PushConstant`](crate::AddressSpace::PushConstant) was used more than + /// once in the entry point, which isn't supported. + #[error("Multiple push constants aren't supported")] + MultiplePushConstants, + /// The specified [`Version`] isn't supported. + #[error("The specified version isn't supported")] + VersionNotSupported, + /// The entry point couldn't be found. + #[error("The requested entry point couldn't be found")] + EntryPointNotFound, + /// A call was made to an unsupported external. + #[error("A call was made to an unsupported external: {0}")] + UnsupportedExternal(String), + /// A scalar with an unsupported width was requested. + #[error("A scalar with an unsupported width was requested: {0:?}")] + UnsupportedScalar(crate::Scalar), + /// A image was used with multiple samplers, which isn't supported. + #[error("A image was used with multiple samplers")] + ImageMultipleSamplers, + #[error("{0}")] + Custom(String), +} + +/// Binary operation with a different logic on the GLSL side. +enum BinaryOperation { + /// Vector comparison should use the function like `greaterThan()`, etc. + VectorCompare, + /// Vector component wise operation; used to polyfill unsupported ops like `|` and `&` for `bvecN`'s + VectorComponentWise, + /// GLSL `%` is SPIR-V `OpUMod/OpSMod` and `mod()` is `OpFMod`, but [`BinaryOperator::Modulo`](crate::BinaryOperator::Modulo) is `OpFRem`. + Modulo, + /// Any plain operation. No additional logic required. + Other, +} + +/// Writer responsible for all code generation. +pub struct Writer<'a, W> { + // Inputs + /// The module being written. + module: &'a crate::Module, + /// The module analysis. + info: &'a valid::ModuleInfo, + /// The output writer. + out: W, + /// User defined configuration to be used. + options: &'a Options, + /// The bound checking policies to be used + policies: proc::BoundsCheckPolicies, + + // Internal State + /// Features manager used to store all the needed features and write them. + features: FeaturesManager, + namer: proc::Namer, + /// A map with all the names needed for writing the module + /// (generated by a [`Namer`](crate::proc::Namer)). + names: crate::FastHashMap<NameKey, String>, + /// A map with the names of global variables needed for reflections. + reflection_names_globals: crate::FastHashMap<Handle<crate::GlobalVariable>, String>, + /// The selected entry point. + entry_point: &'a crate::EntryPoint, + /// The index of the selected entry point. + entry_point_idx: proc::EntryPointIndex, + /// A generator for unique block numbers. + block_id: IdGenerator, + /// Set of expressions that have associated temporary variables. + named_expressions: crate::NamedExpressions, + /// Set of expressions that need to be baked to avoid unnecessary repetition in output + need_bake_expressions: back::NeedBakeExpressions, + /// How many views to render to, if doing multiview rendering. + multiview: Option<std::num::NonZeroU32>, + /// Mapping of varying variables to their location. Needed for reflections. + varying: crate::FastHashMap<String, VaryingLocation>, +} + +impl<'a, W: Write> Writer<'a, W> { + /// Creates a new [`Writer`] instance. + /// + /// # Errors + /// - If the version specified is invalid or supported. + /// - If the entry point couldn't be found in the module. + /// - If the version specified doesn't support some used features. + pub fn new( + out: W, + module: &'a crate::Module, + info: &'a valid::ModuleInfo, + options: &'a Options, + pipeline_options: &'a PipelineOptions, + policies: proc::BoundsCheckPolicies, + ) -> Result<Self, Error> { + // Check if the requested version is supported + if !options.version.is_supported() { + log::error!("Version {}", options.version); + return Err(Error::VersionNotSupported); + } + + // Try to find the entry point and corresponding index + let ep_idx = module + .entry_points + .iter() + .position(|ep| { + pipeline_options.shader_stage == ep.stage && pipeline_options.entry_point == ep.name + }) + .ok_or(Error::EntryPointNotFound)?; + + // Generate a map with names required to write the module + let mut names = crate::FastHashMap::default(); + let mut namer = proc::Namer::default(); + namer.reset( + module, + keywords::RESERVED_KEYWORDS, + &[], + &[], + &[ + "gl_", // all GL built-in variables + "_group", // all normal bindings + "_push_constant_binding_", // all push constant bindings + ], + &mut names, + ); + + // Build the instance + let mut this = Self { + module, + info, + out, + options, + policies, + + namer, + features: FeaturesManager::new(), + names, + reflection_names_globals: crate::FastHashMap::default(), + entry_point: &module.entry_points[ep_idx], + entry_point_idx: ep_idx as u16, + multiview: pipeline_options.multiview, + block_id: IdGenerator::default(), + named_expressions: Default::default(), + need_bake_expressions: Default::default(), + varying: Default::default(), + }; + + // Find all features required to print this module + this.collect_required_features()?; + + Ok(this) + } + + /// Writes the [`Module`](crate::Module) as glsl to the output + /// + /// # Notes + /// If an error occurs while writing, the output might have been written partially + /// + /// # Panics + /// Might panic if the module is invalid + pub fn write(&mut self) -> Result<ReflectionInfo, Error> { + // We use `writeln!(self.out)` throughout the write to add newlines + // to make the output more readable + + let es = self.options.version.is_es(); + + // Write the version (It must be the first thing or it isn't a valid glsl output) + writeln!(self.out, "#version {}", self.options.version)?; + // Write all the needed extensions + // + // This used to be the last thing being written as it allowed to search for features while + // writing the module saving some loops but some older versions (420 or less) required the + // extensions to appear before being used, even though extensions are part of the + // preprocessor not the processor ¯\_(ツ)_/¯ + self.features.write(self.options, &mut self.out)?; + + // Write the additional extensions + if self + .options + .writer_flags + .contains(WriterFlags::TEXTURE_SHADOW_LOD) + { + // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_shadow_lod.txt + writeln!(self.out, "#extension GL_EXT_texture_shadow_lod : require")?; + } + + // glsl es requires a precision to be specified for floats and ints + // TODO: Should this be user configurable? + if es { + writeln!(self.out)?; + writeln!(self.out, "precision highp float;")?; + writeln!(self.out, "precision highp int;")?; + writeln!(self.out)?; + } + + if self.entry_point.stage == ShaderStage::Compute { + let workgroup_size = self.entry_point.workgroup_size; + writeln!( + self.out, + "layout(local_size_x = {}, local_size_y = {}, local_size_z = {}) in;", + workgroup_size[0], workgroup_size[1], workgroup_size[2] + )?; + writeln!(self.out)?; + } + + if self.entry_point.stage == ShaderStage::Vertex + && !self + .options + .writer_flags + .contains(WriterFlags::DRAW_PARAMETERS) + && self.features.contains(Features::INSTANCE_INDEX) + { + writeln!(self.out, "uniform uint {FIRST_INSTANCE_BINDING};")?; + writeln!(self.out)?; + } + + // Enable early depth tests if needed + if let Some(depth_test) = self.entry_point.early_depth_test { + // If early depth test is supported for this version of GLSL + if self.options.version.supports_early_depth_test() { + writeln!(self.out, "layout(early_fragment_tests) in;")?; + + if let Some(conservative) = depth_test.conservative { + use crate::ConservativeDepth as Cd; + + let depth = match conservative { + Cd::GreaterEqual => "greater", + Cd::LessEqual => "less", + Cd::Unchanged => "unchanged", + }; + writeln!(self.out, "layout (depth_{depth}) out float gl_FragDepth;")?; + } + writeln!(self.out)?; + } else { + log::warn!( + "Early depth testing is not supported for this version of GLSL: {}", + self.options.version + ); + } + } + + if self.entry_point.stage == ShaderStage::Vertex && self.options.version.is_webgl() { + if let Some(multiview) = self.multiview.as_ref() { + writeln!(self.out, "layout(num_views = {multiview}) in;")?; + writeln!(self.out)?; + } + } + + // Write struct types. + // + // This are always ordered because the IR is structured in a way that + // you can't make a struct without adding all of its members first. + for (handle, ty) in self.module.types.iter() { + if let TypeInner::Struct { ref members, .. } = ty.inner { + // Structures ending with runtime-sized arrays can only be + // rendered as shader storage blocks in GLSL, not stand-alone + // struct types. + if !self.module.types[members.last().unwrap().ty] + .inner + .is_dynamically_sized(&self.module.types) + { + let name = &self.names[&NameKey::Type(handle)]; + write!(self.out, "struct {name} ")?; + self.write_struct_body(handle, members)?; + writeln!(self.out, ";")?; + } + } + } + + // Write functions to create special types. + for (type_key, struct_ty) in self.module.special_types.predeclared_types.iter() { + match type_key { + &crate::PredeclaredType::ModfResult { size, width } + | &crate::PredeclaredType::FrexpResult { size, width } => { + let arg_type_name_owner; + let arg_type_name = if let Some(size) = size { + arg_type_name_owner = + format!("{}vec{}", if width == 8 { "d" } else { "" }, size as u8); + &arg_type_name_owner + } else if width == 8 { + "double" + } else { + "float" + }; + + let other_type_name_owner; + let (defined_func_name, called_func_name, other_type_name) = + if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) { + (MODF_FUNCTION, "modf", arg_type_name) + } else { + let other_type_name = if let Some(size) = size { + other_type_name_owner = format!("ivec{}", size as u8); + &other_type_name_owner + } else { + "int" + }; + (FREXP_FUNCTION, "frexp", other_type_name) + }; + + let struct_name = &self.names[&NameKey::Type(*struct_ty)]; + + writeln!(self.out)?; + if !self.options.version.supports_frexp_function() + && matches!(type_key, &crate::PredeclaredType::FrexpResult { .. }) + { + writeln!( + self.out, + "{struct_name} {defined_func_name}({arg_type_name} arg) {{ + {other_type_name} other = arg == {arg_type_name}(0) ? {other_type_name}(0) : {other_type_name}({arg_type_name}(1) + log2(arg)); + {arg_type_name} fract = arg * exp2({arg_type_name}(-other)); + return {struct_name}(fract, other); +}}", + )?; + } else { + writeln!( + self.out, + "{struct_name} {defined_func_name}({arg_type_name} arg) {{ + {other_type_name} other; + {arg_type_name} fract = {called_func_name}(arg, other); + return {struct_name}(fract, other); +}}", + )?; + } + } + &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {} + } + } + + // Write all named constants + let mut constants = self + .module + .constants + .iter() + .filter(|&(_, c)| c.name.is_some()) + .peekable(); + while let Some((handle, _)) = constants.next() { + self.write_global_constant(handle)?; + // Add extra newline for readability on last iteration + if constants.peek().is_none() { + writeln!(self.out)?; + } + } + + let ep_info = self.info.get_entry_point(self.entry_point_idx as usize); + + // Write the globals + // + // Unless explicitly disabled with WriterFlags::INCLUDE_UNUSED_ITEMS, + // we filter all globals that aren't used by the selected entry point as they might be + // interfere with each other (i.e. two globals with the same location but different with + // different classes) + let include_unused = self + .options + .writer_flags + .contains(WriterFlags::INCLUDE_UNUSED_ITEMS); + for (handle, global) in self.module.global_variables.iter() { + let is_unused = ep_info[handle].is_empty(); + if !include_unused && is_unused { + continue; + } + + match self.module.types[global.ty].inner { + // We treat images separately because they might require + // writing the storage format + TypeInner::Image { + mut dim, + arrayed, + class, + } => { + // Gather the storage format if needed + let storage_format_access = match self.module.types[global.ty].inner { + TypeInner::Image { + class: crate::ImageClass::Storage { format, access }, + .. + } => Some((format, access)), + _ => None, + }; + + if dim == crate::ImageDimension::D1 && es { + dim = crate::ImageDimension::D2 + } + + // Gether the location if needed + let layout_binding = if self.options.version.supports_explicit_locations() { + let br = global.binding.as_ref().unwrap(); + self.options.binding_map.get(br).cloned() + } else { + None + }; + + // Write all the layout qualifiers + if layout_binding.is_some() || storage_format_access.is_some() { + write!(self.out, "layout(")?; + if let Some(binding) = layout_binding { + write!(self.out, "binding = {binding}")?; + } + if let Some((format, _)) = storage_format_access { + let format_str = glsl_storage_format(format)?; + let separator = match layout_binding { + Some(_) => ",", + None => "", + }; + write!(self.out, "{separator}{format_str}")?; + } + write!(self.out, ") ")?; + } + + if let Some((_, access)) = storage_format_access { + self.write_storage_access(access)?; + } + + // All images in glsl are `uniform` + // The trailing space is important + write!(self.out, "uniform ")?; + + // write the type + // + // This is way we need the leading space because `write_image_type` doesn't add + // any spaces at the beginning or end + self.write_image_type(dim, arrayed, class)?; + + // Finally write the name and end the global with a `;` + // The leading space is important + let global_name = self.get_global_name(handle, global); + writeln!(self.out, " {global_name};")?; + writeln!(self.out)?; + + self.reflection_names_globals.insert(handle, global_name); + } + // glsl has no concept of samplers so we just ignore it + TypeInner::Sampler { .. } => continue, + // All other globals are written by `write_global` + _ => { + self.write_global(handle, global)?; + // Add a newline (only for readability) + writeln!(self.out)?; + } + } + } + + for arg in self.entry_point.function.arguments.iter() { + self.write_varying(arg.binding.as_ref(), arg.ty, false)?; + } + if let Some(ref result) = self.entry_point.function.result { + self.write_varying(result.binding.as_ref(), result.ty, true)?; + } + writeln!(self.out)?; + + // Write all regular functions + for (handle, function) in self.module.functions.iter() { + // Check that the function doesn't use globals that aren't supported + // by the current entry point + if !include_unused && !ep_info.dominates_global_use(&self.info[handle]) { + continue; + } + + let fun_info = &self.info[handle]; + + // Skip functions that that are not compatible with this entry point's stage. + // + // When validation is enabled, it rejects modules whose entry points try to call + // incompatible functions, so if we got this far, then any functions incompatible + // with our selected entry point must not be used. + // + // When validation is disabled, `fun_info.available_stages` is always just + // `ShaderStages::all()`, so this will write all functions in the module, and + // the downstream GLSL compiler will catch any problems. + if !fun_info.available_stages.contains(ep_info.available_stages) { + continue; + } + + // Write the function + self.write_function(back::FunctionType::Function(handle), function, fun_info)?; + + writeln!(self.out)?; + } + + self.write_function( + back::FunctionType::EntryPoint(self.entry_point_idx), + &self.entry_point.function, + ep_info, + )?; + + // Add newline at the end of file + writeln!(self.out)?; + + // Collect all reflection info and return it to the user + self.collect_reflection_info() + } + + fn write_array_size( + &mut self, + base: Handle<crate::Type>, + size: crate::ArraySize, + ) -> BackendResult { + write!(self.out, "[")?; + + // Write the array size + // Writes nothing if `ArraySize::Dynamic` + match size { + crate::ArraySize::Constant(size) => { + write!(self.out, "{size}")?; + } + crate::ArraySize::Dynamic => (), + } + + write!(self.out, "]")?; + + if let TypeInner::Array { + base: next_base, + size: next_size, + .. + } = self.module.types[base].inner + { + self.write_array_size(next_base, next_size)?; + } + + Ok(()) + } + + /// Helper method used to write value types + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_value_type(&mut self, inner: &TypeInner) -> BackendResult { + match *inner { + // Scalars are simple we just get the full name from `glsl_scalar` + TypeInner::Scalar(scalar) + | TypeInner::Atomic(scalar) + | TypeInner::ValuePointer { + size: None, + scalar, + space: _, + } => write!(self.out, "{}", glsl_scalar(scalar)?.full)?, + // Vectors are just `gvecN` where `g` is the scalar prefix and `N` is the vector size + TypeInner::Vector { size, scalar } + | TypeInner::ValuePointer { + size: Some(size), + scalar, + space: _, + } => write!(self.out, "{}vec{}", glsl_scalar(scalar)?.prefix, size as u8)?, + // Matrices are written with `gmatMxN` where `g` is the scalar prefix (only floats and + // doubles are allowed), `M` is the columns count and `N` is the rows count + // + // glsl supports a matrix shorthand `gmatN` where `N` = `M` but it doesn't justify the + // extra branch to write matrices this way + TypeInner::Matrix { + columns, + rows, + scalar, + } => write!( + self.out, + "{}mat{}x{}", + glsl_scalar(scalar)?.prefix, + columns as u8, + rows as u8 + )?, + // GLSL arrays are written as `type name[size]` + // Here we only write the size of the array i.e. `[size]` + // Base `type` and `name` should be written outside + TypeInner::Array { base, size, .. } => self.write_array_size(base, size)?, + // Write all variants instead of `_` so that if new variants are added a + // no exhaustiveness error is thrown + TypeInner::Pointer { .. } + | TypeInner::Struct { .. } + | TypeInner::Image { .. } + | TypeInner::Sampler { .. } + | TypeInner::AccelerationStructure + | TypeInner::RayQuery + | TypeInner::BindingArray { .. } => { + return Err(Error::Custom(format!("Unable to write type {inner:?}"))) + } + } + + Ok(()) + } + + /// Helper method used to write non image/sampler types + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_type(&mut self, ty: Handle<crate::Type>) -> BackendResult { + match self.module.types[ty].inner { + // glsl has no pointer types so just write types as normal and loads are skipped + TypeInner::Pointer { base, .. } => self.write_type(base), + // glsl structs are written as just the struct name + TypeInner::Struct { .. } => { + // Get the struct name + let name = &self.names[&NameKey::Type(ty)]; + write!(self.out, "{name}")?; + Ok(()) + } + // glsl array has the size separated from the base type + TypeInner::Array { base, .. } => self.write_type(base), + ref other => self.write_value_type(other), + } + } + + /// Helper method to write a image type + /// + /// # Notes + /// Adds no leading or trailing whitespace + fn write_image_type( + &mut self, + dim: crate::ImageDimension, + arrayed: bool, + class: crate::ImageClass, + ) -> BackendResult { + // glsl images consist of four parts the scalar prefix, the image "type", the dimensions + // and modifiers + // + // There exists two image types + // - sampler - for sampled images + // - image - for storage images + // + // There are three possible modifiers that can be used together and must be written in + // this order to be valid + // - MS - used if it's a multisampled image + // - Array - used if it's an image array + // - Shadow - used if it's a depth image + use crate::ImageClass as Ic; + + let (base, kind, ms, comparison) = match class { + Ic::Sampled { kind, multi: true } => ("sampler", kind, "MS", ""), + Ic::Sampled { kind, multi: false } => ("sampler", kind, "", ""), + Ic::Depth { multi: true } => ("sampler", crate::ScalarKind::Float, "MS", ""), + Ic::Depth { multi: false } => ("sampler", crate::ScalarKind::Float, "", "Shadow"), + Ic::Storage { format, .. } => ("image", format.into(), "", ""), + }; + + let precision = if self.options.version.is_es() { + "highp " + } else { + "" + }; + + write!( + self.out, + "{}{}{}{}{}{}{}", + precision, + glsl_scalar(crate::Scalar { kind, width: 4 })?.prefix, + base, + glsl_dimension(dim), + ms, + if arrayed { "Array" } else { "" }, + comparison + )?; + + Ok(()) + } + + /// Helper method used to write non images/sampler globals + /// + /// # Notes + /// Adds a newline + /// + /// # Panics + /// If the global has type sampler + fn write_global( + &mut self, + handle: Handle<crate::GlobalVariable>, + global: &crate::GlobalVariable, + ) -> BackendResult { + if self.options.version.supports_explicit_locations() { + if let Some(ref br) = global.binding { + match self.options.binding_map.get(br) { + Some(binding) => { + let layout = match global.space { + crate::AddressSpace::Storage { .. } => { + if self.options.version.supports_std430_layout() { + "std430, " + } else { + "std140, " + } + } + crate::AddressSpace::Uniform => "std140, ", + _ => "", + }; + write!(self.out, "layout({layout}binding = {binding}) ")? + } + None => { + log::debug!("unassigned binding for {:?}", global.name); + if let crate::AddressSpace::Storage { .. } = global.space { + if self.options.version.supports_std430_layout() { + write!(self.out, "layout(std430) ")? + } + } + } + } + } + } + + if let crate::AddressSpace::Storage { access } = global.space { + self.write_storage_access(access)?; + } + + if let Some(storage_qualifier) = glsl_storage_qualifier(global.space) { + write!(self.out, "{storage_qualifier} ")?; + } + + match global.space { + crate::AddressSpace::Private => { + self.write_simple_global(handle, global)?; + } + crate::AddressSpace::WorkGroup => { + self.write_simple_global(handle, global)?; + } + crate::AddressSpace::PushConstant => { + self.write_simple_global(handle, global)?; + } + crate::AddressSpace::Uniform => { + self.write_interface_block(handle, global)?; + } + crate::AddressSpace::Storage { .. } => { + self.write_interface_block(handle, global)?; + } + // A global variable in the `Function` address space is a + // contradiction in terms. + crate::AddressSpace::Function => unreachable!(), + // Textures and samplers are handled directly in `Writer::write`. + crate::AddressSpace::Handle => unreachable!(), + } + + Ok(()) + } + + fn write_simple_global( + &mut self, + handle: Handle<crate::GlobalVariable>, + global: &crate::GlobalVariable, + ) -> BackendResult { + self.write_type(global.ty)?; + write!(self.out, " ")?; + self.write_global_name(handle, global)?; + + if let TypeInner::Array { base, size, .. } = self.module.types[global.ty].inner { + self.write_array_size(base, size)?; + } + + if global.space.initializable() && is_value_init_supported(self.module, global.ty) { + write!(self.out, " = ")?; + if let Some(init) = global.init { + self.write_const_expr(init)?; + } else { + self.write_zero_init_value(global.ty)?; + } + } + + writeln!(self.out, ";")?; + + if let crate::AddressSpace::PushConstant = global.space { + let global_name = self.get_global_name(handle, global); + self.reflection_names_globals.insert(handle, global_name); + } + + Ok(()) + } + + /// Write an interface block for a single Naga global. + /// + /// Write `block_name { members }`. Since `block_name` must be unique + /// between blocks and structs, we add `_block_ID` where `ID` is a + /// `IdGenerator` generated number. Write `members` in the same way we write + /// a struct's members. + fn write_interface_block( + &mut self, + handle: Handle<crate::GlobalVariable>, + global: &crate::GlobalVariable, + ) -> BackendResult { + // Write the block name, it's just the struct name appended with `_block_ID` + let ty_name = &self.names[&NameKey::Type(global.ty)]; + let block_name = format!( + "{}_block_{}{:?}", + // avoid double underscores as they are reserved in GLSL + ty_name.trim_end_matches('_'), + self.block_id.generate(), + self.entry_point.stage, + ); + write!(self.out, "{block_name} ")?; + self.reflection_names_globals.insert(handle, block_name); + + match self.module.types[global.ty].inner { + crate::TypeInner::Struct { ref members, .. } + if self.module.types[members.last().unwrap().ty] + .inner + .is_dynamically_sized(&self.module.types) => + { + // Structs with dynamically sized arrays must have their + // members lifted up as members of the interface block. GLSL + // can't write such struct types anyway. + self.write_struct_body(global.ty, members)?; + write!(self.out, " ")?; + self.write_global_name(handle, global)?; + } + _ => { + // A global of any other type is written as the sole member + // of the interface block. Since the interface block is + // anonymous, this becomes visible in the global scope. + write!(self.out, "{{ ")?; + self.write_type(global.ty)?; + write!(self.out, " ")?; + self.write_global_name(handle, global)?; + if let TypeInner::Array { base, size, .. } = self.module.types[global.ty].inner { + self.write_array_size(base, size)?; + } + write!(self.out, "; }}")?; + } + } + + writeln!(self.out, ";")?; + + Ok(()) + } + + /// Helper method used to find which expressions of a given function require baking + /// + /// # Notes + /// Clears `need_bake_expressions` set before adding to it + fn update_expressions_to_bake(&mut self, func: &crate::Function, info: &valid::FunctionInfo) { + use crate::Expression; + self.need_bake_expressions.clear(); + for (fun_handle, expr) in func.expressions.iter() { + let expr_info = &info[fun_handle]; + let min_ref_count = func.expressions[fun_handle].bake_ref_count(); + if min_ref_count <= expr_info.ref_count { + self.need_bake_expressions.insert(fun_handle); + } + + let inner = expr_info.ty.inner_with(&self.module.types); + + if let Expression::Math { fun, arg, arg1, .. } = *expr { + match fun { + crate::MathFunction::Dot => { + // if the expression is a Dot product with integer arguments, + // then the args needs baking as well + if let TypeInner::Scalar(crate::Scalar { kind, .. }) = *inner { + match kind { + crate::ScalarKind::Sint | crate::ScalarKind::Uint => { + self.need_bake_expressions.insert(arg); + self.need_bake_expressions.insert(arg1.unwrap()); + } + _ => {} + } + } + } + crate::MathFunction::CountLeadingZeros => { + if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { + self.need_bake_expressions.insert(arg); + } + } + _ => {} + } + } + } + } + + /// Helper method used to get a name for a global + /// + /// Globals have different naming schemes depending on their binding: + /// - Globals without bindings use the name from the [`Namer`](crate::proc::Namer) + /// - Globals with resource binding are named `_group_X_binding_Y` where `X` + /// is the group and `Y` is the binding + fn get_global_name( + &self, + handle: Handle<crate::GlobalVariable>, + global: &crate::GlobalVariable, + ) -> String { + match (&global.binding, global.space) { + (&Some(ref br), _) => { + format!( + "_group_{}_binding_{}_{}", + br.group, + br.binding, + self.entry_point.stage.to_str() + ) + } + (&None, crate::AddressSpace::PushConstant) => { + format!("_push_constant_binding_{}", self.entry_point.stage.to_str()) + } + (&None, _) => self.names[&NameKey::GlobalVariable(handle)].clone(), + } + } + + /// Helper method used to write a name for a global without additional heap allocation + fn write_global_name( + &mut self, + handle: Handle<crate::GlobalVariable>, + global: &crate::GlobalVariable, + ) -> BackendResult { + match (&global.binding, global.space) { + (&Some(ref br), _) => write!( + self.out, + "_group_{}_binding_{}_{}", + br.group, + br.binding, + self.entry_point.stage.to_str() + )?, + (&None, crate::AddressSpace::PushConstant) => write!( + self.out, + "_push_constant_binding_{}", + self.entry_point.stage.to_str() + )?, + (&None, _) => write!( + self.out, + "{}", + &self.names[&NameKey::GlobalVariable(handle)] + )?, + } + + Ok(()) + } + + /// Write a GLSL global that will carry a Naga entry point's argument or return value. + /// + /// A Naga entry point's arguments and return value are rendered in GLSL as + /// variables at global scope with the `in` and `out` storage qualifiers. + /// The code we generate for `main` loads from all the `in` globals into + /// appropriately named locals. Before it returns, `main` assigns the + /// components of its return value into all the `out` globals. + /// + /// This function writes a declaration for one such GLSL global, + /// representing a value passed into or returned from [`self.entry_point`] + /// that has a [`Location`] binding. The global's name is generated based on + /// the location index and the shader stages being connected; see + /// [`VaryingName`]. This means we don't need to know the names of + /// arguments, just their types and bindings. + /// + /// Emit nothing for entry point arguments or return values with [`BuiltIn`] + /// bindings; `main` will read from or assign to the appropriate GLSL + /// special variable; these are pre-declared. As an exception, we do declare + /// `gl_Position` or `gl_FragCoord` with the `invariant` qualifier if + /// needed. + /// + /// Use `output` together with [`self.entry_point.stage`] to determine which + /// shader stages are being connected, and choose the `in` or `out` storage + /// qualifier. + /// + /// [`self.entry_point`]: Writer::entry_point + /// [`self.entry_point.stage`]: crate::EntryPoint::stage + /// [`Location`]: crate::Binding::Location + /// [`BuiltIn`]: crate::Binding::BuiltIn + fn write_varying( + &mut self, + binding: Option<&crate::Binding>, + ty: Handle<crate::Type>, + output: bool, + ) -> Result<(), Error> { + // For a struct, emit a separate global for each member with a binding. + if let crate::TypeInner::Struct { ref members, .. } = self.module.types[ty].inner { + for member in members { + self.write_varying(member.binding.as_ref(), member.ty, output)?; + } + return Ok(()); + } + + let binding = match binding { + None => return Ok(()), + Some(binding) => binding, + }; + + let (location, interpolation, sampling, second_blend_source) = match *binding { + crate::Binding::Location { + location, + interpolation, + sampling, + second_blend_source, + } => (location, interpolation, sampling, second_blend_source), + crate::Binding::BuiltIn(built_in) => { + if let crate::BuiltIn::Position { invariant: true } = built_in { + match (self.options.version, self.entry_point.stage) { + ( + Version::Embedded { + version: 300, + is_webgl: true, + }, + ShaderStage::Fragment, + ) => { + // `invariant gl_FragCoord` is not allowed in WebGL2 and possibly + // OpenGL ES in general (waiting on confirmation). + // + // See https://github.com/KhronosGroup/WebGL/issues/3518 + } + _ => { + writeln!( + self.out, + "invariant {};", + glsl_built_in( + built_in, + VaryingOptions::from_writer_options(self.options, output) + ) + )?; + } + } + } + return Ok(()); + } + }; + + // Write the interpolation modifier if needed + // + // We ignore all interpolation and auxiliary modifiers that aren't used in fragment + // shaders' input globals or vertex shaders' output globals. + let emit_interpolation_and_auxiliary = match self.entry_point.stage { + ShaderStage::Vertex => output, + ShaderStage::Fragment => !output, + ShaderStage::Compute => false, + }; + + // Write the I/O locations, if allowed + let io_location = if self.options.version.supports_explicit_locations() + || !emit_interpolation_and_auxiliary + { + if self.options.version.supports_io_locations() { + if second_blend_source { + write!(self.out, "layout(location = {location}, index = 1) ")?; + } else { + write!(self.out, "layout(location = {location}) ")?; + } + None + } else { + Some(VaryingLocation { + location, + index: second_blend_source as u32, + }) + } + } else { + None + }; + + // Write the interpolation qualifier. + if let Some(interp) = interpolation { + if emit_interpolation_and_auxiliary { + write!(self.out, "{} ", glsl_interpolation(interp))?; + } + } + + // Write the sampling auxiliary qualifier. + // + // Before GLSL 4.2, the `centroid` and `sample` qualifiers were required to appear + // immediately before the `in` / `out` qualifier, so we'll just follow that rule + // here, regardless of the version. + if let Some(sampling) = sampling { + if emit_interpolation_and_auxiliary { + if let Some(qualifier) = glsl_sampling(sampling) { + write!(self.out, "{qualifier} ")?; + } + } + } + + // Write the input/output qualifier. + write!(self.out, "{} ", if output { "out" } else { "in" })?; + + // Write the type + // `write_type` adds no leading or trailing spaces + self.write_type(ty)?; + + // Finally write the global name and end the global with a `;` and a newline + // Leading space is important + let vname = VaryingName { + binding: &crate::Binding::Location { + location, + interpolation: None, + sampling: None, + second_blend_source, + }, + stage: self.entry_point.stage, + options: VaryingOptions::from_writer_options(self.options, output), + }; + writeln!(self.out, " {vname};")?; + + if let Some(location) = io_location { + self.varying.insert(vname.to_string(), location); + } + + Ok(()) + } + + /// Helper method used to write functions (both entry points and regular functions) + /// + /// # Notes + /// Adds a newline + fn write_function( + &mut self, + ty: back::FunctionType, + func: &crate::Function, + info: &valid::FunctionInfo, + ) -> BackendResult { + // Create a function context for the function being written + let ctx = back::FunctionCtx { + ty, + info, + expressions: &func.expressions, + named_expressions: &func.named_expressions, + }; + + self.named_expressions.clear(); + self.update_expressions_to_bake(func, info); + + // Write the function header + // + // glsl headers are the same as in c: + // `ret_type name(args)` + // `ret_type` is the return type + // `name` is the function name + // `args` is a comma separated list of `type name` + // | - `type` is the argument type + // | - `name` is the argument name + + // Start by writing the return type if any otherwise write void + // This is the only place where `void` is a valid type + // (though it's more a keyword than a type) + if let back::FunctionType::EntryPoint(_) = ctx.ty { + write!(self.out, "void")?; + } else if let Some(ref result) = func.result { + self.write_type(result.ty)?; + if let TypeInner::Array { base, size, .. } = self.module.types[result.ty].inner { + self.write_array_size(base, size)? + } + } else { + write!(self.out, "void")?; + } + + // Write the function name and open parentheses for the argument list + let function_name = match ctx.ty { + back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)], + back::FunctionType::EntryPoint(_) => "main", + }; + write!(self.out, " {function_name}(")?; + + // Write the comma separated argument list + // + // We need access to `Self` here so we use the reference passed to the closure as an + // argument instead of capturing as that would cause a borrow checker error + let arguments = match ctx.ty { + back::FunctionType::EntryPoint(_) => &[][..], + back::FunctionType::Function(_) => &func.arguments, + }; + let arguments: Vec<_> = arguments + .iter() + .enumerate() + .filter(|&(_, arg)| match self.module.types[arg.ty].inner { + TypeInner::Sampler { .. } => false, + _ => true, + }) + .collect(); + self.write_slice(&arguments, |this, _, &(i, arg)| { + // Write the argument type + match this.module.types[arg.ty].inner { + // We treat images separately because they might require + // writing the storage format + TypeInner::Image { + dim, + arrayed, + class, + } => { + // Write the storage format if needed + if let TypeInner::Image { + class: crate::ImageClass::Storage { format, .. }, + .. + } = this.module.types[arg.ty].inner + { + write!(this.out, "layout({}) ", glsl_storage_format(format)?)?; + } + + // write the type + // + // This is way we need the leading space because `write_image_type` doesn't add + // any spaces at the beginning or end + this.write_image_type(dim, arrayed, class)?; + } + TypeInner::Pointer { base, .. } => { + // write parameter qualifiers + write!(this.out, "inout ")?; + this.write_type(base)?; + } + // All other types are written by `write_type` + _ => { + this.write_type(arg.ty)?; + } + } + + // Write the argument name + // The leading space is important + write!(this.out, " {}", &this.names[&ctx.argument_key(i as u32)])?; + + // Write array size + match this.module.types[arg.ty].inner { + TypeInner::Array { base, size, .. } => { + this.write_array_size(base, size)?; + } + TypeInner::Pointer { base, .. } => { + if let TypeInner::Array { base, size, .. } = this.module.types[base].inner { + this.write_array_size(base, size)?; + } + } + _ => {} + } + + Ok(()) + })?; + + // Close the parentheses and open braces to start the function body + writeln!(self.out, ") {{")?; + + if self.options.zero_initialize_workgroup_memory + && ctx.ty.is_compute_entry_point(self.module) + { + self.write_workgroup_variables_initialization(&ctx)?; + } + + // Compose the function arguments from globals, in case of an entry point. + if let back::FunctionType::EntryPoint(ep_index) = ctx.ty { + let stage = self.module.entry_points[ep_index as usize].stage; + for (index, arg) in func.arguments.iter().enumerate() { + write!(self.out, "{}", back::INDENT)?; + self.write_type(arg.ty)?; + let name = &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)]; + write!(self.out, " {name}")?; + write!(self.out, " = ")?; + match self.module.types[arg.ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + self.write_type(arg.ty)?; + write!(self.out, "(")?; + for (index, member) in members.iter().enumerate() { + let varying_name = VaryingName { + binding: member.binding.as_ref().unwrap(), + stage, + options: VaryingOptions::from_writer_options(self.options, false), + }; + if index != 0 { + write!(self.out, ", ")?; + } + write!(self.out, "{varying_name}")?; + } + writeln!(self.out, ");")?; + } + _ => { + let varying_name = VaryingName { + binding: arg.binding.as_ref().unwrap(), + stage, + options: VaryingOptions::from_writer_options(self.options, false), + }; + writeln!(self.out, "{varying_name};")?; + } + } + } + } + + // Write all function locals + // Locals are `type name (= init)?;` where the init part (including the =) are optional + // + // Always adds a newline + for (handle, local) in func.local_variables.iter() { + // Write indentation (only for readability) and the type + // `write_type` adds no trailing space + write!(self.out, "{}", back::INDENT)?; + self.write_type(local.ty)?; + + // Write the local name + // The leading space is important + write!(self.out, " {}", self.names[&ctx.name_key(handle)])?; + // Write size for array type + if let TypeInner::Array { base, size, .. } = self.module.types[local.ty].inner { + self.write_array_size(base, size)?; + } + // Write the local initializer if needed + if let Some(init) = local.init { + // Put the equal signal only if there's a initializer + // The leading and trailing spaces aren't needed but help with readability + write!(self.out, " = ")?; + + // Write the constant + // `write_constant` adds no trailing or leading space/newline + self.write_expr(init, &ctx)?; + } else if is_value_init_supported(self.module, local.ty) { + write!(self.out, " = ")?; + self.write_zero_init_value(local.ty)?; + } + + // Finish the local with `;` and add a newline (only for readability) + writeln!(self.out, ";")? + } + + // Write the function body (statement list) + for sta in func.body.iter() { + // Write a statement, the indentation should always be 1 when writing the function body + // `write_stmt` adds a newline + self.write_stmt(sta, &ctx, back::Level(1))?; + } + + // Close braces and add a newline + writeln!(self.out, "}}")?; + + Ok(()) + } + + fn write_workgroup_variables_initialization( + &mut self, + ctx: &back::FunctionCtx, + ) -> BackendResult { + let mut vars = self + .module + .global_variables + .iter() + .filter(|&(handle, var)| { + !ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }) + .peekable(); + + if vars.peek().is_some() { + let level = back::Level(1); + + writeln!(self.out, "{level}if (gl_LocalInvocationID == uvec3(0u)) {{")?; + + for (handle, var) in vars { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, "{}{} = ", level.next(), name)?; + self.write_zero_init_value(var.ty)?; + writeln!(self.out, ";")?; + } + + writeln!(self.out, "{level}}}")?; + self.write_barrier(crate::Barrier::WORK_GROUP, level)?; + } + + Ok(()) + } + + /// Write a list of comma separated `T` values using a writer function `F`. + /// + /// The writer function `F` receives a mutable reference to `self` that if needed won't cause + /// borrow checker issues (using for example a closure with `self` will cause issues), the + /// second argument is the 0 based index of the element on the list, and the last element is + /// a reference to the element `T` being written + /// + /// # Notes + /// - Adds no newlines or leading/trailing whitespace + /// - The last element won't have a trailing `,` + fn write_slice<T, F: FnMut(&mut Self, u32, &T) -> BackendResult>( + &mut self, + data: &[T], + mut f: F, + ) -> BackendResult { + // Loop through `data` invoking `f` for each element + for (index, item) in data.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + f(self, index as u32, item)?; + } + + Ok(()) + } + + /// Helper method used to write global constants + fn write_global_constant(&mut self, handle: Handle<crate::Constant>) -> BackendResult { + write!(self.out, "const ")?; + let constant = &self.module.constants[handle]; + self.write_type(constant.ty)?; + let name = &self.names[&NameKey::Constant(handle)]; + write!(self.out, " {name}")?; + if let TypeInner::Array { base, size, .. } = self.module.types[constant.ty].inner { + self.write_array_size(base, size)?; + } + write!(self.out, " = ")?; + self.write_const_expr(constant.init)?; + writeln!(self.out, ";")?; + Ok(()) + } + + /// Helper method used to output a dot product as an arithmetic expression + /// + fn write_dot_product( + &mut self, + arg: Handle<crate::Expression>, + arg1: Handle<crate::Expression>, + size: usize, + ctx: &back::FunctionCtx, + ) -> BackendResult { + // Write parentheses around the dot product expression to prevent operators + // with different precedences from applying earlier. + write!(self.out, "(")?; + + // Cycle trough all the components of the vector + for index in 0..size { + let component = back::COMPONENTS[index]; + // Write the addition to the previous product + // This will print an extra '+' at the beginning but that is fine in glsl + write!(self.out, " + ")?; + // Write the first vector expression, this expression is marked to be + // cached so unless it can't be cached (for example, it's a Constant) + // it shouldn't produce large expressions. + self.write_expr(arg, ctx)?; + // Access the current component on the first vector + write!(self.out, ".{component} * ")?; + // Write the second vector expression, this expression is marked to be + // cached so unless it can't be cached (for example, it's a Constant) + // it shouldn't produce large expressions. + self.write_expr(arg1, ctx)?; + // Access the current component on the second vector + write!(self.out, ".{component}")?; + } + + write!(self.out, ")")?; + Ok(()) + } + + /// Helper method used to write structs + /// + /// # Notes + /// Ends in a newline + fn write_struct_body( + &mut self, + handle: Handle<crate::Type>, + members: &[crate::StructMember], + ) -> BackendResult { + // glsl structs are written as in C + // `struct name() { members };` + // | `struct` is a keyword + // | `name` is the struct name + // | `members` is a semicolon separated list of `type name` + // | `type` is the member type + // | `name` is the member name + writeln!(self.out, "{{")?; + + for (idx, member) in members.iter().enumerate() { + // The indentation is only for readability + write!(self.out, "{}", back::INDENT)?; + + match self.module.types[member.ty].inner { + TypeInner::Array { + base, + size, + stride: _, + } => { + self.write_type(base)?; + write!( + self.out, + " {}", + &self.names[&NameKey::StructMember(handle, idx as u32)] + )?; + // Write [size] + self.write_array_size(base, size)?; + // Newline is important + writeln!(self.out, ";")?; + } + _ => { + // Write the member type + // Adds no trailing space + self.write_type(member.ty)?; + + // Write the member name and put a semicolon + // The leading space is important + // All members must have a semicolon even the last one + writeln!( + self.out, + " {};", + &self.names[&NameKey::StructMember(handle, idx as u32)] + )?; + } + } + } + + write!(self.out, "}}")?; + Ok(()) + } + + /// Helper method used to write statements + /// + /// # Notes + /// Always adds a newline + fn write_stmt( + &mut self, + sta: &crate::Statement, + ctx: &back::FunctionCtx, + level: back::Level, + ) -> BackendResult { + use crate::Statement; + + match *sta { + // This is where we can generate intermediate constants for some expression types. + Statement::Emit(ref range) => { + for handle in range.clone() { + let ptr_class = ctx.resolve_type(handle, &self.module.types).pointer_space(); + let expr_name = if ptr_class.is_some() { + // GLSL can't save a pointer-valued expression in a variable, + // but we shouldn't ever need to: they should never be named expressions, + // and none of the expression types flagged by bake_ref_count can be pointer-valued. + None + } else if let Some(name) = ctx.named_expressions.get(&handle) { + // Front end provides names for all variables at the start of writing. + // But we write them to step by step. We need to recache them + // Otherwise, we could accidentally write variable name instead of full expression. + // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords. + Some(self.namer.call(name)) + } else if self.need_bake_expressions.contains(&handle) { + Some(format!("{}{}", back::BAKE_PREFIX, handle.index())) + } else { + None + }; + + // If we are going to write an `ImageLoad` next and the target image + // is sampled and we are using the `Restrict` policy for bounds + // checking images we need to write a local holding the clamped lod. + if let crate::Expression::ImageLoad { + image, + level: Some(level_expr), + .. + } = ctx.expressions[handle] + { + if let TypeInner::Image { + class: crate::ImageClass::Sampled { .. }, + .. + } = *ctx.resolve_type(image, &self.module.types) + { + if let proc::BoundsCheckPolicy::Restrict = self.policies.image_load { + write!(self.out, "{level}")?; + self.write_clamped_lod(ctx, handle, image, level_expr)? + } + } + } + + if let Some(name) = expr_name { + write!(self.out, "{level}")?; + self.write_named_expr(handle, name, handle, ctx)?; + } + } + } + // Blocks are simple we just need to write the block statements between braces + // We could also just print the statements but this is more readable and maps more + // closely to the IR + Statement::Block(ref block) => { + write!(self.out, "{level}")?; + writeln!(self.out, "{{")?; + for sta in block.iter() { + // Increase the indentation to help with readability + self.write_stmt(sta, ctx, level.next())? + } + writeln!(self.out, "{level}}}")? + } + // Ifs are written as in C: + // ``` + // if(condition) { + // accept + // } else { + // reject + // } + // ``` + Statement::If { + condition, + ref accept, + ref reject, + } => { + write!(self.out, "{level}")?; + write!(self.out, "if (")?; + self.write_expr(condition, ctx)?; + writeln!(self.out, ") {{")?; + + for sta in accept { + // Increase indentation to help with readability + self.write_stmt(sta, ctx, level.next())?; + } + + // If there are no statements in the reject block we skip writing it + // This is only for readability + if !reject.is_empty() { + writeln!(self.out, "{level}}} else {{")?; + + for sta in reject { + // Increase indentation to help with readability + self.write_stmt(sta, ctx, level.next())?; + } + } + + writeln!(self.out, "{level}}}")? + } + // Switch are written as in C: + // ``` + // switch (selector) { + // // Fallthrough + // case label: + // block + // // Non fallthrough + // case label: + // block + // break; + // default: + // block + // } + // ``` + // Where the `default` case happens isn't important but we put it last + // so that we don't need to print a `break` for it + Statement::Switch { + selector, + ref cases, + } => { + // Start the switch + write!(self.out, "{level}")?; + write!(self.out, "switch(")?; + self.write_expr(selector, ctx)?; + writeln!(self.out, ") {{")?; + + // Write all cases + let l2 = level.next(); + for case in cases { + match case.value { + crate::SwitchValue::I32(value) => write!(self.out, "{l2}case {value}:")?, + crate::SwitchValue::U32(value) => write!(self.out, "{l2}case {value}u:")?, + crate::SwitchValue::Default => write!(self.out, "{l2}default:")?, + } + + let write_block_braces = !(case.fall_through && case.body.is_empty()); + if write_block_braces { + writeln!(self.out, " {{")?; + } else { + writeln!(self.out)?; + } + + for sta in case.body.iter() { + self.write_stmt(sta, ctx, l2.next())?; + } + + if !case.fall_through && case.body.last().map_or(true, |s| !s.is_terminator()) { + writeln!(self.out, "{}break;", l2.next())?; + } + + if write_block_braces { + writeln!(self.out, "{l2}}}")?; + } + } + + writeln!(self.out, "{level}}}")? + } + // Loops in naga IR are based on wgsl loops, glsl can emulate the behaviour by using a + // while true loop and appending the continuing block to the body resulting on: + // ``` + // bool loop_init = true; + // while(true) { + // if (!loop_init) { <continuing> } + // loop_init = false; + // <body> + // } + // ``` + Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + if !continuing.is_empty() || break_if.is_some() { + let gate_name = self.namer.call("loop_init"); + writeln!(self.out, "{level}bool {gate_name} = true;")?; + writeln!(self.out, "{level}while(true) {{")?; + let l2 = level.next(); + let l3 = l2.next(); + writeln!(self.out, "{l2}if (!{gate_name}) {{")?; + for sta in continuing { + self.write_stmt(sta, ctx, l3)?; + } + if let Some(condition) = break_if { + write!(self.out, "{l3}if (")?; + self.write_expr(condition, ctx)?; + writeln!(self.out, ") {{")?; + writeln!(self.out, "{}break;", l3.next())?; + writeln!(self.out, "{l3}}}")?; + } + writeln!(self.out, "{l2}}}")?; + writeln!(self.out, "{}{} = false;", level.next(), gate_name)?; + } else { + writeln!(self.out, "{level}while(true) {{")?; + } + for sta in body { + self.write_stmt(sta, ctx, level.next())?; + } + writeln!(self.out, "{level}}}")? + } + // Break, continue and return as written as in C + // `break;` + Statement::Break => { + write!(self.out, "{level}")?; + writeln!(self.out, "break;")? + } + // `continue;` + Statement::Continue => { + write!(self.out, "{level}")?; + writeln!(self.out, "continue;")? + } + // `return expr;`, `expr` is optional + Statement::Return { value } => { + write!(self.out, "{level}")?; + match ctx.ty { + back::FunctionType::Function(_) => { + write!(self.out, "return")?; + // Write the expression to be returned if needed + if let Some(expr) = value { + write!(self.out, " ")?; + self.write_expr(expr, ctx)?; + } + writeln!(self.out, ";")?; + } + back::FunctionType::EntryPoint(ep_index) => { + let mut has_point_size = false; + let ep = &self.module.entry_points[ep_index as usize]; + if let Some(ref result) = ep.function.result { + let value = value.unwrap(); + match self.module.types[result.ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + let temp_struct_name = match ctx.expressions[value] { + crate::Expression::Compose { .. } => { + let return_struct = "_tmp_return"; + write!( + self.out, + "{} {} = ", + &self.names[&NameKey::Type(result.ty)], + return_struct + )?; + self.write_expr(value, ctx)?; + writeln!(self.out, ";")?; + write!(self.out, "{level}")?; + Some(return_struct) + } + _ => None, + }; + + for (index, member) in members.iter().enumerate() { + if let Some(crate::Binding::BuiltIn( + crate::BuiltIn::PointSize, + )) = member.binding + { + has_point_size = true; + } + + let varying_name = VaryingName { + binding: member.binding.as_ref().unwrap(), + stage: ep.stage, + options: VaryingOptions::from_writer_options( + self.options, + true, + ), + }; + write!(self.out, "{varying_name} = ")?; + + if let Some(struct_name) = temp_struct_name { + write!(self.out, "{struct_name}")?; + } else { + self.write_expr(value, ctx)?; + } + + // Write field name + writeln!( + self.out, + ".{};", + &self.names + [&NameKey::StructMember(result.ty, index as u32)] + )?; + write!(self.out, "{level}")?; + } + } + _ => { + let name = VaryingName { + binding: result.binding.as_ref().unwrap(), + stage: ep.stage, + options: VaryingOptions::from_writer_options( + self.options, + true, + ), + }; + write!(self.out, "{name} = ")?; + self.write_expr(value, ctx)?; + writeln!(self.out, ";")?; + write!(self.out, "{level}")?; + } + } + } + + let is_vertex_stage = self.module.entry_points[ep_index as usize].stage + == ShaderStage::Vertex; + if is_vertex_stage + && self + .options + .writer_flags + .contains(WriterFlags::ADJUST_COORDINATE_SPACE) + { + writeln!( + self.out, + "gl_Position.yz = vec2(-gl_Position.y, gl_Position.z * 2.0 - gl_Position.w);", + )?; + write!(self.out, "{level}")?; + } + + if is_vertex_stage + && self + .options + .writer_flags + .contains(WriterFlags::FORCE_POINT_SIZE) + && !has_point_size + { + writeln!(self.out, "gl_PointSize = 1.0;")?; + write!(self.out, "{level}")?; + } + writeln!(self.out, "return;")?; + } + } + } + // This is one of the places were glsl adds to the syntax of C in this case the discard + // keyword which ceases all further processing in a fragment shader, it's called OpKill + // in spir-v that's why it's called `Statement::Kill` + Statement::Kill => writeln!(self.out, "{level}discard;")?, + Statement::Barrier(flags) => { + self.write_barrier(flags, level)?; + } + // Stores in glsl are just variable assignments written as `pointer = value;` + Statement::Store { pointer, value } => { + write!(self.out, "{level}")?; + self.write_expr(pointer, ctx)?; + write!(self.out, " = ")?; + self.write_expr(value, ctx)?; + writeln!(self.out, ";")? + } + Statement::WorkGroupUniformLoad { pointer, result } => { + // GLSL doesn't have pointers, which means that this backend needs to ensure that + // the actual "loading" is happening between the two barriers. + // This is done in `Emit` by never emitting a variable name for pointer variables + self.write_barrier(crate::Barrier::WORK_GROUP, level)?; + + let result_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + write!(self.out, "{level}")?; + // Expressions cannot have side effects, so just writing the expression here is fine. + self.write_named_expr(pointer, result_name, result, ctx)?; + + self.write_barrier(crate::Barrier::WORK_GROUP, level)?; + } + // Stores a value into an image. + Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => { + write!(self.out, "{level}")?; + self.write_image_store(ctx, image, coordinate, array_index, value)? + } + // A `Call` is written `name(arguments)` where `arguments` is a comma separated expressions list + Statement::Call { + function, + ref arguments, + result, + } => { + write!(self.out, "{level}")?; + if let Some(expr) = result { + let name = format!("{}{}", back::BAKE_PREFIX, expr.index()); + let result = self.module.functions[function].result.as_ref().unwrap(); + self.write_type(result.ty)?; + write!(self.out, " {name}")?; + if let TypeInner::Array { base, size, .. } = self.module.types[result.ty].inner + { + self.write_array_size(base, size)? + } + write!(self.out, " = ")?; + self.named_expressions.insert(expr, name); + } + write!(self.out, "{}(", &self.names[&NameKey::Function(function)])?; + let arguments: Vec<_> = arguments + .iter() + .enumerate() + .filter_map(|(i, arg)| { + let arg_ty = self.module.functions[function].arguments[i].ty; + match self.module.types[arg_ty].inner { + TypeInner::Sampler { .. } => None, + _ => Some(*arg), + } + }) + .collect(); + self.write_slice(&arguments, |this, _, arg| this.write_expr(*arg, ctx))?; + writeln!(self.out, ");")? + } + Statement::Atomic { + pointer, + ref fun, + value, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.resolve_type(result, &self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + let fun_str = fun.to_glsl(); + write!(self.out, "atomic{fun_str}(")?; + self.write_expr(pointer, ctx)?; + write!(self.out, ", ")?; + // handle the special cases + match *fun { + crate::AtomicFunction::Subtract => { + // we just wrote `InterlockedAdd`, so negate the argument + write!(self.out, "-")?; + } + crate::AtomicFunction::Exchange { compare: Some(_) } => { + return Err(Error::Custom( + "atomic CompareExchange is not implemented".to_string(), + )); + } + _ => {} + } + self.write_expr(value, ctx)?; + writeln!(self.out, ");")?; + } + Statement::RayQuery { .. } => unreachable!(), + } + + Ok(()) + } + + /// Write a const expression. + /// + /// Write `expr`, a handle to an [`Expression`] in the current [`Module`]'s + /// constant expression arena, as GLSL expression. + /// + /// # Notes + /// Adds no newlines or leading/trailing whitespace + /// + /// [`Expression`]: crate::Expression + /// [`Module`]: crate::Module + fn write_const_expr(&mut self, expr: Handle<crate::Expression>) -> BackendResult { + self.write_possibly_const_expr( + expr, + &self.module.const_expressions, + |expr| &self.info[expr], + |writer, expr| writer.write_const_expr(expr), + ) + } + + /// Write [`Expression`] variants that can occur in both runtime and const expressions. + /// + /// Write `expr`, a handle to an [`Expression`] in the arena `expressions`, + /// as as GLSL expression. This must be one of the [`Expression`] variants + /// that is allowed to occur in constant expressions. + /// + /// Use `write_expression` to write subexpressions. + /// + /// This is the common code for `write_expr`, which handles arbitrary + /// runtime expressions, and `write_const_expr`, which only handles + /// const-expressions. Each of those callers passes itself (essentially) as + /// the `write_expression` callback, so that subexpressions are restricted + /// to the appropriate variants. + /// + /// # Notes + /// Adds no newlines or leading/trailing whitespace + /// + /// [`Expression`]: crate::Expression + fn write_possibly_const_expr<'w, I, E>( + &'w mut self, + expr: Handle<crate::Expression>, + expressions: &crate::Arena<crate::Expression>, + info: I, + write_expression: E, + ) -> BackendResult + where + I: Fn(Handle<crate::Expression>) -> &'w proc::TypeResolution, + E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult, + { + use crate::Expression; + + match expressions[expr] { + Expression::Literal(literal) => { + match literal { + // Floats are written using `Debug` instead of `Display` because it always appends the + // decimal part even it's zero which is needed for a valid glsl float constant + crate::Literal::F64(value) => write!(self.out, "{:?}LF", value)?, + crate::Literal::F32(value) => write!(self.out, "{:?}", value)?, + // Unsigned integers need a `u` at the end + // + // While `core` doesn't necessarily need it, it's allowed and since `es` needs it we + // always write it as the extra branch wouldn't have any benefit in readability + crate::Literal::U32(value) => write!(self.out, "{}u", value)?, + crate::Literal::I32(value) => write!(self.out, "{}", value)?, + crate::Literal::Bool(value) => write!(self.out, "{}", value)?, + crate::Literal::I64(_) => { + return Err(Error::Custom("GLSL has no 64-bit integer type".into())); + } + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".into(), + )); + } + } + } + Expression::Constant(handle) => { + let constant = &self.module.constants[handle]; + if constant.name.is_some() { + write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?; + } else { + self.write_const_expr(constant.init)?; + } + } + Expression::ZeroValue(ty) => { + self.write_zero_init_value(ty)?; + } + Expression::Compose { ty, ref components } => { + self.write_type(ty)?; + + if let TypeInner::Array { base, size, .. } = self.module.types[ty].inner { + self.write_array_size(base, size)?; + } + + write!(self.out, "(")?; + for (index, component) in components.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + write_expression(self, *component)?; + } + write!(self.out, ")")? + } + // `Splat` needs to actually write down a vector, it's not always inferred in GLSL. + Expression::Splat { size: _, value } => { + let resolved = info(expr).inner_with(&self.module.types); + self.write_value_type(resolved)?; + write!(self.out, "(")?; + write_expression(self, value)?; + write!(self.out, ")")? + } + _ => unreachable!(), + } + + Ok(()) + } + + /// Helper method to write expressions + /// + /// # Notes + /// Doesn't add any newlines or leading/trailing spaces + fn write_expr( + &mut self, + expr: Handle<crate::Expression>, + ctx: &back::FunctionCtx, + ) -> BackendResult { + use crate::Expression; + + if let Some(name) = self.named_expressions.get(&expr) { + write!(self.out, "{name}")?; + return Ok(()); + } + + match ctx.expressions[expr] { + Expression::Literal(_) + | Expression::Constant(_) + | Expression::ZeroValue(_) + | Expression::Compose { .. } + | Expression::Splat { .. } => { + self.write_possibly_const_expr( + expr, + ctx.expressions, + |expr| &ctx.info[expr].ty, + |writer, expr| writer.write_expr(expr, ctx), + )?; + } + // `Access` is applied to arrays, vectors and matrices and is written as indexing + Expression::Access { base, index } => { + self.write_expr(base, ctx)?; + write!(self.out, "[")?; + self.write_expr(index, ctx)?; + write!(self.out, "]")? + } + // `AccessIndex` is the same as `Access` except that the index is a constant and it can + // be applied to structs, in this case we need to find the name of the field at that + // index and write `base.field_name` + Expression::AccessIndex { base, index } => { + self.write_expr(base, ctx)?; + + let base_ty_res = &ctx.info[base].ty; + let mut resolved = base_ty_res.inner_with(&self.module.types); + let base_ty_handle = match *resolved { + TypeInner::Pointer { base, space: _ } => { + resolved = &self.module.types[base].inner; + Some(base) + } + _ => base_ty_res.handle(), + }; + + match *resolved { + TypeInner::Vector { .. } => { + // Write vector access as a swizzle + write!(self.out, ".{}", back::COMPONENTS[index as usize])? + } + TypeInner::Matrix { .. } + | TypeInner::Array { .. } + | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?, + TypeInner::Struct { .. } => { + // This will never panic in case the type is a `Struct`, this is not true + // for other types so we can only check while inside this match arm + let ty = base_ty_handle.unwrap(); + + write!( + self.out, + ".{}", + &self.names[&NameKey::StructMember(ty, index)] + )? + } + ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))), + } + } + // `Swizzle` adds a few letters behind the dot. + Expression::Swizzle { + size, + vector, + pattern, + } => { + self.write_expr(vector, ctx)?; + write!(self.out, ".")?; + for &sc in pattern[..size as usize].iter() { + self.out.write_char(back::COMPONENTS[sc as usize])?; + } + } + // Function arguments are written as the argument name + Expression::FunctionArgument(pos) => { + write!(self.out, "{}", &self.names[&ctx.argument_key(pos)])? + } + // Global variables need some special work for their name but + // `get_global_name` does the work for us + Expression::GlobalVariable(handle) => { + let global = &self.module.global_variables[handle]; + self.write_global_name(handle, global)? + } + // A local is written as it's name + Expression::LocalVariable(handle) => { + write!(self.out, "{}", self.names[&ctx.name_key(handle)])? + } + // glsl has no pointers so there's no load operation, just write the pointer expression + Expression::Load { pointer } => self.write_expr(pointer, ctx)?, + // `ImageSample` is a bit complicated compared to the rest of the IR. + // + // First there are three variations depending whether the sample level is explicitly set, + // if it's automatic or it it's bias: + // `texture(image, coordinate)` - Automatic sample level + // `texture(image, coordinate, bias)` - Bias sample level + // `textureLod(image, coordinate, level)` - Zero or Exact sample level + // + // Furthermore if `depth_ref` is some we need to append it to the coordinate vector + Expression::ImageSample { + image, + sampler: _, //TODO? + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + } => { + let dim = match *ctx.resolve_type(image, &self.module.types) { + TypeInner::Image { dim, .. } => dim, + _ => unreachable!(), + }; + + if dim == crate::ImageDimension::Cube + && array_index.is_some() + && depth_ref.is_some() + { + match level { + crate::SampleLevel::Zero + | crate::SampleLevel::Exact(_) + | crate::SampleLevel::Gradient { .. } + | crate::SampleLevel::Bias(_) => { + return Err(Error::Custom(String::from( + "gsamplerCubeArrayShadow isn't supported in textureGrad, \ + textureLod or texture with bias", + ))) + } + crate::SampleLevel::Auto => {} + } + } + + // textureLod on sampler2DArrayShadow and samplerCubeShadow does not exist in GLSL. + // To emulate this, we will have to use textureGrad with a constant gradient of 0. + let workaround_lod_array_shadow_as_grad = (array_index.is_some() + || dim == crate::ImageDimension::Cube) + && depth_ref.is_some() + && gather.is_none() + && !self + .options + .writer_flags + .contains(WriterFlags::TEXTURE_SHADOW_LOD); + + //Write the function to be used depending on the sample level + let fun_name = match level { + crate::SampleLevel::Zero if gather.is_some() => "textureGather", + crate::SampleLevel::Auto | crate::SampleLevel::Bias(_) => "texture", + crate::SampleLevel::Zero | crate::SampleLevel::Exact(_) => { + if workaround_lod_array_shadow_as_grad { + "textureGrad" + } else { + "textureLod" + } + } + crate::SampleLevel::Gradient { .. } => "textureGrad", + }; + let offset_name = match offset { + Some(_) => "Offset", + None => "", + }; + + write!(self.out, "{fun_name}{offset_name}(")?; + + // Write the image that will be used + self.write_expr(image, ctx)?; + // The space here isn't required but it helps with readability + write!(self.out, ", ")?; + + // We need to get the coordinates vector size to later build a vector that's `size + 1` + // if `depth_ref` is some, if it isn't a vector we panic as that's not a valid expression + let mut coord_dim = match *ctx.resolve_type(coordinate, &self.module.types) { + TypeInner::Vector { size, .. } => size as u8, + TypeInner::Scalar { .. } => 1, + _ => unreachable!(), + }; + + if array_index.is_some() { + coord_dim += 1; + } + let merge_depth_ref = depth_ref.is_some() && gather.is_none() && coord_dim < 4; + if merge_depth_ref { + coord_dim += 1; + } + + let tex_1d_hack = dim == crate::ImageDimension::D1 && self.options.version.is_es(); + let is_vec = tex_1d_hack || coord_dim != 1; + // Compose a new texture coordinates vector + if is_vec { + write!(self.out, "vec{}(", coord_dim + tex_1d_hack as u8)?; + } + self.write_expr(coordinate, ctx)?; + if tex_1d_hack { + write!(self.out, ", 0.0")?; + } + if let Some(expr) = array_index { + write!(self.out, ", ")?; + self.write_expr(expr, ctx)?; + } + if merge_depth_ref { + write!(self.out, ", ")?; + self.write_expr(depth_ref.unwrap(), ctx)?; + } + if is_vec { + write!(self.out, ")")?; + } + + if let (Some(expr), false) = (depth_ref, merge_depth_ref) { + write!(self.out, ", ")?; + self.write_expr(expr, ctx)?; + } + + match level { + // Auto needs no more arguments + crate::SampleLevel::Auto => (), + // Zero needs level set to 0 + crate::SampleLevel::Zero => { + if workaround_lod_array_shadow_as_grad { + let vec_dim = match dim { + crate::ImageDimension::Cube => 3, + _ => 2, + }; + write!(self.out, ", vec{vec_dim}(0.0), vec{vec_dim}(0.0)")?; + } else if gather.is_none() { + write!(self.out, ", 0.0")?; + } + } + // Exact and bias require another argument + crate::SampleLevel::Exact(expr) => { + if workaround_lod_array_shadow_as_grad { + log::warn!("Unable to `textureLod` a shadow array, ignoring the LOD"); + write!(self.out, ", vec2(0,0), vec2(0,0)")?; + } else { + write!(self.out, ", ")?; + self.write_expr(expr, ctx)?; + } + } + crate::SampleLevel::Bias(_) => { + // This needs to be done after the offset writing + } + crate::SampleLevel::Gradient { x, y } => { + // If we are using sampler2D to replace sampler1D, we also + // need to make sure to use vec2 gradients + if tex_1d_hack { + write!(self.out, ", vec2(")?; + self.write_expr(x, ctx)?; + write!(self.out, ", 0.0)")?; + write!(self.out, ", vec2(")?; + self.write_expr(y, ctx)?; + write!(self.out, ", 0.0)")?; + } else { + write!(self.out, ", ")?; + self.write_expr(x, ctx)?; + write!(self.out, ", ")?; + self.write_expr(y, ctx)?; + } + } + } + + if let Some(constant) = offset { + write!(self.out, ", ")?; + if tex_1d_hack { + write!(self.out, "ivec2(")?; + } + self.write_const_expr(constant)?; + if tex_1d_hack { + write!(self.out, ", 0)")?; + } + } + + // Bias is always the last argument + if let crate::SampleLevel::Bias(expr) = level { + write!(self.out, ", ")?; + self.write_expr(expr, ctx)?; + } + + if let (Some(component), None) = (gather, depth_ref) { + write!(self.out, ", {}", component as usize)?; + } + + // End the function + write!(self.out, ")")? + } + Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => self.write_image_load(expr, ctx, image, coordinate, array_index, sample, level)?, + // Query translates into one of the: + // - textureSize/imageSize + // - textureQueryLevels + // - textureSamples/imageSamples + Expression::ImageQuery { image, query } => { + use crate::ImageClass; + + // This will only panic if the module is invalid + let (dim, class) = match *ctx.resolve_type(image, &self.module.types) { + TypeInner::Image { + dim, + arrayed: _, + class, + } => (dim, class), + _ => unreachable!(), + }; + let components = match dim { + crate::ImageDimension::D1 => 1, + crate::ImageDimension::D2 => 2, + crate::ImageDimension::D3 => 3, + crate::ImageDimension::Cube => 2, + }; + + if let crate::ImageQuery::Size { .. } = query { + match components { + 1 => write!(self.out, "uint(")?, + _ => write!(self.out, "uvec{components}(")?, + } + } else { + write!(self.out, "uint(")?; + } + + match query { + crate::ImageQuery::Size { level } => { + match class { + ImageClass::Sampled { multi, .. } | ImageClass::Depth { multi } => { + write!(self.out, "textureSize(")?; + self.write_expr(image, ctx)?; + if let Some(expr) = level { + let cast_to_int = matches!( + *ctx.resolve_type(expr, &self.module.types), + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Uint, + .. + }) + ); + + write!(self.out, ", ")?; + + if cast_to_int { + write!(self.out, "int(")?; + } + + self.write_expr(expr, ctx)?; + + if cast_to_int { + write!(self.out, ")")?; + } + } else if !multi { + // All textureSize calls requires an lod argument + // except for multisampled samplers + write!(self.out, ", 0")?; + } + } + ImageClass::Storage { .. } => { + write!(self.out, "imageSize(")?; + self.write_expr(image, ctx)?; + } + } + write!(self.out, ")")?; + if components != 1 || self.options.version.is_es() { + write!(self.out, ".{}", &"xyz"[..components])?; + } + } + crate::ImageQuery::NumLevels => { + write!(self.out, "textureQueryLevels(",)?; + self.write_expr(image, ctx)?; + write!(self.out, ")",)?; + } + crate::ImageQuery::NumLayers => { + let fun_name = match class { + ImageClass::Sampled { .. } | ImageClass::Depth { .. } => "textureSize", + ImageClass::Storage { .. } => "imageSize", + }; + write!(self.out, "{fun_name}(")?; + self.write_expr(image, ctx)?; + // All textureSize calls requires an lod argument + // except for multisampled samplers + if class.is_multisampled() { + write!(self.out, ", 0")?; + } + write!(self.out, ")")?; + if components != 1 || self.options.version.is_es() { + write!(self.out, ".{}", back::COMPONENTS[components])?; + } + } + crate::ImageQuery::NumSamples => { + let fun_name = match class { + ImageClass::Sampled { .. } | ImageClass::Depth { .. } => { + "textureSamples" + } + ImageClass::Storage { .. } => "imageSamples", + }; + write!(self.out, "{fun_name}(")?; + self.write_expr(image, ctx)?; + write!(self.out, ")",)?; + } + } + + write!(self.out, ")")?; + } + Expression::Unary { op, expr } => { + let operator_or_fn = match op { + crate::UnaryOperator::Negate => "-", + crate::UnaryOperator::LogicalNot => { + match *ctx.resolve_type(expr, &self.module.types) { + TypeInner::Vector { .. } => "not", + _ => "!", + } + } + crate::UnaryOperator::BitwiseNot => "~", + }; + write!(self.out, "{operator_or_fn}(")?; + + self.write_expr(expr, ctx)?; + + write!(self.out, ")")? + } + // `Binary` we just write `left op right`, except when dealing with + // comparison operations on vectors as they are implemented with + // builtin functions. + // Once again we wrap everything in parentheses to avoid precedence issues + Expression::Binary { + mut op, + left, + right, + } => { + // Holds `Some(function_name)` if the binary operation is + // implemented as a function call + use crate::{BinaryOperator as Bo, ScalarKind as Sk, TypeInner as Ti}; + + let left_inner = ctx.resolve_type(left, &self.module.types); + let right_inner = ctx.resolve_type(right, &self.module.types); + + let function = match (left_inner, right_inner) { + (&Ti::Vector { scalar, .. }, &Ti::Vector { .. }) => match op { + Bo::Less + | Bo::LessEqual + | Bo::Greater + | Bo::GreaterEqual + | Bo::Equal + | Bo::NotEqual => BinaryOperation::VectorCompare, + Bo::Modulo if scalar.kind == Sk::Float => BinaryOperation::Modulo, + Bo::And if scalar.kind == Sk::Bool => { + op = crate::BinaryOperator::LogicalAnd; + BinaryOperation::VectorComponentWise + } + Bo::InclusiveOr if scalar.kind == Sk::Bool => { + op = crate::BinaryOperator::LogicalOr; + BinaryOperation::VectorComponentWise + } + _ => BinaryOperation::Other, + }, + _ => match (left_inner.scalar_kind(), right_inner.scalar_kind()) { + (Some(Sk::Float), _) | (_, Some(Sk::Float)) => match op { + Bo::Modulo => BinaryOperation::Modulo, + _ => BinaryOperation::Other, + }, + (Some(Sk::Bool), Some(Sk::Bool)) => match op { + Bo::InclusiveOr => { + op = crate::BinaryOperator::LogicalOr; + BinaryOperation::Other + } + Bo::And => { + op = crate::BinaryOperator::LogicalAnd; + BinaryOperation::Other + } + _ => BinaryOperation::Other, + }, + _ => BinaryOperation::Other, + }, + }; + + match function { + BinaryOperation::VectorCompare => { + let op_str = match op { + Bo::Less => "lessThan(", + Bo::LessEqual => "lessThanEqual(", + Bo::Greater => "greaterThan(", + Bo::GreaterEqual => "greaterThanEqual(", + Bo::Equal => "equal(", + Bo::NotEqual => "notEqual(", + _ => unreachable!(), + }; + write!(self.out, "{op_str}")?; + self.write_expr(left, ctx)?; + write!(self.out, ", ")?; + self.write_expr(right, ctx)?; + write!(self.out, ")")?; + } + BinaryOperation::VectorComponentWise => { + self.write_value_type(left_inner)?; + write!(self.out, "(")?; + + let size = match *left_inner { + Ti::Vector { size, .. } => size, + _ => unreachable!(), + }; + + for i in 0..size as usize { + if i != 0 { + write!(self.out, ", ")?; + } + + self.write_expr(left, ctx)?; + write!(self.out, ".{}", back::COMPONENTS[i])?; + + write!(self.out, " {} ", back::binary_operation_str(op))?; + + self.write_expr(right, ctx)?; + write!(self.out, ".{}", back::COMPONENTS[i])?; + } + + write!(self.out, ")")?; + } + // TODO: handle undefined behavior of BinaryOperator::Modulo + // + // sint: + // if right == 0 return 0 + // if left == min(type_of(left)) && right == -1 return 0 + // if sign(left) == -1 || sign(right) == -1 return result as defined by WGSL + // + // uint: + // if right == 0 return 0 + // + // float: + // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798 + BinaryOperation::Modulo => { + write!(self.out, "(")?; + + // write `e1 - e2 * trunc(e1 / e2)` + self.write_expr(left, ctx)?; + write!(self.out, " - ")?; + self.write_expr(right, ctx)?; + write!(self.out, " * ")?; + write!(self.out, "trunc(")?; + self.write_expr(left, ctx)?; + write!(self.out, " / ")?; + self.write_expr(right, ctx)?; + write!(self.out, ")")?; + + write!(self.out, ")")?; + } + BinaryOperation::Other => { + write!(self.out, "(")?; + + self.write_expr(left, ctx)?; + write!(self.out, " {} ", back::binary_operation_str(op))?; + self.write_expr(right, ctx)?; + + write!(self.out, ")")?; + } + } + } + // `Select` is written as `condition ? accept : reject` + // We wrap everything in parentheses to avoid precedence issues + Expression::Select { + condition, + accept, + reject, + } => { + let cond_ty = ctx.resolve_type(condition, &self.module.types); + let vec_select = if let TypeInner::Vector { .. } = *cond_ty { + true + } else { + false + }; + + // TODO: Boolean mix on desktop required GL_EXT_shader_integer_mix + if vec_select { + // Glsl defines that for mix when the condition is a boolean the first element + // is picked if condition is false and the second if condition is true + write!(self.out, "mix(")?; + self.write_expr(reject, ctx)?; + write!(self.out, ", ")?; + self.write_expr(accept, ctx)?; + write!(self.out, ", ")?; + self.write_expr(condition, ctx)?; + } else { + write!(self.out, "(")?; + self.write_expr(condition, ctx)?; + write!(self.out, " ? ")?; + self.write_expr(accept, ctx)?; + write!(self.out, " : ")?; + self.write_expr(reject, ctx)?; + } + + write!(self.out, ")")? + } + // `Derivative` is a function call to a glsl provided function + Expression::Derivative { axis, ctrl, expr } => { + use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; + let fun_name = if self.options.version.supports_derivative_control() { + match (axis, ctrl) { + (Axis::X, Ctrl::Coarse) => "dFdxCoarse", + (Axis::X, Ctrl::Fine) => "dFdxFine", + (Axis::X, Ctrl::None) => "dFdx", + (Axis::Y, Ctrl::Coarse) => "dFdyCoarse", + (Axis::Y, Ctrl::Fine) => "dFdyFine", + (Axis::Y, Ctrl::None) => "dFdy", + (Axis::Width, Ctrl::Coarse) => "fwidthCoarse", + (Axis::Width, Ctrl::Fine) => "fwidthFine", + (Axis::Width, Ctrl::None) => "fwidth", + } + } else { + match axis { + Axis::X => "dFdx", + Axis::Y => "dFdy", + Axis::Width => "fwidth", + } + }; + write!(self.out, "{fun_name}(")?; + self.write_expr(expr, ctx)?; + write!(self.out, ")")? + } + // `Relational` is a normal function call to some glsl provided functions + Expression::Relational { fun, argument } => { + use crate::RelationalFunction as Rf; + + let fun_name = match fun { + Rf::IsInf => "isinf", + Rf::IsNan => "isnan", + Rf::All => "all", + Rf::Any => "any", + }; + write!(self.out, "{fun_name}(")?; + + self.write_expr(argument, ctx)?; + + write!(self.out, ")")? + } + Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + use crate::MathFunction as Mf; + + let fun_name = match fun { + // comparison + Mf::Abs => "abs", + Mf::Min => "min", + Mf::Max => "max", + Mf::Clamp => "clamp", + Mf::Saturate => { + write!(self.out, "clamp(")?; + + self.write_expr(arg, ctx)?; + + match *ctx.resolve_type(arg, &self.module.types) { + crate::TypeInner::Vector { size, .. } => write!( + self.out, + ", vec{}(0.0), vec{0}(1.0)", + back::vector_size_str(size) + )?, + _ => write!(self.out, ", 0.0, 1.0")?, + } + + write!(self.out, ")")?; + + return Ok(()); + } + // trigonometry + Mf::Cos => "cos", + Mf::Cosh => "cosh", + Mf::Sin => "sin", + Mf::Sinh => "sinh", + Mf::Tan => "tan", + Mf::Tanh => "tanh", + Mf::Acos => "acos", + Mf::Asin => "asin", + Mf::Atan => "atan", + Mf::Asinh => "asinh", + Mf::Acosh => "acosh", + Mf::Atanh => "atanh", + Mf::Radians => "radians", + Mf::Degrees => "degrees", + // glsl doesn't have atan2 function + // use two-argument variation of the atan function + Mf::Atan2 => "atan", + // decomposition + Mf::Ceil => "ceil", + Mf::Floor => "floor", + Mf::Round => "roundEven", + Mf::Fract => "fract", + Mf::Trunc => "trunc", + Mf::Modf => MODF_FUNCTION, + Mf::Frexp => FREXP_FUNCTION, + Mf::Ldexp => "ldexp", + // exponent + Mf::Exp => "exp", + Mf::Exp2 => "exp2", + Mf::Log => "log", + Mf::Log2 => "log2", + Mf::Pow => "pow", + // geometry + Mf::Dot => match *ctx.resolve_type(arg, &self.module.types) { + crate::TypeInner::Vector { + scalar: + crate::Scalar { + kind: crate::ScalarKind::Float, + .. + }, + .. + } => "dot", + crate::TypeInner::Vector { size, .. } => { + return self.write_dot_product(arg, arg1.unwrap(), size as usize, ctx) + } + _ => unreachable!( + "Correct TypeInner for dot product should be already validated" + ), + }, + Mf::Outer => "outerProduct", + Mf::Cross => "cross", + Mf::Distance => "distance", + Mf::Length => "length", + Mf::Normalize => "normalize", + Mf::FaceForward => "faceforward", + Mf::Reflect => "reflect", + Mf::Refract => "refract", + // computational + Mf::Sign => "sign", + Mf::Fma => { + if self.options.version.supports_fma_function() { + // Use the fma function when available + "fma" + } else { + // No fma support. Transform the function call into an arithmetic expression + write!(self.out, "(")?; + + self.write_expr(arg, ctx)?; + write!(self.out, " * ")?; + + let arg1 = + arg1.ok_or_else(|| Error::Custom("Missing fma arg1".to_owned()))?; + self.write_expr(arg1, ctx)?; + write!(self.out, " + ")?; + + let arg2 = + arg2.ok_or_else(|| Error::Custom("Missing fma arg2".to_owned()))?; + self.write_expr(arg2, ctx)?; + write!(self.out, ")")?; + + return Ok(()); + } + } + Mf::Mix => "mix", + Mf::Step => "step", + Mf::SmoothStep => "smoothstep", + Mf::Sqrt => "sqrt", + Mf::InverseSqrt => "inversesqrt", + Mf::Inverse => "inverse", + Mf::Transpose => "transpose", + Mf::Determinant => "determinant", + // bits + Mf::CountTrailingZeros => { + match *ctx.resolve_type(arg, &self.module.types) { + crate::TypeInner::Vector { size, scalar, .. } => { + let s = back::vector_size_str(size); + if let crate::ScalarKind::Uint = scalar.kind { + write!(self.out, "min(uvec{s}(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")), uvec{s}(32u))")?; + } else { + write!(self.out, "ivec{s}(min(uvec{s}(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")), uvec{s}(32u)))")?; + } + } + crate::TypeInner::Scalar(scalar) => { + if let crate::ScalarKind::Uint = scalar.kind { + write!(self.out, "min(uint(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")), 32u)")?; + } else { + write!(self.out, "int(min(uint(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")), 32u))")?; + } + } + _ => unreachable!(), + }; + return Ok(()); + } + Mf::CountLeadingZeros => { + if self.options.version.supports_integer_functions() { + match *ctx.resolve_type(arg, &self.module.types) { + crate::TypeInner::Vector { size, scalar } => { + let s = back::vector_size_str(size); + + if let crate::ScalarKind::Uint = scalar.kind { + write!(self.out, "uvec{s}(ivec{s}(31) - findMSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "mix(ivec{s}(31) - findMSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, "), ivec{s}(0), lessThan(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ", ivec{s}(0)))")?; + } + } + crate::TypeInner::Scalar(scalar) => { + if let crate::ScalarKind::Uint = scalar.kind { + write!(self.out, "uint(31 - findMSB(")?; + } else { + write!(self.out, "(")?; + self.write_expr(arg, ctx)?; + write!(self.out, " < 0 ? 0 : 31 - findMSB(")?; + } + + self.write_expr(arg, ctx)?; + write!(self.out, "))")?; + } + _ => unreachable!(), + }; + } else { + match *ctx.resolve_type(arg, &self.module.types) { + crate::TypeInner::Vector { size, scalar } => { + let s = back::vector_size_str(size); + + if let crate::ScalarKind::Uint = scalar.kind { + write!(self.out, "uvec{s}(")?; + write!(self.out, "vec{s}(31.0) - floor(log2(vec{s}(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ") + 0.5)))")?; + } else { + write!(self.out, "ivec{s}(")?; + write!(self.out, "mix(vec{s}(31.0) - floor(log2(vec{s}(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ") + 0.5)), ")?; + write!(self.out, "vec{s}(0.0), lessThan(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ", ivec{s}(0u))))")?; + } + } + crate::TypeInner::Scalar(scalar) => { + if let crate::ScalarKind::Uint = scalar.kind { + write!(self.out, "uint(31.0 - floor(log2(float(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ") + 0.5)))")?; + } else { + write!(self.out, "(")?; + self.write_expr(arg, ctx)?; + write!(self.out, " < 0 ? 0 : int(")?; + write!(self.out, "31.0 - floor(log2(float(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ") + 0.5))))")?; + } + } + _ => unreachable!(), + }; + } + + return Ok(()); + } + Mf::CountOneBits => "bitCount", + Mf::ReverseBits => "bitfieldReverse", + Mf::ExtractBits => "bitfieldExtract", + Mf::InsertBits => "bitfieldInsert", + Mf::FindLsb => "findLSB", + Mf::FindMsb => "findMSB", + // data packing + Mf::Pack4x8snorm => "packSnorm4x8", + Mf::Pack4x8unorm => "packUnorm4x8", + Mf::Pack2x16snorm => "packSnorm2x16", + Mf::Pack2x16unorm => "packUnorm2x16", + Mf::Pack2x16float => "packHalf2x16", + // data unpacking + Mf::Unpack4x8snorm => "unpackSnorm4x8", + Mf::Unpack4x8unorm => "unpackUnorm4x8", + Mf::Unpack2x16snorm => "unpackSnorm2x16", + Mf::Unpack2x16unorm => "unpackUnorm2x16", + Mf::Unpack2x16float => "unpackHalf2x16", + }; + + let extract_bits = fun == Mf::ExtractBits; + let insert_bits = fun == Mf::InsertBits; + + // Some GLSL functions always return signed integers (like findMSB), + // so they need to be cast to uint if the argument is also an uint. + let ret_might_need_int_to_uint = + matches!(fun, Mf::FindLsb | Mf::FindMsb | Mf::CountOneBits | Mf::Abs); + + // Some GLSL functions only accept signed integers (like abs), + // so they need their argument cast from uint to int. + let arg_might_need_uint_to_int = matches!(fun, Mf::Abs); + + // Check if the argument is an unsigned integer and return the vector size + // in case it's a vector + let maybe_uint_size = match *ctx.resolve_type(arg, &self.module.types) { + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Uint, + .. + }) => Some(None), + crate::TypeInner::Vector { + scalar: + crate::Scalar { + kind: crate::ScalarKind::Uint, + .. + }, + size, + } => Some(Some(size)), + _ => None, + }; + + // Cast to uint if the function needs it + if ret_might_need_int_to_uint { + if let Some(maybe_size) = maybe_uint_size { + match maybe_size { + Some(size) => write!(self.out, "uvec{}(", size as u8)?, + None => write!(self.out, "uint(")?, + } + } + } + + write!(self.out, "{fun_name}(")?; + + // Cast to int if the function needs it + if arg_might_need_uint_to_int { + if let Some(maybe_size) = maybe_uint_size { + match maybe_size { + Some(size) => write!(self.out, "ivec{}(", size as u8)?, + None => write!(self.out, "int(")?, + } + } + } + + self.write_expr(arg, ctx)?; + + // Close the cast from uint to int + if arg_might_need_uint_to_int && maybe_uint_size.is_some() { + write!(self.out, ")")? + } + + if let Some(arg) = arg1 { + write!(self.out, ", ")?; + if extract_bits { + write!(self.out, "int(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr(arg, ctx)?; + } + } + if let Some(arg) = arg2 { + write!(self.out, ", ")?; + if extract_bits || insert_bits { + write!(self.out, "int(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr(arg, ctx)?; + } + } + if let Some(arg) = arg3 { + write!(self.out, ", ")?; + if insert_bits { + write!(self.out, "int(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr(arg, ctx)?; + } + } + write!(self.out, ")")?; + + // Close the cast from int to uint + if ret_might_need_int_to_uint && maybe_uint_size.is_some() { + write!(self.out, ")")? + } + } + // `As` is always a call. + // If `convert` is true the function name is the type + // Else the function name is one of the glsl provided bitcast functions + Expression::As { + expr, + kind: target_kind, + convert, + } => { + let inner = ctx.resolve_type(expr, &self.module.types); + match convert { + Some(width) => { + // this is similar to `write_type`, but with the target kind + let scalar = glsl_scalar(crate::Scalar { + kind: target_kind, + width, + })?; + match *inner { + TypeInner::Matrix { columns, rows, .. } => write!( + self.out, + "{}mat{}x{}", + scalar.prefix, columns as u8, rows as u8 + )?, + TypeInner::Vector { size, .. } => { + write!(self.out, "{}vec{}", scalar.prefix, size as u8)? + } + _ => write!(self.out, "{}", scalar.full)?, + } + + write!(self.out, "(")?; + self.write_expr(expr, ctx)?; + write!(self.out, ")")? + } + None => { + use crate::ScalarKind as Sk; + + let target_vector_type = match *inner { + TypeInner::Vector { size, scalar } => Some(TypeInner::Vector { + size, + scalar: crate::Scalar { + kind: target_kind, + width: scalar.width, + }, + }), + _ => None, + }; + + let source_kind = inner.scalar_kind().unwrap(); + + match (source_kind, target_kind, target_vector_type) { + // No conversion needed + (Sk::Sint, Sk::Sint, _) + | (Sk::Uint, Sk::Uint, _) + | (Sk::Float, Sk::Float, _) + | (Sk::Bool, Sk::Bool, _) => { + self.write_expr(expr, ctx)?; + return Ok(()); + } + + // Cast to/from floats + (Sk::Float, Sk::Sint, _) => write!(self.out, "floatBitsToInt")?, + (Sk::Float, Sk::Uint, _) => write!(self.out, "floatBitsToUint")?, + (Sk::Sint, Sk::Float, _) => write!(self.out, "intBitsToFloat")?, + (Sk::Uint, Sk::Float, _) => write!(self.out, "uintBitsToFloat")?, + + // Cast between vector types + (_, _, Some(vector)) => { + self.write_value_type(&vector)?; + } + + // There is no way to bitcast between Uint/Sint in glsl. Use constructor conversion + (Sk::Uint | Sk::Bool, Sk::Sint, None) => write!(self.out, "int")?, + (Sk::Sint | Sk::Bool, Sk::Uint, None) => write!(self.out, "uint")?, + (Sk::Bool, Sk::Float, None) => write!(self.out, "float")?, + (Sk::Sint | Sk::Uint | Sk::Float, Sk::Bool, None) => { + write!(self.out, "bool")? + } + + (Sk::AbstractInt | Sk::AbstractFloat, _, _) + | (_, Sk::AbstractInt | Sk::AbstractFloat, _) => unreachable!(), + }; + + write!(self.out, "(")?; + self.write_expr(expr, ctx)?; + write!(self.out, ")")?; + } + } + } + // These expressions never show up in `Emit`. + Expression::CallResult(_) + | Expression::AtomicResult { .. } + | Expression::RayQueryProceedResult + | Expression::WorkGroupUniformLoadResult { .. } => unreachable!(), + // `ArrayLength` is written as `expr.length()` and we convert it to a uint + Expression::ArrayLength(expr) => { + write!(self.out, "uint(")?; + self.write_expr(expr, ctx)?; + write!(self.out, ".length())")? + } + // not supported yet + Expression::RayQueryGetIntersection { .. } => unreachable!(), + } + + Ok(()) + } + + /// Helper function to write the local holding the clamped lod + fn write_clamped_lod( + &mut self, + ctx: &back::FunctionCtx, + expr: Handle<crate::Expression>, + image: Handle<crate::Expression>, + level_expr: Handle<crate::Expression>, + ) -> Result<(), Error> { + // Define our local and start a call to `clamp` + write!( + self.out, + "int {}{}{} = clamp(", + back::BAKE_PREFIX, + expr.index(), + CLAMPED_LOD_SUFFIX + )?; + // Write the lod that will be clamped + self.write_expr(level_expr, ctx)?; + // Set the min value to 0 and start a call to `textureQueryLevels` to get + // the maximum value + write!(self.out, ", 0, textureQueryLevels(")?; + // Write the target image as an argument to `textureQueryLevels` + self.write_expr(image, ctx)?; + // Close the call to `textureQueryLevels` subtract 1 from it since + // the lod argument is 0 based, close the `clamp` call and end the + // local declaration statement. + writeln!(self.out, ") - 1);")?; + + Ok(()) + } + + // Helper method used to retrieve how many elements a coordinate vector + // for the images operations need. + fn get_coordinate_vector_size(&self, dim: crate::ImageDimension, arrayed: bool) -> u8 { + // openGL es doesn't have 1D images so we need workaround it + let tex_1d_hack = dim == crate::ImageDimension::D1 && self.options.version.is_es(); + // Get how many components the coordinate vector needs for the dimensions only + let tex_coord_size = match dim { + crate::ImageDimension::D1 => 1, + crate::ImageDimension::D2 => 2, + crate::ImageDimension::D3 => 3, + crate::ImageDimension::Cube => 2, + }; + // Calculate the true size of the coordinate vector by adding 1 for arrayed images + // and another 1 if we need to workaround 1D images by making them 2D + tex_coord_size + tex_1d_hack as u8 + arrayed as u8 + } + + /// Helper method to write the coordinate vector for image operations + fn write_texture_coord( + &mut self, + ctx: &back::FunctionCtx, + vector_size: u8, + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + // Emulate 1D images as 2D for profiles that don't support it (glsl es) + tex_1d_hack: bool, + ) -> Result<(), Error> { + match array_index { + // If the image needs an array indice we need to add it to the end of our + // coordinate vector, to do so we will use the `ivec(ivec, scalar)` + // constructor notation (NOTE: the inner `ivec` can also be a scalar, this + // is important for 1D arrayed images). + Some(layer_expr) => { + write!(self.out, "ivec{vector_size}(")?; + self.write_expr(coordinate, ctx)?; + write!(self.out, ", ")?; + // If we are replacing sampler1D with sampler2D we also need + // to add another zero to the coordinates vector for the y component + if tex_1d_hack { + write!(self.out, "0, ")?; + } + self.write_expr(layer_expr, ctx)?; + write!(self.out, ")")?; + } + // Otherwise write just the expression (and the 1D hack if needed) + None => { + let uvec_size = match *ctx.resolve_type(coordinate, &self.module.types) { + TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Uint, + .. + }) => Some(None), + TypeInner::Vector { + size, + scalar: + crate::Scalar { + kind: crate::ScalarKind::Uint, + .. + }, + } => Some(Some(size as u32)), + _ => None, + }; + if tex_1d_hack { + write!(self.out, "ivec2(")?; + } else if uvec_size.is_some() { + match uvec_size { + Some(None) => write!(self.out, "int(")?, + Some(Some(size)) => write!(self.out, "ivec{size}(")?, + _ => {} + } + } + self.write_expr(coordinate, ctx)?; + if tex_1d_hack { + write!(self.out, ", 0)")?; + } else if uvec_size.is_some() { + write!(self.out, ")")?; + } + } + } + + Ok(()) + } + + /// Helper method to write the `ImageStore` statement + fn write_image_store( + &mut self, + ctx: &back::FunctionCtx, + image: Handle<crate::Expression>, + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + value: Handle<crate::Expression>, + ) -> Result<(), Error> { + use crate::ImageDimension as IDim; + + // NOTE: openGL requires that `imageStore`s have no effets when the texel is invalid + // so we don't need to generate bounds checks (OpenGL 4.2 Core §3.9.20) + + // This will only panic if the module is invalid + let dim = match *ctx.resolve_type(image, &self.module.types) { + TypeInner::Image { dim, .. } => dim, + _ => unreachable!(), + }; + + // Begin our call to `imageStore` + write!(self.out, "imageStore(")?; + self.write_expr(image, ctx)?; + // Separate the image argument from the coordinates + write!(self.out, ", ")?; + + // openGL es doesn't have 1D images so we need workaround it + let tex_1d_hack = dim == IDim::D1 && self.options.version.is_es(); + // Write the coordinate vector + self.write_texture_coord( + ctx, + // Get the size of the coordinate vector + self.get_coordinate_vector_size(dim, array_index.is_some()), + coordinate, + array_index, + tex_1d_hack, + )?; + + // Separate the coordinate from the value to write and write the expression + // of the value to write. + write!(self.out, ", ")?; + self.write_expr(value, ctx)?; + // End the call to `imageStore` and the statement. + writeln!(self.out, ");")?; + + Ok(()) + } + + /// Helper method for writing an `ImageLoad` expression. + #[allow(clippy::too_many_arguments)] + fn write_image_load( + &mut self, + handle: Handle<crate::Expression>, + ctx: &back::FunctionCtx, + image: Handle<crate::Expression>, + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + sample: Option<Handle<crate::Expression>>, + level: Option<Handle<crate::Expression>>, + ) -> Result<(), Error> { + use crate::ImageDimension as IDim; + + // `ImageLoad` is a bit complicated. + // There are two functions one for sampled + // images another for storage images, the former uses `texelFetch` and the + // latter uses `imageLoad`. + // + // Furthermore we have `level` which is always `Some` for sampled images + // and `None` for storage images, so we end up with two functions: + // - `texelFetch(image, coordinate, level)` for sampled images + // - `imageLoad(image, coordinate)` for storage images + // + // Finally we also have to consider bounds checking, for storage images + // this is easy since openGL requires that invalid texels always return + // 0, for sampled images we need to either verify that all arguments are + // in bounds (`ReadZeroSkipWrite`) or make them a valid texel (`Restrict`). + + // This will only panic if the module is invalid + let (dim, class) = match *ctx.resolve_type(image, &self.module.types) { + TypeInner::Image { + dim, + arrayed: _, + class, + } => (dim, class), + _ => unreachable!(), + }; + + // Get the name of the function to be used for the load operation + // and the policy to be used with it. + let (fun_name, policy) = match class { + // Sampled images inherit the policy from the user passed policies + crate::ImageClass::Sampled { .. } => ("texelFetch", self.policies.image_load), + crate::ImageClass::Storage { .. } => { + // OpenGL ES 3.1 mentions in Chapter "8.22 Texture Image Loads and Stores" that: + // "Invalid image loads will return a vector where the value of R, G, and B components + // is 0 and the value of the A component is undefined." + // + // OpenGL 4.2 Core mentions in Chapter "3.9.20 Texture Image Loads and Stores" that: + // "Invalid image loads will return zero." + // + // So, we only inject bounds checks for ES + let policy = if self.options.version.is_es() { + self.policies.image_load + } else { + proc::BoundsCheckPolicy::Unchecked + }; + ("imageLoad", policy) + } + // TODO: Is there even a function for this? + crate::ImageClass::Depth { multi: _ } => { + return Err(Error::Custom( + "WGSL `textureLoad` from depth textures is not supported in GLSL".to_string(), + )) + } + }; + + // openGL es doesn't have 1D images so we need workaround it + let tex_1d_hack = dim == IDim::D1 && self.options.version.is_es(); + // Get the size of the coordinate vector + let vector_size = self.get_coordinate_vector_size(dim, array_index.is_some()); + + if let proc::BoundsCheckPolicy::ReadZeroSkipWrite = policy { + // To write the bounds checks for `ReadZeroSkipWrite` we will use a + // ternary operator since we are in the middle of an expression and + // need to return a value. + // + // NOTE: glsl does short circuit when evaluating logical + // expressions so we can be sure that after we test a + // condition it will be true for the next ones + + // Write parentheses around the ternary operator to prevent problems with + // expressions emitted before or after it having more precedence + write!(self.out, "(",)?; + + // The lod check needs to precede the size check since we need + // to use the lod to get the size of the image at that level. + if let Some(level_expr) = level { + self.write_expr(level_expr, ctx)?; + write!(self.out, " < textureQueryLevels(",)?; + self.write_expr(image, ctx)?; + // Chain the next check + write!(self.out, ") && ")?; + } + + // Check that the sample arguments doesn't exceed the number of samples + if let Some(sample_expr) = sample { + self.write_expr(sample_expr, ctx)?; + write!(self.out, " < textureSamples(",)?; + self.write_expr(image, ctx)?; + // Chain the next check + write!(self.out, ") && ")?; + } + + // We now need to write the size checks for the coordinates and array index + // first we write the comparison function in case the image is 1D non arrayed + // (and no 1D to 2D hack was needed) we are comparing scalars so the less than + // operator will suffice, but otherwise we'll be comparing two vectors so we'll + // need to use the `lessThan` function but it returns a vector of booleans (one + // for each comparison) so we need to fold it all in one scalar boolean, since + // we want all comparisons to pass we use the `all` function which will only + // return `true` if all the elements of the boolean vector are also `true`. + // + // So we'll end with one of the following forms + // - `coord < textureSize(image, lod)` for 1D images + // - `all(lessThan(coord, textureSize(image, lod)))` for normal images + // - `all(lessThan(ivec(coord, array_index), textureSize(image, lod)))` + // for arrayed images + // - `all(lessThan(coord, textureSize(image)))` for multi sampled images + + if vector_size != 1 { + write!(self.out, "all(lessThan(")?; + } + + // Write the coordinate vector + self.write_texture_coord(ctx, vector_size, coordinate, array_index, tex_1d_hack)?; + + if vector_size != 1 { + // If we used the `lessThan` function we need to separate the + // coordinates from the image size. + write!(self.out, ", ")?; + } else { + // If we didn't use it (ie. 1D images) we perform the comparison + // using the less than operator. + write!(self.out, " < ")?; + } + + // Call `textureSize` to get our image size + write!(self.out, "textureSize(")?; + self.write_expr(image, ctx)?; + // `textureSize` uses the lod as a second argument for mipmapped images + if let Some(level_expr) = level { + // Separate the image from the lod + write!(self.out, ", ")?; + self.write_expr(level_expr, ctx)?; + } + // Close the `textureSize` call + write!(self.out, ")")?; + + if vector_size != 1 { + // Close the `all` and `lessThan` calls + write!(self.out, "))")?; + } + + // Finally end the condition part of the ternary operator + write!(self.out, " ? ")?; + } + + // Begin the call to the function used to load the texel + write!(self.out, "{fun_name}(")?; + self.write_expr(image, ctx)?; + write!(self.out, ", ")?; + + // If we are using `Restrict` bounds checking we need to pass valid texel + // coordinates, to do so we use the `clamp` function to get a value between + // 0 and the image size - 1 (indexing begins at 0) + if let proc::BoundsCheckPolicy::Restrict = policy { + write!(self.out, "clamp(")?; + } + + // Write the coordinate vector + self.write_texture_coord(ctx, vector_size, coordinate, array_index, tex_1d_hack)?; + + // If we are using `Restrict` bounds checking we need to write the rest of the + // clamp we initiated before writing the coordinates. + if let proc::BoundsCheckPolicy::Restrict = policy { + // Write the min value 0 + if vector_size == 1 { + write!(self.out, ", 0")?; + } else { + write!(self.out, ", ivec{vector_size}(0)")?; + } + // Start the `textureSize` call to use as the max value. + write!(self.out, ", textureSize(")?; + self.write_expr(image, ctx)?; + // If the image is mipmapped we need to add the lod argument to the + // `textureSize` call, but this needs to be the clamped lod, this should + // have been generated earlier and put in a local. + if class.is_mipmapped() { + write!( + self.out, + ", {}{}{}", + back::BAKE_PREFIX, + handle.index(), + CLAMPED_LOD_SUFFIX + )?; + } + // Close the `textureSize` call + write!(self.out, ")")?; + + // Subtract 1 from the `textureSize` call since the coordinates are zero based. + if vector_size == 1 { + write!(self.out, " - 1")?; + } else { + write!(self.out, " - ivec{vector_size}(1)")?; + } + + // Close the `clamp` call + write!(self.out, ")")?; + + // Add the clamped lod (if present) as the second argument to the + // image load function. + if level.is_some() { + write!( + self.out, + ", {}{}{}", + back::BAKE_PREFIX, + handle.index(), + CLAMPED_LOD_SUFFIX + )?; + } + + // If a sample argument is needed we need to clamp it between 0 and + // the number of samples the image has. + if let Some(sample_expr) = sample { + write!(self.out, ", clamp(")?; + self.write_expr(sample_expr, ctx)?; + // Set the min value to 0 and start the call to `textureSamples` + write!(self.out, ", 0, textureSamples(")?; + self.write_expr(image, ctx)?; + // Close the `textureSamples` call, subtract 1 from it since the sample + // argument is zero based, and close the `clamp` call + writeln!(self.out, ") - 1)")?; + } + } else if let Some(sample_or_level) = sample.or(level) { + // If no bounds checking is need just add the sample or level argument + // after the coordinates + write!(self.out, ", ")?; + self.write_expr(sample_or_level, ctx)?; + } + + // Close the image load function. + write!(self.out, ")")?; + + // If we were using the `ReadZeroSkipWrite` policy we need to end the first branch + // (which is taken if the condition is `true`) with a colon (`:`) and write the + // second branch which is just a 0 value. + if let proc::BoundsCheckPolicy::ReadZeroSkipWrite = policy { + // Get the kind of the output value. + let kind = match class { + // Only sampled images can reach here since storage images + // don't need bounds checks and depth images aren't implemented + crate::ImageClass::Sampled { kind, .. } => kind, + _ => unreachable!(), + }; + + // End the first branch + write!(self.out, " : ")?; + // Write the 0 value + write!( + self.out, + "{}vec4(", + glsl_scalar(crate::Scalar { kind, width: 4 })?.prefix, + )?; + self.write_zero_init_scalar(kind)?; + // Close the zero value constructor + write!(self.out, ")")?; + // Close the parentheses surrounding our ternary + write!(self.out, ")")?; + } + + Ok(()) + } + + fn write_named_expr( + &mut self, + handle: Handle<crate::Expression>, + name: String, + // The expression which is being named. + // Generally, this is the same as handle, except in WorkGroupUniformLoad + named: Handle<crate::Expression>, + ctx: &back::FunctionCtx, + ) -> BackendResult { + match ctx.info[named].ty { + proc::TypeResolution::Handle(ty_handle) => match self.module.types[ty_handle].inner { + TypeInner::Struct { .. } => { + let ty_name = &self.names[&NameKey::Type(ty_handle)]; + write!(self.out, "{ty_name}")?; + } + _ => { + self.write_type(ty_handle)?; + } + }, + proc::TypeResolution::Value(ref inner) => { + self.write_value_type(inner)?; + } + } + + let resolved = ctx.resolve_type(named, &self.module.types); + + write!(self.out, " {name}")?; + if let TypeInner::Array { base, size, .. } = *resolved { + self.write_array_size(base, size)?; + } + write!(self.out, " = ")?; + self.write_expr(handle, ctx)?; + writeln!(self.out, ";")?; + self.named_expressions.insert(named, name); + + Ok(()) + } + + /// Helper function that write string with default zero initialization for supported types + fn write_zero_init_value(&mut self, ty: Handle<crate::Type>) -> BackendResult { + let inner = &self.module.types[ty].inner; + match *inner { + TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => { + self.write_zero_init_scalar(scalar.kind)?; + } + TypeInner::Vector { scalar, .. } => { + self.write_value_type(inner)?; + write!(self.out, "(")?; + self.write_zero_init_scalar(scalar.kind)?; + write!(self.out, ")")?; + } + TypeInner::Matrix { .. } => { + self.write_value_type(inner)?; + write!(self.out, "(")?; + self.write_zero_init_scalar(crate::ScalarKind::Float)?; + write!(self.out, ")")?; + } + TypeInner::Array { base, size, .. } => { + let count = match size + .to_indexable_length(self.module) + .expect("Bad array size") + { + proc::IndexableLength::Known(count) => count, + proc::IndexableLength::Dynamic => return Ok(()), + }; + self.write_type(base)?; + self.write_array_size(base, size)?; + write!(self.out, "(")?; + for _ in 1..count { + self.write_zero_init_value(base)?; + write!(self.out, ", ")?; + } + // write last parameter without comma and space + self.write_zero_init_value(base)?; + write!(self.out, ")")?; + } + TypeInner::Struct { ref members, .. } => { + let name = &self.names[&NameKey::Type(ty)]; + write!(self.out, "{name}(")?; + for (index, member) in members.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + self.write_zero_init_value(member.ty)?; + } + write!(self.out, ")")?; + } + _ => unreachable!(), + } + + Ok(()) + } + + /// Helper function that write string with zero initialization for scalar + fn write_zero_init_scalar(&mut self, kind: crate::ScalarKind) -> BackendResult { + match kind { + crate::ScalarKind::Bool => write!(self.out, "false")?, + crate::ScalarKind::Uint => write!(self.out, "0u")?, + crate::ScalarKind::Float => write!(self.out, "0.0")?, + crate::ScalarKind::Sint => write!(self.out, "0")?, + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".to_string(), + )) + } + } + + Ok(()) + } + + /// Issue a memory barrier. Please note that to ensure visibility, + /// OpenGL always requires a call to the `barrier()` function after a `memoryBarrier*()` + fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult { + if flags.contains(crate::Barrier::STORAGE) { + writeln!(self.out, "{level}memoryBarrierBuffer();")?; + } + if flags.contains(crate::Barrier::WORK_GROUP) { + writeln!(self.out, "{level}memoryBarrierShared();")?; + } + writeln!(self.out, "{level}barrier();")?; + Ok(()) + } + + /// Helper function that return the glsl storage access string of [`StorageAccess`](crate::StorageAccess) + /// + /// glsl allows adding both `readonly` and `writeonly` but this means that + /// they can only be used to query information about the resource which isn't what + /// we want here so when storage access is both `LOAD` and `STORE` add no modifiers + fn write_storage_access(&mut self, storage_access: crate::StorageAccess) -> BackendResult { + if !storage_access.contains(crate::StorageAccess::STORE) { + write!(self.out, "readonly ")?; + } + if !storage_access.contains(crate::StorageAccess::LOAD) { + write!(self.out, "writeonly ")?; + } + Ok(()) + } + + /// Helper method used to produce the reflection info that's returned to the user + fn collect_reflection_info(&mut self) -> Result<ReflectionInfo, Error> { + use std::collections::hash_map::Entry; + let info = self.info.get_entry_point(self.entry_point_idx as usize); + let mut texture_mapping = crate::FastHashMap::default(); + let mut uniforms = crate::FastHashMap::default(); + + for sampling in info.sampling_set.iter() { + let tex_name = self.reflection_names_globals[&sampling.image].clone(); + + match texture_mapping.entry(tex_name) { + Entry::Vacant(v) => { + v.insert(TextureMapping { + texture: sampling.image, + sampler: Some(sampling.sampler), + }); + } + Entry::Occupied(e) => { + if e.get().sampler != Some(sampling.sampler) { + log::error!("Conflicting samplers for {}", e.key()); + return Err(Error::ImageMultipleSamplers); + } + } + } + } + + let mut push_constant_info = None; + for (handle, var) in self.module.global_variables.iter() { + if info[handle].is_empty() { + continue; + } + match self.module.types[var.ty].inner { + crate::TypeInner::Image { .. } => { + let tex_name = self.reflection_names_globals[&handle].clone(); + match texture_mapping.entry(tex_name) { + Entry::Vacant(v) => { + v.insert(TextureMapping { + texture: handle, + sampler: None, + }); + } + Entry::Occupied(_) => { + // already used with a sampler, do nothing + } + } + } + _ => match var.space { + crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } => { + let name = self.reflection_names_globals[&handle].clone(); + uniforms.insert(handle, name); + } + crate::AddressSpace::PushConstant => { + let name = self.reflection_names_globals[&handle].clone(); + push_constant_info = Some((name, var.ty)); + } + _ => (), + }, + } + } + + let mut push_constant_segments = Vec::new(); + let mut push_constant_items = vec![]; + + if let Some((name, ty)) = push_constant_info { + // We don't have a layouter available to us, so we need to create one. + // + // This is potentially a bit wasteful, but the set of types in the program + // shouldn't be too large. + let mut layouter = crate::proc::Layouter::default(); + layouter.update(self.module.to_ctx()).unwrap(); + + // We start with the name of the binding itself. + push_constant_segments.push(name); + + // We then recursively collect all the uniform fields of the push constant. + self.collect_push_constant_items( + ty, + &mut push_constant_segments, + &layouter, + &mut 0, + &mut push_constant_items, + ); + } + + Ok(ReflectionInfo { + texture_mapping, + uniforms, + varying: mem::take(&mut self.varying), + push_constant_items, + }) + } + + fn collect_push_constant_items( + &mut self, + ty: Handle<crate::Type>, + segments: &mut Vec<String>, + layouter: &crate::proc::Layouter, + offset: &mut u32, + items: &mut Vec<PushConstantItem>, + ) { + // At this point in the recursion, `segments` contains the path + // needed to access `ty` from the root. + + let layout = &layouter[ty]; + *offset = layout.alignment.round_up(*offset); + match self.module.types[ty].inner { + // All these types map directly to GL uniforms. + TypeInner::Scalar { .. } | TypeInner::Vector { .. } | TypeInner::Matrix { .. } => { + // Build the full name, by combining all current segments. + let name: String = segments.iter().map(String::as_str).collect(); + items.push(PushConstantItem { + access_path: name, + offset: *offset, + ty, + }); + *offset += layout.size; + } + // Arrays are recursed into. + TypeInner::Array { base, size, .. } => { + let crate::ArraySize::Constant(count) = size else { + unreachable!("Cannot have dynamic arrays in push constants"); + }; + + for i in 0..count.get() { + // Add the array accessor and recurse. + segments.push(format!("[{}]", i)); + self.collect_push_constant_items(base, segments, layouter, offset, items); + segments.pop(); + } + + // Ensure the stride is kept by rounding up to the alignment. + *offset = layout.alignment.round_up(*offset) + } + TypeInner::Struct { ref members, .. } => { + for (index, member) in members.iter().enumerate() { + // Add struct accessor and recurse. + segments.push(format!( + ".{}", + self.names[&NameKey::StructMember(ty, index as u32)] + )); + self.collect_push_constant_items(member.ty, segments, layouter, offset, items); + segments.pop(); + } + + // Ensure ending padding is kept by rounding up to the alignment. + *offset = layout.alignment.round_up(*offset) + } + _ => unreachable!(), + } + } +} + +/// Structure returned by [`glsl_scalar`] +/// +/// It contains both a prefix used in other types and the full type name +struct ScalarString<'a> { + /// The prefix used to compose other types + prefix: &'a str, + /// The name of the scalar type + full: &'a str, +} + +/// Helper function that returns scalar related strings +/// +/// Check [`ScalarString`] for the information provided +/// +/// # Errors +/// If a [`Float`](crate::ScalarKind::Float) with an width that isn't 4 or 8 +const fn glsl_scalar(scalar: crate::Scalar) -> Result<ScalarString<'static>, Error> { + use crate::ScalarKind as Sk; + + Ok(match scalar.kind { + Sk::Sint => ScalarString { + prefix: "i", + full: "int", + }, + Sk::Uint => ScalarString { + prefix: "u", + full: "uint", + }, + Sk::Float => match scalar.width { + 4 => ScalarString { + prefix: "", + full: "float", + }, + 8 => ScalarString { + prefix: "d", + full: "double", + }, + _ => return Err(Error::UnsupportedScalar(scalar)), + }, + Sk::Bool => ScalarString { + prefix: "b", + full: "bool", + }, + Sk::AbstractInt | Sk::AbstractFloat => { + return Err(Error::UnsupportedScalar(scalar)); + } + }) +} + +/// Helper function that returns the glsl variable name for a builtin +const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'static str { + use crate::BuiltIn as Bi; + + match built_in { + Bi::Position { .. } => { + if options.output { + "gl_Position" + } else { + "gl_FragCoord" + } + } + Bi::ViewIndex if options.targeting_webgl => "int(gl_ViewID_OVR)", + Bi::ViewIndex => "gl_ViewIndex", + // vertex + Bi::BaseInstance => "uint(gl_BaseInstance)", + Bi::BaseVertex => "uint(gl_BaseVertex)", + Bi::ClipDistance => "gl_ClipDistance", + Bi::CullDistance => "gl_CullDistance", + Bi::InstanceIndex => { + if options.draw_parameters { + "(uint(gl_InstanceID) + uint(gl_BaseInstanceARB))" + } else { + // Must match FIRST_INSTANCE_BINDING + "(uint(gl_InstanceID) + naga_vs_first_instance)" + } + } + Bi::PointSize => "gl_PointSize", + Bi::VertexIndex => "uint(gl_VertexID)", + // fragment + Bi::FragDepth => "gl_FragDepth", + Bi::PointCoord => "gl_PointCoord", + Bi::FrontFacing => "gl_FrontFacing", + Bi::PrimitiveIndex => "uint(gl_PrimitiveID)", + Bi::SampleIndex => "gl_SampleID", + Bi::SampleMask => { + if options.output { + "gl_SampleMask" + } else { + "gl_SampleMaskIn" + } + } + // compute + Bi::GlobalInvocationId => "gl_GlobalInvocationID", + Bi::LocalInvocationId => "gl_LocalInvocationID", + Bi::LocalInvocationIndex => "gl_LocalInvocationIndex", + Bi::WorkGroupId => "gl_WorkGroupID", + Bi::WorkGroupSize => "gl_WorkGroupSize", + Bi::NumWorkGroups => "gl_NumWorkGroups", + } +} + +/// Helper function that returns the string corresponding to the address space +const fn glsl_storage_qualifier(space: crate::AddressSpace) -> Option<&'static str> { + use crate::AddressSpace as As; + + match space { + As::Function => None, + As::Private => None, + As::Storage { .. } => Some("buffer"), + As::Uniform => Some("uniform"), + As::Handle => Some("uniform"), + As::WorkGroup => Some("shared"), + As::PushConstant => Some("uniform"), + } +} + +/// Helper function that returns the string corresponding to the glsl interpolation qualifier +const fn glsl_interpolation(interpolation: crate::Interpolation) -> &'static str { + use crate::Interpolation as I; + + match interpolation { + I::Perspective => "smooth", + I::Linear => "noperspective", + I::Flat => "flat", + } +} + +/// Return the GLSL auxiliary qualifier for the given sampling value. +const fn glsl_sampling(sampling: crate::Sampling) -> Option<&'static str> { + use crate::Sampling as S; + + match sampling { + S::Center => None, + S::Centroid => Some("centroid"), + S::Sample => Some("sample"), + } +} + +/// Helper function that returns the glsl dimension string of [`ImageDimension`](crate::ImageDimension) +const fn glsl_dimension(dim: crate::ImageDimension) -> &'static str { + use crate::ImageDimension as IDim; + + match dim { + IDim::D1 => "1D", + IDim::D2 => "2D", + IDim::D3 => "3D", + IDim::Cube => "Cube", + } +} + +/// Helper function that returns the glsl storage format string of [`StorageFormat`](crate::StorageFormat) +fn glsl_storage_format(format: crate::StorageFormat) -> Result<&'static str, Error> { + use crate::StorageFormat as Sf; + + Ok(match format { + Sf::R8Unorm => "r8", + Sf::R8Snorm => "r8_snorm", + Sf::R8Uint => "r8ui", + Sf::R8Sint => "r8i", + Sf::R16Uint => "r16ui", + Sf::R16Sint => "r16i", + Sf::R16Float => "r16f", + Sf::Rg8Unorm => "rg8", + Sf::Rg8Snorm => "rg8_snorm", + Sf::Rg8Uint => "rg8ui", + Sf::Rg8Sint => "rg8i", + Sf::R32Uint => "r32ui", + Sf::R32Sint => "r32i", + Sf::R32Float => "r32f", + Sf::Rg16Uint => "rg16ui", + Sf::Rg16Sint => "rg16i", + Sf::Rg16Float => "rg16f", + Sf::Rgba8Unorm => "rgba8", + Sf::Rgba8Snorm => "rgba8_snorm", + Sf::Rgba8Uint => "rgba8ui", + Sf::Rgba8Sint => "rgba8i", + Sf::Rgb10a2Uint => "rgb10_a2ui", + Sf::Rgb10a2Unorm => "rgb10_a2", + Sf::Rg11b10Float => "r11f_g11f_b10f", + Sf::Rg32Uint => "rg32ui", + Sf::Rg32Sint => "rg32i", + Sf::Rg32Float => "rg32f", + Sf::Rgba16Uint => "rgba16ui", + Sf::Rgba16Sint => "rgba16i", + Sf::Rgba16Float => "rgba16f", + Sf::Rgba32Uint => "rgba32ui", + Sf::Rgba32Sint => "rgba32i", + Sf::Rgba32Float => "rgba32f", + Sf::R16Unorm => "r16", + Sf::R16Snorm => "r16_snorm", + Sf::Rg16Unorm => "rg16", + Sf::Rg16Snorm => "rg16_snorm", + Sf::Rgba16Unorm => "rgba16", + Sf::Rgba16Snorm => "rgba16_snorm", + + Sf::Bgra8Unorm => { + return Err(Error::Custom( + "Support format BGRA8 is not implemented".into(), + )) + } + }) +} + +fn is_value_init_supported(module: &crate::Module, ty: Handle<crate::Type>) -> bool { + match module.types[ty].inner { + TypeInner::Scalar { .. } | TypeInner::Vector { .. } | TypeInner::Matrix { .. } => true, + TypeInner::Array { base, size, .. } => { + size != crate::ArraySize::Dynamic && is_value_init_supported(module, base) + } + TypeInner::Struct { ref members, .. } => members + .iter() + .all(|member| is_value_init_supported(module, member.ty)), + _ => false, + } +} diff --git a/third_party/rust/naga/src/back/hlsl/conv.rs b/third_party/rust/naga/src/back/hlsl/conv.rs new file mode 100644 index 0000000000..b6918ddc42 --- /dev/null +++ b/third_party/rust/naga/src/back/hlsl/conv.rs @@ -0,0 +1,222 @@ +use std::borrow::Cow; + +use crate::proc::Alignment; + +use super::Error; + +impl crate::ScalarKind { + pub(super) fn to_hlsl_cast(self) -> &'static str { + match self { + Self::Float => "asfloat", + Self::Sint => "asint", + Self::Uint => "asuint", + Self::Bool | Self::AbstractInt | Self::AbstractFloat => unreachable!(), + } + } +} + +impl crate::Scalar { + /// Helper function that returns scalar related strings + /// + /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-scalar> + pub(super) const fn to_hlsl_str(self) -> Result<&'static str, Error> { + match self.kind { + crate::ScalarKind::Sint => Ok("int"), + crate::ScalarKind::Uint => Ok("uint"), + crate::ScalarKind::Float => match self.width { + 2 => Ok("half"), + 4 => Ok("float"), + 8 => Ok("double"), + _ => Err(Error::UnsupportedScalar(self)), + }, + crate::ScalarKind::Bool => Ok("bool"), + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { + Err(Error::UnsupportedScalar(self)) + } + } + } +} + +impl crate::TypeInner { + pub(super) const fn is_matrix(&self) -> bool { + match *self { + Self::Matrix { .. } => true, + _ => false, + } + } + + pub(super) fn size_hlsl(&self, gctx: crate::proc::GlobalCtx) -> u32 { + match *self { + Self::Matrix { + columns, + rows, + scalar, + } => { + let stride = Alignment::from(rows) * scalar.width as u32; + let last_row_size = rows as u32 * scalar.width as u32; + ((columns as u32 - 1) * stride) + last_row_size + } + Self::Array { base, size, stride } => { + let count = match size { + crate::ArraySize::Constant(size) => size.get(), + // A dynamically-sized array has to have at least one element + crate::ArraySize::Dynamic => 1, + }; + let last_el_size = gctx.types[base].inner.size_hlsl(gctx); + ((count - 1) * stride) + last_el_size + } + _ => self.size(gctx), + } + } + + /// Used to generate the name of the wrapped type constructor + pub(super) fn hlsl_type_id<'a>( + base: crate::Handle<crate::Type>, + gctx: crate::proc::GlobalCtx, + names: &'a crate::FastHashMap<crate::proc::NameKey, String>, + ) -> Result<Cow<'a, str>, Error> { + Ok(match gctx.types[base].inner { + crate::TypeInner::Scalar(scalar) => Cow::Borrowed(scalar.to_hlsl_str()?), + crate::TypeInner::Vector { size, scalar } => Cow::Owned(format!( + "{}{}", + scalar.to_hlsl_str()?, + crate::back::vector_size_str(size) + )), + crate::TypeInner::Matrix { + columns, + rows, + scalar, + } => Cow::Owned(format!( + "{}{}x{}", + scalar.to_hlsl_str()?, + crate::back::vector_size_str(columns), + crate::back::vector_size_str(rows), + )), + crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + .. + } => Cow::Owned(format!( + "array{size}_{}_", + Self::hlsl_type_id(base, gctx, names)? + )), + crate::TypeInner::Struct { .. } => { + Cow::Borrowed(&names[&crate::proc::NameKey::Type(base)]) + } + _ => unreachable!(), + }) + } +} + +impl crate::StorageFormat { + pub(super) const fn to_hlsl_str(self) -> &'static str { + match self { + Self::R16Float => "float", + Self::R8Unorm | Self::R16Unorm => "unorm float", + Self::R8Snorm | Self::R16Snorm => "snorm float", + Self::R8Uint | Self::R16Uint => "uint", + Self::R8Sint | Self::R16Sint => "int", + + Self::Rg16Float => "float2", + Self::Rg8Unorm | Self::Rg16Unorm => "unorm float2", + Self::Rg8Snorm | Self::Rg16Snorm => "snorm float2", + + Self::Rg8Sint | Self::Rg16Sint => "int2", + Self::Rg8Uint | Self::Rg16Uint => "uint2", + + Self::Rg11b10Float => "float3", + + Self::Rgba16Float | Self::R32Float | Self::Rg32Float | Self::Rgba32Float => "float4", + Self::Rgba8Unorm | Self::Bgra8Unorm | Self::Rgba16Unorm | Self::Rgb10a2Unorm => { + "unorm float4" + } + Self::Rgba8Snorm | Self::Rgba16Snorm => "snorm float4", + + Self::Rgba8Uint + | Self::Rgba16Uint + | Self::R32Uint + | Self::Rg32Uint + | Self::Rgba32Uint + | Self::Rgb10a2Uint => "uint4", + Self::Rgba8Sint + | Self::Rgba16Sint + | Self::R32Sint + | Self::Rg32Sint + | Self::Rgba32Sint => "int4", + } + } +} + +impl crate::BuiltIn { + pub(super) fn to_hlsl_str(self) -> Result<&'static str, Error> { + Ok(match self { + Self::Position { .. } => "SV_Position", + // vertex + Self::ClipDistance => "SV_ClipDistance", + Self::CullDistance => "SV_CullDistance", + Self::InstanceIndex => "SV_InstanceID", + Self::VertexIndex => "SV_VertexID", + // fragment + Self::FragDepth => "SV_Depth", + Self::FrontFacing => "SV_IsFrontFace", + Self::PrimitiveIndex => "SV_PrimitiveID", + Self::SampleIndex => "SV_SampleIndex", + Self::SampleMask => "SV_Coverage", + // compute + Self::GlobalInvocationId => "SV_DispatchThreadID", + Self::LocalInvocationId => "SV_GroupThreadID", + Self::LocalInvocationIndex => "SV_GroupIndex", + Self::WorkGroupId => "SV_GroupID", + // The specific semantic we use here doesn't matter, because references + // to this field will get replaced with references to `SPECIAL_CBUF_VAR` + // in `Writer::write_expr`. + Self::NumWorkGroups => "SV_GroupID", + Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => { + return Err(Error::Unimplemented(format!("builtin {self:?}"))) + } + Self::PointSize | Self::ViewIndex | Self::PointCoord => { + return Err(Error::Custom(format!("Unsupported builtin {self:?}"))) + } + }) + } +} + +impl crate::Interpolation { + /// Return the string corresponding to the HLSL interpolation qualifier. + pub(super) const fn to_hlsl_str(self) -> Option<&'static str> { + match self { + // Would be "linear", but it's the default interpolation in SM4 and up + // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-struct#interpolation-modifiers-introduced-in-shader-model-4 + Self::Perspective => None, + Self::Linear => Some("noperspective"), + Self::Flat => Some("nointerpolation"), + } + } +} + +impl crate::Sampling { + /// Return the HLSL auxiliary qualifier for the given sampling value. + pub(super) const fn to_hlsl_str(self) -> Option<&'static str> { + match self { + Self::Center => None, + Self::Centroid => Some("centroid"), + Self::Sample => Some("sample"), + } + } +} + +impl crate::AtomicFunction { + /// Return the HLSL suffix for the `InterlockedXxx` method. + pub(super) const fn to_hlsl_suffix(self) -> &'static str { + match self { + Self::Add | Self::Subtract => "Add", + Self::And => "And", + Self::InclusiveOr => "Or", + Self::ExclusiveOr => "Xor", + Self::Min => "Min", + Self::Max => "Max", + Self::Exchange { compare: None } => "Exchange", + Self::Exchange { .. } => "", //TODO + } + } +} diff --git a/third_party/rust/naga/src/back/hlsl/help.rs b/third_party/rust/naga/src/back/hlsl/help.rs new file mode 100644 index 0000000000..fa6062a1ad --- /dev/null +++ b/third_party/rust/naga/src/back/hlsl/help.rs @@ -0,0 +1,1138 @@ +/*! +Helpers for the hlsl backend + +Important note about `Expression::ImageQuery`/`Expression::ArrayLength` and hlsl backend: + +Due to implementation of `GetDimensions` function in hlsl (<https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>) +backend can't work with it as an expression. +Instead, it generates a unique wrapped function per `Expression::ImageQuery`, based on texture info and query function. +See `WrappedImageQuery` struct that represents a unique function and will be generated before writing all statements and expressions. +This allowed to works with `Expression::ImageQuery` as expression and write wrapped function. + +For example: +```wgsl +let dim_1d = textureDimensions(image_1d); +``` + +```hlsl +int NagaDimensions1D(Texture1D<float4>) +{ + uint4 ret; + image_1d.GetDimensions(ret.x); + return ret.x; +} + +int dim_1d = NagaDimensions1D(image_1d); +``` +*/ + +use super::{super::FunctionCtx, BackendResult}; +use crate::{arena::Handle, proc::NameKey}; +use std::fmt::Write; + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) struct WrappedArrayLength { + pub(super) writable: bool, +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) struct WrappedImageQuery { + pub(super) dim: crate::ImageDimension, + pub(super) arrayed: bool, + pub(super) class: crate::ImageClass, + pub(super) query: ImageQuery, +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) struct WrappedConstructor { + pub(super) ty: Handle<crate::Type>, +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) struct WrappedStructMatrixAccess { + pub(super) ty: Handle<crate::Type>, + pub(super) index: u32, +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) struct WrappedMatCx2 { + pub(super) columns: crate::VectorSize, +} + +/// HLSL backend requires its own `ImageQuery` enum. +/// +/// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function. +/// IR version can't be unique per function, because it's store mipmap level as an expression. +/// +/// For example: +/// ```wgsl +/// let dim_cube_array_lod = textureDimensions(image_cube_array, 1); +/// let dim_cube_array_lod2 = textureDimensions(image_cube_array, 1); +/// ``` +/// +/// ```ir +/// ImageQuery { +/// image: [1], +/// query: Size { +/// level: Some( +/// [1], +/// ), +/// }, +/// }, +/// ImageQuery { +/// image: [1], +/// query: Size { +/// level: Some( +/// [2], +/// ), +/// }, +/// }, +/// ``` +/// +/// HLSL should generate only 1 function for this case. +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) enum ImageQuery { + Size, + SizeLevel, + NumLevels, + NumLayers, + NumSamples, +} + +impl From<crate::ImageQuery> for ImageQuery { + fn from(q: crate::ImageQuery) -> Self { + use crate::ImageQuery as Iq; + match q { + Iq::Size { level: Some(_) } => ImageQuery::SizeLevel, + Iq::Size { level: None } => ImageQuery::Size, + Iq::NumLevels => ImageQuery::NumLevels, + Iq::NumLayers => ImageQuery::NumLayers, + Iq::NumSamples => ImageQuery::NumSamples, + } + } +} + +impl<'a, W: Write> super::Writer<'a, W> { + pub(super) fn write_image_type( + &mut self, + dim: crate::ImageDimension, + arrayed: bool, + class: crate::ImageClass, + ) -> BackendResult { + let access_str = match class { + crate::ImageClass::Storage { .. } => "RW", + _ => "", + }; + let dim_str = dim.to_hlsl_str(); + let arrayed_str = if arrayed { "Array" } else { "" }; + write!(self.out, "{access_str}Texture{dim_str}{arrayed_str}")?; + match class { + crate::ImageClass::Depth { multi } => { + let multi_str = if multi { "MS" } else { "" }; + write!(self.out, "{multi_str}<float>")? + } + crate::ImageClass::Sampled { kind, multi } => { + let multi_str = if multi { "MS" } else { "" }; + let scalar_kind_str = crate::Scalar { kind, width: 4 }.to_hlsl_str()?; + write!(self.out, "{multi_str}<{scalar_kind_str}4>")? + } + crate::ImageClass::Storage { format, .. } => { + let storage_format_str = format.to_hlsl_str(); + write!(self.out, "<{storage_format_str}>")? + } + } + Ok(()) + } + + pub(super) fn write_wrapped_array_length_function_name( + &mut self, + query: WrappedArrayLength, + ) -> BackendResult { + let access_str = if query.writable { "RW" } else { "" }; + write!(self.out, "NagaBufferLength{access_str}",)?; + + Ok(()) + } + + /// Helper function that write wrapped function for `Expression::ArrayLength` + /// + /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer-getdimensions> + pub(super) fn write_wrapped_array_length_function( + &mut self, + wal: WrappedArrayLength, + ) -> BackendResult { + use crate::back::INDENT; + + const ARGUMENT_VARIABLE_NAME: &str = "buffer"; + const RETURN_VARIABLE_NAME: &str = "ret"; + + // Write function return type and name + write!(self.out, "uint ")?; + self.write_wrapped_array_length_function_name(wal)?; + + // Write function parameters + write!(self.out, "(")?; + let access_str = if wal.writable { "RW" } else { "" }; + writeln!( + self.out, + "{access_str}ByteAddressBuffer {ARGUMENT_VARIABLE_NAME})" + )?; + // Write function body + writeln!(self.out, "{{")?; + + // Write `GetDimensions` function. + writeln!(self.out, "{INDENT}uint {RETURN_VARIABLE_NAME};")?; + writeln!( + self.out, + "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions({RETURN_VARIABLE_NAME});" + )?; + + // Write return value + writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?; + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } + + pub(super) fn write_wrapped_image_query_function_name( + &mut self, + query: WrappedImageQuery, + ) -> BackendResult { + let dim_str = query.dim.to_hlsl_str(); + let class_str = match query.class { + crate::ImageClass::Sampled { multi: true, .. } => "MS", + crate::ImageClass::Depth { multi: true } => "DepthMS", + crate::ImageClass::Depth { multi: false } => "Depth", + crate::ImageClass::Sampled { multi: false, .. } => "", + crate::ImageClass::Storage { .. } => "RW", + }; + let arrayed_str = if query.arrayed { "Array" } else { "" }; + let query_str = match query.query { + ImageQuery::Size => "Dimensions", + ImageQuery::SizeLevel => "MipDimensions", + ImageQuery::NumLevels => "NumLevels", + ImageQuery::NumLayers => "NumLayers", + ImageQuery::NumSamples => "NumSamples", + }; + + write!(self.out, "Naga{class_str}{query_str}{dim_str}{arrayed_str}")?; + + Ok(()) + } + + /// Helper function that write wrapped function for `Expression::ImageQuery` + /// + /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions> + pub(super) fn write_wrapped_image_query_function( + &mut self, + module: &crate::Module, + wiq: WrappedImageQuery, + expr_handle: Handle<crate::Expression>, + func_ctx: &FunctionCtx, + ) -> BackendResult { + use crate::{ + back::{COMPONENTS, INDENT}, + ImageDimension as IDim, + }; + + const ARGUMENT_VARIABLE_NAME: &str = "tex"; + const RETURN_VARIABLE_NAME: &str = "ret"; + const MIP_LEVEL_PARAM: &str = "mip_level"; + + // Write function return type and name + let ret_ty = func_ctx.resolve_type(expr_handle, &module.types); + self.write_value_type(module, ret_ty)?; + write!(self.out, " ")?; + self.write_wrapped_image_query_function_name(wiq)?; + + // Write function parameters + write!(self.out, "(")?; + // Texture always first parameter + self.write_image_type(wiq.dim, wiq.arrayed, wiq.class)?; + write!(self.out, " {ARGUMENT_VARIABLE_NAME}")?; + // Mipmap is a second parameter if exists + if let ImageQuery::SizeLevel = wiq.query { + write!(self.out, ", uint {MIP_LEVEL_PARAM}")?; + } + writeln!(self.out, ")")?; + + // Write function body + writeln!(self.out, "{{")?; + + let array_coords = usize::from(wiq.arrayed); + // extra parameter is the mip level count or the sample count + let extra_coords = match wiq.class { + crate::ImageClass::Storage { .. } => 0, + crate::ImageClass::Sampled { .. } | crate::ImageClass::Depth { .. } => 1, + }; + + // GetDimensions Overloaded Methods + // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions#overloaded-methods + let (ret_swizzle, number_of_params) = match wiq.query { + ImageQuery::Size | ImageQuery::SizeLevel => { + let ret = match wiq.dim { + IDim::D1 => "x", + IDim::D2 => "xy", + IDim::D3 => "xyz", + IDim::Cube => "xy", + }; + (ret, ret.len() + array_coords + extra_coords) + } + ImageQuery::NumLevels | ImageQuery::NumSamples | ImageQuery::NumLayers => { + if wiq.arrayed || wiq.dim == IDim::D3 { + ("w", 4) + } else { + ("z", 3) + } + } + }; + + // Write `GetDimensions` function. + writeln!(self.out, "{INDENT}uint4 {RETURN_VARIABLE_NAME};")?; + write!(self.out, "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions(")?; + match wiq.query { + ImageQuery::SizeLevel => { + write!(self.out, "{MIP_LEVEL_PARAM}, ")?; + } + _ => match wiq.class { + crate::ImageClass::Sampled { multi: true, .. } + | crate::ImageClass::Depth { multi: true } + | crate::ImageClass::Storage { .. } => {} + _ => { + // Write zero mipmap level for supported types + write!(self.out, "0, ")?; + } + }, + } + + for component in COMPONENTS[..number_of_params - 1].iter() { + write!(self.out, "{RETURN_VARIABLE_NAME}.{component}, ")?; + } + + // write last parameter without comma and space for last parameter + write!( + self.out, + "{}.{}", + RETURN_VARIABLE_NAME, + COMPONENTS[number_of_params - 1] + )?; + + writeln!(self.out, ");")?; + + // Write return value + writeln!( + self.out, + "{INDENT}return {RETURN_VARIABLE_NAME}.{ret_swizzle};" + )?; + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } + + pub(super) fn write_wrapped_constructor_function_name( + &mut self, + module: &crate::Module, + constructor: WrappedConstructor, + ) -> BackendResult { + let name = crate::TypeInner::hlsl_type_id(constructor.ty, module.to_ctx(), &self.names)?; + write!(self.out, "Construct{name}")?; + Ok(()) + } + + /// Helper function that write wrapped function for `Expression::Compose` for structures. + pub(super) fn write_wrapped_constructor_function( + &mut self, + module: &crate::Module, + constructor: WrappedConstructor, + ) -> BackendResult { + use crate::back::INDENT; + + const ARGUMENT_VARIABLE_NAME: &str = "arg"; + const RETURN_VARIABLE_NAME: &str = "ret"; + + // Write function return type and name + if let crate::TypeInner::Array { base, size, .. } = module.types[constructor.ty].inner { + write!(self.out, "typedef ")?; + self.write_type(module, constructor.ty)?; + write!(self.out, " ret_")?; + self.write_wrapped_constructor_function_name(module, constructor)?; + self.write_array_size(module, base, size)?; + writeln!(self.out, ";")?; + + write!(self.out, "ret_")?; + self.write_wrapped_constructor_function_name(module, constructor)?; + } else { + self.write_type(module, constructor.ty)?; + } + write!(self.out, " ")?; + self.write_wrapped_constructor_function_name(module, constructor)?; + + // Write function parameters + write!(self.out, "(")?; + + let mut write_arg = |i, ty| -> BackendResult { + if i != 0 { + write!(self.out, ", ")?; + } + self.write_type(module, ty)?; + write!(self.out, " {ARGUMENT_VARIABLE_NAME}{i}")?; + if let crate::TypeInner::Array { base, size, .. } = module.types[ty].inner { + self.write_array_size(module, base, size)?; + } + Ok(()) + }; + + match module.types[constructor.ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + for (i, member) in members.iter().enumerate() { + write_arg(i, member.ty)?; + } + } + crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + .. + } => { + for i in 0..size.get() as usize { + write_arg(i, base)?; + } + } + _ => unreachable!(), + }; + + write!(self.out, ")")?; + + // Write function body + writeln!(self.out, " {{")?; + + match module.types[constructor.ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + let struct_name = &self.names[&NameKey::Type(constructor.ty)]; + writeln!( + self.out, + "{INDENT}{struct_name} {RETURN_VARIABLE_NAME} = ({struct_name})0;" + )?; + for (i, member) in members.iter().enumerate() { + let field_name = &self.names[&NameKey::StructMember(constructor.ty, i as u32)]; + + match module.types[member.ty].inner { + crate::TypeInner::Matrix { + columns, + rows: crate::VectorSize::Bi, + .. + } if member.binding.is_none() => { + for j in 0..columns as u8 { + writeln!( + self.out, + "{INDENT}{RETURN_VARIABLE_NAME}.{field_name}_{j} = {ARGUMENT_VARIABLE_NAME}{i}[{j}];" + )?; + } + } + ref other => { + // We cast arrays of native HLSL `floatCx2`s to arrays of `matCx2`s + // (where the inner matrix is represented by a struct with C `float2` members). + // See the module-level block comment in mod.rs for details. + if let Some(super::writer::MatrixType { + columns, + rows: crate::VectorSize::Bi, + width: 4, + }) = super::writer::get_inner_matrix_data(module, member.ty) + { + write!( + self.out, + "{}{}.{} = (__mat{}x2", + INDENT, RETURN_VARIABLE_NAME, field_name, columns as u8 + )?; + if let crate::TypeInner::Array { base, size, .. } = *other { + self.write_array_size(module, base, size)?; + } + writeln!(self.out, "){ARGUMENT_VARIABLE_NAME}{i};",)?; + } else { + writeln!( + self.out, + "{INDENT}{RETURN_VARIABLE_NAME}.{field_name} = {ARGUMENT_VARIABLE_NAME}{i};", + )?; + } + } + } + } + } + crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + .. + } => { + write!(self.out, "{INDENT}")?; + self.write_type(module, base)?; + write!(self.out, " {RETURN_VARIABLE_NAME}")?; + self.write_array_size(module, base, crate::ArraySize::Constant(size))?; + write!(self.out, " = {{ ")?; + for i in 0..size.get() { + if i != 0 { + write!(self.out, ", ")?; + } + write!(self.out, "{ARGUMENT_VARIABLE_NAME}{i}")?; + } + writeln!(self.out, " }};",)?; + } + _ => unreachable!(), + } + + // Write return value + writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?; + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } + + pub(super) fn write_wrapped_struct_matrix_get_function_name( + &mut self, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + let name = &self.names[&NameKey::Type(access.ty)]; + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + write!(self.out, "GetMat{field_name}On{name}")?; + Ok(()) + } + + /// Writes a function used to get a matCx2 from within a structure. + pub(super) fn write_wrapped_struct_matrix_get_function( + &mut self, + module: &crate::Module, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + use crate::back::INDENT; + + const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj"; + + // Write function return type and name + let member = match module.types[access.ty].inner { + crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize], + _ => unreachable!(), + }; + let ret_ty = &module.types[member.ty].inner; + self.write_value_type(module, ret_ty)?; + write!(self.out, " ")?; + self.write_wrapped_struct_matrix_get_function_name(access)?; + + // Write function parameters + write!(self.out, "(")?; + let struct_name = &self.names[&NameKey::Type(access.ty)]; + write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}")?; + + // Write function body + writeln!(self.out, ") {{")?; + + // Write return value + write!(self.out, "{INDENT}return ")?; + self.write_value_type(module, ret_ty)?; + write!(self.out, "(")?; + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + match module.types[member.ty].inner { + crate::TypeInner::Matrix { columns, .. } => { + for i in 0..columns as u8 { + if i != 0 { + write!(self.out, ", ")?; + } + write!(self.out, "{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}")?; + } + } + _ => unreachable!(), + } + writeln!(self.out, ");")?; + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } + + pub(super) fn write_wrapped_struct_matrix_set_function_name( + &mut self, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + let name = &self.names[&NameKey::Type(access.ty)]; + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + write!(self.out, "SetMat{field_name}On{name}")?; + Ok(()) + } + + /// Writes a function used to set a matCx2 from within a structure. + pub(super) fn write_wrapped_struct_matrix_set_function( + &mut self, + module: &crate::Module, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + use crate::back::INDENT; + + const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj"; + const MATRIX_ARGUMENT_VARIABLE_NAME: &str = "mat"; + + // Write function return type and name + write!(self.out, "void ")?; + self.write_wrapped_struct_matrix_set_function_name(access)?; + + // Write function parameters + write!(self.out, "(")?; + let struct_name = &self.names[&NameKey::Type(access.ty)]; + write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?; + let member = match module.types[access.ty].inner { + crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize], + _ => unreachable!(), + }; + self.write_type(module, member.ty)?; + write!(self.out, " {MATRIX_ARGUMENT_VARIABLE_NAME}")?; + // Write function body + writeln!(self.out, ") {{")?; + + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + + match module.types[member.ty].inner { + crate::TypeInner::Matrix { columns, .. } => { + for i in 0..columns as u8 { + writeln!( + self.out, + "{INDENT}{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {MATRIX_ARGUMENT_VARIABLE_NAME}[{i}];" + )?; + } + } + _ => unreachable!(), + } + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } + + pub(super) fn write_wrapped_struct_matrix_set_vec_function_name( + &mut self, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + let name = &self.names[&NameKey::Type(access.ty)]; + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + write!(self.out, "SetMatVec{field_name}On{name}")?; + Ok(()) + } + + /// Writes a function used to set a vec2 on a matCx2 from within a structure. + pub(super) fn write_wrapped_struct_matrix_set_vec_function( + &mut self, + module: &crate::Module, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + use crate::back::INDENT; + + const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj"; + const VECTOR_ARGUMENT_VARIABLE_NAME: &str = "vec"; + const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx"; + + // Write function return type and name + write!(self.out, "void ")?; + self.write_wrapped_struct_matrix_set_vec_function_name(access)?; + + // Write function parameters + write!(self.out, "(")?; + let struct_name = &self.names[&NameKey::Type(access.ty)]; + write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?; + let member = match module.types[access.ty].inner { + crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize], + _ => unreachable!(), + }; + let vec_ty = match module.types[member.ty].inner { + crate::TypeInner::Matrix { rows, scalar, .. } => { + crate::TypeInner::Vector { size: rows, scalar } + } + _ => unreachable!(), + }; + self.write_value_type(module, &vec_ty)?; + write!( + self.out, + " {VECTOR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}" + )?; + + // Write function body + writeln!(self.out, ") {{")?; + + writeln!( + self.out, + "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{" + )?; + + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + + match module.types[member.ty].inner { + crate::TypeInner::Matrix { columns, .. } => { + for i in 0..columns as u8 { + writeln!( + self.out, + "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {VECTOR_ARGUMENT_VARIABLE_NAME}; break; }}" + )?; + } + } + _ => unreachable!(), + } + + writeln!(self.out, "{INDENT}}}")?; + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } + + pub(super) fn write_wrapped_struct_matrix_set_scalar_function_name( + &mut self, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + let name = &self.names[&NameKey::Type(access.ty)]; + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + write!(self.out, "SetMatScalar{field_name}On{name}")?; + Ok(()) + } + + /// Writes a function used to set a float on a matCx2 from within a structure. + pub(super) fn write_wrapped_struct_matrix_set_scalar_function( + &mut self, + module: &crate::Module, + access: WrappedStructMatrixAccess, + ) -> BackendResult { + use crate::back::INDENT; + + const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj"; + const SCALAR_ARGUMENT_VARIABLE_NAME: &str = "scalar"; + const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx"; + const VECTOR_INDEX_ARGUMENT_VARIABLE_NAME: &str = "vec_idx"; + + // Write function return type and name + write!(self.out, "void ")?; + self.write_wrapped_struct_matrix_set_scalar_function_name(access)?; + + // Write function parameters + write!(self.out, "(")?; + let struct_name = &self.names[&NameKey::Type(access.ty)]; + write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?; + let member = match module.types[access.ty].inner { + crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize], + _ => unreachable!(), + }; + let scalar_ty = match module.types[member.ty].inner { + crate::TypeInner::Matrix { scalar, .. } => crate::TypeInner::Scalar(scalar), + _ => unreachable!(), + }; + self.write_value_type(module, &scalar_ty)?; + write!( + self.out, + " {SCALAR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}, uint {VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}" + )?; + + // Write function body + writeln!(self.out, ") {{")?; + + writeln!( + self.out, + "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{" + )?; + + let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)]; + + match module.types[member.ty].inner { + crate::TypeInner::Matrix { columns, .. } => { + for i in 0..columns as u8 { + writeln!( + self.out, + "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}[{VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}] = {SCALAR_ARGUMENT_VARIABLE_NAME}; break; }}" + )?; + } + } + _ => unreachable!(), + } + + writeln!(self.out, "{INDENT}}}")?; + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } + + /// Write functions to create special types. + pub(super) fn write_special_functions(&mut self, module: &crate::Module) -> BackendResult { + for (type_key, struct_ty) in module.special_types.predeclared_types.iter() { + match type_key { + &crate::PredeclaredType::ModfResult { size, width } + | &crate::PredeclaredType::FrexpResult { size, width } => { + let arg_type_name_owner; + let arg_type_name = if let Some(size) = size { + arg_type_name_owner = format!( + "{}{}", + if width == 8 { "double" } else { "float" }, + size as u8 + ); + &arg_type_name_owner + } else if width == 8 { + "double" + } else { + "float" + }; + + let (defined_func_name, called_func_name, second_field_name, sign_multiplier) = + if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) { + (super::writer::MODF_FUNCTION, "modf", "whole", "") + } else { + ( + super::writer::FREXP_FUNCTION, + "frexp", + "exp_", + "sign(arg) * ", + ) + }; + + let struct_name = &self.names[&NameKey::Type(*struct_ty)]; + + writeln!( + self.out, + "{struct_name} {defined_func_name}({arg_type_name} arg) {{ + {arg_type_name} other; + {struct_name} result; + result.fract = {sign_multiplier}{called_func_name}(arg, other); + result.{second_field_name} = other; + return result; +}}" + )?; + writeln!(self.out)?; + } + &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {} + } + } + + Ok(()) + } + + /// Helper function that writes compose wrapped functions + pub(super) fn write_wrapped_compose_functions( + &mut self, + module: &crate::Module, + expressions: &crate::Arena<crate::Expression>, + ) -> BackendResult { + for (handle, _) in expressions.iter() { + if let crate::Expression::Compose { ty, .. } = expressions[handle] { + match module.types[ty].inner { + crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => { + let constructor = WrappedConstructor { ty }; + if self.wrapped.constructors.insert(constructor) { + self.write_wrapped_constructor_function(module, constructor)?; + } + } + _ => {} + }; + } + } + Ok(()) + } + + /// Helper function that writes various wrapped functions + pub(super) fn write_wrapped_functions( + &mut self, + module: &crate::Module, + func_ctx: &FunctionCtx, + ) -> BackendResult { + self.write_wrapped_compose_functions(module, func_ctx.expressions)?; + + for (handle, _) in func_ctx.expressions.iter() { + match func_ctx.expressions[handle] { + crate::Expression::ArrayLength(expr) => { + let global_expr = match func_ctx.expressions[expr] { + crate::Expression::GlobalVariable(_) => expr, + crate::Expression::AccessIndex { base, index: _ } => base, + ref other => unreachable!("Array length of {:?}", other), + }; + let global_var = match func_ctx.expressions[global_expr] { + crate::Expression::GlobalVariable(var_handle) => { + &module.global_variables[var_handle] + } + ref other => unreachable!("Array length of base {:?}", other), + }; + let storage_access = match global_var.space { + crate::AddressSpace::Storage { access } => access, + _ => crate::StorageAccess::default(), + }; + let wal = WrappedArrayLength { + writable: storage_access.contains(crate::StorageAccess::STORE), + }; + + if self.wrapped.array_lengths.insert(wal) { + self.write_wrapped_array_length_function(wal)?; + } + } + crate::Expression::ImageQuery { image, query } => { + let wiq = match *func_ctx.resolve_type(image, &module.types) { + crate::TypeInner::Image { + dim, + arrayed, + class, + } => WrappedImageQuery { + dim, + arrayed, + class, + query: query.into(), + }, + _ => unreachable!("we only query images"), + }; + + if self.wrapped.image_queries.insert(wiq) { + self.write_wrapped_image_query_function(module, wiq, handle, func_ctx)?; + } + } + // Write `WrappedConstructor` for structs that are loaded from `AddressSpace::Storage` + // since they will later be used by the fn `write_storage_load` + crate::Expression::Load { pointer } => { + let pointer_space = func_ctx + .resolve_type(pointer, &module.types) + .pointer_space(); + + if let Some(crate::AddressSpace::Storage { .. }) = pointer_space { + if let Some(ty) = func_ctx.info[handle].ty.handle() { + write_wrapped_constructor(self, ty, module)?; + } + } + + fn write_wrapped_constructor<W: Write>( + writer: &mut super::Writer<'_, W>, + ty: Handle<crate::Type>, + module: &crate::Module, + ) -> BackendResult { + match module.types[ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + for member in members { + write_wrapped_constructor(writer, member.ty, module)?; + } + + let constructor = WrappedConstructor { ty }; + if writer.wrapped.constructors.insert(constructor) { + writer + .write_wrapped_constructor_function(module, constructor)?; + } + } + crate::TypeInner::Array { base, .. } => { + write_wrapped_constructor(writer, base, module)?; + + let constructor = WrappedConstructor { ty }; + if writer.wrapped.constructors.insert(constructor) { + writer + .write_wrapped_constructor_function(module, constructor)?; + } + } + _ => {} + }; + + Ok(()) + } + } + // We treat matrices of the form `matCx2` as a sequence of C `vec2`s + // (see top level module docs for details). + // + // The functions injected here are required to get the matrix accesses working. + crate::Expression::AccessIndex { base, index } => { + let base_ty_res = &func_ctx.info[base].ty; + let mut resolved = base_ty_res.inner_with(&module.types); + let base_ty_handle = match *resolved { + crate::TypeInner::Pointer { base, .. } => { + resolved = &module.types[base].inner; + Some(base) + } + _ => base_ty_res.handle(), + }; + if let crate::TypeInner::Struct { ref members, .. } = *resolved { + let member = &members[index as usize]; + + match module.types[member.ty].inner { + crate::TypeInner::Matrix { + rows: crate::VectorSize::Bi, + .. + } if member.binding.is_none() => { + let ty = base_ty_handle.unwrap(); + let access = WrappedStructMatrixAccess { ty, index }; + + if self.wrapped.struct_matrix_access.insert(access) { + self.write_wrapped_struct_matrix_get_function(module, access)?; + self.write_wrapped_struct_matrix_set_function(module, access)?; + self.write_wrapped_struct_matrix_set_vec_function( + module, access, + )?; + self.write_wrapped_struct_matrix_set_scalar_function( + module, access, + )?; + } + } + _ => {} + } + } + } + _ => {} + }; + } + + Ok(()) + } + + pub(super) fn write_texture_coordinates( + &mut self, + kind: &str, + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + mip_level: Option<Handle<crate::Expression>>, + module: &crate::Module, + func_ctx: &FunctionCtx, + ) -> BackendResult { + // HLSL expects the array index to be merged with the coordinate + let extra = array_index.is_some() as usize + (mip_level.is_some()) as usize; + if extra == 0 { + self.write_expr(module, coordinate, func_ctx)?; + } else { + let num_coords = match *func_ctx.resolve_type(coordinate, &module.types) { + crate::TypeInner::Scalar { .. } => 1, + crate::TypeInner::Vector { size, .. } => size as usize, + _ => unreachable!(), + }; + write!(self.out, "{}{}(", kind, num_coords + extra)?; + self.write_expr(module, coordinate, func_ctx)?; + if let Some(expr) = array_index { + write!(self.out, ", ")?; + self.write_expr(module, expr, func_ctx)?; + } + if let Some(expr) = mip_level { + write!(self.out, ", ")?; + self.write_expr(module, expr, func_ctx)?; + } + write!(self.out, ")")?; + } + Ok(()) + } + + pub(super) fn write_mat_cx2_typedef_and_functions( + &mut self, + WrappedMatCx2 { columns }: WrappedMatCx2, + ) -> BackendResult { + use crate::back::INDENT; + + // typedef + write!(self.out, "typedef struct {{ ")?; + for i in 0..columns as u8 { + write!(self.out, "float2 _{i}; ")?; + } + writeln!(self.out, "}} __mat{}x2;", columns as u8)?; + + // __get_col_of_mat + writeln!( + self.out, + "float2 __get_col_of_mat{}x2(__mat{}x2 mat, uint idx) {{", + columns as u8, columns as u8 + )?; + writeln!(self.out, "{INDENT}switch(idx) {{")?; + for i in 0..columns as u8 { + writeln!(self.out, "{INDENT}case {i}: {{ return mat._{i}; }}")?; + } + writeln!(self.out, "{INDENT}default: {{ return (float2)0; }}")?; + writeln!(self.out, "{INDENT}}}")?; + writeln!(self.out, "}}")?; + + // __set_col_of_mat + writeln!( + self.out, + "void __set_col_of_mat{}x2(__mat{}x2 mat, uint idx, float2 value) {{", + columns as u8, columns as u8 + )?; + writeln!(self.out, "{INDENT}switch(idx) {{")?; + for i in 0..columns as u8 { + writeln!(self.out, "{INDENT}case {i}: {{ mat._{i} = value; break; }}")?; + } + writeln!(self.out, "{INDENT}}}")?; + writeln!(self.out, "}}")?; + + // __set_el_of_mat + writeln!( + self.out, + "void __set_el_of_mat{}x2(__mat{}x2 mat, uint idx, uint vec_idx, float value) {{", + columns as u8, columns as u8 + )?; + writeln!(self.out, "{INDENT}switch(idx) {{")?; + for i in 0..columns as u8 { + writeln!( + self.out, + "{INDENT}case {i}: {{ mat._{i}[vec_idx] = value; break; }}" + )?; + } + writeln!(self.out, "{INDENT}}}")?; + writeln!(self.out, "}}")?; + + writeln!(self.out)?; + + Ok(()) + } + + pub(super) fn write_all_mat_cx2_typedefs_and_functions( + &mut self, + module: &crate::Module, + ) -> BackendResult { + for (handle, _) in module.global_variables.iter() { + let global = &module.global_variables[handle]; + + if global.space == crate::AddressSpace::Uniform { + if let Some(super::writer::MatrixType { + columns, + rows: crate::VectorSize::Bi, + width: 4, + }) = super::writer::get_inner_matrix_data(module, global.ty) + { + let entry = WrappedMatCx2 { columns }; + if self.wrapped.mat_cx2s.insert(entry) { + self.write_mat_cx2_typedef_and_functions(entry)?; + } + } + } + } + + for (_, ty) in module.types.iter() { + if let crate::TypeInner::Struct { ref members, .. } = ty.inner { + for member in members.iter() { + if let crate::TypeInner::Array { .. } = module.types[member.ty].inner { + if let Some(super::writer::MatrixType { + columns, + rows: crate::VectorSize::Bi, + width: 4, + }) = super::writer::get_inner_matrix_data(module, member.ty) + { + let entry = WrappedMatCx2 { columns }; + if self.wrapped.mat_cx2s.insert(entry) { + self.write_mat_cx2_typedef_and_functions(entry)?; + } + } + } + } + } + } + + Ok(()) + } +} diff --git a/third_party/rust/naga/src/back/hlsl/keywords.rs b/third_party/rust/naga/src/back/hlsl/keywords.rs new file mode 100644 index 0000000000..059e533ff7 --- /dev/null +++ b/third_party/rust/naga/src/back/hlsl/keywords.rs @@ -0,0 +1,904 @@ +// When compiling with FXC without strict mode, these keywords are actually case insensitive. +// If you compile with strict mode and specify a different casing like "Pass" instead in an identifier, FXC will give this error: +// "error X3086: alternate cases for 'pass' are deprecated in strict mode" +// This behavior is not documented anywhere, but as far as I can tell this is the full list. +pub const RESERVED_CASE_INSENSITIVE: &[&str] = &[ + "asm", + "decl", + "pass", + "technique", + "Texture1D", + "Texture2D", + "Texture3D", + "TextureCube", +]; + +pub const RESERVED: &[&str] = &[ + // FXC keywords, from https://github.com/MicrosoftDocs/win32/blob/c885cb0c63b0e9be80c6a0e6512473ac6f4e771e/desktop-src/direct3dhlsl/dx-graphics-hlsl-appendix-keywords.md?plain=1#L99-L118 + "AppendStructuredBuffer", + "asm", + "asm_fragment", + "BlendState", + "bool", + "break", + "Buffer", + "ByteAddressBuffer", + "case", + "cbuffer", + "centroid", + "class", + "column_major", + "compile", + "compile_fragment", + "CompileShader", + "const", + "continue", + "ComputeShader", + "ConsumeStructuredBuffer", + "default", + "DepthStencilState", + "DepthStencilView", + "discard", + "do", + "double", + "DomainShader", + "dword", + "else", + "export", + "extern", + "false", + "float", + "for", + "fxgroup", + "GeometryShader", + "groupshared", + "half", + "Hullshader", + "if", + "in", + "inline", + "inout", + "InputPatch", + "int", + "interface", + "line", + "lineadj", + "linear", + "LineStream", + "matrix", + "min16float", + "min10float", + "min16int", + "min12int", + "min16uint", + "namespace", + "nointerpolation", + "noperspective", + "NULL", + "out", + "OutputPatch", + "packoffset", + "pass", + "pixelfragment", + "PixelShader", + "point", + "PointStream", + "precise", + "RasterizerState", + "RenderTargetView", + "return", + "register", + "row_major", + "RWBuffer", + "RWByteAddressBuffer", + "RWStructuredBuffer", + "RWTexture1D", + "RWTexture1DArray", + "RWTexture2D", + "RWTexture2DArray", + "RWTexture3D", + "sample", + "sampler", + "SamplerState", + "SamplerComparisonState", + "shared", + "snorm", + "stateblock", + "stateblock_state", + "static", + "string", + "struct", + "switch", + "StructuredBuffer", + "tbuffer", + "technique", + "technique10", + "technique11", + "texture", + "Texture1D", + "Texture1DArray", + "Texture2D", + "Texture2DArray", + "Texture2DMS", + "Texture2DMSArray", + "Texture3D", + "TextureCube", + "TextureCubeArray", + "true", + "typedef", + "triangle", + "triangleadj", + "TriangleStream", + "uint", + "uniform", + "unorm", + "unsigned", + "vector", + "vertexfragment", + "VertexShader", + "void", + "volatile", + "while", + // FXC reserved keywords, from https://github.com/MicrosoftDocs/win32/blob/c885cb0c63b0e9be80c6a0e6512473ac6f4e771e/desktop-src/direct3dhlsl/dx-graphics-hlsl-appendix-reserved-words.md?plain=1#L19-L38 + "auto", + "case", + "catch", + "char", + "class", + "const_cast", + "default", + "delete", + "dynamic_cast", + "enum", + "explicit", + "friend", + "goto", + "long", + "mutable", + "new", + "operator", + "private", + "protected", + "public", + "reinterpret_cast", + "short", + "signed", + "sizeof", + "static_cast", + "template", + "this", + "throw", + "try", + "typename", + "union", + "unsigned", + "using", + "virtual", + // FXC intrinsics, from https://github.com/MicrosoftDocs/win32/blob/1682b99e203708f6f5eda972d966e30f3c1588de/desktop-src/direct3dhlsl/dx-graphics-hlsl-intrinsic-functions.md?plain=1#L26-L165 + "abort", + "abs", + "acos", + "all", + "AllMemoryBarrier", + "AllMemoryBarrierWithGroupSync", + "any", + "asdouble", + "asfloat", + "asin", + "asint", + "asuint", + "atan", + "atan2", + "ceil", + "CheckAccessFullyMapped", + "clamp", + "clip", + "cos", + "cosh", + "countbits", + "cross", + "D3DCOLORtoUBYTE4", + "ddx", + "ddx_coarse", + "ddx_fine", + "ddy", + "ddy_coarse", + "ddy_fine", + "degrees", + "determinant", + "DeviceMemoryBarrier", + "DeviceMemoryBarrierWithGroupSync", + "distance", + "dot", + "dst", + "errorf", + "EvaluateAttributeCentroid", + "EvaluateAttributeAtSample", + "EvaluateAttributeSnapped", + "exp", + "exp2", + "f16tof32", + "f32tof16", + "faceforward", + "firstbithigh", + "firstbitlow", + "floor", + "fma", + "fmod", + "frac", + "frexp", + "fwidth", + "GetRenderTargetSampleCount", + "GetRenderTargetSamplePosition", + "GroupMemoryBarrier", + "GroupMemoryBarrierWithGroupSync", + "InterlockedAdd", + "InterlockedAnd", + "InterlockedCompareExchange", + "InterlockedCompareStore", + "InterlockedExchange", + "InterlockedMax", + "InterlockedMin", + "InterlockedOr", + "InterlockedXor", + "isfinite", + "isinf", + "isnan", + "ldexp", + "length", + "lerp", + "lit", + "log", + "log10", + "log2", + "mad", + "max", + "min", + "modf", + "msad4", + "mul", + "noise", + "normalize", + "pow", + "printf", + "Process2DQuadTessFactorsAvg", + "Process2DQuadTessFactorsMax", + "Process2DQuadTessFactorsMin", + "ProcessIsolineTessFactors", + "ProcessQuadTessFactorsAvg", + "ProcessQuadTessFactorsMax", + "ProcessQuadTessFactorsMin", + "ProcessTriTessFactorsAvg", + "ProcessTriTessFactorsMax", + "ProcessTriTessFactorsMin", + "radians", + "rcp", + "reflect", + "refract", + "reversebits", + "round", + "rsqrt", + "saturate", + "sign", + "sin", + "sincos", + "sinh", + "smoothstep", + "sqrt", + "step", + "tan", + "tanh", + "tex1D", + "tex1Dbias", + "tex1Dgrad", + "tex1Dlod", + "tex1Dproj", + "tex2D", + "tex2Dbias", + "tex2Dgrad", + "tex2Dlod", + "tex2Dproj", + "tex3D", + "tex3Dbias", + "tex3Dgrad", + "tex3Dlod", + "tex3Dproj", + "texCUBE", + "texCUBEbias", + "texCUBEgrad", + "texCUBElod", + "texCUBEproj", + "transpose", + "trunc", + // DXC (reserved) keywords, from https://github.com/microsoft/DirectXShaderCompiler/blob/d5d478470d3020a438d3cb810b8d3fe0992e6709/tools/clang/include/clang/Basic/TokenKinds.def#L222-L648 + // with the KEYALL, KEYCXX, BOOLSUPPORT, WCHARSUPPORT, KEYHLSL options enabled (see https://github.com/microsoft/DirectXShaderCompiler/blob/d5d478470d3020a438d3cb810b8d3fe0992e6709/tools/clang/lib/Frontend/CompilerInvocation.cpp#L1199) + "auto", + "break", + "case", + "char", + "const", + "continue", + "default", + "do", + "double", + "else", + "enum", + "extern", + "float", + "for", + "goto", + "if", + "inline", + "int", + "long", + "register", + "return", + "short", + "signed", + "sizeof", + "static", + "struct", + "switch", + "typedef", + "union", + "unsigned", + "void", + "volatile", + "while", + "_Alignas", + "_Alignof", + "_Atomic", + "_Complex", + "_Generic", + "_Imaginary", + "_Noreturn", + "_Static_assert", + "_Thread_local", + "__func__", + "__objc_yes", + "__objc_no", + "asm", + "bool", + "catch", + "class", + "const_cast", + "delete", + "dynamic_cast", + "explicit", + "export", + "false", + "friend", + "mutable", + "namespace", + "new", + "operator", + "private", + "protected", + "public", + "reinterpret_cast", + "static_cast", + "template", + "this", + "throw", + "true", + "try", + "typename", + "typeid", + "using", + "virtual", + "wchar_t", + "_Decimal32", + "_Decimal64", + "_Decimal128", + "__null", + "__alignof", + "__attribute", + "__builtin_choose_expr", + "__builtin_offsetof", + "__builtin_va_arg", + "__extension__", + "__imag", + "__int128", + "__label__", + "__real", + "__thread", + "__FUNCTION__", + "__PRETTY_FUNCTION__", + "__is_nothrow_assignable", + "__is_constructible", + "__is_nothrow_constructible", + "__has_nothrow_assign", + "__has_nothrow_move_assign", + "__has_nothrow_copy", + "__has_nothrow_constructor", + "__has_trivial_assign", + "__has_trivial_move_assign", + "__has_trivial_copy", + "__has_trivial_constructor", + "__has_trivial_move_constructor", + "__has_trivial_destructor", + "__has_virtual_destructor", + "__is_abstract", + "__is_base_of", + "__is_class", + "__is_convertible_to", + "__is_empty", + "__is_enum", + "__is_final", + "__is_literal", + "__is_literal_type", + "__is_pod", + "__is_polymorphic", + "__is_trivial", + "__is_union", + "__is_trivially_constructible", + "__is_trivially_copyable", + "__is_trivially_assignable", + "__underlying_type", + "__is_lvalue_expr", + "__is_rvalue_expr", + "__is_arithmetic", + "__is_floating_point", + "__is_integral", + "__is_complete_type", + "__is_void", + "__is_array", + "__is_function", + "__is_reference", + "__is_lvalue_reference", + "__is_rvalue_reference", + "__is_fundamental", + "__is_object", + "__is_scalar", + "__is_compound", + "__is_pointer", + "__is_member_object_pointer", + "__is_member_function_pointer", + "__is_member_pointer", + "__is_const", + "__is_volatile", + "__is_standard_layout", + "__is_signed", + "__is_unsigned", + "__is_same", + "__is_convertible", + "__array_rank", + "__array_extent", + "__private_extern__", + "__module_private__", + "__declspec", + "__cdecl", + "__stdcall", + "__fastcall", + "__thiscall", + "__vectorcall", + "cbuffer", + "tbuffer", + "packoffset", + "linear", + "centroid", + "nointerpolation", + "noperspective", + "sample", + "column_major", + "row_major", + "in", + "out", + "inout", + "uniform", + "precise", + "center", + "shared", + "groupshared", + "discard", + "snorm", + "unorm", + "point", + "line", + "lineadj", + "triangle", + "triangleadj", + "globallycoherent", + "interface", + "sampler_state", + "technique", + "indices", + "vertices", + "primitives", + "payload", + "Technique", + "technique10", + "technique11", + "__builtin_omp_required_simd_align", + "__pascal", + "__fp16", + "__alignof__", + "__asm", + "__asm__", + "__attribute__", + "__complex", + "__complex__", + "__const", + "__const__", + "__decltype", + "__imag__", + "__inline", + "__inline__", + "__nullptr", + "__real__", + "__restrict", + "__restrict__", + "__signed", + "__signed__", + "__typeof", + "__typeof__", + "__volatile", + "__volatile__", + "_Nonnull", + "_Nullable", + "_Null_unspecified", + "__builtin_convertvector", + "__char16_t", + "__char32_t", + // DXC intrinsics, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/utils/hct/gen_intrin_main.txt#L86-L376 + "D3DCOLORtoUBYTE4", + "GetRenderTargetSampleCount", + "GetRenderTargetSamplePosition", + "abort", + "abs", + "acos", + "all", + "AllMemoryBarrier", + "AllMemoryBarrierWithGroupSync", + "any", + "asdouble", + "asfloat", + "asfloat16", + "asint16", + "asin", + "asint", + "asuint", + "asuint16", + "atan", + "atan2", + "ceil", + "clamp", + "clip", + "cos", + "cosh", + "countbits", + "cross", + "ddx", + "ddx_coarse", + "ddx_fine", + "ddy", + "ddy_coarse", + "ddy_fine", + "degrees", + "determinant", + "DeviceMemoryBarrier", + "DeviceMemoryBarrierWithGroupSync", + "distance", + "dot", + "dst", + "EvaluateAttributeAtSample", + "EvaluateAttributeCentroid", + "EvaluateAttributeSnapped", + "GetAttributeAtVertex", + "exp", + "exp2", + "f16tof32", + "f32tof16", + "faceforward", + "firstbithigh", + "firstbitlow", + "floor", + "fma", + "fmod", + "frac", + "frexp", + "fwidth", + "GroupMemoryBarrier", + "GroupMemoryBarrierWithGroupSync", + "InterlockedAdd", + "InterlockedMin", + "InterlockedMax", + "InterlockedAnd", + "InterlockedOr", + "InterlockedXor", + "InterlockedCompareStore", + "InterlockedExchange", + "InterlockedCompareExchange", + "InterlockedCompareStoreFloatBitwise", + "InterlockedCompareExchangeFloatBitwise", + "isfinite", + "isinf", + "isnan", + "ldexp", + "length", + "lerp", + "lit", + "log", + "log10", + "log2", + "mad", + "max", + "min", + "modf", + "msad4", + "mul", + "normalize", + "pow", + "printf", + "Process2DQuadTessFactorsAvg", + "Process2DQuadTessFactorsMax", + "Process2DQuadTessFactorsMin", + "ProcessIsolineTessFactors", + "ProcessQuadTessFactorsAvg", + "ProcessQuadTessFactorsMax", + "ProcessQuadTessFactorsMin", + "ProcessTriTessFactorsAvg", + "ProcessTriTessFactorsMax", + "ProcessTriTessFactorsMin", + "radians", + "rcp", + "reflect", + "refract", + "reversebits", + "round", + "rsqrt", + "saturate", + "sign", + "sin", + "sincos", + "sinh", + "smoothstep", + "source_mark", + "sqrt", + "step", + "tan", + "tanh", + "tex1D", + "tex1Dbias", + "tex1Dgrad", + "tex1Dlod", + "tex1Dproj", + "tex2D", + "tex2Dbias", + "tex2Dgrad", + "tex2Dlod", + "tex2Dproj", + "tex3D", + "tex3Dbias", + "tex3Dgrad", + "tex3Dlod", + "tex3Dproj", + "texCUBE", + "texCUBEbias", + "texCUBEgrad", + "texCUBElod", + "texCUBEproj", + "transpose", + "trunc", + "CheckAccessFullyMapped", + "AddUint64", + "NonUniformResourceIndex", + "WaveIsFirstLane", + "WaveGetLaneIndex", + "WaveGetLaneCount", + "WaveActiveAnyTrue", + "WaveActiveAllTrue", + "WaveActiveAllEqual", + "WaveActiveBallot", + "WaveReadLaneAt", + "WaveReadLaneFirst", + "WaveActiveCountBits", + "WaveActiveSum", + "WaveActiveProduct", + "WaveActiveBitAnd", + "WaveActiveBitOr", + "WaveActiveBitXor", + "WaveActiveMin", + "WaveActiveMax", + "WavePrefixCountBits", + "WavePrefixSum", + "WavePrefixProduct", + "WaveMatch", + "WaveMultiPrefixBitAnd", + "WaveMultiPrefixBitOr", + "WaveMultiPrefixBitXor", + "WaveMultiPrefixCountBits", + "WaveMultiPrefixProduct", + "WaveMultiPrefixSum", + "QuadReadLaneAt", + "QuadReadAcrossX", + "QuadReadAcrossY", + "QuadReadAcrossDiagonal", + "QuadAny", + "QuadAll", + "TraceRay", + "ReportHit", + "CallShader", + "IgnoreHit", + "AcceptHitAndEndSearch", + "DispatchRaysIndex", + "DispatchRaysDimensions", + "WorldRayOrigin", + "WorldRayDirection", + "ObjectRayOrigin", + "ObjectRayDirection", + "RayTMin", + "RayTCurrent", + "PrimitiveIndex", + "InstanceID", + "InstanceIndex", + "GeometryIndex", + "HitKind", + "RayFlags", + "ObjectToWorld", + "WorldToObject", + "ObjectToWorld3x4", + "WorldToObject3x4", + "ObjectToWorld4x3", + "WorldToObject4x3", + "dot4add_u8packed", + "dot4add_i8packed", + "dot2add", + "unpack_s8s16", + "unpack_u8u16", + "unpack_s8s32", + "unpack_u8u32", + "pack_s8", + "pack_u8", + "pack_clamp_s8", + "pack_clamp_u8", + "SetMeshOutputCounts", + "DispatchMesh", + "IsHelperLane", + "AllocateRayQuery", + "CreateResourceFromHeap", + "and", + "or", + "select", + // DXC resource and other types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/HlslTypes.cpp#L441-#L572 + "InputPatch", + "OutputPatch", + "PointStream", + "LineStream", + "TriangleStream", + "Texture1D", + "RWTexture1D", + "Texture2D", + "RWTexture2D", + "Texture2DMS", + "RWTexture2DMS", + "Texture3D", + "RWTexture3D", + "TextureCube", + "RWTextureCube", + "Texture1DArray", + "RWTexture1DArray", + "Texture2DArray", + "RWTexture2DArray", + "Texture2DMSArray", + "RWTexture2DMSArray", + "TextureCubeArray", + "RWTextureCubeArray", + "FeedbackTexture2D", + "FeedbackTexture2DArray", + "RasterizerOrderedTexture1D", + "RasterizerOrderedTexture2D", + "RasterizerOrderedTexture3D", + "RasterizerOrderedTexture1DArray", + "RasterizerOrderedTexture2DArray", + "RasterizerOrderedBuffer", + "RasterizerOrderedByteAddressBuffer", + "RasterizerOrderedStructuredBuffer", + "ByteAddressBuffer", + "RWByteAddressBuffer", + "StructuredBuffer", + "RWStructuredBuffer", + "AppendStructuredBuffer", + "ConsumeStructuredBuffer", + "Buffer", + "RWBuffer", + "SamplerState", + "SamplerComparisonState", + "ConstantBuffer", + "TextureBuffer", + "RaytracingAccelerationStructure", + // DXC templated types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp + // look for `BuiltinTypeDeclBuilder` + "matrix", + "vector", + "TextureBuffer", + "ConstantBuffer", + "RayQuery", + // Naga utilities + super::writer::MODF_FUNCTION, + super::writer::FREXP_FUNCTION, +]; + +// DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254 +// + vector and matrix shorthands +pub const TYPES: &[&str] = &{ + const L: usize = 23 * (1 + 4 + 4 * 4); + let mut res = [""; L]; + let mut c = 0; + + /// For each scalar type, it will additionally generate vector and matrix shorthands + macro_rules! generate { + ([$($roots:literal),*], $x:tt) => { + $( + generate!(@inner push $roots); + generate!(@inner $roots, $x); + )* + }; + + (@inner $root:literal, [$($x:literal),*]) => { + generate!(@inner vector $root, $($x)*); + generate!(@inner matrix $root, $($x)*); + }; + + (@inner vector $root:literal, $($x:literal)*) => { + $( + generate!(@inner push concat!($root, $x)); + )* + }; + + (@inner matrix $root:literal, $($x:literal)*) => { + // Duplicate the list + generate!(@inner matrix $root, $($x)*; $($x)*); + }; + + // The head/tail recursion: pick the first element of the first list and recursively do it for the tail. + (@inner matrix $root:literal, $head:literal $($tail:literal)*; $($x:literal)*) => { + $( + generate!(@inner push concat!($root, $head, "x", $x)); + )* + generate!(@inner matrix $root, $($tail)*; $($x)*); + + }; + + // The end of iteration: we exhausted the list + (@inner matrix $root:literal, ; $($x:literal)*) => {}; + + (@inner push $v:expr) => { + res[c] = $v; + c += 1; + }; + } + + generate!( + [ + "bool", + "int", + "uint", + "dword", + "half", + "float", + "double", + "min10float", + "min16float", + "min12int", + "min16int", + "min16uint", + "int16_t", + "int32_t", + "int64_t", + "uint16_t", + "uint32_t", + "uint64_t", + "float16_t", + "float32_t", + "float64_t", + "int8_t4_packed", + "uint8_t4_packed" + ], + ["1", "2", "3", "4"] + ); + + debug_assert!(c == L); + + res +}; diff --git a/third_party/rust/naga/src/back/hlsl/mod.rs b/third_party/rust/naga/src/back/hlsl/mod.rs new file mode 100644 index 0000000000..37ddbd3d67 --- /dev/null +++ b/third_party/rust/naga/src/back/hlsl/mod.rs @@ -0,0 +1,302 @@ +/*! +Backend for [HLSL][hlsl] (High-Level Shading Language). + +# Supported shader model versions: +- 5.0 +- 5.1 +- 6.0 + +# Layout of values in `uniform` buffers + +WGSL's ["Internal Layout of Values"][ilov] rules specify how each WGSL +type should be stored in `uniform` and `storage` buffers. The HLSL we +generate must access values in that form, even when it is not what +HLSL would use normally. + +The rules described here only apply to WGSL `uniform` variables. WGSL +`storage` buffers are translated as HLSL `ByteAddressBuffers`, for +which we generate `Load` and `Store` method calls with explicit byte +offsets. WGSL pipeline inputs must be scalars or vectors; they cannot +be matrices, which is where the interesting problems arise. + +## Row- and column-major ordering for matrices + +WGSL specifies that matrices in uniform buffers are stored in +column-major order. This matches HLSL's default, so one might expect +things to be straightforward. Unfortunately, WGSL and HLSL disagree on +what indexing a matrix means: in WGSL, `m[i]` retrieves the `i`'th +*column* of `m`, whereas in HLSL it retrieves the `i`'th *row*. We +want to avoid translating `m[i]` into some complicated reassembly of a +vector from individually fetched components, so this is a problem. + +However, with a bit of trickery, it is possible to use HLSL's `m[i]` +as the translation of WGSL's `m[i]`: + +- We declare all matrices in uniform buffers in HLSL with the + `row_major` qualifier, and transpose the row and column counts: a + WGSL `mat3x4<f32>`, say, becomes an HLSL `row_major float3x4`. (Note + that WGSL and HLSL type names put the row and column in reverse + order.) Since the HLSL type is the transpose of how WebGPU directs + the user to store the data, HLSL will load all matrices transposed. + +- Since matrices are transposed, an HLSL indexing expression retrieves + the "columns" of the intended WGSL value, as desired. + +- For vector-matrix multiplication, since `mul(transpose(m), v)` is + equivalent to `mul(v, m)` (note the reversal of the arguments), and + `mul(v, transpose(m))` is equivalent to `mul(m, v)`, we can + translate WGSL `m * v` and `v * m` to HLSL by simply reversing the + arguments to `mul`. + +## Padding in two-row matrices + +An HLSL `row_major floatKx2` matrix has padding between its rows that +the WGSL `matKx2<f32>` matrix it represents does not. HLSL stores all +matrix rows [aligned on 16-byte boundaries][16bb], whereas WGSL says +that the columns of a `matKx2<f32>` need only be [aligned as required +for `vec2<f32>`][ilov], which is [eight-byte alignment][8bb]. + +To compensate for this, any time a `matKx2<f32>` appears in a WGSL +`uniform` variable, whether directly as the variable's type or as part +of a struct/array, we actually emit `K` separate `float2` members, and +assemble/disassemble the matrix from its columns (in WGSL; rows in +HLSL) upon load and store. + +For example, the following WGSL struct type: + +```ignore +struct Baz { + m: mat3x2<f32>, +} +``` + +is rendered as the HLSL struct type: + +```ignore +struct Baz { + float2 m_0; float2 m_1; float2 m_2; +}; +``` + +The `wrapped_struct_matrix` functions in `help.rs` generate HLSL +helper functions to access such members, converting between the stored +form and the HLSL matrix types appropriately. For example, for reading +the member `m` of the `Baz` struct above, we emit: + +```ignore +float3x2 GetMatmOnBaz(Baz obj) { + return float3x2(obj.m_0, obj.m_1, obj.m_2); +} +``` + +We also emit an analogous `Set` function, as well as functions for +accessing individual columns by dynamic index. + +[hlsl]: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl +[ilov]: https://gpuweb.github.io/gpuweb/wgsl/#internal-value-layout +[16bb]: https://github.com/microsoft/DirectXShaderCompiler/wiki/Buffer-Packing#constant-buffer-packing +[8bb]: https://gpuweb.github.io/gpuweb/wgsl/#alignment-and-size +*/ + +mod conv; +mod help; +mod keywords; +mod storage; +mod writer; + +use std::fmt::Error as FmtError; +use thiserror::Error; + +use crate::{back, proc}; + +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct BindTarget { + pub space: u8, + pub register: u32, + /// If the binding is an unsized binding array, this overrides the size. + pub binding_array_size: Option<u32>, +} + +// Using `BTreeMap` instead of `HashMap` so that we can hash itself. +pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>; + +/// A HLSL shader model version. +#[allow(non_snake_case, non_camel_case_types)] +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum ShaderModel { + V5_0, + V5_1, + V6_0, +} + +impl ShaderModel { + pub const fn to_str(self) -> &'static str { + match self { + Self::V5_0 => "5_0", + Self::V5_1 => "5_1", + Self::V6_0 => "6_0", + } + } +} + +impl crate::ShaderStage { + pub const fn to_hlsl_str(self) -> &'static str { + match self { + Self::Vertex => "vs", + Self::Fragment => "ps", + Self::Compute => "cs", + } + } +} + +impl crate::ImageDimension { + const fn to_hlsl_str(self) -> &'static str { + match self { + Self::D1 => "1D", + Self::D2 => "2D", + Self::D3 => "3D", + Self::Cube => "Cube", + } + } +} + +/// Shorthand result used internally by the backend +type BackendResult = Result<(), Error>; + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum EntryPointError { + #[error("mapping of {0:?} is missing")] + MissingBinding(crate::ResourceBinding), +} + +/// Configuration used in the [`Writer`]. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct Options { + /// The hlsl shader model to be used + pub shader_model: ShaderModel, + /// Map of resources association to binding locations. + pub binding_map: BindingMap, + /// Don't panic on missing bindings, instead generate any HLSL. + pub fake_missing_bindings: bool, + /// Add special constants to `SV_VertexIndex` and `SV_InstanceIndex`, + /// to make them work like in Vulkan/Metal, with help of the host. + pub special_constants_binding: Option<BindTarget>, + /// Bind target of the push constant buffer + pub push_constants_target: Option<BindTarget>, + /// Should workgroup variables be zero initialized (by polyfilling)? + pub zero_initialize_workgroup_memory: bool, +} + +impl Default for Options { + fn default() -> Self { + Options { + shader_model: ShaderModel::V5_1, + binding_map: BindingMap::default(), + fake_missing_bindings: true, + special_constants_binding: None, + push_constants_target: None, + zero_initialize_workgroup_memory: true, + } + } +} + +impl Options { + fn resolve_resource_binding( + &self, + res_binding: &crate::ResourceBinding, + ) -> Result<BindTarget, EntryPointError> { + match self.binding_map.get(res_binding) { + Some(target) => Ok(target.clone()), + None if self.fake_missing_bindings => Ok(BindTarget { + space: res_binding.group as u8, + register: res_binding.binding, + binding_array_size: None, + }), + None => Err(EntryPointError::MissingBinding(res_binding.clone())), + } + } +} + +/// Reflection info for entry point names. +#[derive(Default)] +pub struct ReflectionInfo { + /// Mapping of the entry point names. + /// + /// Each item in the array corresponds to an entry point index. The real entry point name may be different if one of the + /// reserved words are used. + /// + /// Note: Some entry points may fail translation because of missing bindings. + pub entry_point_names: Vec<Result<String, EntryPointError>>, +} + +#[derive(Error, Debug)] +pub enum Error { + #[error(transparent)] + IoError(#[from] FmtError), + #[error("A scalar with an unsupported width was requested: {0:?}")] + UnsupportedScalar(crate::Scalar), + #[error("{0}")] + Unimplemented(String), // TODO: Error used only during development + #[error("{0}")] + Custom(String), +} + +#[derive(Default)] +struct Wrapped { + array_lengths: crate::FastHashSet<help::WrappedArrayLength>, + image_queries: crate::FastHashSet<help::WrappedImageQuery>, + constructors: crate::FastHashSet<help::WrappedConstructor>, + struct_matrix_access: crate::FastHashSet<help::WrappedStructMatrixAccess>, + mat_cx2s: crate::FastHashSet<help::WrappedMatCx2>, +} + +impl Wrapped { + fn clear(&mut self) { + self.array_lengths.clear(); + self.image_queries.clear(); + self.constructors.clear(); + self.struct_matrix_access.clear(); + self.mat_cx2s.clear(); + } +} + +pub struct Writer<'a, W> { + out: W, + names: crate::FastHashMap<proc::NameKey, String>, + namer: proc::Namer, + /// HLSL backend options + options: &'a Options, + /// Information about entry point arguments and result types. + entry_point_io: Vec<writer::EntryPointInterface>, + /// Set of expressions that have associated temporary variables + named_expressions: crate::NamedExpressions, + wrapped: Wrapped, + + /// A reference to some part of a global variable, lowered to a series of + /// byte offset calculations. + /// + /// See the [`storage`] module for background on why we need this. + /// + /// Each [`SubAccess`] in the vector is a lowering of some [`Access`] or + /// [`AccessIndex`] expression to the level of byte strides and offsets. See + /// [`SubAccess`] for details. + /// + /// This field is a member of [`Writer`] solely to allow re-use of + /// the `Vec`'s dynamic allocation. The value is no longer needed + /// once HLSL for the access has been generated. + /// + /// [`Storage`]: crate::AddressSpace::Storage + /// [`SubAccess`]: storage::SubAccess + /// [`Access`]: crate::Expression::Access + /// [`AccessIndex`]: crate::Expression::AccessIndex + temp_access_chain: Vec<storage::SubAccess>, + need_bake_expressions: back::NeedBakeExpressions, +} diff --git a/third_party/rust/naga/src/back/hlsl/storage.rs b/third_party/rust/naga/src/back/hlsl/storage.rs new file mode 100644 index 0000000000..1b8a6ec12d --- /dev/null +++ b/third_party/rust/naga/src/back/hlsl/storage.rs @@ -0,0 +1,494 @@ +/*! +Generating accesses to [`ByteAddressBuffer`] contents. + +Naga IR globals in the [`Storage`] address space are rendered as +[`ByteAddressBuffer`]s or [`RWByteAddressBuffer`]s in HLSL. These +buffers don't have HLSL types (structs, arrays, etc.); instead, they +are just raw blocks of bytes, with methods to load and store values of +specific types at particular byte offsets. This means that Naga must +translate chains of [`Access`] and [`AccessIndex`] expressions into +HLSL expressions that compute byte offsets into the buffer. + +To generate code for a [`Storage`] access: + +- Call [`Writer::fill_access_chain`] on the expression referring to + the value. This populates [`Writer::temp_access_chain`] with the + appropriate byte offset calculations, as a vector of [`SubAccess`] + values. + +- Call [`Writer::write_storage_address`] to emit an HLSL expression + for a given slice of [`SubAccess`] values. + +Naga IR expressions can operate on composite values of any type, but +[`ByteAddressBuffer`] and [`RWByteAddressBuffer`] have only a fixed +set of `Load` and `Store` methods, to access one through four +consecutive 32-bit values. To synthesize a Naga access, you can +initialize [`temp_access_chain`] to refer to the composite, and then +temporarily push and pop additional steps on +[`Writer::temp_access_chain`] to generate accesses to the individual +elements/members. + +The [`temp_access_chain`] field is a member of [`Writer`] solely to +allow re-use of the `Vec`'s dynamic allocation. Its value is no longer +needed once HLSL for the access has been generated. + +[`Storage`]: crate::AddressSpace::Storage +[`ByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-byteaddressbuffer +[`RWByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer +[`Access`]: crate::Expression::Access +[`AccessIndex`]: crate::Expression::AccessIndex +[`Writer::fill_access_chain`]: super::Writer::fill_access_chain +[`Writer::write_storage_address`]: super::Writer::write_storage_address +[`Writer::temp_access_chain`]: super::Writer::temp_access_chain +[`temp_access_chain`]: super::Writer::temp_access_chain +[`Writer`]: super::Writer +*/ + +use super::{super::FunctionCtx, BackendResult, Error}; +use crate::{ + proc::{Alignment, NameKey, TypeResolution}, + Handle, +}; + +use std::{fmt, mem}; + +const STORE_TEMP_NAME: &str = "_value"; + +/// One step in accessing a [`Storage`] global's component or element. +/// +/// [`Writer::temp_access_chain`] holds a series of these structures, +/// describing how to compute the byte offset of a particular element +/// or member of some global variable in the [`Storage`] address +/// space. +/// +/// [`Writer::temp_access_chain`]: super::Writer::temp_access_chain +/// [`Storage`]: crate::AddressSpace::Storage +#[derive(Debug)] +pub(super) enum SubAccess { + /// Add the given byte offset. This is used for struct members, or + /// known components of a vector or matrix. In all those cases, + /// the byte offset is a compile-time constant. + Offset(u32), + + /// Scale `value` by `stride`, and add that to the current byte + /// offset. This is used to compute the offset of an array element + /// whose index is computed at runtime. + Index { + value: Handle<crate::Expression>, + stride: u32, + }, +} + +pub(super) enum StoreValue { + Expression(Handle<crate::Expression>), + TempIndex { + depth: usize, + index: u32, + ty: TypeResolution, + }, + TempAccess { + depth: usize, + base: Handle<crate::Type>, + member_index: u32, + }, +} + +impl<W: fmt::Write> super::Writer<'_, W> { + pub(super) fn write_storage_address( + &mut self, + module: &crate::Module, + chain: &[SubAccess], + func_ctx: &FunctionCtx, + ) -> BackendResult { + if chain.is_empty() { + write!(self.out, "0")?; + } + for (i, access) in chain.iter().enumerate() { + if i != 0 { + write!(self.out, "+")?; + } + match *access { + SubAccess::Offset(offset) => { + write!(self.out, "{offset}")?; + } + SubAccess::Index { value, stride } => { + self.write_expr(module, value, func_ctx)?; + write!(self.out, "*{stride}")?; + } + } + } + Ok(()) + } + + fn write_storage_load_sequence<I: Iterator<Item = (TypeResolution, u32)>>( + &mut self, + module: &crate::Module, + var_handle: Handle<crate::GlobalVariable>, + sequence: I, + func_ctx: &FunctionCtx, + ) -> BackendResult { + for (i, (ty_resolution, offset)) in sequence.enumerate() { + // add the index temporarily + self.temp_access_chain.push(SubAccess::Offset(offset)); + if i != 0 { + write!(self.out, ", ")?; + }; + self.write_storage_load(module, var_handle, ty_resolution, func_ctx)?; + self.temp_access_chain.pop(); + } + Ok(()) + } + + /// Emit code to access a [`Storage`] global's component. + /// + /// Emit HLSL to access the component of `var_handle`, a global + /// variable in the [`Storage`] address space, whose type is + /// `result_ty` and whose location within the global is given by + /// [`self.temp_access_chain`]. See the [`storage`] module's + /// documentation for background. + /// + /// [`Storage`]: crate::AddressSpace::Storage + /// [`self.temp_access_chain`]: super::Writer::temp_access_chain + pub(super) fn write_storage_load( + &mut self, + module: &crate::Module, + var_handle: Handle<crate::GlobalVariable>, + result_ty: TypeResolution, + func_ctx: &FunctionCtx, + ) -> BackendResult { + match *result_ty.inner_with(&module.types) { + crate::TypeInner::Scalar(scalar) => { + // working around the borrow checker in `self.write_expr` + let chain = mem::take(&mut self.temp_access_chain); + let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; + let cast = scalar.kind.to_hlsl_cast(); + write!(self.out, "{cast}({var_name}.Load(")?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, "))")?; + self.temp_access_chain = chain; + } + crate::TypeInner::Vector { size, scalar } => { + // working around the borrow checker in `self.write_expr` + let chain = mem::take(&mut self.temp_access_chain); + let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; + let cast = scalar.kind.to_hlsl_cast(); + write!(self.out, "{}({}.Load{}(", cast, var_name, size as u8)?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, "))")?; + self.temp_access_chain = chain; + } + crate::TypeInner::Matrix { + columns, + rows, + scalar, + } => { + write!( + self.out, + "{}{}x{}(", + scalar.to_hlsl_str()?, + columns as u8, + rows as u8, + )?; + + // Note: Matrices containing vec3s, due to padding, act like they contain vec4s. + let row_stride = Alignment::from(rows) * scalar.width as u32; + let iter = (0..columns as u32).map(|i| { + let ty_inner = crate::TypeInner::Vector { size: rows, scalar }; + (TypeResolution::Value(ty_inner), i * row_stride) + }); + self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?; + write!(self.out, ")")?; + } + crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + stride, + } => { + let constructor = super::help::WrappedConstructor { + ty: result_ty.handle().unwrap(), + }; + self.write_wrapped_constructor_function_name(module, constructor)?; + write!(self.out, "(")?; + let iter = (0..size.get()).map(|i| (TypeResolution::Handle(base), stride * i)); + self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?; + write!(self.out, ")")?; + } + crate::TypeInner::Struct { ref members, .. } => { + let constructor = super::help::WrappedConstructor { + ty: result_ty.handle().unwrap(), + }; + self.write_wrapped_constructor_function_name(module, constructor)?; + write!(self.out, "(")?; + let iter = members + .iter() + .map(|m| (TypeResolution::Handle(m.ty), m.offset)); + self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?; + write!(self.out, ")")?; + } + _ => unreachable!(), + } + Ok(()) + } + + fn write_store_value( + &mut self, + module: &crate::Module, + value: &StoreValue, + func_ctx: &FunctionCtx, + ) -> BackendResult { + match *value { + StoreValue::Expression(expr) => self.write_expr(module, expr, func_ctx)?, + StoreValue::TempIndex { + depth, + index, + ty: _, + } => write!(self.out, "{STORE_TEMP_NAME}{depth}[{index}]")?, + StoreValue::TempAccess { + depth, + base, + member_index, + } => { + let name = &self.names[&NameKey::StructMember(base, member_index)]; + write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}")? + } + } + Ok(()) + } + + /// Helper function to write down the Store operation on a `ByteAddressBuffer`. + pub(super) fn write_storage_store( + &mut self, + module: &crate::Module, + var_handle: Handle<crate::GlobalVariable>, + value: StoreValue, + func_ctx: &FunctionCtx, + level: crate::back::Level, + ) -> BackendResult { + let temp_resolution; + let ty_resolution = match value { + StoreValue::Expression(expr) => &func_ctx.info[expr].ty, + StoreValue::TempIndex { + depth: _, + index: _, + ref ty, + } => ty, + StoreValue::TempAccess { + depth: _, + base, + member_index, + } => { + let ty_handle = match module.types[base].inner { + crate::TypeInner::Struct { ref members, .. } => { + members[member_index as usize].ty + } + _ => unreachable!(), + }; + temp_resolution = TypeResolution::Handle(ty_handle); + &temp_resolution + } + }; + match *ty_resolution.inner_with(&module.types) { + crate::TypeInner::Scalar(_) => { + // working around the borrow checker in `self.write_expr` + let chain = mem::take(&mut self.temp_access_chain); + let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; + write!(self.out, "{level}{var_name}.Store(")?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", asuint(")?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, "));")?; + self.temp_access_chain = chain; + } + crate::TypeInner::Vector { size, .. } => { + // working around the borrow checker in `self.write_expr` + let chain = mem::take(&mut self.temp_access_chain); + let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; + write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", asuint(")?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, "));")?; + self.temp_access_chain = chain; + } + crate::TypeInner::Matrix { + columns, + rows, + scalar, + } => { + // first, assign the value to a temporary + writeln!(self.out, "{level}{{")?; + let depth = level.0 + 1; + write!( + self.out, + "{}{}{}x{} {}{} = ", + level.next(), + scalar.to_hlsl_str()?, + columns as u8, + rows as u8, + STORE_TEMP_NAME, + depth, + )?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, ";")?; + + // Note: Matrices containing vec3s, due to padding, act like they contain vec4s. + let row_stride = Alignment::from(rows) * scalar.width as u32; + + // then iterate the stores + for i in 0..columns as u32 { + self.temp_access_chain + .push(SubAccess::Offset(i * row_stride)); + let ty_inner = crate::TypeInner::Vector { size: rows, scalar }; + let sv = StoreValue::TempIndex { + depth, + index: i, + ty: TypeResolution::Value(ty_inner), + }; + self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?; + self.temp_access_chain.pop(); + } + // done + writeln!(self.out, "{level}}}")?; + } + crate::TypeInner::Array { + base, + size: crate::ArraySize::Constant(size), + stride, + } => { + // first, assign the value to a temporary + writeln!(self.out, "{level}{{")?; + write!(self.out, "{}", level.next())?; + self.write_value_type(module, &module.types[base].inner)?; + let depth = level.next().0; + write!(self.out, " {STORE_TEMP_NAME}{depth}")?; + self.write_array_size(module, base, crate::ArraySize::Constant(size))?; + write!(self.out, " = ")?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, ";")?; + // then iterate the stores + for i in 0..size.get() { + self.temp_access_chain.push(SubAccess::Offset(i * stride)); + let sv = StoreValue::TempIndex { + depth, + index: i, + ty: TypeResolution::Handle(base), + }; + self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?; + self.temp_access_chain.pop(); + } + // done + writeln!(self.out, "{level}}}")?; + } + crate::TypeInner::Struct { ref members, .. } => { + // first, assign the value to a temporary + writeln!(self.out, "{level}{{")?; + let depth = level.next().0; + let struct_ty = ty_resolution.handle().unwrap(); + let struct_name = &self.names[&NameKey::Type(struct_ty)]; + write!( + self.out, + "{}{} {}{} = ", + level.next(), + struct_name, + STORE_TEMP_NAME, + depth + )?; + self.write_store_value(module, &value, func_ctx)?; + writeln!(self.out, ";")?; + // then iterate the stores + for (i, member) in members.iter().enumerate() { + self.temp_access_chain + .push(SubAccess::Offset(member.offset)); + let sv = StoreValue::TempAccess { + depth, + base: struct_ty, + member_index: i as u32, + }; + self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?; + self.temp_access_chain.pop(); + } + // done + writeln!(self.out, "{level}}}")?; + } + _ => unreachable!(), + } + Ok(()) + } + + /// Set [`temp_access_chain`] to compute the byte offset of `cur_expr`. + /// + /// The `cur_expr` expression must be a reference to a global + /// variable in the [`Storage`] address space, or a chain of + /// [`Access`] and [`AccessIndex`] expressions referring to some + /// component of such a global. + /// + /// [`temp_access_chain`]: super::Writer::temp_access_chain + /// [`Storage`]: crate::AddressSpace::Storage + /// [`Access`]: crate::Expression::Access + /// [`AccessIndex`]: crate::Expression::AccessIndex + pub(super) fn fill_access_chain( + &mut self, + module: &crate::Module, + mut cur_expr: Handle<crate::Expression>, + func_ctx: &FunctionCtx, + ) -> Result<Handle<crate::GlobalVariable>, Error> { + enum AccessIndex { + Expression(Handle<crate::Expression>), + Constant(u32), + } + enum Parent<'a> { + Array { stride: u32 }, + Struct(&'a [crate::StructMember]), + } + self.temp_access_chain.clear(); + + loop { + let (next_expr, access_index) = match func_ctx.expressions[cur_expr] { + crate::Expression::GlobalVariable(handle) => return Ok(handle), + crate::Expression::Access { base, index } => (base, AccessIndex::Expression(index)), + crate::Expression::AccessIndex { base, index } => { + (base, AccessIndex::Constant(index)) + } + ref other => { + return Err(Error::Unimplemented(format!("Pointer access of {other:?}"))) + } + }; + + let parent = match *func_ctx.resolve_type(next_expr, &module.types) { + crate::TypeInner::Pointer { base, .. } => match module.types[base].inner { + crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members), + crate::TypeInner::Array { stride, .. } => Parent::Array { stride }, + crate::TypeInner::Vector { scalar, .. } => Parent::Array { + stride: scalar.width as u32, + }, + crate::TypeInner::Matrix { rows, scalar, .. } => Parent::Array { + // The stride between matrices is the count of rows as this is how + // long each column is. + stride: Alignment::from(rows) * scalar.width as u32, + }, + _ => unreachable!(), + }, + crate::TypeInner::ValuePointer { scalar, .. } => Parent::Array { + stride: scalar.width as u32, + }, + _ => unreachable!(), + }; + + let sub = match (parent, access_index) { + (Parent::Array { stride }, AccessIndex::Expression(value)) => { + SubAccess::Index { value, stride } + } + (Parent::Array { stride }, AccessIndex::Constant(index)) => { + SubAccess::Offset(stride * index) + } + (Parent::Struct(members), AccessIndex::Constant(index)) => { + SubAccess::Offset(members[index as usize].offset) + } + (Parent::Struct(_), AccessIndex::Expression(_)) => unreachable!(), + }; + + self.temp_access_chain.push(sub); + cur_expr = next_expr; + } + } +} diff --git a/third_party/rust/naga/src/back/hlsl/writer.rs b/third_party/rust/naga/src/back/hlsl/writer.rs new file mode 100644 index 0000000000..43f7212837 --- /dev/null +++ b/third_party/rust/naga/src/back/hlsl/writer.rs @@ -0,0 +1,3366 @@ +use super::{ + help::{WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess}, + storage::StoreValue, + BackendResult, Error, Options, +}; +use crate::{ + back, + proc::{self, NameKey}, + valid, Handle, Module, ScalarKind, ShaderStage, TypeInner, +}; +use std::{fmt, mem}; + +const LOCATION_SEMANTIC: &str = "LOC"; +const SPECIAL_CBUF_TYPE: &str = "NagaConstants"; +const SPECIAL_CBUF_VAR: &str = "_NagaConstants"; +const SPECIAL_FIRST_VERTEX: &str = "first_vertex"; +const SPECIAL_FIRST_INSTANCE: &str = "first_instance"; +const SPECIAL_OTHER: &str = "other"; + +pub(crate) const MODF_FUNCTION: &str = "naga_modf"; +pub(crate) const FREXP_FUNCTION: &str = "naga_frexp"; + +struct EpStructMember { + name: String, + ty: Handle<crate::Type>, + // technically, this should always be `Some` + binding: Option<crate::Binding>, + index: u32, +} + +/// Structure contains information required for generating +/// wrapped structure of all entry points arguments +struct EntryPointBinding { + /// Name of the fake EP argument that contains the struct + /// with all the flattened input data. + arg_name: String, + /// Generated structure name + ty_name: String, + /// Members of generated structure + members: Vec<EpStructMember>, +} + +pub(super) struct EntryPointInterface { + /// If `Some`, the input of an entry point is gathered in a special + /// struct with members sorted by binding. + /// The `EntryPointBinding::members` array is sorted by index, + /// so that we can walk it in `write_ep_arguments_initialization`. + input: Option<EntryPointBinding>, + /// If `Some`, the output of an entry point is flattened. + /// The `EntryPointBinding::members` array is sorted by binding, + /// So that we can walk it in `Statement::Return` handler. + output: Option<EntryPointBinding>, +} + +#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)] +enum InterfaceKey { + Location(u32), + BuiltIn(crate::BuiltIn), + Other, +} + +impl InterfaceKey { + const fn new(binding: Option<&crate::Binding>) -> Self { + match binding { + Some(&crate::Binding::Location { location, .. }) => Self::Location(location), + Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in), + None => Self::Other, + } + } +} + +#[derive(Copy, Clone, PartialEq)] +enum Io { + Input, + Output, +} + +impl<'a, W: fmt::Write> super::Writer<'a, W> { + pub fn new(out: W, options: &'a Options) -> Self { + Self { + out, + names: crate::FastHashMap::default(), + namer: proc::Namer::default(), + options, + entry_point_io: Vec::new(), + named_expressions: crate::NamedExpressions::default(), + wrapped: super::Wrapped::default(), + temp_access_chain: Vec::new(), + need_bake_expressions: Default::default(), + } + } + + fn reset(&mut self, module: &Module) { + self.names.clear(); + self.namer.reset( + module, + super::keywords::RESERVED, + super::keywords::TYPES, + super::keywords::RESERVED_CASE_INSENSITIVE, + &[], + &mut self.names, + ); + self.entry_point_io.clear(); + self.named_expressions.clear(); + self.wrapped.clear(); + self.need_bake_expressions.clear(); + } + + /// Helper method used to find which expressions of a given function require baking + /// + /// # Notes + /// Clears `need_bake_expressions` set before adding to it + fn update_expressions_to_bake( + &mut self, + module: &Module, + func: &crate::Function, + info: &valid::FunctionInfo, + ) { + use crate::Expression; + self.need_bake_expressions.clear(); + for (fun_handle, expr) in func.expressions.iter() { + let expr_info = &info[fun_handle]; + let min_ref_count = func.expressions[fun_handle].bake_ref_count(); + if min_ref_count <= expr_info.ref_count { + self.need_bake_expressions.insert(fun_handle); + } + + if let Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } = *expr + { + match fun { + crate::MathFunction::Asinh + | crate::MathFunction::Acosh + | crate::MathFunction::Atanh + | crate::MathFunction::Unpack2x16float + | crate::MathFunction::Unpack2x16snorm + | crate::MathFunction::Unpack2x16unorm + | crate::MathFunction::Unpack4x8snorm + | crate::MathFunction::Unpack4x8unorm + | crate::MathFunction::Pack2x16float + | crate::MathFunction::Pack2x16snorm + | crate::MathFunction::Pack2x16unorm + | crate::MathFunction::Pack4x8snorm + | crate::MathFunction::Pack4x8unorm => { + self.need_bake_expressions.insert(arg); + } + crate::MathFunction::ExtractBits => { + self.need_bake_expressions.insert(arg); + self.need_bake_expressions.insert(arg1.unwrap()); + self.need_bake_expressions.insert(arg2.unwrap()); + } + crate::MathFunction::InsertBits => { + self.need_bake_expressions.insert(arg); + self.need_bake_expressions.insert(arg1.unwrap()); + self.need_bake_expressions.insert(arg2.unwrap()); + self.need_bake_expressions.insert(arg3.unwrap()); + } + crate::MathFunction::CountLeadingZeros => { + let inner = info[fun_handle].ty.inner_with(&module.types); + if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { + self.need_bake_expressions.insert(arg); + } + } + _ => {} + } + } + + if let Expression::Derivative { axis, ctrl, expr } = *expr { + use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; + if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) { + self.need_bake_expressions.insert(expr); + } + } + } + } + + pub fn write( + &mut self, + module: &Module, + module_info: &valid::ModuleInfo, + ) -> Result<super::ReflectionInfo, Error> { + self.reset(module); + + // Write special constants, if needed + if let Some(ref bt) = self.options.special_constants_binding { + writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?; + writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?; + writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?; + writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?; + writeln!(self.out, "}};")?; + write!( + self.out, + "ConstantBuffer<{}> {}: register(b{}", + SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register + )?; + if bt.space != 0 { + write!(self.out, ", space{}", bt.space)?; + } + writeln!(self.out, ");")?; + + // Extra newline for readability + writeln!(self.out)?; + } + + // Save all entry point output types + let ep_results = module + .entry_points + .iter() + .map(|ep| (ep.stage, ep.function.result.clone())) + .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>(); + + self.write_all_mat_cx2_typedefs_and_functions(module)?; + + // Write all structs + for (handle, ty) in module.types.iter() { + if let TypeInner::Struct { ref members, span } = ty.inner { + if module.types[members.last().unwrap().ty] + .inner + .is_dynamically_sized(&module.types) + { + // unsized arrays can only be in storage buffers, + // for which we use `ByteAddressBuffer` anyway. + continue; + } + + let ep_result = ep_results.iter().find(|e| { + if let Some(ref result) = e.1 { + result.ty == handle + } else { + false + } + }); + + self.write_struct( + module, + handle, + members, + span, + ep_result.map(|r| (r.0, Io::Output)), + )?; + writeln!(self.out)?; + } + } + + self.write_special_functions(module)?; + + self.write_wrapped_compose_functions(module, &module.const_expressions)?; + + // Write all named constants + let mut constants = module + .constants + .iter() + .filter(|&(_, c)| c.name.is_some()) + .peekable(); + while let Some((handle, _)) = constants.next() { + self.write_global_constant(module, handle)?; + // Add extra newline for readability on last iteration + if constants.peek().is_none() { + writeln!(self.out)?; + } + } + + // Write all globals + for (ty, _) in module.global_variables.iter() { + self.write_global(module, ty)?; + } + + if !module.global_variables.is_empty() { + // Add extra newline for readability + writeln!(self.out)?; + } + + // Write all entry points wrapped structs + for (index, ep) in module.entry_points.iter().enumerate() { + let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone(); + let ep_io = self.write_ep_interface(module, &ep.function, ep.stage, &ep_name)?; + self.entry_point_io.push(ep_io); + } + + // Write all regular functions + for (handle, function) in module.functions.iter() { + let info = &module_info[handle]; + + // Check if all of the globals are accessible + if !self.options.fake_missing_bindings { + if let Some((var_handle, _)) = + module + .global_variables + .iter() + .find(|&(var_handle, var)| match var.binding { + Some(ref binding) if !info[var_handle].is_empty() => { + self.options.resolve_resource_binding(binding).is_err() + } + _ => false, + }) + { + log::info!( + "Skipping function {:?} (name {:?}) because global {:?} is inaccessible", + handle, + function.name, + var_handle + ); + continue; + } + } + + let ctx = back::FunctionCtx { + ty: back::FunctionType::Function(handle), + info, + expressions: &function.expressions, + named_expressions: &function.named_expressions, + }; + let name = self.names[&NameKey::Function(handle)].clone(); + + self.write_wrapped_functions(module, &ctx)?; + + self.write_function(module, name.as_str(), function, &ctx, info)?; + + writeln!(self.out)?; + } + + let mut entry_point_names = Vec::with_capacity(module.entry_points.len()); + + // Write all entry points + for (index, ep) in module.entry_points.iter().enumerate() { + let info = module_info.get_entry_point(index); + + if !self.options.fake_missing_bindings { + let mut ep_error = None; + for (var_handle, var) in module.global_variables.iter() { + match var.binding { + Some(ref binding) if !info[var_handle].is_empty() => { + if let Err(err) = self.options.resolve_resource_binding(binding) { + ep_error = Some(err); + break; + } + } + _ => {} + } + } + if let Some(err) = ep_error { + entry_point_names.push(Err(err)); + continue; + } + } + + let ctx = back::FunctionCtx { + ty: back::FunctionType::EntryPoint(index as u16), + info, + expressions: &ep.function.expressions, + named_expressions: &ep.function.named_expressions, + }; + + self.write_wrapped_functions(module, &ctx)?; + + if ep.stage == ShaderStage::Compute { + // HLSL is calling workgroup size "num threads" + let num_threads = ep.workgroup_size; + writeln!( + self.out, + "[numthreads({}, {}, {})]", + num_threads[0], num_threads[1], num_threads[2] + )?; + } + + let name = self.names[&NameKey::EntryPoint(index as u16)].clone(); + self.write_function(module, &name, &ep.function, &ctx, info)?; + + if index < module.entry_points.len() - 1 { + writeln!(self.out)?; + } + + entry_point_names.push(Ok(name)); + } + + Ok(super::ReflectionInfo { entry_point_names }) + } + + fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult { + match *binding { + crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => { + write!(self.out, "precise ")?; + } + crate::Binding::Location { + interpolation, + sampling, + .. + } => { + if let Some(interpolation) = interpolation { + if let Some(string) = interpolation.to_hlsl_str() { + write!(self.out, "{string} ")? + } + } + + if let Some(sampling) = sampling { + if let Some(string) = sampling.to_hlsl_str() { + write!(self.out, "{string} ")? + } + } + } + crate::Binding::BuiltIn(_) => {} + } + + Ok(()) + } + + //TODO: we could force fragment outputs to always go through `entry_point_io.output` path + // if they are struct, so that the `stage` argument here could be omitted. + fn write_semantic( + &mut self, + binding: &crate::Binding, + stage: Option<(ShaderStage, Io)>, + ) -> BackendResult { + match *binding { + crate::Binding::BuiltIn(builtin) => { + let builtin_str = builtin.to_hlsl_str()?; + write!(self.out, " : {builtin_str}")?; + } + crate::Binding::Location { + second_blend_source: true, + .. + } => { + write!(self.out, " : SV_Target1")?; + } + crate::Binding::Location { + location, + second_blend_source: false, + .. + } => { + if stage == Some((crate::ShaderStage::Fragment, Io::Output)) { + write!(self.out, " : SV_Target{location}")?; + } else { + write!(self.out, " : {LOCATION_SEMANTIC}{location}")?; + } + } + } + + Ok(()) + } + + fn write_interface_struct( + &mut self, + module: &Module, + shader_stage: (ShaderStage, Io), + struct_name: String, + mut members: Vec<EpStructMember>, + ) -> Result<EntryPointBinding, Error> { + // Sort the members so that first come the user-defined varyings + // in ascending locations, and then built-ins. This allows VS and FS + // interfaces to match with regards to order. + members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref())); + + write!(self.out, "struct {struct_name}")?; + writeln!(self.out, " {{")?; + for m in members.iter() { + write!(self.out, "{}", back::INDENT)?; + if let Some(ref binding) = m.binding { + self.write_modifier(binding)?; + } + self.write_type(module, m.ty)?; + write!(self.out, " {}", &m.name)?; + if let Some(ref binding) = m.binding { + self.write_semantic(binding, Some(shader_stage))?; + } + writeln!(self.out, ";")?; + } + writeln!(self.out, "}};")?; + writeln!(self.out)?; + + match shader_stage.1 { + Io::Input => { + // bring back the original order + members.sort_by_key(|m| m.index); + } + Io::Output => { + // keep it sorted by binding + } + } + + Ok(EntryPointBinding { + arg_name: self.namer.call(struct_name.to_lowercase().as_str()), + ty_name: struct_name, + members, + }) + } + + /// Flatten all entry point arguments into a single struct. + /// This is needed since we need to re-order them: first placing user locations, + /// then built-ins. + fn write_ep_input_struct( + &mut self, + module: &Module, + func: &crate::Function, + stage: ShaderStage, + entry_point_name: &str, + ) -> Result<EntryPointBinding, Error> { + let struct_name = format!("{stage:?}Input_{entry_point_name}"); + + let mut fake_members = Vec::new(); + for arg in func.arguments.iter() { + match module.types[arg.ty].inner { + TypeInner::Struct { ref members, .. } => { + for member in members.iter() { + let name = self.namer.call_or(&member.name, "member"); + let index = fake_members.len() as u32; + fake_members.push(EpStructMember { + name, + ty: member.ty, + binding: member.binding.clone(), + index, + }); + } + } + _ => { + let member_name = self.namer.call_or(&arg.name, "member"); + let index = fake_members.len() as u32; + fake_members.push(EpStructMember { + name: member_name, + ty: arg.ty, + binding: arg.binding.clone(), + index, + }); + } + } + } + + self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members) + } + + /// Flatten all entry point results into a single struct. + /// This is needed since we need to re-order them: first placing user locations, + /// then built-ins. + fn write_ep_output_struct( + &mut self, + module: &Module, + result: &crate::FunctionResult, + stage: ShaderStage, + entry_point_name: &str, + ) -> Result<EntryPointBinding, Error> { + let struct_name = format!("{stage:?}Output_{entry_point_name}"); + + let mut fake_members = Vec::new(); + let empty = []; + let members = match module.types[result.ty].inner { + TypeInner::Struct { ref members, .. } => members, + ref other => { + log::error!("Unexpected {:?} output type without a binding", other); + &empty[..] + } + }; + + for member in members.iter() { + let member_name = self.namer.call_or(&member.name, "member"); + let index = fake_members.len() as u32; + fake_members.push(EpStructMember { + name: member_name, + ty: member.ty, + binding: member.binding.clone(), + index, + }); + } + + self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members) + } + + /// Writes special interface structures for an entry point. The special structures have + /// all the fields flattened into them and sorted by binding. They are only needed for + /// VS outputs and FS inputs, so that these interfaces match. + fn write_ep_interface( + &mut self, + module: &Module, + func: &crate::Function, + stage: ShaderStage, + ep_name: &str, + ) -> Result<EntryPointInterface, Error> { + Ok(EntryPointInterface { + input: if !func.arguments.is_empty() && stage == ShaderStage::Fragment { + Some(self.write_ep_input_struct(module, func, stage, ep_name)?) + } else { + None + }, + output: match func.result { + Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => { + Some(self.write_ep_output_struct(module, fr, stage, ep_name)?) + } + _ => None, + }, + }) + } + + /// Write an entry point preface that initializes the arguments as specified in IR. + fn write_ep_arguments_initialization( + &mut self, + module: &Module, + func: &crate::Function, + ep_index: u16, + ) -> BackendResult { + let ep_input = match self.entry_point_io[ep_index as usize].input.take() { + Some(ep_input) => ep_input, + None => return Ok(()), + }; + let mut fake_iter = ep_input.members.iter(); + for (arg_index, arg) in func.arguments.iter().enumerate() { + write!(self.out, "{}", back::INDENT)?; + self.write_type(module, arg.ty)?; + let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)]; + write!(self.out, " {arg_name}")?; + match module.types[arg.ty].inner { + TypeInner::Array { base, size, .. } => { + self.write_array_size(module, base, size)?; + let fake_member = fake_iter.next().unwrap(); + writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?; + } + TypeInner::Struct { ref members, .. } => { + write!(self.out, " = {{ ")?; + for index in 0..members.len() { + if index != 0 { + write!(self.out, ", ")?; + } + let fake_member = fake_iter.next().unwrap(); + write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?; + } + writeln!(self.out, " }};")?; + } + _ => { + let fake_member = fake_iter.next().unwrap(); + writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?; + } + } + } + assert!(fake_iter.next().is_none()); + Ok(()) + } + + /// Helper method used to write global variables + /// # Notes + /// Always adds a newline + fn write_global( + &mut self, + module: &Module, + handle: Handle<crate::GlobalVariable>, + ) -> BackendResult { + let global = &module.global_variables[handle]; + let inner = &module.types[global.ty].inner; + + if let Some(ref binding) = global.binding { + if let Err(err) = self.options.resolve_resource_binding(binding) { + log::info!( + "Skipping global {:?} (name {:?}) for being inaccessible: {}", + handle, + global.name, + err, + ); + return Ok(()); + } + } + + // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register + let register_ty = match global.space { + crate::AddressSpace::Function => unreachable!("Function address space"), + crate::AddressSpace::Private => { + write!(self.out, "static ")?; + self.write_type(module, global.ty)?; + "" + } + crate::AddressSpace::WorkGroup => { + write!(self.out, "groupshared ")?; + self.write_type(module, global.ty)?; + "" + } + crate::AddressSpace::Uniform => { + // constant buffer declarations are expected to be inlined, e.g. + // `cbuffer foo: register(b0) { field1: type1; }` + write!(self.out, "cbuffer")?; + "b" + } + crate::AddressSpace::Storage { access } => { + let (prefix, register) = if access.contains(crate::StorageAccess::STORE) { + ("RW", "u") + } else { + ("", "t") + }; + write!(self.out, "{prefix}ByteAddressBuffer")?; + register + } + crate::AddressSpace::Handle => { + let handle_ty = match *inner { + TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner, + _ => inner, + }; + + let register = match *handle_ty { + TypeInner::Sampler { .. } => "s", + // all storage textures are UAV, unconditionally + TypeInner::Image { + class: crate::ImageClass::Storage { .. }, + .. + } => "u", + _ => "t", + }; + self.write_type(module, global.ty)?; + register + } + crate::AddressSpace::PushConstant => { + // The type of the push constants will be wrapped in `ConstantBuffer` + write!(self.out, "ConstantBuffer<")?; + "b" + } + }; + + // If the global is a push constant write the type now because it will be a + // generic argument to `ConstantBuffer` + if global.space == crate::AddressSpace::PushConstant { + self.write_global_type(module, global.ty)?; + + // need to write the array size if the type was emitted with `write_type` + if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner { + self.write_array_size(module, base, size)?; + } + + // Close the angled brackets for the generic argument + write!(self.out, ">")?; + } + + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, " {name}")?; + + // Push constants need to be assigned a binding explicitly by the consumer + // since naga has no way to know the binding from the shader alone + if global.space == crate::AddressSpace::PushConstant { + let target = self + .options + .push_constants_target + .as_ref() + .expect("No bind target was defined for the push constants block"); + write!(self.out, ": register(b{}", target.register)?; + if target.space != 0 { + write!(self.out, ", space{}", target.space)?; + } + write!(self.out, ")")?; + } + + if let Some(ref binding) = global.binding { + // this was already resolved earlier when we started evaluating an entry point. + let bt = self.options.resolve_resource_binding(binding).unwrap(); + + // need to write the binding array size if the type was emitted with `write_type` + if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner { + if let Some(overridden_size) = bt.binding_array_size { + write!(self.out, "[{overridden_size}]")?; + } else { + self.write_array_size(module, base, size)?; + } + } + + write!(self.out, " : register({}{}", register_ty, bt.register)?; + if bt.space != 0 { + write!(self.out, ", space{}", bt.space)?; + } + write!(self.out, ")")?; + } else { + // need to write the array size if the type was emitted with `write_type` + if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner { + self.write_array_size(module, base, size)?; + } + if global.space == crate::AddressSpace::Private { + write!(self.out, " = ")?; + if let Some(init) = global.init { + self.write_const_expression(module, init)?; + } else { + self.write_default_init(module, global.ty)?; + } + } + } + + if global.space == crate::AddressSpace::Uniform { + write!(self.out, " {{ ")?; + + self.write_global_type(module, global.ty)?; + + write!( + self.out, + " {}", + &self.names[&NameKey::GlobalVariable(handle)] + )?; + + // need to write the array size if the type was emitted with `write_type` + if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner { + self.write_array_size(module, base, size)?; + } + + writeln!(self.out, "; }}")?; + } else { + writeln!(self.out, ";")?; + } + + Ok(()) + } + + /// Helper method used to write global constants + /// + /// # Notes + /// Ends in a newline + fn write_global_constant( + &mut self, + module: &Module, + handle: Handle<crate::Constant>, + ) -> BackendResult { + write!(self.out, "static const ")?; + let constant = &module.constants[handle]; + self.write_type(module, constant.ty)?; + let name = &self.names[&NameKey::Constant(handle)]; + write!(self.out, " {}", name)?; + // Write size for array type + if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner { + self.write_array_size(module, base, size)?; + } + write!(self.out, " = ")?; + self.write_const_expression(module, constant.init)?; + writeln!(self.out, ";")?; + Ok(()) + } + + pub(super) fn write_array_size( + &mut self, + module: &Module, + base: Handle<crate::Type>, + size: crate::ArraySize, + ) -> BackendResult { + write!(self.out, "[")?; + + match size { + crate::ArraySize::Constant(size) => { + write!(self.out, "{size}")?; + } + crate::ArraySize::Dynamic => unreachable!(), + } + + write!(self.out, "]")?; + + if let TypeInner::Array { + base: next_base, + size: next_size, + .. + } = module.types[base].inner + { + self.write_array_size(module, next_base, next_size)?; + } + + Ok(()) + } + + /// Helper method used to write structs + /// + /// # Notes + /// Ends in a newline + fn write_struct( + &mut self, + module: &Module, + handle: Handle<crate::Type>, + members: &[crate::StructMember], + span: u32, + shader_stage: Option<(ShaderStage, Io)>, + ) -> BackendResult { + // Write struct name + let struct_name = &self.names[&NameKey::Type(handle)]; + writeln!(self.out, "struct {struct_name} {{")?; + + let mut last_offset = 0; + for (index, member) in members.iter().enumerate() { + if member.binding.is_none() && member.offset > last_offset { + // using int as padding should work as long as the backend + // doesn't support a type that's less than 4 bytes in size + // (Error::UnsupportedScalar catches this) + let padding = (member.offset - last_offset) / 4; + for i in 0..padding { + writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?; + } + } + let ty_inner = &module.types[member.ty].inner; + last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx()); + + // The indentation is only for readability + write!(self.out, "{}", back::INDENT)?; + + match module.types[member.ty].inner { + TypeInner::Array { base, size, .. } => { + // HLSL arrays are written as `type name[size]` + + self.write_global_type(module, member.ty)?; + + // Write `name` + write!( + self.out, + " {}", + &self.names[&NameKey::StructMember(handle, index as u32)] + )?; + // Write [size] + self.write_array_size(module, base, size)?; + } + // We treat matrices of the form `matCx2` as a sequence of C `vec2`s. + // See the module-level block comment in mod.rs for details. + TypeInner::Matrix { + rows, + columns, + scalar, + } if member.binding.is_none() && rows == crate::VectorSize::Bi => { + let vec_ty = crate::TypeInner::Vector { size: rows, scalar }; + let field_name_key = NameKey::StructMember(handle, index as u32); + + for i in 0..columns as u8 { + if i != 0 { + write!(self.out, "; ")?; + } + self.write_value_type(module, &vec_ty)?; + write!(self.out, " {}_{}", &self.names[&field_name_key], i)?; + } + } + _ => { + // Write modifier before type + if let Some(ref binding) = member.binding { + self.write_modifier(binding)?; + } + + // Even though Naga IR matrices are column-major, we must describe + // matrices passed from the CPU as being in row-major order. + // See the module-level block comment in mod.rs for details. + if let TypeInner::Matrix { .. } = module.types[member.ty].inner { + write!(self.out, "row_major ")?; + } + + // Write the member type and name + self.write_type(module, member.ty)?; + write!( + self.out, + " {}", + &self.names[&NameKey::StructMember(handle, index as u32)] + )?; + } + } + + if let Some(ref binding) = member.binding { + self.write_semantic(binding, shader_stage)?; + }; + writeln!(self.out, ";")?; + } + + // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL + if members.last().unwrap().binding.is_none() && span > last_offset { + let padding = (span - last_offset) / 4; + for i in 0..padding { + writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?; + } + } + + writeln!(self.out, "}};")?; + Ok(()) + } + + /// Helper method used to write global/structs non image/sampler types + /// + /// # Notes + /// Adds no trailing or leading whitespace + pub(super) fn write_global_type( + &mut self, + module: &Module, + ty: Handle<crate::Type>, + ) -> BackendResult { + let matrix_data = get_inner_matrix_data(module, ty); + + // We treat matrices of the form `matCx2` as a sequence of C `vec2`s. + // See the module-level block comment in mod.rs for details. + if let Some(MatrixType { + columns, + rows: crate::VectorSize::Bi, + width: 4, + }) = matrix_data + { + write!(self.out, "__mat{}x2", columns as u8)?; + } else { + // Even though Naga IR matrices are column-major, we must describe + // matrices passed from the CPU as being in row-major order. + // See the module-level block comment in mod.rs for details. + if matrix_data.is_some() { + write!(self.out, "row_major ")?; + } + + self.write_type(module, ty)?; + } + + Ok(()) + } + + /// Helper method used to write non image/sampler types + /// + /// # Notes + /// Adds no trailing or leading whitespace + pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult { + let inner = &module.types[ty].inner; + match *inner { + TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?, + // hlsl array has the size separated from the base type + TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => { + self.write_type(module, base)? + } + ref other => self.write_value_type(module, other)?, + } + + Ok(()) + } + + /// Helper method used to write value types + /// + /// # Notes + /// Adds no trailing or leading whitespace + pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult { + match *inner { + TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => { + write!(self.out, "{}", scalar.to_hlsl_str()?)?; + } + TypeInner::Vector { size, scalar } => { + write!( + self.out, + "{}{}", + scalar.to_hlsl_str()?, + back::vector_size_str(size) + )?; + } + TypeInner::Matrix { + columns, + rows, + scalar, + } => { + // The IR supports only float matrix + // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix + + // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well. + write!( + self.out, + "{}{}x{}", + scalar.to_hlsl_str()?, + back::vector_size_str(columns), + back::vector_size_str(rows), + )?; + } + TypeInner::Image { + dim, + arrayed, + class, + } => { + self.write_image_type(dim, arrayed, class)?; + } + TypeInner::Sampler { comparison } => { + let sampler = if comparison { + "SamplerComparisonState" + } else { + "SamplerState" + }; + write!(self.out, "{sampler}")?; + } + // HLSL arrays are written as `type name[size]` + // Current code is written arrays only as `[size]` + // Base `type` and `name` should be written outside + TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => { + self.write_array_size(module, base, size)?; + } + _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))), + } + + Ok(()) + } + + /// Helper method used to write functions + /// # Notes + /// Ends in a newline + fn write_function( + &mut self, + module: &Module, + name: &str, + func: &crate::Function, + func_ctx: &back::FunctionCtx<'_>, + info: &valid::FunctionInfo, + ) -> BackendResult { + // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax + + self.update_expressions_to_bake(module, func, info); + + // Write modifier + if let Some(crate::FunctionResult { + binding: + Some( + ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position { + invariant: true, + }), + ), + .. + }) = func.result + { + self.write_modifier(binding)?; + } + + // Write return type + if let Some(ref result) = func.result { + match func_ctx.ty { + back::FunctionType::Function(_) => { + self.write_type(module, result.ty)?; + } + back::FunctionType::EntryPoint(index) => { + if let Some(ref ep_output) = self.entry_point_io[index as usize].output { + write!(self.out, "{}", ep_output.ty_name)?; + } else { + self.write_type(module, result.ty)?; + } + } + } + } else { + write!(self.out, "void")?; + } + + // Write function name + write!(self.out, " {name}(")?; + + let need_workgroup_variables_initialization = + self.need_workgroup_variables_initialization(func_ctx, module); + + // Write function arguments for non entry point functions + match func_ctx.ty { + back::FunctionType::Function(handle) => { + for (index, arg) in func.arguments.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + // Write argument type + let arg_ty = match module.types[arg.ty].inner { + // pointers in function arguments are expected and resolve to `inout` + TypeInner::Pointer { base, .. } => { + //TODO: can we narrow this down to just `in` when possible? + write!(self.out, "inout ")?; + base + } + _ => arg.ty, + }; + self.write_type(module, arg_ty)?; + + let argument_name = + &self.names[&NameKey::FunctionArgument(handle, index as u32)]; + + // Write argument name. Space is important. + write!(self.out, " {argument_name}")?; + if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner { + self.write_array_size(module, base, size)?; + } + } + } + back::FunctionType::EntryPoint(ep_index) => { + if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input { + write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?; + } else { + let stage = module.entry_points[ep_index as usize].stage; + for (index, arg) in func.arguments.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + self.write_type(module, arg.ty)?; + + let argument_name = + &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)]; + + write!(self.out, " {argument_name}")?; + if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner { + self.write_array_size(module, base, size)?; + } + + if let Some(ref binding) = arg.binding { + self.write_semantic(binding, Some((stage, Io::Input)))?; + } + } + + if need_workgroup_variables_initialization { + if !func.arguments.is_empty() { + write!(self.out, ", ")?; + } + write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?; + } + } + } + } + // Ends of arguments + write!(self.out, ")")?; + + // Write semantic if it present + if let back::FunctionType::EntryPoint(index) = func_ctx.ty { + let stage = module.entry_points[index as usize].stage; + if let Some(crate::FunctionResult { + binding: Some(ref binding), + .. + }) = func.result + { + self.write_semantic(binding, Some((stage, Io::Output)))?; + } + } + + // Function body start + writeln!(self.out)?; + writeln!(self.out, "{{")?; + + if need_workgroup_variables_initialization { + self.write_workgroup_variables_initialization(func_ctx, module)?; + } + + if let back::FunctionType::EntryPoint(index) = func_ctx.ty { + self.write_ep_arguments_initialization(module, func, index)?; + } + + // Write function local variables + for (handle, local) in func.local_variables.iter() { + // Write indentation (only for readability) + write!(self.out, "{}", back::INDENT)?; + + // Write the local name + // The leading space is important + self.write_type(module, local.ty)?; + write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?; + // Write size for array type + if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner { + self.write_array_size(module, base, size)?; + } + + write!(self.out, " = ")?; + // Write the local initializer if needed + if let Some(init) = local.init { + self.write_expr(module, init, func_ctx)?; + } else { + // Zero initialize local variables + self.write_default_init(module, local.ty)?; + } + + // Finish the local with `;` and add a newline (only for readability) + writeln!(self.out, ";")? + } + + if !func.local_variables.is_empty() { + writeln!(self.out)?; + } + + // Write the function body (statement list) + for sta in func.body.iter() { + // The indentation should always be 1 when writing the function body + self.write_stmt(module, sta, func_ctx, back::Level(1))?; + } + + writeln!(self.out, "}}")?; + + self.named_expressions.clear(); + + Ok(()) + } + + fn need_workgroup_variables_initialization( + &mut self, + func_ctx: &back::FunctionCtx, + module: &Module, + ) -> bool { + self.options.zero_initialize_workgroup_memory + && func_ctx.ty.is_compute_entry_point(module) + && module.global_variables.iter().any(|(handle, var)| { + !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }) + } + + fn write_workgroup_variables_initialization( + &mut self, + func_ctx: &back::FunctionCtx, + module: &Module, + ) -> BackendResult { + let level = back::Level(1); + + writeln!( + self.out, + "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{" + )?; + + let vars = module.global_variables.iter().filter(|&(handle, var)| { + !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }); + + for (handle, var) in vars { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, "{}{} = ", level.next(), name)?; + self.write_default_init(module, var.ty)?; + writeln!(self.out, ";")?; + } + + writeln!(self.out, "{level}}}")?; + self.write_barrier(crate::Barrier::WORK_GROUP, level) + } + + /// Helper method used to write statements + /// + /// # Notes + /// Always adds a newline + fn write_stmt( + &mut self, + module: &Module, + stmt: &crate::Statement, + func_ctx: &back::FunctionCtx<'_>, + level: back::Level, + ) -> BackendResult { + use crate::Statement; + + match *stmt { + Statement::Emit(ref range) => { + for handle in range.clone() { + let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space(); + let expr_name = if ptr_class.is_some() { + // HLSL can't save a pointer-valued expression in a variable, + // but we shouldn't ever need to: they should never be named expressions, + // and none of the expression types flagged by bake_ref_count can be pointer-valued. + None + } else if let Some(name) = func_ctx.named_expressions.get(&handle) { + // Front end provides names for all variables at the start of writing. + // But we write them to step by step. We need to recache them + // Otherwise, we could accidentally write variable name instead of full expression. + // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords. + Some(self.namer.call(name)) + } else if self.need_bake_expressions.contains(&handle) { + Some(format!("_expr{}", handle.index())) + } else { + None + }; + + if let Some(name) = expr_name { + write!(self.out, "{level}")?; + self.write_named_expr(module, handle, name, handle, func_ctx)?; + } + } + } + // TODO: copy-paste from glsl-out + Statement::Block(ref block) => { + write!(self.out, "{level}")?; + writeln!(self.out, "{{")?; + for sta in block.iter() { + // Increase the indentation to help with readability + self.write_stmt(module, sta, func_ctx, level.next())? + } + writeln!(self.out, "{level}}}")? + } + // TODO: copy-paste from glsl-out + Statement::If { + condition, + ref accept, + ref reject, + } => { + write!(self.out, "{level}")?; + write!(self.out, "if (")?; + self.write_expr(module, condition, func_ctx)?; + writeln!(self.out, ") {{")?; + + let l2 = level.next(); + for sta in accept { + // Increase indentation to help with readability + self.write_stmt(module, sta, func_ctx, l2)?; + } + + // If there are no statements in the reject block we skip writing it + // This is only for readability + if !reject.is_empty() { + writeln!(self.out, "{level}}} else {{")?; + + for sta in reject { + // Increase indentation to help with readability + self.write_stmt(module, sta, func_ctx, l2)?; + } + } + + writeln!(self.out, "{level}}}")? + } + // TODO: copy-paste from glsl-out + Statement::Kill => writeln!(self.out, "{level}discard;")?, + Statement::Return { value: None } => { + writeln!(self.out, "{level}return;")?; + } + Statement::Return { value: Some(expr) } => { + let base_ty_res = &func_ctx.info[expr].ty; + let mut resolved = base_ty_res.inner_with(&module.types); + if let TypeInner::Pointer { base, space: _ } = *resolved { + resolved = &module.types[base].inner; + } + + if let TypeInner::Struct { .. } = *resolved { + // We can safely unwrap here, since we now we working with struct + let ty = base_ty_res.handle().unwrap(); + let struct_name = &self.names[&NameKey::Type(ty)]; + let variable_name = self.namer.call(&struct_name.to_lowercase()); + write!(self.out, "{level}const {struct_name} {variable_name} = ",)?; + self.write_expr(module, expr, func_ctx)?; + writeln!(self.out, ";")?; + + // for entry point returns, we may need to reshuffle the outputs into a different struct + let ep_output = match func_ctx.ty { + back::FunctionType::Function(_) => None, + back::FunctionType::EntryPoint(index) => { + self.entry_point_io[index as usize].output.as_ref() + } + }; + let final_name = match ep_output { + Some(ep_output) => { + let final_name = self.namer.call(&variable_name); + write!( + self.out, + "{}const {} {} = {{ ", + level, ep_output.ty_name, final_name, + )?; + for (index, m) in ep_output.members.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + let member_name = &self.names[&NameKey::StructMember(ty, m.index)]; + write!(self.out, "{variable_name}.{member_name}")?; + } + writeln!(self.out, " }};")?; + final_name + } + None => variable_name, + }; + writeln!(self.out, "{level}return {final_name};")?; + } else { + write!(self.out, "{level}return ")?; + self.write_expr(module, expr, func_ctx)?; + writeln!(self.out, ";")? + } + } + Statement::Store { pointer, value } => { + let ty_inner = func_ctx.resolve_type(pointer, &module.types); + if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() { + let var_handle = self.fill_access_chain(module, pointer, func_ctx)?; + self.write_storage_store( + module, + var_handle, + StoreValue::Expression(value), + func_ctx, + level, + )?; + } else { + // We treat matrices of the form `matCx2` as a sequence of C `vec2`s. + // See the module-level block comment in mod.rs for details. + // + // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars). + // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads). + struct MatrixAccess { + base: Handle<crate::Expression>, + index: u32, + } + enum Index { + Expression(Handle<crate::Expression>), + Static(u32), + } + + let get_members = |expr: Handle<crate::Expression>| { + let resolved = func_ctx.resolve_type(expr, &module.types); + match *resolved { + TypeInner::Pointer { base, .. } => match module.types[base].inner { + TypeInner::Struct { ref members, .. } => Some(members), + _ => None, + }, + _ => None, + } + }; + + let mut matrix = None; + let mut vector = None; + let mut scalar = None; + + let mut current_expr = pointer; + for _ in 0..3 { + let resolved = func_ctx.resolve_type(current_expr, &module.types); + + match (resolved, &func_ctx.expressions[current_expr]) { + ( + &TypeInner::Pointer { base: ty, .. }, + &crate::Expression::AccessIndex { base, index }, + ) if matches!( + module.types[ty].inner, + TypeInner::Matrix { + rows: crate::VectorSize::Bi, + .. + } + ) && get_members(base) + .map(|members| members[index as usize].binding.is_none()) + == Some(true) => + { + matrix = Some(MatrixAccess { base, index }); + break; + } + ( + &TypeInner::ValuePointer { + size: Some(crate::VectorSize::Bi), + .. + }, + &crate::Expression::Access { base, index }, + ) => { + vector = Some(Index::Expression(index)); + current_expr = base; + } + ( + &TypeInner::ValuePointer { + size: Some(crate::VectorSize::Bi), + .. + }, + &crate::Expression::AccessIndex { base, index }, + ) => { + vector = Some(Index::Static(index)); + current_expr = base; + } + ( + &TypeInner::ValuePointer { size: None, .. }, + &crate::Expression::Access { base, index }, + ) => { + scalar = Some(Index::Expression(index)); + current_expr = base; + } + ( + &TypeInner::ValuePointer { size: None, .. }, + &crate::Expression::AccessIndex { base, index }, + ) => { + scalar = Some(Index::Static(index)); + current_expr = base; + } + _ => break, + } + } + + write!(self.out, "{level}")?; + + if let Some(MatrixAccess { index, base }) = matrix { + let base_ty_res = &func_ctx.info[base].ty; + let resolved = base_ty_res.inner_with(&module.types); + let ty = match *resolved { + TypeInner::Pointer { base, .. } => base, + _ => base_ty_res.handle().unwrap(), + }; + + if let Some(Index::Static(vec_index)) = vector { + self.write_expr(module, base, func_ctx)?; + write!( + self.out, + ".{}_{}", + &self.names[&NameKey::StructMember(ty, index)], + vec_index + )?; + + if let Some(scalar_index) = scalar { + write!(self.out, "[")?; + match scalar_index { + Index::Static(index) => { + write!(self.out, "{index}")?; + } + Index::Expression(index) => { + self.write_expr(module, index, func_ctx)?; + } + } + write!(self.out, "]")?; + } + + write!(self.out, " = ")?; + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ";")?; + } else { + let access = WrappedStructMatrixAccess { ty, index }; + match (&vector, &scalar) { + (&Some(_), &Some(_)) => { + self.write_wrapped_struct_matrix_set_scalar_function_name( + access, + )?; + } + (&Some(_), &None) => { + self.write_wrapped_struct_matrix_set_vec_function_name(access)?; + } + (&None, _) => { + self.write_wrapped_struct_matrix_set_function_name(access)?; + } + } + + write!(self.out, "(")?; + self.write_expr(module, base, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + + if let Some(Index::Expression(vec_index)) = vector { + write!(self.out, ", ")?; + self.write_expr(module, vec_index, func_ctx)?; + + if let Some(scalar_index) = scalar { + write!(self.out, ", ")?; + match scalar_index { + Index::Static(index) => { + write!(self.out, "{index}")?; + } + Index::Expression(index) => { + self.write_expr(module, index, func_ctx)?; + } + } + } + } + writeln!(self.out, ");")?; + } + } else { + // We handle `Store`s to __matCx2 column vectors and scalar elements via + // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2. + struct MatrixData { + columns: crate::VectorSize, + base: Handle<crate::Expression>, + } + + enum Index { + Expression(Handle<crate::Expression>), + Static(u32), + } + + let mut matrix = None; + let mut vector = None; + let mut scalar = None; + + let mut current_expr = pointer; + for _ in 0..3 { + let resolved = func_ctx.resolve_type(current_expr, &module.types); + match (resolved, &func_ctx.expressions[current_expr]) { + ( + &TypeInner::ValuePointer { + size: Some(crate::VectorSize::Bi), + .. + }, + &crate::Expression::Access { base, index }, + ) => { + vector = Some(index); + current_expr = base; + } + ( + &TypeInner::ValuePointer { size: None, .. }, + &crate::Expression::Access { base, index }, + ) => { + scalar = Some(Index::Expression(index)); + current_expr = base; + } + ( + &TypeInner::ValuePointer { size: None, .. }, + &crate::Expression::AccessIndex { base, index }, + ) => { + scalar = Some(Index::Static(index)); + current_expr = base; + } + _ => { + if let Some(MatrixType { + columns, + rows: crate::VectorSize::Bi, + width: 4, + }) = get_inner_matrix_of_struct_array_member( + module, + current_expr, + func_ctx, + true, + ) { + matrix = Some(MatrixData { + columns, + base: current_expr, + }); + } + + break; + } + } + } + + if let (Some(MatrixData { columns, base }), Some(vec_index)) = + (matrix, vector) + { + if scalar.is_some() { + write!(self.out, "__set_el_of_mat{}x2", columns as u8)?; + } else { + write!(self.out, "__set_col_of_mat{}x2", columns as u8)?; + } + write!(self.out, "(")?; + self.write_expr(module, base, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, vec_index, func_ctx)?; + + if let Some(scalar_index) = scalar { + write!(self.out, ", ")?; + match scalar_index { + Index::Static(index) => { + write!(self.out, "{index}")?; + } + Index::Expression(index) => { + self.write_expr(module, index, func_ctx)?; + } + } + } + + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + + writeln!(self.out, ");")?; + } else { + self.write_expr(module, pointer, func_ctx)?; + write!(self.out, " = ")?; + + // We cast the RHS of this store in cases where the LHS + // is a struct member with type: + // - matCx2 or + // - a (possibly nested) array of matCx2's + if let Some(MatrixType { + columns, + rows: crate::VectorSize::Bi, + width: 4, + }) = get_inner_matrix_of_struct_array_member( + module, pointer, func_ctx, false, + ) { + let mut resolved = func_ctx.resolve_type(pointer, &module.types); + if let TypeInner::Pointer { base, .. } = *resolved { + resolved = &module.types[base].inner; + } + + write!(self.out, "(__mat{}x2", columns as u8)?; + if let TypeInner::Array { base, size, .. } = *resolved { + self.write_array_size(module, base, size)?; + } + write!(self.out, ")")?; + } + + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ";")? + } + } + } + } + Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + let l2 = level.next(); + if !continuing.is_empty() || break_if.is_some() { + let gate_name = self.namer.call("loop_init"); + writeln!(self.out, "{level}bool {gate_name} = true;")?; + writeln!(self.out, "{level}while(true) {{")?; + writeln!(self.out, "{l2}if (!{gate_name}) {{")?; + let l3 = l2.next(); + for sta in continuing.iter() { + self.write_stmt(module, sta, func_ctx, l3)?; + } + if let Some(condition) = break_if { + write!(self.out, "{l3}if (")?; + self.write_expr(module, condition, func_ctx)?; + writeln!(self.out, ") {{")?; + writeln!(self.out, "{}break;", l3.next())?; + writeln!(self.out, "{l3}}}")?; + } + writeln!(self.out, "{l2}}}")?; + writeln!(self.out, "{l2}{gate_name} = false;")?; + } else { + writeln!(self.out, "{level}while(true) {{")?; + } + + for sta in body.iter() { + self.write_stmt(module, sta, func_ctx, l2)?; + } + writeln!(self.out, "{level}}}")? + } + Statement::Break => writeln!(self.out, "{level}break;")?, + Statement::Continue => writeln!(self.out, "{level}continue;")?, + Statement::Barrier(barrier) => { + self.write_barrier(barrier, level)?; + } + Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => { + write!(self.out, "{level}")?; + self.write_expr(module, image, func_ctx)?; + + write!(self.out, "[")?; + if let Some(index) = array_index { + // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here + write!(self.out, "int3(")?; + self.write_expr(module, coordinate, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr(module, coordinate, func_ctx)?; + } + write!(self.out, "]")?; + + write!(self.out, " = ")?; + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ";")?; + } + Statement::Call { + function, + ref arguments, + result, + } => { + write!(self.out, "{level}")?; + if let Some(expr) = result { + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, expr.index()); + let expr_ty = &func_ctx.info[expr].ty; + match *expr_ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(expr, name); + } + let func_name = &self.names[&NameKey::Function(function)]; + write!(self.out, "{func_name}(")?; + for (index, argument) in arguments.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + self.write_expr(module, *argument, func_ctx)?; + } + writeln!(self.out, ");")? + } + Statement::Atomic { + pointer, + ref fun, + value, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + + // Validation ensures that `pointer` has a `Pointer` type. + let pointer_space = func_ctx + .resolve_type(pointer, &module.types) + .pointer_space() + .unwrap(); + + let fun_str = fun.to_hlsl_suffix(); + write!(self.out, " {res_name}; ")?; + match pointer_space { + crate::AddressSpace::WorkGroup => { + write!(self.out, "Interlocked{fun_str}(")?; + self.write_expr(module, pointer, func_ctx)?; + } + crate::AddressSpace::Storage { .. } => { + let var_handle = self.fill_access_chain(module, pointer, func_ctx)?; + // The call to `self.write_storage_address` wants + // mutable access to all of `self`, so temporarily take + // ownership of our reusable access chain buffer. + let chain = mem::take(&mut self.temp_access_chain); + let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; + write!(self.out, "{var_name}.Interlocked{fun_str}(")?; + self.write_storage_address(module, &chain, func_ctx)?; + self.temp_access_chain = chain; + } + ref other => { + return Err(Error::Custom(format!( + "invalid address space {other:?} for atomic statement" + ))) + } + } + write!(self.out, ", ")?; + // handle the special cases + match *fun { + crate::AtomicFunction::Subtract => { + // we just wrote `InterlockedAdd`, so negate the argument + write!(self.out, "-")?; + } + crate::AtomicFunction::Exchange { compare: Some(_) } => { + return Err(Error::Unimplemented("atomic CompareExchange".to_string())); + } + _ => {} + } + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ", {res_name});")?; + self.named_expressions.insert(result, res_name); + } + Statement::WorkGroupUniformLoad { pointer, result } => { + self.write_barrier(crate::Barrier::WORK_GROUP, level)?; + write!(self.out, "{level}")?; + let name = format!("_expr{}", result.index()); + self.write_named_expr(module, pointer, name, result, func_ctx)?; + + self.write_barrier(crate::Barrier::WORK_GROUP, level)?; + } + Statement::Switch { + selector, + ref cases, + } => { + // Start the switch + write!(self.out, "{level}")?; + write!(self.out, "switch(")?; + self.write_expr(module, selector, func_ctx)?; + writeln!(self.out, ") {{")?; + + // Write all cases + let indent_level_1 = level.next(); + let indent_level_2 = indent_level_1.next(); + + for (i, case) in cases.iter().enumerate() { + match case.value { + crate::SwitchValue::I32(value) => { + write!(self.out, "{indent_level_1}case {value}:")? + } + crate::SwitchValue::U32(value) => { + write!(self.out, "{indent_level_1}case {value}u:")? + } + crate::SwitchValue::Default => { + write!(self.out, "{indent_level_1}default:")? + } + } + + // The new block is not only stylistic, it plays a role here: + // We might end up having to write the same case body + // multiple times due to FXC not supporting fallthrough. + // Therefore, some `Expression`s written by `Statement::Emit` + // will end up having the same name (`_expr<handle_index>`). + // So we need to put each case in its own scope. + let write_block_braces = !(case.fall_through && case.body.is_empty()); + if write_block_braces { + writeln!(self.out, " {{")?; + } else { + writeln!(self.out)?; + } + + // Although FXC does support a series of case clauses before + // a block[^yes], it does not support fallthrough from a + // non-empty case block to the next[^no]. If this case has a + // non-empty body with a fallthrough, emulate that by + // duplicating the bodies of all the cases it would fall + // into as extensions of this case's own body. This makes + // the HLSL output potentially quadratic in the size of the + // Naga IR. + // + // [^yes]: ```hlsl + // case 1: + // case 2: do_stuff() + // ``` + // [^no]: ```hlsl + // case 1: do_this(); + // case 2: do_that(); + // ``` + if case.fall_through && !case.body.is_empty() { + let curr_len = i + 1; + let end_case_idx = curr_len + + cases + .iter() + .skip(curr_len) + .position(|case| !case.fall_through) + .unwrap(); + let indent_level_3 = indent_level_2.next(); + for case in &cases[i..=end_case_idx] { + writeln!(self.out, "{indent_level_2}{{")?; + let prev_len = self.named_expressions.len(); + for sta in case.body.iter() { + self.write_stmt(module, sta, func_ctx, indent_level_3)?; + } + // Clear all named expressions that were previously inserted by the statements in the block + self.named_expressions.truncate(prev_len); + writeln!(self.out, "{indent_level_2}}}")?; + } + + let last_case = &cases[end_case_idx]; + if last_case.body.last().map_or(true, |s| !s.is_terminator()) { + writeln!(self.out, "{indent_level_2}break;")?; + } + } else { + for sta in case.body.iter() { + self.write_stmt(module, sta, func_ctx, indent_level_2)?; + } + if !case.fall_through + && case.body.last().map_or(true, |s| !s.is_terminator()) + { + writeln!(self.out, "{indent_level_2}break;")?; + } + } + + if write_block_braces { + writeln!(self.out, "{indent_level_1}}}")?; + } + } + + writeln!(self.out, "{level}}}")? + } + Statement::RayQuery { .. } => unreachable!(), + } + + Ok(()) + } + + fn write_const_expression( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + ) -> BackendResult { + self.write_possibly_const_expression( + module, + expr, + &module.const_expressions, + |writer, expr| writer.write_const_expression(module, expr), + ) + } + + fn write_possibly_const_expression<E>( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + expressions: &crate::Arena<crate::Expression>, + write_expression: E, + ) -> BackendResult + where + E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult, + { + use crate::Expression; + + match expressions[expr] { + Expression::Literal(literal) => match literal { + // Floats are written using `Debug` instead of `Display` because it always appends the + // decimal part even it's zero + crate::Literal::F64(value) => write!(self.out, "{value:?}L")?, + crate::Literal::F32(value) => write!(self.out, "{value:?}")?, + crate::Literal::U32(value) => write!(self.out, "{}u", value)?, + crate::Literal::I32(value) => write!(self.out, "{}", value)?, + crate::Literal::I64(value) => write!(self.out, "{}L", value)?, + crate::Literal::Bool(value) => write!(self.out, "{}", value)?, + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".into(), + )); + } + }, + Expression::Constant(handle) => { + let constant = &module.constants[handle]; + if constant.name.is_some() { + write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?; + } else { + self.write_const_expression(module, constant.init)?; + } + } + Expression::ZeroValue(ty) => self.write_default_init(module, ty)?, + Expression::Compose { ty, ref components } => { + match module.types[ty].inner { + TypeInner::Struct { .. } | TypeInner::Array { .. } => { + self.write_wrapped_constructor_function_name( + module, + WrappedConstructor { ty }, + )?; + } + _ => { + self.write_type(module, ty)?; + } + }; + write!(self.out, "(")?; + for (index, component) in components.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + write_expression(self, *component)?; + } + write!(self.out, ")")?; + } + Expression::Splat { size, value } => { + // hlsl is not supported one value constructor + // if we write, for example, int4(0), dxc returns error: + // error: too few elements in vector initialization (expected 4 elements, have 1) + let number_of_components = match size { + crate::VectorSize::Bi => "xx", + crate::VectorSize::Tri => "xxx", + crate::VectorSize::Quad => "xxxx", + }; + write!(self.out, "(")?; + write_expression(self, value)?; + write!(self.out, ").{number_of_components}")? + } + _ => unreachable!(), + } + + Ok(()) + } + + /// Helper method to write expressions + /// + /// # Notes + /// Doesn't add any newlines or leading/trailing spaces + pub(super) fn write_expr( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx<'_>, + ) -> BackendResult { + use crate::Expression; + + // Handle the special semantics of vertex_index/instance_index + let ff_input = if self.options.special_constants_binding.is_some() { + func_ctx.is_fixed_function_input(expr, module) + } else { + None + }; + let closing_bracket = match ff_input { + Some(crate::BuiltIn::VertexIndex) => { + write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?; + ")" + } + Some(crate::BuiltIn::InstanceIndex) => { + write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?; + ")" + } + Some(crate::BuiltIn::NumWorkGroups) => { + // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`), + // in compute shaders the special constants contain the number + // of workgroups, which we are using here. + write!( + self.out, + "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})", + )?; + return Ok(()); + } + _ => "", + }; + + if let Some(name) = self.named_expressions.get(&expr) { + write!(self.out, "{name}{closing_bracket}")?; + return Ok(()); + } + + let expression = &func_ctx.expressions[expr]; + + match *expression { + Expression::Literal(_) + | Expression::Constant(_) + | Expression::ZeroValue(_) + | Expression::Compose { .. } + | Expression::Splat { .. } => { + self.write_possibly_const_expression( + module, + expr, + func_ctx.expressions, + |writer, expr| writer.write_expr(module, expr, func_ctx), + )?; + } + // All of the multiplication can be expressed as `mul`, + // except vector * vector, which needs to use the "*" operator. + Expression::Binary { + op: crate::BinaryOperator::Multiply, + left, + right, + } if func_ctx.resolve_type(left, &module.types).is_matrix() + || func_ctx.resolve_type(right, &module.types).is_matrix() => + { + // We intentionally flip the order of multiplication as our matrices are implicitly transposed. + write!(self.out, "mul(")?; + self.write_expr(module, right, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, left, func_ctx)?; + write!(self.out, ")")?; + } + + // TODO: handle undefined behavior of BinaryOperator::Modulo + // + // sint: + // if right == 0 return 0 + // if left == min(type_of(left)) && right == -1 return 0 + // if sign(left) != sign(right) return result as defined by WGSL + // + // uint: + // if right == 0 return 0 + // + // float: + // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798 + + // While HLSL supports float operands with the % operator it is only + // defined in cases where both sides are either positive or negative. + Expression::Binary { + op: crate::BinaryOperator::Modulo, + left, + right, + } if func_ctx.resolve_type(left, &module.types).scalar_kind() + == Some(crate::ScalarKind::Float) => + { + write!(self.out, "fmod(")?; + self.write_expr(module, left, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, right, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Binary { op, left, right } => { + write!(self.out, "(")?; + self.write_expr(module, left, func_ctx)?; + write!(self.out, " {} ", crate::back::binary_operation_str(op))?; + self.write_expr(module, right, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Access { base, index } => { + if let Some(crate::AddressSpace::Storage { .. }) = + func_ctx.resolve_type(expr, &module.types).pointer_space() + { + // do nothing, the chain is written on `Load`/`Store` + } else { + // We use the function __get_col_of_matCx2 here in cases + // where `base`s type resolves to a matCx2 and is part of a + // struct member with type of (possibly nested) array of matCx2's. + // + // Note that this only works for `Load`s and we handle + // `Store`s differently in `Statement::Store`. + if let Some(MatrixType { + columns, + rows: crate::VectorSize::Bi, + width: 4, + }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true) + { + write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?; + self.write_expr(module, base, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + write!(self.out, ")")?; + return Ok(()); + } + + let resolved = func_ctx.resolve_type(base, &module.types); + + let non_uniform_qualifier = match *resolved { + TypeInner::BindingArray { .. } => { + let uniformity = &func_ctx.info[index].uniformity; + + uniformity.non_uniform_result.is_some() + } + _ => false, + }; + + self.write_expr(module, base, func_ctx)?; + write!(self.out, "[")?; + if non_uniform_qualifier { + write!(self.out, "NonUniformResourceIndex(")?; + } + self.write_expr(module, index, func_ctx)?; + if non_uniform_qualifier { + write!(self.out, ")")?; + } + write!(self.out, "]")?; + } + } + Expression::AccessIndex { base, index } => { + if let Some(crate::AddressSpace::Storage { .. }) = + func_ctx.resolve_type(expr, &module.types).pointer_space() + { + // do nothing, the chain is written on `Load`/`Store` + } else { + fn write_access<W: fmt::Write>( + writer: &mut super::Writer<'_, W>, + resolved: &TypeInner, + base_ty_handle: Option<Handle<crate::Type>>, + index: u32, + ) -> BackendResult { + match *resolved { + // We specifically lift the ValuePointer to this case. While `[0]` is valid + // HLSL for any vector behind a value pointer, FXC completely miscompiles + // it and generates completely nonsensical DXBC. + // + // See https://github.com/gfx-rs/naga/issues/2095 for more details. + TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => { + // Write vector access as a swizzle + write!(writer.out, ".{}", back::COMPONENTS[index as usize])? + } + TypeInner::Matrix { .. } + | TypeInner::Array { .. } + | TypeInner::BindingArray { .. } => write!(writer.out, "[{index}]")?, + TypeInner::Struct { .. } => { + // This will never panic in case the type is a `Struct`, this is not true + // for other types so we can only check while inside this match arm + let ty = base_ty_handle.unwrap(); + + write!( + writer.out, + ".{}", + &writer.names[&NameKey::StructMember(ty, index)] + )? + } + ref other => { + return Err(Error::Custom(format!("Cannot index {other:?}"))) + } + } + Ok(()) + } + + // We write the matrix column access in a special way since + // the type of `base` is our special __matCx2 struct. + if let Some(MatrixType { + rows: crate::VectorSize::Bi, + width: 4, + .. + }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true) + { + self.write_expr(module, base, func_ctx)?; + write!(self.out, "._{index}")?; + return Ok(()); + } + + let base_ty_res = &func_ctx.info[base].ty; + let mut resolved = base_ty_res.inner_with(&module.types); + let base_ty_handle = match *resolved { + TypeInner::Pointer { base, .. } => { + resolved = &module.types[base].inner; + Some(base) + } + _ => base_ty_res.handle(), + }; + + // We treat matrices of the form `matCx2` as a sequence of C `vec2`s. + // See the module-level block comment in mod.rs for details. + // + // We handle matrix reconstruction here for Loads. + // Stores are handled directly by `Statement::Store`. + if let TypeInner::Struct { ref members, .. } = *resolved { + let member = &members[index as usize]; + + match module.types[member.ty].inner { + TypeInner::Matrix { + rows: crate::VectorSize::Bi, + .. + } if member.binding.is_none() => { + let ty = base_ty_handle.unwrap(); + self.write_wrapped_struct_matrix_get_function_name( + WrappedStructMatrixAccess { ty, index }, + )?; + write!(self.out, "(")?; + self.write_expr(module, base, func_ctx)?; + write!(self.out, ")")?; + return Ok(()); + } + _ => {} + } + } + + self.write_expr(module, base, func_ctx)?; + write_access(self, resolved, base_ty_handle, index)?; + } + } + Expression::FunctionArgument(pos) => { + let key = func_ctx.argument_key(pos); + let name = &self.names[&key]; + write!(self.out, "{name}")?; + } + Expression::ImageSample { + image, + sampler, + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + } => { + use crate::SampleLevel as Sl; + const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"]; + + let (base_str, component_str) = match gather { + Some(component) => ("Gather", COMPONENTS[component as usize]), + None => ("Sample", ""), + }; + let cmp_str = match depth_ref { + Some(_) => "Cmp", + None => "", + }; + let level_str = match level { + Sl::Zero if gather.is_none() => "LevelZero", + Sl::Auto | Sl::Zero => "", + Sl::Exact(_) => "Level", + Sl::Bias(_) => "Bias", + Sl::Gradient { .. } => "Grad", + }; + + self.write_expr(module, image, func_ctx)?; + write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?; + self.write_expr(module, sampler, func_ctx)?; + write!(self.out, ", ")?; + self.write_texture_coordinates( + "float", + coordinate, + array_index, + None, + module, + func_ctx, + )?; + + if let Some(depth_ref) = depth_ref { + write!(self.out, ", ")?; + self.write_expr(module, depth_ref, func_ctx)?; + } + + match level { + Sl::Auto | Sl::Zero => {} + Sl::Exact(expr) => { + write!(self.out, ", ")?; + self.write_expr(module, expr, func_ctx)?; + } + Sl::Bias(expr) => { + write!(self.out, ", ")?; + self.write_expr(module, expr, func_ctx)?; + } + Sl::Gradient { x, y } => { + write!(self.out, ", ")?; + self.write_expr(module, x, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, y, func_ctx)?; + } + } + + if let Some(offset) = offset { + write!(self.out, ", ")?; + write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807 + self.write_const_expression(module, offset)?; + write!(self.out, ")")?; + } + + write!(self.out, ")")?; + } + Expression::ImageQuery { image, query } => { + // use wrapped image query function + if let TypeInner::Image { + dim, + arrayed, + class, + } = *func_ctx.resolve_type(image, &module.types) + { + let wrapped_image_query = WrappedImageQuery { + dim, + arrayed, + class, + query: query.into(), + }; + + self.write_wrapped_image_query_function_name(wrapped_image_query)?; + write!(self.out, "(")?; + // Image always first param + self.write_expr(module, image, func_ctx)?; + if let crate::ImageQuery::Size { level: Some(level) } = query { + write!(self.out, ", ")?; + self.write_expr(module, level, func_ctx)?; + } + write!(self.out, ")")?; + } + } + Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load + self.write_expr(module, image, func_ctx)?; + write!(self.out, ".Load(")?; + + self.write_texture_coordinates( + "int", + coordinate, + array_index, + level, + module, + func_ctx, + )?; + + if let Some(sample) = sample { + write!(self.out, ", ")?; + self.write_expr(module, sample, func_ctx)?; + } + + // close bracket for Load function + write!(self.out, ")")?; + + // return x component if return type is scalar + if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) { + write!(self.out, ".x")?; + } + } + Expression::GlobalVariable(handle) => match module.global_variables[handle].space { + crate::AddressSpace::Storage { .. } => {} + _ => { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, "{name}")?; + } + }, + Expression::LocalVariable(handle) => { + write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])? + } + Expression::Load { pointer } => { + match func_ctx + .resolve_type(pointer, &module.types) + .pointer_space() + { + Some(crate::AddressSpace::Storage { .. }) => { + let var_handle = self.fill_access_chain(module, pointer, func_ctx)?; + let result_ty = func_ctx.info[expr].ty.clone(); + self.write_storage_load(module, var_handle, result_ty, func_ctx)?; + } + _ => { + let mut close_paren = false; + + // We cast the value loaded to a native HLSL floatCx2 + // in cases where it is of type: + // - __matCx2 or + // - a (possibly nested) array of __matCx2's + if let Some(MatrixType { + rows: crate::VectorSize::Bi, + width: 4, + .. + }) = get_inner_matrix_of_struct_array_member( + module, pointer, func_ctx, false, + ) + .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx)) + { + let mut resolved = func_ctx.resolve_type(pointer, &module.types); + if let TypeInner::Pointer { base, .. } = *resolved { + resolved = &module.types[base].inner; + } + + write!(self.out, "((")?; + if let TypeInner::Array { base, size, .. } = *resolved { + self.write_type(module, base)?; + self.write_array_size(module, base, size)?; + } else { + self.write_value_type(module, resolved)?; + } + write!(self.out, ")")?; + close_paren = true; + } + + self.write_expr(module, pointer, func_ctx)?; + + if close_paren { + write!(self.out, ")")?; + } + } + } + } + Expression::Unary { op, expr } => { + // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators + let op_str = match op { + crate::UnaryOperator::Negate => "-", + crate::UnaryOperator::LogicalNot => "!", + crate::UnaryOperator::BitwiseNot => "~", + }; + write!(self.out, "{op_str}(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")?; + } + Expression::As { + expr, + kind, + convert, + } => { + let inner = func_ctx.resolve_type(expr, &module.types); + match convert { + Some(dst_width) => { + let scalar = crate::Scalar { + kind, + width: dst_width, + }; + match *inner { + TypeInner::Vector { size, .. } => { + write!( + self.out, + "{}{}(", + scalar.to_hlsl_str()?, + back::vector_size_str(size) + )?; + } + TypeInner::Scalar(_) => { + write!(self.out, "{}(", scalar.to_hlsl_str()?,)?; + } + TypeInner::Matrix { columns, rows, .. } => { + write!( + self.out, + "{}{}x{}(", + scalar.to_hlsl_str()?, + back::vector_size_str(columns), + back::vector_size_str(rows) + )?; + } + _ => { + return Err(Error::Unimplemented(format!( + "write_expr expression::as {inner:?}" + ))); + } + }; + } + None => { + write!(self.out, "{}(", kind.to_hlsl_cast(),)?; + } + } + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + use crate::MathFunction as Mf; + + enum Function { + Asincosh { is_sin: bool }, + Atanh, + ExtractBits, + InsertBits, + Pack2x16float, + Pack2x16snorm, + Pack2x16unorm, + Pack4x8snorm, + Pack4x8unorm, + Unpack2x16float, + Unpack2x16snorm, + Unpack2x16unorm, + Unpack4x8snorm, + Unpack4x8unorm, + Regular(&'static str), + MissingIntOverload(&'static str), + MissingIntReturnType(&'static str), + CountTrailingZeros, + CountLeadingZeros, + } + + let fun = match fun { + // comparison + Mf::Abs => Function::Regular("abs"), + Mf::Min => Function::Regular("min"), + Mf::Max => Function::Regular("max"), + Mf::Clamp => Function::Regular("clamp"), + Mf::Saturate => Function::Regular("saturate"), + // trigonometry + Mf::Cos => Function::Regular("cos"), + Mf::Cosh => Function::Regular("cosh"), + Mf::Sin => Function::Regular("sin"), + Mf::Sinh => Function::Regular("sinh"), + Mf::Tan => Function::Regular("tan"), + Mf::Tanh => Function::Regular("tanh"), + Mf::Acos => Function::Regular("acos"), + Mf::Asin => Function::Regular("asin"), + Mf::Atan => Function::Regular("atan"), + Mf::Atan2 => Function::Regular("atan2"), + Mf::Asinh => Function::Asincosh { is_sin: true }, + Mf::Acosh => Function::Asincosh { is_sin: false }, + Mf::Atanh => Function::Atanh, + Mf::Radians => Function::Regular("radians"), + Mf::Degrees => Function::Regular("degrees"), + // decomposition + Mf::Ceil => Function::Regular("ceil"), + Mf::Floor => Function::Regular("floor"), + Mf::Round => Function::Regular("round"), + Mf::Fract => Function::Regular("frac"), + Mf::Trunc => Function::Regular("trunc"), + Mf::Modf => Function::Regular(MODF_FUNCTION), + Mf::Frexp => Function::Regular(FREXP_FUNCTION), + Mf::Ldexp => Function::Regular("ldexp"), + // exponent + Mf::Exp => Function::Regular("exp"), + Mf::Exp2 => Function::Regular("exp2"), + Mf::Log => Function::Regular("log"), + Mf::Log2 => Function::Regular("log2"), + Mf::Pow => Function::Regular("pow"), + // geometry + Mf::Dot => Function::Regular("dot"), + //Mf::Outer => , + Mf::Cross => Function::Regular("cross"), + Mf::Distance => Function::Regular("distance"), + Mf::Length => Function::Regular("length"), + Mf::Normalize => Function::Regular("normalize"), + Mf::FaceForward => Function::Regular("faceforward"), + Mf::Reflect => Function::Regular("reflect"), + Mf::Refract => Function::Regular("refract"), + // computational + Mf::Sign => Function::Regular("sign"), + Mf::Fma => Function::Regular("mad"), + Mf::Mix => Function::Regular("lerp"), + Mf::Step => Function::Regular("step"), + Mf::SmoothStep => Function::Regular("smoothstep"), + Mf::Sqrt => Function::Regular("sqrt"), + Mf::InverseSqrt => Function::Regular("rsqrt"), + //Mf::Inverse =>, + Mf::Transpose => Function::Regular("transpose"), + Mf::Determinant => Function::Regular("determinant"), + // bits + Mf::CountTrailingZeros => Function::CountTrailingZeros, + Mf::CountLeadingZeros => Function::CountLeadingZeros, + Mf::CountOneBits => Function::MissingIntOverload("countbits"), + Mf::ReverseBits => Function::MissingIntOverload("reversebits"), + Mf::FindLsb => Function::MissingIntReturnType("firstbitlow"), + Mf::FindMsb => Function::MissingIntReturnType("firstbithigh"), + Mf::ExtractBits => Function::ExtractBits, + Mf::InsertBits => Function::InsertBits, + // Data Packing + Mf::Pack2x16float => Function::Pack2x16float, + Mf::Pack2x16snorm => Function::Pack2x16snorm, + Mf::Pack2x16unorm => Function::Pack2x16unorm, + Mf::Pack4x8snorm => Function::Pack4x8snorm, + Mf::Pack4x8unorm => Function::Pack4x8unorm, + // Data Unpacking + Mf::Unpack2x16float => Function::Unpack2x16float, + Mf::Unpack2x16snorm => Function::Unpack2x16snorm, + Mf::Unpack2x16unorm => Function::Unpack2x16unorm, + Mf::Unpack4x8snorm => Function::Unpack4x8snorm, + Mf::Unpack4x8unorm => Function::Unpack4x8unorm, + _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))), + }; + + match fun { + Function::Asincosh { is_sin } => { + write!(self.out, "log(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " + sqrt(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " * ")?; + self.write_expr(module, arg, func_ctx)?; + match is_sin { + true => write!(self.out, " + 1.0))")?, + false => write!(self.out, " - 1.0))")?, + } + } + Function::Atanh => { + write!(self.out, "0.5 * log((1.0 + ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ") / (1.0 - ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } + Function::ExtractBits => { + // e: T, + // offset: u32, + // count: u32 + // T is u32 or i32 or vecN<u32> or vecN<i32> + if let (Some(offset), Some(count)) = (arg1, arg2) { + let scalar_width: u8 = 32; + // Works for signed and unsigned + // (count == 0 ? 0 : (e << (32 - count - offset)) >> (32 - count)) + write!(self.out, "(")?; + self.write_expr(module, count, func_ctx)?; + write!(self.out, " == 0 ? 0 : (")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " << ({scalar_width} - ")?; + self.write_expr(module, count, func_ctx)?; + write!(self.out, " - ")?; + self.write_expr(module, offset, func_ctx)?; + write!(self.out, ")) >> ({scalar_width} - ")?; + self.write_expr(module, count, func_ctx)?; + write!(self.out, "))")?; + } + } + Function::InsertBits => { + // e: T, + // newbits: T, + // offset: u32, + // count: u32 + // returns T + // T is i32, u32, vecN<i32>, or vecN<u32> + if let (Some(newbits), Some(offset), Some(count)) = (arg1, arg2, arg3) { + let scalar_width: u8 = 32; + let scalar_max: u32 = 0xFFFFFFFF; + // mask = ((0xFFFFFFFFu >> (32 - count)) << offset) + // (count == 0 ? e : ((e & ~mask) | ((newbits << offset) & mask))) + write!(self.out, "(")?; + self.write_expr(module, count, func_ctx)?; + write!(self.out, " == 0 ? ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " : ")?; + write!(self.out, "(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " & ~")?; + // mask + write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?; + self.write_expr(module, count, func_ctx)?; + write!(self.out, ")) << ")?; + self.write_expr(module, offset, func_ctx)?; + write!(self.out, ")")?; + // end mask + write!(self.out, ") | ((")?; + self.write_expr(module, newbits, func_ctx)?; + write!(self.out, " << ")?; + self.write_expr(module, offset, func_ctx)?; + write!(self.out, ") & ")?; + // // mask + write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?; + self.write_expr(module, count, func_ctx)?; + write!(self.out, ")) << ")?; + self.write_expr(module, offset, func_ctx)?; + write!(self.out, ")")?; + // // end mask + write!(self.out, "))")?; + } + } + Function::Pack2x16float => { + write!(self.out, "(f32tof16(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[0]) | f32tof16(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[1]) << 16)")?; + } + Function::Pack2x16snorm => { + let scale = 32767; + + write!(self.out, "uint((int(round(clamp(")?; + self.write_expr(module, arg, func_ctx)?; + write!( + self.out, + "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp(" + )?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?; + } + Function::Pack2x16unorm => { + let scale = 65535; + + write!(self.out, "(uint(round(clamp(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?; + } + Function::Pack4x8snorm => { + let scale = 127; + + write!(self.out, "uint((int(round(clamp(")?; + self.write_expr(module, arg, func_ctx)?; + write!( + self.out, + "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp(" + )?; + self.write_expr(module, arg, func_ctx)?; + write!( + self.out, + "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp(" + )?; + self.write_expr(module, arg, func_ctx)?; + write!( + self.out, + "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp(" + )?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?; + } + Function::Pack4x8unorm => { + let scale = 255; + + write!(self.out, "(uint(round(clamp(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?; + self.write_expr(module, arg, func_ctx)?; + write!( + self.out, + "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp(" + )?; + self.write_expr(module, arg, func_ctx)?; + write!( + self.out, + "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp(" + )?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?; + } + + Function::Unpack2x16float => { + write!(self.out, "float2(f16tof32(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "), f16tof32((")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ") >> 16))")?; + } + Function::Unpack2x16snorm => { + let scale = 32767; + + write!(self.out, "(float2(int2(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " << 16, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ") >> 16) / {scale}.0)")?; + } + Function::Unpack2x16unorm => { + let scale = 65535; + + write!(self.out, "(float2(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " & 0xFFFF, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " >> 16) / {scale}.0)")?; + } + Function::Unpack4x8snorm => { + let scale = 127; + + write!(self.out, "(float4(int4(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " << 24, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " << 16, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " << 8, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ") >> 24) / {scale}.0)")?; + } + Function::Unpack4x8unorm => { + let scale = 255; + + write!(self.out, "(float4(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " & 0xFF, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " >> 8 & 0xFF, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " >> 16 & 0xFF, ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " >> 24) / {scale}.0)")?; + } + Function::Regular(fun_name) => { + write!(self.out, "{fun_name}(")?; + self.write_expr(module, arg, func_ctx)?; + if let Some(arg) = arg1 { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + if let Some(arg) = arg2 { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + if let Some(arg) = arg3 { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + write!(self.out, ")")? + } + Function::MissingIntOverload(fun_name) => { + let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind(); + if let Some(ScalarKind::Sint) = scalar_kind { + write!(self.out, "asint({fun_name}(asuint(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")))")?; + } else { + write!(self.out, "{fun_name}(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")")?; + } + } + Function::MissingIntReturnType(fun_name) => { + let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind(); + if let Some(ScalarKind::Sint) = scalar_kind { + write!(self.out, "asint({fun_name}(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "{fun_name}(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")")?; + } + } + Function::CountTrailingZeros => { + match *func_ctx.resolve_type(arg, &module.types) { + TypeInner::Vector { size, scalar } => { + let s = match size { + crate::VectorSize::Bi => ".xx", + crate::VectorSize::Tri => ".xxx", + crate::VectorSize::Quad => ".xxxx", + }; + + if let ScalarKind::Uint = scalar.kind { + write!(self.out, "min((32u){s}, firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "asint(min((32u){s}, firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")))")?; + } + } + TypeInner::Scalar(scalar) => { + if let ScalarKind::Uint = scalar.kind { + write!(self.out, "min(32u, firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "asint(min(32u, firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")))")?; + } + } + _ => unreachable!(), + } + + return Ok(()); + } + Function::CountLeadingZeros => { + match *func_ctx.resolve_type(arg, &module.types) { + TypeInner::Vector { size, scalar } => { + let s = match size { + crate::VectorSize::Bi => ".xx", + crate::VectorSize::Tri => ".xxx", + crate::VectorSize::Quad => ".xxxx", + }; + + if let ScalarKind::Uint = scalar.kind { + write!(self.out, "((31u){s} - firstbithigh(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "(")?; + self.write_expr(module, arg, func_ctx)?; + write!( + self.out, + " < (0){s} ? (0){s} : (31){s} - asint(firstbithigh(" + )?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")))")?; + } + } + TypeInner::Scalar(scalar) => { + if let ScalarKind::Uint = scalar.kind { + write!(self.out, "(31u - firstbithigh(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " < 0 ? 0 : 31 - asint(firstbithigh(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ")))")?; + } + } + _ => unreachable!(), + } + + return Ok(()); + } + } + } + Expression::Swizzle { + size, + vector, + pattern, + } => { + self.write_expr(module, vector, func_ctx)?; + write!(self.out, ".")?; + for &sc in pattern[..size as usize].iter() { + self.out.write_char(back::COMPONENTS[sc as usize])?; + } + } + Expression::ArrayLength(expr) => { + let var_handle = match func_ctx.expressions[expr] { + Expression::AccessIndex { base, index: _ } => { + match func_ctx.expressions[base] { + Expression::GlobalVariable(handle) => handle, + _ => unreachable!(), + } + } + Expression::GlobalVariable(handle) => handle, + _ => unreachable!(), + }; + + let var = &module.global_variables[var_handle]; + let (offset, stride) = match module.types[var.ty].inner { + TypeInner::Array { stride, .. } => (0, stride), + TypeInner::Struct { ref members, .. } => { + let last = members.last().unwrap(); + let stride = match module.types[last.ty].inner { + TypeInner::Array { stride, .. } => stride, + _ => unreachable!(), + }; + (last.offset, stride) + } + _ => unreachable!(), + }; + + let storage_access = match var.space { + crate::AddressSpace::Storage { access } => access, + _ => crate::StorageAccess::default(), + }; + let wrapped_array_length = WrappedArrayLength { + writable: storage_access.contains(crate::StorageAccess::STORE), + }; + + write!(self.out, "((")?; + self.write_wrapped_array_length_function_name(wrapped_array_length)?; + let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; + write!(self.out, "({var_name}) - {offset}) / {stride})")? + } + Expression::Derivative { axis, ctrl, expr } => { + use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; + if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) { + let tail = match ctrl { + Ctrl::Coarse => "coarse", + Ctrl::Fine => "fine", + Ctrl::None => unreachable!(), + }; + write!(self.out, "abs(ddx_{tail}(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")) + abs(ddy_{tail}(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, "))")? + } else { + let fun_str = match (axis, ctrl) { + (Axis::X, Ctrl::Coarse) => "ddx_coarse", + (Axis::X, Ctrl::Fine) => "ddx_fine", + (Axis::X, Ctrl::None) => "ddx", + (Axis::Y, Ctrl::Coarse) => "ddy_coarse", + (Axis::Y, Ctrl::Fine) => "ddy_fine", + (Axis::Y, Ctrl::None) => "ddy", + (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(), + (Axis::Width, Ctrl::None) => "fwidth", + }; + write!(self.out, "{fun_str}(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")? + } + } + Expression::Relational { fun, argument } => { + use crate::RelationalFunction as Rf; + + let fun_str = match fun { + Rf::All => "all", + Rf::Any => "any", + Rf::IsNan => "isnan", + Rf::IsInf => "isinf", + }; + write!(self.out, "{fun_str}(")?; + self.write_expr(module, argument, func_ctx)?; + write!(self.out, ")")? + } + Expression::Select { + condition, + accept, + reject, + } => { + write!(self.out, "(")?; + self.write_expr(module, condition, func_ctx)?; + write!(self.out, " ? ")?; + self.write_expr(module, accept, func_ctx)?; + write!(self.out, " : ")?; + self.write_expr(module, reject, func_ctx)?; + write!(self.out, ")")? + } + // Not supported yet + Expression::RayQueryGetIntersection { .. } => unreachable!(), + // Nothing to do here, since call expression already cached + Expression::CallResult(_) + | Expression::AtomicResult { .. } + | Expression::WorkGroupUniformLoadResult { .. } + | Expression::RayQueryProceedResult => {} + } + + if !closing_bracket.is_empty() { + write!(self.out, "{closing_bracket}")?; + } + Ok(()) + } + + fn write_named_expr( + &mut self, + module: &Module, + handle: Handle<crate::Expression>, + name: String, + // The expression which is being named. + // Generally, this is the same as handle, except in WorkGroupUniformLoad + named: Handle<crate::Expression>, + ctx: &back::FunctionCtx, + ) -> BackendResult { + match ctx.info[named].ty { + proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner { + TypeInner::Struct { .. } => { + let ty_name = &self.names[&NameKey::Type(ty_handle)]; + write!(self.out, "{ty_name}")?; + } + _ => { + self.write_type(module, ty_handle)?; + } + }, + proc::TypeResolution::Value(ref inner) => { + self.write_value_type(module, inner)?; + } + } + + let resolved = ctx.resolve_type(named, &module.types); + + write!(self.out, " {name}")?; + // If rhs is a array type, we should write array size + if let TypeInner::Array { base, size, .. } = *resolved { + self.write_array_size(module, base, size)?; + } + write!(self.out, " = ")?; + self.write_expr(module, handle, ctx)?; + writeln!(self.out, ";")?; + self.named_expressions.insert(named, name); + + Ok(()) + } + + /// Helper function that write default zero initialization + fn write_default_init(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult { + write!(self.out, "(")?; + self.write_type(module, ty)?; + if let TypeInner::Array { base, size, .. } = module.types[ty].inner { + self.write_array_size(module, base, size)?; + } + write!(self.out, ")0")?; + Ok(()) + } + + fn write_barrier(&mut self, barrier: crate::Barrier, level: back::Level) -> BackendResult { + if barrier.contains(crate::Barrier::STORAGE) { + writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?; + } + if barrier.contains(crate::Barrier::WORK_GROUP) { + writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?; + } + Ok(()) + } +} + +pub(super) struct MatrixType { + pub(super) columns: crate::VectorSize, + pub(super) rows: crate::VectorSize, + pub(super) width: crate::Bytes, +} + +pub(super) fn get_inner_matrix_data( + module: &Module, + handle: Handle<crate::Type>, +) -> Option<MatrixType> { + match module.types[handle].inner { + TypeInner::Matrix { + columns, + rows, + scalar, + } => Some(MatrixType { + columns, + rows, + width: scalar.width, + }), + TypeInner::Array { base, .. } => get_inner_matrix_data(module, base), + _ => None, + } +} + +/// Returns the matrix data if the access chain starting at `base`: +/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true` +/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`] +/// - ends at an expression with resolved type of [`TypeInner::Struct`] +pub(super) fn get_inner_matrix_of_struct_array_member( + module: &Module, + base: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx<'_>, + direct: bool, +) -> Option<MatrixType> { + let mut mat_data = None; + let mut array_base = None; + + let mut current_base = base; + loop { + let mut resolved = func_ctx.resolve_type(current_base, &module.types); + if let TypeInner::Pointer { base, .. } = *resolved { + resolved = &module.types[base].inner; + }; + + match *resolved { + TypeInner::Matrix { + columns, + rows, + scalar, + } => { + mat_data = Some(MatrixType { + columns, + rows, + width: scalar.width, + }) + } + TypeInner::Array { base, .. } => { + array_base = Some(base); + } + TypeInner::Struct { .. } => { + if let Some(array_base) = array_base { + if direct { + return mat_data; + } else { + return get_inner_matrix_data(module, array_base); + } + } + + break; + } + _ => break, + } + + current_base = match func_ctx.expressions[current_base] { + crate::Expression::Access { base, .. } => base, + crate::Expression::AccessIndex { base, .. } => base, + _ => break, + }; + } + None +} + +/// Returns the matrix data if the access chain starting at `base`: +/// - starts with an expression with resolved type of [`TypeInner::Matrix`] +/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`] +/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform) +fn get_inner_matrix_of_global_uniform( + module: &Module, + base: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx<'_>, +) -> Option<MatrixType> { + let mut mat_data = None; + let mut array_base = None; + + let mut current_base = base; + loop { + let mut resolved = func_ctx.resolve_type(current_base, &module.types); + if let TypeInner::Pointer { base, .. } = *resolved { + resolved = &module.types[base].inner; + }; + + match *resolved { + TypeInner::Matrix { + columns, + rows, + scalar, + } => { + mat_data = Some(MatrixType { + columns, + rows, + width: scalar.width, + }) + } + TypeInner::Array { base, .. } => { + array_base = Some(base); + } + _ => break, + } + + current_base = match func_ctx.expressions[current_base] { + crate::Expression::Access { base, .. } => base, + crate::Expression::AccessIndex { base, .. } => base, + crate::Expression::GlobalVariable(handle) + if module.global_variables[handle].space == crate::AddressSpace::Uniform => + { + return mat_data.or_else(|| { + array_base.and_then(|array_base| get_inner_matrix_data(module, array_base)) + }) + } + _ => break, + }; + } + None +} diff --git a/third_party/rust/naga/src/back/mod.rs b/third_party/rust/naga/src/back/mod.rs new file mode 100644 index 0000000000..8100b930e9 --- /dev/null +++ b/third_party/rust/naga/src/back/mod.rs @@ -0,0 +1,273 @@ +/*! +Backend functions that export shader [`Module`](super::Module)s into binary and text formats. +*/ +#![allow(dead_code)] // can be dead if none of the enabled backends need it + +#[cfg(feature = "dot-out")] +pub mod dot; +#[cfg(feature = "glsl-out")] +pub mod glsl; +#[cfg(feature = "hlsl-out")] +pub mod hlsl; +#[cfg(feature = "msl-out")] +pub mod msl; +#[cfg(feature = "spv-out")] +pub mod spv; +#[cfg(feature = "wgsl-out")] +pub mod wgsl; + +const COMPONENTS: &[char] = &['x', 'y', 'z', 'w']; +const INDENT: &str = " "; +const BAKE_PREFIX: &str = "_e"; + +type NeedBakeExpressions = crate::FastHashSet<crate::Handle<crate::Expression>>; + +#[derive(Clone, Copy)] +struct Level(usize); + +impl Level { + const fn next(&self) -> Self { + Level(self.0 + 1) + } +} + +impl std::fmt::Display for Level { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + (0..self.0).try_for_each(|_| formatter.write_str(INDENT)) + } +} + +/// Whether we're generating an entry point or a regular function. +/// +/// Backend languages often require different code for a [`Function`] +/// depending on whether it represents an [`EntryPoint`] or not. +/// Backends can pass common code one of these values to select the +/// right behavior. +/// +/// These values also carry enough information to find the `Function` +/// in the [`Module`]: the `Handle` for a regular function, or the +/// index into [`Module::entry_points`] for an entry point. +/// +/// [`Function`]: crate::Function +/// [`EntryPoint`]: crate::EntryPoint +/// [`Module`]: crate::Module +/// [`Module::entry_points`]: crate::Module::entry_points +enum FunctionType { + /// A regular function. + Function(crate::Handle<crate::Function>), + /// An [`EntryPoint`], and its index in [`Module::entry_points`]. + /// + /// [`EntryPoint`]: crate::EntryPoint + /// [`Module::entry_points`]: crate::Module::entry_points + EntryPoint(crate::proc::EntryPointIndex), +} + +impl FunctionType { + fn is_compute_entry_point(&self, module: &crate::Module) -> bool { + match *self { + FunctionType::EntryPoint(index) => { + module.entry_points[index as usize].stage == crate::ShaderStage::Compute + } + FunctionType::Function(_) => false, + } + } +} + +/// Helper structure that stores data needed when writing the function +struct FunctionCtx<'a> { + /// The current function being written + ty: FunctionType, + /// Analysis about the function + info: &'a crate::valid::FunctionInfo, + /// The expression arena of the current function being written + expressions: &'a crate::Arena<crate::Expression>, + /// Map of expressions that have associated variable names + named_expressions: &'a crate::NamedExpressions, +} + +impl FunctionCtx<'_> { + fn resolve_type<'a>( + &'a self, + handle: crate::Handle<crate::Expression>, + types: &'a crate::UniqueArena<crate::Type>, + ) -> &'a crate::TypeInner { + self.info[handle].ty.inner_with(types) + } + + /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for a local in the current function + const fn name_key(&self, local: crate::Handle<crate::LocalVariable>) -> crate::proc::NameKey { + match self.ty { + FunctionType::Function(handle) => crate::proc::NameKey::FunctionLocal(handle, local), + FunctionType::EntryPoint(idx) => crate::proc::NameKey::EntryPointLocal(idx, local), + } + } + + /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for a function argument. + /// + /// # Panics + /// - If the function arguments are less or equal to `arg` + const fn argument_key(&self, arg: u32) -> crate::proc::NameKey { + match self.ty { + FunctionType::Function(handle) => crate::proc::NameKey::FunctionArgument(handle, arg), + FunctionType::EntryPoint(ep_index) => { + crate::proc::NameKey::EntryPointArgument(ep_index, arg) + } + } + } + + // Returns true if the given expression points to a fixed-function pipeline input. + fn is_fixed_function_input( + &self, + mut expression: crate::Handle<crate::Expression>, + module: &crate::Module, + ) -> Option<crate::BuiltIn> { + let ep_function = match self.ty { + FunctionType::Function(_) => return None, + FunctionType::EntryPoint(ep_index) => &module.entry_points[ep_index as usize].function, + }; + let mut built_in = None; + loop { + match self.expressions[expression] { + crate::Expression::FunctionArgument(arg_index) => { + return match ep_function.arguments[arg_index as usize].binding { + Some(crate::Binding::BuiltIn(bi)) => Some(bi), + _ => built_in, + }; + } + crate::Expression::AccessIndex { base, index } => { + match *self.resolve_type(base, &module.types) { + crate::TypeInner::Struct { ref members, .. } => { + if let Some(crate::Binding::BuiltIn(bi)) = + members[index as usize].binding + { + built_in = Some(bi); + } + } + _ => return None, + } + expression = base; + } + _ => return None, + } + } + } +} + +impl crate::Expression { + /// Returns the ref count, upon reaching which this expression + /// should be considered for baking. + /// + /// Note: we have to cache any expressions that depend on the control flow, + /// or otherwise they may be moved into a non-uniform control flow, accidentally. + /// See the [module-level documentation][emit] for details. + /// + /// [emit]: index.html#expression-evaluation-time + const fn bake_ref_count(&self) -> usize { + match *self { + // accesses are never cached, only loads are + crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => usize::MAX, + // sampling may use the control flow, and image ops look better by themselves + crate::Expression::ImageSample { .. } | crate::Expression::ImageLoad { .. } => 1, + // derivatives use the control flow + crate::Expression::Derivative { .. } => 1, + // TODO: We need a better fix for named `Load` expressions + // More info - https://github.com/gfx-rs/naga/pull/914 + // And https://github.com/gfx-rs/naga/issues/910 + crate::Expression::Load { .. } => 1, + // cache expressions that are referenced multiple times + _ => 2, + } + } +} + +/// Helper function that returns the string corresponding to the [`BinaryOperator`](crate::BinaryOperator) +/// # Notes +/// Used by `glsl-out`, `msl-out`, `wgsl-out`, `hlsl-out`. +const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str { + use crate::BinaryOperator as Bo; + match op { + Bo::Add => "+", + Bo::Subtract => "-", + Bo::Multiply => "*", + Bo::Divide => "/", + Bo::Modulo => "%", + Bo::Equal => "==", + Bo::NotEqual => "!=", + Bo::Less => "<", + Bo::LessEqual => "<=", + Bo::Greater => ">", + Bo::GreaterEqual => ">=", + Bo::And => "&", + Bo::ExclusiveOr => "^", + Bo::InclusiveOr => "|", + Bo::LogicalAnd => "&&", + Bo::LogicalOr => "||", + Bo::ShiftLeft => "<<", + Bo::ShiftRight => ">>", + } +} + +/// Helper function that returns the string corresponding to the [`VectorSize`](crate::VectorSize) +/// # Notes +/// Used by `msl-out`, `wgsl-out`, `hlsl-out`. +const fn vector_size_str(size: crate::VectorSize) -> &'static str { + match size { + crate::VectorSize::Bi => "2", + crate::VectorSize::Tri => "3", + crate::VectorSize::Quad => "4", + } +} + +impl crate::TypeInner { + const fn is_handle(&self) -> bool { + match *self { + crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => true, + _ => false, + } + } +} + +impl crate::Statement { + /// Returns true if the statement directly terminates the current block. + /// + /// Used to decide whether case blocks require a explicit `break`. + pub const fn is_terminator(&self) -> bool { + match *self { + crate::Statement::Break + | crate::Statement::Continue + | crate::Statement::Return { .. } + | crate::Statement::Kill => true, + _ => false, + } + } +} + +bitflags::bitflags! { + /// Ray flags, for a [`RayDesc`]'s `flags` field. + /// + /// Note that these exactly correspond to the SPIR-V "Ray Flags" mask, and + /// the SPIR-V backend passes them directly through to the + /// `OpRayQueryInitializeKHR` instruction. (We have to choose something, so + /// we might as well make one back end's life easier.) + /// + /// [`RayDesc`]: crate::Module::generate_ray_desc_type + #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] + pub struct RayFlag: u32 { + const OPAQUE = 0x01; + const NO_OPAQUE = 0x02; + const TERMINATE_ON_FIRST_HIT = 0x04; + const SKIP_CLOSEST_HIT_SHADER = 0x08; + const CULL_BACK_FACING = 0x10; + const CULL_FRONT_FACING = 0x20; + const CULL_OPAQUE = 0x40; + const CULL_NO_OPAQUE = 0x80; + const SKIP_TRIANGLES = 0x100; + const SKIP_AABBS = 0x200; + } +} + +#[repr(u32)] +enum RayIntersectionType { + Triangle = 1, + BoundingBox = 4, +} diff --git a/third_party/rust/naga/src/back/msl/keywords.rs b/third_party/rust/naga/src/back/msl/keywords.rs new file mode 100644 index 0000000000..f0025bf239 --- /dev/null +++ b/third_party/rust/naga/src/back/msl/keywords.rs @@ -0,0 +1,342 @@ +// MSLS - Metal Shading Language Specification: +// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf +// +// C++ - Standard for Programming Language C++ (N4431) +// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4431.pdf +pub const RESERVED: &[&str] = &[ + // Standard for Programming Language C++ (N4431): 2.5 Alternative tokens + "and", + "bitor", + "or", + "xor", + "compl", + "bitand", + "and_eq", + "or_eq", + "xor_eq", + "not", + "not_eq", + // Standard for Programming Language C++ (N4431): 2.11 Keywords + "alignas", + "alignof", + "asm", + "auto", + "bool", + "break", + "case", + "catch", + "char", + "char16_t", + "char32_t", + "class", + "const", + "constexpr", + "const_cast", + "continue", + "decltype", + "default", + "delete", + "do", + "double", + "dynamic_cast", + "else", + "enum", + "explicit", + "export", + "extern", + "false", + "float", + "for", + "friend", + "goto", + "if", + "inline", + "int", + "long", + "mutable", + "namespace", + "new", + "noexcept", + "nullptr", + "operator", + "private", + "protected", + "public", + "register", + "reinterpret_cast", + "return", + "short", + "signed", + "sizeof", + "static", + "static_assert", + "static_cast", + "struct", + "switch", + "template", + "this", + "thread_local", + "throw", + "true", + "try", + "typedef", + "typeid", + "typename", + "union", + "unsigned", + "using", + "virtual", + "void", + "volatile", + "wchar_t", + "while", + // Metal Shading Language Specification: 1.4.4 Restrictions + "main", + // Metal Shading Language Specification: 2.1 Scalar Data Types + "int8_t", + "uchar", + "uint8_t", + "int16_t", + "ushort", + "uint16_t", + "int32_t", + "uint", + "uint32_t", + "int64_t", + "uint64_t", + "half", + "bfloat", + "size_t", + "ptrdiff_t", + // Metal Shading Language Specification: 2.2 Vector Data Types + "bool2", + "bool3", + "bool4", + "char2", + "char3", + "char4", + "short2", + "short3", + "short4", + "int2", + "int3", + "int4", + "long2", + "long3", + "long4", + "uchar2", + "uchar3", + "uchar4", + "ushort2", + "ushort3", + "ushort4", + "uint2", + "uint3", + "uint4", + "ulong2", + "ulong3", + "ulong4", + "half2", + "half3", + "half4", + "bfloat2", + "bfloat3", + "bfloat4", + "float2", + "float3", + "float4", + "vec", + // Metal Shading Language Specification: 2.2.3 Packed Vector Types + "packed_bool2", + "packed_bool3", + "packed_bool4", + "packed_char2", + "packed_char3", + "packed_char4", + "packed_short2", + "packed_short3", + "packed_short4", + "packed_int2", + "packed_int3", + "packed_int4", + "packed_uchar2", + "packed_uchar3", + "packed_uchar4", + "packed_ushort2", + "packed_ushort3", + "packed_ushort4", + "packed_uint2", + "packed_uint3", + "packed_uint4", + "packed_half2", + "packed_half3", + "packed_half4", + "packed_bfloat2", + "packed_bfloat3", + "packed_bfloat4", + "packed_float2", + "packed_float3", + "packed_float4", + "packed_long2", + "packed_long3", + "packed_long4", + "packed_vec", + // Metal Shading Language Specification: 2.3 Matrix Data Types + "half2x2", + "half2x3", + "half2x4", + "half3x2", + "half3x3", + "half3x4", + "half4x2", + "half4x3", + "half4x4", + "float2x2", + "float2x3", + "float2x4", + "float3x2", + "float3x3", + "float3x4", + "float4x2", + "float4x3", + "float4x4", + "matrix", + // Metal Shading Language Specification: 2.6 Atomic Data Types + "atomic", + "atomic_int", + "atomic_uint", + "atomic_bool", + "atomic_ulong", + "atomic_float", + // Metal Shading Language Specification: 2.20 Type Conversions and Re-interpreting Data + "as_type", + // Metal Shading Language Specification: 4 Address Spaces + "device", + "constant", + "thread", + "threadgroup", + "threadgroup_imageblock", + "ray_data", + "object_data", + // Metal Shading Language Specification: 5.1 Functions + "vertex", + "fragment", + "kernel", + // Metal Shading Language Specification: 6.1 Namespace and Header Files + "metal", + // C99 / C++ extension: + "restrict", + // Metal reserved types in <metal_types>: + "llong", + "ullong", + "quad", + "complex", + "imaginary", + // Constants in <metal_types>: + "CHAR_BIT", + "SCHAR_MAX", + "SCHAR_MIN", + "UCHAR_MAX", + "CHAR_MAX", + "CHAR_MIN", + "USHRT_MAX", + "SHRT_MAX", + "SHRT_MIN", + "UINT_MAX", + "INT_MAX", + "INT_MIN", + "ULONG_MAX", + "LONG_MAX", + "LONG_MIN", + "ULLONG_MAX", + "LLONG_MAX", + "LLONG_MIN", + "FLT_DIG", + "FLT_MANT_DIG", + "FLT_MAX_10_EXP", + "FLT_MAX_EXP", + "FLT_MIN_10_EXP", + "FLT_MIN_EXP", + "FLT_RADIX", + "FLT_MAX", + "FLT_MIN", + "FLT_EPSILON", + "FLT_DECIMAL_DIG", + "FP_ILOGB0", + "FP_ILOGB0", + "FP_ILOGBNAN", + "FP_ILOGBNAN", + "MAXFLOAT", + "HUGE_VALF", + "INFINITY", + "NAN", + "M_E_F", + "M_LOG2E_F", + "M_LOG10E_F", + "M_LN2_F", + "M_LN10_F", + "M_PI_F", + "M_PI_2_F", + "M_PI_4_F", + "M_1_PI_F", + "M_2_PI_F", + "M_2_SQRTPI_F", + "M_SQRT2_F", + "M_SQRT1_2_F", + "HALF_DIG", + "HALF_MANT_DIG", + "HALF_MAX_10_EXP", + "HALF_MAX_EXP", + "HALF_MIN_10_EXP", + "HALF_MIN_EXP", + "HALF_RADIX", + "HALF_MAX", + "HALF_MIN", + "HALF_EPSILON", + "HALF_DECIMAL_DIG", + "MAXHALF", + "HUGE_VALH", + "M_E_H", + "M_LOG2E_H", + "M_LOG10E_H", + "M_LN2_H", + "M_LN10_H", + "M_PI_H", + "M_PI_2_H", + "M_PI_4_H", + "M_1_PI_H", + "M_2_PI_H", + "M_2_SQRTPI_H", + "M_SQRT2_H", + "M_SQRT1_2_H", + "DBL_DIG", + "DBL_MANT_DIG", + "DBL_MAX_10_EXP", + "DBL_MAX_EXP", + "DBL_MIN_10_EXP", + "DBL_MIN_EXP", + "DBL_RADIX", + "DBL_MAX", + "DBL_MIN", + "DBL_EPSILON", + "DBL_DECIMAL_DIG", + "MAXDOUBLE", + "HUGE_VAL", + "M_E", + "M_LOG2E", + "M_LOG10E", + "M_LN2", + "M_LN10", + "M_PI", + "M_PI_2", + "M_PI_4", + "M_1_PI", + "M_2_PI", + "M_2_SQRTPI", + "M_SQRT2", + "M_SQRT1_2", + // Naga utilities + "DefaultConstructible", + super::writer::FREXP_FUNCTION, + super::writer::MODF_FUNCTION, +]; diff --git a/third_party/rust/naga/src/back/msl/mod.rs b/third_party/rust/naga/src/back/msl/mod.rs new file mode 100644 index 0000000000..5ef18730c9 --- /dev/null +++ b/third_party/rust/naga/src/back/msl/mod.rs @@ -0,0 +1,541 @@ +/*! +Backend for [MSL][msl] (Metal Shading Language). + +## Binding model + +Metal's bindings are flat per resource. Since there isn't an obvious mapping +from SPIR-V's descriptor sets, we require a separate mapping provided in the options. +This mapping may have one or more resource end points for each descriptor set + index +pair. + +## Entry points + +Even though MSL and our IR appear to be similar in that the entry points in both can +accept arguments and return values, the restrictions are different. +MSL allows the varyings to be either in separate arguments, or inside a single +`[[stage_in]]` struct. We gather input varyings and form this artificial structure. +We also add all the (non-Private) globals into the arguments. + +At the beginning of the entry point, we assign the local constants and re-compose +the arguments as they are declared on IR side, so that the rest of the logic can +pretend that MSL doesn't have all the restrictions it has. + +For the result type, if it's a structure, we re-compose it with a temporary value +holding the result. + +[msl]: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf +*/ + +use crate::{arena::Handle, proc::index, valid::ModuleInfo}; +use std::fmt::{Error as FmtError, Write}; + +mod keywords; +pub mod sampler; +mod writer; + +pub use writer::Writer; + +pub type Slot = u8; +pub type InlineSamplerIndex = u8; + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum BindSamplerTarget { + Resource(Slot), + Inline(InlineSamplerIndex), +} + +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))] +pub struct BindTarget { + pub buffer: Option<Slot>, + pub texture: Option<Slot>, + pub sampler: Option<BindSamplerTarget>, + /// If the binding is an unsized binding array, this overrides the size. + pub binding_array_size: Option<u32>, + pub mutable: bool, +} + +// Using `BTreeMap` instead of `HashMap` so that we can hash itself. +pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>; + +#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))] +pub struct EntryPointResources { + pub resources: BindingMap, + + pub push_constant_buffer: Option<Slot>, + + /// The slot of a buffer that contains an array of `u32`, + /// one for the size of each bound buffer that contains a runtime array, + /// in order of [`crate::GlobalVariable`] declarations. + pub sizes_buffer: Option<Slot>, +} + +pub type EntryPointResourceMap = std::collections::BTreeMap<String, EntryPointResources>; + +enum ResolvedBinding { + BuiltIn(crate::BuiltIn), + Attribute(u32), + Color { + location: u32, + second_blend_source: bool, + }, + User { + prefix: &'static str, + index: u32, + interpolation: Option<ResolvedInterpolation>, + }, + Resource(BindTarget), +} + +#[derive(Copy, Clone)] +enum ResolvedInterpolation { + CenterPerspective, + CenterNoPerspective, + CentroidPerspective, + CentroidNoPerspective, + SamplePerspective, + SampleNoPerspective, + Flat, +} + +// Note: some of these should be removed in favor of proper IR validation. + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error(transparent)] + Format(#[from] FmtError), + #[error("bind target {0:?} is empty")] + UnimplementedBindTarget(BindTarget), + #[error("composing of {0:?} is not implemented yet")] + UnsupportedCompose(Handle<crate::Type>), + #[error("operation {0:?} is not implemented yet")] + UnsupportedBinaryOp(crate::BinaryOperator), + #[error("standard function '{0}' is not implemented yet")] + UnsupportedCall(String), + #[error("feature '{0}' is not implemented yet")] + FeatureNotImplemented(String), + #[error("module is not valid")] + Validation, + #[error("BuiltIn {0:?} is not supported")] + UnsupportedBuiltIn(crate::BuiltIn), + #[error("capability {0:?} is not supported")] + CapabilityNotSupported(crate::valid::Capabilities), + #[error("attribute '{0}' is not supported for target MSL version")] + UnsupportedAttribute(String), + #[error("function '{0}' is not supported for target MSL version")] + UnsupportedFunction(String), + #[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")] + UnsupportedWriteableStorageBuffer, + #[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")] + UnsupportedWriteableStorageTexture(crate::ShaderStage), + #[error("can not use read-write storage textures prior to MSL 1.2")] + UnsupportedRWStorageTexture, + #[error("array of '{0}' is not supported for target MSL version")] + UnsupportedArrayOf(String), + #[error("array of type '{0:?}' is not supported")] + UnsupportedArrayOfType(Handle<crate::Type>), + #[error("ray tracing is not supported prior to MSL 2.3")] + UnsupportedRayTracing, +} + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum EntryPointError { + #[error("global '{0}' doesn't have a binding")] + MissingBinding(String), + #[error("mapping of {0:?} is missing")] + MissingBindTarget(crate::ResourceBinding), + #[error("mapping for push constants is missing")] + MissingPushConstants, + #[error("mapping for sizes buffer is missing")] + MissingSizesBuffer, +} + +/// Points in the MSL code where we might emit a pipeline input or output. +/// +/// Note that, even though vertex shaders' outputs are always fragment +/// shaders' inputs, we still need to distinguish `VertexOutput` and +/// `FragmentInput`, since there are certain differences in the way +/// [`ResolvedBinding`s] are represented on either side. +/// +/// [`ResolvedBinding`s]: ResolvedBinding +#[derive(Clone, Copy, Debug)] +enum LocationMode { + /// Input to the vertex shader. + VertexInput, + + /// Output from the vertex shader. + VertexOutput, + + /// Input to the fragment shader. + FragmentInput, + + /// Output from the fragment shader. + FragmentOutput, + + /// Compute shader input or output. + Uniform, +} + +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct Options { + /// (Major, Minor) target version of the Metal Shading Language. + pub lang_version: (u8, u8), + /// Map of entry-point resources, indexed by entry point function name, to slots. + pub per_entry_point_map: EntryPointResourceMap, + /// Samplers to be inlined into the code. + pub inline_samplers: Vec<sampler::InlineSampler>, + /// Make it possible to link different stages via SPIRV-Cross. + pub spirv_cross_compatibility: bool, + /// Don't panic on missing bindings, instead generate invalid MSL. + pub fake_missing_bindings: bool, + /// Bounds checking policies. + #[cfg_attr(feature = "deserialize", serde(default))] + pub bounds_check_policies: index::BoundsCheckPolicies, + /// Should workgroup variables be zero initialized (by polyfilling)? + pub zero_initialize_workgroup_memory: bool, +} + +impl Default for Options { + fn default() -> Self { + Options { + lang_version: (1, 0), + per_entry_point_map: EntryPointResourceMap::default(), + inline_samplers: Vec::new(), + spirv_cross_compatibility: false, + fake_missing_bindings: true, + bounds_check_policies: index::BoundsCheckPolicies::default(), + zero_initialize_workgroup_memory: true, + } + } +} + +/// A subset of options that are meant to be changed per pipeline. +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct PipelineOptions { + /// Allow `BuiltIn::PointSize` and inject it if doesn't exist. + /// + /// Metal doesn't like this for non-point primitive topologies and requires it for + /// point primitive topologies. + /// + /// Enable this for vertex shaders with point primitive topologies. + pub allow_and_force_point_size: bool, +} + +impl Options { + fn resolve_local_binding( + &self, + binding: &crate::Binding, + mode: LocationMode, + ) -> Result<ResolvedBinding, Error> { + match *binding { + crate::Binding::BuiltIn(mut built_in) => { + match built_in { + crate::BuiltIn::Position { ref mut invariant } => { + if *invariant && self.lang_version < (2, 1) { + return Err(Error::UnsupportedAttribute("invariant".to_string())); + } + + // The 'invariant' attribute may only appear on vertex + // shader outputs, not fragment shader inputs. + if !matches!(mode, LocationMode::VertexOutput) { + *invariant = false; + } + } + crate::BuiltIn::BaseInstance if self.lang_version < (1, 2) => { + return Err(Error::UnsupportedAttribute("base_instance".to_string())); + } + crate::BuiltIn::InstanceIndex if self.lang_version < (1, 2) => { + return Err(Error::UnsupportedAttribute("instance_id".to_string())); + } + // macOS: Since Metal 2.2 + // iOS: Since Metal 2.3 (check depends on https://github.com/gfx-rs/naga/issues/2164) + crate::BuiltIn::PrimitiveIndex if self.lang_version < (2, 2) => { + return Err(Error::UnsupportedAttribute("primitive_id".to_string())); + } + _ => {} + } + + Ok(ResolvedBinding::BuiltIn(built_in)) + } + crate::Binding::Location { + location, + interpolation, + sampling, + second_blend_source, + } => match mode { + LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(location)), + LocationMode::FragmentOutput => { + if second_blend_source && self.lang_version < (1, 2) { + return Err(Error::UnsupportedAttribute( + "second_blend_source".to_string(), + )); + } + Ok(ResolvedBinding::Color { + location, + second_blend_source, + }) + } + LocationMode::VertexOutput | LocationMode::FragmentInput => { + Ok(ResolvedBinding::User { + prefix: if self.spirv_cross_compatibility { + "locn" + } else { + "loc" + }, + index: location, + interpolation: { + // unwrap: The verifier ensures that vertex shader outputs and fragment + // shader inputs always have fully specified interpolation, and that + // sampling is `None` only for Flat interpolation. + let interpolation = interpolation.unwrap(); + let sampling = sampling.unwrap_or(crate::Sampling::Center); + Some(ResolvedInterpolation::from_binding(interpolation, sampling)) + }, + }) + } + LocationMode::Uniform => { + log::error!( + "Unexpected Binding::Location({}) for the Uniform mode", + location + ); + Err(Error::Validation) + } + }, + } + } + + fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> { + self.per_entry_point_map.get(&ep.name) + } + + fn get_resource_binding_target( + &self, + ep: &crate::EntryPoint, + res_binding: &crate::ResourceBinding, + ) -> Option<&BindTarget> { + self.get_entry_point_resources(ep) + .and_then(|res| res.resources.get(res_binding)) + } + + fn resolve_resource_binding( + &self, + ep: &crate::EntryPoint, + res_binding: &crate::ResourceBinding, + ) -> Result<ResolvedBinding, EntryPointError> { + let target = self.get_resource_binding_target(ep, res_binding); + match target { + Some(target) => Ok(ResolvedBinding::Resource(target.clone())), + None if self.fake_missing_bindings => Ok(ResolvedBinding::User { + prefix: "fake", + index: 0, + interpolation: None, + }), + None => Err(EntryPointError::MissingBindTarget(res_binding.clone())), + } + } + + fn resolve_push_constants( + &self, + ep: &crate::EntryPoint, + ) -> Result<ResolvedBinding, EntryPointError> { + let slot = self + .get_entry_point_resources(ep) + .and_then(|res| res.push_constant_buffer); + match slot { + Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { + buffer: Some(slot), + ..Default::default() + })), + None if self.fake_missing_bindings => Ok(ResolvedBinding::User { + prefix: "fake", + index: 0, + interpolation: None, + }), + None => Err(EntryPointError::MissingPushConstants), + } + } + + fn resolve_sizes_buffer( + &self, + ep: &crate::EntryPoint, + ) -> Result<ResolvedBinding, EntryPointError> { + let slot = self + .get_entry_point_resources(ep) + .and_then(|res| res.sizes_buffer); + match slot { + Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { + buffer: Some(slot), + ..Default::default() + })), + None if self.fake_missing_bindings => Ok(ResolvedBinding::User { + prefix: "fake", + index: 0, + interpolation: None, + }), + None => Err(EntryPointError::MissingSizesBuffer), + } + } +} + +impl ResolvedBinding { + fn as_inline_sampler<'a>(&self, options: &'a Options) -> Option<&'a sampler::InlineSampler> { + match *self { + Self::Resource(BindTarget { + sampler: Some(BindSamplerTarget::Inline(index)), + .. + }) => Some(&options.inline_samplers[index as usize]), + _ => None, + } + } + + const fn as_bind_target(&self) -> Option<&BindTarget> { + match *self { + Self::Resource(ref target) => Some(target), + _ => None, + } + } + + fn try_fmt<W: Write>(&self, out: &mut W) -> Result<(), Error> { + write!(out, " [[")?; + match *self { + Self::BuiltIn(built_in) => { + use crate::BuiltIn as Bi; + let name = match built_in { + Bi::Position { invariant: false } => "position", + Bi::Position { invariant: true } => "position, invariant", + // vertex + Bi::BaseInstance => "base_instance", + Bi::BaseVertex => "base_vertex", + Bi::ClipDistance => "clip_distance", + Bi::InstanceIndex => "instance_id", + Bi::PointSize => "point_size", + Bi::VertexIndex => "vertex_id", + // fragment + Bi::FragDepth => "depth(any)", + Bi::PointCoord => "point_coord", + Bi::FrontFacing => "front_facing", + Bi::PrimitiveIndex => "primitive_id", + Bi::SampleIndex => "sample_id", + Bi::SampleMask => "sample_mask", + // compute + Bi::GlobalInvocationId => "thread_position_in_grid", + Bi::LocalInvocationId => "thread_position_in_threadgroup", + Bi::LocalInvocationIndex => "thread_index_in_threadgroup", + Bi::WorkGroupId => "threadgroup_position_in_grid", + Bi::WorkGroupSize => "dispatch_threads_per_threadgroup", + Bi::NumWorkGroups => "threadgroups_per_grid", + Bi::CullDistance | Bi::ViewIndex => { + return Err(Error::UnsupportedBuiltIn(built_in)) + } + }; + write!(out, "{name}")?; + } + Self::Attribute(index) => write!(out, "attribute({index})")?, + Self::Color { + location, + second_blend_source, + } => { + if second_blend_source { + write!(out, "color({location}) index(1)")? + } else { + write!(out, "color({location})")? + } + } + Self::User { + prefix, + index, + interpolation, + } => { + write!(out, "user({prefix}{index})")?; + if let Some(interpolation) = interpolation { + write!(out, ", ")?; + interpolation.try_fmt(out)?; + } + } + Self::Resource(ref target) => { + if let Some(id) = target.buffer { + write!(out, "buffer({id})")?; + } else if let Some(id) = target.texture { + write!(out, "texture({id})")?; + } else if let Some(BindSamplerTarget::Resource(id)) = target.sampler { + write!(out, "sampler({id})")?; + } else { + return Err(Error::UnimplementedBindTarget(target.clone())); + } + } + } + write!(out, "]]")?; + Ok(()) + } +} + +impl ResolvedInterpolation { + const fn from_binding(interpolation: crate::Interpolation, sampling: crate::Sampling) -> Self { + use crate::Interpolation as I; + use crate::Sampling as S; + + match (interpolation, sampling) { + (I::Perspective, S::Center) => Self::CenterPerspective, + (I::Perspective, S::Centroid) => Self::CentroidPerspective, + (I::Perspective, S::Sample) => Self::SamplePerspective, + (I::Linear, S::Center) => Self::CenterNoPerspective, + (I::Linear, S::Centroid) => Self::CentroidNoPerspective, + (I::Linear, S::Sample) => Self::SampleNoPerspective, + (I::Flat, _) => Self::Flat, + } + } + + fn try_fmt<W: Write>(self, out: &mut W) -> Result<(), Error> { + let identifier = match self { + Self::CenterPerspective => "center_perspective", + Self::CenterNoPerspective => "center_no_perspective", + Self::CentroidPerspective => "centroid_perspective", + Self::CentroidNoPerspective => "centroid_no_perspective", + Self::SamplePerspective => "sample_perspective", + Self::SampleNoPerspective => "sample_no_perspective", + Self::Flat => "flat", + }; + out.write_str(identifier)?; + Ok(()) + } +} + +/// Information about a translated module that is required +/// for the use of the result. +pub struct TranslationInfo { + /// Mapping of the entry point names. Each item in the array + /// corresponds to an entry point index. + /// + ///Note: Some entry points may fail translation because of missing bindings. + pub entry_point_names: Vec<Result<String, EntryPointError>>, +} + +pub fn write_string( + module: &crate::Module, + info: &ModuleInfo, + options: &Options, + pipeline_options: &PipelineOptions, +) -> Result<(String, TranslationInfo), Error> { + let mut w = writer::Writer::new(String::new()); + let info = w.write(module, info, options, pipeline_options)?; + Ok((w.finish(), info)) +} + +#[test] +fn test_error_size() { + use std::mem::size_of; + assert_eq!(size_of::<Error>(), 32); +} diff --git a/third_party/rust/naga/src/back/msl/sampler.rs b/third_party/rust/naga/src/back/msl/sampler.rs new file mode 100644 index 0000000000..0bf987076d --- /dev/null +++ b/third_party/rust/naga/src/back/msl/sampler.rs @@ -0,0 +1,176 @@ +#[cfg(feature = "deserialize")] +use serde::Deserialize; +#[cfg(feature = "serialize")] +use serde::Serialize; +use std::{num::NonZeroU32, ops::Range}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum Coord { + Normalized, + Pixel, +} + +impl Default for Coord { + fn default() -> Self { + Self::Normalized + } +} + +impl Coord { + pub const fn as_str(&self) -> &'static str { + match *self { + Self::Normalized => "normalized", + Self::Pixel => "pixel", + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum Address { + Repeat, + MirroredRepeat, + ClampToEdge, + ClampToZero, + ClampToBorder, +} + +impl Default for Address { + fn default() -> Self { + Self::ClampToEdge + } +} + +impl Address { + pub const fn as_str(&self) -> &'static str { + match *self { + Self::Repeat => "repeat", + Self::MirroredRepeat => "mirrored_repeat", + Self::ClampToEdge => "clamp_to_edge", + Self::ClampToZero => "clamp_to_zero", + Self::ClampToBorder => "clamp_to_border", + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum BorderColor { + TransparentBlack, + OpaqueBlack, + OpaqueWhite, +} + +impl Default for BorderColor { + fn default() -> Self { + Self::TransparentBlack + } +} + +impl BorderColor { + pub const fn as_str(&self) -> &'static str { + match *self { + Self::TransparentBlack => "transparent_black", + Self::OpaqueBlack => "opaque_black", + Self::OpaqueWhite => "opaque_white", + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum Filter { + Nearest, + Linear, +} + +impl Filter { + pub const fn as_str(&self) -> &'static str { + match *self { + Self::Nearest => "nearest", + Self::Linear => "linear", + } + } +} + +impl Default for Filter { + fn default() -> Self { + Self::Nearest + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub enum CompareFunc { + Never, + Less, + LessEqual, + Greater, + GreaterEqual, + Equal, + NotEqual, + Always, +} + +impl Default for CompareFunc { + fn default() -> Self { + Self::Never + } +} + +impl CompareFunc { + pub const fn as_str(&self) -> &'static str { + match *self { + Self::Never => "never", + Self::Less => "less", + Self::LessEqual => "less_equal", + Self::Greater => "greater", + Self::GreaterEqual => "greater_equal", + Self::Equal => "equal", + Self::NotEqual => "not_equal", + Self::Always => "always", + } + } +} + +#[derive(Clone, Debug, Default, PartialEq)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +pub struct InlineSampler { + pub coord: Coord, + pub address: [Address; 3], + pub border_color: BorderColor, + pub mag_filter: Filter, + pub min_filter: Filter, + pub mip_filter: Option<Filter>, + pub lod_clamp: Option<Range<f32>>, + pub max_anisotropy: Option<NonZeroU32>, + pub compare_func: CompareFunc, +} + +impl Eq for InlineSampler {} + +#[allow(renamed_and_removed_lints)] +#[allow(clippy::derive_hash_xor_eq)] +impl std::hash::Hash for InlineSampler { + fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) { + self.coord.hash(hasher); + self.address.hash(hasher); + self.border_color.hash(hasher); + self.mag_filter.hash(hasher); + self.min_filter.hash(hasher); + self.mip_filter.hash(hasher); + self.lod_clamp + .as_ref() + .map(|range| (range.start.to_bits(), range.end.to_bits())) + .hash(hasher); + self.max_anisotropy.hash(hasher); + self.compare_func.hash(hasher); + } +} diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs new file mode 100644 index 0000000000..1e496b5f50 --- /dev/null +++ b/third_party/rust/naga/src/back/msl/writer.rs @@ -0,0 +1,4659 @@ +use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo}; +use crate::{ + arena::Handle, + back, + proc::index, + proc::{self, NameKey, TypeResolution}, + valid, FastHashMap, FastHashSet, +}; +use bit_set::BitSet; +use std::{ + fmt::{Display, Error as FmtError, Formatter, Write}, + iter, +}; + +/// Shorthand result used internally by the backend +type BackendResult = Result<(), Error>; + +const NAMESPACE: &str = "metal"; +// The name of the array member of the Metal struct types we generate to +// represent Naga `Array` types. See the comments in `Writer::write_type_defs` +// for details. +const WRAPPED_ARRAY_FIELD: &str = "inner"; +// This is a hack: we need to pass a pointer to an atomic, +// but generally the backend isn't putting "&" in front of every pointer. +// Some more general handling of pointers is needed to be implemented here. +const ATOMIC_REFERENCE: &str = "&"; + +const RT_NAMESPACE: &str = "metal::raytracing"; +const RAY_QUERY_TYPE: &str = "_RayQuery"; +const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector"; +const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection"; +const RAY_QUERY_FIELD_READY: &str = "ready"; +const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type"; + +pub(crate) const MODF_FUNCTION: &str = "naga_modf"; +pub(crate) const FREXP_FUNCTION: &str = "naga_frexp"; + +/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix. +/// +/// The `sizes` slice determines whether this function writes a +/// scalar, vector, or matrix type: +/// +/// - An empty slice produces a scalar type. +/// - A one-element slice produces a vector type. +/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size. +fn put_numeric_type( + out: &mut impl Write, + scalar: crate::Scalar, + sizes: &[crate::VectorSize], +) -> Result<(), FmtError> { + match (scalar, sizes) { + (scalar, &[]) => { + write!(out, "{}", scalar.to_msl_name()) + } + (scalar, &[rows]) => { + write!( + out, + "{}::{}{}", + NAMESPACE, + scalar.to_msl_name(), + back::vector_size_str(rows) + ) + } + (scalar, &[rows, columns]) => { + write!( + out, + "{}::{}{}x{}", + NAMESPACE, + scalar.to_msl_name(), + back::vector_size_str(columns), + back::vector_size_str(rows) + ) + } + (_, _) => Ok(()), // not meaningful + } +} + +/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions. +const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e"; + +struct TypeContext<'a> { + handle: Handle<crate::Type>, + gctx: proc::GlobalCtx<'a>, + names: &'a FastHashMap<NameKey, String>, + access: crate::StorageAccess, + binding: Option<&'a super::ResolvedBinding>, + first_time: bool, +} + +impl<'a> Display for TypeContext<'a> { + fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> { + let ty = &self.gctx.types[self.handle]; + if ty.needs_alias() && !self.first_time { + let name = &self.names[&NameKey::Type(self.handle)]; + return write!(out, "{name}"); + } + + match ty.inner { + crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]), + crate::TypeInner::Atomic(scalar) => { + write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name()) + } + crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]), + crate::TypeInner::Matrix { columns, rows, .. } => { + put_numeric_type(out, crate::Scalar::F32, &[rows, columns]) + } + crate::TypeInner::Pointer { base, space } => { + let sub = Self { + handle: base, + first_time: false, + ..*self + }; + let space_name = match space.to_msl_name() { + Some(name) => name, + None => return Ok(()), + }; + write!(out, "{space_name} {sub}&") + } + crate::TypeInner::ValuePointer { + size, + scalar, + space, + } => { + match space.to_msl_name() { + Some(name) => write!(out, "{name} ")?, + None => return Ok(()), + }; + match size { + Some(rows) => put_numeric_type(out, scalar, &[rows])?, + None => put_numeric_type(out, scalar, &[])?, + }; + + write!(out, "&") + } + crate::TypeInner::Array { base, .. } => { + let sub = Self { + handle: base, + first_time: false, + ..*self + }; + // Array lengths go at the end of the type definition, + // so just print the element type here. + write!(out, "{sub}") + } + crate::TypeInner::Struct { .. } => unreachable!(), + crate::TypeInner::Image { + dim, + arrayed, + class, + } => { + let dim_str = match dim { + crate::ImageDimension::D1 => "1d", + crate::ImageDimension::D2 => "2d", + crate::ImageDimension::D3 => "3d", + crate::ImageDimension::Cube => "cube", + }; + let (texture_str, msaa_str, kind, access) = match class { + crate::ImageClass::Sampled { kind, multi } => { + let (msaa_str, access) = if multi { + ("_ms", "read") + } else { + ("", "sample") + }; + ("texture", msaa_str, kind, access) + } + crate::ImageClass::Depth { multi } => { + let (msaa_str, access) = if multi { + ("_ms", "read") + } else { + ("", "sample") + }; + ("depth", msaa_str, crate::ScalarKind::Float, access) + } + crate::ImageClass::Storage { format, .. } => { + let access = if self + .access + .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE) + { + "read_write" + } else if self.access.contains(crate::StorageAccess::STORE) { + "write" + } else if self.access.contains(crate::StorageAccess::LOAD) { + "read" + } else { + log::warn!( + "Storage access for {:?} (name '{}'): {:?}", + self.handle, + ty.name.as_deref().unwrap_or_default(), + self.access + ); + unreachable!("module is not valid"); + }; + ("texture", "", format.into(), access) + } + }; + let base_name = crate::Scalar { kind, width: 4 }.to_msl_name(); + let array_str = if arrayed { "_array" } else { "" }; + write!( + out, + "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>", + ) + } + crate::TypeInner::Sampler { comparison: _ } => { + write!(out, "{NAMESPACE}::sampler") + } + crate::TypeInner::AccelerationStructure => { + write!(out, "{RT_NAMESPACE}::instance_acceleration_structure") + } + crate::TypeInner::RayQuery => { + write!(out, "{RAY_QUERY_TYPE}") + } + crate::TypeInner::BindingArray { base, size } => { + let base_tyname = Self { + handle: base, + first_time: false, + ..*self + }; + + if let Some(&super::ResolvedBinding::Resource(super::BindTarget { + binding_array_size: Some(override_size), + .. + })) = self.binding + { + write!(out, "{NAMESPACE}::array<{base_tyname}, {override_size}>") + } else if let crate::ArraySize::Constant(size) = size { + write!(out, "{NAMESPACE}::array<{base_tyname}, {size}>") + } else { + unreachable!("metal requires all arrays be constant sized"); + } + } + } + } +} + +struct TypedGlobalVariable<'a> { + module: &'a crate::Module, + names: &'a FastHashMap<NameKey, String>, + handle: Handle<crate::GlobalVariable>, + usage: valid::GlobalUse, + binding: Option<&'a super::ResolvedBinding>, + reference: bool, +} + +impl<'a> TypedGlobalVariable<'a> { + fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult { + let var = &self.module.global_variables[self.handle]; + let name = &self.names[&NameKey::GlobalVariable(self.handle)]; + + let storage_access = match var.space { + crate::AddressSpace::Storage { access } => access, + _ => match self.module.types[var.ty].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage { access, .. }, + .. + } => access, + crate::TypeInner::BindingArray { base, .. } => { + match self.module.types[base].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage { access, .. }, + .. + } => access, + _ => crate::StorageAccess::default(), + } + } + _ => crate::StorageAccess::default(), + }, + }; + let ty_name = TypeContext { + handle: var.ty, + gctx: self.module.to_ctx(), + names: self.names, + access: storage_access, + binding: self.binding, + first_time: false, + }; + + let (space, access, reference) = match var.space.to_msl_name() { + Some(space) if self.reference => { + let access = if var.space.needs_access_qualifier() + && !self.usage.contains(valid::GlobalUse::WRITE) + { + "const" + } else { + "" + }; + (space, access, "&") + } + _ => ("", "", ""), + }; + + Ok(write!( + out, + "{}{}{}{}{}{} {}", + space, + if space.is_empty() { "" } else { " " }, + ty_name, + if access.is_empty() { "" } else { " " }, + access, + reference, + name, + )?) + } +} + +pub struct Writer<W> { + out: W, + names: FastHashMap<NameKey, String>, + named_expressions: crate::NamedExpressions, + /// Set of expressions that need to be baked to avoid unnecessary repetition in output + need_bake_expressions: back::NeedBakeExpressions, + namer: proc::Namer, + #[cfg(test)] + put_expression_stack_pointers: FastHashSet<*const ()>, + #[cfg(test)] + put_block_stack_pointers: FastHashSet<*const ()>, + /// Set of (struct type, struct field index) denoting which fields require + /// padding inserted **before** them (i.e. between fields at index - 1 and index) + struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>, +} + +impl crate::Scalar { + const fn to_msl_name(self) -> &'static str { + use crate::ScalarKind as Sk; + match self { + Self { + kind: Sk::Float, + width: _, + } => "float", + Self { + kind: Sk::Sint, + width: _, + } => "int", + Self { + kind: Sk::Uint, + width: _, + } => "uint", + Self { + kind: Sk::Bool, + width: _, + } => "bool", + Self { + kind: Sk::AbstractInt | Sk::AbstractFloat, + width: _, + } => unreachable!(), + } + } +} + +const fn separate(need_separator: bool) -> &'static str { + if need_separator { + "," + } else { + "" + } +} + +fn should_pack_struct_member( + members: &[crate::StructMember], + span: u32, + index: usize, + module: &crate::Module, +) -> Option<crate::Scalar> { + let member = &members[index]; + + let ty_inner = &module.types[member.ty].inner; + let last_offset = member.offset + ty_inner.size(module.to_ctx()); + let next_offset = match members.get(index + 1) { + Some(next) => next.offset, + None => span, + }; + let is_tight = next_offset == last_offset; + + match *ty_inner { + crate::TypeInner::Vector { + size: crate::VectorSize::Tri, + scalar: scalar @ crate::Scalar { width: 4, .. }, + } if is_tight => Some(scalar), + _ => None, + } +} + +fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool { + match arena[ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + if let Some(member) = members.last() { + if let crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } = arena[member.ty].inner + { + return true; + } + } + false + } + crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } => true, + _ => false, + } +} + +impl crate::AddressSpace { + /// Returns true if global variables in this address space are + /// passed in function arguments. These arguments need to be + /// passed through any functions called from the entry point. + const fn needs_pass_through(&self) -> bool { + match *self { + Self::Uniform + | Self::Storage { .. } + | Self::Private + | Self::WorkGroup + | Self::PushConstant + | Self::Handle => true, + Self::Function => false, + } + } + + /// Returns true if the address space may need a "const" qualifier. + const fn needs_access_qualifier(&self) -> bool { + match *self { + //Note: we are ignoring the storage access here, and instead + // rely on the actual use of a global by functions. This means we + // may end up with "const" even if the binding is read-write, + // and that should be OK. + Self::Storage { .. } => true, + // These should always be read-write. + Self::Private | Self::WorkGroup => false, + // These translate to `constant` address space, no need for qualifiers. + Self::Uniform | Self::PushConstant => false, + // Not applicable. + Self::Handle | Self::Function => false, + } + } + + const fn to_msl_name(self) -> Option<&'static str> { + match self { + Self::Handle => None, + Self::Uniform | Self::PushConstant => Some("constant"), + Self::Storage { .. } => Some("device"), + Self::Private | Self::Function => Some("thread"), + Self::WorkGroup => Some("threadgroup"), + } + } +} + +impl crate::Type { + // Returns `true` if we need to emit an alias for this type. + const fn needs_alias(&self) -> bool { + use crate::TypeInner as Ti; + + match self.inner { + // value types are concise enough, we only alias them if they are named + Ti::Scalar(_) + | Ti::Vector { .. } + | Ti::Matrix { .. } + | Ti::Atomic(_) + | Ti::Pointer { .. } + | Ti::ValuePointer { .. } => self.name.is_some(), + // composite types are better to be aliased, regardless of the name + Ti::Struct { .. } | Ti::Array { .. } => true, + // handle types may be different, depending on the global var access, so we always inline them + Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => false, + } + } +} + +enum FunctionOrigin { + Handle(Handle<crate::Function>), + EntryPoint(proc::EntryPointIndex), +} + +/// A level of detail argument. +/// +/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we +/// save the clamped level of detail in a temporary variable whose name is based +/// on the handle of the `ImageLoad` expression. But for other policies, we just +/// use the expression directly. +/// +/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict +/// [`ImageLoad`]: crate::Expression::ImageLoad +#[derive(Clone, Copy)] +enum LevelOfDetail { + Direct(Handle<crate::Expression>), + Restricted(Handle<crate::Expression>), +} + +/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`]. +/// +/// When this is used in code paths unconcerned with the `Restrict` bounds check +/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level` +/// will always be either `None` or `Some(Direct(_))`. But this turns out not to +/// be too awkward. If that changes, we can revisit. +/// +/// [`ImageLoad`]: crate::Expression::ImageLoad +/// [`ImageStore`]: crate::Statement::ImageStore +struct TexelAddress { + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + sample: Option<Handle<crate::Expression>>, + level: Option<LevelOfDetail>, +} + +struct ExpressionContext<'a> { + function: &'a crate::Function, + origin: FunctionOrigin, + info: &'a valid::FunctionInfo, + module: &'a crate::Module, + mod_info: &'a valid::ModuleInfo, + pipeline_options: &'a PipelineOptions, + lang_version: (u8, u8), + policies: index::BoundsCheckPolicies, + + /// A bitset containing the `Expression` handle indexes of expressions used + /// as indices in `ReadZeroSkipWrite`-policy accesses. These may need to be + /// cached in temporary variables. See `index::find_checked_indexes` for + /// details. + guarded_indices: BitSet, +} + +impl<'a> ExpressionContext<'a> { + fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner { + self.info[handle].ty.inner_with(&self.module.types) + } + + /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail. + /// + /// Only mipmapped images need to specify a level of detail. Since 1D + /// textures cannot have mipmaps, MSL requires that the level argument to + /// texture1d queries and accesses must be a constexpr 0. It's easiest + /// just to omit the level entirely for 1D textures. + fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool { + let image_ty = self.resolve_type(image); + if let crate::TypeInner::Image { dim, class, .. } = *image_ty { + class.is_mipmapped() && dim != crate::ImageDimension::D1 + } else { + false + } + } + + fn choose_bounds_check_policy( + &self, + pointer: Handle<crate::Expression>, + ) -> index::BoundsCheckPolicy { + self.policies + .choose_policy(pointer, &self.module.types, self.info) + } + + fn access_needs_check( + &self, + base: Handle<crate::Expression>, + index: index::GuardedIndex, + ) -> Option<index::IndexableLength> { + index::access_needs_check(base, index, self.module, self.function, self.info) + } + + fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> { + match self.function.expressions[expr_handle] { + crate::Expression::AccessIndex { base, index } => { + let ty = match *self.resolve_type(base) { + crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner, + ref ty => ty, + }; + match *ty { + crate::TypeInner::Struct { + ref members, span, .. + } => should_pack_struct_member(members, span, index as usize, self.module), + _ => None, + } + } + _ => None, + } + } +} + +struct StatementContext<'a> { + expression: ExpressionContext<'a>, + result_struct: Option<&'a str>, +} + +impl<W: Write> Writer<W> { + /// Creates a new `Writer` instance. + pub fn new(out: W) -> Self { + Writer { + out, + names: FastHashMap::default(), + named_expressions: Default::default(), + need_bake_expressions: Default::default(), + namer: proc::Namer::default(), + #[cfg(test)] + put_expression_stack_pointers: Default::default(), + #[cfg(test)] + put_block_stack_pointers: Default::default(), + struct_member_pads: FastHashSet::default(), + } + } + + /// Finishes writing and returns the output. + // See https://github.com/rust-lang/rust-clippy/issues/4979. + #[allow(clippy::missing_const_for_fn)] + pub fn finish(self) -> W { + self.out + } + + fn put_call_parameters( + &mut self, + parameters: impl Iterator<Item = Handle<crate::Expression>>, + context: &ExpressionContext, + ) -> BackendResult { + self.put_call_parameters_impl(parameters, context, |writer, context, expr| { + writer.put_expression(expr, context, true) + }) + } + + fn put_call_parameters_impl<C, E>( + &mut self, + parameters: impl Iterator<Item = Handle<crate::Expression>>, + ctx: &C, + put_expression: E, + ) -> BackendResult + where + E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult, + { + write!(self.out, "(")?; + for (i, handle) in parameters.enumerate() { + if i != 0 { + write!(self.out, ", ")?; + } + put_expression(self, ctx, handle)?; + } + write!(self.out, ")")?; + Ok(()) + } + + fn put_level_of_detail( + &mut self, + level: LevelOfDetail, + context: &ExpressionContext, + ) -> BackendResult { + match level { + LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?, + LevelOfDetail::Restricted(load) => { + write!(self.out, "{}{}", CLAMPED_LOD_LOAD_PREFIX, load.index())? + } + } + Ok(()) + } + + fn put_image_query( + &mut self, + image: Handle<crate::Expression>, + query: &str, + level: Option<LevelOfDetail>, + context: &ExpressionContext, + ) -> BackendResult { + self.put_expression(image, context, false)?; + write!(self.out, ".get_{query}(")?; + if let Some(level) = level { + self.put_level_of_detail(level, context)?; + } + write!(self.out, ")")?; + Ok(()) + } + + fn put_image_size_query( + &mut self, + image: Handle<crate::Expression>, + level: Option<LevelOfDetail>, + kind: crate::ScalarKind, + context: &ExpressionContext, + ) -> BackendResult { + //Note: MSL only has separate width/height/depth queries, + // so compose the result of them. + let dim = match *context.resolve_type(image) { + crate::TypeInner::Image { dim, .. } => dim, + ref other => unreachable!("Unexpected type {:?}", other), + }; + let scalar = crate::Scalar { kind, width: 4 }; + let coordinate_type = scalar.to_msl_name(); + match dim { + crate::ImageDimension::D1 => { + // Since 1D textures never have mipmaps, MSL requires that the + // `level` argument be a constexpr 0. It's simplest for us just + // to pass `None` and omit the level entirely. + if kind == crate::ScalarKind::Uint { + // No need to construct a vector. No cast needed. + self.put_image_query(image, "width", None, context)?; + } else { + // There's no definition for `int` in the `metal` namespace. + write!(self.out, "int(")?; + self.put_image_query(image, "width", None, context)?; + write!(self.out, ")")?; + } + } + crate::ImageDimension::D2 => { + write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?; + self.put_image_query(image, "width", level, context)?; + write!(self.out, ", ")?; + self.put_image_query(image, "height", level, context)?; + write!(self.out, ")")?; + } + crate::ImageDimension::D3 => { + write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?; + self.put_image_query(image, "width", level, context)?; + write!(self.out, ", ")?; + self.put_image_query(image, "height", level, context)?; + write!(self.out, ", ")?; + self.put_image_query(image, "depth", level, context)?; + write!(self.out, ")")?; + } + crate::ImageDimension::Cube => { + write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?; + self.put_image_query(image, "width", level, context)?; + write!(self.out, ")")?; + } + } + Ok(()) + } + + fn put_cast_to_uint_scalar_or_vector( + &mut self, + expr: Handle<crate::Expression>, + context: &ExpressionContext, + ) -> BackendResult { + // coordinates in IR are int, but Metal expects uint + match *context.resolve_type(expr) { + crate::TypeInner::Scalar(_) => { + put_numeric_type(&mut self.out, crate::Scalar::U32, &[])? + } + crate::TypeInner::Vector { size, .. } => { + put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])? + } + _ => return Err(Error::Validation), + }; + + write!(self.out, "(")?; + self.put_expression(expr, context, true)?; + write!(self.out, ")")?; + Ok(()) + } + + fn put_image_sample_level( + &mut self, + image: Handle<crate::Expression>, + level: crate::SampleLevel, + context: &ExpressionContext, + ) -> BackendResult { + let has_levels = context.image_needs_lod(image); + match level { + crate::SampleLevel::Auto => {} + crate::SampleLevel::Zero => { + //TODO: do we support Zero on `Sampled` image classes? + } + _ if !has_levels => { + log::warn!("1D image can't be sampled with level {:?}", level); + } + crate::SampleLevel::Exact(h) => { + write!(self.out, ", {NAMESPACE}::level(")?; + self.put_expression(h, context, true)?; + write!(self.out, ")")?; + } + crate::SampleLevel::Bias(h) => { + write!(self.out, ", {NAMESPACE}::bias(")?; + self.put_expression(h, context, true)?; + write!(self.out, ")")?; + } + crate::SampleLevel::Gradient { x, y } => { + write!(self.out, ", {NAMESPACE}::gradient2d(")?; + self.put_expression(x, context, true)?; + write!(self.out, ", ")?; + self.put_expression(y, context, true)?; + write!(self.out, ")")?; + } + } + Ok(()) + } + + fn put_image_coordinate_limits( + &mut self, + image: Handle<crate::Expression>, + level: Option<LevelOfDetail>, + context: &ExpressionContext, + ) -> BackendResult { + self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?; + write!(self.out, " - 1")?; + Ok(()) + } + + /// General function for writing restricted image indexes. + /// + /// This is used to produce restricted mip levels, array indices, and sample + /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the + /// [`Restrict`] bounds check policy. + /// + /// This function writes an expression of the form: + /// + /// ```ignore + /// + /// metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1) + /// + /// ``` + /// + /// [`ImageLoad`]: crate::Expression::ImageLoad + /// [`ImageStore`]: crate::Statement::ImageStore + /// [`Restrict`]: index::BoundsCheckPolicy::Restrict + fn put_restricted_scalar_image_index( + &mut self, + image: Handle<crate::Expression>, + index: Handle<crate::Expression>, + limit_method: &str, + context: &ExpressionContext, + ) -> BackendResult { + write!(self.out, "{NAMESPACE}::min(uint(")?; + self.put_expression(index, context, true)?; + write!(self.out, "), ")?; + self.put_expression(image, context, false)?; + write!(self.out, ".{limit_method}() - 1)")?; + Ok(()) + } + + fn put_restricted_texel_address( + &mut self, + image: Handle<crate::Expression>, + address: &TexelAddress, + context: &ExpressionContext, + ) -> BackendResult { + // Write the coordinate. + write!(self.out, "{NAMESPACE}::min(")?; + self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?; + write!(self.out, ", ")?; + self.put_image_coordinate_limits(image, address.level, context)?; + write!(self.out, ")")?; + + // Write the array index, if present. + if let Some(array_index) = address.array_index { + write!(self.out, ", ")?; + self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?; + } + + // Write the sample index, if present. + if let Some(sample) = address.sample { + write!(self.out, ", ")?; + self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?; + } + + // The level of detail should be clamped and cached by + // `put_cache_restricted_level`, so we don't need to clamp it here. + if let Some(level) = address.level { + write!(self.out, ", ")?; + self.put_level_of_detail(level, context)?; + } + + Ok(()) + } + + /// Write an expression that is true if the given image access is in bounds. + fn put_image_access_bounds_check( + &mut self, + image: Handle<crate::Expression>, + address: &TexelAddress, + context: &ExpressionContext, + ) -> BackendResult { + let mut conjunction = ""; + + // First, check the level of detail. Only if that is in bounds can we + // use it to find the appropriate bounds for the coordinates. + let level = if let Some(level) = address.level { + write!(self.out, "uint(")?; + self.put_level_of_detail(level, context)?; + write!(self.out, ") < ")?; + self.put_expression(image, context, true)?; + write!(self.out, ".get_num_mip_levels()")?; + conjunction = " && "; + Some(level) + } else { + None + }; + + // Check sample index, if present. + if let Some(sample) = address.sample { + write!(self.out, "uint(")?; + self.put_expression(sample, context, true)?; + write!(self.out, ") < ")?; + self.put_expression(image, context, true)?; + write!(self.out, ".get_num_samples()")?; + conjunction = " && "; + } + + // Check array index, if present. + if let Some(array_index) = address.array_index { + write!(self.out, "{conjunction}uint(")?; + self.put_expression(array_index, context, true)?; + write!(self.out, ") < ")?; + self.put_expression(image, context, true)?; + write!(self.out, ".get_array_size()")?; + conjunction = " && "; + } + + // Finally, check if the coordinates are within bounds. + let coord_is_vector = match *context.resolve_type(address.coordinate) { + crate::TypeInner::Vector { .. } => true, + _ => false, + }; + write!(self.out, "{conjunction}")?; + if coord_is_vector { + write!(self.out, "{NAMESPACE}::all(")?; + } + self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?; + write!(self.out, " < ")?; + self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?; + if coord_is_vector { + write!(self.out, ")")?; + } + + Ok(()) + } + + fn put_image_load( + &mut self, + load: Handle<crate::Expression>, + image: Handle<crate::Expression>, + mut address: TexelAddress, + context: &ExpressionContext, + ) -> BackendResult { + match context.policies.image_load { + proc::BoundsCheckPolicy::Restrict => { + // Use the cached restricted level of detail, if any. Omit the + // level altogether for 1D textures. + if address.level.is_some() { + address.level = if context.image_needs_lod(image) { + Some(LevelOfDetail::Restricted(load)) + } else { + None + } + } + + self.put_expression(image, context, false)?; + write!(self.out, ".read(")?; + self.put_restricted_texel_address(image, &address, context)?; + write!(self.out, ")")?; + } + proc::BoundsCheckPolicy::ReadZeroSkipWrite => { + write!(self.out, "(")?; + self.put_image_access_bounds_check(image, &address, context)?; + write!(self.out, " ? ")?; + self.put_unchecked_image_load(image, &address, context)?; + write!(self.out, ": DefaultConstructible())")?; + } + proc::BoundsCheckPolicy::Unchecked => { + self.put_unchecked_image_load(image, &address, context)?; + } + } + + Ok(()) + } + + fn put_unchecked_image_load( + &mut self, + image: Handle<crate::Expression>, + address: &TexelAddress, + context: &ExpressionContext, + ) -> BackendResult { + self.put_expression(image, context, false)?; + write!(self.out, ".read(")?; + // coordinates in IR are int, but Metal expects uint + self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?; + if let Some(expr) = address.array_index { + write!(self.out, ", ")?; + self.put_expression(expr, context, true)?; + } + if let Some(sample) = address.sample { + write!(self.out, ", ")?; + self.put_expression(sample, context, true)?; + } + if let Some(level) = address.level { + if context.image_needs_lod(image) { + write!(self.out, ", ")?; + self.put_level_of_detail(level, context)?; + } + } + write!(self.out, ")")?; + + Ok(()) + } + + fn put_image_store( + &mut self, + level: back::Level, + image: Handle<crate::Expression>, + address: &TexelAddress, + value: Handle<crate::Expression>, + context: &StatementContext, + ) -> BackendResult { + match context.expression.policies.image_store { + proc::BoundsCheckPolicy::Restrict => { + // We don't have a restricted level value, because we don't + // support writes to mipmapped textures. + debug_assert!(address.level.is_none()); + + write!(self.out, "{level}")?; + self.put_expression(image, &context.expression, false)?; + write!(self.out, ".write(")?; + self.put_expression(value, &context.expression, true)?; + write!(self.out, ", ")?; + self.put_restricted_texel_address(image, address, &context.expression)?; + writeln!(self.out, ");")?; + } + proc::BoundsCheckPolicy::ReadZeroSkipWrite => { + write!(self.out, "{level}if (")?; + self.put_image_access_bounds_check(image, address, &context.expression)?; + writeln!(self.out, ") {{")?; + self.put_unchecked_image_store(level.next(), image, address, value, context)?; + writeln!(self.out, "{level}}}")?; + } + proc::BoundsCheckPolicy::Unchecked => { + self.put_unchecked_image_store(level, image, address, value, context)?; + } + } + + Ok(()) + } + + fn put_unchecked_image_store( + &mut self, + level: back::Level, + image: Handle<crate::Expression>, + address: &TexelAddress, + value: Handle<crate::Expression>, + context: &StatementContext, + ) -> BackendResult { + write!(self.out, "{level}")?; + self.put_expression(image, &context.expression, false)?; + write!(self.out, ".write(")?; + self.put_expression(value, &context.expression, true)?; + write!(self.out, ", ")?; + // coordinates in IR are int, but Metal expects uint + self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?; + if let Some(expr) = address.array_index { + write!(self.out, ", ")?; + self.put_expression(expr, &context.expression, true)?; + } + writeln!(self.out, ");")?; + + Ok(()) + } + + /// Write the maximum valid index of the dynamically sized array at the end of `handle`. + /// + /// The 'maximum valid index' is simply one less than the array's length. + /// + /// This emits an expression of the form `a / b`, so the caller must + /// parenthesize its output if it will be applying operators of higher + /// precedence. + /// + /// `handle` must be the handle of a global variable whose final member is a + /// dynamically sized array. + fn put_dynamic_array_max_index( + &mut self, + handle: Handle<crate::GlobalVariable>, + context: &ExpressionContext, + ) -> BackendResult { + let global = &context.module.global_variables[handle]; + let (offset, array_ty) = match context.module.types[global.ty].inner { + crate::TypeInner::Struct { ref members, .. } => match members.last() { + Some(&crate::StructMember { offset, ty, .. }) => (offset, ty), + None => return Err(Error::Validation), + }, + crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } => (0, global.ty), + _ => return Err(Error::Validation), + }; + + let (size, stride) = match context.module.types[array_ty].inner { + crate::TypeInner::Array { base, stride, .. } => ( + context.module.types[base] + .inner + .size(context.module.to_ctx()), + stride, + ), + _ => return Err(Error::Validation), + }; + + // When the stride length is larger than the size, the final element's stride of + // bytes would have padding following the value. But the buffer size in + // `buffer_sizes.sizeN` may not include this padding - it only needs to be large + // enough to hold the actual values' bytes. + // + // So subtract off the size to get a byte size that falls at the start or within + // the final element. Then divide by the stride size, to get one less than the + // length, and then add one. This works even if the buffer size does include the + // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail + // if there are zero elements in the array, but the WebGPU `validating shader binding` + // rules, together with draw-time validation when `minBindingSize` is zero, + // prevent that. + write!( + self.out, + "(_buffer_sizes.size{idx} - {offset} - {size}) / {stride}", + idx = handle.index(), + offset = offset, + size = size, + stride = stride, + )?; + Ok(()) + } + + fn put_atomic_fetch( + &mut self, + pointer: Handle<crate::Expression>, + key: &str, + value: Handle<crate::Expression>, + context: &ExpressionContext, + ) -> BackendResult { + self.put_atomic_operation(pointer, "fetch_", key, value, context) + } + + fn put_atomic_operation( + &mut self, + pointer: Handle<crate::Expression>, + key1: &str, + key2: &str, + value: Handle<crate::Expression>, + context: &ExpressionContext, + ) -> BackendResult { + // If the pointer we're passing to the atomic operation needs to be conditional + // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and + // the pointer operand should be unchecked. + let policy = context.choose_bounds_check_policy(pointer); + let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite + && self.put_bounds_checks(pointer, context, back::Level(0), "")?; + + // If requested and successfully put bounds checks, continue the ternary expression. + if checked { + write!(self.out, " ? ")?; + } + + write!( + self.out, + "{NAMESPACE}::atomic_{key1}{key2}_explicit({ATOMIC_REFERENCE}" + )?; + self.put_access_chain(pointer, policy, context)?; + write!(self.out, ", ")?; + self.put_expression(value, context, true)?; + write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?; + + // Finish the ternary expression. + if checked { + write!(self.out, " : DefaultConstructible()")?; + } + + Ok(()) + } + + /// Emit code for the arithmetic expression of the dot product. + /// + fn put_dot_product( + &mut self, + arg: Handle<crate::Expression>, + arg1: Handle<crate::Expression>, + size: usize, + context: &ExpressionContext, + ) -> BackendResult { + // Write parentheses around the dot product expression to prevent operators + // with different precedences from applying earlier. + write!(self.out, "(")?; + + // Cycle trough all the components of the vector + for index in 0..size { + let component = back::COMPONENTS[index]; + // Write the addition to the previous product + // This will print an extra '+' at the beginning but that is fine in msl + write!(self.out, " + ")?; + // Write the first vector expression, this expression is marked to be + // cached so unless it can't be cached (for example, it's a Constant) + // it shouldn't produce large expressions. + self.put_expression(arg, context, true)?; + // Access the current component on the first vector + write!(self.out, ".{component} * ")?; + // Write the second vector expression, this expression is marked to be + // cached so unless it can't be cached (for example, it's a Constant) + // it shouldn't produce large expressions. + self.put_expression(arg1, context, true)?; + // Access the current component on the second vector + write!(self.out, ".{component}")?; + } + + write!(self.out, ")")?; + Ok(()) + } + + /// Emit code for the sign(i32) expression. + /// + fn put_isign( + &mut self, + arg: Handle<crate::Expression>, + context: &ExpressionContext, + ) -> BackendResult { + write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?; + match context.resolve_type(arg) { + &crate::TypeInner::Vector { size, .. } => { + let size = back::vector_size_str(size); + write!(self.out, "int{size}(-1), int{size}(1)")?; + } + _ => { + write!(self.out, "-1, 1")?; + } + } + write!(self.out, ", (")?; + self.put_expression(arg, context, true)?; + write!(self.out, " > 0)), 0, (")?; + self.put_expression(arg, context, true)?; + write!(self.out, " == 0))")?; + Ok(()) + } + + fn put_const_expression( + &mut self, + expr_handle: Handle<crate::Expression>, + module: &crate::Module, + mod_info: &valid::ModuleInfo, + ) -> BackendResult { + self.put_possibly_const_expression( + expr_handle, + &module.const_expressions, + module, + mod_info, + &(module, mod_info), + |&(_, mod_info), expr| &mod_info[expr], + |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info), + ) + } + + #[allow(clippy::too_many_arguments)] + fn put_possibly_const_expression<C, I, E>( + &mut self, + expr_handle: Handle<crate::Expression>, + expressions: &crate::Arena<crate::Expression>, + module: &crate::Module, + mod_info: &valid::ModuleInfo, + ctx: &C, + get_expr_ty: I, + put_expression: E, + ) -> BackendResult + where + I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution, + E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult, + { + match expressions[expr_handle] { + crate::Expression::Literal(literal) => match literal { + crate::Literal::F64(_) => { + return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64)) + } + crate::Literal::F32(value) => { + if value.is_infinite() { + let sign = if value.is_sign_negative() { "-" } else { "" }; + write!(self.out, "{sign}INFINITY")?; + } else if value.is_nan() { + write!(self.out, "NAN")?; + } else { + let suffix = if value.fract() == 0.0 { ".0" } else { "" }; + write!(self.out, "{value}{suffix}")?; + } + } + crate::Literal::U32(value) => { + write!(self.out, "{value}u")?; + } + crate::Literal::I32(value) => { + write!(self.out, "{value}")?; + } + crate::Literal::I64(value) => { + write!(self.out, "{value}L")?; + } + crate::Literal::Bool(value) => { + write!(self.out, "{value}")?; + } + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Validation); + } + }, + crate::Expression::Constant(handle) => { + let constant = &module.constants[handle]; + if constant.name.is_some() { + write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?; + } else { + self.put_const_expression(constant.init, module, mod_info)?; + } + } + crate::Expression::ZeroValue(ty) => { + let ty_name = TypeContext { + handle: ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + write!(self.out, "{ty_name} {{}}")?; + } + crate::Expression::Compose { ty, ref components } => { + let ty_name = TypeContext { + handle: ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + write!(self.out, "{ty_name}")?; + match module.types[ty].inner { + crate::TypeInner::Scalar(_) + | crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } => { + self.put_call_parameters_impl( + components.iter().copied(), + ctx, + put_expression, + )?; + } + crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => { + write!(self.out, " {{")?; + for (index, &component) in components.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + // insert padding initialization, if needed + if self.struct_member_pads.contains(&(ty, index as u32)) { + write!(self.out, "{{}}, ")?; + } + put_expression(self, ctx, component)?; + } + write!(self.out, "}}")?; + } + _ => return Err(Error::UnsupportedCompose(ty)), + } + } + crate::Expression::Splat { size, value } => { + let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) { + crate::TypeInner::Scalar(scalar) => scalar, + _ => return Err(Error::Validation), + }; + put_numeric_type(&mut self.out, scalar, &[size])?; + write!(self.out, "(")?; + put_expression(self, ctx, value)?; + write!(self.out, ")")?; + } + _ => unreachable!(), + } + + Ok(()) + } + + /// Emit code for the expression `expr_handle`. + /// + /// The `is_scoped` argument is true if the surrounding operators have the + /// precedence of the comma operator, or lower. So, for example: + /// + /// - Pass `true` for `is_scoped` when writing function arguments, an + /// expression statement, an initializer expression, or anything already + /// wrapped in parenthesis. + /// + /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really + /// almost anything else. + fn put_expression( + &mut self, + expr_handle: Handle<crate::Expression>, + context: &ExpressionContext, + is_scoped: bool, + ) -> BackendResult { + // Add to the set in order to track the stack size. + #[cfg(test)] + #[allow(trivial_casts)] + self.put_expression_stack_pointers + .insert(&expr_handle as *const _ as *const ()); + + if let Some(name) = self.named_expressions.get(&expr_handle) { + write!(self.out, "{name}")?; + return Ok(()); + } + + let expression = &context.function.expressions[expr_handle]; + log::trace!("expression {:?} = {:?}", expr_handle, expression); + match *expression { + crate::Expression::Literal(_) + | crate::Expression::Constant(_) + | crate::Expression::ZeroValue(_) + | crate::Expression::Compose { .. } + | crate::Expression::Splat { .. } => { + self.put_possibly_const_expression( + expr_handle, + &context.function.expressions, + context.module, + context.mod_info, + context, + |context, expr: Handle<crate::Expression>| &context.info[expr].ty, + |writer, context, expr| writer.put_expression(expr, context, true), + )?; + } + crate::Expression::Access { base, .. } + | crate::Expression::AccessIndex { base, .. } => { + // This is an acceptable place to generate a `ReadZeroSkipWrite` check. + // Since `put_bounds_checks` and `put_access_chain` handle an entire + // access chain at a time, recursing back through `put_expression` only + // for index expressions and the base object, we will never see intermediate + // `Access` or `AccessIndex` expressions here. + let policy = context.choose_bounds_check_policy(base); + if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite + && self.put_bounds_checks( + expr_handle, + context, + back::Level(0), + if is_scoped { "" } else { "(" }, + )? + { + write!(self.out, " ? ")?; + self.put_access_chain(expr_handle, policy, context)?; + write!(self.out, " : DefaultConstructible()")?; + + if !is_scoped { + write!(self.out, ")")?; + } + } else { + self.put_access_chain(expr_handle, policy, context)?; + } + } + crate::Expression::Swizzle { + size, + vector, + pattern, + } => { + self.put_wrapped_expression_for_packed_vec3_access(vector, context, false)?; + write!(self.out, ".")?; + for &sc in pattern[..size as usize].iter() { + write!(self.out, "{}", back::COMPONENTS[sc as usize])?; + } + } + crate::Expression::FunctionArgument(index) => { + let name_key = match context.origin { + FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index), + FunctionOrigin::EntryPoint(ep_index) => { + NameKey::EntryPointArgument(ep_index, index) + } + }; + let name = &self.names[&name_key]; + write!(self.out, "{name}")?; + } + crate::Expression::GlobalVariable(handle) => { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, "{name}")?; + } + crate::Expression::LocalVariable(handle) => { + let name_key = match context.origin { + FunctionOrigin::Handle(fun_handle) => { + NameKey::FunctionLocal(fun_handle, handle) + } + FunctionOrigin::EntryPoint(ep_index) => { + NameKey::EntryPointLocal(ep_index, handle) + } + }; + let name = &self.names[&name_key]; + write!(self.out, "{name}")?; + } + crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?, + crate::Expression::ImageSample { + image, + sampler, + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + } => { + let main_op = match gather { + Some(_) => "gather", + None => "sample", + }; + let comparison_op = match depth_ref { + Some(_) => "_compare", + None => "", + }; + self.put_expression(image, context, false)?; + write!(self.out, ".{main_op}{comparison_op}(")?; + self.put_expression(sampler, context, true)?; + write!(self.out, ", ")?; + self.put_expression(coordinate, context, true)?; + if let Some(expr) = array_index { + write!(self.out, ", ")?; + self.put_expression(expr, context, true)?; + } + if let Some(dref) = depth_ref { + write!(self.out, ", ")?; + self.put_expression(dref, context, true)?; + } + + self.put_image_sample_level(image, level, context)?; + + if let Some(offset) = offset { + write!(self.out, ", ")?; + self.put_const_expression(offset, context.module, context.mod_info)?; + } + + match gather { + None | Some(crate::SwizzleComponent::X) => {} + Some(component) => { + let is_cube_map = match *context.resolve_type(image) { + crate::TypeInner::Image { + dim: crate::ImageDimension::Cube, + .. + } => true, + _ => false, + }; + // Offset always comes before the gather, except + // in cube maps where it's not applicable + if offset.is_none() && !is_cube_map { + write!(self.out, ", {NAMESPACE}::int2(0)")?; + } + let letter = back::COMPONENTS[component as usize]; + write!(self.out, ", {NAMESPACE}::component::{letter}")?; + } + } + write!(self.out, ")")?; + } + crate::Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + let address = TexelAddress { + coordinate, + array_index, + sample, + level: level.map(LevelOfDetail::Direct), + }; + self.put_image_load(expr_handle, image, address, context)?; + } + //Note: for all the queries, the signed integers are expected, + // so a conversion is needed. + crate::Expression::ImageQuery { image, query } => match query { + crate::ImageQuery::Size { level } => { + self.put_image_size_query( + image, + level.map(LevelOfDetail::Direct), + crate::ScalarKind::Uint, + context, + )?; + } + crate::ImageQuery::NumLevels => { + self.put_expression(image, context, false)?; + write!(self.out, ".get_num_mip_levels()")?; + } + crate::ImageQuery::NumLayers => { + self.put_expression(image, context, false)?; + write!(self.out, ".get_array_size()")?; + } + crate::ImageQuery::NumSamples => { + self.put_expression(image, context, false)?; + write!(self.out, ".get_num_samples()")?; + } + }, + crate::Expression::Unary { op, expr } => { + let op_str = match op { + crate::UnaryOperator::Negate => "-", + crate::UnaryOperator::LogicalNot => "!", + crate::UnaryOperator::BitwiseNot => "~", + }; + write!(self.out, "{op_str}(")?; + self.put_expression(expr, context, false)?; + write!(self.out, ")")?; + } + crate::Expression::Binary { op, left, right } => { + let op_str = crate::back::binary_operation_str(op); + let kind = context + .resolve_type(left) + .scalar_kind() + .ok_or(Error::UnsupportedBinaryOp(op))?; + + // TODO: handle undefined behavior of BinaryOperator::Modulo + // + // sint: + // if right == 0 return 0 + // if left == min(type_of(left)) && right == -1 return 0 + // if sign(left) == -1 || sign(right) == -1 return result as defined by WGSL + // + // uint: + // if right == 0 return 0 + // + // float: + // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798 + + if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float { + write!(self.out, "{NAMESPACE}::fmod(")?; + self.put_expression(left, context, true)?; + write!(self.out, ", ")?; + self.put_expression(right, context, true)?; + write!(self.out, ")")?; + } else { + if !is_scoped { + write!(self.out, "(")?; + } + + // Cast packed vector if necessary + // Packed vector - matrix multiplications are not supported in MSL + if op == crate::BinaryOperator::Multiply + && matches!( + context.resolve_type(right), + &crate::TypeInner::Matrix { .. } + ) + { + self.put_wrapped_expression_for_packed_vec3_access(left, context, false)?; + } else { + self.put_expression(left, context, false)?; + } + + write!(self.out, " {op_str} ")?; + + // See comment above + if op == crate::BinaryOperator::Multiply + && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. }) + { + self.put_wrapped_expression_for_packed_vec3_access(right, context, false)?; + } else { + self.put_expression(right, context, false)?; + } + + if !is_scoped { + write!(self.out, ")")?; + } + } + } + crate::Expression::Select { + condition, + accept, + reject, + } => match *context.resolve_type(condition) { + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Bool, + .. + }) => { + if !is_scoped { + write!(self.out, "(")?; + } + self.put_expression(condition, context, false)?; + write!(self.out, " ? ")?; + self.put_expression(accept, context, false)?; + write!(self.out, " : ")?; + self.put_expression(reject, context, false)?; + if !is_scoped { + write!(self.out, ")")?; + } + } + crate::TypeInner::Vector { + scalar: + crate::Scalar { + kind: crate::ScalarKind::Bool, + .. + }, + .. + } => { + write!(self.out, "{NAMESPACE}::select(")?; + self.put_expression(reject, context, true)?; + write!(self.out, ", ")?; + self.put_expression(accept, context, true)?; + write!(self.out, ", ")?; + self.put_expression(condition, context, true)?; + write!(self.out, ")")?; + } + _ => return Err(Error::Validation), + }, + crate::Expression::Derivative { axis, expr, .. } => { + use crate::DerivativeAxis as Axis; + let op = match axis { + Axis::X => "dfdx", + Axis::Y => "dfdy", + Axis::Width => "fwidth", + }; + write!(self.out, "{NAMESPACE}::{op}")?; + self.put_call_parameters(iter::once(expr), context)?; + } + crate::Expression::Relational { fun, argument } => { + let op = match fun { + crate::RelationalFunction::Any => "any", + crate::RelationalFunction::All => "all", + crate::RelationalFunction::IsNan => "isnan", + crate::RelationalFunction::IsInf => "isinf", + }; + write!(self.out, "{NAMESPACE}::{op}")?; + self.put_call_parameters(iter::once(argument), context)?; + } + crate::Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + use crate::MathFunction as Mf; + + let arg_type = context.resolve_type(arg); + let scalar_argument = match arg_type { + &crate::TypeInner::Scalar(_) => true, + _ => false, + }; + + let fun_name = match fun { + // comparison + Mf::Abs => "abs", + Mf::Min => "min", + Mf::Max => "max", + Mf::Clamp => "clamp", + Mf::Saturate => "saturate", + // trigonometry + Mf::Cos => "cos", + Mf::Cosh => "cosh", + Mf::Sin => "sin", + Mf::Sinh => "sinh", + Mf::Tan => "tan", + Mf::Tanh => "tanh", + Mf::Acos => "acos", + Mf::Asin => "asin", + Mf::Atan => "atan", + Mf::Atan2 => "atan2", + Mf::Asinh => "asinh", + Mf::Acosh => "acosh", + Mf::Atanh => "atanh", + Mf::Radians => "", + Mf::Degrees => "", + // decomposition + Mf::Ceil => "ceil", + Mf::Floor => "floor", + Mf::Round => "rint", + Mf::Fract => "fract", + Mf::Trunc => "trunc", + Mf::Modf => MODF_FUNCTION, + Mf::Frexp => FREXP_FUNCTION, + Mf::Ldexp => "ldexp", + // exponent + Mf::Exp => "exp", + Mf::Exp2 => "exp2", + Mf::Log => "log", + Mf::Log2 => "log2", + Mf::Pow => "pow", + // geometry + Mf::Dot => match *context.resolve_type(arg) { + crate::TypeInner::Vector { + scalar: + crate::Scalar { + kind: crate::ScalarKind::Float, + .. + }, + .. + } => "dot", + crate::TypeInner::Vector { size, .. } => { + return self.put_dot_product(arg, arg1.unwrap(), size as usize, context) + } + _ => unreachable!( + "Correct TypeInner for dot product should be already validated" + ), + }, + Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))), + Mf::Cross => "cross", + Mf::Distance => "distance", + Mf::Length if scalar_argument => "abs", + Mf::Length => "length", + Mf::Normalize => "normalize", + Mf::FaceForward => "faceforward", + Mf::Reflect => "reflect", + Mf::Refract => "refract", + // computational + Mf::Sign => match arg_type.scalar_kind() { + Some(crate::ScalarKind::Sint) => { + return self.put_isign(arg, context); + } + _ => "sign", + }, + Mf::Fma => "fma", + Mf::Mix => "mix", + Mf::Step => "step", + Mf::SmoothStep => "smoothstep", + Mf::Sqrt => "sqrt", + Mf::InverseSqrt => "rsqrt", + Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))), + Mf::Transpose => "transpose", + Mf::Determinant => "determinant", + // bits + Mf::CountTrailingZeros => "ctz", + Mf::CountLeadingZeros => "clz", + Mf::CountOneBits => "popcount", + Mf::ReverseBits => "reverse_bits", + Mf::ExtractBits => "extract_bits", + Mf::InsertBits => "insert_bits", + Mf::FindLsb => "", + Mf::FindMsb => "", + // data packing + Mf::Pack4x8snorm => "pack_float_to_snorm4x8", + Mf::Pack4x8unorm => "pack_float_to_unorm4x8", + Mf::Pack2x16snorm => "pack_float_to_snorm2x16", + Mf::Pack2x16unorm => "pack_float_to_unorm2x16", + Mf::Pack2x16float => "", + // data unpacking + Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float", + Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float", + Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float", + Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float", + Mf::Unpack2x16float => "", + }; + + match fun { + Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => { + // reverse_bits is listed as requiring MSL 2.1 but that + // is a copy/paste error. Looking at previous snapshots + // on web.archive.org it's present in MSL 1.2. + // + // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html + // also talks about MSL 1.2 adding "New integer + // functions to extract, insert, and reverse bits, as + // described in Integer Functions." + if context.lang_version < (1, 2) { + return Err(Error::UnsupportedFunction(fun_name.to_string())); + } + } + _ => {} + } + + if fun == Mf::Distance && scalar_argument { + write!(self.out, "{NAMESPACE}::abs(")?; + self.put_expression(arg, context, false)?; + write!(self.out, " - ")?; + self.put_expression(arg1.unwrap(), context, false)?; + write!(self.out, ")")?; + } else if fun == Mf::FindLsb { + write!(self.out, "((({NAMESPACE}::ctz(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ") + 1) % 33) - 1)")?; + } else if fun == Mf::FindMsb { + let inner = context.resolve_type(arg); + + write!(self.out, "{NAMESPACE}::select(31 - {NAMESPACE}::clz(")?; + + if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { + write!(self.out, "{NAMESPACE}::select(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", ~")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " < 0)")?; + } else { + self.put_expression(arg, context, true)?; + } + + write!(self.out, "), ")?; + + // or metal will complain that select is ambiguous + match *inner { + crate::TypeInner::Vector { size, scalar } => { + let size = back::vector_size_str(size); + if let crate::ScalarKind::Sint = scalar.kind { + write!(self.out, "int{size}")?; + } else { + write!(self.out, "uint{size}")?; + } + } + crate::TypeInner::Scalar(scalar) => { + if let crate::ScalarKind::Sint = scalar.kind { + write!(self.out, "int")?; + } else { + write!(self.out, "uint")?; + } + } + _ => (), + } + + write!(self.out, "(-1), ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " == 0 || ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " == -1)")?; + } else if fun == Mf::Unpack2x16float { + write!(self.out, "float2(as_type<half2>(")?; + self.put_expression(arg, context, false)?; + write!(self.out, "))")?; + } else if fun == Mf::Pack2x16float { + write!(self.out, "as_type<uint>(half2(")?; + self.put_expression(arg, context, false)?; + write!(self.out, "))")?; + } else if fun == Mf::Radians { + write!(self.out, "((")?; + self.put_expression(arg, context, false)?; + write!(self.out, ") * 0.017453292519943295474)")?; + } else if fun == Mf::Degrees { + write!(self.out, "((")?; + self.put_expression(arg, context, false)?; + write!(self.out, ") * 57.295779513082322865)")?; + } else if fun == Mf::Modf || fun == Mf::Frexp { + write!(self.out, "{fun_name}")?; + self.put_call_parameters(iter::once(arg), context)?; + } else { + write!(self.out, "{NAMESPACE}::{fun_name}")?; + self.put_call_parameters( + iter::once(arg).chain(arg1).chain(arg2).chain(arg3), + context, + )?; + } + } + crate::Expression::As { + expr, + kind, + convert, + } => match *context.resolve_type(expr) { + crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => { + let target_scalar = crate::Scalar { + kind, + width: convert.unwrap_or(src.width), + }; + let is_bool_cast = + kind == crate::ScalarKind::Bool || src.kind == crate::ScalarKind::Bool; + let op = match convert { + Some(w) if w == src.width || is_bool_cast => "static_cast", + Some(8) if kind == crate::ScalarKind::Float => { + return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64)) + } + Some(_) => return Err(Error::Validation), + None => "as_type", + }; + write!(self.out, "{op}<")?; + match *context.resolve_type(expr) { + crate::TypeInner::Vector { size, .. } => { + put_numeric_type(&mut self.out, target_scalar, &[size])? + } + _ => put_numeric_type(&mut self.out, target_scalar, &[])?, + }; + write!(self.out, ">(")?; + self.put_expression(expr, context, true)?; + write!(self.out, ")")?; + } + crate::TypeInner::Matrix { + columns, + rows, + scalar, + } => { + let target_scalar = crate::Scalar { + kind, + width: convert.unwrap_or(scalar.width), + }; + put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?; + write!(self.out, "(")?; + self.put_expression(expr, context, true)?; + write!(self.out, ")")?; + } + _ => return Err(Error::Validation), + }, + // has to be a named expression + crate::Expression::CallResult(_) + | crate::Expression::AtomicResult { .. } + | crate::Expression::WorkGroupUniformLoadResult { .. } + | crate::Expression::RayQueryProceedResult => { + unreachable!() + } + crate::Expression::ArrayLength(expr) => { + // Find the global to which the array belongs. + let global = match context.function.expressions[expr] { + crate::Expression::AccessIndex { base, .. } => { + match context.function.expressions[base] { + crate::Expression::GlobalVariable(handle) => handle, + _ => return Err(Error::Validation), + } + } + crate::Expression::GlobalVariable(handle) => handle, + _ => return Err(Error::Validation), + }; + + if !is_scoped { + write!(self.out, "(")?; + } + write!(self.out, "1 + ")?; + self.put_dynamic_array_max_index(global, context)?; + if !is_scoped { + write!(self.out, ")")?; + } + } + crate::Expression::RayQueryGetIntersection { query, committed } => { + if context.lang_version < (2, 4) { + return Err(Error::UnsupportedRayTracing); + } + + if !committed { + unimplemented!() + } + let ty = context.module.special_types.ray_intersection.unwrap(); + let type_name = &self.names[&NameKey::Type(ty)]; + write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?; + self.put_expression(query, context, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?; + let fields = [ + "distance", + "user_instance_id", // req Metal 2.4 + "instance_id", + "", // SBT offset + "geometry_id", + "primitive_id", + "triangle_barycentric_coord", + "triangle_front_facing", + "", // padding + "object_to_world_transform", // req Metal 2.4 + "world_to_object_transform", // req Metal 2.4 + ]; + for field in fields { + write!(self.out, ", ")?; + if field.is_empty() { + write!(self.out, "{{}}")?; + } else { + self.put_expression(query, context, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?; + } + } + write!(self.out, "}}")?; + } + } + Ok(()) + } + + /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3 + fn put_wrapped_expression_for_packed_vec3_access( + &mut self, + expr_handle: Handle<crate::Expression>, + context: &ExpressionContext, + is_scoped: bool, + ) -> BackendResult { + if let Some(scalar) = context.get_packed_vec_kind(expr_handle) { + write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?; + self.put_expression(expr_handle, context, is_scoped)?; + write!(self.out, ")")?; + } else { + self.put_expression(expr_handle, context, is_scoped)?; + } + Ok(()) + } + + /// Write a `GuardedIndex` as a Metal expression. + fn put_index( + &mut self, + index: index::GuardedIndex, + context: &ExpressionContext, + is_scoped: bool, + ) -> BackendResult { + match index { + index::GuardedIndex::Expression(expr) => { + self.put_expression(expr, context, is_scoped)? + } + index::GuardedIndex::Known(value) => write!(self.out, "{value}")?, + } + Ok(()) + } + + /// Emit an index bounds check condition for `chain`, if required. + /// + /// `chain` is a subtree of `Access` and `AccessIndex` expressions, + /// operating either on a pointer to a value, or on a value directly. If we cannot + /// statically determine that all indexing operations in `chain` are within + /// bounds, then write a conditional expression to check them dynamically, + /// and return true. All accesses in the chain are checked by the generated + /// expression. + /// + /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`]. + /// + /// The text written is of the form: + /// + /// ```ignore + /// {level}{prefix}uint(i) < 4 && uint(j) < 10 + /// ``` + /// + /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`] + /// statements, presumably these arguments start an indented `if` statement; for + /// [`Load`] expressions, the caller is probably building up a ternary `?:` + /// expression. In either case, what is written is not a complete syntactic structure + /// in its own right, and the caller will have to finish it off if we return `true`. + /// + /// If no expression is written, return false. + /// + /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy + /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite + /// [`Store`]: crate::Statement::Store + /// [`Load`]: crate::Expression::Load + #[allow(unused_variables)] + fn put_bounds_checks( + &mut self, + mut chain: Handle<crate::Expression>, + context: &ExpressionContext, + level: back::Level, + prefix: &'static str, + ) -> Result<bool, Error> { + let mut check_written = false; + + // Iterate over the access chain, handling each expression. + loop { + // Produce a `GuardedIndex`, so we can shared code between the + // `Access` and `AccessIndex` cases. + let (base, guarded_index) = match context.function.expressions[chain] { + crate::Expression::Access { base, index } => { + (base, Some(index::GuardedIndex::Expression(index))) + } + crate::Expression::AccessIndex { base, index } => { + // Don't try to check indices into structs. Validation already took + // care of them, and index::needs_guard doesn't handle that case. + let mut base_inner = context.resolve_type(base); + if let crate::TypeInner::Pointer { base, .. } = *base_inner { + base_inner = &context.module.types[base].inner; + } + match *base_inner { + crate::TypeInner::Struct { .. } => (base, None), + _ => (base, Some(index::GuardedIndex::Known(index))), + } + } + _ => break, + }; + + if let Some(index) = guarded_index { + if let Some(length) = context.access_needs_check(base, index) { + if check_written { + write!(self.out, " && ")?; + } else { + write!(self.out, "{level}{prefix}")?; + check_written = true; + } + + // Check that the index falls within bounds. Do this with a single + // comparison, by casting the index to `uint` first, so that negative + // indices become large positive values. + write!(self.out, "uint(")?; + self.put_index(index, context, true)?; + self.out.write_str(") < ")?; + match length { + index::IndexableLength::Known(value) => write!(self.out, "{value}")?, + index::IndexableLength::Dynamic => { + let global = context + .function + .originating_global(base) + .ok_or(Error::Validation)?; + write!(self.out, "1 + ")?; + self.put_dynamic_array_max_index(global, context)? + } + } + } + } + + chain = base + } + + Ok(check_written) + } + + /// Write the access chain `chain`. + /// + /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions, + /// operating either on a pointer to a value, or on a value directly. + /// + /// Generate bounds checks code only if `policy` is [`Restrict`]. The + /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so + /// that must be handled in the caller. + /// + /// Handle the entire chain, recursing back into `put_expression` only for index + /// expressions and the base expression that originates the pointer or composite value + /// being accessed. This allows `put_expression` to assume that any `Access` or + /// `AccessIndex` expressions it sees are the top of a chain, so it can emit + /// `ReadZeroSkipWrite` checks. + /// + /// [`Access`]: crate::Expression::Access + /// [`AccessIndex`]: crate::Expression::AccessIndex + /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict + /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite + fn put_access_chain( + &mut self, + chain: Handle<crate::Expression>, + policy: index::BoundsCheckPolicy, + context: &ExpressionContext, + ) -> BackendResult { + match context.function.expressions[chain] { + crate::Expression::Access { base, index } => { + let mut base_ty = context.resolve_type(base); + + // Look through any pointers to see what we're really indexing. + if let crate::TypeInner::Pointer { base, space: _ } = *base_ty { + base_ty = &context.module.types[base].inner; + } + + self.put_subscripted_access_chain( + base, + base_ty, + index::GuardedIndex::Expression(index), + policy, + context, + )?; + } + crate::Expression::AccessIndex { base, index } => { + let base_resolution = &context.info[base].ty; + let mut base_ty = base_resolution.inner_with(&context.module.types); + let mut base_ty_handle = base_resolution.handle(); + + // Look through any pointers to see what we're really indexing. + if let crate::TypeInner::Pointer { base, space: _ } = *base_ty { + base_ty = &context.module.types[base].inner; + base_ty_handle = Some(base); + } + + // Handle structs and anything else that can use `.x` syntax here, so + // `put_subscripted_access_chain` won't have to handle the absurd case of + // indexing a struct with an expression. + match *base_ty { + crate::TypeInner::Struct { .. } => { + let base_ty = base_ty_handle.unwrap(); + self.put_access_chain(base, policy, context)?; + let name = &self.names[&NameKey::StructMember(base_ty, index)]; + write!(self.out, ".{name}")?; + } + crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => { + self.put_access_chain(base, policy, context)?; + // Prior to Metal v2.1 component access for packed vectors wasn't available + // however array indexing is + if context.get_packed_vec_kind(base).is_some() { + write!(self.out, "[{index}]")?; + } else { + write!(self.out, ".{}", back::COMPONENTS[index as usize])?; + } + } + _ => { + self.put_subscripted_access_chain( + base, + base_ty, + index::GuardedIndex::Known(index), + policy, + context, + )?; + } + } + } + _ => self.put_expression(chain, context, false)?, + } + + Ok(()) + } + + /// Write a `[]`-style access of `base` by `index`. + /// + /// If `policy` is [`Restrict`], then generate code as needed to force all index + /// values within bounds. + /// + /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or + /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`] + /// removed. Our callers often already have this handy. + /// + /// This only emits `[]` expressions; it doesn't handle struct member accesses or + /// referencing vector components by name. + /// + /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict + /// [`Array`]: crate::TypeInner::Array + /// [`Vector`]: crate::TypeInner::Vector + /// [`Pointer`]: crate::TypeInner::Pointer + fn put_subscripted_access_chain( + &mut self, + base: Handle<crate::Expression>, + base_ty: &crate::TypeInner, + index: index::GuardedIndex, + policy: index::BoundsCheckPolicy, + context: &ExpressionContext, + ) -> BackendResult { + let accessing_wrapped_array = match *base_ty { + crate::TypeInner::Array { + size: crate::ArraySize::Constant(_), + .. + } => true, + _ => false, + }; + + self.put_access_chain(base, policy, context)?; + if accessing_wrapped_array { + write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?; + } + write!(self.out, "[")?; + + // Decide whether this index needs to be clamped to fall within range. + let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict { + context.access_needs_check(base, index) + } else { + None + }; + if let Some(limit) = restriction_needed { + write!(self.out, "{NAMESPACE}::min(unsigned(")?; + self.put_index(index, context, true)?; + write!(self.out, "), ")?; + match limit { + index::IndexableLength::Known(limit) => { + write!(self.out, "{}u", limit - 1)?; + } + index::IndexableLength::Dynamic => { + let global = context + .function + .originating_global(base) + .ok_or(Error::Validation)?; + self.put_dynamic_array_max_index(global, context)?; + } + } + write!(self.out, ")")?; + } else { + self.put_index(index, context, true)?; + } + + write!(self.out, "]")?; + + Ok(()) + } + + fn put_load( + &mut self, + pointer: Handle<crate::Expression>, + context: &ExpressionContext, + is_scoped: bool, + ) -> BackendResult { + // Since access chains never cross between address spaces, we can just + // check the index bounds check policy once at the top. + let policy = context.choose_bounds_check_policy(pointer); + if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite + && self.put_bounds_checks( + pointer, + context, + back::Level(0), + if is_scoped { "" } else { "(" }, + )? + { + write!(self.out, " ? ")?; + self.put_unchecked_load(pointer, policy, context)?; + write!(self.out, " : DefaultConstructible()")?; + + if !is_scoped { + write!(self.out, ")")?; + } + } else { + self.put_unchecked_load(pointer, policy, context)?; + } + + Ok(()) + } + + fn put_unchecked_load( + &mut self, + pointer: Handle<crate::Expression>, + policy: index::BoundsCheckPolicy, + context: &ExpressionContext, + ) -> BackendResult { + let is_atomic_pointer = context + .resolve_type(pointer) + .is_atomic_pointer(&context.module.types); + + if is_atomic_pointer { + write!( + self.out, + "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}" + )?; + self.put_access_chain(pointer, policy, context)?; + write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?; + } else { + // We don't do any dereferencing with `*` here as pointer arguments to functions + // are done by `&` references and not `*` pointers. These do not need to be + // dereferenced. + self.put_access_chain(pointer, policy, context)?; + } + + Ok(()) + } + + fn put_return_value( + &mut self, + level: back::Level, + expr_handle: Handle<crate::Expression>, + result_struct: Option<&str>, + context: &ExpressionContext, + ) -> BackendResult { + match result_struct { + Some(struct_name) => { + let mut has_point_size = false; + let result_ty = context.function.result.as_ref().unwrap().ty; + match context.module.types[result_ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + let tmp = "_tmp"; + write!(self.out, "{level}const auto {tmp} = ")?; + self.put_expression(expr_handle, context, true)?; + writeln!(self.out, ";")?; + write!(self.out, "{level}return {struct_name} {{")?; + + let mut is_first = true; + + for (index, member) in members.iter().enumerate() { + if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) = + member.binding + { + has_point_size = true; + if !context.pipeline_options.allow_and_force_point_size { + continue; + } + } + + let comma = if is_first { "" } else { "," }; + is_first = false; + let name = &self.names[&NameKey::StructMember(result_ty, index as u32)]; + // HACK: we are forcefully deduplicating the expression here + // to convert from a wrapped struct to a raw array, e.g. + // `float gl_ClipDistance1 [[clip_distance]] [1];`. + if let crate::TypeInner::Array { + size: crate::ArraySize::Constant(size), + .. + } = context.module.types[member.ty].inner + { + write!(self.out, "{comma} {{")?; + for j in 0..size.get() { + if j != 0 { + write!(self.out, ",")?; + } + write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?; + } + write!(self.out, "}}")?; + } else { + write!(self.out, "{comma} {tmp}.{name}")?; + } + } + } + _ => { + write!(self.out, "{level}return {struct_name} {{ ")?; + self.put_expression(expr_handle, context, true)?; + } + } + + if let FunctionOrigin::EntryPoint(ep_index) = context.origin { + let stage = context.module.entry_points[ep_index as usize].stage; + if context.pipeline_options.allow_and_force_point_size + && stage == crate::ShaderStage::Vertex + && !has_point_size + { + // point size was injected and comes last + write!(self.out, ", 1.0")?; + } + } + write!(self.out, " }}")?; + } + None => { + write!(self.out, "{level}return ")?; + self.put_expression(expr_handle, context, true)?; + } + } + writeln!(self.out, ";")?; + Ok(()) + } + + /// Helper method used to find which expressions of a given function require baking + /// + /// # Notes + /// This function overwrites the contents of `self.need_bake_expressions` + fn update_expressions_to_bake( + &mut self, + func: &crate::Function, + info: &valid::FunctionInfo, + context: &ExpressionContext, + ) { + use crate::Expression; + self.need_bake_expressions.clear(); + + for (expr_handle, expr) in func.expressions.iter() { + // Expressions whose reference count is above the + // threshold should always be stored in temporaries. + let expr_info = &info[expr_handle]; + let min_ref_count = func.expressions[expr_handle].bake_ref_count(); + if min_ref_count <= expr_info.ref_count { + self.need_bake_expressions.insert(expr_handle); + } else { + match expr_info.ty { + // force ray desc to be baked: it's used multiple times internally + TypeResolution::Handle(h) + if Some(h) == context.module.special_types.ray_desc => + { + self.need_bake_expressions.insert(expr_handle); + } + _ => {} + } + } + + if let Expression::Math { fun, arg, arg1, .. } = *expr { + match fun { + crate::MathFunction::Dot => { + // WGSL's `dot` function works on any `vecN` type, but Metal's only + // works on floating-point vectors, so we emit inline code for + // integer vector `dot` calls. But that code uses each argument `N` + // times, once for each component (see `put_dot_product`), so to + // avoid duplicated evaluation, we must bake integer operands. + + // check what kind of product this is depending + // on the resolve type of the Dot function itself + let inner = context.resolve_type(expr_handle); + if let crate::TypeInner::Scalar(scalar) = *inner { + match scalar.kind { + crate::ScalarKind::Sint | crate::ScalarKind::Uint => { + self.need_bake_expressions.insert(arg); + self.need_bake_expressions.insert(arg1.unwrap()); + } + _ => {} + } + } + } + crate::MathFunction::FindMsb => { + self.need_bake_expressions.insert(arg); + } + crate::MathFunction::Sign => { + // WGSL's `sign` function works also on signed ints, but Metal's only + // works on floating points, so we emit inline code for integer `sign` + // calls. But that code uses each argument 2 times (see `put_isign`), + // so to avoid duplicated evaluation, we must bake the argument. + let inner = context.resolve_type(expr_handle); + if inner.scalar_kind() == Some(crate::ScalarKind::Sint) { + self.need_bake_expressions.insert(arg); + } + } + _ => {} + } + } + } + } + + fn start_baking_expression( + &mut self, + handle: Handle<crate::Expression>, + context: &ExpressionContext, + name: &str, + ) -> BackendResult { + match context.info[handle].ty { + TypeResolution::Handle(ty_handle) => { + let ty_name = TypeContext { + handle: ty_handle, + gctx: context.module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + write!(self.out, "{ty_name}")?; + } + TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => { + put_numeric_type(&mut self.out, scalar, &[])?; + } + TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => { + put_numeric_type(&mut self.out, scalar, &[size])?; + } + TypeResolution::Value(crate::TypeInner::Matrix { + columns, + rows, + scalar, + }) => { + put_numeric_type(&mut self.out, scalar, &[rows, columns])?; + } + TypeResolution::Value(ref other) => { + log::warn!("Type {:?} isn't a known local", other); //TEMP! + return Err(Error::FeatureNotImplemented("weird local type".to_string())); + } + } + + //TODO: figure out the naming scheme that wouldn't collide with user names. + write!(self.out, " {name} = ")?; + + Ok(()) + } + + /// Cache a clamped level of detail value, if necessary. + /// + /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a + /// properly clamped level of detail value both in the access itself, and + /// for fetching the size of the requested MIP level, needed to clamp the + /// coordinates. To avoid recomputing this clamped level of detail, we cache + /// it in a temporary variable, as part of the [`Emit`] statement covering + /// the [`ImageLoad`] expression. + /// + /// [`ImageLoad`]: crate::Expression::ImageLoad + /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict + /// [`Emit`]: crate::Statement::Emit + fn put_cache_restricted_level( + &mut self, + load: Handle<crate::Expression>, + image: Handle<crate::Expression>, + mip_level: Option<Handle<crate::Expression>>, + indent: back::Level, + context: &StatementContext, + ) -> BackendResult { + // Does this image access actually require (or even permit) a + // level-of-detail, and does the policy require us to restrict it? + let level_of_detail = match mip_level { + Some(level) => level, + None => return Ok(()), + }; + + if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict + || !context.expression.image_needs_lod(image) + { + return Ok(()); + } + + write!( + self.out, + "{}uint {}{} = ", + indent, + CLAMPED_LOD_LOAD_PREFIX, + load.index(), + )?; + self.put_restricted_scalar_image_index( + image, + level_of_detail, + "get_num_mip_levels", + &context.expression, + )?; + writeln!(self.out, ";")?; + + Ok(()) + } + + fn put_block( + &mut self, + level: back::Level, + statements: &[crate::Statement], + context: &StatementContext, + ) -> BackendResult { + // Add to the set in order to track the stack size. + #[cfg(test)] + #[allow(trivial_casts)] + self.put_block_stack_pointers + .insert(&level as *const _ as *const ()); + + for statement in statements { + log::trace!("statement[{}] {:?}", level.0, statement); + match *statement { + crate::Statement::Emit(ref range) => { + for handle in range.clone() { + // `ImageLoad` expressions covered by the `Restrict` bounds check policy + // may need to cache a clamped version of their level-of-detail argument. + if let crate::Expression::ImageLoad { + image, + level: mip_level, + .. + } = context.expression.function.expressions[handle] + { + self.put_cache_restricted_level( + handle, image, mip_level, level, context, + )?; + } + + let ptr_class = context.expression.resolve_type(handle).pointer_space(); + let expr_name = if ptr_class.is_some() { + None // don't bake pointer expressions (just yet) + } else if let Some(name) = + context.expression.function.named_expressions.get(&handle) + { + // The `crate::Function::named_expressions` table holds + // expressions that should be saved in temporaries once they + // are `Emit`ted. We only add them to `self.named_expressions` + // when we reach the `Emit` that covers them, so that we don't + // try to use their names before we've actually initialized + // the temporary that holds them. + // + // Don't assume the names in `named_expressions` are unique, + // or even valid. Use the `Namer`. + Some(self.namer.call(name)) + } else { + // If this expression is an index that we're going to first compare + // against a limit, and then actually use as an index, then we may + // want to cache it in a temporary, to avoid evaluating it twice. + let bake = + if context.expression.guarded_indices.contains(handle.index()) { + true + } else { + self.need_bake_expressions.contains(&handle) + }; + + if bake { + Some(format!("{}{}", back::BAKE_PREFIX, handle.index())) + } else { + None + } + }; + + if let Some(name) = expr_name { + write!(self.out, "{level}")?; + self.start_baking_expression(handle, &context.expression, &name)?; + self.put_expression(handle, &context.expression, true)?; + self.named_expressions.insert(handle, name); + writeln!(self.out, ";")?; + } + } + } + crate::Statement::Block(ref block) => { + if !block.is_empty() { + writeln!(self.out, "{level}{{")?; + self.put_block(level.next(), block, context)?; + writeln!(self.out, "{level}}}")?; + } + } + crate::Statement::If { + condition, + ref accept, + ref reject, + } => { + write!(self.out, "{level}if (")?; + self.put_expression(condition, &context.expression, true)?; + writeln!(self.out, ") {{")?; + self.put_block(level.next(), accept, context)?; + if !reject.is_empty() { + writeln!(self.out, "{level}}} else {{")?; + self.put_block(level.next(), reject, context)?; + } + writeln!(self.out, "{level}}}")?; + } + crate::Statement::Switch { + selector, + ref cases, + } => { + write!(self.out, "{level}switch(")?; + self.put_expression(selector, &context.expression, true)?; + writeln!(self.out, ") {{")?; + let lcase = level.next(); + for case in cases.iter() { + match case.value { + crate::SwitchValue::I32(value) => { + write!(self.out, "{lcase}case {value}:")?; + } + crate::SwitchValue::U32(value) => { + write!(self.out, "{lcase}case {value}u:")?; + } + crate::SwitchValue::Default => { + write!(self.out, "{lcase}default:")?; + } + } + + let write_block_braces = !(case.fall_through && case.body.is_empty()); + if write_block_braces { + writeln!(self.out, " {{")?; + } else { + writeln!(self.out)?; + } + + self.put_block(lcase.next(), &case.body, context)?; + if !case.fall_through + && case.body.last().map_or(true, |s| !s.is_terminator()) + { + writeln!(self.out, "{}break;", lcase.next())?; + } + + if write_block_braces { + writeln!(self.out, "{lcase}}}")?; + } + } + writeln!(self.out, "{level}}}")?; + } + crate::Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + if !continuing.is_empty() || break_if.is_some() { + let gate_name = self.namer.call("loop_init"); + writeln!(self.out, "{level}bool {gate_name} = true;")?; + writeln!(self.out, "{level}while(true) {{")?; + let lif = level.next(); + let lcontinuing = lif.next(); + writeln!(self.out, "{lif}if (!{gate_name}) {{")?; + self.put_block(lcontinuing, continuing, context)?; + if let Some(condition) = break_if { + write!(self.out, "{lcontinuing}if (")?; + self.put_expression(condition, &context.expression, true)?; + writeln!(self.out, ") {{")?; + writeln!(self.out, "{}break;", lcontinuing.next())?; + writeln!(self.out, "{lcontinuing}}}")?; + } + writeln!(self.out, "{lif}}}")?; + writeln!(self.out, "{lif}{gate_name} = false;")?; + } else { + writeln!(self.out, "{level}while(true) {{")?; + } + self.put_block(level.next(), body, context)?; + writeln!(self.out, "{level}}}")?; + } + crate::Statement::Break => { + writeln!(self.out, "{level}break;")?; + } + crate::Statement::Continue => { + writeln!(self.out, "{level}continue;")?; + } + crate::Statement::Return { + value: Some(expr_handle), + } => { + self.put_return_value( + level, + expr_handle, + context.result_struct, + &context.expression, + )?; + } + crate::Statement::Return { value: None } => { + writeln!(self.out, "{level}return;")?; + } + crate::Statement::Kill => { + writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?; + } + crate::Statement::Barrier(flags) => { + self.write_barrier(flags, level)?; + } + crate::Statement::Store { pointer, value } => { + self.put_store(pointer, value, level, context)? + } + crate::Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => { + let address = TexelAddress { + coordinate, + array_index, + sample: None, + level: None, + }; + self.put_image_store(level, image, &address, value, context)? + } + crate::Statement::Call { + function, + ref arguments, + result, + } => { + write!(self.out, "{level}")?; + if let Some(expr) = result { + let name = format!("{}{}", back::BAKE_PREFIX, expr.index()); + self.start_baking_expression(expr, &context.expression, &name)?; + self.named_expressions.insert(expr, name); + } + let fun_name = &self.names[&NameKey::Function(function)]; + write!(self.out, "{fun_name}(")?; + // first, write down the actual arguments + for (i, &handle) in arguments.iter().enumerate() { + if i != 0 { + write!(self.out, ", ")?; + } + self.put_expression(handle, &context.expression, true)?; + } + // follow-up with any global resources used + let mut separate = !arguments.is_empty(); + let fun_info = &context.expression.mod_info[function]; + let mut supports_array_length = false; + for (handle, var) in context.expression.module.global_variables.iter() { + if fun_info[handle].is_empty() { + continue; + } + if var.space.needs_pass_through() { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + if separate { + write!(self.out, ", ")?; + } else { + separate = true; + } + write!(self.out, "{name}")?; + } + supports_array_length |= + needs_array_length(var.ty, &context.expression.module.types); + } + if supports_array_length { + if separate { + write!(self.out, ", ")?; + } + write!(self.out, "_buffer_sizes")?; + } + + // done + writeln!(self.out, ");")?; + } + crate::Statement::Atomic { + pointer, + ref fun, + value, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_baking_expression(result, &context.expression, &res_name)?; + self.named_expressions.insert(result, res_name); + match *fun { + crate::AtomicFunction::Add => { + self.put_atomic_fetch(pointer, "add", value, &context.expression)?; + } + crate::AtomicFunction::Subtract => { + self.put_atomic_fetch(pointer, "sub", value, &context.expression)?; + } + crate::AtomicFunction::And => { + self.put_atomic_fetch(pointer, "and", value, &context.expression)?; + } + crate::AtomicFunction::InclusiveOr => { + self.put_atomic_fetch(pointer, "or", value, &context.expression)?; + } + crate::AtomicFunction::ExclusiveOr => { + self.put_atomic_fetch(pointer, "xor", value, &context.expression)?; + } + crate::AtomicFunction::Min => { + self.put_atomic_fetch(pointer, "min", value, &context.expression)?; + } + crate::AtomicFunction::Max => { + self.put_atomic_fetch(pointer, "max", value, &context.expression)?; + } + crate::AtomicFunction::Exchange { compare: None } => { + self.put_atomic_operation( + pointer, + "exchange", + "", + value, + &context.expression, + )?; + } + crate::AtomicFunction::Exchange { .. } => { + return Err(Error::FeatureNotImplemented( + "atomic CompareExchange".to_string(), + )); + } + } + // done + writeln!(self.out, ";")?; + } + crate::Statement::WorkGroupUniformLoad { pointer, result } => { + self.write_barrier(crate::Barrier::WORK_GROUP, level)?; + + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.put_load(pointer, &context.expression, true)?; + self.named_expressions.insert(result, name); + + writeln!(self.out, ";")?; + self.write_barrier(crate::Barrier::WORK_GROUP, level)?; + } + crate::Statement::RayQuery { query, ref fun } => { + if context.expression.lang_version < (2, 4) { + return Err(Error::UnsupportedRayTracing); + } + + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + //TODO: how to deal with winding? + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?; + { + let f_opaque = back::RayFlag::CULL_OPAQUE.bits(); + let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits(); + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!( + self.out, + ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode((" + )?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?; + writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?; + } + { + let f_opaque = back::RayFlag::OPAQUE.bits(); + let f_no_opaque = back::RayFlag::NO_OPAQUE.bits(); + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?; + writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?; + } + { + let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits(); + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!( + self.out, + ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection((" + )?; + self.put_expression(descriptor, &context.expression, true)?; + writeln!(self.out, ".flags & {flag}) != 0);")?; + } + + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?; + self.put_expression(query, &context.expression, true)?; + write!( + self.out, + ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray(" + )?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".origin, ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".dir, ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".tmin, ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".tmax), ")?; + self.put_expression(acceleration_structure, &context.expression, true)?; + write!(self.out, ", ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".cull_mask);")?; + + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?; + } + crate::RayQueryFunction::Proceed { result } => { + write!(self.out, "{level}")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?; + //TODO: actually proceed? + + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?; + } + crate::RayQueryFunction::Terminate => { + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.abort();")?; + } + } + } + } + } + + // un-emit expressions + //TODO: take care of loop/continuing? + for statement in statements { + if let crate::Statement::Emit(ref range) = *statement { + for handle in range.clone() { + self.named_expressions.remove(&handle); + } + } + } + Ok(()) + } + + fn put_store( + &mut self, + pointer: Handle<crate::Expression>, + value: Handle<crate::Expression>, + level: back::Level, + context: &StatementContext, + ) -> BackendResult { + let policy = context.expression.choose_bounds_check_policy(pointer); + if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite + && self.put_bounds_checks(pointer, &context.expression, level, "if (")? + { + writeln!(self.out, ") {{")?; + self.put_unchecked_store(pointer, value, policy, level.next(), context)?; + writeln!(self.out, "{level}}}")?; + } else { + self.put_unchecked_store(pointer, value, policy, level, context)?; + } + + Ok(()) + } + + fn put_unchecked_store( + &mut self, + pointer: Handle<crate::Expression>, + value: Handle<crate::Expression>, + policy: index::BoundsCheckPolicy, + level: back::Level, + context: &StatementContext, + ) -> BackendResult { + let is_atomic_pointer = context + .expression + .resolve_type(pointer) + .is_atomic_pointer(&context.expression.module.types); + + if is_atomic_pointer { + write!( + self.out, + "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}" + )?; + self.put_access_chain(pointer, policy, &context.expression)?; + write!(self.out, ", ")?; + self.put_expression(value, &context.expression, true)?; + writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?; + } else { + write!(self.out, "{level}")?; + self.put_access_chain(pointer, policy, &context.expression)?; + write!(self.out, " = ")?; + self.put_expression(value, &context.expression, true)?; + writeln!(self.out, ";")?; + } + + Ok(()) + } + + pub fn write( + &mut self, + module: &crate::Module, + info: &valid::ModuleInfo, + options: &Options, + pipeline_options: &PipelineOptions, + ) -> Result<TranslationInfo, Error> { + self.names.clear(); + self.namer.reset( + module, + super::keywords::RESERVED, + &[], + &[], + &[CLAMPED_LOD_LOAD_PREFIX], + &mut self.names, + ); + self.struct_member_pads.clear(); + + writeln!( + self.out, + "// language: metal{}.{}", + options.lang_version.0, options.lang_version.1 + )?; + writeln!(self.out, "#include <metal_stdlib>")?; + writeln!(self.out, "#include <simd/simd.h>")?; + writeln!(self.out)?; + // Work around Metal bug where `uint` is not available by default + writeln!(self.out, "using {NAMESPACE}::uint;")?; + + let mut uses_ray_query = false; + for (_, ty) in module.types.iter() { + match ty.inner { + crate::TypeInner::AccelerationStructure => { + if options.lang_version < (2, 4) { + return Err(Error::UnsupportedRayTracing); + } + } + crate::TypeInner::RayQuery => { + if options.lang_version < (2, 4) { + return Err(Error::UnsupportedRayTracing); + } + uses_ray_query = true; + } + _ => (), + } + } + + if module.special_types.ray_desc.is_some() + || module.special_types.ray_intersection.is_some() + { + if options.lang_version < (2, 4) { + return Err(Error::UnsupportedRayTracing); + } + } + + if uses_ray_query { + self.put_ray_query_type()?; + } + + if options + .bounds_check_policies + .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite) + { + self.put_default_constructible()?; + } + writeln!(self.out)?; + + { + let mut indices = vec![]; + for (handle, var) in module.global_variables.iter() { + if needs_array_length(var.ty, &module.types) { + let idx = handle.index(); + indices.push(idx); + } + } + + if !indices.is_empty() { + writeln!(self.out, "struct _mslBufferSizes {{")?; + + for idx in indices { + writeln!(self.out, "{}uint size{};", back::INDENT, idx)?; + } + + writeln!(self.out, "}};")?; + writeln!(self.out)?; + } + }; + + self.write_type_defs(module)?; + self.write_global_constants(module, info)?; + self.write_functions(module, info, options, pipeline_options) + } + + /// Write the definition for the `DefaultConstructible` class. + /// + /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to + /// produce 'zero' values for any type, including structs, arrays, and so + /// on. We could do this by emitting default constructor applications, but + /// that would entail printing the name of the type, which is more trouble + /// than you'd think. Instead, we just construct this magic C++14 class that + /// can be converted to any type that can be default constructed, using + /// template parameter inference to detect which type is needed, so we don't + /// have to figure out the name. + /// + /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite + fn put_default_constructible(&mut self) -> BackendResult { + let tab = back::INDENT; + writeln!(self.out, "struct DefaultConstructible {{")?; + writeln!(self.out, "{tab}template<typename T>")?; + writeln!(self.out, "{tab}operator T() && {{")?; + writeln!(self.out, "{tab}{tab}return T {{}};")?; + writeln!(self.out, "{tab}}}")?; + writeln!(self.out, "}};")?; + Ok(()) + } + + fn put_ray_query_type(&mut self) -> BackendResult { + let tab = back::INDENT; + writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?; + let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>"); + writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?; + writeln!( + self.out, + "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};" + )?; + writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?; + writeln!(self.out, "}};")?; + writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?; + let v_triangle = back::RayIntersectionType::Triangle as u32; + let v_bbox = back::RayIntersectionType::BoundingBox as u32; + writeln!( + self.out, + "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : " + )?; + writeln!( + self.out, + "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;" + )?; + writeln!(self.out, "}}")?; + Ok(()) + } + + fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult { + for (handle, ty) in module.types.iter() { + if !ty.needs_alias() { + continue; + } + let name = &self.names[&NameKey::Type(handle)]; + match ty.inner { + // Naga IR can pass around arrays by value, but Metal, following + // C++, performs an array-to-pointer conversion (C++ [conv.array]) + // on expressions of array type, so assigning the array by value + // isn't possible. However, Metal *does* assign structs by + // value. So in our Metal output, we wrap all array types in + // synthetic struct types: + // + // struct type1 { + // float inner[10] + // }; + // + // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in + // any expression that actually wants access to the array. + crate::TypeInner::Array { + base, + size, + stride: _, + } => { + let base_name = TypeContext { + handle: base, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + + match size { + crate::ArraySize::Constant(size) => { + writeln!(self.out, "struct {name} {{")?; + writeln!( + self.out, + "{}{} {}[{}];", + back::INDENT, + base_name, + WRAPPED_ARRAY_FIELD, + size + )?; + writeln!(self.out, "}};")?; + } + crate::ArraySize::Dynamic => { + writeln!(self.out, "typedef {base_name} {name}[1];")?; + } + } + } + crate::TypeInner::Struct { + ref members, span, .. + } => { + writeln!(self.out, "struct {name} {{")?; + let mut last_offset = 0; + for (index, member) in members.iter().enumerate() { + if member.offset > last_offset { + self.struct_member_pads.insert((handle, index as u32)); + let pad = member.offset - last_offset; + writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?; + } + let ty_inner = &module.types[member.ty].inner; + last_offset = member.offset + ty_inner.size(module.to_ctx()); + + let member_name = &self.names[&NameKey::StructMember(handle, index as u32)]; + + // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector + match should_pack_struct_member(members, span, index, module) { + Some(scalar) => { + writeln!( + self.out, + "{}{}::packed_{}3 {};", + back::INDENT, + NAMESPACE, + scalar.to_msl_name(), + member_name + )?; + } + None => { + let base_name = TypeContext { + handle: member.ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + writeln!( + self.out, + "{}{} {};", + back::INDENT, + base_name, + member_name + )?; + + // for 3-component vectors, add one component + if let crate::TypeInner::Vector { + size: crate::VectorSize::Tri, + scalar, + } = *ty_inner + { + last_offset += scalar.width as u32; + } + } + } + } + writeln!(self.out, "}};")?; + } + _ => { + let ty_name = TypeContext { + handle, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: true, + }; + writeln!(self.out, "typedef {ty_name} {name};")?; + } + } + } + + // Write functions to create special types. + for (type_key, struct_ty) in module.special_types.predeclared_types.iter() { + match type_key { + &crate::PredeclaredType::ModfResult { size, width } + | &crate::PredeclaredType::FrexpResult { size, width } => { + let arg_type_name_owner; + let arg_type_name = if let Some(size) = size { + arg_type_name_owner = format!( + "{NAMESPACE}::{}{}", + if width == 8 { "double" } else { "float" }, + size as u8 + ); + &arg_type_name_owner + } else if width == 8 { + "double" + } else { + "float" + }; + + let other_type_name_owner; + let (defined_func_name, called_func_name, other_type_name) = + if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) { + (MODF_FUNCTION, "modf", arg_type_name) + } else { + let other_type_name = if let Some(size) = size { + other_type_name_owner = format!("int{}", size as u8); + &other_type_name_owner + } else { + "int" + }; + (FREXP_FUNCTION, "frexp", other_type_name) + }; + + let struct_name = &self.names[&NameKey::Type(*struct_ty)]; + + writeln!(self.out)?; + writeln!( + self.out, + "{} {defined_func_name}({arg_type_name} arg) {{ + {other_type_name} other; + {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other); + return {}{{ fract, other }}; +}}", + struct_name, struct_name + )?; + } + &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {} + } + } + + Ok(()) + } + + /// Writes all named constants + fn write_global_constants( + &mut self, + module: &crate::Module, + mod_info: &valid::ModuleInfo, + ) -> BackendResult { + let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some()); + + for (handle, constant) in constants { + let ty_name = TypeContext { + handle: constant.ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + let name = &self.names[&NameKey::Constant(handle)]; + write!(self.out, "constant {ty_name} {name} = ")?; + self.put_const_expression(constant.init, module, mod_info)?; + writeln!(self.out, ";")?; + } + + Ok(()) + } + + fn put_inline_sampler_properties( + &mut self, + level: back::Level, + sampler: &sm::InlineSampler, + ) -> BackendResult { + for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) { + writeln!( + self.out, + "{}{}::{}_address::{},", + level, + NAMESPACE, + letter, + address.as_str(), + )?; + } + writeln!( + self.out, + "{}{}::mag_filter::{},", + level, + NAMESPACE, + sampler.mag_filter.as_str(), + )?; + writeln!( + self.out, + "{}{}::min_filter::{},", + level, + NAMESPACE, + sampler.min_filter.as_str(), + )?; + if let Some(filter) = sampler.mip_filter { + writeln!( + self.out, + "{}{}::mip_filter::{},", + level, + NAMESPACE, + filter.as_str(), + )?; + } + // avoid setting it on platforms that don't support it + if sampler.border_color != sm::BorderColor::TransparentBlack { + writeln!( + self.out, + "{}{}::border_color::{},", + level, + NAMESPACE, + sampler.border_color.as_str(), + )?; + } + //TODO: I'm not able to feed this in a way that MSL likes: + //>error: use of undeclared identifier 'lod_clamp' + //>error: no member named 'max_anisotropy' in namespace 'metal' + if false { + if let Some(ref lod) = sampler.lod_clamp { + writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?; + } + if let Some(aniso) = sampler.max_anisotropy { + writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?; + } + } + if sampler.compare_func != sm::CompareFunc::Never { + writeln!( + self.out, + "{}{}::compare_func::{},", + level, + NAMESPACE, + sampler.compare_func.as_str(), + )?; + } + writeln!( + self.out, + "{}{}::coord::{}", + level, + NAMESPACE, + sampler.coord.as_str() + )?; + Ok(()) + } + + // Returns the array of mapped entry point names. + fn write_functions( + &mut self, + module: &crate::Module, + mod_info: &valid::ModuleInfo, + options: &Options, + pipeline_options: &PipelineOptions, + ) -> Result<TranslationInfo, Error> { + let mut pass_through_globals = Vec::new(); + for (fun_handle, fun) in module.functions.iter() { + log::trace!( + "function {:?}, handle {:?}", + fun.name.as_deref().unwrap_or("(anonymous)"), + fun_handle + ); + + let fun_info = &mod_info[fun_handle]; + pass_through_globals.clear(); + let mut supports_array_length = false; + for (handle, var) in module.global_variables.iter() { + if !fun_info[handle].is_empty() { + if var.space.needs_pass_through() { + pass_through_globals.push(handle); + } + supports_array_length |= needs_array_length(var.ty, &module.types); + } + } + + writeln!(self.out)?; + let fun_name = &self.names[&NameKey::Function(fun_handle)]; + match fun.result { + Some(ref result) => { + let ty_name = TypeContext { + handle: result.ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + write!(self.out, "{ty_name}")?; + } + None => { + write!(self.out, "void")?; + } + } + writeln!(self.out, " {fun_name}(")?; + + for (index, arg) in fun.arguments.iter().enumerate() { + let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)]; + let param_type_name = TypeContext { + handle: arg.ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + let separator = separate( + !pass_through_globals.is_empty() + || index + 1 != fun.arguments.len() + || supports_array_length, + ); + writeln!( + self.out, + "{}{} {}{}", + back::INDENT, + param_type_name, + name, + separator + )?; + } + for (index, &handle) in pass_through_globals.iter().enumerate() { + let tyvar = TypedGlobalVariable { + module, + names: &self.names, + handle, + usage: fun_info[handle], + binding: None, + reference: true, + }; + let separator = + separate(index + 1 != pass_through_globals.len() || supports_array_length); + write!(self.out, "{}", back::INDENT)?; + tyvar.try_fmt(&mut self.out)?; + writeln!(self.out, "{separator}")?; + } + + if supports_array_length { + writeln!( + self.out, + "{}constant _mslBufferSizes& _buffer_sizes", + back::INDENT + )?; + } + + writeln!(self.out, ") {{")?; + + let guarded_indices = + index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies); + + let context = StatementContext { + expression: ExpressionContext { + function: fun, + origin: FunctionOrigin::Handle(fun_handle), + info: fun_info, + lang_version: options.lang_version, + policies: options.bounds_check_policies, + guarded_indices, + module, + mod_info, + pipeline_options, + }, + result_struct: None, + }; + + for (local_handle, local) in fun.local_variables.iter() { + let ty_name = TypeContext { + handle: local.ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + let local_name = &self.names[&NameKey::FunctionLocal(fun_handle, local_handle)]; + write!(self.out, "{}{} {}", back::INDENT, ty_name, local_name)?; + match local.init { + Some(value) => { + write!(self.out, " = ")?; + self.put_expression(value, &context.expression, true)?; + } + None => { + write!(self.out, " = {{}}")?; + } + }; + writeln!(self.out, ";")?; + } + + self.update_expressions_to_bake(fun, fun_info, &context.expression); + self.put_block(back::Level(1), &fun.body, &context)?; + writeln!(self.out, "}}")?; + self.named_expressions.clear(); + } + + let mut info = TranslationInfo { + entry_point_names: Vec::with_capacity(module.entry_points.len()), + }; + for (ep_index, ep) in module.entry_points.iter().enumerate() { + let fun = &ep.function; + let fun_info = mod_info.get_entry_point(ep_index); + let mut ep_error = None; + + log::trace!( + "entry point {:?}, index {:?}", + fun.name.as_deref().unwrap_or("(anonymous)"), + ep_index + ); + + // Is any global variable used by this entry point dynamically sized? + let supports_array_length = module + .global_variables + .iter() + .filter(|&(handle, _)| !fun_info[handle].is_empty()) + .any(|(_, var)| needs_array_length(var.ty, &module.types)); + + // skip this entry point if any global bindings are missing, + // or their types are incompatible. + if !options.fake_missing_bindings { + for (var_handle, var) in module.global_variables.iter() { + if fun_info[var_handle].is_empty() { + continue; + } + match var.space { + crate::AddressSpace::Uniform + | crate::AddressSpace::Storage { .. } + | crate::AddressSpace::Handle => { + let br = match var.binding { + Some(ref br) => br, + None => { + let var_name = var.name.clone().unwrap_or_default(); + ep_error = + Some(super::EntryPointError::MissingBinding(var_name)); + break; + } + }; + let target = options.get_resource_binding_target(ep, br); + let good = match target { + Some(target) => { + let binding_ty = match module.types[var.ty].inner { + crate::TypeInner::BindingArray { base, .. } => { + &module.types[base].inner + } + ref ty => ty, + }; + match *binding_ty { + crate::TypeInner::Image { .. } => target.texture.is_some(), + crate::TypeInner::Sampler { .. } => { + target.sampler.is_some() + } + _ => target.buffer.is_some(), + } + } + None => false, + }; + if !good { + ep_error = + Some(super::EntryPointError::MissingBindTarget(br.clone())); + break; + } + } + crate::AddressSpace::PushConstant => { + if let Err(e) = options.resolve_push_constants(ep) { + ep_error = Some(e); + break; + } + } + crate::AddressSpace::Function + | crate::AddressSpace::Private + | crate::AddressSpace::WorkGroup => {} + } + } + if supports_array_length { + if let Err(err) = options.resolve_sizes_buffer(ep) { + ep_error = Some(err); + } + } + } + + if let Some(err) = ep_error { + info.entry_point_names.push(Err(err)); + continue; + } + let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)]; + info.entry_point_names.push(Ok(fun_name.clone())); + + writeln!(self.out)?; + + let (em_str, in_mode, out_mode) = match ep.stage { + crate::ShaderStage::Vertex => ( + "vertex", + LocationMode::VertexInput, + LocationMode::VertexOutput, + ), + crate::ShaderStage::Fragment { .. } => ( + "fragment", + LocationMode::FragmentInput, + LocationMode::FragmentOutput, + ), + crate::ShaderStage::Compute { .. } => { + ("kernel", LocationMode::Uniform, LocationMode::Uniform) + } + }; + + // Since `Namer.reset` wasn't expecting struct members to be + // suddenly injected into another namespace like this, + // `self.names` doesn't keep them distinct from other variables. + // Generate fresh names for these arguments, and remember the + // mapping. + let mut flattened_member_names = FastHashMap::default(); + // Varyings' members get their own namespace + let mut varyings_namer = crate::proc::Namer::default(); + + // List all the Naga `EntryPoint`'s `Function`'s arguments, + // flattening structs into their members. In Metal, we will pass + // each of these values to the entry point as a separate argument— + // except for the varyings, handled next. + let mut flattened_arguments = Vec::new(); + for (arg_index, arg) in fun.arguments.iter().enumerate() { + match module.types[arg.ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + for (member_index, member) in members.iter().enumerate() { + let member_index = member_index as u32; + flattened_arguments.push(( + NameKey::StructMember(arg.ty, member_index), + member.ty, + member.binding.as_ref(), + )); + let name_key = NameKey::StructMember(arg.ty, member_index); + let name = match member.binding { + Some(crate::Binding::Location { .. }) => { + varyings_namer.call(&self.names[&name_key]) + } + _ => self.namer.call(&self.names[&name_key]), + }; + flattened_member_names.insert(name_key, name); + } + } + _ => flattened_arguments.push(( + NameKey::EntryPointArgument(ep_index as _, arg_index as u32), + arg.ty, + arg.binding.as_ref(), + )), + } + } + + // Identify the varyings among the argument values, and emit a + // struct type named `<fun>Input` to hold them. + let stage_in_name = format!("{fun_name}Input"); + let varyings_member_name = self.namer.call("varyings"); + let mut has_varyings = false; + if !flattened_arguments.is_empty() { + writeln!(self.out, "struct {stage_in_name} {{")?; + for &(ref name_key, ty, binding) in flattened_arguments.iter() { + let binding = match binding { + Some(ref binding @ &crate::Binding::Location { .. }) => binding, + _ => continue, + }; + has_varyings = true; + let name = match *name_key { + NameKey::StructMember(..) => &flattened_member_names[name_key], + _ => &self.names[name_key], + }; + let ty_name = TypeContext { + handle: ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + let resolved = options.resolve_local_binding(binding, in_mode)?; + write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?; + resolved.try_fmt(&mut self.out)?; + writeln!(self.out, ";")?; + } + writeln!(self.out, "}};")?; + } + + // Define a struct type named for the return value, if any, named + // `<fun>Output`. + let stage_out_name = format!("{fun_name}Output"); + let result_member_name = self.namer.call("member"); + let result_type_name = match fun.result { + Some(ref result) => { + let mut result_members = Vec::new(); + if let crate::TypeInner::Struct { ref members, .. } = + module.types[result.ty].inner + { + for (member_index, member) in members.iter().enumerate() { + result_members.push(( + &self.names[&NameKey::StructMember(result.ty, member_index as u32)], + member.ty, + member.binding.as_ref(), + )); + } + } else { + result_members.push(( + &result_member_name, + result.ty, + result.binding.as_ref(), + )); + } + + writeln!(self.out, "struct {stage_out_name} {{")?; + let mut has_point_size = false; + for (name, ty, binding) in result_members { + let ty_name = TypeContext { + handle: ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: true, + }; + let binding = binding.ok_or(Error::Validation)?; + + if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding { + has_point_size = true; + if !pipeline_options.allow_and_force_point_size { + continue; + } + } + + let array_len = match module.types[ty].inner { + crate::TypeInner::Array { + size: crate::ArraySize::Constant(size), + .. + } => Some(size), + _ => None, + }; + let resolved = options.resolve_local_binding(binding, out_mode)?; + write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?; + if let Some(array_len) = array_len { + write!(self.out, " [{array_len}]")?; + } + resolved.try_fmt(&mut self.out)?; + writeln!(self.out, ";")?; + } + + if pipeline_options.allow_and_force_point_size + && ep.stage == crate::ShaderStage::Vertex + && !has_point_size + { + // inject the point size output last + writeln!( + self.out, + "{}float _point_size [[point_size]];", + back::INDENT + )?; + } + writeln!(self.out, "}};")?; + &stage_out_name + } + None => "void", + }; + + // Write the entry point function's name, and begin its argument list. + writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?; + let mut is_first_argument = true; + + // If we have produced a struct holding the `EntryPoint`'s + // `Function`'s arguments' varyings, pass that struct first. + if has_varyings { + writeln!( + self.out, + " {stage_in_name} {varyings_member_name} [[stage_in]]" + )?; + is_first_argument = false; + } + + let mut local_invocation_id = None; + + // Then pass the remaining arguments not included in the varyings + // struct. + for &(ref name_key, ty, binding) in flattened_arguments.iter() { + let binding = match binding { + Some(binding @ &crate::Binding::BuiltIn { .. }) => binding, + _ => continue, + }; + let name = match *name_key { + NameKey::StructMember(..) => &flattened_member_names[name_key], + _ => &self.names[name_key], + }; + + if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) { + local_invocation_id = Some(name_key); + } + + let ty_name = TypeContext { + handle: ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + let resolved = options.resolve_local_binding(binding, in_mode)?; + let separator = if is_first_argument { + is_first_argument = false; + ' ' + } else { + ',' + }; + write!(self.out, "{separator} {ty_name} {name}")?; + resolved.try_fmt(&mut self.out)?; + writeln!(self.out)?; + } + + let need_workgroup_variables_initialization = + self.need_workgroup_variables_initialization(options, ep, module, fun_info); + + if need_workgroup_variables_initialization && local_invocation_id.is_none() { + let separator = if is_first_argument { + is_first_argument = false; + ' ' + } else { + ',' + }; + writeln!( + self.out, + "{separator} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]" + )?; + } + + // Those global variables used by this entry point and its callees + // get passed as arguments. `Private` globals are an exception, they + // don't outlive this invocation, so we declare them below as locals + // within the entry point. + for (handle, var) in module.global_variables.iter() { + let usage = fun_info[handle]; + if usage.is_empty() || var.space == crate::AddressSpace::Private { + continue; + } + + if options.lang_version < (1, 2) { + match var.space { + // This restriction is not documented in the MSL spec + // but validation will fail if it is not upheld. + // + // We infer the required version from the "Function + // Buffer Read-Writes" section of [what's new], where + // the feature sets listed correspond with the ones + // supporting MSL 1.2. + // + // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html + crate::AddressSpace::Storage { access } + if access.contains(crate::StorageAccess::STORE) + && ep.stage == crate::ShaderStage::Fragment => + { + return Err(Error::UnsupportedWriteableStorageBuffer) + } + crate::AddressSpace::Handle => { + match module.types[var.ty].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage { access, .. }, + .. + } => { + // This restriction is not documented in the MSL spec + // but validation will fail if it is not upheld. + // + // We infer the required version from the "Function + // Texture Read-Writes" section of [what's new], where + // the feature sets listed correspond with the ones + // supporting MSL 1.2. + // + // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html + if access.contains(crate::StorageAccess::STORE) + && (ep.stage == crate::ShaderStage::Vertex + || ep.stage == crate::ShaderStage::Fragment) + { + return Err(Error::UnsupportedWriteableStorageTexture( + ep.stage, + )); + } + + if access.contains( + crate::StorageAccess::LOAD | crate::StorageAccess::STORE, + ) { + return Err(Error::UnsupportedRWStorageTexture); + } + } + _ => {} + } + } + _ => {} + } + } + + // Check min MSL version for binding arrays + match var.space { + crate::AddressSpace::Handle => match module.types[var.ty].inner { + crate::TypeInner::BindingArray { base, .. } => { + match module.types[base].inner { + crate::TypeInner::Sampler { .. } => { + if options.lang_version < (2, 0) { + return Err(Error::UnsupportedArrayOf( + "samplers".to_string(), + )); + } + } + crate::TypeInner::Image { class, .. } => match class { + crate::ImageClass::Sampled { .. } + | crate::ImageClass::Depth { .. } + | crate::ImageClass::Storage { + access: crate::StorageAccess::LOAD, + .. + } => { + // Array of textures since: + // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164) + // - macOS: Metal 2 + + if options.lang_version < (2, 0) { + return Err(Error::UnsupportedArrayOf( + "textures".to_string(), + )); + } + } + crate::ImageClass::Storage { + access: crate::StorageAccess::STORE, + .. + } => { + // Array of write-only textures since: + // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164) + // - macOS: Metal 2 + + if options.lang_version < (2, 0) { + return Err(Error::UnsupportedArrayOf( + "write-only textures".to_string(), + )); + } + } + crate::ImageClass::Storage { .. } => { + return Err(Error::UnsupportedArrayOf( + "read-write textures".to_string(), + )); + } + }, + _ => { + return Err(Error::UnsupportedArrayOfType(base)); + } + } + } + _ => {} + }, + _ => {} + } + + // the resolves have already been checked for `!fake_missing_bindings` case + let resolved = match var.space { + crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(), + crate::AddressSpace::WorkGroup => None, + _ => options + .resolve_resource_binding(ep, var.binding.as_ref().unwrap()) + .ok(), + }; + if let Some(ref resolved) = resolved { + // Inline samplers are be defined in the EP body + if resolved.as_inline_sampler(options).is_some() { + continue; + } + } + + let tyvar = TypedGlobalVariable { + module, + names: &self.names, + handle, + usage, + binding: resolved.as_ref(), + reference: true, + }; + let separator = if is_first_argument { + is_first_argument = false; + ' ' + } else { + ',' + }; + write!(self.out, "{separator} ")?; + tyvar.try_fmt(&mut self.out)?; + if let Some(resolved) = resolved { + resolved.try_fmt(&mut self.out)?; + } + if let Some(value) = var.init { + write!(self.out, " = ")?; + self.put_const_expression(value, module, mod_info)?; + } + writeln!(self.out)?; + } + + // If this entry uses any variable-length arrays, their sizes are + // passed as a final struct-typed argument. + if supports_array_length { + // this is checked earlier + let resolved = options.resolve_sizes_buffer(ep).unwrap(); + let separator = if module.global_variables.is_empty() { + ' ' + } else { + ',' + }; + write!( + self.out, + "{separator} constant _mslBufferSizes& _buffer_sizes", + )?; + resolved.try_fmt(&mut self.out)?; + writeln!(self.out)?; + } + + // end of the entry point argument list + writeln!(self.out, ") {{")?; + + if need_workgroup_variables_initialization { + self.write_workgroup_variables_initialization( + module, + mod_info, + fun_info, + local_invocation_id, + )?; + } + + // Metal doesn't support private mutable variables outside of functions, + // so we put them here, just like the locals. + for (handle, var) in module.global_variables.iter() { + let usage = fun_info[handle]; + if usage.is_empty() { + continue; + } + if var.space == crate::AddressSpace::Private { + let tyvar = TypedGlobalVariable { + module, + names: &self.names, + handle, + usage, + binding: None, + reference: false, + }; + write!(self.out, "{}", back::INDENT)?; + tyvar.try_fmt(&mut self.out)?; + match var.init { + Some(value) => { + write!(self.out, " = ")?; + self.put_const_expression(value, module, mod_info)?; + writeln!(self.out, ";")?; + } + None => { + writeln!(self.out, " = {{}};")?; + } + }; + } else if let Some(ref binding) = var.binding { + // write an inline sampler + let resolved = options.resolve_resource_binding(ep, binding).unwrap(); + if let Some(sampler) = resolved.as_inline_sampler(options) { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + writeln!( + self.out, + "{}constexpr {}::sampler {}(", + back::INDENT, + NAMESPACE, + name + )?; + self.put_inline_sampler_properties(back::Level(2), sampler)?; + writeln!(self.out, "{});", back::INDENT)?; + } + } + } + + // Now take the arguments that we gathered into structs, and the + // structs that we flattened into arguments, and emit local + // variables with initializers that put everything back the way the + // body code expects. + // + // If we had to generate fresh names for struct members passed as + // arguments, be sure to use those names when rebuilding the struct. + // + // "Each day, I change some zeros to ones, and some ones to zeros. + // The rest, I leave alone." + for (arg_index, arg) in fun.arguments.iter().enumerate() { + let arg_name = + &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)]; + match module.types[arg.ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + let struct_name = &self.names[&NameKey::Type(arg.ty)]; + write!( + self.out, + "{}const {} {} = {{ ", + back::INDENT, + struct_name, + arg_name + )?; + for (member_index, member) in members.iter().enumerate() { + let key = NameKey::StructMember(arg.ty, member_index as u32); + let name = &flattened_member_names[&key]; + if member_index != 0 { + write!(self.out, ", ")?; + } + // insert padding initialization, if needed + if self + .struct_member_pads + .contains(&(arg.ty, member_index as u32)) + { + write!(self.out, "{{}}, ")?; + } + if let Some(crate::Binding::Location { .. }) = member.binding { + write!(self.out, "{varyings_member_name}.")?; + } + write!(self.out, "{name}")?; + } + writeln!(self.out, " }};")?; + } + _ => { + if let Some(crate::Binding::Location { .. }) = arg.binding { + writeln!( + self.out, + "{}const auto {} = {}.{};", + back::INDENT, + arg_name, + varyings_member_name, + arg_name + )?; + } + } + } + } + + let guarded_indices = + index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies); + + let context = StatementContext { + expression: ExpressionContext { + function: fun, + origin: FunctionOrigin::EntryPoint(ep_index as _), + info: fun_info, + lang_version: options.lang_version, + policies: options.bounds_check_policies, + guarded_indices, + module, + mod_info, + pipeline_options, + }, + result_struct: Some(&stage_out_name), + }; + + // Finally, declare all the local variables that we need + //TODO: we can postpone this till the relevant expressions are emitted + for (local_handle, local) in fun.local_variables.iter() { + let name = &self.names[&NameKey::EntryPointLocal(ep_index as _, local_handle)]; + let ty_name = TypeContext { + handle: local.ty, + gctx: module.to_ctx(), + names: &self.names, + access: crate::StorageAccess::empty(), + binding: None, + first_time: false, + }; + write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?; + match local.init { + Some(value) => { + write!(self.out, " = ")?; + self.put_expression(value, &context.expression, true)?; + } + None => { + write!(self.out, " = {{}}")?; + } + }; + writeln!(self.out, ";")?; + } + + self.update_expressions_to_bake(fun, fun_info, &context.expression); + self.put_block(back::Level(1), &fun.body, &context)?; + writeln!(self.out, "}}")?; + if ep_index + 1 != module.entry_points.len() { + writeln!(self.out)?; + } + self.named_expressions.clear(); + } + + Ok(info) + } + + fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult { + // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`, + // so we try to avoid it here. + if flags.is_empty() { + writeln!( + self.out, + "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);", + )?; + } + if flags.contains(crate::Barrier::STORAGE) { + writeln!( + self.out, + "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);", + )?; + } + if flags.contains(crate::Barrier::WORK_GROUP) { + writeln!( + self.out, + "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", + )?; + } + Ok(()) + } +} + +/// Initializing workgroup variables is more tricky for Metal because we have to deal +/// with atomics at the type-level (which don't have a copy constructor). +mod workgroup_mem_init { + use crate::EntryPoint; + + use super::*; + + enum Access { + GlobalVariable(Handle<crate::GlobalVariable>), + StructMember(Handle<crate::Type>, u32), + Array(usize), + } + + impl Access { + fn write<W: Write>( + &self, + writer: &mut W, + names: &FastHashMap<NameKey, String>, + ) -> Result<(), core::fmt::Error> { + match *self { + Access::GlobalVariable(handle) => { + write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)]) + } + Access::StructMember(handle, index) => { + write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)]) + } + Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"), + } + } + } + + struct AccessStack { + stack: Vec<Access>, + array_depth: usize, + } + + impl AccessStack { + const fn new() -> Self { + Self { + stack: Vec::new(), + array_depth: 0, + } + } + + fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R { + let array_depth = self.array_depth; + self.stack.push(Access::Array(array_depth)); + self.array_depth += 1; + let res = cb(self, array_depth); + self.stack.pop(); + self.array_depth -= 1; + res + } + + fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R { + self.stack.push(new); + let res = cb(self); + self.stack.pop(); + res + } + + fn write<W: Write>( + &self, + writer: &mut W, + names: &FastHashMap<NameKey, String>, + ) -> Result<(), core::fmt::Error> { + for next in self.stack.iter() { + next.write(writer, names)?; + } + Ok(()) + } + } + + impl<W: Write> Writer<W> { + pub(super) fn need_workgroup_variables_initialization( + &mut self, + options: &Options, + ep: &EntryPoint, + module: &crate::Module, + fun_info: &valid::FunctionInfo, + ) -> bool { + options.zero_initialize_workgroup_memory + && ep.stage == crate::ShaderStage::Compute + && module.global_variables.iter().any(|(handle, var)| { + !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }) + } + + pub(super) fn write_workgroup_variables_initialization( + &mut self, + module: &crate::Module, + module_info: &valid::ModuleInfo, + fun_info: &valid::FunctionInfo, + local_invocation_id: Option<&NameKey>, + ) -> BackendResult { + let level = back::Level(1); + + writeln!( + self.out, + "{}if ({}::all({} == {}::uint3(0u))) {{", + level, + NAMESPACE, + local_invocation_id + .map(|name_key| self.names[name_key].as_str()) + .unwrap_or("__local_invocation_id"), + NAMESPACE, + )?; + + let mut access_stack = AccessStack::new(); + + let vars = module.global_variables.iter().filter(|&(handle, var)| { + !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }); + + for (handle, var) in vars { + access_stack.enter(Access::GlobalVariable(handle), |access_stack| { + self.write_workgroup_variable_initialization( + module, + module_info, + var.ty, + access_stack, + level.next(), + ) + })?; + } + + writeln!(self.out, "{level}}}")?; + self.write_barrier(crate::Barrier::WORK_GROUP, level) + } + + fn write_workgroup_variable_initialization( + &mut self, + module: &crate::Module, + module_info: &valid::ModuleInfo, + ty: Handle<crate::Type>, + access_stack: &mut AccessStack, + level: back::Level, + ) -> BackendResult { + if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) { + write!(self.out, "{level}")?; + access_stack.write(&mut self.out, &self.names)?; + writeln!(self.out, " = {{}};")?; + } else { + match module.types[ty].inner { + crate::TypeInner::Atomic { .. } => { + write!( + self.out, + "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}" + )?; + access_stack.write(&mut self.out, &self.names)?; + writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?; + } + crate::TypeInner::Array { base, size, .. } => { + let count = match size.to_indexable_length(module).expect("Bad array size") + { + proc::IndexableLength::Known(count) => count, + proc::IndexableLength::Dynamic => unreachable!(), + }; + + access_stack.enter_array(|access_stack, array_depth| { + writeln!( + self.out, + "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{" + )?; + self.write_workgroup_variable_initialization( + module, + module_info, + base, + access_stack, + level.next(), + )?; + writeln!(self.out, "{level}}}")?; + BackendResult::Ok(()) + })?; + } + crate::TypeInner::Struct { ref members, .. } => { + for (index, member) in members.iter().enumerate() { + access_stack.enter( + Access::StructMember(ty, index as u32), + |access_stack| { + self.write_workgroup_variable_initialization( + module, + module_info, + member.ty, + access_stack, + level, + ) + }, + )?; + } + } + _ => unreachable!(), + } + } + + Ok(()) + } + } +} + +#[test] +fn test_stack_size() { + use crate::valid::{Capabilities, ValidationFlags}; + // create a module with at least one expression nested + let mut module = crate::Module::default(); + let mut fun = crate::Function::default(); + let const_expr = fun.expressions.append( + crate::Expression::Literal(crate::Literal::F32(1.0)), + Default::default(), + ); + let nested_expr = fun.expressions.append( + crate::Expression::Unary { + op: crate::UnaryOperator::Negate, + expr: const_expr, + }, + Default::default(), + ); + fun.body.push( + crate::Statement::Emit(fun.expressions.range_from(1)), + Default::default(), + ); + fun.body.push( + crate::Statement::If { + condition: nested_expr, + accept: crate::Block::new(), + reject: crate::Block::new(), + }, + Default::default(), + ); + let _ = module.functions.append(fun, Default::default()); + // analyse the module + let info = crate::valid::Validator::new(ValidationFlags::empty(), Capabilities::empty()) + .validate(&module) + .unwrap(); + // process the module + let mut writer = Writer::new(String::new()); + writer + .write(&module, &info, &Default::default(), &Default::default()) + .unwrap(); + + { + // check expression stack + let mut addresses_start = usize::MAX; + let mut addresses_end = 0usize; + for pointer in writer.put_expression_stack_pointers { + addresses_start = addresses_start.min(pointer as usize); + addresses_end = addresses_end.max(pointer as usize); + } + let stack_size = addresses_end - addresses_start; + // check the size (in debug only) + // last observed macOS value: 20528 (CI) + if !(11000..=25000).contains(&stack_size) { + panic!("`put_expression` stack size {stack_size} has changed!"); + } + } + + { + // check block stack + let mut addresses_start = usize::MAX; + let mut addresses_end = 0usize; + for pointer in writer.put_block_stack_pointers { + addresses_start = addresses_start.min(pointer as usize); + addresses_end = addresses_end.max(pointer as usize); + } + let stack_size = addresses_end - addresses_start; + // check the size (in debug only) + // last observed macOS value: 19152 (CI) + if !(9000..=20000).contains(&stack_size) { + panic!("`put_block` stack size {stack_size} has changed!"); + } + } +} diff --git a/third_party/rust/naga/src/back/spv/block.rs b/third_party/rust/naga/src/back/spv/block.rs new file mode 100644 index 0000000000..6c96fa09e3 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/block.rs @@ -0,0 +1,2368 @@ +/*! +Implementations for `BlockContext` methods. +*/ + +use super::{ + helpers, index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext, + Dimension, Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer, + WriterFlags, +}; +use crate::{arena::Handle, proc::TypeResolution, Statement}; +use spirv::Word; + +fn get_dimension(type_inner: &crate::TypeInner) -> Dimension { + match *type_inner { + crate::TypeInner::Scalar(_) => Dimension::Scalar, + crate::TypeInner::Vector { .. } => Dimension::Vector, + crate::TypeInner::Matrix { .. } => Dimension::Matrix, + _ => unreachable!(), + } +} + +/// The results of emitting code for a left-hand-side expression. +/// +/// On success, `write_expression_pointer` returns one of these. +enum ExpressionPointer { + /// The pointer to the expression's value is available, as the value of the + /// expression with the given id. + Ready { pointer_id: Word }, + + /// The access expression must be conditional on the value of `condition`, a boolean + /// expression that is true if all indices are in bounds. If `condition` is true, then + /// `access` is an `OpAccessChain` instruction that will compute a pointer to the + /// expression's value. If `condition` is false, then executing `access` would be + /// undefined behavior. + Conditional { + condition: Word, + access: Instruction, + }, +} + +/// The termination statement to be added to the end of the block +pub enum BlockExit { + /// Generates an OpReturn (void return) + Return, + /// Generates an OpBranch to the specified block + Branch { + /// The branch target block + target: Word, + }, + /// Translates a loop `break if` into an `OpBranchConditional` to the + /// merge block if true (the merge block is passed through [`LoopContext::break_id`] + /// or else to the loop header (passed through [`preamble_id`]) + /// + /// [`preamble_id`]: Self::BreakIf::preamble_id + BreakIf { + /// The condition of the `break if` + condition: Handle<crate::Expression>, + /// The loop header block id + preamble_id: Word, + }, +} + +#[derive(Debug)] +pub(crate) struct DebugInfoInner<'a> { + pub source_code: &'a str, + pub source_file_id: Word, +} + +impl Writer { + // Flip Y coordinate to adjust for coordinate space difference + // between SPIR-V and our IR. + // The `position_id` argument is a pointer to a `vecN<f32>`, + // whose `y` component we will negate. + fn write_epilogue_position_y_flip( + &mut self, + position_id: Word, + body: &mut Vec<Instruction>, + ) -> Result<(), Error> { + let float_ptr_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::F32, + pointer_space: Some(spirv::StorageClass::Output), + })); + let index_y_id = self.get_index_constant(1); + let access_id = self.id_gen.next(); + body.push(Instruction::access_chain( + float_ptr_type_id, + access_id, + position_id, + &[index_y_id], + )); + + let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::F32, + pointer_space: None, + })); + let load_id = self.id_gen.next(); + body.push(Instruction::load(float_type_id, load_id, access_id, None)); + + let neg_id = self.id_gen.next(); + body.push(Instruction::unary( + spirv::Op::FNegate, + float_type_id, + neg_id, + load_id, + )); + + body.push(Instruction::store(access_id, neg_id, None)); + Ok(()) + } + + // Clamp fragment depth between 0 and 1. + fn write_epilogue_frag_depth_clamp( + &mut self, + frag_depth_id: Word, + body: &mut Vec<Instruction>, + ) -> Result<(), Error> { + let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::F32, + pointer_space: None, + })); + let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0)); + let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0)); + + let original_id = self.id_gen.next(); + body.push(Instruction::load( + float_type_id, + original_id, + frag_depth_id, + None, + )); + + let clamp_id = self.id_gen.next(); + body.push(Instruction::ext_inst( + self.gl450_ext_inst_id, + spirv::GLOp::FClamp, + float_type_id, + clamp_id, + &[original_id, zero_scalar_id, one_scalar_id], + )); + + body.push(Instruction::store(frag_depth_id, clamp_id, None)); + Ok(()) + } + + fn write_entry_point_return( + &mut self, + value_id: Word, + ir_result: &crate::FunctionResult, + result_members: &[ResultMember], + body: &mut Vec<Instruction>, + ) -> Result<(), Error> { + for (index, res_member) in result_members.iter().enumerate() { + let member_value_id = match ir_result.binding { + Some(_) => value_id, + None => { + let member_value_id = self.id_gen.next(); + body.push(Instruction::composite_extract( + res_member.type_id, + member_value_id, + value_id, + &[index as u32], + )); + member_value_id + } + }; + + body.push(Instruction::store(res_member.id, member_value_id, None)); + + match res_member.built_in { + Some(crate::BuiltIn::Position { .. }) + if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) => + { + self.write_epilogue_position_y_flip(res_member.id, body)?; + } + Some(crate::BuiltIn::FragDepth) + if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) => + { + self.write_epilogue_frag_depth_clamp(res_member.id, body)?; + } + _ => {} + } + } + Ok(()) + } +} + +impl<'w> BlockContext<'w> { + /// Decide whether to put off emitting instructions for `expr_handle`. + /// + /// We would like to gather together chains of `Access` and `AccessIndex` + /// Naga expressions into a single `OpAccessChain` SPIR-V instruction. To do + /// this, we don't generate instructions for these exprs when we first + /// encounter them. Their ids in `self.writer.cached.ids` are left as zero. Then, + /// once we encounter a `Load` or `Store` expression that actually needs the + /// chain's value, we call `write_expression_pointer` to handle the whole + /// thing in one fell swoop. + fn is_intermediate(&self, expr_handle: Handle<crate::Expression>) -> bool { + match self.ir_function.expressions[expr_handle] { + crate::Expression::GlobalVariable(handle) => { + match self.ir_module.global_variables[handle].space { + crate::AddressSpace::Handle => false, + _ => true, + } + } + crate::Expression::LocalVariable(_) => true, + crate::Expression::FunctionArgument(index) => { + let arg = &self.ir_function.arguments[index as usize]; + self.ir_module.types[arg.ty].inner.pointer_space().is_some() + } + + // The chain rule: if this `Access...`'s `base` operand was + // previously omitted, then omit this one, too. + _ => self.cached.ids[expr_handle.index()] == 0, + } + } + + /// Cache an expression for a value. + pub(super) fn cache_expression_value( + &mut self, + expr_handle: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<(), Error> { + let is_named_expression = self + .ir_function + .named_expressions + .contains_key(&expr_handle); + + if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression { + return Ok(()); + } + + let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty); + let id = match self.ir_function.expressions[expr_handle] { + crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal), + crate::Expression::Constant(handle) => { + let init = self.ir_module.constants[handle].init; + self.writer.constant_ids[init.index()] + } + crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id), + crate::Expression::Compose { ty, ref components } => { + self.temp_list.clear(); + if self.expression_constness.is_const(expr_handle) { + self.temp_list.extend( + crate::proc::flatten_compose( + ty, + components, + &self.ir_function.expressions, + &self.ir_module.types, + ) + .map(|component| self.cached[component]), + ); + self.writer + .get_constant_composite(LookupType::Handle(ty), &self.temp_list) + } else { + self.temp_list + .extend(components.iter().map(|&component| self.cached[component])); + + let id = self.gen_id(); + block.body.push(Instruction::composite_construct( + result_type_id, + id, + &self.temp_list, + )); + id + } + } + crate::Expression::Splat { size, value } => { + let value_id = self.cached[value]; + let components = &[value_id; 4][..size as usize]; + + if self.expression_constness.is_const(expr_handle) { + let ty = self + .writer + .get_expression_lookup_type(&self.fun_info[expr_handle].ty); + self.writer.get_constant_composite(ty, components) + } else { + let id = self.gen_id(); + block.body.push(Instruction::composite_construct( + result_type_id, + id, + components, + )); + id + } + } + crate::Expression::Access { base, index: _ } if self.is_intermediate(base) => { + // See `is_intermediate`; we'll handle this later in + // `write_expression_pointer`. + 0 + } + crate::Expression::Access { base, index } => { + let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types); + match *base_ty_inner { + crate::TypeInner::Vector { .. } => { + self.write_vector_access(expr_handle, base, index, block)? + } + // Only binding arrays in the Handle address space will take this path (due to `is_intermediate`) + crate::TypeInner::BindingArray { + base: binding_type, .. + } => { + let space = match self.ir_function.expressions[base] { + crate::Expression::GlobalVariable(gvar) => { + self.ir_module.global_variables[gvar].space + } + _ => unreachable!(), + }; + let binding_array_false_pointer = LookupType::Local(LocalType::Pointer { + base: binding_type, + class: helpers::map_storage_class(space), + }); + + let result_id = match self.write_expression_pointer( + expr_handle, + block, + Some(binding_array_false_pointer), + )? { + ExpressionPointer::Ready { pointer_id } => pointer_id, + ExpressionPointer::Conditional { .. } => { + return Err(Error::FeatureNotImplemented( + "Texture array out-of-bounds handling", + )); + } + }; + + let binding_type_id = self.get_type_id(LookupType::Handle(binding_type)); + + let load_id = self.gen_id(); + block.body.push(Instruction::load( + binding_type_id, + load_id, + result_id, + None, + )); + + // Subsequent image operations require the image/sampler to be decorated as NonUniform + // if the image/sampler binding array was accessed with a non-uniform index + // see VUID-RuntimeSpirv-NonUniform-06274 + if self.fun_info[index].uniformity.non_uniform_result.is_some() { + self.writer + .decorate_non_uniform_binding_array_access(load_id)?; + } + + load_id + } + ref other => { + log::error!( + "Unable to access base {:?} of type {:?}", + self.ir_function.expressions[base], + other + ); + return Err(Error::Validation( + "only vectors may be dynamically indexed by value", + )); + } + } + } + crate::Expression::AccessIndex { base, index: _ } if self.is_intermediate(base) => { + // See `is_intermediate`; we'll handle this later in + // `write_expression_pointer`. + 0 + } + crate::Expression::AccessIndex { base, index } => { + match *self.fun_info[base].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } + | crate::TypeInner::Array { .. } + | crate::TypeInner::Struct { .. } => { + // We never need bounds checks here: dynamically sized arrays can + // only appear behind pointers, and are thus handled by the + // `is_intermediate` case above. Everything else's size is + // statically known and checked in validation. + let id = self.gen_id(); + let base_id = self.cached[base]; + block.body.push(Instruction::composite_extract( + result_type_id, + id, + base_id, + &[index], + )); + id + } + // Only binding arrays in the Handle address space will take this path (due to `is_intermediate`) + crate::TypeInner::BindingArray { + base: binding_type, .. + } => { + let space = match self.ir_function.expressions[base] { + crate::Expression::GlobalVariable(gvar) => { + self.ir_module.global_variables[gvar].space + } + _ => unreachable!(), + }; + let binding_array_false_pointer = LookupType::Local(LocalType::Pointer { + base: binding_type, + class: helpers::map_storage_class(space), + }); + + let result_id = match self.write_expression_pointer( + expr_handle, + block, + Some(binding_array_false_pointer), + )? { + ExpressionPointer::Ready { pointer_id } => pointer_id, + ExpressionPointer::Conditional { .. } => { + return Err(Error::FeatureNotImplemented( + "Texture array out-of-bounds handling", + )); + } + }; + + let binding_type_id = self.get_type_id(LookupType::Handle(binding_type)); + + let load_id = self.gen_id(); + block.body.push(Instruction::load( + binding_type_id, + load_id, + result_id, + None, + )); + + load_id + } + ref other => { + log::error!("Unable to access index of {:?}", other); + return Err(Error::FeatureNotImplemented("access index for type")); + } + } + } + crate::Expression::GlobalVariable(handle) => { + self.writer.global_variables[handle.index()].access_id + } + crate::Expression::Swizzle { + size, + vector, + pattern, + } => { + let vector_id = self.cached[vector]; + self.temp_list.clear(); + for &sc in pattern[..size as usize].iter() { + self.temp_list.push(sc as Word); + } + let id = self.gen_id(); + block.body.push(Instruction::vector_shuffle( + result_type_id, + id, + vector_id, + vector_id, + &self.temp_list, + )); + id + } + crate::Expression::Unary { op, expr } => { + let id = self.gen_id(); + let expr_id = self.cached[expr]; + let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types); + + let spirv_op = match op { + crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Float) => spirv::Op::FNegate, + Some(crate::ScalarKind::Sint) => spirv::Op::SNegate, + _ => return Err(Error::Validation("Unexpected kind for negation")), + }, + crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot, + crate::UnaryOperator::BitwiseNot => spirv::Op::Not, + }; + + block + .body + .push(Instruction::unary(spirv_op, result_type_id, id, expr_id)); + id + } + crate::Expression::Binary { op, left, right } => { + let id = self.gen_id(); + let left_id = self.cached[left]; + let right_id = self.cached[right]; + + let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types); + let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types); + + let left_dimension = get_dimension(left_ty_inner); + let right_dimension = get_dimension(right_ty_inner); + + let mut reverse_operands = false; + + let spirv_op = match op { + crate::BinaryOperator::Add => match *left_ty_inner { + crate::TypeInner::Scalar(scalar) + | crate::TypeInner::Vector { scalar, .. } => match scalar.kind { + crate::ScalarKind::Float => spirv::Op::FAdd, + _ => spirv::Op::IAdd, + }, + crate::TypeInner::Matrix { + columns, + rows, + scalar, + } => { + self.write_matrix_matrix_column_op( + block, + id, + result_type_id, + left_id, + right_id, + columns, + rows, + scalar.width, + spirv::Op::FAdd, + ); + + self.cached[expr_handle] = id; + return Ok(()); + } + _ => unimplemented!(), + }, + crate::BinaryOperator::Subtract => match *left_ty_inner { + crate::TypeInner::Scalar(scalar) + | crate::TypeInner::Vector { scalar, .. } => match scalar.kind { + crate::ScalarKind::Float => spirv::Op::FSub, + _ => spirv::Op::ISub, + }, + crate::TypeInner::Matrix { + columns, + rows, + scalar, + } => { + self.write_matrix_matrix_column_op( + block, + id, + result_type_id, + left_id, + right_id, + columns, + rows, + scalar.width, + spirv::Op::FSub, + ); + + self.cached[expr_handle] = id; + return Ok(()); + } + _ => unimplemented!(), + }, + crate::BinaryOperator::Multiply => match (left_dimension, right_dimension) { + (Dimension::Scalar, Dimension::Vector) => { + self.write_vector_scalar_mult( + block, + id, + result_type_id, + right_id, + left_id, + right_ty_inner, + ); + + self.cached[expr_handle] = id; + return Ok(()); + } + (Dimension::Vector, Dimension::Scalar) => { + self.write_vector_scalar_mult( + block, + id, + result_type_id, + left_id, + right_id, + left_ty_inner, + ); + + self.cached[expr_handle] = id; + return Ok(()); + } + (Dimension::Vector, Dimension::Matrix) => spirv::Op::VectorTimesMatrix, + (Dimension::Matrix, Dimension::Scalar) => spirv::Op::MatrixTimesScalar, + (Dimension::Scalar, Dimension::Matrix) => { + reverse_operands = true; + spirv::Op::MatrixTimesScalar + } + (Dimension::Matrix, Dimension::Vector) => spirv::Op::MatrixTimesVector, + (Dimension::Matrix, Dimension::Matrix) => spirv::Op::MatrixTimesMatrix, + (Dimension::Vector, Dimension::Vector) + | (Dimension::Scalar, Dimension::Scalar) + if left_ty_inner.scalar_kind() == Some(crate::ScalarKind::Float) => + { + spirv::Op::FMul + } + (Dimension::Vector, Dimension::Vector) + | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul, + }, + crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::SDiv, + Some(crate::ScalarKind::Uint) => spirv::Op::UDiv, + Some(crate::ScalarKind::Float) => spirv::Op::FDiv, + _ => unimplemented!(), + }, + crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() { + // TODO: handle undefined behavior + // if right == 0 return 0 + // if left == min(type_of(left)) && right == -1 return 0 + Some(crate::ScalarKind::Sint) => spirv::Op::SRem, + // TODO: handle undefined behavior + // if right == 0 return 0 + Some(crate::ScalarKind::Uint) => spirv::Op::UMod, + // TODO: handle undefined behavior + // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798 + Some(crate::ScalarKind::Float) => spirv::Op::FRem, + _ => unimplemented!(), + }, + crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => { + spirv::Op::IEqual + } + Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual, + Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual, + _ => unimplemented!(), + }, + crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => { + spirv::Op::INotEqual + } + Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual, + Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual, + _ => unimplemented!(), + }, + crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan, + Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan, + Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan, + _ => unimplemented!(), + }, + crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual, + Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual, + Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual, + _ => unimplemented!(), + }, + crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan, + Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan, + Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan, + _ => unimplemented!(), + }, + crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual, + Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual, + Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual, + _ => unimplemented!(), + }, + crate::BinaryOperator::And => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd, + _ => spirv::Op::BitwiseAnd, + }, + crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor, + crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr, + _ => spirv::Op::BitwiseOr, + }, + crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd, + crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr, + crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical, + crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic, + Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical, + _ => unimplemented!(), + }, + }; + + block.body.push(Instruction::binary( + spirv_op, + result_type_id, + id, + if reverse_operands { right_id } else { left_id }, + if reverse_operands { left_id } else { right_id }, + )); + id + } + crate::Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + use crate::MathFunction as Mf; + enum MathOp { + Ext(spirv::GLOp), + Custom(Instruction), + } + + let arg0_id = self.cached[arg]; + let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types); + let arg_scalar_kind = arg_ty.scalar_kind(); + let arg1_id = match arg1 { + Some(handle) => self.cached[handle], + None => 0, + }; + let arg2_id = match arg2 { + Some(handle) => self.cached[handle], + None => 0, + }; + let arg3_id = match arg3 { + Some(handle) => self.cached[handle], + None => 0, + }; + + let id = self.gen_id(); + let math_op = match fun { + // comparison + Mf::Abs => { + match arg_scalar_kind { + Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FAbs), + Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GLOp::SAbs), + Some(crate::ScalarKind::Uint) => { + MathOp::Custom(Instruction::unary( + spirv::Op::CopyObject, // do nothing + result_type_id, + id, + arg0_id, + )) + } + other => unimplemented!("Unexpected abs({:?})", other), + } + } + Mf::Min => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Float) => spirv::GLOp::FMin, + Some(crate::ScalarKind::Sint) => spirv::GLOp::SMin, + Some(crate::ScalarKind::Uint) => spirv::GLOp::UMin, + other => unimplemented!("Unexpected min({:?})", other), + }), + Mf::Max => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Float) => spirv::GLOp::FMax, + Some(crate::ScalarKind::Sint) => spirv::GLOp::SMax, + Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax, + other => unimplemented!("Unexpected max({:?})", other), + }), + Mf::Clamp => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Float) => spirv::GLOp::FClamp, + Some(crate::ScalarKind::Sint) => spirv::GLOp::SClamp, + Some(crate::ScalarKind::Uint) => spirv::GLOp::UClamp, + other => unimplemented!("Unexpected max({:?})", other), + }), + Mf::Saturate => { + let (maybe_size, scalar) = match *arg_ty { + crate::TypeInner::Vector { size, scalar } => (Some(size), scalar), + crate::TypeInner::Scalar(scalar) => (None, scalar), + ref other => unimplemented!("Unexpected saturate({:?})", other), + }; + let scalar = crate::Scalar::float(scalar.width); + let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?; + let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?; + + if let Some(size) = maybe_size { + let ty = LocalType::Value { + vector_size: Some(size), + scalar, + pointer_space: None, + } + .into(); + + self.temp_list.clear(); + self.temp_list.resize(size as _, arg1_id); + + arg1_id = self.writer.get_constant_composite(ty, &self.temp_list); + + self.temp_list.fill(arg2_id); + + arg2_id = self.writer.get_constant_composite(ty, &self.temp_list); + } + + MathOp::Custom(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::FClamp, + result_type_id, + id, + &[arg0_id, arg1_id, arg2_id], + )) + } + // trigonometry + Mf::Sin => MathOp::Ext(spirv::GLOp::Sin), + Mf::Sinh => MathOp::Ext(spirv::GLOp::Sinh), + Mf::Asin => MathOp::Ext(spirv::GLOp::Asin), + Mf::Cos => MathOp::Ext(spirv::GLOp::Cos), + Mf::Cosh => MathOp::Ext(spirv::GLOp::Cosh), + Mf::Acos => MathOp::Ext(spirv::GLOp::Acos), + Mf::Tan => MathOp::Ext(spirv::GLOp::Tan), + Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh), + Mf::Atan => MathOp::Ext(spirv::GLOp::Atan), + Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2), + Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh), + Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh), + Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh), + Mf::Radians => MathOp::Ext(spirv::GLOp::Radians), + Mf::Degrees => MathOp::Ext(spirv::GLOp::Degrees), + // decomposition + Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil), + Mf::Round => MathOp::Ext(spirv::GLOp::RoundEven), + Mf::Floor => MathOp::Ext(spirv::GLOp::Floor), + Mf::Fract => MathOp::Ext(spirv::GLOp::Fract), + Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc), + Mf::Modf => MathOp::Ext(spirv::GLOp::ModfStruct), + Mf::Frexp => MathOp::Ext(spirv::GLOp::FrexpStruct), + Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp), + // geometry + Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Vector { + scalar: + crate::Scalar { + kind: crate::ScalarKind::Float, + .. + }, + .. + } => MathOp::Custom(Instruction::binary( + spirv::Op::Dot, + result_type_id, + id, + arg0_id, + arg1_id, + )), + // TODO: consider using integer dot product if VK_KHR_shader_integer_dot_product is available + crate::TypeInner::Vector { size, .. } => { + self.write_dot_product( + id, + result_type_id, + arg0_id, + arg1_id, + size as u32, + block, + ); + self.cached[expr_handle] = id; + return Ok(()); + } + _ => unreachable!( + "Correct TypeInner for dot product should be already validated" + ), + }, + Mf::Outer => MathOp::Custom(Instruction::binary( + spirv::Op::OuterProduct, + result_type_id, + id, + arg0_id, + arg1_id, + )), + Mf::Cross => MathOp::Ext(spirv::GLOp::Cross), + Mf::Distance => MathOp::Ext(spirv::GLOp::Distance), + Mf::Length => MathOp::Ext(spirv::GLOp::Length), + Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize), + Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward), + Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect), + Mf::Refract => MathOp::Ext(spirv::GLOp::Refract), + // exponent + Mf::Exp => MathOp::Ext(spirv::GLOp::Exp), + Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2), + Mf::Log => MathOp::Ext(spirv::GLOp::Log), + Mf::Log2 => MathOp::Ext(spirv::GLOp::Log2), + Mf::Pow => MathOp::Ext(spirv::GLOp::Pow), + // computational + Mf::Sign => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Float) => spirv::GLOp::FSign, + Some(crate::ScalarKind::Sint) => spirv::GLOp::SSign, + other => unimplemented!("Unexpected sign({:?})", other), + }), + Mf::Fma => MathOp::Ext(spirv::GLOp::Fma), + Mf::Mix => { + let selector = arg2.unwrap(); + let selector_ty = + self.fun_info[selector].ty.inner_with(&self.ir_module.types); + match (arg_ty, selector_ty) { + // if the selector is a scalar, we need to splat it + ( + &crate::TypeInner::Vector { size, .. }, + &crate::TypeInner::Scalar(scalar), + ) => { + let selector_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(size), + scalar, + pointer_space: None, + })); + self.temp_list.clear(); + self.temp_list.resize(size as usize, arg2_id); + + let selector_id = self.gen_id(); + block.body.push(Instruction::composite_construct( + selector_type_id, + selector_id, + &self.temp_list, + )); + + MathOp::Custom(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::FMix, + result_type_id, + id, + &[arg0_id, arg1_id, selector_id], + )) + } + _ => MathOp::Ext(spirv::GLOp::FMix), + } + } + Mf::Step => MathOp::Ext(spirv::GLOp::Step), + Mf::SmoothStep => MathOp::Ext(spirv::GLOp::SmoothStep), + Mf::Sqrt => MathOp::Ext(spirv::GLOp::Sqrt), + Mf::InverseSqrt => MathOp::Ext(spirv::GLOp::InverseSqrt), + Mf::Inverse => MathOp::Ext(spirv::GLOp::MatrixInverse), + Mf::Transpose => MathOp::Custom(Instruction::unary( + spirv::Op::Transpose, + result_type_id, + id, + arg0_id, + )), + Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant), + Mf::ReverseBits => MathOp::Custom(Instruction::unary( + spirv::Op::BitReverse, + result_type_id, + id, + arg0_id, + )), + Mf::CountTrailingZeros => { + let uint_id = match *arg_ty { + crate::TypeInner::Vector { size, mut scalar } => { + scalar.kind = crate::ScalarKind::Uint; + let ty = LocalType::Value { + vector_size: Some(size), + scalar, + pointer_space: None, + } + .into(); + + self.temp_list.clear(); + self.temp_list.resize( + size as _, + self.writer.get_constant_scalar_with(32, scalar)?, + ); + + self.writer.get_constant_composite(ty, &self.temp_list) + } + crate::TypeInner::Scalar(mut scalar) => { + scalar.kind = crate::ScalarKind::Uint; + self.writer.get_constant_scalar_with(32, scalar)? + } + _ => unreachable!(), + }; + + let lsb_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::FindILsb, + result_type_id, + lsb_id, + &[arg0_id], + )); + + MathOp::Custom(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::UMin, + result_type_id, + id, + &[uint_id, lsb_id], + )) + } + Mf::CountLeadingZeros => { + let (int_type_id, int_id) = match *arg_ty { + crate::TypeInner::Vector { size, mut scalar } => { + scalar.kind = crate::ScalarKind::Sint; + let ty = LocalType::Value { + vector_size: Some(size), + scalar, + pointer_space: None, + } + .into(); + + self.temp_list.clear(); + self.temp_list.resize( + size as _, + self.writer.get_constant_scalar_with(31, scalar)?, + ); + + ( + self.get_type_id(ty), + self.writer.get_constant_composite(ty, &self.temp_list), + ) + } + crate::TypeInner::Scalar(mut scalar) => { + scalar.kind = crate::ScalarKind::Sint; + ( + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar, + pointer_space: None, + })), + self.writer.get_constant_scalar_with(31, scalar)?, + ) + } + _ => unreachable!(), + }; + + let msb_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::FindUMsb, + int_type_id, + msb_id, + &[arg0_id], + )); + + MathOp::Custom(Instruction::binary( + spirv::Op::ISub, + result_type_id, + id, + int_id, + msb_id, + )) + } + Mf::CountOneBits => MathOp::Custom(Instruction::unary( + spirv::Op::BitCount, + result_type_id, + id, + arg0_id, + )), + Mf::ExtractBits => { + let op = match arg_scalar_kind { + Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract, + Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract, + other => unimplemented!("Unexpected sign({:?})", other), + }; + MathOp::Custom(Instruction::ternary( + op, + result_type_id, + id, + arg0_id, + arg1_id, + arg2_id, + )) + } + Mf::InsertBits => MathOp::Custom(Instruction::quaternary( + spirv::Op::BitFieldInsert, + result_type_id, + id, + arg0_id, + arg1_id, + arg2_id, + arg3_id, + )), + Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb), + Mf::FindMsb => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb, + Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb, + other => unimplemented!("Unexpected findMSB({:?})", other), + }), + Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8), + Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8), + Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16), + Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16), + Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16), + Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8), + Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8), + Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16), + Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16), + Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16), + }; + + block.body.push(match math_op { + MathOp::Ext(op) => Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + op, + result_type_id, + id, + &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()], + ), + MathOp::Custom(inst) => inst, + }); + id + } + crate::Expression::LocalVariable(variable) => self.function.variables[&variable].id, + crate::Expression::Load { pointer } => { + match self.write_expression_pointer(pointer, block, None)? { + ExpressionPointer::Ready { pointer_id } => { + let id = self.gen_id(); + let atomic_space = + match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Pointer { base, space } => { + match self.ir_module.types[base].inner { + crate::TypeInner::Atomic { .. } => Some(space), + _ => None, + } + } + _ => None, + }; + let instruction = if let Some(space) = atomic_space { + let (semantics, scope) = space.to_spirv_semantics_and_scope(); + let scope_constant_id = self.get_scope_constant(scope as u32); + let semantics_id = self.get_index_constant(semantics.bits()); + Instruction::atomic_load( + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + ) + } else { + Instruction::load(result_type_id, id, pointer_id, None) + }; + block.body.push(instruction); + id + } + ExpressionPointer::Conditional { condition, access } => { + //TODO: support atomics? + self.write_conditional_indexed_load( + result_type_id, + condition, + block, + move |id_gen, block| { + // The in-bounds path. Perform the access and the load. + let pointer_id = access.result_id.unwrap(); + let value_id = id_gen.next(); + block.body.push(access); + block.body.push(Instruction::load( + result_type_id, + value_id, + pointer_id, + None, + )); + value_id + }, + ) + } + } + } + crate::Expression::FunctionArgument(index) => self.function.parameter_id(index), + crate::Expression::CallResult(_) + | crate::Expression::AtomicResult { .. } + | crate::Expression::WorkGroupUniformLoadResult { .. } + | crate::Expression::RayQueryProceedResult => self.cached[expr_handle], + crate::Expression::As { + expr, + kind, + convert, + } => { + use crate::ScalarKind as Sk; + + let expr_id = self.cached[expr]; + let (src_scalar, src_size, is_matrix) = + match *self.fun_info[expr].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Scalar(scalar) => (scalar, None, false), + crate::TypeInner::Vector { scalar, size } => (scalar, Some(size), false), + crate::TypeInner::Matrix { scalar, .. } => (scalar, None, true), + ref other => { + log::error!("As source {:?}", other); + return Err(Error::Validation("Unexpected Expression::As source")); + } + }; + + enum Cast { + Identity, + Unary(spirv::Op), + Binary(spirv::Op, Word), + Ternary(spirv::Op, Word, Word), + } + + let cast = if is_matrix { + // we only support identity casts for matrices + Cast::Unary(spirv::Op::CopyObject) + } else { + match (src_scalar.kind, kind, convert) { + // Filter out identity casts. Some Adreno drivers are + // confused by no-op OpBitCast instructions. + (src_kind, kind, convert) + if src_kind == kind + && convert.filter(|&width| width != src_scalar.width).is_none() => + { + Cast::Identity + } + (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject), + (_, _, None) => Cast::Unary(spirv::Op::Bitcast), + // casting to a bool - generate `OpXxxNotEqual` + (_, Sk::Bool, Some(_)) => { + let op = match src_scalar.kind { + Sk::Sint | Sk::Uint => spirv::Op::INotEqual, + Sk::Float => spirv::Op::FUnordNotEqual, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(), + }; + let zero_scalar_id = + self.writer.get_constant_scalar_with(0, src_scalar)?; + let zero_id = match src_size { + Some(size) => { + let ty = LocalType::Value { + vector_size: Some(size), + scalar: src_scalar, + pointer_space: None, + } + .into(); + + self.temp_list.clear(); + self.temp_list.resize(size as _, zero_scalar_id); + + self.writer.get_constant_composite(ty, &self.temp_list) + } + None => zero_scalar_id, + }; + + Cast::Binary(op, zero_id) + } + // casting from a bool - generate `OpSelect` + (Sk::Bool, _, Some(dst_width)) => { + let dst_scalar = crate::Scalar { + kind, + width: dst_width, + }; + let zero_scalar_id = + self.writer.get_constant_scalar_with(0, dst_scalar)?; + let one_scalar_id = + self.writer.get_constant_scalar_with(1, dst_scalar)?; + let (accept_id, reject_id) = match src_size { + Some(size) => { + let ty = LocalType::Value { + vector_size: Some(size), + scalar: dst_scalar, + pointer_space: None, + } + .into(); + + self.temp_list.clear(); + self.temp_list.resize(size as _, zero_scalar_id); + + let vec0_id = + self.writer.get_constant_composite(ty, &self.temp_list); + + self.temp_list.fill(one_scalar_id); + + let vec1_id = + self.writer.get_constant_composite(ty, &self.temp_list); + + (vec1_id, vec0_id) + } + None => (one_scalar_id, zero_scalar_id), + }; + + Cast::Ternary(spirv::Op::Select, accept_id, reject_id) + } + (Sk::Float, Sk::Uint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToU), + (Sk::Float, Sk::Sint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToS), + (Sk::Float, Sk::Float, Some(dst_width)) + if src_scalar.width != dst_width => + { + Cast::Unary(spirv::Op::FConvert) + } + (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF), + (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => { + Cast::Unary(spirv::Op::SConvert) + } + (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF), + (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => { + Cast::Unary(spirv::Op::UConvert) + } + // We assume it's either an identity cast, or int-uint. + _ => Cast::Unary(spirv::Op::Bitcast), + } + }; + + let id = self.gen_id(); + let instruction = match cast { + Cast::Identity => None, + Cast::Unary(op) => Some(Instruction::unary(op, result_type_id, id, expr_id)), + Cast::Binary(op, operand) => Some(Instruction::binary( + op, + result_type_id, + id, + expr_id, + operand, + )), + Cast::Ternary(op, op1, op2) => Some(Instruction::ternary( + op, + result_type_id, + id, + expr_id, + op1, + op2, + )), + }; + if let Some(instruction) = instruction { + block.body.push(instruction); + id + } else { + expr_id + } + } + crate::Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => self.write_image_load( + result_type_id, + image, + coordinate, + array_index, + level, + sample, + block, + )?, + crate::Expression::ImageSample { + image, + sampler, + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + } => self.write_image_sample( + result_type_id, + image, + sampler, + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + block, + )?, + crate::Expression::Select { + condition, + accept, + reject, + } => { + let id = self.gen_id(); + let mut condition_id = self.cached[condition]; + let accept_id = self.cached[accept]; + let reject_id = self.cached[reject]; + + let condition_ty = self.fun_info[condition] + .ty + .inner_with(&self.ir_module.types); + let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types); + + if let ( + &crate::TypeInner::Scalar( + condition_scalar @ crate::Scalar { + kind: crate::ScalarKind::Bool, + .. + }, + ), + &crate::TypeInner::Vector { size, .. }, + ) = (condition_ty, object_ty) + { + self.temp_list.clear(); + self.temp_list.resize(size as usize, condition_id); + + let bool_vector_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(size), + scalar: condition_scalar, + pointer_space: None, + })); + + let id = self.gen_id(); + block.body.push(Instruction::composite_construct( + bool_vector_type_id, + id, + &self.temp_list, + )); + condition_id = id + } + + let instruction = + Instruction::select(result_type_id, id, condition_id, accept_id, reject_id); + block.body.push(instruction); + id + } + crate::Expression::Derivative { axis, ctrl, expr } => { + use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; + match ctrl { + Ctrl::Coarse | Ctrl::Fine => { + self.writer.require_any( + "DerivativeControl", + &[spirv::Capability::DerivativeControl], + )?; + } + Ctrl::None => {} + } + let id = self.gen_id(); + let expr_id = self.cached[expr]; + let op = match (axis, ctrl) { + (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse, + (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine, + (Axis::X, Ctrl::None) => spirv::Op::DPdx, + (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse, + (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine, + (Axis::Y, Ctrl::None) => spirv::Op::DPdy, + (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse, + (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine, + (Axis::Width, Ctrl::None) => spirv::Op::Fwidth, + }; + block + .body + .push(Instruction::derivative(op, result_type_id, id, expr_id)); + id + } + crate::Expression::ImageQuery { image, query } => { + self.write_image_query(result_type_id, image, query, block)? + } + crate::Expression::Relational { fun, argument } => { + use crate::RelationalFunction as Rf; + let arg_id = self.cached[argument]; + let op = match fun { + Rf::All => spirv::Op::All, + Rf::Any => spirv::Op::Any, + Rf::IsNan => spirv::Op::IsNan, + Rf::IsInf => spirv::Op::IsInf, + }; + let id = self.gen_id(); + block + .body + .push(Instruction::relational(op, result_type_id, id, arg_id)); + id + } + crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?, + crate::Expression::RayQueryGetIntersection { query, committed } => { + if !committed { + return Err(Error::FeatureNotImplemented("candidate intersection")); + } + self.write_ray_query_get_intersection(query, block) + } + }; + + self.cached[expr_handle] = id; + Ok(()) + } + + /// Build an `OpAccessChain` instruction. + /// + /// Emit any needed bounds-checking expressions to `block`. + /// + /// Some cases we need to generate a different return type than what the IR gives us. + /// This is because pointers to binding arrays of handles (such as images or samplers) + /// don't exist in the IR, but we need to create them to create an access chain in SPIRV. + /// + /// On success, the return value is an [`ExpressionPointer`] value; see the + /// documentation for that type. + fn write_expression_pointer( + &mut self, + mut expr_handle: Handle<crate::Expression>, + block: &mut Block, + return_type_override: Option<LookupType>, + ) -> Result<ExpressionPointer, Error> { + let result_lookup_ty = match self.fun_info[expr_handle].ty { + TypeResolution::Handle(ty_handle) => match return_type_override { + // We use the return type override as a special case for handle binding arrays as the OpAccessChain + // needs to return a pointer, but indexing into a handle binding array just gives you the type of + // the binding in the IR. + Some(ty) => ty, + None => LookupType::Handle(ty_handle), + }, + TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()), + }; + let result_type_id = self.get_type_id(result_lookup_ty); + + // The id of the boolean `and` of all dynamic bounds checks up to this point. If + // `None`, then we haven't done any dynamic bounds checks yet. + // + // When we have a chain of bounds checks, we combine them with `OpLogicalAnd`, not + // a short-circuit branch. This means we might do comparisons we don't need to, + // but we expect these checks to almost always succeed, and keeping branches to a + // minimum is essential. + let mut accumulated_checks = None; + // Is true if we are accessing into a binding array with a non-uniform index. + let mut is_non_uniform_binding_array = false; + + self.temp_list.clear(); + let root_id = loop { + expr_handle = match self.ir_function.expressions[expr_handle] { + crate::Expression::Access { base, index } => { + if let crate::Expression::GlobalVariable(var_handle) = + self.ir_function.expressions[base] + { + // The access chain needs to be decorated as NonUniform + // see VUID-RuntimeSpirv-NonUniform-06274 + let gvar = &self.ir_module.global_variables[var_handle]; + if let crate::TypeInner::BindingArray { .. } = + self.ir_module.types[gvar.ty].inner + { + is_non_uniform_binding_array = + self.fun_info[index].uniformity.non_uniform_result.is_some(); + } + } + + let index_id = match self.write_bounds_check(base, index, block)? { + BoundsCheckResult::KnownInBounds(known_index) => { + // Even if the index is known, `OpAccessIndex` + // requires expression operands, not literals. + let scalar = crate::Literal::U32(known_index); + self.writer.get_constant_scalar(scalar) + } + BoundsCheckResult::Computed(computed_index_id) => computed_index_id, + BoundsCheckResult::Conditional(comparison_id) => { + match accumulated_checks { + Some(prior_checks) => { + let combined = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + self.writer.get_bool_type_id(), + combined, + prior_checks, + comparison_id, + )); + accumulated_checks = Some(combined); + } + None => { + // Start a fresh chain of checks. + accumulated_checks = Some(comparison_id); + } + } + + // Either way, the index to use is unchanged. + self.cached[index] + } + }; + self.temp_list.push(index_id); + base + } + crate::Expression::AccessIndex { base, index } => { + let const_id = self.get_index_constant(index); + self.temp_list.push(const_id); + base + } + crate::Expression::GlobalVariable(handle) => { + let gv = &self.writer.global_variables[handle.index()]; + break gv.access_id; + } + crate::Expression::LocalVariable(variable) => { + let local_var = &self.function.variables[&variable]; + break local_var.id; + } + crate::Expression::FunctionArgument(index) => { + break self.function.parameter_id(index); + } + ref other => unimplemented!("Unexpected pointer expression {:?}", other), + } + }; + + let (pointer_id, expr_pointer) = if self.temp_list.is_empty() { + ( + root_id, + ExpressionPointer::Ready { + pointer_id: root_id, + }, + ) + } else { + self.temp_list.reverse(); + let pointer_id = self.gen_id(); + let access = + Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list); + + // If we generated some bounds checks, we need to leave it to our + // caller to generate the branch, the access, the load or store, and + // the zero value (for loads). Otherwise, we can emit the access + // ourselves, and just hand them the id of the pointer. + let expr_pointer = match accumulated_checks { + Some(condition) => ExpressionPointer::Conditional { condition, access }, + None => { + block.body.push(access); + ExpressionPointer::Ready { pointer_id } + } + }; + (pointer_id, expr_pointer) + }; + // Subsequent load, store and atomic operations require the pointer to be decorated as NonUniform + // if the binding array was accessed with a non-uniform index + // see VUID-RuntimeSpirv-NonUniform-06274 + if is_non_uniform_binding_array { + self.writer + .decorate_non_uniform_binding_array_access(pointer_id)?; + } + + Ok(expr_pointer) + } + + /// Build the instructions for matrix - matrix column operations + #[allow(clippy::too_many_arguments)] + fn write_matrix_matrix_column_op( + &mut self, + block: &mut Block, + result_id: Word, + result_type_id: Word, + left_id: Word, + right_id: Word, + columns: crate::VectorSize, + rows: crate::VectorSize, + width: u8, + op: spirv::Op, + ) { + self.temp_list.clear(); + + let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(rows), + scalar: crate::Scalar::float(width), + pointer_space: None, + })); + + for index in 0..columns as u32 { + let column_id_left = self.gen_id(); + let column_id_right = self.gen_id(); + let column_id_res = self.gen_id(); + + block.body.push(Instruction::composite_extract( + vector_type_id, + column_id_left, + left_id, + &[index], + )); + block.body.push(Instruction::composite_extract( + vector_type_id, + column_id_right, + right_id, + &[index], + )); + block.body.push(Instruction::binary( + op, + vector_type_id, + column_id_res, + column_id_left, + column_id_right, + )); + + self.temp_list.push(column_id_res); + } + + block.body.push(Instruction::composite_construct( + result_type_id, + result_id, + &self.temp_list, + )); + } + + /// Build the instructions for vector - scalar multiplication + fn write_vector_scalar_mult( + &mut self, + block: &mut Block, + result_id: Word, + result_type_id: Word, + vector_id: Word, + scalar_id: Word, + vector: &crate::TypeInner, + ) { + let (size, kind) = match *vector { + crate::TypeInner::Vector { + size, + scalar: crate::Scalar { kind, .. }, + } => (size, kind), + _ => unreachable!(), + }; + + let (op, operand_id) = match kind { + crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id), + _ => { + let operand_id = self.gen_id(); + self.temp_list.clear(); + self.temp_list.resize(size as usize, scalar_id); + block.body.push(Instruction::composite_construct( + result_type_id, + operand_id, + &self.temp_list, + )); + (spirv::Op::IMul, operand_id) + } + }; + + block.body.push(Instruction::binary( + op, + result_type_id, + result_id, + vector_id, + operand_id, + )); + } + + /// Build the instructions for the arithmetic expression of a dot product + fn write_dot_product( + &mut self, + result_id: Word, + result_type_id: Word, + arg0_id: Word, + arg1_id: Word, + size: u32, + block: &mut Block, + ) { + let mut partial_sum = self.writer.get_constant_null(result_type_id); + let last_component = size - 1; + for index in 0..=last_component { + // compute the product of the current components + let a_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + a_id, + arg0_id, + &[index], + )); + let b_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + b_id, + arg1_id, + &[index], + )); + let prod_id = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::IMul, + result_type_id, + prod_id, + a_id, + b_id, + )); + + // choose the id for the next sum, depending on current index + let id = if index == last_component { + result_id + } else { + self.gen_id() + }; + + // sum the computed product with the partial sum + block.body.push(Instruction::binary( + spirv::Op::IAdd, + result_type_id, + id, + partial_sum, + prod_id, + )); + // set the id of the result as the previous partial sum + partial_sum = id; + } + } + + pub(super) fn write_block( + &mut self, + label_id: Word, + naga_block: &crate::Block, + exit: BlockExit, + loop_context: LoopContext, + debug_info: Option<&DebugInfoInner>, + ) -> Result<(), Error> { + let mut block = Block::new(label_id); + for (statement, span) in naga_block.span_iter() { + if let (Some(debug_info), false) = ( + debug_info, + matches!( + statement, + &(Statement::Block(..) + | Statement::Break + | Statement::Continue + | Statement::Kill + | Statement::Return { .. } + | Statement::Loop { .. }) + ), + ) { + let loc: crate::SourceLocation = span.location(debug_info.source_code); + block.body.push(Instruction::line( + debug_info.source_file_id, + loc.line_number, + loc.line_position, + )); + }; + match *statement { + crate::Statement::Emit(ref range) => { + for handle in range.clone() { + // omit const expressions as we've already cached those + if !self.expression_constness.is_const(handle) { + self.cache_expression_value(handle, &mut block)?; + } + } + } + crate::Statement::Block(ref block_statements) => { + let scope_id = self.gen_id(); + self.function.consume(block, Instruction::branch(scope_id)); + + let merge_id = self.gen_id(); + self.write_block( + scope_id, + block_statements, + BlockExit::Branch { target: merge_id }, + loop_context, + debug_info, + )?; + + block = Block::new(merge_id); + } + crate::Statement::If { + condition, + ref accept, + ref reject, + } => { + let condition_id = self.cached[condition]; + + let merge_id = self.gen_id(); + block.body.push(Instruction::selection_merge( + merge_id, + spirv::SelectionControl::NONE, + )); + + let accept_id = if accept.is_empty() { + None + } else { + Some(self.gen_id()) + }; + let reject_id = if reject.is_empty() { + None + } else { + Some(self.gen_id()) + }; + + self.function.consume( + block, + Instruction::branch_conditional( + condition_id, + accept_id.unwrap_or(merge_id), + reject_id.unwrap_or(merge_id), + ), + ); + + if let Some(block_id) = accept_id { + self.write_block( + block_id, + accept, + BlockExit::Branch { target: merge_id }, + loop_context, + debug_info, + )?; + } + if let Some(block_id) = reject_id { + self.write_block( + block_id, + reject, + BlockExit::Branch { target: merge_id }, + loop_context, + debug_info, + )?; + } + + block = Block::new(merge_id); + } + crate::Statement::Switch { + selector, + ref cases, + } => { + let selector_id = self.cached[selector]; + + let merge_id = self.gen_id(); + block.body.push(Instruction::selection_merge( + merge_id, + spirv::SelectionControl::NONE, + )); + + let mut default_id = None; + // id of previous empty fall-through case + let mut last_id = None; + + let mut raw_cases = Vec::with_capacity(cases.len()); + let mut case_ids = Vec::with_capacity(cases.len()); + for case in cases.iter() { + // take id of previous empty fall-through case or generate a new one + let label_id = last_id.take().unwrap_or_else(|| self.gen_id()); + + if case.fall_through && case.body.is_empty() { + last_id = Some(label_id); + } + + case_ids.push(label_id); + + match case.value { + crate::SwitchValue::I32(value) => { + raw_cases.push(super::instructions::Case { + value: value as Word, + label_id, + }); + } + crate::SwitchValue::U32(value) => { + raw_cases.push(super::instructions::Case { value, label_id }); + } + crate::SwitchValue::Default => { + default_id = Some(label_id); + } + } + } + + let default_id = default_id.unwrap(); + + self.function.consume( + block, + Instruction::switch(selector_id, default_id, &raw_cases), + ); + + let inner_context = LoopContext { + break_id: Some(merge_id), + ..loop_context + }; + + for (i, (case, label_id)) in cases + .iter() + .zip(case_ids.iter()) + .filter(|&(case, _)| !(case.fall_through && case.body.is_empty())) + .enumerate() + { + let case_finish_id = if case.fall_through { + case_ids[i + 1] + } else { + merge_id + }; + self.write_block( + *label_id, + &case.body, + BlockExit::Branch { + target: case_finish_id, + }, + inner_context, + debug_info, + )?; + } + + block = Block::new(merge_id); + } + crate::Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + let preamble_id = self.gen_id(); + self.function + .consume(block, Instruction::branch(preamble_id)); + + let merge_id = self.gen_id(); + let body_id = self.gen_id(); + let continuing_id = self.gen_id(); + + // SPIR-V requires the continuing to the `OpLoopMerge`, + // so we have to start a new block with it. + block = Block::new(preamble_id); + // HACK the loop statement is begin with branch instruction, + // so we need to put `OpLine` debug info before merge instruction + if let Some(debug_info) = debug_info { + let loc: crate::SourceLocation = span.location(debug_info.source_code); + block.body.push(Instruction::line( + debug_info.source_file_id, + loc.line_number, + loc.line_position, + )) + } + block.body.push(Instruction::loop_merge( + merge_id, + continuing_id, + spirv::SelectionControl::NONE, + )); + self.function.consume(block, Instruction::branch(body_id)); + + self.write_block( + body_id, + body, + BlockExit::Branch { + target: continuing_id, + }, + LoopContext { + continuing_id: Some(continuing_id), + break_id: Some(merge_id), + }, + debug_info, + )?; + + let exit = match break_if { + Some(condition) => BlockExit::BreakIf { + condition, + preamble_id, + }, + None => BlockExit::Branch { + target: preamble_id, + }, + }; + + self.write_block( + continuing_id, + continuing, + exit, + LoopContext { + continuing_id: None, + break_id: Some(merge_id), + }, + debug_info, + )?; + + block = Block::new(merge_id); + } + crate::Statement::Break => { + self.function + .consume(block, Instruction::branch(loop_context.break_id.unwrap())); + return Ok(()); + } + crate::Statement::Continue => { + self.function.consume( + block, + Instruction::branch(loop_context.continuing_id.unwrap()), + ); + return Ok(()); + } + crate::Statement::Return { value: Some(value) } => { + let value_id = self.cached[value]; + let instruction = match self.function.entry_point_context { + // If this is an entry point, and we need to return anything, + // let's instead store the output variables and return `void`. + Some(ref context) => { + self.writer.write_entry_point_return( + value_id, + self.ir_function.result.as_ref().unwrap(), + &context.results, + &mut block.body, + )?; + Instruction::return_void() + } + None => Instruction::return_value(value_id), + }; + self.function.consume(block, instruction); + return Ok(()); + } + crate::Statement::Return { value: None } => { + self.function.consume(block, Instruction::return_void()); + return Ok(()); + } + crate::Statement::Kill => { + self.function.consume(block, Instruction::kill()); + return Ok(()); + } + crate::Statement::Barrier(flags) => { + self.writer.write_barrier(flags, &mut block); + } + crate::Statement::Store { pointer, value } => { + let value_id = self.cached[value]; + match self.write_expression_pointer(pointer, &mut block, None)? { + ExpressionPointer::Ready { pointer_id } => { + let atomic_space = match *self.fun_info[pointer] + .ty + .inner_with(&self.ir_module.types) + { + crate::TypeInner::Pointer { base, space } => { + match self.ir_module.types[base].inner { + crate::TypeInner::Atomic { .. } => Some(space), + _ => None, + } + } + _ => None, + }; + let instruction = if let Some(space) = atomic_space { + let (semantics, scope) = space.to_spirv_semantics_and_scope(); + let scope_constant_id = self.get_scope_constant(scope as u32); + let semantics_id = self.get_index_constant(semantics.bits()); + Instruction::atomic_store( + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ) + } else { + Instruction::store(pointer_id, value_id, None) + }; + block.body.push(instruction); + } + ExpressionPointer::Conditional { condition, access } => { + let mut selection = Selection::start(&mut block, ()); + selection.if_true(self, condition, ()); + + // The in-bounds path. Perform the access and the store. + let pointer_id = access.result_id.unwrap(); + selection.block().body.push(access); + selection + .block() + .body + .push(Instruction::store(pointer_id, value_id, None)); + + // Finish the in-bounds block and start the merge block. This + // is the block we'll leave current on return. + selection.finish(self, ()); + } + }; + } + crate::Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => self.write_image_store(image, coordinate, array_index, value, &mut block)?, + crate::Statement::Call { + function: local_function, + ref arguments, + result, + } => { + let id = self.gen_id(); + self.temp_list.clear(); + for &argument in arguments { + self.temp_list.push(self.cached[argument]); + } + + let type_id = match result { + Some(expr) => { + self.cached[expr] = id; + self.get_expression_type_id(&self.fun_info[expr].ty) + } + None => self.writer.void_type, + }; + + block.body.push(Instruction::function_call( + type_id, + id, + self.writer.lookup_function[&local_function], + &self.temp_list, + )); + } + crate::Statement::Atomic { + pointer, + ref fun, + value, + result, + } => { + let id = self.gen_id(); + let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty); + + self.cached[result] = id; + + let pointer_id = + match self.write_expression_pointer(pointer, &mut block, None)? { + ExpressionPointer::Ready { pointer_id } => pointer_id, + ExpressionPointer::Conditional { .. } => { + return Err(Error::FeatureNotImplemented( + "Atomics out-of-bounds handling", + )); + } + }; + + let space = self.fun_info[pointer] + .ty + .inner_with(&self.ir_module.types) + .pointer_space() + .unwrap(); + let (semantics, scope) = space.to_spirv_semantics_and_scope(); + let scope_constant_id = self.get_scope_constant(scope as u32); + let semantics_id = self.get_index_constant(semantics.bits()); + let value_id = self.cached[value]; + let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types); + + let instruction = match *fun { + crate::AtomicFunction::Add => Instruction::atomic_binary( + spirv::Op::AtomicIAdd, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::Subtract => Instruction::atomic_binary( + spirv::Op::AtomicISub, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::And => Instruction::atomic_binary( + spirv::Op::AtomicAnd, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::InclusiveOr => Instruction::atomic_binary( + spirv::Op::AtomicOr, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::ExclusiveOr => Instruction::atomic_binary( + spirv::Op::AtomicXor, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::Min => { + let spirv_op = match *value_inner { + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + }) => spirv::Op::AtomicSMin, + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Uint, + width: _, + }) => spirv::Op::AtomicUMin, + _ => unimplemented!(), + }; + Instruction::atomic_binary( + spirv_op, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ) + } + crate::AtomicFunction::Max => { + let spirv_op = match *value_inner { + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + }) => spirv::Op::AtomicSMax, + crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Uint, + width: _, + }) => spirv::Op::AtomicUMax, + _ => unimplemented!(), + }; + Instruction::atomic_binary( + spirv_op, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ) + } + crate::AtomicFunction::Exchange { compare: None } => { + Instruction::atomic_binary( + spirv::Op::AtomicExchange, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ) + } + crate::AtomicFunction::Exchange { compare: Some(cmp) } => { + let scalar_type_id = match *value_inner { + crate::TypeInner::Scalar(scalar) => { + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar, + pointer_space: None, + })) + } + _ => unimplemented!(), + }; + let bool_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::BOOL, + pointer_space: None, + })); + + let cas_result_id = self.gen_id(); + let equality_result_id = self.gen_id(); + let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange); + cas_instr.set_type(scalar_type_id); + cas_instr.set_result(cas_result_id); + cas_instr.add_operand(pointer_id); + cas_instr.add_operand(scope_constant_id); + cas_instr.add_operand(semantics_id); // semantics if equal + cas_instr.add_operand(semantics_id); // semantics if not equal + cas_instr.add_operand(value_id); + cas_instr.add_operand(self.cached[cmp]); + block.body.push(cas_instr); + block.body.push(Instruction::binary( + spirv::Op::IEqual, + bool_type_id, + equality_result_id, + cas_result_id, + self.cached[cmp], + )); + Instruction::composite_construct( + result_type_id, + id, + &[cas_result_id, equality_result_id], + ) + } + }; + + block.body.push(instruction); + } + crate::Statement::WorkGroupUniformLoad { pointer, result } => { + self.writer + .write_barrier(crate::Barrier::WORK_GROUP, &mut block); + let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty); + // Embed the body of + match self.write_expression_pointer(pointer, &mut block, None)? { + ExpressionPointer::Ready { pointer_id } => { + let id = self.gen_id(); + block.body.push(Instruction::load( + result_type_id, + id, + pointer_id, + None, + )); + self.cached[result] = id; + } + ExpressionPointer::Conditional { condition, access } => { + self.cached[result] = self.write_conditional_indexed_load( + result_type_id, + condition, + &mut block, + move |id_gen, block| { + // The in-bounds path. Perform the access and the load. + let pointer_id = access.result_id.unwrap(); + let value_id = id_gen.next(); + block.body.push(access); + block.body.push(Instruction::load( + result_type_id, + value_id, + pointer_id, + None, + )); + value_id + }, + ) + } + } + self.writer + .write_barrier(crate::Barrier::WORK_GROUP, &mut block); + } + crate::Statement::RayQuery { query, ref fun } => { + self.write_ray_query_function(query, fun, &mut block); + } + } + } + + let termination = match exit { + // We're generating code for the top-level Block of the function, so we + // need to end it with some kind of return instruction. + BlockExit::Return => match self.ir_function.result { + Some(ref result) if self.function.entry_point_context.is_none() => { + let type_id = self.get_type_id(LookupType::Handle(result.ty)); + let null_id = self.writer.get_constant_null(type_id); + Instruction::return_value(null_id) + } + _ => Instruction::return_void(), + }, + BlockExit::Branch { target } => Instruction::branch(target), + BlockExit::BreakIf { + condition, + preamble_id, + } => { + let condition_id = self.cached[condition]; + + Instruction::branch_conditional( + condition_id, + loop_context.break_id.unwrap(), + preamble_id, + ) + } + }; + + self.function.consume(block, termination); + Ok(()) + } +} diff --git a/third_party/rust/naga/src/back/spv/helpers.rs b/third_party/rust/naga/src/back/spv/helpers.rs new file mode 100644 index 0000000000..5b6226db85 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/helpers.rs @@ -0,0 +1,109 @@ +use crate::{Handle, UniqueArena}; +use spirv::Word; + +pub(super) fn bytes_to_words(bytes: &[u8]) -> Vec<Word> { + bytes + .chunks(4) + .map(|chars| chars.iter().rev().fold(0u32, |u, c| (u << 8) | *c as u32)) + .collect() +} + +pub(super) fn string_to_words(input: &str) -> Vec<Word> { + let bytes = input.as_bytes(); + let mut words = bytes_to_words(bytes); + + if bytes.len() % 4 == 0 { + // nul-termination + words.push(0x0u32); + } + + words +} + +pub(super) const fn map_storage_class(space: crate::AddressSpace) -> spirv::StorageClass { + match space { + crate::AddressSpace::Handle => spirv::StorageClass::UniformConstant, + crate::AddressSpace::Function => spirv::StorageClass::Function, + crate::AddressSpace::Private => spirv::StorageClass::Private, + crate::AddressSpace::Storage { .. } => spirv::StorageClass::StorageBuffer, + crate::AddressSpace::Uniform => spirv::StorageClass::Uniform, + crate::AddressSpace::WorkGroup => spirv::StorageClass::Workgroup, + crate::AddressSpace::PushConstant => spirv::StorageClass::PushConstant, + } +} + +pub(super) fn contains_builtin( + binding: Option<&crate::Binding>, + ty: Handle<crate::Type>, + arena: &UniqueArena<crate::Type>, + built_in: crate::BuiltIn, +) -> bool { + if let Some(&crate::Binding::BuiltIn(bi)) = binding { + bi == built_in + } else if let crate::TypeInner::Struct { ref members, .. } = arena[ty].inner { + members + .iter() + .any(|member| contains_builtin(member.binding.as_ref(), member.ty, arena, built_in)) + } else { + false // unreachable + } +} + +impl crate::AddressSpace { + pub(super) const fn to_spirv_semantics_and_scope( + self, + ) -> (spirv::MemorySemantics, spirv::Scope) { + match self { + Self::Storage { .. } => (spirv::MemorySemantics::UNIFORM_MEMORY, spirv::Scope::Device), + Self::WorkGroup => ( + spirv::MemorySemantics::WORKGROUP_MEMORY, + spirv::Scope::Workgroup, + ), + _ => (spirv::MemorySemantics::empty(), spirv::Scope::Invocation), + } + } +} + +/// Return true if the global requires a type decorated with `Block`. +/// +/// Vulkan spec v1.3 §15.6.2, "Descriptor Set Interface", says: +/// +/// > Variables identified with the `Uniform` storage class are used to +/// > access transparent buffer backed resources. Such variables must +/// > be: +/// > +/// > - typed as `OpTypeStruct`, or an array of this type, +/// > +/// > - identified with a `Block` or `BufferBlock` decoration, and +/// > +/// > - laid out explicitly using the `Offset`, `ArrayStride`, and +/// > `MatrixStride` decorations as specified in §15.6.4, "Offset +/// > and Stride Assignment." +// See `back::spv::GlobalVariable::access_id` for details. +pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariable) -> bool { + match var.space { + crate::AddressSpace::Uniform + | crate::AddressSpace::Storage { .. } + | crate::AddressSpace::PushConstant => {} + _ => return false, + }; + match ir_module.types[var.ty].inner { + crate::TypeInner::Struct { + ref members, + span: _, + } => match members.last() { + Some(member) => match ir_module.types[member.ty].inner { + // Structs with dynamically sized arrays can't be copied and can't be wrapped. + crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } => false, + _ => true, + }, + None => false, + }, + crate::TypeInner::BindingArray { .. } => false, + // if it's not a structure or a binding array, let's wrap it to be able to put "Block" + _ => true, + } +} diff --git a/third_party/rust/naga/src/back/spv/image.rs b/third_party/rust/naga/src/back/spv/image.rs new file mode 100644 index 0000000000..c0fc41cbb6 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/image.rs @@ -0,0 +1,1210 @@ +/*! +Generating SPIR-V for image operations. +*/ + +use super::{ + selection::{MergeTuple, Selection}, + Block, BlockContext, Error, IdGenerator, Instruction, LocalType, LookupType, +}; +use crate::arena::Handle; +use spirv::Word; + +/// Information about a vector of coordinates. +/// +/// The coordinate vectors expected by SPIR-V `OpImageRead` and `OpImageFetch` +/// supply the array index for arrayed images as an additional component at +/// the end, whereas Naga's `ImageLoad`, `ImageStore`, and `ImageSample` carry +/// the array index as a separate field. +/// +/// In the process of generating code to compute the combined vector, we also +/// produce SPIR-V types and vector lengths that are useful elsewhere. This +/// struct gathers that information into one place, with standard names. +struct ImageCoordinates { + /// The SPIR-V id of the combined coordinate/index vector value. + /// + /// Note: when indexing a non-arrayed 1D image, this will be a scalar. + value_id: Word, + + /// The SPIR-V id of the type of `value`. + type_id: Word, + + /// The number of components in `value`, if it is a vector, or `None` if it + /// is a scalar. + size: Option<crate::VectorSize>, +} + +/// A trait for image access (load or store) code generators. +/// +/// Types implementing this trait hold information about an `ImageStore` or +/// `ImageLoad` operation that is not affected by the bounds check policy. The +/// `generate` method emits code for the access, given the results of bounds +/// checking. +/// +/// The [`image`] bounds checks policy affects access coordinates, level of +/// detail, and sample index, but never the image id, result type (if any), or +/// the specific SPIR-V instruction used. Types that implement this trait gather +/// together the latter category, so we don't have to plumb them through the +/// bounds-checking code. +/// +/// [`image`]: crate::proc::BoundsCheckPolicies::index +trait Access { + /// The Rust type that represents SPIR-V values and types for this access. + /// + /// For operations like loads, this is `Word`. For operations like stores, + /// this is `()`. + /// + /// For `ReadZeroSkipWrite`, this will be the type of the selection + /// construct that performs the bounds checks, so it must implement + /// `MergeTuple`. + type Output: MergeTuple + Copy + Clone; + + /// Write an image access to `block`. + /// + /// Access the texel at `coordinates_id`. The optional `level_id` indicates + /// the level of detail, and `sample_id` is the index of the sample to + /// access in a multisampled texel. + /// + /// This method assumes that `coordinates_id` has already had the image array + /// index, if any, folded in, as done by `write_image_coordinates`. + /// + /// Return the value id produced by the instruction, if any. + /// + /// Use `id_gen` to generate SPIR-V ids as necessary. + fn generate( + &self, + id_gen: &mut IdGenerator, + coordinates_id: Word, + level_id: Option<Word>, + sample_id: Option<Word>, + block: &mut Block, + ) -> Self::Output; + + /// Return the SPIR-V type of the value produced by the code written by + /// `generate`. If the access does not produce a value, `Self::Output` + /// should be `()`. + fn result_type(&self) -> Self::Output; + + /// Construct the SPIR-V 'zero' value to be returned for an out-of-bounds + /// access under the `ReadZeroSkipWrite` policy. If the access does not + /// produce a value, `Self::Output` should be `()`. + fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Self::Output; +} + +/// Texel access information for an [`ImageLoad`] expression. +/// +/// [`ImageLoad`]: crate::Expression::ImageLoad +struct Load { + /// The specific opcode we'll use to perform the fetch. Storage images + /// require `OpImageRead`, while sampled images require `OpImageFetch`. + opcode: spirv::Op, + + /// The type id produced by the actual image access instruction. + type_id: Word, + + /// The id of the image being accessed. + image_id: Word, +} + +impl Load { + fn from_image_expr( + ctx: &mut BlockContext<'_>, + image_id: Word, + image_class: crate::ImageClass, + result_type_id: Word, + ) -> Result<Load, Error> { + let opcode = match image_class { + crate::ImageClass::Storage { .. } => spirv::Op::ImageRead, + crate::ImageClass::Depth { .. } | crate::ImageClass::Sampled { .. } => { + spirv::Op::ImageFetch + } + }; + + // `OpImageRead` and `OpImageFetch` instructions produce vec4<f32> + // values. Most of the time, we can just use `result_type_id` for + // this. The exception is that `Expression::ImageLoad` from a depth + // image produces a scalar `f32`, so in that case we need to find + // the right SPIR-V type for the access instruction here. + let type_id = match image_class { + crate::ImageClass::Depth { .. } => { + ctx.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Quad), + scalar: crate::Scalar::F32, + pointer_space: None, + })) + } + _ => result_type_id, + }; + + Ok(Load { + opcode, + type_id, + image_id, + }) + } +} + +impl Access for Load { + type Output = Word; + + /// Write an instruction to access a given texel of this image. + fn generate( + &self, + id_gen: &mut IdGenerator, + coordinates_id: Word, + level_id: Option<Word>, + sample_id: Option<Word>, + block: &mut Block, + ) -> Word { + let texel_id = id_gen.next(); + let mut instruction = Instruction::image_fetch_or_read( + self.opcode, + self.type_id, + texel_id, + self.image_id, + coordinates_id, + ); + + match (level_id, sample_id) { + (None, None) => {} + (Some(level_id), None) => { + instruction.add_operand(spirv::ImageOperands::LOD.bits()); + instruction.add_operand(level_id); + } + (None, Some(sample_id)) => { + instruction.add_operand(spirv::ImageOperands::SAMPLE.bits()); + instruction.add_operand(sample_id); + } + // There's no such thing as a multi-sampled mipmap. + (Some(_), Some(_)) => unreachable!(), + } + + block.body.push(instruction); + + texel_id + } + + fn result_type(&self) -> Word { + self.type_id + } + + fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Word { + ctx.writer.get_constant_null(self.type_id) + } +} + +/// Texel access information for a [`Store`] statement. +/// +/// [`Store`]: crate::Statement::Store +struct Store { + /// The id of the image being written to. + image_id: Word, + + /// The value we're going to write to the texel. + value_id: Word, +} + +impl Access for Store { + /// Stores don't generate any value. + type Output = (); + + fn generate( + &self, + _id_gen: &mut IdGenerator, + coordinates_id: Word, + _level_id: Option<Word>, + _sample_id: Option<Word>, + block: &mut Block, + ) { + block.body.push(Instruction::image_write( + self.image_id, + coordinates_id, + self.value_id, + )); + } + + /// Stores don't generate any value, so this just returns `()`. + fn result_type(&self) {} + + /// Stores don't generate any value, so this just returns `()`. + fn out_of_bounds_value(&self, _ctx: &mut BlockContext<'_>) {} +} + +impl<'w> BlockContext<'w> { + /// Extend image coordinates with an array index, if necessary. + /// + /// Whereas [`Expression::ImageLoad`] and [`ImageSample`] treat the array + /// index as a separate operand from the coordinates, SPIR-V image access + /// instructions include the array index in the `coordinates` operand. This + /// function builds a SPIR-V coordinate vector from a Naga coordinate vector + /// and array index, if one is supplied, and returns a `ImageCoordinates` + /// struct describing what it built. + /// + /// If `array_index` is `Some(expr)`, then this function constructs a new + /// vector that is `coordinates` with `array_index` concatenated onto the + /// end: a `vec2` becomes a `vec3`, a scalar becomes a `vec2`, and so on. + /// + /// If `array_index` is `None`, then the return value uses `coordinates` + /// unchanged. Note that, when indexing a non-arrayed 1D image, this will be + /// a scalar value. + /// + /// If needed, this function generates code to convert the array index, + /// always an integer scalar, to match the component type of `coordinates`. + /// Naga's `ImageLoad` and SPIR-V's `OpImageRead`, `OpImageFetch`, and + /// `OpImageWrite` all use integer coordinates, while Naga's `ImageSample` + /// and SPIR-V's `OpImageSample...` instructions all take floating-point + /// coordinate vectors. + /// + /// [`Expression::ImageLoad`]: crate::Expression::ImageLoad + /// [`ImageSample`]: crate::Expression::ImageSample + fn write_image_coordinates( + &mut self, + coordinates: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + block: &mut Block, + ) -> Result<ImageCoordinates, Error> { + use crate::TypeInner as Ti; + use crate::VectorSize as Vs; + + let coordinates_id = self.cached[coordinates]; + let ty = &self.fun_info[coordinates].ty; + let inner_ty = ty.inner_with(&self.ir_module.types); + + // If there's no array index, the image coordinates are exactly the + // `coordinate` field of the `Expression::ImageLoad`. No work is needed. + let array_index = match array_index { + None => { + let value_id = coordinates_id; + let type_id = self.get_expression_type_id(ty); + let size = match *inner_ty { + Ti::Scalar { .. } => None, + Ti::Vector { size, .. } => Some(size), + _ => return Err(Error::Validation("coordinate type")), + }; + return Ok(ImageCoordinates { + value_id, + type_id, + size, + }); + } + Some(ix) => ix, + }; + + // Find the component type of `coordinates`, and figure out the size the + // combined coordinate vector will have. + let (component_scalar, size) = match *inner_ty { + Ti::Scalar(scalar @ crate::Scalar { width: 4, .. }) => (scalar, Some(Vs::Bi)), + Ti::Vector { + scalar: scalar @ crate::Scalar { width: 4, .. }, + size: Vs::Bi, + } => (scalar, Some(Vs::Tri)), + Ti::Vector { + scalar: scalar @ crate::Scalar { width: 4, .. }, + size: Vs::Tri, + } => (scalar, Some(Vs::Quad)), + Ti::Vector { size: Vs::Quad, .. } => { + return Err(Error::Validation("extending vec4 coordinate")); + } + ref other => { + log::error!("wrong coordinate type {:?}", other); + return Err(Error::Validation("coordinate type")); + } + }; + + // Convert the index to the coordinate component type, if necessary. + let array_index_id = self.cached[array_index]; + let ty = &self.fun_info[array_index].ty; + let inner_ty = ty.inner_with(&self.ir_module.types); + let array_index_scalar = match *inner_ty { + Ti::Scalar( + scalar @ crate::Scalar { + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + width: 4, + }, + ) => scalar, + _ => unreachable!("we only allow i32 and u32"), + }; + let cast = match (component_scalar.kind, array_index_scalar.kind) { + (crate::ScalarKind::Sint, crate::ScalarKind::Sint) + | (crate::ScalarKind::Uint, crate::ScalarKind::Uint) => None, + (crate::ScalarKind::Sint, crate::ScalarKind::Uint) + | (crate::ScalarKind::Uint, crate::ScalarKind::Sint) => Some(spirv::Op::Bitcast), + (crate::ScalarKind::Float, crate::ScalarKind::Sint) => Some(spirv::Op::ConvertSToF), + (crate::ScalarKind::Float, crate::ScalarKind::Uint) => Some(spirv::Op::ConvertUToF), + (crate::ScalarKind::Bool, _) => unreachable!("we don't allow bool for component"), + (_, crate::ScalarKind::Bool | crate::ScalarKind::Float) => { + unreachable!("we don't allow bool or float for array index") + } + (crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat, _) + | (_, crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat) => { + unreachable!("abstract types should never reach backends") + } + }; + let reconciled_array_index_id = if let Some(cast) = cast { + let component_ty_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: component_scalar, + pointer_space: None, + })); + let reconciled_id = self.gen_id(); + block.body.push(Instruction::unary( + cast, + component_ty_id, + reconciled_id, + array_index_id, + )); + reconciled_id + } else { + array_index_id + }; + + // Find the SPIR-V type for the combined coordinates/index vector. + let type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: size, + scalar: component_scalar, + pointer_space: None, + })); + + // Schmear the coordinates and index together. + let value_id = self.gen_id(); + block.body.push(Instruction::composite_construct( + type_id, + value_id, + &[coordinates_id, reconciled_array_index_id], + )); + Ok(ImageCoordinates { + value_id, + type_id, + size, + }) + } + + pub(super) fn get_handle_id(&mut self, expr_handle: Handle<crate::Expression>) -> Word { + let id = match self.ir_function.expressions[expr_handle] { + crate::Expression::GlobalVariable(handle) => { + self.writer.global_variables[handle.index()].handle_id + } + crate::Expression::FunctionArgument(i) => { + self.function.parameters[i as usize].handle_id + } + crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => { + self.cached[expr_handle] + } + ref other => unreachable!("Unexpected image expression {:?}", other), + }; + + if id == 0 { + unreachable!( + "Image expression {:?} doesn't have a handle ID", + expr_handle + ); + } + + id + } + + /// Generate a vector or scalar 'one' for arithmetic on `coordinates`. + /// + /// If `coordinates` is a scalar, return a scalar one. Otherwise, return + /// a vector of ones. + fn write_coordinate_one(&mut self, coordinates: &ImageCoordinates) -> Result<Word, Error> { + let one = self.get_scope_constant(1); + match coordinates.size { + None => Ok(one), + Some(vector_size) => { + let ones = [one; 4]; + let id = self.gen_id(); + Instruction::constant_composite( + coordinates.type_id, + id, + &ones[..vector_size as usize], + ) + .to_words(&mut self.writer.logical_layout.declarations); + Ok(id) + } + } + } + + /// Generate code to restrict `input` to fall between zero and one less than + /// `size_id`. + /// + /// Both must be 32-bit scalar integer values, whose type is given by + /// `type_id`. The computed value is also of type `type_id`. + fn restrict_scalar( + &mut self, + type_id: Word, + input_id: Word, + size_id: Word, + block: &mut Block, + ) -> Result<Word, Error> { + let i32_one_id = self.get_scope_constant(1); + + // Subtract one from `size` to get the largest valid value. + let limit_id = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::ISub, + type_id, + limit_id, + size_id, + i32_one_id, + )); + + // Use an unsigned minimum, to handle both positive out-of-range values + // and negative values in a single instruction: negative values of + // `input_id` get treated as very large positive values. + let restricted_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::UMin, + type_id, + restricted_id, + &[input_id, limit_id], + )); + + Ok(restricted_id) + } + + /// Write instructions to query the size of an image. + /// + /// This takes care of selecting the right instruction depending on whether + /// a level of detail parameter is present. + fn write_coordinate_bounds( + &mut self, + type_id: Word, + image_id: Word, + level_id: Option<Word>, + block: &mut Block, + ) -> Word { + let coordinate_bounds_id = self.gen_id(); + match level_id { + Some(level_id) => { + // A level of detail was provided, so fetch the image size for + // that level. + let mut inst = Instruction::image_query( + spirv::Op::ImageQuerySizeLod, + type_id, + coordinate_bounds_id, + image_id, + ); + inst.add_operand(level_id); + block.body.push(inst); + } + _ => { + // No level of detail was given. + block.body.push(Instruction::image_query( + spirv::Op::ImageQuerySize, + type_id, + coordinate_bounds_id, + image_id, + )); + } + } + + coordinate_bounds_id + } + + /// Write code to restrict coordinates for an image reference. + /// + /// First, clamp the level of detail or sample index to fall within bounds. + /// Then, obtain the image size, possibly using the clamped level of detail. + /// Finally, use an unsigned minimum instruction to force all coordinates + /// into range. + /// + /// Return a triple `(COORDS, LEVEL, SAMPLE)`, where `COORDS` is a coordinate + /// vector (including the array index, if any), `LEVEL` is an optional level + /// of detail, and `SAMPLE` is an optional sample index, all guaranteed to + /// be in-bounds for `image_id`. + /// + /// The result is usually a vector, but it is a scalar when indexing + /// non-arrayed 1D images. + fn write_restricted_coordinates( + &mut self, + image_id: Word, + coordinates: ImageCoordinates, + level_id: Option<Word>, + sample_id: Option<Word>, + block: &mut Block, + ) -> Result<(Word, Option<Word>, Option<Word>), Error> { + self.writer.require_any( + "the `Restrict` image bounds check policy", + &[spirv::Capability::ImageQuery], + )?; + + let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::I32, + pointer_space: None, + })); + + // If `level` is `Some`, clamp it to fall within bounds. This must + // happen first, because we'll use it to query the image size for + // clamping the actual coordinates. + let level_id = level_id + .map(|level_id| { + // Find the number of mipmap levels in this image. + let num_levels_id = self.gen_id(); + block.body.push(Instruction::image_query( + spirv::Op::ImageQueryLevels, + i32_type_id, + num_levels_id, + image_id, + )); + + self.restrict_scalar(i32_type_id, level_id, num_levels_id, block) + }) + .transpose()?; + + // If `sample_id` is `Some`, clamp it to fall within bounds. + let sample_id = sample_id + .map(|sample_id| { + // Find the number of samples per texel. + let num_samples_id = self.gen_id(); + block.body.push(Instruction::image_query( + spirv::Op::ImageQuerySamples, + i32_type_id, + num_samples_id, + image_id, + )); + + self.restrict_scalar(i32_type_id, sample_id, num_samples_id, block) + }) + .transpose()?; + + // Obtain the image bounds, including the array element count. + let coordinate_bounds_id = + self.write_coordinate_bounds(coordinates.type_id, image_id, level_id, block); + + // Compute maximum valid values from the bounds. + let ones = self.write_coordinate_one(&coordinates)?; + let coordinate_limit_id = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::ISub, + coordinates.type_id, + coordinate_limit_id, + coordinate_bounds_id, + ones, + )); + + // Restrict the coordinates to fall within those bounds. + // + // Use an unsigned minimum, to handle both positive out-of-range values + // and negative values in a single instruction: negative values of + // `coordinates` get treated as very large positive values. + let restricted_coordinates_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::UMin, + coordinates.type_id, + restricted_coordinates_id, + &[coordinates.value_id, coordinate_limit_id], + )); + + Ok((restricted_coordinates_id, level_id, sample_id)) + } + + fn write_conditional_image_access<A: Access>( + &mut self, + image_id: Word, + coordinates: ImageCoordinates, + level_id: Option<Word>, + sample_id: Option<Word>, + block: &mut Block, + access: &A, + ) -> Result<A::Output, Error> { + self.writer.require_any( + "the `ReadZeroSkipWrite` image bounds check policy", + &[spirv::Capability::ImageQuery], + )?; + + let bool_type_id = self.writer.get_bool_type_id(); + let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::I32, + pointer_space: None, + })); + + let null_id = access.out_of_bounds_value(self); + + let mut selection = Selection::start(block, access.result_type()); + + // If `level_id` is `Some`, check whether it is within bounds. This must + // happen first, because we'll be supplying this as an argument when we + // query the image size. + if let Some(level_id) = level_id { + // Find the number of mipmap levels in this image. + let num_levels_id = self.gen_id(); + selection.block().body.push(Instruction::image_query( + spirv::Op::ImageQueryLevels, + i32_type_id, + num_levels_id, + image_id, + )); + + let lod_cond_id = self.gen_id(); + selection.block().body.push(Instruction::binary( + spirv::Op::ULessThan, + bool_type_id, + lod_cond_id, + level_id, + num_levels_id, + )); + + selection.if_true(self, lod_cond_id, null_id); + } + + // If `sample_id` is `Some`, check whether it is in bounds. + if let Some(sample_id) = sample_id { + // Find the number of samples per texel. + let num_samples_id = self.gen_id(); + selection.block().body.push(Instruction::image_query( + spirv::Op::ImageQuerySamples, + i32_type_id, + num_samples_id, + image_id, + )); + + let samples_cond_id = self.gen_id(); + selection.block().body.push(Instruction::binary( + spirv::Op::ULessThan, + bool_type_id, + samples_cond_id, + sample_id, + num_samples_id, + )); + + selection.if_true(self, samples_cond_id, null_id); + } + + // Obtain the image bounds, including any array element count. + let coordinate_bounds_id = self.write_coordinate_bounds( + coordinates.type_id, + image_id, + level_id, + selection.block(), + ); + + // Compare the coordinates against the bounds. + let coords_bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: coordinates.size, + scalar: crate::Scalar::BOOL, + pointer_space: None, + })); + let coords_conds_id = self.gen_id(); + selection.block().body.push(Instruction::binary( + spirv::Op::ULessThan, + coords_bool_type_id, + coords_conds_id, + coordinates.value_id, + coordinate_bounds_id, + )); + + // If the comparison above was a vector comparison, then we need to + // check that all components of the comparison are true. + let coords_cond_id = if coords_bool_type_id != bool_type_id { + let id = self.gen_id(); + selection.block().body.push(Instruction::relational( + spirv::Op::All, + bool_type_id, + id, + coords_conds_id, + )); + id + } else { + coords_conds_id + }; + + selection.if_true(self, coords_cond_id, null_id); + + // All conditions are met. We can carry out the access. + let texel_id = access.generate( + &mut self.writer.id_gen, + coordinates.value_id, + level_id, + sample_id, + selection.block(), + ); + + // This, then, is the value of the 'true' branch. + Ok(selection.finish(self, texel_id)) + } + + /// Generate code for an `ImageLoad` expression. + /// + /// The arguments are the components of an `Expression::ImageLoad` variant. + #[allow(clippy::too_many_arguments)] + pub(super) fn write_image_load( + &mut self, + result_type_id: Word, + image: Handle<crate::Expression>, + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + level: Option<Handle<crate::Expression>>, + sample: Option<Handle<crate::Expression>>, + block: &mut Block, + ) -> Result<Word, Error> { + let image_id = self.get_handle_id(image); + let image_type = self.fun_info[image].ty.inner_with(&self.ir_module.types); + let image_class = match *image_type { + crate::TypeInner::Image { class, .. } => class, + _ => return Err(Error::Validation("image type")), + }; + + let access = Load::from_image_expr(self, image_id, image_class, result_type_id)?; + let coordinates = self.write_image_coordinates(coordinate, array_index, block)?; + + let level_id = level.map(|expr| self.cached[expr]); + let sample_id = sample.map(|expr| self.cached[expr]); + + // Perform the access, according to the bounds check policy. + let access_id = match self.writer.bounds_check_policies.image_load { + crate::proc::BoundsCheckPolicy::Restrict => { + let (coords, level_id, sample_id) = self.write_restricted_coordinates( + image_id, + coordinates, + level_id, + sample_id, + block, + )?; + access.generate(&mut self.writer.id_gen, coords, level_id, sample_id, block) + } + crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => self + .write_conditional_image_access( + image_id, + coordinates, + level_id, + sample_id, + block, + &access, + )?, + crate::proc::BoundsCheckPolicy::Unchecked => access.generate( + &mut self.writer.id_gen, + coordinates.value_id, + level_id, + sample_id, + block, + ), + }; + + // For depth images, `ImageLoad` expressions produce a single f32, + // whereas the SPIR-V instructions always produce a vec4. So we may have + // to pull out the component we need. + let result_id = if result_type_id == access.result_type() { + // The instruction produced the type we expected. We can use + // its result as-is. + access_id + } else { + // For `ImageClass::Depth` images, SPIR-V gave us four components, + // but we only want the first one. + let component_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + component_id, + access_id, + &[0], + )); + component_id + }; + + Ok(result_id) + } + + /// Generate code for an `ImageSample` expression. + /// + /// The arguments are the components of an `Expression::ImageSample` variant. + #[allow(clippy::too_many_arguments)] + pub(super) fn write_image_sample( + &mut self, + result_type_id: Word, + image: Handle<crate::Expression>, + sampler: Handle<crate::Expression>, + gather: Option<crate::SwizzleComponent>, + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + offset: Option<Handle<crate::Expression>>, + level: crate::SampleLevel, + depth_ref: Option<Handle<crate::Expression>>, + block: &mut Block, + ) -> Result<Word, Error> { + use super::instructions::SampleLod; + // image + let image_id = self.get_handle_id(image); + let image_type = self.fun_info[image].ty.handle().unwrap(); + // SPIR-V doesn't know about our `Depth` class, and it returns + // `vec4<f32>`, so we need to grab the first component out of it. + let needs_sub_access = match self.ir_module.types[image_type].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Depth { .. }, + .. + } => depth_ref.is_none() && gather.is_none(), + _ => false, + }; + let sample_result_type_id = if needs_sub_access { + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Quad), + scalar: crate::Scalar::F32, + pointer_space: None, + })) + } else { + result_type_id + }; + + // OpTypeSampledImage + let image_type_id = self.get_type_id(LookupType::Handle(image_type)); + let sampled_image_type_id = + self.get_type_id(LookupType::Local(LocalType::SampledImage { image_type_id })); + + let sampler_id = self.get_handle_id(sampler); + let coordinates_id = self + .write_image_coordinates(coordinate, array_index, block)? + .value_id; + + let sampled_image_id = self.gen_id(); + block.body.push(Instruction::sampled_image( + sampled_image_type_id, + sampled_image_id, + image_id, + sampler_id, + )); + let id = self.gen_id(); + + let depth_id = depth_ref.map(|handle| self.cached[handle]); + let mut mask = spirv::ImageOperands::empty(); + mask.set(spirv::ImageOperands::CONST_OFFSET, offset.is_some()); + + let mut main_instruction = match (level, gather) { + (_, Some(component)) => { + let component_id = self.get_index_constant(component as u32); + let mut inst = Instruction::image_gather( + sample_result_type_id, + id, + sampled_image_id, + coordinates_id, + component_id, + depth_id, + ); + if !mask.is_empty() { + inst.add_operand(mask.bits()); + } + inst + } + (crate::SampleLevel::Zero, None) => { + let mut inst = Instruction::image_sample( + sample_result_type_id, + id, + SampleLod::Explicit, + sampled_image_id, + coordinates_id, + depth_id, + ); + + let zero_id = self.writer.get_constant_scalar(crate::Literal::F32(0.0)); + + mask |= spirv::ImageOperands::LOD; + inst.add_operand(mask.bits()); + inst.add_operand(zero_id); + + inst + } + (crate::SampleLevel::Auto, None) => { + let mut inst = Instruction::image_sample( + sample_result_type_id, + id, + SampleLod::Implicit, + sampled_image_id, + coordinates_id, + depth_id, + ); + if !mask.is_empty() { + inst.add_operand(mask.bits()); + } + inst + } + (crate::SampleLevel::Exact(lod_handle), None) => { + let mut inst = Instruction::image_sample( + sample_result_type_id, + id, + SampleLod::Explicit, + sampled_image_id, + coordinates_id, + depth_id, + ); + + let lod_id = self.cached[lod_handle]; + mask |= spirv::ImageOperands::LOD; + inst.add_operand(mask.bits()); + inst.add_operand(lod_id); + + inst + } + (crate::SampleLevel::Bias(bias_handle), None) => { + let mut inst = Instruction::image_sample( + sample_result_type_id, + id, + SampleLod::Implicit, + sampled_image_id, + coordinates_id, + depth_id, + ); + + let bias_id = self.cached[bias_handle]; + mask |= spirv::ImageOperands::BIAS; + inst.add_operand(mask.bits()); + inst.add_operand(bias_id); + + inst + } + (crate::SampleLevel::Gradient { x, y }, None) => { + let mut inst = Instruction::image_sample( + sample_result_type_id, + id, + SampleLod::Explicit, + sampled_image_id, + coordinates_id, + depth_id, + ); + + let x_id = self.cached[x]; + let y_id = self.cached[y]; + mask |= spirv::ImageOperands::GRAD; + inst.add_operand(mask.bits()); + inst.add_operand(x_id); + inst.add_operand(y_id); + + inst + } + }; + + if let Some(offset_const) = offset { + let offset_id = self.writer.constant_ids[offset_const.index()]; + main_instruction.add_operand(offset_id); + } + + block.body.push(main_instruction); + + let id = if needs_sub_access { + let sub_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + sub_id, + id, + &[0], + )); + sub_id + } else { + id + }; + + Ok(id) + } + + /// Generate code for an `ImageQuery` expression. + /// + /// The arguments are the components of an `Expression::ImageQuery` variant. + pub(super) fn write_image_query( + &mut self, + result_type_id: Word, + image: Handle<crate::Expression>, + query: crate::ImageQuery, + block: &mut Block, + ) -> Result<Word, Error> { + use crate::{ImageClass as Ic, ImageDimension as Id, ImageQuery as Iq}; + + let image_id = self.get_handle_id(image); + let image_type = self.fun_info[image].ty.handle().unwrap(); + let (dim, arrayed, class) = match self.ir_module.types[image_type].inner { + crate::TypeInner::Image { + dim, + arrayed, + class, + } => (dim, arrayed, class), + _ => { + return Err(Error::Validation("image type")); + } + }; + + self.writer + .require_any("image queries", &[spirv::Capability::ImageQuery])?; + + let id = match query { + Iq::Size { level } => { + let dim_coords = match dim { + Id::D1 => 1, + Id::D2 | Id::Cube => 2, + Id::D3 => 3, + }; + let array_coords = usize::from(arrayed); + let vector_size = match dim_coords + array_coords { + 2 => Some(crate::VectorSize::Bi), + 3 => Some(crate::VectorSize::Tri), + 4 => Some(crate::VectorSize::Quad), + _ => None, + }; + let extended_size_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size, + scalar: crate::Scalar::U32, + pointer_space: None, + })); + + let (query_op, level_id) = match class { + Ic::Sampled { multi: true, .. } + | Ic::Depth { multi: true } + | Ic::Storage { .. } => (spirv::Op::ImageQuerySize, None), + _ => { + let level_id = match level { + Some(expr) => self.cached[expr], + None => self.get_index_constant(0), + }; + (spirv::Op::ImageQuerySizeLod, Some(level_id)) + } + }; + + // The ID of the vector returned by SPIR-V, which contains the dimensions + // as well as the layer count. + let id_extended = self.gen_id(); + let mut inst = Instruction::image_query( + query_op, + extended_size_type_id, + id_extended, + image_id, + ); + if let Some(expr_id) = level_id { + inst.add_operand(expr_id); + } + block.body.push(inst); + + if result_type_id != extended_size_type_id { + let id = self.gen_id(); + let components = match dim { + // always pick the first component, and duplicate it for all 3 dimensions + Id::Cube => &[0u32, 0][..], + _ => &[0u32, 1, 2, 3][..dim_coords], + }; + block.body.push(Instruction::vector_shuffle( + result_type_id, + id, + id_extended, + id_extended, + components, + )); + + id + } else { + id_extended + } + } + Iq::NumLevels => { + let query_id = self.gen_id(); + block.body.push(Instruction::image_query( + spirv::Op::ImageQueryLevels, + result_type_id, + query_id, + image_id, + )); + + query_id + } + Iq::NumLayers => { + let vec_size = match dim { + Id::D1 => crate::VectorSize::Bi, + Id::D2 | Id::Cube => crate::VectorSize::Tri, + Id::D3 => crate::VectorSize::Quad, + }; + let extended_size_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(vec_size), + scalar: crate::Scalar::U32, + pointer_space: None, + })); + let id_extended = self.gen_id(); + let mut inst = Instruction::image_query( + spirv::Op::ImageQuerySizeLod, + extended_size_type_id, + id_extended, + image_id, + ); + inst.add_operand(self.get_index_constant(0)); + block.body.push(inst); + + let extract_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + extract_id, + id_extended, + &[vec_size as u32 - 1], + )); + + extract_id + } + Iq::NumSamples => { + let query_id = self.gen_id(); + block.body.push(Instruction::image_query( + spirv::Op::ImageQuerySamples, + result_type_id, + query_id, + image_id, + )); + + query_id + } + }; + + Ok(id) + } + + pub(super) fn write_image_store( + &mut self, + image: Handle<crate::Expression>, + coordinate: Handle<crate::Expression>, + array_index: Option<Handle<crate::Expression>>, + value: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<(), Error> { + let image_id = self.get_handle_id(image); + let coordinates = self.write_image_coordinates(coordinate, array_index, block)?; + let value_id = self.cached[value]; + + let write = Store { image_id, value_id }; + + match *self.fun_info[image].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Image { + class: + crate::ImageClass::Storage { + format: crate::StorageFormat::Bgra8Unorm, + .. + }, + .. + } => self.writer.require_any( + "Bgra8Unorm storage write", + &[spirv::Capability::StorageImageWriteWithoutFormat], + )?, + _ => {} + } + + match self.writer.bounds_check_policies.image_store { + crate::proc::BoundsCheckPolicy::Restrict => { + let (coords, _, _) = + self.write_restricted_coordinates(image_id, coordinates, None, None, block)?; + write.generate(&mut self.writer.id_gen, coords, None, None, block); + } + crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => { + self.write_conditional_image_access( + image_id, + coordinates, + None, + None, + block, + &write, + )?; + } + crate::proc::BoundsCheckPolicy::Unchecked => { + write.generate( + &mut self.writer.id_gen, + coordinates.value_id, + None, + None, + block, + ); + } + } + + Ok(()) + } +} diff --git a/third_party/rust/naga/src/back/spv/index.rs b/third_party/rust/naga/src/back/spv/index.rs new file mode 100644 index 0000000000..92e0f88d9a --- /dev/null +++ b/third_party/rust/naga/src/back/spv/index.rs @@ -0,0 +1,421 @@ +/*! +Bounds-checking for SPIR-V output. +*/ + +use super::{ + helpers::global_needs_wrapper, selection::Selection, Block, BlockContext, Error, IdGenerator, + Instruction, Word, +}; +use crate::{arena::Handle, proc::BoundsCheckPolicy}; + +/// The results of performing a bounds check. +/// +/// On success, `write_bounds_check` returns a value of this type. +pub(super) enum BoundsCheckResult { + /// The index is statically known and in bounds, with the given value. + KnownInBounds(u32), + + /// The given instruction computes the index to be used. + Computed(Word), + + /// The given instruction computes a boolean condition which is true + /// if the index is in bounds. + Conditional(Word), +} + +/// A value that we either know at translation time, or need to compute at runtime. +pub(super) enum MaybeKnown<T> { + /// The value is known at shader translation time. + Known(T), + + /// The value is computed by the instruction with the given id. + Computed(Word), +} + +impl<'w> BlockContext<'w> { + /// Emit code to compute the length of a run-time array. + /// + /// Given `array`, an expression referring a runtime-sized array, return the + /// instruction id for the array's length. + pub(super) fn write_runtime_array_length( + &mut self, + array: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<Word, Error> { + // Naga IR permits runtime-sized arrays as global variables or as the + // final member of a struct that is a global variable. SPIR-V permits + // only the latter, so this back end wraps bare runtime-sized arrays + // in a made-up struct; see `helpers::global_needs_wrapper` and its uses. + // This code must handle both cases. + let (structure_id, last_member_index) = match self.ir_function.expressions[array] { + crate::Expression::AccessIndex { base, index } => { + match self.ir_function.expressions[base] { + crate::Expression::GlobalVariable(handle) => ( + self.writer.global_variables[handle.index()].access_id, + index, + ), + _ => return Err(Error::Validation("array length expression")), + } + } + crate::Expression::GlobalVariable(handle) => { + let global = &self.ir_module.global_variables[handle]; + if !global_needs_wrapper(self.ir_module, global) { + return Err(Error::Validation("array length expression")); + } + + (self.writer.global_variables[handle.index()].var_id, 0) + } + _ => return Err(Error::Validation("array length expression")), + }; + + let length_id = self.gen_id(); + block.body.push(Instruction::array_length( + self.writer.get_uint_type_id(), + length_id, + structure_id, + last_member_index, + )); + + Ok(length_id) + } + + /// Compute the length of a subscriptable value. + /// + /// Given `sequence`, an expression referring to some indexable type, return + /// its length. The result may either be computed by SPIR-V instructions, or + /// known at shader translation time. + /// + /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any + /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically + /// sized, or use a specializable constant as its length. + fn write_sequence_length( + &mut self, + sequence: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<MaybeKnown<u32>, Error> { + let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types); + match sequence_ty.indexable_length(self.ir_module) { + Ok(crate::proc::IndexableLength::Known(known_length)) => { + Ok(MaybeKnown::Known(known_length)) + } + Ok(crate::proc::IndexableLength::Dynamic) => { + let length_id = self.write_runtime_array_length(sequence, block)?; + Ok(MaybeKnown::Computed(length_id)) + } + Err(err) => { + log::error!("Sequence length for {:?} failed: {}", sequence, err); + Err(Error::Validation("indexable length")) + } + } + } + + /// Compute the maximum valid index of a subscriptable value. + /// + /// Given `sequence`, an expression referring to some indexable type, return + /// its maximum valid index - one less than its length. The result may + /// either be computed, or known at shader translation time. + /// + /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any + /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically + /// sized, or use a specializable constant as its length. + fn write_sequence_max_index( + &mut self, + sequence: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<MaybeKnown<u32>, Error> { + match self.write_sequence_length(sequence, block)? { + MaybeKnown::Known(known_length) => { + // We should have thrown out all attempts to subscript zero-length + // sequences during validation, so the following subtraction should never + // underflow. + assert!(known_length > 0); + // Compute the max index from the length now. + Ok(MaybeKnown::Known(known_length - 1)) + } + MaybeKnown::Computed(length_id) => { + // Emit code to compute the max index from the length. + let const_one_id = self.get_index_constant(1); + let max_index_id = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::ISub, + self.writer.get_uint_type_id(), + max_index_id, + length_id, + const_one_id, + )); + Ok(MaybeKnown::Computed(max_index_id)) + } + } + } + + /// Restrict an index to be in range for a vector, matrix, or array. + /// + /// This is used to implement `BoundsCheckPolicy::Restrict`. An in-bounds + /// index is left unchanged. An out-of-bounds index is replaced with some + /// arbitrary in-bounds index. Note,this is not necessarily clamping; for + /// example, negative indices might be changed to refer to the last element + /// of the sequence, not the first, as clamping would do. + /// + /// Either return the restricted index value, if known, or add instructions + /// to `block` to compute it, and return the id of the result. See the + /// documentation for `BoundsCheckResult` for details. + /// + /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a + /// `Pointer` to any of those, or a `ValuePointer`. An array may be + /// fixed-size, dynamically sized, or use a specializable constant as its + /// length. + pub(super) fn write_restricted_index( + &mut self, + sequence: Handle<crate::Expression>, + index: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<BoundsCheckResult, Error> { + let index_id = self.cached[index]; + + // Get the sequence's maximum valid index. Return early if we've already + // done the bounds check. + let max_index_id = match self.write_sequence_max_index(sequence, block)? { + MaybeKnown::Known(known_max_index) => { + if let Ok(known_index) = self + .ir_module + .to_ctx() + .eval_expr_to_u32_from(index, &self.ir_function.expressions) + { + // Both the index and length are known at compile time. + // + // In strict WGSL compliance mode, out-of-bounds indices cannot be + // reported at shader translation time, and must be replaced with + // in-bounds indices at run time. So we cannot assume that + // validation ensured the index was in bounds. Restrict now. + let restricted = std::cmp::min(known_index, known_max_index); + return Ok(BoundsCheckResult::KnownInBounds(restricted)); + } + + self.get_index_constant(known_max_index) + } + MaybeKnown::Computed(max_index_id) => max_index_id, + }; + + // One or the other of the index or length is dynamic, so emit code for + // BoundsCheckPolicy::Restrict. + let restricted_index_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::UMin, + self.writer.get_uint_type_id(), + restricted_index_id, + &[index_id, max_index_id], + )); + Ok(BoundsCheckResult::Computed(restricted_index_id)) + } + + /// Write an index bounds comparison to `block`, if needed. + /// + /// If we're able to determine statically that `index` is in bounds for + /// `sequence`, return `KnownInBounds(value)`, where `value` is the actual + /// value of the index. (In principle, one could know that the index is in + /// bounds without knowing its specific value, but in our simple-minded + /// situation, we always know it.) + /// + /// If instead we must generate code to perform the comparison at run time, + /// return `Conditional(comparison_id)`, where `comparison_id` is an + /// instruction producing a boolean value that is true if `index` is in + /// bounds for `sequence`. + /// + /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a + /// `Pointer` to any of those, or a `ValuePointer`. An array may be + /// fixed-size, dynamically sized, or use a specializable constant as its + /// length. + fn write_index_comparison( + &mut self, + sequence: Handle<crate::Expression>, + index: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<BoundsCheckResult, Error> { + let index_id = self.cached[index]; + + // Get the sequence's length. Return early if we've already done the + // bounds check. + let length_id = match self.write_sequence_length(sequence, block)? { + MaybeKnown::Known(known_length) => { + if let Ok(known_index) = self + .ir_module + .to_ctx() + .eval_expr_to_u32_from(index, &self.ir_function.expressions) + { + // Both the index and length are known at compile time. + // + // It would be nice to assume that, since we are using the + // `ReadZeroSkipWrite` policy, we are not in strict WGSL + // compliance mode, and thus we can count on the validator to have + // rejected any programs with known out-of-bounds indices, and + // thus just return `KnownInBounds` here without actually + // checking. + // + // But it's also reasonable to expect that bounds check policies + // and error reporting policies should be able to vary + // independently without introducing security holes. So, we should + // support the case where bad indices do not cause validation + // errors, and are handled via `ReadZeroSkipWrite`. + // + // In theory, when `known_index` is bad, we could return a new + // `KnownOutOfBounds` variant here. But it's simpler just to fall + // through and let the bounds check take place. The shader is + // broken anyway, so it doesn't make sense to invest in emitting + // the ideal code for it. + if known_index < known_length { + return Ok(BoundsCheckResult::KnownInBounds(known_index)); + } + } + + self.get_index_constant(known_length) + } + MaybeKnown::Computed(length_id) => length_id, + }; + + // Compare the index against the length. + let condition_id = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::ULessThan, + self.writer.get_bool_type_id(), + condition_id, + index_id, + length_id, + )); + + // Indicate that we did generate the check. + Ok(BoundsCheckResult::Conditional(condition_id)) + } + + /// Emit a conditional load for `BoundsCheckPolicy::ReadZeroSkipWrite`. + /// + /// Generate code to load a value of `result_type` if `condition` is true, + /// and generate a null value of that type if it is false. Call `emit_load` + /// to emit the instructions to perform the load. Return the id of the + /// merged value of the two branches. + pub(super) fn write_conditional_indexed_load<F>( + &mut self, + result_type: Word, + condition: Word, + block: &mut Block, + emit_load: F, + ) -> Word + where + F: FnOnce(&mut IdGenerator, &mut Block) -> Word, + { + // For the out-of-bounds case, we produce a zero value. + let null_id = self.writer.get_constant_null(result_type); + + let mut selection = Selection::start(block, result_type); + + // As it turns out, we don't actually need a full 'if-then-else' + // structure for this: SPIR-V constants are declared up front, so the + // 'else' block would have no instructions. Instead we emit something + // like this: + // + // result = zero; + // if in_bounds { + // result = do the load; + // } + // use result; + + // Continue only if the index was in bounds. Otherwise, branch to the + // merge block. + selection.if_true(self, condition, null_id); + + // The in-bounds path. Perform the access and the load. + let loaded_value = emit_load(&mut self.writer.id_gen, selection.block()); + + selection.finish(self, loaded_value) + } + + /// Emit code for bounds checks for an array, vector, or matrix access. + /// + /// This implements either `index_bounds_check_policy` or + /// `buffer_bounds_check_policy`, depending on the address space of the + /// pointer being accessed. + /// + /// Return a `BoundsCheckResult` indicating how the index should be + /// consumed. See that type's documentation for details. + pub(super) fn write_bounds_check( + &mut self, + base: Handle<crate::Expression>, + index: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<BoundsCheckResult, Error> { + let policy = self.writer.bounds_check_policies.choose_policy( + base, + &self.ir_module.types, + self.fun_info, + ); + + Ok(match policy { + BoundsCheckPolicy::Restrict => self.write_restricted_index(base, index, block)?, + BoundsCheckPolicy::ReadZeroSkipWrite => { + self.write_index_comparison(base, index, block)? + } + BoundsCheckPolicy::Unchecked => BoundsCheckResult::Computed(self.cached[index]), + }) + } + + /// Emit code to subscript a vector by value with a computed index. + /// + /// Return the id of the element value. + pub(super) fn write_vector_access( + &mut self, + expr_handle: Handle<crate::Expression>, + base: Handle<crate::Expression>, + index: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<Word, Error> { + let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty); + + let base_id = self.cached[base]; + let index_id = self.cached[index]; + + let result_id = match self.write_bounds_check(base, index, block)? { + BoundsCheckResult::KnownInBounds(known_index) => { + let result_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + result_id, + base_id, + &[known_index], + )); + result_id + } + BoundsCheckResult::Computed(computed_index_id) => { + let result_id = self.gen_id(); + block.body.push(Instruction::vector_extract_dynamic( + result_type_id, + result_id, + base_id, + computed_index_id, + )); + result_id + } + BoundsCheckResult::Conditional(comparison_id) => { + // Run-time bounds checks were required. Emit + // conditional load. + self.write_conditional_indexed_load( + result_type_id, + comparison_id, + block, + |id_gen, block| { + // The in-bounds path. Generate the access. + let element_id = id_gen.next(); + block.body.push(Instruction::vector_extract_dynamic( + result_type_id, + element_id, + base_id, + index_id, + )); + element_id + }, + ) + } + }; + + Ok(result_id) + } +} diff --git a/third_party/rust/naga/src/back/spv/instructions.rs b/third_party/rust/naga/src/back/spv/instructions.rs new file mode 100644 index 0000000000..b963793ad3 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/instructions.rs @@ -0,0 +1,1100 @@ +use super::{block::DebugInfoInner, helpers}; +use spirv::{Op, Word}; + +pub(super) enum Signedness { + Unsigned = 0, + Signed = 1, +} + +pub(super) enum SampleLod { + Explicit, + Implicit, +} + +pub(super) struct Case { + pub value: Word, + pub label_id: Word, +} + +impl super::Instruction { + // + // Debug Instructions + // + + pub(super) fn string(name: &str, id: Word) -> Self { + let mut instruction = Self::new(Op::String); + instruction.set_result(id); + instruction.add_operands(helpers::string_to_words(name)); + instruction + } + + pub(super) fn source( + source_language: spirv::SourceLanguage, + version: u32, + source: &Option<DebugInfoInner>, + ) -> Self { + let mut instruction = Self::new(Op::Source); + instruction.add_operand(source_language as u32); + instruction.add_operands(helpers::bytes_to_words(&version.to_le_bytes())); + if let Some(source) = source.as_ref() { + instruction.add_operand(source.source_file_id); + instruction.add_operands(helpers::string_to_words(source.source_code)); + } + instruction + } + + pub(super) fn name(target_id: Word, name: &str) -> Self { + let mut instruction = Self::new(Op::Name); + instruction.add_operand(target_id); + instruction.add_operands(helpers::string_to_words(name)); + instruction + } + + pub(super) fn member_name(target_id: Word, member: Word, name: &str) -> Self { + let mut instruction = Self::new(Op::MemberName); + instruction.add_operand(target_id); + instruction.add_operand(member); + instruction.add_operands(helpers::string_to_words(name)); + instruction + } + + pub(super) fn line(file: Word, line: Word, column: Word) -> Self { + let mut instruction = Self::new(Op::Line); + instruction.add_operand(file); + instruction.add_operand(line); + instruction.add_operand(column); + instruction + } + + pub(super) const fn no_line() -> Self { + Self::new(Op::NoLine) + } + + // + // Annotation Instructions + // + + pub(super) fn decorate( + target_id: Word, + decoration: spirv::Decoration, + operands: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::Decorate); + instruction.add_operand(target_id); + instruction.add_operand(decoration as u32); + for operand in operands { + instruction.add_operand(*operand) + } + instruction + } + + pub(super) fn member_decorate( + target_id: Word, + member_index: Word, + decoration: spirv::Decoration, + operands: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::MemberDecorate); + instruction.add_operand(target_id); + instruction.add_operand(member_index); + instruction.add_operand(decoration as u32); + for operand in operands { + instruction.add_operand(*operand) + } + instruction + } + + // + // Extension Instructions + // + + pub(super) fn extension(name: &str) -> Self { + let mut instruction = Self::new(Op::Extension); + instruction.add_operands(helpers::string_to_words(name)); + instruction + } + + pub(super) fn ext_inst_import(id: Word, name: &str) -> Self { + let mut instruction = Self::new(Op::ExtInstImport); + instruction.set_result(id); + instruction.add_operands(helpers::string_to_words(name)); + instruction + } + + pub(super) fn ext_inst( + set_id: Word, + op: spirv::GLOp, + result_type_id: Word, + id: Word, + operands: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::ExtInst); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(set_id); + instruction.add_operand(op as u32); + for operand in operands { + instruction.add_operand(*operand) + } + instruction + } + + // + // Mode-Setting Instructions + // + + pub(super) fn memory_model( + addressing_model: spirv::AddressingModel, + memory_model: spirv::MemoryModel, + ) -> Self { + let mut instruction = Self::new(Op::MemoryModel); + instruction.add_operand(addressing_model as u32); + instruction.add_operand(memory_model as u32); + instruction + } + + pub(super) fn entry_point( + execution_model: spirv::ExecutionModel, + entry_point_id: Word, + name: &str, + interface_ids: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::EntryPoint); + instruction.add_operand(execution_model as u32); + instruction.add_operand(entry_point_id); + instruction.add_operands(helpers::string_to_words(name)); + + for interface_id in interface_ids { + instruction.add_operand(*interface_id); + } + + instruction + } + + pub(super) fn execution_mode( + entry_point_id: Word, + execution_mode: spirv::ExecutionMode, + args: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::ExecutionMode); + instruction.add_operand(entry_point_id); + instruction.add_operand(execution_mode as u32); + for arg in args { + instruction.add_operand(*arg); + } + instruction + } + + pub(super) fn capability(capability: spirv::Capability) -> Self { + let mut instruction = Self::new(Op::Capability); + instruction.add_operand(capability as u32); + instruction + } + + // + // Type-Declaration Instructions + // + + pub(super) fn type_void(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeVoid); + instruction.set_result(id); + instruction + } + + pub(super) fn type_bool(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeBool); + instruction.set_result(id); + instruction + } + + pub(super) fn type_int(id: Word, width: Word, signedness: Signedness) -> Self { + let mut instruction = Self::new(Op::TypeInt); + instruction.set_result(id); + instruction.add_operand(width); + instruction.add_operand(signedness as u32); + instruction + } + + pub(super) fn type_float(id: Word, width: Word) -> Self { + let mut instruction = Self::new(Op::TypeFloat); + instruction.set_result(id); + instruction.add_operand(width); + instruction + } + + pub(super) fn type_vector( + id: Word, + component_type_id: Word, + component_count: crate::VectorSize, + ) -> Self { + let mut instruction = Self::new(Op::TypeVector); + instruction.set_result(id); + instruction.add_operand(component_type_id); + instruction.add_operand(component_count as u32); + instruction + } + + pub(super) fn type_matrix( + id: Word, + column_type_id: Word, + column_count: crate::VectorSize, + ) -> Self { + let mut instruction = Self::new(Op::TypeMatrix); + instruction.set_result(id); + instruction.add_operand(column_type_id); + instruction.add_operand(column_count as u32); + instruction + } + + #[allow(clippy::too_many_arguments)] + pub(super) fn type_image( + id: Word, + sampled_type_id: Word, + dim: spirv::Dim, + flags: super::ImageTypeFlags, + image_format: spirv::ImageFormat, + ) -> Self { + let mut instruction = Self::new(Op::TypeImage); + instruction.set_result(id); + instruction.add_operand(sampled_type_id); + instruction.add_operand(dim as u32); + instruction.add_operand(flags.contains(super::ImageTypeFlags::DEPTH) as u32); + instruction.add_operand(flags.contains(super::ImageTypeFlags::ARRAYED) as u32); + instruction.add_operand(flags.contains(super::ImageTypeFlags::MULTISAMPLED) as u32); + instruction.add_operand(if flags.contains(super::ImageTypeFlags::SAMPLED) { + 1 + } else { + 2 + }); + instruction.add_operand(image_format as u32); + instruction + } + + pub(super) fn type_sampler(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeSampler); + instruction.set_result(id); + instruction + } + + pub(super) fn type_acceleration_structure(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeAccelerationStructureKHR); + instruction.set_result(id); + instruction + } + + pub(super) fn type_ray_query(id: Word) -> Self { + let mut instruction = Self::new(Op::TypeRayQueryKHR); + instruction.set_result(id); + instruction + } + + pub(super) fn type_sampled_image(id: Word, image_type_id: Word) -> Self { + let mut instruction = Self::new(Op::TypeSampledImage); + instruction.set_result(id); + instruction.add_operand(image_type_id); + instruction + } + + pub(super) fn type_array(id: Word, element_type_id: Word, length_id: Word) -> Self { + let mut instruction = Self::new(Op::TypeArray); + instruction.set_result(id); + instruction.add_operand(element_type_id); + instruction.add_operand(length_id); + instruction + } + + pub(super) fn type_runtime_array(id: Word, element_type_id: Word) -> Self { + let mut instruction = Self::new(Op::TypeRuntimeArray); + instruction.set_result(id); + instruction.add_operand(element_type_id); + instruction + } + + pub(super) fn type_struct(id: Word, member_ids: &[Word]) -> Self { + let mut instruction = Self::new(Op::TypeStruct); + instruction.set_result(id); + + for member_id in member_ids { + instruction.add_operand(*member_id) + } + + instruction + } + + pub(super) fn type_pointer( + id: Word, + storage_class: spirv::StorageClass, + type_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::TypePointer); + instruction.set_result(id); + instruction.add_operand(storage_class as u32); + instruction.add_operand(type_id); + instruction + } + + pub(super) fn type_function(id: Word, return_type_id: Word, parameter_ids: &[Word]) -> Self { + let mut instruction = Self::new(Op::TypeFunction); + instruction.set_result(id); + instruction.add_operand(return_type_id); + + for parameter_id in parameter_ids { + instruction.add_operand(*parameter_id); + } + + instruction + } + + // + // Constant-Creation Instructions + // + + pub(super) fn constant_null(result_type_id: Word, id: Word) -> Self { + let mut instruction = Self::new(Op::ConstantNull); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction + } + + pub(super) fn constant_true(result_type_id: Word, id: Word) -> Self { + let mut instruction = Self::new(Op::ConstantTrue); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction + } + + pub(super) fn constant_false(result_type_id: Word, id: Word) -> Self { + let mut instruction = Self::new(Op::ConstantFalse); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction + } + + pub(super) fn constant_32bit(result_type_id: Word, id: Word, value: Word) -> Self { + Self::constant(result_type_id, id, &[value]) + } + + pub(super) fn constant_64bit(result_type_id: Word, id: Word, low: Word, high: Word) -> Self { + Self::constant(result_type_id, id, &[low, high]) + } + + pub(super) fn constant(result_type_id: Word, id: Word, values: &[Word]) -> Self { + let mut instruction = Self::new(Op::Constant); + instruction.set_type(result_type_id); + instruction.set_result(id); + + for value in values { + instruction.add_operand(*value); + } + + instruction + } + + pub(super) fn constant_composite( + result_type_id: Word, + id: Word, + constituent_ids: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::ConstantComposite); + instruction.set_type(result_type_id); + instruction.set_result(id); + + for constituent_id in constituent_ids { + instruction.add_operand(*constituent_id); + } + + instruction + } + + // + // Memory Instructions + // + + pub(super) fn variable( + result_type_id: Word, + id: Word, + storage_class: spirv::StorageClass, + initializer_id: Option<Word>, + ) -> Self { + let mut instruction = Self::new(Op::Variable); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(storage_class as u32); + + if let Some(initializer_id) = initializer_id { + instruction.add_operand(initializer_id); + } + + instruction + } + + pub(super) fn load( + result_type_id: Word, + id: Word, + pointer_id: Word, + memory_access: Option<spirv::MemoryAccess>, + ) -> Self { + let mut instruction = Self::new(Op::Load); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(pointer_id); + + if let Some(memory_access) = memory_access { + instruction.add_operand(memory_access.bits()); + } + + instruction + } + + pub(super) fn atomic_load( + result_type_id: Word, + id: Word, + pointer_id: Word, + scope_id: Word, + semantics_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::AtomicLoad); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(pointer_id); + instruction.add_operand(scope_id); + instruction.add_operand(semantics_id); + instruction + } + + pub(super) fn store( + pointer_id: Word, + value_id: Word, + memory_access: Option<spirv::MemoryAccess>, + ) -> Self { + let mut instruction = Self::new(Op::Store); + instruction.add_operand(pointer_id); + instruction.add_operand(value_id); + + if let Some(memory_access) = memory_access { + instruction.add_operand(memory_access.bits()); + } + + instruction + } + + pub(super) fn atomic_store( + pointer_id: Word, + scope_id: Word, + semantics_id: Word, + value_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::AtomicStore); + instruction.add_operand(pointer_id); + instruction.add_operand(scope_id); + instruction.add_operand(semantics_id); + instruction.add_operand(value_id); + instruction + } + + pub(super) fn access_chain( + result_type_id: Word, + id: Word, + base_id: Word, + index_ids: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::AccessChain); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(base_id); + + for index_id in index_ids { + instruction.add_operand(*index_id); + } + + instruction + } + + pub(super) fn array_length( + result_type_id: Word, + id: Word, + structure_id: Word, + array_member: Word, + ) -> Self { + let mut instruction = Self::new(Op::ArrayLength); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(structure_id); + instruction.add_operand(array_member); + instruction + } + + // + // Function Instructions + // + + pub(super) fn function( + return_type_id: Word, + id: Word, + function_control: spirv::FunctionControl, + function_type_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::Function); + instruction.set_type(return_type_id); + instruction.set_result(id); + instruction.add_operand(function_control.bits()); + instruction.add_operand(function_type_id); + instruction + } + + pub(super) fn function_parameter(result_type_id: Word, id: Word) -> Self { + let mut instruction = Self::new(Op::FunctionParameter); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction + } + + pub(super) const fn function_end() -> Self { + Self::new(Op::FunctionEnd) + } + + pub(super) fn function_call( + result_type_id: Word, + id: Word, + function_id: Word, + argument_ids: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::FunctionCall); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(function_id); + + for argument_id in argument_ids { + instruction.add_operand(*argument_id); + } + + instruction + } + + // + // Image Instructions + // + + pub(super) fn sampled_image( + result_type_id: Word, + id: Word, + image: Word, + sampler: Word, + ) -> Self { + let mut instruction = Self::new(Op::SampledImage); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(image); + instruction.add_operand(sampler); + instruction + } + + pub(super) fn image_sample( + result_type_id: Word, + id: Word, + lod: SampleLod, + sampled_image: Word, + coordinates: Word, + depth_ref: Option<Word>, + ) -> Self { + let op = match (lod, depth_ref) { + (SampleLod::Explicit, None) => Op::ImageSampleExplicitLod, + (SampleLod::Implicit, None) => Op::ImageSampleImplicitLod, + (SampleLod::Explicit, Some(_)) => Op::ImageSampleDrefExplicitLod, + (SampleLod::Implicit, Some(_)) => Op::ImageSampleDrefImplicitLod, + }; + + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(sampled_image); + instruction.add_operand(coordinates); + if let Some(dref) = depth_ref { + instruction.add_operand(dref); + } + + instruction + } + + pub(super) fn image_gather( + result_type_id: Word, + id: Word, + sampled_image: Word, + coordinates: Word, + component_id: Word, + depth_ref: Option<Word>, + ) -> Self { + let op = match depth_ref { + None => Op::ImageGather, + Some(_) => Op::ImageDrefGather, + }; + + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(sampled_image); + instruction.add_operand(coordinates); + if let Some(dref) = depth_ref { + instruction.add_operand(dref); + } else { + instruction.add_operand(component_id); + } + + instruction + } + + pub(super) fn image_fetch_or_read( + op: Op, + result_type_id: Word, + id: Word, + image: Word, + coordinates: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(image); + instruction.add_operand(coordinates); + instruction + } + + pub(super) fn image_write(image: Word, coordinates: Word, value: Word) -> Self { + let mut instruction = Self::new(Op::ImageWrite); + instruction.add_operand(image); + instruction.add_operand(coordinates); + instruction.add_operand(value); + instruction + } + + pub(super) fn image_query(op: Op, result_type_id: Word, id: Word, image: Word) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(image); + instruction + } + + // + // Ray Query Instructions + // + #[allow(clippy::too_many_arguments)] + pub(super) fn ray_query_initialize( + query: Word, + acceleration_structure: Word, + ray_flags: Word, + cull_mask: Word, + ray_origin: Word, + ray_tmin: Word, + ray_dir: Word, + ray_tmax: Word, + ) -> Self { + let mut instruction = Self::new(Op::RayQueryInitializeKHR); + instruction.add_operand(query); + instruction.add_operand(acceleration_structure); + instruction.add_operand(ray_flags); + instruction.add_operand(cull_mask); + instruction.add_operand(ray_origin); + instruction.add_operand(ray_tmin); + instruction.add_operand(ray_dir); + instruction.add_operand(ray_tmax); + instruction + } + + pub(super) fn ray_query_proceed(result_type_id: Word, id: Word, query: Word) -> Self { + let mut instruction = Self::new(Op::RayQueryProceedKHR); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(query); + instruction + } + + pub(super) fn ray_query_get_intersection( + op: Op, + result_type_id: Word, + id: Word, + query: Word, + intersection: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(query); + instruction.add_operand(intersection); + instruction + } + + // + // Conversion Instructions + // + pub(super) fn unary(op: Op, result_type_id: Word, id: Word, value: Word) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(value); + instruction + } + + // + // Composite Instructions + // + + pub(super) fn composite_construct( + result_type_id: Word, + id: Word, + constituent_ids: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::CompositeConstruct); + instruction.set_type(result_type_id); + instruction.set_result(id); + + for constituent_id in constituent_ids { + instruction.add_operand(*constituent_id); + } + + instruction + } + + pub(super) fn composite_extract( + result_type_id: Word, + id: Word, + composite_id: Word, + indices: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::CompositeExtract); + instruction.set_type(result_type_id); + instruction.set_result(id); + + instruction.add_operand(composite_id); + for index in indices { + instruction.add_operand(*index); + } + + instruction + } + + pub(super) fn vector_extract_dynamic( + result_type_id: Word, + id: Word, + vector_id: Word, + index_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::VectorExtractDynamic); + instruction.set_type(result_type_id); + instruction.set_result(id); + + instruction.add_operand(vector_id); + instruction.add_operand(index_id); + + instruction + } + + pub(super) fn vector_shuffle( + result_type_id: Word, + id: Word, + v1_id: Word, + v2_id: Word, + components: &[Word], + ) -> Self { + let mut instruction = Self::new(Op::VectorShuffle); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(v1_id); + instruction.add_operand(v2_id); + + for &component in components { + instruction.add_operand(component); + } + + instruction + } + + // + // Arithmetic Instructions + // + pub(super) fn binary( + op: Op, + result_type_id: Word, + id: Word, + operand_1: Word, + operand_2: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(operand_1); + instruction.add_operand(operand_2); + instruction + } + + pub(super) fn ternary( + op: Op, + result_type_id: Word, + id: Word, + operand_1: Word, + operand_2: Word, + operand_3: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(operand_1); + instruction.add_operand(operand_2); + instruction.add_operand(operand_3); + instruction + } + + pub(super) fn quaternary( + op: Op, + result_type_id: Word, + id: Word, + operand_1: Word, + operand_2: Word, + operand_3: Word, + operand_4: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(operand_1); + instruction.add_operand(operand_2); + instruction.add_operand(operand_3); + instruction.add_operand(operand_4); + instruction + } + + pub(super) fn relational(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(expr_id); + instruction + } + + pub(super) fn atomic_binary( + op: Op, + result_type_id: Word, + id: Word, + pointer: Word, + scope_id: Word, + semantics_id: Word, + value: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(pointer); + instruction.add_operand(scope_id); + instruction.add_operand(semantics_id); + instruction.add_operand(value); + instruction + } + + // + // Bit Instructions + // + + // + // Relational and Logical Instructions + // + + // + // Derivative Instructions + // + + pub(super) fn derivative(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(expr_id); + instruction + } + + // + // Control-Flow Instructions + // + + pub(super) fn phi( + result_type_id: Word, + result_id: Word, + var_parent_pairs: &[(Word, Word)], + ) -> Self { + let mut instruction = Self::new(Op::Phi); + instruction.add_operand(result_type_id); + instruction.add_operand(result_id); + for &(variable, parent) in var_parent_pairs { + instruction.add_operand(variable); + instruction.add_operand(parent); + } + instruction + } + + pub(super) fn selection_merge( + merge_id: Word, + selection_control: spirv::SelectionControl, + ) -> Self { + let mut instruction = Self::new(Op::SelectionMerge); + instruction.add_operand(merge_id); + instruction.add_operand(selection_control.bits()); + instruction + } + + pub(super) fn loop_merge( + merge_id: Word, + continuing_id: Word, + selection_control: spirv::SelectionControl, + ) -> Self { + let mut instruction = Self::new(Op::LoopMerge); + instruction.add_operand(merge_id); + instruction.add_operand(continuing_id); + instruction.add_operand(selection_control.bits()); + instruction + } + + pub(super) fn label(id: Word) -> Self { + let mut instruction = Self::new(Op::Label); + instruction.set_result(id); + instruction + } + + pub(super) fn branch(id: Word) -> Self { + let mut instruction = Self::new(Op::Branch); + instruction.add_operand(id); + instruction + } + + // TODO Branch Weights not implemented. + pub(super) fn branch_conditional( + condition_id: Word, + true_label: Word, + false_label: Word, + ) -> Self { + let mut instruction = Self::new(Op::BranchConditional); + instruction.add_operand(condition_id); + instruction.add_operand(true_label); + instruction.add_operand(false_label); + instruction + } + + pub(super) fn switch(selector_id: Word, default_id: Word, cases: &[Case]) -> Self { + let mut instruction = Self::new(Op::Switch); + instruction.add_operand(selector_id); + instruction.add_operand(default_id); + for case in cases { + instruction.add_operand(case.value); + instruction.add_operand(case.label_id); + } + instruction + } + + pub(super) fn select( + result_type_id: Word, + id: Word, + condition_id: Word, + accept_id: Word, + reject_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::Select); + instruction.add_operand(result_type_id); + instruction.add_operand(id); + instruction.add_operand(condition_id); + instruction.add_operand(accept_id); + instruction.add_operand(reject_id); + instruction + } + + pub(super) const fn kill() -> Self { + Self::new(Op::Kill) + } + + pub(super) const fn return_void() -> Self { + Self::new(Op::Return) + } + + pub(super) fn return_value(value_id: Word) -> Self { + let mut instruction = Self::new(Op::ReturnValue); + instruction.add_operand(value_id); + instruction + } + + // + // Atomic Instructions + // + + // + // Primitive Instructions + // + + // Barriers + + pub(super) fn control_barrier( + exec_scope_id: Word, + mem_scope_id: Word, + semantics_id: Word, + ) -> Self { + let mut instruction = Self::new(Op::ControlBarrier); + instruction.add_operand(exec_scope_id); + instruction.add_operand(mem_scope_id); + instruction.add_operand(semantics_id); + instruction + } +} + +impl From<crate::StorageFormat> for spirv::ImageFormat { + fn from(format: crate::StorageFormat) -> Self { + use crate::StorageFormat as Sf; + match format { + Sf::R8Unorm => Self::R8, + Sf::R8Snorm => Self::R8Snorm, + Sf::R8Uint => Self::R8ui, + Sf::R8Sint => Self::R8i, + Sf::R16Uint => Self::R16ui, + Sf::R16Sint => Self::R16i, + Sf::R16Float => Self::R16f, + Sf::Rg8Unorm => Self::Rg8, + Sf::Rg8Snorm => Self::Rg8Snorm, + Sf::Rg8Uint => Self::Rg8ui, + Sf::Rg8Sint => Self::Rg8i, + Sf::R32Uint => Self::R32ui, + Sf::R32Sint => Self::R32i, + Sf::R32Float => Self::R32f, + Sf::Rg16Uint => Self::Rg16ui, + Sf::Rg16Sint => Self::Rg16i, + Sf::Rg16Float => Self::Rg16f, + Sf::Rgba8Unorm => Self::Rgba8, + Sf::Rgba8Snorm => Self::Rgba8Snorm, + Sf::Rgba8Uint => Self::Rgba8ui, + Sf::Rgba8Sint => Self::Rgba8i, + Sf::Bgra8Unorm => Self::Unknown, + Sf::Rgb10a2Uint => Self::Rgb10a2ui, + Sf::Rgb10a2Unorm => Self::Rgb10A2, + Sf::Rg11b10Float => Self::R11fG11fB10f, + Sf::Rg32Uint => Self::Rg32ui, + Sf::Rg32Sint => Self::Rg32i, + Sf::Rg32Float => Self::Rg32f, + Sf::Rgba16Uint => Self::Rgba16ui, + Sf::Rgba16Sint => Self::Rgba16i, + Sf::Rgba16Float => Self::Rgba16f, + Sf::Rgba32Uint => Self::Rgba32ui, + Sf::Rgba32Sint => Self::Rgba32i, + Sf::Rgba32Float => Self::Rgba32f, + Sf::R16Unorm => Self::R16, + Sf::R16Snorm => Self::R16Snorm, + Sf::Rg16Unorm => Self::Rg16, + Sf::Rg16Snorm => Self::Rg16Snorm, + Sf::Rgba16Unorm => Self::Rgba16, + Sf::Rgba16Snorm => Self::Rgba16Snorm, + } + } +} + +impl From<crate::ImageDimension> for spirv::Dim { + fn from(dim: crate::ImageDimension) -> Self { + use crate::ImageDimension as Id; + match dim { + Id::D1 => Self::Dim1D, + Id::D2 => Self::Dim2D, + Id::D3 => Self::Dim3D, + Id::Cube => Self::DimCube, + } + } +} diff --git a/third_party/rust/naga/src/back/spv/layout.rs b/third_party/rust/naga/src/back/spv/layout.rs new file mode 100644 index 0000000000..39117a3d2a --- /dev/null +++ b/third_party/rust/naga/src/back/spv/layout.rs @@ -0,0 +1,210 @@ +use super::{Instruction, LogicalLayout, PhysicalLayout}; +use spirv::{Op, Word, MAGIC_NUMBER}; +use std::iter; + +// https://github.com/KhronosGroup/SPIRV-Headers/pull/195 +const GENERATOR: Word = 28; + +impl PhysicalLayout { + pub(super) const fn new(version: Word) -> Self { + PhysicalLayout { + magic_number: MAGIC_NUMBER, + version, + generator: GENERATOR, + bound: 0, + instruction_schema: 0x0u32, + } + } + + pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) { + sink.extend(iter::once(self.magic_number)); + sink.extend(iter::once(self.version)); + sink.extend(iter::once(self.generator)); + sink.extend(iter::once(self.bound)); + sink.extend(iter::once(self.instruction_schema)); + } +} + +impl super::recyclable::Recyclable for PhysicalLayout { + fn recycle(self) -> Self { + PhysicalLayout { + magic_number: self.magic_number, + version: self.version, + generator: self.generator, + instruction_schema: self.instruction_schema, + bound: 0, + } + } +} + +impl LogicalLayout { + pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) { + sink.extend(self.capabilities.iter().cloned()); + sink.extend(self.extensions.iter().cloned()); + sink.extend(self.ext_inst_imports.iter().cloned()); + sink.extend(self.memory_model.iter().cloned()); + sink.extend(self.entry_points.iter().cloned()); + sink.extend(self.execution_modes.iter().cloned()); + sink.extend(self.debugs.iter().cloned()); + sink.extend(self.annotations.iter().cloned()); + sink.extend(self.declarations.iter().cloned()); + sink.extend(self.function_declarations.iter().cloned()); + sink.extend(self.function_definitions.iter().cloned()); + } +} + +impl super::recyclable::Recyclable for LogicalLayout { + fn recycle(self) -> Self { + Self { + capabilities: self.capabilities.recycle(), + extensions: self.extensions.recycle(), + ext_inst_imports: self.ext_inst_imports.recycle(), + memory_model: self.memory_model.recycle(), + entry_points: self.entry_points.recycle(), + execution_modes: self.execution_modes.recycle(), + debugs: self.debugs.recycle(), + annotations: self.annotations.recycle(), + declarations: self.declarations.recycle(), + function_declarations: self.function_declarations.recycle(), + function_definitions: self.function_definitions.recycle(), + } + } +} + +impl Instruction { + pub(super) const fn new(op: Op) -> Self { + Instruction { + op, + wc: 1, // Always start at 1 for the first word (OP + WC), + type_id: None, + result_id: None, + operands: vec![], + } + } + + #[allow(clippy::panic)] + pub(super) fn set_type(&mut self, id: Word) { + assert!(self.type_id.is_none(), "Type can only be set once"); + self.type_id = Some(id); + self.wc += 1; + } + + #[allow(clippy::panic)] + pub(super) fn set_result(&mut self, id: Word) { + assert!(self.result_id.is_none(), "Result can only be set once"); + self.result_id = Some(id); + self.wc += 1; + } + + pub(super) fn add_operand(&mut self, operand: Word) { + self.operands.push(operand); + self.wc += 1; + } + + pub(super) fn add_operands(&mut self, operands: Vec<Word>) { + for operand in operands.into_iter() { + self.add_operand(operand) + } + } + + pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) { + sink.extend(Some(self.wc << 16 | self.op as u32)); + sink.extend(self.type_id); + sink.extend(self.result_id); + sink.extend(self.operands.iter().cloned()); + } +} + +impl Instruction { + #[cfg(test)] + fn validate(&self, words: &[Word]) { + let mut inst_index = 0; + let (wc, op) = ((words[inst_index] >> 16) as u16, words[inst_index] as u16); + inst_index += 1; + + assert_eq!(wc, words.len() as u16); + assert_eq!(op, self.op as u16); + + if self.type_id.is_some() { + assert_eq!(words[inst_index], self.type_id.unwrap()); + inst_index += 1; + } + + if self.result_id.is_some() { + assert_eq!(words[inst_index], self.result_id.unwrap()); + inst_index += 1; + } + + for (op_index, i) in (inst_index..wc as usize).enumerate() { + assert_eq!(words[i], self.operands[op_index]); + } + } +} + +#[test] +fn test_physical_layout_in_words() { + let bound = 5; + let version = 0x10203; + + let mut output = vec![]; + let mut layout = PhysicalLayout::new(version); + layout.bound = bound; + + layout.in_words(&mut output); + + assert_eq!(&output, &[MAGIC_NUMBER, version, GENERATOR, bound, 0,]); +} + +#[test] +fn test_logical_layout_in_words() { + let mut output = vec![]; + let mut layout = LogicalLayout::default(); + let layout_vectors = 11; + let mut instructions = Vec::with_capacity(layout_vectors); + + let vector_names = &[ + "Capabilities", + "Extensions", + "External Instruction Imports", + "Memory Model", + "Entry Points", + "Execution Modes", + "Debugs", + "Annotations", + "Declarations", + "Function Declarations", + "Function Definitions", + ]; + + for (i, _) in vector_names.iter().enumerate().take(layout_vectors) { + let mut dummy_instruction = Instruction::new(Op::Constant); + dummy_instruction.set_type((i + 1) as u32); + dummy_instruction.set_result((i + 2) as u32); + dummy_instruction.add_operand((i + 3) as u32); + dummy_instruction.add_operands(super::helpers::string_to_words( + format!("This is the vector: {}", vector_names[i]).as_str(), + )); + instructions.push(dummy_instruction); + } + + instructions[0].to_words(&mut layout.capabilities); + instructions[1].to_words(&mut layout.extensions); + instructions[2].to_words(&mut layout.ext_inst_imports); + instructions[3].to_words(&mut layout.memory_model); + instructions[4].to_words(&mut layout.entry_points); + instructions[5].to_words(&mut layout.execution_modes); + instructions[6].to_words(&mut layout.debugs); + instructions[7].to_words(&mut layout.annotations); + instructions[8].to_words(&mut layout.declarations); + instructions[9].to_words(&mut layout.function_declarations); + instructions[10].to_words(&mut layout.function_definitions); + + layout.in_words(&mut output); + + let mut index: usize = 0; + for instruction in instructions { + let wc = instruction.wc as usize; + instruction.validate(&output[index..index + wc]); + index += wc; + } +} diff --git a/third_party/rust/naga/src/back/spv/mod.rs b/third_party/rust/naga/src/back/spv/mod.rs new file mode 100644 index 0000000000..b7d57be0d4 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/mod.rs @@ -0,0 +1,748 @@ +/*! +Backend for [SPIR-V][spv] (Standard Portable Intermediate Representation). + +[spv]: https://www.khronos.org/registry/SPIR-V/ +*/ + +mod block; +mod helpers; +mod image; +mod index; +mod instructions; +mod layout; +mod ray; +mod recyclable; +mod selection; +mod writer; + +pub use spirv::Capability; + +use crate::arena::Handle; +use crate::proc::{BoundsCheckPolicies, TypeResolution}; + +use spirv::Word; +use std::ops; +use thiserror::Error; + +#[derive(Clone)] +struct PhysicalLayout { + magic_number: Word, + version: Word, + generator: Word, + bound: Word, + instruction_schema: Word, +} + +#[derive(Default)] +struct LogicalLayout { + capabilities: Vec<Word>, + extensions: Vec<Word>, + ext_inst_imports: Vec<Word>, + memory_model: Vec<Word>, + entry_points: Vec<Word>, + execution_modes: Vec<Word>, + debugs: Vec<Word>, + annotations: Vec<Word>, + declarations: Vec<Word>, + function_declarations: Vec<Word>, + function_definitions: Vec<Word>, +} + +struct Instruction { + op: spirv::Op, + wc: u32, + type_id: Option<Word>, + result_id: Option<Word>, + operands: Vec<Word>, +} + +const BITS_PER_BYTE: crate::Bytes = 8; + +#[derive(Clone, Debug, Error)] +pub enum Error { + #[error("The requested entry point couldn't be found")] + EntryPointNotFound, + #[error("target SPIRV-{0}.{1} is not supported")] + UnsupportedVersion(u8, u8), + #[error("using {0} requires at least one of the capabilities {1:?}, but none are available")] + MissingCapabilities(&'static str, Vec<Capability>), + #[error("unimplemented {0}")] + FeatureNotImplemented(&'static str), + #[error("module is not validated properly: {0}")] + Validation(&'static str), +} + +#[derive(Default)] +struct IdGenerator(Word); + +impl IdGenerator { + fn next(&mut self) -> Word { + self.0 += 1; + self.0 + } +} + +#[derive(Debug, Clone)] +pub struct DebugInfo<'a> { + pub source_code: &'a str, + pub file_name: &'a std::path::Path, +} + +/// A SPIR-V block to which we are still adding instructions. +/// +/// A `Block` represents a SPIR-V block that does not yet have a termination +/// instruction like `OpBranch` or `OpReturn`. +/// +/// The `OpLabel` that starts the block is implicit. It will be emitted based on +/// `label_id` when we write the block to a `LogicalLayout`. +/// +/// To terminate a `Block`, pass the block and the termination instruction to +/// `Function::consume`. This takes ownership of the `Block` and transforms it +/// into a `TerminatedBlock`. +struct Block { + label_id: Word, + body: Vec<Instruction>, +} + +/// A SPIR-V block that ends with a termination instruction. +struct TerminatedBlock { + label_id: Word, + body: Vec<Instruction>, +} + +impl Block { + const fn new(label_id: Word) -> Self { + Block { + label_id, + body: Vec::new(), + } + } +} + +struct LocalVariable { + id: Word, + instruction: Instruction, +} + +struct ResultMember { + id: Word, + type_id: Word, + built_in: Option<crate::BuiltIn>, +} + +struct EntryPointContext { + argument_ids: Vec<Word>, + results: Vec<ResultMember>, +} + +#[derive(Default)] +struct Function { + signature: Option<Instruction>, + parameters: Vec<FunctionArgument>, + variables: crate::FastHashMap<Handle<crate::LocalVariable>, LocalVariable>, + blocks: Vec<TerminatedBlock>, + entry_point_context: Option<EntryPointContext>, +} + +impl Function { + fn consume(&mut self, mut block: Block, termination: Instruction) { + block.body.push(termination); + self.blocks.push(TerminatedBlock { + label_id: block.label_id, + body: block.body, + }) + } + + fn parameter_id(&self, index: u32) -> Word { + match self.entry_point_context { + Some(ref context) => context.argument_ids[index as usize], + None => self.parameters[index as usize] + .instruction + .result_id + .unwrap(), + } + } +} + +/// Characteristics of a SPIR-V `OpTypeImage` type. +/// +/// SPIR-V requires non-composite types to be unique, including images. Since we +/// use `LocalType` for this deduplication, it's essential that `LocalImageType` +/// be equal whenever the corresponding `OpTypeImage`s would be. To reduce the +/// likelihood of mistakes, we use fields that correspond exactly to the +/// operands of an `OpTypeImage` instruction, using the actual SPIR-V types +/// where practical. +#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] +struct LocalImageType { + sampled_type: crate::ScalarKind, + dim: spirv::Dim, + flags: ImageTypeFlags, + image_format: spirv::ImageFormat, +} + +bitflags::bitflags! { + /// Flags corresponding to the boolean(-ish) parameters to OpTypeImage. + #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] + pub struct ImageTypeFlags: u8 { + const DEPTH = 0x1; + const ARRAYED = 0x2; + const MULTISAMPLED = 0x4; + const SAMPLED = 0x8; + } +} + +impl LocalImageType { + /// Construct a `LocalImageType` from the fields of a `TypeInner::Image`. + fn from_inner(dim: crate::ImageDimension, arrayed: bool, class: crate::ImageClass) -> Self { + let make_flags = |multi: bool, other: ImageTypeFlags| -> ImageTypeFlags { + let mut flags = other; + flags.set(ImageTypeFlags::ARRAYED, arrayed); + flags.set(ImageTypeFlags::MULTISAMPLED, multi); + flags + }; + + let dim = spirv::Dim::from(dim); + + match class { + crate::ImageClass::Sampled { kind, multi } => LocalImageType { + sampled_type: kind, + dim, + flags: make_flags(multi, ImageTypeFlags::SAMPLED), + image_format: spirv::ImageFormat::Unknown, + }, + crate::ImageClass::Depth { multi } => LocalImageType { + sampled_type: crate::ScalarKind::Float, + dim, + flags: make_flags(multi, ImageTypeFlags::DEPTH | ImageTypeFlags::SAMPLED), + image_format: spirv::ImageFormat::Unknown, + }, + crate::ImageClass::Storage { format, access: _ } => LocalImageType { + sampled_type: crate::ScalarKind::from(format), + dim, + flags: make_flags(false, ImageTypeFlags::empty()), + image_format: format.into(), + }, + } + } +} + +/// A SPIR-V type constructed during code generation. +/// +/// This is the variant of [`LookupType`] used to represent types that might not +/// be available in the arena. Variants are present here for one of two reasons: +/// +/// - They represent types synthesized during code generation, as explained +/// in the documentation for [`LookupType`]. +/// +/// - They represent types for which SPIR-V forbids duplicate `OpType...` +/// instructions, requiring deduplication. +/// +/// This is not a complete copy of [`TypeInner`]: for example, SPIR-V generation +/// never synthesizes new struct types, so `LocalType` has nothing for that. +/// +/// Each `LocalType` variant should be handled identically to its analogous +/// `TypeInner` variant. You can use the [`make_local`] function to help with +/// this, by converting everything possible to a `LocalType` before inspecting +/// it. +/// +/// ## `Localtype` equality and SPIR-V `OpType` uniqueness +/// +/// The definition of `Eq` on `LocalType` is carefully chosen to help us follow +/// certain SPIR-V rules. SPIR-V §2.8 requires some classes of `OpType...` +/// instructions to be unique; for example, you can't have two `OpTypeInt 32 1` +/// instructions in the same module. All 32-bit signed integers must use the +/// same type id. +/// +/// All SPIR-V types that must be unique can be represented as a `LocalType`, +/// and two `LocalType`s are always `Eq` if SPIR-V would require them to use the +/// same `OpType...` instruction. This lets us avoid duplicates by recording the +/// ids of the type instructions we've already generated in a hash table, +/// [`Writer::lookup_type`], keyed by `LocalType`. +/// +/// As another example, [`LocalImageType`], stored in the `LocalType::Image` +/// variant, is designed to help us deduplicate `OpTypeImage` instructions. See +/// its documentation for details. +/// +/// `LocalType` also includes variants like `Pointer` that do not need to be +/// unique - but it is harmless to avoid the duplication. +/// +/// As it always must, the `Hash` implementation respects the `Eq` relation. +/// +/// [`TypeInner`]: crate::TypeInner +#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] +enum LocalType { + /// A scalar, vector, or pointer to one of those. + Value { + /// If `None`, this represents a scalar type. If `Some`, this represents + /// a vector type of the given size. + vector_size: Option<crate::VectorSize>, + scalar: crate::Scalar, + pointer_space: Option<spirv::StorageClass>, + }, + /// A matrix of floating-point values. + Matrix { + columns: crate::VectorSize, + rows: crate::VectorSize, + width: crate::Bytes, + }, + Pointer { + base: Handle<crate::Type>, + class: spirv::StorageClass, + }, + Image(LocalImageType), + SampledImage { + image_type_id: Word, + }, + Sampler, + /// Equivalent to a [`LocalType::Pointer`] whose `base` is a Naga IR [`BindingArray`]. SPIR-V + /// permits duplicated `OpTypePointer` ids, so it's fine to have two different [`LocalType`] + /// representations for pointer types. + /// + /// [`BindingArray`]: crate::TypeInner::BindingArray + PointerToBindingArray { + base: Handle<crate::Type>, + size: u32, + space: crate::AddressSpace, + }, + BindingArray { + base: Handle<crate::Type>, + size: u32, + }, + AccelerationStructure, + RayQuery, +} + +/// A type encountered during SPIR-V generation. +/// +/// In the process of writing SPIR-V, we need to synthesize various types for +/// intermediate results and such: pointer types, vector/matrix component types, +/// or even booleans, which usually appear in SPIR-V code even when they're not +/// used by the module source. +/// +/// However, we can't use `crate::Type` or `crate::TypeInner` for these, as the +/// type arena may not contain what we need (it only contains types used +/// directly by other parts of the IR), and the IR module is immutable, so we +/// can't add anything to it. +/// +/// So for local use in the SPIR-V writer, we use this type, which holds either +/// a handle into the arena, or a [`LocalType`] containing something synthesized +/// locally. +/// +/// This is very similar to the [`proc::TypeResolution`] enum, with `LocalType` +/// playing the role of `TypeInner`. However, `LocalType` also has other +/// properties needed for SPIR-V generation; see the description of +/// [`LocalType`] for details. +/// +/// [`proc::TypeResolution`]: crate::proc::TypeResolution +#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] +enum LookupType { + Handle(Handle<crate::Type>), + Local(LocalType), +} + +impl From<LocalType> for LookupType { + fn from(local: LocalType) -> Self { + Self::Local(local) + } +} + +#[derive(Debug, PartialEq, Clone, Hash, Eq)] +struct LookupFunctionType { + parameter_type_ids: Vec<Word>, + return_type_id: Word, +} + +fn make_local(inner: &crate::TypeInner) -> Option<LocalType> { + Some(match *inner { + crate::TypeInner::Scalar(scalar) | crate::TypeInner::Atomic(scalar) => LocalType::Value { + vector_size: None, + scalar, + pointer_space: None, + }, + crate::TypeInner::Vector { size, scalar } => LocalType::Value { + vector_size: Some(size), + scalar, + pointer_space: None, + }, + crate::TypeInner::Matrix { + columns, + rows, + scalar, + } => LocalType::Matrix { + columns, + rows, + width: scalar.width, + }, + crate::TypeInner::Pointer { base, space } => LocalType::Pointer { + base, + class: helpers::map_storage_class(space), + }, + crate::TypeInner::ValuePointer { + size, + scalar, + space, + } => LocalType::Value { + vector_size: size, + scalar, + pointer_space: Some(helpers::map_storage_class(space)), + }, + crate::TypeInner::Image { + dim, + arrayed, + class, + } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)), + crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler, + crate::TypeInner::AccelerationStructure => LocalType::AccelerationStructure, + crate::TypeInner::RayQuery => LocalType::RayQuery, + crate::TypeInner::Array { .. } + | crate::TypeInner::Struct { .. } + | crate::TypeInner::BindingArray { .. } => return None, + }) +} + +#[derive(Debug)] +enum Dimension { + Scalar, + Vector, + Matrix, +} + +/// A map from evaluated [`Expression`](crate::Expression)s to their SPIR-V ids. +/// +/// When we emit code to evaluate a given `Expression`, we record the +/// SPIR-V id of its value here, under its `Handle<Expression>` index. +/// +/// A `CachedExpressions` value can be indexed by a `Handle<Expression>` value. +/// +/// [emit]: index.html#expression-evaluation-time-and-scope +#[derive(Default)] +struct CachedExpressions { + ids: Vec<Word>, +} +impl CachedExpressions { + fn reset(&mut self, length: usize) { + self.ids.clear(); + self.ids.resize(length, 0); + } +} +impl ops::Index<Handle<crate::Expression>> for CachedExpressions { + type Output = Word; + fn index(&self, h: Handle<crate::Expression>) -> &Word { + let id = &self.ids[h.index()]; + if *id == 0 { + unreachable!("Expression {:?} is not cached!", h); + } + id + } +} +impl ops::IndexMut<Handle<crate::Expression>> for CachedExpressions { + fn index_mut(&mut self, h: Handle<crate::Expression>) -> &mut Word { + let id = &mut self.ids[h.index()]; + if *id != 0 { + unreachable!("Expression {:?} is already cached!", h); + } + id + } +} +impl recyclable::Recyclable for CachedExpressions { + fn recycle(self) -> Self { + CachedExpressions { + ids: self.ids.recycle(), + } + } +} + +#[derive(Eq, Hash, PartialEq)] +enum CachedConstant { + Literal(crate::Literal), + Composite { + ty: LookupType, + constituent_ids: Vec<Word>, + }, + ZeroValue(Word), +} + +#[derive(Clone)] +struct GlobalVariable { + /// ID of the OpVariable that declares the global. + /// + /// If you need the variable's value, use [`access_id`] instead of this + /// field. If we wrapped the Naga IR `GlobalVariable`'s type in a struct to + /// comply with Vulkan's requirements, then this points to the `OpVariable` + /// with the synthesized struct type, whereas `access_id` points to the + /// field of said struct that holds the variable's actual value. + /// + /// This is used to compute the `access_id` pointer in function prologues, + /// and used for `ArrayLength` expressions, which do need the struct. + /// + /// [`access_id`]: GlobalVariable::access_id + var_id: Word, + + /// For `AddressSpace::Handle` variables, this ID is recorded in the function + /// prelude block (and reset before every function) as `OpLoad` of the variable. + /// It is then used for all the global ops, such as `OpImageSample`. + handle_id: Word, + + /// Actual ID used to access this variable. + /// For wrapped buffer variables, this ID is `OpAccessChain` into the + /// wrapper. Otherwise, the same as `var_id`. + /// + /// Vulkan requires that globals in the `StorageBuffer` and `Uniform` storage + /// classes must be structs with the `Block` decoration, but WGSL and Naga IR + /// make no such requirement. So for such variables, we generate a wrapper struct + /// type with a single element of the type given by Naga, generate an + /// `OpAccessChain` for that member in the function prelude, and use that pointer + /// to refer to the global in the function body. This is the id of that access, + /// updated for each function in `write_function`. + access_id: Word, +} + +impl GlobalVariable { + const fn dummy() -> Self { + Self { + var_id: 0, + handle_id: 0, + access_id: 0, + } + } + + const fn new(id: Word) -> Self { + Self { + var_id: id, + handle_id: 0, + access_id: 0, + } + } + + /// Prepare `self` for use within a single function. + fn reset_for_function(&mut self) { + self.handle_id = 0; + self.access_id = 0; + } +} + +struct FunctionArgument { + /// Actual instruction of the argument. + instruction: Instruction, + handle_id: Word, +} + +/// General information needed to emit SPIR-V for Naga statements. +struct BlockContext<'w> { + /// The writer handling the module to which this code belongs. + writer: &'w mut Writer, + + /// The [`Module`](crate::Module) for which we're generating code. + ir_module: &'w crate::Module, + + /// The [`Function`](crate::Function) for which we're generating code. + ir_function: &'w crate::Function, + + /// Information module validation produced about + /// [`ir_function`](BlockContext::ir_function). + fun_info: &'w crate::valid::FunctionInfo, + + /// The [`spv::Function`](Function) to which we are contributing SPIR-V instructions. + function: &'w mut Function, + + /// SPIR-V ids for expressions we've evaluated. + cached: CachedExpressions, + + /// The `Writer`'s temporary vector, for convenience. + temp_list: Vec<Word>, + + /// Tracks the constness of `Expression`s residing in `self.ir_function.expressions` + expression_constness: crate::proc::ExpressionConstnessTracker, +} + +impl BlockContext<'_> { + fn gen_id(&mut self) -> Word { + self.writer.id_gen.next() + } + + fn get_type_id(&mut self, lookup_type: LookupType) -> Word { + self.writer.get_type_id(lookup_type) + } + + fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word { + self.writer.get_expression_type_id(tr) + } + + fn get_index_constant(&mut self, index: Word) -> Word { + self.writer.get_constant_scalar(crate::Literal::U32(index)) + } + + fn get_scope_constant(&mut self, scope: Word) -> Word { + self.writer + .get_constant_scalar(crate::Literal::I32(scope as _)) + } +} + +#[derive(Clone, Copy, Default)] +struct LoopContext { + continuing_id: Option<Word>, + break_id: Option<Word>, +} + +pub struct Writer { + physical_layout: PhysicalLayout, + logical_layout: LogicalLayout, + id_gen: IdGenerator, + + /// The set of capabilities modules are permitted to use. + /// + /// This is initialized from `Options::capabilities`. + capabilities_available: Option<crate::FastHashSet<Capability>>, + + /// The set of capabilities used by this module. + /// + /// If `capabilities_available` is `Some`, then this is always a subset of + /// that. + capabilities_used: crate::FastIndexSet<Capability>, + + /// The set of spirv extensions used. + extensions_used: crate::FastIndexSet<&'static str>, + + debugs: Vec<Instruction>, + annotations: Vec<Instruction>, + flags: WriterFlags, + bounds_check_policies: BoundsCheckPolicies, + zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode, + void_type: Word, + //TODO: convert most of these into vectors, addressable by handle indices + lookup_type: crate::FastHashMap<LookupType, Word>, + lookup_function: crate::FastHashMap<Handle<crate::Function>, Word>, + lookup_function_type: crate::FastHashMap<LookupFunctionType, Word>, + /// Indexed by const-expression handle indexes + constant_ids: Vec<Word>, + cached_constants: crate::FastHashMap<CachedConstant, Word>, + global_variables: Vec<GlobalVariable>, + binding_map: BindingMap, + + // Cached expressions are only meaningful within a BlockContext, but we + // retain the table here between functions to save heap allocations. + saved_cached: CachedExpressions, + + gl450_ext_inst_id: Word, + + // Just a temporary list of SPIR-V ids + temp_list: Vec<Word>, +} + +bitflags::bitflags! { + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct WriterFlags: u32 { + /// Include debug labels for everything. + const DEBUG = 0x1; + /// Flip Y coordinate of `BuiltIn::Position` output. + const ADJUST_COORDINATE_SPACE = 0x2; + /// Emit `OpName` for input/output locations. + /// Contrary to spec, some drivers treat it as semantic, not allowing + /// any conflicts. + const LABEL_VARYINGS = 0x4; + /// Emit `PointSize` output builtin to vertex shaders, which is + /// required for drawing with `PointList` topology. + const FORCE_POINT_SIZE = 0x8; + /// Clamp `BuiltIn::FragDepth` output between 0 and 1. + const CLAMP_FRAG_DEPTH = 0x10; + } +} + +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct BindingInfo { + /// If the binding is an unsized binding array, this overrides the size. + pub binding_array_size: Option<u32>, +} + +// Using `BTreeMap` instead of `HashMap` so that we can hash itself. +pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindingInfo>; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ZeroInitializeWorkgroupMemoryMode { + /// Via `VK_KHR_zero_initialize_workgroup_memory` or Vulkan 1.3 + Native, + /// Via assignments + barrier + Polyfill, + None, +} + +#[derive(Debug, Clone)] +pub struct Options<'a> { + /// (Major, Minor) target version of the SPIR-V. + pub lang_version: (u8, u8), + + /// Configuration flags for the writer. + pub flags: WriterFlags, + + /// Map of resources to information about the binding. + pub binding_map: BindingMap, + + /// If given, the set of capabilities modules are allowed to use. Code that + /// requires capabilities beyond these is rejected with an error. + /// + /// If this is `None`, all capabilities are permitted. + pub capabilities: Option<crate::FastHashSet<Capability>>, + + /// How should generate code handle array, vector, matrix, or image texel + /// indices that are out of range? + pub bounds_check_policies: BoundsCheckPolicies, + + /// Dictates the way workgroup variables should be zero initialized + pub zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode, + + pub debug_info: Option<DebugInfo<'a>>, +} + +impl<'a> Default for Options<'a> { + fn default() -> Self { + let mut flags = WriterFlags::ADJUST_COORDINATE_SPACE + | WriterFlags::LABEL_VARYINGS + | WriterFlags::CLAMP_FRAG_DEPTH; + if cfg!(debug_assertions) { + flags |= WriterFlags::DEBUG; + } + Options { + lang_version: (1, 0), + flags, + binding_map: BindingMap::default(), + capabilities: None, + bounds_check_policies: crate::proc::BoundsCheckPolicies::default(), + zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode::Polyfill, + debug_info: None, + } + } +} + +// A subset of options meant to be changed per pipeline. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct PipelineOptions { + /// The stage of the entry point. + pub shader_stage: crate::ShaderStage, + /// The name of the entry point. + /// + /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown. + pub entry_point: String, +} + +pub fn write_vec( + module: &crate::Module, + info: &crate::valid::ModuleInfo, + options: &Options, + pipeline_options: Option<&PipelineOptions>, +) -> Result<Vec<u32>, Error> { + let mut words: Vec<u32> = Vec::new(); + let mut w = Writer::new(options)?; + + w.write( + module, + info, + pipeline_options, + &options.debug_info, + &mut words, + )?; + Ok(words) +} diff --git a/third_party/rust/naga/src/back/spv/ray.rs b/third_party/rust/naga/src/back/spv/ray.rs new file mode 100644 index 0000000000..bc2c4ce3c6 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/ray.rs @@ -0,0 +1,261 @@ +/*! +Generating SPIR-V for ray query operations. +*/ + +use super::{Block, BlockContext, Instruction, LocalType, LookupType}; +use crate::arena::Handle; + +impl<'w> BlockContext<'w> { + pub(super) fn write_ray_query_function( + &mut self, + query: Handle<crate::Expression>, + function: &crate::RayQueryFunction, + block: &mut Block, + ) { + let query_id = self.cached[query]; + match *function { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + //Note: composite extract indices and types must match `generate_ray_desc_type` + let desc_id = self.cached[descriptor]; + let acc_struct_id = self.get_handle_id(acceleration_structure); + + let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::U32, + pointer_space: None, + })); + let ray_flags_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + flag_type_id, + ray_flags_id, + desc_id, + &[0], + )); + let cull_mask_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + flag_type_id, + cull_mask_id, + desc_id, + &[1], + )); + + let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::F32, + pointer_space: None, + })); + let tmin_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmin_id, + desc_id, + &[2], + )); + let tmax_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + scalar_type_id, + tmax_id, + desc_id, + &[3], + )); + + let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + scalar: crate::Scalar::F32, + pointer_space: None, + })); + let ray_origin_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_origin_id, + desc_id, + &[4], + )); + let ray_dir_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + vector_type_id, + ray_dir_id, + desc_id, + &[5], + )); + + block.body.push(Instruction::ray_query_initialize( + query_id, + acc_struct_id, + ray_flags_id, + cull_mask_id, + ray_origin_id, + tmin_id, + ray_dir_id, + tmax_id, + )); + } + crate::RayQueryFunction::Proceed { result } => { + let id = self.gen_id(); + self.cached[result] = id; + let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty); + + block + .body + .push(Instruction::ray_query_proceed(result_type_id, id, query_id)); + } + crate::RayQueryFunction::Terminate => {} + } + } + + pub(super) fn write_ray_query_get_intersection( + &mut self, + query: Handle<crate::Expression>, + block: &mut Block, + ) -> spirv::Word { + let query_id = self.cached[query]; + let intersection_id = self.writer.get_constant_scalar(crate::Literal::U32( + spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _, + )); + + let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::U32, + pointer_space: None, + })); + let kind_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTypeKHR, + flag_type_id, + kind_id, + query_id, + intersection_id, + )); + let instance_custom_index_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR, + flag_type_id, + instance_custom_index_id, + query_id, + intersection_id, + )); + let instance_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceIdKHR, + flag_type_id, + instance_id, + query_id, + intersection_id, + )); + let sbt_record_offset_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR, + flag_type_id, + sbt_record_offset_id, + query_id, + intersection_id, + )); + let geometry_index_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionGeometryIndexKHR, + flag_type_id, + geometry_index_id, + query_id, + intersection_id, + )); + let primitive_index_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR, + flag_type_id, + primitive_index_id, + query_id, + intersection_id, + )); + + let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::F32, + pointer_space: None, + })); + let t_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionTKHR, + scalar_type_id, + t_id, + query_id, + intersection_id, + )); + + let barycentrics_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Bi), + scalar: crate::Scalar::F32, + pointer_space: None, + })); + let barycentrics_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionBarycentricsKHR, + barycentrics_type_id, + barycentrics_id, + query_id, + intersection_id, + )); + + let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::BOOL, + pointer_space: None, + })); + let front_face_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionFrontFaceKHR, + bool_type_id, + front_face_id, + query_id, + intersection_id, + )); + + let transform_type_id = self.get_type_id(LookupType::Local(LocalType::Matrix { + columns: crate::VectorSize::Quad, + rows: crate::VectorSize::Tri, + width: 4, + })); + let object_to_world_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionObjectToWorldKHR, + transform_type_id, + object_to_world_id, + query_id, + intersection_id, + )); + let world_to_object_id = self.gen_id(); + block.body.push(Instruction::ray_query_get_intersection( + spirv::Op::RayQueryGetIntersectionWorldToObjectKHR, + transform_type_id, + world_to_object_id, + query_id, + intersection_id, + )); + + let id = self.gen_id(); + let intersection_type_id = self.get_type_id(LookupType::Handle( + self.ir_module.special_types.ray_intersection.unwrap(), + )); + //Note: the arguments must match `generate_ray_intersection_type` layout + block.body.push(Instruction::composite_construct( + intersection_type_id, + id, + &[ + kind_id, + t_id, + instance_custom_index_id, + instance_id, + sbt_record_offset_id, + geometry_index_id, + primitive_index_id, + barycentrics_id, + front_face_id, + object_to_world_id, + world_to_object_id, + ], + )); + id + } +} diff --git a/third_party/rust/naga/src/back/spv/recyclable.rs b/third_party/rust/naga/src/back/spv/recyclable.rs new file mode 100644 index 0000000000..cd1466e3c7 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/recyclable.rs @@ -0,0 +1,67 @@ +/*! +Reusing collections' previous allocations. +*/ + +/// A value that can be reset to its initial state, retaining its current allocations. +/// +/// Naga attempts to lower the cost of SPIR-V generation by allowing clients to +/// reuse the same `Writer` for multiple Module translations. Reusing a `Writer` +/// means that the `Vec`s, `HashMap`s, and other heap-allocated structures the +/// `Writer` uses internally begin the translation with heap-allocated buffers +/// ready to use. +/// +/// But this approach introduces the risk of `Writer` state leaking from one +/// module to the next. When a developer adds fields to `Writer` or its internal +/// types, they must remember to reset their contents between modules. +/// +/// One trick to ensure that every field has been accounted for is to use Rust's +/// struct literal syntax to construct a new, reset value. If a developer adds a +/// field, but neglects to update the reset code, the compiler will complain +/// that a field is missing from the literal. This trait's `recycle` method +/// takes `self` by value, and returns `Self` by value, encouraging the use of +/// struct literal expressions in its implementation. +pub trait Recyclable { + /// Clear `self`, retaining its current memory allocations. + /// + /// Shrink the buffer if it's currently much larger than was actually used. + /// This prevents a module with exceptionally large allocations from causing + /// the `Writer` to retain more memory than it needs indefinitely. + fn recycle(self) -> Self; +} + +// Stock values for various collections. + +impl<T> Recyclable for Vec<T> { + fn recycle(mut self) -> Self { + self.clear(); + self + } +} + +impl<K, V, S: Clone> Recyclable for std::collections::HashMap<K, V, S> { + fn recycle(mut self) -> Self { + self.clear(); + self + } +} + +impl<K, S: Clone> Recyclable for std::collections::HashSet<K, S> { + fn recycle(mut self) -> Self { + self.clear(); + self + } +} + +impl<K, S: Clone> Recyclable for indexmap::IndexSet<K, S> { + fn recycle(mut self) -> Self { + self.clear(); + self + } +} + +impl<K: Ord, V> Recyclable for std::collections::BTreeMap<K, V> { + fn recycle(mut self) -> Self { + self.clear(); + self + } +} diff --git a/third_party/rust/naga/src/back/spv/selection.rs b/third_party/rust/naga/src/back/spv/selection.rs new file mode 100644 index 0000000000..788b1f10ab --- /dev/null +++ b/third_party/rust/naga/src/back/spv/selection.rs @@ -0,0 +1,257 @@ +/*! +Generate SPIR-V conditional structures. + +Builders for `if` structures with `and`s. + +The types in this module track the information needed to emit SPIR-V code +for complex conditional structures, like those whose conditions involve +short-circuiting 'and' and 'or' structures. These track labels and can emit +`OpPhi` instructions to merge values produced along different paths. + +This currently only supports exactly the forms Naga uses, so it doesn't +support `or` or `else`, and only supports zero or one merged values. + +Naga needs to emit code roughly like this: + +```ignore + + value = DEFAULT; + if COND1 && COND2 { + value = THEN_VALUE; + } + // use value + +``` + +Assuming `ctx` and `block` are a mutable references to a [`BlockContext`] +and the current [`Block`], and `merge_type` is the SPIR-V type for the +merged value `value`, we can build SPIR-V for the code above like so: + +```ignore + + let cond = Selection::start(block, merge_type); + // ... compute `cond1` ... + cond.if_true(ctx, cond1, DEFAULT); + // ... compute `cond2` ... + cond.if_true(ctx, cond2, DEFAULT); + // ... compute THEN_VALUE + let merged_value = cond.finish(ctx, THEN_VALUE); + +``` + +After this, `merged_value` is either `DEFAULT` or `THEN_VALUE`, depending on +the path by which the merged block was reached. + +This takes care of writing all branch instructions, including an +`OpSelectionMerge` annotation in the header block; starting new blocks and +assigning them labels; and emitting the `OpPhi` that gathers together the +right sources for the merged values, for every path through the selection +construct. + +When there is no merged value to produce, you can pass `()` for `merge_type` +and the merge values. In this case no `OpPhi` instructions are produced, and +the `finish` method returns `()`. + +To enforce proper nesting, a `Selection` takes ownership of the `&mut Block` +pointer for the duration of its lifetime. To obtain the block for generating +code in the selection's body, call the `Selection::block` method. +*/ + +use super::{Block, BlockContext, Instruction}; +use spirv::Word; + +/// A private struct recording what we know about the selection construct so far. +pub(super) struct Selection<'b, M: MergeTuple> { + /// The block pointer we're emitting code into. + block: &'b mut Block, + + /// The label of the selection construct's merge block, or `None` if we + /// haven't yet written the `OpSelectionMerge` merge instruction. + merge_label: Option<Word>, + + /// A set of `(VALUES, PARENT)` pairs, used to build `OpPhi` instructions in + /// the merge block. Each `PARENT` is the label of a predecessor block of + /// the merge block. The corresponding `VALUES` holds the ids of the values + /// that `PARENT` contributes to the merged values. + /// + /// We emit all branches to the merge block, so we know all its + /// predecessors. And we refuse to emit a branch unless we're given the + /// values the branching block contributes to the merge, so we always have + /// everything we need to emit the correct phis, by construction. + values: Vec<(M, Word)>, + + /// The types of the values in each element of `values`. + merge_types: M, +} + +impl<'b, M: MergeTuple> Selection<'b, M> { + /// Start a new selection construct. + /// + /// The `block` argument indicates the selection's header block. + /// + /// The `merge_types` argument should be a `Word` or tuple of `Word`s, each + /// value being the SPIR-V result type id of an `OpPhi` instruction that + /// will be written to the selection's merge block when this selection's + /// [`finish`] method is called. This argument may also be `()`, for + /// selections that produce no values. + /// + /// (This function writes no code to `block` itself; it simply constructs a + /// fresh `Selection`.) + /// + /// [`finish`]: Selection::finish + pub(super) fn start(block: &'b mut Block, merge_types: M) -> Self { + Selection { + block, + merge_label: None, + values: vec![], + merge_types, + } + } + + pub(super) fn block(&mut self) -> &mut Block { + self.block + } + + /// Branch to a successor block if `cond` is true, otherwise merge. + /// + /// If `cond` is false, branch to the merge block, using `values` as the + /// merged values. Otherwise, proceed to a new block. + /// + /// The `values` argument must be the same shape as the `merge_types` + /// argument passed to `Selection::start`. + pub(super) fn if_true(&mut self, ctx: &mut BlockContext, cond: Word, values: M) { + self.values.push((values, self.block.label_id)); + + let merge_label = self.make_merge_label(ctx); + let next_label = ctx.gen_id(); + ctx.function.consume( + std::mem::replace(self.block, Block::new(next_label)), + Instruction::branch_conditional(cond, next_label, merge_label), + ); + } + + /// Emit an unconditional branch to the merge block, and compute merged + /// values. + /// + /// Use `final_values` as the merged values contributed by the current + /// block, and transition to the merge block, emitting `OpPhi` instructions + /// to produce the merged values. This must be the same shape as the + /// `merge_types` argument passed to [`Selection::start`]. + /// + /// Return the SPIR-V ids of the merged values. This value has the same + /// shape as the `merge_types` argument passed to `Selection::start`. + pub(super) fn finish(self, ctx: &mut BlockContext, final_values: M) -> M { + match self { + Selection { + merge_label: None, .. + } => { + // We didn't actually emit any branches, so `self.values` must + // be empty, and `final_values` are the only sources we have for + // the merged values. Easy peasy. + final_values + } + + Selection { + block, + merge_label: Some(merge_label), + mut values, + merge_types, + } => { + // Emit the final branch and transition to the merge block. + values.push((final_values, block.label_id)); + ctx.function.consume( + std::mem::replace(block, Block::new(merge_label)), + Instruction::branch(merge_label), + ); + + // Now that we're in the merge block, build the phi instructions. + merge_types.write_phis(ctx, block, &values) + } + } + } + + /// Return the id of the merge block, writing a merge instruction if needed. + fn make_merge_label(&mut self, ctx: &mut BlockContext) -> Word { + match self.merge_label { + None => { + let merge_label = ctx.gen_id(); + self.block.body.push(Instruction::selection_merge( + merge_label, + spirv::SelectionControl::NONE, + )); + self.merge_label = Some(merge_label); + merge_label + } + Some(merge_label) => merge_label, + } + } +} + +/// A trait to help `Selection` manage any number of merged values. +/// +/// Some selection constructs, like a `ReadZeroSkipWrite` bounds check on a +/// [`Load`] expression, produce a single merged value. Others produce no merged +/// value, like a bounds check on a [`Store`] statement. +/// +/// To let `Selection` work nicely with both cases, we let the merge type +/// argument passed to [`Selection::start`] be any type that implements this +/// `MergeTuple` trait. `MergeTuple` is then implemented for `()`, `Word`, +/// `(Word, Word)`, and so on. +/// +/// A `MergeTuple` type can represent either a bunch of SPIR-V types or values; +/// the `merge_types` argument to `Selection::start` are type ids, whereas the +/// `values` arguments to the [`if_true`] and [`finish`] methods are value ids. +/// The set of merged value returned by `finish` is a tuple of value ids. +/// +/// In fact, since Naga only uses zero- and single-valued selection constructs +/// at present, we only implement `MergeTuple` for `()` and `Word`. But if you +/// add more cases, feel free to add more implementations. Once const generics +/// are available, we could have a single implementation of `MergeTuple` for all +/// lengths of arrays, and be done with it. +/// +/// [`Load`]: crate::Expression::Load +/// [`Store`]: crate::Statement::Store +/// [`if_true`]: Selection::if_true +/// [`finish`]: Selection::finish +pub(super) trait MergeTuple: Sized { + /// Write OpPhi instructions for the given set of predecessors. + /// + /// The `predecessors` vector should be a vector of `(LABEL, VALUES)` pairs, + /// where each `VALUES` holds the values contributed by the branch from + /// `LABEL`, which should be one of the current block's predecessors. + fn write_phis( + self, + ctx: &mut BlockContext, + block: &mut Block, + predecessors: &[(Self, Word)], + ) -> Self; +} + +/// Selections that produce a single merged value. +/// +/// For example, `ImageLoad` with `BoundsCheckPolicy::ReadZeroSkipWrite` either +/// returns a texel value or zeros. +impl MergeTuple for Word { + fn write_phis( + self, + ctx: &mut BlockContext, + block: &mut Block, + predecessors: &[(Word, Word)], + ) -> Word { + let merged_value = ctx.gen_id(); + block + .body + .push(Instruction::phi(self, merged_value, predecessors)); + merged_value + } +} + +/// Selections that produce no merged values. +/// +/// For example, `ImageStore` under `BoundsCheckPolicy::ReadZeroSkipWrite` +/// either does the store or skips it, but in neither case does it produce a +/// value. +impl MergeTuple for () { + /// No phis need to be generated. + fn write_phis(self, _: &mut BlockContext, _: &mut Block, _: &[((), Word)]) {} +} diff --git a/third_party/rust/naga/src/back/spv/writer.rs b/third_party/rust/naga/src/back/spv/writer.rs new file mode 100644 index 0000000000..4db86c93a7 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/writer.rs @@ -0,0 +1,2063 @@ +use super::{ + block::DebugInfoInner, + helpers::{contains_builtin, global_needs_wrapper, map_storage_class}, + make_local, Block, BlockContext, CachedConstant, CachedExpressions, DebugInfo, + EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction, + LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, LoopContext, Options, + PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE, +}; +use crate::{ + arena::{Handle, UniqueArena}, + back::spv::BindingInfo, + proc::{Alignment, TypeResolution}, + valid::{FunctionInfo, ModuleInfo}, +}; +use spirv::Word; +use std::collections::hash_map::Entry; + +struct FunctionInterface<'a> { + varying_ids: &'a mut Vec<Word>, + stage: crate::ShaderStage, +} + +impl Function { + fn to_words(&self, sink: &mut impl Extend<Word>) { + self.signature.as_ref().unwrap().to_words(sink); + for argument in self.parameters.iter() { + argument.instruction.to_words(sink); + } + for (index, block) in self.blocks.iter().enumerate() { + Instruction::label(block.label_id).to_words(sink); + if index == 0 { + for local_var in self.variables.values() { + local_var.instruction.to_words(sink); + } + } + for instruction in block.body.iter() { + instruction.to_words(sink); + } + } + } +} + +impl Writer { + pub fn new(options: &Options) -> Result<Self, Error> { + let (major, minor) = options.lang_version; + if major != 1 { + return Err(Error::UnsupportedVersion(major, minor)); + } + let raw_version = ((major as u32) << 16) | ((minor as u32) << 8); + + let mut capabilities_used = crate::FastIndexSet::default(); + capabilities_used.insert(spirv::Capability::Shader); + + let mut id_gen = IdGenerator::default(); + let gl450_ext_inst_id = id_gen.next(); + let void_type = id_gen.next(); + + Ok(Writer { + physical_layout: PhysicalLayout::new(raw_version), + logical_layout: LogicalLayout::default(), + id_gen, + capabilities_available: options.capabilities.clone(), + capabilities_used, + extensions_used: crate::FastIndexSet::default(), + debugs: vec![], + annotations: vec![], + flags: options.flags, + bounds_check_policies: options.bounds_check_policies, + zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory, + void_type, + lookup_type: crate::FastHashMap::default(), + lookup_function: crate::FastHashMap::default(), + lookup_function_type: crate::FastHashMap::default(), + constant_ids: Vec::new(), + cached_constants: crate::FastHashMap::default(), + global_variables: Vec::new(), + binding_map: options.binding_map.clone(), + saved_cached: CachedExpressions::default(), + gl450_ext_inst_id, + temp_list: Vec::new(), + }) + } + + /// Reset `Writer` to its initial state, retaining any allocations. + /// + /// Why not just implement `Recyclable` for `Writer`? By design, + /// `Recyclable::recycle` requires ownership of the value, not just + /// `&mut`; see the trait documentation. But we need to use this method + /// from functions like `Writer::write`, which only have `&mut Writer`. + /// Workarounds include unsafe code (`std::ptr::read`, then `write`, ugh) + /// or something like a `Default` impl that returns an oddly-initialized + /// `Writer`, which is worse. + fn reset(&mut self) { + use super::recyclable::Recyclable; + use std::mem::take; + + let mut id_gen = IdGenerator::default(); + let gl450_ext_inst_id = id_gen.next(); + let void_type = id_gen.next(); + + // Every field of the old writer that is not determined by the `Options` + // passed to `Writer::new` should be reset somehow. + let fresh = Writer { + // Copied from the old Writer: + flags: self.flags, + bounds_check_policies: self.bounds_check_policies, + zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory, + capabilities_available: take(&mut self.capabilities_available), + binding_map: take(&mut self.binding_map), + + // Initialized afresh: + id_gen, + void_type, + gl450_ext_inst_id, + + // Recycled: + capabilities_used: take(&mut self.capabilities_used).recycle(), + extensions_used: take(&mut self.extensions_used).recycle(), + physical_layout: self.physical_layout.clone().recycle(), + logical_layout: take(&mut self.logical_layout).recycle(), + debugs: take(&mut self.debugs).recycle(), + annotations: take(&mut self.annotations).recycle(), + lookup_type: take(&mut self.lookup_type).recycle(), + lookup_function: take(&mut self.lookup_function).recycle(), + lookup_function_type: take(&mut self.lookup_function_type).recycle(), + constant_ids: take(&mut self.constant_ids).recycle(), + cached_constants: take(&mut self.cached_constants).recycle(), + global_variables: take(&mut self.global_variables).recycle(), + saved_cached: take(&mut self.saved_cached).recycle(), + temp_list: take(&mut self.temp_list).recycle(), + }; + + *self = fresh; + + self.capabilities_used.insert(spirv::Capability::Shader); + } + + /// Indicate that the code requires any one of the listed capabilities. + /// + /// If nothing in `capabilities` appears in the available capabilities + /// specified in the [`Options`] from which this `Writer` was created, + /// return an error. The `what` string is used in the error message to + /// explain what provoked the requirement. (If no available capabilities were + /// given, assume everything is available.) + /// + /// The first acceptable capability will be added to this `Writer`'s + /// [`capabilities_used`] table, and an `OpCapability` emitted for it in the + /// result. For this reason, more specific capabilities should be listed + /// before more general. + /// + /// [`capabilities_used`]: Writer::capabilities_used + pub(super) fn require_any( + &mut self, + what: &'static str, + capabilities: &[spirv::Capability], + ) -> Result<(), Error> { + match *capabilities { + [] => Ok(()), + [first, ..] => { + // Find the first acceptable capability, or return an error if + // there is none. + let selected = match self.capabilities_available { + None => first, + Some(ref available) => { + match capabilities.iter().find(|cap| available.contains(cap)) { + Some(&cap) => cap, + None => { + return Err(Error::MissingCapabilities(what, capabilities.to_vec())) + } + } + } + }; + self.capabilities_used.insert(selected); + Ok(()) + } + } + } + + /// Indicate that the code uses the given extension. + pub(super) fn use_extension(&mut self, extension: &'static str) { + self.extensions_used.insert(extension); + } + + pub(super) fn get_type_id(&mut self, lookup_ty: LookupType) -> Word { + match self.lookup_type.entry(lookup_ty) { + Entry::Occupied(e) => *e.get(), + Entry::Vacant(e) => { + let local = match lookup_ty { + LookupType::Handle(_handle) => unreachable!("Handles are populated at start"), + LookupType::Local(local) => local, + }; + + let id = self.id_gen.next(); + e.insert(id); + self.write_type_declaration_local(id, local); + id + } + } + } + + pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType { + match *tr { + TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle), + TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()), + } + } + + pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word { + let lookup_ty = self.get_expression_lookup_type(tr); + self.get_type_id(lookup_ty) + } + + pub(super) fn get_pointer_id( + &mut self, + arena: &UniqueArena<crate::Type>, + handle: Handle<crate::Type>, + class: spirv::StorageClass, + ) -> Result<Word, Error> { + let ty_id = self.get_type_id(LookupType::Handle(handle)); + if let crate::TypeInner::Pointer { .. } = arena[handle].inner { + return Ok(ty_id); + } + let lookup_type = LookupType::Local(LocalType::Pointer { + base: handle, + class, + }); + Ok(if let Some(&id) = self.lookup_type.get(&lookup_type) { + id + } else { + let id = self.id_gen.next(); + let instruction = Instruction::type_pointer(id, class, ty_id); + instruction.to_words(&mut self.logical_layout.declarations); + self.lookup_type.insert(lookup_type, id); + id + }) + } + + pub(super) fn get_uint_type_id(&mut self) -> Word { + let local_type = LocalType::Value { + vector_size: None, + scalar: crate::Scalar::U32, + pointer_space: None, + }; + self.get_type_id(local_type.into()) + } + + pub(super) fn get_float_type_id(&mut self) -> Word { + let local_type = LocalType::Value { + vector_size: None, + scalar: crate::Scalar::F32, + pointer_space: None, + }; + self.get_type_id(local_type.into()) + } + + pub(super) fn get_uint3_type_id(&mut self) -> Word { + let local_type = LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + scalar: crate::Scalar::U32, + pointer_space: None, + }; + self.get_type_id(local_type.into()) + } + + pub(super) fn get_float_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word { + let lookup_type = LookupType::Local(LocalType::Value { + vector_size: None, + scalar: crate::Scalar::F32, + pointer_space: Some(class), + }); + if let Some(&id) = self.lookup_type.get(&lookup_type) { + id + } else { + let id = self.id_gen.next(); + let ty_id = self.get_float_type_id(); + let instruction = Instruction::type_pointer(id, class, ty_id); + instruction.to_words(&mut self.logical_layout.declarations); + self.lookup_type.insert(lookup_type, id); + id + } + } + + pub(super) fn get_uint3_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word { + let lookup_type = LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + scalar: crate::Scalar::U32, + pointer_space: Some(class), + }); + if let Some(&id) = self.lookup_type.get(&lookup_type) { + id + } else { + let id = self.id_gen.next(); + let ty_id = self.get_uint3_type_id(); + let instruction = Instruction::type_pointer(id, class, ty_id); + instruction.to_words(&mut self.logical_layout.declarations); + self.lookup_type.insert(lookup_type, id); + id + } + } + + pub(super) fn get_bool_type_id(&mut self) -> Word { + let local_type = LocalType::Value { + vector_size: None, + scalar: crate::Scalar::BOOL, + pointer_space: None, + }; + self.get_type_id(local_type.into()) + } + + pub(super) fn get_bool3_type_id(&mut self) -> Word { + let local_type = LocalType::Value { + vector_size: Some(crate::VectorSize::Tri), + scalar: crate::Scalar::BOOL, + pointer_space: None, + }; + self.get_type_id(local_type.into()) + } + + pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) { + self.annotations + .push(Instruction::decorate(id, decoration, operands)); + } + + fn write_function( + &mut self, + ir_function: &crate::Function, + info: &FunctionInfo, + ir_module: &crate::Module, + mut interface: Option<FunctionInterface>, + debug_info: &Option<DebugInfoInner>, + ) -> Result<Word, Error> { + let mut function = Function::default(); + + let prelude_id = self.id_gen.next(); + let mut prelude = Block::new(prelude_id); + let mut ep_context = EntryPointContext { + argument_ids: Vec::new(), + results: Vec::new(), + }; + + let mut local_invocation_id = None; + + let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len()); + for argument in ir_function.arguments.iter() { + let class = spirv::StorageClass::Input; + let handle_ty = ir_module.types[argument.ty].inner.is_handle(); + let argument_type_id = match handle_ty { + true => self.get_pointer_id( + &ir_module.types, + argument.ty, + spirv::StorageClass::UniformConstant, + )?, + false => self.get_type_id(LookupType::Handle(argument.ty)), + }; + + if let Some(ref mut iface) = interface { + let id = if let Some(ref binding) = argument.binding { + let name = argument.name.as_deref(); + + let varying_id = self.write_varying( + ir_module, + iface.stage, + class, + name, + argument.ty, + binding, + )?; + iface.varying_ids.push(varying_id); + let id = self.id_gen.next(); + prelude + .body + .push(Instruction::load(argument_type_id, id, varying_id, None)); + + if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) { + local_invocation_id = Some(id); + } + + id + } else if let crate::TypeInner::Struct { ref members, .. } = + ir_module.types[argument.ty].inner + { + let struct_id = self.id_gen.next(); + let mut constituent_ids = Vec::with_capacity(members.len()); + for member in members { + let type_id = self.get_type_id(LookupType::Handle(member.ty)); + let name = member.name.as_deref(); + let binding = member.binding.as_ref().unwrap(); + let varying_id = self.write_varying( + ir_module, + iface.stage, + class, + name, + member.ty, + binding, + )?; + iface.varying_ids.push(varying_id); + let id = self.id_gen.next(); + prelude + .body + .push(Instruction::load(type_id, id, varying_id, None)); + constituent_ids.push(id); + + if binding == &crate::Binding::BuiltIn(crate::BuiltIn::GlobalInvocationId) { + local_invocation_id = Some(id); + } + } + prelude.body.push(Instruction::composite_construct( + argument_type_id, + struct_id, + &constituent_ids, + )); + struct_id + } else { + unreachable!("Missing argument binding on an entry point"); + }; + ep_context.argument_ids.push(id); + } else { + let argument_id = self.id_gen.next(); + let instruction = Instruction::function_parameter(argument_type_id, argument_id); + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(ref name) = argument.name { + self.debugs.push(Instruction::name(argument_id, name)); + } + } + function.parameters.push(FunctionArgument { + instruction, + handle_id: if handle_ty { + let id = self.id_gen.next(); + prelude.body.push(Instruction::load( + self.get_type_id(LookupType::Handle(argument.ty)), + id, + argument_id, + None, + )); + id + } else { + 0 + }, + }); + parameter_type_ids.push(argument_type_id); + }; + } + + let return_type_id = match ir_function.result { + Some(ref result) => { + if let Some(ref mut iface) = interface { + let mut has_point_size = false; + let class = spirv::StorageClass::Output; + if let Some(ref binding) = result.binding { + has_point_size |= + *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize); + let type_id = self.get_type_id(LookupType::Handle(result.ty)); + let varying_id = self.write_varying( + ir_module, + iface.stage, + class, + None, + result.ty, + binding, + )?; + iface.varying_ids.push(varying_id); + ep_context.results.push(ResultMember { + id: varying_id, + type_id, + built_in: binding.to_built_in(), + }); + } else if let crate::TypeInner::Struct { ref members, .. } = + ir_module.types[result.ty].inner + { + for member in members { + let type_id = self.get_type_id(LookupType::Handle(member.ty)); + let name = member.name.as_deref(); + let binding = member.binding.as_ref().unwrap(); + has_point_size |= + *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize); + let varying_id = self.write_varying( + ir_module, + iface.stage, + class, + name, + member.ty, + binding, + )?; + iface.varying_ids.push(varying_id); + ep_context.results.push(ResultMember { + id: varying_id, + type_id, + built_in: binding.to_built_in(), + }); + } + } else { + unreachable!("Missing result binding on an entry point"); + } + + if self.flags.contains(WriterFlags::FORCE_POINT_SIZE) + && iface.stage == crate::ShaderStage::Vertex + && !has_point_size + { + // add point size artificially + let varying_id = self.id_gen.next(); + let pointer_type_id = self.get_float_pointer_type_id(class); + Instruction::variable(pointer_type_id, varying_id, class, None) + .to_words(&mut self.logical_layout.declarations); + self.decorate( + varying_id, + spirv::Decoration::BuiltIn, + &[spirv::BuiltIn::PointSize as u32], + ); + iface.varying_ids.push(varying_id); + + let default_value_id = self.get_constant_scalar(crate::Literal::F32(1.0)); + prelude + .body + .push(Instruction::store(varying_id, default_value_id, None)); + } + self.void_type + } else { + self.get_type_id(LookupType::Handle(result.ty)) + } + } + None => self.void_type, + }; + + let lookup_function_type = LookupFunctionType { + parameter_type_ids, + return_type_id, + }; + + let function_id = self.id_gen.next(); + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(ref name) = ir_function.name { + self.debugs.push(Instruction::name(function_id, name)); + } + } + + let function_type = self.get_function_type(lookup_function_type); + function.signature = Some(Instruction::function( + return_type_id, + function_id, + spirv::FunctionControl::empty(), + function_type, + )); + + if interface.is_some() { + function.entry_point_context = Some(ep_context); + } + + // fill up the `GlobalVariable::access_id` + for gv in self.global_variables.iter_mut() { + gv.reset_for_function(); + } + for (handle, var) in ir_module.global_variables.iter() { + if info[handle].is_empty() { + continue; + } + + let mut gv = self.global_variables[handle.index()].clone(); + if let Some(ref mut iface) = interface { + // Have to include global variables in the interface + if self.physical_layout.version >= 0x10400 { + iface.varying_ids.push(gv.var_id); + } + } + + // Handle globals are pre-emitted and should be loaded automatically. + // + // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing. + let is_binding_array = match ir_module.types[var.ty].inner { + crate::TypeInner::BindingArray { .. } => true, + _ => false, + }; + + if var.space == crate::AddressSpace::Handle && !is_binding_array { + let var_type_id = self.get_type_id(LookupType::Handle(var.ty)); + let id = self.id_gen.next(); + prelude + .body + .push(Instruction::load(var_type_id, id, gv.var_id, None)); + gv.access_id = gv.var_id; + gv.handle_id = id; + } else if global_needs_wrapper(ir_module, var) { + let class = map_storage_class(var.space); + let pointer_type_id = self.get_pointer_id(&ir_module.types, var.ty, class)?; + let index_id = self.get_index_constant(0); + + let id = self.id_gen.next(); + prelude.body.push(Instruction::access_chain( + pointer_type_id, + id, + gv.var_id, + &[index_id], + )); + gv.access_id = id; + } else { + // by default, the variable ID is accessed as is + gv.access_id = gv.var_id; + }; + + // work around borrow checking in the presence of `self.xxx()` calls + self.global_variables[handle.index()] = gv; + } + + // Create a `BlockContext` for generating SPIR-V for the function's + // body. + let mut context = BlockContext { + ir_module, + ir_function, + fun_info: info, + function: &mut function, + // Re-use the cached expression table from prior functions. + cached: std::mem::take(&mut self.saved_cached), + + // Steal the Writer's temp list for a bit. + temp_list: std::mem::take(&mut self.temp_list), + writer: self, + expression_constness: crate::proc::ExpressionConstnessTracker::from_arena( + &ir_function.expressions, + ), + }; + + // fill up the pre-emitted and const expressions + context.cached.reset(ir_function.expressions.len()); + for (handle, expr) in ir_function.expressions.iter() { + if (expr.needs_pre_emit() && !matches!(*expr, crate::Expression::LocalVariable(_))) + || context.expression_constness.is_const(handle) + { + context.cache_expression_value(handle, &mut prelude)?; + } + } + + for (handle, variable) in ir_function.local_variables.iter() { + let id = context.gen_id(); + + if context.writer.flags.contains(WriterFlags::DEBUG) { + if let Some(ref name) = variable.name { + context.writer.debugs.push(Instruction::name(id, name)); + } + } + + let init_word = variable.init.map(|constant| context.cached[constant]); + let pointer_type_id = context.writer.get_pointer_id( + &ir_module.types, + variable.ty, + spirv::StorageClass::Function, + )?; + let instruction = Instruction::variable( + pointer_type_id, + id, + spirv::StorageClass::Function, + init_word.or_else(|| match ir_module.types[variable.ty].inner { + crate::TypeInner::RayQuery => None, + _ => { + let type_id = context.get_type_id(LookupType::Handle(variable.ty)); + Some(context.writer.write_constant_null(type_id)) + } + }), + ); + context + .function + .variables + .insert(handle, LocalVariable { id, instruction }); + } + + // cache local variable expressions + for (handle, expr) in ir_function.expressions.iter() { + if matches!(*expr, crate::Expression::LocalVariable(_)) { + context.cache_expression_value(handle, &mut prelude)?; + } + } + + let next_id = context.gen_id(); + + context + .function + .consume(prelude, Instruction::branch(next_id)); + + let workgroup_vars_init_exit_block_id = + match (context.writer.zero_initialize_workgroup_memory, interface) { + ( + super::ZeroInitializeWorkgroupMemoryMode::Polyfill, + Some( + ref mut interface @ FunctionInterface { + stage: crate::ShaderStage::Compute, + .. + }, + ), + ) => context.writer.generate_workgroup_vars_init_block( + next_id, + ir_module, + info, + local_invocation_id, + interface, + context.function, + ), + _ => None, + }; + + let main_id = if let Some(exit_id) = workgroup_vars_init_exit_block_id { + exit_id + } else { + next_id + }; + + context.write_block( + main_id, + &ir_function.body, + super::block::BlockExit::Return, + LoopContext::default(), + debug_info.as_ref(), + )?; + + // Consume the `BlockContext`, ending its borrows and letting the + // `Writer` steal back its cached expression table and temp_list. + let BlockContext { + cached, temp_list, .. + } = context; + self.saved_cached = cached; + self.temp_list = temp_list; + + function.to_words(&mut self.logical_layout.function_definitions); + Instruction::function_end().to_words(&mut self.logical_layout.function_definitions); + + Ok(function_id) + } + + fn write_execution_mode( + &mut self, + function_id: Word, + mode: spirv::ExecutionMode, + ) -> Result<(), Error> { + //self.check(mode.required_capabilities())?; + Instruction::execution_mode(function_id, mode, &[]) + .to_words(&mut self.logical_layout.execution_modes); + Ok(()) + } + + // TODO Move to instructions module + fn write_entry_point( + &mut self, + entry_point: &crate::EntryPoint, + info: &FunctionInfo, + ir_module: &crate::Module, + debug_info: &Option<DebugInfoInner>, + ) -> Result<Instruction, Error> { + let mut interface_ids = Vec::new(); + let function_id = self.write_function( + &entry_point.function, + info, + ir_module, + Some(FunctionInterface { + varying_ids: &mut interface_ids, + stage: entry_point.stage, + }), + debug_info, + )?; + + let exec_model = match entry_point.stage { + crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex, + crate::ShaderStage::Fragment => { + self.write_execution_mode(function_id, spirv::ExecutionMode::OriginUpperLeft)?; + if let Some(ref result) = entry_point.function.result { + if contains_builtin( + result.binding.as_ref(), + result.ty, + &ir_module.types, + crate::BuiltIn::FragDepth, + ) { + self.write_execution_mode( + function_id, + spirv::ExecutionMode::DepthReplacing, + )?; + } + } + spirv::ExecutionModel::Fragment + } + crate::ShaderStage::Compute => { + let execution_mode = spirv::ExecutionMode::LocalSize; + //self.check(execution_mode.required_capabilities())?; + Instruction::execution_mode( + function_id, + execution_mode, + &entry_point.workgroup_size, + ) + .to_words(&mut self.logical_layout.execution_modes); + spirv::ExecutionModel::GLCompute + } + }; + //self.check(exec_model.required_capabilities())?; + + Ok(Instruction::entry_point( + exec_model, + function_id, + &entry_point.name, + interface_ids.as_slice(), + )) + } + + fn make_scalar(&mut self, id: Word, scalar: crate::Scalar) -> Instruction { + use crate::ScalarKind as Sk; + + let bits = (scalar.width * BITS_PER_BYTE) as u32; + match scalar.kind { + Sk::Sint | Sk::Uint => { + let signedness = if scalar.kind == Sk::Sint { + super::instructions::Signedness::Signed + } else { + super::instructions::Signedness::Unsigned + }; + let cap = match bits { + 8 => Some(spirv::Capability::Int8), + 16 => Some(spirv::Capability::Int16), + 64 => Some(spirv::Capability::Int64), + _ => None, + }; + if let Some(cap) = cap { + self.capabilities_used.insert(cap); + } + Instruction::type_int(id, bits, signedness) + } + Sk::Float => { + if bits == 64 { + self.capabilities_used.insert(spirv::Capability::Float64); + } + Instruction::type_float(id, bits) + } + Sk::Bool => Instruction::type_bool(id), + Sk::AbstractInt | Sk::AbstractFloat => { + unreachable!("abstract types should never reach the backend"); + } + } + } + + fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> { + match *inner { + crate::TypeInner::Image { + dim, + arrayed, + class, + } => { + let sampled = match class { + crate::ImageClass::Sampled { .. } => true, + crate::ImageClass::Depth { .. } => true, + crate::ImageClass::Storage { format, .. } => { + self.request_image_format_capabilities(format.into())?; + false + } + }; + + match dim { + crate::ImageDimension::D1 => { + if sampled { + self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?; + } else { + self.require_any("1D storage images", &[spirv::Capability::Image1D])?; + } + } + crate::ImageDimension::Cube if arrayed => { + if sampled { + self.require_any( + "sampled cube array images", + &[spirv::Capability::SampledCubeArray], + )?; + } else { + self.require_any( + "cube array storage images", + &[spirv::Capability::ImageCubeArray], + )?; + } + } + _ => {} + } + } + crate::TypeInner::AccelerationStructure => { + self.require_any("Acceleration Structure", &[spirv::Capability::RayQueryKHR])?; + } + crate::TypeInner::RayQuery => { + self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?; + } + _ => {} + } + Ok(()) + } + + fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) { + let instruction = match local_ty { + LocalType::Value { + vector_size: None, + scalar, + pointer_space: None, + } => self.make_scalar(id, scalar), + LocalType::Value { + vector_size: Some(size), + scalar, + pointer_space: None, + } => { + let scalar_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar, + pointer_space: None, + })); + Instruction::type_vector(id, scalar_id, size) + } + LocalType::Matrix { + columns, + rows, + width, + } => { + let vector_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(rows), + scalar: crate::Scalar::float(width), + pointer_space: None, + })); + Instruction::type_matrix(id, vector_id, columns) + } + LocalType::Pointer { base, class } => { + let type_id = self.get_type_id(LookupType::Handle(base)); + Instruction::type_pointer(id, class, type_id) + } + LocalType::Value { + vector_size, + scalar, + pointer_space: Some(class), + } => { + let type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size, + scalar, + pointer_space: None, + })); + Instruction::type_pointer(id, class, type_id) + } + LocalType::Image(image) => { + let local_type = LocalType::Value { + vector_size: None, + scalar: crate::Scalar { + kind: image.sampled_type, + width: 4, + }, + pointer_space: None, + }; + let type_id = self.get_type_id(LookupType::Local(local_type)); + Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format) + } + LocalType::Sampler => Instruction::type_sampler(id), + LocalType::SampledImage { image_type_id } => { + Instruction::type_sampled_image(id, image_type_id) + } + LocalType::BindingArray { base, size } => { + let inner_ty = self.get_type_id(LookupType::Handle(base)); + let scalar_id = self.get_constant_scalar(crate::Literal::U32(size)); + Instruction::type_array(id, inner_ty, scalar_id) + } + LocalType::PointerToBindingArray { base, size, space } => { + let inner_ty = + self.get_type_id(LookupType::Local(LocalType::BindingArray { base, size })); + let class = map_storage_class(space); + Instruction::type_pointer(id, class, inner_ty) + } + LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id), + LocalType::RayQuery => Instruction::type_ray_query(id), + }; + + instruction.to_words(&mut self.logical_layout.declarations); + } + + fn write_type_declaration_arena( + &mut self, + arena: &UniqueArena<crate::Type>, + handle: Handle<crate::Type>, + ) -> Result<Word, Error> { + let ty = &arena[handle]; + let id = if let Some(local) = make_local(&ty.inner) { + // This type can be represented as a `LocalType`, so check if we've + // already written an instruction for it. If not, do so now, with + // `write_type_declaration_local`. + match self.lookup_type.entry(LookupType::Local(local)) { + // We already have an id for this `LocalType`. + Entry::Occupied(e) => *e.get(), + + // It's a type we haven't seen before. + Entry::Vacant(e) => { + let id = self.id_gen.next(); + e.insert(id); + + self.write_type_declaration_local(id, local); + + // If it's a type that needs SPIR-V capabilities, request them now, + // so write_type_declaration_local can stay infallible. + self.request_type_capabilities(&ty.inner)?; + + id + } + } + } else { + use spirv::Decoration; + + let id = self.id_gen.next(); + let instruction = match ty.inner { + crate::TypeInner::Array { base, size, stride } => { + self.decorate(id, Decoration::ArrayStride, &[stride]); + + let type_id = self.get_type_id(LookupType::Handle(base)); + match size { + crate::ArraySize::Constant(length) => { + let length_id = self.get_index_constant(length.get()); + Instruction::type_array(id, type_id, length_id) + } + crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id), + } + } + crate::TypeInner::BindingArray { base, size } => { + let type_id = self.get_type_id(LookupType::Handle(base)); + match size { + crate::ArraySize::Constant(length) => { + let length_id = self.get_index_constant(length.get()); + Instruction::type_array(id, type_id, length_id) + } + crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id), + } + } + crate::TypeInner::Struct { + ref members, + span: _, + } => { + let mut has_runtime_array = false; + let mut member_ids = Vec::with_capacity(members.len()); + for (index, member) in members.iter().enumerate() { + let member_ty = &arena[member.ty]; + match member_ty.inner { + crate::TypeInner::Array { + base: _, + size: crate::ArraySize::Dynamic, + stride: _, + } => { + has_runtime_array = true; + } + _ => (), + } + self.decorate_struct_member(id, index, member, arena)?; + let member_id = self.get_type_id(LookupType::Handle(member.ty)); + member_ids.push(member_id); + } + if has_runtime_array { + self.decorate(id, Decoration::Block, &[]); + } + Instruction::type_struct(id, member_ids.as_slice()) + } + + // These all have TypeLocal representations, so they should have been + // handled by `write_type_declaration_local` above. + crate::TypeInner::Scalar(_) + | crate::TypeInner::Atomic(_) + | crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } + | crate::TypeInner::Pointer { .. } + | crate::TypeInner::ValuePointer { .. } + | crate::TypeInner::Image { .. } + | crate::TypeInner::Sampler { .. } + | crate::TypeInner::AccelerationStructure + | crate::TypeInner::RayQuery => unreachable!(), + }; + + instruction.to_words(&mut self.logical_layout.declarations); + id + }; + + // Add this handle as a new alias for that type. + self.lookup_type.insert(LookupType::Handle(handle), id); + + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(ref name) = ty.name { + self.debugs.push(Instruction::name(id, name)); + } + } + + Ok(id) + } + + fn request_image_format_capabilities( + &mut self, + format: spirv::ImageFormat, + ) -> Result<(), Error> { + use spirv::ImageFormat as If; + match format { + If::Rg32f + | If::Rg16f + | If::R11fG11fB10f + | If::R16f + | If::Rgba16 + | If::Rgb10A2 + | If::Rg16 + | If::Rg8 + | If::R16 + | If::R8 + | If::Rgba16Snorm + | If::Rg16Snorm + | If::Rg8Snorm + | If::R16Snorm + | If::R8Snorm + | If::Rg32i + | If::Rg16i + | If::Rg8i + | If::R16i + | If::R8i + | If::Rgb10a2ui + | If::Rg32ui + | If::Rg16ui + | If::Rg8ui + | If::R16ui + | If::R8ui => self.require_any( + "storage image format", + &[spirv::Capability::StorageImageExtendedFormats], + ), + If::R64ui | If::R64i => self.require_any( + "64-bit integer storage image format", + &[spirv::Capability::Int64ImageEXT], + ), + If::Unknown + | If::Rgba32f + | If::Rgba16f + | If::R32f + | If::Rgba8 + | If::Rgba8Snorm + | If::Rgba32i + | If::Rgba16i + | If::Rgba8i + | If::R32i + | If::Rgba32ui + | If::Rgba16ui + | If::Rgba8ui + | If::R32ui => Ok(()), + } + } + + pub(super) fn get_index_constant(&mut self, index: Word) -> Word { + self.get_constant_scalar(crate::Literal::U32(index)) + } + + pub(super) fn get_constant_scalar_with( + &mut self, + value: u8, + scalar: crate::Scalar, + ) -> Result<Word, Error> { + Ok( + self.get_constant_scalar(crate::Literal::new(value, scalar).ok_or( + Error::Validation("Unexpected kind and/or width for Literal"), + )?), + ) + } + + pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word { + let scalar = CachedConstant::Literal(value); + if let Some(&id) = self.cached_constants.get(&scalar) { + return id; + } + let id = self.id_gen.next(); + self.write_constant_scalar(id, &value, None); + self.cached_constants.insert(scalar, id); + id + } + + fn write_constant_scalar( + &mut self, + id: Word, + value: &crate::Literal, + debug_name: Option<&String>, + ) { + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(name) = debug_name { + self.debugs.push(Instruction::name(id, name)); + } + } + let type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + scalar: value.scalar(), + pointer_space: None, + })); + let instruction = match *value { + crate::Literal::F64(value) => { + let bits = value.to_bits(); + Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32) + } + crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()), + crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value), + crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32), + crate::Literal::I64(value) => { + Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32) + } + crate::Literal::Bool(true) => Instruction::constant_true(type_id, id), + crate::Literal::Bool(false) => Instruction::constant_false(type_id, id), + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + unreachable!("Abstract types should not appear in IR presented to backends"); + } + }; + + instruction.to_words(&mut self.logical_layout.declarations); + } + + pub(super) fn get_constant_composite( + &mut self, + ty: LookupType, + constituent_ids: &[Word], + ) -> Word { + let composite = CachedConstant::Composite { + ty, + constituent_ids: constituent_ids.to_vec(), + }; + if let Some(&id) = self.cached_constants.get(&composite) { + return id; + } + let id = self.id_gen.next(); + self.write_constant_composite(id, ty, constituent_ids, None); + self.cached_constants.insert(composite, id); + id + } + + fn write_constant_composite( + &mut self, + id: Word, + ty: LookupType, + constituent_ids: &[Word], + debug_name: Option<&String>, + ) { + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(name) = debug_name { + self.debugs.push(Instruction::name(id, name)); + } + } + let type_id = self.get_type_id(ty); + Instruction::constant_composite(type_id, id, constituent_ids) + .to_words(&mut self.logical_layout.declarations); + } + + pub(super) fn get_constant_null(&mut self, type_id: Word) -> Word { + let null = CachedConstant::ZeroValue(type_id); + if let Some(&id) = self.cached_constants.get(&null) { + return id; + } + let id = self.write_constant_null(type_id); + self.cached_constants.insert(null, id); + id + } + + pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word { + let null_id = self.id_gen.next(); + Instruction::constant_null(type_id, null_id) + .to_words(&mut self.logical_layout.declarations); + null_id + } + + fn write_constant_expr( + &mut self, + handle: Handle<crate::Expression>, + ir_module: &crate::Module, + mod_info: &ModuleInfo, + ) -> Result<Word, Error> { + let id = match ir_module.const_expressions[handle] { + crate::Expression::Literal(literal) => self.get_constant_scalar(literal), + crate::Expression::Constant(constant) => { + let constant = &ir_module.constants[constant]; + self.constant_ids[constant.init.index()] + } + crate::Expression::ZeroValue(ty) => { + let type_id = self.get_type_id(LookupType::Handle(ty)); + self.get_constant_null(type_id) + } + crate::Expression::Compose { ty, ref components } => { + let component_ids: Vec<_> = crate::proc::flatten_compose( + ty, + components, + &ir_module.const_expressions, + &ir_module.types, + ) + .map(|component| self.constant_ids[component.index()]) + .collect(); + self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice()) + } + crate::Expression::Splat { size, value } => { + let value_id = self.constant_ids[value.index()]; + let component_ids = &[value_id; 4][..size as usize]; + + let ty = self.get_expression_lookup_type(&mod_info[handle]); + + self.get_constant_composite(ty, component_ids) + } + _ => unreachable!(), + }; + + self.constant_ids[handle.index()] = id; + + Ok(id) + } + + pub(super) fn write_barrier(&mut self, flags: crate::Barrier, block: &mut Block) { + let memory_scope = if flags.contains(crate::Barrier::STORAGE) { + spirv::Scope::Device + } else { + spirv::Scope::Workgroup + }; + let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE; + semantics.set( + spirv::MemorySemantics::UNIFORM_MEMORY, + flags.contains(crate::Barrier::STORAGE), + ); + semantics.set( + spirv::MemorySemantics::WORKGROUP_MEMORY, + flags.contains(crate::Barrier::WORK_GROUP), + ); + let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32); + let mem_scope_id = self.get_index_constant(memory_scope as u32); + let semantics_id = self.get_index_constant(semantics.bits()); + block.body.push(Instruction::control_barrier( + exec_scope_id, + mem_scope_id, + semantics_id, + )); + } + + fn generate_workgroup_vars_init_block( + &mut self, + entry_id: Word, + ir_module: &crate::Module, + info: &FunctionInfo, + local_invocation_id: Option<Word>, + interface: &mut FunctionInterface, + function: &mut Function, + ) -> Option<Word> { + let body = ir_module + .global_variables + .iter() + .filter(|&(handle, var)| { + !info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup + }) + .map(|(handle, var)| { + // It's safe to use `var_id` here, not `access_id`, because only + // variables in the `Uniform` and `StorageBuffer` address spaces + // get wrapped, and we're initializing `WorkGroup` variables. + let var_id = self.global_variables[handle.index()].var_id; + let var_type_id = self.get_type_id(LookupType::Handle(var.ty)); + let init_word = self.get_constant_null(var_type_id); + Instruction::store(var_id, init_word, None) + }) + .collect::<Vec<_>>(); + + if body.is_empty() { + return None; + } + + let uint3_type_id = self.get_uint3_type_id(); + + let mut pre_if_block = Block::new(entry_id); + + let local_invocation_id = if let Some(local_invocation_id) = local_invocation_id { + local_invocation_id + } else { + let varying_id = self.id_gen.next(); + let class = spirv::StorageClass::Input; + let pointer_type_id = self.get_uint3_pointer_type_id(class); + + Instruction::variable(pointer_type_id, varying_id, class, None) + .to_words(&mut self.logical_layout.declarations); + + self.decorate( + varying_id, + spirv::Decoration::BuiltIn, + &[spirv::BuiltIn::LocalInvocationId as u32], + ); + + interface.varying_ids.push(varying_id); + let id = self.id_gen.next(); + pre_if_block + .body + .push(Instruction::load(uint3_type_id, id, varying_id, None)); + + id + }; + + let zero_id = self.get_constant_null(uint3_type_id); + let bool3_type_id = self.get_bool3_type_id(); + + let eq_id = self.id_gen.next(); + pre_if_block.body.push(Instruction::binary( + spirv::Op::IEqual, + bool3_type_id, + eq_id, + local_invocation_id, + zero_id, + )); + + let condition_id = self.id_gen.next(); + let bool_type_id = self.get_bool_type_id(); + pre_if_block.body.push(Instruction::relational( + spirv::Op::All, + bool_type_id, + condition_id, + eq_id, + )); + + let merge_id = self.id_gen.next(); + pre_if_block.body.push(Instruction::selection_merge( + merge_id, + spirv::SelectionControl::NONE, + )); + + let accept_id = self.id_gen.next(); + function.consume( + pre_if_block, + Instruction::branch_conditional(condition_id, accept_id, merge_id), + ); + + let accept_block = Block { + label_id: accept_id, + body, + }; + function.consume(accept_block, Instruction::branch(merge_id)); + + let mut post_if_block = Block::new(merge_id); + + self.write_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block); + + let next_id = self.id_gen.next(); + function.consume(post_if_block, Instruction::branch(next_id)); + Some(next_id) + } + + /// Generate an `OpVariable` for one value in an [`EntryPoint`]'s IO interface. + /// + /// The [`Binding`]s of the arguments and result of an [`EntryPoint`]'s + /// [`Function`] describe a SPIR-V shader interface. In SPIR-V, the + /// interface is represented by global variables in the `Input` and `Output` + /// storage classes, with decorations indicating which builtin or location + /// each variable corresponds to. + /// + /// This function emits a single global `OpVariable` for a single value from + /// the interface, and adds appropriate decorations to indicate which + /// builtin or location it represents, how it should be interpolated, and so + /// on. The `class` argument gives the variable's SPIR-V storage class, + /// which should be either [`Input`] or [`Output`]. + /// + /// [`Binding`]: crate::Binding + /// [`Function`]: crate::Function + /// [`EntryPoint`]: crate::EntryPoint + /// [`Input`]: spirv::StorageClass::Input + /// [`Output`]: spirv::StorageClass::Output + fn write_varying( + &mut self, + ir_module: &crate::Module, + stage: crate::ShaderStage, + class: spirv::StorageClass, + debug_name: Option<&str>, + ty: Handle<crate::Type>, + binding: &crate::Binding, + ) -> Result<Word, Error> { + let id = self.id_gen.next(); + let pointer_type_id = self.get_pointer_id(&ir_module.types, ty, class)?; + Instruction::variable(pointer_type_id, id, class, None) + .to_words(&mut self.logical_layout.declarations); + + if self + .flags + .contains(WriterFlags::DEBUG | WriterFlags::LABEL_VARYINGS) + { + if let Some(name) = debug_name { + self.debugs.push(Instruction::name(id, name)); + } + } + + use spirv::{BuiltIn, Decoration}; + + match *binding { + crate::Binding::Location { + location, + interpolation, + sampling, + second_blend_source, + } => { + self.decorate(id, Decoration::Location, &[location]); + + let no_decorations = + // VUID-StandaloneSpirv-Flat-06202 + // > The Flat, NoPerspective, Sample, and Centroid decorations + // > must not be used on variables with the Input storage class in a vertex shader + (class == spirv::StorageClass::Input && stage == crate::ShaderStage::Vertex) || + // VUID-StandaloneSpirv-Flat-06201 + // > The Flat, NoPerspective, Sample, and Centroid decorations + // > must not be used on variables with the Output storage class in a fragment shader + (class == spirv::StorageClass::Output && stage == crate::ShaderStage::Fragment); + + if !no_decorations { + match interpolation { + // Perspective-correct interpolation is the default in SPIR-V. + None | Some(crate::Interpolation::Perspective) => (), + Some(crate::Interpolation::Flat) => { + self.decorate(id, Decoration::Flat, &[]); + } + Some(crate::Interpolation::Linear) => { + self.decorate(id, Decoration::NoPerspective, &[]); + } + } + match sampling { + // Center sampling is the default in SPIR-V. + None | Some(crate::Sampling::Center) => (), + Some(crate::Sampling::Centroid) => { + self.decorate(id, Decoration::Centroid, &[]); + } + Some(crate::Sampling::Sample) => { + self.require_any( + "per-sample interpolation", + &[spirv::Capability::SampleRateShading], + )?; + self.decorate(id, Decoration::Sample, &[]); + } + } + } + if second_blend_source { + self.decorate(id, Decoration::Index, &[1]); + } + } + crate::Binding::BuiltIn(built_in) => { + use crate::BuiltIn as Bi; + let built_in = match built_in { + Bi::Position { invariant } => { + if invariant { + self.decorate(id, Decoration::Invariant, &[]); + } + + if class == spirv::StorageClass::Output { + BuiltIn::Position + } else { + BuiltIn::FragCoord + } + } + Bi::ViewIndex => { + self.require_any("`view_index` built-in", &[spirv::Capability::MultiView])?; + BuiltIn::ViewIndex + } + // vertex + Bi::BaseInstance => BuiltIn::BaseInstance, + Bi::BaseVertex => BuiltIn::BaseVertex, + Bi::ClipDistance => { + self.require_any( + "`clip_distance` built-in", + &[spirv::Capability::ClipDistance], + )?; + BuiltIn::ClipDistance + } + Bi::CullDistance => { + self.require_any( + "`cull_distance` built-in", + &[spirv::Capability::CullDistance], + )?; + BuiltIn::CullDistance + } + Bi::InstanceIndex => BuiltIn::InstanceIndex, + Bi::PointSize => BuiltIn::PointSize, + Bi::VertexIndex => BuiltIn::VertexIndex, + // fragment + Bi::FragDepth => BuiltIn::FragDepth, + Bi::PointCoord => BuiltIn::PointCoord, + Bi::FrontFacing => BuiltIn::FrontFacing, + Bi::PrimitiveIndex => { + self.require_any( + "`primitive_index` built-in", + &[spirv::Capability::Geometry], + )?; + BuiltIn::PrimitiveId + } + Bi::SampleIndex => { + self.require_any( + "`sample_index` built-in", + &[spirv::Capability::SampleRateShading], + )?; + + BuiltIn::SampleId + } + Bi::SampleMask => BuiltIn::SampleMask, + // compute + Bi::GlobalInvocationId => BuiltIn::GlobalInvocationId, + Bi::LocalInvocationId => BuiltIn::LocalInvocationId, + Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex, + Bi::WorkGroupId => BuiltIn::WorkgroupId, + Bi::WorkGroupSize => BuiltIn::WorkgroupSize, + Bi::NumWorkGroups => BuiltIn::NumWorkgroups, + }; + + self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); + + use crate::ScalarKind as Sk; + + // Per the Vulkan spec, `VUID-StandaloneSpirv-Flat-04744`: + // + // > Any variable with integer or double-precision floating- + // > point type and with Input storage class in a fragment + // > shader, must be decorated Flat + if class == spirv::StorageClass::Input && stage == crate::ShaderStage::Fragment { + let is_flat = match ir_module.types[ty].inner { + crate::TypeInner::Scalar(scalar) + | crate::TypeInner::Vector { scalar, .. } => match scalar.kind { + Sk::Uint | Sk::Sint | Sk::Bool => true, + Sk::Float => false, + Sk::AbstractInt | Sk::AbstractFloat => { + return Err(Error::Validation( + "Abstract types should not appear in IR presented to backends", + )) + } + }, + _ => false, + }; + + if is_flat { + self.decorate(id, Decoration::Flat, &[]); + } + } + } + } + + Ok(id) + } + + fn write_global_variable( + &mut self, + ir_module: &crate::Module, + global_variable: &crate::GlobalVariable, + ) -> Result<Word, Error> { + use spirv::Decoration; + + let id = self.id_gen.next(); + let class = map_storage_class(global_variable.space); + + //self.check(class.required_capabilities())?; + + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(ref name) = global_variable.name { + self.debugs.push(Instruction::name(id, name)); + } + } + + let storage_access = match global_variable.space { + crate::AddressSpace::Storage { access } => Some(access), + _ => match ir_module.types[global_variable.ty].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage { access, .. }, + .. + } => Some(access), + _ => None, + }, + }; + if let Some(storage_access) = storage_access { + if !storage_access.contains(crate::StorageAccess::LOAD) { + self.decorate(id, Decoration::NonReadable, &[]); + } + if !storage_access.contains(crate::StorageAccess::STORE) { + self.decorate(id, Decoration::NonWritable, &[]); + } + } + + // Note: we should be able to substitute `binding_array<Foo, 0>`, + // but there is still code that tries to register the pre-substituted type, + // and it is failing on 0. + let mut substitute_inner_type_lookup = None; + if let Some(ref res_binding) = global_variable.binding { + self.decorate(id, Decoration::DescriptorSet, &[res_binding.group]); + self.decorate(id, Decoration::Binding, &[res_binding.binding]); + + if let Some(&BindingInfo { + binding_array_size: Some(remapped_binding_array_size), + }) = self.binding_map.get(res_binding) + { + if let crate::TypeInner::BindingArray { base, .. } = + ir_module.types[global_variable.ty].inner + { + substitute_inner_type_lookup = + Some(LookupType::Local(LocalType::PointerToBindingArray { + base, + size: remapped_binding_array_size, + space: global_variable.space, + })) + } + } + }; + + let init_word = global_variable + .init + .map(|constant| self.constant_ids[constant.index()]); + let inner_type_id = self.get_type_id( + substitute_inner_type_lookup.unwrap_or(LookupType::Handle(global_variable.ty)), + ); + + // generate the wrapping structure if needed + let pointer_type_id = if global_needs_wrapper(ir_module, global_variable) { + let wrapper_type_id = self.id_gen.next(); + + self.decorate(wrapper_type_id, Decoration::Block, &[]); + let member = crate::StructMember { + name: None, + ty: global_variable.ty, + binding: None, + offset: 0, + }; + self.decorate_struct_member(wrapper_type_id, 0, &member, &ir_module.types)?; + + Instruction::type_struct(wrapper_type_id, &[inner_type_id]) + .to_words(&mut self.logical_layout.declarations); + + let pointer_type_id = self.id_gen.next(); + Instruction::type_pointer(pointer_type_id, class, wrapper_type_id) + .to_words(&mut self.logical_layout.declarations); + + pointer_type_id + } else { + // This is a global variable in the Storage address space. The only + // way it could have `global_needs_wrapper() == false` is if it has + // a runtime-sized or binding array. + // Runtime-sized arrays were decorated when iterating through struct content. + // Now binding arrays require Block decorating. + if let crate::AddressSpace::Storage { .. } = global_variable.space { + match ir_module.types[global_variable.ty].inner { + crate::TypeInner::BindingArray { base, .. } => { + let decorated_id = self.get_type_id(LookupType::Handle(base)); + self.decorate(decorated_id, Decoration::Block, &[]); + } + _ => (), + }; + } + if substitute_inner_type_lookup.is_some() { + inner_type_id + } else { + self.get_pointer_id(&ir_module.types, global_variable.ty, class)? + } + }; + + let init_word = match (global_variable.space, self.zero_initialize_workgroup_memory) { + (crate::AddressSpace::Private, _) + | (crate::AddressSpace::WorkGroup, super::ZeroInitializeWorkgroupMemoryMode::Native) => { + init_word.or_else(|| Some(self.get_constant_null(inner_type_id))) + } + _ => init_word, + }; + + Instruction::variable(pointer_type_id, id, class, init_word) + .to_words(&mut self.logical_layout.declarations); + Ok(id) + } + + /// Write the necessary decorations for a struct member. + /// + /// Emit decorations for the `index`'th member of the struct type + /// designated by `struct_id`, described by `member`. + fn decorate_struct_member( + &mut self, + struct_id: Word, + index: usize, + member: &crate::StructMember, + arena: &UniqueArena<crate::Type>, + ) -> Result<(), Error> { + use spirv::Decoration; + + self.annotations.push(Instruction::member_decorate( + struct_id, + index as u32, + Decoration::Offset, + &[member.offset], + )); + + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(ref name) = member.name { + self.debugs + .push(Instruction::member_name(struct_id, index as u32, name)); + } + } + + // Matrices and arrays of matrices both require decorations, + // so "see through" an array to determine if they're needed. + let member_array_subty_inner = match arena[member.ty].inner { + crate::TypeInner::Array { base, .. } => &arena[base].inner, + ref other => other, + }; + if let crate::TypeInner::Matrix { + columns: _, + rows, + scalar, + } = *member_array_subty_inner + { + let byte_stride = Alignment::from(rows) * scalar.width as u32; + self.annotations.push(Instruction::member_decorate( + struct_id, + index as u32, + Decoration::ColMajor, + &[], + )); + self.annotations.push(Instruction::member_decorate( + struct_id, + index as u32, + Decoration::MatrixStride, + &[byte_stride], + )); + } + + Ok(()) + } + + fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word { + match self + .lookup_function_type + .entry(lookup_function_type.clone()) + { + Entry::Occupied(e) => *e.get(), + Entry::Vacant(_) => { + let id = self.id_gen.next(); + let instruction = Instruction::type_function( + id, + lookup_function_type.return_type_id, + &lookup_function_type.parameter_type_ids, + ); + instruction.to_words(&mut self.logical_layout.declarations); + self.lookup_function_type.insert(lookup_function_type, id); + id + } + } + } + + fn write_physical_layout(&mut self) { + self.physical_layout.bound = self.id_gen.0 + 1; + } + + fn write_logical_layout( + &mut self, + ir_module: &crate::Module, + mod_info: &ModuleInfo, + ep_index: Option<usize>, + debug_info: &Option<DebugInfo>, + ) -> Result<(), Error> { + fn has_view_index_check( + ir_module: &crate::Module, + binding: Option<&crate::Binding>, + ty: Handle<crate::Type>, + ) -> bool { + match ir_module.types[ty].inner { + crate::TypeInner::Struct { ref members, .. } => members.iter().any(|member| { + has_view_index_check(ir_module, member.binding.as_ref(), member.ty) + }), + _ => binding == Some(&crate::Binding::BuiltIn(crate::BuiltIn::ViewIndex)), + } + } + + let has_storage_buffers = + ir_module + .global_variables + .iter() + .any(|(_, var)| match var.space { + crate::AddressSpace::Storage { .. } => true, + _ => false, + }); + let has_view_index = ir_module + .entry_points + .iter() + .flat_map(|entry| entry.function.arguments.iter()) + .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty)); + let has_ray_query = ir_module.special_types.ray_desc.is_some() + | ir_module.special_types.ray_intersection.is_some(); + + if self.physical_layout.version < 0x10300 && has_storage_buffers { + // enable the storage buffer class on < SPV-1.3 + Instruction::extension("SPV_KHR_storage_buffer_storage_class") + .to_words(&mut self.logical_layout.extensions); + } + if has_view_index { + Instruction::extension("SPV_KHR_multiview") + .to_words(&mut self.logical_layout.extensions) + } + if has_ray_query { + Instruction::extension("SPV_KHR_ray_query") + .to_words(&mut self.logical_layout.extensions) + } + Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations); + Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450") + .to_words(&mut self.logical_layout.ext_inst_imports); + + let mut debug_info_inner = None; + if self.flags.contains(WriterFlags::DEBUG) { + if let Some(debug_info) = debug_info.as_ref() { + let source_file_id = self.id_gen.next(); + self.debugs.push(Instruction::string( + &debug_info.file_name.display().to_string(), + source_file_id, + )); + + debug_info_inner = Some(DebugInfoInner { + source_code: debug_info.source_code, + source_file_id, + }); + self.debugs.push(Instruction::source( + spirv::SourceLanguage::Unknown, + 0, + &debug_info_inner, + )); + } + } + + // write all types + for (handle, _) in ir_module.types.iter() { + self.write_type_declaration_arena(&ir_module.types, handle)?; + } + + // write all const-expressions as constants + self.constant_ids + .resize(ir_module.const_expressions.len(), 0); + for (handle, _) in ir_module.const_expressions.iter() { + self.write_constant_expr(handle, ir_module, mod_info)?; + } + debug_assert!(self.constant_ids.iter().all(|&id| id != 0)); + + // write the name of constants on their respective const-expression initializer + if self.flags.contains(WriterFlags::DEBUG) { + for (_, constant) in ir_module.constants.iter() { + if let Some(ref name) = constant.name { + let id = self.constant_ids[constant.init.index()]; + self.debugs.push(Instruction::name(id, name)); + } + } + } + + // write all global variables + for (handle, var) in ir_module.global_variables.iter() { + // If a single entry point was specified, only write `OpVariable` instructions + // for the globals it actually uses. Emit dummies for the others, + // to preserve the indices in `global_variables`. + let gvar = match ep_index { + Some(index) if mod_info.get_entry_point(index)[handle].is_empty() => { + GlobalVariable::dummy() + } + _ => { + let id = self.write_global_variable(ir_module, var)?; + GlobalVariable::new(id) + } + }; + self.global_variables.push(gvar); + } + + // write all functions + for (handle, ir_function) in ir_module.functions.iter() { + let info = &mod_info[handle]; + if let Some(index) = ep_index { + let ep_info = mod_info.get_entry_point(index); + // If this function uses globals that we omitted from the SPIR-V + // because the entry point and its callees didn't use them, + // then we must skip it. + if !ep_info.dominates_global_use(info) { + log::info!("Skip function {:?}", ir_function.name); + continue; + } + + // Skip functions that that are not compatible with this entry point's stage. + // + // When validation is enabled, it rejects modules whose entry points try to call + // incompatible functions, so if we got this far, then any functions incompatible + // with our selected entry point must not be used. + // + // When validation is disabled, `fun_info.available_stages` is always just + // `ShaderStages::all()`, so this will write all functions in the module, and + // the downstream GLSL compiler will catch any problems. + if !info.available_stages.contains(ep_info.available_stages) { + continue; + } + } + let id = self.write_function(ir_function, info, ir_module, None, &debug_info_inner)?; + self.lookup_function.insert(handle, id); + } + + // write all or one entry points + for (index, ir_ep) in ir_module.entry_points.iter().enumerate() { + if ep_index.is_some() && ep_index != Some(index) { + continue; + } + let info = mod_info.get_entry_point(index); + let ep_instruction = + self.write_entry_point(ir_ep, info, ir_module, &debug_info_inner)?; + ep_instruction.to_words(&mut self.logical_layout.entry_points); + } + + for capability in self.capabilities_used.iter() { + Instruction::capability(*capability).to_words(&mut self.logical_layout.capabilities); + } + for extension in self.extensions_used.iter() { + Instruction::extension(extension).to_words(&mut self.logical_layout.extensions); + } + if ir_module.entry_points.is_empty() { + // SPIR-V doesn't like modules without entry points + Instruction::capability(spirv::Capability::Linkage) + .to_words(&mut self.logical_layout.capabilities); + } + + let addressing_model = spirv::AddressingModel::Logical; + let memory_model = spirv::MemoryModel::GLSL450; + //self.check(addressing_model.required_capabilities())?; + //self.check(memory_model.required_capabilities())?; + + Instruction::memory_model(addressing_model, memory_model) + .to_words(&mut self.logical_layout.memory_model); + + if self.flags.contains(WriterFlags::DEBUG) { + for debug in self.debugs.iter() { + debug.to_words(&mut self.logical_layout.debugs); + } + } + + for annotation in self.annotations.iter() { + annotation.to_words(&mut self.logical_layout.annotations); + } + + Ok(()) + } + + pub fn write( + &mut self, + ir_module: &crate::Module, + info: &ModuleInfo, + pipeline_options: Option<&PipelineOptions>, + debug_info: &Option<DebugInfo>, + words: &mut Vec<Word>, + ) -> Result<(), Error> { + self.reset(); + + // Try to find the entry point and corresponding index + let ep_index = match pipeline_options { + Some(po) => { + let index = ir_module + .entry_points + .iter() + .position(|ep| po.shader_stage == ep.stage && po.entry_point == ep.name) + .ok_or(Error::EntryPointNotFound)?; + Some(index) + } + None => None, + }; + + self.write_logical_layout(ir_module, info, ep_index, debug_info)?; + self.write_physical_layout(); + + self.physical_layout.in_words(words); + self.logical_layout.in_words(words); + Ok(()) + } + + /// Return the set of capabilities the last module written used. + pub const fn get_capabilities_used(&self) -> &crate::FastIndexSet<spirv::Capability> { + &self.capabilities_used + } + + pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> { + self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?; + self.use_extension("SPV_EXT_descriptor_indexing"); + self.decorate(id, spirv::Decoration::NonUniform, &[]); + Ok(()) + } +} + +#[test] +fn test_write_physical_layout() { + let mut writer = Writer::new(&Options::default()).unwrap(); + assert_eq!(writer.physical_layout.bound, 0); + writer.write_physical_layout(); + assert_eq!(writer.physical_layout.bound, 3); +} diff --git a/third_party/rust/naga/src/back/wgsl/mod.rs b/third_party/rust/naga/src/back/wgsl/mod.rs new file mode 100644 index 0000000000..d731b1ca0c --- /dev/null +++ b/third_party/rust/naga/src/back/wgsl/mod.rs @@ -0,0 +1,52 @@ +/*! +Backend for [WGSL][wgsl] (WebGPU Shading Language). + +[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html +*/ + +mod writer; + +use thiserror::Error; + +pub use writer::{Writer, WriterFlags}; + +#[derive(Error, Debug)] +pub enum Error { + #[error(transparent)] + FmtError(#[from] std::fmt::Error), + #[error("{0}")] + Custom(String), + #[error("{0}")] + Unimplemented(String), // TODO: Error used only during development + #[error("Unsupported math function: {0:?}")] + UnsupportedMathFunction(crate::MathFunction), + #[error("Unsupported relational function: {0:?}")] + UnsupportedRelationalFunction(crate::RelationalFunction), +} + +pub fn write_string( + module: &crate::Module, + info: &crate::valid::ModuleInfo, + flags: WriterFlags, +) -> Result<String, Error> { + let mut w = Writer::new(String::new(), flags); + w.write(module, info)?; + let output = w.finish(); + Ok(output) +} + +impl crate::AtomicFunction { + const fn to_wgsl(self) -> &'static str { + match self { + Self::Add => "Add", + Self::Subtract => "Sub", + Self::And => "And", + Self::InclusiveOr => "Or", + Self::ExclusiveOr => "Xor", + Self::Min => "Min", + Self::Max => "Max", + Self::Exchange { compare: None } => "Exchange", + Self::Exchange { .. } => "CompareExchangeWeak", + } + } +} diff --git a/third_party/rust/naga/src/back/wgsl/writer.rs b/third_party/rust/naga/src/back/wgsl/writer.rs new file mode 100644 index 0000000000..c737934f5e --- /dev/null +++ b/third_party/rust/naga/src/back/wgsl/writer.rs @@ -0,0 +1,1961 @@ +use super::Error; +use crate::{ + back, + proc::{self, NameKey}, + valid, Handle, Module, ShaderStage, TypeInner, +}; +use std::fmt::Write; + +/// Shorthand result used internally by the backend +type BackendResult = Result<(), Error>; + +/// WGSL [attribute](https://gpuweb.github.io/gpuweb/wgsl/#attributes) +enum Attribute { + Binding(u32), + BuiltIn(crate::BuiltIn), + Group(u32), + Invariant, + Interpolate(Option<crate::Interpolation>, Option<crate::Sampling>), + Location(u32), + SecondBlendSource, + Stage(ShaderStage), + WorkGroupSize([u32; 3]), +} + +/// The WGSL form that `write_expr_with_indirection` should use to render a Naga +/// expression. +/// +/// Sometimes a Naga `Expression` alone doesn't provide enough information to +/// choose the right rendering for it in WGSL. For example, one natural WGSL +/// rendering of a Naga `LocalVariable(x)` expression might be `&x`, since +/// `LocalVariable` produces a pointer to the local variable's storage. But when +/// rendering a `Store` statement, the `pointer` operand must be the left hand +/// side of a WGSL assignment, so the proper rendering is `x`. +/// +/// The caller of `write_expr_with_indirection` must provide an `Expected` value +/// to indicate how ambiguous expressions should be rendered. +#[derive(Clone, Copy, Debug)] +enum Indirection { + /// Render pointer-construction expressions as WGSL `ptr`-typed expressions. + /// + /// This is the right choice for most cases. Whenever a Naga pointer + /// expression is not the `pointer` operand of a `Load` or `Store`, it + /// must be a WGSL pointer expression. + Ordinary, + + /// Render pointer-construction expressions as WGSL reference-typed + /// expressions. + /// + /// For example, this is the right choice for the `pointer` operand when + /// rendering a `Store` statement as a WGSL assignment. + Reference, +} + +bitflags::bitflags! { + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct WriterFlags: u32 { + /// Always annotate the type information instead of inferring. + const EXPLICIT_TYPES = 0x1; + } +} + +pub struct Writer<W> { + out: W, + flags: WriterFlags, + names: crate::FastHashMap<NameKey, String>, + namer: proc::Namer, + named_expressions: crate::NamedExpressions, + ep_results: Vec<(ShaderStage, Handle<crate::Type>)>, +} + +impl<W: Write> Writer<W> { + pub fn new(out: W, flags: WriterFlags) -> Self { + Writer { + out, + flags, + names: crate::FastHashMap::default(), + namer: proc::Namer::default(), + named_expressions: crate::NamedExpressions::default(), + ep_results: vec![], + } + } + + fn reset(&mut self, module: &Module) { + self.names.clear(); + self.namer.reset( + module, + crate::keywords::wgsl::RESERVED, + // an identifier must not start with two underscore + &[], + &[], + &["__"], + &mut self.names, + ); + self.named_expressions.clear(); + self.ep_results.clear(); + } + + fn is_builtin_wgsl_struct(&self, module: &Module, handle: Handle<crate::Type>) -> bool { + module + .special_types + .predeclared_types + .values() + .any(|t| *t == handle) + } + + pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult { + self.reset(module); + + // Save all ep result types + for (_, ep) in module.entry_points.iter().enumerate() { + if let Some(ref result) = ep.function.result { + self.ep_results.push((ep.stage, result.ty)); + } + } + + // Write all structs + for (handle, ty) in module.types.iter() { + if let TypeInner::Struct { ref members, .. } = ty.inner { + { + if !self.is_builtin_wgsl_struct(module, handle) { + self.write_struct(module, handle, members)?; + writeln!(self.out)?; + } + } + } + } + + // Write all named constants + let mut constants = module + .constants + .iter() + .filter(|&(_, c)| c.name.is_some()) + .peekable(); + while let Some((handle, _)) = constants.next() { + self.write_global_constant(module, handle)?; + // Add extra newline for readability on last iteration + if constants.peek().is_none() { + writeln!(self.out)?; + } + } + + // Write all globals + for (ty, global) in module.global_variables.iter() { + self.write_global(module, global, ty)?; + } + + if !module.global_variables.is_empty() { + // Add extra newline for readability + writeln!(self.out)?; + } + + // Write all regular functions + for (handle, function) in module.functions.iter() { + let fun_info = &info[handle]; + + let func_ctx = back::FunctionCtx { + ty: back::FunctionType::Function(handle), + info: fun_info, + expressions: &function.expressions, + named_expressions: &function.named_expressions, + }; + + // Write the function + self.write_function(module, function, &func_ctx)?; + + writeln!(self.out)?; + } + + // Write all entry points + for (index, ep) in module.entry_points.iter().enumerate() { + let attributes = match ep.stage { + ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)], + ShaderStage::Compute => vec![ + Attribute::Stage(ShaderStage::Compute), + Attribute::WorkGroupSize(ep.workgroup_size), + ], + }; + + self.write_attributes(&attributes)?; + // Add a newline after attribute + writeln!(self.out)?; + + let func_ctx = back::FunctionCtx { + ty: back::FunctionType::EntryPoint(index as u16), + info: info.get_entry_point(index), + expressions: &ep.function.expressions, + named_expressions: &ep.function.named_expressions, + }; + self.write_function(module, &ep.function, &func_ctx)?; + + if index < module.entry_points.len() - 1 { + writeln!(self.out)?; + } + } + + Ok(()) + } + + /// Helper method used to write struct name + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_struct_name(&mut self, module: &Module, handle: Handle<crate::Type>) -> BackendResult { + if module.types[handle].name.is_none() { + if let Some(&(stage, _)) = self.ep_results.iter().find(|&&(_, ty)| ty == handle) { + let name = match stage { + ShaderStage::Compute => "ComputeOutput", + ShaderStage::Fragment => "FragmentOutput", + ShaderStage::Vertex => "VertexOutput", + }; + + write!(self.out, "{name}")?; + return Ok(()); + } + } + + write!(self.out, "{}", self.names[&NameKey::Type(handle)])?; + + Ok(()) + } + + /// Helper method used to write + /// [functions](https://gpuweb.github.io/gpuweb/wgsl/#functions) + /// + /// # Notes + /// Ends in a newline + fn write_function( + &mut self, + module: &Module, + func: &crate::Function, + func_ctx: &back::FunctionCtx<'_>, + ) -> BackendResult { + let func_name = match func_ctx.ty { + back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)], + back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)], + }; + + // Write function name + write!(self.out, "fn {func_name}(")?; + + // Write function arguments + for (index, arg) in func.arguments.iter().enumerate() { + // Write argument attribute if a binding is present + if let Some(ref binding) = arg.binding { + self.write_attributes(&map_binding_to_attribute(binding))?; + } + // Write argument name + let argument_name = &self.names[&func_ctx.argument_key(index as u32)]; + + write!(self.out, "{argument_name}: ")?; + // Write argument type + self.write_type(module, arg.ty)?; + if index < func.arguments.len() - 1 { + // Add a separator between args + write!(self.out, ", ")?; + } + } + + write!(self.out, ")")?; + + // Write function return type + if let Some(ref result) = func.result { + write!(self.out, " -> ")?; + if let Some(ref binding) = result.binding { + self.write_attributes(&map_binding_to_attribute(binding))?; + } + self.write_type(module, result.ty)?; + } + + write!(self.out, " {{")?; + writeln!(self.out)?; + + // Write function local variables + for (handle, local) in func.local_variables.iter() { + // Write indentation (only for readability) + write!(self.out, "{}", back::INDENT)?; + + // Write the local name + // The leading space is important + write!(self.out, "var {}: ", self.names[&func_ctx.name_key(handle)])?; + + // Write the local type + self.write_type(module, local.ty)?; + + // Write the local initializer if needed + if let Some(init) = local.init { + // Put the equal signal only if there's a initializer + // The leading and trailing spaces aren't needed but help with readability + write!(self.out, " = ")?; + + // Write the constant + // `write_constant` adds no trailing or leading space/newline + self.write_expr(module, init, func_ctx)?; + } + + // Finish the local with `;` and add a newline (only for readability) + writeln!(self.out, ";")? + } + + if !func.local_variables.is_empty() { + writeln!(self.out)?; + } + + // Write the function body (statement list) + for sta in func.body.iter() { + // The indentation should always be 1 when writing the function body + self.write_stmt(module, sta, func_ctx, back::Level(1))?; + } + + writeln!(self.out, "}}")?; + + self.named_expressions.clear(); + + Ok(()) + } + + /// Helper method to write a attribute + fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult { + for attribute in attributes { + match *attribute { + Attribute::Location(id) => write!(self.out, "@location({id}) ")?, + Attribute::SecondBlendSource => write!(self.out, "@second_blend_source ")?, + Attribute::BuiltIn(builtin_attrib) => { + let builtin = builtin_str(builtin_attrib)?; + write!(self.out, "@builtin({builtin}) ")?; + } + Attribute::Stage(shader_stage) => { + let stage_str = match shader_stage { + ShaderStage::Vertex => "vertex", + ShaderStage::Fragment => "fragment", + ShaderStage::Compute => "compute", + }; + write!(self.out, "@{stage_str} ")?; + } + Attribute::WorkGroupSize(size) => { + write!( + self.out, + "@workgroup_size({}, {}, {}) ", + size[0], size[1], size[2] + )?; + } + Attribute::Binding(id) => write!(self.out, "@binding({id}) ")?, + Attribute::Group(id) => write!(self.out, "@group({id}) ")?, + Attribute::Invariant => write!(self.out, "@invariant ")?, + Attribute::Interpolate(interpolation, sampling) => { + if sampling.is_some() && sampling != Some(crate::Sampling::Center) { + write!( + self.out, + "@interpolate({}, {}) ", + interpolation_str( + interpolation.unwrap_or(crate::Interpolation::Perspective) + ), + sampling_str(sampling.unwrap_or(crate::Sampling::Center)) + )?; + } else if interpolation.is_some() + && interpolation != Some(crate::Interpolation::Perspective) + { + write!( + self.out, + "@interpolate({}) ", + interpolation_str( + interpolation.unwrap_or(crate::Interpolation::Perspective) + ) + )?; + } + } + }; + } + Ok(()) + } + + /// Helper method used to write structs + /// + /// # Notes + /// Ends in a newline + fn write_struct( + &mut self, + module: &Module, + handle: Handle<crate::Type>, + members: &[crate::StructMember], + ) -> BackendResult { + write!(self.out, "struct ")?; + self.write_struct_name(module, handle)?; + write!(self.out, " {{")?; + writeln!(self.out)?; + for (index, member) in members.iter().enumerate() { + // The indentation is only for readability + write!(self.out, "{}", back::INDENT)?; + if let Some(ref binding) = member.binding { + self.write_attributes(&map_binding_to_attribute(binding))?; + } + // Write struct member name and type + let member_name = &self.names[&NameKey::StructMember(handle, index as u32)]; + write!(self.out, "{member_name}: ")?; + self.write_type(module, member.ty)?; + write!(self.out, ",")?; + writeln!(self.out)?; + } + + write!(self.out, "}}")?; + + writeln!(self.out)?; + + Ok(()) + } + + /// Helper method used to write non image/sampler types + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult { + let inner = &module.types[ty].inner; + match *inner { + TypeInner::Struct { .. } => self.write_struct_name(module, ty)?, + ref other => self.write_value_type(module, other)?, + } + + Ok(()) + } + + /// Helper method used to write value types + /// + /// # Notes + /// Adds no trailing or leading whitespace + fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult { + match *inner { + TypeInner::Vector { size, scalar } => write!( + self.out, + "vec{}<{}>", + back::vector_size_str(size), + scalar_kind_str(scalar), + )?, + TypeInner::Sampler { comparison: false } => { + write!(self.out, "sampler")?; + } + TypeInner::Sampler { comparison: true } => { + write!(self.out, "sampler_comparison")?; + } + TypeInner::Image { + dim, + arrayed, + class, + } => { + // More about texture types: https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type + use crate::ImageClass as Ic; + + let dim_str = image_dimension_str(dim); + let arrayed_str = if arrayed { "_array" } else { "" }; + let (class_str, multisampled_str, format_str, storage_str) = match class { + Ic::Sampled { kind, multi } => ( + "", + if multi { "multisampled_" } else { "" }, + scalar_kind_str(crate::Scalar { kind, width: 4 }), + "", + ), + Ic::Depth { multi } => { + ("depth_", if multi { "multisampled_" } else { "" }, "", "") + } + Ic::Storage { format, access } => ( + "storage_", + "", + storage_format_str(format), + if access.contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE) + { + ",read_write" + } else if access.contains(crate::StorageAccess::LOAD) { + ",read" + } else { + ",write" + }, + ), + }; + write!( + self.out, + "texture_{class_str}{multisampled_str}{dim_str}{arrayed_str}" + )?; + + if !format_str.is_empty() { + write!(self.out, "<{format_str}{storage_str}>")?; + } + } + TypeInner::Scalar(scalar) => { + write!(self.out, "{}", scalar_kind_str(scalar))?; + } + TypeInner::Atomic(scalar) => { + write!(self.out, "atomic<{}>", scalar_kind_str(scalar))?; + } + TypeInner::Array { + base, + size, + stride: _, + } => { + // More info https://gpuweb.github.io/gpuweb/wgsl/#array-types + // array<A, 3> -- Constant array + // array<A> -- Dynamic array + write!(self.out, "array<")?; + match size { + crate::ArraySize::Constant(len) => { + self.write_type(module, base)?; + write!(self.out, ", {len}")?; + } + crate::ArraySize::Dynamic => { + self.write_type(module, base)?; + } + } + write!(self.out, ">")?; + } + TypeInner::BindingArray { base, size } => { + // More info https://github.com/gpuweb/gpuweb/issues/2105 + write!(self.out, "binding_array<")?; + match size { + crate::ArraySize::Constant(len) => { + self.write_type(module, base)?; + write!(self.out, ", {len}")?; + } + crate::ArraySize::Dynamic => { + self.write_type(module, base)?; + } + } + write!(self.out, ">")?; + } + TypeInner::Matrix { + columns, + rows, + scalar, + } => { + write!( + self.out, + "mat{}x{}<{}>", + back::vector_size_str(columns), + back::vector_size_str(rows), + scalar_kind_str(scalar) + )?; + } + TypeInner::Pointer { base, space } => { + let (address, maybe_access) = address_space_str(space); + // Everything but `AddressSpace::Handle` gives us a `address` name, but + // Naga IR never produces pointers to handles, so it doesn't matter much + // how we write such a type. Just write it as the base type alone. + if let Some(space) = address { + write!(self.out, "ptr<{space}, ")?; + } + self.write_type(module, base)?; + if address.is_some() { + if let Some(access) = maybe_access { + write!(self.out, ", {access}")?; + } + write!(self.out, ">")?; + } + } + TypeInner::ValuePointer { + size: None, + scalar, + space, + } => { + let (address, maybe_access) = address_space_str(space); + if let Some(space) = address { + write!(self.out, "ptr<{}, {}", space, scalar_kind_str(scalar))?; + if let Some(access) = maybe_access { + write!(self.out, ", {access}")?; + } + write!(self.out, ">")?; + } else { + return Err(Error::Unimplemented(format!( + "ValuePointer to AddressSpace::Handle {inner:?}" + ))); + } + } + TypeInner::ValuePointer { + size: Some(size), + scalar, + space, + } => { + let (address, maybe_access) = address_space_str(space); + if let Some(space) = address { + write!( + self.out, + "ptr<{}, vec{}<{}>", + space, + back::vector_size_str(size), + scalar_kind_str(scalar) + )?; + if let Some(access) = maybe_access { + write!(self.out, ", {access}")?; + } + write!(self.out, ">")?; + } else { + return Err(Error::Unimplemented(format!( + "ValuePointer to AddressSpace::Handle {inner:?}" + ))); + } + write!(self.out, ">")?; + } + _ => { + return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))); + } + } + + Ok(()) + } + /// Helper method used to write statements + /// + /// # Notes + /// Always adds a newline + fn write_stmt( + &mut self, + module: &Module, + stmt: &crate::Statement, + func_ctx: &back::FunctionCtx<'_>, + level: back::Level, + ) -> BackendResult { + use crate::{Expression, Statement}; + + match *stmt { + Statement::Emit(ref range) => { + for handle in range.clone() { + let info = &func_ctx.info[handle]; + let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) { + // Front end provides names for all variables at the start of writing. + // But we write them to step by step. We need to recache them + // Otherwise, we could accidentally write variable name instead of full expression. + // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords. + Some(self.namer.call(name)) + } else { + let expr = &func_ctx.expressions[handle]; + let min_ref_count = expr.bake_ref_count(); + // Forcefully creating baking expressions in some cases to help with readability + let required_baking_expr = match *expr { + Expression::ImageLoad { .. } + | Expression::ImageQuery { .. } + | Expression::ImageSample { .. } => true, + _ => false, + }; + if min_ref_count <= info.ref_count || required_baking_expr { + Some(format!("{}{}", back::BAKE_PREFIX, handle.index())) + } else { + None + } + }; + + if let Some(name) = expr_name { + write!(self.out, "{level}")?; + self.start_named_expr(module, handle, func_ctx, &name)?; + self.write_expr(module, handle, func_ctx)?; + self.named_expressions.insert(handle, name); + writeln!(self.out, ";")?; + } + } + } + // TODO: copy-paste from glsl-out + Statement::If { + condition, + ref accept, + ref reject, + } => { + write!(self.out, "{level}")?; + write!(self.out, "if ")?; + self.write_expr(module, condition, func_ctx)?; + writeln!(self.out, " {{")?; + + let l2 = level.next(); + for sta in accept { + // Increase indentation to help with readability + self.write_stmt(module, sta, func_ctx, l2)?; + } + + // If there are no statements in the reject block we skip writing it + // This is only for readability + if !reject.is_empty() { + writeln!(self.out, "{level}}} else {{")?; + + for sta in reject { + // Increase indentation to help with readability + self.write_stmt(module, sta, func_ctx, l2)?; + } + } + + writeln!(self.out, "{level}}}")? + } + Statement::Return { value } => { + write!(self.out, "{level}")?; + write!(self.out, "return")?; + if let Some(return_value) = value { + // The leading space is important + write!(self.out, " ")?; + self.write_expr(module, return_value, func_ctx)?; + } + writeln!(self.out, ";")?; + } + // TODO: copy-paste from glsl-out + Statement::Kill => { + write!(self.out, "{level}")?; + writeln!(self.out, "discard;")? + } + Statement::Store { pointer, value } => { + write!(self.out, "{level}")?; + + let is_atomic_pointer = func_ctx + .resolve_type(pointer, &module.types) + .is_atomic_pointer(&module.types); + + if is_atomic_pointer { + write!(self.out, "atomicStore(")?; + self.write_expr(module, pointer, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr_with_indirection( + module, + pointer, + func_ctx, + Indirection::Reference, + )?; + write!(self.out, " = ")?; + self.write_expr(module, value, func_ctx)?; + } + writeln!(self.out, ";")? + } + Statement::Call { + function, + ref arguments, + result, + } => { + write!(self.out, "{level}")?; + if let Some(expr) = result { + let name = format!("{}{}", back::BAKE_PREFIX, expr.index()); + self.start_named_expr(module, expr, func_ctx, &name)?; + self.named_expressions.insert(expr, name); + } + let func_name = &self.names[&NameKey::Function(function)]; + write!(self.out, "{func_name}(")?; + for (index, &argument) in arguments.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + self.write_expr(module, argument, func_ctx)?; + } + writeln!(self.out, ");")? + } + Statement::Atomic { + pointer, + ref fun, + value, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + let fun_str = fun.to_wgsl(); + write!(self.out, "atomic{fun_str}(")?; + self.write_expr(module, pointer, func_ctx)?; + if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { + write!(self.out, ", ")?; + self.write_expr(module, cmp, func_ctx)?; + } + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ");")? + } + Statement::WorkGroupUniformLoad { pointer, result } => { + write!(self.out, "{level}")?; + // TODO: Obey named expressions here. + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + write!(self.out, "workgroupUniformLoad(")?; + self.write_expr(module, pointer, func_ctx)?; + writeln!(self.out, ");")?; + } + Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => { + write!(self.out, "{level}")?; + write!(self.out, "textureStore(")?; + self.write_expr(module, image, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, coordinate, func_ctx)?; + if let Some(array_index_expr) = array_index { + write!(self.out, ", ")?; + self.write_expr(module, array_index_expr, func_ctx)?; + } + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + writeln!(self.out, ");")?; + } + // TODO: copy-paste from glsl-out + Statement::Block(ref block) => { + write!(self.out, "{level}")?; + writeln!(self.out, "{{")?; + for sta in block.iter() { + // Increase the indentation to help with readability + self.write_stmt(module, sta, func_ctx, level.next())? + } + writeln!(self.out, "{level}}}")? + } + Statement::Switch { + selector, + ref cases, + } => { + // Start the switch + write!(self.out, "{level}")?; + write!(self.out, "switch ")?; + self.write_expr(module, selector, func_ctx)?; + writeln!(self.out, " {{")?; + + let l2 = level.next(); + let mut new_case = true; + for case in cases { + if case.fall_through && !case.body.is_empty() { + // TODO: we could do the same workaround as we did for the HLSL backend + return Err(Error::Unimplemented( + "fall-through switch case block".into(), + )); + } + + match case.value { + crate::SwitchValue::I32(value) => { + if new_case { + write!(self.out, "{l2}case ")?; + } + write!(self.out, "{value}")?; + } + crate::SwitchValue::U32(value) => { + if new_case { + write!(self.out, "{l2}case ")?; + } + write!(self.out, "{value}u")?; + } + crate::SwitchValue::Default => { + if new_case { + if case.fall_through { + write!(self.out, "{l2}case ")?; + } else { + write!(self.out, "{l2}")?; + } + } + write!(self.out, "default")?; + } + } + + new_case = !case.fall_through; + + if case.fall_through { + write!(self.out, ", ")?; + } else { + writeln!(self.out, ": {{")?; + } + + for sta in case.body.iter() { + self.write_stmt(module, sta, func_ctx, l2.next())?; + } + + if !case.fall_through { + writeln!(self.out, "{l2}}}")?; + } + } + + writeln!(self.out, "{level}}}")? + } + Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + write!(self.out, "{level}")?; + writeln!(self.out, "loop {{")?; + + let l2 = level.next(); + for sta in body.iter() { + self.write_stmt(module, sta, func_ctx, l2)?; + } + + // The continuing is optional so we don't need to write it if + // it is empty, but the `break if` counts as a continuing statement + // so even if `continuing` is empty we must generate it if a + // `break if` exists + if !continuing.is_empty() || break_if.is_some() { + writeln!(self.out, "{l2}continuing {{")?; + for sta in continuing.iter() { + self.write_stmt(module, sta, func_ctx, l2.next())?; + } + + // The `break if` is always the last + // statement of the `continuing` block + if let Some(condition) = break_if { + // The trailing space is important + write!(self.out, "{}break if ", l2.next())?; + self.write_expr(module, condition, func_ctx)?; + // Close the `break if` statement + writeln!(self.out, ";")?; + } + + writeln!(self.out, "{l2}}}")?; + } + + writeln!(self.out, "{level}}}")? + } + Statement::Break => { + writeln!(self.out, "{level}break;")?; + } + Statement::Continue => { + writeln!(self.out, "{level}continue;")?; + } + Statement::Barrier(barrier) => { + if barrier.contains(crate::Barrier::STORAGE) { + writeln!(self.out, "{level}storageBarrier();")?; + } + + if barrier.contains(crate::Barrier::WORK_GROUP) { + writeln!(self.out, "{level}workgroupBarrier();")?; + } + } + Statement::RayQuery { .. } => unreachable!(), + } + + Ok(()) + } + + /// Return the sort of indirection that `expr`'s plain form evaluates to. + /// + /// An expression's 'plain form' is the most general rendition of that + /// expression into WGSL, lacking `&` or `*` operators: + /// + /// - The plain form of `LocalVariable(x)` is simply `x`, which is a reference + /// to the local variable's storage. + /// + /// - The plain form of `GlobalVariable(g)` is simply `g`, which is usually a + /// reference to the global variable's storage. However, globals in the + /// `Handle` address space are immutable, and `GlobalVariable` expressions for + /// those produce the value directly, not a pointer to it. Such + /// `GlobalVariable` expressions are `Ordinary`. + /// + /// - `Access` and `AccessIndex` are `Reference` when their `base` operand is a + /// pointer. If they are applied directly to a composite value, they are + /// `Ordinary`. + /// + /// Note that `FunctionArgument` expressions are never `Reference`, even when + /// the argument's type is `Pointer`. `FunctionArgument` always evaluates to the + /// argument's value directly, so any pointer it produces is merely the value + /// passed by the caller. + fn plain_form_indirection( + &self, + expr: Handle<crate::Expression>, + module: &Module, + func_ctx: &back::FunctionCtx<'_>, + ) -> Indirection { + use crate::Expression as Ex; + + // Named expressions are `let` expressions, which apply the Load Rule, + // so if their type is a Naga pointer, then that must be a WGSL pointer + // as well. + if self.named_expressions.contains_key(&expr) { + return Indirection::Ordinary; + } + + match func_ctx.expressions[expr] { + Ex::LocalVariable(_) => Indirection::Reference, + Ex::GlobalVariable(handle) => { + let global = &module.global_variables[handle]; + match global.space { + crate::AddressSpace::Handle => Indirection::Ordinary, + _ => Indirection::Reference, + } + } + Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => { + let base_ty = func_ctx.resolve_type(base, &module.types); + match *base_ty { + crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => { + Indirection::Reference + } + _ => Indirection::Ordinary, + } + } + _ => Indirection::Ordinary, + } + } + + fn start_named_expr( + &mut self, + module: &Module, + handle: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx, + name: &str, + ) -> BackendResult { + // Write variable name + write!(self.out, "let {name}")?; + if self.flags.contains(WriterFlags::EXPLICIT_TYPES) { + write!(self.out, ": ")?; + let ty = &func_ctx.info[handle].ty; + // Write variable type + match *ty { + proc::TypeResolution::Handle(handle) => { + self.write_type(module, handle)?; + } + proc::TypeResolution::Value(ref inner) => { + self.write_value_type(module, inner)?; + } + } + } + + write!(self.out, " = ")?; + Ok(()) + } + + /// Write the ordinary WGSL form of `expr`. + /// + /// See `write_expr_with_indirection` for details. + fn write_expr( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx<'_>, + ) -> BackendResult { + self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary) + } + + /// Write `expr` as a WGSL expression with the requested indirection. + /// + /// In terms of the WGSL grammar, the resulting expression is a + /// `singular_expression`. It may be parenthesized. This makes it suitable + /// for use as the operand of a unary or binary operator without worrying + /// about precedence. + /// + /// This does not produce newlines or indentation. + /// + /// The `requested` argument indicates (roughly) whether Naga + /// `Pointer`-valued expressions represent WGSL references or pointers. See + /// `Indirection` for details. + fn write_expr_with_indirection( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx<'_>, + requested: Indirection, + ) -> BackendResult { + // If the plain form of the expression is not what we need, emit the + // operator necessary to correct that. + let plain = self.plain_form_indirection(expr, module, func_ctx); + match (requested, plain) { + (Indirection::Ordinary, Indirection::Reference) => { + write!(self.out, "(&")?; + self.write_expr_plain_form(module, expr, func_ctx, plain)?; + write!(self.out, ")")?; + } + (Indirection::Reference, Indirection::Ordinary) => { + write!(self.out, "(*")?; + self.write_expr_plain_form(module, expr, func_ctx, plain)?; + write!(self.out, ")")?; + } + (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?, + } + + Ok(()) + } + + fn write_const_expression( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + ) -> BackendResult { + self.write_possibly_const_expression( + module, + expr, + &module.const_expressions, + |writer, expr| writer.write_const_expression(module, expr), + ) + } + + fn write_possibly_const_expression<E>( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + expressions: &crate::Arena<crate::Expression>, + write_expression: E, + ) -> BackendResult + where + E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult, + { + use crate::Expression; + + match expressions[expr] { + Expression::Literal(literal) => match literal { + crate::Literal::F32(value) => write!(self.out, "{}f", value)?, + crate::Literal::U32(value) => write!(self.out, "{}u", value)?, + crate::Literal::I32(value) => { + // `-2147483648i` is not valid WGSL. The most negative `i32` + // value can only be expressed in WGSL using AbstractInt and + // a unary negation operator. + if value == i32::MIN { + write!(self.out, "i32(-2147483648)")?; + } else { + write!(self.out, "{}i", value)?; + } + } + crate::Literal::Bool(value) => write!(self.out, "{}", value)?, + crate::Literal::F64(value) => write!(self.out, "{:?}lf", value)?, + crate::Literal::I64(_) => { + return Err(Error::Custom("unsupported i64 literal".to_string())); + } + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".into(), + )); + } + }, + Expression::Constant(handle) => { + let constant = &module.constants[handle]; + if constant.name.is_some() { + write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?; + } else { + self.write_const_expression(module, constant.init)?; + } + } + Expression::ZeroValue(ty) => { + self.write_type(module, ty)?; + write!(self.out, "()")?; + } + Expression::Compose { ty, ref components } => { + self.write_type(module, ty)?; + write!(self.out, "(")?; + for (index, component) in components.iter().enumerate() { + if index != 0 { + write!(self.out, ", ")?; + } + write_expression(self, *component)?; + } + write!(self.out, ")")? + } + Expression::Splat { size, value } => { + let size = back::vector_size_str(size); + write!(self.out, "vec{size}(")?; + write_expression(self, value)?; + write!(self.out, ")")?; + } + _ => unreachable!(), + } + + Ok(()) + } + + /// Write the 'plain form' of `expr`. + /// + /// An expression's 'plain form' is the most general rendition of that + /// expression into WGSL, lacking `&` or `*` operators. The plain forms of + /// `LocalVariable(x)` and `GlobalVariable(g)` are simply `x` and `g`. Such + /// Naga expressions represent both WGSL pointers and references; it's the + /// caller's responsibility to distinguish those cases appropriately. + fn write_expr_plain_form( + &mut self, + module: &Module, + expr: Handle<crate::Expression>, + func_ctx: &back::FunctionCtx<'_>, + indirection: Indirection, + ) -> BackendResult { + use crate::Expression; + + if let Some(name) = self.named_expressions.get(&expr) { + write!(self.out, "{name}")?; + return Ok(()); + } + + let expression = &func_ctx.expressions[expr]; + + // Write the plain WGSL form of a Naga expression. + // + // The plain form of `LocalVariable` and `GlobalVariable` expressions is + // simply the variable name; `*` and `&` operators are never emitted. + // + // The plain form of `Access` and `AccessIndex` expressions are WGSL + // `postfix_expression` forms for member/component access and + // subscripting. + match *expression { + Expression::Literal(_) + | Expression::Constant(_) + | Expression::ZeroValue(_) + | Expression::Compose { .. } + | Expression::Splat { .. } => { + self.write_possibly_const_expression( + module, + expr, + func_ctx.expressions, + |writer, expr| writer.write_expr(module, expr, func_ctx), + )?; + } + Expression::FunctionArgument(pos) => { + let name_key = func_ctx.argument_key(pos); + let name = &self.names[&name_key]; + write!(self.out, "{name}")?; + } + Expression::Binary { op, left, right } => { + write!(self.out, "(")?; + self.write_expr(module, left, func_ctx)?; + write!(self.out, " {} ", back::binary_operation_str(op))?; + self.write_expr(module, right, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Access { base, index } => { + self.write_expr_with_indirection(module, base, func_ctx, indirection)?; + write!(self.out, "[")?; + self.write_expr(module, index, func_ctx)?; + write!(self.out, "]")? + } + Expression::AccessIndex { base, index } => { + let base_ty_res = &func_ctx.info[base].ty; + let mut resolved = base_ty_res.inner_with(&module.types); + + self.write_expr_with_indirection(module, base, func_ctx, indirection)?; + + let base_ty_handle = match *resolved { + TypeInner::Pointer { base, space: _ } => { + resolved = &module.types[base].inner; + Some(base) + } + _ => base_ty_res.handle(), + }; + + match *resolved { + TypeInner::Vector { .. } => { + // Write vector access as a swizzle + write!(self.out, ".{}", back::COMPONENTS[index as usize])? + } + TypeInner::Matrix { .. } + | TypeInner::Array { .. } + | TypeInner::BindingArray { .. } + | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?, + TypeInner::Struct { .. } => { + // This will never panic in case the type is a `Struct`, this is not true + // for other types so we can only check while inside this match arm + let ty = base_ty_handle.unwrap(); + + write!( + self.out, + ".{}", + &self.names[&NameKey::StructMember(ty, index)] + )? + } + ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))), + } + } + Expression::ImageSample { + image, + sampler, + gather: None, + coordinate, + array_index, + offset, + level, + depth_ref, + } => { + use crate::SampleLevel as Sl; + + let suffix_cmp = match depth_ref { + Some(_) => "Compare", + None => "", + }; + let suffix_level = match level { + Sl::Auto => "", + Sl::Zero | Sl::Exact(_) => "Level", + Sl::Bias(_) => "Bias", + Sl::Gradient { .. } => "Grad", + }; + + write!(self.out, "textureSample{suffix_cmp}{suffix_level}(")?; + self.write_expr(module, image, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, sampler, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, coordinate, func_ctx)?; + + if let Some(array_index) = array_index { + write!(self.out, ", ")?; + self.write_expr(module, array_index, func_ctx)?; + } + + if let Some(depth_ref) = depth_ref { + write!(self.out, ", ")?; + self.write_expr(module, depth_ref, func_ctx)?; + } + + match level { + Sl::Auto => {} + Sl::Zero => { + // Level 0 is implied for depth comparison + if depth_ref.is_none() { + write!(self.out, ", 0.0")?; + } + } + Sl::Exact(expr) => { + write!(self.out, ", ")?; + self.write_expr(module, expr, func_ctx)?; + } + Sl::Bias(expr) => { + write!(self.out, ", ")?; + self.write_expr(module, expr, func_ctx)?; + } + Sl::Gradient { x, y } => { + write!(self.out, ", ")?; + self.write_expr(module, x, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, y, func_ctx)?; + } + } + + if let Some(offset) = offset { + write!(self.out, ", ")?; + self.write_const_expression(module, offset)?; + } + + write!(self.out, ")")?; + } + + Expression::ImageSample { + image, + sampler, + gather: Some(component), + coordinate, + array_index, + offset, + level: _, + depth_ref, + } => { + let suffix_cmp = match depth_ref { + Some(_) => "Compare", + None => "", + }; + + write!(self.out, "textureGather{suffix_cmp}(")?; + match *func_ctx.resolve_type(image, &module.types) { + TypeInner::Image { + class: crate::ImageClass::Depth { multi: _ }, + .. + } => {} + _ => { + write!(self.out, "{}, ", component as u8)?; + } + } + self.write_expr(module, image, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, sampler, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, coordinate, func_ctx)?; + + if let Some(array_index) = array_index { + write!(self.out, ", ")?; + self.write_expr(module, array_index, func_ctx)?; + } + + if let Some(depth_ref) = depth_ref { + write!(self.out, ", ")?; + self.write_expr(module, depth_ref, func_ctx)?; + } + + if let Some(offset) = offset { + write!(self.out, ", ")?; + self.write_const_expression(module, offset)?; + } + + write!(self.out, ")")?; + } + Expression::ImageQuery { image, query } => { + use crate::ImageQuery as Iq; + + let texture_function = match query { + Iq::Size { .. } => "textureDimensions", + Iq::NumLevels => "textureNumLevels", + Iq::NumLayers => "textureNumLayers", + Iq::NumSamples => "textureNumSamples", + }; + + write!(self.out, "{texture_function}(")?; + self.write_expr(module, image, func_ctx)?; + if let Iq::Size { level: Some(level) } = query { + write!(self.out, ", ")?; + self.write_expr(module, level, func_ctx)?; + }; + write!(self.out, ")")?; + } + + Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + write!(self.out, "textureLoad(")?; + self.write_expr(module, image, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, coordinate, func_ctx)?; + if let Some(array_index) = array_index { + write!(self.out, ", ")?; + self.write_expr(module, array_index, func_ctx)?; + } + if let Some(index) = sample.or(level) { + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + } + write!(self.out, ")")?; + } + Expression::GlobalVariable(handle) => { + let name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, "{name}")?; + } + + Expression::As { + expr, + kind, + convert, + } => { + let inner = func_ctx.resolve_type(expr, &module.types); + match *inner { + TypeInner::Matrix { + columns, + rows, + scalar, + } => { + let scalar = crate::Scalar { + kind, + width: convert.unwrap_or(scalar.width), + }; + let scalar_kind_str = scalar_kind_str(scalar); + write!( + self.out, + "mat{}x{}<{}>", + back::vector_size_str(columns), + back::vector_size_str(rows), + scalar_kind_str + )?; + } + TypeInner::Vector { + size, + scalar: crate::Scalar { width, .. }, + } => { + let scalar = crate::Scalar { + kind, + width: convert.unwrap_or(width), + }; + let vector_size_str = back::vector_size_str(size); + let scalar_kind_str = scalar_kind_str(scalar); + if convert.is_some() { + write!(self.out, "vec{vector_size_str}<{scalar_kind_str}>")?; + } else { + write!(self.out, "bitcast<vec{vector_size_str}<{scalar_kind_str}>>")?; + } + } + TypeInner::Scalar(crate::Scalar { width, .. }) => { + let scalar = crate::Scalar { + kind, + width: convert.unwrap_or(width), + }; + let scalar_kind_str = scalar_kind_str(scalar); + if convert.is_some() { + write!(self.out, "{scalar_kind_str}")? + } else { + write!(self.out, "bitcast<{scalar_kind_str}>")? + } + } + _ => { + return Err(Error::Unimplemented(format!( + "write_expr expression::as {inner:?}" + ))); + } + }; + write!(self.out, "(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")?; + } + Expression::Load { pointer } => { + let is_atomic_pointer = func_ctx + .resolve_type(pointer, &module.types) + .is_atomic_pointer(&module.types); + + if is_atomic_pointer { + write!(self.out, "atomicLoad(")?; + self.write_expr(module, pointer, func_ctx)?; + write!(self.out, ")")?; + } else { + self.write_expr_with_indirection( + module, + pointer, + func_ctx, + Indirection::Reference, + )?; + } + } + Expression::LocalVariable(handle) => { + write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])? + } + Expression::ArrayLength(expr) => { + write!(self.out, "arrayLength(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")?; + } + + Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + use crate::MathFunction as Mf; + + enum Function { + Regular(&'static str), + } + + let function = match fun { + Mf::Abs => Function::Regular("abs"), + Mf::Min => Function::Regular("min"), + Mf::Max => Function::Regular("max"), + Mf::Clamp => Function::Regular("clamp"), + Mf::Saturate => Function::Regular("saturate"), + // trigonometry + Mf::Cos => Function::Regular("cos"), + Mf::Cosh => Function::Regular("cosh"), + Mf::Sin => Function::Regular("sin"), + Mf::Sinh => Function::Regular("sinh"), + Mf::Tan => Function::Regular("tan"), + Mf::Tanh => Function::Regular("tanh"), + Mf::Acos => Function::Regular("acos"), + Mf::Asin => Function::Regular("asin"), + Mf::Atan => Function::Regular("atan"), + Mf::Atan2 => Function::Regular("atan2"), + Mf::Asinh => Function::Regular("asinh"), + Mf::Acosh => Function::Regular("acosh"), + Mf::Atanh => Function::Regular("atanh"), + Mf::Radians => Function::Regular("radians"), + Mf::Degrees => Function::Regular("degrees"), + // decomposition + Mf::Ceil => Function::Regular("ceil"), + Mf::Floor => Function::Regular("floor"), + Mf::Round => Function::Regular("round"), + Mf::Fract => Function::Regular("fract"), + Mf::Trunc => Function::Regular("trunc"), + Mf::Modf => Function::Regular("modf"), + Mf::Frexp => Function::Regular("frexp"), + Mf::Ldexp => Function::Regular("ldexp"), + // exponent + Mf::Exp => Function::Regular("exp"), + Mf::Exp2 => Function::Regular("exp2"), + Mf::Log => Function::Regular("log"), + Mf::Log2 => Function::Regular("log2"), + Mf::Pow => Function::Regular("pow"), + // geometry + Mf::Dot => Function::Regular("dot"), + Mf::Cross => Function::Regular("cross"), + Mf::Distance => Function::Regular("distance"), + Mf::Length => Function::Regular("length"), + Mf::Normalize => Function::Regular("normalize"), + Mf::FaceForward => Function::Regular("faceForward"), + Mf::Reflect => Function::Regular("reflect"), + Mf::Refract => Function::Regular("refract"), + // computational + Mf::Sign => Function::Regular("sign"), + Mf::Fma => Function::Regular("fma"), + Mf::Mix => Function::Regular("mix"), + Mf::Step => Function::Regular("step"), + Mf::SmoothStep => Function::Regular("smoothstep"), + Mf::Sqrt => Function::Regular("sqrt"), + Mf::InverseSqrt => Function::Regular("inverseSqrt"), + Mf::Transpose => Function::Regular("transpose"), + Mf::Determinant => Function::Regular("determinant"), + // bits + Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"), + Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"), + Mf::CountOneBits => Function::Regular("countOneBits"), + Mf::ReverseBits => Function::Regular("reverseBits"), + Mf::ExtractBits => Function::Regular("extractBits"), + Mf::InsertBits => Function::Regular("insertBits"), + Mf::FindLsb => Function::Regular("firstTrailingBit"), + Mf::FindMsb => Function::Regular("firstLeadingBit"), + // data packing + Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"), + Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"), + Mf::Pack2x16snorm => Function::Regular("pack2x16snorm"), + Mf::Pack2x16unorm => Function::Regular("pack2x16unorm"), + Mf::Pack2x16float => Function::Regular("pack2x16float"), + // data unpacking + Mf::Unpack4x8snorm => Function::Regular("unpack4x8snorm"), + Mf::Unpack4x8unorm => Function::Regular("unpack4x8unorm"), + Mf::Unpack2x16snorm => Function::Regular("unpack2x16snorm"), + Mf::Unpack2x16unorm => Function::Regular("unpack2x16unorm"), + Mf::Unpack2x16float => Function::Regular("unpack2x16float"), + Mf::Inverse | Mf::Outer => { + return Err(Error::UnsupportedMathFunction(fun)); + } + }; + + match function { + Function::Regular(fun_name) => { + write!(self.out, "{fun_name}(")?; + self.write_expr(module, arg, func_ctx)?; + for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + write!(self.out, ")")? + } + } + } + + Expression::Swizzle { + size, + vector, + pattern, + } => { + self.write_expr(module, vector, func_ctx)?; + write!(self.out, ".")?; + for &sc in pattern[..size as usize].iter() { + self.out.write_char(back::COMPONENTS[sc as usize])?; + } + } + Expression::Unary { op, expr } => { + let unary = match op { + crate::UnaryOperator::Negate => "-", + crate::UnaryOperator::LogicalNot => "!", + crate::UnaryOperator::BitwiseNot => "~", + }; + + write!(self.out, "{unary}(")?; + self.write_expr(module, expr, func_ctx)?; + + write!(self.out, ")")? + } + + Expression::Select { + condition, + accept, + reject, + } => { + write!(self.out, "select(")?; + self.write_expr(module, reject, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, accept, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, condition, func_ctx)?; + write!(self.out, ")")? + } + Expression::Derivative { axis, ctrl, expr } => { + use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl}; + let op = match (axis, ctrl) { + (Axis::X, Ctrl::Coarse) => "dpdxCoarse", + (Axis::X, Ctrl::Fine) => "dpdxFine", + (Axis::X, Ctrl::None) => "dpdx", + (Axis::Y, Ctrl::Coarse) => "dpdyCoarse", + (Axis::Y, Ctrl::Fine) => "dpdyFine", + (Axis::Y, Ctrl::None) => "dpdy", + (Axis::Width, Ctrl::Coarse) => "fwidthCoarse", + (Axis::Width, Ctrl::Fine) => "fwidthFine", + (Axis::Width, Ctrl::None) => "fwidth", + }; + write!(self.out, "{op}(")?; + self.write_expr(module, expr, func_ctx)?; + write!(self.out, ")")? + } + Expression::Relational { fun, argument } => { + use crate::RelationalFunction as Rf; + + let fun_name = match fun { + Rf::All => "all", + Rf::Any => "any", + _ => return Err(Error::UnsupportedRelationalFunction(fun)), + }; + write!(self.out, "{fun_name}(")?; + + self.write_expr(module, argument, func_ctx)?; + + write!(self.out, ")")? + } + // Not supported yet + Expression::RayQueryGetIntersection { .. } => unreachable!(), + // Nothing to do here, since call expression already cached + Expression::CallResult(_) + | Expression::AtomicResult { .. } + | Expression::RayQueryProceedResult + | Expression::WorkGroupUniformLoadResult { .. } => {} + } + + Ok(()) + } + + /// Helper method used to write global variables + /// # Notes + /// Always adds a newline + fn write_global( + &mut self, + module: &Module, + global: &crate::GlobalVariable, + handle: Handle<crate::GlobalVariable>, + ) -> BackendResult { + // Write group and binding attributes if present + if let Some(ref binding) = global.binding { + self.write_attributes(&[ + Attribute::Group(binding.group), + Attribute::Binding(binding.binding), + ])?; + writeln!(self.out)?; + } + + // First write global name and address space if supported + write!(self.out, "var")?; + let (address, maybe_access) = address_space_str(global.space); + if let Some(space) = address { + write!(self.out, "<{space}")?; + if let Some(access) = maybe_access { + write!(self.out, ", {access}")?; + } + write!(self.out, ">")?; + } + write!( + self.out, + " {}: ", + &self.names[&NameKey::GlobalVariable(handle)] + )?; + + // Write global type + self.write_type(module, global.ty)?; + + // Write initializer + if let Some(init) = global.init { + write!(self.out, " = ")?; + self.write_const_expression(module, init)?; + } + + // End with semicolon + writeln!(self.out, ";")?; + + Ok(()) + } + + /// Helper method used to write global constants + /// + /// # Notes + /// Ends in a newline + fn write_global_constant( + &mut self, + module: &Module, + handle: Handle<crate::Constant>, + ) -> BackendResult { + let name = &self.names[&NameKey::Constant(handle)]; + // First write only constant name + write!(self.out, "const {name}: ")?; + self.write_type(module, module.constants[handle].ty)?; + write!(self.out, " = ")?; + let init = module.constants[handle].init; + self.write_const_expression(module, init)?; + writeln!(self.out, ";")?; + + Ok(()) + } + + // See https://github.com/rust-lang/rust-clippy/issues/4979. + #[allow(clippy::missing_const_for_fn)] + pub fn finish(self) -> W { + self.out + } +} + +fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { + use crate::BuiltIn as Bi; + + Ok(match built_in { + Bi::VertexIndex => "vertex_index", + Bi::InstanceIndex => "instance_index", + Bi::Position { .. } => "position", + Bi::FrontFacing => "front_facing", + Bi::FragDepth => "frag_depth", + Bi::LocalInvocationId => "local_invocation_id", + Bi::LocalInvocationIndex => "local_invocation_index", + Bi::GlobalInvocationId => "global_invocation_id", + Bi::WorkGroupId => "workgroup_id", + Bi::NumWorkGroups => "num_workgroups", + Bi::SampleIndex => "sample_index", + Bi::SampleMask => "sample_mask", + Bi::PrimitiveIndex => "primitive_index", + Bi::ViewIndex => "view_index", + Bi::BaseInstance + | Bi::BaseVertex + | Bi::ClipDistance + | Bi::CullDistance + | Bi::PointSize + | Bi::PointCoord + | Bi::WorkGroupSize => { + return Err(Error::Custom(format!("Unsupported builtin {built_in:?}"))) + } + }) +} + +const fn image_dimension_str(dim: crate::ImageDimension) -> &'static str { + use crate::ImageDimension as IDim; + + match dim { + IDim::D1 => "1d", + IDim::D2 => "2d", + IDim::D3 => "3d", + IDim::Cube => "cube", + } +} + +const fn scalar_kind_str(scalar: crate::Scalar) -> &'static str { + use crate::Scalar; + use crate::ScalarKind as Sk; + + match scalar { + Scalar { + kind: Sk::Float, + width: 8, + } => "f64", + Scalar { + kind: Sk::Float, + width: 4, + } => "f32", + Scalar { + kind: Sk::Sint, + width: 4, + } => "i32", + Scalar { + kind: Sk::Uint, + width: 4, + } => "u32", + Scalar { + kind: Sk::Bool, + width: 1, + } => "bool", + _ => unreachable!(), + } +} + +const fn storage_format_str(format: crate::StorageFormat) -> &'static str { + use crate::StorageFormat as Sf; + + match format { + Sf::R8Unorm => "r8unorm", + Sf::R8Snorm => "r8snorm", + Sf::R8Uint => "r8uint", + Sf::R8Sint => "r8sint", + Sf::R16Uint => "r16uint", + Sf::R16Sint => "r16sint", + Sf::R16Float => "r16float", + Sf::Rg8Unorm => "rg8unorm", + Sf::Rg8Snorm => "rg8snorm", + Sf::Rg8Uint => "rg8uint", + Sf::Rg8Sint => "rg8sint", + Sf::R32Uint => "r32uint", + Sf::R32Sint => "r32sint", + Sf::R32Float => "r32float", + Sf::Rg16Uint => "rg16uint", + Sf::Rg16Sint => "rg16sint", + Sf::Rg16Float => "rg16float", + Sf::Rgba8Unorm => "rgba8unorm", + Sf::Rgba8Snorm => "rgba8snorm", + Sf::Rgba8Uint => "rgba8uint", + Sf::Rgba8Sint => "rgba8sint", + Sf::Bgra8Unorm => "bgra8unorm", + Sf::Rgb10a2Uint => "rgb10a2uint", + Sf::Rgb10a2Unorm => "rgb10a2unorm", + Sf::Rg11b10Float => "rg11b10float", + Sf::Rg32Uint => "rg32uint", + Sf::Rg32Sint => "rg32sint", + Sf::Rg32Float => "rg32float", + Sf::Rgba16Uint => "rgba16uint", + Sf::Rgba16Sint => "rgba16sint", + Sf::Rgba16Float => "rgba16float", + Sf::Rgba32Uint => "rgba32uint", + Sf::Rgba32Sint => "rgba32sint", + Sf::Rgba32Float => "rgba32float", + Sf::R16Unorm => "r16unorm", + Sf::R16Snorm => "r16snorm", + Sf::Rg16Unorm => "rg16unorm", + Sf::Rg16Snorm => "rg16snorm", + Sf::Rgba16Unorm => "rgba16unorm", + Sf::Rgba16Snorm => "rgba16snorm", + } +} + +/// Helper function that returns the string corresponding to the WGSL interpolation qualifier +const fn interpolation_str(interpolation: crate::Interpolation) -> &'static str { + use crate::Interpolation as I; + + match interpolation { + I::Perspective => "perspective", + I::Linear => "linear", + I::Flat => "flat", + } +} + +/// Return the WGSL auxiliary qualifier for the given sampling value. +const fn sampling_str(sampling: crate::Sampling) -> &'static str { + use crate::Sampling as S; + + match sampling { + S::Center => "", + S::Centroid => "centroid", + S::Sample => "sample", + } +} + +const fn address_space_str( + space: crate::AddressSpace, +) -> (Option<&'static str>, Option<&'static str>) { + use crate::AddressSpace as As; + + ( + Some(match space { + As::Private => "private", + As::Uniform => "uniform", + As::Storage { access } => { + if access.contains(crate::StorageAccess::STORE) { + return (Some("storage"), Some("read_write")); + } else { + "storage" + } + } + As::PushConstant => "push_constant", + As::WorkGroup => "workgroup", + As::Handle => return (None, None), + As::Function => "function", + }), + None, + ) +} + +fn map_binding_to_attribute(binding: &crate::Binding) -> Vec<Attribute> { + match *binding { + crate::Binding::BuiltIn(built_in) => { + if let crate::BuiltIn::Position { invariant: true } = built_in { + vec![Attribute::BuiltIn(built_in), Attribute::Invariant] + } else { + vec![Attribute::BuiltIn(built_in)] + } + } + crate::Binding::Location { + location, + interpolation, + sampling, + second_blend_source: false, + } => vec![ + Attribute::Location(location), + Attribute::Interpolate(interpolation, sampling), + ], + crate::Binding::Location { + location, + interpolation, + sampling, + second_blend_source: true, + } => vec![ + Attribute::Location(location), + Attribute::SecondBlendSource, + Attribute::Interpolate(interpolation, sampling), + ], + } +} |