summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/back
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/back')
-rw-r--r--third_party/rust/naga/src/back/dot/mod.rs703
-rw-r--r--third_party/rust/naga/src/back/glsl/features.rs536
-rw-r--r--third_party/rust/naga/src/back/glsl/keywords.rs484
-rw-r--r--third_party/rust/naga/src/back/glsl/mod.rs4532
-rw-r--r--third_party/rust/naga/src/back/hlsl/conv.rs222
-rw-r--r--third_party/rust/naga/src/back/hlsl/help.rs1138
-rw-r--r--third_party/rust/naga/src/back/hlsl/keywords.rs904
-rw-r--r--third_party/rust/naga/src/back/hlsl/mod.rs302
-rw-r--r--third_party/rust/naga/src/back/hlsl/storage.rs494
-rw-r--r--third_party/rust/naga/src/back/hlsl/writer.rs3366
-rw-r--r--third_party/rust/naga/src/back/mod.rs273
-rw-r--r--third_party/rust/naga/src/back/msl/keywords.rs342
-rw-r--r--third_party/rust/naga/src/back/msl/mod.rs541
-rw-r--r--third_party/rust/naga/src/back/msl/sampler.rs176
-rw-r--r--third_party/rust/naga/src/back/msl/writer.rs4659
-rw-r--r--third_party/rust/naga/src/back/spv/block.rs2368
-rw-r--r--third_party/rust/naga/src/back/spv/helpers.rs109
-rw-r--r--third_party/rust/naga/src/back/spv/image.rs1210
-rw-r--r--third_party/rust/naga/src/back/spv/index.rs421
-rw-r--r--third_party/rust/naga/src/back/spv/instructions.rs1100
-rw-r--r--third_party/rust/naga/src/back/spv/layout.rs210
-rw-r--r--third_party/rust/naga/src/back/spv/mod.rs748
-rw-r--r--third_party/rust/naga/src/back/spv/ray.rs261
-rw-r--r--third_party/rust/naga/src/back/spv/recyclable.rs67
-rw-r--r--third_party/rust/naga/src/back/spv/selection.rs257
-rw-r--r--third_party/rust/naga/src/back/spv/writer.rs2063
-rw-r--r--third_party/rust/naga/src/back/wgsl/mod.rs52
-rw-r--r--third_party/rust/naga/src/back/wgsl/writer.rs1961
28 files changed, 29499 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/back/dot/mod.rs b/third_party/rust/naga/src/back/dot/mod.rs
new file mode 100644
index 0000000000..1556371df1
--- /dev/null
+++ b/third_party/rust/naga/src/back/dot/mod.rs
@@ -0,0 +1,703 @@
+/*!
+Backend for [DOT][dot] (Graphviz).
+
+This backend writes a graph in the DOT language, for the ease
+of IR inspection and debugging.
+
+[dot]: https://graphviz.org/doc/info/lang.html
+*/
+
+use crate::{
+ arena::Handle,
+ valid::{FunctionInfo, ModuleInfo},
+};
+
+use std::{
+ borrow::Cow,
+ fmt::{Error as FmtError, Write as _},
+};
+
+/// Configuration options for the dot backend
+#[derive(Clone, Default)]
+pub struct Options {
+ /// Only emit function bodies
+ pub cfg_only: bool,
+}
+
+/// Identifier used to address a graph node
+type NodeId = usize;
+
+/// Stores the target nodes for control flow statements
+#[derive(Default, Clone, Copy)]
+struct Targets {
+ /// The node, if some, where continue operations will land
+ continue_target: Option<usize>,
+ /// The node, if some, where break operations will land
+ break_target: Option<usize>,
+}
+
+/// Stores information about the graph of statements
+#[derive(Default)]
+struct StatementGraph {
+ /// List of node names
+ nodes: Vec<&'static str>,
+ /// List of edges of the control flow, the items are defined as
+ /// (from, to, label)
+ flow: Vec<(NodeId, NodeId, &'static str)>,
+ /// List of implicit edges of the control flow, used for jump
+ /// operations such as continue or break, the items are defined as
+ /// (from, to, label, color_id)
+ jumps: Vec<(NodeId, NodeId, &'static str, usize)>,
+ /// List of dependency relationships between a statement node and
+ /// expressions
+ dependencies: Vec<(NodeId, Handle<crate::Expression>, &'static str)>,
+ /// List of expression emitted by statement node
+ emits: Vec<(NodeId, Handle<crate::Expression>)>,
+ /// List of function call by statement node
+ calls: Vec<(NodeId, Handle<crate::Function>)>,
+}
+
+impl StatementGraph {
+ /// Adds a new block to the statement graph, returning the first and last node, respectively
+ fn add(&mut self, block: &[crate::Statement], targets: Targets) -> (NodeId, NodeId) {
+ use crate::Statement as S;
+
+ // The first node of the block isn't a statement but a virtual node
+ let root = self.nodes.len();
+ self.nodes.push(if root == 0 { "Root" } else { "Node" });
+ // Track the last placed node, this will be returned to the caller and
+ // will also be used to generate the control flow edges
+ let mut last_node = root;
+ for statement in block {
+ // Reserve a new node for the current statement and link it to the
+ // node of the previous statement
+ let id = self.nodes.len();
+ self.flow.push((last_node, id, ""));
+ self.nodes.push(""); // reserve space
+
+ // Track the node identifier for the merge node, the merge node is
+ // the last node of a statement, normally this is the node itself,
+ // but for control flow statements such as `if`s and `switch`s this
+ // is a virtual node where all branches merge back.
+ let mut merge_id = id;
+
+ self.nodes[id] = match *statement {
+ S::Emit(ref range) => {
+ for handle in range.clone() {
+ self.emits.push((id, handle));
+ }
+ "Emit"
+ }
+ S::Kill => "Kill", //TODO: link to the beginning
+ S::Break => {
+ // Try to link to the break target, otherwise produce
+ // a broken connection
+ if let Some(target) = targets.break_target {
+ self.jumps.push((id, target, "Break", 5))
+ } else {
+ self.jumps.push((id, root, "Broken", 7))
+ }
+ "Break"
+ }
+ S::Continue => {
+ // Try to link to the continue target, otherwise produce
+ // a broken connection
+ if let Some(target) = targets.continue_target {
+ self.jumps.push((id, target, "Continue", 5))
+ } else {
+ self.jumps.push((id, root, "Broken", 7))
+ }
+ "Continue"
+ }
+ S::Barrier(_flags) => "Barrier",
+ S::Block(ref b) => {
+ let (other, last) = self.add(b, targets);
+ self.flow.push((id, other, ""));
+ // All following nodes should connect to the end of the block
+ // statement so change the merge id to it.
+ merge_id = last;
+ "Block"
+ }
+ S::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ self.dependencies.push((id, condition, "condition"));
+ let (accept_id, accept_last) = self.add(accept, targets);
+ self.flow.push((id, accept_id, "accept"));
+ let (reject_id, reject_last) = self.add(reject, targets);
+ self.flow.push((id, reject_id, "reject"));
+
+ // Create a merge node, link the branches to it and set it
+ // as the merge node to make the next statement node link to it
+ merge_id = self.nodes.len();
+ self.nodes.push("Merge");
+ self.flow.push((accept_last, merge_id, ""));
+ self.flow.push((reject_last, merge_id, ""));
+
+ "If"
+ }
+ S::Switch {
+ selector,
+ ref cases,
+ } => {
+ self.dependencies.push((id, selector, "selector"));
+
+ // Create a merge node and set it as the merge node to make
+ // the next statement node link to it
+ merge_id = self.nodes.len();
+ self.nodes.push("Merge");
+
+ // Create a new targets structure and set the break target
+ // to the merge node
+ let mut targets = targets;
+ targets.break_target = Some(merge_id);
+
+ for case in cases {
+ let (case_id, case_last) = self.add(&case.body, targets);
+ let label = match case.value {
+ crate::SwitchValue::Default => "default",
+ _ => "case",
+ };
+ self.flow.push((id, case_id, label));
+ // Link the last node of the branch to the merge node
+ self.flow.push((case_last, merge_id, ""));
+ }
+ "Switch"
+ }
+ S::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ // Create a new targets structure and set the break target
+ // to the merge node, this must happen before generating the
+ // continuing block since it can break.
+ let mut targets = targets;
+ targets.break_target = Some(id);
+
+ let (continuing_id, continuing_last) = self.add(continuing, targets);
+
+ // Set the the continue target to the beginning
+ // of the newly generated continuing block
+ targets.continue_target = Some(continuing_id);
+
+ let (body_id, body_last) = self.add(body, targets);
+
+ self.flow.push((id, body_id, "body"));
+
+ // Link the last node of the body to the continuing block
+ self.flow.push((body_last, continuing_id, "continuing"));
+ // Link the last node of the continuing block back to the
+ // beginning of the loop body
+ self.flow.push((continuing_last, body_id, "continuing"));
+
+ if let Some(expr) = break_if {
+ self.dependencies.push((continuing_id, expr, "break if"));
+ }
+
+ "Loop"
+ }
+ S::Return { value } => {
+ if let Some(expr) = value {
+ self.dependencies.push((id, expr, "value"));
+ }
+ "Return"
+ }
+ S::Store { pointer, value } => {
+ self.dependencies.push((id, value, "value"));
+ self.emits.push((id, pointer));
+ "Store"
+ }
+ S::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ self.dependencies.push((id, image, "image"));
+ self.dependencies.push((id, coordinate, "coordinate"));
+ if let Some(expr) = array_index {
+ self.dependencies.push((id, expr, "array_index"));
+ }
+ self.dependencies.push((id, value, "value"));
+ "ImageStore"
+ }
+ S::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ for &arg in arguments {
+ self.dependencies.push((id, arg, "arg"));
+ }
+ if let Some(expr) = result {
+ self.emits.push((id, expr));
+ }
+ self.calls.push((id, function));
+ "Call"
+ }
+ S::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ self.emits.push((id, result));
+ self.dependencies.push((id, pointer, "pointer"));
+ self.dependencies.push((id, value, "value"));
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ self.dependencies.push((id, cmp, "cmp"));
+ }
+ "Atomic"
+ }
+ S::WorkGroupUniformLoad { pointer, result } => {
+ self.emits.push((id, result));
+ self.dependencies.push((id, pointer, "pointer"));
+ "WorkGroupUniformLoad"
+ }
+ S::RayQuery { query, ref fun } => {
+ self.dependencies.push((id, query, "query"));
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ self.dependencies.push((
+ id,
+ acceleration_structure,
+ "acceleration_structure",
+ ));
+ self.dependencies.push((id, descriptor, "descriptor"));
+ "RayQueryInitialize"
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ self.emits.push((id, result));
+ "RayQueryProceed"
+ }
+ crate::RayQueryFunction::Terminate => "RayQueryTerminate",
+ }
+ }
+ };
+ // Set the last node to the merge node
+ last_node = merge_id;
+ }
+ (root, last_node)
+ }
+}
+
+#[allow(clippy::manual_unwrap_or)]
+fn name(option: &Option<String>) -> &str {
+ match *option {
+ Some(ref name) => name,
+ None => "",
+ }
+}
+
+/// set39 color scheme from <https://graphviz.org/doc/info/colors.html>
+const COLORS: &[&str] = &[
+ "white", // pattern starts at 1
+ "#8dd3c7", "#ffffb3", "#bebada", "#fb8072", "#80b1d3", "#fdb462", "#b3de69", "#fccde5",
+ "#d9d9d9",
+];
+
+fn write_fun(
+ output: &mut String,
+ prefix: String,
+ fun: &crate::Function,
+ info: Option<&FunctionInfo>,
+ options: &Options,
+) -> Result<(), FmtError> {
+ writeln!(output, "\t\tnode [ style=filled ]")?;
+
+ if !options.cfg_only {
+ for (handle, var) in fun.local_variables.iter() {
+ writeln!(
+ output,
+ "\t\t{}_l{} [ shape=hexagon label=\"{:?} '{}'\" ]",
+ prefix,
+ handle.index(),
+ handle,
+ name(&var.name),
+ )?;
+ }
+
+ write_function_expressions(output, &prefix, fun, info)?;
+ }
+
+ let mut sg = StatementGraph::default();
+ sg.add(&fun.body, Targets::default());
+ for (index, label) in sg.nodes.into_iter().enumerate() {
+ writeln!(
+ output,
+ "\t\t{prefix}_s{index} [ shape=square label=\"{label}\" ]",
+ )?;
+ }
+ for (from, to, label) in sg.flow {
+ writeln!(
+ output,
+ "\t\t{prefix}_s{from} -> {prefix}_s{to} [ arrowhead=tee label=\"{label}\" ]",
+ )?;
+ }
+ for (from, to, label, color_id) in sg.jumps {
+ writeln!(
+ output,
+ "\t\t{}_s{} -> {}_s{} [ arrowhead=tee style=dashed color=\"{}\" label=\"{}\" ]",
+ prefix, from, prefix, to, COLORS[color_id], label,
+ )?;
+ }
+
+ if !options.cfg_only {
+ for (to, expr, label) in sg.dependencies {
+ writeln!(
+ output,
+ "\t\t{}_e{} -> {}_s{} [ label=\"{}\" ]",
+ prefix,
+ expr.index(),
+ prefix,
+ to,
+ label,
+ )?;
+ }
+ for (from, to) in sg.emits {
+ writeln!(
+ output,
+ "\t\t{}_s{} -> {}_e{} [ style=dotted ]",
+ prefix,
+ from,
+ prefix,
+ to.index(),
+ )?;
+ }
+ }
+
+ for (from, function) in sg.calls {
+ writeln!(
+ output,
+ "\t\t{}_s{} -> f{}_s0",
+ prefix,
+ from,
+ function.index(),
+ )?;
+ }
+
+ Ok(())
+}
+
+fn write_function_expressions(
+ output: &mut String,
+ prefix: &str,
+ fun: &crate::Function,
+ info: Option<&FunctionInfo>,
+) -> Result<(), FmtError> {
+ enum Payload<'a> {
+ Arguments(&'a [Handle<crate::Expression>]),
+ Local(Handle<crate::LocalVariable>),
+ Global(Handle<crate::GlobalVariable>),
+ }
+
+ let mut edges = crate::FastHashMap::<&str, _>::default();
+ let mut payload = None;
+ for (handle, expression) in fun.expressions.iter() {
+ use crate::Expression as E;
+ let (label, color_id) = match *expression {
+ E::Literal(_) => ("Literal".into(), 2),
+ E::Constant(_) => ("Constant".into(), 2),
+ E::ZeroValue(_) => ("ZeroValue".into(), 2),
+ E::Compose { ref components, .. } => {
+ payload = Some(Payload::Arguments(components));
+ ("Compose".into(), 3)
+ }
+ E::Access { base, index } => {
+ edges.insert("base", base);
+ edges.insert("index", index);
+ ("Access".into(), 1)
+ }
+ E::AccessIndex { base, index } => {
+ edges.insert("base", base);
+ (format!("AccessIndex[{index}]").into(), 1)
+ }
+ E::Splat { size, value } => {
+ edges.insert("value", value);
+ (format!("Splat{size:?}").into(), 3)
+ }
+ E::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ edges.insert("vector", vector);
+ (format!("Swizzle{:?}", &pattern[..size as usize]).into(), 3)
+ }
+ E::FunctionArgument(index) => (format!("Argument[{index}]").into(), 1),
+ E::GlobalVariable(h) => {
+ payload = Some(Payload::Global(h));
+ ("Global".into(), 2)
+ }
+ E::LocalVariable(h) => {
+ payload = Some(Payload::Local(h));
+ ("Local".into(), 1)
+ }
+ E::Load { pointer } => {
+ edges.insert("pointer", pointer);
+ ("Load".into(), 4)
+ }
+ E::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset: _,
+ level,
+ depth_ref,
+ } => {
+ edges.insert("image", image);
+ edges.insert("sampler", sampler);
+ edges.insert("coordinate", coordinate);
+ if let Some(expr) = array_index {
+ edges.insert("array_index", expr);
+ }
+ match level {
+ crate::SampleLevel::Auto => {}
+ crate::SampleLevel::Zero => {}
+ crate::SampleLevel::Exact(expr) => {
+ edges.insert("level", expr);
+ }
+ crate::SampleLevel::Bias(expr) => {
+ edges.insert("bias", expr);
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ edges.insert("grad_x", x);
+ edges.insert("grad_y", y);
+ }
+ }
+ if let Some(expr) = depth_ref {
+ edges.insert("depth_ref", expr);
+ }
+ let string = match gather {
+ Some(component) => Cow::Owned(format!("ImageGather{component:?}")),
+ _ => Cow::Borrowed("ImageSample"),
+ };
+ (string, 5)
+ }
+ E::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ edges.insert("image", image);
+ edges.insert("coordinate", coordinate);
+ if let Some(expr) = array_index {
+ edges.insert("array_index", expr);
+ }
+ if let Some(sample) = sample {
+ edges.insert("sample", sample);
+ }
+ if let Some(level) = level {
+ edges.insert("level", level);
+ }
+ ("ImageLoad".into(), 5)
+ }
+ E::ImageQuery { image, query } => {
+ edges.insert("image", image);
+ let args = match query {
+ crate::ImageQuery::Size { level } => {
+ if let Some(expr) = level {
+ edges.insert("level", expr);
+ }
+ Cow::from("ImageSize")
+ }
+ _ => Cow::Owned(format!("{query:?}")),
+ };
+ (args, 7)
+ }
+ E::Unary { op, expr } => {
+ edges.insert("expr", expr);
+ (format!("{op:?}").into(), 6)
+ }
+ E::Binary { op, left, right } => {
+ edges.insert("left", left);
+ edges.insert("right", right);
+ (format!("{op:?}").into(), 6)
+ }
+ E::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ edges.insert("condition", condition);
+ edges.insert("accept", accept);
+ edges.insert("reject", reject);
+ ("Select".into(), 3)
+ }
+ E::Derivative { axis, ctrl, expr } => {
+ edges.insert("", expr);
+ (format!("d{axis:?}{ctrl:?}").into(), 8)
+ }
+ E::Relational { fun, argument } => {
+ edges.insert("arg", argument);
+ (format!("{fun:?}").into(), 6)
+ }
+ E::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ edges.insert("arg", arg);
+ if let Some(expr) = arg1 {
+ edges.insert("arg1", expr);
+ }
+ if let Some(expr) = arg2 {
+ edges.insert("arg2", expr);
+ }
+ if let Some(expr) = arg3 {
+ edges.insert("arg3", expr);
+ }
+ (format!("{fun:?}").into(), 7)
+ }
+ E::As {
+ kind,
+ expr,
+ convert,
+ } => {
+ edges.insert("", expr);
+ let string = match convert {
+ Some(width) => format!("Convert<{kind:?},{width}>"),
+ None => format!("Bitcast<{kind:?}>"),
+ };
+ (string.into(), 3)
+ }
+ E::CallResult(_function) => ("CallResult".into(), 4),
+ E::AtomicResult { .. } => ("AtomicResult".into(), 4),
+ E::WorkGroupUniformLoadResult { .. } => ("WorkGroupUniformLoadResult".into(), 4),
+ E::ArrayLength(expr) => {
+ edges.insert("", expr);
+ ("ArrayLength".into(), 7)
+ }
+ E::RayQueryProceedResult => ("rayQueryProceedResult".into(), 4),
+ E::RayQueryGetIntersection { query, committed } => {
+ edges.insert("", query);
+ let ty = if committed { "Committed" } else { "Candidate" };
+ (format!("rayQueryGet{}Intersection", ty).into(), 4)
+ }
+ };
+
+ // give uniform expressions an outline
+ let color_attr = match info {
+ Some(info) if info[handle].uniformity.non_uniform_result.is_none() => "fillcolor",
+ _ => "color",
+ };
+ writeln!(
+ output,
+ "\t\t{}_e{} [ {}=\"{}\" label=\"{:?} {}\" ]",
+ prefix,
+ handle.index(),
+ color_attr,
+ COLORS[color_id],
+ handle,
+ label,
+ )?;
+
+ for (key, edge) in edges.drain() {
+ writeln!(
+ output,
+ "\t\t{}_e{} -> {}_e{} [ label=\"{}\" ]",
+ prefix,
+ edge.index(),
+ prefix,
+ handle.index(),
+ key,
+ )?;
+ }
+ match payload.take() {
+ Some(Payload::Arguments(list)) => {
+ write!(output, "\t\t{{")?;
+ for &comp in list {
+ write!(output, " {}_e{}", prefix, comp.index())?;
+ }
+ writeln!(output, " }} -> {}_e{}", prefix, handle.index())?;
+ }
+ Some(Payload::Local(h)) => {
+ writeln!(
+ output,
+ "\t\t{}_l{} -> {}_e{}",
+ prefix,
+ h.index(),
+ prefix,
+ handle.index(),
+ )?;
+ }
+ Some(Payload::Global(h)) => {
+ writeln!(
+ output,
+ "\t\tg{} -> {}_e{} [fillcolor=gray]",
+ h.index(),
+ prefix,
+ handle.index(),
+ )?;
+ }
+ None => {}
+ }
+ }
+
+ Ok(())
+}
+
+/// Write shader module to a [`String`].
+pub fn write(
+ module: &crate::Module,
+ mod_info: Option<&ModuleInfo>,
+ options: Options,
+) -> Result<String, FmtError> {
+ use std::fmt::Write as _;
+
+ let mut output = String::new();
+ output += "digraph Module {\n";
+
+ if !options.cfg_only {
+ writeln!(output, "\tsubgraph cluster_globals {{")?;
+ writeln!(output, "\t\tlabel=\"Globals\"")?;
+ for (handle, var) in module.global_variables.iter() {
+ writeln!(
+ output,
+ "\t\tg{} [ shape=hexagon label=\"{:?} {:?}/'{}'\" ]",
+ handle.index(),
+ handle,
+ var.space,
+ name(&var.name),
+ )?;
+ }
+ writeln!(output, "\t}}")?;
+ }
+
+ for (handle, fun) in module.functions.iter() {
+ let prefix = format!("f{}", handle.index());
+ writeln!(output, "\tsubgraph cluster_{prefix} {{")?;
+ writeln!(
+ output,
+ "\t\tlabel=\"Function{:?}/'{}'\"",
+ handle,
+ name(&fun.name)
+ )?;
+ let info = mod_info.map(|a| &a[handle]);
+ write_fun(&mut output, prefix, fun, info, &options)?;
+ writeln!(output, "\t}}")?;
+ }
+ for (ep_index, ep) in module.entry_points.iter().enumerate() {
+ let prefix = format!("ep{ep_index}");
+ writeln!(output, "\tsubgraph cluster_{prefix} {{")?;
+ writeln!(output, "\t\tlabel=\"{:?}/'{}'\"", ep.stage, ep.name)?;
+ let info = mod_info.map(|a| a.get_entry_point(ep_index));
+ write_fun(&mut output, prefix, &ep.function, info, &options)?;
+ writeln!(output, "\t}}")?;
+ }
+
+ output += "}\n";
+ Ok(output)
+}
diff --git a/third_party/rust/naga/src/back/glsl/features.rs b/third_party/rust/naga/src/back/glsl/features.rs
new file mode 100644
index 0000000000..e7de05f695
--- /dev/null
+++ b/third_party/rust/naga/src/back/glsl/features.rs
@@ -0,0 +1,536 @@
+use super::{BackendResult, Error, Version, Writer};
+use crate::{
+ back::glsl::{Options, WriterFlags},
+ AddressSpace, Binding, Expression, Handle, ImageClass, ImageDimension, Interpolation, Sampling,
+ Scalar, ScalarKind, ShaderStage, StorageFormat, Type, TypeInner,
+};
+use std::fmt::Write;
+
+bitflags::bitflags! {
+ /// Structure used to encode additions to GLSL that aren't supported by all versions.
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct Features: u32 {
+ /// Buffer address space support.
+ const BUFFER_STORAGE = 1;
+ const ARRAY_OF_ARRAYS = 1 << 1;
+ /// 8 byte floats.
+ const DOUBLE_TYPE = 1 << 2;
+ /// More image formats.
+ const FULL_IMAGE_FORMATS = 1 << 3;
+ const MULTISAMPLED_TEXTURES = 1 << 4;
+ const MULTISAMPLED_TEXTURE_ARRAYS = 1 << 5;
+ const CUBE_TEXTURES_ARRAY = 1 << 6;
+ const COMPUTE_SHADER = 1 << 7;
+ /// Image load and early depth tests.
+ const IMAGE_LOAD_STORE = 1 << 8;
+ const CONSERVATIVE_DEPTH = 1 << 9;
+ /// Interpolation and auxiliary qualifiers.
+ ///
+ /// Perspective, Flat, and Centroid are available in all GLSL versions we support.
+ const NOPERSPECTIVE_QUALIFIER = 1 << 11;
+ const SAMPLE_QUALIFIER = 1 << 12;
+ const CLIP_DISTANCE = 1 << 13;
+ const CULL_DISTANCE = 1 << 14;
+ /// Sample ID.
+ const SAMPLE_VARIABLES = 1 << 15;
+ /// Arrays with a dynamic length.
+ const DYNAMIC_ARRAY_SIZE = 1 << 16;
+ const MULTI_VIEW = 1 << 17;
+ /// Texture samples query
+ const TEXTURE_SAMPLES = 1 << 18;
+ /// Texture levels query
+ const TEXTURE_LEVELS = 1 << 19;
+ /// Image size query
+ const IMAGE_SIZE = 1 << 20;
+ /// Dual source blending
+ const DUAL_SOURCE_BLENDING = 1 << 21;
+ /// Instance index
+ ///
+ /// We can always support this, either through the language or a polyfill
+ const INSTANCE_INDEX = 1 << 22;
+ }
+}
+
+/// Helper structure used to store the required [`Features`] needed to output a
+/// [`Module`](crate::Module)
+///
+/// Provides helper methods to check for availability and writing required extensions
+pub struct FeaturesManager(Features);
+
+impl FeaturesManager {
+ /// Creates a new [`FeaturesManager`] instance
+ pub const fn new() -> Self {
+ Self(Features::empty())
+ }
+
+ /// Adds to the list of required [`Features`]
+ pub fn request(&mut self, features: Features) {
+ self.0 |= features
+ }
+
+ /// Checks if the list of features [`Features`] contains the specified [`Features`]
+ pub fn contains(&mut self, features: Features) -> bool {
+ self.0.contains(features)
+ }
+
+ /// Checks that all required [`Features`] are available for the specified
+ /// [`Version`] otherwise returns an [`Error::MissingFeatures`].
+ pub fn check_availability(&self, version: Version) -> BackendResult {
+ // Will store all the features that are unavailable
+ let mut missing = Features::empty();
+
+ // Helper macro to check for feature availability
+ macro_rules! check_feature {
+ // Used when only core glsl supports the feature
+ ($feature:ident, $core:literal) => {
+ if self.0.contains(Features::$feature)
+ && (version < Version::Desktop($core) || version.is_es())
+ {
+ missing |= Features::$feature;
+ }
+ };
+ // Used when both core and es support the feature
+ ($feature:ident, $core:literal, $es:literal) => {
+ if self.0.contains(Features::$feature)
+ && (version < Version::Desktop($core) || version < Version::new_gles($es))
+ {
+ missing |= Features::$feature;
+ }
+ };
+ }
+
+ check_feature!(COMPUTE_SHADER, 420, 310);
+ check_feature!(BUFFER_STORAGE, 400, 310);
+ check_feature!(DOUBLE_TYPE, 150);
+ check_feature!(CUBE_TEXTURES_ARRAY, 130, 310);
+ check_feature!(MULTISAMPLED_TEXTURES, 150, 300);
+ check_feature!(MULTISAMPLED_TEXTURE_ARRAYS, 150, 310);
+ check_feature!(ARRAY_OF_ARRAYS, 120, 310);
+ check_feature!(IMAGE_LOAD_STORE, 130, 310);
+ check_feature!(CONSERVATIVE_DEPTH, 130, 300);
+ check_feature!(NOPERSPECTIVE_QUALIFIER, 130);
+ check_feature!(SAMPLE_QUALIFIER, 400, 320);
+ check_feature!(CLIP_DISTANCE, 130, 300 /* with extension */);
+ check_feature!(CULL_DISTANCE, 450, 300 /* with extension */);
+ check_feature!(SAMPLE_VARIABLES, 400, 300);
+ check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
+ check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */);
+ match version {
+ Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300),
+ _ => check_feature!(MULTI_VIEW, 140, 310),
+ };
+ // Only available on glsl core, this means that opengl es can't query the number
+ // of samples nor levels in a image and neither do bound checks on the sample nor
+ // the level argument of texelFecth
+ check_feature!(TEXTURE_SAMPLES, 150);
+ check_feature!(TEXTURE_LEVELS, 130);
+ check_feature!(IMAGE_SIZE, 430, 310);
+
+ // Return an error if there are missing features
+ if missing.is_empty() {
+ Ok(())
+ } else {
+ Err(Error::MissingFeatures(missing))
+ }
+ }
+
+ /// Helper method used to write all needed extensions
+ ///
+ /// # Notes
+ /// This won't check for feature availability so it might output extensions that aren't even
+ /// supported.[`check_availability`](Self::check_availability) will check feature availability
+ pub fn write(&self, options: &Options, mut out: impl Write) -> BackendResult {
+ if self.0.contains(Features::COMPUTE_SHADER) && !options.version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_compute_shader.txt
+ writeln!(out, "#extension GL_ARB_compute_shader : require")?;
+ }
+
+ if self.0.contains(Features::BUFFER_STORAGE) && !options.version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_storage_buffer_object.txt
+ writeln!(
+ out,
+ "#extension GL_ARB_shader_storage_buffer_object : require"
+ )?;
+ }
+
+ if self.0.contains(Features::DOUBLE_TYPE) && options.version < Version::Desktop(400) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_gpu_shader_fp64.txt
+ writeln!(out, "#extension GL_ARB_gpu_shader_fp64 : require")?;
+ }
+
+ if self.0.contains(Features::CUBE_TEXTURES_ARRAY) {
+ if options.version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_cube_map_array.txt
+ writeln!(out, "#extension GL_EXT_texture_cube_map_array : require")?;
+ } else if options.version < Version::Desktop(400) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_cube_map_array.txt
+ writeln!(out, "#extension GL_ARB_texture_cube_map_array : require")?;
+ }
+ }
+
+ if self.0.contains(Features::MULTISAMPLED_TEXTURE_ARRAYS) && options.version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_texture_storage_multisample_2d_array.txt
+ writeln!(
+ out,
+ "#extension GL_OES_texture_storage_multisample_2d_array : require"
+ )?;
+ }
+
+ if self.0.contains(Features::ARRAY_OF_ARRAYS) && options.version < Version::Desktop(430) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_arrays_of_arrays.txt
+ writeln!(out, "#extension ARB_arrays_of_arrays : require")?;
+ }
+
+ if self.0.contains(Features::IMAGE_LOAD_STORE) {
+ if self.0.contains(Features::FULL_IMAGE_FORMATS) && options.version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/NV/NV_image_formats.txt
+ writeln!(out, "#extension GL_NV_image_formats : require")?;
+ }
+
+ if options.version < Version::Desktop(420) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_image_load_store.txt
+ writeln!(out, "#extension GL_ARB_shader_image_load_store : require")?;
+ }
+ }
+
+ if self.0.contains(Features::CONSERVATIVE_DEPTH) {
+ if options.version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_conservative_depth.txt
+ writeln!(out, "#extension GL_EXT_conservative_depth : require")?;
+ }
+
+ if options.version < Version::Desktop(420) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt
+ writeln!(out, "#extension GL_ARB_conservative_depth : require")?;
+ }
+ }
+
+ if (self.0.contains(Features::CLIP_DISTANCE) || self.0.contains(Features::CULL_DISTANCE))
+ && options.version.is_es()
+ {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_clip_cull_distance.txt
+ writeln!(out, "#extension GL_EXT_clip_cull_distance : require")?;
+ }
+
+ if self.0.contains(Features::SAMPLE_VARIABLES) && options.version.is_es() {
+ // https://www.khronos.org/registry/OpenGL/extensions/OES/OES_sample_variables.txt
+ writeln!(out, "#extension GL_OES_sample_variables : require")?;
+ }
+
+ if self.0.contains(Features::MULTI_VIEW) {
+ if let Version::Embedded { is_webgl: true, .. } = options.version {
+ // https://www.khronos.org/registry/OpenGL/extensions/OVR/OVR_multiview2.txt
+ writeln!(out, "#extension GL_OVR_multiview2 : require")?;
+ } else {
+ // https://github.com/KhronosGroup/GLSL/blob/master/extensions/ext/GL_EXT_multiview.txt
+ writeln!(out, "#extension GL_EXT_multiview : require")?;
+ }
+ }
+
+ if self.0.contains(Features::TEXTURE_SAMPLES) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_shader_texture_image_samples.txt
+ writeln!(
+ out,
+ "#extension GL_ARB_shader_texture_image_samples : require"
+ )?;
+ }
+
+ if self.0.contains(Features::TEXTURE_LEVELS) && options.version < Version::Desktop(430) {
+ // https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_texture_query_levels.txt
+ writeln!(out, "#extension GL_ARB_texture_query_levels : require")?;
+ }
+ if self.0.contains(Features::DUAL_SOURCE_BLENDING) && options.version.is_es() {
+ // https://registry.khronos.org/OpenGL/extensions/EXT/EXT_blend_func_extended.txt
+ writeln!(out, "#extension GL_EXT_blend_func_extended : require")?;
+ }
+
+ if self.0.contains(Features::INSTANCE_INDEX) {
+ if options.writer_flags.contains(WriterFlags::DRAW_PARAMETERS) {
+ // https://registry.khronos.org/OpenGL/extensions/ARB/ARB_shader_draw_parameters.txt
+ writeln!(out, "#extension GL_ARB_shader_draw_parameters : require")?;
+ }
+ }
+
+ Ok(())
+ }
+}
+
+impl<'a, W> Writer<'a, W> {
+ /// Helper method that searches the module for all the needed [`Features`]
+ ///
+ /// # Errors
+ /// If the version doesn't support any of the needed [`Features`] a
+ /// [`Error::MissingFeatures`] will be returned
+ pub(super) fn collect_required_features(&mut self) -> BackendResult {
+ let ep_info = self.info.get_entry_point(self.entry_point_idx as usize);
+
+ if let Some(depth_test) = self.entry_point.early_depth_test {
+ // If IMAGE_LOAD_STORE is supported for this version of GLSL
+ if self.options.version.supports_early_depth_test() {
+ self.features.request(Features::IMAGE_LOAD_STORE);
+ }
+
+ if depth_test.conservative.is_some() {
+ self.features.request(Features::CONSERVATIVE_DEPTH);
+ }
+ }
+
+ for arg in self.entry_point.function.arguments.iter() {
+ self.varying_required_features(arg.binding.as_ref(), arg.ty);
+ }
+ if let Some(ref result) = self.entry_point.function.result {
+ self.varying_required_features(result.binding.as_ref(), result.ty);
+ }
+
+ if let ShaderStage::Compute = self.entry_point.stage {
+ self.features.request(Features::COMPUTE_SHADER)
+ }
+
+ if self.multiview.is_some() {
+ self.features.request(Features::MULTI_VIEW);
+ }
+
+ for (ty_handle, ty) in self.module.types.iter() {
+ match ty.inner {
+ TypeInner::Scalar(scalar)
+ | TypeInner::Vector { scalar, .. }
+ | TypeInner::Matrix { scalar, .. } => self.scalar_required_features(scalar),
+ TypeInner::Array { base, size, .. } => {
+ if let TypeInner::Array { .. } = self.module.types[base].inner {
+ self.features.request(Features::ARRAY_OF_ARRAYS)
+ }
+
+ // If the array is dynamically sized
+ if size == crate::ArraySize::Dynamic {
+ let mut is_used = false;
+
+ // Check if this type is used in a global that is needed by the current entrypoint
+ for (global_handle, global) in self.module.global_variables.iter() {
+ // Skip unused globals
+ if ep_info[global_handle].is_empty() {
+ continue;
+ }
+
+ // If this array is the type of a global, then this array is used
+ if global.ty == ty_handle {
+ is_used = true;
+ break;
+ }
+
+ // If the type of this global is a struct
+ if let crate::TypeInner::Struct { ref members, .. } =
+ self.module.types[global.ty].inner
+ {
+ // Check the last element of the struct to see if it's type uses
+ // this array
+ if let Some(last) = members.last() {
+ if last.ty == ty_handle {
+ is_used = true;
+ break;
+ }
+ }
+ }
+ }
+
+ // If this dynamically size array is used, we need dynamic array size support
+ if is_used {
+ self.features.request(Features::DYNAMIC_ARRAY_SIZE);
+ }
+ }
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ if arrayed && dim == ImageDimension::Cube {
+ self.features.request(Features::CUBE_TEXTURES_ARRAY)
+ }
+
+ match class {
+ ImageClass::Sampled { multi: true, .. }
+ | ImageClass::Depth { multi: true } => {
+ self.features.request(Features::MULTISAMPLED_TEXTURES);
+ if arrayed {
+ self.features.request(Features::MULTISAMPLED_TEXTURE_ARRAYS);
+ }
+ }
+ ImageClass::Storage { format, .. } => match format {
+ StorageFormat::R8Unorm
+ | StorageFormat::R8Snorm
+ | StorageFormat::R8Uint
+ | StorageFormat::R8Sint
+ | StorageFormat::R16Uint
+ | StorageFormat::R16Sint
+ | StorageFormat::R16Float
+ | StorageFormat::Rg8Unorm
+ | StorageFormat::Rg8Snorm
+ | StorageFormat::Rg8Uint
+ | StorageFormat::Rg8Sint
+ | StorageFormat::Rg16Uint
+ | StorageFormat::Rg16Sint
+ | StorageFormat::Rg16Float
+ | StorageFormat::Rgb10a2Uint
+ | StorageFormat::Rgb10a2Unorm
+ | StorageFormat::Rg11b10Float
+ | StorageFormat::Rg32Uint
+ | StorageFormat::Rg32Sint
+ | StorageFormat::Rg32Float => {
+ self.features.request(Features::FULL_IMAGE_FORMATS)
+ }
+ _ => {}
+ },
+ ImageClass::Sampled { multi: false, .. }
+ | ImageClass::Depth { multi: false } => {}
+ }
+ }
+ _ => {}
+ }
+ }
+
+ let mut push_constant_used = false;
+
+ for (handle, global) in self.module.global_variables.iter() {
+ if ep_info[handle].is_empty() {
+ continue;
+ }
+ match global.space {
+ AddressSpace::WorkGroup => self.features.request(Features::COMPUTE_SHADER),
+ AddressSpace::Storage { .. } => self.features.request(Features::BUFFER_STORAGE),
+ AddressSpace::PushConstant => {
+ if push_constant_used {
+ return Err(Error::MultiplePushConstants);
+ }
+ push_constant_used = true;
+ }
+ _ => {}
+ }
+ }
+
+ // We will need to pass some of the members to a closure, so we need
+ // to separate them otherwise the borrow checker will complain, this
+ // shouldn't be needed in rust 2021
+ let &mut Self {
+ module,
+ info,
+ ref mut features,
+ entry_point,
+ entry_point_idx,
+ ref policies,
+ ..
+ } = self;
+
+ // Loop trough all expressions in both functions and the entry point
+ // to check for needed features
+ for (expressions, info) in module
+ .functions
+ .iter()
+ .map(|(h, f)| (&f.expressions, &info[h]))
+ .chain(std::iter::once((
+ &entry_point.function.expressions,
+ info.get_entry_point(entry_point_idx as usize),
+ )))
+ {
+ for (_, expr) in expressions.iter() {
+ match *expr {
+ // Check for queries that need aditonal features
+ Expression::ImageQuery {
+ image,
+ query,
+ ..
+ } => match query {
+ // Storage images use `imageSize` which is only available
+ // in glsl > 420
+ //
+ // layers queries are also implemented as size queries
+ crate::ImageQuery::Size { .. } | crate::ImageQuery::NumLayers => {
+ if let TypeInner::Image {
+ class: crate::ImageClass::Storage { .. }, ..
+ } = *info[image].ty.inner_with(&module.types) {
+ features.request(Features::IMAGE_SIZE)
+ }
+ },
+ crate::ImageQuery::NumLevels => features.request(Features::TEXTURE_LEVELS),
+ crate::ImageQuery::NumSamples => features.request(Features::TEXTURE_SAMPLES),
+ }
+ ,
+ // Check for image loads that needs bound checking on the sample
+ // or level argument since this requires a feature
+ Expression::ImageLoad {
+ sample, level, ..
+ } => {
+ if policies.image_load != crate::proc::BoundsCheckPolicy::Unchecked {
+ if sample.is_some() {
+ features.request(Features::TEXTURE_SAMPLES)
+ }
+
+ if level.is_some() {
+ features.request(Features::TEXTURE_LEVELS)
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+
+ self.features.check_availability(self.options.version)
+ }
+
+ /// Helper method that checks the [`Features`] needed by a scalar
+ fn scalar_required_features(&mut self, scalar: Scalar) {
+ if scalar.kind == ScalarKind::Float && scalar.width == 8 {
+ self.features.request(Features::DOUBLE_TYPE);
+ }
+ }
+
+ fn varying_required_features(&mut self, binding: Option<&Binding>, ty: Handle<Type>) {
+ match self.module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for member in members {
+ self.varying_required_features(member.binding.as_ref(), member.ty);
+ }
+ }
+ _ => {
+ if let Some(binding) = binding {
+ match *binding {
+ Binding::BuiltIn(built_in) => match built_in {
+ crate::BuiltIn::ClipDistance => {
+ self.features.request(Features::CLIP_DISTANCE)
+ }
+ crate::BuiltIn::CullDistance => {
+ self.features.request(Features::CULL_DISTANCE)
+ }
+ crate::BuiltIn::SampleIndex => {
+ self.features.request(Features::SAMPLE_VARIABLES)
+ }
+ crate::BuiltIn::ViewIndex => {
+ self.features.request(Features::MULTI_VIEW)
+ }
+ crate::BuiltIn::InstanceIndex => {
+ self.features.request(Features::INSTANCE_INDEX)
+ }
+ _ => {}
+ },
+ Binding::Location {
+ location: _,
+ interpolation,
+ sampling,
+ second_blend_source,
+ } => {
+ if interpolation == Some(Interpolation::Linear) {
+ self.features.request(Features::NOPERSPECTIVE_QUALIFIER);
+ }
+ if sampling == Some(Sampling::Sample) {
+ self.features.request(Features::SAMPLE_QUALIFIER);
+ }
+ if second_blend_source {
+ self.features.request(Features::DUAL_SOURCE_BLENDING);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/glsl/keywords.rs b/third_party/rust/naga/src/back/glsl/keywords.rs
new file mode 100644
index 0000000000..857c935e68
--- /dev/null
+++ b/third_party/rust/naga/src/back/glsl/keywords.rs
@@ -0,0 +1,484 @@
+pub const RESERVED_KEYWORDS: &[&str] = &[
+ //
+ // GLSL 4.6 keywords, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L2004-L2322
+ // GLSL ES 3.2 keywords, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/es/3.2/GLSL_ES_Specification_3.20.html#L2166-L2478
+ //
+ // Note: The GLSL ES 3.2 keywords are the same as GLSL 4.6 keywords with some residing in the reserved section.
+ // The only exception are the missing Vulkan keywords which I think is an oversight (see https://github.com/KhronosGroup/OpenGL-Registry/issues/585).
+ //
+ "const",
+ "uniform",
+ "buffer",
+ "shared",
+ "attribute",
+ "varying",
+ "coherent",
+ "volatile",
+ "restrict",
+ "readonly",
+ "writeonly",
+ "atomic_uint",
+ "layout",
+ "centroid",
+ "flat",
+ "smooth",
+ "noperspective",
+ "patch",
+ "sample",
+ "invariant",
+ "precise",
+ "break",
+ "continue",
+ "do",
+ "for",
+ "while",
+ "switch",
+ "case",
+ "default",
+ "if",
+ "else",
+ "subroutine",
+ "in",
+ "out",
+ "inout",
+ "int",
+ "void",
+ "bool",
+ "true",
+ "false",
+ "float",
+ "double",
+ "discard",
+ "return",
+ "vec2",
+ "vec3",
+ "vec4",
+ "ivec2",
+ "ivec3",
+ "ivec4",
+ "bvec2",
+ "bvec3",
+ "bvec4",
+ "uint",
+ "uvec2",
+ "uvec3",
+ "uvec4",
+ "dvec2",
+ "dvec3",
+ "dvec4",
+ "mat2",
+ "mat3",
+ "mat4",
+ "mat2x2",
+ "mat2x3",
+ "mat2x4",
+ "mat3x2",
+ "mat3x3",
+ "mat3x4",
+ "mat4x2",
+ "mat4x3",
+ "mat4x4",
+ "dmat2",
+ "dmat3",
+ "dmat4",
+ "dmat2x2",
+ "dmat2x3",
+ "dmat2x4",
+ "dmat3x2",
+ "dmat3x3",
+ "dmat3x4",
+ "dmat4x2",
+ "dmat4x3",
+ "dmat4x4",
+ "lowp",
+ "mediump",
+ "highp",
+ "precision",
+ "sampler1D",
+ "sampler1DShadow",
+ "sampler1DArray",
+ "sampler1DArrayShadow",
+ "isampler1D",
+ "isampler1DArray",
+ "usampler1D",
+ "usampler1DArray",
+ "sampler2D",
+ "sampler2DShadow",
+ "sampler2DArray",
+ "sampler2DArrayShadow",
+ "isampler2D",
+ "isampler2DArray",
+ "usampler2D",
+ "usampler2DArray",
+ "sampler2DRect",
+ "sampler2DRectShadow",
+ "isampler2DRect",
+ "usampler2DRect",
+ "sampler2DMS",
+ "isampler2DMS",
+ "usampler2DMS",
+ "sampler2DMSArray",
+ "isampler2DMSArray",
+ "usampler2DMSArray",
+ "sampler3D",
+ "isampler3D",
+ "usampler3D",
+ "samplerCube",
+ "samplerCubeShadow",
+ "isamplerCube",
+ "usamplerCube",
+ "samplerCubeArray",
+ "samplerCubeArrayShadow",
+ "isamplerCubeArray",
+ "usamplerCubeArray",
+ "samplerBuffer",
+ "isamplerBuffer",
+ "usamplerBuffer",
+ "image1D",
+ "iimage1D",
+ "uimage1D",
+ "image1DArray",
+ "iimage1DArray",
+ "uimage1DArray",
+ "image2D",
+ "iimage2D",
+ "uimage2D",
+ "image2DArray",
+ "iimage2DArray",
+ "uimage2DArray",
+ "image2DRect",
+ "iimage2DRect",
+ "uimage2DRect",
+ "image2DMS",
+ "iimage2DMS",
+ "uimage2DMS",
+ "image2DMSArray",
+ "iimage2DMSArray",
+ "uimage2DMSArray",
+ "image3D",
+ "iimage3D",
+ "uimage3D",
+ "imageCube",
+ "iimageCube",
+ "uimageCube",
+ "imageCubeArray",
+ "iimageCubeArray",
+ "uimageCubeArray",
+ "imageBuffer",
+ "iimageBuffer",
+ "uimageBuffer",
+ "struct",
+ // Vulkan keywords
+ "texture1D",
+ "texture1DArray",
+ "itexture1D",
+ "itexture1DArray",
+ "utexture1D",
+ "utexture1DArray",
+ "texture2D",
+ "texture2DArray",
+ "itexture2D",
+ "itexture2DArray",
+ "utexture2D",
+ "utexture2DArray",
+ "texture2DRect",
+ "itexture2DRect",
+ "utexture2DRect",
+ "texture2DMS",
+ "itexture2DMS",
+ "utexture2DMS",
+ "texture2DMSArray",
+ "itexture2DMSArray",
+ "utexture2DMSArray",
+ "texture3D",
+ "itexture3D",
+ "utexture3D",
+ "textureCube",
+ "itextureCube",
+ "utextureCube",
+ "textureCubeArray",
+ "itextureCubeArray",
+ "utextureCubeArray",
+ "textureBuffer",
+ "itextureBuffer",
+ "utextureBuffer",
+ "sampler",
+ "samplerShadow",
+ "subpassInput",
+ "isubpassInput",
+ "usubpassInput",
+ "subpassInputMS",
+ "isubpassInputMS",
+ "usubpassInputMS",
+ // Reserved keywords
+ "common",
+ "partition",
+ "active",
+ "asm",
+ "class",
+ "union",
+ "enum",
+ "typedef",
+ "template",
+ "this",
+ "resource",
+ "goto",
+ "inline",
+ "noinline",
+ "public",
+ "static",
+ "extern",
+ "external",
+ "interface",
+ "long",
+ "short",
+ "half",
+ "fixed",
+ "unsigned",
+ "superp",
+ "input",
+ "output",
+ "hvec2",
+ "hvec3",
+ "hvec4",
+ "fvec2",
+ "fvec3",
+ "fvec4",
+ "filter",
+ "sizeof",
+ "cast",
+ "namespace",
+ "using",
+ "sampler3DRect",
+ //
+ // GLSL 4.6 Built-In Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13314
+ //
+ // Angle and Trigonometry Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13469-L13561C5
+ "radians",
+ "degrees",
+ "sin",
+ "cos",
+ "tan",
+ "asin",
+ "acos",
+ "atan",
+ "sinh",
+ "cosh",
+ "tanh",
+ "asinh",
+ "acosh",
+ "atanh",
+ // Exponential Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13569-L13620
+ "pow",
+ "exp",
+ "log",
+ "exp2",
+ "log2",
+ "sqrt",
+ "inversesqrt",
+ // Common Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13628-L13908
+ "abs",
+ "sign",
+ "floor",
+ "trunc",
+ "round",
+ "roundEven",
+ "ceil",
+ "fract",
+ "mod",
+ "modf",
+ "min",
+ "max",
+ "clamp",
+ "mix",
+ "step",
+ "smoothstep",
+ "isnan",
+ "isinf",
+ "floatBitsToInt",
+ "floatBitsToUint",
+ "intBitsToFloat",
+ "uintBitsToFloat",
+ "fma",
+ "frexp",
+ "ldexp",
+ // Floating-Point Pack and Unpack Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L13916-L14007
+ "packUnorm2x16",
+ "packSnorm2x16",
+ "packUnorm4x8",
+ "packSnorm4x8",
+ "unpackUnorm2x16",
+ "unpackSnorm2x16",
+ "unpackUnorm4x8",
+ "unpackSnorm4x8",
+ "packHalf2x16",
+ "unpackHalf2x16",
+ "packDouble2x32",
+ "unpackDouble2x32",
+ // Geometric Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14014-L14121
+ "length",
+ "distance",
+ "dot",
+ "cross",
+ "normalize",
+ "ftransform",
+ "faceforward",
+ "reflect",
+ "refract",
+ // Matrix Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14151-L14215
+ "matrixCompMult",
+ "outerProduct",
+ "transpose",
+ "determinant",
+ "inverse",
+ // Vector Relational Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14259-L14322
+ "lessThan",
+ "lessThanEqual",
+ "greaterThan",
+ "greaterThanEqual",
+ "equal",
+ "notEqual",
+ "any",
+ "all",
+ "not",
+ // Integer Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14335-L14432
+ "uaddCarry",
+ "usubBorrow",
+ "umulExtended",
+ "imulExtended",
+ "bitfieldExtract",
+ "bitfieldInsert",
+ "bitfieldReverse",
+ "bitCount",
+ "findLSB",
+ "findMSB",
+ // Texture Query Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14645-L14732
+ "textureSize",
+ "textureQueryLod",
+ "textureQueryLevels",
+ "textureSamples",
+ // Texel Lookup Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L14736-L14997
+ "texture",
+ "textureProj",
+ "textureLod",
+ "textureOffset",
+ "texelFetch",
+ "texelFetchOffset",
+ "textureProjOffset",
+ "textureLodOffset",
+ "textureProjLod",
+ "textureProjLodOffset",
+ "textureGrad",
+ "textureGradOffset",
+ "textureProjGrad",
+ "textureProjGradOffset",
+ // Texture Gather Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15077-L15154
+ "textureGather",
+ "textureGatherOffset",
+ "textureGatherOffsets",
+ // Compatibility Profile Texture Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15161-L15220
+ "texture1D",
+ "texture1DProj",
+ "texture1DLod",
+ "texture1DProjLod",
+ "texture2D",
+ "texture2DProj",
+ "texture2DLod",
+ "texture2DProjLod",
+ "texture3D",
+ "texture3DProj",
+ "texture3DLod",
+ "texture3DProjLod",
+ "textureCube",
+ "textureCubeLod",
+ "shadow1D",
+ "shadow2D",
+ "shadow1DProj",
+ "shadow2DProj",
+ "shadow1DLod",
+ "shadow2DLod",
+ "shadow1DProjLod",
+ "shadow2DProjLod",
+ // Atomic Counter Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15241-L15531
+ "atomicCounterIncrement",
+ "atomicCounterDecrement",
+ "atomicCounter",
+ "atomicCounterAdd",
+ "atomicCounterSubtract",
+ "atomicCounterMin",
+ "atomicCounterMax",
+ "atomicCounterAnd",
+ "atomicCounterOr",
+ "atomicCounterXor",
+ "atomicCounterExchange",
+ "atomicCounterCompSwap",
+ // Atomic Memory Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15563-L15624
+ "atomicAdd",
+ "atomicMin",
+ "atomicMax",
+ "atomicAnd",
+ "atomicOr",
+ "atomicXor",
+ "atomicExchange",
+ "atomicCompSwap",
+ // Image Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15763-L15878
+ "imageSize",
+ "imageSamples",
+ "imageLoad",
+ "imageStore",
+ "imageAtomicAdd",
+ "imageAtomicMin",
+ "imageAtomicMax",
+ "imageAtomicAnd",
+ "imageAtomicOr",
+ "imageAtomicXor",
+ "imageAtomicExchange",
+ "imageAtomicCompSwap",
+ // Geometry Shader Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L15886-L15932
+ "EmitStreamVertex",
+ "EndStreamPrimitive",
+ "EmitVertex",
+ "EndPrimitive",
+ // Fragment Processing Functions, Derivative Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16041-L16114
+ "dFdx",
+ "dFdy",
+ "dFdxFine",
+ "dFdyFine",
+ "dFdxCoarse",
+ "dFdyCoarse",
+ "fwidth",
+ "fwidthFine",
+ "fwidthCoarse",
+ // Fragment Processing Functions, Interpolation Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16150-L16198
+ "interpolateAtCentroid",
+ "interpolateAtSample",
+ "interpolateAtOffset",
+ // Noise Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16214-L16243
+ "noise1",
+ "noise2",
+ "noise3",
+ "noise4",
+ // Shader Invocation Control Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16255-L16276
+ "barrier",
+ // Shader Memory Control Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16336-L16382
+ "memoryBarrier",
+ "memoryBarrierAtomicCounter",
+ "memoryBarrierBuffer",
+ "memoryBarrierShared",
+ "memoryBarrierImage",
+ "groupMemoryBarrier",
+ // Subpass-Input Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16451-L16470
+ "subpassLoad",
+ // Shader Invocation Group Functions, from https://github.com/KhronosGroup/OpenGL-Registry/blob/d00e11dc1a1ffba581d633f21f70202051248d5c/specs/gl/GLSLangSpec.4.60.html#L16483-L16511
+ "anyInvocation",
+ "allInvocations",
+ "allInvocationsEqual",
+ //
+ // entry point name (should not be shadowed)
+ //
+ "main",
+ // Naga utilities:
+ super::MODF_FUNCTION,
+ super::FREXP_FUNCTION,
+ super::FIRST_INSTANCE_BINDING,
+];
diff --git a/third_party/rust/naga/src/back/glsl/mod.rs b/third_party/rust/naga/src/back/glsl/mod.rs
new file mode 100644
index 0000000000..e346d43257
--- /dev/null
+++ b/third_party/rust/naga/src/back/glsl/mod.rs
@@ -0,0 +1,4532 @@
+/*!
+Backend for [GLSL][glsl] (OpenGL Shading Language).
+
+The main structure is [`Writer`], it maintains internal state that is used
+to output a [`Module`](crate::Module) into glsl
+
+# Supported versions
+### Core
+- 330
+- 400
+- 410
+- 420
+- 430
+- 450
+
+### ES
+- 300
+- 310
+
+[glsl]: https://www.khronos.org/registry/OpenGL/index_gl.php
+*/
+
+// GLSL is mostly a superset of C but it also removes some parts of it this is a list of relevant
+// aspects for this backend.
+//
+// The most notable change is the introduction of the version preprocessor directive that must
+// always be the first line of a glsl file and is written as
+// `#version number profile`
+// `number` is the version itself (i.e. 300) and `profile` is the
+// shader profile we only support "core" and "es", the former is used in desktop applications and
+// the later is used in embedded contexts, mobile devices and browsers. Each one as it's own
+// versions (at the time of writing this the latest version for "core" is 460 and for "es" is 320)
+//
+// Other important preprocessor addition is the extension directive which is written as
+// `#extension name: behaviour`
+// Extensions provide increased features in a plugin fashion but they aren't required to be
+// supported hence why they are called extensions, that's why `behaviour` is used it specifies
+// whether the extension is strictly required or if it should only be enabled if needed. In our case
+// when we use extensions we set behaviour to `require` always.
+//
+// The only thing that glsl removes that makes a difference are pointers.
+//
+// Additions that are relevant for the backend are the discard keyword, the introduction of
+// vector, matrices, samplers, image types and functions that provide common shader operations
+
+pub use features::Features;
+
+use crate::{
+ back,
+ proc::{self, NameKey},
+ valid, Handle, ShaderStage, TypeInner,
+};
+use features::FeaturesManager;
+use std::{
+ cmp::Ordering,
+ fmt,
+ fmt::{Error as FmtError, Write},
+ mem,
+};
+use thiserror::Error;
+
+/// Contains the features related code and the features querying method
+mod features;
+/// Contains a constant with a slice of all the reserved keywords RESERVED_KEYWORDS
+mod keywords;
+
+/// List of supported `core` GLSL versions.
+pub const SUPPORTED_CORE_VERSIONS: &[u16] = &[140, 150, 330, 400, 410, 420, 430, 440, 450, 460];
+/// List of supported `es` GLSL versions.
+pub const SUPPORTED_ES_VERSIONS: &[u16] = &[300, 310, 320];
+
+/// The suffix of the variable that will hold the calculated clamped level
+/// of detail for bounds checking in `ImageLoad`
+const CLAMPED_LOD_SUFFIX: &str = "_clamped_lod";
+
+pub(crate) const MODF_FUNCTION: &str = "naga_modf";
+pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
+
+// Must match code in glsl_built_in
+pub const FIRST_INSTANCE_BINDING: &str = "naga_vs_first_instance";
+
+/// Mapping between resources and bindings.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, u8>;
+
+impl crate::AtomicFunction {
+ const fn to_glsl(self) -> &'static str {
+ match self {
+ Self::Add | Self::Subtract => "Add",
+ Self::And => "And",
+ Self::InclusiveOr => "Or",
+ Self::ExclusiveOr => "Xor",
+ Self::Min => "Min",
+ Self::Max => "Max",
+ Self::Exchange { compare: None } => "Exchange",
+ Self::Exchange { compare: Some(_) } => "", //TODO
+ }
+ }
+}
+
+impl crate::AddressSpace {
+ const fn is_buffer(&self) -> bool {
+ match *self {
+ crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } => true,
+ _ => false,
+ }
+ }
+
+ /// Whether a variable with this address space can be initialized
+ const fn initializable(&self) -> bool {
+ match *self {
+ crate::AddressSpace::Function | crate::AddressSpace::Private => true,
+ crate::AddressSpace::WorkGroup
+ | crate::AddressSpace::Uniform
+ | crate::AddressSpace::Storage { .. }
+ | crate::AddressSpace::Handle
+ | crate::AddressSpace::PushConstant => false,
+ }
+ }
+}
+
+/// A GLSL version.
+#[derive(Debug, Copy, Clone, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum Version {
+ /// `core` GLSL.
+ Desktop(u16),
+ /// `es` GLSL.
+ Embedded { version: u16, is_webgl: bool },
+}
+
+impl Version {
+ /// Create a new gles version
+ pub const fn new_gles(version: u16) -> Self {
+ Self::Embedded {
+ version,
+ is_webgl: false,
+ }
+ }
+
+ /// Returns true if self is `Version::Embedded` (i.e. is a es version)
+ const fn is_es(&self) -> bool {
+ match *self {
+ Version::Desktop(_) => false,
+ Version::Embedded { .. } => true,
+ }
+ }
+
+ /// Returns true if targeting WebGL
+ const fn is_webgl(&self) -> bool {
+ match *self {
+ Version::Desktop(_) => false,
+ Version::Embedded { is_webgl, .. } => is_webgl,
+ }
+ }
+
+ /// Checks the list of currently supported versions and returns true if it contains the
+ /// specified version
+ ///
+ /// # Notes
+ /// As an invalid version number will never be added to the supported version list
+ /// so this also checks for version validity
+ fn is_supported(&self) -> bool {
+ match *self {
+ Version::Desktop(v) => SUPPORTED_CORE_VERSIONS.contains(&v),
+ Version::Embedded { version: v, .. } => SUPPORTED_ES_VERSIONS.contains(&v),
+ }
+ }
+
+ fn supports_io_locations(&self) -> bool {
+ *self >= Version::Desktop(330) || *self >= Version::new_gles(300)
+ }
+
+ /// Checks if the version supports all of the explicit layouts:
+ /// - `location=` qualifiers for bindings
+ /// - `binding=` qualifiers for resources
+ ///
+ /// Note: `location=` for vertex inputs and fragment outputs is supported
+ /// unconditionally for GLES 300.
+ fn supports_explicit_locations(&self) -> bool {
+ *self >= Version::Desktop(410) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_early_depth_test(&self) -> bool {
+ *self >= Version::Desktop(130) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_std430_layout(&self) -> bool {
+ *self >= Version::Desktop(430) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_fma_function(&self) -> bool {
+ *self >= Version::Desktop(400) || *self >= Version::new_gles(320)
+ }
+
+ fn supports_integer_functions(&self) -> bool {
+ *self >= Version::Desktop(400) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_frexp_function(&self) -> bool {
+ *self >= Version::Desktop(400) || *self >= Version::new_gles(310)
+ }
+
+ fn supports_derivative_control(&self) -> bool {
+ *self >= Version::Desktop(450)
+ }
+}
+
+impl PartialOrd for Version {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ match (*self, *other) {
+ (Version::Desktop(x), Version::Desktop(y)) => Some(x.cmp(&y)),
+ (Version::Embedded { version: x, .. }, Version::Embedded { version: y, .. }) => {
+ Some(x.cmp(&y))
+ }
+ _ => None,
+ }
+ }
+}
+
+impl fmt::Display for Version {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match *self {
+ Version::Desktop(v) => write!(f, "{v} core"),
+ Version::Embedded { version: v, .. } => write!(f, "{v} es"),
+ }
+ }
+}
+
+bitflags::bitflags! {
+ /// Configuration flags for the [`Writer`].
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct WriterFlags: u32 {
+ /// Flip output Y and extend Z from (0, 1) to (-1, 1).
+ const ADJUST_COORDINATE_SPACE = 0x1;
+ /// Supports GL_EXT_texture_shadow_lod on the host, which provides
+ /// additional functions on shadows and arrays of shadows.
+ const TEXTURE_SHADOW_LOD = 0x2;
+ /// Supports ARB_shader_draw_parameters on the host, which provides
+ /// support for `gl_BaseInstanceARB`, `gl_BaseVertexARB`, and `gl_DrawIDARB`.
+ const DRAW_PARAMETERS = 0x4;
+ /// Include unused global variables, constants and functions. By default the output will exclude
+ /// global variables that are not used in the specified entrypoint (including indirect use),
+ /// all constant declarations, and functions that use excluded global variables.
+ const INCLUDE_UNUSED_ITEMS = 0x10;
+ /// Emit `PointSize` output builtin to vertex shaders, which is
+ /// required for drawing with `PointList` topology.
+ ///
+ /// https://registry.khronos.org/OpenGL/specs/es/3.2/GLSL_ES_Specification_3.20.html#built-in-language-variables
+ /// The variable gl_PointSize is intended for a shader to write the size of the point to be rasterized. It is measured in pixels.
+ /// If gl_PointSize is not written to, its value is undefined in subsequent pipe stages.
+ const FORCE_POINT_SIZE = 0x20;
+ }
+}
+
+/// Configuration used in the [`Writer`].
+#[derive(Debug, Clone)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Options {
+ /// The GLSL version to be used.
+ pub version: Version,
+ /// Configuration flags for the [`Writer`].
+ pub writer_flags: WriterFlags,
+ /// Map of resources association to binding locations.
+ pub binding_map: BindingMap,
+ /// Should workgroup variables be zero initialized (by polyfilling)?
+ pub zero_initialize_workgroup_memory: bool,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ version: Version::new_gles(310),
+ writer_flags: WriterFlags::ADJUST_COORDINATE_SPACE,
+ binding_map: BindingMap::default(),
+ zero_initialize_workgroup_memory: true,
+ }
+ }
+}
+
+/// A subset of options meant to be changed per pipeline.
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct PipelineOptions {
+ /// The stage of the entry point.
+ pub shader_stage: ShaderStage,
+ /// The name of the entry point.
+ ///
+ /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown.
+ pub entry_point: String,
+ /// How many views to render to, if doing multiview rendering.
+ pub multiview: Option<std::num::NonZeroU32>,
+}
+
+#[derive(Debug)]
+pub struct VaryingLocation {
+ /// The location of the global.
+ /// This corresponds to `layout(location = ..)` in GLSL.
+ pub location: u32,
+ /// The index which can be used for dual source blending.
+ /// This corresponds to `layout(index = ..)` in GLSL.
+ pub index: u32,
+}
+
+/// Reflection info for texture mappings and uniforms.
+#[derive(Debug)]
+pub struct ReflectionInfo {
+ /// Mapping between texture names and variables/samplers.
+ pub texture_mapping: crate::FastHashMap<String, TextureMapping>,
+ /// Mapping between uniform variables and names.
+ pub uniforms: crate::FastHashMap<Handle<crate::GlobalVariable>, String>,
+ /// Mapping between names and attribute locations.
+ pub varying: crate::FastHashMap<String, VaryingLocation>,
+ /// List of push constant items in the shader.
+ pub push_constant_items: Vec<PushConstantItem>,
+}
+
+/// Mapping between a texture and its sampler, if it exists.
+///
+/// GLSL pre-Vulkan has no concept of separate textures and samplers. Instead, everything is a
+/// `gsamplerN` where `g` is the scalar type and `N` is the dimension. But naga uses separate textures
+/// and samplers in the IR, so the backend produces a [`FastHashMap`](crate::FastHashMap) with the texture name
+/// as a key and a [`TextureMapping`] as a value. This way, the user knows where to bind.
+///
+/// [`Storage`](crate::ImageClass::Storage) images produce `gimageN` and don't have an associated sampler,
+/// so the [`sampler`](Self::sampler) field will be [`None`].
+#[derive(Debug, Clone)]
+pub struct TextureMapping {
+ /// Handle to the image global variable.
+ pub texture: Handle<crate::GlobalVariable>,
+ /// Handle to the associated sampler global variable, if it exists.
+ pub sampler: Option<Handle<crate::GlobalVariable>>,
+}
+
+/// All information to bind a single uniform value to the shader.
+///
+/// Push constants are emulated using traditional uniforms in OpenGL.
+///
+/// These are composed of a set of primitives (scalar, vector, matrix) that
+/// are given names. Because they are not backed by the concept of a buffer,
+/// we must do the work of calculating the offset of each primitive in the
+/// push constant block.
+#[derive(Debug, Clone)]
+pub struct PushConstantItem {
+ /// GL uniform name for the item. This name is the same as if you were
+ /// to access it directly from a GLSL shader.
+ ///
+ /// The with the following example, the following names will be generated,
+ /// one name per GLSL uniform.
+ ///
+ /// ```glsl
+ /// struct InnerStruct {
+ /// value: f32,
+ /// }
+ ///
+ /// struct PushConstant {
+ /// InnerStruct inner;
+ /// vec4 array[2];
+ /// }
+ ///
+ /// uniform PushConstants _push_constant_binding_cs;
+ /// ```
+ ///
+ /// ```text
+ /// - _push_constant_binding_cs.inner.value
+ /// - _push_constant_binding_cs.array[0]
+ /// - _push_constant_binding_cs.array[1]
+ /// ```
+ ///
+ pub access_path: String,
+ /// Type of the uniform. This will only ever be a scalar, vector, or matrix.
+ pub ty: Handle<crate::Type>,
+ /// The offset in the push constant memory block this uniform maps to.
+ ///
+ /// The size of the uniform can be derived from the type.
+ pub offset: u32,
+}
+
+/// Helper structure that generates a number
+#[derive(Default)]
+struct IdGenerator(u32);
+
+impl IdGenerator {
+ /// Generates a number that's guaranteed to be unique for this `IdGenerator`
+ fn generate(&mut self) -> u32 {
+ // It's just an increasing number but it does the job
+ let ret = self.0;
+ self.0 += 1;
+ ret
+ }
+}
+
+/// Assorted options needed for generating varyings.
+#[derive(Clone, Copy)]
+struct VaryingOptions {
+ output: bool,
+ targeting_webgl: bool,
+ draw_parameters: bool,
+}
+
+impl VaryingOptions {
+ const fn from_writer_options(options: &Options, output: bool) -> Self {
+ Self {
+ output,
+ targeting_webgl: options.version.is_webgl(),
+ draw_parameters: options.writer_flags.contains(WriterFlags::DRAW_PARAMETERS),
+ }
+ }
+}
+
+/// Helper wrapper used to get a name for a varying
+///
+/// Varying have different naming schemes depending on their binding:
+/// - Varyings with builtin bindings get the from [`glsl_built_in`].
+/// - Varyings with location bindings are named `_S_location_X` where `S` is a
+/// prefix identifying which pipeline stage the varying connects, and `X` is
+/// the location.
+struct VaryingName<'a> {
+ binding: &'a crate::Binding,
+ stage: ShaderStage,
+ options: VaryingOptions,
+}
+impl fmt::Display for VaryingName<'_> {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match *self.binding {
+ crate::Binding::Location {
+ second_blend_source: true,
+ ..
+ } => {
+ write!(f, "_fs2p_location1",)
+ }
+ crate::Binding::Location { location, .. } => {
+ let prefix = match (self.stage, self.options.output) {
+ (ShaderStage::Compute, _) => unreachable!(),
+ // pipeline to vertex
+ (ShaderStage::Vertex, false) => "p2vs",
+ // vertex to fragment
+ (ShaderStage::Vertex, true) | (ShaderStage::Fragment, false) => "vs2fs",
+ // fragment to pipeline
+ (ShaderStage::Fragment, true) => "fs2p",
+ };
+ write!(f, "_{prefix}_location{location}",)
+ }
+ crate::Binding::BuiltIn(built_in) => {
+ write!(f, "{}", glsl_built_in(built_in, self.options))
+ }
+ }
+ }
+}
+
+impl ShaderStage {
+ const fn to_str(self) -> &'static str {
+ match self {
+ ShaderStage::Compute => "cs",
+ ShaderStage::Fragment => "fs",
+ ShaderStage::Vertex => "vs",
+ }
+ }
+}
+
+/// Shorthand result used internally by the backend
+type BackendResult<T = ()> = Result<T, Error>;
+
+/// A GLSL compilation error.
+#[derive(Debug, Error)]
+pub enum Error {
+ /// A error occurred while writing to the output.
+ #[error("Format error")]
+ FmtError(#[from] FmtError),
+ /// The specified [`Version`] doesn't have all required [`Features`].
+ ///
+ /// Contains the missing [`Features`].
+ #[error("The selected version doesn't support {0:?}")]
+ MissingFeatures(Features),
+ /// [`AddressSpace::PushConstant`](crate::AddressSpace::PushConstant) was used more than
+ /// once in the entry point, which isn't supported.
+ #[error("Multiple push constants aren't supported")]
+ MultiplePushConstants,
+ /// The specified [`Version`] isn't supported.
+ #[error("The specified version isn't supported")]
+ VersionNotSupported,
+ /// The entry point couldn't be found.
+ #[error("The requested entry point couldn't be found")]
+ EntryPointNotFound,
+ /// A call was made to an unsupported external.
+ #[error("A call was made to an unsupported external: {0}")]
+ UnsupportedExternal(String),
+ /// A scalar with an unsupported width was requested.
+ #[error("A scalar with an unsupported width was requested: {0:?}")]
+ UnsupportedScalar(crate::Scalar),
+ /// A image was used with multiple samplers, which isn't supported.
+ #[error("A image was used with multiple samplers")]
+ ImageMultipleSamplers,
+ #[error("{0}")]
+ Custom(String),
+}
+
+/// Binary operation with a different logic on the GLSL side.
+enum BinaryOperation {
+ /// Vector comparison should use the function like `greaterThan()`, etc.
+ VectorCompare,
+ /// Vector component wise operation; used to polyfill unsupported ops like `|` and `&` for `bvecN`'s
+ VectorComponentWise,
+ /// GLSL `%` is SPIR-V `OpUMod/OpSMod` and `mod()` is `OpFMod`, but [`BinaryOperator::Modulo`](crate::BinaryOperator::Modulo) is `OpFRem`.
+ Modulo,
+ /// Any plain operation. No additional logic required.
+ Other,
+}
+
+/// Writer responsible for all code generation.
+pub struct Writer<'a, W> {
+ // Inputs
+ /// The module being written.
+ module: &'a crate::Module,
+ /// The module analysis.
+ info: &'a valid::ModuleInfo,
+ /// The output writer.
+ out: W,
+ /// User defined configuration to be used.
+ options: &'a Options,
+ /// The bound checking policies to be used
+ policies: proc::BoundsCheckPolicies,
+
+ // Internal State
+ /// Features manager used to store all the needed features and write them.
+ features: FeaturesManager,
+ namer: proc::Namer,
+ /// A map with all the names needed for writing the module
+ /// (generated by a [`Namer`](crate::proc::Namer)).
+ names: crate::FastHashMap<NameKey, String>,
+ /// A map with the names of global variables needed for reflections.
+ reflection_names_globals: crate::FastHashMap<Handle<crate::GlobalVariable>, String>,
+ /// The selected entry point.
+ entry_point: &'a crate::EntryPoint,
+ /// The index of the selected entry point.
+ entry_point_idx: proc::EntryPointIndex,
+ /// A generator for unique block numbers.
+ block_id: IdGenerator,
+ /// Set of expressions that have associated temporary variables.
+ named_expressions: crate::NamedExpressions,
+ /// Set of expressions that need to be baked to avoid unnecessary repetition in output
+ need_bake_expressions: back::NeedBakeExpressions,
+ /// How many views to render to, if doing multiview rendering.
+ multiview: Option<std::num::NonZeroU32>,
+ /// Mapping of varying variables to their location. Needed for reflections.
+ varying: crate::FastHashMap<String, VaryingLocation>,
+}
+
+impl<'a, W: Write> Writer<'a, W> {
+ /// Creates a new [`Writer`] instance.
+ ///
+ /// # Errors
+ /// - If the version specified is invalid or supported.
+ /// - If the entry point couldn't be found in the module.
+ /// - If the version specified doesn't support some used features.
+ pub fn new(
+ out: W,
+ module: &'a crate::Module,
+ info: &'a valid::ModuleInfo,
+ options: &'a Options,
+ pipeline_options: &'a PipelineOptions,
+ policies: proc::BoundsCheckPolicies,
+ ) -> Result<Self, Error> {
+ // Check if the requested version is supported
+ if !options.version.is_supported() {
+ log::error!("Version {}", options.version);
+ return Err(Error::VersionNotSupported);
+ }
+
+ // Try to find the entry point and corresponding index
+ let ep_idx = module
+ .entry_points
+ .iter()
+ .position(|ep| {
+ pipeline_options.shader_stage == ep.stage && pipeline_options.entry_point == ep.name
+ })
+ .ok_or(Error::EntryPointNotFound)?;
+
+ // Generate a map with names required to write the module
+ let mut names = crate::FastHashMap::default();
+ let mut namer = proc::Namer::default();
+ namer.reset(
+ module,
+ keywords::RESERVED_KEYWORDS,
+ &[],
+ &[],
+ &[
+ "gl_", // all GL built-in variables
+ "_group", // all normal bindings
+ "_push_constant_binding_", // all push constant bindings
+ ],
+ &mut names,
+ );
+
+ // Build the instance
+ let mut this = Self {
+ module,
+ info,
+ out,
+ options,
+ policies,
+
+ namer,
+ features: FeaturesManager::new(),
+ names,
+ reflection_names_globals: crate::FastHashMap::default(),
+ entry_point: &module.entry_points[ep_idx],
+ entry_point_idx: ep_idx as u16,
+ multiview: pipeline_options.multiview,
+ block_id: IdGenerator::default(),
+ named_expressions: Default::default(),
+ need_bake_expressions: Default::default(),
+ varying: Default::default(),
+ };
+
+ // Find all features required to print this module
+ this.collect_required_features()?;
+
+ Ok(this)
+ }
+
+ /// Writes the [`Module`](crate::Module) as glsl to the output
+ ///
+ /// # Notes
+ /// If an error occurs while writing, the output might have been written partially
+ ///
+ /// # Panics
+ /// Might panic if the module is invalid
+ pub fn write(&mut self) -> Result<ReflectionInfo, Error> {
+ // We use `writeln!(self.out)` throughout the write to add newlines
+ // to make the output more readable
+
+ let es = self.options.version.is_es();
+
+ // Write the version (It must be the first thing or it isn't a valid glsl output)
+ writeln!(self.out, "#version {}", self.options.version)?;
+ // Write all the needed extensions
+ //
+ // This used to be the last thing being written as it allowed to search for features while
+ // writing the module saving some loops but some older versions (420 or less) required the
+ // extensions to appear before being used, even though extensions are part of the
+ // preprocessor not the processor ¯\_(ツ)_/¯
+ self.features.write(self.options, &mut self.out)?;
+
+ // Write the additional extensions
+ if self
+ .options
+ .writer_flags
+ .contains(WriterFlags::TEXTURE_SHADOW_LOD)
+ {
+ // https://www.khronos.org/registry/OpenGL/extensions/EXT/EXT_texture_shadow_lod.txt
+ writeln!(self.out, "#extension GL_EXT_texture_shadow_lod : require")?;
+ }
+
+ // glsl es requires a precision to be specified for floats and ints
+ // TODO: Should this be user configurable?
+ if es {
+ writeln!(self.out)?;
+ writeln!(self.out, "precision highp float;")?;
+ writeln!(self.out, "precision highp int;")?;
+ writeln!(self.out)?;
+ }
+
+ if self.entry_point.stage == ShaderStage::Compute {
+ let workgroup_size = self.entry_point.workgroup_size;
+ writeln!(
+ self.out,
+ "layout(local_size_x = {}, local_size_y = {}, local_size_z = {}) in;",
+ workgroup_size[0], workgroup_size[1], workgroup_size[2]
+ )?;
+ writeln!(self.out)?;
+ }
+
+ if self.entry_point.stage == ShaderStage::Vertex
+ && !self
+ .options
+ .writer_flags
+ .contains(WriterFlags::DRAW_PARAMETERS)
+ && self.features.contains(Features::INSTANCE_INDEX)
+ {
+ writeln!(self.out, "uniform uint {FIRST_INSTANCE_BINDING};")?;
+ writeln!(self.out)?;
+ }
+
+ // Enable early depth tests if needed
+ if let Some(depth_test) = self.entry_point.early_depth_test {
+ // If early depth test is supported for this version of GLSL
+ if self.options.version.supports_early_depth_test() {
+ writeln!(self.out, "layout(early_fragment_tests) in;")?;
+
+ if let Some(conservative) = depth_test.conservative {
+ use crate::ConservativeDepth as Cd;
+
+ let depth = match conservative {
+ Cd::GreaterEqual => "greater",
+ Cd::LessEqual => "less",
+ Cd::Unchanged => "unchanged",
+ };
+ writeln!(self.out, "layout (depth_{depth}) out float gl_FragDepth;")?;
+ }
+ writeln!(self.out)?;
+ } else {
+ log::warn!(
+ "Early depth testing is not supported for this version of GLSL: {}",
+ self.options.version
+ );
+ }
+ }
+
+ if self.entry_point.stage == ShaderStage::Vertex && self.options.version.is_webgl() {
+ if let Some(multiview) = self.multiview.as_ref() {
+ writeln!(self.out, "layout(num_views = {multiview}) in;")?;
+ writeln!(self.out)?;
+ }
+ }
+
+ // Write struct types.
+ //
+ // This are always ordered because the IR is structured in a way that
+ // you can't make a struct without adding all of its members first.
+ for (handle, ty) in self.module.types.iter() {
+ if let TypeInner::Struct { ref members, .. } = ty.inner {
+ // Structures ending with runtime-sized arrays can only be
+ // rendered as shader storage blocks in GLSL, not stand-alone
+ // struct types.
+ if !self.module.types[members.last().unwrap().ty]
+ .inner
+ .is_dynamically_sized(&self.module.types)
+ {
+ let name = &self.names[&NameKey::Type(handle)];
+ write!(self.out, "struct {name} ")?;
+ self.write_struct_body(handle, members)?;
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+
+ // Write functions to create special types.
+ for (type_key, struct_ty) in self.module.special_types.predeclared_types.iter() {
+ match type_key {
+ &crate::PredeclaredType::ModfResult { size, width }
+ | &crate::PredeclaredType::FrexpResult { size, width } => {
+ let arg_type_name_owner;
+ let arg_type_name = if let Some(size) = size {
+ arg_type_name_owner =
+ format!("{}vec{}", if width == 8 { "d" } else { "" }, size as u8);
+ &arg_type_name_owner
+ } else if width == 8 {
+ "double"
+ } else {
+ "float"
+ };
+
+ let other_type_name_owner;
+ let (defined_func_name, called_func_name, other_type_name) =
+ if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
+ (MODF_FUNCTION, "modf", arg_type_name)
+ } else {
+ let other_type_name = if let Some(size) = size {
+ other_type_name_owner = format!("ivec{}", size as u8);
+ &other_type_name_owner
+ } else {
+ "int"
+ };
+ (FREXP_FUNCTION, "frexp", other_type_name)
+ };
+
+ let struct_name = &self.names[&NameKey::Type(*struct_ty)];
+
+ writeln!(self.out)?;
+ if !self.options.version.supports_frexp_function()
+ && matches!(type_key, &crate::PredeclaredType::FrexpResult { .. })
+ {
+ writeln!(
+ self.out,
+ "{struct_name} {defined_func_name}({arg_type_name} arg) {{
+ {other_type_name} other = arg == {arg_type_name}(0) ? {other_type_name}(0) : {other_type_name}({arg_type_name}(1) + log2(arg));
+ {arg_type_name} fract = arg * exp2({arg_type_name}(-other));
+ return {struct_name}(fract, other);
+}}",
+ )?;
+ } else {
+ writeln!(
+ self.out,
+ "{struct_name} {defined_func_name}({arg_type_name} arg) {{
+ {other_type_name} other;
+ {arg_type_name} fract = {called_func_name}(arg, other);
+ return {struct_name}(fract, other);
+}}",
+ )?;
+ }
+ }
+ &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
+ }
+ }
+
+ // Write all named constants
+ let mut constants = self
+ .module
+ .constants
+ .iter()
+ .filter(|&(_, c)| c.name.is_some())
+ .peekable();
+ while let Some((handle, _)) = constants.next() {
+ self.write_global_constant(handle)?;
+ // Add extra newline for readability on last iteration
+ if constants.peek().is_none() {
+ writeln!(self.out)?;
+ }
+ }
+
+ let ep_info = self.info.get_entry_point(self.entry_point_idx as usize);
+
+ // Write the globals
+ //
+ // Unless explicitly disabled with WriterFlags::INCLUDE_UNUSED_ITEMS,
+ // we filter all globals that aren't used by the selected entry point as they might be
+ // interfere with each other (i.e. two globals with the same location but different with
+ // different classes)
+ let include_unused = self
+ .options
+ .writer_flags
+ .contains(WriterFlags::INCLUDE_UNUSED_ITEMS);
+ for (handle, global) in self.module.global_variables.iter() {
+ let is_unused = ep_info[handle].is_empty();
+ if !include_unused && is_unused {
+ continue;
+ }
+
+ match self.module.types[global.ty].inner {
+ // We treat images separately because they might require
+ // writing the storage format
+ TypeInner::Image {
+ mut dim,
+ arrayed,
+ class,
+ } => {
+ // Gather the storage format if needed
+ let storage_format_access = match self.module.types[global.ty].inner {
+ TypeInner::Image {
+ class: crate::ImageClass::Storage { format, access },
+ ..
+ } => Some((format, access)),
+ _ => None,
+ };
+
+ if dim == crate::ImageDimension::D1 && es {
+ dim = crate::ImageDimension::D2
+ }
+
+ // Gether the location if needed
+ let layout_binding = if self.options.version.supports_explicit_locations() {
+ let br = global.binding.as_ref().unwrap();
+ self.options.binding_map.get(br).cloned()
+ } else {
+ None
+ };
+
+ // Write all the layout qualifiers
+ if layout_binding.is_some() || storage_format_access.is_some() {
+ write!(self.out, "layout(")?;
+ if let Some(binding) = layout_binding {
+ write!(self.out, "binding = {binding}")?;
+ }
+ if let Some((format, _)) = storage_format_access {
+ let format_str = glsl_storage_format(format)?;
+ let separator = match layout_binding {
+ Some(_) => ",",
+ None => "",
+ };
+ write!(self.out, "{separator}{format_str}")?;
+ }
+ write!(self.out, ") ")?;
+ }
+
+ if let Some((_, access)) = storage_format_access {
+ self.write_storage_access(access)?;
+ }
+
+ // All images in glsl are `uniform`
+ // The trailing space is important
+ write!(self.out, "uniform ")?;
+
+ // write the type
+ //
+ // This is way we need the leading space because `write_image_type` doesn't add
+ // any spaces at the beginning or end
+ self.write_image_type(dim, arrayed, class)?;
+
+ // Finally write the name and end the global with a `;`
+ // The leading space is important
+ let global_name = self.get_global_name(handle, global);
+ writeln!(self.out, " {global_name};")?;
+ writeln!(self.out)?;
+
+ self.reflection_names_globals.insert(handle, global_name);
+ }
+ // glsl has no concept of samplers so we just ignore it
+ TypeInner::Sampler { .. } => continue,
+ // All other globals are written by `write_global`
+ _ => {
+ self.write_global(handle, global)?;
+ // Add a newline (only for readability)
+ writeln!(self.out)?;
+ }
+ }
+ }
+
+ for arg in self.entry_point.function.arguments.iter() {
+ self.write_varying(arg.binding.as_ref(), arg.ty, false)?;
+ }
+ if let Some(ref result) = self.entry_point.function.result {
+ self.write_varying(result.binding.as_ref(), result.ty, true)?;
+ }
+ writeln!(self.out)?;
+
+ // Write all regular functions
+ for (handle, function) in self.module.functions.iter() {
+ // Check that the function doesn't use globals that aren't supported
+ // by the current entry point
+ if !include_unused && !ep_info.dominates_global_use(&self.info[handle]) {
+ continue;
+ }
+
+ let fun_info = &self.info[handle];
+
+ // Skip functions that that are not compatible with this entry point's stage.
+ //
+ // When validation is enabled, it rejects modules whose entry points try to call
+ // incompatible functions, so if we got this far, then any functions incompatible
+ // with our selected entry point must not be used.
+ //
+ // When validation is disabled, `fun_info.available_stages` is always just
+ // `ShaderStages::all()`, so this will write all functions in the module, and
+ // the downstream GLSL compiler will catch any problems.
+ if !fun_info.available_stages.contains(ep_info.available_stages) {
+ continue;
+ }
+
+ // Write the function
+ self.write_function(back::FunctionType::Function(handle), function, fun_info)?;
+
+ writeln!(self.out)?;
+ }
+
+ self.write_function(
+ back::FunctionType::EntryPoint(self.entry_point_idx),
+ &self.entry_point.function,
+ ep_info,
+ )?;
+
+ // Add newline at the end of file
+ writeln!(self.out)?;
+
+ // Collect all reflection info and return it to the user
+ self.collect_reflection_info()
+ }
+
+ fn write_array_size(
+ &mut self,
+ base: Handle<crate::Type>,
+ size: crate::ArraySize,
+ ) -> BackendResult {
+ write!(self.out, "[")?;
+
+ // Write the array size
+ // Writes nothing if `ArraySize::Dynamic`
+ match size {
+ crate::ArraySize::Constant(size) => {
+ write!(self.out, "{size}")?;
+ }
+ crate::ArraySize::Dynamic => (),
+ }
+
+ write!(self.out, "]")?;
+
+ if let TypeInner::Array {
+ base: next_base,
+ size: next_size,
+ ..
+ } = self.module.types[base].inner
+ {
+ self.write_array_size(next_base, next_size)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write value types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_value_type(&mut self, inner: &TypeInner) -> BackendResult {
+ match *inner {
+ // Scalars are simple we just get the full name from `glsl_scalar`
+ TypeInner::Scalar(scalar)
+ | TypeInner::Atomic(scalar)
+ | TypeInner::ValuePointer {
+ size: None,
+ scalar,
+ space: _,
+ } => write!(self.out, "{}", glsl_scalar(scalar)?.full)?,
+ // Vectors are just `gvecN` where `g` is the scalar prefix and `N` is the vector size
+ TypeInner::Vector { size, scalar }
+ | TypeInner::ValuePointer {
+ size: Some(size),
+ scalar,
+ space: _,
+ } => write!(self.out, "{}vec{}", glsl_scalar(scalar)?.prefix, size as u8)?,
+ // Matrices are written with `gmatMxN` where `g` is the scalar prefix (only floats and
+ // doubles are allowed), `M` is the columns count and `N` is the rows count
+ //
+ // glsl supports a matrix shorthand `gmatN` where `N` = `M` but it doesn't justify the
+ // extra branch to write matrices this way
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => write!(
+ self.out,
+ "{}mat{}x{}",
+ glsl_scalar(scalar)?.prefix,
+ columns as u8,
+ rows as u8
+ )?,
+ // GLSL arrays are written as `type name[size]`
+ // Here we only write the size of the array i.e. `[size]`
+ // Base `type` and `name` should be written outside
+ TypeInner::Array { base, size, .. } => self.write_array_size(base, size)?,
+ // Write all variants instead of `_` so that if new variants are added a
+ // no exhaustiveness error is thrown
+ TypeInner::Pointer { .. }
+ | TypeInner::Struct { .. }
+ | TypeInner::Image { .. }
+ | TypeInner::Sampler { .. }
+ | TypeInner::AccelerationStructure
+ | TypeInner::RayQuery
+ | TypeInner::BindingArray { .. } => {
+ return Err(Error::Custom(format!("Unable to write type {inner:?}")))
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_type(&mut self, ty: Handle<crate::Type>) -> BackendResult {
+ match self.module.types[ty].inner {
+ // glsl has no pointer types so just write types as normal and loads are skipped
+ TypeInner::Pointer { base, .. } => self.write_type(base),
+ // glsl structs are written as just the struct name
+ TypeInner::Struct { .. } => {
+ // Get the struct name
+ let name = &self.names[&NameKey::Type(ty)];
+ write!(self.out, "{name}")?;
+ Ok(())
+ }
+ // glsl array has the size separated from the base type
+ TypeInner::Array { base, .. } => self.write_type(base),
+ ref other => self.write_value_type(other),
+ }
+ }
+
+ /// Helper method to write a image type
+ ///
+ /// # Notes
+ /// Adds no leading or trailing whitespace
+ fn write_image_type(
+ &mut self,
+ dim: crate::ImageDimension,
+ arrayed: bool,
+ class: crate::ImageClass,
+ ) -> BackendResult {
+ // glsl images consist of four parts the scalar prefix, the image "type", the dimensions
+ // and modifiers
+ //
+ // There exists two image types
+ // - sampler - for sampled images
+ // - image - for storage images
+ //
+ // There are three possible modifiers that can be used together and must be written in
+ // this order to be valid
+ // - MS - used if it's a multisampled image
+ // - Array - used if it's an image array
+ // - Shadow - used if it's a depth image
+ use crate::ImageClass as Ic;
+
+ let (base, kind, ms, comparison) = match class {
+ Ic::Sampled { kind, multi: true } => ("sampler", kind, "MS", ""),
+ Ic::Sampled { kind, multi: false } => ("sampler", kind, "", ""),
+ Ic::Depth { multi: true } => ("sampler", crate::ScalarKind::Float, "MS", ""),
+ Ic::Depth { multi: false } => ("sampler", crate::ScalarKind::Float, "", "Shadow"),
+ Ic::Storage { format, .. } => ("image", format.into(), "", ""),
+ };
+
+ let precision = if self.options.version.is_es() {
+ "highp "
+ } else {
+ ""
+ };
+
+ write!(
+ self.out,
+ "{}{}{}{}{}{}{}",
+ precision,
+ glsl_scalar(crate::Scalar { kind, width: 4 })?.prefix,
+ base,
+ glsl_dimension(dim),
+ ms,
+ if arrayed { "Array" } else { "" },
+ comparison
+ )?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write non images/sampler globals
+ ///
+ /// # Notes
+ /// Adds a newline
+ ///
+ /// # Panics
+ /// If the global has type sampler
+ fn write_global(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ if self.options.version.supports_explicit_locations() {
+ if let Some(ref br) = global.binding {
+ match self.options.binding_map.get(br) {
+ Some(binding) => {
+ let layout = match global.space {
+ crate::AddressSpace::Storage { .. } => {
+ if self.options.version.supports_std430_layout() {
+ "std430, "
+ } else {
+ "std140, "
+ }
+ }
+ crate::AddressSpace::Uniform => "std140, ",
+ _ => "",
+ };
+ write!(self.out, "layout({layout}binding = {binding}) ")?
+ }
+ None => {
+ log::debug!("unassigned binding for {:?}", global.name);
+ if let crate::AddressSpace::Storage { .. } = global.space {
+ if self.options.version.supports_std430_layout() {
+ write!(self.out, "layout(std430) ")?
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if let crate::AddressSpace::Storage { access } = global.space {
+ self.write_storage_access(access)?;
+ }
+
+ if let Some(storage_qualifier) = glsl_storage_qualifier(global.space) {
+ write!(self.out, "{storage_qualifier} ")?;
+ }
+
+ match global.space {
+ crate::AddressSpace::Private => {
+ self.write_simple_global(handle, global)?;
+ }
+ crate::AddressSpace::WorkGroup => {
+ self.write_simple_global(handle, global)?;
+ }
+ crate::AddressSpace::PushConstant => {
+ self.write_simple_global(handle, global)?;
+ }
+ crate::AddressSpace::Uniform => {
+ self.write_interface_block(handle, global)?;
+ }
+ crate::AddressSpace::Storage { .. } => {
+ self.write_interface_block(handle, global)?;
+ }
+ // A global variable in the `Function` address space is a
+ // contradiction in terms.
+ crate::AddressSpace::Function => unreachable!(),
+ // Textures and samplers are handled directly in `Writer::write`.
+ crate::AddressSpace::Handle => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ fn write_simple_global(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ self.write_type(global.ty)?;
+ write!(self.out, " ")?;
+ self.write_global_name(handle, global)?;
+
+ if let TypeInner::Array { base, size, .. } = self.module.types[global.ty].inner {
+ self.write_array_size(base, size)?;
+ }
+
+ if global.space.initializable() && is_value_init_supported(self.module, global.ty) {
+ write!(self.out, " = ")?;
+ if let Some(init) = global.init {
+ self.write_const_expr(init)?;
+ } else {
+ self.write_zero_init_value(global.ty)?;
+ }
+ }
+
+ writeln!(self.out, ";")?;
+
+ if let crate::AddressSpace::PushConstant = global.space {
+ let global_name = self.get_global_name(handle, global);
+ self.reflection_names_globals.insert(handle, global_name);
+ }
+
+ Ok(())
+ }
+
+ /// Write an interface block for a single Naga global.
+ ///
+ /// Write `block_name { members }`. Since `block_name` must be unique
+ /// between blocks and structs, we add `_block_ID` where `ID` is a
+ /// `IdGenerator` generated number. Write `members` in the same way we write
+ /// a struct's members.
+ fn write_interface_block(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ // Write the block name, it's just the struct name appended with `_block_ID`
+ let ty_name = &self.names[&NameKey::Type(global.ty)];
+ let block_name = format!(
+ "{}_block_{}{:?}",
+ // avoid double underscores as they are reserved in GLSL
+ ty_name.trim_end_matches('_'),
+ self.block_id.generate(),
+ self.entry_point.stage,
+ );
+ write!(self.out, "{block_name} ")?;
+ self.reflection_names_globals.insert(handle, block_name);
+
+ match self.module.types[global.ty].inner {
+ crate::TypeInner::Struct { ref members, .. }
+ if self.module.types[members.last().unwrap().ty]
+ .inner
+ .is_dynamically_sized(&self.module.types) =>
+ {
+ // Structs with dynamically sized arrays must have their
+ // members lifted up as members of the interface block. GLSL
+ // can't write such struct types anyway.
+ self.write_struct_body(global.ty, members)?;
+ write!(self.out, " ")?;
+ self.write_global_name(handle, global)?;
+ }
+ _ => {
+ // A global of any other type is written as the sole member
+ // of the interface block. Since the interface block is
+ // anonymous, this becomes visible in the global scope.
+ write!(self.out, "{{ ")?;
+ self.write_type(global.ty)?;
+ write!(self.out, " ")?;
+ self.write_global_name(handle, global)?;
+ if let TypeInner::Array { base, size, .. } = self.module.types[global.ty].inner {
+ self.write_array_size(base, size)?;
+ }
+ write!(self.out, "; }}")?;
+ }
+ }
+
+ writeln!(self.out, ";")?;
+
+ Ok(())
+ }
+
+ /// Helper method used to find which expressions of a given function require baking
+ ///
+ /// # Notes
+ /// Clears `need_bake_expressions` set before adding to it
+ fn update_expressions_to_bake(&mut self, func: &crate::Function, info: &valid::FunctionInfo) {
+ use crate::Expression;
+ self.need_bake_expressions.clear();
+ for (fun_handle, expr) in func.expressions.iter() {
+ let expr_info = &info[fun_handle];
+ let min_ref_count = func.expressions[fun_handle].bake_ref_count();
+ if min_ref_count <= expr_info.ref_count {
+ self.need_bake_expressions.insert(fun_handle);
+ }
+
+ let inner = expr_info.ty.inner_with(&self.module.types);
+
+ if let Expression::Math { fun, arg, arg1, .. } = *expr {
+ match fun {
+ crate::MathFunction::Dot => {
+ // if the expression is a Dot product with integer arguments,
+ // then the args needs baking as well
+ if let TypeInner::Scalar(crate::Scalar { kind, .. }) = *inner {
+ match kind {
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
+ self.need_bake_expressions.insert(arg);
+ self.need_bake_expressions.insert(arg1.unwrap());
+ }
+ _ => {}
+ }
+ }
+ }
+ crate::MathFunction::CountLeadingZeros => {
+ if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
+ self.need_bake_expressions.insert(arg);
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+ }
+
+ /// Helper method used to get a name for a global
+ ///
+ /// Globals have different naming schemes depending on their binding:
+ /// - Globals without bindings use the name from the [`Namer`](crate::proc::Namer)
+ /// - Globals with resource binding are named `_group_X_binding_Y` where `X`
+ /// is the group and `Y` is the binding
+ fn get_global_name(
+ &self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> String {
+ match (&global.binding, global.space) {
+ (&Some(ref br), _) => {
+ format!(
+ "_group_{}_binding_{}_{}",
+ br.group,
+ br.binding,
+ self.entry_point.stage.to_str()
+ )
+ }
+ (&None, crate::AddressSpace::PushConstant) => {
+ format!("_push_constant_binding_{}", self.entry_point.stage.to_str())
+ }
+ (&None, _) => self.names[&NameKey::GlobalVariable(handle)].clone(),
+ }
+ }
+
+ /// Helper method used to write a name for a global without additional heap allocation
+ fn write_global_name(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ global: &crate::GlobalVariable,
+ ) -> BackendResult {
+ match (&global.binding, global.space) {
+ (&Some(ref br), _) => write!(
+ self.out,
+ "_group_{}_binding_{}_{}",
+ br.group,
+ br.binding,
+ self.entry_point.stage.to_str()
+ )?,
+ (&None, crate::AddressSpace::PushConstant) => write!(
+ self.out,
+ "_push_constant_binding_{}",
+ self.entry_point.stage.to_str()
+ )?,
+ (&None, _) => write!(
+ self.out,
+ "{}",
+ &self.names[&NameKey::GlobalVariable(handle)]
+ )?,
+ }
+
+ Ok(())
+ }
+
+ /// Write a GLSL global that will carry a Naga entry point's argument or return value.
+ ///
+ /// A Naga entry point's arguments and return value are rendered in GLSL as
+ /// variables at global scope with the `in` and `out` storage qualifiers.
+ /// The code we generate for `main` loads from all the `in` globals into
+ /// appropriately named locals. Before it returns, `main` assigns the
+ /// components of its return value into all the `out` globals.
+ ///
+ /// This function writes a declaration for one such GLSL global,
+ /// representing a value passed into or returned from [`self.entry_point`]
+ /// that has a [`Location`] binding. The global's name is generated based on
+ /// the location index and the shader stages being connected; see
+ /// [`VaryingName`]. This means we don't need to know the names of
+ /// arguments, just their types and bindings.
+ ///
+ /// Emit nothing for entry point arguments or return values with [`BuiltIn`]
+ /// bindings; `main` will read from or assign to the appropriate GLSL
+ /// special variable; these are pre-declared. As an exception, we do declare
+ /// `gl_Position` or `gl_FragCoord` with the `invariant` qualifier if
+ /// needed.
+ ///
+ /// Use `output` together with [`self.entry_point.stage`] to determine which
+ /// shader stages are being connected, and choose the `in` or `out` storage
+ /// qualifier.
+ ///
+ /// [`self.entry_point`]: Writer::entry_point
+ /// [`self.entry_point.stage`]: crate::EntryPoint::stage
+ /// [`Location`]: crate::Binding::Location
+ /// [`BuiltIn`]: crate::Binding::BuiltIn
+ fn write_varying(
+ &mut self,
+ binding: Option<&crate::Binding>,
+ ty: Handle<crate::Type>,
+ output: bool,
+ ) -> Result<(), Error> {
+ // For a struct, emit a separate global for each member with a binding.
+ if let crate::TypeInner::Struct { ref members, .. } = self.module.types[ty].inner {
+ for member in members {
+ self.write_varying(member.binding.as_ref(), member.ty, output)?;
+ }
+ return Ok(());
+ }
+
+ let binding = match binding {
+ None => return Ok(()),
+ Some(binding) => binding,
+ };
+
+ let (location, interpolation, sampling, second_blend_source) = match *binding {
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ second_blend_source,
+ } => (location, interpolation, sampling, second_blend_source),
+ crate::Binding::BuiltIn(built_in) => {
+ if let crate::BuiltIn::Position { invariant: true } = built_in {
+ match (self.options.version, self.entry_point.stage) {
+ (
+ Version::Embedded {
+ version: 300,
+ is_webgl: true,
+ },
+ ShaderStage::Fragment,
+ ) => {
+ // `invariant gl_FragCoord` is not allowed in WebGL2 and possibly
+ // OpenGL ES in general (waiting on confirmation).
+ //
+ // See https://github.com/KhronosGroup/WebGL/issues/3518
+ }
+ _ => {
+ writeln!(
+ self.out,
+ "invariant {};",
+ glsl_built_in(
+ built_in,
+ VaryingOptions::from_writer_options(self.options, output)
+ )
+ )?;
+ }
+ }
+ }
+ return Ok(());
+ }
+ };
+
+ // Write the interpolation modifier if needed
+ //
+ // We ignore all interpolation and auxiliary modifiers that aren't used in fragment
+ // shaders' input globals or vertex shaders' output globals.
+ let emit_interpolation_and_auxiliary = match self.entry_point.stage {
+ ShaderStage::Vertex => output,
+ ShaderStage::Fragment => !output,
+ ShaderStage::Compute => false,
+ };
+
+ // Write the I/O locations, if allowed
+ let io_location = if self.options.version.supports_explicit_locations()
+ || !emit_interpolation_and_auxiliary
+ {
+ if self.options.version.supports_io_locations() {
+ if second_blend_source {
+ write!(self.out, "layout(location = {location}, index = 1) ")?;
+ } else {
+ write!(self.out, "layout(location = {location}) ")?;
+ }
+ None
+ } else {
+ Some(VaryingLocation {
+ location,
+ index: second_blend_source as u32,
+ })
+ }
+ } else {
+ None
+ };
+
+ // Write the interpolation qualifier.
+ if let Some(interp) = interpolation {
+ if emit_interpolation_and_auxiliary {
+ write!(self.out, "{} ", glsl_interpolation(interp))?;
+ }
+ }
+
+ // Write the sampling auxiliary qualifier.
+ //
+ // Before GLSL 4.2, the `centroid` and `sample` qualifiers were required to appear
+ // immediately before the `in` / `out` qualifier, so we'll just follow that rule
+ // here, regardless of the version.
+ if let Some(sampling) = sampling {
+ if emit_interpolation_and_auxiliary {
+ if let Some(qualifier) = glsl_sampling(sampling) {
+ write!(self.out, "{qualifier} ")?;
+ }
+ }
+ }
+
+ // Write the input/output qualifier.
+ write!(self.out, "{} ", if output { "out" } else { "in" })?;
+
+ // Write the type
+ // `write_type` adds no leading or trailing spaces
+ self.write_type(ty)?;
+
+ // Finally write the global name and end the global with a `;` and a newline
+ // Leading space is important
+ let vname = VaryingName {
+ binding: &crate::Binding::Location {
+ location,
+ interpolation: None,
+ sampling: None,
+ second_blend_source,
+ },
+ stage: self.entry_point.stage,
+ options: VaryingOptions::from_writer_options(self.options, output),
+ };
+ writeln!(self.out, " {vname};")?;
+
+ if let Some(location) = io_location {
+ self.varying.insert(vname.to_string(), location);
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write functions (both entry points and regular functions)
+ ///
+ /// # Notes
+ /// Adds a newline
+ fn write_function(
+ &mut self,
+ ty: back::FunctionType,
+ func: &crate::Function,
+ info: &valid::FunctionInfo,
+ ) -> BackendResult {
+ // Create a function context for the function being written
+ let ctx = back::FunctionCtx {
+ ty,
+ info,
+ expressions: &func.expressions,
+ named_expressions: &func.named_expressions,
+ };
+
+ self.named_expressions.clear();
+ self.update_expressions_to_bake(func, info);
+
+ // Write the function header
+ //
+ // glsl headers are the same as in c:
+ // `ret_type name(args)`
+ // `ret_type` is the return type
+ // `name` is the function name
+ // `args` is a comma separated list of `type name`
+ // | - `type` is the argument type
+ // | - `name` is the argument name
+
+ // Start by writing the return type if any otherwise write void
+ // This is the only place where `void` is a valid type
+ // (though it's more a keyword than a type)
+ if let back::FunctionType::EntryPoint(_) = ctx.ty {
+ write!(self.out, "void")?;
+ } else if let Some(ref result) = func.result {
+ self.write_type(result.ty)?;
+ if let TypeInner::Array { base, size, .. } = self.module.types[result.ty].inner {
+ self.write_array_size(base, size)?
+ }
+ } else {
+ write!(self.out, "void")?;
+ }
+
+ // Write the function name and open parentheses for the argument list
+ let function_name = match ctx.ty {
+ back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
+ back::FunctionType::EntryPoint(_) => "main",
+ };
+ write!(self.out, " {function_name}(")?;
+
+ // Write the comma separated argument list
+ //
+ // We need access to `Self` here so we use the reference passed to the closure as an
+ // argument instead of capturing as that would cause a borrow checker error
+ let arguments = match ctx.ty {
+ back::FunctionType::EntryPoint(_) => &[][..],
+ back::FunctionType::Function(_) => &func.arguments,
+ };
+ let arguments: Vec<_> = arguments
+ .iter()
+ .enumerate()
+ .filter(|&(_, arg)| match self.module.types[arg.ty].inner {
+ TypeInner::Sampler { .. } => false,
+ _ => true,
+ })
+ .collect();
+ self.write_slice(&arguments, |this, _, &(i, arg)| {
+ // Write the argument type
+ match this.module.types[arg.ty].inner {
+ // We treat images separately because they might require
+ // writing the storage format
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ // Write the storage format if needed
+ if let TypeInner::Image {
+ class: crate::ImageClass::Storage { format, .. },
+ ..
+ } = this.module.types[arg.ty].inner
+ {
+ write!(this.out, "layout({}) ", glsl_storage_format(format)?)?;
+ }
+
+ // write the type
+ //
+ // This is way we need the leading space because `write_image_type` doesn't add
+ // any spaces at the beginning or end
+ this.write_image_type(dim, arrayed, class)?;
+ }
+ TypeInner::Pointer { base, .. } => {
+ // write parameter qualifiers
+ write!(this.out, "inout ")?;
+ this.write_type(base)?;
+ }
+ // All other types are written by `write_type`
+ _ => {
+ this.write_type(arg.ty)?;
+ }
+ }
+
+ // Write the argument name
+ // The leading space is important
+ write!(this.out, " {}", &this.names[&ctx.argument_key(i as u32)])?;
+
+ // Write array size
+ match this.module.types[arg.ty].inner {
+ TypeInner::Array { base, size, .. } => {
+ this.write_array_size(base, size)?;
+ }
+ TypeInner::Pointer { base, .. } => {
+ if let TypeInner::Array { base, size, .. } = this.module.types[base].inner {
+ this.write_array_size(base, size)?;
+ }
+ }
+ _ => {}
+ }
+
+ Ok(())
+ })?;
+
+ // Close the parentheses and open braces to start the function body
+ writeln!(self.out, ") {{")?;
+
+ if self.options.zero_initialize_workgroup_memory
+ && ctx.ty.is_compute_entry_point(self.module)
+ {
+ self.write_workgroup_variables_initialization(&ctx)?;
+ }
+
+ // Compose the function arguments from globals, in case of an entry point.
+ if let back::FunctionType::EntryPoint(ep_index) = ctx.ty {
+ let stage = self.module.entry_points[ep_index as usize].stage;
+ for (index, arg) in func.arguments.iter().enumerate() {
+ write!(self.out, "{}", back::INDENT)?;
+ self.write_type(arg.ty)?;
+ let name = &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
+ write!(self.out, " {name}")?;
+ write!(self.out, " = ")?;
+ match self.module.types[arg.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ self.write_type(arg.ty)?;
+ write!(self.out, "(")?;
+ for (index, member) in members.iter().enumerate() {
+ let varying_name = VaryingName {
+ binding: member.binding.as_ref().unwrap(),
+ stage,
+ options: VaryingOptions::from_writer_options(self.options, false),
+ };
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "{varying_name}")?;
+ }
+ writeln!(self.out, ");")?;
+ }
+ _ => {
+ let varying_name = VaryingName {
+ binding: arg.binding.as_ref().unwrap(),
+ stage,
+ options: VaryingOptions::from_writer_options(self.options, false),
+ };
+ writeln!(self.out, "{varying_name};")?;
+ }
+ }
+ }
+ }
+
+ // Write all function locals
+ // Locals are `type name (= init)?;` where the init part (including the =) are optional
+ //
+ // Always adds a newline
+ for (handle, local) in func.local_variables.iter() {
+ // Write indentation (only for readability) and the type
+ // `write_type` adds no trailing space
+ write!(self.out, "{}", back::INDENT)?;
+ self.write_type(local.ty)?;
+
+ // Write the local name
+ // The leading space is important
+ write!(self.out, " {}", self.names[&ctx.name_key(handle)])?;
+ // Write size for array type
+ if let TypeInner::Array { base, size, .. } = self.module.types[local.ty].inner {
+ self.write_array_size(base, size)?;
+ }
+ // Write the local initializer if needed
+ if let Some(init) = local.init {
+ // Put the equal signal only if there's a initializer
+ // The leading and trailing spaces aren't needed but help with readability
+ write!(self.out, " = ")?;
+
+ // Write the constant
+ // `write_constant` adds no trailing or leading space/newline
+ self.write_expr(init, &ctx)?;
+ } else if is_value_init_supported(self.module, local.ty) {
+ write!(self.out, " = ")?;
+ self.write_zero_init_value(local.ty)?;
+ }
+
+ // Finish the local with `;` and add a newline (only for readability)
+ writeln!(self.out, ";")?
+ }
+
+ // Write the function body (statement list)
+ for sta in func.body.iter() {
+ // Write a statement, the indentation should always be 1 when writing the function body
+ // `write_stmt` adds a newline
+ self.write_stmt(sta, &ctx, back::Level(1))?;
+ }
+
+ // Close braces and add a newline
+ writeln!(self.out, "}}")?;
+
+ Ok(())
+ }
+
+ fn write_workgroup_variables_initialization(
+ &mut self,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ let mut vars = self
+ .module
+ .global_variables
+ .iter()
+ .filter(|&(handle, var)| {
+ !ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ })
+ .peekable();
+
+ if vars.peek().is_some() {
+ let level = back::Level(1);
+
+ writeln!(self.out, "{level}if (gl_LocalInvocationID == uvec3(0u)) {{")?;
+
+ for (handle, var) in vars {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{}{} = ", level.next(), name)?;
+ self.write_zero_init_value(var.ty)?;
+ writeln!(self.out, ";")?;
+ }
+
+ writeln!(self.out, "{level}}}")?;
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+ }
+
+ Ok(())
+ }
+
+ /// Write a list of comma separated `T` values using a writer function `F`.
+ ///
+ /// The writer function `F` receives a mutable reference to `self` that if needed won't cause
+ /// borrow checker issues (using for example a closure with `self` will cause issues), the
+ /// second argument is the 0 based index of the element on the list, and the last element is
+ /// a reference to the element `T` being written
+ ///
+ /// # Notes
+ /// - Adds no newlines or leading/trailing whitespace
+ /// - The last element won't have a trailing `,`
+ fn write_slice<T, F: FnMut(&mut Self, u32, &T) -> BackendResult>(
+ &mut self,
+ data: &[T],
+ mut f: F,
+ ) -> BackendResult {
+ // Loop through `data` invoking `f` for each element
+ for (index, item) in data.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ f(self, index as u32, item)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write global constants
+ fn write_global_constant(&mut self, handle: Handle<crate::Constant>) -> BackendResult {
+ write!(self.out, "const ")?;
+ let constant = &self.module.constants[handle];
+ self.write_type(constant.ty)?;
+ let name = &self.names[&NameKey::Constant(handle)];
+ write!(self.out, " {name}")?;
+ if let TypeInner::Array { base, size, .. } = self.module.types[constant.ty].inner {
+ self.write_array_size(base, size)?;
+ }
+ write!(self.out, " = ")?;
+ self.write_const_expr(constant.init)?;
+ writeln!(self.out, ";")?;
+ Ok(())
+ }
+
+ /// Helper method used to output a dot product as an arithmetic expression
+ ///
+ fn write_dot_product(
+ &mut self,
+ arg: Handle<crate::Expression>,
+ arg1: Handle<crate::Expression>,
+ size: usize,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ // Write parentheses around the dot product expression to prevent operators
+ // with different precedences from applying earlier.
+ write!(self.out, "(")?;
+
+ // Cycle trough all the components of the vector
+ for index in 0..size {
+ let component = back::COMPONENTS[index];
+ // Write the addition to the previous product
+ // This will print an extra '+' at the beginning but that is fine in glsl
+ write!(self.out, " + ")?;
+ // Write the first vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.write_expr(arg, ctx)?;
+ // Access the current component on the first vector
+ write!(self.out, ".{component} * ")?;
+ // Write the second vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.write_expr(arg1, ctx)?;
+ // Access the current component on the second vector
+ write!(self.out, ".{component}")?;
+ }
+
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ /// Helper method used to write structs
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_struct_body(
+ &mut self,
+ handle: Handle<crate::Type>,
+ members: &[crate::StructMember],
+ ) -> BackendResult {
+ // glsl structs are written as in C
+ // `struct name() { members };`
+ // | `struct` is a keyword
+ // | `name` is the struct name
+ // | `members` is a semicolon separated list of `type name`
+ // | `type` is the member type
+ // | `name` is the member name
+ writeln!(self.out, "{{")?;
+
+ for (idx, member) in members.iter().enumerate() {
+ // The indentation is only for readability
+ write!(self.out, "{}", back::INDENT)?;
+
+ match self.module.types[member.ty].inner {
+ TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ } => {
+ self.write_type(base)?;
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::StructMember(handle, idx as u32)]
+ )?;
+ // Write [size]
+ self.write_array_size(base, size)?;
+ // Newline is important
+ writeln!(self.out, ";")?;
+ }
+ _ => {
+ // Write the member type
+ // Adds no trailing space
+ self.write_type(member.ty)?;
+
+ // Write the member name and put a semicolon
+ // The leading space is important
+ // All members must have a semicolon even the last one
+ writeln!(
+ self.out,
+ " {};",
+ &self.names[&NameKey::StructMember(handle, idx as u32)]
+ )?;
+ }
+ }
+ }
+
+ write!(self.out, "}}")?;
+ Ok(())
+ }
+
+ /// Helper method used to write statements
+ ///
+ /// # Notes
+ /// Always adds a newline
+ fn write_stmt(
+ &mut self,
+ sta: &crate::Statement,
+ ctx: &back::FunctionCtx,
+ level: back::Level,
+ ) -> BackendResult {
+ use crate::Statement;
+
+ match *sta {
+ // This is where we can generate intermediate constants for some expression types.
+ Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ let ptr_class = ctx.resolve_type(handle, &self.module.types).pointer_space();
+ let expr_name = if ptr_class.is_some() {
+ // GLSL can't save a pointer-valued expression in a variable,
+ // but we shouldn't ever need to: they should never be named expressions,
+ // and none of the expression types flagged by bake_ref_count can be pointer-valued.
+ None
+ } else if let Some(name) = ctx.named_expressions.get(&handle) {
+ // Front end provides names for all variables at the start of writing.
+ // But we write them to step by step. We need to recache them
+ // Otherwise, we could accidentally write variable name instead of full expression.
+ // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
+ Some(self.namer.call(name))
+ } else if self.need_bake_expressions.contains(&handle) {
+ Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
+ } else {
+ None
+ };
+
+ // If we are going to write an `ImageLoad` next and the target image
+ // is sampled and we are using the `Restrict` policy for bounds
+ // checking images we need to write a local holding the clamped lod.
+ if let crate::Expression::ImageLoad {
+ image,
+ level: Some(level_expr),
+ ..
+ } = ctx.expressions[handle]
+ {
+ if let TypeInner::Image {
+ class: crate::ImageClass::Sampled { .. },
+ ..
+ } = *ctx.resolve_type(image, &self.module.types)
+ {
+ if let proc::BoundsCheckPolicy::Restrict = self.policies.image_load {
+ write!(self.out, "{level}")?;
+ self.write_clamped_lod(ctx, handle, image, level_expr)?
+ }
+ }
+ }
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{level}")?;
+ self.write_named_expr(handle, name, handle, ctx)?;
+ }
+ }
+ }
+ // Blocks are simple we just need to write the block statements between braces
+ // We could also just print the statements but this is more readable and maps more
+ // closely to the IR
+ Statement::Block(ref block) => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "{{")?;
+ for sta in block.iter() {
+ // Increase the indentation to help with readability
+ self.write_stmt(sta, ctx, level.next())?
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ // Ifs are written as in C:
+ // ```
+ // if(condition) {
+ // accept
+ // } else {
+ // reject
+ // }
+ // ```
+ Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "if (")?;
+ self.write_expr(condition, ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ for sta in accept {
+ // Increase indentation to help with readability
+ self.write_stmt(sta, ctx, level.next())?;
+ }
+
+ // If there are no statements in the reject block we skip writing it
+ // This is only for readability
+ if !reject.is_empty() {
+ writeln!(self.out, "{level}}} else {{")?;
+
+ for sta in reject {
+ // Increase indentation to help with readability
+ self.write_stmt(sta, ctx, level.next())?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ // Switch are written as in C:
+ // ```
+ // switch (selector) {
+ // // Fallthrough
+ // case label:
+ // block
+ // // Non fallthrough
+ // case label:
+ // block
+ // break;
+ // default:
+ // block
+ // }
+ // ```
+ // Where the `default` case happens isn't important but we put it last
+ // so that we don't need to print a `break` for it
+ Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ // Start the switch
+ write!(self.out, "{level}")?;
+ write!(self.out, "switch(")?;
+ self.write_expr(selector, ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ // Write all cases
+ let l2 = level.next();
+ for case in cases {
+ match case.value {
+ crate::SwitchValue::I32(value) => write!(self.out, "{l2}case {value}:")?,
+ crate::SwitchValue::U32(value) => write!(self.out, "{l2}case {value}u:")?,
+ crate::SwitchValue::Default => write!(self.out, "{l2}default:")?,
+ }
+
+ let write_block_braces = !(case.fall_through && case.body.is_empty());
+ if write_block_braces {
+ writeln!(self.out, " {{")?;
+ } else {
+ writeln!(self.out)?;
+ }
+
+ for sta in case.body.iter() {
+ self.write_stmt(sta, ctx, l2.next())?;
+ }
+
+ if !case.fall_through && case.body.last().map_or(true, |s| !s.is_terminator()) {
+ writeln!(self.out, "{}break;", l2.next())?;
+ }
+
+ if write_block_braces {
+ writeln!(self.out, "{l2}}}")?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ // Loops in naga IR are based on wgsl loops, glsl can emulate the behaviour by using a
+ // while true loop and appending the continuing block to the body resulting on:
+ // ```
+ // bool loop_init = true;
+ // while(true) {
+ // if (!loop_init) { <continuing> }
+ // loop_init = false;
+ // <body>
+ // }
+ // ```
+ Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ if !continuing.is_empty() || break_if.is_some() {
+ let gate_name = self.namer.call("loop_init");
+ writeln!(self.out, "{level}bool {gate_name} = true;")?;
+ writeln!(self.out, "{level}while(true) {{")?;
+ let l2 = level.next();
+ let l3 = l2.next();
+ writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
+ for sta in continuing {
+ self.write_stmt(sta, ctx, l3)?;
+ }
+ if let Some(condition) = break_if {
+ write!(self.out, "{l3}if (")?;
+ self.write_expr(condition, ctx)?;
+ writeln!(self.out, ") {{")?;
+ writeln!(self.out, "{}break;", l3.next())?;
+ writeln!(self.out, "{l3}}}")?;
+ }
+ writeln!(self.out, "{l2}}}")?;
+ writeln!(self.out, "{}{} = false;", level.next(), gate_name)?;
+ } else {
+ writeln!(self.out, "{level}while(true) {{")?;
+ }
+ for sta in body {
+ self.write_stmt(sta, ctx, level.next())?;
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ // Break, continue and return as written as in C
+ // `break;`
+ Statement::Break => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "break;")?
+ }
+ // `continue;`
+ Statement::Continue => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "continue;")?
+ }
+ // `return expr;`, `expr` is optional
+ Statement::Return { value } => {
+ write!(self.out, "{level}")?;
+ match ctx.ty {
+ back::FunctionType::Function(_) => {
+ write!(self.out, "return")?;
+ // Write the expression to be returned if needed
+ if let Some(expr) = value {
+ write!(self.out, " ")?;
+ self.write_expr(expr, ctx)?;
+ }
+ writeln!(self.out, ";")?;
+ }
+ back::FunctionType::EntryPoint(ep_index) => {
+ let mut has_point_size = false;
+ let ep = &self.module.entry_points[ep_index as usize];
+ if let Some(ref result) = ep.function.result {
+ let value = value.unwrap();
+ match self.module.types[result.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let temp_struct_name = match ctx.expressions[value] {
+ crate::Expression::Compose { .. } => {
+ let return_struct = "_tmp_return";
+ write!(
+ self.out,
+ "{} {} = ",
+ &self.names[&NameKey::Type(result.ty)],
+ return_struct
+ )?;
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ";")?;
+ write!(self.out, "{level}")?;
+ Some(return_struct)
+ }
+ _ => None,
+ };
+
+ for (index, member) in members.iter().enumerate() {
+ if let Some(crate::Binding::BuiltIn(
+ crate::BuiltIn::PointSize,
+ )) = member.binding
+ {
+ has_point_size = true;
+ }
+
+ let varying_name = VaryingName {
+ binding: member.binding.as_ref().unwrap(),
+ stage: ep.stage,
+ options: VaryingOptions::from_writer_options(
+ self.options,
+ true,
+ ),
+ };
+ write!(self.out, "{varying_name} = ")?;
+
+ if let Some(struct_name) = temp_struct_name {
+ write!(self.out, "{struct_name}")?;
+ } else {
+ self.write_expr(value, ctx)?;
+ }
+
+ // Write field name
+ writeln!(
+ self.out,
+ ".{};",
+ &self.names
+ [&NameKey::StructMember(result.ty, index as u32)]
+ )?;
+ write!(self.out, "{level}")?;
+ }
+ }
+ _ => {
+ let name = VaryingName {
+ binding: result.binding.as_ref().unwrap(),
+ stage: ep.stage,
+ options: VaryingOptions::from_writer_options(
+ self.options,
+ true,
+ ),
+ };
+ write!(self.out, "{name} = ")?;
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ";")?;
+ write!(self.out, "{level}")?;
+ }
+ }
+ }
+
+ let is_vertex_stage = self.module.entry_points[ep_index as usize].stage
+ == ShaderStage::Vertex;
+ if is_vertex_stage
+ && self
+ .options
+ .writer_flags
+ .contains(WriterFlags::ADJUST_COORDINATE_SPACE)
+ {
+ writeln!(
+ self.out,
+ "gl_Position.yz = vec2(-gl_Position.y, gl_Position.z * 2.0 - gl_Position.w);",
+ )?;
+ write!(self.out, "{level}")?;
+ }
+
+ if is_vertex_stage
+ && self
+ .options
+ .writer_flags
+ .contains(WriterFlags::FORCE_POINT_SIZE)
+ && !has_point_size
+ {
+ writeln!(self.out, "gl_PointSize = 1.0;")?;
+ write!(self.out, "{level}")?;
+ }
+ writeln!(self.out, "return;")?;
+ }
+ }
+ }
+ // This is one of the places were glsl adds to the syntax of C in this case the discard
+ // keyword which ceases all further processing in a fragment shader, it's called OpKill
+ // in spir-v that's why it's called `Statement::Kill`
+ Statement::Kill => writeln!(self.out, "{level}discard;")?,
+ Statement::Barrier(flags) => {
+ self.write_barrier(flags, level)?;
+ }
+ // Stores in glsl are just variable assignments written as `pointer = value;`
+ Statement::Store { pointer, value } => {
+ write!(self.out, "{level}")?;
+ self.write_expr(pointer, ctx)?;
+ write!(self.out, " = ")?;
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ";")?
+ }
+ Statement::WorkGroupUniformLoad { pointer, result } => {
+ // GLSL doesn't have pointers, which means that this backend needs to ensure that
+ // the actual "loading" is happening between the two barriers.
+ // This is done in `Emit` by never emitting a variable name for pointer variables
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+
+ let result_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ write!(self.out, "{level}")?;
+ // Expressions cannot have side effects, so just writing the expression here is fine.
+ self.write_named_expr(pointer, result_name, result, ctx)?;
+
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+ }
+ // Stores a value into an image.
+ Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ write!(self.out, "{level}")?;
+ self.write_image_store(ctx, image, coordinate, array_index, value)?
+ }
+ // A `Call` is written `name(arguments)` where `arguments` is a comma separated expressions list
+ Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ if let Some(expr) = result {
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ let result = self.module.functions[function].result.as_ref().unwrap();
+ self.write_type(result.ty)?;
+ write!(self.out, " {name}")?;
+ if let TypeInner::Array { base, size, .. } = self.module.types[result.ty].inner
+ {
+ self.write_array_size(base, size)?
+ }
+ write!(self.out, " = ")?;
+ self.named_expressions.insert(expr, name);
+ }
+ write!(self.out, "{}(", &self.names[&NameKey::Function(function)])?;
+ let arguments: Vec<_> = arguments
+ .iter()
+ .enumerate()
+ .filter_map(|(i, arg)| {
+ let arg_ty = self.module.functions[function].arguments[i].ty;
+ match self.module.types[arg_ty].inner {
+ TypeInner::Sampler { .. } => None,
+ _ => Some(*arg),
+ }
+ })
+ .collect();
+ self.write_slice(&arguments, |this, _, arg| this.write_expr(*arg, ctx))?;
+ writeln!(self.out, ");")?
+ }
+ Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ let res_ty = ctx.resolve_type(result, &self.module.types);
+ self.write_value_type(res_ty)?;
+ write!(self.out, " {res_name} = ")?;
+ self.named_expressions.insert(result, res_name);
+
+ let fun_str = fun.to_glsl();
+ write!(self.out, "atomic{fun_str}(")?;
+ self.write_expr(pointer, ctx)?;
+ write!(self.out, ", ")?;
+ // handle the special cases
+ match *fun {
+ crate::AtomicFunction::Subtract => {
+ // we just wrote `InterlockedAdd`, so negate the argument
+ write!(self.out, "-")?;
+ }
+ crate::AtomicFunction::Exchange { compare: Some(_) } => {
+ return Err(Error::Custom(
+ "atomic CompareExchange is not implemented".to_string(),
+ ));
+ }
+ _ => {}
+ }
+ self.write_expr(value, ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ Statement::RayQuery { .. } => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Write a const expression.
+ ///
+ /// Write `expr`, a handle to an [`Expression`] in the current [`Module`]'s
+ /// constant expression arena, as GLSL expression.
+ ///
+ /// # Notes
+ /// Adds no newlines or leading/trailing whitespace
+ ///
+ /// [`Expression`]: crate::Expression
+ /// [`Module`]: crate::Module
+ fn write_const_expr(&mut self, expr: Handle<crate::Expression>) -> BackendResult {
+ self.write_possibly_const_expr(
+ expr,
+ &self.module.const_expressions,
+ |expr| &self.info[expr],
+ |writer, expr| writer.write_const_expr(expr),
+ )
+ }
+
+ /// Write [`Expression`] variants that can occur in both runtime and const expressions.
+ ///
+ /// Write `expr`, a handle to an [`Expression`] in the arena `expressions`,
+ /// as as GLSL expression. This must be one of the [`Expression`] variants
+ /// that is allowed to occur in constant expressions.
+ ///
+ /// Use `write_expression` to write subexpressions.
+ ///
+ /// This is the common code for `write_expr`, which handles arbitrary
+ /// runtime expressions, and `write_const_expr`, which only handles
+ /// const-expressions. Each of those callers passes itself (essentially) as
+ /// the `write_expression` callback, so that subexpressions are restricted
+ /// to the appropriate variants.
+ ///
+ /// # Notes
+ /// Adds no newlines or leading/trailing whitespace
+ ///
+ /// [`Expression`]: crate::Expression
+ fn write_possibly_const_expr<'w, I, E>(
+ &'w mut self,
+ expr: Handle<crate::Expression>,
+ expressions: &crate::Arena<crate::Expression>,
+ info: I,
+ write_expression: E,
+ ) -> BackendResult
+ where
+ I: Fn(Handle<crate::Expression>) -> &'w proc::TypeResolution,
+ E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
+ {
+ use crate::Expression;
+
+ match expressions[expr] {
+ Expression::Literal(literal) => {
+ match literal {
+ // Floats are written using `Debug` instead of `Display` because it always appends the
+ // decimal part even it's zero which is needed for a valid glsl float constant
+ crate::Literal::F64(value) => write!(self.out, "{:?}LF", value)?,
+ crate::Literal::F32(value) => write!(self.out, "{:?}", value)?,
+ // Unsigned integers need a `u` at the end
+ //
+ // While `core` doesn't necessarily need it, it's allowed and since `es` needs it we
+ // always write it as the extra branch wouldn't have any benefit in readability
+ crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
+ crate::Literal::I32(value) => write!(self.out, "{}", value)?,
+ crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
+ crate::Literal::I64(_) => {
+ return Err(Error::Custom("GLSL has no 64-bit integer type".into()));
+ }
+ crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
+ return Err(Error::Custom(
+ "Abstract types should not appear in IR presented to backends".into(),
+ ));
+ }
+ }
+ }
+ Expression::Constant(handle) => {
+ let constant = &self.module.constants[handle];
+ if constant.name.is_some() {
+ write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
+ } else {
+ self.write_const_expr(constant.init)?;
+ }
+ }
+ Expression::ZeroValue(ty) => {
+ self.write_zero_init_value(ty)?;
+ }
+ Expression::Compose { ty, ref components } => {
+ self.write_type(ty)?;
+
+ if let TypeInner::Array { base, size, .. } = self.module.types[ty].inner {
+ self.write_array_size(base, size)?;
+ }
+
+ write!(self.out, "(")?;
+ for (index, component) in components.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ write_expression(self, *component)?;
+ }
+ write!(self.out, ")")?
+ }
+ // `Splat` needs to actually write down a vector, it's not always inferred in GLSL.
+ Expression::Splat { size: _, value } => {
+ let resolved = info(expr).inner_with(&self.module.types);
+ self.write_value_type(resolved)?;
+ write!(self.out, "(")?;
+ write_expression(self, value)?;
+ write!(self.out, ")")?
+ }
+ _ => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Helper method to write expressions
+ ///
+ /// # Notes
+ /// Doesn't add any newlines or leading/trailing spaces
+ fn write_expr(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ use crate::Expression;
+
+ if let Some(name) = self.named_expressions.get(&expr) {
+ write!(self.out, "{name}")?;
+ return Ok(());
+ }
+
+ match ctx.expressions[expr] {
+ Expression::Literal(_)
+ | Expression::Constant(_)
+ | Expression::ZeroValue(_)
+ | Expression::Compose { .. }
+ | Expression::Splat { .. } => {
+ self.write_possibly_const_expr(
+ expr,
+ ctx.expressions,
+ |expr| &ctx.info[expr].ty,
+ |writer, expr| writer.write_expr(expr, ctx),
+ )?;
+ }
+ // `Access` is applied to arrays, vectors and matrices and is written as indexing
+ Expression::Access { base, index } => {
+ self.write_expr(base, ctx)?;
+ write!(self.out, "[")?;
+ self.write_expr(index, ctx)?;
+ write!(self.out, "]")?
+ }
+ // `AccessIndex` is the same as `Access` except that the index is a constant and it can
+ // be applied to structs, in this case we need to find the name of the field at that
+ // index and write `base.field_name`
+ Expression::AccessIndex { base, index } => {
+ self.write_expr(base, ctx)?;
+
+ let base_ty_res = &ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&self.module.types);
+ let base_ty_handle = match *resolved {
+ TypeInner::Pointer { base, space: _ } => {
+ resolved = &self.module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+
+ match *resolved {
+ TypeInner::Vector { .. } => {
+ // Write vector access as a swizzle
+ write!(self.out, ".{}", back::COMPONENTS[index as usize])?
+ }
+ TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?,
+ TypeInner::Struct { .. } => {
+ // This will never panic in case the type is a `Struct`, this is not true
+ // for other types so we can only check while inside this match arm
+ let ty = base_ty_handle.unwrap();
+
+ write!(
+ self.out,
+ ".{}",
+ &self.names[&NameKey::StructMember(ty, index)]
+ )?
+ }
+ ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
+ }
+ }
+ // `Swizzle` adds a few letters behind the dot.
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.write_expr(vector, ctx)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ self.out.write_char(back::COMPONENTS[sc as usize])?;
+ }
+ }
+ // Function arguments are written as the argument name
+ Expression::FunctionArgument(pos) => {
+ write!(self.out, "{}", &self.names[&ctx.argument_key(pos)])?
+ }
+ // Global variables need some special work for their name but
+ // `get_global_name` does the work for us
+ Expression::GlobalVariable(handle) => {
+ let global = &self.module.global_variables[handle];
+ self.write_global_name(handle, global)?
+ }
+ // A local is written as it's name
+ Expression::LocalVariable(handle) => {
+ write!(self.out, "{}", self.names[&ctx.name_key(handle)])?
+ }
+ // glsl has no pointers so there's no load operation, just write the pointer expression
+ Expression::Load { pointer } => self.write_expr(pointer, ctx)?,
+ // `ImageSample` is a bit complicated compared to the rest of the IR.
+ //
+ // First there are three variations depending whether the sample level is explicitly set,
+ // if it's automatic or it it's bias:
+ // `texture(image, coordinate)` - Automatic sample level
+ // `texture(image, coordinate, bias)` - Bias sample level
+ // `textureLod(image, coordinate, level)` - Zero or Exact sample level
+ //
+ // Furthermore if `depth_ref` is some we need to append it to the coordinate vector
+ Expression::ImageSample {
+ image,
+ sampler: _, //TODO?
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ let dim = match *ctx.resolve_type(image, &self.module.types) {
+ TypeInner::Image { dim, .. } => dim,
+ _ => unreachable!(),
+ };
+
+ if dim == crate::ImageDimension::Cube
+ && array_index.is_some()
+ && depth_ref.is_some()
+ {
+ match level {
+ crate::SampleLevel::Zero
+ | crate::SampleLevel::Exact(_)
+ | crate::SampleLevel::Gradient { .. }
+ | crate::SampleLevel::Bias(_) => {
+ return Err(Error::Custom(String::from(
+ "gsamplerCubeArrayShadow isn't supported in textureGrad, \
+ textureLod or texture with bias",
+ )))
+ }
+ crate::SampleLevel::Auto => {}
+ }
+ }
+
+ // textureLod on sampler2DArrayShadow and samplerCubeShadow does not exist in GLSL.
+ // To emulate this, we will have to use textureGrad with a constant gradient of 0.
+ let workaround_lod_array_shadow_as_grad = (array_index.is_some()
+ || dim == crate::ImageDimension::Cube)
+ && depth_ref.is_some()
+ && gather.is_none()
+ && !self
+ .options
+ .writer_flags
+ .contains(WriterFlags::TEXTURE_SHADOW_LOD);
+
+ //Write the function to be used depending on the sample level
+ let fun_name = match level {
+ crate::SampleLevel::Zero if gather.is_some() => "textureGather",
+ crate::SampleLevel::Auto | crate::SampleLevel::Bias(_) => "texture",
+ crate::SampleLevel::Zero | crate::SampleLevel::Exact(_) => {
+ if workaround_lod_array_shadow_as_grad {
+ "textureGrad"
+ } else {
+ "textureLod"
+ }
+ }
+ crate::SampleLevel::Gradient { .. } => "textureGrad",
+ };
+ let offset_name = match offset {
+ Some(_) => "Offset",
+ None => "",
+ };
+
+ write!(self.out, "{fun_name}{offset_name}(")?;
+
+ // Write the image that will be used
+ self.write_expr(image, ctx)?;
+ // The space here isn't required but it helps with readability
+ write!(self.out, ", ")?;
+
+ // We need to get the coordinates vector size to later build a vector that's `size + 1`
+ // if `depth_ref` is some, if it isn't a vector we panic as that's not a valid expression
+ let mut coord_dim = match *ctx.resolve_type(coordinate, &self.module.types) {
+ TypeInner::Vector { size, .. } => size as u8,
+ TypeInner::Scalar { .. } => 1,
+ _ => unreachable!(),
+ };
+
+ if array_index.is_some() {
+ coord_dim += 1;
+ }
+ let merge_depth_ref = depth_ref.is_some() && gather.is_none() && coord_dim < 4;
+ if merge_depth_ref {
+ coord_dim += 1;
+ }
+
+ let tex_1d_hack = dim == crate::ImageDimension::D1 && self.options.version.is_es();
+ let is_vec = tex_1d_hack || coord_dim != 1;
+ // Compose a new texture coordinates vector
+ if is_vec {
+ write!(self.out, "vec{}(", coord_dim + tex_1d_hack as u8)?;
+ }
+ self.write_expr(coordinate, ctx)?;
+ if tex_1d_hack {
+ write!(self.out, ", 0.0")?;
+ }
+ if let Some(expr) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+ if merge_depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(depth_ref.unwrap(), ctx)?;
+ }
+ if is_vec {
+ write!(self.out, ")")?;
+ }
+
+ if let (Some(expr), false) = (depth_ref, merge_depth_ref) {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+
+ match level {
+ // Auto needs no more arguments
+ crate::SampleLevel::Auto => (),
+ // Zero needs level set to 0
+ crate::SampleLevel::Zero => {
+ if workaround_lod_array_shadow_as_grad {
+ let vec_dim = match dim {
+ crate::ImageDimension::Cube => 3,
+ _ => 2,
+ };
+ write!(self.out, ", vec{vec_dim}(0.0), vec{vec_dim}(0.0)")?;
+ } else if gather.is_none() {
+ write!(self.out, ", 0.0")?;
+ }
+ }
+ // Exact and bias require another argument
+ crate::SampleLevel::Exact(expr) => {
+ if workaround_lod_array_shadow_as_grad {
+ log::warn!("Unable to `textureLod` a shadow array, ignoring the LOD");
+ write!(self.out, ", vec2(0,0), vec2(0,0)")?;
+ } else {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+ }
+ crate::SampleLevel::Bias(_) => {
+ // This needs to be done after the offset writing
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ // If we are using sampler2D to replace sampler1D, we also
+ // need to make sure to use vec2 gradients
+ if tex_1d_hack {
+ write!(self.out, ", vec2(")?;
+ self.write_expr(x, ctx)?;
+ write!(self.out, ", 0.0)")?;
+ write!(self.out, ", vec2(")?;
+ self.write_expr(y, ctx)?;
+ write!(self.out, ", 0.0)")?;
+ } else {
+ write!(self.out, ", ")?;
+ self.write_expr(x, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(y, ctx)?;
+ }
+ }
+ }
+
+ if let Some(constant) = offset {
+ write!(self.out, ", ")?;
+ if tex_1d_hack {
+ write!(self.out, "ivec2(")?;
+ }
+ self.write_const_expr(constant)?;
+ if tex_1d_hack {
+ write!(self.out, ", 0)")?;
+ }
+ }
+
+ // Bias is always the last argument
+ if let crate::SampleLevel::Bias(expr) = level {
+ write!(self.out, ", ")?;
+ self.write_expr(expr, ctx)?;
+ }
+
+ if let (Some(component), None) = (gather, depth_ref) {
+ write!(self.out, ", {}", component as usize)?;
+ }
+
+ // End the function
+ write!(self.out, ")")?
+ }
+ Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => self.write_image_load(expr, ctx, image, coordinate, array_index, sample, level)?,
+ // Query translates into one of the:
+ // - textureSize/imageSize
+ // - textureQueryLevels
+ // - textureSamples/imageSamples
+ Expression::ImageQuery { image, query } => {
+ use crate::ImageClass;
+
+ // This will only panic if the module is invalid
+ let (dim, class) = match *ctx.resolve_type(image, &self.module.types) {
+ TypeInner::Image {
+ dim,
+ arrayed: _,
+ class,
+ } => (dim, class),
+ _ => unreachable!(),
+ };
+ let components = match dim {
+ crate::ImageDimension::D1 => 1,
+ crate::ImageDimension::D2 => 2,
+ crate::ImageDimension::D3 => 3,
+ crate::ImageDimension::Cube => 2,
+ };
+
+ if let crate::ImageQuery::Size { .. } = query {
+ match components {
+ 1 => write!(self.out, "uint(")?,
+ _ => write!(self.out, "uvec{components}(")?,
+ }
+ } else {
+ write!(self.out, "uint(")?;
+ }
+
+ match query {
+ crate::ImageQuery::Size { level } => {
+ match class {
+ ImageClass::Sampled { multi, .. } | ImageClass::Depth { multi } => {
+ write!(self.out, "textureSize(")?;
+ self.write_expr(image, ctx)?;
+ if let Some(expr) = level {
+ let cast_to_int = matches!(
+ *ctx.resolve_type(expr, &self.module.types),
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ })
+ );
+
+ write!(self.out, ", ")?;
+
+ if cast_to_int {
+ write!(self.out, "int(")?;
+ }
+
+ self.write_expr(expr, ctx)?;
+
+ if cast_to_int {
+ write!(self.out, ")")?;
+ }
+ } else if !multi {
+ // All textureSize calls requires an lod argument
+ // except for multisampled samplers
+ write!(self.out, ", 0")?;
+ }
+ }
+ ImageClass::Storage { .. } => {
+ write!(self.out, "imageSize(")?;
+ self.write_expr(image, ctx)?;
+ }
+ }
+ write!(self.out, ")")?;
+ if components != 1 || self.options.version.is_es() {
+ write!(self.out, ".{}", &"xyz"[..components])?;
+ }
+ }
+ crate::ImageQuery::NumLevels => {
+ write!(self.out, "textureQueryLevels(",)?;
+ self.write_expr(image, ctx)?;
+ write!(self.out, ")",)?;
+ }
+ crate::ImageQuery::NumLayers => {
+ let fun_name = match class {
+ ImageClass::Sampled { .. } | ImageClass::Depth { .. } => "textureSize",
+ ImageClass::Storage { .. } => "imageSize",
+ };
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(image, ctx)?;
+ // All textureSize calls requires an lod argument
+ // except for multisampled samplers
+ if class.is_multisampled() {
+ write!(self.out, ", 0")?;
+ }
+ write!(self.out, ")")?;
+ if components != 1 || self.options.version.is_es() {
+ write!(self.out, ".{}", back::COMPONENTS[components])?;
+ }
+ }
+ crate::ImageQuery::NumSamples => {
+ let fun_name = match class {
+ ImageClass::Sampled { .. } | ImageClass::Depth { .. } => {
+ "textureSamples"
+ }
+ ImageClass::Storage { .. } => "imageSamples",
+ };
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(image, ctx)?;
+ write!(self.out, ")",)?;
+ }
+ }
+
+ write!(self.out, ")")?;
+ }
+ Expression::Unary { op, expr } => {
+ let operator_or_fn = match op {
+ crate::UnaryOperator::Negate => "-",
+ crate::UnaryOperator::LogicalNot => {
+ match *ctx.resolve_type(expr, &self.module.types) {
+ TypeInner::Vector { .. } => "not",
+ _ => "!",
+ }
+ }
+ crate::UnaryOperator::BitwiseNot => "~",
+ };
+ write!(self.out, "{operator_or_fn}(")?;
+
+ self.write_expr(expr, ctx)?;
+
+ write!(self.out, ")")?
+ }
+ // `Binary` we just write `left op right`, except when dealing with
+ // comparison operations on vectors as they are implemented with
+ // builtin functions.
+ // Once again we wrap everything in parentheses to avoid precedence issues
+ Expression::Binary {
+ mut op,
+ left,
+ right,
+ } => {
+ // Holds `Some(function_name)` if the binary operation is
+ // implemented as a function call
+ use crate::{BinaryOperator as Bo, ScalarKind as Sk, TypeInner as Ti};
+
+ let left_inner = ctx.resolve_type(left, &self.module.types);
+ let right_inner = ctx.resolve_type(right, &self.module.types);
+
+ let function = match (left_inner, right_inner) {
+ (&Ti::Vector { scalar, .. }, &Ti::Vector { .. }) => match op {
+ Bo::Less
+ | Bo::LessEqual
+ | Bo::Greater
+ | Bo::GreaterEqual
+ | Bo::Equal
+ | Bo::NotEqual => BinaryOperation::VectorCompare,
+ Bo::Modulo if scalar.kind == Sk::Float => BinaryOperation::Modulo,
+ Bo::And if scalar.kind == Sk::Bool => {
+ op = crate::BinaryOperator::LogicalAnd;
+ BinaryOperation::VectorComponentWise
+ }
+ Bo::InclusiveOr if scalar.kind == Sk::Bool => {
+ op = crate::BinaryOperator::LogicalOr;
+ BinaryOperation::VectorComponentWise
+ }
+ _ => BinaryOperation::Other,
+ },
+ _ => match (left_inner.scalar_kind(), right_inner.scalar_kind()) {
+ (Some(Sk::Float), _) | (_, Some(Sk::Float)) => match op {
+ Bo::Modulo => BinaryOperation::Modulo,
+ _ => BinaryOperation::Other,
+ },
+ (Some(Sk::Bool), Some(Sk::Bool)) => match op {
+ Bo::InclusiveOr => {
+ op = crate::BinaryOperator::LogicalOr;
+ BinaryOperation::Other
+ }
+ Bo::And => {
+ op = crate::BinaryOperator::LogicalAnd;
+ BinaryOperation::Other
+ }
+ _ => BinaryOperation::Other,
+ },
+ _ => BinaryOperation::Other,
+ },
+ };
+
+ match function {
+ BinaryOperation::VectorCompare => {
+ let op_str = match op {
+ Bo::Less => "lessThan(",
+ Bo::LessEqual => "lessThanEqual(",
+ Bo::Greater => "greaterThan(",
+ Bo::GreaterEqual => "greaterThanEqual(",
+ Bo::Equal => "equal(",
+ Bo::NotEqual => "notEqual(",
+ _ => unreachable!(),
+ };
+ write!(self.out, "{op_str}")?;
+ self.write_expr(left, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(right, ctx)?;
+ write!(self.out, ")")?;
+ }
+ BinaryOperation::VectorComponentWise => {
+ self.write_value_type(left_inner)?;
+ write!(self.out, "(")?;
+
+ let size = match *left_inner {
+ Ti::Vector { size, .. } => size,
+ _ => unreachable!(),
+ };
+
+ for i in 0..size as usize {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+
+ self.write_expr(left, ctx)?;
+ write!(self.out, ".{}", back::COMPONENTS[i])?;
+
+ write!(self.out, " {} ", back::binary_operation_str(op))?;
+
+ self.write_expr(right, ctx)?;
+ write!(self.out, ".{}", back::COMPONENTS[i])?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ // TODO: handle undefined behavior of BinaryOperator::Modulo
+ //
+ // sint:
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ // if sign(left) == -1 || sign(right) == -1 return result as defined by WGSL
+ //
+ // uint:
+ // if right == 0 return 0
+ //
+ // float:
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+ BinaryOperation::Modulo => {
+ write!(self.out, "(")?;
+
+ // write `e1 - e2 * trunc(e1 / e2)`
+ self.write_expr(left, ctx)?;
+ write!(self.out, " - ")?;
+ self.write_expr(right, ctx)?;
+ write!(self.out, " * ")?;
+ write!(self.out, "trunc(")?;
+ self.write_expr(left, ctx)?;
+ write!(self.out, " / ")?;
+ self.write_expr(right, ctx)?;
+ write!(self.out, ")")?;
+
+ write!(self.out, ")")?;
+ }
+ BinaryOperation::Other => {
+ write!(self.out, "(")?;
+
+ self.write_expr(left, ctx)?;
+ write!(self.out, " {} ", back::binary_operation_str(op))?;
+ self.write_expr(right, ctx)?;
+
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ // `Select` is written as `condition ? accept : reject`
+ // We wrap everything in parentheses to avoid precedence issues
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ let cond_ty = ctx.resolve_type(condition, &self.module.types);
+ let vec_select = if let TypeInner::Vector { .. } = *cond_ty {
+ true
+ } else {
+ false
+ };
+
+ // TODO: Boolean mix on desktop required GL_EXT_shader_integer_mix
+ if vec_select {
+ // Glsl defines that for mix when the condition is a boolean the first element
+ // is picked if condition is false and the second if condition is true
+ write!(self.out, "mix(")?;
+ self.write_expr(reject, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(accept, ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(condition, ctx)?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(condition, ctx)?;
+ write!(self.out, " ? ")?;
+ self.write_expr(accept, ctx)?;
+ write!(self.out, " : ")?;
+ self.write_expr(reject, ctx)?;
+ }
+
+ write!(self.out, ")")?
+ }
+ // `Derivative` is a function call to a glsl provided function
+ Expression::Derivative { axis, ctrl, expr } => {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ let fun_name = if self.options.version.supports_derivative_control() {
+ match (axis, ctrl) {
+ (Axis::X, Ctrl::Coarse) => "dFdxCoarse",
+ (Axis::X, Ctrl::Fine) => "dFdxFine",
+ (Axis::X, Ctrl::None) => "dFdx",
+ (Axis::Y, Ctrl::Coarse) => "dFdyCoarse",
+ (Axis::Y, Ctrl::Fine) => "dFdyFine",
+ (Axis::Y, Ctrl::None) => "dFdy",
+ (Axis::Width, Ctrl::Coarse) => "fwidthCoarse",
+ (Axis::Width, Ctrl::Fine) => "fwidthFine",
+ (Axis::Width, Ctrl::None) => "fwidth",
+ }
+ } else {
+ match axis {
+ Axis::X => "dFdx",
+ Axis::Y => "dFdy",
+ Axis::Width => "fwidth",
+ }
+ };
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ")")?
+ }
+ // `Relational` is a normal function call to some glsl provided functions
+ Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+
+ let fun_name = match fun {
+ Rf::IsInf => "isinf",
+ Rf::IsNan => "isnan",
+ Rf::All => "all",
+ Rf::Any => "any",
+ };
+ write!(self.out, "{fun_name}(")?;
+
+ self.write_expr(argument, ctx)?;
+
+ write!(self.out, ")")?
+ }
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ let fun_name = match fun {
+ // comparison
+ Mf::Abs => "abs",
+ Mf::Min => "min",
+ Mf::Max => "max",
+ Mf::Clamp => "clamp",
+ Mf::Saturate => {
+ write!(self.out, "clamp(")?;
+
+ self.write_expr(arg, ctx)?;
+
+ match *ctx.resolve_type(arg, &self.module.types) {
+ crate::TypeInner::Vector { size, .. } => write!(
+ self.out,
+ ", vec{}(0.0), vec{0}(1.0)",
+ back::vector_size_str(size)
+ )?,
+ _ => write!(self.out, ", 0.0, 1.0")?,
+ }
+
+ write!(self.out, ")")?;
+
+ return Ok(());
+ }
+ // trigonometry
+ Mf::Cos => "cos",
+ Mf::Cosh => "cosh",
+ Mf::Sin => "sin",
+ Mf::Sinh => "sinh",
+ Mf::Tan => "tan",
+ Mf::Tanh => "tanh",
+ Mf::Acos => "acos",
+ Mf::Asin => "asin",
+ Mf::Atan => "atan",
+ Mf::Asinh => "asinh",
+ Mf::Acosh => "acosh",
+ Mf::Atanh => "atanh",
+ Mf::Radians => "radians",
+ Mf::Degrees => "degrees",
+ // glsl doesn't have atan2 function
+ // use two-argument variation of the atan function
+ Mf::Atan2 => "atan",
+ // decomposition
+ Mf::Ceil => "ceil",
+ Mf::Floor => "floor",
+ Mf::Round => "roundEven",
+ Mf::Fract => "fract",
+ Mf::Trunc => "trunc",
+ Mf::Modf => MODF_FUNCTION,
+ Mf::Frexp => FREXP_FUNCTION,
+ Mf::Ldexp => "ldexp",
+ // exponent
+ Mf::Exp => "exp",
+ Mf::Exp2 => "exp2",
+ Mf::Log => "log",
+ Mf::Log2 => "log2",
+ Mf::Pow => "pow",
+ // geometry
+ Mf::Dot => match *ctx.resolve_type(arg, &self.module.types) {
+ crate::TypeInner::Vector {
+ scalar:
+ crate::Scalar {
+ kind: crate::ScalarKind::Float,
+ ..
+ },
+ ..
+ } => "dot",
+ crate::TypeInner::Vector { size, .. } => {
+ return self.write_dot_product(arg, arg1.unwrap(), size as usize, ctx)
+ }
+ _ => unreachable!(
+ "Correct TypeInner for dot product should be already validated"
+ ),
+ },
+ Mf::Outer => "outerProduct",
+ Mf::Cross => "cross",
+ Mf::Distance => "distance",
+ Mf::Length => "length",
+ Mf::Normalize => "normalize",
+ Mf::FaceForward => "faceforward",
+ Mf::Reflect => "reflect",
+ Mf::Refract => "refract",
+ // computational
+ Mf::Sign => "sign",
+ Mf::Fma => {
+ if self.options.version.supports_fma_function() {
+ // Use the fma function when available
+ "fma"
+ } else {
+ // No fma support. Transform the function call into an arithmetic expression
+ write!(self.out, "(")?;
+
+ self.write_expr(arg, ctx)?;
+ write!(self.out, " * ")?;
+
+ let arg1 =
+ arg1.ok_or_else(|| Error::Custom("Missing fma arg1".to_owned()))?;
+ self.write_expr(arg1, ctx)?;
+ write!(self.out, " + ")?;
+
+ let arg2 =
+ arg2.ok_or_else(|| Error::Custom("Missing fma arg2".to_owned()))?;
+ self.write_expr(arg2, ctx)?;
+ write!(self.out, ")")?;
+
+ return Ok(());
+ }
+ }
+ Mf::Mix => "mix",
+ Mf::Step => "step",
+ Mf::SmoothStep => "smoothstep",
+ Mf::Sqrt => "sqrt",
+ Mf::InverseSqrt => "inversesqrt",
+ Mf::Inverse => "inverse",
+ Mf::Transpose => "transpose",
+ Mf::Determinant => "determinant",
+ // bits
+ Mf::CountTrailingZeros => {
+ match *ctx.resolve_type(arg, &self.module.types) {
+ crate::TypeInner::Vector { size, scalar, .. } => {
+ let s = back::vector_size_str(size);
+ if let crate::ScalarKind::Uint = scalar.kind {
+ write!(self.out, "min(uvec{s}(findLSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")), uvec{s}(32u))")?;
+ } else {
+ write!(self.out, "ivec{s}(min(uvec{s}(findLSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")), uvec{s}(32u)))")?;
+ }
+ }
+ crate::TypeInner::Scalar(scalar) => {
+ if let crate::ScalarKind::Uint = scalar.kind {
+ write!(self.out, "min(uint(findLSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")), 32u)")?;
+ } else {
+ write!(self.out, "int(min(uint(findLSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")), 32u))")?;
+ }
+ }
+ _ => unreachable!(),
+ };
+ return Ok(());
+ }
+ Mf::CountLeadingZeros => {
+ if self.options.version.supports_integer_functions() {
+ match *ctx.resolve_type(arg, &self.module.types) {
+ crate::TypeInner::Vector { size, scalar } => {
+ let s = back::vector_size_str(size);
+
+ if let crate::ScalarKind::Uint = scalar.kind {
+ write!(self.out, "uvec{s}(ivec{s}(31) - findMSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "mix(ivec{s}(31) - findMSB(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, "), ivec{s}(0), lessThan(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ", ivec{s}(0)))")?;
+ }
+ }
+ crate::TypeInner::Scalar(scalar) => {
+ if let crate::ScalarKind::Uint = scalar.kind {
+ write!(self.out, "uint(31 - findMSB(")?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, " < 0 ? 0 : 31 - findMSB(")?;
+ }
+
+ self.write_expr(arg, ctx)?;
+ write!(self.out, "))")?;
+ }
+ _ => unreachable!(),
+ };
+ } else {
+ match *ctx.resolve_type(arg, &self.module.types) {
+ crate::TypeInner::Vector { size, scalar } => {
+ let s = back::vector_size_str(size);
+
+ if let crate::ScalarKind::Uint = scalar.kind {
+ write!(self.out, "uvec{s}(")?;
+ write!(self.out, "vec{s}(31.0) - floor(log2(vec{s}(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ") + 0.5)))")?;
+ } else {
+ write!(self.out, "ivec{s}(")?;
+ write!(self.out, "mix(vec{s}(31.0) - floor(log2(vec{s}(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ") + 0.5)), ")?;
+ write!(self.out, "vec{s}(0.0), lessThan(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ", ivec{s}(0u))))")?;
+ }
+ }
+ crate::TypeInner::Scalar(scalar) => {
+ if let crate::ScalarKind::Uint = scalar.kind {
+ write!(self.out, "uint(31.0 - floor(log2(float(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ") + 0.5)))")?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, " < 0 ? 0 : int(")?;
+ write!(self.out, "31.0 - floor(log2(float(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ") + 0.5))))")?;
+ }
+ }
+ _ => unreachable!(),
+ };
+ }
+
+ return Ok(());
+ }
+ Mf::CountOneBits => "bitCount",
+ Mf::ReverseBits => "bitfieldReverse",
+ Mf::ExtractBits => "bitfieldExtract",
+ Mf::InsertBits => "bitfieldInsert",
+ Mf::FindLsb => "findLSB",
+ Mf::FindMsb => "findMSB",
+ // data packing
+ Mf::Pack4x8snorm => "packSnorm4x8",
+ Mf::Pack4x8unorm => "packUnorm4x8",
+ Mf::Pack2x16snorm => "packSnorm2x16",
+ Mf::Pack2x16unorm => "packUnorm2x16",
+ Mf::Pack2x16float => "packHalf2x16",
+ // data unpacking
+ Mf::Unpack4x8snorm => "unpackSnorm4x8",
+ Mf::Unpack4x8unorm => "unpackUnorm4x8",
+ Mf::Unpack2x16snorm => "unpackSnorm2x16",
+ Mf::Unpack2x16unorm => "unpackUnorm2x16",
+ Mf::Unpack2x16float => "unpackHalf2x16",
+ };
+
+ let extract_bits = fun == Mf::ExtractBits;
+ let insert_bits = fun == Mf::InsertBits;
+
+ // Some GLSL functions always return signed integers (like findMSB),
+ // so they need to be cast to uint if the argument is also an uint.
+ let ret_might_need_int_to_uint =
+ matches!(fun, Mf::FindLsb | Mf::FindMsb | Mf::CountOneBits | Mf::Abs);
+
+ // Some GLSL functions only accept signed integers (like abs),
+ // so they need their argument cast from uint to int.
+ let arg_might_need_uint_to_int = matches!(fun, Mf::Abs);
+
+ // Check if the argument is an unsigned integer and return the vector size
+ // in case it's a vector
+ let maybe_uint_size = match *ctx.resolve_type(arg, &self.module.types) {
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ }) => Some(None),
+ crate::TypeInner::Vector {
+ scalar:
+ crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ },
+ size,
+ } => Some(Some(size)),
+ _ => None,
+ };
+
+ // Cast to uint if the function needs it
+ if ret_might_need_int_to_uint {
+ if let Some(maybe_size) = maybe_uint_size {
+ match maybe_size {
+ Some(size) => write!(self.out, "uvec{}(", size as u8)?,
+ None => write!(self.out, "uint(")?,
+ }
+ }
+ }
+
+ write!(self.out, "{fun_name}(")?;
+
+ // Cast to int if the function needs it
+ if arg_might_need_uint_to_int {
+ if let Some(maybe_size) = maybe_uint_size {
+ match maybe_size {
+ Some(size) => write!(self.out, "ivec{}(", size as u8)?,
+ None => write!(self.out, "int(")?,
+ }
+ }
+ }
+
+ self.write_expr(arg, ctx)?;
+
+ // Close the cast from uint to int
+ if arg_might_need_uint_to_int && maybe_uint_size.is_some() {
+ write!(self.out, ")")?
+ }
+
+ if let Some(arg) = arg1 {
+ write!(self.out, ", ")?;
+ if extract_bits {
+ write!(self.out, "int(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(arg, ctx)?;
+ }
+ }
+ if let Some(arg) = arg2 {
+ write!(self.out, ", ")?;
+ if extract_bits || insert_bits {
+ write!(self.out, "int(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(arg, ctx)?;
+ }
+ }
+ if let Some(arg) = arg3 {
+ write!(self.out, ", ")?;
+ if insert_bits {
+ write!(self.out, "int(")?;
+ self.write_expr(arg, ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(arg, ctx)?;
+ }
+ }
+ write!(self.out, ")")?;
+
+ // Close the cast from int to uint
+ if ret_might_need_int_to_uint && maybe_uint_size.is_some() {
+ write!(self.out, ")")?
+ }
+ }
+ // `As` is always a call.
+ // If `convert` is true the function name is the type
+ // Else the function name is one of the glsl provided bitcast functions
+ Expression::As {
+ expr,
+ kind: target_kind,
+ convert,
+ } => {
+ let inner = ctx.resolve_type(expr, &self.module.types);
+ match convert {
+ Some(width) => {
+ // this is similar to `write_type`, but with the target kind
+ let scalar = glsl_scalar(crate::Scalar {
+ kind: target_kind,
+ width,
+ })?;
+ match *inner {
+ TypeInner::Matrix { columns, rows, .. } => write!(
+ self.out,
+ "{}mat{}x{}",
+ scalar.prefix, columns as u8, rows as u8
+ )?,
+ TypeInner::Vector { size, .. } => {
+ write!(self.out, "{}vec{}", scalar.prefix, size as u8)?
+ }
+ _ => write!(self.out, "{}", scalar.full)?,
+ }
+
+ write!(self.out, "(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ")")?
+ }
+ None => {
+ use crate::ScalarKind as Sk;
+
+ let target_vector_type = match *inner {
+ TypeInner::Vector { size, scalar } => Some(TypeInner::Vector {
+ size,
+ scalar: crate::Scalar {
+ kind: target_kind,
+ width: scalar.width,
+ },
+ }),
+ _ => None,
+ };
+
+ let source_kind = inner.scalar_kind().unwrap();
+
+ match (source_kind, target_kind, target_vector_type) {
+ // No conversion needed
+ (Sk::Sint, Sk::Sint, _)
+ | (Sk::Uint, Sk::Uint, _)
+ | (Sk::Float, Sk::Float, _)
+ | (Sk::Bool, Sk::Bool, _) => {
+ self.write_expr(expr, ctx)?;
+ return Ok(());
+ }
+
+ // Cast to/from floats
+ (Sk::Float, Sk::Sint, _) => write!(self.out, "floatBitsToInt")?,
+ (Sk::Float, Sk::Uint, _) => write!(self.out, "floatBitsToUint")?,
+ (Sk::Sint, Sk::Float, _) => write!(self.out, "intBitsToFloat")?,
+ (Sk::Uint, Sk::Float, _) => write!(self.out, "uintBitsToFloat")?,
+
+ // Cast between vector types
+ (_, _, Some(vector)) => {
+ self.write_value_type(&vector)?;
+ }
+
+ // There is no way to bitcast between Uint/Sint in glsl. Use constructor conversion
+ (Sk::Uint | Sk::Bool, Sk::Sint, None) => write!(self.out, "int")?,
+ (Sk::Sint | Sk::Bool, Sk::Uint, None) => write!(self.out, "uint")?,
+ (Sk::Bool, Sk::Float, None) => write!(self.out, "float")?,
+ (Sk::Sint | Sk::Uint | Sk::Float, Sk::Bool, None) => {
+ write!(self.out, "bool")?
+ }
+
+ (Sk::AbstractInt | Sk::AbstractFloat, _, _)
+ | (_, Sk::AbstractInt | Sk::AbstractFloat, _) => unreachable!(),
+ };
+
+ write!(self.out, "(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ // These expressions never show up in `Emit`.
+ Expression::CallResult(_)
+ | Expression::AtomicResult { .. }
+ | Expression::RayQueryProceedResult
+ | Expression::WorkGroupUniformLoadResult { .. } => unreachable!(),
+ // `ArrayLength` is written as `expr.length()` and we convert it to a uint
+ Expression::ArrayLength(expr) => {
+ write!(self.out, "uint(")?;
+ self.write_expr(expr, ctx)?;
+ write!(self.out, ".length())")?
+ }
+ // not supported yet
+ Expression::RayQueryGetIntersection { .. } => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Helper function to write the local holding the clamped lod
+ fn write_clamped_lod(
+ &mut self,
+ ctx: &back::FunctionCtx,
+ expr: Handle<crate::Expression>,
+ image: Handle<crate::Expression>,
+ level_expr: Handle<crate::Expression>,
+ ) -> Result<(), Error> {
+ // Define our local and start a call to `clamp`
+ write!(
+ self.out,
+ "int {}{}{} = clamp(",
+ back::BAKE_PREFIX,
+ expr.index(),
+ CLAMPED_LOD_SUFFIX
+ )?;
+ // Write the lod that will be clamped
+ self.write_expr(level_expr, ctx)?;
+ // Set the min value to 0 and start a call to `textureQueryLevels` to get
+ // the maximum value
+ write!(self.out, ", 0, textureQueryLevels(")?;
+ // Write the target image as an argument to `textureQueryLevels`
+ self.write_expr(image, ctx)?;
+ // Close the call to `textureQueryLevels` subtract 1 from it since
+ // the lod argument is 0 based, close the `clamp` call and end the
+ // local declaration statement.
+ writeln!(self.out, ") - 1);")?;
+
+ Ok(())
+ }
+
+ // Helper method used to retrieve how many elements a coordinate vector
+ // for the images operations need.
+ fn get_coordinate_vector_size(&self, dim: crate::ImageDimension, arrayed: bool) -> u8 {
+ // openGL es doesn't have 1D images so we need workaround it
+ let tex_1d_hack = dim == crate::ImageDimension::D1 && self.options.version.is_es();
+ // Get how many components the coordinate vector needs for the dimensions only
+ let tex_coord_size = match dim {
+ crate::ImageDimension::D1 => 1,
+ crate::ImageDimension::D2 => 2,
+ crate::ImageDimension::D3 => 3,
+ crate::ImageDimension::Cube => 2,
+ };
+ // Calculate the true size of the coordinate vector by adding 1 for arrayed images
+ // and another 1 if we need to workaround 1D images by making them 2D
+ tex_coord_size + tex_1d_hack as u8 + arrayed as u8
+ }
+
+ /// Helper method to write the coordinate vector for image operations
+ fn write_texture_coord(
+ &mut self,
+ ctx: &back::FunctionCtx,
+ vector_size: u8,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ // Emulate 1D images as 2D for profiles that don't support it (glsl es)
+ tex_1d_hack: bool,
+ ) -> Result<(), Error> {
+ match array_index {
+ // If the image needs an array indice we need to add it to the end of our
+ // coordinate vector, to do so we will use the `ivec(ivec, scalar)`
+ // constructor notation (NOTE: the inner `ivec` can also be a scalar, this
+ // is important for 1D arrayed images).
+ Some(layer_expr) => {
+ write!(self.out, "ivec{vector_size}(")?;
+ self.write_expr(coordinate, ctx)?;
+ write!(self.out, ", ")?;
+ // If we are replacing sampler1D with sampler2D we also need
+ // to add another zero to the coordinates vector for the y component
+ if tex_1d_hack {
+ write!(self.out, "0, ")?;
+ }
+ self.write_expr(layer_expr, ctx)?;
+ write!(self.out, ")")?;
+ }
+ // Otherwise write just the expression (and the 1D hack if needed)
+ None => {
+ let uvec_size = match *ctx.resolve_type(coordinate, &self.module.types) {
+ TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ }) => Some(None),
+ TypeInner::Vector {
+ size,
+ scalar:
+ crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ ..
+ },
+ } => Some(Some(size as u32)),
+ _ => None,
+ };
+ if tex_1d_hack {
+ write!(self.out, "ivec2(")?;
+ } else if uvec_size.is_some() {
+ match uvec_size {
+ Some(None) => write!(self.out, "int(")?,
+ Some(Some(size)) => write!(self.out, "ivec{size}(")?,
+ _ => {}
+ }
+ }
+ self.write_expr(coordinate, ctx)?;
+ if tex_1d_hack {
+ write!(self.out, ", 0)")?;
+ } else if uvec_size.is_some() {
+ write!(self.out, ")")?;
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method to write the `ImageStore` statement
+ fn write_image_store(
+ &mut self,
+ ctx: &back::FunctionCtx,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ value: Handle<crate::Expression>,
+ ) -> Result<(), Error> {
+ use crate::ImageDimension as IDim;
+
+ // NOTE: openGL requires that `imageStore`s have no effets when the texel is invalid
+ // so we don't need to generate bounds checks (OpenGL 4.2 Core §3.9.20)
+
+ // This will only panic if the module is invalid
+ let dim = match *ctx.resolve_type(image, &self.module.types) {
+ TypeInner::Image { dim, .. } => dim,
+ _ => unreachable!(),
+ };
+
+ // Begin our call to `imageStore`
+ write!(self.out, "imageStore(")?;
+ self.write_expr(image, ctx)?;
+ // Separate the image argument from the coordinates
+ write!(self.out, ", ")?;
+
+ // openGL es doesn't have 1D images so we need workaround it
+ let tex_1d_hack = dim == IDim::D1 && self.options.version.is_es();
+ // Write the coordinate vector
+ self.write_texture_coord(
+ ctx,
+ // Get the size of the coordinate vector
+ self.get_coordinate_vector_size(dim, array_index.is_some()),
+ coordinate,
+ array_index,
+ tex_1d_hack,
+ )?;
+
+ // Separate the coordinate from the value to write and write the expression
+ // of the value to write.
+ write!(self.out, ", ")?;
+ self.write_expr(value, ctx)?;
+ // End the call to `imageStore` and the statement.
+ writeln!(self.out, ");")?;
+
+ Ok(())
+ }
+
+ /// Helper method for writing an `ImageLoad` expression.
+ #[allow(clippy::too_many_arguments)]
+ fn write_image_load(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ ctx: &back::FunctionCtx,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ sample: Option<Handle<crate::Expression>>,
+ level: Option<Handle<crate::Expression>>,
+ ) -> Result<(), Error> {
+ use crate::ImageDimension as IDim;
+
+ // `ImageLoad` is a bit complicated.
+ // There are two functions one for sampled
+ // images another for storage images, the former uses `texelFetch` and the
+ // latter uses `imageLoad`.
+ //
+ // Furthermore we have `level` which is always `Some` for sampled images
+ // and `None` for storage images, so we end up with two functions:
+ // - `texelFetch(image, coordinate, level)` for sampled images
+ // - `imageLoad(image, coordinate)` for storage images
+ //
+ // Finally we also have to consider bounds checking, for storage images
+ // this is easy since openGL requires that invalid texels always return
+ // 0, for sampled images we need to either verify that all arguments are
+ // in bounds (`ReadZeroSkipWrite`) or make them a valid texel (`Restrict`).
+
+ // This will only panic if the module is invalid
+ let (dim, class) = match *ctx.resolve_type(image, &self.module.types) {
+ TypeInner::Image {
+ dim,
+ arrayed: _,
+ class,
+ } => (dim, class),
+ _ => unreachable!(),
+ };
+
+ // Get the name of the function to be used for the load operation
+ // and the policy to be used with it.
+ let (fun_name, policy) = match class {
+ // Sampled images inherit the policy from the user passed policies
+ crate::ImageClass::Sampled { .. } => ("texelFetch", self.policies.image_load),
+ crate::ImageClass::Storage { .. } => {
+ // OpenGL ES 3.1 mentions in Chapter "8.22 Texture Image Loads and Stores" that:
+ // "Invalid image loads will return a vector where the value of R, G, and B components
+ // is 0 and the value of the A component is undefined."
+ //
+ // OpenGL 4.2 Core mentions in Chapter "3.9.20 Texture Image Loads and Stores" that:
+ // "Invalid image loads will return zero."
+ //
+ // So, we only inject bounds checks for ES
+ let policy = if self.options.version.is_es() {
+ self.policies.image_load
+ } else {
+ proc::BoundsCheckPolicy::Unchecked
+ };
+ ("imageLoad", policy)
+ }
+ // TODO: Is there even a function for this?
+ crate::ImageClass::Depth { multi: _ } => {
+ return Err(Error::Custom(
+ "WGSL `textureLoad` from depth textures is not supported in GLSL".to_string(),
+ ))
+ }
+ };
+
+ // openGL es doesn't have 1D images so we need workaround it
+ let tex_1d_hack = dim == IDim::D1 && self.options.version.is_es();
+ // Get the size of the coordinate vector
+ let vector_size = self.get_coordinate_vector_size(dim, array_index.is_some());
+
+ if let proc::BoundsCheckPolicy::ReadZeroSkipWrite = policy {
+ // To write the bounds checks for `ReadZeroSkipWrite` we will use a
+ // ternary operator since we are in the middle of an expression and
+ // need to return a value.
+ //
+ // NOTE: glsl does short circuit when evaluating logical
+ // expressions so we can be sure that after we test a
+ // condition it will be true for the next ones
+
+ // Write parentheses around the ternary operator to prevent problems with
+ // expressions emitted before or after it having more precedence
+ write!(self.out, "(",)?;
+
+ // The lod check needs to precede the size check since we need
+ // to use the lod to get the size of the image at that level.
+ if let Some(level_expr) = level {
+ self.write_expr(level_expr, ctx)?;
+ write!(self.out, " < textureQueryLevels(",)?;
+ self.write_expr(image, ctx)?;
+ // Chain the next check
+ write!(self.out, ") && ")?;
+ }
+
+ // Check that the sample arguments doesn't exceed the number of samples
+ if let Some(sample_expr) = sample {
+ self.write_expr(sample_expr, ctx)?;
+ write!(self.out, " < textureSamples(",)?;
+ self.write_expr(image, ctx)?;
+ // Chain the next check
+ write!(self.out, ") && ")?;
+ }
+
+ // We now need to write the size checks for the coordinates and array index
+ // first we write the comparison function in case the image is 1D non arrayed
+ // (and no 1D to 2D hack was needed) we are comparing scalars so the less than
+ // operator will suffice, but otherwise we'll be comparing two vectors so we'll
+ // need to use the `lessThan` function but it returns a vector of booleans (one
+ // for each comparison) so we need to fold it all in one scalar boolean, since
+ // we want all comparisons to pass we use the `all` function which will only
+ // return `true` if all the elements of the boolean vector are also `true`.
+ //
+ // So we'll end with one of the following forms
+ // - `coord < textureSize(image, lod)` for 1D images
+ // - `all(lessThan(coord, textureSize(image, lod)))` for normal images
+ // - `all(lessThan(ivec(coord, array_index), textureSize(image, lod)))`
+ // for arrayed images
+ // - `all(lessThan(coord, textureSize(image)))` for multi sampled images
+
+ if vector_size != 1 {
+ write!(self.out, "all(lessThan(")?;
+ }
+
+ // Write the coordinate vector
+ self.write_texture_coord(ctx, vector_size, coordinate, array_index, tex_1d_hack)?;
+
+ if vector_size != 1 {
+ // If we used the `lessThan` function we need to separate the
+ // coordinates from the image size.
+ write!(self.out, ", ")?;
+ } else {
+ // If we didn't use it (ie. 1D images) we perform the comparison
+ // using the less than operator.
+ write!(self.out, " < ")?;
+ }
+
+ // Call `textureSize` to get our image size
+ write!(self.out, "textureSize(")?;
+ self.write_expr(image, ctx)?;
+ // `textureSize` uses the lod as a second argument for mipmapped images
+ if let Some(level_expr) = level {
+ // Separate the image from the lod
+ write!(self.out, ", ")?;
+ self.write_expr(level_expr, ctx)?;
+ }
+ // Close the `textureSize` call
+ write!(self.out, ")")?;
+
+ if vector_size != 1 {
+ // Close the `all` and `lessThan` calls
+ write!(self.out, "))")?;
+ }
+
+ // Finally end the condition part of the ternary operator
+ write!(self.out, " ? ")?;
+ }
+
+ // Begin the call to the function used to load the texel
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(image, ctx)?;
+ write!(self.out, ", ")?;
+
+ // If we are using `Restrict` bounds checking we need to pass valid texel
+ // coordinates, to do so we use the `clamp` function to get a value between
+ // 0 and the image size - 1 (indexing begins at 0)
+ if let proc::BoundsCheckPolicy::Restrict = policy {
+ write!(self.out, "clamp(")?;
+ }
+
+ // Write the coordinate vector
+ self.write_texture_coord(ctx, vector_size, coordinate, array_index, tex_1d_hack)?;
+
+ // If we are using `Restrict` bounds checking we need to write the rest of the
+ // clamp we initiated before writing the coordinates.
+ if let proc::BoundsCheckPolicy::Restrict = policy {
+ // Write the min value 0
+ if vector_size == 1 {
+ write!(self.out, ", 0")?;
+ } else {
+ write!(self.out, ", ivec{vector_size}(0)")?;
+ }
+ // Start the `textureSize` call to use as the max value.
+ write!(self.out, ", textureSize(")?;
+ self.write_expr(image, ctx)?;
+ // If the image is mipmapped we need to add the lod argument to the
+ // `textureSize` call, but this needs to be the clamped lod, this should
+ // have been generated earlier and put in a local.
+ if class.is_mipmapped() {
+ write!(
+ self.out,
+ ", {}{}{}",
+ back::BAKE_PREFIX,
+ handle.index(),
+ CLAMPED_LOD_SUFFIX
+ )?;
+ }
+ // Close the `textureSize` call
+ write!(self.out, ")")?;
+
+ // Subtract 1 from the `textureSize` call since the coordinates are zero based.
+ if vector_size == 1 {
+ write!(self.out, " - 1")?;
+ } else {
+ write!(self.out, " - ivec{vector_size}(1)")?;
+ }
+
+ // Close the `clamp` call
+ write!(self.out, ")")?;
+
+ // Add the clamped lod (if present) as the second argument to the
+ // image load function.
+ if level.is_some() {
+ write!(
+ self.out,
+ ", {}{}{}",
+ back::BAKE_PREFIX,
+ handle.index(),
+ CLAMPED_LOD_SUFFIX
+ )?;
+ }
+
+ // If a sample argument is needed we need to clamp it between 0 and
+ // the number of samples the image has.
+ if let Some(sample_expr) = sample {
+ write!(self.out, ", clamp(")?;
+ self.write_expr(sample_expr, ctx)?;
+ // Set the min value to 0 and start the call to `textureSamples`
+ write!(self.out, ", 0, textureSamples(")?;
+ self.write_expr(image, ctx)?;
+ // Close the `textureSamples` call, subtract 1 from it since the sample
+ // argument is zero based, and close the `clamp` call
+ writeln!(self.out, ") - 1)")?;
+ }
+ } else if let Some(sample_or_level) = sample.or(level) {
+ // If no bounds checking is need just add the sample or level argument
+ // after the coordinates
+ write!(self.out, ", ")?;
+ self.write_expr(sample_or_level, ctx)?;
+ }
+
+ // Close the image load function.
+ write!(self.out, ")")?;
+
+ // If we were using the `ReadZeroSkipWrite` policy we need to end the first branch
+ // (which is taken if the condition is `true`) with a colon (`:`) and write the
+ // second branch which is just a 0 value.
+ if let proc::BoundsCheckPolicy::ReadZeroSkipWrite = policy {
+ // Get the kind of the output value.
+ let kind = match class {
+ // Only sampled images can reach here since storage images
+ // don't need bounds checks and depth images aren't implemented
+ crate::ImageClass::Sampled { kind, .. } => kind,
+ _ => unreachable!(),
+ };
+
+ // End the first branch
+ write!(self.out, " : ")?;
+ // Write the 0 value
+ write!(
+ self.out,
+ "{}vec4(",
+ glsl_scalar(crate::Scalar { kind, width: 4 })?.prefix,
+ )?;
+ self.write_zero_init_scalar(kind)?;
+ // Close the zero value constructor
+ write!(self.out, ")")?;
+ // Close the parentheses surrounding our ternary
+ write!(self.out, ")")?;
+ }
+
+ Ok(())
+ }
+
+ fn write_named_expr(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ name: String,
+ // The expression which is being named.
+ // Generally, this is the same as handle, except in WorkGroupUniformLoad
+ named: Handle<crate::Expression>,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ match ctx.info[named].ty {
+ proc::TypeResolution::Handle(ty_handle) => match self.module.types[ty_handle].inner {
+ TypeInner::Struct { .. } => {
+ let ty_name = &self.names[&NameKey::Type(ty_handle)];
+ write!(self.out, "{ty_name}")?;
+ }
+ _ => {
+ self.write_type(ty_handle)?;
+ }
+ },
+ proc::TypeResolution::Value(ref inner) => {
+ self.write_value_type(inner)?;
+ }
+ }
+
+ let resolved = ctx.resolve_type(named, &self.module.types);
+
+ write!(self.out, " {name}")?;
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(base, size)?;
+ }
+ write!(self.out, " = ")?;
+ self.write_expr(handle, ctx)?;
+ writeln!(self.out, ";")?;
+ self.named_expressions.insert(named, name);
+
+ Ok(())
+ }
+
+ /// Helper function that write string with default zero initialization for supported types
+ fn write_zero_init_value(&mut self, ty: Handle<crate::Type>) -> BackendResult {
+ let inner = &self.module.types[ty].inner;
+ match *inner {
+ TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
+ self.write_zero_init_scalar(scalar.kind)?;
+ }
+ TypeInner::Vector { scalar, .. } => {
+ self.write_value_type(inner)?;
+ write!(self.out, "(")?;
+ self.write_zero_init_scalar(scalar.kind)?;
+ write!(self.out, ")")?;
+ }
+ TypeInner::Matrix { .. } => {
+ self.write_value_type(inner)?;
+ write!(self.out, "(")?;
+ self.write_zero_init_scalar(crate::ScalarKind::Float)?;
+ write!(self.out, ")")?;
+ }
+ TypeInner::Array { base, size, .. } => {
+ let count = match size
+ .to_indexable_length(self.module)
+ .expect("Bad array size")
+ {
+ proc::IndexableLength::Known(count) => count,
+ proc::IndexableLength::Dynamic => return Ok(()),
+ };
+ self.write_type(base)?;
+ self.write_array_size(base, size)?;
+ write!(self.out, "(")?;
+ for _ in 1..count {
+ self.write_zero_init_value(base)?;
+ write!(self.out, ", ")?;
+ }
+ // write last parameter without comma and space
+ self.write_zero_init_value(base)?;
+ write!(self.out, ")")?;
+ }
+ TypeInner::Struct { ref members, .. } => {
+ let name = &self.names[&NameKey::Type(ty)];
+ write!(self.out, "{name}(")?;
+ for (index, member) in members.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_zero_init_value(member.ty)?;
+ }
+ write!(self.out, ")")?;
+ }
+ _ => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Helper function that write string with zero initialization for scalar
+ fn write_zero_init_scalar(&mut self, kind: crate::ScalarKind) -> BackendResult {
+ match kind {
+ crate::ScalarKind::Bool => write!(self.out, "false")?,
+ crate::ScalarKind::Uint => write!(self.out, "0u")?,
+ crate::ScalarKind::Float => write!(self.out, "0.0")?,
+ crate::ScalarKind::Sint => write!(self.out, "0")?,
+ crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => {
+ return Err(Error::Custom(
+ "Abstract types should not appear in IR presented to backends".to_string(),
+ ))
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Issue a memory barrier. Please note that to ensure visibility,
+ /// OpenGL always requires a call to the `barrier()` function after a `memoryBarrier*()`
+ fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
+ if flags.contains(crate::Barrier::STORAGE) {
+ writeln!(self.out, "{level}memoryBarrierBuffer();")?;
+ }
+ if flags.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(self.out, "{level}memoryBarrierShared();")?;
+ }
+ writeln!(self.out, "{level}barrier();")?;
+ Ok(())
+ }
+
+ /// Helper function that return the glsl storage access string of [`StorageAccess`](crate::StorageAccess)
+ ///
+ /// glsl allows adding both `readonly` and `writeonly` but this means that
+ /// they can only be used to query information about the resource which isn't what
+ /// we want here so when storage access is both `LOAD` and `STORE` add no modifiers
+ fn write_storage_access(&mut self, storage_access: crate::StorageAccess) -> BackendResult {
+ if !storage_access.contains(crate::StorageAccess::STORE) {
+ write!(self.out, "readonly ")?;
+ }
+ if !storage_access.contains(crate::StorageAccess::LOAD) {
+ write!(self.out, "writeonly ")?;
+ }
+ Ok(())
+ }
+
+ /// Helper method used to produce the reflection info that's returned to the user
+ fn collect_reflection_info(&mut self) -> Result<ReflectionInfo, Error> {
+ use std::collections::hash_map::Entry;
+ let info = self.info.get_entry_point(self.entry_point_idx as usize);
+ let mut texture_mapping = crate::FastHashMap::default();
+ let mut uniforms = crate::FastHashMap::default();
+
+ for sampling in info.sampling_set.iter() {
+ let tex_name = self.reflection_names_globals[&sampling.image].clone();
+
+ match texture_mapping.entry(tex_name) {
+ Entry::Vacant(v) => {
+ v.insert(TextureMapping {
+ texture: sampling.image,
+ sampler: Some(sampling.sampler),
+ });
+ }
+ Entry::Occupied(e) => {
+ if e.get().sampler != Some(sampling.sampler) {
+ log::error!("Conflicting samplers for {}", e.key());
+ return Err(Error::ImageMultipleSamplers);
+ }
+ }
+ }
+ }
+
+ let mut push_constant_info = None;
+ for (handle, var) in self.module.global_variables.iter() {
+ if info[handle].is_empty() {
+ continue;
+ }
+ match self.module.types[var.ty].inner {
+ crate::TypeInner::Image { .. } => {
+ let tex_name = self.reflection_names_globals[&handle].clone();
+ match texture_mapping.entry(tex_name) {
+ Entry::Vacant(v) => {
+ v.insert(TextureMapping {
+ texture: handle,
+ sampler: None,
+ });
+ }
+ Entry::Occupied(_) => {
+ // already used with a sampler, do nothing
+ }
+ }
+ }
+ _ => match var.space {
+ crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } => {
+ let name = self.reflection_names_globals[&handle].clone();
+ uniforms.insert(handle, name);
+ }
+ crate::AddressSpace::PushConstant => {
+ let name = self.reflection_names_globals[&handle].clone();
+ push_constant_info = Some((name, var.ty));
+ }
+ _ => (),
+ },
+ }
+ }
+
+ let mut push_constant_segments = Vec::new();
+ let mut push_constant_items = vec![];
+
+ if let Some((name, ty)) = push_constant_info {
+ // We don't have a layouter available to us, so we need to create one.
+ //
+ // This is potentially a bit wasteful, but the set of types in the program
+ // shouldn't be too large.
+ let mut layouter = crate::proc::Layouter::default();
+ layouter.update(self.module.to_ctx()).unwrap();
+
+ // We start with the name of the binding itself.
+ push_constant_segments.push(name);
+
+ // We then recursively collect all the uniform fields of the push constant.
+ self.collect_push_constant_items(
+ ty,
+ &mut push_constant_segments,
+ &layouter,
+ &mut 0,
+ &mut push_constant_items,
+ );
+ }
+
+ Ok(ReflectionInfo {
+ texture_mapping,
+ uniforms,
+ varying: mem::take(&mut self.varying),
+ push_constant_items,
+ })
+ }
+
+ fn collect_push_constant_items(
+ &mut self,
+ ty: Handle<crate::Type>,
+ segments: &mut Vec<String>,
+ layouter: &crate::proc::Layouter,
+ offset: &mut u32,
+ items: &mut Vec<PushConstantItem>,
+ ) {
+ // At this point in the recursion, `segments` contains the path
+ // needed to access `ty` from the root.
+
+ let layout = &layouter[ty];
+ *offset = layout.alignment.round_up(*offset);
+ match self.module.types[ty].inner {
+ // All these types map directly to GL uniforms.
+ TypeInner::Scalar { .. } | TypeInner::Vector { .. } | TypeInner::Matrix { .. } => {
+ // Build the full name, by combining all current segments.
+ let name: String = segments.iter().map(String::as_str).collect();
+ items.push(PushConstantItem {
+ access_path: name,
+ offset: *offset,
+ ty,
+ });
+ *offset += layout.size;
+ }
+ // Arrays are recursed into.
+ TypeInner::Array { base, size, .. } => {
+ let crate::ArraySize::Constant(count) = size else {
+ unreachable!("Cannot have dynamic arrays in push constants");
+ };
+
+ for i in 0..count.get() {
+ // Add the array accessor and recurse.
+ segments.push(format!("[{}]", i));
+ self.collect_push_constant_items(base, segments, layouter, offset, items);
+ segments.pop();
+ }
+
+ // Ensure the stride is kept by rounding up to the alignment.
+ *offset = layout.alignment.round_up(*offset)
+ }
+ TypeInner::Struct { ref members, .. } => {
+ for (index, member) in members.iter().enumerate() {
+ // Add struct accessor and recurse.
+ segments.push(format!(
+ ".{}",
+ self.names[&NameKey::StructMember(ty, index as u32)]
+ ));
+ self.collect_push_constant_items(member.ty, segments, layouter, offset, items);
+ segments.pop();
+ }
+
+ // Ensure ending padding is kept by rounding up to the alignment.
+ *offset = layout.alignment.round_up(*offset)
+ }
+ _ => unreachable!(),
+ }
+ }
+}
+
+/// Structure returned by [`glsl_scalar`]
+///
+/// It contains both a prefix used in other types and the full type name
+struct ScalarString<'a> {
+ /// The prefix used to compose other types
+ prefix: &'a str,
+ /// The name of the scalar type
+ full: &'a str,
+}
+
+/// Helper function that returns scalar related strings
+///
+/// Check [`ScalarString`] for the information provided
+///
+/// # Errors
+/// If a [`Float`](crate::ScalarKind::Float) with an width that isn't 4 or 8
+const fn glsl_scalar(scalar: crate::Scalar) -> Result<ScalarString<'static>, Error> {
+ use crate::ScalarKind as Sk;
+
+ Ok(match scalar.kind {
+ Sk::Sint => ScalarString {
+ prefix: "i",
+ full: "int",
+ },
+ Sk::Uint => ScalarString {
+ prefix: "u",
+ full: "uint",
+ },
+ Sk::Float => match scalar.width {
+ 4 => ScalarString {
+ prefix: "",
+ full: "float",
+ },
+ 8 => ScalarString {
+ prefix: "d",
+ full: "double",
+ },
+ _ => return Err(Error::UnsupportedScalar(scalar)),
+ },
+ Sk::Bool => ScalarString {
+ prefix: "b",
+ full: "bool",
+ },
+ Sk::AbstractInt | Sk::AbstractFloat => {
+ return Err(Error::UnsupportedScalar(scalar));
+ }
+ })
+}
+
+/// Helper function that returns the glsl variable name for a builtin
+const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'static str {
+ use crate::BuiltIn as Bi;
+
+ match built_in {
+ Bi::Position { .. } => {
+ if options.output {
+ "gl_Position"
+ } else {
+ "gl_FragCoord"
+ }
+ }
+ Bi::ViewIndex if options.targeting_webgl => "int(gl_ViewID_OVR)",
+ Bi::ViewIndex => "gl_ViewIndex",
+ // vertex
+ Bi::BaseInstance => "uint(gl_BaseInstance)",
+ Bi::BaseVertex => "uint(gl_BaseVertex)",
+ Bi::ClipDistance => "gl_ClipDistance",
+ Bi::CullDistance => "gl_CullDistance",
+ Bi::InstanceIndex => {
+ if options.draw_parameters {
+ "(uint(gl_InstanceID) + uint(gl_BaseInstanceARB))"
+ } else {
+ // Must match FIRST_INSTANCE_BINDING
+ "(uint(gl_InstanceID) + naga_vs_first_instance)"
+ }
+ }
+ Bi::PointSize => "gl_PointSize",
+ Bi::VertexIndex => "uint(gl_VertexID)",
+ // fragment
+ Bi::FragDepth => "gl_FragDepth",
+ Bi::PointCoord => "gl_PointCoord",
+ Bi::FrontFacing => "gl_FrontFacing",
+ Bi::PrimitiveIndex => "uint(gl_PrimitiveID)",
+ Bi::SampleIndex => "gl_SampleID",
+ Bi::SampleMask => {
+ if options.output {
+ "gl_SampleMask"
+ } else {
+ "gl_SampleMaskIn"
+ }
+ }
+ // compute
+ Bi::GlobalInvocationId => "gl_GlobalInvocationID",
+ Bi::LocalInvocationId => "gl_LocalInvocationID",
+ Bi::LocalInvocationIndex => "gl_LocalInvocationIndex",
+ Bi::WorkGroupId => "gl_WorkGroupID",
+ Bi::WorkGroupSize => "gl_WorkGroupSize",
+ Bi::NumWorkGroups => "gl_NumWorkGroups",
+ }
+}
+
+/// Helper function that returns the string corresponding to the address space
+const fn glsl_storage_qualifier(space: crate::AddressSpace) -> Option<&'static str> {
+ use crate::AddressSpace as As;
+
+ match space {
+ As::Function => None,
+ As::Private => None,
+ As::Storage { .. } => Some("buffer"),
+ As::Uniform => Some("uniform"),
+ As::Handle => Some("uniform"),
+ As::WorkGroup => Some("shared"),
+ As::PushConstant => Some("uniform"),
+ }
+}
+
+/// Helper function that returns the string corresponding to the glsl interpolation qualifier
+const fn glsl_interpolation(interpolation: crate::Interpolation) -> &'static str {
+ use crate::Interpolation as I;
+
+ match interpolation {
+ I::Perspective => "smooth",
+ I::Linear => "noperspective",
+ I::Flat => "flat",
+ }
+}
+
+/// Return the GLSL auxiliary qualifier for the given sampling value.
+const fn glsl_sampling(sampling: crate::Sampling) -> Option<&'static str> {
+ use crate::Sampling as S;
+
+ match sampling {
+ S::Center => None,
+ S::Centroid => Some("centroid"),
+ S::Sample => Some("sample"),
+ }
+}
+
+/// Helper function that returns the glsl dimension string of [`ImageDimension`](crate::ImageDimension)
+const fn glsl_dimension(dim: crate::ImageDimension) -> &'static str {
+ use crate::ImageDimension as IDim;
+
+ match dim {
+ IDim::D1 => "1D",
+ IDim::D2 => "2D",
+ IDim::D3 => "3D",
+ IDim::Cube => "Cube",
+ }
+}
+
+/// Helper function that returns the glsl storage format string of [`StorageFormat`](crate::StorageFormat)
+fn glsl_storage_format(format: crate::StorageFormat) -> Result<&'static str, Error> {
+ use crate::StorageFormat as Sf;
+
+ Ok(match format {
+ Sf::R8Unorm => "r8",
+ Sf::R8Snorm => "r8_snorm",
+ Sf::R8Uint => "r8ui",
+ Sf::R8Sint => "r8i",
+ Sf::R16Uint => "r16ui",
+ Sf::R16Sint => "r16i",
+ Sf::R16Float => "r16f",
+ Sf::Rg8Unorm => "rg8",
+ Sf::Rg8Snorm => "rg8_snorm",
+ Sf::Rg8Uint => "rg8ui",
+ Sf::Rg8Sint => "rg8i",
+ Sf::R32Uint => "r32ui",
+ Sf::R32Sint => "r32i",
+ Sf::R32Float => "r32f",
+ Sf::Rg16Uint => "rg16ui",
+ Sf::Rg16Sint => "rg16i",
+ Sf::Rg16Float => "rg16f",
+ Sf::Rgba8Unorm => "rgba8",
+ Sf::Rgba8Snorm => "rgba8_snorm",
+ Sf::Rgba8Uint => "rgba8ui",
+ Sf::Rgba8Sint => "rgba8i",
+ Sf::Rgb10a2Uint => "rgb10_a2ui",
+ Sf::Rgb10a2Unorm => "rgb10_a2",
+ Sf::Rg11b10Float => "r11f_g11f_b10f",
+ Sf::Rg32Uint => "rg32ui",
+ Sf::Rg32Sint => "rg32i",
+ Sf::Rg32Float => "rg32f",
+ Sf::Rgba16Uint => "rgba16ui",
+ Sf::Rgba16Sint => "rgba16i",
+ Sf::Rgba16Float => "rgba16f",
+ Sf::Rgba32Uint => "rgba32ui",
+ Sf::Rgba32Sint => "rgba32i",
+ Sf::Rgba32Float => "rgba32f",
+ Sf::R16Unorm => "r16",
+ Sf::R16Snorm => "r16_snorm",
+ Sf::Rg16Unorm => "rg16",
+ Sf::Rg16Snorm => "rg16_snorm",
+ Sf::Rgba16Unorm => "rgba16",
+ Sf::Rgba16Snorm => "rgba16_snorm",
+
+ Sf::Bgra8Unorm => {
+ return Err(Error::Custom(
+ "Support format BGRA8 is not implemented".into(),
+ ))
+ }
+ })
+}
+
+fn is_value_init_supported(module: &crate::Module, ty: Handle<crate::Type>) -> bool {
+ match module.types[ty].inner {
+ TypeInner::Scalar { .. } | TypeInner::Vector { .. } | TypeInner::Matrix { .. } => true,
+ TypeInner::Array { base, size, .. } => {
+ size != crate::ArraySize::Dynamic && is_value_init_supported(module, base)
+ }
+ TypeInner::Struct { ref members, .. } => members
+ .iter()
+ .all(|member| is_value_init_supported(module, member.ty)),
+ _ => false,
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/conv.rs b/third_party/rust/naga/src/back/hlsl/conv.rs
new file mode 100644
index 0000000000..b6918ddc42
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/conv.rs
@@ -0,0 +1,222 @@
+use std::borrow::Cow;
+
+use crate::proc::Alignment;
+
+use super::Error;
+
+impl crate::ScalarKind {
+ pub(super) fn to_hlsl_cast(self) -> &'static str {
+ match self {
+ Self::Float => "asfloat",
+ Self::Sint => "asint",
+ Self::Uint => "asuint",
+ Self::Bool | Self::AbstractInt | Self::AbstractFloat => unreachable!(),
+ }
+ }
+}
+
+impl crate::Scalar {
+ /// Helper function that returns scalar related strings
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-scalar>
+ pub(super) const fn to_hlsl_str(self) -> Result<&'static str, Error> {
+ match self.kind {
+ crate::ScalarKind::Sint => Ok("int"),
+ crate::ScalarKind::Uint => Ok("uint"),
+ crate::ScalarKind::Float => match self.width {
+ 2 => Ok("half"),
+ 4 => Ok("float"),
+ 8 => Ok("double"),
+ _ => Err(Error::UnsupportedScalar(self)),
+ },
+ crate::ScalarKind::Bool => Ok("bool"),
+ crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => {
+ Err(Error::UnsupportedScalar(self))
+ }
+ }
+ }
+}
+
+impl crate::TypeInner {
+ pub(super) const fn is_matrix(&self) -> bool {
+ match *self {
+ Self::Matrix { .. } => true,
+ _ => false,
+ }
+ }
+
+ pub(super) fn size_hlsl(&self, gctx: crate::proc::GlobalCtx) -> u32 {
+ match *self {
+ Self::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ let stride = Alignment::from(rows) * scalar.width as u32;
+ let last_row_size = rows as u32 * scalar.width as u32;
+ ((columns as u32 - 1) * stride) + last_row_size
+ }
+ Self::Array { base, size, stride } => {
+ let count = match size {
+ crate::ArraySize::Constant(size) => size.get(),
+ // A dynamically-sized array has to have at least one element
+ crate::ArraySize::Dynamic => 1,
+ };
+ let last_el_size = gctx.types[base].inner.size_hlsl(gctx);
+ ((count - 1) * stride) + last_el_size
+ }
+ _ => self.size(gctx),
+ }
+ }
+
+ /// Used to generate the name of the wrapped type constructor
+ pub(super) fn hlsl_type_id<'a>(
+ base: crate::Handle<crate::Type>,
+ gctx: crate::proc::GlobalCtx,
+ names: &'a crate::FastHashMap<crate::proc::NameKey, String>,
+ ) -> Result<Cow<'a, str>, Error> {
+ Ok(match gctx.types[base].inner {
+ crate::TypeInner::Scalar(scalar) => Cow::Borrowed(scalar.to_hlsl_str()?),
+ crate::TypeInner::Vector { size, scalar } => Cow::Owned(format!(
+ "{}{}",
+ scalar.to_hlsl_str()?,
+ crate::back::vector_size_str(size)
+ )),
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => Cow::Owned(format!(
+ "{}{}x{}",
+ scalar.to_hlsl_str()?,
+ crate::back::vector_size_str(columns),
+ crate::back::vector_size_str(rows),
+ )),
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => Cow::Owned(format!(
+ "array{size}_{}_",
+ Self::hlsl_type_id(base, gctx, names)?
+ )),
+ crate::TypeInner::Struct { .. } => {
+ Cow::Borrowed(&names[&crate::proc::NameKey::Type(base)])
+ }
+ _ => unreachable!(),
+ })
+ }
+}
+
+impl crate::StorageFormat {
+ pub(super) const fn to_hlsl_str(self) -> &'static str {
+ match self {
+ Self::R16Float => "float",
+ Self::R8Unorm | Self::R16Unorm => "unorm float",
+ Self::R8Snorm | Self::R16Snorm => "snorm float",
+ Self::R8Uint | Self::R16Uint => "uint",
+ Self::R8Sint | Self::R16Sint => "int",
+
+ Self::Rg16Float => "float2",
+ Self::Rg8Unorm | Self::Rg16Unorm => "unorm float2",
+ Self::Rg8Snorm | Self::Rg16Snorm => "snorm float2",
+
+ Self::Rg8Sint | Self::Rg16Sint => "int2",
+ Self::Rg8Uint | Self::Rg16Uint => "uint2",
+
+ Self::Rg11b10Float => "float3",
+
+ Self::Rgba16Float | Self::R32Float | Self::Rg32Float | Self::Rgba32Float => "float4",
+ Self::Rgba8Unorm | Self::Bgra8Unorm | Self::Rgba16Unorm | Self::Rgb10a2Unorm => {
+ "unorm float4"
+ }
+ Self::Rgba8Snorm | Self::Rgba16Snorm => "snorm float4",
+
+ Self::Rgba8Uint
+ | Self::Rgba16Uint
+ | Self::R32Uint
+ | Self::Rg32Uint
+ | Self::Rgba32Uint
+ | Self::Rgb10a2Uint => "uint4",
+ Self::Rgba8Sint
+ | Self::Rgba16Sint
+ | Self::R32Sint
+ | Self::Rg32Sint
+ | Self::Rgba32Sint => "int4",
+ }
+ }
+}
+
+impl crate::BuiltIn {
+ pub(super) fn to_hlsl_str(self) -> Result<&'static str, Error> {
+ Ok(match self {
+ Self::Position { .. } => "SV_Position",
+ // vertex
+ Self::ClipDistance => "SV_ClipDistance",
+ Self::CullDistance => "SV_CullDistance",
+ Self::InstanceIndex => "SV_InstanceID",
+ Self::VertexIndex => "SV_VertexID",
+ // fragment
+ Self::FragDepth => "SV_Depth",
+ Self::FrontFacing => "SV_IsFrontFace",
+ Self::PrimitiveIndex => "SV_PrimitiveID",
+ Self::SampleIndex => "SV_SampleIndex",
+ Self::SampleMask => "SV_Coverage",
+ // compute
+ Self::GlobalInvocationId => "SV_DispatchThreadID",
+ Self::LocalInvocationId => "SV_GroupThreadID",
+ Self::LocalInvocationIndex => "SV_GroupIndex",
+ Self::WorkGroupId => "SV_GroupID",
+ // The specific semantic we use here doesn't matter, because references
+ // to this field will get replaced with references to `SPECIAL_CBUF_VAR`
+ // in `Writer::write_expr`.
+ Self::NumWorkGroups => "SV_GroupID",
+ Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
+ return Err(Error::Unimplemented(format!("builtin {self:?}")))
+ }
+ Self::PointSize | Self::ViewIndex | Self::PointCoord => {
+ return Err(Error::Custom(format!("Unsupported builtin {self:?}")))
+ }
+ })
+ }
+}
+
+impl crate::Interpolation {
+ /// Return the string corresponding to the HLSL interpolation qualifier.
+ pub(super) const fn to_hlsl_str(self) -> Option<&'static str> {
+ match self {
+ // Would be "linear", but it's the default interpolation in SM4 and up
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-struct#interpolation-modifiers-introduced-in-shader-model-4
+ Self::Perspective => None,
+ Self::Linear => Some("noperspective"),
+ Self::Flat => Some("nointerpolation"),
+ }
+ }
+}
+
+impl crate::Sampling {
+ /// Return the HLSL auxiliary qualifier for the given sampling value.
+ pub(super) const fn to_hlsl_str(self) -> Option<&'static str> {
+ match self {
+ Self::Center => None,
+ Self::Centroid => Some("centroid"),
+ Self::Sample => Some("sample"),
+ }
+ }
+}
+
+impl crate::AtomicFunction {
+ /// Return the HLSL suffix for the `InterlockedXxx` method.
+ pub(super) const fn to_hlsl_suffix(self) -> &'static str {
+ match self {
+ Self::Add | Self::Subtract => "Add",
+ Self::And => "And",
+ Self::InclusiveOr => "Or",
+ Self::ExclusiveOr => "Xor",
+ Self::Min => "Min",
+ Self::Max => "Max",
+ Self::Exchange { compare: None } => "Exchange",
+ Self::Exchange { .. } => "", //TODO
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/help.rs b/third_party/rust/naga/src/back/hlsl/help.rs
new file mode 100644
index 0000000000..fa6062a1ad
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/help.rs
@@ -0,0 +1,1138 @@
+/*!
+Helpers for the hlsl backend
+
+Important note about `Expression::ImageQuery`/`Expression::ArrayLength` and hlsl backend:
+
+Due to implementation of `GetDimensions` function in hlsl (<https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>)
+backend can't work with it as an expression.
+Instead, it generates a unique wrapped function per `Expression::ImageQuery`, based on texture info and query function.
+See `WrappedImageQuery` struct that represents a unique function and will be generated before writing all statements and expressions.
+This allowed to works with `Expression::ImageQuery` as expression and write wrapped function.
+
+For example:
+```wgsl
+let dim_1d = textureDimensions(image_1d);
+```
+
+```hlsl
+int NagaDimensions1D(Texture1D<float4>)
+{
+ uint4 ret;
+ image_1d.GetDimensions(ret.x);
+ return ret.x;
+}
+
+int dim_1d = NagaDimensions1D(image_1d);
+```
+*/
+
+use super::{super::FunctionCtx, BackendResult};
+use crate::{arena::Handle, proc::NameKey};
+use std::fmt::Write;
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedArrayLength {
+ pub(super) writable: bool,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedImageQuery {
+ pub(super) dim: crate::ImageDimension,
+ pub(super) arrayed: bool,
+ pub(super) class: crate::ImageClass,
+ pub(super) query: ImageQuery,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedConstructor {
+ pub(super) ty: Handle<crate::Type>,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedStructMatrixAccess {
+ pub(super) ty: Handle<crate::Type>,
+ pub(super) index: u32,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedMatCx2 {
+ pub(super) columns: crate::VectorSize,
+}
+
+/// HLSL backend requires its own `ImageQuery` enum.
+///
+/// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function.
+/// IR version can't be unique per function, because it's store mipmap level as an expression.
+///
+/// For example:
+/// ```wgsl
+/// let dim_cube_array_lod = textureDimensions(image_cube_array, 1);
+/// let dim_cube_array_lod2 = textureDimensions(image_cube_array, 1);
+/// ```
+///
+/// ```ir
+/// ImageQuery {
+/// image: [1],
+/// query: Size {
+/// level: Some(
+/// [1],
+/// ),
+/// },
+/// },
+/// ImageQuery {
+/// image: [1],
+/// query: Size {
+/// level: Some(
+/// [2],
+/// ),
+/// },
+/// },
+/// ```
+///
+/// HLSL should generate only 1 function for this case.
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) enum ImageQuery {
+ Size,
+ SizeLevel,
+ NumLevels,
+ NumLayers,
+ NumSamples,
+}
+
+impl From<crate::ImageQuery> for ImageQuery {
+ fn from(q: crate::ImageQuery) -> Self {
+ use crate::ImageQuery as Iq;
+ match q {
+ Iq::Size { level: Some(_) } => ImageQuery::SizeLevel,
+ Iq::Size { level: None } => ImageQuery::Size,
+ Iq::NumLevels => ImageQuery::NumLevels,
+ Iq::NumLayers => ImageQuery::NumLayers,
+ Iq::NumSamples => ImageQuery::NumSamples,
+ }
+ }
+}
+
+impl<'a, W: Write> super::Writer<'a, W> {
+ pub(super) fn write_image_type(
+ &mut self,
+ dim: crate::ImageDimension,
+ arrayed: bool,
+ class: crate::ImageClass,
+ ) -> BackendResult {
+ let access_str = match class {
+ crate::ImageClass::Storage { .. } => "RW",
+ _ => "",
+ };
+ let dim_str = dim.to_hlsl_str();
+ let arrayed_str = if arrayed { "Array" } else { "" };
+ write!(self.out, "{access_str}Texture{dim_str}{arrayed_str}")?;
+ match class {
+ crate::ImageClass::Depth { multi } => {
+ let multi_str = if multi { "MS" } else { "" };
+ write!(self.out, "{multi_str}<float>")?
+ }
+ crate::ImageClass::Sampled { kind, multi } => {
+ let multi_str = if multi { "MS" } else { "" };
+ let scalar_kind_str = crate::Scalar { kind, width: 4 }.to_hlsl_str()?;
+ write!(self.out, "{multi_str}<{scalar_kind_str}4>")?
+ }
+ crate::ImageClass::Storage { format, .. } => {
+ let storage_format_str = format.to_hlsl_str();
+ write!(self.out, "<{storage_format_str}>")?
+ }
+ }
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_array_length_function_name(
+ &mut self,
+ query: WrappedArrayLength,
+ ) -> BackendResult {
+ let access_str = if query.writable { "RW" } else { "" };
+ write!(self.out, "NagaBufferLength{access_str}",)?;
+
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::ArrayLength`
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer-getdimensions>
+ pub(super) fn write_wrapped_array_length_function(
+ &mut self,
+ wal: WrappedArrayLength,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const ARGUMENT_VARIABLE_NAME: &str = "buffer";
+ const RETURN_VARIABLE_NAME: &str = "ret";
+
+ // Write function return type and name
+ write!(self.out, "uint ")?;
+ self.write_wrapped_array_length_function_name(wal)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let access_str = if wal.writable { "RW" } else { "" };
+ writeln!(
+ self.out,
+ "{access_str}ByteAddressBuffer {ARGUMENT_VARIABLE_NAME})"
+ )?;
+ // Write function body
+ writeln!(self.out, "{{")?;
+
+ // Write `GetDimensions` function.
+ writeln!(self.out, "{INDENT}uint {RETURN_VARIABLE_NAME};")?;
+ writeln!(
+ self.out,
+ "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions({RETURN_VARIABLE_NAME});"
+ )?;
+
+ // Write return value
+ writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_image_query_function_name(
+ &mut self,
+ query: WrappedImageQuery,
+ ) -> BackendResult {
+ let dim_str = query.dim.to_hlsl_str();
+ let class_str = match query.class {
+ crate::ImageClass::Sampled { multi: true, .. } => "MS",
+ crate::ImageClass::Depth { multi: true } => "DepthMS",
+ crate::ImageClass::Depth { multi: false } => "Depth",
+ crate::ImageClass::Sampled { multi: false, .. } => "",
+ crate::ImageClass::Storage { .. } => "RW",
+ };
+ let arrayed_str = if query.arrayed { "Array" } else { "" };
+ let query_str = match query.query {
+ ImageQuery::Size => "Dimensions",
+ ImageQuery::SizeLevel => "MipDimensions",
+ ImageQuery::NumLevels => "NumLevels",
+ ImageQuery::NumLayers => "NumLayers",
+ ImageQuery::NumSamples => "NumSamples",
+ };
+
+ write!(self.out, "Naga{class_str}{query_str}{dim_str}{arrayed_str}")?;
+
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::ImageQuery`
+ ///
+ /// <https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions>
+ pub(super) fn write_wrapped_image_query_function(
+ &mut self,
+ module: &crate::Module,
+ wiq: WrappedImageQuery,
+ expr_handle: Handle<crate::Expression>,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ use crate::{
+ back::{COMPONENTS, INDENT},
+ ImageDimension as IDim,
+ };
+
+ const ARGUMENT_VARIABLE_NAME: &str = "tex";
+ const RETURN_VARIABLE_NAME: &str = "ret";
+ const MIP_LEVEL_PARAM: &str = "mip_level";
+
+ // Write function return type and name
+ let ret_ty = func_ctx.resolve_type(expr_handle, &module.types);
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, " ")?;
+ self.write_wrapped_image_query_function_name(wiq)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ // Texture always first parameter
+ self.write_image_type(wiq.dim, wiq.arrayed, wiq.class)?;
+ write!(self.out, " {ARGUMENT_VARIABLE_NAME}")?;
+ // Mipmap is a second parameter if exists
+ if let ImageQuery::SizeLevel = wiq.query {
+ write!(self.out, ", uint {MIP_LEVEL_PARAM}")?;
+ }
+ writeln!(self.out, ")")?;
+
+ // Write function body
+ writeln!(self.out, "{{")?;
+
+ let array_coords = usize::from(wiq.arrayed);
+ // extra parameter is the mip level count or the sample count
+ let extra_coords = match wiq.class {
+ crate::ImageClass::Storage { .. } => 0,
+ crate::ImageClass::Sampled { .. } | crate::ImageClass::Depth { .. } => 1,
+ };
+
+ // GetDimensions Overloaded Methods
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-getdimensions#overloaded-methods
+ let (ret_swizzle, number_of_params) = match wiq.query {
+ ImageQuery::Size | ImageQuery::SizeLevel => {
+ let ret = match wiq.dim {
+ IDim::D1 => "x",
+ IDim::D2 => "xy",
+ IDim::D3 => "xyz",
+ IDim::Cube => "xy",
+ };
+ (ret, ret.len() + array_coords + extra_coords)
+ }
+ ImageQuery::NumLevels | ImageQuery::NumSamples | ImageQuery::NumLayers => {
+ if wiq.arrayed || wiq.dim == IDim::D3 {
+ ("w", 4)
+ } else {
+ ("z", 3)
+ }
+ }
+ };
+
+ // Write `GetDimensions` function.
+ writeln!(self.out, "{INDENT}uint4 {RETURN_VARIABLE_NAME};")?;
+ write!(self.out, "{INDENT}{ARGUMENT_VARIABLE_NAME}.GetDimensions(")?;
+ match wiq.query {
+ ImageQuery::SizeLevel => {
+ write!(self.out, "{MIP_LEVEL_PARAM}, ")?;
+ }
+ _ => match wiq.class {
+ crate::ImageClass::Sampled { multi: true, .. }
+ | crate::ImageClass::Depth { multi: true }
+ | crate::ImageClass::Storage { .. } => {}
+ _ => {
+ // Write zero mipmap level for supported types
+ write!(self.out, "0, ")?;
+ }
+ },
+ }
+
+ for component in COMPONENTS[..number_of_params - 1].iter() {
+ write!(self.out, "{RETURN_VARIABLE_NAME}.{component}, ")?;
+ }
+
+ // write last parameter without comma and space for last parameter
+ write!(
+ self.out,
+ "{}.{}",
+ RETURN_VARIABLE_NAME,
+ COMPONENTS[number_of_params - 1]
+ )?;
+
+ writeln!(self.out, ");")?;
+
+ // Write return value
+ writeln!(
+ self.out,
+ "{INDENT}return {RETURN_VARIABLE_NAME}.{ret_swizzle};"
+ )?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_constructor_function_name(
+ &mut self,
+ module: &crate::Module,
+ constructor: WrappedConstructor,
+ ) -> BackendResult {
+ let name = crate::TypeInner::hlsl_type_id(constructor.ty, module.to_ctx(), &self.names)?;
+ write!(self.out, "Construct{name}")?;
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::Compose` for structures.
+ pub(super) fn write_wrapped_constructor_function(
+ &mut self,
+ module: &crate::Module,
+ constructor: WrappedConstructor,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const ARGUMENT_VARIABLE_NAME: &str = "arg";
+ const RETURN_VARIABLE_NAME: &str = "ret";
+
+ // Write function return type and name
+ if let crate::TypeInner::Array { base, size, .. } = module.types[constructor.ty].inner {
+ write!(self.out, "typedef ")?;
+ self.write_type(module, constructor.ty)?;
+ write!(self.out, " ret_")?;
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ self.write_array_size(module, base, size)?;
+ writeln!(self.out, ";")?;
+
+ write!(self.out, "ret_")?;
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ } else {
+ self.write_type(module, constructor.ty)?;
+ }
+ write!(self.out, " ")?;
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+
+ let mut write_arg = |i, ty| -> BackendResult {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_type(module, ty)?;
+ write!(self.out, " {ARGUMENT_VARIABLE_NAME}{i}")?;
+ if let crate::TypeInner::Array { base, size, .. } = module.types[ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ Ok(())
+ };
+
+ match module.types[constructor.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (i, member) in members.iter().enumerate() {
+ write_arg(i, member.ty)?;
+ }
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ for i in 0..size.get() as usize {
+ write_arg(i, base)?;
+ }
+ }
+ _ => unreachable!(),
+ };
+
+ write!(self.out, ")")?;
+
+ // Write function body
+ writeln!(self.out, " {{")?;
+
+ match module.types[constructor.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let struct_name = &self.names[&NameKey::Type(constructor.ty)];
+ writeln!(
+ self.out,
+ "{INDENT}{struct_name} {RETURN_VARIABLE_NAME} = ({struct_name})0;"
+ )?;
+ for (i, member) in members.iter().enumerate() {
+ let field_name = &self.names[&NameKey::StructMember(constructor.ty, i as u32)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix {
+ columns,
+ rows: crate::VectorSize::Bi,
+ ..
+ } if member.binding.is_none() => {
+ for j in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}{RETURN_VARIABLE_NAME}.{field_name}_{j} = {ARGUMENT_VARIABLE_NAME}{i}[{j}];"
+ )?;
+ }
+ }
+ ref other => {
+ // We cast arrays of native HLSL `floatCx2`s to arrays of `matCx2`s
+ // (where the inner matrix is represented by a struct with C `float2` members).
+ // See the module-level block comment in mod.rs for details.
+ if let Some(super::writer::MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = super::writer::get_inner_matrix_data(module, member.ty)
+ {
+ write!(
+ self.out,
+ "{}{}.{} = (__mat{}x2",
+ INDENT, RETURN_VARIABLE_NAME, field_name, columns as u8
+ )?;
+ if let crate::TypeInner::Array { base, size, .. } = *other {
+ self.write_array_size(module, base, size)?;
+ }
+ writeln!(self.out, "){ARGUMENT_VARIABLE_NAME}{i};",)?;
+ } else {
+ writeln!(
+ self.out,
+ "{INDENT}{RETURN_VARIABLE_NAME}.{field_name} = {ARGUMENT_VARIABLE_NAME}{i};",
+ )?;
+ }
+ }
+ }
+ }
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => {
+ write!(self.out, "{INDENT}")?;
+ self.write_type(module, base)?;
+ write!(self.out, " {RETURN_VARIABLE_NAME}")?;
+ self.write_array_size(module, base, crate::ArraySize::Constant(size))?;
+ write!(self.out, " = {{ ")?;
+ for i in 0..size.get() {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "{ARGUMENT_VARIABLE_NAME}{i}")?;
+ }
+ writeln!(self.out, " }};",)?;
+ }
+ _ => unreachable!(),
+ }
+
+ // Write return value
+ writeln!(self.out, "{INDENT}return {RETURN_VARIABLE_NAME};")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_get_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "GetMat{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to get a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_get_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+
+ // Write function return type and name
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ let ret_ty = &module.types[member.ty].inner;
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, " ")?;
+ self.write_wrapped_struct_matrix_get_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}")?;
+
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ // Write return value
+ write!(self.out, "{INDENT}return ")?;
+ self.write_value_type(module, ret_ty)?;
+ write!(self.out, "(")?;
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}")?;
+ }
+ }
+ _ => unreachable!(),
+ }
+ writeln!(self.out, ");")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_set_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "SetMat{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to set a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_set_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+ const MATRIX_ARGUMENT_VARIABLE_NAME: &str = "mat";
+
+ // Write function return type and name
+ write!(self.out, "void ")?;
+ self.write_wrapped_struct_matrix_set_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ self.write_type(module, member.ty)?;
+ write!(self.out, " {MATRIX_ARGUMENT_VARIABLE_NAME}")?;
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}{STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {MATRIX_ARGUMENT_VARIABLE_NAME}[{i}];"
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_set_vec_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "SetMatVec{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to set a vec2 on a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_set_vec_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+ const VECTOR_ARGUMENT_VARIABLE_NAME: &str = "vec";
+ const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx";
+
+ // Write function return type and name
+ write!(self.out, "void ")?;
+ self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ let vec_ty = match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { rows, scalar, .. } => {
+ crate::TypeInner::Vector { size: rows, scalar }
+ }
+ _ => unreachable!(),
+ };
+ self.write_value_type(module, &vec_ty)?;
+ write!(
+ self.out,
+ " {VECTOR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}"
+ )?;
+
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ writeln!(
+ self.out,
+ "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{"
+ )?;
+
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i} = {VECTOR_ARGUMENT_VARIABLE_NAME}; break; }}"
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ writeln!(self.out, "{INDENT}}}")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_wrapped_struct_matrix_set_scalar_function_name(
+ &mut self,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Type(access.ty)];
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+ write!(self.out, "SetMatScalar{field_name}On{name}")?;
+ Ok(())
+ }
+
+ /// Writes a function used to set a float on a matCx2 from within a structure.
+ pub(super) fn write_wrapped_struct_matrix_set_scalar_function(
+ &mut self,
+ module: &crate::Module,
+ access: WrappedStructMatrixAccess,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const STRUCT_ARGUMENT_VARIABLE_NAME: &str = "obj";
+ const SCALAR_ARGUMENT_VARIABLE_NAME: &str = "scalar";
+ const MATRIX_INDEX_ARGUMENT_VARIABLE_NAME: &str = "mat_idx";
+ const VECTOR_INDEX_ARGUMENT_VARIABLE_NAME: &str = "vec_idx";
+
+ // Write function return type and name
+ write!(self.out, "void ")?;
+ self.write_wrapped_struct_matrix_set_scalar_function_name(access)?;
+
+ // Write function parameters
+ write!(self.out, "(")?;
+ let struct_name = &self.names[&NameKey::Type(access.ty)];
+ write!(self.out, "{struct_name} {STRUCT_ARGUMENT_VARIABLE_NAME}, ")?;
+ let member = match module.types[access.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => &members[access.index as usize],
+ _ => unreachable!(),
+ };
+ let scalar_ty = match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { scalar, .. } => crate::TypeInner::Scalar(scalar),
+ _ => unreachable!(),
+ };
+ self.write_value_type(module, &scalar_ty)?;
+ write!(
+ self.out,
+ " {SCALAR_ARGUMENT_VARIABLE_NAME}, uint {MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}, uint {VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}"
+ )?;
+
+ // Write function body
+ writeln!(self.out, ") {{")?;
+
+ writeln!(
+ self.out,
+ "{INDENT}switch({MATRIX_INDEX_ARGUMENT_VARIABLE_NAME}) {{"
+ )?;
+
+ let field_name = &self.names[&NameKey::StructMember(access.ty, access.index)];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix { columns, .. } => {
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}case {i}: {{ {STRUCT_ARGUMENT_VARIABLE_NAME}.{field_name}_{i}[{VECTOR_INDEX_ARGUMENT_VARIABLE_NAME}] = {SCALAR_ARGUMENT_VARIABLE_NAME}; break; }}"
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ writeln!(self.out, "{INDENT}}}")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ /// Write functions to create special types.
+ pub(super) fn write_special_functions(&mut self, module: &crate::Module) -> BackendResult {
+ for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
+ match type_key {
+ &crate::PredeclaredType::ModfResult { size, width }
+ | &crate::PredeclaredType::FrexpResult { size, width } => {
+ let arg_type_name_owner;
+ let arg_type_name = if let Some(size) = size {
+ arg_type_name_owner = format!(
+ "{}{}",
+ if width == 8 { "double" } else { "float" },
+ size as u8
+ );
+ &arg_type_name_owner
+ } else if width == 8 {
+ "double"
+ } else {
+ "float"
+ };
+
+ let (defined_func_name, called_func_name, second_field_name, sign_multiplier) =
+ if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
+ (super::writer::MODF_FUNCTION, "modf", "whole", "")
+ } else {
+ (
+ super::writer::FREXP_FUNCTION,
+ "frexp",
+ "exp_",
+ "sign(arg) * ",
+ )
+ };
+
+ let struct_name = &self.names[&NameKey::Type(*struct_ty)];
+
+ writeln!(
+ self.out,
+ "{struct_name} {defined_func_name}({arg_type_name} arg) {{
+ {arg_type_name} other;
+ {struct_name} result;
+ result.fract = {sign_multiplier}{called_func_name}(arg, other);
+ result.{second_field_name} = other;
+ return result;
+}}"
+ )?;
+ writeln!(self.out)?;
+ }
+ &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper function that writes compose wrapped functions
+ pub(super) fn write_wrapped_compose_functions(
+ &mut self,
+ module: &crate::Module,
+ expressions: &crate::Arena<crate::Expression>,
+ ) -> BackendResult {
+ for (handle, _) in expressions.iter() {
+ if let crate::Expression::Compose { ty, .. } = expressions[handle] {
+ match module.types[ty].inner {
+ crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => {
+ let constructor = WrappedConstructor { ty };
+ if self.wrapped.constructors.insert(constructor) {
+ self.write_wrapped_constructor_function(module, constructor)?;
+ }
+ }
+ _ => {}
+ };
+ }
+ }
+ Ok(())
+ }
+
+ /// Helper function that writes various wrapped functions
+ pub(super) fn write_wrapped_functions(
+ &mut self,
+ module: &crate::Module,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ self.write_wrapped_compose_functions(module, func_ctx.expressions)?;
+
+ for (handle, _) in func_ctx.expressions.iter() {
+ match func_ctx.expressions[handle] {
+ crate::Expression::ArrayLength(expr) => {
+ let global_expr = match func_ctx.expressions[expr] {
+ crate::Expression::GlobalVariable(_) => expr,
+ crate::Expression::AccessIndex { base, index: _ } => base,
+ ref other => unreachable!("Array length of {:?}", other),
+ };
+ let global_var = match func_ctx.expressions[global_expr] {
+ crate::Expression::GlobalVariable(var_handle) => {
+ &module.global_variables[var_handle]
+ }
+ ref other => unreachable!("Array length of base {:?}", other),
+ };
+ let storage_access = match global_var.space {
+ crate::AddressSpace::Storage { access } => access,
+ _ => crate::StorageAccess::default(),
+ };
+ let wal = WrappedArrayLength {
+ writable: storage_access.contains(crate::StorageAccess::STORE),
+ };
+
+ if self.wrapped.array_lengths.insert(wal) {
+ self.write_wrapped_array_length_function(wal)?;
+ }
+ }
+ crate::Expression::ImageQuery { image, query } => {
+ let wiq = match *func_ctx.resolve_type(image, &module.types) {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => WrappedImageQuery {
+ dim,
+ arrayed,
+ class,
+ query: query.into(),
+ },
+ _ => unreachable!("we only query images"),
+ };
+
+ if self.wrapped.image_queries.insert(wiq) {
+ self.write_wrapped_image_query_function(module, wiq, handle, func_ctx)?;
+ }
+ }
+ // Write `WrappedConstructor` for structs that are loaded from `AddressSpace::Storage`
+ // since they will later be used by the fn `write_storage_load`
+ crate::Expression::Load { pointer } => {
+ let pointer_space = func_ctx
+ .resolve_type(pointer, &module.types)
+ .pointer_space();
+
+ if let Some(crate::AddressSpace::Storage { .. }) = pointer_space {
+ if let Some(ty) = func_ctx.info[handle].ty.handle() {
+ write_wrapped_constructor(self, ty, module)?;
+ }
+ }
+
+ fn write_wrapped_constructor<W: Write>(
+ writer: &mut super::Writer<'_, W>,
+ ty: Handle<crate::Type>,
+ module: &crate::Module,
+ ) -> BackendResult {
+ match module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for member in members {
+ write_wrapped_constructor(writer, member.ty, module)?;
+ }
+
+ let constructor = WrappedConstructor { ty };
+ if writer.wrapped.constructors.insert(constructor) {
+ writer
+ .write_wrapped_constructor_function(module, constructor)?;
+ }
+ }
+ crate::TypeInner::Array { base, .. } => {
+ write_wrapped_constructor(writer, base, module)?;
+
+ let constructor = WrappedConstructor { ty };
+ if writer.wrapped.constructors.insert(constructor) {
+ writer
+ .write_wrapped_constructor_function(module, constructor)?;
+ }
+ }
+ _ => {}
+ };
+
+ Ok(())
+ }
+ }
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s
+ // (see top level module docs for details).
+ //
+ // The functions injected here are required to get the matrix accesses working.
+ crate::Expression::AccessIndex { base, index } => {
+ let base_ty_res = &func_ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+ let base_ty_handle = match *resolved {
+ crate::TypeInner::Pointer { base, .. } => {
+ resolved = &module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+ if let crate::TypeInner::Struct { ref members, .. } = *resolved {
+ let member = &members[index as usize];
+
+ match module.types[member.ty].inner {
+ crate::TypeInner::Matrix {
+ rows: crate::VectorSize::Bi,
+ ..
+ } if member.binding.is_none() => {
+ let ty = base_ty_handle.unwrap();
+ let access = WrappedStructMatrixAccess { ty, index };
+
+ if self.wrapped.struct_matrix_access.insert(access) {
+ self.write_wrapped_struct_matrix_get_function(module, access)?;
+ self.write_wrapped_struct_matrix_set_function(module, access)?;
+ self.write_wrapped_struct_matrix_set_vec_function(
+ module, access,
+ )?;
+ self.write_wrapped_struct_matrix_set_scalar_function(
+ module, access,
+ )?;
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+ _ => {}
+ };
+ }
+
+ Ok(())
+ }
+
+ pub(super) fn write_texture_coordinates(
+ &mut self,
+ kind: &str,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ mip_level: Option<Handle<crate::Expression>>,
+ module: &crate::Module,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ // HLSL expects the array index to be merged with the coordinate
+ let extra = array_index.is_some() as usize + (mip_level.is_some()) as usize;
+ if extra == 0 {
+ self.write_expr(module, coordinate, func_ctx)?;
+ } else {
+ let num_coords = match *func_ctx.resolve_type(coordinate, &module.types) {
+ crate::TypeInner::Scalar { .. } => 1,
+ crate::TypeInner::Vector { size, .. } => size as usize,
+ _ => unreachable!(),
+ };
+ write!(self.out, "{}{}(", kind, num_coords + extra)?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ if let Some(expr) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ if let Some(expr) = mip_level {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ write!(self.out, ")")?;
+ }
+ Ok(())
+ }
+
+ pub(super) fn write_mat_cx2_typedef_and_functions(
+ &mut self,
+ WrappedMatCx2 { columns }: WrappedMatCx2,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ // typedef
+ write!(self.out, "typedef struct {{ ")?;
+ for i in 0..columns as u8 {
+ write!(self.out, "float2 _{i}; ")?;
+ }
+ writeln!(self.out, "}} __mat{}x2;", columns as u8)?;
+
+ // __get_col_of_mat
+ writeln!(
+ self.out,
+ "float2 __get_col_of_mat{}x2(__mat{}x2 mat, uint idx) {{",
+ columns as u8, columns as u8
+ )?;
+ writeln!(self.out, "{INDENT}switch(idx) {{")?;
+ for i in 0..columns as u8 {
+ writeln!(self.out, "{INDENT}case {i}: {{ return mat._{i}; }}")?;
+ }
+ writeln!(self.out, "{INDENT}default: {{ return (float2)0; }}")?;
+ writeln!(self.out, "{INDENT}}}")?;
+ writeln!(self.out, "}}")?;
+
+ // __set_col_of_mat
+ writeln!(
+ self.out,
+ "void __set_col_of_mat{}x2(__mat{}x2 mat, uint idx, float2 value) {{",
+ columns as u8, columns as u8
+ )?;
+ writeln!(self.out, "{INDENT}switch(idx) {{")?;
+ for i in 0..columns as u8 {
+ writeln!(self.out, "{INDENT}case {i}: {{ mat._{i} = value; break; }}")?;
+ }
+ writeln!(self.out, "{INDENT}}}")?;
+ writeln!(self.out, "}}")?;
+
+ // __set_el_of_mat
+ writeln!(
+ self.out,
+ "void __set_el_of_mat{}x2(__mat{}x2 mat, uint idx, uint vec_idx, float value) {{",
+ columns as u8, columns as u8
+ )?;
+ writeln!(self.out, "{INDENT}switch(idx) {{")?;
+ for i in 0..columns as u8 {
+ writeln!(
+ self.out,
+ "{INDENT}case {i}: {{ mat._{i}[vec_idx] = value; break; }}"
+ )?;
+ }
+ writeln!(self.out, "{INDENT}}}")?;
+ writeln!(self.out, "}}")?;
+
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ pub(super) fn write_all_mat_cx2_typedefs_and_functions(
+ &mut self,
+ module: &crate::Module,
+ ) -> BackendResult {
+ for (handle, _) in module.global_variables.iter() {
+ let global = &module.global_variables[handle];
+
+ if global.space == crate::AddressSpace::Uniform {
+ if let Some(super::writer::MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = super::writer::get_inner_matrix_data(module, global.ty)
+ {
+ let entry = WrappedMatCx2 { columns };
+ if self.wrapped.mat_cx2s.insert(entry) {
+ self.write_mat_cx2_typedef_and_functions(entry)?;
+ }
+ }
+ }
+ }
+
+ for (_, ty) in module.types.iter() {
+ if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
+ for member in members.iter() {
+ if let crate::TypeInner::Array { .. } = module.types[member.ty].inner {
+ if let Some(super::writer::MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = super::writer::get_inner_matrix_data(module, member.ty)
+ {
+ let entry = WrappedMatCx2 { columns };
+ if self.wrapped.mat_cx2s.insert(entry) {
+ self.write_mat_cx2_typedef_and_functions(entry)?;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/keywords.rs b/third_party/rust/naga/src/back/hlsl/keywords.rs
new file mode 100644
index 0000000000..059e533ff7
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/keywords.rs
@@ -0,0 +1,904 @@
+// When compiling with FXC without strict mode, these keywords are actually case insensitive.
+// If you compile with strict mode and specify a different casing like "Pass" instead in an identifier, FXC will give this error:
+// "error X3086: alternate cases for 'pass' are deprecated in strict mode"
+// This behavior is not documented anywhere, but as far as I can tell this is the full list.
+pub const RESERVED_CASE_INSENSITIVE: &[&str] = &[
+ "asm",
+ "decl",
+ "pass",
+ "technique",
+ "Texture1D",
+ "Texture2D",
+ "Texture3D",
+ "TextureCube",
+];
+
+pub const RESERVED: &[&str] = &[
+ // FXC keywords, from https://github.com/MicrosoftDocs/win32/blob/c885cb0c63b0e9be80c6a0e6512473ac6f4e771e/desktop-src/direct3dhlsl/dx-graphics-hlsl-appendix-keywords.md?plain=1#L99-L118
+ "AppendStructuredBuffer",
+ "asm",
+ "asm_fragment",
+ "BlendState",
+ "bool",
+ "break",
+ "Buffer",
+ "ByteAddressBuffer",
+ "case",
+ "cbuffer",
+ "centroid",
+ "class",
+ "column_major",
+ "compile",
+ "compile_fragment",
+ "CompileShader",
+ "const",
+ "continue",
+ "ComputeShader",
+ "ConsumeStructuredBuffer",
+ "default",
+ "DepthStencilState",
+ "DepthStencilView",
+ "discard",
+ "do",
+ "double",
+ "DomainShader",
+ "dword",
+ "else",
+ "export",
+ "extern",
+ "false",
+ "float",
+ "for",
+ "fxgroup",
+ "GeometryShader",
+ "groupshared",
+ "half",
+ "Hullshader",
+ "if",
+ "in",
+ "inline",
+ "inout",
+ "InputPatch",
+ "int",
+ "interface",
+ "line",
+ "lineadj",
+ "linear",
+ "LineStream",
+ "matrix",
+ "min16float",
+ "min10float",
+ "min16int",
+ "min12int",
+ "min16uint",
+ "namespace",
+ "nointerpolation",
+ "noperspective",
+ "NULL",
+ "out",
+ "OutputPatch",
+ "packoffset",
+ "pass",
+ "pixelfragment",
+ "PixelShader",
+ "point",
+ "PointStream",
+ "precise",
+ "RasterizerState",
+ "RenderTargetView",
+ "return",
+ "register",
+ "row_major",
+ "RWBuffer",
+ "RWByteAddressBuffer",
+ "RWStructuredBuffer",
+ "RWTexture1D",
+ "RWTexture1DArray",
+ "RWTexture2D",
+ "RWTexture2DArray",
+ "RWTexture3D",
+ "sample",
+ "sampler",
+ "SamplerState",
+ "SamplerComparisonState",
+ "shared",
+ "snorm",
+ "stateblock",
+ "stateblock_state",
+ "static",
+ "string",
+ "struct",
+ "switch",
+ "StructuredBuffer",
+ "tbuffer",
+ "technique",
+ "technique10",
+ "technique11",
+ "texture",
+ "Texture1D",
+ "Texture1DArray",
+ "Texture2D",
+ "Texture2DArray",
+ "Texture2DMS",
+ "Texture2DMSArray",
+ "Texture3D",
+ "TextureCube",
+ "TextureCubeArray",
+ "true",
+ "typedef",
+ "triangle",
+ "triangleadj",
+ "TriangleStream",
+ "uint",
+ "uniform",
+ "unorm",
+ "unsigned",
+ "vector",
+ "vertexfragment",
+ "VertexShader",
+ "void",
+ "volatile",
+ "while",
+ // FXC reserved keywords, from https://github.com/MicrosoftDocs/win32/blob/c885cb0c63b0e9be80c6a0e6512473ac6f4e771e/desktop-src/direct3dhlsl/dx-graphics-hlsl-appendix-reserved-words.md?plain=1#L19-L38
+ "auto",
+ "case",
+ "catch",
+ "char",
+ "class",
+ "const_cast",
+ "default",
+ "delete",
+ "dynamic_cast",
+ "enum",
+ "explicit",
+ "friend",
+ "goto",
+ "long",
+ "mutable",
+ "new",
+ "operator",
+ "private",
+ "protected",
+ "public",
+ "reinterpret_cast",
+ "short",
+ "signed",
+ "sizeof",
+ "static_cast",
+ "template",
+ "this",
+ "throw",
+ "try",
+ "typename",
+ "union",
+ "unsigned",
+ "using",
+ "virtual",
+ // FXC intrinsics, from https://github.com/MicrosoftDocs/win32/blob/1682b99e203708f6f5eda972d966e30f3c1588de/desktop-src/direct3dhlsl/dx-graphics-hlsl-intrinsic-functions.md?plain=1#L26-L165
+ "abort",
+ "abs",
+ "acos",
+ "all",
+ "AllMemoryBarrier",
+ "AllMemoryBarrierWithGroupSync",
+ "any",
+ "asdouble",
+ "asfloat",
+ "asin",
+ "asint",
+ "asuint",
+ "atan",
+ "atan2",
+ "ceil",
+ "CheckAccessFullyMapped",
+ "clamp",
+ "clip",
+ "cos",
+ "cosh",
+ "countbits",
+ "cross",
+ "D3DCOLORtoUBYTE4",
+ "ddx",
+ "ddx_coarse",
+ "ddx_fine",
+ "ddy",
+ "ddy_coarse",
+ "ddy_fine",
+ "degrees",
+ "determinant",
+ "DeviceMemoryBarrier",
+ "DeviceMemoryBarrierWithGroupSync",
+ "distance",
+ "dot",
+ "dst",
+ "errorf",
+ "EvaluateAttributeCentroid",
+ "EvaluateAttributeAtSample",
+ "EvaluateAttributeSnapped",
+ "exp",
+ "exp2",
+ "f16tof32",
+ "f32tof16",
+ "faceforward",
+ "firstbithigh",
+ "firstbitlow",
+ "floor",
+ "fma",
+ "fmod",
+ "frac",
+ "frexp",
+ "fwidth",
+ "GetRenderTargetSampleCount",
+ "GetRenderTargetSamplePosition",
+ "GroupMemoryBarrier",
+ "GroupMemoryBarrierWithGroupSync",
+ "InterlockedAdd",
+ "InterlockedAnd",
+ "InterlockedCompareExchange",
+ "InterlockedCompareStore",
+ "InterlockedExchange",
+ "InterlockedMax",
+ "InterlockedMin",
+ "InterlockedOr",
+ "InterlockedXor",
+ "isfinite",
+ "isinf",
+ "isnan",
+ "ldexp",
+ "length",
+ "lerp",
+ "lit",
+ "log",
+ "log10",
+ "log2",
+ "mad",
+ "max",
+ "min",
+ "modf",
+ "msad4",
+ "mul",
+ "noise",
+ "normalize",
+ "pow",
+ "printf",
+ "Process2DQuadTessFactorsAvg",
+ "Process2DQuadTessFactorsMax",
+ "Process2DQuadTessFactorsMin",
+ "ProcessIsolineTessFactors",
+ "ProcessQuadTessFactorsAvg",
+ "ProcessQuadTessFactorsMax",
+ "ProcessQuadTessFactorsMin",
+ "ProcessTriTessFactorsAvg",
+ "ProcessTriTessFactorsMax",
+ "ProcessTriTessFactorsMin",
+ "radians",
+ "rcp",
+ "reflect",
+ "refract",
+ "reversebits",
+ "round",
+ "rsqrt",
+ "saturate",
+ "sign",
+ "sin",
+ "sincos",
+ "sinh",
+ "smoothstep",
+ "sqrt",
+ "step",
+ "tan",
+ "tanh",
+ "tex1D",
+ "tex1Dbias",
+ "tex1Dgrad",
+ "tex1Dlod",
+ "tex1Dproj",
+ "tex2D",
+ "tex2Dbias",
+ "tex2Dgrad",
+ "tex2Dlod",
+ "tex2Dproj",
+ "tex3D",
+ "tex3Dbias",
+ "tex3Dgrad",
+ "tex3Dlod",
+ "tex3Dproj",
+ "texCUBE",
+ "texCUBEbias",
+ "texCUBEgrad",
+ "texCUBElod",
+ "texCUBEproj",
+ "transpose",
+ "trunc",
+ // DXC (reserved) keywords, from https://github.com/microsoft/DirectXShaderCompiler/blob/d5d478470d3020a438d3cb810b8d3fe0992e6709/tools/clang/include/clang/Basic/TokenKinds.def#L222-L648
+ // with the KEYALL, KEYCXX, BOOLSUPPORT, WCHARSUPPORT, KEYHLSL options enabled (see https://github.com/microsoft/DirectXShaderCompiler/blob/d5d478470d3020a438d3cb810b8d3fe0992e6709/tools/clang/lib/Frontend/CompilerInvocation.cpp#L1199)
+ "auto",
+ "break",
+ "case",
+ "char",
+ "const",
+ "continue",
+ "default",
+ "do",
+ "double",
+ "else",
+ "enum",
+ "extern",
+ "float",
+ "for",
+ "goto",
+ "if",
+ "inline",
+ "int",
+ "long",
+ "register",
+ "return",
+ "short",
+ "signed",
+ "sizeof",
+ "static",
+ "struct",
+ "switch",
+ "typedef",
+ "union",
+ "unsigned",
+ "void",
+ "volatile",
+ "while",
+ "_Alignas",
+ "_Alignof",
+ "_Atomic",
+ "_Complex",
+ "_Generic",
+ "_Imaginary",
+ "_Noreturn",
+ "_Static_assert",
+ "_Thread_local",
+ "__func__",
+ "__objc_yes",
+ "__objc_no",
+ "asm",
+ "bool",
+ "catch",
+ "class",
+ "const_cast",
+ "delete",
+ "dynamic_cast",
+ "explicit",
+ "export",
+ "false",
+ "friend",
+ "mutable",
+ "namespace",
+ "new",
+ "operator",
+ "private",
+ "protected",
+ "public",
+ "reinterpret_cast",
+ "static_cast",
+ "template",
+ "this",
+ "throw",
+ "true",
+ "try",
+ "typename",
+ "typeid",
+ "using",
+ "virtual",
+ "wchar_t",
+ "_Decimal32",
+ "_Decimal64",
+ "_Decimal128",
+ "__null",
+ "__alignof",
+ "__attribute",
+ "__builtin_choose_expr",
+ "__builtin_offsetof",
+ "__builtin_va_arg",
+ "__extension__",
+ "__imag",
+ "__int128",
+ "__label__",
+ "__real",
+ "__thread",
+ "__FUNCTION__",
+ "__PRETTY_FUNCTION__",
+ "__is_nothrow_assignable",
+ "__is_constructible",
+ "__is_nothrow_constructible",
+ "__has_nothrow_assign",
+ "__has_nothrow_move_assign",
+ "__has_nothrow_copy",
+ "__has_nothrow_constructor",
+ "__has_trivial_assign",
+ "__has_trivial_move_assign",
+ "__has_trivial_copy",
+ "__has_trivial_constructor",
+ "__has_trivial_move_constructor",
+ "__has_trivial_destructor",
+ "__has_virtual_destructor",
+ "__is_abstract",
+ "__is_base_of",
+ "__is_class",
+ "__is_convertible_to",
+ "__is_empty",
+ "__is_enum",
+ "__is_final",
+ "__is_literal",
+ "__is_literal_type",
+ "__is_pod",
+ "__is_polymorphic",
+ "__is_trivial",
+ "__is_union",
+ "__is_trivially_constructible",
+ "__is_trivially_copyable",
+ "__is_trivially_assignable",
+ "__underlying_type",
+ "__is_lvalue_expr",
+ "__is_rvalue_expr",
+ "__is_arithmetic",
+ "__is_floating_point",
+ "__is_integral",
+ "__is_complete_type",
+ "__is_void",
+ "__is_array",
+ "__is_function",
+ "__is_reference",
+ "__is_lvalue_reference",
+ "__is_rvalue_reference",
+ "__is_fundamental",
+ "__is_object",
+ "__is_scalar",
+ "__is_compound",
+ "__is_pointer",
+ "__is_member_object_pointer",
+ "__is_member_function_pointer",
+ "__is_member_pointer",
+ "__is_const",
+ "__is_volatile",
+ "__is_standard_layout",
+ "__is_signed",
+ "__is_unsigned",
+ "__is_same",
+ "__is_convertible",
+ "__array_rank",
+ "__array_extent",
+ "__private_extern__",
+ "__module_private__",
+ "__declspec",
+ "__cdecl",
+ "__stdcall",
+ "__fastcall",
+ "__thiscall",
+ "__vectorcall",
+ "cbuffer",
+ "tbuffer",
+ "packoffset",
+ "linear",
+ "centroid",
+ "nointerpolation",
+ "noperspective",
+ "sample",
+ "column_major",
+ "row_major",
+ "in",
+ "out",
+ "inout",
+ "uniform",
+ "precise",
+ "center",
+ "shared",
+ "groupshared",
+ "discard",
+ "snorm",
+ "unorm",
+ "point",
+ "line",
+ "lineadj",
+ "triangle",
+ "triangleadj",
+ "globallycoherent",
+ "interface",
+ "sampler_state",
+ "technique",
+ "indices",
+ "vertices",
+ "primitives",
+ "payload",
+ "Technique",
+ "technique10",
+ "technique11",
+ "__builtin_omp_required_simd_align",
+ "__pascal",
+ "__fp16",
+ "__alignof__",
+ "__asm",
+ "__asm__",
+ "__attribute__",
+ "__complex",
+ "__complex__",
+ "__const",
+ "__const__",
+ "__decltype",
+ "__imag__",
+ "__inline",
+ "__inline__",
+ "__nullptr",
+ "__real__",
+ "__restrict",
+ "__restrict__",
+ "__signed",
+ "__signed__",
+ "__typeof",
+ "__typeof__",
+ "__volatile",
+ "__volatile__",
+ "_Nonnull",
+ "_Nullable",
+ "_Null_unspecified",
+ "__builtin_convertvector",
+ "__char16_t",
+ "__char32_t",
+ // DXC intrinsics, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/utils/hct/gen_intrin_main.txt#L86-L376
+ "D3DCOLORtoUBYTE4",
+ "GetRenderTargetSampleCount",
+ "GetRenderTargetSamplePosition",
+ "abort",
+ "abs",
+ "acos",
+ "all",
+ "AllMemoryBarrier",
+ "AllMemoryBarrierWithGroupSync",
+ "any",
+ "asdouble",
+ "asfloat",
+ "asfloat16",
+ "asint16",
+ "asin",
+ "asint",
+ "asuint",
+ "asuint16",
+ "atan",
+ "atan2",
+ "ceil",
+ "clamp",
+ "clip",
+ "cos",
+ "cosh",
+ "countbits",
+ "cross",
+ "ddx",
+ "ddx_coarse",
+ "ddx_fine",
+ "ddy",
+ "ddy_coarse",
+ "ddy_fine",
+ "degrees",
+ "determinant",
+ "DeviceMemoryBarrier",
+ "DeviceMemoryBarrierWithGroupSync",
+ "distance",
+ "dot",
+ "dst",
+ "EvaluateAttributeAtSample",
+ "EvaluateAttributeCentroid",
+ "EvaluateAttributeSnapped",
+ "GetAttributeAtVertex",
+ "exp",
+ "exp2",
+ "f16tof32",
+ "f32tof16",
+ "faceforward",
+ "firstbithigh",
+ "firstbitlow",
+ "floor",
+ "fma",
+ "fmod",
+ "frac",
+ "frexp",
+ "fwidth",
+ "GroupMemoryBarrier",
+ "GroupMemoryBarrierWithGroupSync",
+ "InterlockedAdd",
+ "InterlockedMin",
+ "InterlockedMax",
+ "InterlockedAnd",
+ "InterlockedOr",
+ "InterlockedXor",
+ "InterlockedCompareStore",
+ "InterlockedExchange",
+ "InterlockedCompareExchange",
+ "InterlockedCompareStoreFloatBitwise",
+ "InterlockedCompareExchangeFloatBitwise",
+ "isfinite",
+ "isinf",
+ "isnan",
+ "ldexp",
+ "length",
+ "lerp",
+ "lit",
+ "log",
+ "log10",
+ "log2",
+ "mad",
+ "max",
+ "min",
+ "modf",
+ "msad4",
+ "mul",
+ "normalize",
+ "pow",
+ "printf",
+ "Process2DQuadTessFactorsAvg",
+ "Process2DQuadTessFactorsMax",
+ "Process2DQuadTessFactorsMin",
+ "ProcessIsolineTessFactors",
+ "ProcessQuadTessFactorsAvg",
+ "ProcessQuadTessFactorsMax",
+ "ProcessQuadTessFactorsMin",
+ "ProcessTriTessFactorsAvg",
+ "ProcessTriTessFactorsMax",
+ "ProcessTriTessFactorsMin",
+ "radians",
+ "rcp",
+ "reflect",
+ "refract",
+ "reversebits",
+ "round",
+ "rsqrt",
+ "saturate",
+ "sign",
+ "sin",
+ "sincos",
+ "sinh",
+ "smoothstep",
+ "source_mark",
+ "sqrt",
+ "step",
+ "tan",
+ "tanh",
+ "tex1D",
+ "tex1Dbias",
+ "tex1Dgrad",
+ "tex1Dlod",
+ "tex1Dproj",
+ "tex2D",
+ "tex2Dbias",
+ "tex2Dgrad",
+ "tex2Dlod",
+ "tex2Dproj",
+ "tex3D",
+ "tex3Dbias",
+ "tex3Dgrad",
+ "tex3Dlod",
+ "tex3Dproj",
+ "texCUBE",
+ "texCUBEbias",
+ "texCUBEgrad",
+ "texCUBElod",
+ "texCUBEproj",
+ "transpose",
+ "trunc",
+ "CheckAccessFullyMapped",
+ "AddUint64",
+ "NonUniformResourceIndex",
+ "WaveIsFirstLane",
+ "WaveGetLaneIndex",
+ "WaveGetLaneCount",
+ "WaveActiveAnyTrue",
+ "WaveActiveAllTrue",
+ "WaveActiveAllEqual",
+ "WaveActiveBallot",
+ "WaveReadLaneAt",
+ "WaveReadLaneFirst",
+ "WaveActiveCountBits",
+ "WaveActiveSum",
+ "WaveActiveProduct",
+ "WaveActiveBitAnd",
+ "WaveActiveBitOr",
+ "WaveActiveBitXor",
+ "WaveActiveMin",
+ "WaveActiveMax",
+ "WavePrefixCountBits",
+ "WavePrefixSum",
+ "WavePrefixProduct",
+ "WaveMatch",
+ "WaveMultiPrefixBitAnd",
+ "WaveMultiPrefixBitOr",
+ "WaveMultiPrefixBitXor",
+ "WaveMultiPrefixCountBits",
+ "WaveMultiPrefixProduct",
+ "WaveMultiPrefixSum",
+ "QuadReadLaneAt",
+ "QuadReadAcrossX",
+ "QuadReadAcrossY",
+ "QuadReadAcrossDiagonal",
+ "QuadAny",
+ "QuadAll",
+ "TraceRay",
+ "ReportHit",
+ "CallShader",
+ "IgnoreHit",
+ "AcceptHitAndEndSearch",
+ "DispatchRaysIndex",
+ "DispatchRaysDimensions",
+ "WorldRayOrigin",
+ "WorldRayDirection",
+ "ObjectRayOrigin",
+ "ObjectRayDirection",
+ "RayTMin",
+ "RayTCurrent",
+ "PrimitiveIndex",
+ "InstanceID",
+ "InstanceIndex",
+ "GeometryIndex",
+ "HitKind",
+ "RayFlags",
+ "ObjectToWorld",
+ "WorldToObject",
+ "ObjectToWorld3x4",
+ "WorldToObject3x4",
+ "ObjectToWorld4x3",
+ "WorldToObject4x3",
+ "dot4add_u8packed",
+ "dot4add_i8packed",
+ "dot2add",
+ "unpack_s8s16",
+ "unpack_u8u16",
+ "unpack_s8s32",
+ "unpack_u8u32",
+ "pack_s8",
+ "pack_u8",
+ "pack_clamp_s8",
+ "pack_clamp_u8",
+ "SetMeshOutputCounts",
+ "DispatchMesh",
+ "IsHelperLane",
+ "AllocateRayQuery",
+ "CreateResourceFromHeap",
+ "and",
+ "or",
+ "select",
+ // DXC resource and other types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/HlslTypes.cpp#L441-#L572
+ "InputPatch",
+ "OutputPatch",
+ "PointStream",
+ "LineStream",
+ "TriangleStream",
+ "Texture1D",
+ "RWTexture1D",
+ "Texture2D",
+ "RWTexture2D",
+ "Texture2DMS",
+ "RWTexture2DMS",
+ "Texture3D",
+ "RWTexture3D",
+ "TextureCube",
+ "RWTextureCube",
+ "Texture1DArray",
+ "RWTexture1DArray",
+ "Texture2DArray",
+ "RWTexture2DArray",
+ "Texture2DMSArray",
+ "RWTexture2DMSArray",
+ "TextureCubeArray",
+ "RWTextureCubeArray",
+ "FeedbackTexture2D",
+ "FeedbackTexture2DArray",
+ "RasterizerOrderedTexture1D",
+ "RasterizerOrderedTexture2D",
+ "RasterizerOrderedTexture3D",
+ "RasterizerOrderedTexture1DArray",
+ "RasterizerOrderedTexture2DArray",
+ "RasterizerOrderedBuffer",
+ "RasterizerOrderedByteAddressBuffer",
+ "RasterizerOrderedStructuredBuffer",
+ "ByteAddressBuffer",
+ "RWByteAddressBuffer",
+ "StructuredBuffer",
+ "RWStructuredBuffer",
+ "AppendStructuredBuffer",
+ "ConsumeStructuredBuffer",
+ "Buffer",
+ "RWBuffer",
+ "SamplerState",
+ "SamplerComparisonState",
+ "ConstantBuffer",
+ "TextureBuffer",
+ "RaytracingAccelerationStructure",
+ // DXC templated types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp
+ // look for `BuiltinTypeDeclBuilder`
+ "matrix",
+ "vector",
+ "TextureBuffer",
+ "ConstantBuffer",
+ "RayQuery",
+ // Naga utilities
+ super::writer::MODF_FUNCTION,
+ super::writer::FREXP_FUNCTION,
+];
+
+// DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254
+// + vector and matrix shorthands
+pub const TYPES: &[&str] = &{
+ const L: usize = 23 * (1 + 4 + 4 * 4);
+ let mut res = [""; L];
+ let mut c = 0;
+
+ /// For each scalar type, it will additionally generate vector and matrix shorthands
+ macro_rules! generate {
+ ([$($roots:literal),*], $x:tt) => {
+ $(
+ generate!(@inner push $roots);
+ generate!(@inner $roots, $x);
+ )*
+ };
+
+ (@inner $root:literal, [$($x:literal),*]) => {
+ generate!(@inner vector $root, $($x)*);
+ generate!(@inner matrix $root, $($x)*);
+ };
+
+ (@inner vector $root:literal, $($x:literal)*) => {
+ $(
+ generate!(@inner push concat!($root, $x));
+ )*
+ };
+
+ (@inner matrix $root:literal, $($x:literal)*) => {
+ // Duplicate the list
+ generate!(@inner matrix $root, $($x)*; $($x)*);
+ };
+
+ // The head/tail recursion: pick the first element of the first list and recursively do it for the tail.
+ (@inner matrix $root:literal, $head:literal $($tail:literal)*; $($x:literal)*) => {
+ $(
+ generate!(@inner push concat!($root, $head, "x", $x));
+ )*
+ generate!(@inner matrix $root, $($tail)*; $($x)*);
+
+ };
+
+ // The end of iteration: we exhausted the list
+ (@inner matrix $root:literal, ; $($x:literal)*) => {};
+
+ (@inner push $v:expr) => {
+ res[c] = $v;
+ c += 1;
+ };
+ }
+
+ generate!(
+ [
+ "bool",
+ "int",
+ "uint",
+ "dword",
+ "half",
+ "float",
+ "double",
+ "min10float",
+ "min16float",
+ "min12int",
+ "min16int",
+ "min16uint",
+ "int16_t",
+ "int32_t",
+ "int64_t",
+ "uint16_t",
+ "uint32_t",
+ "uint64_t",
+ "float16_t",
+ "float32_t",
+ "float64_t",
+ "int8_t4_packed",
+ "uint8_t4_packed"
+ ],
+ ["1", "2", "3", "4"]
+ );
+
+ debug_assert!(c == L);
+
+ res
+};
diff --git a/third_party/rust/naga/src/back/hlsl/mod.rs b/third_party/rust/naga/src/back/hlsl/mod.rs
new file mode 100644
index 0000000000..37ddbd3d67
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/mod.rs
@@ -0,0 +1,302 @@
+/*!
+Backend for [HLSL][hlsl] (High-Level Shading Language).
+
+# Supported shader model versions:
+- 5.0
+- 5.1
+- 6.0
+
+# Layout of values in `uniform` buffers
+
+WGSL's ["Internal Layout of Values"][ilov] rules specify how each WGSL
+type should be stored in `uniform` and `storage` buffers. The HLSL we
+generate must access values in that form, even when it is not what
+HLSL would use normally.
+
+The rules described here only apply to WGSL `uniform` variables. WGSL
+`storage` buffers are translated as HLSL `ByteAddressBuffers`, for
+which we generate `Load` and `Store` method calls with explicit byte
+offsets. WGSL pipeline inputs must be scalars or vectors; they cannot
+be matrices, which is where the interesting problems arise.
+
+## Row- and column-major ordering for matrices
+
+WGSL specifies that matrices in uniform buffers are stored in
+column-major order. This matches HLSL's default, so one might expect
+things to be straightforward. Unfortunately, WGSL and HLSL disagree on
+what indexing a matrix means: in WGSL, `m[i]` retrieves the `i`'th
+*column* of `m`, whereas in HLSL it retrieves the `i`'th *row*. We
+want to avoid translating `m[i]` into some complicated reassembly of a
+vector from individually fetched components, so this is a problem.
+
+However, with a bit of trickery, it is possible to use HLSL's `m[i]`
+as the translation of WGSL's `m[i]`:
+
+- We declare all matrices in uniform buffers in HLSL with the
+ `row_major` qualifier, and transpose the row and column counts: a
+ WGSL `mat3x4<f32>`, say, becomes an HLSL `row_major float3x4`. (Note
+ that WGSL and HLSL type names put the row and column in reverse
+ order.) Since the HLSL type is the transpose of how WebGPU directs
+ the user to store the data, HLSL will load all matrices transposed.
+
+- Since matrices are transposed, an HLSL indexing expression retrieves
+ the "columns" of the intended WGSL value, as desired.
+
+- For vector-matrix multiplication, since `mul(transpose(m), v)` is
+ equivalent to `mul(v, m)` (note the reversal of the arguments), and
+ `mul(v, transpose(m))` is equivalent to `mul(m, v)`, we can
+ translate WGSL `m * v` and `v * m` to HLSL by simply reversing the
+ arguments to `mul`.
+
+## Padding in two-row matrices
+
+An HLSL `row_major floatKx2` matrix has padding between its rows that
+the WGSL `matKx2<f32>` matrix it represents does not. HLSL stores all
+matrix rows [aligned on 16-byte boundaries][16bb], whereas WGSL says
+that the columns of a `matKx2<f32>` need only be [aligned as required
+for `vec2<f32>`][ilov], which is [eight-byte alignment][8bb].
+
+To compensate for this, any time a `matKx2<f32>` appears in a WGSL
+`uniform` variable, whether directly as the variable's type or as part
+of a struct/array, we actually emit `K` separate `float2` members, and
+assemble/disassemble the matrix from its columns (in WGSL; rows in
+HLSL) upon load and store.
+
+For example, the following WGSL struct type:
+
+```ignore
+struct Baz {
+ m: mat3x2<f32>,
+}
+```
+
+is rendered as the HLSL struct type:
+
+```ignore
+struct Baz {
+ float2 m_0; float2 m_1; float2 m_2;
+};
+```
+
+The `wrapped_struct_matrix` functions in `help.rs` generate HLSL
+helper functions to access such members, converting between the stored
+form and the HLSL matrix types appropriately. For example, for reading
+the member `m` of the `Baz` struct above, we emit:
+
+```ignore
+float3x2 GetMatmOnBaz(Baz obj) {
+ return float3x2(obj.m_0, obj.m_1, obj.m_2);
+}
+```
+
+We also emit an analogous `Set` function, as well as functions for
+accessing individual columns by dynamic index.
+
+[hlsl]: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl
+[ilov]: https://gpuweb.github.io/gpuweb/wgsl/#internal-value-layout
+[16bb]: https://github.com/microsoft/DirectXShaderCompiler/wiki/Buffer-Packing#constant-buffer-packing
+[8bb]: https://gpuweb.github.io/gpuweb/wgsl/#alignment-and-size
+*/
+
+mod conv;
+mod help;
+mod keywords;
+mod storage;
+mod writer;
+
+use std::fmt::Error as FmtError;
+use thiserror::Error;
+
+use crate::{back, proc};
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct BindTarget {
+ pub space: u8,
+ pub register: u32,
+ /// If the binding is an unsized binding array, this overrides the size.
+ pub binding_array_size: Option<u32>,
+}
+
+// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
+
+/// A HLSL shader model version.
+#[allow(non_snake_case, non_camel_case_types)]
+#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum ShaderModel {
+ V5_0,
+ V5_1,
+ V6_0,
+}
+
+impl ShaderModel {
+ pub const fn to_str(self) -> &'static str {
+ match self {
+ Self::V5_0 => "5_0",
+ Self::V5_1 => "5_1",
+ Self::V6_0 => "6_0",
+ }
+ }
+}
+
+impl crate::ShaderStage {
+ pub const fn to_hlsl_str(self) -> &'static str {
+ match self {
+ Self::Vertex => "vs",
+ Self::Fragment => "ps",
+ Self::Compute => "cs",
+ }
+ }
+}
+
+impl crate::ImageDimension {
+ const fn to_hlsl_str(self) -> &'static str {
+ match self {
+ Self::D1 => "1D",
+ Self::D2 => "2D",
+ Self::D3 => "3D",
+ Self::Cube => "Cube",
+ }
+ }
+}
+
+/// Shorthand result used internally by the backend
+type BackendResult = Result<(), Error>;
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum EntryPointError {
+ #[error("mapping of {0:?} is missing")]
+ MissingBinding(crate::ResourceBinding),
+}
+
+/// Configuration used in the [`Writer`].
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Options {
+ /// The hlsl shader model to be used
+ pub shader_model: ShaderModel,
+ /// Map of resources association to binding locations.
+ pub binding_map: BindingMap,
+ /// Don't panic on missing bindings, instead generate any HLSL.
+ pub fake_missing_bindings: bool,
+ /// Add special constants to `SV_VertexIndex` and `SV_InstanceIndex`,
+ /// to make them work like in Vulkan/Metal, with help of the host.
+ pub special_constants_binding: Option<BindTarget>,
+ /// Bind target of the push constant buffer
+ pub push_constants_target: Option<BindTarget>,
+ /// Should workgroup variables be zero initialized (by polyfilling)?
+ pub zero_initialize_workgroup_memory: bool,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ shader_model: ShaderModel::V5_1,
+ binding_map: BindingMap::default(),
+ fake_missing_bindings: true,
+ special_constants_binding: None,
+ push_constants_target: None,
+ zero_initialize_workgroup_memory: true,
+ }
+ }
+}
+
+impl Options {
+ fn resolve_resource_binding(
+ &self,
+ res_binding: &crate::ResourceBinding,
+ ) -> Result<BindTarget, EntryPointError> {
+ match self.binding_map.get(res_binding) {
+ Some(target) => Ok(target.clone()),
+ None if self.fake_missing_bindings => Ok(BindTarget {
+ space: res_binding.group as u8,
+ register: res_binding.binding,
+ binding_array_size: None,
+ }),
+ None => Err(EntryPointError::MissingBinding(res_binding.clone())),
+ }
+ }
+}
+
+/// Reflection info for entry point names.
+#[derive(Default)]
+pub struct ReflectionInfo {
+ /// Mapping of the entry point names.
+ ///
+ /// Each item in the array corresponds to an entry point index. The real entry point name may be different if one of the
+ /// reserved words are used.
+ ///
+ /// Note: Some entry points may fail translation because of missing bindings.
+ pub entry_point_names: Vec<Result<String, EntryPointError>>,
+}
+
+#[derive(Error, Debug)]
+pub enum Error {
+ #[error(transparent)]
+ IoError(#[from] FmtError),
+ #[error("A scalar with an unsupported width was requested: {0:?}")]
+ UnsupportedScalar(crate::Scalar),
+ #[error("{0}")]
+ Unimplemented(String), // TODO: Error used only during development
+ #[error("{0}")]
+ Custom(String),
+}
+
+#[derive(Default)]
+struct Wrapped {
+ array_lengths: crate::FastHashSet<help::WrappedArrayLength>,
+ image_queries: crate::FastHashSet<help::WrappedImageQuery>,
+ constructors: crate::FastHashSet<help::WrappedConstructor>,
+ struct_matrix_access: crate::FastHashSet<help::WrappedStructMatrixAccess>,
+ mat_cx2s: crate::FastHashSet<help::WrappedMatCx2>,
+}
+
+impl Wrapped {
+ fn clear(&mut self) {
+ self.array_lengths.clear();
+ self.image_queries.clear();
+ self.constructors.clear();
+ self.struct_matrix_access.clear();
+ self.mat_cx2s.clear();
+ }
+}
+
+pub struct Writer<'a, W> {
+ out: W,
+ names: crate::FastHashMap<proc::NameKey, String>,
+ namer: proc::Namer,
+ /// HLSL backend options
+ options: &'a Options,
+ /// Information about entry point arguments and result types.
+ entry_point_io: Vec<writer::EntryPointInterface>,
+ /// Set of expressions that have associated temporary variables
+ named_expressions: crate::NamedExpressions,
+ wrapped: Wrapped,
+
+ /// A reference to some part of a global variable, lowered to a series of
+ /// byte offset calculations.
+ ///
+ /// See the [`storage`] module for background on why we need this.
+ ///
+ /// Each [`SubAccess`] in the vector is a lowering of some [`Access`] or
+ /// [`AccessIndex`] expression to the level of byte strides and offsets. See
+ /// [`SubAccess`] for details.
+ ///
+ /// This field is a member of [`Writer`] solely to allow re-use of
+ /// the `Vec`'s dynamic allocation. The value is no longer needed
+ /// once HLSL for the access has been generated.
+ ///
+ /// [`Storage`]: crate::AddressSpace::Storage
+ /// [`SubAccess`]: storage::SubAccess
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ temp_access_chain: Vec<storage::SubAccess>,
+ need_bake_expressions: back::NeedBakeExpressions,
+}
diff --git a/third_party/rust/naga/src/back/hlsl/storage.rs b/third_party/rust/naga/src/back/hlsl/storage.rs
new file mode 100644
index 0000000000..1b8a6ec12d
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/storage.rs
@@ -0,0 +1,494 @@
+/*!
+Generating accesses to [`ByteAddressBuffer`] contents.
+
+Naga IR globals in the [`Storage`] address space are rendered as
+[`ByteAddressBuffer`]s or [`RWByteAddressBuffer`]s in HLSL. These
+buffers don't have HLSL types (structs, arrays, etc.); instead, they
+are just raw blocks of bytes, with methods to load and store values of
+specific types at particular byte offsets. This means that Naga must
+translate chains of [`Access`] and [`AccessIndex`] expressions into
+HLSL expressions that compute byte offsets into the buffer.
+
+To generate code for a [`Storage`] access:
+
+- Call [`Writer::fill_access_chain`] on the expression referring to
+ the value. This populates [`Writer::temp_access_chain`] with the
+ appropriate byte offset calculations, as a vector of [`SubAccess`]
+ values.
+
+- Call [`Writer::write_storage_address`] to emit an HLSL expression
+ for a given slice of [`SubAccess`] values.
+
+Naga IR expressions can operate on composite values of any type, but
+[`ByteAddressBuffer`] and [`RWByteAddressBuffer`] have only a fixed
+set of `Load` and `Store` methods, to access one through four
+consecutive 32-bit values. To synthesize a Naga access, you can
+initialize [`temp_access_chain`] to refer to the composite, and then
+temporarily push and pop additional steps on
+[`Writer::temp_access_chain`] to generate accesses to the individual
+elements/members.
+
+The [`temp_access_chain`] field is a member of [`Writer`] solely to
+allow re-use of the `Vec`'s dynamic allocation. Its value is no longer
+needed once HLSL for the access has been generated.
+
+[`Storage`]: crate::AddressSpace::Storage
+[`ByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-byteaddressbuffer
+[`RWByteAddressBuffer`]: https://learn.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer
+[`Access`]: crate::Expression::Access
+[`AccessIndex`]: crate::Expression::AccessIndex
+[`Writer::fill_access_chain`]: super::Writer::fill_access_chain
+[`Writer::write_storage_address`]: super::Writer::write_storage_address
+[`Writer::temp_access_chain`]: super::Writer::temp_access_chain
+[`temp_access_chain`]: super::Writer::temp_access_chain
+[`Writer`]: super::Writer
+*/
+
+use super::{super::FunctionCtx, BackendResult, Error};
+use crate::{
+ proc::{Alignment, NameKey, TypeResolution},
+ Handle,
+};
+
+use std::{fmt, mem};
+
+const STORE_TEMP_NAME: &str = "_value";
+
+/// One step in accessing a [`Storage`] global's component or element.
+///
+/// [`Writer::temp_access_chain`] holds a series of these structures,
+/// describing how to compute the byte offset of a particular element
+/// or member of some global variable in the [`Storage`] address
+/// space.
+///
+/// [`Writer::temp_access_chain`]: super::Writer::temp_access_chain
+/// [`Storage`]: crate::AddressSpace::Storage
+#[derive(Debug)]
+pub(super) enum SubAccess {
+ /// Add the given byte offset. This is used for struct members, or
+ /// known components of a vector or matrix. In all those cases,
+ /// the byte offset is a compile-time constant.
+ Offset(u32),
+
+ /// Scale `value` by `stride`, and add that to the current byte
+ /// offset. This is used to compute the offset of an array element
+ /// whose index is computed at runtime.
+ Index {
+ value: Handle<crate::Expression>,
+ stride: u32,
+ },
+}
+
+pub(super) enum StoreValue {
+ Expression(Handle<crate::Expression>),
+ TempIndex {
+ depth: usize,
+ index: u32,
+ ty: TypeResolution,
+ },
+ TempAccess {
+ depth: usize,
+ base: Handle<crate::Type>,
+ member_index: u32,
+ },
+}
+
+impl<W: fmt::Write> super::Writer<'_, W> {
+ pub(super) fn write_storage_address(
+ &mut self,
+ module: &crate::Module,
+ chain: &[SubAccess],
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ if chain.is_empty() {
+ write!(self.out, "0")?;
+ }
+ for (i, access) in chain.iter().enumerate() {
+ if i != 0 {
+ write!(self.out, "+")?;
+ }
+ match *access {
+ SubAccess::Offset(offset) => {
+ write!(self.out, "{offset}")?;
+ }
+ SubAccess::Index { value, stride } => {
+ self.write_expr(module, value, func_ctx)?;
+ write!(self.out, "*{stride}")?;
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn write_storage_load_sequence<I: Iterator<Item = (TypeResolution, u32)>>(
+ &mut self,
+ module: &crate::Module,
+ var_handle: Handle<crate::GlobalVariable>,
+ sequence: I,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ for (i, (ty_resolution, offset)) in sequence.enumerate() {
+ // add the index temporarily
+ self.temp_access_chain.push(SubAccess::Offset(offset));
+ if i != 0 {
+ write!(self.out, ", ")?;
+ };
+ self.write_storage_load(module, var_handle, ty_resolution, func_ctx)?;
+ self.temp_access_chain.pop();
+ }
+ Ok(())
+ }
+
+ /// Emit code to access a [`Storage`] global's component.
+ ///
+ /// Emit HLSL to access the component of `var_handle`, a global
+ /// variable in the [`Storage`] address space, whose type is
+ /// `result_ty` and whose location within the global is given by
+ /// [`self.temp_access_chain`]. See the [`storage`] module's
+ /// documentation for background.
+ ///
+ /// [`Storage`]: crate::AddressSpace::Storage
+ /// [`self.temp_access_chain`]: super::Writer::temp_access_chain
+ pub(super) fn write_storage_load(
+ &mut self,
+ module: &crate::Module,
+ var_handle: Handle<crate::GlobalVariable>,
+ result_ty: TypeResolution,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ match *result_ty.inner_with(&module.types) {
+ crate::TypeInner::Scalar(scalar) => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ let cast = scalar.kind.to_hlsl_cast();
+ write!(self.out, "{cast}({var_name}.Load(")?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, "))")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Vector { size, scalar } => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ let cast = scalar.kind.to_hlsl_cast();
+ write!(self.out, "{}({}.Load{}(", cast, var_name, size as u8)?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, "))")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ write!(
+ self.out,
+ "{}{}x{}(",
+ scalar.to_hlsl_str()?,
+ columns as u8,
+ rows as u8,
+ )?;
+
+ // Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
+ let row_stride = Alignment::from(rows) * scalar.width as u32;
+ let iter = (0..columns as u32).map(|i| {
+ let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
+ (TypeResolution::Value(ty_inner), i * row_stride)
+ });
+ self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ stride,
+ } => {
+ let constructor = super::help::WrappedConstructor {
+ ty: result_ty.handle().unwrap(),
+ };
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ write!(self.out, "(")?;
+ let iter = (0..size.get()).map(|i| (TypeResolution::Handle(base), stride * i));
+ self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ let constructor = super::help::WrappedConstructor {
+ ty: result_ty.handle().unwrap(),
+ };
+ self.write_wrapped_constructor_function_name(module, constructor)?;
+ write!(self.out, "(")?;
+ let iter = members
+ .iter()
+ .map(|m| (TypeResolution::Handle(m.ty), m.offset));
+ self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ _ => unreachable!(),
+ }
+ Ok(())
+ }
+
+ fn write_store_value(
+ &mut self,
+ module: &crate::Module,
+ value: &StoreValue,
+ func_ctx: &FunctionCtx,
+ ) -> BackendResult {
+ match *value {
+ StoreValue::Expression(expr) => self.write_expr(module, expr, func_ctx)?,
+ StoreValue::TempIndex {
+ depth,
+ index,
+ ty: _,
+ } => write!(self.out, "{STORE_TEMP_NAME}{depth}[{index}]")?,
+ StoreValue::TempAccess {
+ depth,
+ base,
+ member_index,
+ } => {
+ let name = &self.names[&NameKey::StructMember(base, member_index)];
+ write!(self.out, "{STORE_TEMP_NAME}{depth}.{name}")?
+ }
+ }
+ Ok(())
+ }
+
+ /// Helper function to write down the Store operation on a `ByteAddressBuffer`.
+ pub(super) fn write_storage_store(
+ &mut self,
+ module: &crate::Module,
+ var_handle: Handle<crate::GlobalVariable>,
+ value: StoreValue,
+ func_ctx: &FunctionCtx,
+ level: crate::back::Level,
+ ) -> BackendResult {
+ let temp_resolution;
+ let ty_resolution = match value {
+ StoreValue::Expression(expr) => &func_ctx.info[expr].ty,
+ StoreValue::TempIndex {
+ depth: _,
+ index: _,
+ ref ty,
+ } => ty,
+ StoreValue::TempAccess {
+ depth: _,
+ base,
+ member_index,
+ } => {
+ let ty_handle = match module.types[base].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ members[member_index as usize].ty
+ }
+ _ => unreachable!(),
+ };
+ temp_resolution = TypeResolution::Handle(ty_handle);
+ &temp_resolution
+ }
+ };
+ match *ty_resolution.inner_with(&module.types) {
+ crate::TypeInner::Scalar(_) => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "{level}{var_name}.Store(")?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, ", asuint(")?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, "));")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Vector { size, .. } => {
+ // working around the borrow checker in `self.write_expr`
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "{}{}.Store{}(", level, var_name, size as u8)?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ write!(self.out, ", asuint(")?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, "));")?;
+ self.temp_access_chain = chain;
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ // first, assign the value to a temporary
+ writeln!(self.out, "{level}{{")?;
+ let depth = level.0 + 1;
+ write!(
+ self.out,
+ "{}{}{}x{} {}{} = ",
+ level.next(),
+ scalar.to_hlsl_str()?,
+ columns as u8,
+ rows as u8,
+ STORE_TEMP_NAME,
+ depth,
+ )?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, ";")?;
+
+ // Note: Matrices containing vec3s, due to padding, act like they contain vec4s.
+ let row_stride = Alignment::from(rows) * scalar.width as u32;
+
+ // then iterate the stores
+ for i in 0..columns as u32 {
+ self.temp_access_chain
+ .push(SubAccess::Offset(i * row_stride));
+ let ty_inner = crate::TypeInner::Vector { size: rows, scalar };
+ let sv = StoreValue::TempIndex {
+ depth,
+ index: i,
+ ty: TypeResolution::Value(ty_inner),
+ };
+ self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
+ self.temp_access_chain.pop();
+ }
+ // done
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::TypeInner::Array {
+ base,
+ size: crate::ArraySize::Constant(size),
+ stride,
+ } => {
+ // first, assign the value to a temporary
+ writeln!(self.out, "{level}{{")?;
+ write!(self.out, "{}", level.next())?;
+ self.write_value_type(module, &module.types[base].inner)?;
+ let depth = level.next().0;
+ write!(self.out, " {STORE_TEMP_NAME}{depth}")?;
+ self.write_array_size(module, base, crate::ArraySize::Constant(size))?;
+ write!(self.out, " = ")?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ // then iterate the stores
+ for i in 0..size.get() {
+ self.temp_access_chain.push(SubAccess::Offset(i * stride));
+ let sv = StoreValue::TempIndex {
+ depth,
+ index: i,
+ ty: TypeResolution::Handle(base),
+ };
+ self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
+ self.temp_access_chain.pop();
+ }
+ // done
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ // first, assign the value to a temporary
+ writeln!(self.out, "{level}{{")?;
+ let depth = level.next().0;
+ let struct_ty = ty_resolution.handle().unwrap();
+ let struct_name = &self.names[&NameKey::Type(struct_ty)];
+ write!(
+ self.out,
+ "{}{} {}{} = ",
+ level.next(),
+ struct_name,
+ STORE_TEMP_NAME,
+ depth
+ )?;
+ self.write_store_value(module, &value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ // then iterate the stores
+ for (i, member) in members.iter().enumerate() {
+ self.temp_access_chain
+ .push(SubAccess::Offset(member.offset));
+ let sv = StoreValue::TempAccess {
+ depth,
+ base: struct_ty,
+ member_index: i as u32,
+ };
+ self.write_storage_store(module, var_handle, sv, func_ctx, level.next())?;
+ self.temp_access_chain.pop();
+ }
+ // done
+ writeln!(self.out, "{level}}}")?;
+ }
+ _ => unreachable!(),
+ }
+ Ok(())
+ }
+
+ /// Set [`temp_access_chain`] to compute the byte offset of `cur_expr`.
+ ///
+ /// The `cur_expr` expression must be a reference to a global
+ /// variable in the [`Storage`] address space, or a chain of
+ /// [`Access`] and [`AccessIndex`] expressions referring to some
+ /// component of such a global.
+ ///
+ /// [`temp_access_chain`]: super::Writer::temp_access_chain
+ /// [`Storage`]: crate::AddressSpace::Storage
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ pub(super) fn fill_access_chain(
+ &mut self,
+ module: &crate::Module,
+ mut cur_expr: Handle<crate::Expression>,
+ func_ctx: &FunctionCtx,
+ ) -> Result<Handle<crate::GlobalVariable>, Error> {
+ enum AccessIndex {
+ Expression(Handle<crate::Expression>),
+ Constant(u32),
+ }
+ enum Parent<'a> {
+ Array { stride: u32 },
+ Struct(&'a [crate::StructMember]),
+ }
+ self.temp_access_chain.clear();
+
+ loop {
+ let (next_expr, access_index) = match func_ctx.expressions[cur_expr] {
+ crate::Expression::GlobalVariable(handle) => return Ok(handle),
+ crate::Expression::Access { base, index } => (base, AccessIndex::Expression(index)),
+ crate::Expression::AccessIndex { base, index } => {
+ (base, AccessIndex::Constant(index))
+ }
+ ref other => {
+ return Err(Error::Unimplemented(format!("Pointer access of {other:?}")))
+ }
+ };
+
+ let parent = match *func_ctx.resolve_type(next_expr, &module.types) {
+ crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
+ crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members),
+ crate::TypeInner::Array { stride, .. } => Parent::Array { stride },
+ crate::TypeInner::Vector { scalar, .. } => Parent::Array {
+ stride: scalar.width as u32,
+ },
+ crate::TypeInner::Matrix { rows, scalar, .. } => Parent::Array {
+ // The stride between matrices is the count of rows as this is how
+ // long each column is.
+ stride: Alignment::from(rows) * scalar.width as u32,
+ },
+ _ => unreachable!(),
+ },
+ crate::TypeInner::ValuePointer { scalar, .. } => Parent::Array {
+ stride: scalar.width as u32,
+ },
+ _ => unreachable!(),
+ };
+
+ let sub = match (parent, access_index) {
+ (Parent::Array { stride }, AccessIndex::Expression(value)) => {
+ SubAccess::Index { value, stride }
+ }
+ (Parent::Array { stride }, AccessIndex::Constant(index)) => {
+ SubAccess::Offset(stride * index)
+ }
+ (Parent::Struct(members), AccessIndex::Constant(index)) => {
+ SubAccess::Offset(members[index as usize].offset)
+ }
+ (Parent::Struct(_), AccessIndex::Expression(_)) => unreachable!(),
+ };
+
+ self.temp_access_chain.push(sub);
+ cur_expr = next_expr;
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/hlsl/writer.rs b/third_party/rust/naga/src/back/hlsl/writer.rs
new file mode 100644
index 0000000000..43f7212837
--- /dev/null
+++ b/third_party/rust/naga/src/back/hlsl/writer.rs
@@ -0,0 +1,3366 @@
+use super::{
+ help::{WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess},
+ storage::StoreValue,
+ BackendResult, Error, Options,
+};
+use crate::{
+ back,
+ proc::{self, NameKey},
+ valid, Handle, Module, ScalarKind, ShaderStage, TypeInner,
+};
+use std::{fmt, mem};
+
+const LOCATION_SEMANTIC: &str = "LOC";
+const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
+const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
+const SPECIAL_FIRST_VERTEX: &str = "first_vertex";
+const SPECIAL_FIRST_INSTANCE: &str = "first_instance";
+const SPECIAL_OTHER: &str = "other";
+
+pub(crate) const MODF_FUNCTION: &str = "naga_modf";
+pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
+
+struct EpStructMember {
+ name: String,
+ ty: Handle<crate::Type>,
+ // technically, this should always be `Some`
+ binding: Option<crate::Binding>,
+ index: u32,
+}
+
+/// Structure contains information required for generating
+/// wrapped structure of all entry points arguments
+struct EntryPointBinding {
+ /// Name of the fake EP argument that contains the struct
+ /// with all the flattened input data.
+ arg_name: String,
+ /// Generated structure name
+ ty_name: String,
+ /// Members of generated structure
+ members: Vec<EpStructMember>,
+}
+
+pub(super) struct EntryPointInterface {
+ /// If `Some`, the input of an entry point is gathered in a special
+ /// struct with members sorted by binding.
+ /// The `EntryPointBinding::members` array is sorted by index,
+ /// so that we can walk it in `write_ep_arguments_initialization`.
+ input: Option<EntryPointBinding>,
+ /// If `Some`, the output of an entry point is flattened.
+ /// The `EntryPointBinding::members` array is sorted by binding,
+ /// So that we can walk it in `Statement::Return` handler.
+ output: Option<EntryPointBinding>,
+}
+
+#[derive(Clone, Eq, PartialEq, PartialOrd, Ord)]
+enum InterfaceKey {
+ Location(u32),
+ BuiltIn(crate::BuiltIn),
+ Other,
+}
+
+impl InterfaceKey {
+ const fn new(binding: Option<&crate::Binding>) -> Self {
+ match binding {
+ Some(&crate::Binding::Location { location, .. }) => Self::Location(location),
+ Some(&crate::Binding::BuiltIn(built_in)) => Self::BuiltIn(built_in),
+ None => Self::Other,
+ }
+ }
+}
+
+#[derive(Copy, Clone, PartialEq)]
+enum Io {
+ Input,
+ Output,
+}
+
+impl<'a, W: fmt::Write> super::Writer<'a, W> {
+ pub fn new(out: W, options: &'a Options) -> Self {
+ Self {
+ out,
+ names: crate::FastHashMap::default(),
+ namer: proc::Namer::default(),
+ options,
+ entry_point_io: Vec::new(),
+ named_expressions: crate::NamedExpressions::default(),
+ wrapped: super::Wrapped::default(),
+ temp_access_chain: Vec::new(),
+ need_bake_expressions: Default::default(),
+ }
+ }
+
+ fn reset(&mut self, module: &Module) {
+ self.names.clear();
+ self.namer.reset(
+ module,
+ super::keywords::RESERVED,
+ super::keywords::TYPES,
+ super::keywords::RESERVED_CASE_INSENSITIVE,
+ &[],
+ &mut self.names,
+ );
+ self.entry_point_io.clear();
+ self.named_expressions.clear();
+ self.wrapped.clear();
+ self.need_bake_expressions.clear();
+ }
+
+ /// Helper method used to find which expressions of a given function require baking
+ ///
+ /// # Notes
+ /// Clears `need_bake_expressions` set before adding to it
+ fn update_expressions_to_bake(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ info: &valid::FunctionInfo,
+ ) {
+ use crate::Expression;
+ self.need_bake_expressions.clear();
+ for (fun_handle, expr) in func.expressions.iter() {
+ let expr_info = &info[fun_handle];
+ let min_ref_count = func.expressions[fun_handle].bake_ref_count();
+ if min_ref_count <= expr_info.ref_count {
+ self.need_bake_expressions.insert(fun_handle);
+ }
+
+ if let Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } = *expr
+ {
+ match fun {
+ crate::MathFunction::Asinh
+ | crate::MathFunction::Acosh
+ | crate::MathFunction::Atanh
+ | crate::MathFunction::Unpack2x16float
+ | crate::MathFunction::Unpack2x16snorm
+ | crate::MathFunction::Unpack2x16unorm
+ | crate::MathFunction::Unpack4x8snorm
+ | crate::MathFunction::Unpack4x8unorm
+ | crate::MathFunction::Pack2x16float
+ | crate::MathFunction::Pack2x16snorm
+ | crate::MathFunction::Pack2x16unorm
+ | crate::MathFunction::Pack4x8snorm
+ | crate::MathFunction::Pack4x8unorm => {
+ self.need_bake_expressions.insert(arg);
+ }
+ crate::MathFunction::ExtractBits => {
+ self.need_bake_expressions.insert(arg);
+ self.need_bake_expressions.insert(arg1.unwrap());
+ self.need_bake_expressions.insert(arg2.unwrap());
+ }
+ crate::MathFunction::InsertBits => {
+ self.need_bake_expressions.insert(arg);
+ self.need_bake_expressions.insert(arg1.unwrap());
+ self.need_bake_expressions.insert(arg2.unwrap());
+ self.need_bake_expressions.insert(arg3.unwrap());
+ }
+ crate::MathFunction::CountLeadingZeros => {
+ let inner = info[fun_handle].ty.inner_with(&module.types);
+ if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
+ self.need_bake_expressions.insert(arg);
+ }
+ }
+ _ => {}
+ }
+ }
+
+ if let Expression::Derivative { axis, ctrl, expr } = *expr {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
+ self.need_bake_expressions.insert(expr);
+ }
+ }
+ }
+ }
+
+ pub fn write(
+ &mut self,
+ module: &Module,
+ module_info: &valid::ModuleInfo,
+ ) -> Result<super::ReflectionInfo, Error> {
+ self.reset(module);
+
+ // Write special constants, if needed
+ if let Some(ref bt) = self.options.special_constants_binding {
+ writeln!(self.out, "struct {SPECIAL_CBUF_TYPE} {{")?;
+ writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_VERTEX)?;
+ writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_FIRST_INSTANCE)?;
+ writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
+ writeln!(self.out, "}};")?;
+ write!(
+ self.out,
+ "ConstantBuffer<{}> {}: register(b{}",
+ SPECIAL_CBUF_TYPE, SPECIAL_CBUF_VAR, bt.register
+ )?;
+ if bt.space != 0 {
+ write!(self.out, ", space{}", bt.space)?;
+ }
+ writeln!(self.out, ");")?;
+
+ // Extra newline for readability
+ writeln!(self.out)?;
+ }
+
+ // Save all entry point output types
+ let ep_results = module
+ .entry_points
+ .iter()
+ .map(|ep| (ep.stage, ep.function.result.clone()))
+ .collect::<Vec<(ShaderStage, Option<crate::FunctionResult>)>>();
+
+ self.write_all_mat_cx2_typedefs_and_functions(module)?;
+
+ // Write all structs
+ for (handle, ty) in module.types.iter() {
+ if let TypeInner::Struct { ref members, span } = ty.inner {
+ if module.types[members.last().unwrap().ty]
+ .inner
+ .is_dynamically_sized(&module.types)
+ {
+ // unsized arrays can only be in storage buffers,
+ // for which we use `ByteAddressBuffer` anyway.
+ continue;
+ }
+
+ let ep_result = ep_results.iter().find(|e| {
+ if let Some(ref result) = e.1 {
+ result.ty == handle
+ } else {
+ false
+ }
+ });
+
+ self.write_struct(
+ module,
+ handle,
+ members,
+ span,
+ ep_result.map(|r| (r.0, Io::Output)),
+ )?;
+ writeln!(self.out)?;
+ }
+ }
+
+ self.write_special_functions(module)?;
+
+ self.write_wrapped_compose_functions(module, &module.const_expressions)?;
+
+ // Write all named constants
+ let mut constants = module
+ .constants
+ .iter()
+ .filter(|&(_, c)| c.name.is_some())
+ .peekable();
+ while let Some((handle, _)) = constants.next() {
+ self.write_global_constant(module, handle)?;
+ // Add extra newline for readability on last iteration
+ if constants.peek().is_none() {
+ writeln!(self.out)?;
+ }
+ }
+
+ // Write all globals
+ for (ty, _) in module.global_variables.iter() {
+ self.write_global(module, ty)?;
+ }
+
+ if !module.global_variables.is_empty() {
+ // Add extra newline for readability
+ writeln!(self.out)?;
+ }
+
+ // Write all entry points wrapped structs
+ for (index, ep) in module.entry_points.iter().enumerate() {
+ let ep_name = self.names[&NameKey::EntryPoint(index as u16)].clone();
+ let ep_io = self.write_ep_interface(module, &ep.function, ep.stage, &ep_name)?;
+ self.entry_point_io.push(ep_io);
+ }
+
+ // Write all regular functions
+ for (handle, function) in module.functions.iter() {
+ let info = &module_info[handle];
+
+ // Check if all of the globals are accessible
+ if !self.options.fake_missing_bindings {
+ if let Some((var_handle, _)) =
+ module
+ .global_variables
+ .iter()
+ .find(|&(var_handle, var)| match var.binding {
+ Some(ref binding) if !info[var_handle].is_empty() => {
+ self.options.resolve_resource_binding(binding).is_err()
+ }
+ _ => false,
+ })
+ {
+ log::info!(
+ "Skipping function {:?} (name {:?}) because global {:?} is inaccessible",
+ handle,
+ function.name,
+ var_handle
+ );
+ continue;
+ }
+ }
+
+ let ctx = back::FunctionCtx {
+ ty: back::FunctionType::Function(handle),
+ info,
+ expressions: &function.expressions,
+ named_expressions: &function.named_expressions,
+ };
+ let name = self.names[&NameKey::Function(handle)].clone();
+
+ self.write_wrapped_functions(module, &ctx)?;
+
+ self.write_function(module, name.as_str(), function, &ctx, info)?;
+
+ writeln!(self.out)?;
+ }
+
+ let mut entry_point_names = Vec::with_capacity(module.entry_points.len());
+
+ // Write all entry points
+ for (index, ep) in module.entry_points.iter().enumerate() {
+ let info = module_info.get_entry_point(index);
+
+ if !self.options.fake_missing_bindings {
+ let mut ep_error = None;
+ for (var_handle, var) in module.global_variables.iter() {
+ match var.binding {
+ Some(ref binding) if !info[var_handle].is_empty() => {
+ if let Err(err) = self.options.resolve_resource_binding(binding) {
+ ep_error = Some(err);
+ break;
+ }
+ }
+ _ => {}
+ }
+ }
+ if let Some(err) = ep_error {
+ entry_point_names.push(Err(err));
+ continue;
+ }
+ }
+
+ let ctx = back::FunctionCtx {
+ ty: back::FunctionType::EntryPoint(index as u16),
+ info,
+ expressions: &ep.function.expressions,
+ named_expressions: &ep.function.named_expressions,
+ };
+
+ self.write_wrapped_functions(module, &ctx)?;
+
+ if ep.stage == ShaderStage::Compute {
+ // HLSL is calling workgroup size "num threads"
+ let num_threads = ep.workgroup_size;
+ writeln!(
+ self.out,
+ "[numthreads({}, {}, {})]",
+ num_threads[0], num_threads[1], num_threads[2]
+ )?;
+ }
+
+ let name = self.names[&NameKey::EntryPoint(index as u16)].clone();
+ self.write_function(module, &name, &ep.function, &ctx, info)?;
+
+ if index < module.entry_points.len() - 1 {
+ writeln!(self.out)?;
+ }
+
+ entry_point_names.push(Ok(name));
+ }
+
+ Ok(super::ReflectionInfo { entry_point_names })
+ }
+
+ fn write_modifier(&mut self, binding: &crate::Binding) -> BackendResult {
+ match *binding {
+ crate::Binding::BuiltIn(crate::BuiltIn::Position { invariant: true }) => {
+ write!(self.out, "precise ")?;
+ }
+ crate::Binding::Location {
+ interpolation,
+ sampling,
+ ..
+ } => {
+ if let Some(interpolation) = interpolation {
+ if let Some(string) = interpolation.to_hlsl_str() {
+ write!(self.out, "{string} ")?
+ }
+ }
+
+ if let Some(sampling) = sampling {
+ if let Some(string) = sampling.to_hlsl_str() {
+ write!(self.out, "{string} ")?
+ }
+ }
+ }
+ crate::Binding::BuiltIn(_) => {}
+ }
+
+ Ok(())
+ }
+
+ //TODO: we could force fragment outputs to always go through `entry_point_io.output` path
+ // if they are struct, so that the `stage` argument here could be omitted.
+ fn write_semantic(
+ &mut self,
+ binding: &crate::Binding,
+ stage: Option<(ShaderStage, Io)>,
+ ) -> BackendResult {
+ match *binding {
+ crate::Binding::BuiltIn(builtin) => {
+ let builtin_str = builtin.to_hlsl_str()?;
+ write!(self.out, " : {builtin_str}")?;
+ }
+ crate::Binding::Location {
+ second_blend_source: true,
+ ..
+ } => {
+ write!(self.out, " : SV_Target1")?;
+ }
+ crate::Binding::Location {
+ location,
+ second_blend_source: false,
+ ..
+ } => {
+ if stage == Some((crate::ShaderStage::Fragment, Io::Output)) {
+ write!(self.out, " : SV_Target{location}")?;
+ } else {
+ write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ fn write_interface_struct(
+ &mut self,
+ module: &Module,
+ shader_stage: (ShaderStage, Io),
+ struct_name: String,
+ mut members: Vec<EpStructMember>,
+ ) -> Result<EntryPointBinding, Error> {
+ // Sort the members so that first come the user-defined varyings
+ // in ascending locations, and then built-ins. This allows VS and FS
+ // interfaces to match with regards to order.
+ members.sort_by_key(|m| InterfaceKey::new(m.binding.as_ref()));
+
+ write!(self.out, "struct {struct_name}")?;
+ writeln!(self.out, " {{")?;
+ for m in members.iter() {
+ write!(self.out, "{}", back::INDENT)?;
+ if let Some(ref binding) = m.binding {
+ self.write_modifier(binding)?;
+ }
+ self.write_type(module, m.ty)?;
+ write!(self.out, " {}", &m.name)?;
+ if let Some(ref binding) = m.binding {
+ self.write_semantic(binding, Some(shader_stage))?;
+ }
+ writeln!(self.out, ";")?;
+ }
+ writeln!(self.out, "}};")?;
+ writeln!(self.out)?;
+
+ match shader_stage.1 {
+ Io::Input => {
+ // bring back the original order
+ members.sort_by_key(|m| m.index);
+ }
+ Io::Output => {
+ // keep it sorted by binding
+ }
+ }
+
+ Ok(EntryPointBinding {
+ arg_name: self.namer.call(struct_name.to_lowercase().as_str()),
+ ty_name: struct_name,
+ members,
+ })
+ }
+
+ /// Flatten all entry point arguments into a single struct.
+ /// This is needed since we need to re-order them: first placing user locations,
+ /// then built-ins.
+ fn write_ep_input_struct(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ stage: ShaderStage,
+ entry_point_name: &str,
+ ) -> Result<EntryPointBinding, Error> {
+ let struct_name = format!("{stage:?}Input_{entry_point_name}");
+
+ let mut fake_members = Vec::new();
+ for arg in func.arguments.iter() {
+ match module.types[arg.ty].inner {
+ TypeInner::Struct { ref members, .. } => {
+ for member in members.iter() {
+ let name = self.namer.call_or(&member.name, "member");
+ let index = fake_members.len() as u32;
+ fake_members.push(EpStructMember {
+ name,
+ ty: member.ty,
+ binding: member.binding.clone(),
+ index,
+ });
+ }
+ }
+ _ => {
+ let member_name = self.namer.call_or(&arg.name, "member");
+ let index = fake_members.len() as u32;
+ fake_members.push(EpStructMember {
+ name: member_name,
+ ty: arg.ty,
+ binding: arg.binding.clone(),
+ index,
+ });
+ }
+ }
+ }
+
+ self.write_interface_struct(module, (stage, Io::Input), struct_name, fake_members)
+ }
+
+ /// Flatten all entry point results into a single struct.
+ /// This is needed since we need to re-order them: first placing user locations,
+ /// then built-ins.
+ fn write_ep_output_struct(
+ &mut self,
+ module: &Module,
+ result: &crate::FunctionResult,
+ stage: ShaderStage,
+ entry_point_name: &str,
+ ) -> Result<EntryPointBinding, Error> {
+ let struct_name = format!("{stage:?}Output_{entry_point_name}");
+
+ let mut fake_members = Vec::new();
+ let empty = [];
+ let members = match module.types[result.ty].inner {
+ TypeInner::Struct { ref members, .. } => members,
+ ref other => {
+ log::error!("Unexpected {:?} output type without a binding", other);
+ &empty[..]
+ }
+ };
+
+ for member in members.iter() {
+ let member_name = self.namer.call_or(&member.name, "member");
+ let index = fake_members.len() as u32;
+ fake_members.push(EpStructMember {
+ name: member_name,
+ ty: member.ty,
+ binding: member.binding.clone(),
+ index,
+ });
+ }
+
+ self.write_interface_struct(module, (stage, Io::Output), struct_name, fake_members)
+ }
+
+ /// Writes special interface structures for an entry point. The special structures have
+ /// all the fields flattened into them and sorted by binding. They are only needed for
+ /// VS outputs and FS inputs, so that these interfaces match.
+ fn write_ep_interface(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ stage: ShaderStage,
+ ep_name: &str,
+ ) -> Result<EntryPointInterface, Error> {
+ Ok(EntryPointInterface {
+ input: if !func.arguments.is_empty() && stage == ShaderStage::Fragment {
+ Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
+ } else {
+ None
+ },
+ output: match func.result {
+ Some(ref fr) if fr.binding.is_none() && stage == ShaderStage::Vertex => {
+ Some(self.write_ep_output_struct(module, fr, stage, ep_name)?)
+ }
+ _ => None,
+ },
+ })
+ }
+
+ /// Write an entry point preface that initializes the arguments as specified in IR.
+ fn write_ep_arguments_initialization(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ ep_index: u16,
+ ) -> BackendResult {
+ let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
+ Some(ep_input) => ep_input,
+ None => return Ok(()),
+ };
+ let mut fake_iter = ep_input.members.iter();
+ for (arg_index, arg) in func.arguments.iter().enumerate() {
+ write!(self.out, "{}", back::INDENT)?;
+ self.write_type(module, arg.ty)?;
+ let arg_name = &self.names[&NameKey::EntryPointArgument(ep_index, arg_index as u32)];
+ write!(self.out, " {arg_name}")?;
+ match module.types[arg.ty].inner {
+ TypeInner::Array { base, size, .. } => {
+ self.write_array_size(module, base, size)?;
+ let fake_member = fake_iter.next().unwrap();
+ writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
+ }
+ TypeInner::Struct { ref members, .. } => {
+ write!(self.out, " = {{ ")?;
+ for index in 0..members.len() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ let fake_member = fake_iter.next().unwrap();
+ write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
+ }
+ writeln!(self.out, " }};")?;
+ }
+ _ => {
+ let fake_member = fake_iter.next().unwrap();
+ writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
+ }
+ }
+ }
+ assert!(fake_iter.next().is_none());
+ Ok(())
+ }
+
+ /// Helper method used to write global variables
+ /// # Notes
+ /// Always adds a newline
+ fn write_global(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::GlobalVariable>,
+ ) -> BackendResult {
+ let global = &module.global_variables[handle];
+ let inner = &module.types[global.ty].inner;
+
+ if let Some(ref binding) = global.binding {
+ if let Err(err) = self.options.resolve_resource_binding(binding) {
+ log::info!(
+ "Skipping global {:?} (name {:?}) for being inaccessible: {}",
+ handle,
+ global.name,
+ err,
+ );
+ return Ok(());
+ }
+ }
+
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-variable-register
+ let register_ty = match global.space {
+ crate::AddressSpace::Function => unreachable!("Function address space"),
+ crate::AddressSpace::Private => {
+ write!(self.out, "static ")?;
+ self.write_type(module, global.ty)?;
+ ""
+ }
+ crate::AddressSpace::WorkGroup => {
+ write!(self.out, "groupshared ")?;
+ self.write_type(module, global.ty)?;
+ ""
+ }
+ crate::AddressSpace::Uniform => {
+ // constant buffer declarations are expected to be inlined, e.g.
+ // `cbuffer foo: register(b0) { field1: type1; }`
+ write!(self.out, "cbuffer")?;
+ "b"
+ }
+ crate::AddressSpace::Storage { access } => {
+ let (prefix, register) = if access.contains(crate::StorageAccess::STORE) {
+ ("RW", "u")
+ } else {
+ ("", "t")
+ };
+ write!(self.out, "{prefix}ByteAddressBuffer")?;
+ register
+ }
+ crate::AddressSpace::Handle => {
+ let handle_ty = match *inner {
+ TypeInner::BindingArray { ref base, .. } => &module.types[*base].inner,
+ _ => inner,
+ };
+
+ let register = match *handle_ty {
+ TypeInner::Sampler { .. } => "s",
+ // all storage textures are UAV, unconditionally
+ TypeInner::Image {
+ class: crate::ImageClass::Storage { .. },
+ ..
+ } => "u",
+ _ => "t",
+ };
+ self.write_type(module, global.ty)?;
+ register
+ }
+ crate::AddressSpace::PushConstant => {
+ // The type of the push constants will be wrapped in `ConstantBuffer`
+ write!(self.out, "ConstantBuffer<")?;
+ "b"
+ }
+ };
+
+ // If the global is a push constant write the type now because it will be a
+ // generic argument to `ConstantBuffer`
+ if global.space == crate::AddressSpace::PushConstant {
+ self.write_global_type(module, global.ty)?;
+
+ // need to write the array size if the type was emitted with `write_type`
+ if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ // Close the angled brackets for the generic argument
+ write!(self.out, ">")?;
+ }
+
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, " {name}")?;
+
+ // Push constants need to be assigned a binding explicitly by the consumer
+ // since naga has no way to know the binding from the shader alone
+ if global.space == crate::AddressSpace::PushConstant {
+ let target = self
+ .options
+ .push_constants_target
+ .as_ref()
+ .expect("No bind target was defined for the push constants block");
+ write!(self.out, ": register(b{}", target.register)?;
+ if target.space != 0 {
+ write!(self.out, ", space{}", target.space)?;
+ }
+ write!(self.out, ")")?;
+ }
+
+ if let Some(ref binding) = global.binding {
+ // this was already resolved earlier when we started evaluating an entry point.
+ let bt = self.options.resolve_resource_binding(binding).unwrap();
+
+ // need to write the binding array size if the type was emitted with `write_type`
+ if let TypeInner::BindingArray { base, size, .. } = module.types[global.ty].inner {
+ if let Some(overridden_size) = bt.binding_array_size {
+ write!(self.out, "[{overridden_size}]")?;
+ } else {
+ self.write_array_size(module, base, size)?;
+ }
+ }
+
+ write!(self.out, " : register({}{}", register_ty, bt.register)?;
+ if bt.space != 0 {
+ write!(self.out, ", space{}", bt.space)?;
+ }
+ write!(self.out, ")")?;
+ } else {
+ // need to write the array size if the type was emitted with `write_type`
+ if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ if global.space == crate::AddressSpace::Private {
+ write!(self.out, " = ")?;
+ if let Some(init) = global.init {
+ self.write_const_expression(module, init)?;
+ } else {
+ self.write_default_init(module, global.ty)?;
+ }
+ }
+ }
+
+ if global.space == crate::AddressSpace::Uniform {
+ write!(self.out, " {{ ")?;
+
+ self.write_global_type(module, global.ty)?;
+
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::GlobalVariable(handle)]
+ )?;
+
+ // need to write the array size if the type was emitted with `write_type`
+ if let TypeInner::Array { base, size, .. } = module.types[global.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ writeln!(self.out, "; }}")?;
+ } else {
+ writeln!(self.out, ";")?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write global constants
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_global_constant(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Constant>,
+ ) -> BackendResult {
+ write!(self.out, "static const ")?;
+ let constant = &module.constants[handle];
+ self.write_type(module, constant.ty)?;
+ let name = &self.names[&NameKey::Constant(handle)];
+ write!(self.out, " {}", name)?;
+ // Write size for array type
+ if let TypeInner::Array { base, size, .. } = module.types[constant.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, " = ")?;
+ self.write_const_expression(module, constant.init)?;
+ writeln!(self.out, ";")?;
+ Ok(())
+ }
+
+ pub(super) fn write_array_size(
+ &mut self,
+ module: &Module,
+ base: Handle<crate::Type>,
+ size: crate::ArraySize,
+ ) -> BackendResult {
+ write!(self.out, "[")?;
+
+ match size {
+ crate::ArraySize::Constant(size) => {
+ write!(self.out, "{size}")?;
+ }
+ crate::ArraySize::Dynamic => unreachable!(),
+ }
+
+ write!(self.out, "]")?;
+
+ if let TypeInner::Array {
+ base: next_base,
+ size: next_size,
+ ..
+ } = module.types[base].inner
+ {
+ self.write_array_size(module, next_base, next_size)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write structs
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_struct(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Type>,
+ members: &[crate::StructMember],
+ span: u32,
+ shader_stage: Option<(ShaderStage, Io)>,
+ ) -> BackendResult {
+ // Write struct name
+ let struct_name = &self.names[&NameKey::Type(handle)];
+ writeln!(self.out, "struct {struct_name} {{")?;
+
+ let mut last_offset = 0;
+ for (index, member) in members.iter().enumerate() {
+ if member.binding.is_none() && member.offset > last_offset {
+ // using int as padding should work as long as the backend
+ // doesn't support a type that's less than 4 bytes in size
+ // (Error::UnsupportedScalar catches this)
+ let padding = (member.offset - last_offset) / 4;
+ for i in 0..padding {
+ writeln!(self.out, "{}int _pad{}_{};", back::INDENT, index, i)?;
+ }
+ }
+ let ty_inner = &module.types[member.ty].inner;
+ last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx());
+
+ // The indentation is only for readability
+ write!(self.out, "{}", back::INDENT)?;
+
+ match module.types[member.ty].inner {
+ TypeInner::Array { base, size, .. } => {
+ // HLSL arrays are written as `type name[size]`
+
+ self.write_global_type(module, member.ty)?;
+
+ // Write `name`
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::StructMember(handle, index as u32)]
+ )?;
+ // Write [size]
+ self.write_array_size(module, base, size)?;
+ }
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ TypeInner::Matrix {
+ rows,
+ columns,
+ scalar,
+ } if member.binding.is_none() && rows == crate::VectorSize::Bi => {
+ let vec_ty = crate::TypeInner::Vector { size: rows, scalar };
+ let field_name_key = NameKey::StructMember(handle, index as u32);
+
+ for i in 0..columns as u8 {
+ if i != 0 {
+ write!(self.out, "; ")?;
+ }
+ self.write_value_type(module, &vec_ty)?;
+ write!(self.out, " {}_{}", &self.names[&field_name_key], i)?;
+ }
+ }
+ _ => {
+ // Write modifier before type
+ if let Some(ref binding) = member.binding {
+ self.write_modifier(binding)?;
+ }
+
+ // Even though Naga IR matrices are column-major, we must describe
+ // matrices passed from the CPU as being in row-major order.
+ // See the module-level block comment in mod.rs for details.
+ if let TypeInner::Matrix { .. } = module.types[member.ty].inner {
+ write!(self.out, "row_major ")?;
+ }
+
+ // Write the member type and name
+ self.write_type(module, member.ty)?;
+ write!(
+ self.out,
+ " {}",
+ &self.names[&NameKey::StructMember(handle, index as u32)]
+ )?;
+ }
+ }
+
+ if let Some(ref binding) = member.binding {
+ self.write_semantic(binding, shader_stage)?;
+ };
+ writeln!(self.out, ";")?;
+ }
+
+ // add padding at the end since sizes of types don't get rounded up to their alignment in HLSL
+ if members.last().unwrap().binding.is_none() && span > last_offset {
+ let padding = (span - last_offset) / 4;
+ for i in 0..padding {
+ writeln!(self.out, "{}int _end_pad_{};", back::INDENT, i)?;
+ }
+ }
+
+ writeln!(self.out, "}};")?;
+ Ok(())
+ }
+
+ /// Helper method used to write global/structs non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ pub(super) fn write_global_type(
+ &mut self,
+ module: &Module,
+ ty: Handle<crate::Type>,
+ ) -> BackendResult {
+ let matrix_data = get_inner_matrix_data(module, ty);
+
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = matrix_data
+ {
+ write!(self.out, "__mat{}x2", columns as u8)?;
+ } else {
+ // Even though Naga IR matrices are column-major, we must describe
+ // matrices passed from the CPU as being in row-major order.
+ // See the module-level block comment in mod.rs for details.
+ if matrix_data.is_some() {
+ write!(self.out, "row_major ")?;
+ }
+
+ self.write_type(module, ty)?;
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ pub(super) fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ let inner = &module.types[ty].inner;
+ match *inner {
+ TypeInner::Struct { .. } => write!(self.out, "{}", self.names[&NameKey::Type(ty)])?,
+ // hlsl array has the size separated from the base type
+ TypeInner::Array { base, .. } | TypeInner::BindingArray { base, .. } => {
+ self.write_type(module, base)?
+ }
+ ref other => self.write_value_type(module, other)?,
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write value types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ pub(super) fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
+ match *inner {
+ TypeInner::Scalar(scalar) | TypeInner::Atomic(scalar) => {
+ write!(self.out, "{}", scalar.to_hlsl_str()?)?;
+ }
+ TypeInner::Vector { size, scalar } => {
+ write!(
+ self.out,
+ "{}{}",
+ scalar.to_hlsl_str()?,
+ back::vector_size_str(size)
+ )?;
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ // The IR supports only float matrix
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-matrix
+
+ // Because of the implicit transpose all matrices have in HLSL, we need to transpose the size as well.
+ write!(
+ self.out,
+ "{}{}x{}",
+ scalar.to_hlsl_str()?,
+ back::vector_size_str(columns),
+ back::vector_size_str(rows),
+ )?;
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ self.write_image_type(dim, arrayed, class)?;
+ }
+ TypeInner::Sampler { comparison } => {
+ let sampler = if comparison {
+ "SamplerComparisonState"
+ } else {
+ "SamplerState"
+ };
+ write!(self.out, "{sampler}")?;
+ }
+ // HLSL arrays are written as `type name[size]`
+ // Current code is written arrays only as `[size]`
+ // Base `type` and `name` should be written outside
+ TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
+ self.write_array_size(module, base, size)?;
+ }
+ _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write functions
+ /// # Notes
+ /// Ends in a newline
+ fn write_function(
+ &mut self,
+ module: &Module,
+ name: &str,
+ func: &crate::Function,
+ func_ctx: &back::FunctionCtx<'_>,
+ info: &valid::FunctionInfo,
+ ) -> BackendResult {
+ // Function Declaration Syntax - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-function-syntax
+
+ self.update_expressions_to_bake(module, func, info);
+
+ // Write modifier
+ if let Some(crate::FunctionResult {
+ binding:
+ Some(
+ ref binding @ crate::Binding::BuiltIn(crate::BuiltIn::Position {
+ invariant: true,
+ }),
+ ),
+ ..
+ }) = func.result
+ {
+ self.write_modifier(binding)?;
+ }
+
+ // Write return type
+ if let Some(ref result) = func.result {
+ match func_ctx.ty {
+ back::FunctionType::Function(_) => {
+ self.write_type(module, result.ty)?;
+ }
+ back::FunctionType::EntryPoint(index) => {
+ if let Some(ref ep_output) = self.entry_point_io[index as usize].output {
+ write!(self.out, "{}", ep_output.ty_name)?;
+ } else {
+ self.write_type(module, result.ty)?;
+ }
+ }
+ }
+ } else {
+ write!(self.out, "void")?;
+ }
+
+ // Write function name
+ write!(self.out, " {name}(")?;
+
+ let need_workgroup_variables_initialization =
+ self.need_workgroup_variables_initialization(func_ctx, module);
+
+ // Write function arguments for non entry point functions
+ match func_ctx.ty {
+ back::FunctionType::Function(handle) => {
+ for (index, arg) in func.arguments.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ // Write argument type
+ let arg_ty = match module.types[arg.ty].inner {
+ // pointers in function arguments are expected and resolve to `inout`
+ TypeInner::Pointer { base, .. } => {
+ //TODO: can we narrow this down to just `in` when possible?
+ write!(self.out, "inout ")?;
+ base
+ }
+ _ => arg.ty,
+ };
+ self.write_type(module, arg_ty)?;
+
+ let argument_name =
+ &self.names[&NameKey::FunctionArgument(handle, index as u32)];
+
+ // Write argument name. Space is important.
+ write!(self.out, " {argument_name}")?;
+ if let TypeInner::Array { base, size, .. } = module.types[arg_ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ }
+ }
+ back::FunctionType::EntryPoint(ep_index) => {
+ if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
+ write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?;
+ } else {
+ let stage = module.entry_points[ep_index as usize].stage;
+ for (index, arg) in func.arguments.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_type(module, arg.ty)?;
+
+ let argument_name =
+ &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)];
+
+ write!(self.out, " {argument_name}")?;
+ if let TypeInner::Array { base, size, .. } = module.types[arg.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ if let Some(ref binding) = arg.binding {
+ self.write_semantic(binding, Some((stage, Io::Input)))?;
+ }
+ }
+
+ if need_workgroup_variables_initialization {
+ if !func.arguments.is_empty() {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
+ }
+ }
+ }
+ }
+ // Ends of arguments
+ write!(self.out, ")")?;
+
+ // Write semantic if it present
+ if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
+ let stage = module.entry_points[index as usize].stage;
+ if let Some(crate::FunctionResult {
+ binding: Some(ref binding),
+ ..
+ }) = func.result
+ {
+ self.write_semantic(binding, Some((stage, Io::Output)))?;
+ }
+ }
+
+ // Function body start
+ writeln!(self.out)?;
+ writeln!(self.out, "{{")?;
+
+ if need_workgroup_variables_initialization {
+ self.write_workgroup_variables_initialization(func_ctx, module)?;
+ }
+
+ if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
+ self.write_ep_arguments_initialization(module, func, index)?;
+ }
+
+ // Write function local variables
+ for (handle, local) in func.local_variables.iter() {
+ // Write indentation (only for readability)
+ write!(self.out, "{}", back::INDENT)?;
+
+ // Write the local name
+ // The leading space is important
+ self.write_type(module, local.ty)?;
+ write!(self.out, " {}", self.names[&func_ctx.name_key(handle)])?;
+ // Write size for array type
+ if let TypeInner::Array { base, size, .. } = module.types[local.ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+
+ write!(self.out, " = ")?;
+ // Write the local initializer if needed
+ if let Some(init) = local.init {
+ self.write_expr(module, init, func_ctx)?;
+ } else {
+ // Zero initialize local variables
+ self.write_default_init(module, local.ty)?;
+ }
+
+ // Finish the local with `;` and add a newline (only for readability)
+ writeln!(self.out, ";")?
+ }
+
+ if !func.local_variables.is_empty() {
+ writeln!(self.out)?;
+ }
+
+ // Write the function body (statement list)
+ for sta in func.body.iter() {
+ // The indentation should always be 1 when writing the function body
+ self.write_stmt(module, sta, func_ctx, back::Level(1))?;
+ }
+
+ writeln!(self.out, "}}")?;
+
+ self.named_expressions.clear();
+
+ Ok(())
+ }
+
+ fn need_workgroup_variables_initialization(
+ &mut self,
+ func_ctx: &back::FunctionCtx,
+ module: &Module,
+ ) -> bool {
+ self.options.zero_initialize_workgroup_memory
+ && func_ctx.ty.is_compute_entry_point(module)
+ && module.global_variables.iter().any(|(handle, var)| {
+ !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ })
+ }
+
+ fn write_workgroup_variables_initialization(
+ &mut self,
+ func_ctx: &back::FunctionCtx,
+ module: &Module,
+ ) -> BackendResult {
+ let level = back::Level(1);
+
+ writeln!(
+ self.out,
+ "{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
+ )?;
+
+ let vars = module.global_variables.iter().filter(|&(handle, var)| {
+ !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ });
+
+ for (handle, var) in vars {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{}{} = ", level.next(), name)?;
+ self.write_default_init(module, var.ty)?;
+ writeln!(self.out, ";")?;
+ }
+
+ writeln!(self.out, "{level}}}")?;
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)
+ }
+
+ /// Helper method used to write statements
+ ///
+ /// # Notes
+ /// Always adds a newline
+ fn write_stmt(
+ &mut self,
+ module: &Module,
+ stmt: &crate::Statement,
+ func_ctx: &back::FunctionCtx<'_>,
+ level: back::Level,
+ ) -> BackendResult {
+ use crate::Statement;
+
+ match *stmt {
+ Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ let ptr_class = func_ctx.resolve_type(handle, &module.types).pointer_space();
+ let expr_name = if ptr_class.is_some() {
+ // HLSL can't save a pointer-valued expression in a variable,
+ // but we shouldn't ever need to: they should never be named expressions,
+ // and none of the expression types flagged by bake_ref_count can be pointer-valued.
+ None
+ } else if let Some(name) = func_ctx.named_expressions.get(&handle) {
+ // Front end provides names for all variables at the start of writing.
+ // But we write them to step by step. We need to recache them
+ // Otherwise, we could accidentally write variable name instead of full expression.
+ // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
+ Some(self.namer.call(name))
+ } else if self.need_bake_expressions.contains(&handle) {
+ Some(format!("_expr{}", handle.index()))
+ } else {
+ None
+ };
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{level}")?;
+ self.write_named_expr(module, handle, name, handle, func_ctx)?;
+ }
+ }
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Block(ref block) => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "{{")?;
+ for sta in block.iter() {
+ // Increase the indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, level.next())?
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "if (")?;
+ self.write_expr(module, condition, func_ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ let l2 = level.next();
+ for sta in accept {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+
+ // If there are no statements in the reject block we skip writing it
+ // This is only for readability
+ if !reject.is_empty() {
+ writeln!(self.out, "{level}}} else {{")?;
+
+ for sta in reject {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Kill => writeln!(self.out, "{level}discard;")?,
+ Statement::Return { value: None } => {
+ writeln!(self.out, "{level}return;")?;
+ }
+ Statement::Return { value: Some(expr) } => {
+ let base_ty_res = &func_ctx.info[expr].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+ if let TypeInner::Pointer { base, space: _ } = *resolved {
+ resolved = &module.types[base].inner;
+ }
+
+ if let TypeInner::Struct { .. } = *resolved {
+ // We can safely unwrap here, since we now we working with struct
+ let ty = base_ty_res.handle().unwrap();
+ let struct_name = &self.names[&NameKey::Type(ty)];
+ let variable_name = self.namer.call(&struct_name.to_lowercase());
+ write!(self.out, "{level}const {struct_name} {variable_name} = ",)?;
+ self.write_expr(module, expr, func_ctx)?;
+ writeln!(self.out, ";")?;
+
+ // for entry point returns, we may need to reshuffle the outputs into a different struct
+ let ep_output = match func_ctx.ty {
+ back::FunctionType::Function(_) => None,
+ back::FunctionType::EntryPoint(index) => {
+ self.entry_point_io[index as usize].output.as_ref()
+ }
+ };
+ let final_name = match ep_output {
+ Some(ep_output) => {
+ let final_name = self.namer.call(&variable_name);
+ write!(
+ self.out,
+ "{}const {} {} = {{ ",
+ level, ep_output.ty_name, final_name,
+ )?;
+ for (index, m) in ep_output.members.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ let member_name = &self.names[&NameKey::StructMember(ty, m.index)];
+ write!(self.out, "{variable_name}.{member_name}")?;
+ }
+ writeln!(self.out, " }};")?;
+ final_name
+ }
+ None => variable_name,
+ };
+ writeln!(self.out, "{level}return {final_name};")?;
+ } else {
+ write!(self.out, "{level}return ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ writeln!(self.out, ";")?
+ }
+ }
+ Statement::Store { pointer, value } => {
+ let ty_inner = func_ctx.resolve_type(pointer, &module.types);
+ if let Some(crate::AddressSpace::Storage { .. }) = ty_inner.pointer_space() {
+ let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
+ self.write_storage_store(
+ module,
+ var_handle,
+ StoreValue::Expression(value),
+ func_ctx,
+ level,
+ )?;
+ } else {
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ //
+ // We handle matrix Stores here directly (including sub accesses for Vectors and Scalars).
+ // Loads are handled by `Expression::AccessIndex` (since sub accesses work fine for Loads).
+ struct MatrixAccess {
+ base: Handle<crate::Expression>,
+ index: u32,
+ }
+ enum Index {
+ Expression(Handle<crate::Expression>),
+ Static(u32),
+ }
+
+ let get_members = |expr: Handle<crate::Expression>| {
+ let resolved = func_ctx.resolve_type(expr, &module.types);
+ match *resolved {
+ TypeInner::Pointer { base, .. } => match module.types[base].inner {
+ TypeInner::Struct { ref members, .. } => Some(members),
+ _ => None,
+ },
+ _ => None,
+ }
+ };
+
+ let mut matrix = None;
+ let mut vector = None;
+ let mut scalar = None;
+
+ let mut current_expr = pointer;
+ for _ in 0..3 {
+ let resolved = func_ctx.resolve_type(current_expr, &module.types);
+
+ match (resolved, &func_ctx.expressions[current_expr]) {
+ (
+ &TypeInner::Pointer { base: ty, .. },
+ &crate::Expression::AccessIndex { base, index },
+ ) if matches!(
+ module.types[ty].inner,
+ TypeInner::Matrix {
+ rows: crate::VectorSize::Bi,
+ ..
+ }
+ ) && get_members(base)
+ .map(|members| members[index as usize].binding.is_none())
+ == Some(true) =>
+ {
+ matrix = Some(MatrixAccess { base, index });
+ break;
+ }
+ (
+ &TypeInner::ValuePointer {
+ size: Some(crate::VectorSize::Bi),
+ ..
+ },
+ &crate::Expression::Access { base, index },
+ ) => {
+ vector = Some(Index::Expression(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer {
+ size: Some(crate::VectorSize::Bi),
+ ..
+ },
+ &crate::Expression::AccessIndex { base, index },
+ ) => {
+ vector = Some(Index::Static(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::Access { base, index },
+ ) => {
+ scalar = Some(Index::Expression(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::AccessIndex { base, index },
+ ) => {
+ scalar = Some(Index::Static(index));
+ current_expr = base;
+ }
+ _ => break,
+ }
+ }
+
+ write!(self.out, "{level}")?;
+
+ if let Some(MatrixAccess { index, base }) = matrix {
+ let base_ty_res = &func_ctx.info[base].ty;
+ let resolved = base_ty_res.inner_with(&module.types);
+ let ty = match *resolved {
+ TypeInner::Pointer { base, .. } => base,
+ _ => base_ty_res.handle().unwrap(),
+ };
+
+ if let Some(Index::Static(vec_index)) = vector {
+ self.write_expr(module, base, func_ctx)?;
+ write!(
+ self.out,
+ ".{}_{}",
+ &self.names[&NameKey::StructMember(ty, index)],
+ vec_index
+ )?;
+
+ if let Some(scalar_index) = scalar {
+ write!(self.out, "[")?;
+ match scalar_index {
+ Index::Static(index) => {
+ write!(self.out, "{index}")?;
+ }
+ Index::Expression(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ write!(self.out, "]")?;
+ }
+
+ write!(self.out, " = ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ } else {
+ let access = WrappedStructMatrixAccess { ty, index };
+ match (&vector, &scalar) {
+ (&Some(_), &Some(_)) => {
+ self.write_wrapped_struct_matrix_set_scalar_function_name(
+ access,
+ )?;
+ }
+ (&Some(_), &None) => {
+ self.write_wrapped_struct_matrix_set_vec_function_name(access)?;
+ }
+ (&None, _) => {
+ self.write_wrapped_struct_matrix_set_function_name(access)?;
+ }
+ }
+
+ write!(self.out, "(")?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+
+ if let Some(Index::Expression(vec_index)) = vector {
+ write!(self.out, ", ")?;
+ self.write_expr(module, vec_index, func_ctx)?;
+
+ if let Some(scalar_index) = scalar {
+ write!(self.out, ", ")?;
+ match scalar_index {
+ Index::Static(index) => {
+ write!(self.out, "{index}")?;
+ }
+ Index::Expression(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ }
+ }
+ writeln!(self.out, ");")?;
+ }
+ } else {
+ // We handle `Store`s to __matCx2 column vectors and scalar elements via
+ // the previously injected functions __set_col_of_matCx2 / __set_el_of_matCx2.
+ struct MatrixData {
+ columns: crate::VectorSize,
+ base: Handle<crate::Expression>,
+ }
+
+ enum Index {
+ Expression(Handle<crate::Expression>),
+ Static(u32),
+ }
+
+ let mut matrix = None;
+ let mut vector = None;
+ let mut scalar = None;
+
+ let mut current_expr = pointer;
+ for _ in 0..3 {
+ let resolved = func_ctx.resolve_type(current_expr, &module.types);
+ match (resolved, &func_ctx.expressions[current_expr]) {
+ (
+ &TypeInner::ValuePointer {
+ size: Some(crate::VectorSize::Bi),
+ ..
+ },
+ &crate::Expression::Access { base, index },
+ ) => {
+ vector = Some(index);
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::Access { base, index },
+ ) => {
+ scalar = Some(Index::Expression(index));
+ current_expr = base;
+ }
+ (
+ &TypeInner::ValuePointer { size: None, .. },
+ &crate::Expression::AccessIndex { base, index },
+ ) => {
+ scalar = Some(Index::Static(index));
+ current_expr = base;
+ }
+ _ => {
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = get_inner_matrix_of_struct_array_member(
+ module,
+ current_expr,
+ func_ctx,
+ true,
+ ) {
+ matrix = Some(MatrixData {
+ columns,
+ base: current_expr,
+ });
+ }
+
+ break;
+ }
+ }
+ }
+
+ if let (Some(MatrixData { columns, base }), Some(vec_index)) =
+ (matrix, vector)
+ {
+ if scalar.is_some() {
+ write!(self.out, "__set_el_of_mat{}x2", columns as u8)?;
+ } else {
+ write!(self.out, "__set_col_of_mat{}x2", columns as u8)?;
+ }
+ write!(self.out, "(")?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, vec_index, func_ctx)?;
+
+ if let Some(scalar_index) = scalar {
+ write!(self.out, ", ")?;
+ match scalar_index {
+ Index::Static(index) => {
+ write!(self.out, "{index}")?;
+ }
+ Index::Expression(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ }
+
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+
+ writeln!(self.out, ");")?;
+ } else {
+ self.write_expr(module, pointer, func_ctx)?;
+ write!(self.out, " = ")?;
+
+ // We cast the RHS of this store in cases where the LHS
+ // is a struct member with type:
+ // - matCx2 or
+ // - a (possibly nested) array of matCx2's
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = get_inner_matrix_of_struct_array_member(
+ module, pointer, func_ctx, false,
+ ) {
+ let mut resolved = func_ctx.resolve_type(pointer, &module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ }
+
+ write!(self.out, "(__mat{}x2", columns as u8)?;
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, ")")?;
+ }
+
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ";")?
+ }
+ }
+ }
+ }
+ Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ let l2 = level.next();
+ if !continuing.is_empty() || break_if.is_some() {
+ let gate_name = self.namer.call("loop_init");
+ writeln!(self.out, "{level}bool {gate_name} = true;")?;
+ writeln!(self.out, "{level}while(true) {{")?;
+ writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
+ let l3 = l2.next();
+ for sta in continuing.iter() {
+ self.write_stmt(module, sta, func_ctx, l3)?;
+ }
+ if let Some(condition) = break_if {
+ write!(self.out, "{l3}if (")?;
+ self.write_expr(module, condition, func_ctx)?;
+ writeln!(self.out, ") {{")?;
+ writeln!(self.out, "{}break;", l3.next())?;
+ writeln!(self.out, "{l3}}}")?;
+ }
+ writeln!(self.out, "{l2}}}")?;
+ writeln!(self.out, "{l2}{gate_name} = false;")?;
+ } else {
+ writeln!(self.out, "{level}while(true) {{")?;
+ }
+
+ for sta in body.iter() {
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Break => writeln!(self.out, "{level}break;")?,
+ Statement::Continue => writeln!(self.out, "{level}continue;")?,
+ Statement::Barrier(barrier) => {
+ self.write_barrier(barrier, level)?;
+ }
+ Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ write!(self.out, "{level}")?;
+ self.write_expr(module, image, func_ctx)?;
+
+ write!(self.out, "[")?;
+ if let Some(index) = array_index {
+ // Array index accepted only for texture_storage_2d_array, so we can safety use int3(coordinate, array_index) here
+ write!(self.out, "int3(")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr(module, coordinate, func_ctx)?;
+ }
+ write!(self.out, "]")?;
+
+ write!(self.out, " = ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ";")?;
+ }
+ Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ if let Some(expr) = result {
+ write!(self.out, "const ")?;
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ let expr_ty = &func_ctx.info[expr].ty;
+ match *expr_ty {
+ proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
+ proc::TypeResolution::Value(ref value) => {
+ self.write_value_type(module, value)?
+ }
+ };
+ write!(self.out, " {name} = ")?;
+ self.named_expressions.insert(expr, name);
+ }
+ let func_name = &self.names[&NameKey::Function(function)];
+ write!(self.out, "{func_name}(")?;
+ for (index, argument) in arguments.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_expr(module, *argument, func_ctx)?;
+ }
+ writeln!(self.out, ");")?
+ }
+ Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ match func_ctx.info[result].ty {
+ proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
+ proc::TypeResolution::Value(ref value) => {
+ self.write_value_type(module, value)?
+ }
+ };
+
+ // Validation ensures that `pointer` has a `Pointer` type.
+ let pointer_space = func_ctx
+ .resolve_type(pointer, &module.types)
+ .pointer_space()
+ .unwrap();
+
+ let fun_str = fun.to_hlsl_suffix();
+ write!(self.out, " {res_name}; ")?;
+ match pointer_space {
+ crate::AddressSpace::WorkGroup => {
+ write!(self.out, "Interlocked{fun_str}(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ }
+ crate::AddressSpace::Storage { .. } => {
+ let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
+ // The call to `self.write_storage_address` wants
+ // mutable access to all of `self`, so temporarily take
+ // ownership of our reusable access chain buffer.
+ let chain = mem::take(&mut self.temp_access_chain);
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "{var_name}.Interlocked{fun_str}(")?;
+ self.write_storage_address(module, &chain, func_ctx)?;
+ self.temp_access_chain = chain;
+ }
+ ref other => {
+ return Err(Error::Custom(format!(
+ "invalid address space {other:?} for atomic statement"
+ )))
+ }
+ }
+ write!(self.out, ", ")?;
+ // handle the special cases
+ match *fun {
+ crate::AtomicFunction::Subtract => {
+ // we just wrote `InterlockedAdd`, so negate the argument
+ write!(self.out, "-")?;
+ }
+ crate::AtomicFunction::Exchange { compare: Some(_) } => {
+ return Err(Error::Unimplemented("atomic CompareExchange".to_string()));
+ }
+ _ => {}
+ }
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ", {res_name});")?;
+ self.named_expressions.insert(result, res_name);
+ }
+ Statement::WorkGroupUniformLoad { pointer, result } => {
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+ write!(self.out, "{level}")?;
+ let name = format!("_expr{}", result.index());
+ self.write_named_expr(module, pointer, name, result, func_ctx)?;
+
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+ }
+ Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ // Start the switch
+ write!(self.out, "{level}")?;
+ write!(self.out, "switch(")?;
+ self.write_expr(module, selector, func_ctx)?;
+ writeln!(self.out, ") {{")?;
+
+ // Write all cases
+ let indent_level_1 = level.next();
+ let indent_level_2 = indent_level_1.next();
+
+ for (i, case) in cases.iter().enumerate() {
+ match case.value {
+ crate::SwitchValue::I32(value) => {
+ write!(self.out, "{indent_level_1}case {value}:")?
+ }
+ crate::SwitchValue::U32(value) => {
+ write!(self.out, "{indent_level_1}case {value}u:")?
+ }
+ crate::SwitchValue::Default => {
+ write!(self.out, "{indent_level_1}default:")?
+ }
+ }
+
+ // The new block is not only stylistic, it plays a role here:
+ // We might end up having to write the same case body
+ // multiple times due to FXC not supporting fallthrough.
+ // Therefore, some `Expression`s written by `Statement::Emit`
+ // will end up having the same name (`_expr<handle_index>`).
+ // So we need to put each case in its own scope.
+ let write_block_braces = !(case.fall_through && case.body.is_empty());
+ if write_block_braces {
+ writeln!(self.out, " {{")?;
+ } else {
+ writeln!(self.out)?;
+ }
+
+ // Although FXC does support a series of case clauses before
+ // a block[^yes], it does not support fallthrough from a
+ // non-empty case block to the next[^no]. If this case has a
+ // non-empty body with a fallthrough, emulate that by
+ // duplicating the bodies of all the cases it would fall
+ // into as extensions of this case's own body. This makes
+ // the HLSL output potentially quadratic in the size of the
+ // Naga IR.
+ //
+ // [^yes]: ```hlsl
+ // case 1:
+ // case 2: do_stuff()
+ // ```
+ // [^no]: ```hlsl
+ // case 1: do_this();
+ // case 2: do_that();
+ // ```
+ if case.fall_through && !case.body.is_empty() {
+ let curr_len = i + 1;
+ let end_case_idx = curr_len
+ + cases
+ .iter()
+ .skip(curr_len)
+ .position(|case| !case.fall_through)
+ .unwrap();
+ let indent_level_3 = indent_level_2.next();
+ for case in &cases[i..=end_case_idx] {
+ writeln!(self.out, "{indent_level_2}{{")?;
+ let prev_len = self.named_expressions.len();
+ for sta in case.body.iter() {
+ self.write_stmt(module, sta, func_ctx, indent_level_3)?;
+ }
+ // Clear all named expressions that were previously inserted by the statements in the block
+ self.named_expressions.truncate(prev_len);
+ writeln!(self.out, "{indent_level_2}}}")?;
+ }
+
+ let last_case = &cases[end_case_idx];
+ if last_case.body.last().map_or(true, |s| !s.is_terminator()) {
+ writeln!(self.out, "{indent_level_2}break;")?;
+ }
+ } else {
+ for sta in case.body.iter() {
+ self.write_stmt(module, sta, func_ctx, indent_level_2)?;
+ }
+ if !case.fall_through
+ && case.body.last().map_or(true, |s| !s.is_terminator())
+ {
+ writeln!(self.out, "{indent_level_2}break;")?;
+ }
+ }
+
+ if write_block_braces {
+ writeln!(self.out, "{indent_level_1}}}")?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::RayQuery { .. } => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ fn write_const_expression(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ ) -> BackendResult {
+ self.write_possibly_const_expression(
+ module,
+ expr,
+ &module.const_expressions,
+ |writer, expr| writer.write_const_expression(module, expr),
+ )
+ }
+
+ fn write_possibly_const_expression<E>(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ expressions: &crate::Arena<crate::Expression>,
+ write_expression: E,
+ ) -> BackendResult
+ where
+ E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
+ {
+ use crate::Expression;
+
+ match expressions[expr] {
+ Expression::Literal(literal) => match literal {
+ // Floats are written using `Debug` instead of `Display` because it always appends the
+ // decimal part even it's zero
+ crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
+ crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
+ crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
+ crate::Literal::I32(value) => write!(self.out, "{}", value)?,
+ crate::Literal::I64(value) => write!(self.out, "{}L", value)?,
+ crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
+ crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
+ return Err(Error::Custom(
+ "Abstract types should not appear in IR presented to backends".into(),
+ ));
+ }
+ },
+ Expression::Constant(handle) => {
+ let constant = &module.constants[handle];
+ if constant.name.is_some() {
+ write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
+ } else {
+ self.write_const_expression(module, constant.init)?;
+ }
+ }
+ Expression::ZeroValue(ty) => self.write_default_init(module, ty)?,
+ Expression::Compose { ty, ref components } => {
+ match module.types[ty].inner {
+ TypeInner::Struct { .. } | TypeInner::Array { .. } => {
+ self.write_wrapped_constructor_function_name(
+ module,
+ WrappedConstructor { ty },
+ )?;
+ }
+ _ => {
+ self.write_type(module, ty)?;
+ }
+ };
+ write!(self.out, "(")?;
+ for (index, component) in components.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ write_expression(self, *component)?;
+ }
+ write!(self.out, ")")?;
+ }
+ Expression::Splat { size, value } => {
+ // hlsl is not supported one value constructor
+ // if we write, for example, int4(0), dxc returns error:
+ // error: too few elements in vector initialization (expected 4 elements, have 1)
+ let number_of_components = match size {
+ crate::VectorSize::Bi => "xx",
+ crate::VectorSize::Tri => "xxx",
+ crate::VectorSize::Quad => "xxxx",
+ };
+ write!(self.out, "(")?;
+ write_expression(self, value)?;
+ write!(self.out, ").{number_of_components}")?
+ }
+ _ => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Helper method to write expressions
+ ///
+ /// # Notes
+ /// Doesn't add any newlines or leading/trailing spaces
+ pub(super) fn write_expr(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ use crate::Expression;
+
+ // Handle the special semantics of vertex_index/instance_index
+ let ff_input = if self.options.special_constants_binding.is_some() {
+ func_ctx.is_fixed_function_input(expr, module)
+ } else {
+ None
+ };
+ let closing_bracket = match ff_input {
+ Some(crate::BuiltIn::VertexIndex) => {
+ write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX} + ")?;
+ ")"
+ }
+ Some(crate::BuiltIn::InstanceIndex) => {
+ write!(self.out, "({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE} + ",)?;
+ ")"
+ }
+ Some(crate::BuiltIn::NumWorkGroups) => {
+ // Note: despite their names (`FIRST_VERTEX` and `FIRST_INSTANCE`),
+ // in compute shaders the special constants contain the number
+ // of workgroups, which we are using here.
+ write!(
+ self.out,
+ "uint3({SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_VERTEX}, {SPECIAL_CBUF_VAR}.{SPECIAL_FIRST_INSTANCE}, {SPECIAL_CBUF_VAR}.{SPECIAL_OTHER})",
+ )?;
+ return Ok(());
+ }
+ _ => "",
+ };
+
+ if let Some(name) = self.named_expressions.get(&expr) {
+ write!(self.out, "{name}{closing_bracket}")?;
+ return Ok(());
+ }
+
+ let expression = &func_ctx.expressions[expr];
+
+ match *expression {
+ Expression::Literal(_)
+ | Expression::Constant(_)
+ | Expression::ZeroValue(_)
+ | Expression::Compose { .. }
+ | Expression::Splat { .. } => {
+ self.write_possibly_const_expression(
+ module,
+ expr,
+ func_ctx.expressions,
+ |writer, expr| writer.write_expr(module, expr, func_ctx),
+ )?;
+ }
+ // All of the multiplication can be expressed as `mul`,
+ // except vector * vector, which needs to use the "*" operator.
+ Expression::Binary {
+ op: crate::BinaryOperator::Multiply,
+ left,
+ right,
+ } if func_ctx.resolve_type(left, &module.types).is_matrix()
+ || func_ctx.resolve_type(right, &module.types).is_matrix() =>
+ {
+ // We intentionally flip the order of multiplication as our matrices are implicitly transposed.
+ write!(self.out, "mul(")?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+
+ // TODO: handle undefined behavior of BinaryOperator::Modulo
+ //
+ // sint:
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ // if sign(left) != sign(right) return result as defined by WGSL
+ //
+ // uint:
+ // if right == 0 return 0
+ //
+ // float:
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+
+ // While HLSL supports float operands with the % operator it is only
+ // defined in cases where both sides are either positive or negative.
+ Expression::Binary {
+ op: crate::BinaryOperator::Modulo,
+ left,
+ right,
+ } if func_ctx.resolve_type(left, &module.types).scalar_kind()
+ == Some(crate::ScalarKind::Float) =>
+ {
+ write!(self.out, "fmod(")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Binary { op, left, right } => {
+ write!(self.out, "(")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, " {} ", crate::back::binary_operation_str(op))?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Access { base, index } => {
+ if let Some(crate::AddressSpace::Storage { .. }) =
+ func_ctx.resolve_type(expr, &module.types).pointer_space()
+ {
+ // do nothing, the chain is written on `Load`/`Store`
+ } else {
+ // We use the function __get_col_of_matCx2 here in cases
+ // where `base`s type resolves to a matCx2 and is part of a
+ // struct member with type of (possibly nested) array of matCx2's.
+ //
+ // Note that this only works for `Load`s and we handle
+ // `Store`s differently in `Statement::Store`.
+ if let Some(MatrixType {
+ columns,
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
+ {
+ write!(self.out, "__get_col_of_mat{}x2(", columns as u8)?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ write!(self.out, ")")?;
+ return Ok(());
+ }
+
+ let resolved = func_ctx.resolve_type(base, &module.types);
+
+ let non_uniform_qualifier = match *resolved {
+ TypeInner::BindingArray { .. } => {
+ let uniformity = &func_ctx.info[index].uniformity;
+
+ uniformity.non_uniform_result.is_some()
+ }
+ _ => false,
+ };
+
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, "[")?;
+ if non_uniform_qualifier {
+ write!(self.out, "NonUniformResourceIndex(")?;
+ }
+ self.write_expr(module, index, func_ctx)?;
+ if non_uniform_qualifier {
+ write!(self.out, ")")?;
+ }
+ write!(self.out, "]")?;
+ }
+ }
+ Expression::AccessIndex { base, index } => {
+ if let Some(crate::AddressSpace::Storage { .. }) =
+ func_ctx.resolve_type(expr, &module.types).pointer_space()
+ {
+ // do nothing, the chain is written on `Load`/`Store`
+ } else {
+ fn write_access<W: fmt::Write>(
+ writer: &mut super::Writer<'_, W>,
+ resolved: &TypeInner,
+ base_ty_handle: Option<Handle<crate::Type>>,
+ index: u32,
+ ) -> BackendResult {
+ match *resolved {
+ // We specifically lift the ValuePointer to this case. While `[0]` is valid
+ // HLSL for any vector behind a value pointer, FXC completely miscompiles
+ // it and generates completely nonsensical DXBC.
+ //
+ // See https://github.com/gfx-rs/naga/issues/2095 for more details.
+ TypeInner::Vector { .. } | TypeInner::ValuePointer { .. } => {
+ // Write vector access as a swizzle
+ write!(writer.out, ".{}", back::COMPONENTS[index as usize])?
+ }
+ TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::BindingArray { .. } => write!(writer.out, "[{index}]")?,
+ TypeInner::Struct { .. } => {
+ // This will never panic in case the type is a `Struct`, this is not true
+ // for other types so we can only check while inside this match arm
+ let ty = base_ty_handle.unwrap();
+
+ write!(
+ writer.out,
+ ".{}",
+ &writer.names[&NameKey::StructMember(ty, index)]
+ )?
+ }
+ ref other => {
+ return Err(Error::Custom(format!("Cannot index {other:?}")))
+ }
+ }
+ Ok(())
+ }
+
+ // We write the matrix column access in a special way since
+ // the type of `base` is our special __matCx2 struct.
+ if let Some(MatrixType {
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ ..
+ }) = get_inner_matrix_of_struct_array_member(module, base, func_ctx, true)
+ {
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, "._{index}")?;
+ return Ok(());
+ }
+
+ let base_ty_res = &func_ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+ let base_ty_handle = match *resolved {
+ TypeInner::Pointer { base, .. } => {
+ resolved = &module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+
+ // We treat matrices of the form `matCx2` as a sequence of C `vec2`s.
+ // See the module-level block comment in mod.rs for details.
+ //
+ // We handle matrix reconstruction here for Loads.
+ // Stores are handled directly by `Statement::Store`.
+ if let TypeInner::Struct { ref members, .. } = *resolved {
+ let member = &members[index as usize];
+
+ match module.types[member.ty].inner {
+ TypeInner::Matrix {
+ rows: crate::VectorSize::Bi,
+ ..
+ } if member.binding.is_none() => {
+ let ty = base_ty_handle.unwrap();
+ self.write_wrapped_struct_matrix_get_function_name(
+ WrappedStructMatrixAccess { ty, index },
+ )?;
+ write!(self.out, "(")?;
+ self.write_expr(module, base, func_ctx)?;
+ write!(self.out, ")")?;
+ return Ok(());
+ }
+ _ => {}
+ }
+ }
+
+ self.write_expr(module, base, func_ctx)?;
+ write_access(self, resolved, base_ty_handle, index)?;
+ }
+ }
+ Expression::FunctionArgument(pos) => {
+ let key = func_ctx.argument_key(pos);
+ let name = &self.names[&key];
+ write!(self.out, "{name}")?;
+ }
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ use crate::SampleLevel as Sl;
+ const COMPONENTS: [&str; 4] = ["", "Green", "Blue", "Alpha"];
+
+ let (base_str, component_str) = match gather {
+ Some(component) => ("Gather", COMPONENTS[component as usize]),
+ None => ("Sample", ""),
+ };
+ let cmp_str = match depth_ref {
+ Some(_) => "Cmp",
+ None => "",
+ };
+ let level_str = match level {
+ Sl::Zero if gather.is_none() => "LevelZero",
+ Sl::Auto | Sl::Zero => "",
+ Sl::Exact(_) => "Level",
+ Sl::Bias(_) => "Bias",
+ Sl::Gradient { .. } => "Grad",
+ };
+
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ".{base_str}{cmp_str}{component_str}{level_str}(")?;
+ self.write_expr(module, sampler, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_texture_coordinates(
+ "float",
+ coordinate,
+ array_index,
+ None,
+ module,
+ func_ctx,
+ )?;
+
+ if let Some(depth_ref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(module, depth_ref, func_ctx)?;
+ }
+
+ match level {
+ Sl::Auto | Sl::Zero => {}
+ Sl::Exact(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Bias(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Gradient { x, y } => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, x, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, y, func_ctx)?;
+ }
+ }
+
+ if let Some(offset) = offset {
+ write!(self.out, ", ")?;
+ write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807
+ self.write_const_expression(module, offset)?;
+ write!(self.out, ")")?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ Expression::ImageQuery { image, query } => {
+ // use wrapped image query function
+ if let TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } = *func_ctx.resolve_type(image, &module.types)
+ {
+ let wrapped_image_query = WrappedImageQuery {
+ dim,
+ arrayed,
+ class,
+ query: query.into(),
+ };
+
+ self.write_wrapped_image_query_function_name(wrapped_image_query)?;
+ write!(self.out, "(")?;
+ // Image always first param
+ self.write_expr(module, image, func_ctx)?;
+ if let crate::ImageQuery::Size { level: Some(level) } = query {
+ write!(self.out, ", ")?;
+ self.write_expr(module, level, func_ctx)?;
+ }
+ write!(self.out, ")")?;
+ }
+ }
+ Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ".Load(")?;
+
+ self.write_texture_coordinates(
+ "int",
+ coordinate,
+ array_index,
+ level,
+ module,
+ func_ctx,
+ )?;
+
+ if let Some(sample) = sample {
+ write!(self.out, ", ")?;
+ self.write_expr(module, sample, func_ctx)?;
+ }
+
+ // close bracket for Load function
+ write!(self.out, ")")?;
+
+ // return x component if return type is scalar
+ if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
+ write!(self.out, ".x")?;
+ }
+ }
+ Expression::GlobalVariable(handle) => match module.global_variables[handle].space {
+ crate::AddressSpace::Storage { .. } => {}
+ _ => {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{name}")?;
+ }
+ },
+ Expression::LocalVariable(handle) => {
+ write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
+ }
+ Expression::Load { pointer } => {
+ match func_ctx
+ .resolve_type(pointer, &module.types)
+ .pointer_space()
+ {
+ Some(crate::AddressSpace::Storage { .. }) => {
+ let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
+ let result_ty = func_ctx.info[expr].ty.clone();
+ self.write_storage_load(module, var_handle, result_ty, func_ctx)?;
+ }
+ _ => {
+ let mut close_paren = false;
+
+ // We cast the value loaded to a native HLSL floatCx2
+ // in cases where it is of type:
+ // - __matCx2 or
+ // - a (possibly nested) array of __matCx2's
+ if let Some(MatrixType {
+ rows: crate::VectorSize::Bi,
+ width: 4,
+ ..
+ }) = get_inner_matrix_of_struct_array_member(
+ module, pointer, func_ctx, false,
+ )
+ .or_else(|| get_inner_matrix_of_global_uniform(module, pointer, func_ctx))
+ {
+ let mut resolved = func_ctx.resolve_type(pointer, &module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ }
+
+ write!(self.out, "((")?;
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_type(module, base)?;
+ self.write_array_size(module, base, size)?;
+ } else {
+ self.write_value_type(module, resolved)?;
+ }
+ write!(self.out, ")")?;
+ close_paren = true;
+ }
+
+ self.write_expr(module, pointer, func_ctx)?;
+
+ if close_paren {
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ }
+ Expression::Unary { op, expr } => {
+ // https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-operators#unary-operators
+ let op_str = match op {
+ crate::UnaryOperator::Negate => "-",
+ crate::UnaryOperator::LogicalNot => "!",
+ crate::UnaryOperator::BitwiseNot => "~",
+ };
+ write!(self.out, "{op_str}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ let inner = func_ctx.resolve_type(expr, &module.types);
+ match convert {
+ Some(dst_width) => {
+ let scalar = crate::Scalar {
+ kind,
+ width: dst_width,
+ };
+ match *inner {
+ TypeInner::Vector { size, .. } => {
+ write!(
+ self.out,
+ "{}{}(",
+ scalar.to_hlsl_str()?,
+ back::vector_size_str(size)
+ )?;
+ }
+ TypeInner::Scalar(_) => {
+ write!(self.out, "{}(", scalar.to_hlsl_str()?,)?;
+ }
+ TypeInner::Matrix { columns, rows, .. } => {
+ write!(
+ self.out,
+ "{}{}x{}(",
+ scalar.to_hlsl_str()?,
+ back::vector_size_str(columns),
+ back::vector_size_str(rows)
+ )?;
+ }
+ _ => {
+ return Err(Error::Unimplemented(format!(
+ "write_expr expression::as {inner:?}"
+ )));
+ }
+ };
+ }
+ None => {
+ write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
+ }
+ }
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ enum Function {
+ Asincosh { is_sin: bool },
+ Atanh,
+ ExtractBits,
+ InsertBits,
+ Pack2x16float,
+ Pack2x16snorm,
+ Pack2x16unorm,
+ Pack4x8snorm,
+ Pack4x8unorm,
+ Unpack2x16float,
+ Unpack2x16snorm,
+ Unpack2x16unorm,
+ Unpack4x8snorm,
+ Unpack4x8unorm,
+ Regular(&'static str),
+ MissingIntOverload(&'static str),
+ MissingIntReturnType(&'static str),
+ CountTrailingZeros,
+ CountLeadingZeros,
+ }
+
+ let fun = match fun {
+ // comparison
+ Mf::Abs => Function::Regular("abs"),
+ Mf::Min => Function::Regular("min"),
+ Mf::Max => Function::Regular("max"),
+ Mf::Clamp => Function::Regular("clamp"),
+ Mf::Saturate => Function::Regular("saturate"),
+ // trigonometry
+ Mf::Cos => Function::Regular("cos"),
+ Mf::Cosh => Function::Regular("cosh"),
+ Mf::Sin => Function::Regular("sin"),
+ Mf::Sinh => Function::Regular("sinh"),
+ Mf::Tan => Function::Regular("tan"),
+ Mf::Tanh => Function::Regular("tanh"),
+ Mf::Acos => Function::Regular("acos"),
+ Mf::Asin => Function::Regular("asin"),
+ Mf::Atan => Function::Regular("atan"),
+ Mf::Atan2 => Function::Regular("atan2"),
+ Mf::Asinh => Function::Asincosh { is_sin: true },
+ Mf::Acosh => Function::Asincosh { is_sin: false },
+ Mf::Atanh => Function::Atanh,
+ Mf::Radians => Function::Regular("radians"),
+ Mf::Degrees => Function::Regular("degrees"),
+ // decomposition
+ Mf::Ceil => Function::Regular("ceil"),
+ Mf::Floor => Function::Regular("floor"),
+ Mf::Round => Function::Regular("round"),
+ Mf::Fract => Function::Regular("frac"),
+ Mf::Trunc => Function::Regular("trunc"),
+ Mf::Modf => Function::Regular(MODF_FUNCTION),
+ Mf::Frexp => Function::Regular(FREXP_FUNCTION),
+ Mf::Ldexp => Function::Regular("ldexp"),
+ // exponent
+ Mf::Exp => Function::Regular("exp"),
+ Mf::Exp2 => Function::Regular("exp2"),
+ Mf::Log => Function::Regular("log"),
+ Mf::Log2 => Function::Regular("log2"),
+ Mf::Pow => Function::Regular("pow"),
+ // geometry
+ Mf::Dot => Function::Regular("dot"),
+ //Mf::Outer => ,
+ Mf::Cross => Function::Regular("cross"),
+ Mf::Distance => Function::Regular("distance"),
+ Mf::Length => Function::Regular("length"),
+ Mf::Normalize => Function::Regular("normalize"),
+ Mf::FaceForward => Function::Regular("faceforward"),
+ Mf::Reflect => Function::Regular("reflect"),
+ Mf::Refract => Function::Regular("refract"),
+ // computational
+ Mf::Sign => Function::Regular("sign"),
+ Mf::Fma => Function::Regular("mad"),
+ Mf::Mix => Function::Regular("lerp"),
+ Mf::Step => Function::Regular("step"),
+ Mf::SmoothStep => Function::Regular("smoothstep"),
+ Mf::Sqrt => Function::Regular("sqrt"),
+ Mf::InverseSqrt => Function::Regular("rsqrt"),
+ //Mf::Inverse =>,
+ Mf::Transpose => Function::Regular("transpose"),
+ Mf::Determinant => Function::Regular("determinant"),
+ // bits
+ Mf::CountTrailingZeros => Function::CountTrailingZeros,
+ Mf::CountLeadingZeros => Function::CountLeadingZeros,
+ Mf::CountOneBits => Function::MissingIntOverload("countbits"),
+ Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
+ Mf::FindLsb => Function::MissingIntReturnType("firstbitlow"),
+ Mf::FindMsb => Function::MissingIntReturnType("firstbithigh"),
+ Mf::ExtractBits => Function::ExtractBits,
+ Mf::InsertBits => Function::InsertBits,
+ // Data Packing
+ Mf::Pack2x16float => Function::Pack2x16float,
+ Mf::Pack2x16snorm => Function::Pack2x16snorm,
+ Mf::Pack2x16unorm => Function::Pack2x16unorm,
+ Mf::Pack4x8snorm => Function::Pack4x8snorm,
+ Mf::Pack4x8unorm => Function::Pack4x8unorm,
+ // Data Unpacking
+ Mf::Unpack2x16float => Function::Unpack2x16float,
+ Mf::Unpack2x16snorm => Function::Unpack2x16snorm,
+ Mf::Unpack2x16unorm => Function::Unpack2x16unorm,
+ Mf::Unpack4x8snorm => Function::Unpack4x8snorm,
+ Mf::Unpack4x8unorm => Function::Unpack4x8unorm,
+ _ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
+ };
+
+ match fun {
+ Function::Asincosh { is_sin } => {
+ write!(self.out, "log(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " + sqrt(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " * ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ match is_sin {
+ true => write!(self.out, " + 1.0))")?,
+ false => write!(self.out, " - 1.0))")?,
+ }
+ }
+ Function::Atanh => {
+ write!(self.out, "0.5 * log((1.0 + ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") / (1.0 - ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ }
+ Function::ExtractBits => {
+ // e: T,
+ // offset: u32,
+ // count: u32
+ // T is u32 or i32 or vecN<u32> or vecN<i32>
+ if let (Some(offset), Some(count)) = (arg1, arg2) {
+ let scalar_width: u8 = 32;
+ // Works for signed and unsigned
+ // (count == 0 ? 0 : (e << (32 - count - offset)) >> (32 - count))
+ write!(self.out, "(")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, " == 0 ? 0 : (")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " << ({scalar_width} - ")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, " - ")?;
+ self.write_expr(module, offset, func_ctx)?;
+ write!(self.out, ")) >> ({scalar_width} - ")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, "))")?;
+ }
+ }
+ Function::InsertBits => {
+ // e: T,
+ // newbits: T,
+ // offset: u32,
+ // count: u32
+ // returns T
+ // T is i32, u32, vecN<i32>, or vecN<u32>
+ if let (Some(newbits), Some(offset), Some(count)) = (arg1, arg2, arg3) {
+ let scalar_width: u8 = 32;
+ let scalar_max: u32 = 0xFFFFFFFF;
+ // mask = ((0xFFFFFFFFu >> (32 - count)) << offset)
+ // (count == 0 ? e : ((e & ~mask) | ((newbits << offset) & mask)))
+ write!(self.out, "(")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, " == 0 ? ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " : ")?;
+ write!(self.out, "(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " & ~")?;
+ // mask
+ write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, ")) << ")?;
+ self.write_expr(module, offset, func_ctx)?;
+ write!(self.out, ")")?;
+ // end mask
+ write!(self.out, ") | ((")?;
+ self.write_expr(module, newbits, func_ctx)?;
+ write!(self.out, " << ")?;
+ self.write_expr(module, offset, func_ctx)?;
+ write!(self.out, ") & ")?;
+ // // mask
+ write!(self.out, "(({scalar_max}u >> ({scalar_width}u - ")?;
+ self.write_expr(module, count, func_ctx)?;
+ write!(self.out, ")) << ")?;
+ self.write_expr(module, offset, func_ctx)?;
+ write!(self.out, ")")?;
+ // // end mask
+ write!(self.out, "))")?;
+ }
+ }
+ Function::Pack2x16float => {
+ write!(self.out, "(f32tof16(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[0]) | f32tof16(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[1]) << 16)")?;
+ }
+ Function::Pack2x16snorm => {
+ let scale = 32767;
+
+ write!(self.out, "uint((int(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[0], -1.0, 1.0) * {scale}.0)) & 0xFFFF) | ((int(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[1], -1.0, 1.0) * {scale}.0)) & 0xFFFF) << 16))",)?;
+ }
+ Function::Pack2x16unorm => {
+ let scale = 65535;
+
+ write!(self.out, "(uint(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[1], 0.0, 1.0) * {scale}.0)) << 16)")?;
+ }
+ Function::Pack4x8snorm => {
+ let scale = 127;
+
+ write!(self.out, "uint((int(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[0], -1.0, 1.0) * {scale}.0)) & 0xFF) | ((int(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[1], -1.0, 1.0) * {scale}.0)) & 0xFF) << 8) | ((int(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[2], -1.0, 1.0) * {scale}.0)) & 0xFF) << 16) | ((int(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[3], -1.0, 1.0) * {scale}.0)) & 0xFF) << 24))",)?;
+ }
+ Function::Pack4x8unorm => {
+ let scale = 255;
+
+ write!(self.out, "(uint(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[0], 0.0, 1.0) * {scale}.0)) | uint(round(clamp(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[1], 0.0, 1.0) * {scale}.0)) << 8 | uint(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ "[2], 0.0, 1.0) * {scale}.0)) << 16 | uint(round(clamp("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "[3], 0.0, 1.0) * {scale}.0)) << 24)")?;
+ }
+
+ Function::Unpack2x16float => {
+ write!(self.out, "float2(f16tof32(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "), f16tof32((")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") >> 16))")?;
+ }
+ Function::Unpack2x16snorm => {
+ let scale = 32767;
+
+ write!(self.out, "(float2(int2(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " << 16, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") >> 16) / {scale}.0)")?;
+ }
+ Function::Unpack2x16unorm => {
+ let scale = 65535;
+
+ write!(self.out, "(float2(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " & 0xFFFF, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " >> 16) / {scale}.0)")?;
+ }
+ Function::Unpack4x8snorm => {
+ let scale = 127;
+
+ write!(self.out, "(float4(int4(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " << 24, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " << 16, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " << 8, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ") >> 24) / {scale}.0)")?;
+ }
+ Function::Unpack4x8unorm => {
+ let scale = 255;
+
+ write!(self.out, "(float4(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " & 0xFF, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " >> 8 & 0xFF, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " >> 16 & 0xFF, ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " >> 24) / {scale}.0)")?;
+ }
+ Function::Regular(fun_name) => {
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ if let Some(arg) = arg1 {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ if let Some(arg) = arg2 {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ if let Some(arg) = arg3 {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ write!(self.out, ")")?
+ }
+ Function::MissingIntOverload(fun_name) => {
+ let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind();
+ if let Some(ScalarKind::Sint) = scalar_kind {
+ write!(self.out, "asint({fun_name}(asuint(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ } else {
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Function::MissingIntReturnType(fun_name) => {
+ let scalar_kind = func_ctx.resolve_type(arg, &module.types).scalar_kind();
+ if let Some(ScalarKind::Sint) = scalar_kind {
+ write!(self.out, "asint({fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Function::CountTrailingZeros => {
+ match *func_ctx.resolve_type(arg, &module.types) {
+ TypeInner::Vector { size, scalar } => {
+ let s = match size {
+ crate::VectorSize::Bi => ".xx",
+ crate::VectorSize::Tri => ".xxx",
+ crate::VectorSize::Quad => ".xxxx",
+ };
+
+ if let ScalarKind::Uint = scalar.kind {
+ write!(self.out, "min((32u){s}, firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "asint(min((32u){s}, firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ }
+ }
+ TypeInner::Scalar(scalar) => {
+ if let ScalarKind::Uint = scalar.kind {
+ write!(self.out, "min(32u, firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "asint(min(32u, firstbitlow(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ return Ok(());
+ }
+ Function::CountLeadingZeros => {
+ match *func_ctx.resolve_type(arg, &module.types) {
+ TypeInner::Vector { size, scalar } => {
+ let s = match size {
+ crate::VectorSize::Bi => ".xx",
+ crate::VectorSize::Tri => ".xxx",
+ crate::VectorSize::Quad => ".xxxx",
+ };
+
+ if let ScalarKind::Uint = scalar.kind {
+ write!(self.out, "((31u){s} - firstbithigh(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(
+ self.out,
+ " < (0){s} ? (0){s} : (31){s} - asint(firstbithigh("
+ )?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ }
+ }
+ TypeInner::Scalar(scalar) => {
+ if let ScalarKind::Uint = scalar.kind {
+ write!(self.out, "(31u - firstbithigh(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, "))")?;
+ } else {
+ write!(self.out, "(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, " < 0 ? 0 : 31 - asint(firstbithigh(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ write!(self.out, ")))")?;
+ }
+ }
+ _ => unreachable!(),
+ }
+
+ return Ok(());
+ }
+ }
+ }
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.write_expr(module, vector, func_ctx)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ self.out.write_char(back::COMPONENTS[sc as usize])?;
+ }
+ }
+ Expression::ArrayLength(expr) => {
+ let var_handle = match func_ctx.expressions[expr] {
+ Expression::AccessIndex { base, index: _ } => {
+ match func_ctx.expressions[base] {
+ Expression::GlobalVariable(handle) => handle,
+ _ => unreachable!(),
+ }
+ }
+ Expression::GlobalVariable(handle) => handle,
+ _ => unreachable!(),
+ };
+
+ let var = &module.global_variables[var_handle];
+ let (offset, stride) = match module.types[var.ty].inner {
+ TypeInner::Array { stride, .. } => (0, stride),
+ TypeInner::Struct { ref members, .. } => {
+ let last = members.last().unwrap();
+ let stride = match module.types[last.ty].inner {
+ TypeInner::Array { stride, .. } => stride,
+ _ => unreachable!(),
+ };
+ (last.offset, stride)
+ }
+ _ => unreachable!(),
+ };
+
+ let storage_access = match var.space {
+ crate::AddressSpace::Storage { access } => access,
+ _ => crate::StorageAccess::default(),
+ };
+ let wrapped_array_length = WrappedArrayLength {
+ writable: storage_access.contains(crate::StorageAccess::STORE),
+ };
+
+ write!(self.out, "((")?;
+ self.write_wrapped_array_length_function_name(wrapped_array_length)?;
+ let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
+ write!(self.out, "({var_name}) - {offset}) / {stride})")?
+ }
+ Expression::Derivative { axis, ctrl, expr } => {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ if axis == Axis::Width && (ctrl == Ctrl::Coarse || ctrl == Ctrl::Fine) {
+ let tail = match ctrl {
+ Ctrl::Coarse => "coarse",
+ Ctrl::Fine => "fine",
+ Ctrl::None => unreachable!(),
+ };
+ write!(self.out, "abs(ddx_{tail}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")) + abs(ddy_{tail}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, "))")?
+ } else {
+ let fun_str = match (axis, ctrl) {
+ (Axis::X, Ctrl::Coarse) => "ddx_coarse",
+ (Axis::X, Ctrl::Fine) => "ddx_fine",
+ (Axis::X, Ctrl::None) => "ddx",
+ (Axis::Y, Ctrl::Coarse) => "ddy_coarse",
+ (Axis::Y, Ctrl::Fine) => "ddy_fine",
+ (Axis::Y, Ctrl::None) => "ddy",
+ (Axis::Width, Ctrl::Coarse | Ctrl::Fine) => unreachable!(),
+ (Axis::Width, Ctrl::None) => "fwidth",
+ };
+ write!(self.out, "{fun_str}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ }
+ Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+
+ let fun_str = match fun {
+ Rf::All => "all",
+ Rf::Any => "any",
+ Rf::IsNan => "isnan",
+ Rf::IsInf => "isinf",
+ };
+ write!(self.out, "{fun_str}(")?;
+ self.write_expr(module, argument, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ write!(self.out, "(")?;
+ self.write_expr(module, condition, func_ctx)?;
+ write!(self.out, " ? ")?;
+ self.write_expr(module, accept, func_ctx)?;
+ write!(self.out, " : ")?;
+ self.write_expr(module, reject, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ // Not supported yet
+ Expression::RayQueryGetIntersection { .. } => unreachable!(),
+ // Nothing to do here, since call expression already cached
+ Expression::CallResult(_)
+ | Expression::AtomicResult { .. }
+ | Expression::WorkGroupUniformLoadResult { .. }
+ | Expression::RayQueryProceedResult => {}
+ }
+
+ if !closing_bracket.is_empty() {
+ write!(self.out, "{closing_bracket}")?;
+ }
+ Ok(())
+ }
+
+ fn write_named_expr(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Expression>,
+ name: String,
+ // The expression which is being named.
+ // Generally, this is the same as handle, except in WorkGroupUniformLoad
+ named: Handle<crate::Expression>,
+ ctx: &back::FunctionCtx,
+ ) -> BackendResult {
+ match ctx.info[named].ty {
+ proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
+ TypeInner::Struct { .. } => {
+ let ty_name = &self.names[&NameKey::Type(ty_handle)];
+ write!(self.out, "{ty_name}")?;
+ }
+ _ => {
+ self.write_type(module, ty_handle)?;
+ }
+ },
+ proc::TypeResolution::Value(ref inner) => {
+ self.write_value_type(module, inner)?;
+ }
+ }
+
+ let resolved = ctx.resolve_type(named, &module.types);
+
+ write!(self.out, " {name}")?;
+ // If rhs is a array type, we should write array size
+ if let TypeInner::Array { base, size, .. } = *resolved {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, " = ")?;
+ self.write_expr(module, handle, ctx)?;
+ writeln!(self.out, ";")?;
+ self.named_expressions.insert(named, name);
+
+ Ok(())
+ }
+
+ /// Helper function that write default zero initialization
+ fn write_default_init(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ write!(self.out, "(")?;
+ self.write_type(module, ty)?;
+ if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
+ self.write_array_size(module, base, size)?;
+ }
+ write!(self.out, ")0")?;
+ Ok(())
+ }
+
+ fn write_barrier(&mut self, barrier: crate::Barrier, level: back::Level) -> BackendResult {
+ if barrier.contains(crate::Barrier::STORAGE) {
+ writeln!(self.out, "{level}DeviceMemoryBarrierWithGroupSync();")?;
+ }
+ if barrier.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
+ }
+ Ok(())
+ }
+}
+
+pub(super) struct MatrixType {
+ pub(super) columns: crate::VectorSize,
+ pub(super) rows: crate::VectorSize,
+ pub(super) width: crate::Bytes,
+}
+
+pub(super) fn get_inner_matrix_data(
+ module: &Module,
+ handle: Handle<crate::Type>,
+) -> Option<MatrixType> {
+ match module.types[handle].inner {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => Some(MatrixType {
+ columns,
+ rows,
+ width: scalar.width,
+ }),
+ TypeInner::Array { base, .. } => get_inner_matrix_data(module, base),
+ _ => None,
+ }
+}
+
+/// Returns the matrix data if the access chain starting at `base`:
+/// - starts with an expression with resolved type of [`TypeInner::Matrix`] if `direct = true`
+/// - contains one or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
+/// - ends at an expression with resolved type of [`TypeInner::Struct`]
+pub(super) fn get_inner_matrix_of_struct_array_member(
+ module: &Module,
+ base: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ direct: bool,
+) -> Option<MatrixType> {
+ let mut mat_data = None;
+ let mut array_base = None;
+
+ let mut current_base = base;
+ loop {
+ let mut resolved = func_ctx.resolve_type(current_base, &module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ };
+
+ match *resolved {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ mat_data = Some(MatrixType {
+ columns,
+ rows,
+ width: scalar.width,
+ })
+ }
+ TypeInner::Array { base, .. } => {
+ array_base = Some(base);
+ }
+ TypeInner::Struct { .. } => {
+ if let Some(array_base) = array_base {
+ if direct {
+ return mat_data;
+ } else {
+ return get_inner_matrix_data(module, array_base);
+ }
+ }
+
+ break;
+ }
+ _ => break,
+ }
+
+ current_base = match func_ctx.expressions[current_base] {
+ crate::Expression::Access { base, .. } => base,
+ crate::Expression::AccessIndex { base, .. } => base,
+ _ => break,
+ };
+ }
+ None
+}
+
+/// Returns the matrix data if the access chain starting at `base`:
+/// - starts with an expression with resolved type of [`TypeInner::Matrix`]
+/// - contains zero or more expressions with resolved type of [`TypeInner::Array`] of [`TypeInner::Matrix`]
+/// - ends with an [`Expression::GlobalVariable`](crate::Expression::GlobalVariable) in [`AddressSpace::Uniform`](crate::AddressSpace::Uniform)
+fn get_inner_matrix_of_global_uniform(
+ module: &Module,
+ base: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+) -> Option<MatrixType> {
+ let mut mat_data = None;
+ let mut array_base = None;
+
+ let mut current_base = base;
+ loop {
+ let mut resolved = func_ctx.resolve_type(current_base, &module.types);
+ if let TypeInner::Pointer { base, .. } = *resolved {
+ resolved = &module.types[base].inner;
+ };
+
+ match *resolved {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ mat_data = Some(MatrixType {
+ columns,
+ rows,
+ width: scalar.width,
+ })
+ }
+ TypeInner::Array { base, .. } => {
+ array_base = Some(base);
+ }
+ _ => break,
+ }
+
+ current_base = match func_ctx.expressions[current_base] {
+ crate::Expression::Access { base, .. } => base,
+ crate::Expression::AccessIndex { base, .. } => base,
+ crate::Expression::GlobalVariable(handle)
+ if module.global_variables[handle].space == crate::AddressSpace::Uniform =>
+ {
+ return mat_data.or_else(|| {
+ array_base.and_then(|array_base| get_inner_matrix_data(module, array_base))
+ })
+ }
+ _ => break,
+ };
+ }
+ None
+}
diff --git a/third_party/rust/naga/src/back/mod.rs b/third_party/rust/naga/src/back/mod.rs
new file mode 100644
index 0000000000..8100b930e9
--- /dev/null
+++ b/third_party/rust/naga/src/back/mod.rs
@@ -0,0 +1,273 @@
+/*!
+Backend functions that export shader [`Module`](super::Module)s into binary and text formats.
+*/
+#![allow(dead_code)] // can be dead if none of the enabled backends need it
+
+#[cfg(feature = "dot-out")]
+pub mod dot;
+#[cfg(feature = "glsl-out")]
+pub mod glsl;
+#[cfg(feature = "hlsl-out")]
+pub mod hlsl;
+#[cfg(feature = "msl-out")]
+pub mod msl;
+#[cfg(feature = "spv-out")]
+pub mod spv;
+#[cfg(feature = "wgsl-out")]
+pub mod wgsl;
+
+const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];
+const INDENT: &str = " ";
+const BAKE_PREFIX: &str = "_e";
+
+type NeedBakeExpressions = crate::FastHashSet<crate::Handle<crate::Expression>>;
+
+#[derive(Clone, Copy)]
+struct Level(usize);
+
+impl Level {
+ const fn next(&self) -> Self {
+ Level(self.0 + 1)
+ }
+}
+
+impl std::fmt::Display for Level {
+ fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
+ (0..self.0).try_for_each(|_| formatter.write_str(INDENT))
+ }
+}
+
+/// Whether we're generating an entry point or a regular function.
+///
+/// Backend languages often require different code for a [`Function`]
+/// depending on whether it represents an [`EntryPoint`] or not.
+/// Backends can pass common code one of these values to select the
+/// right behavior.
+///
+/// These values also carry enough information to find the `Function`
+/// in the [`Module`]: the `Handle` for a regular function, or the
+/// index into [`Module::entry_points`] for an entry point.
+///
+/// [`Function`]: crate::Function
+/// [`EntryPoint`]: crate::EntryPoint
+/// [`Module`]: crate::Module
+/// [`Module::entry_points`]: crate::Module::entry_points
+enum FunctionType {
+ /// A regular function.
+ Function(crate::Handle<crate::Function>),
+ /// An [`EntryPoint`], and its index in [`Module::entry_points`].
+ ///
+ /// [`EntryPoint`]: crate::EntryPoint
+ /// [`Module::entry_points`]: crate::Module::entry_points
+ EntryPoint(crate::proc::EntryPointIndex),
+}
+
+impl FunctionType {
+ fn is_compute_entry_point(&self, module: &crate::Module) -> bool {
+ match *self {
+ FunctionType::EntryPoint(index) => {
+ module.entry_points[index as usize].stage == crate::ShaderStage::Compute
+ }
+ FunctionType::Function(_) => false,
+ }
+ }
+}
+
+/// Helper structure that stores data needed when writing the function
+struct FunctionCtx<'a> {
+ /// The current function being written
+ ty: FunctionType,
+ /// Analysis about the function
+ info: &'a crate::valid::FunctionInfo,
+ /// The expression arena of the current function being written
+ expressions: &'a crate::Arena<crate::Expression>,
+ /// Map of expressions that have associated variable names
+ named_expressions: &'a crate::NamedExpressions,
+}
+
+impl FunctionCtx<'_> {
+ fn resolve_type<'a>(
+ &'a self,
+ handle: crate::Handle<crate::Expression>,
+ types: &'a crate::UniqueArena<crate::Type>,
+ ) -> &'a crate::TypeInner {
+ self.info[handle].ty.inner_with(types)
+ }
+
+ /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for a local in the current function
+ const fn name_key(&self, local: crate::Handle<crate::LocalVariable>) -> crate::proc::NameKey {
+ match self.ty {
+ FunctionType::Function(handle) => crate::proc::NameKey::FunctionLocal(handle, local),
+ FunctionType::EntryPoint(idx) => crate::proc::NameKey::EntryPointLocal(idx, local),
+ }
+ }
+
+ /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for a function argument.
+ ///
+ /// # Panics
+ /// - If the function arguments are less or equal to `arg`
+ const fn argument_key(&self, arg: u32) -> crate::proc::NameKey {
+ match self.ty {
+ FunctionType::Function(handle) => crate::proc::NameKey::FunctionArgument(handle, arg),
+ FunctionType::EntryPoint(ep_index) => {
+ crate::proc::NameKey::EntryPointArgument(ep_index, arg)
+ }
+ }
+ }
+
+ // Returns true if the given expression points to a fixed-function pipeline input.
+ fn is_fixed_function_input(
+ &self,
+ mut expression: crate::Handle<crate::Expression>,
+ module: &crate::Module,
+ ) -> Option<crate::BuiltIn> {
+ let ep_function = match self.ty {
+ FunctionType::Function(_) => return None,
+ FunctionType::EntryPoint(ep_index) => &module.entry_points[ep_index as usize].function,
+ };
+ let mut built_in = None;
+ loop {
+ match self.expressions[expression] {
+ crate::Expression::FunctionArgument(arg_index) => {
+ return match ep_function.arguments[arg_index as usize].binding {
+ Some(crate::Binding::BuiltIn(bi)) => Some(bi),
+ _ => built_in,
+ };
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ match *self.resolve_type(base, &module.types) {
+ crate::TypeInner::Struct { ref members, .. } => {
+ if let Some(crate::Binding::BuiltIn(bi)) =
+ members[index as usize].binding
+ {
+ built_in = Some(bi);
+ }
+ }
+ _ => return None,
+ }
+ expression = base;
+ }
+ _ => return None,
+ }
+ }
+ }
+}
+
+impl crate::Expression {
+ /// Returns the ref count, upon reaching which this expression
+ /// should be considered for baking.
+ ///
+ /// Note: we have to cache any expressions that depend on the control flow,
+ /// or otherwise they may be moved into a non-uniform control flow, accidentally.
+ /// See the [module-level documentation][emit] for details.
+ ///
+ /// [emit]: index.html#expression-evaluation-time
+ const fn bake_ref_count(&self) -> usize {
+ match *self {
+ // accesses are never cached, only loads are
+ crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => usize::MAX,
+ // sampling may use the control flow, and image ops look better by themselves
+ crate::Expression::ImageSample { .. } | crate::Expression::ImageLoad { .. } => 1,
+ // derivatives use the control flow
+ crate::Expression::Derivative { .. } => 1,
+ // TODO: We need a better fix for named `Load` expressions
+ // More info - https://github.com/gfx-rs/naga/pull/914
+ // And https://github.com/gfx-rs/naga/issues/910
+ crate::Expression::Load { .. } => 1,
+ // cache expressions that are referenced multiple times
+ _ => 2,
+ }
+ }
+}
+
+/// Helper function that returns the string corresponding to the [`BinaryOperator`](crate::BinaryOperator)
+/// # Notes
+/// Used by `glsl-out`, `msl-out`, `wgsl-out`, `hlsl-out`.
+const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str {
+ use crate::BinaryOperator as Bo;
+ match op {
+ Bo::Add => "+",
+ Bo::Subtract => "-",
+ Bo::Multiply => "*",
+ Bo::Divide => "/",
+ Bo::Modulo => "%",
+ Bo::Equal => "==",
+ Bo::NotEqual => "!=",
+ Bo::Less => "<",
+ Bo::LessEqual => "<=",
+ Bo::Greater => ">",
+ Bo::GreaterEqual => ">=",
+ Bo::And => "&",
+ Bo::ExclusiveOr => "^",
+ Bo::InclusiveOr => "|",
+ Bo::LogicalAnd => "&&",
+ Bo::LogicalOr => "||",
+ Bo::ShiftLeft => "<<",
+ Bo::ShiftRight => ">>",
+ }
+}
+
+/// Helper function that returns the string corresponding to the [`VectorSize`](crate::VectorSize)
+/// # Notes
+/// Used by `msl-out`, `wgsl-out`, `hlsl-out`.
+const fn vector_size_str(size: crate::VectorSize) -> &'static str {
+ match size {
+ crate::VectorSize::Bi => "2",
+ crate::VectorSize::Tri => "3",
+ crate::VectorSize::Quad => "4",
+ }
+}
+
+impl crate::TypeInner {
+ const fn is_handle(&self) -> bool {
+ match *self {
+ crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => true,
+ _ => false,
+ }
+ }
+}
+
+impl crate::Statement {
+ /// Returns true if the statement directly terminates the current block.
+ ///
+ /// Used to decide whether case blocks require a explicit `break`.
+ pub const fn is_terminator(&self) -> bool {
+ match *self {
+ crate::Statement::Break
+ | crate::Statement::Continue
+ | crate::Statement::Return { .. }
+ | crate::Statement::Kill => true,
+ _ => false,
+ }
+ }
+}
+
+bitflags::bitflags! {
+ /// Ray flags, for a [`RayDesc`]'s `flags` field.
+ ///
+ /// Note that these exactly correspond to the SPIR-V "Ray Flags" mask, and
+ /// the SPIR-V backend passes them directly through to the
+ /// `OpRayQueryInitializeKHR` instruction. (We have to choose something, so
+ /// we might as well make one back end's life easier.)
+ ///
+ /// [`RayDesc`]: crate::Module::generate_ray_desc_type
+ #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
+ pub struct RayFlag: u32 {
+ const OPAQUE = 0x01;
+ const NO_OPAQUE = 0x02;
+ const TERMINATE_ON_FIRST_HIT = 0x04;
+ const SKIP_CLOSEST_HIT_SHADER = 0x08;
+ const CULL_BACK_FACING = 0x10;
+ const CULL_FRONT_FACING = 0x20;
+ const CULL_OPAQUE = 0x40;
+ const CULL_NO_OPAQUE = 0x80;
+ const SKIP_TRIANGLES = 0x100;
+ const SKIP_AABBS = 0x200;
+ }
+}
+
+#[repr(u32)]
+enum RayIntersectionType {
+ Triangle = 1,
+ BoundingBox = 4,
+}
diff --git a/third_party/rust/naga/src/back/msl/keywords.rs b/third_party/rust/naga/src/back/msl/keywords.rs
new file mode 100644
index 0000000000..f0025bf239
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/keywords.rs
@@ -0,0 +1,342 @@
+// MSLS - Metal Shading Language Specification:
+// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+//
+// C++ - Standard for Programming Language C++ (N4431)
+// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4431.pdf
+pub const RESERVED: &[&str] = &[
+ // Standard for Programming Language C++ (N4431): 2.5 Alternative tokens
+ "and",
+ "bitor",
+ "or",
+ "xor",
+ "compl",
+ "bitand",
+ "and_eq",
+ "or_eq",
+ "xor_eq",
+ "not",
+ "not_eq",
+ // Standard for Programming Language C++ (N4431): 2.11 Keywords
+ "alignas",
+ "alignof",
+ "asm",
+ "auto",
+ "bool",
+ "break",
+ "case",
+ "catch",
+ "char",
+ "char16_t",
+ "char32_t",
+ "class",
+ "const",
+ "constexpr",
+ "const_cast",
+ "continue",
+ "decltype",
+ "default",
+ "delete",
+ "do",
+ "double",
+ "dynamic_cast",
+ "else",
+ "enum",
+ "explicit",
+ "export",
+ "extern",
+ "false",
+ "float",
+ "for",
+ "friend",
+ "goto",
+ "if",
+ "inline",
+ "int",
+ "long",
+ "mutable",
+ "namespace",
+ "new",
+ "noexcept",
+ "nullptr",
+ "operator",
+ "private",
+ "protected",
+ "public",
+ "register",
+ "reinterpret_cast",
+ "return",
+ "short",
+ "signed",
+ "sizeof",
+ "static",
+ "static_assert",
+ "static_cast",
+ "struct",
+ "switch",
+ "template",
+ "this",
+ "thread_local",
+ "throw",
+ "true",
+ "try",
+ "typedef",
+ "typeid",
+ "typename",
+ "union",
+ "unsigned",
+ "using",
+ "virtual",
+ "void",
+ "volatile",
+ "wchar_t",
+ "while",
+ // Metal Shading Language Specification: 1.4.4 Restrictions
+ "main",
+ // Metal Shading Language Specification: 2.1 Scalar Data Types
+ "int8_t",
+ "uchar",
+ "uint8_t",
+ "int16_t",
+ "ushort",
+ "uint16_t",
+ "int32_t",
+ "uint",
+ "uint32_t",
+ "int64_t",
+ "uint64_t",
+ "half",
+ "bfloat",
+ "size_t",
+ "ptrdiff_t",
+ // Metal Shading Language Specification: 2.2 Vector Data Types
+ "bool2",
+ "bool3",
+ "bool4",
+ "char2",
+ "char3",
+ "char4",
+ "short2",
+ "short3",
+ "short4",
+ "int2",
+ "int3",
+ "int4",
+ "long2",
+ "long3",
+ "long4",
+ "uchar2",
+ "uchar3",
+ "uchar4",
+ "ushort2",
+ "ushort3",
+ "ushort4",
+ "uint2",
+ "uint3",
+ "uint4",
+ "ulong2",
+ "ulong3",
+ "ulong4",
+ "half2",
+ "half3",
+ "half4",
+ "bfloat2",
+ "bfloat3",
+ "bfloat4",
+ "float2",
+ "float3",
+ "float4",
+ "vec",
+ // Metal Shading Language Specification: 2.2.3 Packed Vector Types
+ "packed_bool2",
+ "packed_bool3",
+ "packed_bool4",
+ "packed_char2",
+ "packed_char3",
+ "packed_char4",
+ "packed_short2",
+ "packed_short3",
+ "packed_short4",
+ "packed_int2",
+ "packed_int3",
+ "packed_int4",
+ "packed_uchar2",
+ "packed_uchar3",
+ "packed_uchar4",
+ "packed_ushort2",
+ "packed_ushort3",
+ "packed_ushort4",
+ "packed_uint2",
+ "packed_uint3",
+ "packed_uint4",
+ "packed_half2",
+ "packed_half3",
+ "packed_half4",
+ "packed_bfloat2",
+ "packed_bfloat3",
+ "packed_bfloat4",
+ "packed_float2",
+ "packed_float3",
+ "packed_float4",
+ "packed_long2",
+ "packed_long3",
+ "packed_long4",
+ "packed_vec",
+ // Metal Shading Language Specification: 2.3 Matrix Data Types
+ "half2x2",
+ "half2x3",
+ "half2x4",
+ "half3x2",
+ "half3x3",
+ "half3x4",
+ "half4x2",
+ "half4x3",
+ "half4x4",
+ "float2x2",
+ "float2x3",
+ "float2x4",
+ "float3x2",
+ "float3x3",
+ "float3x4",
+ "float4x2",
+ "float4x3",
+ "float4x4",
+ "matrix",
+ // Metal Shading Language Specification: 2.6 Atomic Data Types
+ "atomic",
+ "atomic_int",
+ "atomic_uint",
+ "atomic_bool",
+ "atomic_ulong",
+ "atomic_float",
+ // Metal Shading Language Specification: 2.20 Type Conversions and Re-interpreting Data
+ "as_type",
+ // Metal Shading Language Specification: 4 Address Spaces
+ "device",
+ "constant",
+ "thread",
+ "threadgroup",
+ "threadgroup_imageblock",
+ "ray_data",
+ "object_data",
+ // Metal Shading Language Specification: 5.1 Functions
+ "vertex",
+ "fragment",
+ "kernel",
+ // Metal Shading Language Specification: 6.1 Namespace and Header Files
+ "metal",
+ // C99 / C++ extension:
+ "restrict",
+ // Metal reserved types in <metal_types>:
+ "llong",
+ "ullong",
+ "quad",
+ "complex",
+ "imaginary",
+ // Constants in <metal_types>:
+ "CHAR_BIT",
+ "SCHAR_MAX",
+ "SCHAR_MIN",
+ "UCHAR_MAX",
+ "CHAR_MAX",
+ "CHAR_MIN",
+ "USHRT_MAX",
+ "SHRT_MAX",
+ "SHRT_MIN",
+ "UINT_MAX",
+ "INT_MAX",
+ "INT_MIN",
+ "ULONG_MAX",
+ "LONG_MAX",
+ "LONG_MIN",
+ "ULLONG_MAX",
+ "LLONG_MAX",
+ "LLONG_MIN",
+ "FLT_DIG",
+ "FLT_MANT_DIG",
+ "FLT_MAX_10_EXP",
+ "FLT_MAX_EXP",
+ "FLT_MIN_10_EXP",
+ "FLT_MIN_EXP",
+ "FLT_RADIX",
+ "FLT_MAX",
+ "FLT_MIN",
+ "FLT_EPSILON",
+ "FLT_DECIMAL_DIG",
+ "FP_ILOGB0",
+ "FP_ILOGB0",
+ "FP_ILOGBNAN",
+ "FP_ILOGBNAN",
+ "MAXFLOAT",
+ "HUGE_VALF",
+ "INFINITY",
+ "NAN",
+ "M_E_F",
+ "M_LOG2E_F",
+ "M_LOG10E_F",
+ "M_LN2_F",
+ "M_LN10_F",
+ "M_PI_F",
+ "M_PI_2_F",
+ "M_PI_4_F",
+ "M_1_PI_F",
+ "M_2_PI_F",
+ "M_2_SQRTPI_F",
+ "M_SQRT2_F",
+ "M_SQRT1_2_F",
+ "HALF_DIG",
+ "HALF_MANT_DIG",
+ "HALF_MAX_10_EXP",
+ "HALF_MAX_EXP",
+ "HALF_MIN_10_EXP",
+ "HALF_MIN_EXP",
+ "HALF_RADIX",
+ "HALF_MAX",
+ "HALF_MIN",
+ "HALF_EPSILON",
+ "HALF_DECIMAL_DIG",
+ "MAXHALF",
+ "HUGE_VALH",
+ "M_E_H",
+ "M_LOG2E_H",
+ "M_LOG10E_H",
+ "M_LN2_H",
+ "M_LN10_H",
+ "M_PI_H",
+ "M_PI_2_H",
+ "M_PI_4_H",
+ "M_1_PI_H",
+ "M_2_PI_H",
+ "M_2_SQRTPI_H",
+ "M_SQRT2_H",
+ "M_SQRT1_2_H",
+ "DBL_DIG",
+ "DBL_MANT_DIG",
+ "DBL_MAX_10_EXP",
+ "DBL_MAX_EXP",
+ "DBL_MIN_10_EXP",
+ "DBL_MIN_EXP",
+ "DBL_RADIX",
+ "DBL_MAX",
+ "DBL_MIN",
+ "DBL_EPSILON",
+ "DBL_DECIMAL_DIG",
+ "MAXDOUBLE",
+ "HUGE_VAL",
+ "M_E",
+ "M_LOG2E",
+ "M_LOG10E",
+ "M_LN2",
+ "M_LN10",
+ "M_PI",
+ "M_PI_2",
+ "M_PI_4",
+ "M_1_PI",
+ "M_2_PI",
+ "M_2_SQRTPI",
+ "M_SQRT2",
+ "M_SQRT1_2",
+ // Naga utilities
+ "DefaultConstructible",
+ super::writer::FREXP_FUNCTION,
+ super::writer::MODF_FUNCTION,
+];
diff --git a/third_party/rust/naga/src/back/msl/mod.rs b/third_party/rust/naga/src/back/msl/mod.rs
new file mode 100644
index 0000000000..5ef18730c9
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/mod.rs
@@ -0,0 +1,541 @@
+/*!
+Backend for [MSL][msl] (Metal Shading Language).
+
+## Binding model
+
+Metal's bindings are flat per resource. Since there isn't an obvious mapping
+from SPIR-V's descriptor sets, we require a separate mapping provided in the options.
+This mapping may have one or more resource end points for each descriptor set + index
+pair.
+
+## Entry points
+
+Even though MSL and our IR appear to be similar in that the entry points in both can
+accept arguments and return values, the restrictions are different.
+MSL allows the varyings to be either in separate arguments, or inside a single
+`[[stage_in]]` struct. We gather input varyings and form this artificial structure.
+We also add all the (non-Private) globals into the arguments.
+
+At the beginning of the entry point, we assign the local constants and re-compose
+the arguments as they are declared on IR side, so that the rest of the logic can
+pretend that MSL doesn't have all the restrictions it has.
+
+For the result type, if it's a structure, we re-compose it with a temporary value
+holding the result.
+
+[msl]: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+*/
+
+use crate::{arena::Handle, proc::index, valid::ModuleInfo};
+use std::fmt::{Error as FmtError, Write};
+
+mod keywords;
+pub mod sampler;
+mod writer;
+
+pub use writer::Writer;
+
+pub type Slot = u8;
+pub type InlineSamplerIndex = u8;
+
+#[derive(Clone, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum BindSamplerTarget {
+ Resource(Slot),
+ Inline(InlineSamplerIndex),
+}
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
+pub struct BindTarget {
+ pub buffer: Option<Slot>,
+ pub texture: Option<Slot>,
+ pub sampler: Option<BindSamplerTarget>,
+ /// If the binding is an unsized binding array, this overrides the size.
+ pub binding_array_size: Option<u32>,
+ pub mutable: bool,
+}
+
+// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTarget>;
+
+#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
+pub struct EntryPointResources {
+ pub resources: BindingMap,
+
+ pub push_constant_buffer: Option<Slot>,
+
+ /// The slot of a buffer that contains an array of `u32`,
+ /// one for the size of each bound buffer that contains a runtime array,
+ /// in order of [`crate::GlobalVariable`] declarations.
+ pub sizes_buffer: Option<Slot>,
+}
+
+pub type EntryPointResourceMap = std::collections::BTreeMap<String, EntryPointResources>;
+
+enum ResolvedBinding {
+ BuiltIn(crate::BuiltIn),
+ Attribute(u32),
+ Color {
+ location: u32,
+ second_blend_source: bool,
+ },
+ User {
+ prefix: &'static str,
+ index: u32,
+ interpolation: Option<ResolvedInterpolation>,
+ },
+ Resource(BindTarget),
+}
+
+#[derive(Copy, Clone)]
+enum ResolvedInterpolation {
+ CenterPerspective,
+ CenterNoPerspective,
+ CentroidPerspective,
+ CentroidNoPerspective,
+ SamplePerspective,
+ SampleNoPerspective,
+ Flat,
+}
+
+// Note: some of these should be removed in favor of proper IR validation.
+
+#[derive(Debug, thiserror::Error)]
+pub enum Error {
+ #[error(transparent)]
+ Format(#[from] FmtError),
+ #[error("bind target {0:?} is empty")]
+ UnimplementedBindTarget(BindTarget),
+ #[error("composing of {0:?} is not implemented yet")]
+ UnsupportedCompose(Handle<crate::Type>),
+ #[error("operation {0:?} is not implemented yet")]
+ UnsupportedBinaryOp(crate::BinaryOperator),
+ #[error("standard function '{0}' is not implemented yet")]
+ UnsupportedCall(String),
+ #[error("feature '{0}' is not implemented yet")]
+ FeatureNotImplemented(String),
+ #[error("module is not valid")]
+ Validation,
+ #[error("BuiltIn {0:?} is not supported")]
+ UnsupportedBuiltIn(crate::BuiltIn),
+ #[error("capability {0:?} is not supported")]
+ CapabilityNotSupported(crate::valid::Capabilities),
+ #[error("attribute '{0}' is not supported for target MSL version")]
+ UnsupportedAttribute(String),
+ #[error("function '{0}' is not supported for target MSL version")]
+ UnsupportedFunction(String),
+ #[error("can not use writeable storage buffers in fragment stage prior to MSL 1.2")]
+ UnsupportedWriteableStorageBuffer,
+ #[error("can not use writeable storage textures in {0:?} stage prior to MSL 1.2")]
+ UnsupportedWriteableStorageTexture(crate::ShaderStage),
+ #[error("can not use read-write storage textures prior to MSL 1.2")]
+ UnsupportedRWStorageTexture,
+ #[error("array of '{0}' is not supported for target MSL version")]
+ UnsupportedArrayOf(String),
+ #[error("array of type '{0:?}' is not supported")]
+ UnsupportedArrayOfType(Handle<crate::Type>),
+ #[error("ray tracing is not supported prior to MSL 2.3")]
+ UnsupportedRayTracing,
+}
+
+#[derive(Clone, Debug, PartialEq, thiserror::Error)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum EntryPointError {
+ #[error("global '{0}' doesn't have a binding")]
+ MissingBinding(String),
+ #[error("mapping of {0:?} is missing")]
+ MissingBindTarget(crate::ResourceBinding),
+ #[error("mapping for push constants is missing")]
+ MissingPushConstants,
+ #[error("mapping for sizes buffer is missing")]
+ MissingSizesBuffer,
+}
+
+/// Points in the MSL code where we might emit a pipeline input or output.
+///
+/// Note that, even though vertex shaders' outputs are always fragment
+/// shaders' inputs, we still need to distinguish `VertexOutput` and
+/// `FragmentInput`, since there are certain differences in the way
+/// [`ResolvedBinding`s] are represented on either side.
+///
+/// [`ResolvedBinding`s]: ResolvedBinding
+#[derive(Clone, Copy, Debug)]
+enum LocationMode {
+ /// Input to the vertex shader.
+ VertexInput,
+
+ /// Output from the vertex shader.
+ VertexOutput,
+
+ /// Input to the fragment shader.
+ FragmentInput,
+
+ /// Output from the fragment shader.
+ FragmentOutput,
+
+ /// Compute shader input or output.
+ Uniform,
+}
+
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Options {
+ /// (Major, Minor) target version of the Metal Shading Language.
+ pub lang_version: (u8, u8),
+ /// Map of entry-point resources, indexed by entry point function name, to slots.
+ pub per_entry_point_map: EntryPointResourceMap,
+ /// Samplers to be inlined into the code.
+ pub inline_samplers: Vec<sampler::InlineSampler>,
+ /// Make it possible to link different stages via SPIRV-Cross.
+ pub spirv_cross_compatibility: bool,
+ /// Don't panic on missing bindings, instead generate invalid MSL.
+ pub fake_missing_bindings: bool,
+ /// Bounds checking policies.
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub bounds_check_policies: index::BoundsCheckPolicies,
+ /// Should workgroup variables be zero initialized (by polyfilling)?
+ pub zero_initialize_workgroup_memory: bool,
+}
+
+impl Default for Options {
+ fn default() -> Self {
+ Options {
+ lang_version: (1, 0),
+ per_entry_point_map: EntryPointResourceMap::default(),
+ inline_samplers: Vec::new(),
+ spirv_cross_compatibility: false,
+ fake_missing_bindings: true,
+ bounds_check_policies: index::BoundsCheckPolicies::default(),
+ zero_initialize_workgroup_memory: true,
+ }
+ }
+}
+
+/// A subset of options that are meant to be changed per pipeline.
+#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct PipelineOptions {
+ /// Allow `BuiltIn::PointSize` and inject it if doesn't exist.
+ ///
+ /// Metal doesn't like this for non-point primitive topologies and requires it for
+ /// point primitive topologies.
+ ///
+ /// Enable this for vertex shaders with point primitive topologies.
+ pub allow_and_force_point_size: bool,
+}
+
+impl Options {
+ fn resolve_local_binding(
+ &self,
+ binding: &crate::Binding,
+ mode: LocationMode,
+ ) -> Result<ResolvedBinding, Error> {
+ match *binding {
+ crate::Binding::BuiltIn(mut built_in) => {
+ match built_in {
+ crate::BuiltIn::Position { ref mut invariant } => {
+ if *invariant && self.lang_version < (2, 1) {
+ return Err(Error::UnsupportedAttribute("invariant".to_string()));
+ }
+
+ // The 'invariant' attribute may only appear on vertex
+ // shader outputs, not fragment shader inputs.
+ if !matches!(mode, LocationMode::VertexOutput) {
+ *invariant = false;
+ }
+ }
+ crate::BuiltIn::BaseInstance if self.lang_version < (1, 2) => {
+ return Err(Error::UnsupportedAttribute("base_instance".to_string()));
+ }
+ crate::BuiltIn::InstanceIndex if self.lang_version < (1, 2) => {
+ return Err(Error::UnsupportedAttribute("instance_id".to_string()));
+ }
+ // macOS: Since Metal 2.2
+ // iOS: Since Metal 2.3 (check depends on https://github.com/gfx-rs/naga/issues/2164)
+ crate::BuiltIn::PrimitiveIndex if self.lang_version < (2, 2) => {
+ return Err(Error::UnsupportedAttribute("primitive_id".to_string()));
+ }
+ _ => {}
+ }
+
+ Ok(ResolvedBinding::BuiltIn(built_in))
+ }
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ second_blend_source,
+ } => match mode {
+ LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(location)),
+ LocationMode::FragmentOutput => {
+ if second_blend_source && self.lang_version < (1, 2) {
+ return Err(Error::UnsupportedAttribute(
+ "second_blend_source".to_string(),
+ ));
+ }
+ Ok(ResolvedBinding::Color {
+ location,
+ second_blend_source,
+ })
+ }
+ LocationMode::VertexOutput | LocationMode::FragmentInput => {
+ Ok(ResolvedBinding::User {
+ prefix: if self.spirv_cross_compatibility {
+ "locn"
+ } else {
+ "loc"
+ },
+ index: location,
+ interpolation: {
+ // unwrap: The verifier ensures that vertex shader outputs and fragment
+ // shader inputs always have fully specified interpolation, and that
+ // sampling is `None` only for Flat interpolation.
+ let interpolation = interpolation.unwrap();
+ let sampling = sampling.unwrap_or(crate::Sampling::Center);
+ Some(ResolvedInterpolation::from_binding(interpolation, sampling))
+ },
+ })
+ }
+ LocationMode::Uniform => {
+ log::error!(
+ "Unexpected Binding::Location({}) for the Uniform mode",
+ location
+ );
+ Err(Error::Validation)
+ }
+ },
+ }
+ }
+
+ fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> {
+ self.per_entry_point_map.get(&ep.name)
+ }
+
+ fn get_resource_binding_target(
+ &self,
+ ep: &crate::EntryPoint,
+ res_binding: &crate::ResourceBinding,
+ ) -> Option<&BindTarget> {
+ self.get_entry_point_resources(ep)
+ .and_then(|res| res.resources.get(res_binding))
+ }
+
+ fn resolve_resource_binding(
+ &self,
+ ep: &crate::EntryPoint,
+ res_binding: &crate::ResourceBinding,
+ ) -> Result<ResolvedBinding, EntryPointError> {
+ let target = self.get_resource_binding_target(ep, res_binding);
+ match target {
+ Some(target) => Ok(ResolvedBinding::Resource(target.clone())),
+ None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
+ prefix: "fake",
+ index: 0,
+ interpolation: None,
+ }),
+ None => Err(EntryPointError::MissingBindTarget(res_binding.clone())),
+ }
+ }
+
+ fn resolve_push_constants(
+ &self,
+ ep: &crate::EntryPoint,
+ ) -> Result<ResolvedBinding, EntryPointError> {
+ let slot = self
+ .get_entry_point_resources(ep)
+ .and_then(|res| res.push_constant_buffer);
+ match slot {
+ Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
+ buffer: Some(slot),
+ ..Default::default()
+ })),
+ None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
+ prefix: "fake",
+ index: 0,
+ interpolation: None,
+ }),
+ None => Err(EntryPointError::MissingPushConstants),
+ }
+ }
+
+ fn resolve_sizes_buffer(
+ &self,
+ ep: &crate::EntryPoint,
+ ) -> Result<ResolvedBinding, EntryPointError> {
+ let slot = self
+ .get_entry_point_resources(ep)
+ .and_then(|res| res.sizes_buffer);
+ match slot {
+ Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
+ buffer: Some(slot),
+ ..Default::default()
+ })),
+ None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
+ prefix: "fake",
+ index: 0,
+ interpolation: None,
+ }),
+ None => Err(EntryPointError::MissingSizesBuffer),
+ }
+ }
+}
+
+impl ResolvedBinding {
+ fn as_inline_sampler<'a>(&self, options: &'a Options) -> Option<&'a sampler::InlineSampler> {
+ match *self {
+ Self::Resource(BindTarget {
+ sampler: Some(BindSamplerTarget::Inline(index)),
+ ..
+ }) => Some(&options.inline_samplers[index as usize]),
+ _ => None,
+ }
+ }
+
+ const fn as_bind_target(&self) -> Option<&BindTarget> {
+ match *self {
+ Self::Resource(ref target) => Some(target),
+ _ => None,
+ }
+ }
+
+ fn try_fmt<W: Write>(&self, out: &mut W) -> Result<(), Error> {
+ write!(out, " [[")?;
+ match *self {
+ Self::BuiltIn(built_in) => {
+ use crate::BuiltIn as Bi;
+ let name = match built_in {
+ Bi::Position { invariant: false } => "position",
+ Bi::Position { invariant: true } => "position, invariant",
+ // vertex
+ Bi::BaseInstance => "base_instance",
+ Bi::BaseVertex => "base_vertex",
+ Bi::ClipDistance => "clip_distance",
+ Bi::InstanceIndex => "instance_id",
+ Bi::PointSize => "point_size",
+ Bi::VertexIndex => "vertex_id",
+ // fragment
+ Bi::FragDepth => "depth(any)",
+ Bi::PointCoord => "point_coord",
+ Bi::FrontFacing => "front_facing",
+ Bi::PrimitiveIndex => "primitive_id",
+ Bi::SampleIndex => "sample_id",
+ Bi::SampleMask => "sample_mask",
+ // compute
+ Bi::GlobalInvocationId => "thread_position_in_grid",
+ Bi::LocalInvocationId => "thread_position_in_threadgroup",
+ Bi::LocalInvocationIndex => "thread_index_in_threadgroup",
+ Bi::WorkGroupId => "threadgroup_position_in_grid",
+ Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
+ Bi::NumWorkGroups => "threadgroups_per_grid",
+ Bi::CullDistance | Bi::ViewIndex => {
+ return Err(Error::UnsupportedBuiltIn(built_in))
+ }
+ };
+ write!(out, "{name}")?;
+ }
+ Self::Attribute(index) => write!(out, "attribute({index})")?,
+ Self::Color {
+ location,
+ second_blend_source,
+ } => {
+ if second_blend_source {
+ write!(out, "color({location}) index(1)")?
+ } else {
+ write!(out, "color({location})")?
+ }
+ }
+ Self::User {
+ prefix,
+ index,
+ interpolation,
+ } => {
+ write!(out, "user({prefix}{index})")?;
+ if let Some(interpolation) = interpolation {
+ write!(out, ", ")?;
+ interpolation.try_fmt(out)?;
+ }
+ }
+ Self::Resource(ref target) => {
+ if let Some(id) = target.buffer {
+ write!(out, "buffer({id})")?;
+ } else if let Some(id) = target.texture {
+ write!(out, "texture({id})")?;
+ } else if let Some(BindSamplerTarget::Resource(id)) = target.sampler {
+ write!(out, "sampler({id})")?;
+ } else {
+ return Err(Error::UnimplementedBindTarget(target.clone()));
+ }
+ }
+ }
+ write!(out, "]]")?;
+ Ok(())
+ }
+}
+
+impl ResolvedInterpolation {
+ const fn from_binding(interpolation: crate::Interpolation, sampling: crate::Sampling) -> Self {
+ use crate::Interpolation as I;
+ use crate::Sampling as S;
+
+ match (interpolation, sampling) {
+ (I::Perspective, S::Center) => Self::CenterPerspective,
+ (I::Perspective, S::Centroid) => Self::CentroidPerspective,
+ (I::Perspective, S::Sample) => Self::SamplePerspective,
+ (I::Linear, S::Center) => Self::CenterNoPerspective,
+ (I::Linear, S::Centroid) => Self::CentroidNoPerspective,
+ (I::Linear, S::Sample) => Self::SampleNoPerspective,
+ (I::Flat, _) => Self::Flat,
+ }
+ }
+
+ fn try_fmt<W: Write>(self, out: &mut W) -> Result<(), Error> {
+ let identifier = match self {
+ Self::CenterPerspective => "center_perspective",
+ Self::CenterNoPerspective => "center_no_perspective",
+ Self::CentroidPerspective => "centroid_perspective",
+ Self::CentroidNoPerspective => "centroid_no_perspective",
+ Self::SamplePerspective => "sample_perspective",
+ Self::SampleNoPerspective => "sample_no_perspective",
+ Self::Flat => "flat",
+ };
+ out.write_str(identifier)?;
+ Ok(())
+ }
+}
+
+/// Information about a translated module that is required
+/// for the use of the result.
+pub struct TranslationInfo {
+ /// Mapping of the entry point names. Each item in the array
+ /// corresponds to an entry point index.
+ ///
+ ///Note: Some entry points may fail translation because of missing bindings.
+ pub entry_point_names: Vec<Result<String, EntryPointError>>,
+}
+
+pub fn write_string(
+ module: &crate::Module,
+ info: &ModuleInfo,
+ options: &Options,
+ pipeline_options: &PipelineOptions,
+) -> Result<(String, TranslationInfo), Error> {
+ let mut w = writer::Writer::new(String::new());
+ let info = w.write(module, info, options, pipeline_options)?;
+ Ok((w.finish(), info))
+}
+
+#[test]
+fn test_error_size() {
+ use std::mem::size_of;
+ assert_eq!(size_of::<Error>(), 32);
+}
diff --git a/third_party/rust/naga/src/back/msl/sampler.rs b/third_party/rust/naga/src/back/msl/sampler.rs
new file mode 100644
index 0000000000..0bf987076d
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/sampler.rs
@@ -0,0 +1,176 @@
+#[cfg(feature = "deserialize")]
+use serde::Deserialize;
+#[cfg(feature = "serialize")]
+use serde::Serialize;
+use std::{num::NonZeroU32, ops::Range};
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Coord {
+ Normalized,
+ Pixel,
+}
+
+impl Default for Coord {
+ fn default() -> Self {
+ Self::Normalized
+ }
+}
+
+impl Coord {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Normalized => "normalized",
+ Self::Pixel => "pixel",
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Address {
+ Repeat,
+ MirroredRepeat,
+ ClampToEdge,
+ ClampToZero,
+ ClampToBorder,
+}
+
+impl Default for Address {
+ fn default() -> Self {
+ Self::ClampToEdge
+ }
+}
+
+impl Address {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Repeat => "repeat",
+ Self::MirroredRepeat => "mirrored_repeat",
+ Self::ClampToEdge => "clamp_to_edge",
+ Self::ClampToZero => "clamp_to_zero",
+ Self::ClampToBorder => "clamp_to_border",
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum BorderColor {
+ TransparentBlack,
+ OpaqueBlack,
+ OpaqueWhite,
+}
+
+impl Default for BorderColor {
+ fn default() -> Self {
+ Self::TransparentBlack
+ }
+}
+
+impl BorderColor {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::TransparentBlack => "transparent_black",
+ Self::OpaqueBlack => "opaque_black",
+ Self::OpaqueWhite => "opaque_white",
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum Filter {
+ Nearest,
+ Linear,
+}
+
+impl Filter {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Nearest => "nearest",
+ Self::Linear => "linear",
+ }
+ }
+}
+
+impl Default for Filter {
+ fn default() -> Self {
+ Self::Nearest
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub enum CompareFunc {
+ Never,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+ Equal,
+ NotEqual,
+ Always,
+}
+
+impl Default for CompareFunc {
+ fn default() -> Self {
+ Self::Never
+ }
+}
+
+impl CompareFunc {
+ pub const fn as_str(&self) -> &'static str {
+ match *self {
+ Self::Never => "never",
+ Self::Less => "less",
+ Self::LessEqual => "less_equal",
+ Self::Greater => "greater",
+ Self::GreaterEqual => "greater_equal",
+ Self::Equal => "equal",
+ Self::NotEqual => "not_equal",
+ Self::Always => "always",
+ }
+ }
+}
+
+#[derive(Clone, Debug, Default, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+pub struct InlineSampler {
+ pub coord: Coord,
+ pub address: [Address; 3],
+ pub border_color: BorderColor,
+ pub mag_filter: Filter,
+ pub min_filter: Filter,
+ pub mip_filter: Option<Filter>,
+ pub lod_clamp: Option<Range<f32>>,
+ pub max_anisotropy: Option<NonZeroU32>,
+ pub compare_func: CompareFunc,
+}
+
+impl Eq for InlineSampler {}
+
+#[allow(renamed_and_removed_lints)]
+#[allow(clippy::derive_hash_xor_eq)]
+impl std::hash::Hash for InlineSampler {
+ fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
+ self.coord.hash(hasher);
+ self.address.hash(hasher);
+ self.border_color.hash(hasher);
+ self.mag_filter.hash(hasher);
+ self.min_filter.hash(hasher);
+ self.mip_filter.hash(hasher);
+ self.lod_clamp
+ .as_ref()
+ .map(|range| (range.start.to_bits(), range.end.to_bits()))
+ .hash(hasher);
+ self.max_anisotropy.hash(hasher);
+ self.compare_func.hash(hasher);
+ }
+}
diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs
new file mode 100644
index 0000000000..1e496b5f50
--- /dev/null
+++ b/third_party/rust/naga/src/back/msl/writer.rs
@@ -0,0 +1,4659 @@
+use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, TranslationInfo};
+use crate::{
+ arena::Handle,
+ back,
+ proc::index,
+ proc::{self, NameKey, TypeResolution},
+ valid, FastHashMap, FastHashSet,
+};
+use bit_set::BitSet;
+use std::{
+ fmt::{Display, Error as FmtError, Formatter, Write},
+ iter,
+};
+
+/// Shorthand result used internally by the backend
+type BackendResult = Result<(), Error>;
+
+const NAMESPACE: &str = "metal";
+// The name of the array member of the Metal struct types we generate to
+// represent Naga `Array` types. See the comments in `Writer::write_type_defs`
+// for details.
+const WRAPPED_ARRAY_FIELD: &str = "inner";
+// This is a hack: we need to pass a pointer to an atomic,
+// but generally the backend isn't putting "&" in front of every pointer.
+// Some more general handling of pointers is needed to be implemented here.
+const ATOMIC_REFERENCE: &str = "&";
+
+const RT_NAMESPACE: &str = "metal::raytracing";
+const RAY_QUERY_TYPE: &str = "_RayQuery";
+const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector";
+const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
+const RAY_QUERY_FIELD_READY: &str = "ready";
+const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";
+
+pub(crate) const MODF_FUNCTION: &str = "naga_modf";
+pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
+
+/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
+///
+/// The `sizes` slice determines whether this function writes a
+/// scalar, vector, or matrix type:
+///
+/// - An empty slice produces a scalar type.
+/// - A one-element slice produces a vector type.
+/// - A two element slice `[ROWS COLUMNS]` produces a matrix of the given size.
+fn put_numeric_type(
+ out: &mut impl Write,
+ scalar: crate::Scalar,
+ sizes: &[crate::VectorSize],
+) -> Result<(), FmtError> {
+ match (scalar, sizes) {
+ (scalar, &[]) => {
+ write!(out, "{}", scalar.to_msl_name())
+ }
+ (scalar, &[rows]) => {
+ write!(
+ out,
+ "{}::{}{}",
+ NAMESPACE,
+ scalar.to_msl_name(),
+ back::vector_size_str(rows)
+ )
+ }
+ (scalar, &[rows, columns]) => {
+ write!(
+ out,
+ "{}::{}{}x{}",
+ NAMESPACE,
+ scalar.to_msl_name(),
+ back::vector_size_str(columns),
+ back::vector_size_str(rows)
+ )
+ }
+ (_, _) => Ok(()), // not meaningful
+ }
+}
+
+/// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions.
+const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e";
+
+struct TypeContext<'a> {
+ handle: Handle<crate::Type>,
+ gctx: proc::GlobalCtx<'a>,
+ names: &'a FastHashMap<NameKey, String>,
+ access: crate::StorageAccess,
+ binding: Option<&'a super::ResolvedBinding>,
+ first_time: bool,
+}
+
+impl<'a> Display for TypeContext<'a> {
+ fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
+ let ty = &self.gctx.types[self.handle];
+ if ty.needs_alias() && !self.first_time {
+ let name = &self.names[&NameKey::Type(self.handle)];
+ return write!(out, "{name}");
+ }
+
+ match ty.inner {
+ crate::TypeInner::Scalar(scalar) => put_numeric_type(out, scalar, &[]),
+ crate::TypeInner::Atomic(scalar) => {
+ write!(out, "{}::atomic_{}", NAMESPACE, scalar.to_msl_name())
+ }
+ crate::TypeInner::Vector { size, scalar } => put_numeric_type(out, scalar, &[size]),
+ crate::TypeInner::Matrix { columns, rows, .. } => {
+ put_numeric_type(out, crate::Scalar::F32, &[rows, columns])
+ }
+ crate::TypeInner::Pointer { base, space } => {
+ let sub = Self {
+ handle: base,
+ first_time: false,
+ ..*self
+ };
+ let space_name = match space.to_msl_name() {
+ Some(name) => name,
+ None => return Ok(()),
+ };
+ write!(out, "{space_name} {sub}&")
+ }
+ crate::TypeInner::ValuePointer {
+ size,
+ scalar,
+ space,
+ } => {
+ match space.to_msl_name() {
+ Some(name) => write!(out, "{name} ")?,
+ None => return Ok(()),
+ };
+ match size {
+ Some(rows) => put_numeric_type(out, scalar, &[rows])?,
+ None => put_numeric_type(out, scalar, &[])?,
+ };
+
+ write!(out, "&")
+ }
+ crate::TypeInner::Array { base, .. } => {
+ let sub = Self {
+ handle: base,
+ first_time: false,
+ ..*self
+ };
+ // Array lengths go at the end of the type definition,
+ // so just print the element type here.
+ write!(out, "{sub}")
+ }
+ crate::TypeInner::Struct { .. } => unreachable!(),
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ let dim_str = match dim {
+ crate::ImageDimension::D1 => "1d",
+ crate::ImageDimension::D2 => "2d",
+ crate::ImageDimension::D3 => "3d",
+ crate::ImageDimension::Cube => "cube",
+ };
+ let (texture_str, msaa_str, kind, access) = match class {
+ crate::ImageClass::Sampled { kind, multi } => {
+ let (msaa_str, access) = if multi {
+ ("_ms", "read")
+ } else {
+ ("", "sample")
+ };
+ ("texture", msaa_str, kind, access)
+ }
+ crate::ImageClass::Depth { multi } => {
+ let (msaa_str, access) = if multi {
+ ("_ms", "read")
+ } else {
+ ("", "sample")
+ };
+ ("depth", msaa_str, crate::ScalarKind::Float, access)
+ }
+ crate::ImageClass::Storage { format, .. } => {
+ let access = if self
+ .access
+ .contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
+ {
+ "read_write"
+ } else if self.access.contains(crate::StorageAccess::STORE) {
+ "write"
+ } else if self.access.contains(crate::StorageAccess::LOAD) {
+ "read"
+ } else {
+ log::warn!(
+ "Storage access for {:?} (name '{}'): {:?}",
+ self.handle,
+ ty.name.as_deref().unwrap_or_default(),
+ self.access
+ );
+ unreachable!("module is not valid");
+ };
+ ("texture", "", format.into(), access)
+ }
+ };
+ let base_name = crate::Scalar { kind, width: 4 }.to_msl_name();
+ let array_str = if arrayed { "_array" } else { "" };
+ write!(
+ out,
+ "{NAMESPACE}::{texture_str}{dim_str}{msaa_str}{array_str}<{base_name}, {NAMESPACE}::access::{access}>",
+ )
+ }
+ crate::TypeInner::Sampler { comparison: _ } => {
+ write!(out, "{NAMESPACE}::sampler")
+ }
+ crate::TypeInner::AccelerationStructure => {
+ write!(out, "{RT_NAMESPACE}::instance_acceleration_structure")
+ }
+ crate::TypeInner::RayQuery => {
+ write!(out, "{RAY_QUERY_TYPE}")
+ }
+ crate::TypeInner::BindingArray { base, size } => {
+ let base_tyname = Self {
+ handle: base,
+ first_time: false,
+ ..*self
+ };
+
+ if let Some(&super::ResolvedBinding::Resource(super::BindTarget {
+ binding_array_size: Some(override_size),
+ ..
+ })) = self.binding
+ {
+ write!(out, "{NAMESPACE}::array<{base_tyname}, {override_size}>")
+ } else if let crate::ArraySize::Constant(size) = size {
+ write!(out, "{NAMESPACE}::array<{base_tyname}, {size}>")
+ } else {
+ unreachable!("metal requires all arrays be constant sized");
+ }
+ }
+ }
+ }
+}
+
+struct TypedGlobalVariable<'a> {
+ module: &'a crate::Module,
+ names: &'a FastHashMap<NameKey, String>,
+ handle: Handle<crate::GlobalVariable>,
+ usage: valid::GlobalUse,
+ binding: Option<&'a super::ResolvedBinding>,
+ reference: bool,
+}
+
+impl<'a> TypedGlobalVariable<'a> {
+ fn try_fmt<W: Write>(&self, out: &mut W) -> BackendResult {
+ let var = &self.module.global_variables[self.handle];
+ let name = &self.names[&NameKey::GlobalVariable(self.handle)];
+
+ let storage_access = match var.space {
+ crate::AddressSpace::Storage { access } => access,
+ _ => match self.module.types[var.ty].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => access,
+ crate::TypeInner::BindingArray { base, .. } => {
+ match self.module.types[base].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => access,
+ _ => crate::StorageAccess::default(),
+ }
+ }
+ _ => crate::StorageAccess::default(),
+ },
+ };
+ let ty_name = TypeContext {
+ handle: var.ty,
+ gctx: self.module.to_ctx(),
+ names: self.names,
+ access: storage_access,
+ binding: self.binding,
+ first_time: false,
+ };
+
+ let (space, access, reference) = match var.space.to_msl_name() {
+ Some(space) if self.reference => {
+ let access = if var.space.needs_access_qualifier()
+ && !self.usage.contains(valid::GlobalUse::WRITE)
+ {
+ "const"
+ } else {
+ ""
+ };
+ (space, access, "&")
+ }
+ _ => ("", "", ""),
+ };
+
+ Ok(write!(
+ out,
+ "{}{}{}{}{}{} {}",
+ space,
+ if space.is_empty() { "" } else { " " },
+ ty_name,
+ if access.is_empty() { "" } else { " " },
+ access,
+ reference,
+ name,
+ )?)
+ }
+}
+
+pub struct Writer<W> {
+ out: W,
+ names: FastHashMap<NameKey, String>,
+ named_expressions: crate::NamedExpressions,
+ /// Set of expressions that need to be baked to avoid unnecessary repetition in output
+ need_bake_expressions: back::NeedBakeExpressions,
+ namer: proc::Namer,
+ #[cfg(test)]
+ put_expression_stack_pointers: FastHashSet<*const ()>,
+ #[cfg(test)]
+ put_block_stack_pointers: FastHashSet<*const ()>,
+ /// Set of (struct type, struct field index) denoting which fields require
+ /// padding inserted **before** them (i.e. between fields at index - 1 and index)
+ struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
+}
+
+impl crate::Scalar {
+ const fn to_msl_name(self) -> &'static str {
+ use crate::ScalarKind as Sk;
+ match self {
+ Self {
+ kind: Sk::Float,
+ width: _,
+ } => "float",
+ Self {
+ kind: Sk::Sint,
+ width: _,
+ } => "int",
+ Self {
+ kind: Sk::Uint,
+ width: _,
+ } => "uint",
+ Self {
+ kind: Sk::Bool,
+ width: _,
+ } => "bool",
+ Self {
+ kind: Sk::AbstractInt | Sk::AbstractFloat,
+ width: _,
+ } => unreachable!(),
+ }
+ }
+}
+
+const fn separate(need_separator: bool) -> &'static str {
+ if need_separator {
+ ","
+ } else {
+ ""
+ }
+}
+
+fn should_pack_struct_member(
+ members: &[crate::StructMember],
+ span: u32,
+ index: usize,
+ module: &crate::Module,
+) -> Option<crate::Scalar> {
+ let member = &members[index];
+
+ let ty_inner = &module.types[member.ty].inner;
+ let last_offset = member.offset + ty_inner.size(module.to_ctx());
+ let next_offset = match members.get(index + 1) {
+ Some(next) => next.offset,
+ None => span,
+ };
+ let is_tight = next_offset == last_offset;
+
+ match *ty_inner {
+ crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ scalar: scalar @ crate::Scalar { width: 4, .. },
+ } if is_tight => Some(scalar),
+ _ => None,
+ }
+}
+
+fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
+ match arena[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ if let Some(member) = members.last() {
+ if let crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } = arena[member.ty].inner
+ {
+ return true;
+ }
+ }
+ false
+ }
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } => true,
+ _ => false,
+ }
+}
+
+impl crate::AddressSpace {
+ /// Returns true if global variables in this address space are
+ /// passed in function arguments. These arguments need to be
+ /// passed through any functions called from the entry point.
+ const fn needs_pass_through(&self) -> bool {
+ match *self {
+ Self::Uniform
+ | Self::Storage { .. }
+ | Self::Private
+ | Self::WorkGroup
+ | Self::PushConstant
+ | Self::Handle => true,
+ Self::Function => false,
+ }
+ }
+
+ /// Returns true if the address space may need a "const" qualifier.
+ const fn needs_access_qualifier(&self) -> bool {
+ match *self {
+ //Note: we are ignoring the storage access here, and instead
+ // rely on the actual use of a global by functions. This means we
+ // may end up with "const" even if the binding is read-write,
+ // and that should be OK.
+ Self::Storage { .. } => true,
+ // These should always be read-write.
+ Self::Private | Self::WorkGroup => false,
+ // These translate to `constant` address space, no need for qualifiers.
+ Self::Uniform | Self::PushConstant => false,
+ // Not applicable.
+ Self::Handle | Self::Function => false,
+ }
+ }
+
+ const fn to_msl_name(self) -> Option<&'static str> {
+ match self {
+ Self::Handle => None,
+ Self::Uniform | Self::PushConstant => Some("constant"),
+ Self::Storage { .. } => Some("device"),
+ Self::Private | Self::Function => Some("thread"),
+ Self::WorkGroup => Some("threadgroup"),
+ }
+ }
+}
+
+impl crate::Type {
+ // Returns `true` if we need to emit an alias for this type.
+ const fn needs_alias(&self) -> bool {
+ use crate::TypeInner as Ti;
+
+ match self.inner {
+ // value types are concise enough, we only alias them if they are named
+ Ti::Scalar(_)
+ | Ti::Vector { .. }
+ | Ti::Matrix { .. }
+ | Ti::Atomic(_)
+ | Ti::Pointer { .. }
+ | Ti::ValuePointer { .. } => self.name.is_some(),
+ // composite types are better to be aliased, regardless of the name
+ Ti::Struct { .. } | Ti::Array { .. } => true,
+ // handle types may be different, depending on the global var access, so we always inline them
+ Ti::Image { .. }
+ | Ti::Sampler { .. }
+ | Ti::AccelerationStructure
+ | Ti::RayQuery
+ | Ti::BindingArray { .. } => false,
+ }
+ }
+}
+
+enum FunctionOrigin {
+ Handle(Handle<crate::Function>),
+ EntryPoint(proc::EntryPointIndex),
+}
+
+/// A level of detail argument.
+///
+/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
+/// save the clamped level of detail in a temporary variable whose name is based
+/// on the handle of the `ImageLoad` expression. But for other policies, we just
+/// use the expression directly.
+///
+/// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+#[derive(Clone, Copy)]
+enum LevelOfDetail {
+ Direct(Handle<crate::Expression>),
+ Restricted(Handle<crate::Expression>),
+}
+
+/// Values needed to select a particular texel for [`ImageLoad`] and [`ImageStore`].
+///
+/// When this is used in code paths unconcerned with the `Restrict` bounds check
+/// policy, the `LevelOfDetail` enum introduces an unneeded match, since `level`
+/// will always be either `None` or `Some(Direct(_))`. But this turns out not to
+/// be too awkward. If that changes, we can revisit.
+///
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+/// [`ImageStore`]: crate::Statement::ImageStore
+struct TexelAddress {
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ sample: Option<Handle<crate::Expression>>,
+ level: Option<LevelOfDetail>,
+}
+
+struct ExpressionContext<'a> {
+ function: &'a crate::Function,
+ origin: FunctionOrigin,
+ info: &'a valid::FunctionInfo,
+ module: &'a crate::Module,
+ mod_info: &'a valid::ModuleInfo,
+ pipeline_options: &'a PipelineOptions,
+ lang_version: (u8, u8),
+ policies: index::BoundsCheckPolicies,
+
+ /// A bitset containing the `Expression` handle indexes of expressions used
+ /// as indices in `ReadZeroSkipWrite`-policy accesses. These may need to be
+ /// cached in temporary variables. See `index::find_checked_indexes` for
+ /// details.
+ guarded_indices: BitSet,
+}
+
+impl<'a> ExpressionContext<'a> {
+ fn resolve_type(&self, handle: Handle<crate::Expression>) -> &'a crate::TypeInner {
+ self.info[handle].ty.inner_with(&self.module.types)
+ }
+
+ /// Return true if calls to `image`'s `read` and `write` methods should supply a level of detail.
+ ///
+ /// Only mipmapped images need to specify a level of detail. Since 1D
+ /// textures cannot have mipmaps, MSL requires that the level argument to
+ /// texture1d queries and accesses must be a constexpr 0. It's easiest
+ /// just to omit the level entirely for 1D textures.
+ fn image_needs_lod(&self, image: Handle<crate::Expression>) -> bool {
+ let image_ty = self.resolve_type(image);
+ if let crate::TypeInner::Image { dim, class, .. } = *image_ty {
+ class.is_mipmapped() && dim != crate::ImageDimension::D1
+ } else {
+ false
+ }
+ }
+
+ fn choose_bounds_check_policy(
+ &self,
+ pointer: Handle<crate::Expression>,
+ ) -> index::BoundsCheckPolicy {
+ self.policies
+ .choose_policy(pointer, &self.module.types, self.info)
+ }
+
+ fn access_needs_check(
+ &self,
+ base: Handle<crate::Expression>,
+ index: index::GuardedIndex,
+ ) -> Option<index::IndexableLength> {
+ index::access_needs_check(base, index, self.module, self.function, self.info)
+ }
+
+ fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
+ match self.function.expressions[expr_handle] {
+ crate::Expression::AccessIndex { base, index } => {
+ let ty = match *self.resolve_type(base) {
+ crate::TypeInner::Pointer { base, .. } => &self.module.types[base].inner,
+ ref ty => ty,
+ };
+ match *ty {
+ crate::TypeInner::Struct {
+ ref members, span, ..
+ } => should_pack_struct_member(members, span, index as usize, self.module),
+ _ => None,
+ }
+ }
+ _ => None,
+ }
+ }
+}
+
+struct StatementContext<'a> {
+ expression: ExpressionContext<'a>,
+ result_struct: Option<&'a str>,
+}
+
+impl<W: Write> Writer<W> {
+ /// Creates a new `Writer` instance.
+ pub fn new(out: W) -> Self {
+ Writer {
+ out,
+ names: FastHashMap::default(),
+ named_expressions: Default::default(),
+ need_bake_expressions: Default::default(),
+ namer: proc::Namer::default(),
+ #[cfg(test)]
+ put_expression_stack_pointers: Default::default(),
+ #[cfg(test)]
+ put_block_stack_pointers: Default::default(),
+ struct_member_pads: FastHashSet::default(),
+ }
+ }
+
+ /// Finishes writing and returns the output.
+ // See https://github.com/rust-lang/rust-clippy/issues/4979.
+ #[allow(clippy::missing_const_for_fn)]
+ pub fn finish(self) -> W {
+ self.out
+ }
+
+ fn put_call_parameters(
+ &mut self,
+ parameters: impl Iterator<Item = Handle<crate::Expression>>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
+ writer.put_expression(expr, context, true)
+ })
+ }
+
+ fn put_call_parameters_impl<C, E>(
+ &mut self,
+ parameters: impl Iterator<Item = Handle<crate::Expression>>,
+ ctx: &C,
+ put_expression: E,
+ ) -> BackendResult
+ where
+ E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
+ {
+ write!(self.out, "(")?;
+ for (i, handle) in parameters.enumerate() {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ put_expression(self, ctx, handle)?;
+ }
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_level_of_detail(
+ &mut self,
+ level: LevelOfDetail,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match level {
+ LevelOfDetail::Direct(expr) => self.put_expression(expr, context, true)?,
+ LevelOfDetail::Restricted(load) => {
+ write!(self.out, "{}{}", CLAMPED_LOD_LOAD_PREFIX, load.index())?
+ }
+ }
+ Ok(())
+ }
+
+ fn put_image_query(
+ &mut self,
+ image: Handle<crate::Expression>,
+ query: &str,
+ level: Option<LevelOfDetail>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_{query}(")?;
+ if let Some(level) = level {
+ self.put_level_of_detail(level, context)?;
+ }
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_image_size_query(
+ &mut self,
+ image: Handle<crate::Expression>,
+ level: Option<LevelOfDetail>,
+ kind: crate::ScalarKind,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ //Note: MSL only has separate width/height/depth queries,
+ // so compose the result of them.
+ let dim = match *context.resolve_type(image) {
+ crate::TypeInner::Image { dim, .. } => dim,
+ ref other => unreachable!("Unexpected type {:?}", other),
+ };
+ let scalar = crate::Scalar { kind, width: 4 };
+ let coordinate_type = scalar.to_msl_name();
+ match dim {
+ crate::ImageDimension::D1 => {
+ // Since 1D textures never have mipmaps, MSL requires that the
+ // `level` argument be a constexpr 0. It's simplest for us just
+ // to pass `None` and omit the level entirely.
+ if kind == crate::ScalarKind::Uint {
+ // No need to construct a vector. No cast needed.
+ self.put_image_query(image, "width", None, context)?;
+ } else {
+ // There's no definition for `int` in the `metal` namespace.
+ write!(self.out, "int(")?;
+ self.put_image_query(image, "width", None, context)?;
+ write!(self.out, ")")?;
+ }
+ }
+ crate::ImageDimension::D2 => {
+ write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
+ self.put_image_query(image, "width", level, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_query(image, "height", level, context)?;
+ write!(self.out, ")")?;
+ }
+ crate::ImageDimension::D3 => {
+ write!(self.out, "{NAMESPACE}::{coordinate_type}3(")?;
+ self.put_image_query(image, "width", level, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_query(image, "height", level, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_query(image, "depth", level, context)?;
+ write!(self.out, ")")?;
+ }
+ crate::ImageDimension::Cube => {
+ write!(self.out, "{NAMESPACE}::{coordinate_type}2(")?;
+ self.put_image_query(image, "width", level, context)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Ok(())
+ }
+
+ fn put_cast_to_uint_scalar_or_vector(
+ &mut self,
+ expr: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // coordinates in IR are int, but Metal expects uint
+ match *context.resolve_type(expr) {
+ crate::TypeInner::Scalar(_) => {
+ put_numeric_type(&mut self.out, crate::Scalar::U32, &[])?
+ }
+ crate::TypeInner::Vector { size, .. } => {
+ put_numeric_type(&mut self.out, crate::Scalar::U32, &[size])?
+ }
+ _ => return Err(Error::Validation),
+ };
+
+ write!(self.out, "(")?;
+ self.put_expression(expr, context, true)?;
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ fn put_image_sample_level(
+ &mut self,
+ image: Handle<crate::Expression>,
+ level: crate::SampleLevel,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let has_levels = context.image_needs_lod(image);
+ match level {
+ crate::SampleLevel::Auto => {}
+ crate::SampleLevel::Zero => {
+ //TODO: do we support Zero on `Sampled` image classes?
+ }
+ _ if !has_levels => {
+ log::warn!("1D image can't be sampled with level {:?}", level);
+ }
+ crate::SampleLevel::Exact(h) => {
+ write!(self.out, ", {NAMESPACE}::level(")?;
+ self.put_expression(h, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::SampleLevel::Bias(h) => {
+ write!(self.out, ", {NAMESPACE}::bias(")?;
+ self.put_expression(h, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ write!(self.out, ", {NAMESPACE}::gradient2d(")?;
+ self.put_expression(x, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(y, context, true)?;
+ write!(self.out, ")")?;
+ }
+ }
+ Ok(())
+ }
+
+ fn put_image_coordinate_limits(
+ &mut self,
+ image: Handle<crate::Expression>,
+ level: Option<LevelOfDetail>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
+ write!(self.out, " - 1")?;
+ Ok(())
+ }
+
+ /// General function for writing restricted image indexes.
+ ///
+ /// This is used to produce restricted mip levels, array indices, and sample
+ /// indices for [`ImageLoad`] and [`ImageStore`] accesses under the
+ /// [`Restrict`] bounds check policy.
+ ///
+ /// This function writes an expression of the form:
+ ///
+ /// ```ignore
+ ///
+ /// metal::min(uint(INDEX), IMAGE.LIMIT_METHOD() - 1)
+ ///
+ /// ```
+ ///
+ /// [`ImageLoad`]: crate::Expression::ImageLoad
+ /// [`ImageStore`]: crate::Statement::ImageStore
+ /// [`Restrict`]: index::BoundsCheckPolicy::Restrict
+ fn put_restricted_scalar_image_index(
+ &mut self,
+ image: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ limit_method: &str,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ write!(self.out, "{NAMESPACE}::min(uint(")?;
+ self.put_expression(index, context, true)?;
+ write!(self.out, "), ")?;
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".{limit_method}() - 1)")?;
+ Ok(())
+ }
+
+ fn put_restricted_texel_address(
+ &mut self,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // Write the coordinate.
+ write!(self.out, "{NAMESPACE}::min(")?;
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
+ write!(self.out, ", ")?;
+ self.put_image_coordinate_limits(image, address.level, context)?;
+ write!(self.out, ")")?;
+
+ // Write the array index, if present.
+ if let Some(array_index) = address.array_index {
+ write!(self.out, ", ")?;
+ self.put_restricted_scalar_image_index(image, array_index, "get_array_size", context)?;
+ }
+
+ // Write the sample index, if present.
+ if let Some(sample) = address.sample {
+ write!(self.out, ", ")?;
+ self.put_restricted_scalar_image_index(image, sample, "get_num_samples", context)?;
+ }
+
+ // The level of detail should be clamped and cached by
+ // `put_cache_restricted_level`, so we don't need to clamp it here.
+ if let Some(level) = address.level {
+ write!(self.out, ", ")?;
+ self.put_level_of_detail(level, context)?;
+ }
+
+ Ok(())
+ }
+
+ /// Write an expression that is true if the given image access is in bounds.
+ fn put_image_access_bounds_check(
+ &mut self,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let mut conjunction = "";
+
+ // First, check the level of detail. Only if that is in bounds can we
+ // use it to find the appropriate bounds for the coordinates.
+ let level = if let Some(level) = address.level {
+ write!(self.out, "uint(")?;
+ self.put_level_of_detail(level, context)?;
+ write!(self.out, ") < ")?;
+ self.put_expression(image, context, true)?;
+ write!(self.out, ".get_num_mip_levels()")?;
+ conjunction = " && ";
+ Some(level)
+ } else {
+ None
+ };
+
+ // Check sample index, if present.
+ if let Some(sample) = address.sample {
+ write!(self.out, "uint(")?;
+ self.put_expression(sample, context, true)?;
+ write!(self.out, ") < ")?;
+ self.put_expression(image, context, true)?;
+ write!(self.out, ".get_num_samples()")?;
+ conjunction = " && ";
+ }
+
+ // Check array index, if present.
+ if let Some(array_index) = address.array_index {
+ write!(self.out, "{conjunction}uint(")?;
+ self.put_expression(array_index, context, true)?;
+ write!(self.out, ") < ")?;
+ self.put_expression(image, context, true)?;
+ write!(self.out, ".get_array_size()")?;
+ conjunction = " && ";
+ }
+
+ // Finally, check if the coordinates are within bounds.
+ let coord_is_vector = match *context.resolve_type(address.coordinate) {
+ crate::TypeInner::Vector { .. } => true,
+ _ => false,
+ };
+ write!(self.out, "{conjunction}")?;
+ if coord_is_vector {
+ write!(self.out, "{NAMESPACE}::all(")?;
+ }
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
+ write!(self.out, " < ")?;
+ self.put_image_size_query(image, level, crate::ScalarKind::Uint, context)?;
+ if coord_is_vector {
+ write!(self.out, ")")?;
+ }
+
+ Ok(())
+ }
+
+ fn put_image_load(
+ &mut self,
+ load: Handle<crate::Expression>,
+ image: Handle<crate::Expression>,
+ mut address: TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match context.policies.image_load {
+ proc::BoundsCheckPolicy::Restrict => {
+ // Use the cached restricted level of detail, if any. Omit the
+ // level altogether for 1D textures.
+ if address.level.is_some() {
+ address.level = if context.image_needs_lod(image) {
+ Some(LevelOfDetail::Restricted(load))
+ } else {
+ None
+ }
+ }
+
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".read(")?;
+ self.put_restricted_texel_address(image, &address, context)?;
+ write!(self.out, ")")?;
+ }
+ proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
+ write!(self.out, "(")?;
+ self.put_image_access_bounds_check(image, &address, context)?;
+ write!(self.out, " ? ")?;
+ self.put_unchecked_image_load(image, &address, context)?;
+ write!(self.out, ": DefaultConstructible())")?;
+ }
+ proc::BoundsCheckPolicy::Unchecked => {
+ self.put_unchecked_image_load(image, &address, context)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_image_load(
+ &mut self,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".read(")?;
+ // coordinates in IR are int, but Metal expects uint
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, context)?;
+ if let Some(expr) = address.array_index {
+ write!(self.out, ", ")?;
+ self.put_expression(expr, context, true)?;
+ }
+ if let Some(sample) = address.sample {
+ write!(self.out, ", ")?;
+ self.put_expression(sample, context, true)?;
+ }
+ if let Some(level) = address.level {
+ if context.image_needs_lod(image) {
+ write!(self.out, ", ")?;
+ self.put_level_of_detail(level, context)?;
+ }
+ }
+ write!(self.out, ")")?;
+
+ Ok(())
+ }
+
+ fn put_image_store(
+ &mut self,
+ level: back::Level,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ value: Handle<crate::Expression>,
+ context: &StatementContext,
+ ) -> BackendResult {
+ match context.expression.policies.image_store {
+ proc::BoundsCheckPolicy::Restrict => {
+ // We don't have a restricted level value, because we don't
+ // support writes to mipmapped textures.
+ debug_assert!(address.level.is_none());
+
+ write!(self.out, "{level}")?;
+ self.put_expression(image, &context.expression, false)?;
+ write!(self.out, ".write(")?;
+ self.put_expression(value, &context.expression, true)?;
+ write!(self.out, ", ")?;
+ self.put_restricted_texel_address(image, address, &context.expression)?;
+ writeln!(self.out, ");")?;
+ }
+ proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
+ write!(self.out, "{level}if (")?;
+ self.put_image_access_bounds_check(image, address, &context.expression)?;
+ writeln!(self.out, ") {{")?;
+ self.put_unchecked_image_store(level.next(), image, address, value, context)?;
+ writeln!(self.out, "{level}}}")?;
+ }
+ proc::BoundsCheckPolicy::Unchecked => {
+ self.put_unchecked_image_store(level, image, address, value, context)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_image_store(
+ &mut self,
+ level: back::Level,
+ image: Handle<crate::Expression>,
+ address: &TexelAddress,
+ value: Handle<crate::Expression>,
+ context: &StatementContext,
+ ) -> BackendResult {
+ write!(self.out, "{level}")?;
+ self.put_expression(image, &context.expression, false)?;
+ write!(self.out, ".write(")?;
+ self.put_expression(value, &context.expression, true)?;
+ write!(self.out, ", ")?;
+ // coordinates in IR are int, but Metal expects uint
+ self.put_cast_to_uint_scalar_or_vector(address.coordinate, &context.expression)?;
+ if let Some(expr) = address.array_index {
+ write!(self.out, ", ")?;
+ self.put_expression(expr, &context.expression, true)?;
+ }
+ writeln!(self.out, ");")?;
+
+ Ok(())
+ }
+
+ /// Write the maximum valid index of the dynamically sized array at the end of `handle`.
+ ///
+ /// The 'maximum valid index' is simply one less than the array's length.
+ ///
+ /// This emits an expression of the form `a / b`, so the caller must
+ /// parenthesize its output if it will be applying operators of higher
+ /// precedence.
+ ///
+ /// `handle` must be the handle of a global variable whose final member is a
+ /// dynamically sized array.
+ fn put_dynamic_array_max_index(
+ &mut self,
+ handle: Handle<crate::GlobalVariable>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let global = &context.module.global_variables[handle];
+ let (offset, array_ty) = match context.module.types[global.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => match members.last() {
+ Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
+ None => return Err(Error::Validation),
+ },
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } => (0, global.ty),
+ _ => return Err(Error::Validation),
+ };
+
+ let (size, stride) = match context.module.types[array_ty].inner {
+ crate::TypeInner::Array { base, stride, .. } => (
+ context.module.types[base]
+ .inner
+ .size(context.module.to_ctx()),
+ stride,
+ ),
+ _ => return Err(Error::Validation),
+ };
+
+ // When the stride length is larger than the size, the final element's stride of
+ // bytes would have padding following the value. But the buffer size in
+ // `buffer_sizes.sizeN` may not include this padding - it only needs to be large
+ // enough to hold the actual values' bytes.
+ //
+ // So subtract off the size to get a byte size that falls at the start or within
+ // the final element. Then divide by the stride size, to get one less than the
+ // length, and then add one. This works even if the buffer size does include the
+ // stride padding, since division rounds towards zero (MSL 2.4 §6.1). It will fail
+ // if there are zero elements in the array, but the WebGPU `validating shader binding`
+ // rules, together with draw-time validation when `minBindingSize` is zero,
+ // prevent that.
+ write!(
+ self.out,
+ "(_buffer_sizes.size{idx} - {offset} - {size}) / {stride}",
+ idx = handle.index(),
+ offset = offset,
+ size = size,
+ stride = stride,
+ )?;
+ Ok(())
+ }
+
+ fn put_atomic_fetch(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ key: &str,
+ value: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ self.put_atomic_operation(pointer, "fetch_", key, value, context)
+ }
+
+ fn put_atomic_operation(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ key1: &str,
+ key2: &str,
+ value: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // If the pointer we're passing to the atomic operation needs to be conditional
+ // for `ReadZeroSkipWrite`, the condition needs to *surround* the atomic op, and
+ // the pointer operand should be unchecked.
+ let policy = context.choose_bounds_check_policy(pointer);
+ let checked = policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(pointer, context, back::Level(0), "")?;
+
+ // If requested and successfully put bounds checks, continue the ternary expression.
+ if checked {
+ write!(self.out, " ? ")?;
+ }
+
+ write!(
+ self.out,
+ "{NAMESPACE}::atomic_{key1}{key2}_explicit({ATOMIC_REFERENCE}"
+ )?;
+ self.put_access_chain(pointer, policy, context)?;
+ write!(self.out, ", ")?;
+ self.put_expression(value, context, true)?;
+ write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
+
+ // Finish the ternary expression.
+ if checked {
+ write!(self.out, " : DefaultConstructible()")?;
+ }
+
+ Ok(())
+ }
+
+ /// Emit code for the arithmetic expression of the dot product.
+ ///
+ fn put_dot_product(
+ &mut self,
+ arg: Handle<crate::Expression>,
+ arg1: Handle<crate::Expression>,
+ size: usize,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ // Write parentheses around the dot product expression to prevent operators
+ // with different precedences from applying earlier.
+ write!(self.out, "(")?;
+
+ // Cycle trough all the components of the vector
+ for index in 0..size {
+ let component = back::COMPONENTS[index];
+ // Write the addition to the previous product
+ // This will print an extra '+' at the beginning but that is fine in msl
+ write!(self.out, " + ")?;
+ // Write the first vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.put_expression(arg, context, true)?;
+ // Access the current component on the first vector
+ write!(self.out, ".{component} * ")?;
+ // Write the second vector expression, this expression is marked to be
+ // cached so unless it can't be cached (for example, it's a Constant)
+ // it shouldn't produce large expressions.
+ self.put_expression(arg1, context, true)?;
+ // Access the current component on the second vector
+ write!(self.out, ".{component}")?;
+ }
+
+ write!(self.out, ")")?;
+ Ok(())
+ }
+
+ /// Emit code for the sign(i32) expression.
+ ///
+ fn put_isign(
+ &mut self,
+ arg: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ write!(self.out, "{NAMESPACE}::select({NAMESPACE}::select(")?;
+ match context.resolve_type(arg) {
+ &crate::TypeInner::Vector { size, .. } => {
+ let size = back::vector_size_str(size);
+ write!(self.out, "int{size}(-1), int{size}(1)")?;
+ }
+ _ => {
+ write!(self.out, "-1, 1")?;
+ }
+ }
+ write!(self.out, ", (")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " > 0)), 0, (")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " == 0))")?;
+ Ok(())
+ }
+
+ fn put_const_expression(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ module: &crate::Module,
+ mod_info: &valid::ModuleInfo,
+ ) -> BackendResult {
+ self.put_possibly_const_expression(
+ expr_handle,
+ &module.const_expressions,
+ module,
+ mod_info,
+ &(module, mod_info),
+ |&(_, mod_info), expr| &mod_info[expr],
+ |writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info),
+ )
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ fn put_possibly_const_expression<C, I, E>(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ expressions: &crate::Arena<crate::Expression>,
+ module: &crate::Module,
+ mod_info: &valid::ModuleInfo,
+ ctx: &C,
+ get_expr_ty: I,
+ put_expression: E,
+ ) -> BackendResult
+ where
+ I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
+ E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
+ {
+ match expressions[expr_handle] {
+ crate::Expression::Literal(literal) => match literal {
+ crate::Literal::F64(_) => {
+ return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
+ }
+ crate::Literal::F32(value) => {
+ if value.is_infinite() {
+ let sign = if value.is_sign_negative() { "-" } else { "" };
+ write!(self.out, "{sign}INFINITY")?;
+ } else if value.is_nan() {
+ write!(self.out, "NAN")?;
+ } else {
+ let suffix = if value.fract() == 0.0 { ".0" } else { "" };
+ write!(self.out, "{value}{suffix}")?;
+ }
+ }
+ crate::Literal::U32(value) => {
+ write!(self.out, "{value}u")?;
+ }
+ crate::Literal::I32(value) => {
+ write!(self.out, "{value}")?;
+ }
+ crate::Literal::I64(value) => {
+ write!(self.out, "{value}L")?;
+ }
+ crate::Literal::Bool(value) => {
+ write!(self.out, "{value}")?;
+ }
+ crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
+ return Err(Error::Validation);
+ }
+ },
+ crate::Expression::Constant(handle) => {
+ let constant = &module.constants[handle];
+ if constant.name.is_some() {
+ write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
+ } else {
+ self.put_const_expression(constant.init, module, mod_info)?;
+ }
+ }
+ crate::Expression::ZeroValue(ty) => {
+ let ty_name = TypeContext {
+ handle: ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{ty_name} {{}}")?;
+ }
+ crate::Expression::Compose { ty, ref components } => {
+ let ty_name = TypeContext {
+ handle: ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{ty_name}")?;
+ match module.types[ty].inner {
+ crate::TypeInner::Scalar(_)
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. } => {
+ self.put_call_parameters_impl(
+ components.iter().copied(),
+ ctx,
+ put_expression,
+ )?;
+ }
+ crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
+ write!(self.out, " {{")?;
+ for (index, &component) in components.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ // insert padding initialization, if needed
+ if self.struct_member_pads.contains(&(ty, index as u32)) {
+ write!(self.out, "{{}}, ")?;
+ }
+ put_expression(self, ctx, component)?;
+ }
+ write!(self.out, "}}")?;
+ }
+ _ => return Err(Error::UnsupportedCompose(ty)),
+ }
+ }
+ crate::Expression::Splat { size, value } => {
+ let scalar = match *get_expr_ty(ctx, value).inner_with(&module.types) {
+ crate::TypeInner::Scalar(scalar) => scalar,
+ _ => return Err(Error::Validation),
+ };
+ put_numeric_type(&mut self.out, scalar, &[size])?;
+ write!(self.out, "(")?;
+ put_expression(self, ctx, value)?;
+ write!(self.out, ")")?;
+ }
+ _ => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Emit code for the expression `expr_handle`.
+ ///
+ /// The `is_scoped` argument is true if the surrounding operators have the
+ /// precedence of the comma operator, or lower. So, for example:
+ ///
+ /// - Pass `true` for `is_scoped` when writing function arguments, an
+ /// expression statement, an initializer expression, or anything already
+ /// wrapped in parenthesis.
+ ///
+ /// - Pass `false` if it is an operand of a `?:` operator, a `[]`, or really
+ /// almost anything else.
+ fn put_expression(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ // Add to the set in order to track the stack size.
+ #[cfg(test)]
+ #[allow(trivial_casts)]
+ self.put_expression_stack_pointers
+ .insert(&expr_handle as *const _ as *const ());
+
+ if let Some(name) = self.named_expressions.get(&expr_handle) {
+ write!(self.out, "{name}")?;
+ return Ok(());
+ }
+
+ let expression = &context.function.expressions[expr_handle];
+ log::trace!("expression {:?} = {:?}", expr_handle, expression);
+ match *expression {
+ crate::Expression::Literal(_)
+ | crate::Expression::Constant(_)
+ | crate::Expression::ZeroValue(_)
+ | crate::Expression::Compose { .. }
+ | crate::Expression::Splat { .. } => {
+ self.put_possibly_const_expression(
+ expr_handle,
+ &context.function.expressions,
+ context.module,
+ context.mod_info,
+ context,
+ |context, expr: Handle<crate::Expression>| &context.info[expr].ty,
+ |writer, context, expr| writer.put_expression(expr, context, true),
+ )?;
+ }
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => {
+ // This is an acceptable place to generate a `ReadZeroSkipWrite` check.
+ // Since `put_bounds_checks` and `put_access_chain` handle an entire
+ // access chain at a time, recursing back through `put_expression` only
+ // for index expressions and the base object, we will never see intermediate
+ // `Access` or `AccessIndex` expressions here.
+ let policy = context.choose_bounds_check_policy(base);
+ if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(
+ expr_handle,
+ context,
+ back::Level(0),
+ if is_scoped { "" } else { "(" },
+ )?
+ {
+ write!(self.out, " ? ")?;
+ self.put_access_chain(expr_handle, policy, context)?;
+ write!(self.out, " : DefaultConstructible()")?;
+
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ } else {
+ self.put_access_chain(expr_handle, policy, context)?;
+ }
+ }
+ crate::Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.put_wrapped_expression_for_packed_vec3_access(vector, context, false)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ write!(self.out, "{}", back::COMPONENTS[sc as usize])?;
+ }
+ }
+ crate::Expression::FunctionArgument(index) => {
+ let name_key = match context.origin {
+ FunctionOrigin::Handle(handle) => NameKey::FunctionArgument(handle, index),
+ FunctionOrigin::EntryPoint(ep_index) => {
+ NameKey::EntryPointArgument(ep_index, index)
+ }
+ };
+ let name = &self.names[&name_key];
+ write!(self.out, "{name}")?;
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{name}")?;
+ }
+ crate::Expression::LocalVariable(handle) => {
+ let name_key = match context.origin {
+ FunctionOrigin::Handle(fun_handle) => {
+ NameKey::FunctionLocal(fun_handle, handle)
+ }
+ FunctionOrigin::EntryPoint(ep_index) => {
+ NameKey::EntryPointLocal(ep_index, handle)
+ }
+ };
+ let name = &self.names[&name_key];
+ write!(self.out, "{name}")?;
+ }
+ crate::Expression::Load { pointer } => self.put_load(pointer, context, is_scoped)?,
+ crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ let main_op = match gather {
+ Some(_) => "gather",
+ None => "sample",
+ };
+ let comparison_op = match depth_ref {
+ Some(_) => "_compare",
+ None => "",
+ };
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".{main_op}{comparison_op}(")?;
+ self.put_expression(sampler, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(coordinate, context, true)?;
+ if let Some(expr) = array_index {
+ write!(self.out, ", ")?;
+ self.put_expression(expr, context, true)?;
+ }
+ if let Some(dref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.put_expression(dref, context, true)?;
+ }
+
+ self.put_image_sample_level(image, level, context)?;
+
+ if let Some(offset) = offset {
+ write!(self.out, ", ")?;
+ self.put_const_expression(offset, context.module, context.mod_info)?;
+ }
+
+ match gather {
+ None | Some(crate::SwizzleComponent::X) => {}
+ Some(component) => {
+ let is_cube_map = match *context.resolve_type(image) {
+ crate::TypeInner::Image {
+ dim: crate::ImageDimension::Cube,
+ ..
+ } => true,
+ _ => false,
+ };
+ // Offset always comes before the gather, except
+ // in cube maps where it's not applicable
+ if offset.is_none() && !is_cube_map {
+ write!(self.out, ", {NAMESPACE}::int2(0)")?;
+ }
+ let letter = back::COMPONENTS[component as usize];
+ write!(self.out, ", {NAMESPACE}::component::{letter}")?;
+ }
+ }
+ write!(self.out, ")")?;
+ }
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ let address = TexelAddress {
+ coordinate,
+ array_index,
+ sample,
+ level: level.map(LevelOfDetail::Direct),
+ };
+ self.put_image_load(expr_handle, image, address, context)?;
+ }
+ //Note: for all the queries, the signed integers are expected,
+ // so a conversion is needed.
+ crate::Expression::ImageQuery { image, query } => match query {
+ crate::ImageQuery::Size { level } => {
+ self.put_image_size_query(
+ image,
+ level.map(LevelOfDetail::Direct),
+ crate::ScalarKind::Uint,
+ context,
+ )?;
+ }
+ crate::ImageQuery::NumLevels => {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_num_mip_levels()")?;
+ }
+ crate::ImageQuery::NumLayers => {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_array_size()")?;
+ }
+ crate::ImageQuery::NumSamples => {
+ self.put_expression(image, context, false)?;
+ write!(self.out, ".get_num_samples()")?;
+ }
+ },
+ crate::Expression::Unary { op, expr } => {
+ let op_str = match op {
+ crate::UnaryOperator::Negate => "-",
+ crate::UnaryOperator::LogicalNot => "!",
+ crate::UnaryOperator::BitwiseNot => "~",
+ };
+ write!(self.out, "{op_str}(")?;
+ self.put_expression(expr, context, false)?;
+ write!(self.out, ")")?;
+ }
+ crate::Expression::Binary { op, left, right } => {
+ let op_str = crate::back::binary_operation_str(op);
+ let kind = context
+ .resolve_type(left)
+ .scalar_kind()
+ .ok_or(Error::UnsupportedBinaryOp(op))?;
+
+ // TODO: handle undefined behavior of BinaryOperator::Modulo
+ //
+ // sint:
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ // if sign(left) == -1 || sign(right) == -1 return result as defined by WGSL
+ //
+ // uint:
+ // if right == 0 return 0
+ //
+ // float:
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+
+ if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float {
+ write!(self.out, "{NAMESPACE}::fmod(")?;
+ self.put_expression(left, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(right, context, true)?;
+ write!(self.out, ")")?;
+ } else {
+ if !is_scoped {
+ write!(self.out, "(")?;
+ }
+
+ // Cast packed vector if necessary
+ // Packed vector - matrix multiplications are not supported in MSL
+ if op == crate::BinaryOperator::Multiply
+ && matches!(
+ context.resolve_type(right),
+ &crate::TypeInner::Matrix { .. }
+ )
+ {
+ self.put_wrapped_expression_for_packed_vec3_access(left, context, false)?;
+ } else {
+ self.put_expression(left, context, false)?;
+ }
+
+ write!(self.out, " {op_str} ")?;
+
+ // See comment above
+ if op == crate::BinaryOperator::Multiply
+ && matches!(context.resolve_type(left), &crate::TypeInner::Matrix { .. })
+ {
+ self.put_wrapped_expression_for_packed_vec3_access(right, context, false)?;
+ } else {
+ self.put_expression(right, context, false)?;
+ }
+
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ }
+ }
+ crate::Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => match *context.resolve_type(condition) {
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Bool,
+ ..
+ }) => {
+ if !is_scoped {
+ write!(self.out, "(")?;
+ }
+ self.put_expression(condition, context, false)?;
+ write!(self.out, " ? ")?;
+ self.put_expression(accept, context, false)?;
+ write!(self.out, " : ")?;
+ self.put_expression(reject, context, false)?;
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ }
+ crate::TypeInner::Vector {
+ scalar:
+ crate::Scalar {
+ kind: crate::ScalarKind::Bool,
+ ..
+ },
+ ..
+ } => {
+ write!(self.out, "{NAMESPACE}::select(")?;
+ self.put_expression(reject, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(accept, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(condition, context, true)?;
+ write!(self.out, ")")?;
+ }
+ _ => return Err(Error::Validation),
+ },
+ crate::Expression::Derivative { axis, expr, .. } => {
+ use crate::DerivativeAxis as Axis;
+ let op = match axis {
+ Axis::X => "dfdx",
+ Axis::Y => "dfdy",
+ Axis::Width => "fwidth",
+ };
+ write!(self.out, "{NAMESPACE}::{op}")?;
+ self.put_call_parameters(iter::once(expr), context)?;
+ }
+ crate::Expression::Relational { fun, argument } => {
+ let op = match fun {
+ crate::RelationalFunction::Any => "any",
+ crate::RelationalFunction::All => "all",
+ crate::RelationalFunction::IsNan => "isnan",
+ crate::RelationalFunction::IsInf => "isinf",
+ };
+ write!(self.out, "{NAMESPACE}::{op}")?;
+ self.put_call_parameters(iter::once(argument), context)?;
+ }
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ let arg_type = context.resolve_type(arg);
+ let scalar_argument = match arg_type {
+ &crate::TypeInner::Scalar(_) => true,
+ _ => false,
+ };
+
+ let fun_name = match fun {
+ // comparison
+ Mf::Abs => "abs",
+ Mf::Min => "min",
+ Mf::Max => "max",
+ Mf::Clamp => "clamp",
+ Mf::Saturate => "saturate",
+ // trigonometry
+ Mf::Cos => "cos",
+ Mf::Cosh => "cosh",
+ Mf::Sin => "sin",
+ Mf::Sinh => "sinh",
+ Mf::Tan => "tan",
+ Mf::Tanh => "tanh",
+ Mf::Acos => "acos",
+ Mf::Asin => "asin",
+ Mf::Atan => "atan",
+ Mf::Atan2 => "atan2",
+ Mf::Asinh => "asinh",
+ Mf::Acosh => "acosh",
+ Mf::Atanh => "atanh",
+ Mf::Radians => "",
+ Mf::Degrees => "",
+ // decomposition
+ Mf::Ceil => "ceil",
+ Mf::Floor => "floor",
+ Mf::Round => "rint",
+ Mf::Fract => "fract",
+ Mf::Trunc => "trunc",
+ Mf::Modf => MODF_FUNCTION,
+ Mf::Frexp => FREXP_FUNCTION,
+ Mf::Ldexp => "ldexp",
+ // exponent
+ Mf::Exp => "exp",
+ Mf::Exp2 => "exp2",
+ Mf::Log => "log",
+ Mf::Log2 => "log2",
+ Mf::Pow => "pow",
+ // geometry
+ Mf::Dot => match *context.resolve_type(arg) {
+ crate::TypeInner::Vector {
+ scalar:
+ crate::Scalar {
+ kind: crate::ScalarKind::Float,
+ ..
+ },
+ ..
+ } => "dot",
+ crate::TypeInner::Vector { size, .. } => {
+ return self.put_dot_product(arg, arg1.unwrap(), size as usize, context)
+ }
+ _ => unreachable!(
+ "Correct TypeInner for dot product should be already validated"
+ ),
+ },
+ Mf::Outer => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
+ Mf::Cross => "cross",
+ Mf::Distance => "distance",
+ Mf::Length if scalar_argument => "abs",
+ Mf::Length => "length",
+ Mf::Normalize => "normalize",
+ Mf::FaceForward => "faceforward",
+ Mf::Reflect => "reflect",
+ Mf::Refract => "refract",
+ // computational
+ Mf::Sign => match arg_type.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => {
+ return self.put_isign(arg, context);
+ }
+ _ => "sign",
+ },
+ Mf::Fma => "fma",
+ Mf::Mix => "mix",
+ Mf::Step => "step",
+ Mf::SmoothStep => "smoothstep",
+ Mf::Sqrt => "sqrt",
+ Mf::InverseSqrt => "rsqrt",
+ Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
+ Mf::Transpose => "transpose",
+ Mf::Determinant => "determinant",
+ // bits
+ Mf::CountTrailingZeros => "ctz",
+ Mf::CountLeadingZeros => "clz",
+ Mf::CountOneBits => "popcount",
+ Mf::ReverseBits => "reverse_bits",
+ Mf::ExtractBits => "extract_bits",
+ Mf::InsertBits => "insert_bits",
+ Mf::FindLsb => "",
+ Mf::FindMsb => "",
+ // data packing
+ Mf::Pack4x8snorm => "pack_float_to_snorm4x8",
+ Mf::Pack4x8unorm => "pack_float_to_unorm4x8",
+ Mf::Pack2x16snorm => "pack_float_to_snorm2x16",
+ Mf::Pack2x16unorm => "pack_float_to_unorm2x16",
+ Mf::Pack2x16float => "",
+ // data unpacking
+ Mf::Unpack4x8snorm => "unpack_snorm4x8_to_float",
+ Mf::Unpack4x8unorm => "unpack_unorm4x8_to_float",
+ Mf::Unpack2x16snorm => "unpack_snorm2x16_to_float",
+ Mf::Unpack2x16unorm => "unpack_unorm2x16_to_float",
+ Mf::Unpack2x16float => "",
+ };
+
+ match fun {
+ Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => {
+ // reverse_bits is listed as requiring MSL 2.1 but that
+ // is a copy/paste error. Looking at previous snapshots
+ // on web.archive.org it's present in MSL 1.2.
+ //
+ // https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
+ // also talks about MSL 1.2 adding "New integer
+ // functions to extract, insert, and reverse bits, as
+ // described in Integer Functions."
+ if context.lang_version < (1, 2) {
+ return Err(Error::UnsupportedFunction(fun_name.to_string()));
+ }
+ }
+ _ => {}
+ }
+
+ if fun == Mf::Distance && scalar_argument {
+ write!(self.out, "{NAMESPACE}::abs(")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, " - ")?;
+ self.put_expression(arg1.unwrap(), context, false)?;
+ write!(self.out, ")")?;
+ } else if fun == Mf::FindLsb {
+ write!(self.out, "((({NAMESPACE}::ctz(")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ") + 1) % 33) - 1)")?;
+ } else if fun == Mf::FindMsb {
+ let inner = context.resolve_type(arg);
+
+ write!(self.out, "{NAMESPACE}::select(31 - {NAMESPACE}::clz(")?;
+
+ if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
+ write!(self.out, "{NAMESPACE}::select(")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ", ~")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " < 0)")?;
+ } else {
+ self.put_expression(arg, context, true)?;
+ }
+
+ write!(self.out, "), ")?;
+
+ // or metal will complain that select is ambiguous
+ match *inner {
+ crate::TypeInner::Vector { size, scalar } => {
+ let size = back::vector_size_str(size);
+ if let crate::ScalarKind::Sint = scalar.kind {
+ write!(self.out, "int{size}")?;
+ } else {
+ write!(self.out, "uint{size}")?;
+ }
+ }
+ crate::TypeInner::Scalar(scalar) => {
+ if let crate::ScalarKind::Sint = scalar.kind {
+ write!(self.out, "int")?;
+ } else {
+ write!(self.out, "uint")?;
+ }
+ }
+ _ => (),
+ }
+
+ write!(self.out, "(-1), ")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " == 0 || ")?;
+ self.put_expression(arg, context, true)?;
+ write!(self.out, " == -1)")?;
+ } else if fun == Mf::Unpack2x16float {
+ write!(self.out, "float2(as_type<half2>(")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, "))")?;
+ } else if fun == Mf::Pack2x16float {
+ write!(self.out, "as_type<uint>(half2(")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, "))")?;
+ } else if fun == Mf::Radians {
+ write!(self.out, "((")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, ") * 0.017453292519943295474)")?;
+ } else if fun == Mf::Degrees {
+ write!(self.out, "((")?;
+ self.put_expression(arg, context, false)?;
+ write!(self.out, ") * 57.295779513082322865)")?;
+ } else if fun == Mf::Modf || fun == Mf::Frexp {
+ write!(self.out, "{fun_name}")?;
+ self.put_call_parameters(iter::once(arg), context)?;
+ } else {
+ write!(self.out, "{NAMESPACE}::{fun_name}")?;
+ self.put_call_parameters(
+ iter::once(arg).chain(arg1).chain(arg2).chain(arg3),
+ context,
+ )?;
+ }
+ }
+ crate::Expression::As {
+ expr,
+ kind,
+ convert,
+ } => match *context.resolve_type(expr) {
+ crate::TypeInner::Scalar(src) | crate::TypeInner::Vector { scalar: src, .. } => {
+ let target_scalar = crate::Scalar {
+ kind,
+ width: convert.unwrap_or(src.width),
+ };
+ let is_bool_cast =
+ kind == crate::ScalarKind::Bool || src.kind == crate::ScalarKind::Bool;
+ let op = match convert {
+ Some(w) if w == src.width || is_bool_cast => "static_cast",
+ Some(8) if kind == crate::ScalarKind::Float => {
+ return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
+ }
+ Some(_) => return Err(Error::Validation),
+ None => "as_type",
+ };
+ write!(self.out, "{op}<")?;
+ match *context.resolve_type(expr) {
+ crate::TypeInner::Vector { size, .. } => {
+ put_numeric_type(&mut self.out, target_scalar, &[size])?
+ }
+ _ => put_numeric_type(&mut self.out, target_scalar, &[])?,
+ };
+ write!(self.out, ">(")?;
+ self.put_expression(expr, context, true)?;
+ write!(self.out, ")")?;
+ }
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ let target_scalar = crate::Scalar {
+ kind,
+ width: convert.unwrap_or(scalar.width),
+ };
+ put_numeric_type(&mut self.out, target_scalar, &[rows, columns])?;
+ write!(self.out, "(")?;
+ self.put_expression(expr, context, true)?;
+ write!(self.out, ")")?;
+ }
+ _ => return Err(Error::Validation),
+ },
+ // has to be a named expression
+ crate::Expression::CallResult(_)
+ | crate::Expression::AtomicResult { .. }
+ | crate::Expression::WorkGroupUniformLoadResult { .. }
+ | crate::Expression::RayQueryProceedResult => {
+ unreachable!()
+ }
+ crate::Expression::ArrayLength(expr) => {
+ // Find the global to which the array belongs.
+ let global = match context.function.expressions[expr] {
+ crate::Expression::AccessIndex { base, .. } => {
+ match context.function.expressions[base] {
+ crate::Expression::GlobalVariable(handle) => handle,
+ _ => return Err(Error::Validation),
+ }
+ }
+ crate::Expression::GlobalVariable(handle) => handle,
+ _ => return Err(Error::Validation),
+ };
+
+ if !is_scoped {
+ write!(self.out, "(")?;
+ }
+ write!(self.out, "1 + ")?;
+ self.put_dynamic_array_max_index(global, context)?;
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ }
+ crate::Expression::RayQueryGetIntersection { query, committed } => {
+ if context.lang_version < (2, 4) {
+ return Err(Error::UnsupportedRayTracing);
+ }
+
+ if !committed {
+ unimplemented!()
+ }
+ let ty = context.module.special_types.ray_intersection.unwrap();
+ let type_name = &self.names[&NameKey::Type(ty)];
+ write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?;
+ self.put_expression(query, context, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?;
+ let fields = [
+ "distance",
+ "user_instance_id", // req Metal 2.4
+ "instance_id",
+ "", // SBT offset
+ "geometry_id",
+ "primitive_id",
+ "triangle_barycentric_coord",
+ "triangle_front_facing",
+ "", // padding
+ "object_to_world_transform", // req Metal 2.4
+ "world_to_object_transform", // req Metal 2.4
+ ];
+ for field in fields {
+ write!(self.out, ", ")?;
+ if field.is_empty() {
+ write!(self.out, "{{}}")?;
+ } else {
+ self.put_expression(query, context, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?;
+ }
+ }
+ write!(self.out, "}}")?;
+ }
+ }
+ Ok(())
+ }
+
+ /// Used by expressions like Swizzle and Binary since they need packed_vec3's to be casted to a vec3
+ fn put_wrapped_expression_for_packed_vec3_access(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ if let Some(scalar) = context.get_packed_vec_kind(expr_handle) {
+ write!(self.out, "{}::{}3(", NAMESPACE, scalar.to_msl_name())?;
+ self.put_expression(expr_handle, context, is_scoped)?;
+ write!(self.out, ")")?;
+ } else {
+ self.put_expression(expr_handle, context, is_scoped)?;
+ }
+ Ok(())
+ }
+
+ /// Write a `GuardedIndex` as a Metal expression.
+ fn put_index(
+ &mut self,
+ index: index::GuardedIndex,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ match index {
+ index::GuardedIndex::Expression(expr) => {
+ self.put_expression(expr, context, is_scoped)?
+ }
+ index::GuardedIndex::Known(value) => write!(self.out, "{value}")?,
+ }
+ Ok(())
+ }
+
+ /// Emit an index bounds check condition for `chain`, if required.
+ ///
+ /// `chain` is a subtree of `Access` and `AccessIndex` expressions,
+ /// operating either on a pointer to a value, or on a value directly. If we cannot
+ /// statically determine that all indexing operations in `chain` are within
+ /// bounds, then write a conditional expression to check them dynamically,
+ /// and return true. All accesses in the chain are checked by the generated
+ /// expression.
+ ///
+ /// This assumes that the [`BoundsCheckPolicy`] for `chain` is [`ReadZeroSkipWrite`].
+ ///
+ /// The text written is of the form:
+ ///
+ /// ```ignore
+ /// {level}{prefix}uint(i) < 4 && uint(j) < 10
+ /// ```
+ ///
+ /// where `{level}` and `{prefix}` are the arguments to this function. For [`Store`]
+ /// statements, presumably these arguments start an indented `if` statement; for
+ /// [`Load`] expressions, the caller is probably building up a ternary `?:`
+ /// expression. In either case, what is written is not a complete syntactic structure
+ /// in its own right, and the caller will have to finish it off if we return `true`.
+ ///
+ /// If no expression is written, return false.
+ ///
+ /// [`BoundsCheckPolicy`]: index::BoundsCheckPolicy
+ /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
+ /// [`Store`]: crate::Statement::Store
+ /// [`Load`]: crate::Expression::Load
+ #[allow(unused_variables)]
+ fn put_bounds_checks(
+ &mut self,
+ mut chain: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ level: back::Level,
+ prefix: &'static str,
+ ) -> Result<bool, Error> {
+ let mut check_written = false;
+
+ // Iterate over the access chain, handling each expression.
+ loop {
+ // Produce a `GuardedIndex`, so we can shared code between the
+ // `Access` and `AccessIndex` cases.
+ let (base, guarded_index) = match context.function.expressions[chain] {
+ crate::Expression::Access { base, index } => {
+ (base, Some(index::GuardedIndex::Expression(index)))
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ // Don't try to check indices into structs. Validation already took
+ // care of them, and index::needs_guard doesn't handle that case.
+ let mut base_inner = context.resolve_type(base);
+ if let crate::TypeInner::Pointer { base, .. } = *base_inner {
+ base_inner = &context.module.types[base].inner;
+ }
+ match *base_inner {
+ crate::TypeInner::Struct { .. } => (base, None),
+ _ => (base, Some(index::GuardedIndex::Known(index))),
+ }
+ }
+ _ => break,
+ };
+
+ if let Some(index) = guarded_index {
+ if let Some(length) = context.access_needs_check(base, index) {
+ if check_written {
+ write!(self.out, " && ")?;
+ } else {
+ write!(self.out, "{level}{prefix}")?;
+ check_written = true;
+ }
+
+ // Check that the index falls within bounds. Do this with a single
+ // comparison, by casting the index to `uint` first, so that negative
+ // indices become large positive values.
+ write!(self.out, "uint(")?;
+ self.put_index(index, context, true)?;
+ self.out.write_str(") < ")?;
+ match length {
+ index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
+ index::IndexableLength::Dynamic => {
+ let global = context
+ .function
+ .originating_global(base)
+ .ok_or(Error::Validation)?;
+ write!(self.out, "1 + ")?;
+ self.put_dynamic_array_max_index(global, context)?
+ }
+ }
+ }
+ }
+
+ chain = base
+ }
+
+ Ok(check_written)
+ }
+
+ /// Write the access chain `chain`.
+ ///
+ /// `chain` is a subtree of [`Access`] and [`AccessIndex`] expressions,
+ /// operating either on a pointer to a value, or on a value directly.
+ ///
+ /// Generate bounds checks code only if `policy` is [`Restrict`]. The
+ /// [`ReadZeroSkipWrite`] policy requires checks before any accesses take place, so
+ /// that must be handled in the caller.
+ ///
+ /// Handle the entire chain, recursing back into `put_expression` only for index
+ /// expressions and the base expression that originates the pointer or composite value
+ /// being accessed. This allows `put_expression` to assume that any `Access` or
+ /// `AccessIndex` expressions it sees are the top of a chain, so it can emit
+ /// `ReadZeroSkipWrite` checks.
+ ///
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
+ /// [`ReadZeroSkipWrite`]: crate::proc::index::BoundsCheckPolicy::ReadZeroSkipWrite
+ fn put_access_chain(
+ &mut self,
+ chain: Handle<crate::Expression>,
+ policy: index::BoundsCheckPolicy,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match context.function.expressions[chain] {
+ crate::Expression::Access { base, index } => {
+ let mut base_ty = context.resolve_type(base);
+
+ // Look through any pointers to see what we're really indexing.
+ if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
+ base_ty = &context.module.types[base].inner;
+ }
+
+ self.put_subscripted_access_chain(
+ base,
+ base_ty,
+ index::GuardedIndex::Expression(index),
+ policy,
+ context,
+ )?;
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ let base_resolution = &context.info[base].ty;
+ let mut base_ty = base_resolution.inner_with(&context.module.types);
+ let mut base_ty_handle = base_resolution.handle();
+
+ // Look through any pointers to see what we're really indexing.
+ if let crate::TypeInner::Pointer { base, space: _ } = *base_ty {
+ base_ty = &context.module.types[base].inner;
+ base_ty_handle = Some(base);
+ }
+
+ // Handle structs and anything else that can use `.x` syntax here, so
+ // `put_subscripted_access_chain` won't have to handle the absurd case of
+ // indexing a struct with an expression.
+ match *base_ty {
+ crate::TypeInner::Struct { .. } => {
+ let base_ty = base_ty_handle.unwrap();
+ self.put_access_chain(base, policy, context)?;
+ let name = &self.names[&NameKey::StructMember(base_ty, index)];
+ write!(self.out, ".{name}")?;
+ }
+ crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Vector { .. } => {
+ self.put_access_chain(base, policy, context)?;
+ // Prior to Metal v2.1 component access for packed vectors wasn't available
+ // however array indexing is
+ if context.get_packed_vec_kind(base).is_some() {
+ write!(self.out, "[{index}]")?;
+ } else {
+ write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
+ }
+ }
+ _ => {
+ self.put_subscripted_access_chain(
+ base,
+ base_ty,
+ index::GuardedIndex::Known(index),
+ policy,
+ context,
+ )?;
+ }
+ }
+ }
+ _ => self.put_expression(chain, context, false)?,
+ }
+
+ Ok(())
+ }
+
+ /// Write a `[]`-style access of `base` by `index`.
+ ///
+ /// If `policy` is [`Restrict`], then generate code as needed to force all index
+ /// values within bounds.
+ ///
+ /// The `base_ty` argument must be the type we are actually indexing, like [`Array`] or
+ /// [`Vector`]. In other words, it's `base`'s type with any surrounding [`Pointer`]
+ /// removed. Our callers often already have this handy.
+ ///
+ /// This only emits `[]` expressions; it doesn't handle struct member accesses or
+ /// referencing vector components by name.
+ ///
+ /// [`Restrict`]: crate::proc::index::BoundsCheckPolicy::Restrict
+ /// [`Array`]: crate::TypeInner::Array
+ /// [`Vector`]: crate::TypeInner::Vector
+ /// [`Pointer`]: crate::TypeInner::Pointer
+ fn put_subscripted_access_chain(
+ &mut self,
+ base: Handle<crate::Expression>,
+ base_ty: &crate::TypeInner,
+ index: index::GuardedIndex,
+ policy: index::BoundsCheckPolicy,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let accessing_wrapped_array = match *base_ty {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(_),
+ ..
+ } => true,
+ _ => false,
+ };
+
+ self.put_access_chain(base, policy, context)?;
+ if accessing_wrapped_array {
+ write!(self.out, ".{WRAPPED_ARRAY_FIELD}")?;
+ }
+ write!(self.out, "[")?;
+
+ // Decide whether this index needs to be clamped to fall within range.
+ let restriction_needed = if policy == index::BoundsCheckPolicy::Restrict {
+ context.access_needs_check(base, index)
+ } else {
+ None
+ };
+ if let Some(limit) = restriction_needed {
+ write!(self.out, "{NAMESPACE}::min(unsigned(")?;
+ self.put_index(index, context, true)?;
+ write!(self.out, "), ")?;
+ match limit {
+ index::IndexableLength::Known(limit) => {
+ write!(self.out, "{}u", limit - 1)?;
+ }
+ index::IndexableLength::Dynamic => {
+ let global = context
+ .function
+ .originating_global(base)
+ .ok_or(Error::Validation)?;
+ self.put_dynamic_array_max_index(global, context)?;
+ }
+ }
+ write!(self.out, ")")?;
+ } else {
+ self.put_index(index, context, true)?;
+ }
+
+ write!(self.out, "]")?;
+
+ Ok(())
+ }
+
+ fn put_load(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ is_scoped: bool,
+ ) -> BackendResult {
+ // Since access chains never cross between address spaces, we can just
+ // check the index bounds check policy once at the top.
+ let policy = context.choose_bounds_check_policy(pointer);
+ if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(
+ pointer,
+ context,
+ back::Level(0),
+ if is_scoped { "" } else { "(" },
+ )?
+ {
+ write!(self.out, " ? ")?;
+ self.put_unchecked_load(pointer, policy, context)?;
+ write!(self.out, " : DefaultConstructible()")?;
+
+ if !is_scoped {
+ write!(self.out, ")")?;
+ }
+ } else {
+ self.put_unchecked_load(pointer, policy, context)?;
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_load(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ policy: index::BoundsCheckPolicy,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ let is_atomic_pointer = context
+ .resolve_type(pointer)
+ .is_atomic_pointer(&context.module.types);
+
+ if is_atomic_pointer {
+ write!(
+ self.out,
+ "{NAMESPACE}::atomic_load_explicit({ATOMIC_REFERENCE}"
+ )?;
+ self.put_access_chain(pointer, policy, context)?;
+ write!(self.out, ", {NAMESPACE}::memory_order_relaxed)")?;
+ } else {
+ // We don't do any dereferencing with `*` here as pointer arguments to functions
+ // are done by `&` references and not `*` pointers. These do not need to be
+ // dereferenced.
+ self.put_access_chain(pointer, policy, context)?;
+ }
+
+ Ok(())
+ }
+
+ fn put_return_value(
+ &mut self,
+ level: back::Level,
+ expr_handle: Handle<crate::Expression>,
+ result_struct: Option<&str>,
+ context: &ExpressionContext,
+ ) -> BackendResult {
+ match result_struct {
+ Some(struct_name) => {
+ let mut has_point_size = false;
+ let result_ty = context.function.result.as_ref().unwrap().ty;
+ match context.module.types[result_ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let tmp = "_tmp";
+ write!(self.out, "{level}const auto {tmp} = ")?;
+ self.put_expression(expr_handle, context, true)?;
+ writeln!(self.out, ";")?;
+ write!(self.out, "{level}return {struct_name} {{")?;
+
+ let mut is_first = true;
+
+ for (index, member) in members.iter().enumerate() {
+ if let Some(crate::Binding::BuiltIn(crate::BuiltIn::PointSize)) =
+ member.binding
+ {
+ has_point_size = true;
+ if !context.pipeline_options.allow_and_force_point_size {
+ continue;
+ }
+ }
+
+ let comma = if is_first { "" } else { "," };
+ is_first = false;
+ let name = &self.names[&NameKey::StructMember(result_ty, index as u32)];
+ // HACK: we are forcefully deduplicating the expression here
+ // to convert from a wrapped struct to a raw array, e.g.
+ // `float gl_ClipDistance1 [[clip_distance]] [1];`.
+ if let crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(size),
+ ..
+ } = context.module.types[member.ty].inner
+ {
+ write!(self.out, "{comma} {{")?;
+ for j in 0..size.get() {
+ if j != 0 {
+ write!(self.out, ",")?;
+ }
+ write!(self.out, "{tmp}.{name}.{WRAPPED_ARRAY_FIELD}[{j}]")?;
+ }
+ write!(self.out, "}}")?;
+ } else {
+ write!(self.out, "{comma} {tmp}.{name}")?;
+ }
+ }
+ }
+ _ => {
+ write!(self.out, "{level}return {struct_name} {{ ")?;
+ self.put_expression(expr_handle, context, true)?;
+ }
+ }
+
+ if let FunctionOrigin::EntryPoint(ep_index) = context.origin {
+ let stage = context.module.entry_points[ep_index as usize].stage;
+ if context.pipeline_options.allow_and_force_point_size
+ && stage == crate::ShaderStage::Vertex
+ && !has_point_size
+ {
+ // point size was injected and comes last
+ write!(self.out, ", 1.0")?;
+ }
+ }
+ write!(self.out, " }}")?;
+ }
+ None => {
+ write!(self.out, "{level}return ")?;
+ self.put_expression(expr_handle, context, true)?;
+ }
+ }
+ writeln!(self.out, ";")?;
+ Ok(())
+ }
+
+ /// Helper method used to find which expressions of a given function require baking
+ ///
+ /// # Notes
+ /// This function overwrites the contents of `self.need_bake_expressions`
+ fn update_expressions_to_bake(
+ &mut self,
+ func: &crate::Function,
+ info: &valid::FunctionInfo,
+ context: &ExpressionContext,
+ ) {
+ use crate::Expression;
+ self.need_bake_expressions.clear();
+
+ for (expr_handle, expr) in func.expressions.iter() {
+ // Expressions whose reference count is above the
+ // threshold should always be stored in temporaries.
+ let expr_info = &info[expr_handle];
+ let min_ref_count = func.expressions[expr_handle].bake_ref_count();
+ if min_ref_count <= expr_info.ref_count {
+ self.need_bake_expressions.insert(expr_handle);
+ } else {
+ match expr_info.ty {
+ // force ray desc to be baked: it's used multiple times internally
+ TypeResolution::Handle(h)
+ if Some(h) == context.module.special_types.ray_desc =>
+ {
+ self.need_bake_expressions.insert(expr_handle);
+ }
+ _ => {}
+ }
+ }
+
+ if let Expression::Math { fun, arg, arg1, .. } = *expr {
+ match fun {
+ crate::MathFunction::Dot => {
+ // WGSL's `dot` function works on any `vecN` type, but Metal's only
+ // works on floating-point vectors, so we emit inline code for
+ // integer vector `dot` calls. But that code uses each argument `N`
+ // times, once for each component (see `put_dot_product`), so to
+ // avoid duplicated evaluation, we must bake integer operands.
+
+ // check what kind of product this is depending
+ // on the resolve type of the Dot function itself
+ let inner = context.resolve_type(expr_handle);
+ if let crate::TypeInner::Scalar(scalar) = *inner {
+ match scalar.kind {
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
+ self.need_bake_expressions.insert(arg);
+ self.need_bake_expressions.insert(arg1.unwrap());
+ }
+ _ => {}
+ }
+ }
+ }
+ crate::MathFunction::FindMsb => {
+ self.need_bake_expressions.insert(arg);
+ }
+ crate::MathFunction::Sign => {
+ // WGSL's `sign` function works also on signed ints, but Metal's only
+ // works on floating points, so we emit inline code for integer `sign`
+ // calls. But that code uses each argument 2 times (see `put_isign`),
+ // so to avoid duplicated evaluation, we must bake the argument.
+ let inner = context.resolve_type(expr_handle);
+ if inner.scalar_kind() == Some(crate::ScalarKind::Sint) {
+ self.need_bake_expressions.insert(arg);
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+ }
+
+ fn start_baking_expression(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ context: &ExpressionContext,
+ name: &str,
+ ) -> BackendResult {
+ match context.info[handle].ty {
+ TypeResolution::Handle(ty_handle) => {
+ let ty_name = TypeContext {
+ handle: ty_handle,
+ gctx: context.module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{ty_name}")?;
+ }
+ TypeResolution::Value(crate::TypeInner::Scalar(scalar)) => {
+ put_numeric_type(&mut self.out, scalar, &[])?;
+ }
+ TypeResolution::Value(crate::TypeInner::Vector { size, scalar }) => {
+ put_numeric_type(&mut self.out, scalar, &[size])?;
+ }
+ TypeResolution::Value(crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ }) => {
+ put_numeric_type(&mut self.out, scalar, &[rows, columns])?;
+ }
+ TypeResolution::Value(ref other) => {
+ log::warn!("Type {:?} isn't a known local", other); //TEMP!
+ return Err(Error::FeatureNotImplemented("weird local type".to_string()));
+ }
+ }
+
+ //TODO: figure out the naming scheme that wouldn't collide with user names.
+ write!(self.out, " {name} = ")?;
+
+ Ok(())
+ }
+
+ /// Cache a clamped level of detail value, if necessary.
+ ///
+ /// [`ImageLoad`] accesses covered by [`BoundsCheckPolicy::Restrict`] use a
+ /// properly clamped level of detail value both in the access itself, and
+ /// for fetching the size of the requested MIP level, needed to clamp the
+ /// coordinates. To avoid recomputing this clamped level of detail, we cache
+ /// it in a temporary variable, as part of the [`Emit`] statement covering
+ /// the [`ImageLoad`] expression.
+ ///
+ /// [`ImageLoad`]: crate::Expression::ImageLoad
+ /// [`BoundsCheckPolicy::Restrict`]: index::BoundsCheckPolicy::Restrict
+ /// [`Emit`]: crate::Statement::Emit
+ fn put_cache_restricted_level(
+ &mut self,
+ load: Handle<crate::Expression>,
+ image: Handle<crate::Expression>,
+ mip_level: Option<Handle<crate::Expression>>,
+ indent: back::Level,
+ context: &StatementContext,
+ ) -> BackendResult {
+ // Does this image access actually require (or even permit) a
+ // level-of-detail, and does the policy require us to restrict it?
+ let level_of_detail = match mip_level {
+ Some(level) => level,
+ None => return Ok(()),
+ };
+
+ if context.expression.policies.image_load != index::BoundsCheckPolicy::Restrict
+ || !context.expression.image_needs_lod(image)
+ {
+ return Ok(());
+ }
+
+ write!(
+ self.out,
+ "{}uint {}{} = ",
+ indent,
+ CLAMPED_LOD_LOAD_PREFIX,
+ load.index(),
+ )?;
+ self.put_restricted_scalar_image_index(
+ image,
+ level_of_detail,
+ "get_num_mip_levels",
+ &context.expression,
+ )?;
+ writeln!(self.out, ";")?;
+
+ Ok(())
+ }
+
+ fn put_block(
+ &mut self,
+ level: back::Level,
+ statements: &[crate::Statement],
+ context: &StatementContext,
+ ) -> BackendResult {
+ // Add to the set in order to track the stack size.
+ #[cfg(test)]
+ #[allow(trivial_casts)]
+ self.put_block_stack_pointers
+ .insert(&level as *const _ as *const ());
+
+ for statement in statements {
+ log::trace!("statement[{}] {:?}", level.0, statement);
+ match *statement {
+ crate::Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ // `ImageLoad` expressions covered by the `Restrict` bounds check policy
+ // may need to cache a clamped version of their level-of-detail argument.
+ if let crate::Expression::ImageLoad {
+ image,
+ level: mip_level,
+ ..
+ } = context.expression.function.expressions[handle]
+ {
+ self.put_cache_restricted_level(
+ handle, image, mip_level, level, context,
+ )?;
+ }
+
+ let ptr_class = context.expression.resolve_type(handle).pointer_space();
+ let expr_name = if ptr_class.is_some() {
+ None // don't bake pointer expressions (just yet)
+ } else if let Some(name) =
+ context.expression.function.named_expressions.get(&handle)
+ {
+ // The `crate::Function::named_expressions` table holds
+ // expressions that should be saved in temporaries once they
+ // are `Emit`ted. We only add them to `self.named_expressions`
+ // when we reach the `Emit` that covers them, so that we don't
+ // try to use their names before we've actually initialized
+ // the temporary that holds them.
+ //
+ // Don't assume the names in `named_expressions` are unique,
+ // or even valid. Use the `Namer`.
+ Some(self.namer.call(name))
+ } else {
+ // If this expression is an index that we're going to first compare
+ // against a limit, and then actually use as an index, then we may
+ // want to cache it in a temporary, to avoid evaluating it twice.
+ let bake =
+ if context.expression.guarded_indices.contains(handle.index()) {
+ true
+ } else {
+ self.need_bake_expressions.contains(&handle)
+ };
+
+ if bake {
+ Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
+ } else {
+ None
+ }
+ };
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{level}")?;
+ self.start_baking_expression(handle, &context.expression, &name)?;
+ self.put_expression(handle, &context.expression, true)?;
+ self.named_expressions.insert(handle, name);
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+ crate::Statement::Block(ref block) => {
+ if !block.is_empty() {
+ writeln!(self.out, "{level}{{")?;
+ self.put_block(level.next(), block, context)?;
+ writeln!(self.out, "{level}}}")?;
+ }
+ }
+ crate::Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{level}if (")?;
+ self.put_expression(condition, &context.expression, true)?;
+ writeln!(self.out, ") {{")?;
+ self.put_block(level.next(), accept, context)?;
+ if !reject.is_empty() {
+ writeln!(self.out, "{level}}} else {{")?;
+ self.put_block(level.next(), reject, context)?;
+ }
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ write!(self.out, "{level}switch(")?;
+ self.put_expression(selector, &context.expression, true)?;
+ writeln!(self.out, ") {{")?;
+ let lcase = level.next();
+ for case in cases.iter() {
+ match case.value {
+ crate::SwitchValue::I32(value) => {
+ write!(self.out, "{lcase}case {value}:")?;
+ }
+ crate::SwitchValue::U32(value) => {
+ write!(self.out, "{lcase}case {value}u:")?;
+ }
+ crate::SwitchValue::Default => {
+ write!(self.out, "{lcase}default:")?;
+ }
+ }
+
+ let write_block_braces = !(case.fall_through && case.body.is_empty());
+ if write_block_braces {
+ writeln!(self.out, " {{")?;
+ } else {
+ writeln!(self.out)?;
+ }
+
+ self.put_block(lcase.next(), &case.body, context)?;
+ if !case.fall_through
+ && case.body.last().map_or(true, |s| !s.is_terminator())
+ {
+ writeln!(self.out, "{}break;", lcase.next())?;
+ }
+
+ if write_block_braces {
+ writeln!(self.out, "{lcase}}}")?;
+ }
+ }
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ if !continuing.is_empty() || break_if.is_some() {
+ let gate_name = self.namer.call("loop_init");
+ writeln!(self.out, "{level}bool {gate_name} = true;")?;
+ writeln!(self.out, "{level}while(true) {{")?;
+ let lif = level.next();
+ let lcontinuing = lif.next();
+ writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
+ self.put_block(lcontinuing, continuing, context)?;
+ if let Some(condition) = break_if {
+ write!(self.out, "{lcontinuing}if (")?;
+ self.put_expression(condition, &context.expression, true)?;
+ writeln!(self.out, ") {{")?;
+ writeln!(self.out, "{}break;", lcontinuing.next())?;
+ writeln!(self.out, "{lcontinuing}}}")?;
+ }
+ writeln!(self.out, "{lif}}}")?;
+ writeln!(self.out, "{lif}{gate_name} = false;")?;
+ } else {
+ writeln!(self.out, "{level}while(true) {{")?;
+ }
+ self.put_block(level.next(), body, context)?;
+ writeln!(self.out, "{level}}}")?;
+ }
+ crate::Statement::Break => {
+ writeln!(self.out, "{level}break;")?;
+ }
+ crate::Statement::Continue => {
+ writeln!(self.out, "{level}continue;")?;
+ }
+ crate::Statement::Return {
+ value: Some(expr_handle),
+ } => {
+ self.put_return_value(
+ level,
+ expr_handle,
+ context.result_struct,
+ &context.expression,
+ )?;
+ }
+ crate::Statement::Return { value: None } => {
+ writeln!(self.out, "{level}return;")?;
+ }
+ crate::Statement::Kill => {
+ writeln!(self.out, "{level}{NAMESPACE}::discard_fragment();")?;
+ }
+ crate::Statement::Barrier(flags) => {
+ self.write_barrier(flags, level)?;
+ }
+ crate::Statement::Store { pointer, value } => {
+ self.put_store(pointer, value, level, context)?
+ }
+ crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ let address = TexelAddress {
+ coordinate,
+ array_index,
+ sample: None,
+ level: None,
+ };
+ self.put_image_store(level, image, &address, value, context)?
+ }
+ crate::Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ if let Some(expr) = result {
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ self.start_baking_expression(expr, &context.expression, &name)?;
+ self.named_expressions.insert(expr, name);
+ }
+ let fun_name = &self.names[&NameKey::Function(function)];
+ write!(self.out, "{fun_name}(")?;
+ // first, write down the actual arguments
+ for (i, &handle) in arguments.iter().enumerate() {
+ if i != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.put_expression(handle, &context.expression, true)?;
+ }
+ // follow-up with any global resources used
+ let mut separate = !arguments.is_empty();
+ let fun_info = &context.expression.mod_info[function];
+ let mut supports_array_length = false;
+ for (handle, var) in context.expression.module.global_variables.iter() {
+ if fun_info[handle].is_empty() {
+ continue;
+ }
+ if var.space.needs_pass_through() {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ if separate {
+ write!(self.out, ", ")?;
+ } else {
+ separate = true;
+ }
+ write!(self.out, "{name}")?;
+ }
+ supports_array_length |=
+ needs_array_length(var.ty, &context.expression.module.types);
+ }
+ if supports_array_length {
+ if separate {
+ write!(self.out, ", ")?;
+ }
+ write!(self.out, "_buffer_sizes")?;
+ }
+
+ // done
+ writeln!(self.out, ");")?;
+ }
+ crate::Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_baking_expression(result, &context.expression, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+ match *fun {
+ crate::AtomicFunction::Add => {
+ self.put_atomic_fetch(pointer, "add", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Subtract => {
+ self.put_atomic_fetch(pointer, "sub", value, &context.expression)?;
+ }
+ crate::AtomicFunction::And => {
+ self.put_atomic_fetch(pointer, "and", value, &context.expression)?;
+ }
+ crate::AtomicFunction::InclusiveOr => {
+ self.put_atomic_fetch(pointer, "or", value, &context.expression)?;
+ }
+ crate::AtomicFunction::ExclusiveOr => {
+ self.put_atomic_fetch(pointer, "xor", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Min => {
+ self.put_atomic_fetch(pointer, "min", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Max => {
+ self.put_atomic_fetch(pointer, "max", value, &context.expression)?;
+ }
+ crate::AtomicFunction::Exchange { compare: None } => {
+ self.put_atomic_operation(
+ pointer,
+ "exchange",
+ "",
+ value,
+ &context.expression,
+ )?;
+ }
+ crate::AtomicFunction::Exchange { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "atomic CompareExchange".to_string(),
+ ));
+ }
+ }
+ // done
+ writeln!(self.out, ";")?;
+ }
+ crate::Statement::WorkGroupUniformLoad { pointer, result } => {
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+
+ write!(self.out, "{level}")?;
+ let name = self.namer.call("");
+ self.start_baking_expression(result, &context.expression, &name)?;
+ self.put_load(pointer, &context.expression, true)?;
+ self.named_expressions.insert(result, name);
+
+ writeln!(self.out, ";")?;
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
+ }
+ crate::Statement::RayQuery { query, ref fun } => {
+ if context.expression.lang_version < (2, 4) {
+ return Err(Error::UnsupportedRayTracing);
+ }
+
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ //TODO: how to deal with winding?
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?;
+ {
+ let f_opaque = back::RayFlag::CULL_OPAQUE.bits();
+ let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits();
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(
+ self.out,
+ ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode(("
+ )?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?;
+ writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?;
+ }
+ {
+ let f_opaque = back::RayFlag::OPAQUE.bits();
+ let f_no_opaque = back::RayFlag::NO_OPAQUE.bits();
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?;
+ writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?;
+ }
+ {
+ let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits();
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(
+ self.out,
+ ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection(("
+ )?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ writeln!(self.out, ".flags & {flag}) != 0);")?;
+ }
+
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?;
+ self.put_expression(query, &context.expression, true)?;
+ write!(
+ self.out,
+ ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray("
+ )?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".origin, ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".dir, ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".tmin, ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".tmax), ")?;
+ self.put_expression(acceleration_structure, &context.expression, true)?;
+ write!(self.out, ", ")?;
+ self.put_expression(descriptor, &context.expression, true)?;
+ write!(self.out, ".cull_mask);")?;
+
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?;
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ write!(self.out, "{level}")?;
+ let name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_baking_expression(result, &context.expression, &name)?;
+ self.named_expressions.insert(result, name);
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?;
+ //TODO: actually proceed?
+
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?;
+ }
+ crate::RayQueryFunction::Terminate => {
+ write!(self.out, "{level}")?;
+ self.put_expression(query, &context.expression, true)?;
+ writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.abort();")?;
+ }
+ }
+ }
+ }
+ }
+
+ // un-emit expressions
+ //TODO: take care of loop/continuing?
+ for statement in statements {
+ if let crate::Statement::Emit(ref range) = *statement {
+ for handle in range.clone() {
+ self.named_expressions.remove(&handle);
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn put_store(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ value: Handle<crate::Expression>,
+ level: back::Level,
+ context: &StatementContext,
+ ) -> BackendResult {
+ let policy = context.expression.choose_bounds_check_policy(pointer);
+ if policy == index::BoundsCheckPolicy::ReadZeroSkipWrite
+ && self.put_bounds_checks(pointer, &context.expression, level, "if (")?
+ {
+ writeln!(self.out, ") {{")?;
+ self.put_unchecked_store(pointer, value, policy, level.next(), context)?;
+ writeln!(self.out, "{level}}}")?;
+ } else {
+ self.put_unchecked_store(pointer, value, policy, level, context)?;
+ }
+
+ Ok(())
+ }
+
+ fn put_unchecked_store(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ value: Handle<crate::Expression>,
+ policy: index::BoundsCheckPolicy,
+ level: back::Level,
+ context: &StatementContext,
+ ) -> BackendResult {
+ let is_atomic_pointer = context
+ .expression
+ .resolve_type(pointer)
+ .is_atomic_pointer(&context.expression.module.types);
+
+ if is_atomic_pointer {
+ write!(
+ self.out,
+ "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
+ )?;
+ self.put_access_chain(pointer, policy, &context.expression)?;
+ write!(self.out, ", ")?;
+ self.put_expression(value, &context.expression, true)?;
+ writeln!(self.out, ", {NAMESPACE}::memory_order_relaxed);")?;
+ } else {
+ write!(self.out, "{level}")?;
+ self.put_access_chain(pointer, policy, &context.expression)?;
+ write!(self.out, " = ")?;
+ self.put_expression(value, &context.expression, true)?;
+ writeln!(self.out, ";")?;
+ }
+
+ Ok(())
+ }
+
+ pub fn write(
+ &mut self,
+ module: &crate::Module,
+ info: &valid::ModuleInfo,
+ options: &Options,
+ pipeline_options: &PipelineOptions,
+ ) -> Result<TranslationInfo, Error> {
+ self.names.clear();
+ self.namer.reset(
+ module,
+ super::keywords::RESERVED,
+ &[],
+ &[],
+ &[CLAMPED_LOD_LOAD_PREFIX],
+ &mut self.names,
+ );
+ self.struct_member_pads.clear();
+
+ writeln!(
+ self.out,
+ "// language: metal{}.{}",
+ options.lang_version.0, options.lang_version.1
+ )?;
+ writeln!(self.out, "#include <metal_stdlib>")?;
+ writeln!(self.out, "#include <simd/simd.h>")?;
+ writeln!(self.out)?;
+ // Work around Metal bug where `uint` is not available by default
+ writeln!(self.out, "using {NAMESPACE}::uint;")?;
+
+ let mut uses_ray_query = false;
+ for (_, ty) in module.types.iter() {
+ match ty.inner {
+ crate::TypeInner::AccelerationStructure => {
+ if options.lang_version < (2, 4) {
+ return Err(Error::UnsupportedRayTracing);
+ }
+ }
+ crate::TypeInner::RayQuery => {
+ if options.lang_version < (2, 4) {
+ return Err(Error::UnsupportedRayTracing);
+ }
+ uses_ray_query = true;
+ }
+ _ => (),
+ }
+ }
+
+ if module.special_types.ray_desc.is_some()
+ || module.special_types.ray_intersection.is_some()
+ {
+ if options.lang_version < (2, 4) {
+ return Err(Error::UnsupportedRayTracing);
+ }
+ }
+
+ if uses_ray_query {
+ self.put_ray_query_type()?;
+ }
+
+ if options
+ .bounds_check_policies
+ .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
+ {
+ self.put_default_constructible()?;
+ }
+ writeln!(self.out)?;
+
+ {
+ let mut indices = vec![];
+ for (handle, var) in module.global_variables.iter() {
+ if needs_array_length(var.ty, &module.types) {
+ let idx = handle.index();
+ indices.push(idx);
+ }
+ }
+
+ if !indices.is_empty() {
+ writeln!(self.out, "struct _mslBufferSizes {{")?;
+
+ for idx in indices {
+ writeln!(self.out, "{}uint size{};", back::INDENT, idx)?;
+ }
+
+ writeln!(self.out, "}};")?;
+ writeln!(self.out)?;
+ }
+ };
+
+ self.write_type_defs(module)?;
+ self.write_global_constants(module, info)?;
+ self.write_functions(module, info, options, pipeline_options)
+ }
+
+ /// Write the definition for the `DefaultConstructible` class.
+ ///
+ /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
+ /// produce 'zero' values for any type, including structs, arrays, and so
+ /// on. We could do this by emitting default constructor applications, but
+ /// that would entail printing the name of the type, which is more trouble
+ /// than you'd think. Instead, we just construct this magic C++14 class that
+ /// can be converted to any type that can be default constructed, using
+ /// template parameter inference to detect which type is needed, so we don't
+ /// have to figure out the name.
+ ///
+ /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
+ fn put_default_constructible(&mut self) -> BackendResult {
+ let tab = back::INDENT;
+ writeln!(self.out, "struct DefaultConstructible {{")?;
+ writeln!(self.out, "{tab}template<typename T>")?;
+ writeln!(self.out, "{tab}operator T() && {{")?;
+ writeln!(self.out, "{tab}{tab}return T {{}};")?;
+ writeln!(self.out, "{tab}}}")?;
+ writeln!(self.out, "}};")?;
+ Ok(())
+ }
+
+ fn put_ray_query_type(&mut self) -> BackendResult {
+ let tab = back::INDENT;
+ writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?;
+ let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>");
+ writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?;
+ writeln!(
+ self.out,
+ "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};"
+ )?;
+ writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?;
+ writeln!(self.out, "}};")?;
+ writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?;
+ let v_triangle = back::RayIntersectionType::Triangle as u32;
+ let v_bbox = back::RayIntersectionType::BoundingBox as u32;
+ writeln!(
+ self.out,
+ "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : "
+ )?;
+ writeln!(
+ self.out,
+ "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;"
+ )?;
+ writeln!(self.out, "}}")?;
+ Ok(())
+ }
+
+ fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
+ for (handle, ty) in module.types.iter() {
+ if !ty.needs_alias() {
+ continue;
+ }
+ let name = &self.names[&NameKey::Type(handle)];
+ match ty.inner {
+ // Naga IR can pass around arrays by value, but Metal, following
+ // C++, performs an array-to-pointer conversion (C++ [conv.array])
+ // on expressions of array type, so assigning the array by value
+ // isn't possible. However, Metal *does* assign structs by
+ // value. So in our Metal output, we wrap all array types in
+ // synthetic struct types:
+ //
+ // struct type1 {
+ // float inner[10]
+ // };
+ //
+ // Then we carefully include `.inner` (`WRAPPED_ARRAY_FIELD`) in
+ // any expression that actually wants access to the array.
+ crate::TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ } => {
+ let base_name = TypeContext {
+ handle: base,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+
+ match size {
+ crate::ArraySize::Constant(size) => {
+ writeln!(self.out, "struct {name} {{")?;
+ writeln!(
+ self.out,
+ "{}{} {}[{}];",
+ back::INDENT,
+ base_name,
+ WRAPPED_ARRAY_FIELD,
+ size
+ )?;
+ writeln!(self.out, "}};")?;
+ }
+ crate::ArraySize::Dynamic => {
+ writeln!(self.out, "typedef {base_name} {name}[1];")?;
+ }
+ }
+ }
+ crate::TypeInner::Struct {
+ ref members, span, ..
+ } => {
+ writeln!(self.out, "struct {name} {{")?;
+ let mut last_offset = 0;
+ for (index, member) in members.iter().enumerate() {
+ if member.offset > last_offset {
+ self.struct_member_pads.insert((handle, index as u32));
+ let pad = member.offset - last_offset;
+ writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?;
+ }
+ let ty_inner = &module.types[member.ty].inner;
+ last_offset = member.offset + ty_inner.size(module.to_ctx());
+
+ let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
+
+ // If the member should be packed (as is the case for a misaligned vec3) issue a packed vector
+ match should_pack_struct_member(members, span, index, module) {
+ Some(scalar) => {
+ writeln!(
+ self.out,
+ "{}{}::packed_{}3 {};",
+ back::INDENT,
+ NAMESPACE,
+ scalar.to_msl_name(),
+ member_name
+ )?;
+ }
+ None => {
+ let base_name = TypeContext {
+ handle: member.ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ writeln!(
+ self.out,
+ "{}{} {};",
+ back::INDENT,
+ base_name,
+ member_name
+ )?;
+
+ // for 3-component vectors, add one component
+ if let crate::TypeInner::Vector {
+ size: crate::VectorSize::Tri,
+ scalar,
+ } = *ty_inner
+ {
+ last_offset += scalar.width as u32;
+ }
+ }
+ }
+ }
+ writeln!(self.out, "}};")?;
+ }
+ _ => {
+ let ty_name = TypeContext {
+ handle,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: true,
+ };
+ writeln!(self.out, "typedef {ty_name} {name};")?;
+ }
+ }
+ }
+
+ // Write functions to create special types.
+ for (type_key, struct_ty) in module.special_types.predeclared_types.iter() {
+ match type_key {
+ &crate::PredeclaredType::ModfResult { size, width }
+ | &crate::PredeclaredType::FrexpResult { size, width } => {
+ let arg_type_name_owner;
+ let arg_type_name = if let Some(size) = size {
+ arg_type_name_owner = format!(
+ "{NAMESPACE}::{}{}",
+ if width == 8 { "double" } else { "float" },
+ size as u8
+ );
+ &arg_type_name_owner
+ } else if width == 8 {
+ "double"
+ } else {
+ "float"
+ };
+
+ let other_type_name_owner;
+ let (defined_func_name, called_func_name, other_type_name) =
+ if matches!(type_key, &crate::PredeclaredType::ModfResult { .. }) {
+ (MODF_FUNCTION, "modf", arg_type_name)
+ } else {
+ let other_type_name = if let Some(size) = size {
+ other_type_name_owner = format!("int{}", size as u8);
+ &other_type_name_owner
+ } else {
+ "int"
+ };
+ (FREXP_FUNCTION, "frexp", other_type_name)
+ };
+
+ let struct_name = &self.names[&NameKey::Type(*struct_ty)];
+
+ writeln!(self.out)?;
+ writeln!(
+ self.out,
+ "{} {defined_func_name}({arg_type_name} arg) {{
+ {other_type_name} other;
+ {arg_type_name} fract = {NAMESPACE}::{called_func_name}(arg, other);
+ return {}{{ fract, other }};
+}}",
+ struct_name, struct_name
+ )?;
+ }
+ &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Writes all named constants
+ fn write_global_constants(
+ &mut self,
+ module: &crate::Module,
+ mod_info: &valid::ModuleInfo,
+ ) -> BackendResult {
+ let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
+
+ for (handle, constant) in constants {
+ let ty_name = TypeContext {
+ handle: constant.ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let name = &self.names[&NameKey::Constant(handle)];
+ write!(self.out, "constant {ty_name} {name} = ")?;
+ self.put_const_expression(constant.init, module, mod_info)?;
+ writeln!(self.out, ";")?;
+ }
+
+ Ok(())
+ }
+
+ fn put_inline_sampler_properties(
+ &mut self,
+ level: back::Level,
+ sampler: &sm::InlineSampler,
+ ) -> BackendResult {
+ for (&letter, address) in ['s', 't', 'r'].iter().zip(sampler.address.iter()) {
+ writeln!(
+ self.out,
+ "{}{}::{}_address::{},",
+ level,
+ NAMESPACE,
+ letter,
+ address.as_str(),
+ )?;
+ }
+ writeln!(
+ self.out,
+ "{}{}::mag_filter::{},",
+ level,
+ NAMESPACE,
+ sampler.mag_filter.as_str(),
+ )?;
+ writeln!(
+ self.out,
+ "{}{}::min_filter::{},",
+ level,
+ NAMESPACE,
+ sampler.min_filter.as_str(),
+ )?;
+ if let Some(filter) = sampler.mip_filter {
+ writeln!(
+ self.out,
+ "{}{}::mip_filter::{},",
+ level,
+ NAMESPACE,
+ filter.as_str(),
+ )?;
+ }
+ // avoid setting it on platforms that don't support it
+ if sampler.border_color != sm::BorderColor::TransparentBlack {
+ writeln!(
+ self.out,
+ "{}{}::border_color::{},",
+ level,
+ NAMESPACE,
+ sampler.border_color.as_str(),
+ )?;
+ }
+ //TODO: I'm not able to feed this in a way that MSL likes:
+ //>error: use of undeclared identifier 'lod_clamp'
+ //>error: no member named 'max_anisotropy' in namespace 'metal'
+ if false {
+ if let Some(ref lod) = sampler.lod_clamp {
+ writeln!(self.out, "{}lod_clamp({},{}),", level, lod.start, lod.end,)?;
+ }
+ if let Some(aniso) = sampler.max_anisotropy {
+ writeln!(self.out, "{}max_anisotropy({}),", level, aniso.get(),)?;
+ }
+ }
+ if sampler.compare_func != sm::CompareFunc::Never {
+ writeln!(
+ self.out,
+ "{}{}::compare_func::{},",
+ level,
+ NAMESPACE,
+ sampler.compare_func.as_str(),
+ )?;
+ }
+ writeln!(
+ self.out,
+ "{}{}::coord::{}",
+ level,
+ NAMESPACE,
+ sampler.coord.as_str()
+ )?;
+ Ok(())
+ }
+
+ // Returns the array of mapped entry point names.
+ fn write_functions(
+ &mut self,
+ module: &crate::Module,
+ mod_info: &valid::ModuleInfo,
+ options: &Options,
+ pipeline_options: &PipelineOptions,
+ ) -> Result<TranslationInfo, Error> {
+ let mut pass_through_globals = Vec::new();
+ for (fun_handle, fun) in module.functions.iter() {
+ log::trace!(
+ "function {:?}, handle {:?}",
+ fun.name.as_deref().unwrap_or("(anonymous)"),
+ fun_handle
+ );
+
+ let fun_info = &mod_info[fun_handle];
+ pass_through_globals.clear();
+ let mut supports_array_length = false;
+ for (handle, var) in module.global_variables.iter() {
+ if !fun_info[handle].is_empty() {
+ if var.space.needs_pass_through() {
+ pass_through_globals.push(handle);
+ }
+ supports_array_length |= needs_array_length(var.ty, &module.types);
+ }
+ }
+
+ writeln!(self.out)?;
+ let fun_name = &self.names[&NameKey::Function(fun_handle)];
+ match fun.result {
+ Some(ref result) => {
+ let ty_name = TypeContext {
+ handle: result.ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{ty_name}")?;
+ }
+ None => {
+ write!(self.out, "void")?;
+ }
+ }
+ writeln!(self.out, " {fun_name}(")?;
+
+ for (index, arg) in fun.arguments.iter().enumerate() {
+ let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)];
+ let param_type_name = TypeContext {
+ handle: arg.ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let separator = separate(
+ !pass_through_globals.is_empty()
+ || index + 1 != fun.arguments.len()
+ || supports_array_length,
+ );
+ writeln!(
+ self.out,
+ "{}{} {}{}",
+ back::INDENT,
+ param_type_name,
+ name,
+ separator
+ )?;
+ }
+ for (index, &handle) in pass_through_globals.iter().enumerate() {
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage: fun_info[handle],
+ binding: None,
+ reference: true,
+ };
+ let separator =
+ separate(index + 1 != pass_through_globals.len() || supports_array_length);
+ write!(self.out, "{}", back::INDENT)?;
+ tyvar.try_fmt(&mut self.out)?;
+ writeln!(self.out, "{separator}")?;
+ }
+
+ if supports_array_length {
+ writeln!(
+ self.out,
+ "{}constant _mslBufferSizes& _buffer_sizes",
+ back::INDENT
+ )?;
+ }
+
+ writeln!(self.out, ") {{")?;
+
+ let guarded_indices =
+ index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
+
+ let context = StatementContext {
+ expression: ExpressionContext {
+ function: fun,
+ origin: FunctionOrigin::Handle(fun_handle),
+ info: fun_info,
+ lang_version: options.lang_version,
+ policies: options.bounds_check_policies,
+ guarded_indices,
+ module,
+ mod_info,
+ pipeline_options,
+ },
+ result_struct: None,
+ };
+
+ for (local_handle, local) in fun.local_variables.iter() {
+ let ty_name = TypeContext {
+ handle: local.ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let local_name = &self.names[&NameKey::FunctionLocal(fun_handle, local_handle)];
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, local_name)?;
+ match local.init {
+ Some(value) => {
+ write!(self.out, " = ")?;
+ self.put_expression(value, &context.expression, true)?;
+ }
+ None => {
+ write!(self.out, " = {{}}")?;
+ }
+ };
+ writeln!(self.out, ";")?;
+ }
+
+ self.update_expressions_to_bake(fun, fun_info, &context.expression);
+ self.put_block(back::Level(1), &fun.body, &context)?;
+ writeln!(self.out, "}}")?;
+ self.named_expressions.clear();
+ }
+
+ let mut info = TranslationInfo {
+ entry_point_names: Vec::with_capacity(module.entry_points.len()),
+ };
+ for (ep_index, ep) in module.entry_points.iter().enumerate() {
+ let fun = &ep.function;
+ let fun_info = mod_info.get_entry_point(ep_index);
+ let mut ep_error = None;
+
+ log::trace!(
+ "entry point {:?}, index {:?}",
+ fun.name.as_deref().unwrap_or("(anonymous)"),
+ ep_index
+ );
+
+ // Is any global variable used by this entry point dynamically sized?
+ let supports_array_length = module
+ .global_variables
+ .iter()
+ .filter(|&(handle, _)| !fun_info[handle].is_empty())
+ .any(|(_, var)| needs_array_length(var.ty, &module.types));
+
+ // skip this entry point if any global bindings are missing,
+ // or their types are incompatible.
+ if !options.fake_missing_bindings {
+ for (var_handle, var) in module.global_variables.iter() {
+ if fun_info[var_handle].is_empty() {
+ continue;
+ }
+ match var.space {
+ crate::AddressSpace::Uniform
+ | crate::AddressSpace::Storage { .. }
+ | crate::AddressSpace::Handle => {
+ let br = match var.binding {
+ Some(ref br) => br,
+ None => {
+ let var_name = var.name.clone().unwrap_or_default();
+ ep_error =
+ Some(super::EntryPointError::MissingBinding(var_name));
+ break;
+ }
+ };
+ let target = options.get_resource_binding_target(ep, br);
+ let good = match target {
+ Some(target) => {
+ let binding_ty = match module.types[var.ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => {
+ &module.types[base].inner
+ }
+ ref ty => ty,
+ };
+ match *binding_ty {
+ crate::TypeInner::Image { .. } => target.texture.is_some(),
+ crate::TypeInner::Sampler { .. } => {
+ target.sampler.is_some()
+ }
+ _ => target.buffer.is_some(),
+ }
+ }
+ None => false,
+ };
+ if !good {
+ ep_error =
+ Some(super::EntryPointError::MissingBindTarget(br.clone()));
+ break;
+ }
+ }
+ crate::AddressSpace::PushConstant => {
+ if let Err(e) = options.resolve_push_constants(ep) {
+ ep_error = Some(e);
+ break;
+ }
+ }
+ crate::AddressSpace::Function
+ | crate::AddressSpace::Private
+ | crate::AddressSpace::WorkGroup => {}
+ }
+ }
+ if supports_array_length {
+ if let Err(err) = options.resolve_sizes_buffer(ep) {
+ ep_error = Some(err);
+ }
+ }
+ }
+
+ if let Some(err) = ep_error {
+ info.entry_point_names.push(Err(err));
+ continue;
+ }
+ let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
+ info.entry_point_names.push(Ok(fun_name.clone()));
+
+ writeln!(self.out)?;
+
+ let (em_str, in_mode, out_mode) = match ep.stage {
+ crate::ShaderStage::Vertex => (
+ "vertex",
+ LocationMode::VertexInput,
+ LocationMode::VertexOutput,
+ ),
+ crate::ShaderStage::Fragment { .. } => (
+ "fragment",
+ LocationMode::FragmentInput,
+ LocationMode::FragmentOutput,
+ ),
+ crate::ShaderStage::Compute { .. } => {
+ ("kernel", LocationMode::Uniform, LocationMode::Uniform)
+ }
+ };
+
+ // Since `Namer.reset` wasn't expecting struct members to be
+ // suddenly injected into another namespace like this,
+ // `self.names` doesn't keep them distinct from other variables.
+ // Generate fresh names for these arguments, and remember the
+ // mapping.
+ let mut flattened_member_names = FastHashMap::default();
+ // Varyings' members get their own namespace
+ let mut varyings_namer = crate::proc::Namer::default();
+
+ // List all the Naga `EntryPoint`'s `Function`'s arguments,
+ // flattening structs into their members. In Metal, we will pass
+ // each of these values to the entry point as a separate argument—
+ // except for the varyings, handled next.
+ let mut flattened_arguments = Vec::new();
+ for (arg_index, arg) in fun.arguments.iter().enumerate() {
+ match module.types[arg.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (member_index, member) in members.iter().enumerate() {
+ let member_index = member_index as u32;
+ flattened_arguments.push((
+ NameKey::StructMember(arg.ty, member_index),
+ member.ty,
+ member.binding.as_ref(),
+ ));
+ let name_key = NameKey::StructMember(arg.ty, member_index);
+ let name = match member.binding {
+ Some(crate::Binding::Location { .. }) => {
+ varyings_namer.call(&self.names[&name_key])
+ }
+ _ => self.namer.call(&self.names[&name_key]),
+ };
+ flattened_member_names.insert(name_key, name);
+ }
+ }
+ _ => flattened_arguments.push((
+ NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
+ arg.ty,
+ arg.binding.as_ref(),
+ )),
+ }
+ }
+
+ // Identify the varyings among the argument values, and emit a
+ // struct type named `<fun>Input` to hold them.
+ let stage_in_name = format!("{fun_name}Input");
+ let varyings_member_name = self.namer.call("varyings");
+ let mut has_varyings = false;
+ if !flattened_arguments.is_empty() {
+ writeln!(self.out, "struct {stage_in_name} {{")?;
+ for &(ref name_key, ty, binding) in flattened_arguments.iter() {
+ let binding = match binding {
+ Some(ref binding @ &crate::Binding::Location { .. }) => binding,
+ _ => continue,
+ };
+ has_varyings = true;
+ let name = match *name_key {
+ NameKey::StructMember(..) => &flattened_member_names[name_key],
+ _ => &self.names[name_key],
+ };
+ let ty_name = TypeContext {
+ handle: ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let resolved = options.resolve_local_binding(binding, in_mode)?;
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out, ";")?;
+ }
+ writeln!(self.out, "}};")?;
+ }
+
+ // Define a struct type named for the return value, if any, named
+ // `<fun>Output`.
+ let stage_out_name = format!("{fun_name}Output");
+ let result_member_name = self.namer.call("member");
+ let result_type_name = match fun.result {
+ Some(ref result) => {
+ let mut result_members = Vec::new();
+ if let crate::TypeInner::Struct { ref members, .. } =
+ module.types[result.ty].inner
+ {
+ for (member_index, member) in members.iter().enumerate() {
+ result_members.push((
+ &self.names[&NameKey::StructMember(result.ty, member_index as u32)],
+ member.ty,
+ member.binding.as_ref(),
+ ));
+ }
+ } else {
+ result_members.push((
+ &result_member_name,
+ result.ty,
+ result.binding.as_ref(),
+ ));
+ }
+
+ writeln!(self.out, "struct {stage_out_name} {{")?;
+ let mut has_point_size = false;
+ for (name, ty, binding) in result_members {
+ let ty_name = TypeContext {
+ handle: ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: true,
+ };
+ let binding = binding.ok_or(Error::Validation)?;
+
+ if let crate::Binding::BuiltIn(crate::BuiltIn::PointSize) = *binding {
+ has_point_size = true;
+ if !pipeline_options.allow_and_force_point_size {
+ continue;
+ }
+ }
+
+ let array_len = match module.types[ty].inner {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(size),
+ ..
+ } => Some(size),
+ _ => None,
+ };
+ let resolved = options.resolve_local_binding(binding, out_mode)?;
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
+ if let Some(array_len) = array_len {
+ write!(self.out, " [{array_len}]")?;
+ }
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out, ";")?;
+ }
+
+ if pipeline_options.allow_and_force_point_size
+ && ep.stage == crate::ShaderStage::Vertex
+ && !has_point_size
+ {
+ // inject the point size output last
+ writeln!(
+ self.out,
+ "{}float _point_size [[point_size]];",
+ back::INDENT
+ )?;
+ }
+ writeln!(self.out, "}};")?;
+ &stage_out_name
+ }
+ None => "void",
+ };
+
+ // Write the entry point function's name, and begin its argument list.
+ writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?;
+ let mut is_first_argument = true;
+
+ // If we have produced a struct holding the `EntryPoint`'s
+ // `Function`'s arguments' varyings, pass that struct first.
+ if has_varyings {
+ writeln!(
+ self.out,
+ " {stage_in_name} {varyings_member_name} [[stage_in]]"
+ )?;
+ is_first_argument = false;
+ }
+
+ let mut local_invocation_id = None;
+
+ // Then pass the remaining arguments not included in the varyings
+ // struct.
+ for &(ref name_key, ty, binding) in flattened_arguments.iter() {
+ let binding = match binding {
+ Some(binding @ &crate::Binding::BuiltIn { .. }) => binding,
+ _ => continue,
+ };
+ let name = match *name_key {
+ NameKey::StructMember(..) => &flattened_member_names[name_key],
+ _ => &self.names[name_key],
+ };
+
+ if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
+ local_invocation_id = Some(name_key);
+ }
+
+ let ty_name = TypeContext {
+ handle: ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ let resolved = options.resolve_local_binding(binding, in_mode)?;
+ let separator = if is_first_argument {
+ is_first_argument = false;
+ ' '
+ } else {
+ ','
+ };
+ write!(self.out, "{separator} {ty_name} {name}")?;
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out)?;
+ }
+
+ let need_workgroup_variables_initialization =
+ self.need_workgroup_variables_initialization(options, ep, module, fun_info);
+
+ if need_workgroup_variables_initialization && local_invocation_id.is_none() {
+ let separator = if is_first_argument {
+ is_first_argument = false;
+ ' '
+ } else {
+ ','
+ };
+ writeln!(
+ self.out,
+ "{separator} {NAMESPACE}::uint3 __local_invocation_id [[thread_position_in_threadgroup]]"
+ )?;
+ }
+
+ // Those global variables used by this entry point and its callees
+ // get passed as arguments. `Private` globals are an exception, they
+ // don't outlive this invocation, so we declare them below as locals
+ // within the entry point.
+ for (handle, var) in module.global_variables.iter() {
+ let usage = fun_info[handle];
+ if usage.is_empty() || var.space == crate::AddressSpace::Private {
+ continue;
+ }
+
+ if options.lang_version < (1, 2) {
+ match var.space {
+ // This restriction is not documented in the MSL spec
+ // but validation will fail if it is not upheld.
+ //
+ // We infer the required version from the "Function
+ // Buffer Read-Writes" section of [what's new], where
+ // the feature sets listed correspond with the ones
+ // supporting MSL 1.2.
+ //
+ // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
+ crate::AddressSpace::Storage { access }
+ if access.contains(crate::StorageAccess::STORE)
+ && ep.stage == crate::ShaderStage::Fragment =>
+ {
+ return Err(Error::UnsupportedWriteableStorageBuffer)
+ }
+ crate::AddressSpace::Handle => {
+ match module.types[var.ty].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => {
+ // This restriction is not documented in the MSL spec
+ // but validation will fail if it is not upheld.
+ //
+ // We infer the required version from the "Function
+ // Texture Read-Writes" section of [what's new], where
+ // the feature sets listed correspond with the ones
+ // supporting MSL 1.2.
+ //
+ // [what's new]: https://developer.apple.com/library/archive/documentation/Miscellaneous/Conceptual/MetalProgrammingGuide/WhatsNewiniOS10tvOS10andOSX1012/WhatsNewiniOS10tvOS10andOSX1012.html
+ if access.contains(crate::StorageAccess::STORE)
+ && (ep.stage == crate::ShaderStage::Vertex
+ || ep.stage == crate::ShaderStage::Fragment)
+ {
+ return Err(Error::UnsupportedWriteableStorageTexture(
+ ep.stage,
+ ));
+ }
+
+ if access.contains(
+ crate::StorageAccess::LOAD | crate::StorageAccess::STORE,
+ ) {
+ return Err(Error::UnsupportedRWStorageTexture);
+ }
+ }
+ _ => {}
+ }
+ }
+ _ => {}
+ }
+ }
+
+ // Check min MSL version for binding arrays
+ match var.space {
+ crate::AddressSpace::Handle => match module.types[var.ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => {
+ match module.types[base].inner {
+ crate::TypeInner::Sampler { .. } => {
+ if options.lang_version < (2, 0) {
+ return Err(Error::UnsupportedArrayOf(
+ "samplers".to_string(),
+ ));
+ }
+ }
+ crate::TypeInner::Image { class, .. } => match class {
+ crate::ImageClass::Sampled { .. }
+ | crate::ImageClass::Depth { .. }
+ | crate::ImageClass::Storage {
+ access: crate::StorageAccess::LOAD,
+ ..
+ } => {
+ // Array of textures since:
+ // - iOS: Metal 1.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
+ // - macOS: Metal 2
+
+ if options.lang_version < (2, 0) {
+ return Err(Error::UnsupportedArrayOf(
+ "textures".to_string(),
+ ));
+ }
+ }
+ crate::ImageClass::Storage {
+ access: crate::StorageAccess::STORE,
+ ..
+ } => {
+ // Array of write-only textures since:
+ // - iOS: Metal 2.2 (check depends on https://github.com/gfx-rs/naga/issues/2164)
+ // - macOS: Metal 2
+
+ if options.lang_version < (2, 0) {
+ return Err(Error::UnsupportedArrayOf(
+ "write-only textures".to_string(),
+ ));
+ }
+ }
+ crate::ImageClass::Storage { .. } => {
+ return Err(Error::UnsupportedArrayOf(
+ "read-write textures".to_string(),
+ ));
+ }
+ },
+ _ => {
+ return Err(Error::UnsupportedArrayOfType(base));
+ }
+ }
+ }
+ _ => {}
+ },
+ _ => {}
+ }
+
+ // the resolves have already been checked for `!fake_missing_bindings` case
+ let resolved = match var.space {
+ crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
+ crate::AddressSpace::WorkGroup => None,
+ _ => options
+ .resolve_resource_binding(ep, var.binding.as_ref().unwrap())
+ .ok(),
+ };
+ if let Some(ref resolved) = resolved {
+ // Inline samplers are be defined in the EP body
+ if resolved.as_inline_sampler(options).is_some() {
+ continue;
+ }
+ }
+
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage,
+ binding: resolved.as_ref(),
+ reference: true,
+ };
+ let separator = if is_first_argument {
+ is_first_argument = false;
+ ' '
+ } else {
+ ','
+ };
+ write!(self.out, "{separator} ")?;
+ tyvar.try_fmt(&mut self.out)?;
+ if let Some(resolved) = resolved {
+ resolved.try_fmt(&mut self.out)?;
+ }
+ if let Some(value) = var.init {
+ write!(self.out, " = ")?;
+ self.put_const_expression(value, module, mod_info)?;
+ }
+ writeln!(self.out)?;
+ }
+
+ // If this entry uses any variable-length arrays, their sizes are
+ // passed as a final struct-typed argument.
+ if supports_array_length {
+ // this is checked earlier
+ let resolved = options.resolve_sizes_buffer(ep).unwrap();
+ let separator = if module.global_variables.is_empty() {
+ ' '
+ } else {
+ ','
+ };
+ write!(
+ self.out,
+ "{separator} constant _mslBufferSizes& _buffer_sizes",
+ )?;
+ resolved.try_fmt(&mut self.out)?;
+ writeln!(self.out)?;
+ }
+
+ // end of the entry point argument list
+ writeln!(self.out, ") {{")?;
+
+ if need_workgroup_variables_initialization {
+ self.write_workgroup_variables_initialization(
+ module,
+ mod_info,
+ fun_info,
+ local_invocation_id,
+ )?;
+ }
+
+ // Metal doesn't support private mutable variables outside of functions,
+ // so we put them here, just like the locals.
+ for (handle, var) in module.global_variables.iter() {
+ let usage = fun_info[handle];
+ if usage.is_empty() {
+ continue;
+ }
+ if var.space == crate::AddressSpace::Private {
+ let tyvar = TypedGlobalVariable {
+ module,
+ names: &self.names,
+ handle,
+ usage,
+ binding: None,
+ reference: false,
+ };
+ write!(self.out, "{}", back::INDENT)?;
+ tyvar.try_fmt(&mut self.out)?;
+ match var.init {
+ Some(value) => {
+ write!(self.out, " = ")?;
+ self.put_const_expression(value, module, mod_info)?;
+ writeln!(self.out, ";")?;
+ }
+ None => {
+ writeln!(self.out, " = {{}};")?;
+ }
+ };
+ } else if let Some(ref binding) = var.binding {
+ // write an inline sampler
+ let resolved = options.resolve_resource_binding(ep, binding).unwrap();
+ if let Some(sampler) = resolved.as_inline_sampler(options) {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ writeln!(
+ self.out,
+ "{}constexpr {}::sampler {}(",
+ back::INDENT,
+ NAMESPACE,
+ name
+ )?;
+ self.put_inline_sampler_properties(back::Level(2), sampler)?;
+ writeln!(self.out, "{});", back::INDENT)?;
+ }
+ }
+ }
+
+ // Now take the arguments that we gathered into structs, and the
+ // structs that we flattened into arguments, and emit local
+ // variables with initializers that put everything back the way the
+ // body code expects.
+ //
+ // If we had to generate fresh names for struct members passed as
+ // arguments, be sure to use those names when rebuilding the struct.
+ //
+ // "Each day, I change some zeros to ones, and some ones to zeros.
+ // The rest, I leave alone."
+ for (arg_index, arg) in fun.arguments.iter().enumerate() {
+ let arg_name =
+ &self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
+ match module.types[arg.ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ let struct_name = &self.names[&NameKey::Type(arg.ty)];
+ write!(
+ self.out,
+ "{}const {} {} = {{ ",
+ back::INDENT,
+ struct_name,
+ arg_name
+ )?;
+ for (member_index, member) in members.iter().enumerate() {
+ let key = NameKey::StructMember(arg.ty, member_index as u32);
+ let name = &flattened_member_names[&key];
+ if member_index != 0 {
+ write!(self.out, ", ")?;
+ }
+ // insert padding initialization, if needed
+ if self
+ .struct_member_pads
+ .contains(&(arg.ty, member_index as u32))
+ {
+ write!(self.out, "{{}}, ")?;
+ }
+ if let Some(crate::Binding::Location { .. }) = member.binding {
+ write!(self.out, "{varyings_member_name}.")?;
+ }
+ write!(self.out, "{name}")?;
+ }
+ writeln!(self.out, " }};")?;
+ }
+ _ => {
+ if let Some(crate::Binding::Location { .. }) = arg.binding {
+ writeln!(
+ self.out,
+ "{}const auto {} = {}.{};",
+ back::INDENT,
+ arg_name,
+ varyings_member_name,
+ arg_name
+ )?;
+ }
+ }
+ }
+ }
+
+ let guarded_indices =
+ index::find_checked_indexes(module, fun, fun_info, options.bounds_check_policies);
+
+ let context = StatementContext {
+ expression: ExpressionContext {
+ function: fun,
+ origin: FunctionOrigin::EntryPoint(ep_index as _),
+ info: fun_info,
+ lang_version: options.lang_version,
+ policies: options.bounds_check_policies,
+ guarded_indices,
+ module,
+ mod_info,
+ pipeline_options,
+ },
+ result_struct: Some(&stage_out_name),
+ };
+
+ // Finally, declare all the local variables that we need
+ //TODO: we can postpone this till the relevant expressions are emitted
+ for (local_handle, local) in fun.local_variables.iter() {
+ let name = &self.names[&NameKey::EntryPointLocal(ep_index as _, local_handle)];
+ let ty_name = TypeContext {
+ handle: local.ty,
+ gctx: module.to_ctx(),
+ names: &self.names,
+ access: crate::StorageAccess::empty(),
+ binding: None,
+ first_time: false,
+ };
+ write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
+ match local.init {
+ Some(value) => {
+ write!(self.out, " = ")?;
+ self.put_expression(value, &context.expression, true)?;
+ }
+ None => {
+ write!(self.out, " = {{}}")?;
+ }
+ };
+ writeln!(self.out, ";")?;
+ }
+
+ self.update_expressions_to_bake(fun, fun_info, &context.expression);
+ self.put_block(back::Level(1), &fun.body, &context)?;
+ writeln!(self.out, "}}")?;
+ if ep_index + 1 != module.entry_points.len() {
+ writeln!(self.out)?;
+ }
+ self.named_expressions.clear();
+ }
+
+ Ok(info)
+ }
+
+ fn write_barrier(&mut self, flags: crate::Barrier, level: back::Level) -> BackendResult {
+ // Note: OR-ring bitflags requires `__HAVE_MEMFLAG_OPERATORS__`,
+ // so we try to avoid it here.
+ if flags.is_empty() {
+ writeln!(
+ self.out,
+ "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_none);",
+ )?;
+ }
+ if flags.contains(crate::Barrier::STORAGE) {
+ writeln!(
+ self.out,
+ "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_device);",
+ )?;
+ }
+ if flags.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(
+ self.out,
+ "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
+ )?;
+ }
+ Ok(())
+ }
+}
+
+/// Initializing workgroup variables is more tricky for Metal because we have to deal
+/// with atomics at the type-level (which don't have a copy constructor).
+mod workgroup_mem_init {
+ use crate::EntryPoint;
+
+ use super::*;
+
+ enum Access {
+ GlobalVariable(Handle<crate::GlobalVariable>),
+ StructMember(Handle<crate::Type>, u32),
+ Array(usize),
+ }
+
+ impl Access {
+ fn write<W: Write>(
+ &self,
+ writer: &mut W,
+ names: &FastHashMap<NameKey, String>,
+ ) -> Result<(), core::fmt::Error> {
+ match *self {
+ Access::GlobalVariable(handle) => {
+ write!(writer, "{}", &names[&NameKey::GlobalVariable(handle)])
+ }
+ Access::StructMember(handle, index) => {
+ write!(writer, ".{}", &names[&NameKey::StructMember(handle, index)])
+ }
+ Access::Array(depth) => write!(writer, ".{WRAPPED_ARRAY_FIELD}[__i{depth}]"),
+ }
+ }
+ }
+
+ struct AccessStack {
+ stack: Vec<Access>,
+ array_depth: usize,
+ }
+
+ impl AccessStack {
+ const fn new() -> Self {
+ Self {
+ stack: Vec::new(),
+ array_depth: 0,
+ }
+ }
+
+ fn enter_array<R>(&mut self, cb: impl FnOnce(&mut Self, usize) -> R) -> R {
+ let array_depth = self.array_depth;
+ self.stack.push(Access::Array(array_depth));
+ self.array_depth += 1;
+ let res = cb(self, array_depth);
+ self.stack.pop();
+ self.array_depth -= 1;
+ res
+ }
+
+ fn enter<R>(&mut self, new: Access, cb: impl FnOnce(&mut Self) -> R) -> R {
+ self.stack.push(new);
+ let res = cb(self);
+ self.stack.pop();
+ res
+ }
+
+ fn write<W: Write>(
+ &self,
+ writer: &mut W,
+ names: &FastHashMap<NameKey, String>,
+ ) -> Result<(), core::fmt::Error> {
+ for next in self.stack.iter() {
+ next.write(writer, names)?;
+ }
+ Ok(())
+ }
+ }
+
+ impl<W: Write> Writer<W> {
+ pub(super) fn need_workgroup_variables_initialization(
+ &mut self,
+ options: &Options,
+ ep: &EntryPoint,
+ module: &crate::Module,
+ fun_info: &valid::FunctionInfo,
+ ) -> bool {
+ options.zero_initialize_workgroup_memory
+ && ep.stage == crate::ShaderStage::Compute
+ && module.global_variables.iter().any(|(handle, var)| {
+ !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ })
+ }
+
+ pub(super) fn write_workgroup_variables_initialization(
+ &mut self,
+ module: &crate::Module,
+ module_info: &valid::ModuleInfo,
+ fun_info: &valid::FunctionInfo,
+ local_invocation_id: Option<&NameKey>,
+ ) -> BackendResult {
+ let level = back::Level(1);
+
+ writeln!(
+ self.out,
+ "{}if ({}::all({} == {}::uint3(0u))) {{",
+ level,
+ NAMESPACE,
+ local_invocation_id
+ .map(|name_key| self.names[name_key].as_str())
+ .unwrap_or("__local_invocation_id"),
+ NAMESPACE,
+ )?;
+
+ let mut access_stack = AccessStack::new();
+
+ let vars = module.global_variables.iter().filter(|&(handle, var)| {
+ !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ });
+
+ for (handle, var) in vars {
+ access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
+ self.write_workgroup_variable_initialization(
+ module,
+ module_info,
+ var.ty,
+ access_stack,
+ level.next(),
+ )
+ })?;
+ }
+
+ writeln!(self.out, "{level}}}")?;
+ self.write_barrier(crate::Barrier::WORK_GROUP, level)
+ }
+
+ fn write_workgroup_variable_initialization(
+ &mut self,
+ module: &crate::Module,
+ module_info: &valid::ModuleInfo,
+ ty: Handle<crate::Type>,
+ access_stack: &mut AccessStack,
+ level: back::Level,
+ ) -> BackendResult {
+ if module_info[ty].contains(valid::TypeFlags::CONSTRUCTIBLE) {
+ write!(self.out, "{level}")?;
+ access_stack.write(&mut self.out, &self.names)?;
+ writeln!(self.out, " = {{}};")?;
+ } else {
+ match module.types[ty].inner {
+ crate::TypeInner::Atomic { .. } => {
+ write!(
+ self.out,
+ "{level}{NAMESPACE}::atomic_store_explicit({ATOMIC_REFERENCE}"
+ )?;
+ access_stack.write(&mut self.out, &self.names)?;
+ writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
+ }
+ crate::TypeInner::Array { base, size, .. } => {
+ let count = match size.to_indexable_length(module).expect("Bad array size")
+ {
+ proc::IndexableLength::Known(count) => count,
+ proc::IndexableLength::Dynamic => unreachable!(),
+ };
+
+ access_stack.enter_array(|access_stack, array_depth| {
+ writeln!(
+ self.out,
+ "{level}for (int __i{array_depth} = 0; __i{array_depth} < {count}; __i{array_depth}++) {{"
+ )?;
+ self.write_workgroup_variable_initialization(
+ module,
+ module_info,
+ base,
+ access_stack,
+ level.next(),
+ )?;
+ writeln!(self.out, "{level}}}")?;
+ BackendResult::Ok(())
+ })?;
+ }
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (index, member) in members.iter().enumerate() {
+ access_stack.enter(
+ Access::StructMember(ty, index as u32),
+ |access_stack| {
+ self.write_workgroup_variable_initialization(
+ module,
+ module_info,
+ member.ty,
+ access_stack,
+ level,
+ )
+ },
+ )?;
+ }
+ }
+ _ => unreachable!(),
+ }
+ }
+
+ Ok(())
+ }
+ }
+}
+
+#[test]
+fn test_stack_size() {
+ use crate::valid::{Capabilities, ValidationFlags};
+ // create a module with at least one expression nested
+ let mut module = crate::Module::default();
+ let mut fun = crate::Function::default();
+ let const_expr = fun.expressions.append(
+ crate::Expression::Literal(crate::Literal::F32(1.0)),
+ Default::default(),
+ );
+ let nested_expr = fun.expressions.append(
+ crate::Expression::Unary {
+ op: crate::UnaryOperator::Negate,
+ expr: const_expr,
+ },
+ Default::default(),
+ );
+ fun.body.push(
+ crate::Statement::Emit(fun.expressions.range_from(1)),
+ Default::default(),
+ );
+ fun.body.push(
+ crate::Statement::If {
+ condition: nested_expr,
+ accept: crate::Block::new(),
+ reject: crate::Block::new(),
+ },
+ Default::default(),
+ );
+ let _ = module.functions.append(fun, Default::default());
+ // analyse the module
+ let info = crate::valid::Validator::new(ValidationFlags::empty(), Capabilities::empty())
+ .validate(&module)
+ .unwrap();
+ // process the module
+ let mut writer = Writer::new(String::new());
+ writer
+ .write(&module, &info, &Default::default(), &Default::default())
+ .unwrap();
+
+ {
+ // check expression stack
+ let mut addresses_start = usize::MAX;
+ let mut addresses_end = 0usize;
+ for pointer in writer.put_expression_stack_pointers {
+ addresses_start = addresses_start.min(pointer as usize);
+ addresses_end = addresses_end.max(pointer as usize);
+ }
+ let stack_size = addresses_end - addresses_start;
+ // check the size (in debug only)
+ // last observed macOS value: 20528 (CI)
+ if !(11000..=25000).contains(&stack_size) {
+ panic!("`put_expression` stack size {stack_size} has changed!");
+ }
+ }
+
+ {
+ // check block stack
+ let mut addresses_start = usize::MAX;
+ let mut addresses_end = 0usize;
+ for pointer in writer.put_block_stack_pointers {
+ addresses_start = addresses_start.min(pointer as usize);
+ addresses_end = addresses_end.max(pointer as usize);
+ }
+ let stack_size = addresses_end - addresses_start;
+ // check the size (in debug only)
+ // last observed macOS value: 19152 (CI)
+ if !(9000..=20000).contains(&stack_size) {
+ panic!("`put_block` stack size {stack_size} has changed!");
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/block.rs b/third_party/rust/naga/src/back/spv/block.rs
new file mode 100644
index 0000000000..6c96fa09e3
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/block.rs
@@ -0,0 +1,2368 @@
+/*!
+Implementations for `BlockContext` methods.
+*/
+
+use super::{
+ helpers, index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext,
+ Dimension, Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer,
+ WriterFlags,
+};
+use crate::{arena::Handle, proc::TypeResolution, Statement};
+use spirv::Word;
+
+fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
+ match *type_inner {
+ crate::TypeInner::Scalar(_) => Dimension::Scalar,
+ crate::TypeInner::Vector { .. } => Dimension::Vector,
+ crate::TypeInner::Matrix { .. } => Dimension::Matrix,
+ _ => unreachable!(),
+ }
+}
+
+/// The results of emitting code for a left-hand-side expression.
+///
+/// On success, `write_expression_pointer` returns one of these.
+enum ExpressionPointer {
+ /// The pointer to the expression's value is available, as the value of the
+ /// expression with the given id.
+ Ready { pointer_id: Word },
+
+ /// The access expression must be conditional on the value of `condition`, a boolean
+ /// expression that is true if all indices are in bounds. If `condition` is true, then
+ /// `access` is an `OpAccessChain` instruction that will compute a pointer to the
+ /// expression's value. If `condition` is false, then executing `access` would be
+ /// undefined behavior.
+ Conditional {
+ condition: Word,
+ access: Instruction,
+ },
+}
+
+/// The termination statement to be added to the end of the block
+pub enum BlockExit {
+ /// Generates an OpReturn (void return)
+ Return,
+ /// Generates an OpBranch to the specified block
+ Branch {
+ /// The branch target block
+ target: Word,
+ },
+ /// Translates a loop `break if` into an `OpBranchConditional` to the
+ /// merge block if true (the merge block is passed through [`LoopContext::break_id`]
+ /// or else to the loop header (passed through [`preamble_id`])
+ ///
+ /// [`preamble_id`]: Self::BreakIf::preamble_id
+ BreakIf {
+ /// The condition of the `break if`
+ condition: Handle<crate::Expression>,
+ /// The loop header block id
+ preamble_id: Word,
+ },
+}
+
+#[derive(Debug)]
+pub(crate) struct DebugInfoInner<'a> {
+ pub source_code: &'a str,
+ pub source_file_id: Word,
+}
+
+impl Writer {
+ // Flip Y coordinate to adjust for coordinate space difference
+ // between SPIR-V and our IR.
+ // The `position_id` argument is a pointer to a `vecN<f32>`,
+ // whose `y` component we will negate.
+ fn write_epilogue_position_y_flip(
+ &mut self,
+ position_id: Word,
+ body: &mut Vec<Instruction>,
+ ) -> Result<(), Error> {
+ let float_ptr_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::F32,
+ pointer_space: Some(spirv::StorageClass::Output),
+ }));
+ let index_y_id = self.get_index_constant(1);
+ let access_id = self.id_gen.next();
+ body.push(Instruction::access_chain(
+ float_ptr_type_id,
+ access_id,
+ position_id,
+ &[index_y_id],
+ ));
+
+ let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }));
+ let load_id = self.id_gen.next();
+ body.push(Instruction::load(float_type_id, load_id, access_id, None));
+
+ let neg_id = self.id_gen.next();
+ body.push(Instruction::unary(
+ spirv::Op::FNegate,
+ float_type_id,
+ neg_id,
+ load_id,
+ ));
+
+ body.push(Instruction::store(access_id, neg_id, None));
+ Ok(())
+ }
+
+ // Clamp fragment depth between 0 and 1.
+ fn write_epilogue_frag_depth_clamp(
+ &mut self,
+ frag_depth_id: Word,
+ body: &mut Vec<Instruction>,
+ ) -> Result<(), Error> {
+ let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }));
+ let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0));
+ let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0));
+
+ let original_id = self.id_gen.next();
+ body.push(Instruction::load(
+ float_type_id,
+ original_id,
+ frag_depth_id,
+ None,
+ ));
+
+ let clamp_id = self.id_gen.next();
+ body.push(Instruction::ext_inst(
+ self.gl450_ext_inst_id,
+ spirv::GLOp::FClamp,
+ float_type_id,
+ clamp_id,
+ &[original_id, zero_scalar_id, one_scalar_id],
+ ));
+
+ body.push(Instruction::store(frag_depth_id, clamp_id, None));
+ Ok(())
+ }
+
+ fn write_entry_point_return(
+ &mut self,
+ value_id: Word,
+ ir_result: &crate::FunctionResult,
+ result_members: &[ResultMember],
+ body: &mut Vec<Instruction>,
+ ) -> Result<(), Error> {
+ for (index, res_member) in result_members.iter().enumerate() {
+ let member_value_id = match ir_result.binding {
+ Some(_) => value_id,
+ None => {
+ let member_value_id = self.id_gen.next();
+ body.push(Instruction::composite_extract(
+ res_member.type_id,
+ member_value_id,
+ value_id,
+ &[index as u32],
+ ));
+ member_value_id
+ }
+ };
+
+ body.push(Instruction::store(res_member.id, member_value_id, None));
+
+ match res_member.built_in {
+ Some(crate::BuiltIn::Position { .. })
+ if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
+ {
+ self.write_epilogue_position_y_flip(res_member.id, body)?;
+ }
+ Some(crate::BuiltIn::FragDepth)
+ if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
+ {
+ self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
+ }
+ _ => {}
+ }
+ }
+ Ok(())
+ }
+}
+
+impl<'w> BlockContext<'w> {
+ /// Decide whether to put off emitting instructions for `expr_handle`.
+ ///
+ /// We would like to gather together chains of `Access` and `AccessIndex`
+ /// Naga expressions into a single `OpAccessChain` SPIR-V instruction. To do
+ /// this, we don't generate instructions for these exprs when we first
+ /// encounter them. Their ids in `self.writer.cached.ids` are left as zero. Then,
+ /// once we encounter a `Load` or `Store` expression that actually needs the
+ /// chain's value, we call `write_expression_pointer` to handle the whole
+ /// thing in one fell swoop.
+ fn is_intermediate(&self, expr_handle: Handle<crate::Expression>) -> bool {
+ match self.ir_function.expressions[expr_handle] {
+ crate::Expression::GlobalVariable(handle) => {
+ match self.ir_module.global_variables[handle].space {
+ crate::AddressSpace::Handle => false,
+ _ => true,
+ }
+ }
+ crate::Expression::LocalVariable(_) => true,
+ crate::Expression::FunctionArgument(index) => {
+ let arg = &self.ir_function.arguments[index as usize];
+ self.ir_module.types[arg.ty].inner.pointer_space().is_some()
+ }
+
+ // The chain rule: if this `Access...`'s `base` operand was
+ // previously omitted, then omit this one, too.
+ _ => self.cached.ids[expr_handle.index()] == 0,
+ }
+ }
+
+ /// Cache an expression for a value.
+ pub(super) fn cache_expression_value(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ let is_named_expression = self
+ .ir_function
+ .named_expressions
+ .contains_key(&expr_handle);
+
+ if self.fun_info[expr_handle].ref_count == 0 && !is_named_expression {
+ return Ok(());
+ }
+
+ let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
+ let id = match self.ir_function.expressions[expr_handle] {
+ crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal),
+ crate::Expression::Constant(handle) => {
+ let init = self.ir_module.constants[handle].init;
+ self.writer.constant_ids[init.index()]
+ }
+ crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
+ crate::Expression::Compose { ty, ref components } => {
+ self.temp_list.clear();
+ if self.expression_constness.is_const(expr_handle) {
+ self.temp_list.extend(
+ crate::proc::flatten_compose(
+ ty,
+ components,
+ &self.ir_function.expressions,
+ &self.ir_module.types,
+ )
+ .map(|component| self.cached[component]),
+ );
+ self.writer
+ .get_constant_composite(LookupType::Handle(ty), &self.temp_list)
+ } else {
+ self.temp_list
+ .extend(components.iter().map(|&component| self.cached[component]));
+
+ let id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ id,
+ &self.temp_list,
+ ));
+ id
+ }
+ }
+ crate::Expression::Splat { size, value } => {
+ let value_id = self.cached[value];
+ let components = &[value_id; 4][..size as usize];
+
+ if self.expression_constness.is_const(expr_handle) {
+ let ty = self
+ .writer
+ .get_expression_lookup_type(&self.fun_info[expr_handle].ty);
+ self.writer.get_constant_composite(ty, components)
+ } else {
+ let id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ id,
+ components,
+ ));
+ id
+ }
+ }
+ crate::Expression::Access { base, index: _ } if self.is_intermediate(base) => {
+ // See `is_intermediate`; we'll handle this later in
+ // `write_expression_pointer`.
+ 0
+ }
+ crate::Expression::Access { base, index } => {
+ let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types);
+ match *base_ty_inner {
+ crate::TypeInner::Vector { .. } => {
+ self.write_vector_access(expr_handle, base, index, block)?
+ }
+ // Only binding arrays in the Handle address space will take this path (due to `is_intermediate`)
+ crate::TypeInner::BindingArray {
+ base: binding_type, ..
+ } => {
+ let space = match self.ir_function.expressions[base] {
+ crate::Expression::GlobalVariable(gvar) => {
+ self.ir_module.global_variables[gvar].space
+ }
+ _ => unreachable!(),
+ };
+ let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
+ base: binding_type,
+ class: helpers::map_storage_class(space),
+ });
+
+ let result_id = match self.write_expression_pointer(
+ expr_handle,
+ block,
+ Some(binding_array_false_pointer),
+ )? {
+ ExpressionPointer::Ready { pointer_id } => pointer_id,
+ ExpressionPointer::Conditional { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "Texture array out-of-bounds handling",
+ ));
+ }
+ };
+
+ let binding_type_id = self.get_type_id(LookupType::Handle(binding_type));
+
+ let load_id = self.gen_id();
+ block.body.push(Instruction::load(
+ binding_type_id,
+ load_id,
+ result_id,
+ None,
+ ));
+
+ // Subsequent image operations require the image/sampler to be decorated as NonUniform
+ // if the image/sampler binding array was accessed with a non-uniform index
+ // see VUID-RuntimeSpirv-NonUniform-06274
+ if self.fun_info[index].uniformity.non_uniform_result.is_some() {
+ self.writer
+ .decorate_non_uniform_binding_array_access(load_id)?;
+ }
+
+ load_id
+ }
+ ref other => {
+ log::error!(
+ "Unable to access base {:?} of type {:?}",
+ self.ir_function.expressions[base],
+ other
+ );
+ return Err(Error::Validation(
+ "only vectors may be dynamically indexed by value",
+ ));
+ }
+ }
+ }
+ crate::Expression::AccessIndex { base, index: _ } if self.is_intermediate(base) => {
+ // See `is_intermediate`; we'll handle this later in
+ // `write_expression_pointer`.
+ 0
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. }
+ | crate::TypeInner::Array { .. }
+ | crate::TypeInner::Struct { .. } => {
+ // We never need bounds checks here: dynamically sized arrays can
+ // only appear behind pointers, and are thus handled by the
+ // `is_intermediate` case above. Everything else's size is
+ // statically known and checked in validation.
+ let id = self.gen_id();
+ let base_id = self.cached[base];
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ id,
+ base_id,
+ &[index],
+ ));
+ id
+ }
+ // Only binding arrays in the Handle address space will take this path (due to `is_intermediate`)
+ crate::TypeInner::BindingArray {
+ base: binding_type, ..
+ } => {
+ let space = match self.ir_function.expressions[base] {
+ crate::Expression::GlobalVariable(gvar) => {
+ self.ir_module.global_variables[gvar].space
+ }
+ _ => unreachable!(),
+ };
+ let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
+ base: binding_type,
+ class: helpers::map_storage_class(space),
+ });
+
+ let result_id = match self.write_expression_pointer(
+ expr_handle,
+ block,
+ Some(binding_array_false_pointer),
+ )? {
+ ExpressionPointer::Ready { pointer_id } => pointer_id,
+ ExpressionPointer::Conditional { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "Texture array out-of-bounds handling",
+ ));
+ }
+ };
+
+ let binding_type_id = self.get_type_id(LookupType::Handle(binding_type));
+
+ let load_id = self.gen_id();
+ block.body.push(Instruction::load(
+ binding_type_id,
+ load_id,
+ result_id,
+ None,
+ ));
+
+ load_id
+ }
+ ref other => {
+ log::error!("Unable to access index of {:?}", other);
+ return Err(Error::FeatureNotImplemented("access index for type"));
+ }
+ }
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ self.writer.global_variables[handle.index()].access_id
+ }
+ crate::Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ let vector_id = self.cached[vector];
+ self.temp_list.clear();
+ for &sc in pattern[..size as usize].iter() {
+ self.temp_list.push(sc as Word);
+ }
+ let id = self.gen_id();
+ block.body.push(Instruction::vector_shuffle(
+ result_type_id,
+ id,
+ vector_id,
+ vector_id,
+ &self.temp_list,
+ ));
+ id
+ }
+ crate::Expression::Unary { op, expr } => {
+ let id = self.gen_id();
+ let expr_id = self.cached[expr];
+ let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
+
+ let spirv_op = match op {
+ crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
+ Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
+ _ => return Err(Error::Validation("Unexpected kind for negation")),
+ },
+ crate::UnaryOperator::LogicalNot => spirv::Op::LogicalNot,
+ crate::UnaryOperator::BitwiseNot => spirv::Op::Not,
+ };
+
+ block
+ .body
+ .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
+ id
+ }
+ crate::Expression::Binary { op, left, right } => {
+ let id = self.gen_id();
+ let left_id = self.cached[left];
+ let right_id = self.cached[right];
+
+ let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
+ let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
+
+ let left_dimension = get_dimension(left_ty_inner);
+ let right_dimension = get_dimension(right_ty_inner);
+
+ let mut reverse_operands = false;
+
+ let spirv_op = match op {
+ crate::BinaryOperator::Add => match *left_ty_inner {
+ crate::TypeInner::Scalar(scalar)
+ | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
+ crate::ScalarKind::Float => spirv::Op::FAdd,
+ _ => spirv::Op::IAdd,
+ },
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ self.write_matrix_matrix_column_op(
+ block,
+ id,
+ result_type_id,
+ left_id,
+ right_id,
+ columns,
+ rows,
+ scalar.width,
+ spirv::Op::FAdd,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Subtract => match *left_ty_inner {
+ crate::TypeInner::Scalar(scalar)
+ | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
+ crate::ScalarKind::Float => spirv::Op::FSub,
+ _ => spirv::Op::ISub,
+ },
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ self.write_matrix_matrix_column_op(
+ block,
+ id,
+ result_type_id,
+ left_id,
+ right_id,
+ columns,
+ rows,
+ scalar.width,
+ spirv::Op::FSub,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Multiply => match (left_dimension, right_dimension) {
+ (Dimension::Scalar, Dimension::Vector) => {
+ self.write_vector_scalar_mult(
+ block,
+ id,
+ result_type_id,
+ right_id,
+ left_id,
+ right_ty_inner,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ (Dimension::Vector, Dimension::Scalar) => {
+ self.write_vector_scalar_mult(
+ block,
+ id,
+ result_type_id,
+ left_id,
+ right_id,
+ left_ty_inner,
+ );
+
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ (Dimension::Vector, Dimension::Matrix) => spirv::Op::VectorTimesMatrix,
+ (Dimension::Matrix, Dimension::Scalar) => spirv::Op::MatrixTimesScalar,
+ (Dimension::Scalar, Dimension::Matrix) => {
+ reverse_operands = true;
+ spirv::Op::MatrixTimesScalar
+ }
+ (Dimension::Matrix, Dimension::Vector) => spirv::Op::MatrixTimesVector,
+ (Dimension::Matrix, Dimension::Matrix) => spirv::Op::MatrixTimesMatrix,
+ (Dimension::Vector, Dimension::Vector)
+ | (Dimension::Scalar, Dimension::Scalar)
+ if left_ty_inner.scalar_kind() == Some(crate::ScalarKind::Float) =>
+ {
+ spirv::Op::FMul
+ }
+ (Dimension::Vector, Dimension::Vector)
+ | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
+ },
+ crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
+ Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
+ Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
+ // TODO: handle undefined behavior
+ // if right == 0 return 0
+ // if left == min(type_of(left)) && right == -1 return 0
+ Some(crate::ScalarKind::Sint) => spirv::Op::SRem,
+ // TODO: handle undefined behavior
+ // if right == 0 return 0
+ Some(crate::ScalarKind::Uint) => spirv::Op::UMod,
+ // TODO: handle undefined behavior
+ // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798
+ Some(crate::ScalarKind::Float) => spirv::Op::FRem,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
+ spirv::Op::IEqual
+ }
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => {
+ spirv::Op::INotEqual
+ }
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
+ Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
+ Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
+ Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
+ Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
+ Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
+ _ => unimplemented!(),
+ },
+ crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
+ _ => spirv::Op::BitwiseAnd,
+ },
+ crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
+ crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
+ _ => spirv::Op::BitwiseOr,
+ },
+ crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
+ crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
+ crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
+ crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
+ Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
+ _ => unimplemented!(),
+ },
+ };
+
+ block.body.push(Instruction::binary(
+ spirv_op,
+ result_type_id,
+ id,
+ if reverse_operands { right_id } else { left_id },
+ if reverse_operands { left_id } else { right_id },
+ ));
+ id
+ }
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+ enum MathOp {
+ Ext(spirv::GLOp),
+ Custom(Instruction),
+ }
+
+ let arg0_id = self.cached[arg];
+ let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
+ let arg_scalar_kind = arg_ty.scalar_kind();
+ let arg1_id = match arg1 {
+ Some(handle) => self.cached[handle],
+ None => 0,
+ };
+ let arg2_id = match arg2 {
+ Some(handle) => self.cached[handle],
+ None => 0,
+ };
+ let arg3_id = match arg3 {
+ Some(handle) => self.cached[handle],
+ None => 0,
+ };
+
+ let id = self.gen_id();
+ let math_op = match fun {
+ // comparison
+ Mf::Abs => {
+ match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FAbs),
+ Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GLOp::SAbs),
+ Some(crate::ScalarKind::Uint) => {
+ MathOp::Custom(Instruction::unary(
+ spirv::Op::CopyObject, // do nothing
+ result_type_id,
+ id,
+ arg0_id,
+ ))
+ }
+ other => unimplemented!("Unexpected abs({:?})", other),
+ }
+ }
+ Mf::Min => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FMin,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SMin,
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::UMin,
+ other => unimplemented!("Unexpected min({:?})", other),
+ }),
+ Mf::Max => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FMax,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SMax,
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax,
+ other => unimplemented!("Unexpected max({:?})", other),
+ }),
+ Mf::Clamp => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FClamp,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SClamp,
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::UClamp,
+ other => unimplemented!("Unexpected max({:?})", other),
+ }),
+ Mf::Saturate => {
+ let (maybe_size, scalar) = match *arg_ty {
+ crate::TypeInner::Vector { size, scalar } => (Some(size), scalar),
+ crate::TypeInner::Scalar(scalar) => (None, scalar),
+ ref other => unimplemented!("Unexpected saturate({:?})", other),
+ };
+ let scalar = crate::Scalar::float(scalar.width);
+ let mut arg1_id = self.writer.get_constant_scalar_with(0, scalar)?;
+ let mut arg2_id = self.writer.get_constant_scalar_with(1, scalar)?;
+
+ if let Some(size) = maybe_size {
+ let ty = LocalType::Value {
+ vector_size: Some(size),
+ scalar,
+ pointer_space: None,
+ }
+ .into();
+
+ self.temp_list.clear();
+ self.temp_list.resize(size as _, arg1_id);
+
+ arg1_id = self.writer.get_constant_composite(ty, &self.temp_list);
+
+ self.temp_list.fill(arg2_id);
+
+ arg2_id = self.writer.get_constant_composite(ty, &self.temp_list);
+ }
+
+ MathOp::Custom(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::FClamp,
+ result_type_id,
+ id,
+ &[arg0_id, arg1_id, arg2_id],
+ ))
+ }
+ // trigonometry
+ Mf::Sin => MathOp::Ext(spirv::GLOp::Sin),
+ Mf::Sinh => MathOp::Ext(spirv::GLOp::Sinh),
+ Mf::Asin => MathOp::Ext(spirv::GLOp::Asin),
+ Mf::Cos => MathOp::Ext(spirv::GLOp::Cos),
+ Mf::Cosh => MathOp::Ext(spirv::GLOp::Cosh),
+ Mf::Acos => MathOp::Ext(spirv::GLOp::Acos),
+ Mf::Tan => MathOp::Ext(spirv::GLOp::Tan),
+ Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh),
+ Mf::Atan => MathOp::Ext(spirv::GLOp::Atan),
+ Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2),
+ Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh),
+ Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh),
+ Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh),
+ Mf::Radians => MathOp::Ext(spirv::GLOp::Radians),
+ Mf::Degrees => MathOp::Ext(spirv::GLOp::Degrees),
+ // decomposition
+ Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil),
+ Mf::Round => MathOp::Ext(spirv::GLOp::RoundEven),
+ Mf::Floor => MathOp::Ext(spirv::GLOp::Floor),
+ Mf::Fract => MathOp::Ext(spirv::GLOp::Fract),
+ Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc),
+ Mf::Modf => MathOp::Ext(spirv::GLOp::ModfStruct),
+ Mf::Frexp => MathOp::Ext(spirv::GLOp::FrexpStruct),
+ Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
+ // geometry
+ Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Vector {
+ scalar:
+ crate::Scalar {
+ kind: crate::ScalarKind::Float,
+ ..
+ },
+ ..
+ } => MathOp::Custom(Instruction::binary(
+ spirv::Op::Dot,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ )),
+ // TODO: consider using integer dot product if VK_KHR_shader_integer_dot_product is available
+ crate::TypeInner::Vector { size, .. } => {
+ self.write_dot_product(
+ id,
+ result_type_id,
+ arg0_id,
+ arg1_id,
+ size as u32,
+ block,
+ );
+ self.cached[expr_handle] = id;
+ return Ok(());
+ }
+ _ => unreachable!(
+ "Correct TypeInner for dot product should be already validated"
+ ),
+ },
+ Mf::Outer => MathOp::Custom(Instruction::binary(
+ spirv::Op::OuterProduct,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ )),
+ Mf::Cross => MathOp::Ext(spirv::GLOp::Cross),
+ Mf::Distance => MathOp::Ext(spirv::GLOp::Distance),
+ Mf::Length => MathOp::Ext(spirv::GLOp::Length),
+ Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize),
+ Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward),
+ Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect),
+ Mf::Refract => MathOp::Ext(spirv::GLOp::Refract),
+ // exponent
+ Mf::Exp => MathOp::Ext(spirv::GLOp::Exp),
+ Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2),
+ Mf::Log => MathOp::Ext(spirv::GLOp::Log),
+ Mf::Log2 => MathOp::Ext(spirv::GLOp::Log2),
+ Mf::Pow => MathOp::Ext(spirv::GLOp::Pow),
+ // computational
+ Mf::Sign => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Float) => spirv::GLOp::FSign,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::SSign,
+ other => unimplemented!("Unexpected sign({:?})", other),
+ }),
+ Mf::Fma => MathOp::Ext(spirv::GLOp::Fma),
+ Mf::Mix => {
+ let selector = arg2.unwrap();
+ let selector_ty =
+ self.fun_info[selector].ty.inner_with(&self.ir_module.types);
+ match (arg_ty, selector_ty) {
+ // if the selector is a scalar, we need to splat it
+ (
+ &crate::TypeInner::Vector { size, .. },
+ &crate::TypeInner::Scalar(scalar),
+ ) => {
+ let selector_type_id =
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(size),
+ scalar,
+ pointer_space: None,
+ }));
+ self.temp_list.clear();
+ self.temp_list.resize(size as usize, arg2_id);
+
+ let selector_id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ selector_type_id,
+ selector_id,
+ &self.temp_list,
+ ));
+
+ MathOp::Custom(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::FMix,
+ result_type_id,
+ id,
+ &[arg0_id, arg1_id, selector_id],
+ ))
+ }
+ _ => MathOp::Ext(spirv::GLOp::FMix),
+ }
+ }
+ Mf::Step => MathOp::Ext(spirv::GLOp::Step),
+ Mf::SmoothStep => MathOp::Ext(spirv::GLOp::SmoothStep),
+ Mf::Sqrt => MathOp::Ext(spirv::GLOp::Sqrt),
+ Mf::InverseSqrt => MathOp::Ext(spirv::GLOp::InverseSqrt),
+ Mf::Inverse => MathOp::Ext(spirv::GLOp::MatrixInverse),
+ Mf::Transpose => MathOp::Custom(Instruction::unary(
+ spirv::Op::Transpose,
+ result_type_id,
+ id,
+ arg0_id,
+ )),
+ Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant),
+ Mf::ReverseBits => MathOp::Custom(Instruction::unary(
+ spirv::Op::BitReverse,
+ result_type_id,
+ id,
+ arg0_id,
+ )),
+ Mf::CountTrailingZeros => {
+ let uint_id = match *arg_ty {
+ crate::TypeInner::Vector { size, mut scalar } => {
+ scalar.kind = crate::ScalarKind::Uint;
+ let ty = LocalType::Value {
+ vector_size: Some(size),
+ scalar,
+ pointer_space: None,
+ }
+ .into();
+
+ self.temp_list.clear();
+ self.temp_list.resize(
+ size as _,
+ self.writer.get_constant_scalar_with(32, scalar)?,
+ );
+
+ self.writer.get_constant_composite(ty, &self.temp_list)
+ }
+ crate::TypeInner::Scalar(mut scalar) => {
+ scalar.kind = crate::ScalarKind::Uint;
+ self.writer.get_constant_scalar_with(32, scalar)?
+ }
+ _ => unreachable!(),
+ };
+
+ let lsb_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::FindILsb,
+ result_type_id,
+ lsb_id,
+ &[arg0_id],
+ ));
+
+ MathOp::Custom(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::UMin,
+ result_type_id,
+ id,
+ &[uint_id, lsb_id],
+ ))
+ }
+ Mf::CountLeadingZeros => {
+ let (int_type_id, int_id) = match *arg_ty {
+ crate::TypeInner::Vector { size, mut scalar } => {
+ scalar.kind = crate::ScalarKind::Sint;
+ let ty = LocalType::Value {
+ vector_size: Some(size),
+ scalar,
+ pointer_space: None,
+ }
+ .into();
+
+ self.temp_list.clear();
+ self.temp_list.resize(
+ size as _,
+ self.writer.get_constant_scalar_with(31, scalar)?,
+ );
+
+ (
+ self.get_type_id(ty),
+ self.writer.get_constant_composite(ty, &self.temp_list),
+ )
+ }
+ crate::TypeInner::Scalar(mut scalar) => {
+ scalar.kind = crate::ScalarKind::Sint;
+ (
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar,
+ pointer_space: None,
+ })),
+ self.writer.get_constant_scalar_with(31, scalar)?,
+ )
+ }
+ _ => unreachable!(),
+ };
+
+ let msb_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::FindUMsb,
+ int_type_id,
+ msb_id,
+ &[arg0_id],
+ ));
+
+ MathOp::Custom(Instruction::binary(
+ spirv::Op::ISub,
+ result_type_id,
+ id,
+ int_id,
+ msb_id,
+ ))
+ }
+ Mf::CountOneBits => MathOp::Custom(Instruction::unary(
+ spirv::Op::BitCount,
+ result_type_id,
+ id,
+ arg0_id,
+ )),
+ Mf::ExtractBits => {
+ let op = match arg_scalar_kind {
+ Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
+ Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
+ other => unimplemented!("Unexpected sign({:?})", other),
+ };
+ MathOp::Custom(Instruction::ternary(
+ op,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ arg2_id,
+ ))
+ }
+ Mf::InsertBits => MathOp::Custom(Instruction::quaternary(
+ spirv::Op::BitFieldInsert,
+ result_type_id,
+ id,
+ arg0_id,
+ arg1_id,
+ arg2_id,
+ arg3_id,
+ )),
+ Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb),
+ Mf::FindMsb => MathOp::Ext(match arg_scalar_kind {
+ Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
+ Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
+ other => unimplemented!("Unexpected findMSB({:?})", other),
+ }),
+ Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8),
+ Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8),
+ Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16),
+ Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16),
+ Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16),
+ Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8),
+ Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8),
+ Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16),
+ Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16),
+ Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16),
+ };
+
+ block.body.push(match math_op {
+ MathOp::Ext(op) => Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ op,
+ result_type_id,
+ id,
+ &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
+ ),
+ MathOp::Custom(inst) => inst,
+ });
+ id
+ }
+ crate::Expression::LocalVariable(variable) => self.function.variables[&variable].id,
+ crate::Expression::Load { pointer } => {
+ match self.write_expression_pointer(pointer, block, None)? {
+ ExpressionPointer::Ready { pointer_id } => {
+ let id = self.gen_id();
+ let atomic_space =
+ match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Pointer { base, space } => {
+ match self.ir_module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => Some(space),
+ _ => None,
+ }
+ }
+ _ => None,
+ };
+ let instruction = if let Some(space) = atomic_space {
+ let (semantics, scope) = space.to_spirv_semantics_and_scope();
+ let scope_constant_id = self.get_scope_constant(scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ Instruction::atomic_load(
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ )
+ } else {
+ Instruction::load(result_type_id, id, pointer_id, None)
+ };
+ block.body.push(instruction);
+ id
+ }
+ ExpressionPointer::Conditional { condition, access } => {
+ //TODO: support atomics?
+ self.write_conditional_indexed_load(
+ result_type_id,
+ condition,
+ block,
+ move |id_gen, block| {
+ // The in-bounds path. Perform the access and the load.
+ let pointer_id = access.result_id.unwrap();
+ let value_id = id_gen.next();
+ block.body.push(access);
+ block.body.push(Instruction::load(
+ result_type_id,
+ value_id,
+ pointer_id,
+ None,
+ ));
+ value_id
+ },
+ )
+ }
+ }
+ }
+ crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
+ crate::Expression::CallResult(_)
+ | crate::Expression::AtomicResult { .. }
+ | crate::Expression::WorkGroupUniformLoadResult { .. }
+ | crate::Expression::RayQueryProceedResult => self.cached[expr_handle],
+ crate::Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ use crate::ScalarKind as Sk;
+
+ let expr_id = self.cached[expr];
+ let (src_scalar, src_size, is_matrix) =
+ match *self.fun_info[expr].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Scalar(scalar) => (scalar, None, false),
+ crate::TypeInner::Vector { scalar, size } => (scalar, Some(size), false),
+ crate::TypeInner::Matrix { scalar, .. } => (scalar, None, true),
+ ref other => {
+ log::error!("As source {:?}", other);
+ return Err(Error::Validation("Unexpected Expression::As source"));
+ }
+ };
+
+ enum Cast {
+ Identity,
+ Unary(spirv::Op),
+ Binary(spirv::Op, Word),
+ Ternary(spirv::Op, Word, Word),
+ }
+
+ let cast = if is_matrix {
+ // we only support identity casts for matrices
+ Cast::Unary(spirv::Op::CopyObject)
+ } else {
+ match (src_scalar.kind, kind, convert) {
+ // Filter out identity casts. Some Adreno drivers are
+ // confused by no-op OpBitCast instructions.
+ (src_kind, kind, convert)
+ if src_kind == kind
+ && convert.filter(|&width| width != src_scalar.width).is_none() =>
+ {
+ Cast::Identity
+ }
+ (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject),
+ (_, _, None) => Cast::Unary(spirv::Op::Bitcast),
+ // casting to a bool - generate `OpXxxNotEqual`
+ (_, Sk::Bool, Some(_)) => {
+ let op = match src_scalar.kind {
+ Sk::Sint | Sk::Uint => spirv::Op::INotEqual,
+ Sk::Float => spirv::Op::FUnordNotEqual,
+ Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(),
+ };
+ let zero_scalar_id =
+ self.writer.get_constant_scalar_with(0, src_scalar)?;
+ let zero_id = match src_size {
+ Some(size) => {
+ let ty = LocalType::Value {
+ vector_size: Some(size),
+ scalar: src_scalar,
+ pointer_space: None,
+ }
+ .into();
+
+ self.temp_list.clear();
+ self.temp_list.resize(size as _, zero_scalar_id);
+
+ self.writer.get_constant_composite(ty, &self.temp_list)
+ }
+ None => zero_scalar_id,
+ };
+
+ Cast::Binary(op, zero_id)
+ }
+ // casting from a bool - generate `OpSelect`
+ (Sk::Bool, _, Some(dst_width)) => {
+ let dst_scalar = crate::Scalar {
+ kind,
+ width: dst_width,
+ };
+ let zero_scalar_id =
+ self.writer.get_constant_scalar_with(0, dst_scalar)?;
+ let one_scalar_id =
+ self.writer.get_constant_scalar_with(1, dst_scalar)?;
+ let (accept_id, reject_id) = match src_size {
+ Some(size) => {
+ let ty = LocalType::Value {
+ vector_size: Some(size),
+ scalar: dst_scalar,
+ pointer_space: None,
+ }
+ .into();
+
+ self.temp_list.clear();
+ self.temp_list.resize(size as _, zero_scalar_id);
+
+ let vec0_id =
+ self.writer.get_constant_composite(ty, &self.temp_list);
+
+ self.temp_list.fill(one_scalar_id);
+
+ let vec1_id =
+ self.writer.get_constant_composite(ty, &self.temp_list);
+
+ (vec1_id, vec0_id)
+ }
+ None => (one_scalar_id, zero_scalar_id),
+ };
+
+ Cast::Ternary(spirv::Op::Select, accept_id, reject_id)
+ }
+ (Sk::Float, Sk::Uint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToU),
+ (Sk::Float, Sk::Sint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToS),
+ (Sk::Float, Sk::Float, Some(dst_width))
+ if src_scalar.width != dst_width =>
+ {
+ Cast::Unary(spirv::Op::FConvert)
+ }
+ (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF),
+ (Sk::Sint, Sk::Sint, Some(dst_width)) if src_scalar.width != dst_width => {
+ Cast::Unary(spirv::Op::SConvert)
+ }
+ (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF),
+ (Sk::Uint, Sk::Uint, Some(dst_width)) if src_scalar.width != dst_width => {
+ Cast::Unary(spirv::Op::UConvert)
+ }
+ // We assume it's either an identity cast, or int-uint.
+ _ => Cast::Unary(spirv::Op::Bitcast),
+ }
+ };
+
+ let id = self.gen_id();
+ let instruction = match cast {
+ Cast::Identity => None,
+ Cast::Unary(op) => Some(Instruction::unary(op, result_type_id, id, expr_id)),
+ Cast::Binary(op, operand) => Some(Instruction::binary(
+ op,
+ result_type_id,
+ id,
+ expr_id,
+ operand,
+ )),
+ Cast::Ternary(op, op1, op2) => Some(Instruction::ternary(
+ op,
+ result_type_id,
+ id,
+ expr_id,
+ op1,
+ op2,
+ )),
+ };
+ if let Some(instruction) = instruction {
+ block.body.push(instruction);
+ id
+ } else {
+ expr_id
+ }
+ }
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => self.write_image_load(
+ result_type_id,
+ image,
+ coordinate,
+ array_index,
+ level,
+ sample,
+ block,
+ )?,
+ crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => self.write_image_sample(
+ result_type_id,
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ block,
+ )?,
+ crate::Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ let id = self.gen_id();
+ let mut condition_id = self.cached[condition];
+ let accept_id = self.cached[accept];
+ let reject_id = self.cached[reject];
+
+ let condition_ty = self.fun_info[condition]
+ .ty
+ .inner_with(&self.ir_module.types);
+ let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
+
+ if let (
+ &crate::TypeInner::Scalar(
+ condition_scalar @ crate::Scalar {
+ kind: crate::ScalarKind::Bool,
+ ..
+ },
+ ),
+ &crate::TypeInner::Vector { size, .. },
+ ) = (condition_ty, object_ty)
+ {
+ self.temp_list.clear();
+ self.temp_list.resize(size as usize, condition_id);
+
+ let bool_vector_type_id =
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(size),
+ scalar: condition_scalar,
+ pointer_space: None,
+ }));
+
+ let id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ bool_vector_type_id,
+ id,
+ &self.temp_list,
+ ));
+ condition_id = id
+ }
+
+ let instruction =
+ Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
+ block.body.push(instruction);
+ id
+ }
+ crate::Expression::Derivative { axis, ctrl, expr } => {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ match ctrl {
+ Ctrl::Coarse | Ctrl::Fine => {
+ self.writer.require_any(
+ "DerivativeControl",
+ &[spirv::Capability::DerivativeControl],
+ )?;
+ }
+ Ctrl::None => {}
+ }
+ let id = self.gen_id();
+ let expr_id = self.cached[expr];
+ let op = match (axis, ctrl) {
+ (Axis::X, Ctrl::Coarse) => spirv::Op::DPdxCoarse,
+ (Axis::X, Ctrl::Fine) => spirv::Op::DPdxFine,
+ (Axis::X, Ctrl::None) => spirv::Op::DPdx,
+ (Axis::Y, Ctrl::Coarse) => spirv::Op::DPdyCoarse,
+ (Axis::Y, Ctrl::Fine) => spirv::Op::DPdyFine,
+ (Axis::Y, Ctrl::None) => spirv::Op::DPdy,
+ (Axis::Width, Ctrl::Coarse) => spirv::Op::FwidthCoarse,
+ (Axis::Width, Ctrl::Fine) => spirv::Op::FwidthFine,
+ (Axis::Width, Ctrl::None) => spirv::Op::Fwidth,
+ };
+ block
+ .body
+ .push(Instruction::derivative(op, result_type_id, id, expr_id));
+ id
+ }
+ crate::Expression::ImageQuery { image, query } => {
+ self.write_image_query(result_type_id, image, query, block)?
+ }
+ crate::Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+ let arg_id = self.cached[argument];
+ let op = match fun {
+ Rf::All => spirv::Op::All,
+ Rf::Any => spirv::Op::Any,
+ Rf::IsNan => spirv::Op::IsNan,
+ Rf::IsInf => spirv::Op::IsInf,
+ };
+ let id = self.gen_id();
+ block
+ .body
+ .push(Instruction::relational(op, result_type_id, id, arg_id));
+ id
+ }
+ crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
+ crate::Expression::RayQueryGetIntersection { query, committed } => {
+ if !committed {
+ return Err(Error::FeatureNotImplemented("candidate intersection"));
+ }
+ self.write_ray_query_get_intersection(query, block)
+ }
+ };
+
+ self.cached[expr_handle] = id;
+ Ok(())
+ }
+
+ /// Build an `OpAccessChain` instruction.
+ ///
+ /// Emit any needed bounds-checking expressions to `block`.
+ ///
+ /// Some cases we need to generate a different return type than what the IR gives us.
+ /// This is because pointers to binding arrays of handles (such as images or samplers)
+ /// don't exist in the IR, but we need to create them to create an access chain in SPIRV.
+ ///
+ /// On success, the return value is an [`ExpressionPointer`] value; see the
+ /// documentation for that type.
+ fn write_expression_pointer(
+ &mut self,
+ mut expr_handle: Handle<crate::Expression>,
+ block: &mut Block,
+ return_type_override: Option<LookupType>,
+ ) -> Result<ExpressionPointer, Error> {
+ let result_lookup_ty = match self.fun_info[expr_handle].ty {
+ TypeResolution::Handle(ty_handle) => match return_type_override {
+ // We use the return type override as a special case for handle binding arrays as the OpAccessChain
+ // needs to return a pointer, but indexing into a handle binding array just gives you the type of
+ // the binding in the IR.
+ Some(ty) => ty,
+ None => LookupType::Handle(ty_handle),
+ },
+ TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()),
+ };
+ let result_type_id = self.get_type_id(result_lookup_ty);
+
+ // The id of the boolean `and` of all dynamic bounds checks up to this point. If
+ // `None`, then we haven't done any dynamic bounds checks yet.
+ //
+ // When we have a chain of bounds checks, we combine them with `OpLogicalAnd`, not
+ // a short-circuit branch. This means we might do comparisons we don't need to,
+ // but we expect these checks to almost always succeed, and keeping branches to a
+ // minimum is essential.
+ let mut accumulated_checks = None;
+ // Is true if we are accessing into a binding array with a non-uniform index.
+ let mut is_non_uniform_binding_array = false;
+
+ self.temp_list.clear();
+ let root_id = loop {
+ expr_handle = match self.ir_function.expressions[expr_handle] {
+ crate::Expression::Access { base, index } => {
+ if let crate::Expression::GlobalVariable(var_handle) =
+ self.ir_function.expressions[base]
+ {
+ // The access chain needs to be decorated as NonUniform
+ // see VUID-RuntimeSpirv-NonUniform-06274
+ let gvar = &self.ir_module.global_variables[var_handle];
+ if let crate::TypeInner::BindingArray { .. } =
+ self.ir_module.types[gvar.ty].inner
+ {
+ is_non_uniform_binding_array =
+ self.fun_info[index].uniformity.non_uniform_result.is_some();
+ }
+ }
+
+ let index_id = match self.write_bounds_check(base, index, block)? {
+ BoundsCheckResult::KnownInBounds(known_index) => {
+ // Even if the index is known, `OpAccessIndex`
+ // requires expression operands, not literals.
+ let scalar = crate::Literal::U32(known_index);
+ self.writer.get_constant_scalar(scalar)
+ }
+ BoundsCheckResult::Computed(computed_index_id) => computed_index_id,
+ BoundsCheckResult::Conditional(comparison_id) => {
+ match accumulated_checks {
+ Some(prior_checks) => {
+ let combined = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::LogicalAnd,
+ self.writer.get_bool_type_id(),
+ combined,
+ prior_checks,
+ comparison_id,
+ ));
+ accumulated_checks = Some(combined);
+ }
+ None => {
+ // Start a fresh chain of checks.
+ accumulated_checks = Some(comparison_id);
+ }
+ }
+
+ // Either way, the index to use is unchanged.
+ self.cached[index]
+ }
+ };
+ self.temp_list.push(index_id);
+ base
+ }
+ crate::Expression::AccessIndex { base, index } => {
+ let const_id = self.get_index_constant(index);
+ self.temp_list.push(const_id);
+ base
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ let gv = &self.writer.global_variables[handle.index()];
+ break gv.access_id;
+ }
+ crate::Expression::LocalVariable(variable) => {
+ let local_var = &self.function.variables[&variable];
+ break local_var.id;
+ }
+ crate::Expression::FunctionArgument(index) => {
+ break self.function.parameter_id(index);
+ }
+ ref other => unimplemented!("Unexpected pointer expression {:?}", other),
+ }
+ };
+
+ let (pointer_id, expr_pointer) = if self.temp_list.is_empty() {
+ (
+ root_id,
+ ExpressionPointer::Ready {
+ pointer_id: root_id,
+ },
+ )
+ } else {
+ self.temp_list.reverse();
+ let pointer_id = self.gen_id();
+ let access =
+ Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
+
+ // If we generated some bounds checks, we need to leave it to our
+ // caller to generate the branch, the access, the load or store, and
+ // the zero value (for loads). Otherwise, we can emit the access
+ // ourselves, and just hand them the id of the pointer.
+ let expr_pointer = match accumulated_checks {
+ Some(condition) => ExpressionPointer::Conditional { condition, access },
+ None => {
+ block.body.push(access);
+ ExpressionPointer::Ready { pointer_id }
+ }
+ };
+ (pointer_id, expr_pointer)
+ };
+ // Subsequent load, store and atomic operations require the pointer to be decorated as NonUniform
+ // if the binding array was accessed with a non-uniform index
+ // see VUID-RuntimeSpirv-NonUniform-06274
+ if is_non_uniform_binding_array {
+ self.writer
+ .decorate_non_uniform_binding_array_access(pointer_id)?;
+ }
+
+ Ok(expr_pointer)
+ }
+
+ /// Build the instructions for matrix - matrix column operations
+ #[allow(clippy::too_many_arguments)]
+ fn write_matrix_matrix_column_op(
+ &mut self,
+ block: &mut Block,
+ result_id: Word,
+ result_type_id: Word,
+ left_id: Word,
+ right_id: Word,
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: u8,
+ op: spirv::Op,
+ ) {
+ self.temp_list.clear();
+
+ let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(rows),
+ scalar: crate::Scalar::float(width),
+ pointer_space: None,
+ }));
+
+ for index in 0..columns as u32 {
+ let column_id_left = self.gen_id();
+ let column_id_right = self.gen_id();
+ let column_id_res = self.gen_id();
+
+ block.body.push(Instruction::composite_extract(
+ vector_type_id,
+ column_id_left,
+ left_id,
+ &[index],
+ ));
+ block.body.push(Instruction::composite_extract(
+ vector_type_id,
+ column_id_right,
+ right_id,
+ &[index],
+ ));
+ block.body.push(Instruction::binary(
+ op,
+ vector_type_id,
+ column_id_res,
+ column_id_left,
+ column_id_right,
+ ));
+
+ self.temp_list.push(column_id_res);
+ }
+
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ result_id,
+ &self.temp_list,
+ ));
+ }
+
+ /// Build the instructions for vector - scalar multiplication
+ fn write_vector_scalar_mult(
+ &mut self,
+ block: &mut Block,
+ result_id: Word,
+ result_type_id: Word,
+ vector_id: Word,
+ scalar_id: Word,
+ vector: &crate::TypeInner,
+ ) {
+ let (size, kind) = match *vector {
+ crate::TypeInner::Vector {
+ size,
+ scalar: crate::Scalar { kind, .. },
+ } => (size, kind),
+ _ => unreachable!(),
+ };
+
+ let (op, operand_id) = match kind {
+ crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id),
+ _ => {
+ let operand_id = self.gen_id();
+ self.temp_list.clear();
+ self.temp_list.resize(size as usize, scalar_id);
+ block.body.push(Instruction::composite_construct(
+ result_type_id,
+ operand_id,
+ &self.temp_list,
+ ));
+ (spirv::Op::IMul, operand_id)
+ }
+ };
+
+ block.body.push(Instruction::binary(
+ op,
+ result_type_id,
+ result_id,
+ vector_id,
+ operand_id,
+ ));
+ }
+
+ /// Build the instructions for the arithmetic expression of a dot product
+ fn write_dot_product(
+ &mut self,
+ result_id: Word,
+ result_type_id: Word,
+ arg0_id: Word,
+ arg1_id: Word,
+ size: u32,
+ block: &mut Block,
+ ) {
+ let mut partial_sum = self.writer.get_constant_null(result_type_id);
+ let last_component = size - 1;
+ for index in 0..=last_component {
+ // compute the product of the current components
+ let a_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ a_id,
+ arg0_id,
+ &[index],
+ ));
+ let b_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ b_id,
+ arg1_id,
+ &[index],
+ ));
+ let prod_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::IMul,
+ result_type_id,
+ prod_id,
+ a_id,
+ b_id,
+ ));
+
+ // choose the id for the next sum, depending on current index
+ let id = if index == last_component {
+ result_id
+ } else {
+ self.gen_id()
+ };
+
+ // sum the computed product with the partial sum
+ block.body.push(Instruction::binary(
+ spirv::Op::IAdd,
+ result_type_id,
+ id,
+ partial_sum,
+ prod_id,
+ ));
+ // set the id of the result as the previous partial sum
+ partial_sum = id;
+ }
+ }
+
+ pub(super) fn write_block(
+ &mut self,
+ label_id: Word,
+ naga_block: &crate::Block,
+ exit: BlockExit,
+ loop_context: LoopContext,
+ debug_info: Option<&DebugInfoInner>,
+ ) -> Result<(), Error> {
+ let mut block = Block::new(label_id);
+ for (statement, span) in naga_block.span_iter() {
+ if let (Some(debug_info), false) = (
+ debug_info,
+ matches!(
+ statement,
+ &(Statement::Block(..)
+ | Statement::Break
+ | Statement::Continue
+ | Statement::Kill
+ | Statement::Return { .. }
+ | Statement::Loop { .. })
+ ),
+ ) {
+ let loc: crate::SourceLocation = span.location(debug_info.source_code);
+ block.body.push(Instruction::line(
+ debug_info.source_file_id,
+ loc.line_number,
+ loc.line_position,
+ ));
+ };
+ match *statement {
+ crate::Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ // omit const expressions as we've already cached those
+ if !self.expression_constness.is_const(handle) {
+ self.cache_expression_value(handle, &mut block)?;
+ }
+ }
+ }
+ crate::Statement::Block(ref block_statements) => {
+ let scope_id = self.gen_id();
+ self.function.consume(block, Instruction::branch(scope_id));
+
+ let merge_id = self.gen_id();
+ self.write_block(
+ scope_id,
+ block_statements,
+ BlockExit::Branch { target: merge_id },
+ loop_context,
+ debug_info,
+ )?;
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ let condition_id = self.cached[condition];
+
+ let merge_id = self.gen_id();
+ block.body.push(Instruction::selection_merge(
+ merge_id,
+ spirv::SelectionControl::NONE,
+ ));
+
+ let accept_id = if accept.is_empty() {
+ None
+ } else {
+ Some(self.gen_id())
+ };
+ let reject_id = if reject.is_empty() {
+ None
+ } else {
+ Some(self.gen_id())
+ };
+
+ self.function.consume(
+ block,
+ Instruction::branch_conditional(
+ condition_id,
+ accept_id.unwrap_or(merge_id),
+ reject_id.unwrap_or(merge_id),
+ ),
+ );
+
+ if let Some(block_id) = accept_id {
+ self.write_block(
+ block_id,
+ accept,
+ BlockExit::Branch { target: merge_id },
+ loop_context,
+ debug_info,
+ )?;
+ }
+ if let Some(block_id) = reject_id {
+ self.write_block(
+ block_id,
+ reject,
+ BlockExit::Branch { target: merge_id },
+ loop_context,
+ debug_info,
+ )?;
+ }
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ let selector_id = self.cached[selector];
+
+ let merge_id = self.gen_id();
+ block.body.push(Instruction::selection_merge(
+ merge_id,
+ spirv::SelectionControl::NONE,
+ ));
+
+ let mut default_id = None;
+ // id of previous empty fall-through case
+ let mut last_id = None;
+
+ let mut raw_cases = Vec::with_capacity(cases.len());
+ let mut case_ids = Vec::with_capacity(cases.len());
+ for case in cases.iter() {
+ // take id of previous empty fall-through case or generate a new one
+ let label_id = last_id.take().unwrap_or_else(|| self.gen_id());
+
+ if case.fall_through && case.body.is_empty() {
+ last_id = Some(label_id);
+ }
+
+ case_ids.push(label_id);
+
+ match case.value {
+ crate::SwitchValue::I32(value) => {
+ raw_cases.push(super::instructions::Case {
+ value: value as Word,
+ label_id,
+ });
+ }
+ crate::SwitchValue::U32(value) => {
+ raw_cases.push(super::instructions::Case { value, label_id });
+ }
+ crate::SwitchValue::Default => {
+ default_id = Some(label_id);
+ }
+ }
+ }
+
+ let default_id = default_id.unwrap();
+
+ self.function.consume(
+ block,
+ Instruction::switch(selector_id, default_id, &raw_cases),
+ );
+
+ let inner_context = LoopContext {
+ break_id: Some(merge_id),
+ ..loop_context
+ };
+
+ for (i, (case, label_id)) in cases
+ .iter()
+ .zip(case_ids.iter())
+ .filter(|&(case, _)| !(case.fall_through && case.body.is_empty()))
+ .enumerate()
+ {
+ let case_finish_id = if case.fall_through {
+ case_ids[i + 1]
+ } else {
+ merge_id
+ };
+ self.write_block(
+ *label_id,
+ &case.body,
+ BlockExit::Branch {
+ target: case_finish_id,
+ },
+ inner_context,
+ debug_info,
+ )?;
+ }
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ let preamble_id = self.gen_id();
+ self.function
+ .consume(block, Instruction::branch(preamble_id));
+
+ let merge_id = self.gen_id();
+ let body_id = self.gen_id();
+ let continuing_id = self.gen_id();
+
+ // SPIR-V requires the continuing to the `OpLoopMerge`,
+ // so we have to start a new block with it.
+ block = Block::new(preamble_id);
+ // HACK the loop statement is begin with branch instruction,
+ // so we need to put `OpLine` debug info before merge instruction
+ if let Some(debug_info) = debug_info {
+ let loc: crate::SourceLocation = span.location(debug_info.source_code);
+ block.body.push(Instruction::line(
+ debug_info.source_file_id,
+ loc.line_number,
+ loc.line_position,
+ ))
+ }
+ block.body.push(Instruction::loop_merge(
+ merge_id,
+ continuing_id,
+ spirv::SelectionControl::NONE,
+ ));
+ self.function.consume(block, Instruction::branch(body_id));
+
+ self.write_block(
+ body_id,
+ body,
+ BlockExit::Branch {
+ target: continuing_id,
+ },
+ LoopContext {
+ continuing_id: Some(continuing_id),
+ break_id: Some(merge_id),
+ },
+ debug_info,
+ )?;
+
+ let exit = match break_if {
+ Some(condition) => BlockExit::BreakIf {
+ condition,
+ preamble_id,
+ },
+ None => BlockExit::Branch {
+ target: preamble_id,
+ },
+ };
+
+ self.write_block(
+ continuing_id,
+ continuing,
+ exit,
+ LoopContext {
+ continuing_id: None,
+ break_id: Some(merge_id),
+ },
+ debug_info,
+ )?;
+
+ block = Block::new(merge_id);
+ }
+ crate::Statement::Break => {
+ self.function
+ .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
+ return Ok(());
+ }
+ crate::Statement::Continue => {
+ self.function.consume(
+ block,
+ Instruction::branch(loop_context.continuing_id.unwrap()),
+ );
+ return Ok(());
+ }
+ crate::Statement::Return { value: Some(value) } => {
+ let value_id = self.cached[value];
+ let instruction = match self.function.entry_point_context {
+ // If this is an entry point, and we need to return anything,
+ // let's instead store the output variables and return `void`.
+ Some(ref context) => {
+ self.writer.write_entry_point_return(
+ value_id,
+ self.ir_function.result.as_ref().unwrap(),
+ &context.results,
+ &mut block.body,
+ )?;
+ Instruction::return_void()
+ }
+ None => Instruction::return_value(value_id),
+ };
+ self.function.consume(block, instruction);
+ return Ok(());
+ }
+ crate::Statement::Return { value: None } => {
+ self.function.consume(block, Instruction::return_void());
+ return Ok(());
+ }
+ crate::Statement::Kill => {
+ self.function.consume(block, Instruction::kill());
+ return Ok(());
+ }
+ crate::Statement::Barrier(flags) => {
+ self.writer.write_barrier(flags, &mut block);
+ }
+ crate::Statement::Store { pointer, value } => {
+ let value_id = self.cached[value];
+ match self.write_expression_pointer(pointer, &mut block, None)? {
+ ExpressionPointer::Ready { pointer_id } => {
+ let atomic_space = match *self.fun_info[pointer]
+ .ty
+ .inner_with(&self.ir_module.types)
+ {
+ crate::TypeInner::Pointer { base, space } => {
+ match self.ir_module.types[base].inner {
+ crate::TypeInner::Atomic { .. } => Some(space),
+ _ => None,
+ }
+ }
+ _ => None,
+ };
+ let instruction = if let Some(space) = atomic_space {
+ let (semantics, scope) = space.to_spirv_semantics_and_scope();
+ let scope_constant_id = self.get_scope_constant(scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ Instruction::atomic_store(
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ } else {
+ Instruction::store(pointer_id, value_id, None)
+ };
+ block.body.push(instruction);
+ }
+ ExpressionPointer::Conditional { condition, access } => {
+ let mut selection = Selection::start(&mut block, ());
+ selection.if_true(self, condition, ());
+
+ // The in-bounds path. Perform the access and the store.
+ let pointer_id = access.result_id.unwrap();
+ selection.block().body.push(access);
+ selection
+ .block()
+ .body
+ .push(Instruction::store(pointer_id, value_id, None));
+
+ // Finish the in-bounds block and start the merge block. This
+ // is the block we'll leave current on return.
+ selection.finish(self, ());
+ }
+ };
+ }
+ crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
+ crate::Statement::Call {
+ function: local_function,
+ ref arguments,
+ result,
+ } => {
+ let id = self.gen_id();
+ self.temp_list.clear();
+ for &argument in arguments {
+ self.temp_list.push(self.cached[argument]);
+ }
+
+ let type_id = match result {
+ Some(expr) => {
+ self.cached[expr] = id;
+ self.get_expression_type_id(&self.fun_info[expr].ty)
+ }
+ None => self.writer.void_type,
+ };
+
+ block.body.push(Instruction::function_call(
+ type_id,
+ id,
+ self.writer.lookup_function[&local_function],
+ &self.temp_list,
+ ));
+ }
+ crate::Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ let id = self.gen_id();
+ let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
+
+ self.cached[result] = id;
+
+ let pointer_id =
+ match self.write_expression_pointer(pointer, &mut block, None)? {
+ ExpressionPointer::Ready { pointer_id } => pointer_id,
+ ExpressionPointer::Conditional { .. } => {
+ return Err(Error::FeatureNotImplemented(
+ "Atomics out-of-bounds handling",
+ ));
+ }
+ };
+
+ let space = self.fun_info[pointer]
+ .ty
+ .inner_with(&self.ir_module.types)
+ .pointer_space()
+ .unwrap();
+ let (semantics, scope) = space.to_spirv_semantics_and_scope();
+ let scope_constant_id = self.get_scope_constant(scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ let value_id = self.cached[value];
+ let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
+
+ let instruction = match *fun {
+ crate::AtomicFunction::Add => Instruction::atomic_binary(
+ spirv::Op::AtomicIAdd,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::Subtract => Instruction::atomic_binary(
+ spirv::Op::AtomicISub,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::And => Instruction::atomic_binary(
+ spirv::Op::AtomicAnd,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::InclusiveOr => Instruction::atomic_binary(
+ spirv::Op::AtomicOr,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::ExclusiveOr => Instruction::atomic_binary(
+ spirv::Op::AtomicXor,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ ),
+ crate::AtomicFunction::Min => {
+ let spirv_op = match *value_inner {
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: _,
+ }) => spirv::Op::AtomicSMin,
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: _,
+ }) => spirv::Op::AtomicUMin,
+ _ => unimplemented!(),
+ };
+ Instruction::atomic_binary(
+ spirv_op,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ }
+ crate::AtomicFunction::Max => {
+ let spirv_op = match *value_inner {
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Sint,
+ width: _,
+ }) => spirv::Op::AtomicSMax,
+ crate::TypeInner::Scalar(crate::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: _,
+ }) => spirv::Op::AtomicUMax,
+ _ => unimplemented!(),
+ };
+ Instruction::atomic_binary(
+ spirv_op,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ }
+ crate::AtomicFunction::Exchange { compare: None } => {
+ Instruction::atomic_binary(
+ spirv::Op::AtomicExchange,
+ result_type_id,
+ id,
+ pointer_id,
+ scope_constant_id,
+ semantics_id,
+ value_id,
+ )
+ }
+ crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
+ let scalar_type_id = match *value_inner {
+ crate::TypeInner::Scalar(scalar) => {
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar,
+ pointer_space: None,
+ }))
+ }
+ _ => unimplemented!(),
+ };
+ let bool_type_id =
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::BOOL,
+ pointer_space: None,
+ }));
+
+ let cas_result_id = self.gen_id();
+ let equality_result_id = self.gen_id();
+ let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
+ cas_instr.set_type(scalar_type_id);
+ cas_instr.set_result(cas_result_id);
+ cas_instr.add_operand(pointer_id);
+ cas_instr.add_operand(scope_constant_id);
+ cas_instr.add_operand(semantics_id); // semantics if equal
+ cas_instr.add_operand(semantics_id); // semantics if not equal
+ cas_instr.add_operand(value_id);
+ cas_instr.add_operand(self.cached[cmp]);
+ block.body.push(cas_instr);
+ block.body.push(Instruction::binary(
+ spirv::Op::IEqual,
+ bool_type_id,
+ equality_result_id,
+ cas_result_id,
+ self.cached[cmp],
+ ));
+ Instruction::composite_construct(
+ result_type_id,
+ id,
+ &[cas_result_id, equality_result_id],
+ )
+ }
+ };
+
+ block.body.push(instruction);
+ }
+ crate::Statement::WorkGroupUniformLoad { pointer, result } => {
+ self.writer
+ .write_barrier(crate::Barrier::WORK_GROUP, &mut block);
+ let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
+ // Embed the body of
+ match self.write_expression_pointer(pointer, &mut block, None)? {
+ ExpressionPointer::Ready { pointer_id } => {
+ let id = self.gen_id();
+ block.body.push(Instruction::load(
+ result_type_id,
+ id,
+ pointer_id,
+ None,
+ ));
+ self.cached[result] = id;
+ }
+ ExpressionPointer::Conditional { condition, access } => {
+ self.cached[result] = self.write_conditional_indexed_load(
+ result_type_id,
+ condition,
+ &mut block,
+ move |id_gen, block| {
+ // The in-bounds path. Perform the access and the load.
+ let pointer_id = access.result_id.unwrap();
+ let value_id = id_gen.next();
+ block.body.push(access);
+ block.body.push(Instruction::load(
+ result_type_id,
+ value_id,
+ pointer_id,
+ None,
+ ));
+ value_id
+ },
+ )
+ }
+ }
+ self.writer
+ .write_barrier(crate::Barrier::WORK_GROUP, &mut block);
+ }
+ crate::Statement::RayQuery { query, ref fun } => {
+ self.write_ray_query_function(query, fun, &mut block);
+ }
+ }
+ }
+
+ let termination = match exit {
+ // We're generating code for the top-level Block of the function, so we
+ // need to end it with some kind of return instruction.
+ BlockExit::Return => match self.ir_function.result {
+ Some(ref result) if self.function.entry_point_context.is_none() => {
+ let type_id = self.get_type_id(LookupType::Handle(result.ty));
+ let null_id = self.writer.get_constant_null(type_id);
+ Instruction::return_value(null_id)
+ }
+ _ => Instruction::return_void(),
+ },
+ BlockExit::Branch { target } => Instruction::branch(target),
+ BlockExit::BreakIf {
+ condition,
+ preamble_id,
+ } => {
+ let condition_id = self.cached[condition];
+
+ Instruction::branch_conditional(
+ condition_id,
+ loop_context.break_id.unwrap(),
+ preamble_id,
+ )
+ }
+ };
+
+ self.function.consume(block, termination);
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/helpers.rs b/third_party/rust/naga/src/back/spv/helpers.rs
new file mode 100644
index 0000000000..5b6226db85
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/helpers.rs
@@ -0,0 +1,109 @@
+use crate::{Handle, UniqueArena};
+use spirv::Word;
+
+pub(super) fn bytes_to_words(bytes: &[u8]) -> Vec<Word> {
+ bytes
+ .chunks(4)
+ .map(|chars| chars.iter().rev().fold(0u32, |u, c| (u << 8) | *c as u32))
+ .collect()
+}
+
+pub(super) fn string_to_words(input: &str) -> Vec<Word> {
+ let bytes = input.as_bytes();
+ let mut words = bytes_to_words(bytes);
+
+ if bytes.len() % 4 == 0 {
+ // nul-termination
+ words.push(0x0u32);
+ }
+
+ words
+}
+
+pub(super) const fn map_storage_class(space: crate::AddressSpace) -> spirv::StorageClass {
+ match space {
+ crate::AddressSpace::Handle => spirv::StorageClass::UniformConstant,
+ crate::AddressSpace::Function => spirv::StorageClass::Function,
+ crate::AddressSpace::Private => spirv::StorageClass::Private,
+ crate::AddressSpace::Storage { .. } => spirv::StorageClass::StorageBuffer,
+ crate::AddressSpace::Uniform => spirv::StorageClass::Uniform,
+ crate::AddressSpace::WorkGroup => spirv::StorageClass::Workgroup,
+ crate::AddressSpace::PushConstant => spirv::StorageClass::PushConstant,
+ }
+}
+
+pub(super) fn contains_builtin(
+ binding: Option<&crate::Binding>,
+ ty: Handle<crate::Type>,
+ arena: &UniqueArena<crate::Type>,
+ built_in: crate::BuiltIn,
+) -> bool {
+ if let Some(&crate::Binding::BuiltIn(bi)) = binding {
+ bi == built_in
+ } else if let crate::TypeInner::Struct { ref members, .. } = arena[ty].inner {
+ members
+ .iter()
+ .any(|member| contains_builtin(member.binding.as_ref(), member.ty, arena, built_in))
+ } else {
+ false // unreachable
+ }
+}
+
+impl crate::AddressSpace {
+ pub(super) const fn to_spirv_semantics_and_scope(
+ self,
+ ) -> (spirv::MemorySemantics, spirv::Scope) {
+ match self {
+ Self::Storage { .. } => (spirv::MemorySemantics::UNIFORM_MEMORY, spirv::Scope::Device),
+ Self::WorkGroup => (
+ spirv::MemorySemantics::WORKGROUP_MEMORY,
+ spirv::Scope::Workgroup,
+ ),
+ _ => (spirv::MemorySemantics::empty(), spirv::Scope::Invocation),
+ }
+ }
+}
+
+/// Return true if the global requires a type decorated with `Block`.
+///
+/// Vulkan spec v1.3 §15.6.2, "Descriptor Set Interface", says:
+///
+/// > Variables identified with the `Uniform` storage class are used to
+/// > access transparent buffer backed resources. Such variables must
+/// > be:
+/// >
+/// > - typed as `OpTypeStruct`, or an array of this type,
+/// >
+/// > - identified with a `Block` or `BufferBlock` decoration, and
+/// >
+/// > - laid out explicitly using the `Offset`, `ArrayStride`, and
+/// > `MatrixStride` decorations as specified in §15.6.4, "Offset
+/// > and Stride Assignment."
+// See `back::spv::GlobalVariable::access_id` for details.
+pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariable) -> bool {
+ match var.space {
+ crate::AddressSpace::Uniform
+ | crate::AddressSpace::Storage { .. }
+ | crate::AddressSpace::PushConstant => {}
+ _ => return false,
+ };
+ match ir_module.types[var.ty].inner {
+ crate::TypeInner::Struct {
+ ref members,
+ span: _,
+ } => match members.last() {
+ Some(member) => match ir_module.types[member.ty].inner {
+ // Structs with dynamically sized arrays can't be copied and can't be wrapped.
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } => false,
+ _ => true,
+ },
+ None => false,
+ },
+ crate::TypeInner::BindingArray { .. } => false,
+ // if it's not a structure or a binding array, let's wrap it to be able to put "Block"
+ _ => true,
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/image.rs b/third_party/rust/naga/src/back/spv/image.rs
new file mode 100644
index 0000000000..c0fc41cbb6
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/image.rs
@@ -0,0 +1,1210 @@
+/*!
+Generating SPIR-V for image operations.
+*/
+
+use super::{
+ selection::{MergeTuple, Selection},
+ Block, BlockContext, Error, IdGenerator, Instruction, LocalType, LookupType,
+};
+use crate::arena::Handle;
+use spirv::Word;
+
+/// Information about a vector of coordinates.
+///
+/// The coordinate vectors expected by SPIR-V `OpImageRead` and `OpImageFetch`
+/// supply the array index for arrayed images as an additional component at
+/// the end, whereas Naga's `ImageLoad`, `ImageStore`, and `ImageSample` carry
+/// the array index as a separate field.
+///
+/// In the process of generating code to compute the combined vector, we also
+/// produce SPIR-V types and vector lengths that are useful elsewhere. This
+/// struct gathers that information into one place, with standard names.
+struct ImageCoordinates {
+ /// The SPIR-V id of the combined coordinate/index vector value.
+ ///
+ /// Note: when indexing a non-arrayed 1D image, this will be a scalar.
+ value_id: Word,
+
+ /// The SPIR-V id of the type of `value`.
+ type_id: Word,
+
+ /// The number of components in `value`, if it is a vector, or `None` if it
+ /// is a scalar.
+ size: Option<crate::VectorSize>,
+}
+
+/// A trait for image access (load or store) code generators.
+///
+/// Types implementing this trait hold information about an `ImageStore` or
+/// `ImageLoad` operation that is not affected by the bounds check policy. The
+/// `generate` method emits code for the access, given the results of bounds
+/// checking.
+///
+/// The [`image`] bounds checks policy affects access coordinates, level of
+/// detail, and sample index, but never the image id, result type (if any), or
+/// the specific SPIR-V instruction used. Types that implement this trait gather
+/// together the latter category, so we don't have to plumb them through the
+/// bounds-checking code.
+///
+/// [`image`]: crate::proc::BoundsCheckPolicies::index
+trait Access {
+ /// The Rust type that represents SPIR-V values and types for this access.
+ ///
+ /// For operations like loads, this is `Word`. For operations like stores,
+ /// this is `()`.
+ ///
+ /// For `ReadZeroSkipWrite`, this will be the type of the selection
+ /// construct that performs the bounds checks, so it must implement
+ /// `MergeTuple`.
+ type Output: MergeTuple + Copy + Clone;
+
+ /// Write an image access to `block`.
+ ///
+ /// Access the texel at `coordinates_id`. The optional `level_id` indicates
+ /// the level of detail, and `sample_id` is the index of the sample to
+ /// access in a multisampled texel.
+ ///
+ /// This method assumes that `coordinates_id` has already had the image array
+ /// index, if any, folded in, as done by `write_image_coordinates`.
+ ///
+ /// Return the value id produced by the instruction, if any.
+ ///
+ /// Use `id_gen` to generate SPIR-V ids as necessary.
+ fn generate(
+ &self,
+ id_gen: &mut IdGenerator,
+ coordinates_id: Word,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ ) -> Self::Output;
+
+ /// Return the SPIR-V type of the value produced by the code written by
+ /// `generate`. If the access does not produce a value, `Self::Output`
+ /// should be `()`.
+ fn result_type(&self) -> Self::Output;
+
+ /// Construct the SPIR-V 'zero' value to be returned for an out-of-bounds
+ /// access under the `ReadZeroSkipWrite` policy. If the access does not
+ /// produce a value, `Self::Output` should be `()`.
+ fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Self::Output;
+}
+
+/// Texel access information for an [`ImageLoad`] expression.
+///
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+struct Load {
+ /// The specific opcode we'll use to perform the fetch. Storage images
+ /// require `OpImageRead`, while sampled images require `OpImageFetch`.
+ opcode: spirv::Op,
+
+ /// The type id produced by the actual image access instruction.
+ type_id: Word,
+
+ /// The id of the image being accessed.
+ image_id: Word,
+}
+
+impl Load {
+ fn from_image_expr(
+ ctx: &mut BlockContext<'_>,
+ image_id: Word,
+ image_class: crate::ImageClass,
+ result_type_id: Word,
+ ) -> Result<Load, Error> {
+ let opcode = match image_class {
+ crate::ImageClass::Storage { .. } => spirv::Op::ImageRead,
+ crate::ImageClass::Depth { .. } | crate::ImageClass::Sampled { .. } => {
+ spirv::Op::ImageFetch
+ }
+ };
+
+ // `OpImageRead` and `OpImageFetch` instructions produce vec4<f32>
+ // values. Most of the time, we can just use `result_type_id` for
+ // this. The exception is that `Expression::ImageLoad` from a depth
+ // image produces a scalar `f32`, so in that case we need to find
+ // the right SPIR-V type for the access instruction here.
+ let type_id = match image_class {
+ crate::ImageClass::Depth { .. } => {
+ ctx.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Quad),
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }))
+ }
+ _ => result_type_id,
+ };
+
+ Ok(Load {
+ opcode,
+ type_id,
+ image_id,
+ })
+ }
+}
+
+impl Access for Load {
+ type Output = Word;
+
+ /// Write an instruction to access a given texel of this image.
+ fn generate(
+ &self,
+ id_gen: &mut IdGenerator,
+ coordinates_id: Word,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ ) -> Word {
+ let texel_id = id_gen.next();
+ let mut instruction = Instruction::image_fetch_or_read(
+ self.opcode,
+ self.type_id,
+ texel_id,
+ self.image_id,
+ coordinates_id,
+ );
+
+ match (level_id, sample_id) {
+ (None, None) => {}
+ (Some(level_id), None) => {
+ instruction.add_operand(spirv::ImageOperands::LOD.bits());
+ instruction.add_operand(level_id);
+ }
+ (None, Some(sample_id)) => {
+ instruction.add_operand(spirv::ImageOperands::SAMPLE.bits());
+ instruction.add_operand(sample_id);
+ }
+ // There's no such thing as a multi-sampled mipmap.
+ (Some(_), Some(_)) => unreachable!(),
+ }
+
+ block.body.push(instruction);
+
+ texel_id
+ }
+
+ fn result_type(&self) -> Word {
+ self.type_id
+ }
+
+ fn out_of_bounds_value(&self, ctx: &mut BlockContext<'_>) -> Word {
+ ctx.writer.get_constant_null(self.type_id)
+ }
+}
+
+/// Texel access information for a [`Store`] statement.
+///
+/// [`Store`]: crate::Statement::Store
+struct Store {
+ /// The id of the image being written to.
+ image_id: Word,
+
+ /// The value we're going to write to the texel.
+ value_id: Word,
+}
+
+impl Access for Store {
+ /// Stores don't generate any value.
+ type Output = ();
+
+ fn generate(
+ &self,
+ _id_gen: &mut IdGenerator,
+ coordinates_id: Word,
+ _level_id: Option<Word>,
+ _sample_id: Option<Word>,
+ block: &mut Block,
+ ) {
+ block.body.push(Instruction::image_write(
+ self.image_id,
+ coordinates_id,
+ self.value_id,
+ ));
+ }
+
+ /// Stores don't generate any value, so this just returns `()`.
+ fn result_type(&self) {}
+
+ /// Stores don't generate any value, so this just returns `()`.
+ fn out_of_bounds_value(&self, _ctx: &mut BlockContext<'_>) {}
+}
+
+impl<'w> BlockContext<'w> {
+ /// Extend image coordinates with an array index, if necessary.
+ ///
+ /// Whereas [`Expression::ImageLoad`] and [`ImageSample`] treat the array
+ /// index as a separate operand from the coordinates, SPIR-V image access
+ /// instructions include the array index in the `coordinates` operand. This
+ /// function builds a SPIR-V coordinate vector from a Naga coordinate vector
+ /// and array index, if one is supplied, and returns a `ImageCoordinates`
+ /// struct describing what it built.
+ ///
+ /// If `array_index` is `Some(expr)`, then this function constructs a new
+ /// vector that is `coordinates` with `array_index` concatenated onto the
+ /// end: a `vec2` becomes a `vec3`, a scalar becomes a `vec2`, and so on.
+ ///
+ /// If `array_index` is `None`, then the return value uses `coordinates`
+ /// unchanged. Note that, when indexing a non-arrayed 1D image, this will be
+ /// a scalar value.
+ ///
+ /// If needed, this function generates code to convert the array index,
+ /// always an integer scalar, to match the component type of `coordinates`.
+ /// Naga's `ImageLoad` and SPIR-V's `OpImageRead`, `OpImageFetch`, and
+ /// `OpImageWrite` all use integer coordinates, while Naga's `ImageSample`
+ /// and SPIR-V's `OpImageSample...` instructions all take floating-point
+ /// coordinate vectors.
+ ///
+ /// [`Expression::ImageLoad`]: crate::Expression::ImageLoad
+ /// [`ImageSample`]: crate::Expression::ImageSample
+ fn write_image_coordinates(
+ &mut self,
+ coordinates: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ block: &mut Block,
+ ) -> Result<ImageCoordinates, Error> {
+ use crate::TypeInner as Ti;
+ use crate::VectorSize as Vs;
+
+ let coordinates_id = self.cached[coordinates];
+ let ty = &self.fun_info[coordinates].ty;
+ let inner_ty = ty.inner_with(&self.ir_module.types);
+
+ // If there's no array index, the image coordinates are exactly the
+ // `coordinate` field of the `Expression::ImageLoad`. No work is needed.
+ let array_index = match array_index {
+ None => {
+ let value_id = coordinates_id;
+ let type_id = self.get_expression_type_id(ty);
+ let size = match *inner_ty {
+ Ti::Scalar { .. } => None,
+ Ti::Vector { size, .. } => Some(size),
+ _ => return Err(Error::Validation("coordinate type")),
+ };
+ return Ok(ImageCoordinates {
+ value_id,
+ type_id,
+ size,
+ });
+ }
+ Some(ix) => ix,
+ };
+
+ // Find the component type of `coordinates`, and figure out the size the
+ // combined coordinate vector will have.
+ let (component_scalar, size) = match *inner_ty {
+ Ti::Scalar(scalar @ crate::Scalar { width: 4, .. }) => (scalar, Some(Vs::Bi)),
+ Ti::Vector {
+ scalar: scalar @ crate::Scalar { width: 4, .. },
+ size: Vs::Bi,
+ } => (scalar, Some(Vs::Tri)),
+ Ti::Vector {
+ scalar: scalar @ crate::Scalar { width: 4, .. },
+ size: Vs::Tri,
+ } => (scalar, Some(Vs::Quad)),
+ Ti::Vector { size: Vs::Quad, .. } => {
+ return Err(Error::Validation("extending vec4 coordinate"));
+ }
+ ref other => {
+ log::error!("wrong coordinate type {:?}", other);
+ return Err(Error::Validation("coordinate type"));
+ }
+ };
+
+ // Convert the index to the coordinate component type, if necessary.
+ let array_index_id = self.cached[array_index];
+ let ty = &self.fun_info[array_index].ty;
+ let inner_ty = ty.inner_with(&self.ir_module.types);
+ let array_index_scalar = match *inner_ty {
+ Ti::Scalar(
+ scalar @ crate::Scalar {
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ width: 4,
+ },
+ ) => scalar,
+ _ => unreachable!("we only allow i32 and u32"),
+ };
+ let cast = match (component_scalar.kind, array_index_scalar.kind) {
+ (crate::ScalarKind::Sint, crate::ScalarKind::Sint)
+ | (crate::ScalarKind::Uint, crate::ScalarKind::Uint) => None,
+ (crate::ScalarKind::Sint, crate::ScalarKind::Uint)
+ | (crate::ScalarKind::Uint, crate::ScalarKind::Sint) => Some(spirv::Op::Bitcast),
+ (crate::ScalarKind::Float, crate::ScalarKind::Sint) => Some(spirv::Op::ConvertSToF),
+ (crate::ScalarKind::Float, crate::ScalarKind::Uint) => Some(spirv::Op::ConvertUToF),
+ (crate::ScalarKind::Bool, _) => unreachable!("we don't allow bool for component"),
+ (_, crate::ScalarKind::Bool | crate::ScalarKind::Float) => {
+ unreachable!("we don't allow bool or float for array index")
+ }
+ (crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat, _)
+ | (_, crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat) => {
+ unreachable!("abstract types should never reach backends")
+ }
+ };
+ let reconciled_array_index_id = if let Some(cast) = cast {
+ let component_ty_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: component_scalar,
+ pointer_space: None,
+ }));
+ let reconciled_id = self.gen_id();
+ block.body.push(Instruction::unary(
+ cast,
+ component_ty_id,
+ reconciled_id,
+ array_index_id,
+ ));
+ reconciled_id
+ } else {
+ array_index_id
+ };
+
+ // Find the SPIR-V type for the combined coordinates/index vector.
+ let type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: size,
+ scalar: component_scalar,
+ pointer_space: None,
+ }));
+
+ // Schmear the coordinates and index together.
+ let value_id = self.gen_id();
+ block.body.push(Instruction::composite_construct(
+ type_id,
+ value_id,
+ &[coordinates_id, reconciled_array_index_id],
+ ));
+ Ok(ImageCoordinates {
+ value_id,
+ type_id,
+ size,
+ })
+ }
+
+ pub(super) fn get_handle_id(&mut self, expr_handle: Handle<crate::Expression>) -> Word {
+ let id = match self.ir_function.expressions[expr_handle] {
+ crate::Expression::GlobalVariable(handle) => {
+ self.writer.global_variables[handle.index()].handle_id
+ }
+ crate::Expression::FunctionArgument(i) => {
+ self.function.parameters[i as usize].handle_id
+ }
+ crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => {
+ self.cached[expr_handle]
+ }
+ ref other => unreachable!("Unexpected image expression {:?}", other),
+ };
+
+ if id == 0 {
+ unreachable!(
+ "Image expression {:?} doesn't have a handle ID",
+ expr_handle
+ );
+ }
+
+ id
+ }
+
+ /// Generate a vector or scalar 'one' for arithmetic on `coordinates`.
+ ///
+ /// If `coordinates` is a scalar, return a scalar one. Otherwise, return
+ /// a vector of ones.
+ fn write_coordinate_one(&mut self, coordinates: &ImageCoordinates) -> Result<Word, Error> {
+ let one = self.get_scope_constant(1);
+ match coordinates.size {
+ None => Ok(one),
+ Some(vector_size) => {
+ let ones = [one; 4];
+ let id = self.gen_id();
+ Instruction::constant_composite(
+ coordinates.type_id,
+ id,
+ &ones[..vector_size as usize],
+ )
+ .to_words(&mut self.writer.logical_layout.declarations);
+ Ok(id)
+ }
+ }
+ }
+
+ /// Generate code to restrict `input` to fall between zero and one less than
+ /// `size_id`.
+ ///
+ /// Both must be 32-bit scalar integer values, whose type is given by
+ /// `type_id`. The computed value is also of type `type_id`.
+ fn restrict_scalar(
+ &mut self,
+ type_id: Word,
+ input_id: Word,
+ size_id: Word,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ let i32_one_id = self.get_scope_constant(1);
+
+ // Subtract one from `size` to get the largest valid value.
+ let limit_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ISub,
+ type_id,
+ limit_id,
+ size_id,
+ i32_one_id,
+ ));
+
+ // Use an unsigned minimum, to handle both positive out-of-range values
+ // and negative values in a single instruction: negative values of
+ // `input_id` get treated as very large positive values.
+ let restricted_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::UMin,
+ type_id,
+ restricted_id,
+ &[input_id, limit_id],
+ ));
+
+ Ok(restricted_id)
+ }
+
+ /// Write instructions to query the size of an image.
+ ///
+ /// This takes care of selecting the right instruction depending on whether
+ /// a level of detail parameter is present.
+ fn write_coordinate_bounds(
+ &mut self,
+ type_id: Word,
+ image_id: Word,
+ level_id: Option<Word>,
+ block: &mut Block,
+ ) -> Word {
+ let coordinate_bounds_id = self.gen_id();
+ match level_id {
+ Some(level_id) => {
+ // A level of detail was provided, so fetch the image size for
+ // that level.
+ let mut inst = Instruction::image_query(
+ spirv::Op::ImageQuerySizeLod,
+ type_id,
+ coordinate_bounds_id,
+ image_id,
+ );
+ inst.add_operand(level_id);
+ block.body.push(inst);
+ }
+ _ => {
+ // No level of detail was given.
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySize,
+ type_id,
+ coordinate_bounds_id,
+ image_id,
+ ));
+ }
+ }
+
+ coordinate_bounds_id
+ }
+
+ /// Write code to restrict coordinates for an image reference.
+ ///
+ /// First, clamp the level of detail or sample index to fall within bounds.
+ /// Then, obtain the image size, possibly using the clamped level of detail.
+ /// Finally, use an unsigned minimum instruction to force all coordinates
+ /// into range.
+ ///
+ /// Return a triple `(COORDS, LEVEL, SAMPLE)`, where `COORDS` is a coordinate
+ /// vector (including the array index, if any), `LEVEL` is an optional level
+ /// of detail, and `SAMPLE` is an optional sample index, all guaranteed to
+ /// be in-bounds for `image_id`.
+ ///
+ /// The result is usually a vector, but it is a scalar when indexing
+ /// non-arrayed 1D images.
+ fn write_restricted_coordinates(
+ &mut self,
+ image_id: Word,
+ coordinates: ImageCoordinates,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ ) -> Result<(Word, Option<Word>, Option<Word>), Error> {
+ self.writer.require_any(
+ "the `Restrict` image bounds check policy",
+ &[spirv::Capability::ImageQuery],
+ )?;
+
+ let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::I32,
+ pointer_space: None,
+ }));
+
+ // If `level` is `Some`, clamp it to fall within bounds. This must
+ // happen first, because we'll use it to query the image size for
+ // clamping the actual coordinates.
+ let level_id = level_id
+ .map(|level_id| {
+ // Find the number of mipmap levels in this image.
+ let num_levels_id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQueryLevels,
+ i32_type_id,
+ num_levels_id,
+ image_id,
+ ));
+
+ self.restrict_scalar(i32_type_id, level_id, num_levels_id, block)
+ })
+ .transpose()?;
+
+ // If `sample_id` is `Some`, clamp it to fall within bounds.
+ let sample_id = sample_id
+ .map(|sample_id| {
+ // Find the number of samples per texel.
+ let num_samples_id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySamples,
+ i32_type_id,
+ num_samples_id,
+ image_id,
+ ));
+
+ self.restrict_scalar(i32_type_id, sample_id, num_samples_id, block)
+ })
+ .transpose()?;
+
+ // Obtain the image bounds, including the array element count.
+ let coordinate_bounds_id =
+ self.write_coordinate_bounds(coordinates.type_id, image_id, level_id, block);
+
+ // Compute maximum valid values from the bounds.
+ let ones = self.write_coordinate_one(&coordinates)?;
+ let coordinate_limit_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ISub,
+ coordinates.type_id,
+ coordinate_limit_id,
+ coordinate_bounds_id,
+ ones,
+ ));
+
+ // Restrict the coordinates to fall within those bounds.
+ //
+ // Use an unsigned minimum, to handle both positive out-of-range values
+ // and negative values in a single instruction: negative values of
+ // `coordinates` get treated as very large positive values.
+ let restricted_coordinates_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::UMin,
+ coordinates.type_id,
+ restricted_coordinates_id,
+ &[coordinates.value_id, coordinate_limit_id],
+ ));
+
+ Ok((restricted_coordinates_id, level_id, sample_id))
+ }
+
+ fn write_conditional_image_access<A: Access>(
+ &mut self,
+ image_id: Word,
+ coordinates: ImageCoordinates,
+ level_id: Option<Word>,
+ sample_id: Option<Word>,
+ block: &mut Block,
+ access: &A,
+ ) -> Result<A::Output, Error> {
+ self.writer.require_any(
+ "the `ReadZeroSkipWrite` image bounds check policy",
+ &[spirv::Capability::ImageQuery],
+ )?;
+
+ let bool_type_id = self.writer.get_bool_type_id();
+ let i32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::I32,
+ pointer_space: None,
+ }));
+
+ let null_id = access.out_of_bounds_value(self);
+
+ let mut selection = Selection::start(block, access.result_type());
+
+ // If `level_id` is `Some`, check whether it is within bounds. This must
+ // happen first, because we'll be supplying this as an argument when we
+ // query the image size.
+ if let Some(level_id) = level_id {
+ // Find the number of mipmap levels in this image.
+ let num_levels_id = self.gen_id();
+ selection.block().body.push(Instruction::image_query(
+ spirv::Op::ImageQueryLevels,
+ i32_type_id,
+ num_levels_id,
+ image_id,
+ ));
+
+ let lod_cond_id = self.gen_id();
+ selection.block().body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ bool_type_id,
+ lod_cond_id,
+ level_id,
+ num_levels_id,
+ ));
+
+ selection.if_true(self, lod_cond_id, null_id);
+ }
+
+ // If `sample_id` is `Some`, check whether it is in bounds.
+ if let Some(sample_id) = sample_id {
+ // Find the number of samples per texel.
+ let num_samples_id = self.gen_id();
+ selection.block().body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySamples,
+ i32_type_id,
+ num_samples_id,
+ image_id,
+ ));
+
+ let samples_cond_id = self.gen_id();
+ selection.block().body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ bool_type_id,
+ samples_cond_id,
+ sample_id,
+ num_samples_id,
+ ));
+
+ selection.if_true(self, samples_cond_id, null_id);
+ }
+
+ // Obtain the image bounds, including any array element count.
+ let coordinate_bounds_id = self.write_coordinate_bounds(
+ coordinates.type_id,
+ image_id,
+ level_id,
+ selection.block(),
+ );
+
+ // Compare the coordinates against the bounds.
+ let coords_bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: coordinates.size,
+ scalar: crate::Scalar::BOOL,
+ pointer_space: None,
+ }));
+ let coords_conds_id = self.gen_id();
+ selection.block().body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ coords_bool_type_id,
+ coords_conds_id,
+ coordinates.value_id,
+ coordinate_bounds_id,
+ ));
+
+ // If the comparison above was a vector comparison, then we need to
+ // check that all components of the comparison are true.
+ let coords_cond_id = if coords_bool_type_id != bool_type_id {
+ let id = self.gen_id();
+ selection.block().body.push(Instruction::relational(
+ spirv::Op::All,
+ bool_type_id,
+ id,
+ coords_conds_id,
+ ));
+ id
+ } else {
+ coords_conds_id
+ };
+
+ selection.if_true(self, coords_cond_id, null_id);
+
+ // All conditions are met. We can carry out the access.
+ let texel_id = access.generate(
+ &mut self.writer.id_gen,
+ coordinates.value_id,
+ level_id,
+ sample_id,
+ selection.block(),
+ );
+
+ // This, then, is the value of the 'true' branch.
+ Ok(selection.finish(self, texel_id))
+ }
+
+ /// Generate code for an `ImageLoad` expression.
+ ///
+ /// The arguments are the components of an `Expression::ImageLoad` variant.
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn write_image_load(
+ &mut self,
+ result_type_id: Word,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ level: Option<Handle<crate::Expression>>,
+ sample: Option<Handle<crate::Expression>>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ let image_id = self.get_handle_id(image);
+ let image_type = self.fun_info[image].ty.inner_with(&self.ir_module.types);
+ let image_class = match *image_type {
+ crate::TypeInner::Image { class, .. } => class,
+ _ => return Err(Error::Validation("image type")),
+ };
+
+ let access = Load::from_image_expr(self, image_id, image_class, result_type_id)?;
+ let coordinates = self.write_image_coordinates(coordinate, array_index, block)?;
+
+ let level_id = level.map(|expr| self.cached[expr]);
+ let sample_id = sample.map(|expr| self.cached[expr]);
+
+ // Perform the access, according to the bounds check policy.
+ let access_id = match self.writer.bounds_check_policies.image_load {
+ crate::proc::BoundsCheckPolicy::Restrict => {
+ let (coords, level_id, sample_id) = self.write_restricted_coordinates(
+ image_id,
+ coordinates,
+ level_id,
+ sample_id,
+ block,
+ )?;
+ access.generate(&mut self.writer.id_gen, coords, level_id, sample_id, block)
+ }
+ crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => self
+ .write_conditional_image_access(
+ image_id,
+ coordinates,
+ level_id,
+ sample_id,
+ block,
+ &access,
+ )?,
+ crate::proc::BoundsCheckPolicy::Unchecked => access.generate(
+ &mut self.writer.id_gen,
+ coordinates.value_id,
+ level_id,
+ sample_id,
+ block,
+ ),
+ };
+
+ // For depth images, `ImageLoad` expressions produce a single f32,
+ // whereas the SPIR-V instructions always produce a vec4. So we may have
+ // to pull out the component we need.
+ let result_id = if result_type_id == access.result_type() {
+ // The instruction produced the type we expected. We can use
+ // its result as-is.
+ access_id
+ } else {
+ // For `ImageClass::Depth` images, SPIR-V gave us four components,
+ // but we only want the first one.
+ let component_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ component_id,
+ access_id,
+ &[0],
+ ));
+ component_id
+ };
+
+ Ok(result_id)
+ }
+
+ /// Generate code for an `ImageSample` expression.
+ ///
+ /// The arguments are the components of an `Expression::ImageSample` variant.
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn write_image_sample(
+ &mut self,
+ result_type_id: Word,
+ image: Handle<crate::Expression>,
+ sampler: Handle<crate::Expression>,
+ gather: Option<crate::SwizzleComponent>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ offset: Option<Handle<crate::Expression>>,
+ level: crate::SampleLevel,
+ depth_ref: Option<Handle<crate::Expression>>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ use super::instructions::SampleLod;
+ // image
+ let image_id = self.get_handle_id(image);
+ let image_type = self.fun_info[image].ty.handle().unwrap();
+ // SPIR-V doesn't know about our `Depth` class, and it returns
+ // `vec4<f32>`, so we need to grab the first component out of it.
+ let needs_sub_access = match self.ir_module.types[image_type].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Depth { .. },
+ ..
+ } => depth_ref.is_none() && gather.is_none(),
+ _ => false,
+ };
+ let sample_result_type_id = if needs_sub_access {
+ self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Quad),
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }))
+ } else {
+ result_type_id
+ };
+
+ // OpTypeSampledImage
+ let image_type_id = self.get_type_id(LookupType::Handle(image_type));
+ let sampled_image_type_id =
+ self.get_type_id(LookupType::Local(LocalType::SampledImage { image_type_id }));
+
+ let sampler_id = self.get_handle_id(sampler);
+ let coordinates_id = self
+ .write_image_coordinates(coordinate, array_index, block)?
+ .value_id;
+
+ let sampled_image_id = self.gen_id();
+ block.body.push(Instruction::sampled_image(
+ sampled_image_type_id,
+ sampled_image_id,
+ image_id,
+ sampler_id,
+ ));
+ let id = self.gen_id();
+
+ let depth_id = depth_ref.map(|handle| self.cached[handle]);
+ let mut mask = spirv::ImageOperands::empty();
+ mask.set(spirv::ImageOperands::CONST_OFFSET, offset.is_some());
+
+ let mut main_instruction = match (level, gather) {
+ (_, Some(component)) => {
+ let component_id = self.get_index_constant(component as u32);
+ let mut inst = Instruction::image_gather(
+ sample_result_type_id,
+ id,
+ sampled_image_id,
+ coordinates_id,
+ component_id,
+ depth_id,
+ );
+ if !mask.is_empty() {
+ inst.add_operand(mask.bits());
+ }
+ inst
+ }
+ (crate::SampleLevel::Zero, None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Explicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let zero_id = self.writer.get_constant_scalar(crate::Literal::F32(0.0));
+
+ mask |= spirv::ImageOperands::LOD;
+ inst.add_operand(mask.bits());
+ inst.add_operand(zero_id);
+
+ inst
+ }
+ (crate::SampleLevel::Auto, None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Implicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+ if !mask.is_empty() {
+ inst.add_operand(mask.bits());
+ }
+ inst
+ }
+ (crate::SampleLevel::Exact(lod_handle), None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Explicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let lod_id = self.cached[lod_handle];
+ mask |= spirv::ImageOperands::LOD;
+ inst.add_operand(mask.bits());
+ inst.add_operand(lod_id);
+
+ inst
+ }
+ (crate::SampleLevel::Bias(bias_handle), None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Implicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let bias_id = self.cached[bias_handle];
+ mask |= spirv::ImageOperands::BIAS;
+ inst.add_operand(mask.bits());
+ inst.add_operand(bias_id);
+
+ inst
+ }
+ (crate::SampleLevel::Gradient { x, y }, None) => {
+ let mut inst = Instruction::image_sample(
+ sample_result_type_id,
+ id,
+ SampleLod::Explicit,
+ sampled_image_id,
+ coordinates_id,
+ depth_id,
+ );
+
+ let x_id = self.cached[x];
+ let y_id = self.cached[y];
+ mask |= spirv::ImageOperands::GRAD;
+ inst.add_operand(mask.bits());
+ inst.add_operand(x_id);
+ inst.add_operand(y_id);
+
+ inst
+ }
+ };
+
+ if let Some(offset_const) = offset {
+ let offset_id = self.writer.constant_ids[offset_const.index()];
+ main_instruction.add_operand(offset_id);
+ }
+
+ block.body.push(main_instruction);
+
+ let id = if needs_sub_access {
+ let sub_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ sub_id,
+ id,
+ &[0],
+ ));
+ sub_id
+ } else {
+ id
+ };
+
+ Ok(id)
+ }
+
+ /// Generate code for an `ImageQuery` expression.
+ ///
+ /// The arguments are the components of an `Expression::ImageQuery` variant.
+ pub(super) fn write_image_query(
+ &mut self,
+ result_type_id: Word,
+ image: Handle<crate::Expression>,
+ query: crate::ImageQuery,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ use crate::{ImageClass as Ic, ImageDimension as Id, ImageQuery as Iq};
+
+ let image_id = self.get_handle_id(image);
+ let image_type = self.fun_info[image].ty.handle().unwrap();
+ let (dim, arrayed, class) = match self.ir_module.types[image_type].inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => (dim, arrayed, class),
+ _ => {
+ return Err(Error::Validation("image type"));
+ }
+ };
+
+ self.writer
+ .require_any("image queries", &[spirv::Capability::ImageQuery])?;
+
+ let id = match query {
+ Iq::Size { level } => {
+ let dim_coords = match dim {
+ Id::D1 => 1,
+ Id::D2 | Id::Cube => 2,
+ Id::D3 => 3,
+ };
+ let array_coords = usize::from(arrayed);
+ let vector_size = match dim_coords + array_coords {
+ 2 => Some(crate::VectorSize::Bi),
+ 3 => Some(crate::VectorSize::Tri),
+ 4 => Some(crate::VectorSize::Quad),
+ _ => None,
+ };
+ let extended_size_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size,
+ scalar: crate::Scalar::U32,
+ pointer_space: None,
+ }));
+
+ let (query_op, level_id) = match class {
+ Ic::Sampled { multi: true, .. }
+ | Ic::Depth { multi: true }
+ | Ic::Storage { .. } => (spirv::Op::ImageQuerySize, None),
+ _ => {
+ let level_id = match level {
+ Some(expr) => self.cached[expr],
+ None => self.get_index_constant(0),
+ };
+ (spirv::Op::ImageQuerySizeLod, Some(level_id))
+ }
+ };
+
+ // The ID of the vector returned by SPIR-V, which contains the dimensions
+ // as well as the layer count.
+ let id_extended = self.gen_id();
+ let mut inst = Instruction::image_query(
+ query_op,
+ extended_size_type_id,
+ id_extended,
+ image_id,
+ );
+ if let Some(expr_id) = level_id {
+ inst.add_operand(expr_id);
+ }
+ block.body.push(inst);
+
+ if result_type_id != extended_size_type_id {
+ let id = self.gen_id();
+ let components = match dim {
+ // always pick the first component, and duplicate it for all 3 dimensions
+ Id::Cube => &[0u32, 0][..],
+ _ => &[0u32, 1, 2, 3][..dim_coords],
+ };
+ block.body.push(Instruction::vector_shuffle(
+ result_type_id,
+ id,
+ id_extended,
+ id_extended,
+ components,
+ ));
+
+ id
+ } else {
+ id_extended
+ }
+ }
+ Iq::NumLevels => {
+ let query_id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQueryLevels,
+ result_type_id,
+ query_id,
+ image_id,
+ ));
+
+ query_id
+ }
+ Iq::NumLayers => {
+ let vec_size = match dim {
+ Id::D1 => crate::VectorSize::Bi,
+ Id::D2 | Id::Cube => crate::VectorSize::Tri,
+ Id::D3 => crate::VectorSize::Quad,
+ };
+ let extended_size_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(vec_size),
+ scalar: crate::Scalar::U32,
+ pointer_space: None,
+ }));
+ let id_extended = self.gen_id();
+ let mut inst = Instruction::image_query(
+ spirv::Op::ImageQuerySizeLod,
+ extended_size_type_id,
+ id_extended,
+ image_id,
+ );
+ inst.add_operand(self.get_index_constant(0));
+ block.body.push(inst);
+
+ let extract_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ extract_id,
+ id_extended,
+ &[vec_size as u32 - 1],
+ ));
+
+ extract_id
+ }
+ Iq::NumSamples => {
+ let query_id = self.gen_id();
+ block.body.push(Instruction::image_query(
+ spirv::Op::ImageQuerySamples,
+ result_type_id,
+ query_id,
+ image_id,
+ ));
+
+ query_id
+ }
+ };
+
+ Ok(id)
+ }
+
+ pub(super) fn write_image_store(
+ &mut self,
+ image: Handle<crate::Expression>,
+ coordinate: Handle<crate::Expression>,
+ array_index: Option<Handle<crate::Expression>>,
+ value: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ let image_id = self.get_handle_id(image);
+ let coordinates = self.write_image_coordinates(coordinate, array_index, block)?;
+ let value_id = self.cached[value];
+
+ let write = Store { image_id, value_id };
+
+ match *self.fun_info[image].ty.inner_with(&self.ir_module.types) {
+ crate::TypeInner::Image {
+ class:
+ crate::ImageClass::Storage {
+ format: crate::StorageFormat::Bgra8Unorm,
+ ..
+ },
+ ..
+ } => self.writer.require_any(
+ "Bgra8Unorm storage write",
+ &[spirv::Capability::StorageImageWriteWithoutFormat],
+ )?,
+ _ => {}
+ }
+
+ match self.writer.bounds_check_policies.image_store {
+ crate::proc::BoundsCheckPolicy::Restrict => {
+ let (coords, _, _) =
+ self.write_restricted_coordinates(image_id, coordinates, None, None, block)?;
+ write.generate(&mut self.writer.id_gen, coords, None, None, block);
+ }
+ crate::proc::BoundsCheckPolicy::ReadZeroSkipWrite => {
+ self.write_conditional_image_access(
+ image_id,
+ coordinates,
+ None,
+ None,
+ block,
+ &write,
+ )?;
+ }
+ crate::proc::BoundsCheckPolicy::Unchecked => {
+ write.generate(
+ &mut self.writer.id_gen,
+ coordinates.value_id,
+ None,
+ None,
+ block,
+ );
+ }
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/index.rs b/third_party/rust/naga/src/back/spv/index.rs
new file mode 100644
index 0000000000..92e0f88d9a
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/index.rs
@@ -0,0 +1,421 @@
+/*!
+Bounds-checking for SPIR-V output.
+*/
+
+use super::{
+ helpers::global_needs_wrapper, selection::Selection, Block, BlockContext, Error, IdGenerator,
+ Instruction, Word,
+};
+use crate::{arena::Handle, proc::BoundsCheckPolicy};
+
+/// The results of performing a bounds check.
+///
+/// On success, `write_bounds_check` returns a value of this type.
+pub(super) enum BoundsCheckResult {
+ /// The index is statically known and in bounds, with the given value.
+ KnownInBounds(u32),
+
+ /// The given instruction computes the index to be used.
+ Computed(Word),
+
+ /// The given instruction computes a boolean condition which is true
+ /// if the index is in bounds.
+ Conditional(Word),
+}
+
+/// A value that we either know at translation time, or need to compute at runtime.
+pub(super) enum MaybeKnown<T> {
+ /// The value is known at shader translation time.
+ Known(T),
+
+ /// The value is computed by the instruction with the given id.
+ Computed(Word),
+}
+
+impl<'w> BlockContext<'w> {
+ /// Emit code to compute the length of a run-time array.
+ ///
+ /// Given `array`, an expression referring a runtime-sized array, return the
+ /// instruction id for the array's length.
+ pub(super) fn write_runtime_array_length(
+ &mut self,
+ array: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ // Naga IR permits runtime-sized arrays as global variables or as the
+ // final member of a struct that is a global variable. SPIR-V permits
+ // only the latter, so this back end wraps bare runtime-sized arrays
+ // in a made-up struct; see `helpers::global_needs_wrapper` and its uses.
+ // This code must handle both cases.
+ let (structure_id, last_member_index) = match self.ir_function.expressions[array] {
+ crate::Expression::AccessIndex { base, index } => {
+ match self.ir_function.expressions[base] {
+ crate::Expression::GlobalVariable(handle) => (
+ self.writer.global_variables[handle.index()].access_id,
+ index,
+ ),
+ _ => return Err(Error::Validation("array length expression")),
+ }
+ }
+ crate::Expression::GlobalVariable(handle) => {
+ let global = &self.ir_module.global_variables[handle];
+ if !global_needs_wrapper(self.ir_module, global) {
+ return Err(Error::Validation("array length expression"));
+ }
+
+ (self.writer.global_variables[handle.index()].var_id, 0)
+ }
+ _ => return Err(Error::Validation("array length expression")),
+ };
+
+ let length_id = self.gen_id();
+ block.body.push(Instruction::array_length(
+ self.writer.get_uint_type_id(),
+ length_id,
+ structure_id,
+ last_member_index,
+ ));
+
+ Ok(length_id)
+ }
+
+ /// Compute the length of a subscriptable value.
+ ///
+ /// Given `sequence`, an expression referring to some indexable type, return
+ /// its length. The result may either be computed by SPIR-V instructions, or
+ /// known at shader translation time.
+ ///
+ /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
+ /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
+ /// sized, or use a specializable constant as its length.
+ fn write_sequence_length(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<MaybeKnown<u32>, Error> {
+ let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types);
+ match sequence_ty.indexable_length(self.ir_module) {
+ Ok(crate::proc::IndexableLength::Known(known_length)) => {
+ Ok(MaybeKnown::Known(known_length))
+ }
+ Ok(crate::proc::IndexableLength::Dynamic) => {
+ let length_id = self.write_runtime_array_length(sequence, block)?;
+ Ok(MaybeKnown::Computed(length_id))
+ }
+ Err(err) => {
+ log::error!("Sequence length for {:?} failed: {}", sequence, err);
+ Err(Error::Validation("indexable length"))
+ }
+ }
+ }
+
+ /// Compute the maximum valid index of a subscriptable value.
+ ///
+ /// Given `sequence`, an expression referring to some indexable type, return
+ /// its maximum valid index - one less than its length. The result may
+ /// either be computed, or known at shader translation time.
+ ///
+ /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
+ /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
+ /// sized, or use a specializable constant as its length.
+ fn write_sequence_max_index(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<MaybeKnown<u32>, Error> {
+ match self.write_sequence_length(sequence, block)? {
+ MaybeKnown::Known(known_length) => {
+ // We should have thrown out all attempts to subscript zero-length
+ // sequences during validation, so the following subtraction should never
+ // underflow.
+ assert!(known_length > 0);
+ // Compute the max index from the length now.
+ Ok(MaybeKnown::Known(known_length - 1))
+ }
+ MaybeKnown::Computed(length_id) => {
+ // Emit code to compute the max index from the length.
+ let const_one_id = self.get_index_constant(1);
+ let max_index_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ISub,
+ self.writer.get_uint_type_id(),
+ max_index_id,
+ length_id,
+ const_one_id,
+ ));
+ Ok(MaybeKnown::Computed(max_index_id))
+ }
+ }
+ }
+
+ /// Restrict an index to be in range for a vector, matrix, or array.
+ ///
+ /// This is used to implement `BoundsCheckPolicy::Restrict`. An in-bounds
+ /// index is left unchanged. An out-of-bounds index is replaced with some
+ /// arbitrary in-bounds index. Note,this is not necessarily clamping; for
+ /// example, negative indices might be changed to refer to the last element
+ /// of the sequence, not the first, as clamping would do.
+ ///
+ /// Either return the restricted index value, if known, or add instructions
+ /// to `block` to compute it, and return the id of the result. See the
+ /// documentation for `BoundsCheckResult` for details.
+ ///
+ /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
+ /// `Pointer` to any of those, or a `ValuePointer`. An array may be
+ /// fixed-size, dynamically sized, or use a specializable constant as its
+ /// length.
+ pub(super) fn write_restricted_index(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<BoundsCheckResult, Error> {
+ let index_id = self.cached[index];
+
+ // Get the sequence's maximum valid index. Return early if we've already
+ // done the bounds check.
+ let max_index_id = match self.write_sequence_max_index(sequence, block)? {
+ MaybeKnown::Known(known_max_index) => {
+ if let Ok(known_index) = self
+ .ir_module
+ .to_ctx()
+ .eval_expr_to_u32_from(index, &self.ir_function.expressions)
+ {
+ // Both the index and length are known at compile time.
+ //
+ // In strict WGSL compliance mode, out-of-bounds indices cannot be
+ // reported at shader translation time, and must be replaced with
+ // in-bounds indices at run time. So we cannot assume that
+ // validation ensured the index was in bounds. Restrict now.
+ let restricted = std::cmp::min(known_index, known_max_index);
+ return Ok(BoundsCheckResult::KnownInBounds(restricted));
+ }
+
+ self.get_index_constant(known_max_index)
+ }
+ MaybeKnown::Computed(max_index_id) => max_index_id,
+ };
+
+ // One or the other of the index or length is dynamic, so emit code for
+ // BoundsCheckPolicy::Restrict.
+ let restricted_index_id = self.gen_id();
+ block.body.push(Instruction::ext_inst(
+ self.writer.gl450_ext_inst_id,
+ spirv::GLOp::UMin,
+ self.writer.get_uint_type_id(),
+ restricted_index_id,
+ &[index_id, max_index_id],
+ ));
+ Ok(BoundsCheckResult::Computed(restricted_index_id))
+ }
+
+ /// Write an index bounds comparison to `block`, if needed.
+ ///
+ /// If we're able to determine statically that `index` is in bounds for
+ /// `sequence`, return `KnownInBounds(value)`, where `value` is the actual
+ /// value of the index. (In principle, one could know that the index is in
+ /// bounds without knowing its specific value, but in our simple-minded
+ /// situation, we always know it.)
+ ///
+ /// If instead we must generate code to perform the comparison at run time,
+ /// return `Conditional(comparison_id)`, where `comparison_id` is an
+ /// instruction producing a boolean value that is true if `index` is in
+ /// bounds for `sequence`.
+ ///
+ /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
+ /// `Pointer` to any of those, or a `ValuePointer`. An array may be
+ /// fixed-size, dynamically sized, or use a specializable constant as its
+ /// length.
+ fn write_index_comparison(
+ &mut self,
+ sequence: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<BoundsCheckResult, Error> {
+ let index_id = self.cached[index];
+
+ // Get the sequence's length. Return early if we've already done the
+ // bounds check.
+ let length_id = match self.write_sequence_length(sequence, block)? {
+ MaybeKnown::Known(known_length) => {
+ if let Ok(known_index) = self
+ .ir_module
+ .to_ctx()
+ .eval_expr_to_u32_from(index, &self.ir_function.expressions)
+ {
+ // Both the index and length are known at compile time.
+ //
+ // It would be nice to assume that, since we are using the
+ // `ReadZeroSkipWrite` policy, we are not in strict WGSL
+ // compliance mode, and thus we can count on the validator to have
+ // rejected any programs with known out-of-bounds indices, and
+ // thus just return `KnownInBounds` here without actually
+ // checking.
+ //
+ // But it's also reasonable to expect that bounds check policies
+ // and error reporting policies should be able to vary
+ // independently without introducing security holes. So, we should
+ // support the case where bad indices do not cause validation
+ // errors, and are handled via `ReadZeroSkipWrite`.
+ //
+ // In theory, when `known_index` is bad, we could return a new
+ // `KnownOutOfBounds` variant here. But it's simpler just to fall
+ // through and let the bounds check take place. The shader is
+ // broken anyway, so it doesn't make sense to invest in emitting
+ // the ideal code for it.
+ if known_index < known_length {
+ return Ok(BoundsCheckResult::KnownInBounds(known_index));
+ }
+ }
+
+ self.get_index_constant(known_length)
+ }
+ MaybeKnown::Computed(length_id) => length_id,
+ };
+
+ // Compare the index against the length.
+ let condition_id = self.gen_id();
+ block.body.push(Instruction::binary(
+ spirv::Op::ULessThan,
+ self.writer.get_bool_type_id(),
+ condition_id,
+ index_id,
+ length_id,
+ ));
+
+ // Indicate that we did generate the check.
+ Ok(BoundsCheckResult::Conditional(condition_id))
+ }
+
+ /// Emit a conditional load for `BoundsCheckPolicy::ReadZeroSkipWrite`.
+ ///
+ /// Generate code to load a value of `result_type` if `condition` is true,
+ /// and generate a null value of that type if it is false. Call `emit_load`
+ /// to emit the instructions to perform the load. Return the id of the
+ /// merged value of the two branches.
+ pub(super) fn write_conditional_indexed_load<F>(
+ &mut self,
+ result_type: Word,
+ condition: Word,
+ block: &mut Block,
+ emit_load: F,
+ ) -> Word
+ where
+ F: FnOnce(&mut IdGenerator, &mut Block) -> Word,
+ {
+ // For the out-of-bounds case, we produce a zero value.
+ let null_id = self.writer.get_constant_null(result_type);
+
+ let mut selection = Selection::start(block, result_type);
+
+ // As it turns out, we don't actually need a full 'if-then-else'
+ // structure for this: SPIR-V constants are declared up front, so the
+ // 'else' block would have no instructions. Instead we emit something
+ // like this:
+ //
+ // result = zero;
+ // if in_bounds {
+ // result = do the load;
+ // }
+ // use result;
+
+ // Continue only if the index was in bounds. Otherwise, branch to the
+ // merge block.
+ selection.if_true(self, condition, null_id);
+
+ // The in-bounds path. Perform the access and the load.
+ let loaded_value = emit_load(&mut self.writer.id_gen, selection.block());
+
+ selection.finish(self, loaded_value)
+ }
+
+ /// Emit code for bounds checks for an array, vector, or matrix access.
+ ///
+ /// This implements either `index_bounds_check_policy` or
+ /// `buffer_bounds_check_policy`, depending on the address space of the
+ /// pointer being accessed.
+ ///
+ /// Return a `BoundsCheckResult` indicating how the index should be
+ /// consumed. See that type's documentation for details.
+ pub(super) fn write_bounds_check(
+ &mut self,
+ base: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<BoundsCheckResult, Error> {
+ let policy = self.writer.bounds_check_policies.choose_policy(
+ base,
+ &self.ir_module.types,
+ self.fun_info,
+ );
+
+ Ok(match policy {
+ BoundsCheckPolicy::Restrict => self.write_restricted_index(base, index, block)?,
+ BoundsCheckPolicy::ReadZeroSkipWrite => {
+ self.write_index_comparison(base, index, block)?
+ }
+ BoundsCheckPolicy::Unchecked => BoundsCheckResult::Computed(self.cached[index]),
+ })
+ }
+
+ /// Emit code to subscript a vector by value with a computed index.
+ ///
+ /// Return the id of the element value.
+ pub(super) fn write_vector_access(
+ &mut self,
+ expr_handle: Handle<crate::Expression>,
+ base: Handle<crate::Expression>,
+ index: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<Word, Error> {
+ let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
+
+ let base_id = self.cached[base];
+ let index_id = self.cached[index];
+
+ let result_id = match self.write_bounds_check(base, index, block)? {
+ BoundsCheckResult::KnownInBounds(known_index) => {
+ let result_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ result_type_id,
+ result_id,
+ base_id,
+ &[known_index],
+ ));
+ result_id
+ }
+ BoundsCheckResult::Computed(computed_index_id) => {
+ let result_id = self.gen_id();
+ block.body.push(Instruction::vector_extract_dynamic(
+ result_type_id,
+ result_id,
+ base_id,
+ computed_index_id,
+ ));
+ result_id
+ }
+ BoundsCheckResult::Conditional(comparison_id) => {
+ // Run-time bounds checks were required. Emit
+ // conditional load.
+ self.write_conditional_indexed_load(
+ result_type_id,
+ comparison_id,
+ block,
+ |id_gen, block| {
+ // The in-bounds path. Generate the access.
+ let element_id = id_gen.next();
+ block.body.push(Instruction::vector_extract_dynamic(
+ result_type_id,
+ element_id,
+ base_id,
+ index_id,
+ ));
+ element_id
+ },
+ )
+ }
+ };
+
+ Ok(result_id)
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/instructions.rs b/third_party/rust/naga/src/back/spv/instructions.rs
new file mode 100644
index 0000000000..b963793ad3
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/instructions.rs
@@ -0,0 +1,1100 @@
+use super::{block::DebugInfoInner, helpers};
+use spirv::{Op, Word};
+
+pub(super) enum Signedness {
+ Unsigned = 0,
+ Signed = 1,
+}
+
+pub(super) enum SampleLod {
+ Explicit,
+ Implicit,
+}
+
+pub(super) struct Case {
+ pub value: Word,
+ pub label_id: Word,
+}
+
+impl super::Instruction {
+ //
+ // Debug Instructions
+ //
+
+ pub(super) fn string(name: &str, id: Word) -> Self {
+ let mut instruction = Self::new(Op::String);
+ instruction.set_result(id);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn source(
+ source_language: spirv::SourceLanguage,
+ version: u32,
+ source: &Option<DebugInfoInner>,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Source);
+ instruction.add_operand(source_language as u32);
+ instruction.add_operands(helpers::bytes_to_words(&version.to_le_bytes()));
+ if let Some(source) = source.as_ref() {
+ instruction.add_operand(source.source_file_id);
+ instruction.add_operands(helpers::string_to_words(source.source_code));
+ }
+ instruction
+ }
+
+ pub(super) fn name(target_id: Word, name: &str) -> Self {
+ let mut instruction = Self::new(Op::Name);
+ instruction.add_operand(target_id);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn member_name(target_id: Word, member: Word, name: &str) -> Self {
+ let mut instruction = Self::new(Op::MemberName);
+ instruction.add_operand(target_id);
+ instruction.add_operand(member);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn line(file: Word, line: Word, column: Word) -> Self {
+ let mut instruction = Self::new(Op::Line);
+ instruction.add_operand(file);
+ instruction.add_operand(line);
+ instruction.add_operand(column);
+ instruction
+ }
+
+ pub(super) const fn no_line() -> Self {
+ Self::new(Op::NoLine)
+ }
+
+ //
+ // Annotation Instructions
+ //
+
+ pub(super) fn decorate(
+ target_id: Word,
+ decoration: spirv::Decoration,
+ operands: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::Decorate);
+ instruction.add_operand(target_id);
+ instruction.add_operand(decoration as u32);
+ for operand in operands {
+ instruction.add_operand(*operand)
+ }
+ instruction
+ }
+
+ pub(super) fn member_decorate(
+ target_id: Word,
+ member_index: Word,
+ decoration: spirv::Decoration,
+ operands: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::MemberDecorate);
+ instruction.add_operand(target_id);
+ instruction.add_operand(member_index);
+ instruction.add_operand(decoration as u32);
+ for operand in operands {
+ instruction.add_operand(*operand)
+ }
+ instruction
+ }
+
+ //
+ // Extension Instructions
+ //
+
+ pub(super) fn extension(name: &str) -> Self {
+ let mut instruction = Self::new(Op::Extension);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn ext_inst_import(id: Word, name: &str) -> Self {
+ let mut instruction = Self::new(Op::ExtInstImport);
+ instruction.set_result(id);
+ instruction.add_operands(helpers::string_to_words(name));
+ instruction
+ }
+
+ pub(super) fn ext_inst(
+ set_id: Word,
+ op: spirv::GLOp,
+ result_type_id: Word,
+ id: Word,
+ operands: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::ExtInst);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(set_id);
+ instruction.add_operand(op as u32);
+ for operand in operands {
+ instruction.add_operand(*operand)
+ }
+ instruction
+ }
+
+ //
+ // Mode-Setting Instructions
+ //
+
+ pub(super) fn memory_model(
+ addressing_model: spirv::AddressingModel,
+ memory_model: spirv::MemoryModel,
+ ) -> Self {
+ let mut instruction = Self::new(Op::MemoryModel);
+ instruction.add_operand(addressing_model as u32);
+ instruction.add_operand(memory_model as u32);
+ instruction
+ }
+
+ pub(super) fn entry_point(
+ execution_model: spirv::ExecutionModel,
+ entry_point_id: Word,
+ name: &str,
+ interface_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::EntryPoint);
+ instruction.add_operand(execution_model as u32);
+ instruction.add_operand(entry_point_id);
+ instruction.add_operands(helpers::string_to_words(name));
+
+ for interface_id in interface_ids {
+ instruction.add_operand(*interface_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn execution_mode(
+ entry_point_id: Word,
+ execution_mode: spirv::ExecutionMode,
+ args: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::ExecutionMode);
+ instruction.add_operand(entry_point_id);
+ instruction.add_operand(execution_mode as u32);
+ for arg in args {
+ instruction.add_operand(*arg);
+ }
+ instruction
+ }
+
+ pub(super) fn capability(capability: spirv::Capability) -> Self {
+ let mut instruction = Self::new(Op::Capability);
+ instruction.add_operand(capability as u32);
+ instruction
+ }
+
+ //
+ // Type-Declaration Instructions
+ //
+
+ pub(super) fn type_void(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeVoid);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_bool(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeBool);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_int(id: Word, width: Word, signedness: Signedness) -> Self {
+ let mut instruction = Self::new(Op::TypeInt);
+ instruction.set_result(id);
+ instruction.add_operand(width);
+ instruction.add_operand(signedness as u32);
+ instruction
+ }
+
+ pub(super) fn type_float(id: Word, width: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeFloat);
+ instruction.set_result(id);
+ instruction.add_operand(width);
+ instruction
+ }
+
+ pub(super) fn type_vector(
+ id: Word,
+ component_type_id: Word,
+ component_count: crate::VectorSize,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypeVector);
+ instruction.set_result(id);
+ instruction.add_operand(component_type_id);
+ instruction.add_operand(component_count as u32);
+ instruction
+ }
+
+ pub(super) fn type_matrix(
+ id: Word,
+ column_type_id: Word,
+ column_count: crate::VectorSize,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypeMatrix);
+ instruction.set_result(id);
+ instruction.add_operand(column_type_id);
+ instruction.add_operand(column_count as u32);
+ instruction
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn type_image(
+ id: Word,
+ sampled_type_id: Word,
+ dim: spirv::Dim,
+ flags: super::ImageTypeFlags,
+ image_format: spirv::ImageFormat,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypeImage);
+ instruction.set_result(id);
+ instruction.add_operand(sampled_type_id);
+ instruction.add_operand(dim as u32);
+ instruction.add_operand(flags.contains(super::ImageTypeFlags::DEPTH) as u32);
+ instruction.add_operand(flags.contains(super::ImageTypeFlags::ARRAYED) as u32);
+ instruction.add_operand(flags.contains(super::ImageTypeFlags::MULTISAMPLED) as u32);
+ instruction.add_operand(if flags.contains(super::ImageTypeFlags::SAMPLED) {
+ 1
+ } else {
+ 2
+ });
+ instruction.add_operand(image_format as u32);
+ instruction
+ }
+
+ pub(super) fn type_sampler(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeSampler);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_acceleration_structure(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeAccelerationStructureKHR);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_ray_query(id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeRayQueryKHR);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn type_sampled_image(id: Word, image_type_id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeSampledImage);
+ instruction.set_result(id);
+ instruction.add_operand(image_type_id);
+ instruction
+ }
+
+ pub(super) fn type_array(id: Word, element_type_id: Word, length_id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeArray);
+ instruction.set_result(id);
+ instruction.add_operand(element_type_id);
+ instruction.add_operand(length_id);
+ instruction
+ }
+
+ pub(super) fn type_runtime_array(id: Word, element_type_id: Word) -> Self {
+ let mut instruction = Self::new(Op::TypeRuntimeArray);
+ instruction.set_result(id);
+ instruction.add_operand(element_type_id);
+ instruction
+ }
+
+ pub(super) fn type_struct(id: Word, member_ids: &[Word]) -> Self {
+ let mut instruction = Self::new(Op::TypeStruct);
+ instruction.set_result(id);
+
+ for member_id in member_ids {
+ instruction.add_operand(*member_id)
+ }
+
+ instruction
+ }
+
+ pub(super) fn type_pointer(
+ id: Word,
+ storage_class: spirv::StorageClass,
+ type_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::TypePointer);
+ instruction.set_result(id);
+ instruction.add_operand(storage_class as u32);
+ instruction.add_operand(type_id);
+ instruction
+ }
+
+ pub(super) fn type_function(id: Word, return_type_id: Word, parameter_ids: &[Word]) -> Self {
+ let mut instruction = Self::new(Op::TypeFunction);
+ instruction.set_result(id);
+ instruction.add_operand(return_type_id);
+
+ for parameter_id in parameter_ids {
+ instruction.add_operand(*parameter_id);
+ }
+
+ instruction
+ }
+
+ //
+ // Constant-Creation Instructions
+ //
+
+ pub(super) fn constant_null(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::ConstantNull);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn constant_true(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::ConstantTrue);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn constant_false(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::ConstantFalse);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn constant_32bit(result_type_id: Word, id: Word, value: Word) -> Self {
+ Self::constant(result_type_id, id, &[value])
+ }
+
+ pub(super) fn constant_64bit(result_type_id: Word, id: Word, low: Word, high: Word) -> Self {
+ Self::constant(result_type_id, id, &[low, high])
+ }
+
+ pub(super) fn constant(result_type_id: Word, id: Word, values: &[Word]) -> Self {
+ let mut instruction = Self::new(Op::Constant);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ for value in values {
+ instruction.add_operand(*value);
+ }
+
+ instruction
+ }
+
+ pub(super) fn constant_composite(
+ result_type_id: Word,
+ id: Word,
+ constituent_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::ConstantComposite);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ for constituent_id in constituent_ids {
+ instruction.add_operand(*constituent_id);
+ }
+
+ instruction
+ }
+
+ //
+ // Memory Instructions
+ //
+
+ pub(super) fn variable(
+ result_type_id: Word,
+ id: Word,
+ storage_class: spirv::StorageClass,
+ initializer_id: Option<Word>,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Variable);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(storage_class as u32);
+
+ if let Some(initializer_id) = initializer_id {
+ instruction.add_operand(initializer_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn load(
+ result_type_id: Word,
+ id: Word,
+ pointer_id: Word,
+ memory_access: Option<spirv::MemoryAccess>,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Load);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(pointer_id);
+
+ if let Some(memory_access) = memory_access {
+ instruction.add_operand(memory_access.bits());
+ }
+
+ instruction
+ }
+
+ pub(super) fn atomic_load(
+ result_type_id: Word,
+ id: Word,
+ pointer_id: Word,
+ scope_id: Word,
+ semantics_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::AtomicLoad);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(pointer_id);
+ instruction.add_operand(scope_id);
+ instruction.add_operand(semantics_id);
+ instruction
+ }
+
+ pub(super) fn store(
+ pointer_id: Word,
+ value_id: Word,
+ memory_access: Option<spirv::MemoryAccess>,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Store);
+ instruction.add_operand(pointer_id);
+ instruction.add_operand(value_id);
+
+ if let Some(memory_access) = memory_access {
+ instruction.add_operand(memory_access.bits());
+ }
+
+ instruction
+ }
+
+ pub(super) fn atomic_store(
+ pointer_id: Word,
+ scope_id: Word,
+ semantics_id: Word,
+ value_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::AtomicStore);
+ instruction.add_operand(pointer_id);
+ instruction.add_operand(scope_id);
+ instruction.add_operand(semantics_id);
+ instruction.add_operand(value_id);
+ instruction
+ }
+
+ pub(super) fn access_chain(
+ result_type_id: Word,
+ id: Word,
+ base_id: Word,
+ index_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::AccessChain);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(base_id);
+
+ for index_id in index_ids {
+ instruction.add_operand(*index_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn array_length(
+ result_type_id: Word,
+ id: Word,
+ structure_id: Word,
+ array_member: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::ArrayLength);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(structure_id);
+ instruction.add_operand(array_member);
+ instruction
+ }
+
+ //
+ // Function Instructions
+ //
+
+ pub(super) fn function(
+ return_type_id: Word,
+ id: Word,
+ function_control: spirv::FunctionControl,
+ function_type_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Function);
+ instruction.set_type(return_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(function_control.bits());
+ instruction.add_operand(function_type_id);
+ instruction
+ }
+
+ pub(super) fn function_parameter(result_type_id: Word, id: Word) -> Self {
+ let mut instruction = Self::new(Op::FunctionParameter);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) const fn function_end() -> Self {
+ Self::new(Op::FunctionEnd)
+ }
+
+ pub(super) fn function_call(
+ result_type_id: Word,
+ id: Word,
+ function_id: Word,
+ argument_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::FunctionCall);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(function_id);
+
+ for argument_id in argument_ids {
+ instruction.add_operand(*argument_id);
+ }
+
+ instruction
+ }
+
+ //
+ // Image Instructions
+ //
+
+ pub(super) fn sampled_image(
+ result_type_id: Word,
+ id: Word,
+ image: Word,
+ sampler: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::SampledImage);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(image);
+ instruction.add_operand(sampler);
+ instruction
+ }
+
+ pub(super) fn image_sample(
+ result_type_id: Word,
+ id: Word,
+ lod: SampleLod,
+ sampled_image: Word,
+ coordinates: Word,
+ depth_ref: Option<Word>,
+ ) -> Self {
+ let op = match (lod, depth_ref) {
+ (SampleLod::Explicit, None) => Op::ImageSampleExplicitLod,
+ (SampleLod::Implicit, None) => Op::ImageSampleImplicitLod,
+ (SampleLod::Explicit, Some(_)) => Op::ImageSampleDrefExplicitLod,
+ (SampleLod::Implicit, Some(_)) => Op::ImageSampleDrefImplicitLod,
+ };
+
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(sampled_image);
+ instruction.add_operand(coordinates);
+ if let Some(dref) = depth_ref {
+ instruction.add_operand(dref);
+ }
+
+ instruction
+ }
+
+ pub(super) fn image_gather(
+ result_type_id: Word,
+ id: Word,
+ sampled_image: Word,
+ coordinates: Word,
+ component_id: Word,
+ depth_ref: Option<Word>,
+ ) -> Self {
+ let op = match depth_ref {
+ None => Op::ImageGather,
+ Some(_) => Op::ImageDrefGather,
+ };
+
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(sampled_image);
+ instruction.add_operand(coordinates);
+ if let Some(dref) = depth_ref {
+ instruction.add_operand(dref);
+ } else {
+ instruction.add_operand(component_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn image_fetch_or_read(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ image: Word,
+ coordinates: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(image);
+ instruction.add_operand(coordinates);
+ instruction
+ }
+
+ pub(super) fn image_write(image: Word, coordinates: Word, value: Word) -> Self {
+ let mut instruction = Self::new(Op::ImageWrite);
+ instruction.add_operand(image);
+ instruction.add_operand(coordinates);
+ instruction.add_operand(value);
+ instruction
+ }
+
+ pub(super) fn image_query(op: Op, result_type_id: Word, id: Word, image: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(image);
+ instruction
+ }
+
+ //
+ // Ray Query Instructions
+ //
+ #[allow(clippy::too_many_arguments)]
+ pub(super) fn ray_query_initialize(
+ query: Word,
+ acceleration_structure: Word,
+ ray_flags: Word,
+ cull_mask: Word,
+ ray_origin: Word,
+ ray_tmin: Word,
+ ray_dir: Word,
+ ray_tmax: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::RayQueryInitializeKHR);
+ instruction.add_operand(query);
+ instruction.add_operand(acceleration_structure);
+ instruction.add_operand(ray_flags);
+ instruction.add_operand(cull_mask);
+ instruction.add_operand(ray_origin);
+ instruction.add_operand(ray_tmin);
+ instruction.add_operand(ray_dir);
+ instruction.add_operand(ray_tmax);
+ instruction
+ }
+
+ pub(super) fn ray_query_proceed(result_type_id: Word, id: Word, query: Word) -> Self {
+ let mut instruction = Self::new(Op::RayQueryProceedKHR);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(query);
+ instruction
+ }
+
+ pub(super) fn ray_query_get_intersection(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ query: Word,
+ intersection: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(query);
+ instruction.add_operand(intersection);
+ instruction
+ }
+
+ //
+ // Conversion Instructions
+ //
+ pub(super) fn unary(op: Op, result_type_id: Word, id: Word, value: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(value);
+ instruction
+ }
+
+ //
+ // Composite Instructions
+ //
+
+ pub(super) fn composite_construct(
+ result_type_id: Word,
+ id: Word,
+ constituent_ids: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::CompositeConstruct);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ for constituent_id in constituent_ids {
+ instruction.add_operand(*constituent_id);
+ }
+
+ instruction
+ }
+
+ pub(super) fn composite_extract(
+ result_type_id: Word,
+ id: Word,
+ composite_id: Word,
+ indices: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::CompositeExtract);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ instruction.add_operand(composite_id);
+ for index in indices {
+ instruction.add_operand(*index);
+ }
+
+ instruction
+ }
+
+ pub(super) fn vector_extract_dynamic(
+ result_type_id: Word,
+ id: Word,
+ vector_id: Word,
+ index_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::VectorExtractDynamic);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+
+ instruction.add_operand(vector_id);
+ instruction.add_operand(index_id);
+
+ instruction
+ }
+
+ pub(super) fn vector_shuffle(
+ result_type_id: Word,
+ id: Word,
+ v1_id: Word,
+ v2_id: Word,
+ components: &[Word],
+ ) -> Self {
+ let mut instruction = Self::new(Op::VectorShuffle);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(v1_id);
+ instruction.add_operand(v2_id);
+
+ for &component in components {
+ instruction.add_operand(component);
+ }
+
+ instruction
+ }
+
+ //
+ // Arithmetic Instructions
+ //
+ pub(super) fn binary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(operand_1);
+ instruction.add_operand(operand_2);
+ instruction
+ }
+
+ pub(super) fn ternary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+ operand_3: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(operand_1);
+ instruction.add_operand(operand_2);
+ instruction.add_operand(operand_3);
+ instruction
+ }
+
+ pub(super) fn quaternary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ operand_1: Word,
+ operand_2: Word,
+ operand_3: Word,
+ operand_4: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(operand_1);
+ instruction.add_operand(operand_2);
+ instruction.add_operand(operand_3);
+ instruction.add_operand(operand_4);
+ instruction
+ }
+
+ pub(super) fn relational(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(expr_id);
+ instruction
+ }
+
+ pub(super) fn atomic_binary(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ pointer: Word,
+ scope_id: Word,
+ semantics_id: Word,
+ value: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(pointer);
+ instruction.add_operand(scope_id);
+ instruction.add_operand(semantics_id);
+ instruction.add_operand(value);
+ instruction
+ }
+
+ //
+ // Bit Instructions
+ //
+
+ //
+ // Relational and Logical Instructions
+ //
+
+ //
+ // Derivative Instructions
+ //
+
+ pub(super) fn derivative(op: Op, result_type_id: Word, id: Word, expr_id: Word) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(expr_id);
+ instruction
+ }
+
+ //
+ // Control-Flow Instructions
+ //
+
+ pub(super) fn phi(
+ result_type_id: Word,
+ result_id: Word,
+ var_parent_pairs: &[(Word, Word)],
+ ) -> Self {
+ let mut instruction = Self::new(Op::Phi);
+ instruction.add_operand(result_type_id);
+ instruction.add_operand(result_id);
+ for &(variable, parent) in var_parent_pairs {
+ instruction.add_operand(variable);
+ instruction.add_operand(parent);
+ }
+ instruction
+ }
+
+ pub(super) fn selection_merge(
+ merge_id: Word,
+ selection_control: spirv::SelectionControl,
+ ) -> Self {
+ let mut instruction = Self::new(Op::SelectionMerge);
+ instruction.add_operand(merge_id);
+ instruction.add_operand(selection_control.bits());
+ instruction
+ }
+
+ pub(super) fn loop_merge(
+ merge_id: Word,
+ continuing_id: Word,
+ selection_control: spirv::SelectionControl,
+ ) -> Self {
+ let mut instruction = Self::new(Op::LoopMerge);
+ instruction.add_operand(merge_id);
+ instruction.add_operand(continuing_id);
+ instruction.add_operand(selection_control.bits());
+ instruction
+ }
+
+ pub(super) fn label(id: Word) -> Self {
+ let mut instruction = Self::new(Op::Label);
+ instruction.set_result(id);
+ instruction
+ }
+
+ pub(super) fn branch(id: Word) -> Self {
+ let mut instruction = Self::new(Op::Branch);
+ instruction.add_operand(id);
+ instruction
+ }
+
+ // TODO Branch Weights not implemented.
+ pub(super) fn branch_conditional(
+ condition_id: Word,
+ true_label: Word,
+ false_label: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::BranchConditional);
+ instruction.add_operand(condition_id);
+ instruction.add_operand(true_label);
+ instruction.add_operand(false_label);
+ instruction
+ }
+
+ pub(super) fn switch(selector_id: Word, default_id: Word, cases: &[Case]) -> Self {
+ let mut instruction = Self::new(Op::Switch);
+ instruction.add_operand(selector_id);
+ instruction.add_operand(default_id);
+ for case in cases {
+ instruction.add_operand(case.value);
+ instruction.add_operand(case.label_id);
+ }
+ instruction
+ }
+
+ pub(super) fn select(
+ result_type_id: Word,
+ id: Word,
+ condition_id: Word,
+ accept_id: Word,
+ reject_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::Select);
+ instruction.add_operand(result_type_id);
+ instruction.add_operand(id);
+ instruction.add_operand(condition_id);
+ instruction.add_operand(accept_id);
+ instruction.add_operand(reject_id);
+ instruction
+ }
+
+ pub(super) const fn kill() -> Self {
+ Self::new(Op::Kill)
+ }
+
+ pub(super) const fn return_void() -> Self {
+ Self::new(Op::Return)
+ }
+
+ pub(super) fn return_value(value_id: Word) -> Self {
+ let mut instruction = Self::new(Op::ReturnValue);
+ instruction.add_operand(value_id);
+ instruction
+ }
+
+ //
+ // Atomic Instructions
+ //
+
+ //
+ // Primitive Instructions
+ //
+
+ // Barriers
+
+ pub(super) fn control_barrier(
+ exec_scope_id: Word,
+ mem_scope_id: Word,
+ semantics_id: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::ControlBarrier);
+ instruction.add_operand(exec_scope_id);
+ instruction.add_operand(mem_scope_id);
+ instruction.add_operand(semantics_id);
+ instruction
+ }
+}
+
+impl From<crate::StorageFormat> for spirv::ImageFormat {
+ fn from(format: crate::StorageFormat) -> Self {
+ use crate::StorageFormat as Sf;
+ match format {
+ Sf::R8Unorm => Self::R8,
+ Sf::R8Snorm => Self::R8Snorm,
+ Sf::R8Uint => Self::R8ui,
+ Sf::R8Sint => Self::R8i,
+ Sf::R16Uint => Self::R16ui,
+ Sf::R16Sint => Self::R16i,
+ Sf::R16Float => Self::R16f,
+ Sf::Rg8Unorm => Self::Rg8,
+ Sf::Rg8Snorm => Self::Rg8Snorm,
+ Sf::Rg8Uint => Self::Rg8ui,
+ Sf::Rg8Sint => Self::Rg8i,
+ Sf::R32Uint => Self::R32ui,
+ Sf::R32Sint => Self::R32i,
+ Sf::R32Float => Self::R32f,
+ Sf::Rg16Uint => Self::Rg16ui,
+ Sf::Rg16Sint => Self::Rg16i,
+ Sf::Rg16Float => Self::Rg16f,
+ Sf::Rgba8Unorm => Self::Rgba8,
+ Sf::Rgba8Snorm => Self::Rgba8Snorm,
+ Sf::Rgba8Uint => Self::Rgba8ui,
+ Sf::Rgba8Sint => Self::Rgba8i,
+ Sf::Bgra8Unorm => Self::Unknown,
+ Sf::Rgb10a2Uint => Self::Rgb10a2ui,
+ Sf::Rgb10a2Unorm => Self::Rgb10A2,
+ Sf::Rg11b10Float => Self::R11fG11fB10f,
+ Sf::Rg32Uint => Self::Rg32ui,
+ Sf::Rg32Sint => Self::Rg32i,
+ Sf::Rg32Float => Self::Rg32f,
+ Sf::Rgba16Uint => Self::Rgba16ui,
+ Sf::Rgba16Sint => Self::Rgba16i,
+ Sf::Rgba16Float => Self::Rgba16f,
+ Sf::Rgba32Uint => Self::Rgba32ui,
+ Sf::Rgba32Sint => Self::Rgba32i,
+ Sf::Rgba32Float => Self::Rgba32f,
+ Sf::R16Unorm => Self::R16,
+ Sf::R16Snorm => Self::R16Snorm,
+ Sf::Rg16Unorm => Self::Rg16,
+ Sf::Rg16Snorm => Self::Rg16Snorm,
+ Sf::Rgba16Unorm => Self::Rgba16,
+ Sf::Rgba16Snorm => Self::Rgba16Snorm,
+ }
+ }
+}
+
+impl From<crate::ImageDimension> for spirv::Dim {
+ fn from(dim: crate::ImageDimension) -> Self {
+ use crate::ImageDimension as Id;
+ match dim {
+ Id::D1 => Self::Dim1D,
+ Id::D2 => Self::Dim2D,
+ Id::D3 => Self::Dim3D,
+ Id::Cube => Self::DimCube,
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/layout.rs b/third_party/rust/naga/src/back/spv/layout.rs
new file mode 100644
index 0000000000..39117a3d2a
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/layout.rs
@@ -0,0 +1,210 @@
+use super::{Instruction, LogicalLayout, PhysicalLayout};
+use spirv::{Op, Word, MAGIC_NUMBER};
+use std::iter;
+
+// https://github.com/KhronosGroup/SPIRV-Headers/pull/195
+const GENERATOR: Word = 28;
+
+impl PhysicalLayout {
+ pub(super) const fn new(version: Word) -> Self {
+ PhysicalLayout {
+ magic_number: MAGIC_NUMBER,
+ version,
+ generator: GENERATOR,
+ bound: 0,
+ instruction_schema: 0x0u32,
+ }
+ }
+
+ pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) {
+ sink.extend(iter::once(self.magic_number));
+ sink.extend(iter::once(self.version));
+ sink.extend(iter::once(self.generator));
+ sink.extend(iter::once(self.bound));
+ sink.extend(iter::once(self.instruction_schema));
+ }
+}
+
+impl super::recyclable::Recyclable for PhysicalLayout {
+ fn recycle(self) -> Self {
+ PhysicalLayout {
+ magic_number: self.magic_number,
+ version: self.version,
+ generator: self.generator,
+ instruction_schema: self.instruction_schema,
+ bound: 0,
+ }
+ }
+}
+
+impl LogicalLayout {
+ pub(super) fn in_words(&self, sink: &mut impl Extend<Word>) {
+ sink.extend(self.capabilities.iter().cloned());
+ sink.extend(self.extensions.iter().cloned());
+ sink.extend(self.ext_inst_imports.iter().cloned());
+ sink.extend(self.memory_model.iter().cloned());
+ sink.extend(self.entry_points.iter().cloned());
+ sink.extend(self.execution_modes.iter().cloned());
+ sink.extend(self.debugs.iter().cloned());
+ sink.extend(self.annotations.iter().cloned());
+ sink.extend(self.declarations.iter().cloned());
+ sink.extend(self.function_declarations.iter().cloned());
+ sink.extend(self.function_definitions.iter().cloned());
+ }
+}
+
+impl super::recyclable::Recyclable for LogicalLayout {
+ fn recycle(self) -> Self {
+ Self {
+ capabilities: self.capabilities.recycle(),
+ extensions: self.extensions.recycle(),
+ ext_inst_imports: self.ext_inst_imports.recycle(),
+ memory_model: self.memory_model.recycle(),
+ entry_points: self.entry_points.recycle(),
+ execution_modes: self.execution_modes.recycle(),
+ debugs: self.debugs.recycle(),
+ annotations: self.annotations.recycle(),
+ declarations: self.declarations.recycle(),
+ function_declarations: self.function_declarations.recycle(),
+ function_definitions: self.function_definitions.recycle(),
+ }
+ }
+}
+
+impl Instruction {
+ pub(super) const fn new(op: Op) -> Self {
+ Instruction {
+ op,
+ wc: 1, // Always start at 1 for the first word (OP + WC),
+ type_id: None,
+ result_id: None,
+ operands: vec![],
+ }
+ }
+
+ #[allow(clippy::panic)]
+ pub(super) fn set_type(&mut self, id: Word) {
+ assert!(self.type_id.is_none(), "Type can only be set once");
+ self.type_id = Some(id);
+ self.wc += 1;
+ }
+
+ #[allow(clippy::panic)]
+ pub(super) fn set_result(&mut self, id: Word) {
+ assert!(self.result_id.is_none(), "Result can only be set once");
+ self.result_id = Some(id);
+ self.wc += 1;
+ }
+
+ pub(super) fn add_operand(&mut self, operand: Word) {
+ self.operands.push(operand);
+ self.wc += 1;
+ }
+
+ pub(super) fn add_operands(&mut self, operands: Vec<Word>) {
+ for operand in operands.into_iter() {
+ self.add_operand(operand)
+ }
+ }
+
+ pub(super) fn to_words(&self, sink: &mut impl Extend<Word>) {
+ sink.extend(Some(self.wc << 16 | self.op as u32));
+ sink.extend(self.type_id);
+ sink.extend(self.result_id);
+ sink.extend(self.operands.iter().cloned());
+ }
+}
+
+impl Instruction {
+ #[cfg(test)]
+ fn validate(&self, words: &[Word]) {
+ let mut inst_index = 0;
+ let (wc, op) = ((words[inst_index] >> 16) as u16, words[inst_index] as u16);
+ inst_index += 1;
+
+ assert_eq!(wc, words.len() as u16);
+ assert_eq!(op, self.op as u16);
+
+ if self.type_id.is_some() {
+ assert_eq!(words[inst_index], self.type_id.unwrap());
+ inst_index += 1;
+ }
+
+ if self.result_id.is_some() {
+ assert_eq!(words[inst_index], self.result_id.unwrap());
+ inst_index += 1;
+ }
+
+ for (op_index, i) in (inst_index..wc as usize).enumerate() {
+ assert_eq!(words[i], self.operands[op_index]);
+ }
+ }
+}
+
+#[test]
+fn test_physical_layout_in_words() {
+ let bound = 5;
+ let version = 0x10203;
+
+ let mut output = vec![];
+ let mut layout = PhysicalLayout::new(version);
+ layout.bound = bound;
+
+ layout.in_words(&mut output);
+
+ assert_eq!(&output, &[MAGIC_NUMBER, version, GENERATOR, bound, 0,]);
+}
+
+#[test]
+fn test_logical_layout_in_words() {
+ let mut output = vec![];
+ let mut layout = LogicalLayout::default();
+ let layout_vectors = 11;
+ let mut instructions = Vec::with_capacity(layout_vectors);
+
+ let vector_names = &[
+ "Capabilities",
+ "Extensions",
+ "External Instruction Imports",
+ "Memory Model",
+ "Entry Points",
+ "Execution Modes",
+ "Debugs",
+ "Annotations",
+ "Declarations",
+ "Function Declarations",
+ "Function Definitions",
+ ];
+
+ for (i, _) in vector_names.iter().enumerate().take(layout_vectors) {
+ let mut dummy_instruction = Instruction::new(Op::Constant);
+ dummy_instruction.set_type((i + 1) as u32);
+ dummy_instruction.set_result((i + 2) as u32);
+ dummy_instruction.add_operand((i + 3) as u32);
+ dummy_instruction.add_operands(super::helpers::string_to_words(
+ format!("This is the vector: {}", vector_names[i]).as_str(),
+ ));
+ instructions.push(dummy_instruction);
+ }
+
+ instructions[0].to_words(&mut layout.capabilities);
+ instructions[1].to_words(&mut layout.extensions);
+ instructions[2].to_words(&mut layout.ext_inst_imports);
+ instructions[3].to_words(&mut layout.memory_model);
+ instructions[4].to_words(&mut layout.entry_points);
+ instructions[5].to_words(&mut layout.execution_modes);
+ instructions[6].to_words(&mut layout.debugs);
+ instructions[7].to_words(&mut layout.annotations);
+ instructions[8].to_words(&mut layout.declarations);
+ instructions[9].to_words(&mut layout.function_declarations);
+ instructions[10].to_words(&mut layout.function_definitions);
+
+ layout.in_words(&mut output);
+
+ let mut index: usize = 0;
+ for instruction in instructions {
+ let wc = instruction.wc as usize;
+ instruction.validate(&output[index..index + wc]);
+ index += wc;
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/mod.rs b/third_party/rust/naga/src/back/spv/mod.rs
new file mode 100644
index 0000000000..b7d57be0d4
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/mod.rs
@@ -0,0 +1,748 @@
+/*!
+Backend for [SPIR-V][spv] (Standard Portable Intermediate Representation).
+
+[spv]: https://www.khronos.org/registry/SPIR-V/
+*/
+
+mod block;
+mod helpers;
+mod image;
+mod index;
+mod instructions;
+mod layout;
+mod ray;
+mod recyclable;
+mod selection;
+mod writer;
+
+pub use spirv::Capability;
+
+use crate::arena::Handle;
+use crate::proc::{BoundsCheckPolicies, TypeResolution};
+
+use spirv::Word;
+use std::ops;
+use thiserror::Error;
+
+#[derive(Clone)]
+struct PhysicalLayout {
+ magic_number: Word,
+ version: Word,
+ generator: Word,
+ bound: Word,
+ instruction_schema: Word,
+}
+
+#[derive(Default)]
+struct LogicalLayout {
+ capabilities: Vec<Word>,
+ extensions: Vec<Word>,
+ ext_inst_imports: Vec<Word>,
+ memory_model: Vec<Word>,
+ entry_points: Vec<Word>,
+ execution_modes: Vec<Word>,
+ debugs: Vec<Word>,
+ annotations: Vec<Word>,
+ declarations: Vec<Word>,
+ function_declarations: Vec<Word>,
+ function_definitions: Vec<Word>,
+}
+
+struct Instruction {
+ op: spirv::Op,
+ wc: u32,
+ type_id: Option<Word>,
+ result_id: Option<Word>,
+ operands: Vec<Word>,
+}
+
+const BITS_PER_BYTE: crate::Bytes = 8;
+
+#[derive(Clone, Debug, Error)]
+pub enum Error {
+ #[error("The requested entry point couldn't be found")]
+ EntryPointNotFound,
+ #[error("target SPIRV-{0}.{1} is not supported")]
+ UnsupportedVersion(u8, u8),
+ #[error("using {0} requires at least one of the capabilities {1:?}, but none are available")]
+ MissingCapabilities(&'static str, Vec<Capability>),
+ #[error("unimplemented {0}")]
+ FeatureNotImplemented(&'static str),
+ #[error("module is not validated properly: {0}")]
+ Validation(&'static str),
+}
+
+#[derive(Default)]
+struct IdGenerator(Word);
+
+impl IdGenerator {
+ fn next(&mut self) -> Word {
+ self.0 += 1;
+ self.0
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct DebugInfo<'a> {
+ pub source_code: &'a str,
+ pub file_name: &'a std::path::Path,
+}
+
+/// A SPIR-V block to which we are still adding instructions.
+///
+/// A `Block` represents a SPIR-V block that does not yet have a termination
+/// instruction like `OpBranch` or `OpReturn`.
+///
+/// The `OpLabel` that starts the block is implicit. It will be emitted based on
+/// `label_id` when we write the block to a `LogicalLayout`.
+///
+/// To terminate a `Block`, pass the block and the termination instruction to
+/// `Function::consume`. This takes ownership of the `Block` and transforms it
+/// into a `TerminatedBlock`.
+struct Block {
+ label_id: Word,
+ body: Vec<Instruction>,
+}
+
+/// A SPIR-V block that ends with a termination instruction.
+struct TerminatedBlock {
+ label_id: Word,
+ body: Vec<Instruction>,
+}
+
+impl Block {
+ const fn new(label_id: Word) -> Self {
+ Block {
+ label_id,
+ body: Vec::new(),
+ }
+ }
+}
+
+struct LocalVariable {
+ id: Word,
+ instruction: Instruction,
+}
+
+struct ResultMember {
+ id: Word,
+ type_id: Word,
+ built_in: Option<crate::BuiltIn>,
+}
+
+struct EntryPointContext {
+ argument_ids: Vec<Word>,
+ results: Vec<ResultMember>,
+}
+
+#[derive(Default)]
+struct Function {
+ signature: Option<Instruction>,
+ parameters: Vec<FunctionArgument>,
+ variables: crate::FastHashMap<Handle<crate::LocalVariable>, LocalVariable>,
+ blocks: Vec<TerminatedBlock>,
+ entry_point_context: Option<EntryPointContext>,
+}
+
+impl Function {
+ fn consume(&mut self, mut block: Block, termination: Instruction) {
+ block.body.push(termination);
+ self.blocks.push(TerminatedBlock {
+ label_id: block.label_id,
+ body: block.body,
+ })
+ }
+
+ fn parameter_id(&self, index: u32) -> Word {
+ match self.entry_point_context {
+ Some(ref context) => context.argument_ids[index as usize],
+ None => self.parameters[index as usize]
+ .instruction
+ .result_id
+ .unwrap(),
+ }
+ }
+}
+
+/// Characteristics of a SPIR-V `OpTypeImage` type.
+///
+/// SPIR-V requires non-composite types to be unique, including images. Since we
+/// use `LocalType` for this deduplication, it's essential that `LocalImageType`
+/// be equal whenever the corresponding `OpTypeImage`s would be. To reduce the
+/// likelihood of mistakes, we use fields that correspond exactly to the
+/// operands of an `OpTypeImage` instruction, using the actual SPIR-V types
+/// where practical.
+#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
+struct LocalImageType {
+ sampled_type: crate::ScalarKind,
+ dim: spirv::Dim,
+ flags: ImageTypeFlags,
+ image_format: spirv::ImageFormat,
+}
+
+bitflags::bitflags! {
+ /// Flags corresponding to the boolean(-ish) parameters to OpTypeImage.
+ #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
+ pub struct ImageTypeFlags: u8 {
+ const DEPTH = 0x1;
+ const ARRAYED = 0x2;
+ const MULTISAMPLED = 0x4;
+ const SAMPLED = 0x8;
+ }
+}
+
+impl LocalImageType {
+ /// Construct a `LocalImageType` from the fields of a `TypeInner::Image`.
+ fn from_inner(dim: crate::ImageDimension, arrayed: bool, class: crate::ImageClass) -> Self {
+ let make_flags = |multi: bool, other: ImageTypeFlags| -> ImageTypeFlags {
+ let mut flags = other;
+ flags.set(ImageTypeFlags::ARRAYED, arrayed);
+ flags.set(ImageTypeFlags::MULTISAMPLED, multi);
+ flags
+ };
+
+ let dim = spirv::Dim::from(dim);
+
+ match class {
+ crate::ImageClass::Sampled { kind, multi } => LocalImageType {
+ sampled_type: kind,
+ dim,
+ flags: make_flags(multi, ImageTypeFlags::SAMPLED),
+ image_format: spirv::ImageFormat::Unknown,
+ },
+ crate::ImageClass::Depth { multi } => LocalImageType {
+ sampled_type: crate::ScalarKind::Float,
+ dim,
+ flags: make_flags(multi, ImageTypeFlags::DEPTH | ImageTypeFlags::SAMPLED),
+ image_format: spirv::ImageFormat::Unknown,
+ },
+ crate::ImageClass::Storage { format, access: _ } => LocalImageType {
+ sampled_type: crate::ScalarKind::from(format),
+ dim,
+ flags: make_flags(false, ImageTypeFlags::empty()),
+ image_format: format.into(),
+ },
+ }
+ }
+}
+
+/// A SPIR-V type constructed during code generation.
+///
+/// This is the variant of [`LookupType`] used to represent types that might not
+/// be available in the arena. Variants are present here for one of two reasons:
+///
+/// - They represent types synthesized during code generation, as explained
+/// in the documentation for [`LookupType`].
+///
+/// - They represent types for which SPIR-V forbids duplicate `OpType...`
+/// instructions, requiring deduplication.
+///
+/// This is not a complete copy of [`TypeInner`]: for example, SPIR-V generation
+/// never synthesizes new struct types, so `LocalType` has nothing for that.
+///
+/// Each `LocalType` variant should be handled identically to its analogous
+/// `TypeInner` variant. You can use the [`make_local`] function to help with
+/// this, by converting everything possible to a `LocalType` before inspecting
+/// it.
+///
+/// ## `Localtype` equality and SPIR-V `OpType` uniqueness
+///
+/// The definition of `Eq` on `LocalType` is carefully chosen to help us follow
+/// certain SPIR-V rules. SPIR-V §2.8 requires some classes of `OpType...`
+/// instructions to be unique; for example, you can't have two `OpTypeInt 32 1`
+/// instructions in the same module. All 32-bit signed integers must use the
+/// same type id.
+///
+/// All SPIR-V types that must be unique can be represented as a `LocalType`,
+/// and two `LocalType`s are always `Eq` if SPIR-V would require them to use the
+/// same `OpType...` instruction. This lets us avoid duplicates by recording the
+/// ids of the type instructions we've already generated in a hash table,
+/// [`Writer::lookup_type`], keyed by `LocalType`.
+///
+/// As another example, [`LocalImageType`], stored in the `LocalType::Image`
+/// variant, is designed to help us deduplicate `OpTypeImage` instructions. See
+/// its documentation for details.
+///
+/// `LocalType` also includes variants like `Pointer` that do not need to be
+/// unique - but it is harmless to avoid the duplication.
+///
+/// As it always must, the `Hash` implementation respects the `Eq` relation.
+///
+/// [`TypeInner`]: crate::TypeInner
+#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
+enum LocalType {
+ /// A scalar, vector, or pointer to one of those.
+ Value {
+ /// If `None`, this represents a scalar type. If `Some`, this represents
+ /// a vector type of the given size.
+ vector_size: Option<crate::VectorSize>,
+ scalar: crate::Scalar,
+ pointer_space: Option<spirv::StorageClass>,
+ },
+ /// A matrix of floating-point values.
+ Matrix {
+ columns: crate::VectorSize,
+ rows: crate::VectorSize,
+ width: crate::Bytes,
+ },
+ Pointer {
+ base: Handle<crate::Type>,
+ class: spirv::StorageClass,
+ },
+ Image(LocalImageType),
+ SampledImage {
+ image_type_id: Word,
+ },
+ Sampler,
+ /// Equivalent to a [`LocalType::Pointer`] whose `base` is a Naga IR [`BindingArray`]. SPIR-V
+ /// permits duplicated `OpTypePointer` ids, so it's fine to have two different [`LocalType`]
+ /// representations for pointer types.
+ ///
+ /// [`BindingArray`]: crate::TypeInner::BindingArray
+ PointerToBindingArray {
+ base: Handle<crate::Type>,
+ size: u32,
+ space: crate::AddressSpace,
+ },
+ BindingArray {
+ base: Handle<crate::Type>,
+ size: u32,
+ },
+ AccelerationStructure,
+ RayQuery,
+}
+
+/// A type encountered during SPIR-V generation.
+///
+/// In the process of writing SPIR-V, we need to synthesize various types for
+/// intermediate results and such: pointer types, vector/matrix component types,
+/// or even booleans, which usually appear in SPIR-V code even when they're not
+/// used by the module source.
+///
+/// However, we can't use `crate::Type` or `crate::TypeInner` for these, as the
+/// type arena may not contain what we need (it only contains types used
+/// directly by other parts of the IR), and the IR module is immutable, so we
+/// can't add anything to it.
+///
+/// So for local use in the SPIR-V writer, we use this type, which holds either
+/// a handle into the arena, or a [`LocalType`] containing something synthesized
+/// locally.
+///
+/// This is very similar to the [`proc::TypeResolution`] enum, with `LocalType`
+/// playing the role of `TypeInner`. However, `LocalType` also has other
+/// properties needed for SPIR-V generation; see the description of
+/// [`LocalType`] for details.
+///
+/// [`proc::TypeResolution`]: crate::proc::TypeResolution
+#[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)]
+enum LookupType {
+ Handle(Handle<crate::Type>),
+ Local(LocalType),
+}
+
+impl From<LocalType> for LookupType {
+ fn from(local: LocalType) -> Self {
+ Self::Local(local)
+ }
+}
+
+#[derive(Debug, PartialEq, Clone, Hash, Eq)]
+struct LookupFunctionType {
+ parameter_type_ids: Vec<Word>,
+ return_type_id: Word,
+}
+
+fn make_local(inner: &crate::TypeInner) -> Option<LocalType> {
+ Some(match *inner {
+ crate::TypeInner::Scalar(scalar) | crate::TypeInner::Atomic(scalar) => LocalType::Value {
+ vector_size: None,
+ scalar,
+ pointer_space: None,
+ },
+ crate::TypeInner::Vector { size, scalar } => LocalType::Value {
+ vector_size: Some(size),
+ scalar,
+ pointer_space: None,
+ },
+ crate::TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => LocalType::Matrix {
+ columns,
+ rows,
+ width: scalar.width,
+ },
+ crate::TypeInner::Pointer { base, space } => LocalType::Pointer {
+ base,
+ class: helpers::map_storage_class(space),
+ },
+ crate::TypeInner::ValuePointer {
+ size,
+ scalar,
+ space,
+ } => LocalType::Value {
+ vector_size: size,
+ scalar,
+ pointer_space: Some(helpers::map_storage_class(space)),
+ },
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)),
+ crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler,
+ crate::TypeInner::AccelerationStructure => LocalType::AccelerationStructure,
+ crate::TypeInner::RayQuery => LocalType::RayQuery,
+ crate::TypeInner::Array { .. }
+ | crate::TypeInner::Struct { .. }
+ | crate::TypeInner::BindingArray { .. } => return None,
+ })
+}
+
+#[derive(Debug)]
+enum Dimension {
+ Scalar,
+ Vector,
+ Matrix,
+}
+
+/// A map from evaluated [`Expression`](crate::Expression)s to their SPIR-V ids.
+///
+/// When we emit code to evaluate a given `Expression`, we record the
+/// SPIR-V id of its value here, under its `Handle<Expression>` index.
+///
+/// A `CachedExpressions` value can be indexed by a `Handle<Expression>` value.
+///
+/// [emit]: index.html#expression-evaluation-time-and-scope
+#[derive(Default)]
+struct CachedExpressions {
+ ids: Vec<Word>,
+}
+impl CachedExpressions {
+ fn reset(&mut self, length: usize) {
+ self.ids.clear();
+ self.ids.resize(length, 0);
+ }
+}
+impl ops::Index<Handle<crate::Expression>> for CachedExpressions {
+ type Output = Word;
+ fn index(&self, h: Handle<crate::Expression>) -> &Word {
+ let id = &self.ids[h.index()];
+ if *id == 0 {
+ unreachable!("Expression {:?} is not cached!", h);
+ }
+ id
+ }
+}
+impl ops::IndexMut<Handle<crate::Expression>> for CachedExpressions {
+ fn index_mut(&mut self, h: Handle<crate::Expression>) -> &mut Word {
+ let id = &mut self.ids[h.index()];
+ if *id != 0 {
+ unreachable!("Expression {:?} is already cached!", h);
+ }
+ id
+ }
+}
+impl recyclable::Recyclable for CachedExpressions {
+ fn recycle(self) -> Self {
+ CachedExpressions {
+ ids: self.ids.recycle(),
+ }
+ }
+}
+
+#[derive(Eq, Hash, PartialEq)]
+enum CachedConstant {
+ Literal(crate::Literal),
+ Composite {
+ ty: LookupType,
+ constituent_ids: Vec<Word>,
+ },
+ ZeroValue(Word),
+}
+
+#[derive(Clone)]
+struct GlobalVariable {
+ /// ID of the OpVariable that declares the global.
+ ///
+ /// If you need the variable's value, use [`access_id`] instead of this
+ /// field. If we wrapped the Naga IR `GlobalVariable`'s type in a struct to
+ /// comply with Vulkan's requirements, then this points to the `OpVariable`
+ /// with the synthesized struct type, whereas `access_id` points to the
+ /// field of said struct that holds the variable's actual value.
+ ///
+ /// This is used to compute the `access_id` pointer in function prologues,
+ /// and used for `ArrayLength` expressions, which do need the struct.
+ ///
+ /// [`access_id`]: GlobalVariable::access_id
+ var_id: Word,
+
+ /// For `AddressSpace::Handle` variables, this ID is recorded in the function
+ /// prelude block (and reset before every function) as `OpLoad` of the variable.
+ /// It is then used for all the global ops, such as `OpImageSample`.
+ handle_id: Word,
+
+ /// Actual ID used to access this variable.
+ /// For wrapped buffer variables, this ID is `OpAccessChain` into the
+ /// wrapper. Otherwise, the same as `var_id`.
+ ///
+ /// Vulkan requires that globals in the `StorageBuffer` and `Uniform` storage
+ /// classes must be structs with the `Block` decoration, but WGSL and Naga IR
+ /// make no such requirement. So for such variables, we generate a wrapper struct
+ /// type with a single element of the type given by Naga, generate an
+ /// `OpAccessChain` for that member in the function prelude, and use that pointer
+ /// to refer to the global in the function body. This is the id of that access,
+ /// updated for each function in `write_function`.
+ access_id: Word,
+}
+
+impl GlobalVariable {
+ const fn dummy() -> Self {
+ Self {
+ var_id: 0,
+ handle_id: 0,
+ access_id: 0,
+ }
+ }
+
+ const fn new(id: Word) -> Self {
+ Self {
+ var_id: id,
+ handle_id: 0,
+ access_id: 0,
+ }
+ }
+
+ /// Prepare `self` for use within a single function.
+ fn reset_for_function(&mut self) {
+ self.handle_id = 0;
+ self.access_id = 0;
+ }
+}
+
+struct FunctionArgument {
+ /// Actual instruction of the argument.
+ instruction: Instruction,
+ handle_id: Word,
+}
+
+/// General information needed to emit SPIR-V for Naga statements.
+struct BlockContext<'w> {
+ /// The writer handling the module to which this code belongs.
+ writer: &'w mut Writer,
+
+ /// The [`Module`](crate::Module) for which we're generating code.
+ ir_module: &'w crate::Module,
+
+ /// The [`Function`](crate::Function) for which we're generating code.
+ ir_function: &'w crate::Function,
+
+ /// Information module validation produced about
+ /// [`ir_function`](BlockContext::ir_function).
+ fun_info: &'w crate::valid::FunctionInfo,
+
+ /// The [`spv::Function`](Function) to which we are contributing SPIR-V instructions.
+ function: &'w mut Function,
+
+ /// SPIR-V ids for expressions we've evaluated.
+ cached: CachedExpressions,
+
+ /// The `Writer`'s temporary vector, for convenience.
+ temp_list: Vec<Word>,
+
+ /// Tracks the constness of `Expression`s residing in `self.ir_function.expressions`
+ expression_constness: crate::proc::ExpressionConstnessTracker,
+}
+
+impl BlockContext<'_> {
+ fn gen_id(&mut self) -> Word {
+ self.writer.id_gen.next()
+ }
+
+ fn get_type_id(&mut self, lookup_type: LookupType) -> Word {
+ self.writer.get_type_id(lookup_type)
+ }
+
+ fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
+ self.writer.get_expression_type_id(tr)
+ }
+
+ fn get_index_constant(&mut self, index: Word) -> Word {
+ self.writer.get_constant_scalar(crate::Literal::U32(index))
+ }
+
+ fn get_scope_constant(&mut self, scope: Word) -> Word {
+ self.writer
+ .get_constant_scalar(crate::Literal::I32(scope as _))
+ }
+}
+
+#[derive(Clone, Copy, Default)]
+struct LoopContext {
+ continuing_id: Option<Word>,
+ break_id: Option<Word>,
+}
+
+pub struct Writer {
+ physical_layout: PhysicalLayout,
+ logical_layout: LogicalLayout,
+ id_gen: IdGenerator,
+
+ /// The set of capabilities modules are permitted to use.
+ ///
+ /// This is initialized from `Options::capabilities`.
+ capabilities_available: Option<crate::FastHashSet<Capability>>,
+
+ /// The set of capabilities used by this module.
+ ///
+ /// If `capabilities_available` is `Some`, then this is always a subset of
+ /// that.
+ capabilities_used: crate::FastIndexSet<Capability>,
+
+ /// The set of spirv extensions used.
+ extensions_used: crate::FastIndexSet<&'static str>,
+
+ debugs: Vec<Instruction>,
+ annotations: Vec<Instruction>,
+ flags: WriterFlags,
+ bounds_check_policies: BoundsCheckPolicies,
+ zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode,
+ void_type: Word,
+ //TODO: convert most of these into vectors, addressable by handle indices
+ lookup_type: crate::FastHashMap<LookupType, Word>,
+ lookup_function: crate::FastHashMap<Handle<crate::Function>, Word>,
+ lookup_function_type: crate::FastHashMap<LookupFunctionType, Word>,
+ /// Indexed by const-expression handle indexes
+ constant_ids: Vec<Word>,
+ cached_constants: crate::FastHashMap<CachedConstant, Word>,
+ global_variables: Vec<GlobalVariable>,
+ binding_map: BindingMap,
+
+ // Cached expressions are only meaningful within a BlockContext, but we
+ // retain the table here between functions to save heap allocations.
+ saved_cached: CachedExpressions,
+
+ gl450_ext_inst_id: Word,
+
+ // Just a temporary list of SPIR-V ids
+ temp_list: Vec<Word>,
+}
+
+bitflags::bitflags! {
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct WriterFlags: u32 {
+ /// Include debug labels for everything.
+ const DEBUG = 0x1;
+ /// Flip Y coordinate of `BuiltIn::Position` output.
+ const ADJUST_COORDINATE_SPACE = 0x2;
+ /// Emit `OpName` for input/output locations.
+ /// Contrary to spec, some drivers treat it as semantic, not allowing
+ /// any conflicts.
+ const LABEL_VARYINGS = 0x4;
+ /// Emit `PointSize` output builtin to vertex shaders, which is
+ /// required for drawing with `PointList` topology.
+ const FORCE_POINT_SIZE = 0x8;
+ /// Clamp `BuiltIn::FragDepth` output between 0 and 1.
+ const CLAMP_FRAG_DEPTH = 0x10;
+ }
+}
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct BindingInfo {
+ /// If the binding is an unsized binding array, this overrides the size.
+ pub binding_array_size: Option<u32>,
+}
+
+// Using `BTreeMap` instead of `HashMap` so that we can hash itself.
+pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindingInfo>;
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub enum ZeroInitializeWorkgroupMemoryMode {
+ /// Via `VK_KHR_zero_initialize_workgroup_memory` or Vulkan 1.3
+ Native,
+ /// Via assignments + barrier
+ Polyfill,
+ None,
+}
+
+#[derive(Debug, Clone)]
+pub struct Options<'a> {
+ /// (Major, Minor) target version of the SPIR-V.
+ pub lang_version: (u8, u8),
+
+ /// Configuration flags for the writer.
+ pub flags: WriterFlags,
+
+ /// Map of resources to information about the binding.
+ pub binding_map: BindingMap,
+
+ /// If given, the set of capabilities modules are allowed to use. Code that
+ /// requires capabilities beyond these is rejected with an error.
+ ///
+ /// If this is `None`, all capabilities are permitted.
+ pub capabilities: Option<crate::FastHashSet<Capability>>,
+
+ /// How should generate code handle array, vector, matrix, or image texel
+ /// indices that are out of range?
+ pub bounds_check_policies: BoundsCheckPolicies,
+
+ /// Dictates the way workgroup variables should be zero initialized
+ pub zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode,
+
+ pub debug_info: Option<DebugInfo<'a>>,
+}
+
+impl<'a> Default for Options<'a> {
+ fn default() -> Self {
+ let mut flags = WriterFlags::ADJUST_COORDINATE_SPACE
+ | WriterFlags::LABEL_VARYINGS
+ | WriterFlags::CLAMP_FRAG_DEPTH;
+ if cfg!(debug_assertions) {
+ flags |= WriterFlags::DEBUG;
+ }
+ Options {
+ lang_version: (1, 0),
+ flags,
+ binding_map: BindingMap::default(),
+ capabilities: None,
+ bounds_check_policies: crate::proc::BoundsCheckPolicies::default(),
+ zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode::Polyfill,
+ debug_info: None,
+ }
+ }
+}
+
+// A subset of options meant to be changed per pipeline.
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct PipelineOptions {
+ /// The stage of the entry point.
+ pub shader_stage: crate::ShaderStage,
+ /// The name of the entry point.
+ ///
+ /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown.
+ pub entry_point: String,
+}
+
+pub fn write_vec(
+ module: &crate::Module,
+ info: &crate::valid::ModuleInfo,
+ options: &Options,
+ pipeline_options: Option<&PipelineOptions>,
+) -> Result<Vec<u32>, Error> {
+ let mut words: Vec<u32> = Vec::new();
+ let mut w = Writer::new(options)?;
+
+ w.write(
+ module,
+ info,
+ pipeline_options,
+ &options.debug_info,
+ &mut words,
+ )?;
+ Ok(words)
+}
diff --git a/third_party/rust/naga/src/back/spv/ray.rs b/third_party/rust/naga/src/back/spv/ray.rs
new file mode 100644
index 0000000000..bc2c4ce3c6
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/ray.rs
@@ -0,0 +1,261 @@
+/*!
+Generating SPIR-V for ray query operations.
+*/
+
+use super::{Block, BlockContext, Instruction, LocalType, LookupType};
+use crate::arena::Handle;
+
+impl<'w> BlockContext<'w> {
+ pub(super) fn write_ray_query_function(
+ &mut self,
+ query: Handle<crate::Expression>,
+ function: &crate::RayQueryFunction,
+ block: &mut Block,
+ ) {
+ let query_id = self.cached[query];
+ match *function {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ //Note: composite extract indices and types must match `generate_ray_desc_type`
+ let desc_id = self.cached[descriptor];
+ let acc_struct_id = self.get_handle_id(acceleration_structure);
+
+ let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::U32,
+ pointer_space: None,
+ }));
+ let ray_flags_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ flag_type_id,
+ ray_flags_id,
+ desc_id,
+ &[0],
+ ));
+ let cull_mask_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ flag_type_id,
+ cull_mask_id,
+ desc_id,
+ &[1],
+ ));
+
+ let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }));
+ let tmin_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ scalar_type_id,
+ tmin_id,
+ desc_id,
+ &[2],
+ ));
+ let tmax_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ scalar_type_id,
+ tmax_id,
+ desc_id,
+ &[3],
+ ));
+
+ let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Tri),
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }));
+ let ray_origin_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ vector_type_id,
+ ray_origin_id,
+ desc_id,
+ &[4],
+ ));
+ let ray_dir_id = self.gen_id();
+ block.body.push(Instruction::composite_extract(
+ vector_type_id,
+ ray_dir_id,
+ desc_id,
+ &[5],
+ ));
+
+ block.body.push(Instruction::ray_query_initialize(
+ query_id,
+ acc_struct_id,
+ ray_flags_id,
+ cull_mask_id,
+ ray_origin_id,
+ tmin_id,
+ ray_dir_id,
+ tmax_id,
+ ));
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ let id = self.gen_id();
+ self.cached[result] = id;
+ let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
+
+ block
+ .body
+ .push(Instruction::ray_query_proceed(result_type_id, id, query_id));
+ }
+ crate::RayQueryFunction::Terminate => {}
+ }
+ }
+
+ pub(super) fn write_ray_query_get_intersection(
+ &mut self,
+ query: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> spirv::Word {
+ let query_id = self.cached[query];
+ let intersection_id = self.writer.get_constant_scalar(crate::Literal::U32(
+ spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _,
+ ));
+
+ let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::U32,
+ pointer_space: None,
+ }));
+ let kind_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionTypeKHR,
+ flag_type_id,
+ kind_id,
+ query_id,
+ intersection_id,
+ ));
+ let instance_custom_index_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionInstanceCustomIndexKHR,
+ flag_type_id,
+ instance_custom_index_id,
+ query_id,
+ intersection_id,
+ ));
+ let instance_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionInstanceIdKHR,
+ flag_type_id,
+ instance_id,
+ query_id,
+ intersection_id,
+ ));
+ let sbt_record_offset_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR,
+ flag_type_id,
+ sbt_record_offset_id,
+ query_id,
+ intersection_id,
+ ));
+ let geometry_index_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionGeometryIndexKHR,
+ flag_type_id,
+ geometry_index_id,
+ query_id,
+ intersection_id,
+ ));
+ let primitive_index_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionPrimitiveIndexKHR,
+ flag_type_id,
+ primitive_index_id,
+ query_id,
+ intersection_id,
+ ));
+
+ let scalar_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }));
+ let t_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionTKHR,
+ scalar_type_id,
+ t_id,
+ query_id,
+ intersection_id,
+ ));
+
+ let barycentrics_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Bi),
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ }));
+ let barycentrics_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionBarycentricsKHR,
+ barycentrics_type_id,
+ barycentrics_id,
+ query_id,
+ intersection_id,
+ ));
+
+ let bool_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::BOOL,
+ pointer_space: None,
+ }));
+ let front_face_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionFrontFaceKHR,
+ bool_type_id,
+ front_face_id,
+ query_id,
+ intersection_id,
+ ));
+
+ let transform_type_id = self.get_type_id(LookupType::Local(LocalType::Matrix {
+ columns: crate::VectorSize::Quad,
+ rows: crate::VectorSize::Tri,
+ width: 4,
+ }));
+ let object_to_world_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionObjectToWorldKHR,
+ transform_type_id,
+ object_to_world_id,
+ query_id,
+ intersection_id,
+ ));
+ let world_to_object_id = self.gen_id();
+ block.body.push(Instruction::ray_query_get_intersection(
+ spirv::Op::RayQueryGetIntersectionWorldToObjectKHR,
+ transform_type_id,
+ world_to_object_id,
+ query_id,
+ intersection_id,
+ ));
+
+ let id = self.gen_id();
+ let intersection_type_id = self.get_type_id(LookupType::Handle(
+ self.ir_module.special_types.ray_intersection.unwrap(),
+ ));
+ //Note: the arguments must match `generate_ray_intersection_type` layout
+ block.body.push(Instruction::composite_construct(
+ intersection_type_id,
+ id,
+ &[
+ kind_id,
+ t_id,
+ instance_custom_index_id,
+ instance_id,
+ sbt_record_offset_id,
+ geometry_index_id,
+ primitive_index_id,
+ barycentrics_id,
+ front_face_id,
+ object_to_world_id,
+ world_to_object_id,
+ ],
+ ));
+ id
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/recyclable.rs b/third_party/rust/naga/src/back/spv/recyclable.rs
new file mode 100644
index 0000000000..cd1466e3c7
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/recyclable.rs
@@ -0,0 +1,67 @@
+/*!
+Reusing collections' previous allocations.
+*/
+
+/// A value that can be reset to its initial state, retaining its current allocations.
+///
+/// Naga attempts to lower the cost of SPIR-V generation by allowing clients to
+/// reuse the same `Writer` for multiple Module translations. Reusing a `Writer`
+/// means that the `Vec`s, `HashMap`s, and other heap-allocated structures the
+/// `Writer` uses internally begin the translation with heap-allocated buffers
+/// ready to use.
+///
+/// But this approach introduces the risk of `Writer` state leaking from one
+/// module to the next. When a developer adds fields to `Writer` or its internal
+/// types, they must remember to reset their contents between modules.
+///
+/// One trick to ensure that every field has been accounted for is to use Rust's
+/// struct literal syntax to construct a new, reset value. If a developer adds a
+/// field, but neglects to update the reset code, the compiler will complain
+/// that a field is missing from the literal. This trait's `recycle` method
+/// takes `self` by value, and returns `Self` by value, encouraging the use of
+/// struct literal expressions in its implementation.
+pub trait Recyclable {
+ /// Clear `self`, retaining its current memory allocations.
+ ///
+ /// Shrink the buffer if it's currently much larger than was actually used.
+ /// This prevents a module with exceptionally large allocations from causing
+ /// the `Writer` to retain more memory than it needs indefinitely.
+ fn recycle(self) -> Self;
+}
+
+// Stock values for various collections.
+
+impl<T> Recyclable for Vec<T> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
+
+impl<K, V, S: Clone> Recyclable for std::collections::HashMap<K, V, S> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
+
+impl<K, S: Clone> Recyclable for std::collections::HashSet<K, S> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
+
+impl<K, S: Clone> Recyclable for indexmap::IndexSet<K, S> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
+
+impl<K: Ord, V> Recyclable for std::collections::BTreeMap<K, V> {
+ fn recycle(mut self) -> Self {
+ self.clear();
+ self
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/selection.rs b/third_party/rust/naga/src/back/spv/selection.rs
new file mode 100644
index 0000000000..788b1f10ab
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/selection.rs
@@ -0,0 +1,257 @@
+/*!
+Generate SPIR-V conditional structures.
+
+Builders for `if` structures with `and`s.
+
+The types in this module track the information needed to emit SPIR-V code
+for complex conditional structures, like those whose conditions involve
+short-circuiting 'and' and 'or' structures. These track labels and can emit
+`OpPhi` instructions to merge values produced along different paths.
+
+This currently only supports exactly the forms Naga uses, so it doesn't
+support `or` or `else`, and only supports zero or one merged values.
+
+Naga needs to emit code roughly like this:
+
+```ignore
+
+ value = DEFAULT;
+ if COND1 && COND2 {
+ value = THEN_VALUE;
+ }
+ // use value
+
+```
+
+Assuming `ctx` and `block` are a mutable references to a [`BlockContext`]
+and the current [`Block`], and `merge_type` is the SPIR-V type for the
+merged value `value`, we can build SPIR-V for the code above like so:
+
+```ignore
+
+ let cond = Selection::start(block, merge_type);
+ // ... compute `cond1` ...
+ cond.if_true(ctx, cond1, DEFAULT);
+ // ... compute `cond2` ...
+ cond.if_true(ctx, cond2, DEFAULT);
+ // ... compute THEN_VALUE
+ let merged_value = cond.finish(ctx, THEN_VALUE);
+
+```
+
+After this, `merged_value` is either `DEFAULT` or `THEN_VALUE`, depending on
+the path by which the merged block was reached.
+
+This takes care of writing all branch instructions, including an
+`OpSelectionMerge` annotation in the header block; starting new blocks and
+assigning them labels; and emitting the `OpPhi` that gathers together the
+right sources for the merged values, for every path through the selection
+construct.
+
+When there is no merged value to produce, you can pass `()` for `merge_type`
+and the merge values. In this case no `OpPhi` instructions are produced, and
+the `finish` method returns `()`.
+
+To enforce proper nesting, a `Selection` takes ownership of the `&mut Block`
+pointer for the duration of its lifetime. To obtain the block for generating
+code in the selection's body, call the `Selection::block` method.
+*/
+
+use super::{Block, BlockContext, Instruction};
+use spirv::Word;
+
+/// A private struct recording what we know about the selection construct so far.
+pub(super) struct Selection<'b, M: MergeTuple> {
+ /// The block pointer we're emitting code into.
+ block: &'b mut Block,
+
+ /// The label of the selection construct's merge block, or `None` if we
+ /// haven't yet written the `OpSelectionMerge` merge instruction.
+ merge_label: Option<Word>,
+
+ /// A set of `(VALUES, PARENT)` pairs, used to build `OpPhi` instructions in
+ /// the merge block. Each `PARENT` is the label of a predecessor block of
+ /// the merge block. The corresponding `VALUES` holds the ids of the values
+ /// that `PARENT` contributes to the merged values.
+ ///
+ /// We emit all branches to the merge block, so we know all its
+ /// predecessors. And we refuse to emit a branch unless we're given the
+ /// values the branching block contributes to the merge, so we always have
+ /// everything we need to emit the correct phis, by construction.
+ values: Vec<(M, Word)>,
+
+ /// The types of the values in each element of `values`.
+ merge_types: M,
+}
+
+impl<'b, M: MergeTuple> Selection<'b, M> {
+ /// Start a new selection construct.
+ ///
+ /// The `block` argument indicates the selection's header block.
+ ///
+ /// The `merge_types` argument should be a `Word` or tuple of `Word`s, each
+ /// value being the SPIR-V result type id of an `OpPhi` instruction that
+ /// will be written to the selection's merge block when this selection's
+ /// [`finish`] method is called. This argument may also be `()`, for
+ /// selections that produce no values.
+ ///
+ /// (This function writes no code to `block` itself; it simply constructs a
+ /// fresh `Selection`.)
+ ///
+ /// [`finish`]: Selection::finish
+ pub(super) fn start(block: &'b mut Block, merge_types: M) -> Self {
+ Selection {
+ block,
+ merge_label: None,
+ values: vec![],
+ merge_types,
+ }
+ }
+
+ pub(super) fn block(&mut self) -> &mut Block {
+ self.block
+ }
+
+ /// Branch to a successor block if `cond` is true, otherwise merge.
+ ///
+ /// If `cond` is false, branch to the merge block, using `values` as the
+ /// merged values. Otherwise, proceed to a new block.
+ ///
+ /// The `values` argument must be the same shape as the `merge_types`
+ /// argument passed to `Selection::start`.
+ pub(super) fn if_true(&mut self, ctx: &mut BlockContext, cond: Word, values: M) {
+ self.values.push((values, self.block.label_id));
+
+ let merge_label = self.make_merge_label(ctx);
+ let next_label = ctx.gen_id();
+ ctx.function.consume(
+ std::mem::replace(self.block, Block::new(next_label)),
+ Instruction::branch_conditional(cond, next_label, merge_label),
+ );
+ }
+
+ /// Emit an unconditional branch to the merge block, and compute merged
+ /// values.
+ ///
+ /// Use `final_values` as the merged values contributed by the current
+ /// block, and transition to the merge block, emitting `OpPhi` instructions
+ /// to produce the merged values. This must be the same shape as the
+ /// `merge_types` argument passed to [`Selection::start`].
+ ///
+ /// Return the SPIR-V ids of the merged values. This value has the same
+ /// shape as the `merge_types` argument passed to `Selection::start`.
+ pub(super) fn finish(self, ctx: &mut BlockContext, final_values: M) -> M {
+ match self {
+ Selection {
+ merge_label: None, ..
+ } => {
+ // We didn't actually emit any branches, so `self.values` must
+ // be empty, and `final_values` are the only sources we have for
+ // the merged values. Easy peasy.
+ final_values
+ }
+
+ Selection {
+ block,
+ merge_label: Some(merge_label),
+ mut values,
+ merge_types,
+ } => {
+ // Emit the final branch and transition to the merge block.
+ values.push((final_values, block.label_id));
+ ctx.function.consume(
+ std::mem::replace(block, Block::new(merge_label)),
+ Instruction::branch(merge_label),
+ );
+
+ // Now that we're in the merge block, build the phi instructions.
+ merge_types.write_phis(ctx, block, &values)
+ }
+ }
+ }
+
+ /// Return the id of the merge block, writing a merge instruction if needed.
+ fn make_merge_label(&mut self, ctx: &mut BlockContext) -> Word {
+ match self.merge_label {
+ None => {
+ let merge_label = ctx.gen_id();
+ self.block.body.push(Instruction::selection_merge(
+ merge_label,
+ spirv::SelectionControl::NONE,
+ ));
+ self.merge_label = Some(merge_label);
+ merge_label
+ }
+ Some(merge_label) => merge_label,
+ }
+ }
+}
+
+/// A trait to help `Selection` manage any number of merged values.
+///
+/// Some selection constructs, like a `ReadZeroSkipWrite` bounds check on a
+/// [`Load`] expression, produce a single merged value. Others produce no merged
+/// value, like a bounds check on a [`Store`] statement.
+///
+/// To let `Selection` work nicely with both cases, we let the merge type
+/// argument passed to [`Selection::start`] be any type that implements this
+/// `MergeTuple` trait. `MergeTuple` is then implemented for `()`, `Word`,
+/// `(Word, Word)`, and so on.
+///
+/// A `MergeTuple` type can represent either a bunch of SPIR-V types or values;
+/// the `merge_types` argument to `Selection::start` are type ids, whereas the
+/// `values` arguments to the [`if_true`] and [`finish`] methods are value ids.
+/// The set of merged value returned by `finish` is a tuple of value ids.
+///
+/// In fact, since Naga only uses zero- and single-valued selection constructs
+/// at present, we only implement `MergeTuple` for `()` and `Word`. But if you
+/// add more cases, feel free to add more implementations. Once const generics
+/// are available, we could have a single implementation of `MergeTuple` for all
+/// lengths of arrays, and be done with it.
+///
+/// [`Load`]: crate::Expression::Load
+/// [`Store`]: crate::Statement::Store
+/// [`if_true`]: Selection::if_true
+/// [`finish`]: Selection::finish
+pub(super) trait MergeTuple: Sized {
+ /// Write OpPhi instructions for the given set of predecessors.
+ ///
+ /// The `predecessors` vector should be a vector of `(LABEL, VALUES)` pairs,
+ /// where each `VALUES` holds the values contributed by the branch from
+ /// `LABEL`, which should be one of the current block's predecessors.
+ fn write_phis(
+ self,
+ ctx: &mut BlockContext,
+ block: &mut Block,
+ predecessors: &[(Self, Word)],
+ ) -> Self;
+}
+
+/// Selections that produce a single merged value.
+///
+/// For example, `ImageLoad` with `BoundsCheckPolicy::ReadZeroSkipWrite` either
+/// returns a texel value or zeros.
+impl MergeTuple for Word {
+ fn write_phis(
+ self,
+ ctx: &mut BlockContext,
+ block: &mut Block,
+ predecessors: &[(Word, Word)],
+ ) -> Word {
+ let merged_value = ctx.gen_id();
+ block
+ .body
+ .push(Instruction::phi(self, merged_value, predecessors));
+ merged_value
+ }
+}
+
+/// Selections that produce no merged values.
+///
+/// For example, `ImageStore` under `BoundsCheckPolicy::ReadZeroSkipWrite`
+/// either does the store or skips it, but in neither case does it produce a
+/// value.
+impl MergeTuple for () {
+ /// No phis need to be generated.
+ fn write_phis(self, _: &mut BlockContext, _: &mut Block, _: &[((), Word)]) {}
+}
diff --git a/third_party/rust/naga/src/back/spv/writer.rs b/third_party/rust/naga/src/back/spv/writer.rs
new file mode 100644
index 0000000000..4db86c93a7
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/writer.rs
@@ -0,0 +1,2063 @@
+use super::{
+ block::DebugInfoInner,
+ helpers::{contains_builtin, global_needs_wrapper, map_storage_class},
+ make_local, Block, BlockContext, CachedConstant, CachedExpressions, DebugInfo,
+ EntryPointContext, Error, Function, FunctionArgument, GlobalVariable, IdGenerator, Instruction,
+ LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, LoopContext, Options,
+ PhysicalLayout, PipelineOptions, ResultMember, Writer, WriterFlags, BITS_PER_BYTE,
+};
+use crate::{
+ arena::{Handle, UniqueArena},
+ back::spv::BindingInfo,
+ proc::{Alignment, TypeResolution},
+ valid::{FunctionInfo, ModuleInfo},
+};
+use spirv::Word;
+use std::collections::hash_map::Entry;
+
+struct FunctionInterface<'a> {
+ varying_ids: &'a mut Vec<Word>,
+ stage: crate::ShaderStage,
+}
+
+impl Function {
+ fn to_words(&self, sink: &mut impl Extend<Word>) {
+ self.signature.as_ref().unwrap().to_words(sink);
+ for argument in self.parameters.iter() {
+ argument.instruction.to_words(sink);
+ }
+ for (index, block) in self.blocks.iter().enumerate() {
+ Instruction::label(block.label_id).to_words(sink);
+ if index == 0 {
+ for local_var in self.variables.values() {
+ local_var.instruction.to_words(sink);
+ }
+ }
+ for instruction in block.body.iter() {
+ instruction.to_words(sink);
+ }
+ }
+ }
+}
+
+impl Writer {
+ pub fn new(options: &Options) -> Result<Self, Error> {
+ let (major, minor) = options.lang_version;
+ if major != 1 {
+ return Err(Error::UnsupportedVersion(major, minor));
+ }
+ let raw_version = ((major as u32) << 16) | ((minor as u32) << 8);
+
+ let mut capabilities_used = crate::FastIndexSet::default();
+ capabilities_used.insert(spirv::Capability::Shader);
+
+ let mut id_gen = IdGenerator::default();
+ let gl450_ext_inst_id = id_gen.next();
+ let void_type = id_gen.next();
+
+ Ok(Writer {
+ physical_layout: PhysicalLayout::new(raw_version),
+ logical_layout: LogicalLayout::default(),
+ id_gen,
+ capabilities_available: options.capabilities.clone(),
+ capabilities_used,
+ extensions_used: crate::FastIndexSet::default(),
+ debugs: vec![],
+ annotations: vec![],
+ flags: options.flags,
+ bounds_check_policies: options.bounds_check_policies,
+ zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory,
+ void_type,
+ lookup_type: crate::FastHashMap::default(),
+ lookup_function: crate::FastHashMap::default(),
+ lookup_function_type: crate::FastHashMap::default(),
+ constant_ids: Vec::new(),
+ cached_constants: crate::FastHashMap::default(),
+ global_variables: Vec::new(),
+ binding_map: options.binding_map.clone(),
+ saved_cached: CachedExpressions::default(),
+ gl450_ext_inst_id,
+ temp_list: Vec::new(),
+ })
+ }
+
+ /// Reset `Writer` to its initial state, retaining any allocations.
+ ///
+ /// Why not just implement `Recyclable` for `Writer`? By design,
+ /// `Recyclable::recycle` requires ownership of the value, not just
+ /// `&mut`; see the trait documentation. But we need to use this method
+ /// from functions like `Writer::write`, which only have `&mut Writer`.
+ /// Workarounds include unsafe code (`std::ptr::read`, then `write`, ugh)
+ /// or something like a `Default` impl that returns an oddly-initialized
+ /// `Writer`, which is worse.
+ fn reset(&mut self) {
+ use super::recyclable::Recyclable;
+ use std::mem::take;
+
+ let mut id_gen = IdGenerator::default();
+ let gl450_ext_inst_id = id_gen.next();
+ let void_type = id_gen.next();
+
+ // Every field of the old writer that is not determined by the `Options`
+ // passed to `Writer::new` should be reset somehow.
+ let fresh = Writer {
+ // Copied from the old Writer:
+ flags: self.flags,
+ bounds_check_policies: self.bounds_check_policies,
+ zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory,
+ capabilities_available: take(&mut self.capabilities_available),
+ binding_map: take(&mut self.binding_map),
+
+ // Initialized afresh:
+ id_gen,
+ void_type,
+ gl450_ext_inst_id,
+
+ // Recycled:
+ capabilities_used: take(&mut self.capabilities_used).recycle(),
+ extensions_used: take(&mut self.extensions_used).recycle(),
+ physical_layout: self.physical_layout.clone().recycle(),
+ logical_layout: take(&mut self.logical_layout).recycle(),
+ debugs: take(&mut self.debugs).recycle(),
+ annotations: take(&mut self.annotations).recycle(),
+ lookup_type: take(&mut self.lookup_type).recycle(),
+ lookup_function: take(&mut self.lookup_function).recycle(),
+ lookup_function_type: take(&mut self.lookup_function_type).recycle(),
+ constant_ids: take(&mut self.constant_ids).recycle(),
+ cached_constants: take(&mut self.cached_constants).recycle(),
+ global_variables: take(&mut self.global_variables).recycle(),
+ saved_cached: take(&mut self.saved_cached).recycle(),
+ temp_list: take(&mut self.temp_list).recycle(),
+ };
+
+ *self = fresh;
+
+ self.capabilities_used.insert(spirv::Capability::Shader);
+ }
+
+ /// Indicate that the code requires any one of the listed capabilities.
+ ///
+ /// If nothing in `capabilities` appears in the available capabilities
+ /// specified in the [`Options`] from which this `Writer` was created,
+ /// return an error. The `what` string is used in the error message to
+ /// explain what provoked the requirement. (If no available capabilities were
+ /// given, assume everything is available.)
+ ///
+ /// The first acceptable capability will be added to this `Writer`'s
+ /// [`capabilities_used`] table, and an `OpCapability` emitted for it in the
+ /// result. For this reason, more specific capabilities should be listed
+ /// before more general.
+ ///
+ /// [`capabilities_used`]: Writer::capabilities_used
+ pub(super) fn require_any(
+ &mut self,
+ what: &'static str,
+ capabilities: &[spirv::Capability],
+ ) -> Result<(), Error> {
+ match *capabilities {
+ [] => Ok(()),
+ [first, ..] => {
+ // Find the first acceptable capability, or return an error if
+ // there is none.
+ let selected = match self.capabilities_available {
+ None => first,
+ Some(ref available) => {
+ match capabilities.iter().find(|cap| available.contains(cap)) {
+ Some(&cap) => cap,
+ None => {
+ return Err(Error::MissingCapabilities(what, capabilities.to_vec()))
+ }
+ }
+ }
+ };
+ self.capabilities_used.insert(selected);
+ Ok(())
+ }
+ }
+ }
+
+ /// Indicate that the code uses the given extension.
+ pub(super) fn use_extension(&mut self, extension: &'static str) {
+ self.extensions_used.insert(extension);
+ }
+
+ pub(super) fn get_type_id(&mut self, lookup_ty: LookupType) -> Word {
+ match self.lookup_type.entry(lookup_ty) {
+ Entry::Occupied(e) => *e.get(),
+ Entry::Vacant(e) => {
+ let local = match lookup_ty {
+ LookupType::Handle(_handle) => unreachable!("Handles are populated at start"),
+ LookupType::Local(local) => local,
+ };
+
+ let id = self.id_gen.next();
+ e.insert(id);
+ self.write_type_declaration_local(id, local);
+ id
+ }
+ }
+ }
+
+ pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType {
+ match *tr {
+ TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle),
+ TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()),
+ }
+ }
+
+ pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
+ let lookup_ty = self.get_expression_lookup_type(tr);
+ self.get_type_id(lookup_ty)
+ }
+
+ pub(super) fn get_pointer_id(
+ &mut self,
+ arena: &UniqueArena<crate::Type>,
+ handle: Handle<crate::Type>,
+ class: spirv::StorageClass,
+ ) -> Result<Word, Error> {
+ let ty_id = self.get_type_id(LookupType::Handle(handle));
+ if let crate::TypeInner::Pointer { .. } = arena[handle].inner {
+ return Ok(ty_id);
+ }
+ let lookup_type = LookupType::Local(LocalType::Pointer {
+ base: handle,
+ class,
+ });
+ Ok(if let Some(&id) = self.lookup_type.get(&lookup_type) {
+ id
+ } else {
+ let id = self.id_gen.next();
+ let instruction = Instruction::type_pointer(id, class, ty_id);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_type.insert(lookup_type, id);
+ id
+ })
+ }
+
+ pub(super) fn get_uint_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::U32,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn get_float_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::F32,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn get_uint3_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: Some(crate::VectorSize::Tri),
+ scalar: crate::Scalar::U32,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn get_float_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
+ let lookup_type = LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::F32,
+ pointer_space: Some(class),
+ });
+ if let Some(&id) = self.lookup_type.get(&lookup_type) {
+ id
+ } else {
+ let id = self.id_gen.next();
+ let ty_id = self.get_float_type_id();
+ let instruction = Instruction::type_pointer(id, class, ty_id);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_type.insert(lookup_type, id);
+ id
+ }
+ }
+
+ pub(super) fn get_uint3_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
+ let lookup_type = LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Tri),
+ scalar: crate::Scalar::U32,
+ pointer_space: Some(class),
+ });
+ if let Some(&id) = self.lookup_type.get(&lookup_type) {
+ id
+ } else {
+ let id = self.id_gen.next();
+ let ty_id = self.get_uint3_type_id();
+ let instruction = Instruction::type_pointer(id, class, ty_id);
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_type.insert(lookup_type, id);
+ id
+ }
+ }
+
+ pub(super) fn get_bool_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar::BOOL,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn get_bool3_type_id(&mut self) -> Word {
+ let local_type = LocalType::Value {
+ vector_size: Some(crate::VectorSize::Tri),
+ scalar: crate::Scalar::BOOL,
+ pointer_space: None,
+ };
+ self.get_type_id(local_type.into())
+ }
+
+ pub(super) fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) {
+ self.annotations
+ .push(Instruction::decorate(id, decoration, operands));
+ }
+
+ fn write_function(
+ &mut self,
+ ir_function: &crate::Function,
+ info: &FunctionInfo,
+ ir_module: &crate::Module,
+ mut interface: Option<FunctionInterface>,
+ debug_info: &Option<DebugInfoInner>,
+ ) -> Result<Word, Error> {
+ let mut function = Function::default();
+
+ let prelude_id = self.id_gen.next();
+ let mut prelude = Block::new(prelude_id);
+ let mut ep_context = EntryPointContext {
+ argument_ids: Vec::new(),
+ results: Vec::new(),
+ };
+
+ let mut local_invocation_id = None;
+
+ let mut parameter_type_ids = Vec::with_capacity(ir_function.arguments.len());
+ for argument in ir_function.arguments.iter() {
+ let class = spirv::StorageClass::Input;
+ let handle_ty = ir_module.types[argument.ty].inner.is_handle();
+ let argument_type_id = match handle_ty {
+ true => self.get_pointer_id(
+ &ir_module.types,
+ argument.ty,
+ spirv::StorageClass::UniformConstant,
+ )?,
+ false => self.get_type_id(LookupType::Handle(argument.ty)),
+ };
+
+ if let Some(ref mut iface) = interface {
+ let id = if let Some(ref binding) = argument.binding {
+ let name = argument.name.as_deref();
+
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ name,
+ argument.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ let id = self.id_gen.next();
+ prelude
+ .body
+ .push(Instruction::load(argument_type_id, id, varying_id, None));
+
+ if binding == &crate::Binding::BuiltIn(crate::BuiltIn::LocalInvocationId) {
+ local_invocation_id = Some(id);
+ }
+
+ id
+ } else if let crate::TypeInner::Struct { ref members, .. } =
+ ir_module.types[argument.ty].inner
+ {
+ let struct_id = self.id_gen.next();
+ let mut constituent_ids = Vec::with_capacity(members.len());
+ for member in members {
+ let type_id = self.get_type_id(LookupType::Handle(member.ty));
+ let name = member.name.as_deref();
+ let binding = member.binding.as_ref().unwrap();
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ name,
+ member.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ let id = self.id_gen.next();
+ prelude
+ .body
+ .push(Instruction::load(type_id, id, varying_id, None));
+ constituent_ids.push(id);
+
+ if binding == &crate::Binding::BuiltIn(crate::BuiltIn::GlobalInvocationId) {
+ local_invocation_id = Some(id);
+ }
+ }
+ prelude.body.push(Instruction::composite_construct(
+ argument_type_id,
+ struct_id,
+ &constituent_ids,
+ ));
+ struct_id
+ } else {
+ unreachable!("Missing argument binding on an entry point");
+ };
+ ep_context.argument_ids.push(id);
+ } else {
+ let argument_id = self.id_gen.next();
+ let instruction = Instruction::function_parameter(argument_type_id, argument_id);
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = argument.name {
+ self.debugs.push(Instruction::name(argument_id, name));
+ }
+ }
+ function.parameters.push(FunctionArgument {
+ instruction,
+ handle_id: if handle_ty {
+ let id = self.id_gen.next();
+ prelude.body.push(Instruction::load(
+ self.get_type_id(LookupType::Handle(argument.ty)),
+ id,
+ argument_id,
+ None,
+ ));
+ id
+ } else {
+ 0
+ },
+ });
+ parameter_type_ids.push(argument_type_id);
+ };
+ }
+
+ let return_type_id = match ir_function.result {
+ Some(ref result) => {
+ if let Some(ref mut iface) = interface {
+ let mut has_point_size = false;
+ let class = spirv::StorageClass::Output;
+ if let Some(ref binding) = result.binding {
+ has_point_size |=
+ *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
+ let type_id = self.get_type_id(LookupType::Handle(result.ty));
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ None,
+ result.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ ep_context.results.push(ResultMember {
+ id: varying_id,
+ type_id,
+ built_in: binding.to_built_in(),
+ });
+ } else if let crate::TypeInner::Struct { ref members, .. } =
+ ir_module.types[result.ty].inner
+ {
+ for member in members {
+ let type_id = self.get_type_id(LookupType::Handle(member.ty));
+ let name = member.name.as_deref();
+ let binding = member.binding.as_ref().unwrap();
+ has_point_size |=
+ *binding == crate::Binding::BuiltIn(crate::BuiltIn::PointSize);
+ let varying_id = self.write_varying(
+ ir_module,
+ iface.stage,
+ class,
+ name,
+ member.ty,
+ binding,
+ )?;
+ iface.varying_ids.push(varying_id);
+ ep_context.results.push(ResultMember {
+ id: varying_id,
+ type_id,
+ built_in: binding.to_built_in(),
+ });
+ }
+ } else {
+ unreachable!("Missing result binding on an entry point");
+ }
+
+ if self.flags.contains(WriterFlags::FORCE_POINT_SIZE)
+ && iface.stage == crate::ShaderStage::Vertex
+ && !has_point_size
+ {
+ // add point size artificially
+ let varying_id = self.id_gen.next();
+ let pointer_type_id = self.get_float_pointer_type_id(class);
+ Instruction::variable(pointer_type_id, varying_id, class, None)
+ .to_words(&mut self.logical_layout.declarations);
+ self.decorate(
+ varying_id,
+ spirv::Decoration::BuiltIn,
+ &[spirv::BuiltIn::PointSize as u32],
+ );
+ iface.varying_ids.push(varying_id);
+
+ let default_value_id = self.get_constant_scalar(crate::Literal::F32(1.0));
+ prelude
+ .body
+ .push(Instruction::store(varying_id, default_value_id, None));
+ }
+ self.void_type
+ } else {
+ self.get_type_id(LookupType::Handle(result.ty))
+ }
+ }
+ None => self.void_type,
+ };
+
+ let lookup_function_type = LookupFunctionType {
+ parameter_type_ids,
+ return_type_id,
+ };
+
+ let function_id = self.id_gen.next();
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = ir_function.name {
+ self.debugs.push(Instruction::name(function_id, name));
+ }
+ }
+
+ let function_type = self.get_function_type(lookup_function_type);
+ function.signature = Some(Instruction::function(
+ return_type_id,
+ function_id,
+ spirv::FunctionControl::empty(),
+ function_type,
+ ));
+
+ if interface.is_some() {
+ function.entry_point_context = Some(ep_context);
+ }
+
+ // fill up the `GlobalVariable::access_id`
+ for gv in self.global_variables.iter_mut() {
+ gv.reset_for_function();
+ }
+ for (handle, var) in ir_module.global_variables.iter() {
+ if info[handle].is_empty() {
+ continue;
+ }
+
+ let mut gv = self.global_variables[handle.index()].clone();
+ if let Some(ref mut iface) = interface {
+ // Have to include global variables in the interface
+ if self.physical_layout.version >= 0x10400 {
+ iface.varying_ids.push(gv.var_id);
+ }
+ }
+
+ // Handle globals are pre-emitted and should be loaded automatically.
+ //
+ // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing.
+ let is_binding_array = match ir_module.types[var.ty].inner {
+ crate::TypeInner::BindingArray { .. } => true,
+ _ => false,
+ };
+
+ if var.space == crate::AddressSpace::Handle && !is_binding_array {
+ let var_type_id = self.get_type_id(LookupType::Handle(var.ty));
+ let id = self.id_gen.next();
+ prelude
+ .body
+ .push(Instruction::load(var_type_id, id, gv.var_id, None));
+ gv.access_id = gv.var_id;
+ gv.handle_id = id;
+ } else if global_needs_wrapper(ir_module, var) {
+ let class = map_storage_class(var.space);
+ let pointer_type_id = self.get_pointer_id(&ir_module.types, var.ty, class)?;
+ let index_id = self.get_index_constant(0);
+
+ let id = self.id_gen.next();
+ prelude.body.push(Instruction::access_chain(
+ pointer_type_id,
+ id,
+ gv.var_id,
+ &[index_id],
+ ));
+ gv.access_id = id;
+ } else {
+ // by default, the variable ID is accessed as is
+ gv.access_id = gv.var_id;
+ };
+
+ // work around borrow checking in the presence of `self.xxx()` calls
+ self.global_variables[handle.index()] = gv;
+ }
+
+ // Create a `BlockContext` for generating SPIR-V for the function's
+ // body.
+ let mut context = BlockContext {
+ ir_module,
+ ir_function,
+ fun_info: info,
+ function: &mut function,
+ // Re-use the cached expression table from prior functions.
+ cached: std::mem::take(&mut self.saved_cached),
+
+ // Steal the Writer's temp list for a bit.
+ temp_list: std::mem::take(&mut self.temp_list),
+ writer: self,
+ expression_constness: crate::proc::ExpressionConstnessTracker::from_arena(
+ &ir_function.expressions,
+ ),
+ };
+
+ // fill up the pre-emitted and const expressions
+ context.cached.reset(ir_function.expressions.len());
+ for (handle, expr) in ir_function.expressions.iter() {
+ if (expr.needs_pre_emit() && !matches!(*expr, crate::Expression::LocalVariable(_)))
+ || context.expression_constness.is_const(handle)
+ {
+ context.cache_expression_value(handle, &mut prelude)?;
+ }
+ }
+
+ for (handle, variable) in ir_function.local_variables.iter() {
+ let id = context.gen_id();
+
+ if context.writer.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = variable.name {
+ context.writer.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ let init_word = variable.init.map(|constant| context.cached[constant]);
+ let pointer_type_id = context.writer.get_pointer_id(
+ &ir_module.types,
+ variable.ty,
+ spirv::StorageClass::Function,
+ )?;
+ let instruction = Instruction::variable(
+ pointer_type_id,
+ id,
+ spirv::StorageClass::Function,
+ init_word.or_else(|| match ir_module.types[variable.ty].inner {
+ crate::TypeInner::RayQuery => None,
+ _ => {
+ let type_id = context.get_type_id(LookupType::Handle(variable.ty));
+ Some(context.writer.write_constant_null(type_id))
+ }
+ }),
+ );
+ context
+ .function
+ .variables
+ .insert(handle, LocalVariable { id, instruction });
+ }
+
+ // cache local variable expressions
+ for (handle, expr) in ir_function.expressions.iter() {
+ if matches!(*expr, crate::Expression::LocalVariable(_)) {
+ context.cache_expression_value(handle, &mut prelude)?;
+ }
+ }
+
+ let next_id = context.gen_id();
+
+ context
+ .function
+ .consume(prelude, Instruction::branch(next_id));
+
+ let workgroup_vars_init_exit_block_id =
+ match (context.writer.zero_initialize_workgroup_memory, interface) {
+ (
+ super::ZeroInitializeWorkgroupMemoryMode::Polyfill,
+ Some(
+ ref mut interface @ FunctionInterface {
+ stage: crate::ShaderStage::Compute,
+ ..
+ },
+ ),
+ ) => context.writer.generate_workgroup_vars_init_block(
+ next_id,
+ ir_module,
+ info,
+ local_invocation_id,
+ interface,
+ context.function,
+ ),
+ _ => None,
+ };
+
+ let main_id = if let Some(exit_id) = workgroup_vars_init_exit_block_id {
+ exit_id
+ } else {
+ next_id
+ };
+
+ context.write_block(
+ main_id,
+ &ir_function.body,
+ super::block::BlockExit::Return,
+ LoopContext::default(),
+ debug_info.as_ref(),
+ )?;
+
+ // Consume the `BlockContext`, ending its borrows and letting the
+ // `Writer` steal back its cached expression table and temp_list.
+ let BlockContext {
+ cached, temp_list, ..
+ } = context;
+ self.saved_cached = cached;
+ self.temp_list = temp_list;
+
+ function.to_words(&mut self.logical_layout.function_definitions);
+ Instruction::function_end().to_words(&mut self.logical_layout.function_definitions);
+
+ Ok(function_id)
+ }
+
+ fn write_execution_mode(
+ &mut self,
+ function_id: Word,
+ mode: spirv::ExecutionMode,
+ ) -> Result<(), Error> {
+ //self.check(mode.required_capabilities())?;
+ Instruction::execution_mode(function_id, mode, &[])
+ .to_words(&mut self.logical_layout.execution_modes);
+ Ok(())
+ }
+
+ // TODO Move to instructions module
+ fn write_entry_point(
+ &mut self,
+ entry_point: &crate::EntryPoint,
+ info: &FunctionInfo,
+ ir_module: &crate::Module,
+ debug_info: &Option<DebugInfoInner>,
+ ) -> Result<Instruction, Error> {
+ let mut interface_ids = Vec::new();
+ let function_id = self.write_function(
+ &entry_point.function,
+ info,
+ ir_module,
+ Some(FunctionInterface {
+ varying_ids: &mut interface_ids,
+ stage: entry_point.stage,
+ }),
+ debug_info,
+ )?;
+
+ let exec_model = match entry_point.stage {
+ crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex,
+ crate::ShaderStage::Fragment => {
+ self.write_execution_mode(function_id, spirv::ExecutionMode::OriginUpperLeft)?;
+ if let Some(ref result) = entry_point.function.result {
+ if contains_builtin(
+ result.binding.as_ref(),
+ result.ty,
+ &ir_module.types,
+ crate::BuiltIn::FragDepth,
+ ) {
+ self.write_execution_mode(
+ function_id,
+ spirv::ExecutionMode::DepthReplacing,
+ )?;
+ }
+ }
+ spirv::ExecutionModel::Fragment
+ }
+ crate::ShaderStage::Compute => {
+ let execution_mode = spirv::ExecutionMode::LocalSize;
+ //self.check(execution_mode.required_capabilities())?;
+ Instruction::execution_mode(
+ function_id,
+ execution_mode,
+ &entry_point.workgroup_size,
+ )
+ .to_words(&mut self.logical_layout.execution_modes);
+ spirv::ExecutionModel::GLCompute
+ }
+ };
+ //self.check(exec_model.required_capabilities())?;
+
+ Ok(Instruction::entry_point(
+ exec_model,
+ function_id,
+ &entry_point.name,
+ interface_ids.as_slice(),
+ ))
+ }
+
+ fn make_scalar(&mut self, id: Word, scalar: crate::Scalar) -> Instruction {
+ use crate::ScalarKind as Sk;
+
+ let bits = (scalar.width * BITS_PER_BYTE) as u32;
+ match scalar.kind {
+ Sk::Sint | Sk::Uint => {
+ let signedness = if scalar.kind == Sk::Sint {
+ super::instructions::Signedness::Signed
+ } else {
+ super::instructions::Signedness::Unsigned
+ };
+ let cap = match bits {
+ 8 => Some(spirv::Capability::Int8),
+ 16 => Some(spirv::Capability::Int16),
+ 64 => Some(spirv::Capability::Int64),
+ _ => None,
+ };
+ if let Some(cap) = cap {
+ self.capabilities_used.insert(cap);
+ }
+ Instruction::type_int(id, bits, signedness)
+ }
+ Sk::Float => {
+ if bits == 64 {
+ self.capabilities_used.insert(spirv::Capability::Float64);
+ }
+ Instruction::type_float(id, bits)
+ }
+ Sk::Bool => Instruction::type_bool(id),
+ Sk::AbstractInt | Sk::AbstractFloat => {
+ unreachable!("abstract types should never reach the backend");
+ }
+ }
+ }
+
+ fn request_type_capabilities(&mut self, inner: &crate::TypeInner) -> Result<(), Error> {
+ match *inner {
+ crate::TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ let sampled = match class {
+ crate::ImageClass::Sampled { .. } => true,
+ crate::ImageClass::Depth { .. } => true,
+ crate::ImageClass::Storage { format, .. } => {
+ self.request_image_format_capabilities(format.into())?;
+ false
+ }
+ };
+
+ match dim {
+ crate::ImageDimension::D1 => {
+ if sampled {
+ self.require_any("sampled 1D images", &[spirv::Capability::Sampled1D])?;
+ } else {
+ self.require_any("1D storage images", &[spirv::Capability::Image1D])?;
+ }
+ }
+ crate::ImageDimension::Cube if arrayed => {
+ if sampled {
+ self.require_any(
+ "sampled cube array images",
+ &[spirv::Capability::SampledCubeArray],
+ )?;
+ } else {
+ self.require_any(
+ "cube array storage images",
+ &[spirv::Capability::ImageCubeArray],
+ )?;
+ }
+ }
+ _ => {}
+ }
+ }
+ crate::TypeInner::AccelerationStructure => {
+ self.require_any("Acceleration Structure", &[spirv::Capability::RayQueryKHR])?;
+ }
+ crate::TypeInner::RayQuery => {
+ self.require_any("Ray Query", &[spirv::Capability::RayQueryKHR])?;
+ }
+ _ => {}
+ }
+ Ok(())
+ }
+
+ fn write_type_declaration_local(&mut self, id: Word, local_ty: LocalType) {
+ let instruction = match local_ty {
+ LocalType::Value {
+ vector_size: None,
+ scalar,
+ pointer_space: None,
+ } => self.make_scalar(id, scalar),
+ LocalType::Value {
+ vector_size: Some(size),
+ scalar,
+ pointer_space: None,
+ } => {
+ let scalar_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar,
+ pointer_space: None,
+ }));
+ Instruction::type_vector(id, scalar_id, size)
+ }
+ LocalType::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let vector_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(rows),
+ scalar: crate::Scalar::float(width),
+ pointer_space: None,
+ }));
+ Instruction::type_matrix(id, vector_id, columns)
+ }
+ LocalType::Pointer { base, class } => {
+ let type_id = self.get_type_id(LookupType::Handle(base));
+ Instruction::type_pointer(id, class, type_id)
+ }
+ LocalType::Value {
+ vector_size,
+ scalar,
+ pointer_space: Some(class),
+ } => {
+ let type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size,
+ scalar,
+ pointer_space: None,
+ }));
+ Instruction::type_pointer(id, class, type_id)
+ }
+ LocalType::Image(image) => {
+ let local_type = LocalType::Value {
+ vector_size: None,
+ scalar: crate::Scalar {
+ kind: image.sampled_type,
+ width: 4,
+ },
+ pointer_space: None,
+ };
+ let type_id = self.get_type_id(LookupType::Local(local_type));
+ Instruction::type_image(id, type_id, image.dim, image.flags, image.image_format)
+ }
+ LocalType::Sampler => Instruction::type_sampler(id),
+ LocalType::SampledImage { image_type_id } => {
+ Instruction::type_sampled_image(id, image_type_id)
+ }
+ LocalType::BindingArray { base, size } => {
+ let inner_ty = self.get_type_id(LookupType::Handle(base));
+ let scalar_id = self.get_constant_scalar(crate::Literal::U32(size));
+ Instruction::type_array(id, inner_ty, scalar_id)
+ }
+ LocalType::PointerToBindingArray { base, size, space } => {
+ let inner_ty =
+ self.get_type_id(LookupType::Local(LocalType::BindingArray { base, size }));
+ let class = map_storage_class(space);
+ Instruction::type_pointer(id, class, inner_ty)
+ }
+ LocalType::AccelerationStructure => Instruction::type_acceleration_structure(id),
+ LocalType::RayQuery => Instruction::type_ray_query(id),
+ };
+
+ instruction.to_words(&mut self.logical_layout.declarations);
+ }
+
+ fn write_type_declaration_arena(
+ &mut self,
+ arena: &UniqueArena<crate::Type>,
+ handle: Handle<crate::Type>,
+ ) -> Result<Word, Error> {
+ let ty = &arena[handle];
+ let id = if let Some(local) = make_local(&ty.inner) {
+ // This type can be represented as a `LocalType`, so check if we've
+ // already written an instruction for it. If not, do so now, with
+ // `write_type_declaration_local`.
+ match self.lookup_type.entry(LookupType::Local(local)) {
+ // We already have an id for this `LocalType`.
+ Entry::Occupied(e) => *e.get(),
+
+ // It's a type we haven't seen before.
+ Entry::Vacant(e) => {
+ let id = self.id_gen.next();
+ e.insert(id);
+
+ self.write_type_declaration_local(id, local);
+
+ // If it's a type that needs SPIR-V capabilities, request them now,
+ // so write_type_declaration_local can stay infallible.
+ self.request_type_capabilities(&ty.inner)?;
+
+ id
+ }
+ }
+ } else {
+ use spirv::Decoration;
+
+ let id = self.id_gen.next();
+ let instruction = match ty.inner {
+ crate::TypeInner::Array { base, size, stride } => {
+ self.decorate(id, Decoration::ArrayStride, &[stride]);
+
+ let type_id = self.get_type_id(LookupType::Handle(base));
+ match size {
+ crate::ArraySize::Constant(length) => {
+ let length_id = self.get_index_constant(length.get());
+ Instruction::type_array(id, type_id, length_id)
+ }
+ crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id),
+ }
+ }
+ crate::TypeInner::BindingArray { base, size } => {
+ let type_id = self.get_type_id(LookupType::Handle(base));
+ match size {
+ crate::ArraySize::Constant(length) => {
+ let length_id = self.get_index_constant(length.get());
+ Instruction::type_array(id, type_id, length_id)
+ }
+ crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id),
+ }
+ }
+ crate::TypeInner::Struct {
+ ref members,
+ span: _,
+ } => {
+ let mut has_runtime_array = false;
+ let mut member_ids = Vec::with_capacity(members.len());
+ for (index, member) in members.iter().enumerate() {
+ let member_ty = &arena[member.ty];
+ match member_ty.inner {
+ crate::TypeInner::Array {
+ base: _,
+ size: crate::ArraySize::Dynamic,
+ stride: _,
+ } => {
+ has_runtime_array = true;
+ }
+ _ => (),
+ }
+ self.decorate_struct_member(id, index, member, arena)?;
+ let member_id = self.get_type_id(LookupType::Handle(member.ty));
+ member_ids.push(member_id);
+ }
+ if has_runtime_array {
+ self.decorate(id, Decoration::Block, &[]);
+ }
+ Instruction::type_struct(id, member_ids.as_slice())
+ }
+
+ // These all have TypeLocal representations, so they should have been
+ // handled by `write_type_declaration_local` above.
+ crate::TypeInner::Scalar(_)
+ | crate::TypeInner::Atomic(_)
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. }
+ | crate::TypeInner::Pointer { .. }
+ | crate::TypeInner::ValuePointer { .. }
+ | crate::TypeInner::Image { .. }
+ | crate::TypeInner::Sampler { .. }
+ | crate::TypeInner::AccelerationStructure
+ | crate::TypeInner::RayQuery => unreachable!(),
+ };
+
+ instruction.to_words(&mut self.logical_layout.declarations);
+ id
+ };
+
+ // Add this handle as a new alias for that type.
+ self.lookup_type.insert(LookupType::Handle(handle), id);
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = ty.name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ Ok(id)
+ }
+
+ fn request_image_format_capabilities(
+ &mut self,
+ format: spirv::ImageFormat,
+ ) -> Result<(), Error> {
+ use spirv::ImageFormat as If;
+ match format {
+ If::Rg32f
+ | If::Rg16f
+ | If::R11fG11fB10f
+ | If::R16f
+ | If::Rgba16
+ | If::Rgb10A2
+ | If::Rg16
+ | If::Rg8
+ | If::R16
+ | If::R8
+ | If::Rgba16Snorm
+ | If::Rg16Snorm
+ | If::Rg8Snorm
+ | If::R16Snorm
+ | If::R8Snorm
+ | If::Rg32i
+ | If::Rg16i
+ | If::Rg8i
+ | If::R16i
+ | If::R8i
+ | If::Rgb10a2ui
+ | If::Rg32ui
+ | If::Rg16ui
+ | If::Rg8ui
+ | If::R16ui
+ | If::R8ui => self.require_any(
+ "storage image format",
+ &[spirv::Capability::StorageImageExtendedFormats],
+ ),
+ If::R64ui | If::R64i => self.require_any(
+ "64-bit integer storage image format",
+ &[spirv::Capability::Int64ImageEXT],
+ ),
+ If::Unknown
+ | If::Rgba32f
+ | If::Rgba16f
+ | If::R32f
+ | If::Rgba8
+ | If::Rgba8Snorm
+ | If::Rgba32i
+ | If::Rgba16i
+ | If::Rgba8i
+ | If::R32i
+ | If::Rgba32ui
+ | If::Rgba16ui
+ | If::Rgba8ui
+ | If::R32ui => Ok(()),
+ }
+ }
+
+ pub(super) fn get_index_constant(&mut self, index: Word) -> Word {
+ self.get_constant_scalar(crate::Literal::U32(index))
+ }
+
+ pub(super) fn get_constant_scalar_with(
+ &mut self,
+ value: u8,
+ scalar: crate::Scalar,
+ ) -> Result<Word, Error> {
+ Ok(
+ self.get_constant_scalar(crate::Literal::new(value, scalar).ok_or(
+ Error::Validation("Unexpected kind and/or width for Literal"),
+ )?),
+ )
+ }
+
+ pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word {
+ let scalar = CachedConstant::Literal(value);
+ if let Some(&id) = self.cached_constants.get(&scalar) {
+ return id;
+ }
+ let id = self.id_gen.next();
+ self.write_constant_scalar(id, &value, None);
+ self.cached_constants.insert(scalar, id);
+ id
+ }
+
+ fn write_constant_scalar(
+ &mut self,
+ id: Word,
+ value: &crate::Literal,
+ debug_name: Option<&String>,
+ ) {
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(name) = debug_name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+ let type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: None,
+ scalar: value.scalar(),
+ pointer_space: None,
+ }));
+ let instruction = match *value {
+ crate::Literal::F64(value) => {
+ let bits = value.to_bits();
+ Instruction::constant_64bit(type_id, id, bits as u32, (bits >> 32) as u32)
+ }
+ crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()),
+ crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value),
+ crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32),
+ crate::Literal::I64(value) => {
+ Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
+ }
+ crate::Literal::Bool(true) => Instruction::constant_true(type_id, id),
+ crate::Literal::Bool(false) => Instruction::constant_false(type_id, id),
+ crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
+ unreachable!("Abstract types should not appear in IR presented to backends");
+ }
+ };
+
+ instruction.to_words(&mut self.logical_layout.declarations);
+ }
+
+ pub(super) fn get_constant_composite(
+ &mut self,
+ ty: LookupType,
+ constituent_ids: &[Word],
+ ) -> Word {
+ let composite = CachedConstant::Composite {
+ ty,
+ constituent_ids: constituent_ids.to_vec(),
+ };
+ if let Some(&id) = self.cached_constants.get(&composite) {
+ return id;
+ }
+ let id = self.id_gen.next();
+ self.write_constant_composite(id, ty, constituent_ids, None);
+ self.cached_constants.insert(composite, id);
+ id
+ }
+
+ fn write_constant_composite(
+ &mut self,
+ id: Word,
+ ty: LookupType,
+ constituent_ids: &[Word],
+ debug_name: Option<&String>,
+ ) {
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(name) = debug_name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+ let type_id = self.get_type_id(ty);
+ Instruction::constant_composite(type_id, id, constituent_ids)
+ .to_words(&mut self.logical_layout.declarations);
+ }
+
+ pub(super) fn get_constant_null(&mut self, type_id: Word) -> Word {
+ let null = CachedConstant::ZeroValue(type_id);
+ if let Some(&id) = self.cached_constants.get(&null) {
+ return id;
+ }
+ let id = self.write_constant_null(type_id);
+ self.cached_constants.insert(null, id);
+ id
+ }
+
+ pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word {
+ let null_id = self.id_gen.next();
+ Instruction::constant_null(type_id, null_id)
+ .to_words(&mut self.logical_layout.declarations);
+ null_id
+ }
+
+ fn write_constant_expr(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ ir_module: &crate::Module,
+ mod_info: &ModuleInfo,
+ ) -> Result<Word, Error> {
+ let id = match ir_module.const_expressions[handle] {
+ crate::Expression::Literal(literal) => self.get_constant_scalar(literal),
+ crate::Expression::Constant(constant) => {
+ let constant = &ir_module.constants[constant];
+ self.constant_ids[constant.init.index()]
+ }
+ crate::Expression::ZeroValue(ty) => {
+ let type_id = self.get_type_id(LookupType::Handle(ty));
+ self.get_constant_null(type_id)
+ }
+ crate::Expression::Compose { ty, ref components } => {
+ let component_ids: Vec<_> = crate::proc::flatten_compose(
+ ty,
+ components,
+ &ir_module.const_expressions,
+ &ir_module.types,
+ )
+ .map(|component| self.constant_ids[component.index()])
+ .collect();
+ self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice())
+ }
+ crate::Expression::Splat { size, value } => {
+ let value_id = self.constant_ids[value.index()];
+ let component_ids = &[value_id; 4][..size as usize];
+
+ let ty = self.get_expression_lookup_type(&mod_info[handle]);
+
+ self.get_constant_composite(ty, component_ids)
+ }
+ _ => unreachable!(),
+ };
+
+ self.constant_ids[handle.index()] = id;
+
+ Ok(id)
+ }
+
+ pub(super) fn write_barrier(&mut self, flags: crate::Barrier, block: &mut Block) {
+ let memory_scope = if flags.contains(crate::Barrier::STORAGE) {
+ spirv::Scope::Device
+ } else {
+ spirv::Scope::Workgroup
+ };
+ let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
+ semantics.set(
+ spirv::MemorySemantics::UNIFORM_MEMORY,
+ flags.contains(crate::Barrier::STORAGE),
+ );
+ semantics.set(
+ spirv::MemorySemantics::WORKGROUP_MEMORY,
+ flags.contains(crate::Barrier::WORK_GROUP),
+ );
+ let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32);
+ let mem_scope_id = self.get_index_constant(memory_scope as u32);
+ let semantics_id = self.get_index_constant(semantics.bits());
+ block.body.push(Instruction::control_barrier(
+ exec_scope_id,
+ mem_scope_id,
+ semantics_id,
+ ));
+ }
+
+ fn generate_workgroup_vars_init_block(
+ &mut self,
+ entry_id: Word,
+ ir_module: &crate::Module,
+ info: &FunctionInfo,
+ local_invocation_id: Option<Word>,
+ interface: &mut FunctionInterface,
+ function: &mut Function,
+ ) -> Option<Word> {
+ let body = ir_module
+ .global_variables
+ .iter()
+ .filter(|&(handle, var)| {
+ !info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
+ })
+ .map(|(handle, var)| {
+ // It's safe to use `var_id` here, not `access_id`, because only
+ // variables in the `Uniform` and `StorageBuffer` address spaces
+ // get wrapped, and we're initializing `WorkGroup` variables.
+ let var_id = self.global_variables[handle.index()].var_id;
+ let var_type_id = self.get_type_id(LookupType::Handle(var.ty));
+ let init_word = self.get_constant_null(var_type_id);
+ Instruction::store(var_id, init_word, None)
+ })
+ .collect::<Vec<_>>();
+
+ if body.is_empty() {
+ return None;
+ }
+
+ let uint3_type_id = self.get_uint3_type_id();
+
+ let mut pre_if_block = Block::new(entry_id);
+
+ let local_invocation_id = if let Some(local_invocation_id) = local_invocation_id {
+ local_invocation_id
+ } else {
+ let varying_id = self.id_gen.next();
+ let class = spirv::StorageClass::Input;
+ let pointer_type_id = self.get_uint3_pointer_type_id(class);
+
+ Instruction::variable(pointer_type_id, varying_id, class, None)
+ .to_words(&mut self.logical_layout.declarations);
+
+ self.decorate(
+ varying_id,
+ spirv::Decoration::BuiltIn,
+ &[spirv::BuiltIn::LocalInvocationId as u32],
+ );
+
+ interface.varying_ids.push(varying_id);
+ let id = self.id_gen.next();
+ pre_if_block
+ .body
+ .push(Instruction::load(uint3_type_id, id, varying_id, None));
+
+ id
+ };
+
+ let zero_id = self.get_constant_null(uint3_type_id);
+ let bool3_type_id = self.get_bool3_type_id();
+
+ let eq_id = self.id_gen.next();
+ pre_if_block.body.push(Instruction::binary(
+ spirv::Op::IEqual,
+ bool3_type_id,
+ eq_id,
+ local_invocation_id,
+ zero_id,
+ ));
+
+ let condition_id = self.id_gen.next();
+ let bool_type_id = self.get_bool_type_id();
+ pre_if_block.body.push(Instruction::relational(
+ spirv::Op::All,
+ bool_type_id,
+ condition_id,
+ eq_id,
+ ));
+
+ let merge_id = self.id_gen.next();
+ pre_if_block.body.push(Instruction::selection_merge(
+ merge_id,
+ spirv::SelectionControl::NONE,
+ ));
+
+ let accept_id = self.id_gen.next();
+ function.consume(
+ pre_if_block,
+ Instruction::branch_conditional(condition_id, accept_id, merge_id),
+ );
+
+ let accept_block = Block {
+ label_id: accept_id,
+ body,
+ };
+ function.consume(accept_block, Instruction::branch(merge_id));
+
+ let mut post_if_block = Block::new(merge_id);
+
+ self.write_barrier(crate::Barrier::WORK_GROUP, &mut post_if_block);
+
+ let next_id = self.id_gen.next();
+ function.consume(post_if_block, Instruction::branch(next_id));
+ Some(next_id)
+ }
+
+ /// Generate an `OpVariable` for one value in an [`EntryPoint`]'s IO interface.
+ ///
+ /// The [`Binding`]s of the arguments and result of an [`EntryPoint`]'s
+ /// [`Function`] describe a SPIR-V shader interface. In SPIR-V, the
+ /// interface is represented by global variables in the `Input` and `Output`
+ /// storage classes, with decorations indicating which builtin or location
+ /// each variable corresponds to.
+ ///
+ /// This function emits a single global `OpVariable` for a single value from
+ /// the interface, and adds appropriate decorations to indicate which
+ /// builtin or location it represents, how it should be interpolated, and so
+ /// on. The `class` argument gives the variable's SPIR-V storage class,
+ /// which should be either [`Input`] or [`Output`].
+ ///
+ /// [`Binding`]: crate::Binding
+ /// [`Function`]: crate::Function
+ /// [`EntryPoint`]: crate::EntryPoint
+ /// [`Input`]: spirv::StorageClass::Input
+ /// [`Output`]: spirv::StorageClass::Output
+ fn write_varying(
+ &mut self,
+ ir_module: &crate::Module,
+ stage: crate::ShaderStage,
+ class: spirv::StorageClass,
+ debug_name: Option<&str>,
+ ty: Handle<crate::Type>,
+ binding: &crate::Binding,
+ ) -> Result<Word, Error> {
+ let id = self.id_gen.next();
+ let pointer_type_id = self.get_pointer_id(&ir_module.types, ty, class)?;
+ Instruction::variable(pointer_type_id, id, class, None)
+ .to_words(&mut self.logical_layout.declarations);
+
+ if self
+ .flags
+ .contains(WriterFlags::DEBUG | WriterFlags::LABEL_VARYINGS)
+ {
+ if let Some(name) = debug_name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ use spirv::{BuiltIn, Decoration};
+
+ match *binding {
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ second_blend_source,
+ } => {
+ self.decorate(id, Decoration::Location, &[location]);
+
+ let no_decorations =
+ // VUID-StandaloneSpirv-Flat-06202
+ // > The Flat, NoPerspective, Sample, and Centroid decorations
+ // > must not be used on variables with the Input storage class in a vertex shader
+ (class == spirv::StorageClass::Input && stage == crate::ShaderStage::Vertex) ||
+ // VUID-StandaloneSpirv-Flat-06201
+ // > The Flat, NoPerspective, Sample, and Centroid decorations
+ // > must not be used on variables with the Output storage class in a fragment shader
+ (class == spirv::StorageClass::Output && stage == crate::ShaderStage::Fragment);
+
+ if !no_decorations {
+ match interpolation {
+ // Perspective-correct interpolation is the default in SPIR-V.
+ None | Some(crate::Interpolation::Perspective) => (),
+ Some(crate::Interpolation::Flat) => {
+ self.decorate(id, Decoration::Flat, &[]);
+ }
+ Some(crate::Interpolation::Linear) => {
+ self.decorate(id, Decoration::NoPerspective, &[]);
+ }
+ }
+ match sampling {
+ // Center sampling is the default in SPIR-V.
+ None | Some(crate::Sampling::Center) => (),
+ Some(crate::Sampling::Centroid) => {
+ self.decorate(id, Decoration::Centroid, &[]);
+ }
+ Some(crate::Sampling::Sample) => {
+ self.require_any(
+ "per-sample interpolation",
+ &[spirv::Capability::SampleRateShading],
+ )?;
+ self.decorate(id, Decoration::Sample, &[]);
+ }
+ }
+ }
+ if second_blend_source {
+ self.decorate(id, Decoration::Index, &[1]);
+ }
+ }
+ crate::Binding::BuiltIn(built_in) => {
+ use crate::BuiltIn as Bi;
+ let built_in = match built_in {
+ Bi::Position { invariant } => {
+ if invariant {
+ self.decorate(id, Decoration::Invariant, &[]);
+ }
+
+ if class == spirv::StorageClass::Output {
+ BuiltIn::Position
+ } else {
+ BuiltIn::FragCoord
+ }
+ }
+ Bi::ViewIndex => {
+ self.require_any("`view_index` built-in", &[spirv::Capability::MultiView])?;
+ BuiltIn::ViewIndex
+ }
+ // vertex
+ Bi::BaseInstance => BuiltIn::BaseInstance,
+ Bi::BaseVertex => BuiltIn::BaseVertex,
+ Bi::ClipDistance => {
+ self.require_any(
+ "`clip_distance` built-in",
+ &[spirv::Capability::ClipDistance],
+ )?;
+ BuiltIn::ClipDistance
+ }
+ Bi::CullDistance => {
+ self.require_any(
+ "`cull_distance` built-in",
+ &[spirv::Capability::CullDistance],
+ )?;
+ BuiltIn::CullDistance
+ }
+ Bi::InstanceIndex => BuiltIn::InstanceIndex,
+ Bi::PointSize => BuiltIn::PointSize,
+ Bi::VertexIndex => BuiltIn::VertexIndex,
+ // fragment
+ Bi::FragDepth => BuiltIn::FragDepth,
+ Bi::PointCoord => BuiltIn::PointCoord,
+ Bi::FrontFacing => BuiltIn::FrontFacing,
+ Bi::PrimitiveIndex => {
+ self.require_any(
+ "`primitive_index` built-in",
+ &[spirv::Capability::Geometry],
+ )?;
+ BuiltIn::PrimitiveId
+ }
+ Bi::SampleIndex => {
+ self.require_any(
+ "`sample_index` built-in",
+ &[spirv::Capability::SampleRateShading],
+ )?;
+
+ BuiltIn::SampleId
+ }
+ Bi::SampleMask => BuiltIn::SampleMask,
+ // compute
+ Bi::GlobalInvocationId => BuiltIn::GlobalInvocationId,
+ Bi::LocalInvocationId => BuiltIn::LocalInvocationId,
+ Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
+ Bi::WorkGroupId => BuiltIn::WorkgroupId,
+ Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
+ Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
+ };
+
+ self.decorate(id, Decoration::BuiltIn, &[built_in as u32]);
+
+ use crate::ScalarKind as Sk;
+
+ // Per the Vulkan spec, `VUID-StandaloneSpirv-Flat-04744`:
+ //
+ // > Any variable with integer or double-precision floating-
+ // > point type and with Input storage class in a fragment
+ // > shader, must be decorated Flat
+ if class == spirv::StorageClass::Input && stage == crate::ShaderStage::Fragment {
+ let is_flat = match ir_module.types[ty].inner {
+ crate::TypeInner::Scalar(scalar)
+ | crate::TypeInner::Vector { scalar, .. } => match scalar.kind {
+ Sk::Uint | Sk::Sint | Sk::Bool => true,
+ Sk::Float => false,
+ Sk::AbstractInt | Sk::AbstractFloat => {
+ return Err(Error::Validation(
+ "Abstract types should not appear in IR presented to backends",
+ ))
+ }
+ },
+ _ => false,
+ };
+
+ if is_flat {
+ self.decorate(id, Decoration::Flat, &[]);
+ }
+ }
+ }
+ }
+
+ Ok(id)
+ }
+
+ fn write_global_variable(
+ &mut self,
+ ir_module: &crate::Module,
+ global_variable: &crate::GlobalVariable,
+ ) -> Result<Word, Error> {
+ use spirv::Decoration;
+
+ let id = self.id_gen.next();
+ let class = map_storage_class(global_variable.space);
+
+ //self.check(class.required_capabilities())?;
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = global_variable.name {
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+
+ let storage_access = match global_variable.space {
+ crate::AddressSpace::Storage { access } => Some(access),
+ _ => match ir_module.types[global_variable.ty].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => Some(access),
+ _ => None,
+ },
+ };
+ if let Some(storage_access) = storage_access {
+ if !storage_access.contains(crate::StorageAccess::LOAD) {
+ self.decorate(id, Decoration::NonReadable, &[]);
+ }
+ if !storage_access.contains(crate::StorageAccess::STORE) {
+ self.decorate(id, Decoration::NonWritable, &[]);
+ }
+ }
+
+ // Note: we should be able to substitute `binding_array<Foo, 0>`,
+ // but there is still code that tries to register the pre-substituted type,
+ // and it is failing on 0.
+ let mut substitute_inner_type_lookup = None;
+ if let Some(ref res_binding) = global_variable.binding {
+ self.decorate(id, Decoration::DescriptorSet, &[res_binding.group]);
+ self.decorate(id, Decoration::Binding, &[res_binding.binding]);
+
+ if let Some(&BindingInfo {
+ binding_array_size: Some(remapped_binding_array_size),
+ }) = self.binding_map.get(res_binding)
+ {
+ if let crate::TypeInner::BindingArray { base, .. } =
+ ir_module.types[global_variable.ty].inner
+ {
+ substitute_inner_type_lookup =
+ Some(LookupType::Local(LocalType::PointerToBindingArray {
+ base,
+ size: remapped_binding_array_size,
+ space: global_variable.space,
+ }))
+ }
+ }
+ };
+
+ let init_word = global_variable
+ .init
+ .map(|constant| self.constant_ids[constant.index()]);
+ let inner_type_id = self.get_type_id(
+ substitute_inner_type_lookup.unwrap_or(LookupType::Handle(global_variable.ty)),
+ );
+
+ // generate the wrapping structure if needed
+ let pointer_type_id = if global_needs_wrapper(ir_module, global_variable) {
+ let wrapper_type_id = self.id_gen.next();
+
+ self.decorate(wrapper_type_id, Decoration::Block, &[]);
+ let member = crate::StructMember {
+ name: None,
+ ty: global_variable.ty,
+ binding: None,
+ offset: 0,
+ };
+ self.decorate_struct_member(wrapper_type_id, 0, &member, &ir_module.types)?;
+
+ Instruction::type_struct(wrapper_type_id, &[inner_type_id])
+ .to_words(&mut self.logical_layout.declarations);
+
+ let pointer_type_id = self.id_gen.next();
+ Instruction::type_pointer(pointer_type_id, class, wrapper_type_id)
+ .to_words(&mut self.logical_layout.declarations);
+
+ pointer_type_id
+ } else {
+ // This is a global variable in the Storage address space. The only
+ // way it could have `global_needs_wrapper() == false` is if it has
+ // a runtime-sized or binding array.
+ // Runtime-sized arrays were decorated when iterating through struct content.
+ // Now binding arrays require Block decorating.
+ if let crate::AddressSpace::Storage { .. } = global_variable.space {
+ match ir_module.types[global_variable.ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => {
+ let decorated_id = self.get_type_id(LookupType::Handle(base));
+ self.decorate(decorated_id, Decoration::Block, &[]);
+ }
+ _ => (),
+ };
+ }
+ if substitute_inner_type_lookup.is_some() {
+ inner_type_id
+ } else {
+ self.get_pointer_id(&ir_module.types, global_variable.ty, class)?
+ }
+ };
+
+ let init_word = match (global_variable.space, self.zero_initialize_workgroup_memory) {
+ (crate::AddressSpace::Private, _)
+ | (crate::AddressSpace::WorkGroup, super::ZeroInitializeWorkgroupMemoryMode::Native) => {
+ init_word.or_else(|| Some(self.get_constant_null(inner_type_id)))
+ }
+ _ => init_word,
+ };
+
+ Instruction::variable(pointer_type_id, id, class, init_word)
+ .to_words(&mut self.logical_layout.declarations);
+ Ok(id)
+ }
+
+ /// Write the necessary decorations for a struct member.
+ ///
+ /// Emit decorations for the `index`'th member of the struct type
+ /// designated by `struct_id`, described by `member`.
+ fn decorate_struct_member(
+ &mut self,
+ struct_id: Word,
+ index: usize,
+ member: &crate::StructMember,
+ arena: &UniqueArena<crate::Type>,
+ ) -> Result<(), Error> {
+ use spirv::Decoration;
+
+ self.annotations.push(Instruction::member_decorate(
+ struct_id,
+ index as u32,
+ Decoration::Offset,
+ &[member.offset],
+ ));
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(ref name) = member.name {
+ self.debugs
+ .push(Instruction::member_name(struct_id, index as u32, name));
+ }
+ }
+
+ // Matrices and arrays of matrices both require decorations,
+ // so "see through" an array to determine if they're needed.
+ let member_array_subty_inner = match arena[member.ty].inner {
+ crate::TypeInner::Array { base, .. } => &arena[base].inner,
+ ref other => other,
+ };
+ if let crate::TypeInner::Matrix {
+ columns: _,
+ rows,
+ scalar,
+ } = *member_array_subty_inner
+ {
+ let byte_stride = Alignment::from(rows) * scalar.width as u32;
+ self.annotations.push(Instruction::member_decorate(
+ struct_id,
+ index as u32,
+ Decoration::ColMajor,
+ &[],
+ ));
+ self.annotations.push(Instruction::member_decorate(
+ struct_id,
+ index as u32,
+ Decoration::MatrixStride,
+ &[byte_stride],
+ ));
+ }
+
+ Ok(())
+ }
+
+ fn get_function_type(&mut self, lookup_function_type: LookupFunctionType) -> Word {
+ match self
+ .lookup_function_type
+ .entry(lookup_function_type.clone())
+ {
+ Entry::Occupied(e) => *e.get(),
+ Entry::Vacant(_) => {
+ let id = self.id_gen.next();
+ let instruction = Instruction::type_function(
+ id,
+ lookup_function_type.return_type_id,
+ &lookup_function_type.parameter_type_ids,
+ );
+ instruction.to_words(&mut self.logical_layout.declarations);
+ self.lookup_function_type.insert(lookup_function_type, id);
+ id
+ }
+ }
+ }
+
+ fn write_physical_layout(&mut self) {
+ self.physical_layout.bound = self.id_gen.0 + 1;
+ }
+
+ fn write_logical_layout(
+ &mut self,
+ ir_module: &crate::Module,
+ mod_info: &ModuleInfo,
+ ep_index: Option<usize>,
+ debug_info: &Option<DebugInfo>,
+ ) -> Result<(), Error> {
+ fn has_view_index_check(
+ ir_module: &crate::Module,
+ binding: Option<&crate::Binding>,
+ ty: Handle<crate::Type>,
+ ) -> bool {
+ match ir_module.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => members.iter().any(|member| {
+ has_view_index_check(ir_module, member.binding.as_ref(), member.ty)
+ }),
+ _ => binding == Some(&crate::Binding::BuiltIn(crate::BuiltIn::ViewIndex)),
+ }
+ }
+
+ let has_storage_buffers =
+ ir_module
+ .global_variables
+ .iter()
+ .any(|(_, var)| match var.space {
+ crate::AddressSpace::Storage { .. } => true,
+ _ => false,
+ });
+ let has_view_index = ir_module
+ .entry_points
+ .iter()
+ .flat_map(|entry| entry.function.arguments.iter())
+ .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty));
+ let has_ray_query = ir_module.special_types.ray_desc.is_some()
+ | ir_module.special_types.ray_intersection.is_some();
+
+ if self.physical_layout.version < 0x10300 && has_storage_buffers {
+ // enable the storage buffer class on < SPV-1.3
+ Instruction::extension("SPV_KHR_storage_buffer_storage_class")
+ .to_words(&mut self.logical_layout.extensions);
+ }
+ if has_view_index {
+ Instruction::extension("SPV_KHR_multiview")
+ .to_words(&mut self.logical_layout.extensions)
+ }
+ if has_ray_query {
+ Instruction::extension("SPV_KHR_ray_query")
+ .to_words(&mut self.logical_layout.extensions)
+ }
+ Instruction::type_void(self.void_type).to_words(&mut self.logical_layout.declarations);
+ Instruction::ext_inst_import(self.gl450_ext_inst_id, "GLSL.std.450")
+ .to_words(&mut self.logical_layout.ext_inst_imports);
+
+ let mut debug_info_inner = None;
+ if self.flags.contains(WriterFlags::DEBUG) {
+ if let Some(debug_info) = debug_info.as_ref() {
+ let source_file_id = self.id_gen.next();
+ self.debugs.push(Instruction::string(
+ &debug_info.file_name.display().to_string(),
+ source_file_id,
+ ));
+
+ debug_info_inner = Some(DebugInfoInner {
+ source_code: debug_info.source_code,
+ source_file_id,
+ });
+ self.debugs.push(Instruction::source(
+ spirv::SourceLanguage::Unknown,
+ 0,
+ &debug_info_inner,
+ ));
+ }
+ }
+
+ // write all types
+ for (handle, _) in ir_module.types.iter() {
+ self.write_type_declaration_arena(&ir_module.types, handle)?;
+ }
+
+ // write all const-expressions as constants
+ self.constant_ids
+ .resize(ir_module.const_expressions.len(), 0);
+ for (handle, _) in ir_module.const_expressions.iter() {
+ self.write_constant_expr(handle, ir_module, mod_info)?;
+ }
+ debug_assert!(self.constant_ids.iter().all(|&id| id != 0));
+
+ // write the name of constants on their respective const-expression initializer
+ if self.flags.contains(WriterFlags::DEBUG) {
+ for (_, constant) in ir_module.constants.iter() {
+ if let Some(ref name) = constant.name {
+ let id = self.constant_ids[constant.init.index()];
+ self.debugs.push(Instruction::name(id, name));
+ }
+ }
+ }
+
+ // write all global variables
+ for (handle, var) in ir_module.global_variables.iter() {
+ // If a single entry point was specified, only write `OpVariable` instructions
+ // for the globals it actually uses. Emit dummies for the others,
+ // to preserve the indices in `global_variables`.
+ let gvar = match ep_index {
+ Some(index) if mod_info.get_entry_point(index)[handle].is_empty() => {
+ GlobalVariable::dummy()
+ }
+ _ => {
+ let id = self.write_global_variable(ir_module, var)?;
+ GlobalVariable::new(id)
+ }
+ };
+ self.global_variables.push(gvar);
+ }
+
+ // write all functions
+ for (handle, ir_function) in ir_module.functions.iter() {
+ let info = &mod_info[handle];
+ if let Some(index) = ep_index {
+ let ep_info = mod_info.get_entry_point(index);
+ // If this function uses globals that we omitted from the SPIR-V
+ // because the entry point and its callees didn't use them,
+ // then we must skip it.
+ if !ep_info.dominates_global_use(info) {
+ log::info!("Skip function {:?}", ir_function.name);
+ continue;
+ }
+
+ // Skip functions that that are not compatible with this entry point's stage.
+ //
+ // When validation is enabled, it rejects modules whose entry points try to call
+ // incompatible functions, so if we got this far, then any functions incompatible
+ // with our selected entry point must not be used.
+ //
+ // When validation is disabled, `fun_info.available_stages` is always just
+ // `ShaderStages::all()`, so this will write all functions in the module, and
+ // the downstream GLSL compiler will catch any problems.
+ if !info.available_stages.contains(ep_info.available_stages) {
+ continue;
+ }
+ }
+ let id = self.write_function(ir_function, info, ir_module, None, &debug_info_inner)?;
+ self.lookup_function.insert(handle, id);
+ }
+
+ // write all or one entry points
+ for (index, ir_ep) in ir_module.entry_points.iter().enumerate() {
+ if ep_index.is_some() && ep_index != Some(index) {
+ continue;
+ }
+ let info = mod_info.get_entry_point(index);
+ let ep_instruction =
+ self.write_entry_point(ir_ep, info, ir_module, &debug_info_inner)?;
+ ep_instruction.to_words(&mut self.logical_layout.entry_points);
+ }
+
+ for capability in self.capabilities_used.iter() {
+ Instruction::capability(*capability).to_words(&mut self.logical_layout.capabilities);
+ }
+ for extension in self.extensions_used.iter() {
+ Instruction::extension(extension).to_words(&mut self.logical_layout.extensions);
+ }
+ if ir_module.entry_points.is_empty() {
+ // SPIR-V doesn't like modules without entry points
+ Instruction::capability(spirv::Capability::Linkage)
+ .to_words(&mut self.logical_layout.capabilities);
+ }
+
+ let addressing_model = spirv::AddressingModel::Logical;
+ let memory_model = spirv::MemoryModel::GLSL450;
+ //self.check(addressing_model.required_capabilities())?;
+ //self.check(memory_model.required_capabilities())?;
+
+ Instruction::memory_model(addressing_model, memory_model)
+ .to_words(&mut self.logical_layout.memory_model);
+
+ if self.flags.contains(WriterFlags::DEBUG) {
+ for debug in self.debugs.iter() {
+ debug.to_words(&mut self.logical_layout.debugs);
+ }
+ }
+
+ for annotation in self.annotations.iter() {
+ annotation.to_words(&mut self.logical_layout.annotations);
+ }
+
+ Ok(())
+ }
+
+ pub fn write(
+ &mut self,
+ ir_module: &crate::Module,
+ info: &ModuleInfo,
+ pipeline_options: Option<&PipelineOptions>,
+ debug_info: &Option<DebugInfo>,
+ words: &mut Vec<Word>,
+ ) -> Result<(), Error> {
+ self.reset();
+
+ // Try to find the entry point and corresponding index
+ let ep_index = match pipeline_options {
+ Some(po) => {
+ let index = ir_module
+ .entry_points
+ .iter()
+ .position(|ep| po.shader_stage == ep.stage && po.entry_point == ep.name)
+ .ok_or(Error::EntryPointNotFound)?;
+ Some(index)
+ }
+ None => None,
+ };
+
+ self.write_logical_layout(ir_module, info, ep_index, debug_info)?;
+ self.write_physical_layout();
+
+ self.physical_layout.in_words(words);
+ self.logical_layout.in_words(words);
+ Ok(())
+ }
+
+ /// Return the set of capabilities the last module written used.
+ pub const fn get_capabilities_used(&self) -> &crate::FastIndexSet<spirv::Capability> {
+ &self.capabilities_used
+ }
+
+ pub fn decorate_non_uniform_binding_array_access(&mut self, id: Word) -> Result<(), Error> {
+ self.require_any("NonUniformEXT", &[spirv::Capability::ShaderNonUniform])?;
+ self.use_extension("SPV_EXT_descriptor_indexing");
+ self.decorate(id, spirv::Decoration::NonUniform, &[]);
+ Ok(())
+ }
+}
+
+#[test]
+fn test_write_physical_layout() {
+ let mut writer = Writer::new(&Options::default()).unwrap();
+ assert_eq!(writer.physical_layout.bound, 0);
+ writer.write_physical_layout();
+ assert_eq!(writer.physical_layout.bound, 3);
+}
diff --git a/third_party/rust/naga/src/back/wgsl/mod.rs b/third_party/rust/naga/src/back/wgsl/mod.rs
new file mode 100644
index 0000000000..d731b1ca0c
--- /dev/null
+++ b/third_party/rust/naga/src/back/wgsl/mod.rs
@@ -0,0 +1,52 @@
+/*!
+Backend for [WGSL][wgsl] (WebGPU Shading Language).
+
+[wgsl]: https://gpuweb.github.io/gpuweb/wgsl.html
+*/
+
+mod writer;
+
+use thiserror::Error;
+
+pub use writer::{Writer, WriterFlags};
+
+#[derive(Error, Debug)]
+pub enum Error {
+ #[error(transparent)]
+ FmtError(#[from] std::fmt::Error),
+ #[error("{0}")]
+ Custom(String),
+ #[error("{0}")]
+ Unimplemented(String), // TODO: Error used only during development
+ #[error("Unsupported math function: {0:?}")]
+ UnsupportedMathFunction(crate::MathFunction),
+ #[error("Unsupported relational function: {0:?}")]
+ UnsupportedRelationalFunction(crate::RelationalFunction),
+}
+
+pub fn write_string(
+ module: &crate::Module,
+ info: &crate::valid::ModuleInfo,
+ flags: WriterFlags,
+) -> Result<String, Error> {
+ let mut w = Writer::new(String::new(), flags);
+ w.write(module, info)?;
+ let output = w.finish();
+ Ok(output)
+}
+
+impl crate::AtomicFunction {
+ const fn to_wgsl(self) -> &'static str {
+ match self {
+ Self::Add => "Add",
+ Self::Subtract => "Sub",
+ Self::And => "And",
+ Self::InclusiveOr => "Or",
+ Self::ExclusiveOr => "Xor",
+ Self::Min => "Min",
+ Self::Max => "Max",
+ Self::Exchange { compare: None } => "Exchange",
+ Self::Exchange { .. } => "CompareExchangeWeak",
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/wgsl/writer.rs b/third_party/rust/naga/src/back/wgsl/writer.rs
new file mode 100644
index 0000000000..c737934f5e
--- /dev/null
+++ b/third_party/rust/naga/src/back/wgsl/writer.rs
@@ -0,0 +1,1961 @@
+use super::Error;
+use crate::{
+ back,
+ proc::{self, NameKey},
+ valid, Handle, Module, ShaderStage, TypeInner,
+};
+use std::fmt::Write;
+
+/// Shorthand result used internally by the backend
+type BackendResult = Result<(), Error>;
+
+/// WGSL [attribute](https://gpuweb.github.io/gpuweb/wgsl/#attributes)
+enum Attribute {
+ Binding(u32),
+ BuiltIn(crate::BuiltIn),
+ Group(u32),
+ Invariant,
+ Interpolate(Option<crate::Interpolation>, Option<crate::Sampling>),
+ Location(u32),
+ SecondBlendSource,
+ Stage(ShaderStage),
+ WorkGroupSize([u32; 3]),
+}
+
+/// The WGSL form that `write_expr_with_indirection` should use to render a Naga
+/// expression.
+///
+/// Sometimes a Naga `Expression` alone doesn't provide enough information to
+/// choose the right rendering for it in WGSL. For example, one natural WGSL
+/// rendering of a Naga `LocalVariable(x)` expression might be `&x`, since
+/// `LocalVariable` produces a pointer to the local variable's storage. But when
+/// rendering a `Store` statement, the `pointer` operand must be the left hand
+/// side of a WGSL assignment, so the proper rendering is `x`.
+///
+/// The caller of `write_expr_with_indirection` must provide an `Expected` value
+/// to indicate how ambiguous expressions should be rendered.
+#[derive(Clone, Copy, Debug)]
+enum Indirection {
+ /// Render pointer-construction expressions as WGSL `ptr`-typed expressions.
+ ///
+ /// This is the right choice for most cases. Whenever a Naga pointer
+ /// expression is not the `pointer` operand of a `Load` or `Store`, it
+ /// must be a WGSL pointer expression.
+ Ordinary,
+
+ /// Render pointer-construction expressions as WGSL reference-typed
+ /// expressions.
+ ///
+ /// For example, this is the right choice for the `pointer` operand when
+ /// rendering a `Store` statement as a WGSL assignment.
+ Reference,
+}
+
+bitflags::bitflags! {
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ #[derive(Clone, Copy, Debug, Eq, PartialEq)]
+ pub struct WriterFlags: u32 {
+ /// Always annotate the type information instead of inferring.
+ const EXPLICIT_TYPES = 0x1;
+ }
+}
+
+pub struct Writer<W> {
+ out: W,
+ flags: WriterFlags,
+ names: crate::FastHashMap<NameKey, String>,
+ namer: proc::Namer,
+ named_expressions: crate::NamedExpressions,
+ ep_results: Vec<(ShaderStage, Handle<crate::Type>)>,
+}
+
+impl<W: Write> Writer<W> {
+ pub fn new(out: W, flags: WriterFlags) -> Self {
+ Writer {
+ out,
+ flags,
+ names: crate::FastHashMap::default(),
+ namer: proc::Namer::default(),
+ named_expressions: crate::NamedExpressions::default(),
+ ep_results: vec![],
+ }
+ }
+
+ fn reset(&mut self, module: &Module) {
+ self.names.clear();
+ self.namer.reset(
+ module,
+ crate::keywords::wgsl::RESERVED,
+ // an identifier must not start with two underscore
+ &[],
+ &[],
+ &["__"],
+ &mut self.names,
+ );
+ self.named_expressions.clear();
+ self.ep_results.clear();
+ }
+
+ fn is_builtin_wgsl_struct(&self, module: &Module, handle: Handle<crate::Type>) -> bool {
+ module
+ .special_types
+ .predeclared_types
+ .values()
+ .any(|t| *t == handle)
+ }
+
+ pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
+ self.reset(module);
+
+ // Save all ep result types
+ for (_, ep) in module.entry_points.iter().enumerate() {
+ if let Some(ref result) = ep.function.result {
+ self.ep_results.push((ep.stage, result.ty));
+ }
+ }
+
+ // Write all structs
+ for (handle, ty) in module.types.iter() {
+ if let TypeInner::Struct { ref members, .. } = ty.inner {
+ {
+ if !self.is_builtin_wgsl_struct(module, handle) {
+ self.write_struct(module, handle, members)?;
+ writeln!(self.out)?;
+ }
+ }
+ }
+ }
+
+ // Write all named constants
+ let mut constants = module
+ .constants
+ .iter()
+ .filter(|&(_, c)| c.name.is_some())
+ .peekable();
+ while let Some((handle, _)) = constants.next() {
+ self.write_global_constant(module, handle)?;
+ // Add extra newline for readability on last iteration
+ if constants.peek().is_none() {
+ writeln!(self.out)?;
+ }
+ }
+
+ // Write all globals
+ for (ty, global) in module.global_variables.iter() {
+ self.write_global(module, global, ty)?;
+ }
+
+ if !module.global_variables.is_empty() {
+ // Add extra newline for readability
+ writeln!(self.out)?;
+ }
+
+ // Write all regular functions
+ for (handle, function) in module.functions.iter() {
+ let fun_info = &info[handle];
+
+ let func_ctx = back::FunctionCtx {
+ ty: back::FunctionType::Function(handle),
+ info: fun_info,
+ expressions: &function.expressions,
+ named_expressions: &function.named_expressions,
+ };
+
+ // Write the function
+ self.write_function(module, function, &func_ctx)?;
+
+ writeln!(self.out)?;
+ }
+
+ // Write all entry points
+ for (index, ep) in module.entry_points.iter().enumerate() {
+ let attributes = match ep.stage {
+ ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)],
+ ShaderStage::Compute => vec![
+ Attribute::Stage(ShaderStage::Compute),
+ Attribute::WorkGroupSize(ep.workgroup_size),
+ ],
+ };
+
+ self.write_attributes(&attributes)?;
+ // Add a newline after attribute
+ writeln!(self.out)?;
+
+ let func_ctx = back::FunctionCtx {
+ ty: back::FunctionType::EntryPoint(index as u16),
+ info: info.get_entry_point(index),
+ expressions: &ep.function.expressions,
+ named_expressions: &ep.function.named_expressions,
+ };
+ self.write_function(module, &ep.function, &func_ctx)?;
+
+ if index < module.entry_points.len() - 1 {
+ writeln!(self.out)?;
+ }
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write struct name
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_struct_name(&mut self, module: &Module, handle: Handle<crate::Type>) -> BackendResult {
+ if module.types[handle].name.is_none() {
+ if let Some(&(stage, _)) = self.ep_results.iter().find(|&&(_, ty)| ty == handle) {
+ let name = match stage {
+ ShaderStage::Compute => "ComputeOutput",
+ ShaderStage::Fragment => "FragmentOutput",
+ ShaderStage::Vertex => "VertexOutput",
+ };
+
+ write!(self.out, "{name}")?;
+ return Ok(());
+ }
+ }
+
+ write!(self.out, "{}", self.names[&NameKey::Type(handle)])?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write
+ /// [functions](https://gpuweb.github.io/gpuweb/wgsl/#functions)
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_function(
+ &mut self,
+ module: &Module,
+ func: &crate::Function,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ let func_name = match func_ctx.ty {
+ back::FunctionType::EntryPoint(index) => &self.names[&NameKey::EntryPoint(index)],
+ back::FunctionType::Function(handle) => &self.names[&NameKey::Function(handle)],
+ };
+
+ // Write function name
+ write!(self.out, "fn {func_name}(")?;
+
+ // Write function arguments
+ for (index, arg) in func.arguments.iter().enumerate() {
+ // Write argument attribute if a binding is present
+ if let Some(ref binding) = arg.binding {
+ self.write_attributes(&map_binding_to_attribute(binding))?;
+ }
+ // Write argument name
+ let argument_name = &self.names[&func_ctx.argument_key(index as u32)];
+
+ write!(self.out, "{argument_name}: ")?;
+ // Write argument type
+ self.write_type(module, arg.ty)?;
+ if index < func.arguments.len() - 1 {
+ // Add a separator between args
+ write!(self.out, ", ")?;
+ }
+ }
+
+ write!(self.out, ")")?;
+
+ // Write function return type
+ if let Some(ref result) = func.result {
+ write!(self.out, " -> ")?;
+ if let Some(ref binding) = result.binding {
+ self.write_attributes(&map_binding_to_attribute(binding))?;
+ }
+ self.write_type(module, result.ty)?;
+ }
+
+ write!(self.out, " {{")?;
+ writeln!(self.out)?;
+
+ // Write function local variables
+ for (handle, local) in func.local_variables.iter() {
+ // Write indentation (only for readability)
+ write!(self.out, "{}", back::INDENT)?;
+
+ // Write the local name
+ // The leading space is important
+ write!(self.out, "var {}: ", self.names[&func_ctx.name_key(handle)])?;
+
+ // Write the local type
+ self.write_type(module, local.ty)?;
+
+ // Write the local initializer if needed
+ if let Some(init) = local.init {
+ // Put the equal signal only if there's a initializer
+ // The leading and trailing spaces aren't needed but help with readability
+ write!(self.out, " = ")?;
+
+ // Write the constant
+ // `write_constant` adds no trailing or leading space/newline
+ self.write_expr(module, init, func_ctx)?;
+ }
+
+ // Finish the local with `;` and add a newline (only for readability)
+ writeln!(self.out, ";")?
+ }
+
+ if !func.local_variables.is_empty() {
+ writeln!(self.out)?;
+ }
+
+ // Write the function body (statement list)
+ for sta in func.body.iter() {
+ // The indentation should always be 1 when writing the function body
+ self.write_stmt(module, sta, func_ctx, back::Level(1))?;
+ }
+
+ writeln!(self.out, "}}")?;
+
+ self.named_expressions.clear();
+
+ Ok(())
+ }
+
+ /// Helper method to write a attribute
+ fn write_attributes(&mut self, attributes: &[Attribute]) -> BackendResult {
+ for attribute in attributes {
+ match *attribute {
+ Attribute::Location(id) => write!(self.out, "@location({id}) ")?,
+ Attribute::SecondBlendSource => write!(self.out, "@second_blend_source ")?,
+ Attribute::BuiltIn(builtin_attrib) => {
+ let builtin = builtin_str(builtin_attrib)?;
+ write!(self.out, "@builtin({builtin}) ")?;
+ }
+ Attribute::Stage(shader_stage) => {
+ let stage_str = match shader_stage {
+ ShaderStage::Vertex => "vertex",
+ ShaderStage::Fragment => "fragment",
+ ShaderStage::Compute => "compute",
+ };
+ write!(self.out, "@{stage_str} ")?;
+ }
+ Attribute::WorkGroupSize(size) => {
+ write!(
+ self.out,
+ "@workgroup_size({}, {}, {}) ",
+ size[0], size[1], size[2]
+ )?;
+ }
+ Attribute::Binding(id) => write!(self.out, "@binding({id}) ")?,
+ Attribute::Group(id) => write!(self.out, "@group({id}) ")?,
+ Attribute::Invariant => write!(self.out, "@invariant ")?,
+ Attribute::Interpolate(interpolation, sampling) => {
+ if sampling.is_some() && sampling != Some(crate::Sampling::Center) {
+ write!(
+ self.out,
+ "@interpolate({}, {}) ",
+ interpolation_str(
+ interpolation.unwrap_or(crate::Interpolation::Perspective)
+ ),
+ sampling_str(sampling.unwrap_or(crate::Sampling::Center))
+ )?;
+ } else if interpolation.is_some()
+ && interpolation != Some(crate::Interpolation::Perspective)
+ {
+ write!(
+ self.out,
+ "@interpolate({}) ",
+ interpolation_str(
+ interpolation.unwrap_or(crate::Interpolation::Perspective)
+ )
+ )?;
+ }
+ }
+ };
+ }
+ Ok(())
+ }
+
+ /// Helper method used to write structs
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_struct(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Type>,
+ members: &[crate::StructMember],
+ ) -> BackendResult {
+ write!(self.out, "struct ")?;
+ self.write_struct_name(module, handle)?;
+ write!(self.out, " {{")?;
+ writeln!(self.out)?;
+ for (index, member) in members.iter().enumerate() {
+ // The indentation is only for readability
+ write!(self.out, "{}", back::INDENT)?;
+ if let Some(ref binding) = member.binding {
+ self.write_attributes(&map_binding_to_attribute(binding))?;
+ }
+ // Write struct member name and type
+ let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
+ write!(self.out, "{member_name}: ")?;
+ self.write_type(module, member.ty)?;
+ write!(self.out, ",")?;
+ writeln!(self.out)?;
+ }
+
+ write!(self.out, "}}")?;
+
+ writeln!(self.out)?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write non image/sampler types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_type(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ let inner = &module.types[ty].inner;
+ match *inner {
+ TypeInner::Struct { .. } => self.write_struct_name(module, ty)?,
+ ref other => self.write_value_type(module, other)?,
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write value types
+ ///
+ /// # Notes
+ /// Adds no trailing or leading whitespace
+ fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
+ match *inner {
+ TypeInner::Vector { size, scalar } => write!(
+ self.out,
+ "vec{}<{}>",
+ back::vector_size_str(size),
+ scalar_kind_str(scalar),
+ )?,
+ TypeInner::Sampler { comparison: false } => {
+ write!(self.out, "sampler")?;
+ }
+ TypeInner::Sampler { comparison: true } => {
+ write!(self.out, "sampler_comparison")?;
+ }
+ TypeInner::Image {
+ dim,
+ arrayed,
+ class,
+ } => {
+ // More about texture types: https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type
+ use crate::ImageClass as Ic;
+
+ let dim_str = image_dimension_str(dim);
+ let arrayed_str = if arrayed { "_array" } else { "" };
+ let (class_str, multisampled_str, format_str, storage_str) = match class {
+ Ic::Sampled { kind, multi } => (
+ "",
+ if multi { "multisampled_" } else { "" },
+ scalar_kind_str(crate::Scalar { kind, width: 4 }),
+ "",
+ ),
+ Ic::Depth { multi } => {
+ ("depth_", if multi { "multisampled_" } else { "" }, "", "")
+ }
+ Ic::Storage { format, access } => (
+ "storage_",
+ "",
+ storage_format_str(format),
+ if access.contains(crate::StorageAccess::LOAD | crate::StorageAccess::STORE)
+ {
+ ",read_write"
+ } else if access.contains(crate::StorageAccess::LOAD) {
+ ",read"
+ } else {
+ ",write"
+ },
+ ),
+ };
+ write!(
+ self.out,
+ "texture_{class_str}{multisampled_str}{dim_str}{arrayed_str}"
+ )?;
+
+ if !format_str.is_empty() {
+ write!(self.out, "<{format_str}{storage_str}>")?;
+ }
+ }
+ TypeInner::Scalar(scalar) => {
+ write!(self.out, "{}", scalar_kind_str(scalar))?;
+ }
+ TypeInner::Atomic(scalar) => {
+ write!(self.out, "atomic<{}>", scalar_kind_str(scalar))?;
+ }
+ TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ } => {
+ // More info https://gpuweb.github.io/gpuweb/wgsl/#array-types
+ // array<A, 3> -- Constant array
+ // array<A> -- Dynamic array
+ write!(self.out, "array<")?;
+ match size {
+ crate::ArraySize::Constant(len) => {
+ self.write_type(module, base)?;
+ write!(self.out, ", {len}")?;
+ }
+ crate::ArraySize::Dynamic => {
+ self.write_type(module, base)?;
+ }
+ }
+ write!(self.out, ">")?;
+ }
+ TypeInner::BindingArray { base, size } => {
+ // More info https://github.com/gpuweb/gpuweb/issues/2105
+ write!(self.out, "binding_array<")?;
+ match size {
+ crate::ArraySize::Constant(len) => {
+ self.write_type(module, base)?;
+ write!(self.out, ", {len}")?;
+ }
+ crate::ArraySize::Dynamic => {
+ self.write_type(module, base)?;
+ }
+ }
+ write!(self.out, ">")?;
+ }
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ write!(
+ self.out,
+ "mat{}x{}<{}>",
+ back::vector_size_str(columns),
+ back::vector_size_str(rows),
+ scalar_kind_str(scalar)
+ )?;
+ }
+ TypeInner::Pointer { base, space } => {
+ let (address, maybe_access) = address_space_str(space);
+ // Everything but `AddressSpace::Handle` gives us a `address` name, but
+ // Naga IR never produces pointers to handles, so it doesn't matter much
+ // how we write such a type. Just write it as the base type alone.
+ if let Some(space) = address {
+ write!(self.out, "ptr<{space}, ")?;
+ }
+ self.write_type(module, base)?;
+ if address.is_some() {
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {access}")?;
+ }
+ write!(self.out, ">")?;
+ }
+ }
+ TypeInner::ValuePointer {
+ size: None,
+ scalar,
+ space,
+ } => {
+ let (address, maybe_access) = address_space_str(space);
+ if let Some(space) = address {
+ write!(self.out, "ptr<{}, {}", space, scalar_kind_str(scalar))?;
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {access}")?;
+ }
+ write!(self.out, ">")?;
+ } else {
+ return Err(Error::Unimplemented(format!(
+ "ValuePointer to AddressSpace::Handle {inner:?}"
+ )));
+ }
+ }
+ TypeInner::ValuePointer {
+ size: Some(size),
+ scalar,
+ space,
+ } => {
+ let (address, maybe_access) = address_space_str(space);
+ if let Some(space) = address {
+ write!(
+ self.out,
+ "ptr<{}, vec{}<{}>",
+ space,
+ back::vector_size_str(size),
+ scalar_kind_str(scalar)
+ )?;
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {access}")?;
+ }
+ write!(self.out, ">")?;
+ } else {
+ return Err(Error::Unimplemented(format!(
+ "ValuePointer to AddressSpace::Handle {inner:?}"
+ )));
+ }
+ write!(self.out, ">")?;
+ }
+ _ => {
+ return Err(Error::Unimplemented(format!("write_value_type {inner:?}")));
+ }
+ }
+
+ Ok(())
+ }
+ /// Helper method used to write statements
+ ///
+ /// # Notes
+ /// Always adds a newline
+ fn write_stmt(
+ &mut self,
+ module: &Module,
+ stmt: &crate::Statement,
+ func_ctx: &back::FunctionCtx<'_>,
+ level: back::Level,
+ ) -> BackendResult {
+ use crate::{Expression, Statement};
+
+ match *stmt {
+ Statement::Emit(ref range) => {
+ for handle in range.clone() {
+ let info = &func_ctx.info[handle];
+ let expr_name = if let Some(name) = func_ctx.named_expressions.get(&handle) {
+ // Front end provides names for all variables at the start of writing.
+ // But we write them to step by step. We need to recache them
+ // Otherwise, we could accidentally write variable name instead of full expression.
+ // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
+ Some(self.namer.call(name))
+ } else {
+ let expr = &func_ctx.expressions[handle];
+ let min_ref_count = expr.bake_ref_count();
+ // Forcefully creating baking expressions in some cases to help with readability
+ let required_baking_expr = match *expr {
+ Expression::ImageLoad { .. }
+ | Expression::ImageQuery { .. }
+ | Expression::ImageSample { .. } => true,
+ _ => false,
+ };
+ if min_ref_count <= info.ref_count || required_baking_expr {
+ Some(format!("{}{}", back::BAKE_PREFIX, handle.index()))
+ } else {
+ None
+ }
+ };
+
+ if let Some(name) = expr_name {
+ write!(self.out, "{level}")?;
+ self.start_named_expr(module, handle, func_ctx, &name)?;
+ self.write_expr(module, handle, func_ctx)?;
+ self.named_expressions.insert(handle, name);
+ writeln!(self.out, ";")?;
+ }
+ }
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "if ")?;
+ self.write_expr(module, condition, func_ctx)?;
+ writeln!(self.out, " {{")?;
+
+ let l2 = level.next();
+ for sta in accept {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+
+ // If there are no statements in the reject block we skip writing it
+ // This is only for readability
+ if !reject.is_empty() {
+ writeln!(self.out, "{level}}} else {{")?;
+
+ for sta in reject {
+ // Increase indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Return { value } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "return")?;
+ if let Some(return_value) = value {
+ // The leading space is important
+ write!(self.out, " ")?;
+ self.write_expr(module, return_value, func_ctx)?;
+ }
+ writeln!(self.out, ";")?;
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Kill => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "discard;")?
+ }
+ Statement::Store { pointer, value } => {
+ write!(self.out, "{level}")?;
+
+ let is_atomic_pointer = func_ctx
+ .resolve_type(pointer, &module.types)
+ .is_atomic_pointer(&module.types);
+
+ if is_atomic_pointer {
+ write!(self.out, "atomicStore(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr_with_indirection(
+ module,
+ pointer,
+ func_ctx,
+ Indirection::Reference,
+ )?;
+ write!(self.out, " = ")?;
+ self.write_expr(module, value, func_ctx)?;
+ }
+ writeln!(self.out, ";")?
+ }
+ Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ if let Some(expr) = result {
+ let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
+ self.start_named_expr(module, expr, func_ctx, &name)?;
+ self.named_expressions.insert(expr, name);
+ }
+ let func_name = &self.names[&NameKey::Function(function)];
+ write!(self.out, "{func_name}(")?;
+ for (index, &argument) in arguments.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ self.write_expr(module, argument, func_ctx)?;
+ }
+ writeln!(self.out, ");")?
+ }
+ Statement::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_named_expr(module, result, func_ctx, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+
+ let fun_str = fun.to_wgsl();
+ write!(self.out, "atomic{fun_str}(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ write!(self.out, ", ")?;
+ self.write_expr(module, cmp, func_ctx)?;
+ }
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ");")?
+ }
+ Statement::WorkGroupUniformLoad { pointer, result } => {
+ write!(self.out, "{level}")?;
+ // TODO: Obey named expressions here.
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_named_expr(module, result, func_ctx, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+ write!(self.out, "workgroupUniformLoad(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "textureStore(")?;
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ if let Some(array_index_expr) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index_expr, func_ctx)?;
+ }
+ write!(self.out, ", ")?;
+ self.write_expr(module, value, func_ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ // TODO: copy-paste from glsl-out
+ Statement::Block(ref block) => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "{{")?;
+ for sta in block.iter() {
+ // Increase the indentation to help with readability
+ self.write_stmt(module, sta, func_ctx, level.next())?
+ }
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ // Start the switch
+ write!(self.out, "{level}")?;
+ write!(self.out, "switch ")?;
+ self.write_expr(module, selector, func_ctx)?;
+ writeln!(self.out, " {{")?;
+
+ let l2 = level.next();
+ let mut new_case = true;
+ for case in cases {
+ if case.fall_through && !case.body.is_empty() {
+ // TODO: we could do the same workaround as we did for the HLSL backend
+ return Err(Error::Unimplemented(
+ "fall-through switch case block".into(),
+ ));
+ }
+
+ match case.value {
+ crate::SwitchValue::I32(value) => {
+ if new_case {
+ write!(self.out, "{l2}case ")?;
+ }
+ write!(self.out, "{value}")?;
+ }
+ crate::SwitchValue::U32(value) => {
+ if new_case {
+ write!(self.out, "{l2}case ")?;
+ }
+ write!(self.out, "{value}u")?;
+ }
+ crate::SwitchValue::Default => {
+ if new_case {
+ if case.fall_through {
+ write!(self.out, "{l2}case ")?;
+ } else {
+ write!(self.out, "{l2}")?;
+ }
+ }
+ write!(self.out, "default")?;
+ }
+ }
+
+ new_case = !case.fall_through;
+
+ if case.fall_through {
+ write!(self.out, ", ")?;
+ } else {
+ writeln!(self.out, ": {{")?;
+ }
+
+ for sta in case.body.iter() {
+ self.write_stmt(module, sta, func_ctx, l2.next())?;
+ }
+
+ if !case.fall_through {
+ writeln!(self.out, "{l2}}}")?;
+ }
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ write!(self.out, "{level}")?;
+ writeln!(self.out, "loop {{")?;
+
+ let l2 = level.next();
+ for sta in body.iter() {
+ self.write_stmt(module, sta, func_ctx, l2)?;
+ }
+
+ // The continuing is optional so we don't need to write it if
+ // it is empty, but the `break if` counts as a continuing statement
+ // so even if `continuing` is empty we must generate it if a
+ // `break if` exists
+ if !continuing.is_empty() || break_if.is_some() {
+ writeln!(self.out, "{l2}continuing {{")?;
+ for sta in continuing.iter() {
+ self.write_stmt(module, sta, func_ctx, l2.next())?;
+ }
+
+ // The `break if` is always the last
+ // statement of the `continuing` block
+ if let Some(condition) = break_if {
+ // The trailing space is important
+ write!(self.out, "{}break if ", l2.next())?;
+ self.write_expr(module, condition, func_ctx)?;
+ // Close the `break if` statement
+ writeln!(self.out, ";")?;
+ }
+
+ writeln!(self.out, "{l2}}}")?;
+ }
+
+ writeln!(self.out, "{level}}}")?
+ }
+ Statement::Break => {
+ writeln!(self.out, "{level}break;")?;
+ }
+ Statement::Continue => {
+ writeln!(self.out, "{level}continue;")?;
+ }
+ Statement::Barrier(barrier) => {
+ if barrier.contains(crate::Barrier::STORAGE) {
+ writeln!(self.out, "{level}storageBarrier();")?;
+ }
+
+ if barrier.contains(crate::Barrier::WORK_GROUP) {
+ writeln!(self.out, "{level}workgroupBarrier();")?;
+ }
+ }
+ Statement::RayQuery { .. } => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Return the sort of indirection that `expr`'s plain form evaluates to.
+ ///
+ /// An expression's 'plain form' is the most general rendition of that
+ /// expression into WGSL, lacking `&` or `*` operators:
+ ///
+ /// - The plain form of `LocalVariable(x)` is simply `x`, which is a reference
+ /// to the local variable's storage.
+ ///
+ /// - The plain form of `GlobalVariable(g)` is simply `g`, which is usually a
+ /// reference to the global variable's storage. However, globals in the
+ /// `Handle` address space are immutable, and `GlobalVariable` expressions for
+ /// those produce the value directly, not a pointer to it. Such
+ /// `GlobalVariable` expressions are `Ordinary`.
+ ///
+ /// - `Access` and `AccessIndex` are `Reference` when their `base` operand is a
+ /// pointer. If they are applied directly to a composite value, they are
+ /// `Ordinary`.
+ ///
+ /// Note that `FunctionArgument` expressions are never `Reference`, even when
+ /// the argument's type is `Pointer`. `FunctionArgument` always evaluates to the
+ /// argument's value directly, so any pointer it produces is merely the value
+ /// passed by the caller.
+ fn plain_form_indirection(
+ &self,
+ expr: Handle<crate::Expression>,
+ module: &Module,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> Indirection {
+ use crate::Expression as Ex;
+
+ // Named expressions are `let` expressions, which apply the Load Rule,
+ // so if their type is a Naga pointer, then that must be a WGSL pointer
+ // as well.
+ if self.named_expressions.contains_key(&expr) {
+ return Indirection::Ordinary;
+ }
+
+ match func_ctx.expressions[expr] {
+ Ex::LocalVariable(_) => Indirection::Reference,
+ Ex::GlobalVariable(handle) => {
+ let global = &module.global_variables[handle];
+ match global.space {
+ crate::AddressSpace::Handle => Indirection::Ordinary,
+ _ => Indirection::Reference,
+ }
+ }
+ Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
+ let base_ty = func_ctx.resolve_type(base, &module.types);
+ match *base_ty {
+ crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } => {
+ Indirection::Reference
+ }
+ _ => Indirection::Ordinary,
+ }
+ }
+ _ => Indirection::Ordinary,
+ }
+ }
+
+ fn start_named_expr(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx,
+ name: &str,
+ ) -> BackendResult {
+ // Write variable name
+ write!(self.out, "let {name}")?;
+ if self.flags.contains(WriterFlags::EXPLICIT_TYPES) {
+ write!(self.out, ": ")?;
+ let ty = &func_ctx.info[handle].ty;
+ // Write variable type
+ match *ty {
+ proc::TypeResolution::Handle(handle) => {
+ self.write_type(module, handle)?;
+ }
+ proc::TypeResolution::Value(ref inner) => {
+ self.write_value_type(module, inner)?;
+ }
+ }
+ }
+
+ write!(self.out, " = ")?;
+ Ok(())
+ }
+
+ /// Write the ordinary WGSL form of `expr`.
+ ///
+ /// See `write_expr_with_indirection` for details.
+ fn write_expr(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ ) -> BackendResult {
+ self.write_expr_with_indirection(module, expr, func_ctx, Indirection::Ordinary)
+ }
+
+ /// Write `expr` as a WGSL expression with the requested indirection.
+ ///
+ /// In terms of the WGSL grammar, the resulting expression is a
+ /// `singular_expression`. It may be parenthesized. This makes it suitable
+ /// for use as the operand of a unary or binary operator without worrying
+ /// about precedence.
+ ///
+ /// This does not produce newlines or indentation.
+ ///
+ /// The `requested` argument indicates (roughly) whether Naga
+ /// `Pointer`-valued expressions represent WGSL references or pointers. See
+ /// `Indirection` for details.
+ fn write_expr_with_indirection(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ requested: Indirection,
+ ) -> BackendResult {
+ // If the plain form of the expression is not what we need, emit the
+ // operator necessary to correct that.
+ let plain = self.plain_form_indirection(expr, module, func_ctx);
+ match (requested, plain) {
+ (Indirection::Ordinary, Indirection::Reference) => {
+ write!(self.out, "(&")?;
+ self.write_expr_plain_form(module, expr, func_ctx, plain)?;
+ write!(self.out, ")")?;
+ }
+ (Indirection::Reference, Indirection::Ordinary) => {
+ write!(self.out, "(*")?;
+ self.write_expr_plain_form(module, expr, func_ctx, plain)?;
+ write!(self.out, ")")?;
+ }
+ (_, _) => self.write_expr_plain_form(module, expr, func_ctx, plain)?,
+ }
+
+ Ok(())
+ }
+
+ fn write_const_expression(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ ) -> BackendResult {
+ self.write_possibly_const_expression(
+ module,
+ expr,
+ &module.const_expressions,
+ |writer, expr| writer.write_const_expression(module, expr),
+ )
+ }
+
+ fn write_possibly_const_expression<E>(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ expressions: &crate::Arena<crate::Expression>,
+ write_expression: E,
+ ) -> BackendResult
+ where
+ E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
+ {
+ use crate::Expression;
+
+ match expressions[expr] {
+ Expression::Literal(literal) => match literal {
+ crate::Literal::F32(value) => write!(self.out, "{}f", value)?,
+ crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
+ crate::Literal::I32(value) => {
+ // `-2147483648i` is not valid WGSL. The most negative `i32`
+ // value can only be expressed in WGSL using AbstractInt and
+ // a unary negation operator.
+ if value == i32::MIN {
+ write!(self.out, "i32(-2147483648)")?;
+ } else {
+ write!(self.out, "{}i", value)?;
+ }
+ }
+ crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
+ crate::Literal::F64(value) => write!(self.out, "{:?}lf", value)?,
+ crate::Literal::I64(_) => {
+ return Err(Error::Custom("unsupported i64 literal".to_string()));
+ }
+ crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => {
+ return Err(Error::Custom(
+ "Abstract types should not appear in IR presented to backends".into(),
+ ));
+ }
+ },
+ Expression::Constant(handle) => {
+ let constant = &module.constants[handle];
+ if constant.name.is_some() {
+ write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
+ } else {
+ self.write_const_expression(module, constant.init)?;
+ }
+ }
+ Expression::ZeroValue(ty) => {
+ self.write_type(module, ty)?;
+ write!(self.out, "()")?;
+ }
+ Expression::Compose { ty, ref components } => {
+ self.write_type(module, ty)?;
+ write!(self.out, "(")?;
+ for (index, component) in components.iter().enumerate() {
+ if index != 0 {
+ write!(self.out, ", ")?;
+ }
+ write_expression(self, *component)?;
+ }
+ write!(self.out, ")")?
+ }
+ Expression::Splat { size, value } => {
+ let size = back::vector_size_str(size);
+ write!(self.out, "vec{size}(")?;
+ write_expression(self, value)?;
+ write!(self.out, ")")?;
+ }
+ _ => unreachable!(),
+ }
+
+ Ok(())
+ }
+
+ /// Write the 'plain form' of `expr`.
+ ///
+ /// An expression's 'plain form' is the most general rendition of that
+ /// expression into WGSL, lacking `&` or `*` operators. The plain forms of
+ /// `LocalVariable(x)` and `GlobalVariable(g)` are simply `x` and `g`. Such
+ /// Naga expressions represent both WGSL pointers and references; it's the
+ /// caller's responsibility to distinguish those cases appropriately.
+ fn write_expr_plain_form(
+ &mut self,
+ module: &Module,
+ expr: Handle<crate::Expression>,
+ func_ctx: &back::FunctionCtx<'_>,
+ indirection: Indirection,
+ ) -> BackendResult {
+ use crate::Expression;
+
+ if let Some(name) = self.named_expressions.get(&expr) {
+ write!(self.out, "{name}")?;
+ return Ok(());
+ }
+
+ let expression = &func_ctx.expressions[expr];
+
+ // Write the plain WGSL form of a Naga expression.
+ //
+ // The plain form of `LocalVariable` and `GlobalVariable` expressions is
+ // simply the variable name; `*` and `&` operators are never emitted.
+ //
+ // The plain form of `Access` and `AccessIndex` expressions are WGSL
+ // `postfix_expression` forms for member/component access and
+ // subscripting.
+ match *expression {
+ Expression::Literal(_)
+ | Expression::Constant(_)
+ | Expression::ZeroValue(_)
+ | Expression::Compose { .. }
+ | Expression::Splat { .. } => {
+ self.write_possibly_const_expression(
+ module,
+ expr,
+ func_ctx.expressions,
+ |writer, expr| writer.write_expr(module, expr, func_ctx),
+ )?;
+ }
+ Expression::FunctionArgument(pos) => {
+ let name_key = func_ctx.argument_key(pos);
+ let name = &self.names[&name_key];
+ write!(self.out, "{name}")?;
+ }
+ Expression::Binary { op, left, right } => {
+ write!(self.out, "(")?;
+ self.write_expr(module, left, func_ctx)?;
+ write!(self.out, " {} ", back::binary_operation_str(op))?;
+ self.write_expr(module, right, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Access { base, index } => {
+ self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
+ write!(self.out, "[")?;
+ self.write_expr(module, index, func_ctx)?;
+ write!(self.out, "]")?
+ }
+ Expression::AccessIndex { base, index } => {
+ let base_ty_res = &func_ctx.info[base].ty;
+ let mut resolved = base_ty_res.inner_with(&module.types);
+
+ self.write_expr_with_indirection(module, base, func_ctx, indirection)?;
+
+ let base_ty_handle = match *resolved {
+ TypeInner::Pointer { base, space: _ } => {
+ resolved = &module.types[base].inner;
+ Some(base)
+ }
+ _ => base_ty_res.handle(),
+ };
+
+ match *resolved {
+ TypeInner::Vector { .. } => {
+ // Write vector access as a swizzle
+ write!(self.out, ".{}", back::COMPONENTS[index as usize])?
+ }
+ TypeInner::Matrix { .. }
+ | TypeInner::Array { .. }
+ | TypeInner::BindingArray { .. }
+ | TypeInner::ValuePointer { .. } => write!(self.out, "[{index}]")?,
+ TypeInner::Struct { .. } => {
+ // This will never panic in case the type is a `Struct`, this is not true
+ // for other types so we can only check while inside this match arm
+ let ty = base_ty_handle.unwrap();
+
+ write!(
+ self.out,
+ ".{}",
+ &self.names[&NameKey::StructMember(ty, index)]
+ )?
+ }
+ ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
+ }
+ }
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather: None,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ use crate::SampleLevel as Sl;
+
+ let suffix_cmp = match depth_ref {
+ Some(_) => "Compare",
+ None => "",
+ };
+ let suffix_level = match level {
+ Sl::Auto => "",
+ Sl::Zero | Sl::Exact(_) => "Level",
+ Sl::Bias(_) => "Bias",
+ Sl::Gradient { .. } => "Grad",
+ };
+
+ write!(self.out, "textureSample{suffix_cmp}{suffix_level}(")?;
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, sampler, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+
+ if let Some(array_index) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index, func_ctx)?;
+ }
+
+ if let Some(depth_ref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(module, depth_ref, func_ctx)?;
+ }
+
+ match level {
+ Sl::Auto => {}
+ Sl::Zero => {
+ // Level 0 is implied for depth comparison
+ if depth_ref.is_none() {
+ write!(self.out, ", 0.0")?;
+ }
+ }
+ Sl::Exact(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Bias(expr) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, expr, func_ctx)?;
+ }
+ Sl::Gradient { x, y } => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, x, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, y, func_ctx)?;
+ }
+ }
+
+ if let Some(offset) = offset {
+ write!(self.out, ", ")?;
+ self.write_const_expression(module, offset)?;
+ }
+
+ write!(self.out, ")")?;
+ }
+
+ Expression::ImageSample {
+ image,
+ sampler,
+ gather: Some(component),
+ coordinate,
+ array_index,
+ offset,
+ level: _,
+ depth_ref,
+ } => {
+ let suffix_cmp = match depth_ref {
+ Some(_) => "Compare",
+ None => "",
+ };
+
+ write!(self.out, "textureGather{suffix_cmp}(")?;
+ match *func_ctx.resolve_type(image, &module.types) {
+ TypeInner::Image {
+ class: crate::ImageClass::Depth { multi: _ },
+ ..
+ } => {}
+ _ => {
+ write!(self.out, "{}, ", component as u8)?;
+ }
+ }
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, sampler, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+
+ if let Some(array_index) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index, func_ctx)?;
+ }
+
+ if let Some(depth_ref) = depth_ref {
+ write!(self.out, ", ")?;
+ self.write_expr(module, depth_ref, func_ctx)?;
+ }
+
+ if let Some(offset) = offset {
+ write!(self.out, ", ")?;
+ self.write_const_expression(module, offset)?;
+ }
+
+ write!(self.out, ")")?;
+ }
+ Expression::ImageQuery { image, query } => {
+ use crate::ImageQuery as Iq;
+
+ let texture_function = match query {
+ Iq::Size { .. } => "textureDimensions",
+ Iq::NumLevels => "textureNumLevels",
+ Iq::NumLayers => "textureNumLayers",
+ Iq::NumSamples => "textureNumSamples",
+ };
+
+ write!(self.out, "{texture_function}(")?;
+ self.write_expr(module, image, func_ctx)?;
+ if let Iq::Size { level: Some(level) } = query {
+ write!(self.out, ", ")?;
+ self.write_expr(module, level, func_ctx)?;
+ };
+ write!(self.out, ")")?;
+ }
+
+ Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ write!(self.out, "textureLoad(")?;
+ self.write_expr(module, image, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, coordinate, func_ctx)?;
+ if let Some(array_index) = array_index {
+ write!(self.out, ", ")?;
+ self.write_expr(module, array_index, func_ctx)?;
+ }
+ if let Some(index) = sample.or(level) {
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ }
+ write!(self.out, ")")?;
+ }
+ Expression::GlobalVariable(handle) => {
+ let name = &self.names[&NameKey::GlobalVariable(handle)];
+ write!(self.out, "{name}")?;
+ }
+
+ Expression::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ let inner = func_ctx.resolve_type(expr, &module.types);
+ match *inner {
+ TypeInner::Matrix {
+ columns,
+ rows,
+ scalar,
+ } => {
+ let scalar = crate::Scalar {
+ kind,
+ width: convert.unwrap_or(scalar.width),
+ };
+ let scalar_kind_str = scalar_kind_str(scalar);
+ write!(
+ self.out,
+ "mat{}x{}<{}>",
+ back::vector_size_str(columns),
+ back::vector_size_str(rows),
+ scalar_kind_str
+ )?;
+ }
+ TypeInner::Vector {
+ size,
+ scalar: crate::Scalar { width, .. },
+ } => {
+ let scalar = crate::Scalar {
+ kind,
+ width: convert.unwrap_or(width),
+ };
+ let vector_size_str = back::vector_size_str(size);
+ let scalar_kind_str = scalar_kind_str(scalar);
+ if convert.is_some() {
+ write!(self.out, "vec{vector_size_str}<{scalar_kind_str}>")?;
+ } else {
+ write!(self.out, "bitcast<vec{vector_size_str}<{scalar_kind_str}>>")?;
+ }
+ }
+ TypeInner::Scalar(crate::Scalar { width, .. }) => {
+ let scalar = crate::Scalar {
+ kind,
+ width: convert.unwrap_or(width),
+ };
+ let scalar_kind_str = scalar_kind_str(scalar);
+ if convert.is_some() {
+ write!(self.out, "{scalar_kind_str}")?
+ } else {
+ write!(self.out, "bitcast<{scalar_kind_str}>")?
+ }
+ }
+ _ => {
+ return Err(Error::Unimplemented(format!(
+ "write_expr expression::as {inner:?}"
+ )));
+ }
+ };
+ write!(self.out, "(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+ Expression::Load { pointer } => {
+ let is_atomic_pointer = func_ctx
+ .resolve_type(pointer, &module.types)
+ .is_atomic_pointer(&module.types);
+
+ if is_atomic_pointer {
+ write!(self.out, "atomicLoad(")?;
+ self.write_expr(module, pointer, func_ctx)?;
+ write!(self.out, ")")?;
+ } else {
+ self.write_expr_with_indirection(
+ module,
+ pointer,
+ func_ctx,
+ Indirection::Reference,
+ )?;
+ }
+ }
+ Expression::LocalVariable(handle) => {
+ write!(self.out, "{}", self.names[&func_ctx.name_key(handle)])?
+ }
+ Expression::ArrayLength(expr) => {
+ write!(self.out, "arrayLength(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?;
+ }
+
+ Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ enum Function {
+ Regular(&'static str),
+ }
+
+ let function = match fun {
+ Mf::Abs => Function::Regular("abs"),
+ Mf::Min => Function::Regular("min"),
+ Mf::Max => Function::Regular("max"),
+ Mf::Clamp => Function::Regular("clamp"),
+ Mf::Saturate => Function::Regular("saturate"),
+ // trigonometry
+ Mf::Cos => Function::Regular("cos"),
+ Mf::Cosh => Function::Regular("cosh"),
+ Mf::Sin => Function::Regular("sin"),
+ Mf::Sinh => Function::Regular("sinh"),
+ Mf::Tan => Function::Regular("tan"),
+ Mf::Tanh => Function::Regular("tanh"),
+ Mf::Acos => Function::Regular("acos"),
+ Mf::Asin => Function::Regular("asin"),
+ Mf::Atan => Function::Regular("atan"),
+ Mf::Atan2 => Function::Regular("atan2"),
+ Mf::Asinh => Function::Regular("asinh"),
+ Mf::Acosh => Function::Regular("acosh"),
+ Mf::Atanh => Function::Regular("atanh"),
+ Mf::Radians => Function::Regular("radians"),
+ Mf::Degrees => Function::Regular("degrees"),
+ // decomposition
+ Mf::Ceil => Function::Regular("ceil"),
+ Mf::Floor => Function::Regular("floor"),
+ Mf::Round => Function::Regular("round"),
+ Mf::Fract => Function::Regular("fract"),
+ Mf::Trunc => Function::Regular("trunc"),
+ Mf::Modf => Function::Regular("modf"),
+ Mf::Frexp => Function::Regular("frexp"),
+ Mf::Ldexp => Function::Regular("ldexp"),
+ // exponent
+ Mf::Exp => Function::Regular("exp"),
+ Mf::Exp2 => Function::Regular("exp2"),
+ Mf::Log => Function::Regular("log"),
+ Mf::Log2 => Function::Regular("log2"),
+ Mf::Pow => Function::Regular("pow"),
+ // geometry
+ Mf::Dot => Function::Regular("dot"),
+ Mf::Cross => Function::Regular("cross"),
+ Mf::Distance => Function::Regular("distance"),
+ Mf::Length => Function::Regular("length"),
+ Mf::Normalize => Function::Regular("normalize"),
+ Mf::FaceForward => Function::Regular("faceForward"),
+ Mf::Reflect => Function::Regular("reflect"),
+ Mf::Refract => Function::Regular("refract"),
+ // computational
+ Mf::Sign => Function::Regular("sign"),
+ Mf::Fma => Function::Regular("fma"),
+ Mf::Mix => Function::Regular("mix"),
+ Mf::Step => Function::Regular("step"),
+ Mf::SmoothStep => Function::Regular("smoothstep"),
+ Mf::Sqrt => Function::Regular("sqrt"),
+ Mf::InverseSqrt => Function::Regular("inverseSqrt"),
+ Mf::Transpose => Function::Regular("transpose"),
+ Mf::Determinant => Function::Regular("determinant"),
+ // bits
+ Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"),
+ Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"),
+ Mf::CountOneBits => Function::Regular("countOneBits"),
+ Mf::ReverseBits => Function::Regular("reverseBits"),
+ Mf::ExtractBits => Function::Regular("extractBits"),
+ Mf::InsertBits => Function::Regular("insertBits"),
+ Mf::FindLsb => Function::Regular("firstTrailingBit"),
+ Mf::FindMsb => Function::Regular("firstLeadingBit"),
+ // data packing
+ Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"),
+ Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"),
+ Mf::Pack2x16snorm => Function::Regular("pack2x16snorm"),
+ Mf::Pack2x16unorm => Function::Regular("pack2x16unorm"),
+ Mf::Pack2x16float => Function::Regular("pack2x16float"),
+ // data unpacking
+ Mf::Unpack4x8snorm => Function::Regular("unpack4x8snorm"),
+ Mf::Unpack4x8unorm => Function::Regular("unpack4x8unorm"),
+ Mf::Unpack2x16snorm => Function::Regular("unpack2x16snorm"),
+ Mf::Unpack2x16unorm => Function::Regular("unpack2x16unorm"),
+ Mf::Unpack2x16float => Function::Regular("unpack2x16float"),
+ Mf::Inverse | Mf::Outer => {
+ return Err(Error::UnsupportedMathFunction(fun));
+ }
+ };
+
+ match function {
+ Function::Regular(fun_name) => {
+ write!(self.out, "{fun_name}(")?;
+ self.write_expr(module, arg, func_ctx)?;
+ for arg in IntoIterator::into_iter([arg1, arg2, arg3]).flatten() {
+ write!(self.out, ", ")?;
+ self.write_expr(module, arg, func_ctx)?;
+ }
+ write!(self.out, ")")?
+ }
+ }
+ }
+
+ Expression::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ self.write_expr(module, vector, func_ctx)?;
+ write!(self.out, ".")?;
+ for &sc in pattern[..size as usize].iter() {
+ self.out.write_char(back::COMPONENTS[sc as usize])?;
+ }
+ }
+ Expression::Unary { op, expr } => {
+ let unary = match op {
+ crate::UnaryOperator::Negate => "-",
+ crate::UnaryOperator::LogicalNot => "!",
+ crate::UnaryOperator::BitwiseNot => "~",
+ };
+
+ write!(self.out, "{unary}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+
+ write!(self.out, ")")?
+ }
+
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ write!(self.out, "select(")?;
+ self.write_expr(module, reject, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, accept, func_ctx)?;
+ write!(self.out, ", ")?;
+ self.write_expr(module, condition, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Derivative { axis, ctrl, expr } => {
+ use crate::{DerivativeAxis as Axis, DerivativeControl as Ctrl};
+ let op = match (axis, ctrl) {
+ (Axis::X, Ctrl::Coarse) => "dpdxCoarse",
+ (Axis::X, Ctrl::Fine) => "dpdxFine",
+ (Axis::X, Ctrl::None) => "dpdx",
+ (Axis::Y, Ctrl::Coarse) => "dpdyCoarse",
+ (Axis::Y, Ctrl::Fine) => "dpdyFine",
+ (Axis::Y, Ctrl::None) => "dpdy",
+ (Axis::Width, Ctrl::Coarse) => "fwidthCoarse",
+ (Axis::Width, Ctrl::Fine) => "fwidthFine",
+ (Axis::Width, Ctrl::None) => "fwidth",
+ };
+ write!(self.out, "{op}(")?;
+ self.write_expr(module, expr, func_ctx)?;
+ write!(self.out, ")")?
+ }
+ Expression::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+
+ let fun_name = match fun {
+ Rf::All => "all",
+ Rf::Any => "any",
+ _ => return Err(Error::UnsupportedRelationalFunction(fun)),
+ };
+ write!(self.out, "{fun_name}(")?;
+
+ self.write_expr(module, argument, func_ctx)?;
+
+ write!(self.out, ")")?
+ }
+ // Not supported yet
+ Expression::RayQueryGetIntersection { .. } => unreachable!(),
+ // Nothing to do here, since call expression already cached
+ Expression::CallResult(_)
+ | Expression::AtomicResult { .. }
+ | Expression::RayQueryProceedResult
+ | Expression::WorkGroupUniformLoadResult { .. } => {}
+ }
+
+ Ok(())
+ }
+
+ /// Helper method used to write global variables
+ /// # Notes
+ /// Always adds a newline
+ fn write_global(
+ &mut self,
+ module: &Module,
+ global: &crate::GlobalVariable,
+ handle: Handle<crate::GlobalVariable>,
+ ) -> BackendResult {
+ // Write group and binding attributes if present
+ if let Some(ref binding) = global.binding {
+ self.write_attributes(&[
+ Attribute::Group(binding.group),
+ Attribute::Binding(binding.binding),
+ ])?;
+ writeln!(self.out)?;
+ }
+
+ // First write global name and address space if supported
+ write!(self.out, "var")?;
+ let (address, maybe_access) = address_space_str(global.space);
+ if let Some(space) = address {
+ write!(self.out, "<{space}")?;
+ if let Some(access) = maybe_access {
+ write!(self.out, ", {access}")?;
+ }
+ write!(self.out, ">")?;
+ }
+ write!(
+ self.out,
+ " {}: ",
+ &self.names[&NameKey::GlobalVariable(handle)]
+ )?;
+
+ // Write global type
+ self.write_type(module, global.ty)?;
+
+ // Write initializer
+ if let Some(init) = global.init {
+ write!(self.out, " = ")?;
+ self.write_const_expression(module, init)?;
+ }
+
+ // End with semicolon
+ writeln!(self.out, ";")?;
+
+ Ok(())
+ }
+
+ /// Helper method used to write global constants
+ ///
+ /// # Notes
+ /// Ends in a newline
+ fn write_global_constant(
+ &mut self,
+ module: &Module,
+ handle: Handle<crate::Constant>,
+ ) -> BackendResult {
+ let name = &self.names[&NameKey::Constant(handle)];
+ // First write only constant name
+ write!(self.out, "const {name}: ")?;
+ self.write_type(module, module.constants[handle].ty)?;
+ write!(self.out, " = ")?;
+ let init = module.constants[handle].init;
+ self.write_const_expression(module, init)?;
+ writeln!(self.out, ";")?;
+
+ Ok(())
+ }
+
+ // See https://github.com/rust-lang/rust-clippy/issues/4979.
+ #[allow(clippy::missing_const_for_fn)]
+ pub fn finish(self) -> W {
+ self.out
+ }
+}
+
+fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> {
+ use crate::BuiltIn as Bi;
+
+ Ok(match built_in {
+ Bi::VertexIndex => "vertex_index",
+ Bi::InstanceIndex => "instance_index",
+ Bi::Position { .. } => "position",
+ Bi::FrontFacing => "front_facing",
+ Bi::FragDepth => "frag_depth",
+ Bi::LocalInvocationId => "local_invocation_id",
+ Bi::LocalInvocationIndex => "local_invocation_index",
+ Bi::GlobalInvocationId => "global_invocation_id",
+ Bi::WorkGroupId => "workgroup_id",
+ Bi::NumWorkGroups => "num_workgroups",
+ Bi::SampleIndex => "sample_index",
+ Bi::SampleMask => "sample_mask",
+ Bi::PrimitiveIndex => "primitive_index",
+ Bi::ViewIndex => "view_index",
+ Bi::BaseInstance
+ | Bi::BaseVertex
+ | Bi::ClipDistance
+ | Bi::CullDistance
+ | Bi::PointSize
+ | Bi::PointCoord
+ | Bi::WorkGroupSize => {
+ return Err(Error::Custom(format!("Unsupported builtin {built_in:?}")))
+ }
+ })
+}
+
+const fn image_dimension_str(dim: crate::ImageDimension) -> &'static str {
+ use crate::ImageDimension as IDim;
+
+ match dim {
+ IDim::D1 => "1d",
+ IDim::D2 => "2d",
+ IDim::D3 => "3d",
+ IDim::Cube => "cube",
+ }
+}
+
+const fn scalar_kind_str(scalar: crate::Scalar) -> &'static str {
+ use crate::Scalar;
+ use crate::ScalarKind as Sk;
+
+ match scalar {
+ Scalar {
+ kind: Sk::Float,
+ width: 8,
+ } => "f64",
+ Scalar {
+ kind: Sk::Float,
+ width: 4,
+ } => "f32",
+ Scalar {
+ kind: Sk::Sint,
+ width: 4,
+ } => "i32",
+ Scalar {
+ kind: Sk::Uint,
+ width: 4,
+ } => "u32",
+ Scalar {
+ kind: Sk::Bool,
+ width: 1,
+ } => "bool",
+ _ => unreachable!(),
+ }
+}
+
+const fn storage_format_str(format: crate::StorageFormat) -> &'static str {
+ use crate::StorageFormat as Sf;
+
+ match format {
+ Sf::R8Unorm => "r8unorm",
+ Sf::R8Snorm => "r8snorm",
+ Sf::R8Uint => "r8uint",
+ Sf::R8Sint => "r8sint",
+ Sf::R16Uint => "r16uint",
+ Sf::R16Sint => "r16sint",
+ Sf::R16Float => "r16float",
+ Sf::Rg8Unorm => "rg8unorm",
+ Sf::Rg8Snorm => "rg8snorm",
+ Sf::Rg8Uint => "rg8uint",
+ Sf::Rg8Sint => "rg8sint",
+ Sf::R32Uint => "r32uint",
+ Sf::R32Sint => "r32sint",
+ Sf::R32Float => "r32float",
+ Sf::Rg16Uint => "rg16uint",
+ Sf::Rg16Sint => "rg16sint",
+ Sf::Rg16Float => "rg16float",
+ Sf::Rgba8Unorm => "rgba8unorm",
+ Sf::Rgba8Snorm => "rgba8snorm",
+ Sf::Rgba8Uint => "rgba8uint",
+ Sf::Rgba8Sint => "rgba8sint",
+ Sf::Bgra8Unorm => "bgra8unorm",
+ Sf::Rgb10a2Uint => "rgb10a2uint",
+ Sf::Rgb10a2Unorm => "rgb10a2unorm",
+ Sf::Rg11b10Float => "rg11b10float",
+ Sf::Rg32Uint => "rg32uint",
+ Sf::Rg32Sint => "rg32sint",
+ Sf::Rg32Float => "rg32float",
+ Sf::Rgba16Uint => "rgba16uint",
+ Sf::Rgba16Sint => "rgba16sint",
+ Sf::Rgba16Float => "rgba16float",
+ Sf::Rgba32Uint => "rgba32uint",
+ Sf::Rgba32Sint => "rgba32sint",
+ Sf::Rgba32Float => "rgba32float",
+ Sf::R16Unorm => "r16unorm",
+ Sf::R16Snorm => "r16snorm",
+ Sf::Rg16Unorm => "rg16unorm",
+ Sf::Rg16Snorm => "rg16snorm",
+ Sf::Rgba16Unorm => "rgba16unorm",
+ Sf::Rgba16Snorm => "rgba16snorm",
+ }
+}
+
+/// Helper function that returns the string corresponding to the WGSL interpolation qualifier
+const fn interpolation_str(interpolation: crate::Interpolation) -> &'static str {
+ use crate::Interpolation as I;
+
+ match interpolation {
+ I::Perspective => "perspective",
+ I::Linear => "linear",
+ I::Flat => "flat",
+ }
+}
+
+/// Return the WGSL auxiliary qualifier for the given sampling value.
+const fn sampling_str(sampling: crate::Sampling) -> &'static str {
+ use crate::Sampling as S;
+
+ match sampling {
+ S::Center => "",
+ S::Centroid => "centroid",
+ S::Sample => "sample",
+ }
+}
+
+const fn address_space_str(
+ space: crate::AddressSpace,
+) -> (Option<&'static str>, Option<&'static str>) {
+ use crate::AddressSpace as As;
+
+ (
+ Some(match space {
+ As::Private => "private",
+ As::Uniform => "uniform",
+ As::Storage { access } => {
+ if access.contains(crate::StorageAccess::STORE) {
+ return (Some("storage"), Some("read_write"));
+ } else {
+ "storage"
+ }
+ }
+ As::PushConstant => "push_constant",
+ As::WorkGroup => "workgroup",
+ As::Handle => return (None, None),
+ As::Function => "function",
+ }),
+ None,
+ )
+}
+
+fn map_binding_to_attribute(binding: &crate::Binding) -> Vec<Attribute> {
+ match *binding {
+ crate::Binding::BuiltIn(built_in) => {
+ if let crate::BuiltIn::Position { invariant: true } = built_in {
+ vec![Attribute::BuiltIn(built_in), Attribute::Invariant]
+ } else {
+ vec![Attribute::BuiltIn(built_in)]
+ }
+ }
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ second_blend_source: false,
+ } => vec![
+ Attribute::Location(location),
+ Attribute::Interpolate(interpolation, sampling),
+ ],
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ second_blend_source: true,
+ } => vec![
+ Attribute::Location(location),
+ Attribute::SecondBlendSource,
+ Attribute::Interpolate(interpolation, sampling),
+ ],
+ }
+}